nshtrainer 1.0.0b53__py3-none-any.whl → 1.0.0b55__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -85,6 +85,7 @@ from nshtrainer.nn import NonlinearityConfig as NonlinearityConfig
85
85
  from nshtrainer.nn import NonlinearityConfigBase as NonlinearityConfigBase
86
86
  from nshtrainer.nn import PReLUConfig as PReLUConfig
87
87
  from nshtrainer.nn import ReLUNonlinearityConfig as ReLUNonlinearityConfig
88
+ from nshtrainer.nn import RNGConfig as RNGConfig
88
89
  from nshtrainer.nn import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
89
90
  from nshtrainer.nn import SiLUNonlinearityConfig as SiLUNonlinearityConfig
90
91
  from nshtrainer.nn import SoftmaxNonlinearityConfig as SoftmaxNonlinearityConfig
@@ -306,6 +307,7 @@ __all__ = [
306
307
  "ProfilerConfig",
307
308
  "PyTorchProfilerConfig",
308
309
  "RLPSanityChecksCallbackConfig",
310
+ "RNGConfig",
309
311
  "ReLUNonlinearityConfig",
310
312
  "ReduceLROnPlateauConfig",
311
313
  "SLURMEnvironmentPlugin",
@@ -12,6 +12,7 @@ from nshtrainer.lr_scheduler.base import lr_scheduler_registry as lr_scheduler_r
12
12
  from nshtrainer.lr_scheduler.linear_warmup_cosine import (
13
13
  DurationConfig as DurationConfig,
14
14
  )
15
+ from nshtrainer.lr_scheduler.reduce_lr_on_plateau import EpochsConfig as EpochsConfig
15
16
  from nshtrainer.lr_scheduler.reduce_lr_on_plateau import MetricConfig as MetricConfig
16
17
 
17
18
  from . import base as base
@@ -20,6 +21,7 @@ from . import reduce_lr_on_plateau as reduce_lr_on_plateau
20
21
 
21
22
  __all__ = [
22
23
  "DurationConfig",
24
+ "EpochsConfig",
23
25
  "LRSchedulerConfig",
24
26
  "LRSchedulerConfigBase",
25
27
  "LinearWarmupCosineDecayLRSchedulerConfig",
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
+ from nshtrainer.lr_scheduler.reduce_lr_on_plateau import EpochsConfig as EpochsConfig
5
6
  from nshtrainer.lr_scheduler.reduce_lr_on_plateau import (
6
7
  LRSchedulerConfigBase as LRSchedulerConfigBase,
7
8
  )
@@ -14,6 +15,7 @@ from nshtrainer.lr_scheduler.reduce_lr_on_plateau import (
14
15
  )
15
16
 
16
17
  __all__ = [
18
+ "EpochsConfig",
17
19
  "LRSchedulerConfigBase",
18
20
  "MetricConfig",
19
21
  "ReduceLROnPlateauConfig",
@@ -11,6 +11,7 @@ from nshtrainer.nn import NonlinearityConfig as NonlinearityConfig
11
11
  from nshtrainer.nn import NonlinearityConfigBase as NonlinearityConfigBase
12
12
  from nshtrainer.nn import PReLUConfig as PReLUConfig
13
13
  from nshtrainer.nn import ReLUNonlinearityConfig as ReLUNonlinearityConfig
14
+ from nshtrainer.nn import RNGConfig as RNGConfig
14
15
  from nshtrainer.nn import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
15
16
  from nshtrainer.nn import SiLUNonlinearityConfig as SiLUNonlinearityConfig
16
17
  from nshtrainer.nn import SoftmaxNonlinearityConfig as SoftmaxNonlinearityConfig
@@ -25,6 +26,7 @@ from nshtrainer.nn.nonlinearity import nonlinearity_registry as nonlinearity_reg
25
26
 
26
27
  from . import mlp as mlp
27
28
  from . import nonlinearity as nonlinearity
29
+ from . import rng as rng
28
30
 
29
31
  __all__ = [
30
32
  "ELUNonlinearityConfig",
@@ -35,6 +37,7 @@ __all__ = [
35
37
  "NonlinearityConfig",
36
38
  "NonlinearityConfigBase",
37
39
  "PReLUConfig",
40
+ "RNGConfig",
38
41
  "ReLUNonlinearityConfig",
39
42
  "SiLUNonlinearityConfig",
40
43
  "SigmoidNonlinearityConfig",
@@ -47,4 +50,5 @@ __all__ = [
47
50
  "mlp",
48
51
  "nonlinearity",
49
52
  "nonlinearity_registry",
53
+ "rng",
50
54
  ]
@@ -0,0 +1,9 @@
1
+ from __future__ import annotations
2
+
3
+ __codegen__ = True
4
+
5
+ from nshtrainer.nn.rng import RNGConfig as RNGConfig
6
+
7
+ __all__ = [
8
+ "RNGConfig",
9
+ ]
nshtrainer/nn/__init__.py CHANGED
@@ -3,7 +3,6 @@ from __future__ import annotations
3
3
  from .mlp import MLP as MLP
4
4
  from .mlp import MLPConfig as MLPConfig
5
5
  from .mlp import ResidualSequential as ResidualSequential
6
- from .mlp import custom_seed_context as custom_seed_context
7
6
  from .module_dict import TypedModuleDict as TypedModuleDict
8
7
  from .module_list import TypedModuleList as TypedModuleList
9
8
  from .nonlinearity import ELUNonlinearityConfig as ELUNonlinearityConfig
@@ -21,3 +20,5 @@ from .nonlinearity import SoftplusNonlinearityConfig as SoftplusNonlinearityConf
21
20
  from .nonlinearity import SoftsignNonlinearityConfig as SoftsignNonlinearityConfig
22
21
  from .nonlinearity import SwishNonlinearityConfig as SwishNonlinearityConfig
23
22
  from .nonlinearity import TanhNonlinearityConfig as TanhNonlinearityConfig
23
+ from .rng import RNGConfig as RNGConfig
24
+ from .rng import rng_context as rng_context
nshtrainer/nn/rng.py ADDED
@@ -0,0 +1,23 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+
5
+ import nshconfig as C
6
+ import torch
7
+
8
+
9
+ @contextlib.contextmanager
10
+ def rng_context(config: RNGConfig | None):
11
+ with contextlib.ExitStack() as stack:
12
+ if config is not None:
13
+ stack.enter_context(
14
+ torch.random.fork_rng(devices=range(torch.cuda.device_count()))
15
+ )
16
+ torch.manual_seed(config.seed)
17
+
18
+ yield
19
+
20
+
21
+ class RNGConfig(C.Config):
22
+ seed: int
23
+ """Random seed to use for initialization."""
@@ -457,7 +457,21 @@ class Trainer(LightningTrainer):
457
457
  ):
458
458
  filepath = Path(filepath)
459
459
 
460
- super().save_checkpoint(filepath, weights_only, storage_options)
460
+ if self.model is None:
461
+ raise AttributeError(
462
+ "Saving a checkpoint is only possible if a model is attached to the Trainer. Did you call"
463
+ " `Trainer.save_checkpoint()` before calling `Trainer.{fit,validate,test,predict}`?"
464
+ )
465
+ with self.profiler.profile("save_checkpoint"): # type: ignore
466
+ checkpoint = self._checkpoint_connector.dump_checkpoint(weights_only)
467
+ # Update the checkpoint for the trainer hyperparameters
468
+ checkpoint[self.CHECKPOINT_HYPER_PARAMS_KEY] = self.hparams.model_dump(
469
+ mode="json"
470
+ )
471
+ self.strategy.save_checkpoint(
472
+ checkpoint, filepath, storage_options=storage_options
473
+ )
474
+ self.strategy.barrier("Trainer.save_checkpoint")
461
475
 
462
476
  # Save the checkpoint metadata
463
477
  metadata_path = None
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshtrainer
3
- Version: 1.0.0b53
3
+ Version: 1.0.0b55
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -32,7 +32,7 @@ nshtrainer/callbacks/timer.py,sha256=gDcw_K_ikf0bkVgxQ0cDhvvNvz6GLZVLcatuKfh0ORU
32
32
  nshtrainer/callbacks/wandb_upload_code.py,sha256=shV7UtnXgY2bUlXdVrXiaDs0PNLlIt7TzNJkJPkzvzI,2414
33
33
  nshtrainer/callbacks/wandb_watch.py,sha256=VB14Dy5ZRXQ3di0fPv0K_DFJurLhroLPytnuwQBiJFg,3037
34
34
  nshtrainer/configs/.gitattributes,sha256=VeZmarvNEqiRBOHGcllpKm90nL6C8u4tBu7SEm7fj-E,26
35
- nshtrainer/configs/__init__.py,sha256=0BzCgE1iEJ0Ywmy__mqJZipLQtwZVdz6XK-gHbkA7GY,14650
35
+ nshtrainer/configs/__init__.py,sha256=-rGk9pnRnuz4yKvACGOpY3nkrWnHholqZGk7UP2Vkrc,14716
36
36
  nshtrainer/configs/_checkpoint/__init__.py,sha256=6s7Y68StboqscY2G4P_QG443jz5aiym5SjOogIljWLg,342
37
37
  nshtrainer/configs/_checkpoint/metadata/__init__.py,sha256=oOPfYkXTjKgm6pluGsG6V1TPyCEGjsQpHVL-LffSUFQ,290
38
38
  nshtrainer/configs/_directory/__init__.py,sha256=_oO7vM9DhzHSxtZcv86sTi7hZIptnK1gr-AP9mqQ370,386
@@ -67,15 +67,16 @@ nshtrainer/configs/loggers/base/__init__.py,sha256=HLUfEDbjaAXqzsFmQbjdciIWzR1st
67
67
  nshtrainer/configs/loggers/csv/__init__.py,sha256=gawaDX92JObGSmBqYpfNHWMHBwVOofS694W-1Y2GWDU,353
68
68
  nshtrainer/configs/loggers/tensorboard/__init__.py,sha256=phzm-TnBkdkibTgoOxIIcAliqL3zU8gSNK61Mwxs1CM,410
69
69
  nshtrainer/configs/loggers/wandb/__init__.py,sha256=TDcD5WZSKenc2mgIXhwz2l96l8P_Ur3N5CzEol5AKGw,746
70
- nshtrainer/configs/lr_scheduler/__init__.py,sha256=xtiUx0isxA82-uXMn4-KmPnDCfbUkpAnd2_pFupAAKQ,1137
70
+ nshtrainer/configs/lr_scheduler/__init__.py,sha256=PvH2d8QEC3TsC3_svcUbxeQEMMzIK_In0_Bp9xntSms,1243
71
71
  nshtrainer/configs/lr_scheduler/base/__init__.py,sha256=6Cx8r4rdxeSYxc_z0o7drKCblGJU_zzqrOoYlWYR5qY,305
72
72
  nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py,sha256=5ZMLDO9VL6SNU6pF-62lDnpmqix3_Ol9DdEwiuOPYlA,675
73
- nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py,sha256=w-vq8UbRGPX8DZVWCMC5eIrbvVc_guxjj7Du9AaeKCw,609
73
+ nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py,sha256=DBwZV590I5qwyOS5M43YhUzgYy1-AjzkM5aEnTA6XdI,715
74
74
  nshtrainer/configs/metrics/__init__.py,sha256=mK_xgXJDAyGY4K_x_Zo4aj36kjT45d850keuUe3U1rY,200
75
75
  nshtrainer/configs/metrics/_config/__init__.py,sha256=XDDvDPWULd_vd3lrgF2KGAVR2LVDuhdQvy-fF2ImarI,159
76
- nshtrainer/configs/nn/__init__.py,sha256=tkFG2Hb0oL_AmWP3_0WkDN2zI5PkVfrgwXhaAII7CZw,2072
76
+ nshtrainer/configs/nn/__init__.py,sha256=Ms2gIqbRxNVm6GHKCddCJTTqMwUPifjjHD_fCfJpcMQ,2174
77
77
  nshtrainer/configs/nn/mlp/__init__.py,sha256=O6kQ6utZNJPG9Fax5pRdZcHa3J-XFKKdXcc_PQg0jk0,347
78
78
  nshtrainer/configs/nn/nonlinearity/__init__.py,sha256=LCTbTyelCMABVw505CGQ4UpEGlAnIhflSLFqwAQXLQA,2155
79
+ nshtrainer/configs/nn/rng/__init__.py,sha256=4iC6vwxbfNeXyvpwZ1Z5Kcy-he4cu7mg3UpLD-RLrHc,141
79
80
  nshtrainer/configs/optimizer/__init__.py,sha256=itIDIHQvGm50eZ7JLyNElahnNUMPJ__4PMmTjc0RQ6o,444
80
81
  nshtrainer/configs/profiler/__init__.py,sha256=2ssaIpfVnvcbfNvZ-JeKp1Cx4NO1LknkVqTm1hu7Lvw,768
81
82
  nshtrainer/configs/profiler/_base/__init__.py,sha256=ekYfPg-VDhCAFM5nJka2TxUYdRDm1CKqjwUOQNbQjD4,176
@@ -119,11 +120,12 @@ nshtrainer/model/base.py,sha256=bZMNap0rkxRbAbu2BOHV_6YS2iZZnvy6wVSMOXGa_ZM,8680
119
120
  nshtrainer/model/mixins/callback.py,sha256=0LPgve4VszHbLipid4mpI1qnnmdGS2spivs0dXLvqHw,3154
120
121
  nshtrainer/model/mixins/debug.py,sha256=ydLuAAaa7M5bX0gougZ5gWuZnvn4Ra9assal3IZ9hq8,2086
121
122
  nshtrainer/model/mixins/logger.py,sha256=7u9fQig-SVFA9RFIB4U0gqJAzruh49mgmXXvZ6VkDUk,11694
122
- nshtrainer/nn/__init__.py,sha256=5Gg3nieGSC5_dXaI9KUVUUbM13hHexH9831m4hcf6no,1475
123
+ nshtrainer/nn/__init__.py,sha256=Vd246v2N9tBQ8XxmTquWzj5lAmeSnngrjpYOfp4LTXM,1499
123
124
  nshtrainer/nn/mlp.py,sha256=nYUgAISzuhC8sav6PloAdyz0PdEoikwppiXIuToEVdE,7550
124
125
  nshtrainer/nn/module_dict.py,sha256=9plb8aQUx5TUEPhX5jI9u8LrpTeKe7jZAHi8iIqcN8w,2365
125
126
  nshtrainer/nn/module_list.py,sha256=UB43pcwD_3nUke_DyLQt-iXKhWdKM6Zjm84lRC1hPYA,1755
126
127
  nshtrainer/nn/nonlinearity.py,sha256=xmaL4QCRvCxqmaGIOwetJeKK-6IK4m2OV7D3SjxSwJQ,6322
128
+ nshtrainer/nn/rng.py,sha256=IJGvX9v8qBkfgBrMlNU2aj-MbYTPoncFyJzvPkzCQpM,512
127
129
  nshtrainer/optimizer.py,sha256=u968GRNPUNn3f_9BEY2RBNuJq5O3wJWams3NG0dkrOA,1738
128
130
  nshtrainer/profiler/__init__.py,sha256=RjaNBoVcTFu8lF0dNlFp-2LaPYdonoIbDy2_KhgF0Ek,594
129
131
  nshtrainer/profiler/_base.py,sha256=kFcSVn9gJuMwgDxbfyHh46CmEAIPZjxw3yjPbKgzvwA,950
@@ -142,7 +144,7 @@ nshtrainer/trainer/plugin/layer_sync.py,sha256=h-ydZwXepnsw5-paLgiDatqPyQ_8C0QEv
142
144
  nshtrainer/trainer/plugin/precision.py,sha256=I0QsB1bVxmsFmBOkgrAfGONsuYae_lD9Bz0PfJEQvH4,5598
143
145
  nshtrainer/trainer/signal_connector.py,sha256=GhfGcSzfaTNhnj2QFkBDq5aT7FqbLMA7eC8SYQs8_8w,10828
144
146
  nshtrainer/trainer/strategy.py,sha256=VPTn5z3zvXTydY8IJchjhjcOfpvtoejnvUkq5E4WTus,1368
145
- nshtrainer/trainer/trainer.py,sha256=QEK-0bcw1y5Cconi99PYFXr0MElUGgGYMZ_SlcJUQ1k,20364
147
+ nshtrainer/trainer/trainer.py,sha256=Lo3vUo3ooTAjaX2fUYPFSMv5FP7sWfVov0QbA-T5hZ8,21113
146
148
  nshtrainer/util/_environment_info.py,sha256=MT8mBe6ZolRfKiwU-les1P-lPNPqXpHQcfADrh_A3uY,24629
147
149
  nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
148
150
  nshtrainer/util/config/__init__.py,sha256=Z39JJufSb61Lhn2GfVcv3eFW_eorOrN9-9llDWlnZZM,272
@@ -154,6 +156,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
154
156
  nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
155
157
  nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
156
158
  nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
157
- nshtrainer-1.0.0b53.dist-info/METADATA,sha256=xzPw_zTWp8gLpzKhvcAHTv8KbZCvxwMU_MIzMiq8j78,988
158
- nshtrainer-1.0.0b53.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
159
- nshtrainer-1.0.0b53.dist-info/RECORD,,
159
+ nshtrainer-1.0.0b55.dist-info/METADATA,sha256=3JHfGqw8kR8FKZT-E--W3H_rS5M-QQHH9U2EooKyO70,988
160
+ nshtrainer-1.0.0b55.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
161
+ nshtrainer-1.0.0b55.dist-info/RECORD,,