nntool 2.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,546 @@
1
+ import os
2
+ import sys
3
+ import submitit
4
+
5
+ from copy import deepcopy
6
+ from functools import partial
7
+ from submitit import Job
8
+ from typing import Any, Callable, Literal, Tuple, Union, Dict, List, Optional
9
+
10
+ from ..config import SlurmConfig
11
+ from ..task import (
12
+ PyTorchDistributedTask,
13
+ pack_code_files,
14
+ include_code_files,
15
+ exclude_code_folders,
16
+ )
17
+ from ._slurm_context import SubmititDistributedCommandContext
18
+
19
+
20
+ class SlurmFunction:
21
+ def __init__(
22
+ self,
23
+ submit_fn: Callable[..., Any],
24
+ default_submit_fn_args: Optional[Tuple[Any]] = None,
25
+ default_submit_fn_kwargs: Optional[Dict[str, Any]] = None,
26
+ ) -> None:
27
+ """A slurm function for the slurm job, which can be used for distributed or non-distributed job (controlled by `use_distributed_env` in the slurm dataclass).
28
+
29
+ :param submit_fn: function to be submitted to Slurm, defaults to None
30
+ :param default_submit_fn_args: default args for submit_fn, defaults to ()
31
+ :param default_submit_fn_kwargs: default known word args for submit_fn, defaults to {}
32
+ :return: the wrapped submit function with configured slurm paramters
33
+ """
34
+ self.submit_fn: Callable[..., Any] = submit_fn
35
+ self.default_submit_fn_args: Tuple[Any] = (
36
+ tuple() if default_submit_fn_args is None else default_submit_fn_args
37
+ )
38
+ self.default_submit_fn_kwargs: Dict[str, Any] = (
39
+ dict() if default_submit_fn_kwargs is None else default_submit_fn_kwargs
40
+ )
41
+ self.__doc__ = self.submit_fn.__doc__
42
+
43
+ # slurm funcion is configured after calling `configure`
44
+ self.__configured: bool = False
45
+ self.__executor: Optional[submitit.AutoExecutor] = None # to be set up by `get_executor`
46
+
47
+ # annotations here, will be set up after instantiation
48
+ self.slurm_config: SlurmConfig
49
+ self.slurm_params_kwargs: Dict[str, str]
50
+ self.slurm_submit_kwargs: Dict[str, str]
51
+ self.slurm_task_kwargs: Dict[str, str]
52
+ self.system_argv: Optional[List[str]]
53
+ self.pack_code_include_fn: Callable[[str, str], bool]
54
+ self.pack_code_exclude_fn: Callable[[str, str], bool]
55
+
56
+ def is_configured(self) -> bool:
57
+ """Whether the slurm function has been configured.
58
+
59
+ :return: True if the slurm function has been configured, False otherwise
60
+ """
61
+ return self.submit_fn is not None and self.__configured
62
+
63
+ def is_distributed(self) -> bool:
64
+ """Whether the slurm function is distributed.
65
+
66
+ :return: True if the slurm function is distributed, False otherwise
67
+ """
68
+ return self.slurm_config.use_distributed_env
69
+
70
+ def prepare_executor(self) -> submitit.AutoExecutor:
71
+ slurm_config = self.slurm_config
72
+ slurm_parameters_kwargs = self.slurm_params_kwargs
73
+ slurm_submission_kwargs = self.slurm_submit_kwargs
74
+
75
+ # Select the cluster type, which is based on the submitit library
76
+ # Here we add a special mode called `exec` for running the job in the local machine,
77
+ # which is equivalent to the `debug` mode in the submitit library
78
+ cluster_dispatch = {
79
+ "slurm": None,
80
+ "debug": "debug",
81
+ "run": "debug",
82
+ "local": "local",
83
+ }
84
+ executor = submitit.AutoExecutor(
85
+ folder=slurm_config.output_path,
86
+ cluster=cluster_dispatch.get(slurm_config.mode, slurm_config.mode),
87
+ )
88
+
89
+ if slurm_config.mode in ("slurm", "debug"):
90
+ # Set additional slurm parameters
91
+ slurm_additional_parameters = {}
92
+ if slurm_config.node_list:
93
+ slurm_additional_parameters["nodelist"] = slurm_config.node_list
94
+ if slurm_config.node_list_exclude:
95
+ slurm_additional_parameters["exclude"] = slurm_config.node_list_exclude
96
+ if slurm_config.mem:
97
+ slurm_additional_parameters["mem"] = slurm_config.mem
98
+
99
+ # Update the slurm additional parameters with the slurm parameters kwargs
100
+ slurm_additional_parameters.update(slurm_parameters_kwargs)
101
+
102
+ executor.update_parameters(
103
+ name=slurm_config.job_name,
104
+ slurm_partition=slurm_config.partition,
105
+ nodes=slurm_config.num_of_node,
106
+ tasks_per_node=slurm_config.tasks_per_node,
107
+ cpus_per_task=slurm_config.cpus_per_task,
108
+ gpus_per_node=(
109
+ slurm_config.gpus_per_task * slurm_config.tasks_per_node
110
+ if slurm_config.gpus_per_node is None
111
+ else slurm_config.gpus_per_node
112
+ ), # gpu cannot be assigned in the task level
113
+ timeout_min=slurm_config.timeout_min,
114
+ # refer to https://samuelstevens.me/writing/submitit#multi-gpu-training-in-torch
115
+ stderr_to_stdout=slurm_config.stderr_to_stdout,
116
+ local_setup=slurm_config.setup,
117
+ slurm_additional_parameters=slurm_additional_parameters,
118
+ **slurm_submission_kwargs,
119
+ )
120
+ elif slurm_config.mode in ("local",):
121
+ # If CUDA_VISIBLE_DEVICES is set by users, we need to set it to the local job
122
+ # Refer to:
123
+ # 1. https://github.com/facebookincubator/submitit/blob/64119dc669a21d69f46c9d9a3f556ce447d238d3/submitit/local/local.py#L203
124
+ # 2. https://github.com/facebookincubator/submitit/blob/64119dc669a21d69f46c9d9a3f556ce447d238d3/submitit/local/local.py#L241
125
+ # if "CUDA_VISIBLE_DEVICES" in os.environ:
126
+ # visible_gpus = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
127
+ # visible_gpus = tuple(int(gpu) for gpu in visible_gpus if gpu.isdigit())
128
+ # else:
129
+ # visible_gpus = ()
130
+
131
+ executor.update_parameters(
132
+ name=slurm_config.job_name,
133
+ nodes=slurm_config.num_of_node,
134
+ tasks_per_node=slurm_config.tasks_per_node,
135
+ cpus_per_task=slurm_config.cpus_per_task,
136
+ gpus_per_node=(
137
+ slurm_config.gpus_per_task * slurm_config.tasks_per_node
138
+ if slurm_config.gpus_per_node is None
139
+ else slurm_config.gpus_per_node
140
+ ), # gpu cannot be assigned in the task level
141
+ timeout_min=slurm_config.timeout_min,
142
+ # refer to https://samuelstevens.me/writing/submitit#multi-gpu-training-in-torch
143
+ stderr_to_stdout=slurm_config.stderr_to_stdout,
144
+ local_setup=slurm_config.setup,
145
+ )
146
+ else:
147
+ raise ValueError(
148
+ f"Unsupported slurm mode: {slurm_config.mode}. Supported modes are: {list(cluster_dispatch.keys())}."
149
+ )
150
+ return executor
151
+
152
+ def get_executor(
153
+ self,
154
+ ) -> submitit.AutoExecutor:
155
+ if self.__executor is not None:
156
+ return self.__executor
157
+ else:
158
+ executor = self.prepare_executor()
159
+ self.__executor = executor
160
+ return executor
161
+
162
+ @staticmethod
163
+ def slurm_has_been_set_up() -> bool:
164
+ """This function checks whether the slurm has been set up by checking whether `NNTOOL_SLURM_HAS_BEEN_SET_UP` is existed in enviroment variables, which is a special environment variable to indicate that the slurm has been set up.
165
+
166
+ :return: True if the slurm has been set up, False otherwise
167
+ """
168
+ # check whether slurm has been set up
169
+ has_been_set_up = False
170
+ if os.environ.get("NNTOOL_SLURM_HAS_BEEN_SET_UP") is not None:
171
+ has_been_set_up = True
172
+ return has_been_set_up
173
+
174
+ @staticmethod
175
+ def slurm_packed_code() -> Optional[str]:
176
+ """This function checks whether the slurm has been set up with packed code by checking whether `NNTOOL_SLURM_PACKED_CODE` is existed in enviroment variables, which is a special environment variable to indicate that the slurm has been set up with packed code.
177
+
178
+ :return: The target code root if the slurm has been set up with packed code, None otherwise
179
+ """
180
+ # check whether slurm has been set up
181
+ packed_code_root = None
182
+ if os.environ.get("NNTOOL_SLURM_PACKED_CODE", None) is not None:
183
+ packed_code_root = os.environ.get("NNTOOL_SLURM_PACKED_CODE")
184
+ return packed_code_root
185
+
186
+ def __mark_slurm_has_been_set_up(self):
187
+ os.environ["NNTOOL_SLURM_HAS_BEEN_SET_UP"] = "1"
188
+
189
+ def __mark_slurm_packed_code(self, target_code_root: str):
190
+ os.environ["NNTOOL_SLURM_PACKED_CODE"] = target_code_root
191
+
192
+ def __update_slurm_kwargs(
193
+ self,
194
+ slurm_params_kwargs: Dict[str, str] = {},
195
+ slurm_submit_kwargs: Dict[str, str] = {},
196
+ slurm_task_kwargs: Dict[str, str] = {},
197
+ ):
198
+ """update the slurm configuration for the slurm function. By default, the slurm parameters, slurm submission parameters, and slurm task parameters are updated. The slurm parameters are updated by the slurm configuration, while the slurm submission parameters and slurm task parameters would override them by the given arguments.
199
+
200
+ :param slurm_params_kwargs: extra settings, defaults to {}
201
+ :param slurm_submit_kwargs: extra settings, defaults to {}
202
+ :param slurm_task_kwargs: extra settings, defaults to {}
203
+ """
204
+ if slurm_params_kwargs:
205
+ self.slurm_params_kwargs.update(slurm_params_kwargs)
206
+ if slurm_submit_kwargs:
207
+ self.slurm_submit_kwargs.update(slurm_submit_kwargs)
208
+ if slurm_task_kwargs:
209
+ self.slurm_task_kwargs.update(slurm_task_kwargs)
210
+
211
+ def configure(
212
+ self,
213
+ slurm_config: SlurmConfig,
214
+ slurm_params_kwargs: Optional[Dict[str, str]] = None,
215
+ slurm_submit_kwargs: Optional[Dict[str, str]] = None,
216
+ slurm_task_kwargs: Optional[Dict[str, str]] = None,
217
+ system_argv: Optional[List[str]] = None,
218
+ pack_code_include_fn: Optional[Callable[[str, str], bool]] = None,
219
+ pack_code_exclude_fn: Optional[Callable[[str, str], bool]] = None,
220
+ ) -> "SlurmFunction":
221
+ """Update the slurm configuration for the slurm function. A slurm function for the slurm job, which can be used for distributed or non-distributed job (controlled by `use_distributed_env` in the slurm dataclass).
222
+
223
+ **Exported Distributed Enviroment Variables**
224
+
225
+ - `NNTOOL_SLURM_HAS_BEEN_SET_UP` is a special environment variable to indicate that the slurm has been set up.
226
+ - After the set up, the distributed job will be launched and the following variables are exported:
227
+ - `num_processes`: int
228
+ - `num_machines`: int
229
+ - `machine_rank`: int
230
+ - `main_process_ip`: str
231
+ - `main_process_port`: int
232
+
233
+ :param slurm_config: SlurmConfig, the slurm configuration dataclass, defaults to None
234
+ :param slurm_params_kwargs: extra slurm arguments for the slurm configuration, defaults to {}
235
+ :param slurm_submit_kwargs: extra slurm arguments for `srun` or `sbatch`, defaults to {}
236
+ :param slurm_task_kwargs: extra arguments for the setting of distributed task, defaults to {}
237
+ :param system_argv: the system arguments for the second launch in the distributed task (by default it will use the current system arguments `sys.argv[1:]`), defaults to None
238
+ :return: the wrapped submit function with configured slurm paramters
239
+ """
240
+ slurm_fn = SlurmFunction(
241
+ submit_fn=self.submit_fn,
242
+ default_submit_fn_args=self.default_submit_fn_args,
243
+ default_submit_fn_kwargs=self.default_submit_fn_kwargs,
244
+ )
245
+
246
+ slurm_fn.slurm_config = slurm_config
247
+ slurm_fn.slurm_params_kwargs = (
248
+ {} if slurm_params_kwargs is None else deepcopy(slurm_params_kwargs)
249
+ )
250
+ slurm_fn.slurm_submit_kwargs = (
251
+ {} if slurm_submit_kwargs is None else deepcopy(slurm_submit_kwargs)
252
+ )
253
+ slurm_fn.slurm_task_kwargs = (
254
+ {} if slurm_task_kwargs is None else deepcopy(slurm_task_kwargs)
255
+ )
256
+ slurm_fn.system_argv = system_argv
257
+
258
+ slurm_fn.__update_slurm_kwargs(
259
+ slurm_fn.slurm_config.extra_params_kwargs, # make sure the same parameters are controlled by the config
260
+ slurm_fn.slurm_config.extra_submit_kwargs,
261
+ slurm_fn.slurm_config.extra_task_kwargs,
262
+ )
263
+
264
+ slurm_fn.pack_code_include_fn = partial(
265
+ include_code_files,
266
+ code_ext=slurm_fn.slurm_config.code_file_suffixes,
267
+ )
268
+ slurm_fn.pack_code_exclude_fn = partial(
269
+ exclude_code_folders,
270
+ code_folders=slurm_fn.slurm_config.exclude_code_folders,
271
+ )
272
+
273
+ if pack_code_include_fn is not None:
274
+ slurm_fn.pack_code_include_fn = pack_code_include_fn
275
+
276
+ if pack_code_exclude_fn is not None:
277
+ slurm_fn.pack_code_exclude_fn = pack_code_exclude_fn
278
+
279
+ # mark instantiated
280
+ slurm_fn.__configured = True
281
+ return slurm_fn
282
+
283
+ def __getitem__(self, slurm_config: Union[Dict[str, Any], Tuple[Any], Any]) -> "SlurmFunction":
284
+ """Instantiate the slurm configuration for the slurm function. A slurm function for the slurm job, which can be used for distributed or non-distributed job (controlled by `use_distributed_env` in the slurm dataclass).
285
+
286
+ #### Exported Distributed Enviroment Variables
287
+ 1. NNTOOL_SLURM_HAS_BEEN_SET_UP is a special environment variable to indicate that the slurm has been set up.
288
+ 2. After the set up, the distributed job will be launched and the following variables are exported: num_processes: int, num_machines: int, machine_rank: int, main_process_ip: str, main_process_port: int.
289
+
290
+ :param slurm_config: SlurmConfig, the slurm configuration dataclass
291
+ :return: the wrapped submit function with configured slurm paramters
292
+ """
293
+ if isinstance(slurm_config, dict):
294
+ return self.configure(**slurm_config)
295
+ elif isinstance(slurm_config, (list, tuple)):
296
+ return self.configure(*slurm_config)
297
+ else:
298
+ # will try to pass the slurm_configs as the first argument
299
+ return self.configure(slurm_config)
300
+
301
+ def __before_submission(self, *args, **kwargs):
302
+ """The hook function before submitting the job. It will pack the code and scripts to the slurm output folder if the `pack_code` is set to True in the slurm configuration. Only work before the first submit.
303
+
304
+ :raises Exception: if the slurm function is not integrated
305
+ """
306
+ if self.slurm_packed_code() is not None:
307
+ # set sbatch command to change directory
308
+ self.slurm_params_kwargs.update({"chdir": self.slurm_packed_code()})
309
+
310
+ if self.slurm_has_been_set_up():
311
+ return
312
+
313
+ if not self.is_configured():
314
+ raise Exception("A `SlurmFunction` should be configured before calling it.")
315
+
316
+ # pack the code and scripts to the slurm output folder
317
+ if self.slurm_config.pack_code:
318
+ target_code_root = pack_code_files(
319
+ self.slurm_config.code_root,
320
+ self.slurm_config.output_path,
321
+ include_fn=self.pack_code_include_fn,
322
+ exclude_fn=self.pack_code_exclude_fn,
323
+ )
324
+ # set sbatch command to change directory for the first launch
325
+ self.slurm_params_kwargs.update({"chdir": target_code_root})
326
+ self.__mark_slurm_packed_code(str(target_code_root))
327
+
328
+ def __after_submission(
329
+ self,
330
+ submit_results: Union[Job, List[Job], Any] = None,
331
+ *args,
332
+ **kwargs,
333
+ ):
334
+ # get result to run program other than slurm mode
335
+ if isinstance(submit_results, Job):
336
+ if self.slurm_config.mode != "slurm":
337
+ submit_results.results()
338
+ elif (
339
+ isinstance(submit_results, list)
340
+ and submit_results
341
+ and isinstance(submit_results[0], Job)
342
+ ):
343
+ if self.slurm_config.mode != "slurm":
344
+ for job in submit_results:
345
+ job.results()
346
+ else:
347
+ pass
348
+
349
+ @property
350
+ def _is_valid_mode_for_executor(self) -> bool:
351
+ return self.slurm_config.mode in ("slurm", "debug", "local")
352
+
353
+ @property
354
+ def _should_be_submitted_to_executor(self) -> bool:
355
+ """Check whether the slurm function should be submitted to the executor.
356
+
357
+ :return: True if the slurm function should be submitted to the executor, False otherwise
358
+ """
359
+ # If the slurm function is distributed, it should be submitted to the executor only if the slurm has not been set up since the first launch is to set up the distributed environment and the second launch is to run the submit function in the distributed environment directly.
360
+ if self.is_distributed():
361
+ return self._is_valid_mode_for_executor and not self.slurm_has_been_set_up()
362
+ else:
363
+ return self._is_valid_mode_for_executor
364
+
365
+ def __call__(self, *submit_fn_args, **submit_fn_kwargs) -> Union[Job, Any]:
366
+ """Run the submit_fn with the given arguments and keyword arguments. The function is non-blocking in the mode of `slurm`, while other modes cause blocking. If there is no given arguments or keyword arguments, the default arguments and keyword arguments will be used.
367
+
368
+ :raises Exception: if the submit_fn is not set up
369
+ :return: Slurm Job or the return value of the submit_fn
370
+ """
371
+ if self._should_be_submitted_to_executor:
372
+ self.__before_submission()
373
+ submit_strategy = self.__dispatch_submit_strategy("submit")
374
+ submit_results = submit_strategy(*submit_fn_args, **submit_fn_kwargs)
375
+ self.__after_submission(submit_results)
376
+ return submit_results
377
+ else:
378
+ return self.submit_fn(*submit_fn_args, **submit_fn_kwargs)
379
+
380
+ def submit(self, *submit_fn_args, **submit_fn_kwargs) -> Union[Job, Any]:
381
+ """An alias function to `__call__`.
382
+
383
+ :raises Exception: if the submit_fn is not set up
384
+ :return: Slurm Job or the return value of the submit_fn
385
+ """
386
+ return self(*submit_fn_args, **submit_fn_kwargs)
387
+
388
+ def map_array(
389
+ self, *submit_fn_args, **submit_fn_kwargs
390
+ ) -> Union[Job[Any], List[Job[Any]], Any]:
391
+ """Run the submit_fn with the given arguments and keyword arguments. The function is non-blocking in the mode of `slurm`, while other modes cause blocking. If there is no given arguments or keyword arguments, the default arguments and keyword arguments will be used.
392
+
393
+ :raises Exception: if the submit_fn is not set up
394
+ :return: Slurm Job or the return value of the submit_fn
395
+ """
396
+ if (
397
+ self._should_be_submitted_to_executor
398
+ and not self.is_distributed()
399
+ and self.slurm_config.mode == "slurm"
400
+ ):
401
+ self.__before_submission()
402
+ submit_strategy = self.__dispatch_submit_strategy("map_array")
403
+ submit_results = submit_strategy(*submit_fn_args, **submit_fn_kwargs)
404
+ self.__after_submission(submit_results)
405
+ return submit_results
406
+ else:
407
+ raise Exception("The `map_array` method is only supported in the slurm mode.")
408
+
409
+ def __dispatch_submit_strategy(
410
+ self,
411
+ submit_mode: Literal["submit", "map_array"] = "submit",
412
+ *submit_fn_args,
413
+ **submit_fn_kwargs,
414
+ ) -> Callable[..., Union[Job, List[Job], Any]]:
415
+ if submit_mode == "submit":
416
+ if self.is_distributed():
417
+ return self.__distributed_submit
418
+ else:
419
+ return self.__submit
420
+ elif submit_mode == "map_array":
421
+ if self.is_distributed():
422
+ raise Exception("Distributed job does not support `map_array` mode.")
423
+ else:
424
+ return self.__submit_map_array
425
+ else:
426
+ raise Exception(f"Invalid submit mode: {submit_mode}")
427
+
428
+ def on_condition(
429
+ self,
430
+ jobs: Union[Job, List[Job], Tuple[Job]],
431
+ condition: Literal["afterany", "afterok", "afternotok"] = "afterok",
432
+ ) -> "SlurmFunction":
433
+ """Mark this job should be executed after the provided slurm jobs have been done. This function allows combining different conditions by multiple calling.
434
+
435
+ :param jobs: dependent jobs
436
+ :param condition: run condition, defaults to "afterok"
437
+ :return: the function itself
438
+ """
439
+ if not isinstance(jobs, (list, tuple)):
440
+ jobs = [jobs]
441
+
442
+ previous_conditions = self.slurm_params_kwargs.get("dependency", "")
443
+ append_condition = f"{condition}:{':'.join([job.job_id for job in jobs])}"
444
+ self.slurm_params_kwargs.update(
445
+ {
446
+ "dependency": (
447
+ f"{previous_conditions}:{append_condition}"
448
+ if previous_conditions
449
+ else append_condition
450
+ )
451
+ }
452
+ )
453
+ return self
454
+
455
+ def afterok(self, *jobs: Job) -> "SlurmFunction":
456
+ """Mark the function should be executed after the provided slurm jobs have been done.
457
+
458
+ :return: the function itself
459
+ """
460
+ return self.on_condition(list(jobs), "afterok")
461
+
462
+ def afterany(self, *jobs: Job) -> "SlurmFunction":
463
+ """Mark the function should be executed after any one of the provided slurm jobs has been done.
464
+
465
+ :return: the function itself
466
+ """
467
+ return self.on_condition(list(jobs), "afterany")
468
+
469
+ def afternotok(self, *jobs: Job) -> "SlurmFunction":
470
+ """Mark the function should be executed after any one of the provided slurm jobs has been failed.
471
+
472
+ :return: the function itself
473
+ """
474
+ return self.on_condition(list(jobs), "afternotok")
475
+
476
+ def __get_submit_args(
477
+ self,
478
+ *submit_fn_args,
479
+ **submit_fn_kwargs,
480
+ ):
481
+ submit_fn_args = self.default_submit_fn_args if not submit_fn_args else submit_fn_args
482
+ submit_fn_kwargs = (
483
+ self.default_submit_fn_kwargs if not submit_fn_kwargs else submit_fn_kwargs
484
+ )
485
+ return submit_fn_args, submit_fn_kwargs
486
+
487
+ def __submit(
488
+ self,
489
+ *submit_fn_args,
490
+ **submit_fn_kwargs,
491
+ ) -> Job:
492
+ submit_fn_args, submit_fn_kwargs = self.__get_submit_args(
493
+ *submit_fn_args, **submit_fn_kwargs
494
+ )
495
+ executor = self.get_executor()
496
+ self.__mark_slurm_has_been_set_up()
497
+ job = executor.submit(self.submit_fn, *submit_fn_args, **submit_fn_kwargs)
498
+ return job
499
+
500
+ def __submit_map_array(
501
+ self,
502
+ *submit_fn_args,
503
+ **submit_fn_kwargs,
504
+ ) -> List[Job]:
505
+ submit_fn_args, submit_fn_kwargs = self.__get_submit_args(
506
+ *submit_fn_args, **submit_fn_kwargs
507
+ )
508
+ executor = self.get_executor()
509
+ self.__mark_slurm_has_been_set_up()
510
+ job = executor.map_array(self.submit_fn, *submit_fn_args, **submit_fn_kwargs)
511
+ return job
512
+
513
+ def __distributed_submit(
514
+ self,
515
+ *submit_fn_args,
516
+ **submit_fn_kwargs,
517
+ ) -> Job:
518
+ submit_fn_args, submit_fn_kwargs = self.__get_submit_args(
519
+ *submit_fn_args, **submit_fn_kwargs
520
+ )
521
+
522
+ # The distributed job in slurm mode will be launched twice:
523
+ # 1. the first launch is to set up the distributed environment
524
+ # 2. the second launch is to run the submit function in the distributed environment directly
525
+ # The task to be submitted is to request enough resources and set up the distributed environment if
526
+ # in slurm mode.
527
+ if self.slurm_config.distributed_env_task == "torch":
528
+ task = PyTorchDistributedTask(
529
+ self.slurm_config.distributed_launch_command,
530
+ (self.system_argv if self.system_argv is not None else list(sys.argv[1:])),
531
+ self.slurm_config,
532
+ verbose=True,
533
+ **self.slurm_task_kwargs,
534
+ )
535
+ else:
536
+ raise ValueError(
537
+ f"Unsupported distributed environment task: {self.slurm_config.distributed_env_task}"
538
+ )
539
+
540
+ # We need to patch the submitit command string to include the task and the second launch
541
+ # command.
542
+ with SubmititDistributedCommandContext(self.slurm_config, task):
543
+ executor = self.get_executor()
544
+ self.__mark_slurm_has_been_set_up()
545
+ job = executor.submit(task)
546
+ return job
@@ -0,0 +1,47 @@
1
+ import shlex
2
+
3
+ from submitit import SlurmExecutor
4
+ from ..config import SlurmConfig
5
+ from ..task import Task
6
+
7
+
8
+ class SubmititDistributedCommandContext:
9
+ def __init__(self, config: SlurmConfig, task: Task):
10
+ self.config = config
11
+ self.task = task
12
+ self.is_patched = False
13
+ self.previous_submitit_command_str = None
14
+
15
+ def __enter__(self):
16
+ # monkey patch the submitit command to set up distributed env
17
+ # in distributed training, if two jobs are launched in the same node, the second job will fail
18
+ # but directly use `sbatch`` to submit the second job without any issues
19
+ # this patch is only applied when the mode is `slurm`. otherwise, it will not be patched.
20
+ if self.config.mode == "slurm":
21
+ task_command = self.task.command()
22
+
23
+ def _submitit_command_str(self) -> str:
24
+ return " ".join(
25
+ [
26
+ self.python,
27
+ "-u -m submitit.core._submit",
28
+ shlex.quote(str(self.folder)),
29
+ "\n".join(
30
+ [
31
+ "\n",
32
+ "# nntool command",
33
+ "export NNTOOL_SLURM_HAS_BEEN_SET_UP=1",
34
+ f"source {shlex.quote(str(self.folder))}/nntool_distributed_env.sh",
35
+ f"{task_command}",
36
+ ]
37
+ ),
38
+ ]
39
+ )
40
+
41
+ self.previous_submitit_command_str = SlurmExecutor._submitit_command_str
42
+ SlurmExecutor._submitit_command_str = property(_submitit_command_str)
43
+ self.is_patched = True
44
+
45
+ def __exit__(self, exc_type, exc_value, traceback):
46
+ if self.is_patched:
47
+ SlurmExecutor._submitit_command_str = self.previous_submitit_command_str