Source code for pysiglib.static_kernels

# Copyright 2025 Daniil Shmelev
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========================================================================

from abc import ABC, abstractmethod
import math
import warnings
import torch

[docs] class Context: """ Provides context for backpropagation through static kernels. It is not generally necessary to create instances of this class manually; documentation for this class is provided purely for reference when constructing custom-made static kernels. """ def __init__(self): self.saved_tensors = () self.saved_for_y = ()
[docs] def save_for_backward(self, *args): """ Save objects from the forward pass to be re-used on the backward pass. """ self.saved_tensors = args
[docs] def save_for_grad_y(self, *args): """ Save objects from the computation of the gradient with respect to x to be re-used for that of the gradient with respect to y. """ self.saved_for_y = args
[docs] class StaticKernel(ABC):
[docs] @abstractmethod def __call__(self, ctx : Context, x : torch.Tensor, y : torch.Tensor): """ Returns the gram matrix of static kernels: .. math:: \\{ \\kappa(x_s, y_t) - \\kappa(x_{s-1}, x_t) - \\kappa(x_s, y_{t-1}) + \\kappa(x_{s-1}, y_{t-1}) \\}_{0 \\leq s \\leq L_1, 0 \\leq t \\leq L_2} as a tensor of shape ``(batch_size, length_1 - 1, length_2 - 1)``, where ``length_1`` is the length of :math:`x` and ``length_2`` is the length of :math:`y`. :param ctx: ``pysiglib.Context`` object for backpropagation :type ctx: pysiglib.Context :param x: Path :math:`x` of shape ``(batch_size, length_1, dimension)``. :type x: torch.Tensor :param y: Path :math:`y` of shape ``(batch_size, length_2, dimension)``. :type y: torch.Tensor :return: Batch of gram matrices of shape ``(batch_size, length_1 - 1, length_2 - 1)``. :rtype: torch.Tensor """ pass
[docs] @abstractmethod def grad_x(self, ctx : Context, derivs : torch.Tensor): """ Backpropagates ``derivs`` through the static kernel computation and returns the derivatives with respect to the path :math:`x`. :param ctx: ``pysiglib.Context`` object for backpropagation :type ctx: pysiglib.Context :param derivs: Derivatives with respect to the gram matrices outputted by ``__call__``, of shape ``(batch_size, length_1 - 1, length_2 - 1)``. :return: Derivatives with respect to the path :math:`x` of shape ``(batch_size, length_1, dimension)``. :rtype: torch.Tensor """ pass
[docs] @abstractmethod def grad_y(self, ctx : Context, derivs : torch.Tensor): """ Backpropagates ``derivs`` through the static kernel computation and returns the derivatives with respect to the path :math:`y`. :param ctx: ``pysiglib.Context`` object for backpropagation :type ctx: pysiglib.Context :param derivs: Derivatives with respect to the gram matrices outputted by ``__call__``, of shape ``(batch_size, length_1 - 1, length_2 - 1)``. :return: Derivatives with respect to the path :math:`y` of shape ``(batch_size, length_2, dimension)``. :rtype: torch.Tensor """ pass
[docs] class LinearKernel(StaticKernel): """ The linear kernel, defined by :math:`\\kappa(x, y) = \\langle x, y \\rangle`. """ def __call__(self, ctx : Context, x : torch.Tensor, y : torch.Tensor): dx = torch.diff(x, dim=1) dy = torch.diff(y, dim=1) ctx.save_for_backward(dx, dy) return torch.bmm(dx, dy.permute(0, 2, 1)) def grad_x(self, ctx : Context, derivs : torch.Tensor): dx, dy = ctx.saved_tensors out = torch.empty((dx.shape[0], dx.shape[1] + 1, dy.shape[1]), dtype=dx.dtype, device=derivs.device) out[:, 0, :] = 0 out[:, 1:, :] = derivs out[:, :-1, :] -= derivs return torch.bmm(out, dy) def grad_y(self, ctx : Context, derivs : torch.Tensor): dx, dy = ctx.saved_tensors out = torch.empty((dx.shape[0], dx.shape[1], dy.shape[1] + 1), dtype=dx.dtype, device=derivs.device) out[:, :, 0] = 0 out[:, :, 1:] = derivs out[:, :, :-1] -= derivs return torch.bmm(out.permute(0, 2, 1), dx)
[docs] class ScaledLinearKernel(StaticKernel): """ The scaled linear kernel, defined by :math:`\\kappa(x, y) = \\langle \\alpha x, \\alpha y \\rangle = \\alpha^2 \\langle x, y \\rangle`, where :math:`\\alpha` is given by the parameter ``scale``. A choice of ``scale=1.0`` corresponds to the standard linear kernel. """ def __init__(self, scale : float = 1.): if scale < 0: raise ValueError(f"ScaledLinearKernel: scale must be >= 0, got {scale}") self.linear_kernel = LinearKernel() self.scale = scale self._scale_sq = scale ** 2 def __call__(self, ctx : Context, x : torch.Tensor, y : torch.Tensor): return self.linear_kernel(ctx, x * self._scale_sq, y) def grad_x(self, ctx : Context, derivs : torch.Tensor): return self.linear_kernel.grad_x(ctx, derivs) * self._scale_sq def grad_y(self, ctx : Context, derivs : torch.Tensor): return self.linear_kernel.grad_y(ctx, derivs)
[docs] class RBFKernel(StaticKernel): """ The RBF kernel, defined by :math:`\\kappa(x, y) = \\exp\\left( -\\frac{\\lVert x - y \\rVert^2}{\\sigma} \\right)`. """ def __init__(self, sigma : float): if sigma <= 0: raise ValueError(f"RBFKernel: sigma must be > 0, got {sigma}") self.sigma = sigma self._one_over_sigma = 1. / sigma self._scale = 2 * self._one_over_sigma def __call__(self, ctx : Context, x : torch.Tensor, y : torch.Tensor): dist = torch.bmm(x * self._scale, y.permute(0, 2, 1)) x2 = torch.pow(x, 2) y2 = torch.pow(y, 2) x2 = torch.sum(x2, dim=2) * self._one_over_sigma y2 = torch.sum(y2, dim=2) * self._one_over_sigma dist = dist - (torch.reshape(x2, (x.shape[0], x.shape[1], 1)) + torch.reshape(y2, (x.shape[0], 1, y.shape[1]))) dist = torch.exp(dist) ctx.save_for_backward(x, y, dist.clone()) buff = torch.diff(dist, dim=1) result = torch.diff(buff, dim=2) return result def grad_x(self, ctx : Context, derivs : torch.Tensor): x, y, out = ctx.saved_tensors dout = _undo_double_diff(derivs, out) dout *= out dout *= 2. * self._one_over_sigma ctx.save_for_grad_y(x, y, dout) return torch.bmm(dout, y) - x * torch.sum(dout, dim=2, keepdim=True) def grad_y(self, ctx : Context, derivs : torch.Tensor): if ctx.saved_for_y: x, y, dout = ctx.saved_for_y else: x, y, out = ctx.saved_tensors dout = _undo_double_diff(derivs, out) dout *= out dout *= 2. * self._one_over_sigma return torch.bmm(dout.permute(0, 2, 1), x) - y * torch.sum(dout, dim=1).unsqueeze(-1)
def _squared_dist(x, y): x2 = torch.sum(x * x, dim=2).unsqueeze(2) y2 = torch.sum(y * y, dim=2).unsqueeze(1) return torch.clamp(x2 - 2 * torch.bmm(x, y.permute(0, 2, 1)) + y2, min=0) def _undo_double_diff(derivs, like): dout = torch.zeros_like(like) dout[:, 1:, 1:] = derivs dout[:, :-1, :-1] += derivs dout[:, 1:, :-1] -= derivs dout[:, :-1, 1:] -= derivs return dout
[docs] class PolynomialKernel(StaticKernel): """ The polynomial kernel, defined by :math:`\\kappa(x, y) = \\text{scale} \\cdot \\left( \\langle x, y \\rangle + \\gamma \\right)^d`, where :math:`d` is the ``degree`` parameter. """ def __init__(self, degree : float = 3., gamma : float = 1., scale : float = 1.): if degree < 0: raise ValueError(f"PolynomialKernel: degree must be >= 0, got {degree}") if scale < 0: raise ValueError(f"PolynomialKernel: scale must be >= 0, got {scale}") self.degree = degree self.gamma = gamma self.scale = scale self._int_degree = int(degree) if degree == int(degree) and 1 <= degree <= 5 else None self._needs_base_clamp = self._int_degree is None and degree != 0 self._warned_negative_base = False def _pow(self, base, exp): if self._int_degree is not None and exp == int(exp) and 0 <= exp <= 4: n = int(exp) if n == 0: return torch.ones_like(base) if n == 1: return base if n == 2: return base * base b2 = base * base if n == 3: return b2 * base if n == 4: return b2 * b2 return torch.pow(base, exp) def __call__(self, ctx : Context, x : torch.Tensor, y : torch.Tensor): inner = torch.bmm(x, y.permute(0, 2, 1)) base = inner + self.gamma if self._needs_base_clamp: if not self._warned_negative_base and (base < 0).any(): self._warned_negative_base = True warnings.warn( "PolynomialKernel: non-integer degree with negative base values " "(<x, y> + gamma < 0). These entries are clamped to 0. Consider " "increasing gamma to ensure all base values are non-negative.", RuntimeWarning, stacklevel=2 ) base = torch.clamp(base, min=0) K = self.scale * self._pow(base, self.degree) ctx.save_for_backward(x, y, base) return torch.diff(torch.diff(K, dim=1), dim=2) def grad_x(self, ctx : Context, derivs : torch.Tensor): if self.degree == 0: return torch.zeros_like(ctx.saved_tensors[0]) x, y, base = ctx.saved_tensors dout = _undo_double_diff(derivs, base) dout *= self.scale * self.degree dout *= self._pow(base, self.degree - 1) ctx.save_for_grad_y(x, y, dout) return torch.bmm(dout, y) def grad_y(self, ctx : Context, derivs : torch.Tensor): if self.degree == 0: return torch.zeros_like(ctx.saved_tensors[1]) if ctx.saved_for_y: x, y, dout = ctx.saved_for_y else: x, y, base = ctx.saved_tensors dout = _undo_double_diff(derivs, base) dout *= self.scale * self.degree dout *= self._pow(base, self.degree - 1) return torch.bmm(dout.permute(0, 2, 1), x)
[docs] class Matern12Kernel(StaticKernel): """ The Matern-1/2 kernel (exponential kernel), defined by :math:`\\kappa(x, y) = \\exp\\left( -\\frac{\\lVert x - y \\rVert}{\\sigma} \\right)`. """ def __init__(self, sigma : float): if sigma <= 0: raise ValueError(f"Matern12Kernel: sigma must be > 0, got {sigma}") self.sigma = sigma self._one_over_sigma = 1. / sigma def __call__(self, ctx : Context, x : torch.Tensor, y : torch.Tensor): dist2 = _squared_dist(x, y) dist = torch.sqrt(dist2 + 1e-30) K = torch.exp(-dist * self._one_over_sigma) ctx.save_for_backward(x, y, dist) return torch.diff(torch.diff(K, dim=1), dim=2) def grad_x(self, ctx : Context, derivs : torch.Tensor): x, y, dist = ctx.saved_tensors K = torch.exp(-dist * self._one_over_sigma) dout = _undo_double_diff(derivs, dist) dout *= K dout *= self._one_over_sigma / torch.clamp(dist, min=1e-15) ctx.save_for_grad_y(x, y, dout) return torch.bmm(dout, y) - x * torch.sum(dout, dim=2, keepdim=True) def grad_y(self, ctx : Context, derivs : torch.Tensor): if ctx.saved_for_y: x, y, dout = ctx.saved_for_y else: x, y, dist = ctx.saved_tensors K = torch.exp(-dist * self._one_over_sigma) dout = _undo_double_diff(derivs, dist) dout *= K dout *= self._one_over_sigma / torch.clamp(dist, min=1e-15) return torch.bmm(dout.permute(0, 2, 1), x) - y * torch.sum(dout, dim=1).unsqueeze(-1)
[docs] class Matern32Kernel(StaticKernel): """ The Matern-3/2 kernel, defined by :math:`\\kappa(x, y) = \\left(1 + \\frac{\\sqrt{3} \\lVert x - y \\rVert}{\\sigma}\\right) \\exp\\left( -\\frac{\\sqrt{3} \\lVert x - y \\rVert}{\\sigma} \\right)`. """ def __init__(self, sigma : float): if sigma <= 0: raise ValueError(f"Matern32Kernel: sigma must be > 0, got {sigma}") self.sigma = sigma self._sqrt3_over_sigma = math.sqrt(3.) / sigma self._3_over_sigma_sq = 3. / (sigma ** 2) def __call__(self, ctx : Context, x : torch.Tensor, y : torch.Tensor): D_scaled = torch.sqrt(_squared_dist(x, y) + 1e-30) * self._sqrt3_over_sigma exp_term = torch.exp(-D_scaled) K = (1. + D_scaled) * exp_term ctx.save_for_backward(x, y, exp_term) return torch.diff(torch.diff(K, dim=1), dim=2) def grad_x(self, ctx : Context, derivs : torch.Tensor): x, y, exp_term = ctx.saved_tensors dout = _undo_double_diff(derivs, exp_term) dout *= exp_term dout *= self._3_over_sigma_sq ctx.save_for_grad_y(x, y, dout) return torch.bmm(dout, y) - x * torch.sum(dout, dim=2, keepdim=True) def grad_y(self, ctx : Context, derivs : torch.Tensor): if ctx.saved_for_y: x, y, dout = ctx.saved_for_y else: x, y, exp_term = ctx.saved_tensors dout = _undo_double_diff(derivs, exp_term) dout *= exp_term dout *= self._3_over_sigma_sq return torch.bmm(dout.permute(0, 2, 1), x) - y * torch.sum(dout, dim=1).unsqueeze(-1)
[docs] class Matern52Kernel(StaticKernel): """ The Matern-5/2 kernel, defined by :math:`\\kappa(x, y) = \\left(1 + \\frac{\\sqrt{5} \\lVert x - y \\rVert}{\\sigma} + \\frac{5 \\lVert x - y \\rVert^2}{3\\sigma^2}\\right) \\exp\\left( -\\frac{\\sqrt{5} \\lVert x - y \\rVert}{\\sigma} \\right)`. """ def __init__(self, sigma : float): if sigma <= 0: raise ValueError(f"Matern52Kernel: sigma must be > 0, got {sigma}") self.sigma = sigma self._sqrt5_over_sigma = math.sqrt(5.) / sigma self._5_over_3sigma_sq = 5. / (3. * sigma ** 2) def __call__(self, ctx : Context, x : torch.Tensor, y : torch.Tensor): u = torch.sqrt(_squared_dist(x, y) + 1e-30) * self._sqrt5_over_sigma exp_term = torch.exp(-u) K = (1. + u + u * u / 3.) * exp_term ctx.save_for_backward(x, y, u, exp_term) return torch.diff(torch.diff(K, dim=1), dim=2) def grad_x(self, ctx : Context, derivs : torch.Tensor): x, y, u, exp_term = ctx.saved_tensors dout = _undo_double_diff(derivs, exp_term) dout *= exp_term dout *= 1. + u dout *= self._5_over_3sigma_sq ctx.save_for_grad_y(x, y, dout) return torch.bmm(dout, y) - x * torch.sum(dout, dim=2, keepdim=True) def grad_y(self, ctx : Context, derivs : torch.Tensor): if ctx.saved_for_y: x, y, dout = ctx.saved_for_y else: x, y, u, exp_term = ctx.saved_tensors dout = _undo_double_diff(derivs, exp_term) dout *= exp_term dout *= 1. + u dout *= self._5_over_3sigma_sq return torch.bmm(dout.permute(0, 2, 1), x) - y * torch.sum(dout, dim=1).unsqueeze(-1)
[docs] class RationalQuadraticKernel(StaticKernel): """ The rational quadratic kernel, defined by :math:`\\kappa(x, y) = \\left(1 + \\frac{\\lVert x - y \\rVert^2}{2 \\alpha \\sigma^2}\\right)^{-\\alpha}`. """ def __init__(self, sigma : float, alpha : float = 1.): if sigma <= 0: raise ValueError(f"RationalQuadraticKernel: sigma must be > 0, got {sigma}") if alpha <= 0: raise ValueError(f"RationalQuadraticKernel: alpha must be > 0, got {alpha}") self.sigma = sigma self.alpha = alpha self._c = 2. * alpha * sigma ** 2 self._one_over_sigma_sq = 1. / (sigma ** 2) def __call__(self, ctx : Context, x : torch.Tensor, y : torch.Tensor): dist2 = _squared_dist(x, y) base = 1. + dist2 / self._c K = torch.pow(base, -self.alpha) ctx.save_for_backward(x, y, K / base) return torch.diff(torch.diff(K, dim=1), dim=2) def grad_x(self, ctx : Context, derivs : torch.Tensor): x, y, weight = ctx.saved_tensors dout = _undo_double_diff(derivs, weight) dout *= weight dout *= self._one_over_sigma_sq ctx.save_for_grad_y(x, y, dout) return torch.bmm(dout, y) - x * torch.sum(dout, dim=2, keepdim=True) def grad_y(self, ctx : Context, derivs : torch.Tensor): if ctx.saved_for_y: x, y, dout = ctx.saved_for_y else: x, y, weight = ctx.saved_tensors dout = _undo_double_diff(derivs, weight) dout *= weight dout *= self._one_over_sigma_sq return torch.bmm(dout.permute(0, 2, 1), x) - y * torch.sum(dout, dim=1).unsqueeze(-1)