Source code for tbp.monty.frameworks.models.mixins.no_reset_evidence
# Copyright 2025 Thousand Brains Project
#
# Copyright may exist in Contributors' modifications
# and/or contributions to the work.
#
# Use of this source code is governed by the MIT
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.
from typing import Any, Dict
from scipy.spatial.transform import Rotation
from tbp.monty.frameworks.models.evidence_matching import EvidenceGraphLM
from tbp.monty.frameworks.utils.logging_utils import compute_pose_error
[docs]class TheoreticalLimitLMLoggingMixin:
"""Mixin that adds theoretical limit and pose error logging for learning modules.
This mixin augments the learning module with methods to compute and log:
- The maximum evidence score for each object.
- The theoretical lower bound of pose error on the target object, assuming
Monty had selected the best possible hypothesis (oracle performance).
- The actual pose error of the most likely hypothesis (MLH) on the target object.
These metrics are useful for analyzing the performance gap between the model's
current inference and its best achievable potential given its internal hypotheses.
Compatible with:
- EvidenceGraphLM
"""
def __init_subclass__(cls, **kwargs: Any) -> None:
"""Ensure the mixin is used only with compatible learning modules.
Raises:
TypeError: If the mixin is used with a non-compatible learning module.
"""
super().__init_subclass__(**kwargs)
if not any(issubclass(b, (EvidenceGraphLM)) for b in cls.__bases__):
raise TypeError(
"TheoreticalLimitLMLoggingMixin must be mixed in with a subclass of "
f"EvidenceGraphLM, got {cls.__bases__}"
)
def _add_detailed_stats(self, stats: Dict[str, Any]) -> Dict[str, Any]:
"""Add detailed statistics to the logging dictionary.
This includes metrics like the max evidence score per object, the theoretical
limit of Monty (i.e., pose error of Monty's best potential hypothesis on the
target object) , and the pose error of the MLH hypothesis on the target object.
Args:
stats (Dict[str, Any]): The existing statistics dictionary to augment.
Returns:
Dict[str, Any]: Updated statistics dictionary.
"""
stats["max_evidence"] = {k: max(v) for k, v in self.evidence.items()}
stats["target_object_theoretical_limit"] = (
self._theoretical_limit_target_object_pose_error()
)
stats["target_object_pose_error"] = self._mlh_target_object_pose_error()
return stats
def _theoretical_limit_target_object_pose_error(self) -> float:
"""Compute the theoretical minimum rotation error on the target object.
This considers all possible hypotheses rotations on the target object
and compares them to the target's rotation. The theoretical limit conveys the
best achievable performance if Monty selects the best hypothesis as its most
likely hypothesis (MLH).
Note that having a low pose error for the theoretical limit may not be
sufficient for deciding on the quality of the hypothesis. Despite good
hypotheses being generally correlated with good theoretical limit, it is
possible for rotation error to be small (i.e., low geodesic distance to
ground-truth rotation), while the hypothesis is on a different location
of the object.
Returns:
float: The minimum achievable rotation error (in radians).
"""
hyp_rotations = Rotation.from_matrix(
self.possible_poses[self.primary_target]
).inv()
target_rotation = Rotation.from_quat(self.primary_target_rotation_quat)
error = compute_pose_error(hyp_rotations, target_rotation)
return error
def _mlh_target_object_pose_error(self) -> float:
"""Compute the actual rotation error between predicted and target pose.
This compares the most likely hypothesis pose (based on evidence) on the target
object with the ground truth rotation of the target object.
Returns:
float: The rotation error (in radians).
"""
obj_rotation = self.get_mlh_for_object(self.primary_target)["rotation"].inv()
target_rotation = Rotation.from_quat(self.primary_target_rotation_quat)
error = compute_pose_error(obj_rotation, target_rotation)
return error