nshtrainer 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +64 -0
- nshtrainer/_experimental/__init__.py +2 -0
- nshtrainer/_experimental/flops/__init__.py +48 -0
- nshtrainer/_experimental/flops/flop_counter.py +787 -0
- nshtrainer/_experimental/flops/module_tracker.py +140 -0
- nshtrainer/_snoop.py +216 -0
- nshtrainer/_submit/print_environment_info.py +31 -0
- nshtrainer/_submit/session/_output.py +12 -0
- nshtrainer/_submit/session/_script.py +109 -0
- nshtrainer/_submit/session/lsf.py +467 -0
- nshtrainer/_submit/session/slurm.py +573 -0
- nshtrainer/_submit/session/unified.py +350 -0
- nshtrainer/actsave/__init__.py +7 -0
- nshtrainer/actsave/_callback.py +75 -0
- nshtrainer/actsave/_loader.py +144 -0
- nshtrainer/actsave/_saver.py +337 -0
- nshtrainer/callbacks/__init__.py +35 -0
- nshtrainer/callbacks/_throughput_monitor_callback.py +549 -0
- nshtrainer/callbacks/base.py +113 -0
- nshtrainer/callbacks/early_stopping.py +112 -0
- nshtrainer/callbacks/ema.py +383 -0
- nshtrainer/callbacks/finite_checks.py +75 -0
- nshtrainer/callbacks/gradient_skipping.py +103 -0
- nshtrainer/callbacks/interval.py +322 -0
- nshtrainer/callbacks/latest_epoch_checkpoint.py +45 -0
- nshtrainer/callbacks/log_epoch.py +35 -0
- nshtrainer/callbacks/norm_logging.py +187 -0
- nshtrainer/callbacks/on_exception_checkpoint.py +44 -0
- nshtrainer/callbacks/print_table.py +90 -0
- nshtrainer/callbacks/throughput_monitor.py +56 -0
- nshtrainer/callbacks/timer.py +157 -0
- nshtrainer/callbacks/wandb_watch.py +103 -0
- nshtrainer/config.py +289 -0
- nshtrainer/data/__init__.py +4 -0
- nshtrainer/data/balanced_batch_sampler.py +132 -0
- nshtrainer/data/transform.py +67 -0
- nshtrainer/lr_scheduler/__init__.py +18 -0
- nshtrainer/lr_scheduler/_base.py +101 -0
- nshtrainer/lr_scheduler/linear_warmup_cosine.py +138 -0
- nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +73 -0
- nshtrainer/model/__init__.py +44 -0
- nshtrainer/model/base.py +641 -0
- nshtrainer/model/config.py +2064 -0
- nshtrainer/model/modules/callback.py +157 -0
- nshtrainer/model/modules/debug.py +42 -0
- nshtrainer/model/modules/distributed.py +70 -0
- nshtrainer/model/modules/logger.py +170 -0
- nshtrainer/model/modules/profiler.py +24 -0
- nshtrainer/model/modules/rlp_sanity_checks.py +202 -0
- nshtrainer/model/modules/shared_parameters.py +72 -0
- nshtrainer/nn/__init__.py +19 -0
- nshtrainer/nn/mlp.py +106 -0
- nshtrainer/nn/module_dict.py +66 -0
- nshtrainer/nn/module_list.py +50 -0
- nshtrainer/nn/nonlinearity.py +157 -0
- nshtrainer/optimizer.py +62 -0
- nshtrainer/runner.py +21 -0
- nshtrainer/scripts/check_env.py +41 -0
- nshtrainer/scripts/find_packages.py +51 -0
- nshtrainer/trainer/__init__.py +1 -0
- nshtrainer/trainer/signal_connector.py +208 -0
- nshtrainer/trainer/trainer.py +340 -0
- nshtrainer/typecheck.py +144 -0
- nshtrainer/util/environment.py +119 -0
- nshtrainer/util/seed.py +11 -0
- nshtrainer/util/singleton.py +89 -0
- nshtrainer/util/slurm.py +49 -0
- nshtrainer/util/typed.py +2 -0
- nshtrainer/util/typing_utils.py +19 -0
- nshtrainer-0.1.0.dist-info/METADATA +18 -0
- nshtrainer-0.1.0.dist-info/RECORD +72 -0
- nshtrainer-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,337 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
import fnmatch
|
|
3
|
+
import tempfile
|
|
4
|
+
import uuid
|
|
5
|
+
import weakref
|
|
6
|
+
from collections.abc import Callable, Mapping
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from functools import wraps
|
|
9
|
+
from logging import getLogger
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Generic, TypeAlias, cast, overload
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import torch
|
|
15
|
+
from lightning_utilities.core.apply_func import apply_to_collection
|
|
16
|
+
from typing_extensions import ParamSpec, TypeVar, override
|
|
17
|
+
|
|
18
|
+
log = getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
Value: TypeAlias = int | float | complex | bool | str | np.ndarray | torch.Tensor | None
|
|
21
|
+
ValueOrLambda = Value | Callable[..., Value]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _to_numpy(activation: Value) -> np.ndarray:
|
|
25
|
+
# Make sure it's not `None`
|
|
26
|
+
if activation is None:
|
|
27
|
+
raise ValueError("Activation should not be `None`")
|
|
28
|
+
|
|
29
|
+
if isinstance(activation, np.ndarray):
|
|
30
|
+
return activation
|
|
31
|
+
if isinstance(activation, torch.Tensor):
|
|
32
|
+
activation = activation.detach()
|
|
33
|
+
if activation.is_floating_point():
|
|
34
|
+
# NOTE: We need to convert to float32 because [b]float16 is not supported by numpy
|
|
35
|
+
activation = activation.float()
|
|
36
|
+
return activation.cpu().numpy()
|
|
37
|
+
if isinstance(activation, (int, float, complex, str, bool)):
|
|
38
|
+
return np.array(activation)
|
|
39
|
+
|
|
40
|
+
return activation
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
T = TypeVar("T", infer_variance=True)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# A wrapper around weakref.ref that allows for primitive types
|
|
47
|
+
# To get around errors like:
|
|
48
|
+
# TypeError: cannot create weak reference to 'int' object
|
|
49
|
+
class WeakRef(Generic[T]):
|
|
50
|
+
_ref: Callable[[], T] | None
|
|
51
|
+
|
|
52
|
+
def __init__(self, obj: T):
|
|
53
|
+
try:
|
|
54
|
+
self._ref = cast(Callable[[], T], weakref.ref(obj))
|
|
55
|
+
except TypeError as e:
|
|
56
|
+
if "cannot create weak reference" not in str(e):
|
|
57
|
+
raise
|
|
58
|
+
self._ref = lambda: obj
|
|
59
|
+
|
|
60
|
+
def __call__(self) -> T:
|
|
61
|
+
if self._ref is None:
|
|
62
|
+
raise RuntimeError("WeakRef is deleted")
|
|
63
|
+
return self._ref()
|
|
64
|
+
|
|
65
|
+
def delete(self):
|
|
66
|
+
del self._ref
|
|
67
|
+
self._ref = None
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@dataclass
|
|
71
|
+
class Activation:
|
|
72
|
+
name: str
|
|
73
|
+
ref: WeakRef[ValueOrLambda] | None
|
|
74
|
+
transformed: np.ndarray | None = None
|
|
75
|
+
|
|
76
|
+
def __post_init__(self):
|
|
77
|
+
# Update the `name` to replace `/` with `.`
|
|
78
|
+
self.name = self.name.replace("/", ".")
|
|
79
|
+
|
|
80
|
+
def __call__(self) -> np.ndarray | None:
|
|
81
|
+
# If we have a transformed value, we return it
|
|
82
|
+
if self.transformed is not None:
|
|
83
|
+
return self.transformed
|
|
84
|
+
|
|
85
|
+
if self.ref is None:
|
|
86
|
+
raise RuntimeError("Activation is deleted")
|
|
87
|
+
|
|
88
|
+
# If we have a lambda, we need to call it
|
|
89
|
+
unrwapped_ref = self.ref()
|
|
90
|
+
activation = unrwapped_ref
|
|
91
|
+
if callable(unrwapped_ref):
|
|
92
|
+
activation = unrwapped_ref()
|
|
93
|
+
|
|
94
|
+
# If we have a `None`, we return early
|
|
95
|
+
if activation is None:
|
|
96
|
+
return None
|
|
97
|
+
|
|
98
|
+
activation = apply_to_collection(activation, torch.Tensor, _to_numpy)
|
|
99
|
+
activation = _to_numpy(activation)
|
|
100
|
+
|
|
101
|
+
# Set the transformed value
|
|
102
|
+
self.transformed = activation
|
|
103
|
+
|
|
104
|
+
# Delete the reference
|
|
105
|
+
self.ref.delete()
|
|
106
|
+
del self.ref
|
|
107
|
+
self.ref = None
|
|
108
|
+
|
|
109
|
+
return self.transformed
|
|
110
|
+
|
|
111
|
+
@classmethod
|
|
112
|
+
def from_value_or_lambda(cls, name: str, value_or_lambda: ValueOrLambda):
|
|
113
|
+
return cls(name, WeakRef(value_or_lambda))
|
|
114
|
+
|
|
115
|
+
@classmethod
|
|
116
|
+
def from_dict(cls, d: Mapping[str, ValueOrLambda]):
|
|
117
|
+
return [cls.from_value_or_lambda(k, v) for k, v in d.items()]
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
Transform = Callable[[Activation], Mapping[str, ValueOrLambda]]
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _ensure_supported():
|
|
124
|
+
try:
|
|
125
|
+
import torch.distributed as dist
|
|
126
|
+
|
|
127
|
+
if dist.is_initialized() and dist.get_world_size() > 1:
|
|
128
|
+
raise RuntimeError("Only single GPU is supported at the moment")
|
|
129
|
+
except ImportError:
|
|
130
|
+
pass
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
P = ParamSpec("P")
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def _ignore_if_scripting(fn: Callable[P, None]) -> Callable[P, None]:
|
|
137
|
+
@wraps(fn)
|
|
138
|
+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
|
|
139
|
+
if torch.jit.is_scripting():
|
|
140
|
+
return
|
|
141
|
+
|
|
142
|
+
_ensure_supported()
|
|
143
|
+
fn(*args, **kwargs)
|
|
144
|
+
|
|
145
|
+
return wrapper
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class _Saver:
|
|
149
|
+
def __init__(
|
|
150
|
+
self,
|
|
151
|
+
save_dir: Path,
|
|
152
|
+
prefixes_fn: Callable[[], list[str]],
|
|
153
|
+
*,
|
|
154
|
+
filters: list[str] | None = None,
|
|
155
|
+
):
|
|
156
|
+
# Create a directory under `save_dir` by autoincrementing
|
|
157
|
+
# (i.e., every activation save context, we create a new directory)
|
|
158
|
+
# The id = the number of activation subdirectories
|
|
159
|
+
self._id = sum(1 for subdir in save_dir.glob("*") if subdir.is_dir())
|
|
160
|
+
save_dir.mkdir(parents=True, exist_ok=True)
|
|
161
|
+
|
|
162
|
+
# Add a .activationbase file to the save_dir to indicate that this is an activation base
|
|
163
|
+
(save_dir / ".activationbase").touch(exist_ok=True)
|
|
164
|
+
|
|
165
|
+
self._save_dir = save_dir / f"{self._id:04d}"
|
|
166
|
+
# Make sure `self._save_dir` does not exist and create it
|
|
167
|
+
self._save_dir.mkdir(exist_ok=False)
|
|
168
|
+
|
|
169
|
+
self._prefixes_fn = prefixes_fn
|
|
170
|
+
self._filters = filters
|
|
171
|
+
|
|
172
|
+
def _save_activation(self, activation: Activation):
|
|
173
|
+
# If the activation value is `None`, we skip it.
|
|
174
|
+
if (activation_value := activation()) is None:
|
|
175
|
+
return
|
|
176
|
+
|
|
177
|
+
# Save the activation to self._save_dir / name / {id}.npz, where id is an auto-incrementing integer
|
|
178
|
+
file_name = ".".join(self._prefixes_fn() + [activation.name])
|
|
179
|
+
path = self._save_dir / file_name
|
|
180
|
+
path.mkdir(exist_ok=True, parents=True)
|
|
181
|
+
|
|
182
|
+
# Get the next id and save the activation
|
|
183
|
+
id = len(list(path.glob("*.npy")))
|
|
184
|
+
np.save(path / f"{id:04d}.npy", activation_value)
|
|
185
|
+
|
|
186
|
+
@_ignore_if_scripting
|
|
187
|
+
def save(
|
|
188
|
+
self,
|
|
189
|
+
acts: dict[str, ValueOrLambda] | None = None,
|
|
190
|
+
/,
|
|
191
|
+
**kwargs: ValueOrLambda,
|
|
192
|
+
):
|
|
193
|
+
kwargs.update(acts or {})
|
|
194
|
+
|
|
195
|
+
# Build activations
|
|
196
|
+
activations = Activation.from_dict(kwargs)
|
|
197
|
+
|
|
198
|
+
for activation in activations:
|
|
199
|
+
# Make sure name matches at least one filter if filters are specified
|
|
200
|
+
if self._filters is not None and all(
|
|
201
|
+
not fnmatch.fnmatch(activation.name, f) for f in self._filters
|
|
202
|
+
):
|
|
203
|
+
continue
|
|
204
|
+
|
|
205
|
+
# Save the current activation
|
|
206
|
+
self._save_activation(activation)
|
|
207
|
+
|
|
208
|
+
del activations
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
class ActSaveProvider:
|
|
212
|
+
_saver: _Saver | None = None
|
|
213
|
+
_prefixes: list[str] = []
|
|
214
|
+
|
|
215
|
+
def initialize(self, save_dir: Path | None = None):
|
|
216
|
+
"""
|
|
217
|
+
Initializes the saver with the given configuration and save directory.
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
save_dir (Path): The directory where the saved files will be stored.
|
|
221
|
+
"""
|
|
222
|
+
if self._saver is None:
|
|
223
|
+
if save_dir is None:
|
|
224
|
+
save_dir = Path(tempfile.gettempdir()) / f"actsave-{uuid.uuid4()}"
|
|
225
|
+
log.critical(f"No save_dir specified, using {save_dir=}")
|
|
226
|
+
self._saver = _Saver(
|
|
227
|
+
save_dir,
|
|
228
|
+
lambda: self._prefixes,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
@contextlib.contextmanager
|
|
232
|
+
def enabled(self, save_dir: Path | None = None):
|
|
233
|
+
"""
|
|
234
|
+
Context manager that enables the actsave functionality with the specified configuration.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
save_dir (Path): The directory where the saved files will be stored.
|
|
238
|
+
"""
|
|
239
|
+
prev = self._saver
|
|
240
|
+
self.initialize(save_dir)
|
|
241
|
+
try:
|
|
242
|
+
yield
|
|
243
|
+
finally:
|
|
244
|
+
self._saver = prev
|
|
245
|
+
|
|
246
|
+
@override
|
|
247
|
+
def __init__(self):
|
|
248
|
+
super().__init__()
|
|
249
|
+
|
|
250
|
+
self._saver = None
|
|
251
|
+
self._prefixes = []
|
|
252
|
+
|
|
253
|
+
@contextlib.contextmanager
|
|
254
|
+
def context(self, label: str):
|
|
255
|
+
"""
|
|
256
|
+
A context manager that adds a label to the current context.
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
label (str): The label for the context.
|
|
260
|
+
"""
|
|
261
|
+
if torch.jit.is_scripting():
|
|
262
|
+
yield
|
|
263
|
+
return
|
|
264
|
+
|
|
265
|
+
if self._saver is None:
|
|
266
|
+
yield
|
|
267
|
+
return
|
|
268
|
+
|
|
269
|
+
_ensure_supported()
|
|
270
|
+
|
|
271
|
+
log.debug(f"Entering ActSave context {label}")
|
|
272
|
+
self._prefixes.append(label)
|
|
273
|
+
try:
|
|
274
|
+
yield
|
|
275
|
+
finally:
|
|
276
|
+
_ = self._prefixes.pop()
|
|
277
|
+
|
|
278
|
+
prefix = context
|
|
279
|
+
|
|
280
|
+
@overload
|
|
281
|
+
def __call__(
|
|
282
|
+
self,
|
|
283
|
+
acts: dict[str, ValueOrLambda] | None = None,
|
|
284
|
+
/,
|
|
285
|
+
**kwargs: ValueOrLambda,
|
|
286
|
+
):
|
|
287
|
+
"""
|
|
288
|
+
Saves the activations to disk.
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
acts (dict[str, ValueOrLambda] | None, optional): A dictionary of acts. Defaults to None.
|
|
292
|
+
**kwargs (ValueOrLambda): Additional keyword arguments.
|
|
293
|
+
|
|
294
|
+
Returns:
|
|
295
|
+
None
|
|
296
|
+
|
|
297
|
+
"""
|
|
298
|
+
...
|
|
299
|
+
|
|
300
|
+
@overload
|
|
301
|
+
def __call__(self, acts: Callable[[], dict[str, ValueOrLambda]], /):
|
|
302
|
+
"""
|
|
303
|
+
Saves the activations to disk.
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
acts (Callable[[], dict[str, ValueOrLambda]]): A callable that returns a dictionary of acts.
|
|
307
|
+
**kwargs (ValueOrLambda): Additional keyword arguments.
|
|
308
|
+
|
|
309
|
+
Returns:
|
|
310
|
+
None
|
|
311
|
+
|
|
312
|
+
"""
|
|
313
|
+
...
|
|
314
|
+
|
|
315
|
+
def __call__(
|
|
316
|
+
self,
|
|
317
|
+
acts: (
|
|
318
|
+
dict[str, ValueOrLambda] | Callable[[], dict[str, ValueOrLambda]] | None
|
|
319
|
+
) = None,
|
|
320
|
+
/,
|
|
321
|
+
**kwargs: ValueOrLambda,
|
|
322
|
+
):
|
|
323
|
+
if torch.jit.is_scripting():
|
|
324
|
+
return
|
|
325
|
+
|
|
326
|
+
if self._saver is None:
|
|
327
|
+
return
|
|
328
|
+
|
|
329
|
+
if acts is not None and callable(acts):
|
|
330
|
+
acts = acts()
|
|
331
|
+
self._saver.save(acts, **kwargs)
|
|
332
|
+
|
|
333
|
+
save = __call__
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
ActSave = ActSaveProvider()
|
|
337
|
+
ActivationSaver = ActSave
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from typing import Annotated
|
|
2
|
+
|
|
3
|
+
from ..config import Field
|
|
4
|
+
from .base import CallbackConfigBase as CallbackConfigBase
|
|
5
|
+
from .early_stopping import EarlyStopping as EarlyStopping
|
|
6
|
+
from .ema import EMA as EMA
|
|
7
|
+
from .ema import EMAConfig as EMAConfig
|
|
8
|
+
from .finite_checks import FiniteChecksCallback as FiniteChecksCallback
|
|
9
|
+
from .finite_checks import FiniteChecksConfig as FiniteChecksConfig
|
|
10
|
+
from .gradient_skipping import GradientSkipping as GradientSkipping
|
|
11
|
+
from .gradient_skipping import GradientSkippingConfig as GradientSkippingConfig
|
|
12
|
+
from .interval import EpochIntervalCallback as EpochIntervalCallback
|
|
13
|
+
from .interval import IntervalCallback as IntervalCallback
|
|
14
|
+
from .interval import StepIntervalCallback as StepIntervalCallback
|
|
15
|
+
from .latest_epoch_checkpoint import LatestEpochCheckpoint as LatestEpochCheckpoint
|
|
16
|
+
from .log_epoch import LogEpochCallback as LogEpochCallback
|
|
17
|
+
from .norm_logging import NormLoggingCallback as NormLoggingCallback
|
|
18
|
+
from .norm_logging import NormLoggingConfig as NormLoggingConfig
|
|
19
|
+
from .on_exception_checkpoint import OnExceptionCheckpoint as OnExceptionCheckpoint
|
|
20
|
+
from .print_table import PrintTableMetricsCallback as PrintTableMetricsCallback
|
|
21
|
+
from .print_table import PrintTableMetricsConfig as PrintTableMetricsConfig
|
|
22
|
+
from .throughput_monitor import ThroughputMonitorConfig as ThroughputMonitorConfig
|
|
23
|
+
from .timer import EpochTimer as EpochTimer
|
|
24
|
+
from .timer import EpochTimerConfig as EpochTimerConfig
|
|
25
|
+
|
|
26
|
+
CallbackConfig = Annotated[
|
|
27
|
+
ThroughputMonitorConfig
|
|
28
|
+
| EpochTimerConfig
|
|
29
|
+
| PrintTableMetricsConfig
|
|
30
|
+
| FiniteChecksConfig
|
|
31
|
+
| NormLoggingConfig
|
|
32
|
+
| GradientSkippingConfig
|
|
33
|
+
| EMAConfig,
|
|
34
|
+
Field(discriminator="name"),
|
|
35
|
+
]
|