nshtrainer 0.34.1__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -81,6 +81,22 @@ class BalancedBatchSampler(BatchSampler):
81
81
  ):
82
82
  super().__init__(sampler, batch_size, drop_last=drop_last)
83
83
 
84
+ # Validate the dataset
85
+ dataset = self._unwrap_dataset(self.distributed_sampler.dataset)
86
+ # Dataset much either implement `data_sizes`, or we need to provide a custom
87
+ # implementation of the dataset sizes function.
88
+ if isinstance(dataset, DatasetWithSizes):
89
+ log.critical(f"BalancedBatchSampler: Resolved dataset to {type(dataset)}")
90
+
91
+ elif self._data_sizes_fn is not None:
92
+ log.critical("BalancedBatchSampler: Using custom data_sizes_fn")
93
+ else:
94
+ raise ValueError(
95
+ "Dataset must implement the `data_sizes` method, "
96
+ "or a custom data_sizes_fn must be provided "
97
+ "to the BalancedBatchSampler."
98
+ )
99
+
84
100
  self._device = device
85
101
  self._data_sizes_fn = data_sizes_fn
86
102
 
@@ -97,7 +113,6 @@ class BalancedBatchSampler(BatchSampler):
97
113
  # Dataset much either implement `data_sizes`, or we need to provide a custom
98
114
  # implementation of the dataset sizes function.
99
115
  if isinstance(dataset, DatasetWithSizes):
100
- log.critical(f"BalancedBatchSampler: Resolved dataset to {type(dataset)}")
101
116
  return dataset.data_sizes(indices)
102
117
 
103
118
  if (data_sizes_fn := self._data_sizes_fn) is not None:
nshtrainer/model/base.py CHANGED
@@ -15,6 +15,7 @@ from typing_extensions import Self, TypeVar, override
15
15
  from ..callbacks.rlp_sanity_checks import _RLPSanityCheckModuleMixin
16
16
  from ..util._environment_info import EnvironmentConfig
17
17
  from .config import BaseConfig
18
+ from .mixins.callback import CallbackModuleMixin
18
19
  from .mixins.logger import LoggerLightningModuleMixin
19
20
 
20
21
  log = logging.getLogger(__name__)
