nntool 2.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,209 @@
1
+ import submitit
2
+ import copy
3
+
4
+ from submitit import Job
5
+ from typing import Any, Callable, Literal, Tuple, Union, Dict, List, Optional
6
+ from .config import SlurmConfig
7
+ from .core import SlurmBackend
8
+
9
+
10
+ class SlurmFunction:
11
+ """The function for the slurm job, which can be used for distributed or non-distributed job (controlled by `use_distributed_env` in the slurm dataclass)."""
12
+
13
+ def __init__(
14
+ self,
15
+ submit_fn: Callable[..., Any],
16
+ default_submit_fn_args: Optional[Tuple[Any]] = None,
17
+ default_submit_fn_kwargs: Optional[Dict[str, Any]] = None,
18
+ ) -> None:
19
+ """A slurm function for the slurm job, which can be used for distributed or non-distributed job (controlled by `use_distributed_env` in the slurm dataclass).
20
+
21
+ Args:
22
+ submit_fn: function to be submitted to Slurm, defaults to None
23
+ default_submit_fn_args: default args for submit_fn, defaults to ()
24
+ default_submit_fn_kwargs: default known word args for submit_fn, defaults to {}
25
+
26
+ Returns:
27
+ the wrapped submit function with configured slurm paramters
28
+ """
29
+ self.engine = SlurmBackend(submit_fn, default_submit_fn_args, default_submit_fn_kwargs)
30
+
31
+ def __create_copy(self) -> "SlurmFunction":
32
+ return copy.copy(self)
33
+
34
+ def is_configured(self) -> bool:
35
+ """Whether the slurm function has been configured.
36
+
37
+ Returns:
38
+ True if the slurm function has been configured, False otherwise
39
+ """
40
+ return self.engine.is_configured()
41
+
42
+ def is_distributed(self) -> bool:
43
+ """Whether the slurm function is distributed.
44
+
45
+ Returns:
46
+ True if the slurm function is distributed, False otherwise
47
+ """
48
+ return self.engine.is_distributed()
49
+
50
+ def get_executor(
51
+ self,
52
+ ) -> submitit.AutoExecutor:
53
+ return self.engine.get_executor()
54
+
55
+ def configure(
56
+ self,
57
+ slurm_config: SlurmConfig,
58
+ slurm_params_kwargs: Optional[Dict[str, str]] = None,
59
+ slurm_submit_kwargs: Optional[Dict[str, str]] = None,
60
+ slurm_task_kwargs: Optional[Dict[str, str]] = None,
61
+ system_argv: Optional[List[str]] = None,
62
+ pack_code_include_fn: Optional[Callable[[str, str], bool]] = None,
63
+ pack_code_exclude_fn: Optional[Callable[[str, str], bool]] = None,
64
+ ) -> "SlurmFunction":
65
+ """Update the slurm configuration for the slurm function. A slurm function for the slurm job, which can be used for distributed or non-distributed job (controlled by `use_distributed_env` in the slurm dataclass).
66
+
67
+ **Exported Distributed Enviroment Variables**
68
+
69
+ - ``NNTOOL_SLURM_HAS_BEEN_SET_UP`` is a special environment variable to indicate that the slurm has been set up.
70
+ - After the set up, the distributed job will be launched and the following variables are exported:
71
+ - ``num_processes``: int
72
+ - ``num_machines``: int
73
+ - ``machine_rank``: int
74
+ - ``main_process_ip``: str
75
+ - ``main_process_port``: int
76
+
77
+ Args:
78
+ slurm_config: SlurmConfig, the slurm configuration dataclass, defaults to None
79
+ slurm_params_kwargs: extra slurm arguments for the slurm configuration, defaults to {}
80
+ slurm_submit_kwargs: extra slurm arguments for `srun` or `sbatch`, defaults to {}
81
+ slurm_task_kwargs: extra arguments for the setting of distributed task, defaults to {}
82
+ system_argv: the system arguments for the second launch in the distributed task (by default it will use the current system arguments `sys.argv[1:]`), defaults to None
83
+
84
+ Returns:
85
+ a new copy with configured slurm parameters
86
+ """
87
+ configured_slurm_function = self.__create_copy()
88
+ configured_slurm_function.engine = self.engine.configure(
89
+ slurm_config,
90
+ slurm_params_kwargs,
91
+ slurm_submit_kwargs,
92
+ slurm_task_kwargs,
93
+ system_argv,
94
+ pack_code_include_fn,
95
+ pack_code_exclude_fn,
96
+ )
97
+ return configured_slurm_function
98
+
99
+ def __getitem__(self, slurm_config: Union[Dict[str, Any], Tuple[Any], Any]) -> "SlurmFunction":
100
+ """Instantiate the slurm configuration for the slurm function. A slurm function for the slurm job, which can be used for distributed or non-distributed job (controlled by `use_distributed_env` in the slurm dataclass).
101
+
102
+ **Exported Distributed Enviroment Variables**
103
+
104
+ - ``NNTOOL_SLURM_HAS_BEEN_SET_UP`` is a special environment variable to indicate that the slurm has been set up.
105
+ - After the set up, the distributed job will be launched and the following variables are exported:
106
+ - ``num_processes``: int
107
+ - ``num_machines``: int
108
+ - ``machine_rank``: int
109
+ - ``main_process_ip``: str
110
+ - ``main_process_port``: int
111
+
112
+ Args:
113
+ slurm_config: SlurmConfig, the slurm configuration dataclass
114
+
115
+ Returns:
116
+ the wrapped submit function with configured slurm paramters
117
+ """
118
+ configured_slurm_function = self.__create_copy()
119
+ configured_slurm_function.engine = self.engine[slurm_config]
120
+ return configured_slurm_function
121
+
122
+ def __call__(self, *submit_fn_args, **submit_fn_kwargs) -> Union[Job, Any]:
123
+ """Run the submit_fn with the given arguments and keyword arguments. The function is non-blocking in the mode of `slurm`, while other modes cause blocking. If there is no given arguments or keyword arguments, the default arguments and keyword arguments will be used.
124
+
125
+ Args:
126
+ submit_fn_args: arguments for the submit_fn
127
+ submit_fn_kwargs: keyword arguments for the submit_fn
128
+
129
+ Returns:
130
+ Slurm Job or the return value of the submit_fn, depends on the submit mode
131
+
132
+ Raises:
133
+ Exception: if the submit_fn is not set up
134
+ """
135
+ return self.engine(*submit_fn_args, **submit_fn_kwargs)
136
+
137
+ def submit(self, *submit_fn_args, **submit_fn_kwargs) -> Union[Job, Any]:
138
+ """An alias function to ``__call__``.
139
+
140
+ Args:
141
+ submit_fn_args: arguments for the submit_fn
142
+ submit_fn_kwargs: keyword arguments for the submit_fn
143
+
144
+ Raises:
145
+ Exception: if the submit_fn is not set up
146
+
147
+ Returns:
148
+ Slurm Job or the return value of the submit_fn
149
+ """
150
+ return self(*submit_fn_args, **submit_fn_kwargs)
151
+
152
+ def map_array(
153
+ self, *submit_fn_args, **submit_fn_kwargs
154
+ ) -> Union[Job[Any], List[Job[Any]], Any]:
155
+ """Run the submit_fn with the given arguments and keyword arguments. The function is non-blocking in the mode of `slurm`, while other modes cause blocking. If there is no given arguments or keyword arguments, the default arguments and keyword arguments will be used.
156
+
157
+ Args:
158
+ submit_fn_args: arguments for the submit_fn
159
+ submit_fn_kwargs: keyword arguments for the submit_fn
160
+
161
+ Raises:
162
+ Exception: if the submit_fn is not set up
163
+
164
+ Returns:
165
+ Slurm Job or the return value of the submit_fn
166
+ """
167
+ return self.engine.map_array(*submit_fn_args, **submit_fn_kwargs)
168
+
169
+ def on_condition(
170
+ self,
171
+ jobs: Union[Job, List[Job], Tuple[Job]],
172
+ condition: Literal["afterany", "afterok", "afternotok"] = "afterok",
173
+ ) -> "SlurmFunction":
174
+ """Mark this job should be executed after the provided slurm jobs have been done. This function allows combining different conditions by multiple calling.
175
+
176
+ Args:
177
+ jobs: dependent jobs
178
+ condition: run condition, defaults to "afterok"
179
+
180
+ Returns:
181
+ the function itself
182
+ """
183
+ configured_slurm_function = self.__create_copy()
184
+ configured_slurm_function.engine = self.engine.on_condition(jobs, condition)
185
+ return configured_slurm_function
186
+
187
+ def afterok(self, *jobs: Job) -> "SlurmFunction":
188
+ """Mark the function should be executed after the provided slurm jobs have been done.
189
+
190
+ Returns:
191
+ the new slurm function with the condition
192
+ """
193
+ return self.on_condition(list(jobs), "afterok")
194
+
195
+ def afterany(self, *jobs: Job) -> "SlurmFunction":
196
+ """Mark the function should be executed after any one of the provided slurm jobs has been done.
197
+
198
+ Returns:
199
+ the new slurm function with the condition
200
+ """
201
+ return self.on_condition(list(jobs), "afterany")
202
+
203
+ def afternotok(self, *jobs: Job) -> "SlurmFunction":
204
+ """Mark the function should be executed after any one of the provided slurm jobs has been failed.
205
+
206
+ Returns:
207
+ the new slurm function with the condition
208
+ """
209
+ return self.on_condition(list(jobs), "afternotok")
@@ -0,0 +1,6 @@
1
+ from .parse import parse_from_cli
2
+
3
+
4
+ __all__ = [
5
+ "parse_from_cli",
6
+ ]
@@ -0,0 +1,22 @@
1
+ import tyro
2
+
3
+ from typing import Union, Callable, Type, Any
4
+
5
+
6
+ def parse_from_cli(
7
+ ArgsType: Type[Any],
8
+ parser: Union[str, Callable] = "tyro",
9
+ *args,
10
+ **kwargs,
11
+ ) -> Any:
12
+ # parse with built-in parser or custom parser function
13
+ parser_fn = None
14
+ if isinstance(parser, str):
15
+ if parser == "tyro":
16
+ parser_fn = tyro.cli
17
+ else:
18
+ raise ValueError(f"Parser `{parser}` is not supported.")
19
+ else:
20
+ parser_fn = parser
21
+ args: ArgsType = parser_fn(ArgsType, *args, **kwargs)
22
+ return args
nntool/slurm/task.py ADDED
@@ -0,0 +1,300 @@
1
+ import os
2
+ import shlex
3
+ import shutil
4
+ import submitit
5
+ import subprocess
6
+
7
+ from pathlib import Path
8
+ from typing import Union, Generator, Callable
9
+ from dataclasses import dataclass
10
+ from .config import SlurmConfig
11
+ from .accelerator.utils import nvidia_smi_gpu_memory_stats_str
12
+
13
+ WANDB_DIRS = ("wandb", ".wandb")
14
+
15
+
16
+ def _is_py_or_dockerfile(path: str, root: str) -> bool:
17
+ file = os.path.basename(path)
18
+ return file.endswith(".py") or file.startswith("Dockerfile")
19
+
20
+
21
+ def include_code_files(path: str, root: str, code_ext: list[str]):
22
+ file = os.path.basename(path)
23
+ return any(file.endswith(ext) for ext in code_ext) or file.startswith("Dockerfile")
24
+
25
+
26
+ def exclude_code_folders(path: str, root: str, code_folders: list[str]):
27
+ return any(
28
+ os.path.relpath(path, root).startswith(code_folders + os.sep)
29
+ for code_folders in code_folders
30
+ )
31
+
32
+
33
+ def exclude_wandb_fn(path: str, root: str) -> bool:
34
+ return any(
35
+ os.path.relpath(path, root).startswith(wandb_dir + os.sep) for wandb_dir in WANDB_DIRS
36
+ )
37
+
38
+
39
+ def filtered_dir(
40
+ root: str,
41
+ include_fn: Callable[[str, str], bool],
42
+ exclude_fn: Callable[[str, str], bool],
43
+ ) -> Generator[str, None, None]:
44
+ """Simple generator to walk a directory."""
45
+
46
+ for dirpath, _, files in os.walk(root):
47
+ for fname in files:
48
+ file_path = os.path.join(dirpath, fname)
49
+ if include_fn(file_path, root) and not exclude_fn(file_path, root):
50
+ yield file_path
51
+
52
+
53
+ def pack_code_files(
54
+ root: str,
55
+ target_root: str,
56
+ include_fn: Callable[[str, str], bool] = _is_py_or_dockerfile,
57
+ exclude_fn: Callable[[str, str], bool] = exclude_wandb_fn,
58
+ ):
59
+ root = os.path.abspath(root)
60
+ code_root = Path(os.path.abspath(root))
61
+ code_target = Path(os.path.abspath(target_root)) / "code"
62
+ if not code_root.exists():
63
+ raise ValueError(f"Code root {code_root} does not exist.")
64
+ if not code_target.exists():
65
+ code_target.mkdir(parents=True)
66
+
67
+ for file_path in filtered_dir(root, include_fn, exclude_fn):
68
+ save_name = os.path.relpath(file_path, root)
69
+ sub_file_path, file_name = os.path.split(save_name)
70
+ sub_file_full_path = code_target / sub_file_path
71
+ if not sub_file_full_path.exists():
72
+ sub_file_full_path.mkdir(parents=True)
73
+ shutil.copy(file_path, sub_file_full_path / file_name)
74
+
75
+ return code_target
76
+
77
+
78
+ def reconstruct_command_line(argv):
79
+ # Quote each argument that needs special handling (like spaces or shell characters)
80
+ # and join them with spaces to form the command line
81
+ return " ".join(shlex.quote(arg) for arg in argv)
82
+
83
+
84
+ class Task:
85
+ """The base class for all tasks that will be run on Slurm. Especially useful for
86
+ distributed tasks that need to set up the distributed environment variables.
87
+
88
+ Args:
89
+ argv (list[str]): the command line arguments to run the task. This will be passed to the command method to reconstruct the command line.
90
+ slurm_config (SlurmConfig): the Slurm configuration to use for the task.
91
+ verbose (bool, optional): whether to print verbose output. Defaults to False.
92
+ """
93
+
94
+ def __init__(self, argv: list[str], slurm_config: SlurmConfig, verbose: bool = False):
95
+ self.argv = argv
96
+ self.slurm_config = slurm_config
97
+ self.verbose = verbose
98
+
99
+ def log(self, msg: str):
100
+ """Log a message to the console if verbose is enabled.
101
+
102
+ Args:
103
+ msg (str): the message to log.
104
+ """
105
+ if not self.verbose:
106
+ return
107
+ print(msg)
108
+
109
+ def command(self) -> str:
110
+ """Return the command to run the task. This method should be implemented by
111
+ subclasses to return the actual command line to run the task.
112
+
113
+ Raises:
114
+ NotImplementedError: If the method is not implemented by the subclass.
115
+
116
+ Returns:
117
+ str: the command to run the task.
118
+ """
119
+ raise NotImplementedError
120
+
121
+ def checkpoint(self):
122
+ """Return a checkpoint for the task. This is used to save the state of the task."""
123
+ return submitit.helpers.DelayedSubmission(self)
124
+
125
+
126
+ @dataclass
127
+ class DistributedTaskConfig:
128
+ """Configuration for distributed tasks. This is used to set up the distributed environment
129
+ variables for PyTorch distributed training.
130
+
131
+ Args:
132
+ num_processes (int): The total number of processes to run across all machines.
133
+ num_machines (int): The number of machines to run the task on.
134
+ machine_rank (int): The rank of the current machine in the distributed setup.
135
+ main_process_ip (str): The IP address of the main process (rank 0)
136
+ in the distributed setup.
137
+ main_process_port (int): The port of the main process (rank 0)
138
+ in the distributed setup.
139
+ """
140
+
141
+ # The number of processes to run in total across all machines.
142
+ num_processes: Union[int, str] = "$nntool_num_processes"
143
+
144
+ # The number of machines to run the task on.
145
+ num_machines: Union[int, str] = "$nntool_num_machines"
146
+
147
+ # The rank of the current machine in the distributed setup.
148
+ machine_rank: Union[int, str] = "$nntool_machine_rank"
149
+
150
+ # The IP address of the main process (rank 0) in the distributed setup.
151
+ main_process_ip: str = "$nntool_main_process_ip"
152
+
153
+ # The port of the main process (rank 0) in the distributed setup.
154
+ main_process_port: Union[int, str] = "$nntool_main_process_port"
155
+
156
+ def export_bash(self, output_folder: str):
157
+ """Export the distributed environment variables to a bash script.
158
+ This script can be sourced to set the environment variables for the distributed task.
159
+
160
+ Args:
161
+ output_folder (str): the folder to save the bash script to.
162
+ """
163
+ lines = ["#!/bin/bash"]
164
+ for k, v in self.__dict__.items():
165
+ lines.append(f"export nntool_{k}={v}")
166
+ with open(os.path.join(output_folder, "nntool_distributed_env.sh"), "w") as f:
167
+ f.write("\n".join(lines))
168
+
169
+
170
+ class PyTorchDistributedTask(Task):
171
+ """A task that runs on Slurm and sets up the PyTorch distributed environment variables. It runs the command locally
172
+ if in other modes.
173
+
174
+ Args:
175
+ launch_cmd (str): The command to launch the task.
176
+ argv (list[str]): The command line arguments for the task.
177
+ slurm_config (SlurmConfig): The Slurm configuration to use for the task.
178
+ verbose (bool, optional): _description_. Defaults to False.
179
+
180
+ References:
181
+ https://github.com/huggingface/accelerate/issues/1239
182
+ https://github.com/yuvalkirstain/PickScore/blob/main/trainer/slurm_scripts/slurm_train.py
183
+ https://github.com/facebookincubator/submitit/pull/1703
184
+ """
185
+
186
+ def __init__(
187
+ self,
188
+ launch_cmd: str,
189
+ argv: list[str],
190
+ slurm_config: SlurmConfig,
191
+ verbose: bool = False,
192
+ **env_setup_kwargs,
193
+ ):
194
+ super().__init__(argv, slurm_config, verbose)
195
+ self.launch_cmd = launch_cmd
196
+ self.env_setup_kwargs = env_setup_kwargs
197
+
198
+ # to be set up in the dist_set_up method
199
+ self.dist_args: DistributedTaskConfig = DistributedTaskConfig()
200
+ self.dist_env: Union[None, submitit.helpers.TorchDistributedEnvironment] = None
201
+
202
+ def set_up_dist_env(self):
203
+ """Set up the distributed environment variables for PyTorch distributed training."""
204
+
205
+ self.log("running task on slurm")
206
+ self.log("exporting PyTorch distributed environment variables")
207
+
208
+ # prepare enviroment variables
209
+ dist_env = submitit.helpers.TorchDistributedEnvironment().export(
210
+ set_cuda_visible_devices=False
211
+ )
212
+
213
+ # other setup
214
+ env_setup = {}
215
+
216
+ # set CUDA visible devices if slurm has scheduled GPUs otherwise use all GPUs (without setting
217
+ # CUDA_VISIBLE_DEVICES)
218
+ if self.slurm_config.mode == "slurm":
219
+ env_setup.update(
220
+ {"CUDA_VISIBLE_DEVICES": os.environ["SLURM_JOB_GPUS"]}
221
+ if "SLURM_JOB_GPUS" in os.environ
222
+ else {}
223
+ )
224
+
225
+ # other environment variables set by the user
226
+ env_setup.update(self.env_setup_kwargs)
227
+ self.log(f"Env setup: {env_setup}")
228
+
229
+ # update environment variables
230
+ os.environ.update(**env_setup)
231
+
232
+ self.log(nvidia_smi_gpu_memory_stats_str())
233
+ self.log(f"Master: {dist_env.master_addr}:{dist_env.master_port}")
234
+ self.log(f"Rank: {dist_env.rank}")
235
+ self.log(f"World size: {dist_env.world_size}")
236
+ self.log(f"Local rank: {dist_env.local_rank}")
237
+ self.log(f"Local world size: {dist_env.local_world_size}")
238
+ self.log(
239
+ f"Local rank {dist_env.local_rank}: CUDA_VISIBLE_DEVICES {os.environ.get('CUDA_VISIBLE_DEVICES', 'all')}"
240
+ )
241
+
242
+ # set distributed arguments
243
+ num_processes = (
244
+ self.slurm_config.tasks_per_node
245
+ * self.slurm_config.processes_per_task
246
+ * self.slurm_config.num_of_node
247
+ )
248
+ machine_rank = dist_env.rank // self.slurm_config.tasks_per_node
249
+ self.dist_args = DistributedTaskConfig(
250
+ num_processes=num_processes,
251
+ num_machines=self.slurm_config.num_of_node,
252
+ machine_rank=machine_rank,
253
+ main_process_ip=dist_env.master_addr,
254
+ main_process_port=dist_env.master_port,
255
+ )
256
+ self.dist_env = dist_env
257
+
258
+ return self.dist_args, self.dist_env
259
+
260
+ def command(self) -> str:
261
+ """Return the command to run the task. This method should be implemented by
262
+ subclasses to return the actual command line to run the task.
263
+
264
+ Returns:
265
+ str: the command to run the task.
266
+ """
267
+ cmd = self.launch_cmd.format(**self.dist_args.__dict__)
268
+ cmd += " " + reconstruct_command_line(self.argv)
269
+ return cmd
270
+
271
+ def __call__(self):
272
+ # Set up distributed environment
273
+ self.set_up_dist_env()
274
+
275
+ # Job environment
276
+ job_env = submitit.helpers.JobEnvironment()
277
+
278
+ # Concrete run command
279
+ cmd = self.command()
280
+
281
+ # Export distributed environment variables only the global rank 0 process will run the command
282
+ if self.dist_env.rank == 0:
283
+ print(f"running command: {cmd}")
284
+ if self.slurm_config.mode == "slurm":
285
+ try:
286
+ # Export distributed environment variables to a bash script
287
+ # and the fn will be launched after the job is scheduled
288
+ self.dist_args.export_bash(shlex.quote(str(job_env.paths.folder)))
289
+ except Exception as e:
290
+ print(f"failed to export distributed environment variables: {e}")
291
+ return -1
292
+ elif self.slurm_config.mode == "local":
293
+ cmd_list = shlex.split(cmd)
294
+ return subprocess.Popen(cmd_list)
295
+ else:
296
+ # If not on slurm mode, we can just run the command directly
297
+ # This is useful for local testing or when running on a single machine
298
+ return os.system(cmd)
299
+
300
+ return 0
nntool/slurm/wrap.py ADDED
@@ -0,0 +1,148 @@
1
+ import sys
2
+
3
+ from warnings import warn
4
+ from typing import Any, Callable, Type, Union, Dict, List
5
+
6
+ from .config import SlurmConfig
7
+ from .function import SlurmFunction
8
+ from .parser import parse_from_cli
9
+
10
+
11
+ def slurm_fn(
12
+ submit_fn: Callable,
13
+ ) -> SlurmFunction:
14
+ """A decorator to wrap a function to be run on slurm. The function decorated by this decorator should be launched on the way below. The decorated function `submit_fn` is non-blocking now. To block and get the return value, you can call ``job.result()``.
15
+
16
+ Args:
17
+ submit_fn: the function to be run on slurm
18
+
19
+ Returns:
20
+ the function to be run on slurm
21
+
22
+ Example:
23
+ >>> @slurm_fn
24
+ ... def run_on_slurm(a, b):
25
+ ... return a + b
26
+ >>> slurm_config = SlurmConfig(
27
+ ... mode="slurm",
28
+ ... partition="PARTITION",
29
+ ... job_name="EXAMPLE",
30
+ ... tasks_per_node=1,
31
+ ... cpus_per_task=8,
32
+ ... mem="1GB",
33
+ ... )
34
+ >>> job = run_on_slurm[slurm_config](1, b=2)
35
+ >>> result = job.result() # block and get the result
36
+ """
37
+ slurm_fn = SlurmFunction(submit_fn=submit_fn)
38
+
39
+ return slurm_fn
40
+
41
+
42
+ def slurm_launcher(
43
+ ArgsType: Type[Any],
44
+ parser: Union[str, Callable] = "tyro",
45
+ slurm_key: str = "slurm",
46
+ slurm_params_kwargs: dict = {},
47
+ slurm_submit_kwargs: dict = {},
48
+ slurm_task_kwargs: dict = {},
49
+ *extra_args,
50
+ **extra_kwargs,
51
+ ) -> Callable[[Callable[..., Any]], SlurmFunction]:
52
+ """A slurm launcher decorator for distributed or non-distributed job (controlled by `use_distributed_env` in slurm field). This decorator should be used as the program entry. The decorated function is non-blocking in the mode of `slurm`, while other modes cause blocking.
53
+
54
+ Args:
55
+ ArgsType: the experiment arguments type, which should be a dataclass (it
56
+ mush have a slurm field defined by `slurm_key`)
57
+ slurm_key: the key of the slurm field in the ArgsType, defaults to "slurm"
58
+ parser: the parser for the arguments, defaults to "tyro"
59
+ slurm_config: SlurmConfig, the slurm configuration dataclass
60
+ slurm_params_kwargs: extra slurm arguments for the slurm configuration, defaults to {}
61
+ slurm_submit_kwargs: extra slurm arguments for `srun` or `sbatch`, defaults to {}
62
+ slurm_task_kwargs: extra arguments for the setting of distributed task, defaults to {}
63
+ extra_args: extra arguments for the parser
64
+ extra_kwargs: extra keyword arguments for the parser
65
+
66
+ Returns:
67
+ decorator function with main entry
68
+
69
+ Exported Distributed Enviroment Variables:
70
+ 1. NNTOOL_SLURM_HAS_BEEN_SET_UP is a special environment variable to indicate that the slurm has been set up.
71
+ 2. After the set up, the distributed job will be launched and the following variables are exported: num_processes: int, num_machines: int, machine_rank: int, main_process_ip: str, main_process_port: int.
72
+ """
73
+ argv = list(sys.argv[1:])
74
+ args = parse_from_cli(ArgsType, parser, *extra_args, **extra_kwargs)
75
+
76
+ # check if args have slurm field
77
+ if not hasattr(args, slurm_key):
78
+ raise ValueError(
79
+ f"ArgsType should have a field named `{slurm_key}` to use `slurm_launcher` decorator."
80
+ )
81
+ slurm_config: SlurmConfig = getattr(args, slurm_key)
82
+
83
+ def decorator(
84
+ submit_fn: Callable[..., Any],
85
+ ) -> SlurmFunction:
86
+ return SlurmFunction(
87
+ submit_fn=submit_fn,
88
+ default_submit_fn_args=(args,),
89
+ ).configure(
90
+ slurm_config,
91
+ slurm_params_kwargs,
92
+ slurm_submit_kwargs,
93
+ slurm_task_kwargs,
94
+ system_argv=argv,
95
+ )
96
+
97
+ return decorator
98
+
99
+
100
+ def slurm_function(
101
+ submit_fn: Callable,
102
+ ):
103
+ """
104
+ A decorator to annoate a function to be run in slurm. The function decorated by this decorator should be launched in the way below.
105
+
106
+ Deprecated:
107
+ This function is deprecated and will be removed in future versions. Please use `slurm_fn` instead.
108
+
109
+ Example:
110
+ >>> @slurm_function
111
+ ... def run_on_slurm(a, b):
112
+ ... return a + b
113
+ >>> slurm_config = SlurmConfig(
114
+ ... mode="slurm",
115
+ ... partition="PARTITION",
116
+ ... job_name="EXAMPLE",
117
+ ... tasks_per_node=1,
118
+ ... cpus_per_task=8,
119
+ ... mem="1GB",
120
+ ... )
121
+ >>> job = run_on_slurm(slurm_config)(1, b=2)
122
+ >>> result = job.result() # block and get the result
123
+ """
124
+
125
+ def wrapper(
126
+ slurm_config: SlurmConfig,
127
+ slurm_params_kwargs: Dict[str, Any] = {},
128
+ slurm_submit_kwargs: Dict[str, Any] = {},
129
+ slurm_task_kwargs: Dict[str, Any] = {},
130
+ system_argv: Union[List[str], None] = None,
131
+ ) -> SlurmFunction:
132
+ warn(
133
+ "`slurm_function` has been deprecated. Please use `slurm_fn` instead.",
134
+ DeprecationWarning,
135
+ stacklevel=2,
136
+ )
137
+ slurm_fn = SlurmFunction(
138
+ submit_fn=submit_fn,
139
+ ).configure(
140
+ slurm_config,
141
+ slurm_params_kwargs,
142
+ slurm_submit_kwargs,
143
+ slurm_task_kwargs,
144
+ system_argv,
145
+ )
146
+ return slurm_fn
147
+
148
+ return wrapper