Source code for pysiglib.static_kernels

# Copyright 2025 Daniil Shmelev
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========================================================================

from abc import ABC, abstractmethod
import torch

[docs] class Context: """ Provides context for backpropagation through static kernels. It is not generally necessary to create instances of this class manually; documentation for this class is provided purely for reference when constructing custom-made static kernels. """ def __init__(self): self.saved_tensors = () self.saved_for_y = ()
[docs] def save_for_backward(self, *args): """ Save objects from the forward pass to be re-used on the backward pass. """ self.saved_tensors = args
[docs] def save_for_grad_y(self, *args): """ Save objects from the computation of the gradient with respect to x to be re-used for that of the gradient with respect to y. """ self.saved_for_y = args
[docs] class StaticKernel(ABC):
[docs] @abstractmethod def __call__(self, ctx : Context, x : torch.Tensor, y : torch.Tensor): """ Returns the gram matrix of static kernels: .. math:: \\{ \\kappa(x_s, y_t) - \\kappa(x_{s-1}, x_t) - \\kappa(x_s, y_{t-1}) + \\kappa(x_{s-1}, y_{t-1}) \\}_{0 \\leq s \\leq L_1, 0 \\leq t \\leq L_2} as a tensor of shape ``(batch_size, length_1 - 1, length_2 - 1)``, where ``length_1`` is the length of :math:`x` and ``length_2`` is the length of :math:`y`. :param ctx: ``pysiglib.Context`` object for backpropagation :type ctx: pysiglib.Context :param x: Path :math:`x` of shape ``(batch_size, length_1, dimension)``. :type x: torch.Tensor :param y: Path :math:`y` of shape ``(batch_size, length_2, dimension)``. :type y: torch.Tensor :return: Batch of gram matrices of shape ``(batch_size, length_1 - 1, length_2 - 1)``. :rtype: torch.Tensor """ pass
[docs] @abstractmethod def grad_x(self, ctx : Context, derivs : torch.Tensor): """ Backpropagates ``derivs`` through the static kernel computation and returns the derivatives with respect to the path :math:`x`. :param ctx: ``pysiglib.Context`` object for backpropagation :type ctx: pysiglib.Context :param derivs: Derivatives with respect to the gram matrices outputted by ``__call__``, of shape ``(batch_size, length_1 - 1, length_2 - 1)``. :return: Derivatives with respect to the path :math:`x` of shape ``(batch_size, length_1, dimension)``. :rtype: torch.Tensor """ pass
[docs] @abstractmethod def grad_y(self, ctx : Context, derivs : torch.Tensor): """ Backpropagates ``derivs`` through the static kernel computation and returns the derivatives with respect to the path :math:`y`. :param ctx: ``pysiglib.Context`` object for backpropagation :type ctx: pysiglib.Context :param derivs: Derivatives with respect to the gram matrices outputted by ``__call__``, of shape ``(batch_size, length_1 - 1, length_2 - 1)``. :return: Derivatives with respect to the path :math:`y` of shape ``(batch_size, length_2, dimension)``. :rtype: torch.Tensor """ pass
[docs] class LinearKernel(StaticKernel): """ The linear kernel, defined by :math:`\\kappa(x, y) = \\langle x, y \\rangle`. """ def __call__(self, ctx : Context, x : torch.Tensor, y : torch.Tensor): dx = torch.diff(x, dim=1) dy = torch.diff(y, dim=1) ctx.save_for_backward(dx, dy) return torch.bmm(dx, dy.permute(0, 2, 1)) def grad_x(self, ctx : Context, derivs : torch.Tensor): dx, dy = ctx.saved_tensors out = torch.empty((dx.shape[0], dx.shape[1] + 1, dy.shape[1]), dtype=dx.dtype, device=derivs.device) out[:, 0, :] = 0 out[:, 1:, :] = derivs out[:, :-1, :] -= derivs return torch.bmm(out, dy) def grad_y(self, ctx : Context, derivs : torch.Tensor): dx, dy = ctx.saved_tensors out = torch.empty((dx.shape[0], dx.shape[1], dy.shape[1] + 1), dtype=dx.dtype, device=derivs.device) out[:, :, 0] = 0 out[:, :, 1:] = derivs out[:, :, :-1] -= derivs return torch.bmm(out.permute(0, 2, 1), dx)
[docs] class ScaledLinearKernel(StaticKernel): """ The scaled linear kernel, defined by :math:`\\kappa(x, y) = \\langle \\alpha x, \\alpha y \\rangle = \\alpha^2 \\langle x, y \\rangle`, where :math:`\\alpha` is given by the parameter ``scale``. A choice of ``scale=1.0`` corresponds to the standard linear kernel. """ def __init__(self, scale : float = 1.): self.linear_kernel = LinearKernel() self.scale = scale self._scale_sq = scale ** 2 def __call__(self, ctx : Context, x : torch.Tensor, y : torch.Tensor): return self.linear_kernel(ctx, x * self._scale_sq, y) def grad_x(self, ctx : Context, derivs : torch.Tensor): return self.linear_kernel.grad_x(ctx, derivs) * self._scale_sq def grad_y(self, ctx : Context, derivs : torch.Tensor): return self.linear_kernel.grad_y(ctx, derivs)
[docs] class RBFKernel(StaticKernel): """ The RBF kernel, defined by :math:`\\kappa(x, y) = \\exp\\left( -\\frac{\\lVert x - y \\rVert^2}{\\sigma} \\right)`. """ def __init__(self, sigma : float): self.sigma = sigma self._one_over_sigma = 1. / sigma self._scale = 2 * self._one_over_sigma def __call__(self, ctx : Context, x : torch.Tensor, y : torch.Tensor): dist = torch.bmm(x * self._scale, y.permute(0, 2, 1)) x2 = torch.pow(x, 2) y2 = torch.pow(y, 2) x2 = torch.sum(x2, dim=2) * self._one_over_sigma y2 = torch.sum(y2, dim=2) * self._one_over_sigma dist -= torch.reshape(x2, (x.shape[0], x.shape[1], 1)) + torch.reshape(y2, (x.shape[0], 1, y.shape[1])) torch.exp(dist, out=dist) ctx.save_for_backward(x, y, dist.clone()) buff = torch.empty_like(dist[:, :-1, :]) torch.diff(dist, dim=1, out=buff) dist.resize_((dist.shape[0], dist.shape[1] - 1, dist.shape[2] - 1)) torch.diff(buff, dim=2, out=dist) return dist def grad_x(self, ctx : Context, derivs : torch.Tensor): x, y, out = ctx.saved_tensors dout = torch.zeros_like(out) dout[:, 1:, 1:] += derivs dout[:, :-1, :-1] += derivs dout[:, 1:, :-1] -= derivs dout[:, :-1, 1:] -= derivs dout *= out dout *= 2. * self._one_over_sigma ctx.save_for_grad_y(x, y, dout) return torch.bmm(dout, y) - x * torch.sum(dout, dim=2).unsqueeze(-1) def grad_y(self, ctx : Context, derivs : torch.Tensor): if ctx.saved_for_y: x, y, dout = ctx.saved_for_y# else: x, y, out = ctx.saved_tensors dout = torch.zeros_like(out) dout[:, 1:, 1:] += derivs dout[:, :-1, :-1] += derivs dout[:, 1:, :-1] -= derivs dout[:, :-1, 1:] -= derivs dout *= out dout *= 2. * self._one_over_sigma return torch.bmm(dout.permute(0, 2, 1), x) - y * torch.sum(dout, dim=1).unsqueeze(-1)