nshtrainer 0.30.1__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. nshtrainer/__init__.py +1 -2
  2. nshtrainer/_directory.py +85 -0
  3. nshtrainer/callbacks/__init__.py +12 -1
  4. nshtrainer/callbacks/debug_flag.py +72 -0
  5. nshtrainer/callbacks/directory_setup.py +85 -0
  6. nshtrainer/callbacks/rlp_sanity_checks.py +230 -0
  7. nshtrainer/callbacks/shared_parameters.py +87 -0
  8. nshtrainer/config.py +67 -0
  9. nshtrainer/ll/__init__.py +5 -4
  10. nshtrainer/ll/model.py +7 -0
  11. nshtrainer/loggers/wandb.py +1 -1
  12. nshtrainer/lr_scheduler/linear_warmup_cosine.py +1 -1
  13. nshtrainer/model/__init__.py +0 -21
  14. nshtrainer/model/base.py +124 -67
  15. nshtrainer/model/config.py +7 -1025
  16. nshtrainer/model/{modules → mixins}/logger.py +13 -16
  17. nshtrainer/profiler/__init__.py +13 -0
  18. nshtrainer/profiler/_base.py +29 -0
  19. nshtrainer/profiler/advanced.py +37 -0
  20. nshtrainer/profiler/pytorch.py +83 -0
  21. nshtrainer/profiler/simple.py +36 -0
  22. nshtrainer/trainer/_config.py +787 -0
  23. nshtrainer/trainer/trainer.py +16 -17
  24. nshtrainer/{config → util/config}/__init__.py +1 -0
  25. {nshtrainer-0.30.1.dist-info → nshtrainer-0.32.0.dist-info}/METADATA +1 -1
  26. {nshtrainer-0.30.1.dist-info → nshtrainer-0.32.0.dist-info}/RECORD +28 -22
  27. nshtrainer/model/modules/callback.py +0 -206
  28. nshtrainer/model/modules/debug.py +0 -42
  29. nshtrainer/model/modules/distributed.py +0 -70
  30. nshtrainer/model/modules/profiler.py +0 -24
  31. nshtrainer/model/modules/rlp_sanity_checks.py +0 -202
  32. nshtrainer/model/modules/shared_parameters.py +0 -72
  33. /nshtrainer/{config → util/config}/duration.py +0 -0
  34. {nshtrainer-0.30.1.dist-info → nshtrainer-0.32.0.dist-info}/WHEEL +0 -0
@@ -18,10 +18,8 @@ from typing_extensions import Unpack, assert_never, override
18
18
 
19
19
  from .._checkpoint.metadata import _write_checkpoint_metadata
20
20
  from ..callbacks.base import resolve_all_callbacks
