experimaestro 1.5.1__py3-none-any.whl → 2.0.0a8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of experimaestro might be problematic. Click here for more details.

Files changed (118) hide show
  1. experimaestro/__init__.py +14 -4
  2. experimaestro/__main__.py +3 -423
  3. experimaestro/annotations.py +14 -4
  4. experimaestro/cli/__init__.py +311 -0
  5. experimaestro/{filter.py → cli/filter.py} +23 -9
  6. experimaestro/cli/jobs.py +268 -0
  7. experimaestro/cli/progress.py +269 -0
  8. experimaestro/click.py +0 -35
  9. experimaestro/commandline.py +3 -7
  10. experimaestro/connectors/__init__.py +29 -14
  11. experimaestro/connectors/local.py +19 -10
  12. experimaestro/connectors/ssh.py +27 -8
  13. experimaestro/core/arguments.py +45 -3
  14. experimaestro/core/callbacks.py +52 -0
  15. experimaestro/core/context.py +8 -9
  16. experimaestro/core/identifier.py +310 -0
  17. experimaestro/core/objects/__init__.py +44 -0
  18. experimaestro/core/{objects.py → objects/config.py} +399 -772
  19. experimaestro/core/objects/config_utils.py +58 -0
  20. experimaestro/core/objects/config_walk.py +151 -0
  21. experimaestro/core/objects.pyi +15 -45
  22. experimaestro/core/serialization.py +63 -9
  23. experimaestro/core/serializers.py +1 -8
  24. experimaestro/core/types.py +104 -66
  25. experimaestro/experiments/cli.py +154 -72
  26. experimaestro/experiments/configuration.py +10 -1
  27. experimaestro/generators.py +6 -1
  28. experimaestro/ipc.py +4 -1
  29. experimaestro/launcherfinder/__init__.py +1 -1
  30. experimaestro/launcherfinder/base.py +2 -18
  31. experimaestro/launcherfinder/parser.py +8 -3
  32. experimaestro/launcherfinder/registry.py +52 -140
  33. experimaestro/launcherfinder/specs.py +49 -10
  34. experimaestro/launchers/direct.py +0 -47
  35. experimaestro/launchers/slurm/base.py +54 -14
  36. experimaestro/mkdocs/__init__.py +1 -1
  37. experimaestro/mkdocs/base.py +6 -8
  38. experimaestro/notifications.py +38 -12
  39. experimaestro/progress.py +406 -0
  40. experimaestro/run.py +24 -3
  41. experimaestro/scheduler/__init__.py +18 -1
  42. experimaestro/scheduler/base.py +108 -808
  43. experimaestro/scheduler/dynamic_outputs.py +184 -0
  44. experimaestro/scheduler/experiment.py +387 -0
  45. experimaestro/scheduler/jobs.py +475 -0
  46. experimaestro/scheduler/signal_handler.py +32 -0
  47. experimaestro/scheduler/state.py +75 -0
  48. experimaestro/scheduler/workspace.py +27 -8
  49. experimaestro/scriptbuilder.py +18 -3
  50. experimaestro/server/__init__.py +36 -5
  51. experimaestro/server/data/1815e00441357e01619e.ttf +0 -0
  52. experimaestro/server/data/2463b90d9a316e4e5294.woff2 +0 -0
  53. experimaestro/server/data/2582b0e4bcf85eceead0.ttf +0 -0
  54. experimaestro/server/data/89999bdf5d835c012025.woff2 +0 -0
  55. experimaestro/server/data/914997e1bdfc990d0897.ttf +0 -0
  56. experimaestro/server/data/c210719e60948b211a12.woff2 +0 -0
  57. experimaestro/server/data/index.css +5187 -5068
  58. experimaestro/server/data/index.css.map +1 -1
  59. experimaestro/server/data/index.js +68887 -68064
  60. experimaestro/server/data/index.js.map +1 -1
  61. experimaestro/settings.py +45 -5
  62. experimaestro/sphinx/__init__.py +7 -17
  63. experimaestro/taskglobals.py +7 -2
  64. experimaestro/tests/core/__init__.py +0 -0
  65. experimaestro/tests/core/test_generics.py +206 -0
  66. experimaestro/tests/definitions_types.py +5 -3
  67. experimaestro/tests/launchers/bin/sbatch +34 -7
  68. experimaestro/tests/launchers/bin/srun +5 -0
  69. experimaestro/tests/launchers/common.py +17 -5
  70. experimaestro/tests/launchers/config_slurm/launchers.py +25 -0
  71. experimaestro/tests/restart.py +10 -5
  72. experimaestro/tests/tasks/all.py +23 -10
  73. experimaestro/tests/tasks/foreign.py +2 -4
  74. experimaestro/tests/test_checkers.py +2 -2
  75. experimaestro/tests/test_dependencies.py +11 -17
  76. experimaestro/tests/test_experiment.py +73 -0
  77. experimaestro/tests/test_file_progress.py +425 -0
  78. experimaestro/tests/test_file_progress_integration.py +477 -0
  79. experimaestro/tests/test_findlauncher.py +12 -5
  80. experimaestro/tests/test_forward.py +5 -5
  81. experimaestro/tests/test_generators.py +93 -0
  82. experimaestro/tests/test_identifier.py +182 -158
  83. experimaestro/tests/test_instance.py +19 -27
  84. experimaestro/tests/test_objects.py +13 -20
  85. experimaestro/tests/test_outputs.py +6 -6
  86. experimaestro/tests/test_param.py +68 -30
  87. experimaestro/tests/test_progress.py +4 -4
  88. experimaestro/tests/test_serializers.py +24 -64
  89. experimaestro/tests/test_ssh.py +7 -0
  90. experimaestro/tests/test_tags.py +50 -21
  91. experimaestro/tests/test_tasks.py +42 -51
  92. experimaestro/tests/test_tokens.py +11 -8
  93. experimaestro/tests/test_types.py +24 -21
  94. experimaestro/tests/test_validation.py +67 -110
  95. experimaestro/tests/token_reschedule.py +1 -1
  96. experimaestro/tokens.py +24 -13
  97. experimaestro/tools/diff.py +8 -1
  98. experimaestro/typingutils.py +20 -11
  99. experimaestro/utils/asyncio.py +6 -2
  100. experimaestro/utils/multiprocessing.py +44 -0
  101. experimaestro/utils/resources.py +11 -3
  102. {experimaestro-1.5.1.dist-info → experimaestro-2.0.0a8.dist-info}/METADATA +28 -36
  103. experimaestro-2.0.0a8.dist-info/RECORD +166 -0
  104. {experimaestro-1.5.1.dist-info → experimaestro-2.0.0a8.dist-info}/WHEEL +1 -1
  105. {experimaestro-1.5.1.dist-info → experimaestro-2.0.0a8.dist-info}/entry_points.txt +0 -4
  106. experimaestro/launchers/slurm/cli.py +0 -29
  107. experimaestro/launchers/slurm/configuration.py +0 -597
  108. experimaestro/scheduler/environment.py +0 -94
  109. experimaestro/server/data/016b4a6cdced82ab3aa1.ttf +0 -0
  110. experimaestro/server/data/50701fbb8177c2dde530.ttf +0 -0
  111. experimaestro/server/data/878f31251d960bd6266f.woff2 +0 -0
  112. experimaestro/server/data/b041b1fa4fe241b23445.woff2 +0 -0
  113. experimaestro/server/data/b6879d41b0852f01ed5b.woff2 +0 -0
  114. experimaestro/server/data/d75e3fd1eb12e9bd6655.ttf +0 -0
  115. experimaestro/tests/launchers/config_slurm/launchers.yaml +0 -134
  116. experimaestro/utils/yaml.py +0 -202
  117. experimaestro-1.5.1.dist-info/RECORD +0 -148
  118. {experimaestro-1.5.1.dist-info → experimaestro-2.0.0a8.dist-info/licenses}/LICENSE +0 -0
