nshtrainer 1.0.0b29__py3-none-any.whl → 1.0.0b30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +2 -0
- nshtrainer/configs/__init__.py +95 -3
- nshtrainer/configs/trainer/__init__.py +103 -3
- nshtrainer/configs/trainer/_config/__init__.py +10 -6
- nshtrainer/configs/trainer/accelerator/__init__.py +25 -0
- nshtrainer/configs/trainer/plugin/__init__.py +98 -0
- nshtrainer/configs/trainer/plugin/base/__init__.py +13 -0
- nshtrainer/configs/trainer/plugin/environment/__init__.py +41 -0
- nshtrainer/configs/trainer/plugin/io/__init__.py +23 -0
- nshtrainer/configs/trainer/plugin/layer_sync/__init__.py +15 -0
- nshtrainer/configs/trainer/plugin/precision/__init__.py +43 -0
- nshtrainer/configs/trainer/strategy/__init__.py +11 -0
- nshtrainer/configs/trainer/trainer/__init__.py +2 -0
- nshtrainer/trainer/__init__.py +2 -0
- nshtrainer/trainer/_config.py +3 -47
- nshtrainer/trainer/accelerator.py +86 -0
- nshtrainer/trainer/plugin/__init__.py +10 -0
- nshtrainer/trainer/plugin/base.py +33 -0
- nshtrainer/trainer/plugin/environment.py +128 -0
- nshtrainer/trainer/plugin/io.py +62 -0
- nshtrainer/trainer/plugin/layer_sync.py +25 -0
- nshtrainer/trainer/plugin/precision.py +163 -0
- nshtrainer/trainer/strategy.py +51 -0
- nshtrainer/trainer/trainer.py +8 -9
- {nshtrainer-1.0.0b29.dist-info → nshtrainer-1.0.0b30.dist-info}/METADATA +1 -1
- {nshtrainer-1.0.0b29.dist-info → nshtrainer-1.0.0b30.dist-info}/RECORD +27 -11
- {nshtrainer-1.0.0b29.dist-info → nshtrainer-1.0.0b30.dist-info}/WHEEL +0 -0
@@ -0,0 +1,51 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from abc import ABC, abstractmethod
|
4
|
+
from typing import TYPE_CHECKING, Literal
|
5
|
+
|
6
|
+
import nshconfig as C
|
7
|
+
from lightning.pytorch.strategies.strategy import Strategy
|
8
|
+
from typing_extensions import TypeAliasType
|
9
|
+
|
10
|
+
if TYPE_CHECKING:
|
11
|
+
from ._config import TrainerConfig
|
12
|
+
|
13
|
+
StrategyLiteral = TypeAliasType(
|
14
|
+
"StrategyLiteral",
|
15
|
+
Literal[
|
16
|
+
"auto",
|
17
|
+
"ddp",
|
18
|
+
"ddp_find_unused_parameters_false",
|
19
|
+
"ddp_find_unused_parameters_true",
|
20
|
+
"ddp_spawn",
|
21
|
+
"ddp_spawn_find_unused_parameters_false",
|
22
|
+
"ddp_spawn_find_unused_parameters_true",
|
23
|
+
"ddp_fork",
|
24
|
+
"ddp_fork_find_unused_parameters_false",
|
25
|
+
"ddp_fork_find_unused_parameters_true",
|
26
|
+
"ddp_notebook",
|
27
|
+
"dp",
|
28
|
+
"deepspeed",
|
29
|
+
"deepspeed_stage_1",
|
30
|
+
"deepspeed_stage_1_offload",
|
31
|
+
"deepspeed_stage_2",
|
32
|
+
"deepspeed_stage_2_offload",
|
33
|
+
"deepspeed_stage_3",
|
34
|
+
"deepspeed_stage_3_offload",
|
35
|
+
"deepspeed_stage_3_offload_nvme",
|
36
|
+
"fsdp",
|
37
|
+
"fsdp_cpu_offload",
|
38
|
+
"single_xla",
|
39
|
+
"xla_fsdp",
|
40
|
+
"xla",
|
41
|
+
"single_tpu",
|
42
|
+
],
|
43
|
+
)
|
44
|
+
|
45
|
+
|
46
|
+
class StrategyConfigBase(C.Config, ABC):
|
47
|
+
@abstractmethod
|
48
|
+
def create_strategy(self, trainer_config: "TrainerConfig") -> Strategy: ...
|
49
|
+
|
50
|
+
|
51
|
+
StrategyConfig = TypeAliasType("StrategyConfig", StrategyConfigBase)
|
nshtrainer/trainer/trainer.py
CHANGED
@@ -22,14 +22,12 @@ from .._checkpoint.metadata import _write_checkpoint_metadata
|
|
22
22
|
from ..callbacks.base import resolve_all_callbacks
|
23
23
|
from ..util._environment_info import EnvironmentConfig
|
24
24
|
from ..util.bf16 import is_bf16_supported_no_emulation
|
25
|
-
from ._config import
|
26
|
-
AcceleratorConfigBase,
|
27
|
-
LightningTrainerKwargs,
|
28
|
-
StrategyConfigBase,
|
29
|
-
TrainerConfig,
|
30
|
-
)
|
25
|
+
from ._config import LightningTrainerKwargs, TrainerConfig
|
31
26
|
from ._runtime_callback import RuntimeTrackerCallback, Stage
|
27
|
+
from .accelerator import AcceleratorConfigBase
|
28
|
+
from .plugin import PluginConfigBase
|
32
29
|
from .signal_connector import _SignalConnector
|
30
|
+
from .strategy import StrategyConfigBase
|
33
31
|
|
34
32
|
log = logging.getLogger(__name__)
|
35
33
|
|
@@ -172,12 +170,12 @@ class Trainer(LightningTrainer):
|
|
172
170
|
|
173
171
|
if (accelerator := hparams.accelerator) is not None:
|
174
172
|
if isinstance(accelerator, AcceleratorConfigBase):
|
175
|
-
accelerator = accelerator.create_accelerator()
|
173
|
+
accelerator = accelerator.create_accelerator(hparams)
|
176
174
|
_update_kwargs(accelerator=accelerator)
|
177
175
|
|
178
176
|
if (strategy := hparams.strategy) is not None:
|
179
177
|
if isinstance(strategy, StrategyConfigBase):
|
180
|
-
strategy = strategy.create_strategy()
|
178
|
+
strategy = strategy.create_strategy(hparams)
|
181
179
|
_update_kwargs(strategy=strategy)
|
182
180
|
|
183
181
|
if (precision := hparams.precision) is not None:
|
@@ -238,7 +236,8 @@ class Trainer(LightningTrainer):
|
|
238
236
|
if plugin_configs := hparams.plugins:
|
239
237
|
_update_kwargs(
|
240
238
|
plugins=[
|
241
|
-
plugin_config.create_plugin()
|
239
|
+
plugin_config.create_plugin(hparams)
|
240
|
+
for plugin_config in plugin_configs
|
242
241
|
]
|
243
242
|
)
|
244
243
|
|
@@ -1,5 +1,5 @@
|
|
1
1
|
nshtrainer/.nshconfig.generated.json,sha256=yZd6cn1RhvNNJUgiUTRYut8ofZYvbulnpPG-rZIRhi4,106
|
2
|
-
nshtrainer/__init__.py,sha256=
|
2
|
+
nshtrainer/__init__.py,sha256=52OB7QRlhrTCIdDecpT7yEZyZM1XvYxywhuORn1eKoY,814
|
3
3
|
nshtrainer/_callback.py,sha256=tXQCDzS6CvMTuTY5lQSH5qZs1pXUi-gt9bQdpXMVdEs,12715
|
4
4
|
nshtrainer/_checkpoint/metadata.py,sha256=PHy-54Cg-o3OtCffAqrVv6ZVMU7zhRo_-sZiSEEno1Y,5019
|
5
5
|
nshtrainer/_checkpoint/saver.py,sha256=LOP8jjKF0Dw9x9H-BKrLMWlEp1XTan2DUK0zQUCWw5U,1360
|
@@ -31,7 +31,7 @@ nshtrainer/callbacks/shared_parameters.py,sha256=ggMI1krkqN7sGOrjK_I96IsTMYMXHoV
|
|
31
31
|
nshtrainer/callbacks/timer.py,sha256=BB-M7tV4QNYOwY_Su6j9P7IILxVRae_upmDq4qsxiao,4670
|
32
32
|
nshtrainer/callbacks/wandb_upload_code.py,sha256=PTqNE1QB5U8NR5zhbiQZrmQuugX2UV7B12UdMpo9aV0,2353
|
33
33
|
nshtrainer/callbacks/wandb_watch.py,sha256=tTTcFzxd2Ia9xu8tCogQ5CLJZBq1ne5JlpGVE75vKYs,2976
|
34
|
-
nshtrainer/configs/__init__.py,sha256=
|
34
|
+
nshtrainer/configs/__init__.py,sha256=eS3naq6EG1vCq28G2nAW1CqYFdsrh6ueBlzX_LazgUw,14159
|
35
35
|
nshtrainer/configs/_checkpoint/__init__.py,sha256=6s7Y68StboqscY2G4P_QG443jz5aiym5SjOogIljWLg,342
|
36
36
|
nshtrainer/configs/_checkpoint/metadata/__init__.py,sha256=oOPfYkXTjKgm6pluGsG6V1TPyCEGjsQpHVL-LffSUFQ,290
|
37
37
|
nshtrainer/configs/_directory/__init__.py,sha256=_oO7vM9DhzHSxtZcv86sTi7hZIptnK1gr-AP9mqQ370,386
|
@@ -81,9 +81,17 @@ nshtrainer/configs/profiler/_base/__init__.py,sha256=ekYfPg-VDhCAFM5nJka2TxUYdRD
|
|
81
81
|
nshtrainer/configs/profiler/advanced/__init__.py,sha256=-ThpUat16Ij_0avkMUVVA8wCWDG_q_tM7KQofnWQCtg,308
|
82
82
|
nshtrainer/configs/profiler/pytorch/__init__.py,sha256=soAU1s2_Pa1na4gW8CK-iysJBO5M_7YeZC2_x40iEdg,294
|
83
83
|
nshtrainer/configs/profiler/simple/__init__.py,sha256=3Wb11lPuFuyasq8xS1CZ4WLuBCLS_nVSQGVllvOOi0Y,289
|
84
|
-
nshtrainer/configs/trainer/__init__.py,sha256=
|
85
|
-
nshtrainer/configs/trainer/_config/__init__.py,sha256=
|
86
|
-
nshtrainer/configs/trainer/
|
84
|
+
nshtrainer/configs/trainer/__init__.py,sha256=hKMI_2ve5zcsQys2DDQDv7OmshYsIG0uJlCLreVHpF0,7779
|
85
|
+
nshtrainer/configs/trainer/_config/__init__.py,sha256=Xw6I_9tUemDbHncpjKHRqye_e1_OyubK_FJcvdcQ0yc,4020
|
86
|
+
nshtrainer/configs/trainer/accelerator/__init__.py,sha256=3H6R3wlwbKL1TzDqGCChZk78-BcE2czLouo7Djiq3nA,898
|
87
|
+
nshtrainer/configs/trainer/plugin/__init__.py,sha256=NkHQxMPkrtTtdIAO4dQUE9SWEcHRDB0yUXLkTjnl4dA,3332
|
88
|
+
nshtrainer/configs/trainer/plugin/base/__init__.py,sha256=GuubcKrbXt4VjJRT8VpNUQqBtyuutre_CfkS0EWZ5_E,368
|
89
|
+
nshtrainer/configs/trainer/plugin/environment/__init__.py,sha256=3o16x4qRAOvkJH9Vg4-QwsEODDC6aP_OXRnPPkm_xSo,1376
|
90
|
+
nshtrainer/configs/trainer/plugin/io/__init__.py,sha256=W6G67JnigB6d3MiwLrbSKgtIZLUccXznp-IXwkK1J4U,743
|
91
|
+
nshtrainer/configs/trainer/plugin/layer_sync/__init__.py,sha256=SYDZk2M6sgpt4sEuoURuS8EKYmaqGcvYxETE9jvTrEE,431
|
92
|
+
nshtrainer/configs/trainer/plugin/precision/__init__.py,sha256=szlqSfK2XuWdkf72LQzQFv3SlWfKFdRUpBEYIxQ3TPs,1507
|
93
|
+
nshtrainer/configs/trainer/strategy/__init__.py,sha256=50whNloJVBq_bdbLaPQnPBTeS1Rcs8MwxTCYBj1kKa4,273
|
94
|
+
nshtrainer/configs/trainer/trainer/__init__.py,sha256=QnuhMQNAa1nSVN2o50_WeKAQG_qkNlkeoq9zTjjwmTI,586
|
87
95
|
nshtrainer/configs/util/__init__.py,sha256=qXittS7f7MyaqJnjvFLKnKsyb6bXTD3dEV16jXVDaH4,2104
|
88
96
|
nshtrainer/configs/util/_environment_info/__init__.py,sha256=eB4E0Ck7XCeSC5gbUdA5thd7TXnjGCL0t8GZIFj7uCI,1644
|
89
97
|
nshtrainer/configs/util/config/__init__.py,sha256=nEFiDG3-dvvTytYn1tEkPFzp7fgaGRp2j7toSN7yRGs,501
|
@@ -121,11 +129,19 @@ nshtrainer/profiler/_base.py,sha256=kFcSVn9gJuMwgDxbfyHh46CmEAIPZjxw3yjPbKgzvwA,
|
|
121
129
|
nshtrainer/profiler/advanced.py,sha256=XrM3FX0ThCv5UwUrrH0l4Ow4LGAtpiBww2N8QAU5NOQ,1160
|
122
130
|
nshtrainer/profiler/pytorch.py,sha256=8K37XvPnCApUpIK8tA2zNMFIaIiTLSoxKQoiyCPBm1Q,2757
|
123
131
|
nshtrainer/profiler/simple.py,sha256=PimjqcU-JuS-8C0ZGHAdwCxgNLij4x0FH6WXsjBQzZs,1005
|
124
|
-
nshtrainer/trainer/__init__.py,sha256=
|
125
|
-
nshtrainer/trainer/_config.py,sha256=
|
132
|
+
nshtrainer/trainer/__init__.py,sha256=ggDHzIUbABezh4BjEwrxyWuXmuDBV-x4jv9gwXgVHU0,250
|
133
|
+
nshtrainer/trainer/_config.py,sha256=0GgofvaWf5Vo9REXNJpTvpVVRlFExGTOzcOt4jwJXNk,34129
|
126
134
|
nshtrainer/trainer/_runtime_callback.py,sha256=6F2Gq27Q8OFfN3RtdNC6QRA8ac0LC1hh4DUE3V5WgbI,4217
|
135
|
+
nshtrainer/trainer/accelerator.py,sha256=Bqq-ry7DeCY4zw9_zBvTZiijpA-uUHrDjtbLV652m4M,2415
|
136
|
+
nshtrainer/trainer/plugin/__init__.py,sha256=UM8f70Ml3RGpsXeQr1Yh1yBcxxjFNGFGgJbniCn_rws,366
|
137
|
+
nshtrainer/trainer/plugin/base.py,sha256=9-qUHXGpll_yCylun0899sbmJDpyhD9IQcBtVrJx38I,919
|
138
|
+
nshtrainer/trainer/plugin/environment.py,sha256=NW0qbsbvDPe59JGOMgPLq1fj7szLucIV1WRTxCrcjF4,4367
|
139
|
+
nshtrainer/trainer/plugin/io.py,sha256=nm6YDCVZAhmPvLaLnw6q4BrK2Gj2wvD5ZLDhj1xneEE,2030
|
140
|
+
nshtrainer/trainer/plugin/layer_sync.py,sha256=h-ydZwXepnsw5-paLgiDatqPyQ_8C0QEv6DrYFIaXOo,735
|
141
|
+
nshtrainer/trainer/plugin/precision.py,sha256=I0QsB1bVxmsFmBOkgrAfGONsuYae_lD9Bz0PfJEQvH4,5598
|
127
142
|
nshtrainer/trainer/signal_connector.py,sha256=GhfGcSzfaTNhnj2QFkBDq5aT7FqbLMA7eC8SYQs8_8w,10828
|
128
|
-
nshtrainer/trainer/
|
143
|
+
nshtrainer/trainer/strategy.py,sha256=VPTn5z3zvXTydY8IJchjhjcOfpvtoejnvUkq5E4WTus,1368
|
144
|
+
nshtrainer/trainer/trainer.py,sha256=l2kJs27v4IHZnzxExr0zX0sVex0wukgiD2Wn_0wiGJg,20836
|
129
145
|
nshtrainer/util/_environment_info.py,sha256=MT8mBe6ZolRfKiwU-les1P-lPNPqXpHQcfADrh_A3uY,24629
|
130
146
|
nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
|
131
147
|
nshtrainer/util/config/__init__.py,sha256=Z39JJufSb61Lhn2GfVcv3eFW_eorOrN9-9llDWlnZZM,272
|
@@ -137,6 +153,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
|
|
137
153
|
nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
|
138
154
|
nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
|
139
155
|
nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
|
140
|
-
nshtrainer-1.0.
|
141
|
-
nshtrainer-1.0.
|
142
|
-
nshtrainer-1.0.
|
156
|
+
nshtrainer-1.0.0b30.dist-info/METADATA,sha256=zxFm4X5APkZR6E4E8-jzVghTwYEYCJQzCHpCV_8hWzg,988
|
157
|
+
nshtrainer-1.0.0b30.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
|
158
|
+
nshtrainer-1.0.0b30.dist-info/RECORD,,
|
File without changes
|