nshtrainer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. nshtrainer/__init__.py +64 -0
  2. nshtrainer/_experimental/__init__.py +2 -0
  3. nshtrainer/_experimental/flops/__init__.py +48 -0
  4. nshtrainer/_experimental/flops/flop_counter.py +787 -0
  5. nshtrainer/_experimental/flops/module_tracker.py +140 -0
  6. nshtrainer/_snoop.py +216 -0
  7. nshtrainer/_submit/print_environment_info.py +31 -0
  8. nshtrainer/_submit/session/_output.py +12 -0
  9. nshtrainer/_submit/session/_script.py +109 -0
  10. nshtrainer/_submit/session/lsf.py +467 -0
  11. nshtrainer/_submit/session/slurm.py +573 -0
  12. nshtrainer/_submit/session/unified.py +350 -0
  13. nshtrainer/actsave/__init__.py +7 -0
  14. nshtrainer/actsave/_callback.py +75 -0
  15. nshtrainer/actsave/_loader.py +144 -0
  16. nshtrainer/actsave/_saver.py +337 -0
  17. nshtrainer/callbacks/__init__.py +35 -0
  18. nshtrainer/callbacks/_throughput_monitor_callback.py +549 -0
  19. nshtrainer/callbacks/base.py +113 -0
  20. nshtrainer/callbacks/early_stopping.py +112 -0
  21. nshtrainer/callbacks/ema.py +383 -0
  22. nshtrainer/callbacks/finite_checks.py +75 -0
  23. nshtrainer/callbacks/gradient_skipping.py +103 -0
  24. nshtrainer/callbacks/interval.py +322 -0
  25. nshtrainer/callbacks/latest_epoch_checkpoint.py +45 -0
  26. nshtrainer/callbacks/log_epoch.py +35 -0
  27. nshtrainer/callbacks/norm_logging.py +187 -0
  28. nshtrainer/callbacks/on_exception_checkpoint.py +44 -0
  29. nshtrainer/callbacks/print_table.py +90 -0
  30. nshtrainer/callbacks/throughput_monitor.py +56 -0
  31. nshtrainer/callbacks/timer.py +157 -0
  32. nshtrainer/callbacks/wandb_watch.py +103 -0
  33. nshtrainer/config.py +289 -0
  34. nshtrainer/data/__init__.py +4 -0
  35. nshtrainer/data/balanced_batch_sampler.py +132 -0
  36. nshtrainer/data/transform.py +67 -0
  37. nshtrainer/lr_scheduler/__init__.py +18 -0
  38. nshtrainer/lr_scheduler/_base.py +101 -0
  39. nshtrainer/lr_scheduler/linear_warmup_cosine.py +138 -0
  40. nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +73 -0
  41. nshtrainer/model/__init__.py +44 -0
  42. nshtrainer/model/base.py +641 -0
  43. nshtrainer/model/config.py +2064 -0
  44. nshtrainer/model/modules/callback.py +157 -0
  45. nshtrainer/model/modules/debug.py +42 -0
  46. nshtrainer/model/modules/distributed.py +70 -0
  47. nshtrainer/model/modules/logger.py +170 -0
  48. nshtrainer/model/modules/profiler.py +24 -0
  49. nshtrainer/model/modules/rlp_sanity_checks.py +202 -0
  50. nshtrainer/model/modules/shared_parameters.py +72 -0
  51. nshtrainer/nn/__init__.py +19 -0
  52. nshtrainer/nn/mlp.py +106 -0
  53. nshtrainer/nn/module_dict.py +66 -0
  54. nshtrainer/nn/module_list.py +50 -0
  55. nshtrainer/nn/nonlinearity.py +157 -0
  56. nshtrainer/optimizer.py +62 -0
  57. nshtrainer/runner.py +21 -0
  58. nshtrainer/scripts/check_env.py +41 -0
  59. nshtrainer/scripts/find_packages.py +51 -0
  60. nshtrainer/trainer/__init__.py +1 -0
  61. nshtrainer/trainer/signal_connector.py +208 -0
  62. nshtrainer/trainer/trainer.py +340 -0
  63. nshtrainer/typecheck.py +144 -0
  64. nshtrainer/util/environment.py +119 -0
  65. nshtrainer/util/seed.py +11 -0
  66. nshtrainer/util/singleton.py +89 -0
  67. nshtrainer/util/slurm.py +49 -0
  68. nshtrainer/util/typed.py +2 -0
  69. nshtrainer/util/typing_utils.py +19 -0
  70. nshtrainer-0.1.0.dist-info/METADATA +18 -0
  71. nshtrainer-0.1.0.dist-info/RECORD +72 -0
  72. nshtrainer-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,337 @@
