Source code for pymablock.algorithm_parsing

"""Tools for compiling optimized series computations."""

from __future__ import annotations

import ast
import dataclasses
import inspect
from collections import Counter, defaultdict
from collections.abc import Callable  # noqa: TC003 (sphinx needs unconditional import)
from enum import Enum
from functools import cache
from itertools import chain
from operator import matmul
from typing import Any

import numpy as np
from sympy.physics.quantum import Dagger

from pymablock.linalg import aslinearoperator
from pymablock.series import BlockSeries, cauchy_dot_product, one, zero

__all__ = ["series_computation"]

result = ast.Name(id="result", ctx=ast.Load())


@dataclasses.dataclass
class _Series:
    """Series properties."""

    name: str = None
    start: str = None
    uses: list = dataclasses.field(default_factory=list)
    definition: list[ast.stmt] = dataclasses.field(default_factory=list)


@dataclasses.dataclass
class _Product:
    """Product properties."""

    terms: list[str] = dataclasses.field(default_factory=list)
    hermitian: list = False

    @property
    def name(self) -> str:
        """Name that represents the product."""
        return " @ ".join(self.terms)


class _EvalTransformer(ast.NodeTransformer):
    """Transforms a `with` statement to a callable eval understood by `BlockSeries`."""

    def __init__(self, to_delete):
        self.to_delete = to_delete

    def visit_With(self, node: ast.With) -> ast.Module:
        """Build a function definition from a `with` statement."""
        linear_operator_select = ast.parse(
            "which = linear_operator_series if use_linear_operator[index[:2]] else series"
        ).body
        result_is_zero = ast.parse("result = zero").body

        module = ast.Module(
            body=[
                ast.FunctionDef(
                    name="series_eval",
                    args=ast.arguments(
                        posonlyargs=[],
                        args=[],
                        vararg=ast.arg(arg="index"),
                        kwonlyargs=[],
                        kw_defaults=[],
                        defaults=[],
                    ),
                    body=[
                        *linear_operator_select,
                        *result_is_zero,
                        *(
                            line
                            for expr in node.body
                            for line in self._visit_Line(expr)
                            if line is not None
                        ),
                        ast.Return(value=result),
                    ],
                    decorator_list=[],
                )
            ],
            type_ignores=[],
        )
        ast.fix_missing_locations(module)
        return module

    def _visit_Line(self, node: ast.AST) -> list[ast.AST]:
        """Transform each line of the function body.

        Assign statements are removed.
        If statements are transformed to a valid index test.
        Expressions are transformed using `_visit_Eval`.
        """
        if isinstance(node, ast.Assign):
            # Delete start = ... statements
            return [None]
        if isinstance(node, ast.If):
            eval_type = _EvalType.from_condition(node.test.id)
            if eval_type is None:
                return node
            node.test = eval_type.test
            # The diagonal blocks are wrapped inside `diag`
            if eval_type == _EvalType.diagonal:
                node.body[0] = ast.Expr(
                    ast.Call(
                        ast.Name(id="diag", ctx=ast.Load()), [node.body[0].value], []
                    )
                )
            nodes = [node]
            # If an offdiagonal eval is present, we need to evaluate
            # this wrapped with `offdiag` for diagonal blocks.
            if eval_type == _EvalType.offdiagonal:
                nodes.append(
                    ast.If(
                        test=ast.BoolOp(
                            op=ast.And(),
                            values=[
                                ast.parse("offdiag is not None").body[0].value,
                                _EvalType.diagonal.test,
                            ],
                        ),
                        body=[
                            ast.Expr(
                                ast.Call(
                                    ast.Name(id="offdiag", ctx=ast.Load()),
                                    [node.body[0].value],
                                    [],
                                )
                            )
                        ],
                        orelse=[],
                    )
                )
            for node in nodes:
                node.body = self._visit_Eval(node.body[0], eval_type)
            if eval_type == _EvalType.lower:
                nodes[0].body.append(ast.Return(value=result))

            return nodes

        return self._visit_Eval(node, _EvalType.default)

    def _visit_Eval(self, node: ast.Expr, eval_type: _EvalType) -> list[ast.AST]:
        """Transform evaluation expressions to executable AST.

        First it applies `_SumTransformer`, `_DivideTransformer`, `_LiteralTransformer` and `_FunctionTransformer`
        and stores the result.
        Then it inserts delete statements for intermediate terms.
        Finally it returns the result.
        """
        diagonal = eval_type == _EvalType.diagonal
        eval_transformers = [
            _SumTransformer(),
            _DivideTransformer(),
            _FunctionTransformer(),
            _LiteralTransformer(diagonal=diagonal),
        ]
        node = node.value  # Get the expression from the Expr node.
        node = ast.BinOp(left=result, op=ast.Add(), right=node)
        for transformer in eval_transformers:
            node = transformer.visit(node)
        return [
            ast.Assign(
                targets=[ast.Name(id="result", ctx=ast.Store())],
                value=node,
            ),
            *(
                ast.Expr(
                    value=ast.Call(
                        func=ast.Name(id="del_", ctx=ast.Load()),
                        args=[
                            ast.Constant(value=term),
                            _LiteralTransformer._index(adjoint),
                        ],
                        keywords=[],
                    )
                )
                for term, adjoint, _eval_type in self.to_delete
                if _eval_type == eval_type
            ),
        ]


