experimaestro 2.0.0a8__py3-none-any.whl → 2.0.0b8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of experimaestro might be problematic. Click here for more details.

Files changed (122) hide show
  1. experimaestro/__init__.py +10 -11
  2. experimaestro/annotations.py +167 -206
  3. experimaestro/cli/__init__.py +278 -7
  4. experimaestro/cli/filter.py +42 -74
  5. experimaestro/cli/jobs.py +157 -106
  6. experimaestro/cli/refactor.py +249 -0
  7. experimaestro/click.py +0 -1
  8. experimaestro/commandline.py +19 -3
  9. experimaestro/connectors/__init__.py +20 -1
  10. experimaestro/connectors/local.py +12 -0
  11. experimaestro/core/arguments.py +182 -46
  12. experimaestro/core/identifier.py +107 -6
  13. experimaestro/core/objects/__init__.py +6 -0
  14. experimaestro/core/objects/config.py +542 -25
  15. experimaestro/core/objects/config_walk.py +20 -0
  16. experimaestro/core/serialization.py +91 -34
  17. experimaestro/core/subparameters.py +164 -0
  18. experimaestro/core/types.py +175 -38
  19. experimaestro/exceptions.py +26 -0
  20. experimaestro/experiments/cli.py +111 -25
  21. experimaestro/generators.py +50 -9
  22. experimaestro/huggingface.py +3 -1
  23. experimaestro/launcherfinder/parser.py +29 -0
  24. experimaestro/launchers/__init__.py +26 -1
  25. experimaestro/launchers/direct.py +12 -0
  26. experimaestro/launchers/slurm/base.py +154 -2
  27. experimaestro/mkdocs/metaloader.py +0 -1
  28. experimaestro/mypy.py +452 -7
  29. experimaestro/notifications.py +63 -13
  30. experimaestro/progress.py +0 -2
  31. experimaestro/rpyc.py +0 -1
  32. experimaestro/run.py +19 -6
  33. experimaestro/scheduler/base.py +510 -125
  34. experimaestro/scheduler/dependencies.py +43 -28
  35. experimaestro/scheduler/dynamic_outputs.py +259 -130
  36. experimaestro/scheduler/experiment.py +256 -31
  37. experimaestro/scheduler/interfaces.py +501 -0
  38. experimaestro/scheduler/jobs.py +216 -206
  39. experimaestro/scheduler/remote/__init__.py +31 -0
  40. experimaestro/scheduler/remote/client.py +874 -0
  41. experimaestro/scheduler/remote/protocol.py +467 -0
  42. experimaestro/scheduler/remote/server.py +423 -0
  43. experimaestro/scheduler/remote/sync.py +144 -0
  44. experimaestro/scheduler/services.py +323 -23
  45. experimaestro/scheduler/state_db.py +437 -0
  46. experimaestro/scheduler/state_provider.py +2766 -0
  47. experimaestro/scheduler/state_sync.py +891 -0
  48. experimaestro/scheduler/workspace.py +52 -10
  49. experimaestro/scriptbuilder.py +7 -0
  50. experimaestro/server/__init__.py +147 -57
  51. experimaestro/server/data/index.css +0 -125
  52. experimaestro/server/data/index.css.map +1 -1
  53. experimaestro/server/data/index.js +194 -58
  54. experimaestro/server/data/index.js.map +1 -1
  55. experimaestro/settings.py +44 -5
  56. experimaestro/sphinx/__init__.py +3 -3
  57. experimaestro/taskglobals.py +20 -0
  58. experimaestro/tests/conftest.py +80 -0
  59. experimaestro/tests/core/test_generics.py +2 -2
  60. experimaestro/tests/identifier_stability.json +45 -0
  61. experimaestro/tests/launchers/bin/sacct +6 -2
  62. experimaestro/tests/launchers/bin/sbatch +4 -2
  63. experimaestro/tests/launchers/test_slurm.py +80 -0
  64. experimaestro/tests/tasks/test_dynamic.py +231 -0
  65. experimaestro/tests/test_cli_jobs.py +615 -0
  66. experimaestro/tests/test_deprecated.py +630 -0
  67. experimaestro/tests/test_environment.py +200 -0
  68. experimaestro/tests/test_file_progress_integration.py +1 -1
  69. experimaestro/tests/test_forward.py +3 -3
  70. experimaestro/tests/test_identifier.py +372 -41
  71. experimaestro/tests/test_identifier_stability.py +458 -0
  72. experimaestro/tests/test_instance.py +3 -3
  73. experimaestro/tests/test_multitoken.py +442 -0
  74. experimaestro/tests/test_mypy.py +433 -0
  75. experimaestro/tests/test_objects.py +312 -5
  76. experimaestro/tests/test_outputs.py +2 -2
  77. experimaestro/tests/test_param.py +8 -12
  78. experimaestro/tests/test_partial_paths.py +231 -0
  79. experimaestro/tests/test_progress.py +0 -48
  80. experimaestro/tests/test_remote_state.py +671 -0
  81. experimaestro/tests/test_resumable_task.py +480 -0
  82. experimaestro/tests/test_serializers.py +141 -1
  83. experimaestro/tests/test_state_db.py +434 -0
  84. experimaestro/tests/test_subparameters.py +160 -0
  85. experimaestro/tests/test_tags.py +136 -0
  86. experimaestro/tests/test_tasks.py +107 -121
  87. experimaestro/tests/test_token_locking.py +252 -0
  88. experimaestro/tests/test_tokens.py +17 -13
  89. experimaestro/tests/test_types.py +123 -1
  90. experimaestro/tests/test_workspace_triggers.py +158 -0
  91. experimaestro/tests/token_reschedule.py +4 -2
  92. experimaestro/tests/utils.py +2 -2
  93. experimaestro/tokens.py +154 -57
  94. experimaestro/tools/diff.py +1 -1
  95. experimaestro/tui/__init__.py +8 -0
  96. experimaestro/tui/app.py +2395 -0
  97. experimaestro/tui/app.tcss +353 -0
  98. experimaestro/tui/log_viewer.py +228 -0
  99. experimaestro/utils/__init__.py +23 -0
  100. experimaestro/utils/environment.py +148 -0
  101. experimaestro/utils/git.py +129 -0
  102. experimaestro/utils/resources.py +1 -1
  103. experimaestro/version.py +34 -0
  104. {experimaestro-2.0.0a8.dist-info → experimaestro-2.0.0b8.dist-info}/METADATA +68 -38
  105. experimaestro-2.0.0b8.dist-info/RECORD +187 -0
  106. {experimaestro-2.0.0a8.dist-info → experimaestro-2.0.0b8.dist-info}/WHEEL +1 -1
  107. experimaestro-2.0.0b8.dist-info/entry_points.txt +16 -0
  108. experimaestro/compat.py +0 -6
  109. experimaestro/core/objects.pyi +0 -221
  110. experimaestro/server/data/0c35d18bf06992036b69.woff2 +0 -0
  111. experimaestro/server/data/219aa9140e099e6c72ed.woff2 +0 -0
  112. experimaestro/server/data/3a4004a46a653d4b2166.woff +0 -0
  113. experimaestro/server/data/3baa5b8f3469222b822d.woff +0 -0
  114. experimaestro/server/data/4d73cb90e394b34b7670.woff +0 -0
  115. experimaestro/server/data/4ef4218c522f1eb6b5b1.woff2 +0 -0
  116. experimaestro/server/data/5d681e2edae8c60630db.woff +0 -0
  117. experimaestro/server/data/6f420cf17cc0d7676fad.woff2 +0 -0
  118. experimaestro/server/data/c380809fd3677d7d6903.woff2 +0 -0
  119. experimaestro/server/data/f882956fd323fd322f31.woff +0 -0
  120. experimaestro-2.0.0a8.dist-info/RECORD +0 -166
  121. experimaestro-2.0.0a8.dist-info/entry_points.txt +0 -17
  122. {experimaestro-2.0.0a8.dist-info → experimaestro-2.0.0b8.dist-info}/licenses/LICENSE +0 -0
