nshutils 0.2.0__tar.gz → 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nshutils-0.2.0 → nshutils-0.3.0}/PKG-INFO +2 -1
- {nshutils-0.2.0 → nshutils-0.3.0}/pyproject.toml +14 -1
- {nshutils-0.2.0 → nshutils-0.3.0}/src/nshutils/__init__.py +1 -0
- nshutils-0.3.0/src/nshutils/actsave/__init__.py +6 -0
- nshutils-0.3.0/src/nshutils/actsave/_loader.py +144 -0
- nshutils-0.3.0/src/nshutils/actsave/_saver.py +360 -0
- nshutils-0.3.0/src/nshutils/collections.py +271 -0
- {nshutils-0.2.0 → nshutils-0.3.0}/README.md +0 -0
- {nshutils-0.2.0 → nshutils-0.3.0}/src/nshutils/snoop.py +0 -0
- {nshutils-0.2.0 → nshutils-0.3.0}/src/nshutils/typecheck.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: nshutils
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.3.0
|
4
4
|
Summary:
|
5
5
|
Author: Nima Shoghi
|
6
6
|
Author-email: nimashoghi@gmail.com
|
@@ -11,6 +11,7 @@ Classifier: Programming Language :: Python :: 3.11
|
|
11
11
|
Classifier: Programming Language :: Python :: 3.12
|
12
12
|
Requires-Dist: beartype (>=0.18.5,<0.19.0)
|
13
13
|
Requires-Dist: jaxtyping (>=0.2.33,<0.3.0)
|
14
|
+
Requires-Dist: numpy
|
14
15
|
Requires-Dist: pysnooper (>=1.2.0,<2.0.0)
|
15
16
|
Requires-Dist: typing-extensions
|
16
17
|
Description-Content-Type: text/markdown
|
@@ -1,6 +1,6 @@
|
|
1
1
|
[tool.poetry]
|
2
2
|
name = "nshutils"
|
3
|
-
version = "0.
|
3
|
+
version = "0.3.0"
|
4
4
|
description = ""
|
5
5
|
authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
|
6
6
|
readme = "README.md"
|
@@ -11,8 +11,21 @@ pysnooper = "^1.2.0"
|
|
11
11
|
jaxtyping = "^0.2.33"
|
12
12
|
typing-extensions = "*"
|
13
13
|
beartype = "^0.18.5"
|
14
|
+
numpy = "*"
|
14
15
|
|
15
16
|
|
16
17
|
[build-system]
|
17
18
|
requires = ["poetry-core"]
|
18
19
|
build-backend = "poetry.core.masonry.api"
|
20
|
+
|
21
|
+
[tool.pyright]
|
22
|
+
typeCheckingMode = "standard"
|
23
|
+
deprecateTypingAliases = true
|
24
|
+
strictListInference = true
|
25
|
+
strictDictionaryInference = true
|
26
|
+
strictSetInference = true
|
27
|
+
reportPrivateImportUsage = false
|
28
|
+
ignore = ["./build/"]
|
29
|
+
|
30
|
+
[tool.ruff.lint]
|
31
|
+
ignore = ["F722", "F821", "E731", "E741"]
|
@@ -0,0 +1,6 @@
|
|
1
|
+
from ._loader import ActivationLoader as ActivationLoader
|
2
|
+
from ._loader import ActLoad as ActLoad
|
3
|
+
from ._saver import Activation as Activation
|
4
|
+
from ._saver import ActivationSaver as ActivationSaver
|
5
|
+
from ._saver import ActSave as ActSave
|
6
|
+
from ._saver import Transform as Transform
|
@@ -0,0 +1,144 @@
|
|
1
|
+
import pprint
|
2
|
+
from dataclasses import dataclass, field
|
3
|
+
from functools import cached_property
|
4
|
+
from logging import getLogger
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import cast, overload
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
from typing_extensions import TypeVar, override
|
10
|
+
|
11
|
+
log = getLogger(__name__)
|
12
|
+
|
13
|
+
T = TypeVar("T", infer_variance=True)
|
14
|
+
|
15
|
+
|
16
|
+
@dataclass
|
17
|
+
class LoadedActivation:
|
18
|
+
base_dir: Path = field(repr=False)
|
19
|
+
name: str
|
20
|
+
num_activations: int = field(init=False)
|
21
|
+
activation_files: list[Path] = field(init=False, repr=False)
|
22
|
+
|
23
|
+
def __post_init__(self):
|
24
|
+
if not self.activation_dir.exists():
|
25
|
+
raise ValueError(f"Activation dir {self.activation_dir} does not exist")
|
26
|
+
|
27
|
+
# The number of activations = the * of .npy files in the activation dir
|
28
|
+
self.activation_files = list(self.activation_dir.glob("*.npy"))
|
29
|
+
# Sort the activation files by the numerical index in the filename
|
30
|
+
self.activation_files.sort(key=lambda p: int(p.stem))
|
31
|
+
self.num_activations = len(self.activation_files)
|
32
|
+
|
33
|
+
@property
|
34
|
+
def activation_dir(self) -> Path:
|
35
|
+
return self.base_dir / self.name
|
36
|
+
|
37
|
+
def _load_activation(self, item: int):
|
38
|
+
activation_path = self.activation_files[item]
|
39
|
+
if not activation_path.exists():
|
40
|
+
raise ValueError(f"Activation {activation_path} does not exist")
|
41
|
+
return cast(np.ndarray, np.load(activation_path, allow_pickle=True))
|
42
|
+
|
43
|
+
@overload
|
44
|
+
def __getitem__(self, item: int) -> np.ndarray: ...
|
45
|
+
|
46
|
+
@overload
|
47
|
+
def __getitem__(self, item: slice | list[int]) -> list[np.ndarray]: ...
|
48
|
+
|
49
|
+
def __getitem__(
|
50
|
+
self, item: int | slice | list[int]
|
51
|
+
) -> np.ndarray | list[np.ndarray]:
|
52
|
+
if isinstance(item, int):
|
53
|
+
return self._load_activation(item)
|
54
|
+
elif isinstance(item, slice):
|
55
|
+
return [
|
56
|
+
self._load_activation(i)
|
57
|
+
for i in range(*item.indices(self.num_activations))
|
58
|
+
]
|
59
|
+
elif isinstance(item, list):
|
60
|
+
return [self._load_activation(i) for i in item]
|
61
|
+
else:
|
62
|
+
raise TypeError(f"Invalid type {type(item)} for item {item}")
|
63
|
+
|
64
|
+
def __iter__(self):
|
65
|
+
return iter(self[i] for i in range(self.num_activations))
|
66
|
+
|
67
|
+
def __len__(self):
|
68
|
+
return self.num_activations
|
69
|
+
|
70
|
+
def all_activations(self):
|
71
|
+
return [self[i] for i in range(self.num_activations)]
|
72
|
+
|
73
|
+
@override
|
74
|
+
def __repr__(self):
|
75
|
+
return f"<LoadedActivation {self.name} ({self.num_activations} activations)>"
|
76
|
+
|
77
|
+
|
78
|
+
class ActLoad:
|
79
|
+
@classmethod
|
80
|
+
def all_versions(cls, dir: str | Path):
|
81
|
+
dir = Path(dir)
|
82
|
+
|
83
|
+
# If the dir is not an activation base directory, we return None
|
84
|
+
if not (dir / ".activationbase").exists():
|
85
|
+
return None
|
86
|
+
|
87
|
+
# The contents of `dir` should be directories, each of which is a version.
|
88
|
+
return [
|
89
|
+
(subdir, int(subdir.name)) for subdir in dir.iterdir() if subdir.is_dir()
|
90
|
+
]
|
91
|
+
|
92
|
+
@classmethod
|
93
|
+
def is_valid_activation_base(cls, dir: str | Path):
|
94
|
+
return cls.all_versions(dir) is not None
|
95
|
+
|
96
|
+
@classmethod
|
97
|
+
def from_latest_version(cls, dir: str | Path):
|
98
|
+
# The contents of `dir` should be directories, each of which is a version
|
99
|
+
# We need to find the latest version
|
100
|
+
if (all_versions := cls.all_versions(dir)) is None:
|
101
|
+
raise ValueError(f"{dir} is not an activation base directory")
|
102
|
+
|
103
|
+
path, _ = max(all_versions, key=lambda p: p[1])
|
104
|
+
return cls(path)
|
105
|
+
|
106
|
+
def __init__(self, dir: Path):
|
107
|
+
self._dir = dir
|
108
|
+
|
109
|
+
def activation(self, name: str):
|
110
|
+
return LoadedActivation(self._dir, name)
|
111
|
+
|
112
|
+
@cached_property
|
113
|
+
def activations(self):
|
114
|
+
dirs = list(self._dir.iterdir())
|
115
|
+
# Sort the dirs by the last modified time
|
116
|
+
dirs.sort(key=lambda p: p.stat().st_mtime)
|
117
|
+
|
118
|
+
return {p.name: LoadedActivation(self._dir, p.name) for p in dirs}
|
119
|
+
|
120
|
+
def __iter__(self):
|
121
|
+
return iter(self.activations.values())
|
122
|
+
|
123
|
+
def __getitem__(self, item: str):
|
124
|
+
return self.activations[item]
|
125
|
+
|
126
|
+
def __len__(self):
|
127
|
+
return len(self.activations)
|
128
|
+
|
129
|
+
@override
|
130
|
+
def __repr__(self):
|
131
|
+
acts_str = pprint.pformat(
|
132
|
+
{
|
133
|
+
name: f"<{activation.num_activations} activations>"
|
134
|
+
for name, activation in self.activations.items()
|
135
|
+
}
|
136
|
+
)
|
137
|
+
acts_str = acts_str.replace("'<", "<").replace(">'", ">")
|
138
|
+
return f"ActLoad({acts_str})"
|
139
|
+
|
140
|
+
def get(self, name: str, /, default: T) -> LoadedActivation | T:
|
141
|
+
return self.activations.get(name, default)
|
142
|
+
|
143
|
+
|
144
|
+
ActivationLoader = ActLoad
|
@@ -0,0 +1,360 @@
|
|
1
|
+
import contextlib
|
2
|
+
import fnmatch
|
3
|
+
import tempfile
|
4
|
+
import uuid
|
5
|
+
import weakref
|
6
|
+
from collections.abc import Callable, Mapping
|
7
|
+
from dataclasses import dataclass
|
8
|
+
from functools import wraps
|
9
|
+
from logging import getLogger
|
10
|
+
from pathlib import Path
|
11
|
+
from typing import TYPE_CHECKING, Generic, TypeAlias, cast, overload
|
12
|
+
|
13
|
+
import numpy as np
|
14
|
+
from typing_extensions import Never, ParamSpec, TypeVar, override
|
15
|
+
|
16
|
+
from ..collections import apply_to_collection
|
17
|
+
|
18
|
+
try:
|
19
|
+
import torch
|
20
|
+
|
21
|
+
if not TYPE_CHECKING:
|
22
|
+
Tensor: TypeAlias = torch.Tensor
|
23
|
+
except ImportError:
|
24
|
+
torch = None
|
25
|
+
|
26
|
+
if not TYPE_CHECKING:
|
27
|
+
Tensor: TypeAlias = Never
|
28
|
+
|
29
|
+
if TYPE_CHECKING:
|
30
|
+
Tensor: TypeAlias = Never
|
31
|
+
|
32
|
+
log = getLogger(__name__)
|
33
|
+
|
34
|
+
Value: TypeAlias = int | float | complex | bool | str | np.ndarray | Tensor | None
|
35
|
+
ValueOrLambda = Value | Callable[..., Value]
|
36
|
+
|
37
|
+
|
38
|
+
def _torch_is_scripting() -> bool:
|
39
|
+
if torch is None:
|
40
|
+
return False
|
41
|
+
|
42
|
+
return torch.jit.is_scripting()
|
43
|
+
|
44
|
+
|
45
|
+
def _to_numpy(activation: Value) -> np.ndarray:
|
46
|
+
# Make sure it's not `None`
|
47
|
+
if activation is None:
|
48
|
+
raise ValueError("Activation should not be `None`")
|
49
|
+
|
50
|
+
if isinstance(activation, (int, float, complex, str, bool)):
|
51
|
+
return np.array(activation)
|
52
|
+
elif isinstance(activation, np.ndarray):
|
53
|
+
return activation
|
54
|
+
elif isinstance(activation, Tensor):
|
55
|
+
activation = activation.detach()
|
56
|
+
if activation.is_floating_point():
|
57
|
+
# NOTE: We need to convert to float32 because [b]float16 is not supported by numpy
|
58
|
+
activation = activation.float()
|
59
|
+
return activation.cpu().numpy()
|
60
|
+
else:
|
61
|
+
log.warning(f"Unrecognized activation type {type(activation)}")
|
62
|
+
|
63
|
+
return activation
|
64
|
+
|
65
|
+
|
66
|
+
T = TypeVar("T", infer_variance=True)
|
67
|
+
|
68
|
+
|
69
|
+
# A wrapper around weakref.ref that allows for primitive types
|
70
|
+
# To get around errors like:
|
71
|
+
# TypeError: cannot create weak reference to 'int' object
|
72
|
+
class WeakRef(Generic[T]):
|
73
|
+
_ref: Callable[[], T] | None
|
74
|
+
|
75
|
+
def __init__(self, obj: T):
|
76
|
+
try:
|
77
|
+
self._ref = cast(Callable[[], T], weakref.ref(obj))
|
78
|
+
except TypeError as e:
|
79
|
+
if "cannot create weak reference" not in str(e):
|
80
|
+
raise
|
81
|
+
self._ref = lambda: obj
|
82
|
+
|
83
|
+
def __call__(self) -> T:
|
84
|
+
if self._ref is None:
|
85
|
+
raise RuntimeError("WeakRef is deleted")
|
86
|
+
return self._ref()
|
87
|
+
|
88
|
+
def delete(self):
|
89
|
+
del self._ref
|
90
|
+
self._ref = None
|
91
|
+
|
92
|
+
|
93
|
+
@dataclass
|
94
|
+
class Activation:
|
95
|
+
name: str
|
96
|
+
ref: WeakRef[ValueOrLambda] | None
|
97
|
+
transformed: np.ndarray | None = None
|
98
|
+
|
99
|
+
def __post_init__(self):
|
100
|
+
# Update the `name` to replace `/` with `.`
|
101
|
+
self.name = self.name.replace("/", ".")
|
102
|
+
|
103
|
+
def __call__(self) -> np.ndarray | None:
|
104
|
+
# If we have a transformed value, we return it
|
105
|
+
if self.transformed is not None:
|
106
|
+
return self.transformed
|
107
|
+
|
108
|
+
if self.ref is None:
|
109
|
+
raise RuntimeError("Activation is deleted")
|
110
|
+
|
111
|
+
# If we have a lambda, we need to call it
|
112
|
+
unrwapped_ref = self.ref()
|
113
|
+
activation = unrwapped_ref
|
114
|
+
if callable(unrwapped_ref):
|
115
|
+
activation = unrwapped_ref()
|
116
|
+
|
117
|
+
# If we have a `None`, we return early
|
118
|
+
if activation is None:
|
119
|
+
return None
|
120
|
+
|
121
|
+
activation = apply_to_collection(activation, Tensor, _to_numpy)
|
122
|
+
activation = _to_numpy(activation)
|
123
|
+
|
124
|
+
# Set the transformed value
|
125
|
+
self.transformed = activation
|
126
|
+
|
127
|
+
# Delete the reference
|
128
|
+
self.ref.delete()
|
129
|
+
del self.ref
|
130
|
+
self.ref = None
|
131
|
+
|
132
|
+
return self.transformed
|
133
|
+
|
134
|
+
@classmethod
|
135
|
+
def from_value_or_lambda(cls, name: str, value_or_lambda: ValueOrLambda):
|
136
|
+
return cls(name, WeakRef(value_or_lambda))
|
137
|
+
|
138
|
+
@classmethod
|
139
|
+
def from_dict(cls, d: Mapping[str, ValueOrLambda]):
|
140
|
+
return [cls.from_value_or_lambda(k, v) for k, v in d.items()]
|
141
|
+
|
142
|
+
|
143
|
+
Transform = Callable[[Activation], Mapping[str, ValueOrLambda]]
|
144
|
+
|
145
|
+
|
146
|
+
def _ensure_supported():
|
147
|
+
try:
|
148
|
+
import torch.distributed as dist
|
149
|
+
|
150
|
+
if dist.is_initialized() and dist.get_world_size() > 1:
|
151
|
+
raise RuntimeError("Only single GPU is supported at the moment")
|
152
|
+
except ImportError:
|
153
|
+
pass
|
154
|
+
|
155
|
+
|
156
|
+
P = ParamSpec("P")
|
157
|
+
|
158
|
+
|
159
|
+
def _ignore_if_scripting(fn: Callable[P, None]) -> Callable[P, None]:
|
160
|
+
@wraps(fn)
|
161
|
+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
|
162
|
+
if _torch_is_scripting():
|
163
|
+
return
|
164
|
+
|
165
|
+
_ensure_supported()
|
166
|
+
fn(*args, **kwargs)
|
167
|
+
|
168
|
+
return wrapper
|
169
|
+
|
170
|
+
|
171
|
+
class _Saver:
|
172
|
+
def __init__(
|
173
|
+
self,
|
174
|
+
save_dir: Path,
|
175
|
+
prefixes_fn: Callable[[], list[str]],
|
176
|
+
*,
|
177
|
+
filters: list[str] | None = None,
|
178
|
+
):
|
179
|
+
# Create a directory under `save_dir` by autoincrementing
|
180
|
+
# (i.e., every activation save context, we create a new directory)
|
181
|
+
# The id = the number of activation subdirectories
|
182
|
+
self._id = sum(1 for subdir in save_dir.glob("*") if subdir.is_dir())
|
183
|
+
save_dir.mkdir(parents=True, exist_ok=True)
|
184
|
+
|
185
|
+
# Add a .activationbase file to the save_dir to indicate that this is an activation base
|
186
|
+
(save_dir / ".activationbase").touch(exist_ok=True)
|
187
|
+
|
188
|
+
self._save_dir = save_dir / f"{self._id:04d}"
|
189
|
+
# Make sure `self._save_dir` does not exist and create it
|
190
|
+
self._save_dir.mkdir(exist_ok=False)
|
191
|
+
|
192
|
+
self._prefixes_fn = prefixes_fn
|
193
|
+
self._filters = filters
|
194
|
+
|
195
|
+
def _save_activation(self, activation: Activation):
|
196
|
+
# If the activation value is `None`, we skip it.
|
197
|
+
if (activation_value := activation()) is None:
|
198
|
+
return
|
199
|
+
|
200
|
+
# Save the activation to self._save_dir / name / {id}.npz, where id is an auto-incrementing integer
|
201
|
+
file_name = ".".join(self._prefixes_fn() + [activation.name])
|
202
|
+
path = self._save_dir / file_name
|
203
|
+
path.mkdir(exist_ok=True, parents=True)
|
204
|
+
|
205
|
+
# Get the next id and save the activation
|
206
|
+
id = len(list(path.glob("*.npy")))
|
207
|
+
np.save(path / f"{id:04d}.npy", activation_value)
|
208
|
+
|
209
|
+
@_ignore_if_scripting
|
210
|
+
def save(
|
211
|
+
self,
|
212
|
+
acts: dict[str, ValueOrLambda] | None = None,
|
213
|
+
/,
|
214
|
+
**kwargs: ValueOrLambda,
|
215
|
+
):
|
216
|
+
kwargs.update(acts or {})
|
217
|
+
|
218
|
+
# Build activations
|
219
|
+
activations = Activation.from_dict(kwargs)
|
220
|
+
|
221
|
+
for activation in activations:
|
222
|
+
# Make sure name matches at least one filter if filters are specified
|
223
|
+
if self._filters is not None and all(
|
224
|
+
not fnmatch.fnmatch(activation.name, f) for f in self._filters
|
225
|
+
):
|
226
|
+
continue
|
227
|
+
|
228
|
+
# Save the current activation
|
229
|
+
self._save_activation(activation)
|
230
|
+
|
231
|
+
del activations
|
232
|
+
|
233
|
+
|
234
|
+
class ActSaveProvider:
|
235
|
+
_saver: _Saver | None = None
|
236
|
+
_prefixes: list[str] = []
|
237
|
+
|
238
|
+
def initialize(self, save_dir: Path | None = None):
|
239
|
+
"""
|
240
|
+
Initializes the saver with the given configuration and save directory.
|
241
|
+
|
242
|
+
Args:
|
243
|
+
save_dir (Path): The directory where the saved files will be stored.
|
244
|
+
"""
|
245
|
+
if self._saver is None:
|
246
|
+
if save_dir is None:
|
247
|
+
save_dir = Path(tempfile.gettempdir()) / f"actsave-{uuid.uuid4()}"
|
248
|
+
log.critical(f"No save_dir specified, using {save_dir=}")
|
249
|
+
self._saver = _Saver(
|
250
|
+
save_dir,
|
251
|
+
lambda: self._prefixes,
|
252
|
+
)
|
253
|
+
|
254
|
+
@contextlib.contextmanager
|
255
|
+
def enabled(self, save_dir: Path | None = None):
|
256
|
+
"""
|
257
|
+
Context manager that enables the actsave functionality with the specified configuration.
|
258
|
+
|
259
|
+
Args:
|
260
|
+
save_dir (Path): The directory where the saved files will be stored.
|
261
|
+
"""
|
262
|
+
prev = self._saver
|
263
|
+
self.initialize(save_dir)
|
264
|
+
try:
|
265
|
+
yield
|
266
|
+
finally:
|
267
|
+
self._saver = prev
|
268
|
+
|
269
|
+
@override
|
270
|
+
def __init__(self):
|
271
|
+
super().__init__()
|
272
|
+
|
273
|
+
self._saver = None
|
274
|
+
self._prefixes = []
|
275
|
+
|
276
|
+
@contextlib.contextmanager
|
277
|
+
def context(self, label: str):
|
278
|
+
"""
|
279
|
+
A context manager that adds a label to the current context.
|
280
|
+
|
281
|
+
Args:
|
282
|
+
label (str): The label for the context.
|
283
|
+
"""
|
284
|
+
if _torch_is_scripting():
|
285
|
+
yield
|
286
|
+
return
|
287
|
+
|
288
|
+
if self._saver is None:
|
289
|
+
yield
|
290
|
+
return
|
291
|
+
|
292
|
+
_ensure_supported()
|
293
|
+
|
294
|
+
log.debug(f"Entering ActSave context {label}")
|
295
|
+
self._prefixes.append(label)
|
296
|
+
try:
|
297
|
+
yield
|
298
|
+
finally:
|
299
|
+
_ = self._prefixes.pop()
|
300
|
+
|
301
|
+
prefix = context
|
302
|
+
|
303
|
+
@overload
|
304
|
+
def __call__(
|
305
|
+
self,
|
306
|
+
acts: dict[str, ValueOrLambda] | None = None,
|
307
|
+
/,
|
308
|
+
**kwargs: ValueOrLambda,
|
309
|
+
):
|
310
|
+
"""
|
311
|
+
Saves the activations to disk.
|
312
|
+
|
313
|
+
Args:
|
314
|
+
acts (dict[str, ValueOrLambda] | None, optional): A dictionary of acts. Defaults to None.
|
315
|
+
**kwargs (ValueOrLambda): Additional keyword arguments.
|
316
|
+
|
317
|
+
Returns:
|
318
|
+
None
|
319
|
+
|
320
|
+
"""
|
321
|
+
...
|
322
|
+
|
323
|
+
@overload
|
324
|
+
def __call__(self, acts: Callable[[], dict[str, ValueOrLambda]], /):
|
325
|
+
"""
|
326
|
+
Saves the activations to disk.
|
327
|
+
|
328
|
+
Args:
|
329
|
+
acts (Callable[[], dict[str, ValueOrLambda]]): A callable that returns a dictionary of acts.
|
330
|
+
**kwargs (ValueOrLambda): Additional keyword arguments.
|
331
|
+
|
332
|
+
Returns:
|
333
|
+
None
|
334
|
+
|
335
|
+
"""
|
336
|
+
...
|
337
|
+
|
338
|
+
def __call__(
|
339
|
+
self,
|
340
|
+
acts: (
|
341
|
+
dict[str, ValueOrLambda] | Callable[[], dict[str, ValueOrLambda]] | None
|
342
|
+
) = None,
|
343
|
+
/,
|
344
|
+
**kwargs: ValueOrLambda,
|
345
|
+
):
|
346
|
+
if _torch_is_scripting():
|
347
|
+
return
|
348
|
+
|
349
|
+
if self._saver is None:
|
350
|
+
return
|
351
|
+
|
352
|
+
if acts is not None and callable(acts):
|
353
|
+
acts = acts()
|
354
|
+
self._saver.save(acts, **kwargs)
|
355
|
+
|
356
|
+
save = __call__
|
357
|
+
|
358
|
+
|
359
|
+
ActSave = ActSaveProvider()
|
360
|
+
ActivationSaver = ActSave
|
@@ -0,0 +1,271 @@
|
|
1
|
+
# Copyright The PyTorch Lightning team.
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
4
|
+
#
|
5
|
+
import dataclasses
|
6
|
+
from collections import OrderedDict, defaultdict
|
7
|
+
from collections.abc import Callable, Mapping, Sequence
|
8
|
+
from copy import deepcopy
|
9
|
+
from typing import Any
|
10
|
+
|
11
|
+
|
12
|
+
def is_namedtuple(obj: object) -> bool:
|
13
|
+
"""Check if object is type nametuple."""
|
14
|
+
# https://github.com/pytorch/pytorch/blob/v1.8.1/torch/nn/parallel/scatter_gather.py#L4-L8
|
15
|
+
return (
|
16
|
+
isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")
|
17
|
+
)
|
18
|
+
|
19
|
+
|
20
|
+
def is_dataclass_instance(obj: object) -> bool:
|
21
|
+
"""Check if object is dataclass."""
|
22
|
+
# https://docs.python.org/3/library/dataclasses.html#module-level-decorators-classes-and-functions
|
23
|
+
return dataclasses.is_dataclass(obj) and not isinstance(obj, type)
|
24
|
+
|
25
|
+
|
26
|
+
def apply_to_collection(
|
27
|
+
data: Any,
|
28
|
+
dtype: type | Any | tuple[type | Any],
|
29
|
+
function: Callable,
|
30
|
+
*args: Any,
|
31
|
+
wrong_dtype: type | tuple[type, ...] | None = None,
|
32
|
+
include_none: bool = True,
|
33
|
+
allow_frozen: bool = False,
|
34
|
+
**kwargs: Any,
|
35
|
+
) -> Any:
|
36
|
+
"""Recursively applies a function to all elements of a certain dtype.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
data: the collection to apply the function to
|
40
|
+
dtype: the given function will be applied to all elements of this dtype
|
41
|
+
function: the function to apply
|
42
|
+
*args: positional arguments (will be forwarded to calls of ``function``)
|
43
|
+
wrong_dtype: the given function won't be applied if this type is specified and the given collections
|
44
|
+
is of the ``wrong_dtype`` even if it is of type ``dtype``
|
45
|
+
include_none: Whether to include an element if the output of ``function`` is ``None``.
|
46
|
+
allow_frozen: Whether not to error upon encountering a frozen dataclass instance.
|
47
|
+
**kwargs: keyword arguments (will be forwarded to calls of ``function``)
|
48
|
+
|
49
|
+
Returns:
|
50
|
+
The resulting collection
|
51
|
+
"""
|
52
|
+
# Breaking condition
|
53
|
+
if isinstance(data, dtype) and (
|
54
|
+
wrong_dtype is None or not isinstance(data, wrong_dtype)
|
55
|
+
):
|
56
|
+
return function(data, *args, **kwargs)
|
57
|
+
|
58
|
+
elem_type = type(data)
|
59
|
+
|
60
|
+
# Recursively apply to collection items
|
61
|
+
if isinstance(data, Mapping):
|
62
|
+
out = []
|
63
|
+
for k, v in data.items():
|
64
|
+
v = apply_to_collection(
|
65
|
+
v,
|
66
|
+
dtype,
|
67
|
+
function,
|
68
|
+
*args,
|
69
|
+
wrong_dtype=wrong_dtype,
|
70
|
+
include_none=include_none,
|
71
|
+
allow_frozen=allow_frozen,
|
72
|
+
**kwargs,
|
73
|
+
)
|
74
|
+
if include_none or v is not None:
|
75
|
+
out.append((k, v))
|
76
|
+
if isinstance(data, defaultdict):
|
77
|
+
return elem_type(data.default_factory, OrderedDict(out))
|
78
|
+
return elem_type(OrderedDict(out))
|
79
|
+
|
80
|
+
is_namedtuple_ = is_namedtuple(data)
|
81
|
+
is_sequence = isinstance(data, Sequence) and not isinstance(data, str)
|
82
|
+
if is_namedtuple_ or is_sequence:
|
83
|
+
out = []
|
84
|
+
for d in data:
|
85
|
+
v = apply_to_collection(
|
86
|
+
d,
|
87
|
+
dtype,
|
88
|
+
function,
|
89
|
+
*args,
|
90
|
+
wrong_dtype=wrong_dtype,
|
91
|
+
include_none=include_none,
|
92
|
+
allow_frozen=allow_frozen,
|
93
|
+
**kwargs,
|
94
|
+
)
|
95
|
+
if include_none or v is not None:
|
96
|
+
out.append(v)
|
97
|
+
return elem_type(*out) if is_namedtuple_ else elem_type(out)
|
98
|
+
|
99
|
+
if is_dataclass_instance(data):
|
100
|
+
# make a deepcopy of the data,
|
101
|
+
# but do not deepcopy mapped fields since the computation would
|
102
|
+
# be wasted on values that likely get immediately overwritten
|
103
|
+
fields = {}
|
104
|
+
memo = {}
|
105
|
+
for field in dataclasses.fields(data):
|
106
|
+
field_value = getattr(data, field.name)
|
107
|
+
fields[field.name] = (field_value, field.init)
|
108
|
+
memo[id(field_value)] = field_value
|
109
|
+
result = deepcopy(data, memo=memo)
|
110
|
+
# apply function to each field
|
111
|
+
for field_name, (field_value, field_init) in fields.items():
|
112
|
+
v = None
|
113
|
+
if field_init:
|
114
|
+
v = apply_to_collection(
|
115
|
+
field_value,
|
116
|
+
dtype,
|
117
|
+
function,
|
118
|
+
*args,
|
119
|
+
wrong_dtype=wrong_dtype,
|
120
|
+
include_none=include_none,
|
121
|
+
allow_frozen=allow_frozen,
|
122
|
+
**kwargs,
|
123
|
+
)
|
124
|
+
if not field_init or (not include_none and v is None): # retain old value
|
125
|
+
v = getattr(data, field_name)
|
126
|
+
try:
|
127
|
+
setattr(result, field_name, v)
|
128
|
+
except dataclasses.FrozenInstanceError as e:
|
129
|
+
if allow_frozen:
|
130
|
+
# Quit early if we encounter a frozen data class; return `result` as is.
|
131
|
+
break
|
132
|
+
raise ValueError(
|
133
|
+
"A frozen dataclass was passed to `apply_to_collection` but this is not allowed."
|
134
|
+
) from e
|
135
|
+
return result
|
136
|
+
|
137
|
+
# data is neither of dtype, nor a collection
|
138
|
+
return data
|
139
|
+
|
140
|
+
|
141
|
+
def apply_to_collections(
|
142
|
+
data1: Any | None,
|
143
|
+
data2: Any | None,
|
144
|
+
dtype: type | Any | tuple[type | Any],
|
145
|
+
function: Callable,
|
146
|
+
*args: Any,
|
147
|
+
wrong_dtype: type | tuple[type] | None = None,
|
148
|
+
**kwargs: Any,
|
149
|
+
) -> Any:
|
150
|
+
"""Zips two collections and applies a function to their items of a certain dtype.
|
151
|
+
|
152
|
+
Args:
|
153
|
+
data1: The first collection
|
154
|
+
data2: The second collection
|
155
|
+
dtype: the given function will be applied to all elements of this dtype
|
156
|
+
function: the function to apply
|
157
|
+
*args: positional arguments (will be forwarded to calls of ``function``)
|
158
|
+
wrong_dtype: the given function won't be applied if this type is specified and the given collections
|
159
|
+
is of the ``wrong_dtype`` even if it is of type ``dtype``
|
160
|
+
**kwargs: keyword arguments (will be forwarded to calls of ``function``)
|
161
|
+
|
162
|
+
Returns:
|
163
|
+
The resulting collection
|
164
|
+
|
165
|
+
Raises:
|
166
|
+
AssertionError:
|
167
|
+
If sequence collections have different data sizes.
|
168
|
+
"""
|
169
|
+
if data1 is None:
|
170
|
+
if data2 is None:
|
171
|
+
return None
|
172
|
+
# in case they were passed reversed
|
173
|
+
data1, data2 = data2, None
|
174
|
+
|
175
|
+
elem_type = type(data1)
|
176
|
+
|
177
|
+
if (
|
178
|
+
isinstance(data1, dtype)
|
179
|
+
and data2 is not None
|
180
|
+
and (wrong_dtype is None or not isinstance(data1, wrong_dtype))
|
181
|
+
):
|
182
|
+
return function(data1, data2, *args, **kwargs)
|
183
|
+
|
184
|
+
if isinstance(data1, Mapping) and data2 is not None:
|
185
|
+
# use union because we want to fail if a key does not exist in both
|
186
|
+
zipped = {k: (data1[k], data2[k]) for k in data1.keys() | data2.keys()}
|
187
|
+
return elem_type(
|
188
|
+
{
|
189
|
+
k: apply_to_collections(
|
190
|
+
*v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs
|
191
|
+
)
|
192
|
+
for k, v in zipped.items()
|
193
|
+
}
|
194
|
+
)
|
195
|
+
|
196
|
+
is_namedtuple_ = is_namedtuple(data1)
|
197
|
+
is_sequence = isinstance(data1, Sequence) and not isinstance(data1, str)
|
198
|
+
if (is_namedtuple_ or is_sequence) and data2 is not None:
|
199
|
+
if len(data1) != len(data2):
|
200
|
+
raise ValueError("Sequence collections have different sizes.")
|
201
|
+
out = [
|
202
|
+
apply_to_collections(
|
203
|
+
v1, v2, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs
|
204
|
+
)
|
205
|
+
for v1, v2 in zip(data1, data2)
|
206
|
+
]
|
207
|
+
return elem_type(*out) if is_namedtuple_ else elem_type(out)
|
208
|
+
|
209
|
+
if is_dataclass_instance(data1) and data2 is not None:
|
210
|
+
if not is_dataclass_instance(data2):
|
211
|
+
raise TypeError(
|
212
|
+
"Expected inputs to be dataclasses of the same type or to have identical fields"
|
213
|
+
f" but got input 1 of type {type(data1)} and input 2 of type {type(data2)}."
|
214
|
+
)
|
215
|
+
if not (
|
216
|
+
len(dataclasses.fields(data1)) == len(dataclasses.fields(data2))
|
217
|
+
and all(
|
218
|
+
map(
|
219
|
+
lambda f1, f2: isinstance(f1, type(f2)),
|
220
|
+
dataclasses.fields(data1),
|
221
|
+
dataclasses.fields(data2),
|
222
|
+
)
|
223
|
+
)
|
224
|
+
):
|
225
|
+
raise TypeError("Dataclasses fields do not match.")
|
226
|
+
# make a deepcopy of the data,
|
227
|
+
# but do not deepcopy mapped fields since the computation would
|
228
|
+
# be wasted on values that likely get immediately overwritten
|
229
|
+
data = [data1, data2]
|
230
|
+
fields: list[dict] = [{}, {}]
|
231
|
+
memo: dict = {}
|
232
|
+
for i in range(len(data)):
|
233
|
+
for field in dataclasses.fields(data[i]):
|
234
|
+
field_value = getattr(data[i], field.name)
|
235
|
+
fields[i][field.name] = (field_value, field.init)
|
236
|
+
if i == 0:
|
237
|
+
memo[id(field_value)] = field_value
|
238
|
+
|
239
|
+
result = deepcopy(data1, memo=memo)
|
240
|
+
|
241
|
+
# apply function to each field
|
242
|
+
for (field_name, (field_value1, field_init1)), (
|
243
|
+
_,
|
244
|
+
(field_value2, field_init2),
|
245
|
+
) in zip(fields[0].items(), fields[1].items()):
|
246
|
+
v = None
|
247
|
+
if field_init1 and field_init2:
|
248
|
+
v = apply_to_collections(
|
249
|
+
field_value1,
|
250
|
+
field_value2,
|
251
|
+
dtype,
|
252
|
+
function,
|
253
|
+
*args,
|
254
|
+
wrong_dtype=wrong_dtype,
|
255
|
+
**kwargs,
|
256
|
+
)
|
257
|
+
if not field_init1 or not field_init2 or v is None: # retain old value
|
258
|
+
return apply_to_collection(
|
259
|
+
data1, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs
|
260
|
+
)
|
261
|
+
try:
|
262
|
+
setattr(result, field_name, v)
|
263
|
+
except dataclasses.FrozenInstanceError as e:
|
264
|
+
raise ValueError(
|
265
|
+
"A frozen dataclass was passed to `apply_to_collections` but this is not allowed."
|
266
|
+
) from e
|
267
|
+
return result
|
268
|
+
|
269
|
+
return apply_to_collection(
|
270
|
+
data1, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs
|
271
|
+
)
|
File without changes
|
File without changes
|
File without changes
|