nshtrainer 0.30.1__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +1 -2
- nshtrainer/_directory.py +85 -0
- nshtrainer/callbacks/__init__.py +12 -1
- nshtrainer/callbacks/debug_flag.py +72 -0
- nshtrainer/callbacks/directory_setup.py +85 -0
- nshtrainer/callbacks/rlp_sanity_checks.py +230 -0
- nshtrainer/callbacks/shared_parameters.py +87 -0
- nshtrainer/config.py +67 -0
- nshtrainer/ll/__init__.py +5 -4
- nshtrainer/ll/model.py +7 -0
- nshtrainer/loggers/wandb.py +1 -1
- nshtrainer/lr_scheduler/linear_warmup_cosine.py +1 -1
- nshtrainer/model/__init__.py +0 -21
- nshtrainer/model/base.py +124 -67
- nshtrainer/model/config.py +7 -1025
- nshtrainer/model/{modules → mixins}/logger.py +13 -16
- nshtrainer/profiler/__init__.py +13 -0
- nshtrainer/profiler/_base.py +29 -0
- nshtrainer/profiler/advanced.py +37 -0
- nshtrainer/profiler/pytorch.py +83 -0
- nshtrainer/profiler/simple.py +36 -0
- nshtrainer/trainer/_config.py +787 -0
- nshtrainer/trainer/trainer.py +16 -17
- nshtrainer/{config → util/config}/__init__.py +1 -0
- {nshtrainer-0.30.1.dist-info → nshtrainer-0.32.0.dist-info}/METADATA +1 -1
- {nshtrainer-0.30.1.dist-info → nshtrainer-0.32.0.dist-info}/RECORD +28 -22
- nshtrainer/model/modules/callback.py +0 -206
- nshtrainer/model/modules/debug.py +0 -42
- nshtrainer/model/modules/distributed.py +0 -70
- nshtrainer/model/modules/profiler.py +0 -24
- nshtrainer/model/modules/rlp_sanity_checks.py +0 -202
- nshtrainer/model/modules/shared_parameters.py +0 -72
- /nshtrainer/{config → util/config}/duration.py +0 -0
- {nshtrainer-0.30.1.dist-info → nshtrainer-0.32.0.dist-info}/WHEEL +0 -0
nshtrainer/trainer/trainer.py
CHANGED
|
@@ -18,10 +18,8 @@ from typing_extensions import Unpack, assert_never, override
|
|
|
18
18
|
|
|
19
19
|
from .._checkpoint.metadata import _write_checkpoint_metadata
|
|
20
20
|
from ..callbacks.base import resolve_all_callbacks
|
|
21
|
-
from
|
|
21
|
+
from ._config import (
|
|
22
22
|
AcceleratorConfigProtocol,
|
|
23
|
-
BaseConfig,
|
|
24
|
-
BaseProfilerConfig,
|
|
25
23
|
LightningTrainerKwargs,
|
|
26
24
|
StrategyConfigProtocol,
|
|
27
25
|
)
|
|
@@ -29,6 +27,9 @@ from ._runtime_callback import RuntimeTrackerCallback, Stage
|
|
|
29
27
|
from .checkpoint_connector import _CheckpointConnector
|
|
30
28
|
from .signal_connector import _SignalConnector
|
|
31
29
|
|
|
30
|
+
if TYPE_CHECKING:
|
|
31
|
+
from ..model.config import BaseConfig
|
|
32
|
+
|
|
32
33
|
log = logging.getLogger(__name__)
|
|
33
34
|
|
|
34
35
|
|
|
@@ -58,14 +59,14 @@ def _is_bf16_supported_no_emulation():
|
|
|
58
59
|
|
|
59
60
|
class Trainer(LightningTrainer):
|
|
60
61
|
@classmethod
|
|
61
|
-
def _pre_init(cls, config: BaseConfig):
|
|
62
|
+
def _pre_init(cls, config: "BaseConfig"):
|
|
62
63
|
if (precision := config.trainer.set_float32_matmul_precision) is not None:
|
|
63
64
|
torch.set_float32_matmul_precision(precision)
|
|
64
65
|
|
|
65
66
|
@classmethod
|
|
66
67
|
def _update_kwargs(
|
|
67
68
|
cls,
|
|
68
|
-
config: BaseConfig,
|
|
69
|
+
config: "BaseConfig",
|
|
69
70
|
kwargs_ctor: LightningTrainerKwargs,
|
|
70
71
|
):
|
|
71
72
|
kwargs: LightningTrainerKwargs = {
|
|
@@ -217,18 +218,16 @@ class Trainer(LightningTrainer):
|
|
|
217
218
|
gradient_clip_val=grad_clip_config.value,
|
|
218
219
|
)
|
|
219
220
|
|
|
220
|
-
if
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
raise ValueError(f"{profiler=} is not an instance of `{Profiler}`.")
|
|
227
|
-
|
|
221
|
+
if profiler_config := config.trainer.profiler:
|
|
222
|
+
if (profiler := profiler_config.create_profiler(config)) is None:
|
|
223
|
+
log.warning(f"Profiler config {profiler_config=} returned None.")
|
|
224
|
+
# Make sure that the profiler is an instance of `Profiler`.
|
|
225
|
+
elif not isinstance(profiler, Profiler):
|
|
226
|
+
raise ValueError(f"{profiler=} is not an instance of `{Profiler}`.")
|
|
228
227
|
# Otherwise, if the profiler is a string (e.g., "simpe", "advanced", "pytorch"),
|
|
229
228
|
# then we just pass it through.
|
|
230
|
-
|
|
231
|
-
|
|
229
|
+
else:
|
|
230
|
+
_update_kwargs(profiler=profiler)
|
|
232
231
|
|
|
233
232
|
if callbacks := resolve_all_callbacks(config):
|
|
234
233
|
_update_kwargs(callbacks=callbacks)
|
|
@@ -281,7 +280,7 @@ class Trainer(LightningTrainer):
|
|
|
281
280
|
@override
|
|
282
281
|
def __init__(
|
|
283
282
|
self,
|
|
284
|
-
config: BaseConfig,
|
|
283
|
+
config: "BaseConfig",
|
|
285
284
|
/,
|
|
286
285
|
**kwargs: Unpack[LightningTrainerKwargs],
|
|
287
286
|
):
|
|
@@ -424,7 +423,7 @@ class Trainer(LightningTrainer):
|
|
|
424
423
|
# Save the checkpoint metadata
|
|
425
424
|
metadata_path = None
|
|
426
425
|
lm = self._base_module
|
|
427
|
-
root_config = cast(BaseConfig, lm.hparams)
|
|
426
|
+
root_config = cast("BaseConfig", lm.hparams)
|
|
428
427
|
if root_config.trainer.save_checkpoint_metadata and self.is_global_zero:
|
|
429
428
|
# Generate the metadata and write to disk
|
|
430
429
|
if (
|
|
@@ -1,11 +1,12 @@
|
|
|
1
|
-
nshtrainer/__init__.py,sha256=
|
|
1
|
+
nshtrainer/__init__.py,sha256=flMI50Hj1Ie8c1YMSUQ759AqtNBQLT_zHaV2J9EUmOs,573
|
|
2
2
|
nshtrainer/_callback.py,sha256=A1zLsTy4b_wOYnInLLXGSRdHzT2yNa6mPEql-ozm0u0,1013
|
|
3
3
|
nshtrainer/_checkpoint/loader.py,sha256=5vjg-OFChXJjgiOVv8vnV8nwTscfdDtEdxQRz6uPfDE,14158
|
|
4
4
|
nshtrainer/_checkpoint/metadata.py,sha256=5D4PgKodzhLsmQvuF3xxkH49epKaegxi4wh_ImDTtns,4737
|
|
5
5
|
nshtrainer/_checkpoint/saver.py,sha256=MbX_WjkDtHHAf9Ms-KXDlknkjiPXVoGIe2ciO28AdZ0,1264
|
|
6
|
+
nshtrainer/_directory.py,sha256=RjnW6vKTeKlz2vQWT3cG0Jje5BkFXA7HpUubDhcSiq4,2993
|
|
6
7
|
nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
|
|
7
8
|
nshtrainer/_hf_hub.py,sha256=0bkXkqhve5D1onMW-fCfuvVKlTn0i6jv_6uMNgZ7OHQ,12974
|
|
8
|
-
nshtrainer/callbacks/__init__.py,sha256=
|
|
9
|
+
nshtrainer/callbacks/__init__.py,sha256=1SBLpMsx7BzgimO35MwQViYBcbgxlkyvTMz1JKUKK-0,3060
|
|
9
10
|
nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
|
|
10
11
|
nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
|
|
11
12
|
nshtrainer/callbacks/base.py,sha256=NpjeKmonJ1Kaz5_39XSn3LlDwvbGjk6WV8BpHSNCvI4,3508
|
|
@@ -14,6 +15,8 @@ nshtrainer/callbacks/checkpoint/_base.py,sha256=vvlwuD-20NozYVIolGGShmUdkkNYeuwN
|
|
|
14
15
|
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=8BHgLAd3Tuzf5sup0guEAKF1jJiAwYsjdKBFYZw98ac,2171
|
|
15
16
|
nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CWWv0cSwQ1VAX26N7hAyMxbNCk26Keh39oQguBEK5To,1102
|
|
16
17
|
nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=ctT88EGT22_t_6tr5r7Sfo43cuve6XeroBnBYRMPOus,3372
|
|
18
|
+
nshtrainer/callbacks/debug_flag.py,sha256=Mo69CtJqPWMlFBvgBEuYls8Vfp5v1QFiyMRTiMStdec,2059
|
|
19
|
+
nshtrainer/callbacks/directory_setup.py,sha256=c0uY0oTqLcQ3egInHO7G6BeQQgk_xvOLoHH8FR-9U0U,2629
|
|
17
20
|
nshtrainer/callbacks/early_stopping.py,sha256=VWuJz0oN87b6SwBeVc32YNpeJr1wts8K45k8JJJmG9I,4617
|
|
18
21
|
nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,12131
|
|
19
22
|
nshtrainer/callbacks/finite_checks.py,sha256=gJC_RUr3ais3FJI0uB6wUZnDdE3WRwCix3ppA3PwQXA,2077
|
|
@@ -22,15 +25,16 @@ nshtrainer/callbacks/interval.py,sha256=smz5Zl8cN6X6yHKVsMRS2e3SEkzRCP3LvwE1ONvL
|
|
|
22
25
|
nshtrainer/callbacks/log_epoch.py,sha256=fTa_K_Y8A7g09630cG4YkDE6AzSMPkjb9bpPm4gtqos,1120
|
|
23
26
|
nshtrainer/callbacks/norm_logging.py,sha256=T2psu8mYsw9iahPKT6aUPjkGrZ4TIzm6_UUUmE09GJs,6274
|
|
24
27
|
nshtrainer/callbacks/print_table.py,sha256=_FdAHhqylWGk4Z0c2FrLFeiMA4jhfA_beZRK_BHpzmE,2837
|
|
28
|
+
nshtrainer/callbacks/rlp_sanity_checks.py,sha256=c30G9jAu42QLLIS5LnusdSnI3wqyIHgOUFDRcKESuNI,9935
|
|
29
|
+
nshtrainer/callbacks/shared_parameters.py,sha256=fqlDweFDXPV_bfcAWpRgaJIad9i5AehYDtuJjDtUum4,2922
|
|
25
30
|
nshtrainer/callbacks/throughput_monitor.py,sha256=H_ocXErZxUO3dxFk8Tx_VQdpI9E_Ztvqof5WtFevLyQ,1838
|
|
26
31
|
nshtrainer/callbacks/timer.py,sha256=quS79oYClDUvQxJkNWmDMe0hwRUkkREgTgqzVrnom50,4607
|
|
27
32
|
nshtrainer/callbacks/wandb_watch.py,sha256=Y6SEXfIx3kDDQbI5zpP53BVq0FBLJbLd3RJsiHZk1-Y,2921
|
|
28
|
-
nshtrainer/config
|
|
29
|
-
nshtrainer/config/duration.py,sha256=f_obz0eorkktI3HzAuIawABDkvuL4lDqCxcPb3UW7Q4,692
|
|
33
|
+
nshtrainer/config.py,sha256=W6nAmn5Y1GVZto9vkx4v8i5XdikMSdVYDiq7kbDEWAg,5900
|
|
30
34
|
nshtrainer/data/__init__.py,sha256=7mk1tr7SWUZ7ySbsf0y0ZPszk7u4QznPhQ-7wnpH9ec,149
|
|
31
35
|
nshtrainer/data/balanced_batch_sampler.py,sha256=dGBTDDtlBU6c-ZlVQOCnTW7SjTB5hczWsOWEdUWjvkA,4385
|
|
32
36
|
nshtrainer/data/transform.py,sha256=6SNs3_TpNpfhcwTwvPKyEJ3opM1OT7LmMEYQNHKgRl8,2227
|
|
33
|
-
nshtrainer/ll/__init__.py,sha256=
|
|
37
|
+
nshtrainer/ll/__init__.py,sha256=L-aTi1V1bbvnZjOro8NvI393zbHQSFR9movWSRK9Mds,2477
|
|
34
38
|
nshtrainer/ll/_experimental.py,sha256=oBQCKOEVYoxuUU9eLb-Fg2B2mzZD7SA0zfAO6lmWZ88,53
|
|
35
39
|
nshtrainer/ll/actsave.py,sha256=2lbiseSrjcwFT6AiyLNWarTWl1bnzliVWlu1iOfnP30,209
|
|
36
40
|
nshtrainer/ll/callbacks.py,sha256=AxyUmc8aGRSjx6WwwgXYCmdJ73rwLuEAEH0AGRosojQ,49
|
|
@@ -38,7 +42,7 @@ nshtrainer/ll/config.py,sha256=fKumJf42HY2FITX1QUM1OTXkYD6U2np2ciyd4PFRPZ8,145
|
|
|
38
42
|
nshtrainer/ll/data.py,sha256=zRG0FRje-jtSHximVzkHIHzpwsyQxpHCoACFihNKLPM,44
|
|
39
43
|
nshtrainer/ll/log.py,sha256=d4BB3TyM8imK65EXOiOeUTF0zFM1ropbe7Vq3DeB0xU,140
|
|
40
44
|
nshtrainer/ll/lr_scheduler.py,sha256=7xjhN6L69BCUzFhcy33NtMtPuCzHiB611zVWFg92lQ0,52
|
|
41
|
-
nshtrainer/ll/model.py,sha256=
|
|
45
|
+
nshtrainer/ll/model.py,sha256=Cw8Vq8IUL6YU1fTUcOIZsXcNJ3XyKgQY4YENIsL9H7c,996
|
|
42
46
|
nshtrainer/ll/nn.py,sha256=8qiRDFwojIxkB7-LtNWk4mLL2tJbaskHYofDsOIHiNg,42
|
|
43
47
|
nshtrainer/ll/optimizer.py,sha256=3T-VZtT73jVvwCNJGDjgGEbzs-1LFTzMQH-SB_58mSo,49
|
|
44
48
|
nshtrainer/ll/runner.py,sha256=B0m5VEhNKIjF1aFmqPkonkQxDoRL2jeHZGsV3zwhSVE,117
|
|
@@ -51,44 +55,46 @@ nshtrainer/loggers/__init__.py,sha256=C_xk0A3_qKbNdTmzK85AgjRHFD3w-jPRS2ig-iPhfE
|
|
|
51
55
|
nshtrainer/loggers/_base.py,sha256=xiZKEK0ALJkcqf4OpVNRY0QbZsamR_WR7x7m_68YHXQ,705
|
|
52
56
|
nshtrainer/loggers/csv.py,sha256=D_lYyd94bZ8jAgnRo-ARtFgVcInaD9zktxtsUD9RWCI,1052
|
|
53
57
|
nshtrainer/loggers/tensorboard.py,sha256=wL2amRSdP68zbslZvBeM0ZQBnjF3hIKsz-_lBbdomaM,2216
|
|
54
|
-
nshtrainer/loggers/wandb.py,sha256=
|
|
58
|
+
nshtrainer/loggers/wandb.py,sha256=8B2BMMzILRSUEiCkmp_fBpcXs69euRKViTiaV__DJZk,5128
|
|
55
59
|
nshtrainer/lr_scheduler/__init__.py,sha256=uEvgaFAs-4s_bAEMaildy0GT6OvgpgOEKTuzqutESHE,736
|
|
56
60
|
nshtrainer/lr_scheduler/_base.py,sha256=7xOIuxQ86YHbFWG5a3gX46emQj1WN_LaY4-i0Q1TDBg,3659
|
|
57
|
-
nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=
|
|
61
|
+
nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=YQm84Sb4SWrofpBwa39DCslJvu2uorjbpWaGWyys1l4,5352
|
|
58
62
|
nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=h76oTHYpMxauV_l6lviya5DW-WKArwxxf7ZQizhmbCw,2782
|
|
59
63
|
nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy00,50
|
|
60
64
|
nshtrainer/metrics/_config.py,sha256=jgRBfDAQLFTW7AiUY7CRtdfts6CR6keeuqm0FFMWCzQ,1288
|
|
61
|
-
nshtrainer/model/__init__.py,sha256=
|
|
62
|
-
nshtrainer/model/base.py,sha256=
|
|
63
|
-
nshtrainer/model/config.py,sha256=
|
|
64
|
-
nshtrainer/model/
|
|
65
|
-
nshtrainer/model/modules/debug.py,sha256=Yy7XEdPou9BkCsD5hJchwJGmCVGrfUru5g9VjPM4uAw,1120
|
|
66
|
-
nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
|
|
67
|
-
nshtrainer/model/modules/logger.py,sha256=CJWSmNT8SV5GLtfml-qGYenqRPXcNOMsJRGEavAd8Hw,5464
|
|
68
|
-
nshtrainer/model/modules/profiler.py,sha256=rQ_jRMcM1Z2AIROZlRnBRHM5rkTpq67afZPD6CIRfXs,825
|
|
69
|
-
nshtrainer/model/modules/rlp_sanity_checks.py,sha256=I_ralr2ThQ-D_FkVQTwbdXLLlgHJEr7-s01I5wSDjps,8893
|
|
70
|
-
nshtrainer/model/modules/shared_parameters.py,sha256=ZiRKkZXr6RwdwLCdZCJPl3dXe7bnT8Z9yTeRK5bXBGk,2687
|
|
65
|
+
nshtrainer/model/__init__.py,sha256=2i_VEy6u_Y1LUGKljHXWeekvhnUcanZM2QyaaBM1Bmw,261
|
|
66
|
+
nshtrainer/model/base.py,sha256=1zVY8ybZTzVKhpp7sUC0t360Ut3YmdGxAW5PZAIBSyw,18535
|
|
67
|
+
nshtrainer/model/config.py,sha256=Q4Wong6w3cp_Sq7s8iZdABKF-LZBbSCFn_TQPYkhkrI,6572
|
|
68
|
+
nshtrainer/model/mixins/logger.py,sha256=xOymSTofukEYZGkGojXsMEO__ZlBI5lIPZVmlotMEX8,5291
|
|
71
69
|
nshtrainer/nn/__init__.py,sha256=0QPFl02a71WZQjLMGOlFNMmsYP5aa1q3eABHmnWH58Q,1427
|
|
72
70
|
nshtrainer/nn/mlp.py,sha256=V0FrScpIUdg_IgIO8GMtIsGEtmHjwF14i2IWxmZrsqg,5952
|
|
73
71
|
nshtrainer/nn/module_dict.py,sha256=NOY0B6WDTnktyWH4GthsprMQo0bpehC-hCq9SfD8paE,2329
|
|
74
72
|
nshtrainer/nn/module_list.py,sha256=fb2u5Rqdjff8Pekyr9hkCPkBorQ-fldzzFAjsgWAm30,1719
|
|
75
73
|
nshtrainer/nn/nonlinearity.py,sha256=4sYE4MN5zojc-go1k0PYtqssVRuXrM7D4tbpIXp5K-E,6078
|
|
76
74
|
nshtrainer/optimizer.py,sha256=kuJEA1pvB3y1FcsfhAoOJujVqEZqFHlmYO8GW6JeA1g,1527
|
|
75
|
+
nshtrainer/profiler/__init__.py,sha256=RQYkqQBVWuVvfdtAJIk2x5bNsXownklT87Mr_j-uXjw,474
|
|
76
|
+
nshtrainer/profiler/_base.py,sha256=YF5lsJBIl9qts9GLW5Z62JuYeo4SnIArhlFwTGkfTb4,897
|
|
77
|
+
nshtrainer/profiler/advanced.py,sha256=44asloha0aGUW8YwjQt3lm3ve8H-N6mM4QgseUSLT30,1112
|
|
78
|
+
nshtrainer/profiler/pytorch.py,sha256=tGeRvoPP5ulWX2RkfXrQvMBoki1T95dpz5p8mwyon1I,2709
|
|
79
|
+
nshtrainer/profiler/simple.py,sha256=MbMfsJvligd0mtGiltxJ0T8MQVDP9T9BzQZFwswl66Y,957
|
|
77
80
|
nshtrainer/runner.py,sha256=USAjrExHkN5oVNVunsoPnLxfQrEHSaa54S3RipOe544,3605
|
|
78
81
|
nshtrainer/scripts/find_packages.py,sha256=ixYivZobumyyGsf2B9oYMLyLTRcBzY_vUv-u3bNW-hs,1424
|
|
79
82
|
nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3E,40
|
|
83
|
+
nshtrainer/trainer/_config.py,sha256=ZIodM5Ek1lpkWFhQ_VfmKR7q1mZFFwtjfx8FH72H8WM,29174
|
|
80
84
|
nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
|
|
81
85
|
nshtrainer/trainer/checkpoint_connector.py,sha256=r0ir4xYSdf_jebM0x09qaO6nJsvsiRQDyM0fs80ppOQ,2347
|
|
82
86
|
nshtrainer/trainer/signal_connector.py,sha256=2EzkVktlasl8PgWAKNLDZRUMY__gRlDy1HdinAU-tfU,10740
|
|
83
|
-
nshtrainer/trainer/trainer.py,sha256=
|
|
87
|
+
nshtrainer/trainer/trainer.py,sha256=iYueHW-m8fHyC8SQuXmpgxq_-GUa7pAJik7rDFPXmy0,17499
|
|
84
88
|
nshtrainer/util/_environment_info.py,sha256=CFUUZYjXhBLWGc0jtPNOaZgYMueUDEHpEaWFA1f3GoY,24213
|
|
85
89
|
nshtrainer/util/_useful_types.py,sha256=dwZokFkIe7M5i2GR3nQ9A1lhGw06DMAFfH5atyquqSA,8000
|
|
90
|
+
nshtrainer/util/config/__init__.py,sha256=6iCFLhujhbOi7Q694e--Sq-ethiGoGHShm699GPV8Zg,154
|
|
91
|
+
nshtrainer/util/config/duration.py,sha256=f_obz0eorkktI3HzAuIawABDkvuL4lDqCxcPb3UW7Q4,692
|
|
86
92
|
nshtrainer/util/environment.py,sha256=AeW_kLl-N70wmb6L_JLz1wRj0kA70xs6RCmc9iUqczE,4159
|
|
87
93
|
nshtrainer/util/path.py,sha256=VkpuhR4GaZtSFBVqbGAvfjcrU-PR8xwiGzzwFNOWP9c,2995
|
|
88
94
|
nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
|
|
89
95
|
nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
|
|
90
96
|
nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
|
|
91
97
|
nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
|
|
92
|
-
nshtrainer-0.
|
|
93
|
-
nshtrainer-0.
|
|
94
|
-
nshtrainer-0.
|
|
98
|
+
nshtrainer-0.32.0.dist-info/METADATA,sha256=pe-TVRS0ZmZ9kx5NBQ8-0C6m4ZzaH_MalJZmh31mUNQ,916
|
|
99
|
+
nshtrainer-0.32.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
100
|
+
nshtrainer-0.32.0.dist-info/RECORD,,
|
|
@@ -1,206 +0,0 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
from collections.abc import Callable, Iterable, Sequence
|
|
3
|
-
from typing import Any, TypeAlias, cast, final, overload
|
|
4
|
-
|
|
5
|
-
from lightning.pytorch import Callback, LightningModule
|
|
6
|
-
from lightning.pytorch.callbacks import LambdaCallback
|
|
7
|
-
from typing_extensions import override
|
|
8
|
-
|
|
9
|
-
from ...util.typing_utils import mixin_base_type
|
|
10
|
-
|
|
11
|
-
log = logging.getLogger(__name__)
|
|
12
|
-
|
|
13
|
-
CallbackFn: TypeAlias = Callable[[], Callback | Iterable[Callback] | None]
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class CallbackRegistrarModuleMixin:
|
|
17
|
-
@override
|
|
18
|
-
def __init__(self, *args, **kwargs):
|
|
19
|
-
super().__init__(*args, **kwargs)
|
|
20
|
-
|
|
21
|
-
self._nshtrainer_callbacks: list[CallbackFn] = []
|
|
22
|
-
|
|
23
|
-
@overload
|
|
24
|
-
def register_callback(
|
|
25
|
-
self, callback: Callback | Iterable[Callback] | CallbackFn | None = None, /
|
|
26
|
-
): ...
|
|
27
|
-
|
|
28
|
-
@overload
|
|
29
|
-
def register_callback(
|
|
30
|
-
self,
|
|
31
|
-
/,
|
|
32
|
-
*,
|
|
33
|
-
setup: Callable | None = None,
|
|
34
|
-
teardown: Callable | None = None,
|
|
35
|
-
on_fit_start: Callable | None = None,
|
|
36
|
-
on_fit_end: Callable | None = None,
|
|
37
|
-
on_sanity_check_start: Callable | None = None,
|
|
38
|
-
on_sanity_check_end: Callable | None = None,
|
|
39
|
-
on_train_batch_start: Callable | None = None,
|
|
40
|
-
on_train_batch_end: Callable | None = None,
|
|
41
|
-
on_train_epoch_start: Callable | None = None,
|
|
42
|
-
on_train_epoch_end: Callable | None = None,
|
|
43
|
-
on_validation_epoch_start: Callable | None = None,
|
|
44
|
-
on_validation_epoch_end: Callable | None = None,
|
|
45
|
-
on_test_epoch_start: Callable | None = None,
|
|
46
|
-
on_test_epoch_end: Callable | None = None,
|
|
47
|
-
on_validation_batch_start: Callable | None = None,
|
|
48
|
-
on_validation_batch_end: Callable | None = None,
|
|
49
|
-
on_test_batch_start: Callable | None = None,
|
|
50
|
-
on_test_batch_end: Callable | None = None,
|
|
51
|
-
on_train_start: Callable | None = None,
|
|
52
|
-
on_train_end: Callable | None = None,
|
|
53
|
-
on_validation_start: Callable | None = None,
|
|
54
|
-
on_validation_end: Callable | None = None,
|
|
55
|
-
on_test_start: Callable | None = None,
|
|
56
|
-
on_test_end: Callable | None = None,
|
|
57
|
-
on_exception: Callable | None = None,
|
|
58
|
-
on_save_checkpoint: Callable | None = None,
|
|
59
|
-
on_load_checkpoint: Callable | None = None,
|
|
60
|
-
on_before_backward: Callable | None = None,
|
|
61
|
-
on_after_backward: Callable | None = None,
|
|
62
|
-
on_before_optimizer_step: Callable | None = None,
|
|
63
|
-
on_before_zero_grad: Callable | None = None,
|
|
64
|
-
on_predict_start: Callable | None = None,
|
|
65
|
-
on_predict_end: Callable | None = None,
|
|
66
|
-
on_predict_batch_start: Callable | None = None,
|
|
67
|
-
on_predict_batch_end: Callable | None = None,
|
|
68
|
-
on_predict_epoch_start: Callable | None = None,
|
|
69
|
-
on_predict_epoch_end: Callable | None = None,
|
|
70
|
-
): ...
|
|
71
|
-
|
|
72
|
-
def register_callback(
|
|
73
|
-
self,
|
|
74
|
-
callback: Callback | Iterable[Callback] | CallbackFn | None = None,
|
|
75
|
-
/,
|
|
76
|
-
*,
|
|
77
|
-
setup: Callable | None = None,
|
|
78
|
-
teardown: Callable | None = None,
|
|
79
|
-
on_fit_start: Callable | None = None,
|
|
80
|
-
on_fit_end: Callable | None = None,
|
|
81
|
-
on_sanity_check_start: Callable | None = None,
|
|
82
|
-
on_sanity_check_end: Callable | None = None,
|
|
83
|
-
on_train_batch_start: Callable | None = None,
|
|
84
|
-
on_train_batch_end: Callable | None = None,
|
|
85
|
-
on_train_epoch_start: Callable | None = None,
|
|
86
|
-
on_train_epoch_end: Callable | None = None,
|
|
87
|
-
on_validation_epoch_start: Callable | None = None,
|
|
88
|
-
on_validation_epoch_end: Callable | None = None,
|
|
89
|
-
on_test_epoch_start: Callable | None = None,
|
|
90
|
-
on_test_epoch_end: Callable | None = None,
|
|
91
|
-
on_validation_batch_start: Callable | None = None,
|
|
92
|
-
on_validation_batch_end: Callable | None = None,
|
|
93
|
-
on_test_batch_start: Callable | None = None,
|
|
94
|
-
on_test_batch_end: Callable | None = None,
|
|
95
|
-
on_train_start: Callable | None = None,
|
|
96
|
-
on_train_end: Callable | None = None,
|
|
97
|
-
on_validation_start: Callable | None = None,
|
|
98
|
-
on_validation_end: Callable | None = None,
|
|
99
|
-
on_test_start: Callable | None = None,
|
|
100
|
-
on_test_end: Callable | None = None,
|
|
101
|
-
on_exception: Callable | None = None,
|
|
102
|
-
on_save_checkpoint: Callable | None = None,
|
|
103
|
-
on_load_checkpoint: Callable | None = None,
|
|
104
|
-
on_before_backward: Callable | None = None,
|
|
105
|
-
on_after_backward: Callable | None = None,
|
|
106
|
-
on_before_optimizer_step: Callable | None = None,
|
|
107
|
-
on_before_zero_grad: Callable | None = None,
|
|
108
|
-
on_predict_start: Callable | None = None,
|
|
109
|
-
on_predict_end: Callable | None = None,
|
|
110
|
-
on_predict_batch_start: Callable | None = None,
|
|
111
|
-
on_predict_batch_end: Callable | None = None,
|
|
112
|
-
on_predict_epoch_start: Callable | None = None,
|
|
113
|
-
on_predict_epoch_end: Callable | None = None,
|
|
114
|
-
):
|
|
115
|
-
if callback is None:
|
|
116
|
-
callback = LambdaCallback(
|
|
117
|
-
setup=setup,
|
|
118
|
-
teardown=teardown,
|
|
119
|
-
on_fit_start=on_fit_start,
|
|
120
|
-
on_fit_end=on_fit_end,
|
|
121
|
-
on_sanity_check_start=on_sanity_check_start,
|
|
122
|
-
on_sanity_check_end=on_sanity_check_end,
|
|
123
|
-
on_train_batch_start=on_train_batch_start,
|
|
124
|
-
on_train_batch_end=on_train_batch_end,
|
|
125
|
-
on_train_epoch_start=on_train_epoch_start,
|
|
126
|
-
on_train_epoch_end=on_train_epoch_end,
|
|
127
|
-
on_validation_epoch_start=on_validation_epoch_start,
|
|
128
|
-
on_validation_epoch_end=on_validation_epoch_end,
|
|
129
|
-
on_test_epoch_start=on_test_epoch_start,
|
|
130
|
-
on_test_epoch_end=on_test_epoch_end,
|
|
131
|
-
on_validation_batch_start=on_validation_batch_start,
|
|
132
|
-
on_validation_batch_end=on_validation_batch_end,
|
|
133
|
-
on_test_batch_start=on_test_batch_start,
|
|
134
|
-
on_test_batch_end=on_test_batch_end,
|
|
135
|
-
on_train_start=on_train_start,
|
|
136
|
-
on_train_end=on_train_end,
|
|
137
|
-
on_validation_start=on_validation_start,
|
|
138
|
-
on_validation_end=on_validation_end,
|
|
139
|
-
on_test_start=on_test_start,
|
|
140
|
-
on_test_end=on_test_end,
|
|
141
|
-
on_exception=on_exception,
|
|
142
|
-
on_save_checkpoint=on_save_checkpoint,
|
|
143
|
-
on_load_checkpoint=on_load_checkpoint,
|
|
144
|
-
on_before_backward=on_before_backward,
|
|
145
|
-
on_after_backward=on_after_backward,
|
|
146
|
-
on_before_optimizer_step=on_before_optimizer_step,
|
|
147
|
-
on_before_zero_grad=on_before_zero_grad,
|
|
148
|
-
on_predict_start=on_predict_start,
|
|
149
|
-
on_predict_end=on_predict_end,
|
|
150
|
-
on_predict_batch_start=on_predict_batch_start,
|
|
151
|
-
on_predict_batch_end=on_predict_batch_end,
|
|
152
|
-
on_predict_epoch_start=on_predict_epoch_start,
|
|
153
|
-
on_predict_epoch_end=on_predict_epoch_end,
|
|
154
|
-
)
|
|
155
|
-
|
|
156
|
-
if not callable(callback):
|
|
157
|
-
callback_ = cast(CallbackFn, lambda: callback)
|
|
158
|
-
else:
|
|
159
|
-
callback_ = callback
|
|
160
|
-
|
|
161
|
-
self._nshtrainer_callbacks.append(callback_)
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
class CallbackModuleMixin(
|
|
165
|
-
CallbackRegistrarModuleMixin,
|
|
166
|
-
mixin_base_type(LightningModule),
|
|
167
|
-
):
|
|
168
|
-
def _gather_all_callbacks(self):
|
|
169
|
-
modules: list[Any] = []
|
|
170
|
-
if isinstance(self, CallbackRegistrarModuleMixin):
|
|
171
|
-
modules.append(self)
|
|
172
|
-
if (
|
|
173
|
-
datamodule := getattr(self.trainer, "datamodule", None)
|
|
174
|
-
) is not None and isinstance(datamodule, CallbackRegistrarModuleMixin):
|
|
175
|
-
modules.append(datamodule)
|
|
176
|
-
modules.extend(
|
|
177
|
-
module
|
|
178
|
-
for module in self.children()
|
|
179
|
-
if isinstance(module, CallbackRegistrarModuleMixin)
|
|
180
|
-
)
|
|
181
|
-
for module in modules:
|
|
182
|
-
yield from module._nshtrainer_callbacks
|
|
183
|
-
|
|
184
|
-
@final
|
|
185
|
-
@override
|
|
186
|
-
def configure_callbacks(self):
|
|
187
|
-
callbacks = super().configure_callbacks()
|
|
188
|
-
if not isinstance(callbacks, Sequence):
|
|
189
|
-
callbacks = [callbacks]
|
|
190
|
-
|
|
191
|
-
callbacks = list(callbacks)
|
|
192
|
-
for callback_fn in self._gather_all_callbacks():
|
|
193
|
-
callback_result = callback_fn()
|
|
194
|
-
if callback_result is None:
|
|
195
|
-
continue
|
|
196
|
-
|
|
197
|
-
if not isinstance(callback_result, Iterable):
|
|
198
|
-
callback_result = [callback_result]
|
|
199
|
-
|
|
200
|
-
for callback in callback_result:
|
|
201
|
-
log.info(
|
|
202
|
-
f"Registering {callback.__class__.__qualname__} callback {callback}"
|
|
203
|
-
)
|
|
204
|
-
callbacks.append(callback)
|
|
205
|
-
|
|
206
|
-
return callbacks
|
|
@@ -1,42 +0,0 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
import torch.distributed
|
|
5
|
-
|
|
6
|
-
log = logging.getLogger(__name__)
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
class DebugModuleMixin:
|
|
10
|
-
@torch.jit.unused
|
|
11
|
-
def breakpoint(self, rank_zero_only: bool = True):
|
|
12
|
-
if (
|
|
13
|
-
not rank_zero_only
|
|
14
|
-
or not torch.distributed.is_initialized()
|
|
15
|
-
or torch.distributed.get_rank() == 0
|
|
16
|
-
):
|
|
17
|
-
breakpoint()
|
|
18
|
-
|
|
19
|
-
if rank_zero_only and torch.distributed.is_initialized():
|
|
20
|
-
_ = torch.distributed.barrier()
|
|
21
|
-
|
|
22
|
-
@torch.jit.unused
|
|
23
|
-
def ensure_finite(
|
|
24
|
-
self,
|
|
25
|
-
tensor: torch.Tensor,
|
|
26
|
-
name: str | None = None,
|
|
27
|
-
throw: bool = False,
|
|
28
|
-
):
|
|
29
|
-
name_parts: list[str] = ["Tensor"]
|
|
30
|
-
if name is not None:
|
|
31
|
-
name_parts.append(name)
|
|
32
|
-
name = " ".join(name_parts)
|
|
33
|
-
|
|
34
|
-
not_finite = ~torch.isfinite(tensor)
|
|
35
|
-
if not_finite.any():
|
|
36
|
-
msg = f"{name} has {not_finite.sum().item()}/{not_finite.numel()} non-finite values."
|
|
37
|
-
if throw:
|
|
38
|
-
raise RuntimeError(msg)
|
|
39
|
-
else:
|
|
40
|
-
log.warning(msg)
|
|
41
|
-
return False
|
|
42
|
-
return True
|
|
@@ -1,70 +0,0 @@
|
|
|
1
|
-
from typing import Any, Literal, cast
|
|
2
|
-
|
|
3
|
-
import torch.distributed
|
|
4
|
-
from lightning.pytorch import LightningModule
|
|
5
|
-
from torch.distributed import ReduceOp
|
|
6
|
-
from typing_extensions import TypeVar
|
|
7
|
-
|
|
8
|
-
from ...util.typing_utils import mixin_base_type
|
|
9
|
-
|
|
10
|
-
T = TypeVar("T", infer_variance=True)
|
|
11
|
-
|
|
12
|
-
ReduceOpStr = Literal[
|
|
13
|
-
"avg",
|
|
14
|
-
"mean",
|
|
15
|
-
"band",
|
|
16
|
-
"bor",
|
|
17
|
-
"bxor",
|
|
18
|
-
"max",
|
|
19
|
-
"min",
|
|
20
|
-
"premul_sum",
|
|
21
|
-
"product",
|
|
22
|
-
"sum",
|
|
23
|
-
]
|
|
24
|
-
VALID_REDUCE_OPS = (
|
|
25
|
-
"avg",
|
|
26
|
-
"mean",
|
|
27
|
-
"band",
|
|
28
|
-
"bor",
|
|
29
|
-
"bxor",
|
|
30
|
-
"max",
|
|
31
|
-
"min",
|
|
32
|
-
"premul_sum",
|
|
33
|
-
"product",
|
|
34
|
-
"sum",
|
|
35
|
-
)
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
class DistributedMixin(mixin_base_type(LightningModule)):
|
|
39
|
-
def all_gather_object(
|
|
40
|
-
self,
|
|
41
|
-
object: T,
|
|
42
|
-
group: torch.distributed.ProcessGroup | None = None,
|
|
43
|
-
) -> list[T]:
|
|
44
|
-
if (
|
|
45
|
-
not torch.distributed.is_available()
|
|
46
|
-
or not torch.distributed.is_initialized()
|
|
47
|
-
):
|
|
48
|
-
return [object]
|
|
49
|
-
|
|
50
|
-
object_list = [cast(T, None) for _ in range(self.trainer.world_size)]
|
|
51
|
-
torch.distributed.all_gather_object(object_list, object, group=group)
|
|
52
|
-
return object_list
|
|
53
|
-
|
|
54
|
-
def barrier(self, name: str | None = None):
|
|
55
|
-
self.trainer.strategy.barrier(name=name)
|
|
56
|
-
|
|
57
|
-
def reduce(
|
|
58
|
-
self,
|
|
59
|
-
tensor: torch.Tensor,
|
|
60
|
-
reduce_op: ReduceOp.RedOpType | ReduceOpStr,
|
|
61
|
-
group: Any | None = None,
|
|
62
|
-
) -> torch.Tensor:
|
|
63
|
-
if isinstance(reduce_op, str):
|
|
64
|
-
# validate reduce_op
|
|
65
|
-
if reduce_op not in VALID_REDUCE_OPS:
|
|
66
|
-
raise ValueError(
|
|
67
|
-
f"reduce_op must be one of {VALID_REDUCE_OPS}, got {reduce_op}"
|
|
68
|
-
)
|
|
69
|
-
|
|
70
|
-
return self.trainer.strategy.reduce(tensor, group=group, reduce_op=reduce_op)
|
|
@@ -1,24 +0,0 @@
|
|
|
1
|
-
from lightning.pytorch import LightningDataModule, LightningModule
|
|
2
|
-
from lightning.pytorch.profilers import PassThroughProfiler
|
|
3
|
-
|
|
4
|
-
from ...util.typing_utils import mixin_base_type
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
class ProfilerMixin(mixin_base_type(LightningModule)):
|
|
8
|
-
@property
|
|
9
|
-
def profiler(self):
|
|
10
|
-
if not isinstance(self, (LightningModule, LightningDataModule)):
|
|
11
|
-
raise TypeError(
|
|
12
|
-
"`profiler` can only be used on LightningModule or LightningDataModule"
|
|
13
|
-
)
|
|
14
|
-
|
|
15
|
-
if (trainer := self.trainer) is None:
|
|
16
|
-
raise RuntimeError("trainer is not defined")
|
|
17
|
-
|
|
18
|
-
if not hasattr(trainer, "profiler"):
|
|
19
|
-
raise RuntimeError("trainer does not have profiler")
|
|
20
|
-
|
|
21
|
-
if (profiler := getattr(trainer, "profiler")) is None:
|
|
22
|
-
profiler = PassThroughProfiler()
|
|
23
|
-
|
|
24
|
-
return profiler
|