nshtrainer 1.0.0b53__py3-none-any.whl → 1.0.0b55__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/configs/__init__.py +2 -0
- nshtrainer/configs/lr_scheduler/__init__.py +2 -0
- nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +2 -0
- nshtrainer/configs/nn/__init__.py +4 -0
- nshtrainer/configs/nn/rng/__init__.py +9 -0
- nshtrainer/nn/__init__.py +2 -1
- nshtrainer/nn/rng.py +23 -0
- nshtrainer/trainer/trainer.py +15 -1
- {nshtrainer-1.0.0b53.dist-info → nshtrainer-1.0.0b55.dist-info}/METADATA +1 -1
- {nshtrainer-1.0.0b53.dist-info → nshtrainer-1.0.0b55.dist-info}/RECORD +11 -9
- {nshtrainer-1.0.0b53.dist-info → nshtrainer-1.0.0b55.dist-info}/WHEEL +0 -0
nshtrainer/configs/__init__.py
CHANGED
@@ -85,6 +85,7 @@ from nshtrainer.nn import NonlinearityConfig as NonlinearityConfig
|
|
85
85
|
from nshtrainer.nn import NonlinearityConfigBase as NonlinearityConfigBase
|
86
86
|
from nshtrainer.nn import PReLUConfig as PReLUConfig
|
87
87
|
from nshtrainer.nn import ReLUNonlinearityConfig as ReLUNonlinearityConfig
|
88
|
+
from nshtrainer.nn import RNGConfig as RNGConfig
|
88
89
|
from nshtrainer.nn import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
|
89
90
|
from nshtrainer.nn import SiLUNonlinearityConfig as SiLUNonlinearityConfig
|
90
91
|
from nshtrainer.nn import SoftmaxNonlinearityConfig as SoftmaxNonlinearityConfig
|
@@ -306,6 +307,7 @@ __all__ = [
|
|
306
307
|
"ProfilerConfig",
|
307
308
|
"PyTorchProfilerConfig",
|
308
309
|
"RLPSanityChecksCallbackConfig",
|
310
|
+
"RNGConfig",
|
309
311
|
"ReLUNonlinearityConfig",
|
310
312
|
"ReduceLROnPlateauConfig",
|
311
313
|
"SLURMEnvironmentPlugin",
|
@@ -12,6 +12,7 @@ from nshtrainer.lr_scheduler.base import lr_scheduler_registry as lr_scheduler_r
|
|
12
12
|
from nshtrainer.lr_scheduler.linear_warmup_cosine import (
|
13
13
|
DurationConfig as DurationConfig,
|
14
14
|
)
|
15
|
+
from nshtrainer.lr_scheduler.reduce_lr_on_plateau import EpochsConfig as EpochsConfig
|
15
16
|
from nshtrainer.lr_scheduler.reduce_lr_on_plateau import MetricConfig as MetricConfig
|
16
17
|
|
17
18
|
from . import base as base
|
@@ -20,6 +21,7 @@ from . import reduce_lr_on_plateau as reduce_lr_on_plateau
|
|
20
21
|
|
21
22
|
__all__ = [
|
22
23
|
"DurationConfig",
|
24
|
+
"EpochsConfig",
|
23
25
|
"LRSchedulerConfig",
|
24
26
|
"LRSchedulerConfigBase",
|
25
27
|
"LinearWarmupCosineDecayLRSchedulerConfig",
|
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__codegen__ = True
|
4
4
|
|
5
|
+
from nshtrainer.lr_scheduler.reduce_lr_on_plateau import EpochsConfig as EpochsConfig
|
5
6
|
from nshtrainer.lr_scheduler.reduce_lr_on_plateau import (
|
6
7
|
LRSchedulerConfigBase as LRSchedulerConfigBase,
|
7
8
|
)
|
@@ -14,6 +15,7 @@ from nshtrainer.lr_scheduler.reduce_lr_on_plateau import (
|
|
14
15
|
)
|
15
16
|
|
16
17
|
__all__ = [
|
18
|
+
"EpochsConfig",
|
17
19
|
"LRSchedulerConfigBase",
|
18
20
|
"MetricConfig",
|
19
21
|
"ReduceLROnPlateauConfig",
|
@@ -11,6 +11,7 @@ from nshtrainer.nn import NonlinearityConfig as NonlinearityConfig
|
|
11
11
|
from nshtrainer.nn import NonlinearityConfigBase as NonlinearityConfigBase
|
12
12
|
from nshtrainer.nn import PReLUConfig as PReLUConfig
|
13
13
|
from nshtrainer.nn import ReLUNonlinearityConfig as ReLUNonlinearityConfig
|
14
|
+
from nshtrainer.nn import RNGConfig as RNGConfig
|
14
15
|
from nshtrainer.nn import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
|
15
16
|
from nshtrainer.nn import SiLUNonlinearityConfig as SiLUNonlinearityConfig
|
16
17
|
from nshtrainer.nn import SoftmaxNonlinearityConfig as SoftmaxNonlinearityConfig
|
@@ -25,6 +26,7 @@ from nshtrainer.nn.nonlinearity import nonlinearity_registry as nonlinearity_reg
|
|
25
26
|
|
26
27
|
from . import mlp as mlp
|
27
28
|
from . import nonlinearity as nonlinearity
|
29
|
+
from . import rng as rng
|
28
30
|
|
29
31
|
__all__ = [
|
30
32
|
"ELUNonlinearityConfig",
|
@@ -35,6 +37,7 @@ __all__ = [
|
|
35
37
|
"NonlinearityConfig",
|
36
38
|
"NonlinearityConfigBase",
|
37
39
|
"PReLUConfig",
|
40
|
+
"RNGConfig",
|
38
41
|
"ReLUNonlinearityConfig",
|
39
42
|
"SiLUNonlinearityConfig",
|
40
43
|
"SigmoidNonlinearityConfig",
|
@@ -47,4 +50,5 @@ __all__ = [
|
|
47
50
|
"mlp",
|
48
51
|
"nonlinearity",
|
49
52
|
"nonlinearity_registry",
|
53
|
+
"rng",
|
50
54
|
]
|
nshtrainer/nn/__init__.py
CHANGED
@@ -3,7 +3,6 @@ from __future__ import annotations
|
|
3
3
|
from .mlp import MLP as MLP
|
4
4
|
from .mlp import MLPConfig as MLPConfig
|
5
5
|
from .mlp import ResidualSequential as ResidualSequential
|
6
|
-
from .mlp import custom_seed_context as custom_seed_context
|
7
6
|
from .module_dict import TypedModuleDict as TypedModuleDict
|
8
7
|
from .module_list import TypedModuleList as TypedModuleList
|
9
8
|
from .nonlinearity import ELUNonlinearityConfig as ELUNonlinearityConfig
|
@@ -21,3 +20,5 @@ from .nonlinearity import SoftplusNonlinearityConfig as SoftplusNonlinearityConf
|
|
21
20
|
from .nonlinearity import SoftsignNonlinearityConfig as SoftsignNonlinearityConfig
|
22
21
|
from .nonlinearity import SwishNonlinearityConfig as SwishNonlinearityConfig
|
23
22
|
from .nonlinearity import TanhNonlinearityConfig as TanhNonlinearityConfig
|
23
|
+
from .rng import RNGConfig as RNGConfig
|
24
|
+
from .rng import rng_context as rng_context
|
nshtrainer/nn/rng.py
ADDED
@@ -0,0 +1,23 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import contextlib
|
4
|
+
|
5
|
+
import nshconfig as C
|
6
|
+
import torch
|
7
|
+
|
8
|
+
|
9
|
+
@contextlib.contextmanager
|
10
|
+
def rng_context(config: RNGConfig | None):
|
11
|
+
with contextlib.ExitStack() as stack:
|
12
|
+
if config is not None:
|
13
|
+
stack.enter_context(
|
14
|
+
torch.random.fork_rng(devices=range(torch.cuda.device_count()))
|
15
|
+
)
|
16
|
+
torch.manual_seed(config.seed)
|
17
|
+
|
18
|
+
yield
|
19
|
+
|
20
|
+
|
21
|
+
class RNGConfig(C.Config):
|
22
|
+
seed: int
|
23
|
+
"""Random seed to use for initialization."""
|
nshtrainer/trainer/trainer.py
CHANGED
@@ -457,7 +457,21 @@ class Trainer(LightningTrainer):
|
|
457
457
|
):
|
458
458
|
filepath = Path(filepath)
|
459
459
|
|
460
|
-
|
460
|
+
if self.model is None:
|
461
|
+
raise AttributeError(
|
462
|
+
"Saving a checkpoint is only possible if a model is attached to the Trainer. Did you call"
|
463
|
+
" `Trainer.save_checkpoint()` before calling `Trainer.{fit,validate,test,predict}`?"
|
464
|
+
)
|
465
|
+
with self.profiler.profile("save_checkpoint"): # type: ignore
|
466
|
+
checkpoint = self._checkpoint_connector.dump_checkpoint(weights_only)
|
467
|
+
# Update the checkpoint for the trainer hyperparameters
|
468
|
+
checkpoint[self.CHECKPOINT_HYPER_PARAMS_KEY] = self.hparams.model_dump(
|
469
|
+
mode="json"
|
470
|
+
)
|
471
|
+
self.strategy.save_checkpoint(
|
472
|
+
checkpoint, filepath, storage_options=storage_options
|
473
|
+
)
|
474
|
+
self.strategy.barrier("Trainer.save_checkpoint")
|
461
475
|
|
462
476
|
# Save the checkpoint metadata
|
463
477
|
metadata_path = None
|
@@ -32,7 +32,7 @@ nshtrainer/callbacks/timer.py,sha256=gDcw_K_ikf0bkVgxQ0cDhvvNvz6GLZVLcatuKfh0ORU
|
|
32
32
|
nshtrainer/callbacks/wandb_upload_code.py,sha256=shV7UtnXgY2bUlXdVrXiaDs0PNLlIt7TzNJkJPkzvzI,2414
|
33
33
|
nshtrainer/callbacks/wandb_watch.py,sha256=VB14Dy5ZRXQ3di0fPv0K_DFJurLhroLPytnuwQBiJFg,3037
|
34
34
|
nshtrainer/configs/.gitattributes,sha256=VeZmarvNEqiRBOHGcllpKm90nL6C8u4tBu7SEm7fj-E,26
|
35
|
-
nshtrainer/configs/__init__.py,sha256
|
35
|
+
nshtrainer/configs/__init__.py,sha256=-rGk9pnRnuz4yKvACGOpY3nkrWnHholqZGk7UP2Vkrc,14716
|
36
36
|
nshtrainer/configs/_checkpoint/__init__.py,sha256=6s7Y68StboqscY2G4P_QG443jz5aiym5SjOogIljWLg,342
|
37
37
|
nshtrainer/configs/_checkpoint/metadata/__init__.py,sha256=oOPfYkXTjKgm6pluGsG6V1TPyCEGjsQpHVL-LffSUFQ,290
|
38
38
|
nshtrainer/configs/_directory/__init__.py,sha256=_oO7vM9DhzHSxtZcv86sTi7hZIptnK1gr-AP9mqQ370,386
|
@@ -67,15 +67,16 @@ nshtrainer/configs/loggers/base/__init__.py,sha256=HLUfEDbjaAXqzsFmQbjdciIWzR1st
|
|
67
67
|
nshtrainer/configs/loggers/csv/__init__.py,sha256=gawaDX92JObGSmBqYpfNHWMHBwVOofS694W-1Y2GWDU,353
|
68
68
|
nshtrainer/configs/loggers/tensorboard/__init__.py,sha256=phzm-TnBkdkibTgoOxIIcAliqL3zU8gSNK61Mwxs1CM,410
|
69
69
|
nshtrainer/configs/loggers/wandb/__init__.py,sha256=TDcD5WZSKenc2mgIXhwz2l96l8P_Ur3N5CzEol5AKGw,746
|
70
|
-
nshtrainer/configs/lr_scheduler/__init__.py,sha256=
|
70
|
+
nshtrainer/configs/lr_scheduler/__init__.py,sha256=PvH2d8QEC3TsC3_svcUbxeQEMMzIK_In0_Bp9xntSms,1243
|
71
71
|
nshtrainer/configs/lr_scheduler/base/__init__.py,sha256=6Cx8r4rdxeSYxc_z0o7drKCblGJU_zzqrOoYlWYR5qY,305
|
72
72
|
nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py,sha256=5ZMLDO9VL6SNU6pF-62lDnpmqix3_Ol9DdEwiuOPYlA,675
|
73
|
-
nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py,sha256=
|
73
|
+
nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py,sha256=DBwZV590I5qwyOS5M43YhUzgYy1-AjzkM5aEnTA6XdI,715
|
74
74
|
nshtrainer/configs/metrics/__init__.py,sha256=mK_xgXJDAyGY4K_x_Zo4aj36kjT45d850keuUe3U1rY,200
|
75
75
|
nshtrainer/configs/metrics/_config/__init__.py,sha256=XDDvDPWULd_vd3lrgF2KGAVR2LVDuhdQvy-fF2ImarI,159
|
76
|
-
nshtrainer/configs/nn/__init__.py,sha256=
|
76
|
+
nshtrainer/configs/nn/__init__.py,sha256=Ms2gIqbRxNVm6GHKCddCJTTqMwUPifjjHD_fCfJpcMQ,2174
|
77
77
|
nshtrainer/configs/nn/mlp/__init__.py,sha256=O6kQ6utZNJPG9Fax5pRdZcHa3J-XFKKdXcc_PQg0jk0,347
|
78
78
|
nshtrainer/configs/nn/nonlinearity/__init__.py,sha256=LCTbTyelCMABVw505CGQ4UpEGlAnIhflSLFqwAQXLQA,2155
|
79
|
+
nshtrainer/configs/nn/rng/__init__.py,sha256=4iC6vwxbfNeXyvpwZ1Z5Kcy-he4cu7mg3UpLD-RLrHc,141
|
79
80
|
nshtrainer/configs/optimizer/__init__.py,sha256=itIDIHQvGm50eZ7JLyNElahnNUMPJ__4PMmTjc0RQ6o,444
|
80
81
|
nshtrainer/configs/profiler/__init__.py,sha256=2ssaIpfVnvcbfNvZ-JeKp1Cx4NO1LknkVqTm1hu7Lvw,768
|
81
82
|
nshtrainer/configs/profiler/_base/__init__.py,sha256=ekYfPg-VDhCAFM5nJka2TxUYdRDm1CKqjwUOQNbQjD4,176
|
@@ -119,11 +120,12 @@ nshtrainer/model/base.py,sha256=bZMNap0rkxRbAbu2BOHV_6YS2iZZnvy6wVSMOXGa_ZM,8680
|
|
119
120
|
nshtrainer/model/mixins/callback.py,sha256=0LPgve4VszHbLipid4mpI1qnnmdGS2spivs0dXLvqHw,3154
|
120
121
|
nshtrainer/model/mixins/debug.py,sha256=ydLuAAaa7M5bX0gougZ5gWuZnvn4Ra9assal3IZ9hq8,2086
|
121
122
|
nshtrainer/model/mixins/logger.py,sha256=7u9fQig-SVFA9RFIB4U0gqJAzruh49mgmXXvZ6VkDUk,11694
|
122
|
-
nshtrainer/nn/__init__.py,sha256=
|
123
|
+
nshtrainer/nn/__init__.py,sha256=Vd246v2N9tBQ8XxmTquWzj5lAmeSnngrjpYOfp4LTXM,1499
|
123
124
|
nshtrainer/nn/mlp.py,sha256=nYUgAISzuhC8sav6PloAdyz0PdEoikwppiXIuToEVdE,7550
|
124
125
|
nshtrainer/nn/module_dict.py,sha256=9plb8aQUx5TUEPhX5jI9u8LrpTeKe7jZAHi8iIqcN8w,2365
|
125
126
|
nshtrainer/nn/module_list.py,sha256=UB43pcwD_3nUke_DyLQt-iXKhWdKM6Zjm84lRC1hPYA,1755
|
126
127
|
nshtrainer/nn/nonlinearity.py,sha256=xmaL4QCRvCxqmaGIOwetJeKK-6IK4m2OV7D3SjxSwJQ,6322
|
128
|
+
nshtrainer/nn/rng.py,sha256=IJGvX9v8qBkfgBrMlNU2aj-MbYTPoncFyJzvPkzCQpM,512
|
127
129
|
nshtrainer/optimizer.py,sha256=u968GRNPUNn3f_9BEY2RBNuJq5O3wJWams3NG0dkrOA,1738
|
128
130
|
nshtrainer/profiler/__init__.py,sha256=RjaNBoVcTFu8lF0dNlFp-2LaPYdonoIbDy2_KhgF0Ek,594
|
129
131
|
nshtrainer/profiler/_base.py,sha256=kFcSVn9gJuMwgDxbfyHh46CmEAIPZjxw3yjPbKgzvwA,950
|
@@ -142,7 +144,7 @@ nshtrainer/trainer/plugin/layer_sync.py,sha256=h-ydZwXepnsw5-paLgiDatqPyQ_8C0QEv
|
|
142
144
|
nshtrainer/trainer/plugin/precision.py,sha256=I0QsB1bVxmsFmBOkgrAfGONsuYae_lD9Bz0PfJEQvH4,5598
|
143
145
|
nshtrainer/trainer/signal_connector.py,sha256=GhfGcSzfaTNhnj2QFkBDq5aT7FqbLMA7eC8SYQs8_8w,10828
|
144
146
|
nshtrainer/trainer/strategy.py,sha256=VPTn5z3zvXTydY8IJchjhjcOfpvtoejnvUkq5E4WTus,1368
|
145
|
-
nshtrainer/trainer/trainer.py,sha256=
|
147
|
+
nshtrainer/trainer/trainer.py,sha256=Lo3vUo3ooTAjaX2fUYPFSMv5FP7sWfVov0QbA-T5hZ8,21113
|
146
148
|
nshtrainer/util/_environment_info.py,sha256=MT8mBe6ZolRfKiwU-les1P-lPNPqXpHQcfADrh_A3uY,24629
|
147
149
|
nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
|
148
150
|
nshtrainer/util/config/__init__.py,sha256=Z39JJufSb61Lhn2GfVcv3eFW_eorOrN9-9llDWlnZZM,272
|
@@ -154,6 +156,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
|
|
154
156
|
nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
|
155
157
|
nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
|
156
158
|
nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
|
157
|
-
nshtrainer-1.0.
|
158
|
-
nshtrainer-1.0.
|
159
|
-
nshtrainer-1.0.
|
159
|
+
nshtrainer-1.0.0b55.dist-info/METADATA,sha256=3JHfGqw8kR8FKZT-E--W3H_rS5M-QQHH9U2EooKyO70,988
|
160
|
+
nshtrainer-1.0.0b55.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
|
161
|
+
nshtrainer-1.0.0b55.dist-info/RECORD,,
|
File without changes
|