@@ -1,27 +1,15 @@
1
- # Launcher registry
1
+ # Configuration registers
2
+
3
+ from contextlib import contextmanager
4
+ from typing import ClassVar, Dict, Optional, Set, Type, Union
2
5
 
3
- from dataclasses import dataclass
4
- import itertools
5
- from types import new_class
6
- from typing import ClassVar, Dict, List, Optional, Set, Type, Union
7
- from experimaestro import Annotated
8
6
  from pathlib import Path
9
7
  import typing
10
- import pkg_resources
11
- import humanfriendly
12
- import yaml
13
- from yaml import Loader, Dumper
8
+ from omegaconf import DictConfig, OmegaConf, SCMode
9
+ from importlib.metadata import entry_points
14
10
  from experimaestro.utils import logger
15
- from experimaestro.utils.yaml import (
16
- Initialize,
17
- YAMLDataClass,
18
- YAMLException,
19
- YAMLList,
20
- add_path_resolvers,
21
- )
22
-
23
- from .base import LauncherConfiguration, ConnectorConfiguration, TokenConfiguration
24
- from .specs import CPUSpecification, CudaSpecification, HostRequirement
11
+ from .base import ConnectorConfiguration, TokenConfiguration
12
+ from .specs import HostRequirement, RequirementUnion
25
13
 
26
14
  if typing.TYPE_CHECKING:
