nshtrainer 1.0.0b28__py3-none-any.whl → 1.0.0b30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,43 @@
1
+ from __future__ import annotations
2
+
3
+ __codegen__ = True
4
+
5
+ from nshtrainer.trainer.plugin.precision import (
6
+ BitsandbytesPluginConfig as BitsandbytesPluginConfig,
7
+ )
8
+ from nshtrainer.trainer.plugin.precision import (
9
+ DeepSpeedPluginConfig as DeepSpeedPluginConfig,
10
+ )
11
+ from nshtrainer.trainer.plugin.precision import (
12
+ DoublePrecisionPluginConfig as DoublePrecisionPluginConfig,
13
+ )
14
+ from nshtrainer.trainer.plugin.precision import DTypeConfig as DTypeConfig
15
+ from nshtrainer.trainer.plugin.precision import (
16
+ FSDPPrecisionPluginConfig as FSDPPrecisionPluginConfig,
17
+ )
18
+ from nshtrainer.trainer.plugin.precision import (
19
+ HalfPrecisionPluginConfig as HalfPrecisionPluginConfig,
20
+ )
21
+ from nshtrainer.trainer.plugin.precision import (
22
+ MixedPrecisionPluginConfig as MixedPrecisionPluginConfig,
23
+ )
24
+ from nshtrainer.trainer.plugin.precision import PluginConfigBase as PluginConfigBase
25
+ from nshtrainer.trainer.plugin.precision import (
26
+ TransformerEnginePluginConfig as TransformerEnginePluginConfig,
27
+ )
28
+ from nshtrainer.trainer.plugin.precision import XLAPluginConfig as XLAPluginConfig
29
+ from nshtrainer.trainer.plugin.precision import plugin_registry as plugin_registry
30
+
31
+ __all__ = [
32
+ "BitsandbytesPluginConfig",
33
+ "DTypeConfig",
34
+ "DeepSpeedPluginConfig",
35
+ "DoublePrecisionPluginConfig",
36
+ "FSDPPrecisionPluginConfig",
37
+ "HalfPrecisionPluginConfig",
38
+ "MixedPrecisionPluginConfig",
39
+ "PluginConfigBase",
40
+ "TransformerEnginePluginConfig",
41
+ "XLAPluginConfig",
42
+ "plugin_registry",
43
+ ]
@@ -0,0 +1,11 @@
1
+ from __future__ import annotations
2
+
3
+ __codegen__ = True
4
+
5
+ from nshtrainer.trainer.strategy import StrategyConfig as StrategyConfig
6
+ from nshtrainer.trainer.strategy import StrategyConfigBase as StrategyConfigBase
7
+
8
+ __all__ = [
9
+ "StrategyConfig",
10
+ "StrategyConfigBase",
11
+ ]
@@ -4,12 +4,14 @@ __codegen__ = True
4
4
 
5
5
  from nshtrainer.trainer.trainer import AcceleratorConfigBase as AcceleratorConfigBase
6
6
  from nshtrainer.trainer.trainer import EnvironmentConfig as EnvironmentConfig
7
+ from nshtrainer.trainer.trainer import PluginConfigBase as PluginConfigBase
7
8
  from nshtrainer.trainer.trainer import StrategyConfigBase as StrategyConfigBase
8
9
  from nshtrainer.trainer.trainer import TrainerConfig as TrainerConfig
9
10
 
10
11
  __all__ = [
11
12
  "AcceleratorConfigBase",
12
13
  "EnvironmentConfig",
14
+ "PluginConfigBase",
13
15
  "StrategyConfigBase",
14
16
  "TrainerConfig",
15
17
  ]
@@ -1,12 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import builtins
4
- from typing import Literal
4
+ from typing import Any, Literal
5
5
 
6
6
  import nshconfig as C
7
7
 
8
- from ..util._useful_types import SupportsRichComparisonT
9
-
10
8
 
11
9
  class MetricConfig(C.Config):
12
10
  name: str
@@ -40,5 +38,5 @@ class MetricConfig(C.Config):
40
38
  def best(self):
41
39
  return builtins.min if self.mode == "min" else builtins.max
42
40
 
43
- def is_better(self, a: SupportsRichComparisonT, b: SupportsRichComparisonT) -> bool:
41
+ def is_better(self, a: Any, b: Any):
44
42
  return self.best(a, b) == a
@@ -1,4 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from ._config import TrainerConfig as TrainerConfig
4
+ from ._config import accelerator_registry as accelerator_registry
5
+ from ._config import plugin_registry as plugin_registry
4
6
  from .trainer import Trainer as Trainer
