Source code for tbp.monty.frameworks.utils.dataclass_utils
# Copyright 2025 Thousand Brains Project
# Copyright 2022-2024 Numenta Inc.
#
# Copyright may exist in Contributors' modifications
# and/or contributions to the work.
#
# Use of this source code is governed by the MIT
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.
import dataclasses
import importlib
from inspect import Parameter, signature
from typing import Callable, Optional, Type
__all__ = [
"as_dataclass_dict",
"create_dataclass_args",
"from_dataclass_dict",
]
# Keeps track of the dataclass type in a serializable dataclass dict
_DATACLASS_TYPE = "__dataclass_type__"
[docs]def as_dataclass_dict(obj):
"""Convert a dataclass instance to a serializable dataclass dict.
Args:
obj: The dataclass instance to convert
Returns:
A dictionary with the dataclass fields and values
Raises:
TypeError: If the object is not a dataclass instance
"""
if not dataclasses.is_dataclass(obj):
raise TypeError(f"Expecting dataclass instance but got {type(obj)}")
result = {_DATACLASS_TYPE: f"{obj.__module__}.{obj.__class__.__name__}"}
for f in dataclasses.fields(obj):
value = getattr(obj, f.name)
# Handle nested dataclassess
if dataclasses.is_dataclass(value):
value = as_dataclass_dict(value)
result[f.name] = value
return result
[docs]def from_dataclass_dict(datadict):
"""Convert a serializable dataclass dict back into the original dataclass.
Expecting that the serializable dataclass dict was created via :func:`.asdict`.
Args:
datadict: The serializable dataclass dict to convert
Returns:
The original dataclass instance
Raises:
TypeError: If the object is not a dict instance
"""
if not isinstance(datadict, dict):
raise TypeError(f"Expecting dict instance but got {type(datadict)}")
# Check for nested dataclass
kwargs = {}
for k, v in datadict.items():
if isinstance(v, dict):
v = from_dataclass_dict(v)
kwargs[k] = v
if _DATACLASS_TYPE not in kwargs:
# Not a dataclass dict
return kwargs
# Get dataclass module and type
module_name, class_name = kwargs.pop(_DATACLASS_TYPE).rsplit(".", 1)
# Load dataclass module and recreate the instance
module = importlib.import_module(module_name)
dataclass_type = getattr(module, class_name)
return dataclass_type(**kwargs)
def extract_fields(function):
# Extract function signature
sig = signature(function)
_fields = []
# Convert function parameters to dataclass fields
for p in sig.parameters.values():
# Ignore "self" in case of class methods
if p.name == "self":
continue
f = [p.name]
if p.default != Parameter.empty:
# Infer data type from default value
if p.annotation != Parameter.empty:
t = type(p.default)
else:
t = p.annotation
f.extend([t, dataclasses.field(default=p.default)])
elif p.annotation != Parameter.empty:
f.append(p.annotation)
_fields.append(f)
return _fields
[docs]def create_dataclass_args(
dataclass_name: str,
function: Callable,
base: Optional[Type] = None,
):
"""Creates configuration dataclass args from a given function arguments.
When the function arguments have type annotations these annotations will be
passed to the dataclass fields, otherwise the type will be inferred from the
argument default value, if any.
For example::
SingleSensorAgentArgs = create_dataclass_args(
"SingleSensorAgentArgs", SingleSensorAgent.__init__)
# Is equivalent to
@dataclass(frozen=True)
class SingleSensorAgentArgs:
agent_id: str
sensor_id: str
position: Tuple[float. float, float] = (0.0, 1.5, 0.0)
rotation: Tuple[float, float, float, float] = (1.0, 0.0, 0.0, 0.0)
height: float = 0.0
:
Args:
dataclass_name: The name of the new dataclass
function: The function used to extract the parameters for the dataclass
base: Optional base class for newly created dataclass
Returns:
New dataclass with fields defined by the function arguments.
"""
_fields = extract_fields(function)
# Add base class to new dataclass if given. Limited to a single base class
bases = (base,) if base is not None else ()
return dataclasses.make_dataclass(dataclass_name, _fields, bases=bases, frozen=True)
def config_to_dict(config):
"""Convert config composed of mixed dataclass and dict elements to pure dict.
We want to convert configs composed of mixed dataclass and dict elements to
pure dicts without dataclasses for backward compatibility.
TODO: Remove once all other configs are converted to dict only
Returns:
Pure dict version of config.
"""
if isinstance(config, dict):
return {k: config_to_dict(v) for k, v in config.items()}
if dataclasses.is_dataclass(config):
return dataclasses.asdict(config)
return config
def get_subset_of_args(arguments, function):
dict_args = config_to_dict(arguments)
_fields = extract_fields(function)
common_fields = dict()
for field in _fields:
field_name = field[0]
if field_name in dict_args:
common_fields[field_name] = dict_args[field_name]
return common_fields