Source code for tbp.monty.frameworks.actions.actions

# Copyright 2025 Thousand Brains Project
# Copyright 2024 Numenta Inc.
#
# Copyright may exist in Contributors' modifications
# and/or contributions to the work.
#
# Use of this source code is governed by the MIT
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.

from __future__ import annotations

from json import JSONDecoder, JSONEncoder
from typing import Any, Generator, Tuple

from typing_extensions import (
    Protocol,  # Enables default __init__ in Protocol classes
    runtime_checkable,  # For JSONEncoder instance checks
)

__all__ = [
    # Actions
    "Action",
    "LookDown",
    "LookDownActionSampler",
    "LookDownActuator",
    "LookUp",
    "LookUpActionSampler",
    "LookUpActuator",
    "MoveForward",
    "MoveForwardActionSampler",
    "MoveForwardActuator",
    "MoveTangentially",
    "MoveTangentiallyActionSampler",
    "MoveTangentiallyActuator",
    "OrientHorizontal",
    "OrientHorizontalActionSampler",
    "OrientHorizontalActuator",
    "OrientVertical",
    "OrientVerticalActionSampler",
    "OrientVerticalActuator",
    "SetAgentPitch",
    "SetAgentPitchActionSampler",
    "SetAgentPitchActuator",
    "SetAgentPose",
    "SetAgentPoseActionSampler",
    "SetAgentPoseActuator",
    "SetSensorPitch",
    "SetSensorPitchActionSampler",
    "SetSensorPitchActuator",
    "SetSensorPose",
    "SetSensorPoseActionSampler",
    "SetSensorPoseActuator",
    "SetSensorRotation",
    "SetSensorRotationActionSampler",
    "SetSensorRotationActuator",
    "SetYaw",
    "SetYawActionSampler",
    "SetYawActuator",
    "TurnLeft",
    "TurnLeftActionSampler",
    "TurnLeftActuator",
    "TurnRight",
    "TurnRightActionSampler",
    "TurnRightActuator",
    # Spatial representations
    "QuaternionWXYZ",
    "VectorXYZ",
]

VectorXYZ = Tuple[float, float, float]
QuaternionWXYZ = Tuple[float, float, float, float]


