heracls 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heracls-0.1.0/LICENSE +21 -0
- heracls-0.1.0/PKG-INFO +108 -0
- heracls-0.1.0/README.md +83 -0
- heracls-0.1.0/heracls/__init__.py +14 -0
- heracls-0.1.0/heracls/patch.py +35 -0
- heracls-0.1.0/heracls/transforms.py +132 -0
- heracls-0.1.0/heracls.egg-info/PKG-INFO +108 -0
- heracls-0.1.0/heracls.egg-info/SOURCES.txt +12 -0
- heracls-0.1.0/heracls.egg-info/dependency_links.txt +1 -0
- heracls-0.1.0/heracls.egg-info/requires.txt +6 -0
- heracls-0.1.0/heracls.egg-info/top_level.txt +1 -0
- heracls-0.1.0/pyproject.toml +59 -0
- heracls-0.1.0/setup.cfg +4 -0
- heracls-0.1.0/tests/test_transforms.py +43 -0
heracls-0.1.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 François Rozet
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
heracls-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: heracls
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Slayer of Hydra
|
|
5
|
+
Author-email: François Rozet <francois.rozet@outlook.com>
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Project-URL: documentation, https://github.com/francois-rozet/heracls
|
|
8
|
+
Project-URL: source, https://github.com/francois-rozet/heracls
|
|
9
|
+
Project-URL: tracker, https://github.com/francois-rozet/heracls/issues
|
|
10
|
+
Keywords: config,dataclass
|
|
11
|
+
Classifier: Intended Audience :: Developers
|
|
12
|
+
Classifier: Intended Audience :: Science/Research
|
|
13
|
+
Classifier: Natural Language :: English
|
|
14
|
+
Classifier: Operating System :: OS Independent
|
|
15
|
+
Classifier: Programming Language :: Python :: 3
|
|
16
|
+
Requires-Python: >=3.8
|
|
17
|
+
Description-Content-Type: text/markdown
|
|
18
|
+
License-File: LICENSE
|
|
19
|
+
Requires-Dist: dacite>=1.9.2
|
|
20
|
+
Requires-Dist: omegaconf>=2.3.0
|
|
21
|
+
Provides-Extra: dev
|
|
22
|
+
Requires-Dist: pytest>=8.3.5; extra == "dev"
|
|
23
|
+
Requires-Dist: ruff==0.15.2; extra == "dev"
|
|
24
|
+
Dynamic: license-file
|
|
25
|
+
|
|
26
|
+
# Heracls - Slayer of Hydra
|
|
27
|
+
|
|
28
|
+
`heracls` is a tiny utility package to instantiate typed dataclasses from flexible config sources (dictionary, OmegaConf, YAML, dotlist, ...). It is designed for projects that want strict, typed config objects while supporting the dynamic overrides commonly used in scripts and experiments.
|
|
29
|
+
|
|
30
|
+
## Installation
|
|
31
|
+
|
|
32
|
+
The `heracls` package is available on [PyPi](https://pypi.org/project/heracls) and can be installed with `pip`.
|
|
33
|
+
|
|
34
|
+
```
|
|
35
|
+
pip install heracls
|
|
36
|
+
```
|
|
37
|
+
|
|
38
|
+
Alternatively, if you need the latest features, you can install it from source.
|
|
39
|
+
|
|
40
|
+
```
|
|
41
|
+
pip install git+https://github.com/francois-rozet/heracls
|
|
42
|
+
```
|
|
43
|
+
|
|
44
|
+
## Getting started
|
|
45
|
+
|
|
46
|
+
The following example demonstrates how to declare a nested dataclass config, instantiate it while overriding default fields with a dotlist, serialize it to YAML, and use its fields in a script.
|
|
47
|
+
|
|
48
|
+
```python
|
|
49
|
+
import heracls
|
|
50
|
+
|
|
51
|
+
from dataclasses import dataclass, field
|
|
52
|
+
from typing import Any, Dict, Literal, Tuple, Union
|
|
53
|
+
|
|
54
|
+
@dataclass
|
|
55
|
+
class AdamConfig:
|
|
56
|
+
optimizer: Literal["adam", "adamw"] = "adam"
|
|
57
|
+
betas: Tuple[float, float] = (0.95, 0.95)
|
|
58
|
+
learning_rate: float = 1e-3
|
|
59
|
+
weight_decay: float = 0.0
|
|
60
|
+
|
|
61
|
+
@dataclass
|
|
62
|
+
class SGDConfig:
|
|
63
|
+
optimizer: Literal["sgd"] = "sgd"
|
|
64
|
+
momentum: float = 0.0
|
|
65
|
+
learning_rate: float = 1e-3
|
|
66
|
+
weight_decay: float = 0.0
|
|
67
|
+
nesterov: bool = False
|
|
68
|
+
|
|
69
|
+
@dataclass
|
|
70
|
+
class TrainConfig:
|
|
71
|
+
optim: Union[AdamConfig, SGDConfig] = field(default_factory=AdamConfig)
|
|
72
|
+
n_epochs: int = 1024
|
|
73
|
+
data_splits: Tuple[float, float] = (0.8, 0.1)
|
|
74
|
+
slurm: Dict[str, Any] = field(default_factory=dict)
|
|
75
|
+
|
|
76
|
+
dotlist = [ # usually retrieved from the command line
|
|
77
|
+
"optim.optimizer=sgd",
|
|
78
|
+
"data_splits=[0.7,0.2]",
|
|
79
|
+
"slurm.account=frozet",
|
|
80
|
+
]
|
|
81
|
+
|
|
82
|
+
cfg = heracls.from_dotlist(TrainConfig, dotlist)
|
|
83
|
+
|
|
84
|
+
with open("config.yaml", "w") as f:
|
|
85
|
+
f.write(heracls.to_yaml(cfg))
|
|
86
|
+
|
|
87
|
+
trainset, validset, testset = load_dataset(cfg.data_splits)
|
|
88
|
+
|
|
89
|
+
for epoch in range(cfg.n_epochs):
|
|
90
|
+
...
|
|
91
|
+
```
|
|
92
|
+
|
|
93
|
+
The fields that are not specified in the dotlist are instantiated with their default value, as can be seen in the dumped `config.yaml` file.
|
|
94
|
+
|
|
95
|
+
```yaml
|
|
96
|
+
optim:
|
|
97
|
+
optimizer: sgd
|
|
98
|
+
momentum: 0.0
|
|
99
|
+
learning_rate: 0.001
|
|
100
|
+
weight_decay: 0.0
|
|
101
|
+
nesterov: false
|
|
102
|
+
n_epochs: 1024
|
|
103
|
+
data_splits:
|
|
104
|
+
- 0.7
|
|
105
|
+
- 0.2
|
|
106
|
+
slurm:
|
|
107
|
+
account: frozet
|
|
108
|
+
```
|
heracls-0.1.0/README.md
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
# Heracls - Slayer of Hydra
|
|
2
|
+
|
|
3
|
+
`heracls` is a tiny utility package to instantiate typed dataclasses from flexible config sources (dictionary, OmegaConf, YAML, dotlist, ...). It is designed for projects that want strict, typed config objects while supporting the dynamic overrides commonly used in scripts and experiments.
|
|
4
|
+
|
|
5
|
+
## Installation
|
|
6
|
+
|
|
7
|
+
The `heracls` package is available on [PyPi](https://pypi.org/project/heracls) and can be installed with `pip`.
|
|
8
|
+
|
|
9
|
+
```
|
|
10
|
+
pip install heracls
|
|
11
|
+
```
|
|
12
|
+
|
|
13
|
+
Alternatively, if you need the latest features, you can install it from source.
|
|
14
|
+
|
|
15
|
+
```
|
|
16
|
+
pip install git+https://github.com/francois-rozet/heracls
|
|
17
|
+
```
|
|
18
|
+
|
|
19
|
+
## Getting started
|
|
20
|
+
|
|
21
|
+
The following example demonstrates how to declare a nested dataclass config, instantiate it while overriding default fields with a dotlist, serialize it to YAML, and use its fields in a script.
|
|
22
|
+
|
|
23
|
+
```python
|
|
24
|
+
import heracls
|
|
25
|
+
|
|
26
|
+
from dataclasses import dataclass, field
|
|
27
|
+
from typing import Any, Dict, Literal, Tuple, Union
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class AdamConfig:
|
|
31
|
+
optimizer: Literal["adam", "adamw"] = "adam"
|
|
32
|
+
betas: Tuple[float, float] = (0.95, 0.95)
|
|
33
|
+
learning_rate: float = 1e-3
|
|
34
|
+
weight_decay: float = 0.0
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class SGDConfig:
|
|
38
|
+
optimizer: Literal["sgd"] = "sgd"
|
|
39
|
+
momentum: float = 0.0
|
|
40
|
+
learning_rate: float = 1e-3
|
|
41
|
+
weight_decay: float = 0.0
|
|
42
|
+
nesterov: bool = False
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class TrainConfig:
|
|
46
|
+
optim: Union[AdamConfig, SGDConfig] = field(default_factory=AdamConfig)
|
|
47
|
+
n_epochs: int = 1024
|
|
48
|
+
data_splits: Tuple[float, float] = (0.8, 0.1)
|
|
49
|
+
slurm: Dict[str, Any] = field(default_factory=dict)
|
|
50
|
+
|
|
51
|
+
dotlist = [ # usually retrieved from the command line
|
|
52
|
+
"optim.optimizer=sgd",
|
|
53
|
+
"data_splits=[0.7,0.2]",
|
|
54
|
+
"slurm.account=frozet",
|
|
55
|
+
]
|
|
56
|
+
|
|
57
|
+
cfg = heracls.from_dotlist(TrainConfig, dotlist)
|
|
58
|
+
|
|
59
|
+
with open("config.yaml", "w") as f:
|
|
60
|
+
f.write(heracls.to_yaml(cfg))
|
|
61
|
+
|
|
62
|
+
trainset, validset, testset = load_dataset(cfg.data_splits)
|
|
63
|
+
|
|
64
|
+
for epoch in range(cfg.n_epochs):
|
|
65
|
+
...
|
|
66
|
+
```
|
|
67
|
+
|
|
68
|
+
The fields that are not specified in the dotlist are instantiated with their default value, as can be seen in the dumped `config.yaml` file.
|
|
69
|
+
|
|
70
|
+
```yaml
|
|
71
|
+
optim:
|
|
72
|
+
optimizer: sgd
|
|
73
|
+
momentum: 0.0
|
|
74
|
+
learning_rate: 0.001
|
|
75
|
+
weight_decay: 0.0
|
|
76
|
+
nesterov: false
|
|
77
|
+
n_epochs: 1024
|
|
78
|
+
data_splits:
|
|
79
|
+
- 0.7
|
|
80
|
+
- 0.2
|
|
81
|
+
slurm:
|
|
82
|
+
account: frozet
|
|
83
|
+
```
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""Monkey patching."""
|
|
2
|
+
|
|
3
|
+
import dacite
|
|
4
|
+
|
|
5
|
+
from collections.abc import Mapping
|
|
6
|
+
from dacite.core import Config, _build_value, extract_generic, is_subclass
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _build_value_for_collection(collection: type, data: Any, config: Config) -> Any:
|
|
11
|
+
data_type = type(data)
|
|
12
|
+
if isinstance(data, dict) and is_subclass(collection, Mapping):
|
|
13
|
+
types = extract_generic(collection, defaults=(Any, Any))
|
|
14
|
+
return data_type(
|
|
15
|
+
(key, _build_value(type_=types[1], data=value, config=config))
|
|
16
|
+
for key, value in data.items()
|
|
17
|
+
)
|
|
18
|
+
elif isinstance(data, (list, tuple)) and is_subclass(collection, tuple):
|
|
19
|
+
types = extract_generic(collection)
|
|
20
|
+
if len(types) == 2 and types[1] is Ellipsis:
|
|
21
|
+
return data_type(
|
|
22
|
+
_build_value(type_=types[0], data=item, config=config) for item in data
|
|
23
|
+
)
|
|
24
|
+
elif len(data) == len(types):
|
|
25
|
+
return data_type(
|
|
26
|
+
_build_value(type_=type_, data=item, config=config)
|
|
27
|
+
for item, type_ in zip(data, types)
|
|
28
|
+
)
|
|
29
|
+
elif isinstance(data, (list, tuple)):
|
|
30
|
+
types = extract_generic(collection, defaults=(Any,))
|
|
31
|
+
return data_type(_build_value(type_=types[0], data=item, config=config) for item in data)
|
|
32
|
+
return data
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
dacite.core._build_value_for_collection = _build_value_for_collection
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
"""Data transformations."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"from_dict",
|
|
5
|
+
"from_dotlist",
|
|
6
|
+
"from_omega",
|
|
7
|
+
"from_yaml",
|
|
8
|
+
"to_dict",
|
|
9
|
+
"to_omega",
|
|
10
|
+
"to_yaml",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
import dacite
|
|
14
|
+
|
|
15
|
+
from dataclasses import fields as iter_fields
|
|
16
|
+
from dataclasses import is_dataclass
|
|
17
|
+
from omegaconf import DictConfig, OmegaConf
|
|
18
|
+
from typing import Any, ClassVar, Dict, List, Protocol, Type, TypeVar
|
|
19
|
+
|
|
20
|
+
T = TypeVar("T")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class DataClass(Protocol):
|
|
24
|
+
__dataclass_fields__: ClassVar[Dict[str, Any]]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def from_dict(data_cls: Type[T], data: Dict[str, Any]) -> T:
|
|
28
|
+
"""Instantiate a dataclass from a dictionary.
|
|
29
|
+
|
|
30
|
+
Arguments:
|
|
31
|
+
data_cls: A dataclass type.
|
|
32
|
+
data: A dictionary.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
A `data_cls` instance.
|
|
36
|
+
"""
|
|
37
|
+
return dacite.from_dict(
|
|
38
|
+
data_class=data_cls,
|
|
39
|
+
data=data,
|
|
40
|
+
config=dacite.Config(
|
|
41
|
+
cast=[tuple],
|
|
42
|
+
strict=True,
|
|
43
|
+
strict_unions_match=True,
|
|
44
|
+
),
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def to_dict(data: DataClass, /, recursive: bool = False) -> Dict[str, Any]:
|
|
49
|
+
"""Convert a dataclass instance to a dictionary.
|
|
50
|
+
|
|
51
|
+
Arguments:
|
|
52
|
+
data: A dataclass instance.
|
|
53
|
+
recursive: Recursively convert nested dataclasses when `True`.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
A dictionary representation of `data`.
|
|
57
|
+
"""
|
|
58
|
+
if not is_dataclass(type(data)):
|
|
59
|
+
return data
|
|
60
|
+
keys = [f.name for f in iter_fields(data)]
|
|
61
|
+
data = {k: getattr(data, k) for k in keys}
|
|
62
|
+
if recursive:
|
|
63
|
+
data = {k: to_dict(v, recursive=True) for k, v in data.items()}
|
|
64
|
+
return data
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def from_omega(data_cls: Type[T], data: DictConfig) -> T:
|
|
68
|
+
"""Instantiate a dataclass from an :mod:`omegaconf` config.
|
|
69
|
+
|
|
70
|
+
Arguments:
|
|
71
|
+
data_cls: A dataclass type.
|
|
72
|
+
data: A config object.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
A `data_cls` instance.
|
|
76
|
+
"""
|
|
77
|
+
return from_dict(
|
|
78
|
+
data_cls,
|
|
79
|
+
OmegaConf.to_container(data, resolve=True, throw_on_missing=True),
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def to_omega(data: DataClass) -> DictConfig:
|
|
84
|
+
"""Convert a dataclass instance to an :mod:`omegaconf` config.
|
|
85
|
+
|
|
86
|
+
Arguments:
|
|
87
|
+
data: A dataclass instance.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
A config representation of `data`.
|
|
91
|
+
"""
|
|
92
|
+
return OmegaConf.create(to_dict(data, recursive=True))
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def from_yaml(data_cls: Type[T], data: str) -> T:
|
|
96
|
+
"""Instantiate a dataclass from a YAML string.
|
|
97
|
+
|
|
98
|
+
Arguments:
|
|
99
|
+
data_cls: A dataclass type.
|
|
100
|
+
data: A YAML string.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
A `data_cls` instance.
|
|
104
|
+
"""
|
|
105
|
+
return from_omega(data_cls, OmegaConf.create(data))
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def to_yaml(data: DataClass, sort_keys: bool = False) -> str:
|
|
109
|
+
"""Serialize a dataclass instance to YAML.
|
|
110
|
+
|
|
111
|
+
Arguments:
|
|
112
|
+
data: A dataclass instance.
|
|
113
|
+
sort_keys: Sort mapping keys in YAML output when `True`.
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
A YAML string representation of `data`.
|
|
117
|
+
"""
|
|
118
|
+
return OmegaConf.to_yaml(to_omega(data), sort_keys=sort_keys)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def from_dotlist(data_cls: Type[T], data: List[str]) -> T:
|
|
122
|
+
"""Instantiate a dataclass from a list of dot-style strings.
|
|
123
|
+
|
|
124
|
+
Arguments:
|
|
125
|
+
data_cls: A dataclass type.
|
|
126
|
+
data: A list of dot-style strings.
|
|
127
|
+
For example, `["foo.bar=1", "foo.bis=[baz,qux]"]`.
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
A `data_cls` instance.
|
|
131
|
+
"""
|
|
132
|
+
return from_omega(data_cls, OmegaConf.from_dotlist(data))
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: heracls
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Slayer of Hydra
|
|
5
|
+
Author-email: François Rozet <francois.rozet@outlook.com>
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Project-URL: documentation, https://github.com/francois-rozet/heracls
|
|
8
|
+
Project-URL: source, https://github.com/francois-rozet/heracls
|
|
9
|
+
Project-URL: tracker, https://github.com/francois-rozet/heracls/issues
|
|
10
|
+
Keywords: config,dataclass
|
|
11
|
+
Classifier: Intended Audience :: Developers
|
|
12
|
+
Classifier: Intended Audience :: Science/Research
|
|
13
|
+
Classifier: Natural Language :: English
|
|
14
|
+
Classifier: Operating System :: OS Independent
|
|
15
|
+
Classifier: Programming Language :: Python :: 3
|
|
16
|
+
Requires-Python: >=3.8
|
|
17
|
+
Description-Content-Type: text/markdown
|
|
18
|
+
License-File: LICENSE
|
|
19
|
+
Requires-Dist: dacite>=1.9.2
|
|
20
|
+
Requires-Dist: omegaconf>=2.3.0
|
|
21
|
+
Provides-Extra: dev
|
|
22
|
+
Requires-Dist: pytest>=8.3.5; extra == "dev"
|
|
23
|
+
Requires-Dist: ruff==0.15.2; extra == "dev"
|
|
24
|
+
Dynamic: license-file
|
|
25
|
+
|
|
26
|
+
# Heracls - Slayer of Hydra
|
|
27
|
+
|
|
28
|
+
`heracls` is a tiny utility package to instantiate typed dataclasses from flexible config sources (dictionary, OmegaConf, YAML, dotlist, ...). It is designed for projects that want strict, typed config objects while supporting the dynamic overrides commonly used in scripts and experiments.
|
|
29
|
+
|
|
30
|
+
## Installation
|
|
31
|
+
|
|
32
|
+
The `heracls` package is available on [PyPi](https://pypi.org/project/heracls) and can be installed with `pip`.
|
|
33
|
+
|
|
34
|
+
```
|
|
35
|
+
pip install heracls
|
|
36
|
+
```
|
|
37
|
+
|
|
38
|
+
Alternatively, if you need the latest features, you can install it from source.
|
|
39
|
+
|
|
40
|
+
```
|
|
41
|
+
pip install git+https://github.com/francois-rozet/heracls
|
|
42
|
+
```
|
|
43
|
+
|
|
44
|
+
## Getting started
|
|
45
|
+
|
|
46
|
+
The following example demonstrates how to declare a nested dataclass config, instantiate it while overriding default fields with a dotlist, serialize it to YAML, and use its fields in a script.
|
|
47
|
+
|
|
48
|
+
```python
|
|
49
|
+
import heracls
|
|
50
|
+
|
|
51
|
+
from dataclasses import dataclass, field
|
|
52
|
+
from typing import Any, Dict, Literal, Tuple, Union
|
|
53
|
+
|
|
54
|
+
@dataclass
|
|
55
|
+
class AdamConfig:
|
|
56
|
+
optimizer: Literal["adam", "adamw"] = "adam"
|
|
57
|
+
betas: Tuple[float, float] = (0.95, 0.95)
|
|
58
|
+
learning_rate: float = 1e-3
|
|
59
|
+
weight_decay: float = 0.0
|
|
60
|
+
|
|
61
|
+
@dataclass
|
|
62
|
+
class SGDConfig:
|
|
63
|
+
optimizer: Literal["sgd"] = "sgd"
|
|
64
|
+
momentum: float = 0.0
|
|
65
|
+
learning_rate: float = 1e-3
|
|
66
|
+
weight_decay: float = 0.0
|
|
67
|
+
nesterov: bool = False
|
|
68
|
+
|
|
69
|
+
@dataclass
|
|
70
|
+
class TrainConfig:
|
|
71
|
+
optim: Union[AdamConfig, SGDConfig] = field(default_factory=AdamConfig)
|
|
72
|
+
n_epochs: int = 1024
|
|
73
|
+
data_splits: Tuple[float, float] = (0.8, 0.1)
|
|
74
|
+
slurm: Dict[str, Any] = field(default_factory=dict)
|
|
75
|
+
|
|
76
|
+
dotlist = [ # usually retrieved from the command line
|
|
77
|
+
"optim.optimizer=sgd",
|
|
78
|
+
"data_splits=[0.7,0.2]",
|
|
79
|
+
"slurm.account=frozet",
|
|
80
|
+
]
|
|
81
|
+
|
|
82
|
+
cfg = heracls.from_dotlist(TrainConfig, dotlist)
|
|
83
|
+
|
|
84
|
+
with open("config.yaml", "w") as f:
|
|
85
|
+
f.write(heracls.to_yaml(cfg))
|
|
86
|
+
|
|
87
|
+
trainset, validset, testset = load_dataset(cfg.data_splits)
|
|
88
|
+
|
|
89
|
+
for epoch in range(cfg.n_epochs):
|
|
90
|
+
...
|
|
91
|
+
```
|
|
92
|
+
|
|
93
|
+
The fields that are not specified in the dotlist are instantiated with their default value, as can be seen in the dumped `config.yaml` file.
|
|
94
|
+
|
|
95
|
+
```yaml
|
|
96
|
+
optim:
|
|
97
|
+
optimizer: sgd
|
|
98
|
+
momentum: 0.0
|
|
99
|
+
learning_rate: 0.001
|
|
100
|
+
weight_decay: 0.0
|
|
101
|
+
nesterov: false
|
|
102
|
+
n_epochs: 1024
|
|
103
|
+
data_splits:
|
|
104
|
+
- 0.7
|
|
105
|
+
- 0.2
|
|
106
|
+
slurm:
|
|
107
|
+
account: frozet
|
|
108
|
+
```
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
pyproject.toml
|
|
4
|
+
heracls/__init__.py
|
|
5
|
+
heracls/patch.py
|
|
6
|
+
heracls/transforms.py
|
|
7
|
+
heracls.egg-info/PKG-INFO
|
|
8
|
+
heracls.egg-info/SOURCES.txt
|
|
9
|
+
heracls.egg-info/dependency_links.txt
|
|
10
|
+
heracls.egg-info/requires.txt
|
|
11
|
+
heracls.egg-info/top_level.txt
|
|
12
|
+
tests/test_transforms.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
heracls
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
build-backend = "setuptools.build_meta"
|
|
3
|
+
requires = ["setuptools>=61.0"]
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "heracls"
|
|
7
|
+
description = "Slayer of Hydra"
|
|
8
|
+
authors = [
|
|
9
|
+
{name = "François Rozet", email = "francois.rozet@outlook.com"}
|
|
10
|
+
]
|
|
11
|
+
classifiers = [
|
|
12
|
+
"Intended Audience :: Developers",
|
|
13
|
+
"Intended Audience :: Science/Research",
|
|
14
|
+
"Natural Language :: English",
|
|
15
|
+
"Operating System :: OS Independent",
|
|
16
|
+
"Programming Language :: Python :: 3",
|
|
17
|
+
]
|
|
18
|
+
dependencies = [
|
|
19
|
+
"dacite>=1.9.2",
|
|
20
|
+
"omegaconf>=2.3.0",
|
|
21
|
+
]
|
|
22
|
+
dynamic = ["version"]
|
|
23
|
+
keywords = ["config", "dataclass"]
|
|
24
|
+
license = "MIT"
|
|
25
|
+
readme = "README.md"
|
|
26
|
+
requires-python = ">=3.8"
|
|
27
|
+
|
|
28
|
+
[project.optional-dependencies]
|
|
29
|
+
dev = [
|
|
30
|
+
"pytest>=8.3.5",
|
|
31
|
+
"ruff==0.15.2",
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
[project.urls]
|
|
35
|
+
documentation = "https://github.com/francois-rozet/heracls"
|
|
36
|
+
source = "https://github.com/francois-rozet/heracls"
|
|
37
|
+
tracker = "https://github.com/francois-rozet/heracls/issues"
|
|
38
|
+
|
|
39
|
+
[tool.ruff]
|
|
40
|
+
line-length = 99
|
|
41
|
+
|
|
42
|
+
[tool.ruff.lint]
|
|
43
|
+
extend-select = ["B", "I", "W"]
|
|
44
|
+
ignore = ["FA100"]
|
|
45
|
+
preview = true
|
|
46
|
+
|
|
47
|
+
[tool.ruff.lint.isort]
|
|
48
|
+
lines-between-types = 1
|
|
49
|
+
relative-imports-order = "closest-to-furthest"
|
|
50
|
+
section-order = ["future", "third-party", "first-party", "local-folder"]
|
|
51
|
+
|
|
52
|
+
[tool.ruff.format]
|
|
53
|
+
preview = true
|
|
54
|
+
|
|
55
|
+
[tool.setuptools.dynamic]
|
|
56
|
+
version = {attr = "heracls.__version__"}
|
|
57
|
+
|
|
58
|
+
[tool.setuptools.packages.find]
|
|
59
|
+
include = ["heracls*"]
|
heracls-0.1.0/setup.cfg
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""Tests for the heracls.transforms module."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Any, Dict, Literal, Tuple, Union
|
|
5
|
+
|
|
6
|
+
from heracls.transforms import from_dotlist, from_yaml, to_yaml
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def test_dl_config():
|
|
10
|
+
@dataclass
|
|
11
|
+
class AdamConfig:
|
|
12
|
+
optimizer: Literal["adam", "adamw"] = "adam"
|
|
13
|
+
betas: Tuple[float, float] = (0.95, 0.95)
|
|
14
|
+
learning_rate: float = 1e-3
|
|
15
|
+
weight_decay: float = 0.0
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class SGDConfig:
|
|
19
|
+
optimizer: Literal["sgd"] = "sgd"
|
|
20
|
+
momentum: float = 0.0
|
|
21
|
+
learning_rate: float = 1e-3
|
|
22
|
+
weight_decay: float = 0.0
|
|
23
|
+
nesterov: bool = False
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class TrainConfig:
|
|
27
|
+
optim: Union[AdamConfig, SGDConfig] = field(default_factory=AdamConfig)
|
|
28
|
+
n_epochs: int = 1024
|
|
29
|
+
data_splits: Tuple[float, float] = (0.8, 0.1)
|
|
30
|
+
slurm: Dict[str, Any] = field(default_factory=dict)
|
|
31
|
+
|
|
32
|
+
dotlist = ["optim.optimizer=sgd", "data_splits=[0.7,0.2]", "slurm.account=frozet"]
|
|
33
|
+
cfg = from_dotlist(TrainConfig, dotlist)
|
|
34
|
+
|
|
35
|
+
assert cfg == TrainConfig(
|
|
36
|
+
optim=SGDConfig(),
|
|
37
|
+
data_splits=(0.7, 0.2),
|
|
38
|
+
slurm={"account": "frozet"},
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
dump = to_yaml(cfg)
|
|
42
|
+
|
|
43
|
+
assert from_yaml(TrainConfig, dump) == cfg
|