Source code for tbp.monty.frameworks.models.salience.sensor_module

# Copyright 2025 Thousand Brains Project
#
# Copyright may exist in Contributors' modifications
# and/or contributions to the work.
#
# Use of this source code is governed by the MIT
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.
from __future__ import annotations

from typing import Any

import numpy as np

from tbp.monty.frameworks.models.abstract_monty_classes import SensorModule
from tbp.monty.frameworks.models.salience.on_object_observation import (
    on_object_observation,
)
from tbp.monty.frameworks.models.salience.return_inhibitor import ReturnInhibitor
from tbp.monty.frameworks.models.salience.strategies import (
    SalienceStrategy,
    UniformSalienceStrategy,
)
from tbp.monty.frameworks.models.sensor_modules import SnapshotTelemetry
from tbp.monty.frameworks.models.states import GoalState, State


[docs]class HabitatSalienceSM(SensorModule):
[docs] def __init__( self, rng, sensor_module_id: str, save_raw_obs: bool = False, salience_strategy_class: type[SalienceStrategy] = UniformSalienceStrategy, salience_strategy_args: dict[str, Any] | None = None, return_inhibitor_class: type[ReturnInhibitor] = ReturnInhibitor, return_inhibitor_args: dict[str, Any] | None = None, snapshot_telemetry_class: type[SnapshotTelemetry] = SnapshotTelemetry, ) -> None: self._rng = rng self._sensor_module_id = sensor_module_id self._save_raw_obs = save_raw_obs salience_strategy_args = ( dict(salience_strategy_args) if salience_strategy_args else {} ) self._salience_strategy = salience_strategy_class(**salience_strategy_args) return_inhibitor_args = ( dict(return_inhibitor_args) if return_inhibitor_args else {} ) self._return_inhibitor = return_inhibitor_class(**return_inhibitor_args) self._goals: list[GoalState] = [] self._snapshot_telemetry = snapshot_telemetry_class() # TODO: Goes away once experiment code is extracted self.is_exploring = False
[docs] def state_dict(self): return self._snapshot_telemetry.state_dict()
[docs] def update_state(self, state): """Update the state of the sensor module.""" self.state = state
[docs] def step(self, data) -> State | None: """Generate goal states for the current step. Args: data: Raw sensor observations Returns: A Percept, if one is generated. """ if self._save_raw_obs and not self.is_exploring: self._snapshot_telemetry.raw_observation( data, self.state["rotation"], self.state["location"] if "location" in self.state.keys() else self.state["position"], ) salience_map = self._salience_strategy(rgba=data["rgba"], depth=data["depth"]) on_object = on_object_observation(data, salience_map) ior_weights = self._return_inhibitor( on_object.center_location, on_object.locations ) salience = self._weight_salience(on_object.salience, ior_weights) self._goals = [ GoalState( location=on_object.locations[i], morphological_features=None, non_morphological_features=None, confidence=salience[i], use_state=True, sender_id=self._sensor_module_id, sender_type="SM", goal_tolerances=None, ) for i in range(len(on_object.locations)) ] return None
def _weight_salience( self, salience: np.ndarray, ior_weights: np.ndarray, ) -> np.ndarray: weighted_salience = self._decay_salience(salience, ior_weights) weighted_salience = self._randomize_salience(weighted_salience) weighted_salience = self._normalize_salience(weighted_salience) return weighted_salience def _decay_salience( self, salience: np.ndarray, ior_weights: np.ndarray ) -> np.ndarray: decay_factor = 0.75 weighted_salience = salience - decay_factor * ior_weights return weighted_salience def _randomize_salience(self, weighted_salience: np.ndarray) -> np.ndarray: randomness_factor = 0.05 weighted_salience += self._rng.normal( loc=0, scale=randomness_factor, size=weighted_salience.shape[0] ) return weighted_salience def _normalize_salience(self, weighted_salience: np.ndarray) -> np.ndarray: if weighted_salience.size == 0: return weighted_salience min_ = weighted_salience.min() max_ = weighted_salience.max() scale = max_ - min_ if np.isclose(scale, 0): return np.clip(weighted_salience, 0, 1) return (weighted_salience - min_) / scale
[docs] def pre_episode(self): """This method is called before each episode.""" self._goals.clear() self._return_inhibitor.reset() self._snapshot_telemetry.reset() self.is_exploring = False
[docs] def propose_goal_states(self) -> list[GoalState]: return self._goals