experimaestro 2.0.0a8__py3-none-any.whl → 2.0.0b4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of experimaestro might be problematic. Click here for more details.
- experimaestro/__init__.py +10 -11
- experimaestro/annotations.py +167 -206
- experimaestro/cli/__init__.py +130 -5
- experimaestro/cli/filter.py +42 -74
- experimaestro/cli/jobs.py +157 -106
- experimaestro/cli/refactor.py +249 -0
- experimaestro/click.py +0 -1
- experimaestro/commandline.py +19 -3
- experimaestro/connectors/__init__.py +20 -1
- experimaestro/connectors/local.py +12 -0
- experimaestro/core/arguments.py +182 -46
- experimaestro/core/identifier.py +107 -6
- experimaestro/core/objects/__init__.py +6 -0
- experimaestro/core/objects/config.py +542 -25
- experimaestro/core/objects/config_walk.py +20 -0
- experimaestro/core/serialization.py +91 -34
- experimaestro/core/subparameters.py +164 -0
- experimaestro/core/types.py +175 -38
- experimaestro/exceptions.py +26 -0
- experimaestro/experiments/cli.py +107 -25
- experimaestro/generators.py +50 -9
- experimaestro/huggingface.py +3 -1
- experimaestro/launcherfinder/parser.py +29 -0
- experimaestro/launchers/__init__.py +26 -1
- experimaestro/launchers/direct.py +12 -0
- experimaestro/launchers/slurm/base.py +154 -2
- experimaestro/mkdocs/metaloader.py +0 -1
- experimaestro/mypy.py +452 -7
- experimaestro/notifications.py +63 -13
- experimaestro/progress.py +0 -2
- experimaestro/rpyc.py +0 -1
- experimaestro/run.py +19 -6
- experimaestro/scheduler/base.py +489 -125
- experimaestro/scheduler/dependencies.py +43 -28
- experimaestro/scheduler/dynamic_outputs.py +259 -130
- experimaestro/scheduler/experiment.py +225 -30
- experimaestro/scheduler/interfaces.py +474 -0
- experimaestro/scheduler/jobs.py +216 -206
- experimaestro/scheduler/services.py +186 -12
- experimaestro/scheduler/state_db.py +388 -0
- experimaestro/scheduler/state_provider.py +2345 -0
- experimaestro/scheduler/state_sync.py +834 -0
- experimaestro/scheduler/workspace.py +52 -10
- experimaestro/scriptbuilder.py +7 -0
- experimaestro/server/__init__.py +147 -57
- experimaestro/server/data/index.css +0 -125
- experimaestro/server/data/index.css.map +1 -1
- experimaestro/server/data/index.js +194 -58
- experimaestro/server/data/index.js.map +1 -1
- experimaestro/settings.py +44 -5
- experimaestro/sphinx/__init__.py +3 -3
- experimaestro/taskglobals.py +20 -0
- experimaestro/tests/conftest.py +80 -0
- experimaestro/tests/core/test_generics.py +2 -2
- experimaestro/tests/identifier_stability.json +45 -0
- experimaestro/tests/launchers/bin/sacct +6 -2
- experimaestro/tests/launchers/bin/sbatch +4 -2
- experimaestro/tests/launchers/test_slurm.py +80 -0
- experimaestro/tests/tasks/test_dynamic.py +231 -0
- experimaestro/tests/test_cli_jobs.py +615 -0
- experimaestro/tests/test_deprecated.py +630 -0
- experimaestro/tests/test_environment.py +200 -0
- experimaestro/tests/test_file_progress_integration.py +1 -1
- experimaestro/tests/test_forward.py +3 -3
- experimaestro/tests/test_identifier.py +372 -41
- experimaestro/tests/test_identifier_stability.py +458 -0
- experimaestro/tests/test_instance.py +3 -3
- experimaestro/tests/test_multitoken.py +442 -0
- experimaestro/tests/test_mypy.py +433 -0
- experimaestro/tests/test_objects.py +312 -5
- experimaestro/tests/test_outputs.py +2 -2
- experimaestro/tests/test_param.py +8 -12
- experimaestro/tests/test_partial_paths.py +231 -0
- experimaestro/tests/test_progress.py +0 -48
- experimaestro/tests/test_resumable_task.py +480 -0
- experimaestro/tests/test_serializers.py +141 -1
- experimaestro/tests/test_state_db.py +434 -0
- experimaestro/tests/test_subparameters.py +160 -0
- experimaestro/tests/test_tags.py +136 -0
- experimaestro/tests/test_tasks.py +107 -121
- experimaestro/tests/test_token_locking.py +252 -0
- experimaestro/tests/test_tokens.py +17 -13
- experimaestro/tests/test_types.py +123 -1
- experimaestro/tests/test_workspace_triggers.py +158 -0
- experimaestro/tests/token_reschedule.py +4 -2
- experimaestro/tests/utils.py +2 -2
- experimaestro/tokens.py +154 -57
- experimaestro/tools/diff.py +1 -1
- experimaestro/tui/__init__.py +8 -0
- experimaestro/tui/app.py +2303 -0
- experimaestro/tui/app.tcss +353 -0
- experimaestro/tui/log_viewer.py +228 -0
- experimaestro/utils/__init__.py +23 -0
- experimaestro/utils/environment.py +148 -0
- experimaestro/utils/git.py +129 -0
- experimaestro/utils/resources.py +1 -1
- experimaestro/version.py +34 -0
- {experimaestro-2.0.0a8.dist-info → experimaestro-2.0.0b4.dist-info}/METADATA +68 -38
- experimaestro-2.0.0b4.dist-info/RECORD +181 -0
- {experimaestro-2.0.0a8.dist-info → experimaestro-2.0.0b4.dist-info}/WHEEL +1 -1
- experimaestro-2.0.0b4.dist-info/entry_points.txt +16 -0
- experimaestro/compat.py +0 -6
- experimaestro/core/objects.pyi +0 -221
- experimaestro/server/data/0c35d18bf06992036b69.woff2 +0 -0
- experimaestro/server/data/219aa9140e099e6c72ed.woff2 +0 -0
- experimaestro/server/data/3a4004a46a653d4b2166.woff +0 -0
- experimaestro/server/data/3baa5b8f3469222b822d.woff +0 -0
- experimaestro/server/data/4d73cb90e394b34b7670.woff +0 -0
- experimaestro/server/data/4ef4218c522f1eb6b5b1.woff2 +0 -0
- experimaestro/server/data/5d681e2edae8c60630db.woff +0 -0
- experimaestro/server/data/6f420cf17cc0d7676fad.woff2 +0 -0
- experimaestro/server/data/c380809fd3677d7d6903.woff2 +0 -0
- experimaestro/server/data/f882956fd323fd322f31.woff +0 -0
- experimaestro-2.0.0a8.dist-info/RECORD +0 -166
- experimaestro-2.0.0a8.dist-info/entry_points.txt +0 -17
- {experimaestro-2.0.0a8.dist-info → experimaestro-2.0.0b4.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
"""Tests for environment capture utilities"""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import pytest
|
|
5
|
+
|
|
6
|
+
from experimaestro.utils.git import get_git_info
|
|
7
|
+
from experimaestro.utils.environment import (
|
|
8
|
+
get_environment_info,
|
|
9
|
+
get_editable_packages_git_info,
|
|
10
|
+
save_environment_info,
|
|
11
|
+
load_environment_info,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TestGetGitInfo:
|
|
16
|
+
"""Tests for get_git_info function"""
|
|
17
|
+
|
|
18
|
+
def test_returns_dict_in_git_repo(self, tmp_path):
|
|
19
|
+
"""Test that get_git_info returns a dict when in a git repo"""
|
|
20
|
+
# Use the current working directory which should be a git repo
|
|
21
|
+
git_info = get_git_info()
|
|
22
|
+
|
|
23
|
+
assert git_info is not None
|
|
24
|
+
assert isinstance(git_info, dict)
|
|
25
|
+
assert "commit" in git_info
|
|
26
|
+
assert "commit_short" in git_info
|
|
27
|
+
assert "branch" in git_info
|
|
28
|
+
assert "dirty" in git_info
|
|
29
|
+
assert "message" in git_info
|
|
30
|
+
assert "author" in git_info
|
|
31
|
+
assert "date" in git_info
|
|
32
|
+
|
|
33
|
+
def test_commit_format(self):
|
|
34
|
+
"""Test that commit hashes have correct format"""
|
|
35
|
+
git_info = get_git_info()
|
|
36
|
+
if git_info is None:
|
|
37
|
+
pytest.skip("Not in a git repository")
|
|
38
|
+
|
|
39
|
+
# Full commit should be 40 hex characters
|
|
40
|
+
assert len(git_info["commit"]) == 40
|
|
41
|
+
assert all(c in "0123456789abcdef" for c in git_info["commit"])
|
|
42
|
+
|
|
43
|
+
# Short commit should be 7 characters
|
|
44
|
+
assert len(git_info["commit_short"]) == 7
|
|
45
|
+
|
|
46
|
+
def test_returns_none_for_non_git_dir(self, tmp_path):
|
|
47
|
+
"""Test that get_git_info returns None for non-git directories"""
|
|
48
|
+
git_info = get_git_info(tmp_path)
|
|
49
|
+
assert git_info is None
|
|
50
|
+
|
|
51
|
+
def test_dirty_flag(self):
|
|
52
|
+
"""Test that dirty flag is a boolean"""
|
|
53
|
+
git_info = get_git_info()
|
|
54
|
+
if git_info is None:
|
|
55
|
+
pytest.skip("Not in a git repository")
|
|
56
|
+
|
|
57
|
+
assert isinstance(git_info["dirty"], bool)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class TestGetEnvironmentInfo:
|
|
61
|
+
"""Tests for get_environment_info function"""
|
|
62
|
+
|
|
63
|
+
def test_returns_dict_with_required_keys(self):
|
|
64
|
+
"""Test that get_environment_info returns dict with required keys"""
|
|
65
|
+
env_info = get_environment_info()
|
|
66
|
+
|
|
67
|
+
assert isinstance(env_info, dict)
|
|
68
|
+
assert "python_version" in env_info
|
|
69
|
+
assert "packages" in env_info
|
|
70
|
+
assert "editable_packages" in env_info
|
|
71
|
+
|
|
72
|
+
def test_python_version_format(self):
|
|
73
|
+
"""Test that python_version has correct format"""
|
|
74
|
+
env_info = get_environment_info()
|
|
75
|
+
version = env_info["python_version"]
|
|
76
|
+
|
|
77
|
+
# Should be in format X.Y.Z
|
|
78
|
+
parts = version.split(".")
|
|
79
|
+
assert len(parts) == 3
|
|
80
|
+
assert all(part.isdigit() for part in parts)
|
|
81
|
+
|
|
82
|
+
def test_packages_is_dict(self):
|
|
83
|
+
"""Test that packages is a dict of name -> version"""
|
|
84
|
+
env_info = get_environment_info()
|
|
85
|
+
packages = env_info["packages"]
|
|
86
|
+
|
|
87
|
+
assert isinstance(packages, dict)
|
|
88
|
+
assert len(packages) > 0 # Should have at least some packages
|
|
89
|
+
|
|
90
|
+
# Check that all values are strings (versions)
|
|
91
|
+
for name, version in packages.items():
|
|
92
|
+
assert isinstance(name, str)
|
|
93
|
+
assert isinstance(version, str)
|
|
94
|
+
|
|
95
|
+
def test_experimaestro_is_editable(self):
|
|
96
|
+
"""Test that experimaestro itself is detected as editable"""
|
|
97
|
+
env_info = get_environment_info()
|
|
98
|
+
editable = env_info["editable_packages"]
|
|
99
|
+
|
|
100
|
+
# When running tests, experimaestro should be installed in editable mode
|
|
101
|
+
assert "experimaestro" in editable
|
|
102
|
+
assert "version" in editable["experimaestro"]
|
|
103
|
+
assert "path" in editable["experimaestro"]
|
|
104
|
+
assert "git" in editable["experimaestro"]
|
|
105
|
+
|
|
106
|
+
def test_editable_package_has_git_info(self):
|
|
107
|
+
"""Test that editable packages include git info"""
|
|
108
|
+
env_info = get_environment_info()
|
|
109
|
+
editable = env_info["editable_packages"]
|
|
110
|
+
|
|
111
|
+
# experimaestro should have git info since it's in a git repo
|
|
112
|
+
if "experimaestro" in editable:
|
|
113
|
+
git_info = editable["experimaestro"]["git"]
|
|
114
|
+
if git_info is not None: # May be None if not in git repo
|
|
115
|
+
assert "commit" in git_info
|
|
116
|
+
assert "dirty" in git_info
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class TestGetEditablePackagesGitInfo:
|
|
120
|
+
"""Tests for get_editable_packages_git_info function"""
|
|
121
|
+
|
|
122
|
+
def test_returns_dict(self):
|
|
123
|
+
"""Test that function returns a dict"""
|
|
124
|
+
result = get_editable_packages_git_info()
|
|
125
|
+
assert isinstance(result, dict)
|
|
126
|
+
|
|
127
|
+
def test_contains_experimaestro(self):
|
|
128
|
+
"""Test that experimaestro is in the result"""
|
|
129
|
+
result = get_editable_packages_git_info()
|
|
130
|
+
assert "experimaestro" in result
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class TestSaveAndLoadEnvironmentInfo:
|
|
134
|
+
"""Tests for save_environment_info and load_environment_info functions"""
|
|
135
|
+
|
|
136
|
+
def test_save_creates_file(self, tmp_path):
|
|
137
|
+
"""Test that save_environment_info creates a JSON file"""
|
|
138
|
+
path = tmp_path / "environment.json"
|
|
139
|
+
|
|
140
|
+
result = save_environment_info(path)
|
|
141
|
+
|
|
142
|
+
assert path.exists()
|
|
143
|
+
assert isinstance(result, dict)
|
|
144
|
+
|
|
145
|
+
def test_save_writes_valid_json(self, tmp_path):
|
|
146
|
+
"""Test that saved file contains valid JSON"""
|
|
147
|
+
path = tmp_path / "environment.json"
|
|
148
|
+
|
|
149
|
+
save_environment_info(path)
|
|
150
|
+
|
|
151
|
+
content = json.loads(path.read_text())
|
|
152
|
+
assert "python_version" in content
|
|
153
|
+
assert "packages" in content
|
|
154
|
+
assert "editable_packages" in content
|
|
155
|
+
|
|
156
|
+
def test_load_reads_saved_data(self, tmp_path):
|
|
157
|
+
"""Test that load_environment_info reads back saved data"""
|
|
158
|
+
path = tmp_path / "environment.json"
|
|
159
|
+
|
|
160
|
+
saved = save_environment_info(path)
|
|
161
|
+
loaded = load_environment_info(path)
|
|
162
|
+
|
|
163
|
+
assert loaded == saved
|
|
164
|
+
|
|
165
|
+
def test_load_returns_none_for_missing_file(self, tmp_path):
|
|
166
|
+
"""Test that load returns None for non-existent file"""
|
|
167
|
+
path = tmp_path / "nonexistent.json"
|
|
168
|
+
|
|
169
|
+
result = load_environment_info(path)
|
|
170
|
+
|
|
171
|
+
assert result is None
|
|
172
|
+
|
|
173
|
+
def test_load_returns_none_for_invalid_json(self, tmp_path):
|
|
174
|
+
"""Test that load returns None for invalid JSON"""
|
|
175
|
+
path = tmp_path / "invalid.json"
|
|
176
|
+
path.write_text("not valid json{")
|
|
177
|
+
|
|
178
|
+
result = load_environment_info(path)
|
|
179
|
+
|
|
180
|
+
assert result is None
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
class TestExperimentEnvironmentSaving:
|
|
184
|
+
"""Integration tests for environment saving in experiments"""
|
|
185
|
+
|
|
186
|
+
def test_experiment_saves_environment_info(self, xpmdirectory):
|
|
187
|
+
"""Test that experiment saves environment.json on start"""
|
|
188
|
+
from experimaestro import experiment
|
|
189
|
+
|
|
190
|
+
# Just enter the experiment context, no need to run any tasks
|
|
191
|
+
with experiment(xpmdirectory, "test-env-save", port=-1) as xp:
|
|
192
|
+
pass # environment.json should be saved on __enter__
|
|
193
|
+
|
|
194
|
+
env_path = xp.workdir / "environment.json"
|
|
195
|
+
assert env_path.exists()
|
|
196
|
+
|
|
197
|
+
env_info = json.loads(env_path.read_text())
|
|
198
|
+
assert "python_version" in env_info
|
|
199
|
+
assert "packages" in env_info
|
|
200
|
+
assert "editable_packages" in env_info
|
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
from experimaestro import Param, Config
|
|
1
|
+
from experimaestro import field, Param, Config
|
|
2
2
|
from experimaestro.click import forwardoption
|
|
3
3
|
import click
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
def test_main():
|
|
7
7
|
class MyModel(Config):
|
|
8
|
-
epochs: Param[int] = 100
|
|
8
|
+
epochs: Param[int] = field(ignore_default=100)
|
|
9
9
|
"""Number of learning epochs"""
|
|
10
10
|
|
|
11
11
|
@forwardoption.epochs(MyModel)
|
|
@@ -19,7 +19,7 @@ def test_main():
|
|
|
19
19
|
|
|
20
20
|
def test_rename():
|
|
21
21
|
class MyModel(Config):
|
|
22
|
-
epochs: Param[int] = 100
|
|
22
|
+
epochs: Param[int] = field(ignore_default=100)
|
|
23
23
|
"""Number of learning epochs"""
|
|
24
24
|
|
|
25
25
|
@forwardoption.epochs(MyModel, "my-epochs")
|
|
@@ -5,8 +5,8 @@ from pathlib import Path
|
|
|
5
5
|
from typing import Dict, List, Optional
|
|
6
6
|
from experimaestro import (
|
|
7
7
|
Param,
|
|
8
|
-
deprecate,
|
|
9
8
|
Config,
|
|
9
|
+
InstanceConfig,
|
|
10
10
|
Constant,
|
|
11
11
|
Meta,
|
|
12
12
|
Option,
|
|
@@ -14,6 +14,8 @@ from experimaestro import (
|
|
|
14
14
|
field,
|
|
15
15
|
Task,
|
|
16
16
|
LightweightTask,
|
|
17
|
+
subparameters,
|
|
18
|
+
param_group,
|
|
17
19
|
)
|
|
18
20
|
from experimaestro.core.objects import (
|
|
19
21
|
ConfigInformation,
|
|
@@ -34,7 +36,7 @@ class B(Config):
|
|
|
34
36
|
|
|
35
37
|
|
|
36
38
|
class C(Config):
|
|
37
|
-
a: Param[int] = 1
|
|
39
|
+
a: Param[int] = field(ignore_default=1)
|
|
38
40
|
b: Param[int]
|
|
39
41
|
|
|
40
42
|
|
|
@@ -129,7 +131,7 @@ def test_identifier_option():
|
|
|
129
131
|
class OptionConfig(Config):
|
|
130
132
|
__xpmid__ = "test.identifier.option"
|
|
131
133
|
a: Param[int]
|
|
132
|
-
b: Option[int] = 1
|
|
134
|
+
b: Option[int] = field(ignore_default=1)
|
|
133
135
|
|
|
134
136
|
assert_notequal(OptionConfig.C(a=2), OptionConfig.C(a=1))
|
|
135
137
|
assert_equal(OptionConfig.C(a=1, b=2), OptionConfig.C(a=1))
|
|
@@ -229,7 +231,7 @@ def test_identifier_defaultnew():
|
|
|
229
231
|
__xpmid__ = "defaultnew"
|
|
230
232
|
|
|
231
233
|
a: Param[int]
|
|
232
|
-
b: Param[int] = 1
|
|
234
|
+
b: Param[int] = field(ignore_default=1)
|
|
233
235
|
|
|
234
236
|
class A(Config):
|
|
235
237
|
__xpmid__ = "defaultnew"
|
|
@@ -281,41 +283,6 @@ def test_identifier_constant():
|
|
|
281
283
|
assert_notequal(A1.C(), A2.C())
|
|
282
284
|
|
|
283
285
|
|
|
284
|
-
def test_identifier_deprecated_class():
|
|
285
|
-
"""Test that when submitting the task, the computed identifier is the one of
|
|
286
|
-
the new class"""
|
|
287
|
-
|
|
288
|
-
class NewConfig(Config):
|
|
289
|
-
__xpmid__ = "new"
|
|
290
|
-
|
|
291
|
-
@deprecate
|
|
292
|
-
class OldConfig(NewConfig):
|
|
293
|
-
__xpmid__ = "old"
|
|
294
|
-
|
|
295
|
-
class DerivedConfig(NewConfig):
|
|
296
|
-
__xpmid__ = "derived"
|
|
297
|
-
|
|
298
|
-
assert_notequal(
|
|
299
|
-
NewConfig.C(), DerivedConfig.C(), "A derived configuration has another ID"
|
|
300
|
-
)
|
|
301
|
-
assert_equal(
|
|
302
|
-
NewConfig.C(),
|
|
303
|
-
OldConfig.C(),
|
|
304
|
-
"Deprecated and new configuration have the same ID",
|
|
305
|
-
)
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
def test_identifier_deprecated_attribute():
|
|
309
|
-
class Values(Config):
|
|
310
|
-
values: Param[List[int]] = []
|
|
311
|
-
|
|
312
|
-
@deprecate
|
|
313
|
-
def value(self, x):
|
|
314
|
-
self.values = [x]
|
|
315
|
-
|
|
316
|
-
assert_equal(Values.C(values=[1]), Values.C(value=1))
|
|
317
|
-
|
|
318
|
-
|
|
319
286
|
class MetaA(Config):
|
|
320
287
|
x: Param[int]
|
|
321
288
|
|
|
@@ -365,7 +332,7 @@ def test_identifier_meta():
|
|
|
365
332
|
|
|
366
333
|
def test_identifier_meta_default_dict():
|
|
367
334
|
class DictConfig(Config):
|
|
368
|
-
params: Param[Dict[str, MetaA]] = {}
|
|
335
|
+
params: Param[Dict[str, MetaA]] = field(ignore_default={})
|
|
369
336
|
|
|
370
337
|
assert_equal(
|
|
371
338
|
DictConfig.C(params={}),
|
|
@@ -381,7 +348,7 @@ def test_identifier_meta_default_dict():
|
|
|
381
348
|
|
|
382
349
|
def test_identifier_meta_default_array():
|
|
383
350
|
class ArrayConfigWithDefault(Config):
|
|
384
|
-
array: Param[List[MetaA]] = []
|
|
351
|
+
array: Param[List[MetaA]] = field(ignore_default=[])
|
|
385
352
|
|
|
386
353
|
# Array (with default) with mixed
|
|
387
354
|
assert_equal(
|
|
@@ -590,3 +557,367 @@ def test_identifier_loop():
|
|
|
590
557
|
for i in range(len(configs)):
|
|
591
558
|
for j in range(1, len(configs)):
|
|
592
559
|
assert identifiers[i][0] == identifiers[i][j]
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
# --- Test InstanceConfig
|
|
563
|
+
|
|
564
|
+
|
|
565
|
+
class SubModel(InstanceConfig):
|
|
566
|
+
"""Test InstanceConfig - instances are distinguished even with same params"""
|
|
567
|
+
|
|
568
|
+
pass
|
|
569
|
+
|
|
570
|
+
|
|
571
|
+
class SubModelAsConfig(Config):
|
|
572
|
+
"""Same as SubModel but as regular Config for backwards compat testing"""
|
|
573
|
+
|
|
574
|
+
__xpmid__ = "test.SubModel"
|
|
575
|
+
pass
|
|
576
|
+
|
|
577
|
+
|
|
578
|
+
class Model(Config):
|
|
579
|
+
"""Model that can contain SubModel instances"""
|
|
580
|
+
|
|
581
|
+
m1: Param[SubModel]
|
|
582
|
+
m2: Param[SubModel]
|
|
583
|
+
|
|
584
|
+
|
|
585
|
+
class ModelWithRegularConfig(Config):
|
|
586
|
+
"""Model using regular Config instead of InstanceConfig"""
|
|
587
|
+
|
|
588
|
+
__xpmid__ = "test.Model"
|
|
589
|
+
m1: Param[SubModelAsConfig]
|
|
590
|
+
m2: Param[SubModelAsConfig]
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
def test_instanceconfig_backwards_compat():
|
|
594
|
+
"""Model using single InstanceConfig should have same ID as with regular Config"""
|
|
595
|
+
# Using InstanceConfig (first occurrence only, no instance marker added)
|
|
596
|
+
sm1 = SubModel.C()
|
|
597
|
+
sm1.__xpmtype__.identifier.name = "test.SubModel" # Match the __xpmid__
|
|
598
|
+
m_instance = Model.C(m1=sm1, m2=sm1)
|
|
599
|
+
m_instance.__xpmtype__.identifier.name = "test.Model"
|
|
600
|
+
|
|
601
|
+
# Using regular Config
|
|
602
|
+
sc1 = SubModelAsConfig.C()
|
|
603
|
+
m_regular = ModelWithRegularConfig.C(m1=sc1, m2=sc1)
|
|
604
|
+
|
|
605
|
+
# Should have same identifier (backwards compatible)
|
|
606
|
+
assert_equal(
|
|
607
|
+
m_instance, m_regular, "Single InstanceConfig should be backwards compatible"
|
|
608
|
+
)
|
|
609
|
+
|
|
610
|
+
|
|
611
|
+
def test_instanceconfig_same_params_different_instances():
|
|
612
|
+
"""Model with separate InstanceConfig instances should differ from shared"""
|
|
613
|
+
sm1 = SubModel.C()
|
|
614
|
+
sm2 = SubModel.C()
|
|
615
|
+
|
|
616
|
+
# Using the same instance twice (shared)
|
|
617
|
+
m1 = Model.C(m1=sm1, m2=sm1)
|
|
618
|
+
|
|
619
|
+
# Using different instances (separate)
|
|
620
|
+
m2 = Model.C(m1=sm1, m2=sm2)
|
|
621
|
+
|
|
622
|
+
# These should be different because sm2 is a second instance with same params
|
|
623
|
+
assert_notequal(m1, m2, "Models with shared vs separate instances should differ")
|
|
624
|
+
|
|
625
|
+
|
|
626
|
+
def test_instanceconfig_reused_instance():
|
|
627
|
+
"""Reusing the same InstanceConfig instance should give same ID"""
|
|
628
|
+
sm1 = SubModel.C()
|
|
629
|
+
|
|
630
|
+
# Using the same instance object multiple times should be OK
|
|
631
|
+
m1 = Model.C(m1=sm1, m2=sm1)
|
|
632
|
+
m2 = Model.C(m1=sm1, m2=sm1)
|
|
633
|
+
|
|
634
|
+
# These should be the same because we're reusing the exact same objects
|
|
635
|
+
assert_equal(m1, m2, "Models with same instance objects should be equal")
|
|
636
|
+
|
|
637
|
+
|
|
638
|
+
def test_instanceconfig_serialization():
|
|
639
|
+
"""InstanceConfig identifiers should be stable after serialization"""
|
|
640
|
+
sm1 = SubModel.C()
|
|
641
|
+
sm2 = SubModel.C()
|
|
642
|
+
|
|
643
|
+
# Create a model with two different instances
|
|
644
|
+
m1 = Model.C(m1=sm1, m2=sm2)
|
|
645
|
+
original_id = getidentifier(m1)
|
|
646
|
+
|
|
647
|
+
# Serialize and reload
|
|
648
|
+
check_reload(m1)
|
|
649
|
+
|
|
650
|
+
# The identifier should remain the same
|
|
651
|
+
assert getidentifier(m1) == original_id
|
|
652
|
+
|
|
653
|
+
|
|
654
|
+
# --- Test ignore_default vs default in field() ---
|
|
655
|
+
|
|
656
|
+
|
|
657
|
+
def test_identifier_field_ignore_default():
|
|
658
|
+
"""Test that field(ignore_default=X) ignores value in identifier when value == X"""
|
|
659
|
+
|
|
660
|
+
class ConfigWithIgnoreDefault(Config):
|
|
661
|
+
__xpmid__ = "test.identifier.field_ignore_default"
|
|
662
|
+
a: Param[int] = field(ignore_default=1)
|
|
663
|
+
b: Param[int]
|
|
664
|
+
|
|
665
|
+
# When a=1 (matches ignore_default), should be same as not specifying a
|
|
666
|
+
class ConfigWithoutA(Config):
|
|
667
|
+
__xpmid__ = "test.identifier.field_ignore_default"
|
|
668
|
+
b: Param[int]
|
|
669
|
+
|
|
670
|
+
assert_equal(
|
|
671
|
+
ConfigWithIgnoreDefault.C(a=1, b=2),
|
|
672
|
+
ConfigWithIgnoreDefault.C(b=2),
|
|
673
|
+
"field(ignore_default=1) should ignore a=1 in identifier",
|
|
674
|
+
)
|
|
675
|
+
assert_equal(
|
|
676
|
+
ConfigWithIgnoreDefault.C(a=1, b=2),
|
|
677
|
+
ConfigWithoutA.C(b=2),
|
|
678
|
+
"Config with ignore_default should match config without that param",
|
|
679
|
+
)
|
|
680
|
+
|
|
681
|
+
# When a=2 (doesn't match ignore_default), should be included
|
|
682
|
+
assert_notequal(
|
|
683
|
+
ConfigWithIgnoreDefault.C(a=2, b=2),
|
|
684
|
+
ConfigWithIgnoreDefault.C(b=2),
|
|
685
|
+
"field(ignore_default=1) should include a=2 in identifier",
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
|
|
689
|
+
def test_identifier_field_default():
|
|
690
|
+
"""Test that field(default=X) includes value in identifier even when value == X"""
|
|
691
|
+
|
|
692
|
+
class ConfigWithDefault(Config):
|
|
693
|
+
__xpmid__ = "test.identifier.field_default"
|
|
694
|
+
a: Param[int] = field(default=1)
|
|
695
|
+
b: Param[int]
|
|
696
|
+
|
|
697
|
+
class ConfigWithoutA(Config):
|
|
698
|
+
__xpmid__ = "test.identifier.field_default"
|
|
699
|
+
b: Param[int]
|
|
700
|
+
|
|
701
|
+
# When a=1 (matches default), should still be included in identifier
|
|
702
|
+
# so Config with a=1 should differ from Config without a
|
|
703
|
+
assert_notequal(
|
|
704
|
+
ConfigWithDefault.C(a=1, b=2),
|
|
705
|
+
ConfigWithoutA.C(b=2),
|
|
706
|
+
"field(default=1) should include a=1 in identifier",
|
|
707
|
+
)
|
|
708
|
+
|
|
709
|
+
# But two configs with same a=1 should be equal
|
|
710
|
+
assert_equal(
|
|
711
|
+
ConfigWithDefault.C(a=1, b=2),
|
|
712
|
+
ConfigWithDefault.C(a=1, b=2),
|
|
713
|
+
"Same values should have same identifier",
|
|
714
|
+
)
|
|
715
|
+
|
|
716
|
+
|
|
717
|
+
def test_identifier_field_default_vs_ignore_default():
|
|
718
|
+
"""Test difference between field(default=X) and field(ignore_default=X)"""
|
|
719
|
+
|
|
720
|
+
class ConfigWithDefault(Config):
|
|
721
|
+
__xpmid__ = "test.identifier.field_default_vs_ignore"
|
|
722
|
+
a: Param[int] = field(default=1)
|
|
723
|
+
b: Param[int]
|
|
724
|
+
|
|
725
|
+
class ConfigWithIgnoreDefault(Config):
|
|
726
|
+
__xpmid__ = "test.identifier.field_default_vs_ignore"
|
|
727
|
+
a: Param[int] = field(ignore_default=1)
|
|
728
|
+
b: Param[int]
|
|
729
|
+
|
|
730
|
+
# Both with a=1, b=2 - should differ because one includes a, other doesn't
|
|
731
|
+
assert_notequal(
|
|
732
|
+
ConfigWithDefault.C(a=1, b=2),
|
|
733
|
+
ConfigWithIgnoreDefault.C(a=1, b=2),
|
|
734
|
+
"field(default=1) vs field(ignore_default=1) should differ when a=1",
|
|
735
|
+
)
|
|
736
|
+
|
|
737
|
+
# Both with a=2 (not matching default), should be the same
|
|
738
|
+
assert_equal(
|
|
739
|
+
ConfigWithDefault.C(a=2, b=2),
|
|
740
|
+
ConfigWithIgnoreDefault.C(a=2, b=2),
|
|
741
|
+
"field(default=1) vs field(ignore_default=1) should be same when a!=1",
|
|
742
|
+
)
|
|
743
|
+
|
|
744
|
+
|
|
745
|
+
# --- Test partial identifiers (subparameters) ---
|
|
746
|
+
|
|
747
|
+
|
|
748
|
+
# Define parameter groups at module level
|
|
749
|
+
iter_group = param_group("iter")
|
|
750
|
+
model_group = param_group("model")
|
|
751
|
+
|
|
752
|
+
|
|
753
|
+
def get_partial_identifier(config, sp):
|
|
754
|
+
"""Helper to get partial identifier for a config and subparameters"""
|
|
755
|
+
return config.__xpm__.get_partial_identifier(sp).all
|
|
756
|
+
|
|
757
|
+
|
|
758
|
+
def test_partial_identifier_excludes_grouped_params():
|
|
759
|
+
"""Test that partial identifier excludes parameters in excluded groups"""
|
|
760
|
+
|
|
761
|
+
class ConfigWithGroups(Config):
|
|
762
|
+
checkpoints = subparameters(exclude_groups=[iter_group])
|
|
763
|
+
max_iter: Param[int] = field(groups=[iter_group])
|
|
764
|
+
learning_rate: Param[float]
|
|
765
|
+
|
|
766
|
+
c1 = ConfigWithGroups.C(max_iter=100, learning_rate=0.1)
|
|
767
|
+
c2 = ConfigWithGroups.C(max_iter=200, learning_rate=0.1)
|
|
768
|
+
|
|
769
|
+
# Full identifiers should differ (max_iter is different)
|
|
770
|
+
assert_notequal(c1, c2, "Full identifiers should differ when max_iter differs")
|
|
771
|
+
|
|
772
|
+
# Partial identifiers should be the same (max_iter is excluded)
|
|
773
|
+
pid1 = get_partial_identifier(c1, ConfigWithGroups.checkpoints)
|
|
774
|
+
pid2 = get_partial_identifier(c2, ConfigWithGroups.checkpoints)
|
|
775
|
+
assert (
|
|
776
|
+
pid1 == pid2
|
|
777
|
+
), "Partial identifiers should match when only excluded params differ"
|
|
778
|
+
|
|
779
|
+
|
|
780
|
+
def test_partial_identifier_includes_ungrouped_params():
|
|
781
|
+
"""Test that partial identifier includes parameters not in excluded groups"""
|
|
782
|
+
|
|
783
|
+
class ConfigWithGroups(Config):
|
|
784
|
+
checkpoints = subparameters(exclude_groups=[iter_group])
|
|
785
|
+
max_iter: Param[int] = field(groups=[iter_group])
|
|
786
|
+
learning_rate: Param[float]
|
|
787
|
+
|
|
788
|
+
c1 = ConfigWithGroups.C(max_iter=100, learning_rate=0.1)
|
|
789
|
+
c2 = ConfigWithGroups.C(max_iter=100, learning_rate=0.2)
|
|
790
|
+
|
|
791
|
+
# Partial identifiers should differ (learning_rate is not excluded)
|
|
792
|
+
pid1 = get_partial_identifier(c1, ConfigWithGroups.checkpoints)
|
|
793
|
+
pid2 = get_partial_identifier(c2, ConfigWithGroups.checkpoints)
|
|
794
|
+
assert (
|
|
795
|
+
pid1 != pid2
|
|
796
|
+
), "Partial identifiers should differ when non-excluded params differ"
|
|
797
|
+
|
|
798
|
+
|
|
799
|
+
def test_partial_identifier_matches_config_without_excluded():
|
|
800
|
+
"""Test that partial identifier matches config without the excluded fields"""
|
|
801
|
+
|
|
802
|
+
class ConfigWithIter(Config):
|
|
803
|
+
__xpmid__ = "test.partial_identifier.config"
|
|
804
|
+
checkpoints = subparameters(exclude_groups=[iter_group])
|
|
805
|
+
max_iter: Param[int] = field(groups=[iter_group])
|
|
806
|
+
learning_rate: Param[float]
|
|
807
|
+
|
|
808
|
+
class ConfigWithoutIter(Config):
|
|
809
|
+
__xpmid__ = "test.partial_identifier.config"
|
|
810
|
+
learning_rate: Param[float]
|
|
811
|
+
|
|
812
|
+
c_with = ConfigWithIter.C(max_iter=100, learning_rate=0.1)
|
|
813
|
+
c_without = ConfigWithoutIter.C(learning_rate=0.1)
|
|
814
|
+
|
|
815
|
+
# The partial identifier of c_with should match full identifier of c_without
|
|
816
|
+
pid = get_partial_identifier(c_with, ConfigWithIter.checkpoints)
|
|
817
|
+
full_id = getidentifier(c_without)
|
|
818
|
+
assert (
|
|
819
|
+
pid == full_id
|
|
820
|
+
), "Partial identifier should match config without excluded fields"
|
|
821
|
+
|
|
822
|
+
|
|
823
|
+
def test_partial_identifier_multiple_groups():
|
|
824
|
+
"""Test partial identifier with parameter in multiple groups"""
|
|
825
|
+
|
|
826
|
+
class ConfigMultiGroup(Config):
|
|
827
|
+
checkpoints = subparameters(exclude_groups=[iter_group])
|
|
828
|
+
# This parameter is in both groups - should be excluded if any group is excluded
|
|
829
|
+
x: Param[int] = field(groups=[iter_group, model_group])
|
|
830
|
+
y: Param[float]
|
|
831
|
+
|
|
832
|
+
c1 = ConfigMultiGroup.C(x=1, y=0.1)
|
|
833
|
+
c2 = ConfigMultiGroup.C(x=2, y=0.1)
|
|
834
|
+
|
|
835
|
+
# Partial identifiers should be the same (x is in iter_group which is excluded)
|
|
836
|
+
pid1 = get_partial_identifier(c1, ConfigMultiGroup.checkpoints)
|
|
837
|
+
pid2 = get_partial_identifier(c2, ConfigMultiGroup.checkpoints)
|
|
838
|
+
assert (
|
|
839
|
+
pid1 == pid2
|
|
840
|
+
), "Partial identifiers should match when param is in any excluded group"
|
|
841
|
+
|
|
842
|
+
|
|
843
|
+
def test_partial_identifier_include_overrides_exclude():
|
|
844
|
+
"""Test that include_groups overrides exclude_groups"""
|
|
845
|
+
|
|
846
|
+
class ConfigIncludeOverride(Config):
|
|
847
|
+
# iter_group is excluded but also included, so it should NOT be excluded
|
|
848
|
+
partial = subparameters(
|
|
849
|
+
exclude_groups=[iter_group, model_group], include_groups=[iter_group]
|
|
850
|
+
)
|
|
851
|
+
x: Param[int] = field(groups=[iter_group])
|
|
852
|
+
y: Param[int] = field(groups=[model_group])
|
|
853
|
+
z: Param[float]
|
|
854
|
+
|
|
855
|
+
c1 = ConfigIncludeOverride.C(x=1, y=1, z=0.1)
|
|
856
|
+
c2 = ConfigIncludeOverride.C(x=2, y=1, z=0.1)
|
|
857
|
+
c3 = ConfigIncludeOverride.C(x=1, y=2, z=0.1)
|
|
858
|
+
|
|
859
|
+
# x is in iter_group which is included (overrides exclusion)
|
|
860
|
+
# so different x should give different partial identifiers
|
|
861
|
+
pid1 = get_partial_identifier(c1, ConfigIncludeOverride.partial)
|
|
862
|
+
pid2 = get_partial_identifier(c2, ConfigIncludeOverride.partial)
|
|
863
|
+
assert pid1 != pid2, "Include should override exclude - x should be included"
|
|
864
|
+
|
|
865
|
+
# y is in model_group which is excluded (not included)
|
|
866
|
+
# so different y should give SAME partial identifiers
|
|
867
|
+
pid3 = get_partial_identifier(c3, ConfigIncludeOverride.partial)
|
|
868
|
+
assert pid1 == pid3, "y is excluded - different y should give same partial ID"
|
|
869
|
+
|
|
870
|
+
|
|
871
|
+
def test_partial_identifier_exclude_all():
|
|
872
|
+
"""Test exclude_all option"""
|
|
873
|
+
|
|
874
|
+
class ConfigExcludeAll(Config):
|
|
875
|
+
# Exclude all, but include model_group
|
|
876
|
+
partial = subparameters(exclude_all=True, include_groups=[model_group])
|
|
877
|
+
x: Param[int] = field(groups=[iter_group])
|
|
878
|
+
y: Param[int] = field(groups=[model_group])
|
|
879
|
+
z: Param[float] # No group
|
|
880
|
+
|
|
881
|
+
c1 = ConfigExcludeAll.C(x=1, y=1, z=0.1)
|
|
882
|
+
c2 = ConfigExcludeAll.C(x=2, y=1, z=0.1) # Different x (excluded)
|
|
883
|
+
c3 = ConfigExcludeAll.C(x=1, y=2, z=0.1) # Different y (included)
|
|
884
|
+
c4 = ConfigExcludeAll.C(x=1, y=1, z=0.2) # Different z (excluded - no group)
|
|
885
|
+
|
|
886
|
+
pid1 = get_partial_identifier(c1, ConfigExcludeAll.partial)
|
|
887
|
+
pid2 = get_partial_identifier(c2, ConfigExcludeAll.partial)
|
|
888
|
+
pid3 = get_partial_identifier(c3, ConfigExcludeAll.partial)
|
|
889
|
+
pid4 = get_partial_identifier(c4, ConfigExcludeAll.partial)
|
|
890
|
+
|
|
891
|
+
# x is excluded (in iter_group, not included) - same partial ID
|
|
892
|
+
assert pid1 == pid2, "x is excluded - should have same partial ID"
|
|
893
|
+
|
|
894
|
+
# y is included (in model_group) - different partial ID
|
|
895
|
+
assert pid1 != pid3, "y is included - should have different partial ID"
|
|
896
|
+
|
|
897
|
+
# z is excluded (no group, exclude_all=True) - same partial ID
|
|
898
|
+
assert (
|
|
899
|
+
pid1 == pid4
|
|
900
|
+
), "z (no group) is excluded by exclude_all - should have same partial ID"
|
|
901
|
+
|
|
902
|
+
|
|
903
|
+
def test_partial_identifier_exclude_no_group():
|
|
904
|
+
"""Test exclude_no_group option"""
|
|
905
|
+
|
|
906
|
+
class ConfigExcludeNoGroup(Config):
|
|
907
|
+
partial = subparameters(exclude_no_group=True)
|
|
908
|
+
x: Param[int] = field(groups=[iter_group])
|
|
909
|
+
y: Param[float] # No group
|
|
910
|
+
|
|
911
|
+
c1 = ConfigExcludeNoGroup.C(x=1, y=0.1)
|
|
912
|
+
c2 = ConfigExcludeNoGroup.C(x=2, y=0.1) # Different x (has group - not excluded)
|
|
913
|
+
c3 = ConfigExcludeNoGroup.C(x=1, y=0.2) # Different y (no group - excluded)
|
|
914
|
+
|
|
915
|
+
pid1 = get_partial_identifier(c1, ConfigExcludeNoGroup.partial)
|
|
916
|
+
pid2 = get_partial_identifier(c2, ConfigExcludeNoGroup.partial)
|
|
917
|
+
pid3 = get_partial_identifier(c3, ConfigExcludeNoGroup.partial)
|
|
918
|
+
|
|
919
|
+
# x has a group, so it's NOT excluded by exclude_no_group
|
|
920
|
+
assert pid1 != pid2, "x has group - should have different partial ID"
|
|
921
|
+
|
|
922
|
+
# y has no group, so it IS excluded by exclude_no_group
|
|
923
|
+
assert pid1 == pid3, "y has no group - should have same partial ID"
|