Source code for tbp.monty.frameworks.models.displacement_matching

# Copyright 2025 Thousand Brains Project
# Copyright 2022-2024 Numenta Inc.
#
# Copyright may exist in Contributors' modifications
# and/or contributions to the work.
#
# Use of this source code is governed by the MIT
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.

import logging

import numpy as np
import torch
from scipy.spatial.transform import Rotation

from tbp.monty.frameworks.environment_utils.graph_utils import get_edge_index
from tbp.monty.frameworks.models.graph_matching import GraphLM, GraphMemory
from tbp.monty.frameworks.models.object_model import GraphObjectModel
from tbp.monty.frameworks.utils.graph_matching_utils import is_in_ranges
from tbp.monty.frameworks.utils.sensor_processing import point_pair_features


[docs]class DisplacementGraphLM(GraphLM): """Learning module that uses displacement stored in graphs to recognize objects.""" def __init__( self, k=None, match_attribute=None, tolerance=0.001, use_relative_len=False, graph_delta_thresholds=None, ): """Initialize Learning Module. Args: k: How many nearest neighbors should nodes in graphs connect to. match_attribute: Which displacement to use for matching. Should be in ['displacement', 'PPF']. tolerance: How close does an observed displacement have to be to a stored displacement to be considered a match,. defaults to 0.001 use_relative_len: Whether to scale the displacements to achieve scale invariance. Only possible when using PPF. graph_delta_thresholds: Thresholds used to compare nodes in the graphs being learned, and thereby whether to include a new point or not. By default, we only consider the distance between points, using a threshold of 0.001 (determined in remove_close_points). Can also specify thresholds based on e.g. point-normal angle difference, or principal curvature magnitude difference. """ super(DisplacementGraphLM, self).__init__() self.graph_memory = DisplacementGraphMemory( graph_delta_thresholds=graph_delta_thresholds, k=k, match_attribute=match_attribute, ) self.match_attribute = match_attribute self.tolerance = tolerance self.use_relative_len = use_relative_len # =============== Public Interface Functions =============== # ------------------- Main Algorithm -----------------------
[docs] def reset(self): """Call this before each episode.""" # reset possible matches for paths on objects ( self.possible_matches, self.possible_paths, self.next_possible_paths, self.scale_factors, ) = self.graph_memory.get_initial_hypotheses()
# ------------------ Getters & Setters ---------------------
[docs] def get_unique_pose_if_available(self, object_id): """Compute (location, rotation, scale) of object. If we are sure about where on the object we are compare the sensed displacements to the observed displacements to calculate the pose, else return None. Returns: The pose and scale of the object. """ pose_and_scale = None possible_paths = self.get_possible_paths()[object_id] # If multiple paths are possible, return None if len(possible_paths) == 1: # TODO H: Do we want to clean up all this first channel stuff # in the old LMs? first_channel = self.buffer.get_first_sensory_input_channel() detected_path = possible_paths[0] # get locations in model RF for nodes (int IDs) in the detected path detected_path_locs = self.graph_memory.get_locations_in_graph( object_id, input_channel=first_channel )[detected_path] # The location in object RF where the sensor is right now will be the last # on in the detected path current_model_loc = detected_path_locs[-1] model_displacements = np.array( [ np.array(detected_path_locs[i + 1] - detected_path_locs[i]) for i in range(len(detected_path_locs) - 1) ] ) r_euler, _, r = self.get_object_rotation( sensed_displacements=np.array( self.buffer.displacements[first_channel]["displacement"][1:] ), model_displacements=model_displacements, get_reverse_r=False, ) # If r_euler is not None, we have a unique rotation if r_euler is not None: self.detected_rotation_r = r scale = self.get_object_scale( np.array( self.buffer.get_nth_displacement(1, input_channel=first_channel) ), model_displacements[0], ) pose_and_scale = np.concatenate([current_model_loc, r_euler, [scale]]) self.detected_pose = pose_and_scale lm_episode_stats = { "detected_path": detected_path, "detected_location_on_model": current_model_loc, "detected_location_rel_body": self.buffer.get_current_location( input_channel=first_channel ), "detected_rotation": r_euler, "detected_rotation_quat": r.as_quat(), "detected_scale": scale, } self.buffer.add_overall_stats(lm_episode_stats) logging.debug(f"(location, rotation, scale): {pose_and_scale}") return pose_and_scale
[docs] def get_object_rotation( self, sensed_displacements, model_displacements, get_reverse_r=False ): """Calculate the rotation between two sets of displacement vectors. Args: sensed_displacements: The displacements that were sensed. model_displacements: The displacements in the model that were matched to the sensed displacements. get_reverse_r: Whether to get the rotation that turns the model such that it would produce the sensed_displacements (False) or the rotation needed to turn the sensed_displacements into the model displacements. Returns: The rotation in Euler angles, as a matrix, and as a Rotation object. """ try: if get_reverse_r: r, msr = Rotation.align_vectors( sensed_displacements, model_displacements ) else: r, msr = Rotation.align_vectors( model_displacements, sensed_displacements ) except UserWarning: # This can happen if the displacements that were sampled lie in one plane # such that we can not determine the rotation along all three axes. print("could not determine rotation uniquely -> keep moving!") return None, None r_euler = np.round(r.as_euler("xyz", degrees=True), 3) r_matrix = r.as_matrix() return r_euler, r_matrix, r
[docs] def get_object_scale(self, sensed_displacement, model_displacement): """Calculate the objects scale given sensed and model displacements. Returns: The scale of the object. """ scale = np.linalg.norm(sensed_displacement) / np.linalg.norm(model_displacement) return scale
# ------------------ Logging & Saving ---------------------- # ======================= Private ========================== # ------------------- Main Algorithm ----------------------- def _compute_possible_matches(self, observation, not_moved=False): """Use the current observation to narrown down the possible matches. This is framed as a prediction problem. We take the current observation as a query and try to predict whether after the displacement we will still be on the object. In a next step we could also predict the feature that we sense next. The prediction is then compared with he actual obervation (currently just whether we sensed on_object or not). If there is a prediction error, then we remove the object from the possible matches. Args: observation: The current observation. not_moved: Whether the agent has moved yet. True on the first step. """ if not_moved: return if self.match_attribute == "displacement": query = self.buffer.get_current_displacement(input_channel="first") elif self.match_attribute == "PPF": query = self.buffer.get_current_ppf(input_channel="first") else: logging.error("match_attribute not defined") # This is just whether we are on the object or not here. target = self._select_features_to_use(observation) if self.match_attribute == "PPF" and self.use_relative_len: query[0] = query[0] / self.buffer.get_first_displacement_len( input_channel="first" ) logging.debug(f"query: {query}") self._update_possible_matches(query=query, target=target) def _update_possible_matches(self, query, target, threshold=0): """Update the list of possible matches. This is done by excluding objects that had a prediction error > threshold. Args: query: Incoming displacement. target: Whether we expect to be on the object of not after the given displacement. threshold: How high can the prediction error be for the object to still be considered? With binary predictions this is just 0. With features we may want to adapt this. """ predictions = self._make_predictions( query=query, use_relative_len=self.use_relative_len, ) prediction_error = self._get_prediction_error(predictions, target=target) for graph_id in prediction_error: if prediction_error[graph_id] > threshold: self.possible_matches.pop(graph_id) def _make_predictions(self, query, use_relative_len): """Predict whether we will still be on the object given a displacement. Args: query: x,y,z query (where the sensor moved). For _predict_using_displacements this is a displacement vector. ranges: Range in which each element in the displacement can be to be classified as the same. use_relative_len: When matching, use relative displacement lengths instead of absolute. This may help with scale invariance. Returns: dict: Binary predictions for each graph in possible_matches whether the object will be at the new location. """ predictions = {} for graph_id in self.possible_matches: prediction = self._predict_using_displacements( np.array(query), graph_id, use_relative_len ) predictions[graph_id] = prediction return predictions def _predict_using_displacements( self, displacement, graph_id, use_relative_len, ): """Predict whether we will still be on the object given a displacement. Takes a displacement as input (the last action that was performed) and checks for a specific object in memory whether the displacement could end up on one of its nodes given the current possible nodes after the past series of displacements. Args: displacement: 3D displacement vector of the last action. Will be compared to displacements between nodes of a graph. graph_id: id of graph which is used to make predictions. ranges: Range in which each element in the displacement can be to be classified as the same. -> how exact do they need to match? use_relative_len: When matching, use relative displacement lengths instead of absolute. This may help with scale invariance. Returns: int: Whether the displacement is on the object. 0 if not, 1 if it is. """ # TODO: Due to the use of node IDs as paths start IDs it a bit tricky to use # multiple input channels & I am not sure if it is worth the time investment atm # since we don't actively use this LM. So for now we just take the first input # channel here. first_input_channel = list(self.possible_matches[graph_id].keys())[0] displacement_plus_tolerance = np.stack( [displacement - self.tolerance, displacement + self.tolerance], axis=1, ) self.possible_paths[graph_id] = self.next_possible_paths[graph_id] # possible_next_nodes = [] new_possible_paths = [] current_possible_paths = [] path_scale_factors = [] # for node in self.possible_nodes[graph_id]: for path_id, path in enumerate(self.possible_paths[graph_id]): previous_node = path[-2] current_node = path[-1] edge_id = get_edge_index( self.possible_matches[graph_id][first_input_channel], previous_node, current_node, ) node_displacement = ( self.possible_matches[graph_id][first_input_channel] .edge_attr[edge_id] .detach() .clone() ) if use_relative_len: node_displacement[0] = ( node_displacement[0] / self.scale_factors[graph_id][path_id] ) if is_in_ranges(node_displacement, displacement_plus_tolerance): current_possible_paths.append(path) edges_of_node = np.where( self.possible_matches[graph_id][first_input_channel].edge_index[0] == current_node )[0] next_nodes = self.possible_matches[graph_id][ first_input_channel ].edge_index[1][edges_of_node] for next_node in next_nodes: new_possible_paths.append(np.append(path, int(next_node))) path_scale_factors.append(self.scale_factors[graph_id][path_id]) self.possible_paths[graph_id] = current_possible_paths self.next_possible_paths[graph_id] = new_possible_paths self.scale_factors[graph_id] = path_scale_factors # logging.info( # "possible paths for " # + graph_id # + ": " # + str(self.possible_paths[graph_id]) # ) # logging.info( # "next possible paths for " # + graph_id # + ": " # + str(self.next_possible_paths[graph_id]) # ) if len(self.possible_paths[graph_id]) == 0: return 0 else: return 1 def _get_prediction_error(self, predictions, target): """Calculate the prediction error (binary if not using features). Args: predictions: A binary prediction on the objects morphology (object there or not) per graph. target: The actual sensation at the new location (also binary) Returns: prediction_error: Binary prediction error for each graph: int(target != prediction) """ prediction_error = {} for graph_id in predictions: prediction_error[graph_id] = int(target != predictions[graph_id]) return prediction_error # ------------------------ Helper -------------------------- def _add_displacements(self, obs): """Add displacements to the current observation. The observation consists of features at a location. To get the displacement we have to look at the previous observation stored in the buffer. TODO: Should we move this and a (short term) buffer to the sensor module? Returns: The observations with displacements added. """ displacement = np.zeros(3) ppf = np.zeros(4) # TODO S: calculate displacements for each separately (mostly for rotation disp) obs_to_use = obs[0] if len(self.buffer) > 0: # TODO S: Make sure result of get_current_location() and get_current_pose() # is on object (should always be atm). displacement = np.array( obs_to_use.location ) - self.buffer.get_current_location(input_channel=obs_to_use.sender_id) pos1 = torch.tensor( self.buffer.get_current_location(input_channel=obs_to_use.sender_id) ) pos2 = torch.tensor(obs_to_use.location) norm1 = torch.tensor( # element 0 of current pose is location, element 1 is point normal self.buffer.get_current_pose(input_channel=obs_to_use.sender_id)[1], dtype=torch.float64, ) norm2 = torch.tensor( obs_to_use.get_nth_pose_vector(pose_vector_index=0), dtype=torch.float64, ) ppf = point_pair_features(pos1, pos2, norm1, norm2) for o in obs: o.set_displacement(displacement=displacement, ppf=ppf) return obs def _select_features_to_use(self, states): """Extract on_object from observed features to use as target. Returns: int: Whether we are on the object or not. """ morph_features = states[0].morphological_features # TODO S: decide if we want to store on_object in state if "on_object" in morph_features: on_object = morph_features["on_object"] else: on_object = 1 return int(on_object) # ----------------------- Logging -------------------------- def _add_detailed_stats(self, stats): stats["possible_paths"] = self.get_possible_paths() return stats
[docs]class DisplacementGraphMemory(GraphMemory): """Graph memory that stores graphs with displacements as edges.""" def __init__(self, match_attribute, *args, **kwargs): """Initialize Graph memory.""" super(DisplacementGraphMemory, self).__init__(*args, **kwargs) self.match_attribute = match_attribute # =============== Public Interface Functions =============== # ------------------- Main Algorithm -----------------------
[docs] def get_initial_hypotheses(self): possible_matches = self.get_all_models_in_memory() possible_paths = {} next_possible_paths = {} # Need this for scale factors to work scale_factors = {} for graph_id in self.get_memory_ids(): first_input_channel = self.get_input_channels_in_graph(graph_id)[0] next_possible_paths[graph_id] = np.swapaxes( self.get_graph(graph_id, first_input_channel).edge_index, 0, 1 ) if self.get_graph(graph_id, first_input_channel).x.dim() > 1: # Features of nodes contain more than just IDs (i.e. RGBA) possible_paths[graph_id] = self.get_graph_node_ids( graph_id, first_input_channel ) else: possible_paths[graph_id] = np.array( self.get_graph(graph_id, first_input_channel).x ) scale_factors[graph_id] = np.array( self.get_graph(graph_id, first_input_channel).edge_attr[:, 0] ) return ( possible_matches, possible_paths, next_possible_paths, scale_factors, )
# ------------------ Getters & Setters --------------------- # ------------------ Logging & Saving ----------------------
[docs] def load_state_dict(self, state_dict): """Load graphs into memory from a state_dict and add point pair features.""" logging.info("loading models") for obj_name, model in state_dict.items(): logging.debug(f"loading {obj_name}: {model}") for input_channel in model: if (self.match_attribute == "PPF") and ( model[input_channel].has_ppf is False ): model[input_channel].add_ppf_to_graph() self._add_graph_to_memory(model, obj_name)
# ======================= Private ========================== # ------------------- Main Algorithm ----------------------- def _build_graph(self, locations, features, graph_id, input_channel): """Build a k nearest neighbor graph from a list of observations. Custom version of super._build_graph that adds edges to the graph and attaches point pair features to them if this is the match_attribute. Args: locations: List of x,y,z locations. features: Features observed at the locations. graph_id: Name of the object. input_channel: ? """ logging.info(f"Adding a new graph to memory.") model = GraphObjectModel( object_id=graph_id, ) graph_delta_thresholds = ( None if self.graph_delta_thresholds is None else self.graph_delta_thresholds[input_channel] ) model.build_model( locations, features, k_n=self.k, graph_delta_thresholds=graph_delta_thresholds, ) if self.match_attribute == "PPF": model.add_ppf_to_graph() if graph_id not in self.models_in_memory: self.models_in_memory[graph_id] = dict() self.models_in_memory[graph_id][input_channel] = model logging.info(f"Added new graph with id {graph_id} to memory.")
# ------------------------ Helper -------------------------- # ----------------------- Logging --------------------------