experimaestro 1.6.1__py3-none-any.whl → 1.15.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. experimaestro/__init__.py +14 -3
  2. experimaestro/annotations.py +13 -3
  3. experimaestro/cli/filter.py +19 -5
  4. experimaestro/cli/jobs.py +12 -5
  5. experimaestro/commandline.py +3 -7
  6. experimaestro/connectors/__init__.py +27 -12
  7. experimaestro/connectors/local.py +19 -10
  8. experimaestro/connectors/ssh.py +1 -1
  9. experimaestro/core/arguments.py +35 -3
  10. experimaestro/core/callbacks.py +52 -0
  11. experimaestro/core/context.py +8 -9
  12. experimaestro/core/identifier.py +301 -0
  13. experimaestro/core/objects/__init__.py +44 -0
  14. experimaestro/core/{objects.py → objects/config.py} +364 -716
  15. experimaestro/core/objects/config_utils.py +58 -0
  16. experimaestro/core/objects/config_walk.py +151 -0
  17. experimaestro/core/objects.pyi +15 -45
  18. experimaestro/core/serialization.py +63 -9
  19. experimaestro/core/serializers.py +1 -8
  20. experimaestro/core/types.py +61 -6
  21. experimaestro/experiments/cli.py +79 -29
  22. experimaestro/experiments/configuration.py +3 -0
  23. experimaestro/generators.py +6 -1
  24. experimaestro/ipc.py +4 -1
  25. experimaestro/launcherfinder/parser.py +8 -3
  26. experimaestro/launcherfinder/registry.py +29 -10
  27. experimaestro/launcherfinder/specs.py +49 -10
  28. experimaestro/launchers/slurm/base.py +51 -13
  29. experimaestro/mkdocs/__init__.py +1 -1
  30. experimaestro/notifications.py +2 -1
  31. experimaestro/run.py +3 -1
  32. experimaestro/scheduler/base.py +114 -6
  33. experimaestro/scheduler/dynamic_outputs.py +184 -0
  34. experimaestro/scheduler/state.py +75 -0
  35. experimaestro/scheduler/workspace.py +2 -1
  36. experimaestro/scriptbuilder.py +13 -2
  37. experimaestro/server/data/0c35d18bf06992036b69.woff2 +0 -0
  38. experimaestro/server/data/1815e00441357e01619e.ttf +0 -0
  39. experimaestro/server/data/219aa9140e099e6c72ed.woff2 +0 -0
  40. experimaestro/server/data/2463b90d9a316e4e5294.woff2 +0 -0
  41. experimaestro/server/data/2582b0e4bcf85eceead0.ttf +0 -0
  42. experimaestro/server/data/3a4004a46a653d4b2166.woff +0 -0
  43. experimaestro/server/data/3baa5b8f3469222b822d.woff +0 -0
  44. experimaestro/server/data/4d73cb90e394b34b7670.woff +0 -0
  45. experimaestro/server/data/4ef4218c522f1eb6b5b1.woff2 +0 -0
  46. experimaestro/server/data/5d681e2edae8c60630db.woff +0 -0
  47. experimaestro/server/data/6f420cf17cc0d7676fad.woff2 +0 -0
  48. experimaestro/server/data/89999bdf5d835c012025.woff2 +0 -0
  49. experimaestro/server/data/914997e1bdfc990d0897.ttf +0 -0
  50. experimaestro/server/data/c210719e60948b211a12.woff2 +0 -0
  51. experimaestro/server/data/c380809fd3677d7d6903.woff2 +0 -0
  52. experimaestro/server/data/f882956fd323fd322f31.woff +0 -0
  53. experimaestro/server/data/favicon.ico +0 -0
  54. experimaestro/server/data/index.css +22963 -0
  55. experimaestro/server/data/index.css.map +1 -0
  56. experimaestro/server/data/index.html +27 -0
  57. experimaestro/server/data/index.js +101770 -0
  58. experimaestro/server/data/index.js.map +1 -0
  59. experimaestro/server/data/login.html +22 -0
  60. experimaestro/server/data/manifest.json +15 -0
  61. experimaestro/settings.py +2 -2
  62. experimaestro/sphinx/__init__.py +7 -17
  63. experimaestro/taskglobals.py +7 -2
  64. experimaestro/tests/core/__init__.py +0 -0
  65. experimaestro/tests/core/test_generics.py +206 -0
  66. experimaestro/tests/definitions_types.py +5 -3
  67. experimaestro/tests/launchers/bin/sbatch +34 -7
  68. experimaestro/tests/launchers/bin/srun +5 -0
  69. experimaestro/tests/launchers/common.py +16 -4
  70. experimaestro/tests/restart.py +9 -4
  71. experimaestro/tests/tasks/all.py +23 -10
  72. experimaestro/tests/tasks/foreign.py +2 -4
  73. experimaestro/tests/test_dependencies.py +0 -6
  74. experimaestro/tests/test_experiment.py +73 -0
  75. experimaestro/tests/test_findlauncher.py +11 -4
  76. experimaestro/tests/test_forward.py +5 -5
  77. experimaestro/tests/test_generators.py +93 -0
  78. experimaestro/tests/test_identifier.py +114 -99
  79. experimaestro/tests/test_instance.py +6 -21
  80. experimaestro/tests/test_objects.py +20 -4
  81. experimaestro/tests/test_param.py +60 -22
  82. experimaestro/tests/test_serializers.py +24 -64
  83. experimaestro/tests/test_tags.py +5 -11
  84. experimaestro/tests/test_tasks.py +10 -23
  85. experimaestro/tests/test_tokens.py +3 -2
  86. experimaestro/tests/test_types.py +20 -17
  87. experimaestro/tests/test_validation.py +48 -91
  88. experimaestro/tokens.py +16 -5
  89. experimaestro/typingutils.py +8 -8
  90. experimaestro/utils/asyncio.py +6 -2
  91. experimaestro/utils/multiprocessing.py +44 -0
  92. experimaestro/utils/resources.py +7 -3
  93. {experimaestro-1.6.1.dist-info → experimaestro-1.15.2.dist-info}/METADATA +27 -34
  94. experimaestro-1.15.2.dist-info/RECORD +159 -0
  95. {experimaestro-1.6.1.dist-info → experimaestro-1.15.2.dist-info}/WHEEL +1 -1
  96. experimaestro-1.6.1.dist-info/RECORD +0 -122
  97. {experimaestro-1.6.1.dist-info → experimaestro-1.15.2.dist-info}/entry_points.txt +0 -0
  98. {experimaestro-1.6.1.dist-info → experimaestro-1.15.2.dist-info/licenses}/LICENSE +0 -0
