nshtrainer 0.30.0__py3-none-any.whl → 0.31.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +1 -2
- nshtrainer/_directory.py +85 -0
- nshtrainer/callbacks/__init__.py +8 -0
- nshtrainer/callbacks/directory_setup.py +85 -0
- nshtrainer/callbacks/rlp_sanity_checks.py +230 -0
- nshtrainer/callbacks/shared_parameters.py +87 -0
- nshtrainer/config.py +67 -0
- nshtrainer/ll/__init__.py +5 -4
- nshtrainer/ll/model.py +7 -0
- nshtrainer/loggers/wandb.py +1 -1
- nshtrainer/lr_scheduler/linear_warmup_cosine.py +3 -8
- nshtrainer/model/__init__.py +0 -21
- nshtrainer/model/base.py +139 -44
- nshtrainer/model/config.py +7 -1025
- nshtrainer/model/{modules → mixins}/callback.py +2 -2
- nshtrainer/model/{modules → mixins}/logger.py +13 -16
- nshtrainer/profiler/__init__.py +13 -0
- nshtrainer/profiler/_base.py +29 -0
- nshtrainer/profiler/advanced.py +37 -0
- nshtrainer/profiler/pytorch.py +83 -0
- nshtrainer/profiler/simple.py +36 -0
- nshtrainer/trainer/_config.py +778 -0
- nshtrainer/trainer/trainer.py +16 -17
- nshtrainer/{config → util/config}/__init__.py +1 -0
- {nshtrainer-0.30.0.dist-info → nshtrainer-0.31.0.dist-info}/METADATA +1 -1
- {nshtrainer-0.30.0.dist-info → nshtrainer-0.31.0.dist-info}/RECORD +28 -22
- nshtrainer/model/modules/debug.py +0 -42
- nshtrainer/model/modules/distributed.py +0 -70
- nshtrainer/model/modules/profiler.py +0 -24
- nshtrainer/model/modules/rlp_sanity_checks.py +0 -202
- nshtrainer/model/modules/shared_parameters.py +0 -72
- /nshtrainer/{config → util/config}/duration.py +0 -0
- {nshtrainer-0.30.0.dist-info → nshtrainer-0.31.0.dist-info}/WHEEL +0 -0
nshtrainer/trainer/trainer.py
CHANGED
|
@@ -18,10 +18,8 @@ from typing_extensions import Unpack, assert_never, override
|
|
|
18
18
|
|
|
19
19
|
from .._checkpoint.metadata import _write_checkpoint_metadata
|
|
20
20
|
from ..callbacks.base import resolve_all_callbacks
|
|
21
|
-
from
|
|
21
|
+
from ._config import (
|
|
22
22
|
AcceleratorConfigProtocol,
|
|
23
|
-
BaseConfig,
|
|
24
|
-
BaseProfilerConfig,
|
|
25
23
|
LightningTrainerKwargs,
|
|
26
24
|
StrategyConfigProtocol,
|
|
27
25
|
)
|
|
@@ -29,6 +27,9 @@ from ._runtime_callback import RuntimeTrackerCallback, Stage
|
|
|
29
27
|
from .checkpoint_connector import _CheckpointConnector
|
|
30
28
|
from .signal_connector import _SignalConnector
|
|
31
29
|
|
|
30
|
+
if TYPE_CHECKING:
|
|
31
|
+
from ..model.config import BaseConfig
|
|
32
|
+
|
|
32
33
|
log = logging.getLogger(__name__)
|
|
33
34
|
|
|
34
35
|
|
|
@@ -58,14 +59,14 @@ def _is_bf16_supported_no_emulation():
|
|
|
58
59
|
|
|
59
60
|
class Trainer(LightningTrainer):
|
|
60
61
|
@classmethod
|
|
61
|
-
def _pre_init(cls, config: BaseConfig):
|
|
62
|
+
def _pre_init(cls, config: "BaseConfig"):
|
|
62
63
|
if (precision := config.trainer.set_float32_matmul_precision) is not None:
|
|
63
64
|
torch.set_float32_matmul_precision(precision)
|
|
64
65
|
|
|
65
66
|
@classmethod
|
|
66
67
|
def _update_kwargs(
|
|
67
68
|
cls,
|
|
68
|
-
config: BaseConfig,
|
|
69
|
+
config: "BaseConfig",
|
|
69
70
|
kwargs_ctor: LightningTrainerKwargs,
|
|
70
71
|
):
|
|
71
72
|
kwargs: LightningTrainerKwargs = {
|
|
@@ -217,18 +218,16 @@ class Trainer(LightningTrainer):
|
|
|
217
218
|
gradient_clip_val=grad_clip_config.value,
|
|
218
219
|
)
|
|
219
220
|
|
|
220
|
-
if
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
raise ValueError(f"{profiler=} is not an instance of `{Profiler}`.")
|
|
227
|
-
|
|
221
|
+
if profiler_config := config.trainer.profiler:
|
|
222
|
+
if (profiler := profiler_config.create_profiler(config)) is None:
|
|
223
|
+
log.warning(f"Profiler config {profiler_config=} returned None.")
|
|
224
|
+
# Make sure that the profiler is an instance of `Profiler`.
|
|
225
|
+
elif not isinstance(profiler, Profiler):
|
|
226
|
+
raise ValueError(f"{profiler=} is not an instance of `{Profiler}`.")
|
|
228
227
|
# Otherwise, if the profiler is a string (e.g., "simpe", "advanced", "pytorch"),
|
|
229
228
|
# then we just pass it through.
|
|
230
|
-
|
|
231
|
-
|
|
229
|
+
else:
|
|
230
|
+
_update_kwargs(profiler=profiler)
|
|
232
231
|
|
|
233
232
|
if callbacks := resolve_all_callbacks(config):
|
|
234
233
|
_update_kwargs(callbacks=callbacks)
|
|
@@ -281,7 +280,7 @@ class Trainer(LightningTrainer):
|
|
|
281
280
|
@override
|
|
282
281
|
def __init__(
|
|
283
282
|
self,
|
|
284
|
-
config: BaseConfig,
|
|
283
|
+
config: "BaseConfig",
|
|
285
284
|
/,
|
|
286
285
|
**kwargs: Unpack[LightningTrainerKwargs],
|
|
287
286
|
):
|
|
@@ -424,7 +423,7 @@ class Trainer(LightningTrainer):
|
|
|
424
423
|
# Save the checkpoint metadata
|
|
425
424
|
metadata_path = None
|
|
426
425
|
lm = self._base_module
|
|
427
|
-
root_config = cast(BaseConfig, lm.hparams)
|
|
426
|
+
root_config = cast("BaseConfig", lm.hparams)
|
|
428
427
|
if root_config.trainer.save_checkpoint_metadata and self.is_global_zero:
|
|
429
428
|
# Generate the metadata and write to disk
|
|
430
429
|
if (
|
|
@@ -1,11 +1,12 @@
|
|
|
1
|
-
nshtrainer/__init__.py,sha256=
|
|
1
|
+
nshtrainer/__init__.py,sha256=flMI50Hj1Ie8c1YMSUQ759AqtNBQLT_zHaV2J9EUmOs,573
|
|
2
2
|
nshtrainer/_callback.py,sha256=A1zLsTy4b_wOYnInLLXGSRdHzT2yNa6mPEql-ozm0u0,1013
|
|
3
3
|
nshtrainer/_checkpoint/loader.py,sha256=5vjg-OFChXJjgiOVv8vnV8nwTscfdDtEdxQRz6uPfDE,14158
|
|
4
4
|
nshtrainer/_checkpoint/metadata.py,sha256=5D4PgKodzhLsmQvuF3xxkH49epKaegxi4wh_ImDTtns,4737
|
|
5
5
|
nshtrainer/_checkpoint/saver.py,sha256=MbX_WjkDtHHAf9Ms-KXDlknkjiPXVoGIe2ciO28AdZ0,1264
|
|
6
|
+
nshtrainer/_directory.py,sha256=RjnW6vKTeKlz2vQWT3cG0Jje5BkFXA7HpUubDhcSiq4,2993
|
|
6
7
|
nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
|
|
7
8
|
nshtrainer/_hf_hub.py,sha256=0bkXkqhve5D1onMW-fCfuvVKlTn0i6jv_6uMNgZ7OHQ,12974
|
|
8
|
-
nshtrainer/callbacks/__init__.py,sha256=
|
|
9
|
+
nshtrainer/callbacks/__init__.py,sha256=NpEV8bMU12ClFN2sLKLBDXnuwIHYyZOCNxDZgjrV104,2892
|
|
9
10
|
nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
|
|
10
11
|
nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
|
|
11
12
|
nshtrainer/callbacks/base.py,sha256=NpjeKmonJ1Kaz5_39XSn3LlDwvbGjk6WV8BpHSNCvI4,3508
|
|
@@ -14,6 +15,7 @@ nshtrainer/callbacks/checkpoint/_base.py,sha256=vvlwuD-20NozYVIolGGShmUdkkNYeuwN
|
|
|
14
15
|
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=8BHgLAd3Tuzf5sup0guEAKF1jJiAwYsjdKBFYZw98ac,2171
|
|
15
16
|
nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CWWv0cSwQ1VAX26N7hAyMxbNCk26Keh39oQguBEK5To,1102
|
|
16
17
|
nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=ctT88EGT22_t_6tr5r7Sfo43cuve6XeroBnBYRMPOus,3372
|
|
18
|
+
nshtrainer/callbacks/directory_setup.py,sha256=c0uY0oTqLcQ3egInHO7G6BeQQgk_xvOLoHH8FR-9U0U,2629
|
|
17
19
|
nshtrainer/callbacks/early_stopping.py,sha256=VWuJz0oN87b6SwBeVc32YNpeJr1wts8K45k8JJJmG9I,4617
|
|
18
20
|
nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,12131
|
|
19
21
|
nshtrainer/callbacks/finite_checks.py,sha256=gJC_RUr3ais3FJI0uB6wUZnDdE3WRwCix3ppA3PwQXA,2077
|
|
@@ -22,15 +24,16 @@ nshtrainer/callbacks/interval.py,sha256=smz5Zl8cN6X6yHKVsMRS2e3SEkzRCP3LvwE1ONvL
|
|
|
22
24
|
nshtrainer/callbacks/log_epoch.py,sha256=fTa_K_Y8A7g09630cG4YkDE6AzSMPkjb9bpPm4gtqos,1120
|
|
23
25
|
nshtrainer/callbacks/norm_logging.py,sha256=T2psu8mYsw9iahPKT6aUPjkGrZ4TIzm6_UUUmE09GJs,6274
|
|
24
26
|
nshtrainer/callbacks/print_table.py,sha256=_FdAHhqylWGk4Z0c2FrLFeiMA4jhfA_beZRK_BHpzmE,2837
|
|
27
|
+
nshtrainer/callbacks/rlp_sanity_checks.py,sha256=c30G9jAu42QLLIS5LnusdSnI3wqyIHgOUFDRcKESuNI,9935
|
|
28
|
+
nshtrainer/callbacks/shared_parameters.py,sha256=fqlDweFDXPV_bfcAWpRgaJIad9i5AehYDtuJjDtUum4,2922
|
|
25
29
|
nshtrainer/callbacks/throughput_monitor.py,sha256=H_ocXErZxUO3dxFk8Tx_VQdpI9E_Ztvqof5WtFevLyQ,1838
|
|
26
30
|
nshtrainer/callbacks/timer.py,sha256=quS79oYClDUvQxJkNWmDMe0hwRUkkREgTgqzVrnom50,4607
|
|
27
31
|
nshtrainer/callbacks/wandb_watch.py,sha256=Y6SEXfIx3kDDQbI5zpP53BVq0FBLJbLd3RJsiHZk1-Y,2921
|
|
28
|
-
nshtrainer/config
|
|
29
|
-
nshtrainer/config/duration.py,sha256=f_obz0eorkktI3HzAuIawABDkvuL4lDqCxcPb3UW7Q4,692
|
|
32
|
+
nshtrainer/config.py,sha256=W6nAmn5Y1GVZto9vkx4v8i5XdikMSdVYDiq7kbDEWAg,5900
|
|
30
33
|
nshtrainer/data/__init__.py,sha256=7mk1tr7SWUZ7ySbsf0y0ZPszk7u4QznPhQ-7wnpH9ec,149
|
|
31
34
|
nshtrainer/data/balanced_batch_sampler.py,sha256=dGBTDDtlBU6c-ZlVQOCnTW7SjTB5hczWsOWEdUWjvkA,4385
|
|
32
35
|
nshtrainer/data/transform.py,sha256=6SNs3_TpNpfhcwTwvPKyEJ3opM1OT7LmMEYQNHKgRl8,2227
|
|
33
|
-
nshtrainer/ll/__init__.py,sha256=
|
|
36
|
+
nshtrainer/ll/__init__.py,sha256=L-aTi1V1bbvnZjOro8NvI393zbHQSFR9movWSRK9Mds,2477
|
|
34
37
|
nshtrainer/ll/_experimental.py,sha256=oBQCKOEVYoxuUU9eLb-Fg2B2mzZD7SA0zfAO6lmWZ88,53
|
|
35
38
|
nshtrainer/ll/actsave.py,sha256=2lbiseSrjcwFT6AiyLNWarTWl1bnzliVWlu1iOfnP30,209
|
|
36
39
|
nshtrainer/ll/callbacks.py,sha256=AxyUmc8aGRSjx6WwwgXYCmdJ73rwLuEAEH0AGRosojQ,49
|
|
@@ -38,7 +41,7 @@ nshtrainer/ll/config.py,sha256=fKumJf42HY2FITX1QUM1OTXkYD6U2np2ciyd4PFRPZ8,145
|
|
|
38
41
|
nshtrainer/ll/data.py,sha256=zRG0FRje-jtSHximVzkHIHzpwsyQxpHCoACFihNKLPM,44
|
|
39
42
|
nshtrainer/ll/log.py,sha256=d4BB3TyM8imK65EXOiOeUTF0zFM1ropbe7Vq3DeB0xU,140
|
|
40
43
|
nshtrainer/ll/lr_scheduler.py,sha256=7xjhN6L69BCUzFhcy33NtMtPuCzHiB611zVWFg92lQ0,52
|
|
41
|
-
nshtrainer/ll/model.py,sha256=
|
|
44
|
+
nshtrainer/ll/model.py,sha256=Cw8Vq8IUL6YU1fTUcOIZsXcNJ3XyKgQY4YENIsL9H7c,996
|
|
42
45
|
nshtrainer/ll/nn.py,sha256=8qiRDFwojIxkB7-LtNWk4mLL2tJbaskHYofDsOIHiNg,42
|
|
43
46
|
nshtrainer/ll/optimizer.py,sha256=3T-VZtT73jVvwCNJGDjgGEbzs-1LFTzMQH-SB_58mSo,49
|
|
44
47
|
nshtrainer/ll/runner.py,sha256=B0m5VEhNKIjF1aFmqPkonkQxDoRL2jeHZGsV3zwhSVE,117
|
|
@@ -51,44 +54,47 @@ nshtrainer/loggers/__init__.py,sha256=C_xk0A3_qKbNdTmzK85AgjRHFD3w-jPRS2ig-iPhfE
|
|
|
51
54
|
nshtrainer/loggers/_base.py,sha256=xiZKEK0ALJkcqf4OpVNRY0QbZsamR_WR7x7m_68YHXQ,705
|
|
52
55
|
nshtrainer/loggers/csv.py,sha256=D_lYyd94bZ8jAgnRo-ARtFgVcInaD9zktxtsUD9RWCI,1052
|
|
53
56
|
nshtrainer/loggers/tensorboard.py,sha256=wL2amRSdP68zbslZvBeM0ZQBnjF3hIKsz-_lBbdomaM,2216
|
|
54
|
-
nshtrainer/loggers/wandb.py,sha256=
|
|
57
|
+
nshtrainer/loggers/wandb.py,sha256=8B2BMMzILRSUEiCkmp_fBpcXs69euRKViTiaV__DJZk,5128
|
|
55
58
|
nshtrainer/lr_scheduler/__init__.py,sha256=uEvgaFAs-4s_bAEMaildy0GT6OvgpgOEKTuzqutESHE,736
|
|
56
59
|
nshtrainer/lr_scheduler/_base.py,sha256=7xOIuxQ86YHbFWG5a3gX46emQj1WN_LaY4-i0Q1TDBg,3659
|
|
57
|
-
nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=
|
|
60
|
+
nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=YQm84Sb4SWrofpBwa39DCslJvu2uorjbpWaGWyys1l4,5352
|
|
58
61
|
nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=h76oTHYpMxauV_l6lviya5DW-WKArwxxf7ZQizhmbCw,2782
|
|
59
62
|
nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy00,50
|
|
60
63
|
nshtrainer/metrics/_config.py,sha256=jgRBfDAQLFTW7AiUY7CRtdfts6CR6keeuqm0FFMWCzQ,1288
|
|
61
|
-
nshtrainer/model/__init__.py,sha256=
|
|
62
|
-
nshtrainer/model/base.py,sha256=
|
|
63
|
-
nshtrainer/model/config.py,sha256=
|
|
64
|
-
nshtrainer/model/
|
|
65
|
-
nshtrainer/model/
|
|
66
|
-
nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
|
|
67
|
-
nshtrainer/model/modules/logger.py,sha256=CJWSmNT8SV5GLtfml-qGYenqRPXcNOMsJRGEavAd8Hw,5464
|
|
68
|
-
nshtrainer/model/modules/profiler.py,sha256=rQ_jRMcM1Z2AIROZlRnBRHM5rkTpq67afZPD6CIRfXs,825
|
|
69
|
-
nshtrainer/model/modules/rlp_sanity_checks.py,sha256=I_ralr2ThQ-D_FkVQTwbdXLLlgHJEr7-s01I5wSDjps,8893
|
|
70
|
-
nshtrainer/model/modules/shared_parameters.py,sha256=ZiRKkZXr6RwdwLCdZCJPl3dXe7bnT8Z9yTeRK5bXBGk,2687
|
|
64
|
+
nshtrainer/model/__init__.py,sha256=2i_VEy6u_Y1LUGKljHXWeekvhnUcanZM2QyaaBM1Bmw,261
|
|
65
|
+
nshtrainer/model/base.py,sha256=hT27FtzwKQiEL0C8RcaTKYXlanfvzTxHOJpHUcWiItk,19891
|
|
66
|
+
nshtrainer/model/config.py,sha256=Q4Wong6w3cp_Sq7s8iZdABKF-LZBbSCFn_TQPYkhkrI,6572
|
|
67
|
+
nshtrainer/model/mixins/callback.py,sha256=lh3imlw1H3ESIG4WFA5frooSlWi6-RPUUDRFGRzEg4A,8571
|
|
68
|
+
nshtrainer/model/mixins/logger.py,sha256=xOymSTofukEYZGkGojXsMEO__ZlBI5lIPZVmlotMEX8,5291
|
|
71
69
|
nshtrainer/nn/__init__.py,sha256=0QPFl02a71WZQjLMGOlFNMmsYP5aa1q3eABHmnWH58Q,1427
|
|
72
70
|
nshtrainer/nn/mlp.py,sha256=V0FrScpIUdg_IgIO8GMtIsGEtmHjwF14i2IWxmZrsqg,5952
|
|
73
71
|
nshtrainer/nn/module_dict.py,sha256=NOY0B6WDTnktyWH4GthsprMQo0bpehC-hCq9SfD8paE,2329
|
|
74
72
|
nshtrainer/nn/module_list.py,sha256=fb2u5Rqdjff8Pekyr9hkCPkBorQ-fldzzFAjsgWAm30,1719
|
|
75
73
|
nshtrainer/nn/nonlinearity.py,sha256=4sYE4MN5zojc-go1k0PYtqssVRuXrM7D4tbpIXp5K-E,6078
|
|
76
74
|
nshtrainer/optimizer.py,sha256=kuJEA1pvB3y1FcsfhAoOJujVqEZqFHlmYO8GW6JeA1g,1527
|
|
75
|
+
nshtrainer/profiler/__init__.py,sha256=RQYkqQBVWuVvfdtAJIk2x5bNsXownklT87Mr_j-uXjw,474
|
|
76
|
+
nshtrainer/profiler/_base.py,sha256=YF5lsJBIl9qts9GLW5Z62JuYeo4SnIArhlFwTGkfTb4,897
|
|
77
|
+
nshtrainer/profiler/advanced.py,sha256=44asloha0aGUW8YwjQt3lm3ve8H-N6mM4QgseUSLT30,1112
|
|
78
|
+
nshtrainer/profiler/pytorch.py,sha256=tGeRvoPP5ulWX2RkfXrQvMBoki1T95dpz5p8mwyon1I,2709
|
|
79
|
+
nshtrainer/profiler/simple.py,sha256=MbMfsJvligd0mtGiltxJ0T8MQVDP9T9BzQZFwswl66Y,957
|
|
77
80
|
nshtrainer/runner.py,sha256=USAjrExHkN5oVNVunsoPnLxfQrEHSaa54S3RipOe544,3605
|
|
78
81
|
nshtrainer/scripts/find_packages.py,sha256=ixYivZobumyyGsf2B9oYMLyLTRcBzY_vUv-u3bNW-hs,1424
|
|
79
82
|
nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3E,40
|
|
83
|
+
nshtrainer/trainer/_config.py,sha256=2SgO1L5VBfWQ5g7Dg2dTx_vq2_Wo7dTqt2A4GlQaGo0,28673
|
|
80
84
|
nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
|
|
81
85
|
nshtrainer/trainer/checkpoint_connector.py,sha256=r0ir4xYSdf_jebM0x09qaO6nJsvsiRQDyM0fs80ppOQ,2347
|
|
82
86
|
nshtrainer/trainer/signal_connector.py,sha256=2EzkVktlasl8PgWAKNLDZRUMY__gRlDy1HdinAU-tfU,10740
|
|
83
|
-
nshtrainer/trainer/trainer.py,sha256=
|
|
87
|
+
nshtrainer/trainer/trainer.py,sha256=iYueHW-m8fHyC8SQuXmpgxq_-GUa7pAJik7rDFPXmy0,17499
|
|
84
88
|
nshtrainer/util/_environment_info.py,sha256=CFUUZYjXhBLWGc0jtPNOaZgYMueUDEHpEaWFA1f3GoY,24213
|
|
85
89
|
nshtrainer/util/_useful_types.py,sha256=dwZokFkIe7M5i2GR3nQ9A1lhGw06DMAFfH5atyquqSA,8000
|
|
90
|
+
nshtrainer/util/config/__init__.py,sha256=6iCFLhujhbOi7Q694e--Sq-ethiGoGHShm699GPV8Zg,154
|
|
91
|
+
nshtrainer/util/config/duration.py,sha256=f_obz0eorkktI3HzAuIawABDkvuL4lDqCxcPb3UW7Q4,692
|
|
86
92
|
nshtrainer/util/environment.py,sha256=AeW_kLl-N70wmb6L_JLz1wRj0kA70xs6RCmc9iUqczE,4159
|
|
87
93
|
nshtrainer/util/path.py,sha256=VkpuhR4GaZtSFBVqbGAvfjcrU-PR8xwiGzzwFNOWP9c,2995
|
|
88
94
|
nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
|
|
89
95
|
nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
|
|
90
96
|
nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
|
|
91
97
|
nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
|
|
92
|
-
nshtrainer-0.
|
|
93
|
-
nshtrainer-0.
|
|
94
|
-
nshtrainer-0.
|
|
98
|
+
nshtrainer-0.31.0.dist-info/METADATA,sha256=99b-8IvPlMmTrjyb5EK1kKsgKj8lWhGw4gZvM5sKyzc,916
|
|
99
|
+
nshtrainer-0.31.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
100
|
+
nshtrainer-0.31.0.dist-info/RECORD,,
|
|
@@ -1,42 +0,0 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
import torch.distributed
|
|
5
|
-
|
|
6
|
-
log = logging.getLogger(__name__)
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
class DebugModuleMixin:
|
|
10
|
-
@torch.jit.unused
|
|
11
|
-
def breakpoint(self, rank_zero_only: bool = True):
|
|
12
|
-
if (
|
|
13
|
-
not rank_zero_only
|
|
14
|
-
or not torch.distributed.is_initialized()
|
|
15
|
-
or torch.distributed.get_rank() == 0
|
|
16
|
-
):
|
|
17
|
-
breakpoint()
|
|
18
|
-
|
|
19
|
-
if rank_zero_only and torch.distributed.is_initialized():
|
|
20
|
-
_ = torch.distributed.barrier()
|
|
21
|
-
|
|
22
|
-
@torch.jit.unused
|
|
23
|
-
def ensure_finite(
|
|
24
|
-
self,
|
|
25
|
-
tensor: torch.Tensor,
|
|
26
|
-
name: str | None = None,
|
|
27
|
-
throw: bool = False,
|
|
28
|
-
):
|
|
29
|
-
name_parts: list[str] = ["Tensor"]
|
|
30
|
-
if name is not None:
|
|
31
|
-
name_parts.append(name)
|
|
32
|
-
name = " ".join(name_parts)
|
|
33
|
-
|
|
34
|
-
not_finite = ~torch.isfinite(tensor)
|
|
35
|
-
if not_finite.any():
|
|
36
|
-
msg = f"{name} has {not_finite.sum().item()}/{not_finite.numel()} non-finite values."
|
|
37
|
-
if throw:
|
|
38
|
-
raise RuntimeError(msg)
|
|
39
|
-
else:
|
|
40
|
-
log.warning(msg)
|
|
41
|
-
return False
|
|
42
|
-
return True
|
|
@@ -1,70 +0,0 @@
|
|
|
1
|
-
from typing import Any, Literal, cast
|
|
2
|
-
|
|
3
|
-
import torch.distributed
|
|
4
|
-
from lightning.pytorch import LightningModule
|
|
5
|
-
from torch.distributed import ReduceOp
|
|
6
|
-
from typing_extensions import TypeVar
|
|
7
|
-
|
|
8
|
-
from ...util.typing_utils import mixin_base_type
|
|
9
|
-
|
|
10
|
-
T = TypeVar("T", infer_variance=True)
|
|
11
|
-
|
|
12
|
-
ReduceOpStr = Literal[
|
|
13
|
-
"avg",
|
|
14
|
-
"mean",
|
|
15
|
-
"band",
|
|
16
|
-
"bor",
|
|
17
|
-
"bxor",
|
|
18
|
-
"max",
|
|
19
|
-
"min",
|
|
20
|
-
"premul_sum",
|
|
21
|
-
"product",
|
|
22
|
-
"sum",
|
|
23
|
-
]
|
|
24
|
-
VALID_REDUCE_OPS = (
|
|
25
|
-
"avg",
|
|
26
|
-
"mean",
|
|
27
|
-
"band",
|
|
28
|
-
"bor",
|
|
29
|
-
"bxor",
|
|
30
|
-
"max",
|
|
31
|
-
"min",
|
|
32
|
-
"premul_sum",
|
|
33
|
-
"product",
|
|
34
|
-
"sum",
|
|
35
|
-
)
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
class DistributedMixin(mixin_base_type(LightningModule)):
|
|
39
|
-
def all_gather_object(
|
|
40
|
-
self,
|
|
41
|
-
object: T,
|
|
42
|
-
group: torch.distributed.ProcessGroup | None = None,
|
|
43
|
-
) -> list[T]:
|
|
44
|
-
if (
|
|
45
|
-
not torch.distributed.is_available()
|
|
46
|
-
or not torch.distributed.is_initialized()
|
|
47
|
-
):
|
|
48
|
-
return [object]
|
|
49
|
-
|
|
50
|
-
object_list = [cast(T, None) for _ in range(self.trainer.world_size)]
|
|
51
|
-
torch.distributed.all_gather_object(object_list, object, group=group)
|
|
52
|
-
return object_list
|
|
53
|
-
|
|
54
|
-
def barrier(self, name: str | None = None):
|
|
55
|
-
self.trainer.strategy.barrier(name=name)
|
|
56
|
-
|
|
57
|
-
def reduce(
|
|
58
|
-
self,
|
|
59
|
-
tensor: torch.Tensor,
|
|
60
|
-
reduce_op: ReduceOp.RedOpType | ReduceOpStr,
|
|
61
|
-
group: Any | None = None,
|
|
62
|
-
) -> torch.Tensor:
|
|
63
|
-
if isinstance(reduce_op, str):
|
|
64
|
-
# validate reduce_op
|
|
65
|
-
if reduce_op not in VALID_REDUCE_OPS:
|
|
66
|
-
raise ValueError(
|
|
67
|
-
f"reduce_op must be one of {VALID_REDUCE_OPS}, got {reduce_op}"
|
|
68
|
-
)
|
|
69
|
-
|
|
70
|
-
return self.trainer.strategy.reduce(tensor, group=group, reduce_op=reduce_op)
|
|
@@ -1,24 +0,0 @@
|
|
|
1
|
-
from lightning.pytorch import LightningDataModule, LightningModule
|
|
2
|
-
from lightning.pytorch.profilers import PassThroughProfiler
|
|
3
|
-
|
|
4
|
-
from ...util.typing_utils import mixin_base_type
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
class ProfilerMixin(mixin_base_type(LightningModule)):
|
|
8
|
-
@property
|
|
9
|
-
def profiler(self):
|
|
10
|
-
if not isinstance(self, (LightningModule, LightningDataModule)):
|
|
11
|
-
raise TypeError(
|
|
12
|
-
"`profiler` can only be used on LightningModule or LightningDataModule"
|
|
13
|
-
)
|
|
14
|
-
|
|
15
|
-
if (trainer := self.trainer) is None:
|
|
16
|
-
raise RuntimeError("trainer is not defined")
|
|
17
|
-
|
|
18
|
-
if not hasattr(trainer, "profiler"):
|
|
19
|
-
raise RuntimeError("trainer does not have profiler")
|
|
20
|
-
|
|
21
|
-
if (profiler := getattr(trainer, "profiler")) is None:
|
|
22
|
-
profiler = PassThroughProfiler()
|
|
23
|
-
|
|
24
|
-
return profiler
|
|
@@ -1,202 +0,0 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
from collections.abc import Mapping
|
|
3
|
-
from typing import cast
|
|
4
|
-
|
|
5
|
-
import torch
|
|
6
|
-
from lightning.pytorch import LightningModule, Trainer
|
|
7
|
-
from lightning.pytorch.utilities.types import (
|
|
8
|
-
LRSchedulerConfigType,
|
|
9
|
-
LRSchedulerTypeUnion,
|
|
10
|
-
)
|
|
11
|
-
from typing_extensions import Protocol, override, runtime_checkable
|
|
12
|
-
|
|
13
|
-
from ...util.typing_utils import mixin_base_type
|
|
14
|
-
from ..config import BaseConfig
|
|
15
|
-
from .callback import CallbackModuleMixin
|
|
16
|
-
|
|
17
|
-
log = logging.getLogger(__name__)
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
def _on_train_start_callback(trainer: Trainer, pl_module: LightningModule):
|
|
21
|
-
# If we're in PL's "sanity check" mode, we don't need to run this check
|
|
22
|
-
if trainer.sanity_checking:
|
|
23
|
-
return
|
|
24
|
-
|
|
25
|
-
config = cast(BaseConfig, pl_module.hparams)
|
|
26
|
-
if config.trainer.sanity_checking.reduce_lr_on_plateau == "disable":
|
|
27
|
-
return
|
|
28
|
-
|
|
29
|
-
# if no lr schedulers, return
|
|
30
|
-
if not trainer.lr_scheduler_configs:
|
|
31
|
-
return
|
|
32
|
-
|
|
33
|
-
errors: list[str] = []
|
|
34
|
-
disable_message = (
|
|
35
|
-
"Otherwise, set `config.trainer.sanity_checking.reduce_lr_on_plateau='disable'` "
|
|
36
|
-
"to disable this sanity check."
|
|
37
|
-
)
|
|
38
|
-
|
|
39
|
-
for lr_scheduler_config in trainer.lr_scheduler_configs:
|
|
40
|
-
if not lr_scheduler_config.reduce_on_plateau:
|
|
41
|
-
continue
|
|
42
|
-
|
|
43
|
-
match lr_scheduler_config.interval:
|
|
44
|
-
case "epoch":
|
|
45
|
-
# we need to make sure that the trainer runs val every `frequency` epochs
|
|
46
|
-
|
|
47
|
-
# If `trainer.check_val_every_n_epoch` is None, then Lightning
|
|
48
|
-
# will run val every `int(trainer.val_check_interval)` steps.
|
|
49
|
-
# So, first we need to make sure that `trainer.val_check_interval` is not None first.
|
|
50
|
-
if trainer.check_val_every_n_epoch is None:
|
|
51
|
-
errors.append(
|
|
52
|
-
"Trainer is not running validation at epoch intervals "
|
|
53
|
-
"(i.e., `trainer.check_val_every_n_epoch` is None) but "
|
|
54
|
-
f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} is used."
|
|
55
|
-
f"Please set `config.trainer.check_val_every_n_epoch={lr_scheduler_config.frequency}`. "
|
|
56
|
-
+ disable_message
|
|
57
|
-
)
|
|
58
|
-
|
|
59
|
-
# Second, we make sure that the trainer runs val at least every `frequency` epochs
|
|
60
|
-
if (
|
|
61
|
-
trainer.check_val_every_n_epoch is not None
|
|
62
|
-
and lr_scheduler_config.frequency % trainer.check_val_every_n_epoch
|
|
63
|
-
!= 0
|
|
64
|
-
):
|
|
65
|
-
errors.append(
|
|
66
|
-
f"Trainer is not running validation every {lr_scheduler_config.frequency} epochs but "
|
|
67
|
-
f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} and frequency={lr_scheduler_config.frequency} is used."
|
|
68
|
-
f"Please set `config.trainer.check_val_every_n_epoch` to a multiple of {lr_scheduler_config.frequency}. "
|
|
69
|
-
+ disable_message
|
|
70
|
-
)
|
|
71
|
-
|
|
72
|
-
case "step":
|
|
73
|
-
# In this case, we need to make sure that the trainer runs val at step intervals
|
|
74
|
-
# that are multiples of `frequency`.
|
|
75
|
-
|
|
76
|
-
# First, we make sure that validation is run at step intervals
|
|
77
|
-
if trainer.check_val_every_n_epoch is not None:
|
|
78
|
-
errors.append(
|
|
79
|
-
"Trainer is running validation at epoch intervals "
|
|
80
|
-
"(i.e., `trainer.check_val_every_n_epoch` is not None) but "
|
|
81
|
-
f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} is used."
|
|
82
|
-
"Please set `config.trainer.check_val_every_n_epoch=None` "
|
|
83
|
-
f"and `config.trainer.val_check_interval={lr_scheduler_config.frequency}`. "
|
|
84
|
-
+ disable_message
|
|
85
|
-
)
|
|
86
|
-
|
|
87
|
-
# Second, we make sure `trainer.val_check_interval` is an integer
|
|
88
|
-
if not isinstance(trainer.val_check_interval, int):
|
|
89
|
-
errors.append(
|
|
90
|
-
f"Trainer is not running validation at step intervals "
|
|
91
|
-
f"(i.e., `trainer.val_check_interval` is not an integer) but "
|
|
92
|
-
f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} is used."
|
|
93
|
-
"Please set `config.trainer.val_check_interval=None` "
|
|
94
|
-
f"and `config.trainer.val_check_interval={lr_scheduler_config.frequency}`. "
|
|
95
|
-
+ disable_message
|
|
96
|
-
)
|
|
97
|
-
|
|
98
|
-
# Third, we make sure that the trainer runs val at least every `frequency` steps
|
|
99
|
-
if (
|
|
100
|
-
isinstance(trainer.val_check_interval, int)
|
|
101
|
-
and trainer.val_check_interval % lr_scheduler_config.frequency != 0
|
|
102
|
-
):
|
|
103
|
-
errors.append(
|
|
104
|
-
f"Trainer is not running validation every {lr_scheduler_config.frequency} steps but "
|
|
105
|
-
f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} and frequency={lr_scheduler_config.frequency} is used."
|
|
106
|
-
"Please set `config.trainer.val_check_interval` "
|
|
107
|
-
f"to a multiple of {lr_scheduler_config.frequency}. "
|
|
108
|
-
+ disable_message
|
|
109
|
-
)
|
|
110
|
-
|
|
111
|
-
case _:
|
|
112
|
-
pass
|
|
113
|
-
|
|
114
|
-
if not errors:
|
|
115
|
-
return
|
|
116
|
-
|
|
117
|
-
message = (
|
|
118
|
-
"ReduceLRPlateau sanity checks failed with the following errors:\n"
|
|
119
|
-
+ "\n".join(errors)
|
|
120
|
-
)
|
|
121
|
-
match config.trainer.sanity_checking.reduce_lr_on_plateau:
|
|
122
|
-
case "warn":
|
|
123
|
-
log.warning(message)
|
|
124
|
-
case "error":
|
|
125
|
-
raise ValueError(message)
|
|
126
|
-
case _:
|
|
127
|
-
pass
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
@runtime_checkable
|
|
131
|
-
class CustomRLPImplementation(Protocol):
|
|
132
|
-
__reduce_lr_on_plateau__: bool
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
class RLPSanityCheckModuleMixin(mixin_base_type(CallbackModuleMixin)):
|
|
136
|
-
@override
|
|
137
|
-
def __init__(self, *args, **kwargs):
|
|
138
|
-
super().__init__(*args, **kwargs)
|
|
139
|
-
|
|
140
|
-
global _on_train_start_callback
|
|
141
|
-
self.register_callback(on_train_start=_on_train_start_callback)
|
|
142
|
-
|
|
143
|
-
def reduce_lr_on_plateau_config(
|
|
144
|
-
self,
|
|
145
|
-
lr_scheduler: LRSchedulerTypeUnion | LRSchedulerConfigType,
|
|
146
|
-
) -> LRSchedulerConfigType:
|
|
147
|
-
if (trainer := self._trainer) is None:
|
|
148
|
-
raise RuntimeError(
|
|
149
|
-
"Could not determine the frequency of ReduceLRPlateau scheduler "
|
|
150
|
-
"because `self.trainer` is None."
|
|
151
|
-
)
|
|
152
|
-
|
|
153
|
-
# First, resolve the LR scheduler from the provided config.
|
|
154
|
-
lr_scheduler_config: LRSchedulerConfigType
|
|
155
|
-
match lr_scheduler:
|
|
156
|
-
case Mapping():
|
|
157
|
-
lr_scheduler_config = cast(LRSchedulerConfigType, lr_scheduler)
|
|
158
|
-
case _:
|
|
159
|
-
lr_scheduler_config = {"scheduler": lr_scheduler}
|
|
160
|
-
|
|
161
|
-
# Make sure the scheduler is a ReduceLRPlateau scheduler. Otherwise, warn the user.
|
|
162
|
-
if (
|
|
163
|
-
not isinstance(
|
|
164
|
-
lr_scheduler_config["scheduler"],
|
|
165
|
-
torch.optim.lr_scheduler.ReduceLROnPlateau,
|
|
166
|
-
)
|
|
167
|
-
) and (
|
|
168
|
-
not isinstance(lr_scheduler_config["scheduler"], CustomRLPImplementation)
|
|
169
|
-
or not lr_scheduler_config["scheduler"].__reduce_lr_on_plateau__
|
|
170
|
-
):
|
|
171
|
-
log.warning(
|
|
172
|
-
"`reduce_lr_on_plateau_config` should only be used with a ReduceLRPlateau scheduler. "
|
|
173
|
-
f"The provided scheduler, {lr_scheduler_config['scheduler']}, does not subclass "
|
|
174
|
-
"`torch.optim.lr_scheduler.ReduceLROnPlateau`. "
|
|
175
|
-
"Please ensure that the scheduler is a ReduceLRPlateau scheduler. "
|
|
176
|
-
"If you are using a custom ReduceLRPlateau scheduler implementation, "
|
|
177
|
-
"please either (1) make sure that it subclasses `torch.optim.lr_scheduler.ReduceLROnPlateau`, "
|
|
178
|
-
"or (2) set the scheduler's `__reduce_lr_on_plateau__` attribute to `True`."
|
|
179
|
-
)
|
|
180
|
-
|
|
181
|
-
# If trainer.check_val_every_n_epoch is an integer, then we run val at epoch intervals.
|
|
182
|
-
if trainer.check_val_every_n_epoch is not None:
|
|
183
|
-
return {
|
|
184
|
-
"reduce_on_plateau": True,
|
|
185
|
-
"interval": "epoch",
|
|
186
|
-
"frequency": trainer.check_val_every_n_epoch,
|
|
187
|
-
**lr_scheduler_config,
|
|
188
|
-
}
|
|
189
|
-
|
|
190
|
-
# Otherwise, we run val at step intervals.
|
|
191
|
-
if not isinstance(trainer.val_check_batch, int):
|
|
192
|
-
raise ValueError(
|
|
193
|
-
"Could not determine the frequency of ReduceLRPlateau scheduler "
|
|
194
|
-
f"because {trainer.val_check_batch=} is not an integer."
|
|
195
|
-
)
|
|
196
|
-
|
|
197
|
-
return {
|
|
198
|
-
"reduce_on_plateau": True,
|
|
199
|
-
"interval": "step",
|
|
200
|
-
"frequency": trainer.val_check_batch,
|
|
201
|
-
**lr_scheduler_config,
|
|
202
|
-
}
|
|
@@ -1,72 +0,0 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
from collections.abc import Sequence
|
|
3
|
-
from typing import cast
|
|
4
|
-
|
|
5
|
-
import torch.nn as nn
|
|
6
|
-
from lightning.pytorch import LightningModule, Trainer
|
|
7
|
-
from typing_extensions import override
|
|
8
|
-
|
|
9
|
-
from ...util.typing_utils import mixin_base_type
|
|
10
|
-
from ..config import BaseConfig
|
|
11
|
-
from .callback import CallbackRegistrarModuleMixin
|
|
12
|
-
|
|
13
|
-
log = logging.getLogger(__name__)
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def _parameters_to_names(parameters: Sequence[nn.Parameter], model: nn.Module):
|
|
17
|
-
mapping = {id(p): n for n, p in model.named_parameters()}
|
|
18
|
-
return [mapping[id(p)] for p in parameters]
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class SharedParametersModuleMixin(mixin_base_type(CallbackRegistrarModuleMixin)):
|
|
22
|
-
@override
|
|
23
|
-
def __init__(self, *args, **kwargs):
|
|
24
|
-
super().__init__(*args, **kwargs)
|
|
25
|
-
|
|
26
|
-
self.shared_parameters: list[tuple[nn.Parameter, int | float]] = []
|
|
27
|
-
self._warned_shared_parameters = False
|
|
28
|
-
|
|
29
|
-
def on_after_backward(_trainer: Trainer, pl_module: LightningModule):
|
|
30
|
-
nonlocal self
|
|
31
|
-
|
|
32
|
-
config = cast(BaseConfig, pl_module.hparams)
|
|
33
|
-
if not config.trainer.supports_shared_parameters:
|
|
34
|
-
return
|
|
35
|
-
|
|
36
|
-
log.debug(f"Scaling {len(self.shared_parameters)} shared parameters...")
|
|
37
|
-
no_grad_parameters: list[nn.Parameter] = []
|
|
38
|
-
for p, factor in self.shared_parameters:
|
|
39
|
-
if not hasattr(p, "grad") or p.grad is None:
|
|
40
|
-
no_grad_parameters.append(p)
|
|
41
|
-
continue
|
|
42
|
-
|
|
43
|
-
_ = p.grad.data.div_(factor)
|
|
44
|
-
|
|
45
|
-
if no_grad_parameters and not self._warned_shared_parameters:
|
|
46
|
-
no_grad_parameters_str = ", ".join(
|
|
47
|
-
_parameters_to_names(no_grad_parameters, pl_module)
|
|
48
|
-
)
|
|
49
|
-
log.warning(
|
|
50
|
-
"The following parameters were marked as shared, but had no gradients: "
|
|
51
|
-
f"{no_grad_parameters_str}"
|
|
52
|
-
)
|
|
53
|
-
self._warned_shared_parameters = True
|
|
54
|
-
|
|
55
|
-
log.debug(
|
|
56
|
-
f"Done scaling shared parameters. (len={len(self.shared_parameters)})"
|
|
57
|
-
)
|
|
58
|
-
|
|
59
|
-
self.register_callback(on_after_backward=on_after_backward)
|
|
60
|
-
|
|
61
|
-
def register_shared_parameters(
|
|
62
|
-
self, parameters: list[tuple[nn.Parameter, int | float]]
|
|
63
|
-
):
|
|
64
|
-
for parameter, factor in parameters:
|
|
65
|
-
if not isinstance(parameter, nn.Parameter):
|
|
66
|
-
raise ValueError("Shared parameters must be PyTorch parameters")
|
|
67
|
-
if not isinstance(factor, (int, float)):
|
|
68
|
-
raise ValueError("Factor must be an integer or float")
|
|
69
|
-
|
|
70
|
-
self.shared_parameters.append((parameter, factor))
|
|
71
|
-
|
|
72
|
-
log.info(f"Registered {len(parameters)} shared parameters")
|
|
File without changes
|
|
File without changes
|