nshtrainer 1.0.0b28__py3-none-any.whl → 1.0.0b30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +2 -0
- nshtrainer/configs/__init__.py +95 -3
- nshtrainer/configs/trainer/__init__.py +103 -3
- nshtrainer/configs/trainer/_config/__init__.py +10 -6
- nshtrainer/configs/trainer/accelerator/__init__.py +25 -0
- nshtrainer/configs/trainer/plugin/__init__.py +98 -0
- nshtrainer/configs/trainer/plugin/base/__init__.py +13 -0
- nshtrainer/configs/trainer/plugin/environment/__init__.py +41 -0
- nshtrainer/configs/trainer/plugin/io/__init__.py +23 -0
- nshtrainer/configs/trainer/plugin/layer_sync/__init__.py +15 -0
- nshtrainer/configs/trainer/plugin/precision/__init__.py +43 -0
- nshtrainer/configs/trainer/strategy/__init__.py +11 -0
- nshtrainer/configs/trainer/trainer/__init__.py +2 -0
- nshtrainer/metrics/_config.py +2 -4
- nshtrainer/trainer/__init__.py +2 -0
- nshtrainer/trainer/_config.py +6 -51
- nshtrainer/trainer/accelerator.py +86 -0
- nshtrainer/trainer/plugin/__init__.py +10 -0
- nshtrainer/trainer/plugin/base.py +33 -0
- nshtrainer/trainer/plugin/environment.py +128 -0
- nshtrainer/trainer/plugin/io.py +62 -0
- nshtrainer/trainer/plugin/layer_sync.py +25 -0
- nshtrainer/trainer/plugin/precision.py +163 -0
- nshtrainer/trainer/strategy.py +51 -0
- nshtrainer/trainer/trainer.py +8 -9
- {nshtrainer-1.0.0b28.dist-info → nshtrainer-1.0.0b30.dist-info}/METADATA +1 -1
- {nshtrainer-1.0.0b28.dist-info → nshtrainer-1.0.0b30.dist-info}/RECORD +28 -13
- nshtrainer/util/_useful_types.py +0 -316
- {nshtrainer-1.0.0b28.dist-info → nshtrainer-1.0.0b30.dist-info}/WHEEL +0 -0
@@ -0,0 +1,163 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Any, Literal
|
4
|
+
|
5
|
+
from lightning.pytorch.plugins.precision import Precision
|
6
|
+
from typing_extensions import override
|
7
|
+
|
8
|
+
from ...util.config.dtype import DTypeConfig
|
9
|
+
from .base import PluginConfigBase, plugin_registry
|
10
|
+
|
11
|
+
|
12
|
+
@plugin_registry.register
|
13
|
+
class MixedPrecisionPluginConfig(PluginConfigBase):
|
14
|
+
name: Literal["mixed_precision"] = "mixed_precision"
|
15
|
+
|
16
|
+
precision: Literal["16-mixed", "bf16-mixed"]
|
17
|
+
"""Whether to use ``torch.float16`` (``'16-mixed'``) or ``torch.bfloat16`` (``'bf16-mixed'``)."""
|
18
|
+
|
19
|
+
device: str
|
20
|
+
"""The device for ``torch.autocast``."""
|
21
|
+
|
22
|
+
@override
|
23
|
+
def create_plugin(self, trainer_config) -> Precision:
|
24
|
+
from lightning.pytorch.plugins.precision.amp import MixedPrecision
|
25
|
+
|
26
|
+
return MixedPrecision(self.precision, self.device)
|
27
|
+
|
28
|
+
|
29
|
+
@plugin_registry.register
|
30
|
+
class BitsandbytesPluginConfig(PluginConfigBase):
|
31
|
+
name: Literal["bitsandbytes_precision"] = "bitsandbytes_precision"
|
32
|
+
|
33
|
+
mode: Literal["nf4", "nf4-dq", "fp4", "fp4-dq", "int8", "int8-training"]
|
34
|
+
"""The quantization mode to use."""
|
35
|
+
|
36
|
+
dtype: DTypeConfig | None = None
|
37
|
+
"""The compute dtype to use."""
|
38
|
+
|
39
|
+
ignore_modules: set[str] | None = None
|
40
|
+
"""The submodules whose Linear layers should not be replaced.
|
41
|
+
|
42
|
+
This might be desirable for numerical stability. The string will be checked
|
43
|
+
as a prefix, so a value like "transformer.blocks" will ignore all linear
|
44
|
+
layers in all of the transformer blocks.
|
45
|
+
"""
|
46
|
+
|
47
|
+
@override
|
48
|
+
def create_plugin(self, trainer_config) -> Precision:
|
49
|
+
from lightning.pytorch.plugins.precision.bitsandbytes import (
|
50
|
+
BitsandbytesPrecision,
|
51
|
+
)
|
52
|
+
|
53
|
+
return BitsandbytesPrecision(
|
54
|
+
mode=self.mode,
|
55
|
+
dtype=self.dtype.torch_dtype if self.dtype is not None else None,
|
56
|
+
ignore_modules=self.ignore_modules,
|
57
|
+
)
|
58
|
+
|
59
|
+
|
60
|
+
@plugin_registry.register
|
61
|
+
class DeepSpeedPluginConfig(PluginConfigBase):
|
62
|
+
name: Literal["deepspeed_precision"] = "deepspeed_precision"
|
63
|
+
|
64
|
+
precision: Literal["16-true", "bf16-true", "16-mixed", "bf16-mixed", "32-true"]
|
65
|
+
"""Full precision (32-true), half precision (16-true, bf16-true) or
|
66
|
+
mixed precision (16-mixed, bf16-mixed)."""
|
67
|
+
|
68
|
+
@override
|
69
|
+
def create_plugin(self, trainer_config) -> Precision:
|
70
|
+
from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecision
|
71
|
+
|
72
|
+
return DeepSpeedPrecision(precision=self.precision)
|
73
|
+
|
74
|
+
|
75
|
+
@plugin_registry.register
|
76
|
+
class DoublePrecisionPluginConfig(PluginConfigBase):
|
77
|
+
name: Literal["double_precision"] = "double_precision"
|
78
|
+
|
79
|
+
precision: Literal["64-true"] = "64-true"
|
80
|
+
"""Plugin for training with double (``torch.float64``) precision."""
|
81
|
+
|
82
|
+
@override
|
83
|
+
def create_plugin(self, trainer_config) -> Precision:
|
84
|
+
from lightning.pytorch.plugins.precision.double import DoublePrecision
|
85
|
+
|
86
|
+
return DoublePrecision()
|
87
|
+
|
88
|
+
|
89
|
+
@plugin_registry.register
|
90
|
+
class FSDPPrecisionPluginConfig(PluginConfigBase):
|
91
|
+
name: Literal["fsdp_precision"] = "fsdp_precision"
|
92
|
+
|
93
|
+
precision: Literal["16-true", "bf16-true", "16-mixed", "bf16-mixed", "32-true"]
|
94
|
+
"""Full precision (32-true), half precision (16-true, bf16-true) or
|
95
|
+
mixed precision (16-mixed, bf16-mixed)."""
|
96
|
+
|
97
|
+
@override
|
98
|
+
def create_plugin(self, trainer_config) -> Precision:
|
99
|
+
from lightning.pytorch.plugins.precision.fsdp import FSDPPrecision
|
100
|
+
|
101
|
+
return FSDPPrecision(precision=self.precision)
|
102
|
+
|
103
|
+
|
104
|
+
@plugin_registry.register
|
105
|
+
class HalfPrecisionPluginConfig(PluginConfigBase):
|
106
|
+
name: Literal["half_precision"] = "half_precision"
|
107
|
+
|
108
|
+
precision: Literal["bf16-true", "16-true"]
|
109
|
+
"""Whether to use ``torch.float16`` (``'16-true'``) or ``torch.bfloat16`` (``'bf16-true'``)."""
|
110
|
+
|
111
|
+
@override
|
112
|
+
def create_plugin(self, trainer_config) -> Precision:
|
113
|
+
from lightning.pytorch.plugins.precision.half import HalfPrecision
|
114
|
+
|
115
|
+
return HalfPrecision(precision=self.precision)
|
116
|
+
|
117
|
+
|
118
|
+
@plugin_registry.register
|
119
|
+
class TransformerEnginePluginConfig(PluginConfigBase):
|
120
|
+
name: Literal["transformer_engine_precision"] = "transformer_engine_precision"
|
121
|
+
|
122
|
+
weights_dtype: DTypeConfig
|
123
|
+
"""The weights dtype to use."""
|
124
|
+
|
125
|
+
recipe: dict[str, Any] | None = None
|
126
|
+
"""Recipe for the DelayedScaling configuration in dict format."""
|
127
|
+
|
128
|
+
replace_layers: bool | None = None
|
129
|
+
"""Whether to replace ``Linear`` and ``LayerNorm`` layers automatically with their
|
130
|
+
Transformer Engine alternatives."""
|
131
|
+
|
132
|
+
fallback_compute_dtype: DTypeConfig | None = None
|
133
|
+
"""The compute dtype to use for operations that don't support fp8 autocast.
|
134
|
+
Defaults to the same as weights_dtype."""
|
135
|
+
|
136
|
+
@override
|
137
|
+
def create_plugin(self, trainer_config) -> Precision:
|
138
|
+
from lightning.pytorch.plugins.precision.transformer_engine import (
|
139
|
+
TransformerEnginePrecision,
|
140
|
+
)
|
141
|
+
|
142
|
+
return TransformerEnginePrecision(
|
143
|
+
weights_dtype=self.weights_dtype.torch_dtype,
|
144
|
+
recipe=self.recipe,
|
145
|
+
replace_layers=self.replace_layers,
|
146
|
+
fallback_compute_dtype=self.fallback_compute_dtype.torch_dtype
|
147
|
+
if self.fallback_compute_dtype
|
148
|
+
else None,
|
149
|
+
)
|
150
|
+
|
151
|
+
|
152
|
+
@plugin_registry.register
|
153
|
+
class XLAPluginConfig(PluginConfigBase):
|
154
|
+
name: Literal["xla_precision"] = "xla_precision"
|
155
|
+
|
156
|
+
precision: Literal["32-true", "16-true", "bf16-true"]
|
157
|
+
"""Full precision (32-true) or half precision (16-true, bf16-true)."""
|
158
|
+
|
159
|
+
@override
|
160
|
+
def create_plugin(self, trainer_config) -> Precision:
|
161
|
+
from lightning.pytorch.plugins.precision.xla import XLAPrecision
|
162
|
+
|
163
|
+
return XLAPrecision(precision=self.precision)
|
@@ -0,0 +1,51 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from abc import ABC, abstractmethod
|
4
|
+
from typing import TYPE_CHECKING, Literal
|
5
|
+
|
6
|
+
import nshconfig as C
|
7
|
+
from lightning.pytorch.strategies.strategy import Strategy
|
8
|
+
from typing_extensions import TypeAliasType
|
9
|
+
|
10
|
+
if TYPE_CHECKING:
|
11
|
+
from ._config import TrainerConfig
|
12
|
+
|
13
|
+
StrategyLiteral = TypeAliasType(
|
14
|
+
"StrategyLiteral",
|
15
|
+
Literal[
|
16
|
+
"auto",
|
17
|
+
"ddp",
|
18
|
+
"ddp_find_unused_parameters_false",
|
19
|
+
"ddp_find_unused_parameters_true",
|
20
|
+
"ddp_spawn",
|
21
|
+
"ddp_spawn_find_unused_parameters_false",
|
22
|
+
"ddp_spawn_find_unused_parameters_true",
|
23
|
+
"ddp_fork",
|
24
|
+
"ddp_fork_find_unused_parameters_false",
|
25
|
+
"ddp_fork_find_unused_parameters_true",
|
26
|
+
"ddp_notebook",
|
27
|
+
"dp",
|
28
|
+
"deepspeed",
|
29
|
+
"deepspeed_stage_1",
|
30
|
+
"deepspeed_stage_1_offload",
|
31
|
+
"deepspeed_stage_2",
|
32
|
+
"deepspeed_stage_2_offload",
|
33
|
+
"deepspeed_stage_3",
|
34
|
+
"deepspeed_stage_3_offload",
|
35
|
+
"deepspeed_stage_3_offload_nvme",
|
36
|
+
"fsdp",
|
37
|
+
"fsdp_cpu_offload",
|
38
|
+
"single_xla",
|
39
|
+
"xla_fsdp",
|
40
|
+
"xla",
|
41
|
+
"single_tpu",
|
42
|
+
],
|
43
|
+
)
|
44
|
+
|
45
|
+
|
46
|
+
class StrategyConfigBase(C.Config, ABC):
|
47
|
+
@abstractmethod
|
48
|
+
def create_strategy(self, trainer_config: "TrainerConfig") -> Strategy: ...
|
49
|
+
|
50
|
+
|
51
|
+
StrategyConfig = TypeAliasType("StrategyConfig", StrategyConfigBase)
|
nshtrainer/trainer/trainer.py
CHANGED
@@ -22,14 +22,12 @@ from .._checkpoint.metadata import _write_checkpoint_metadata
|
|
22
22
|
from ..callbacks.base import resolve_all_callbacks
|
23
23
|
from ..util._environment_info import EnvironmentConfig
|
24
24
|
from ..util.bf16 import is_bf16_supported_no_emulation
|
25
|
-
from ._config import
|
26
|
-
AcceleratorConfigBase,
|
27
|
-
LightningTrainerKwargs,
|
28
|
-
StrategyConfigBase,
|
29
|
-
TrainerConfig,
|
30
|
-
)
|
25
|
+
from ._config import LightningTrainerKwargs, TrainerConfig
|
31
26
|
from ._runtime_callback import RuntimeTrackerCallback, Stage
|
27
|
+
from .accelerator import AcceleratorConfigBase
|
28
|
+
from .plugin import PluginConfigBase
|
32
29
|
from .signal_connector import _SignalConnector
|
30
|
+
from .strategy import StrategyConfigBase
|
33
31
|
|
34
32
|
log = logging.getLogger(__name__)
|
35
33
|
|
@@ -172,12 +170,12 @@ class Trainer(LightningTrainer):
|
|
172
170
|
|
173
171
|
if (accelerator := hparams.accelerator) is not None:
|
174
172
|
if isinstance(accelerator, AcceleratorConfigBase):
|
175
|
-
accelerator = accelerator.create_accelerator()
|
173
|
+
accelerator = accelerator.create_accelerator(hparams)
|
176
174
|
_update_kwargs(accelerator=accelerator)
|
177
175
|
|
178
176
|
if (strategy := hparams.strategy) is not None:
|
179
177
|
if isinstance(strategy, StrategyConfigBase):
|
180
|
-
strategy = strategy.create_strategy()
|
178
|
+
strategy = strategy.create_strategy(hparams)
|
181
179
|
_update_kwargs(strategy=strategy)
|
182
180
|
|
183
181
|
if (precision := hparams.precision) is not None:
|
@@ -238,7 +236,8 @@ class Trainer(LightningTrainer):
|
|
238
236
|
if plugin_configs := hparams.plugins:
|
239
237
|
_update_kwargs(
|
240
238
|
plugins=[
|
241
|
-
plugin_config.create_plugin()
|
239
|
+
plugin_config.create_plugin(hparams)
|
240
|
+
for plugin_config in plugin_configs
|
242
241
|
]
|
243
242
|
)
|
244
243
|
|
@@ -1,5 +1,5 @@
|
|
1
1
|
nshtrainer/.nshconfig.generated.json,sha256=yZd6cn1RhvNNJUgiUTRYut8ofZYvbulnpPG-rZIRhi4,106
|
2
|
-
nshtrainer/__init__.py,sha256=
|
2
|
+
nshtrainer/__init__.py,sha256=52OB7QRlhrTCIdDecpT7yEZyZM1XvYxywhuORn1eKoY,814
|
3
3
|
nshtrainer/_callback.py,sha256=tXQCDzS6CvMTuTY5lQSH5qZs1pXUi-gt9bQdpXMVdEs,12715
|
4
4
|
nshtrainer/_checkpoint/metadata.py,sha256=PHy-54Cg-o3OtCffAqrVv6ZVMU7zhRo_-sZiSEEno1Y,5019
|
5
5
|
nshtrainer/_checkpoint/saver.py,sha256=LOP8jjKF0Dw9x9H-BKrLMWlEp1XTan2DUK0zQUCWw5U,1360
|
@@ -31,7 +31,7 @@ nshtrainer/callbacks/shared_parameters.py,sha256=ggMI1krkqN7sGOrjK_I96IsTMYMXHoV
|
|
31
31
|
nshtrainer/callbacks/timer.py,sha256=BB-M7tV4QNYOwY_Su6j9P7IILxVRae_upmDq4qsxiao,4670
|
32
32
|
nshtrainer/callbacks/wandb_upload_code.py,sha256=PTqNE1QB5U8NR5zhbiQZrmQuugX2UV7B12UdMpo9aV0,2353
|
33
33
|
nshtrainer/callbacks/wandb_watch.py,sha256=tTTcFzxd2Ia9xu8tCogQ5CLJZBq1ne5JlpGVE75vKYs,2976
|
34
|
-
nshtrainer/configs/__init__.py,sha256=
|
34
|
+
nshtrainer/configs/__init__.py,sha256=eS3naq6EG1vCq28G2nAW1CqYFdsrh6ueBlzX_LazgUw,14159
|
35
35
|
nshtrainer/configs/_checkpoint/__init__.py,sha256=6s7Y68StboqscY2G4P_QG443jz5aiym5SjOogIljWLg,342
|
36
36
|
nshtrainer/configs/_checkpoint/metadata/__init__.py,sha256=oOPfYkXTjKgm6pluGsG6V1TPyCEGjsQpHVL-LffSUFQ,290
|
37
37
|
nshtrainer/configs/_directory/__init__.py,sha256=_oO7vM9DhzHSxtZcv86sTi7hZIptnK1gr-AP9mqQ370,386
|
@@ -81,9 +81,17 @@ nshtrainer/configs/profiler/_base/__init__.py,sha256=ekYfPg-VDhCAFM5nJka2TxUYdRD
|
|
81
81
|
nshtrainer/configs/profiler/advanced/__init__.py,sha256=-ThpUat16Ij_0avkMUVVA8wCWDG_q_tM7KQofnWQCtg,308
|
82
82
|
nshtrainer/configs/profiler/pytorch/__init__.py,sha256=soAU1s2_Pa1na4gW8CK-iysJBO5M_7YeZC2_x40iEdg,294
|
83
83
|
nshtrainer/configs/profiler/simple/__init__.py,sha256=3Wb11lPuFuyasq8xS1CZ4WLuBCLS_nVSQGVllvOOi0Y,289
|
84
|
-
nshtrainer/configs/trainer/__init__.py,sha256=
|
85
|
-
nshtrainer/configs/trainer/_config/__init__.py,sha256=
|
86
|
-
nshtrainer/configs/trainer/
|
84
|
+
nshtrainer/configs/trainer/__init__.py,sha256=hKMI_2ve5zcsQys2DDQDv7OmshYsIG0uJlCLreVHpF0,7779
|
85
|
+
nshtrainer/configs/trainer/_config/__init__.py,sha256=Xw6I_9tUemDbHncpjKHRqye_e1_OyubK_FJcvdcQ0yc,4020
|
86
|
+
nshtrainer/configs/trainer/accelerator/__init__.py,sha256=3H6R3wlwbKL1TzDqGCChZk78-BcE2czLouo7Djiq3nA,898
|
87
|
+
nshtrainer/configs/trainer/plugin/__init__.py,sha256=NkHQxMPkrtTtdIAO4dQUE9SWEcHRDB0yUXLkTjnl4dA,3332
|
88
|
+
nshtrainer/configs/trainer/plugin/base/__init__.py,sha256=GuubcKrbXt4VjJRT8VpNUQqBtyuutre_CfkS0EWZ5_E,368
|
89
|
+
nshtrainer/configs/trainer/plugin/environment/__init__.py,sha256=3o16x4qRAOvkJH9Vg4-QwsEODDC6aP_OXRnPPkm_xSo,1376
|
90
|
+
nshtrainer/configs/trainer/plugin/io/__init__.py,sha256=W6G67JnigB6d3MiwLrbSKgtIZLUccXznp-IXwkK1J4U,743
|
91
|
+
nshtrainer/configs/trainer/plugin/layer_sync/__init__.py,sha256=SYDZk2M6sgpt4sEuoURuS8EKYmaqGcvYxETE9jvTrEE,431
|
92
|
+
nshtrainer/configs/trainer/plugin/precision/__init__.py,sha256=szlqSfK2XuWdkf72LQzQFv3SlWfKFdRUpBEYIxQ3TPs,1507
|
93
|
+
nshtrainer/configs/trainer/strategy/__init__.py,sha256=50whNloJVBq_bdbLaPQnPBTeS1Rcs8MwxTCYBj1kKa4,273
|
94
|
+
nshtrainer/configs/trainer/trainer/__init__.py,sha256=QnuhMQNAa1nSVN2o50_WeKAQG_qkNlkeoq9zTjjwmTI,586
|
87
95
|
nshtrainer/configs/util/__init__.py,sha256=qXittS7f7MyaqJnjvFLKnKsyb6bXTD3dEV16jXVDaH4,2104
|
88
96
|
nshtrainer/configs/util/_environment_info/__init__.py,sha256=eB4E0Ck7XCeSC5gbUdA5thd7TXnjGCL0t8GZIFj7uCI,1644
|
89
97
|
nshtrainer/configs/util/config/__init__.py,sha256=nEFiDG3-dvvTytYn1tEkPFzp7fgaGRp2j7toSN7yRGs,501
|
@@ -104,7 +112,7 @@ nshtrainer/lr_scheduler/_base.py,sha256=EhA2f_WiZ79RcXL2nJbwCwNK620c8ugEVUmJ8CcV
|
|
104
112
|
nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=gvUuv031lvWdXboDeH7iAF3ZgNPQK40bQwfmqb11TNk,5492
|
105
113
|
nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=vXH5S26ESHO_LPPqW8aDC3S5NGoZYkXeFjAOgttaUX8,2870
|
106
114
|
nshtrainer/metrics/__init__.py,sha256=Nqkn_jsDf3n5WtfMcnaaEftYjIIT2b-S7rmsB1MOMkU,86
|
107
|
-
nshtrainer/metrics/_config.py,sha256=
|
115
|
+
nshtrainer/metrics/_config.py,sha256=XIRokFM8PHrhBa3w2R6BM6a4es3ncsoBqE_LqXQFsFE,1223
|
108
116
|
nshtrainer/model/__init__.py,sha256=3G-bwPPSRStWdsdwG9-rn0bXcRpEiP1BiQpF_qavtls,97
|
109
117
|
nshtrainer/model/base.py,sha256=JL3AmH17GQjQIoMrZl3O0vUI7dj5ZsO5iEJgoLPyzHw,10356
|
110
118
|
nshtrainer/model/mixins/callback.py,sha256=0LPgve4VszHbLipid4mpI1qnnmdGS2spivs0dXLvqHw,3154
|
@@ -121,13 +129,20 @@ nshtrainer/profiler/_base.py,sha256=kFcSVn9gJuMwgDxbfyHh46CmEAIPZjxw3yjPbKgzvwA,
|
|
121
129
|
nshtrainer/profiler/advanced.py,sha256=XrM3FX0ThCv5UwUrrH0l4Ow4LGAtpiBww2N8QAU5NOQ,1160
|
122
130
|
nshtrainer/profiler/pytorch.py,sha256=8K37XvPnCApUpIK8tA2zNMFIaIiTLSoxKQoiyCPBm1Q,2757
|
123
131
|
nshtrainer/profiler/simple.py,sha256=PimjqcU-JuS-8C0ZGHAdwCxgNLij4x0FH6WXsjBQzZs,1005
|
124
|
-
nshtrainer/trainer/__init__.py,sha256=
|
125
|
-
nshtrainer/trainer/_config.py,sha256=
|
132
|
+
nshtrainer/trainer/__init__.py,sha256=ggDHzIUbABezh4BjEwrxyWuXmuDBV-x4jv9gwXgVHU0,250
|
133
|
+
nshtrainer/trainer/_config.py,sha256=0GgofvaWf5Vo9REXNJpTvpVVRlFExGTOzcOt4jwJXNk,34129
|
126
134
|
nshtrainer/trainer/_runtime_callback.py,sha256=6F2Gq27Q8OFfN3RtdNC6QRA8ac0LC1hh4DUE3V5WgbI,4217
|
135
|
+
nshtrainer/trainer/accelerator.py,sha256=Bqq-ry7DeCY4zw9_zBvTZiijpA-uUHrDjtbLV652m4M,2415
|
136
|
+
nshtrainer/trainer/plugin/__init__.py,sha256=UM8f70Ml3RGpsXeQr1Yh1yBcxxjFNGFGgJbniCn_rws,366
|
137
|
+
nshtrainer/trainer/plugin/base.py,sha256=9-qUHXGpll_yCylun0899sbmJDpyhD9IQcBtVrJx38I,919
|
138
|
+
nshtrainer/trainer/plugin/environment.py,sha256=NW0qbsbvDPe59JGOMgPLq1fj7szLucIV1WRTxCrcjF4,4367
|
139
|
+
nshtrainer/trainer/plugin/io.py,sha256=nm6YDCVZAhmPvLaLnw6q4BrK2Gj2wvD5ZLDhj1xneEE,2030
|
140
|
+
nshtrainer/trainer/plugin/layer_sync.py,sha256=h-ydZwXepnsw5-paLgiDatqPyQ_8C0QEv6DrYFIaXOo,735
|
141
|
+
nshtrainer/trainer/plugin/precision.py,sha256=I0QsB1bVxmsFmBOkgrAfGONsuYae_lD9Bz0PfJEQvH4,5598
|
127
142
|
nshtrainer/trainer/signal_connector.py,sha256=GhfGcSzfaTNhnj2QFkBDq5aT7FqbLMA7eC8SYQs8_8w,10828
|
128
|
-
nshtrainer/trainer/
|
143
|
+
nshtrainer/trainer/strategy.py,sha256=VPTn5z3zvXTydY8IJchjhjcOfpvtoejnvUkq5E4WTus,1368
|
144
|
+
nshtrainer/trainer/trainer.py,sha256=l2kJs27v4IHZnzxExr0zX0sVex0wukgiD2Wn_0wiGJg,20836
|
129
145
|
nshtrainer/util/_environment_info.py,sha256=MT8mBe6ZolRfKiwU-les1P-lPNPqXpHQcfADrh_A3uY,24629
|
130
|
-
nshtrainer/util/_useful_types.py,sha256=7yd1ajSmjwfmZdBPlHVrIG3iXl1-T3n83JI53N8C7as,8080
|
131
146
|
nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
|
132
147
|
nshtrainer/util/config/__init__.py,sha256=Z39JJufSb61Lhn2GfVcv3eFW_eorOrN9-9llDWlnZZM,272
|
133
148
|
nshtrainer/util/config/dtype.py,sha256=Fn_MhhQoHPyFAnFPSwvcvLiGR3yWFIszMba02CJiC4g,2213
|
@@ -138,6 +153,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
|
|
138
153
|
nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
|
139
154
|
nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
|
140
155
|
nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
|
141
|
-
nshtrainer-1.0.
|
142
|
-
nshtrainer-1.0.
|
143
|
-
nshtrainer-1.0.
|
156
|
+
nshtrainer-1.0.0b30.dist-info/METADATA,sha256=zxFm4X5APkZR6E4E8-jzVghTwYEYCJQzCHpCV_8hWzg,988
|
157
|
+
nshtrainer-1.0.0b30.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
|
158
|
+
nshtrainer-1.0.0b30.dist-info/RECORD,,
|
nshtrainer/util/_useful_types.py
DELETED
@@ -1,316 +0,0 @@
|
|
1
|
-
"""Credit to useful-types from https://github.com/hauntsaninja/useful_types"""
|
2
|
-
|
3
|
-
from __future__ import annotations
|
4
|
-
|
5
|
-
from collections.abc import Awaitable, Iterable, Iterator, Sequence, Sized
|
6
|
-
from collections.abc import Set as AbstractSet
|
7
|
-
from os import PathLike
|
8
|
-
from typing import Any, TypeVar, overload
|
9
|
-
|
10
|
-
from typing_extensions import (
|
11
|
-
Buffer,
|
12
|
-
Literal,
|
13
|
-
Protocol,
|
14
|
-
SupportsIndex,
|
15
|
-
TypeAlias,
|
16
|
-
TypeAliasType,
|
17
|
-
)
|
18
|
-
|
19
|
-
_KT = TypeVar("_KT")
|
20
|
-
_KT_co = TypeVar("_KT_co", covariant=True)
|
21
|
-
_KT_contra = TypeVar("_KT_contra", contravariant=True)
|
22
|
-
_VT = TypeVar("_VT")
|
23
|
-
_VT_co = TypeVar("_VT_co", covariant=True)
|
24
|
-
_T = TypeVar("_T")
|
25
|
-
_T_co = TypeVar("_T_co", covariant=True)
|
26
|
-
_T_contra = TypeVar("_T_contra", contravariant=True)
|
27
|
-
|
28
|
-
# For partially known annotations. Usually, fields where type annotations
|
29
|
-
# haven't been added are left unannotated, but in some situations this
|
30
|
-
# isn't possible or a type is already partially known. In cases like these,
|
31
|
-
# use Incomplete instead of Any as a marker. For example, use
|
32
|
-
# "Incomplete | None" instead of "Any | None".
|
33
|
-
Incomplete: TypeAlias = Any
|
34
|
-
|
35
|
-
|
36
|
-
class IdentityFunction(Protocol):
|
37
|
-
def __call__(self, __x: _T) -> _T: ...
|
38
|
-
|
39
|
-
|
40
|
-
# ====================
|
41
|
-
# Comparison protocols
|
42
|
-
# ====================
|
43
|
-
|
44
|
-
|
45
|
-
class SupportsDunderLT(Protocol[_T_contra]):
|
46
|
-
def __lt__(self, __other: _T_contra) -> bool: ...
|
47
|
-
|
48
|
-
|
49
|
-
class SupportsDunderGT(Protocol[_T_contra]):
|
50
|
-
def __gt__(self, __other: _T_contra) -> bool: ...
|
51
|
-
|
52
|
-
|
53
|
-
class SupportsDunderLE(Protocol[_T_contra]):
|
54
|
-
def __le__(self, __other: _T_contra) -> bool: ...
|
55
|
-
|
56
|
-
|
57
|
-
class SupportsDunderGE(Protocol[_T_contra]):
|
58
|
-
def __ge__(self, __other: _T_contra) -> bool: ...
|
59
|
-
|
60
|
-
|
61
|
-
class SupportsAllComparisons(
|
62
|
-
SupportsDunderLT[Any],
|
63
|
-
SupportsDunderGT[Any],
|
64
|
-
SupportsDunderLE[Any],
|
65
|
-
SupportsDunderGE[Any],
|
66
|
-
Protocol,
|
67
|
-
): ...
|
68
|
-
|
69
|
-
|
70
|
-
SupportsRichComparison = TypeAliasType(
|
71
|
-
"SupportsRichComparison", SupportsDunderLT[Any] | SupportsDunderGT[Any]
|
72
|
-
)
|
73
|
-
SupportsRichComparisonT = TypeVar(
|
74
|
-
"SupportsRichComparisonT", bound=SupportsRichComparison
|
75
|
-
)
|
76
|
-
|
77
|
-
# ====================
|
78
|
-
# Dunder protocols
|
79
|
-
# ====================
|
80
|
-
|
81
|
-
|
82
|
-
class SupportsNext(Protocol[_T_co]):
|
83
|
-
def __next__(self) -> _T_co: ...
|
84
|
-
|
85
|
-
|
86
|
-
class SupportsAnext(Protocol[_T_co]):
|
87
|
-
def __anext__(self) -> Awaitable[_T_co]: ...
|
88
|
-
|
89
|
-
|
90
|
-
class SupportsAdd(Protocol[_T_contra, _T_co]):
|
91
|
-
def __add__(self, __x: _T_contra) -> _T_co: ...
|
92
|
-
|
93
|
-
|
94
|
-
class SupportsRAdd(Protocol[_T_contra, _T_co]):
|
95
|
-
def __radd__(self, __x: _T_contra) -> _T_co: ...
|
96
|
-
|
97
|
-
|
98
|
-
class SupportsSub(Protocol[_T_contra, _T_co]):
|
99
|
-
def __sub__(self, __x: _T_contra) -> _T_co: ...
|
100
|
-
|
101
|
-
|
102
|
-
class SupportsRSub(Protocol[_T_contra, _T_co]):
|
103
|
-
def __rsub__(self, __x: _T_contra) -> _T_co: ...
|
104
|
-
|
105
|
-
|
106
|
-
class SupportsDivMod(Protocol[_T_contra, _T_co]):
|
107
|
-
def __divmod__(self, __other: _T_contra) -> _T_co: ...
|
108
|
-
|
109
|
-
|
110
|
-
class SupportsRDivMod(Protocol[_T_contra, _T_co]):
|
111
|
-
def __rdivmod__(self, __other: _T_contra) -> _T_co: ...
|
112
|
-
|
113
|
-
|
114
|
-
# This protocol is generic over the iterator type, while Iterable is
|
115
|
-
# generic over the type that is iterated over.
|
116
|
-
class SupportsIter(Protocol[_T_co]):
|
117
|
-
def __iter__(self) -> _T_co: ...
|
118
|
-
|
119
|
-
|
120
|
-
# This protocol is generic over the iterator type, while AsyncIterable is
|
121
|
-
# generic over the type that is iterated over.
|
122
|
-
class SupportsAiter(Protocol[_T_co]):
|
123
|
-
def __aiter__(self) -> _T_co: ...
|
124
|
-
|
125
|
-
|
126
|
-
class SupportsLenAndGetItem(Protocol[_T_co]):
|
127
|
-
def __len__(self) -> int: ...
|
128
|
-
def __getitem__(self, __k: int) -> _T_co: ...
|
129
|
-
|
130
|
-
|
131
|
-
class SupportsTrunc(Protocol):
|
132
|
-
def __trunc__(self) -> int: ...
|
133
|
-
|
134
|
-
|
135
|
-
# ====================
|
136
|
-
# Mapping-like protocols
|
137
|
-
# ====================
|
138
|
-
|
139
|
-
|
140
|
-
class SupportsItems(Protocol[_KT_co, _VT_co]):
|
141
|
-
def items(self) -> AbstractSet[tuple[_KT_co, _VT_co]]: ...
|
142
|
-
|
143
|
-
|
144
|
-
class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]):
|
145
|
-
def keys(self) -> Iterable[_KT]: ...
|
146
|
-
def __getitem__(self, __key: _KT) -> _VT_co: ...
|
147
|
-
|
148
|
-
|
149
|
-
class SupportsGetItem(Protocol[_KT_contra, _VT_co]):
|
150
|
-
def __contains__(self, __x: Any) -> bool: ...
|
151
|
-
def __getitem__(self, __key: _KT_contra) -> _VT_co: ...
|
152
|
-
|
153
|
-
|
154
|
-
class SupportsItemAccess(SupportsGetItem[_KT_contra, _VT], Protocol[_KT_contra, _VT]):
|
155
|
-
def __setitem__(self, __key: _KT_contra, __value: _VT) -> None: ...
|
156
|
-
def __delitem__(self, __key: _KT_contra) -> None: ...
|
157
|
-
|
158
|
-
|
159
|
-
# ====================
|
160
|
-
# File handling
|
161
|
-
# ====================
|
162
|
-
|
163
|
-
StrPath: TypeAlias = str | PathLike[str]
|
164
|
-
BytesPath: TypeAlias = bytes | PathLike[bytes]
|
165
|
-
StrOrBytesPath: TypeAlias = str | bytes | PathLike[str] | PathLike[bytes]
|
166
|
-
|
167
|
-
OpenTextModeUpdating: TypeAlias = Literal[
|
168
|
-
"r+",
|
169
|
-
"+r",
|
170
|
-
"rt+",
|
171
|
-
"r+t",
|
172
|
-
"+rt",
|
173
|
-
"tr+",
|
174
|
-
"t+r",
|
175
|
-
"+tr",
|
176
|
-
"w+",
|
177
|
-
"+w",
|
178
|
-
"wt+",
|
179
|
-
"w+t",
|
180
|
-
"+wt",
|
181
|
-
"tw+",
|
182
|
-
"t+w",
|
183
|
-
"+tw",
|
184
|
-
"a+",
|
185
|
-
"+a",
|
186
|
-
"at+",
|
187
|
-
"a+t",
|
188
|
-
"+at",
|
189
|
-
"ta+",
|
190
|
-
"t+a",
|
191
|
-
"+ta",
|
192
|
-
"x+",
|
193
|
-
"+x",
|
194
|
-
"xt+",
|
195
|
-
"x+t",
|
196
|
-
"+xt",
|
197
|
-
"tx+",
|
198
|
-
"t+x",
|
199
|
-
"+tx",
|
200
|
-
]
|
201
|
-
OpenTextModeWriting: TypeAlias = Literal[
|
202
|
-
"w", "wt", "tw", "a", "at", "ta", "x", "xt", "tx"
|
203
|
-
]
|
204
|
-
OpenTextModeReading: TypeAlias = Literal[
|
205
|
-
"r", "rt", "tr", "U", "rU", "Ur", "rtU", "rUt", "Urt", "trU", "tUr", "Utr"
|
206
|
-
]
|
207
|
-
OpenTextMode: TypeAlias = (
|
208
|
-
OpenTextModeUpdating | OpenTextModeWriting | OpenTextModeReading
|
209
|
-
)
|
210
|
-
OpenBinaryModeUpdating: TypeAlias = Literal[
|
211
|
-
"rb+",
|
212
|
-
"r+b",
|
213
|
-
"+rb",
|
214
|
-
"br+",
|
215
|
-
"b+r",
|
216
|
-
"+br",
|
217
|
-
"wb+",
|
218
|
-
"w+b",
|
219
|
-
"+wb",
|
220
|
-
"bw+",
|
221
|
-
"b+w",
|
222
|
-
"+bw",
|
223
|
-
"ab+",
|
224
|
-
"a+b",
|
225
|
-
"+ab",
|
226
|
-
"ba+",
|
227
|
-
"b+a",
|
228
|
-
"+ba",
|
229
|
-
"xb+",
|
230
|
-
"x+b",
|
231
|
-
"+xb",
|
232
|
-
"bx+",
|
233
|
-
"b+x",
|
234
|
-
"+bx",
|
235
|
-
]
|
236
|
-
OpenBinaryModeWriting: TypeAlias = Literal["wb", "bw", "ab", "ba", "xb", "bx"]
|
237
|
-
OpenBinaryModeReading: TypeAlias = Literal[
|
238
|
-
"rb", "br", "rbU", "rUb", "Urb", "brU", "bUr", "Ubr"
|
239
|
-
]
|
240
|
-
OpenBinaryMode: TypeAlias = (
|
241
|
-
OpenBinaryModeUpdating | OpenBinaryModeReading | OpenBinaryModeWriting
|
242
|
-
)
|
243
|
-
|
244
|
-
|
245
|
-
class HasFileno(Protocol):
|
246
|
-
def fileno(self) -> int: ...
|
247
|
-
|
248
|
-
|
249
|
-
FileDescriptor: TypeAlias = int
|
250
|
-
FileDescriptorLike: TypeAlias = int | HasFileno
|
251
|
-
FileDescriptorOrPath: TypeAlias = int | StrOrBytesPath
|
252
|
-
|
253
|
-
|
254
|
-
class SupportsRead(Protocol[_T_co]):
|
255
|
-
def read(self, __length: int = ...) -> _T_co: ...
|
256
|
-
|
257
|
-
|
258
|
-
class SupportsReadline(Protocol[_T_co]):
|
259
|
-
def readline(self, __length: int = ...) -> _T_co: ...
|
260
|
-
|
261
|
-
|
262
|
-
class SupportsNoArgReadline(Protocol[_T_co]):
|
263
|
-
def readline(self) -> _T_co: ...
|
264
|
-
|
265
|
-
|
266
|
-
class SupportsWrite(Protocol[_T_contra]):
|
267
|
-
def write(self, __s: _T_contra) -> object: ...
|
268
|
-
|
269
|
-
|
270
|
-
# ====================
|
271
|
-
# Buffer protocols
|
272
|
-
# ====================
|
273
|
-
|
274
|
-
# Unfortunately PEP 688 does not allow us to distinguish read-only
|
275
|
-
# from writable buffers. We use these aliases for readability for now.
|
276
|
-
# Perhaps a future extension of the buffer protocol will allow us to
|
277
|
-
# distinguish these cases in the type system.
|
278
|
-
ReadOnlyBuffer: TypeAlias = Buffer
|
279
|
-
# Anything that implements the read-write buffer interface.
|
280
|
-
WriteableBuffer: TypeAlias = Buffer
|
281
|
-
# Same as WriteableBuffer, but also includes read-only buffer types (like bytes).
|
282
|
-
ReadableBuffer: TypeAlias = Buffer
|
283
|
-
|
284
|
-
|
285
|
-
class SliceableBuffer(Buffer, Protocol):
|
286
|
-
def __getitem__(self, __slice: slice) -> Sequence[int]: ...
|
287
|
-
|
288
|
-
|
289
|
-
class IndexableBuffer(Buffer, Protocol):
|
290
|
-
def __getitem__(self, __i: int) -> int: ...
|
291
|
-
|
292
|
-
|
293
|
-
class SupportsGetItemBuffer(SliceableBuffer, IndexableBuffer, Protocol):
|
294
|
-
def __contains__(self, __x: Any) -> bool: ...
|
295
|
-
@overload
|
296
|
-
def __getitem__(self, __slice: slice) -> Sequence[int]: ...
|
297
|
-
@overload
|
298
|
-
def __getitem__(self, __i: int) -> int: ...
|
299
|
-
|
300
|
-
|
301
|
-
class SizedBuffer(Sized, Buffer, Protocol): ...
|
302
|
-
|
303
|
-
|
304
|
-
# Source from https://github.com/python/typing/issues/256#issuecomment-1442633430
|
305
|
-
# This works because str.__contains__ does not accept object (either in typeshed or at runtime)
|
306
|
-
class SequenceNotStr(Protocol[_T_co]):
|
307
|
-
@overload
|
308
|
-
def __getitem__(self, index: SupportsIndex, /) -> _T_co: ...
|
309
|
-
@overload
|
310
|
-
def __getitem__(self, index: slice, /) -> Sequence[_T_co]: ...
|
311
|
-
def __contains__(self, value: object, /) -> bool: ...
|
312
|
-
def __len__(self) -> int: ...
|
313
|
-
def __iter__(self) -> Iterator[_T_co]: ...
|
314
|
-
def index(self, value: Any, start: int = 0, stop: int = ..., /) -> int: ...
|
315
|
-
def count(self, value: Any, /) -> int: ...
|
316
|
-
def __reversed__(self) -> Iterator[_T_co]: ...
|
File without changes
|