27
15
  from experimaestro.launchers import Launcher
@@ -32,80 +20,30 @@ class LauncherNotFoundError(Exception):
32
20
  pass
33
21
 
34
22
 
35
- @dataclass
36
- class GPU(YAMLDataClass):
37
- """Represents a GPU"""
38
-
39
- model: str
40
- count: int
41
- memory: Annotated[int, Initialize(humanfriendly.parse_size)]
42
-
43
- def to_spec(self):
44
- return [CudaSpecification(self.memory, self.model) for _ in range(self.count)]
45
-
46
-
47
- class GPUList(YAMLList[GPU]):
48
- """Represents a list of GPUs"""
49
-
50
- def __repr__(self):
51
- return f"GPUs({super().__repr__()})"
52
-
53
- def to_spec(self) -> List[CudaSpecification]:
54
- return list(itertools.chain(*[gpu.to_spec() for gpu in self]))
55
-
56
-
57
- @dataclass
58
- class CPU(YAMLDataClass):
59
- """Represents a CPU"""
60
-
61
- memory: Annotated[int, Initialize(humanfriendly.parse_size)] = 0
62
- cores: int = 1
63
-
64
- def to_spec(self):
65
- return CPUSpecification(self.memory, self.cores)
66
-
67
-
68
- @dataclass
69
- class Host(YAMLDataClass):
70
- name: str
71
- gpus: List[GPU]
72
- launchers: List[str]
73
-
74
-
75
- Launchers = Dict[str, List[LauncherConfiguration]]
76
23
  Connectors = Dict[str, Dict[str, ConnectorConfiguration]]
77
24
  Tokens = Dict[str, Dict[str, TokenConfiguration]]
78
25
 
79
26
 
80
- def new_loader(name: str) -> Type[Loader]:
81
- return new_class("LauncherLoader", (yaml.FullLoader,)) # type: ignore
82
-
83
-
84
- def load_yaml(loader_cls: Type[Loader], path: Path):
27
+ def load_yaml(schema, path: Path):
85
28
  if not path.is_file():
86
- return None
29
+ return {}
87
30
 
88
- logger.warning(
89
- "Using YAML file to configure launchers is deprecated. Please remove %s using launchers.py",
90
- path,
91
- )
92
31
  logger.debug("Loading %s", path)
93
32
  with path.open("rt") as fp:
94
- loader = loader_cls(fp)
95
- try:
96
- return loader.get_single_data()
97
- finally:
98
- loader.dispose()
33
+ cfg = OmegaConf.load(fp)
34
+ return OmegaConf.to_container(
35
+ OmegaConf.merge(cfg, schema), structured_config_mode=SCMode.INSTANTIATE
36
+ )
99
37
 
100
38
 
101
- def unknown_error(loader: Loader, node):
102
- raise YAMLException(
103
- "",
104
- node.start_mark.name,
105
- node.start_mark.line,
106
- node.start_mark.column,
107
- f"No handler defined for key {node}",
108
- )
39
+ @contextmanager
40
+ def ensure_enter(fp):
41
+ """Behaves as a resource, whether it is one or not"""
42
+ if hasattr(fp, "__enter__"):
43
+ with fp as _fp:
44
+ yield _fp
45
+ else:
46
+ yield fp
109
47
 
110
48
 
111
49
  class LauncherRegistry:
@@ -132,72 +70,47 @@ class LauncherRegistry:
132
70
  LauncherRegistry.CURRENT_CONFIG_DIR = config_dir
133
71
 
134
72
  def __init__(self, basepath: Path):
135
- self.LauncherLoader: Type[Loader] = new_loader("LauncherLoader")
136
- self.ConnectorLoader: Type[Loader] = new_loader("ConnectorLoader")
137
- self.TokenLoader: Type[Loader] = new_loader("TokenLoader")
138
- self.Dumper: Type[Dumper] = new_class("CustomDumper", (Dumper,), {})
73
+ self.connectors_schema = DictConfig({})
74
+ self.tokens_schema = DictConfig({})
139
75
  self.find_launcher_fn = None
140
76
 
141
- # Add safeguards
142
- add_path_resolvers(
143
- self.LauncherLoader,
144
- [],
145
- Dict[str, LauncherConfiguration],
146
- dumper=self.Dumper,
147
- )
148
-
149
77
  # Use entry points for connectors and launchers
150
- for entry_point in pkg_resources.iter_entry_points("experimaestro.connectors"):
151
- entry_point.load().init_registry(self)
152
-
153
- for entry_point in pkg_resources.iter_entry_points("experimaestro.launchers"):
78
+ for entry_point in entry_points(group="experimaestro.connectors"):
154
79
  entry_point.load().init_registry(self)
