Source code for tbp.monty.frameworks.models.salience.return_inhibitor

# Copyright 2025-2026 Thousand Brains Project
#
# Copyright may exist in Contributors' modifications
# and/or contributions to the work.
#
# Use of this source code is governed by the MIT
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.
from __future__ import annotations

import numpy as np


[docs]class DecayKernel: """Decay kernel represents a previously visited location. Returns the product of time- and space-dependent exponentials. """
[docs] def __init__( self, location: np.ndarray, tau_t: float = 10.0, tau_s: float = 0.01, spatial_cutoff: float | None = 0.02, w_t_min: float = 0.1, ): self._location = location self._tau_t = tau_t self._tau_s = tau_s self._spatial_cutoff = spatial_cutoff self._w_t_min = w_t_min self._t = 0
[docs] def w_t(self) -> float: """Compute the time-dependent weight at the current step. The weight is computed as `exp(-t / lam)`, where `t` is the number of steps since the kernel was created, and `lam` is equal to `tau_t / log(2)`. Returns: The weight, bounded to [0, 1]. """ return np.exp(-self._t / (self._tau_t / float(np.log(2))))
[docs] def w_s(self, points: np.ndarray) -> np.ndarray: """Compute the distance-dependent weight. The weight is computed as `exp(-z / lam)`, where `z` is the distance between the kernel's center and the given point(s), and `lam` is equal to `tau_s / log(2)`. Args: points: An (num_points, 3) array of points. Returns: The weight, bounded to [0, 1]. Has shape (num_points,). """ dist = np.linalg.norm(self._location - points, axis=1) out = np.exp(-dist / (self._tau_s / np.log(2))) if self._spatial_cutoff is not None: out[dist > self._spatial_cutoff] = 0.0 return out
[docs] def step(self) -> bool: """Increment the step counter, and check if the kernel is expired. Returns: True if the kernel is expired, False otherwise. """ self._t += 1 return self.w_t() < self._w_t_min
def __call__(self, points: np.ndarray) -> np.ndarray: """Compute the time- and distance-dependent weight at a given point. Computes the product of the time- and distance-dependent weights. Weights are bounded to [0, 1], where values close to 1 indicate the kernel has a large influence on the given point(s). Args: points: An (num_points, 3) array of points. Returns: The weights, bounded to [0, 1]. Has shape (num_points,). """ assert points.ndim == 2 and points.shape[1] == 3 return self.w_t() * self.w_s(points)
[docs]class DecayKernelFactory:
[docs] def __init__( self, tau_t: float = 10.0, tau_s: float = 0.01, spatial_cutoff: float | None = 0.02, w_t_min: float = 0.1, ): self._tau_t = tau_t self._tau_s = tau_s self._spatial_cutoff = spatial_cutoff self._w_t_min = w_t_min
def __call__(self, location: np.ndarray) -> DecayKernel: return DecayKernel( location, self._tau_t, self._tau_s, self._spatial_cutoff, self._w_t_min )
[docs]class DecayField: """Manages a collection of decay kernels."""
[docs] def __init__( self, kernel_factory: DecayKernelFactory | None = None, ): self._kernel_factory = ( DecayKernelFactory() if kernel_factory is None else kernel_factory ) self._kernels: list[DecayKernel] = []
[docs] def reset(self) -> None: self._kernels.clear()
[docs] def add(self, location: np.ndarray) -> None: """Add a kernel to the field.""" self._kernels.append(self._kernel_factory(location))
[docs] def step(self) -> None: """Step each kernel to increment its counter, and keep only non-expired ones.""" self._kernels = [k for k in self._kernels if not k.step()]
[docs] def compute_weights(self, points: np.ndarray) -> np.ndarray: assert points.ndim == 2 and points.shape[1] == 3 if not self._kernels: return np.zeros(points.shape[0]) # Stack kernel parameters and compute in batch results = np.array([k(points) for k in self._kernels]) return np.max(results, axis=0)
[docs]class ReturnInhibitor:
[docs] def __init__( self, decay_field: DecayField | None = None, ): self._decay_field = DecayField() if decay_field is None else decay_field
[docs] def reset(self) -> None: self._decay_field.reset()
def __call__( self, visited_location: np.ndarray | None, query_locations: np.ndarray, ) -> np.ndarray: if visited_location is not None: self._decay_field.add(visited_location) ior_vals = self._decay_field.compute_weights(query_locations) self._decay_field.step() return ior_vals