nshtrainer 0.30.1__py3-none-any.whl → 0.31.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. nshtrainer/__init__.py +1 -2
  2. nshtrainer/_directory.py +85 -0
  3. nshtrainer/callbacks/__init__.py +8 -0
  4. nshtrainer/callbacks/directory_setup.py +85 -0
  5. nshtrainer/callbacks/rlp_sanity_checks.py +230 -0
  6. nshtrainer/callbacks/shared_parameters.py +87 -0
  7. nshtrainer/config.py +67 -0
  8. nshtrainer/ll/__init__.py +5 -4
  9. nshtrainer/ll/model.py +7 -0
  10. nshtrainer/loggers/wandb.py +1 -1
  11. nshtrainer/lr_scheduler/linear_warmup_cosine.py +1 -1
  12. nshtrainer/model/__init__.py +0 -21
  13. nshtrainer/model/base.py +139 -44
  14. nshtrainer/model/config.py +7 -1025
  15. nshtrainer/model/{modules → mixins}/callback.py +2 -2
  16. nshtrainer/model/{modules → mixins}/logger.py +13 -16
  17. nshtrainer/profiler/__init__.py +13 -0
  18. nshtrainer/profiler/_base.py +29 -0
  19. nshtrainer/profiler/advanced.py +37 -0
  20. nshtrainer/profiler/pytorch.py +83 -0
  21. nshtrainer/profiler/simple.py +36 -0
  22. nshtrainer/trainer/_config.py +778 -0
  23. nshtrainer/trainer/trainer.py +16 -17
  24. nshtrainer/{config → util/config}/__init__.py +1 -0
  25. {nshtrainer-0.30.1.dist-info → nshtrainer-0.31.0.dist-info}/METADATA +1 -1
  26. {nshtrainer-0.30.1.dist-info → nshtrainer-0.31.0.dist-info}/RECORD +28 -22
  27. nshtrainer/model/modules/debug.py +0 -42
  28. nshtrainer/model/modules/distributed.py +0 -70
  29. nshtrainer/model/modules/profiler.py +0 -24
  30. nshtrainer/model/modules/rlp_sanity_checks.py +0 -202
  31. nshtrainer/model/modules/shared_parameters.py +0 -72
  32. /nshtrainer/{config → util/config}/duration.py +0 -0
  33. {nshtrainer-0.30.1.dist-info → nshtrainer-0.31.0.dist-info}/WHEEL +0 -0
@@ -18,10 +18,8 @@ from typing_extensions import Unpack, assert_never, override
18
18
 
19
19
  from .._checkpoint.metadata import _write_checkpoint_metadata
20
20
  from ..callbacks.base import resolve_all_callbacks