@@ -1,7 +1,7 @@
1
1
  from pathlib import Path
2
2
 
3
3
  import pytest
4
- from experimaestro import Config, Task, Annotated, copyconfig, default
4
+ from experimaestro import Config, Task, Annotated, copyconfig, field
5
5
  from experimaestro.core.arguments import Param
6
6
  from experimaestro.core.objects import ConfigMixin
7
7
  from experimaestro.generators import pathgenerator
@@ -16,17 +16,17 @@ def xp():
16
16
 
17
17
 
18
18
  class A(Config):
19
- x: Param[int] = 3
19
+ x: Param[int] = field(ignore_default=3)
20
20
 
21
21
 
22
22
  def test_object_default():
23
23
  """Test plain default value"""
24
- a = A()
24
+ a = A.C()
25
25
  assert a.x == 3
26
26
 
27
27
 
28
28
  class B(Config):
29
- a: Param[A] = A.C(x=3)
29
+ a: Param[A] = field(ignore_default=A.C(x=3))
30
30
 
31
31
 
32
32
  class C(B):
@@ -38,7 +38,7 @@ class D(B, A):
38
38
 
39
39
 
40
40
  class DefaultAnnotationConfig(Config):
41
- a: Annotated[A, default(A.C(x=3))]
41
+ a: Param[A] = field(default=A.C(x=3))
42
42
 