class _HermitianTransformer(ast.NodeTransformer):
    """Transform hermitian attributes into if statements."""

    def __init__(self, term):
        self.term = term

    def visit_Expr(self, node: ast.Expr) -> ast.AST:
        """Insert a conditional evaluation for hermitian and antihermitian attributes.

        It adds an if statement with `lower` that matches indices in the lower triangle.
        The evaluation result is either the conjugate adjoint of itself in case of `hermitian`
        or the negation of that in case of `antihermitian`.
        """
        if not isinstance(node.value, ast.Name):
            return self.generic_visit(node)

        term = ast.Attribute(
            value=ast.Constant(value=self.term), attr="adj", ctx=ast.Load()
        )

        match node.value.id:
            case "hermitian":
                pass
            case "antihermitian":
                term = ast.UnaryOp(op=ast.USub(), operand=term)
            case _:
                return self.generic_visit(node)

        return ast.If(
            test=ast.Name(id="lower", ctx=ast.Load()),
            body=[ast.Expr(value=term)],
            orelse=[],
        )


class _UseCounter(ast.NodeVisitor):
    """Count uses of terms in an expression.

    The result is later used to determine which terms are accessed exactly once,
    which can be deleted from the series after accessing them.
    """

    def __init__(self):
        self.uses = []  # List of (term, adjoint)

    def visit_Attribute(self, node: ast.Attribute) -> ast.AST:
        """Count an adjoint access."""
        # We assume the attribute is `.adj`.
        self.uses.append((node.value.value, True))

    def visit_Constant(self, node: ast.Constant) -> ast.AST:
        """Count a regular access."""
        if not isinstance(node.value, str):
            return
        self.uses.append((node.value, False))


class _LiteralTransformer(ast.NodeTransformer):
    """Transform string literals to `series[term][index]`."""

    def __init__(self, diagonal: bool):
        self.diagonal = diagonal

    def visit_Subscript(self, node: ast.Subscript) -> ast.AST:
        # Do not visit subscripts as these are already transformed.
        return node

    def visit_Attribute(self, node: ast.Attribute) -> ast.AST:
        """Transform adjoint terms."""
        # We assume the attribute is `.adj`.
        return self._to_series_index(node.value, adjoint=True)

    def visit_Constant(self, node: ast.Constant) -> ast.AST:
        """Transform regular terms."""
        if not isinstance(node.value, str):
            return node
        return self._to_series_index(node, adjoint=False)

    @staticmethod
    def _to_series(node: ast.Constant) -> ast.AST:
        """Build series[term] as AST."""
        return ast.Subscript(
            value=ast.Name(id="which", ctx=ast.Load()),
            slice=ast.Constant(value=node.value),
            ctx=ast.Load(),
        )

    def _to_series_index(self, node: ast.Constant, adjoint: bool) -> ast.AST:
        """Build series[term][index] as AST."""
        result = ast.Subscript(
            value=self._to_series(node),
            slice=ast.Index(value=self._index(adjoint and (not self.diagonal))),
            ctx=ast.Load(),
        )
        if adjoint:
            result = ast.Call(
                func=ast.Name(id="Dagger", ctx=ast.Load()),
                args=[result],
                keywords=[],
            )
        return result

    @staticmethod
    def _index(adjoint) -> ast.AST:
        """Build the (adjoint) index as AST."""
        if not adjoint:
            return ast.Name(id="index", ctx=ast.Load())
        return ast.parse("(index[1], index[0], *index[2:])").body[0].value


