nntool 2.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nntool/__init__.py +7 -0
- nntool/build_backend.py +24 -0
- nntool/experiment/__init__.py +13 -0
- nntool/experiment/config.py +108 -0
- nntool/experiment/utils.py +63 -0
- nntool/slurm/__init__.py +21 -0
- nntool/slurm/accelerator/__init__.py +0 -0
- nntool/slurm/accelerator/utils.py +37 -0
- nntool/slurm/config.py +208 -0
- nntool/slurm/core/__init__.py +4 -0
- nntool/slurm/core/_slurm.py +546 -0
- nntool/slurm/core/_slurm_context.py +47 -0
- nntool/slurm/function.py +209 -0
- nntool/slurm/parser/__init__.py +6 -0
- nntool/slurm/parser/parse.py +22 -0
- nntool/slurm/task.py +300 -0
- nntool/slurm/wrap.py +148 -0
- nntool/utils/__init__.py +12 -0
- nntool/version.py +11 -0
- nntool/wandb/__init__.py +7 -0
- nntool/wandb/config.py +116 -0
- nntool-2.0.0rc0.dist-info/METADATA +12 -0
- nntool-2.0.0rc0.dist-info/RECORD +25 -0
- nntool-2.0.0rc0.dist-info/WHEEL +5 -0
- nntool-2.0.0rc0.dist-info/top_level.txt +1 -0
nntool/utils/__init__.py
ADDED
nntool/version.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
_MAJOR = "2"
|
|
2
|
+
_MINOR = "0"
|
|
3
|
+
# On main and in a nightly release the patch should be one ahead of the last
|
|
4
|
+
# released build.
|
|
5
|
+
_PATCH = "0"
|
|
6
|
+
# This is mainly for nightly builds which have the suffix ".dev$DATE". See
|
|
7
|
+
# https://semver.org/#is-v123-a-semantic-version for the semantics.
|
|
8
|
+
_SUFFIX = "rc"
|
|
9
|
+
|
|
10
|
+
VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR)
|
|
11
|
+
VERSION = "{0}.{1}.{2}{3}".format(_MAJOR, _MINOR, _PATCH, _SUFFIX)
|
nntool/wandb/__init__.py
ADDED
nntool/wandb/config.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import git
|
|
3
|
+
import wandb
|
|
4
|
+
import tomli
|
|
5
|
+
import warnings
|
|
6
|
+
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from typing import Optional, Union
|
|
9
|
+
from wandb.sdk.lib.disabled import RunDisabled
|
|
10
|
+
from wandb.sdk.wandb_run import Run
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class WandbConfig:
|
|
15
|
+
"""
|
|
16
|
+
Configuration class for Weights and Biases (wandb) integration.
|
|
17
|
+
|
|
18
|
+
:param api_key_config_file: The file path to the configuration file containing the wandb API key.
|
|
19
|
+
The file should be a toml file with a `[wandb]` section. Default is an empty string.
|
|
20
|
+
|
|
21
|
+
:param dir: The directory to save the wandb logs. Default is None.
|
|
22
|
+
|
|
23
|
+
:param project: The name of the project in wandb. Default is an empty string.
|
|
24
|
+
|
|
25
|
+
:param entity: The wandb user or team name. Default is an empty string.
|
|
26
|
+
|
|
27
|
+
:param name: The name of the wandb run. Leave blank to use the default run name. Default is an empty string.
|
|
28
|
+
|
|
29
|
+
:param notes: Notes or comments for the wandb run. Default is an empty string.
|
|
30
|
+
|
|
31
|
+
:param log_git_hash: Whether to log the current Git hash. Default is True.
|
|
32
|
+
|
|
33
|
+
:param log_code: Whether to log the current codebase. Default is True.
|
|
34
|
+
|
|
35
|
+
:param code_root: The root directory of the codebase to be logged. Default is the current directory (`.`).
|
|
36
|
+
|
|
37
|
+
:param code_ext: A list of file extensions for the code files to be logged. Default includes `.py` and `.sh`.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
# wandb api key (toml file with [wandb] key field)
|
|
41
|
+
api_key_config_file: str = ""
|
|
42
|
+
|
|
43
|
+
# dir to save the wandb logs
|
|
44
|
+
dir: Optional[str] = None
|
|
45
|
+
|
|
46
|
+
# project name in wandb
|
|
47
|
+
project: str = ""
|
|
48
|
+
|
|
49
|
+
# wandb user name
|
|
50
|
+
entity: str = ""
|
|
51
|
+
|
|
52
|
+
# wandb run name (leave blacnk to use default name)
|
|
53
|
+
name: str = ""
|
|
54
|
+
|
|
55
|
+
# wandb run notes
|
|
56
|
+
notes: str = ""
|
|
57
|
+
|
|
58
|
+
# log git hash
|
|
59
|
+
log_git_hash: bool = True
|
|
60
|
+
|
|
61
|
+
# log code
|
|
62
|
+
log_code: bool = True
|
|
63
|
+
|
|
64
|
+
# code root
|
|
65
|
+
code_root: str = "."
|
|
66
|
+
|
|
67
|
+
# code file extensions
|
|
68
|
+
code_ext: list[str] = field(default_factory=lambda: [".py", ".sh"])
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def is_wandb_enabled():
|
|
72
|
+
return wandb.run is not None
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def init_wandb(args: WandbConfig, run_config: dict) -> Union[Run, RunDisabled, None]:
|
|
76
|
+
"""initialize wandb and log the configuration
|
|
77
|
+
|
|
78
|
+
:param args: WandbConfig object
|
|
79
|
+
:param run_config: configuration dictionary to be logged
|
|
80
|
+
:return: wandb run object
|
|
81
|
+
"""
|
|
82
|
+
project, entity = args.project, args.entity
|
|
83
|
+
if "WANDB_API_KEY" in os.environ:
|
|
84
|
+
warnings.warn("WANDB_API_KEY is found in environment variables. Using it.")
|
|
85
|
+
wandb.login(key=os.environ["WANDB_API_KEY"])
|
|
86
|
+
elif args.api_key_config_file:
|
|
87
|
+
with open(args.api_key_config_file, "rb") as config_file:
|
|
88
|
+
config_data = tomli.load(config_file)
|
|
89
|
+
wandb.login(key=config_data["wandb"]["key"])
|
|
90
|
+
project = config_data["wandb"].get("project", args.project)
|
|
91
|
+
entity = config_data["wandb"].get("entity", args.entity)
|
|
92
|
+
else:
|
|
93
|
+
warnings.warn("WANDB_API_KEY is not found in environment variables or the local key file.")
|
|
94
|
+
pass
|
|
95
|
+
|
|
96
|
+
if args.log_git_hash:
|
|
97
|
+
repo = git.Repo(search_parent_directories=True)
|
|
98
|
+
git_hash = repo.head.object.hexsha
|
|
99
|
+
args.notes = f"{args.notes + ', ' if args.notes else ''}" + f"git hash: {git_hash}"
|
|
100
|
+
|
|
101
|
+
wandb_run = wandb.init(
|
|
102
|
+
dir=args.dir,
|
|
103
|
+
project=project,
|
|
104
|
+
entity=entity,
|
|
105
|
+
name=args.name if args.name else None,
|
|
106
|
+
notes=args.notes,
|
|
107
|
+
config=run_config,
|
|
108
|
+
)
|
|
109
|
+
if not is_wandb_enabled():
|
|
110
|
+
warnings.warn("wandb is not enabled after intialization. Please check `wandb enabled`.")
|
|
111
|
+
|
|
112
|
+
wandb_run.log_code(
|
|
113
|
+
root=args.code_root,
|
|
114
|
+
include_fn=lambda path, root: any([path.endswith(ext) for ext in args.code_ext]),
|
|
115
|
+
)
|
|
116
|
+
return wandb_run
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: nntool
|
|
3
|
+
Version: 2.0.0rc0
|
|
4
|
+
Summary: NNTool is a package built on top of submitit designed to provide simple abstractions to conduct experiments on Slurm for machine learning research.
|
|
5
|
+
Author-email: Junhao Liu <junhaoliu17@gmail.com>
|
|
6
|
+
Project-URL: Homepage, https://github.com/jhliu17/nntool
|
|
7
|
+
Keywords: deep learning,neural network,research
|
|
8
|
+
Requires-Python: >=3.10
|
|
9
|
+
Requires-Dist: tyro
|
|
10
|
+
Requires-Dist: submitit>=1.5.4
|
|
11
|
+
Requires-Dist: wandb>=0.15.0
|
|
12
|
+
Requires-Dist: tomli>=2.0.1
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
nntool/__init__.py,sha256=ISEQYjns0E37EbFA__hp6bJJES5t_3Y_ZgtwZVKP8ko,117
|
|
2
|
+
nntool/build_backend.py,sha256=kGAtEbEyjxq8YHXNun95Dl6hQ0dI6o4VVdP3RAmL24E,759
|
|
3
|
+
nntool/version.py,sha256=GeySxzv0TjgCbYsd3GB6Xo-twF0ZgnVguEaHdn9wgmI,409
|
|
4
|
+
nntool/experiment/__init__.py,sha256=GIr3ehBQSTTiQyxn7WDiqna7SQVO5YvAG7WbMl7gg8s,239
|
|
5
|
+
nntool/experiment/config.py,sha256=6cjk9pBdKjs9-Syawbjw0RtgfQybnoJkRpxM-g993O0,3807
|
|
6
|
+
nntool/experiment/utils.py,sha256=ajrjYNltyNwNfmFciygCvbcRBtn8Gc7-dVR07Q8xbKo,2119
|
|
7
|
+
nntool/slurm/__init__.py,sha256=YHFIpTTVDR-xmj7xDmfsRfa25UNThO4qISVS01NtiPc,429
|
|
8
|
+
nntool/slurm/config.py,sha256=wDsdHhcc5tNR9nyLWG5SwcvKJoqo84dedz_3v0WcqdM,8260
|
|
9
|
+
nntool/slurm/function.py,sha256=_awI1g8hEaOQxBbRMmO7jJwqBXtYHPTaaZ9g1XK7JsM,8865
|
|
10
|
+
nntool/slurm/task.py,sha256=3rEgEXKL9q3VnDRtxRH6LzqaKjXDyiCOxe3aV6mgSeQ,11132
|
|
11
|
+
nntool/slurm/wrap.py,sha256=ZtHDqJD6Uk2rwumdiewz6Yl6tVtau5tZ5ZCo9VF_JJw,5291
|
|
12
|
+
nntool/slurm/accelerator/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
13
|
+
nntool/slurm/accelerator/utils.py,sha256=VuGfPGjcanDXmDHtZh7U_jEx0k45aBtAJAsUJs_LoRc,1276
|
|
14
|
+
nntool/slurm/core/__init__.py,sha256=920Q5lxUFOpngHN0ZAXRjgb4GqPXSyJAhlnQk_hbKYM,79
|
|
15
|
+
nntool/slurm/core/_slurm.py,sha256=1tZ_TDFFsolsjTXbFsBTJUd7pBRSXfVbJrBrwUukfdE,24660
|
|
16
|
+
nntool/slurm/core/_slurm_context.py,sha256=oW9R9fwa5IAnA8LtbLkyn_wPf-54aDMRoF8vhyhUQN8,1912
|
|
17
|
+
nntool/slurm/parser/__init__.py,sha256=yxIckzha7nSlsdjPYKrP9ozneeQFtBdziTPJeVPXAFY,72
|
|
18
|
+
nntool/slurm/parser/parse.py,sha256=IqlcqQfd7GoeTGJkMOrb2_-x0s49Zz_2fA3Fsk_6ysQ,552
|
|
19
|
+
nntool/utils/__init__.py,sha256=YnGpWfnDUhxOR1-wx0-tjKOdHr9cP3_B242xoIT5Kww,219
|
|
20
|
+
nntool/wandb/__init__.py,sha256=lExtgvdu-E60vOyqlNIc7cFF565U4fGFDHFKrZyMxtg,97
|
|
21
|
+
nntool/wandb/config.py,sha256=TzHZmatV8bGl_kixWgJx0MI_DiXjscEESXlbYsokIos,3712
|
|
22
|
+
nntool-2.0.0rc0.dist-info/METADATA,sha256=czd8bC5BJGXcmEqouT27A9H5Vjqpnlvb0ECmsuvafn8,496
|
|
23
|
+
nntool-2.0.0rc0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
24
|
+
nntool-2.0.0rc0.dist-info/top_level.txt,sha256=hXO7zSAI-o9plYeUkCiKbkDpQCfTDIfjivyoSEqT28Y,7
|
|
25
|
+
nntool-2.0.0rc0.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
nntool
|