Source code for pysiglib.branched_sig_kernel

# Copyright 2026 Daniil Shmelev
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========================================================================

from typing import Optional, Union
from ctypes import POINTER, cast
import warnings

import numpy as np
import torch

from .data_handlers import MultiplePathInputHandler, ScalarOutputHandler, GridOutputHandler
from .dtypes import (
    CPSIG_BRANCHED_SIG_KERNEL,
    CUSIG_BRANCHED_SIG_KERNEL_CUDA,
    DTYPES,
)
from .error_codes import err_msg
from .param_checks import (
    check_non_neg,
    check_n_jobs,
    check_type,
    dyadic_grid_length,
    parse_dyadic_order,
)
from .sig_kernel import _ensure_3d, _safe_normalize
from .static_kernels import Context, LinearKernel, StaticKernel
from .transform_path import transform_path


[docs] def branched_sig_kernel( path1: Union[np.ndarray, torch.Tensor], path2: Union[np.ndarray, torch.Tensor], depth: int, dyadic_order: Union[int, tuple], *, static_kernel: Optional[StaticKernel] = None, time_aug: bool = False, lead_lag: bool = False, end_time: float = 1., n_jobs: int = 1, return_grid: bool = False, normalize: bool = False, ) -> Union[np.ndarray, torch.Tensor]: """ Computes a single branched signature kernel or a batch of branched signature kernels. This is the non-planar BCK branched rough path kernel of Chevyrev and Oberhauser. It uses the depth recursion ``K_0 = 1`` and ``K_m = exp(integral K_{m-1} dk)`` with the same static-kernel increment abstraction as :func:`pysiglib.sig_kernel`. The parameter ``depth`` truncates forests by tree depth, not by the number of nodes. :param path1: First path or batch of paths, shape ``(..., length_1, dimension)``. :type path1: numpy.ndarray | torch.Tensor :param path2: Second path or batch of paths, shape ``(..., length_2, dimension)``. :type path2: numpy.ndarray | torch.Tensor :param depth: Forest depth truncation for the branched kernel. :type depth: int :param dyadic_order: Dyadic refinement order, or a pair of orders. :type dyadic_order: int | tuple :param static_kernel: Static kernel. If ``None``, the linear kernel is used. :type static_kernel: None | pysiglib.StaticKernel :param time_aug: Whether to time augment the paths. :type time_aug: bool :param lead_lag: Whether to apply the lead-lag transformation. :type lead_lag: bool :param end_time: End time for time augmentation. :type end_time: float :param n_jobs: Number of CPU worker threads. :type n_jobs: int :param return_grid: If ``True``, returns the final-depth grid. :type return_grid: bool :param normalize: If ``True``, normalizes by ``sqrt(K(x, x) * K(y, y))``. Cannot be used with ``return_grid=True``. :type normalize: bool :return: Branched signature kernel value or final-depth grid. :rtype: numpy.ndarray | torch.Tensor """ check_type(depth, "depth", int) check_non_neg(depth, "depth") check_type(time_aug, "time_aug", bool) check_type(lead_lag, "lead_lag", bool) check_n_jobs(n_jobs) if normalize and return_grid: raise ValueError("normalize=True cannot be used with return_grid=True") dyadic_order_1, dyadic_order_2 = parse_dyadic_order(dyadic_order) if time_aug or lead_lag: path1 = transform_path(path1, time_aug=time_aug, lead_lag=lead_lag, end_time=end_time, n_jobs=n_jobs) path2 = transform_path(path2, time_aug=time_aug, lead_lag=lead_lag, end_time=end_time, n_jobs=n_jobs) data = MultiplePathInputHandler([path1, path2], False, False, 0., ["path1", "path2"]) if not return_grid: result = ScalarOutputHandler(data) else: dyadic_len_1 = dyadic_grid_length(data.length[0], dyadic_order_1) dyadic_len_2 = dyadic_grid_length(data.length[1], dyadic_order_2) result = GridOutputHandler(dyadic_len_1, dyadic_len_2, data) if data.batch_size == 0: return result.data torch_path1 = _ensure_3d(torch.as_tensor(data.path[0])) torch_path2 = _ensure_3d(torch.as_tensor(data.path[1])) ctx = Context() if static_kernel is None: static_kernel = LinearKernel() elif not isinstance(static_kernel, StaticKernel): raise ValueError("kernel must be a child class of pysiglib.StaticKernel") gram = static_kernel(ctx, torch_path1, torch_path2) gram_ptr = cast(gram.data_ptr(), POINTER(DTYPES[str(gram.dtype)[6:]])) if data.device == "cpu": err_code = CPSIG_BRANCHED_SIG_KERNEL[data.dtype]( gram_ptr, result.data_ptr, data.batch_size, data.dimension, data.length[0], data.length[1], depth, dyadic_order_1, dyadic_order_2, return_grid, n_jobs) else: err_code = CUSIG_BRANCHED_SIG_KERNEL_CUDA[data.dtype]( gram_ptr, result.data_ptr, data.batch_size, data.dimension, data.length[0], data.length[1], depth, dyadic_order_1, dyadic_order_2, return_grid) if err_code: raise Exception("Error in pysiglib.branched_sig_kernel: " + err_msg(err_code)) if isinstance(result.data, np.ndarray): has_bad = np.isnan(result.data).any() or np.isinf(result.data).any() else: has_bad = torch.isnan(result.data).any().item() or torch.isinf(result.data).any().item() if has_bad: warnings.warn( "branched_sig_kernel produced NaN or Inf values. This is typically " "caused by large static-kernel increments. Consider scaling paths " "or using a bounded static kernel.", RuntimeWarning, stacklevel=2, ) if normalize: k1 = branched_sig_kernel(path1, path1, depth, dyadic_order, static_kernel=static_kernel, n_jobs=n_jobs) k2 = branched_sig_kernel(path2, path2, depth, dyadic_order, static_kernel=static_kernel, n_jobs=n_jobs) result.data = _safe_normalize(result.data, k1, k2, "branched_sig_kernel(normalize=True)") return result.data
[docs] def branched_sig_kernel_gram( path1: Union[np.ndarray, torch.Tensor], path2: Union[np.ndarray, torch.Tensor], depth: int, dyadic_order: Union[int, tuple], *, static_kernel: Optional[StaticKernel] = None, time_aug: bool = False, lead_lag: bool = False, end_time: float = 1., n_jobs: int = 1, max_batch: int = -1, return_grid: bool = False, normalize: bool = False, ) -> Union[np.ndarray, torch.Tensor]: """ Computes a Gram matrix of branched signature kernels. :param path1: First path batch, shape ``(*batch_shape_1, length_1, dimension)``. :type path1: numpy.ndarray | torch.Tensor :param path2: Second path batch, shape ``(*batch_shape_2, length_2, dimension)``. :type path2: numpy.ndarray | torch.Tensor :param depth: Forest depth truncation for the branched kernel. :type depth: int :param dyadic_order: Dyadic refinement order, or a pair of orders. :type dyadic_order: int | tuple :param max_batch: Maximum side length of pair chunks. ``-1`` uses all pairs. :type max_batch: int :param return_grid: If ``True``, returns final-depth grids per pair. :type return_grid: bool :param normalize: If ``True``, normalizes the Gram matrix. :type normalize: bool :return: Gram matrix of branched signature kernels. :rtype: numpy.ndarray | torch.Tensor """ check_type(depth, "depth", int) check_non_neg(depth, "depth") check_type(time_aug, "time_aug", bool) check_type(lead_lag, "lead_lag", bool) check_type(max_batch, "max_batch", int) if max_batch == 0 or max_batch < -1: raise ValueError("max_batch must be a positive integer or -1") if normalize and return_grid: raise ValueError("normalize=True cannot be used with return_grid=True") symmetric = path1 is path2 batch_shape_1 = tuple(path1.shape[:-2]) batch_shape_2 = batch_shape_1 if symmetric else tuple(path2.shape[:-2]) path1 = _ensure_3d(path1) path2 = path1 if symmetric else _ensure_3d(path2) data = MultiplePathInputHandler([path1, path2], time_aug, lead_lag, end_time, ["path1", "path2"], False) path1 = torch.as_tensor(data.path[0]) path2 = torch.as_tensor(data.path[1]) batch1 = path1.shape[0] batch2 = path2.shape[0] do1, do2 = parse_dyadic_order(dyadic_order) if return_grid: gl1 = dyadic_grid_length(data.length[0], do1) gl2 = dyadic_grid_length(data.length[1], do2) if symmetric and gl1 != gl2: symmetric = False if batch1 == 0 or batch2 == 0: out_shape = batch_shape_1 + batch_shape_2 if return_grid: out_shape = out_shape + (gl1, gl2) res = torch.empty(out_shape, dtype=path1.dtype, device=path1.device) return res.numpy() if data.type_ == "numpy" else res if max_batch == -1: max_batch = max(batch1, batch2) if symmetric: idx_i, idx_j = torch.triu_indices(batch1, batch1, device=path1.device) src1, src2 = path1, path1 else: idx_i = torch.arange(batch1, device=path1.device).repeat_interleave(batch2) idx_j = torch.arange(batch2, device=path2.device).repeat(batch1) src1, src2 = path1, path2 if return_grid: res = torch.empty(batch1, batch2, gl1, gl2, dtype=path1.dtype, device=path1.device) else: res = torch.empty(batch1, batch2, dtype=path1.dtype, device=path1.device) chunk_size = max_batch * max_batch for start in range(0, idx_i.shape[0], chunk_size): end = min(start + chunk_size, idx_i.shape[0]) ci = idx_i[start:end] cj = idx_j[start:end] k = branched_sig_kernel( src1[ci], src2[cj], depth, dyadic_order, static_kernel=static_kernel, time_aug=time_aug, lead_lag=lead_lag, end_time=end_time, n_jobs=n_jobs, return_grid=return_grid) res[ci, cj] = k if symmetric: off = ci != cj if off.any(): k_mirror = k[off] if return_grid: k_mirror = k_mirror.transpose(-2, -1) res[cj[off], ci[off]] = k_mirror if normalize: d1 = branched_sig_kernel(path1, path1, depth, dyadic_order, static_kernel=static_kernel, time_aug=time_aug, lead_lag=lead_lag, end_time=end_time, n_jobs=n_jobs) d2 = d1 if symmetric else branched_sig_kernel( path2, path2, depth, dyadic_order, static_kernel=static_kernel, time_aug=time_aug, lead_lag=lead_lag, end_time=end_time, n_jobs=n_jobs) res = _safe_normalize(res, d1.unsqueeze(1), d2.unsqueeze(0), "branched_sig_kernel_gram(normalize=True)") out_shape = batch_shape_1 + batch_shape_2 if return_grid: out_shape = out_shape + (res.shape[-2], res.shape[-1]) res = res.reshape(out_shape) if out_shape else res.squeeze() if data.type_ == "numpy": return res.numpy() return res