nshtrainer 0.2.0__tar.gz → 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/PKG-INFO +2 -4
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/pyproject.toml +2 -4
- nshtrainer-0.4.0/src/nshtrainer/_snoop.py +1 -0
- nshtrainer-0.4.0/src/nshtrainer/actsave/__init__.py +3 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/actsave/_callback.py +1 -2
- nshtrainer-0.4.0/src/nshtrainer/typecheck.py +1 -0
- nshtrainer-0.2.0/src/nshtrainer/_snoop.py +0 -216
- nshtrainer-0.2.0/src/nshtrainer/actsave/__init__.py +0 -7
- nshtrainer-0.2.0/src/nshtrainer/actsave/_loader.py +0 -144
- nshtrainer-0.2.0/src/nshtrainer/actsave/_saver.py +0 -337
- nshtrainer-0.2.0/src/nshtrainer/typecheck.py +0 -145
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/README.md +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/__init__.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/_experimental/__init__.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/_experimental/flops/__init__.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/_experimental/flops/flop_counter.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/_experimental/flops/module_tracker.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/callbacks/__init__.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/callbacks/base.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/callbacks/early_stopping.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/callbacks/ema.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/callbacks/finite_checks.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/callbacks/interval.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/callbacks/latest_epoch_checkpoint.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/callbacks/log_epoch.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/callbacks/norm_logging.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/callbacks/on_exception_checkpoint.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/callbacks/print_table.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/callbacks/timer.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/config.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/data/__init__.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/data/transform.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/lr_scheduler/_base.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/model/__init__.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/model/base.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/model/config.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/model/modules/callback.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/model/modules/debug.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/model/modules/distributed.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/model/modules/logger.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/model/modules/profiler.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/model/modules/shared_parameters.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/nn/__init__.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/nn/mlp.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/nn/module_dict.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/nn/module_list.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/nn/nonlinearity.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/optimizer.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/runner.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/scripts/check_env.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/scripts/find_packages.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/trainer/__init__.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/trainer/signal_connector.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/trainer/trainer.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/util/environment.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/util/seed.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/util/slurm.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/util/typed.py +0 -0
- {nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/util/typing_utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: nshtrainer
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.0
|
|
4
4
|
Summary:
|
|
5
5
|
Author: Nima Shoghi
|
|
6
6
|
Author-email: nimashoghi@gmail.com
|
|
@@ -9,15 +9,13 @@ Classifier: Programming Language :: Python :: 3
|
|
|
9
9
|
Classifier: Programming Language :: Python :: 3.10
|
|
10
10
|
Classifier: Programming Language :: Python :: 3.11
|
|
11
11
|
Classifier: Programming Language :: Python :: 3.12
|
|
12
|
-
Requires-Dist: beartype (>=0.18.5,<0.19.0)
|
|
13
|
-
Requires-Dist: jaxtyping (>=0.2.33,<0.3.0)
|
|
14
12
|
Requires-Dist: lightning
|
|
15
13
|
Requires-Dist: lovely-numpy (>=0.2.13,<0.3.0)
|
|
16
14
|
Requires-Dist: lovely-tensors (>=0.1.16,<0.2.0)
|
|
17
15
|
Requires-Dist: nshconfig (>=0.2.0,<0.3.0)
|
|
18
16
|
Requires-Dist: nshrunner (>=0.5.4,<0.6.0)
|
|
17
|
+
Requires-Dist: nshutils (>=0.3.0,<0.4.0)
|
|
19
18
|
Requires-Dist: numpy
|
|
20
|
-
Requires-Dist: pysnooper
|
|
21
19
|
Requires-Dist: pytorch-lightning
|
|
22
20
|
Requires-Dist: rich
|
|
23
21
|
Requires-Dist: torch
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "nshtrainer"
|
|
3
|
-
version = "0.
|
|
3
|
+
version = "0.4.0"
|
|
4
4
|
description = ""
|
|
5
5
|
authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
|
|
6
6
|
readme = "README.md"
|
|
@@ -15,13 +15,11 @@ lightning = "*"
|
|
|
15
15
|
pytorch-lightning = "*"
|
|
16
16
|
torchmetrics = "*"
|
|
17
17
|
numpy = "*"
|
|
18
|
-
jaxtyping = "^0.2.33"
|
|
19
|
-
beartype = "^0.18.5"
|
|
20
18
|
lovely-numpy = "^0.2.13"
|
|
21
19
|
lovely-tensors = "^0.1.16"
|
|
22
|
-
pysnooper = "*"
|
|
23
20
|
wrapt = "*"
|
|
24
21
|
rich = "*"
|
|
22
|
+
nshutils = "^0.3.0"
|
|
25
23
|
|
|
26
24
|
|
|
27
25
|
[tool.poetry.group.dev.dependencies]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from nshutils.snoop import * # type: ignore # noqa: F403
|
|
@@ -3,10 +3,9 @@ from typing import TYPE_CHECKING, Literal, cast
|
|
|
3
3
|
|
|
4
4
|
from lightning.pytorch import LightningModule, Trainer
|
|
5
5
|
from lightning.pytorch.callbacks.callback import Callback
|
|
6
|
+
from nshutils.actsave import ActSave
|
|
6
7
|
from typing_extensions import TypeAlias, override
|
|
7
8
|
|
|
8
|
-
from ._saver import ActSave
|
|
9
|
-
|
|
10
9
|
if TYPE_CHECKING:
|
|
11
10
|
from ..model.config import BaseConfig
|
|
12
11
|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from nshutils.typecheck import * # type: ignore # noqa: F403
|
|
@@ -1,216 +0,0 @@
|
|
|
1
|
-
import contextlib
|
|
2
|
-
from typing import Any, Protocol, cast
|
|
3
|
-
|
|
4
|
-
from typing_extensions import TypeVar
|
|
5
|
-
|
|
6
|
-
T = TypeVar("T", infer_variance=True)
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
class SnoopConstructor(Protocol):
|
|
10
|
-
def __call__(self, *args, **kwargs) -> contextlib.AbstractContextManager: ...
|
|
11
|
-
|
|
12
|
-
def disable(self) -> contextlib.AbstractContextManager: ...
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
try:
|
|
16
|
-
import warnings
|
|
17
|
-
from contextlib import nullcontext
|
|
18
|
-
|
|
19
|
-
import lovely_numpy as lo
|
|
20
|
-
import lovely_tensors as lt
|
|
21
|
-
import numpy
|
|
22
|
-
import pysnooper
|
|
23
|
-
import pysnooper.utils
|
|
24
|
-
import torch
|
|
25
|
-
from pkg_resources import DistributionNotFound, get_distribution
|
|
26
|
-
|
|
27
|
-
FLOATING_POINTS = set()
|
|
28
|
-
for i in ["float", "double", "half", "complex128", "complex32", "complex64"]:
|
|
29
|
-
if hasattr(torch, i): # older version of PyTorch do not have complex dtypes
|
|
30
|
-
FLOATING_POINTS.add(getattr(torch, i))
|
|
31
|
-
|
|
32
|
-
try:
|
|
33
|
-
__version__ = get_distribution(__name__).version
|
|
34
|
-
except DistributionNotFound:
|
|
35
|
-
# package is not installed
|
|
36
|
-
pass
|
|
37
|
-
|
|
38
|
-
def default_format(x):
|
|
39
|
-
try:
|
|
40
|
-
formatted = str(lt.lovely(x))
|
|
41
|
-
return formatted
|
|
42
|
-
except BaseException:
|
|
43
|
-
return str(x.shape)
|
|
44
|
-
|
|
45
|
-
def default_numpy_format(x):
|
|
46
|
-
return str(lo.lovely(x))
|
|
47
|
-
|
|
48
|
-
class TorchSnooper(pysnooper.tracer.Tracer):
|
|
49
|
-
def __init__(
|
|
50
|
-
self,
|
|
51
|
-
*args,
|
|
52
|
-
tensor_format=default_format,
|
|
53
|
-
numpy_format=default_numpy_format,
|
|
54
|
-
**kwargs,
|
|
55
|
-
):
|
|
56
|
-
self.orig_custom_repr = (
|
|
57
|
-
kwargs["custom_repr"] if "custom_repr" in kwargs else ()
|
|
58
|
-
)
|
|
59
|
-
custom_repr = (lambda x: True, self.compute_repr)
|
|
60
|
-
kwargs["custom_repr"] = (custom_repr,)
|
|
61
|
-
super(TorchSnooper, self).__init__(*args, **kwargs)
|
|
62
|
-
self.tensor_format = tensor_format
|
|
63
|
-
self.numpy_format = numpy_format
|
|
64
|
-
|
|
65
|
-
@staticmethod
|
|
66
|
-
def is_return_types(x):
|
|
67
|
-
return type(x).__module__ == "torch.return_types"
|
|
68
|
-
|
|
69
|
-
def return_types_repr(self, x):
|
|
70
|
-
if type(x).__name__ in {
|
|
71
|
-
"max",
|
|
72
|
-
"min",
|
|
73
|
-
"median",
|
|
74
|
-
"mode",
|
|
75
|
-
"sort",
|
|
76
|
-
"topk",
|
|
77
|
-
"kthvalue",
|
|
78
|
-
}:
|
|
79
|
-
return (
|
|
80
|
-
type(x).__name__
|
|
81
|
-
+ "(values="
|
|
82
|
-
+ self.tensor_format(x.values)
|
|
83
|
-
+ ", indices="
|
|
84
|
-
+ self.tensor_format(x.indices)
|
|
85
|
-
+ ")"
|
|
86
|
-
)
|
|
87
|
-
if type(x).__name__ == "svd":
|
|
88
|
-
return (
|
|
89
|
-
"svd(U="
|
|
90
|
-
+ self.tensor_format(x.U)
|
|
91
|
-
+ ", S="
|
|
92
|
-
+ self.tensor_format(x.S)
|
|
93
|
-
+ ", V="
|
|
94
|
-
+ self.tensor_format(x.V)
|
|
95
|
-
+ ")"
|
|
96
|
-
)
|
|
97
|
-
if type(x).__name__ == "slogdet":
|
|
98
|
-
return (
|
|
99
|
-
"slogdet(sign="
|
|
100
|
-
+ self.tensor_format(x.sign)
|
|
101
|
-
+ ", logabsdet="
|
|
102
|
-
+ self.tensor_format(x.logabsdet)
|
|
103
|
-
+ ")"
|
|
104
|
-
)
|
|
105
|
-
if type(x).__name__ == "qr":
|
|
106
|
-
return (
|
|
107
|
-
"qr(Q="
|
|
108
|
-
+ self.tensor_format(x.Q)
|
|
109
|
-
+ ", R="
|
|
110
|
-
+ self.tensor_format(x.R)
|
|
111
|
-
+ ")"
|
|
112
|
-
)
|
|
113
|
-
if type(x).__name__ == "solve":
|
|
114
|
-
return (
|
|
115
|
-
"solve(solution="
|
|
116
|
-
+ self.tensor_format(x.solution)
|
|
117
|
-
+ ", LU="
|
|
118
|
-
+ self.tensor_format(x.LU)
|
|
119
|
-
+ ")"
|
|
120
|
-
)
|
|
121
|
-
if type(x).__name__ == "geqrf":
|
|
122
|
-
return (
|
|
123
|
-
"geqrf(a="
|
|
124
|
-
+ self.tensor_format(x.a)
|
|
125
|
-
+ ", tau="
|
|
126
|
-
+ self.tensor_format(x.tau)
|
|
127
|
-
+ ")"
|
|
128
|
-
)
|
|
129
|
-
if type(x).__name__ in {"symeig", "eig"}:
|
|
130
|
-
return (
|
|
131
|
-
type(x).__name__
|
|
132
|
-
+ "(eigenvalues="
|
|
133
|
-
+ self.tensor_format(x.eigenvalues)
|
|
134
|
-
+ ", eigenvectors="
|
|
135
|
-
+ self.tensor_format(x.eigenvectors)
|
|
136
|
-
+ ")"
|
|
137
|
-
)
|
|
138
|
-
if type(x).__name__ == "triangular_solve":
|
|
139
|
-
return (
|
|
140
|
-
"triangular_solve(solution="
|
|
141
|
-
+ self.tensor_format(x.solution)
|
|
142
|
-
+ ", cloned_coefficient="
|
|
143
|
-
+ self.tensor_format(x.cloned_coefficient)
|
|
144
|
-
+ ")"
|
|
145
|
-
)
|
|
146
|
-
if type(x).__name__ == "gels":
|
|
147
|
-
return (
|
|
148
|
-
"gels(solution="
|
|
149
|
-
+ self.tensor_format(x.solution)
|
|
150
|
-
+ ", QR="
|
|
151
|
-
+ self.tensor_format(x.QR)
|
|
152
|
-
+ ")"
|
|
153
|
-
)
|
|
154
|
-
warnings.warn("Unknown return_types encountered, open a bug report!")
|
|
155
|
-
|
|
156
|
-
def compute_repr(self, x):
|
|
157
|
-
orig_repr_func = pysnooper.utils.get_repr_function(x, self.orig_custom_repr)
|
|
158
|
-
if torch.is_tensor(x):
|
|
159
|
-
return self.tensor_format(x)
|
|
160
|
-
if isinstance(x, numpy.ndarray):
|
|
161
|
-
return self.numpy_format(x)
|
|
162
|
-
if self.is_return_types(x):
|
|
163
|
-
return self.return_types_repr(x)
|
|
164
|
-
if orig_repr_func is not repr:
|
|
165
|
-
return orig_repr_func(x)
|
|
166
|
-
if isinstance(x, (list, tuple)):
|
|
167
|
-
content = ""
|
|
168
|
-
for i in x:
|
|
169
|
-
if content != "":
|
|
170
|
-
content += ", "
|
|
171
|
-
content += self.compute_repr(i)
|
|
172
|
-
if isinstance(x, tuple) and len(x) == 1:
|
|
173
|
-
content += ","
|
|
174
|
-
if isinstance(x, tuple):
|
|
175
|
-
return "(" + content + ")"
|
|
176
|
-
return "[" + content + "]"
|
|
177
|
-
if isinstance(x, dict):
|
|
178
|
-
content = ""
|
|
179
|
-
for k, v in x.items():
|
|
180
|
-
if content != "":
|
|
181
|
-
content += ", "
|
|
182
|
-
content += self.compute_repr(k) + ": " + self.compute_repr(v)
|
|
183
|
-
return "{" + content + "}"
|
|
184
|
-
return repr(x)
|
|
185
|
-
|
|
186
|
-
class _Snoop:
|
|
187
|
-
disable = nullcontext
|
|
188
|
-
__call__ = TorchSnooper
|
|
189
|
-
|
|
190
|
-
snoop: SnoopConstructor = cast(Any, _Snoop())
|
|
191
|
-
|
|
192
|
-
except ImportError:
|
|
193
|
-
import warnings
|
|
194
|
-
from contextlib import nullcontext
|
|
195
|
-
|
|
196
|
-
from typing_extensions import override
|
|
197
|
-
|
|
198
|
-
_has_warned = False
|
|
199
|
-
|
|
200
|
-
class _snoop_cls(nullcontext):
|
|
201
|
-
@classmethod
|
|
202
|
-
def disable(cls):
|
|
203
|
-
return nullcontext()
|
|
204
|
-
|
|
205
|
-
@override
|
|
206
|
-
def __enter__(self):
|
|
207
|
-
global _has_warned
|
|
208
|
-
if not _has_warned:
|
|
209
|
-
warnings.warn(
|
|
210
|
-
"snoop is not installed, please install it to enable snoop"
|
|
211
|
-
)
|
|
212
|
-
_has_warned = True
|
|
213
|
-
|
|
214
|
-
return super().__enter__()
|
|
215
|
-
|
|
216
|
-
snoop: SnoopConstructor = cast(Any, _snoop_cls)
|
|
@@ -1,7 +0,0 @@
|
|
|
1
|
-
from ._callback import ActSaveCallback as ActSaveCallback
|
|
2
|
-
from ._loader import ActivationLoader as ActivationLoader
|
|
3
|
-
from ._loader import ActLoad as ActLoad
|
|
4
|
-
from ._saver import Activation as Activation
|
|
5
|
-
from ._saver import ActivationSaver as ActivationSaver
|
|
6
|
-
from ._saver import ActSave as ActSave
|
|
7
|
-
from ._saver import Transform as Transform
|
|
@@ -1,144 +0,0 @@
|
|
|
1
|
-
import pprint
|
|
2
|
-
from dataclasses import dataclass, field
|
|
3
|
-
from functools import cached_property
|
|
4
|
-
from logging import getLogger
|
|
5
|
-
from pathlib import Path
|
|
6
|
-
from typing import cast, overload
|
|
7
|
-
|
|
8
|
-
import numpy as np
|
|
9
|
-
from typing_extensions import TypeVar, override
|
|
10
|
-
|
|
11
|
-
log = getLogger(__name__)
|
|
12
|
-
|
|
13
|
-
T = TypeVar("T", infer_variance=True)
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
@dataclass
|
|
17
|
-
class LoadedActivation:
|
|
18
|
-
base_dir: Path = field(repr=False)
|
|
19
|
-
name: str
|
|
20
|
-
num_activations: int = field(init=False)
|
|
21
|
-
activation_files: list[Path] = field(init=False, repr=False)
|
|
22
|
-
|
|
23
|
-
def __post_init__(self):
|
|
24
|
-
if not self.activation_dir.exists():
|
|
25
|
-
raise ValueError(f"Activation dir {self.activation_dir} does not exist")
|
|
26
|
-
|
|
27
|
-
# The number of activations = the * of .npy files in the activation dir
|
|
28
|
-
self.activation_files = list(self.activation_dir.glob("*.npy"))
|
|
29
|
-
# Sort the activation files by the numerical index in the filename
|
|
30
|
-
self.activation_files.sort(key=lambda p: int(p.stem))
|
|
31
|
-
self.num_activations = len(self.activation_files)
|
|
32
|
-
|
|
33
|
-
@property
|
|
34
|
-
def activation_dir(self) -> Path:
|
|
35
|
-
return self.base_dir / self.name
|
|
36
|
-
|
|
37
|
-
def _load_activation(self, item: int):
|
|
38
|
-
activation_path = self.activation_files[item]
|
|
39
|
-
if not activation_path.exists():
|
|
40
|
-
raise ValueError(f"Activation {activation_path} does not exist")
|
|
41
|
-
return cast(np.ndarray, np.load(activation_path, allow_pickle=True))
|
|
42
|
-
|
|
43
|
-
@overload
|
|
44
|
-
def __getitem__(self, item: int) -> np.ndarray: ...
|
|
45
|
-
|
|
46
|
-
@overload
|
|
47
|
-
def __getitem__(self, item: slice | list[int]) -> list[np.ndarray]: ...
|
|
48
|
-
|
|
49
|
-
def __getitem__(
|
|
50
|
-
self, item: int | slice | list[int]
|
|
51
|
-
) -> np.ndarray | list[np.ndarray]:
|
|
52
|
-
if isinstance(item, int):
|
|
53
|
-
return self._load_activation(item)
|
|
54
|
-
elif isinstance(item, slice):
|
|
55
|
-
return [
|
|
56
|
-
self._load_activation(i)
|
|
57
|
-
for i in range(*item.indices(self.num_activations))
|
|
58
|
-
]
|
|
59
|
-
elif isinstance(item, list):
|
|
60
|
-
return [self._load_activation(i) for i in item]
|
|
61
|
-
else:
|
|
62
|
-
raise TypeError(f"Invalid type {type(item)} for item {item}")
|
|
63
|
-
|
|
64
|
-
def __iter__(self):
|
|
65
|
-
return iter(self[i] for i in range(self.num_activations))
|
|
66
|
-
|
|
67
|
-
def __len__(self):
|
|
68
|
-
return self.num_activations
|
|
69
|
-
|
|
70
|
-
def all_activations(self):
|
|
71
|
-
return [self[i] for i in range(self.num_activations)]
|
|
72
|
-
|
|
73
|
-
@override
|
|
74
|
-
def __repr__(self):
|
|
75
|
-
return f"<LoadedActivation {self.name} ({self.num_activations} activations)>"
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
class ActLoad:
|
|
79
|
-
@classmethod
|
|
80
|
-
def all_versions(cls, dir: str | Path):
|
|
81
|
-
dir = Path(dir)
|
|
82
|
-
|
|
83
|
-
# If the dir is not an activation base directory, we return None
|
|
84
|
-
if not (dir / ".activationbase").exists():
|
|
85
|
-
return None
|
|
86
|
-
|
|
87
|
-
# The contents of `dir` should be directories, each of which is a version.
|
|
88
|
-
return [
|
|
89
|
-
(subdir, int(subdir.name)) for subdir in dir.iterdir() if subdir.is_dir()
|
|
90
|
-
]
|
|
91
|
-
|
|
92
|
-
@classmethod
|
|
93
|
-
def is_valid_activation_base(cls, dir: str | Path):
|
|
94
|
-
return cls.all_versions(dir) is not None
|
|
95
|
-
|
|
96
|
-
@classmethod
|
|
97
|
-
def from_latest_version(cls, dir: str | Path):
|
|
98
|
-
# The contents of `dir` should be directories, each of which is a version
|
|
99
|
-
# We need to find the latest version
|
|
100
|
-
if (all_versions := cls.all_versions(dir)) is None:
|
|
101
|
-
raise ValueError(f"{dir} is not an activation base directory")
|
|
102
|
-
|
|
103
|
-
path, _ = max(all_versions, key=lambda p: p[1])
|
|
104
|
-
return cls(path)
|
|
105
|
-
|
|
106
|
-
def __init__(self, dir: Path):
|
|
107
|
-
self._dir = dir
|
|
108
|
-
|
|
109
|
-
def activation(self, name: str):
|
|
110
|
-
return LoadedActivation(self._dir, name)
|
|
111
|
-
|
|
112
|
-
@cached_property
|
|
113
|
-
def activations(self):
|
|
114
|
-
dirs = list(self._dir.iterdir())
|
|
115
|
-
# Sort the dirs by the last modified time
|
|
116
|
-
dirs.sort(key=lambda p: p.stat().st_mtime)
|
|
117
|
-
|
|
118
|
-
return {p.name: LoadedActivation(self._dir, p.name) for p in dirs}
|
|
119
|
-
|
|
120
|
-
def __iter__(self):
|
|
121
|
-
return iter(self.activations.values())
|
|
122
|
-
|
|
123
|
-
def __getitem__(self, item: str):
|
|
124
|
-
return self.activations[item]
|
|
125
|
-
|
|
126
|
-
def __len__(self):
|
|
127
|
-
return len(self.activations)
|
|
128
|
-
|
|
129
|
-
@override
|
|
130
|
-
def __repr__(self):
|
|
131
|
-
acts_str = pprint.pformat(
|
|
132
|
-
{
|
|
133
|
-
name: f"<{activation.num_activations} activations>"
|
|
134
|
-
for name, activation in self.activations.items()
|
|
135
|
-
}
|
|
136
|
-
)
|
|
137
|
-
acts_str = acts_str.replace("'<", "<").replace(">'", ">")
|
|
138
|
-
return f"ActLoad({acts_str})"
|
|
139
|
-
|
|
140
|
-
def get(self, name: str, /, default: T) -> LoadedActivation | T:
|
|
141
|
-
return self.activations.get(name, default)
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
ActivationLoader = ActLoad
|
|
@@ -1,337 +0,0 @@
|
|
|
1
|
-
import contextlib
|
|
2
|
-
import fnmatch
|
|
3
|
-
import tempfile
|
|
4
|
-
import uuid
|
|
5
|
-
import weakref
|
|
6
|
-
from collections.abc import Callable, Mapping
|
|
7
|
-
from dataclasses import dataclass
|
|
8
|
-
from functools import wraps
|
|
9
|
-
from logging import getLogger
|
|
10
|
-
from pathlib import Path
|
|
11
|
-
from typing import Generic, TypeAlias, cast, overload
|
|
12
|
-
|
|
13
|
-
import numpy as np
|
|
14
|
-
import torch
|
|
15
|
-
from lightning_utilities.core.apply_func import apply_to_collection
|
|
16
|
-
from typing_extensions import ParamSpec, TypeVar, override
|
|
17
|
-
|
|
18
|
-
log = getLogger(__name__)
|
|
19
|
-
|
|
20
|
-
Value: TypeAlias = int | float | complex | bool | str | np.ndarray | torch.Tensor | None
|
|
21
|
-
ValueOrLambda = Value | Callable[..., Value]
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
def _to_numpy(activation: Value) -> np.ndarray:
|
|
25
|
-
# Make sure it's not `None`
|
|
26
|
-
if activation is None:
|
|
27
|
-
raise ValueError("Activation should not be `None`")
|
|
28
|
-
|
|
29
|
-
if isinstance(activation, np.ndarray):
|
|
30
|
-
return activation
|
|
31
|
-
if isinstance(activation, torch.Tensor):
|
|
32
|
-
activation = activation.detach()
|
|
33
|
-
if activation.is_floating_point():
|
|
34
|
-
# NOTE: We need to convert to float32 because [b]float16 is not supported by numpy
|
|
35
|
-
activation = activation.float()
|
|
36
|
-
return activation.cpu().numpy()
|
|
37
|
-
if isinstance(activation, (int, float, complex, str, bool)):
|
|
38
|
-
return np.array(activation)
|
|
39
|
-
|
|
40
|
-
return activation
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
T = TypeVar("T", infer_variance=True)
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
# A wrapper around weakref.ref that allows for primitive types
|
|
47
|
-
# To get around errors like:
|
|
48
|
-
# TypeError: cannot create weak reference to 'int' object
|
|
49
|
-
class WeakRef(Generic[T]):
|
|
50
|
-
_ref: Callable[[], T] | None
|
|
51
|
-
|
|
52
|
-
def __init__(self, obj: T):
|
|
53
|
-
try:
|
|
54
|
-
self._ref = cast(Callable[[], T], weakref.ref(obj))
|
|
55
|
-
except TypeError as e:
|
|
56
|
-
if "cannot create weak reference" not in str(e):
|
|
57
|
-
raise
|
|
58
|
-
self._ref = lambda: obj
|
|
59
|
-
|
|
60
|
-
def __call__(self) -> T:
|
|
61
|
-
if self._ref is None:
|
|
62
|
-
raise RuntimeError("WeakRef is deleted")
|
|
63
|
-
return self._ref()
|
|
64
|
-
|
|
65
|
-
def delete(self):
|
|
66
|
-
del self._ref
|
|
67
|
-
self._ref = None
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
@dataclass
|
|
71
|
-
class Activation:
|
|
72
|
-
name: str
|
|
73
|
-
ref: WeakRef[ValueOrLambda] | None
|
|
74
|
-
transformed: np.ndarray | None = None
|
|
75
|
-
|
|
76
|
-
def __post_init__(self):
|
|
77
|
-
# Update the `name` to replace `/` with `.`
|
|
78
|
-
self.name = self.name.replace("/", ".")
|
|
79
|
-
|
|
80
|
-
def __call__(self) -> np.ndarray | None:
|
|
81
|
-
# If we have a transformed value, we return it
|
|
82
|
-
if self.transformed is not None:
|
|
83
|
-
return self.transformed
|
|
84
|
-
|
|
85
|
-
if self.ref is None:
|
|
86
|
-
raise RuntimeError("Activation is deleted")
|
|
87
|
-
|
|
88
|
-
# If we have a lambda, we need to call it
|
|
89
|
-
unrwapped_ref = self.ref()
|
|
90
|
-
activation = unrwapped_ref
|
|
91
|
-
if callable(unrwapped_ref):
|
|
92
|
-
activation = unrwapped_ref()
|
|
93
|
-
|
|
94
|
-
# If we have a `None`, we return early
|
|
95
|
-
if activation is None:
|
|
96
|
-
return None
|
|
97
|
-
|
|
98
|
-
activation = apply_to_collection(activation, torch.Tensor, _to_numpy)
|
|
99
|
-
activation = _to_numpy(activation)
|
|
100
|
-
|
|
101
|
-
# Set the transformed value
|
|
102
|
-
self.transformed = activation
|
|
103
|
-
|
|
104
|
-
# Delete the reference
|
|
105
|
-
self.ref.delete()
|
|
106
|
-
del self.ref
|
|
107
|
-
self.ref = None
|
|
108
|
-
|
|
109
|
-
return self.transformed
|
|
110
|
-
|
|
111
|
-
@classmethod
|
|
112
|
-
def from_value_or_lambda(cls, name: str, value_or_lambda: ValueOrLambda):
|
|
113
|
-
return cls(name, WeakRef(value_or_lambda))
|
|
114
|
-
|
|
115
|
-
@classmethod
|
|
116
|
-
def from_dict(cls, d: Mapping[str, ValueOrLambda]):
|
|
117
|
-
return [cls.from_value_or_lambda(k, v) for k, v in d.items()]
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
Transform = Callable[[Activation], Mapping[str, ValueOrLambda]]
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
def _ensure_supported():
|
|
124
|
-
try:
|
|
125
|
-
import torch.distributed as dist
|
|
126
|
-
|
|
127
|
-
if dist.is_initialized() and dist.get_world_size() > 1:
|
|
128
|
-
raise RuntimeError("Only single GPU is supported at the moment")
|
|
129
|
-
except ImportError:
|
|
130
|
-
pass
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
P = ParamSpec("P")
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
def _ignore_if_scripting(fn: Callable[P, None]) -> Callable[P, None]:
|
|
137
|
-
@wraps(fn)
|
|
138
|
-
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
|
|
139
|
-
if torch.jit.is_scripting():
|
|
140
|
-
return
|
|
141
|
-
|
|
142
|
-
_ensure_supported()
|
|
143
|
-
fn(*args, **kwargs)
|
|
144
|
-
|
|
145
|
-
return wrapper
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
class _Saver:
|
|
149
|
-
def __init__(
|
|
150
|
-
self,
|
|
151
|
-
save_dir: Path,
|
|
152
|
-
prefixes_fn: Callable[[], list[str]],
|
|
153
|
-
*,
|
|
154
|
-
filters: list[str] | None = None,
|
|
155
|
-
):
|
|
156
|
-
# Create a directory under `save_dir` by autoincrementing
|
|
157
|
-
# (i.e., every activation save context, we create a new directory)
|
|
158
|
-
# The id = the number of activation subdirectories
|
|
159
|
-
self._id = sum(1 for subdir in save_dir.glob("*") if subdir.is_dir())
|
|
160
|
-
save_dir.mkdir(parents=True, exist_ok=True)
|
|
161
|
-
|
|
162
|
-
# Add a .activationbase file to the save_dir to indicate that this is an activation base
|
|
163
|
-
(save_dir / ".activationbase").touch(exist_ok=True)
|
|
164
|
-
|
|
165
|
-
self._save_dir = save_dir / f"{self._id:04d}"
|
|
166
|
-
# Make sure `self._save_dir` does not exist and create it
|
|
167
|
-
self._save_dir.mkdir(exist_ok=False)
|
|
168
|
-
|
|
169
|
-
self._prefixes_fn = prefixes_fn
|
|
170
|
-
self._filters = filters
|
|
171
|
-
|
|
172
|
-
def _save_activation(self, activation: Activation):
|
|
173
|
-
# If the activation value is `None`, we skip it.
|
|
174
|
-
if (activation_value := activation()) is None:
|
|
175
|
-
return
|
|
176
|
-
|
|
177
|
-
# Save the activation to self._save_dir / name / {id}.npz, where id is an auto-incrementing integer
|
|
178
|
-
file_name = ".".join(self._prefixes_fn() + [activation.name])
|
|
179
|
-
path = self._save_dir / file_name
|
|
180
|
-
path.mkdir(exist_ok=True, parents=True)
|
|
181
|
-
|
|
182
|
-
# Get the next id and save the activation
|
|
183
|
-
id = len(list(path.glob("*.npy")))
|
|
184
|
-
np.save(path / f"{id:04d}.npy", activation_value)
|
|
185
|
-
|
|
186
|
-
@_ignore_if_scripting
|
|
187
|
-
def save(
|
|
188
|
-
self,
|
|
189
|
-
acts: dict[str, ValueOrLambda] | None = None,
|
|
190
|
-
/,
|
|
191
|
-
**kwargs: ValueOrLambda,
|
|
192
|
-
):
|
|
193
|
-
kwargs.update(acts or {})
|
|
194
|
-
|
|
195
|
-
# Build activations
|
|
196
|
-
activations = Activation.from_dict(kwargs)
|
|
197
|
-
|
|
198
|
-
for activation in activations:
|
|
199
|
-
# Make sure name matches at least one filter if filters are specified
|
|
200
|
-
if self._filters is not None and all(
|
|
201
|
-
not fnmatch.fnmatch(activation.name, f) for f in self._filters
|
|
202
|
-
):
|
|
203
|
-
continue
|
|
204
|
-
|
|
205
|
-
# Save the current activation
|
|
206
|
-
self._save_activation(activation)
|
|
207
|
-
|
|
208
|
-
del activations
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
class ActSaveProvider:
|
|
212
|
-
_saver: _Saver | None = None
|
|
213
|
-
_prefixes: list[str] = []
|
|
214
|
-
|
|
215
|
-
def initialize(self, save_dir: Path | None = None):
|
|
216
|
-
"""
|
|
217
|
-
Initializes the saver with the given configuration and save directory.
|
|
218
|
-
|
|
219
|
-
Args:
|
|
220
|
-
save_dir (Path): The directory where the saved files will be stored.
|
|
221
|
-
"""
|
|
222
|
-
if self._saver is None:
|
|
223
|
-
if save_dir is None:
|
|
224
|
-
save_dir = Path(tempfile.gettempdir()) / f"actsave-{uuid.uuid4()}"
|
|
225
|
-
log.critical(f"No save_dir specified, using {save_dir=}")
|
|
226
|
-
self._saver = _Saver(
|
|
227
|
-
save_dir,
|
|
228
|
-
lambda: self._prefixes,
|
|
229
|
-
)
|
|
230
|
-
|
|
231
|
-
@contextlib.contextmanager
|
|
232
|
-
def enabled(self, save_dir: Path | None = None):
|
|
233
|
-
"""
|
|
234
|
-
Context manager that enables the actsave functionality with the specified configuration.
|
|
235
|
-
|
|
236
|
-
Args:
|
|
237
|
-
save_dir (Path): The directory where the saved files will be stored.
|
|
238
|
-
"""
|
|
239
|
-
prev = self._saver
|
|
240
|
-
self.initialize(save_dir)
|
|
241
|
-
try:
|
|
242
|
-
yield
|
|
243
|
-
finally:
|
|
244
|
-
self._saver = prev
|
|
245
|
-
|
|
246
|
-
@override
|
|
247
|
-
def __init__(self):
|
|
248
|
-
super().__init__()
|
|
249
|
-
|
|
250
|
-
self._saver = None
|
|
251
|
-
self._prefixes = []
|
|
252
|
-
|
|
253
|
-
@contextlib.contextmanager
|
|
254
|
-
def context(self, label: str):
|
|
255
|
-
"""
|
|
256
|
-
A context manager that adds a label to the current context.
|
|
257
|
-
|
|
258
|
-
Args:
|
|
259
|
-
label (str): The label for the context.
|
|
260
|
-
"""
|
|
261
|
-
if torch.jit.is_scripting():
|
|
262
|
-
yield
|
|
263
|
-
return
|
|
264
|
-
|
|
265
|
-
if self._saver is None:
|
|
266
|
-
yield
|
|
267
|
-
return
|
|
268
|
-
|
|
269
|
-
_ensure_supported()
|
|
270
|
-
|
|
271
|
-
log.debug(f"Entering ActSave context {label}")
|
|
272
|
-
self._prefixes.append(label)
|
|
273
|
-
try:
|
|
274
|
-
yield
|
|
275
|
-
finally:
|
|
276
|
-
_ = self._prefixes.pop()
|
|
277
|
-
|
|
278
|
-
prefix = context
|
|
279
|
-
|
|
280
|
-
@overload
|
|
281
|
-
def __call__(
|
|
282
|
-
self,
|
|
283
|
-
acts: dict[str, ValueOrLambda] | None = None,
|
|
284
|
-
/,
|
|
285
|
-
**kwargs: ValueOrLambda,
|
|
286
|
-
):
|
|
287
|
-
"""
|
|
288
|
-
Saves the activations to disk.
|
|
289
|
-
|
|
290
|
-
Args:
|
|
291
|
-
acts (dict[str, ValueOrLambda] | None, optional): A dictionary of acts. Defaults to None.
|
|
292
|
-
**kwargs (ValueOrLambda): Additional keyword arguments.
|
|
293
|
-
|
|
294
|
-
Returns:
|
|
295
|
-
None
|
|
296
|
-
|
|
297
|
-
"""
|
|
298
|
-
...
|
|
299
|
-
|
|
300
|
-
@overload
|
|
301
|
-
def __call__(self, acts: Callable[[], dict[str, ValueOrLambda]], /):
|
|
302
|
-
"""
|
|
303
|
-
Saves the activations to disk.
|
|
304
|
-
|
|
305
|
-
Args:
|
|
306
|
-
acts (Callable[[], dict[str, ValueOrLambda]]): A callable that returns a dictionary of acts.
|
|
307
|
-
**kwargs (ValueOrLambda): Additional keyword arguments.
|
|
308
|
-
|
|
309
|
-
Returns:
|
|
310
|
-
None
|
|
311
|
-
|
|
312
|
-
"""
|
|
313
|
-
...
|
|
314
|
-
|
|
315
|
-
def __call__(
|
|
316
|
-
self,
|
|
317
|
-
acts: (
|
|
318
|
-
dict[str, ValueOrLambda] | Callable[[], dict[str, ValueOrLambda]] | None
|
|
319
|
-
) = None,
|
|
320
|
-
/,
|
|
321
|
-
**kwargs: ValueOrLambda,
|
|
322
|
-
):
|
|
323
|
-
if torch.jit.is_scripting():
|
|
324
|
-
return
|
|
325
|
-
|
|
326
|
-
if self._saver is None:
|
|
327
|
-
return
|
|
328
|
-
|
|
329
|
-
if acts is not None and callable(acts):
|
|
330
|
-
acts = acts()
|
|
331
|
-
self._saver.save(acts, **kwargs)
|
|
332
|
-
|
|
333
|
-
save = __call__
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
ActSave = ActSaveProvider()
|
|
337
|
-
ActivationSaver = ActSave
|
|
@@ -1,145 +0,0 @@
|
|
|
1
|
-
import os
|
|
2
|
-
from collections.abc import Sequence
|
|
3
|
-
from logging import getLogger
|
|
4
|
-
from typing import Any
|
|
5
|
-
|
|
6
|
-
import numpy as np
|
|
7
|
-
import torch
|
|
8
|
-
from jaxtyping import BFloat16 as BFloat16
|
|
9
|
-
from jaxtyping import Bool as Bool
|
|
10
|
-
from jaxtyping import Complex as Complex
|
|
11
|
-
from jaxtyping import Complex64 as Complex64
|
|
12
|
-
from jaxtyping import Complex128 as Complex128
|
|
13
|
-
from jaxtyping import Float as Float
|
|
14
|
-
from jaxtyping import Float16 as Float16
|
|
15
|
-
from jaxtyping import Float32 as Float32
|
|
16
|
-
from jaxtyping import Float64 as Float64
|
|
17
|
-
from jaxtyping import Inexact as Inexact
|
|
18
|
-
from jaxtyping import Int as Int
|
|
19
|
-
from jaxtyping import Int4 as Int4
|
|
20
|
-
from jaxtyping import Int8 as Int8
|
|
21
|
-
from jaxtyping import Int16 as Int16
|
|
22
|
-
from jaxtyping import Int32 as Int32
|
|
23
|
-
from jaxtyping import Int64 as Int64
|
|
24
|
-
from jaxtyping import Integer as Integer
|
|
25
|
-
from jaxtyping import Key as Key
|
|
26
|
-
from jaxtyping import Num as Num
|
|
27
|
-
from jaxtyping import Real as Real
|
|
28
|
-
from jaxtyping import Shaped as Shaped
|
|
29
|
-
from jaxtyping import UInt as UInt
|
|
30
|
-
from jaxtyping import UInt4 as UInt4
|
|
31
|
-
from jaxtyping import UInt8 as UInt8
|
|
32
|
-
from jaxtyping import UInt16 as UInt16
|
|
33
|
-
from jaxtyping import UInt32 as UInt32
|
|
34
|
-
from jaxtyping import UInt64 as UInt64
|
|
35
|
-
from jaxtyping._storage import get_shape_memo, shape_str
|
|
36
|
-
from torch import Tensor as Tensor
|
|
37
|
-
from torch.nn.parameter import Parameter as Parameter
|
|
38
|
-
from typing_extensions import TypeVar
|
|
39
|
-
|
|
40
|
-
log = getLogger(__name__)
|
|
41
|
-
|
|
42
|
-
DISABLE_ENV_KEY = "LL_DISABLE_TYPECHECKING"
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
def typecheck_modules(modules: Sequence[str]):
|
|
46
|
-
"""
|
|
47
|
-
Typecheck the given modules using `jaxtyping`.
|
|
48
|
-
|
|
49
|
-
Args:
|
|
50
|
-
modules: Modules to typecheck.
|
|
51
|
-
"""
|
|
52
|
-
# If `DISABLE_ENV_KEY` is set and the environment variable is set, skip
|
|
53
|
-
# typechecking.
|
|
54
|
-
if DISABLE_ENV_KEY is not None and bool(int(os.environ.get(DISABLE_ENV_KEY, "0"))):
|
|
55
|
-
log.critical(
|
|
56
|
-
f"Type checking is disabled due to the environment variable {DISABLE_ENV_KEY}."
|
|
57
|
-
)
|
|
58
|
-
return
|
|
59
|
-
|
|
60
|
-
# Install the jaxtyping import hook for this module.
|
|
61
|
-
from jaxtyping import install_import_hook
|
|
62
|
-
|
|
63
|
-
install_import_hook(modules, "beartype.beartype")
|
|
64
|
-
|
|
65
|
-
log.critical(f"Type checking the following modules: {modules}")
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
def typecheck_this_module(additional_modules: Sequence[str] = ()):
|
|
69
|
-
"""
|
|
70
|
-
Typecheck the calling module and any additional modules using `jaxtyping`.
|
|
71
|
-
|
|
72
|
-
Args:
|
|
73
|
-
additional_modules: Additional modules to typecheck.
|
|
74
|
-
"""
|
|
75
|
-
# Get the calling module's name.
|
|
76
|
-
# Here, we can just use beartype's internal implementation behind
|
|
77
|
-
# `beartype_this_package`.
|
|
78
|
-
from beartype._util.func.utilfuncframe import get_frame, get_frame_package_name
|
|
79
|
-
|
|
80
|
-
# Get the calling module's name.
|
|
81
|
-
assert get_frame is not None, "get_frame is None"
|
|
82
|
-
frame = get_frame(1)
|
|
83
|
-
assert frame is not None, "frame is None"
|
|
84
|
-
calling_module_name = get_frame_package_name(frame)
|
|
85
|
-
assert calling_module_name is not None, "calling_module_name is None"
|
|
86
|
-
|
|
87
|
-
# Typecheck the calling module + any additional modules.
|
|
88
|
-
typecheck_modules((calling_module_name, *additional_modules))
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
def _make_error_str(input: Any, t: Any) -> str:
|
|
92
|
-
error_components: list[str] = []
|
|
93
|
-
error_components.append("Type checking error:")
|
|
94
|
-
if hasattr(t, "__instancecheck_str__"):
|
|
95
|
-
error_components.append(t.__instancecheck_str__(input))
|
|
96
|
-
if torch.is_tensor(input):
|
|
97
|
-
try:
|
|
98
|
-
from lovely_tensors import lovely
|
|
99
|
-
|
|
100
|
-
error_components.append(repr(lovely(input)))
|
|
101
|
-
except BaseException:
|
|
102
|
-
error_components.append(repr(input.shape))
|
|
103
|
-
error_components.append(shape_str(get_shape_memo()))
|
|
104
|
-
|
|
105
|
-
return "\n".join(error_components)
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
T = TypeVar("T", torch.Tensor, np.ndarray, infer_variance=True)
|
|
109
|
-
|
|
110
|
-
"""
|
|
111
|
-
Patch to jaxtyping:
|
|
112
|
-
|
|
113
|
-
In `jaxtyping._import_hook`, we add:
|
|
114
|
-
def _has_isinstance_or_tassert(func_def):
|
|
115
|
-
for node in ast.walk(func_def):
|
|
116
|
-
if isinstance(node, ast.Call):
|
|
117
|
-
if isinstance(node.func, ast.Name) and node.func.id == "isinstance":
|
|
118
|
-
return True
|
|
119
|
-
elif isinstance(node.func, ast.Name) and node.func.id == "tassert":
|
|
120
|
-
return True
|
|
121
|
-
return False
|
|
122
|
-
|
|
123
|
-
and we check this when adding the decorators.
|
|
124
|
-
"""
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
def tassert(t: Any, input: T | tuple[T, ...]):
|
|
128
|
-
"""
|
|
129
|
-
Typecheck the input against the given type.
|
|
130
|
-
|
|
131
|
-
Args:
|
|
132
|
-
t: Type to check against.
|
|
133
|
-
input: Input to check.
|
|
134
|
-
"""
|
|
135
|
-
|
|
136
|
-
# Ignore typechecking if the environment variable is set.
|
|
137
|
-
if DISABLE_ENV_KEY is not None and bool(int(os.environ.get(DISABLE_ENV_KEY, "0"))):
|
|
138
|
-
return
|
|
139
|
-
|
|
140
|
-
if isinstance(input, tuple):
|
|
141
|
-
for i in input:
|
|
142
|
-
assert isinstance(i, t), _make_error_str(i, t)
|
|
143
|
-
return
|
|
144
|
-
else:
|
|
145
|
-
assert isinstance(input, t), _make_error_str(input, t)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{nshtrainer-0.2.0 → nshtrainer-0.4.0}/src/nshtrainer/callbacks/_throughput_monitor_callback.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|