155
80
 
156
- for entry_point in pkg_resources.iter_entry_points("experimaestro.tokens"):
81
+ for entry_point in entry_points(group="experimaestro.tokens"):
157
82
  entry_point.load().init_registry(self)
158
83
 
159
84
  # Register the find launcher function if it exists
160
85
  launchers_py = basepath / "launchers.py"
86
+ print(f"basepath {launchers_py}")
161
87
  if launchers_py.is_file():
162
88
  logger.info("Loading %s", launchers_py)
163
89
 
164
90
  from importlib import util
165
91
 
166
- spec = util.spec_from_file_location("xpm_launchers_conf", launchers_py)
167
- module = util.module_from_spec(spec)
168
- spec.loader.exec_module(module)
92
+ with ensure_enter(launchers_py.__fspath__()) as fp:
93
+ spec = util.spec_from_file_location("xpm_launchers_conf", fp)
94
+ module = util.module_from_spec(spec)
95
+ spec.loader.exec_module(module)
169
96
 
170
97
  self.find_launcher_fn = getattr(module, "find_launcher", None)
171
98
  if self.find_launcher_fn is None:
172
- logger.warn("No find_launcher() function was found in %s", launchers_py)
99
+ logger.warning(
100
+ "No find_launcher() function was found in %s", launchers_py
101
+ )
173
102
 
174
103
  # Read the configuration file
