experimaestro 2.0.0a8__py3-none-any.whl → 2.0.0b8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of experimaestro might be problematic. Click here for more details.
- experimaestro/__init__.py +10 -11
- experimaestro/annotations.py +167 -206
- experimaestro/cli/__init__.py +278 -7
- experimaestro/cli/filter.py +42 -74
- experimaestro/cli/jobs.py +157 -106
- experimaestro/cli/refactor.py +249 -0
- experimaestro/click.py +0 -1
- experimaestro/commandline.py +19 -3
- experimaestro/connectors/__init__.py +20 -1
- experimaestro/connectors/local.py +12 -0
- experimaestro/core/arguments.py +182 -46
- experimaestro/core/identifier.py +107 -6
- experimaestro/core/objects/__init__.py +6 -0
- experimaestro/core/objects/config.py +542 -25
- experimaestro/core/objects/config_walk.py +20 -0
- experimaestro/core/serialization.py +91 -34
- experimaestro/core/subparameters.py +164 -0
- experimaestro/core/types.py +175 -38
- experimaestro/exceptions.py +26 -0
- experimaestro/experiments/cli.py +111 -25
- experimaestro/generators.py +50 -9
- experimaestro/huggingface.py +3 -1
- experimaestro/launcherfinder/parser.py +29 -0
- experimaestro/launchers/__init__.py +26 -1
- experimaestro/launchers/direct.py +12 -0
- experimaestro/launchers/slurm/base.py +154 -2
- experimaestro/mkdocs/metaloader.py +0 -1
- experimaestro/mypy.py +452 -7
- experimaestro/notifications.py +63 -13
- experimaestro/progress.py +0 -2
- experimaestro/rpyc.py +0 -1
- experimaestro/run.py +19 -6
- experimaestro/scheduler/base.py +510 -125
- experimaestro/scheduler/dependencies.py +43 -28
- experimaestro/scheduler/dynamic_outputs.py +259 -130
- experimaestro/scheduler/experiment.py +256 -31
- experimaestro/scheduler/interfaces.py +501 -0
- experimaestro/scheduler/jobs.py +216 -206
- experimaestro/scheduler/remote/__init__.py +31 -0
- experimaestro/scheduler/remote/client.py +874 -0
- experimaestro/scheduler/remote/protocol.py +467 -0
- experimaestro/scheduler/remote/server.py +423 -0
- experimaestro/scheduler/remote/sync.py +144 -0
- experimaestro/scheduler/services.py +323 -23
- experimaestro/scheduler/state_db.py +437 -0
- experimaestro/scheduler/state_provider.py +2766 -0
- experimaestro/scheduler/state_sync.py +891 -0
- experimaestro/scheduler/workspace.py +52 -10
- experimaestro/scriptbuilder.py +7 -0
- experimaestro/server/__init__.py +147 -57
- experimaestro/server/data/index.css +0 -125
- experimaestro/server/data/index.css.map +1 -1
- experimaestro/server/data/index.js +194 -58
- experimaestro/server/data/index.js.map +1 -1
- experimaestro/settings.py +44 -5
- experimaestro/sphinx/__init__.py +3 -3
- experimaestro/taskglobals.py +20 -0
- experimaestro/tests/conftest.py +80 -0
- experimaestro/tests/core/test_generics.py +2 -2
- experimaestro/tests/identifier_stability.json +45 -0
- experimaestro/tests/launchers/bin/sacct +6 -2
- experimaestro/tests/launchers/bin/sbatch +4 -2
- experimaestro/tests/launchers/test_slurm.py +80 -0
- experimaestro/tests/tasks/test_dynamic.py +231 -0
- experimaestro/tests/test_cli_jobs.py +615 -0
- experimaestro/tests/test_deprecated.py +630 -0
- experimaestro/tests/test_environment.py +200 -0
- experimaestro/tests/test_file_progress_integration.py +1 -1
- experimaestro/tests/test_forward.py +3 -3
- experimaestro/tests/test_identifier.py +372 -41
- experimaestro/tests/test_identifier_stability.py +458 -0
- experimaestro/tests/test_instance.py +3 -3
- experimaestro/tests/test_multitoken.py +442 -0
- experimaestro/tests/test_mypy.py +433 -0
- experimaestro/tests/test_objects.py +312 -5
- experimaestro/tests/test_outputs.py +2 -2
- experimaestro/tests/test_param.py +8 -12
- experimaestro/tests/test_partial_paths.py +231 -0
- experimaestro/tests/test_progress.py +0 -48
- experimaestro/tests/test_remote_state.py +671 -0
- experimaestro/tests/test_resumable_task.py +480 -0
- experimaestro/tests/test_serializers.py +141 -1
- experimaestro/tests/test_state_db.py +434 -0
- experimaestro/tests/test_subparameters.py +160 -0
- experimaestro/tests/test_tags.py +136 -0
- experimaestro/tests/test_tasks.py +107 -121
- experimaestro/tests/test_token_locking.py +252 -0
- experimaestro/tests/test_tokens.py +17 -13
- experimaestro/tests/test_types.py +123 -1
- experimaestro/tests/test_workspace_triggers.py +158 -0
- experimaestro/tests/token_reschedule.py +4 -2
- experimaestro/tests/utils.py +2 -2
- experimaestro/tokens.py +154 -57
- experimaestro/tools/diff.py +1 -1
- experimaestro/tui/__init__.py +8 -0
- experimaestro/tui/app.py +2395 -0
- experimaestro/tui/app.tcss +353 -0
- experimaestro/tui/log_viewer.py +228 -0
- experimaestro/utils/__init__.py +23 -0
- experimaestro/utils/environment.py +148 -0
- experimaestro/utils/git.py +129 -0
- experimaestro/utils/resources.py +1 -1
- experimaestro/version.py +34 -0
- {experimaestro-2.0.0a8.dist-info → experimaestro-2.0.0b8.dist-info}/METADATA +68 -38
- experimaestro-2.0.0b8.dist-info/RECORD +187 -0
- {experimaestro-2.0.0a8.dist-info → experimaestro-2.0.0b8.dist-info}/WHEEL +1 -1
- experimaestro-2.0.0b8.dist-info/entry_points.txt +16 -0
- experimaestro/compat.py +0 -6
- experimaestro/core/objects.pyi +0 -221
- experimaestro/server/data/0c35d18bf06992036b69.woff2 +0 -0
- experimaestro/server/data/219aa9140e099e6c72ed.woff2 +0 -0
- experimaestro/server/data/3a4004a46a653d4b2166.woff +0 -0
- experimaestro/server/data/3baa5b8f3469222b822d.woff +0 -0
- experimaestro/server/data/4d73cb90e394b34b7670.woff +0 -0
- experimaestro/server/data/4ef4218c522f1eb6b5b1.woff2 +0 -0
- experimaestro/server/data/5d681e2edae8c60630db.woff +0 -0
- experimaestro/server/data/6f420cf17cc0d7676fad.woff2 +0 -0
- experimaestro/server/data/c380809fd3677d7d6903.woff2 +0 -0
- experimaestro/server/data/f882956fd323fd322f31.woff +0 -0
- experimaestro-2.0.0a8.dist-info/RECORD +0 -166
- experimaestro-2.0.0a8.dist-info/entry_points.txt +0 -17
- {experimaestro-2.0.0a8.dist-info → experimaestro-2.0.0b8.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,434 @@
|
|
|
1
|
+
"""Tests for workspace-level database models"""
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from experimaestro.scheduler.state_db import (
|
|
6
|
+
ExperimentModel,
|
|
7
|
+
ExperimentRunModel,
|
|
8
|
+
JobModel,
|
|
9
|
+
JobTagModel,
|
|
10
|
+
ServiceModel,
|
|
11
|
+
WorkspaceSyncMetadata,
|
|
12
|
+
initialize_workspace_database,
|
|
13
|
+
close_workspace_database,
|
|
14
|
+
ALL_MODELS,
|
|
15
|
+
)
|
|
16
|
+
from experimaestro.scheduler.state_sync import sync_workspace_from_disk
|
|
17
|
+
from experimaestro import Task, Param
|
|
18
|
+
from experimaestro.tests.utils import TemporaryExperiment
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def test_database_initialization(tmp_path: Path):
|
|
22
|
+
"""Test that workspace database is initialized correctly"""
|
|
23
|
+
db_path = tmp_path / "workspace.db"
|
|
24
|
+
|
|
25
|
+
# Initialize database
|
|
26
|
+
db, _ = initialize_workspace_database(db_path, read_only=False)
|
|
27
|
+
|
|
28
|
+
# Verify all tables were created
|
|
29
|
+
assert ExperimentModel.table_exists()
|
|
30
|
+
assert ExperimentRunModel.table_exists()
|
|
31
|
+
assert JobModel.table_exists()
|
|
32
|
+
assert JobTagModel.table_exists()
|
|
33
|
+
assert ServiceModel.table_exists()
|
|
34
|
+
assert WorkspaceSyncMetadata.table_exists()
|
|
35
|
+
|
|
36
|
+
# Verify WorkspaceSyncMetadata was initialized
|
|
37
|
+
metadata = WorkspaceSyncMetadata.get_or_none(
|
|
38
|
+
WorkspaceSyncMetadata.id == "workspace"
|
|
39
|
+
)
|
|
40
|
+
assert metadata is not None
|
|
41
|
+
assert metadata.sync_interval_minutes == 5
|
|
42
|
+
|
|
43
|
+
# Cleanup
|
|
44
|
+
close_workspace_database(db)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def test_experiment_and_run_models(tmp_path: Path):
|
|
48
|
+
"""Test creating experiments and runs"""
|
|
49
|
+
db_path = tmp_path / "workspace.db"
|
|
50
|
+
db, _ = initialize_workspace_database(db_path, read_only=False)
|
|
51
|
+
|
|
52
|
+
with db.bind_ctx(ALL_MODELS):
|
|
53
|
+
# Create an experiment
|
|
54
|
+
ExperimentModel.create(experiment_id="test_exp")
|
|
55
|
+
|
|
56
|
+
# Create a run for this experiment
|
|
57
|
+
ExperimentRunModel.create(
|
|
58
|
+
experiment_id="test_exp", run_id="run_001", status="active"
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
# Update experiment to point to this run
|
|
62
|
+
ExperimentModel.update(current_run_id="run_001").where(
|
|
63
|
+
ExperimentModel.experiment_id == "test_exp"
|
|
64
|
+
).execute()
|
|
65
|
+
|
|
66
|
+
# Verify
|
|
67
|
+
retrieved_exp = ExperimentModel.get(ExperimentModel.experiment_id == "test_exp")
|
|
68
|
+
assert retrieved_exp.current_run_id == "run_001"
|
|
69
|
+
|
|
70
|
+
retrieved_run = ExperimentRunModel.get(
|
|
71
|
+
(ExperimentRunModel.experiment_id == "test_exp")
|
|
72
|
+
& (ExperimentRunModel.run_id == "run_001")
|
|
73
|
+
)
|
|
74
|
+
assert retrieved_run.status == "active"
|
|
75
|
+
|
|
76
|
+
close_workspace_database(db)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def test_job_model_with_composite_key(tmp_path: Path):
|
|
80
|
+
"""Test job model with composite primary key (job_id, experiment_id, run_id)"""
|
|
81
|
+
db_path = tmp_path / "workspace.db"
|
|
82
|
+
db, _ = initialize_workspace_database(db_path, read_only=False)
|
|
83
|
+
|
|
84
|
+
with db.bind_ctx(ALL_MODELS):
|
|
85
|
+
# Create experiment and run first
|
|
86
|
+
ExperimentModel.create(experiment_id="exp1", current_run_id="run1")
|
|
87
|
+
ExperimentRunModel.create(experiment_id="exp1", run_id="run1")
|
|
88
|
+
|
|
89
|
+
# Create a job
|
|
90
|
+
JobModel.create(
|
|
91
|
+
job_id="job_abc",
|
|
92
|
+
experiment_id="exp1",
|
|
93
|
+
run_id="run1",
|
|
94
|
+
task_id="MyTask",
|
|
95
|
+
locator="",
|
|
96
|
+
state="running",
|
|
97
|
+
submitted_time=1234567890.0,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
# Same job in different run
|
|
101
|
+
ExperimentRunModel.create(experiment_id="exp1", run_id="run2")
|
|
102
|
+
JobModel.create(
|
|
103
|
+
job_id="job_abc",
|
|
104
|
+
experiment_id="exp1",
|
|
105
|
+
run_id="run2",
|
|
106
|
+
task_id="MyTask",
|
|
107
|
+
locator="",
|
|
108
|
+
state="done",
|
|
109
|
+
submitted_time=1234567891.0,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# Both should exist independently
|
|
113
|
+
jobs = list(JobModel.select().where(JobModel.job_id == "job_abc"))
|
|
114
|
+
assert len(jobs) == 2
|
|
115
|
+
|
|
116
|
+
# Can update one without affecting the other
|
|
117
|
+
JobModel.update(state="done").where(
|
|
118
|
+
(JobModel.job_id == "job_abc")
|
|
119
|
+
& (JobModel.experiment_id == "exp1")
|
|
120
|
+
& (JobModel.run_id == "run1")
|
|
121
|
+
).execute()
|
|
122
|
+
|
|
123
|
+
job_run1 = JobModel.get(
|
|
124
|
+
(JobModel.job_id == "job_abc")
|
|
125
|
+
& (JobModel.experiment_id == "exp1")
|
|
126
|
+
& (JobModel.run_id == "run1")
|
|
127
|
+
)
|
|
128
|
+
assert job_run1.state == "done"
|
|
129
|
+
|
|
130
|
+
close_workspace_database(db)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def test_job_tags_model(tmp_path: Path):
|
|
134
|
+
"""Test run-scoped job tags (fixes GH #128)"""
|
|
135
|
+
db_path = tmp_path / "workspace.db"
|
|
136
|
+
db, _ = initialize_workspace_database(db_path, read_only=False)
|
|
137
|
+
|
|
138
|
+
with db.bind_ctx(ALL_MODELS):
|
|
139
|
+
# Create experiment and runs
|
|
140
|
+
ExperimentModel.create(experiment_id="exp1")
|
|
141
|
+
ExperimentRunModel.create(experiment_id="exp1", run_id="run1")
|
|
142
|
+
ExperimentRunModel.create(experiment_id="exp1", run_id="run2")
|
|
143
|
+
|
|
144
|
+
# Create job in both runs
|
|
145
|
+
JobModel.create(
|
|
146
|
+
job_id="job1",
|
|
147
|
+
experiment_id="exp1",
|
|
148
|
+
run_id="run1",
|
|
149
|
+
task_id="Task",
|
|
150
|
+
locator="",
|
|
151
|
+
state="done",
|
|
152
|
+
)
|
|
153
|
+
JobModel.create(
|
|
154
|
+
job_id="job1",
|
|
155
|
+
experiment_id="exp1",
|
|
156
|
+
run_id="run2",
|
|
157
|
+
task_id="Task",
|
|
158
|
+
locator="",
|
|
159
|
+
state="done",
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
# Add different tags to same job in different runs
|
|
163
|
+
JobTagModel.create(
|
|
164
|
+
job_id="job1",
|
|
165
|
+
experiment_id="exp1",
|
|
166
|
+
run_id="run1",
|
|
167
|
+
tag_key="env",
|
|
168
|
+
tag_value="production",
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
JobTagModel.create(
|
|
172
|
+
job_id="job1",
|
|
173
|
+
experiment_id="exp1",
|
|
174
|
+
run_id="run2",
|
|
175
|
+
tag_key="env",
|
|
176
|
+
tag_value="testing",
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
# Verify tags are independent per run
|
|
180
|
+
run1_tags = list(
|
|
181
|
+
JobTagModel.select().where(
|
|
182
|
+
(JobTagModel.job_id == "job1")
|
|
183
|
+
& (JobTagModel.experiment_id == "exp1")
|
|
184
|
+
& (JobTagModel.run_id == "run1")
|
|
185
|
+
)
|
|
186
|
+
)
|
|
187
|
+
assert len(run1_tags) == 1
|
|
188
|
+
assert run1_tags[0].tag_value == "production"
|
|
189
|
+
|
|
190
|
+
run2_tags = list(
|
|
191
|
+
JobTagModel.select().where(
|
|
192
|
+
(JobTagModel.job_id == "job1")
|
|
193
|
+
& (JobTagModel.experiment_id == "exp1")
|
|
194
|
+
& (JobTagModel.run_id == "run2")
|
|
195
|
+
)
|
|
196
|
+
)
|
|
197
|
+
assert len(run2_tags) == 1
|
|
198
|
+
assert run2_tags[0].tag_value == "testing"
|
|
199
|
+
|
|
200
|
+
close_workspace_database(db)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def test_multiple_experiments_same_workspace(tmp_path: Path):
|
|
204
|
+
"""Test that multiple experiments can coexist in same workspace database"""
|
|
205
|
+
db_path = tmp_path / "workspace.db"
|
|
206
|
+
db, _ = initialize_workspace_database(db_path, read_only=False)
|
|
207
|
+
|
|
208
|
+
with db.bind_ctx(ALL_MODELS):
|
|
209
|
+
# Create two experiments
|
|
210
|
+
ExperimentModel.create(experiment_id="exp1")
|
|
211
|
+
ExperimentModel.create(experiment_id="exp2")
|
|
212
|
+
|
|
213
|
+
# Create runs for each
|
|
214
|
+
ExperimentRunModel.create(experiment_id="exp1", run_id="run1")
|
|
215
|
+
ExperimentRunModel.create(experiment_id="exp2", run_id="run1")
|
|
216
|
+
|
|
217
|
+
# Create jobs for each experiment
|
|
218
|
+
JobModel.create(
|
|
219
|
+
job_id="job1",
|
|
220
|
+
experiment_id="exp1",
|
|
221
|
+
run_id="run1",
|
|
222
|
+
task_id="Task",
|
|
223
|
+
locator="",
|
|
224
|
+
state="done",
|
|
225
|
+
)
|
|
226
|
+
JobModel.create(
|
|
227
|
+
job_id="job2",
|
|
228
|
+
experiment_id="exp2",
|
|
229
|
+
run_id="run1",
|
|
230
|
+
task_id="Task",
|
|
231
|
+
locator="",
|
|
232
|
+
state="running",
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
# Query jobs for specific experiment
|
|
236
|
+
exp1_jobs = list(JobModel.select().where(JobModel.experiment_id == "exp1"))
|
|
237
|
+
assert len(exp1_jobs) == 1
|
|
238
|
+
assert exp1_jobs[0].job_id == "job1"
|
|
239
|
+
|
|
240
|
+
exp2_jobs = list(JobModel.select().where(JobModel.experiment_id == "exp2"))
|
|
241
|
+
assert len(exp2_jobs) == 1
|
|
242
|
+
assert exp2_jobs[0].job_id == "job2"
|
|
243
|
+
|
|
244
|
+
close_workspace_database(db)
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def test_read_only_mode(tmp_path: Path):
|
|
248
|
+
"""Test that read-only mode prevents writes"""
|
|
249
|
+
db_path = tmp_path / "workspace.db"
|
|
250
|
+
|
|
251
|
+
# Create database with write mode
|
|
252
|
+
db_write, _ = initialize_workspace_database(db_path, read_only=False)
|
|
253
|
+
with db_write.bind_ctx(ALL_MODELS):
|
|
254
|
+
ExperimentModel.create(experiment_id="exp1")
|
|
255
|
+
close_workspace_database(db_write)
|
|
256
|
+
|
|
257
|
+
# Open in read-only mode
|
|
258
|
+
db_read, _ = initialize_workspace_database(db_path, read_only=True)
|
|
259
|
+
|
|
260
|
+
with db_read.bind_ctx(ALL_MODELS):
|
|
261
|
+
# Can read
|
|
262
|
+
exp = ExperimentModel.get(ExperimentModel.experiment_id == "exp1")
|
|
263
|
+
assert exp.experiment_id == "exp1"
|
|
264
|
+
|
|
265
|
+
# Cannot write (SQLite will raise OperationalError)
|
|
266
|
+
with pytest.raises(Exception): # Could be OperationalError or similar
|
|
267
|
+
ExperimentModel.create(experiment_id="exp2", workdir_path="/tmp/exp2")
|
|
268
|
+
|
|
269
|
+
close_workspace_database(db_read)
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def test_upsert_on_conflict(tmp_path: Path):
|
|
273
|
+
"""Test that on_conflict works for updating existing records"""
|
|
274
|
+
db_path = tmp_path / "workspace.db"
|
|
275
|
+
db, _ = initialize_workspace_database(db_path, read_only=False)
|
|
276
|
+
|
|
277
|
+
with db.bind_ctx(ALL_MODELS):
|
|
278
|
+
# Create experiment and run
|
|
279
|
+
ExperimentModel.create(experiment_id="exp1")
|
|
280
|
+
ExperimentRunModel.create(experiment_id="exp1", run_id="run1")
|
|
281
|
+
|
|
282
|
+
# Create job
|
|
283
|
+
JobModel.insert(
|
|
284
|
+
job_id="job1",
|
|
285
|
+
experiment_id="exp1",
|
|
286
|
+
run_id="run1",
|
|
287
|
+
task_id="Task",
|
|
288
|
+
locator="",
|
|
289
|
+
state="running",
|
|
290
|
+
).execute()
|
|
291
|
+
|
|
292
|
+
# Upsert with different state (disk wins)
|
|
293
|
+
JobModel.insert(
|
|
294
|
+
job_id="job1",
|
|
295
|
+
experiment_id="exp1",
|
|
296
|
+
run_id="run1",
|
|
297
|
+
task_id="Task",
|
|
298
|
+
locator="",
|
|
299
|
+
state="done",
|
|
300
|
+
).on_conflict(
|
|
301
|
+
conflict_target=[JobModel.job_id, JobModel.experiment_id, JobModel.run_id],
|
|
302
|
+
update={JobModel.state: "done"},
|
|
303
|
+
).execute()
|
|
304
|
+
|
|
305
|
+
# Verify state was updated
|
|
306
|
+
job = JobModel.get(
|
|
307
|
+
(JobModel.job_id == "job1")
|
|
308
|
+
& (JobModel.experiment_id == "exp1")
|
|
309
|
+
& (JobModel.run_id == "run1")
|
|
310
|
+
)
|
|
311
|
+
assert job.state == "done"
|
|
312
|
+
|
|
313
|
+
# Only one job should exist
|
|
314
|
+
assert JobModel.select().where(JobModel.job_id == "job1").count() == 1
|
|
315
|
+
|
|
316
|
+
close_workspace_database(db)
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
# Define test task
|
|
320
|
+
class SimpleTask(Task):
|
|
321
|
+
value: Param[int]
|
|
322
|
+
|
|
323
|
+
def execute(self):
|
|
324
|
+
# Write a marker file to indicate completion
|
|
325
|
+
(self.__taskdir__ / "output.txt").write_text(f"value={self.value}")
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def test_database_recovery_from_disk(tmp_path: Path):
|
|
329
|
+
"""Test recovering database from jobs.jsonl and disk state"""
|
|
330
|
+
|
|
331
|
+
# Step 1: Run experiment with tasks and tags
|
|
332
|
+
workdir = tmp_path / "workspace"
|
|
333
|
+
workdir.mkdir()
|
|
334
|
+
|
|
335
|
+
with TemporaryExperiment("recovery", maxwait=0, workdir=workdir):
|
|
336
|
+
# Submit first task with tags
|
|
337
|
+
task1 = SimpleTask.C(value=42).tag("priority", "high").tag("env", "test")
|
|
338
|
+
task1.submit()
|
|
339
|
+
|
|
340
|
+
# Submit second task with different tags
|
|
341
|
+
task2 = SimpleTask.C(value=100).tag("priority", "low").tag("env", "prod")
|
|
342
|
+
task2.submit()
|
|
343
|
+
|
|
344
|
+
# Step 2: Verify jobs.jsonl was created
|
|
345
|
+
jobs_jsonl_path = workdir / "xp" / "recovery" / "jobs.jsonl"
|
|
346
|
+
assert jobs_jsonl_path.exists(), "jobs.jsonl should have been created"
|
|
347
|
+
|
|
348
|
+
# Verify database exists
|
|
349
|
+
workspace_db_path = workdir / ".experimaestro" / "workspace.db"
|
|
350
|
+
assert workspace_db_path.exists()
|
|
351
|
+
|
|
352
|
+
# Get workspace state provider and access database
|
|
353
|
+
from experimaestro.scheduler.state_provider import WorkspaceStateProvider
|
|
354
|
+
|
|
355
|
+
provider = WorkspaceStateProvider.get_instance(
|
|
356
|
+
workdir, read_only=False, sync_on_start=False
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
with provider.workspace_db.bind_ctx(ALL_MODELS):
|
|
360
|
+
# Get original state
|
|
361
|
+
original_jobs = list(JobModel.select())
|
|
362
|
+
|
|
363
|
+
# If no jobs in DB yet, sync from disk first
|
|
364
|
+
if len(original_jobs) == 0:
|
|
365
|
+
sync_workspace_from_disk(workdir, write_mode=True, force=True)
|
|
366
|
+
original_jobs = list(JobModel.select())
|
|
367
|
+
|
|
368
|
+
assert len(original_jobs) == 2
|
|
369
|
+
|
|
370
|
+
original_job_ids = {job.job_id for job in original_jobs}
|
|
371
|
+
assert len(original_job_ids) == 2
|
|
372
|
+
|
|
373
|
+
# Get original tags
|
|
374
|
+
original_tags = {}
|
|
375
|
+
for job in original_jobs:
|
|
376
|
+
job_tags = list(
|
|
377
|
+
JobTagModel.select().where(
|
|
378
|
+
(JobTagModel.job_id == job.job_id)
|
|
379
|
+
& (JobTagModel.experiment_id == job.experiment_id)
|
|
380
|
+
& (JobTagModel.run_id == job.run_id)
|
|
381
|
+
)
|
|
382
|
+
)
|
|
383
|
+
original_tags[job.job_id] = {tag.tag_key: tag.tag_value for tag in job_tags}
|
|
384
|
+
|
|
385
|
+
assert len(original_tags) == 2
|
|
386
|
+
# Verify we have tags for both jobs
|
|
387
|
+
for job_id, tags in original_tags.items():
|
|
388
|
+
assert "priority" in tags
|
|
389
|
+
assert "env" in tags
|
|
390
|
+
|
|
391
|
+
# Close provider to cleanup
|
|
392
|
+
provider.close()
|
|
393
|
+
|
|
394
|
+
# Step 3: Delete the database
|
|
395
|
+
workspace_db_path.unlink()
|
|
396
|
+
assert not workspace_db_path.exists()
|
|
397
|
+
|
|
398
|
+
# Step 4: Recover from disk by syncing - get new provider instance
|
|
399
|
+
provider2 = WorkspaceStateProvider.get_instance(
|
|
400
|
+
workdir, read_only=False, sync_on_start=False
|
|
401
|
+
)
|
|
402
|
+
sync_workspace_from_disk(
|
|
403
|
+
workdir, write_mode=True, force=True, sync_interval_minutes=0
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
# Step 5: Verify recovered state matches original
|
|
407
|
+
with provider2.workspace_db.bind_ctx(ALL_MODELS):
|
|
408
|
+
# Check jobs were recovered
|
|
409
|
+
recovered_jobs = list(JobModel.select())
|
|
410
|
+
assert len(recovered_jobs) == 2
|
|
411
|
+
|
|
412
|
+
recovered_job_ids = {job.job_id for job in recovered_jobs}
|
|
413
|
+
assert recovered_job_ids == original_job_ids
|
|
414
|
+
|
|
415
|
+
# Check tags were recovered
|
|
416
|
+
recovered_tags = {}
|
|
417
|
+
for job in recovered_jobs:
|
|
418
|
+
job_tags = list(
|
|
419
|
+
JobTagModel.select().where(
|
|
420
|
+
(JobTagModel.job_id == job.job_id)
|
|
421
|
+
& (JobTagModel.experiment_id == job.experiment_id)
|
|
422
|
+
& (JobTagModel.run_id == job.run_id)
|
|
423
|
+
)
|
|
424
|
+
)
|
|
425
|
+
recovered_tags[job.job_id] = {
|
|
426
|
+
tag.tag_key: tag.tag_value for tag in job_tags
|
|
427
|
+
}
|
|
428
|
+
|
|
429
|
+
assert len(recovered_tags) == 2
|
|
430
|
+
|
|
431
|
+
# Verify tags match
|
|
432
|
+
for job_id in original_job_ids:
|
|
433
|
+
assert job_id in recovered_tags
|
|
434
|
+
assert recovered_tags[job_id] == original_tags[job_id]
|
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
"""Tests for subparameters (partial identifier computation)"""
|
|
2
|
+
|
|
3
|
+
from experimaestro import (
|
|
4
|
+
Config,
|
|
5
|
+
Task,
|
|
6
|
+
Param,
|
|
7
|
+
field,
|
|
8
|
+
subparameters,
|
|
9
|
+
param_group,
|
|
10
|
+
ParameterGroup,
|
|
11
|
+
Subparameters,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# Define parameter groups at module level
|
|
16
|
+
iter_group = param_group("iter")
|
|
17
|
+
model_group = param_group("model")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class TestSubparametersBasic:
|
|
21
|
+
"""Test basic subparameters functionality"""
|
|
22
|
+
|
|
23
|
+
def test_param_group_creation(self):
|
|
24
|
+
"""Test creating parameter groups"""
|
|
25
|
+
group = param_group("test")
|
|
26
|
+
assert isinstance(group, ParameterGroup)
|
|
27
|
+
assert group.name == "test"
|
|
28
|
+
|
|
29
|
+
def test_param_group_hashable(self):
|
|
30
|
+
"""Test that parameter groups are hashable"""
|
|
31
|
+
group1 = param_group("test")
|
|
32
|
+
group2 = param_group("test")
|
|
33
|
+
# Same name means equal (frozen dataclass uses value equality)
|
|
34
|
+
assert group1 == group2
|
|
35
|
+
# Both should be hashable and deduplicated
|
|
36
|
+
s = {group1, group2}
|
|
37
|
+
assert len(s) == 1
|
|
38
|
+
|
|
39
|
+
# Different names should be different
|
|
40
|
+
group3 = param_group("other")
|
|
41
|
+
assert group1 != group3
|
|
42
|
+
s2 = {group1, group3}
|
|
43
|
+
assert len(s2) == 2
|
|
44
|
+
|
|
45
|
+
def test_subparameters_creation(self):
|
|
46
|
+
"""Test creating subparameters"""
|
|
47
|
+
sp = subparameters(exclude_groups=[iter_group])
|
|
48
|
+
assert isinstance(sp, Subparameters)
|
|
49
|
+
assert iter_group in sp.exclude_groups
|
|
50
|
+
|
|
51
|
+
def test_subparameters_is_excluded(self):
|
|
52
|
+
"""Test is_excluded method"""
|
|
53
|
+
sp = subparameters(exclude_groups=[iter_group])
|
|
54
|
+
assert sp.is_excluded({iter_group}) is True
|
|
55
|
+
assert sp.is_excluded({model_group}) is False
|
|
56
|
+
assert sp.is_excluded(set()) is False
|
|
57
|
+
|
|
58
|
+
def test_subparameters_include_overrides_exclude(self):
|
|
59
|
+
"""Test that include_groups overrides exclude_groups"""
|
|
60
|
+
sp = subparameters(
|
|
61
|
+
exclude_groups=[iter_group, model_group], include_groups=[iter_group]
|
|
62
|
+
)
|
|
63
|
+
# iter_group is in both, but include wins
|
|
64
|
+
assert sp.is_excluded({iter_group}) is False
|
|
65
|
+
assert sp.is_excluded({model_group}) is True
|
|
66
|
+
|
|
67
|
+
def test_subparameters_exclude_all(self):
|
|
68
|
+
"""Test exclude_all option"""
|
|
69
|
+
sp = subparameters(exclude_all=True, include_groups=[model_group])
|
|
70
|
+
assert sp.is_excluded({iter_group}) is True
|
|
71
|
+
assert sp.is_excluded({model_group}) is False
|
|
72
|
+
assert sp.is_excluded(set()) is True
|
|
73
|
+
|
|
74
|
+
def test_subparameters_exclude_no_group(self):
|
|
75
|
+
"""Test exclude_no_group option"""
|
|
76
|
+
sp = subparameters(exclude_no_group=True)
|
|
77
|
+
assert sp.is_excluded(set()) is True
|
|
78
|
+
assert sp.is_excluded({iter_group}) is False
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class TestPartialIdentifiers:
|
|
82
|
+
"""Test partial identifier computation"""
|
|
83
|
+
|
|
84
|
+
def test_field_groups(self):
|
|
85
|
+
"""Test that field groups are correctly stored in Argument"""
|
|
86
|
+
|
|
87
|
+
class MyConfig(Config):
|
|
88
|
+
x: Param[int] = field(groups=[iter_group])
|
|
89
|
+
y: Param[float]
|
|
90
|
+
|
|
91
|
+
xpmtype = MyConfig.__getxpmtype__()
|
|
92
|
+
xpmtype.__initialize__()
|
|
93
|
+
|
|
94
|
+
assert iter_group in xpmtype.arguments["x"].groups
|
|
95
|
+
assert len(xpmtype.arguments["y"].groups) == 0
|
|
96
|
+
|
|
97
|
+
def test_subparameters_collected_in_objecttype(self):
|
|
98
|
+
"""Test that subparameters are collected in ObjectType"""
|
|
99
|
+
|
|
100
|
+
class MyTask(Task):
|
|
101
|
+
checkpoints = subparameters(exclude_groups=[iter_group])
|
|
102
|
+
x: Param[int]
|
|
103
|
+
|
|
104
|
+
xpmtype = MyTask.__getxpmtype__()
|
|
105
|
+
xpmtype.__initialize__()
|
|
106
|
+
|
|
107
|
+
assert "checkpoints" in xpmtype._subparameters
|
|
108
|
+
assert xpmtype._subparameters["checkpoints"].name == "checkpoints"
|
|
109
|
+
|
|
110
|
+
def test_partial_identifier_same_when_excluded_differs(self):
|
|
111
|
+
"""Test that partial identifiers are the same when only excluded params differ"""
|
|
112
|
+
|
|
113
|
+
class MyTask(Task):
|
|
114
|
+
checkpoints = subparameters(exclude_groups=[iter_group])
|
|
115
|
+
max_iter: Param[int] = field(groups=[iter_group])
|
|
116
|
+
learning_rate: Param[float]
|
|
117
|
+
|
|
118
|
+
c1 = MyTask.C(max_iter=100, learning_rate=0.1)
|
|
119
|
+
c2 = MyTask.C(max_iter=200, learning_rate=0.1)
|
|
120
|
+
|
|
121
|
+
# Regular identifiers should differ
|
|
122
|
+
assert c1.__xpm__.identifier != c2.__xpm__.identifier
|
|
123
|
+
|
|
124
|
+
# Partial identifiers should be the same
|
|
125
|
+
pid1 = c1.__xpm__.get_partial_identifier(MyTask.checkpoints)
|
|
126
|
+
pid2 = c2.__xpm__.get_partial_identifier(MyTask.checkpoints)
|
|
127
|
+
assert pid1 == pid2
|
|
128
|
+
|
|
129
|
+
def test_partial_identifier_differs_when_included_differs(self):
|
|
130
|
+
"""Test that partial identifiers differ when included params differ"""
|
|
131
|
+
|
|
132
|
+
class MyTask(Task):
|
|
133
|
+
checkpoints = subparameters(exclude_groups=[iter_group])
|
|
134
|
+
max_iter: Param[int] = field(groups=[iter_group])
|
|
135
|
+
learning_rate: Param[float]
|
|
136
|
+
|
|
137
|
+
c1 = MyTask.C(max_iter=100, learning_rate=0.1)
|
|
138
|
+
c2 = MyTask.C(max_iter=100, learning_rate=0.2)
|
|
139
|
+
|
|
140
|
+
# Partial identifiers should differ (learning_rate is not excluded)
|
|
141
|
+
pid1 = c1.__xpm__.get_partial_identifier(MyTask.checkpoints)
|
|
142
|
+
pid2 = c2.__xpm__.get_partial_identifier(MyTask.checkpoints)
|
|
143
|
+
assert pid1 != pid2
|
|
144
|
+
|
|
145
|
+
def test_partial_identifier_with_multiple_groups(self):
|
|
146
|
+
"""Test partial identifiers with parameters in multiple groups"""
|
|
147
|
+
|
|
148
|
+
class MyTask(Task):
|
|
149
|
+
checkpoints = subparameters(exclude_groups=[iter_group])
|
|
150
|
+
# This parameter is in both groups
|
|
151
|
+
x: Param[int] = field(groups=[iter_group, model_group])
|
|
152
|
+
y: Param[float]
|
|
153
|
+
|
|
154
|
+
c1 = MyTask.C(x=1, y=0.1)
|
|
155
|
+
c2 = MyTask.C(x=2, y=0.1)
|
|
156
|
+
|
|
157
|
+
# Partial identifiers should be the same (x is in iter_group which is excluded)
|
|
158
|
+
pid1 = c1.__xpm__.get_partial_identifier(MyTask.checkpoints)
|
|
159
|
+
pid2 = c2.__xpm__.get_partial_identifier(MyTask.checkpoints)
|
|
160
|
+
assert pid1 == pid2
|