Source code for tbp.monty.frameworks.environments.object_init_samplers

# Copyright 2025-2026 Thousand Brains Project
# Copyright 2022-2024 Numenta Inc.
#
# Copyright may exist in Contributors' modifications
# and/or contributions to the work.
#
# Use of this source code is governed by the MIT
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.
from __future__ import annotations

from typing import Sequence, TypedDict, cast

import numpy as np
import numpy.typing as npt
from scipy.spatial.transform import Rotation
from typing_extensions import NotRequired

from tbp.monty.frameworks.experiments.mode import ExperimentMode
from tbp.monty.frameworks.experiments.seed import episode_seed
from tbp.monty.frameworks.utils.transform_utils import (
    rotation_as_quat,
)
from tbp.monty.math import EulerAnglesXYZ, QuaternionWXYZ, VectorXYZ


[docs]class MultiObjectNames(TypedDict): targets_list: Sequence[str] source_object_list: Sequence[str] num_distractors: int
[docs]class ObjectInitParams(TypedDict): position: VectorXYZ rotation: QuaternionWXYZ scale: VectorXYZ euler_rotation: npt.NDArray[np.float64] | EulerAnglesXYZ quat_rotation: NotRequired[npt.NDArray[np.float64]]
[docs]class Default: def __call__( self, seed: int, mode: ExperimentMode, epoch: int, # noqa: ARG002 episode: int, ) -> ObjectInitParams: seed = episode_seed(seed, mode, episode) rng = np.random.RandomState(seed) euler_rotation = rng.uniform(0, 360, 3) q = Rotation.from_euler("xyz", euler_rotation, degrees=True) return dict( rotation=cast("QuaternionWXYZ", tuple(rotation_as_quat(q))), euler_rotation=euler_rotation, position=(rng.uniform(-0.5, 0.5), 0.0, 0.0), scale=(1.0, 1.0, 1.0), )
[docs]class Predefined(Default):
[docs] def __init__( self, positions: Sequence[VectorXYZ] | None = None, rotations: Sequence[EulerAnglesXYZ] | None = None, scales: Sequence[VectorXYZ] | None = None, change_every_episode: bool | None = None, ): # NOTE: added param change_every_episode. This is so if I want to run an # experiment and specify an exact list of objects, with specific poses per # object, I can set this to True. Otherwise, I have to loop over all objects # for every pose specified. self.positions = positions or [(0.0, 1.5, 0.0)] self.rotations = rotations or [(0.0, 0.0, 0.0), (45.0, 0.0, 0.0)] self.scales = scales or [(1.0, 1.0, 1.0)] self.change_every_episode = change_every_episode
def __call__( self, seed: int, # noqa: ARG002 mode: ExperimentMode, # noqa: ARG002 epoch: int, episode: int, ) -> ObjectInitParams: mod_counter = episode if self.change_every_episode else epoch q = Rotation.from_euler( "xyz", self.rotations[mod_counter % len(self.rotations)], degrees=True, ) return dict( rotation=cast("QuaternionWXYZ", tuple(rotation_as_quat(q))), euler_rotation=self.rotations[mod_counter % len(self.rotations)], quat_rotation=q.as_quat(), position=self.positions[mod_counter % len(self.positions)], scale=self.scales[mod_counter % len(self.scales)], ) def __len__(self): return len(self.all_combinations_of_params())
[docs] def all_combinations_of_params(self): param_list = [] for i in range(len(self.rotations)): for j in range(len(self.scales)): for k in range(len(self.positions)): params = dict( rotations=[self.rotations[i]], scales=[self.scales[j]], positions=[self.positions[k]], ) param_list.append(params) return param_list
[docs]class RandomRotation(Default):
[docs] def __init__( self, position: VectorXYZ | None = None, scale: VectorXYZ | None = None, ): if position is not None: self.position = position else: self.position = (0.0, 1.5, 0.0) if scale is not None: self.scale = scale else: self.scale = (1.0, 1.0, 1.0)
def __call__( self, seed: int, mode: ExperimentMode, epoch: int, # noqa: ARG002 episode: int, ) -> ObjectInitParams: seed = episode_seed(seed, mode, episode) rng = np.random.RandomState(seed) euler_rotation = rng.uniform(0, 360, 3) q = Rotation.from_euler("xyz", euler_rotation, degrees=True) return dict( rotation=cast("QuaternionWXYZ", tuple(rotation_as_quat(q))), euler_rotation=euler_rotation, quat_rotation=q.as_quat(), position=self.position, scale=self.scale, )