175
- launchers: Launchers = (
176
- load_yaml(self.LauncherLoader, basepath / "launchers.yaml") or {}
177
- )
178
- self.launchers = sorted(
179
- itertools.chain(*launchers.values()), key=lambda launcher: -launcher.weight
180
- )
181
-
182
- self.connectors: Connectors = (
183
- load_yaml(self.ConnectorLoader, basepath / "connectors.yaml") or {}
184
- )
185
- self.tokens: Tokens = (
186
- load_yaml(self.TokenLoader, basepath / "tokens.yaml") or {}
104
+ self.connectors = load_yaml(
105
+ self.connectors_schema, basepath / "connectors.yaml"
187
106
  )
107
+ self.tokens = load_yaml(self.tokens_schema, basepath / "tokens.yaml")
188
108
 
189
- def register_launcher(self, identifier: str, cls: Type[YAMLDataClass]):
190
- add_path_resolvers(
191
- self.LauncherLoader, [identifier, None], cls, dumper=self.Dumper
192
- )
193
-
194
- def register_connector(self, identifier: str, cls: Type[YAMLDataClass]):
195
- add_path_resolvers(
196
- self.ConnectorLoader, [identifier, None], cls, dumper=self.Dumper
197
- )
109
+ def register_connector(self, identifier: str, cls: Type):
110
+ self.connectors_schema.merge_with({identifier: cls})
198
111
 
199
- def register_token(self, identifier: str, cls: Type[YAMLDataClass]):
200
- add_path_resolvers(self.TokenLoader, [identifier], cls, dumper=self.Dumper)
112
+ def register_token(self, identifier: str, cls: Type):
113
+ self.tokens_schema.merge_with({identifier: cls})
201
114
 
202
115
  def getToken(self, identifier: str) -> "Token":
203
116
  for tokens in self.tokens.values():
@@ -227,7 +140,7 @@ class LauncherRegistry:
227
140
  tags: Restrict the launchers to those containing one of the specified tags
228
141
  """
229
142
 
230
- if len(self.launchers) == 0 and self.find_launcher_fn is None:
143
+ if self.find_launcher_fn is None:
231
144
  logger.info("No launchers.yaml file: using local host ")
232
145
  from experimaestro.launchers.direct import DirectLauncher
233
146
  from experimaestro.connectors.local import LocalConnector
@@ -237,25 +150,24 @@ class LauncherRegistry:
237
150
  # Parse specs
238
151
  from .parser import parse
239
152
 
240
- specs = []
153
+ specs = RequirementUnion()
241
154
  for spec in input_specs:
242
155
  if isinstance(spec, str):
243
- specs.extend(parse(spec))
156
+ specs.add(parse(spec))
244
157
  else:
245
- specs.append(spec)
158
+ specs.add(spec)
246
159
 
247
160
  # Use launcher function
161
+ from experimaestro.launchers import Launcher
162
+
248
163
  if self.find_launcher_fn is not None:
249
- for spec in specs:
250
- if launcher := self.find_launcher_fn(*specs, tags):
164
+ for spec in specs.requirements:
165
+ if launcher := self.find_launcher_fn(spec, tags):
166
+ assert isinstance(
167
+ launcher, Launcher
168
+ ), "f{self.find_launcher_fn} did not return a Launcher but {type(launcher)}"
251
169
  return launcher
252
170
 
253
- # We have registered launchers
254
- for spec in specs:
255
- for handler in self.launchers:
256
- if (not tags) or any((tag in tags) for tag in handler.tags):
257
- if launcher := handler.get(self, spec):
258
- return launcher
259
171
  return None
260
172
 
261
173
 
@@ -1,4 +1,6 @@
1
+ from abc import ABC, abstractmethod
1
2
  import logging
3
+ import math
2
4
  from attr import Factory
3
5
  from attrs import define
4
6
  from copy import copy, deepcopy
@@ -20,9 +22,6 @@ class CudaSpecification:
20
22
  min_memory: int = 0
21
23
  """Minimum request memory (in bytes)"""
22
24
 
23
- def __lt__(self, other: "CudaSpecification"):
24
- return self.memory < other.memory
25
-
26
25
  def match(self, spec: "CudaSpecification"):
27
26
  """Returns True if the specification matches this host"""
28
27
  return (self.memory >= spec.memory) and (self.min_memory <= spec.memory)
@@ -30,7 +29,7 @@ class CudaSpecification:
30
29
  def __repr__(self):
31
30
  return (
32
31
  f"CUDA({self.model} "
33
- f"{format_size(self.memory)}/{format_size(self.min_memory)})"
32
+ f"max={format_size(self.memory, binary=True)}/min={format_size(self.min_memory, binary=True)})"
34
33
  )
35
34
 
36
35
 
@@ -48,8 +47,15 @@ class CPUSpecification:
48
47
  cpu_per_gpu: int = 0
49
48
  """Number of CPU per GPU (0 if not defined)"""
50
49
 
51
- def __lt__(self, other: "CPUSpecification"):
52
- return self.memory < other.memory and self.cores < other.cores
50
+ def __repr__(self):
51
+ return (
52
+ f"CPU("
53
+ f"mem={format_size(self.memory, binary=True)}, cores={self.cores}"
54
+ ")"
55
+ )
56
+
57
+ def match(self, other: "CPUSpecification"):
58
+ return (self.memory >= other.memory) and (self.cores >= other.cores)
53
59
 
54
60
  def total_memory(self, gpus: int = 0):
55
61
  return max(
@@ -91,10 +97,11 @@ class MatchRequirement:
91
97
  requirement: "HostSimpleRequirement"
92
98
 
93
99
 
94
- class HostRequirement:
100
+ class HostRequirement(ABC):
95
101
  """A requirement must be a disjunction of host requirements"""
96
102
 
97
103
  requirements: List["HostSimpleRequirement"]
104
+ """List of requirements (by order of priority)"""
98
105
 
99
106
  def __init__(self) -> None:
100
107
  self.requirements = []
@@ -105,6 +112,12 @@ class HostRequirement:
105
112
  def match(self, host: HostSpecification) -> Optional[MatchRequirement]:
106
113
  raise NotImplementedError()
107
114
 
115
+ @abstractmethod
116
+ def multiply_duration(self, coefficient: float) -> "HostRequirement":
117
+ """Returns a new HostRequirement with a duration multiplied by the
118
+ provided coefficient"""
119
+ ...
120
+
108
121
 
109
122
  class RequirementUnion(HostRequirement):
110
123
  """Ordered list of simple host requirements -- the first one is the priority"""
@@ -114,6 +127,16 @@ class RequirementUnion(HostRequirement):
114
127
  def __init__(self, *requirements: "HostSimpleRequirement"):
115
128
  self.requirements = list(requirements)
116
129
 
130
+ def add(self, requirement: "HostRequirement"):
131
+ match requirement:
132
+ case HostSimpleRequirement():
133
+ self.requirements.extend(*requirement.requirements)
134
+ case RequirementUnion():
135
+ self.requirements.append(requirement)
136
+ case _:
137
+ raise RuntimeError("Cannot handle type %s", type(requirement))
138
+ return self
139
+
117
140
  def match(self, host: HostSpecification) -> Optional[MatchRequirement]:
118
141
  """Returns the matched requirement (if any)"""
119
142
 
@@ -128,12 +151,23 @@ class RequirementUnion(HostRequirement):
128
151
 
129
152
  return argmax
130
153
 
154
+ def multiply_duration(self, coefficient: float) -> "RequirementUnion":
155
+ return RequirementUnion(
156
+ *[r.multiply_duration(coefficient) for r in self.requirements]
157
+ )
158
+
159
+ def __repr__(self):
160
+ return " | ".join(repr(r) for r in self.requirements)
161
+
131
162
 
132
163
  class HostSimpleRequirement(HostRequirement):
133
164
  """Simple host requirement"""
134
165
 
135
166
  cuda_gpus: List["CudaSpecification"]
167
+ """Specification for CUDA gpus"""
168
+
136
169
  cpu: "CPUSpecification"
170
+ """Specification for CPU"""
137
171
 
138
172
  duration: int
139
173
  """Requested duration (in seconds)"""
@@ -141,6 +175,11 @@ class HostSimpleRequirement(HostRequirement):
141
175
  def __repr__(self):
142
176
  return f"Req(cpu={self.cpu}, cuda={self.cuda_gpus}, duration={self.duration})"
143
177
 
178
+ def multiply_duration(self, coefficient: float) -> "HostSimpleRequirement":
179
+ r = HostSimpleRequirement(self)
180
+ r.duration = math.ceil(self.duration * coefficient)
181
+ return r
182
+
144
183
  def __init__(self, *reqs: "HostSimpleRequirement"):
145
184
  self.cuda_gpus = []
146
185
  self.cpu = CPUSpecification(0, 0)
@@ -158,7 +197,7 @@ class HostSimpleRequirement(HostRequirement):
158
197
  self.cpu.cores = max(req.cpu.cores, self.cpu.cores)
159
198
  self.duration = max(req.duration, self.duration)
160
199
  self.cuda_gpus.extend(req.cuda_gpus)
161
- self.cuda_gpus.sort()
200
+ self.cuda_gpus.sort(key=lambda cuda: -cuda.memory)
162
201
 
163
202
  def match(self, host: HostSpecification) -> Optional[MatchRequirement]:
164
203
  if self.cuda_gpus:
@@ -182,7 +221,7 @@ class HostSimpleRequirement(HostRequirement):
182
221
  )
183
222
  return None
184
223
 
185
- if host.cpu < self.cpu:
224
+ if not host.cpu.match(self.cpu):
186
225
  return None
187
226
 
188
227
  if host.max_duration > 0 and self.duration > host.max_duration:
@@ -197,7 +236,7 @@ class HostSimpleRequirement(HostRequirement):
197
236
  _self = deepcopy(self)
198
237
  for _ in range(count - 1):
199
238
  _self.cuda_gpus.extend(self.cuda_gpus)
200
- _self.cuda_gpus.sort()
239
+ _self.cuda_gpus.sort(key=lambda cuda: -cuda.memory)
201
240
 
202
241
  return _self
203
242
 
@@ -1,15 +1,3 @@
1
- from dataclasses import dataclass, field
2
- from functools import cached_property
3
- from typing import Dict, List, Optional
4
- from experimaestro.launcherfinder import (
5
- LauncherConfiguration,
6
- LauncherRegistry,
7
- HostRequirement,
8
- )
9
- from experimaestro.launcherfinder.registry import CPU, GPUList, YAMLDataClass
10
- from experimaestro.launcherfinder.specs import (
11
- HostSpecification,
12
- )
13
1
  from experimaestro.scriptbuilder import PythonScriptBuilder
14
2
  from . import Launcher
15
3
 
@@ -18,40 +6,5 @@ class DirectLauncher(Launcher):
18
6
  def scriptbuilder(self):
19
7
  return PythonScriptBuilder()
20
8
 
21
- @staticmethod
22
- def init_registry(registry: LauncherRegistry):
23
- registry.register_launcher("local", DirectLauncherConfiguration)
24
-
25
9
  def __str__(self):
26
10
  return f"DirectLauncher({self.connector})"
27
-
28
-
29
- @dataclass
30
- class DirectLauncherConfiguration(YAMLDataClass, LauncherConfiguration):
31
- connector: str = "connector"
32
- cpu: CPU = field(default_factory=CPU)
33
- gpus: GPUList = field(default_factory=GPUList)
34
- tokens: Optional[Dict[str, int]] = None
35
- tags: List[str] = field(default_factory=lambda: [])
36
- weight: int = 0
37
- disable: bool = False
38
-
39
- @cached_property
40
- def spec(self) -> HostSpecification:
41
- return HostSpecification(cpu=self.cpu.to_spec(), cuda=self.gpus.to_spec())
42
-
43
- def get(
44
- self, registry: LauncherRegistry, requirement: "HostRequirement"
45
- ) -> Optional[Launcher]:
46
- if requirement.match(self.spec):
47
- launcher = DirectLauncher(connector=registry.getConnector(self.connector))
48
- if self.tokens:
49
- for token_identifier, count in self.tokens.items():
50
- token = registry.getToken(token_identifier)
51
- # TODO: handle the case where this is not a CounterToken
52
- launcher.addListener(
53
- lambda job: job.dependencies.add(token.dependency(count))
54
- )
55
- return launcher
56
-
57
- return None
@@ -11,6 +11,7 @@ from typing import (
11
11
  )
12
12
  from experimaestro.connectors.local import LocalConnector
13
13
  import re
14
+ from shlex import quote as shquote
14
15
  from contextlib import contextmanager
15
16
  from dataclasses import dataclass
16
17
  from experimaestro.launcherfinder.registry import (
@@ -86,7 +87,7 @@ class SlurmProcessWatcher(threading.Thread):
86
87
  WATCHERS: Dict[Tuple[Tuple[str, Any]], "SlurmProcessWatcher"] = {}
87
88
 
88
89
  def __init__(self, launcher: "SlurmLauncher"):
89
- super().__init__()
90
+ super().__init__(daemon=True)
90
91
  self.launcher = launcher
91
92
  self.count = 1
92
93
  self.jobs: Dict[str, SlurmJobState] = {}
@@ -183,10 +184,10 @@ class BatchSlurmProcess(Process):
183
184
  if state and state.finished():
184
185
  return 0 if state.slurm_state == "COMPLETED" else 1
185
186
 
186
- async def aio_state(self):
187
+ async def aio_state(self, timeout: float | None = None) -> ProcessState:
187
188
  def check():
188
189
  with SlurmProcessWatcher.get(self.launcher) as watcher:
189
- jobinfo = watcher.getjob(self.jobid)
190
+ jobinfo = watcher.getjob(self.jobid, timeout=timeout)
190
191
  return jobinfo.state if jobinfo else ProcessState.SCHEDULED
191
192
 
192
193
  return await asyncThreadcheck("slurm.aio_isrunning", check)
@@ -211,7 +212,7 @@ class BatchSlurmProcess(Process):
211
212
 
212
213
  # Checks that the process is running
213
214
  with SlurmProcessWatcher.get(launcher) as watcher:
214
- logger.info("Checking SLURM job %s", process.jobid)
215
+ logger.debug("Checking SLURM job %s", process.jobid)
215
216
  jobinfo = watcher.getjob(process.jobid, timeout=0.1)
216
217
  if jobinfo and jobinfo.state.running:
217
218
  logger.debug(
@@ -235,15 +236,15 @@ class SlurmProcessBuilder(ProcessBuilder):
235
236
  super().__init__()
236
237
  self.launcher = launcher
237
238
 
238
- def start(self) -> BatchSlurmProcess:
239
+ def start(self, task_mode: bool = False) -> BatchSlurmProcess:
239
240
  """Start the process"""
240
241
  builder = self.launcher.connector.processbuilder()
241
- builder.workingDirectory = self.workingDirectory
242
242
  builder.environ = self.launcher.launcherenv
243
243
  builder.detach = False
244
244
 
245
245
  if not self.detach:
246
246
  # Simplest case: we wait for the output
247
+ builder.workingDirectory = self.workingDirectory
247
248
  builder.command = [f"{self.launcher.binpath}/srun"]
248
249
  builder.command.extend(self.launcher.options.args())
249
250
  builder.command.extend(self.command)
@@ -255,14 +256,22 @@ class SlurmProcessBuilder(ProcessBuilder):
255
256
  return builder.start()
256
257
 
257
258
  builder.command = [f"{self.launcher.binpath}/sbatch", "--parsable"]
258
- builder.command.extend(self.launcher.options.args())
259
259
 
260
- addstream(builder.command, "-e", self.stderr)
261
- addstream(builder.command, "-o", self.stdout)
262
- addstream(builder.command, "-i", self.stdin)
260
+ if not task_mode:
261
+ # Use command line parameters when not running a task
262
+ builder.command.extend(self.launcher.options.args())
263
+
264
+ if self.workingDirectory:
265
+ workdir = self.launcher.connector.resolve(self.workingDirectory)
266
+ builder.command.append(f"--chdir={workdir}")
267
+ addstream(builder.command, "-e", self.stderr)
268
+ addstream(builder.command, "-o", self.stdout)
269
+ addstream(builder.command, "-i", self.stdin)
263
270
 
264
271
  builder.command.extend(self.command)
265
- logger.info("slurm sbatch command: %s", builder.command)
272
+ logger.info(
273
+ "slurm sbatch command: %s", " ".join(f'"{s}"' for s in builder.command)
274
+ )
266
275
  handler = OutputCaptureHandler()
267
276
  builder.stdout = Redirect.pipe(handler)
268
277
  builder.stderr = Redirect.inherit()
@@ -425,12 +434,43 @@ class SlurmLauncher(Launcher):
425
434
 
426
435
  We assume *nix, but should be changed to PythonScriptBuilder when working
427
436
  """
428
- builder = PythonScriptBuilder()
429
- builder.processtype = "slurm"
430
- return builder
437
+ return SlurmScriptBuilder(self)
431
438
 
432
439
  def processbuilder(self) -> SlurmProcessBuilder:
433
440
  """Returns the process builder for this launcher
434
441
 
435
442
  By default, returns the associated connector builder"""
436
443
  return SlurmProcessBuilder(self)
444
+
445
+
446
+ class SlurmScriptBuilder(PythonScriptBuilder):
447
+ def __init__(self, launcher: SlurmLauncher, pythonpath=None):
448
+ super().__init__(pythonpath)
449
+ self.launcher = launcher
450
+ self.processtype = "slurm"
451
+
452
+ def write(self, job):
453
+ py_path = super().write(job)
454
+ main_path = py_path.parent
455
+
456
+ def relpath(path: Path):
457
+ return shquote(self.launcher.connector.resolve(path, main_path))
458
+
459
+ # Writes the sbatch shell script containing all the options
460
+ sh_path = job.jobpath / ("%s.sh" % job.name)
461
+ with sh_path.open("wt") as out:
462
+ out.write("""#!/bin/sh\n\n""")
463
+
464
+ workdir = self.launcher.connector.resolve(main_path)
465
+ out.write(f"#SBATCH --chdir={shquote(workdir)}\n")
466
+ out.write(f"""#SBATCH --error={relpath(job.stderr)}\n""")
467
+ out.write(f"""#SBATCH --output={relpath(job.stdout)}\n""")
468
+
469
+ for arg in self.launcher.options.args():
470
+ out.write(f"""#SBATCH {arg}\n""")
471
+
472
+ # We finish by the call to srun
473
+ out.write(f"""\nsrun ./{relpath(py_path)}\n\n""")
474
+
475
+ self.launcher.connector.setExecutable(sh_path, True)
476
+ return sh_path
@@ -1 +1 @@
1
- # from .base import Documentation
1
+ from .base import Documentation # noqa: F401
@@ -4,12 +4,11 @@ See https://www.mkdocs.org/user-guide/plugins/ for plugin API documentation
4
4
  """
5
5
 
6
6
  from collections import defaultdict
7
- import functools
8
7
  import re
9
8
  from experimaestro.mkdocs.annotations import shoulddocument
10
9
  import requests
11
10
  from urllib.parse import urljoin
12
- from experimaestro.core.types import ObjectType, Type
11
+ from experimaestro.core.types import ObjectType
13
12
  import mkdocs
14
13
  from pathlib import Path
15
14
  from typing import Dict, Iterator, List, Optional, Set, Tuple, Type as TypingType
@@ -76,7 +75,7 @@ class ObjectLatticeNode:
76
75
  return f"node({self.objecttype.identifier})"
77
76
 
78
77
  def isAncestor(self, other):
79
- return issubclass(self.objecttype.configtype, other.objecttype.configtype)
78
+ return issubclass(self.objecttype.config_type, other.objecttype.config_type)
80
79
 
81
80
  def _addChild(self, child: "ObjectLatticeNode"):
82
81
  child.parents.add(self)
@@ -321,7 +320,7 @@ class Documentation(mkdocs.plugins.BasePlugin):
321
320
 
322
321
  for node in self.lattice.iter_all():
323
322
  if node.objecttype is not None:
324
- member = node.objecttype.basetype
323
+ member = node.objecttype.value_type
325
324
  qname = f"{member.__module__}.{member.__qualname__}"
326
325
  path = self.type2path[qname]
327
326
 
@@ -354,7 +353,7 @@ class Documentation(mkdocs.plugins.BasePlugin):
354
353
  # Now, sort according to descendant/ascendant relationship or name
355
354
  nodes = set()
356
355
  for _node in cfgs:
357
- if issubclass(_node.objecttype.configtype, xpmtype.configtype):
356
+ if issubclass(_node.objecttype.config_type, xpmtype.config_type):
358
357
  nodes.add(_node)
359
358
 
360
359
  # Removes so they are not generated twice
@@ -443,11 +442,10 @@ class Documentation(mkdocs.plugins.BasePlugin):
443
442
  lines.append("\n\n")
444
443
 
445
444
  for name, argument in xpminfo.arguments.items():
446
-
447
445
  if isinstance(argument.type, ObjectType):
448
- basetype = argument.type.basetype
446
+ value_type = argument.type.value_type
449
447
  typestr = self.getlink(
450
- page.url, f"{basetype.__module__}.{basetype.__qualname__}"
448
+ page.url, f"{value_type.__module__}.{value_type.__qualname__}"
451
449
  )
452
450
  else:
453
451
  typestr = argument.type.name()