experimaestro 2.0.0a8__py3-none-any.whl → 2.0.0b4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of experimaestro might be problematic. Click here for more details.

Files changed (116) hide show
  1. experimaestro/__init__.py +10 -11
  2. experimaestro/annotations.py +167 -206
  3. experimaestro/cli/__init__.py +130 -5
  4. experimaestro/cli/filter.py +42 -74
  5. experimaestro/cli/jobs.py +157 -106
  6. experimaestro/cli/refactor.py +249 -0
  7. experimaestro/click.py +0 -1
  8. experimaestro/commandline.py +19 -3
  9. experimaestro/connectors/__init__.py +20 -1
  10. experimaestro/connectors/local.py +12 -0
  11. experimaestro/core/arguments.py +182 -46
  12. experimaestro/core/identifier.py +107 -6
  13. experimaestro/core/objects/__init__.py +6 -0
  14. experimaestro/core/objects/config.py +542 -25
  15. experimaestro/core/objects/config_walk.py +20 -0
  16. experimaestro/core/serialization.py +91 -34
  17. experimaestro/core/subparameters.py +164 -0
  18. experimaestro/core/types.py +175 -38
  19. experimaestro/exceptions.py +26 -0
  20. experimaestro/experiments/cli.py +107 -25
  21. experimaestro/generators.py +50 -9
  22. experimaestro/huggingface.py +3 -1
  23. experimaestro/launcherfinder/parser.py +29 -0
  24. experimaestro/launchers/__init__.py +26 -1
  25. experimaestro/launchers/direct.py +12 -0
  26. experimaestro/launchers/slurm/base.py +154 -2
  27. experimaestro/mkdocs/metaloader.py +0 -1
  28. experimaestro/mypy.py +452 -7
  29. experimaestro/notifications.py +63 -13
  30. experimaestro/progress.py +0 -2
  31. experimaestro/rpyc.py +0 -1
  32. experimaestro/run.py +19 -6
  33. experimaestro/scheduler/base.py +489 -125
  34. experimaestro/scheduler/dependencies.py +43 -28
  35. experimaestro/scheduler/dynamic_outputs.py +259 -130
  36. experimaestro/scheduler/experiment.py +225 -30
  37. experimaestro/scheduler/interfaces.py +474 -0
  38. experimaestro/scheduler/jobs.py +216 -206
  39. experimaestro/scheduler/services.py +186 -12
  40. experimaestro/scheduler/state_db.py +388 -0
  41. experimaestro/scheduler/state_provider.py +2345 -0
  42. experimaestro/scheduler/state_sync.py +834 -0
  43. experimaestro/scheduler/workspace.py +52 -10
  44. experimaestro/scriptbuilder.py +7 -0
  45. experimaestro/server/__init__.py +147 -57
  46. experimaestro/server/data/index.css +0 -125
  47. experimaestro/server/data/index.css.map +1 -1
  48. experimaestro/server/data/index.js +194 -58
  49. experimaestro/server/data/index.js.map +1 -1
  50. experimaestro/settings.py +44 -5
  51. experimaestro/sphinx/__init__.py +3 -3
  52. experimaestro/taskglobals.py +20 -0
  53. experimaestro/tests/conftest.py +80 -0
  54. experimaestro/tests/core/test_generics.py +2 -2
  55. experimaestro/tests/identifier_stability.json +45 -0
  56. experimaestro/tests/launchers/bin/sacct +6 -2
  57. experimaestro/tests/launchers/bin/sbatch +4 -2
  58. experimaestro/tests/launchers/test_slurm.py +80 -0
  59. experimaestro/tests/tasks/test_dynamic.py +231 -0
  60. experimaestro/tests/test_cli_jobs.py +615 -0
  61. experimaestro/tests/test_deprecated.py +630 -0
  62. experimaestro/tests/test_environment.py +200 -0
  63. experimaestro/tests/test_file_progress_integration.py +1 -1
  64. experimaestro/tests/test_forward.py +3 -3
  65. experimaestro/tests/test_identifier.py +372 -41
  66. experimaestro/tests/test_identifier_stability.py +458 -0
  67. experimaestro/tests/test_instance.py +3 -3
  68. experimaestro/tests/test_multitoken.py +442 -0
  69. experimaestro/tests/test_mypy.py +433 -0
  70. experimaestro/tests/test_objects.py +312 -5
  71. experimaestro/tests/test_outputs.py +2 -2
  72. experimaestro/tests/test_param.py +8 -12
  73. experimaestro/tests/test_partial_paths.py +231 -0
  74. experimaestro/tests/test_progress.py +0 -48
  75. experimaestro/tests/test_resumable_task.py +480 -0
  76. experimaestro/tests/test_serializers.py +141 -1
  77. experimaestro/tests/test_state_db.py +434 -0
  78. experimaestro/tests/test_subparameters.py +160 -0
  79. experimaestro/tests/test_tags.py +136 -0
  80. experimaestro/tests/test_tasks.py +107 -121
  81. experimaestro/tests/test_token_locking.py +252 -0
  82. experimaestro/tests/test_tokens.py +17 -13
  83. experimaestro/tests/test_types.py +123 -1
  84. experimaestro/tests/test_workspace_triggers.py +158 -0
  85. experimaestro/tests/token_reschedule.py +4 -2
  86. experimaestro/tests/utils.py +2 -2
  87. experimaestro/tokens.py +154 -57
  88. experimaestro/tools/diff.py +1 -1
  89. experimaestro/tui/__init__.py +8 -0
  90. experimaestro/tui/app.py +2303 -0
  91. experimaestro/tui/app.tcss +353 -0
  92. experimaestro/tui/log_viewer.py +228 -0
  93. experimaestro/utils/__init__.py +23 -0
  94. experimaestro/utils/environment.py +148 -0
  95. experimaestro/utils/git.py +129 -0
  96. experimaestro/utils/resources.py +1 -1
  97. experimaestro/version.py +34 -0
  98. {experimaestro-2.0.0a8.dist-info → experimaestro-2.0.0b4.dist-info}/METADATA +68 -38
  99. experimaestro-2.0.0b4.dist-info/RECORD +181 -0
  100. {experimaestro-2.0.0a8.dist-info → experimaestro-2.0.0b4.dist-info}/WHEEL +1 -1
  101. experimaestro-2.0.0b4.dist-info/entry_points.txt +16 -0
  102. experimaestro/compat.py +0 -6
  103. experimaestro/core/objects.pyi +0 -221
  104. experimaestro/server/data/0c35d18bf06992036b69.woff2 +0 -0
  105. experimaestro/server/data/219aa9140e099e6c72ed.woff2 +0 -0
  106. experimaestro/server/data/3a4004a46a653d4b2166.woff +0 -0
  107. experimaestro/server/data/3baa5b8f3469222b822d.woff +0 -0
  108. experimaestro/server/data/4d73cb90e394b34b7670.woff +0 -0
  109. experimaestro/server/data/4ef4218c522f1eb6b5b1.woff2 +0 -0
  110. experimaestro/server/data/5d681e2edae8c60630db.woff +0 -0
  111. experimaestro/server/data/6f420cf17cc0d7676fad.woff2 +0 -0
  112. experimaestro/server/data/c380809fd3677d7d6903.woff2 +0 -0
  113. experimaestro/server/data/f882956fd323fd322f31.woff +0 -0
  114. experimaestro-2.0.0a8.dist-info/RECORD +0 -166
  115. experimaestro-2.0.0a8.dist-info/entry_points.txt +0 -17
  116. {experimaestro-2.0.0a8.dist-info → experimaestro-2.0.0b4.dist-info}/licenses/LICENSE +0 -0