@@ -0,0 +1,73 @@
1
+ from experimaestro import Task, Param, get_experiment, tag
2
+ from experimaestro.tests.utils import TemporaryDirectory, TemporaryExperiment
3
+
4
+
5
+ class TaskA(Task):
6
+ def execute(self):
7
+ pass
8
+
9
+
10
+ class TaskB(Task):
11
+ task_a: Param[TaskA]
12
+ x: Param[int]
13
+
14
+ def execute(self):
15
+ pass
16
+
17
+
18
+ # xp = get_experiment(id="my-xp-1")
19
+
20
+ # # Returns a list of tasks which were submitted and successful
21
+ # tasks = xp.get_tasks(myxps.evaluation.Evaluation, status=Job.DONE)
22
+
23
+ # for task in tasks:
24
+ # # Look at the tags
25
+ # print(task.tags)
26
+
27
+ # # Get some information
28
+ # print("Task ran in {task.workdir}")
29
+
30
+ # # Look at the parent jobs
31
+ # print(task.depends_on)
32
+
33
+ # # Look at the dependant
34
+ # print(task.dependents)
35
+
36
+
37
+ def test_experiment_history():
38
+ """Test retrieving experiment history"""
39
+ with TemporaryDirectory() as workdir:
40
+ with TemporaryExperiment("experiment", workdir=workdir):
41
+ task_a = TaskA().submit()
42
+ TaskB(task_a=task_a, x=tag(1)).submit()
43
+
44
+ # Look at the experiment
45
+ xp = get_experiment("experiment", workdir=workdir)
46
+
47
+ (task_a_info,) = xp.get_jobs(TaskA)
48
+ (task_b_info,) = xp.get_jobs(TaskB)
49
+ assert task_b_info.tags == {"x": 1}
50
+ assert task_b_info.depends_on == [task_a_info]
51
+
52
+
53
+ class FlagHandler:
54
+ def __init__(self):
55
+ self.flag = False
56
+
57
+ def set(self):
58
+ self.flag = True
59
+
60
+ def is_set(self):
61
+ return self.flag
62
+
63
+
64
+ def test_experiment_events():
65
+ """Test handlers"""
66
+
67
+ flag = FlagHandler()
68
+ with TemporaryExperiment("experiment"):
69
+ task_a = TaskA()
70
+ task_a.submit()
71
+ task_a.on_completed(flag.set)
72
+
73
+ assert flag.is_set()
@@ -3,9 +3,9 @@ from experimaestro.launcherfinder.specs import (
3
3
  CPUSpecification,
4
4
  CudaSpecification,
5
5
  HostSpecification,
6
+ RequirementUnion,
6
7
  cpu,
7
8
  cuda_gpu,
8
- HostSimpleRequirement,
9
9
  )
