Source code for tbp.monty.frameworks.models.motor_policy_selectors

# Copyright 2026 Thousand Brains Project
#
# Copyright may exist in Contributors' modifications
# and/or contributions to the work.
#
# Use of this source code is governed by the MIT
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Protocol

from tbp.monty.cmp import Goal, Message
from tbp.monty.context import RuntimeContext
from tbp.monty.frameworks.models.abstract_monty_classes import Observations
from tbp.monty.frameworks.models.motor_policies import (
    JumpToGoal,
    MotorPolicy,
    MotorPolicyResult,
    PolicyStatus,
)
from tbp.monty.frameworks.models.motor_system_state import MotorSystemState

if TYPE_CHECKING:
    from tbp.monty.frameworks.models.motor_system import MotorSystem
    from tbp.monty.frameworks.models.salience.motor_policy import LookAtGoal


__all__ = [
    "MotorPolicySelector",
    "SinglePolicySelector",
    "highest_confidence_goal",
]


[docs]def highest_confidence_goal(goals: list[Goal]) -> Goal: """Return the goal with the highest confidence. If there are multiple goals with the same confidence, returns the first one. Args: goals: A list of goals. Must be non-empty. Returns: The goal with the highest confidence. """ return sorted(goals, key=lambda x: x.confidence, reverse=True)[0]
[docs]class MotorPolicySelector(Protocol):
[docs] def pre_episode(self, motor_system: MotorSystem) -> None: ...
[docs] def state_dict(self) -> dict[str, Any]: ...
def __call__( self, ctx: RuntimeContext, observations: Observations, state: MotorSystemState, percept: Message, goals: list[Goal], ) -> MotorPolicyResult: """Return a motor policy result containing the next actions to take. Args: ctx: The runtime context. observations: The observations from the environment. state: The current state of the motor system. percept: The percept from (as of this writing) the first sensor module. goals: The list of goals to consider. Returns: A MotorPolicyResult that contains the actions to take. """ ...
[docs]class SinglePolicySelector(MotorPolicySelector):
[docs] def __init__(self, policy: MotorPolicy): self._policy = policy # TODO: Get rid of this once we have another path for telemetry. self._selected_goals: list[Goal | None] = []
[docs] def pre_episode(self, motor_system: MotorSystem) -> None: self._policy.pre_episode(motor_system) self._selected_goals = []
[docs] def state_dict(self) -> dict[str, Any]: return { "policy": self._policy.state_dict(), "selected_goals": self._selected_goals, }
def __call__( self, ctx: RuntimeContext, observations: Observations, state: MotorSystemState, percept: Message, goals: list[Goal], ) -> MotorPolicyResult: goal = highest_confidence_goal(goals) if goals else None self._selected_goals.append(goal) result = self._policy(ctx, observations, state, percept, goal) if result is None: return MotorPolicyResult([]) return result
class DistantPolicySelector(MotorPolicySelector): def __init__( self, jump_to_goal: JumpToGoal, look_at_goal: LookAtGoal, default: MotorPolicy, ): # policies self._jump_to_goal = jump_to_goal self._look_at_goal = look_at_goal self._default = default # state self._is_jumping = False # telemetry self._selected_policies: list[MotorPolicy] = [] self._selected_goals: list[Goal | None] = [] def pre_episode(self, motor_system: MotorSystem) -> None: self._jump_to_goal.pre_episode(motor_system) self._look_at_goal.pre_episode(motor_system) self._default.pre_episode(motor_system) self._is_jumping = False self._selected_policies = [] self._selected_goals = [] def state_dict(self) -> dict[str, Any]: return { "jump_to_goal": self._jump_to_goal.state_dict(), "look_at_goal": self._look_at_goal.state_dict(), "default": self._default.state_dict(), } def __call__( self, ctx: RuntimeContext, observations: Observations, state: MotorSystemState, percept: Message, goals: list[Goal], ) -> MotorPolicyResult: gsg_goals = [g for g in goals if g.sender_type == "GSG"] # Handle possibly undoing a jump or jumping to a new LM GSG goal. if self._is_jumping: goal = highest_confidence_goal(gsg_goals) if gsg_goals else None result = self._jump_to_goal( ctx, observations, state, percept, goal, ) self._is_jumping = result.status == PolicyStatus.IN_PROGRESS if result.actions: self._update_telemetry(policy=self._jump_to_goal, goal=goal) return result # Handle jumping to an LM GSG's goal. if gsg_goals: goal = highest_confidence_goal(gsg_goals) result = self._jump_to_goal( ctx, observations, state, percept, goal, ) self._is_jumping = result.status == PolicyStatus.IN_PROGRESS self._update_telemetry(policy=self._jump_to_goal, goal=goal) return result # Handle looking at an SM's goal. sm_goals = [g for g in goals if g.sender_type == "SM"] if sm_goals: goal = highest_confidence_goal(sm_goals) result = self._look_at_goal( ctx, observations, state, percept, goal, ) self._is_jumping = False self._update_telemetry(policy=self._look_at_goal, goal=goal) return result # Fall back to the default policy. result = self._default( ctx, observations, state, percept, None, ) self._is_jumping = False self._update_telemetry(policy=self._default, goal=None) return result def _update_telemetry( self, policy: MotorPolicy, goal: Goal | None, ) -> None: self._selected_policies.append(policy) self._selected_goals.append(goal)