@@ -90,54 +90,6 @@ def test_progress_basic():
90
90
  assert info.progress == v
91
91
 
92
92
 
93
- def test_progress_multiple():
94
- """Test that even with two schedulers, we get notified"""
95
- max_wait = 5
96
-
97
- with TemporaryExperiment(
98
- "progress-progress-multiple-1", maxwait=max_wait, port=0
99
- ) as xp1:
100
- assert xp1.server is not None
101
- assert xp1.server.port > 0
102
-
103
- listener1 = ProgressListener()
104
- xp1.scheduler.addlistener(listener1)
105
-
106
- out = ProgressingTask.C().submit()
107
- path = out.path # type: Path
108
- job = out.__xpm__.job
109
-
110
- logger.info("Waiting for job to start (1)")
111
- while job.state.notstarted():
112
- time.sleep(1e-2)
113
-
114
- with TemporaryExperiment(
115
- "progress-progress-multiple-2",
116
- workdir=xp1.workdir,
117
- maxwait=max_wait,
118
- port=0,
119
- ) as xp2:
120
- assert xp2.server is not None
121
- assert xp2.server.port > 0
122
- listener2 = ProgressListener()
123
- xp2.scheduler.addlistener(listener2)
124
-
125
- out = ProgressingTask.C().submit()
126
- job = out.__xpm__.job # type: CommandLineJob
127
- logger.info("Waiting for job to start (2)")
128
- while job.state.notstarted():
129
- time.sleep(1e-2)
130
-
131
- # Both schedulers should receive the job progress information
132
- logger.info("Checking job progress")
133
- progresses = [i / 10.0 for i in range(11)]
134
- for v in progresses:
135
- writeprogress(path, v)
136
- if v < 1:
137
- assert listener1.progresses.get()[0].progress == v
138
- assert listener2.progresses.get()[0].progress == v
139
-
140
-
141
93
  NestedTasks = Tuple[str, Union[int, List["NestedTasks"]]]
