experimaestro 1.6.1__py3-none-any.whl → 1.15.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. experimaestro/__init__.py +14 -3
  2. experimaestro/annotations.py +13 -3
  3. experimaestro/cli/filter.py +19 -5
  4. experimaestro/cli/jobs.py +12 -5
  5. experimaestro/commandline.py +3 -7
  6. experimaestro/connectors/__init__.py +27 -12
  7. experimaestro/connectors/local.py +19 -10
  8. experimaestro/connectors/ssh.py +1 -1
  9. experimaestro/core/arguments.py +35 -3
  10. experimaestro/core/callbacks.py +52 -0
  11. experimaestro/core/context.py +8 -9
  12. experimaestro/core/identifier.py +301 -0
  13. experimaestro/core/objects/__init__.py +44 -0
  14. experimaestro/core/{objects.py → objects/config.py} +364 -716
  15. experimaestro/core/objects/config_utils.py +58 -0
  16. experimaestro/core/objects/config_walk.py +151 -0
  17. experimaestro/core/objects.pyi +15 -45
  18. experimaestro/core/serialization.py +63 -9
  19. experimaestro/core/serializers.py +1 -8
  20. experimaestro/core/types.py +61 -6
  21. experimaestro/experiments/cli.py +79 -29
  22. experimaestro/experiments/configuration.py +3 -0
  23. experimaestro/generators.py +6 -1
  24. experimaestro/ipc.py +4 -1
  25. experimaestro/launcherfinder/parser.py +8 -3
  26. experimaestro/launcherfinder/registry.py +29 -10
  27. experimaestro/launcherfinder/specs.py +49 -10
  28. experimaestro/launchers/slurm/base.py +51 -13
  29. experimaestro/mkdocs/__init__.py +1 -1
  30. experimaestro/notifications.py +2 -1
  31. experimaestro/run.py +3 -1
  32. experimaestro/scheduler/base.py +114 -6
  33. experimaestro/scheduler/dynamic_outputs.py +184 -0
  34. experimaestro/scheduler/state.py +75 -0
  35. experimaestro/scheduler/workspace.py +2 -1
  36. experimaestro/scriptbuilder.py +13 -2
  37. experimaestro/server/data/0c35d18bf06992036b69.woff2 +0 -0
  38. experimaestro/server/data/1815e00441357e01619e.ttf +0 -0
  39. experimaestro/server/data/219aa9140e099e6c72ed.woff2 +0 -0
  40. experimaestro/server/data/2463b90d9a316e4e5294.woff2 +0 -0
  41. experimaestro/server/data/2582b0e4bcf85eceead0.ttf +0 -0
  42. experimaestro/server/data/3a4004a46a653d4b2166.woff +0 -0
  43. experimaestro/server/data/3baa5b8f3469222b822d.woff +0 -0
  44. experimaestro/server/data/4d73cb90e394b34b7670.woff +0 -0
  45. experimaestro/server/data/4ef4218c522f1eb6b5b1.woff2 +0 -0
  46. experimaestro/server/data/5d681e2edae8c60630db.woff +0 -0
  47. experimaestro/server/data/6f420cf17cc0d7676fad.woff2 +0 -0
  48. experimaestro/server/data/89999bdf5d835c012025.woff2 +0 -0
  49. experimaestro/server/data/914997e1bdfc990d0897.ttf +0 -0
  50. experimaestro/server/data/c210719e60948b211a12.woff2 +0 -0
  51. experimaestro/server/data/c380809fd3677d7d6903.woff2 +0 -0
  52. experimaestro/server/data/f882956fd323fd322f31.woff +0 -0
  53. experimaestro/server/data/favicon.ico +0 -0
  54. experimaestro/server/data/index.css +22963 -0
  55. experimaestro/server/data/index.css.map +1 -0
  56. experimaestro/server/data/index.html +27 -0
  57. experimaestro/server/data/index.js +101770 -0
  58. experimaestro/server/data/index.js.map +1 -0
  59. experimaestro/server/data/login.html +22 -0
  60. experimaestro/server/data/manifest.json +15 -0
  61. experimaestro/settings.py +2 -2
  62. experimaestro/sphinx/__init__.py +7 -17
  63. experimaestro/taskglobals.py +7 -2
  64. experimaestro/tests/core/__init__.py +0 -0
  65. experimaestro/tests/core/test_generics.py +206 -0
  66. experimaestro/tests/definitions_types.py +5 -3
  67. experimaestro/tests/launchers/bin/sbatch +34 -7
  68. experimaestro/tests/launchers/bin/srun +5 -0
  69. experimaestro/tests/launchers/common.py +16 -4
  70. experimaestro/tests/restart.py +9 -4
  71. experimaestro/tests/tasks/all.py +23 -10
  72. experimaestro/tests/tasks/foreign.py +2 -4
  73. experimaestro/tests/test_dependencies.py +0 -6
  74. experimaestro/tests/test_experiment.py +73 -0
  75. experimaestro/tests/test_findlauncher.py +11 -4
  76. experimaestro/tests/test_forward.py +5 -5
  77. experimaestro/tests/test_generators.py +93 -0
  78. experimaestro/tests/test_identifier.py +114 -99
  79. experimaestro/tests/test_instance.py +6 -21
  80. experimaestro/tests/test_objects.py +20 -4
  81. experimaestro/tests/test_param.py +60 -22
  82. experimaestro/tests/test_serializers.py +24 -64
  83. experimaestro/tests/test_tags.py +5 -11
  84. experimaestro/tests/test_tasks.py +10 -23
  85. experimaestro/tests/test_tokens.py +3 -2
  86. experimaestro/tests/test_types.py +20 -17
  87. experimaestro/tests/test_validation.py +48 -91
  88. experimaestro/tokens.py +16 -5
  89. experimaestro/typingutils.py +8 -8
  90. experimaestro/utils/asyncio.py +6 -2
  91. experimaestro/utils/multiprocessing.py +44 -0
  92. experimaestro/utils/resources.py +7 -3
  93. {experimaestro-1.6.1.dist-info → experimaestro-1.15.2.dist-info}/METADATA +27 -34
  94. experimaestro-1.15.2.dist-info/RECORD +159 -0
  95. {experimaestro-1.6.1.dist-info → experimaestro-1.15.2.dist-info}/WHEEL +1 -1
  96. experimaestro-1.6.1.dist-info/RECORD +0 -122
  97. {experimaestro-1.6.1.dist-info → experimaestro-1.15.2.dist-info}/entry_points.txt +0 -0
  98. {experimaestro-1.6.1.dist-info → experimaestro-1.15.2.dist-info/licenses}/LICENSE +0 -0
