experimaestro 2.0.0a8__py3-none-any.whl → 2.0.0b4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of experimaestro might be problematic. Click here for more details.

Files changed (116) hide show
  1. experimaestro/__init__.py +10 -11
  2. experimaestro/annotations.py +167 -206
  3. experimaestro/cli/__init__.py +130 -5
  4. experimaestro/cli/filter.py +42 -74
  5. experimaestro/cli/jobs.py +157 -106
  6. experimaestro/cli/refactor.py +249 -0
  7. experimaestro/click.py +0 -1
  8. experimaestro/commandline.py +19 -3
  9. experimaestro/connectors/__init__.py +20 -1
  10. experimaestro/connectors/local.py +12 -0
  11. experimaestro/core/arguments.py +182 -46
  12. experimaestro/core/identifier.py +107 -6
  13. experimaestro/core/objects/__init__.py +6 -0
  14. experimaestro/core/objects/config.py +542 -25
  15. experimaestro/core/objects/config_walk.py +20 -0
  16. experimaestro/core/serialization.py +91 -34
  17. experimaestro/core/subparameters.py +164 -0
  18. experimaestro/core/types.py +175 -38
  19. experimaestro/exceptions.py +26 -0
  20. experimaestro/experiments/cli.py +107 -25
  21. experimaestro/generators.py +50 -9
  22. experimaestro/huggingface.py +3 -1
  23. experimaestro/launcherfinder/parser.py +29 -0
  24. experimaestro/launchers/__init__.py +26 -1
  25. experimaestro/launchers/direct.py +12 -0
  26. experimaestro/launchers/slurm/base.py +154 -2
  27. experimaestro/mkdocs/metaloader.py +0 -1
  28. experimaestro/mypy.py +452 -7
  29. experimaestro/notifications.py +63 -13
  30. experimaestro/progress.py +0 -2
  31. experimaestro/rpyc.py +0 -1
  32. experimaestro/run.py +19 -6
  33. experimaestro/scheduler/base.py +489 -125
  34. experimaestro/scheduler/dependencies.py +43 -28
  35. experimaestro/scheduler/dynamic_outputs.py +259 -130
  36. experimaestro/scheduler/experiment.py +225 -30
  37. experimaestro/scheduler/interfaces.py +474 -0
  38. experimaestro/scheduler/jobs.py +216 -206
  39. experimaestro/scheduler/services.py +186 -12
  40. experimaestro/scheduler/state_db.py +388 -0
  41. experimaestro/scheduler/state_provider.py +2345 -0
  42. experimaestro/scheduler/state_sync.py +834 -0
  43. experimaestro/scheduler/workspace.py +52 -10
  44. experimaestro/scriptbuilder.py +7 -0
  45. experimaestro/server/__init__.py +147 -57
  46. experimaestro/server/data/index.css +0 -125
  47. experimaestro/server/data/index.css.map +1 -1
  48. experimaestro/server/data/index.js +194 -58
  49. experimaestro/server/data/index.js.map +1 -1
  50. experimaestro/settings.py +44 -5
  51. experimaestro/sphinx/__init__.py +3 -3
  52. experimaestro/taskglobals.py +20 -0
  53. experimaestro/tests/conftest.py +80 -0
  54. experimaestro/tests/core/test_generics.py +2 -2
  55. experimaestro/tests/identifier_stability.json +45 -0
  56. experimaestro/tests/launchers/bin/sacct +6 -2
  57. experimaestro/tests/launchers/bin/sbatch +4 -2
  58. experimaestro/tests/launchers/test_slurm.py +80 -0
  59. experimaestro/tests/tasks/test_dynamic.py +231 -0
  60. experimaestro/tests/test_cli_jobs.py +615 -0
  61. experimaestro/tests/test_deprecated.py +630 -0
  62. experimaestro/tests/test_environment.py +200 -0
  63. experimaestro/tests/test_file_progress_integration.py +1 -1
  64. experimaestro/tests/test_forward.py +3 -3
  65. experimaestro/tests/test_identifier.py +372 -41
  66. experimaestro/tests/test_identifier_stability.py +458 -0
  67. experimaestro/tests/test_instance.py +3 -3
  68. experimaestro/tests/test_multitoken.py +442 -0
  69. experimaestro/tests/test_mypy.py +433 -0
  70. experimaestro/tests/test_objects.py +312 -5
  71. experimaestro/tests/test_outputs.py +2 -2
  72. experimaestro/tests/test_param.py +8 -12
  73. experimaestro/tests/test_partial_paths.py +231 -0
  74. experimaestro/tests/test_progress.py +0 -48
  75. experimaestro/tests/test_resumable_task.py +480 -0
  76. experimaestro/tests/test_serializers.py +141 -1
  77. experimaestro/tests/test_state_db.py +434 -0
  78. experimaestro/tests/test_subparameters.py +160 -0
  79. experimaestro/tests/test_tags.py +136 -0
  80. experimaestro/tests/test_tasks.py +107 -121
  81. experimaestro/tests/test_token_locking.py +252 -0
  82. experimaestro/tests/test_tokens.py +17 -13
  83. experimaestro/tests/test_types.py +123 -1
  84. experimaestro/tests/test_workspace_triggers.py +158 -0
  85. experimaestro/tests/token_reschedule.py +4 -2
  86. experimaestro/tests/utils.py +2 -2
  87. experimaestro/tokens.py +154 -57
  88. experimaestro/tools/diff.py +1 -1
  89. experimaestro/tui/__init__.py +8 -0
  90. experimaestro/tui/app.py +2303 -0
  91. experimaestro/tui/app.tcss +353 -0
  92. experimaestro/tui/log_viewer.py +228 -0
  93. experimaestro/utils/__init__.py +23 -0
  94. experimaestro/utils/environment.py +148 -0
  95. experimaestro/utils/git.py +129 -0
  96. experimaestro/utils/resources.py +1 -1
  97. experimaestro/version.py +34 -0
  98. {experimaestro-2.0.0a8.dist-info → experimaestro-2.0.0b4.dist-info}/METADATA +68 -38
  99. experimaestro-2.0.0b4.dist-info/RECORD +181 -0
  100. {experimaestro-2.0.0a8.dist-info → experimaestro-2.0.0b4.dist-info}/WHEEL +1 -1
  101. experimaestro-2.0.0b4.dist-info/entry_points.txt +16 -0
  102. experimaestro/compat.py +0 -6
  103. experimaestro/core/objects.pyi +0 -221
  104. experimaestro/server/data/0c35d18bf06992036b69.woff2 +0 -0
  105. experimaestro/server/data/219aa9140e099e6c72ed.woff2 +0 -0
  106. experimaestro/server/data/3a4004a46a653d4b2166.woff +0 -0
  107. experimaestro/server/data/3baa5b8f3469222b822d.woff +0 -0
  108. experimaestro/server/data/4d73cb90e394b34b7670.woff +0 -0
  109. experimaestro/server/data/4ef4218c522f1eb6b5b1.woff2 +0 -0
  110. experimaestro/server/data/5d681e2edae8c60630db.woff +0 -0
  111. experimaestro/server/data/6f420cf17cc0d7676fad.woff2 +0 -0
  112. experimaestro/server/data/c380809fd3677d7d6903.woff2 +0 -0
  113. experimaestro/server/data/f882956fd323fd322f31.woff +0 -0
  114. experimaestro-2.0.0a8.dist-info/RECORD +0 -166
  115. experimaestro-2.0.0a8.dist-info/entry_points.txt +0 -17
  116. {experimaestro-2.0.0a8.dist-info → experimaestro-2.0.0b4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,434 @@
1
+ """Tests for workspace-level database models"""
2
+
3
+ import pytest
4
+ from pathlib import Path
5
+ from experimaestro.scheduler.state_db import (
6
+ ExperimentModel,
7
+ ExperimentRunModel,
8
+ JobModel,
9
+ JobTagModel,
10
+ ServiceModel,
11
+ WorkspaceSyncMetadata,
12
+ initialize_workspace_database,
13
+ close_workspace_database,
14
+ ALL_MODELS,
15
+ )
16
+ from experimaestro.scheduler.state_sync import sync_workspace_from_disk
17
+ from experimaestro import Task, Param
18
+ from experimaestro.tests.utils import TemporaryExperiment
19
+
20
+
21
+ def test_database_initialization(tmp_path: Path):
22
+ """Test that workspace database is initialized correctly"""
23
+ db_path = tmp_path / "workspace.db"
24
+
25
+ # Initialize database
26
+ db = initialize_workspace_database(db_path, read_only=False)
27
+
28
+ # Verify all tables were created
29
+ assert ExperimentModel.table_exists()
30
+ assert ExperimentRunModel.table_exists()
31
+ assert JobModel.table_exists()
32
+ assert JobTagModel.table_exists()
33
+ assert ServiceModel.table_exists()
34
+ assert WorkspaceSyncMetadata.table_exists()
35
+
36
+ # Verify WorkspaceSyncMetadata was initialized
37
+ metadata = WorkspaceSyncMetadata.get_or_none(
38
+ WorkspaceSyncMetadata.id == "workspace"
39
+ )
40
+ assert metadata is not None
41
+ assert metadata.sync_interval_minutes == 5
42
+
43
+ # Cleanup
44
+ close_workspace_database(db)
45
+
46
+
47
+ def test_experiment_and_run_models(tmp_path: Path):
48
+ """Test creating experiments and runs"""
49
+ db_path = tmp_path / "workspace.db"
50
+ db = initialize_workspace_database(db_path, read_only=False)
51
+
52
+ with db.bind_ctx(ALL_MODELS):
53
+ # Create an experiment
54
+ ExperimentModel.create(experiment_id="test_exp")
55
+
56
+ # Create a run for this experiment
57
+ ExperimentRunModel.create(
58
+ experiment_id="test_exp", run_id="run_001", status="active"
59
+ )
60
+
61
+ # Update experiment to point to this run
62
+ ExperimentModel.update(current_run_id="run_001").where(
63
+ ExperimentModel.experiment_id == "test_exp"
64
+ ).execute()
65
+
66
+ # Verify
67
+ retrieved_exp = ExperimentModel.get(ExperimentModel.experiment_id == "test_exp")
68
+ assert retrieved_exp.current_run_id == "run_001"
69
+
70
+ retrieved_run = ExperimentRunModel.get(
71
+ (ExperimentRunModel.experiment_id == "test_exp")
72
+ & (ExperimentRunModel.run_id == "run_001")
73
+ )
74
+ assert retrieved_run.status == "active"
75
+
76
+ close_workspace_database(db)
77
+
78
+
79
+ def test_job_model_with_composite_key(tmp_path: Path):
80
+ """Test job model with composite primary key (job_id, experiment_id, run_id)"""
81
+ db_path = tmp_path / "workspace.db"
82
+ db = initialize_workspace_database(db_path, read_only=False)
83
+
84
+ with db.bind_ctx(ALL_MODELS):
85
+ # Create experiment and run first
86
+ ExperimentModel.create(experiment_id="exp1", current_run_id="run1")
87
+ ExperimentRunModel.create(experiment_id="exp1", run_id="run1")
88
+
89
+ # Create a job
90
+ JobModel.create(
91
+ job_id="job_abc",
92
+ experiment_id="exp1",
93
+ run_id="run1",
94
+ task_id="MyTask",
95
+ locator="",
96
+ state="running",
97
+ submitted_time=1234567890.0,
98
+ )
99
+
100
+ # Same job in different run
101
+ ExperimentRunModel.create(experiment_id="exp1", run_id="run2")
102
+ JobModel.create(
103
+ job_id="job_abc",
104
+ experiment_id="exp1",
105
+ run_id="run2",
106
+ task_id="MyTask",
107
+ locator="",
108
+ state="done",
109
+ submitted_time=1234567891.0,
110
+ )
111
+
112
+ # Both should exist independently
113
+ jobs = list(JobModel.select().where(JobModel.job_id == "job_abc"))
114
+ assert len(jobs) == 2
115
+
116
+ # Can update one without affecting the other
117
+ JobModel.update(state="done").where(
118
+ (JobModel.job_id == "job_abc")
119
+ & (JobModel.experiment_id == "exp1")
120
+ & (JobModel.run_id == "run1")
121
+ ).execute()
122
+
123
+ job_run1 = JobModel.get(
124
+ (JobModel.job_id == "job_abc")
125
+ & (JobModel.experiment_id == "exp1")
126
+ & (JobModel.run_id == "run1")
127
+ )
128
+ assert job_run1.state == "done"
129
+
130
+ close_workspace_database(db)
131
+
132
+
133
+ def test_job_tags_model(tmp_path: Path):
134
+ """Test run-scoped job tags (fixes GH #128)"""
135
+ db_path = tmp_path / "workspace.db"
136
+ db = initialize_workspace_database(db_path, read_only=False)
137
+
138
+ with db.bind_ctx(ALL_MODELS):
139
+ # Create experiment and runs
140
+ ExperimentModel.create(experiment_id="exp1")
141
+ ExperimentRunModel.create(experiment_id="exp1", run_id="run1")
142
+ ExperimentRunModel.create(experiment_id="exp1", run_id="run2")
143
+
144
+ # Create job in both runs
145
+ JobModel.create(
146
+ job_id="job1",
147
+ experiment_id="exp1",
148
+ run_id="run1",
149
+ task_id="Task",
150
+ locator="",
151
+ state="done",
152
+ )
153
+ JobModel.create(
154
+ job_id="job1",
155
+ experiment_id="exp1",
156
+ run_id="run2",
157
+ task_id="Task",
158
+ locator="",
159
+ state="done",
160
+ )
161
+
162
+ # Add different tags to same job in different runs
163
+ JobTagModel.create(
164
+ job_id="job1",
165
+ experiment_id="exp1",
166
+ run_id="run1",
167
+ tag_key="env",
168
+ tag_value="production",
169
+ )
170
+
171
+ JobTagModel.create(
172
+ job_id="job1",
173
+ experiment_id="exp1",
174
+ run_id="run2",
175
+ tag_key="env",
176
+ tag_value="testing",
177
+ )
178
+
179
+ # Verify tags are independent per run
180
+ run1_tags = list(
181
+ JobTagModel.select().where(
182
+ (JobTagModel.job_id == "job1")
183
+ & (JobTagModel.experiment_id == "exp1")
184
+ & (JobTagModel.run_id == "run1")
185
+ )
186
+ )
187
+ assert len(run1_tags) == 1
188
+ assert run1_tags[0].tag_value == "production"
189
+
190
+ run2_tags = list(
191
+ JobTagModel.select().where(
192
+ (JobTagModel.job_id == "job1")
193
+ & (JobTagModel.experiment_id == "exp1")
194
+ & (JobTagModel.run_id == "run2")
195
+ )
196
+ )
197
+ assert len(run2_tags) == 1
198
+ assert run2_tags[0].tag_value == "testing"
199
+
200
+ close_workspace_database(db)
201
+
202
+
203
+ def test_multiple_experiments_same_workspace(tmp_path: Path):
204
+ """Test that multiple experiments can coexist in same workspace database"""
205
+ db_path = tmp_path / "workspace.db"
206
+ db = initialize_workspace_database(db_path, read_only=False)
207
+
208
+ with db.bind_ctx(ALL_MODELS):
209
+ # Create two experiments
210
+ ExperimentModel.create(experiment_id="exp1")
211
+ ExperimentModel.create(experiment_id="exp2")
212
+
213
+ # Create runs for each
214
+ ExperimentRunModel.create(experiment_id="exp1", run_id="run1")
215
+ ExperimentRunModel.create(experiment_id="exp2", run_id="run1")
216
+
217
+ # Create jobs for each experiment
218
+ JobModel.create(
219
+ job_id="job1",
220
+ experiment_id="exp1",
221
+ run_id="run1",
222
+ task_id="Task",
223
+ locator="",
224
+ state="done",
225
+ )
226
+ JobModel.create(
227
+ job_id="job2",
228
+ experiment_id="exp2",
229
+ run_id="run1",
230
+ task_id="Task",
231
+ locator="",
232
+ state="running",
233
+ )
234
+
235
+ # Query jobs for specific experiment
236
+ exp1_jobs = list(JobModel.select().where(JobModel.experiment_id == "exp1"))
237
+ assert len(exp1_jobs) == 1
238
+ assert exp1_jobs[0].job_id == "job1"
239
+
240
+ exp2_jobs = list(JobModel.select().where(JobModel.experiment_id == "exp2"))
241
+ assert len(exp2_jobs) == 1
242
+ assert exp2_jobs[0].job_id == "job2"
243
+
244
+ close_workspace_database(db)
245
+
246
+
247
+ def test_read_only_mode(tmp_path: Path):
248
+ """Test that read-only mode prevents writes"""
249
+ db_path = tmp_path / "workspace.db"
250
+
251
+ # Create database with write mode
252
+ db_write = initialize_workspace_database(db_path, read_only=False)
253
+ with db_write.bind_ctx(ALL_MODELS):
254
+ ExperimentModel.create(experiment_id="exp1")
255
+ close_workspace_database(db_write)
256
+
257
+ # Open in read-only mode
258
+ db_read = initialize_workspace_database(db_path, read_only=True)
259
+
260
+ with db_read.bind_ctx(ALL_MODELS):
261
+ # Can read
262
+ exp = ExperimentModel.get(ExperimentModel.experiment_id == "exp1")
263
+ assert exp.experiment_id == "exp1"
264
+
265
+ # Cannot write (SQLite will raise OperationalError)
266
+ with pytest.raises(Exception): # Could be OperationalError or similar
267
+ ExperimentModel.create(experiment_id="exp2", workdir_path="/tmp/exp2")
268
+
269
+ close_workspace_database(db_read)
270
+
271
+
272
+ def test_upsert_on_conflict(tmp_path: Path):
273
+ """Test that on_conflict works for updating existing records"""
274
+ db_path = tmp_path / "workspace.db"
275
+ db = initialize_workspace_database(db_path, read_only=False)
276
+
277
+ with db.bind_ctx(ALL_MODELS):
278
+ # Create experiment and run
279
+ ExperimentModel.create(experiment_id="exp1")
280
+ ExperimentRunModel.create(experiment_id="exp1", run_id="run1")
281
+
282
+ # Create job
283
+ JobModel.insert(
284
+ job_id="job1",
285
+ experiment_id="exp1",
286
+ run_id="run1",
287
+ task_id="Task",
288
+ locator="",
289
+ state="running",
290
+ ).execute()
291
+
292
+ # Upsert with different state (disk wins)
293
+ JobModel.insert(
294
+ job_id="job1",
295
+ experiment_id="exp1",
296
+ run_id="run1",
297
+ task_id="Task",
298
+ locator="",
299
+ state="done",
300
+ ).on_conflict(
301
+ conflict_target=[JobModel.job_id, JobModel.experiment_id, JobModel.run_id],
302
+ update={JobModel.state: "done"},
303
+ ).execute()
304
+
305
+ # Verify state was updated
306
+ job = JobModel.get(
307
+ (JobModel.job_id == "job1")
308
+ & (JobModel.experiment_id == "exp1")
309
+ & (JobModel.run_id == "run1")
310
+ )
311
+ assert job.state == "done"
312
+
313
+ # Only one job should exist
314
+ assert JobModel.select().where(JobModel.job_id == "job1").count() == 1
315
+
316
+ close_workspace_database(db)
317
+
318
+
319
+ # Define test task
320
+ class SimpleTask(Task):
321
+ value: Param[int]
322
+
323
+ def execute(self):
324
+ # Write a marker file to indicate completion
325
+ (self.__taskdir__ / "output.txt").write_text(f"value={self.value}")
326
+
327
+
328
+ def test_database_recovery_from_disk(tmp_path: Path):
329
+ """Test recovering database from jobs.jsonl and disk state"""
330
+
331
+ # Step 1: Run experiment with tasks and tags
332
+ workdir = tmp_path / "workspace"
333
+ workdir.mkdir()
334
+
335
+ with TemporaryExperiment("recovery", maxwait=0, workdir=workdir):
336
+ # Submit first task with tags
337
+ task1 = SimpleTask.C(value=42).tag("priority", "high").tag("env", "test")
338
+ task1.submit()
339
+
340
+ # Submit second task with different tags
341
+ task2 = SimpleTask.C(value=100).tag("priority", "low").tag("env", "prod")
342
+ task2.submit()
343
+
344
+ # Step 2: Verify jobs.jsonl was created
345
+ jobs_jsonl_path = workdir / "xp" / "recovery" / "jobs.jsonl"
346
+ assert jobs_jsonl_path.exists(), "jobs.jsonl should have been created"
347
+
348
+ # Verify database exists
349
+ workspace_db_path = workdir / ".experimaestro" / "workspace.db"
350
+ assert workspace_db_path.exists()
351
+
352
+ # Get workspace state provider and access database
353
+ from experimaestro.scheduler.state_provider import WorkspaceStateProvider
354
+
355
+ provider = WorkspaceStateProvider.get_instance(
356
+ workdir, read_only=False, sync_on_start=False
357
+ )
358
+
359
+ with provider.workspace_db.bind_ctx(ALL_MODELS):
360
+ # Get original state
361
+ original_jobs = list(JobModel.select())
362
+
363
+ # If no jobs in DB yet, sync from disk first
364
+ if len(original_jobs) == 0:
365
+ sync_workspace_from_disk(workdir, write_mode=True, force=True)
366
+ original_jobs = list(JobModel.select())
367
+
368
+ assert len(original_jobs) == 2
369
+
370
+ original_job_ids = {job.job_id for job in original_jobs}
371
+ assert len(original_job_ids) == 2
372
+
373
+ # Get original tags
374
+ original_tags = {}
375
+ for job in original_jobs:
376
+ job_tags = list(
377
+ JobTagModel.select().where(
378
+ (JobTagModel.job_id == job.job_id)
379
+ & (JobTagModel.experiment_id == job.experiment_id)
380
+ & (JobTagModel.run_id == job.run_id)
381
+ )
382
+ )
383
+ original_tags[job.job_id] = {tag.tag_key: tag.tag_value for tag in job_tags}
384
+
385
+ assert len(original_tags) == 2
386
+ # Verify we have tags for both jobs
387
+ for job_id, tags in original_tags.items():
388
+ assert "priority" in tags
389
+ assert "env" in tags
390
+
391
+ # Close provider to cleanup
392
+ provider.close()
393
+
394
+ # Step 3: Delete the database
395
+ workspace_db_path.unlink()
396
+ assert not workspace_db_path.exists()
397
+
398
+ # Step 4: Recover from disk by syncing - get new provider instance
399
+ provider2 = WorkspaceStateProvider.get_instance(
400
+ workdir, read_only=False, sync_on_start=False
401
+ )
402
+ sync_workspace_from_disk(
403
+ workdir, write_mode=True, force=True, sync_interval_minutes=0
404
+ )
405
+
406
+ # Step 5: Verify recovered state matches original
407
+ with provider2.workspace_db.bind_ctx(ALL_MODELS):
408
+ # Check jobs were recovered
409
+ recovered_jobs = list(JobModel.select())
410
+ assert len(recovered_jobs) == 2
411
+
412
+ recovered_job_ids = {job.job_id for job in recovered_jobs}
413
+ assert recovered_job_ids == original_job_ids
414
+
415
+ # Check tags were recovered
416
+ recovered_tags = {}
417
+ for job in recovered_jobs:
418
+ job_tags = list(
419
+ JobTagModel.select().where(
420
+ (JobTagModel.job_id == job.job_id)
421
+ & (JobTagModel.experiment_id == job.experiment_id)
422
+ & (JobTagModel.run_id == job.run_id)
423
+ )
424
+ )
425
+ recovered_tags[job.job_id] = {
426
+ tag.tag_key: tag.tag_value for tag in job_tags
427
+ }
428
+
429
+ assert len(recovered_tags) == 2
430
+
431
+ # Verify tags match
432
+ for job_id in original_job_ids:
433
+ assert job_id in recovered_tags
434
+ assert recovered_tags[job_id] == original_tags[job_id]
@@ -0,0 +1,160 @@
1
+ """Tests for subparameters (partial identifier computation)"""
2
+
3
+ from experimaestro import (
4
+ Config,
5
+ Task,
6
+ Param,
7
+ field,
8
+ subparameters,
9
+ param_group,
10
+ ParameterGroup,
11
+ Subparameters,
12
+ )
13
+
14
+
15
+ # Define parameter groups at module level
16
+ iter_group = param_group("iter")
17
+ model_group = param_group("model")
18
+
19
+
20
+ class TestSubparametersBasic:
21
+ """Test basic subparameters functionality"""
22
+
23
+ def test_param_group_creation(self):
24
+ """Test creating parameter groups"""
25
+ group = param_group("test")
26
+ assert isinstance(group, ParameterGroup)
27
+ assert group.name == "test"
28
+
29
+ def test_param_group_hashable(self):
30
+ """Test that parameter groups are hashable"""
31
+ group1 = param_group("test")
32
+ group2 = param_group("test")
33
+ # Same name means equal (frozen dataclass uses value equality)
34
+ assert group1 == group2
35
+ # Both should be hashable and deduplicated
36
+ s = {group1, group2}
37
+ assert len(s) == 1
38
+
39
+ # Different names should be different
40
+ group3 = param_group("other")
41
+ assert group1 != group3
42
+ s2 = {group1, group3}
43
+ assert len(s2) == 2
44
+
45
+ def test_subparameters_creation(self):
46
+ """Test creating subparameters"""
47
+ sp = subparameters(exclude_groups=[iter_group])
48
+ assert isinstance(sp, Subparameters)
49
+ assert iter_group in sp.exclude_groups
50
+
51
+ def test_subparameters_is_excluded(self):
52
+ """Test is_excluded method"""
53
+ sp = subparameters(exclude_groups=[iter_group])
54
+ assert sp.is_excluded({iter_group}) is True
55
+ assert sp.is_excluded({model_group}) is False
56
+ assert sp.is_excluded(set()) is False
57
+
58
+ def test_subparameters_include_overrides_exclude(self):
59
+ """Test that include_groups overrides exclude_groups"""
60
+ sp = subparameters(
61
+ exclude_groups=[iter_group, model_group], include_groups=[iter_group]
62
+ )
63
+ # iter_group is in both, but include wins
64
+ assert sp.is_excluded({iter_group}) is False
65
+ assert sp.is_excluded({model_group}) is True
66
+
67
+ def test_subparameters_exclude_all(self):
68
+ """Test exclude_all option"""
69
+ sp = subparameters(exclude_all=True, include_groups=[model_group])
70
+ assert sp.is_excluded({iter_group}) is True
71
+ assert sp.is_excluded({model_group}) is False
72
+ assert sp.is_excluded(set()) is True
73
+
74
+ def test_subparameters_exclude_no_group(self):
75
+ """Test exclude_no_group option"""
76
+ sp = subparameters(exclude_no_group=True)
77
+ assert sp.is_excluded(set()) is True
78
+ assert sp.is_excluded({iter_group}) is False
79
+
80
+
81
+ class TestPartialIdentifiers:
82
+ """Test partial identifier computation"""
83
+
84
+ def test_field_groups(self):
85
+ """Test that field groups are correctly stored in Argument"""
86
+
87
+ class MyConfig(Config):
88
+ x: Param[int] = field(groups=[iter_group])
89
+ y: Param[float]
90
+
91
+ xpmtype = MyConfig.__getxpmtype__()
92
+ xpmtype.__initialize__()
93
+
94
+ assert iter_group in xpmtype.arguments["x"].groups
95
+ assert len(xpmtype.arguments["y"].groups) == 0
96
+
97
+ def test_subparameters_collected_in_objecttype(self):
98
+ """Test that subparameters are collected in ObjectType"""
99
+
100
+ class MyTask(Task):
101
+ checkpoints = subparameters(exclude_groups=[iter_group])
102
+ x: Param[int]
103
+
104
+ xpmtype = MyTask.__getxpmtype__()
105
+ xpmtype.__initialize__()
106
+
107
+ assert "checkpoints" in xpmtype._subparameters
108
+ assert xpmtype._subparameters["checkpoints"].name == "checkpoints"
109
+
110
+ def test_partial_identifier_same_when_excluded_differs(self):
111
+ """Test that partial identifiers are the same when only excluded params differ"""
112
+
113
+ class MyTask(Task):
114
+ checkpoints = subparameters(exclude_groups=[iter_group])
115
+ max_iter: Param[int] = field(groups=[iter_group])
116
+ learning_rate: Param[float]
117
+
118
+ c1 = MyTask.C(max_iter=100, learning_rate=0.1)
119
+ c2 = MyTask.C(max_iter=200, learning_rate=0.1)
120
+
121
+ # Regular identifiers should differ
122
+ assert c1.__xpm__.identifier != c2.__xpm__.identifier
123
+
124
+ # Partial identifiers should be the same
125
+ pid1 = c1.__xpm__.get_partial_identifier(MyTask.checkpoints)
126
+ pid2 = c2.__xpm__.get_partial_identifier(MyTask.checkpoints)
127
+ assert pid1 == pid2
128
+
129
+ def test_partial_identifier_differs_when_included_differs(self):
130
+ """Test that partial identifiers differ when included params differ"""
131
+
132
+ class MyTask(Task):
133
+ checkpoints = subparameters(exclude_groups=[iter_group])
134
+ max_iter: Param[int] = field(groups=[iter_group])
135
+ learning_rate: Param[float]
136
+
137
+ c1 = MyTask.C(max_iter=100, learning_rate=0.1)
138
+ c2 = MyTask.C(max_iter=100, learning_rate=0.2)
139
+
140
+ # Partial identifiers should differ (learning_rate is not excluded)
141
+ pid1 = c1.__xpm__.get_partial_identifier(MyTask.checkpoints)
142
+ pid2 = c2.__xpm__.get_partial_identifier(MyTask.checkpoints)
143
+ assert pid1 != pid2
144
+
145
+ def test_partial_identifier_with_multiple_groups(self):
146
+ """Test partial identifiers with parameters in multiple groups"""
147
+
148
+ class MyTask(Task):
149
+ checkpoints = subparameters(exclude_groups=[iter_group])
150
+ # This parameter is in both groups
151
+ x: Param[int] = field(groups=[iter_group, model_group])
152
+ y: Param[float]
153
+
154
+ c1 = MyTask.C(x=1, y=0.1)
155
+ c2 = MyTask.C(x=2, y=0.1)
156
+
157
+ # Partial identifiers should be the same (x is in iter_group which is excluded)
158
+ pid1 = c1.__xpm__.get_partial_identifier(MyTask.checkpoints)
159
+ pid2 = c2.__xpm__.get_partial_identifier(MyTask.checkpoints)
160
+ assert pid1 == pid2