43
43
 
44
44
  def test_object_config_default():
@@ -84,3 +84,310 @@ def test_copyconfig(xp):
84
84
 
85
85
  assert copy_b.x == b.x
86
86
  assert "path" not in copy_b.__xpm__.values
87
+
88
+
89
+ # --- Composition operator tests (GH #33) ---
90
+
91
+
92
+ class CompositionA(Config):
93
+ x: Param[int]
94
+
95
+
96
+ class CompositionSubA(CompositionA):
97
+ """Subclass of CompositionA"""
98
+
99
+ y: Param[int] = field(ignore_default=0)
100
+
101
+
102
+ class CompositionB(Config):
103
+ a: Param[CompositionA]
104
+
105
+
106
+ class CompositionC(Config):
107
+ """Config with two parameters of same type - should be ambiguous"""
108
+
109
+ a1: Param[CompositionA]
110
+ a2: Param[CompositionA]
111
+
112
+
113
+ class CompositionD(Config):
114
+ """Config with no matching parameter"""
115
+
116
+ x: Param[int]
117
+
118
+
119
+ class CompositionE(Config):
120
+ """Config with two parameters, one subclass of the other"""
121
+
122
+ base: Param[CompositionA]
123
+ sub: Param[CompositionSubA]
124
+
125
+
126
+ def test_composition_operator():
127
+ """Test that B() @ A(x=1) is equivalent to B(a=A(x=1))"""
128
+ a = CompositionA.C(x=42)
129
+ b = CompositionB.C() @ a
130
+
131
+ assert b.a is a
132
+ assert b.a.x == 42
133
+
134
+
135
+ def test_composition_operator_chained():
136
+ """Test chaining composition operators
137
+
138
+ Chaining A @ B @ C adds both B and C to A (same outer config).
139
+ For nested structures, use parentheses: A @ (B @ C)
140
+ """
141
+
142
+ class MultiParam(Config):
143
+ a: Param[CompositionA]
144
+ b: Param[CompositionB]
145
+
146
+ # Chaining adds multiple configs to same outer config
147
+ result = MultiParam.C() @ CompositionA.C(x=10) @ CompositionB.C()
148
+
149
+ assert result.a.x == 10
150
+ assert result.b is not None
151
+
152
+
153
+ def test_composition_operator_nested():
154
+ """Test nested composition with parentheses"""
155
+
156
+ class Outer(Config):
157
+ b: Param[CompositionB]
158
+
159
+ # For nested structures, use parentheses
160
+ result = Outer.C() @ (CompositionB.C() @ CompositionA.C(x=10))
161
+
162
+ assert result.b.a.x == 10
163
+
164
+
165
+ def test_composition_operator_ambiguous():
166
+ """Test that ambiguous composition raises ValueError"""
167
+ a = CompositionA.C(x=1)
168
+
169
+ with pytest.raises(ValueError, match="Ambiguous"):
170
+ CompositionC.C() @ a
171
+
172
+
173
+ def test_composition_operator_no_match():
174
+ """Test that composition with no matching param raises ValueError"""
175
+ a = CompositionA.C(x=1)
176
+
177
+ with pytest.raises(ValueError, match="No parameter"):
178
+ CompositionD.C() @ a
179
+
180
+
181
+ def test_composition_operator_subclass():
182
+ """Test composition works with subclasses"""
183
+ sub_a = CompositionSubA.C(x=5, y=10)
184
+ b = CompositionB.C() @ sub_a
185
+
186
+ assert b.a is sub_a
187
+ assert b.a.x == 5
188
+
189
+
190
+ def test_composition_operator_subclass_hierarchy():
191
+ """Test composition when two params have subclass relationship
192
+
193
+ When CompositionSubA is passed, both 'base' (CompositionA) and 'sub'
194
+ (CompositionSubA) match. This should be ambiguous since both accept it.
195
+ """
196
+ sub_a = CompositionSubA.C(x=1, y=2)
197
+
198
+ # SubA matches both base (CompositionA) and sub (CompositionSubA)
199
+ with pytest.raises(ValueError, match="Ambiguous"):
200
+ CompositionE.C() @ sub_a
201
+
202
+
203
+ def test_composition_operator_exact_match():
204
+ """Test composition when base class instance matches only base param"""
205
+ # CompositionA matches only 'base', not 'sub' (which requires SubA)
206
+ a = CompositionA.C(x=1)
207
+ e = CompositionE.C() @ a
208
+
209
+ assert e.base is a
210
+ assert e.base.x == 1
211
+
212
+
213
+ # --- Value class decorator tests (GH #99) ---
214
+
215
+ # Test 1: Basic value class registration
216
+
217
+
218
+ class ValueBasicModel(Config):
219
+ x: Param[int] = field(ignore_default=1)
220
+
221
+
222
+ @ValueBasicModel.value_class()
223
+ class ValueBasicModelImpl(ValueBasicModel):
224
+ def compute(self):
225
+ return self.x * 2
226
+
227
+
228
+ # Test 2: Subclass without explicit value class
229
+
230
+
231
+ class ValueInheritBase(Config):
232
+ x: Param[int] = field(ignore_default=1)
233
+
234
+
235
+ @ValueInheritBase.value_class()
236
+ class ValueInheritBaseImpl(ValueInheritBase):
237
+ pass
238
+
239
+
240
+ class ValueInheritSubNoExplicit(ValueInheritBase):
241
+ """Subclass without explicit value class"""
242
+
243
+ y: Param[int] = field(ignore_default=2)
244
+
245
+
246
+ # Test 3: Value class with proper inheritance
247
+
248
+
249
+ class ValueInheritParent(Config):
250
+ x: Param[int] = field(ignore_default=1)
251
+
252
+
253
+ @ValueInheritParent.value_class()
254
+ class ValueInheritParentImpl(ValueInheritParent):
255
+ def compute(self):
256
+ return self.x * 2
257
+
258
+
259
+ class ValueInheritChild(ValueInheritParent):
260
+ y: Param[int] = field(ignore_default=2)
261
+
262
+
263
+ @ValueInheritChild.value_class()
264
+ class ValueInheritChildImpl(ValueInheritChild, ValueInheritParentImpl):
265
+ def compute_both(self):
266
+ return self.x + self.y
267
+
268
+
269
+ # Test 4: Skip intermediate class (A -> B -> C, only A and C have value classes)
270
+
271
+
272
+ class ValueSkipBase(Config):
273
+ x: Param[int] = field(ignore_default=1)
274
+
275
+
276
+ @ValueSkipBase.value_class()
277
+ class ValueSkipBaseImpl(ValueSkipBase):
278
+ def compute(self):
279
+ return self.x * 2
280
+
281
+
282
+ class ValueSkipIntermediate(ValueSkipBase):
283
+ """Intermediate class without explicit value class"""
284
+
285
+ y: Param[int] = field(ignore_default=2)
286
+
287
+
288
+ class ValueSkipDeep(ValueSkipIntermediate):
289
+ """Deep subclass with value class"""
290
+
291
+ z: Param[int] = field(ignore_default=3)
292
+
293
+
294
+ @ValueSkipDeep.value_class()
295
+ class ValueSkipDeepImpl(ValueSkipDeep, ValueSkipBaseImpl):
296
+ def compute_all(self):
297
+ return self.x + self.y + self.z
298
+
299
+
300
+ # --- Value class tests ---
301
+
302
+
303
+ def test_value_decorator_basic():
304
+ """Test basic value class registration"""
305
+ # XPMValue should return the registered value class
306
+ assert ValueBasicModel.XPMValue is ValueBasicModelImpl
307
+
308
+ # Creating an instance should use the value class
309
+ config = ValueBasicModel.C(x=5)
310
+ instance = config.instance()
311
+
312
+ assert isinstance(instance, ValueBasicModelImpl)
313
+ assert instance.x == 5
314
+ assert instance.compute() == 10
315
+
316
+
317
+ def test_value_decorator_inheritance_no_explicit():
318
+ """Test that subclass without value class uses config class as value"""
319
+ # SubModel has no explicit value class, XPMValue returns the config class
320
+ assert ValueInheritSubNoExplicit.XPMValue is ValueInheritSubNoExplicit
321
+
322
+ config = ValueInheritSubNoExplicit.C(x=3, y=4)
323
+ instance = config.instance()
324
+
325
+ # Instance is created from the config class (no explicit value type)
326
+ assert isinstance(instance, ValueInheritSubNoExplicit)
327
+ assert instance.x == 3
328
+ assert instance.y == 4
329
+
330
+
331
+ def test_value_decorator_inheritance_with_explicit():
332
+ """Test value class with proper inheritance from parent value class"""
333
+ assert ValueInheritChild.XPMValue is ValueInheritChildImpl
334
+
335
+ config = ValueInheritChild.C(x=3, y=4)
336
+ instance = config.instance()
337
+
338
+ assert isinstance(instance, ValueInheritChildImpl)
339
+ assert isinstance(instance, ValueInheritParentImpl)
340
+ assert instance.x == 3
341
+ assert instance.y == 4
342
+ assert instance.compute() == 6 # From parent value class
343
+ assert instance.compute_both() == 7 # From this value class
344
+
345
+
346
+ def test_value_decorator_must_be_subclass():
347
+ """Test that value class must be subclass of config"""
348
+
349
+ class LocalModel(Config):
350
+ x: Param[int]
351
+
352
+ class OtherConfig(Config):
353
+ z: Param[int]
354
+
355
+ with pytest.raises(TypeError, match="must be a subclass of"):
356
+
357
+ @LocalModel.value_class()
358
+ class InvalidValue(OtherConfig): # Not a subclass of LocalModel
359
+ pass
360
+
361
+
362
+ def test_value_decorator_must_inherit_parent_value():
363
+ """Test that value class must inherit from parent value class"""
364
+
365
+ class LocalBase(Config):
366
+ x: Param[int] = field(ignore_default=1)
367
+
368
+ @LocalBase.value_class()
369
+ class LocalBaseImpl(LocalBase):
370
+ pass
371
+
372
+ class LocalChild(LocalBase):
373
+ y: Param[int] = field(ignore_default=2)
374
+
375
+ with pytest.raises(TypeError, match="must be a subclass of.*parent value class"):
376
+
377
+ @LocalChild.value_class()
378
+ class InvalidChildValue(LocalChild): # Missing LocalBaseImpl inheritance
379
+ pass
380
+
381
+
382
+ def test_value_decorator_skip_intermediate():
383
+ """Test value class when intermediate class has no value class"""
384
+ # ValueSkipBase has impl, ValueSkipIntermediate has none, ValueSkipDeep has impl
385
+ assert ValueSkipDeep.XPMValue is ValueSkipDeepImpl
386
+
387
+ config = ValueSkipDeep.C(x=1, y=2, z=3)
388
+ instance = config.instance()
389
+
390
+ assert isinstance(instance, ValueSkipDeepImpl)
391
+ assert isinstance(instance, ValueSkipBaseImpl)
392
+ assert instance.compute() == 2 # From ValueSkipBaseImpl
393
+ assert instance.compute_all() == 6 # From ValueSkipDeepImpl
@@ -1,11 +1,11 @@
1
1
  """Test for task outputs"""