142
94
 
143
95
 
@@ -0,0 +1,480 @@
1
+ """Tests for ResumableTask with timeout retry logic"""
2
+
3
+ import json
4
+ from pathlib import Path
5
+ from experimaestro import field, ResumableTask, Task, Param, GracefulTimeout
6
+ from experimaestro.scheduler.workspace import RunMode
7
+ from experimaestro.scheduler import JobState, JobFailureStatus
8
+ from experimaestro.scheduler.jobs import JobStateError
9
+ from experimaestro.scheduler.interfaces import JobState as JobStateClass
10
+ from experimaestro.connectors import Process, ProcessState
11
+ from experimaestro.connectors.local import LocalConnector
12
+ from experimaestro.launchers.direct import DirectLauncher
13
+ from experimaestro.commandline import CommandLineJob
14
+ from .utils import TemporaryExperiment
15
+
16
+
17
+ class MockTimeoutCommandLineJob(CommandLineJob):
18
+ """CommandLineJob that simulates timeouts based on attempt count"""
19
+
20
+ def __init__(self, *args, timeout_count=0, checkpoint_file=None, **kwargs):
21
+ super().__init__(*args, **kwargs)
22
+ self.timeout_count = timeout_count
23
+ self.test_checkpoint_file = checkpoint_file
24
+
25
+ async def aio_run(self):
26
+ """Override to simulate timeout behavior"""
27
+ # Execute the actual task in a thread
28
+ from experimaestro.utils.asyncio import asyncThreadcheck
29
+
30
+ await asyncThreadcheck("execute", self.config.execute)
31
+
32
+ # Return a mock process that simulates timeout based on attempt count
33
+ return MockTimeoutProcess(self.test_checkpoint_file, self.timeout_count)
34
+
35
+ async def aio_process(self):
36
+ """No existing process"""
37
+ return None
38
+
39
+
40
+ class MockTimeoutProcess(Process):
41
+ """Process that returns TIMEOUT for first N attempts"""
42
+
43
+ def __init__(self, checkpoint_file: Path, timeout_count: int):
44
+ self.checkpoint_file = checkpoint_file
45
+ self.timeout_count = timeout_count
46
+
47
+ def wait(self) -> int:
48
+ # Always return success (code 0) - timeout detection is in get_job_state
49
+ return 0
50
+
51
+ async def aio_state(self, timeout: float | None = None) -> ProcessState:
52
+ return ProcessState.DONE
53
+
54
+ def get_job_state(self, code: int) -> "JobState":
55
+ """Return TIMEOUT for first timeout_count attempts"""
56
+ # Read attempt count from checkpoint
57
+ attempt = 1
58
+ if self.checkpoint_file.exists():
59
+ attempt = int(self.checkpoint_file.read_text())
60
+
61
+ # Return TIMEOUT for first timeout_count attempts
62
+ if attempt <= self.timeout_count:
63
+ return JobStateError(JobFailureStatus.TIMEOUT)
64
+
65
+ return JobState.DONE
66
+
67
+
68
+ class MockTimeoutLauncher(DirectLauncher):
69
+ """Launcher that creates jobs simulating timeouts"""
70
+
71
+ def __init__(self, timeout_count: int, checkpoint_file: Path):
72
+ super().__init__(LocalConnector())
73
+ self.timeout_count = timeout_count
74
+ self.checkpoint_file = checkpoint_file
75
+
76
+
77
+ # Monkey-patch the task type to use our mock job
78
+ # This is done by overriding the task factory
79
+ def create_mock_timeout_task(timeout_count: int, checkpoint_file: Path):
80
+ """Create a task type that uses MockTimeoutCommandLineJob"""
81
+
82
+ def mock_job_factory(commandline):
83
+ class MockCommandLineTask:
84
+ def __init__(self, commandline):
85
+ self.commandline = commandline
86
+
87
+ def __call__(
88
+ self,
89
+ pyobject,
90
+ *,
91
+ launcher=None,
92
+ workspace=None,
93
+ run_mode=None,
94
+ max_retries=None,
95
+ ):
96
+ return MockTimeoutCommandLineJob(
97
+ self.commandline,
98
+ pyobject,
99
+ launcher=launcher,
100
+ workspace=workspace,
101
+ run_mode=run_mode,
102
+ max_retries=max_retries,
103
+ timeout_count=timeout_count,
104
+ checkpoint_file=checkpoint_file,
105
+ )
106
+
107
+ return MockCommandLineTask(commandline)
108
+
109
+ return mock_job_factory
110
+
111
+
112
+ class CountingResumableTask(ResumableTask):
113
+ """Resumable task that counts execution attempts"""
114
+
115
+ checkpoint: Param[Path]
116
+
117
+ def execute(self):
118
+ # Count attempts in checkpoint file
119
+ attempt = 1
120
+ if self.checkpoint.exists():
121
+ attempt = int(self.checkpoint.read_text()) + 1
122
+
123
+ self.checkpoint.write_text(str(attempt))
124
+
125
+
126
+ class SimpleResumableTask(ResumableTask):
127
+ """Simple resumable task for testing"""
128
+
129
+ def execute(self):
130
+ # This would normally contain checkpoint logic
131
+ pass
132
+
133
+
134
+ class SimpleNonResumableTask(Task):
135
+ """Simple non-resumable task for testing"""
136
+
137
+ def execute(self):
138
+ pass
139
+
140
+
141
+ def test_resumable_task_has_resumable_flag():
142
+ """Test that ResumableTask instances are correctly identified"""
143
+ with TemporaryExperiment("resumable_flag", maxwait=0):
144
+ launcher = DirectLauncher(LocalConnector())
145
+
146
+ # Submit resumable task
147
+ resumable = SimpleResumableTask.C().submit(
148
+ launcher=launcher, run_mode=RunMode.DRY_RUN
149
+ )
150
+ assert resumable.__xpm__.job.resumable is True
151
+
152
+ # Submit non-resumable task
153
+ non_resumable = SimpleNonResumableTask.C().submit(
154
+ launcher=launcher, run_mode=RunMode.DRY_RUN
155
+ )
156
+ assert non_resumable.__xpm__.job.resumable is False
157
+
158
+
159
+ def test_max_retries_default():
160
+ """Test that default max_retries is 3"""
161
+ with TemporaryExperiment("max_retries_default", maxwait=0):
162
+ launcher = DirectLauncher(LocalConnector())
163
+
164
+ task = SimpleResumableTask.C().submit(
165
+ launcher=launcher, run_mode=RunMode.DRY_RUN
166
+ )
167
+
168
+ # Default should be 3 (from workspace settings)
169
+ assert task.__xpm__.job.max_retries == 3
170
+ assert task.__xpm__.job.retry_count == 0
171
+
172
+
173
+ def test_max_retries_custom():
174
+ """Test that custom max_retries parameter is respected"""
175
+ with TemporaryExperiment("max_retries_custom", maxwait=0):
176
+ launcher = DirectLauncher(LocalConnector())
177
+
178
+ task = SimpleResumableTask.C().submit(
179
+ launcher=launcher, run_mode=RunMode.DRY_RUN, max_retries=5
180
+ )
181
+
182
+ assert task.__xpm__.job.max_retries == 5
183
+ assert task.__xpm__.job.retry_count == 0
184
+
185
+
186
+ def test_max_retries_zero():
187
+ """Test that max_retries=0 is allowed (no retries)"""
188
+ with TemporaryExperiment("max_retries_zero", maxwait=0):
189
+ launcher = DirectLauncher(LocalConnector())
190
+
191
+ task = SimpleResumableTask.C().submit(
192
+ launcher=launcher, run_mode=RunMode.DRY_RUN, max_retries=0
193
+ )
194
+
195
+ assert task.__xpm__.job.max_retries == 0
196
+ assert task.__xpm__.job.retry_count == 0
197
+
198
+
199
+ def test_resumable_task_succeeds_after_timeouts():
200
+ """Test that a resumable task retries and succeeds after timeouts"""
201
+ with TemporaryExperiment("resumable_timeout_success", maxwait=20) as xp:
202
+ checkpoint_file = xp.workspace.path / "checkpoint.txt"
203
+ launcher = DirectLauncher(LocalConnector())
204
+
205
+ # Create task config
206
+ task_config = CountingResumableTask.C(checkpoint=checkpoint_file)
207
+
208
+ # Create mock job directly
209
+ job = MockTimeoutCommandLineJob(
210
+ None, # commandline not used
211
+ task_config,
212
+ workspace=xp.workspace,
213
+ launcher=launcher,
214
+ max_retries=5,
215
+ timeout_count=2,
216
+ checkpoint_file=checkpoint_file,
217
+ )
218
+
219
+ # Set the job on the config
220
+ task_config.__xpm__.job = job
221
+
222
+ # Submit to scheduler
223
+ from experimaestro.scheduler import experiment
224
+
225
+ experiment.CURRENT.submit(job)
226
+
227
+ # Should succeed after 3 attempts (2 timeouts + 1 success)
228
+ state = job.wait()
229
+ assert state == JobState.DONE
230
+ assert job.retry_count == 2
231
+ # Checkpoint should show 3 executions
232
+ assert int(checkpoint_file.read_text()) == 3
233
+
234
+
235
+ def test_resumable_task_fails_after_max_retries():
236
+ """Test that a resumable task fails after exceeding max_retries"""
237
+ from experimaestro.scheduler import FailedExperiment
238
+ import pytest
239
+
240
+ with pytest.raises(FailedExperiment):
241
+ with TemporaryExperiment("resumable_timeout_fail", maxwait=20) as xp:
242
+ checkpoint_file = xp.workspace.path / "checkpoint.txt"
243
+ launcher = DirectLauncher(LocalConnector())
244
+
245
+ # Create task config
246
+ task_config = CountingResumableTask.C(checkpoint=checkpoint_file)
247
+
248
+ # Create mock job directly
249
+ job = MockTimeoutCommandLineJob(
250
+ None, # commandline not used
251
+ task_config,
252
+ workspace=xp.workspace,
253
+ launcher=launcher,
254
+ max_retries=3,
255
+ timeout_count=10,
256
+ checkpoint_file=checkpoint_file,
257
+ )
258
+
259
+ # Set the job on the config
260
+ task_config.__xpm__.job = job
261
+
262
+ # Submit to scheduler
263
+ from experimaestro.scheduler import experiment
264
+
265
+ experiment.CURRENT.submit(job)
266
+
267
+ # Wait for job to complete (will fail)
268
+ state = job.wait()
269
+
270
+ # Verify job failed correctly
271
+ assert isinstance(state, JobStateError)
272
+ assert state.failure_reason == JobFailureStatus.TIMEOUT
273
+ # retry_count should be max_retries + 1 (initial attempt + 3 retries = 4 total)
274
+ assert job.retry_count == 4
275
+ # Checkpoint should show 4 executions
276
+ assert int(checkpoint_file.read_text()) == 4
277
+
278
+
279
+ # =============================================================================
280
+ # Tests for JobState.from_path and .failed file format
281
+ # =============================================================================
282
+
283
+
284
+ def test_job_state_from_path_json_timeout(tmp_path):
285
+ """Test JobState.from_path reads JSON format with timeout reason"""
286
+ failed_file = tmp_path / "test.failed"
287
+ failed_file.write_text(
288
+ json.dumps({"code": 1, "reason": "timeout", "message": "Graceful"})
289
+ )
290
+
291
+ state = JobStateClass.from_path(tmp_path, "test")
292
+ assert isinstance(state, JobStateError)
293
+ assert state.failure_reason == JobFailureStatus.TIMEOUT
294
+
295
+
296
+ def test_job_state_from_path_json_memory(tmp_path):
297
+ """Test JobState.from_path reads JSON format with memory reason"""
298
+ failed_file = tmp_path / "test.failed"
299
+ failed_file.write_text(json.dumps({"code": 1, "reason": "memory"}))
300
+
301
+ state = JobStateClass.from_path(tmp_path, "test")
302
+ assert isinstance(state, JobStateError)
303
+ assert state.failure_reason == JobFailureStatus.MEMORY
304
+
305
+
306
+ def test_job_state_from_path_json_dependency(tmp_path):
307
+ """Test JobState.from_path reads JSON format with dependency reason"""
308
+ failed_file = tmp_path / "test.failed"
309
+ failed_file.write_text(json.dumps({"code": 1, "reason": "dependency"}))
310
+
311
+ state = JobStateClass.from_path(tmp_path, "test")
312
+ assert isinstance(state, JobStateError)
313
+ assert state.failure_reason == JobFailureStatus.DEPENDENCY
314
+
315
+
316
+ def test_job_state_from_path_json_failed(tmp_path):
317
+ """Test JobState.from_path reads JSON format with failed reason"""
318
+ failed_file = tmp_path / "test.failed"
319
+ failed_file.write_text(json.dumps({"code": 1, "reason": "failed"}))
320
+
321
+ state = JobStateClass.from_path(tmp_path, "test")
322
+ assert isinstance(state, JobStateError)
323
+ assert state.failure_reason == JobFailureStatus.FAILED
324
+
325
+
326
+ def test_job_state_from_path_json_unknown_reason(tmp_path):
327
+ """Test JobState.from_path handles unknown reason gracefully"""
328
+ failed_file = tmp_path / "test.failed"
329
+ failed_file.write_text(json.dumps({"code": 1, "reason": "unknown_reason"}))
330
+
331
+ state = JobStateClass.from_path(tmp_path, "test")
332
+ assert isinstance(state, JobStateError)
333
+ assert state.failure_reason == JobFailureStatus.FAILED # Falls back to FAILED
334
+
335
+
336
+ def test_job_state_from_path_legacy_integer_nonzero(tmp_path):
337
+ """Test JobState.from_path reads legacy integer format (non-zero = error)"""
338
+ failed_file = tmp_path / "test.failed"
339
+ failed_file.write_text("1")
340
+
341
+ state = JobStateClass.from_path(tmp_path, "test")
342
+ assert isinstance(state, JobStateError)
343
+ assert state.failure_reason == JobFailureStatus.FAILED
344
+
345
+
346
+ def test_job_state_from_path_legacy_integer_zero(tmp_path):
347
+ """Test JobState.from_path reads legacy integer format (zero = done)"""
348
+ failed_file = tmp_path / "test.failed"
349
+ failed_file.write_text("0")
350
+
351
+ state = JobStateClass.from_path(tmp_path, "test")
352
+ assert state == JobState.DONE
353
+
354
+
355
+ def test_job_state_from_path_done_file(tmp_path):
356
+ """Test JobState.from_path reads .done file"""
357
+ done_file = tmp_path / "test.done"
358
+ done_file.touch()
359
+
360
+ state = JobStateClass.from_path(tmp_path, "test")
361
+ assert state == JobState.DONE
362
+
363
+
364
+ def test_job_state_from_path_no_file(tmp_path):
365
+ """Test JobState.from_path returns None when no file exists"""
366
+ state = JobStateClass.from_path(tmp_path, "test")
367
+ assert state is None
368
+
369
+
370
+ def test_job_state_from_path_done_takes_precedence(tmp_path):
371
+ """Test JobState.from_path prefers .done over .failed"""
372
+ done_file = tmp_path / "test.done"
373
+ done_file.touch()
374
+ failed_file = tmp_path / "test.failed"
375
+ failed_file.write_text(json.dumps({"code": 1, "reason": "failed"}))
376
+
377
+ state = JobStateClass.from_path(tmp_path, "test")
378
+ assert state == JobState.DONE
379
+
380
+
381
+ # =============================================================================
382
+ # Tests for GracefulTimeout exception
383
+ # =============================================================================
384
+
385
+
386
+ class GracefulTimeoutTask(ResumableTask):
387
+ """Task that raises GracefulTimeout"""
388
+
389
+ checkpoint: Param[Path]
390
+ should_timeout: Param[bool] = field(ignore_default=True)
391
+
392
+ def execute(self):
393
+ # Count attempts in checkpoint file
394
+ attempt = 1
395
+ if self.checkpoint.exists():
396
+ attempt = int(self.checkpoint.read_text()) + 1
397
+ self.checkpoint.write_text(str(attempt))
398
+
399
+ # Raise GracefulTimeout on first attempt if should_timeout is True
400
+ if self.should_timeout and attempt == 1:
401
+ raise GracefulTimeout("Not enough time for another epoch")
402
+
403
+
404
+ # =============================================================================
405
+ # Tests for remaining_time() method with mock launcher
406
+ # =============================================================================
407
+
408
+
409
+ class MockLauncherWithRemainingTime(DirectLauncher):
410
+ """Mock launcher that provides remaining_time via launcher_info_code()"""
411
+
412
+ def __init__(self, remaining_time_value: float | None):
413
+ super().__init__(LocalConnector())
414
+ self.remaining_time_value = remaining_time_value
415
+
416
+ def launcher_info_code(self) -> str:
417
+ """Generate code to set up MockLauncherInformation with the specified value"""
418
+ if self.remaining_time_value is None:
419
+ return (
420
+ " from experimaestro.tests.test_resumable_task import MockLauncherInformation\n"
421
+ " from experimaestro import taskglobals\n"
422
+ " taskglobals.Env.instance().launcher_info = MockLauncherInformation(None)\n"
423
+ )
424
+ return (
425
+ " from experimaestro.tests.test_resumable_task import MockLauncherInformation\n"
426
+ " from experimaestro import taskglobals\n"
427
+ f" taskglobals.Env.instance().launcher_info = MockLauncherInformation({self.remaining_time_value})\n"
428
+ )
429
+
430
+
431
+ class MockLauncherInformation:
432
+ """Mock launcher info for testing remaining_time()"""
433
+
434
+ def __init__(self, remaining: float | None):
435
+ self._remaining = remaining
436
+
437
+ def remaining_time(self) -> float | None:
438
+ return self._remaining
439
+
440
+
441
+ class RemainingTimeTask(ResumableTask):
442
+ """Task that records the remaining_time() value"""
443
+
444
+ output_file: Param[Path]
445
+
446
+ def execute(self):
447
+ remaining = self.remaining_time()
448
+ self.output_file.write_text(str(remaining) if remaining is not None else "None")
449
+
450
+
451
+ def test_remaining_time_with_mock_launcher():
452
+ """Test remaining_time() works with a mock launcher that provides launcher_info_code()"""
453
+ with TemporaryExperiment("remaining_time", maxwait=10) as xp:
454
+ output_file = xp.workspace.path / "remaining.txt"
455
+ launcher = MockLauncherWithRemainingTime(remaining_time_value=1234.5)
456
+
457
+ task = RemainingTimeTask.C(output_file=output_file).submit(launcher=launcher)
458
+
459
+ state = task.__xpm__.job.wait()
460
+ assert state == JobState.DONE
461
+
462
+ # Verify the task received the remaining time value
463
+ assert output_file.exists()
464
+ assert output_file.read_text() == "1234.5"
465
+
466
+
467
+ def test_remaining_time_none_with_mock_launcher():
468
+ """Test remaining_time() returns None when launcher has no time limit"""
469
+ with TemporaryExperiment("remaining_time_none", maxwait=10) as xp:
470
+ output_file = xp.workspace.path / "remaining.txt"
471
+ launcher = MockLauncherWithRemainingTime(remaining_time_value=None)
472
+
473
+ task = RemainingTimeTask.C(output_file=output_file).submit(launcher=launcher)
474
+
475
+ state = task.__xpm__.job.wait()
476
+ assert state == JobState.DONE
477
+
478
+ # Verify the task received None
479
+ assert output_file.exists()
480
+ assert output_file.read_text() == "None"
@@ -2,11 +2,12 @@ from typing import Optional
2
2
  from experimaestro import (
3
3
  Config,
4
4
  Param,
5
+ Task,
5
6
  state_dict,
6
7
  from_state_dict,
7
8
  )
