Source code for tbp.monty.frameworks.experiments.object_recognition_experiments
# Copyright 2025-2026 Thousand Brains Project
# Copyright 2023-2024 Numenta Inc.
#
# Copyright may exist in Contributors' modifications
# and/or contributions to the work.
#
# Use of this source code is governed by the MIT
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.
import logging
import torch
from tbp.monty.context import RuntimeContext
from tbp.monty.frameworks.environments.embodied_data import (
SaccadeOnImageEnvironmentInterface,
)
from tbp.monty.frameworks.experiments.mode import ExperimentMode
from tbp.monty.frameworks.experiments.monty_experiment import (
MontyExperiment,
)
__all__ = ["MontyGeneralizationExperiment", "MontyObjectRecognitionExperiment"]
logger = logging.getLogger(__name__)
[docs]class MontyObjectRecognitionExperiment(MontyExperiment):
"""Experiment customized for object-pose recognition with a single object.
Adds additional logging of the target object and pose for each episode and
specific terminal states for object recognition. It also adds code for
handling a matching and an exploration phase during each episode when training.
Note that this experiment assumes a particular model configuration in order
for the show_observations method to work: a zoomed-out "view_finder"
RGBA sensor and an up-close "patch" depth sensor.
"""
[docs] def run_episode(self):
"""Episode that checks the terminal states of an object recognition episode."""
self.pre_episode()
last_step = self.run_episode_steps()
self.post_episode(last_step)
[docs] def pre_episode(self):
"""Pre-episode hook.
Passes the primary target object and the mapping from semantic IDs to labels
to the Monty model for logging and reporting evaluation results.
"""
if self.experiment_mode is ExperimentMode.TRAIN:
logger.info(
f"running train epoch {self.train_epochs} "
f"train episode {self.train_episodes}"
)
else:
logger.info(
f"running eval epoch {self.eval_epochs} "
f"eval episode {self.eval_episodes}"
)
self.reset_episode_rng()
# TODO, eventually it would be better to pass
# self.env_interface.semantic_id_to_label via an "Observation" object when this
# is eventually implemented, such that we can ensure this information is never
# inappropriately accessed and used
if hasattr(self.env_interface, "semantic_id_to_label"):
# TODO: Fix invalid pre_episode signature call
self.model.pre_episode(
self.env_interface.primary_target,
self.env_interface.semantic_id_to_label,
)
else:
# TODO: Fix invalid pre_episode signature call
self.model.pre_episode(self.env_interface.primary_target)
self.env_interface.pre_episode(self.rng)
self.max_steps = self.max_train_steps
if self.experiment_mode is not ExperimentMode.TRAIN:
self.max_steps = self.max_eval_steps
self.logger_handler.pre_episode(self.logger_args)
if self.show_sensor_output:
self.live_plotter.initialize_online_plotting()
[docs] def run_episode_steps(self):
"""Runs one episode of the experiment.
At each step, observations are collected from the env_interface and either
passed to the model or sent directly to the motor system. We also check if a
terminal condition was reached at each step and increment step counters.
Returns:
The number of total steps taken in the episode.
"""
step = 0
ctx = RuntimeContext(rng=self.rng)
while True:
try:
observations = self.env_interface.step(ctx, first=(step == 0))
except StopIteration:
# TODO: StopIteration is being thrown by NaiveScanPolicy to signal
# episode termination. This is a holdover from when we used
# iterators. However, this also abdicates control of the
# experiment to the policy. We should find a better way to handle
# this, so that the experiment can control the episode termination
# fully. For example, we know how many steps the policy will take,
# so the experiment can set max steps based on that knowledge
# alone.
self.model.set_is_done()
return step
if self.show_sensor_output:
is_saccade_on_image_data_loader = isinstance(
self.env_interface, SaccadeOnImageEnvironmentInterface
)
self.live_plotter.show_observations(
*self.live_plotter.hardcoded_assumptions(observations, self.model),
step,
is_saccade_on_image_data_loader,
)
if self.model.check_reached_max_matching_steps(self.max_steps):
logger.info(
f"Terminated due to maximum matching steps : {self.max_steps}"
)
# Need to break here already, otherwise there are problems
# when the object is recognized in the last step
return step
if step >= (self.max_total_steps):
logger.info(f"Terminated due to maximum episode steps : {step}")
self.model.deal_with_time_out()
return step
if self.model.is_motor_only_step:
logger.debug(
"Performing a motor-only step, so passing info straight to motor"
)
# On these sensations, we just want to pass information to the motor
# system, so bypass the main model step (i.e. updating of LMs)
self.model.pass_features_directly_to_motor_system(ctx, observations)
else:
self.model.step(ctx, observations)
if self.model.is_done:
# Check this right after step to avoid setting time out
# after object was already recognized.
return step
step += 1
[docs]class MontyGeneralizationExperiment(MontyObjectRecognitionExperiment):
"""Remove the tested object model from memory to see what is recognized instead."""
[docs] def pre_episode(self):
"""Pre episode where we pass target object to the model for logging."""
if "model.pt" not in self.model_path.parts:
model_path = self.model_path / "model.pt"
state_dict = torch.load(model_path)
print(f"loading models again from {model_path}")
self.model.load_state_dict(state_dict)
super().pre_episode()
target_object = self.env_interface.primary_target["object"]
print(f"removing {target_object}")
for lm in self.model.learning_modules:
lm.graph_memory.remove_graph_from_memory(target_object)
print(f"graphs in memory: {lm.get_all_known_object_ids()}")