21
- from ..model.config import (
21
+ from ._config import (
22
22
  AcceleratorConfigProtocol,
23
- BaseConfig,
24
- BaseProfilerConfig,
25
23
  LightningTrainerKwargs,
26
24
  StrategyConfigProtocol,
27
25
  )
@@ -29,6 +27,9 @@ from ._runtime_callback import RuntimeTrackerCallback, Stage
29
27
  from .checkpoint_connector import _CheckpointConnector
30
28
  from .signal_connector import _SignalConnector
31
29
 
30
+ if TYPE_CHECKING:
31
+ from ..model.config import BaseConfig
32
+
32
33
  log = logging.getLogger(__name__)
33
34
 
34
35
 
@@ -58,14 +59,14 @@ def _is_bf16_supported_no_emulation():
58
59
 
59
60
  class Trainer(LightningTrainer):
60
61
  @classmethod
61
- def _pre_init(cls, config: BaseConfig):
62
+ def _pre_init(cls, config: "BaseConfig"):
62
63
  if (precision := config.trainer.set_float32_matmul_precision) is not None:
63
64
  torch.set_float32_matmul_precision(precision)
64
65
 
65
66
  @classmethod
66
67
  def _update_kwargs(
67
68
  cls,
68
- config: BaseConfig,
69
+ config: "BaseConfig",
69
70
  kwargs_ctor: LightningTrainerKwargs,
70
71
  ):
71
72
  kwargs: LightningTrainerKwargs = {
@@ -217,18 +218,16 @@ class Trainer(LightningTrainer):
217
218
  gradient_clip_val=grad_clip_config.value,
218
219
  )
219
220
 
220
- if profiler := config.trainer.profiler:
221
- # If the profiler is an ProfilerConfig instance, then we instantiate it.
222
- if isinstance(profiler, BaseProfilerConfig):
223
- profiler = profiler.create_profiler(config)
224
- # Make sure that the profiler is an instance of `Profiler`.
225
- if not isinstance(profiler, Profiler):
226
- raise ValueError(f"{profiler=} is not an instance of `{Profiler}`.")
227
-
221
+ if profiler_config := config.trainer.profiler:
222
+ if (profiler := profiler_config.create_profiler(config)) is None:
223
+ log.warning(f"Profiler config {profiler_config=} returned None.")
224
+ # Make sure that the profiler is an instance of `Profiler`.
225
+ elif not isinstance(profiler, Profiler):
226
+ raise ValueError(f"{profiler=} is not an instance of `{Profiler}`.")
228
227
  # Otherwise, if the profiler is a string (e.g., "simpe", "advanced", "pytorch"),
229
228
  # then we just pass it through.
230
- # kwargs["profiler"] = profiler
231
- _update_kwargs(profiler=profiler)
229
+ else:
230
+ _update_kwargs(profiler=profiler)
232
231
 
233
232
  if callbacks := resolve_all_callbacks(config):
234
233
  _update_kwargs(callbacks=callbacks)
@@ -281,7 +280,7 @@ class Trainer(LightningTrainer):
281
280
  @override
282
281
  def __init__(
283
282
  self,
284
- config: BaseConfig,
283
+ config: "BaseConfig",
285
284
  /,
286
285
  **kwargs: Unpack[LightningTrainerKwargs],
287
286
  ):
@@ -424,7 +423,7 @@ class Trainer(LightningTrainer):
424
423
  # Save the checkpoint metadata
425
424
  metadata_path = None
426
425
  lm = self._base_module
427
- root_config = cast(BaseConfig, lm.hparams)
426
+ root_config = cast("BaseConfig", lm.hparams)
428
427
  if root_config.trainer.save_checkpoint_metadata and self.is_global_zero:
429
428
  # Generate the metadata and write to disk
430
429
  if (
@@ -1,3 +1,4 @@
1
+ from . import duration as duration
1
2
  from .duration import Duration as Duration
2
3
  from .duration import Epochs as Epochs
3
4
  from .duration import Steps as Steps
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.30.1
3
+ Version: 0.32.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,11 +1,12 @@
1
- nshtrainer/__init__.py,sha256=sUb2yNdkHHhrKWCeWA5QKIA1Xx3jkO1QGD5Pa-HvgbA,614
1
+ nshtrainer/__init__.py,sha256=flMI50Hj1Ie8c1YMSUQ759AqtNBQLT_zHaV2J9EUmOs,573
2
2
  nshtrainer/_callback.py,sha256=A1zLsTy4b_wOYnInLLXGSRdHzT2yNa6mPEql-ozm0u0,1013
3
3
  nshtrainer/_checkpoint/loader.py,sha256=5vjg-OFChXJjgiOVv8vnV8nwTscfdDtEdxQRz6uPfDE,14158
4
4
  nshtrainer/_checkpoint/metadata.py,sha256=5D4PgKodzhLsmQvuF3xxkH49epKaegxi4wh_ImDTtns,4737
5
5
  nshtrainer/_checkpoint/saver.py,sha256=MbX_WjkDtHHAf9Ms-KXDlknkjiPXVoGIe2ciO28AdZ0,1264
6
+ nshtrainer/_directory.py,sha256=RjnW6vKTeKlz2vQWT3cG0Jje5BkFXA7HpUubDhcSiq4,2993
6
7
  nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
7
8
  nshtrainer/_hf_hub.py,sha256=0bkXkqhve5D1onMW-fCfuvVKlTn0i6jv_6uMNgZ7OHQ,12974
8
- nshtrainer/callbacks/__init__.py,sha256=4qocBDzQbLLhhbIEfvbA3SQB_Dy9ZJH7keMwPay-ZS8,2359
9
+ nshtrainer/callbacks/__init__.py,sha256=1SBLpMsx7BzgimO35MwQViYBcbgxlkyvTMz1JKUKK-0,3060
9
10
  nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
10
11
  nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
11
12
  nshtrainer/callbacks/base.py,sha256=NpjeKmonJ1Kaz5_39XSn3LlDwvbGjk6WV8BpHSNCvI4,3508
@@ -14,6 +15,8 @@ nshtrainer/callbacks/checkpoint/_base.py,sha256=vvlwuD-20NozYVIolGGShmUdkkNYeuwN
14
15
  nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=8BHgLAd3Tuzf5sup0guEAKF1jJiAwYsjdKBFYZw98ac,2171
15
16
  nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CWWv0cSwQ1VAX26N7hAyMxbNCk26Keh39oQguBEK5To,1102
16
17
  nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=ctT88EGT22_t_6tr5r7Sfo43cuve6XeroBnBYRMPOus,3372
18
+ nshtrainer/callbacks/debug_flag.py,sha256=Mo69CtJqPWMlFBvgBEuYls8Vfp5v1QFiyMRTiMStdec,2059
19
+ nshtrainer/callbacks/directory_setup.py,sha256=c0uY0oTqLcQ3egInHO7G6BeQQgk_xvOLoHH8FR-9U0U,2629
17
20
  nshtrainer/callbacks/early_stopping.py,sha256=VWuJz0oN87b6SwBeVc32YNpeJr1wts8K45k8JJJmG9I,4617
18
21
  nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,12131
19
22
  nshtrainer/callbacks/finite_checks.py,sha256=gJC_RUr3ais3FJI0uB6wUZnDdE3WRwCix3ppA3PwQXA,2077
@@ -22,15 +25,16 @@ nshtrainer/callbacks/interval.py,sha256=smz5Zl8cN6X6yHKVsMRS2e3SEkzRCP3LvwE1ONvL
22
25
  nshtrainer/callbacks/log_epoch.py,sha256=fTa_K_Y8A7g09630cG4YkDE6AzSMPkjb9bpPm4gtqos,1120
23
26
  nshtrainer/callbacks/norm_logging.py,sha256=T2psu8mYsw9iahPKT6aUPjkGrZ4TIzm6_UUUmE09GJs,6274
24
27
  nshtrainer/callbacks/print_table.py,sha256=_FdAHhqylWGk4Z0c2FrLFeiMA4jhfA_beZRK_BHpzmE,2837
28
+ nshtrainer/callbacks/rlp_sanity_checks.py,sha256=c30G9jAu42QLLIS5LnusdSnI3wqyIHgOUFDRcKESuNI,9935
29
+ nshtrainer/callbacks/shared_parameters.py,sha256=fqlDweFDXPV_bfcAWpRgaJIad9i5AehYDtuJjDtUum4,2922
25
30
  nshtrainer/callbacks/throughput_monitor.py,sha256=H_ocXErZxUO3dxFk8Tx_VQdpI9E_Ztvqof5WtFevLyQ,1838
26
31
  nshtrainer/callbacks/timer.py,sha256=quS79oYClDUvQxJkNWmDMe0hwRUkkREgTgqzVrnom50,4607
27
32
  nshtrainer/callbacks/wandb_watch.py,sha256=Y6SEXfIx3kDDQbI5zpP53BVq0FBLJbLd3RJsiHZk1-Y,2921
28
- nshtrainer/config/__init__.py,sha256=v9RtlM1Pqj_4fCDfskgxEtiGtbWH3Tj7lqNsKCDQ4gk,119
29
- nshtrainer/config/duration.py,sha256=f_obz0eorkktI3HzAuIawABDkvuL4lDqCxcPb3UW7Q4,692
33
+ nshtrainer/config.py,sha256=W6nAmn5Y1GVZto9vkx4v8i5XdikMSdVYDiq7kbDEWAg,5900
30
34
  nshtrainer/data/__init__.py,sha256=7mk1tr7SWUZ7ySbsf0y0ZPszk7u4QznPhQ-7wnpH9ec,149
31
35
  nshtrainer/data/balanced_batch_sampler.py,sha256=dGBTDDtlBU6c-ZlVQOCnTW7SjTB5hczWsOWEdUWjvkA,4385
32
36
  nshtrainer/data/transform.py,sha256=6SNs3_TpNpfhcwTwvPKyEJ3opM1OT7LmMEYQNHKgRl8,2227
33
- nshtrainer/ll/__init__.py,sha256=6UTt2apSD8tOZw3M7hyd-33v4RKSpNNATlWFbW4cNnU,2523
37
+ nshtrainer/ll/__init__.py,sha256=L-aTi1V1bbvnZjOro8NvI393zbHQSFR9movWSRK9Mds,2477
34
38
  nshtrainer/ll/_experimental.py,sha256=oBQCKOEVYoxuUU9eLb-Fg2B2mzZD7SA0zfAO6lmWZ88,53
35
39
  nshtrainer/ll/actsave.py,sha256=2lbiseSrjcwFT6AiyLNWarTWl1bnzliVWlu1iOfnP30,209
36
40
  nshtrainer/ll/callbacks.py,sha256=AxyUmc8aGRSjx6WwwgXYCmdJ73rwLuEAEH0AGRosojQ,49
@@ -38,7 +42,7 @@ nshtrainer/ll/config.py,sha256=fKumJf42HY2FITX1QUM1OTXkYD6U2np2ciyd4PFRPZ8,145
38
42
  nshtrainer/ll/data.py,sha256=zRG0FRje-jtSHximVzkHIHzpwsyQxpHCoACFihNKLPM,44
39
43
  nshtrainer/ll/log.py,sha256=d4BB3TyM8imK65EXOiOeUTF0zFM1ropbe7Vq3DeB0xU,140
40
44
  nshtrainer/ll/lr_scheduler.py,sha256=7xjhN6L69BCUzFhcy33NtMtPuCzHiB611zVWFg92lQ0,52
41
- nshtrainer/ll/model.py,sha256=cxFQfFc-2mAYBGwDpP8m5tjQBs7M47cZ6JoPXksPaoI,473
45
+ nshtrainer/ll/model.py,sha256=Cw8Vq8IUL6YU1fTUcOIZsXcNJ3XyKgQY4YENIsL9H7c,996
42
46
  nshtrainer/ll/nn.py,sha256=8qiRDFwojIxkB7-LtNWk4mLL2tJbaskHYofDsOIHiNg,42
43
47
  nshtrainer/ll/optimizer.py,sha256=3T-VZtT73jVvwCNJGDjgGEbzs-1LFTzMQH-SB_58mSo,49
44
48
  nshtrainer/ll/runner.py,sha256=B0m5VEhNKIjF1aFmqPkonkQxDoRL2jeHZGsV3zwhSVE,117
@@ -51,44 +55,46 @@ nshtrainer/loggers/__init__.py,sha256=C_xk0A3_qKbNdTmzK85AgjRHFD3w-jPRS2ig-iPhfE
51
55
  nshtrainer/loggers/_base.py,sha256=xiZKEK0ALJkcqf4OpVNRY0QbZsamR_WR7x7m_68YHXQ,705
52
56
  nshtrainer/loggers/csv.py,sha256=D_lYyd94bZ8jAgnRo-ARtFgVcInaD9zktxtsUD9RWCI,1052
53
57
  nshtrainer/loggers/tensorboard.py,sha256=wL2amRSdP68zbslZvBeM0ZQBnjF3hIKsz-_lBbdomaM,2216
54
- nshtrainer/loggers/wandb.py,sha256=FPwbf618AYmuPzHdhd1ZFhJ8qDjwTUiSe7cm7g3KCyM,5112
58
+ nshtrainer/loggers/wandb.py,sha256=8B2BMMzILRSUEiCkmp_fBpcXs69euRKViTiaV__DJZk,5128
55
59
  nshtrainer/lr_scheduler/__init__.py,sha256=uEvgaFAs-4s_bAEMaildy0GT6OvgpgOEKTuzqutESHE,736
56
60
  nshtrainer/lr_scheduler/_base.py,sha256=7xOIuxQ86YHbFWG5a3gX46emQj1WN_LaY4-i0Q1TDBg,3659
57
- nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=Fyontbfu4k2932xZenE63QL4CrVGWANXdTeq63dUko0,5347
61
+ nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=YQm84Sb4SWrofpBwa39DCslJvu2uorjbpWaGWyys1l4,5352
58
62
  nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=h76oTHYpMxauV_l6lviya5DW-WKArwxxf7ZQizhmbCw,2782
59
63
  nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy00,50
60
64
  nshtrainer/metrics/_config.py,sha256=jgRBfDAQLFTW7AiUY7CRtdfts6CR6keeuqm0FFMWCzQ,1288
61
- nshtrainer/model/__init__.py,sha256=VyRziPT3YilP6xjLi_StsSqtlvn7N4LOMzgukRsOnF8,1380
62
- nshtrainer/model/base.py,sha256=oQVolDk81acy4OlckwQEBHuX2gCaVSYiIA0JaDIfhQ4,17517
63
- nshtrainer/model/config.py,sha256=zcCLcqvg4u7Zg6SLtCnqdIfiW8I0eART47lf1LCYl-A,43326
64
- nshtrainer/model/modules/callback.py,sha256=1z6gUDBd35KG3phGzRekgZM6SIk-wj5Uo6APN4YhRR0,8549
65
- nshtrainer/model/modules/debug.py,sha256=Yy7XEdPou9BkCsD5hJchwJGmCVGrfUru5g9VjPM4uAw,1120
66
- nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
67
- nshtrainer/model/modules/logger.py,sha256=CJWSmNT8SV5GLtfml-qGYenqRPXcNOMsJRGEavAd8Hw,5464
68
- nshtrainer/model/modules/profiler.py,sha256=rQ_jRMcM1Z2AIROZlRnBRHM5rkTpq67afZPD6CIRfXs,825
69
- nshtrainer/model/modules/rlp_sanity_checks.py,sha256=I_ralr2ThQ-D_FkVQTwbdXLLlgHJEr7-s01I5wSDjps,8893
70
- nshtrainer/model/modules/shared_parameters.py,sha256=ZiRKkZXr6RwdwLCdZCJPl3dXe7bnT8Z9yTeRK5bXBGk,2687
65
+ nshtrainer/model/__init__.py,sha256=2i_VEy6u_Y1LUGKljHXWeekvhnUcanZM2QyaaBM1Bmw,261
66
+ nshtrainer/model/base.py,sha256=1zVY8ybZTzVKhpp7sUC0t360Ut3YmdGxAW5PZAIBSyw,18535
67
+ nshtrainer/model/config.py,sha256=Q4Wong6w3cp_Sq7s8iZdABKF-LZBbSCFn_TQPYkhkrI,6572
68
+ nshtrainer/model/mixins/logger.py,sha256=xOymSTofukEYZGkGojXsMEO__ZlBI5lIPZVmlotMEX8,5291
71
69
  nshtrainer/nn/__init__.py,sha256=0QPFl02a71WZQjLMGOlFNMmsYP5aa1q3eABHmnWH58Q,1427
72
70
  nshtrainer/nn/mlp.py,sha256=V0FrScpIUdg_IgIO8GMtIsGEtmHjwF14i2IWxmZrsqg,5952
73
71
  nshtrainer/nn/module_dict.py,sha256=NOY0B6WDTnktyWH4GthsprMQo0bpehC-hCq9SfD8paE,2329
74
72
  nshtrainer/nn/module_list.py,sha256=fb2u5Rqdjff8Pekyr9hkCPkBorQ-fldzzFAjsgWAm30,1719
75
73
  nshtrainer/nn/nonlinearity.py,sha256=4sYE4MN5zojc-go1k0PYtqssVRuXrM7D4tbpIXp5K-E,6078
76
74
  nshtrainer/optimizer.py,sha256=kuJEA1pvB3y1FcsfhAoOJujVqEZqFHlmYO8GW6JeA1g,1527
75
+ nshtrainer/profiler/__init__.py,sha256=RQYkqQBVWuVvfdtAJIk2x5bNsXownklT87Mr_j-uXjw,474
76
+ nshtrainer/profiler/_base.py,sha256=YF5lsJBIl9qts9GLW5Z62JuYeo4SnIArhlFwTGkfTb4,897
77
+ nshtrainer/profiler/advanced.py,sha256=44asloha0aGUW8YwjQt3lm3ve8H-N6mM4QgseUSLT30,1112
78
+ nshtrainer/profiler/pytorch.py,sha256=tGeRvoPP5ulWX2RkfXrQvMBoki1T95dpz5p8mwyon1I,2709
79
+ nshtrainer/profiler/simple.py,sha256=MbMfsJvligd0mtGiltxJ0T8MQVDP9T9BzQZFwswl66Y,957
77
80
  nshtrainer/runner.py,sha256=USAjrExHkN5oVNVunsoPnLxfQrEHSaa54S3RipOe544,3605
78
81
  nshtrainer/scripts/find_packages.py,sha256=ixYivZobumyyGsf2B9oYMLyLTRcBzY_vUv-u3bNW-hs,1424
79
82
  nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3E,40
83
+ nshtrainer/trainer/_config.py,sha256=ZIodM5Ek1lpkWFhQ_VfmKR7q1mZFFwtjfx8FH72H8WM,29174
80
84
  nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
81
85
  nshtrainer/trainer/checkpoint_connector.py,sha256=r0ir4xYSdf_jebM0x09qaO6nJsvsiRQDyM0fs80ppOQ,2347
82
86
  nshtrainer/trainer/signal_connector.py,sha256=2EzkVktlasl8PgWAKNLDZRUMY__gRlDy1HdinAU-tfU,10740
83
- nshtrainer/trainer/trainer.py,sha256=L4nYXq6Gts2sS9CQGenwEcvMET4L5vO5c60KM5Hm8Do,17544
87
+ nshtrainer/trainer/trainer.py,sha256=iYueHW-m8fHyC8SQuXmpgxq_-GUa7pAJik7rDFPXmy0,17499
84
88
  nshtrainer/util/_environment_info.py,sha256=CFUUZYjXhBLWGc0jtPNOaZgYMueUDEHpEaWFA1f3GoY,24213
85
89
  nshtrainer/util/_useful_types.py,sha256=dwZokFkIe7M5i2GR3nQ9A1lhGw06DMAFfH5atyquqSA,8000
90
+ nshtrainer/util/config/__init__.py,sha256=6iCFLhujhbOi7Q694e--Sq-ethiGoGHShm699GPV8Zg,154
91
+ nshtrainer/util/config/duration.py,sha256=f_obz0eorkktI3HzAuIawABDkvuL4lDqCxcPb3UW7Q4,692
86
92
  nshtrainer/util/environment.py,sha256=AeW_kLl-N70wmb6L_JLz1wRj0kA70xs6RCmc9iUqczE,4159
87
93
  nshtrainer/util/path.py,sha256=VkpuhR4GaZtSFBVqbGAvfjcrU-PR8xwiGzzwFNOWP9c,2995
88
94
  nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
89
95
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
90
96
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
91
97
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
92
- nshtrainer-0.30.1.dist-info/METADATA,sha256=LV0wQlmotpfC3qO76dFVCbS26bEl-9YMiTetEeqVQsU,916
93
- nshtrainer-0.30.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
94
- nshtrainer-0.30.1.dist-info/RECORD,,
98
+ nshtrainer-0.32.0.dist-info/METADATA,sha256=pe-TVRS0ZmZ9kx5NBQ8-0C6m4ZzaH_MalJZmh31mUNQ,916
99
+ nshtrainer-0.32.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
100
+ nshtrainer-0.32.0.dist-info/RECORD,,
@@ -1,206 +0,0 @@
1
- import logging
2
- from collections.abc import Callable, Iterable, Sequence
3
- from typing import Any, TypeAlias, cast, final, overload
4
-
5
- from lightning.pytorch import Callback, LightningModule
6
- from lightning.pytorch.callbacks import LambdaCallback
7
- from typing_extensions import override
8
-
9
- from ...util.typing_utils import mixin_base_type
10
-
11
- log = logging.getLogger(__name__)
12
-
13
- CallbackFn: TypeAlias = Callable[[], Callback | Iterable[Callback] | None]
14
-
15
-
16
- class CallbackRegistrarModuleMixin:
17
- @override
18
- def __init__(self, *args, **kwargs):
19
- super().__init__(*args, **kwargs)
20
-
21
- self._nshtrainer_callbacks: list[CallbackFn] = []
22
-
23
- @overload
24
- def register_callback(
25
- self, callback: Callback | Iterable[Callback] | CallbackFn | None = None, /
26
- ): ...
27
-
28
- @overload
29
- def register_callback(
30
- self,
31
- /,
32
- *,
33
- setup: Callable | None = None,
34
- teardown: Callable | None = None,
35
- on_fit_start: Callable | None = None,
36
- on_fit_end: Callable | None = None,
37
- on_sanity_check_start: Callable | None = None,
38
- on_sanity_check_end: Callable | None = None,
39
- on_train_batch_start: Callable | None = None,
40
- on_train_batch_end: Callable | None = None,
41
- on_train_epoch_start: Callable | None = None,
42
- on_train_epoch_end: Callable | None = None,
43
- on_validation_epoch_start: Callable | None = None,
44
- on_validation_epoch_end: Callable | None = None,
45
- on_test_epoch_start: Callable | None = None,
46
- on_test_epoch_end: Callable | None = None,
47
- on_validation_batch_start: Callable | None = None,
48
- on_validation_batch_end: Callable | None = None,
49
- on_test_batch_start: Callable | None = None,
50
- on_test_batch_end: Callable | None = None,
51
- on_train_start: Callable | None = None,
52
- on_train_end: Callable | None = None,
53
- on_validation_start: Callable | None = None,
54
- on_validation_end: Callable | None = None,
55
- on_test_start: Callable | None = None,
56
- on_test_end: Callable | None = None,
57
- on_exception: Callable | None = None,
58
- on_save_checkpoint: Callable | None = None,
59
- on_load_checkpoint: Callable | None = None,
60
- on_before_backward: Callable | None = None,
61
- on_after_backward: Callable | None = None,
62
- on_before_optimizer_step: Callable | None = None,
63
- on_before_zero_grad: Callable | None = None,
64
- on_predict_start: Callable | None = None,
65
- on_predict_end: Callable | None = None,
66
- on_predict_batch_start: Callable | None = None,
67
- on_predict_batch_end: Callable | None = None,
68
- on_predict_epoch_start: Callable | None = None,
69
- on_predict_epoch_end: Callable | None = None,
70
- ): ...
71
-
72
- def register_callback(
73
- self,
74
- callback: Callback | Iterable[Callback] | CallbackFn | None = None,
75
- /,
76
- *,
77
- setup: Callable | None = None,
78
- teardown: Callable | None = None,
79
- on_fit_start: Callable | None = None,
80
- on_fit_end: Callable | None = None,
81
- on_sanity_check_start: Callable | None = None,
82
- on_sanity_check_end: Callable | None = None,
83
- on_train_batch_start: Callable | None = None,
84
- on_train_batch_end: Callable | None = None,
85
- on_train_epoch_start: Callable | None = None,
86
- on_train_epoch_end: Callable | None = None,
87
- on_validation_epoch_start: Callable | None = None,
88
- on_validation_epoch_end: Callable | None = None,
89
- on_test_epoch_start: Callable | None = None,
90
- on_test_epoch_end: Callable | None = None,
91
- on_validation_batch_start: Callable | None = None,
92
- on_validation_batch_end: Callable | None = None,
93
- on_test_batch_start: Callable | None = None,
94
- on_test_batch_end: Callable | None = None,
95
- on_train_start: Callable | None = None,
96
- on_train_end: Callable | None = None,
97
- on_validation_start: Callable | None = None,
98
- on_validation_end: Callable | None = None,
99
- on_test_start: Callable | None = None,
100
- on_test_end: Callable | None = None,
101
- on_exception: Callable | None = None,
102
- on_save_checkpoint: Callable | None = None,
103
- on_load_checkpoint: Callable | None = None,
104
- on_before_backward: Callable | None = None,
105
- on_after_backward: Callable | None = None,
106
- on_before_optimizer_step: Callable | None = None,
107
- on_before_zero_grad: Callable | None = None,
108
- on_predict_start: Callable | None = None,
109
- on_predict_end: Callable | None = None,
110
- on_predict_batch_start: Callable | None = None,
111
- on_predict_batch_end: Callable | None = None,
112
- on_predict_epoch_start: Callable | None = None,
113
- on_predict_epoch_end: Callable | None = None,
114
- ):
115
- if callback is None:
116
- callback = LambdaCallback(
117
- setup=setup,
118
- teardown=teardown,
119
- on_fit_start=on_fit_start,
120
- on_fit_end=on_fit_end,
121
- on_sanity_check_start=on_sanity_check_start,
122
- on_sanity_check_end=on_sanity_check_end,
123
- on_train_batch_start=on_train_batch_start,
124
- on_train_batch_end=on_train_batch_end,
125
- on_train_epoch_start=on_train_epoch_start,
126
- on_train_epoch_end=on_train_epoch_end,
127
- on_validation_epoch_start=on_validation_epoch_start,
128
- on_validation_epoch_end=on_validation_epoch_end,
129
- on_test_epoch_start=on_test_epoch_start,
130
- on_test_epoch_end=on_test_epoch_end,
131
- on_validation_batch_start=on_validation_batch_start,
132
- on_validation_batch_end=on_validation_batch_end,
133
- on_test_batch_start=on_test_batch_start,
134
- on_test_batch_end=on_test_batch_end,
135
- on_train_start=on_train_start,
136
- on_train_end=on_train_end,
137
- on_validation_start=on_validation_start,
138
- on_validation_end=on_validation_end,
139
- on_test_start=on_test_start,
140
- on_test_end=on_test_end,
141
- on_exception=on_exception,
142
- on_save_checkpoint=on_save_checkpoint,
143
- on_load_checkpoint=on_load_checkpoint,
144
- on_before_backward=on_before_backward,
145
- on_after_backward=on_after_backward,
146
- on_before_optimizer_step=on_before_optimizer_step,
147
- on_before_zero_grad=on_before_zero_grad,
148
- on_predict_start=on_predict_start,
149
- on_predict_end=on_predict_end,
150
- on_predict_batch_start=on_predict_batch_start,
151
- on_predict_batch_end=on_predict_batch_end,
152
- on_predict_epoch_start=on_predict_epoch_start,
153
- on_predict_epoch_end=on_predict_epoch_end,
154
- )
155
-
156
- if not callable(callback):
157
- callback_ = cast(CallbackFn, lambda: callback)
158
- else:
159
- callback_ = callback
160
-
161
- self._nshtrainer_callbacks.append(callback_)
162
-
163
-
164
- class CallbackModuleMixin(
165
- CallbackRegistrarModuleMixin,
166
- mixin_base_type(LightningModule),
167
- ):
168
- def _gather_all_callbacks(self):
169
- modules: list[Any] = []
170
- if isinstance(self, CallbackRegistrarModuleMixin):
171
- modules.append(self)
172
- if (
173
- datamodule := getattr(self.trainer, "datamodule", None)
174
- ) is not None and isinstance(datamodule, CallbackRegistrarModuleMixin):
175
- modules.append(datamodule)
176
- modules.extend(
177
- module
178
- for module in self.children()
179
- if isinstance(module, CallbackRegistrarModuleMixin)
180
- )
181
- for module in modules:
182
- yield from module._nshtrainer_callbacks
183
-
184
- @final
185
- @override
186
- def configure_callbacks(self):
187
- callbacks = super().configure_callbacks()
188
- if not isinstance(callbacks, Sequence):
189
- callbacks = [callbacks]
190
-
191
- callbacks = list(callbacks)
192
- for callback_fn in self._gather_all_callbacks():
193
- callback_result = callback_fn()
194
- if callback_result is None:
195
- continue
196
-
197
- if not isinstance(callback_result, Iterable):
198
- callback_result = [callback_result]
199
-
200
- for callback in callback_result:
201
- log.info(
202
- f"Registering {callback.__class__.__qualname__} callback {callback}"
203
- )
204
- callbacks.append(callback)
205
-
206
- return callbacks
@@ -1,42 +0,0 @@
1
- import logging
2
-
3
- import torch
4
- import torch.distributed
5
-
6
- log = logging.getLogger(__name__)
7
-
8
-
9
- class DebugModuleMixin:
10
- @torch.jit.unused
11
- def breakpoint(self, rank_zero_only: bool = True):
12
- if (
13
- not rank_zero_only
14
- or not torch.distributed.is_initialized()
15
- or torch.distributed.get_rank() == 0
16
- ):
17
- breakpoint()
18
-
19
- if rank_zero_only and torch.distributed.is_initialized():
20
- _ = torch.distributed.barrier()
21
-
22
- @torch.jit.unused
23
- def ensure_finite(
24
- self,
25
- tensor: torch.Tensor,
26
- name: str | None = None,
27
- throw: bool = False,
28
- ):
29
- name_parts: list[str] = ["Tensor"]
30
- if name is not None:
31
- name_parts.append(name)
32
- name = " ".join(name_parts)
33
-
34
- not_finite = ~torch.isfinite(tensor)
35
- if not_finite.any():
36
- msg = f"{name} has {not_finite.sum().item()}/{not_finite.numel()} non-finite values."
37
- if throw:
38
- raise RuntimeError(msg)
39
- else:
40
- log.warning(msg)
41
- return False
42
- return True
@@ -1,70 +0,0 @@
1
- from typing import Any, Literal, cast
2
-
3
- import torch.distributed
4
- from lightning.pytorch import LightningModule
5
- from torch.distributed import ReduceOp
6
- from typing_extensions import TypeVar
7
-
8
- from ...util.typing_utils import mixin_base_type
9
-
10
- T = TypeVar("T", infer_variance=True)
11
-
12
- ReduceOpStr = Literal[
13
- "avg",
14
- "mean",
15
- "band",
16
- "bor",
17
- "bxor",
18
- "max",
19
- "min",
20
- "premul_sum",
21
- "product",
22
- "sum",
23
- ]
24
- VALID_REDUCE_OPS = (
25
- "avg",
26
- "mean",
27
- "band",
28
- "bor",
29
- "bxor",
30
- "max",
31
- "min",
32
- "premul_sum",
33
- "product",
34
- "sum",
35
- )
36
-
37
-
38
- class DistributedMixin(mixin_base_type(LightningModule)):
39
- def all_gather_object(
40
- self,
41
- object: T,
42
- group: torch.distributed.ProcessGroup | None = None,
43
- ) -> list[T]:
44
- if (
45
- not torch.distributed.is_available()
46
- or not torch.distributed.is_initialized()
47
- ):
48
- return [object]
49
-
50
- object_list = [cast(T, None) for _ in range(self.trainer.world_size)]
51
- torch.distributed.all_gather_object(object_list, object, group=group)
52
- return object_list
53
-
54
- def barrier(self, name: str | None = None):
55
- self.trainer.strategy.barrier(name=name)
56
-
57
- def reduce(
58
- self,
59
- tensor: torch.Tensor,
60
- reduce_op: ReduceOp.RedOpType | ReduceOpStr,
61
- group: Any | None = None,
62
- ) -> torch.Tensor:
63
- if isinstance(reduce_op, str):
64
- # validate reduce_op
65
- if reduce_op not in VALID_REDUCE_OPS:
66
- raise ValueError(
67
- f"reduce_op must be one of {VALID_REDUCE_OPS}, got {reduce_op}"
68
- )
69
-
70
- return self.trainer.strategy.reduce(tensor, group=group, reduce_op=reduce_op)
@@ -1,24 +0,0 @@
1
- from lightning.pytorch import LightningDataModule, LightningModule
2
- from lightning.pytorch.profilers import PassThroughProfiler
3
-
4
- from ...util.typing_utils import mixin_base_type
5
-
6
-
7
- class ProfilerMixin(mixin_base_type(LightningModule)):
8
- @property
9
- def profiler(self):
10
- if not isinstance(self, (LightningModule, LightningDataModule)):
11
- raise TypeError(
12
- "`profiler` can only be used on LightningModule or LightningDataModule"
13
- )
14
-
15
- if (trainer := self.trainer) is None:
16
- raise RuntimeError("trainer is not defined")
17
-
18
- if not hasattr(trainer, "profiler"):
19
- raise RuntimeError("trainer does not have profiler")
20
-
21
- if (profiler := getattr(trainer, "profiler")) is None:
22
- profiler = PassThroughProfiler()
23
-
24
- return profiler