# Copyright 2025 Thousand Brains Project
# Copyright 2024 Numenta Inc.
#
# Copyright may exist in Contributors' modifications
# and/or contributions to the work.
#
# Use of this source code is governed by the MIT
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List, Type, cast
import quaternion as qt
from numpy import cos, pi, sin, sqrt
from numpy.random import Generator, default_rng
from tbp.monty.frameworks.actions.actions import (
Action,
LookDown,
LookUp,
MoveForward,
MoveTangentially,
OrientHorizontal,
OrientVertical,
QuaternionWXYZ,
SetAgentPitch,
SetAgentPose,
SetSensorPitch,
SetSensorPose,
SetSensorRotation,
SetYaw,
TurnLeft,
TurnRight,
VectorXYZ,
)
__all__ = [
"ActionSampler",
"ConstantSampler",
"UniformlyDistributedSampler",
]
[docs]class ActionSampler(ABC):
"""Declares the interface for an abstract Action factory.
Used to generate Actions by sampling from a set of available action types.
"""
def __init__(self, rng: Generator = None, actions: List[Type[Action]] = None):
self.rng = rng if rng is not None else default_rng()
self._actions = actions if actions is not None else []
self._action_names = [action.action_name() for action in self._actions]
self._method_names = [
f"sample_{action_name}" for action_name in self._action_names
]
[docs] @abstractmethod
def sample_look_down(self, agent_id: str) -> LookDown:
pass
[docs] @abstractmethod
def sample_look_up(self, agent_id: str) -> LookUp:
pass
[docs] @abstractmethod
def sample_move_forward(self, agent_id: str) -> MoveForward:
pass
[docs] @abstractmethod
def sample_move_tangentially(self, agent_id: str) -> MoveTangentially:
pass
[docs] @abstractmethod
def sample_orient_horizontal(self, agent_id: str) -> OrientHorizontal:
pass
[docs] @abstractmethod
def sample_orient_vertical(self, agent_id: str) -> OrientVertical:
pass
[docs] @abstractmethod
def sample_set_agent_pitch(self, agent_id: str) -> SetAgentPitch:
pass
[docs] @abstractmethod
def sample_set_agent_pose(self, agent_id: str) -> SetAgentPose:
pass
[docs] @abstractmethod
def sample_set_sensor_pitch(self, agent_id: str) -> SetSensorPitch:
pass
[docs] @abstractmethod
def sample_set_sensor_pose(self, agent_id: str) -> SetSensorPose:
pass
[docs] @abstractmethod
def sample_set_sensor_rotation(self, agent_id: str) -> SetSensorRotation:
pass
[docs] @abstractmethod
def sample_set_yaw(self, agent_id: str) -> SetYaw:
pass
[docs] @abstractmethod
def sample_turn_left(self, agent_id: str) -> TurnLeft:
pass
[docs] @abstractmethod
def sample_turn_right(self, agent_id: str) -> TurnRight:
pass
[docs] def sample(self, agent_id: str) -> Action:
"""Sample a random action from the available action types.
Returns:
A random action from the available action types.
"""
random_create_method_name = self.rng.choice(self._method_names)
random_create_method = getattr(self, random_create_method_name)
action = random_create_method(agent_id)
return cast(Action, action)
[docs]class ConstantSampler(ActionSampler):
"""An Action factory using constant, prespecified action parameters.
This Action factory samples actions with constant parameters.
The values of action parameters used are set at initialization time and
remain the same for all actions created by this factory. For example,
if you specify `rotation_degrees=5.0`, all actions created by this factory
that take a `rotation_degrees` parameter will have it set to `5.0`.
When sampling an Action, only applicable parameters are used. For example,
when sampling a `MoveForward` action, only the ConstantCreator's
`translation_distance` parameter is used to determine the action's `distance`
parameter.
"""
def __init__(
self,
absolute_degrees: float = 0.0,
actions: List[Type[Action]] = None,
direction: VectorXYZ = None,
location: VectorXYZ = None,
rng: Generator = None,
rotation_degrees: float = 5.0,
rotation_quat: QuaternionWXYZ = None,
translation_distance: float = 0.004,
**kwargs, # Accept arbitrary keyword arguments for compatibility
):
super().__init__(actions=actions, rng=rng)
self.absolute_degrees = absolute_degrees
self.direction = direction if direction is not None else [0.0, 0.0, 0.0]
self.location = location if location is not None else [0.0, 0.0, 0.0]
self.rotation_degrees = rotation_degrees
self.rotation_quat = rotation_quat if rotation_quat is not None else qt.one
self.translation_distance = translation_distance
[docs] def sample_look_down(self, agent_id: str) -> LookDown:
return LookDown(agent_id=agent_id, rotation_degrees=self.rotation_degrees)
[docs] def sample_look_up(self, agent_id: str) -> LookUp:
return LookUp(agent_id=agent_id, rotation_degrees=self.rotation_degrees)
[docs] def sample_move_forward(self, agent_id: str) -> MoveForward:
return MoveForward(agent_id=agent_id, distance=self.translation_distance)
[docs] def sample_move_tangentially(self, agent_id: str) -> MoveTangentially:
return MoveTangentially(
agent_id=agent_id,
distance=self.translation_distance,
direction=self.direction,
)
[docs] def sample_orient_horizontal(self, agent_id: str) -> OrientHorizontal:
return OrientHorizontal(
agent_id=agent_id,
rotation_degrees=self.rotation_degrees,
left_distance=self.translation_distance,
forward_distance=self.translation_distance,
)
[docs] def sample_orient_vertical(self, agent_id: str) -> OrientVertical:
return OrientVertical(
agent_id=agent_id,
rotation_degrees=self.rotation_degrees,
down_distance=self.translation_distance,
forward_distance=self.translation_distance,
)
[docs] def sample_set_agent_pitch(self, agent_id: str) -> SetAgentPitch:
return SetAgentPitch(agent_id=agent_id, pitch_degrees=self.absolute_degrees)
[docs] def sample_set_agent_pose(self, agent_id: str) -> SetAgentPose:
return SetAgentPose(
agent_id=agent_id, location=self.location, rotation_quat=self.rotation_quat
)
[docs] def sample_set_sensor_pitch(self, agent_id: str) -> SetSensorPitch:
return SetSensorPitch(agent_id=agent_id, pitch_degrees=self.absolute_degrees)
[docs] def sample_set_sensor_pose(self, agent_id: str) -> SetSensorPose:
return SetSensorPose(
agent_id=agent_id, location=self.location, rotation_quat=self.rotation_quat
)
[docs] def sample_set_sensor_rotation(self, agent_id: str) -> SetSensorRotation:
return SetSensorRotation(agent_id=agent_id, rotation_quat=self.rotation_quat)
[docs] def sample_set_yaw(self, agent_id: str) -> SetYaw:
return SetYaw(agent_id=agent_id, rotation_degrees=self.absolute_degrees)
[docs] def sample_turn_left(self, agent_id: str) -> TurnLeft:
return TurnLeft(agent_id=agent_id, rotation_degrees=self.rotation_degrees)
[docs] def sample_turn_right(self, agent_id: str) -> TurnRight:
return TurnRight(agent_id=agent_id, rotation_degrees=self.rotation_degrees)