Source code for tbp.monty.hydra
# Copyright 2025-2026 Thousand Brains Project
#
# Copyright may exist in Contributors' modifications
# and/or contributions to the work.
#
# Use of this source code is governed by the MIT
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.
from __future__ import annotations
import contextlib
import importlib
from pathlib import Path
from typing import Any, Callable
import numpy as np
from omegaconf import OmegaConf
[docs]def monty_class_resolver(class_name: str) -> type:
"""Returns a class object by fully qualified path.
TODO: This is an interim solution to retrieve my_class in
the my_class(**my_args) pattern.
"""
parts = class_name.split(".")
module = ".".join(parts[:-1])
klass = parts[-1]
module_obj = importlib.import_module(module)
return getattr(module_obj, klass)
[docs]def ndarray_resolver(list_or_tuple: list | tuple) -> np.ndarray:
"""Returns a numpy array from a list or tuple."""
return np.array(list_or_tuple)
[docs]def ones_resolver(n: int) -> np.ndarray:
"""Returns a numpy array of ones."""
return np.ones(n)
[docs]def numpy_list_eval_resolver(expr_list: list) -> list[float]:
# call str() on each item so we can use number literals
return [eval(str(item)) for item in expr_list] # noqa: S307
[docs]def path_expanduser_resolver(path: str) -> str:
"""Returns a path with ~ expanded to the user's home directory."""
return str(Path(path).expanduser())
[docs]def tests_dir_resolver(path: str) -> str:
return str(Path(__file__).parents[3] / "tests" / Path(path))
[docs]def register_resolvers() -> None:
"""Register custom OmegaConf resolvers for Monty configs.
Skips resolvers that are already registered rather than raising
a ValueError, since multiple entry points (e.g. tests/__init__.py
and update_snapshots.py) may call this function in the same process.
"""
resolvers: dict[str, Callable[..., Any]] = {
"monty.class": monty_class_resolver,
"np.array": ndarray_resolver,
"np.ones": ones_resolver,
"np.list_eval": numpy_list_eval_resolver,
"path.expanduser": path_expanduser_resolver,
"path.tests": tests_dir_resolver,
}
for name, resolver in resolvers.items():
with contextlib.suppress(ValueError):
OmegaConf.register_new_resolver(name, resolver)