class _SumTransformer(ast.NodeTransformer):
    """Transform additive operations to `_zero_sum`."""

    @staticmethod
    def _is_zero_sum(node: ast.AST) -> bool:
        """Whether a node is a call to `_zero_sum`."""
        return isinstance(node, ast.Call) and node.func.id == "_zero_sum"

    @staticmethod
    def _zero_sum(args: list[ast.AST]) -> ast.Call:
        """Build AST representation of `_zero_sum` of args."""
        return ast.Call(
            func=ast.Name(id="_zero_sum", ctx=ast.Load()),
            args=args,
            keywords=[],
        )

    @staticmethod
    def _negate(node: ast.AST) -> ast.AST:
        """Negate a node. Return the original node if already negated."""
        if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
            return node.operand
        return ast.UnaryOp(
            op=ast.USub(),
            operand=node,
        )

    def visit_BinOp(self, node: ast.BinOp) -> ast.AST:
        """Recursively transform subtraction and addition to a `_zero_sum` call."""
        if not (isinstance(node.op, ast.Add) or isinstance(node.op, ast.Sub)):
            return self.generic_visit(node)

        left = self.visit(node.left)
        right = self.visit(node.right)

        left_args = left.args if self._is_zero_sum(left) else [left]
        right_args = right.args if self._is_zero_sum(right) else [right]

        if isinstance(node.op, ast.Sub):
            right_args = [self._negate(arg) for arg in right_args]

        return self._zero_sum(left_args + right_args)


class _DivideTransformer(ast.NodeTransformer):
    """Replace division with `_safe_divide`."""

    def visit_BinOp(self, node: ast.BinOp) -> ast.AST:
        """Transform division to a `_safe_divide` call."""
        if not isinstance(node.op, ast.Div):
            return self.generic_visit(node)

        return ast.Call(
            func=ast.Name(id="_safe_divide", ctx=ast.Load()),
            args=[node.left, node.right],
            keywords=[],
        )


class _FunctionTransformer(ast.NodeTransformer):
    """Transforms function calls.

    The internal functions `_safe_divide` and `_zero_sum` are not modified.
    Other functions are changed as follows:
    - If an argument to the function is a series, it is transformed into `series["arg"]`.
    - All other arguments are left unchanged.
    - The index is added as the last argument.
    """

    def visit_Call(self, node: ast.Call) -> ast.AST:
        if not isinstance(node.func, ast.Name):
            return self.generic_visit(node)

        # Functions introduced internally, should not be modified.
        if node.func.id in ["_safe_divide", "_zero_sum"]:
            return self.generic_visit(node)

        return self._visit_series_argument(node)

    def _visit_series_argument(self, node: ast.Call) -> ast.AST:
        """Transform functions that have series as arguments.

        Inserts the index as the last argument, preceded by all series passed as arguments.
        The series arguments are string literals, which are transformed to `series["arg"]`.
        """
        node.args = [
            *(
                _LiteralTransformer._to_series(arg)
                if (isinstance(arg, ast.Constant) and isinstance(arg.value, str))
                else arg
                for arg in node.args
            ),
            ast.Name(id="index", ctx=ast.Load()),
        ]
        return node


def _parse_return(node: ast.Return) -> list[str]:
    """Parse return statement to list of series names."""
    if isinstance(node, ast.Return):
        if isinstance(node.value, ast.Constant):
            return [node.value.value]
        if isinstance(node.value, ast.Tuple):
            return [element.value for element in node.value.elts]
    return []


def _preprocess_algorithm(
    definition: ast.FunctionDef,
) -> tuple[list[_Series], list[_Product], list[str]]:
    """Read and preprocess series, products and outputs definition."""
    series = []
    products = []
    outputs = []
    for node in definition.body:
        if isinstance(node, ast.With):
            if "@" in node.items[0].context_expr.value:
                products.append(_read_product(node))
            else:
                series.append(_preprocess_series(node))
        if isinstance(node, ast.Return):
            outputs = _parse_return(node)

    return series, products, outputs


