Source code for pysiglib.trees

# Copyright 2026 Daniil Shmelev
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========================================================================

"""
Decorated basis enumeration and indexing for branched signatures.

Non-planar BCK signatures are indexed by decorated rooted trees. Planar MKW
signatures are indexed by ordered forests of decorated planar rooted trees.

Trees use the native pySigLib tuple convention::

    Empty tree:   None
    Leaf:         (label,)
    Internal:     (child1, child2, ..., root_label)

where children are sorted in the non-planar case and ``label`` is an int in
``[0, dimension)``. Ordered forests are represented as ``(tree1, tree2, ...)``.
The order is the implementation order used by the native CPU and CUDA caches.
"""

from functools import cache

from .param_checks import check_type, check_non_neg


def _tree_node_count(tree):
    return 1 + sum(_tree_node_count(child) for child in tree[:-1])


def _is_planar_forest_tuple(obj):
    return isinstance(obj, tuple) and len(obj) > 0 and not isinstance(obj[-1], int)


def _as_planar_forest_tuple(tree):
    if tree is None:
        return None
    if _is_planar_forest_tuple(tree):
        return tree
    return (tree,)


def _child_sequences(target_nodes, trees, planar):
    out = []
    current = []

    def enumerate_(remaining, min_idx):
        if remaining == 0:
            out.append(tuple(current))
            return
        start = 0 if planar else min_idx
        for idx in range(start, len(trees)):
            nodes = _tree_node_count(trees[idx])
            if nodes > remaining:
                break
            current.append(idx)
            enumerate_(remaining - nodes, idx)
            current.pop()

    enumerate_(target_nodes, 0)
    return tuple(out)


@cache
def _native_tree_basis_cached(dimension, degree, planar):
    trees = []
    for order in range(1, degree + 1):
        if order == 1:
            for label in range(dimension):
                trees.append((label,))
            continue

        current_trees = tuple(trees)
        for children in _child_sequences(order - 1, current_trees, planar):
            if not children:
                continue
            for label in range(dimension):
                trees.append(tuple(current_trees[idx] for idx in children) + (label,))
    return tuple(trees)


@cache
def _native_forest_basis_cached(dimension, degree):
    trees = _native_tree_basis_cached(dimension, degree, True)
    forests = []
    current = []

    def enumerate_(remaining):
        if remaining == 0:
            forests.append(tuple(current))
            return
        for tree in trees:
            nodes = _tree_node_count(tree)
            if nodes > remaining:
                break
            current.append(tree)
            enumerate_(remaining - nodes)
            current.pop()

    for order in range(1, degree + 1):
        enumerate_(order)
    return tuple(forests)


@cache
def _basis_cached(dimension, degree, planar):
    if planar:
        non_empty = _native_forest_basis_cached(dimension, degree)
    else:
        non_empty = _native_tree_basis_cached(dimension, degree, False)
    return (None,) + non_empty


@cache
def _basis_index_cached(dimension, degree, planar):
    return {basis_element: idx for idx, basis_element in enumerate(_basis_cached(dimension, degree, planar))}


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------

