# Copyright 2025 Thousand Brains Project
# Copyright 2024 Numenta Inc.
#
# Copyright may exist in Contributors' modifications
# and/or contributions to the work.
#
# Use of this source code is governed by the MIT
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.
from __future__ import annotations
from json import JSONDecoder, JSONEncoder
from typing import Any, Generator, Tuple
from typing_extensions import (
Protocol, # Enables default __init__ in Protocol classes
runtime_checkable, # For JSONEncoder instance checks
)
__all__ = [
# Actions
"Action",
"LookDown",
"LookDownActionSampler",
"LookDownActuator",
"LookUp",
"LookUpActionSampler",
"LookUpActuator",
"MoveForward",
"MoveForwardActionSampler",
"MoveForwardActuator",
"MoveTangentially",
"MoveTangentiallyActionSampler",
"MoveTangentiallyActuator",
"OrientHorizontal",
"OrientHorizontalActionSampler",
"OrientHorizontalActuator",
"OrientVertical",
"OrientVerticalActionSampler",
"OrientVerticalActuator",
"SetAgentPitch",
"SetAgentPitchActionSampler",
"SetAgentPitchActuator",
"SetAgentPose",
"SetAgentPoseActionSampler",
"SetAgentPoseActuator",
"SetSensorPitch",
"SetSensorPitchActionSampler",
"SetSensorPitchActuator",
"SetSensorPose",
"SetSensorPoseActionSampler",
"SetSensorPoseActuator",
"SetSensorRotation",
"SetSensorRotationActionSampler",
"SetSensorRotationActuator",
"SetYaw",
"SetYawActionSampler",
"SetYawActuator",
"TurnLeft",
"TurnLeftActionSampler",
"TurnLeftActuator",
"TurnRight",
"TurnRightActionSampler",
"TurnRightActuator",
# Spatial representations
"QuaternionWXYZ",
"VectorXYZ",
]
VectorXYZ = Tuple[float, float, float]
QuaternionWXYZ = Tuple[float, float, float, float]
[docs]@runtime_checkable
class Action(Protocol):
"""An action that can be taken by an agent.
Actions are generated by the MotorSystem and are executed by an Actuator.
"""
agent_id: str
"""The ID of the agent that will take this action."""
@staticmethod
def _camel_case_to_snake_case(name: str) -> str:
"""Expecting a class name in CamelCase returns it in snake_case.
Returns:
The class name in snake_case.
"""
return "".join(
["_" + char.lower() if char.isupper() else char for char in name]
).lstrip("_")
[docs] @classmethod
def action_name(cls) -> str:
"""Generate action name based on class.
Used in static configuration, e.g., `FakeAction.action_name()`.
Returns:
The action name in snake_case.
"""
return Action._camel_case_to_snake_case(cls.__name__)
def __init__(self, agent_id: str) -> None:
"""Initialize the action with the agent ID.
Args:
agent_id: The ID of the agent that will take this action.
"""
self.agent_id = agent_id
@property
def name(self) -> str:
"""Used for checking action name on an Action instance."""
return self.__class__.action_name()
def __iter__(self) -> Generator[tuple[str, Any]]:
"""Yields the action name and all action parameters.
Useful if you want to do something like: `dict(action_instance)`.
"""
yield "action", self.name
for key, value in self.__dict__.items():
if key == "name":
continue
yield key, value
[docs]class LookDownActionSampler(Protocol):
[docs] def sample_look_down(self, agent_id: str) -> LookDown: ...
[docs]class LookDownActuator(Protocol):
[docs] def actuate_look_down(self, action: LookDown) -> None: ...
[docs]class LookDown(Action):
"""Rotate the agent downwards by a specified number of degrees."""
[docs] @staticmethod
def sample(agent_id: str, sampler: LookDownActionSampler) -> LookDown:
return sampler.sample_look_down(agent_id)
def __init__(
self, agent_id: str, rotation_degrees: float, constraint_degrees: float = 90.0
) -> None:
super().__init__(agent_id=agent_id)
self.constraint_degrees = constraint_degrees
self.rotation_degrees = rotation_degrees
[docs] def act(self, actuator: LookDownActuator) -> None:
actuator.actuate_look_down(self)
[docs]class LookUpActionSampler(Protocol):
[docs] def sample_look_up(self, agent_id: str) -> LookUp: ...
[docs]class LookUpActuator(Protocol):
[docs] def actuate_look_up(self, action: LookUp) -> None: ...
[docs]class LookUp(Action):
"""Rotate the agent upwards by a specified number of degrees."""
[docs] @staticmethod
def sample(agent_id: str, sampler: LookUpActionSampler) -> LookUp:
return sampler.sample_look_up(agent_id)
def __init__(
self, agent_id: str, rotation_degrees: float, constraint_degrees: float = 90.0
) -> None:
super().__init__(agent_id=agent_id)
self.constraint_degrees = constraint_degrees
self.rotation_degrees = rotation_degrees
[docs] def act(self, actuator: LookUpActuator) -> None:
actuator.actuate_look_up(self)
[docs]class MoveForwardActionSampler(Protocol):
[docs] def sample_move_forward(self, agent_id: str) -> MoveForward: ...
[docs]class MoveForwardActuator(Protocol):
[docs] def actuate_move_forward(self, action: MoveForward) -> None: ...
[docs]class MoveForward(Action):
"""Move the agent forward by a specified distance."""
[docs] @staticmethod
def sample(agent_id: str, sampler: MoveForwardActionSampler) -> MoveForward:
return sampler.sample_move_forward(agent_id)
def __init__(self, agent_id: str, distance: float) -> None:
super().__init__(agent_id=agent_id)
self.distance = distance
[docs] def act(self, actuator: MoveForwardActuator) -> None:
actuator.actuate_move_forward(self)
[docs]class MoveTangentiallyActionSampler(Protocol):
[docs] def sample_move_tangentially(self, agent_id: str) -> MoveTangentially: ...
[docs]class MoveTangentiallyActuator(Protocol):
[docs] def actuate_move_tangentially(self, action: MoveTangentially) -> None: ...
[docs]class MoveTangentially(Action):
"""Move the agent tangentially.
Moves the agent tangentially to the current orientation by a specified distance
along a specified direction.
"""
[docs] @staticmethod
def sample(
agent_id: str, sampler: MoveTangentiallyActionSampler
) -> MoveTangentially:
return sampler.sample_move_tangentially(agent_id)
def __init__(self, agent_id: str, distance: float, direction: VectorXYZ) -> None:
super().__init__(agent_id=agent_id)
self.distance = distance
self.direction = direction
[docs] def act(self, actuator: MoveTangentiallyActuator) -> None:
actuator.actuate_move_tangentially(self)
[docs]class OrientHorizontalActionSampler(Protocol):
[docs] def sample_orient_horizontal(self, agent_id: str) -> OrientHorizontal: ...
[docs]class OrientHorizontalActuator(Protocol):
[docs] def actuate_orient_horizontal(self, action: OrientHorizontal) -> None: ...
[docs]class OrientHorizontal(Action):
"""Move the agent in the horizontal plane.
Moves the agent in the horizontal plane compensating for the horizontal
motion with a rotation in the horizontal plane.
"""
[docs] @staticmethod
def sample(
agent_id: str, sampler: OrientHorizontalActionSampler
) -> OrientHorizontal:
return sampler.sample_orient_horizontal(agent_id)
def __init__(
self,
agent_id: str,
rotation_degrees: float,
left_distance: float,
forward_distance: float,
) -> None:
super().__init__(agent_id=agent_id)
self.rotation_degrees = rotation_degrees
self.left_distance = left_distance
self.forward_distance = forward_distance
[docs] def act(self, actuator: OrientHorizontalActuator) -> None:
actuator.actuate_orient_horizontal(self)
[docs]class OrientVerticalActionSampler(Protocol):
[docs] def sample_orient_vertical(self, agent_id: str) -> OrientVertical: ...
[docs]class OrientVerticalActuator(Protocol):
[docs] def actuate_orient_vertical(self, action: OrientVertical) -> None: ...
[docs]class OrientVertical(Action):
"""Move the agent in the vertical plane.
Moves the agent in the vertical plane compensating for the vertical motion
with a rotation in the vertical plane.
"""
[docs] @staticmethod
def sample(agent_id: str, sampler: OrientVerticalActionSampler) -> OrientVertical:
return sampler.sample_orient_vertical(agent_id)
def __init__(
self,
agent_id: str,
rotation_degrees: float,
down_distance: float,
forward_distance: float,
) -> None:
super().__init__(agent_id=agent_id)
self.rotation_degrees = rotation_degrees
self.down_distance = down_distance
self.forward_distance = forward_distance
[docs] def act(self, actuator: OrientVerticalActuator) -> None:
actuator.actuate_orient_vertical(self)
[docs]class SetAgentPitchActionSampler(Protocol):
[docs] def sample_set_agent_pitch(self, agent_id: str) -> SetAgentPitch: ...
[docs]class SetAgentPitchActuator(Protocol):
[docs] def actuate_set_agent_pitch(self, action: SetAgentPitch) -> None: ...
[docs]class SetAgentPitch(Action):
"""Set the agent pitch rotation in degrees.
Note that unless otherwise changed, the sensors maintain identity orientation
with regard to the agent. So, this will also adjust the pitch of agent's sensors
with regard to the environment.
"""
[docs] @staticmethod
def sample(agent_id: str, sampler: SetAgentPitchActionSampler) -> SetAgentPitch:
return sampler.sample_set_agent_pitch(agent_id)
def __init__(self, agent_id: str, pitch_degrees: float) -> None:
super().__init__(agent_id=agent_id)
self.pitch_degrees = pitch_degrees
[docs] def act(self, actuator: SetAgentPitchActuator) -> None:
actuator.actuate_set_agent_pitch(self)
[docs]class SetAgentPoseActionSampler(Protocol):
[docs] def sample_set_agent_pose(self, agent_id: str) -> SetAgentPose: ...
[docs]class SetAgentPoseActuator(Protocol):
[docs] def actuate_set_agent_pose(self, action: SetAgentPose) -> None: ...
[docs]class SetAgentPose(Action):
"""Set the agent pose.
Set the agent pose to absolute location coordinates and orientation in the
environment.
"""
[docs] @staticmethod
def sample(agent_id: str, sampler: SetAgentPoseActionSampler) -> SetAgentPose:
return sampler.sample_set_agent_pose(agent_id)
def __init__(
self, agent_id: str, location: VectorXYZ, rotation_quat: QuaternionWXYZ
) -> None:
super().__init__(agent_id=agent_id)
self.location = location
self.rotation_quat = rotation_quat
[docs] def act(self, actuator: SetAgentPoseActuator) -> None:
actuator.actuate_set_agent_pose(self)
[docs]class SetSensorPitchActionSampler(Protocol):
[docs] def sample_set_sensor_pitch(self, agent_id: str) -> SetSensorPitch: ...
[docs]class SetSensorPitchActuator(Protocol):
[docs] def actuate_set_sensor_pitch(self, action: SetSensorPitch) -> None: ...
[docs]class SetSensorPitch(Action):
"""Set the sensor pitch rotation.
Note that this does not update the pitch of the agent. Imagine the body associated
with the eye remaining in place, while the eye moves.
"""
[docs] @staticmethod
def sample(agent_id: str, sampler: SetSensorPitchActionSampler) -> SetSensorPitch:
return sampler.sample_set_sensor_pitch(agent_id)
def __init__(self, agent_id: str, pitch_degrees: float) -> None:
super().__init__(agent_id=agent_id)
self.pitch_degrees = pitch_degrees
[docs] def act(self, actuator: SetSensorPitchActuator) -> None:
actuator.actuate_set_sensor_pitch(self)
[docs]class SetSensorPoseActionSampler(Protocol):
[docs] def sample_set_sensor_pose(self, agent_id: str) -> SetSensorPose: ...
[docs]class SetSensorPoseActuator(Protocol):
[docs] def actuate_set_sensor_pose(self, action: SetSensorPose) -> None: ...
[docs]class SetSensorPose(Action):
"""Set the sensor pose.
Set the sensor pose to absolute location coordinates and orientation in the
environment.
"""
[docs] @staticmethod
def sample(agent_id: str, sampler: SetSensorPoseActionSampler) -> SetSensorPose:
return sampler.sample_set_sensor_pose(agent_id)
def __init__(
self, agent_id: str, location: VectorXYZ, rotation_quat: QuaternionWXYZ
) -> None:
super().__init__(agent_id=agent_id)
self.location = location
self.rotation_quat = rotation_quat
[docs] def act(self, actuator: SetSensorPoseActuator) -> None:
actuator.actuate_set_sensor_pose(self)
[docs]class SetSensorRotationActionSampler(Protocol):
[docs] def sample_set_sensor_rotation(self, agent_id: str) -> SetSensorRotation: ...
[docs]class SetSensorRotationActuator(Protocol):
[docs] def actuate_set_sensor_rotation(self, action: SetSensorRotation) -> None: ...
[docs]class SetSensorRotation(Action):
"""Set the sensor rotation relative to the agent."""
[docs] @staticmethod
def sample(
agent_id: str, sampler: SetSensorRotationActionSampler
) -> SetSensorRotation:
return sampler.sample_set_sensor_rotation(agent_id)
def __init__(self, agent_id: str, rotation_quat: QuaternionWXYZ) -> None:
super().__init__(agent_id=agent_id)
self.rotation_quat = rotation_quat
[docs] def act(self, actuator: SetSensorRotationActuator) -> None:
actuator.actuate_set_sensor_rotation(self)
[docs]class SetYawActionSampler(Protocol):
[docs] def sample_set_yaw(self, agent_id: str) -> SetYaw: ...
[docs]class SetYawActuator(Protocol):
[docs] def actuate_set_yaw(self, action: SetYaw) -> None: ...
[docs]class SetYaw(Action):
"""Set the agent body yaw rotation."""
[docs] @staticmethod
def sample(agent_id: str, sampler: SetYawActionSampler) -> SetYaw:
return sampler.sample_set_yaw(agent_id)
def __init__(self, agent_id: str, rotation_degrees: float) -> None:
super().__init__(agent_id=agent_id)
self.rotation_degrees = rotation_degrees
[docs] def act(self, actuator: SetYawActuator) -> None:
actuator.actuate_set_yaw(self)
[docs]class TurnLeftActionSampler(Protocol):
[docs] def sample_turn_left(self, agent_id: str) -> TurnLeft: ...
[docs]class TurnLeftActuator(Protocol):
[docs] def actuate_turn_left(self, action: TurnLeft) -> None: ...
[docs]class TurnLeft(Action):
"""Rotate the agent to the left."""
[docs] @staticmethod
def sample(agent_id: str, sampler: TurnLeftActionSampler) -> TurnLeft:
return sampler.sample_turn_left(agent_id)
def __init__(self, agent_id: str, rotation_degrees: float) -> None:
super().__init__(agent_id=agent_id)
self.rotation_degrees = rotation_degrees
[docs] def act(self, actuator: TurnLeftActuator) -> None:
actuator.actuate_turn_left(self)
[docs]class TurnRightActionSampler(Protocol):
[docs] def sample_turn_right(self, agent_id: str) -> TurnRight: ...
[docs]class TurnRightActuator(Protocol):
[docs] def actuate_turn_right(self, action: TurnRight) -> None: ...
[docs]class TurnRight(Action):
"""Rotate the agent to the right."""
[docs] @staticmethod
def sample(agent_id: str, sampler: TurnRightActionSampler) -> TurnRight:
return sampler.sample_turn_right(agent_id)
def __init__(self, agent_id: str, rotation_degrees: float) -> None:
super().__init__(agent_id=agent_id)
self.rotation_degrees = rotation_degrees
[docs] def act(self, actuator: TurnRightActuator) -> None:
actuator.actuate_turn_right(self)
class ActionJSONEncoder(JSONEncoder):
"""Encodes an Action into a JSON object.
Action name is encoded as the `"action"` parameter. All other Action
parameters are encoded as key-value pairs in the JSON object
"""
def default(self, obj: Any) -> Any:
if isinstance(obj, Action):
return dict(obj)
return super().default(obj)
class ActionJSONDecoder(JSONDecoder):
"""Decodes JSON object into Actions.
Requires that the JSON object contains an "action" key with the name of the action.
Additionally, the JSON object must contain all action parameters used by the action.
"""
def __init__(self) -> None:
super().__init__(object_hook=self.object_hook)
def object_hook(self, obj: dict[str, Any]) -> Any:
if "action" not in obj:
raise ValueError("Invalid action object: missing 'action' key.")
action = obj["action"]
if action == LookDown.action_name():
return LookDown(
agent_id=obj["agent_id"],
rotation_degrees=obj["rotation_degrees"],
constraint_degrees=obj["constraint_degrees"],
)
elif action == LookUp.action_name():
return LookUp(
agent_id=obj["agent_id"],
rotation_degrees=obj["rotation_degrees"],
constraint_degrees=obj["constraint_degrees"],
)
elif action == MoveForward.action_name():
return MoveForward(
agent_id=obj["agent_id"],
distance=obj["distance"],
)
elif action == MoveTangentially.action_name():
return MoveTangentially(
agent_id=obj["agent_id"],
distance=obj["distance"],
direction=tuple(obj["direction"]),
)
elif action == OrientHorizontal.action_name():
return OrientHorizontal(
agent_id=obj["agent_id"],
rotation_degrees=obj["rotation_degrees"],
left_distance=obj["left_distance"],
forward_distance=obj["forward_distance"],
)
elif action == OrientVertical.action_name():
return OrientVertical(
agent_id=obj["agent_id"],
rotation_degrees=obj["rotation_degrees"],
down_distance=obj["down_distance"],
forward_distance=obj["forward_distance"],
)
elif action == SetAgentPitch.action_name():
return SetAgentPitch(
agent_id=obj["agent_id"],
pitch_degrees=obj["pitch_degrees"],
)
elif action == SetAgentPose.action_name():
return SetAgentPose(
agent_id=obj["agent_id"],
location=tuple(obj["location"]),
rotation_quat=tuple(obj["rotation_quat"]),
)
elif action == SetSensorPitch.action_name():
return SetSensorPitch(
agent_id=obj["agent_id"],
pitch_degrees=obj["pitch_degrees"],
)
elif action == SetSensorPose.action_name():
return SetSensorPose(
agent_id=obj["agent_id"],
location=tuple(obj["location"]),
rotation_quat=tuple(obj["rotation_quat"]),
)
elif action == SetSensorRotation.action_name():
return SetSensorRotation(
agent_id=obj["agent_id"],
rotation_quat=tuple(obj["rotation_quat"]),
)
elif action == SetYaw.action_name():
return SetYaw(
agent_id=obj["agent_id"],
rotation_degrees=obj["rotation_degrees"],
)
elif action == TurnLeft.action_name():
return TurnLeft(
agent_id=obj["agent_id"],
rotation_degrees=obj["rotation_degrees"],
)
elif action == TurnRight.action_name():
return TurnRight(
agent_id=obj["agent_id"],
rotation_degrees=obj["rotation_degrees"],
)
else:
raise ValueError(f"Invalid action object: unknown action '{action}'.")