nshtrainer 0.34.1__tar.gz → 0.35.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/PKG-INFO +1 -1
  2. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/pyproject.toml +1 -1
  3. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/data/balanced_batch_sampler.py +16 -1
  4. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/model/base.py +2 -0
  5. nshtrainer-0.35.0/src/nshtrainer/model/mixins/callback.py +74 -0
  6. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/README.md +0 -0
  7. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/__init__.py +0 -0
  8. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/_callback.py +0 -0
  9. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/_checkpoint/loader.py +0 -0
  10. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/_checkpoint/metadata.py +0 -0
  11. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/_checkpoint/saver.py +0 -0
  12. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/_directory.py +0 -0
  13. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/_experimental/__init__.py +0 -0
  14. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/_hf_hub.py +0 -0
  15. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/callbacks/__init__.py +0 -0
  16. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
  17. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/callbacks/actsave.py +0 -0
  18. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/callbacks/base.py +0 -0
  19. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
  20. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
  21. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
  22. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
  23. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
  24. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/callbacks/debug_flag.py +0 -0
  25. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/callbacks/directory_setup.py +0 -0
  26. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  27. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/callbacks/ema.py +0 -0
  28. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  29. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  30. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/callbacks/interval.py +0 -0
  31. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  32. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  33. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/callbacks/print_table.py +0 -0
  34. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/callbacks/rlp_sanity_checks.py +0 -0
  35. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/callbacks/shared_parameters.py +0 -0
  36. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
  37. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/callbacks/timer.py +0 -0
  38. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  39. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/config.py +0 -0
  40. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/data/__init__.py +0 -0
  41. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/data/transform.py +0 -0
  42. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/ll/__init__.py +0 -0
  43. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/ll/_experimental.py +0 -0
  44. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/ll/actsave.py +0 -0
  45. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/ll/callbacks.py +0 -0
  46. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/ll/config.py +0 -0
  47. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/ll/data.py +0 -0
  48. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/ll/log.py +0 -0
  49. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/ll/lr_scheduler.py +0 -0
  50. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/ll/model.py +0 -0
  51. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/ll/nn.py +0 -0
  52. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/ll/optimizer.py +0 -0
  53. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/ll/runner.py +0 -0
  54. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/ll/snapshot.py +0 -0
  55. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/ll/snoop.py +0 -0
  56. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/ll/trainer.py +0 -0
  57. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/ll/typecheck.py +0 -0
  58. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/ll/util.py +0 -0
  59. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/loggers/__init__.py +0 -0
  60. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/loggers/_base.py +0 -0
  61. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/loggers/csv.py +0 -0
  62. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/loggers/tensorboard.py +0 -0
  63. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/loggers/wandb.py +0 -0
  64. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  65. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/lr_scheduler/_base.py +0 -0
  66. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  67. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  68. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/metrics/__init__.py +0 -0
  69. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/metrics/_config.py +0 -0
  70. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/model/__init__.py +0 -0
  71. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/model/config.py +0 -0
  72. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/model/mixins/logger.py +0 -0
  73. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/nn/__init__.py +0 -0
  74. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/nn/mlp.py +0 -0
  75. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/nn/module_dict.py +0 -0
  76. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/nn/module_list.py +0 -0
  77. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/nn/nonlinearity.py +0 -0
  78. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/optimizer.py +0 -0
  79. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/profiler/__init__.py +0 -0
  80. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/profiler/_base.py +0 -0
  81. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/profiler/advanced.py +0 -0
  82. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/profiler/pytorch.py +0 -0
  83. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/profiler/simple.py +0 -0
  84. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/runner.py +0 -0
  85. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/scripts/find_packages.py +0 -0
  86. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/trainer/__init__.py +0 -0
  87. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/trainer/_config.py +0 -0
  88. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  89. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
  90. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/trainer/signal_connector.py +0 -0
  91. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/trainer/trainer.py +0 -0
  92. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/util/_environment_info.py +0 -0
  93. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/util/_useful_types.py +0 -0
  94. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/util/bf16.py +0 -0
  95. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/util/config/__init__.py +0 -0
  96. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/util/config/dtype.py +0 -0
  97. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/util/config/duration.py +0 -0
  98. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/util/environment.py +0 -0
  99. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/util/path.py +0 -0
  100. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/util/seed.py +0 -0
  101. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/util/slurm.py +0 -0
  102. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/util/typed.py +0 -0
  103. {nshtrainer-0.34.1 → nshtrainer-0.35.0}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.34.1
3
+ Version: 0.35.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "0.34.1"
3
+ version = "0.35.0"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -81,6 +81,22 @@ class BalancedBatchSampler(BatchSampler):
81
81
  ):
82
82
  super().__init__(sampler, batch_size, drop_last=drop_last)
83
83
 
