Source code for pysiglib.trees

# Copyright 2026 Daniil Shmelev
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========================================================================

"""
Decorated rooted tree enumeration and indexing for branched signatures.

Trees use the kauri tuple convention::

    Empty tree:   None
    Leaf:         (label,)
    Internal:     (child1, child2, ..., root_label)

where children are sorted and ``label`` is an int in ``[0, dimension)``.
"""

from functools import cache

import kauri

from .param_checks import check_type, check_non_neg


def _check_tree_order(tree_order):
    if tree_order not in ("recursive", "canonical"):
        raise ValueError(f"tree_order must be 'recursive' or 'canonical', got {tree_order!r}")


def _canonical_to_recursive_perm(dimension, degree, planar):
    if planar:
        return kauri.planar_canonical_to_recursive_permutation(dimension, degree)
    return kauri.canonical_to_recursive_permutation(dimension, degree)


def _recursive_to_canonical_perm(dimension, degree, planar):
    if planar:
        return kauri.planar_recursive_to_canonical_permutation(dimension, degree)
    return kauri.recursive_to_canonical_permutation(dimension, degree)


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------

[docs] def trees_of_order( dimension: int, order: int, *, tree_order: str = "canonical", planar: bool = False, ) -> tuple[tuple]: """ Returns all decorated rooted trees with exactly ``order`` nodes, in the specified ordering. :param dimension: Path dimension (alphabet size). :type dimension: int :param order: Exact number of nodes. :type order: int :param tree_order: Tree ordering convention. ``"canonical"`` (default) uses the shape-first order matching :func:`tree_to_idx` and ``branched_sig(..., tree_order="canonical")``. ``"recursive"`` uses the recursive bottom-up construction order matching ``branched_sig(..., tree_order="recursive")`` (the default). :type tree_order: str :param planar: If True, enumerate planar (ordered) trees in which the order of children matters. :type planar: bool :return: Tuple of trees as tuples in kauri convention. :rtype: tuple[tuple] Example: --------- .. code-block:: python import pysiglib # All single-node trees over dimension 2 t = pysiglib.trees_of_order(2, 1) print(t) # ((0,), (1,)) """ _check_tree_order(tree_order) check_type(dimension, "dimension", int) check_type(order, "order", int) check_type(planar, "planar", bool) check_non_neg(dimension, "dimension") check_non_neg(order, "order") return _trees_of_order_cached(dimension, order, tree_order, planar)
@cache def _trees_of_order_cached(dimension, order, tree_order, planar): if order == 0: return (None,) if tree_order == "canonical": if planar: return tuple(t.sorted_list_repr() for t in kauri.colored_planar_trees_of_order(order, dimension)) return tuple(t.sorted_list_repr() for t in kauri.colored_trees_of_order(order, dimension)) # Recursive ordering: compute all trees up to `order` in recursive order, # then slice off the last stratum (trees of exactly `order` nodes). from .branched_sig import branched_sig_length all_trees = _trees_cached(dimension, order, "recursive", planar) lower = branched_sig_length(dimension, order - 1, scalar_term=True, planar=planar) upper = branched_sig_length(dimension, order, scalar_term=True, planar=planar) return all_trees[lower:upper]
[docs] def trees( dimension: int, degree: int, *, tree_order: str = "canonical", planar: bool = False, ) -> tuple[tuple]: """ Returns all decorated rooted trees up to a given degree (max nodes), starting with the empty tree (``None``), in the specified ordering. :param dimension: Path dimension (alphabet size). :type dimension: int :param degree: Maximum number of nodes per tree. :type degree: int :param tree_order: Tree ordering convention. ``"canonical"`` (default) uses the shape-first order matching :func:`tree_to_idx` and ``branched_sig(..., tree_order="canonical")``. ``"recursive"`` uses the recursive bottom-up construction order matching ``branched_sig(..., tree_order="recursive")`` (the default). :type tree_order: str :param planar: If True, enumerate planar (ordered) trees in which the order of children matters. :type planar: bool :return: All decorated rooted trees up to the given degree. :rtype: tuple[tuple] Example: --------- .. code-block:: python import pysiglib t = pysiglib.trees(2, 2) print(t) # (None, (0,), (1,), ((0,), 0), ((1,), 0), ((0,), 1), ((1,), 1)) """ _check_tree_order(tree_order) check_type(dimension, "dimension", int) check_type(degree, "degree", int) check_type(planar, "planar", bool) check_non_neg(dimension, "dimension") check_non_neg(degree, "degree") return _trees_cached(dimension, degree, tree_order, planar)
@cache def _trees_cached(dimension, degree, tree_order, planar): if tree_order == "canonical": if planar: canonical = tuple( t.sorted_list_repr() for t in kauri.colored_planar_trees_up_to_order(degree, dimension) ) else: canonical = tuple(t.sorted_list_repr() for t in kauri.colored_trees(dimension, degree)) return canonical # Recursive: start from canonical and permute the non-empty trees. canonical = _trees_cached(dimension, degree, "canonical", planar) if len(canonical) <= 1: return canonical perm = _canonical_to_recursive_perm(dimension, degree, planar) # perm[i] = recursive position of the canonical non-empty tree at index i. # Build recursive-ordered list: recursive[0] = None; recursive[perm[i] + 1] = canonical[i + 1]. n_non_empty = len(canonical) - 1 recursive = [None] * (n_non_empty + 1) for i in range(n_non_empty): recursive[perm[i] + 1] = canonical[i + 1] return tuple(recursive)
[docs] def tree_to_idx( tree, dimension: int, degree: int, *, tree_order: str = "canonical", planar: bool = False, scalar_term: bool = False, ) -> int: """ Given a decorated rooted tree, returns its flat index in the branched-signature coefficient vector. Trees use the kauri tuple convention: - Empty tree: ``None`` -- index 0 when ``scalar_term=True``; invalid otherwise - Leaf: ``(label,)`` where ``label`` is in ``[0, dimension)`` - Internal node: ``(child_1, child_2, ..., root_label)`` With ``scalar_term=True`` the empty tree sits at index 0. With ``scalar_term=False`` (default) there is no empty-tree entry, so all indices shift down by 1 and ``None`` is invalid. :param tree: Decorated rooted tree as a tuple (or None for empty when ``scalar_term=True``). :param dimension: Path dimension (alphabet size). :type dimension: int :param degree: Maximum number of nodes (same as ``degree`` in :func:`branched_sig`). :type degree: int :param tree_order: Tree ordering convention. ``"canonical"`` (default) matches ``branched_sig(..., tree_order="canonical")``. ``"recursive"`` matches ``branched_sig(..., tree_order="recursive")`` (the default). :type tree_order: str :param planar: If True, use the planar (ordered) enumeration matching ``branched_sig(..., planar=True)``. :type planar: bool :param scalar_term: Whether the target branched signature includes the leading scalar 1 at index 0. Must match the format of the bsig you intend to index. Default ``False``. :type scalar_term: bool :return: Flat index in the branched signature vector. :rtype: int Example: --------- .. code-block:: python import torch import pysiglib path = torch.rand(size=(100, 2)) pysiglib.prepare_branched_sig(2, 3) bsig = pysiglib.branched_sig(path, 3, tree_order="canonical") # scalar_term=False tree = ((1,), 0) idx = pysiglib.tree_to_idx(tree, dimension=2, degree=3) print(f"Index: {idx}, Coefficient: {bsig[idx]}") """ _check_tree_order(tree_order) check_type(dimension, "dimension", int) check_type(degree, "degree", int) check_type(planar, "planar", bool) check_non_neg(dimension, "dimension") check_non_neg(degree, "degree") if tree is None and not scalar_term: raise ValueError( "The empty tree has no index in a branched signature with scalar_term=False. " "Pass scalar_term=True if your bsig includes the leading scalar 1." ) return _tree_to_idx_cached(tree, dimension, degree, tree_order, planar, scalar_term)
@cache def _tree_to_idx_cached(tree, dimension, degree, tree_order, planar, scalar_term): if tree is None: return 0 if planar: canonical_idx = kauri.colored_planar_tree_to_idx(kauri.PlanarTree(tree), dimension, degree) else: canonical_idx = kauri.colored_tree_to_idx(kauri.Tree(tree), dimension, degree) if tree_order == "canonical": idx = canonical_idx else: perm = _canonical_to_recursive_perm(dimension, degree, planar) idx = int(perm[canonical_idx - 1]) + 1 return idx if scalar_term else idx - 1
[docs] def idx_to_tree( idx: int, dimension: int, degree: int, *, tree_order: str = "canonical", planar: bool = False, scalar_term: bool = False, ) -> tuple: """ Inverse of :func:`tree_to_idx`. Given a flat index in the branched-signature coefficient vector, returns the corresponding decorated rooted tree. With ``scalar_term=True``, index 0 maps to the empty tree (``None``). With ``scalar_term=False`` (default), all indices shift down by 1 (index 0 maps to the first non-empty tree) and the empty tree is unreachable. :param idx: Flat index in the branched signature vector. :type idx: int :param dimension: Path dimension (alphabet size). :type dimension: int :param degree: Maximum number of nodes (same as ``degree`` in :func:`branched_sig`). :type degree: int :param tree_order: Tree ordering convention. ``"canonical"`` (default) matches ``branched_sig(..., tree_order="canonical")``. ``"recursive"`` matches ``branched_sig(..., tree_order="recursive")`` (the default). :type tree_order: str :param planar: If True, interpret ``idx`` in the planar (ordered) enumeration matching ``branched_sig(..., planar=True)``. :type planar: bool :param scalar_term: Whether the source branched signature includes the leading scalar 1 at index 0. Must match the format of the bsig the index was taken from. Default ``False``. :type scalar_term: bool :return: Decorated rooted tree (None for empty tree when ``scalar_term=True``, tuple otherwise). :rtype: tuple or None Example: --------- .. code-block:: python import pysiglib tree = pysiglib.idx_to_tree(3, dimension=2, degree=3) print(tree) """ check_type(idx, "idx", int) check_type(dimension, "dimension", int) check_type(degree, "degree", int) check_type(planar, "planar", bool) check_non_neg(idx, "idx") check_non_neg(dimension, "dimension") check_non_neg(degree, "degree") _check_tree_order(tree_order) return _idx_to_tree_cached(idx, dimension, degree, tree_order, planar, scalar_term)
@cache def _idx_to_tree_cached(idx, dimension, degree, tree_order, planar, scalar_term): if not scalar_term: idx = idx + 1 if idx == 0: return None if tree_order == "recursive": inv_perm = _recursive_to_canonical_perm(dimension, degree, planar) canonical_idx = int(inv_perm[idx - 1]) + 1 else: canonical_idx = idx if planar: kt = kauri.idx_to_colored_planar_tree(canonical_idx, dimension, degree) else: kt = kauri.idx_to_colored_tree(canonical_idx, dimension, degree) return kt.sorted_list_repr()