@@ -1,4 +1,4 @@
1
- import imp
1
+ import datetime
2
2
  import importlib
3
3
  import inspect
4
4
  import json
@@ -59,7 +59,8 @@ class ExperimentCallable(Protocol):
59
59
  class ConfigurationLoader:
60
60
  def __init__(self):
61
61
  self.yamls = []
62
- self.pythonpath = set()
62
+ self.python_path = set()
63
+ self.yaml_module_file: None | Path = None
63
64
 
64
65
  def load(self, yaml_file: Path):
65
66
  """Loads a YAML file, and parents one if they exist"""
@@ -68,6 +69,16 @@ class ConfigurationLoader:
68
69
 
69
70
  with yaml_file.open("rt") as fp:
70
71
  _data = yaml.full_load(fp)
72
+
73
+ if "file" in _data:
74
+ path = Path(_data["file"])
75
+ if not path.is_absolute():
76
+ _data["file"] = str((yaml_file.parent / path).resolve())
77
+
78
+ if "module" in _data:
79
+ # Keeps track of the YAML file where the module was defined
80
+ self.yaml_module_file = yaml_file
81
+
71
82
  if parent := _data.get("parent", None):
72
83
  self.load(yaml_file.parent / parent)
73
84
 
@@ -76,9 +87,9 @@ class ConfigurationLoader:
76
87
  for path in _data.get("pythonpath", []):
77
88
  path = Path(path)
78
89
  if path.is_absolute():
79
- self.pythonpath.add(path.resolve())
90
+ self.python_path.add(path.resolve())
80
91
  else:
81
- self.pythonpath.add((yaml_file.parent / path).resolve())
92
+ self.python_path.add((yaml_file.parent / path).resolve())
82
93
 
83
94
 
84
95
  @click.option("--debug", is_flag=True, help="Print debug information")
@@ -116,7 +127,10 @@ class ConfigurationLoader:
116
127
  help="Port for monitoring (can be defined in the settings.yaml file)",
117
128
  )
118
129
  @click.option(
119
- "--file", "xp_file", help="The file containing the main experimental code"
130
+ "--file",
131
+ "xp_file",
132
+ type=Path,
133
+ help="The file containing the main experimental code",
120
134
  )
