experimaestro 2.0.0a8__py3-none-any.whl → 2.0.0b8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of experimaestro might be problematic. Click here for more details.

Files changed (122) hide show
  1. experimaestro/__init__.py +10 -11
  2. experimaestro/annotations.py +167 -206
  3. experimaestro/cli/__init__.py +278 -7
  4. experimaestro/cli/filter.py +42 -74
  5. experimaestro/cli/jobs.py +157 -106
  6. experimaestro/cli/refactor.py +249 -0
  7. experimaestro/click.py +0 -1
  8. experimaestro/commandline.py +19 -3
  9. experimaestro/connectors/__init__.py +20 -1
  10. experimaestro/connectors/local.py +12 -0
  11. experimaestro/core/arguments.py +182 -46
  12. experimaestro/core/identifier.py +107 -6
  13. experimaestro/core/objects/__init__.py +6 -0
  14. experimaestro/core/objects/config.py +542 -25
  15. experimaestro/core/objects/config_walk.py +20 -0
  16. experimaestro/core/serialization.py +91 -34
  17. experimaestro/core/subparameters.py +164 -0
  18. experimaestro/core/types.py +175 -38
  19. experimaestro/exceptions.py +26 -0
  20. experimaestro/experiments/cli.py +111 -25
  21. experimaestro/generators.py +50 -9
  22. experimaestro/huggingface.py +3 -1
  23. experimaestro/launcherfinder/parser.py +29 -0
  24. experimaestro/launchers/__init__.py +26 -1
  25. experimaestro/launchers/direct.py +12 -0
  26. experimaestro/launchers/slurm/base.py +154 -2
  27. experimaestro/mkdocs/metaloader.py +0 -1
  28. experimaestro/mypy.py +452 -7
  29. experimaestro/notifications.py +63 -13
  30. experimaestro/progress.py +0 -2
  31. experimaestro/rpyc.py +0 -1
  32. experimaestro/run.py +19 -6
  33. experimaestro/scheduler/base.py +510 -125
  34. experimaestro/scheduler/dependencies.py +43 -28
  35. experimaestro/scheduler/dynamic_outputs.py +259 -130
  36. experimaestro/scheduler/experiment.py +256 -31
  37. experimaestro/scheduler/interfaces.py +501 -0
  38. experimaestro/scheduler/jobs.py +216 -206
  39. experimaestro/scheduler/remote/__init__.py +31 -0
  40. experimaestro/scheduler/remote/client.py +874 -0
  41. experimaestro/scheduler/remote/protocol.py +467 -0
  42. experimaestro/scheduler/remote/server.py +423 -0
  43. experimaestro/scheduler/remote/sync.py +144 -0
  44. experimaestro/scheduler/services.py +323 -23
  45. experimaestro/scheduler/state_db.py +437 -0
  46. experimaestro/scheduler/state_provider.py +2766 -0
  47. experimaestro/scheduler/state_sync.py +891 -0
  48. experimaestro/scheduler/workspace.py +52 -10
  49. experimaestro/scriptbuilder.py +7 -0
  50. experimaestro/server/__init__.py +147 -57
  51. experimaestro/server/data/index.css +0 -125
  52. experimaestro/server/data/index.css.map +1 -1
  53. experimaestro/server/data/index.js +194 -58
  54. experimaestro/server/data/index.js.map +1 -1
  55. experimaestro/settings.py +44 -5
  56. experimaestro/sphinx/__init__.py +3 -3
  57. experimaestro/taskglobals.py +20 -0
  58. experimaestro/tests/conftest.py +80 -0
  59. experimaestro/tests/core/test_generics.py +2 -2
  60. experimaestro/tests/identifier_stability.json +45 -0
  61. experimaestro/tests/launchers/bin/sacct +6 -2
  62. experimaestro/tests/launchers/bin/sbatch +4 -2
  63. experimaestro/tests/launchers/test_slurm.py +80 -0
  64. experimaestro/tests/tasks/test_dynamic.py +231 -0
  65. experimaestro/tests/test_cli_jobs.py +615 -0
  66. experimaestro/tests/test_deprecated.py +630 -0
  67. experimaestro/tests/test_environment.py +200 -0
  68. experimaestro/tests/test_file_progress_integration.py +1 -1
  69. experimaestro/tests/test_forward.py +3 -3
  70. experimaestro/tests/test_identifier.py +372 -41
  71. experimaestro/tests/test_identifier_stability.py +458 -0
  72. experimaestro/tests/test_instance.py +3 -3
  73. experimaestro/tests/test_multitoken.py +442 -0
  74. experimaestro/tests/test_mypy.py +433 -0
  75. experimaestro/tests/test_objects.py +312 -5
  76. experimaestro/tests/test_outputs.py +2 -2
  77. experimaestro/tests/test_param.py +8 -12
  78. experimaestro/tests/test_partial_paths.py +231 -0
  79. experimaestro/tests/test_progress.py +0 -48
  80. experimaestro/tests/test_remote_state.py +671 -0
  81. experimaestro/tests/test_resumable_task.py +480 -0
  82. experimaestro/tests/test_serializers.py +141 -1
  83. experimaestro/tests/test_state_db.py +434 -0
  84. experimaestro/tests/test_subparameters.py +160 -0
  85. experimaestro/tests/test_tags.py +136 -0
  86. experimaestro/tests/test_tasks.py +107 -121
  87. experimaestro/tests/test_token_locking.py +252 -0
  88. experimaestro/tests/test_tokens.py +17 -13
  89. experimaestro/tests/test_types.py +123 -1
  90. experimaestro/tests/test_workspace_triggers.py +158 -0
  91. experimaestro/tests/token_reschedule.py +4 -2
  92. experimaestro/tests/utils.py +2 -2
  93. experimaestro/tokens.py +154 -57
  94. experimaestro/tools/diff.py +1 -1
  95. experimaestro/tui/__init__.py +8 -0
  96. experimaestro/tui/app.py +2395 -0
  97. experimaestro/tui/app.tcss +353 -0
  98. experimaestro/tui/log_viewer.py +228 -0
  99. experimaestro/utils/__init__.py +23 -0
  100. experimaestro/utils/environment.py +148 -0
  101. experimaestro/utils/git.py +129 -0
  102. experimaestro/utils/resources.py +1 -1
  103. experimaestro/version.py +34 -0
  104. {experimaestro-2.0.0a8.dist-info → experimaestro-2.0.0b8.dist-info}/METADATA +68 -38
  105. experimaestro-2.0.0b8.dist-info/RECORD +187 -0
  106. {experimaestro-2.0.0a8.dist-info → experimaestro-2.0.0b8.dist-info}/WHEEL +1 -1
  107. experimaestro-2.0.0b8.dist-info/entry_points.txt +16 -0
  108. experimaestro/compat.py +0 -6
  109. experimaestro/core/objects.pyi +0 -221
  110. experimaestro/server/data/0c35d18bf06992036b69.woff2 +0 -0
  111. experimaestro/server/data/219aa9140e099e6c72ed.woff2 +0 -0
  112. experimaestro/server/data/3a4004a46a653d4b2166.woff +0 -0
  113. experimaestro/server/data/3baa5b8f3469222b822d.woff +0 -0
  114. experimaestro/server/data/4d73cb90e394b34b7670.woff +0 -0
  115. experimaestro/server/data/4ef4218c522f1eb6b5b1.woff2 +0 -0
  116. experimaestro/server/data/5d681e2edae8c60630db.woff +0 -0
  117. experimaestro/server/data/6f420cf17cc0d7676fad.woff2 +0 -0
  118. experimaestro/server/data/c380809fd3677d7d6903.woff2 +0 -0
  119. experimaestro/server/data/f882956fd323fd322f31.woff +0 -0
  120. experimaestro-2.0.0a8.dist-info/RECORD +0 -166
  121. experimaestro-2.0.0a8.dist-info/entry_points.txt +0 -17
  122. {experimaestro-2.0.0a8.dist-info → experimaestro-2.0.0b8.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,200 @@
1
+ """Tests for environment capture utilities"""
2
+
3
+ import json
4
+ import pytest
5
+
6
+ from experimaestro.utils.git import get_git_info
7
+ from experimaestro.utils.environment import (
8
+ get_environment_info,
9
+ get_editable_packages_git_info,
10
+ save_environment_info,
11
+ load_environment_info,
12
+ )
13
+
14
+
15
+ class TestGetGitInfo:
16
+ """Tests for get_git_info function"""
17
+
18
+ def test_returns_dict_in_git_repo(self, tmp_path):
19
+ """Test that get_git_info returns a dict when in a git repo"""
20
+ # Use the current working directory which should be a git repo
21
+ git_info = get_git_info()
22
+
23
+ assert git_info is not None
24
+ assert isinstance(git_info, dict)
25
+ assert "commit" in git_info
26
+ assert "commit_short" in git_info
27
+ assert "branch" in git_info
28
+ assert "dirty" in git_info
29
+ assert "message" in git_info
30
+ assert "author" in git_info
31
+ assert "date" in git_info
32
+
33
+ def test_commit_format(self):
34
+ """Test that commit hashes have correct format"""
35
+ git_info = get_git_info()
36
+ if git_info is None:
37
+ pytest.skip("Not in a git repository")
38
+
39
+ # Full commit should be 40 hex characters
40
+ assert len(git_info["commit"]) == 40
41
+ assert all(c in "0123456789abcdef" for c in git_info["commit"])
42
+
43
+ # Short commit should be 7 characters
44
+ assert len(git_info["commit_short"]) == 7
45
+
46
+ def test_returns_none_for_non_git_dir(self, tmp_path):
47
+ """Test that get_git_info returns None for non-git directories"""
48
+ git_info = get_git_info(tmp_path)
49
+ assert git_info is None
50
+
51
+ def test_dirty_flag(self):
52
+ """Test that dirty flag is a boolean"""
53
+ git_info = get_git_info()
54
+ if git_info is None:
55
+ pytest.skip("Not in a git repository")
56
+
57
+ assert isinstance(git_info["dirty"], bool)
58
+
59
+
60
+ class TestGetEnvironmentInfo:
61
+ """Tests for get_environment_info function"""
62
+
63
+ def test_returns_dict_with_required_keys(self):
64
+ """Test that get_environment_info returns dict with required keys"""
65
+ env_info = get_environment_info()
66
+
67
+ assert isinstance(env_info, dict)
68
+ assert "python_version" in env_info
69
+ assert "packages" in env_info
70
+ assert "editable_packages" in env_info
71
+
72
+ def test_python_version_format(self):
73
+ """Test that python_version has correct format"""
74
+ env_info = get_environment_info()
75
+ version = env_info["python_version"]
76
+
77
+ # Should be in format X.Y.Z
78
+ parts = version.split(".")
79
+ assert len(parts) == 3
80
+ assert all(part.isdigit() for part in parts)
81
+
82
+ def test_packages_is_dict(self):
83
+ """Test that packages is a dict of name -> version"""
84
+ env_info = get_environment_info()
85
+ packages = env_info["packages"]
86
+
87
+ assert isinstance(packages, dict)
88
+ assert len(packages) > 0 # Should have at least some packages
89
+
90
+ # Check that all values are strings (versions)
91
+ for name, version in packages.items():
92
+ assert isinstance(name, str)
93
+ assert isinstance(version, str)
94
+
95
+ def test_experimaestro_is_editable(self):
96
+ """Test that experimaestro itself is detected as editable"""
97
+ env_info = get_environment_info()
98
+ editable = env_info["editable_packages"]
99
+
100
+ # When running tests, experimaestro should be installed in editable mode
101
+ assert "experimaestro" in editable
102
+ assert "version" in editable["experimaestro"]
103
+ assert "path" in editable["experimaestro"]
104
+ assert "git" in editable["experimaestro"]
105
+
106
+ def test_editable_package_has_git_info(self):
107
+ """Test that editable packages include git info"""
108
+ env_info = get_environment_info()
109
+ editable = env_info["editable_packages"]
110
+
111
+ # experimaestro should have git info since it's in a git repo
112
+ if "experimaestro" in editable:
113
+ git_info = editable["experimaestro"]["git"]
114
+ if git_info is not None: # May be None if not in git repo
115
+ assert "commit" in git_info
116
+ assert "dirty" in git_info
117
+
118
+
119
+ class TestGetEditablePackagesGitInfo:
120
+ """Tests for get_editable_packages_git_info function"""
121
+
122
+ def test_returns_dict(self):
123
+ """Test that function returns a dict"""
124
+ result = get_editable_packages_git_info()
125
+ assert isinstance(result, dict)
126
+
127
+ def test_contains_experimaestro(self):
128
+ """Test that experimaestro is in the result"""
129
+ result = get_editable_packages_git_info()
130
+ assert "experimaestro" in result
131
+
132
+
133
+ class TestSaveAndLoadEnvironmentInfo:
134
+ """Tests for save_environment_info and load_environment_info functions"""
135
+
136
+ def test_save_creates_file(self, tmp_path):
137
+ """Test that save_environment_info creates a JSON file"""
138
+ path = tmp_path / "environment.json"
139
+
140
+ result = save_environment_info(path)
141
+
142
+ assert path.exists()
143
+ assert isinstance(result, dict)
144
+
145
+ def test_save_writes_valid_json(self, tmp_path):
146
+ """Test that saved file contains valid JSON"""
147
+ path = tmp_path / "environment.json"
148
+
149
+ save_environment_info(path)
150
+
151
+ content = json.loads(path.read_text())
152
+ assert "python_version" in content
153
+ assert "packages" in content
154
+ assert "editable_packages" in content
155
+
156
+ def test_load_reads_saved_data(self, tmp_path):
157
+ """Test that load_environment_info reads back saved data"""
158
+ path = tmp_path / "environment.json"
159
+
160
+ saved = save_environment_info(path)
161
+ loaded = load_environment_info(path)
162
+
163
+ assert loaded == saved
164
+
165
+ def test_load_returns_none_for_missing_file(self, tmp_path):
166
+ """Test that load returns None for non-existent file"""
167
+ path = tmp_path / "nonexistent.json"
168
+
169
+ result = load_environment_info(path)
170
+
171
+ assert result is None
172
+
173
+ def test_load_returns_none_for_invalid_json(self, tmp_path):
174
+ """Test that load returns None for invalid JSON"""
175
+ path = tmp_path / "invalid.json"
176
+ path.write_text("not valid json{")
177
+
178
+ result = load_environment_info(path)
179
+
180
+ assert result is None
181
+
182
+
183
+ class TestExperimentEnvironmentSaving:
184
+ """Integration tests for environment saving in experiments"""
185
+
186
+ def test_experiment_saves_environment_info(self, xpmdirectory):
187
+ """Test that experiment saves environment.json on start"""
188
+ from experimaestro import experiment
189
+
190
+ # Just enter the experiment context, no need to run any tasks
191
+ with experiment(xpmdirectory, "test-env-save", port=-1) as xp:
192
+ pass # environment.json should be saved on __enter__
193
+
194
+ env_path = xp.workdir / "environment.json"
195
+ assert env_path.exists()
196
+
197
+ env_info = json.loads(env_path.read_text())
198
+ assert "python_version" in env_info
199
+ assert "packages" in env_info
200
+ assert "editable_packages" in env_info
@@ -321,7 +321,7 @@ def test_file_progress_concurrent_experiments():
321
321
 
322
322
  with TemporaryExperiment(
323
323
  "file-progress-concurrent-2",
324
- workdir=xp1.workdir,
324
+ workdir=xp1.workspace.path,
325
325
  maxwait=max_wait,
326
326
  port=0,
327
327
  ) as xp2:
@@ -1,11 +1,11 @@
1
- from experimaestro import Param, Config
1
+ from experimaestro import field, Param, Config
2
2
  from experimaestro.click import forwardoption
3
3
  import click
4
4
 
5
5
 
6
6
  def test_main():
7
7
  class MyModel(Config):
8
- epochs: Param[int] = 100
8
+ epochs: Param[int] = field(ignore_default=100)
9
9
  """Number of learning epochs"""
10
10
 
11
11
  @forwardoption.epochs(MyModel)
@@ -19,7 +19,7 @@ def test_main():
19
19
 
20
20
  def test_rename():
21
21
  class MyModel(Config):
22
- epochs: Param[int] = 100
22
+ epochs: Param[int] = field(ignore_default=100)
23
23
  """Number of learning epochs"""
24
24
 
25
25
  @forwardoption.epochs(MyModel, "my-epochs")
@@ -5,8 +5,8 @@ from pathlib import Path
5
5
  from typing import Dict, List, Optional
6
6
  from experimaestro import (
7
7
  Param,
8
- deprecate,
9
8
  Config,
9
+ InstanceConfig,
10
10
  Constant,
11
11
  Meta,
12
12
  Option,
@@ -14,6 +14,8 @@ from experimaestro import (
14
14
  field,
15
15
  Task,
16
16
  LightweightTask,
17
+ subparameters,
18
+ param_group,
17
19
  )
18
20
  from experimaestro.core.objects import (
19
21
  ConfigInformation,
@@ -34,7 +36,7 @@ class B(Config):
34
36
 
35
37
 
36
38
  class C(Config):
37
- a: Param[int] = 1
39
+ a: Param[int] = field(ignore_default=1)
38
40
  b: Param[int]
39
41
 
40
42
 
@@ -129,7 +131,7 @@ def test_identifier_option():
129
131
  class OptionConfig(Config):
130
132
  __xpmid__ = "test.identifier.option"
131
133
  a: Param[int]
132
- b: Option[int] = 1
134
+ b: Option[int] = field(ignore_default=1)
133
135
 
134
136
  assert_notequal(OptionConfig.C(a=2), OptionConfig.C(a=1))
135
137
  assert_equal(OptionConfig.C(a=1, b=2), OptionConfig.C(a=1))
@@ -229,7 +231,7 @@ def test_identifier_defaultnew():
229
231
  __xpmid__ = "defaultnew"
230
232
 
231
233
  a: Param[int]
232
- b: Param[int] = 1
234
+ b: Param[int] = field(ignore_default=1)
233
235
 
234
236
  class A(Config):
235
237
  __xpmid__ = "defaultnew"
@@ -281,41 +283,6 @@ def test_identifier_constant():
281
283
  assert_notequal(A1.C(), A2.C())
282
284
 
283
285
 
284
- def test_identifier_deprecated_class():
285
- """Test that when submitting the task, the computed identifier is the one of
286
- the new class"""
287
-
288
- class NewConfig(Config):
289
- __xpmid__ = "new"
290
-
291
- @deprecate
292
- class OldConfig(NewConfig):
293
- __xpmid__ = "old"
294
-
295
- class DerivedConfig(NewConfig):
296
- __xpmid__ = "derived"
297
-
298
- assert_notequal(
299
- NewConfig.C(), DerivedConfig.C(), "A derived configuration has another ID"
300
- )
301
- assert_equal(
302
- NewConfig.C(),
303
- OldConfig.C(),
304
- "Deprecated and new configuration have the same ID",
305
- )
306
-
307
-
308
- def test_identifier_deprecated_attribute():
309
- class Values(Config):
310
- values: Param[List[int]] = []
311
-
312
- @deprecate
313
- def value(self, x):
314
- self.values = [x]
315
-
316
- assert_equal(Values.C(values=[1]), Values.C(value=1))
317
-
318
-
319
286
  class MetaA(Config):
320
287
  x: Param[int]
321
288
 
@@ -365,7 +332,7 @@ def test_identifier_meta():
365
332
 
366
333
  def test_identifier_meta_default_dict():
367
334
  class DictConfig(Config):
368
- params: Param[Dict[str, MetaA]] = {}
335
+ params: Param[Dict[str, MetaA]] = field(ignore_default={})
369
336
 
370
337
  assert_equal(
371
338
  DictConfig.C(params={}),
@@ -381,7 +348,7 @@ def test_identifier_meta_default_dict():
381
348
 
382
349
  def test_identifier_meta_default_array():
383
350
  class ArrayConfigWithDefault(Config):
384
- array: Param[List[MetaA]] = []
351
+ array: Param[List[MetaA]] = field(ignore_default=[])
385
352
 
386
353
  # Array (with default) with mixed
387
354
  assert_equal(
@@ -590,3 +557,367 @@ def test_identifier_loop():
590
557
  for i in range(len(configs)):
591
558
  for j in range(1, len(configs)):
592
559
  assert identifiers[i][0] == identifiers[i][j]
560
+
561
+
562
+ # --- Test InstanceConfig
563
+
564
+
565
+ class SubModel(InstanceConfig):
566
+ """Test InstanceConfig - instances are distinguished even with same params"""
567
+
568
+ pass
569
+
570
+
571
+ class SubModelAsConfig(Config):
572
+ """Same as SubModel but as regular Config for backwards compat testing"""
573
+
574
+ __xpmid__ = "test.SubModel"
575
+ pass
576
+
577
+
578
+ class Model(Config):
579
+ """Model that can contain SubModel instances"""
580
+
581
+ m1: Param[SubModel]
582
+ m2: Param[SubModel]
583
+
584
+
585
+ class ModelWithRegularConfig(Config):
586
+ """Model using regular Config instead of InstanceConfig"""
587
+
588
+ __xpmid__ = "test.Model"
589
+ m1: Param[SubModelAsConfig]
590
+ m2: Param[SubModelAsConfig]
591
+
592
+
593
+ def test_instanceconfig_backwards_compat():
594
+ """Model using single InstanceConfig should have same ID as with regular Config"""
595
+ # Using InstanceConfig (first occurrence only, no instance marker added)
596
+ sm1 = SubModel.C()
597
+ sm1.__xpmtype__.identifier.name = "test.SubModel" # Match the __xpmid__
598
+ m_instance = Model.C(m1=sm1, m2=sm1)
599
+ m_instance.__xpmtype__.identifier.name = "test.Model"
600
+
601
+ # Using regular Config
602
+ sc1 = SubModelAsConfig.C()
603
+ m_regular = ModelWithRegularConfig.C(m1=sc1, m2=sc1)
604
+
605
+ # Should have same identifier (backwards compatible)
606
+ assert_equal(
607
+ m_instance, m_regular, "Single InstanceConfig should be backwards compatible"
608
+ )
609
+
610
+
611
+ def test_instanceconfig_same_params_different_instances():
612
+ """Model with separate InstanceConfig instances should differ from shared"""
613
+ sm1 = SubModel.C()
614
+ sm2 = SubModel.C()
615
+
616
+ # Using the same instance twice (shared)
617
+ m1 = Model.C(m1=sm1, m2=sm1)
618
+
619
+ # Using different instances (separate)
620
+ m2 = Model.C(m1=sm1, m2=sm2)
621
+
622
+ # These should be different because sm2 is a second instance with same params
623
+ assert_notequal(m1, m2, "Models with shared vs separate instances should differ")
624
+
625
+
626
+ def test_instanceconfig_reused_instance():
627
+ """Reusing the same InstanceConfig instance should give same ID"""
628
+ sm1 = SubModel.C()
629
+
630
+ # Using the same instance object multiple times should be OK
631
+ m1 = Model.C(m1=sm1, m2=sm1)
632
+ m2 = Model.C(m1=sm1, m2=sm1)
633
+
634
+ # These should be the same because we're reusing the exact same objects
635
+ assert_equal(m1, m2, "Models with same instance objects should be equal")
636
+
637
+
638
+ def test_instanceconfig_serialization():
639
+ """InstanceConfig identifiers should be stable after serialization"""
640
+ sm1 = SubModel.C()
641
+ sm2 = SubModel.C()
642
+
643
+ # Create a model with two different instances
644
+ m1 = Model.C(m1=sm1, m2=sm2)
645
+ original_id = getidentifier(m1)
646
+
647
+ # Serialize and reload
648
+ check_reload(m1)
649
+
650
+ # The identifier should remain the same
651
+ assert getidentifier(m1) == original_id
652
+
653
+
654
+ # --- Test ignore_default vs default in field() ---
655
+
656
+
657
+ def test_identifier_field_ignore_default():
658
+ """Test that field(ignore_default=X) ignores value in identifier when value == X"""
659
+
660
+ class ConfigWithIgnoreDefault(Config):
661
+ __xpmid__ = "test.identifier.field_ignore_default"
662
+ a: Param[int] = field(ignore_default=1)
663
+ b: Param[int]
664
+
665
+ # When a=1 (matches ignore_default), should be same as not specifying a
666
+ class ConfigWithoutA(Config):
667
+ __xpmid__ = "test.identifier.field_ignore_default"
668
+ b: Param[int]
669
+
670
+ assert_equal(
671
+ ConfigWithIgnoreDefault.C(a=1, b=2),
672
+ ConfigWithIgnoreDefault.C(b=2),
673
+ "field(ignore_default=1) should ignore a=1 in identifier",
674
+ )
675
+ assert_equal(
676
+ ConfigWithIgnoreDefault.C(a=1, b=2),
677
+ ConfigWithoutA.C(b=2),
678
+ "Config with ignore_default should match config without that param",
679
+ )
680
+
681
+ # When a=2 (doesn't match ignore_default), should be included
682
+ assert_notequal(
683
+ ConfigWithIgnoreDefault.C(a=2, b=2),
684
+ ConfigWithIgnoreDefault.C(b=2),
685
+ "field(ignore_default=1) should include a=2 in identifier",
686
+ )
687
+
688
+
689
+ def test_identifier_field_default():
690
+ """Test that field(default=X) includes value in identifier even when value == X"""
691
+
692
+ class ConfigWithDefault(Config):
693
+ __xpmid__ = "test.identifier.field_default"
694
+ a: Param[int] = field(default=1)
695
+ b: Param[int]
696
+
697
+ class ConfigWithoutA(Config):
698
+ __xpmid__ = "test.identifier.field_default"
699
+ b: Param[int]
700
+
701
+ # When a=1 (matches default), should still be included in identifier
702
+ # so Config with a=1 should differ from Config without a
703
+ assert_notequal(
704
+ ConfigWithDefault.C(a=1, b=2),
705
+ ConfigWithoutA.C(b=2),
706
+ "field(default=1) should include a=1 in identifier",
707
+ )
708
+
709
+ # But two configs with same a=1 should be equal
710
+ assert_equal(
711
+ ConfigWithDefault.C(a=1, b=2),
712
+ ConfigWithDefault.C(a=1, b=2),
713
+ "Same values should have same identifier",
714
+ )
715
+
716
+
717
+ def test_identifier_field_default_vs_ignore_default():
718
+ """Test difference between field(default=X) and field(ignore_default=X)"""
719
+
720
+ class ConfigWithDefault(Config):
721
+ __xpmid__ = "test.identifier.field_default_vs_ignore"
722
+ a: Param[int] = field(default=1)
723
+ b: Param[int]
724
+
725
+ class ConfigWithIgnoreDefault(Config):
726
+ __xpmid__ = "test.identifier.field_default_vs_ignore"
727
+ a: Param[int] = field(ignore_default=1)
728
+ b: Param[int]
729
+
730
+ # Both with a=1, b=2 - should differ because one includes a, other doesn't
731
+ assert_notequal(
732
+ ConfigWithDefault.C(a=1, b=2),
733
+ ConfigWithIgnoreDefault.C(a=1, b=2),
734
+ "field(default=1) vs field(ignore_default=1) should differ when a=1",
735
+ )
736
+
737
+ # Both with a=2 (not matching default), should be the same
738
+ assert_equal(
739
+ ConfigWithDefault.C(a=2, b=2),
740
+ ConfigWithIgnoreDefault.C(a=2, b=2),
741
+ "field(default=1) vs field(ignore_default=1) should be same when a!=1",
742
+ )
743
+
744
+
745
+ # --- Test partial identifiers (subparameters) ---
746
+
747
+
748
+ # Define parameter groups at module level
749
+ iter_group = param_group("iter")
750
+ model_group = param_group("model")
751
+
752
+
753
+ def get_partial_identifier(config, sp):
754
+ """Helper to get partial identifier for a config and subparameters"""
755
+ return config.__xpm__.get_partial_identifier(sp).all
756
+
757
+
758
+ def test_partial_identifier_excludes_grouped_params():
759
+ """Test that partial identifier excludes parameters in excluded groups"""
760
+
761
+ class ConfigWithGroups(Config):
762
+ checkpoints = subparameters(exclude_groups=[iter_group])
763
+ max_iter: Param[int] = field(groups=[iter_group])
764
+ learning_rate: Param[float]
765
+
766
+ c1 = ConfigWithGroups.C(max_iter=100, learning_rate=0.1)
767
+ c2 = ConfigWithGroups.C(max_iter=200, learning_rate=0.1)
768
+
769
+ # Full identifiers should differ (max_iter is different)
770
+ assert_notequal(c1, c2, "Full identifiers should differ when max_iter differs")
771
+
772
+ # Partial identifiers should be the same (max_iter is excluded)
773
+ pid1 = get_partial_identifier(c1, ConfigWithGroups.checkpoints)
774
+ pid2 = get_partial_identifier(c2, ConfigWithGroups.checkpoints)
775
+ assert (
776
+ pid1 == pid2
777
+ ), "Partial identifiers should match when only excluded params differ"
778
+
779
+
780
+ def test_partial_identifier_includes_ungrouped_params():
781
+ """Test that partial identifier includes parameters not in excluded groups"""
782
+
783
+ class ConfigWithGroups(Config):
784
+ checkpoints = subparameters(exclude_groups=[iter_group])
785
+ max_iter: Param[int] = field(groups=[iter_group])
786
+ learning_rate: Param[float]
787
+
788
+ c1 = ConfigWithGroups.C(max_iter=100, learning_rate=0.1)
789
+ c2 = ConfigWithGroups.C(max_iter=100, learning_rate=0.2)
790
+
791
+ # Partial identifiers should differ (learning_rate is not excluded)
792
+ pid1 = get_partial_identifier(c1, ConfigWithGroups.checkpoints)
793
+ pid2 = get_partial_identifier(c2, ConfigWithGroups.checkpoints)
794
+ assert (
795
+ pid1 != pid2
796
+ ), "Partial identifiers should differ when non-excluded params differ"
797
+
798
+
799
+ def test_partial_identifier_matches_config_without_excluded():
800
+ """Test that partial identifier matches config without the excluded fields"""
801
+
802
+ class ConfigWithIter(Config):
803
+ __xpmid__ = "test.partial_identifier.config"
804
+ checkpoints = subparameters(exclude_groups=[iter_group])
805
+ max_iter: Param[int] = field(groups=[iter_group])
806
+ learning_rate: Param[float]
807
+
808
+ class ConfigWithoutIter(Config):
809
+ __xpmid__ = "test.partial_identifier.config"
810
+ learning_rate: Param[float]
811
+
812
+ c_with = ConfigWithIter.C(max_iter=100, learning_rate=0.1)
813
+ c_without = ConfigWithoutIter.C(learning_rate=0.1)
814
+
815
+ # The partial identifier of c_with should match full identifier of c_without
816
+ pid = get_partial_identifier(c_with, ConfigWithIter.checkpoints)
817
+ full_id = getidentifier(c_without)
818
+ assert (
819
+ pid == full_id
820
+ ), "Partial identifier should match config without excluded fields"
821
+
822
+
823
+ def test_partial_identifier_multiple_groups():
824
+ """Test partial identifier with parameter in multiple groups"""
825
+
826
+ class ConfigMultiGroup(Config):
827
+ checkpoints = subparameters(exclude_groups=[iter_group])
828
+ # This parameter is in both groups - should be excluded if any group is excluded
829
+ x: Param[int] = field(groups=[iter_group, model_group])
830
+ y: Param[float]
831
+
832
+ c1 = ConfigMultiGroup.C(x=1, y=0.1)
833
+ c2 = ConfigMultiGroup.C(x=2, y=0.1)
834
+
835
+ # Partial identifiers should be the same (x is in iter_group which is excluded)
836
+ pid1 = get_partial_identifier(c1, ConfigMultiGroup.checkpoints)
837
+ pid2 = get_partial_identifier(c2, ConfigMultiGroup.checkpoints)
838
+ assert (
839
+ pid1 == pid2
840
+ ), "Partial identifiers should match when param is in any excluded group"
841
+
842
+
843
+ def test_partial_identifier_include_overrides_exclude():
844
+ """Test that include_groups overrides exclude_groups"""
845
+
846
+ class ConfigIncludeOverride(Config):
847
+ # iter_group is excluded but also included, so it should NOT be excluded
848
+ partial = subparameters(
849
+ exclude_groups=[iter_group, model_group], include_groups=[iter_group]
850
+ )
851
+ x: Param[int] = field(groups=[iter_group])
852
+ y: Param[int] = field(groups=[model_group])
853
+ z: Param[float]
854
+
855
+ c1 = ConfigIncludeOverride.C(x=1, y=1, z=0.1)
856
+ c2 = ConfigIncludeOverride.C(x=2, y=1, z=0.1)
857
+ c3 = ConfigIncludeOverride.C(x=1, y=2, z=0.1)
858
+
859
+ # x is in iter_group which is included (overrides exclusion)
860
+ # so different x should give different partial identifiers
861
+ pid1 = get_partial_identifier(c1, ConfigIncludeOverride.partial)
862
+ pid2 = get_partial_identifier(c2, ConfigIncludeOverride.partial)
863
+ assert pid1 != pid2, "Include should override exclude - x should be included"
864
+
865
+ # y is in model_group which is excluded (not included)
866
+ # so different y should give SAME partial identifiers
867
+ pid3 = get_partial_identifier(c3, ConfigIncludeOverride.partial)
868
+ assert pid1 == pid3, "y is excluded - different y should give same partial ID"
869
+
870
+
871
+ def test_partial_identifier_exclude_all():
872
+ """Test exclude_all option"""
873
+
874
+ class ConfigExcludeAll(Config):
875
+ # Exclude all, but include model_group
876
+ partial = subparameters(exclude_all=True, include_groups=[model_group])
877
+ x: Param[int] = field(groups=[iter_group])
878
+ y: Param[int] = field(groups=[model_group])
879
+ z: Param[float] # No group
880
+
881
+ c1 = ConfigExcludeAll.C(x=1, y=1, z=0.1)
882
+ c2 = ConfigExcludeAll.C(x=2, y=1, z=0.1) # Different x (excluded)
883
+ c3 = ConfigExcludeAll.C(x=1, y=2, z=0.1) # Different y (included)
884
+ c4 = ConfigExcludeAll.C(x=1, y=1, z=0.2) # Different z (excluded - no group)
885
+
886
+ pid1 = get_partial_identifier(c1, ConfigExcludeAll.partial)
887
+ pid2 = get_partial_identifier(c2, ConfigExcludeAll.partial)
888
+ pid3 = get_partial_identifier(c3, ConfigExcludeAll.partial)
889
+ pid4 = get_partial_identifier(c4, ConfigExcludeAll.partial)
890
+
891
+ # x is excluded (in iter_group, not included) - same partial ID
892
+ assert pid1 == pid2, "x is excluded - should have same partial ID"
893
+
894
+ # y is included (in model_group) - different partial ID
895
+ assert pid1 != pid3, "y is included - should have different partial ID"
896
+
897
+ # z is excluded (no group, exclude_all=True) - same partial ID
898
+ assert (
899
+ pid1 == pid4
900
+ ), "z (no group) is excluded by exclude_all - should have same partial ID"
901
+
902
+
903
+ def test_partial_identifier_exclude_no_group():
904
+ """Test exclude_no_group option"""
905
+
906
+ class ConfigExcludeNoGroup(Config):
907
+ partial = subparameters(exclude_no_group=True)
908
+ x: Param[int] = field(groups=[iter_group])
909
+ y: Param[float] # No group
910
+
911
+ c1 = ConfigExcludeNoGroup.C(x=1, y=0.1)
912
+ c2 = ConfigExcludeNoGroup.C(x=2, y=0.1) # Different x (has group - not excluded)
913
+ c3 = ConfigExcludeNoGroup.C(x=1, y=0.2) # Different y (no group - excluded)
914
+
915
+ pid1 = get_partial_identifier(c1, ConfigExcludeNoGroup.partial)
916
+ pid2 = get_partial_identifier(c2, ConfigExcludeNoGroup.partial)
917
+ pid3 = get_partial_identifier(c3, ConfigExcludeNoGroup.partial)
918
+
919
+ # x has a group, so it's NOT excluded by exclude_no_group
920
+ assert pid1 != pid2, "x has group - should have different partial ID"
921
+
922
+ # y has no group, so it IS excluded by exclude_no_group
923
+ assert pid1 == pid3, "y has no group - should have same partial ID"