experimaestro 0.22.0__py2.py3-none-any.whl → 0.24.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of experimaestro might be problematic. Click here for more details.
- experimaestro/__init__.py +7 -5
- experimaestro/__main__.py +3 -3
- experimaestro/commandline.py +0 -8
- experimaestro/core/objects.py +218 -164
- experimaestro/core/objects.pyi +25 -11
- experimaestro/core/serializers.py +52 -0
- experimaestro/core/types.py +44 -2
- experimaestro/generators.py +7 -6
- experimaestro/huggingface.py +2 -2
- experimaestro/launchers/__init__.py +19 -7
- experimaestro/scheduler/base.py +21 -3
- experimaestro/server/__init__.py +10 -2
- experimaestro/tests/test_identifier.py +33 -6
- experimaestro/tests/test_instance.py +18 -15
- experimaestro/tests/test_outputs.py +2 -40
- experimaestro/tests/test_progress.py +7 -9
- experimaestro/tests/test_serializers.py +54 -0
- experimaestro/utils/jobs.py +2 -2
- experimaestro/version.py +2 -2
- experimaestro/xpmutils.py +3 -3
- {experimaestro-0.22.0.dist-info → experimaestro-0.24.0.dist-info}/METADATA +26 -3
- {experimaestro-0.22.0.dist-info → experimaestro-0.24.0.dist-info}/RECORD +26 -25
- experimaestro/tests/test_serialization.py +0 -45
- {experimaestro-0.22.0.dist-info → experimaestro-0.24.0.dist-info}/LICENSE +0 -0
- {experimaestro-0.22.0.dist-info → experimaestro-0.24.0.dist-info}/WHEEL +0 -0
- {experimaestro-0.22.0.dist-info → experimaestro-0.24.0.dist-info}/entry_points.txt +0 -0
- {experimaestro-0.22.0.dist-info → experimaestro-0.24.0.dist-info}/top_level.txt +0 -0
experimaestro/core/objects.pyi
CHANGED
|
@@ -1,8 +1,10 @@
|
|
|
1
|
+
from abc import ABC
|
|
1
2
|
import typing_extensions
|
|
2
3
|
|
|
3
4
|
from experimaestro.core.types import ObjectType
|
|
4
5
|
import experimaestro
|
|
5
6
|
import io
|
|
7
|
+
from experimaestro.launchers import Launcher
|
|
6
8
|
from experimaestro.scheduler.base import Job
|
|
7
9
|
|
|
8
10
|
from experimaestro.scheduler.workspace import RunMode
|
|
@@ -20,7 +22,18 @@ from experimaestro.core.types import (
|
|
|
20
22
|
from experimaestro.utils import logger as logger
|
|
21
23
|
from functools import cached_property as cached_property
|
|
22
24
|
from pathlib import Path
|
|
23
|
-
from typing import
|
|
25
|
+
from typing import (
|
|
26
|
+
Any,
|
|
27
|
+
Callable,
|
|
28
|
+
ClassVar,
|
|
29
|
+
Dict,
|
|
30
|
+
List,
|
|
31
|
+
Optional,
|
|
32
|
+
Set,
|
|
33
|
+
TypeVar,
|
|
34
|
+
Union,
|
|
35
|
+
overload,
|
|
36
|
+
)
|
|
24
37
|
|
|
25
38
|
T = TypeVar("T", bound="Config")
|
|
26
39
|
|
|
@@ -62,7 +75,7 @@ class TaggedValue:
|
|
|
62
75
|
|
|
63
76
|
def add_to_path(p) -> Generator[None, None, None]: ...
|
|
64
77
|
|
|
65
|
-
class
|
|
78
|
+
class ConfigWalkContext:
|
|
66
79
|
@property
|
|
67
80
|
def path(self) -> None: ...
|
|
68
81
|
def __init__(self) -> None: ...
|
|
@@ -83,9 +96,9 @@ class ConfigProcessing:
|
|
|
83
96
|
def map(self, k: str): ...
|
|
84
97
|
def __call__(self, x): ...
|
|
85
98
|
|
|
86
|
-
class
|
|
99
|
+
class ConfigWalk(ConfigProcessing):
|
|
87
100
|
context: Incomplete
|
|
88
|
-
def __init__(self, context:
|
|
101
|
+
def __init__(self, context: ConfigWalkContext) -> None: ...
|
|
89
102
|
def list(self, i: int): ...
|
|
90
103
|
def map(self, k: str): ...
|
|
91
104
|
|
|
@@ -108,7 +121,7 @@ class ConfigInformation:
|
|
|
108
121
|
def xpmvalues(self, generated: bool = ...) -> Generator[Incomplete, None, None]: ...
|
|
109
122
|
def tags(self): ...
|
|
110
123
|
def validate(self) -> None: ...
|
|
111
|
-
def seal(self, context:
|
|
124
|
+
def seal(self, context: ConfigWalkContext): ...
|
|
112
125
|
@property
|
|
113
126
|
def identifier(self) -> Identifier: ...
|
|
114
127
|
def dependency(self): ...
|
|
@@ -141,13 +154,13 @@ class ConfigInformation:
|
|
|
141
154
|
save_directory: Optional[Path] = ...,
|
|
142
155
|
) -> Config: ...
|
|
143
156
|
|
|
144
|
-
class FromPython(
|
|
157
|
+
class FromPython(ConfigWalk):
|
|
145
158
|
objects: Incomplete
|
|
146
|
-
def __init__(self, context:
|
|
159
|
+
def __init__(self, context: ConfigWalkContext) -> None: ...
|
|
147
160
|
def preprocess(self, config: Config): ...
|
|
148
161
|
def postprocess(self, config: Config, values: Dict[str, Any]): ...
|
|
149
162
|
|
|
150
|
-
def fromConfig(self, context:
|
|
163
|
+
def fromConfig(self, context: ConfigWalkContext): ...
|
|
151
164
|
def add_dependencies(self, *dependencies) -> None: ...
|
|
152
165
|
|
|
153
166
|
def clone(v): ...
|
|
@@ -162,13 +175,14 @@ class TypeConfig:
|
|
|
162
175
|
def __arguments__(self): ...
|
|
163
176
|
def tags(self): ...
|
|
164
177
|
def add_dependencies(self, *dependencies): ...
|
|
165
|
-
def instance(self, context:
|
|
178
|
+
def instance(self, context: ConfigWalkContext = ...) -> T: ...
|
|
166
179
|
def submit(
|
|
167
180
|
self,
|
|
168
181
|
*,
|
|
169
182
|
workspace: Incomplete | None = ...,
|
|
170
183
|
launcher: Incomplete | None = ...,
|
|
171
|
-
run_mode: RunMode =
|
|
184
|
+
run_mode: RunMode = ...,
|
|
185
|
+
pre: List[Task] = []
|
|
172
186
|
): ...
|
|
173
187
|
def stdout(self): ...
|
|
174
188
|
def stderr(self): ...
|
|
@@ -183,7 +197,7 @@ class Config:
|
|
|
183
197
|
__xpmtype__: ClassVar[ObjectType]
|
|
184
198
|
__xpm__: ConfigInformation
|
|
185
199
|
@classmethod
|
|
186
|
-
def __getxpmtype__(cls): ...
|
|
200
|
+
def __getxpmtype__(cls) -> ObjectType: ...
|
|
187
201
|
def __getnewargs_ex__(self): ...
|
|
188
202
|
@classmethod
|
|
189
203
|
def c(cls, **kwargs) -> T: ...
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import List, TypeVar
|
|
4
|
+
|
|
5
|
+
from experimaestro import Param
|
|
6
|
+
|
|
7
|
+
from .objects import Config, Proxy
|
|
8
|
+
from .arguments import DataPath
|
|
9
|
+
|
|
10
|
+
T = TypeVar("T")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SerializedConfig(Config, Proxy, ABC):
|
|
14
|
+
"""A serializable configuration
|
|
15
|
+
|
|
16
|
+
This can be used to define a loading mechanism when instanciating the
|
|
17
|
+
configuration
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
config: Param[Config]
|
|
21
|
+
"""The configuration that will be serialized"""
|
|
22
|
+
|
|
23
|
+
registered: List[Config]
|
|
24
|
+
"""(execution only) List of configurations that use this serialized config"""
|
|
25
|
+
|
|
26
|
+
def __post_init__(self):
|
|
27
|
+
super().__post_init__()
|
|
28
|
+
self.registered = []
|
|
29
|
+
|
|
30
|
+
def register(self, config: Config):
|
|
31
|
+
self.registered.append(config)
|
|
32
|
+
|
|
33
|
+
def __unwrap__(self):
|
|
34
|
+
return self.config
|
|
35
|
+
|
|
36
|
+
@abstractmethod
|
|
37
|
+
def initialize(self):
|
|
38
|
+
"""Initialize the object
|
|
39
|
+
|
|
40
|
+
This might imply loading saved data (e.g. learned models)
|
|
41
|
+
"""
|
|
42
|
+
...
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class PathBasedSerializedConfig(SerializedConfig):
|
|
46
|
+
"""A path based serialized configuration
|
|
47
|
+
|
|
48
|
+
The most common case it to have
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
path: DataPath[Path]
|
|
52
|
+
"""Path containing the data"""
|
experimaestro/core/types.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
1
2
|
import inspect
|
|
2
3
|
import sys
|
|
3
|
-
from typing import Union, Dict, Iterator, List
|
|
4
|
+
from typing import Set, Union, Dict, Iterator, List
|
|
4
5
|
from collections import ChainMap
|
|
5
6
|
from pathlib import Path
|
|
6
7
|
import typing
|
|
@@ -17,6 +18,11 @@ if sys.version_info.major == 3 and sys.version_info.minor < 9:
|
|
|
17
18
|
else:
|
|
18
19
|
from typing import _AnnotatedAlias, get_type_hints
|
|
19
20
|
|
|
21
|
+
if typing.TYPE_CHECKING:
|
|
22
|
+
from experimaestro.scheduler.base import Job
|
|
23
|
+
from experimaestro.launchers import Launcher
|
|
24
|
+
from experimaestro.core.objects import Config
|
|
25
|
+
|
|
20
26
|
|
|
21
27
|
class Identifier:
|
|
22
28
|
def __init__(self, name: str):
|
|
@@ -143,7 +149,42 @@ class DeprecatedAttribute:
|
|
|
143
149
|
self.fn(instance, value)
|
|
144
150
|
|
|
145
151
|
|
|
152
|
+
class SubmitHook(ABC):
|
|
153
|
+
"""Hook called before the job is submitted to the scheduler
|
|
154
|
+
|
|
155
|
+
This allows modifying e.g. the run environnement
|
|
156
|
+
"""
|
|
157
|
+
|
|
158
|
+
@abstractmethod
|
|
159
|
+
def __call__(self, job: "Job", launcher: "Launcher"):
|
|
160
|
+
...
|
|
161
|
+
|
|
162
|
+
@abstractmethod
|
|
163
|
+
def __spec__(self):
|
|
164
|
+
"""Returns an identifier tuple for hashing/equality"""
|
|
165
|
+
...
|
|
166
|
+
|
|
167
|
+
def __eq__(self, other):
|
|
168
|
+
if other.__class__ is not self.__class__:
|
|
169
|
+
return False
|
|
170
|
+
return self.__spec__ == other.__spec__
|
|
171
|
+
|
|
172
|
+
def __hash__(self):
|
|
173
|
+
return hash((self.__class__, self.__spec__))
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def submit_hook_decorator(hook: SubmitHook):
|
|
177
|
+
def decorator(cls: typing.Type["Config"]):
|
|
178
|
+
cls.__getxpmtype__().submit_hooks.add(hook)
|
|
179
|
+
return cls
|
|
180
|
+
|
|
181
|
+
return decorator
|
|
182
|
+
|
|
183
|
+
|
|
146
184
|
class ObjectType(Type):
|
|
185
|
+
submit_hooks: Set[SubmitHook]
|
|
186
|
+
"""Hooks associated with this configuration"""
|
|
187
|
+
|
|
147
188
|
"""ObjectType contains class-level information about
|
|
148
189
|
experimaestro configurations and tasks
|
|
149
190
|
|
|
@@ -167,6 +208,7 @@ class ObjectType(Type):
|
|
|
167
208
|
self.taskcommandfactory = None
|
|
168
209
|
self.task = None
|
|
169
210
|
self._title = None
|
|
211
|
+
self.submit_hooks = set()
|
|
170
212
|
|
|
171
213
|
# Get the identifier
|
|
172
214
|
if identifier is None and hasattr(tp, "__xpmid__"):
|
|
@@ -219,7 +261,7 @@ class ObjectType(Type):
|
|
|
219
261
|
self.configtype.__module__ = tp.__module__
|
|
220
262
|
|
|
221
263
|
# Create the type-specific object class
|
|
222
|
-
# (now, the same as basetype -
|
|
264
|
+
# (now, the same as basetype - but in the future, remove references)
|
|
223
265
|
self.objecttype = self.basetype # type: type
|
|
224
266
|
self.basetype._ = self.configtype
|
|
225
267
|
|
experimaestro/generators.py
CHANGED
|
@@ -1,15 +1,16 @@
|
|
|
1
1
|
import inspect
|
|
2
2
|
from pathlib import Path
|
|
3
|
-
from typing import Callable,
|
|
3
|
+
from typing import Callable, Union
|
|
4
4
|
from experimaestro.core.arguments import ArgumentOptions, TypeAnnotation
|
|
5
|
-
from experimaestro.core.objects import
|
|
5
|
+
from experimaestro.core.objects import ConfigWalkContext, Config
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class Generator:
|
|
9
9
|
"""Base class for all generators"""
|
|
10
10
|
|
|
11
11
|
def isoutput(self):
|
|
12
|
-
"""Returns True if this generator is a task output (e.g. generates a
|
|
12
|
+
"""Returns True if this generator is a task output (e.g. generates a
|
|
13
|
+
path within the job folder)"""
|
|
13
14
|
return False
|
|
14
15
|
|
|
15
16
|
|
|
@@ -17,11 +18,11 @@ class PathGenerator(Generator):
|
|
|
17
18
|
"""Generates a path"""
|
|
18
19
|
|
|
19
20
|
def __init__(
|
|
20
|
-
self, path: Union[str, Path, Callable[[
|
|
21
|
+
self, path: Union[str, Path, Callable[[ConfigWalkContext, Config], Path]]
|
|
21
22
|
):
|
|
22
23
|
self.path = path
|
|
23
24
|
|
|
24
|
-
def __call__(self, context:
|
|
25
|
+
def __call__(self, context: ConfigWalkContext, config: Config):
|
|
25
26
|
if inspect.isfunction(self.path):
|
|
26
27
|
path = context.currentpath() / self.path(context, config) # type: Path
|
|
27
28
|
else:
|
|
@@ -34,7 +35,7 @@ class PathGenerator(Generator):
|
|
|
34
35
|
|
|
35
36
|
|
|
36
37
|
class pathgenerator(TypeAnnotation):
|
|
37
|
-
def __init__(self, value: Union[str, Callable[[
|
|
38
|
+
def __init__(self, value: Union[str, Callable[[ConfigWalkContext, Config], str]]):
|
|
38
39
|
self.value = value
|
|
39
40
|
|
|
40
41
|
def annotate(self, options: ArgumentOptions):
|
experimaestro/huggingface.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from pathlib import Path
|
|
2
2
|
from typing import Optional, Union
|
|
3
|
-
from experimaestro import Config,
|
|
3
|
+
from experimaestro import Config, ConfigWrapper
|
|
4
4
|
from experimaestro.core.context import SerializedPath
|
|
5
5
|
from experimaestro.core.objects import ConfigInformation
|
|
6
6
|
from huggingface_hub import ModelHubMixin, hf_hub_download, snapshot_download
|
|
@@ -11,7 +11,7 @@ class ExperimaestroHFHub(ModelHubMixin):
|
|
|
11
11
|
"""Defines models that can be uploaded/downloaded from the Hub"""
|
|
12
12
|
|
|
13
13
|
def __init__(
|
|
14
|
-
self, config: Union[Config,
|
|
14
|
+
self, config: Union[Config, ConfigWrapper], variant: Optional[str] = None
|
|
15
15
|
):
|
|
16
16
|
self.config = config if isinstance(config, Config) else config.__unwrap__()
|
|
17
17
|
self.variant = variant
|
|
@@ -1,25 +1,36 @@
|
|
|
1
1
|
from pathlib import Path, PosixPath
|
|
2
|
-
from typing import Callable, Dict, List, Optional
|
|
2
|
+
from typing import Callable, Dict, List, Optional
|
|
3
3
|
from experimaestro.commandline import AbstractCommand, Job, CommandLineJob
|
|
4
4
|
from experimaestro.connectors import Connector
|
|
5
5
|
from experimaestro.connectors.local import ProcessBuilder, LocalConnector
|
|
6
6
|
from experimaestro.connectors.ssh import SshPath, SshConnector
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
7
8
|
|
|
8
9
|
|
|
9
|
-
class ScriptBuilder:
|
|
10
|
-
lockfiles: List[Path]
|
|
11
|
-
command: "AbstractCommand"
|
|
10
|
+
class ScriptBuilder(ABC):
|
|
12
11
|
"""A script builder is responsible for generating the script
|
|
13
12
|
used to launch a command line job"""
|
|
14
13
|
|
|
14
|
+
lockfiles: List[Path]
|
|
15
|
+
"""The files that must be locked before starting the job"""
|
|
16
|
+
|
|
17
|
+
command: "AbstractCommand"
|
|
18
|
+
"""Command to be run"""
|
|
19
|
+
|
|
20
|
+
@abstractmethod
|
|
15
21
|
def write(self, job: CommandLineJob) -> Path:
|
|
16
|
-
|
|
22
|
+
"""Write the commmand line job
|
|
23
|
+
|
|
24
|
+
:params job: The job to be written
|
|
25
|
+
"""
|
|
26
|
+
...
|
|
17
27
|
|
|
18
28
|
|
|
19
29
|
SubmitListener = Callable[[Job], None]
|
|
30
|
+
"""Listen to job submissions"""
|
|
20
31
|
|
|
21
32
|
|
|
22
|
-
class Launcher:
|
|
33
|
+
class Launcher(ABC):
|
|
23
34
|
"""A launcher"""
|
|
24
35
|
|
|
25
36
|
submit_listeners: List[SubmitListener]
|
|
@@ -36,9 +47,10 @@ class Launcher:
|
|
|
36
47
|
def setNotificationURL(self, url: Optional[str]):
|
|
37
48
|
self.notificationURL = url
|
|
38
49
|
|
|
50
|
+
@abstractmethod
|
|
39
51
|
def scriptbuilder(self) -> ScriptBuilder:
|
|
40
52
|
"""Returns a script builder"""
|
|
41
|
-
|
|
53
|
+
...
|
|
42
54
|
|
|
43
55
|
def addListener(self, listener: SubmitListener):
|
|
44
56
|
self.submit_listeners.append(listener)
|
experimaestro/scheduler/base.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from collections import defaultdict
|
|
1
|
+
from collections import ChainMap, defaultdict
|
|
2
2
|
import os
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
from shutil import rmtree
|
|
@@ -14,7 +14,7 @@ from experimaestro.scheduler.services import Service
|
|
|
14
14
|
from experimaestro.settings import get_settings
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
from experimaestro.core.objects import Config,
|
|
17
|
+
from experimaestro.core.objects import Config, ConfigWalkContext
|
|
18
18
|
from experimaestro.utils import logger
|
|
19
19
|
from experimaestro.locking import Locks, LockError, Lock
|
|
20
20
|
from experimaestro.tokens import ProcessCounterToken
|
|
@@ -151,6 +151,23 @@ class Job(Resource):
|
|
|
151
151
|
assert self._future, "Cannot wait a not submitted job"
|
|
152
152
|
return self._future.result()
|
|
153
153
|
|
|
154
|
+
@property
|
|
155
|
+
def environ(self):
|
|
156
|
+
"""Returns the job environment
|
|
157
|
+
|
|
158
|
+
It is made of (by order of priority):
|
|
159
|
+
|
|
160
|
+
1. The job environment
|
|
161
|
+
1. The launcher environment
|
|
162
|
+
1. The workspace environment
|
|
163
|
+
|
|
164
|
+
"""
|
|
165
|
+
return ChainMap(
|
|
166
|
+
{},
|
|
167
|
+
self.launcher.environ if self.launcher else {},
|
|
168
|
+
self.workspace.environment.environ,
|
|
169
|
+
)
|
|
170
|
+
|
|
154
171
|
@property
|
|
155
172
|
def progress(self):
|
|
156
173
|
return self._progress
|
|
@@ -292,7 +309,7 @@ class Job(Resource):
|
|
|
292
309
|
return self._future
|
|
293
310
|
|
|
294
311
|
|
|
295
|
-
class JobContext(
|
|
312
|
+
class JobContext(ConfigWalkContext):
|
|
296
313
|
def __init__(self, job: Job):
|
|
297
314
|
super().__init__()
|
|
298
315
|
self.job = job
|
|
@@ -839,6 +856,7 @@ class experiment:
|
|
|
839
856
|
def __enter__(self):
|
|
840
857
|
logger.info("Locking experiment %s", self.xplockpath)
|
|
841
858
|
self.xplock = self.workspace.connector.lock(self.xplockpath, 0).__enter__()
|
|
859
|
+
logger.info("Experiment locked")
|
|
842
860
|
|
|
843
861
|
# Move old jobs into "jobs.bak"
|
|
844
862
|
if self.workspace.run_mode == RunMode.NORMAL:
|
experimaestro/server/__init__.py
CHANGED
|
@@ -143,10 +143,15 @@ def proxy_response(base_url: str, request: Request, path: str):
|
|
|
143
143
|
|
|
144
144
|
|
|
145
145
|
def start_app(server: "Server"):
|
|
146
|
+
logging.debug("Starting Flask server...")
|
|
146
147
|
app = Flask("experimaestro")
|
|
147
|
-
|
|
148
|
+
|
|
149
|
+
logging.debug("Starting Flask server (SocketIO)...")
|
|
150
|
+
socketio = SocketIO(app, path="/api", async_mode="gevent")
|
|
148
151
|
listener = Listener(server.scheduler, socketio)
|
|
149
152
|
|
|
153
|
+
logging.debug("Starting Flask server (setting up socketio)...")
|
|
154
|
+
|
|
150
155
|
@socketio.on("connect")
|
|
151
156
|
def handle_connect():
|
|
152
157
|
if server.token != request.cookies.get("experimaestro_token", None):
|
|
@@ -183,6 +188,8 @@ def start_app(server: "Server"):
|
|
|
183
188
|
if process is not None:
|
|
184
189
|
process.kill()
|
|
185
190
|
|
|
191
|
+
logging.debug("Starting Flask server (setting up routes)...")
|
|
192
|
+
|
|
186
193
|
@app.route("/services/<path:path>", methods=["GET", "POST"])
|
|
187
194
|
def route_service(path):
|
|
188
195
|
service, *path = path.split("/", 1)
|
|
@@ -256,8 +263,9 @@ def start_app(server: "Server"):
|
|
|
256
263
|
|
|
257
264
|
# Start the app
|
|
258
265
|
if server.port is None or server.port == 0:
|
|
266
|
+
logging.info("Searching for an available port")
|
|
259
267
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
260
|
-
sock.bind(("
|
|
268
|
+
sock.bind(("", 0))
|
|
261
269
|
server.port = sock.getsockname()[1]
|
|
262
270
|
sock.close()
|
|
263
271
|
|
|
@@ -17,7 +17,8 @@ from experimaestro import (
|
|
|
17
17
|
Annotated,
|
|
18
18
|
Task,
|
|
19
19
|
)
|
|
20
|
-
from experimaestro.core.objects import ConfigInformation,
|
|
20
|
+
from experimaestro.core.objects import ConfigInformation, ConfigWrapper, setmeta
|
|
21
|
+
from experimaestro.core.serializers import SerializedConfig
|
|
21
22
|
from experimaestro.scheduler.workspace import RunMode
|
|
22
23
|
|
|
23
24
|
|
|
@@ -56,7 +57,7 @@ class Values:
|
|
|
56
57
|
|
|
57
58
|
|
|
58
59
|
def getidentifier(x):
|
|
59
|
-
if isinstance(x,
|
|
60
|
+
if isinstance(x, ConfigWrapper):
|
|
60
61
|
return x.__xpm__.identifier.all
|
|
61
62
|
return x.__xpm__.identifier.all
|
|
62
63
|
|
|
@@ -387,6 +388,32 @@ def test_identifier_meta_default_array():
|
|
|
387
388
|
)
|
|
388
389
|
|
|
389
390
|
|
|
391
|
+
# --- Check ConfigWrapper
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
class Model(Config):
|
|
395
|
+
def __post_init__(self):
|
|
396
|
+
self.initialized = False
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
class Trainer(Config):
|
|
400
|
+
model: Param[Model]
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
class SerializedModel(SerializedConfig):
|
|
404
|
+
def initialize(self):
|
|
405
|
+
self.config.initialized = True
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
def test_identifier_serialized_config():
|
|
409
|
+
trainer1 = Trainer(model=Model())
|
|
410
|
+
trainer2 = Trainer(model=SerializedModel(config=Model()))
|
|
411
|
+
assert_notequal(trainer1, trainer2)
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
# --- Check configuration reloads
|
|
415
|
+
|
|
416
|
+
|
|
390
417
|
def check_reload(config):
|
|
391
418
|
old_identifier = config.__xpm__.identifier.all
|
|
392
419
|
|
|
@@ -412,14 +439,14 @@ def test_identifier_reload_config():
|
|
|
412
439
|
check_reload(IdentifierReloadConfig(id="123"))
|
|
413
440
|
|
|
414
441
|
|
|
415
|
-
class
|
|
442
|
+
class IdentifierReloadConfigWrapper(Task):
|
|
416
443
|
id: Param[str]
|
|
417
444
|
|
|
418
445
|
def taskoutputs(self):
|
|
419
446
|
return IdentifierReloadConfig(id=self.id)
|
|
420
447
|
|
|
421
448
|
|
|
422
|
-
class
|
|
449
|
+
class IdentifierReloadConfigWrapperDerived(Config):
|
|
423
450
|
task: Param[IdentifierReloadConfig]
|
|
424
451
|
|
|
425
452
|
|
|
@@ -427,8 +454,8 @@ def test_identifier_reload_taskoutput():
|
|
|
427
454
|
"""When using a task output, the identifier should not be different"""
|
|
428
455
|
|
|
429
456
|
# Creates the configuration
|
|
430
|
-
task =
|
|
431
|
-
config =
|
|
457
|
+
task = IdentifierReloadConfigWrapper(id="123").submit(run_mode=RunMode.DRY_RUN)
|
|
458
|
+
config = IdentifierReloadConfigWrapperDerived(task=task)
|
|
432
459
|
check_reload(config)
|
|
433
460
|
|
|
434
461
|
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from typing import Optional
|
|
2
2
|
from experimaestro import config, Param, Config
|
|
3
3
|
from experimaestro.core.objects import TypeConfig
|
|
4
|
+
from experimaestro.core.serializers import SerializedConfig
|
|
4
5
|
|
|
5
6
|
|
|
6
7
|
@config()
|
|
@@ -31,27 +32,29 @@ def test_simple_instance():
|
|
|
31
32
|
assert isinstance(b.a, A.__xpmtype__.basetype)
|
|
32
33
|
|
|
33
34
|
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
35
|
+
class Model(Config):
|
|
36
|
+
def __post_init__(self):
|
|
37
|
+
self.initialized = False
|
|
38
|
+
|
|
38
39
|
|
|
40
|
+
class Trainer(Config):
|
|
41
|
+
model: Param[Model]
|
|
39
42
|
|
|
40
|
-
class TestSerialization:
|
|
41
|
-
"""Test that a config can be serialized during execution"""
|
|
42
43
|
|
|
43
|
-
|
|
44
|
-
|
|
44
|
+
class SerializedModel(SerializedConfig):
|
|
45
|
+
def initialize(self):
|
|
46
|
+
self.config.initialized = True
|
|
45
47
|
|
|
46
|
-
a = SerializedConfig(x=2).instance()
|
|
47
|
-
assert not isinstance(a, TypeConfig)
|
|
48
|
-
assert isinstance(a, SerializedConfig)
|
|
49
48
|
|
|
50
|
-
|
|
49
|
+
def test_instance_serialized():
|
|
50
|
+
model = SerializedModel(config=Model())
|
|
51
|
+
trainer = Trainer(model=model)
|
|
52
|
+
instance = trainer.instance()
|
|
51
53
|
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
54
|
+
assert isinstance(
|
|
55
|
+
instance.model, Model
|
|
56
|
+
), f"The model is not a Model but a {type(instance.model).__qualname__}"
|
|
57
|
+
assert instance.model.initialized, "The model was not initialized"
|
|
55
58
|
|
|
56
59
|
|
|
57
60
|
class ConfigWithOptional(Config):
|
|
@@ -1,9 +1,7 @@
|
|
|
1
1
|
"""Test for task outputs"""
|
|
2
2
|
|
|
3
|
-
from experimaestro import Config, Task, Param
|
|
4
|
-
from experimaestro.core.objects import SerializedConfig, Serialized, TaskOutput
|
|
3
|
+
from experimaestro import Config, Task, Param, ConfigWrapper
|
|
5
4
|
from experimaestro.scheduler.workspace import RunMode
|
|
6
|
-
from experimaestro.tests.utils import TemporaryExperiment
|
|
7
5
|
|
|
8
6
|
|
|
9
7
|
class B(Config):
|
|
@@ -14,19 +12,12 @@ class A(Config):
|
|
|
14
12
|
b: Param[B]
|
|
15
13
|
|
|
16
14
|
|
|
17
|
-
class LoaderA(Serialized):
|
|
18
|
-
@staticmethod
|
|
19
|
-
def fromJSON(x) -> A:
|
|
20
|
-
return A(b=B(x=x)).instance()
|
|
21
|
-
|
|
22
|
-
|
|
23
15
|
class Main(Task):
|
|
24
16
|
a: Param[A]
|
|
25
17
|
|
|
26
18
|
def taskoutputs(self):
|
|
27
19
|
return self.a, {
|
|
28
20
|
"a": self.a,
|
|
29
|
-
"serialized": SerializedConfig(self.a, LoaderA(self.a.b.x)),
|
|
30
21
|
}
|
|
31
22
|
|
|
32
23
|
def execute(self):
|
|
@@ -44,8 +35,7 @@ def test_output_taskoutput():
|
|
|
44
35
|
a = A(b=B())
|
|
45
36
|
output, ioutput = Main(a=a).submit(run_mode=RunMode.DRY_RUN)
|
|
46
37
|
|
|
47
|
-
assert isinstance(
|
|
48
|
-
assert isinstance(output, TaskOutput), "outputs should be task proxies"
|
|
38
|
+
assert isinstance(output, ConfigWrapper), "outputs should be task proxies"
|
|
49
39
|
|
|
50
40
|
# Direct
|
|
51
41
|
Main(a=output)
|
|
@@ -58,31 +48,3 @@ def test_output_taskoutput():
|
|
|
58
48
|
|
|
59
49
|
# Now, submits
|
|
60
50
|
Main(a=output).submit(run_mode=RunMode.DRY_RUN)
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
def test_output_serialization():
|
|
64
|
-
"""Test output serialization"""
|
|
65
|
-
|
|
66
|
-
with TemporaryExperiment("output_serialization", maxwait=5) as xp:
|
|
67
|
-
a = A(b=B(x=2))
|
|
68
|
-
|
|
69
|
-
main0 = Main(a=a)
|
|
70
|
-
output, ioutput = main0.submit()
|
|
71
|
-
|
|
72
|
-
# Direct
|
|
73
|
-
serialized_a = ioutput["serialized"]
|
|
74
|
-
main1 = Main(a=serialized_a)
|
|
75
|
-
main1.submit()
|
|
76
|
-
|
|
77
|
-
# Indirect (via attribute)
|
|
78
|
-
serialized_a = ioutput["serialized"]
|
|
79
|
-
main2 = Main(a=A(b=serialized_a.b))
|
|
80
|
-
main2.submit()
|
|
81
|
-
|
|
82
|
-
xp.wait()
|
|
83
|
-
|
|
84
|
-
for main in (main1, main2):
|
|
85
|
-
assert main.__xpm__.job.stdout.read_text().strip() == "2"
|
|
86
|
-
assert len(main.__xpm__.job.dependencies) == 1
|
|
87
|
-
dep = next(iter(main.__xpm__.job.dependencies))
|
|
88
|
-
assert dep.origin.config is main0
|