2
2
 
3
- from experimaestro import Config, Task, Param
3
+ from experimaestro import field, Config, Task, Param
4
4
  from experimaestro.scheduler.workspace import RunMode
5
5
 
6
6
 
7
7
  class B(Config):
8
- x: Param[int] = 1
8
+ x: Param[int] = field(ignore_default=1)
9
9
 
10
10
 
11
11
  class A(Config):
@@ -5,7 +5,6 @@ Test annotation handling for configurations and tasks
5
5
 
6
6
  # Annotation specific tests
7
7
 
8
- import sys
9
8
  from pathlib import Path
10
9
  from typing import Dict, Optional, List
11
10
  from experimaestro.core.context import SerializationContext
@@ -17,7 +16,6 @@ from experimaestro import (
17
16
  Constant,
18
17
  Param,
19
18
  Task,
20
- default,
21
19
  Meta,
22
20
  Config,
23
21
  pathgenerator,
@@ -80,8 +78,8 @@ def test_type_hinting():
80
78
  __xpmid__ = "annotations.class_variable.config"
81
79
 
82
80
  x: Param[int]
83
- y: Param[float] = 2.3
84
- y2: Annotated[float, default(2.3)]
81
+ y: Param[float] = field(ignore_default=2.3)
82
+ y2: Param[float] = field(ignore_default=2.3)
85
83
  z: Param[Optional[float]]
86
84
  t: Param[List[float]]
87
85
  w: Param[int]
@@ -180,7 +178,7 @@ def test_config_class():
180
178
 
181
179
  def test_constant():
182
180
  class A(Config):
183
- x: Constant[int] = 2
181
+ x: Constant[int] = field(ignore_default=2)
184
182
 
185
183
  a = A.C()
186
184
  assert a.x == 2, "Constant value not set"
@@ -214,7 +212,7 @@ def test_inheritance():
214
212
  x: Param[int]
215
213
 
216
214
  class B(A):
217
- y: Param[int] = 3
215
+ y: Param[int] = field(ignore_default=3)
218
216
 
219
217
  b = B.C()
220
218
  b.x = 2
@@ -227,7 +225,7 @@ def test_redefined_param():
227
225
  x: Param[int]
228
226
 
229
227
  class B(Config):
230
- x: Param[int] = 3
228
+ x: Param[int] = field(ignore_default=3)
231
229
 
232
230
  atx = A.C.__getxpmtype__().getArgument("x")
233
231
  btx = B.C.__getxpmtype__().getArgument("x")
@@ -284,7 +282,7 @@ def test_default_mismatch():
284
282
  """Test mismatch between default and type"""
285
283
 
286
284
  class A(Config):
287
- x: Param[int] = 0.2
285
+ x: Param[int] = field(ignore_default=0.2)
288
286
 
289
287
  with pytest.raises(TypeError):
290
288
  A.__getxpmtype__().getArgument("x")
@@ -297,7 +295,7 @@ def test_param_default_set():
297
295
  """Test that the default setting is well set"""
298
296
 
299
297
  class A0(Config):
300
- x: Param[int] = 2
298
+ x: Param[int] = field(ignore_default=2)
301
299
 
302
300
  assert A0.C().instance().x == 2
303
301
  assert A0.C(x=3).instance().x == 3
@@ -336,6 +334,4 @@ def test_help():
336
334
  assert xpmtype.description.strip() == "Long description of A."
337
335
  assert xpmtype.arguments["y"].help == "Parameter y"
338
336
 
339
- # Only python >= 3.9
340
- if sys.version_info.major == 3 and sys.version_info.minor > 8:
341
- assert xpmtype.arguments["x"].help == "Parameter x"
337
+ assert xpmtype.arguments["x"].help == "Parameter x"
@@ -0,0 +1,231 @@
1
+ """Integration tests for partial paths and cleanup"""
2
+
3
+ from pathlib import Path
4
+ from experimaestro import (
5
+ Task,
6
+ Param,
7
+ Meta,
8
+ field,
9
+ PathGenerator,
10
+ subparameters,
11
+ param_group,
12
+ )
13
+ from experimaestro.scheduler import JobState
14
+
15
+ from .utils import TemporaryExperiment, TemporaryDirectory
16
+
17
+
18
+ # Define parameter groups
19
+ iter_group = param_group("iter")
20
+
21
+
22
+ class TaskWithPartial(Task):
23
+ """Task that uses subparameters for partial paths"""
24
+
25
+ # Define a subparameters set
26
+ checkpoints = subparameters(exclude_groups=[iter_group])
27
+
28
+ # Parameter in iter_group - excluded from partial identifier
29
+ max_iter: Param[int] = field(groups=[iter_group])
30
+
31
+ # Parameter not in any group - included in partial identifier
32
+ learning_rate: Param[float]
33
+
34
+ # Path generated using the partial identifier
35
+ checkpoint_path: Meta[Path] = field(
36
+ default_factory=PathGenerator("checkpoint", partial=checkpoints)
37
+ )
38
+
39
+ def execute(self):
40
+ # Create the checkpoint directory and a marker file
41
+ self.checkpoint_path.mkdir(parents=True, exist_ok=True)
42
+ (self.checkpoint_path / "model.pt").write_text("checkpoint data")
43
+
44
+
45
+ def test_partial_path_created():
46
+ """Test that partial paths are correctly created during task execution"""
47
+ with TemporaryDirectory(prefix="xpm", suffix="partial") as workdir:
48
+ with TemporaryExperiment("partial_test", workdir=workdir, maxwait=30):
49
+ task = TaskWithPartial.C(max_iter=100, learning_rate=0.1).submit()
50
+
51
+ assert task.__xpm__.job.state == JobState.DONE
52
+
53
+ # Verify the partial path was created
54
+ assert task.checkpoint_path.exists()
55
+ assert (task.checkpoint_path / "model.pt").exists()
56
+
57
+ # Verify the path is in the partials directory
58
+ partials_path = workdir / "partials"
59
+ assert partials_path.exists()
60
+
61
+ # The checkpoint_path should be under partials/TASK_ID/checkpoints/PARTIAL_ID/
62
+ # Use resolve() to handle symlinks like /var -> /private/var on macOS
63
+ assert task.checkpoint_path.resolve().is_relative_to(partials_path.resolve())
64
+
65
+
66
+ def test_partial_path_shared_across_tasks():
67
+ """Test that tasks with same non-excluded params share partial paths"""
68
+ with TemporaryDirectory(prefix="xpm", suffix="partial_shared") as workdir:
69
+ with TemporaryExperiment("partial_shared", workdir=workdir, maxwait=30):
70
+ # Submit two tasks with different max_iter but same learning_rate
71
+ task1 = TaskWithPartial.C(max_iter=100, learning_rate=0.1).submit()
72
+ task2 = TaskWithPartial.C(max_iter=200, learning_rate=0.1).submit()
73
+
74
+ assert task1.__xpm__.job.state == JobState.DONE
75
+ assert task2.__xpm__.job.state == JobState.DONE
76
+
77
+ # They should share the same partial path
78
+ assert task1.checkpoint_path == task2.checkpoint_path
79
+
80
+
81
+ def test_partial_path_different_for_different_params():
82
+ """Test that tasks with different non-excluded params have different partial paths"""
83
+ with TemporaryDirectory(prefix="xpm", suffix="partial_diff") as workdir:
84
+ with TemporaryExperiment("partial_diff", workdir=workdir, maxwait=30):
85
+ # Submit two tasks with different learning_rate
86
+ task1 = TaskWithPartial.C(max_iter=100, learning_rate=0.1).submit()
87
+ task2 = TaskWithPartial.C(max_iter=100, learning_rate=0.2).submit()
88
+
89
+ assert task1.__xpm__.job.state == JobState.DONE
90
+ assert task2.__xpm__.job.state == JobState.DONE
91
+
92
+ # They should have different partial paths
93
+ assert task1.checkpoint_path != task2.checkpoint_path
94
+
95
+
96
+ def test_partial_registered_in_database():
97
+ """Test that partials are registered in the database when jobs are submitted"""
98
+ from experimaestro.scheduler.state_provider import WorkspaceStateProvider
99
+ from experimaestro.scheduler.state_db import PartialModel, JobPartialModel
100
+
101
+ with TemporaryDirectory(prefix="xpm", suffix="partial_db") as workdir:
102
+ with TemporaryExperiment("partial_db", workdir=workdir, maxwait=30) as xp:
103
+ task = TaskWithPartial.C(max_iter=100, learning_rate=0.1).submit()
104
+
105
+ assert task.__xpm__.job.state == JobState.DONE
106
+
107
+ # Get the state provider and check database
108
+ # Note: Must use read_only=False since the experiment left a singleton
109
+ # with read_only=False that hasn't been closed yet
110
+ provider = WorkspaceStateProvider.get_instance(workdir, read_only=False)
111
+
112
+ try:
113
+ with provider.workspace_db.bind_ctx([PartialModel, JobPartialModel]):
114
+ # Check that partial is registered
115
+ partials = list(PartialModel.select())
116
+ assert len(partials) == 1
117
+ assert partials[0].subparameters_name == "checkpoints"
118
+
119
+ # Check that job is linked to partial
120
+ job_partials = list(JobPartialModel.select())
121
+ assert len(job_partials) == 1
122
+ assert job_partials[0].partial_id == partials[0].partial_id
123
+ assert job_partials[0].experiment_id == xp.workdir.name
124
+ finally:
125
+ provider.close()
126
+
127
+
128
+ def test_orphan_partial_cleanup():
129
+ """Test that orphan partials are cleaned up when jobs are deleted"""
130
+ from experimaestro.scheduler.state_provider import WorkspaceStateProvider
131
+ from experimaestro.scheduler.state_db import PartialModel, JobPartialModel
132
+
133
+ with TemporaryDirectory(prefix="xpm", suffix="partial_cleanup") as workdir:
134
+ with TemporaryExperiment("partial_cleanup", workdir=workdir, maxwait=30) as xp:
135
+ task = TaskWithPartial.C(max_iter=100, learning_rate=0.1).submit()
136
+
137
+ assert task.__xpm__.job.state == JobState.DONE
138
+ checkpoint_path = task.checkpoint_path
139
+
140
+ # Verify partial path exists
141
+ assert checkpoint_path.exists()
142
+
143
+ # Get the state provider
144
+ provider = WorkspaceStateProvider.get_instance(workdir, read_only=False)
145
+
146
+ try:
147
+ # Delete the job
148
+ with provider.workspace_db.bind_ctx([PartialModel, JobPartialModel]):
149
+ job_partials = list(JobPartialModel.select())
150
+ assert len(job_partials) == 1
151
+
152
+ # Delete job (this also removes job-partial link)
153
+ provider.delete_job(
154
+ task.__xpm__.job.identifier,
155
+ xp.workdir.name,
156
+ xp.run_id,
157
+ )
158
+
159
+ # Now the partial should be orphaned
160
+ orphans = provider.get_orphan_partials()
161
+ assert len(orphans) == 1
162
+
163
+ # Cleanup orphan partials
164
+ deleted = provider.cleanup_orphan_partials(perform=True)
165
+ assert len(deleted) == 1
166
+
167
+ # Verify partial directory is deleted
168
+ assert not checkpoint_path.exists()
169
+
170
+ # Verify partial is removed from database
171
+ with provider.workspace_db.bind_ctx([PartialModel]):
172
+ partials = list(PartialModel.select())
173
+ assert len(partials) == 0
174
+ finally:
175
+ provider.close()
176
+
177
+
178
+ def test_shared_partial_not_orphaned():
179
+ """Test that partials shared by multiple jobs are not orphaned until all jobs deleted"""
180
+ from experimaestro.scheduler.state_provider import WorkspaceStateProvider
181
+
182
+ with TemporaryDirectory(prefix="xpm", suffix="partial_shared_cleanup") as workdir:
183
+ with TemporaryExperiment(
184
+ "partial_shared_cleanup", workdir=workdir, maxwait=30
185
+ ) as xp:
186
+ # Submit two tasks with same learning_rate (same partial)
187
+ task1 = TaskWithPartial.C(max_iter=100, learning_rate=0.1).submit()
188
+ task2 = TaskWithPartial.C(max_iter=200, learning_rate=0.1).submit()
189
+
190
+ assert task1.__xpm__.job.state == JobState.DONE
191
+ assert task2.__xpm__.job.state == JobState.DONE
192
+
193
+ # They share the same partial path
194
+ checkpoint_path = task1.checkpoint_path
195
+ assert checkpoint_path == task2.checkpoint_path
196
+ assert checkpoint_path.exists()
197
+
198
+ provider = WorkspaceStateProvider.get_instance(workdir, read_only=False)
199
+
200
+ try:
201
+ # Delete first job
202
+ provider.delete_job(
203
+ task1.__xpm__.job.identifier,
204
+ xp.workdir.name,
205
+ xp.run_id,
206
+ )
207
+
208
+ # Partial should NOT be orphaned (still used by task2)
209
+ orphans = provider.get_orphan_partials()
210
+ assert len(orphans) == 0
211
+
212
+ # Partial directory should still exist
213
+ assert checkpoint_path.exists()
214
+
215
+ # Delete second job
216
+ provider.delete_job(
217
+ task2.__xpm__.job.identifier,
218
+ xp.workdir.name,
219
+ xp.run_id,
220
+ )
221
+
222
+ # Now partial should be orphaned
223
+ orphans = provider.get_orphan_partials()
224
+ assert len(orphans) == 1
225
+
226
+ # Cleanup
227
+ deleted = provider.cleanup_orphan_partials(perform=True)
228
+ assert len(deleted) == 1
229
+ assert not checkpoint_path.exists()
230
+ finally:
231
+ provider.close()