Source code for tbp.monty.frameworks.experiments.data_collection_experiments

# Copyright 2025 Thousand Brains Project
# Copyright 2022-2024 Numenta Inc.
#
# Copyright may exist in Contributors' modifications
# and/or contributions to the work.
#
# Use of this source code is governed by the MIT
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.

import os

import torch
from tqdm import tqdm

from .object_recognition_experiments import MontyObjectRecognitionExperiment


[docs]class DataCollectionExperiment(MontyObjectRecognitionExperiment): """Collect data in environment without performing inference. Stripped down experiment, to explore points on the object and save JUST the resulting observations as a .pt file. This was used to collect data that can then be used offline to quickly test other, non-Monty methods (like ICP). Mostly useful for methods that require batches of observations and do not work with inference through movement over the object. Otherwise would recommend to implement approaches directly in the Monty framework instead of using offline data. """
[docs] def run_episode(self): """Episode that checks the terminal states of an object recognition episode.""" self.pre_episode() for step, observation in tqdm(enumerate(self.env_interface)): if step > self.max_steps: break if self.show_sensor_output: self.live_plotter.show_observations( *self.live_plotter.hardcoded_assumptions(observation, self.model), step, ) self.pass_features_to_motor_system(observation, step) self.post_episode()
[docs] def pass_features_to_motor_system(self, observation, step): self.model.aggregate_sensory_inputs(observation) self.model.motor_system._policy.processed_observations = ( self.model.sensor_module_outputs[0] ) # Add the object and action to the observation dict self.model.sensor_modules[0].processed_obs[-1]["object"] = ( self.env_interface.primary_target["object"] ) self.model.sensor_modules[0].processed_obs[-1]["action"] = ( None if self.model.motor_system._policy.action is None else ( f"{self.model.motor_system._policy.action.agent_id}." f"{self.model.motor_system._policy.action.name}" ) ) # Only include observations coming right before a move_tangentially action if step > 0 and ( self.model.motor_system._policy.action is None or self.model.motor_system._policy.action.name != "move_tangentially" ): del self.model.sensor_modules[0].processed_obs[-2]
[docs] def pre_episode(self): """Pre episode where we pass target object to the model for logging.""" self.model.pre_episode() self.env_interface.pre_episode() self.max_steps = self.max_train_steps self.logger_handler.pre_episode(self.logger_args) if self.show_sensor_output: self.live_plotter.initialize_online_plotting()
[docs] def post_episode(self): torch.save( self.model.sensor_modules[0].processed_obs[:-1], os.path.join(self.output_dir, f"observations{self.train_episodes}.pt"), ) self.env_interface.post_episode() self.train_episodes += 1
[docs] def post_epoch(self): # This stripped down expt only allows for one pass pass