10
10
  from experimaestro.launcherfinder import parse
11
11
  from humanfriendly import parse_size, parse_timespan
@@ -39,6 +39,11 @@ def test_findlauncher_specs():
39
39
  assert m is not None
40
40
  assert m.requirement is req2
41
41
 
42
+ # Multiply
43
+ req2 = req.multiply_duration(2)
44
+ for i in range(2):
45
+ assert req2.requirements[i].duration == req.requirements[i].duration * 2
46
+
42
47
 
43
48
  def test_findlauncher_specs_gpu_mem():
44
49
  host = HostSpecification(
@@ -60,8 +65,10 @@ def test_findlauncher_specs_gpu_mem():
60
65
 
61
66
 
62
67
  def test_findlauncher_parse():
63
- (r,) = parse("""duration=4 d & cuda(mem=4G) * 2 & cpu(mem=400M, cores=4)""")
64
- assert isinstance(r, HostSimpleRequirement)
68
+ r = parse("""duration=4 d & cuda(mem=4G) * 2 & cpu(mem=400M, cores=4)""")
69
+ assert isinstance(r, RequirementUnion)
70
+
71
+ r = r.requirements[0]
65
72
 
66
73
  assert len(r.cuda_gpus) == 2
67
74
  assert r.cuda_gpus[0].memory == parse_size("4G")
@@ -79,7 +86,7 @@ def slurm_constraint_split(constraint: str):
79
86
 
80
87
 
81
88
  def test_findlauncher_slurm():
82
- path = ResourcePathWrapper.create(f"{__package__ }.launchers", "config_slurm")
89
+ path = ResourcePathWrapper.create(f"{__package__}.launchers", "config_slurm")
83
90
 
84
91
  assert (path / "launchers.py").is_file()
85
92
 
@@ -1,12 +1,12 @@
1
- from experimaestro import argument, Config
1
+ from experimaestro import Param, Config
2
2
  from experimaestro.click import forwardoption
3
3
  import click
4
4
 
5
5
 
6
6
  def test_main():
7
- @argument("epochs", type=int, default=100, help="Number of learning epochs")
8
7
  class MyModel(Config):
9
- pass
8
+ epochs: Param[int] = 100
9
+ """Number of learning epochs"""
10
10
 
11
11
  @forwardoption.epochs(MyModel)
12
12
  @click.command()
@@ -18,9 +18,9 @@ def test_main():
18
18
 
19
19
 
20
20
  def test_rename():
21
- @argument("epochs", type=int, default=100, help="Number of learning epochs")
22
21
  class MyModel(Config):
23
- pass
22
+ epochs: Param[int] = 100
23
+ """Number of learning epochs"""
24
24
 
25
25
  @forwardoption.epochs(MyModel, "my-epochs")
26
26
  @click.command()
@@ -0,0 +1,93 @@
1
+ from experimaestro import Config, Task, Param, Meta, Path, field, PathGenerator
2
+ from experimaestro.scheduler.workspace import Workspace
3
+ from experimaestro.settings import Settings, WorkspaceSettings
4
+ import pytest
5
+ from experimaestro.scheduler import RunMode
6
+
7
+
8
+ class Validation(Config):
9
+ best_checkpoint: Meta[Path] = field(default_factory=PathGenerator("index"))
10
+
11
+
12
+ class Learner(Task):
13
+ validation: Param[Validation]
14
+ x: Param[int]
15
+
16
+ @staticmethod
17
+ def create(x: int, validation: Param[Validation]):
18
+ return Learner.C(x=x, validation=validation)
19
+
20
+
21
+ class LearnerList(Task):
22
+ validation: Param[list[Validation]]
23
+ x: Param[int]
24
+
25
+ @staticmethod
26
+ def create(x: int, validation: Param[Validation]):
27
+ return LearnerList.C(x=x, validation=[validation])
28
+
29
+
30
+ class LearnerDict(Task):
31
+ validation: Param[dict[str, Validation]]
32
+ x: Param[int]
33
+
34
+ @staticmethod
35
+ def create(x: int, validation: Param[Validation]):
36
+ return LearnerDict.C(x=x, validation={"key": validation})
37
+
38
+
39
+ class ModuleLoader(Task):
40
+ validation: Param[Validation] = field(ignore_generated=True)
41
+
42
+
43
+ @pytest.mark.parametrize("cls", [Learner, LearnerDict, LearnerList])
44
+ def test_generators_reuse_on_submit(cls):
45
+ # We have one way to select the best model
46
+ validation = Validation.C()
47
+
48
+ workspace = Workspace(
49
+ Settings(),
50
+ WorkspaceSettings("test_generators_reuse", path=Path("/tmp")),
51
+ run_mode=RunMode.DRY_RUN,
52
+ )
53
+
54
+ # OK, the path is generated depending on Learner with x=1
55
+ cls.create(1, validation).submit(workspace=workspace)
56
+
57
+ with pytest.raises((AttributeError)):
58
+ # Here we have a problem...
59
+ # the path is still the previous one
60
+ cls.create(2, validation).submit(workspace=workspace)
61
+
62
+
63
+ @pytest.mark.parametrize("cls", [Learner, LearnerDict, LearnerList])
64
+ def test_generators_delayed_submit(cls):
65
+ workspace = Workspace(
66
+ Settings(),
67
+ WorkspaceSettings("test_generators_simple", path=Path("/tmp")),
68
+ run_mode=RunMode.DRY_RUN,
69
+ )
70
+ validation = Validation.C()
71
+ task1 = cls.create(1, validation)
72
+ task2 = cls.create(2, validation)
73
+ task1.submit(workspace=workspace)
74
+ with pytest.raises((AttributeError)):
75
+ task2.submit(workspace=workspace)
76
+
77
+
78
+ @pytest.mark.parametrize("cls", [Learner, LearnerDict, LearnerList])
79
+ def test_generators_reuse_on_set(cls):
80
+ workspace = Workspace(
81
+ Settings(),
82
+ WorkspaceSettings("test_generators_simple", path=Path("/tmp")),
83
+ run_mode=RunMode.DRY_RUN,
84
+ )
85
+ validation = Validation.C()
86
+ cls.create(1, validation).submit(workspace=workspace)
87
+ with pytest.raises((AttributeError)):
88
+ # We should not be able to *create* a second task with the same validation,
89
+ # even without submitting it
90
+ cls.create(2, validation)
91
+
92
+ # This should run OK
93
+ ModuleLoader.C(validation=validation)
@@ -4,16 +4,14 @@ import json
4
4
  from pathlib import Path
5
5
  from typing import Dict, List, Optional
6
6
  from experimaestro import (
7
- config,
8
7
  Param,
9
- param,
10
8
  deprecate,
11
9
  Config,
12
10
  Constant,
13
11
  Meta,
14
12
  Option,
15
- pathgenerator,
16
- Annotated,
13
+ PathGenerator,
14
+ field,
17
15
  Task,
18
16
  LightweightTask,
19
17
  )
@@ -40,6 +38,11 @@ class C(Config):
40
38
  b: Param[int]
41
39
 
42
40
 
41
+ class CField(Config):
42
+ a: Param[int] = field(default_factory=lambda: 1)
43
+ b: Param[int]
44
+
45
+
43
46
  class D(Config):
44
47
  a: Param[A]
45
48
 
@@ -48,8 +51,7 @@ class Float(Config):
48
51
  value: Param[float]
49
52
 
50
53
 
51
- @config()
52
- class Values:
54
+ class Values(Config):
53
55
  value1: Param[float]
54
56
  value2: Param[float]
55
57
 
@@ -66,50 +68,54 @@ def assert_notequal(a, b, message=""):
66
68
  assert getidentifier(a) != getidentifier(b), message
67
69
 
68
70
 
69
- def test_int():
70
- assert_equal(A(a=1), A(a=1))
71
+ def test_identifier_int():
72
+ assert_equal(A.C(a=1), A.C(a=1))
73
+
74
+
75
+ def test_identifier_different_type():
76
+ assert_notequal(A.C(a=1), B.C(a=1))
71
77
 
72
78
 
73
- def test_different_type():
74
- assert_notequal(A(a=1), B(a=1))
79
+ def test_identifier_order():
80
+ assert_equal(Values.C(value1=1, value2=2), Values.C(value2=2, value1=1))
75
81
 
76
82
 
77
- def test_order():
78
- assert_equal(Values(value1=1, value2=2), Values(value2=2, value1=1))
83
+ def test_identifier_default():
84
+ assert_equal(C.C(a=1, b=2), C.C(b=2))
79
85
 
80
86
 
81
- def test_default():
82
- assert_equal(C(a=1, b=2), C(b=2))
87
+ def test_identifier_default_field():
88
+ assert_equal(CField(a=1, b=2), CField(b=2))
83
89
 
84
90
 
85
- def test_inner_eq():
86
- assert_equal(D(a=A(a=1)), D(a=A(a=1)))
91
+ def test_identifier_inner_eq():
92
+ assert_equal(D.C(a=A.C(a=1)), D.C(a=A.C(a=1)))
87
93
 
88
94
 
89
- def test_float():
90
- assert_equal(Float(value=1), Float(value=1))
95
+ def test_identifier_float():
96
+ assert_equal(Float.C(value=1), Float.C(value=1))
91
97
 
92
98
 
93
- def test_float2():
94
- assert_equal(Float(value=1.0), Float(value=1))
99
+ def test_identifier_float2():
100
+ assert_equal(Float.C(value=1.0), Float.C(value=1))
95
101
 
96
102
 
97
103
  # --- Argument name
98
104
 
99
105
 
100
- def test_name():
106
+ def test_identifier_name():
101
107
  """The identifier fully determines the hash code"""
102
108
 
103
- @config("test.identifier.argumentname")
104
- class Config0:
109
+ class Config0(Config):
110
+ __xpmid__ = "test.identifier.argumentname"
105
111
  a: Param[int]
106
112
 
107
- @config("test.identifier.argumentname")
108
- class Config1:
113
+ class Config1(Config):
114
+ __xpmid__ = "test.identifier.argumentname"
109
115
  b: Param[int]
110
116
 
111
- @config("test.identifier.argumentname")
112
- class Config3:
117
+ class Config3(Config):
118
+ __xpmid__ = "test.identifier.argumentname"
113
119
  a: Param[int]
114
120
 
115
121
  assert_notequal(Config0(a=2), Config1(b=2))
@@ -119,9 +125,9 @@ def test_name():
119
125
  # --- Test option
120
126
 
121
127
 
122
- def test_option():
123
- @config("test.identifier.option")
124
- class OptionConfig:
128
+ def test_identifier_option():
129
+ class OptionConfig(Config):
130
+ __xpmid__ = "test.identifier.option"
125
131
  a: Param[int]
126
132
  b: Option[int] = 1
127
133
 
@@ -152,13 +158,12 @@ def test_identifier_dict():
152
158
  # --- Ignore paths
153
159
 
154
160
 
155
- @config()
156
- class TypeWithPath:
161
+ class TypeWithPath(Config):
157
162
  a: Param[int]
158
163
  path: Param[Path]
159
164
 
160
165
 
161
- def test_path():
166
+ def test_identifier_path():
162
167
  """Path should be ignored"""
163
168
  assert_equal(TypeWithPath(a=1, path="/a/b"), TypeWithPath(a=1, path="/c/d"))
164
169
  assert_notequal(TypeWithPath(a=2, path="/a/b"), TypeWithPath(a=1, path="/c/d"))
@@ -167,23 +172,23 @@ def test_path():
167
172
  # --- Test with added arguments
168
173
 
169
174
 
170
- def test_pathoption():
175
+ def test_identifier_pathoption():
171
176
  """Path arguments should be ignored"""
172
177
 
173
- @config("pathoption_test")
174
- class A_with_path:
178
+ class A_with_path(Config):
179
+ __xpmid__ = "pathoption_test"
175
180
  a: Param[int]
176
- path: Annotated[Path, pathgenerator("path")]
181
+ path: Meta[Path] = field(default_factory=PathGenerator("path"))
177
182
 
178
- @config("pathoption_test")
179
- class A_without_path:
183
+ class A_without_path(Config):
184
+ __xpmid__ = "pathoption_test"
180
185
  a: Param[int]
181
186
 
182
187
  assert_equal(A_with_path(a=1), A_without_path(a=1))
183
188
 
184
189
 
185
190
  def test_identifier_enum():
186
- """Path arguments should be ignored"""
191
+ """test enum parameters"""
187
192
  from enum import Enum
188
193
 
189
194
  class EnumParam(Enum):
@@ -214,25 +219,24 @@ def test_identifier_addnone():
214
219
  assert_notequal(A_with_b(b=B(x=1)), A())
215
220
 
216
221
 
217
- def test_defaultnew():
222
+ def test_identifier_defaultnew():
218
223
  """Path arguments should be ignored"""
219
224
 
220
- @param("b", type=int, default=1)
221
- @param(name="a", type=int)
222
- @config("defaultnew")
223
- class A_with_b:
224
- pass
225
+ class A_with_b(Config):
226
+ __xpmid__ = "defaultnew"
225
227
 
226
- @param(name="a", type=int)
227
- @config("defaultnew")
228
- class A:
229
- pass
228
+ a: Param[int]
229
+ b: Param[int] = 1
230
+
231
+ class A(Config):
232
+ __xpmid__ = "defaultnew"
233
+ a: Param[int]
230
234
 
231
235
  assert_equal(A_with_b(a=1, b=1), A(a=1))
232
236
  assert_equal(A_with_b(a=1), A(a=1))
233
237
 
234
238
 
235
- def test_taskconfigidentifier():
239
+ def test_identifier_taskconfigidentifier():
236
240
  """Test whether the embedded task arguments make the configuration different"""
237
241
 
238
242
  class MyConfig(Config):
@@ -254,21 +258,21 @@ def test_taskconfigidentifier():
254
258
  )