84
+ # Validate the dataset
85
+ dataset = self._unwrap_dataset(self.distributed_sampler.dataset)
86
+ # Dataset much either implement `data_sizes`, or we need to provide a custom
87
+ # implementation of the dataset sizes function.
88
+ if isinstance(dataset, DatasetWithSizes):
89
+ log.critical(f"BalancedBatchSampler: Resolved dataset to {type(dataset)}")
90
+
91
+ elif self._data_sizes_fn is not None:
92
+ log.critical("BalancedBatchSampler: Using custom data_sizes_fn")
93
+ else:
94
+ raise ValueError(
95
+ "Dataset must implement the `data_sizes` method, "
96
+ "or a custom data_sizes_fn must be provided "
97
+ "to the BalancedBatchSampler."
98
+ )
99
+
84
100
  self._device = device
85
101
  self._data_sizes_fn = data_sizes_fn
86
102
 
@@ -97,7 +113,6 @@ class BalancedBatchSampler(BatchSampler):
97
113
  # Dataset much either implement `data_sizes`, or we need to provide a custom
98
114
  # implementation of the dataset sizes function.
99
115
  if isinstance(dataset, DatasetWithSizes):
100
- log.critical(f"BalancedBatchSampler: Resolved dataset to {type(dataset)}")
101
116
  return dataset.data_sizes(indices)
102
117
 
103
118
  if (data_sizes_fn := self._data_sizes_fn) is not None:
@@ -15,6 +15,7 @@ from typing_extensions import Self, TypeVar, override
15
15
  from ..callbacks.rlp_sanity_checks import _RLPSanityCheckModuleMixin
16
16
  from ..util._environment_info import EnvironmentConfig
17
17
  from .config import BaseConfig
18
+ from .mixins.callback import CallbackModuleMixin
18
19
  from .mixins.logger import LoggerLightningModuleMixin
19
20
 
20
21
  log = logging.getLogger(__name__)
@@ -53,6 +54,7 @@ VALID_REDUCE_OPS = (
53
54
  class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
54
55
  _RLPSanityCheckModuleMixin,
55
56
  LoggerLightningModuleMixin,
57
+ CallbackModuleMixin,
56
58
  LightningModule,
57
59
  ABC,
58
60
  Generic[THparams],
@@ -0,0 +1,74 @@
1
+ import logging
2
+ from collections.abc import Callable, Iterable, Sequence
3
+ from typing import Any, TypeAlias, cast, final
4
+
5
+ from lightning.pytorch import Callback, LightningModule
6
+ from typing_extensions import override
7
+
8
+ from ...util.typing_utils import mixin_base_type
9
+
10
+ log = logging.getLogger(__name__)
11
+
12
+ CallbackFn: TypeAlias = Callable[[], Callback | Iterable[Callback] | None]
13
+
14
+
15
+ class CallbackRegistrarModuleMixin:
16
+ @override
17
+ def __init__(self, *args, **kwargs):
18
+ super().__init__(*args, **kwargs)
19
+
20
+ self._nshtrainer_callbacks: list[CallbackFn] = []
21
+
22
+ def register_callback(
23
+ self,
24
+ callback: Callback | Iterable[Callback] | CallbackFn | None = None,
25
+ ):
26
+ if not callable(callback):
27
+ callback_ = cast(CallbackFn, lambda: callback)
28
+ else:
29
+ callback_ = callback
30
+
31
+ self._nshtrainer_callbacks.append(callback_)
32
+
33
+
34
+ class CallbackModuleMixin(
35
+ CallbackRegistrarModuleMixin, mixin_base_type(LightningModule)
36
+ ):
37
+ def _gather_all_callbacks(self):
38
+ modules: list[Any] = []
39
+ if isinstance(self, CallbackRegistrarModuleMixin):
40
+ modules.append(self)
41
+ if (
42
+ datamodule := getattr(self.trainer, "datamodule", None)
43
+ ) is not None and isinstance(datamodule, CallbackRegistrarModuleMixin):
44
+ modules.append(datamodule)
45
+ modules.extend(
46
+ module
47
+ for module in self.children()
48
+ if isinstance(module, CallbackRegistrarModuleMixin)
49
+ )
50
+ for module in modules:
51
+ yield from module._nshtrainer_callbacks
52
+
53
+ @final
54
+ @override
55
+ def configure_callbacks(self):
56
+ callbacks = super().configure_callbacks()
57
+ if not isinstance(callbacks, Sequence):
58
+ callbacks = [callbacks]
59
+
60
+ callbacks = list(callbacks)
61
+ for callback_fn in self._gather_all_callbacks():
62
+ if (callback_result := callback_fn()) is None:
63
+ continue
64
+
65
+ if not isinstance(callback_result, Iterable):
66
+ callback_result = [callback_result]
67
+
68
+ for callback in callback_result:
69
+ log.info(
70
+ f"Registering {callback.__class__.__qualname__} callback {callback}"
71
+ )
72
+ callbacks.append(callback)
73
+
74
+ return callbacks
File without changes