nntool 1.3.0__cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nntool might be problematic. Click here for more details.

Files changed (76) hide show
  1. nntool/__init__.py +2 -0
  2. nntool/__pycache__/__init__.cpython-311.opt-1.pyc +0 -0
  3. nntool/__pycache__/__init__.cpython-311.pyc +0 -0
  4. nntool/experiment/__init__.py +6 -0
  5. nntool/experiment/__pycache__/__init__.cpython-311.opt-1.pyc +0 -0
  6. nntool/experiment/__pycache__/__init__.cpython-311.pyc +0 -0
  7. nntool/experiment/__pycache__/config.cpython-311.opt-1.pyc +0 -0
  8. nntool/experiment/__pycache__/config.cpython-311.pyc +0 -0
  9. nntool/experiment/__pycache__/utils.cpython-311.opt-1.pyc +0 -0
  10. nntool/experiment/__pycache__/utils.cpython-311.pyc +0 -0
  11. nntool/experiment/config.py +112 -0
  12. nntool/experiment/utils.py +63 -0
  13. nntool/parser/__init__.py +1 -0
  14. nntool/parser/__pycache__/__init__.cpython-311.opt-1.pyc +0 -0
  15. nntool/parser/__pycache__/__init__.cpython-311.pyc +0 -0
  16. nntool/parser/__pycache__/parse.cpython-311.opt-1.pyc +0 -0
  17. nntool/parser/__pycache__/parse.cpython-311.pyc +0 -0
  18. nntool/parser/parse.py +22 -0
  19. nntool/plot/__init__.py +6 -0
  20. nntool/plot/__pycache__/__init__.cpython-311.opt-1.pyc +0 -0
  21. nntool/plot/__pycache__/__init__.cpython-311.pyc +0 -0
  22. nntool/plot/__pycache__/context.cpython-311.opt-1.pyc +0 -0
  23. nntool/plot/__pycache__/context.cpython-311.pyc +0 -0
  24. nntool/plot/context.py +48 -0
  25. nntool/plot/csrc/__compile__.cpython-311-x86_64-linux-gnu.so +0 -0
  26. nntool/plot/csrc/__init__.py +3 -0
  27. nntool/plot/csrc/__pycache__/__init__.cpython-311.opt-1.pyc +0 -0
  28. nntool/plot/csrc/__pycache__/__init__.cpython-311.pyc +0 -0
  29. nntool/slurm/__init__.py +9 -0
  30. nntool/slurm/__pycache__/__init__.cpython-311.opt-1.pyc +0 -0
  31. nntool/slurm/__pycache__/__init__.cpython-311.pyc +0 -0
  32. nntool/slurm/__pycache__/config.cpython-311.opt-1.pyc +0 -0
  33. nntool/slurm/__pycache__/config.cpython-311.pyc +0 -0
  34. nntool/slurm/__pycache__/function.cpython-311.opt-1.pyc +0 -0
  35. nntool/slurm/__pycache__/function.cpython-311.pyc +0 -0
  36. nntool/slurm/__pycache__/task.cpython-311.opt-1.pyc +0 -0
  37. nntool/slurm/__pycache__/task.cpython-311.pyc +0 -0
  38. nntool/slurm/__pycache__/wrap.cpython-311.opt-1.pyc +0 -0
  39. nntool/slurm/__pycache__/wrap.cpython-311.pyc +0 -0
  40. nntool/slurm/accelerator/__init__.py +0 -0
  41. nntool/slurm/accelerator/__pycache__/__init__.cpython-311.opt-1.pyc +0 -0
  42. nntool/slurm/accelerator/__pycache__/__init__.cpython-311.pyc +0 -0
  43. nntool/slurm/accelerator/__pycache__/utils.cpython-311.opt-1.pyc +0 -0
  44. nntool/slurm/accelerator/__pycache__/utils.cpython-311.pyc +0 -0
  45. nntool/slurm/accelerator/utils.py +39 -0
  46. nntool/slurm/config.py +182 -0
  47. nntool/slurm/csrc/__compile__.cpython-311-x86_64-linux-gnu.so +0 -0
  48. nntool/slurm/csrc/__init__.py +5 -0
  49. nntool/slurm/csrc/__pycache__/__init__.cpython-311.opt-1.pyc +0 -0
  50. nntool/slurm/csrc/__pycache__/__init__.cpython-311.pyc +0 -0
  51. nntool/slurm/function.py +173 -0
  52. nntool/slurm/task.py +231 -0
  53. nntool/slurm/wrap.py +210 -0
  54. nntool/train/__init__.py +25 -0
  55. nntool/train/__pycache__/__init__.cpython-311.opt-1.pyc +0 -0
  56. nntool/train/__pycache__/__init__.cpython-311.pyc +0 -0
  57. nntool/train/__pycache__/trainer.cpython-311.opt-1.pyc +0 -0
  58. nntool/train/__pycache__/trainer.cpython-311.pyc +0 -0
  59. nntool/train/trainer.py +92 -0
  60. nntool/utils/__init__.py +6 -0
  61. nntool/utils/__pycache__/__init__.cpython-311.opt-1.pyc +0 -0
  62. nntool/utils/__pycache__/__init__.cpython-311.pyc +0 -0
  63. nntool/wandb/__init__.py +1 -0
  64. nntool/wandb/__pycache__/__init__.cpython-311.opt-1.pyc +0 -0
  65. nntool/wandb/__pycache__/__init__.cpython-311.pyc +0 -0
  66. nntool/wandb/__pycache__/config.cpython-311.opt-1.pyc +0 -0
  67. nntool/wandb/__pycache__/config.cpython-311.pyc +0 -0
  68. nntool/wandb/config.py +118 -0
  69. nntool-1.3.0.dist-info/._SOURCES.txt +0 -0
  70. nntool-1.3.0.dist-info/._dependency_links.txt +0 -0
  71. nntool-1.3.0.dist-info/._requires.txt +0 -0
  72. nntool-1.3.0.dist-info/._top_level.txt +0 -0
  73. nntool-1.3.0.dist-info/METADATA +25 -0
  74. nntool-1.3.0.dist-info/RECORD +76 -0
  75. nntool-1.3.0.dist-info/WHEEL +6 -0
  76. nntool-1.3.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,173 @@
