nntool 2.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,12 @@
1
+ # this is for backward compatibility
2
+ from ..experiment.utils import (
3
+ get_current_time,
4
+ get_output_path,
5
+ read_toml_file,
6
+ )
7
+
8
+ __all__ = [
9
+ "get_current_time",
10
+ "get_output_path",
11
+ "read_toml_file",
12
+ ]
nntool/version.py ADDED
@@ -0,0 +1,11 @@
1
+ _MAJOR = "2"
2
+ _MINOR = "0"
3
+ # On main and in a nightly release the patch should be one ahead of the last
4
+ # released build.
5
+ _PATCH = "0"
6
+ # This is mainly for nightly builds which have the suffix ".dev$DATE". See
7
+ # https://semver.org/#is-v123-a-semantic-version for the semantics.
8
+ _SUFFIX = "rc"
9
+
10
+ VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR)
11
+ VERSION = "{0}.{1}.{2}{3}".format(_MAJOR, _MINOR, _PATCH, _SUFFIX)
@@ -0,0 +1,7 @@
1
+ from .config import WandbConfig, init_wandb
2
+
3
+
4
+ __all__ = [
5
+ "WandbConfig",
6
+ "init_wandb",
7
+ ]
nntool/wandb/config.py ADDED
@@ -0,0 +1,116 @@
1
+ import os
2
+ import git
3
+ import wandb
4
+ import tomli
5
+ import warnings
6
+
7
+ from dataclasses import dataclass, field
8
+ from typing import Optional, Union
9
+ from wandb.sdk.lib.disabled import RunDisabled
10
+ from wandb.sdk.wandb_run import Run
11
+
12
+
13
+ @dataclass
14
+ class WandbConfig:
15
+ """
16
+ Configuration class for Weights and Biases (wandb) integration.
17
+
18
+ :param api_key_config_file: The file path to the configuration file containing the wandb API key.
19
+ The file should be a toml file with a `[wandb]` section. Default is an empty string.
20
+
21
+ :param dir: The directory to save the wandb logs. Default is None.
22
+
23
+ :param project: The name of the project in wandb. Default is an empty string.
24
+
25
+ :param entity: The wandb user or team name. Default is an empty string.
26
+
27
+ :param name: The name of the wandb run. Leave blank to use the default run name. Default is an empty string.
28
+
29
+ :param notes: Notes or comments for the wandb run. Default is an empty string.
30
+
31
+ :param log_git_hash: Whether to log the current Git hash. Default is True.
32
+
33
+ :param log_code: Whether to log the current codebase. Default is True.
34
+
35
+ :param code_root: The root directory of the codebase to be logged. Default is the current directory (`.`).
36
+
37
+ :param code_ext: A list of file extensions for the code files to be logged. Default includes `.py` and `.sh`.
38
+ """
39
+
40
+ # wandb api key (toml file with [wandb] key field)
41
+ api_key_config_file: str = ""
42
+
43
+ # dir to save the wandb logs
44
+ dir: Optional[str] = None
45
+
46
+ # project name in wandb
47
+ project: str = ""
48
+
49
+ # wandb user name
50
+ entity: str = ""
51
+
52
+ # wandb run name (leave blacnk to use default name)
53
+ name: str = ""
54
+
55
+ # wandb run notes
56
+ notes: str = ""
57
+
58
+ # log git hash
59
+ log_git_hash: bool = True
60
+
61
+ # log code
62
+ log_code: bool = True
63
+
64
+ # code root
65
+ code_root: str = "."
66
+
67
+ # code file extensions
68
+ code_ext: list[str] = field(default_factory=lambda: [".py", ".sh"])
69
+
70
+
71
+ def is_wandb_enabled():
72
+ return wandb.run is not None
73
+
74
+
75
+ def init_wandb(args: WandbConfig, run_config: dict) -> Union[Run, RunDisabled, None]:
76
+ """initialize wandb and log the configuration
77
+
78
+ :param args: WandbConfig object
79
+ :param run_config: configuration dictionary to be logged
80
+ :return: wandb run object
81
+ """
82
+ project, entity = args.project, args.entity
83
+ if "WANDB_API_KEY" in os.environ:
84
+ warnings.warn("WANDB_API_KEY is found in environment variables. Using it.")
85
+ wandb.login(key=os.environ["WANDB_API_KEY"])
86
+ elif args.api_key_config_file:
87
+ with open(args.api_key_config_file, "rb") as config_file:
88
+ config_data = tomli.load(config_file)
89
+ wandb.login(key=config_data["wandb"]["key"])
90
+ project = config_data["wandb"].get("project", args.project)
91
+ entity = config_data["wandb"].get("entity", args.entity)
92
+ else:
93
+ warnings.warn("WANDB_API_KEY is not found in environment variables or the local key file.")
94
+ pass
95
+
96
+ if args.log_git_hash:
97
+ repo = git.Repo(search_parent_directories=True)
98
+ git_hash = repo.head.object.hexsha
99
+ args.notes = f"{args.notes + ', ' if args.notes else ''}" + f"git hash: {git_hash}"
100
+
101
+ wandb_run = wandb.init(
102
+ dir=args.dir,
103
+ project=project,
104
+ entity=entity,
105
+ name=args.name if args.name else None,
106
+ notes=args.notes,
107
+ config=run_config,
108
+ )
109
+ if not is_wandb_enabled():
110
+ warnings.warn("wandb is not enabled after intialization. Please check `wandb enabled`.")
111
+
112
+ wandb_run.log_code(
113
+ root=args.code_root,
114
+ include_fn=lambda path, root: any([path.endswith(ext) for ext in args.code_ext]),
115
+ )
116
+ return wandb_run
@@ -0,0 +1,12 @@
1
+ Metadata-Version: 2.4
2
+ Name: nntool
3
+ Version: 2.0.0rc0
4
+ Summary: NNTool is a package built on top of submitit designed to provide simple abstractions to conduct experiments on Slurm for machine learning research.
5
+ Author-email: Junhao Liu <junhaoliu17@gmail.com>
6
+ Project-URL: Homepage, https://github.com/jhliu17/nntool
7
+ Keywords: deep learning,neural network,research
8
+ Requires-Python: >=3.10
9
+ Requires-Dist: tyro
10
+ Requires-Dist: submitit>=1.5.4
11
+ Requires-Dist: wandb>=0.15.0
12
+ Requires-Dist: tomli>=2.0.1
@@ -0,0 +1,25 @@
1
+ nntool/__init__.py,sha256=ISEQYjns0E37EbFA__hp6bJJES5t_3Y_ZgtwZVKP8ko,117
2
+ nntool/build_backend.py,sha256=kGAtEbEyjxq8YHXNun95Dl6hQ0dI6o4VVdP3RAmL24E,759
3
+ nntool/version.py,sha256=GeySxzv0TjgCbYsd3GB6Xo-twF0ZgnVguEaHdn9wgmI,409
4
+ nntool/experiment/__init__.py,sha256=GIr3ehBQSTTiQyxn7WDiqna7SQVO5YvAG7WbMl7gg8s,239
5
+ nntool/experiment/config.py,sha256=6cjk9pBdKjs9-Syawbjw0RtgfQybnoJkRpxM-g993O0,3807
6
+ nntool/experiment/utils.py,sha256=ajrjYNltyNwNfmFciygCvbcRBtn8Gc7-dVR07Q8xbKo,2119
7
+ nntool/slurm/__init__.py,sha256=YHFIpTTVDR-xmj7xDmfsRfa25UNThO4qISVS01NtiPc,429
8
+ nntool/slurm/config.py,sha256=wDsdHhcc5tNR9nyLWG5SwcvKJoqo84dedz_3v0WcqdM,8260
9
+ nntool/slurm/function.py,sha256=_awI1g8hEaOQxBbRMmO7jJwqBXtYHPTaaZ9g1XK7JsM,8865
10
+ nntool/slurm/task.py,sha256=3rEgEXKL9q3VnDRtxRH6LzqaKjXDyiCOxe3aV6mgSeQ,11132
11
+ nntool/slurm/wrap.py,sha256=ZtHDqJD6Uk2rwumdiewz6Yl6tVtau5tZ5ZCo9VF_JJw,5291
12
+ nntool/slurm/accelerator/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
+ nntool/slurm/accelerator/utils.py,sha256=VuGfPGjcanDXmDHtZh7U_jEx0k45aBtAJAsUJs_LoRc,1276
14
+ nntool/slurm/core/__init__.py,sha256=920Q5lxUFOpngHN0ZAXRjgb4GqPXSyJAhlnQk_hbKYM,79
15
+ nntool/slurm/core/_slurm.py,sha256=1tZ_TDFFsolsjTXbFsBTJUd7pBRSXfVbJrBrwUukfdE,24660
16
+ nntool/slurm/core/_slurm_context.py,sha256=oW9R9fwa5IAnA8LtbLkyn_wPf-54aDMRoF8vhyhUQN8,1912
17
+ nntool/slurm/parser/__init__.py,sha256=yxIckzha7nSlsdjPYKrP9ozneeQFtBdziTPJeVPXAFY,72
18
+ nntool/slurm/parser/parse.py,sha256=IqlcqQfd7GoeTGJkMOrb2_-x0s49Zz_2fA3Fsk_6ysQ,552
19
+ nntool/utils/__init__.py,sha256=YnGpWfnDUhxOR1-wx0-tjKOdHr9cP3_B242xoIT5Kww,219
20
+ nntool/wandb/__init__.py,sha256=lExtgvdu-E60vOyqlNIc7cFF565U4fGFDHFKrZyMxtg,97
21
+ nntool/wandb/config.py,sha256=TzHZmatV8bGl_kixWgJx0MI_DiXjscEESXlbYsokIos,3712
22
+ nntool-2.0.0rc0.dist-info/METADATA,sha256=czd8bC5BJGXcmEqouT27A9H5Vjqpnlvb0ECmsuvafn8,496
23
+ nntool-2.0.0rc0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
24
+ nntool-2.0.0rc0.dist-info/top_level.txt,sha256=hXO7zSAI-o9plYeUkCiKbkDpQCfTDIfjivyoSEqT28Y,7
25
+ nntool-2.0.0rc0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.9.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1 @@
1
+ nntool