nntool 1.6.2__cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nntool might be problematic. Click here for more details.
- nntool/__init__.py +7 -0
- nntool/experiment/__init__.py +13 -0
- nntool/experiment/config.py +108 -0
- nntool/experiment/utils.py +63 -0
- nntool/plot/__init__.py +10 -0
- nntool/plot/context.py +46 -0
- nntool/plot/csrc/__compile__.cpython-313-x86_64-linux-gnu.so +0 -0
- nntool/plot/csrc/__compile__.py +0 -0
- nntool/plot/csrc/__init__.py +7 -0
- nntool/plot/csrc/latexify.py +129 -0
- nntool/slurm/__init__.py +21 -0
- nntool/slurm/accelerator/__init__.py +0 -0
- nntool/slurm/accelerator/utils.py +37 -0
- nntool/slurm/config.py +208 -0
- nntool/slurm/csrc/__compile__.cpython-313-x86_64-linux-gnu.so +0 -0
- nntool/slurm/csrc/__compile__.py +0 -0
- nntool/slurm/csrc/__init__.py +4 -0
- nntool/slurm/csrc/_slurm.py +476 -0
- nntool/slurm/csrc/_slurm_context.py +47 -0
- nntool/slurm/function.py +209 -0
- nntool/slurm/parser/__init__.py +6 -0
- nntool/slurm/parser/parse.py +22 -0
- nntool/slurm/task.py +294 -0
- nntool/slurm/wrap.py +147 -0
- nntool/utils/__init__.py +12 -0
- nntool/version.py +11 -0
- nntool/wandb/__init__.py +7 -0
- nntool/wandb/config.py +110 -0
- nntool-1.6.2.dist-info/METADATA +39 -0
- nntool-1.6.2.dist-info/RECORD +32 -0
- nntool-1.6.2.dist-info/WHEEL +7 -0
- nntool-1.6.2.dist-info/top_level.txt +1 -0
nntool/__init__.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from .utils import get_output_path, read_toml_file
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class BaseExperimentConfig:
|
|
11
|
+
"""
|
|
12
|
+
Configuration class for setting up an experiment.
|
|
13
|
+
|
|
14
|
+
:param config_name: The name of the configuration.
|
|
15
|
+
:param output_folder: The folder path where the outputs will be saved.
|
|
16
|
+
:param experiment_name_key: Key for experiment name in the environment variable, default is 'EXP_NAME'.
|
|
17
|
+
:param env_toml_path: Path to the `env.toml` file, default is 'env.toml'.
|
|
18
|
+
:param append_date_to_path: If True, the current date and time will be appended to the output path, default is True.
|
|
19
|
+
:param existing_output_path_ok: If True, the existing output path is ok to be reused, default is False.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
# config name
|
|
23
|
+
config_name: str
|
|
24
|
+
|
|
25
|
+
# the output folder for the outputs
|
|
26
|
+
output_folder: str
|
|
27
|
+
|
|
28
|
+
# key for experiment name in the environment variable
|
|
29
|
+
experiment_name_key: str = "EXP_NAME"
|
|
30
|
+
|
|
31
|
+
# the path to the env.toml file
|
|
32
|
+
env_toml_path: str = "env.toml"
|
|
33
|
+
|
|
34
|
+
# append date time to the output path
|
|
35
|
+
append_date_to_path: bool = True
|
|
36
|
+
|
|
37
|
+
# exisiting output path is ok
|
|
38
|
+
existing_output_path_ok: bool = False
|
|
39
|
+
|
|
40
|
+
def __post_init__(self):
|
|
41
|
+
# annotations
|
|
42
|
+
self.experiment_name: str
|
|
43
|
+
self.project_path: str
|
|
44
|
+
self.output_path: str
|
|
45
|
+
self.current_time: str
|
|
46
|
+
self.env_toml: Dict[str, Any] = self.__prepare_env_toml_dict()
|
|
47
|
+
|
|
48
|
+
self.experiment_name = self.__prepare_experiment_name()
|
|
49
|
+
self.project_path, self.output_path, self.current_time = self.__prepare_experiment_paths()
|
|
50
|
+
|
|
51
|
+
# custom post update for the derived class
|
|
52
|
+
self.set_up_stateful_fields()
|
|
53
|
+
|
|
54
|
+
def __prepare_env_toml_dict(self):
|
|
55
|
+
env_toml_path = Path(self.env_toml_path)
|
|
56
|
+
if not env_toml_path.exists():
|
|
57
|
+
raise FileNotFoundError(f"{env_toml_path} does not exist")
|
|
58
|
+
|
|
59
|
+
config = read_toml_file(env_toml_path)
|
|
60
|
+
return config
|
|
61
|
+
|
|
62
|
+
def __prepare_experiment_name(self):
|
|
63
|
+
return os.environ.get(self.experiment_name_key, "default")
|
|
64
|
+
|
|
65
|
+
def __prepare_experiment_paths(self):
|
|
66
|
+
project_path = self.env_toml["project"]["path"]
|
|
67
|
+
|
|
68
|
+
output_path, current_time = get_output_path(
|
|
69
|
+
output_path=os.path.join(self.output_folder, self.config_name, self.experiment_name),
|
|
70
|
+
append_date=self.append_date_to_path,
|
|
71
|
+
cache_into_env=False,
|
|
72
|
+
)
|
|
73
|
+
output_path = f"{project_path}/{output_path}"
|
|
74
|
+
return project_path, output_path, current_time
|
|
75
|
+
|
|
76
|
+
def get_output_path(self) -> str:
|
|
77
|
+
"""Return the output path prepared for the experiment.
|
|
78
|
+
|
|
79
|
+
:return: output path for the experiment
|
|
80
|
+
"""
|
|
81
|
+
return self.output_path
|
|
82
|
+
|
|
83
|
+
def get_current_time(self) -> str:
|
|
84
|
+
"""Return the current time for the experiment.
|
|
85
|
+
|
|
86
|
+
:return: current time for the experiment
|
|
87
|
+
"""
|
|
88
|
+
return self.current_time
|
|
89
|
+
|
|
90
|
+
def set_up_stateful_fields(self):
|
|
91
|
+
"""
|
|
92
|
+
Post configuration steps for stateful fields such as `output_path` in the derived class.
|
|
93
|
+
This method should be overridden in the derived class.
|
|
94
|
+
"""
|
|
95
|
+
pass
|
|
96
|
+
|
|
97
|
+
def start(self):
|
|
98
|
+
"""
|
|
99
|
+
Start the experimen. This will
|
|
100
|
+
- cache `NNTOOL_OUTPUT_PATH` and `NNTOOL_OUTPUT_PATH_DATE` into environment variables, which means the later launched processes would inherit these variables.
|
|
101
|
+
- create the output path if it does not exist.
|
|
102
|
+
"""
|
|
103
|
+
os.environ["NNTOOL_OUTPUT_PATH"] = self.get_output_path()
|
|
104
|
+
os.environ["NNTOOL_OUTPUT_PATH_DATE"] = self.get_current_time()
|
|
105
|
+
|
|
106
|
+
# create the output path
|
|
107
|
+
output_path = Path(self.get_output_path())
|
|
108
|
+
output_path.mkdir(parents=True, exist_ok=self.existing_output_path_ok)
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import datetime
|
|
3
|
+
import tomli
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def get_current_time() -> str:
|
|
7
|
+
"""get current time in this format: MMDDYYYY/HHMMSS
|
|
8
|
+
|
|
9
|
+
:return: time in the format MMDDYYYY/HHMMSS
|
|
10
|
+
"""
|
|
11
|
+
# Get the current time
|
|
12
|
+
current_time = datetime.datetime.now()
|
|
13
|
+
|
|
14
|
+
# Format the time (MDY/HMS)
|
|
15
|
+
formatted_time = current_time.strftime("%m%d%Y/%H%M%S")
|
|
16
|
+
|
|
17
|
+
return formatted_time
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def read_toml_file(file_path: str) -> dict:
|
|
21
|
+
"""Read a toml file and return the content as a dictionary
|
|
22
|
+
|
|
23
|
+
:param file_path: path to the toml file
|
|
24
|
+
:return: content of the toml file as a dictionary
|
|
25
|
+
"""
|
|
26
|
+
with open(file_path, "rb") as f:
|
|
27
|
+
content = tomli.load(f)
|
|
28
|
+
|
|
29
|
+
return content
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def get_output_path(
|
|
33
|
+
output_path: str = "./",
|
|
34
|
+
append_date: bool = True,
|
|
35
|
+
cache_into_env: bool = True,
|
|
36
|
+
) -> tuple[str, str]:
|
|
37
|
+
"""Get output path based on environment variable OUTPUT_PATH and NNTOOL_OUTPUT_PATH.
|
|
38
|
+
The output path is appended with the current time if append_date is True (e.g. /OUTPUT_PATH/xxx/MMDDYYYY/HHMMSS).
|
|
39
|
+
|
|
40
|
+
:param append_date: append a children folder with the date time, defaults to True
|
|
41
|
+
:param cache_into_env: whether cache the newly created path into env, defaults to True
|
|
42
|
+
:return: (output path, current time)
|
|
43
|
+
"""
|
|
44
|
+
if "OUTPUT_PATH" in os.environ:
|
|
45
|
+
output_path = os.environ["OUTPUT_PATH"]
|
|
46
|
+
current_time = "" if not append_date else get_current_time()
|
|
47
|
+
elif "NNTOOL_OUTPUT_PATH" in os.environ:
|
|
48
|
+
# reuse the NNTOOL_OUTPUT_PATH if it is set
|
|
49
|
+
output_path = os.environ["NNTOOL_OUTPUT_PATH"]
|
|
50
|
+
current_time = "" if not append_date else os.environ["NNTOOL_OUTPUT_PATH_DATE"]
|
|
51
|
+
else:
|
|
52
|
+
current_time = get_current_time()
|
|
53
|
+
if append_date:
|
|
54
|
+
output_path = os.path.join(output_path, current_time)
|
|
55
|
+
print(
|
|
56
|
+
f"OUTPUT_PATH is not found in environment variables. NNTOOL_OUTPUT_PATH is set using path: {output_path}"
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
if cache_into_env:
|
|
60
|
+
os.environ["NNTOOL_OUTPUT_PATH"] = output_path
|
|
61
|
+
os.environ["NNTOOL_OUTPUT_PATH_DATE"] = current_time
|
|
62
|
+
|
|
63
|
+
return output_path, current_time
|
nntool/plot/__init__.py
ADDED
nntool/plot/context.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import matplotlib
|
|
3
|
+
import seaborn as sns
|
|
4
|
+
|
|
5
|
+
from typing import Union
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from .csrc.latexify import SIZE_SMALL, latexify, savefig
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class latexify_plot:
|
|
12
|
+
enable: bool = True
|
|
13
|
+
width_scale_factor: float = 1
|
|
14
|
+
height_scale_factor: float = 1
|
|
15
|
+
fig_width: Union[float, None] = None
|
|
16
|
+
fig_height: Union[float, None] = None
|
|
17
|
+
font_size: int = SIZE_SMALL
|
|
18
|
+
|
|
19
|
+
def __post_init__(self):
|
|
20
|
+
self.legend_size = 7 if self.enable else None
|
|
21
|
+
|
|
22
|
+
def __enter__(self):
|
|
23
|
+
if self.enable:
|
|
24
|
+
os.environ["LATEXIFY"] = "1"
|
|
25
|
+
latexify(
|
|
26
|
+
width_scale_factor=self.width_scale_factor,
|
|
27
|
+
height_scale_factor=self.height_scale_factor,
|
|
28
|
+
fig_width=self.fig_width,
|
|
29
|
+
fig_height=self.fig_height,
|
|
30
|
+
font_size=self.font_size,
|
|
31
|
+
)
|
|
32
|
+
return self
|
|
33
|
+
|
|
34
|
+
def __exit__(self, *args):
|
|
35
|
+
if self.enable:
|
|
36
|
+
os.environ.pop("LATEXIFY")
|
|
37
|
+
matplotlib.rcParams.update(matplotlib.rcParamsDefault)
|
|
38
|
+
|
|
39
|
+
def savefig(self, filename, despine: bool = True, fig_dir: str = "tests/plot", **kwargs):
|
|
40
|
+
if despine:
|
|
41
|
+
sns.despine()
|
|
42
|
+
savefig(filename, fig_dir=fig_dir, **kwargs)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# This is for backward compatibility
|
|
46
|
+
enable_latexify = latexify_plot
|
|
Binary file
|
|
File without changes
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This code is borrowed from https://github.com/probml/probml-utils/blob/main/probml_utils/plotting.py
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
import matplotlib.pyplot as plt
|
|
7
|
+
import warnings
|
|
8
|
+
|
|
9
|
+
DEFAULT_WIDTH = 6.0
|
|
10
|
+
DEFAULT_HEIGHT = 1.5
|
|
11
|
+
SIZE_SMALL = 9 # Caption size in the pml book
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def latexify(
|
|
15
|
+
width_scale_factor=1,
|
|
16
|
+
height_scale_factor=1,
|
|
17
|
+
fig_width=None,
|
|
18
|
+
fig_height=None,
|
|
19
|
+
font_size=SIZE_SMALL,
|
|
20
|
+
):
|
|
21
|
+
f"""
|
|
22
|
+
width_scale_factor: float, DEFAULT_WIDTH will be divided by this number, DEFAULT_WIDTH is page width: {DEFAULT_WIDTH} inches.
|
|
23
|
+
height_scale_factor: float, DEFAULT_HEIGHT will be divided by this number, DEFAULT_HEIGHT is {DEFAULT_HEIGHT} inches.
|
|
24
|
+
fig_width: float, width of the figure in inches (if this is specified, width_scale_factor is ignored)
|
|
25
|
+
fig_height: float, height of the figure in inches (if this is specified, height_scale_factor is ignored)
|
|
26
|
+
font_size: float, font size
|
|
27
|
+
"""
|
|
28
|
+
if "LATEXIFY" not in os.environ:
|
|
29
|
+
warnings.warn("LATEXIFY environment variable not set, not latexifying")
|
|
30
|
+
return
|
|
31
|
+
if fig_width is None:
|
|
32
|
+
fig_width = DEFAULT_WIDTH / width_scale_factor
|
|
33
|
+
if fig_height is None:
|
|
34
|
+
fig_height = DEFAULT_HEIGHT / height_scale_factor
|
|
35
|
+
|
|
36
|
+
# use TrueType fonts so they are embedded
|
|
37
|
+
# https://stackoverflow.com/questions/9054884/how-to-embed-fonts-in-pdfs-produced-by-matplotlib
|
|
38
|
+
# https://jdhao.github.io/2018/01/18/mpl-plotting-notes-201801/
|
|
39
|
+
plt.rcParams["pdf.fonttype"] = 42
|
|
40
|
+
|
|
41
|
+
# Font sizes
|
|
42
|
+
# SIZE_MEDIUM = 14
|
|
43
|
+
# SIZE_LARGE = 24
|
|
44
|
+
# https://stackoverflow.com/a/39566040
|
|
45
|
+
plt.rc("font", size=font_size) # controls default text sizes
|
|
46
|
+
plt.rc("axes", titlesize=font_size) # fontsize of the axes title
|
|
47
|
+
plt.rc("axes", labelsize=font_size) # fontsize of the x and y labels
|
|
48
|
+
plt.rc("xtick", labelsize=font_size) # fontsize of the tick labels
|
|
49
|
+
plt.rc("ytick", labelsize=font_size) # fontsize of the tick labels
|
|
50
|
+
plt.rc("legend", fontsize=font_size) # legend fontsize
|
|
51
|
+
plt.rc("figure", titlesize=font_size) # fontsize of the figure title
|
|
52
|
+
|
|
53
|
+
# latexify: https://nipunbatra.github.io/blog/visualisation/2014/06/02/latexify.html
|
|
54
|
+
plt.rcParams["backend"] = "ps"
|
|
55
|
+
plt.rc("text", usetex=True)
|
|
56
|
+
plt.rc("font", family="serif")
|
|
57
|
+
plt.rc("figure", figsize=(fig_width, fig_height))
|
|
58
|
+
plt.rcParams["text.latex.preamble"] = r"""
|
|
59
|
+
\usepackage{amsmath}
|
|
60
|
+
\usepackage{bm}
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def is_latexify_enabled():
|
|
65
|
+
"""
|
|
66
|
+
returns true if LATEXIFY environment variable is set
|
|
67
|
+
"""
|
|
68
|
+
return "LATEXIFY" in os.environ
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _get_fig_name(fname_full):
|
|
72
|
+
fname_full = fname_full.replace("_latexified", "")
|
|
73
|
+
LATEXIFY = "LATEXIFY" in os.environ
|
|
74
|
+
# extention = "_latexified.pdf" if LATEXIFY else ".png"
|
|
75
|
+
extention = "_latexified.pdf" if LATEXIFY else ".pdf"
|
|
76
|
+
if fname_full[-4:] in [".png", ".pdf", ".jpg"]:
|
|
77
|
+
fname = fname_full[:-4]
|
|
78
|
+
print(
|
|
79
|
+
f"renaming {fname_full} to {fname}{extention} because LATEXIFY is {LATEXIFY}",
|
|
80
|
+
)
|
|
81
|
+
else:
|
|
82
|
+
fname = fname_full
|
|
83
|
+
return fname + extention
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def savefig(
|
|
87
|
+
f_name,
|
|
88
|
+
tight_layout=True,
|
|
89
|
+
tight_bbox=False,
|
|
90
|
+
pad_inches=0.0,
|
|
91
|
+
fig_dir=None,
|
|
92
|
+
*args,
|
|
93
|
+
**kwargs,
|
|
94
|
+
):
|
|
95
|
+
if len(f_name) == 0:
|
|
96
|
+
return
|
|
97
|
+
if "FIG_DIR" not in os.environ and fig_dir is None:
|
|
98
|
+
warnings.warn("set FIG_DIR environment variable or pass fig_dir argument to save figures")
|
|
99
|
+
return
|
|
100
|
+
|
|
101
|
+
fig_dir = fig_dir if fig_dir is not None else os.environ["FIG_DIR"]
|
|
102
|
+
# Auto create the directory if it doesn't exist
|
|
103
|
+
if not os.path.exists(fig_dir):
|
|
104
|
+
os.makedirs(fig_dir)
|
|
105
|
+
|
|
106
|
+
fname_full = os.path.join(fig_dir, f_name)
|
|
107
|
+
fname_full = _get_fig_name(fname_full)
|
|
108
|
+
|
|
109
|
+
print("saving image to {}".format(fname_full))
|
|
110
|
+
if tight_layout:
|
|
111
|
+
plt.tight_layout(pad=pad_inches)
|
|
112
|
+
print("Figure size:", plt.gcf().get_size_inches())
|
|
113
|
+
|
|
114
|
+
if tight_bbox:
|
|
115
|
+
# This changes the size of the figure
|
|
116
|
+
plt.savefig(fname_full, pad_inches=pad_inches, bbox_inches="tight", *args, **kwargs)
|
|
117
|
+
else:
|
|
118
|
+
plt.savefig(fname_full, pad_inches=pad_inches, *args, **kwargs)
|
|
119
|
+
|
|
120
|
+
if "DUAL_SAVE" in os.environ:
|
|
121
|
+
if fname_full.endswith(".pdf"):
|
|
122
|
+
fname_full = fname_full[:-4] + ".png"
|
|
123
|
+
else:
|
|
124
|
+
fname_full = fname_full[:-4] + ".pdf"
|
|
125
|
+
if tight_bbox:
|
|
126
|
+
# This changes the size of the figure
|
|
127
|
+
plt.savefig(fname_full, pad_inches=pad_inches, bbox_inches="tight", *args, **kwargs)
|
|
128
|
+
else:
|
|
129
|
+
plt.savefig(fname_full, pad_inches=pad_inches, *args, **kwargs)
|
nntool/slurm/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from .config import SlurmConfig, SlurmArgs
|
|
2
|
+
from .wrap import (
|
|
3
|
+
slurm_function,
|
|
4
|
+
slurm_fn,
|
|
5
|
+
slurm_launcher,
|
|
6
|
+
)
|
|
7
|
+
from .function import SlurmFunction
|
|
8
|
+
from .task import Task, DistributedTaskConfig, PyTorchDistributedTask
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"SlurmConfig",
|
|
13
|
+
"SlurmArgs",
|
|
14
|
+
"SlurmFunction",
|
|
15
|
+
"slurm_fn",
|
|
16
|
+
"slurm_function",
|
|
17
|
+
"slurm_launcher",
|
|
18
|
+
"Task",
|
|
19
|
+
"DistributedTaskConfig",
|
|
20
|
+
"PyTorchDistributedTask",
|
|
21
|
+
]
|
|
File without changes
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
import subprocess
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def nvidia_smi_gpu_memory_stats() -> dict:
|
|
5
|
+
"""
|
|
6
|
+
Parse the nvidia-smi output and extract the memory used stats.
|
|
7
|
+
"""
|
|
8
|
+
out_dict = {}
|
|
9
|
+
try:
|
|
10
|
+
sp = subprocess.Popen(
|
|
11
|
+
["nvidia-smi", "--query-gpu=index,memory.used", "--format=csv,noheader"],
|
|
12
|
+
stdout=subprocess.PIPE,
|
|
13
|
+
stderr=subprocess.PIPE,
|
|
14
|
+
close_fds=True,
|
|
15
|
+
)
|
|
16
|
+
out_str = sp.communicate()
|
|
17
|
+
out_list = out_str[0].decode("utf-8").split("\n")
|
|
18
|
+
out_dict = {}
|
|
19
|
+
for item in out_list:
|
|
20
|
+
if " MiB" in item:
|
|
21
|
+
gpu_idx, mem_used = item.split(",")
|
|
22
|
+
gpu_key = f"gpu_{gpu_idx}_mem_used_gb"
|
|
23
|
+
out_dict[gpu_key] = int(mem_used.strip().split(" ")[0]) / 1024
|
|
24
|
+
except FileNotFoundError:
|
|
25
|
+
raise Exception("Failed to find the 'nvidia-smi' executable for printing GPU stats")
|
|
26
|
+
except subprocess.CalledProcessError as e:
|
|
27
|
+
raise Exception(f"nvidia-smi returned non zero error code: {e.returncode}")
|
|
28
|
+
|
|
29
|
+
return out_dict
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def nvidia_smi_gpu_memory_stats_str() -> str:
|
|
33
|
+
"""
|
|
34
|
+
Parse the nvidia-smi output and extract the memory used stats.
|
|
35
|
+
"""
|
|
36
|
+
stats = nvidia_smi_gpu_memory_stats()
|
|
37
|
+
return ", ".join([f"{k}: {v:.4f}" for k, v in stats.items()])
|
nntool/slurm/config.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass, field, replace
|
|
5
|
+
from typing import List, Literal, Dict, Optional
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class SlurmConfig:
|
|
10
|
+
"""
|
|
11
|
+
Configuration class for SLURM job submission and execution.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
mode (Literal["debug", "exec", "local", "slurm"]): Running mode for the job. Options include:
|
|
15
|
+
"debug" (default, run debugging which will involve pdb), "exec" (alias of local), "local" (run the job locally), or "slurm" (run the job on a SLURM cluster).
|
|
16
|
+
|
|
17
|
+
job_name (str): The name of the SLURM job. Default is 'Job'.
|
|
18
|
+
|
|
19
|
+
partition (str): The name of the SLURM partition to use. Default is ''.
|
|
20
|
+
|
|
21
|
+
output_parent_path (str): The parent directory path for saving the slurm folder. Default is './'.
|
|
22
|
+
|
|
23
|
+
output_folder (str): The folder name where SLURM output files will be stored. Default is 'slurm'.
|
|
24
|
+
|
|
25
|
+
node_list (str): A string specifying the nodes to use. Leave blank to use all available nodes. Default is an empty string.
|
|
26
|
+
|
|
27
|
+
node_list_exclude (str): A string specifying the nodes to exclude. Leave blank to use all nodes in the node list. Default is an empty string.
|
|
28
|
+
|
|
29
|
+
num_of_node (int): The number of nodes to request. Default is 1.
|
|
30
|
+
|
|
31
|
+
tasks_per_node (int): The number of tasks to run per node. Default is 1.
|
|
32
|
+
|
|
33
|
+
gpus_per_task (int): The number of GPUs to request per task. Default is 0.
|
|
34
|
+
|
|
35
|
+
cpus_per_task (int): The number of CPUs to request per task. Default is 1.
|
|
36
|
+
|
|
37
|
+
gpus_per_node (int): The number of GPUs to request per node. If this is set, `gpus_per_task` will be ignored. Default is None.
|
|
38
|
+
|
|
39
|
+
mem (str): The amount of memory (GB) to request. Leave blank to use the default memory configuration of the node. Default is an empty string.
|
|
40
|
+
|
|
41
|
+
timeout_min (int): The time limit for the job in minutes. Default is `sys.maxsize` for effectively no limit.
|
|
42
|
+
|
|
43
|
+
stderr_to_stdout (bool): Whether to redirect stderr to stdout. Default is False.
|
|
44
|
+
|
|
45
|
+
setup (List[str]): A list of environment variable setup commands. Default is an empty list.
|
|
46
|
+
|
|
47
|
+
pack_code (bool): Whether to pack the codebase before submission. Default is False.
|
|
48
|
+
|
|
49
|
+
use_packed_code (bool): Whether to use the packed code for execution. Default is False.
|
|
50
|
+
|
|
51
|
+
code_root (str): The root directory of the codebase, which will be used by the code packing. Default is the current directory (``.``).
|
|
52
|
+
|
|
53
|
+
code_file_suffixes (List[str]): A list of file extensions for code files to be included when packing. Default includes ``.py``, ``.sh``, ``.yaml``, and ``.toml``.
|
|
54
|
+
|
|
55
|
+
exclude_code_folders (List[str]): A list of folder names relative to `code_root` that will be excluded from packing. Default excludes 'wandb', 'outputs', and 'datasets'.
|
|
56
|
+
|
|
57
|
+
use_distributed_env (bool): Whether to use a distributed environment for the job. Default is False.
|
|
58
|
+
|
|
59
|
+
distributed_env_task (Literal["torch"]): The type of distributed environment task to use. Currently, only "torch" is supported. Default is "torch".
|
|
60
|
+
|
|
61
|
+
processes_per_task (int): The number of processes to run per task. This value is not used by SLURM but is relevant for correctly set up distributed environments. Default is 1.
|
|
62
|
+
|
|
63
|
+
distributed_launch_command (str): The command to launch distributed environment setup, using environment variables like ``{num_processes}``, ``{num_machines}``, ``{machine_rank}``, ``{main_process_ip}``, ``{main_process_port}``. Default is an empty string.
|
|
64
|
+
|
|
65
|
+
extra_params_kwargs (Dict[str, str]): Additional parameters for the SLURM job as a dictionary of key-value pairs. Default is an empty dictionary.
|
|
66
|
+
|
|
67
|
+
extra_submit_kwargs (Dict[str, str]): Additional submit parameters for the SLURM job as a dictionary of key-value pairs. Default is an empty dictionary.
|
|
68
|
+
|
|
69
|
+
extra_task_kwargs (Dict[str, str]): Additional task parameters for the SLURM job as a dictionary of key-value pairs. Default is an empty dictionary.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
# running mode
|
|
73
|
+
mode: Literal["debug", "exec", "local", "slurm"] = "debug"
|
|
74
|
+
|
|
75
|
+
# slurm job name
|
|
76
|
+
job_name: str = "Job"
|
|
77
|
+
|
|
78
|
+
# slurm partition name
|
|
79
|
+
partition: str = ""
|
|
80
|
+
|
|
81
|
+
# slurm output parent path
|
|
82
|
+
output_parent_path: str = "./"
|
|
83
|
+
|
|
84
|
+
# slurm output folder name
|
|
85
|
+
output_folder: str = "slurm"
|
|
86
|
+
|
|
87
|
+
# node list string (leave blank to use all nodes)
|
|
88
|
+
node_list: str = ""
|
|
89
|
+
|
|
90
|
+
# node list string to be excluded (leave blank to use all nodes in the node list)
|
|
91
|
+
node_list_exclude: str = ""
|
|
92
|
+
|
|
93
|
+
# number of nodes to request
|
|
94
|
+
num_of_node: int = 1
|
|
95
|
+
|
|
96
|
+
# tasks per node
|
|
97
|
+
tasks_per_node: int = 1
|
|
98
|
+
|
|
99
|
+
# number of gpus per task to request
|
|
100
|
+
gpus_per_task: int = 0
|
|
101
|
+
|
|
102
|
+
# number of cpus per task to request
|
|
103
|
+
cpus_per_task: int = 1
|
|
104
|
+
|
|
105
|
+
# number of gpus per node to request (if this is set, gpus_per_task will be ignored)
|
|
106
|
+
gpus_per_node: Optional[int] = None
|
|
107
|
+
|
|
108
|
+
# memory (GB) to request (leave black to use default memory configurations in the node)
|
|
109
|
+
mem: str = ""
|
|
110
|
+
|
|
111
|
+
# time out min
|
|
112
|
+
timeout_min: int = sys.maxsize
|
|
113
|
+
|
|
114
|
+
# whether to redirect stderr to stdout
|
|
115
|
+
stderr_to_stdout: bool = False
|
|
116
|
+
|
|
117
|
+
# environment variables setup command
|
|
118
|
+
setup: List[str] = field(default_factory=list)
|
|
119
|
+
|
|
120
|
+
# whether to pack code
|
|
121
|
+
pack_code: bool = False
|
|
122
|
+
|
|
123
|
+
# use packed code to run
|
|
124
|
+
use_packed_code: bool = False
|
|
125
|
+
|
|
126
|
+
# code root
|
|
127
|
+
code_root: str = "."
|
|
128
|
+
|
|
129
|
+
# code file extensions
|
|
130
|
+
code_file_suffixes: list[str] = field(default_factory=lambda: [".py", ".sh", ".yaml", ".toml"])
|
|
131
|
+
|
|
132
|
+
# exclude folders (relative to the code root)
|
|
133
|
+
exclude_code_folders: list[str] = field(
|
|
134
|
+
default_factory=lambda: ["wandb", "outputs", "datasets"]
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
# whether to use distributed environment
|
|
138
|
+
use_distributed_env: bool = False
|
|
139
|
+
|
|
140
|
+
# distributed enviroment task
|
|
141
|
+
distributed_env_task: Literal["torch"] = "torch"
|
|
142
|
+
|
|
143
|
+
# processes per task (this value is not used by slurm, but in the distributed environment)
|
|
144
|
+
processes_per_task: int = 1
|
|
145
|
+
|
|
146
|
+
# distributed launch command (this will be called after the distributed enviroment is set up)
|
|
147
|
+
# the following environment variables are available:
|
|
148
|
+
# num_processes: int
|
|
149
|
+
# num_machines: int
|
|
150
|
+
# machine_rank: int
|
|
151
|
+
# main_process_ip: str
|
|
152
|
+
# main_process_port: int
|
|
153
|
+
# use braces to access the environment variables, e.g. {num_processes}
|
|
154
|
+
distributed_launch_command: str = ""
|
|
155
|
+
|
|
156
|
+
# extra slurm job parameters
|
|
157
|
+
extra_params_kwargs: Dict[str, str] = field(default_factory=dict)
|
|
158
|
+
|
|
159
|
+
# extra slurm submit parameters
|
|
160
|
+
extra_submit_kwargs: Dict[str, str] = field(default_factory=dict)
|
|
161
|
+
|
|
162
|
+
# extra slurm task parameters
|
|
163
|
+
extra_task_kwargs: Dict[str, str] = field(default_factory=dict)
|
|
164
|
+
|
|
165
|
+
def _configuration_check(self):
|
|
166
|
+
# check partition
|
|
167
|
+
if self.partition == "":
|
|
168
|
+
raise ValueError("partition must be set")
|
|
169
|
+
|
|
170
|
+
# check distributed enviroment task
|
|
171
|
+
if self.use_distributed_env and self.distributed_launch_command == "":
|
|
172
|
+
raise ValueError(
|
|
173
|
+
"distributed_launch_command must be set when use_distributed_env is True"
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
def __post_init__(self):
|
|
177
|
+
# check configuration
|
|
178
|
+
self._configuration_check()
|
|
179
|
+
|
|
180
|
+
# normalize the output folder
|
|
181
|
+
output_folder_suffix = ""
|
|
182
|
+
if self.mode != "slurm":
|
|
183
|
+
output_folder_suffix = f"_{self.mode}"
|
|
184
|
+
if self.output_folder.endswith("slurm"):
|
|
185
|
+
self.output_folder = f"{self.output_folder}{output_folder_suffix}"
|
|
186
|
+
else:
|
|
187
|
+
self.output_folder = os.path.join(self.output_folder, f"slurm{output_folder_suffix}")
|
|
188
|
+
|
|
189
|
+
# output path
|
|
190
|
+
self.output_path: str = os.path.join(self.output_parent_path, self.output_folder)
|
|
191
|
+
|
|
192
|
+
def set_output_path(self, output_parent_path: str) -> "SlurmConfig":
|
|
193
|
+
"""Set output path and date for the slurm job.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
output_parent_path (str): The parent path for the output.
|
|
197
|
+
|
|
198
|
+
Returns:
|
|
199
|
+
SlurmConfig: The updated SlurmConfig instance.
|
|
200
|
+
"""
|
|
201
|
+
new_config = replace(
|
|
202
|
+
self,
|
|
203
|
+
output_parent_path=output_parent_path,
|
|
204
|
+
)
|
|
205
|
+
return new_config
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
SlurmArgs = SlurmConfig
|
|
Binary file
|
|
File without changes
|