[docs] def trees_of_order( dimension: int, order: int, *, planar: bool = False, ) -> tuple[tuple]: """ Returns all basis elements with exactly ``order`` nodes. For ``planar=False`` these are non-planar decorated rooted trees. For ``planar=True`` these are ordered forests of decorated planar rooted trees, with ``order`` equal to the total number of nodes in the forest. The order matches the native pySigLib coefficient layout. :param dimension: Path dimension (alphabet size). :type dimension: int :param order: Exact number of nodes. :type order: int :param planar: If True, enumerate ordered forests of planar rooted trees. :type planar: bool :return: Tuple of basis elements as tuples in native pySigLib convention. :rtype: tuple[tuple] Example: --------- .. code-block:: python import pysiglib # All single-node trees over dimension 2 t = pysiglib.trees_of_order(2, 1) print(t) # ((0,), (1,)) """ check_type(dimension, "dimension", int) check_type(order, "order", int) check_type(planar, "planar", bool) check_non_neg(dimension, "dimension") check_non_neg(order, "order") if order == 0: return (None,) return tuple( basis_element for basis_element in _basis_cached(dimension, order, planar)[1:] if _basis_element_order(basis_element, planar) == order )
[docs] def trees( dimension: int, degree: int, *, planar: bool = False, ) -> tuple[tuple]: """ Returns all basis elements up to a given degree, starting with the empty tree (``None``). For ``planar=False`` these are non-planar decorated rooted trees. For ``planar=True`` these are ordered forests of decorated planar rooted trees, with degree equal to the maximum total number of nodes in the forest. The order matches the native pySigLib coefficient layout. :param dimension: Path dimension (alphabet size). :type dimension: int :param degree: Maximum number of nodes per tree or forest. :type degree: int :param planar: If True, enumerate ordered forests of planar rooted trees. :type planar: bool :return: All basis elements up to the given degree. :rtype: tuple[tuple] Example: --------- .. code-block:: python import pysiglib t = pysiglib.trees(2, 2) print(t) # (None, (0,), (1,), ((0,), 0), ((1,), 0), ((0,), 1), ((1,), 1)) """ check_type(dimension, "dimension", int) check_type(degree, "degree", int) check_type(planar, "planar", bool) check_non_neg(dimension, "dimension") check_non_neg(degree, "degree") return _basis_cached(dimension, degree, planar)
def _basis_element_order(basis_element, planar): if planar: return sum(_tree_node_count(tree) for tree in basis_element) return _tree_node_count(basis_element)
[docs] def tree_to_idx( tree, dimension: int, degree: int, *, planar: bool = False, scalar_term: bool = False, ) -> int: """ Given a basis element, returns its flat index in the branched-signature coefficient vector. Trees use the native pySigLib tuple convention: - Empty tree: ``None`` - index 0 when ``scalar_term=True``; invalid otherwise - Leaf: ``(label,)`` where ``label`` is in ``[0, dimension)`` - Internal node: ``(child_1, child_2, ..., root_label)`` - Planar ordered forest: ``(tree_1, tree_2, ...)`` when ``planar=True`` With ``scalar_term=True`` the empty tree sits at index 0. With ``scalar_term=False`` (default) there is no empty-tree entry, so all indices shift down by 1 and ``None`` is invalid. In the planar case, the branched-signature basis is indexed by ordered forests. Passing a single tree is accepted as shorthand for the one-tree forest. :param tree: Decorated rooted tree, ordered forest, or None for empty when ``scalar_term=True``. :type tree: tuple | None :param dimension: Path dimension (alphabet size). :type dimension: int :param degree: Maximum number of nodes (same as ``degree`` in :func:`branched_sig`). :type degree: int :param planar: If True, use the planar ordered-forest enumeration matching ``branched_sig(..., planar=True)``. :type planar: bool :param scalar_term: Whether the target branched signature includes the leading scalar 1 at index 0. Must match the format of the bsig you intend to index. Default ``False``. :type scalar_term: bool :return: Flat index in the branched signature vector. :rtype: int Example: --------- .. code-block:: python import torch import pysiglib path = torch.rand(size=(100, 2)) pysiglib.prepare_branched_sig(2, 3) bsig = pysiglib.branched_sig(path, 3) # scalar_term=False tree = ((1,), 0) idx = pysiglib.tree_to_idx(tree, dimension=2, degree=3) print(f"Index: {idx}, Coefficient: {bsig[idx]}") """ check_type(dimension, "dimension", int) check_type(degree, "degree", int) check_type(planar, "planar", bool) check_non_neg(dimension, "dimension") check_non_neg(degree, "degree") if tree is None and not scalar_term: raise ValueError( "The empty tree has no index in a branched signature with scalar_term=False. " "Pass scalar_term=True if your bsig includes the leading scalar 1." ) basis_element = _as_planar_forest_tuple(tree) if planar else tree index = _basis_index_cached(dimension, degree, planar).get(basis_element) if index is None: raise ValueError("tree is not in the requested branched signature basis") return index if scalar_term else index - 1
[docs] def idx_to_tree( idx: int, dimension: int, degree: int, *, planar: bool = False, scalar_term: bool = False, ) -> tuple: """ Inverse of :func:`tree_to_idx`. Given a flat index in the branched-signature coefficient vector, returns the corresponding basis element. With ``scalar_term=True``, index 0 maps to the empty tree (``None``). With ``scalar_term=False`` (default), all indices shift down by 1 (index 0 maps to the first non-empty tree) and the empty tree is unreachable. :param idx: Flat index in the branched signature vector. :type idx: int :param dimension: Path dimension (alphabet size). :type dimension: int :param degree: Maximum number of nodes (same as ``degree`` in :func:`branched_sig`). :type degree: int :param planar: If True, interpret ``idx`` in the planar ordered-forest enumeration matching ``branched_sig(..., planar=True)``. :type planar: bool :param scalar_term: Whether the source branched signature includes the leading scalar 1 at index 0. Must match the format of the bsig the index was taken from. Default ``False``. :type scalar_term: bool In the planar case, the returned tuple is an ordered forest, represented as ``(tree_1, tree_2, ...)``. :return: Decorated rooted tree, ordered forest, or None for empty tree when ``scalar_term=True``. :rtype: tuple or None Example: --------- .. code-block:: python import pysiglib tree = pysiglib.idx_to_tree(3, dimension=2, degree=3) print(tree) """ check_type(idx, "idx", int) check_type(dimension, "dimension", int) check_type(degree, "degree", int) check_type(planar, "planar", bool) check_non_neg(idx, "idx") check_non_neg(dimension, "dimension") check_non_neg(degree, "degree") if not scalar_term: idx = idx + 1 basis = _basis_cached(dimension, degree, planar) if idx >= len(basis): raise ValueError("idx is out of range for the requested branched signature basis") return basis[idx]