Source code for tbp.monty.frameworks.models.evidence_matching.hypotheses_displacer

# Copyright 2025 Thousand Brains Project
#
# Copyright may exist in Contributors' modifications
# and/or contributions to the work.
#
# Use of this source code is governed by the MIT
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.

from __future__ import annotations

import logging
from typing import Protocol, Type

import numpy as np

from tbp.monty.frameworks.models.evidence_matching.feature_evidence.calculator import (
    DefaultFeatureEvidenceCalculator,
    FeatureEvidenceCalculator,
)
from tbp.monty.frameworks.models.evidence_matching.graph_memory import (
    EvidenceGraphMemory,
)
from tbp.monty.frameworks.models.evidence_matching.hypotheses import ChannelHypotheses
from tbp.monty.frameworks.utils.graph_matching_utils import (
    get_custom_distances,
    get_relevant_curvature,
)
from tbp.monty.frameworks.utils.spatial_arithmetics import (
    get_angles_for_all_hypotheses,
    rotate_pose_dependent_features,
)

logger = logging.getLogger(__name__)


[docs]class HypothesesDisplacer(Protocol):
[docs] def displace_hypotheses_and_compute_evidence( self, channel_displacement: np.ndarray, channel_features: dict, evidence_update_threshold: float, graph_id: str, possible_hypotheses: ChannelHypotheses, total_hypotheses_count: int, ) -> ChannelHypotheses: """Updates evidence by comparing features after applying sensed displacement. This function applies the sensor displacement to the existing hypothesis and uses the result as search locations for comparing the sensed features. This comparison is used to update the evidence scores of the existing hypotheses. The hypotheses locations are updated to the new locations (i.e., after displacement) Args: channel_displacement (np.ndarray): Channel-specific sensor displacement. channel_features (dict): Channel-specific input features. evidence_update_threshold (float): Evidence update threshold. graph_id (str): The ID of the current graph possible_hypotheses (ChannelHypotheses): Channel-specific possible hypotheses. total_hypotheses_count (int): Total number of hypotheses in the graph. Returns: ChannelHypotheses: Displaced hypotheses with computed evidence. """ ...
[docs]class DefaultHypothesesDisplacer: def __init__( self, feature_weights: dict, graph_memory: EvidenceGraphMemory, max_match_distance: float, tolerances: dict, use_features_for_matching: dict[str, bool], feature_evidence_calculator: Type[ FeatureEvidenceCalculator ] = DefaultFeatureEvidenceCalculator, feature_evidence_increment: int = 1, max_nneighbors: int = 3, past_weight: float = 1, present_weight: float = 1, ): """Initializes the DefaultHypothesesDisplacer. Args: feature_weights (dict): How much should each feature be weighted when calculating the evidence update for hypothesis. Weights are stored in a dictionary with keys corresponding to features (same as keys in tolerances). graph_memory (EvidenceGraphMemory): The graph memory to read graphs from. max_match_distance (float): Maximum distance of a tested and stored location to be matched. tolerances (dict): How much can each observed feature deviate from the stored features to still be considered a match. use_features_for_matching (dict): Dictionary mapping input channels to booleans indicating whether to use features for matching. feature_evidence_calculator (Type[FeatureEvidenceCalculator]): Class to calculate feature evidence for all nodes. Defaults to the default calculator. feature_evidence_increment (int): Feature evidence (between 0 and 1) is multiplied by this value before being added to the overall evidence of a hypothesis. This factor is only multiplied with the feature evidence (not the pose evidence as opposed to the present_weight). Defaults to 1. max_nneighbors (int): Maximum number of nearest neighbors to consider in the radius of a hypothesis for calculating the evidence. Defaults to 3. past_weight (float): How much should the evidence accumulated so far be weighted when combined with the evidence from the most recent observation. Defaults to 1. present_weight (float): How much should the current evidence be weighted when added to the previous evidence. If past_weight and present_weight add up to 1, the evidence is bounded and can't grow infinitely. Defaults to 1. NOTE: right now this doesn't give as good performance as with unbounded evidence since we don't keep a full history of what we saw. With a more efficient policy and better parameters that may be possible to use though and could help when moving from one object to another and to generally make setting thresholds etc. more intuitive. """ self.feature_evidence_calculator = feature_evidence_calculator self.feature_evidence_increment = feature_evidence_increment self.feature_weights = feature_weights self.graph_memory = graph_memory self.max_match_distance = max_match_distance self.max_nneighbors = max_nneighbors self.past_weight = past_weight self.present_weight = present_weight self.tolerances = tolerances self.use_features_for_matching = use_features_for_matching
[docs] def displace_hypotheses_and_compute_evidence( self, channel_displacement: np.ndarray, channel_features: dict, evidence_update_threshold: float, graph_id: str, possible_hypotheses: ChannelHypotheses, total_hypotheses_count: int, ) -> ChannelHypotheses: # Have to do this for all hypotheses so we don't loose the path information rotated_displacements = possible_hypotheses.poses.dot(channel_displacement) search_locations = possible_hypotheses.locations + rotated_displacements # Get indices of hypotheses with evidence > threshold hyp_ids_to_test = np.where( possible_hypotheses.evidence >= evidence_update_threshold )[0] num_hypotheses_to_test = hyp_ids_to_test.shape[0] if num_hypotheses_to_test > 0: logger.info( f"Testing {num_hypotheses_to_test} out of " f"{total_hypotheses_count} hypotheses for {graph_id} " f"(evidence > {evidence_update_threshold})" ) # Get evidence update for all hypotheses with evidence > current # _evidence_update_threshold new_evidence = self._calculate_evidence_for_new_locations( graph_id=graph_id, input_channel=possible_hypotheses.input_channel, search_locations=search_locations[hyp_ids_to_test], channel_possible_poses=possible_hypotheses.poses[hyp_ids_to_test], channel_features=channel_features, ) min_update = np.clip(np.min(new_evidence), 0, np.inf) # Alternatives (no update to other Hs or adding avg) left in # here in case we want to revert back to those. # avg_update = np.mean(new_evidence) # evidence_to_add = np.zeros_like(channel_hypotheses_evidence) evidence_to_add = np.ones_like(possible_hypotheses.evidence) * min_update evidence_to_add[hyp_ids_to_test] = new_evidence # If past and present weight add up to 1, equivalent to # np.average and evidence will be bound to [-1, 2]. Otherwise it # keeps growing. evidence = ( possible_hypotheses.evidence * self.past_weight + evidence_to_add * self.present_weight ) else: evidence = possible_hypotheses.evidence return ChannelHypotheses( input_channel=possible_hypotheses.input_channel, evidence=evidence, locations=search_locations, poses=possible_hypotheses.poses, )
def _calculate_evidence_for_new_locations( self, graph_id: str, input_channel: str, search_locations: np.ndarray, channel_possible_poses: np.ndarray, channel_features: dict, ): """Use search locations, sensed features and graph model to calculate evidence. First, the search locations are used to find the nearest nodes in the graph model. Then we calculate the error between the stored pose features and the sensed ones. Additionally we look at whether the non-pose features match at the neighboring nodes. Everything is weighted by the nodes distance from the search location. If there are no nodes in the search radius (max_match_distance), evidence = -1. We do this for every incoming input channel and its features if they are stored in the graph and take the average over the evidence from all input channels. Returns: The location evidence. """ logger.debug( f"Calculating evidence for {graph_id} using input from {input_channel}" ) pose_transformed_features = rotate_pose_dependent_features( channel_features, channel_possible_poses, ) # Get max_nneighbors nearest nodes to search locations. nearest_node_ids = self.graph_memory.get_graph( graph_id, input_channel ).find_nearest_neighbors( search_locations, num_neighbors=self.max_nneighbors, ) if self.max_nneighbors == 1: nearest_node_ids = np.expand_dims(nearest_node_ids, axis=1) nearest_node_locs = self.graph_memory.get_locations_in_graph( graph_id, input_channel )[nearest_node_ids] max_abs_curvature = get_relevant_curvature(channel_features) custom_nearest_node_dists = get_custom_distances( nearest_node_locs, search_locations, pose_transformed_features["pose_vectors"][:, 0], max_abs_curvature, ) # shape=(H, K) node_distance_weights = self._get_node_distance_weights( custom_nearest_node_dists ) # Get IDs where custom_nearest_node_dists > max_match_distance mask = node_distance_weights <= 0 new_pos_features = self.graph_memory.get_features_at_node( graph_id, input_channel, nearest_node_ids, feature_keys=["pose_vectors", "pose_fully_defined"], ) # Calculate the pose error for each hypothesis # shape=(H, K) radius_evidence = self._get_pose_evidence_matrix( pose_transformed_features, new_pos_features, input_channel, node_distance_weights, ) # Set the evidences which are too far away to -1 radius_evidence[mask] = -1 # If a node is too far away, weight the negative evidence fully (*1). This # only comes into play if there are no nearby nodes in the radius, then we # want an evidence of -1 for this hypothesis. # NOTE: Currently we don't weight the evidence by distance so this doesn't # matter. node_distance_weights[mask] = 1 # If no feature weights are provided besides the ones for point_normal # and curvature_directions we don't need to calculate feature evidence. if self.use_features_for_matching[input_channel]: # add evidence if features match node_feature_evidence = self.feature_evidence_calculator.calculate( channel_feature_array=self.graph_memory.get_feature_array(graph_id)[ input_channel ], channel_feature_order=self.graph_memory.get_feature_order(graph_id)[ input_channel ], channel_feature_weights=self.feature_weights[input_channel], channel_query_features=channel_features, channel_tolerances=self.tolerances[input_channel], input_channel=input_channel, ) hypothesis_radius_feature_evidence = node_feature_evidence[nearest_node_ids] # Set feature evidence of nearest neighbors that are too far away to 0 hypothesis_radius_feature_evidence[mask] = 0 # Take the maximum feature evidence out of the nearest neighbors in the # search radius and weighted by its distance to the search location. # Evidence will be in [0, 1] and is only 1 if all features match # perfectly and the node is at the search location. radius_evidence = ( radius_evidence + hypothesis_radius_feature_evidence * self.feature_evidence_increment ) # We take the maximum to be better able to deal with parts of the model where # features change quickly and we may have noisy location information. This way # we check if we can find a good match of pose features within the search # radius. It doesn't matter if there are also points stored nearby in the model # that are not a good match. # Removing the comment weights the evidence by the nodes distance from the # search location. However, epirically this did not seem to help. # shape=(H,) location_evidence = np.max( radius_evidence, # * node_distance_weights, axis=1, ) return location_evidence def _get_node_distance_weights(self, distances): node_distance_weights = ( self.max_match_distance - distances ) / self.max_match_distance return node_distance_weights def _get_pose_evidence_matrix( self, query_features, node_features, input_channel, node_distance_weights, ): """Get angle mismatch error of the three pose features for multiple points. Args: query_features: Observed features. node_features: Features at nodes that are being tested. input_channel: Input channel for which we want to calculate the pose evidence. This are all input channels that are received at the current time step and are also stored in the graph. node_distance_weights: Weights for each nodes error (determined by distance to the search location). Currently not used, except for shape. Returns: The sum of angle evidence weighted by weights. In range [-1, 1]. """ # TODO S: simplify by looping over pose vectors evidences_shape = node_distance_weights.shape[:2] pose_evidence_weighted = np.zeros(evidences_shape) # TODO H: at higher level LMs we may want to look at all pose vectors. # Currently we skip the third since the second curv dir is always 90 degree # from the first. # Get angles between three pose features pn_error = get_angles_for_all_hypotheses( # shape of node_features[input_channel]["pose_vectors"]: (nH, knn, 9) node_features["pose_vectors"][:, :, :3], query_features["pose_vectors"][:, 0], # shape (nH, 3) ) # Divide error by 2 so it is in range [0, pi/2] # Apply sin -> [0, 1]. Subtract 0.5 -> [-0.5, 0.5] # Negate the error to get evidence (lower error is higher evidence) pn_evidence = -(np.sin(pn_error / 2) - 0.5) pn_weight = self.feature_weights[input_channel]["pose_vectors"][0] # If curvatures are same the directions are meaningless # -> set curvature angle error to zero. if not query_features["pose_fully_defined"]: cd1_weight = 0 # Only calculate curv dir angle if sensed curv dirs are meaningful cd1_evidence = np.zeros(pn_error.shape) else: cd1_weight = self.feature_weights[input_channel]["pose_vectors"][1] # Also check if curv dirs stored at node are meaningful use_cd = np.array( node_features["pose_fully_defined"][:, :, 0], dtype=bool, ) cd1_angle = get_angles_for_all_hypotheses( node_features["pose_vectors"][:, :, 3:6], query_features["pose_vectors"][:, 1], ) # Since curvature directions could be rotated 180 degrees we define the # error to be largest when the angle is pi/2 (90 deg) and angles 0 and # pi are equal. This means the angle error will be between 0 and pi/2. cd1_error = np.pi / 2 - np.abs(cd1_angle - np.pi / 2) # We then apply the same operations as on pn error to get cd1_evidence # in range [-0.5, 0.5] cd1_evidence = -(np.sin(cd1_error) - 0.5) # nodes where pc1==pc2 receive no cd evidence but twice the pn evidence # -> overall evidence can be in range [-1, 1] cd1_evidence = cd1_evidence * use_cd pn_evidence[np.logical_not(use_cd)] * 2 # weight angle errors by feature weights # if sensed pc1==pc2 cd1_weight==0 and overall evidence is in [-0.5, 0.5] # otherwise it is in [-1, 1]. pose_evidence_weighted += pn_evidence * pn_weight + cd1_evidence * cd1_weight return pose_evidence_weighted