[docs]@runtime_checkable class Action(Protocol): """An action that can be taken by an agent. Actions are generated by the MotorSystem and are executed by an Actuator. """ agent_id: str """The ID of the agent that will take this action.""" @staticmethod def _camel_case_to_snake_case(name: str) -> str: """Expecting a class name in CamelCase returns it in snake_case. Returns: The class name in snake_case. """ return "".join( ["_" + char.lower() if char.isupper() else char for char in name] ).lstrip("_")
[docs] @classmethod def action_name(cls) -> str: """Generate action name based on class. Used in static configuration, e.g., `FakeAction.action_name()`. Returns: The action name in snake_case. """ return Action._camel_case_to_snake_case(cls.__name__)
def __init__(self, agent_id: str) -> None: """Initialize the action with the agent ID. Args: agent_id: The ID of the agent that will take this action. """ self.agent_id = agent_id @property def name(self) -> str: """Used for checking action name on an Action instance.""" return self.__class__.action_name() def __iter__(self) -> Generator[tuple[str, Any]]: """Yields the action name and all action parameters. Useful if you want to do something like: `dict(action_instance)`. """ yield "action", self.name for key, value in self.__dict__.items(): if key == "name": continue yield key, value
[docs]class LookDownActionSampler(Protocol):
[docs] def sample_look_down(self, agent_id: str) -> LookDown: ...
[docs]class LookDownActuator(Protocol):
[docs] def actuate_look_down(self, action: LookDown) -> None: ...
[docs]class LookDown(Action): """Rotate the agent downwards by a specified number of degrees."""
[docs] @staticmethod def sample(agent_id: str, sampler: LookDownActionSampler) -> LookDown: return sampler.sample_look_down(agent_id)
def __init__( self, agent_id: str, rotation_degrees: float, constraint_degrees: float = 90.0 ) -> None: super().__init__(agent_id=agent_id) self.constraint_degrees = constraint_degrees self.rotation_degrees = rotation_degrees
[docs] def act(self, actuator: LookDownActuator) -> None: actuator.actuate_look_down(self)
[docs]class LookUpActionSampler(Protocol):
[docs] def sample_look_up(self, agent_id: str) -> LookUp: ...
[docs]class LookUpActuator(Protocol):
[docs] def actuate_look_up(self, action: LookUp) -> None: ...
[docs]class LookUp(Action): """Rotate the agent upwards by a specified number of degrees."""
[docs] @staticmethod def sample(agent_id: str, sampler: LookUpActionSampler) -> LookUp: return sampler.sample_look_up(agent_id)
def __init__( self, agent_id: str, rotation_degrees: float, constraint_degrees: float = 90.0 ) -> None: super().__init__(agent_id=agent_id) self.constraint_degrees = constraint_degrees self.rotation_degrees = rotation_degrees
[docs] def act(self, actuator: LookUpActuator) -> None: actuator.actuate_look_up(self)
[docs]class MoveForwardActionSampler(Protocol):
[docs] def sample_move_forward(self, agent_id: str) -> MoveForward: ...
[docs]class MoveForwardActuator(Protocol):
[docs] def actuate_move_forward(self, action: MoveForward) -> None: ...
[docs]class MoveForward(Action): """Move the agent forward by a specified distance."""
[docs] @staticmethod def sample(agent_id: str, sampler: MoveForwardActionSampler) -> MoveForward: return sampler.sample_move_forward(agent_id)
def __init__(self, agent_id: str, distance: float) -> None: super().__init__(agent_id=agent_id) self.distance = distance
[docs] def act(self, actuator: MoveForwardActuator) -> None: actuator.actuate_move_forward(self)
[docs]class MoveTangentiallyActionSampler(Protocol):
[docs] def sample_move_tangentially(self, agent_id: str) -> MoveTangentially: ...
[docs]class MoveTangentiallyActuator(Protocol):
[docs] def actuate_move_tangentially(self, action: MoveTangentially) -> None: ...
[docs]class MoveTangentially(Action): """Move the agent tangentially. Moves the agent tangentially to the current orientation by a specified distance along a specified direction. """
[docs] @staticmethod def sample( agent_id: str, sampler: MoveTangentiallyActionSampler ) -> MoveTangentially: return sampler.sample_move_tangentially(agent_id)
def __init__(self, agent_id: str, distance: float, direction: VectorXYZ) -> None: super().__init__(agent_id=agent_id) self.distance = distance self.direction = direction
[docs] def act(self, actuator: MoveTangentiallyActuator) -> None: actuator.actuate_move_tangentially(self)
[docs]class OrientHorizontalActionSampler(Protocol):
[docs] def sample_orient_horizontal(self, agent_id: str) -> OrientHorizontal: ...
[docs]class OrientHorizontalActuator(Protocol):
[docs] def actuate_orient_horizontal(self, action: OrientHorizontal) -> None: ...
[docs]class OrientHorizontal(Action): """Move the agent in the horizontal plane. Moves the agent in the horizontal plane compensating for the horizontal motion with a rotation in the horizontal plane. """
[docs] @staticmethod def sample( agent_id: str, sampler: OrientHorizontalActionSampler ) -> OrientHorizontal: return sampler.sample_orient_horizontal(agent_id)
def __init__( self, agent_id: str, rotation_degrees: float, left_distance: float, forward_distance: float, ) -> None: super().__init__(agent_id=agent_id) self.rotation_degrees = rotation_degrees self.left_distance = left_distance self.forward_distance = forward_distance
[docs] def act(self, actuator: OrientHorizontalActuator) -> None: actuator.actuate_orient_horizontal(self)
[docs]class OrientVerticalActionSampler(Protocol):
[docs] def sample_orient_vertical(self, agent_id: str) -> OrientVertical: ...
[docs]class OrientVerticalActuator(Protocol):
[docs] def actuate_orient_vertical(self, action: OrientVertical) -> None: ...
[docs]class OrientVertical(Action): """Move the agent in the vertical plane. Moves the agent in the vertical plane compensating for the vertical motion with a rotation in the vertical plane. """
[docs] @staticmethod def sample(agent_id: str, sampler: OrientVerticalActionSampler) -> OrientVertical: return sampler.sample_orient_vertical(agent_id)
def __init__( self, agent_id: str, rotation_degrees: float, down_distance: float, forward_distance: float, ) -> None: super().__init__(agent_id=agent_id) self.rotation_degrees = rotation_degrees self.down_distance = down_distance self.forward_distance = forward_distance
[docs] def act(self, actuator: OrientVerticalActuator) -> None: actuator.actuate_orient_vertical(self)
[docs]class SetAgentPitchActionSampler(Protocol):
[docs] def sample_set_agent_pitch(self, agent_id: str) -> SetAgentPitch: ...
[docs]class SetAgentPitchActuator(Protocol):
[docs] def actuate_set_agent_pitch(self, action: SetAgentPitch) -> None: ...
[docs]class SetAgentPitch(Action): """Set the agent pitch rotation in degrees. Note that unless otherwise changed, the sensors maintain identity orientation with regard to the agent. So, this will also adjust the pitch of agent's sensors with regard to the environment. """
[docs] @staticmethod def sample(agent_id: str, sampler: SetAgentPitchActionSampler) -> SetAgentPitch: return sampler.sample_set_agent_pitch(agent_id)
def __init__(self, agent_id: str, pitch_degrees: float) -> None: super().__init__(agent_id=agent_id) self.pitch_degrees = pitch_degrees
[docs] def act(self, actuator: SetAgentPitchActuator) -> None: actuator.actuate_set_agent_pitch(self)
[docs]class SetAgentPoseActionSampler(Protocol):
[docs] def sample_set_agent_pose(self, agent_id: str) -> SetAgentPose: ...
[docs]class SetAgentPoseActuator(Protocol):
[docs] def actuate_set_agent_pose(self, action: SetAgentPose) -> None: ...
[docs]class SetAgentPose(Action): """Set the agent pose. Set the agent pose to absolute location coordinates and orientation in the environment. """
[docs] @staticmethod def sample(agent_id: str, sampler: SetAgentPoseActionSampler) -> SetAgentPose: return sampler.sample_set_agent_pose(agent_id)
def __init__( self, agent_id: str, location: VectorXYZ, rotation_quat: QuaternionWXYZ ) -> None: super().__init__(agent_id=agent_id) self.location = location self.rotation_quat = rotation_quat
[docs] def act(self, actuator: SetAgentPoseActuator) -> None: actuator.actuate_set_agent_pose(self)
[docs]class SetSensorPitchActionSampler(Protocol):
[docs] def sample_set_sensor_pitch(self, agent_id: str) -> SetSensorPitch: ...
[docs]class SetSensorPitchActuator(Protocol):
[docs] def actuate_set_sensor_pitch(self, action: SetSensorPitch) -> None: ...
[docs]class SetSensorPitch(Action): """Set the sensor pitch rotation. Note that this does not update the pitch of the agent. Imagine the body associated with the eye remaining in place, while the eye moves. """
[docs] @staticmethod def sample(agent_id: str, sampler: SetSensorPitchActionSampler) -> SetSensorPitch: return sampler.sample_set_sensor_pitch(agent_id)
def __init__(self, agent_id: str, pitch_degrees: float) -> None: super().__init__(agent_id=agent_id) self.pitch_degrees = pitch_degrees
[docs] def act(self, actuator: SetSensorPitchActuator) -> None: actuator.actuate_set_sensor_pitch(self)
[docs]class SetSensorPoseActionSampler(Protocol):
[docs] def sample_set_sensor_pose(self, agent_id: str) -> SetSensorPose: ...
[docs]class SetSensorPoseActuator(Protocol):
[docs] def actuate_set_sensor_pose(self, action: SetSensorPose) -> None: ...
[docs]class SetSensorPose(Action): """Set the sensor pose. Set the sensor pose to absolute location coordinates and orientation in the environment. """
[docs] @staticmethod def sample(agent_id: str, sampler: SetSensorPoseActionSampler) -> SetSensorPose: return sampler.sample_set_sensor_pose(agent_id)
def __init__( self, agent_id: str, location: VectorXYZ, rotation_quat: QuaternionWXYZ ) -> None: super().__init__(agent_id=agent_id) self.location = location self.rotation_quat = rotation_quat
[docs] def act(self, actuator: SetSensorPoseActuator) -> None: actuator.actuate_set_sensor_pose(self)
[docs]class SetSensorRotationActionSampler(Protocol):
[docs] def sample_set_sensor_rotation(self, agent_id: str) -> SetSensorRotation: ...
[docs]class SetSensorRotationActuator(Protocol):
[docs] def actuate_set_sensor_rotation(self, action: SetSensorRotation) -> None: ...
[docs]class SetSensorRotation(Action): """Set the sensor rotation relative to the agent."""
[docs] @staticmethod def sample( agent_id: str, sampler: SetSensorRotationActionSampler ) -> SetSensorRotation: return sampler.sample_set_sensor_rotation(agent_id)
def __init__(self, agent_id: str, rotation_quat: QuaternionWXYZ) -> None: super().__init__(agent_id=agent_id) self.rotation_quat = rotation_quat
[docs] def act(self, actuator: SetSensorRotationActuator) -> None: actuator.actuate_set_sensor_rotation(self)
[docs]class SetYawActionSampler(Protocol):
[docs] def sample_set_yaw(self, agent_id: str) -> SetYaw: ...
[docs]class SetYawActuator(Protocol):
[docs] def actuate_set_yaw(self, action: SetYaw) -> None: ...
[docs]class SetYaw(Action): """Set the agent body yaw rotation."""
[docs] @staticmethod def sample(agent_id: str, sampler: SetYawActionSampler) -> SetYaw: return sampler.sample_set_yaw(agent_id)
def __init__(self, agent_id: str, rotation_degrees: float) -> None: super().__init__(agent_id=agent_id) self.rotation_degrees = rotation_degrees
[docs] def act(self, actuator: SetYawActuator) -> None: actuator.actuate_set_yaw(self)
[docs]class TurnLeftActionSampler(Protocol):
[docs] def sample_turn_left(self, agent_id: str) -> TurnLeft: ...
[docs]class TurnLeftActuator(Protocol):
[docs] def actuate_turn_left(self, action: TurnLeft) -> None: ...
[docs]class TurnLeft(Action): """Rotate the agent to the left."""
[docs] @staticmethod def sample(agent_id: str, sampler: TurnLeftActionSampler) -> TurnLeft: return sampler.sample_turn_left(agent_id)
def __init__(self, agent_id: str, rotation_degrees: float) -> None: super().__init__(agent_id=agent_id) self.rotation_degrees = rotation_degrees
[docs] def act(self, actuator: TurnLeftActuator) -> None: actuator.actuate_turn_left(self)
[docs]class TurnRightActionSampler(Protocol):
[docs] def sample_turn_right(self, agent_id: str) -> TurnRight: ...
[docs]class TurnRightActuator(Protocol):
[docs] def actuate_turn_right(self, action: TurnRight) -> None: ...
[docs]class TurnRight(Action): """Rotate the agent to the right."""
[docs] @staticmethod def sample(agent_id: str, sampler: TurnRightActionSampler) -> TurnRight: return sampler.sample_turn_right(agent_id)
def __init__(self, agent_id: str, rotation_degrees: float) -> None: super().__init__(agent_id=agent_id) self.rotation_degrees = rotation_degrees
[docs] def act(self, actuator: TurnRightActuator) -> None: actuator.actuate_turn_right(self)
class ActionJSONEncoder(JSONEncoder): """Encodes an Action into a JSON object. Action name is encoded as the `"action"` parameter. All other Action parameters are encoded as key-value pairs in the JSON object """ def default(self, obj: Any) -> Any: if isinstance(obj, Action): return dict(obj) return super().default(obj) class ActionJSONDecoder(JSONDecoder): """Decodes JSON object into Actions. Requires that the JSON object contains an "action" key with the name of the action. Additionally, the JSON object must contain all action parameters used by the action. """ def __init__(self) -> None: super().__init__(object_hook=self.object_hook) def object_hook(self, obj: dict[str, Any]) -> Any: if "action" not in obj: raise ValueError("Invalid action object: missing 'action' key.") action = obj["action"] if action == LookDown.action_name(): return LookDown( agent_id=obj["agent_id"], rotation_degrees=obj["rotation_degrees"], constraint_degrees=obj["constraint_degrees"], ) elif action == LookUp.action_name(): return LookUp( agent_id=obj["agent_id"], rotation_degrees=obj["rotation_degrees"], constraint_degrees=obj["constraint_degrees"], ) elif action == MoveForward.action_name(): return MoveForward( agent_id=obj["agent_id"], distance=obj["distance"], ) elif action == MoveTangentially.action_name(): return MoveTangentially( agent_id=obj["agent_id"], distance=obj["distance"], direction=tuple(obj["direction"]), ) elif action == OrientHorizontal.action_name(): return OrientHorizontal( agent_id=obj["agent_id"], rotation_degrees=obj["rotation_degrees"], left_distance=obj["left_distance"], forward_distance=obj["forward_distance"], ) elif action == OrientVertical.action_name(): return OrientVertical( agent_id=obj["agent_id"], rotation_degrees=obj["rotation_degrees"], down_distance=obj["down_distance"], forward_distance=obj["forward_distance"], ) elif action == SetAgentPitch.action_name(): return SetAgentPitch( agent_id=obj["agent_id"], pitch_degrees=obj["pitch_degrees"], ) elif action == SetAgentPose.action_name(): return SetAgentPose( agent_id=obj["agent_id"], location=tuple(obj["location"]), rotation_quat=tuple(obj["rotation_quat"]), ) elif action == SetSensorPitch.action_name(): return SetSensorPitch( agent_id=obj["agent_id"], pitch_degrees=obj["pitch_degrees"], ) elif action == SetSensorPose.action_name(): return SetSensorPose( agent_id=obj["agent_id"], location=tuple(obj["location"]), rotation_quat=tuple(obj["rotation_quat"]), ) elif action == SetSensorRotation.action_name(): return SetSensorRotation( agent_id=obj["agent_id"], rotation_quat=tuple(obj["rotation_quat"]), ) elif action == SetYaw.action_name(): return SetYaw( agent_id=obj["agent_id"], rotation_degrees=obj["rotation_degrees"], ) elif action == TurnLeft.action_name(): return TurnLeft( agent_id=obj["agent_id"], rotation_degrees=obj["rotation_degrees"], ) elif action == TurnRight.action_name(): return TurnRight( agent_id=obj["agent_id"], rotation_degrees=obj["rotation_degrees"], ) else: raise ValueError(f"Invalid action object: unknown action '{action}'.")