experimaestro 1.11.1__py3-none-any.whl → 2.0.0b4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of experimaestro might be problematic. Click here for more details.
- experimaestro/__init__.py +10 -11
- experimaestro/annotations.py +167 -206
- experimaestro/cli/__init__.py +140 -16
- experimaestro/cli/filter.py +42 -74
- experimaestro/cli/jobs.py +157 -106
- experimaestro/cli/progress.py +269 -0
- experimaestro/cli/refactor.py +249 -0
- experimaestro/click.py +0 -1
- experimaestro/commandline.py +19 -3
- experimaestro/connectors/__init__.py +22 -3
- experimaestro/connectors/local.py +12 -0
- experimaestro/core/arguments.py +192 -37
- experimaestro/core/identifier.py +127 -12
- experimaestro/core/objects/__init__.py +6 -0
- experimaestro/core/objects/config.py +702 -285
- experimaestro/core/objects/config_walk.py +24 -6
- experimaestro/core/serialization.py +91 -34
- experimaestro/core/serializers.py +1 -8
- experimaestro/core/subparameters.py +164 -0
- experimaestro/core/types.py +198 -83
- experimaestro/exceptions.py +26 -0
- experimaestro/experiments/cli.py +107 -25
- experimaestro/generators.py +50 -9
- experimaestro/huggingface.py +3 -1
- experimaestro/launcherfinder/parser.py +29 -0
- experimaestro/launcherfinder/registry.py +3 -3
- experimaestro/launchers/__init__.py +26 -1
- experimaestro/launchers/direct.py +12 -0
- experimaestro/launchers/slurm/base.py +154 -2
- experimaestro/mkdocs/base.py +6 -8
- experimaestro/mkdocs/metaloader.py +0 -1
- experimaestro/mypy.py +452 -7
- experimaestro/notifications.py +75 -16
- experimaestro/progress.py +404 -0
- experimaestro/rpyc.py +0 -1
- experimaestro/run.py +19 -6
- experimaestro/scheduler/__init__.py +18 -1
- experimaestro/scheduler/base.py +504 -959
- experimaestro/scheduler/dependencies.py +43 -28
- experimaestro/scheduler/dynamic_outputs.py +259 -130
- experimaestro/scheduler/experiment.py +582 -0
- experimaestro/scheduler/interfaces.py +474 -0
- experimaestro/scheduler/jobs.py +485 -0
- experimaestro/scheduler/services.py +186 -12
- experimaestro/scheduler/signal_handler.py +32 -0
- experimaestro/scheduler/state.py +1 -1
- experimaestro/scheduler/state_db.py +388 -0
- experimaestro/scheduler/state_provider.py +2345 -0
- experimaestro/scheduler/state_sync.py +834 -0
- experimaestro/scheduler/workspace.py +52 -10
- experimaestro/scriptbuilder.py +7 -0
- experimaestro/server/__init__.py +153 -32
- experimaestro/server/data/index.css +0 -125
- experimaestro/server/data/index.css.map +1 -1
- experimaestro/server/data/index.js +194 -58
- experimaestro/server/data/index.js.map +1 -1
- experimaestro/settings.py +47 -6
- experimaestro/sphinx/__init__.py +3 -3
- experimaestro/taskglobals.py +20 -0
- experimaestro/tests/conftest.py +80 -0
- experimaestro/tests/core/test_generics.py +2 -2
- experimaestro/tests/identifier_stability.json +45 -0
- experimaestro/tests/launchers/bin/sacct +6 -2
- experimaestro/tests/launchers/bin/sbatch +4 -2
- experimaestro/tests/launchers/common.py +2 -2
- experimaestro/tests/launchers/test_slurm.py +80 -0
- experimaestro/tests/restart.py +1 -1
- experimaestro/tests/tasks/all.py +7 -0
- experimaestro/tests/tasks/test_dynamic.py +231 -0
- experimaestro/tests/test_checkers.py +2 -2
- experimaestro/tests/test_cli_jobs.py +615 -0
- experimaestro/tests/test_dependencies.py +11 -17
- experimaestro/tests/test_deprecated.py +630 -0
- experimaestro/tests/test_environment.py +200 -0
- experimaestro/tests/test_experiment.py +3 -3
- experimaestro/tests/test_file_progress.py +425 -0
- experimaestro/tests/test_file_progress_integration.py +477 -0
- experimaestro/tests/test_forward.py +3 -3
- experimaestro/tests/test_generators.py +93 -0
- experimaestro/tests/test_identifier.py +520 -169
- experimaestro/tests/test_identifier_stability.py +458 -0
- experimaestro/tests/test_instance.py +16 -21
- experimaestro/tests/test_multitoken.py +442 -0
- experimaestro/tests/test_mypy.py +433 -0
- experimaestro/tests/test_objects.py +314 -30
- experimaestro/tests/test_outputs.py +8 -8
- experimaestro/tests/test_param.py +22 -26
- experimaestro/tests/test_partial_paths.py +231 -0
- experimaestro/tests/test_progress.py +2 -50
- experimaestro/tests/test_resumable_task.py +480 -0
- experimaestro/tests/test_serializers.py +141 -60
- experimaestro/tests/test_state_db.py +434 -0
- experimaestro/tests/test_subparameters.py +160 -0
- experimaestro/tests/test_tags.py +151 -15
- experimaestro/tests/test_tasks.py +137 -160
- experimaestro/tests/test_token_locking.py +252 -0
- experimaestro/tests/test_tokens.py +25 -19
- experimaestro/tests/test_types.py +133 -11
- experimaestro/tests/test_validation.py +19 -19
- experimaestro/tests/test_workspace_triggers.py +158 -0
- experimaestro/tests/token_reschedule.py +5 -3
- experimaestro/tests/utils.py +2 -2
- experimaestro/tokens.py +154 -57
- experimaestro/tools/diff.py +8 -1
- experimaestro/tui/__init__.py +8 -0
- experimaestro/tui/app.py +2303 -0
- experimaestro/tui/app.tcss +353 -0
- experimaestro/tui/log_viewer.py +228 -0
- experimaestro/typingutils.py +11 -2
- experimaestro/utils/__init__.py +23 -0
- experimaestro/utils/environment.py +148 -0
- experimaestro/utils/git.py +129 -0
- experimaestro/utils/resources.py +1 -1
- experimaestro/version.py +34 -0
- {experimaestro-1.11.1.dist-info → experimaestro-2.0.0b4.dist-info}/METADATA +70 -39
- experimaestro-2.0.0b4.dist-info/RECORD +181 -0
- {experimaestro-1.11.1.dist-info → experimaestro-2.0.0b4.dist-info}/WHEEL +1 -1
- experimaestro-2.0.0b4.dist-info/entry_points.txt +16 -0
- experimaestro/compat.py +0 -6
- experimaestro/core/objects.pyi +0 -225
- experimaestro/server/data/0c35d18bf06992036b69.woff2 +0 -0
- experimaestro/server/data/219aa9140e099e6c72ed.woff2 +0 -0
- experimaestro/server/data/3a4004a46a653d4b2166.woff +0 -0
- experimaestro/server/data/3baa5b8f3469222b822d.woff +0 -0
- experimaestro/server/data/4d73cb90e394b34b7670.woff +0 -0
- experimaestro/server/data/4ef4218c522f1eb6b5b1.woff2 +0 -0
- experimaestro/server/data/5d681e2edae8c60630db.woff +0 -0
- experimaestro/server/data/6f420cf17cc0d7676fad.woff2 +0 -0
- experimaestro/server/data/c380809fd3677d7d6903.woff2 +0 -0
- experimaestro/server/data/f882956fd323fd322f31.woff +0 -0
- experimaestro-1.11.1.dist-info/RECORD +0 -158
- experimaestro-1.11.1.dist-info/entry_points.txt +0 -17
- {experimaestro-1.11.1.dist-info → experimaestro-2.0.0b4.dist-info/licenses}/LICENSE +0 -0
|
@@ -9,7 +9,6 @@ from experimaestro import taskglobals
|
|
|
9
9
|
|
|
10
10
|
from termcolor import cprint
|
|
11
11
|
from pathlib import Path
|
|
12
|
-
import hashlib
|
|
13
12
|
import logging
|
|
14
13
|
import io
|
|
15
14
|
from enum import Enum
|
|
@@ -20,12 +19,10 @@ from typing import (
|
|
|
20
19
|
Callable,
|
|
21
20
|
ClassVar,
|
|
22
21
|
Dict,
|
|
23
|
-
Iterator,
|
|
24
22
|
List,
|
|
25
23
|
Optional,
|
|
26
24
|
Set,
|
|
27
25
|
Tuple,
|
|
28
|
-
Type,
|
|
29
26
|
TypeVar,
|
|
30
27
|
Union,
|
|
31
28
|
overload,
|
|
@@ -33,13 +30,14 @@ from typing import (
|
|
|
33
30
|
)
|
|
34
31
|
import sys
|
|
35
32
|
import experimaestro
|
|
36
|
-
from experimaestro.utils import logger
|
|
33
|
+
from experimaestro.utils import logger, get_caller_location
|
|
37
34
|
from experimaestro.core.types import DeprecatedAttribute, ObjectType, TypeVarType
|
|
38
35
|
from ..context import SerializationContext, SerializedPath, SerializedPathLoader
|
|
39
36
|
|
|
40
37
|
if TYPE_CHECKING:
|
|
41
38
|
from ..callbacks import TaskEventListener
|
|
42
39
|
from ..identifier import Identifier
|
|
40
|
+
from ..subparameters import Subparameters
|
|
43
41
|
from experimaestro.scheduler.base import Job
|
|
44
42
|
from experimaestro.scheduler.workspace import RunMode
|
|
45
43
|
from experimaestro.launchers import Launcher
|
|
@@ -49,7 +47,6 @@ from .config_walk import ConfigWalk, ConfigWalkContext
|
|
|
49
47
|
from .config_utils import (
|
|
50
48
|
getqualattr,
|
|
51
49
|
add_to_path,
|
|
52
|
-
SealedError,
|
|
53
50
|
TaggedValue,
|
|
54
51
|
ObjectStore,
|
|
55
52
|
classproperty,
|
|
@@ -59,6 +56,26 @@ T = TypeVar("T", bound="Config")
|
|
|
59
56
|
|
|
60
57
|
|
|
61
58
|
DependentMarker = Callable[["Config"], None]
|
|
59
|
+
"""Type alias for dependency marker functions.
|
|
60
|
+
|
|
61
|
+
A DependentMarker is a callable that marks a configuration as a dependency
|
|
62
|
+
of another configuration. Used in ``task_outputs()`` and dynamic output methods
|
|
63
|
+
to establish task dependencies.
|
|
64
|
+
|
|
65
|
+
Example::
|
|
66
|
+
|
|
67
|
+
class Learn(Task):
|
|
68
|
+
model: Param[Model]
|
|
69
|
+
|
|
70
|
+
def task_outputs(self, dep: DependentMarker):
|
|
71
|
+
return dep(Checkpoint.C(model=self.model, path=self.checkpoint_path))
|
|
72
|
+
|
|
73
|
+
class Validation(Config):
|
|
74
|
+
model: Param[Model]
|
|
75
|
+
|
|
76
|
+
def checkpoint(self, dep: DependentMarker, *, step: int) -> Checkpoint:
|
|
77
|
+
return dep(Checkpoint.C(model=self.model, step=step))
|
|
78
|
+
"""
|
|
62
79
|
|
|
63
80
|
|
|
64
81
|
def updatedependencies(
|
|
@@ -94,9 +111,6 @@ NOT_SET = object()
|
|
|
94
111
|
|
|
95
112
|
@define()
|
|
96
113
|
class WatchedOutput:
|
|
97
|
-
#: The enclosing job
|
|
98
|
-
job: "Job"
|
|
99
|
-
|
|
100
114
|
#: The configuration containing the watched output
|
|
101
115
|
config: "ConfigInformation"
|
|
102
116
|
|
|
@@ -109,6 +123,61 @@ class WatchedOutput:
|
|
|
109
123
|
#: The callback to call (with the output of the previous method)
|
|
110
124
|
callback: Callable
|
|
111
125
|
|
|
126
|
+
#: The enclosing job (set when registered with scheduler)
|
|
127
|
+
job: Optional["Job"] = None
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def get_generated_paths(
|
|
131
|
+
v: Union["ConfigMixin", list, dict],
|
|
132
|
+
path: list[str] | None = None,
|
|
133
|
+
paths: list[str] | None = None,
|
|
134
|
+
) -> list[str]:
|
|
135
|
+
"""Get the list of generated paths, useful to track down those
|
|
136
|
+
|
|
137
|
+
:param path: The current path
|
|
138
|
+
:param paths: The list of generated paths so far, defaults to None
|
|
139
|
+
:return: The full list of generated paths
|
|
140
|
+
"""
|
|
141
|
+
paths = [] if paths is None else paths
|
|
142
|
+
path = [] if path is None else path
|
|
143
|
+
|
|
144
|
+
if isinstance(v, list):
|
|
145
|
+
for ix, element in enumerate(v):
|
|
146
|
+
get_generated_paths(element, path + [f"[{ix}]"], paths)
|
|
147
|
+
|
|
148
|
+
elif isinstance(v, dict):
|
|
149
|
+
for key, element in v.items():
|
|
150
|
+
get_generated_paths(element, path + [f"[{key}]"], paths)
|
|
151
|
+
|
|
152
|
+
elif isinstance(v, ConfigMixin):
|
|
153
|
+
for key in v.__xpm__._generated_values:
|
|
154
|
+
value = v.__xpm__.values[key]
|
|
155
|
+
if isinstance(value, ConfigMixin) and value.__xpm__._generated_values:
|
|
156
|
+
path.append(key)
|
|
157
|
+
get_generated_paths(value, path, paths)
|
|
158
|
+
path.pop()
|
|
159
|
+
else:
|
|
160
|
+
paths.append(".".join(path + [key]))
|
|
161
|
+
else:
|
|
162
|
+
raise ValueError(f"Cannot handle type {type(v)}")
|
|
163
|
+
return paths
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@define
|
|
167
|
+
class TaskStub:
|
|
168
|
+
"""Stub for a task that was not loaded during partial loading.
|
|
169
|
+
|
|
170
|
+
This is used when loading configurations from disk (e.g., HuggingFace)
|
|
171
|
+
where the task code may have changed or is not available. The stub stores
|
|
172
|
+
the identifier and typename so the information is preserved.
|
|
173
|
+
"""
|
|
174
|
+
|
|
175
|
+
identifier: "Identifier"
|
|
176
|
+
"""The experimaestro identifier of the task"""
|
|
177
|
+
|
|
178
|
+
typename: str
|
|
179
|
+
"""The type name of the task (e.g., 'mymodule.MyTask')"""
|
|
180
|
+
|
|
112
181
|
|
|
113
182
|
class ConfigInformation:
|
|
114
183
|
"""Holds experimaestro information for a config (or task) instance"""
|
|
@@ -122,11 +191,13 @@ class ConfigInformation:
|
|
|
122
191
|
def __init__(self, pyobject: "ConfigMixin"):
|
|
123
192
|
# The underlying pyobject and XPM type
|
|
124
193
|
self.pyobject = pyobject
|
|
125
|
-
self.xpmtype = pyobject.__xpmtype__
|
|
194
|
+
self.xpmtype: "ObjectType" = pyobject.__xpmtype__
|
|
126
195
|
self.values = {}
|
|
127
196
|
|
|
128
197
|
# Meta-informations
|
|
129
|
-
|
|
198
|
+
# Tags are stored as {name: (value, source_location)}
|
|
199
|
+
# where source_location is "file:line" string for error reporting
|
|
200
|
+
self._tags: dict[str, tuple[Any, str]] = {}
|
|
130
201
|
self._initinfo = ""
|
|
131
202
|
|
|
132
203
|
self._taskoutput = None
|
|
@@ -142,16 +213,13 @@ class ConfigInformation:
|
|
|
142
213
|
#: True when this configuration was loaded from disk
|
|
143
214
|
self.loaded = False
|
|
144
215
|
|
|
145
|
-
#
|
|
216
|
+
# Explicitly added dependencies
|
|
146
217
|
self.dependencies = []
|
|
147
218
|
|
|
148
219
|
# Concrete type variables resolutions
|
|
149
220
|
# This is used to check typevars coherence
|
|
150
221
|
self.concrete_typevars: Dict[TypeVar, type] = {}
|
|
151
222
|
|
|
152
|
-
# Lightweight tasks
|
|
153
|
-
self.pre_tasks: List["LightweightTask"] = []
|
|
154
|
-
|
|
155
223
|
# Initialization tasks
|
|
156
224
|
self.init_tasks: List["LightweightTask"] = []
|
|
157
225
|
|
|
@@ -160,16 +228,21 @@ class ConfigInformation:
|
|
|
160
228
|
|
|
161
229
|
# Cached information
|
|
162
230
|
|
|
163
|
-
self.
|
|
164
|
-
"""The
|
|
231
|
+
self._identifier = None
|
|
232
|
+
"""The configuration identifier (cached when sealed)"""
|
|
165
233
|
|
|
166
|
-
self.
|
|
167
|
-
"""
|
|
234
|
+
self._partial_identifiers: Dict[str, "Identifier"] = {}
|
|
235
|
+
"""Cached partial identifiers (keyed by subparameters name)"""
|
|
168
236
|
|
|
169
237
|
self._validated = False
|
|
170
238
|
self._sealed = False
|
|
171
239
|
self._meta = None
|
|
172
240
|
|
|
241
|
+
# This contains the list of generated values (using context) in this
|
|
242
|
+
# configuration or any sub-configuration, is generated. This prevents
|
|
243
|
+
# problem when a configuration with generated values is re-used.
|
|
244
|
+
self._generated_values = []
|
|
245
|
+
|
|
173
246
|
def set_meta(self, value: Optional[bool]):
|
|
174
247
|
"""Sets the meta flag"""
|
|
175
248
|
assert not self._sealed, "Configuration is sealed"
|
|
@@ -187,6 +260,31 @@ class ConfigInformation:
|
|
|
187
260
|
# Not an argument, bypass
|
|
188
261
|
return object.__getattribute__(self.pyobject, name)
|
|
189
262
|
|
|
263
|
+
@staticmethod
|
|
264
|
+
def is_generated_value(argument, value):
|
|
265
|
+
if argument.ignore_generated:
|
|
266
|
+
return False
|
|
267
|
+
|
|
268
|
+
if value is None:
|
|
269
|
+
return False
|
|
270
|
+
|
|
271
|
+
if isinstance(value, (int, str, float, bool, Enum, Path)):
|
|
272
|
+
return False
|
|
273
|
+
|
|
274
|
+
if isinstance(value, ConfigMixin):
|
|
275
|
+
return value.__xpm__._generated_values and value.__xpm__.task is None
|
|
276
|
+
|
|
277
|
+
if isinstance(value, list):
|
|
278
|
+
return any(ConfigInformation.is_generated_value(argument, x) for x in value)
|
|
279
|
+
|
|
280
|
+
if isinstance(value, dict):
|
|
281
|
+
return any(
|
|
282
|
+
ConfigInformation.is_generated_value(argument, x)
|
|
283
|
+
for x in value.values()
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
return False
|
|
287
|
+
|
|
190
288
|
def set(self, k, v, bypass=False):
|
|
191
289
|
from experimaestro.generators import Generator
|
|
192
290
|
|
|
@@ -198,9 +296,21 @@ class ConfigInformation:
|
|
|
198
296
|
if self._sealed and not bypass:
|
|
199
297
|
raise AttributeError(f"Object is read-only (trying to set {k})")
|
|
200
298
|
|
|
299
|
+
if not isinstance(v, ConfigMixin) and isinstance(v, Config):
|
|
300
|
+
raise AttributeError(
|
|
301
|
+
"Configuration (and not objects) should be used. Consider using .C(...)"
|
|
302
|
+
)
|
|
303
|
+
|
|
201
304
|
try:
|
|
202
305
|
argument = self.xpmtype.arguments.get(k, None)
|
|
203
306
|
if argument:
|
|
307
|
+
if ConfigInformation.is_generated_value(argument, v):
|
|
308
|
+
raise AttributeError(
|
|
309
|
+
f"Cannot set {k} to a configuration with generated values. "
|
|
310
|
+
"Here is the list of paths to help you: "
|
|
311
|
+
f"""{', '.join(get_generated_paths(v, [k]))}"""
|
|
312
|
+
)
|
|
313
|
+
|
|
204
314
|
if not bypass and (
|
|
205
315
|
(isinstance(argument.generator, Generator)) or argument.constant
|
|
206
316
|
):
|
|
@@ -263,8 +373,17 @@ class ConfigInformation:
|
|
|
263
373
|
f" (current typevars bindings: {self.concrete_typevars})"
|
|
264
374
|
)
|
|
265
375
|
|
|
266
|
-
def addtag(self, name, value):
|
|
267
|
-
|
|
376
|
+
def addtag(self, name, value, source: str = None):
|
|
377
|
+
"""Add a tag with optional source location for error reporting
|
|
378
|
+
|
|
379
|
+
Args:
|
|
380
|
+
name: Tag name
|
|
381
|
+
value: Tag value
|
|
382
|
+
source: Source location string (file:line). If None, captured from caller.
|
|
383
|
+
"""
|
|
384
|
+
if source is None:
|
|
385
|
+
source = get_caller_location(skip_frames=1)
|
|
386
|
+
self._tags[name] = (value, source)
|
|
268
387
|
|
|
269
388
|
def xpmvalues(self, generated=False):
|
|
270
389
|
"""Returns an iterarator over arguments and associated values"""
|
|
@@ -276,11 +395,29 @@ class ConfigInformation:
|
|
|
276
395
|
class TagFinder(ConfigWalk):
|
|
277
396
|
def __init__(self):
|
|
278
397
|
super().__init__(recurse_task=True)
|
|
279
|
-
|
|
398
|
+
# Store {name: (value, source)} for conflict detection
|
|
399
|
+
self.tags_with_source: dict[str, tuple[Any, str]] = {}
|
|
280
400
|
|
|
281
401
|
def postprocess(self, stub, config: Config, values):
|
|
282
|
-
|
|
283
|
-
|
|
402
|
+
for name, (value, source) in config.__xpm__._tags.items():
|
|
403
|
+
if name in self.tags_with_source:
|
|
404
|
+
existing_value, existing_source = self.tags_with_source[name]
|
|
405
|
+
if existing_value != value:
|
|
406
|
+
logger.warning(
|
|
407
|
+
"Tag '%s' has conflicting values: "
|
|
408
|
+
"'%s' (set at %s) vs '%s' (set at %s). "
|
|
409
|
+
"Using the latter value.",
|
|
410
|
+
name,
|
|
411
|
+
existing_value,
|
|
412
|
+
existing_source,
|
|
413
|
+
value,
|
|
414
|
+
source,
|
|
415
|
+
)
|
|
416
|
+
self.tags_with_source[name] = (value, source)
|
|
417
|
+
# Return just the values (without source info)
|
|
418
|
+
return {
|
|
419
|
+
name: value for name, (value, _) in self.tags_with_source.items()
|
|
420
|
+
}
|
|
284
421
|
|
|
285
422
|
return TagFinder()(self.pyobject)
|
|
286
423
|
|
|
@@ -302,10 +439,6 @@ class ConfigInformation:
|
|
|
302
439
|
% (k, self.xpmtype, self._initinfo)
|
|
303
440
|
)
|
|
304
441
|
|
|
305
|
-
# Validate pre-tasks
|
|
306
|
-
for pre_task in self.pre_tasks:
|
|
307
|
-
pre_task.__xpm__.validate()
|
|
308
|
-
|
|
309
442
|
# Validate init tasks
|
|
310
443
|
for init_task in self.init_tasks:
|
|
311
444
|
init_task.__xpm__.validate()
|
|
@@ -326,12 +459,21 @@ class ConfigInformation:
|
|
|
326
459
|
Arguments:
|
|
327
460
|
- context: the generation context
|
|
328
461
|
"""
|
|
462
|
+
if generated_keys := [
|
|
463
|
+
k
|
|
464
|
+
for k, v in self.values.items()
|
|
465
|
+
if ConfigInformation.is_generated_value(self.xpmtype.arguments[k], v)
|
|
466
|
+
]:
|
|
467
|
+
raise AttributeError(
|
|
468
|
+
"Cannot seal a configuration with generated values:"
|
|
469
|
+
f"""{",".join(generated_keys)} in {context.currentpath}"""
|
|
470
|
+
)
|
|
329
471
|
|
|
330
472
|
class Sealer(ConfigWalk):
|
|
331
|
-
def preprocess(self, config:
|
|
473
|
+
def preprocess(self, config: ConfigMixin):
|
|
332
474
|
return not config.__xpm__._sealed, config
|
|
333
475
|
|
|
334
|
-
def postprocess(self, stub, config:
|
|
476
|
+
def postprocess(self, stub, config: ConfigMixin, values):
|
|
335
477
|
# Generate values
|
|
336
478
|
from experimaestro.generators import Generator
|
|
337
479
|
|
|
@@ -344,22 +486,42 @@ class ConfigInformation:
|
|
|
344
486
|
continue
|
|
345
487
|
value = argument.generator()
|
|
346
488
|
else:
|
|
489
|
+
# Generate a value
|
|
347
490
|
sig = inspect.signature(argument.generator)
|
|
348
491
|
if len(sig.parameters) == 0:
|
|
349
492
|
value = argument.generator()
|
|
350
493
|
elif len(sig.parameters) == 2:
|
|
494
|
+
# Only in that case do we need to flag this configuration
|
|
495
|
+
# as containing generated values
|
|
496
|
+
if not argument.ignore_generated:
|
|
497
|
+
config.__xpm__._generated_values.append(k)
|
|
498
|
+
else:
|
|
499
|
+
logging.warning("Ignoring %s", k)
|
|
351
500
|
value = argument.generator(self.context, config)
|
|
352
501
|
else:
|
|
353
502
|
assert (
|
|
354
503
|
False
|
|
355
504
|
), "generator has either two parameters (context and config), or none"
|
|
356
505
|
config.__xpm__.set(k, value, bypass=True)
|
|
506
|
+
else:
|
|
507
|
+
value = config.__xpm__.values.get(k)
|
|
357
508
|
except Exception:
|
|
358
509
|
logger.error(
|
|
359
510
|
"While setting %s of %s", argument.name, config.__xpmtype__
|
|
360
511
|
)
|
|
361
512
|
raise
|
|
362
513
|
|
|
514
|
+
# Propagate the generated value flag
|
|
515
|
+
if (
|
|
516
|
+
value is not None
|
|
517
|
+
and isinstance(value, ConfigMixin)
|
|
518
|
+
and value.__xpm__._generated_values
|
|
519
|
+
):
|
|
520
|
+
if not argument.ignore_generated:
|
|
521
|
+
config.__xpm__._generated_values.append(k)
|
|
522
|
+
else:
|
|
523
|
+
logging.warning("Ignoring %s", k)
|
|
524
|
+
|
|
363
525
|
config.__xpm__._sealed = True
|
|
364
526
|
|
|
365
527
|
Sealer(context, recurse_task=True)(self.pyobject)
|
|
@@ -372,90 +534,56 @@ class ConfigInformation:
|
|
|
372
534
|
context = ConfigWalkContext()
|
|
373
535
|
|
|
374
536
|
class Unsealer(ConfigWalk):
|
|
375
|
-
def preprocess(self, config:
|
|
537
|
+
def preprocess(self, config: ConfigMixin):
|
|
376
538
|
return config.__xpm__._sealed, config
|
|
377
539
|
|
|
378
|
-
def postprocess(self, stub, config:
|
|
540
|
+
def postprocess(self, stub, config: ConfigMixin, values):
|
|
379
541
|
config.__xpm__._sealed = False
|
|
380
542
|
config.__xpm__._identifier = None
|
|
381
543
|
|
|
382
544
|
Unsealer(context, recurse_task=True)(self.pyobject)
|
|
383
545
|
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
pre_tasks: Dict[int, "Config"] = {}
|
|
387
|
-
|
|
388
|
-
class PreTaskCollect(ConfigWalk):
|
|
389
|
-
def preprocess(self, config: Config):
|
|
390
|
-
# Do not cross tasks
|
|
391
|
-
return not isinstance(config.__xpm__, Task), config
|
|
392
|
-
|
|
393
|
-
def postprocess(self, stub, config: Config, values):
|
|
394
|
-
pre_tasks.update(
|
|
395
|
-
{id(pre_task): pre_task for pre_task in config.__xpm__.pre_tasks}
|
|
396
|
-
)
|
|
397
|
-
|
|
398
|
-
PreTaskCollect(context, recurse_task=True)(self.pyobject)
|
|
399
|
-
return pre_tasks.values()
|
|
400
|
-
|
|
401
|
-
def identifiers(self, only_raw: bool):
|
|
546
|
+
@property
|
|
547
|
+
def identifier(self):
|
|
402
548
|
"""Computes the unique identifier"""
|
|
403
|
-
from ..identifier import IdentifierComputer
|
|
404
|
-
|
|
405
|
-
raw_identifier = self._raw_identifier
|
|
406
|
-
full_identifier = self._full_identifier
|
|
549
|
+
from ..identifier import IdentifierComputer
|
|
407
550
|
|
|
408
551
|
# Computes raw identifier if needed
|
|
409
|
-
if
|
|
410
|
-
|
|
411
|
-
raw_identifier = IdentifierComputer.compute(self.pyobject)
|
|
412
|
-
if self._sealed:
|
|
413
|
-
self._raw_identifier = raw_identifier
|
|
414
|
-
|
|
415
|
-
if only_raw:
|
|
416
|
-
return raw_identifier, full_identifier
|
|
417
|
-
|
|
418
|
-
# OK, let's compute the full identifier
|
|
419
|
-
if full_identifier is None or not self._sealed:
|
|
420
|
-
# Compute the full identifier by including the pre-tasks
|
|
421
|
-
hasher = hashlib.sha256()
|
|
422
|
-
hasher.update(raw_identifier.all)
|
|
423
|
-
pre_tasks_ids = [
|
|
424
|
-
pre_task.__xpm__.raw_identifier.all
|
|
425
|
-
for pre_task in self.collect_pre_tasks()
|
|
426
|
-
]
|
|
427
|
-
for task_id in sorted(pre_tasks_ids):
|
|
428
|
-
hasher.update(task_id)
|
|
552
|
+
if self._identifier is not None:
|
|
553
|
+
return self._identifier
|
|
429
554
|
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
555
|
+
# Get the main identifier
|
|
556
|
+
identifier = IdentifierComputer.compute(self.pyobject)
|
|
557
|
+
if self._sealed:
|
|
558
|
+
self._identifier = identifier
|
|
559
|
+
return identifier
|
|
435
560
|
|
|
436
|
-
|
|
437
|
-
|
|
561
|
+
def get_partial_identifier(self, subparameters: "Subparameters") -> "Identifier":
|
|
562
|
+
"""Get the partial identifier for a given subparameters instance.
|
|
438
563
|
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
564
|
+
Partial identifiers exclude certain parameter groups, allowing
|
|
565
|
+
configurations that differ only in those groups to share the same
|
|
566
|
+
partial identifier (and thus the same partial directory).
|
|
442
567
|
|
|
443
|
-
|
|
568
|
+
Args:
|
|
569
|
+
subparameters: The Subparameters instance defining which groups
|
|
570
|
+
to include/exclude.
|
|
444
571
|
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
"""
|
|
448
|
-
|
|
449
|
-
return raw_identifier
|
|
572
|
+
Returns:
|
|
573
|
+
The partial identifier for this configuration.
|
|
574
|
+
"""
|
|
575
|
+
from ..identifier import IdentifierComputer
|
|
450
576
|
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
577
|
+
name = subparameters.name
|
|
578
|
+
if name in self._partial_identifiers:
|
|
579
|
+
return self._partial_identifiers[name]
|
|
580
|
+
|
|
581
|
+
identifier = IdentifierComputer.compute_partial(self.pyobject, subparameters)
|
|
582
|
+
|
|
583
|
+
if self._sealed:
|
|
584
|
+
self._partial_identifiers[name] = identifier
|
|
456
585
|
|
|
457
|
-
|
|
458
|
-
"""Deprecated: use full_identifier"""
|
|
586
|
+
return identifier
|
|
459
587
|
|
|
460
588
|
def dependency(self):
|
|
461
589
|
"""Returns a dependency"""
|
|
@@ -470,12 +598,6 @@ class ConfigInformation:
|
|
|
470
598
|
path: List[str],
|
|
471
599
|
taskids: Set[int],
|
|
472
600
|
):
|
|
473
|
-
# Add pre-tasks
|
|
474
|
-
for pre_task in self.pre_tasks:
|
|
475
|
-
pre_task.__xpm__.updatedependencies(
|
|
476
|
-
dependencies, path + ["__pre_tasks__"], taskids
|
|
477
|
-
)
|
|
478
|
-
|
|
479
601
|
# Add initialization tasks
|
|
480
602
|
for init_task in self.init_tasks:
|
|
481
603
|
init_task.__xpm__.updatedependencies(
|
|
@@ -537,10 +659,20 @@ class ConfigInformation:
|
|
|
537
659
|
|
|
538
660
|
:param method: The method to watch
|
|
539
661
|
:param callback: The callback
|
|
662
|
+
|
|
663
|
+
:raises TypeError: If the task is not a ResumableTask
|
|
540
664
|
"""
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
)
|
|
665
|
+
# Only ResumableTask can have dynamic outputs - regular tasks
|
|
666
|
+
# have their directories cleaned up, losing the output file
|
|
667
|
+
if not isinstance(self.pyobject, ResumableTask):
|
|
668
|
+
raise TypeError(
|
|
669
|
+
f"Only ResumableTask can use watch_output. "
|
|
670
|
+
f"{self.xpmtype} is not a ResumableTask. "
|
|
671
|
+
"Dynamic outputs require the task directory to be preserved "
|
|
672
|
+
"across restarts, which only ResumableTask provides."
|
|
673
|
+
)
|
|
674
|
+
|
|
675
|
+
watched = WatchedOutput(method.__self__, method.__name__, method, callback)
|
|
544
676
|
self.watched_outputs.append(watched)
|
|
545
677
|
if self.job:
|
|
546
678
|
self.job.watch_output(watched)
|
|
@@ -561,6 +693,7 @@ class ConfigInformation:
|
|
|
561
693
|
*,
|
|
562
694
|
run_mode=None,
|
|
563
695
|
init_tasks: List["LightweightTask"] = [],
|
|
696
|
+
max_retries: Optional[int] = None,
|
|
564
697
|
):
|
|
565
698
|
from experimaestro.scheduler import experiment, JobContext
|
|
566
699
|
from experimaestro.scheduler.workspace import RunMode
|
|
@@ -580,7 +713,11 @@ class ConfigInformation:
|
|
|
580
713
|
|
|
581
714
|
# Creates a new job
|
|
582
715
|
self.job = self.xpmtype.task(
|
|
583
|
-
self.pyobject,
|
|
716
|
+
self.pyobject,
|
|
717
|
+
launcher=launcher,
|
|
718
|
+
workspace=workspace,
|
|
719
|
+
run_mode=run_mode,
|
|
720
|
+
max_retries=max_retries,
|
|
584
721
|
)
|
|
585
722
|
|
|
586
723
|
# Validate the object
|
|
@@ -620,8 +757,8 @@ class ConfigInformation:
|
|
|
620
757
|
TaskEventListener.connect(experiment.CURRENT)
|
|
621
758
|
other = experiment.CURRENT.submit(self.job)
|
|
622
759
|
if other:
|
|
623
|
-
#
|
|
624
|
-
|
|
760
|
+
# Our job = previously submitted job
|
|
761
|
+
self.job = other
|
|
625
762
|
else:
|
|
626
763
|
# Show a warning
|
|
627
764
|
if run_mode == RunMode.GENERATE_ONLY:
|
|
@@ -657,13 +794,6 @@ class ConfigInformation:
|
|
|
657
794
|
|
|
658
795
|
print(file=sys.stderr) # noqa: T201
|
|
659
796
|
|
|
660
|
-
# Handle an output configuration # FIXME: remove
|
|
661
|
-
def mark_output(config: "Config"):
|
|
662
|
-
"""Sets a dependency on the job"""
|
|
663
|
-
assert not isinstance(config, Task), "Cannot set a dependency on a task"
|
|
664
|
-
config.__xpm__.task = self.pyobject
|
|
665
|
-
return config
|
|
666
|
-
|
|
667
797
|
# Mark this configuration also
|
|
668
798
|
self.task = self.pyobject
|
|
669
799
|
|
|
@@ -677,6 +807,9 @@ class ConfigInformation:
|
|
|
677
807
|
def mark_output(self, config: "Config"):
|
|
678
808
|
"""Sets a dependency on the job"""
|
|
679
809
|
assert not isinstance(config, Task), "Cannot set a dependency on a task"
|
|
810
|
+
assert isinstance(
|
|
811
|
+
config, ConfigMixin
|
|
812
|
+
), "Only configurations can be marked as dependent on a task"
|
|
680
813
|
config.__xpm__.task = self.pyobject
|
|
681
814
|
return config
|
|
682
815
|
|
|
@@ -752,9 +885,6 @@ class ConfigInformation:
|
|
|
752
885
|
if self.task is not None and self.task is not self:
|
|
753
886
|
ConfigInformation.__collect_objects__(self.task, objects, context)
|
|
754
887
|
|
|
755
|
-
# Serialize pre-tasks
|
|
756
|
-
ConfigInformation.__collect_objects__(self.pre_tasks, objects, context)
|
|
757
|
-
|
|
758
888
|
# Serialize initialization tasks
|
|
759
889
|
ConfigInformation.__collect_objects__(self.init_tasks, objects, context)
|
|
760
890
|
|
|
@@ -762,14 +892,12 @@ class ConfigInformation:
|
|
|
762
892
|
state_dict = {
|
|
763
893
|
"id": id(self.pyobject),
|
|
764
894
|
"module": self.xpmtype._module,
|
|
765
|
-
"type": self.xpmtype.
|
|
895
|
+
"type": self.xpmtype.value_type.__qualname__,
|
|
766
896
|
"typename": self.xpmtype.name(),
|
|
767
897
|
"identifier": self.identifier.state_dict(),
|
|
768
898
|
}
|
|
769
899
|
|
|
770
900
|
# Add pre/init tasks
|
|
771
|
-
if self.pre_tasks:
|
|
772
|
-
state_dict["pre-tasks"] = [id(pre_task) for pre_task in self.pre_tasks]
|
|
773
901
|
if self.init_tasks:
|
|
774
902
|
state_dict["init-tasks"] = [id(init_task) for init_task in self.init_tasks]
|
|
775
903
|
|
|
@@ -826,7 +954,6 @@ class ConfigInformation:
|
|
|
826
954
|
|
|
827
955
|
The format is an array of objects
|
|
828
956
|
{
|
|
829
|
-
"tags: [ LIST_OF_TAGS ],
|
|
830
957
|
"workspace": FOLDERPATH,
|
|
831
958
|
"version": 2,
|
|
832
959
|
"objects": [
|
|
@@ -846,6 +973,10 @@ class ConfigInformation:
|
|
|
846
973
|
|
|
847
974
|
The last object is the one that is serialized
|
|
848
975
|
|
|
976
|
+
Note: Tags are no longer stored in params.json. They are managed by the
|
|
977
|
+
experiment state provider (scoped to job_id, experiment_id, run_id) and
|
|
978
|
+
also stored in experiment state.json for full experiment details.
|
|
979
|
+
|
|
849
980
|
Arguments:
|
|
850
981
|
out {io.TextIOBase} -- The output stream
|
|
851
982
|
context {[type]} -- the command context
|
|
@@ -853,8 +984,8 @@ class ConfigInformation:
|
|
|
853
984
|
json.dump(
|
|
854
985
|
{
|
|
855
986
|
"workspace": str(context.workspace.path.absolute()),
|
|
856
|
-
"tags": {key: value for key, value in self.tags().items()},
|
|
857
987
|
"version": 2,
|
|
988
|
+
"experimaestro": experimaestro.__version__,
|
|
858
989
|
"objects": self.__get_objects__([], context),
|
|
859
990
|
},
|
|
860
991
|
out,
|
|
@@ -877,12 +1008,18 @@ class ConfigInformation:
|
|
|
877
1008
|
path: Union[str, Path, SerializedPathLoader],
|
|
878
1009
|
as_instance: bool = False,
|
|
879
1010
|
return_tasks: bool = False,
|
|
1011
|
+
partial_loading: Optional[bool] = None,
|
|
880
1012
|
) -> "Config":
|
|
881
1013
|
"""Deserialize a configuration
|
|
882
1014
|
|
|
883
1015
|
:param path: The filesystem Path to use, or a way to download the
|
|
884
1016
|
information through a function taking two arguments
|
|
885
1017
|
:param as_instance: Return an instance
|
|
1018
|
+
:param return_tasks: Return init tasks instead of executing them
|
|
1019
|
+
:param partial_loading: If True, skip loading task references. If None
|
|
1020
|
+
(default), partial_loading is enabled when as_instance is True.
|
|
1021
|
+
This is useful when loading configurations from disk (e.g.,
|
|
1022
|
+
HuggingFace) where the task code may have changed.
|
|
886
1023
|
:return: a Config object, its instance or a tuple (instance, init_tasks) is return_tasks is True
|
|
887
1024
|
"""
|
|
888
1025
|
# Load
|
|
@@ -907,6 +1044,7 @@ class ConfigInformation:
|
|
|
907
1044
|
as_instance=as_instance,
|
|
908
1045
|
data_loader=data_loader,
|
|
909
1046
|
return_tasks=return_tasks,
|
|
1047
|
+
partial_loading=partial_loading,
|
|
910
1048
|
)
|
|
911
1049
|
|
|
912
1050
|
@staticmethod
|
|
@@ -949,34 +1087,97 @@ class ConfigInformation:
|
|
|
949
1087
|
|
|
950
1088
|
@overload
|
|
951
1089
|
@staticmethod
|
|
952
|
-
def fromParameters(
|
|
1090
|
+
def fromParameters( # noqa: E704
|
|
953
1091
|
definitions: List[Dict],
|
|
954
1092
|
as_instance=True,
|
|
955
1093
|
save_directory: Optional[Path] = None,
|
|
956
1094
|
discard_id: bool = False,
|
|
957
|
-
|
|
958
|
-
|
|
1095
|
+
partial_loading: Optional[bool] = None,
|
|
1096
|
+
) -> "ConfigMixin": ...
|
|
959
1097
|
|
|
960
1098
|
@overload
|
|
961
1099
|
@staticmethod
|
|
962
|
-
def fromParameters(
|
|
1100
|
+
def fromParameters( # noqa: E704
|
|
963
1101
|
definitions: List[Dict],
|
|
964
1102
|
as_instance=False,
|
|
965
1103
|
return_tasks=True,
|
|
966
1104
|
save_directory: Optional[Path] = None,
|
|
967
1105
|
discard_id: bool = False,
|
|
968
|
-
|
|
969
|
-
|
|
1106
|
+
partial_loading: Optional[bool] = None,
|
|
1107
|
+
) -> Tuple["Config", List["LightweightTask"]]: ...
|
|
970
1108
|
|
|
971
1109
|
@overload
|
|
972
1110
|
@staticmethod
|
|
973
|
-
def fromParameters(
|
|
1111
|
+
def fromParameters( # noqa: E704
|
|
974
1112
|
definitions: List[Dict],
|
|
975
1113
|
as_instance=False,
|
|
976
1114
|
save_directory: Optional[Path] = None,
|
|
977
1115
|
discard_id: bool = False,
|
|
978
|
-
|
|
979
|
-
|
|
1116
|
+
partial_loading: Optional[bool] = None,
|
|
1117
|
+
) -> "Config": ...
|
|
1118
|
+
|
|
1119
|
+
@staticmethod
|
|
1120
|
+
def _get_field_refs(value: Any) -> Set[int]:
|
|
1121
|
+
"""Recursively extract object references from a serialized field value"""
|
|
1122
|
+
refs: Set[int] = set()
|
|
1123
|
+
if isinstance(value, dict):
|
|
1124
|
+
if value.get("type") == "python":
|
|
1125
|
+
refs.add(value["value"])
|
|
1126
|
+
else:
|
|
1127
|
+
for v in value.values():
|
|
1128
|
+
refs.update(ConfigInformation._get_field_refs(v))
|
|
1129
|
+
elif isinstance(value, list):
|
|
1130
|
+
for v in value:
|
|
1131
|
+
refs.update(ConfigInformation._get_field_refs(v))
|
|
1132
|
+
return refs
|
|
1133
|
+
|
|
1134
|
+
@staticmethod
|
|
1135
|
+
def _compute_skipped_ids(definitions: List[Dict]) -> Set[int]:
|
|
1136
|
+
"""Compute IDs of objects only reachable through task references.
|
|
1137
|
+
|
|
1138
|
+
When partial_loading is enabled, we skip loading task references and
|
|
1139
|
+
any objects that are only used by those tasks. This method computes
|
|
1140
|
+
which object IDs should be skipped by finding objects reachable from
|
|
1141
|
+
the main object (last definition) without following task references.
|
|
1142
|
+
|
|
1143
|
+
:param definitions: List of object definitions from JSON
|
|
1144
|
+
:return: Set of object IDs to skip loading
|
|
1145
|
+
"""
|
|
1146
|
+
# Build index of definitions by ID
|
|
1147
|
+
def_by_id = {d["id"]: d for d in definitions}
|
|
1148
|
+
|
|
1149
|
+
# Compute reachable objects from main object (last definition)
|
|
1150
|
+
# without going through task references
|
|
1151
|
+
main_defn = definitions[-1]
|
|
1152
|
+
main_id = main_defn["id"]
|
|
1153
|
+
reachable: Set[int] = set()
|
|
1154
|
+
to_visit = [main_id]
|
|
1155
|
+
|
|
1156
|
+
# Also include init-tasks as reachable (needed for as_instance/return_tasks)
|
|
1157
|
+
for init_task_id in main_defn.get("init-tasks", []):
|
|
1158
|
+
to_visit.append(init_task_id)
|
|
1159
|
+
|
|
1160
|
+
while to_visit:
|
|
1161
|
+
obj_id = to_visit.pop()
|
|
1162
|
+
if obj_id in reachable:
|
|
1163
|
+
continue
|
|
1164
|
+
reachable.add(obj_id)
|
|
1165
|
+
|
|
1166
|
+
defn = def_by_id.get(obj_id)
|
|
1167
|
+
if defn is None:
|
|
1168
|
+
continue
|
|
1169
|
+
|
|
1170
|
+
# Add field references (not task reference)
|
|
1171
|
+
for field_value in defn.get("fields", {}).values():
|
|
1172
|
+
for ref_id in ConfigInformation._get_field_refs(field_value):
|
|
1173
|
+
if ref_id not in reachable:
|
|
1174
|
+
to_visit.append(ref_id)
|
|
1175
|
+
|
|
1176
|
+
# Note: we intentionally skip defn["task"] to avoid loading tasks
|
|
1177
|
+
|
|
1178
|
+
# All objects not reachable should be skipped
|
|
1179
|
+
all_ids = {d["id"] for d in definitions}
|
|
1180
|
+
return all_ids - reachable
|
|
980
1181
|
|
|
981
1182
|
@staticmethod
|
|
982
1183
|
def load_objects( # noqa: C901
|
|
@@ -984,15 +1185,43 @@ class ConfigInformation:
|
|
|
984
1185
|
as_instance=True,
|
|
985
1186
|
data_loader: Optional[SerializedPathLoader] = None,
|
|
986
1187
|
discard_id: bool = False,
|
|
1188
|
+
partial_loading: bool = False,
|
|
987
1189
|
):
|
|
988
|
-
"""Load the objects
|
|
1190
|
+
"""Load the objects
|
|
1191
|
+
|
|
1192
|
+
:param definitions: List of object definitions from JSON
|
|
1193
|
+
:param as_instance: Return instances instead of configs
|
|
1194
|
+
:param data_loader: Function to load data files
|
|
1195
|
+
:param discard_id: If True, don't use the stored identifier
|
|
1196
|
+
:param partial_loading: If True, skip loading task references. This is
|
|
1197
|
+
useful when loading configurations from disk (e.g., HuggingFace)
|
|
1198
|
+
where the task code may have changed.
|
|
1199
|
+
"""
|
|
989
1200
|
o = None
|
|
990
1201
|
objects = {}
|
|
991
1202
|
import experimaestro.taskglobals as taskglobals
|
|
992
1203
|
from ..identifier import Identifier
|
|
993
1204
|
|
|
1205
|
+
# Compute which objects to skip when partial_loading
|
|
1206
|
+
skipped_ids = (
|
|
1207
|
+
ConfigInformation._compute_skipped_ids(definitions)
|
|
1208
|
+
if partial_loading
|
|
1209
|
+
else set()
|
|
1210
|
+
)
|
|
1211
|
+
|
|
994
1212
|
# Loop over all the definitions and create objects
|
|
995
1213
|
for definition in definitions:
|
|
1214
|
+
obj_id = definition["id"]
|
|
1215
|
+
|
|
1216
|
+
# Skip objects that are only reachable through task references
|
|
1217
|
+
if obj_id in skipped_ids:
|
|
1218
|
+
# Create a TaskStub for skipped task objects
|
|
1219
|
+
objects[obj_id] = TaskStub(
|
|
1220
|
+
identifier=Identifier.from_state_dict(definition["identifier"]),
|
|
1221
|
+
typename=definition.get("typename", definition["type"]),
|
|
1222
|
+
)
|
|
1223
|
+
continue
|
|
1224
|
+
|
|
996
1225
|
module_name = definition["module"]
|
|
997
1226
|
|
|
998
1227
|
# Avoids problem when runing module
|
|
@@ -1022,15 +1251,21 @@ class ConfigInformation:
|
|
|
1022
1251
|
|
|
1023
1252
|
# Creates an object (or a config)
|
|
1024
1253
|
if as_instance:
|
|
1025
|
-
o = cls.
|
|
1254
|
+
o = cls.__new__(cls)
|
|
1026
1255
|
else:
|
|
1027
1256
|
o = cls.XPMConfig.__new__(cls.XPMConfig)
|
|
1028
|
-
assert
|
|
1029
|
-
objects[
|
|
1257
|
+
assert obj_id not in objects, "Duplicate id %s" % obj_id
|
|
1258
|
+
objects[obj_id] = o
|
|
1030
1259
|
|
|
1031
1260
|
# Now that objects have been created, fill in the fields
|
|
1032
1261
|
for definition in definitions:
|
|
1033
|
-
|
|
1262
|
+
obj_id = definition["id"]
|
|
1263
|
+
|
|
1264
|
+
# Skip processing skipped objects (they are TaskStubs)
|
|
1265
|
+
if obj_id in skipped_ids:
|
|
1266
|
+
continue
|
|
1267
|
+
|
|
1268
|
+
o = objects[obj_id]
|
|
1034
1269
|
xpmtype = o.__getxpmtype__() # type: ObjectType
|
|
1035
1270
|
|
|
1036
1271
|
# If instance...
|
|
@@ -1101,12 +1336,6 @@ class ConfigInformation:
|
|
|
1101
1336
|
o.__post_init__()
|
|
1102
1337
|
|
|
1103
1338
|
else:
|
|
1104
|
-
# Sets pre-tasks
|
|
1105
|
-
o.__xpm__.pre_tasks = [
|
|
1106
|
-
objects[pre_task_id]
|
|
1107
|
-
for pre_task_id in definition.get("pre-tasks", [])
|
|
1108
|
-
]
|
|
1109
|
-
|
|
1110
1339
|
if task_id := definition.get("task", None):
|
|
1111
1340
|
o.__xpm__.task = objects[task_id]
|
|
1112
1341
|
|
|
@@ -1126,13 +1355,20 @@ class ConfigInformation:
|
|
|
1126
1355
|
data_loader: Optional[SerializedPathLoader] = None,
|
|
1127
1356
|
discard_id: bool = False,
|
|
1128
1357
|
return_tasks: bool = False,
|
|
1358
|
+
partial_loading: Optional[bool] = None,
|
|
1129
1359
|
):
|
|
1360
|
+
# Determine effective partial_loading: as_instance implies partial_loading
|
|
1361
|
+
effective_partial_loading = (
|
|
1362
|
+
partial_loading if partial_loading is not None else as_instance
|
|
1363
|
+
)
|
|
1364
|
+
|
|
1130
1365
|
# Get the objects
|
|
1131
1366
|
objects = ConfigInformation.load_objects(
|
|
1132
1367
|
definitions,
|
|
1133
1368
|
as_instance=as_instance,
|
|
1134
1369
|
data_loader=data_loader,
|
|
1135
1370
|
discard_id=discard_id,
|
|
1371
|
+
partial_loading=effective_partial_loading,
|
|
1136
1372
|
)
|
|
1137
1373
|
|
|
1138
1374
|
# Get the last one
|
|
@@ -1140,15 +1376,6 @@ class ConfigInformation:
|
|
|
1140
1376
|
|
|
1141
1377
|
# Run pre-task (or returns them)
|
|
1142
1378
|
if as_instance or return_tasks:
|
|
1143
|
-
# Collect pre-tasks (just once)
|
|
1144
|
-
completed_pretasks = set()
|
|
1145
|
-
pre_tasks = []
|
|
1146
|
-
for definition in definitions:
|
|
1147
|
-
for pre_task_id in definition.get("pre-tasks", []):
|
|
1148
|
-
if pre_task_id not in completed_pretasks:
|
|
1149
|
-
completed_pretasks.add(pre_task_id)
|
|
1150
|
-
pre_tasks.append(objects[pre_task_id])
|
|
1151
|
-
|
|
1152
1379
|
# Collect init tasks
|
|
1153
1380
|
init_tasks = []
|
|
1154
1381
|
for init_task_id in definitions[-1].get("init-tasks", []):
|
|
@@ -1156,14 +1383,11 @@ class ConfigInformation:
|
|
|
1156
1383
|
init_tasks.append(init_task)
|
|
1157
1384
|
|
|
1158
1385
|
if as_instance:
|
|
1159
|
-
for pre_task in pre_tasks:
|
|
1160
|
-
logger.info("Executing pre-task %s", type(pre_task))
|
|
1161
|
-
pre_task.execute()
|
|
1162
1386
|
for init_task in init_tasks:
|
|
1163
1387
|
logger.info("Executing init task %s", type(init_task))
|
|
1164
1388
|
init_task.execute()
|
|
1165
1389
|
else:
|
|
1166
|
-
return o,
|
|
1390
|
+
return o, init_tasks
|
|
1167
1391
|
|
|
1168
1392
|
return o
|
|
1169
1393
|
|
|
@@ -1171,7 +1395,6 @@ class ConfigInformation:
|
|
|
1171
1395
|
def __init__(self, context: ConfigWalkContext, *, objects: ObjectStore = None):
|
|
1172
1396
|
super().__init__(context)
|
|
1173
1397
|
self.objects = ObjectStore() if objects is None else objects
|
|
1174
|
-
self.pre_tasks = {}
|
|
1175
1398
|
|
|
1176
1399
|
def preprocess(self, config: "Config"):
|
|
1177
1400
|
if self.objects.is_constructed(id(config)):
|
|
@@ -1183,7 +1406,7 @@ class ConfigInformation:
|
|
|
1183
1406
|
|
|
1184
1407
|
if o is None:
|
|
1185
1408
|
# Creates an object (and not a config)
|
|
1186
|
-
o = config.
|
|
1409
|
+
o = config.__xpmtype__.value_type()
|
|
1187
1410
|
|
|
1188
1411
|
# Store in cache
|
|
1189
1412
|
self.objects.add_stub(id(config), o)
|
|
@@ -1198,10 +1421,6 @@ class ConfigInformation:
|
|
|
1198
1421
|
# Call __post_init__
|
|
1199
1422
|
stub.__post_init__()
|
|
1200
1423
|
|
|
1201
|
-
# Gather pre-tasks
|
|
1202
|
-
for pre_task in config.__xpm__.pre_tasks:
|
|
1203
|
-
self.pre_tasks[id(pre_task)] = self.stub(pre_task)
|
|
1204
|
-
|
|
1205
1424
|
self.objects.set_constructed(id(config))
|
|
1206
1425
|
return stub
|
|
1207
1426
|
|
|
@@ -1215,10 +1434,6 @@ class ConfigInformation:
|
|
|
1215
1434
|
processor = ConfigInformation.FromPython(context, objects=objects)
|
|
1216
1435
|
last_object = processor(self.pyobject)
|
|
1217
1436
|
|
|
1218
|
-
# Execute pre-tasks
|
|
1219
|
-
for pre_task in processor.pre_tasks.values():
|
|
1220
|
-
pre_task.execute()
|
|
1221
|
-
|
|
1222
1437
|
return last_object
|
|
1223
1438
|
|
|
1224
1439
|
def add_dependencies(self, *dependencies):
|
|
@@ -1242,6 +1457,9 @@ def clone(v):
|
|
|
1242
1457
|
if isinstance(v, Enum):
|
|
1243
1458
|
return v
|
|
1244
1459
|
|
|
1460
|
+
if isinstance(v, tuple):
|
|
1461
|
+
return tuple(clone(x) for x in v)
|
|
1462
|
+
|
|
1245
1463
|
if isinstance(v, Config):
|
|
1246
1464
|
# Create a new instance
|
|
1247
1465
|
kwargs = {
|
|
@@ -1260,10 +1478,104 @@ class ConfigMixin:
|
|
|
1260
1478
|
"""Class for configuration objects"""
|
|
1261
1479
|
|
|
1262
1480
|
__xpmtype__: ObjectType
|
|
1481
|
+
"""The associated XPM type"""
|
|
1482
|
+
|
|
1483
|
+
__xpm__: ConfigInformation
|
|
1484
|
+
"""The __xpm__ object contains all instance specific information about a
|
|
1485
|
+
configuration/task"""
|
|
1486
|
+
|
|
1487
|
+
# Set when this instance was created via a deprecated config with replace=True
|
|
1488
|
+
_deprecated_from: "ConfigMixin | None" = None
|
|
1489
|
+
|
|
1490
|
+
def __new__(cls, **kwargs):
|
|
1491
|
+
"""Create a new config instance, handling deprecated replacements."""
|
|
1492
|
+
xpmtype = cls.__xpmtype__
|
|
1493
|
+
|
|
1494
|
+
# Check if this is a deprecated type with replace=True
|
|
1495
|
+
if xpmtype._deprecation is not None and xpmtype._deprecation.replace:
|
|
1496
|
+
# Create the deprecated instance normally
|
|
1497
|
+
instance = object.__new__(cls)
|
|
1498
|
+
# Initialize it
|
|
1499
|
+
ConfigMixin.__init__(instance, **kwargs)
|
|
1500
|
+
# Convert to the new type
|
|
1501
|
+
converted = instance.__convert__()
|
|
1502
|
+
# Mark that this came from a deprecated config
|
|
1503
|
+
converted._deprecated_from = instance
|
|
1504
|
+
return converted
|
|
1505
|
+
|
|
1506
|
+
# Normal creation
|
|
1507
|
+
return object.__new__(cls)
|
|
1508
|
+
|
|
1509
|
+
def __getattribute__(self, name: str):
|
|
1510
|
+
"""Get an attribute, handling XPM arguments specially.
|
|
1511
|
+
|
|
1512
|
+
We use __getattribute__ instead of __getattr__ because default values
|
|
1513
|
+
like `b: Param[X] = None` create class attributes that would prevent
|
|
1514
|
+
__getattr__ from being called.
|
|
1515
|
+
"""
|
|
1516
|
+
# Get __xpm__ without recursion
|
|
1517
|
+
try:
|
|
1518
|
+
xpm = object.__getattribute__(self, "__xpm__")
|
|
1519
|
+
except AttributeError:
|
|
1520
|
+
# During early init, __xpm__ may not exist yet
|
|
1521
|
+
return object.__getattribute__(self, name)
|
|
1522
|
+
|
|
1523
|
+
# Check if this is an XPM argument - parameters take precedence
|
|
1524
|
+
xpmtype = object.__getattribute__(self, "__xpmtype__")
|
|
1525
|
+
if name in xpmtype.arguments:
|
|
1526
|
+
return xpm.get(name)
|
|
1527
|
+
|
|
1528
|
+
# Fall back to normal lookup (methods, etc.)
|
|
1529
|
+
return object.__getattribute__(self, name)
|
|
1530
|
+
|
|
1531
|
+
def __setattr__(self, name: str, value):
|
|
1532
|
+
"""Set an attribute, handling XPM arguments specially."""
|
|
1533
|
+
# Allow setting internal attributes directly
|
|
1534
|
+
if name in ("__xpm__", "_deprecated_from"):
|
|
1535
|
+
object.__setattr__(self, name, value)
|
|
1536
|
+
return
|
|
1537
|
+
|
|
1538
|
+
# Check if we have __xpm__ yet (might not during early init)
|
|
1539
|
+
xpm = self.__dict__.get("__xpm__")
|
|
1540
|
+
if xpm is None:
|
|
1541
|
+
object.__setattr__(self, name, value)
|
|
1542
|
+
return
|
|
1543
|
+
|
|
1544
|
+
# Check if this is an XPM argument
|
|
1545
|
+
xpmtype = self.__xpmtype__
|
|
1546
|
+
if name in xpmtype.arguments:
|
|
1547
|
+
# Handle TaggedValue: extract value and add tag
|
|
1548
|
+
if isinstance(value, TaggedValue):
|
|
1549
|
+
actual_value = value.value
|
|
1550
|
+
source = get_caller_location(skip_frames=1)
|
|
1551
|
+
xpm.addtag(name, actual_value, source=source)
|
|
1552
|
+
xpm.set(name, actual_value)
|
|
1553
|
+
else:
|
|
1554
|
+
xpm.set(name, value)
|
|
1555
|
+
return
|
|
1556
|
+
|
|
1557
|
+
# Check for deprecated replacement warning
|
|
1558
|
+
deprecated_from = self.__dict__.get("_deprecated_from")
|
|
1559
|
+
if deprecated_from is not None:
|
|
1560
|
+
deprecated_xpmtype = deprecated_from.__xpmtype__
|
|
1561
|
+
if name in deprecated_xpmtype.arguments:
|
|
1562
|
+
logger.warning(
|
|
1563
|
+
f"Attribute '{name}' was in deprecated config "
|
|
1564
|
+
f"{deprecated_xpmtype.identifier} but is not in "
|
|
1565
|
+
f"{xpmtype.identifier}. The value is being discarded."
|
|
1566
|
+
)
|
|
1567
|
+
return # Don't set the attribute
|
|
1568
|
+
|
|
1569
|
+
# Normal attribute setting
|
|
1570
|
+
object.__setattr__(self, name, value)
|
|
1263
1571
|
|
|
1264
1572
|
def __init__(self, **kwargs):
|
|
1265
1573
|
"""Initialize the configuration with the given parameters"""
|
|
1266
1574
|
|
|
1575
|
+
# Skip if already initialized (can happen with deprecated replace=True)
|
|
1576
|
+
if hasattr(self, "__xpm__"):
|
|
1577
|
+
return
|
|
1578
|
+
|
|
1267
1579
|
# Add configuration
|
|
1268
1580
|
xpmtype = self.__xpmtype__
|
|
1269
1581
|
|
|
@@ -1297,7 +1609,8 @@ class ConfigMixin:
|
|
|
1297
1609
|
# Special case of a tagged value
|
|
1298
1610
|
if isinstance(value, TaggedValue):
|
|
1299
1611
|
value = value.value
|
|
1300
|
-
|
|
1612
|
+
# Use _initinfo as source since tag is set at config creation
|
|
1613
|
+
self.__xpm__.addtag(name, value, source=xpm._initinfo)
|
|
1301
1614
|
|
|
1302
1615
|
# Really set the value
|
|
1303
1616
|
xpm.set(name, value)
|
|
@@ -1310,12 +1623,14 @@ class ConfigMixin:
|
|
|
1310
1623
|
[f"{key}={value}" for key, value in self.__xpm__.values.items()]
|
|
1311
1624
|
)
|
|
1312
1625
|
return (
|
|
1313
|
-
f"{self.__xpmtype__.
|
|
1314
|
-
f"{self.__xpmtype__.
|
|
1626
|
+
f"{self.__xpmtype__.value_type.__module__}."
|
|
1627
|
+
f"{self.__xpmtype__.value_type.__qualname__}({params})"
|
|
1315
1628
|
)
|
|
1316
1629
|
|
|
1317
1630
|
def tag(self, name, value):
|
|
1318
|
-
|
|
1631
|
+
# Capture caller's location and pass to addtag
|
|
1632
|
+
source = get_caller_location(skip_frames=1)
|
|
1633
|
+
self.__xpm__.addtag(name, value, source=source)
|
|
1319
1634
|
return self
|
|
1320
1635
|
|
|
1321
1636
|
def __eq__(self, other):
|
|
@@ -1340,9 +1655,20 @@ class ConfigMixin:
|
|
|
1340
1655
|
return self
|
|
1341
1656
|
|
|
1342
1657
|
def instance(
|
|
1343
|
-
self,
|
|
1658
|
+
self,
|
|
1659
|
+
context: ConfigWalkContext = None,
|
|
1660
|
+
*,
|
|
1661
|
+
objects: ObjectStore = None,
|
|
1662
|
+
keep: bool = True,
|
|
1344
1663
|
) -> T:
|
|
1345
|
-
"""Return an instance with the current values
|
|
1664
|
+
"""Return an instance with the current values
|
|
1665
|
+
|
|
1666
|
+
:param context: The context when computing the instance
|
|
1667
|
+
:param objects: The previously built objects (so that we avoid
|
|
1668
|
+
re-creating instances of past configurations)
|
|
1669
|
+
:param keep: register a configuration in the __config__ field of the
|
|
1670
|
+
instance
|
|
1671
|
+
"""
|
|
1346
1672
|
if context is None:
|
|
1347
1673
|
from experimaestro.xpmutils import EmptyContext
|
|
1348
1674
|
|
|
@@ -1351,7 +1677,11 @@ class ConfigMixin:
|
|
|
1351
1677
|
assert isinstance(
|
|
1352
1678
|
context, ConfigWalkContext
|
|
1353
1679
|
), f"{context.__class__} is not an instance of ConfigWalkContext"
|
|
1354
|
-
|
|
1680
|
+
|
|
1681
|
+
instance = self.__xpm__.fromConfig(context, objects=objects) # type: ignore
|
|
1682
|
+
if keep:
|
|
1683
|
+
object.__setattr__(instance, "__config__", self)
|
|
1684
|
+
return instance
|
|
1355
1685
|
|
|
1356
1686
|
def submit(
|
|
1357
1687
|
self,
|
|
@@ -1360,16 +1690,22 @@ class ConfigMixin:
|
|
|
1360
1690
|
launcher=None,
|
|
1361
1691
|
run_mode: "RunMode" = None,
|
|
1362
1692
|
init_tasks: List["LightweightTask"] = [],
|
|
1693
|
+
max_retries: Optional[int] = None,
|
|
1363
1694
|
):
|
|
1364
1695
|
"""Submit this task
|
|
1365
1696
|
|
|
1366
1697
|
:param workspace: the workspace, defaults to None
|
|
1367
1698
|
:param launcher: The launcher, defaults to None
|
|
1368
1699
|
:param run_mode: Run mode (if None, uses the workspace default)
|
|
1700
|
+
:param max_retries: Maximum number of retries for resumable tasks that timeout (default: from workspace settings or 3)
|
|
1369
1701
|
:return: an object object
|
|
1370
1702
|
"""
|
|
1371
1703
|
return self.__xpm__.submit(
|
|
1372
|
-
workspace,
|
|
1704
|
+
workspace,
|
|
1705
|
+
launcher,
|
|
1706
|
+
run_mode=run_mode,
|
|
1707
|
+
init_tasks=init_tasks,
|
|
1708
|
+
max_retries=max_retries,
|
|
1373
1709
|
)
|
|
1374
1710
|
|
|
1375
1711
|
def stdout(self):
|
|
@@ -1396,29 +1732,7 @@ class ConfigMixin:
|
|
|
1396
1732
|
attributes)"""
|
|
1397
1733
|
return clone(self)
|
|
1398
1734
|
|
|
1399
|
-
def
|
|
1400
|
-
assert all(
|
|
1401
|
-
[isinstance(task, LightweightTask) for task in tasks]
|
|
1402
|
-
), "One of the pre-tasks are not lightweight tasks"
|
|
1403
|
-
if self.__xpm__._sealed:
|
|
1404
|
-
raise SealedError("Cannot add pre-tasks to a sealed configuration")
|
|
1405
|
-
self.__xpm__.pre_tasks.extend(tasks)
|
|
1406
|
-
return self
|
|
1407
|
-
|
|
1408
|
-
def add_pretasks_from(self, *configs: "Config"):
|
|
1409
|
-
assert all(
|
|
1410
|
-
[isinstance(config, ConfigMixin) for config in configs]
|
|
1411
|
-
), "One of the parameters is not a configuration object"
|
|
1412
|
-
for config in configs:
|
|
1413
|
-
self.add_pretasks(*config.__xpm__.pre_tasks)
|
|
1414
|
-
return self
|
|
1415
|
-
|
|
1416
|
-
@property
|
|
1417
|
-
def pre_tasks(self) -> List["LightweightTask"]:
|
|
1418
|
-
"""Access pre-tasks"""
|
|
1419
|
-
return self.__xpm__.pre_tasks
|
|
1420
|
-
|
|
1421
|
-
def copy_dependencies(self, other: "Config"):
|
|
1735
|
+
def copy_dependencies(self, other: "ConfigMixin"):
|
|
1422
1736
|
"""Add all the dependencies from other configuration"""
|
|
1423
1737
|
|
|
1424
1738
|
# Add task dependency
|
|
@@ -1429,6 +1743,59 @@ class ConfigMixin:
|
|
|
1429
1743
|
# Add other dependencies
|
|
1430
1744
|
self.__xpm__.add_dependencies(*other.__xpm__.dependencies)
|
|
1431
1745
|
|
|
1746
|
+
def __rmatmul__(self, other: "ConfigMixin") -> "ConfigMixin":
|
|
1747
|
+
"""Right-associative composition operator: B() @ A(x=1) is equivalent to B(a=A(x=1))
|
|
1748
|
+
|
|
1749
|
+
For expression `other @ self`, finds the unique parameter in `other` that
|
|
1750
|
+
accepts `self`'s type and sets it. Returns `self` to enable right-associative
|
|
1751
|
+
chaining: `Outer() @ Middle() @ Inner()` builds Outer(middle=Middle(inner=Inner())).
|
|
1752
|
+
|
|
1753
|
+
The chain is evaluated left-to-right by Python, but returns the inner config
|
|
1754
|
+
so each step adds the current result as a parameter to the next outer config.
|
|
1755
|
+
|
|
1756
|
+
:param other: The outer configuration that will receive self as a parameter
|
|
1757
|
+
:return: self (the inner configuration) to continue the chain
|
|
1758
|
+
"""
|
|
1759
|
+
if not isinstance(other, ConfigMixin):
|
|
1760
|
+
return NotImplemented
|
|
1761
|
+
|
|
1762
|
+
# Find parameters in 'other' that can accept self's type
|
|
1763
|
+
self_type = self.__xpmtype__.value_type
|
|
1764
|
+
matching_params = []
|
|
1765
|
+
|
|
1766
|
+
for name, argument in other.__xpmtype__.arguments.items():
|
|
1767
|
+
# Get the expected type for this argument
|
|
1768
|
+
arg_type = argument.type
|
|
1769
|
+
if hasattr(arg_type, "value_type"):
|
|
1770
|
+
# It's an ObjectType wrapper
|
|
1771
|
+
expected_type = arg_type.value_type
|
|
1772
|
+
elif hasattr(arg_type, "__origin__"):
|
|
1773
|
+
# Generic type like Optional[X] or List[X]
|
|
1774
|
+
continue # Skip complex types for now
|
|
1775
|
+
elif isinstance(arg_type, type):
|
|
1776
|
+
expected_type = arg_type
|
|
1777
|
+
else:
|
|
1778
|
+
continue
|
|
1779
|
+
|
|
1780
|
+
# Check if self's type is compatible
|
|
1781
|
+
if isinstance(expected_type, type) and issubclass(self_type, expected_type):
|
|
1782
|
+
matching_params.append(name)
|
|
1783
|
+
|
|
1784
|
+
if len(matching_params) == 0:
|
|
1785
|
+
raise ValueError(
|
|
1786
|
+
f"No parameter in {other.__xpmtype__} accepts type {self_type.__name__}"
|
|
1787
|
+
)
|
|
1788
|
+
if len(matching_params) > 1:
|
|
1789
|
+
raise ValueError(
|
|
1790
|
+
f"Ambiguous composition: parameters {matching_params} in "
|
|
1791
|
+
f"{other.__xpmtype__} all accept type {self_type.__name__}"
|
|
1792
|
+
)
|
|
1793
|
+
|
|
1794
|
+
# Set the parameter on 'other'
|
|
1795
|
+
param_name = matching_params[0]
|
|
1796
|
+
other.__xpm__.set(param_name, self)
|
|
1797
|
+
return other
|
|
1798
|
+
|
|
1432
1799
|
|
|
1433
1800
|
class Config:
|
|
1434
1801
|
"""Base type for all objects in python interface"""
|
|
@@ -1441,51 +1808,78 @@ class Config:
|
|
|
1441
1808
|
"""The object type holds all the information about a specific subclass
|
|
1442
1809
|
experimaestro metadata"""
|
|
1443
1810
|
|
|
1444
|
-
__xpm__: ConfigInformation
|
|
1445
|
-
"""The __xpm__ object contains all instance specific information about a
|
|
1446
|
-
configuration/task"""
|
|
1447
|
-
|
|
1448
1811
|
@classproperty
|
|
1449
1812
|
def XPMConfig(cls):
|
|
1450
1813
|
if issubclass(cls, ConfigMixin):
|
|
1451
1814
|
return cls
|
|
1452
|
-
return cls.__getxpmtype__().
|
|
1815
|
+
return cls.__getxpmtype__().config_type
|
|
1816
|
+
|
|
1817
|
+
@classproperty
|
|
1818
|
+
def C(cls):
|
|
1819
|
+
"""Alias for XPMConfig"""
|
|
1820
|
+
return cls.XPMConfig
|
|
1453
1821
|
|
|
1454
1822
|
@classproperty
|
|
1455
1823
|
def XPMValue(cls):
|
|
1456
|
-
"""
|
|
1457
|
-
if issubclass(cls, ConfigMixin):
|
|
1458
|
-
return cls.__xpmtype__.objecttype
|
|
1824
|
+
"""Get the value class for this configuration.
|
|
1459
1825
|
|
|
1460
|
-
|
|
1461
|
-
|
|
1462
|
-
|
|
1463
|
-
|
|
1826
|
+
Returns the explicitly registered value class, or the base config class
|
|
1827
|
+
if no value class was registered.
|
|
1828
|
+
"""
|
|
1829
|
+
return cls.__getxpmtype__().value_type
|
|
1464
1830
|
|
|
1465
|
-
|
|
1466
|
-
|
|
1467
|
-
|
|
1468
|
-
if issubclass(s, Config) and (s is not Config)
|
|
1469
|
-
) or (XPMValue,)
|
|
1831
|
+
@classmethod
|
|
1832
|
+
def value_class(cls):
|
|
1833
|
+
"""Decorator to register an external value class for this configuration.
|
|
1470
1834
|
|
|
1471
|
-
|
|
1472
|
-
|
|
1473
|
-
|
|
1474
|
-
value_cls.__module__ = cls.__module__
|
|
1835
|
+
This allows declaring a separate class that will be used when creating
|
|
1836
|
+
instances, which is useful to avoid initializing resources (e.g., PyTorch)
|
|
1837
|
+
when only configuring.
|
|
1475
1838
|
|
|
1476
|
-
|
|
1839
|
+
.. code-block:: python
|
|
1477
1840
|
|
|
1478
|
-
|
|
1841
|
+
class Model(Config):
|
|
1842
|
+
hidden_size: Param[int]
|
|
1479
1843
|
|
|
1480
|
-
|
|
1481
|
-
|
|
1482
|
-
|
|
1483
|
-
|
|
1844
|
+
@Model.value_class()
|
|
1845
|
+
class TorchModel(Model, nn.Module):
|
|
1846
|
+
def __init__(self):
|
|
1847
|
+
super().__init__()
|
|
1848
|
+
self.layer = nn.Linear(self.hidden_size, self.hidden_size)
|
|
1484
1849
|
|
|
1485
|
-
|
|
1486
|
-
|
|
1487
|
-
"""
|
|
1488
|
-
|
|
1850
|
+
The value class must be a subclass of the configuration class
|
|
1851
|
+
and a subclass of parent configuration value classes (if any).
|
|
1852
|
+
"""
|
|
1853
|
+
|
|
1854
|
+
def decorator(value_class: type) -> type:
|
|
1855
|
+
xpmtype = cls.__getxpmtype__()
|
|
1856
|
+
|
|
1857
|
+
# Check that value class is a subclass of the config class
|
|
1858
|
+
if not issubclass(value_class, cls):
|
|
1859
|
+
raise TypeError(
|
|
1860
|
+
f"Value class {value_class.__name__} must be a subclass of "
|
|
1861
|
+
f"{cls.__name__}"
|
|
1862
|
+
)
|
|
1863
|
+
|
|
1864
|
+
# Check that value class inherits from parent value classes
|
|
1865
|
+
for base in cls.__bases__:
|
|
1866
|
+
if base is Config or not issubclass(base, Config):
|
|
1867
|
+
continue
|
|
1868
|
+
parent_xpmtype = base.__getxpmtype__()
|
|
1869
|
+
# Check if parent has an explicit value type (different from original)
|
|
1870
|
+
if parent_xpmtype.value_type is not parent_xpmtype._original_type:
|
|
1871
|
+
parent_value = parent_xpmtype.value_type
|
|
1872
|
+
if not issubclass(value_class, parent_value):
|
|
1873
|
+
raise TypeError(
|
|
1874
|
+
f"Value class {value_class.__name__} must be a subclass of "
|
|
1875
|
+
f"parent value class {parent_value.__name__}"
|
|
1876
|
+
)
|
|
1877
|
+
|
|
1878
|
+
# Register the value class
|
|
1879
|
+
xpmtype.set_value_type(value_class)
|
|
1880
|
+
return value_class
|
|
1881
|
+
|
|
1882
|
+
return decorator
|
|
1489
1883
|
|
|
1490
1884
|
@classmethod
|
|
1491
1885
|
def __getxpmtype__(cls) -> "ObjectType":
|
|
@@ -1503,46 +1897,6 @@ class Config:
|
|
|
1503
1897
|
raise
|
|
1504
1898
|
return xpmtype
|
|
1505
1899
|
|
|
1506
|
-
def __new__(cls: Type[T], *args, **kwargs) -> T:
|
|
1507
|
-
"""Returns an instance of a ConfigMixin (for compatibility, use XPMConfig
|
|
1508
|
-
or C if possible)
|
|
1509
|
-
|
|
1510
|
-
:deprecated: Use Config.C or Config.XPMConfig to construct a new
|
|
1511
|
-
configuration, and Config.V (or Config.XPMValue) for a new value
|
|
1512
|
-
"""
|
|
1513
|
-
# If this is an XPMValue, just return a new instance
|
|
1514
|
-
from experimaestro.core.types import XPMValue
|
|
1515
|
-
|
|
1516
|
-
if issubclass(cls, XPMValue):
|
|
1517
|
-
return object.__new__(cls)
|
|
1518
|
-
|
|
1519
|
-
# If this is the XPMConfig, just return a new instance
|
|
1520
|
-
# __init__ will be called
|
|
1521
|
-
if issubclass(cls, ConfigMixin):
|
|
1522
|
-
return object.__new__(cls)
|
|
1523
|
-
|
|
1524
|
-
# Log a deprecation warning for this way of creating a configuration
|
|
1525
|
-
caller = inspect.getframeinfo(inspect.stack()[1][0])
|
|
1526
|
-
logger.warning(
|
|
1527
|
-
"Creating a configuration using Config.__new__ is deprecated, and will be removed in a future version. "
|
|
1528
|
-
"Use Config.C or Config.XPMConfig to create a new configuration. "
|
|
1529
|
-
"Issue created at %s:%s",
|
|
1530
|
-
str(Path(caller.filename).absolute()),
|
|
1531
|
-
caller.lineno,
|
|
1532
|
-
)
|
|
1533
|
-
|
|
1534
|
-
# otherwise, we use the configuration type
|
|
1535
|
-
o: ConfigMixin = object.__new__(cls.__getxpmtype__().configtype)
|
|
1536
|
-
try:
|
|
1537
|
-
o.__init__(*args, **kwargs)
|
|
1538
|
-
except Exception:
|
|
1539
|
-
logger.error(
|
|
1540
|
-
"Init error in %s:%s"
|
|
1541
|
-
% (str(Path(caller.filename).absolute()), caller.lineno)
|
|
1542
|
-
)
|
|
1543
|
-
raise
|
|
1544
|
-
return o
|
|
1545
|
-
|
|
1546
1900
|
def __validate__(self):
|
|
1547
1901
|
"""Validate the values"""
|
|
1548
1902
|
pass
|
|
@@ -1557,17 +1911,7 @@ class Config:
|
|
|
1557
1911
|
return self.__xpm__.__json__()
|
|
1558
1912
|
|
|
1559
1913
|
def __identifier__(self) -> "Identifier":
|
|
1560
|
-
return self.__xpm__.
|
|
1561
|
-
|
|
1562
|
-
def add_pretasks(self, *tasks: "LightweightTask"):
|
|
1563
|
-
"""Add pre-tasks"""
|
|
1564
|
-
raise AssertionError("This method can only be used during configuration")
|
|
1565
|
-
|
|
1566
|
-
def add_pretasks_from(self, *configs: "Config"):
|
|
1567
|
-
"""Add pre-tasks from the listed configurations"""
|
|
1568
|
-
raise AssertionError(
|
|
1569
|
-
"The 'add_pretasks_from' can only be used during configuration"
|
|
1570
|
-
)
|
|
1914
|
+
return self.__xpm__.identifier
|
|
1571
1915
|
|
|
1572
1916
|
def copy_dependencies(self, other: "Config"):
|
|
1573
1917
|
"""Add pre-tasks from the listed configurations"""
|
|
@@ -1575,11 +1919,6 @@ class Config:
|
|
|
1575
1919
|
"The 'copy_dependencies' method can only be used during configuration"
|
|
1576
1920
|
)
|
|
1577
1921
|
|
|
1578
|
-
@property
|
|
1579
|
-
def pre_tasks(self) -> List["LightweightTask"]:
|
|
1580
|
-
"""Access pre-tasks"""
|
|
1581
|
-
raise AssertionError("Pre-tasks can be accessed only during configuration")
|
|
1582
|
-
|
|
1583
1922
|
def register_task_output(self, method, *args, **kwargs):
|
|
1584
1923
|
# Determine the path for this...
|
|
1585
1924
|
path = taskglobals.Env.instance().xpm_path / "task-outputs.jsonl"
|
|
@@ -1597,6 +1936,40 @@ class Config:
|
|
|
1597
1936
|
fp.flush()
|
|
1598
1937
|
|
|
1599
1938
|
|
|
1939
|
+
class InstanceConfig(Config):
|
|
1940
|
+
"""Base class for configurations where instance identity matters.
|
|
1941
|
+
|
|
1942
|
+
When a Config class derives from InstanceConfig instead of Config,
|
|
1943
|
+
instances are distinguished based on their object identity when used
|
|
1944
|
+
in containers. This enables distinguishing between shared and separate
|
|
1945
|
+
instances even when all parameters are identical.
|
|
1946
|
+
|
|
1947
|
+
Example:
|
|
1948
|
+
>>> class SubModel(InstanceConfig):
|
|
1949
|
+
... value: Param[int] = 100
|
|
1950
|
+
>>> class MainModel(Config):
|
|
1951
|
+
... m1: Param[SubModel]
|
|
1952
|
+
... m2: Param[SubModel]
|
|
1953
|
+
>>>
|
|
1954
|
+
>>> sm1 = SubModel.C()
|
|
1955
|
+
>>> sm2 = SubModel.C() # Same params, different instance
|
|
1956
|
+
>>>
|
|
1957
|
+
>>> # Shared instance (same object used twice)
|
|
1958
|
+
>>> shared = MainModel.C(m1=sm1, m2=sm1)
|
|
1959
|
+
>>>
|
|
1960
|
+
>>> # Separate instances (different objects)
|
|
1961
|
+
>>> separate = MainModel.C(m1=sm1, m2=sm2)
|
|
1962
|
+
>>>
|
|
1963
|
+
>>> # Different identifiers: shared vs separate
|
|
1964
|
+
>>> shared.__identifier__() != separate.__identifier__()
|
|
1965
|
+
|
|
1966
|
+
The instance order is determined by the traversal order during
|
|
1967
|
+
identifier computation, ensuring reproducibility.
|
|
1968
|
+
"""
|
|
1969
|
+
|
|
1970
|
+
pass
|
|
1971
|
+
|
|
1972
|
+
|
|
1600
1973
|
class LightweightTask(Config):
|
|
1601
1974
|
"""A task that can be run before or after a real task to modify its behaviour"""
|
|
1602
1975
|
|
|
@@ -1659,11 +2032,55 @@ def copyconfig(config: Config, **kwargs):
|
|
|
1659
2032
|
|
|
1660
2033
|
|
|
1661
2034
|
def setmeta(config: Config, flag: bool):
|
|
1662
|
-
"""
|
|
2035
|
+
"""Force a configuration to be treated as a meta-parameter.
|
|
2036
|
+
|
|
2037
|
+
When a configuration is marked as meta, it is excluded from the
|
|
2038
|
+
identifier computation of its parent configuration.
|
|
2039
|
+
|
|
2040
|
+
Example::
|
|
2041
|
+
|
|
2042
|
+
class Ensemble(Config):
|
|
2043
|
+
model1: Param[Model]
|
|
2044
|
+
model2: Param[Model]
|
|
2045
|
+
|
|
2046
|
+
# Mark model2 as meta - it won't affect the ensemble's identifier
|
|
2047
|
+
model2 = setmeta(Model.C(...), True)
|
|
2048
|
+
ensemble = Ensemble.C(model1=model1, model2=model2)
|
|
2049
|
+
|
|
2050
|
+
:param config: The configuration to mark
|
|
2051
|
+
:param flag: True to mark as meta, False to include in identifier
|
|
2052
|
+
:return: The same configuration (for chaining)
|
|
2053
|
+
"""
|
|
1663
2054
|
config.__xpm__.set_meta(flag)
|
|
1664
2055
|
return config
|
|
1665
2056
|
|
|
1666
2057
|
|
|
2058
|
+
class ResumableTask(Task):
|
|
2059
|
+
"""Base class for resumable/checkpointable tasks
|
|
2060
|
+
|
|
2061
|
+
Resumable tasks can be restarted if they are stopped by a time limit
|
|
2062
|
+
(e.g., SLURM job timeout). The task directory and dynamic outputs are
|
|
2063
|
+
preserved across restarts to allow checkpoint recovery.
|
|
2064
|
+
"""
|
|
2065
|
+
|
|
2066
|
+
def remaining_time(self) -> Optional[float]:
|
|
2067
|
+
"""Returns the remaining time in seconds before the job times out.
|
|
2068
|
+
|
|
2069
|
+
This is useful for checkpointing before hitting a time limit
|
|
2070
|
+
(e.g., SLURM walltime).
|
|
2071
|
+
|
|
2072
|
+
Returns:
|
|
2073
|
+
The remaining time in seconds, or None if:
|
|
2074
|
+
- There is no time limit
|
|
2075
|
+
- The launcher doesn't support querying remaining time
|
|
2076
|
+
- The task is not running
|
|
2077
|
+
"""
|
|
2078
|
+
launcher_info = taskglobals.Env.instance().launcher_info
|
|
2079
|
+
if launcher_info is None:
|
|
2080
|
+
return None
|
|
2081
|
+
return launcher_info.remaining_time()
|
|
2082
|
+
|
|
2083
|
+
|
|
1667
2084
|
def cache(fn, name: str):
|
|
1668
2085
|
def __call__(config, *args, **kwargs):
|
|
1669
2086
|
import experimaestro.taskglobals as taskglobals
|