1
+ import submitit
2
+ import copy
3
+
4
+ from submitit import Job
5
+ from typing import Any, Callable, Literal, Tuple, Union, Dict, List, Optional
6
+ from .config import SlurmConfig
7
+ from .csrc import _SlurmFunction
8
+
9
+
10
+ class SlurmFunction:
11
+ """The function for the slurm job, which can be used for distributed or non-distributed job (controlled by `use_distributed_env` in the slurm dataclass)."""
12
+
13
+ def __init__(
14
+ self,
15
+ submit_fn: Callable[..., Any],
16
+ default_submit_fn_args: Optional[Tuple[Any]] = None,
17
+ default_submit_fn_kwargs: Optional[Dict[str, Any]] = None,
18
+ ) -> "SlurmFunction":
19
+ """A slurm function for the slurm job, which can be used for distributed or non-distributed job (controlled by `use_distributed_env` in the slurm dataclass).
20
+
21
+ :param submit_fn: function to be submitted to Slurm, defaults to None
22
+ :param default_submit_fn_args: default args for submit_fn, defaults to ()
23
+ :param default_submit_fn_kwargs: default known word args for submit_fn, defaults to {}
24
+ :return: the wrapped submit function with configured slurm paramters
25
+ """
26
+ self.engine = _SlurmFunction(
27
+ submit_fn, default_submit_fn_args, default_submit_fn_kwargs
28
+ )
29
+
30
+ def __create_copy(self) -> "SlurmFunction":
31
+ return copy.copy(self)
32
+
33
+ def is_configured(self) -> bool:
34
+ """Whether the slurm function has been configured.
35
+
36
+ :return: True if the slurm function has been configured, False otherwise
37
+ """
38
+ return self.engine.is_configured()
39
+
40
+ def is_distributed(self) -> bool:
41
+ """Whether the slurm function is distributed.
42
+
43
+ :return: True if the slurm function is distributed, False otherwise
44
+ """
45
+ return self.engine.is_distributed()
46
+
47
+ def get_executor(
48
+ self,
49
+ ) -> submitit.AutoExecutor:
50
+ return self.engine.get_executor()
51
+
52
+ def configure(
53
+ self,
54
+ slurm_config: SlurmConfig,
55
+ slurm_params_kwargs: Optional[Dict[str, str]] = None,
56
+ slurm_submit_kwargs: Optional[Dict[str, str]] = None,
57
+ slurm_task_kwargs: Optional[Dict[str, str]] = None,
58
+ system_argv: Optional[List[str]] = None,
59
+ pack_code_include_fn: Optional[Callable[[str, str], bool]] = None,
60
+ pack_code_exclude_fn: Optional[Callable[[str, str], bool]] = None,
61
+ ) -> "SlurmFunction":
62
+ """Update the slurm configuration for the slurm function. A slurm function for the slurm job, which can be used for distributed or non-distributed job (controlled by `use_distributed_env` in the slurm dataclass).
63
+
64
+ **Exported Distributed Enviroment Variables**
65
+
66
+ - ``NNTOOL_SLURM_HAS_BEEN_SET_UP`` is a special environment variable to indicate that the slurm has been set up.
67
+ - After the set up, the distributed job will be launched and the following variables are exported:
68
+ - ``num_processes``: int
69
+ - ``num_machines``: int
70
+ - ``machine_rank``: int
71
+ - ``main_process_ip``: str
72
+ - ``main_process_port``: int
73
+
74
+ :param slurm_config: SlurmConfig, the slurm configuration dataclass, defaults to None
75
+ :param slurm_params_kwargs: extra slurm arguments for the slurm configuration, defaults to {}
76
+ :param slurm_submit_kwargs: extra slurm arguments for `srun` or `sbatch`, defaults to {}
77
+ :param slurm_task_kwargs: extra arguments for the setting of distributed task, defaults to {}
78
+ :param system_argv: the system arguments for the second launch in the distributed task (by default it will use the current system arguments `sys.argv[1:]`), defaults to None
79
+ :return: a new copy with configured slurm parameters
80
+ """
81
+ configured_slurm_function = self.__create_copy()
82
+ configured_slurm_function.engine = self.engine.configure(
83
+ slurm_config,
84
+ slurm_params_kwargs,
85
+ slurm_submit_kwargs,
86
+ slurm_task_kwargs,
87
+ system_argv,
88
+ pack_code_include_fn,
89
+ pack_code_exclude_fn,
90
+ )
91
+ return configured_slurm_function
92
+
93
+ def __getitem__(
94
+ self, slurm_config: Union[Dict[str, Any], Tuple[Any], Any]
95
+ ) -> "SlurmFunction":
96
+ """Instantiate the slurm configuration for the slurm function. A slurm function for the slurm job, which can be used for distributed or non-distributed job (controlled by `use_distributed_env` in the slurm dataclass).
97
+
98
+ **Exported Distributed Enviroment Variables**
99
+
100
+ - ``NNTOOL_SLURM_HAS_BEEN_SET_UP`` is a special environment variable to indicate that the slurm has been set up.
101
+ - After the set up, the distributed job will be launched and the following variables are exported:
102
+ - ``num_processes``: int
103
+ - ``num_machines``: int
104
+ - ``machine_rank``: int
105
+ - ``main_process_ip``: str
106
+ - ``main_process_port``: int
107
+
108
+ :param slurm_config: SlurmConfig, the slurm configuration dataclass
109
+ :return: the wrapped submit function with configured slurm paramters
110
+ """
111
+ configured_slurm_function = self.__create_copy()
112
+ configured_slurm_function.engine = self.engine[slurm_config]
113
+ return configured_slurm_function
114
+
115
+ def __call__(self, *submit_fn_args, **submit_fn_kwargs) -> Union[Job, Any]:
116
+ """Run the submit_fn with the given arguments and keyword arguments. The function is non-blocking in the mode of `slurm`, while other modes cause blocking. If there is no given arguments or keyword arguments, the default arguments and keyword arguments will be used.
117
+
118
+ :raises Exception: if the submit_fn is not set up
119
+ :return: Slurm Job or the return value of the submit_fn
120
+ """
121
+ return self.engine(*submit_fn_args, **submit_fn_kwargs)
122
+
123
+ def submit(self, *submit_fn_args, **submit_fn_kwargs) -> Union[Job, Any]:
124
+ """An alias function to ``__call__``.
125
+
126
+ :raises Exception: if the submit_fn is not set up
127
+ :return: Slurm Job or the return value of the submit_fn
128
+ """
129
+ return self(*submit_fn_args, **submit_fn_kwargs)
130
+
131
+ def map_array(self, *submit_fn_args, **submit_fn_kwargs) -> List[Job]:
132
+ """Run the submit_fn with the given arguments and keyword arguments. The function is non-blocking in the mode of `slurm`, while other modes cause blocking. If there is no given arguments or keyword arguments, the default arguments and keyword arguments will be used.
133
+
134
+ :raises Exception: if the submit_fn is not set up
135
+ :return: Slurm Job or the return value of the submit_fn
136
+ """
137
+ return self.engine.map_array(*submit_fn_args, **submit_fn_kwargs)
138
+
139
+ def on_condition(
140
+ self,
141
+ jobs: Union[Job, List[Job], Tuple[Job]],
142
+ condition: Literal["afterany", "afterok", "afternotok"] = "afterok",
143
+ ) -> "SlurmFunction":
144
+ """Mark this job should be executed after the provided slurm jobs have been done. This function allows combining different conditions by multiple calling.
145
+
146
+ :param jobs: dependent jobs
147
+ :param condition: run condition, defaults to "afterok"
148
+ :return: the function itself
149
+ """
150
+ configured_slurm_function = self.__create_copy()
151
+ configured_slurm_function.engine = self.engine.on_condition(jobs, condition)
152
+ return configured_slurm_function
153
+
154
+ def afterok(self, *jobs: Tuple[Job]) -> "SlurmFunction":
155
+ """Mark the function should be executed after the provided slurm jobs have been done.
156
+
157
+ :return: the function itself
158
+ """
159
+ return self.on_condition(jobs, "afterok")
160
+
161
+ def afterany(self, *jobs: Tuple[Job]) -> "SlurmFunction":
162
+ """Mark the function should be executed after any one of the provided slurm jobs has been done.
163
+
164
+ :return: the function itself
165
+ """
166
+ return self.on_condition(jobs, "afterany")
167
+
168
+ def afternotok(self, *jobs: Tuple[Job]) -> "SlurmFunction":
169
+ """Mark the function should be executed after any one of the provided slurm jobs has been failed.
170
+
171
+ :return: the function itself
172
+ """
173
+ return self.on_condition(jobs, "afternotok")
nntool/slurm/task.py ADDED
@@ -0,0 +1,231 @@
1
+ import os
2
+ import shlex
3
+ import shutil
4
+ import submitit
5
+
6
+ from pathlib import Path
7
+ from typing import Union, Generator, Callable
8
+ from dataclasses import dataclass
9
+ from .config import SlurmConfig
10
+ from .accelerator.utils import nvidia_smi_gpu_memory_stats_str
11
+
12
+ WANDB_DIRS = ("wandb", ".wandb")
13
+
14
+
15
+ def _is_py_or_dockerfile(path: str) -> bool:
16
+ file = os.path.basename(path)
17
+ return file.endswith(".py") or file.startswith("Dockerfile")
18
+
19
+
20
+ def include_code_files(path: str, root: str, code_ext: list[str]):
21
+ file = os.path.basename(path)
22
+ return any(file.endswith(ext) for ext in code_ext) or file.startswith("Dockerfile")
23
+
24
+
25
+ def exclude_code_folders(path: str, root: str, code_folders: list[str]):
26
+ return any(
27
+ os.path.relpath(path, root).startswith(code_folders + os.sep)
28
+ for code_folders in code_folders
29
+ )
30
+
31
+
32
+ def exclude_wandb_fn(path: str, root: str) -> bool:
33
+ return any(
34
+ os.path.relpath(path, root).startswith(wandb_dir + os.sep)
35
+ for wandb_dir in WANDB_DIRS
36
+ )
37
+
38
+
39
+ def filtered_dir(
40
+ root: str,
41
+ include_fn: Union[Callable[[str, str], bool], Callable[[str], bool]],
42
+ exclude_fn: Union[Callable[[str, str], bool], Callable[[str], bool]],
43
+ ) -> Generator[str, None, None]:
44
+ """Simple generator to walk a directory."""
45
+
46
+ for dirpath, _, files in os.walk(root):
47
+ for fname in files:
48
+ file_path = os.path.join(dirpath, fname)
49
+ if include_fn(file_path, root) and not exclude_fn(file_path, root):
50
+ yield file_path
51
+
52
+
53
+ def pack_code_files(
54
+ root: str,
55
+ target_root: str,
56
+ include_fn: Union[
57
+ Callable[[str, str], bool], Callable[[str], bool]
58
+ ] = _is_py_or_dockerfile,
59
+ exclude_fn: Union[
60
+ Callable[[str, str], bool], Callable[[str], bool]
61
+ ] = exclude_wandb_fn,
62
+ ):
63
+ root = os.path.abspath(root)
64
+ code_root = Path(os.path.abspath(root))
65
+ code_target = Path(os.path.abspath(target_root)) / "code"
66
+ if not code_root.exists():
67
+ raise ValueError(f"Code root {code_root} does not exist.")
68
+ if not code_target.exists():
69
+ code_target.mkdir(parents=True)
70
+
71
+ for file_path in filtered_dir(root, include_fn, exclude_fn):
72
+ save_name = os.path.relpath(file_path, root)
73
+ sub_file_path, file_name = os.path.split(save_name)
74
+ sub_file_full_path = code_target / sub_file_path
75
+ if not sub_file_full_path.exists():
76
+ sub_file_full_path.mkdir(parents=True)
77
+ shutil.copy(file_path, sub_file_full_path / file_name)
78
+
79
+ return code_target
80
+
81
+
82
+ def reconstruct_command_line(argv):
83
+ # Quote each argument that needs special handling (like spaces or shell characters)
84
+ # and join them with spaces to form the command line
85
+ return " ".join(shlex.quote(arg) for arg in argv)
86
+
87
+
88
+ class Task:
89
+ def __init__(
90
+ self, argv: list[str], slurm_config: SlurmConfig, verbose: bool = False
91
+ ):
92
+ self.argv = argv
93
+ self.slurm_config = slurm_config
94
+ self.verbose = verbose
95
+
96
+ def log(self, msg: str):
97
+ if not self.verbose:
98
+ return
99
+
100
+ print(msg)
101
+
102
+ def command(self) -> str:
103
+ raise NotImplementedError
104
+
105
+ def checkpoint(self):
106
+ print("checkpointing")
107
+ return submitit.helpers.DelayedSubmission(self)
108
+
109
+
110
+ @dataclass
111
+ class DistributedTaskConfig:
112
+ num_processes: Union[int, str] = "$nntool_num_processes"
113
+ num_machines: Union[int, str] = "$nntool_num_machines"
114
+ machine_rank: Union[int, str] = "$nntool_machine_rank"
115
+ main_process_ip: str = "$nntool_main_process_ip"
116
+ main_process_port: Union[int, str] = "$nntool_main_process_port"
117
+
118
+ def export_bash(self, output_folder: str):
119
+ lines = ["#!/bin/bash"]
120
+ for k, v in self.__dict__.items():
121
+ lines.append(f"export nntool_{k}={v}")
122
+ with open(os.path.join(output_folder, "nntool_distributed_env.sh"), "w") as f:
123
+ f.write("\n".join(lines))
124
+
125
+
126
+ class PyTorchDistributedTask(Task):
127
+ """Ref:
128
+ https://github.com/huggingface/accelerate/issues/1239
129
+ https://github.com/yuvalkirstain/PickScore/blob/main/trainer/slurm_scripts/slurm_train.py
130
+ https://github.com/facebookincubator/submitit/pull/1703
131
+ """
132
+
133
+ def __init__(
134
+ self,
135
+ launch_cmd: str,
136
+ argv: list[str],
137
+ slurm_config: SlurmConfig,
138
+ verbose: bool = False,
139
+ **env_setup_kwargs,
140
+ ):
141
+ super().__init__(argv, slurm_config, verbose)
142
+ self.launch_cmd = launch_cmd
143
+ self.env_setup_kwargs = env_setup_kwargs
144
+
145
+ # to be set up in the dist_set_up method
146
+ self.dist_args = DistributedTaskConfig()
147
+ self.dist_env = None
148
+
149
+ def dist_set_up(self):
150
+ self.log("running task on slurm")
151
+ self.log("exporting PyTorch distributed environment variables")
152
+
153
+ # prepare enviroment variables
154
+ dist_env = submitit.helpers.TorchDistributedEnvironment().export()
155
+
156
+ # other setup
157
+ env_setup = {
158
+ # "NCCL_DEBUG": "info",
159
+ # "CUDA_LAUNCH_BLOCKING": "1",
160
+ }
161
+
162
+ # set CUDA visible devices if slurm has scheduled GPUs otherwise use all GPUs (without setting
163
+ # CUDA_VISIBLE_DEVICES)
164
+ env_setup.update(
165
+ {"CUDA_VISIBLE_DEVICES": os.environ["SLURM_JOB_GPUS"]}
166
+ if "SLURM_JOB_GPUS" in os.environ
167
+ else {}
168
+ )
169
+
170
+ # other environment variables set by the user
171
+ env_setup.update(self.env_setup_kwargs)
172
+
173
+ # update environment variables
174
+ os.environ.update(**env_setup)
175
+
176
+ self.log(nvidia_smi_gpu_memory_stats_str())
177
+ self.log(f"master: {dist_env.master_addr}:{dist_env.master_port}")
178
+ self.log(f"rank: {dist_env.rank}")
179
+ self.log(f"world size: {dist_env.world_size}")
180
+ self.log(f"local rank: {dist_env.local_rank}")
181
+ self.log(f"local world size: {dist_env.local_world_size}")
182
+ self.log(
183
+ f"local rank {dist_env.local_rank}: CUDA_VISIBLE_DEVICES {os.environ.get('CUDA_VISIBLE_DEVICES', 'all')}"
184
+ )
185
+
186
+ # set distributed arguments
187
+ num_processes = (
188
+ self.slurm_config.tasks_per_node
189
+ * self.slurm_config.processes_per_task
190
+ * self.slurm_config.num_of_node
191
+ )
192
+ machine_rank = dist_env.rank // self.slurm_config.tasks_per_node
193
+ self.dist_args = DistributedTaskConfig(
194
+ num_processes=num_processes,
195
+ num_machines=self.slurm_config.num_of_node,
196
+ machine_rank=machine_rank,
197
+ main_process_ip=dist_env.master_addr,
198
+ main_process_port=dist_env.master_port,
199
+ )
200
+ self.dist_env = dist_env
201
+
202
+ return self.dist_args, self.dist_env
203
+
204
+ def command(self) -> str:
205
+ cmd = self.launch_cmd.format(**self.dist_args.__dict__)
206
+ cmd += " " + reconstruct_command_line(self.argv)
207
+ return cmd
208
+
209
+ def __call__(self):
210
+ # set up distributed environment
211
+ self.dist_set_up()
212
+
213
+ # job environment
214
+ job_env = submitit.helpers.JobEnvironment()
215
+
216
+ # concrete run command
217
+ cmd = self.command()
218
+
219
+ # export distributed environment variables
220
+ if self.dist_env.local_rank == 0:
221
+ print(f"running command: {cmd}")
222
+ if self.slurm_config.mode == "slurm":
223
+ try:
224
+ self.dist_args.export_bash(shlex.quote(str(job_env.paths.folder)))
225
+ except Exception as e:
226
+ print(f"failed to export distributed environment variables: {e}")
227
+ return -1
228
+ else:
229
+ return os.system(cmd)
230
+
231
+ return 0
nntool/slurm/wrap.py ADDED
@@ -0,0 +1,210 @@
1
+ import sys
2
+
3
+ from warnings import warn
4
+ from typing import Any, Callable, Type, Union, Dict, List
5
+
6
+ from .config import SlurmConfig
7
+ from .function import SlurmFunction
8
+ from ..parser import parse_from_cli
9
+
10
+
11
+ def slurm_fn(
12
+ submit_fn: Callable,
13
+ ) -> SlurmFunction:
14
+ """A decorator to wrap a function to be run on slurm. The function decorated by this decorator should be launched on the way below. The decorated function `submit_fn` is non-blocking now. To block and get the return value, you can call ``job.result()``.
15
+
16
+ **Example**
17
+
18
+ Here's an example of how to use this function:
19
+
20
+ .. code-block:: python
21
+
22
+ @slurm_fn
23
+ def run_on_slurm(a, b):
24
+ return a + b
25
+
26
+ slurm_config = SlurmConfig(
27
+ mode="slurm",
28
+ partition="PARTITION",
29
+ job_name="EXAMPLE",
30
+ tasks_per_node=1,
31
+ cpus_per_task=8,
32
+ mem="1GB",
33
+ )
34
+ job = run_on_slurm[slurm_config](1, b=2)
35
+ result = job.result() # block and get the result
36
+
37
+ :param submit_fn: the function to be run on slurm
38
+ :return: the function to be run on slurm
39
+ """
40
+ slurm_fn = SlurmFunction(submit_fn=submit_fn)
41
+
42
+ return slurm_fn
43
+
44
+
45
+ def slurm_launcher(
46
+ ArgsType: Type[Any],
47
+ parser: Union[str, Callable] = "tyro",
48
+ slurm_key: str = "slurm",
49
+ slurm_params_kwargs: dict = {},
50
+ slurm_submit_kwargs: dict = {},
51
+ slurm_task_kwargs: dict = {},
52
+ *extra_args,
53
+ **extra_kwargs,
54
+ ) -> Callable[[Callable[..., Any]], SlurmFunction]:
55
+ """A slurm launcher decorator for distributed or non-distributed job (controlled by `use_distributed_env` in slurm field). This decorator should be used as the program entry. The decorated function is non-blocking in the mode of `slurm`, while other modes cause blocking.
56
+
57
+ #### Exported Distributed Enviroment Variables
58
+ 1. NNTOOL_SLURM_HAS_BEEN_SET_UP is a special environment variable to indicate that the slurm has been set up.
59
+ 2. After the set up, the distributed job will be launched and the following variables are exported: num_processes: int, num_machines: int, machine_rank: int, main_process_ip: str, main_process_port: int.
60
+
61
+ :param ArgsType: the experiment arguments type, which should be a dataclass (it
62
+ mush have a slurm field defined by `slurm_key`)
63
+ :param slurm_key: the key of the slurm field in the ArgsType, defaults to "slurm"
64
+ :param parser: the parser for the arguments, defaults to "tyro"
65
+ :param slurm_config: SlurmConfig, the slurm configuration dataclass
66
+ :param slurm_params_kwargs: extra slurm arguments for the slurm configuration, defaults to {}
67
+ :param slurm_submit_kwargs: extra slurm arguments for `srun` or `sbatch`, defaults to {}
68
+ :param slurm_task_kwargs: extra arguments for the setting of distributed task, defaults to {}
69
+ :param extra_args: extra arguments for the parser
70
+ :param extra_kwargs: extra keyword arguments for the parser
71
+ :return: decorator function with main entry
72
+ """
73
+ argv = list(sys.argv[1:])
74
+ args = parse_from_cli(ArgsType, parser, *extra_args, **extra_kwargs)
75
+
76
+ # check if args have slurm field
77
+ if not hasattr(args, slurm_key):
78
+ raise ValueError(
79
+ f"ArgsType should have a field named `{slurm_key}` to use `slurm_launcher` decorator."
80
+ )
81
+ slurm_config: SlurmConfig = getattr(args, slurm_key)
82
+
83
+ def decorator(
84
+ submit_fn: Callable[..., Any],
85
+ ) -> SlurmFunction:
86
+ return SlurmFunction(
87
+ submit_fn=submit_fn,
88
+ default_submit_fn_args=(args,),
89
+ ).configure(
90
+ slurm_config,
91
+ slurm_params_kwargs,
92
+ slurm_submit_kwargs,
93
+ slurm_task_kwargs,
94
+ system_argv=argv,
95
+ )
96
+
97
+ return decorator
98
+
99
+
100
+ def slurm_distributed_launcher(
101
+ ArgsType: Type[Any],
102
+ parser: Union[str, Callable] = "tyro",
103
+ slurm_key: str = "slurm",
104
+ slurm_params_kwargs: dict = {},
105
+ slurm_submit_kwargs: dict = {},
106
+ slurm_task_kwargs: dict = {},
107
+ *extra_args,
108
+ **extra_kwargs,
109
+ ) -> Callable[[Callable[..., Any]], SlurmFunction]:
110
+ """A slurm launcher decorator for the distributed job. This decorator should be used for the distributed job only and as the program entry. The decorated function is non-blocking in the mode of `slurm`, while other modes cause blocking.
111
+
112
+ #### Exported Distributed Enviroment Variables
113
+ 1. NNTOOL_SLURM_HAS_BEEN_SET_UP is a special environment variable to indicate that the slurm has been set up.
114
+ 2. After the set up, the distributed job will be launched and the following variables are exported: num_processes: int, num_machines: int, machine_rank: int, main_process_ip: str, main_process_port: int.
115
+
116
+ :param ArgsType: the experiment arguments type, which should be a dataclass (it
117
+ mush have a slurm field defined by `slurm_key`)
118
+ :param slurm_key: the key of the slurm field in the ArgsType, defaults to "slurm"
119
+ :param parser: the parser for the arguments, defaults to "tyro"
120
+ :param slurm_config: SlurmConfig, the slurm configuration dataclass
121
+ :param slurm_params_kwargs: extra slurm arguments for the slurm configuration, defaults to {}
122
+ :param slurm_submit_kwargs: extra slurm arguments for `srun` or `sbatch`, defaults to {}
123
+ :param slurm_task_kwargs: extra arguments for the setting of distributed task, defaults to {}
124
+ :param extra_args: extra arguments for the parser
125
+ :param extra_kwargs: extra keyword arguments for the parser
126
+ :return: decorator function with main entry
127
+ """
128
+ warn(
129
+ "`slurm_distributed_launcher` has been deprecated. Please use `slurm_launcher` instead, which supports both distributed and non-distributed job (controlled by `use_distributed_env` in slurm field).",
130
+ DeprecationWarning,
131
+ stacklevel=2,
132
+ )
133
+ argv = list(sys.argv[1:])
134
+ args = parse_from_cli(ArgsType, parser, *extra_args, **extra_kwargs)
135
+
136
+ # check if args have slurm field
137
+ if not hasattr(args, slurm_key):
138
+ raise ValueError(
139
+ f"ArgsType should have a field named `{slurm_key}` to use `slurm_distributed_launcher` decorator."
140
+ )
141
+ slurm_config: SlurmConfig = getattr(args, slurm_key)
142
+
143
+ def decorator(
144
+ submit_fn: Callable[..., Any],
145
+ ) -> SlurmFunction:
146
+ return SlurmFunction(
147
+ submit_fn=submit_fn,
148
+ default_submit_fn_args=(args,),
149
+ ).configure(
150
+ slurm_config,
151
+ slurm_params_kwargs,
152
+ slurm_submit_kwargs,
153
+ slurm_task_kwargs,
154
+ system_argv=argv,
155
+ )
156
+
157
+ return decorator
158
+
159
+
160
+ def slurm_function(
161
+ submit_fn: Callable,
162
+ ):
163
+ """A decorator to annoate a function to be run in slurm. The function decorated by this decorator should be launched in the way below.
164
+ ```
165
+ @slurm_function
166
+ def run_in_slurm(*args, **kwargs):
167
+ pass
168
+
169
+ job = run_in_slurm(slurm_config)(*args, **kwargs)
170
+ ```
171
+ The decorated function `submit_fn` is non-blocking now. To block and get the return value, you can call `job.result()`.
172
+ """
173
+
174
+ def wrapper(
175
+ slurm_config: SlurmConfig,
176
+ slurm_params_kwargs: Dict[str, Any] = {},
177
+ slurm_submit_kwargs: Dict[str, Any] = {},
178
+ slurm_task_kwargs: Dict[str, Any] = {},
179
+ system_argv: Union[List[str], None] = None,
180
+ ) -> SlurmFunction:
181
+ """Update the slurm configuration for the slurm function.
182
+
183
+ #### Exported Distributed Enviroment Variables
184
+ 1. NNTOOL_SLURM_HAS_BEEN_SET_UP is a special environment variable to indicate that the slurm has been set up.
185
+ 2. After the set up, the distributed job will be launched and the following variables are exported: num_processes: int, num_machines: int, machine_rank: int, main_process_ip: str, main_process_port: int.
186
+
187
+ :param slurm_config: SlurmConfig, the slurm configuration dataclass
188
+ :param slurm_params_kwargs: extra slurm arguments for the slurm configuration, defaults to {}
189
+ :param slurm_submit_kwargs: extra slurm arguments for `srun` or `sbatch`, defaults to {}
190
+ :param slurm_task_kwargs: extra arguments for the setting of distributed task, defaults to {}
191
+ :param system_argv: the system arguments for the second launch in the distributed task (by default it will use the current system arguments `sys.argv[1:]`), defaults to None
192
+ :return: the wrapped submit function with configured slurm paramters
193
+ """
194
+ warn(
195
+ "`slurm_function` has been deprecated. Please use `slurm_fn` instead, which supports both distributed and non-distributed job (controlled by `use_distributed_env` in slurm field).",
196
+ DeprecationWarning,
197
+ stacklevel=2,
198
+ )
199
+ slurm_fn = SlurmFunction(
200
+ submit_fn=submit_fn,
201
+ ).configure(
202
+ slurm_config,
203
+ slurm_params_kwargs,
204
+ slurm_submit_kwargs,
205
+ slurm_task_kwargs,
206
+ system_argv,
207
+ )
208
+ return slurm_fn
209
+
210
+ return wrapper
@@ -0,0 +1,25 @@
1
+ import importlib.metadata as importlib_metadata
2
+ import importlib.util
3
+
4
+
5
+ def is_torch_available():
6
+ package_exists = importlib.util.find_spec("torch") is not None
7
+
8
+ # Check we're not importing a "torch" directory somewhere but the actual library by
9
+ # trying to grab the version
10
+ if package_exists:
11
+ try:
12
+ _ = importlib_metadata.metadata("torch")
13
+ return True
14
+ except importlib_metadata.PackageNotFoundError:
15
+ return False
16
+
17
+
18
+ if is_torch_available():
19
+ from .trainer import BaseTrainer
20
+ else:
21
+ # Inherits from a dummy `object` if torch is not available, so that python
22
+ # succeeds to import this file.
23
+ # BaseTrainer abstraction code will never inherit this dummy object as it checks if
24
+ # torch is available.
25
+ from builtins import object as BaseTrainer