# Copyright 2025 Thousand Brains Project
# Copyright 2022-2024 Numenta Inc.
#
# Copyright may exist in Contributors' modifications
# and/or contributions to the work.
#
# Use of this source code is governed by the MIT
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.
import os
import pickle
"""
Based on https://github.com/huggingface/transformers/blob/1438c487df5ce38a7b2ae30877b3074b96a423dd/src/transformers/trainer_callback.py
"""
[docs]class LoggingCallbackHandler:
"""Calls a list of loggers on an event (eg post_train).
Each logger receives:
logger_args: dict with time stamps (steps, epochs, etc.) and
dataloader.primary_target which contains object id and pose
output_dir: Full path of the directory to store log files
Note:
This logger handler is intended primarily for logging
"""
def __init__(self, loggers, model, output_dir):
self.loggers = loggers
if isinstance(loggers, BaseMontyLogger):
self.loggers = [loggers]
self.model = model
self.output_dir = output_dir
@property
def logger_list(self):
return "\n".join(logger.__class__.__name__ for logger in self.loggers)
[docs] def pre_step(self, logger_args):
self.call_event("pre_step", logger_args)
[docs] def post_step(self, logger_args):
self.call_event("post_step", logger_args)
[docs] def pre_episode(self, logger_args):
self.call_event("pre_episode", logger_args)
[docs] def post_episode(self, logger_args):
self.call_event("post_episode", logger_args)
[docs] def pre_epoch(self, logger_args):
self.call_event("pre_epoch", logger_args)
[docs] def post_epoch(self, logger_args):
self.call_event("post_epoch", logger_args)
[docs] def pre_train(self, logger_args):
self.call_event("pre_train", logger_args)
[docs] def post_train(self, logger_args):
self.call_event("post_train", logger_args)
[docs] def pre_eval(self, logger_args):
self.call_event("pre_eval", logger_args)
[docs] def post_eval(self, logger_args):
self.call_event("post_eval", logger_args)
[docs] def close(self, logger_args):
self.call_event("close", logger_args)
[docs] def call_event(self, event, logger_args):
for logger in self.loggers:
getattr(logger, event)(
logger_args=logger_args, output_dir=self.output_dir, model=self.model
)
[docs]class BaseMontyLogger:
"""Basic logger that logs or saves information when logging is called."""
def __init__(self, handlers):
self.handlers = handlers
[docs] def pre_step(self, logger_args, output_dir, model):
pass
[docs] def post_step(self, logger_args, output_dir, model):
pass
[docs] def pre_episode(self, logger_args, output_dir, model):
pass
[docs] def post_episode(self, logger_args, output_dir, model):
pass
[docs] def pre_epoch(self, logger_args, output_dir, model):
pass
[docs] def post_epoch(self, logger_args, output_dir, model):
pass
[docs] def pre_train(self, logger_args, output_dir, model):
pass
[docs] def post_train(self, logger_args, output_dir, model):
pass
[docs] def pre_eval(self, logger_args, output_dir, model):
pass
[docs] def post_eval(self, logger_args, output_dir, model):
pass
[docs] def close(self, logger_args, output_dir, model):
for handler in self.handlers:
handler.close()
[docs]class TestLogger(BaseMontyLogger):
def __init__(self, handlers):
self.handlers = handlers
self.log = []
[docs] def pre_episode(self, logger_args, output_dir, model):
self.log.append("pre_episode")
[docs] def post_episode(self, logger_args, output_dir, model):
self.log.append("post_episode")
[docs] def pre_epoch(self, logger_args, output_dir, model):
self.log.append("pre_epoch")
[docs] def post_epoch(self, logger_args, output_dir, model):
self.log.append("post_epoch")
[docs] def pre_train(self, logger_args, output_dir, model):
self.log.append("pre_train")
[docs] def post_train(self, logger_args, output_dir, model):
self.log.append("post_train")
[docs] def pre_eval(self, logger_args, output_dir, model):
self.log.append("pre_eval")
[docs] def post_eval(self, logger_args, output_dir, model):
self.log.append("post_eval")
[docs] def close(self, logger_args, output_dir, model):
with open(os.path.join(output_dir, "fake_log.pkl"), "wb") as f:
pickle.dump(self.log, f)
def __deepcopy__(self, memo):
# Do not create new copy of loggers. They are create by the tests outside
# the experiment
return self