Source code for tbp.monty.frameworks.loggers.monty_handlers

# Copyright 2025 Thousand Brains Project
# Copyright 2022-2024 Numenta Inc.
#
# Copyright may exist in Contributors' modifications
# and/or contributions to the work.
#
# Use of this source code is governed by the MIT
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.

from __future__ import annotations

import abc
import copy
import json
import logging
import os
from pathlib import Path
from pprint import pformat
from typing import Container, Literal

from typing_extensions import override

from tbp.monty.frameworks.actions.actions import ActionJSONEncoder
from tbp.monty.frameworks.models.buffer import BufferEncoder
from tbp.monty.frameworks.utils.logging_utils import (
    lm_stats_to_dataframe,
    maybe_rename_existing_file,
)

logger = logging.getLogger(__name__)

###
# Template for MontyHandler
###


[docs]class MontyHandler(metaclass=abc.ABCMeta):
[docs] @abc.abstractmethod def report_episode(self, **kwargs): pass
[docs] @abc.abstractmethod def close(self): pass
[docs] @abc.abstractclassmethod def log_level(self): """Handlers filter information from the data they receive. This class method specifies the level they filter at. """ pass
### # Handler classes ###
[docs]class DetailedJSONHandler(MontyHandler): """Grab any logs at the DETAILED level and append to a json file."""
[docs] def __init__( self, detailed_episodes_to_save: Container[int] | Literal["all"] = "all", detailed_save_per_episode: bool = False, episode_id_parallel: int | None = None, ) -> None: """Initialize the DetailedJSONHandler. Args: detailed_episodes_to_save: Container of episodes to save or the string ``"all"`` (default) to include every episode. detailed_save_per_episode: Whether to save individual episode files or consolidate into a single detailed_run_stats.json file. Defaults to False. episode_id_parallel: Episode id associated with current run, used to identify the episode when using run_parallel. """ self.already_renamed = False self.detailed_episodes_to_save = detailed_episodes_to_save self.detailed_save_per_episode = detailed_save_per_episode self.episode_id_parallel = episode_id_parallel
[docs] @classmethod def log_level(cls): return "DETAILED"
def _should_save_episode(self, global_episode_id: int) -> bool: """Check if episode should be saved. Returns: True if episode should be saved, False otherwise. """ return ( self.detailed_episodes_to_save == "all" or global_episode_id in self.detailed_episodes_to_save )
[docs] def get_episode_id( self, local_episode, mode: Literal["train", "eval"], **kwargs ) -> int: """Get global episode id corresponding to a mode-local episode index. This function is needed to determine correct episode id when using run_parallel. Returns: global_episode_id: Combined train+eval episode id. """ return ( kwargs[f"{mode}_episodes_to_total"][local_episode] if self.episode_id_parallel is None else self.episode_id_parallel )
[docs] def get_detailed_stats( self, data, global_episode_id: int, local_episode: int, mode: Literal["train", "eval"], ) -> dict: """Get detailed episode stats. Returns: stats: Episode stats. """ output_data = {} basic_stats = data["BASIC"][f"{mode}_stats"][local_episode] detailed_pool = data["DETAILED"] detailed_stats = detailed_pool.get(local_episode) if detailed_stats is None: detailed_stats = detailed_pool.get(global_episode_id) output_data[global_episode_id] = copy.deepcopy(basic_stats) output_data[global_episode_id].update(detailed_stats) return output_data
[docs] def report_episode(self, data, output_dir, local_episode, mode="train", **kwargs): """Report episode data.""" global_episode_id = self.get_episode_id(local_episode, mode, **kwargs) if not self._should_save_episode(global_episode_id): logger.debug( "Skipping detailed JSON for episode %s (not requested)", global_episode_id, ) return stats = self.get_detailed_stats(data, global_episode_id, local_episode, mode) if self.detailed_save_per_episode: self._save_per_episode(output_dir, global_episode_id, stats) else: self._save_all(global_episode_id, output_dir, stats)
def _save_per_episode(self, output_dir: str, global_episode_id: int, stats: dict): """Save detailed stats for a single episode. Args: output_dir: Directory where results are written. global_episode_id: Combined train+eval episode id used for DETAILED logs. stats: Dictionary containing episode stats keyed by global episode id. """ episodes_dir = Path(output_dir) / "detailed_run_stats" os.makedirs(episodes_dir, exist_ok=True) episode_file = episodes_dir / f"episode_{global_episode_id:06d}.json" maybe_rename_existing_file(episode_file) with open(episode_file, "w") as f: json.dump( {global_episode_id: stats[global_episode_id]}, f, cls=BufferEncoder, ) logger.debug( "Saved detailed JSON for episode %s to %s", global_episode_id, str(episode_file), ) def _save_all(self, global_episode_id: int, output_dir: str, stats: dict): """Save detailed stats for all episodes.""" save_stats_path = Path(output_dir) / "detailed_run_stats.json" if not self.already_renamed: maybe_rename_existing_file(save_stats_path) self.already_renamed = True with open(save_stats_path, "a") as f: json.dump( {global_episode_id: stats[global_episode_id]}, f, cls=BufferEncoder, ) f.write(os.linesep) logger.debug( "Appended detailed stats for episode %s to %s", global_episode_id, str(save_stats_path), )
[docs] def close(self): pass
[docs]class BasicCSVStatsHandler(MontyHandler): """Grab any logs at the BASIC level and append to train or eval CSV files."""
[docs] @classmethod def log_level(cls): return "BASIC"
[docs] def __init__(self): """Initialize with empty dictionary to keep track of writes per file. We only want to include the header the first time we write to a file. This keeps track of writes per file so we can format the file properly. """ self.reports_per_file = {}
[docs] @override def report_episode(self, data, output_dir, episode, mode="train", **kwargs): # episode is ignored when reporting stats to CSV # Look for train_stats or eval_stats under BASIC logs basic_logs = data["BASIC"] mode_key = f"{mode}_stats" output_file = Path(output_dir) / f"{mode}_stats.csv" stats = basic_logs.get(mode_key, {}) logger.debug(pformat(stats)) # Remove file if it existed before to avoid appending to previous results file if output_file not in self.reports_per_file: self.reports_per_file[output_file] = 0 maybe_rename_existing_file(output_file) else: self.reports_per_file[output_file] += 1 # Format stats for a single episode as a dataframe dataframe = lm_stats_to_dataframe(stats) # Move most relevant columns to front if "most_likely_object" in dataframe: top_columns = [ "primary_performance", "stepwise_performance", "num_steps", "rotation_error", "result", "most_likely_object", "primary_target_object", "stepwise_target_object", "highest_evidence", "time", "symmetry_evidence", "monty_steps", "monty_matching_steps", "individual_ts_performance", "individual_ts_reached_at_step", "primary_target_position", "primary_target_rotation_euler", "most_likely_rotation", ] else: top_columns = [ "primary_performance", "stepwise_performance", "num_steps", "rotation_error", "result", "primary_target_object", "stepwise_target_object", "time", "symmetry_evidence", "monty_steps", "monty_matching_steps", "primary_target_position", "primary_target_rotation_euler", ] dataframe = self.move_columns_to_front( dataframe, top_columns, ) # Only include header first time you write to this file header = self.reports_per_file[output_file] < 1 dataframe.to_csv(output_file, mode="a", header=header)
[docs] def move_columns_to_front(self, df, columns): for c_key in reversed(columns): df.insert(0, c_key, df.pop(c_key)) return df
[docs] def close(self): pass
[docs]class ReproduceEpisodeHandler(MontyHandler):
[docs] @classmethod def log_level(cls): return "BASIC"
[docs] @override def report_episode(self, data, output_dir, episode, mode="train", **kwargs): # Set up data directory with reproducibility info for each episode if not hasattr(self, "data_dir"): self.data_dir = os.path.join(output_dir, "reproduce_episode_data") os.makedirs(self.data_dir, exist_ok=True) # TODO: store a pointer to the training model # something like if train_epochs == 0: # use model_name_or_path # else: # get checkpoint of most up to date model # Write data to action file action_file = f"{mode}_episode_{episode}_actions.jsonl" action_file_path = os.path.join(self.data_dir, action_file) actions = data["BASIC"][f"{mode}_actions"][episode] with open(action_file_path, "w") as f: f.writelines( f"{json.dumps(action[0], cls=ActionJSONEncoder)}\n" for action in actions ) # Write data to object params / targets file object_file = f"{mode}_episode_{episode}_target.txt" object_file_path = os.path.join(self.data_dir, object_file) target = data["BASIC"][f"{mode}_targets"][episode] with open(object_file_path, "w") as f: json.dump(target, f, cls=BufferEncoder)
[docs] def close(self): pass