laco 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- laco/__init__.py +33 -0
- laco/__init__.pyi +5 -0
- laco/_dump.py +104 -0
- laco/_lazy.py +139 -0
- laco/_loader.py +194 -0
- laco/_overrides.py +53 -0
- laco/_resolvers.py +17 -0
- laco/builtins.py +20 -0
- laco/cli.py +2 -0
- laco/env.py +111 -0
- laco/examples/mlp.py +28 -0
- laco/handler.py +100 -0
- laco/keys.py +19 -0
- laco/language.py +322 -0
- laco/py.typed +0 -0
- laco/readers/wandb.py +17 -0
- laco/utils.py +220 -0
- laco-0.0.0.dist-info/LICENSE +21 -0
- laco-0.0.0.dist-info/METADATA +28 -0
- laco-0.0.0.dist-info/RECORD +23 -0
- laco-0.0.0.dist-info/WHEEL +5 -0
- laco-0.0.0.dist-info/entry_points.txt +8 -0
- laco-0.0.0.dist-info/top_level.txt +1 -0
laco/__init__.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Lazy configuration system, inspired by and based on Detectron2 and Hydra.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from ._lazy import *
|
|
6
|
+
from ._loader import *
|
|
7
|
+
from ._overrides import *
|
|
8
|
+
from ._resolvers import *
|
|
9
|
+
|
|
10
|
+
__lazy__ = ("env", "language", "builtins", "cli", "env", "handler", "keys", "utils", "handler")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def __getattr__(name: str):
|
|
14
|
+
from importlib.metadata import PackageNotFoundError, version
|
|
15
|
+
from importlib import import_module
|
|
16
|
+
|
|
17
|
+
if name in __lazy__:
|
|
18
|
+
return import_module(name, package=__name__)
|
|
19
|
+
if name == "__version__":
|
|
20
|
+
try:
|
|
21
|
+
return version(__name__)
|
|
22
|
+
except PackageNotFoundError:
|
|
23
|
+
return "unknown"
|
|
24
|
+
msg = f"Module {__name__!r} has no attribute {name!r}"
|
|
25
|
+
raise AttributeError(msg)
|
|
26
|
+
|
|
27
|
+
def __dir__() -> list[str]:
|
|
28
|
+
from ._lazy import __all__ as all_lazy
|
|
29
|
+
from ._loader import __all__ as all_loader
|
|
30
|
+
from ._overrides import __all__ as all_overrides
|
|
31
|
+
from ._resolvers import __all__ as all_resolvers
|
|
32
|
+
|
|
33
|
+
return sorted(__lazy__ + ["__version__"] + all_lazy + all_loader + all_overrides + all_resolvers)
|
laco/__init__.pyi
ADDED
laco/_dump.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
import pprint
|
|
3
|
+
from contextlib import suppress
|
|
4
|
+
from copy import deepcopy
|
|
5
|
+
|
|
6
|
+
import iopathlib
|
|
7
|
+
import yaml
|
|
8
|
+
from omegaconf import DictConfig, OmegaConf, SCMode
|
|
9
|
+
|
|
10
|
+
from . import keys, utils
|
|
11
|
+
|
|
12
|
+
__all__ = ["dump_config", "save_config"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def dump_config(cfg) -> str: # noqa: C901
|
|
16
|
+
if not isinstance(cfg, DictConfig):
|
|
17
|
+
cfg = utils.as_omegadict(
|
|
18
|
+
dataclasses.asdict(cfg) if dataclasses.is_dataclass(cfg) else cfg
|
|
19
|
+
)
|
|
20
|
+
try:
|
|
21
|
+
cfg = deepcopy(cfg)
|
|
22
|
+
except Exception:
|
|
23
|
+
pass
|
|
24
|
+
else:
|
|
25
|
+
|
|
26
|
+
def _replace_type_by_name(x):
|
|
27
|
+
if keys.LAZY_CALL in x and callable(x._target_):
|
|
28
|
+
with suppress(AttributeError):
|
|
29
|
+
x._target_ = utils.generate_path(x._target_)
|
|
30
|
+
|
|
31
|
+
utils.apply_recursive(cfg, _replace_type_by_name)
|
|
32
|
+
|
|
33
|
+
try:
|
|
34
|
+
cfg_as_dict = OmegaConf.to_container(
|
|
35
|
+
cfg,
|
|
36
|
+
# Do not resolve interpolation when saving, i.e. do not turn ${a} into
|
|
37
|
+
# actual values when saving.
|
|
38
|
+
resolve=False,
|
|
39
|
+
# Save structures (dataclasses) in a format that can be instantiated later.
|
|
40
|
+
# Without this option, the type information of the dataclass will be erased.
|
|
41
|
+
structured_config_mode=SCMode.INSTANTIATE,
|
|
42
|
+
)
|
|
43
|
+
except Exception as err:
|
|
44
|
+
cfg_pretty = pprint.pformat(OmegaConf.to_container(cfg)).replace("\n", "\n\t")
|
|
45
|
+
msg = f"Config cannot be converted to a dict!\n\nConfig node:\n{cfg_pretty}"
|
|
46
|
+
raise ValueError(msg) from err
|
|
47
|
+
|
|
48
|
+
dump_kwargs = {"default_flow_style": None, "allow_unicode": True}
|
|
49
|
+
|
|
50
|
+
def _find_undumpable(cfg_as_dict, *, _key=()) -> tuple[str, ...] | None:
|
|
51
|
+
for key, value in cfg_as_dict.items():
|
|
52
|
+
if not isinstance(value, dict):
|
|
53
|
+
continue
|
|
54
|
+
try:
|
|
55
|
+
_ = yaml.dump(value, **dump_kwargs)
|
|
56
|
+
continue
|
|
57
|
+
except Exception:
|
|
58
|
+
pass
|
|
59
|
+
key_with_error = _find_undumpable(value, _key=_key + (key,))
|
|
60
|
+
if key_with_error:
|
|
61
|
+
return key_with_error
|
|
62
|
+
return _key + (key,)
|
|
63
|
+
return None
|
|
64
|
+
|
|
65
|
+
try:
|
|
66
|
+
dumped = yaml.dump(cfg_as_dict, **dump_kwargs)
|
|
67
|
+
except Exception as err:
|
|
68
|
+
cfg_pretty = pprint.pformat(cfg_as_dict).replace("\n", "\n\t")
|
|
69
|
+
problem_key = _find_undumpable(cfg_as_dict)
|
|
70
|
+
if problem_key:
|
|
71
|
+
problem_key = ".".join(problem_key)
|
|
72
|
+
msg = f"Config cannot be saved due to key {problem_key!r}"
|
|
73
|
+
else:
|
|
74
|
+
msg = "Config cannot be saved due to an unknown entry"
|
|
75
|
+
msg += f"\n\nConfig node:\n\t{cfg_pretty}"
|
|
76
|
+
raise SyntaxError(msg) from err
|
|
77
|
+
|
|
78
|
+
return dumped
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def save_config(cfg, path: str):
|
|
82
|
+
"""
|
|
83
|
+
Save a config object to a yaml file.
|
|
84
|
+
|
|
85
|
+
Parameters
|
|
86
|
+
----------
|
|
87
|
+
cfg
|
|
88
|
+
An omegaconf config object.
|
|
89
|
+
filename
|
|
90
|
+
The file name to save the config file.
|
|
91
|
+
"""
|
|
92
|
+
local_path = iopathlib.get_local_path(path) # type: ignore[arg-type]
|
|
93
|
+
if not local_path.endswith(".yaml"):
|
|
94
|
+
msg = f"Config file should be saved as a yaml file! Got: {path}"
|
|
95
|
+
raise ValueError(msg)
|
|
96
|
+
|
|
97
|
+
dumped = dump_config(cfg)
|
|
98
|
+
try:
|
|
99
|
+
with open(local_path, "w") as fh: # noqa: PTH123
|
|
100
|
+
fh.write(dumped)
|
|
101
|
+
_ = yaml.unsafe_load(dumped)
|
|
102
|
+
except Exception as err:
|
|
103
|
+
msg = f"Config file cannot be saved at {local_path!r}"
|
|
104
|
+
raise SyntaxError(msg) from err
|
laco/_lazy.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Instantiation of configuration objects.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import dataclasses
|
|
6
|
+
import logging
|
|
7
|
+
import pprint
|
|
8
|
+
import types
|
|
9
|
+
import typing
|
|
10
|
+
|
|
11
|
+
import omegaconf
|
|
12
|
+
from omegaconf import DictConfig, ListConfig
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"instantiate",
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
if typing.TYPE_CHECKING:
|
|
20
|
+
|
|
21
|
+
class LazyObject[_L]:
|
|
22
|
+
def __getattr__(self, name: str, /) -> typing.Any: ...
|
|
23
|
+
|
|
24
|
+
@typing.override
|
|
25
|
+
def __setattr__(self, name: str, value: typing.Any, /) -> None: ... # noqa: PYI063
|
|
26
|
+
|
|
27
|
+
else:
|
|
28
|
+
|
|
29
|
+
class LazyObject(dict[str, typing.Any]):
|
|
30
|
+
def __class_getitem__(cls, item: typing.Any) -> dict[str, typing.Any]:
|
|
31
|
+
return types.GenericAlias(dict, (str, typing.Any))
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
type AnyConfig = DictConfig | ListConfig
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def migrate_target(target: typing.Any) -> typing.Any:
|
|
38
|
+
return target
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
_INST_SEQ_TYPEMAP: dict[type, type] = {
|
|
42
|
+
ListConfig: list,
|
|
43
|
+
list: list,
|
|
44
|
+
tuple: tuple,
|
|
45
|
+
set: set,
|
|
46
|
+
frozenset: frozenset,
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def instantiate(cfg: typing.Any, /) -> object: # noqa: C901, PLR0912
|
|
51
|
+
"""
|
|
52
|
+
Recursively instantiate objects defined in dictionaries with keys:
|
|
53
|
+
|
|
54
|
+
- Special key ``keys.CONFIG_CALL``: defines the callable/objec to be instantiated.
|
|
55
|
+
- Special key ``"_args_"``: defines the positional arguments to be passed to the
|
|
56
|
+
callable.
|
|
57
|
+
- Other keys define the keyword arguments to be passed to the callable.
|
|
58
|
+
"""
|
|
59
|
+
import laco.env
|
|
60
|
+
import laco.keys
|
|
61
|
+
import laco.utils
|
|
62
|
+
|
|
63
|
+
if cfg is None or isinstance(
|
|
64
|
+
cfg,
|
|
65
|
+
int
|
|
66
|
+
| float
|
|
67
|
+
| bool
|
|
68
|
+
| str
|
|
69
|
+
| set
|
|
70
|
+
| frozenset
|
|
71
|
+
| bytes
|
|
72
|
+
| type
|
|
73
|
+
| types.NoneType
|
|
74
|
+
| types.FunctionType,
|
|
75
|
+
):
|
|
76
|
+
return cfg # type: ignore[return-value]
|
|
77
|
+
|
|
78
|
+
if laco.env.fetch(bool, "LACO_TRACE", default=False):
|
|
79
|
+
logging.getLogger(__name__).info(
|
|
80
|
+
"Instantiating %s", pprint.pprint(omegaconf.OmegaConf.to_container(cfg))
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
if isinstance(cfg, typing.Sequence) and not isinstance(
|
|
84
|
+
cfg, typing.Mapping | str | bytes
|
|
85
|
+
):
|
|
86
|
+
cls = type(cfg)
|
|
87
|
+
cls = _INST_SEQ_TYPEMAP.get(cls, cls)
|
|
88
|
+
return cls(instantiate(x) for x in cfg)
|
|
89
|
+
|
|
90
|
+
# If input is a DictConfig backed by dataclasses (structured config)
|
|
91
|
+
# instantiate it to the actual dataclass.
|
|
92
|
+
if isinstance(cfg, DictConfig) and dataclasses.is_dataclass(
|
|
93
|
+
cfg._metadata.object_type
|
|
94
|
+
):
|
|
95
|
+
return omegaconf.OmegaConf.to_object(cfg)
|
|
96
|
+
|
|
97
|
+
if isinstance(cfg, typing.Mapping) and laco.keys.LAZY_CALL in cfg:
|
|
98
|
+
# conceptually equivalent to hydra.utils.instantiate(cfg) with _convert_=all,
|
|
99
|
+
# but faster: https://github.com/facebookresearch/hydra/issues/1200
|
|
100
|
+
cfg = {k: instantiate(v) for k, v in cfg.items()}
|
|
101
|
+
cls = cfg.pop(laco.keys.LAZY_CALL)
|
|
102
|
+
cls = migrate_target(cls)
|
|
103
|
+
cls = instantiate(cls)
|
|
104
|
+
|
|
105
|
+
if isinstance(cls, str):
|
|
106
|
+
cls_name = cls
|
|
107
|
+
cls = laco.utils.locate_object(cls_name)
|
|
108
|
+
assert cls is not None, cls_name
|
|
109
|
+
else:
|
|
110
|
+
try:
|
|
111
|
+
cls_name = cls.__module__ + "." + cls.__qualname__
|
|
112
|
+
except Exception: # noqa: B902, PIE786
|
|
113
|
+
# target could be anything, so the above could fail
|
|
114
|
+
cls_name = str(cls)
|
|
115
|
+
if not callable(cls):
|
|
116
|
+
msg = f"Non-callable object found: {laco.keys.LAZY_CALL}={cls!r}!"
|
|
117
|
+
raise TypeError(msg)
|
|
118
|
+
|
|
119
|
+
cfg_args = cfg.pop(laco.keys.LAZY_ARGS, ())
|
|
120
|
+
if not isinstance(cfg_args, typing.Sequence):
|
|
121
|
+
msg = f"Expected sequence for {laco.keys.LAZY_ARGS}, but got {type(cfg_args)}!"
|
|
122
|
+
raise TypeError(msg)
|
|
123
|
+
|
|
124
|
+
try:
|
|
125
|
+
return cls(*cfg_args, **cfg)
|
|
126
|
+
except Exception as err:
|
|
127
|
+
msg = (
|
|
128
|
+
f"Error instantiating lazy object {cls_name}.\n\nConfig node:\n\t{cfg}!"
|
|
129
|
+
)
|
|
130
|
+
raise RuntimeError(msg) from err
|
|
131
|
+
|
|
132
|
+
if isinstance(cfg, dict | DictConfig):
|
|
133
|
+
return {k: instantiate(v) for k, v in cfg.items()} # type: ignore[return-value]
|
|
134
|
+
|
|
135
|
+
if callable(cfg):
|
|
136
|
+
return cfg
|
|
137
|
+
|
|
138
|
+
err = f"Cannot instantiate {cfg}, type {type(cfg)}!"
|
|
139
|
+
raise ValueError(err)
|
laco/_loader.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
import builtins
|
|
2
|
+
import os
|
|
3
|
+
import typing
|
|
4
|
+
from contextlib import contextmanager
|
|
5
|
+
from uuid import uuid4
|
|
6
|
+
|
|
7
|
+
import iopathlib
|
|
8
|
+
import yaml
|
|
9
|
+
from omegaconf import DictConfig, ListConfig
|
|
10
|
+
|
|
11
|
+
__all__ = ["load_config_remote", "load_config_local"]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
PATCH_PREFIX: typing.Final = "_laco_"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def load_config_remote(path: str):
|
|
18
|
+
"""
|
|
19
|
+
Load a configuration from a remote source. Currently accepted external configuration
|
|
20
|
+
sources are:
|
|
21
|
+
|
|
22
|
+
- `Weights & Biases <https://wandb.ai/>`_ runs: ``wandb-run://<run_id>``
|
|
23
|
+
"""
|
|
24
|
+
from unipercept.engine.integrations.wandb_integration import WANDB_RUN_PREFIX
|
|
25
|
+
from unipercept.engine.integrations.wandb_integration import (
|
|
26
|
+
read_run as wandb_read_run,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
if path.startswith(WANDB_RUN_PREFIX):
|
|
30
|
+
run = wandb_read_run(path)
|
|
31
|
+
cfg = DictConfig(run.config)
|
|
32
|
+
else:
|
|
33
|
+
raise FileNotFoundError(path)
|
|
34
|
+
|
|
35
|
+
return cfg
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@contextmanager
|
|
39
|
+
def _patch_import(): # noqa: C901
|
|
40
|
+
import importlib.machinery
|
|
41
|
+
import importlib.util
|
|
42
|
+
|
|
43
|
+
import_default = builtins.__import__
|
|
44
|
+
|
|
45
|
+
def find_relative(original_file, relative_import_path, level):
|
|
46
|
+
# NOTE: "from . import x" is not handled. Because then it's unclear
|
|
47
|
+
# if such import should produce `x` as a python module or DictConfig.
|
|
48
|
+
# This can be discussed further if needed.
|
|
49
|
+
relative_import_err = (
|
|
50
|
+
"Relative import of directories is not allowed within config files. "
|
|
51
|
+
"Within a config file, relative import can only import other config files."
|
|
52
|
+
)
|
|
53
|
+
if not len(relative_import_path):
|
|
54
|
+
raise ImportError(relative_import_err)
|
|
55
|
+
|
|
56
|
+
cur_file = os.path.dirname(original_file) # noqa: PTH120
|
|
57
|
+
for _ in range(level - 1):
|
|
58
|
+
cur_file = os.path.dirname(cur_file) # noqa: PTH120
|
|
59
|
+
cur_name = relative_import_path.lstrip(".")
|
|
60
|
+
for part in cur_name.split("."):
|
|
61
|
+
cur_file = os.path.join(cur_file, part) # noqa: PTH118
|
|
62
|
+
if not cur_file.endswith(".py"):
|
|
63
|
+
cur_file += ".py"
|
|
64
|
+
if not iopathlib.isfile(cur_file):
|
|
65
|
+
cur_file_no_suffix = cur_file[: -len(".py")]
|
|
66
|
+
if iopathlib.isdir(cur_file_no_suffix):
|
|
67
|
+
raise ImportError(
|
|
68
|
+
f"Cannot import from {cur_file_no_suffix}." + relative_import_err
|
|
69
|
+
)
|
|
70
|
+
msg = (
|
|
71
|
+
f"Cannot import name {relative_import_path} from "
|
|
72
|
+
f"{original_file}: {cur_file} does not exist."
|
|
73
|
+
)
|
|
74
|
+
raise ImportError(msg)
|
|
75
|
+
return cur_file
|
|
76
|
+
|
|
77
|
+
def import_patched(name, globals=None, locals=None, fromlist=(), level=0):
|
|
78
|
+
if (
|
|
79
|
+
# Only deal with relative imports inside config files
|
|
80
|
+
level != 0
|
|
81
|
+
and globals is not None
|
|
82
|
+
and (globals.get("__package__", "") or "").startswith(PATCH_PREFIX)
|
|
83
|
+
):
|
|
84
|
+
cur_file = find_relative(globals["__file__"], name, level)
|
|
85
|
+
laco.utils.check_syntax(cur_file)
|
|
86
|
+
spec = importlib.machinery.ModuleSpec(
|
|
87
|
+
_generate_packagename(cur_file), None, origin=cur_file
|
|
88
|
+
)
|
|
89
|
+
module = importlib.util.module_from_spec(spec)
|
|
90
|
+
module.__file__ = cur_file
|
|
91
|
+
with iopathlib.open(cur_file) as f:
|
|
92
|
+
content = f.read()
|
|
93
|
+
exec(compile(content, cur_file, "exec"), module.__dict__)
|
|
94
|
+
for name in fromlist: # noqa: PLR1704
|
|
95
|
+
val = laco.utils.as_omegadict(module.__dict__[name])
|
|
96
|
+
module.__dict__[name] = val
|
|
97
|
+
return module
|
|
98
|
+
return import_default(name, globals, locals, fromlist=fromlist, level=level)
|
|
99
|
+
|
|
100
|
+
builtins.__import__ = import_patched
|
|
101
|
+
yield import_patched
|
|
102
|
+
builtins.__import__ = import_default
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def load_config_local(path: str):
|
|
106
|
+
"""
|
|
107
|
+
Loads a configuration from a local source.
|
|
108
|
+
|
|
109
|
+
Users should prefer to load configurations via the unified API with
|
|
110
|
+
:func:`unipercept.read_config` instead of calling this method directly.
|
|
111
|
+
"""
|
|
112
|
+
import laco
|
|
113
|
+
import laco.keys
|
|
114
|
+
import laco.utils
|
|
115
|
+
|
|
116
|
+
ext = os.path.splitext(path)[1] # noqa: PTH122
|
|
117
|
+
match ext.lower():
|
|
118
|
+
case ".py":
|
|
119
|
+
laco.utils.check_syntax(path)
|
|
120
|
+
|
|
121
|
+
with _patch_import():
|
|
122
|
+
# Record the filename
|
|
123
|
+
nsp = {
|
|
124
|
+
"__file__": path,
|
|
125
|
+
"__package__": _generate_packagename(path),
|
|
126
|
+
}
|
|
127
|
+
with iopathlib.open(path) as f:
|
|
128
|
+
content = f.read()
|
|
129
|
+
# Compile first with filename to:
|
|
130
|
+
# 1. make filename appears in stacktrace
|
|
131
|
+
# 2. make load_rel able to find its parent's (possibly remote) location
|
|
132
|
+
exec(compile(content, iopathlib.get_local_path(path), "exec"), nsp)
|
|
133
|
+
|
|
134
|
+
export = nsp.get(
|
|
135
|
+
"__all__",
|
|
136
|
+
(
|
|
137
|
+
k
|
|
138
|
+
for k, v in nsp.items()
|
|
139
|
+
if not k.startswith("_")
|
|
140
|
+
and (
|
|
141
|
+
isinstance(
|
|
142
|
+
v,
|
|
143
|
+
dict
|
|
144
|
+
| list
|
|
145
|
+
| DictConfig
|
|
146
|
+
| ListConfig
|
|
147
|
+
| int
|
|
148
|
+
| float
|
|
149
|
+
| str
|
|
150
|
+
| bool,
|
|
151
|
+
)
|
|
152
|
+
or v is None
|
|
153
|
+
)
|
|
154
|
+
),
|
|
155
|
+
)
|
|
156
|
+
obj: dict[str, typing.Any] = {k: v for k, v in nsp.items() if k in export}
|
|
157
|
+
obj.setdefault(laco.keys.CONFIG_NAME, _filepath_to_name(path))
|
|
158
|
+
obj.setdefault(laco.keys.CONFIG_VERSION, laco.__version__)
|
|
159
|
+
|
|
160
|
+
case ".yaml":
|
|
161
|
+
with iopathlib.open(path) as f:
|
|
162
|
+
obj = yaml.unsafe_load(f)
|
|
163
|
+
obj.setdefault(laco.keys.CONFIG_NAME, "unknown")
|
|
164
|
+
obj.setdefault(laco.keys.CONFIG_VERSION, "unknown")
|
|
165
|
+
case _:
|
|
166
|
+
msg = "Unsupported file extension %s!"
|
|
167
|
+
raise ValueError(msg, ext)
|
|
168
|
+
|
|
169
|
+
return laco.utils.as_omegadict(obj)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def _filepath_to_name(path: str | iopathlib.Path) -> str | None:
|
|
173
|
+
"""
|
|
174
|
+
Convert a file path to a module name.
|
|
175
|
+
"""
|
|
176
|
+
|
|
177
|
+
configs_root = iopathlib.Path("./configs").resolve()
|
|
178
|
+
path = iopathlib.Path(path).resolve()
|
|
179
|
+
try:
|
|
180
|
+
name = path.relative_to(configs_root).parent.as_posix() + "/" + path.stem
|
|
181
|
+
except Exception:
|
|
182
|
+
name = "/".join([path.parent.stem, path.stem])
|
|
183
|
+
|
|
184
|
+
name = name.replace("./", "")
|
|
185
|
+
name = name.replace("//", "/")
|
|
186
|
+
|
|
187
|
+
if name in {"__init__", "defaults", "unknown", "config", "configs"}:
|
|
188
|
+
return None
|
|
189
|
+
return name.removesuffix(".py")
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def _generate_packagename(path: str):
|
|
193
|
+
# generate a random package name when loading config files
|
|
194
|
+
return PATCH_PREFIX + str(uuid4())[:4] + "." + iopathlib.Path(path).name
|
laco/_overrides.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
|
|
2
|
+
from omegaconf import OmegaConf
|
|
3
|
+
|
|
4
|
+
__all__ = ["apply_overrides"]
|
|
5
|
+
|
|
6
|
+
def apply_overrides(cfg, overrides: list[str]):
|
|
7
|
+
"""
|
|
8
|
+
In-place override contents of cfg.
|
|
9
|
+
|
|
10
|
+
Parameters
|
|
11
|
+
----------
|
|
12
|
+
cfg
|
|
13
|
+
An omegaconf config object
|
|
14
|
+
overrides
|
|
15
|
+
List of strings in the format of "a=b" to override configs.
|
|
16
|
+
See: https://hydra.cc/docs/next/advanced/override_grammar/basic/
|
|
17
|
+
|
|
18
|
+
Returns
|
|
19
|
+
-------
|
|
20
|
+
DictConfig
|
|
21
|
+
Lazy configuration object
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
from hydra.core.override_parser.overrides_parser import OverridesParser
|
|
26
|
+
except ImportError as err:
|
|
27
|
+
msg = "Hydra is not installed. Please install Hydra to use this function."
|
|
28
|
+
raise ImportError(msg) from err
|
|
29
|
+
|
|
30
|
+
def safe_update(cfg, key, value):
|
|
31
|
+
parts = key.split(".")
|
|
32
|
+
for idx in range(1, len(parts)):
|
|
33
|
+
prefix = ".".join(parts[:idx])
|
|
34
|
+
v = OmegaConf.select(cfg, prefix, default=None)
|
|
35
|
+
if v is None:
|
|
36
|
+
break
|
|
37
|
+
if not OmegaConf.is_config(v):
|
|
38
|
+
msg = (
|
|
39
|
+
f"Trying to update key {key}, but {prefix} "
|
|
40
|
+
f"is not a config, but has type {type(v)}."
|
|
41
|
+
)
|
|
42
|
+
raise KeyError(msg)
|
|
43
|
+
OmegaConf.update(cfg, key, value, merge=True)
|
|
44
|
+
|
|
45
|
+
for o in OverridesParser.create().parse_overrides(overrides):
|
|
46
|
+
key = o.key_or_group
|
|
47
|
+
value = o.value()
|
|
48
|
+
if o.is_delete():
|
|
49
|
+
msg = "deletion is not yet a supported override"
|
|
50
|
+
raise NotImplementedError(msg)
|
|
51
|
+
safe_update(cfg, key, value)
|
|
52
|
+
|
|
53
|
+
return cfg
|
laco/_resolvers.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
from omegaconf import OmegaConf
|
|
4
|
+
|
|
5
|
+
__all__ = []
|
|
6
|
+
|
|
7
|
+
OmegaConf.register_new_resolver("sum", lambda *numbers: sum(numbers))
|
|
8
|
+
OmegaConf.register_new_resolver("min", lambda *numbers: min(numbers))
|
|
9
|
+
OmegaConf.register_new_resolver("max", lambda *numbers: max(numbers))
|
|
10
|
+
OmegaConf.register_new_resolver("div", lambda a, b: a / b)
|
|
11
|
+
OmegaConf.register_new_resolver("pow", lambda a, b: a**b)
|
|
12
|
+
OmegaConf.register_new_resolver("mod", lambda a, b: a % b)
|
|
13
|
+
OmegaConf.register_new_resolver("neg", lambda a: -a)
|
|
14
|
+
OmegaConf.register_new_resolver("reciprocal", lambda a: 1 / a)
|
|
15
|
+
OmegaConf.register_new_resolver("abs", lambda a: abs(a))
|
|
16
|
+
OmegaConf.register_new_resolver("round", lambda a, b: round(a, b))
|
|
17
|
+
OmegaConf.register_new_resolver("math", lambda name, *args: getattr(math, name)(args))
|
laco/builtins.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import typing
|
|
3
|
+
from dataclasses import is_dataclass
|
|
4
|
+
|
|
5
|
+
from omegaconf import DictConfig
|
|
6
|
+
|
|
7
|
+
import laco.utils
|
|
8
|
+
import laco.keys
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def partial(
|
|
12
|
+
**kwargs: typing.Any,
|
|
13
|
+
) -> typing.Callable[..., typing.Any]:
|
|
14
|
+
cb = kwargs.get(laco.keys.LAZY_PART, None)
|
|
15
|
+
if isinstance(cb, str):
|
|
16
|
+
cb = laco.utils.locate_object(cb)
|
|
17
|
+
if not callable(cb):
|
|
18
|
+
msg = f"Expected a callable object or location (str), got {cb} (type {type(cb)}"
|
|
19
|
+
raise TypeError(msg)
|
|
20
|
+
return functools.partial(cb, **kwargs)
|
laco/cli.py
ADDED
laco/env.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
r"""
|
|
2
|
+
Working with environment variables.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import enum
|
|
6
|
+
import functools
|
|
7
|
+
import os
|
|
8
|
+
import typing
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
__all__ = ["fetch", "EnvFilter"]
|
|
12
|
+
|
|
13
|
+
type EnvVarCompatible = int | str | bool
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class EnvFilter(enum.StrEnum):
|
|
17
|
+
STRING = enum.auto()
|
|
18
|
+
TRUTHY = enum.auto()
|
|
19
|
+
FALSY = enum.auto()
|
|
20
|
+
POSITIVE = enum.auto()
|
|
21
|
+
NEGATIVE = enum.auto()
|
|
22
|
+
NONNEGATIVE = enum.auto()
|
|
23
|
+
NONPOSITIVE = enum.auto()
|
|
24
|
+
|
|
25
|
+
@staticmethod
|
|
26
|
+
def apply(f: "EnvFilter | str | None", v: typing.Any, /) -> bool: # noqa: PLR0911
|
|
27
|
+
if f is None:
|
|
28
|
+
return True
|
|
29
|
+
if v is None:
|
|
30
|
+
return False
|
|
31
|
+
match EnvFilter(f):
|
|
32
|
+
case EnvFilter.STRING:
|
|
33
|
+
assert isinstance(v, str)
|
|
34
|
+
v = v.lower()
|
|
35
|
+
return v != ""
|
|
36
|
+
case EnvFilter.TRUTHY:
|
|
37
|
+
return bool(v)
|
|
38
|
+
case EnvFilter.FALSY:
|
|
39
|
+
return not bool(v)
|
|
40
|
+
case EnvFilter.POSITIVE:
|
|
41
|
+
return v > 0
|
|
42
|
+
case EnvFilter.NEGATIVE:
|
|
43
|
+
return v < 0
|
|
44
|
+
case EnvFilter.NONNEGATIVE:
|
|
45
|
+
return v >= 0
|
|
46
|
+
case EnvFilter.NONPOSITIVE:
|
|
47
|
+
return v <= 0
|
|
48
|
+
case _:
|
|
49
|
+
msg = f"Invalid filter: {f!r}"
|
|
50
|
+
raise ValueError(msg)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@typing.overload
|
|
54
|
+
def fetch[_T: EnvVarCompatible](
|
|
55
|
+
__type: type[_T],
|
|
56
|
+
/,
|
|
57
|
+
*keys: str,
|
|
58
|
+
default: _T,
|
|
59
|
+
filter: EnvFilter | None = None,
|
|
60
|
+
) -> _T: ...
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@typing.overload
|
|
64
|
+
def fetch[_T: EnvVarCompatible](
|
|
65
|
+
__type: type[_T],
|
|
66
|
+
/,
|
|
67
|
+
*keys: str,
|
|
68
|
+
default: _T | None = None,
|
|
69
|
+
filter: EnvFilter | None = None,
|
|
70
|
+
) -> _T | None: ...
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@functools.cache
|
|
74
|
+
def fetch[_T: EnvVarCompatible](
|
|
75
|
+
__type: type[_T],
|
|
76
|
+
/,
|
|
77
|
+
*keys: str,
|
|
78
|
+
default: _T | None = None,
|
|
79
|
+
filter: EnvFilter | None = None,
|
|
80
|
+
) -> _T | None:
|
|
81
|
+
"""
|
|
82
|
+
Read an environment variable. If the variable is not set, return the default value.
|
|
83
|
+
|
|
84
|
+
If no default is given, an error is raised if the variable is not set.
|
|
85
|
+
"""
|
|
86
|
+
keys_read = []
|
|
87
|
+
for k in keys:
|
|
88
|
+
keys_read.append(k)
|
|
89
|
+
v = os.getenv(k)
|
|
90
|
+
if v is None:
|
|
91
|
+
continue
|
|
92
|
+
v = strtobool(v) if issubclass(__type, bool) else __type(v)
|
|
93
|
+
if not EnvFilter.apply(filter, v):
|
|
94
|
+
continue
|
|
95
|
+
break
|
|
96
|
+
else:
|
|
97
|
+
v = default
|
|
98
|
+
return typing.cast(_T, v)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def strtobool(v: str) -> bool:
|
|
102
|
+
"""
|
|
103
|
+
Convert a string to a boolean value.
|
|
104
|
+
"""
|
|
105
|
+
v = v.lower()
|
|
106
|
+
if v in {"true", "yes", "on", "1"}:
|
|
107
|
+
return True
|
|
108
|
+
if v in {"false", "no", "off", "0"}:
|
|
109
|
+
return False
|
|
110
|
+
msg = f"Invalid boolean value: {v!r}"
|
|
111
|
+
raise ValueError(msg)
|