Source code for tbp.monty.frameworks.models.no_reset_evidence_matching
# Copyright 2025 Thousand Brains Project
#
# Copyright may exist in Contributors' modifications
# and/or contributions to the work.
#
# Use of this source code is governed by the MIT
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.
from typing import List
import numpy as np
from tbp.monty.frameworks.models.evidence_matching import (
EvidenceGraphLM,
MontyForEvidenceGraphMatching,
)
from tbp.monty.frameworks.models.mixins.no_reset_evidence import (
TheoreticalLimitLMLoggingMixin,
)
from tbp.monty.frameworks.models.states import State
[docs]class MontyForNoResetEvidenceGraphMatching(MontyForEvidenceGraphMatching):
"""Monty class for unsupervised inference without explicit episode resets.
This variant of `MontyForEvidenceGraphMatching` is designed for unsupervised
inference experiments where objects may change dynamically without any reset
signal. Unlike standard experiments, this class avoids resetting Monty's
internal state (e.g., hypothesis space, evidence scores) between episodes.
This setup better reflects real-world conditions, where object boundaries
are ambiguous and no supervisory signal is available to indicate when a new
object appears. Only minimal state — such as step counters and termination
flags — is reset to prevent buffers from accumulating across objects. Additionally,
Monty is currently forced to switch to Matching state. Evaluation of unsupervised
inference is performed over a fixed number of matching steps per object.
*Intended for evaluation-only runs using pre-trained models, with Monty
remaining in the matching phase throughout.*
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Track whether `pre_episode` has been called at least once.
# There are two separate issues this helps avoid:
#
# 1. Some internal variables in SMs and LMs (e.g., `stepwise_targets_list`,
# `terminal_state`, `is_exploring`, `visited_locs`) are not initialized
# in `__init__`, but only inside `pre_episode`. Ideally, these should be
# initialized once in `__init__` and reset in `pre_episode`, but fixing
# this would require changes across multiple classes.
#
# 2. The order of operations: Graphs are loaded into LMs *after* the Monty
# object is constructed but *before* `pre_episode` is called. Some
# functions (e.g., in `EvidenceGraphLM`) depend on the graph being loaded to
# compute initial possible matches inside `pre_episode`, and this cannot
# be safely moved into `__init__`.
#
# As a workaround, we allow `pre_episode` to run normally once (to complete
# required initialization), and skip full resets on subsequent calls.
# TODO: Remove initialization logic from `pre_episode`
self.init_pre_episode = False
[docs] def pre_episode(self, primary_target, semantic_id_to_label=None):
if not self.init_pre_episode:
self.init_pre_episode = True
return super().pre_episode(primary_target, semantic_id_to_label)
# reset terminal state
self._is_done = False
self.reset_episode_steps()
self.switch_to_matching_step()
# keep target up-to-date for logging
self.primary_target = primary_target
self.semantic_id_to_label = semantic_id_to_label
for lm in self.learning_modules:
lm.primary_target = primary_target["object"]
lm.primary_target_rotation_quat = primary_target["quat_rotation"]
# reset LMs and SMs buffers to save memory
self._reset_modules_buffers()
def _reset_modules_buffers(self):
"""Resets buffers for LMs and SMs."""
for lm in self.learning_modules:
lm.buffer.reset()
for sm in self.sensor_modules:
sm.raw_observations = []
sm.sm_properties = []
sm.processed_obs = []
[docs]class NoResetEvidenceGraphLM(TheoreticalLimitLMLoggingMixin, EvidenceGraphLM):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.last_location = {}
# it does not make sense for the wait factor to exponentially
# grow when objects are swapped without any supervisory signal.
self.gsg.wait_growth_multiplier = 1
[docs] def reset(self) -> None:
super().reset()
self.evidence = {}
self.last_location = {}
def _add_displacements(self, obs: List[State]) -> List[State]:
"""Add displacements to the current observation.
For each input channel, this function computes the displacement vector by
subtracting the current location from the last observed location. It then
updates `self.last_location` for use in the next step. If any observation
has a recorded previous location, we assume movement has occurred.
In this unsupervised inference setting, the displacement is set to zero
at the beginning of the first episode when the last location is not set.
Args:
obs (List[State]): A list of observations to which displacements will be
added.
Returns:
obs (List[State]): The list of observations, each updated with a
displacement vector.
"""
for o in obs:
if o.sender_id in self.last_location.keys():
displacement = o.location - self.last_location[o.sender_id]
else:
displacement = np.zeros(3)
o.set_displacement(displacement)
self.last_location[o.sender_id] = o.location
return obs
def _agent_moved_since_reset(self):
"""Overwrites the logic of whether the agent has moved since the last reset.
In unsupervised inference, the first movement is detected on the first
episode only. If a `last_location` exists, then first movement has occurred.
Returns:
- Whether the agent has moved since the last reset.
"""
return len(self.last_location) > 0