nshtrainer 0.39.0__py3-none-any.whl → 0.40.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +1 -0
- nshtrainer/data/__init__.py +1 -0
- nshtrainer/data/datamodule.py +5 -0
- nshtrainer/nn/module_list.py +2 -2
- nshtrainer/runner.py +44 -63
- {nshtrainer-0.39.0.dist-info → nshtrainer-0.40.0.dist-info}/METADATA +1 -1
- {nshtrainer-0.39.0.dist-info → nshtrainer-0.40.0.dist-info}/RECORD +8 -7
- {nshtrainer-0.39.0.dist-info → nshtrainer-0.40.0.dist-info}/WHEEL +0 -0
nshtrainer/__init__.py
CHANGED
|
@@ -8,6 +8,7 @@ from . import model as model
|
|
|
8
8
|
from . import nn as nn
|
|
9
9
|
from . import optimizer as optimizer
|
|
10
10
|
from . import profiler as profiler
|
|
11
|
+
from .data import LightningDataModuleBase as LightningDataModuleBase
|
|
11
12
|
from .metrics import MetricConfig as MetricConfig
|
|
12
13
|
from .model import BaseConfig as BaseConfig
|
|
13
14
|
from .model import LightningModuleBase as LightningModuleBase
|
nshtrainer/data/__init__.py
CHANGED
nshtrainer/nn/module_list.py
CHANGED
|
@@ -12,10 +12,10 @@ class TypedModuleList(nn.ModuleList, Generic[TModule]):
|
|
|
12
12
|
super().__init__(modules)
|
|
13
13
|
|
|
14
14
|
@overload
|
|
15
|
-
def __getitem__(self, idx:
|
|
15
|
+
def __getitem__(self, idx: slice) -> "TypedModuleList[TModule]": ...
|
|
16
16
|
|
|
17
17
|
@overload
|
|
18
|
-
def __getitem__(self, idx:
|
|
18
|
+
def __getitem__(self, idx: int) -> TModule: ...
|
|
19
19
|
|
|
20
20
|
@override
|
|
21
21
|
def __getitem__(self, idx: int | slice) -> TModule | "TypedModuleList[TModule]":
|
nshtrainer/runner.py
CHANGED
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
import copy
|
|
2
|
-
import
|
|
2
|
+
import logging
|
|
3
3
|
from collections.abc import Callable, Iterable, Mapping, Sequence
|
|
4
|
+
from pathlib import Path
|
|
4
5
|
from typing import Generic
|
|
5
6
|
|
|
6
|
-
|
|
7
|
-
from nshrunner import Runner as _Runner
|
|
7
|
+
import nshrunner as nr
|
|
8
8
|
from nshrunner._submit import screen
|
|
9
|
-
from typing_extensions import TypeVar, TypeVarTuple, Unpack, override
|
|
9
|
+
from typing_extensions import TypeVar, TypeVarTuple, Unpack, deprecated, override
|
|
10
10
|
|
|
11
11
|
from .model.config import BaseConfig
|
|
12
12
|
|
|
@@ -15,34 +15,27 @@ TArguments = TypeVarTuple("TArguments", default=Unpack[tuple[()]])
|
|
|
15
15
|
TReturn = TypeVar("TReturn", infer_variance=True)
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
@deprecated("Use nshrunner.Runner instead.")
|
|
18
19
|
class Runner(
|
|
19
|
-
|
|
20
|
+
nr.Runner[TReturn, TConfig, Unpack[TArguments]],
|
|
20
21
|
Generic[TReturn, TConfig, Unpack[TArguments]],
|
|
21
22
|
):
|
|
22
23
|
@override
|
|
23
|
-
|
|
24
|
-
def default_validate_fn(cls, config: TConfig, *args: Unpack[TArguments]) -> None:
|
|
25
|
-
super().default_validate_fn(config, *args)
|
|
26
|
-
|
|
27
|
-
@override
|
|
28
|
-
@classmethod
|
|
29
|
-
def default_info_fn(cls, config: TConfig, *args: Unpack[TArguments]) -> RunInfo:
|
|
30
|
-
run_info = super().default_info_fn(config, *args)
|
|
31
|
-
return {
|
|
32
|
-
**run_info,
|
|
33
|
-
"id": config.id,
|
|
34
|
-
"base_dir": config.directory.project_root,
|
|
35
|
-
}
|
|
36
|
-
|
|
37
|
-
def _fast_dev_run_transform(
|
|
24
|
+
def __init__(
|
|
38
25
|
self,
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
n_batches: int,
|
|
26
|
+
run_fn: Callable[[TConfig, Unpack[TArguments]], TReturn],
|
|
27
|
+
config: nr.RunnerConfig | None = None,
|
|
42
28
|
):
|
|
43
|
-
config
|
|
44
|
-
|
|
45
|
-
|
|
29
|
+
if config is None:
|
|
30
|
+
working_dir = Path.cwd() / "nshrunner"
|
|
31
|
+
working_dir.mkdir(exist_ok=True)
|
|
32
|
+
|
|
33
|
+
logging.warning(
|
|
34
|
+
f"`config` is not provided. Using default working directory of {working_dir}."
|
|
35
|
+
)
|
|
36
|
+
config = nr.RunnerConfig(working_dir=working_dir)
|
|
37
|
+
|
|
38
|
+
super().__init__(run_fn, config)
|
|
46
39
|
|
|
47
40
|
def fast_dev_run(
|
|
48
41
|
self,
|
|
@@ -50,20 +43,15 @@ class Runner(
|
|
|
50
43
|
n_batches: int = 1,
|
|
51
44
|
*,
|
|
52
45
|
env: Mapping[str, str] | None = None,
|
|
53
|
-
transforms: list[
|
|
54
|
-
Callable[[TConfig, Unpack[TArguments]], tuple[TConfig, Unpack[TArguments]]]
|
|
55
|
-
]
|
|
56
|
-
| None = None,
|
|
57
46
|
):
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
)
|
|
47
|
+
runs_updated: list[tuple[TConfig, Unpack[TArguments]]] = []
|
|
48
|
+
for args in runs:
|
|
49
|
+
config = copy.deepcopy(args[0])
|
|
50
|
+
config.trainer.fast_dev_run = n_batches
|
|
51
|
+
runs_updated.append((config, *args[1:]))
|
|
52
|
+
del runs
|
|
53
|
+
|
|
54
|
+
return self.local(runs_updated, env=env)
|
|
67
55
|
|
|
68
56
|
def fast_dev_run_generator(
|
|
69
57
|
self,
|
|
@@ -71,20 +59,15 @@ class Runner(
|
|
|
71
59
|
n_batches: int = 1,
|
|
72
60
|
*,
|
|
73
61
|
env: Mapping[str, str] | None = None,
|
|
74
|
-
transforms: list[
|
|
75
|
-
Callable[[TConfig, Unpack[TArguments]], tuple[TConfig, Unpack[TArguments]]]
|
|
76
|
-
]
|
|
77
|
-
| None = None,
|
|
78
62
|
):
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
)
|
|
63
|
+
runs_updated: list[tuple[TConfig, Unpack[TArguments]]] = []
|
|
64
|
+
for args in runs:
|
|
65
|
+
config = copy.deepcopy(args[0])
|
|
66
|
+
config.trainer.fast_dev_run = n_batches
|
|
67
|
+
runs_updated.append((config, *args[1:]))
|
|
68
|
+
del runs
|
|
69
|
+
|
|
70
|
+
return self.local_generator(runs_updated, env=env)
|
|
88
71
|
|
|
89
72
|
def fast_dev_run_session(
|
|
90
73
|
self,
|
|
@@ -92,27 +75,25 @@ class Runner(
|
|
|
92
75
|
options: screen.ScreenJobKwargs = {},
|
|
93
76
|
n_batches: int = 1,
|
|
94
77
|
*,
|
|
95
|
-
snapshot: Snapshot,
|
|
78
|
+
snapshot: nr.Snapshot,
|
|
96
79
|
setup_commands: Sequence[str] | None = None,
|
|
97
80
|
env: Mapping[str, str] | None = None,
|
|
98
|
-
transforms: list[
|
|
99
|
-
Callable[[TConfig, Unpack[TArguments]], tuple[TConfig, Unpack[TArguments]]]
|
|
100
|
-
]
|
|
101
|
-
| None = None,
|
|
102
81
|
activate_venv: bool = True,
|
|
103
82
|
print_command: bool = True,
|
|
104
83
|
):
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
84
|
+
runs_updated: list[tuple[TConfig, Unpack[TArguments]]] = []
|
|
85
|
+
for args in runs:
|
|
86
|
+
config = copy.deepcopy(args[0])
|
|
87
|
+
config.trainer.fast_dev_run = n_batches
|
|
88
|
+
runs_updated.append((config, *args[1:]))
|
|
89
|
+
del runs
|
|
90
|
+
|
|
109
91
|
return self.session(
|
|
110
|
-
|
|
92
|
+
runs_updated,
|
|
111
93
|
options,
|
|
112
94
|
snapshot=snapshot,
|
|
113
95
|
setup_commands=setup_commands,
|
|
114
96
|
env=env,
|
|
115
|
-
transforms=transforms,
|
|
116
97
|
activate_venv=activate_venv,
|
|
117
98
|
print_command=print_command,
|
|
118
99
|
)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
nshtrainer/__init__.py,sha256=
|
|
1
|
+
nshtrainer/__init__.py,sha256=8hx3uqzroMYOQM7S4U3Eznw7MHF3YDqUkO1MDQeC4cM,642
|
|
2
2
|
nshtrainer/_callback.py,sha256=A1zLsTy4b_wOYnInLLXGSRdHzT2yNa6mPEql-ozm0u0,1013
|
|
3
3
|
nshtrainer/_checkpoint/loader.py,sha256=5vjg-OFChXJjgiOVv8vnV8nwTscfdDtEdxQRz6uPfDE,14158
|
|
4
4
|
nshtrainer/_checkpoint/metadata.py,sha256=5D4PgKodzhLsmQvuF3xxkH49epKaegxi4wh_ImDTtns,4737
|
|
@@ -32,8 +32,9 @@ nshtrainer/callbacks/timer.py,sha256=quS79oYClDUvQxJkNWmDMe0hwRUkkREgTgqzVrnom50
|
|
|
32
32
|
nshtrainer/callbacks/wandb_upload_code.py,sha256=OWG4UkL2SfW6oj6AGRXeBJsZmgsqeHLW2Fj8Jm4ga3I,2298
|
|
33
33
|
nshtrainer/callbacks/wandb_watch.py,sha256=Y6SEXfIx3kDDQbI5zpP53BVq0FBLJbLd3RJsiHZk1-Y,2921
|
|
34
34
|
nshtrainer/config.py,sha256=pZyRZOkBRR7eBFRiHpHjQFNEFjaX9tYZIAqZtvKi6cA,8312
|
|
35
|
-
nshtrainer/data/__init__.py,sha256=
|
|
35
|
+
nshtrainer/data/__init__.py,sha256=fnigdEtNcB8DJqz4MHvcI6UwOZICFzWItskY78lLTMA,224
|
|
36
36
|
nshtrainer/data/balanced_batch_sampler.py,sha256=ybMJF-CguaZ17fLEweZ5suaGOiHOMEm3Bn8rQfGTzGQ,5445
|
|
37
|
+
nshtrainer/data/datamodule.py,sha256=Dk06LdcbK13sCLQWacdCoXU_CHWZxoKEZ_xpAg7nZQg,113
|
|
37
38
|
nshtrainer/data/transform.py,sha256=6SNs3_TpNpfhcwTwvPKyEJ3opM1OT7LmMEYQNHKgRl8,2227
|
|
38
39
|
nshtrainer/ll/__init__.py,sha256=L-aTi1V1bbvnZjOro8NvI393zbHQSFR9movWSRK9Mds,2477
|
|
39
40
|
nshtrainer/ll/_experimental.py,sha256=oBQCKOEVYoxuUU9eLb-Fg2B2mzZD7SA0zfAO6lmWZ88,53
|
|
@@ -71,7 +72,7 @@ nshtrainer/model/mixins/logger.py,sha256=xOymSTofukEYZGkGojXsMEO__ZlBI5lIPZVmlot
|
|
|
71
72
|
nshtrainer/nn/__init__.py,sha256=0QPFl02a71WZQjLMGOlFNMmsYP5aa1q3eABHmnWH58Q,1427
|
|
72
73
|
nshtrainer/nn/mlp.py,sha256=V0FrScpIUdg_IgIO8GMtIsGEtmHjwF14i2IWxmZrsqg,5952
|
|
73
74
|
nshtrainer/nn/module_dict.py,sha256=NOY0B6WDTnktyWH4GthsprMQo0bpehC-hCq9SfD8paE,2329
|
|
74
|
-
nshtrainer/nn/module_list.py,sha256=
|
|
75
|
+
nshtrainer/nn/module_list.py,sha256=bbR62bnQM4No6HtmBllijOhLwDtWwmcPcO2DklPbsbc,1719
|
|
75
76
|
nshtrainer/nn/nonlinearity.py,sha256=4sYE4MN5zojc-go1k0PYtqssVRuXrM7D4tbpIXp5K-E,6078
|
|
76
77
|
nshtrainer/optimizer.py,sha256=kuJEA1pvB3y1FcsfhAoOJujVqEZqFHlmYO8GW6JeA1g,1527
|
|
77
78
|
nshtrainer/profiler/__init__.py,sha256=cJ_wAm8j3Bz6cKgNQ_9gQLZz9nddKW53VE81UKg8l8g,480
|
|
@@ -79,7 +80,7 @@ nshtrainer/profiler/_base.py,sha256=YF5lsJBIl9qts9GLW5Z62JuYeo4SnIArhlFwTGkfTb4,
|
|
|
79
80
|
nshtrainer/profiler/advanced.py,sha256=44asloha0aGUW8YwjQt3lm3ve8H-N6mM4QgseUSLT30,1112
|
|
80
81
|
nshtrainer/profiler/pytorch.py,sha256=tGeRvoPP5ulWX2RkfXrQvMBoki1T95dpz5p8mwyon1I,2709
|
|
81
82
|
nshtrainer/profiler/simple.py,sha256=MbMfsJvligd0mtGiltxJ0T8MQVDP9T9BzQZFwswl66Y,957
|
|
82
|
-
nshtrainer/runner.py,sha256=
|
|
83
|
+
nshtrainer/runner.py,sha256=llW5PThXfzlu8ym1waFVzED1RGI91OHqN3j02sepoZo,3153
|
|
83
84
|
nshtrainer/scripts/find_packages.py,sha256=ixYivZobumyyGsf2B9oYMLyLTRcBzY_vUv-u3bNW-hs,1424
|
|
84
85
|
nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3E,40
|
|
85
86
|
nshtrainer/trainer/_config.py,sha256=YqpGb4RodkUg87TVE5WBSc4CQkUF0z3qDRdil1HRxoM,29198
|
|
@@ -99,6 +100,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
|
|
|
99
100
|
nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
|
|
100
101
|
nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
|
|
101
102
|
nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
|
|
102
|
-
nshtrainer-0.
|
|
103
|
-
nshtrainer-0.
|
|
104
|
-
nshtrainer-0.
|
|
103
|
+
nshtrainer-0.40.0.dist-info/METADATA,sha256=E8dGtOSZ6jdCxugNpIutAeUyd7CcYSRH9cBchBDY7UE,916
|
|
104
|
+
nshtrainer-0.40.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
105
|
+
nshtrainer-0.40.0.dist-info/RECORD,,
|
|
File without changes
|