Source code for tbp.monty.frameworks.loggers.wandb_handlers

# Copyright 2025 Thousand Brains Project
# Copyright 2022-2024 Numenta Inc.
#
# Copyright may exist in Contributors' modifications
# and/or contributions to the work.
#
# Use of this source code is governed by the MIT
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.
from __future__ import annotations

import json

import numpy as np
import pandas as pd
import wandb
from typing_extensions import override

from tbp.monty.frameworks.loggers.monty_handlers import MontyHandler
from tbp.monty.frameworks.utils.logging_utils import (
    format_columns_for_wandb,
    get_rgba_frames_single_sm,
    lm_stats_to_dataframe,
)
from tbp.monty.frameworks.utils.plot_utils import mark_obs


[docs]class WandbWrapper(MontyHandler): """Container for wandb handlers. Loops over a series of handlers which log different information without commiting (sending it to wandb). The wrapper finally commits all logs at once. This allows us to maintain control over the wandb global step. This class assumes reporting takes place once per episode, hence the wandb handlers have `report_episode` methods. """
[docs] def __init__( self, wandb_handlers: list, run_name: str, wandb_group: str | None = None, config: dict | None = None, resume_wandb_run: bool = False, wandb_id: str | None = None, ): self.name = run_name self.group = wandb_group self.config = config self.wandb_logger = wandb.init( name=self.name, group=self.group, project="Monty", config=config, resume=resume_wandb_run, id=wandb_id, ) self.wandb_handlers = [wandb_handler() for wandb_handler in wandb_handlers]
[docs] def report_episode(self, data, output_dir, episode, mode="train", **kwargs): for handler in self.wandb_handlers: handler.report_episode(data, output_dir, episode, mode=mode, **kwargs) wandb.log({}) # TODO: What is this for?
[docs] @classmethod def log_level(cls): return ""
[docs] def close(self): self.wandb_logger.finish()
[docs]class WandbHandler(MontyHandler): """Parent class for wandb loggers."""
[docs] def __init__(self): self.report_count = 0 self.variable_length_columns = [ "possible_object_poses", "possible_object_locations", "possible_object_sources", "possible_match_sources", "detected_path", ] self.post_init()
[docs] def post_init(self): """Handle additional initialization for subclasses. Call this to handle any additional initializations for subclasses not covered by init of `WandbHandler`. """ pass
[docs] @classmethod def log_level(cls): return ""
[docs] def report_episode(self, data, output_dir, mode="train", **kwargs): pass
[docs] def close(self): pass
[docs]class BasicWandbTableStatsHandler(WandbHandler): """Log LM episode stats to wandb as tables."""
[docs] @classmethod def log_level(cls): return "BASIC"
[docs] @override def report_episode(self, data, output_dir, episode, mode="train", **kwargs): ### # Log basic statistics # Ignore the episode value ### # Get stats data depending on mode (train or eval) basic_logs = data["BASIC"] mode_key = f"{mode}_stats" stats_table = f"{mode}_stats_table" stats = basic_logs.get(mode_key, {}) # if len(stats) > 0: df = lm_stats_to_dataframe(stats, format_for_wandb=True) # Filter df to only include columns without variable length entries, like # possible_object_poses const_columns = list(set(df.columns) - set(self.variable_length_columns)) const_df = df[const_columns] # shorthand for self.train_table = df or self.eval_table = df if not hasattr(self, stats_table): setattr(self, stats_table, const_df) else: # Don't log first episode twice setattr( self, stats_table, pd.concat([getattr(self, stats_table), const_df]), ) # print(getattr(self, stats_table)) table = wandb.Table(dataframe=getattr(self, stats_table)) wandb.log({stats_table: table}, commit=False) self.report_count += 1
[docs]class DetailedWandbTableStatsHandler(BasicWandbTableStatsHandler): """Log LM stats and actions to wandb as tables. This is a modified version of BasicWandbTableStatsHandler that, in addition to the stats, logs the actions exectuted in each episode to wandb as tables (one table per episode). """
[docs] def __init__(self): super().__init__()
[docs] @classmethod def log_level(cls): return "DETAILED"
[docs] def report_episode(self, data, output_dir, episode, mode="train", **kwargs): super().report_episode(data, output_dir, episode, mode, **kwargs) basic_logs = data["BASIC"] # Get actions depending on mode (train or eval) action_key = f"{mode}_actions" action_data = basic_logs.get(action_key, {}) assert len(action_data) == 1, "why do we have keys for multiple or no episodes" # Log one table of actions per episode # for episode in action_data.keys(): # TODO: is table the best format for this? episode = list(action_data.keys())[0] table_name = f"{mode}_actions/episode_{episode}_table" actions = action_data[episode] for i, action in enumerate(actions): a = action[0] if a is not None: o = {} for key, value in dict(a).items(): if key in {"action", "agent_id"}: continue # don't duplicate action or agent_id in "params" if isinstance(value, np.ndarray): o[key] = value.tolist() else: o[key] = value actions[i][0] = { f"{a.agent_id}": {"action": a.name, "params": json.dumps(o)} } actions_df = pd.DataFrame(actions) table = wandb.Table(dataframe=actions_df) wandb.log({table_name: table}, step=episode)
[docs]class BasicWandbChartStatsHandler(WandbHandler): """Log LM episode stats to wandb with one chart per measure."""
[docs] @classmethod def log_level(cls): return "BASIC"
[docs] @override def report_episode(self, data, output_dir, episode, mode="train", **kwargs): basic_logs = data["BASIC"] mode_key = f"{mode}_overall_stats" stats = basic_logs.get(mode_key, {}) wandb.log(stats[episode], step=episode, commit=False)
[docs] def get_safe_columns_per_lm(self, stats): """Format each episode by looping over learning modules and formatting each one. Args: stats: dict ~ {LM_0: dict, LM_1: dict} Returns: The formatted stats. """ safe_stats = {} for lm, lm_dict in stats.items(): safe_lm_dict = { lm_col: lm_val for lm_col, lm_val in lm_dict.items() if lm_col not in self.variable_length_columns } formatted_lm_dict = format_columns_for_wandb(safe_lm_dict) safe_stats[lm] = formatted_lm_dict return safe_stats
[docs]class DetailedWandbHandler(WandbHandler): """Make animations from sequences of observations on wandb. NOTE: not yet generalized for different model architectures. This assumes SM_0 is the patch, SM_1 is the view finder. """
[docs] def post_init(self): self.report_key = "raw_rgba"
[docs] def get_episode_frames(self, episode_stats): frames_per_sm = {} sm_ids = [sm for sm in episode_stats.keys() if sm.startswith("SM_")] for sm in sm_ids: observations = episode_stats[sm]["raw_observations"] frames_per_sm[sm] = get_rgba_frames_single_sm(observations) return frames_per_sm
[docs] @override def report_episode(self, data, output_dir, episode, mode="train", **kwargs): # mode is ignored when reporting this episode detailed_stats = data["DETAILED"] frames_per_sm = self.get_episode_frames(detailed_stats[episode]) for sm, frames in frames_per_sm.items(): wandb.log( { f"episode_{episode}_{self.report_key}_{sm}": wandb.Video( frames, format="gif" ) }, step=episode, commit=False, )
[docs]class DetailedWandbMarkedObsHandler(DetailedWandbHandler): """Just like DetailedWandbHandler, but use fancier observations. NOTE: this assumes sm1 and sm0 are the view finder and patch modules respectively, meaning this logger is specific to the model architecture NOTE: this is slow, adding ~ a few seconds per function call. The intended use case is for debugging and error analysis, so speed should not be an issue when the number of episodes is small. But probably do not use this fi you are running a large number of experiments. """
[docs] def post_init(self): self.report_key = "marked_obs"
[docs] def get_episode_frames(self, episode_stats): frame_key = "patch_view" frame_dict = {frame_key: []} for step in range(len(episode_stats["SM_1"]["raw_observations"])): viz_obs = episode_stats["SM_1"]["raw_observations"][step] patch_obs = episode_stats["SM_0"]["raw_observations"][step] frame = mark_obs(viz_obs, patch_obs) wandb_frame = np.moveaxis(frame, [0, 1, 2], [1, 2, 0]) # format for wandb frame_dict[frame_key].append(wandb_frame) frame_dict[frame_key] = np.array(frame_dict[frame_key]) return frame_dict