8
9
  from experimaestro.core.context import SerializationContext
9
- from experimaestro.core.objects import ConfigMixin
10
+ from experimaestro.core.objects import ConfigMixin, ConfigInformation, TaskStub
10
11
 
11
12
 
12
13
  class Object1(Config):
@@ -52,3 +53,142 @@ def test_serializers_types():
52
53
  config = MultiParamObject.C(x={"a": None})
53
54
  config.__xpm__.seal(context)
54
55
  state_dict(context, config)
56
+
57
+
58
+ # --- Tests for partial_loading feature ---
59
+
60
+
61
+ class MyTask(Task):
62
+ """A task for testing partial_loading"""
63
+
64
+ value: Param[int]
65
+
66
+ def execute(self):
67
+ pass
68
+
69
+
70
+ class ConfigWithTask(Config):
71
+ """A config that references a task"""
72
+
73
+ name: Param[str]
74
+
75
+
76
+ class TaskDependentConfig(Config):
77
+ """A config that is only used by a task"""
78
+
79
+ data: Param[str]
80
+
81
+
82
+ class TaskWithDependency(Task):
83
+ """A task that uses another config"""
84
+
85
+ dep: Param[TaskDependentConfig]
86
+
87
+ def execute(self):
88
+ pass
89
+
90
+
91
+ def test_partial_loading_skips_task_reference():
92
+ """Test that partial_loading skips loading task references"""
93
+ context = SerializationContext(save_directory=None)
94
+
95
+ # Create a config with a task
96
+ task = MyTask.C(value=42)
97
+ config = ConfigWithTask.C(name="test")
98
+ config.__xpm__.task = task
99
+
100
+ # Seal both
101
+ task.__xpm__.seal(context)
102
+ config.__xpm__.seal(context)
103
+
104
+ # Serialize
105
+ data = state_dict(context, [task, config])
106
+
107
+ # Load without partial_loading - task should be loaded
108
+ [loaded_task, loaded_config] = from_state_dict(data, partial_loading=False)
109
+ assert isinstance(loaded_config, ConfigWithTask)
110
+ assert isinstance(loaded_config.__xpm__.task, MyTask)
111
+
112
+ # Load with partial_loading - task should be a stub
113
+ [loaded_task2, loaded_config2] = from_state_dict(data, partial_loading=True)
114
+ assert isinstance(loaded_config2, ConfigWithTask)
115
+ assert isinstance(loaded_config2.__xpm__.task, TaskStub)
116
+ # Check that typename contains the task class name
117
+ assert "MyTask" in loaded_config2.__xpm__.task.typename
118
+
119
+
120
+ def test_partial_loading_skips_task_dependencies():
121
+ """Test that partial_loading skips configs only used by tasks"""
122
+ context = SerializationContext(save_directory=None)
123
+
124
+ # Create a config that is only used by a task
125
+ dep = TaskDependentConfig.C(data="some data")
126
+ task = TaskWithDependency.C(dep=dep)
127
+ config = ConfigWithTask.C(name="main")
128
+ config.__xpm__.task = task
129
+
130
+ # Seal all
131
+ dep.__xpm__.seal(context)
132
+ task.__xpm__.seal(context)
133
+ config.__xpm__.seal(context)
134
+
135
+ # Serialize
136
+ data = state_dict(context, [dep, task, config])
137
+ definitions = data["objects"]
138
+
139
+ # Load with partial_loading - both task and its dependency should be stubs
140
+ objects = ConfigInformation.load_objects(
141
+ definitions, as_instance=False, partial_loading=True
142
+ )
143
+
144
+ # The main config should be loaded
145
+ main_obj = objects[definitions[-1]["id"]]
146
+ assert isinstance(main_obj, ConfigWithTask)
147
+
148
+ # The task should be a stub
149
+ task_obj = objects[definitions[1]["id"]]
150
+ assert isinstance(task_obj, TaskStub)
151
+
152
+ # The task dependency should also be a stub
153
+ dep_obj = objects[definitions[0]["id"]]
154
+ assert isinstance(dep_obj, TaskStub)
155
+
156
+
157
+ def test_partial_loading_preserves_shared_configs():
158
+ """Test that configs used by both main object and task are not skipped"""
159
+ context = SerializationContext(save_directory=None)
160
+
161
+ # Create a shared config
162
+ shared = Object1.C()
163
+
164
+ # Create task and main config that both use the shared config
165
+ task = MyTask.C(value=1)
166
+ main = Object2.C(object=shared)
167
+ main.__xpm__.task = task
168
+
169
+ # Seal all
170
+ shared.__xpm__.seal(context)
171
+ task.__xpm__.seal(context)
172
+ main.__xpm__.seal(context)
173
+
174
+ # Serialize
175
+ data = state_dict(context, [shared, task, main])
176
+ definitions = data["objects"]
177
+
178
+ # Load with partial_loading
179
+ objects = ConfigInformation.load_objects(
180
+ definitions, as_instance=False, partial_loading=True
181
+ )
182
+
183
+ # The shared config should be loaded (not a stub) since main uses it
184
+ shared_obj = objects[definitions[0]["id"]]
185
+ assert isinstance(shared_obj, Object1)
186
+ assert not isinstance(shared_obj, TaskStub)
187
+
188
+ # The main config should be loaded
189
+ main_obj = objects[definitions[-1]["id"]]
190
+ assert isinstance(main_obj, Object2)
191
+
192
+ # The task should be a stub
193
+ task_obj = objects[definitions[1]["id"]]
194
+ assert isinstance(task_obj, TaskStub)