255
259
 
256
260
 
257
- def test_constant():
261
+ def test_identifier_constant():
258
262
  """Test if constants are taken into account for signature computation"""
259
263
 
260
- @config("test.constant")
261
- class A1:
264
+ class A1(Config):
265
+ __xpmid__ = "test.constant"
262
266
  version: Constant[int] = 1
263
267
 
264
- @config("test.constant")
265
- class A1bis:
268
+ class A1bis(Config):
269
+ __xpmid__ = "test.constant"
266
270
  version: Constant[int] = 1
267
271
 
268
272
  assert_equal(A1(), A1bis())
269
273
 
270
- @config("test.constant")
271
- class A2:
274
+ class A2(Config):
275
+ __xpmid__ = "test.constant"
272
276
  version: Constant[int] = 2
273
277
 
274
278
  assert_notequal(A1(), A2())
@@ -382,36 +386,6 @@ def test_identifier_meta_default_array():
382
386
  )
383
387
 
384
388
 
385
- def test_identifier_pre_task():
386
- class MyConfig(Config):
387
- pass
388
-
389
- class IdentifierPreLightTask(LightweightTask):
390
- pass
391
-
392
- class IdentifierPreTask(Task):
393
- x: Param[MyConfig]
394
-
395
- task = IdentifierPreTask(x=MyConfig()).submit(run_mode=RunMode.DRY_RUN)
396
- task_with_pre = (
397
- IdentifierPreTask(x=MyConfig())
398
- .add_pretasks(IdentifierPreLightTask())
399
- .submit(run_mode=RunMode.DRY_RUN)
400
- )
401
- task_with_pre_2 = (
402
- IdentifierPreTask(x=MyConfig())
403
- .add_pretasks(IdentifierPreLightTask())
404
- .submit(run_mode=RunMode.DRY_RUN)
405
- )
406
- task_with_pre_3 = IdentifierPreTask(
407
- x=MyConfig().add_pretasks(IdentifierPreLightTask())
408
- ).submit(run_mode=RunMode.DRY_RUN)
409
-
410
- assert_notequal(task, task_with_pre, "No pre-task")
411
- assert_equal(task_with_pre, task_with_pre_2, "Same parameters")
412
- assert_equal(task_with_pre, task_with_pre_3, "Pre-tasks are order-less")
413
-
414
-
415
389
  def test_identifier_init_task():