121
135
  @click.option(
122
136
  "--module-name", "module_name", help="Module containing the experimental code"
@@ -181,7 +195,7 @@ def experiments_cli( # noqa: C901
181
195
  configuration.merge_with(OmegaConf.from_dotlist(extra_conf))
182
196
 
183
197
  # --- Get the XP file
184
- pythonpath = list(conf_loader.pythonpath)
198
+ python_path = list(conf_loader.python_path)
185
199
  if module_name is None:
186
200
  module_name = configuration.get("module", None)
187
201
 
@@ -192,9 +206,11 @@ def experiments_cli( # noqa: C901
192
206
  not module_name
193
207
  ), "Module name and experiment file are mutually exclusive options"
194
208
  xp_file = Path(xp_file)
195
- if not pythonpath:
196
- pythonpath.append(xp_file.parent)
197
- logging.info("Using python path: %s", ", ".join(str(s) for s in pythonpath))
209
+ if not python_path:
210
+ python_path.append(xp_file.parent.absolute())
211
+ logging.info(
212
+ "Using python path: %s", ", ".join(str(s) for s in python_path)
213
+ )
198
214
 
199
215
  assert (
200
216
  module_name or xp_file
@@ -209,40 +225,59 @@ def experiments_cli( # noqa: C901
209
225
  # --- Finds the "run" function
210
226
 
211
227
  # Modifies the Python path
212
- for path in pythonpath:
228
+ for path in python_path:
213
229
  sys.path.append(str(path))
214
230
 
231
+ # --- Adds automatically the experiment module if not found
232
+ if module_name and conf_loader.yaml_module_file:
233
+ try:
234
+ importlib.import_module(module_name)
235
+ except ModuleNotFoundError:
236
+ # Try to setup a path
237
+ path = conf_loader.yaml_module_file.resolve()
238
+ for _ in range(len(module_name.split("."))):
239
+ path = path.parent
240
+
241
+ logging.info("Appending %s to python path", path)
242
+ sys.path.append(str(path))
243
+ python_path.append(path)
244
+
215
245
  if xp_file:
216
246
  if not xp_file.exists() and xp_file.suffix != ".py":
217
247
  xp_file = xp_file.with_suffix(".py")
218
248
  xp_file: Path = Path(yaml_file).parent / xp_file
219
- with open(xp_file) as src:
220
- module_name = xp_file.with_suffix("").name
221
- mod = imp.load_module(
222
- module_name,
223
- src,
224
- str(xp_file.absolute()),
225
- (".py", "r", imp.PY_SOURCE),
226
- )
249
+ module_name = xp_file.with_suffix("").name
250
+ spec = importlib.util.spec_from_file_location(
251
+ module_name, str(xp_file.absolute())
252
+ )
253
+ mod = importlib.util.module_from_spec(spec)
254
+ spec.loader.exec_module(mod)
227
255
  else:
228
256
  # Module
229
- mod = importlib.import_module(module_name)
257
+ try:
258
+ mod = importlib.import_module(module_name)
259
+ except ModuleNotFoundError as e:
260
+ logging.error("Module not found: %s with python path %s", e, sys.path)
261
+ raise
230
262
 
231
263
  helper = getattr(mod, "run", None)
232
264
 
233
265
  # --- ... and runs it
234
266
  if helper is None:
235
- raise ValueError(f"Could not find run function in {xp_file}")
267
+ raise click.ClickException(
268
+ f"Could not find run function in {xp_file if xp_file else module_name}"
269
+ )
236
270
 
237
271
  if not isinstance(helper, ExperimentHelper):
238
272
  helper = ExperimentHelper(helper)
239
273
 
240
274
  parameters = inspect.signature(helper.callable).parameters
241
275
  list_parameters = list(parameters.values())
242
- assert len(list_parameters) == 2, (
243
- "Callable function should only "
244
- f"have two arguments (got {len(list_parameters)})"
245
- )
276
+ if len(list_parameters) != 2:
277
+ raise click.ClickException(
278
+ f"run in {xp_file if xp_file else module_name} function should only "
279
+ f"have two arguments (got {len(list_parameters)}), "
280
+ )
246
281
 
247
282
  schema = list_parameters[1].annotation
248
283
  omegaconf_schema = OmegaConf.structured(schema())
@@ -252,36 +287,51 @@ def experiments_cli( # noqa: C901
252
287
  configuration = OmegaConf.merge(omegaconf_schema, configuration)
253
288
  except omegaconf.errors.ConfigKeyError as e:
254
289
  cprint(f"Error in configuration:\n\n{e}", "red", file=sys.stderr)
255
- sys.exit(1)
290
+ raise click.ClickException("Error in configuration")
256
291
 
257
292
  if show:
258
293
  print(json.dumps(OmegaConf.to_container(configuration))) # noqa: T201
259
294
  sys.exit(0)
260
295
 
261
296
  # Move to an object container
262
- configuration = OmegaConf.to_container(
297
+ xp_configuration: ConfigurationBase = OmegaConf.to_container(
263
298
  configuration, structured_config_mode=SCMode.INSTANTIATE
264
299
  )
265
300
 
266
301
  # Define the workspace
267
302
  ws_env = find_workspace(workdir=workdir, workspace=workspace)
303
+
268
304
  workdir = ws_env.path
269
305
 
270
- logging.info("Using working directory %s", str(workdir.resolve()))
306
+ # --- Sets up the experiment ID
271
307
 
272
308
  # --- Runs the experiment
309
+ if xp_configuration.add_timestamp:
310
+ timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M")
311
+ experiment_id = f"""{xp_configuration.id}-{timestamp}"""
312
+ else:
313
+ experiment_id = xp_configuration.id
314
+
315
+ logging.info(
316
+ "Running experiment %s working directory %s",
317
+ experiment_id,
318
+ str(workdir.resolve()),
319
+ )
273
320
  with experiment(
274
- ws_env, configuration.id, host=host, port=port, run_mode=run_mode
321
+ ws_env, experiment_id, host=host, port=port, run_mode=run_mode
275
322
  ) as xp:
276
323
  # Set up the environment
277
324
  # (1) global settings (2) workspace settings and (3) command line settings
278
325
  for key, value in env:
279
326
  xp.setenv(key, value)
280
327
 
328
+ # Sets the python path
329
+ xp.workspace.python_path.extend(python_path)
330
+
281
331
  try:
282
332
  # Run the experiment
283
333
  helper.xp = xp
284
- helper.run(list(args), configuration)
334
+ helper.run(list(args), xp_configuration)
285
335
 
286
336
  # ... and wait
287
337
  xp.wait()
@@ -51,3 +51,6 @@ class ConfigurationBase:
51
51
 
52
52
  description: str = ""
53
53
  """Description of the experiment"""
54
+
55
+ add_timestamp: bool = False
56
+ """Adds a timestamp YYYY_MM_DD-HH_MM to the experiment ID"""
@@ -1,11 +1,12 @@
1
1
  import inspect
2
2
  from pathlib import Path
3
+ from abc import ABC, abstractmethod
3
4
  from typing import Callable, Union
4
5
  from experimaestro.core.arguments import ArgumentOptions, TypeAnnotation
5
6
  from experimaestro.core.objects import ConfigWalkContext, Config
6
7
 
7
8
 
8
- class Generator:
9
+ class Generator(ABC):
9
10
  """Base class for all generators"""
10
11
 
11
12
  def isoutput(self):
@@ -13,6 +14,10 @@ class Generator:
13
14
  path within the job folder)"""
14
15
  return False
15
16
 
17
+ @abstractmethod
18
+ def __call__(self, context: ConfigWalkContext, config: Config):
19
+ ...
20
+
16
21
 
17
22
  class PathGenerator(Generator):
18
23
  """Generates a path"""
experimaestro/ipc.py CHANGED
@@ -7,6 +7,7 @@ import sys
7
7
  import logging
8
8
  from .utils import logger
9
9
  from watchdog.observers import Observer
10
+ from watchdog.observers.api import ObservedWatch
10
11
  from watchdog.events import FileSystemEventHandler
11
12
 
12
13
 
@@ -20,7 +21,9 @@ class IPCom:
20
21
  self.observer.start()
21
22
  self.pid = os.getpid()
22
23
 
23
- def fswatch(self, watcher: FileSystemEventHandler, path: Path, recursive=False):
24
+ def fswatch(
25
+ self, watcher: FileSystemEventHandler, path: Path, recursive=False
26
+ ) -> ObservedWatch:
24
27
  if not self.observer.is_alive():
25
28
  logging.error("Observer is not alive")
26
29
 
@@ -23,7 +23,7 @@ class SuppressStrMatch(StrMatch):
23
23
 
24
24
 
25
25
  def mem_spec():
26
- return "mem", "=", RegExMatch(r"\d+(G|M)?")
26
+ return "mem", "=", RegExMatch(r"\d+(GiB|MiB|G|M)?")
27
27
 
28
28
 
29
29
  def cores_spec():
@@ -51,7 +51,12 @@ def cpu():
51
51
 
52
52
 
53
53
  def duration():
54
- return "duration", "=", RegExMatch(r"\d+"), RegExMatch(r"h(ours)?|d(ays)?")
54
+ return (
55
+ "duration",
56
+ "=",
57
+ RegExMatch(r"\d+"),
58
+ RegExMatch(r"h(ours?)?|d(ays?)?|m(ins?)?"),
59
+ )
55
60
 
56
61
 
57
62
  def one_spec():
@@ -67,7 +72,7 @@ def grammar():
67
72
 
68
73
  class Visitor(PTNodeVisitor):
69
74
  def visit_grammar(self, node, children):
70
- return [child for child in children]
75
+ return specs.RequirementUnion(*[child for child in children])
71
76
 
72
77
  def visit_one_spec(self, node, children):
73
78
  return reduce(lambda x, el: x & el, children)
@@ -1,5 +1,6 @@
1
1
  # Configuration registers
2
2
 
3
+ from contextlib import contextmanager
3
4
  from typing import ClassVar, Dict, Optional, Set, Type, Union
4
5
 
5
6
  from pathlib import Path
@@ -7,9 +8,8 @@ import typing
7
8
  from omegaconf import DictConfig, OmegaConf, SCMode
8
9
  import pkg_resources
9
10
  from experimaestro.utils import logger
10
-
11
11
  from .base import ConnectorConfiguration, TokenConfiguration
12
- from .specs import HostRequirement
12
+ from .specs import HostRequirement, RequirementUnion
13
13
 
14
14
  if typing.TYPE_CHECKING:
15
15
  from experimaestro.launchers import Launcher
@@ -36,6 +36,16 @@ def load_yaml(schema, path: Path):
36
36
  )
37
37
 
38
38
 
39
+ @contextmanager
40
+ def ensure_enter(fp):
41
+ """Behaves as a resource, whether it is one or not"""
42
+ if hasattr(fp, "__enter__"):
43
+ with fp as _fp:
44
+ yield _fp
45
+ else:
46
+ yield fp
47
+
48
+
39
49
  class LauncherRegistry:
40
50
  INSTANCES: ClassVar[Dict[Path, "LauncherRegistry"]] = {}
41
51
  CURRENT_CONFIG_DIR: ClassVar[Optional[Path]] = None
@@ -73,18 +83,22 @@ class LauncherRegistry:
73
83
 
74
84
  # Register the find launcher function if it exists
75
85
  launchers_py = basepath / "launchers.py"
86
+ print(f"basepath {launchers_py}")
76
87
  if launchers_py.is_file():
77
88
  logger.info("Loading %s", launchers_py)
78
89
 
79
90
  from importlib import util
80
91
 
81
- spec = util.spec_from_file_location("xpm_launchers_conf", launchers_py)
82
- module = util.module_from_spec(spec)
83
- spec.loader.exec_module(module)
92
+ with ensure_enter(launchers_py.__fspath__()) as fp:
93
+ spec = util.spec_from_file_location("xpm_launchers_conf", fp)
94
+ module = util.module_from_spec(spec)
95
+ spec.loader.exec_module(module)
84
96
 
85
97
  self.find_launcher_fn = getattr(module, "find_launcher", None)
86
98
  if self.find_launcher_fn is None:
87
- logger.warn("No find_launcher() function was found in %s", launchers_py)
99
+ logger.warning(
100
+ "No find_launcher() function was found in %s", launchers_py
101
+ )
88
102
 
89
103
  # Read the configuration file
90
104
  self.connectors = load_yaml(
@@ -136,17 +150,22 @@ class LauncherRegistry:
136
150
  # Parse specs
137
151
  from .parser import parse
138
152
 
139
- specs = []
153
+ specs = RequirementUnion()
140
154
  for spec in input_specs:
141
155
  if isinstance(spec, str):
142
- specs.extend(parse(spec))
156
+ specs.add(parse(spec))
143
157
  else:
144
- specs.append(spec)
158
+ specs.add(spec)
145
159
 
146
160
  # Use launcher function
161
+ from experimaestro.launchers import Launcher
162
+
147
163
  if self.find_launcher_fn is not None:
148
- for spec in specs:
164
+ for spec in specs.requirements:
149
165
  if launcher := self.find_launcher_fn(spec, tags):
166
+ assert isinstance(
167
+ launcher, Launcher
168
+ ), "f{self.find_launcher_fn} did not return a Launcher but {type(launcher)}"
150
169
  return launcher
151
170
 
152
171
  return None
@@ -1,4 +1,6 @@
1
+ from abc import ABC, abstractmethod
1
2
  import logging
3
+ import math
2
4
  from attr import Factory
3
5
  from attrs import define
4
6
  from copy import copy, deepcopy
@@ -20,9 +22,6 @@ class CudaSpecification:
20
22
  min_memory: int = 0
21
23
  """Minimum request memory (in bytes)"""
22
24
 
23
- def __lt__(self, other: "CudaSpecification"):
24
- return self.memory < other.memory
25
-
26
25
  def match(self, spec: "CudaSpecification"):
27
26
  """Returns True if the specification matches this host"""
28
27
  return (self.memory >= spec.memory) and (self.min_memory <= spec.memory)
@@ -30,7 +29,7 @@ class CudaSpecification:
30
29
  def __repr__(self):
31
30
  return (
32
31
  f"CUDA({self.model} "
33
- f"{format_size(self.memory)}/{format_size(self.min_memory)})"
32
+ f"max={format_size(self.memory, binary=True)}/min={format_size(self.min_memory, binary=True)})"
34
33
  )
35
34
 
36
35
 
@@ -48,8 +47,15 @@ class CPUSpecification:
48
47
  cpu_per_gpu: int = 0
49
48
  """Number of CPU per GPU (0 if not defined)"""
50
49
 
51
- def __lt__(self, other: "CPUSpecification"):
52
- return self.memory < other.memory and self.cores < other.cores
50
+ def __repr__(self):
51
+ return (
52
+ f"CPU("
53
+ f"mem={format_size(self.memory, binary=True)}, cores={self.cores}"
54
+ ")"
55
+ )
56
+
57
+ def match(self, other: "CPUSpecification"):
58
+ return (self.memory >= other.memory) and (self.cores >= other.cores)
53
59
 
54
60
  def total_memory(self, gpus: int = 0):
55
61
  return max(
@@ -91,10 +97,11 @@ class MatchRequirement:
91
97
  requirement: "HostSimpleRequirement"
92
98
 
93
99
 
94
- class HostRequirement:
100
+ class HostRequirement(ABC):
95
101
  """A requirement must be a disjunction of host requirements"""
96
102
 
97
103
  requirements: List["HostSimpleRequirement"]
104
+ """List of requirements (by order of priority)"""
98
105
 
99
106
  def __init__(self) -> None:
100
107
  self.requirements = []
@@ -105,6 +112,12 @@ class HostRequirement:
105
112
  def match(self, host: HostSpecification) -> Optional[MatchRequirement]:
106
113
  raise NotImplementedError()
107
114
 
115
+ @abstractmethod
116
+ def multiply_duration(self, coefficient: float) -> "HostRequirement":
117
+ """Returns a new HostRequirement with a duration multiplied by the
118
+ provided coefficient"""
119
+ ...
120
+
108
121
 
109
122
  class RequirementUnion(HostRequirement):
110
123
  """Ordered list of simple host requirements -- the first one is the priority"""
@@ -114,6 +127,16 @@ class RequirementUnion(HostRequirement):
114
127
  def __init__(self, *requirements: "HostSimpleRequirement"):
115
128
  self.requirements = list(requirements)
116
129
 
130
+ def add(self, requirement: "HostRequirement"):
131
+ match requirement:
132
+ case HostSimpleRequirement():
133
+ self.requirements.extend(*requirement.requirements)
134
+ case RequirementUnion():
135
+ self.requirements.append(requirement)
136
+ case _:
137
+ raise RuntimeError("Cannot handle type %s", type(requirement))
138
+ return self
139
+
117
140
  def match(self, host: HostSpecification) -> Optional[MatchRequirement]:
118
141
  """Returns the matched requirement (if any)"""
119
142
 
@@ -128,12 +151,23 @@ class RequirementUnion(HostRequirement):
128
151
 
129
152
  return argmax
130
153
 
154
+ def multiply_duration(self, coefficient: float) -> "RequirementUnion":
155
+ return RequirementUnion(
156
+ *[r.multiply_duration(coefficient) for r in self.requirements]
157
+ )
158
+
159
+ def __repr__(self):
160
+ return " | ".join(repr(r) for r in self.requirements)
161
+
131
162
 
132
163
  class HostSimpleRequirement(HostRequirement):
133
164
  """Simple host requirement"""
134
165
 
135
166
  cuda_gpus: List["CudaSpecification"]
167
+ """Specification for CUDA gpus"""
168
+
136
169
  cpu: "CPUSpecification"
170
+ """Specification for CPU"""
137
171
 
138
172
  duration: int
139
173
  """Requested duration (in seconds)"""
@@ -141,6 +175,11 @@ class HostSimpleRequirement(HostRequirement):
141
175
  def __repr__(self):
142
176
  return f"Req(cpu={self.cpu}, cuda={self.cuda_gpus}, duration={self.duration})"
143
177
 
178
+ def multiply_duration(self, coefficient: float) -> "HostSimpleRequirement":
179
+ r = HostSimpleRequirement(self)
180
+ r.duration = math.ceil(self.duration * coefficient)
181
+ return r
182
+
144
183
  def __init__(self, *reqs: "HostSimpleRequirement"):
145
184
  self.cuda_gpus = []
146
185
  self.cpu = CPUSpecification(0, 0)
@@ -158,7 +197,7 @@ class HostSimpleRequirement(HostRequirement):
158
197
  self.cpu.cores = max(req.cpu.cores, self.cpu.cores)
159
198
  self.duration = max(req.duration, self.duration)
160
199
  self.cuda_gpus.extend(req.cuda_gpus)
161
- self.cuda_gpus.sort()
200
+ self.cuda_gpus.sort(key=lambda cuda: -cuda.memory)
162
201
 
163
202
  def match(self, host: HostSpecification) -> Optional[MatchRequirement]:
164
203
  if self.cuda_gpus:
@@ -182,7 +221,7 @@ class HostSimpleRequirement(HostRequirement):
182
221
  )
183
222
  return None
184
223
 
185
- if host.cpu < self.cpu:
224
+ if not host.cpu.match(self.cpu):
186
225
  return None
187
226
 
188
227
  if host.max_duration > 0 and self.duration > host.max_duration:
@@ -197,7 +236,7 @@ class HostSimpleRequirement(HostRequirement):
197
236
  _self = deepcopy(self)
198
237
  for _ in range(count - 1):
199
238
  _self.cuda_gpus.extend(self.cuda_gpus)
200
- _self.cuda_gpus.sort()
239
+ _self.cuda_gpus.sort(key=lambda cuda: -cuda.memory)
201
240
 
202
241
  return _self
203
242
 
@@ -11,6 +11,7 @@ from typing import (
11
11
  )
12
12
  from experimaestro.connectors.local import LocalConnector
13
13
  import re
14
+ from shlex import quote as shquote
14
15
  from contextlib import contextmanager
15
16
  from dataclasses import dataclass
16
17
  from experimaestro.launcherfinder.registry import (
@@ -86,7 +87,7 @@ class SlurmProcessWatcher(threading.Thread):
86
87
  WATCHERS: Dict[Tuple[Tuple[str, Any]], "SlurmProcessWatcher"] = {}
87
88
 
88
89
  def __init__(self, launcher: "SlurmLauncher"):
89
- super().__init__()
90
+ super().__init__(daemon=True)
90
91
  self.launcher = launcher
91
92
  self.count = 1
92
93
  self.jobs: Dict[str, SlurmJobState] = {}
@@ -183,10 +184,10 @@ class BatchSlurmProcess(Process):
183
184
  if state and state.finished():
184
185
  return 0 if state.slurm_state == "COMPLETED" else 1
185
186
 
186
- async def aio_state(self):
187
+ async def aio_state(self, timeout: float | None = None) -> ProcessState:
187
188
  def check():
188
189
  with SlurmProcessWatcher.get(self.launcher) as watcher:
189
- jobinfo = watcher.getjob(self.jobid)
190
+ jobinfo = watcher.getjob(self.jobid, timeout=timeout)
190
191
  return jobinfo.state if jobinfo else ProcessState.SCHEDULED
191
192
 
192
193
  return await asyncThreadcheck("slurm.aio_isrunning", check)
@@ -211,7 +212,7 @@ class BatchSlurmProcess(Process):
211
212
 
212
213
  # Checks that the process is running
213
214
  with SlurmProcessWatcher.get(launcher) as watcher:
214
- logger.info("Checking SLURM job %s", process.jobid)
215
+ logger.debug("Checking SLURM job %s", process.jobid)
215
216
  jobinfo = watcher.getjob(process.jobid, timeout=0.1)
216
217
  if jobinfo and jobinfo.state.running:
217
218
  logger.debug(
@@ -235,15 +236,15 @@ class SlurmProcessBuilder(ProcessBuilder):
235
236
  super().__init__()
236
237
  self.launcher = launcher
237
238
 
238
- def start(self) -> BatchSlurmProcess:
239
+ def start(self, task_mode: bool = False) -> BatchSlurmProcess:
239
240
  """Start the process"""
240
241
  builder = self.launcher.connector.processbuilder()
241
- builder.workingDirectory = self.workingDirectory
242
242
  builder.environ = self.launcher.launcherenv
243
243
  builder.detach = False
244
244
 
245
245
  if not self.detach:
246
246
  # Simplest case: we wait for the output
247
+ builder.workingDirectory = self.workingDirectory
247
248
  builder.command = [f"{self.launcher.binpath}/srun"]
248
249
  builder.command.extend(self.launcher.options.args())
249
250
  builder.command.extend(self.command)
@@ -255,11 +256,17 @@ class SlurmProcessBuilder(ProcessBuilder):
255
256
  return builder.start()
256
257
 
257
258
  builder.command = [f"{self.launcher.binpath}/sbatch", "--parsable"]
258
- builder.command.extend(self.launcher.options.args())
259
259
 
260
- addstream(builder.command, "-e", self.stderr)
261
- addstream(builder.command, "-o", self.stdout)
262
- addstream(builder.command, "-i", self.stdin)
260
+ if not task_mode:
261
+ # Use command line parameters when not running a task
262
+ builder.command.extend(self.launcher.options.args())
263
+
264
+ if self.workingDirectory:
265
+ workdir = self.launcher.connector.resolve(self.workingDirectory)
266
+ builder.command.append(f"--chdir={workdir}")
267
+ addstream(builder.command, "-e", self.stderr)
268
+ addstream(builder.command, "-o", self.stdout)
269
+ addstream(builder.command, "-i", self.stdin)
263
270
 
264
271
  builder.command.extend(self.command)
265
272
  logger.info(
@@ -427,12 +434,43 @@ class SlurmLauncher(Launcher):
427
434
 
428
435
  We assume *nix, but should be changed to PythonScriptBuilder when working
429
436
  """
430
- builder = PythonScriptBuilder()
431
- builder.processtype = "slurm"
432
- return builder
437
+ return SlurmScriptBuilder(self)
433
438
 
434
439
  def processbuilder(self) -> SlurmProcessBuilder:
435
440
  """Returns the process builder for this launcher
436
441
 
437
442
  By default, returns the associated connector builder"""
438
443
  return SlurmProcessBuilder(self)
444
+
445
+
446
+ class SlurmScriptBuilder(PythonScriptBuilder):
447
+ def __init__(self, launcher: SlurmLauncher, pythonpath=None):
448
+ super().__init__(pythonpath)
449
+ self.launcher = launcher
450
+ self.processtype = "slurm"
451
+
452
+ def write(self, job):
453
+ py_path = super().write(job)
454
+ main_path = py_path.parent
455
+
456
+ def relpath(path: Path):
457
+ return shquote(self.launcher.connector.resolve(path, main_path))
458
+
459
+ # Writes the sbatch shell script containing all the options
460
+ sh_path = job.jobpath / ("%s.sh" % job.name)
461
+ with sh_path.open("wt") as out:
462
+ out.write("""#!/bin/sh\n\n""")
463
+
464
+ workdir = self.launcher.connector.resolve(main_path)
465
+ out.write(f"#SBATCH --chdir={shquote(workdir)}\n")
466
+ out.write(f"""#SBATCH --error={relpath(job.stderr)}\n""")
467
+ out.write(f"""#SBATCH --output={relpath(job.stdout)}\n""")
468
+
469
+ for arg in self.launcher.options.args():
470
+ out.write(f"""#SBATCH {arg}\n""")
471
+
472
+ # We finish by the call to srun
473
+ out.write(f"""\nsrun ./{relpath(py_path)}\n\n""")
474
+
475
+ self.launcher.connector.setExecutable(sh_path, True)
476
+ return sh_path
@@ -1 +1 @@
1
- # from .base import Documentation
1
+ from .base import Documentation # noqa: F401
@@ -78,7 +78,6 @@ class Reporter(threading.Thread):
78
78
 
79
79
  self.progress_threshold = 0.01
80
80
  self.cv = threading.Condition()
81
- self.start()
82
81
 
83
82
  def stop(self):
84
83
  self.stopping = True
@@ -110,6 +109,7 @@ class Reporter(threading.Thread):
110
109
  return any(level.modified(self) for level in self.levels)
111
110
 
112
111
  def check_urls(self):
112
+ """Check whether we have new schedulers to notify"""
113
113
  mtime = os.path.getmtime(self.path)
114
114
  if mtime > self.lastcheck:
115
115
  for f in self.path.iterdir():
@@ -222,6 +222,7 @@ class Reporter(threading.Thread):
222
222
  taskpath = TaskEnv.instance().taskpath
223
223
  assert taskpath is not None, "Task path is not defined"
224
224
  Reporter.INSTANCE = Reporter(taskpath)
225
+ Reporter.INSTANCE.start()
225
226
  return Reporter.INSTANCE
226
227
 
227
228