nshtrainer 1.0.0b54__py3-none-any.whl → 1.0.0b56__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -85,6 +85,7 @@ from nshtrainer.nn import NonlinearityConfig as NonlinearityConfig
85
85
  from nshtrainer.nn import NonlinearityConfigBase as NonlinearityConfigBase
86
86
  from nshtrainer.nn import PReLUConfig as PReLUConfig
87
87
  from nshtrainer.nn import ReLUNonlinearityConfig as ReLUNonlinearityConfig
88
+ from nshtrainer.nn import RNGConfig as RNGConfig
88
89
  from nshtrainer.nn import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
89
90
  from nshtrainer.nn import SiLUNonlinearityConfig as SiLUNonlinearityConfig
90
91
  from nshtrainer.nn import SoftmaxNonlinearityConfig as SoftmaxNonlinearityConfig
@@ -306,6 +307,7 @@ __all__ = [
306
307
  "ProfilerConfig",
307
308
  "PyTorchProfilerConfig",
308
309
  "RLPSanityChecksCallbackConfig",
310
+ "RNGConfig",
309
311
  "ReLUNonlinearityConfig",
310
312
  "ReduceLROnPlateauConfig",
311
313
  "SLURMEnvironmentPlugin",
@@ -12,6 +12,7 @@ from nshtrainer.lr_scheduler.base import lr_scheduler_registry as lr_scheduler_r
12
12
  from nshtrainer.lr_scheduler.linear_warmup_cosine import (
13
13
  DurationConfig as DurationConfig,
14
14
  )
15
+ from nshtrainer.lr_scheduler.reduce_lr_on_plateau import EpochsConfig as EpochsConfig
15
16
  from nshtrainer.lr_scheduler.reduce_lr_on_plateau import MetricConfig as MetricConfig
16
17
 
17
18
  from . import base as base
@@ -20,6 +21,7 @@ from . import reduce_lr_on_plateau as reduce_lr_on_plateau
20
21
 
21
22
  __all__ = [
22
23
  "DurationConfig",
24
+ "EpochsConfig",
23
25
  "LRSchedulerConfig",
24
26
  "LRSchedulerConfigBase",
25
27
  "LinearWarmupCosineDecayLRSchedulerConfig",
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
+ from nshtrainer.lr_scheduler.reduce_lr_on_plateau import EpochsConfig as EpochsConfig
5
6
  from nshtrainer.lr_scheduler.reduce_lr_on_plateau import (
6
7
  LRSchedulerConfigBase as LRSchedulerConfigBase,
7
8
  )
@@ -14,6 +15,7 @@ from nshtrainer.lr_scheduler.reduce_lr_on_plateau import (
14
15
  )
15
16
 
16
17
  __all__ = [
18
+ "EpochsConfig",
17
19
  "LRSchedulerConfigBase",
18
20
  "MetricConfig",
19
21
  "ReduceLROnPlateauConfig",
@@ -11,6 +11,7 @@ from nshtrainer.nn import NonlinearityConfig as NonlinearityConfig
11
11
  from nshtrainer.nn import NonlinearityConfigBase as NonlinearityConfigBase
12
12
  from nshtrainer.nn import PReLUConfig as PReLUConfig
13
13
  from nshtrainer.nn import ReLUNonlinearityConfig as ReLUNonlinearityConfig
14
+ from nshtrainer.nn import RNGConfig as RNGConfig
14
15
  from nshtrainer.nn import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
15
16
  from nshtrainer.nn import SiLUNonlinearityConfig as SiLUNonlinearityConfig
16
17
  from nshtrainer.nn import SoftmaxNonlinearityConfig as SoftmaxNonlinearityConfig
@@ -25,6 +26,7 @@ from nshtrainer.nn.nonlinearity import nonlinearity_registry as nonlinearity_reg
25
26
 
26
27
  from . import mlp as mlp
27
28
  from . import nonlinearity as nonlinearity
29
+ from . import rng as rng
28
30
 
29
31
  __all__ = [
30
32
  "ELUNonlinearityConfig",
@@ -35,6 +37,7 @@ __all__ = [
35
37
  "NonlinearityConfig",
36
38
  "NonlinearityConfigBase",
37
39
  "PReLUConfig",
40
+ "RNGConfig",
38
41
  "ReLUNonlinearityConfig",
39
42
  "SiLUNonlinearityConfig",
40
43
  "SigmoidNonlinearityConfig",
@@ -47,4 +50,5 @@ __all__ = [
47
50
  "mlp",
48
51
  "nonlinearity",
49
52
  "nonlinearity_registry",
53
+ "rng",
50
54
  ]
@@ -0,0 +1,9 @@
1
+ from __future__ import annotations
2
+
3
+ __codegen__ = True
4
+
5
+ from nshtrainer.nn.rng import RNGConfig as RNGConfig
6
+
7
+ __all__ = [
8
+ "RNGConfig",
9
+ ]
@@ -2,12 +2,10 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
- from nshtrainer.trainer.plugin.base import PluginConfig as PluginConfig
6
5
  from nshtrainer.trainer.plugin.base import PluginConfigBase as PluginConfigBase
7
6
  from nshtrainer.trainer.plugin.base import plugin_registry as plugin_registry
8
7
 
9
8
  __all__ = [
10
- "PluginConfig",
11
9
  "PluginConfigBase",
12
10
  "plugin_registry",
13
11
  ]
@@ -2,7 +2,6 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
- from nshtrainer.trainer.plugin.environment import DTypeConfig as DTypeConfig
6
5
  from nshtrainer.trainer.plugin.environment import (
7
6
  KubeflowEnvironmentPlugin as KubeflowEnvironmentPlugin,
8
7
  )
@@ -28,7 +27,6 @@ from nshtrainer.trainer.plugin.environment import (
28
27
  from nshtrainer.trainer.plugin.environment import plugin_registry as plugin_registry
29
28
 
30
29
  __all__ = [
31
- "DTypeConfig",
32
30
  "KubeflowEnvironmentPlugin",
33
31
  "LSFEnvironmentPlugin",
34
32
  "LightningEnvironmentPlugin",
@@ -5,7 +5,6 @@ __codegen__ = True
5
5
  from nshtrainer.trainer.plugin.io import (
6
6
  AsyncCheckpointIOPlugin as AsyncCheckpointIOPlugin,
7
7
  )
8
- from nshtrainer.trainer.plugin.io import PluginConfig as PluginConfig
9
8
  from nshtrainer.trainer.plugin.io import PluginConfigBase as PluginConfigBase
10
9
  from nshtrainer.trainer.plugin.io import (
11
10
  TorchCheckpointIOPlugin as TorchCheckpointIOPlugin,
@@ -15,7 +14,6 @@ from nshtrainer.trainer.plugin.io import plugin_registry as plugin_registry
15
14
 
16
15
  __all__ = [
17
16
  "AsyncCheckpointIOPlugin",
18
- "PluginConfig",
19
17
  "PluginConfigBase",
20
18
  "TorchCheckpointIOPlugin",
21
19
  "XLACheckpointIOPlugin",
nshtrainer/nn/__init__.py CHANGED
@@ -3,7 +3,6 @@ from __future__ import annotations
3
3
  from .mlp import MLP as MLP
4
4
  from .mlp import MLPConfig as MLPConfig
5
5
  from .mlp import ResidualSequential as ResidualSequential
6
- from .mlp import custom_seed_context as custom_seed_context
7
6
  from .module_dict import TypedModuleDict as TypedModuleDict
8
7
  from .module_list import TypedModuleList as TypedModuleList
9
8
  from .nonlinearity import ELUNonlinearityConfig as ELUNonlinearityConfig
@@ -21,3 +20,5 @@ from .nonlinearity import SoftplusNonlinearityConfig as SoftplusNonlinearityConf
21
20
  from .nonlinearity import SoftsignNonlinearityConfig as SoftsignNonlinearityConfig
22
21
  from .nonlinearity import SwishNonlinearityConfig as SwishNonlinearityConfig
23
22
  from .nonlinearity import TanhNonlinearityConfig as TanhNonlinearityConfig
23
+ from .rng import RNGConfig as RNGConfig
24
+ from .rng import rng_context as rng_context
nshtrainer/nn/rng.py ADDED
@@ -0,0 +1,23 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+
5
+ import nshconfig as C
6
+ import torch
7
+
8
+
9
+ @contextlib.contextmanager
10
+ def rng_context(config: RNGConfig | None):
11
+ with contextlib.ExitStack() as stack:
12
+ if config is not None:
13
+ stack.enter_context(
14
+ torch.random.fork_rng(devices=range(torch.cuda.device_count()))
15
+ )
16
+ torch.manual_seed(config.seed)
17
+
18
+ yield
19
+
20
+
21
+ class RNGConfig(C.Config):
22
+ seed: int
23
+ """Random seed to use for initialization."""
@@ -1,10 +1,18 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from typing import Annotated
4
+
5
+ from typing_extensions import TypeAliasType
6
+
3
7
  from . import environment as environment
4
8
  from . import io as io
5
9
  from . import layer_sync as layer_sync
6
10
  from . import precision as precision
7
11
  from .base import Plugin as Plugin
8
- from .base import PluginConfig as PluginConfig
9
12
  from .base import PluginConfigBase as PluginConfigBase
10
13
  from .base import plugin_registry as plugin_registry
14
+
15
+ PluginConfig = TypeAliasType(
16
+ "PluginConfig",
17
+ Annotated[PluginConfigBase, plugin_registry.DynamicResolution()],
18
+ )
@@ -1,8 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
- import logging
4
3
  from abc import ABC, abstractmethod
5
- from typing import TYPE_CHECKING, Annotated
4
+ from typing import TYPE_CHECKING
6
5
 
7
6
  import nshconfig as C
8
7
  from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment
@@ -12,7 +11,6 @@ from typing_extensions import TypeAliasType
12
11
 
13
12
  if TYPE_CHECKING:
14
13
  from .._config import TrainerConfig
15
- log = logging.getLogger(__name__)
16
14
 
17
15
 
18
16
  Plugin = TypeAliasType(
@@ -26,8 +24,3 @@ class PluginConfigBase(C.Config, ABC):
26
24
 
27
25
 
28
26
  plugin_registry = C.Registry(PluginConfigBase, discriminator="name")
29
-
30
- PluginConfig = TypeAliasType(
31
- "PluginConfig",
32
- Annotated[PluginConfigBase, plugin_registry.DynamicResolution()],
33
- )
@@ -1,27 +1,25 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import signal
4
- from typing import Any, Literal
4
+ from typing import Literal
5
5
 
6
- from lightning.pytorch.plugins.environments import ClusterEnvironment
7
- from typing_extensions import override
6
+ from typing_extensions import TypeAliasType, override
8
7
 
9
- from ...util.config.dtype import DTypeConfig
10
8
  from .base import PluginConfigBase, plugin_registry
11
9
 
12
10
 
13
11
  @plugin_registry.register
14
12
  class KubeflowEnvironmentPlugin(PluginConfigBase):
15
- name: Literal["kubeflow_environment"] = "kubeflow_environment"
16
-
17
13
  """Environment for distributed training using the PyTorchJob operator from Kubeflow.
18
14
 
19
15
  This environment, unlike others, does not get auto-detected and needs to be passed
20
16
  to the Fabric/Trainer constructor manually.
21
17
  """
22
18
 
19
+ name: Literal["kubeflow_environment"] = "kubeflow_environment"
20
+
23
21
  @override
24
- def create_plugin(self, trainer_config) -> ClusterEnvironment:
22
+ def create_plugin(self, trainer_config):
25
23
  from lightning.fabric.plugins.environments.kubeflow import KubeflowEnvironment
26
24
 
27
25
  return KubeflowEnvironment()
@@ -29,8 +27,6 @@ class KubeflowEnvironmentPlugin(PluginConfigBase):
29
27
 
30
28
  @plugin_registry.register
31
29
  class LightningEnvironmentPlugin(PluginConfigBase):
32
- name: Literal["lightning_environment"] = "lightning_environment"
33
-
34
30
  """The default environment used by Lightning for a single node or free cluster (not managed).
35
31
 
36
32
  There are two modes the Lightning environment can operate with:
@@ -40,8 +36,10 @@ class LightningEnvironmentPlugin(PluginConfigBase):
40
36
  The appropriate environment variables need to be set, and at minimum `LOCAL_RANK`.
41
37
  """
42
38
 
39
+ name: Literal["lightning_environment"] = "lightning_environment"
40
+
43
41
  @override
44
- def create_plugin(self, trainer_config) -> ClusterEnvironment:
42
+ def create_plugin(self, trainer_config):
45
43
  from lightning.fabric.plugins.environments.lightning import LightningEnvironment
46
44
 
47
45
  return LightningEnvironment()
@@ -49,16 +47,16 @@ class LightningEnvironmentPlugin(PluginConfigBase):
49
47
 
50
48
  @plugin_registry.register
51
49
  class LSFEnvironmentPlugin(PluginConfigBase):
52
- name: Literal["lsf_environment"] = "lsf_environment"
53
-
54
50
  """An environment for running on clusters managed by the LSF resource manager.
55
51
 
56
52
  It is expected that any execution using this ClusterEnvironment was executed
57
53
  using the Job Step Manager i.e. `jsrun`.
58
54
  """
59
55
 
56
+ name: Literal["lsf_environment"] = "lsf_environment"
57
+
60
58
  @override
61
- def create_plugin(self, trainer_config) -> ClusterEnvironment:
59
+ def create_plugin(self, trainer_config):
62
60
  from lightning.fabric.plugins.environments.lsf import LSFEnvironment
63
61
 
64
62
  return LSFEnvironment()
@@ -66,48 +64,108 @@ class LSFEnvironmentPlugin(PluginConfigBase):
66
64
 
67
65
  @plugin_registry.register
68
66
  class MPIEnvironmentPlugin(PluginConfigBase):
69
- name: Literal["mpi_environment"] = "mpi_environment"
70
-
71
67
  """An environment for running on clusters with processes created through MPI.
72
68
 
73
69
  Requires the installation of the `mpi4py` package.
74
70
  """
75
71
 
72
+ name: Literal["mpi_environment"] = "mpi_environment"
73
+
76
74
  @override
77
- def create_plugin(self, trainer_config) -> ClusterEnvironment:
75
+ def create_plugin(self, trainer_config):
78
76
  from lightning.fabric.plugins.environments.mpi import MPIEnvironment
79
77
 
80
78
  return MPIEnvironment()
81
79
 
82
80
 
81
+ SignalAlias = TypeAliasType(
82
+ "SignalAlias",
83
+ Literal[
84
+ "SIGABRT",
85
+ "SIGFPE",
86
+ "SIGILL",
87
+ "SIGINT",
88
+ "SIGSEGV",
89
+ "SIGTERM",
90
+ "SIGBREAK",
91
+ "CTRL_C_EVENT",
92
+ "CTRL_BREAK_EVENT",
93
+ "SIGALRM",
94
+ "SIGBUS",
95
+ "SIGCHLD",
96
+ "SIGCONT",
97
+ "SIGHUP",
98
+ "SIGIO",
99
+ "SIGIOT",
100
+ "SIGKILL",
101
+ "SIGPIPE",
102
+ "SIGPROF",
103
+ "SIGQUIT",
104
+ "SIGSTOP",
105
+ "SIGSYS",
106
+ "SIGTRAP",
107
+ "SIGTSTP",
108
+ "SIGTTIN",
109
+ "SIGTTOU",
110
+ "SIGURG",
111
+ "SIGUSR1",
112
+ "SIGUSR2",
113
+ "SIGVTALRM",
114
+ "SIGWINCH",
115
+ "SIGXCPU",
116
+ "SIGXFSZ",
117
+ "SIGEMT",
118
+ "SIGINFO",
119
+ "SIGCLD",
120
+ "SIGPOLL",
121
+ "SIGPWR",
122
+ "SIGRTMAX",
123
+ "SIGRTMIN",
124
+ "SIGSTKFLT",
125
+ ],
126
+ )
127
+
128
+
83
129
  @plugin_registry.register
84
130
  class SLURMEnvironmentPlugin(PluginConfigBase):
131
+ """An environment for running on clusters managed by the SLURM resource manager."""
132
+
85
133
  name: Literal["slurm_environment"] = "slurm_environment"
86
134
 
87
135
  auto_requeue: bool = True
88
136
  """Whether automatic job resubmission is enabled or not."""
89
137
 
90
- requeue_signal: signal.Signals | None = None
138
+ requeue_signal: SignalAlias | None = None
91
139
  """The signal that SLURM will send to indicate that the job should be requeued."""
92
140
 
93
141
  @override
94
- def create_plugin(self, trainer_config) -> ClusterEnvironment:
142
+ def create_plugin(self, trainer_config):
95
143
  from lightning.fabric.plugins.environments.slurm import SLURMEnvironment
96
144
 
145
+ requeue_signal = None
146
+ if self.requeue_signal is not None:
147
+ try:
148
+ requeue_signal = signal.Signals[self.requeue_signal]
149
+ except KeyError:
150
+ raise ValueError(
151
+ f"Invalid signal name: {self.requeue_signal}. "
152
+ "Please provide a valid signal name from the signal module."
153
+ )
154
+
97
155
  return SLURMEnvironment(
98
156
  auto_requeue=self.auto_requeue,
99
- requeue_signal=self.requeue_signal,
157
+ requeue_signal=requeue_signal,
100
158
  )
101
159
 
102
160
 
103
161
  @plugin_registry.register
104
162
  class TorchElasticEnvironmentPlugin(PluginConfigBase):
105
- name: Literal["torchelastic_environment"] = "torchelastic_environment"
106
-
107
163
  """Environment for fault-tolerant and elastic training with torchelastic."""
108
164
 
165
+ name: Literal["torchelastic_environment"] = "torchelastic_environment"
166
+
109
167
  @override
110
- def create_plugin(self, trainer_config) -> ClusterEnvironment:
168
+ def create_plugin(self, trainer_config):
111
169
  from lightning.fabric.plugins.environments.torchelastic import (
112
170
  TorchElasticEnvironment,
113
171
  )
@@ -117,12 +175,12 @@ class TorchElasticEnvironmentPlugin(PluginConfigBase):
117
175
 
118
176
  @plugin_registry.register
119
177
  class XLAEnvironmentPlugin(PluginConfigBase):
120
- name: Literal["xla_environment"] = "xla_environment"
121
-
122
178
  """Cluster environment for training on a TPU Pod with the PyTorch/XLA library."""
123
179
 
180
+ name: Literal["xla_environment"] = "xla_environment"
181
+
124
182
  @override
125
- def create_plugin(self, trainer_config) -> ClusterEnvironment:
183
+ def create_plugin(self, trainer_config):
126
184
  from lightning.fabric.plugins.environments.xla import XLAEnvironment
127
185
 
128
186
  return XLAEnvironment()
@@ -2,27 +2,52 @@ from __future__ import annotations
2
2
 
3
3
  from typing import Literal
4
4
 
5
- from lightning.pytorch.plugins.io import CheckpointIO
6
5
  from typing_extensions import override
7
6
 
8
- from .base import PluginConfig, PluginConfigBase, plugin_registry
7
+ from .base import PluginConfigBase, plugin_registry
9
8
 
10
9
 
11
10
  @plugin_registry.register
12
- class AsyncCheckpointIOPlugin(PluginConfigBase):
13
- name: Literal["async_checkpoint"] = "async_checkpoint"
11
+ class TorchCheckpointIOPlugin(PluginConfigBase):
12
+ """CheckpointIO that utilizes torch.save and torch.load to save and load checkpoints respectively."""
13
+
14
+ name: Literal["torch_checkpoint"] = "torch_checkpoint"
15
+
16
+ @override
17
+ def create_plugin(self, trainer_config):
18
+ from lightning.fabric.plugins.io.torch_io import TorchCheckpointIO
19
+
20
+ return TorchCheckpointIO()
21
+
22
+
23
+ @plugin_registry.register
24
+ class XLACheckpointIOPlugin(PluginConfigBase):
25
+ """CheckpointIO that utilizes xm.save to save checkpoints for TPU training strategies."""
14
26
 
27
+ name: Literal["xla_checkpoint"] = "xla_checkpoint"
28
+
29
+ @override
30
+ def create_plugin(self, trainer_config):
31
+ from lightning.fabric.plugins.io.xla import XLACheckpointIO
32
+
33
+ return XLACheckpointIO()
34
+
35
+
36
+ @plugin_registry.register
37
+ class AsyncCheckpointIOPlugin(PluginConfigBase):
15
38
  """Enables saving the checkpoints asynchronously in a thread.
16
39
 
17
40
  .. warning:: This is an experimental feature.
18
41
  """
19
42
 
20
- checkpoint_io: PluginConfig | None = None
43
+ name: Literal["async_checkpoint"] = "async_checkpoint"
44
+
45
+ checkpoint_io: TorchCheckpointIOPlugin | None = None
21
46
  """A checkpoint IO plugin that is used as the basis for async checkpointing."""
22
47
 
23
48
  @override
24
- def create_plugin(self, trainer_config) -> CheckpointIO:
25
- from lightning.pytorch.plugins.io.async_plugin import AsyncCheckpointIO
49
+ def create_plugin(self, trainer_config):
50
+ from lightning.pytorch.plugins.io import AsyncCheckpointIO, CheckpointIO
26
51
 
27
52
  base_io = (
28
53
  self.checkpoint_io.create_plugin(trainer_config)
@@ -34,29 +59,3 @@ class AsyncCheckpointIOPlugin(PluginConfigBase):
34
59
  f"Expected `checkpoint_io` to be a `CheckpointIO` instance, but got {type(base_io)}."
35
60
  )
36
61
  return AsyncCheckpointIO(checkpoint_io=base_io)
37
-
38
-
39
- @plugin_registry.register
40
- class TorchCheckpointIOPlugin(PluginConfigBase):
41
- name: Literal["torch_checkpoint"] = "torch_checkpoint"
42
-
43
- """CheckpointIO that utilizes torch.save and torch.load to save and load checkpoints respectively."""
44
-
45
- @override
46
- def create_plugin(self, trainer_config) -> CheckpointIO:
47
- from lightning.fabric.plugins.io.torch_io import TorchCheckpointIO
48
-
49
- return TorchCheckpointIO()
50
-
51
-
52
- @plugin_registry.register
53
- class XLACheckpointIOPlugin(PluginConfigBase):
54
- name: Literal["xla_checkpoint"] = "xla_checkpoint"
55
-
56
- """CheckpointIO that utilizes xm.save to save checkpoints for TPU training strategies."""
57
-
58
- @override
59
- def create_plugin(self, trainer_config) -> CheckpointIO:
60
- from lightning.fabric.plugins.io.xla import XLACheckpointIO
61
-
62
- return XLACheckpointIO()
@@ -2,7 +2,6 @@ from __future__ import annotations
2
2
 
3
3
  from typing import Literal
4
4
 
5
- from lightning.pytorch.plugins.layer_sync import LayerSync
6
5
  from typing_extensions import override
7
6
 
8
7
  from .base import PluginConfigBase, plugin_registry
@@ -10,16 +9,16 @@ from .base import PluginConfigBase, plugin_registry
10
9
 
11
10
  @plugin_registry.register
12
11
  class TorchSyncBatchNormPlugin(PluginConfigBase):
13
- name: Literal["torch_sync_batchnorm"] = "torch_sync_batchnorm"
14
-
15
12
  """A plugin that wraps all batch normalization layers of a model with synchronization
16
13
  logic for multiprocessing.
17
14
 
18
15
  This plugin has no effect in single-device operation.
19
16
  """
20
17
 
18
+ name: Literal["torch_sync_batchnorm"] = "torch_sync_batchnorm"
19
+
21
20
  @override
22
- def create_plugin(self, trainer_config) -> LayerSync:
21
+ def create_plugin(self, trainer_config):
23
22
  from lightning.pytorch.plugins.layer_sync import TorchSyncBatchNorm
24
23
 
25
24
  return TorchSyncBatchNorm()
@@ -2,7 +2,6 @@ from __future__ import annotations
2
2
 
3
3
  from typing import Any, Literal
4
4
 
5
- from lightning.pytorch.plugins.precision import Precision
6
5
  from typing_extensions import override
7
6
 
8
7
  from ...util.config.dtype import DTypeConfig
@@ -20,7 +19,7 @@ class MixedPrecisionPluginConfig(PluginConfigBase):
20
19
  """The device for ``torch.autocast``."""
21
20
 
22
21
  @override
23
- def create_plugin(self, trainer_config) -> Precision:
22
+ def create_plugin(self, trainer_config):
24
23
  from lightning.pytorch.plugins.precision.amp import MixedPrecision
25
24
 
26
25
  return MixedPrecision(self.precision, self.device)
@@ -45,7 +44,7 @@ class BitsandbytesPluginConfig(PluginConfigBase):
45
44
  """
46
45
 
47
46
  @override
48
- def create_plugin(self, trainer_config) -> Precision:
47
+ def create_plugin(self, trainer_config):
49
48
  from lightning.pytorch.plugins.precision.bitsandbytes import (
50
49
  BitsandbytesPrecision,
51
50
  )
@@ -66,7 +65,7 @@ class DeepSpeedPluginConfig(PluginConfigBase):
66
65
  mixed precision (16-mixed, bf16-mixed)."""
67
66
 
68
67
  @override
69
- def create_plugin(self, trainer_config) -> Precision:
68
+ def create_plugin(self, trainer_config):
70
69
  from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecision
71
70
 
72
71
  return DeepSpeedPrecision(precision=self.precision)
@@ -80,7 +79,7 @@ class DoublePrecisionPluginConfig(PluginConfigBase):
80
79
  """Plugin for training with double (``torch.float64``) precision."""
81
80
 
82
81
  @override
83
- def create_plugin(self, trainer_config) -> Precision:
82
+ def create_plugin(self, trainer_config):
84
83
  from lightning.pytorch.plugins.precision.double import DoublePrecision
85
84
 
86
85
  return DoublePrecision()
@@ -95,7 +94,7 @@ class FSDPPrecisionPluginConfig(PluginConfigBase):
95
94
  mixed precision (16-mixed, bf16-mixed)."""
96
95
 
97
96
  @override
98
- def create_plugin(self, trainer_config) -> Precision:
97
+ def create_plugin(self, trainer_config):
99
98
  from lightning.pytorch.plugins.precision.fsdp import FSDPPrecision
100
99
 
101
100
  return FSDPPrecision(precision=self.precision)
@@ -109,7 +108,7 @@ class HalfPrecisionPluginConfig(PluginConfigBase):
109
108
  """Whether to use ``torch.float16`` (``'16-true'``) or ``torch.bfloat16`` (``'bf16-true'``)."""
110
109
 
111
110
  @override
112
- def create_plugin(self, trainer_config) -> Precision:
111
+ def create_plugin(self, trainer_config):
113
112
  from lightning.pytorch.plugins.precision.half import HalfPrecision
114
113
 
115
114
  return HalfPrecision(precision=self.precision)
@@ -134,7 +133,7 @@ class TransformerEnginePluginConfig(PluginConfigBase):
134
133
  Defaults to the same as weights_dtype."""
135
134
 
136
135
  @override
137
- def create_plugin(self, trainer_config) -> Precision:
136
+ def create_plugin(self, trainer_config):
138
137
  from lightning.pytorch.plugins.precision.transformer_engine import (
139
138
  TransformerEnginePrecision,
140
139
  )
@@ -157,7 +156,7 @@ class XLAPluginConfig(PluginConfigBase):
157
156
  """Full precision (32-true) or half precision (16-true, bf16-true)."""
158
157
 
159
158
  @override
160
- def create_plugin(self, trainer_config) -> Precision:
159
+ def create_plugin(self, trainer_config):
161
160
  from lightning.pytorch.plugins.precision.xla import XLAPrecision
162
161
 
163
162
  return XLAPrecision(precision=self.precision)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshtrainer
3
- Version: 1.0.0b54
3
+ Version: 1.0.0b56
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -32,7 +32,7 @@ nshtrainer/callbacks/timer.py,sha256=gDcw_K_ikf0bkVgxQ0cDhvvNvz6GLZVLcatuKfh0ORU
32
32
  nshtrainer/callbacks/wandb_upload_code.py,sha256=shV7UtnXgY2bUlXdVrXiaDs0PNLlIt7TzNJkJPkzvzI,2414
33
33
  nshtrainer/callbacks/wandb_watch.py,sha256=VB14Dy5ZRXQ3di0fPv0K_DFJurLhroLPytnuwQBiJFg,3037
34
34
  nshtrainer/configs/.gitattributes,sha256=VeZmarvNEqiRBOHGcllpKm90nL6C8u4tBu7SEm7fj-E,26
35
- nshtrainer/configs/__init__.py,sha256=0BzCgE1iEJ0Ywmy__mqJZipLQtwZVdz6XK-gHbkA7GY,14650
35
+ nshtrainer/configs/__init__.py,sha256=-rGk9pnRnuz4yKvACGOpY3nkrWnHholqZGk7UP2Vkrc,14716
36
36
  nshtrainer/configs/_checkpoint/__init__.py,sha256=6s7Y68StboqscY2G4P_QG443jz5aiym5SjOogIljWLg,342
37
37
  nshtrainer/configs/_checkpoint/metadata/__init__.py,sha256=oOPfYkXTjKgm6pluGsG6V1TPyCEGjsQpHVL-LffSUFQ,290
38
38
  nshtrainer/configs/_directory/__init__.py,sha256=_oO7vM9DhzHSxtZcv86sTi7hZIptnK1gr-AP9mqQ370,386
@@ -67,15 +67,16 @@ nshtrainer/configs/loggers/base/__init__.py,sha256=HLUfEDbjaAXqzsFmQbjdciIWzR1st
67
67
  nshtrainer/configs/loggers/csv/__init__.py,sha256=gawaDX92JObGSmBqYpfNHWMHBwVOofS694W-1Y2GWDU,353
68
68
  nshtrainer/configs/loggers/tensorboard/__init__.py,sha256=phzm-TnBkdkibTgoOxIIcAliqL3zU8gSNK61Mwxs1CM,410
69
69
  nshtrainer/configs/loggers/wandb/__init__.py,sha256=TDcD5WZSKenc2mgIXhwz2l96l8P_Ur3N5CzEol5AKGw,746
70
- nshtrainer/configs/lr_scheduler/__init__.py,sha256=xtiUx0isxA82-uXMn4-KmPnDCfbUkpAnd2_pFupAAKQ,1137
70
+ nshtrainer/configs/lr_scheduler/__init__.py,sha256=PvH2d8QEC3TsC3_svcUbxeQEMMzIK_In0_Bp9xntSms,1243
71
71
  nshtrainer/configs/lr_scheduler/base/__init__.py,sha256=6Cx8r4rdxeSYxc_z0o7drKCblGJU_zzqrOoYlWYR5qY,305
72
72
  nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py,sha256=5ZMLDO9VL6SNU6pF-62lDnpmqix3_Ol9DdEwiuOPYlA,675
73
- nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py,sha256=w-vq8UbRGPX8DZVWCMC5eIrbvVc_guxjj7Du9AaeKCw,609
73
+ nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py,sha256=DBwZV590I5qwyOS5M43YhUzgYy1-AjzkM5aEnTA6XdI,715
74
74
  nshtrainer/configs/metrics/__init__.py,sha256=mK_xgXJDAyGY4K_x_Zo4aj36kjT45d850keuUe3U1rY,200
75
75
  nshtrainer/configs/metrics/_config/__init__.py,sha256=XDDvDPWULd_vd3lrgF2KGAVR2LVDuhdQvy-fF2ImarI,159
76
- nshtrainer/configs/nn/__init__.py,sha256=tkFG2Hb0oL_AmWP3_0WkDN2zI5PkVfrgwXhaAII7CZw,2072
76
+ nshtrainer/configs/nn/__init__.py,sha256=Ms2gIqbRxNVm6GHKCddCJTTqMwUPifjjHD_fCfJpcMQ,2174
77
77
  nshtrainer/configs/nn/mlp/__init__.py,sha256=O6kQ6utZNJPG9Fax5pRdZcHa3J-XFKKdXcc_PQg0jk0,347
78
78
  nshtrainer/configs/nn/nonlinearity/__init__.py,sha256=LCTbTyelCMABVw505CGQ4UpEGlAnIhflSLFqwAQXLQA,2155
79
+ nshtrainer/configs/nn/rng/__init__.py,sha256=4iC6vwxbfNeXyvpwZ1Z5Kcy-he4cu7mg3UpLD-RLrHc,141
79
80
  nshtrainer/configs/optimizer/__init__.py,sha256=itIDIHQvGm50eZ7JLyNElahnNUMPJ__4PMmTjc0RQ6o,444
80
81
  nshtrainer/configs/profiler/__init__.py,sha256=2ssaIpfVnvcbfNvZ-JeKp1Cx4NO1LknkVqTm1hu7Lvw,768
81
82
  nshtrainer/configs/profiler/_base/__init__.py,sha256=ekYfPg-VDhCAFM5nJka2TxUYdRDm1CKqjwUOQNbQjD4,176
@@ -86,9 +87,9 @@ nshtrainer/configs/trainer/__init__.py,sha256=a8pzGVid52abAVARPbgjaN566H1ZM44FH_
86
87
  nshtrainer/configs/trainer/_config/__init__.py,sha256=6DXdtP-uH11TopQ7kzId9fco-wVkD7ZfevbBqDpN6TE,3817
87
88
  nshtrainer/configs/trainer/accelerator/__init__.py,sha256=3H6R3wlwbKL1TzDqGCChZk78-BcE2czLouo7Djiq3nA,898
88
89
  nshtrainer/configs/trainer/plugin/__init__.py,sha256=NkHQxMPkrtTtdIAO4dQUE9SWEcHRDB0yUXLkTjnl4dA,3332
89
- nshtrainer/configs/trainer/plugin/base/__init__.py,sha256=GuubcKrbXt4VjJRT8VpNUQqBtyuutre_CfkS0EWZ5_E,368
90
- nshtrainer/configs/trainer/plugin/environment/__init__.py,sha256=3o16x4qRAOvkJH9Vg4-QwsEODDC6aP_OXRnPPkm_xSo,1376
91
- nshtrainer/configs/trainer/plugin/io/__init__.py,sha256=W6G67JnigB6d3MiwLrbSKgtIZLUccXznp-IXwkK1J4U,743
90
+ nshtrainer/configs/trainer/plugin/base/__init__.py,sha256=slW5z1FZw2qICXO9l9DnLIDB1Yl7KOcxPEZkyYIHrp4,276
91
+ nshtrainer/configs/trainer/plugin/environment/__init__.py,sha256=geTaOfbFep3DVDgSsRrufM7q6MIZur5QG1_47ngAL0I,1280
92
+ nshtrainer/configs/trainer/plugin/io/__init__.py,sha256=AtGUuE0M16dTpX0q9NqvJiE4qU1j07N0RLtd-JFzWuc,653
92
93
  nshtrainer/configs/trainer/plugin/layer_sync/__init__.py,sha256=SYDZk2M6sgpt4sEuoURuS8EKYmaqGcvYxETE9jvTrEE,431
93
94
  nshtrainer/configs/trainer/plugin/precision/__init__.py,sha256=szlqSfK2XuWdkf72LQzQFv3SlWfKFdRUpBEYIxQ3TPs,1507
94
95
  nshtrainer/configs/trainer/strategy/__init__.py,sha256=50whNloJVBq_bdbLaPQnPBTeS1Rcs8MwxTCYBj1kKa4,273
@@ -119,11 +120,12 @@ nshtrainer/model/base.py,sha256=bZMNap0rkxRbAbu2BOHV_6YS2iZZnvy6wVSMOXGa_ZM,8680
119
120
  nshtrainer/model/mixins/callback.py,sha256=0LPgve4VszHbLipid4mpI1qnnmdGS2spivs0dXLvqHw,3154
120
121
  nshtrainer/model/mixins/debug.py,sha256=ydLuAAaa7M5bX0gougZ5gWuZnvn4Ra9assal3IZ9hq8,2086
121
122
  nshtrainer/model/mixins/logger.py,sha256=7u9fQig-SVFA9RFIB4U0gqJAzruh49mgmXXvZ6VkDUk,11694
122
- nshtrainer/nn/__init__.py,sha256=5Gg3nieGSC5_dXaI9KUVUUbM13hHexH9831m4hcf6no,1475
123
+ nshtrainer/nn/__init__.py,sha256=Vd246v2N9tBQ8XxmTquWzj5lAmeSnngrjpYOfp4LTXM,1499
123
124
  nshtrainer/nn/mlp.py,sha256=nYUgAISzuhC8sav6PloAdyz0PdEoikwppiXIuToEVdE,7550
124
125
  nshtrainer/nn/module_dict.py,sha256=9plb8aQUx5TUEPhX5jI9u8LrpTeKe7jZAHi8iIqcN8w,2365
125
126
  nshtrainer/nn/module_list.py,sha256=UB43pcwD_3nUke_DyLQt-iXKhWdKM6Zjm84lRC1hPYA,1755
126
127
  nshtrainer/nn/nonlinearity.py,sha256=xmaL4QCRvCxqmaGIOwetJeKK-6IK4m2OV7D3SjxSwJQ,6322
128
+ nshtrainer/nn/rng.py,sha256=IJGvX9v8qBkfgBrMlNU2aj-MbYTPoncFyJzvPkzCQpM,512
127
129
  nshtrainer/optimizer.py,sha256=u968GRNPUNn3f_9BEY2RBNuJq5O3wJWams3NG0dkrOA,1738
128
130
  nshtrainer/profiler/__init__.py,sha256=RjaNBoVcTFu8lF0dNlFp-2LaPYdonoIbDy2_KhgF0Ek,594
129
131
  nshtrainer/profiler/_base.py,sha256=kFcSVn9gJuMwgDxbfyHh46CmEAIPZjxw3yjPbKgzvwA,950
@@ -134,12 +136,12 @@ nshtrainer/trainer/__init__.py,sha256=fQ7gQRlGWX-90TYT0rttkQyvXDCzo7DAvJgr-jX1zs
134
136
  nshtrainer/trainer/_config.py,sha256=s-_XoLc9mbNAdroRJyOKd3dLTyrFLQkPyGJkKDmBYf8,33267
135
137
  nshtrainer/trainer/_runtime_callback.py,sha256=6F2Gq27Q8OFfN3RtdNC6QRA8ac0LC1hh4DUE3V5WgbI,4217
136
138
  nshtrainer/trainer/accelerator.py,sha256=Bqq-ry7DeCY4zw9_zBvTZiijpA-uUHrDjtbLV652m4M,2415
137
- nshtrainer/trainer/plugin/__init__.py,sha256=UM8f70Ml3RGpsXeQr1Yh1yBcxxjFNGFGgJbniCn_rws,366
138
- nshtrainer/trainer/plugin/base.py,sha256=9-qUHXGpll_yCylun0899sbmJDpyhD9IQcBtVrJx38I,919
139
- nshtrainer/trainer/plugin/environment.py,sha256=NW0qbsbvDPe59JGOMgPLq1fj7szLucIV1WRTxCrcjF4,4367
140
- nshtrainer/trainer/plugin/io.py,sha256=nm6YDCVZAhmPvLaLnw6q4BrK2Gj2wvD5ZLDhj1xneEE,2030
141
- nshtrainer/trainer/plugin/layer_sync.py,sha256=h-ydZwXepnsw5-paLgiDatqPyQ_8C0QEv6DrYFIaXOo,735
142
- nshtrainer/trainer/plugin/precision.py,sha256=I0QsB1bVxmsFmBOkgrAfGONsuYae_lD9Bz0PfJEQvH4,5598
139
+ nshtrainer/trainer/plugin/__init__.py,sha256=q_q98MYNaZ2VE_tqGqYlQjQnlaF4NE1FUqVVbj0EK7k,517
140
+ nshtrainer/trainer/plugin/base.py,sha256=76ct2TTHLpPr5MO8B9CIkoCOo-dFImzqAll8cIdC0cg,736
141
+ nshtrainer/trainer/plugin/environment.py,sha256=SSXRWHjyFUA6oFx3duD_ZwhM59pWUjR1_UzHz02NI2c,5440
142
+ nshtrainer/trainer/plugin/io.py,sha256=OmFSKLloMypletjaUr_Ptg6LS0ljqTVIp2o4Hm3eZoE,1926
143
+ nshtrainer/trainer/plugin/layer_sync.py,sha256=-BbEyWZ063O7tZme7Gdu1lVxK6p1NeuLcOfPXnvH29s,663
144
+ nshtrainer/trainer/plugin/precision.py,sha256=7lf7KZd_yFyPmhLApjEIv0pkoDB5zdxi-7in0wRj3z8,5436
143
145
  nshtrainer/trainer/signal_connector.py,sha256=GhfGcSzfaTNhnj2QFkBDq5aT7FqbLMA7eC8SYQs8_8w,10828
144
146
  nshtrainer/trainer/strategy.py,sha256=VPTn5z3zvXTydY8IJchjhjcOfpvtoejnvUkq5E4WTus,1368
145
147
  nshtrainer/trainer/trainer.py,sha256=Lo3vUo3ooTAjaX2fUYPFSMv5FP7sWfVov0QbA-T5hZ8,21113
@@ -154,6 +156,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
154
156
  nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
155
157
  nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
156
158
  nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
157
- nshtrainer-1.0.0b54.dist-info/METADATA,sha256=oSfrN2tgKgkZJwGbZVNbLULQcVMxh_wb02u7Hrujfn4,988
158
- nshtrainer-1.0.0b54.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
159
- nshtrainer-1.0.0b54.dist-info/RECORD,,
159
+ nshtrainer-1.0.0b56.dist-info/METADATA,sha256=KevxOAySDQL3f1OO7dgEsNUeX322ostuFo6i7egHzRU,988
160
+ nshtrainer-1.0.0b56.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
161
+ nshtrainer-1.0.0b56.dist-info/RECORD,,