# Copyright 2025 Thousand Brains Project
# Copyright 2022-2024 Numenta Inc.
#
# Copyright may exist in Contributors' modifications
# and/or contributions to the work.
#
# Use of this source code is governed by the MIT
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.
import copy
import numpy as np
import pandas as pd
import wandb
from sklearn.preprocessing import LabelEncoder
from tbp.monty.frameworks.loggers.exp_logger import BaseMontyLogger
from tbp.monty.frameworks.utils.logging_utils import (
get_stats_per_lm,
target_data_to_dict,
total_size,
)
"""
Loggers define a pool of data in the attribute self.data. The job of loggers is to
create and manage this pool of data, including flushing it periodically. The job of
handlers is to grab subsets of the data pool, format it, and send it to a location
(wandb, or a certain file). The data pool can be updated at any of the callback points
in the experiment, for example, `post_episode`. The loggers receive a standard set of
arguments used for building the data pool, most notably a reference to the model. During
reporting, all handlers are called upon to send data to a destination. Handlers also all
receive a standard set of arguments, including most importantly the self.data pool.
The Basic logger is associated with the BASIC monty_log_level specified in
config_args.logging_config. The Detailed logger is associated with the DETAILED logging
level. The Detailed logger data pool contains a superset of the data in the basic logger
self.data attribute.
The term logger is used unconventionally here. Normally, the thing being logged is a
string of text and the destination is a file or the terminal screen. Here, the stuff
being logged is structured data. Perhaps the term DataReporter is more apt.
NOTE: previously, we updated the data pool every episode, and logged at most once per
episode. Reporting and flushing frequency were based on the size of self.data. To make
things easier to think about for handlers, we moved reporting and flushing both to take
place at the end of an episode. Why all the callbacks then, like post_step, post_train,
etc.? We are building a plane while flying it. Best not to throw the landing gear out
the window just because it isn't being used right now.
"""
[docs]class BasicGraphMatchingLogger(BaseMontyLogger):
"""Basic logger that logs or saves information when logging is called."""
def __init__(
self,
handlers,
):
"""Initialize logger."""
self.handlers = handlers
self.data = dict(
BASIC=dict(
train_stats=dict(),
train_overall_stats=dict(),
train_targets=dict(),
train_actions=dict(),
train_timing=dict(),
eval_stats=dict(),
eval_overall_stats=dict(),
eval_actions=dict(),
eval_targets=dict(),
eval_timing=dict(),
),
DETAILED=dict(),
)
self.overall_train_stats = dict(
num_episodes=0,
num_correct=0,
num_correct_mlh=0,
num_no_match=0,
num_confused=0,
num_confused_mlh=0,
num_pose_time_out=0,
num_time_out=0,
num_patch_off_object=0,
num_no_label=0,
num_correct_per_lm=0,
num_correct_mlh_per_lm=0,
num_no_match_per_lm=0,
num_confused_per_lm=0,
num_confused_mlh_per_lm=0,
num_pose_time_out_per_lm=0,
num_time_out_per_lm=0,
num_patch_off_object_per_lm=0,
num_no_label_per_lm=0,
episode_correct=0,
episode_correct_mlh=0,
episode_no_match=0,
episode_confused=0,
episode_confused_mlh=0,
episode_pose_time_out=0,
episode_time_out=0,
episode_lm_performances=[],
# Total number of steps performed during the episode,
# including steps where no sensory data was passed to the learning-modules:
monty_steps=[],
# Number of global monty *matching* steps. Counts steps when at least one LM
# was updated:
monty_matching_steps=[],
# Number of steps associated with an individual LM processing data, i.e.
# can differ across the LMs of a Monty model:
episode_lm_steps=[],
episode_lm_steps_indv_ts=[],
episode_symmetry_evidence=[],
rotation_errors=[],
run_times=[],
# Policy stats
goal_states_attempted=0,
goal_state_success_rate=0,
)
self.overall_eval_stats = copy.deepcopy(self.overall_train_stats)
self.lms = []
# Order of performance_options matters since we check them in sequence for each
# lm. The lower in the list, the stronger it is to determine overall
# performance. Performance lower down in the list will always trump higher-up
# performance values. For example, if we have N LMs and one of them has
# performance "correct", the overall episode performance is correct. It doesn't
# matter if all the other LMs have a time_out performance.
# The three strong terminal condition cases (no_match, confused, correct) need
# to be listed last. The weaker time out conditions first so they don't
# overwrite the performance of an LM that caused the episode to end.
# TODO: what if 2 LMs reach a strong terminal state at the same step? For
# example if one LM reaches the correct state and the other confused the
# performance will be logged as correct. However, in this case we would
# probably want to keep moving or log a conflicting performance.
# TODO: If we have a time out and look at the mlh, we should take the majority
# vote and not let correct_mlh win
self.performance_options = [
"patch_off_object",
"no_label",
"pose_time_out",
"time_out",
"confused_mlh",
"correct_mlh",
"no_match",
"confused",
"correct",
]
self.performance_encoder = LabelEncoder()
self.performance_encoder.fit(self.performance_options)
self.use_parallel_wandb_logging = False
pd.set_option("display.max_rows", False)
[docs] def flush(self):
self.data = dict(
BASIC=dict(
train_stats=dict(),
train_overall_stats=dict(),
train_targets=dict(),
train_actions=dict(),
train_timing=dict(),
eval_stats=dict(),
eval_overall_stats=dict(),
eval_actions=dict(),
eval_targets=dict(),
eval_timing=dict(),
),
DETAILED=dict(),
)
[docs] def log_episode(self, logger_args, output_dir, model):
mode = model.experiment_mode
episode = logger_args[f"{mode}_episodes"]
for handler in self.handlers:
handler.report_episode(self.data, output_dir, episode, mode)
if not self.use_parallel_wandb_logging:
# when logging in parallel to wandb we need to wait with flushing
# until the parallel run script has retrieved the episode stats.
self.flush()
[docs] def maybe_log(self, logger_args, output_dir, model):
"""Left here in case we go back to size based logging.
Remove if not used after code has stabilized.
"""
# If we get above 10Mb, report, flush, continue
if total_size(self.data) > 10_000_000:
self.log(logger_args, output_dir, model)
[docs] def post_episode(self, logger_args, output_dir, model):
self.update_episode_data(logger_args, output_dir, model)
self.log_episode(logger_args, output_dir, model)
[docs] def update_episode_data(self, logger_args, output_dir, model):
"""Run get_stats_per_lm and add to overall stats.
Store stats ~
1 (episode)
lm_0 (which lm)
stats
"""
performance_dict = get_stats_per_lm(model, logger_args["target"])
target_dict = target_data_to_dict(logger_args["target"])
if len(self.lms) == 0: # first time function is called
for lm in performance_dict.keys():
if lm.startswith("LM_"):
self.lms.append(lm)
mode = model.experiment_mode
episode = logger_args[f"{mode}_episodes"]
actions = model.motor_system.action_sequence
logger_time = {k: v for k, v in logger_args.items() if k != "target"}
self.data["BASIC"][f"{mode}_stats"][episode] = performance_dict
self.update_overall_stats(
mode, episode, model.episode_steps, model.matching_steps
)
overall_stats = self.get_formatted_overall_stats(mode, episode)
self.data["BASIC"][f"{mode}_overall_stats"][episode] = overall_stats
self.data["BASIC"][f"{mode}_actions"][episode] = actions
self.data["BASIC"][f"{mode}_targets"][episode] = target_dict
self.data["BASIC"][f"{mode}_timing"][episode] = logger_time
self.data["BASIC"][f"{mode}_stats"][episode]["target"] = target_dict
[docs] def update_overall_stats(self, mode, episode, episode_steps, monty_matching_steps):
"""Update overall run stats for mode."""
if mode == "train":
stats = self.overall_train_stats
else:
stats = self.overall_eval_stats
lm_performances = []
for lm in self.lms:
# This accumulates stats from all LM
episode_stats = self.data["BASIC"][f"{mode}_stats"][episode][lm]
performance = episode_stats["primary_performance"]
if performance is not None: # in pre training performance is None
stats[f"num_{performance}_per_lm"] += 1
lm_performances.append(performance)
stats["rotation_errors"].append(episode_stats["rotation_error"])
stats["run_times"].append(episode_stats["time"])
stats["episode_lm_steps"].append(episode_stats["num_steps"])
stats["episode_lm_steps_indv_ts"].append(
episode_stats["individual_ts_reached_at_step"]
)
stats["episode_symmetry_evidence"].append(
episode_stats["symmetry_evidence"]
)
stats["monty_steps"].append(episode_steps)
stats["monty_matching_steps"].append(monty_matching_steps)
stats["goal_states_attempted"] = episode_stats["goal_states_attempted"]
stats["goal_state_success_rate"] = (
episode_stats["goal_state_achieved"]
/ episode_stats["goal_states_attempted"]
if episode_stats["goal_states_attempted"]
else 0 # Handles division by 0
)
stats["episode_lm_performances"].append(lm_performances)
for p in self.performance_options:
if p in lm_performances:
# order of performance_options matters since we overwrite here!
# episode_performance is only no_match if no lm had another
# performance. That makes it possible for some lms to have no match
# but still have an overall performance of correct (or other).
episode_performance = p
for p in self.performance_options:
stats[f"episode_{p}"] = int(p == episode_performance)
stats[f"num_{p}"] += int(p == episode_performance)
stats["num_episodes"] += 1
[docs]class DetailedGraphMatchingLogger(BasicGraphMatchingLogger):
"""Log detailed stats as .json file by saving data for each LM and SM."""
def __init__(self, handlers):
"""Initialize stats dicts."""
super().__init__(handlers)
self.train_episodes_to_total = dict()
self.eval_episodes_to_total = dict()
[docs] def log_episode(self, logger_args, output_dir, model):
mode = model.experiment_mode
episode = logger_args[f"{mode}_episodes"]
kwargs = dict(
train_episodes_to_total=self.train_episodes_to_total,
eval_episodes_to_total=self.eval_episodes_to_total,
)
for handler in self.handlers:
handler.report_episode(self.data, output_dir, episode, mode, **kwargs)
self.flush()
[docs] def post_episode(self, logger_args, output_dir, model):
self.update_episode_data(logger_args, output_dir, model)
self.log_episode(logger_args, output_dir, model)
[docs] def update_episode_data(self, logger_args, output_dir, model):
"""Add episode data to overall buffer_data dict.
Store stats ~
1 (episode)
lm_0 (which lm)
stats
"""
# update train / eval stats
super().update_episode_data(logger_args, output_dir, model)
episodes = logger_args["train_episodes"] + logger_args["eval_episodes"]
self.train_episodes_to_total[logger_args["train_episodes"]] = episodes
self.eval_episodes_to_total[logger_args["eval_episodes"]] = episodes
buffer_data = dict()
for i, lm in enumerate(model.learning_modules):
lm_dict = dict()
lm_dict.update(logger_args)
lm_dict.update({"locations": lm.buffer.locations})
lm_dict.update(lm.buffer.features)
lm_dict.update({"displacements": lm.buffer.displacements})
lm_dict.update(lm.buffer.stats)
lm_dict.update(mode=model.experiment_mode)
lm_dict.update({"stepwise_targets_list": lm.stepwise_targets_list})
buffer_data[f"LM_{i}"] = lm_dict # NOTE: probably same for all LMs
for i, sm in enumerate(model.sensor_modules):
if len(sm.state_dict()["raw_observations"]) > 0:
buffer_data[f"SM_{i}"] = sm.state_dict()
# TODO ensure will work with multiple, independent sensor agents
buffer_data["motor_system"] = dict()
buffer_data["motor_system"]["action_sequence"] = (
model.motor_system.action_sequence
)
# Some motor systems store additional data specific to their policy, e.g. when
# principal curvature has informed movements
if hasattr(model.motor_system, "action_details"):
buffer_data["motor_system"]["action_details"] = (
model.motor_system.action_details
)
self.data["DETAILED"][episodes] = buffer_data
[docs]class SelectiveEvidenceLogger(BasicGraphMatchingLogger):
"""Log evidences as .json file by saving data for each LM and SM.
This is slimmed down to only log data needed for object similarity analysis.
Data logged:
- evidences for each object and pose at the end of an episode
- first frame of the view finder
"""
def __init__(self, handlers):
"""Initialize stats dicts."""
super().__init__(handlers)
self.train_episodes_to_total = dict()
self.eval_episodes_to_total = dict()
[docs] def log_episode(self, logger_args, output_dir, model):
mode = model.experiment_mode
episode = logger_args[f"{mode}_episodes"]
kwargs = dict(
train_episodes_to_total=self.train_episodes_to_total,
eval_episodes_to_total=self.eval_episodes_to_total,
)
for handler in self.handlers:
handler.report_episode(self.data, output_dir, episode, mode, **kwargs)
self.flush()
[docs] def post_episode(self, logger_args, output_dir, model):
self.update_episode_data(logger_args, output_dir, model)
self.log_episode(logger_args, output_dir, model)
[docs] def update_episode_data(self, logger_args, output_dir, model):
"""Add episode data to overall buffer_data dict."""
# update train / eval stats
super().update_episode_data(logger_args, output_dir, model)
episodes = logger_args["train_episodes"] + logger_args["eval_episodes"]
self.train_episodes_to_total[logger_args["train_episodes"]] = episodes
self.eval_episodes_to_total[logger_args["eval_episodes"]] = episodes
buffer_data = dict()
for i, lm in enumerate(model.learning_modules):
lm_dict = dict()
lm_dict.update(
{
# Save evidences and hypotheses only for last step to save storage
"evidences_ls": lm.buffer.stats["evidences"][-1],
"possible_locations_ls": lm.buffer.stats["possible_locations"][-1],
# possible rotations don't change over time so no indexing here
"possible_rotations_ls": lm.buffer.stats["possible_rotations"],
# Save possible matches, mlh and symmetry evidence for all steps
"possible_matches": lm.buffer.stats["possible_matches"],
"current_mlh": lm.buffer.stats["current_mlh"],
"symmetry_evidence": lm.buffer.stats["symmetry_evidence"],
}
)
buffer_data[f"LM_{i}"] = lm_dict
for i, sm in enumerate(model.sensor_modules):
if len(sm.state_dict()["raw_observations"]) > 0:
# Only store first observation
buffer_data[f"SM_{i}"] = sm.state_dict()["raw_observations"][0]
self.data["DETAILED"][episodes] = buffer_data