nntool 2.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nntool/__init__.py +7 -0
- nntool/build_backend.py +24 -0
- nntool/experiment/__init__.py +13 -0
- nntool/experiment/config.py +108 -0
- nntool/experiment/utils.py +63 -0
- nntool/slurm/__init__.py +21 -0
- nntool/slurm/accelerator/__init__.py +0 -0
- nntool/slurm/accelerator/utils.py +37 -0
- nntool/slurm/config.py +208 -0
- nntool/slurm/core/__init__.py +4 -0
- nntool/slurm/core/_slurm.py +546 -0
- nntool/slurm/core/_slurm_context.py +47 -0
- nntool/slurm/function.py +209 -0
- nntool/slurm/parser/__init__.py +6 -0
- nntool/slurm/parser/parse.py +22 -0
- nntool/slurm/task.py +300 -0
- nntool/slurm/wrap.py +148 -0
- nntool/utils/__init__.py +12 -0
- nntool/version.py +11 -0
- nntool/wandb/__init__.py +7 -0
- nntool/wandb/config.py +116 -0
- nntool-2.0.0rc0.dist-info/METADATA +12 -0
- nntool-2.0.0rc0.dist-info/RECORD +25 -0
- nntool-2.0.0rc0.dist-info/WHEEL +5 -0
- nntool-2.0.0rc0.dist-info/top_level.txt +1 -0
nntool/slurm/function.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
import submitit
|
|
2
|
+
import copy
|
|
3
|
+
|
|
4
|
+
from submitit import Job
|
|
5
|
+
from typing import Any, Callable, Literal, Tuple, Union, Dict, List, Optional
|
|
6
|
+
from .config import SlurmConfig
|
|
7
|
+
from .core import SlurmBackend
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SlurmFunction:
|
|
11
|
+
"""The function for the slurm job, which can be used for distributed or non-distributed job (controlled by `use_distributed_env` in the slurm dataclass)."""
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
submit_fn: Callable[..., Any],
|
|
16
|
+
default_submit_fn_args: Optional[Tuple[Any]] = None,
|
|
17
|
+
default_submit_fn_kwargs: Optional[Dict[str, Any]] = None,
|
|
18
|
+
) -> None:
|
|
19
|
+
"""A slurm function for the slurm job, which can be used for distributed or non-distributed job (controlled by `use_distributed_env` in the slurm dataclass).
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
submit_fn: function to be submitted to Slurm, defaults to None
|
|
23
|
+
default_submit_fn_args: default args for submit_fn, defaults to ()
|
|
24
|
+
default_submit_fn_kwargs: default known word args for submit_fn, defaults to {}
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
the wrapped submit function with configured slurm paramters
|
|
28
|
+
"""
|
|
29
|
+
self.engine = SlurmBackend(submit_fn, default_submit_fn_args, default_submit_fn_kwargs)
|
|
30
|
+
|
|
31
|
+
def __create_copy(self) -> "SlurmFunction":
|
|
32
|
+
return copy.copy(self)
|
|
33
|
+
|
|
34
|
+
def is_configured(self) -> bool:
|
|
35
|
+
"""Whether the slurm function has been configured.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
True if the slurm function has been configured, False otherwise
|
|
39
|
+
"""
|
|
40
|
+
return self.engine.is_configured()
|
|
41
|
+
|
|
42
|
+
def is_distributed(self) -> bool:
|
|
43
|
+
"""Whether the slurm function is distributed.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
True if the slurm function is distributed, False otherwise
|
|
47
|
+
"""
|
|
48
|
+
return self.engine.is_distributed()
|
|
49
|
+
|
|
50
|
+
def get_executor(
|
|
51
|
+
self,
|
|
52
|
+
) -> submitit.AutoExecutor:
|
|
53
|
+
return self.engine.get_executor()
|
|
54
|
+
|
|
55
|
+
def configure(
|
|
56
|
+
self,
|
|
57
|
+
slurm_config: SlurmConfig,
|
|
58
|
+
slurm_params_kwargs: Optional[Dict[str, str]] = None,
|
|
59
|
+
slurm_submit_kwargs: Optional[Dict[str, str]] = None,
|
|
60
|
+
slurm_task_kwargs: Optional[Dict[str, str]] = None,
|
|
61
|
+
system_argv: Optional[List[str]] = None,
|
|
62
|
+
pack_code_include_fn: Optional[Callable[[str, str], bool]] = None,
|
|
63
|
+
pack_code_exclude_fn: Optional[Callable[[str, str], bool]] = None,
|
|
64
|
+
) -> "SlurmFunction":
|
|
65
|
+
"""Update the slurm configuration for the slurm function. A slurm function for the slurm job, which can be used for distributed or non-distributed job (controlled by `use_distributed_env` in the slurm dataclass).
|
|
66
|
+
|
|
67
|
+
**Exported Distributed Enviroment Variables**
|
|
68
|
+
|
|
69
|
+
- ``NNTOOL_SLURM_HAS_BEEN_SET_UP`` is a special environment variable to indicate that the slurm has been set up.
|
|
70
|
+
- After the set up, the distributed job will be launched and the following variables are exported:
|
|
71
|
+
- ``num_processes``: int
|
|
72
|
+
- ``num_machines``: int
|
|
73
|
+
- ``machine_rank``: int
|
|
74
|
+
- ``main_process_ip``: str
|
|
75
|
+
- ``main_process_port``: int
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
slurm_config: SlurmConfig, the slurm configuration dataclass, defaults to None
|
|
79
|
+
slurm_params_kwargs: extra slurm arguments for the slurm configuration, defaults to {}
|
|
80
|
+
slurm_submit_kwargs: extra slurm arguments for `srun` or `sbatch`, defaults to {}
|
|
81
|
+
slurm_task_kwargs: extra arguments for the setting of distributed task, defaults to {}
|
|
82
|
+
system_argv: the system arguments for the second launch in the distributed task (by default it will use the current system arguments `sys.argv[1:]`), defaults to None
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
a new copy with configured slurm parameters
|
|
86
|
+
"""
|
|
87
|
+
configured_slurm_function = self.__create_copy()
|
|
88
|
+
configured_slurm_function.engine = self.engine.configure(
|
|
89
|
+
slurm_config,
|
|
90
|
+
slurm_params_kwargs,
|
|
91
|
+
slurm_submit_kwargs,
|
|
92
|
+
slurm_task_kwargs,
|
|
93
|
+
system_argv,
|
|
94
|
+
pack_code_include_fn,
|
|
95
|
+
pack_code_exclude_fn,
|
|
96
|
+
)
|
|
97
|
+
return configured_slurm_function
|
|
98
|
+
|
|
99
|
+
def __getitem__(self, slurm_config: Union[Dict[str, Any], Tuple[Any], Any]) -> "SlurmFunction":
|
|
100
|
+
"""Instantiate the slurm configuration for the slurm function. A slurm function for the slurm job, which can be used for distributed or non-distributed job (controlled by `use_distributed_env` in the slurm dataclass).
|
|
101
|
+
|
|
102
|
+
**Exported Distributed Enviroment Variables**
|
|
103
|
+
|
|
104
|
+
- ``NNTOOL_SLURM_HAS_BEEN_SET_UP`` is a special environment variable to indicate that the slurm has been set up.
|
|
105
|
+
- After the set up, the distributed job will be launched and the following variables are exported:
|
|
106
|
+
- ``num_processes``: int
|
|
107
|
+
- ``num_machines``: int
|
|
108
|
+
- ``machine_rank``: int
|
|
109
|
+
- ``main_process_ip``: str
|
|
110
|
+
- ``main_process_port``: int
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
slurm_config: SlurmConfig, the slurm configuration dataclass
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
the wrapped submit function with configured slurm paramters
|
|
117
|
+
"""
|
|
118
|
+
configured_slurm_function = self.__create_copy()
|
|
119
|
+
configured_slurm_function.engine = self.engine[slurm_config]
|
|
120
|
+
return configured_slurm_function
|
|
121
|
+
|
|
122
|
+
def __call__(self, *submit_fn_args, **submit_fn_kwargs) -> Union[Job, Any]:
|
|
123
|
+
"""Run the submit_fn with the given arguments and keyword arguments. The function is non-blocking in the mode of `slurm`, while other modes cause blocking. If there is no given arguments or keyword arguments, the default arguments and keyword arguments will be used.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
submit_fn_args: arguments for the submit_fn
|
|
127
|
+
submit_fn_kwargs: keyword arguments for the submit_fn
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
Slurm Job or the return value of the submit_fn, depends on the submit mode
|
|
131
|
+
|
|
132
|
+
Raises:
|
|
133
|
+
Exception: if the submit_fn is not set up
|
|
134
|
+
"""
|
|
135
|
+
return self.engine(*submit_fn_args, **submit_fn_kwargs)
|
|
136
|
+
|
|
137
|
+
def submit(self, *submit_fn_args, **submit_fn_kwargs) -> Union[Job, Any]:
|
|
138
|
+
"""An alias function to ``__call__``.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
submit_fn_args: arguments for the submit_fn
|
|
142
|
+
submit_fn_kwargs: keyword arguments for the submit_fn
|
|
143
|
+
|
|
144
|
+
Raises:
|
|
145
|
+
Exception: if the submit_fn is not set up
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
Slurm Job or the return value of the submit_fn
|
|
149
|
+
"""
|
|
150
|
+
return self(*submit_fn_args, **submit_fn_kwargs)
|
|
151
|
+
|
|
152
|
+
def map_array(
|
|
153
|
+
self, *submit_fn_args, **submit_fn_kwargs
|
|
154
|
+
) -> Union[Job[Any], List[Job[Any]], Any]:
|
|
155
|
+
"""Run the submit_fn with the given arguments and keyword arguments. The function is non-blocking in the mode of `slurm`, while other modes cause blocking. If there is no given arguments or keyword arguments, the default arguments and keyword arguments will be used.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
submit_fn_args: arguments for the submit_fn
|
|
159
|
+
submit_fn_kwargs: keyword arguments for the submit_fn
|
|
160
|
+
|
|
161
|
+
Raises:
|
|
162
|
+
Exception: if the submit_fn is not set up
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
Slurm Job or the return value of the submit_fn
|
|
166
|
+
"""
|
|
167
|
+
return self.engine.map_array(*submit_fn_args, **submit_fn_kwargs)
|
|
168
|
+
|
|
169
|
+
def on_condition(
|
|
170
|
+
self,
|
|
171
|
+
jobs: Union[Job, List[Job], Tuple[Job]],
|
|
172
|
+
condition: Literal["afterany", "afterok", "afternotok"] = "afterok",
|
|
173
|
+
) -> "SlurmFunction":
|
|
174
|
+
"""Mark this job should be executed after the provided slurm jobs have been done. This function allows combining different conditions by multiple calling.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
jobs: dependent jobs
|
|
178
|
+
condition: run condition, defaults to "afterok"
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
the function itself
|
|
182
|
+
"""
|
|
183
|
+
configured_slurm_function = self.__create_copy()
|
|
184
|
+
configured_slurm_function.engine = self.engine.on_condition(jobs, condition)
|
|
185
|
+
return configured_slurm_function
|
|
186
|
+
|
|
187
|
+
def afterok(self, *jobs: Job) -> "SlurmFunction":
|
|
188
|
+
"""Mark the function should be executed after the provided slurm jobs have been done.
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
the new slurm function with the condition
|
|
192
|
+
"""
|
|
193
|
+
return self.on_condition(list(jobs), "afterok")
|
|
194
|
+
|
|
195
|
+
def afterany(self, *jobs: Job) -> "SlurmFunction":
|
|
196
|
+
"""Mark the function should be executed after any one of the provided slurm jobs has been done.
|
|
197
|
+
|
|
198
|
+
Returns:
|
|
199
|
+
the new slurm function with the condition
|
|
200
|
+
"""
|
|
201
|
+
return self.on_condition(list(jobs), "afterany")
|
|
202
|
+
|
|
203
|
+
def afternotok(self, *jobs: Job) -> "SlurmFunction":
|
|
204
|
+
"""Mark the function should be executed after any one of the provided slurm jobs has been failed.
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
the new slurm function with the condition
|
|
208
|
+
"""
|
|
209
|
+
return self.on_condition(list(jobs), "afternotok")
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
import tyro
|
|
2
|
+
|
|
3
|
+
from typing import Union, Callable, Type, Any
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def parse_from_cli(
|
|
7
|
+
ArgsType: Type[Any],
|
|
8
|
+
parser: Union[str, Callable] = "tyro",
|
|
9
|
+
*args,
|
|
10
|
+
**kwargs,
|
|
11
|
+
) -> Any:
|
|
12
|
+
# parse with built-in parser or custom parser function
|
|
13
|
+
parser_fn = None
|
|
14
|
+
if isinstance(parser, str):
|
|
15
|
+
if parser == "tyro":
|
|
16
|
+
parser_fn = tyro.cli
|
|
17
|
+
else:
|
|
18
|
+
raise ValueError(f"Parser `{parser}` is not supported.")
|
|
19
|
+
else:
|
|
20
|
+
parser_fn = parser
|
|
21
|
+
args: ArgsType = parser_fn(ArgsType, *args, **kwargs)
|
|
22
|
+
return args
|
nntool/slurm/task.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import shlex
|
|
3
|
+
import shutil
|
|
4
|
+
import submitit
|
|
5
|
+
import subprocess
|
|
6
|
+
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Union, Generator, Callable
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from .config import SlurmConfig
|
|
11
|
+
from .accelerator.utils import nvidia_smi_gpu_memory_stats_str
|
|
12
|
+
|
|
13
|
+
WANDB_DIRS = ("wandb", ".wandb")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _is_py_or_dockerfile(path: str, root: str) -> bool:
|
|
17
|
+
file = os.path.basename(path)
|
|
18
|
+
return file.endswith(".py") or file.startswith("Dockerfile")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def include_code_files(path: str, root: str, code_ext: list[str]):
|
|
22
|
+
file = os.path.basename(path)
|
|
23
|
+
return any(file.endswith(ext) for ext in code_ext) or file.startswith("Dockerfile")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def exclude_code_folders(path: str, root: str, code_folders: list[str]):
|
|
27
|
+
return any(
|
|
28
|
+
os.path.relpath(path, root).startswith(code_folders + os.sep)
|
|
29
|
+
for code_folders in code_folders
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def exclude_wandb_fn(path: str, root: str) -> bool:
|
|
34
|
+
return any(
|
|
35
|
+
os.path.relpath(path, root).startswith(wandb_dir + os.sep) for wandb_dir in WANDB_DIRS
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def filtered_dir(
|
|
40
|
+
root: str,
|
|
41
|
+
include_fn: Callable[[str, str], bool],
|
|
42
|
+
exclude_fn: Callable[[str, str], bool],
|
|
43
|
+
) -> Generator[str, None, None]:
|
|
44
|
+
"""Simple generator to walk a directory."""
|
|
45
|
+
|
|
46
|
+
for dirpath, _, files in os.walk(root):
|
|
47
|
+
for fname in files:
|
|
48
|
+
file_path = os.path.join(dirpath, fname)
|
|
49
|
+
if include_fn(file_path, root) and not exclude_fn(file_path, root):
|
|
50
|
+
yield file_path
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def pack_code_files(
|
|
54
|
+
root: str,
|
|
55
|
+
target_root: str,
|
|
56
|
+
include_fn: Callable[[str, str], bool] = _is_py_or_dockerfile,
|
|
57
|
+
exclude_fn: Callable[[str, str], bool] = exclude_wandb_fn,
|
|
58
|
+
):
|
|
59
|
+
root = os.path.abspath(root)
|
|
60
|
+
code_root = Path(os.path.abspath(root))
|
|
61
|
+
code_target = Path(os.path.abspath(target_root)) / "code"
|
|
62
|
+
if not code_root.exists():
|
|
63
|
+
raise ValueError(f"Code root {code_root} does not exist.")
|
|
64
|
+
if not code_target.exists():
|
|
65
|
+
code_target.mkdir(parents=True)
|
|
66
|
+
|
|
67
|
+
for file_path in filtered_dir(root, include_fn, exclude_fn):
|
|
68
|
+
save_name = os.path.relpath(file_path, root)
|
|
69
|
+
sub_file_path, file_name = os.path.split(save_name)
|
|
70
|
+
sub_file_full_path = code_target / sub_file_path
|
|
71
|
+
if not sub_file_full_path.exists():
|
|
72
|
+
sub_file_full_path.mkdir(parents=True)
|
|
73
|
+
shutil.copy(file_path, sub_file_full_path / file_name)
|
|
74
|
+
|
|
75
|
+
return code_target
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def reconstruct_command_line(argv):
|
|
79
|
+
# Quote each argument that needs special handling (like spaces or shell characters)
|
|
80
|
+
# and join them with spaces to form the command line
|
|
81
|
+
return " ".join(shlex.quote(arg) for arg in argv)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class Task:
|
|
85
|
+
"""The base class for all tasks that will be run on Slurm. Especially useful for
|
|
86
|
+
distributed tasks that need to set up the distributed environment variables.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
argv (list[str]): the command line arguments to run the task. This will be passed to the command method to reconstruct the command line.
|
|
90
|
+
slurm_config (SlurmConfig): the Slurm configuration to use for the task.
|
|
91
|
+
verbose (bool, optional): whether to print verbose output. Defaults to False.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
def __init__(self, argv: list[str], slurm_config: SlurmConfig, verbose: bool = False):
|
|
95
|
+
self.argv = argv
|
|
96
|
+
self.slurm_config = slurm_config
|
|
97
|
+
self.verbose = verbose
|
|
98
|
+
|
|
99
|
+
def log(self, msg: str):
|
|
100
|
+
"""Log a message to the console if verbose is enabled.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
msg (str): the message to log.
|
|
104
|
+
"""
|
|
105
|
+
if not self.verbose:
|
|
106
|
+
return
|
|
107
|
+
print(msg)
|
|
108
|
+
|
|
109
|
+
def command(self) -> str:
|
|
110
|
+
"""Return the command to run the task. This method should be implemented by
|
|
111
|
+
subclasses to return the actual command line to run the task.
|
|
112
|
+
|
|
113
|
+
Raises:
|
|
114
|
+
NotImplementedError: If the method is not implemented by the subclass.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
str: the command to run the task.
|
|
118
|
+
"""
|
|
119
|
+
raise NotImplementedError
|
|
120
|
+
|
|
121
|
+
def checkpoint(self):
|
|
122
|
+
"""Return a checkpoint for the task. This is used to save the state of the task."""
|
|
123
|
+
return submitit.helpers.DelayedSubmission(self)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
@dataclass
|
|
127
|
+
class DistributedTaskConfig:
|
|
128
|
+
"""Configuration for distributed tasks. This is used to set up the distributed environment
|
|
129
|
+
variables for PyTorch distributed training.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
num_processes (int): The total number of processes to run across all machines.
|
|
133
|
+
num_machines (int): The number of machines to run the task on.
|
|
134
|
+
machine_rank (int): The rank of the current machine in the distributed setup.
|
|
135
|
+
main_process_ip (str): The IP address of the main process (rank 0)
|
|
136
|
+
in the distributed setup.
|
|
137
|
+
main_process_port (int): The port of the main process (rank 0)
|
|
138
|
+
in the distributed setup.
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
# The number of processes to run in total across all machines.
|
|
142
|
+
num_processes: Union[int, str] = "$nntool_num_processes"
|
|
143
|
+
|
|
144
|
+
# The number of machines to run the task on.
|
|
145
|
+
num_machines: Union[int, str] = "$nntool_num_machines"
|
|
146
|
+
|
|
147
|
+
# The rank of the current machine in the distributed setup.
|
|
148
|
+
machine_rank: Union[int, str] = "$nntool_machine_rank"
|
|
149
|
+
|
|
150
|
+
# The IP address of the main process (rank 0) in the distributed setup.
|
|
151
|
+
main_process_ip: str = "$nntool_main_process_ip"
|
|
152
|
+
|
|
153
|
+
# The port of the main process (rank 0) in the distributed setup.
|
|
154
|
+
main_process_port: Union[int, str] = "$nntool_main_process_port"
|
|
155
|
+
|
|
156
|
+
def export_bash(self, output_folder: str):
|
|
157
|
+
"""Export the distributed environment variables to a bash script.
|
|
158
|
+
This script can be sourced to set the environment variables for the distributed task.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
output_folder (str): the folder to save the bash script to.
|
|
162
|
+
"""
|
|
163
|
+
lines = ["#!/bin/bash"]
|
|
164
|
+
for k, v in self.__dict__.items():
|
|
165
|
+
lines.append(f"export nntool_{k}={v}")
|
|
166
|
+
with open(os.path.join(output_folder, "nntool_distributed_env.sh"), "w") as f:
|
|
167
|
+
f.write("\n".join(lines))
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class PyTorchDistributedTask(Task):
|
|
171
|
+
"""A task that runs on Slurm and sets up the PyTorch distributed environment variables. It runs the command locally
|
|
172
|
+
if in other modes.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
launch_cmd (str): The command to launch the task.
|
|
176
|
+
argv (list[str]): The command line arguments for the task.
|
|
177
|
+
slurm_config (SlurmConfig): The Slurm configuration to use for the task.
|
|
178
|
+
verbose (bool, optional): _description_. Defaults to False.
|
|
179
|
+
|
|
180
|
+
References:
|
|
181
|
+
https://github.com/huggingface/accelerate/issues/1239
|
|
182
|
+
https://github.com/yuvalkirstain/PickScore/blob/main/trainer/slurm_scripts/slurm_train.py
|
|
183
|
+
https://github.com/facebookincubator/submitit/pull/1703
|
|
184
|
+
"""
|
|
185
|
+
|
|
186
|
+
def __init__(
|
|
187
|
+
self,
|
|
188
|
+
launch_cmd: str,
|
|
189
|
+
argv: list[str],
|
|
190
|
+
slurm_config: SlurmConfig,
|
|
191
|
+
verbose: bool = False,
|
|
192
|
+
**env_setup_kwargs,
|
|
193
|
+
):
|
|
194
|
+
super().__init__(argv, slurm_config, verbose)
|
|
195
|
+
self.launch_cmd = launch_cmd
|
|
196
|
+
self.env_setup_kwargs = env_setup_kwargs
|
|
197
|
+
|
|
198
|
+
# to be set up in the dist_set_up method
|
|
199
|
+
self.dist_args: DistributedTaskConfig = DistributedTaskConfig()
|
|
200
|
+
self.dist_env: Union[None, submitit.helpers.TorchDistributedEnvironment] = None
|
|
201
|
+
|
|
202
|
+
def set_up_dist_env(self):
|
|
203
|
+
"""Set up the distributed environment variables for PyTorch distributed training."""
|
|
204
|
+
|
|
205
|
+
self.log("running task on slurm")
|
|
206
|
+
self.log("exporting PyTorch distributed environment variables")
|
|
207
|
+
|
|
208
|
+
# prepare enviroment variables
|
|
209
|
+
dist_env = submitit.helpers.TorchDistributedEnvironment().export(
|
|
210
|
+
set_cuda_visible_devices=False
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
# other setup
|
|
214
|
+
env_setup = {}
|
|
215
|
+
|
|
216
|
+
# set CUDA visible devices if slurm has scheduled GPUs otherwise use all GPUs (without setting
|
|
217
|
+
# CUDA_VISIBLE_DEVICES)
|
|
218
|
+
if self.slurm_config.mode == "slurm":
|
|
219
|
+
env_setup.update(
|
|
220
|
+
{"CUDA_VISIBLE_DEVICES": os.environ["SLURM_JOB_GPUS"]}
|
|
221
|
+
if "SLURM_JOB_GPUS" in os.environ
|
|
222
|
+
else {}
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
# other environment variables set by the user
|
|
226
|
+
env_setup.update(self.env_setup_kwargs)
|
|
227
|
+
self.log(f"Env setup: {env_setup}")
|
|
228
|
+
|
|
229
|
+
# update environment variables
|
|
230
|
+
os.environ.update(**env_setup)
|
|
231
|
+
|
|
232
|
+
self.log(nvidia_smi_gpu_memory_stats_str())
|
|
233
|
+
self.log(f"Master: {dist_env.master_addr}:{dist_env.master_port}")
|
|
234
|
+
self.log(f"Rank: {dist_env.rank}")
|
|
235
|
+
self.log(f"World size: {dist_env.world_size}")
|
|
236
|
+
self.log(f"Local rank: {dist_env.local_rank}")
|
|
237
|
+
self.log(f"Local world size: {dist_env.local_world_size}")
|
|
238
|
+
self.log(
|
|
239
|
+
f"Local rank {dist_env.local_rank}: CUDA_VISIBLE_DEVICES {os.environ.get('CUDA_VISIBLE_DEVICES', 'all')}"
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
# set distributed arguments
|
|
243
|
+
num_processes = (
|
|
244
|
+
self.slurm_config.tasks_per_node
|
|
245
|
+
* self.slurm_config.processes_per_task
|
|
246
|
+
* self.slurm_config.num_of_node
|
|
247
|
+
)
|
|
248
|
+
machine_rank = dist_env.rank // self.slurm_config.tasks_per_node
|
|
249
|
+
self.dist_args = DistributedTaskConfig(
|
|
250
|
+
num_processes=num_processes,
|
|
251
|
+
num_machines=self.slurm_config.num_of_node,
|
|
252
|
+
machine_rank=machine_rank,
|
|
253
|
+
main_process_ip=dist_env.master_addr,
|
|
254
|
+
main_process_port=dist_env.master_port,
|
|
255
|
+
)
|
|
256
|
+
self.dist_env = dist_env
|
|
257
|
+
|
|
258
|
+
return self.dist_args, self.dist_env
|
|
259
|
+
|
|
260
|
+
def command(self) -> str:
|
|
261
|
+
"""Return the command to run the task. This method should be implemented by
|
|
262
|
+
subclasses to return the actual command line to run the task.
|
|
263
|
+
|
|
264
|
+
Returns:
|
|
265
|
+
str: the command to run the task.
|
|
266
|
+
"""
|
|
267
|
+
cmd = self.launch_cmd.format(**self.dist_args.__dict__)
|
|
268
|
+
cmd += " " + reconstruct_command_line(self.argv)
|
|
269
|
+
return cmd
|
|
270
|
+
|
|
271
|
+
def __call__(self):
|
|
272
|
+
# Set up distributed environment
|
|
273
|
+
self.set_up_dist_env()
|
|
274
|
+
|
|
275
|
+
# Job environment
|
|
276
|
+
job_env = submitit.helpers.JobEnvironment()
|
|
277
|
+
|
|
278
|
+
# Concrete run command
|
|
279
|
+
cmd = self.command()
|
|
280
|
+
|
|
281
|
+
# Export distributed environment variables only the global rank 0 process will run the command
|
|
282
|
+
if self.dist_env.rank == 0:
|
|
283
|
+
print(f"running command: {cmd}")
|
|
284
|
+
if self.slurm_config.mode == "slurm":
|
|
285
|
+
try:
|
|
286
|
+
# Export distributed environment variables to a bash script
|
|
287
|
+
# and the fn will be launched after the job is scheduled
|
|
288
|
+
self.dist_args.export_bash(shlex.quote(str(job_env.paths.folder)))
|
|
289
|
+
except Exception as e:
|
|
290
|
+
print(f"failed to export distributed environment variables: {e}")
|
|
291
|
+
return -1
|
|
292
|
+
elif self.slurm_config.mode == "local":
|
|
293
|
+
cmd_list = shlex.split(cmd)
|
|
294
|
+
return subprocess.Popen(cmd_list)
|
|
295
|
+
else:
|
|
296
|
+
# If not on slurm mode, we can just run the command directly
|
|
297
|
+
# This is useful for local testing or when running on a single machine
|
|
298
|
+
return os.system(cmd)
|
|
299
|
+
|
|
300
|
+
return 0
|
nntool/slurm/wrap.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
|
|
3
|
+
from warnings import warn
|
|
4
|
+
from typing import Any, Callable, Type, Union, Dict, List
|
|
5
|
+
|
|
6
|
+
from .config import SlurmConfig
|
|
7
|
+
from .function import SlurmFunction
|
|
8
|
+
from .parser import parse_from_cli
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def slurm_fn(
|
|
12
|
+
submit_fn: Callable,
|
|
13
|
+
) -> SlurmFunction:
|
|
14
|
+
"""A decorator to wrap a function to be run on slurm. The function decorated by this decorator should be launched on the way below. The decorated function `submit_fn` is non-blocking now. To block and get the return value, you can call ``job.result()``.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
submit_fn: the function to be run on slurm
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
the function to be run on slurm
|
|
21
|
+
|
|
22
|
+
Example:
|
|
23
|
+
>>> @slurm_fn
|
|
24
|
+
... def run_on_slurm(a, b):
|
|
25
|
+
... return a + b
|
|
26
|
+
>>> slurm_config = SlurmConfig(
|
|
27
|
+
... mode="slurm",
|
|
28
|
+
... partition="PARTITION",
|
|
29
|
+
... job_name="EXAMPLE",
|
|
30
|
+
... tasks_per_node=1,
|
|
31
|
+
... cpus_per_task=8,
|
|
32
|
+
... mem="1GB",
|
|
33
|
+
... )
|
|
34
|
+
>>> job = run_on_slurm[slurm_config](1, b=2)
|
|
35
|
+
>>> result = job.result() # block and get the result
|
|
36
|
+
"""
|
|
37
|
+
slurm_fn = SlurmFunction(submit_fn=submit_fn)
|
|
38
|
+
|
|
39
|
+
return slurm_fn
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def slurm_launcher(
|
|
43
|
+
ArgsType: Type[Any],
|
|
44
|
+
parser: Union[str, Callable] = "tyro",
|
|
45
|
+
slurm_key: str = "slurm",
|
|
46
|
+
slurm_params_kwargs: dict = {},
|
|
47
|
+
slurm_submit_kwargs: dict = {},
|
|
48
|
+
slurm_task_kwargs: dict = {},
|
|
49
|
+
*extra_args,
|
|
50
|
+
**extra_kwargs,
|
|
51
|
+
) -> Callable[[Callable[..., Any]], SlurmFunction]:
|
|
52
|
+
"""A slurm launcher decorator for distributed or non-distributed job (controlled by `use_distributed_env` in slurm field). This decorator should be used as the program entry. The decorated function is non-blocking in the mode of `slurm`, while other modes cause blocking.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
ArgsType: the experiment arguments type, which should be a dataclass (it
|
|
56
|
+
mush have a slurm field defined by `slurm_key`)
|
|
57
|
+
slurm_key: the key of the slurm field in the ArgsType, defaults to "slurm"
|
|
58
|
+
parser: the parser for the arguments, defaults to "tyro"
|
|
59
|
+
slurm_config: SlurmConfig, the slurm configuration dataclass
|
|
60
|
+
slurm_params_kwargs: extra slurm arguments for the slurm configuration, defaults to {}
|
|
61
|
+
slurm_submit_kwargs: extra slurm arguments for `srun` or `sbatch`, defaults to {}
|
|
62
|
+
slurm_task_kwargs: extra arguments for the setting of distributed task, defaults to {}
|
|
63
|
+
extra_args: extra arguments for the parser
|
|
64
|
+
extra_kwargs: extra keyword arguments for the parser
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
decorator function with main entry
|
|
68
|
+
|
|
69
|
+
Exported Distributed Enviroment Variables:
|
|
70
|
+
1. NNTOOL_SLURM_HAS_BEEN_SET_UP is a special environment variable to indicate that the slurm has been set up.
|
|
71
|
+
2. After the set up, the distributed job will be launched and the following variables are exported: num_processes: int, num_machines: int, machine_rank: int, main_process_ip: str, main_process_port: int.
|
|
72
|
+
"""
|
|
73
|
+
argv = list(sys.argv[1:])
|
|
74
|
+
args = parse_from_cli(ArgsType, parser, *extra_args, **extra_kwargs)
|
|
75
|
+
|
|
76
|
+
# check if args have slurm field
|
|
77
|
+
if not hasattr(args, slurm_key):
|
|
78
|
+
raise ValueError(
|
|
79
|
+
f"ArgsType should have a field named `{slurm_key}` to use `slurm_launcher` decorator."
|
|
80
|
+
)
|
|
81
|
+
slurm_config: SlurmConfig = getattr(args, slurm_key)
|
|
82
|
+
|
|
83
|
+
def decorator(
|
|
84
|
+
submit_fn: Callable[..., Any],
|
|
85
|
+
) -> SlurmFunction:
|
|
86
|
+
return SlurmFunction(
|
|
87
|
+
submit_fn=submit_fn,
|
|
88
|
+
default_submit_fn_args=(args,),
|
|
89
|
+
).configure(
|
|
90
|
+
slurm_config,
|
|
91
|
+
slurm_params_kwargs,
|
|
92
|
+
slurm_submit_kwargs,
|
|
93
|
+
slurm_task_kwargs,
|
|
94
|
+
system_argv=argv,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
return decorator
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def slurm_function(
|
|
101
|
+
submit_fn: Callable,
|
|
102
|
+
):
|
|
103
|
+
"""
|
|
104
|
+
A decorator to annoate a function to be run in slurm. The function decorated by this decorator should be launched in the way below.
|
|
105
|
+
|
|
106
|
+
Deprecated:
|
|
107
|
+
This function is deprecated and will be removed in future versions. Please use `slurm_fn` instead.
|
|
108
|
+
|
|
109
|
+
Example:
|
|
110
|
+
>>> @slurm_function
|
|
111
|
+
... def run_on_slurm(a, b):
|
|
112
|
+
... return a + b
|
|
113
|
+
>>> slurm_config = SlurmConfig(
|
|
114
|
+
... mode="slurm",
|
|
115
|
+
... partition="PARTITION",
|
|
116
|
+
... job_name="EXAMPLE",
|
|
117
|
+
... tasks_per_node=1,
|
|
118
|
+
... cpus_per_task=8,
|
|
119
|
+
... mem="1GB",
|
|
120
|
+
... )
|
|
121
|
+
>>> job = run_on_slurm(slurm_config)(1, b=2)
|
|
122
|
+
>>> result = job.result() # block and get the result
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
def wrapper(
|
|
126
|
+
slurm_config: SlurmConfig,
|
|
127
|
+
slurm_params_kwargs: Dict[str, Any] = {},
|
|
128
|
+
slurm_submit_kwargs: Dict[str, Any] = {},
|
|
129
|
+
slurm_task_kwargs: Dict[str, Any] = {},
|
|
130
|
+
system_argv: Union[List[str], None] = None,
|
|
131
|
+
) -> SlurmFunction:
|
|
132
|
+
warn(
|
|
133
|
+
"`slurm_function` has been deprecated. Please use `slurm_fn` instead.",
|
|
134
|
+
DeprecationWarning,
|
|
135
|
+
stacklevel=2,
|
|
136
|
+
)
|
|
137
|
+
slurm_fn = SlurmFunction(
|
|
138
|
+
submit_fn=submit_fn,
|
|
139
|
+
).configure(
|
|
140
|
+
slurm_config,
|
|
141
|
+
slurm_params_kwargs,
|
|
142
|
+
slurm_submit_kwargs,
|
|
143
|
+
slurm_task_kwargs,
|
|
144
|
+
system_argv,
|
|
145
|
+
)
|
|
146
|
+
return slurm_fn
|
|
147
|
+
|
|
148
|
+
return wrapper
|