416
390
  class MyConfig(Config):
417
391
  pass
@@ -422,26 +396,67 @@ def test_identifier_init_task():
422
396
  class IdentifierInitTask2(Task):
423
397
  pass
424
398
 
425
- class IdentierTask(Task):
399
+ class IdentifierTask(Task):
426
400
  x: Param[MyConfig]
427
401
 
428
- task = IdentierTask(x=MyConfig()).submit(run_mode=RunMode.DRY_RUN)
429
- task_with_pre = IdentierTask(x=MyConfig()).submit(
402
+ task = IdentifierTask.C(x=MyConfig.C()).submit(run_mode=RunMode.DRY_RUN)
403
+ task_with_pre = IdentifierTask.C(x=MyConfig.C()).submit(
430
404
  run_mode=RunMode.DRY_RUN,
431
405
  init_tasks=[IdentifierInitTask(), IdentifierInitTask2()],
432
406
  )
433
- task_with_pre_2 = IdentierTask(x=MyConfig()).submit(
407
+ task_with_pre_2 = IdentifierTask.C(x=MyConfig.C()).submit(
434
408
  run_mode=RunMode.DRY_RUN,
435
409
  init_tasks=[IdentifierInitTask(), IdentifierInitTask2()],
436
410
  )
437
- task_with_pre_3 = IdentierTask(x=MyConfig()).submit(
411
+ task_with_pre_3 = IdentifierTask.C(x=MyConfig.C()).submit(
438
412
  run_mode=RunMode.DRY_RUN,
439
413
  init_tasks=[IdentifierInitTask2(), IdentifierInitTask()],
440
414
  )
441
415
 
442
- assert_notequal(task, task_with_pre, "No pre-task")
416
+ assert_notequal(task, task_with_pre, "Should be different with init-task")
443
417
  assert_equal(task_with_pre, task_with_pre_2, "Same parameters")
444
- assert_notequal(task_with_pre, task_with_pre_3, "Same parameters")
418
+ assert_notequal(task_with_pre, task_with_pre_3, "Other parameters")
419
+
420
+
421
+ def test_identifier_init_task_dep():
422
+ class Loader(LightweightTask):
423
+ param1: Param[float]
424
+
425
+ def execute(self):
426
+ pass
427
+
428
+ class FirstTask(Task):
429
+ def task_outputs(self, dep):
430
+ return dep(Loader.C(param1=1))
431
+
432
+ def execute(self):
433
+ pass
434
+
435
+ class SecondTask(Task):
436
+ param3: Param[int]
437
+
438
+ def execute(self):
439
+ pass
440
+
441
+ # Two identical tasks
442
+ task_a_1 = FirstTask.C()
443
+ task_a_2 = FirstTask.C()
444
+ assert_equal(task_a_1, task_a_2)
445
+
446
+ # We process them with two different init tasks
447
+ loader_1 = task_a_1.submit(
448
+ init_tasks=[Loader.C(param1=0.5)], run_mode=RunMode.DRY_RUN
449
+ )
450
+ loader_2 = task_a_2.submit(
451
+ init_tasks=[Loader.C(param1=5)], run_mode=RunMode.DRY_RUN
452
+ )
453
+ assert_notequal(loader_1, loader_2)
454
+
455
+ # Now, we process
456
+ c_1 = SecondTask.C(param3=2).submit(init_tasks=[loader_1], run_mode=RunMode.DRY_RUN)
457
+
458
+ c_2 = SecondTask.C(param3=2).submit(init_tasks=[loader_2], run_mode=RunMode.DRY_RUN)
459
+ assert_notequal(c_1, c_2)
445
460
 
446
461
 
447
462
  # --- Check configuration reloads
@@ -459,7 +474,7 @@ def check_reload(config):
459
474
  new_config = ConfigInformation.fromParameters(
460
475
  data, as_instance=False, discard_id=True
461
476
  )
462
- assert new_config.__xpm__._full_identifier is None
477
+ assert new_config.__xpm__._identifier is None
463
478
  new_identifier = new_config.__xpm__.identifier.all
464
479
 
465
480
  assert new_identifier == old_identifier
@@ -542,9 +557,9 @@ class LoopC(Config):
542
557
 
543
558
 
544
559
  def test_identifier_loop():
545
- c = LoopC()
546
- b = LoopB(param_c=c)
547
- a = LoopA(param_b=b)
560
+ c = LoopC.C()
561
+ b = LoopB.C(param_c=c)
562
+ a = LoopA.C(param_b=b)
548
563
  c.param_a = a
549
564
  c.param_b = b
550
565
 
@@ -1,21 +1,18 @@
1
1
  from typing import Optional
2
- from experimaestro import config, Param, Config
3
- from experimaestro.core.objects import TypeConfig
2
+ from experimaestro import Param, Config
3
+ from experimaestro.core.objects import ConfigMixin
4
4
  from experimaestro.core.serializers import SerializationLWTask
5
5
 
6
6
 
7
- @config()
8
- class A:
7
+ class A(Config):
9
8
  x: Param[int] = 1
10
9
 
11
10
 
12
- @config()
13
11
  class A1(A):
14
12
  pass
15
13
 
16
14
 
17
- @config()
18
- class B:
15
+ class B(Config):
19
16
  a: Param[A]
20
17
 
21
18
 
@@ -24,10 +21,10 @@ def test_simple_instance():
24
21
  b = B(a=a)
25
22
  b = b.instance()
26
23
 
27
- assert not isinstance(b, TypeConfig)
24
+ assert not isinstance(b, ConfigMixin)
28
25
  assert isinstance(b, B.__xpmtype__.objecttype)
29
26
 
30
- assert not isinstance(b.a, TypeConfig)
27
+ assert not isinstance(b.a, ConfigMixin)
31
28
  assert isinstance(b.a, A1.__xpmtype__.objecttype)
32
29
  assert isinstance(b.a, A.__xpmtype__.basetype)
33
30
 
@@ -49,18 +46,6 @@ class LoadModel(SerializationLWTask):
49
46
  self.value.initialized = True
50
47
 
51
48
 
52
- def test_instance_serialized():
53
- model = Model()
54
- model.add_pretasks(LoadModel(value=model))
55
- trainer = Evaluator(model=model)
56
- instance = trainer.instance()
57
-
58
- assert isinstance(
59
- instance.model, Model
60
- ), f"The model is not a Model but a {type(instance.model).__qualname__}"
61
- assert instance.model.initialized, "The model was not initialized"
62
-
63
-
64
49
  class ConfigWithOptional(Config):
65
50
  x: Param[int] = 1
66
51
  y: Param[Optional[int]]