nntool 1.3.0__cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nntool might be problematic. Click here for more details.
- nntool/__init__.py +2 -0
- nntool/__pycache__/__init__.cpython-312.opt-1.pyc +0 -0
- nntool/__pycache__/__init__.cpython-312.pyc +0 -0
- nntool/experiment/__init__.py +6 -0
- nntool/experiment/__pycache__/__init__.cpython-312.opt-1.pyc +0 -0
- nntool/experiment/__pycache__/__init__.cpython-312.pyc +0 -0
- nntool/experiment/__pycache__/config.cpython-312.opt-1.pyc +0 -0
- nntool/experiment/__pycache__/config.cpython-312.pyc +0 -0
- nntool/experiment/__pycache__/utils.cpython-312.opt-1.pyc +0 -0
- nntool/experiment/__pycache__/utils.cpython-312.pyc +0 -0
- nntool/experiment/config.py +112 -0
- nntool/experiment/utils.py +63 -0
- nntool/parser/__init__.py +1 -0
- nntool/parser/__pycache__/__init__.cpython-312.opt-1.pyc +0 -0
- nntool/parser/__pycache__/__init__.cpython-312.pyc +0 -0
- nntool/parser/__pycache__/parse.cpython-312.opt-1.pyc +0 -0
- nntool/parser/__pycache__/parse.cpython-312.pyc +0 -0
- nntool/parser/parse.py +22 -0
- nntool/plot/__init__.py +6 -0
- nntool/plot/__pycache__/__init__.cpython-312.opt-1.pyc +0 -0
- nntool/plot/__pycache__/__init__.cpython-312.pyc +0 -0
- nntool/plot/__pycache__/context.cpython-312.opt-1.pyc +0 -0
- nntool/plot/__pycache__/context.cpython-312.pyc +0 -0
- nntool/plot/context.py +48 -0
- nntool/plot/csrc/__compile__.cpython-312-x86_64-linux-gnu.so +0 -0
- nntool/plot/csrc/__init__.py +3 -0
- nntool/plot/csrc/__pycache__/__init__.cpython-312.opt-1.pyc +0 -0
- nntool/plot/csrc/__pycache__/__init__.cpython-312.pyc +0 -0
- nntool/slurm/__init__.py +9 -0
- nntool/slurm/__pycache__/__init__.cpython-312.opt-1.pyc +0 -0
- nntool/slurm/__pycache__/__init__.cpython-312.pyc +0 -0
- nntool/slurm/__pycache__/config.cpython-312.opt-1.pyc +0 -0
- nntool/slurm/__pycache__/config.cpython-312.pyc +0 -0
- nntool/slurm/__pycache__/function.cpython-312.opt-1.pyc +0 -0
- nntool/slurm/__pycache__/function.cpython-312.pyc +0 -0
- nntool/slurm/__pycache__/task.cpython-312.opt-1.pyc +0 -0
- nntool/slurm/__pycache__/task.cpython-312.pyc +0 -0
- nntool/slurm/__pycache__/wrap.cpython-312.opt-1.pyc +0 -0
- nntool/slurm/__pycache__/wrap.cpython-312.pyc +0 -0
- nntool/slurm/accelerator/__init__.py +0 -0
- nntool/slurm/accelerator/__pycache__/__init__.cpython-312.opt-1.pyc +0 -0
- nntool/slurm/accelerator/__pycache__/__init__.cpython-312.pyc +0 -0
- nntool/slurm/accelerator/__pycache__/utils.cpython-312.opt-1.pyc +0 -0
- nntool/slurm/accelerator/__pycache__/utils.cpython-312.pyc +0 -0
- nntool/slurm/accelerator/utils.py +39 -0
- nntool/slurm/config.py +182 -0
- nntool/slurm/csrc/__compile__.cpython-312-x86_64-linux-gnu.so +0 -0
- nntool/slurm/csrc/__init__.py +5 -0
- nntool/slurm/csrc/__pycache__/__init__.cpython-312.opt-1.pyc +0 -0
- nntool/slurm/csrc/__pycache__/__init__.cpython-312.pyc +0 -0
- nntool/slurm/function.py +173 -0
- nntool/slurm/task.py +231 -0
- nntool/slurm/wrap.py +210 -0
- nntool/train/__init__.py +25 -0
- nntool/train/__pycache__/__init__.cpython-312.opt-1.pyc +0 -0
- nntool/train/__pycache__/__init__.cpython-312.pyc +0 -0
- nntool/train/__pycache__/trainer.cpython-312.opt-1.pyc +0 -0
- nntool/train/__pycache__/trainer.cpython-312.pyc +0 -0
- nntool/train/trainer.py +92 -0
- nntool/utils/__init__.py +6 -0
- nntool/utils/__pycache__/__init__.cpython-312.opt-1.pyc +0 -0
- nntool/utils/__pycache__/__init__.cpython-312.pyc +0 -0
- nntool/wandb/__init__.py +1 -0
- nntool/wandb/__pycache__/__init__.cpython-312.opt-1.pyc +0 -0
- nntool/wandb/__pycache__/__init__.cpython-312.pyc +0 -0
- nntool/wandb/__pycache__/config.cpython-312.opt-1.pyc +0 -0
- nntool/wandb/__pycache__/config.cpython-312.pyc +0 -0
- nntool/wandb/config.py +118 -0
- nntool-1.3.0.dist-info/._SOURCES.txt +0 -0
- nntool-1.3.0.dist-info/._dependency_links.txt +0 -0
- nntool-1.3.0.dist-info/._requires.txt +0 -0
- nntool-1.3.0.dist-info/._top_level.txt +0 -0
- nntool-1.3.0.dist-info/METADATA +25 -0
- nntool-1.3.0.dist-info/RECORD +76 -0
- nntool-1.3.0.dist-info/WHEEL +6 -0
- nntool-1.3.0.dist-info/top_level.txt +1 -0
nntool/slurm/function.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
import submitit
|
|
2
|
+
import copy
|
|
3
|
+
|
|
4
|
+
from submitit import Job
|
|
5
|
+
from typing import Any, Callable, Literal, Tuple, Union, Dict, List, Optional
|
|
6
|
+
from .config import SlurmConfig
|
|
7
|
+
from .csrc import _SlurmFunction
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SlurmFunction:
|
|
11
|
+
"""The function for the slurm job, which can be used for distributed or non-distributed job (controlled by `use_distributed_env` in the slurm dataclass)."""
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
submit_fn: Callable[..., Any],
|
|
16
|
+
default_submit_fn_args: Optional[Tuple[Any]] = None,
|
|
17
|
+
default_submit_fn_kwargs: Optional[Dict[str, Any]] = None,
|
|
18
|
+
) -> "SlurmFunction":
|
|
19
|
+
"""A slurm function for the slurm job, which can be used for distributed or non-distributed job (controlled by `use_distributed_env` in the slurm dataclass).
|
|
20
|
+
|
|
21
|
+
:param submit_fn: function to be submitted to Slurm, defaults to None
|
|
22
|
+
:param default_submit_fn_args: default args for submit_fn, defaults to ()
|
|
23
|
+
:param default_submit_fn_kwargs: default known word args for submit_fn, defaults to {}
|
|
24
|
+
:return: the wrapped submit function with configured slurm paramters
|
|
25
|
+
"""
|
|
26
|
+
self.engine = _SlurmFunction(
|
|
27
|
+
submit_fn, default_submit_fn_args, default_submit_fn_kwargs
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
def __create_copy(self) -> "SlurmFunction":
|
|
31
|
+
return copy.copy(self)
|
|
32
|
+
|
|
33
|
+
def is_configured(self) -> bool:
|
|
34
|
+
"""Whether the slurm function has been configured.
|
|
35
|
+
|
|
36
|
+
:return: True if the slurm function has been configured, False otherwise
|
|
37
|
+
"""
|
|
38
|
+
return self.engine.is_configured()
|
|
39
|
+
|
|
40
|
+
def is_distributed(self) -> bool:
|
|
41
|
+
"""Whether the slurm function is distributed.
|
|
42
|
+
|
|
43
|
+
:return: True if the slurm function is distributed, False otherwise
|
|
44
|
+
"""
|
|
45
|
+
return self.engine.is_distributed()
|
|
46
|
+
|
|
47
|
+
def get_executor(
|
|
48
|
+
self,
|
|
49
|
+
) -> submitit.AutoExecutor:
|
|
50
|
+
return self.engine.get_executor()
|
|
51
|
+
|
|
52
|
+
def configure(
|
|
53
|
+
self,
|
|
54
|
+
slurm_config: SlurmConfig,
|
|
55
|
+
slurm_params_kwargs: Optional[Dict[str, str]] = None,
|
|
56
|
+
slurm_submit_kwargs: Optional[Dict[str, str]] = None,
|
|
57
|
+
slurm_task_kwargs: Optional[Dict[str, str]] = None,
|
|
58
|
+
system_argv: Optional[List[str]] = None,
|
|
59
|
+
pack_code_include_fn: Optional[Callable[[str, str], bool]] = None,
|
|
60
|
+
pack_code_exclude_fn: Optional[Callable[[str, str], bool]] = None,
|
|
61
|
+
) -> "SlurmFunction":
|
|
62
|
+
"""Update the slurm configuration for the slurm function. A slurm function for the slurm job, which can be used for distributed or non-distributed job (controlled by `use_distributed_env` in the slurm dataclass).
|
|
63
|
+
|
|
64
|
+
**Exported Distributed Enviroment Variables**
|
|
65
|
+
|
|
66
|
+
- ``NNTOOL_SLURM_HAS_BEEN_SET_UP`` is a special environment variable to indicate that the slurm has been set up.
|
|
67
|
+
- After the set up, the distributed job will be launched and the following variables are exported:
|
|
68
|
+
- ``num_processes``: int
|
|
69
|
+
- ``num_machines``: int
|
|
70
|
+
- ``machine_rank``: int
|
|
71
|
+
- ``main_process_ip``: str
|
|
72
|
+
- ``main_process_port``: int
|
|
73
|
+
|
|
74
|
+
:param slurm_config: SlurmConfig, the slurm configuration dataclass, defaults to None
|
|
75
|
+
:param slurm_params_kwargs: extra slurm arguments for the slurm configuration, defaults to {}
|
|
76
|
+
:param slurm_submit_kwargs: extra slurm arguments for `srun` or `sbatch`, defaults to {}
|
|
77
|
+
:param slurm_task_kwargs: extra arguments for the setting of distributed task, defaults to {}
|
|
78
|
+
:param system_argv: the system arguments for the second launch in the distributed task (by default it will use the current system arguments `sys.argv[1:]`), defaults to None
|
|
79
|
+
:return: a new copy with configured slurm parameters
|
|
80
|
+
"""
|
|
81
|
+
configured_slurm_function = self.__create_copy()
|
|
82
|
+
configured_slurm_function.engine = self.engine.configure(
|
|
83
|
+
slurm_config,
|
|
84
|
+
slurm_params_kwargs,
|
|
85
|
+
slurm_submit_kwargs,
|
|
86
|
+
slurm_task_kwargs,
|
|
87
|
+
system_argv,
|
|
88
|
+
pack_code_include_fn,
|
|
89
|
+
pack_code_exclude_fn,
|
|
90
|
+
)
|
|
91
|
+
return configured_slurm_function
|
|
92
|
+
|
|
93
|
+
def __getitem__(
|
|
94
|
+
self, slurm_config: Union[Dict[str, Any], Tuple[Any], Any]
|
|
95
|
+
) -> "SlurmFunction":
|
|
96
|
+
"""Instantiate the slurm configuration for the slurm function. A slurm function for the slurm job, which can be used for distributed or non-distributed job (controlled by `use_distributed_env` in the slurm dataclass).
|
|
97
|
+
|
|
98
|
+
**Exported Distributed Enviroment Variables**
|
|
99
|
+
|
|
100
|
+
- ``NNTOOL_SLURM_HAS_BEEN_SET_UP`` is a special environment variable to indicate that the slurm has been set up.
|
|
101
|
+
- After the set up, the distributed job will be launched and the following variables are exported:
|
|
102
|
+
- ``num_processes``: int
|
|
103
|
+
- ``num_machines``: int
|
|
104
|
+
- ``machine_rank``: int
|
|
105
|
+
- ``main_process_ip``: str
|
|
106
|
+
- ``main_process_port``: int
|
|
107
|
+
|
|
108
|
+
:param slurm_config: SlurmConfig, the slurm configuration dataclass
|
|
109
|
+
:return: the wrapped submit function with configured slurm paramters
|
|
110
|
+
"""
|
|
111
|
+
configured_slurm_function = self.__create_copy()
|
|
112
|
+
configured_slurm_function.engine = self.engine[slurm_config]
|
|
113
|
+
return configured_slurm_function
|
|
114
|
+
|
|
115
|
+
def __call__(self, *submit_fn_args, **submit_fn_kwargs) -> Union[Job, Any]:
|
|
116
|
+
"""Run the submit_fn with the given arguments and keyword arguments. The function is non-blocking in the mode of `slurm`, while other modes cause blocking. If there is no given arguments or keyword arguments, the default arguments and keyword arguments will be used.
|
|
117
|
+
|
|
118
|
+
:raises Exception: if the submit_fn is not set up
|
|
119
|
+
:return: Slurm Job or the return value of the submit_fn
|
|
120
|
+
"""
|
|
121
|
+
return self.engine(*submit_fn_args, **submit_fn_kwargs)
|
|
122
|
+
|
|
123
|
+
def submit(self, *submit_fn_args, **submit_fn_kwargs) -> Union[Job, Any]:
|
|
124
|
+
"""An alias function to ``__call__``.
|
|
125
|
+
|
|
126
|
+
:raises Exception: if the submit_fn is not set up
|
|
127
|
+
:return: Slurm Job or the return value of the submit_fn
|
|
128
|
+
"""
|
|
129
|
+
return self(*submit_fn_args, **submit_fn_kwargs)
|
|
130
|
+
|
|
131
|
+
def map_array(self, *submit_fn_args, **submit_fn_kwargs) -> List[Job]:
|
|
132
|
+
"""Run the submit_fn with the given arguments and keyword arguments. The function is non-blocking in the mode of `slurm`, while other modes cause blocking. If there is no given arguments or keyword arguments, the default arguments and keyword arguments will be used.
|
|
133
|
+
|
|
134
|
+
:raises Exception: if the submit_fn is not set up
|
|
135
|
+
:return: Slurm Job or the return value of the submit_fn
|
|
136
|
+
"""
|
|
137
|
+
return self.engine.map_array(*submit_fn_args, **submit_fn_kwargs)
|
|
138
|
+
|
|
139
|
+
def on_condition(
|
|
140
|
+
self,
|
|
141
|
+
jobs: Union[Job, List[Job], Tuple[Job]],
|
|
142
|
+
condition: Literal["afterany", "afterok", "afternotok"] = "afterok",
|
|
143
|
+
) -> "SlurmFunction":
|
|
144
|
+
"""Mark this job should be executed after the provided slurm jobs have been done. This function allows combining different conditions by multiple calling.
|
|
145
|
+
|
|
146
|
+
:param jobs: dependent jobs
|
|
147
|
+
:param condition: run condition, defaults to "afterok"
|
|
148
|
+
:return: the function itself
|
|
149
|
+
"""
|
|
150
|
+
configured_slurm_function = self.__create_copy()
|
|
151
|
+
configured_slurm_function.engine = self.engine.on_condition(jobs, condition)
|
|
152
|
+
return configured_slurm_function
|
|
153
|
+
|
|
154
|
+
def afterok(self, *jobs: Tuple[Job]) -> "SlurmFunction":
|
|
155
|
+
"""Mark the function should be executed after the provided slurm jobs have been done.
|
|
156
|
+
|
|
157
|
+
:return: the function itself
|
|
158
|
+
"""
|
|
159
|
+
return self.on_condition(jobs, "afterok")
|
|
160
|
+
|
|
161
|
+
def afterany(self, *jobs: Tuple[Job]) -> "SlurmFunction":
|
|
162
|
+
"""Mark the function should be executed after any one of the provided slurm jobs has been done.
|
|
163
|
+
|
|
164
|
+
:return: the function itself
|
|
165
|
+
"""
|
|
166
|
+
return self.on_condition(jobs, "afterany")
|
|
167
|
+
|
|
168
|
+
def afternotok(self, *jobs: Tuple[Job]) -> "SlurmFunction":
|
|
169
|
+
"""Mark the function should be executed after any one of the provided slurm jobs has been failed.
|
|
170
|
+
|
|
171
|
+
:return: the function itself
|
|
172
|
+
"""
|
|
173
|
+
return self.on_condition(jobs, "afternotok")
|
nntool/slurm/task.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import shlex
|
|
3
|
+
import shutil
|
|
4
|
+
import submitit
|
|
5
|
+
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Union, Generator, Callable
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from .config import SlurmConfig
|
|
10
|
+
from .accelerator.utils import nvidia_smi_gpu_memory_stats_str
|
|
11
|
+
|
|
12
|
+
WANDB_DIRS = ("wandb", ".wandb")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _is_py_or_dockerfile(path: str) -> bool:
|
|
16
|
+
file = os.path.basename(path)
|
|
17
|
+
return file.endswith(".py") or file.startswith("Dockerfile")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def include_code_files(path: str, root: str, code_ext: list[str]):
|
|
21
|
+
file = os.path.basename(path)
|
|
22
|
+
return any(file.endswith(ext) for ext in code_ext) or file.startswith("Dockerfile")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def exclude_code_folders(path: str, root: str, code_folders: list[str]):
|
|
26
|
+
return any(
|
|
27
|
+
os.path.relpath(path, root).startswith(code_folders + os.sep)
|
|
28
|
+
for code_folders in code_folders
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def exclude_wandb_fn(path: str, root: str) -> bool:
|
|
33
|
+
return any(
|
|
34
|
+
os.path.relpath(path, root).startswith(wandb_dir + os.sep)
|
|
35
|
+
for wandb_dir in WANDB_DIRS
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def filtered_dir(
|
|
40
|
+
root: str,
|
|
41
|
+
include_fn: Union[Callable[[str, str], bool], Callable[[str], bool]],
|
|
42
|
+
exclude_fn: Union[Callable[[str, str], bool], Callable[[str], bool]],
|
|
43
|
+
) -> Generator[str, None, None]:
|
|
44
|
+
"""Simple generator to walk a directory."""
|
|
45
|
+
|
|
46
|
+
for dirpath, _, files in os.walk(root):
|
|
47
|
+
for fname in files:
|
|
48
|
+
file_path = os.path.join(dirpath, fname)
|
|
49
|
+
if include_fn(file_path, root) and not exclude_fn(file_path, root):
|
|
50
|
+
yield file_path
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def pack_code_files(
|
|
54
|
+
root: str,
|
|
55
|
+
target_root: str,
|
|
56
|
+
include_fn: Union[
|
|
57
|
+
Callable[[str, str], bool], Callable[[str], bool]
|
|
58
|
+
] = _is_py_or_dockerfile,
|
|
59
|
+
exclude_fn: Union[
|
|
60
|
+
Callable[[str, str], bool], Callable[[str], bool]
|
|
61
|
+
] = exclude_wandb_fn,
|
|
62
|
+
):
|
|
63
|
+
root = os.path.abspath(root)
|
|
64
|
+
code_root = Path(os.path.abspath(root))
|
|
65
|
+
code_target = Path(os.path.abspath(target_root)) / "code"
|
|
66
|
+
if not code_root.exists():
|
|
67
|
+
raise ValueError(f"Code root {code_root} does not exist.")
|
|
68
|
+
if not code_target.exists():
|
|
69
|
+
code_target.mkdir(parents=True)
|
|
70
|
+
|
|
71
|
+
for file_path in filtered_dir(root, include_fn, exclude_fn):
|
|
72
|
+
save_name = os.path.relpath(file_path, root)
|
|
73
|
+
sub_file_path, file_name = os.path.split(save_name)
|
|
74
|
+
sub_file_full_path = code_target / sub_file_path
|
|
75
|
+
if not sub_file_full_path.exists():
|
|
76
|
+
sub_file_full_path.mkdir(parents=True)
|
|
77
|
+
shutil.copy(file_path, sub_file_full_path / file_name)
|
|
78
|
+
|
|
79
|
+
return code_target
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def reconstruct_command_line(argv):
|
|
83
|
+
# Quote each argument that needs special handling (like spaces or shell characters)
|
|
84
|
+
# and join them with spaces to form the command line
|
|
85
|
+
return " ".join(shlex.quote(arg) for arg in argv)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class Task:
|
|
89
|
+
def __init__(
|
|
90
|
+
self, argv: list[str], slurm_config: SlurmConfig, verbose: bool = False
|
|
91
|
+
):
|
|
92
|
+
self.argv = argv
|
|
93
|
+
self.slurm_config = slurm_config
|
|
94
|
+
self.verbose = verbose
|
|
95
|
+
|
|
96
|
+
def log(self, msg: str):
|
|
97
|
+
if not self.verbose:
|
|
98
|
+
return
|
|
99
|
+
|
|
100
|
+
print(msg)
|
|
101
|
+
|
|
102
|
+
def command(self) -> str:
|
|
103
|
+
raise NotImplementedError
|
|
104
|
+
|
|
105
|
+
def checkpoint(self):
|
|
106
|
+
print("checkpointing")
|
|
107
|
+
return submitit.helpers.DelayedSubmission(self)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@dataclass
|
|
111
|
+
class DistributedTaskConfig:
|
|
112
|
+
num_processes: Union[int, str] = "$nntool_num_processes"
|
|
113
|
+
num_machines: Union[int, str] = "$nntool_num_machines"
|
|
114
|
+
machine_rank: Union[int, str] = "$nntool_machine_rank"
|
|
115
|
+
main_process_ip: str = "$nntool_main_process_ip"
|
|
116
|
+
main_process_port: Union[int, str] = "$nntool_main_process_port"
|
|
117
|
+
|
|
118
|
+
def export_bash(self, output_folder: str):
|
|
119
|
+
lines = ["#!/bin/bash"]
|
|
120
|
+
for k, v in self.__dict__.items():
|
|
121
|
+
lines.append(f"export nntool_{k}={v}")
|
|
122
|
+
with open(os.path.join(output_folder, "nntool_distributed_env.sh"), "w") as f:
|
|
123
|
+
f.write("\n".join(lines))
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class PyTorchDistributedTask(Task):
|
|
127
|
+
"""Ref:
|
|
128
|
+
https://github.com/huggingface/accelerate/issues/1239
|
|
129
|
+
https://github.com/yuvalkirstain/PickScore/blob/main/trainer/slurm_scripts/slurm_train.py
|
|
130
|
+
https://github.com/facebookincubator/submitit/pull/1703
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
def __init__(
|
|
134
|
+
self,
|
|
135
|
+
launch_cmd: str,
|
|
136
|
+
argv: list[str],
|
|
137
|
+
slurm_config: SlurmConfig,
|
|
138
|
+
verbose: bool = False,
|
|
139
|
+
**env_setup_kwargs,
|
|
140
|
+
):
|
|
141
|
+
super().__init__(argv, slurm_config, verbose)
|
|
142
|
+
self.launch_cmd = launch_cmd
|
|
143
|
+
self.env_setup_kwargs = env_setup_kwargs
|
|
144
|
+
|
|
145
|
+
# to be set up in the dist_set_up method
|
|
146
|
+
self.dist_args = DistributedTaskConfig()
|
|
147
|
+
self.dist_env = None
|
|
148
|
+
|
|
149
|
+
def dist_set_up(self):
|
|
150
|
+
self.log("running task on slurm")
|
|
151
|
+
self.log("exporting PyTorch distributed environment variables")
|
|
152
|
+
|
|
153
|
+
# prepare enviroment variables
|
|
154
|
+
dist_env = submitit.helpers.TorchDistributedEnvironment().export()
|
|
155
|
+
|
|
156
|
+
# other setup
|
|
157
|
+
env_setup = {
|
|
158
|
+
# "NCCL_DEBUG": "info",
|
|
159
|
+
# "CUDA_LAUNCH_BLOCKING": "1",
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
# set CUDA visible devices if slurm has scheduled GPUs otherwise use all GPUs (without setting
|
|
163
|
+
# CUDA_VISIBLE_DEVICES)
|
|
164
|
+
env_setup.update(
|
|
165
|
+
{"CUDA_VISIBLE_DEVICES": os.environ["SLURM_JOB_GPUS"]}
|
|
166
|
+
if "SLURM_JOB_GPUS" in os.environ
|
|
167
|
+
else {}
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
# other environment variables set by the user
|
|
171
|
+
env_setup.update(self.env_setup_kwargs)
|
|
172
|
+
|
|
173
|
+
# update environment variables
|
|
174
|
+
os.environ.update(**env_setup)
|
|
175
|
+
|
|
176
|
+
self.log(nvidia_smi_gpu_memory_stats_str())
|
|
177
|
+
self.log(f"master: {dist_env.master_addr}:{dist_env.master_port}")
|
|
178
|
+
self.log(f"rank: {dist_env.rank}")
|
|
179
|
+
self.log(f"world size: {dist_env.world_size}")
|
|
180
|
+
self.log(f"local rank: {dist_env.local_rank}")
|
|
181
|
+
self.log(f"local world size: {dist_env.local_world_size}")
|
|
182
|
+
self.log(
|
|
183
|
+
f"local rank {dist_env.local_rank}: CUDA_VISIBLE_DEVICES {os.environ.get('CUDA_VISIBLE_DEVICES', 'all')}"
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
# set distributed arguments
|
|
187
|
+
num_processes = (
|
|
188
|
+
self.slurm_config.tasks_per_node
|
|
189
|
+
* self.slurm_config.processes_per_task
|
|
190
|
+
* self.slurm_config.num_of_node
|
|
191
|
+
)
|
|
192
|
+
machine_rank = dist_env.rank // self.slurm_config.tasks_per_node
|
|
193
|
+
self.dist_args = DistributedTaskConfig(
|
|
194
|
+
num_processes=num_processes,
|
|
195
|
+
num_machines=self.slurm_config.num_of_node,
|
|
196
|
+
machine_rank=machine_rank,
|
|
197
|
+
main_process_ip=dist_env.master_addr,
|
|
198
|
+
main_process_port=dist_env.master_port,
|
|
199
|
+
)
|
|
200
|
+
self.dist_env = dist_env
|
|
201
|
+
|
|
202
|
+
return self.dist_args, self.dist_env
|
|
203
|
+
|
|
204
|
+
def command(self) -> str:
|
|
205
|
+
cmd = self.launch_cmd.format(**self.dist_args.__dict__)
|
|
206
|
+
cmd += " " + reconstruct_command_line(self.argv)
|
|
207
|
+
return cmd
|
|
208
|
+
|
|
209
|
+
def __call__(self):
|
|
210
|
+
# set up distributed environment
|
|
211
|
+
self.dist_set_up()
|
|
212
|
+
|
|
213
|
+
# job environment
|
|
214
|
+
job_env = submitit.helpers.JobEnvironment()
|
|
215
|
+
|
|
216
|
+
# concrete run command
|
|
217
|
+
cmd = self.command()
|
|
218
|
+
|
|
219
|
+
# export distributed environment variables
|
|
220
|
+
if self.dist_env.local_rank == 0:
|
|
221
|
+
print(f"running command: {cmd}")
|
|
222
|
+
if self.slurm_config.mode == "slurm":
|
|
223
|
+
try:
|
|
224
|
+
self.dist_args.export_bash(shlex.quote(str(job_env.paths.folder)))
|
|
225
|
+
except Exception as e:
|
|
226
|
+
print(f"failed to export distributed environment variables: {e}")
|
|
227
|
+
return -1
|
|
228
|
+
else:
|
|
229
|
+
return os.system(cmd)
|
|
230
|
+
|
|
231
|
+
return 0
|
nntool/slurm/wrap.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
|
|
3
|
+
from warnings import warn
|
|
4
|
+
from typing import Any, Callable, Type, Union, Dict, List
|
|
5
|
+
|
|
6
|
+
from .config import SlurmConfig
|
|
7
|
+
from .function import SlurmFunction
|
|
8
|
+
from ..parser import parse_from_cli
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def slurm_fn(
|
|
12
|
+
submit_fn: Callable,
|
|
13
|
+
) -> SlurmFunction:
|
|
14
|
+
"""A decorator to wrap a function to be run on slurm. The function decorated by this decorator should be launched on the way below. The decorated function `submit_fn` is non-blocking now. To block and get the return value, you can call ``job.result()``.
|
|
15
|
+
|
|
16
|
+
**Example**
|
|
17
|
+
|
|
18
|
+
Here's an example of how to use this function:
|
|
19
|
+
|
|
20
|
+
.. code-block:: python
|
|
21
|
+
|
|
22
|
+
@slurm_fn
|
|
23
|
+
def run_on_slurm(a, b):
|
|
24
|
+
return a + b
|
|
25
|
+
|
|
26
|
+
slurm_config = SlurmConfig(
|
|
27
|
+
mode="slurm",
|
|
28
|
+
partition="PARTITION",
|
|
29
|
+
job_name="EXAMPLE",
|
|
30
|
+
tasks_per_node=1,
|
|
31
|
+
cpus_per_task=8,
|
|
32
|
+
mem="1GB",
|
|
33
|
+
)
|
|
34
|
+
job = run_on_slurm[slurm_config](1, b=2)
|
|
35
|
+
result = job.result() # block and get the result
|
|
36
|
+
|
|
37
|
+
:param submit_fn: the function to be run on slurm
|
|
38
|
+
:return: the function to be run on slurm
|
|
39
|
+
"""
|
|
40
|
+
slurm_fn = SlurmFunction(submit_fn=submit_fn)
|
|
41
|
+
|
|
42
|
+
return slurm_fn
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def slurm_launcher(
|
|
46
|
+
ArgsType: Type[Any],
|
|
47
|
+
parser: Union[str, Callable] = "tyro",
|
|
48
|
+
slurm_key: str = "slurm",
|
|
49
|
+
slurm_params_kwargs: dict = {},
|
|
50
|
+
slurm_submit_kwargs: dict = {},
|
|
51
|
+
slurm_task_kwargs: dict = {},
|
|
52
|
+
*extra_args,
|
|
53
|
+
**extra_kwargs,
|
|
54
|
+
) -> Callable[[Callable[..., Any]], SlurmFunction]:
|
|
55
|
+
"""A slurm launcher decorator for distributed or non-distributed job (controlled by `use_distributed_env` in slurm field). This decorator should be used as the program entry. The decorated function is non-blocking in the mode of `slurm`, while other modes cause blocking.
|
|
56
|
+
|
|
57
|
+
#### Exported Distributed Enviroment Variables
|
|
58
|
+
1. NNTOOL_SLURM_HAS_BEEN_SET_UP is a special environment variable to indicate that the slurm has been set up.
|
|
59
|
+
2. After the set up, the distributed job will be launched and the following variables are exported: num_processes: int, num_machines: int, machine_rank: int, main_process_ip: str, main_process_port: int.
|
|
60
|
+
|
|
61
|
+
:param ArgsType: the experiment arguments type, which should be a dataclass (it
|
|
62
|
+
mush have a slurm field defined by `slurm_key`)
|
|
63
|
+
:param slurm_key: the key of the slurm field in the ArgsType, defaults to "slurm"
|
|
64
|
+
:param parser: the parser for the arguments, defaults to "tyro"
|
|
65
|
+
:param slurm_config: SlurmConfig, the slurm configuration dataclass
|
|
66
|
+
:param slurm_params_kwargs: extra slurm arguments for the slurm configuration, defaults to {}
|
|
67
|
+
:param slurm_submit_kwargs: extra slurm arguments for `srun` or `sbatch`, defaults to {}
|
|
68
|
+
:param slurm_task_kwargs: extra arguments for the setting of distributed task, defaults to {}
|
|
69
|
+
:param extra_args: extra arguments for the parser
|
|
70
|
+
:param extra_kwargs: extra keyword arguments for the parser
|
|
71
|
+
:return: decorator function with main entry
|
|
72
|
+
"""
|
|
73
|
+
argv = list(sys.argv[1:])
|
|
74
|
+
args = parse_from_cli(ArgsType, parser, *extra_args, **extra_kwargs)
|
|
75
|
+
|
|
76
|
+
# check if args have slurm field
|
|
77
|
+
if not hasattr(args, slurm_key):
|
|
78
|
+
raise ValueError(
|
|
79
|
+
f"ArgsType should have a field named `{slurm_key}` to use `slurm_launcher` decorator."
|
|
80
|
+
)
|
|
81
|
+
slurm_config: SlurmConfig = getattr(args, slurm_key)
|
|
82
|
+
|
|
83
|
+
def decorator(
|
|
84
|
+
submit_fn: Callable[..., Any],
|
|
85
|
+
) -> SlurmFunction:
|
|
86
|
+
return SlurmFunction(
|
|
87
|
+
submit_fn=submit_fn,
|
|
88
|
+
default_submit_fn_args=(args,),
|
|
89
|
+
).configure(
|
|
90
|
+
slurm_config,
|
|
91
|
+
slurm_params_kwargs,
|
|
92
|
+
slurm_submit_kwargs,
|
|
93
|
+
slurm_task_kwargs,
|
|
94
|
+
system_argv=argv,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
return decorator
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def slurm_distributed_launcher(
|
|
101
|
+
ArgsType: Type[Any],
|
|
102
|
+
parser: Union[str, Callable] = "tyro",
|
|
103
|
+
slurm_key: str = "slurm",
|
|
104
|
+
slurm_params_kwargs: dict = {},
|
|
105
|
+
slurm_submit_kwargs: dict = {},
|
|
106
|
+
slurm_task_kwargs: dict = {},
|
|
107
|
+
*extra_args,
|
|
108
|
+
**extra_kwargs,
|
|
109
|
+
) -> Callable[[Callable[..., Any]], SlurmFunction]:
|
|
110
|
+
"""A slurm launcher decorator for the distributed job. This decorator should be used for the distributed job only and as the program entry. The decorated function is non-blocking in the mode of `slurm`, while other modes cause blocking.
|
|
111
|
+
|
|
112
|
+
#### Exported Distributed Enviroment Variables
|
|
113
|
+
1. NNTOOL_SLURM_HAS_BEEN_SET_UP is a special environment variable to indicate that the slurm has been set up.
|
|
114
|
+
2. After the set up, the distributed job will be launched and the following variables are exported: num_processes: int, num_machines: int, machine_rank: int, main_process_ip: str, main_process_port: int.
|
|
115
|
+
|
|
116
|
+
:param ArgsType: the experiment arguments type, which should be a dataclass (it
|
|
117
|
+
mush have a slurm field defined by `slurm_key`)
|
|
118
|
+
:param slurm_key: the key of the slurm field in the ArgsType, defaults to "slurm"
|
|
119
|
+
:param parser: the parser for the arguments, defaults to "tyro"
|
|
120
|
+
:param slurm_config: SlurmConfig, the slurm configuration dataclass
|
|
121
|
+
:param slurm_params_kwargs: extra slurm arguments for the slurm configuration, defaults to {}
|
|
122
|
+
:param slurm_submit_kwargs: extra slurm arguments for `srun` or `sbatch`, defaults to {}
|
|
123
|
+
:param slurm_task_kwargs: extra arguments for the setting of distributed task, defaults to {}
|
|
124
|
+
:param extra_args: extra arguments for the parser
|
|
125
|
+
:param extra_kwargs: extra keyword arguments for the parser
|
|
126
|
+
:return: decorator function with main entry
|
|
127
|
+
"""
|
|
128
|
+
warn(
|
|
129
|
+
"`slurm_distributed_launcher` has been deprecated. Please use `slurm_launcher` instead, which supports both distributed and non-distributed job (controlled by `use_distributed_env` in slurm field).",
|
|
130
|
+
DeprecationWarning,
|
|
131
|
+
stacklevel=2,
|
|
132
|
+
)
|
|
133
|
+
argv = list(sys.argv[1:])
|
|
134
|
+
args = parse_from_cli(ArgsType, parser, *extra_args, **extra_kwargs)
|
|
135
|
+
|
|
136
|
+
# check if args have slurm field
|
|
137
|
+
if not hasattr(args, slurm_key):
|
|
138
|
+
raise ValueError(
|
|
139
|
+
f"ArgsType should have a field named `{slurm_key}` to use `slurm_distributed_launcher` decorator."
|
|
140
|
+
)
|
|
141
|
+
slurm_config: SlurmConfig = getattr(args, slurm_key)
|
|
142
|
+
|
|
143
|
+
def decorator(
|
|
144
|
+
submit_fn: Callable[..., Any],
|
|
145
|
+
) -> SlurmFunction:
|
|
146
|
+
return SlurmFunction(
|
|
147
|
+
submit_fn=submit_fn,
|
|
148
|
+
default_submit_fn_args=(args,),
|
|
149
|
+
).configure(
|
|
150
|
+
slurm_config,
|
|
151
|
+
slurm_params_kwargs,
|
|
152
|
+
slurm_submit_kwargs,
|
|
153
|
+
slurm_task_kwargs,
|
|
154
|
+
system_argv=argv,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
return decorator
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def slurm_function(
|
|
161
|
+
submit_fn: Callable,
|
|
162
|
+
):
|
|
163
|
+
"""A decorator to annoate a function to be run in slurm. The function decorated by this decorator should be launched in the way below.
|
|
164
|
+
```
|
|
165
|
+
@slurm_function
|
|
166
|
+
def run_in_slurm(*args, **kwargs):
|
|
167
|
+
pass
|
|
168
|
+
|
|
169
|
+
job = run_in_slurm(slurm_config)(*args, **kwargs)
|
|
170
|
+
```
|
|
171
|
+
The decorated function `submit_fn` is non-blocking now. To block and get the return value, you can call `job.result()`.
|
|
172
|
+
"""
|
|
173
|
+
|
|
174
|
+
def wrapper(
|
|
175
|
+
slurm_config: SlurmConfig,
|
|
176
|
+
slurm_params_kwargs: Dict[str, Any] = {},
|
|
177
|
+
slurm_submit_kwargs: Dict[str, Any] = {},
|
|
178
|
+
slurm_task_kwargs: Dict[str, Any] = {},
|
|
179
|
+
system_argv: Union[List[str], None] = None,
|
|
180
|
+
) -> SlurmFunction:
|
|
181
|
+
"""Update the slurm configuration for the slurm function.
|
|
182
|
+
|
|
183
|
+
#### Exported Distributed Enviroment Variables
|
|
184
|
+
1. NNTOOL_SLURM_HAS_BEEN_SET_UP is a special environment variable to indicate that the slurm has been set up.
|
|
185
|
+
2. After the set up, the distributed job will be launched and the following variables are exported: num_processes: int, num_machines: int, machine_rank: int, main_process_ip: str, main_process_port: int.
|
|
186
|
+
|
|
187
|
+
:param slurm_config: SlurmConfig, the slurm configuration dataclass
|
|
188
|
+
:param slurm_params_kwargs: extra slurm arguments for the slurm configuration, defaults to {}
|
|
189
|
+
:param slurm_submit_kwargs: extra slurm arguments for `srun` or `sbatch`, defaults to {}
|
|
190
|
+
:param slurm_task_kwargs: extra arguments for the setting of distributed task, defaults to {}
|
|
191
|
+
:param system_argv: the system arguments for the second launch in the distributed task (by default it will use the current system arguments `sys.argv[1:]`), defaults to None
|
|
192
|
+
:return: the wrapped submit function with configured slurm paramters
|
|
193
|
+
"""
|
|
194
|
+
warn(
|
|
195
|
+
"`slurm_function` has been deprecated. Please use `slurm_fn` instead, which supports both distributed and non-distributed job (controlled by `use_distributed_env` in slurm field).",
|
|
196
|
+
DeprecationWarning,
|
|
197
|
+
stacklevel=2,
|
|
198
|
+
)
|
|
199
|
+
slurm_fn = SlurmFunction(
|
|
200
|
+
submit_fn=submit_fn,
|
|
201
|
+
).configure(
|
|
202
|
+
slurm_config,
|
|
203
|
+
slurm_params_kwargs,
|
|
204
|
+
slurm_submit_kwargs,
|
|
205
|
+
slurm_task_kwargs,
|
|
206
|
+
system_argv,
|
|
207
|
+
)
|
|
208
|
+
return slurm_fn
|
|
209
|
+
|
|
210
|
+
return wrapper
|
nntool/train/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import importlib.metadata as importlib_metadata
|
|
2
|
+
import importlib.util
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def is_torch_available():
|
|
6
|
+
package_exists = importlib.util.find_spec("torch") is not None
|
|
7
|
+
|
|
8
|
+
# Check we're not importing a "torch" directory somewhere but the actual library by
|
|
9
|
+
# trying to grab the version
|
|
10
|
+
if package_exists:
|
|
11
|
+
try:
|
|
12
|
+
_ = importlib_metadata.metadata("torch")
|
|
13
|
+
return True
|
|
14
|
+
except importlib_metadata.PackageNotFoundError:
|
|
15
|
+
return False
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
if is_torch_available():
|
|
19
|
+
from .trainer import BaseTrainer
|
|
20
|
+
else:
|
|
21
|
+
# Inherits from a dummy `object` if torch is not available, so that python
|
|
22
|
+
# succeeds to import this file.
|
|
23
|
+
# BaseTrainer abstraction code will never inherit this dummy object as it checks if
|
|
24
|
+
# torch is available.
|
|
25
|
+
from builtins import object as BaseTrainer
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|