@@ -53,6 +54,7 @@ VALID_REDUCE_OPS = (
53
54
  class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
54
55
  _RLPSanityCheckModuleMixin,
55
56
  LoggerLightningModuleMixin,
57
+ CallbackModuleMixin,
56
58
  LightningModule,
57
59
  ABC,
58
60
  Generic[THparams],
@@ -0,0 +1,74 @@
1
+ import logging
2
+ from collections.abc import Callable, Iterable, Sequence
3
+ from typing import Any, TypeAlias, cast, final
4
+
5
+ from lightning.pytorch import Callback, LightningModule
6
+ from typing_extensions import override
7
+
8
+ from ...util.typing_utils import mixin_base_type
9
+
10
+ log = logging.getLogger(__name__)
11
+
12
+ CallbackFn: TypeAlias = Callable[[], Callback | Iterable[Callback] | None]
13
+
14
+
15
+ class CallbackRegistrarModuleMixin:
16
+ @override
17
+ def __init__(self, *args, **kwargs):
18
+ super().__init__(*args, **kwargs)
19
+
20
+ self._nshtrainer_callbacks: list[CallbackFn] = []
21
+
22
+ def register_callback(
23
+ self,
24
+ callback: Callback | Iterable[Callback] | CallbackFn | None = None,
25
+ ):
26
+ if not callable(callback):
27
+ callback_ = cast(CallbackFn, lambda: callback)
28
+ else:
29
+ callback_ = callback
30
+
31
+ self._nshtrainer_callbacks.append(callback_)
32
+
33
+
34
+ class CallbackModuleMixin(
35
+ CallbackRegistrarModuleMixin, mixin_base_type(LightningModule)
36
+ ):
37
+ def _gather_all_callbacks(self):
38
+ modules: list[Any] = []
39
+ if isinstance(self, CallbackRegistrarModuleMixin):
40
+ modules.append(self)
41
+ if (
42
+ datamodule := getattr(self.trainer, "datamodule", None)
43
+ ) is not None and isinstance(datamodule, CallbackRegistrarModuleMixin):
44
+ modules.append(datamodule)
45
+ modules.extend(
46
+ module
47
+ for module in self.children()
48
+ if isinstance(module, CallbackRegistrarModuleMixin)
49
+ )
50
+ for module in modules:
51
+ yield from module._nshtrainer_callbacks
52
+
53
+ @final
54
+ @override
55
+ def configure_callbacks(self):
56
+ callbacks = super().configure_callbacks()
57
+ if not isinstance(callbacks, Sequence):
58
+ callbacks = [callbacks]
59
+
60
+ callbacks = list(callbacks)
61
+ for callback_fn in self._gather_all_callbacks():
62
+ if (callback_result := callback_fn()) is None:
63
+ continue
64
+
65
+ if not isinstance(callback_result, Iterable):
66
+ callback_result = [callback_result]
67
+
68
+ for callback in callback_result:
69
+ log.info(
70
+ f"Registering {callback.__class__.__qualname__} callback {callback}"
71
+ )
72
+ callbacks.append(callback)
73
+
74
+ return callbacks
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.34.1
3
+ Version: 0.35.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -32,7 +32,7 @@ nshtrainer/callbacks/timer.py,sha256=quS79oYClDUvQxJkNWmDMe0hwRUkkREgTgqzVrnom50
32
32
  nshtrainer/callbacks/wandb_watch.py,sha256=Y6SEXfIx3kDDQbI5zpP53BVq0FBLJbLd3RJsiHZk1-Y,2921
33
33
  nshtrainer/config.py,sha256=6U7B-kCIMrfEnF_y92RuBm1WfASW7k05Zsm2uHBzRrk,8205
34
34
  nshtrainer/data/__init__.py,sha256=7mk1tr7SWUZ7ySbsf0y0ZPszk7u4QznPhQ-7wnpH9ec,149
35
- nshtrainer/data/balanced_batch_sampler.py,sha256=WAjhbO9EsZ_UadhdW3obBsjvEDMc2V-irpjegqIb7AI,4791
35
+ nshtrainer/data/balanced_batch_sampler.py,sha256=ybMJF-CguaZ17fLEweZ5suaGOiHOMEm3Bn8rQfGTzGQ,5445
36
36
  nshtrainer/data/transform.py,sha256=6SNs3_TpNpfhcwTwvPKyEJ3opM1OT7LmMEYQNHKgRl8,2227
37
37
  nshtrainer/ll/__init__.py,sha256=L-aTi1V1bbvnZjOro8NvI393zbHQSFR9movWSRK9Mds,2477
38
38
  nshtrainer/ll/_experimental.py,sha256=oBQCKOEVYoxuUU9eLb-Fg2B2mzZD7SA0zfAO6lmWZ88,53
@@ -63,8 +63,9 @@ nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=h76oTHYpMxauV_l6lviya5DW-
63
63
  nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy00,50
64
64
  nshtrainer/metrics/_config.py,sha256=jgRBfDAQLFTW7AiUY7CRtdfts6CR6keeuqm0FFMWCzQ,1288
65
65
  nshtrainer/model/__init__.py,sha256=2i_VEy6u_Y1LUGKljHXWeekvhnUcanZM2QyaaBM1Bmw,261
66
- nshtrainer/model/base.py,sha256=1zVY8ybZTzVKhpp7sUC0t360Ut3YmdGxAW5PZAIBSyw,18535
66
+ nshtrainer/model/base.py,sha256=NasbYZJBuEly6Hm9t9HVZk-CUHmy4T7p1v-Ye981XA4,18609
67
67
  nshtrainer/model/config.py,sha256=Q4Wong6w3cp_Sq7s8iZdABKF-LZBbSCFn_TQPYkhkrI,6572
68
+ nshtrainer/model/mixins/callback.py,sha256=lvX9Q2ErETXmGFd79CscSAOJAlTWq-mwMKVC0d0uH1c,2324
68
69
  nshtrainer/model/mixins/logger.py,sha256=xOymSTofukEYZGkGojXsMEO__ZlBI5lIPZVmlotMEX8,5291
69
70
  nshtrainer/nn/__init__.py,sha256=0QPFl02a71WZQjLMGOlFNMmsYP5aa1q3eABHmnWH58Q,1427
70
71
  nshtrainer/nn/mlp.py,sha256=V0FrScpIUdg_IgIO8GMtIsGEtmHjwF14i2IWxmZrsqg,5952
@@ -97,6 +98,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
97
98
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
98
99
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
99
100
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
100
- nshtrainer-0.34.1.dist-info/METADATA,sha256=c_iXv-CQLl6kig2u3lmrP4EDSOdMvq8L2WewYiFL-8Q,916
101
- nshtrainer-0.34.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
102
- nshtrainer-0.34.1.dist-info/RECORD,,
101
+ nshtrainer-0.35.0.dist-info/METADATA,sha256=NBZegh-RUfnkVt_ERUPdH7fdCFZriQZXoMskq_8HB60,916
102
+ nshtrainer-0.35.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
103
+ nshtrainer-0.35.0.dist-info/RECORD,,