nshtrainer 1.0.0b29__py3-none-any.whl → 1.0.0b30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,43 @@
1
+ from __future__ import annotations
2
+
3
+ __codegen__ = True
4
+
5
+ from nshtrainer.trainer.plugin.precision import (
6
+ BitsandbytesPluginConfig as BitsandbytesPluginConfig,
7
+ )
8
+ from nshtrainer.trainer.plugin.precision import (
9
+ DeepSpeedPluginConfig as DeepSpeedPluginConfig,
10
+ )
11
+ from nshtrainer.trainer.plugin.precision import (
12
+ DoublePrecisionPluginConfig as DoublePrecisionPluginConfig,
13
+ )
14
+ from nshtrainer.trainer.plugin.precision import DTypeConfig as DTypeConfig
15
+ from nshtrainer.trainer.plugin.precision import (
16
+ FSDPPrecisionPluginConfig as FSDPPrecisionPluginConfig,
17
+ )
18
+ from nshtrainer.trainer.plugin.precision import (
19
+ HalfPrecisionPluginConfig as HalfPrecisionPluginConfig,
20
+ )
21
+ from nshtrainer.trainer.plugin.precision import (
22
+ MixedPrecisionPluginConfig as MixedPrecisionPluginConfig,
23
+ )
24
+ from nshtrainer.trainer.plugin.precision import PluginConfigBase as PluginConfigBase
25
+ from nshtrainer.trainer.plugin.precision import (
26
+ TransformerEnginePluginConfig as TransformerEnginePluginConfig,
27
+ )
28
+ from nshtrainer.trainer.plugin.precision import XLAPluginConfig as XLAPluginConfig
29
+ from nshtrainer.trainer.plugin.precision import plugin_registry as plugin_registry
30
+
31
+ __all__ = [
32
+ "BitsandbytesPluginConfig",
33
+ "DTypeConfig",
34
+ "DeepSpeedPluginConfig",
35
+ "DoublePrecisionPluginConfig",
36
+ "FSDPPrecisionPluginConfig",
37
+ "HalfPrecisionPluginConfig",
38
+ "MixedPrecisionPluginConfig",
39
+ "PluginConfigBase",
40
+ "TransformerEnginePluginConfig",
41
+ "XLAPluginConfig",
42
+ "plugin_registry",
43
+ ]
@@ -0,0 +1,11 @@
1
+ from __future__ import annotations
2
+
3
+ __codegen__ = True
4
+
5
+ from nshtrainer.trainer.strategy import StrategyConfig as StrategyConfig
6
+ from nshtrainer.trainer.strategy import StrategyConfigBase as StrategyConfigBase
7
+
8
+ __all__ = [
9
+ "StrategyConfig",
10
+ "StrategyConfigBase",
11
+ ]
@@ -4,12 +4,14 @@ __codegen__ = True
4
4
 
5
5
  from nshtrainer.trainer.trainer import AcceleratorConfigBase as AcceleratorConfigBase
6
6
  from nshtrainer.trainer.trainer import EnvironmentConfig as EnvironmentConfig
7
+ from nshtrainer.trainer.trainer import PluginConfigBase as PluginConfigBase
7
8
  from nshtrainer.trainer.trainer import StrategyConfigBase as StrategyConfigBase
8
9
  from nshtrainer.trainer.trainer import TrainerConfig as TrainerConfig
9
10
 
10
11
  __all__ = [
11
12
  "AcceleratorConfigBase",
12
13
  "EnvironmentConfig",
14
+ "PluginConfigBase",
13
15
  "StrategyConfigBase",
14
16
  "TrainerConfig",
15
17
  ]
@@ -1,4 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from ._config import TrainerConfig as TrainerConfig
4
+ from ._config import accelerator_registry as accelerator_registry
5
+ from ._config import plugin_registry as plugin_registry
4
6
  from .trainer import Trainer as Trainer
@@ -5,7 +5,6 @@ import logging
5
5
  import os
6
6
  import string
7
7
  import time
8
- from abc import ABC, abstractmethod
9
8
  from collections.abc import Iterable, Sequence