def _find_delete_candidates(
    series: list[_Series], products: list[_Product], outputs: list[str]
) -> dict[str, list[tuple[str, tuple[int, int], "_EvalType"]]]:
    """Determine the intermediate terms to delete.

    All terms that are accessed exactly once are detected.

    The result is a dictionary where the values correspond to the terms to be
    deleted, consisting of the name, index and eval type. Each key marks from
    which series the terms are accessed.

    """
    # We should never delete terms that appear in products, are part of the input, or
    # are part of the output.
    terms_in_products = set(chain.from_iterable(product.terms for product in products))
    computed = set(term.name for term in series)
    inputs = (
        set(
            needed_term
            for term in series
            for needed_term, _, _ in term.uses
            if "@" not in needed_term  # Products are defined in a different way.
        )
        - computed
    )
    delete_blacklist = terms_in_products | inputs | set(outputs)

    uses = []  # List of (term, index)
    source_map = {}  # Map from (term, index) to (origin, adjoint, eval_type)

    # Collect all accessed terms and indices of the entire algorithm.
    for origin in series:
        # These indices are valid for 2x2 matrices, but with larger sizes they
        # keep track of the diagonal/offdiagonal structure.
        remaining_indices = {(0, 0), (0, 1), (1, 0), (1, 1)}
        last_eval_type = None

        for term, adjoint, eval_type in origin.uses:
            if eval_type != last_eval_type:
                # Update indices used for this eval_type.
                # We assume the uses are ordered by their appearance in the series definition.
                indices = set(
                    index for index in remaining_indices if eval_type.matches(index)
                )
                remaining_indices -= indices
                last_eval_type = eval_type

            if term in delete_blacklist:
                continue

            for index in indices:
                if adjoint:
                    index = (index[1], index[0])
                uses.append((term, index))
                # This gets overwritten if the term is used multiple times.
                # This is fine since we only care about terms that are used once.
                source_map[(term, index)] = (origin.name, adjoint, eval_type)

    # Find terms that are used exactly once.
    delete_items = [item for item, count in Counter(uses).items() if count == 1]

    # Group terms by their origin and collect the adjoint and eval_type.
    result = defaultdict(set)
    for term, index in delete_items:
        origin, adjoint, eval_type = source_map[(term, index)]
        result[origin].add((term, adjoint, eval_type))

    return result


def _read_product(definition: ast.With) -> _Product:
    """Read product properties."""
    product = _Product()
    name = definition.items[0].context_expr.value
    product.terms = name.split(" @ ")
    for node in definition.body:
        if not isinstance(node, ast.Expr):
            continue
        if not isinstance(node.value, ast.Name):
            continue
        if node.value.id == "hermitian":
            product.hermitian = True
    return product


class _EvalType(Enum):
    """Represents the different types of evaluations."""

    def __init__(self, test: str):
        self.test = ast.parse(test).body[0].value if test else None

    @staticmethod
    def from_condition(value: str) -> _EvalType | None:
        """Get the eval type from an if statement test."""
        try:
            return _EvalType[value]
        except KeyError:
            return None

    def matches(self, index: tuple[int, int]) -> bool:
        """Whether the index matches the condition of this eval type."""
        if self.test is None:
            return True
        return eval(
            compile(ast.Expression(body=self.test), "<string>", mode="eval"),
            {},
            {"index": index},
        )

    default = (None,)
    diagonal = ("index[0] == index[1]",)
    offdiagonal = ("index[0] != index[1]",)
    lower = ("index[0] > index[1]",)


def _preprocess_series(definition: ast.With) -> _Series:
    """Determine the properties of a series."""
    series = _Series()
    series.name = definition.items[0].context_expr.value
    series.definition = _HermitianTransformer(series.name).visit(definition)

    for node in definition.body:
        # Read and remove start = ... statements.
        if isinstance(node, ast.Assign):
            if node.targets[0].id == "start":
                series.start = _parse_start(node.value.value)
            continue

        # Extract the expression and eval type.
        if isinstance(node, ast.Expr):
            if isinstance(node.value, ast.Name):
                continue
            eval_type = _EvalType.default
            expression = node.value
        elif isinstance(node, ast.If):
            eval_type = _EvalType.from_condition(node.test.id)
            if eval_type is None:
                continue
            expression = node.body[0].value
        else:
            continue

        # Count uses of terms in the expression to later determine candidates for deletion.
        (counter := _UseCounter()).visit(expression)
        series.uses += [(*use, eval_type) for use in counter.uses]

    return series


