experimaestro 2.0.0a8__py3-none-any.whl → 2.0.0b4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of experimaestro might be problematic. Click here for more details.
- experimaestro/__init__.py +10 -11
- experimaestro/annotations.py +167 -206
- experimaestro/cli/__init__.py +130 -5
- experimaestro/cli/filter.py +42 -74
- experimaestro/cli/jobs.py +157 -106
- experimaestro/cli/refactor.py +249 -0
- experimaestro/click.py +0 -1
- experimaestro/commandline.py +19 -3
- experimaestro/connectors/__init__.py +20 -1
- experimaestro/connectors/local.py +12 -0
- experimaestro/core/arguments.py +182 -46
- experimaestro/core/identifier.py +107 -6
- experimaestro/core/objects/__init__.py +6 -0
- experimaestro/core/objects/config.py +542 -25
- experimaestro/core/objects/config_walk.py +20 -0
- experimaestro/core/serialization.py +91 -34
- experimaestro/core/subparameters.py +164 -0
- experimaestro/core/types.py +175 -38
- experimaestro/exceptions.py +26 -0
- experimaestro/experiments/cli.py +107 -25
- experimaestro/generators.py +50 -9
- experimaestro/huggingface.py +3 -1
- experimaestro/launcherfinder/parser.py +29 -0
- experimaestro/launchers/__init__.py +26 -1
- experimaestro/launchers/direct.py +12 -0
- experimaestro/launchers/slurm/base.py +154 -2
- experimaestro/mkdocs/metaloader.py +0 -1
- experimaestro/mypy.py +452 -7
- experimaestro/notifications.py +63 -13
- experimaestro/progress.py +0 -2
- experimaestro/rpyc.py +0 -1
- experimaestro/run.py +19 -6
- experimaestro/scheduler/base.py +489 -125
- experimaestro/scheduler/dependencies.py +43 -28
- experimaestro/scheduler/dynamic_outputs.py +259 -130
- experimaestro/scheduler/experiment.py +225 -30
- experimaestro/scheduler/interfaces.py +474 -0
- experimaestro/scheduler/jobs.py +216 -206
- experimaestro/scheduler/services.py +186 -12
- experimaestro/scheduler/state_db.py +388 -0
- experimaestro/scheduler/state_provider.py +2345 -0
- experimaestro/scheduler/state_sync.py +834 -0
- experimaestro/scheduler/workspace.py +52 -10
- experimaestro/scriptbuilder.py +7 -0
- experimaestro/server/__init__.py +147 -57
- experimaestro/server/data/index.css +0 -125
- experimaestro/server/data/index.css.map +1 -1
- experimaestro/server/data/index.js +194 -58
- experimaestro/server/data/index.js.map +1 -1
- experimaestro/settings.py +44 -5
- experimaestro/sphinx/__init__.py +3 -3
- experimaestro/taskglobals.py +20 -0
- experimaestro/tests/conftest.py +80 -0
- experimaestro/tests/core/test_generics.py +2 -2
- experimaestro/tests/identifier_stability.json +45 -0
- experimaestro/tests/launchers/bin/sacct +6 -2
- experimaestro/tests/launchers/bin/sbatch +4 -2
- experimaestro/tests/launchers/test_slurm.py +80 -0
- experimaestro/tests/tasks/test_dynamic.py +231 -0
- experimaestro/tests/test_cli_jobs.py +615 -0
- experimaestro/tests/test_deprecated.py +630 -0
- experimaestro/tests/test_environment.py +200 -0
- experimaestro/tests/test_file_progress_integration.py +1 -1
- experimaestro/tests/test_forward.py +3 -3
- experimaestro/tests/test_identifier.py +372 -41
- experimaestro/tests/test_identifier_stability.py +458 -0
- experimaestro/tests/test_instance.py +3 -3
- experimaestro/tests/test_multitoken.py +442 -0
- experimaestro/tests/test_mypy.py +433 -0
- experimaestro/tests/test_objects.py +312 -5
- experimaestro/tests/test_outputs.py +2 -2
- experimaestro/tests/test_param.py +8 -12
- experimaestro/tests/test_partial_paths.py +231 -0
- experimaestro/tests/test_progress.py +0 -48
- experimaestro/tests/test_resumable_task.py +480 -0
- experimaestro/tests/test_serializers.py +141 -1
- experimaestro/tests/test_state_db.py +434 -0
- experimaestro/tests/test_subparameters.py +160 -0
- experimaestro/tests/test_tags.py +136 -0
- experimaestro/tests/test_tasks.py +107 -121
- experimaestro/tests/test_token_locking.py +252 -0
- experimaestro/tests/test_tokens.py +17 -13
- experimaestro/tests/test_types.py +123 -1
- experimaestro/tests/test_workspace_triggers.py +158 -0
- experimaestro/tests/token_reschedule.py +4 -2
- experimaestro/tests/utils.py +2 -2
- experimaestro/tokens.py +154 -57
- experimaestro/tools/diff.py +1 -1
- experimaestro/tui/__init__.py +8 -0
- experimaestro/tui/app.py +2303 -0
- experimaestro/tui/app.tcss +353 -0
- experimaestro/tui/log_viewer.py +228 -0
- experimaestro/utils/__init__.py +23 -0
- experimaestro/utils/environment.py +148 -0
- experimaestro/utils/git.py +129 -0
- experimaestro/utils/resources.py +1 -1
- experimaestro/version.py +34 -0
- {experimaestro-2.0.0a8.dist-info → experimaestro-2.0.0b4.dist-info}/METADATA +68 -38
- experimaestro-2.0.0b4.dist-info/RECORD +181 -0
- {experimaestro-2.0.0a8.dist-info → experimaestro-2.0.0b4.dist-info}/WHEEL +1 -1
- experimaestro-2.0.0b4.dist-info/entry_points.txt +16 -0
- experimaestro/compat.py +0 -6
- experimaestro/core/objects.pyi +0 -221
- experimaestro/server/data/0c35d18bf06992036b69.woff2 +0 -0
- experimaestro/server/data/219aa9140e099e6c72ed.woff2 +0 -0
- experimaestro/server/data/3a4004a46a653d4b2166.woff +0 -0
- experimaestro/server/data/3baa5b8f3469222b822d.woff +0 -0
- experimaestro/server/data/4d73cb90e394b34b7670.woff +0 -0
- experimaestro/server/data/4ef4218c522f1eb6b5b1.woff2 +0 -0
- experimaestro/server/data/5d681e2edae8c60630db.woff +0 -0
- experimaestro/server/data/6f420cf17cc0d7676fad.woff2 +0 -0
- experimaestro/server/data/c380809fd3677d7d6903.woff2 +0 -0
- experimaestro/server/data/f882956fd323fd322f31.woff +0 -0
- experimaestro-2.0.0a8.dist-info/RECORD +0 -166
- experimaestro-2.0.0a8.dist-info/entry_points.txt +0 -17
- {experimaestro-2.0.0a8.dist-info → experimaestro-2.0.0b4.dist-info}/licenses/LICENSE +0 -0
|
@@ -30,13 +30,14 @@ from typing import (
|
|
|
30
30
|
)
|
|
31
31
|
import sys
|
|
32
32
|
import experimaestro
|
|
33
|
-
from experimaestro.utils import logger
|
|
33
|
+
from experimaestro.utils import logger, get_caller_location
|
|
34
34
|
from experimaestro.core.types import DeprecatedAttribute, ObjectType, TypeVarType
|
|
35
35
|
from ..context import SerializationContext, SerializedPath, SerializedPathLoader
|
|
36
36
|
|
|
37
37
|
if TYPE_CHECKING:
|
|
38
38
|
from ..callbacks import TaskEventListener
|
|
39
39
|
from ..identifier import Identifier
|
|
40
|
+
from ..subparameters import Subparameters
|
|
40
41
|
from experimaestro.scheduler.base import Job
|
|
41
42
|
from experimaestro.scheduler.workspace import RunMode
|
|
42
43
|
from experimaestro.launchers import Launcher
|
|
@@ -55,6 +56,26 @@ T = TypeVar("T", bound="Config")
|
|
|
55
56
|
|
|
56
57
|
|
|
57
58
|
DependentMarker = Callable[["Config"], None]
|
|
59
|
+
"""Type alias for dependency marker functions.
|
|
60
|
+
|
|
61
|
+
A DependentMarker is a callable that marks a configuration as a dependency
|
|
62
|
+
of another configuration. Used in ``task_outputs()`` and dynamic output methods
|
|
63
|
+
to establish task dependencies.
|
|
64
|
+
|
|
65
|
+
Example::
|
|
66
|
+
|
|
67
|
+
class Learn(Task):
|
|
68
|
+
model: Param[Model]
|
|
69
|
+
|
|
70
|
+
def task_outputs(self, dep: DependentMarker):
|
|
71
|
+
return dep(Checkpoint.C(model=self.model, path=self.checkpoint_path))
|
|
72
|
+
|
|
73
|
+
class Validation(Config):
|
|
74
|
+
model: Param[Model]
|
|
75
|
+
|
|
76
|
+
def checkpoint(self, dep: DependentMarker, *, step: int) -> Checkpoint:
|
|
77
|
+
return dep(Checkpoint.C(model=self.model, step=step))
|
|
78
|
+
"""
|
|
58
79
|
|
|
59
80
|
|
|
60
81
|
def updatedependencies(
|
|
@@ -90,9 +111,6 @@ NOT_SET = object()
|
|
|
90
111
|
|
|
91
112
|
@define()
|
|
92
113
|
class WatchedOutput:
|
|
93
|
-
#: The enclosing job
|
|
94
|
-
job: "Job"
|
|
95
|
-
|
|
96
114
|
#: The configuration containing the watched output
|
|
97
115
|
config: "ConfigInformation"
|
|
98
116
|
|
|
@@ -105,6 +123,9 @@ class WatchedOutput:
|
|
|
105
123
|
#: The callback to call (with the output of the previous method)
|
|
106
124
|
callback: Callable
|
|
107
125
|
|
|
126
|
+
#: The enclosing job (set when registered with scheduler)
|
|
127
|
+
job: Optional["Job"] = None
|
|
128
|
+
|
|
108
129
|
|
|
109
130
|
def get_generated_paths(
|
|
110
131
|
v: Union["ConfigMixin", list, dict],
|
|
@@ -142,6 +163,22 @@ def get_generated_paths(
|
|
|
142
163
|
return paths
|
|
143
164
|
|
|
144
165
|
|
|
166
|
+
@define
|
|
167
|
+
class TaskStub:
|
|
168
|
+
"""Stub for a task that was not loaded during partial loading.
|
|
169
|
+
|
|
170
|
+
This is used when loading configurations from disk (e.g., HuggingFace)
|
|
171
|
+
where the task code may have changed or is not available. The stub stores
|
|
172
|
+
the identifier and typename so the information is preserved.
|
|
173
|
+
"""
|
|
174
|
+
|
|
175
|
+
identifier: "Identifier"
|
|
176
|
+
"""The experimaestro identifier of the task"""
|
|
177
|
+
|
|
178
|
+
typename: str
|
|
179
|
+
"""The type name of the task (e.g., 'mymodule.MyTask')"""
|
|
180
|
+
|
|
181
|
+
|
|
145
182
|
class ConfigInformation:
|
|
146
183
|
"""Holds experimaestro information for a config (or task) instance"""
|
|
147
184
|
|
|
@@ -158,7 +195,9 @@ class ConfigInformation:
|
|
|
158
195
|
self.values = {}
|
|
159
196
|
|
|
160
197
|
# Meta-informations
|
|
161
|
-
|
|
198
|
+
# Tags are stored as {name: (value, source_location)}
|
|
199
|
+
# where source_location is "file:line" string for error reporting
|
|
200
|
+
self._tags: dict[str, tuple[Any, str]] = {}
|
|
162
201
|
self._initinfo = ""
|
|
163
202
|
|
|
164
203
|
self._taskoutput = None
|
|
@@ -192,6 +231,9 @@ class ConfigInformation:
|
|
|
192
231
|
self._identifier = None
|
|
193
232
|
"""The configuration identifier (cached when sealed)"""
|
|
194
233
|
|
|
234
|
+
self._partial_identifiers: Dict[str, "Identifier"] = {}
|
|
235
|
+
"""Cached partial identifiers (keyed by subparameters name)"""
|
|
236
|
+
|
|
195
237
|
self._validated = False
|
|
196
238
|
self._sealed = False
|
|
197
239
|
self._meta = None
|
|
@@ -331,8 +373,17 @@ class ConfigInformation:
|
|
|
331
373
|
f" (current typevars bindings: {self.concrete_typevars})"
|
|
332
374
|
)
|
|
333
375
|
|
|
334
|
-
def addtag(self, name, value):
|
|
335
|
-
|
|
376
|
+
def addtag(self, name, value, source: str = None):
|
|
377
|
+
"""Add a tag with optional source location for error reporting
|
|
378
|
+
|
|
379
|
+
Args:
|
|
380
|
+
name: Tag name
|
|
381
|
+
value: Tag value
|
|
382
|
+
source: Source location string (file:line). If None, captured from caller.
|
|
383
|
+
"""
|
|
384
|
+
if source is None:
|
|
385
|
+
source = get_caller_location(skip_frames=1)
|
|
386
|
+
self._tags[name] = (value, source)
|
|
336
387
|
|
|
337
388
|
def xpmvalues(self, generated=False):
|
|
338
389
|
"""Returns an iterarator over arguments and associated values"""
|
|
@@ -344,11 +395,29 @@ class ConfigInformation:
|
|
|
344
395
|
class TagFinder(ConfigWalk):
|
|
345
396
|
def __init__(self):
|
|
346
397
|
super().__init__(recurse_task=True)
|
|
347
|
-
|
|
398
|
+
# Store {name: (value, source)} for conflict detection
|
|
399
|
+
self.tags_with_source: dict[str, tuple[Any, str]] = {}
|
|
348
400
|
|
|
349
401
|
def postprocess(self, stub, config: Config, values):
|
|
350
|
-
|
|
351
|
-
|
|
402
|
+
for name, (value, source) in config.__xpm__._tags.items():
|
|
403
|
+
if name in self.tags_with_source:
|
|
404
|
+
existing_value, existing_source = self.tags_with_source[name]
|
|
405
|
+
if existing_value != value:
|
|
406
|
+
logger.warning(
|
|
407
|
+
"Tag '%s' has conflicting values: "
|
|
408
|
+
"'%s' (set at %s) vs '%s' (set at %s). "
|
|
409
|
+
"Using the latter value.",
|
|
410
|
+
name,
|
|
411
|
+
existing_value,
|
|
412
|
+
existing_source,
|
|
413
|
+
value,
|
|
414
|
+
source,
|
|
415
|
+
)
|
|
416
|
+
self.tags_with_source[name] = (value, source)
|
|
417
|
+
# Return just the values (without source info)
|
|
418
|
+
return {
|
|
419
|
+
name: value for name, (value, _) in self.tags_with_source.items()
|
|
420
|
+
}
|
|
352
421
|
|
|
353
422
|
return TagFinder()(self.pyobject)
|
|
354
423
|
|
|
@@ -489,6 +558,33 @@ class ConfigInformation:
|
|
|
489
558
|
self._identifier = identifier
|
|
490
559
|
return identifier
|
|
491
560
|
|
|
561
|
+
def get_partial_identifier(self, subparameters: "Subparameters") -> "Identifier":
|
|
562
|
+
"""Get the partial identifier for a given subparameters instance.
|
|
563
|
+
|
|
564
|
+
Partial identifiers exclude certain parameter groups, allowing
|
|
565
|
+
configurations that differ only in those groups to share the same
|
|
566
|
+
partial identifier (and thus the same partial directory).
|
|
567
|
+
|
|
568
|
+
Args:
|
|
569
|
+
subparameters: The Subparameters instance defining which groups
|
|
570
|
+
to include/exclude.
|
|
571
|
+
|
|
572
|
+
Returns:
|
|
573
|
+
The partial identifier for this configuration.
|
|
574
|
+
"""
|
|
575
|
+
from ..identifier import IdentifierComputer
|
|
576
|
+
|
|
577
|
+
name = subparameters.name
|
|
578
|
+
if name in self._partial_identifiers:
|
|
579
|
+
return self._partial_identifiers[name]
|
|
580
|
+
|
|
581
|
+
identifier = IdentifierComputer.compute_partial(self.pyobject, subparameters)
|
|
582
|
+
|
|
583
|
+
if self._sealed:
|
|
584
|
+
self._partial_identifiers[name] = identifier
|
|
585
|
+
|
|
586
|
+
return identifier
|
|
587
|
+
|
|
492
588
|
def dependency(self):
|
|
493
589
|
"""Returns a dependency"""
|
|
494
590
|
from experimaestro.scheduler import JobDependency
|
|
@@ -563,10 +659,20 @@ class ConfigInformation:
|
|
|
563
659
|
|
|
564
660
|
:param method: The method to watch
|
|
565
661
|
:param callback: The callback
|
|
662
|
+
|
|
663
|
+
:raises TypeError: If the task is not a ResumableTask
|
|
566
664
|
"""
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
)
|
|
665
|
+
# Only ResumableTask can have dynamic outputs - regular tasks
|
|
666
|
+
# have their directories cleaned up, losing the output file
|
|
667
|
+
if not isinstance(self.pyobject, ResumableTask):
|
|
668
|
+
raise TypeError(
|
|
669
|
+
f"Only ResumableTask can use watch_output. "
|
|
670
|
+
f"{self.xpmtype} is not a ResumableTask. "
|
|
671
|
+
"Dynamic outputs require the task directory to be preserved "
|
|
672
|
+
"across restarts, which only ResumableTask provides."
|
|
673
|
+
)
|
|
674
|
+
|
|
675
|
+
watched = WatchedOutput(method.__self__, method.__name__, method, callback)
|
|
570
676
|
self.watched_outputs.append(watched)
|
|
571
677
|
if self.job:
|
|
572
678
|
self.job.watch_output(watched)
|
|
@@ -587,6 +693,7 @@ class ConfigInformation:
|
|
|
587
693
|
*,
|
|
588
694
|
run_mode=None,
|
|
589
695
|
init_tasks: List["LightweightTask"] = [],
|
|
696
|
+
max_retries: Optional[int] = None,
|
|
590
697
|
):
|
|
591
698
|
from experimaestro.scheduler import experiment, JobContext
|
|
592
699
|
from experimaestro.scheduler.workspace import RunMode
|
|
@@ -606,7 +713,11 @@ class ConfigInformation:
|
|
|
606
713
|
|
|
607
714
|
# Creates a new job
|
|
608
715
|
self.job = self.xpmtype.task(
|
|
609
|
-
self.pyobject,
|
|
716
|
+
self.pyobject,
|
|
717
|
+
launcher=launcher,
|
|
718
|
+
workspace=workspace,
|
|
719
|
+
run_mode=run_mode,
|
|
720
|
+
max_retries=max_retries,
|
|
610
721
|
)
|
|
611
722
|
|
|
612
723
|
# Validate the object
|
|
@@ -644,7 +755,6 @@ class ConfigInformation:
|
|
|
644
755
|
) or RunMode.NORMAL
|
|
645
756
|
if run_mode == RunMode.NORMAL:
|
|
646
757
|
TaskEventListener.connect(experiment.CURRENT)
|
|
647
|
-
experiment.CURRENT.submit(self.job)
|
|
648
758
|
other = experiment.CURRENT.submit(self.job)
|
|
649
759
|
if other:
|
|
650
760
|
# Our job = previously submitted job
|
|
@@ -844,7 +954,6 @@ class ConfigInformation:
|
|
|
844
954
|
|
|
845
955
|
The format is an array of objects
|
|
846
956
|
{
|
|
847
|
-
"tags: [ LIST_OF_TAGS ],
|
|
848
957
|
"workspace": FOLDERPATH,
|
|
849
958
|
"version": 2,
|
|
850
959
|
"objects": [
|
|
@@ -864,6 +973,10 @@ class ConfigInformation:
|
|
|
864
973
|
|
|
865
974
|
The last object is the one that is serialized
|
|
866
975
|
|
|
976
|
+
Note: Tags are no longer stored in params.json. They are managed by the
|
|
977
|
+
experiment state provider (scoped to job_id, experiment_id, run_id) and
|
|
978
|
+
also stored in experiment state.json for full experiment details.
|
|
979
|
+
|
|
867
980
|
Arguments:
|
|
868
981
|
out {io.TextIOBase} -- The output stream
|
|
869
982
|
context {[type]} -- the command context
|
|
@@ -871,7 +984,6 @@ class ConfigInformation:
|
|
|
871
984
|
json.dump(
|
|
872
985
|
{
|
|
873
986
|
"workspace": str(context.workspace.path.absolute()),
|
|
874
|
-
"tags": {key: value for key, value in self.tags().items()},
|
|
875
987
|
"version": 2,
|
|
876
988
|
"experimaestro": experimaestro.__version__,
|
|
877
989
|
"objects": self.__get_objects__([], context),
|
|
@@ -896,12 +1008,18 @@ class ConfigInformation:
|
|
|
896
1008
|
path: Union[str, Path, SerializedPathLoader],
|
|
897
1009
|
as_instance: bool = False,
|
|
898
1010
|
return_tasks: bool = False,
|
|
1011
|
+
partial_loading: Optional[bool] = None,
|
|
899
1012
|
) -> "Config":
|
|
900
1013
|
"""Deserialize a configuration
|
|
901
1014
|
|
|
902
1015
|
:param path: The filesystem Path to use, or a way to download the
|
|
903
1016
|
information through a function taking two arguments
|
|
904
1017
|
:param as_instance: Return an instance
|
|
1018
|
+
:param return_tasks: Return init tasks instead of executing them
|
|
1019
|
+
:param partial_loading: If True, skip loading task references. If None
|
|
1020
|
+
(default), partial_loading is enabled when as_instance is True.
|
|
1021
|
+
This is useful when loading configurations from disk (e.g.,
|
|
1022
|
+
HuggingFace) where the task code may have changed.
|
|
905
1023
|
:return: a Config object, its instance or a tuple (instance, init_tasks) is return_tasks is True
|
|
906
1024
|
"""
|
|
907
1025
|
# Load
|
|
@@ -926,6 +1044,7 @@ class ConfigInformation:
|
|
|
926
1044
|
as_instance=as_instance,
|
|
927
1045
|
data_loader=data_loader,
|
|
928
1046
|
return_tasks=return_tasks,
|
|
1047
|
+
partial_loading=partial_loading,
|
|
929
1048
|
)
|
|
930
1049
|
|
|
931
1050
|
@staticmethod
|
|
@@ -973,6 +1092,7 @@ class ConfigInformation:
|
|
|
973
1092
|
as_instance=True,
|
|
974
1093
|
save_directory: Optional[Path] = None,
|
|
975
1094
|
discard_id: bool = False,
|
|
1095
|
+
partial_loading: Optional[bool] = None,
|
|
976
1096
|
) -> "ConfigMixin": ...
|
|
977
1097
|
|
|
978
1098
|
@overload
|
|
@@ -983,6 +1103,7 @@ class ConfigInformation:
|
|
|
983
1103
|
return_tasks=True,
|
|
984
1104
|
save_directory: Optional[Path] = None,
|
|
985
1105
|
discard_id: bool = False,
|
|
1106
|
+
partial_loading: Optional[bool] = None,
|
|
986
1107
|
) -> Tuple["Config", List["LightweightTask"]]: ...
|
|
987
1108
|
|
|
988
1109
|
@overload
|
|
@@ -992,23 +1113,115 @@ class ConfigInformation:
|
|
|
992
1113
|
as_instance=False,
|
|
993
1114
|
save_directory: Optional[Path] = None,
|
|
994
1115
|
discard_id: bool = False,
|
|
1116
|
+
partial_loading: Optional[bool] = None,
|
|
995
1117
|
) -> "Config": ...
|
|
996
1118
|
|
|
1119
|
+
@staticmethod
|
|
1120
|
+
def _get_field_refs(value: Any) -> Set[int]:
|
|
1121
|
+
"""Recursively extract object references from a serialized field value"""
|
|
1122
|
+
refs: Set[int] = set()
|
|
1123
|
+
if isinstance(value, dict):
|
|
1124
|
+
if value.get("type") == "python":
|
|
1125
|
+
refs.add(value["value"])
|
|
1126
|
+
else:
|
|
1127
|
+
for v in value.values():
|
|
1128
|
+
refs.update(ConfigInformation._get_field_refs(v))
|
|
1129
|
+
elif isinstance(value, list):
|
|
1130
|
+
for v in value:
|
|
1131
|
+
refs.update(ConfigInformation._get_field_refs(v))
|
|
1132
|
+
return refs
|
|
1133
|
+
|
|
1134
|
+
@staticmethod
|
|
1135
|
+
def _compute_skipped_ids(definitions: List[Dict]) -> Set[int]:
|
|
1136
|
+
"""Compute IDs of objects only reachable through task references.
|
|
1137
|
+
|
|
1138
|
+
When partial_loading is enabled, we skip loading task references and
|
|
1139
|
+
any objects that are only used by those tasks. This method computes
|
|
1140
|
+
which object IDs should be skipped by finding objects reachable from
|
|
1141
|
+
the main object (last definition) without following task references.
|
|
1142
|
+
|
|
1143
|
+
:param definitions: List of object definitions from JSON
|
|
1144
|
+
:return: Set of object IDs to skip loading
|
|
1145
|
+
"""
|
|
1146
|
+
# Build index of definitions by ID
|
|
1147
|
+
def_by_id = {d["id"]: d for d in definitions}
|
|
1148
|
+
|
|
1149
|
+
# Compute reachable objects from main object (last definition)
|
|
1150
|
+
# without going through task references
|
|
1151
|
+
main_defn = definitions[-1]
|
|
1152
|
+
main_id = main_defn["id"]
|
|
1153
|
+
reachable: Set[int] = set()
|
|
1154
|
+
to_visit = [main_id]
|
|
1155
|
+
|
|
1156
|
+
# Also include init-tasks as reachable (needed for as_instance/return_tasks)
|
|
1157
|
+
for init_task_id in main_defn.get("init-tasks", []):
|
|
1158
|
+
to_visit.append(init_task_id)
|
|
1159
|
+
|
|
1160
|
+
while to_visit:
|
|
1161
|
+
obj_id = to_visit.pop()
|
|
1162
|
+
if obj_id in reachable:
|
|
1163
|
+
continue
|
|
1164
|
+
reachable.add(obj_id)
|
|
1165
|
+
|
|
1166
|
+
defn = def_by_id.get(obj_id)
|
|
1167
|
+
if defn is None:
|
|
1168
|
+
continue
|
|
1169
|
+
|
|
1170
|
+
# Add field references (not task reference)
|
|
1171
|
+
for field_value in defn.get("fields", {}).values():
|
|
1172
|
+
for ref_id in ConfigInformation._get_field_refs(field_value):
|
|
1173
|
+
if ref_id not in reachable:
|
|
1174
|
+
to_visit.append(ref_id)
|
|
1175
|
+
|
|
1176
|
+
# Note: we intentionally skip defn["task"] to avoid loading tasks
|
|
1177
|
+
|
|
1178
|
+
# All objects not reachable should be skipped
|
|
1179
|
+
all_ids = {d["id"] for d in definitions}
|
|
1180
|
+
return all_ids - reachable
|
|
1181
|
+
|
|
997
1182
|
@staticmethod
|
|
998
1183
|
def load_objects( # noqa: C901
|
|
999
1184
|
definitions: List[Dict],
|
|
1000
1185
|
as_instance=True,
|
|
1001
1186
|
data_loader: Optional[SerializedPathLoader] = None,
|
|
1002
1187
|
discard_id: bool = False,
|
|
1188
|
+
partial_loading: bool = False,
|
|
1003
1189
|
):
|
|
1004
|
-
"""Load the objects
|
|
1190
|
+
"""Load the objects
|
|
1191
|
+
|
|
1192
|
+
:param definitions: List of object definitions from JSON
|
|
1193
|
+
:param as_instance: Return instances instead of configs
|
|
1194
|
+
:param data_loader: Function to load data files
|
|
1195
|
+
:param discard_id: If True, don't use the stored identifier
|
|
1196
|
+
:param partial_loading: If True, skip loading task references. This is
|
|
1197
|
+
useful when loading configurations from disk (e.g., HuggingFace)
|
|
1198
|
+
where the task code may have changed.
|
|
1199
|
+
"""
|
|
1005
1200
|
o = None
|
|
1006
1201
|
objects = {}
|
|
1007
1202
|
import experimaestro.taskglobals as taskglobals
|
|
1008
1203
|
from ..identifier import Identifier
|
|
1009
1204
|
|
|
1205
|
+
# Compute which objects to skip when partial_loading
|
|
1206
|
+
skipped_ids = (
|
|
1207
|
+
ConfigInformation._compute_skipped_ids(definitions)
|
|
1208
|
+
if partial_loading
|
|
1209
|
+
else set()
|
|
1210
|
+
)
|
|
1211
|
+
|
|
1010
1212
|
# Loop over all the definitions and create objects
|
|
1011
1213
|
for definition in definitions:
|
|
1214
|
+
obj_id = definition["id"]
|
|
1215
|
+
|
|
1216
|
+
# Skip objects that are only reachable through task references
|
|
1217
|
+
if obj_id in skipped_ids:
|
|
1218
|
+
# Create a TaskStub for skipped task objects
|
|
1219
|
+
objects[obj_id] = TaskStub(
|
|
1220
|
+
identifier=Identifier.from_state_dict(definition["identifier"]),
|
|
1221
|
+
typename=definition.get("typename", definition["type"]),
|
|
1222
|
+
)
|
|
1223
|
+
continue
|
|
1224
|
+
|
|
1012
1225
|
module_name = definition["module"]
|
|
1013
1226
|
|
|
1014
1227
|
# Avoids problem when runing module
|
|
@@ -1041,12 +1254,18 @@ class ConfigInformation:
|
|
|
1041
1254
|
o = cls.__new__(cls)
|
|
1042
1255
|
else:
|
|
1043
1256
|
o = cls.XPMConfig.__new__(cls.XPMConfig)
|
|
1044
|
-
assert
|
|
1045
|
-
objects[
|
|
1257
|
+
assert obj_id not in objects, "Duplicate id %s" % obj_id
|
|
1258
|
+
objects[obj_id] = o
|
|
1046
1259
|
|
|
1047
1260
|
# Now that objects have been created, fill in the fields
|
|
1048
1261
|
for definition in definitions:
|
|
1049
|
-
|
|
1262
|
+
obj_id = definition["id"]
|
|
1263
|
+
|
|
1264
|
+
# Skip processing skipped objects (they are TaskStubs)
|
|
1265
|
+
if obj_id in skipped_ids:
|
|
1266
|
+
continue
|
|
1267
|
+
|
|
1268
|
+
o = objects[obj_id]
|
|
1050
1269
|
xpmtype = o.__getxpmtype__() # type: ObjectType
|
|
1051
1270
|
|
|
1052
1271
|
# If instance...
|
|
@@ -1136,13 +1355,20 @@ class ConfigInformation:
|
|
|
1136
1355
|
data_loader: Optional[SerializedPathLoader] = None,
|
|
1137
1356
|
discard_id: bool = False,
|
|
1138
1357
|
return_tasks: bool = False,
|
|
1358
|
+
partial_loading: Optional[bool] = None,
|
|
1139
1359
|
):
|
|
1360
|
+
# Determine effective partial_loading: as_instance implies partial_loading
|
|
1361
|
+
effective_partial_loading = (
|
|
1362
|
+
partial_loading if partial_loading is not None else as_instance
|
|
1363
|
+
)
|
|
1364
|
+
|
|
1140
1365
|
# Get the objects
|
|
1141
1366
|
objects = ConfigInformation.load_objects(
|
|
1142
1367
|
definitions,
|
|
1143
1368
|
as_instance=as_instance,
|
|
1144
1369
|
data_loader=data_loader,
|
|
1145
1370
|
discard_id=discard_id,
|
|
1371
|
+
partial_loading=effective_partial_loading,
|
|
1146
1372
|
)
|
|
1147
1373
|
|
|
1148
1374
|
# Get the last one
|
|
@@ -1258,9 +1484,98 @@ class ConfigMixin:
|
|
|
1258
1484
|
"""The __xpm__ object contains all instance specific information about a
|
|
1259
1485
|
configuration/task"""
|
|
1260
1486
|
|
|
1487
|
+
# Set when this instance was created via a deprecated config with replace=True
|
|
1488
|
+
_deprecated_from: "ConfigMixin | None" = None
|
|
1489
|
+
|
|
1490
|
+
def __new__(cls, **kwargs):
|
|
1491
|
+
"""Create a new config instance, handling deprecated replacements."""
|
|
1492
|
+
xpmtype = cls.__xpmtype__
|
|
1493
|
+
|
|
1494
|
+
# Check if this is a deprecated type with replace=True
|
|
1495
|
+
if xpmtype._deprecation is not None and xpmtype._deprecation.replace:
|
|
1496
|
+
# Create the deprecated instance normally
|
|
1497
|
+
instance = object.__new__(cls)
|
|
1498
|
+
# Initialize it
|
|
1499
|
+
ConfigMixin.__init__(instance, **kwargs)
|
|
1500
|
+
# Convert to the new type
|
|
1501
|
+
converted = instance.__convert__()
|
|
1502
|
+
# Mark that this came from a deprecated config
|
|
1503
|
+
converted._deprecated_from = instance
|
|
1504
|
+
return converted
|
|
1505
|
+
|
|
1506
|
+
# Normal creation
|
|
1507
|
+
return object.__new__(cls)
|
|
1508
|
+
|
|
1509
|
+
def __getattribute__(self, name: str):
|
|
1510
|
+
"""Get an attribute, handling XPM arguments specially.
|
|
1511
|
+
|
|
1512
|
+
We use __getattribute__ instead of __getattr__ because default values
|
|
1513
|
+
like `b: Param[X] = None` create class attributes that would prevent
|
|
1514
|
+
__getattr__ from being called.
|
|
1515
|
+
"""
|
|
1516
|
+
# Get __xpm__ without recursion
|
|
1517
|
+
try:
|
|
1518
|
+
xpm = object.__getattribute__(self, "__xpm__")
|
|
1519
|
+
except AttributeError:
|
|
1520
|
+
# During early init, __xpm__ may not exist yet
|
|
1521
|
+
return object.__getattribute__(self, name)
|
|
1522
|
+
|
|
1523
|
+
# Check if this is an XPM argument - parameters take precedence
|
|
1524
|
+
xpmtype = object.__getattribute__(self, "__xpmtype__")
|
|
1525
|
+
if name in xpmtype.arguments:
|
|
1526
|
+
return xpm.get(name)
|
|
1527
|
+
|
|
1528
|
+
# Fall back to normal lookup (methods, etc.)
|
|
1529
|
+
return object.__getattribute__(self, name)
|
|
1530
|
+
|
|
1531
|
+
def __setattr__(self, name: str, value):
|
|
1532
|
+
"""Set an attribute, handling XPM arguments specially."""
|
|
1533
|
+
# Allow setting internal attributes directly
|
|
1534
|
+
if name in ("__xpm__", "_deprecated_from"):
|
|
1535
|
+
object.__setattr__(self, name, value)
|
|
1536
|
+
return
|
|
1537
|
+
|
|
1538
|
+
# Check if we have __xpm__ yet (might not during early init)
|
|
1539
|
+
xpm = self.__dict__.get("__xpm__")
|
|
1540
|
+
if xpm is None:
|
|
1541
|
+
object.__setattr__(self, name, value)
|
|
1542
|
+
return
|
|
1543
|
+
|
|
1544
|
+
# Check if this is an XPM argument
|
|
1545
|
+
xpmtype = self.__xpmtype__
|
|
1546
|
+
if name in xpmtype.arguments:
|
|
1547
|
+
# Handle TaggedValue: extract value and add tag
|
|
1548
|
+
if isinstance(value, TaggedValue):
|
|
1549
|
+
actual_value = value.value
|
|
1550
|
+
source = get_caller_location(skip_frames=1)
|
|
1551
|
+
xpm.addtag(name, actual_value, source=source)
|
|
1552
|
+
xpm.set(name, actual_value)
|
|
1553
|
+
else:
|
|
1554
|
+
xpm.set(name, value)
|
|
1555
|
+
return
|
|
1556
|
+
|
|
1557
|
+
# Check for deprecated replacement warning
|
|
1558
|
+
deprecated_from = self.__dict__.get("_deprecated_from")
|
|
1559
|
+
if deprecated_from is not None:
|
|
1560
|
+
deprecated_xpmtype = deprecated_from.__xpmtype__
|
|
1561
|
+
if name in deprecated_xpmtype.arguments:
|
|
1562
|
+
logger.warning(
|
|
1563
|
+
f"Attribute '{name}' was in deprecated config "
|
|
1564
|
+
f"{deprecated_xpmtype.identifier} but is not in "
|
|
1565
|
+
f"{xpmtype.identifier}. The value is being discarded."
|
|
1566
|
+
)
|
|
1567
|
+
return # Don't set the attribute
|
|
1568
|
+
|
|
1569
|
+
# Normal attribute setting
|
|
1570
|
+
object.__setattr__(self, name, value)
|
|
1571
|
+
|
|
1261
1572
|
def __init__(self, **kwargs):
|
|
1262
1573
|
"""Initialize the configuration with the given parameters"""
|
|
1263
1574
|
|
|
1575
|
+
# Skip if already initialized (can happen with deprecated replace=True)
|
|
1576
|
+
if hasattr(self, "__xpm__"):
|
|
1577
|
+
return
|
|
1578
|
+
|
|
1264
1579
|
# Add configuration
|
|
1265
1580
|
xpmtype = self.__xpmtype__
|
|
1266
1581
|
|
|
@@ -1294,7 +1609,8 @@ class ConfigMixin:
|
|
|
1294
1609
|
# Special case of a tagged value
|
|
1295
1610
|
if isinstance(value, TaggedValue):
|
|
1296
1611
|
value = value.value
|
|
1297
|
-
|
|
1612
|
+
# Use _initinfo as source since tag is set at config creation
|
|
1613
|
+
self.__xpm__.addtag(name, value, source=xpm._initinfo)
|
|
1298
1614
|
|
|
1299
1615
|
# Really set the value
|
|
1300
1616
|
xpm.set(name, value)
|
|
@@ -1312,7 +1628,9 @@ class ConfigMixin:
|
|
|
1312
1628
|
)
|
|
1313
1629
|
|
|
1314
1630
|
def tag(self, name, value):
|
|
1315
|
-
|
|
1631
|
+
# Capture caller's location and pass to addtag
|
|
1632
|
+
source = get_caller_location(skip_frames=1)
|
|
1633
|
+
self.__xpm__.addtag(name, value, source=source)
|
|
1316
1634
|
return self
|
|
1317
1635
|
|
|
1318
1636
|
def __eq__(self, other):
|
|
@@ -1372,16 +1690,22 @@ class ConfigMixin:
|
|
|
1372
1690
|
launcher=None,
|
|
1373
1691
|
run_mode: "RunMode" = None,
|
|
1374
1692
|
init_tasks: List["LightweightTask"] = [],
|
|
1693
|
+
max_retries: Optional[int] = None,
|
|
1375
1694
|
):
|
|
1376
1695
|
"""Submit this task
|
|
1377
1696
|
|
|
1378
1697
|
:param workspace: the workspace, defaults to None
|
|
1379
1698
|
:param launcher: The launcher, defaults to None
|
|
1380
1699
|
:param run_mode: Run mode (if None, uses the workspace default)
|
|
1700
|
+
:param max_retries: Maximum number of retries for resumable tasks that timeout (default: from workspace settings or 3)
|
|
1381
1701
|
:return: an object object
|
|
1382
1702
|
"""
|
|
1383
1703
|
return self.__xpm__.submit(
|
|
1384
|
-
workspace,
|
|
1704
|
+
workspace,
|
|
1705
|
+
launcher,
|
|
1706
|
+
run_mode=run_mode,
|
|
1707
|
+
init_tasks=init_tasks,
|
|
1708
|
+
max_retries=max_retries,
|
|
1385
1709
|
)
|
|
1386
1710
|
|
|
1387
1711
|
def stdout(self):
|
|
@@ -1419,6 +1743,59 @@ class ConfigMixin:
|
|
|
1419
1743
|
# Add other dependencies
|
|
1420
1744
|
self.__xpm__.add_dependencies(*other.__xpm__.dependencies)
|
|
1421
1745
|
|
|
1746
|
+
def __rmatmul__(self, other: "ConfigMixin") -> "ConfigMixin":
|
|
1747
|
+
"""Right-associative composition operator: B() @ A(x=1) is equivalent to B(a=A(x=1))
|
|
1748
|
+
|
|
1749
|
+
For expression `other @ self`, finds the unique parameter in `other` that
|
|
1750
|
+
accepts `self`'s type and sets it. Returns `self` to enable right-associative
|
|
1751
|
+
chaining: `Outer() @ Middle() @ Inner()` builds Outer(middle=Middle(inner=Inner())).
|
|
1752
|
+
|
|
1753
|
+
The chain is evaluated left-to-right by Python, but returns the inner config
|
|
1754
|
+
so each step adds the current result as a parameter to the next outer config.
|
|
1755
|
+
|
|
1756
|
+
:param other: The outer configuration that will receive self as a parameter
|
|
1757
|
+
:return: self (the inner configuration) to continue the chain
|
|
1758
|
+
"""
|
|
1759
|
+
if not isinstance(other, ConfigMixin):
|
|
1760
|
+
return NotImplemented
|
|
1761
|
+
|
|
1762
|
+
# Find parameters in 'other' that can accept self's type
|
|
1763
|
+
self_type = self.__xpmtype__.value_type
|
|
1764
|
+
matching_params = []
|
|
1765
|
+
|
|
1766
|
+
for name, argument in other.__xpmtype__.arguments.items():
|
|
1767
|
+
# Get the expected type for this argument
|
|
1768
|
+
arg_type = argument.type
|
|
1769
|
+
if hasattr(arg_type, "value_type"):
|
|
1770
|
+
# It's an ObjectType wrapper
|
|
1771
|
+
expected_type = arg_type.value_type
|
|
1772
|
+
elif hasattr(arg_type, "__origin__"):
|
|
1773
|
+
# Generic type like Optional[X] or List[X]
|
|
1774
|
+
continue # Skip complex types for now
|
|
1775
|
+
elif isinstance(arg_type, type):
|
|
1776
|
+
expected_type = arg_type
|
|
1777
|
+
else:
|
|
1778
|
+
continue
|
|
1779
|
+
|
|
1780
|
+
# Check if self's type is compatible
|
|
1781
|
+
if isinstance(expected_type, type) and issubclass(self_type, expected_type):
|
|
1782
|
+
matching_params.append(name)
|
|
1783
|
+
|
|
1784
|
+
if len(matching_params) == 0:
|
|
1785
|
+
raise ValueError(
|
|
1786
|
+
f"No parameter in {other.__xpmtype__} accepts type {self_type.__name__}"
|
|
1787
|
+
)
|
|
1788
|
+
if len(matching_params) > 1:
|
|
1789
|
+
raise ValueError(
|
|
1790
|
+
f"Ambiguous composition: parameters {matching_params} in "
|
|
1791
|
+
f"{other.__xpmtype__} all accept type {self_type.__name__}"
|
|
1792
|
+
)
|
|
1793
|
+
|
|
1794
|
+
# Set the parameter on 'other'
|
|
1795
|
+
param_name = matching_params[0]
|
|
1796
|
+
other.__xpm__.set(param_name, self)
|
|
1797
|
+
return other
|
|
1798
|
+
|
|
1422
1799
|
|
|
1423
1800
|
class Config:
|
|
1424
1801
|
"""Base type for all objects in python interface"""
|
|
@@ -1442,6 +1819,68 @@ class Config:
|
|
|
1442
1819
|
"""Alias for XPMConfig"""
|
|
1443
1820
|
return cls.XPMConfig
|
|
1444
1821
|
|
|
1822
|
+
@classproperty
|
|
1823
|
+
def XPMValue(cls):
|
|
1824
|
+
"""Get the value class for this configuration.
|
|
1825
|
+
|
|
1826
|
+
Returns the explicitly registered value class, or the base config class
|
|
1827
|
+
if no value class was registered.
|
|
1828
|
+
"""
|
|
1829
|
+
return cls.__getxpmtype__().value_type
|
|
1830
|
+
|
|
1831
|
+
@classmethod
|
|
1832
|
+
def value_class(cls):
|
|
1833
|
+
"""Decorator to register an external value class for this configuration.
|
|
1834
|
+
|
|
1835
|
+
This allows declaring a separate class that will be used when creating
|
|
1836
|
+
instances, which is useful to avoid initializing resources (e.g., PyTorch)
|
|
1837
|
+
when only configuring.
|
|
1838
|
+
|
|
1839
|
+
.. code-block:: python
|
|
1840
|
+
|
|
1841
|
+
class Model(Config):
|
|
1842
|
+
hidden_size: Param[int]
|
|
1843
|
+
|
|
1844
|
+
@Model.value_class()
|
|
1845
|
+
class TorchModel(Model, nn.Module):
|
|
1846
|
+
def __init__(self):
|
|
1847
|
+
super().__init__()
|
|
1848
|
+
self.layer = nn.Linear(self.hidden_size, self.hidden_size)
|
|
1849
|
+
|
|
1850
|
+
The value class must be a subclass of the configuration class
|
|
1851
|
+
and a subclass of parent configuration value classes (if any).
|
|
1852
|
+
"""
|
|
1853
|
+
|
|
1854
|
+
def decorator(value_class: type) -> type:
|
|
1855
|
+
xpmtype = cls.__getxpmtype__()
|
|
1856
|
+
|
|
1857
|
+
# Check that value class is a subclass of the config class
|
|
1858
|
+
if not issubclass(value_class, cls):
|
|
1859
|
+
raise TypeError(
|
|
1860
|
+
f"Value class {value_class.__name__} must be a subclass of "
|
|
1861
|
+
f"{cls.__name__}"
|
|
1862
|
+
)
|
|
1863
|
+
|
|
1864
|
+
# Check that value class inherits from parent value classes
|
|
1865
|
+
for base in cls.__bases__:
|
|
1866
|
+
if base is Config or not issubclass(base, Config):
|
|
1867
|
+
continue
|
|
1868
|
+
parent_xpmtype = base.__getxpmtype__()
|
|
1869
|
+
# Check if parent has an explicit value type (different from original)
|
|
1870
|
+
if parent_xpmtype.value_type is not parent_xpmtype._original_type:
|
|
1871
|
+
parent_value = parent_xpmtype.value_type
|
|
1872
|
+
if not issubclass(value_class, parent_value):
|
|
1873
|
+
raise TypeError(
|
|
1874
|
+
f"Value class {value_class.__name__} must be a subclass of "
|
|
1875
|
+
f"parent value class {parent_value.__name__}"
|
|
1876
|
+
)
|
|
1877
|
+
|
|
1878
|
+
# Register the value class
|
|
1879
|
+
xpmtype.set_value_type(value_class)
|
|
1880
|
+
return value_class
|
|
1881
|
+
|
|
1882
|
+
return decorator
|
|
1883
|
+
|
|
1445
1884
|
@classmethod
|
|
1446
1885
|
def __getxpmtype__(cls) -> "ObjectType":
|
|
1447
1886
|
"""Get (and create if necessary) the Object type associated
|
|
@@ -1497,6 +1936,40 @@ class Config:
|
|
|
1497
1936
|
fp.flush()
|
|
1498
1937
|
|
|
1499
1938
|
|
|
1939
|
+
class InstanceConfig(Config):
|
|
1940
|
+
"""Base class for configurations where instance identity matters.
|
|
1941
|
+
|
|
1942
|
+
When a Config class derives from InstanceConfig instead of Config,
|
|
1943
|
+
instances are distinguished based on their object identity when used
|
|
1944
|
+
in containers. This enables distinguishing between shared and separate
|
|
1945
|
+
instances even when all parameters are identical.
|
|
1946
|
+
|
|
1947
|
+
Example:
|
|
1948
|
+
>>> class SubModel(InstanceConfig):
|
|
1949
|
+
... value: Param[int] = 100
|
|
1950
|
+
>>> class MainModel(Config):
|
|
1951
|
+
... m1: Param[SubModel]
|
|
1952
|
+
... m2: Param[SubModel]
|
|
1953
|
+
>>>
|
|
1954
|
+
>>> sm1 = SubModel.C()
|
|
1955
|
+
>>> sm2 = SubModel.C() # Same params, different instance
|
|
1956
|
+
>>>
|
|
1957
|
+
>>> # Shared instance (same object used twice)
|
|
1958
|
+
>>> shared = MainModel.C(m1=sm1, m2=sm1)
|
|
1959
|
+
>>>
|
|
1960
|
+
>>> # Separate instances (different objects)
|
|
1961
|
+
>>> separate = MainModel.C(m1=sm1, m2=sm2)
|
|
1962
|
+
>>>
|
|
1963
|
+
>>> # Different identifiers: shared vs separate
|
|
1964
|
+
>>> shared.__identifier__() != separate.__identifier__()
|
|
1965
|
+
|
|
1966
|
+
The instance order is determined by the traversal order during
|
|
1967
|
+
identifier computation, ensuring reproducibility.
|
|
1968
|
+
"""
|
|
1969
|
+
|
|
1970
|
+
pass
|
|
1971
|
+
|
|
1972
|
+
|
|
1500
1973
|
class LightweightTask(Config):
|
|
1501
1974
|
"""A task that can be run before or after a real task to modify its behaviour"""
|
|
1502
1975
|
|
|
@@ -1559,11 +2032,55 @@ def copyconfig(config: Config, **kwargs):
|
|
|
1559
2032
|
|
|
1560
2033
|
|
|
1561
2034
|
def setmeta(config: Config, flag: bool):
|
|
1562
|
-
"""
|
|
2035
|
+
"""Force a configuration to be treated as a meta-parameter.
|
|
2036
|
+
|
|
2037
|
+
When a configuration is marked as meta, it is excluded from the
|
|
2038
|
+
identifier computation of its parent configuration.
|
|
2039
|
+
|
|
2040
|
+
Example::
|
|
2041
|
+
|
|
2042
|
+
class Ensemble(Config):
|
|
2043
|
+
model1: Param[Model]
|
|
2044
|
+
model2: Param[Model]
|
|
2045
|
+
|
|
2046
|
+
# Mark model2 as meta - it won't affect the ensemble's identifier
|
|
2047
|
+
model2 = setmeta(Model.C(...), True)
|
|
2048
|
+
ensemble = Ensemble.C(model1=model1, model2=model2)
|
|
2049
|
+
|
|
2050
|
+
:param config: The configuration to mark
|
|
2051
|
+
:param flag: True to mark as meta, False to include in identifier
|
|
2052
|
+
:return: The same configuration (for chaining)
|
|
2053
|
+
"""
|
|
1563
2054
|
config.__xpm__.set_meta(flag)
|
|
1564
2055
|
return config
|
|
1565
2056
|
|
|
1566
2057
|
|
|
2058
|
+
class ResumableTask(Task):
|
|
2059
|
+
"""Base class for resumable/checkpointable tasks
|
|
2060
|
+
|
|
2061
|
+
Resumable tasks can be restarted if they are stopped by a time limit
|
|
2062
|
+
(e.g., SLURM job timeout). The task directory and dynamic outputs are
|
|
2063
|
+
preserved across restarts to allow checkpoint recovery.
|
|
2064
|
+
"""
|
|
2065
|
+
|
|
2066
|
+
def remaining_time(self) -> Optional[float]:
|
|
2067
|
+
"""Returns the remaining time in seconds before the job times out.
|
|
2068
|
+
|
|
2069
|
+
This is useful for checkpointing before hitting a time limit
|
|
2070
|
+
(e.g., SLURM walltime).
|
|
2071
|
+
|
|
2072
|
+
Returns:
|
|
2073
|
+
The remaining time in seconds, or None if:
|
|
2074
|
+
- There is no time limit
|
|
2075
|
+
- The launcher doesn't support querying remaining time
|
|
2076
|
+
- The task is not running
|
|
2077
|
+
"""
|
|
2078
|
+
launcher_info = taskglobals.Env.instance().launcher_info
|
|
2079
|
+
if launcher_info is None:
|
|
2080
|
+
return None
|
|
2081
|
+
return launcher_info.remaining_time()
|
|
2082
|
+
|
|
2083
|
+
|
|
1567
2084
|
def cache(fn, name: str):
|
|
1568
2085
|
def __call__(config, *args, **kwargs):
|
|
1569
2086
|
import experimaestro.taskglobals as taskglobals
|