@@ -5,7 +5,6 @@ import logging
5
5
  import os
6
6
  import string
7
7
  import time
8
- from abc import ABC, abstractmethod
9
8
  from collections.abc import Iterable, Sequence
10
9
  from datetime import timedelta
11
10
  from pathlib import Path
@@ -18,14 +17,11 @@ from typing import (
18
17
 
19
18
  import nshconfig as C
20
19
  import numpy as np
21
- from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment
22
20
  from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT
23
21
  from lightning.pytorch.accelerators import Accelerator
24
22
  from lightning.pytorch.callbacks.callback import Callback
25
23
  from lightning.pytorch.loggers import Logger
26
24
  from lightning.pytorch.plugins import _PLUGIN_INPUT
27
- from lightning.pytorch.plugins.layer_sync import LayerSync
28
- from lightning.pytorch.plugins.precision.precision import Precision
29
25
  from lightning.pytorch.profilers import Profiler
30
26
  from lightning.pytorch.strategies.strategy import Strategy
31
27
  from typing_extensions import TypeAliasType, TypedDict, override
@@ -58,6 +54,9 @@ from ..loggers.actsave import ActSaveLoggerConfig
58
54
  from ..metrics._config import MetricConfig
59
55
  from ..profiler import ProfilerConfig
60
56
  from ..util._environment_info import EnvironmentConfig
57
+ from .accelerator import AcceleratorConfig, AcceleratorLiteral, accelerator_registry
58
+ from .plugin import PluginConfig, plugin_registry
59
+ from .strategy import StrategyConfig
61
60
 
62
61
  log = logging.getLogger(__name__)
63
62
 
@@ -71,39 +70,6 @@ class GradientClippingConfig(C.Config):
71
70
  """Norm type to use for gradient clipping."""
72
71
 
73
72
 
74
- Plugin = TypeAliasType(
75
- "Plugin", Precision | ClusterEnvironment | CheckpointIO | LayerSync
76
- )
77
-
78
-
79
- class PluginConfigBase(C.Config, ABC):
80
- @abstractmethod
81
- def create_plugin(self) -> Plugin: ...
82
-
83
-
84
- plugin_registry = C.Registry(PluginConfigBase, discriminator="name")
85
-
86
-
87
- class AcceleratorConfigBase(C.Config, ABC):
88
- @abstractmethod
89
- def create_accelerator(self) -> Accelerator: ...
90
-
91
-
92
- accelerator_registry = C.Registry(AcceleratorConfigBase, discriminator="name")
93
-
94
-
95
- class StrategyConfigBase(C.Config, ABC):
96
- @abstractmethod
97
- def create_strategy(self) -> Strategy: ...
98
-
99
-
100
- strategy_registry = C.Registry(StrategyConfigBase, discriminator="name")
101
-
102
-
103
- AcceleratorLiteral = TypeAliasType(
104
- "AcceleratorLiteral", Literal["cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto"]
105
- )
106
-
107
73
  StrategyLiteral = TypeAliasType(
108
74
  "StrategyLiteral",
109
75
  Literal[
@@ -432,7 +398,6 @@ class SanityCheckingConfig(C.Config):
432
398
 
433
399
 
434
400
  @plugin_registry.rebuild_on_registers
435
- @strategy_registry.rebuild_on_registers
436
401
  @accelerator_registry.rebuild_on_registers
437
402
  class TrainerConfig(C.Config):
438
403
  # region Active Run Configuration
@@ -578,9 +543,7 @@ class TrainerConfig(C.Config):
578
543
  Default: ``False``.
579
544
  """
580
545
 
581
- plugins: (
582
- list[Annotated[PluginConfigBase, plugin_registry.DynamicResolution()]] | None
583
- ) = None
546
+ plugins: list[PluginConfig] | None = None
584
547
  """
585
548
  Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.
586
549
  Default: ``None``.
@@ -740,21 +703,13 @@ class TrainerConfig(C.Config):
740
703
  Default: ``True``.
741
704
  """
742
705
 
743
- accelerator: (
744
- Annotated[AcceleratorConfigBase, accelerator_registry.DynamicResolution()]
745
- | AcceleratorLiteral
746
- | None
747
- ) = None
706
+ accelerator: AcceleratorConfig | AcceleratorLiteral | None = None
748
707
  """Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto")
749
708
  as well as custom accelerator instances.
750
709
  Default: ``"auto"``.
751
710
  """
752
711
 
753
- strategy: (
754
- Annotated[StrategyConfigBase, strategy_registry.DynamicResolution()]
755
- | StrategyLiteral
756
- | None
757
- ) = None
712
+ strategy: StrategyConfig | StrategyLiteral | None = None
758
713
  """Supports different training strategies with aliases as well custom strategies.
759
714
  Default: ``"auto"``.
760
715
  """
@@ -0,0 +1,86 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import TYPE_CHECKING, Annotated, Literal
5
+
6
+ import nshconfig as C
7
+ from lightning.pytorch.accelerators import Accelerator
8
+ from typing_extensions import TypeAliasType, override
9
+
10
+ if TYPE_CHECKING:
11
+ from ._config import TrainerConfig
12
+
13
+ AcceleratorLiteral = TypeAliasType(
14
+ "AcceleratorLiteral", Literal["cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto"]
15
+ )
16
+
17
+
18
+ class AcceleratorConfigBase(C.Config, ABC):
19
+ @abstractmethod
20
+ def create_accelerator(self, trainer_config: "TrainerConfig") -> Accelerator: ...
21
+
22
+
23
+ accelerator_registry = C.Registry(AcceleratorConfigBase, discriminator="name")
24
+
25
+ AcceleratorConfig = TypeAliasType(
26
+ "AcceleratorConfig",
27
+ Annotated[AcceleratorConfigBase, accelerator_registry.DynamicResolution()],
28
+ )
29
+
30
+
31
+ @accelerator_registry.register
32
+ class CPUAcceleratorConfig(AcceleratorConfigBase):
33
+ name: Literal["cpu"] = "cpu"
34
+
35
+ """Accelerator for CPU devices."""
36
+
37
+ @override
38
+ def create_accelerator(self, trainer_config) -> Accelerator:
39
+ from lightning.pytorch.accelerators.cpu import CPUAccelerator
40
+
41
+ return CPUAccelerator()
42
+
43
+
44
+ @accelerator_registry.register
45
+ class CUDAAcceleratorConfig(AcceleratorConfigBase):
46
+ name: Literal["gpu"] = "gpu"
47
+
48
+ """Accelerator for NVIDIA CUDA devices."""
49
+
50
+ @override
51
+ def create_accelerator(self, trainer_config) -> Accelerator:
52
+ from lightning.pytorch.accelerators.cuda import CUDAAccelerator
53
+
54
+ return CUDAAccelerator()
55
+
56
+
57
+ @accelerator_registry.register
58
+ class MPSAcceleratorConfig(AcceleratorConfigBase):
59
+ name: Literal["mps"] = "mps"
60
+
61
+ """Accelerator for Metal Apple Silicon GPU devices.
62
+
63
+ .. warning:: Use of this accelerator beyond import and instantiation is experimental.
64
+ """
65
+
66
+ @override
67
+ def create_accelerator(self, trainer_config) -> Accelerator:
68
+ from lightning.pytorch.accelerators.mps import MPSAccelerator
69
+
70
+ return MPSAccelerator()
71
+
72
+
73
+ @accelerator_registry.register
74
+ class XLAAcceleratorConfig(AcceleratorConfigBase):
75
+ name: Literal["tpu"] = "tpu"
76
+
77
+ """Accelerator for XLA devices, normally TPUs.
78
+
79
+ .. warning:: Use of this accelerator beyond import and instantiation is experimental.
80
+ """
81
+
82
+ @override
83
+ def create_accelerator(self, trainer_config) -> Accelerator:
84
+ from lightning.pytorch.accelerators.xla import XLAAccelerator
85
+
86
+ return XLAAccelerator()
@@ -0,0 +1,10 @@
1
+ from __future__ import annotations
2
+
3
+ from . import environment as environment
4
+ from . import io as io
5
+ from . import layer_sync as layer_sync
6
+ from . import precision as precision
7
+ from .base import Plugin as Plugin
8
+ from .base import PluginConfig as PluginConfig
9
+ from .base import PluginConfigBase as PluginConfigBase
10
+ from .base import plugin_registry as plugin_registry
@@ -0,0 +1,33 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from abc import ABC, abstractmethod
5
+ from typing import TYPE_CHECKING, Annotated
6
+
7
+ import nshconfig as C
8
+ from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment
9
+ from lightning.pytorch.plugins.layer_sync import LayerSync
10
+ from lightning.pytorch.plugins.precision.precision import Precision
11
+ from typing_extensions import TypeAliasType
12
+
13
+ if TYPE_CHECKING:
14
+ from .._config import TrainerConfig
15
+ log = logging.getLogger(__name__)
16
+
17
+
18
+ Plugin = TypeAliasType(
19
+ "Plugin", Precision | ClusterEnvironment | CheckpointIO | LayerSync
20
+ )
21
+
22
+
23
+ class PluginConfigBase(C.Config, ABC):
24
+ @abstractmethod
25
+ def create_plugin(self, trainer_config: "TrainerConfig") -> Plugin: ...
26
+
27
+
28
+ plugin_registry = C.Registry(PluginConfigBase, discriminator="name")
29
+
30
+ PluginConfig = TypeAliasType(
31
+ "PluginConfig",
32
+ Annotated[PluginConfigBase, plugin_registry.DynamicResolution()],
33
+ )
@@ -0,0 +1,128 @@
1
+ from __future__ import annotations
2
+
3
+ import signal
4
+ from typing import Any, Literal
5
+
6
+ from lightning.pytorch.plugins.environments import ClusterEnvironment
7
+ from typing_extensions import override
8
+
9
+ from ...util.config.dtype import DTypeConfig
10
+ from .base import PluginConfigBase, plugin_registry
11
+
12
+
13
+ @plugin_registry.register
14
+ class KubeflowEnvironmentPlugin(PluginConfigBase):
15
+ name: Literal["kubeflow_environment"] = "kubeflow_environment"
16
+
17
+ """Environment for distributed training using the PyTorchJob operator from Kubeflow.
18
+
19
+ This environment, unlike others, does not get auto-detected and needs to be passed
20
+ to the Fabric/Trainer constructor manually.
21
+ """
22
+
23
+ @override
24
+ def create_plugin(self, trainer_config) -> ClusterEnvironment:
25
+ from lightning.fabric.plugins.environments.kubeflow import KubeflowEnvironment
26
+
27
+ return KubeflowEnvironment()
28
+
29
+
30
+ @plugin_registry.register
31
+ class LightningEnvironmentPlugin(PluginConfigBase):
32
+ name: Literal["lightning_environment"] = "lightning_environment"
33
+
34
+ """The default environment used by Lightning for a single node or free cluster (not managed).
35
+
36
+ There are two modes the Lightning environment can operate with:
37
+ 1. User launches main process by `python train.py ...` with no additional environment variables.
38
+ Lightning will spawn new worker processes for distributed training in the current node.
39
+ 2. User launches all processes manually or with utilities like `torch.distributed.launch`.
40
+ The appropriate environment variables need to be set, and at minimum `LOCAL_RANK`.
41
+ """
42
+
43
+ @override
44
+ def create_plugin(self, trainer_config) -> ClusterEnvironment:
45
+ from lightning.fabric.plugins.environments.lightning import LightningEnvironment
46
+
47
+ return LightningEnvironment()
48
+
49
+
50
+ @plugin_registry.register
51
+ class LSFEnvironmentPlugin(PluginConfigBase):
52
+ name: Literal["lsf_environment"] = "lsf_environment"
53
+
54
+ """An environment for running on clusters managed by the LSF resource manager.
55
+
56
+ It is expected that any execution using this ClusterEnvironment was executed
57
+ using the Job Step Manager i.e. `jsrun`.
58
+ """
59
+
60
+ @override
61
+ def create_plugin(self, trainer_config) -> ClusterEnvironment:
62
+ from lightning.fabric.plugins.environments.lsf import LSFEnvironment
63
+
64
+ return LSFEnvironment()
65
+
66
+
67
+ @plugin_registry.register
68
+ class MPIEnvironmentPlugin(PluginConfigBase):
69
+ name: Literal["mpi_environment"] = "mpi_environment"
70
+
71
+ """An environment for running on clusters with processes created through MPI.
72
+
73
+ Requires the installation of the `mpi4py` package.
74
+ """
75
+
76
+ @override
77
+ def create_plugin(self, trainer_config) -> ClusterEnvironment:
78
+ from lightning.fabric.plugins.environments.mpi import MPIEnvironment
79
+
80
+ return MPIEnvironment()
81
+
82
+
83
+ @plugin_registry.register
84
+ class SLURMEnvironmentPlugin(PluginConfigBase):
85
+ name: Literal["slurm_environment"] = "slurm_environment"
86
+
87
+ auto_requeue: bool = True
88
+ """Whether automatic job resubmission is enabled or not."""
89
+
90
+ requeue_signal: signal.Signals | None = None
91
+ """The signal that SLURM will send to indicate that the job should be requeued."""
92
+
93
+ @override
94
+ def create_plugin(self, trainer_config) -> ClusterEnvironment:
95
+ from lightning.fabric.plugins.environments.slurm import SLURMEnvironment
96
+
97
+ return SLURMEnvironment(
98
+ auto_requeue=self.auto_requeue,
99
+ requeue_signal=self.requeue_signal,
100
+ )
101
+
102
+
103
+ @plugin_registry.register
104
+ class TorchElasticEnvironmentPlugin(PluginConfigBase):
105
+ name: Literal["torchelastic_environment"] = "torchelastic_environment"
106
+
107
+ """Environment for fault-tolerant and elastic training with torchelastic."""
108
+
109
+ @override
110
+ def create_plugin(self, trainer_config) -> ClusterEnvironment:
111
+ from lightning.fabric.plugins.environments.torchelastic import (
112
+ TorchElasticEnvironment,
113
+ )
114
+
115
+ return TorchElasticEnvironment()
116
+
117
+
118
+ @plugin_registry.register
119
+ class XLAEnvironmentPlugin(PluginConfigBase):
120
+ name: Literal["xla_environment"] = "xla_environment"
121
+
122
+ """Cluster environment for training on a TPU Pod with the PyTorch/XLA library."""
123
+
124
+ @override
125
+ def create_plugin(self, trainer_config) -> ClusterEnvironment:
126
+ from lightning.fabric.plugins.environments.xla import XLAEnvironment
127
+
128
+ return XLAEnvironment()
@@ -0,0 +1,62 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Literal
4
+
5
+ from lightning.pytorch.plugins.io import CheckpointIO
6
+ from typing_extensions import override
7
+
8
+ from .base import PluginConfig, PluginConfigBase, plugin_registry
9
+
10
+
11
+ @plugin_registry.register
12
+ class AsyncCheckpointIOPlugin(PluginConfigBase):
13
+ name: Literal["async_checkpoint"] = "async_checkpoint"
14
+
15
+ """Enables saving the checkpoints asynchronously in a thread.
16
+
17
+ .. warning:: This is an experimental feature.
18
+ """
19
+
20
+ checkpoint_io: PluginConfig | None = None
21
+ """A checkpoint IO plugin that is used as the basis for async checkpointing."""
22
+
23
+ @override
24
+ def create_plugin(self, trainer_config) -> CheckpointIO:
25
+ from lightning.pytorch.plugins.io.async_plugin import AsyncCheckpointIO
26
+
27
+ base_io = (
28
+ self.checkpoint_io.create_plugin(trainer_config)
29
+ if self.checkpoint_io
30
+ else None
31
+ )
32
+ if base_io is not None and not isinstance(base_io, CheckpointIO):
33
+ raise TypeError(
34
+ f"Expected `checkpoint_io` to be a `CheckpointIO` instance, but got {type(base_io)}."
35
+ )
36
+ return AsyncCheckpointIO(checkpoint_io=base_io)
37
+
38
+
39
+ @plugin_registry.register
40
+ class TorchCheckpointIOPlugin(PluginConfigBase):
41
+ name: Literal["torch_checkpoint"] = "torch_checkpoint"
42
+
43
+ """CheckpointIO that utilizes torch.save and torch.load to save and load checkpoints respectively."""
44
+
45
+ @override
46
+ def create_plugin(self, trainer_config) -> CheckpointIO:
47
+ from lightning.fabric.plugins.io.torch_io import TorchCheckpointIO
48
+
49
+ return TorchCheckpointIO()
50
+
51
+
52
+ @plugin_registry.register
53
+ class XLACheckpointIOPlugin(PluginConfigBase):
54
+ name: Literal["xla_checkpoint"] = "xla_checkpoint"
55
+
56
+ """CheckpointIO that utilizes xm.save to save checkpoints for TPU training strategies."""
57
+
58
+ @override
59
+ def create_plugin(self, trainer_config) -> CheckpointIO:
60
+ from lightning.fabric.plugins.io.xla import XLACheckpointIO
61
+
62
+ return XLACheckpointIO()
@@ -0,0 +1,25 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Literal
4
+
5
+ from lightning.pytorch.plugins.layer_sync import LayerSync
6
+ from typing_extensions import override
7
+
8
+ from .base import PluginConfigBase, plugin_registry
9
+
10
+
11
+ @plugin_registry.register
12
+ class TorchSyncBatchNormPlugin(PluginConfigBase):
13
+ name: Literal["torch_sync_batchnorm"] = "torch_sync_batchnorm"
14
+
15
+ """A plugin that wraps all batch normalization layers of a model with synchronization
16
+ logic for multiprocessing.
17
+
18
+ This plugin has no effect in single-device operation.
19
+ """
20
+
21
+ @override
22
+ def create_plugin(self, trainer_config) -> LayerSync:
23
+ from lightning.pytorch.plugins.layer_sync import TorchSyncBatchNorm
24
+
25
+ return TorchSyncBatchNorm()