def _parse_start(value: str | int) -> str:
    """Parse start value."""
    match value:
        case str(name):
            return name + "_data"
        case 0:
            return "zero_data"
        case 1:
            return "identity_data"


@cache
def _parse_algorithm(func: Callable) -> tuple[list[_Series], list[_Product], list[str]]:
    """Turn a function into an algorithm.

    Each algorithm is represented by a function definition.

    See the `series_computation` function for a more complete format description.

    Arguments:
    ---------
    func :
        The module containing the algorithm definitions.

    Returns:
    -------
    algorithm :
        A tuple containing the series, products, and outputs of the algorithm.

    """
    source = ast.parse(inspect.getsource(func))
    series, products, outputs = _preprocess_algorithm(source.body[0])
    to_delete = _find_delete_candidates(series, products, outputs)

    for term in series:
        term.definition = _EvalTransformer(to_delete[term.name]).visit(term.definition)

    return series, products, outputs


def _zero_sum(*terms: Any) -> Any:
    """Sum that returns a singleton zero if empty and omits zero terms.

    Parameters
    ----------
    terms :
        Terms to sum over with zero as default value.

    Returns
    -------
    Sum of terms, or zero if terms is empty.

    """
    return sum((term for term in terms if term is not zero), start=zero)


def _safe_divide(numerator, denominator):
    """Divide unless it's impossible, then multiply by inverse."""
    try:
        return numerator / denominator
    except TypeError:
        return numerator * (1 / denominator)


