nshtrainer 1.0.0b29__py3-none-any.whl → 1.0.0b31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +2 -0
- nshtrainer/configs/__init__.py +95 -3
- nshtrainer/configs/trainer/__init__.py +103 -3
- nshtrainer/configs/trainer/_config/__init__.py +10 -6
- nshtrainer/configs/trainer/accelerator/__init__.py +25 -0
- nshtrainer/configs/trainer/plugin/__init__.py +98 -0
- nshtrainer/configs/trainer/plugin/base/__init__.py +13 -0
- nshtrainer/configs/trainer/plugin/environment/__init__.py +41 -0
- nshtrainer/configs/trainer/plugin/io/__init__.py +23 -0
- nshtrainer/configs/trainer/plugin/layer_sync/__init__.py +15 -0
- nshtrainer/configs/trainer/plugin/precision/__init__.py +43 -0
- nshtrainer/configs/trainer/strategy/__init__.py +11 -0
- nshtrainer/configs/trainer/trainer/__init__.py +2 -0
- nshtrainer/data/datamodule.py +2 -0
- nshtrainer/model/base.py +2 -0
- nshtrainer/trainer/__init__.py +2 -0
- nshtrainer/trainer/_config.py +3 -47
- nshtrainer/trainer/accelerator.py +86 -0
- nshtrainer/trainer/plugin/__init__.py +10 -0
- nshtrainer/trainer/plugin/base.py +33 -0
- nshtrainer/trainer/plugin/environment.py +128 -0
- nshtrainer/trainer/plugin/io.py +62 -0
- nshtrainer/trainer/plugin/layer_sync.py +25 -0
- nshtrainer/trainer/plugin/precision.py +163 -0
- nshtrainer/trainer/strategy.py +51 -0
- nshtrainer/trainer/trainer.py +8 -9
- nshtrainer/util/hparams.py +17 -0
- {nshtrainer-1.0.0b29.dist-info → nshtrainer-1.0.0b31.dist-info}/METADATA +1 -1
- {nshtrainer-1.0.0b29.dist-info → nshtrainer-1.0.0b31.dist-info}/RECORD +30 -13
- {nshtrainer-1.0.0b29.dist-info → nshtrainer-1.0.0b31.dist-info}/WHEEL +0 -0
@@ -0,0 +1,43 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__codegen__ = True
|
4
|
+
|
5
|
+
from nshtrainer.trainer.plugin.precision import (
|
6
|
+
BitsandbytesPluginConfig as BitsandbytesPluginConfig,
|
7
|
+
)
|
8
|
+
from nshtrainer.trainer.plugin.precision import (
|
9
|
+
DeepSpeedPluginConfig as DeepSpeedPluginConfig,
|
10
|
+
)
|
11
|
+
from nshtrainer.trainer.plugin.precision import (
|
12
|
+
DoublePrecisionPluginConfig as DoublePrecisionPluginConfig,
|
13
|
+
)
|
14
|
+
from nshtrainer.trainer.plugin.precision import DTypeConfig as DTypeConfig
|
15
|
+
from nshtrainer.trainer.plugin.precision import (
|
16
|
+
FSDPPrecisionPluginConfig as FSDPPrecisionPluginConfig,
|
17
|
+
)
|
18
|
+
from nshtrainer.trainer.plugin.precision import (
|
19
|
+
HalfPrecisionPluginConfig as HalfPrecisionPluginConfig,
|
20
|
+
)
|
21
|
+
from nshtrainer.trainer.plugin.precision import (
|
22
|
+
MixedPrecisionPluginConfig as MixedPrecisionPluginConfig,
|
23
|
+
)
|
24
|
+
from nshtrainer.trainer.plugin.precision import PluginConfigBase as PluginConfigBase
|
25
|
+
from nshtrainer.trainer.plugin.precision import (
|
26
|
+
TransformerEnginePluginConfig as TransformerEnginePluginConfig,
|
27
|
+
)
|
28
|
+
from nshtrainer.trainer.plugin.precision import XLAPluginConfig as XLAPluginConfig
|
29
|
+
from nshtrainer.trainer.plugin.precision import plugin_registry as plugin_registry
|
30
|
+
|
31
|
+
__all__ = [
|
32
|
+
"BitsandbytesPluginConfig",
|
33
|
+
"DTypeConfig",
|
34
|
+
"DeepSpeedPluginConfig",
|
35
|
+
"DoublePrecisionPluginConfig",
|
36
|
+
"FSDPPrecisionPluginConfig",
|
37
|
+
"HalfPrecisionPluginConfig",
|
38
|
+
"MixedPrecisionPluginConfig",
|
39
|
+
"PluginConfigBase",
|
40
|
+
"TransformerEnginePluginConfig",
|
41
|
+
"XLAPluginConfig",
|
42
|
+
"plugin_registry",
|
43
|
+
]
|
@@ -0,0 +1,11 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__codegen__ = True
|
4
|
+
|
5
|
+
from nshtrainer.trainer.strategy import StrategyConfig as StrategyConfig
|
6
|
+
from nshtrainer.trainer.strategy import StrategyConfigBase as StrategyConfigBase
|
7
|
+
|
8
|
+
__all__ = [
|
9
|
+
"StrategyConfig",
|
10
|
+
"StrategyConfigBase",
|
11
|
+
]
|
@@ -4,12 +4,14 @@ __codegen__ = True
|
|
4
4
|
|
5
5
|
from nshtrainer.trainer.trainer import AcceleratorConfigBase as AcceleratorConfigBase
|
6
6
|
from nshtrainer.trainer.trainer import EnvironmentConfig as EnvironmentConfig
|
7
|
+
from nshtrainer.trainer.trainer import PluginConfigBase as PluginConfigBase
|
7
8
|
from nshtrainer.trainer.trainer import StrategyConfigBase as StrategyConfigBase
|
8
9
|
from nshtrainer.trainer.trainer import TrainerConfig as TrainerConfig
|
9
10
|
|
10
11
|
__all__ = [
|
11
12
|
"AcceleratorConfigBase",
|
12
13
|
"EnvironmentConfig",
|
14
|
+
"PluginConfigBase",
|
13
15
|
"StrategyConfigBase",
|
14
16
|
"TrainerConfig",
|
15
17
|
]
|
nshtrainer/data/datamodule.py
CHANGED
@@ -12,11 +12,13 @@ from typing_extensions import Never, TypeVar, deprecated, override
|
|
12
12
|
|
13
13
|
from ..model.mixins.callback import CallbackRegistrarModuleMixin
|
14
14
|
from ..model.mixins.debug import _DebugModuleMixin
|
15
|
+
from ..util.hparams import HyperparamsMixin
|
15
16
|
|
16
17
|
THparams = TypeVar("THparams", bound=C.Config, infer_variance=True)
|
17
18
|
|
18
19
|
|
19
20
|
class LightningDataModuleBase(
|
21
|
+
HyperparamsMixin,
|
20
22
|
_DebugModuleMixin,
|
21
23
|
CallbackRegistrarModuleMixin,
|
22
24
|
LightningDataModule,
|
nshtrainer/model/base.py
CHANGED
@@ -16,6 +16,7 @@ from lightning.pytorch.utilities.rank_zero import rank_zero_warn
|
|
16
16
|
from typing_extensions import Never, TypeVar, deprecated, override
|
17
17
|
|
18
18
|
from ..callbacks.rlp_sanity_checks import _RLPSanityCheckModuleMixin
|
19
|
+
from ..util.hparams import HyperparamsMixin
|
19
20
|
from .mixins.callback import CallbackModuleMixin
|
20
21
|
from .mixins.debug import _DebugModuleMixin
|
21
22
|
from .mixins.logger import LoggerLightningModuleMixin
|
@@ -54,6 +55,7 @@ VALID_REDUCE_OPS = (
|
|
54
55
|
|
55
56
|
|
56
57
|
class LightningModuleBase(
|
58
|
+
HyperparamsMixin,
|
57
59
|
_DebugModuleMixin,
|
58
60
|
_RLPSanityCheckModuleMixin,
|
59
61
|
LoggerLightningModuleMixin,
|
nshtrainer/trainer/__init__.py
CHANGED
nshtrainer/trainer/_config.py
CHANGED
@@ -5,7 +5,6 @@ import logging
|
|
5
5
|
import os
|
6
6
|
import string
|
7
7
|
import time
|
8
|
-
from abc import ABC, abstractmethod
|
9
8
|
from collections.abc import Iterable, Sequence
|
10
9
|
from datetime import timedelta
|
11
10
|
from pathlib import Path
|
@@ -18,14 +17,11 @@ from typing import (
|
|
18
17
|
|
19
18
|
import nshconfig as C
|
20
19
|
import numpy as np
|
21
|
-
from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment
|
22
20
|
from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT
|
23
21
|
from lightning.pytorch.accelerators import Accelerator
|
24
22
|
from lightning.pytorch.callbacks.callback import Callback
|
25
23
|
from lightning.pytorch.loggers import Logger
|
26
24
|
from lightning.pytorch.plugins import _PLUGIN_INPUT
|
27
|
-
from lightning.pytorch.plugins.layer_sync import LayerSync
|
28
|
-
from lightning.pytorch.plugins.precision.precision import Precision
|
29
25
|
from lightning.pytorch.profilers import Profiler
|
30
26
|
from lightning.pytorch.strategies.strategy import Strategy
|
31
27
|
from typing_extensions import TypeAliasType, TypedDict, override
|
@@ -58,6 +54,9 @@ from ..loggers.actsave import ActSaveLoggerConfig
|
|
58
54
|
from ..metrics._config import MetricConfig
|
59
55
|
from ..profiler import ProfilerConfig
|
60
56
|
from ..util._environment_info import EnvironmentConfig
|
57
|
+
from .accelerator import AcceleratorConfig, AcceleratorLiteral, accelerator_registry
|
58
|
+
from .plugin import PluginConfig, plugin_registry
|
59
|
+
from .strategy import StrategyConfig
|
61
60
|
|
62
61
|
log = logging.getLogger(__name__)
|
63
62
|
|
@@ -71,37 +70,6 @@ class GradientClippingConfig(C.Config):
|
|
71
70
|
"""Norm type to use for gradient clipping."""
|
72
71
|
|
73
72
|
|
74
|
-
Plugin = TypeAliasType(
|
75
|
-
"Plugin", Precision | ClusterEnvironment | CheckpointIO | LayerSync
|
76
|
-
)
|
77
|
-
|
78
|
-
|
79
|
-
class PluginConfigBase(C.Config, ABC):
|
80
|
-
@abstractmethod
|
81
|
-
def create_plugin(self) -> Plugin: ...
|
82
|
-
|
83
|
-
|
84
|
-
plugin_registry = C.Registry(PluginConfigBase, discriminator="name")
|
85
|
-
PluginConfig = TypeAliasType(
|
86
|
-
"PluginConfig", Annotated[PluginConfigBase, plugin_registry.DynamicResolution()]
|
87
|
-
)
|
88
|
-
|
89
|
-
AcceleratorLiteral = TypeAliasType(
|
90
|
-
"AcceleratorLiteral", Literal["cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto"]
|
91
|
-
)
|
92
|
-
|
93
|
-
|
94
|
-
class AcceleratorConfigBase(C.Config, ABC):
|
95
|
-
@abstractmethod
|
96
|
-
def create_accelerator(self) -> Accelerator: ...
|
97
|
-
|
98
|
-
|
99
|
-
accelerator_registry = C.Registry(AcceleratorConfigBase, discriminator="name")
|
100
|
-
AcceleratorConfig = TypeAliasType(
|
101
|
-
"AcceleratorConfig",
|
102
|
-
Annotated[AcceleratorConfigBase, accelerator_registry.DynamicResolution()],
|
103
|
-
)
|
104
|
-
|
105
73
|
StrategyLiteral = TypeAliasType(
|
106
74
|
"StrategyLiteral",
|
107
75
|
Literal[
|
@@ -135,17 +103,6 @@ StrategyLiteral = TypeAliasType(
|
|
135
103
|
)
|
136
104
|
|
137
105
|
|
138
|
-
class StrategyConfigBase(C.Config, ABC):
|
139
|
-
@abstractmethod
|
140
|
-
def create_strategy(self) -> Strategy: ...
|
141
|
-
|
142
|
-
|
143
|
-
strategy_registry = C.Registry(StrategyConfigBase, discriminator="name")
|
144
|
-
StrategyConfig = TypeAliasType(
|
145
|
-
"StrategyConfig",
|
146
|
-
Annotated[StrategyConfigBase, strategy_registry.DynamicResolution()],
|
147
|
-
)
|
148
|
-
|
149
106
|
CheckpointCallbackConfig = TypeAliasType(
|
150
107
|
"CheckpointCallbackConfig",
|
151
108
|
Annotated[
|
@@ -441,7 +398,6 @@ class SanityCheckingConfig(C.Config):
|
|
441
398
|
|
442
399
|
|
443
400
|
@plugin_registry.rebuild_on_registers
|
444
|
-
@strategy_registry.rebuild_on_registers
|
445
401
|
@accelerator_registry.rebuild_on_registers
|
446
402
|
class TrainerConfig(C.Config):
|
447
403
|
# region Active Run Configuration
|
@@ -0,0 +1,86 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from abc import ABC, abstractmethod
|
4
|
+
from typing import TYPE_CHECKING, Annotated, Literal
|
5
|
+
|
6
|
+
import nshconfig as C
|
7
|
+
from lightning.pytorch.accelerators import Accelerator
|
8
|
+
from typing_extensions import TypeAliasType, override
|
9
|
+
|
10
|
+
if TYPE_CHECKING:
|
11
|
+
from ._config import TrainerConfig
|
12
|
+
|
13
|
+
AcceleratorLiteral = TypeAliasType(
|
14
|
+
"AcceleratorLiteral", Literal["cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto"]
|
15
|
+
)
|
16
|
+
|
17
|
+
|
18
|
+
class AcceleratorConfigBase(C.Config, ABC):
|
19
|
+
@abstractmethod
|
20
|
+
def create_accelerator(self, trainer_config: "TrainerConfig") -> Accelerator: ...
|
21
|
+
|
22
|
+
|
23
|
+
accelerator_registry = C.Registry(AcceleratorConfigBase, discriminator="name")
|
24
|
+
|
25
|
+
AcceleratorConfig = TypeAliasType(
|
26
|
+
"AcceleratorConfig",
|
27
|
+
Annotated[AcceleratorConfigBase, accelerator_registry.DynamicResolution()],
|
28
|
+
)
|
29
|
+
|
30
|
+
|
31
|
+
@accelerator_registry.register
|
32
|
+
class CPUAcceleratorConfig(AcceleratorConfigBase):
|
33
|
+
name: Literal["cpu"] = "cpu"
|
34
|
+
|
35
|
+
"""Accelerator for CPU devices."""
|
36
|
+
|
37
|
+
@override
|
38
|
+
def create_accelerator(self, trainer_config) -> Accelerator:
|
39
|
+
from lightning.pytorch.accelerators.cpu import CPUAccelerator
|
40
|
+
|
41
|
+
return CPUAccelerator()
|
42
|
+
|
43
|
+
|
44
|
+
@accelerator_registry.register
|
45
|
+
class CUDAAcceleratorConfig(AcceleratorConfigBase):
|
46
|
+
name: Literal["gpu"] = "gpu"
|
47
|
+
|
48
|
+
"""Accelerator for NVIDIA CUDA devices."""
|
49
|
+
|
50
|
+
@override
|
51
|
+
def create_accelerator(self, trainer_config) -> Accelerator:
|
52
|
+
from lightning.pytorch.accelerators.cuda import CUDAAccelerator
|
53
|
+
|
54
|
+
return CUDAAccelerator()
|
55
|
+
|
56
|
+
|
57
|
+
@accelerator_registry.register
|
58
|
+
class MPSAcceleratorConfig(AcceleratorConfigBase):
|
59
|
+
name: Literal["mps"] = "mps"
|
60
|
+
|
61
|
+
"""Accelerator for Metal Apple Silicon GPU devices.
|
62
|
+
|
63
|
+
.. warning:: Use of this accelerator beyond import and instantiation is experimental.
|
64
|
+
"""
|
65
|
+
|
66
|
+
@override
|
67
|
+
def create_accelerator(self, trainer_config) -> Accelerator:
|
68
|
+
from lightning.pytorch.accelerators.mps import MPSAccelerator
|
69
|
+
|
70
|
+
return MPSAccelerator()
|
71
|
+
|
72
|
+
|
73
|
+
@accelerator_registry.register
|
74
|
+
class XLAAcceleratorConfig(AcceleratorConfigBase):
|
75
|
+
name: Literal["tpu"] = "tpu"
|
76
|
+
|
77
|
+
"""Accelerator for XLA devices, normally TPUs.
|
78
|
+
|
79
|
+
.. warning:: Use of this accelerator beyond import and instantiation is experimental.
|
80
|
+
"""
|
81
|
+
|
82
|
+
@override
|
83
|
+
def create_accelerator(self, trainer_config) -> Accelerator:
|
84
|
+
from lightning.pytorch.accelerators.xla import XLAAccelerator
|
85
|
+
|
86
|
+
return XLAAccelerator()
|
@@ -0,0 +1,10 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from . import environment as environment
|
4
|
+
from . import io as io
|
5
|
+
from . import layer_sync as layer_sync
|
6
|
+
from . import precision as precision
|
7
|
+
from .base import Plugin as Plugin
|
8
|
+
from .base import PluginConfig as PluginConfig
|
9
|
+
from .base import PluginConfigBase as PluginConfigBase
|
10
|
+
from .base import plugin_registry as plugin_registry
|
@@ -0,0 +1,33 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from abc import ABC, abstractmethod
|
5
|
+
from typing import TYPE_CHECKING, Annotated
|
6
|
+
|
7
|
+
import nshconfig as C
|
8
|
+
from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment
|
9
|
+
from lightning.pytorch.plugins.layer_sync import LayerSync
|
10
|
+
from lightning.pytorch.plugins.precision.precision import Precision
|
11
|
+
from typing_extensions import TypeAliasType
|
12
|
+
|
13
|
+
if TYPE_CHECKING:
|
14
|
+
from .._config import TrainerConfig
|
15
|
+
log = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
Plugin = TypeAliasType(
|
19
|
+
"Plugin", Precision | ClusterEnvironment | CheckpointIO | LayerSync
|
20
|
+
)
|
21
|
+
|
22
|
+
|
23
|
+
class PluginConfigBase(C.Config, ABC):
|
24
|
+
@abstractmethod
|
25
|
+
def create_plugin(self, trainer_config: "TrainerConfig") -> Plugin: ...
|
26
|
+
|
27
|
+
|
28
|
+
plugin_registry = C.Registry(PluginConfigBase, discriminator="name")
|
29
|
+
|
30
|
+
PluginConfig = TypeAliasType(
|
31
|
+
"PluginConfig",
|
32
|
+
Annotated[PluginConfigBase, plugin_registry.DynamicResolution()],
|
33
|
+
)
|
@@ -0,0 +1,128 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import signal
|
4
|
+
from typing import Any, Literal
|
5
|
+
|
6
|
+
from lightning.pytorch.plugins.environments import ClusterEnvironment
|
7
|
+
from typing_extensions import override
|
8
|
+
|
9
|
+
from ...util.config.dtype import DTypeConfig
|
10
|
+
from .base import PluginConfigBase, plugin_registry
|
11
|
+
|
12
|
+
|
13
|
+
@plugin_registry.register
|
14
|
+
class KubeflowEnvironmentPlugin(PluginConfigBase):
|
15
|
+
name: Literal["kubeflow_environment"] = "kubeflow_environment"
|
16
|
+
|
17
|
+
"""Environment for distributed training using the PyTorchJob operator from Kubeflow.
|
18
|
+
|
19
|
+
This environment, unlike others, does not get auto-detected and needs to be passed
|
20
|
+
to the Fabric/Trainer constructor manually.
|
21
|
+
"""
|
22
|
+
|
23
|
+
@override
|
24
|
+
def create_plugin(self, trainer_config) -> ClusterEnvironment:
|
25
|
+
from lightning.fabric.plugins.environments.kubeflow import KubeflowEnvironment
|
26
|
+
|
27
|
+
return KubeflowEnvironment()
|
28
|
+
|
29
|
+
|
30
|
+
@plugin_registry.register
|
31
|
+
class LightningEnvironmentPlugin(PluginConfigBase):
|
32
|
+
name: Literal["lightning_environment"] = "lightning_environment"
|
33
|
+
|
34
|
+
"""The default environment used by Lightning for a single node or free cluster (not managed).
|
35
|
+
|
36
|
+
There are two modes the Lightning environment can operate with:
|
37
|
+
1. User launches main process by `python train.py ...` with no additional environment variables.
|
38
|
+
Lightning will spawn new worker processes for distributed training in the current node.
|
39
|
+
2. User launches all processes manually or with utilities like `torch.distributed.launch`.
|
40
|
+
The appropriate environment variables need to be set, and at minimum `LOCAL_RANK`.
|
41
|
+
"""
|
42
|
+
|
43
|
+
@override
|
44
|
+
def create_plugin(self, trainer_config) -> ClusterEnvironment:
|
45
|
+
from lightning.fabric.plugins.environments.lightning import LightningEnvironment
|
46
|
+
|
47
|
+
return LightningEnvironment()
|
48
|
+
|
49
|
+
|
50
|
+
@plugin_registry.register
|
51
|
+
class LSFEnvironmentPlugin(PluginConfigBase):
|
52
|
+
name: Literal["lsf_environment"] = "lsf_environment"
|
53
|
+
|
54
|
+
"""An environment for running on clusters managed by the LSF resource manager.
|
55
|
+
|
56
|
+
It is expected that any execution using this ClusterEnvironment was executed
|
57
|
+
using the Job Step Manager i.e. `jsrun`.
|
58
|
+
"""
|
59
|
+
|
60
|
+
@override
|
61
|
+
def create_plugin(self, trainer_config) -> ClusterEnvironment:
|
62
|
+
from lightning.fabric.plugins.environments.lsf import LSFEnvironment
|
63
|
+
|
64
|
+
return LSFEnvironment()
|
65
|
+
|
66
|
+
|
67
|
+
@plugin_registry.register
|
68
|
+
class MPIEnvironmentPlugin(PluginConfigBase):
|
69
|
+
name: Literal["mpi_environment"] = "mpi_environment"
|
70
|
+
|
71
|
+
"""An environment for running on clusters with processes created through MPI.
|
72
|
+
|
73
|
+
Requires the installation of the `mpi4py` package.
|
74
|
+
"""
|
75
|
+
|
76
|
+
@override
|
77
|
+
def create_plugin(self, trainer_config) -> ClusterEnvironment:
|
78
|
+
from lightning.fabric.plugins.environments.mpi import MPIEnvironment
|
79
|
+
|
80
|
+
return MPIEnvironment()
|
81
|
+
|
82
|
+
|
83
|
+
@plugin_registry.register
|
84
|
+
class SLURMEnvironmentPlugin(PluginConfigBase):
|
85
|
+
name: Literal["slurm_environment"] = "slurm_environment"
|
86
|
+
|
87
|
+
auto_requeue: bool = True
|
88
|
+
"""Whether automatic job resubmission is enabled or not."""
|
89
|
+
|
90
|
+
requeue_signal: signal.Signals | None = None
|
91
|
+
"""The signal that SLURM will send to indicate that the job should be requeued."""
|
92
|
+
|
93
|
+
@override
|
94
|
+
def create_plugin(self, trainer_config) -> ClusterEnvironment:
|
95
|
+
from lightning.fabric.plugins.environments.slurm import SLURMEnvironment
|
96
|
+
|
97
|
+
return SLURMEnvironment(
|
98
|
+
auto_requeue=self.auto_requeue,
|
99
|
+
requeue_signal=self.requeue_signal,
|
100
|
+
)
|
101
|
+
|
102
|
+
|
103
|
+
@plugin_registry.register
|
104
|
+
class TorchElasticEnvironmentPlugin(PluginConfigBase):
|
105
|
+
name: Literal["torchelastic_environment"] = "torchelastic_environment"
|
106
|
+
|
107
|
+
"""Environment for fault-tolerant and elastic training with torchelastic."""
|
108
|
+
|
109
|
+
@override
|
110
|
+
def create_plugin(self, trainer_config) -> ClusterEnvironment:
|
111
|
+
from lightning.fabric.plugins.environments.torchelastic import (
|
112
|
+
TorchElasticEnvironment,
|
113
|
+
)
|
114
|
+
|
115
|
+
return TorchElasticEnvironment()
|
116
|
+
|
117
|
+
|
118
|
+
@plugin_registry.register
|
119
|
+
class XLAEnvironmentPlugin(PluginConfigBase):
|
120
|
+
name: Literal["xla_environment"] = "xla_environment"
|
121
|
+
|
122
|
+
"""Cluster environment for training on a TPU Pod with the PyTorch/XLA library."""
|
123
|
+
|
124
|
+
@override
|
125
|
+
def create_plugin(self, trainer_config) -> ClusterEnvironment:
|
126
|
+
from lightning.fabric.plugins.environments.xla import XLAEnvironment
|
127
|
+
|
128
|
+
return XLAEnvironment()
|
@@ -0,0 +1,62 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Literal
|
4
|
+
|
5
|
+
from lightning.pytorch.plugins.io import CheckpointIO
|
6
|
+
from typing_extensions import override
|
7
|
+
|
8
|
+
from .base import PluginConfig, PluginConfigBase, plugin_registry
|
9
|
+
|
10
|
+
|
11
|
+
@plugin_registry.register
|
12
|
+
class AsyncCheckpointIOPlugin(PluginConfigBase):
|
13
|
+
name: Literal["async_checkpoint"] = "async_checkpoint"
|
14
|
+
|
15
|
+
"""Enables saving the checkpoints asynchronously in a thread.
|
16
|
+
|
17
|
+
.. warning:: This is an experimental feature.
|
18
|
+
"""
|
19
|
+
|
20
|
+
checkpoint_io: PluginConfig | None = None
|
21
|
+
"""A checkpoint IO plugin that is used as the basis for async checkpointing."""
|
22
|
+
|
23
|
+
@override
|
24
|
+
def create_plugin(self, trainer_config) -> CheckpointIO:
|
25
|
+
from lightning.pytorch.plugins.io.async_plugin import AsyncCheckpointIO
|
26
|
+
|
27
|
+
base_io = (
|
28
|
+
self.checkpoint_io.create_plugin(trainer_config)
|
29
|
+
if self.checkpoint_io
|
30
|
+
else None
|
31
|
+
)
|
32
|
+
if base_io is not None and not isinstance(base_io, CheckpointIO):
|
33
|
+
raise TypeError(
|
34
|
+
f"Expected `checkpoint_io` to be a `CheckpointIO` instance, but got {type(base_io)}."
|
35
|
+
)
|
36
|
+
return AsyncCheckpointIO(checkpoint_io=base_io)
|
37
|
+
|
38
|
+
|
39
|
+
@plugin_registry.register
|
40
|
+
class TorchCheckpointIOPlugin(PluginConfigBase):
|
41
|
+
name: Literal["torch_checkpoint"] = "torch_checkpoint"
|
42
|
+
|
43
|
+
"""CheckpointIO that utilizes torch.save and torch.load to save and load checkpoints respectively."""
|
44
|
+
|
45
|
+
@override
|
46
|
+
def create_plugin(self, trainer_config) -> CheckpointIO:
|
47
|
+
from lightning.fabric.plugins.io.torch_io import TorchCheckpointIO
|
48
|
+
|
49
|
+
return TorchCheckpointIO()
|
50
|
+
|
51
|
+
|
52
|
+
@plugin_registry.register
|
53
|
+
class XLACheckpointIOPlugin(PluginConfigBase):
|
54
|
+
name: Literal["xla_checkpoint"] = "xla_checkpoint"
|
55
|
+
|
56
|
+
"""CheckpointIO that utilizes xm.save to save checkpoints for TPU training strategies."""
|
57
|
+
|
58
|
+
@override
|
59
|
+
def create_plugin(self, trainer_config) -> CheckpointIO:
|
60
|
+
from lightning.fabric.plugins.io.xla import XLACheckpointIO
|
61
|
+
|
62
|
+
return XLACheckpointIO()
|
@@ -0,0 +1,25 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Literal
|
4
|
+
|
5
|
+
from lightning.pytorch.plugins.layer_sync import LayerSync
|
6
|
+
from typing_extensions import override
|
7
|
+
|
8
|
+
from .base import PluginConfigBase, plugin_registry
|
9
|
+
|
10
|
+
|
11
|
+
@plugin_registry.register
|
12
|
+
class TorchSyncBatchNormPlugin(PluginConfigBase):
|
13
|
+
name: Literal["torch_sync_batchnorm"] = "torch_sync_batchnorm"
|
14
|
+
|
15
|
+
"""A plugin that wraps all batch normalization layers of a model with synchronization
|
16
|
+
logic for multiprocessing.
|
17
|
+
|
18
|
+
This plugin has no effect in single-device operation.
|
19
|
+
"""
|
20
|
+
|
21
|
+
@override
|
22
|
+
def create_plugin(self, trainer_config) -> LayerSync:
|
23
|
+
from lightning.pytorch.plugins.layer_sync import TorchSyncBatchNorm
|
24
|
+
|
25
|
+
return TorchSyncBatchNorm()
|