# Copyright 2025 Daniil Shmelev
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========================================================================
from typing import Union, Tuple, Optional
from ctypes import POINTER, cast
from math import prod
import numpy as np
import torch
from .transform_path import transform_path
from .transform_path_backprop import transform_path_backprop
from .sig_kernel import sig_kernel, _ensure_3d
from .param_checks import check_type, parse_dyadic_order, check_n_jobs
from .error_codes import err_msg
from .dtypes import CPSIG_SIG_KERNEL_BACKPROP, DTYPES, CUSIG_SIG_KERNEL_BACKPROP_CUDA
from .data_handlers import MultiplePathInputHandler, ScalarInputHandler, GridOutputHandler, PathInputHandler
from .static_kernels import StaticKernel, LinearKernel, Context
def gram_deriv(
derivs_data,
data,
gram : Union[np.ndarray, torch.tensor],
k_grid_data : Union[np.ndarray, torch.tensor],
dyadic_order_1,
dyadic_order_2,
return_grid : bool = False,
n_jobs : int = 1
) -> Union[np.ndarray, torch.tensor]:
result = GridOutputHandler(data.length[0] - 1, data.length[1] - 1, derivs_data) #Derivatives with respect to gram matrix
gram_ptr = cast(gram.data_ptr(), POINTER(DTYPES[str(gram.dtype)[6:]]))
if data.device == "cpu":
err_code = CPSIG_SIG_KERNEL_BACKPROP[data.dtype](
gram_ptr, result.data_ptr, derivs_data.data_ptr, k_grid_data.data_ptr,
data.batch_size, data.dimension, data.length[0], data.length[1],
dyadic_order_1, dyadic_order_2, return_grid, n_jobs)
else:
err_code = CUSIG_SIG_KERNEL_BACKPROP_CUDA[data.dtype](
gram_ptr, result.data_ptr, derivs_data.data_ptr, k_grid_data.data_ptr,
data.batch_size, data.dimension, data.length[0], data.length[1],
dyadic_order_1, dyadic_order_2, return_grid)
if err_code:
raise Exception("Error in pysiglib.sig_kernel_backprop: " + err_msg(err_code))
return torch.as_tensor(result.data)
[docs]
def sig_kernel_backprop(
derivs : Union[np.ndarray, torch.tensor],
path1 : Union[np.ndarray, torch.tensor],
path2 : Union[np.ndarray, torch.tensor],
dyadic_order : Union[int, tuple],
*,
static_kernel : Optional[StaticKernel] = None,
time_aug : bool = False,
lead_lag : bool = False,
end_time : float = 1.,
left_deriv : bool = True,
right_deriv : bool = False,
k_grid : Union[np.ndarray, torch.tensor] = None,
n_jobs : int = 1,
return_grid: bool = False
) -> Union[np.ndarray, torch.tensor, Tuple[np.ndarray, np.ndarray], Tuple[torch.tensor, torch.tensor]]:
"""
This function is required to backpropagate through ``pysiglib.sig_kernel``.
Given the derivatives of a scalar function :math:`F` with respect to a
signature kernel, :math:`\\partial F / \\left< S(x), S(y) \\right>`,
returns the derivatives of :math:`F` with respect to one or both of the
underlying paths, :math:`\\{\\partial F / x_{t_i}\\}_{i=0}^{L_1}` and
:math:`\\{\\partial F / y_{t_i}\\}_{i=0}^{L_2}`.
:param derivs: Derivatives with respect to a signature kernel or batch
of signature kernels, :math:`\\partial F / \\left< S(x), S(y) \\right>`.
If ``return_grid=False``, this should be of shape ``(...)`` matching the leading
batch dimensions of the paths. If ``return_grid=True``, this should have the same
shape as the PDE grid returned by ``pysiglib.sig_kernel(..., return_grid=True)``.
:type derivs: numpy.ndarray | torch.tensor
:param path1: The first underlying path or batch of paths, of shape
``(..., length_1, dimension)``.
:type path1: numpy.ndarray | torch.tensor
:param path2: The second underlying path or batch of paths, of shape
``(..., length_2, dimension)``. Leading batch dimensions must match those of
``path1``.
:type path2: numpy.ndarray | torch.tensor
:param dyadic_order: The dyadic order(s) used to compute the signature kernels.
:type dyadic_order: int | tuple
:param static_kernel: Static kernel. If ``None`` (default), the linear kernel will be used.
For details, see the documentation on :doc:`static kernels </pages/signature_kernels/static_kernels>`.
:type static_kernel: None | pysiglib.StaticKernel
:param time_aug: Whether the signature kernels were computed with ``time_aug=True``.
:type time_aug: bool
:param lead_lag: Whether the signature kernels were computed with ``lead_lag=True``.
:type lead_lag: bool
:param end_time: End time for time-augmentation, :math:`t_L`.
:type end_time: float
:param left_deriv: If ``True``, returns :math:`\\{\\partial F / x_{t_i}\\}_{i=0}^{L_1}`.
At least one of ``left_deriv`` and ``right_deriv`` must be ``True``. If both are
``True``, returns both derivatives as a tuple.
:type left_deriv: bool
:param right_deriv: If ``True``, returns :math:`\\{\\partial F / y_{t_i}\\}_{i=0}^{L_2}`.
At least one of ``left_deriv`` and ``right_deriv`` must be ``True``. If both are
``True``, returns both derivatives as a tuple.
:type right_deriv: bool
:param k_grid: Signature kernel PDE grid. If ``None``, the grid will be recomputed.
:type k_grid: numpy.ndarray | torch.tensor
:param n_jobs: (Only applicable to CPU computation) Number of threads to run in parallel.
If n_jobs = 1, the computation is run serially. If set to -1, all available threads
are used. For n_jobs below -1, (max_threads + 1 + n_jobs) threads are used. For example
if n_jobs = -2, all threads but one are used.
:type n_jobs: int
:param return_grid: If ``True``, backpropagates derivatives with respect to the entire PDE grid.
:type return_grid: bool
:return: Tuple of derivatives of :math:`F` with respect to one or both of the
underlying paths. If ``left_deriv`` is ``True``, the first element of
this tuple is :math:`\\{\\partial F / x_{t_i}\\}_{i=0}^{L_1}`, otherwise
it is ``None``. Similarly for ``right_deriv`` and
:math:`\\{\\partial F / y_{t_i}\\}_{i=0}^{L_2}`.
:rtype: numpy.ndarray | torch.tensor | Tuple[numpy.ndarray | numpy.ndarray] | Tuple[torch.tensor | torch.tensor]
Example:
---------
.. code-block:: python
import torch
import pysiglib
path1 = torch.rand((10, 100, 5))
path2 = torch.rand((10, 100, 5))
k = pysiglib.sig_kernel(path1, path2, dyadic_order=2)
derivs = torch.ones_like(k)
dpath1, _ = pysiglib.sig_kernel_backprop(derivs, path1, path2, dyadic_order=2)
print(dpath1)
.. code-block:: python
# Backprop with a static kernel and time augmentation
import torch
import pysiglib
path1 = torch.rand((10, 100, 5))
path2 = torch.rand((10, 100, 5))
rbf = pysiglib.RBFKernel(sigma=1.0)
k = pysiglib.sig_kernel(
path1, path2, dyadic_order=2, static_kernel=rbf, time_aug=True,
)
derivs = torch.ones_like(k)
dpath1, _ = pysiglib.sig_kernel_backprop(
derivs, path1, path2,
dyadic_order=2,
static_kernel=rbf,
time_aug=True,
)
print(dpath1)
"""
check_n_jobs(n_jobs)
check_type(left_deriv, "left_deriv", bool)
check_type(right_deriv, "right_deriv", bool)
if not (left_deriv or right_deriv):
return None, None
dyadic_order_1, dyadic_order_2 = parse_dyadic_order(dyadic_order)
if path1.ndim > 3 or path2.ndim > 3:
if tuple(path1.shape[:-2]) != tuple(path2.shape[:-2]):
raise ValueError(
"path1 and path2 must have matching leading batch dimensions; got "
f"{tuple(path1.shape[:-2])} and {tuple(path2.shape[:-2])}."
)
leading_shape = tuple(path1.shape[:-2])
lead_size = prod(leading_shape)
flat_derivs = derivs.reshape(lead_size, *derivs.shape[-2:]) if return_grid else derivs.reshape(lead_size)
flat_k_grid = None if k_grid is None else k_grid.reshape(lead_size, *k_grid.shape[-2:])
ld, rd = sig_kernel_backprop(
flat_derivs, _ensure_3d(path1), _ensure_3d(path2), dyadic_order,
static_kernel=static_kernel,
time_aug=time_aug, lead_lag=lead_lag, end_time=end_time,
left_deriv=left_deriv, right_deriv=right_deriv,
k_grid=flat_k_grid, n_jobs=n_jobs, return_grid=return_grid,
)
if ld is not None:
ld = ld.reshape(*leading_shape, *ld.shape[-2:])
if rd is not None:
rd = rd.reshape(*leading_shape, *rd.shape[-2:])
return ld, rd
if time_aug or lead_lag:
path1 = transform_path(path1, time_aug=time_aug, lead_lag=lead_lag, end_time=end_time, n_jobs=n_jobs)
path2 = transform_path(path2, time_aug=time_aug, lead_lag=lead_lag, end_time=end_time, n_jobs=n_jobs)
data = MultiplePathInputHandler([path1, path2], False, False, end_time, ["path1", "path2"])
if data.batch_size == 0:
from .data_handlers import PathOutputHandler
ld = PathOutputHandler(data.data[0].data_length, data.data[0].data_dimension, data.data[0]).data
rd = PathOutputHandler(data.data[1].data_length, data.data[1].data_dimension, data.data[1]).data
return (ld if left_deriv else None), (rd if right_deriv else None)
if return_grid:
derivs_data = PathInputHandler(derivs, False, False, 0., "derivs")
else:
derivs_data = ScalarInputHandler(derivs, bool(data.batch_shape), "derivs")
if not (derivs_data.type_ == data.type_ and derivs_data.device == data.device):
raise ValueError("derivs, path1 and path2 must all be numpy arrays or all torch tensors on the same device")
if not return_grid and data.batch_size != derivs_data.batch_size:
raise ValueError("batch size for derivs does not match batch size of paths")
torch_path1 = torch.as_tensor(data.path[0]) # Avoids data copy
torch_path2 = torch.as_tensor(data.path[1])
if k_grid is None:
k_grid = sig_kernel(torch.as_tensor(path1), torch.as_tensor(path2), dyadic_order, static_kernel=static_kernel, time_aug=False, lead_lag=False, end_time=end_time, n_jobs=n_jobs, return_grid=True)
torch_path1 = _ensure_3d(torch_path1)
torch_path2 = _ensure_3d(torch_path2)
ctx = Context()
if static_kernel is None:
static_kernel = LinearKernel()
elif not isinstance(static_kernel, StaticKernel):
raise ValueError("kernel must be a child class of pysiglib.StaticKernel")
gram = static_kernel(ctx, torch_path1, torch_path2).squeeze()
k_grid_data = PathInputHandler(k_grid, False, False, 0., "k_grid")
gram_derivs = gram_deriv(derivs_data, data, gram, k_grid_data, dyadic_order_1, dyadic_order_2, return_grid, n_jobs)
ld = static_kernel.grad_x(ctx, gram_derivs) if left_deriv else None
rd = static_kernel.grad_y(ctx, gram_derivs) if right_deriv else None
if lead_lag or time_aug:
ld = transform_path_backprop(ld, time_aug=time_aug, lead_lag=lead_lag, end_time=end_time, n_jobs=n_jobs) if left_deriv else None
rd = transform_path_backprop(rd, time_aug=time_aug, lead_lag=lead_lag, end_time=end_time, n_jobs=n_jobs) if right_deriv else None
if data.type_ == "numpy":
ld = ld.numpy() if left_deriv else None
rd = rd.numpy() if right_deriv else None
return ld, rd
[docs]
def sig_kernel_gram_backprop(
derivs : Union[np.ndarray, torch.tensor],
path1 : Union[np.ndarray, torch.tensor],
path2 : Union[np.ndarray, torch.tensor],
dyadic_order : Union[int, tuple],
*,
static_kernel : Optional[StaticKernel] = None,
time_aug : bool = False,
lead_lag : bool = False,
end_time : float = 1.,
left_deriv : bool = True,
right_deriv : bool = False,
k_grid : Union[np.ndarray, torch.tensor] = None,
n_jobs : int = 1,
return_grid : bool = False,
max_batch : int = -1
) -> Union[np.ndarray, torch.tensor, Tuple[np.ndarray, np.ndarray], Tuple[torch.tensor, torch.tensor]]:
"""
This function is required to backpropagate through ``pysiglib.sig_kernel_gram``.
Given the derivatives of a scalar function :math:`F` with respect to a
gram matrix of signature kernels, :math:`\\partial F / G`,
returns the derivatives of :math:`F` with respect to one or both of the
underlying paths, :math:`\\{\\partial F / x_{t_i}\\}_{i=0}^{L_1}` and
:math:`\\{\\partial F / y_{t_i}\\}_{i=0}^{L_2}`.
:param derivs: Derivatives with respect to a gram matrix of signature kernels,
:math:`\\partial F / G`. Should have the same shape as the output of
``pysiglib.sig_kernel_gram`` for the same inputs:
``(*batch_shape_1, *batch_shape_2)`` if ``return_grid=False``, or
``(*batch_shape_1, *batch_shape_2, dyadic_length_1, dyadic_length_2)`` if
``return_grid=True``.
:type derivs: numpy.ndarray | torch.tensor
:param path1: A path or batch of paths, of shape ``(*batch_shape_1, length_1, dimension)``.
:type path1: numpy.ndarray | torch.tensor
:param path2: A path or batch of paths, of shape ``(*batch_shape_2, length_2, dimension)``.
Independent of ``path1``'s batch shape.
:type path2: numpy.ndarray | torch.tensor
:param dyadic_order: The dyadic order(s) used to compute the signature kernels.
:type dyadic_order: int | tuple
:param static_kernel: Static kernel. If ``None`` (default), the linear kernel will be used.
For details, see the documentation on :doc:`static kernels </pages/signature_kernels/static_kernels>`.
:type static_kernel: None | pysiglib.StaticKernel
:param time_aug: If ``True``, assumes the paths were time augmented.
:type time_aug: bool
:param lead_lag: If ``True``, assumes the lead-lag transform was applied.
:type lead_lag: bool
:param end_time: End time for time-augmentation, :math:`t_L`.
:type end_time: float
:param left_deriv: If ``True``, returns :math:`\\{\\partial F / x_{t_i}\\}_{i=0}^{L_1}`.
At least one of ``left_deriv`` and ``right_deriv`` must be ``True``. If both are
``True``, returns both derivatives as a tuple.
:type left_deriv: bool
:param right_deriv: If ``True``, returns :math:`\\{\\partial F / y_{t_i}\\}_{i=0}^{L_2}`.
At least one of ``left_deriv`` and ``right_deriv`` must be ``True``. If both are
``True``, returns both derivatives as a tuple.
:type right_deriv: bool
:param k_grid: Signature kernel PDE grid. If ``None``, the grid will be recomputed.
:type k_grid: numpy.ndarray | torch.tensor
:param n_jobs: (Only applicable to CPU computation) Number of threads to run in parallel.
If n_jobs = 1, the computation is run serially. If set to -1, all available threads
are used. For n_jobs below -1, (max_threads + 1 + n_jobs) threads are used. For example
if n_jobs = -2, all threads but one are used.
:type n_jobs: int
:param return_grid: If ``True``, backpropagates derivatives with respect to the entire PDE grid.
:type return_grid: bool
:param max_batch: Maximum batch size to run in parallel. If the computation is failing
due to insufficient memory, this parameter should be decreased.
If set to -1, the entire batch is computed in parallel.
:type max_batch: int
:return: Tuple of derivatives of :math:`F` with respect to one or both of the
underlying paths. If ``left_deriv`` is ``True``, the first element of
this tuple is :math:`\\{\\partial F / x_{t_i}\\}_{i=0}^{L_1}` with shape
matching ``path1``, otherwise it is ``None``. Similarly for ``right_deriv``
and :math:`\\{\\partial F / y_{t_i}\\}_{i=0}^{L_2}` with shape matching ``path2``.
:rtype: numpy.ndarray | torch.tensor | Tuple[numpy.ndarray | numpy.ndarray] | Tuple[torch.tensor | torch.tensor]
.. note::
When called via ``pysiglib.torch_api``, the default behaviour is to pass ``k_grid = None`` and reconstruct the
PDE grids. This is done to avoid memory allocation issues for large batch sizes.
Example:
---------
.. code-block:: python
import torch
import pysiglib
path1 = torch.rand((10, 100, 5))
path2 = torch.rand((8, 100, 5))
gram = pysiglib.sig_kernel_gram(path1, path2, dyadic_order=2)
derivs = torch.ones_like(gram)
dpath1, _ = pysiglib.sig_kernel_gram_backprop(derivs, path1, path2, dyadic_order=2)
print(dpath1)
.. code-block:: python
# Gram backprop with a static kernel
import torch
import pysiglib
path1 = torch.rand((10, 100, 5))
path2 = torch.rand((8, 100, 5))
rbf = pysiglib.RBFKernel(sigma=0.5)
gram = pysiglib.sig_kernel_gram(
path1, path2, dyadic_order=2, static_kernel=rbf, time_aug=True,
)
derivs = torch.ones_like(gram)
dpath1, dpath2 = pysiglib.sig_kernel_gram_backprop(
derivs, path1, path2,
dyadic_order=2,
static_kernel=rbf,
time_aug=True,
left_deriv=True,
right_deriv=True,
max_batch=4,
)
print(dpath1)
print(dpath2)
"""
# We use sig_kernel_backprop for simplicity, rather than directly calling
# the cpp function.
# There is clearly more overhead here than is necessary, but it
# shouldn't be significant for large computations.
check_type(time_aug, "time_aug", bool)
check_type(lead_lag, "lead_lag", bool)
check_type(left_deriv, "left_deriv", bool)
check_type(right_deriv, "right_deriv", bool)
if not (left_deriv or right_deriv):
return None, None
check_type(max_batch, "max_batch", int)
if max_batch == 0 or max_batch < -1:
raise ValueError("max_batch must be a positive integer or -1")
symmetric = path1 is path2
batch_shape_1 = tuple(path1.shape[:-2])
batch_shape_2 = batch_shape_1 if symmetric else tuple(path2.shape[:-2])
n1, n2 = len(batch_shape_1), len(batch_shape_2)
path1 = _ensure_3d(path1)
path2 = path1 if symmetric else _ensure_3d(path2)
if n1 != 1 or n2 != 1:
flat_B1, flat_B2 = path1.shape[0], path2.shape[0]
derivs = derivs.reshape(flat_B1, flat_B2, *derivs.shape[n1 + n2:]) if return_grid else derivs.reshape(flat_B1, flat_B2)
if k_grid is not None:
k_grid = k_grid.reshape(flat_B1, flat_B2, *k_grid.shape[n1 + n2:])
data = MultiplePathInputHandler([path1, path2], time_aug, lead_lag, end_time, ["path1", "path2"], False)
derivs = torch.as_tensor(derivs)
# Use torch for simplicity
path1 = torch.as_tensor(data.path[0])
path2 = torch.as_tensor(data.path[1])
if k_grid is not None:
k_grid = torch.as_tensor(k_grid)
batch1 = path1.shape[0]
batch2 = path2.shape[0]
if max_batch == -1:
max_batch = max(batch1, batch2)
ld = torch.zeros(path1.shape, dtype=torch.float64, device=path1.device) if left_deriv else None
rd = torch.zeros(path2.shape, dtype=torch.float64, device=path1.device) if right_deriv else None
####################################
# Now run computation in batches
####################################
if symmetric:
idx_i, idx_j = torch.triu_indices(batch1, batch1, device=path1.device)
else:
idx_i = torch.arange(batch1, device=path1.device).repeat_interleave(batch2)
idx_j = torch.arange(batch2, device=path2.device).repeat(batch1)
src2 = path1 if symmetric else path2
n_pairs = idx_i.shape[0]
chunk_size = max_batch * max_batch
# Check if k_grid can be transposed for symmetric off-diagonal pairs
do1, do2 = parse_dyadic_order(dyadic_order)
can_transpose_k = (do1 == do2)
for start in range(0, n_pairs, chunk_size):
end = min(start + chunk_size, n_pairs)
ci = idx_i[start:end]
cj = idx_j[start:end]
path1_ = path1[ci]
path2_ = src2[cj]
if k_grid is None:
k = sig_kernel(path1_, path2_, dyadic_order, static_kernel=static_kernel, time_aug=time_aug, lead_lag=lead_lag, end_time=end_time, n_jobs=n_jobs, return_grid=True)
else:
k = k_grid[ci, cj]
derivs_ = derivs[ci, cj]
ld_, rd_ = sig_kernel_backprop(derivs_, path1_, path2_, dyadic_order, static_kernel=static_kernel, time_aug=time_aug, lead_lag=lead_lag, end_time=end_time, left_deriv=left_deriv, right_deriv=right_deriv, k_grid=k, n_jobs=n_jobs, return_grid=return_grid)
if left_deriv:
ld.index_add_(0, ci, ld_.to(ld.dtype))
if right_deriv:
rd.index_add_(0, cj, rd_.to(rd.dtype))
# Symmetric: handle transposed off-diagonal entries G[j,i]
if symmetric:
off = ci != cj
if off.any():
ci_off = ci[off]
cj_off = cj[off]
path1_t = src2[cj_off]
path2_t = path1[ci_off]
if k_grid is not None:
k_t = k_grid[cj_off, ci_off]
elif can_transpose_k:
k_t = k[off].transpose(-2, -1)
else:
k_t = sig_kernel(path1_t, path2_t, dyadic_order, static_kernel=static_kernel, time_aug=time_aug, lead_lag=lead_lag, end_time=end_time, n_jobs=n_jobs, return_grid=True)
derivs_t = derivs[cj_off, ci_off]
ld_t, rd_t = sig_kernel_backprop(derivs_t, path1_t, path2_t, dyadic_order, static_kernel=static_kernel, time_aug=time_aug, lead_lag=lead_lag, end_time=end_time, left_deriv=left_deriv, right_deriv=right_deriv, k_grid=k_t, n_jobs=n_jobs, return_grid=return_grid)
if left_deriv:
ld.index_add_(0, cj_off, ld_t.to(ld.dtype))
if right_deriv:
rd.index_add_(0, ci_off, rd_t.to(rd.dtype))
if ld is not None:
ld = ld.reshape(*batch_shape_1, *ld.shape[-2:])
if rd is not None:
rd = rd.reshape(*batch_shape_2, *rd.shape[-2:])
if data.type_ == "numpy":
return (ld.numpy() if ld is not None else None), (rd.numpy() if rd is not None else None)
return ld, rd