21
- from ..model.config import (
21
+ from ._config import (
22
22
  AcceleratorConfigProtocol,
23
- BaseConfig,
24
- BaseProfilerConfig,
25
23
  LightningTrainerKwargs,
26
24
  StrategyConfigProtocol,
27
25
  )
@@ -29,6 +27,9 @@ from ._runtime_callback import RuntimeTrackerCallback, Stage
29
27
  from .checkpoint_connector import _CheckpointConnector
30
28
  from .signal_connector import _SignalConnector
31
29
 
30
+ if TYPE_CHECKING:
31
+ from ..model.config import BaseConfig
32
+
32
33
  log = logging.getLogger(__name__)
33
34
 
34
35
 
@@ -58,14 +59,14 @@ def _is_bf16_supported_no_emulation():
58
59
 
59
60
  class Trainer(LightningTrainer):
60
61
  @classmethod
61
- def _pre_init(cls, config: BaseConfig):
62
+ def _pre_init(cls, config: "BaseConfig"):
62
63
  if (precision := config.trainer.set_float32_matmul_precision) is not None:
63
64
  torch.set_float32_matmul_precision(precision)
64
65
 
65
66
  @classmethod
66
67
  def _update_kwargs(
67
68
  cls,
68
- config: BaseConfig,
69
+ config: "BaseConfig",
69
70
  kwargs_ctor: LightningTrainerKwargs,
70
71
  ):
71
72
  kwargs: LightningTrainerKwargs = {
@@ -217,18 +218,16 @@ class Trainer(LightningTrainer):
217
218
  gradient_clip_val=grad_clip_config.value,
218
219
  )
219
220
 
220
- if profiler := config.trainer.profiler:
221
- # If the profiler is an ProfilerConfig instance, then we instantiate it.
222
- if isinstance(profiler, BaseProfilerConfig):
223
- profiler = profiler.create_profiler(config)
224
- # Make sure that the profiler is an instance of `Profiler`.
225
- if not isinstance(profiler, Profiler):
226
- raise ValueError(f"{profiler=} is not an instance of `{Profiler}`.")
227
-
221
+ if profiler_config := config.trainer.profiler:
222
+ if (profiler := profiler_config.create_profiler(config)) is None:
223
+ log.warning(f"Profiler config {profiler_config=} returned None.")
224
+ # Make sure that the profiler is an instance of `Profiler`.
225
+ elif not isinstance(profiler, Profiler):
226
+ raise ValueError(f"{profiler=} is not an instance of `{Profiler}`.")
228
227
  # Otherwise, if the profiler is a string (e.g., "simpe", "advanced", "pytorch"),
229
228
  # then we just pass it through.
230
- # kwargs["profiler"] = profiler
231
- _update_kwargs(profiler=profiler)
229
+ else:
230
+ _update_kwargs(profiler=profiler)
232
231
 
233
232
  if callbacks := resolve_all_callbacks(config):
234
233
  _update_kwargs(callbacks=callbacks)
@@ -281,7 +280,7 @@ class Trainer(LightningTrainer):
281
280
  @override
282
281
  def __init__(
283
282
  self,
284
- config: BaseConfig,
283
+ config: "BaseConfig",
285
284
  /,
286
285
  **kwargs: Unpack[LightningTrainerKwargs],
287
286
  ):
@@ -424,7 +423,7 @@ class Trainer(LightningTrainer):
424
423
  # Save the checkpoint metadata
425
424
  metadata_path = None
426
425
  lm = self._base_module
427
- root_config = cast(BaseConfig, lm.hparams)
426
+ root_config = cast("BaseConfig", lm.hparams)
428
427
  if root_config.trainer.save_checkpoint_metadata and self.is_global_zero:
429
428
  # Generate the metadata and write to disk
430
429
  if (
@@ -1,3 +1,4 @@
1
+ from . import duration as duration
1
2
  from .duration import Duration as Duration
2
3
  from .duration import Epochs as Epochs
3
4
  from .duration import Steps as Steps
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.30.1
3
+ Version: 0.31.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,11 +1,12 @@
1
- nshtrainer/__init__.py,sha256=sUb2yNdkHHhrKWCeWA5QKIA1Xx3jkO1QGD5Pa-HvgbA,614
1
+ nshtrainer/__init__.py,sha256=flMI50Hj1Ie8c1YMSUQ759AqtNBQLT_zHaV2J9EUmOs,573
2
2
  nshtrainer/_callback.py,sha256=A1zLsTy4b_wOYnInLLXGSRdHzT2yNa6mPEql-ozm0u0,1013
3
3
  nshtrainer/_checkpoint/loader.py,sha256=5vjg-OFChXJjgiOVv8vnV8nwTscfdDtEdxQRz6uPfDE,14158
4
4
  nshtrainer/_checkpoint/metadata.py,sha256=5D4PgKodzhLsmQvuF3xxkH49epKaegxi4wh_ImDTtns,4737
5
5
  nshtrainer/_checkpoint/saver.py,sha256=MbX_WjkDtHHAf9Ms-KXDlknkjiPXVoGIe2ciO28AdZ0,1264
6
+ nshtrainer/_directory.py,sha256=RjnW6vKTeKlz2vQWT3cG0Jje5BkFXA7HpUubDhcSiq4,2993
6
7
  nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
7
8
  nshtrainer/_hf_hub.py,sha256=0bkXkqhve5D1onMW-fCfuvVKlTn0i6jv_6uMNgZ7OHQ,12974
8
- nshtrainer/callbacks/__init__.py,sha256=4qocBDzQbLLhhbIEfvbA3SQB_Dy9ZJH7keMwPay-ZS8,2359
9
+ nshtrainer/callbacks/__init__.py,sha256=NpEV8bMU12ClFN2sLKLBDXnuwIHYyZOCNxDZgjrV104,2892
9
10
  nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
10
11
  nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
11
12
  nshtrainer/callbacks/base.py,sha256=NpjeKmonJ1Kaz5_39XSn3LlDwvbGjk6WV8BpHSNCvI4,3508
@@ -14,6 +15,7 @@ nshtrainer/callbacks/checkpoint/_base.py,sha256=vvlwuD-20NozYVIolGGShmUdkkNYeuwN
14
15
  nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=8BHgLAd3Tuzf5sup0guEAKF1jJiAwYsjdKBFYZw98ac,2171
15
16
  nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CWWv0cSwQ1VAX26N7hAyMxbNCk26Keh39oQguBEK5To,1102
16
17
  nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=ctT88EGT22_t_6tr5r7Sfo43cuve6XeroBnBYRMPOus,3372
18
+ nshtrainer/callbacks/directory_setup.py,sha256=c0uY0oTqLcQ3egInHO7G6BeQQgk_xvOLoHH8FR-9U0U,2629
17
19
  nshtrainer/callbacks/early_stopping.py,sha256=VWuJz0oN87b6SwBeVc32YNpeJr1wts8K45k8JJJmG9I,4617
18
20
  nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,12131
19
21
  nshtrainer/callbacks/finite_checks.py,sha256=gJC_RUr3ais3FJI0uB6wUZnDdE3WRwCix3ppA3PwQXA,2077
@@ -22,15 +24,16 @@ nshtrainer/callbacks/interval.py,sha256=smz5Zl8cN6X6yHKVsMRS2e3SEkzRCP3LvwE1ONvL
22
24
  nshtrainer/callbacks/log_epoch.py,sha256=fTa_K_Y8A7g09630cG4YkDE6AzSMPkjb9bpPm4gtqos,1120
23
25
  nshtrainer/callbacks/norm_logging.py,sha256=T2psu8mYsw9iahPKT6aUPjkGrZ4TIzm6_UUUmE09GJs,6274
24
26
  nshtrainer/callbacks/print_table.py,sha256=_FdAHhqylWGk4Z0c2FrLFeiMA4jhfA_beZRK_BHpzmE,2837
27
+ nshtrainer/callbacks/rlp_sanity_checks.py,sha256=c30G9jAu42QLLIS5LnusdSnI3wqyIHgOUFDRcKESuNI,9935
28
+ nshtrainer/callbacks/shared_parameters.py,sha256=fqlDweFDXPV_bfcAWpRgaJIad9i5AehYDtuJjDtUum4,2922
25
29
  nshtrainer/callbacks/throughput_monitor.py,sha256=H_ocXErZxUO3dxFk8Tx_VQdpI9E_Ztvqof5WtFevLyQ,1838
26
30
  nshtrainer/callbacks/timer.py,sha256=quS79oYClDUvQxJkNWmDMe0hwRUkkREgTgqzVrnom50,4607
27
31
  nshtrainer/callbacks/wandb_watch.py,sha256=Y6SEXfIx3kDDQbI5zpP53BVq0FBLJbLd3RJsiHZk1-Y,2921
28
- nshtrainer/config/__init__.py,sha256=v9RtlM1Pqj_4fCDfskgxEtiGtbWH3Tj7lqNsKCDQ4gk,119
29
- nshtrainer/config/duration.py,sha256=f_obz0eorkktI3HzAuIawABDkvuL4lDqCxcPb3UW7Q4,692
32
+ nshtrainer/config.py,sha256=W6nAmn5Y1GVZto9vkx4v8i5XdikMSdVYDiq7kbDEWAg,5900
30
33
  nshtrainer/data/__init__.py,sha256=7mk1tr7SWUZ7ySbsf0y0ZPszk7u4QznPhQ-7wnpH9ec,149
31
34
  nshtrainer/data/balanced_batch_sampler.py,sha256=dGBTDDtlBU6c-ZlVQOCnTW7SjTB5hczWsOWEdUWjvkA,4385
32
35
  nshtrainer/data/transform.py,sha256=6SNs3_TpNpfhcwTwvPKyEJ3opM1OT7LmMEYQNHKgRl8,2227
33
- nshtrainer/ll/__init__.py,sha256=6UTt2apSD8tOZw3M7hyd-33v4RKSpNNATlWFbW4cNnU,2523
36
+ nshtrainer/ll/__init__.py,sha256=L-aTi1V1bbvnZjOro8NvI393zbHQSFR9movWSRK9Mds,2477
34
37
  nshtrainer/ll/_experimental.py,sha256=oBQCKOEVYoxuUU9eLb-Fg2B2mzZD7SA0zfAO6lmWZ88,53
35
38
  nshtrainer/ll/actsave.py,sha256=2lbiseSrjcwFT6AiyLNWarTWl1bnzliVWlu1iOfnP30,209
36
39
  nshtrainer/ll/callbacks.py,sha256=AxyUmc8aGRSjx6WwwgXYCmdJ73rwLuEAEH0AGRosojQ,49
@@ -38,7 +41,7 @@ nshtrainer/ll/config.py,sha256=fKumJf42HY2FITX1QUM1OTXkYD6U2np2ciyd4PFRPZ8,145
38
41
  nshtrainer/ll/data.py,sha256=zRG0FRje-jtSHximVzkHIHzpwsyQxpHCoACFihNKLPM,44
39
42
  nshtrainer/ll/log.py,sha256=d4BB3TyM8imK65EXOiOeUTF0zFM1ropbe7Vq3DeB0xU,140
40
43
  nshtrainer/ll/lr_scheduler.py,sha256=7xjhN6L69BCUzFhcy33NtMtPuCzHiB611zVWFg92lQ0,52
41
- nshtrainer/ll/model.py,sha256=cxFQfFc-2mAYBGwDpP8m5tjQBs7M47cZ6JoPXksPaoI,473
44
+ nshtrainer/ll/model.py,sha256=Cw8Vq8IUL6YU1fTUcOIZsXcNJ3XyKgQY4YENIsL9H7c,996
42
45
  nshtrainer/ll/nn.py,sha256=8qiRDFwojIxkB7-LtNWk4mLL2tJbaskHYofDsOIHiNg,42
43
46
  nshtrainer/ll/optimizer.py,sha256=3T-VZtT73jVvwCNJGDjgGEbzs-1LFTzMQH-SB_58mSo,49
44
47
  nshtrainer/ll/runner.py,sha256=B0m5VEhNKIjF1aFmqPkonkQxDoRL2jeHZGsV3zwhSVE,117
@@ -51,44 +54,47 @@ nshtrainer/loggers/__init__.py,sha256=C_xk0A3_qKbNdTmzK85AgjRHFD3w-jPRS2ig-iPhfE
51
54
  nshtrainer/loggers/_base.py,sha256=xiZKEK0ALJkcqf4OpVNRY0QbZsamR_WR7x7m_68YHXQ,705
52
55
  nshtrainer/loggers/csv.py,sha256=D_lYyd94bZ8jAgnRo-ARtFgVcInaD9zktxtsUD9RWCI,1052
53
56
  nshtrainer/loggers/tensorboard.py,sha256=wL2amRSdP68zbslZvBeM0ZQBnjF3hIKsz-_lBbdomaM,2216
54
- nshtrainer/loggers/wandb.py,sha256=FPwbf618AYmuPzHdhd1ZFhJ8qDjwTUiSe7cm7g3KCyM,5112
57
+ nshtrainer/loggers/wandb.py,sha256=8B2BMMzILRSUEiCkmp_fBpcXs69euRKViTiaV__DJZk,5128
55
58
  nshtrainer/lr_scheduler/__init__.py,sha256=uEvgaFAs-4s_bAEMaildy0GT6OvgpgOEKTuzqutESHE,736
56
59
  nshtrainer/lr_scheduler/_base.py,sha256=7xOIuxQ86YHbFWG5a3gX46emQj1WN_LaY4-i0Q1TDBg,3659
57
- nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=Fyontbfu4k2932xZenE63QL4CrVGWANXdTeq63dUko0,5347
60
+ nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=YQm84Sb4SWrofpBwa39DCslJvu2uorjbpWaGWyys1l4,5352
58
61
  nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=h76oTHYpMxauV_l6lviya5DW-WKArwxxf7ZQizhmbCw,2782
59
62
  nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy00,50
60
63
  nshtrainer/metrics/_config.py,sha256=jgRBfDAQLFTW7AiUY7CRtdfts6CR6keeuqm0FFMWCzQ,1288
61
- nshtrainer/model/__init__.py,sha256=VyRziPT3YilP6xjLi_StsSqtlvn7N4LOMzgukRsOnF8,1380
62
- nshtrainer/model/base.py,sha256=oQVolDk81acy4OlckwQEBHuX2gCaVSYiIA0JaDIfhQ4,17517
63
- nshtrainer/model/config.py,sha256=zcCLcqvg4u7Zg6SLtCnqdIfiW8I0eART47lf1LCYl-A,43326
64
- nshtrainer/model/modules/callback.py,sha256=1z6gUDBd35KG3phGzRekgZM6SIk-wj5Uo6APN4YhRR0,8549
65
- nshtrainer/model/modules/debug.py,sha256=Yy7XEdPou9BkCsD5hJchwJGmCVGrfUru5g9VjPM4uAw,1120
66
- nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
67
- nshtrainer/model/modules/logger.py,sha256=CJWSmNT8SV5GLtfml-qGYenqRPXcNOMsJRGEavAd8Hw,5464
68
- nshtrainer/model/modules/profiler.py,sha256=rQ_jRMcM1Z2AIROZlRnBRHM5rkTpq67afZPD6CIRfXs,825
69
- nshtrainer/model/modules/rlp_sanity_checks.py,sha256=I_ralr2ThQ-D_FkVQTwbdXLLlgHJEr7-s01I5wSDjps,8893
70
- nshtrainer/model/modules/shared_parameters.py,sha256=ZiRKkZXr6RwdwLCdZCJPl3dXe7bnT8Z9yTeRK5bXBGk,2687
64
+ nshtrainer/model/__init__.py,sha256=2i_VEy6u_Y1LUGKljHXWeekvhnUcanZM2QyaaBM1Bmw,261
65
+ nshtrainer/model/base.py,sha256=hT27FtzwKQiEL0C8RcaTKYXlanfvzTxHOJpHUcWiItk,19891
66
+ nshtrainer/model/config.py,sha256=Q4Wong6w3cp_Sq7s8iZdABKF-LZBbSCFn_TQPYkhkrI,6572
67
+ nshtrainer/model/mixins/callback.py,sha256=lh3imlw1H3ESIG4WFA5frooSlWi6-RPUUDRFGRzEg4A,8571
68
+ nshtrainer/model/mixins/logger.py,sha256=xOymSTofukEYZGkGojXsMEO__ZlBI5lIPZVmlotMEX8,5291
71
69
  nshtrainer/nn/__init__.py,sha256=0QPFl02a71WZQjLMGOlFNMmsYP5aa1q3eABHmnWH58Q,1427
72
70
  nshtrainer/nn/mlp.py,sha256=V0FrScpIUdg_IgIO8GMtIsGEtmHjwF14i2IWxmZrsqg,5952
73
71
  nshtrainer/nn/module_dict.py,sha256=NOY0B6WDTnktyWH4GthsprMQo0bpehC-hCq9SfD8paE,2329
74
72
  nshtrainer/nn/module_list.py,sha256=fb2u5Rqdjff8Pekyr9hkCPkBorQ-fldzzFAjsgWAm30,1719
75
73
  nshtrainer/nn/nonlinearity.py,sha256=4sYE4MN5zojc-go1k0PYtqssVRuXrM7D4tbpIXp5K-E,6078
76
74
  nshtrainer/optimizer.py,sha256=kuJEA1pvB3y1FcsfhAoOJujVqEZqFHlmYO8GW6JeA1g,1527
75
+ nshtrainer/profiler/__init__.py,sha256=RQYkqQBVWuVvfdtAJIk2x5bNsXownklT87Mr_j-uXjw,474
76
+ nshtrainer/profiler/_base.py,sha256=YF5lsJBIl9qts9GLW5Z62JuYeo4SnIArhlFwTGkfTb4,897
77
+ nshtrainer/profiler/advanced.py,sha256=44asloha0aGUW8YwjQt3lm3ve8H-N6mM4QgseUSLT30,1112
78
+ nshtrainer/profiler/pytorch.py,sha256=tGeRvoPP5ulWX2RkfXrQvMBoki1T95dpz5p8mwyon1I,2709
79
+ nshtrainer/profiler/simple.py,sha256=MbMfsJvligd0mtGiltxJ0T8MQVDP9T9BzQZFwswl66Y,957
77
80
  nshtrainer/runner.py,sha256=USAjrExHkN5oVNVunsoPnLxfQrEHSaa54S3RipOe544,3605
78
81
  nshtrainer/scripts/find_packages.py,sha256=ixYivZobumyyGsf2B9oYMLyLTRcBzY_vUv-u3bNW-hs,1424
79
82
  nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3E,40
83
+ nshtrainer/trainer/_config.py,sha256=2SgO1L5VBfWQ5g7Dg2dTx_vq2_Wo7dTqt2A4GlQaGo0,28673
80
84
  nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
81
85
  nshtrainer/trainer/checkpoint_connector.py,sha256=r0ir4xYSdf_jebM0x09qaO6nJsvsiRQDyM0fs80ppOQ,2347
82
86
  nshtrainer/trainer/signal_connector.py,sha256=2EzkVktlasl8PgWAKNLDZRUMY__gRlDy1HdinAU-tfU,10740
83
- nshtrainer/trainer/trainer.py,sha256=L4nYXq6Gts2sS9CQGenwEcvMET4L5vO5c60KM5Hm8Do,17544
87
+ nshtrainer/trainer/trainer.py,sha256=iYueHW-m8fHyC8SQuXmpgxq_-GUa7pAJik7rDFPXmy0,17499
84
88
  nshtrainer/util/_environment_info.py,sha256=CFUUZYjXhBLWGc0jtPNOaZgYMueUDEHpEaWFA1f3GoY,24213
85
89
  nshtrainer/util/_useful_types.py,sha256=dwZokFkIe7M5i2GR3nQ9A1lhGw06DMAFfH5atyquqSA,8000
90
+ nshtrainer/util/config/__init__.py,sha256=6iCFLhujhbOi7Q694e--Sq-ethiGoGHShm699GPV8Zg,154
91
+ nshtrainer/util/config/duration.py,sha256=f_obz0eorkktI3HzAuIawABDkvuL4lDqCxcPb3UW7Q4,692
86
92
  nshtrainer/util/environment.py,sha256=AeW_kLl-N70wmb6L_JLz1wRj0kA70xs6RCmc9iUqczE,4159
87
93
  nshtrainer/util/path.py,sha256=VkpuhR4GaZtSFBVqbGAvfjcrU-PR8xwiGzzwFNOWP9c,2995
88
94
  nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
89
95
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
90
96
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
91
97
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
92
- nshtrainer-0.30.1.dist-info/METADATA,sha256=LV0wQlmotpfC3qO76dFVCbS26bEl-9YMiTetEeqVQsU,916
93
- nshtrainer-0.30.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
94
- nshtrainer-0.30.1.dist-info/RECORD,,
98
+ nshtrainer-0.31.0.dist-info/METADATA,sha256=99b-8IvPlMmTrjyb5EK1kKsgKj8lWhGw4gZvM5sKyzc,916
99
+ nshtrainer-0.31.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
100
+ nshtrainer-0.31.0.dist-info/RECORD,,
@@ -1,42 +0,0 @@
1
- import logging
2
-
3
- import torch
4
- import torch.distributed
5
-
6
- log = logging.getLogger(__name__)
7
-
8
-
9
- class DebugModuleMixin:
10
- @torch.jit.unused
11
- def breakpoint(self, rank_zero_only: bool = True):
12
- if (
13
- not rank_zero_only
14
- or not torch.distributed.is_initialized()
15
- or torch.distributed.get_rank() == 0
16
- ):
17
- breakpoint()
18
-
19
- if rank_zero_only and torch.distributed.is_initialized():
20
- _ = torch.distributed.barrier()
21
-
22
- @torch.jit.unused
23
- def ensure_finite(
24
- self,
25
- tensor: torch.Tensor,
26
- name: str | None = None,
27
- throw: bool = False,
28
- ):
29
- name_parts: list[str] = ["Tensor"]
30
- if name is not None:
31
- name_parts.append(name)
32
- name = " ".join(name_parts)
33
-
34
- not_finite = ~torch.isfinite(tensor)
35
- if not_finite.any():
36
- msg = f"{name} has {not_finite.sum().item()}/{not_finite.numel()} non-finite values."
37
- if throw:
38
- raise RuntimeError(msg)
39
- else:
40
- log.warning(msg)
41
- return False
42
- return True
@@ -1,70 +0,0 @@
1
- from typing import Any, Literal, cast
2
-
3
- import torch.distributed
4
- from lightning.pytorch import LightningModule
5
- from torch.distributed import ReduceOp
6
- from typing_extensions import TypeVar
7
-
8
- from ...util.typing_utils import mixin_base_type
9
-
10
- T = TypeVar("T", infer_variance=True)
11
-
12
- ReduceOpStr = Literal[
13
- "avg",
14
- "mean",
15
- "band",
16
- "bor",
17
- "bxor",
18
- "max",
19
- "min",
20
- "premul_sum",
21
- "product",
22
- "sum",
23
- ]
24
- VALID_REDUCE_OPS = (
25
- "avg",
26
- "mean",
27
- "band",
28
- "bor",
29
- "bxor",
30
- "max",
31
- "min",
32
- "premul_sum",
33
- "product",
34
- "sum",
35
- )
36
-
37
-
38
- class DistributedMixin(mixin_base_type(LightningModule)):
39
- def all_gather_object(
40
- self,
41
- object: T,
42
- group: torch.distributed.ProcessGroup | None = None,
43
- ) -> list[T]:
44
- if (
45
- not torch.distributed.is_available()
46
- or not torch.distributed.is_initialized()
47
- ):
48
- return [object]
49
-
50
- object_list = [cast(T, None) for _ in range(self.trainer.world_size)]
51
- torch.distributed.all_gather_object(object_list, object, group=group)
52
- return object_list
53
-
54
- def barrier(self, name: str | None = None):
55
- self.trainer.strategy.barrier(name=name)
56
-
57
- def reduce(
58
- self,
59
- tensor: torch.Tensor,
60
- reduce_op: ReduceOp.RedOpType | ReduceOpStr,
61
- group: Any | None = None,
62
- ) -> torch.Tensor:
63
- if isinstance(reduce_op, str):
64
- # validate reduce_op
65
- if reduce_op not in VALID_REDUCE_OPS:
66
- raise ValueError(
67
- f"reduce_op must be one of {VALID_REDUCE_OPS}, got {reduce_op}"
68
- )
69
-
70
- return self.trainer.strategy.reduce(tensor, group=group, reduce_op=reduce_op)
@@ -1,24 +0,0 @@
1
- from lightning.pytorch import LightningDataModule, LightningModule
2
- from lightning.pytorch.profilers import PassThroughProfiler
3
-
4
- from ...util.typing_utils import mixin_base_type
5
-
6
-
7
- class ProfilerMixin(mixin_base_type(LightningModule)):
8
- @property
9
- def profiler(self):
10
- if not isinstance(self, (LightningModule, LightningDataModule)):
11
- raise TypeError(
12
- "`profiler` can only be used on LightningModule or LightningDataModule"
13
- )
14
-
15
- if (trainer := self.trainer) is None:
16
- raise RuntimeError("trainer is not defined")
17
-
18
- if not hasattr(trainer, "profiler"):
19
- raise RuntimeError("trainer does not have profiler")
20
-
21
- if (profiler := getattr(trainer, "profiler")) is None:
22
- profiler = PassThroughProfiler()
23
-
24
- return profiler
@@ -1,202 +0,0 @@
1
- import logging
2
- from collections.abc import Mapping
3
- from typing import cast
4
-
5
- import torch
6
- from lightning.pytorch import LightningModule, Trainer
7
- from lightning.pytorch.utilities.types import (
8
- LRSchedulerConfigType,
9
- LRSchedulerTypeUnion,
10
- )
11
- from typing_extensions import Protocol, override, runtime_checkable
12
-
13
- from ...util.typing_utils import mixin_base_type
14
- from ..config import BaseConfig
15
- from .callback import CallbackModuleMixin
16
-
17
- log = logging.getLogger(__name__)
18
-
19
-
20
- def _on_train_start_callback(trainer: Trainer, pl_module: LightningModule):
21
- # If we're in PL's "sanity check" mode, we don't need to run this check
22
- if trainer.sanity_checking:
23
- return
24
-
25
- config = cast(BaseConfig, pl_module.hparams)
26
- if config.trainer.sanity_checking.reduce_lr_on_plateau == "disable":
27
- return
28
-
29
- # if no lr schedulers, return
30
- if not trainer.lr_scheduler_configs:
31
- return
32
-
33
- errors: list[str] = []
34
- disable_message = (
35
- "Otherwise, set `config.trainer.sanity_checking.reduce_lr_on_plateau='disable'` "
36
- "to disable this sanity check."
37
- )
38
-
39
- for lr_scheduler_config in trainer.lr_scheduler_configs:
40
- if not lr_scheduler_config.reduce_on_plateau:
41
- continue
42
-
43
- match lr_scheduler_config.interval:
44
- case "epoch":
45
- # we need to make sure that the trainer runs val every `frequency` epochs
46
-
47
- # If `trainer.check_val_every_n_epoch` is None, then Lightning
48
- # will run val every `int(trainer.val_check_interval)` steps.
49
- # So, first we need to make sure that `trainer.val_check_interval` is not None first.
50
- if trainer.check_val_every_n_epoch is None:
51
- errors.append(
52
- "Trainer is not running validation at epoch intervals "
53
- "(i.e., `trainer.check_val_every_n_epoch` is None) but "
54
- f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} is used."
55
- f"Please set `config.trainer.check_val_every_n_epoch={lr_scheduler_config.frequency}`. "
56
- + disable_message
57
- )
58
-
59
- # Second, we make sure that the trainer runs val at least every `frequency` epochs
60
- if (
61
- trainer.check_val_every_n_epoch is not None
62
- and lr_scheduler_config.frequency % trainer.check_val_every_n_epoch
63
- != 0
64
- ):
65
- errors.append(
66
- f"Trainer is not running validation every {lr_scheduler_config.frequency} epochs but "
67
- f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} and frequency={lr_scheduler_config.frequency} is used."
68
- f"Please set `config.trainer.check_val_every_n_epoch` to a multiple of {lr_scheduler_config.frequency}. "
69
- + disable_message
70
- )
71
-
72
- case "step":
73
- # In this case, we need to make sure that the trainer runs val at step intervals
74
- # that are multiples of `frequency`.
75
-
76
- # First, we make sure that validation is run at step intervals
77
- if trainer.check_val_every_n_epoch is not None:
78
- errors.append(
79
- "Trainer is running validation at epoch intervals "
80
- "(i.e., `trainer.check_val_every_n_epoch` is not None) but "
81
- f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} is used."
82
- "Please set `config.trainer.check_val_every_n_epoch=None` "
83
- f"and `config.trainer.val_check_interval={lr_scheduler_config.frequency}`. "
84
- + disable_message
85
- )
86
-
87
- # Second, we make sure `trainer.val_check_interval` is an integer
88
- if not isinstance(trainer.val_check_interval, int):
89
- errors.append(
90
- f"Trainer is not running validation at step intervals "
91
- f"(i.e., `trainer.val_check_interval` is not an integer) but "
92
- f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} is used."
93
- "Please set `config.trainer.val_check_interval=None` "
94
- f"and `config.trainer.val_check_interval={lr_scheduler_config.frequency}`. "
95
- + disable_message
96
- )
97
-
98
- # Third, we make sure that the trainer runs val at least every `frequency` steps
99
- if (
100
- isinstance(trainer.val_check_interval, int)
101
- and trainer.val_check_interval % lr_scheduler_config.frequency != 0
102
- ):
103
- errors.append(
104
- f"Trainer is not running validation every {lr_scheduler_config.frequency} steps but "
105
- f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} and frequency={lr_scheduler_config.frequency} is used."
106
- "Please set `config.trainer.val_check_interval` "
107
- f"to a multiple of {lr_scheduler_config.frequency}. "
108
- + disable_message
109
- )
110
-
111
- case _:
112
- pass
113
-
114
- if not errors:
115
- return
116
-
117
- message = (
118
- "ReduceLRPlateau sanity checks failed with the following errors:\n"
119
- + "\n".join(errors)
120
- )
121
- match config.trainer.sanity_checking.reduce_lr_on_plateau:
122
- case "warn":
123
- log.warning(message)
124
- case "error":
125
- raise ValueError(message)
126
- case _:
127
- pass
128
-
129
-
130
- @runtime_checkable
131
- class CustomRLPImplementation(Protocol):
132
- __reduce_lr_on_plateau__: bool
133
-
134
-
135
- class RLPSanityCheckModuleMixin(mixin_base_type(CallbackModuleMixin)):
136
- @override
137
- def __init__(self, *args, **kwargs):
138
- super().__init__(*args, **kwargs)
139
-
140
- global _on_train_start_callback
141
- self.register_callback(on_train_start=_on_train_start_callback)
142
-
143
- def reduce_lr_on_plateau_config(
144
- self,
145
- lr_scheduler: LRSchedulerTypeUnion | LRSchedulerConfigType,
146
- ) -> LRSchedulerConfigType:
147
- if (trainer := self._trainer) is None:
148
- raise RuntimeError(
149
- "Could not determine the frequency of ReduceLRPlateau scheduler "
150
- "because `self.trainer` is None."
151
- )
152
-
153
- # First, resolve the LR scheduler from the provided config.
154
- lr_scheduler_config: LRSchedulerConfigType
155
- match lr_scheduler:
156
- case Mapping():
157
- lr_scheduler_config = cast(LRSchedulerConfigType, lr_scheduler)
158
- case _:
159
- lr_scheduler_config = {"scheduler": lr_scheduler}
160
-
161
- # Make sure the scheduler is a ReduceLRPlateau scheduler. Otherwise, warn the user.
162
- if (
163
- not isinstance(
164
- lr_scheduler_config["scheduler"],
165
- torch.optim.lr_scheduler.ReduceLROnPlateau,
166
- )
167
- ) and (
168
- not isinstance(lr_scheduler_config["scheduler"], CustomRLPImplementation)
169
- or not lr_scheduler_config["scheduler"].__reduce_lr_on_plateau__
170
- ):
171
- log.warning(
172
- "`reduce_lr_on_plateau_config` should only be used with a ReduceLRPlateau scheduler. "
173
- f"The provided scheduler, {lr_scheduler_config['scheduler']}, does not subclass "
174
- "`torch.optim.lr_scheduler.ReduceLROnPlateau`. "
175
- "Please ensure that the scheduler is a ReduceLRPlateau scheduler. "
176
- "If you are using a custom ReduceLRPlateau scheduler implementation, "
177
- "please either (1) make sure that it subclasses `torch.optim.lr_scheduler.ReduceLROnPlateau`, "
178
- "or (2) set the scheduler's `__reduce_lr_on_plateau__` attribute to `True`."
179
- )
180
-
181
- # If trainer.check_val_every_n_epoch is an integer, then we run val at epoch intervals.
182
- if trainer.check_val_every_n_epoch is not None:
183
- return {
184
- "reduce_on_plateau": True,
185
- "interval": "epoch",
186
- "frequency": trainer.check_val_every_n_epoch,
187
- **lr_scheduler_config,
188
- }
189
-
190
- # Otherwise, we run val at step intervals.
191
- if not isinstance(trainer.val_check_batch, int):
192
- raise ValueError(
193
- "Could not determine the frequency of ReduceLRPlateau scheduler "
194
- f"because {trainer.val_check_batch=} is not an integer."
195
- )
196
-
197
- return {
198
- "reduce_on_plateau": True,
199
- "interval": "step",
200
- "frequency": trainer.val_check_batch,
201
- **lr_scheduler_config,
202
- }
@@ -1,72 +0,0 @@
1
- import logging
2
- from collections.abc import Sequence
3
- from typing import cast
4
-
5
- import torch.nn as nn
6
- from lightning.pytorch import LightningModule, Trainer
7
- from typing_extensions import override
8
-
9
- from ...util.typing_utils import mixin_base_type
10
- from ..config import BaseConfig
11
- from .callback import CallbackRegistrarModuleMixin
12
-
13
- log = logging.getLogger(__name__)
14
-
15
-
16
- def _parameters_to_names(parameters: Sequence[nn.Parameter], model: nn.Module):
17
- mapping = {id(p): n for n, p in model.named_parameters()}
18
- return [mapping[id(p)] for p in parameters]
19
-
20
-
21
- class SharedParametersModuleMixin(mixin_base_type(CallbackRegistrarModuleMixin)):
22
- @override
23
- def __init__(self, *args, **kwargs):
24
- super().__init__(*args, **kwargs)
25
-
26
- self.shared_parameters: list[tuple[nn.Parameter, int | float]] = []
27
- self._warned_shared_parameters = False
28
-
29
- def on_after_backward(_trainer: Trainer, pl_module: LightningModule):
30
- nonlocal self
31
-
32
- config = cast(BaseConfig, pl_module.hparams)
33
- if not config.trainer.supports_shared_parameters:
34
- return
35
-
36
- log.debug(f"Scaling {len(self.shared_parameters)} shared parameters...")
37
- no_grad_parameters: list[nn.Parameter] = []
38
- for p, factor in self.shared_parameters:
39
- if not hasattr(p, "grad") or p.grad is None:
40
- no_grad_parameters.append(p)
41
- continue
42
-
43
- _ = p.grad.data.div_(factor)
44
-
45
- if no_grad_parameters and not self._warned_shared_parameters:
46
- no_grad_parameters_str = ", ".join(
47
- _parameters_to_names(no_grad_parameters, pl_module)
48
- )
49
- log.warning(
50
- "The following parameters were marked as shared, but had no gradients: "
51
- f"{no_grad_parameters_str}"
52
- )
53
- self._warned_shared_parameters = True
54
-
55
- log.debug(
56
- f"Done scaling shared parameters. (len={len(self.shared_parameters)})"
57
- )
58
-
59
- self.register_callback(on_after_backward=on_after_backward)
60
-
61
- def register_shared_parameters(
62
- self, parameters: list[tuple[nn.Parameter, int | float]]
63
- ):
64
- for parameter, factor in parameters:
65
- if not isinstance(parameter, nn.Parameter):
66
- raise ValueError("Shared parameters must be PyTorch parameters")
67
- if not isinstance(factor, (int, float)):
68
- raise ValueError("Factor must be an integer or float")
69
-
70
- self.shared_parameters.append((parameter, factor))
71
-
72
- log.info(f"Registered {len(parameters)} shared parameters")
File without changes