Source code for tbp.monty.frameworks.models.evidence_matching.graph_memory

# Copyright 2025 Thousand Brains Project
# Copyright 2022-2024 Numenta Inc.
#
# Copyright may exist in Contributors' modifications
# and/or contributions to the work.
#
# Use of this source code is governed by the MIT
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.

import logging

from tbp.monty.frameworks.models.graph_matching import GraphMemory
from tbp.monty.frameworks.models.object_model import (
    GraphObjectModel,
    GridObjectModel,
    GridTooSmallError,
)

logger = logging.getLogger(__name__)


[docs]class EvidenceGraphMemory(GraphMemory): """Custom GraphMemory that stores GridObjectModel instead of GraphObjectModel.""" def __init__( self, max_nodes_per_graph, max_graph_size, num_model_voxels_per_dim, *args, **kwargs, ): super(EvidenceGraphMemory, self).__init__(*args, **kwargs) self.max_nodes_per_graph = max_nodes_per_graph self.max_graph_size = max_graph_size self.num_model_voxels_per_dim = num_model_voxels_per_dim # =============== Public Interface Functions =============== # ------------------- Main Algorithm ----------------------- # ------------------ Getters & Setters ---------------------
[docs] def get_initial_hypotheses(self): return self.get_memory_ids()
[docs] def get_rotation_features_at_all_nodes(self, graph_id, input_channel): """Get rotation features from all N nodes. shape=(N, 3, 3). Returns: The rotation features from all N nodes. shape=(N, 3, 3). """ all_node_r_features = self.get_features_at_node( graph_id, input_channel, self.get_graph_node_ids(graph_id, input_channel), feature_keys=["pose_vectors"], ) node_directions = all_node_r_features["pose_vectors"] num_nodes = len(node_directions) node_directions = node_directions.reshape((num_nodes, 3, 3)) return node_directions
# ======================= Private ========================== # ------------------- Main Algorithm ----------------------- def _add_graph_to_memory(self, model, graph_id): """Add pretrained graph to memory. Initializes GridObjectModel and calls set_graph. Args: model: new model to be added to memory graph_id: id of graph that should be added """ self.models_in_memory[graph_id] = {} for input_channel in model.keys(): channel_model = model[input_channel] try: if isinstance(channel_model, GraphObjectModel): # When loading a model trained with a different LM, need to convert # it to the GridObjectModel (with use_original_graph == True) loaded_graph = channel_model._graph channel_model = self._initialize_model_with_graph( graph_id, loaded_graph ) logger.info(f"Loaded {model} for {input_channel}") self.models_in_memory[graph_id][input_channel] = channel_model except GridTooSmallError: logger.info("Grid too small for given locations. Not adding to memory.") def _initialize_model_with_graph(self, graph_id, graph): model = GridObjectModel( object_id=graph_id, max_nodes=self.max_nodes_per_graph, max_size=self.max_graph_size, num_voxels_per_dim=self.num_model_voxels_per_dim, ) # Keep benchmark results constant by still using original graph for # matching when loading pretrained models. model.use_original_graph = True model.set_graph(graph) return model def _build_graph(self, locations, features, graph_id, input_channel): """Build a graph from a list of features at locations and add to memory. This initializes a new GridObjectModel and calls model.build_graph. Args: locations: List of x,y,z locations. features: List of features. graph_id: name of new graph. input_channel: ? """ logger.info(f"Adding a new graph to memory.") model = GridObjectModel( object_id=graph_id, max_nodes=self.max_nodes_per_graph, max_size=self.max_graph_size, num_voxels_per_dim=self.num_model_voxels_per_dim, ) try: model.build_model(locations=locations, features=features) if graph_id not in self.models_in_memory: self.models_in_memory[graph_id] = {} self.models_in_memory[graph_id][input_channel] = model logger.info(f"Added new graph with id {graph_id} to memory.") logger.info(model) except GridTooSmallError: logger.info( "Grid too small for given locations. Not building a model " f"for {graph_id}" ) def _extend_graph( self, locations, features, graph_id, input_channel, object_location_rel_body, location_rel_model, object_rotation, object_scale, ): """Add new observations into an existing graph. Args: locations: List of x,y,z locations. features: Features observed at the locations. graph_id: ID of the existing graph. input_channel: ? object_location_rel_body: location of the sensor in body reference frame location_rel_model: location of sensor in model reference frame object_rotation: rotation of the sensed object relative to the model object_scale: scale of the object relative to the model of it """ logger.info(f"Updating existing graph for {graph_id}") try: self.models_in_memory[graph_id][input_channel].update_model( locations=locations, features=features, location_rel_model=location_rel_model, object_location_rel_body=object_location_rel_body, object_rotation=object_rotation, ) logger.info( f"Extended graph {graph_id} with new points. New model:\n" f"{self.models_in_memory[graph_id]}" ) except GridTooSmallError: logger.info("Grid too small for given locations. Not updating model.")