# Copyright 2025 Thousand Brains Project
# Copyright 2024 Numenta Inc.
#
# Copyright may exist in Contributors' modifications
# and/or contributions to the work.
#
# Use of this source code is governed by the MIT
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.
import logging
import os
import shutil
import numpy as np
from tqdm import tqdm
from tbp.monty.frameworks.models.evidence_matching import EvidenceGraphLM
[docs]class LoggerSDR:
"""A simple logger that saves the data passed to it.
This logger maintains an episode counter and logs
the data it receives under different files named
by the episode counter.
*See more information about what data is being logged
under the `log_episode` function.*
"""
# TODO: Needs to be removed. The logger should be part
# of Monty loggers. Fix issue #328 first.
def __init__(self, path):
if path is None:
logging.warning("EvidenceSDR log path is set to None.")
return
path = os.path.expanduser(path)
# overwrite existing logs
if os.path.exists(path):
shutil.rmtree(path)
os.makedirs(path)
self.path = path
self.episode = 0
[docs] def log_episode(self, data):
"""Receives data dictionary and saves it as a pth file.
This function will save all the data passed to it. Here is
a breakdown of the data to be logged by this function.
The data dictionary contains these key-value pairs:
- mask: 2d tensor of the available overlap targets after this
episode
- target_overlap: 2d tensor of the target overlap at the end of
this episode
- training: Dictionary of training statistics for every epoch.
Includes overlap_error, training_summed_distance,
dense representations, and sdrs
- obj2id: Objects to ids dictionary mapping
- id2obj: Ids to objects dictionary mapping
"""
if hasattr(self, "path"):
np.save(
os.path.join(self.path, f"episode_{str(self.episode).zfill(3)}.npy"),
data,
)
self.episode += 1
[docs]class EncoderSDR:
"""The SDR Encoder class.
This class keeps track of the dense representations, and trains them to output SDRs
when binarized. This class also contains its own optimizer and function to add more
objects/representations.
The representations are stored as dense vectors and binarized using
top-k to convert them to SDRs. During training, the pairwise overlaps between the
sdrs are compared to the target overlaps. This error signal trains the dense
representations.
Refer to the `self.train_sdrs` function for more information on the training details
Attributes:
sdr_length: The size of the SDRs (total number of bits).
sdr_on_bits: The number of on bits in the SDRs. Controls sparsity.
lr: The learning rate of the encoding algorithm.
n_epochs: The number of training epochs per episode
stability: The stability parameter controls by how much old SDRs
change relative to new SDRs. Value range is [0.0, 1.0], where 0.0 is no
stability constraint applied and 1.0 is fixed SDRs. Values in between are
for partial stability.
log_flag: Flag to activate the logger.
"""
def __init__(
self,
sdr_length=2048,
sdr_on_bits=41,
lr=1e-2,
n_epochs=1000,
stability=0.0,
log_flag=False,
):
if sdr_on_bits >= sdr_length or sdr_on_bits <= 0:
logging.warning(
f"Invalid sparsity: sdr_on_bits set to 2% ({round(sdr_length*0.02)})"
)
sdr_on_bits = round(sdr_length * 0.02)
self.sdr_length, self.sdr_on_bits = sdr_length, sdr_on_bits
self.lr = lr
self.n_epochs = n_epochs
self.stability = stability
if self.stability > 1.0 or self.stability < 0.0:
self.stability = np.clip(self.stability, 0.0, 1.0)
logging.warning(
f"Invalid stability parameter: stability clamped to {self.stability}"
)
self.log_flag = log_flag
# Initialize obj SDR array with arbitrary values
self.obj_sdrs = np.zeros((0, self.sdr_length))
@property
def n_objects(self):
"""Return the available number of objects."""
return self.obj_sdrs.shape[0]
@property
def sdrs(self):
"""Return the available SDRs."""
return self.binarize(self.obj_sdrs)
[docs] def get_sdr(self, index):
"""Return the SDR at a specific index.
This index refers to the object index in the SDRs dictionary
(i.e., self.obj_sdrs)
"""
return self.sdrs[index]
[docs] def optimize(self, overlap_error, mask):
"""Compute and apply local gradient descent.
Compute based on the overlap error and mask.
Note there is no use of the chain rule, i.e. each SDR is optimized based on the
derivative of its clustering error with respect to its values, with no
intermediary functions.
The overlap error helps correct the sign and also provides a magnitude for the
representation updates.
Args:
overlap_error: The difference between target and predicted overlaps.
mask: Mask indicating valid entries in the overlap matrix.
Note:
num_objects = self.n_objects
Note:
A vectorized version of the algorithm is provided below, although it
would need to be modified to avoid repeated creation of arrays in order to
be more efficient. Leaving for now as this algorithm is not a bottle-neck
(circa 10-20 seconds to learn 60 object SDRs).:
# Initialize gradients
grad = np.zeros_like(self.obj_sdrs)
# Compute the pairwise differences between SDRs
diff_matrix = (
self.obj_sdrs[:, np.newaxis, :] - self.obj_sdrs[np.newaxis, :, :]
)
# Compute the absolute differences for each pair
abs_diff = np.sum(np.abs(diff_matrix), axis=2)
# Create a mask for non-zero differences
non_zero_mask = abs_diff > 0
# Apply the mask to the original mask
valid_mask = mask & non_zero_mask
# Calculate the summed distance and gradient contributions where the mask
# is valid for logging
summed_distance = np.sum(overlap_error * valid_mask * abs_diff)
# Calculate the gradients
grad_contrib = overlap_error[:, :, np.newaxis] * 2 * diff_matrix
grad += np.sum(grad_contrib * valid_mask[:, :, np.newaxis], axis=1)
grad -= np.sum(grad_contrib * valid_mask[:, :, np.newaxis], axis=0)
# Update the SDRs using the gradient
self.obj_sdrs -= self.lr * grad
Returns:
The summed distance for logging.
"""
# Initialize the gradient array
grad = np.zeros_like(self.obj_sdrs)
# Track the summed distance for logging, i.e. to be able to visualize that
# it is decreasing during each update
summed_distance = 0
# Calculate the gradient for each pair of objects
for i in range(self.n_objects):
for j in range(self.n_objects):
if mask[i, j]:
# As we're optimizing the L2 norm of the difference between
# the two SDRs, the gradient is the difference between the
# two non-binarized (dense) representations.
diff = self.obj_sdrs[i] - self.obj_sdrs[j]
if np.sum(np.abs(diff)) > 0:
summed_distance += overlap_error[i, j] * np.sum(np.abs(diff))
# Multiply the gradient by 2 to exactly match the
# derivative of the L2 norm
grad[i] += overlap_error[i, j] * 2 * diff
grad[j] -= overlap_error[i, j] * 2 * diff
# Update the SDRs using the gradient
self.obj_sdrs -= self.lr * grad
return summed_distance
[docs] def add_objects(self, n_objects):
"""Adds more objects to the available objects and re-initializes the optimizer.
We keep track of the stable representation ids (old objects) when
adding new objects.
Args:
n_objects: Number of objects to add
"""
if n_objects == 0:
return
# store stable data and ids
stable_data = self.obj_sdrs.copy()
self.stable_ids = np.arange(stable_data.shape[0])
new_obj_sdrs = np.random.randn(
stable_data.shape[0] + n_objects, self.sdr_length
)
new_obj_sdrs[: stable_data.shape[0]] = stable_data
self.obj_sdrs = new_obj_sdrs
[docs] def train_sdrs(self, target_overlaps, log_epoch_every=10):
"""Main SDR training function.
This function receives a copy of the average target overlap 2D tensor
and trains the sdr representations for `n_epochs` to achieve these target
overlap scores.
We use the overlap target as a learning signal to move the dense representations
towards or away from each other. The magnitude of the overlap error controls the
strength of moving dense representations. Also the sign of the overlap error
controls whether the representations will be moving towards or away from each
other.
We want to limit the amount by which trained representation change relative
to untrained object representations such that higher-level LMs would not suffer
from significant changes in lower-level representations that were used to build
higher-level graphs.
When adding new representations, we keep track of the ids of the older
representations (i.e., `self.stable_ids`). This allows us to control by how much
the older representations move relative to the newer ones during training. This
behavior is controlled by the stability value. During each training iteration,
we update these older representations with an average of the optimizer output
and the original representation (weighted by the stability value). Note that
too much stability restricts the SDRs from adapting to desired changes in the
target overlaps caused by normalization or distribution shift, affecting the
overall encoding performance.
Consider two dense representations, A_dense and B_dense. We apply top-k
operation on both to convert them to A_sdr and B_sdr, then calculate their
overlaps. If the overlap is less than the target overlap, we move dense
representations (A_dense and B_dense) closer to eachother with strength
proportional to the error in overlaps. We move them apart if they have more
overlap than the target.
Note:
The `distance_matrix` variable is calculated using the cdist function
and it denotes the pairwise euclidean distances between *dense*
representations. The term "overlap" always refers to the overlap in bits
between SDRs.
Note:
The overlap_error is only used to weight the distance_matrix for
each pair of objects, and gradients *do not* flow through the sparse
overlap calculations.
Returns:
The stats dictionary for logging.
"""
# return if no target provided
stats = {}
if np.all(np.isnan(target_overlaps)):
logging.warning("Empty overlap targets. No training needed.")
return stats
if np.all(np.array(target_overlaps.shape) > self.n_objects):
logging.warning(
"Overlap targets have larger size than "
+ f"{(self.n_objects, self.n_objects)}"
)
target_overlaps = target_overlaps[: self.n_objects, : self.n_objects]
# Calculate the training mask and target
# The mask determines the valid entries with value:
# - 1: where overlap target exists
# - 0: where overlap target does not exist as of this episode.
mask = ~np.isnan(target_overlaps)
overlaps = np.nan_to_num(target_overlaps, nan=0)
# logging details
if self.log_flag:
stats["mask"] = mask
stats["target_overlap"] = overlaps
stats["training"] = {}
for epoch in tqdm(range(self.n_epochs)):
# These values are used to pull back the representations from moving
# too far during training. Notice this is only applied on self.stable_ids.
sdrs_stable_before = self.obj_sdrs[self.stable_ids].copy()
# calculate predicted overlaps from existing representations
reps = self.obj_sdrs
bins = self.binarize(reps)
pred_overlaps = bins @ bins.T
# calculate error and optimize
overlap_error = overlaps - pred_overlaps
summed_distance = self.optimize(overlap_error, mask)
# stabilize the SDRs at `self.stable_ids` by pulling them back towards
# sdrs_stable_before
sdrs_stable_after = self.obj_sdrs[self.stable_ids].copy()
self.obj_sdrs[self.stable_ids] = (self.stability * sdrs_stable_before) + (
(1 - self.stability) * sdrs_stable_after
)
# logging details
if self.log_flag and epoch % log_epoch_every == 0:
stats["training"][epoch] = {}
stats["training"][epoch]["obj_dense"] = self.obj_sdrs.copy()
stats["training"][epoch]["obj_sdr"] = bins.copy()
stats["training"][epoch]["overlap_error"] = overlap_error
stats["training"][epoch]["summed_distance"] = summed_distance
# Reset stable ids.
# Stability training only used after adding new objects
self.stable_ids = np.array([]).astype(int)
return stats
[docs] def binarize(self, emb):
"""Convert dense representations to SDRs (0s and 1s) using Top-k function.
Returns:
The SDRs.
"""
topk_indices = np.argsort(emb, axis=1)[:, -self.sdr_on_bits :]
mask = np.zeros_like(emb)
np.put_along_axis(mask, topk_indices, 1, axis=1)
return mask
[docs]class EvidenceSDRTargetOverlaps:
"""Keep track of the running average of target overlaps for each episode.
The target overlaps is implemented as a 2D tensor where the indices
of the tensor represent the ids of the objects, and the values of the
tensor represent a running average of the overlap target.
To achieve this, we implement functions for expanding the size of the
overlap target tensor, linear mapping for normalization, and updating
the overlap tensor (i.e., running average).
Note:
We are averaging over multiple overlap targets.
Multiple targets can happen for different reasons:
- Asymmetric evidences: the target overlap for object 1 w.r.t object 2
([2,1]) is averaged with object 2 w.r.t object 1 ([1,2]). This is
possible because we sort the ids when we add the evidences to
overlaps. Both evidences get added to the location [1,2].
- Additional episodes with similar MLO (most-likely object): More episodes
can accumulate additional evidences on to the same key if the MLO
is similar to previous MLO of another episode.
"""
def __init__(self):
"""Initialize with overlap tensor.
Initialize the class with overlap tensor to store the running average of the
target scores. Additionally we store the counts to easily calculate the running
average.
"""
self._overlaps = np.full((0, 0), np.nan)
self._counts = np.zeros_like(self._overlaps)
@property
def overlaps(self):
"""Returns the target overlap values rounded to the nearest integer."""
# TODO: Experiment without rounding. Shouldn't make much of a difference
# since this is only used to weight the encoding distances.
return np.round(self._overlaps)
[docs] def add_objects(self, new_size):
"""Expands the overlaps and the counts 2D tensors to accomodate new objects."""
# expand the overlaps tensor to the new size
new_overlaps = np.full((new_size, new_size), np.nan)
new_overlaps[: self._overlaps.shape[0], : self._overlaps.shape[0]] = (
self._overlaps
)
self._overlaps = new_overlaps
# expand the counts tensor to the new size
new_counts = np.zeros((new_size, new_size))
new_counts[: self._counts.shape[0], : self._counts.shape[0]] = self._counts
self._counts = new_counts
[docs] def map_to_overlaps(self, evidence, output_range):
"""Linear mapping of values from input range to output range.
Only applies to real values (i.e., ignores nan values).
Returns:
?
"""
valid_ix = ~np.isnan(evidence)
min_evidence = np.nanmin(evidence[valid_ix])
max_evidence = np.nanmax(evidence[valid_ix])
input_range = [min_evidence, max_evidence]
output_range_diff = output_range[1] - output_range[0]
input_range_diff = input_range[1] - input_range[0]
evidence[valid_ix] = (evidence[valid_ix] - input_range[0]) * (
output_range_diff
) / input_range_diff + output_range[0]
return evidence
[docs] def add_overlaps(self, mapped_overlaps):
"""Main function for updating the running average with overlaps.
The running average equation we use is:
new_average = ((old_average * counts) + (new_val * 1))/ (counts + 1)
This calculates equally-weighted average, assuming that we keep track
of the counts and increment them every time we add a new value to the
average.
"""
# calculate the mask of indices for existing avg overlaps and
# new overlaps. The mask should be True where both values are
# not nan
mask_avg = np.logical_and(~np.isnan(self._overlaps), ~np.isnan(mapped_overlaps))
# apply the running average equation explained in the docstring
self._overlaps[mask_avg] = (
(self._overlaps[mask_avg] * self._counts[mask_avg])
+ mapped_overlaps[mask_avg]
) / (self._counts[mask_avg] + 1)
# calculate the mask of indices with True values where existing overlaps
# are nan and new overlaps are not nan.
mask_overwrite = np.logical_and(
np.isnan(self._overlaps), ~np.isnan(mapped_overlaps)
)
# overlap existing nan values in `self._overlaps` with new overlaps values
# in `mapped_overlaps`
self._overlaps[mask_overwrite] = mapped_overlaps[mask_overwrite]
# update counts of all new entries
self._counts[np.logical_or(mask_avg, mask_overwrite)] += 1
[docs] def add_evidence(self, evidence, mapping_output_range):
"""Main function for updating the running average with evidence.
This function receives as input the relative evidence scores and
maps them to overlaps in the `mapping_output_range`. The mapped
overlaps are added to the running average in the function `add_overlaps`.
"""
# map relative evidences of the current episode to the output range
mapped_overlaps = self.map_to_overlaps(evidence, mapping_output_range)
# add overlaps to running average
self.add_overlaps(mapped_overlaps)
[docs]class EvidenceSDRLMMixin:
"""This Mixin adds training of SDR representations to the EvidenceGraphLM.
It overrides the __init__ and post_episode functions of the LM
To use this Mixin, pass the EvidenceSDRGraphLM class as the `learning_module_class`
in the `learning_module_configs`.
Additionally pass the `sdr_args` dictionary as an additional key
in the `learning_module_args`.
The sdr_args dictionary should contain:
- `log_path` (string): A string that points to a temporary location for saving
experiment logs. "None" means don't save to file
- `sdr_length` (int): The size of the SDR to be used for encoding
- `sdr_on_bits` (int): The number of active bits to be used with these SDRs
- `sdr_lr` (float): The learning rate of the encoding algorithm
- `n_sdr_epochs` (int): The number of epochs to train the encoding algorithm
- `stability` (float): Stability of older object SDRs.
Value range is [0.0, 1.0], where 0.0 is no stability
applied and 1.0 is fixed SDRs.
- `sdr_log_flag` (bool): Flag indicating whether to log the results or not
See the `monty_lab` repo for reference. Specifically,
`experiments/configs/evidence_sdr_evaluation.py`
"""
def __init__(self, *args, **kwargs):
"""The mixin overrides the `__init__` function of the Learning Module.
The encoding algorithm is initialized here and it stores the actual SDRs.
Also, a temporary logging function is initialized here.
"""
self.sdr_args = kwargs.pop("sdr_args")
super().__init__(*args, **kwargs)
# keeps track of the Graph objects and their ids
self.obj2id = {}
self.id2obj = {}
# keeps track of overlap running average values
self.target_overlaps = EvidenceSDRTargetOverlaps()
# initialize the encoding algorithm
self.sdr_encoder = EncoderSDR(
sdr_length=self.sdr_args["sdr_length"],
sdr_on_bits=self.sdr_args["sdr_on_bits"],
lr=self.sdr_args["sdr_lr"],
n_epochs=self.sdr_args["n_sdr_epochs"],
log_flag=self.sdr_args["sdr_log_flag"],
)
# TODO: remove this logger and merge with the Monty Loggers after
# issue #328 is fixed.
if self.sdr_args["sdr_log_flag"]:
self.tmp_logger = LoggerSDR(self.sdr_args["log_path"])
[docs] def collect_evidences(self):
"""Collect evidence scores from the Learning Module.
We do this in three steps:
- Step 1: We use the number of objects in the LM
to update the sdr_encoder and id <-> obj tracking
dictionaries, as well as the target overlap tensor
- Step 2: We collect evidences relative to the current
most likely hypothesis (mlh). Evidences are stored
in a 2d tensor.
- Step 3: We use the stored evidences to update the target overlap
which stores the running average.
Refer to `EvidenceSDRTargetOverlaps` for more details.
**Note:** We sort the ids in step 2 because the overlap values are
suppossed to be symmetric (e.g., "2,5" = "5,2"). This way the target
overlaps for the ids "x,y" and "y,x" will be averaged together in the
`EvidenceSDRTargetOverlaps` class.
"""
# TODO: add more sophisticated logic to sync the SDR representations
# with available objects in the graph memory. This should facilitate
# merging or removing objects. The SDR representations should always
# be in sync with graphs in memory
# Step 1: add new objects if needed. Useful in learning from scratch experiments
available_objects = self.get_all_known_object_ids()
for obj in available_objects:
if obj not in self.obj2id:
self.obj2id[obj] = len(self.obj2id)
self.id2obj[len(self.id2obj)] = obj
self.sdr_encoder.add_objects(len(self.id2obj) - self.sdr_encoder.n_objects)
self.target_overlaps.add_objects(len(self.id2obj))
# Step 2: collect evidences
mlh_object = self.get_current_mlh()["graph_id"]
if mlh_object == "no_observations_yet" or self.sdr_encoder.n_objects == 1:
return
mlh_object_id = self.obj2id[mlh_object]
mlh_evidence = np.max(self.evidence[mlh_object])
relative_evidences = np.full_like(self.target_overlaps.overlaps, np.nan)
for obj in self.evidence.keys():
ids = sorted([mlh_object_id, self.obj2id[obj]])
ev = np.max(self.evidence[obj]) - mlh_evidence
relative_evidences[ids[0], ids[1]] = ev
# Step 3: update running average with new evidence scores
self.target_overlaps.add_evidence(
relative_evidences, [0, self.sdr_args["sdr_on_bits"]]
)
[docs] def post_episode(self, *args, **kwargs):
"""Overrides the LM post_episode function.
This function collects evidences, trains SDRs and logs the output.
"""
super().post_episode(*args, **kwargs)
# collect the evidences from Learning Module
self.collect_evidences()
# Train the SDR Encoder based on overlap targets
stats = self.sdr_encoder.train_sdrs(self.target_overlaps.overlaps)
# logging episode information if flag set to True
if self.sdr_args["sdr_log_flag"]:
stats.update(
{
"obj2id": self.obj2id,
"id2obj": self.id2obj,
}
)
self.tmp_logger.log_episode(stats)
def _check_use_features_for_matching(self):
"""Check if features should be used for matching.
EvidenceGraphLM bypasses comparing object ID by checking the number
of features on the input channel. In this Mixin we want to use all
features for matching.
Returns:
A dictionary indicating whether to use features for each input channel.
"""
use_features = dict()
for input_channel in self.tolerances.keys():
if input_channel not in self.feature_weights.keys():
use_features[input_channel] = False
elif self.feature_evidence_increment <= 0:
use_features[input_channel] = False
else:
use_features[input_channel] = True
return use_features
def _object_id_to_features(self, object_id):
"""Retrieves the trained SDR corresponding to the object ID.
Returns:
The trained SDR corresponding to the object ID.
"""
if object_id in self.obj2id:
return self.sdr_encoder.get_sdr(self.obj2id[object_id])
else:
return np.zeros(self.sdr_args["sdr_length"])
def _calculate_feature_evidence_sdr_for_all_nodes(
self, query_features, input_channel, graph_id
):
"""Calculate overlap between stored and query SDR features.
Calculates the overlap between the SDR features stored at every location in
the graph and the query SDR feature. This overlap is then compared to the
tolerance value and the result is used for adjusting the evidence score.
We use the tolerance (in overlap bits) for generalization. If two objects are
close enough, their overlap in bits should be higher that the set tolerance
value.
The tolerance sets the lowest overlap for adding evidence, the range
[tolerance, sdr_on_bits] is mapped to [0,1] evidence points. Any overlap less
then tolerance will not add any evidence. These evidence scores are then
multiplied by the feature weight of object_ids which scales all of the
evidence points to the range [0, feature_weights[input_channel]["object_id"]].
The below variables have the following shapes:
- feature_array: (n, sdr_length)
- query_features[input_channel]["object_id"]: (sdr_length)
- query_feat: (sdr_length, 1)
- np.matmul(feature_array, query_feat): (n, 1)
- overlaps: (n)
Returns:
The normalized overlaps.
"""
feature_array = self.graph_memory.get_feature_array(graph_id)[input_channel]
query_feat = np.expand_dims(query_features[input_channel]["object_id"], 1)
tolerance = self.tolerances[input_channel]["object_id"]
sdr_on_bits = query_feat.sum(axis=0)
overlaps = feature_array @ query_feat.squeeze(-1)
normalized_overlaps = (overlaps - tolerance) / (sdr_on_bits - tolerance)
normalized_overlaps[normalized_overlaps < 0] = 0.0
normalized_overlaps *= self.feature_weights[input_channel]["object_id"]
return normalized_overlaps
def _calculate_feature_evidence_for_all_nodes(
self, query_features, input_channel, graph_id
):
"""Calculates feature evidence for all nodes stored in a graph.
This override method tests if the input_channel is a learning_module. If so,
a different function is used for feature comparison.
Note: This assumes that learning modules always outputs 1 feature, object_id.
If the learning modules output more than object_id features, we need to
compare these according to their weights.
Returns:
The feature evidence for all nodes.
"""
if input_channel.startswith("learning_module"):
return self._calculate_feature_evidence_sdr_for_all_nodes(
query_features, input_channel, graph_id
)
return super()._calculate_feature_evidence_for_all_nodes(
query_features, input_channel, graph_id
)
[docs]class EvidenceSDRGraphLM(EvidenceSDRLMMixin, EvidenceGraphLM):
"""Class that incorporates the EvidenceSDR Mixin with the EvidenceGraphLM."""
pass