10
9
  from datetime import timedelta
11
10
  from pathlib import Path
@@ -18,14 +17,11 @@ from typing import (
18
17
 
19
18
  import nshconfig as C
20
19
  import numpy as np
21
- from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment
22
20
  from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT
23
21
  from lightning.pytorch.accelerators import Accelerator
24
22
  from lightning.pytorch.callbacks.callback import Callback
25
23
  from lightning.pytorch.loggers import Logger
26
24
  from lightning.pytorch.plugins import _PLUGIN_INPUT
27
- from lightning.pytorch.plugins.layer_sync import LayerSync
28
- from lightning.pytorch.plugins.precision.precision import Precision
29
25
  from lightning.pytorch.profilers import Profiler
30
26
  from lightning.pytorch.strategies.strategy import Strategy
31
27
  from typing_extensions import TypeAliasType, TypedDict, override
@@ -58,6 +54,9 @@ from ..loggers.actsave import ActSaveLoggerConfig
58
54
  from ..metrics._config import MetricConfig
59
55
  from ..profiler import ProfilerConfig
60
56
  from ..util._environment_info import EnvironmentConfig
57
+ from .accelerator import AcceleratorConfig, AcceleratorLiteral, accelerator_registry
58
+ from .plugin import PluginConfig, plugin_registry
59
+ from .strategy import StrategyConfig
61
60
 
62
61
  log = logging.getLogger(__name__)
63
62
 
@@ -71,37 +70,6 @@ class GradientClippingConfig(C.Config):
71
70
  """Norm type to use for gradient clipping."""
72
71
 
73
72
 
74
- Plugin = TypeAliasType(
75
- "Plugin", Precision | ClusterEnvironment | CheckpointIO | LayerSync
76
- )
77
-
78
-
79
- class PluginConfigBase(C.Config, ABC):
80
- @abstractmethod
81
- def create_plugin(self) -> Plugin: ...
82
-
83
-
84
- plugin_registry = C.Registry(PluginConfigBase, discriminator="name")
85
- PluginConfig = TypeAliasType(
86
- "PluginConfig", Annotated[PluginConfigBase, plugin_registry.DynamicResolution()]
87
- )
88
-
89
- AcceleratorLiteral = TypeAliasType(
90
- "AcceleratorLiteral", Literal["cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto"]
91
- )
92
-
93
-
94
- class AcceleratorConfigBase(C.Config, ABC):
95
- @abstractmethod
96
- def create_accelerator(self) -> Accelerator: ...
97
-
98
-
99
- accelerator_registry = C.Registry(AcceleratorConfigBase, discriminator="name")
100
- AcceleratorConfig = TypeAliasType(
101
- "AcceleratorConfig",
102
- Annotated[AcceleratorConfigBase, accelerator_registry.DynamicResolution()],
103
- )
104
-
105
73
  StrategyLiteral = TypeAliasType(
106
74
  "StrategyLiteral",
107
75
  Literal[
@@ -135,17 +103,6 @@ StrategyLiteral = TypeAliasType(
135
103
  )
136
104
 
137
105
 
138
- class StrategyConfigBase(C.Config, ABC):
139
- @abstractmethod
140
- def create_strategy(self) -> Strategy: ...
141
-
142
-
143
- strategy_registry = C.Registry(StrategyConfigBase, discriminator="name")
144
- StrategyConfig = TypeAliasType(
145
- "StrategyConfig",
146
- Annotated[StrategyConfigBase, strategy_registry.DynamicResolution()],
147
- )
148
-
149
106
  CheckpointCallbackConfig = TypeAliasType(
150
107
  "CheckpointCallbackConfig",
151
108
  Annotated[
@@ -441,7 +398,6 @@ class SanityCheckingConfig(C.Config):
441
398
 
442
399
 
443
400
  @plugin_registry.rebuild_on_registers
444
- @strategy_registry.rebuild_on_registers
445
401
  @accelerator_registry.rebuild_on_registers
446
402
  class TrainerConfig(C.Config):
447
403
  # region Active Run Configuration
@@ -0,0 +1,86 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import TYPE_CHECKING, Annotated, Literal
5
+
6
+ import nshconfig as C
7
+ from lightning.pytorch.accelerators import Accelerator
8
+ from typing_extensions import TypeAliasType, override
9
+
10
+ if TYPE_CHECKING:
11
+ from ._config import TrainerConfig
12
+
13
+ AcceleratorLiteral = TypeAliasType(
14
+ "AcceleratorLiteral", Literal["cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto"]
15
+ )
16
+
17
+
18
+ class AcceleratorConfigBase(C.Config, ABC):
19
+ @abstractmethod
20
+ def create_accelerator(self, trainer_config: "TrainerConfig") -> Accelerator: ...
21
+
22
+
23
+ accelerator_registry = C.Registry(AcceleratorConfigBase, discriminator="name")
24
+
25
+ AcceleratorConfig = TypeAliasType(
26
+ "AcceleratorConfig",
27
+ Annotated[AcceleratorConfigBase, accelerator_registry.DynamicResolution()],
28
+ )
29
+
30
+
31
+ @accelerator_registry.register
32
+ class CPUAcceleratorConfig(AcceleratorConfigBase):
33
+ name: Literal["cpu"] = "cpu"
34
+
35
+ """Accelerator for CPU devices."""
36
+
37
+ @override
38
+ def create_accelerator(self, trainer_config) -> Accelerator:
39
+ from lightning.pytorch.accelerators.cpu import CPUAccelerator
40
+
41
+ return CPUAccelerator()
42
+
43
+
44
+ @accelerator_registry.register
45
+ class CUDAAcceleratorConfig(AcceleratorConfigBase):
46
+ name: Literal["gpu"] = "gpu"
47
+
48
+ """Accelerator for NVIDIA CUDA devices."""
49
+
50
+ @override
51
+ def create_accelerator(self, trainer_config) -> Accelerator:
52
+ from lightning.pytorch.accelerators.cuda import CUDAAccelerator
53
+
54
+ return CUDAAccelerator()
55
+
56
+
57
+ @accelerator_registry.register
58
+ class MPSAcceleratorConfig(AcceleratorConfigBase):
59
+ name: Literal["mps"] = "mps"
60
+
61
+ """Accelerator for Metal Apple Silicon GPU devices.
62
+
63
+ .. warning:: Use of this accelerator beyond import and instantiation is experimental.
64
+ """
65
+
66
+ @override
67
+ def create_accelerator(self, trainer_config) -> Accelerator:
68
+ from lightning.pytorch.accelerators.mps import MPSAccelerator
69
+
70
+ return MPSAccelerator()
71
+
72
+
73
+ @accelerator_registry.register
74
+ class XLAAcceleratorConfig(AcceleratorConfigBase):
75
+ name: Literal["tpu"] = "tpu"
76
+
77
+ """Accelerator for XLA devices, normally TPUs.
78
+
79
+ .. warning:: Use of this accelerator beyond import and instantiation is experimental.
80
+ """
81
+
82
+ @override
83
+ def create_accelerator(self, trainer_config) -> Accelerator:
84
+ from lightning.pytorch.accelerators.xla import XLAAccelerator
85
+
86
+ return XLAAccelerator()
@@ -0,0 +1,10 @@
1
+ from __future__ import annotations
2
+
3
+ from . import environment as environment
4
+ from . import io as io
5
+ from . import layer_sync as layer_sync
6
+ from . import precision as precision
7
+ from .base import Plugin as Plugin
8
+ from .base import PluginConfig as PluginConfig
9
+ from .base import PluginConfigBase as PluginConfigBase
10
+ from .base import plugin_registry as plugin_registry
@@ -0,0 +1,33 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from abc import ABC, abstractmethod
5
+ from typing import TYPE_CHECKING, Annotated
6
+
7
+ import nshconfig as C
8
+ from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment
9
+ from lightning.pytorch.plugins.layer_sync import LayerSync
10
+ from lightning.pytorch.plugins.precision.precision import Precision
11
+ from typing_extensions import TypeAliasType
12
+
13
+ if TYPE_CHECKING:
14
+ from .._config import TrainerConfig
15
+ log = logging.getLogger(__name__)
16
+
17
+
18
+ Plugin = TypeAliasType(
19
+ "Plugin", Precision | ClusterEnvironment | CheckpointIO | LayerSync
20
+ )
21
+
22
+
23
+ class PluginConfigBase(C.Config, ABC):
24
+ @abstractmethod
25
+ def create_plugin(self, trainer_config: "TrainerConfig") -> Plugin: ...
26
+
27
+
28
+ plugin_registry = C.Registry(PluginConfigBase, discriminator="name")
29
+
30
+ PluginConfig = TypeAliasType(
31
+ "PluginConfig",
32
+ Annotated[PluginConfigBase, plugin_registry.DynamicResolution()],
33
+ )
@@ -0,0 +1,128 @@
1
+ from __future__ import annotations
2
+
3
+ import signal
4
+ from typing import Any, Literal
5
+
6
+ from lightning.pytorch.plugins.environments import ClusterEnvironment
7
+ from typing_extensions import override
8
+
9
+ from ...util.config.dtype import DTypeConfig
10
+ from .base import PluginConfigBase, plugin_registry
11
+
12
+
13
+ @plugin_registry.register
14
+ class KubeflowEnvironmentPlugin(PluginConfigBase):
15
+ name: Literal["kubeflow_environment"] = "kubeflow_environment"
16
+
17
+ """Environment for distributed training using the PyTorchJob operator from Kubeflow.
18
+
19
+ This environment, unlike others, does not get auto-detected and needs to be passed
20
+ to the Fabric/Trainer constructor manually.
21
+ """
22
+
23
+ @override
24
+ def create_plugin(self, trainer_config) -> ClusterEnvironment:
25
+ from lightning.fabric.plugins.environments.kubeflow import KubeflowEnvironment
26
+
27
+ return KubeflowEnvironment()
28
+
29
+
30
+ @plugin_registry.register
31
+ class LightningEnvironmentPlugin(PluginConfigBase):
32
+ name: Literal["lightning_environment"] = "lightning_environment"
33
+
34
+ """The default environment used by Lightning for a single node or free cluster (not managed).
35
+
36
+ There are two modes the Lightning environment can operate with:
37
+ 1. User launches main process by `python train.py ...` with no additional environment variables.
38
+ Lightning will spawn new worker processes for distributed training in the current node.
39
+ 2. User launches all processes manually or with utilities like `torch.distributed.launch`.
40
+ The appropriate environment variables need to be set, and at minimum `LOCAL_RANK`.
41
+ """
42
+
43
+ @override
44
+ def create_plugin(self, trainer_config) -> ClusterEnvironment:
45
+ from lightning.fabric.plugins.environments.lightning import LightningEnvironment
46
+
47
+ return LightningEnvironment()
48
+
49
+
50
+ @plugin_registry.register
51
+ class LSFEnvironmentPlugin(PluginConfigBase):
52
+ name: Literal["lsf_environment"] = "lsf_environment"
53
+
54
+ """An environment for running on clusters managed by the LSF resource manager.
55
+
56
+ It is expected that any execution using this ClusterEnvironment was executed
57
+ using the Job Step Manager i.e. `jsrun`.
58
+ """
59
+
60
+ @override
61
+ def create_plugin(self, trainer_config) -> ClusterEnvironment:
62
+ from lightning.fabric.plugins.environments.lsf import LSFEnvironment
63
+
64
+ return LSFEnvironment()
65
+
66
+
67
+ @plugin_registry.register
68
+ class MPIEnvironmentPlugin(PluginConfigBase):
69
+ name: Literal["mpi_environment"] = "mpi_environment"
70
+
71
+ """An environment for running on clusters with processes created through MPI.
72
+
73
+ Requires the installation of the `mpi4py` package.
74
+ """
75
+
76
+ @override
77
+ def create_plugin(self, trainer_config) -> ClusterEnvironment:
78
+ from lightning.fabric.plugins.environments.mpi import MPIEnvironment
79
+
80
+ return MPIEnvironment()
81
+
82
+
83
+ @plugin_registry.register
84
+ class SLURMEnvironmentPlugin(PluginConfigBase):
85
+ name: Literal["slurm_environment"] = "slurm_environment"
86
+
87
+ auto_requeue: bool = True
88
+ """Whether automatic job resubmission is enabled or not."""
89
+
90
+ requeue_signal: signal.Signals | None = None
91
+ """The signal that SLURM will send to indicate that the job should be requeued."""
92
+
93
+ @override
94
+ def create_plugin(self, trainer_config) -> ClusterEnvironment:
95
+ from lightning.fabric.plugins.environments.slurm import SLURMEnvironment
96
+
97
+ return SLURMEnvironment(
98
+ auto_requeue=self.auto_requeue,
99
+ requeue_signal=self.requeue_signal,
100
+ )
101
+
102
+
103
+ @plugin_registry.register
104
+ class TorchElasticEnvironmentPlugin(PluginConfigBase):
105
+ name: Literal["torchelastic_environment"] = "torchelastic_environment"
106
+
107
+ """Environment for fault-tolerant and elastic training with torchelastic."""
108
+
109
+ @override
110
+ def create_plugin(self, trainer_config) -> ClusterEnvironment:
111
+ from lightning.fabric.plugins.environments.torchelastic import (
112
+ TorchElasticEnvironment,
113
+ )
114
+
115
+ return TorchElasticEnvironment()
116
+
117
+
118
+ @plugin_registry.register
119
+ class XLAEnvironmentPlugin(PluginConfigBase):
120
+ name: Literal["xla_environment"] = "xla_environment"
121
+
122
+ """Cluster environment for training on a TPU Pod with the PyTorch/XLA library."""
123
+
124
+ @override
125
+ def create_plugin(self, trainer_config) -> ClusterEnvironment:
126
+ from lightning.fabric.plugins.environments.xla import XLAEnvironment
127
+
128
+ return XLAEnvironment()
@@ -0,0 +1,62 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Literal
4
+
5
+ from lightning.pytorch.plugins.io import CheckpointIO
6
+ from typing_extensions import override
7
+
8
+ from .base import PluginConfig, PluginConfigBase, plugin_registry
9
+
10
+
11
+ @plugin_registry.register
12
+ class AsyncCheckpointIOPlugin(PluginConfigBase):
13
+ name: Literal["async_checkpoint"] = "async_checkpoint"
14
+
15
+ """Enables saving the checkpoints asynchronously in a thread.
16
+
17
+ .. warning:: This is an experimental feature.
18
+ """
19
+
20
+ checkpoint_io: PluginConfig | None = None
21
+ """A checkpoint IO plugin that is used as the basis for async checkpointing."""
22
+
23
+ @override
24
+ def create_plugin(self, trainer_config) -> CheckpointIO:
25
+ from lightning.pytorch.plugins.io.async_plugin import AsyncCheckpointIO
26
+
27
+ base_io = (
28
+ self.checkpoint_io.create_plugin(trainer_config)
29
+ if self.checkpoint_io
30
+ else None
31
+ )
32
+ if base_io is not None and not isinstance(base_io, CheckpointIO):
33
+ raise TypeError(
34
+ f"Expected `checkpoint_io` to be a `CheckpointIO` instance, but got {type(base_io)}."
35
+ )
36
+ return AsyncCheckpointIO(checkpoint_io=base_io)
37
+
38
+
39
+ @plugin_registry.register
40
+ class TorchCheckpointIOPlugin(PluginConfigBase):
41
+ name: Literal["torch_checkpoint"] = "torch_checkpoint"
42
+
43
+ """CheckpointIO that utilizes torch.save and torch.load to save and load checkpoints respectively."""
44
+
45
+ @override
46
+ def create_plugin(self, trainer_config) -> CheckpointIO:
47
+ from lightning.fabric.plugins.io.torch_io import TorchCheckpointIO
48
+
49
+ return TorchCheckpointIO()
50
+
51
+
52
+ @plugin_registry.register
53
+ class XLACheckpointIOPlugin(PluginConfigBase):
54
+ name: Literal["xla_checkpoint"] = "xla_checkpoint"
55
+
56
+ """CheckpointIO that utilizes xm.save to save checkpoints for TPU training strategies."""
57
+
58
+ @override
59
+ def create_plugin(self, trainer_config) -> CheckpointIO:
60
+ from lightning.fabric.plugins.io.xla import XLACheckpointIO
61
+
62
+ return XLACheckpointIO()
@@ -0,0 +1,25 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Literal
4
+
5
+ from lightning.pytorch.plugins.layer_sync import LayerSync
6
+ from typing_extensions import override
7
+
8
+ from .base import PluginConfigBase, plugin_registry
9
+
10
+
11
+ @plugin_registry.register
12
+ class TorchSyncBatchNormPlugin(PluginConfigBase):
13
+ name: Literal["torch_sync_batchnorm"] = "torch_sync_batchnorm"
14
+
15
+ """A plugin that wraps all batch normalization layers of a model with synchronization
16
+ logic for multiprocessing.
17
+
18
+ This plugin has no effect in single-device operation.
19
+ """
20
+
21
+ @override
22
+ def create_plugin(self, trainer_config) -> LayerSync:
23
+ from lightning.pytorch.plugins.layer_sync import TorchSyncBatchNorm
24
+
25
+ return TorchSyncBatchNorm()
@@ -0,0 +1,163 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Literal
4
+
5
+ from lightning.pytorch.plugins.precision import Precision
6
+ from typing_extensions import override
7
+
8
+ from ...util.config.dtype import DTypeConfig
9
+ from .base import PluginConfigBase, plugin_registry
10
+
11
+
12
+ @plugin_registry.register
13
+ class MixedPrecisionPluginConfig(PluginConfigBase):
14
+ name: Literal["mixed_precision"] = "mixed_precision"
15
+
16
+ precision: Literal["16-mixed", "bf16-mixed"]
17
+ """Whether to use ``torch.float16`` (``'16-mixed'``) or ``torch.bfloat16`` (``'bf16-mixed'``)."""
18
+
19
+ device: str
20
+ """The device for ``torch.autocast``."""
21
+
22
+ @override
23
+ def create_plugin(self, trainer_config) -> Precision:
24
+ from lightning.pytorch.plugins.precision.amp import MixedPrecision
25
+
26
+ return MixedPrecision(self.precision, self.device)
27
+
28
+
29
+ @plugin_registry.register
30
+ class BitsandbytesPluginConfig(PluginConfigBase):
31
+ name: Literal["bitsandbytes_precision"] = "bitsandbytes_precision"
32
+
33
+ mode: Literal["nf4", "nf4-dq", "fp4", "fp4-dq", "int8", "int8-training"]
34
+ """The quantization mode to use."""
35
+
36
+ dtype: DTypeConfig | None = None
37
+ """The compute dtype to use."""
38
+
39
+ ignore_modules: set[str] | None = None
40
+ """The submodules whose Linear layers should not be replaced.
41
+
42
+ This might be desirable for numerical stability. The string will be checked
43
+ as a prefix, so a value like "transformer.blocks" will ignore all linear
44
+ layers in all of the transformer blocks.
45
+ """
46
+
47
+ @override
48
+ def create_plugin(self, trainer_config) -> Precision:
49
+ from lightning.pytorch.plugins.precision.bitsandbytes import (
50
+ BitsandbytesPrecision,
51
+ )
52
+
53
+ return BitsandbytesPrecision(
54
+ mode=self.mode,
55
+ dtype=self.dtype.torch_dtype if self.dtype is not None else None,
56
+ ignore_modules=self.ignore_modules,
57
+ )
58
+
59
+
60
+ @plugin_registry.register
61
+ class DeepSpeedPluginConfig(PluginConfigBase):
62
+ name: Literal["deepspeed_precision"] = "deepspeed_precision"
63
+
64
+ precision: Literal["16-true", "bf16-true", "16-mixed", "bf16-mixed", "32-true"]
65
+ """Full precision (32-true), half precision (16-true, bf16-true) or
66
+ mixed precision (16-mixed, bf16-mixed)."""
67
+
68
+ @override
69
+ def create_plugin(self, trainer_config) -> Precision:
70
+ from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecision
71
+
72
+ return DeepSpeedPrecision(precision=self.precision)
73
+
74
+
75
+ @plugin_registry.register
76
+ class DoublePrecisionPluginConfig(PluginConfigBase):
77
+ name: Literal["double_precision"] = "double_precision"
78
+
79
+ precision: Literal["64-true"] = "64-true"
80
+ """Plugin for training with double (``torch.float64``) precision."""
81
+
82
+ @override
83
+ def create_plugin(self, trainer_config) -> Precision:
84
+ from lightning.pytorch.plugins.precision.double import DoublePrecision
85
+
86
+ return DoublePrecision()
87
+
88
+
89
+ @plugin_registry.register
90
+ class FSDPPrecisionPluginConfig(PluginConfigBase):
91
+ name: Literal["fsdp_precision"] = "fsdp_precision"
92
+
93
+ precision: Literal["16-true", "bf16-true", "16-mixed", "bf16-mixed", "32-true"]
94
+ """Full precision (32-true), half precision (16-true, bf16-true) or
95
+ mixed precision (16-mixed, bf16-mixed)."""
96
+
97
+ @override
98
+ def create_plugin(self, trainer_config) -> Precision:
99
+ from lightning.pytorch.plugins.precision.fsdp import FSDPPrecision
100
+
101
+ return FSDPPrecision(precision=self.precision)
102
+
103
+
104
+ @plugin_registry.register
105
+ class HalfPrecisionPluginConfig(PluginConfigBase):
106
+ name: Literal["half_precision"] = "half_precision"
107
+
108
+ precision: Literal["bf16-true", "16-true"]
109
+ """Whether to use ``torch.float16`` (``'16-true'``) or ``torch.bfloat16`` (``'bf16-true'``)."""
110
+
111
+ @override
112
+ def create_plugin(self, trainer_config) -> Precision:
113
+ from lightning.pytorch.plugins.precision.half import HalfPrecision
114
+
115
+ return HalfPrecision(precision=self.precision)
116
+
117
+
118
+ @plugin_registry.register
119
+ class TransformerEnginePluginConfig(PluginConfigBase):
120
+ name: Literal["transformer_engine_precision"] = "transformer_engine_precision"
121
+
122
+ weights_dtype: DTypeConfig
123
+ """The weights dtype to use."""
124
+
125
+ recipe: dict[str, Any] | None = None
126
+ """Recipe for the DelayedScaling configuration in dict format."""
127
+
128
+ replace_layers: bool | None = None
129
+ """Whether to replace ``Linear`` and ``LayerNorm`` layers automatically with their
130
+ Transformer Engine alternatives."""
131
+
132
+ fallback_compute_dtype: DTypeConfig | None = None
133
+ """The compute dtype to use for operations that don't support fp8 autocast.
134
+ Defaults to the same as weights_dtype."""
135
+
136
+ @override
137
+ def create_plugin(self, trainer_config) -> Precision:
138
+ from lightning.pytorch.plugins.precision.transformer_engine import (
139
+ TransformerEnginePrecision,
140
+ )
141
+
142
+ return TransformerEnginePrecision(
143
+ weights_dtype=self.weights_dtype.torch_dtype,
144
+ recipe=self.recipe,
145
+ replace_layers=self.replace_layers,
146
+ fallback_compute_dtype=self.fallback_compute_dtype.torch_dtype
147
+ if self.fallback_compute_dtype
148
+ else None,
149
+ )
150
+
151
+
152
+ @plugin_registry.register
153
+ class XLAPluginConfig(PluginConfigBase):
154
+ name: Literal["xla_precision"] = "xla_precision"
155
+
156
+ precision: Literal["32-true", "16-true", "bf16-true"]
157
+ """Full precision (32-true) or half precision (16-true, bf16-true)."""
158
+
159
+ @override
160
+ def create_plugin(self, trainer_config) -> Precision:
161
+ from lightning.pytorch.plugins.precision.xla import XLAPrecision
162
+
163
+ return XLAPrecision(precision=self.precision)