Source code for tbp.monty.frameworks.experiments.object_recognition_experiments

# Copyright 2025 Thousand Brains Project
# Copyright 2023-2024 Numenta Inc.
#
# Copyright may exist in Contributors' modifications
# and/or contributions to the work.
#
# Use of this source code is governed by the MIT
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.

import logging
import os

import torch

from tbp.monty.frameworks.environments.embodied_data import (
    SaccadeOnImageEnvironmentInterface,
)

from .monty_experiment import MontyExperiment

logger = logging.getLogger(__name__)


[docs]class MontyObjectRecognitionExperiment(MontyExperiment): """Experiment customized for object-pose recognition with a single object. Adds additional logging of the target object and pose for each episode and specific terminal states for object recognition. It also adds code for handling a matching and an exploration phase during each episode when training. Note that this experiment assumes a particular model configuration, in order for the show_observations method to work: a zoomed out "view_finder" rgba sensor and an up-close "patch" depth sensor """
[docs] def run_episode(self): """Episode that checks the terminal states of an object recognition episode.""" self.pre_episode() last_step = self.run_episode_steps() self.post_episode(last_step)
[docs] def pre_episode(self): """Pre-episode hook. Pre episode where we pass the primary target object, as well as the mapping between semantic ID to labels, both for logging/evaluation purposes. """ # TODO, eventually it would be better to pass # self.env_interface.semantic_id_to_label via an "Observation" object when this # is eventually implemented, such that we can ensure this information is never # inappropriately accessed and used if hasattr(self.env_interface, "semantic_id_to_label"): self.model.pre_episode( self.env_interface.primary_target, self.env_interface.semantic_id_to_label, ) else: self.model.pre_episode(self.env_interface.primary_target) self.env_interface.pre_episode() self.max_steps = self.max_train_steps if self.model.experiment_mode != "train": self.max_steps = self.max_eval_steps self.logger_handler.pre_episode(self.logger_args) if self.show_sensor_output: self.live_plotter.initialize_online_plotting()
[docs] def run_episode_steps(self): """Runs one episode of the experiment. At each step, observations are collected from the env_interface and either passed to the model or sent directly to the motor system. We also check if a terminal condition was reached at each step and increment step counters. Returns: The number of total steps taken in the episode. """ for loader_step, observation in enumerate(self.env_interface): if self.show_sensor_output: is_saccade_on_image_data_loader = isinstance( self.env_interface, SaccadeOnImageEnvironmentInterface ) self.live_plotter.show_observations( *self.live_plotter.hardcoded_assumptions(observation, self.model), loader_step, is_saccade_on_image_data_loader, ) if self.model.check_reached_max_matching_steps(self.max_steps): logger.info( f"Terminated due to maximum matching steps : {self.max_steps}" ) # Need to break here already, otherwise there are problems # when the object is recognized in the last step return loader_step if loader_step >= (self.max_total_steps): logger.info(f"Terminated due to maximum episode steps : {loader_step}") self.model.deal_with_time_out() return loader_step if self.model.is_motor_only_step: logger.debug( "Performing a motor-only step, so passing info straight to motor" ) # On these sensations, we just want to pass information to the motor # system, so bypass the main model step (i.e. updating of LMs) self.model.pass_features_directly_to_motor_system(observation) else: self.model.step(observation) if self.model.is_done: # Check this right after step to avoid setting time out # after object was already recognized. return loader_step # handle case where spiral policy calls StopIterator in motor policy self.model.set_is_done() return loader_step
[docs]class MontyGeneralizationExperiment(MontyObjectRecognitionExperiment): """Remove the tested object model from memory to see what is recognized instead."""
[docs] def pre_episode(self): """Pre episode where we pass target object to the model for logging.""" if "model.pt" not in self.model_path: model_path = os.path.join(self.model_path, "model.pt") state_dict = torch.load(model_path) print(f"loading models again from {model_path}") self.model.load_state_dict(state_dict) super().pre_episode() target_object = self.env_interface.primary_target["object"] print(f"removing {target_object}") for lm in self.model.learning_modules: lm.graph_memory.remove_graph_from_memory(target_object) print(f"graphs in memory: {lm.get_all_known_object_ids()}")