1
+ import contextlib
2
+ import fnmatch
3
+ import tempfile
4
+ import uuid
5
+ import weakref
6
+ from collections.abc import Callable, Mapping
7
+ from dataclasses import dataclass
8
+ from functools import wraps
9
+ from logging import getLogger
10
+ from pathlib import Path
11
+ from typing import Generic, TypeAlias, cast, overload
12
+
13
+ import numpy as np
14
+ import torch
15
+ from lightning_utilities.core.apply_func import apply_to_collection
16
+ from typing_extensions import ParamSpec, TypeVar, override
17
+
18
+ log = getLogger(__name__)
19
+
20
+ Value: TypeAlias = int | float | complex | bool | str | np.ndarray | torch.Tensor | None
21
+ ValueOrLambda = Value | Callable[..., Value]
22
+
23
+
24
+ def _to_numpy(activation: Value) -> np.ndarray:
25
+ # Make sure it's not `None`
26
+ if activation is None:
27
+ raise ValueError("Activation should not be `None`")
28
+
29
+ if isinstance(activation, np.ndarray):
30
+ return activation
31
+ if isinstance(activation, torch.Tensor):
32
+ activation = activation.detach()
33
+ if activation.is_floating_point():
34
+ # NOTE: We need to convert to float32 because [b]float16 is not supported by numpy
35
+ activation = activation.float()
36
+ return activation.cpu().numpy()
37
+ if isinstance(activation, (int, float, complex, str, bool)):
38
+ return np.array(activation)
39
+
40
+ return activation
41
+
42
+
43
+ T = TypeVar("T", infer_variance=True)
44
+
45
+
46
+ # A wrapper around weakref.ref that allows for primitive types
47
+ # To get around errors like:
48
+ # TypeError: cannot create weak reference to 'int' object
49
+ class WeakRef(Generic[T]):
50
+ _ref: Callable[[], T] | None
51
+
52
+ def __init__(self, obj: T):
53
+ try:
54
+ self._ref = cast(Callable[[], T], weakref.ref(obj))
55
+ except TypeError as e:
56
+ if "cannot create weak reference" not in str(e):
57
+ raise
58
+ self._ref = lambda: obj
59
+
60
+ def __call__(self) -> T:
61
+ if self._ref is None:
62
+ raise RuntimeError("WeakRef is deleted")
63
+ return self._ref()
64
+
65
+ def delete(self):
66
+ del self._ref
67
+ self._ref = None
68
+
69
+
70
+ @dataclass
71
+ class Activation:
72
+ name: str
73
+ ref: WeakRef[ValueOrLambda] | None
74
+ transformed: np.ndarray | None = None
75
+
76
+ def __post_init__(self):
77
+ # Update the `name` to replace `/` with `.`
78
+ self.name = self.name.replace("/", ".")
79
+
80
+ def __call__(self) -> np.ndarray | None:
81
+ # If we have a transformed value, we return it
82
+ if self.transformed is not None:
83
+ return self.transformed
84
+
85
+ if self.ref is None:
86
+ raise RuntimeError("Activation is deleted")
87
+
88
+ # If we have a lambda, we need to call it
89
+ unrwapped_ref = self.ref()
90
+ activation = unrwapped_ref
91
+ if callable(unrwapped_ref):
92
+ activation = unrwapped_ref()
93
+
94
+ # If we have a `None`, we return early
95
+ if activation is None:
96
+ return None
97
+
98
+ activation = apply_to_collection(activation, torch.Tensor, _to_numpy)
99
+ activation = _to_numpy(activation)
100
+
101
+ # Set the transformed value
102
+ self.transformed = activation
103
+
104
+ # Delete the reference
105
+ self.ref.delete()
106
+ del self.ref
107
+ self.ref = None
108
+
109
+ return self.transformed
110
+
111
+ @classmethod
112
+ def from_value_or_lambda(cls, name: str, value_or_lambda: ValueOrLambda):
113
+ return cls(name, WeakRef(value_or_lambda))
114
+
115
+ @classmethod
116
+ def from_dict(cls, d: Mapping[str, ValueOrLambda]):
117
+ return [cls.from_value_or_lambda(k, v) for k, v in d.items()]
118
+
119
+
120
+ Transform = Callable[[Activation], Mapping[str, ValueOrLambda]]
121
+
122
+
123
+ def _ensure_supported():
124
+ try:
125
+ import torch.distributed as dist
126
+
127
+ if dist.is_initialized() and dist.get_world_size() > 1:
128
+ raise RuntimeError("Only single GPU is supported at the moment")
129
+ except ImportError:
130
+ pass
131
+
132
+
133
+ P = ParamSpec("P")
134
+
135
+
136
+ def _ignore_if_scripting(fn: Callable[P, None]) -> Callable[P, None]:
137
+ @wraps(fn)
138
+ def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
139
+ if torch.jit.is_scripting():
140
+ return
141
+
142
+ _ensure_supported()
143
+ fn(*args, **kwargs)
144
+
145
+ return wrapper
146
+
147
+
148
+ class _Saver:
149
+ def __init__(
150
+ self,
151
+ save_dir: Path,
152
+ prefixes_fn: Callable[[], list[str]],
153
+ *,
154
+ filters: list[str] | None = None,
155
+ ):
156
+ # Create a directory under `save_dir` by autoincrementing
157
+ # (i.e., every activation save context, we create a new directory)
158
+ # The id = the number of activation subdirectories
159
+ self._id = sum(1 for subdir in save_dir.glob("*") if subdir.is_dir())
160
+ save_dir.mkdir(parents=True, exist_ok=True)
161
+
162
+ # Add a .activationbase file to the save_dir to indicate that this is an activation base
163
+ (save_dir / ".activationbase").touch(exist_ok=True)
164
+
165
+ self._save_dir = save_dir / f"{self._id:04d}"
166
+ # Make sure `self._save_dir` does not exist and create it
167
+ self._save_dir.mkdir(exist_ok=False)
168
+
169
+ self._prefixes_fn = prefixes_fn
170
+ self._filters = filters
171
+
172
+ def _save_activation(self, activation: Activation):
173
+ # If the activation value is `None`, we skip it.
174
+ if (activation_value := activation()) is None:
175
+ return
176
+
177
+ # Save the activation to self._save_dir / name / {id}.npz, where id is an auto-incrementing integer
178
+ file_name = ".".join(self._prefixes_fn() + [activation.name])
179
+ path = self._save_dir / file_name
180
+ path.mkdir(exist_ok=True, parents=True)
181
+
182
+ # Get the next id and save the activation
183
+ id = len(list(path.glob("*.npy")))
184
+ np.save(path / f"{id:04d}.npy", activation_value)
185
+
186
+ @_ignore_if_scripting
187
+ def save(
188
+ self,
189
+ acts: dict[str, ValueOrLambda] | None = None,
190
+ /,
191
+ **kwargs: ValueOrLambda,
192
+ ):
193
+ kwargs.update(acts or {})
194
+
195
+ # Build activations
196
+ activations = Activation.from_dict(kwargs)
197
+
198
+ for activation in activations:
199
+ # Make sure name matches at least one filter if filters are specified
200
+ if self._filters is not None and all(
201
+ not fnmatch.fnmatch(activation.name, f) for f in self._filters
202
+ ):
203
+ continue
204
+
205
+ # Save the current activation
206
+ self._save_activation(activation)
207
+
208
+ del activations
209
+
210
+
211
+ class ActSaveProvider:
212
+ _saver: _Saver | None = None
213
+ _prefixes: list[str] = []
214
+
215
+ def initialize(self, save_dir: Path | None = None):
216
+ """
217
+ Initializes the saver with the given configuration and save directory.
218
+
219
+ Args:
220
+ save_dir (Path): The directory where the saved files will be stored.
221
+ """
222
+ if self._saver is None:
223
+ if save_dir is None:
224
+ save_dir = Path(tempfile.gettempdir()) / f"actsave-{uuid.uuid4()}"
225
+ log.critical(f"No save_dir specified, using {save_dir=}")
226
+ self._saver = _Saver(
227
+ save_dir,
228
+ lambda: self._prefixes,
229
+ )
230
+
231
+ @contextlib.contextmanager
232
+ def enabled(self, save_dir: Path | None = None):
233
+ """
234
+ Context manager that enables the actsave functionality with the specified configuration.
235
+
236
+ Args:
237
+ save_dir (Path): The directory where the saved files will be stored.
238
+ """
239
+ prev = self._saver
240
+ self.initialize(save_dir)
241
+ try:
242
+ yield
243
+ finally:
244
+ self._saver = prev
245
+
246
+ @override
247
+ def __init__(self):
248
+ super().__init__()
249
+
250
+ self._saver = None
251
+ self._prefixes = []
252
+
253
+ @contextlib.contextmanager
254
+ def context(self, label: str):
255
+ """
256
+ A context manager that adds a label to the current context.
257
+
258
+ Args:
259
+ label (str): The label for the context.
260
+ """
261
+ if torch.jit.is_scripting():
262
+ yield
263
+ return
264
+
265
+ if self._saver is None:
266
+ yield
267
+ return
268
+
269
+ _ensure_supported()
270
+
271
+ log.debug(f"Entering ActSave context {label}")
272
+ self._prefixes.append(label)
273
+ try:
274
+ yield
275
+ finally:
276
+ _ = self._prefixes.pop()
277
+
278
+ prefix = context
279
+
280
+ @overload
281
+ def __call__(
282
+ self,
283
+ acts: dict[str, ValueOrLambda] | None = None,
284
+ /,
285
+ **kwargs: ValueOrLambda,
286
+ ):
287
+ """
288
+ Saves the activations to disk.
289
+
290
+ Args:
291
+ acts (dict[str, ValueOrLambda] | None, optional): A dictionary of acts. Defaults to None.
292
+ **kwargs (ValueOrLambda): Additional keyword arguments.
293
+
294
+ Returns:
295
+ None
296
+
297
+ """
298
+ ...
299
+
300
+ @overload
301
+ def __call__(self, acts: Callable[[], dict[str, ValueOrLambda]], /):
302
+ """
303
+ Saves the activations to disk.
304
+
305
+ Args:
306
+ acts (Callable[[], dict[str, ValueOrLambda]]): A callable that returns a dictionary of acts.
307
+ **kwargs (ValueOrLambda): Additional keyword arguments.
308
+
309
+ Returns:
310
+ None
311
+
312
+ """
313
+ ...
314
+
315
+ def __call__(
316
+ self,
317
+ acts: (
318
+ dict[str, ValueOrLambda] | Callable[[], dict[str, ValueOrLambda]] | None
319
+ ) = None,
320
+ /,
321
+ **kwargs: ValueOrLambda,
322
+ ):
323
+ if torch.jit.is_scripting():
324
+ return
325
+
326
+ if self._saver is None:
327
+ return
328
+
329
+ if acts is not None and callable(acts):
330
+ acts = acts()
331
+ self._saver.save(acts, **kwargs)
332
+
333
+ save = __call__
334
+
335
+
336
+ ActSave = ActSaveProvider()
337
+ ActivationSaver = ActSave
@@ -0,0 +1,35 @@
1
+ from typing import Annotated
2
+
3
+ from ..config import Field
4
+ from .base import CallbackConfigBase as CallbackConfigBase
5
+ from .early_stopping import EarlyStopping as EarlyStopping
6
+ from .ema import EMA as EMA
7
+ from .ema import EMAConfig as EMAConfig
8
+ from .finite_checks import FiniteChecksCallback as FiniteChecksCallback
9
+ from .finite_checks import FiniteChecksConfig as FiniteChecksConfig
10
+ from .gradient_skipping import GradientSkipping as GradientSkipping
11
+ from .gradient_skipping import GradientSkippingConfig as GradientSkippingConfig
12
+ from .interval import EpochIntervalCallback as EpochIntervalCallback
13
+ from .interval import IntervalCallback as IntervalCallback
14
+ from .interval import StepIntervalCallback as StepIntervalCallback
15
+ from .latest_epoch_checkpoint import LatestEpochCheckpoint as LatestEpochCheckpoint
16
+ from .log_epoch import LogEpochCallback as LogEpochCallback
17
+ from .norm_logging import NormLoggingCallback as NormLoggingCallback
18
+ from .norm_logging import NormLoggingConfig as NormLoggingConfig
19
+ from .on_exception_checkpoint import OnExceptionCheckpoint as OnExceptionCheckpoint
20
+ from .print_table import PrintTableMetricsCallback as PrintTableMetricsCallback
21
+ from .print_table import PrintTableMetricsConfig as PrintTableMetricsConfig
22
+ from .throughput_monitor import ThroughputMonitorConfig as ThroughputMonitorConfig
23
+ from .timer import EpochTimer as EpochTimer
24
+ from .timer import EpochTimerConfig as EpochTimerConfig
25
+
26
+ CallbackConfig = Annotated[
27
+ ThroughputMonitorConfig
28
+ | EpochTimerConfig
29
+ | PrintTableMetricsConfig
30
+ | FiniteChecksConfig
31
+ | NormLoggingConfig
32
+ | GradientSkippingConfig
33
+ | EMAConfig,
34
+ Field(discriminator="name"),
35
+ ]