[docs] def series_computation( series: dict[str, BlockSeries], algorithm: Callable, scope: dict | None = None, *, operator: Callable | None = None, ) -> tuple[dict[str, BlockSeries], dict[str, BlockSeries]]: """Compile a `~pymablock.series.BlockSeries` computation. Given several series, functions to apply to their elements, and an algorithm, return the output series defined by the algorithm. The algorithm parsing used by this function applies multiple optimizations used in the Pymablock main algorithm. While these could be generated by hand, the resulting code is complex and repetitive. The mini-language used to specify the algorithm allows to avoid this complexity and enabled multiple improvements of the current algorithm and simplifies development of new algorithms. Specifically, code generation from the mini-language description of an algorithm: - Handles initialization of series and the definition of their evaluation functions. - Utilizes hermiticity and antihermiticity to reduce the number of evaluations. - Automatically handles the implicit mode when parts of the series are provided as linear operators. - Handles deletion of intermediate series terms that are only used once to reduce the memory usage. Implementing a new algorithm is advanced usage, and familiarity with the codebase is highly recommended. Parameters ---------- series : Dictionary with all input series, where the keys are the names of the series. algorithm : Algorithm to use for the block diagonalization. Should be passed as a callable whose contents follow the algorithm mini-language, see notes below. scope : Extra variables to pass to pass to the algorithm. It is particularly relevant for passing custom functions or data. operator : (optional) function to use for matrix multiplication. Defaults to matmul. Returns ------- series : dict[str, BlockSeries] A dictionary with all the series used in the computation. The keys are the names of the series and the values are the corresponding `~pymablock.series.BlockSeries`. linear_operator_series : dict[str, BlockSeries] The same series as above, but wrapped into linear operators. Only used in the implicit mode. Notes ----- The ``algorithm`` callable is not evaluated directly, but rather parsed to extract the computation that needs to be performed. It needs to follow the specification below. .. warning:: This domain-specific language is experimental and may change in the future. The function body contains multiple `with` statements that define the series and products of that algorithm. Throughout the definition the series and products are represented by their name using string literals. A series definition allows the following statements: - ``start = ...`` to define the zeroth order of the series. Allowed values are ``"series_name"``, ``0``, ``1``. - ``hermitian`` or ``antihermitian`` to optionally mark the lower triangular blocks of a series to be evaluated using a conjugate transpose of the upper triangular blocks. - One or more expressions that define how to evaluate the series. If there are multiple expressions, they are summed together. The expression can contain the following: - String literals to represent series. - Attribute ``.adj`` access to represent the conjugate adjoint of a series. - Integer literals. - Unary and binary operations. - Function calls. Using ``f("series")`` will call the function ``f`` with the series and the block index as arguments. Using ``f(expression)`` will call the function with the evaluated expression and block index as arguments. - ``if <condition>:`` differentiates evaluation based on the requested index. Allowed conditions are: - ``diagonal``: indices on the main diagonal. - ``offdiagonal``: indices *not* on the main diagonal. - ``lower``: indices in the lower triangle. If a name contains an "@" symbol, it defines a Cauchy product of the terms in it. For example ``"A @ B @ C"`` is a Cauchy product of the series ``A``, ``B``, and ``C``. A product definition must contain one of the two following statements: - ``hermitian`` to mark the product as hermitian. - ``pass`` otherwise. The final return statement in the function body defines a tuple of series that are the output of the algorithm and terms of which should not be deleted. Example ------- The algorithm definition may look as follows (this example does not do anything useful): .. code-block:: python def my_algorithm(): with "B": start = 0 hermitian if diagonal: "A" + f("B @ C") with "C": start = "A" if offdiagonal: "A" + "B" / 2 "B @ C" with "B @ C": hermitian return "C" Here ``"A"`` is an input, ``"B"`` and ``"C"`` are defined in the computation, and the function ``f`` must be provided using the scope. For an extended example, see the ``main`` function in ``pymablock/algorithms.py``. """ if operator is None: operator = matmul # For now we demand that all series are similar because outputs are like inputs. dimension_names = next(iter(series.values())).dimension_names if any(series.dimension_names != dimension_names for series in series.values()): raise ValueError("All series must have the same dimension names.") n_infinite = {series.n_infinite for series in series.values()} if len(n_infinite) > 1: raise ValueError("All series must have the same number of infinite indices.") n_infinite = next(iter(n_infinite)) shape = next(iter(series.values())).shape zeroth_order = (0,) * n_infinite all_blocks = [(i, j) for i in range(shape[0]) for j in range(shape[1])] diagonal = [(i, i) for i in range(shape[0])] zero_data = {block + zeroth_order: zero for block in all_blocks} identity_data = {block + zeroth_order: one for block in diagonal} data = { "zero_data": zero_data, "identity_data": identity_data, **{ f"{name}_0_data": { block + zeroth_order: series[block + zeroth_order] for block in all_blocks } for name, series in series.items() }, } # Common series kwargs to avoid some repetition series_kwargs = dict( shape=shape, n_infinite=n_infinite, dimension_names=dimension_names, ) def linear_operator_wrapped(original: BlockSeries) -> BlockSeries: return BlockSeries( eval=(lambda *index: aslinearoperator(original[index])), name=original.name, **series_kwargs, ) linear_operator_series = { name: linear_operator_wrapped(series) for name, series in series.items() } def del_(series_name, index: int) -> None: series[series_name].pop(index, None) linear_operator_series[series_name].pop(index, None) eval_scope = { # Defined in this function "series": series, "linear_operator_series": linear_operator_series, "del_": del_, "use_linear_operator": np.zeros(shape, dtype=bool), "offdiag": None, "diag": lambda x, index: x[index] if isinstance(x, BlockSeries) else x, # Globals "zero": zero, "_safe_divide": _safe_divide, "_zero_sum": _zero_sum, "Dagger": Dagger, # User-provided, may override the above **(scope or {}), } terms, products, outputs = _parse_algorithm(algorithm) for term in terms: # This defines `series_eval` as the eval function for this term. exec(compile(term.definition, filename="<string>", mode="exec"), eval_scope) series_data = data.get(term.start, None) series[term.name] = BlockSeries( eval=eval_scope["series_eval"], data=series_data, name=term.name, **series_kwargs, ) linear_operator_series[term.name] = linear_operator_wrapped(series[term.name]) for product in products: for which in series, linear_operator_series: which[product.name] = cauchy_dot_product( *(which[term] for term in product.terms), operator=operator, hermitian=product.hermitian, ) return series, linear_operator_series