nshutils 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nshutils/__init__.py CHANGED
@@ -1,2 +1,3 @@
1
+ from . import actsave as actsave
1
2
  from . import typecheck as typecheck
2
3
  from .snoop import snoop as snoop
@@ -0,0 +1,6 @@
1
+ from ._loader import ActivationLoader as ActivationLoader
2
+ from ._loader import ActLoad as ActLoad
3
+ from ._saver import Activation as Activation
4
+ from ._saver import ActivationSaver as ActivationSaver
5
+ from ._saver import ActSave as ActSave
6
+ from ._saver import Transform as Transform
@@ -0,0 +1,144 @@
1
+ import pprint
2
+ from dataclasses import dataclass, field
3
+ from functools import cached_property
4
+ from logging import getLogger
5
+ from pathlib import Path
6
+ from typing import cast, overload
7
+
8
+ import numpy as np
9
+ from typing_extensions import TypeVar, override
10
+
11
+ log = getLogger(__name__)
12
+
13
+ T = TypeVar("T", infer_variance=True)
14
+
15
+
16
+ @dataclass
17
+ class LoadedActivation:
18
+ base_dir: Path = field(repr=False)
19
+ name: str
20
+ num_activations: int = field(init=False)
21
+ activation_files: list[Path] = field(init=False, repr=False)
22
+
23
+ def __post_init__(self):
24
+ if not self.activation_dir.exists():
25
+ raise ValueError(f"Activation dir {self.activation_dir} does not exist")
26
+
27
+ # The number of activations = the * of .npy files in the activation dir
28
+ self.activation_files = list(self.activation_dir.glob("*.npy"))
29
+ # Sort the activation files by the numerical index in the filename
30
+ self.activation_files.sort(key=lambda p: int(p.stem))
31
+ self.num_activations = len(self.activation_files)
32
+
33
+ @property
34
+ def activation_dir(self) -> Path:
35
+ return self.base_dir / self.name
36
+
37
+ def _load_activation(self, item: int):
38
+ activation_path = self.activation_files[item]
39
+ if not activation_path.exists():
40
+ raise ValueError(f"Activation {activation_path} does not exist")
41
+ return cast(np.ndarray, np.load(activation_path, allow_pickle=True))
42
+
43
+ @overload
44
+ def __getitem__(self, item: int) -> np.ndarray: ...
45
+
46
+ @overload
47
+ def __getitem__(self, item: slice | list[int]) -> list[np.ndarray]: ...
48
+
49
+ def __getitem__(
50
+ self, item: int | slice | list[int]
51
+ ) -> np.ndarray | list[np.ndarray]:
52
+ if isinstance(item, int):
53
+ return self._load_activation(item)
54
+ elif isinstance(item, slice):
55
+ return [
56
+ self._load_activation(i)
57
+ for i in range(*item.indices(self.num_activations))
58
+ ]
59
+ elif isinstance(item, list):
60
+ return [self._load_activation(i) for i in item]
61
+ else:
62
+ raise TypeError(f"Invalid type {type(item)} for item {item}")
63
+
64
+ def __iter__(self):
65
+ return iter(self[i] for i in range(self.num_activations))
66
+
67
+ def __len__(self):
68
+ return self.num_activations
69
+
70
+ def all_activations(self):
71
+ return [self[i] for i in range(self.num_activations)]
72
+
73
+ @override
74
+ def __repr__(self):
75
+ return f"<LoadedActivation {self.name} ({self.num_activations} activations)>"
76
+
77
+
78
+ class ActLoad:
79
+ @classmethod
80
+ def all_versions(cls, dir: str | Path):
81
+ dir = Path(dir)
82
+
83
+ # If the dir is not an activation base directory, we return None
84
+ if not (dir / ".activationbase").exists():
85
+ return None
86
+
87
+ # The contents of `dir` should be directories, each of which is a version.
88
+ return [
89
+ (subdir, int(subdir.name)) for subdir in dir.iterdir() if subdir.is_dir()
90
+ ]
91
+
92
+ @classmethod
93
+ def is_valid_activation_base(cls, dir: str | Path):
94
+ return cls.all_versions(dir) is not None
95
+
96
+ @classmethod
97
+ def from_latest_version(cls, dir: str | Path):
98
+ # The contents of `dir` should be directories, each of which is a version
99
+ # We need to find the latest version
100
+ if (all_versions := cls.all_versions(dir)) is None:
101
+ raise ValueError(f"{dir} is not an activation base directory")
102
+
103
+ path, _ = max(all_versions, key=lambda p: p[1])
104
+ return cls(path)
105
+
106
+ def __init__(self, dir: Path):
107
+ self._dir = dir
108
+
109
+ def activation(self, name: str):
110
+ return LoadedActivation(self._dir, name)
111
+
112
+ @cached_property
113
+ def activations(self):
114
+ dirs = list(self._dir.iterdir())
115
+ # Sort the dirs by the last modified time
116
+ dirs.sort(key=lambda p: p.stat().st_mtime)
117
+
118
+ return {p.name: LoadedActivation(self._dir, p.name) for p in dirs}
119
+
120
+ def __iter__(self):
121
+ return iter(self.activations.values())
122
+
123
+ def __getitem__(self, item: str):
124
+ return self.activations[item]
125
+
126
+ def __len__(self):
127
+ return len(self.activations)
128
+
129
+ @override
130
+ def __repr__(self):
131
+ acts_str = pprint.pformat(
132
+ {
133
+ name: f"<{activation.num_activations} activations>"
134
+ for name, activation in self.activations.items()
135
+ }
136
+ )
137
+ acts_str = acts_str.replace("'<", "<").replace(">'", ">")
138
+ return f"ActLoad({acts_str})"
139
+
140
+ def get(self, name: str, /, default: T) -> LoadedActivation | T:
141
+ return self.activations.get(name, default)
142
+
143
+
144
+ ActivationLoader = ActLoad
@@ -0,0 +1,360 @@
1
+ import contextlib
2
+ import fnmatch
3
+ import tempfile
4
+ import uuid
5
+ import weakref
6
+ from collections.abc import Callable, Mapping
7
+ from dataclasses import dataclass
8
+ from functools import wraps
9
+ from logging import getLogger
10
+ from pathlib import Path
11
+ from typing import TYPE_CHECKING, Generic, TypeAlias, cast, overload
12
+
13
+ import numpy as np
14
+ from typing_extensions import Never, ParamSpec, TypeVar, override
15
+
16
+ from ..collections import apply_to_collection
17
+
18
+ try:
19
+ import torch
20
+
21
+ if not TYPE_CHECKING:
22
+ Tensor: TypeAlias = torch.Tensor
23
+ except ImportError:
24
+ torch = None
25
+
26
+ if not TYPE_CHECKING:
27
+ Tensor: TypeAlias = Never
28
+
29
+ if TYPE_CHECKING:
30
+ Tensor: TypeAlias = Never
31
+
32
+ log = getLogger(__name__)
33
+
34
+ Value: TypeAlias = int | float | complex | bool | str | np.ndarray | Tensor | None
35
+ ValueOrLambda = Value | Callable[..., Value]
36
+
37
+
38
+ def _torch_is_scripting() -> bool:
39
+ if torch is None:
40
+ return False
41
+
42
+ return torch.jit.is_scripting()
43
+
44
+
45
+ def _to_numpy(activation: Value) -> np.ndarray:
46
+ # Make sure it's not `None`
47
+ if activation is None:
48
+ raise ValueError("Activation should not be `None`")
49
+
50
+ if isinstance(activation, (int, float, complex, str, bool)):
51
+ return np.array(activation)
52
+ elif isinstance(activation, np.ndarray):
53
+ return activation
54
+ elif isinstance(activation, Tensor):
55
+ activation = activation.detach()
56
+ if activation.is_floating_point():
57
+ # NOTE: We need to convert to float32 because [b]float16 is not supported by numpy
58
+ activation = activation.float()
59
+ return activation.cpu().numpy()
60
+ else:
61
+ log.warning(f"Unrecognized activation type {type(activation)}")
62
+
63
+ return activation
64
+
65
+
66
+ T = TypeVar("T", infer_variance=True)
67
+
68
+
69
+ # A wrapper around weakref.ref that allows for primitive types
70
+ # To get around errors like:
71
+ # TypeError: cannot create weak reference to 'int' object
72
+ class WeakRef(Generic[T]):
73
+ _ref: Callable[[], T] | None
74
+
75
+ def __init__(self, obj: T):
76
+ try:
77
+ self._ref = cast(Callable[[], T], weakref.ref(obj))
78
+ except TypeError as e:
79
+ if "cannot create weak reference" not in str(e):
80
+ raise
81
+ self._ref = lambda: obj
82
+
83
+ def __call__(self) -> T:
84
+ if self._ref is None:
85
+ raise RuntimeError("WeakRef is deleted")
86
+ return self._ref()
87
+
88
+ def delete(self):
89
+ del self._ref
90
+ self._ref = None
91
+
92
+
93
+ @dataclass
94
+ class Activation:
95
+ name: str
96
+ ref: WeakRef[ValueOrLambda] | None
97
+ transformed: np.ndarray | None = None
98
+
99
+ def __post_init__(self):
100
+ # Update the `name` to replace `/` with `.`
101
+ self.name = self.name.replace("/", ".")
102
+
103
+ def __call__(self) -> np.ndarray | None:
104
+ # If we have a transformed value, we return it
105
+ if self.transformed is not None:
106
+ return self.transformed
107
+
108
+ if self.ref is None:
109
+ raise RuntimeError("Activation is deleted")
110
+
111
+ # If we have a lambda, we need to call it
112
+ unrwapped_ref = self.ref()
113
+ activation = unrwapped_ref
114
+ if callable(unrwapped_ref):
115
+ activation = unrwapped_ref()
116
+
117
+ # If we have a `None`, we return early
118
+ if activation is None:
119
+ return None
120
+
121
+ activation = apply_to_collection(activation, Tensor, _to_numpy)
122
+ activation = _to_numpy(activation)
123
+
124
+ # Set the transformed value
125
+ self.transformed = activation
126
+
127
+ # Delete the reference
128
+ self.ref.delete()
129
+ del self.ref
130
+ self.ref = None
131
+
132
+ return self.transformed
133
+
134
+ @classmethod
135
+ def from_value_or_lambda(cls, name: str, value_or_lambda: ValueOrLambda):
136
+ return cls(name, WeakRef(value_or_lambda))
137
+
138
+ @classmethod
139
+ def from_dict(cls, d: Mapping[str, ValueOrLambda]):
140
+ return [cls.from_value_or_lambda(k, v) for k, v in d.items()]
141
+
142
+
143
+ Transform = Callable[[Activation], Mapping[str, ValueOrLambda]]
144
+
145
+
146
+ def _ensure_supported():
147
+ try:
148
+ import torch.distributed as dist
149
+
150
+ if dist.is_initialized() and dist.get_world_size() > 1:
151
+ raise RuntimeError("Only single GPU is supported at the moment")
152
+ except ImportError:
153
+ pass
154
+
155
+
156
+ P = ParamSpec("P")
157
+
158
+
159
+ def _ignore_if_scripting(fn: Callable[P, None]) -> Callable[P, None]:
160
+ @wraps(fn)
161
+ def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
162
+ if _torch_is_scripting():
163
+ return
164
+
165
+ _ensure_supported()
166
+ fn(*args, **kwargs)
167
+
168
+ return wrapper
169
+
170
+
171
+ class _Saver:
172
+ def __init__(
173
+ self,
174
+ save_dir: Path,
175
+ prefixes_fn: Callable[[], list[str]],
176
+ *,
177
+ filters: list[str] | None = None,
178
+ ):
179
+ # Create a directory under `save_dir` by autoincrementing
180
+ # (i.e., every activation save context, we create a new directory)
181
+ # The id = the number of activation subdirectories
182
+ self._id = sum(1 for subdir in save_dir.glob("*") if subdir.is_dir())
183
+ save_dir.mkdir(parents=True, exist_ok=True)
184
+
185
+ # Add a .activationbase file to the save_dir to indicate that this is an activation base
186
+ (save_dir / ".activationbase").touch(exist_ok=True)
187
+
188
+ self._save_dir = save_dir / f"{self._id:04d}"
189
+ # Make sure `self._save_dir` does not exist and create it
190
+ self._save_dir.mkdir(exist_ok=False)
191
+
192
+ self._prefixes_fn = prefixes_fn
193
+ self._filters = filters
194
+
195
+ def _save_activation(self, activation: Activation):
196
+ # If the activation value is `None`, we skip it.
197
+ if (activation_value := activation()) is None:
198
+ return
199
+
200
+ # Save the activation to self._save_dir / name / {id}.npz, where id is an auto-incrementing integer
201
+ file_name = ".".join(self._prefixes_fn() + [activation.name])
202
+ path = self._save_dir / file_name
203
+ path.mkdir(exist_ok=True, parents=True)
204
+
205
+ # Get the next id and save the activation
206
+ id = len(list(path.glob("*.npy")))
207
+ np.save(path / f"{id:04d}.npy", activation_value)
208
+
209
+ @_ignore_if_scripting
210
+ def save(
211
+ self,
212
+ acts: dict[str, ValueOrLambda] | None = None,
213
+ /,
214
+ **kwargs: ValueOrLambda,
215
+ ):
216
+ kwargs.update(acts or {})
217
+
218
+ # Build activations
219
+ activations = Activation.from_dict(kwargs)
220
+
221
+ for activation in activations:
222
+ # Make sure name matches at least one filter if filters are specified
223
+ if self._filters is not None and all(
224
+ not fnmatch.fnmatch(activation.name, f) for f in self._filters
225
+ ):
226
+ continue
227
+
228
+ # Save the current activation
229
+ self._save_activation(activation)
230
+
231
+ del activations
232
+
233
+
234
+ class ActSaveProvider:
235
+ _saver: _Saver | None = None
236
+ _prefixes: list[str] = []
237
+
238
+ def initialize(self, save_dir: Path | None = None):
239
+ """
240
+ Initializes the saver with the given configuration and save directory.
241
+
242
+ Args:
243
+ save_dir (Path): The directory where the saved files will be stored.
244
+ """
245
+ if self._saver is None:
246
+ if save_dir is None:
247
+ save_dir = Path(tempfile.gettempdir()) / f"actsave-{uuid.uuid4()}"
248
+ log.critical(f"No save_dir specified, using {save_dir=}")
249
+ self._saver = _Saver(
250
+ save_dir,
251
+ lambda: self._prefixes,
252
+ )
253
+
254
+ @contextlib.contextmanager
255
+ def enabled(self, save_dir: Path | None = None):
256
+ """
257
+ Context manager that enables the actsave functionality with the specified configuration.
258
+
259
+ Args:
260
+ save_dir (Path): The directory where the saved files will be stored.
261
+ """
262
+ prev = self._saver
263
+ self.initialize(save_dir)
264
+ try:
265
+ yield
266
+ finally:
267
+ self._saver = prev
268
+
269
+ @override
270
+ def __init__(self):
271
+ super().__init__()
272
+
273
+ self._saver = None
274
+ self._prefixes = []
275
+
276
+ @contextlib.contextmanager
277
+ def context(self, label: str):
278
+ """
279
+ A context manager that adds a label to the current context.
280
+
281
+ Args:
282
+ label (str): The label for the context.
283
+ """
284
+ if _torch_is_scripting():
285
+ yield
286
+ return
287
+
288
+ if self._saver is None:
289
+ yield
290
+ return
291
+
292
+ _ensure_supported()
293
+
294
+ log.debug(f"Entering ActSave context {label}")
295
+ self._prefixes.append(label)
296
+ try:
297
+ yield
298
+ finally:
299
+ _ = self._prefixes.pop()
300
+
301
+ prefix = context
302
+
303
+ @overload
304
+ def __call__(
305
+ self,
306
+ acts: dict[str, ValueOrLambda] | None = None,
307
+ /,
308
+ **kwargs: ValueOrLambda,
309
+ ):
310
+ """
311
+ Saves the activations to disk.
312
+
313
+ Args:
314
+ acts (dict[str, ValueOrLambda] | None, optional): A dictionary of acts. Defaults to None.
315
+ **kwargs (ValueOrLambda): Additional keyword arguments.
316
+
317
+ Returns:
318
+ None
319
+
320
+ """
321
+ ...
322
+
323
+ @overload
324
+ def __call__(self, acts: Callable[[], dict[str, ValueOrLambda]], /):
325
+ """
326
+ Saves the activations to disk.
327
+
328
+ Args:
329
+ acts (Callable[[], dict[str, ValueOrLambda]]): A callable that returns a dictionary of acts.
330
+ **kwargs (ValueOrLambda): Additional keyword arguments.
331
+
332
+ Returns:
333
+ None
334
+
335
+ """
336
+ ...
337
+
338
+ def __call__(
339
+ self,
340
+ acts: (
341
+ dict[str, ValueOrLambda] | Callable[[], dict[str, ValueOrLambda]] | None
342
+ ) = None,
343
+ /,
344
+ **kwargs: ValueOrLambda,
345
+ ):
346
+ if _torch_is_scripting():
347
+ return
348
+
349
+ if self._saver is None:
350
+ return
351
+
352
+ if acts is not None and callable(acts):
353
+ acts = acts()
354
+ self._saver.save(acts, **kwargs)
355
+
356
+ save = __call__
357
+
358
+
359
+ ActSave = ActSaveProvider()
360
+ ActivationSaver = ActSave
@@ -0,0 +1,271 @@
1
+ # Copyright The PyTorch Lightning team.
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # http://www.apache.org/licenses/LICENSE-2.0
4
+ #
5
+ import dataclasses
6
+ from collections import OrderedDict, defaultdict
7
+ from collections.abc import Callable, Mapping, Sequence
8
+ from copy import deepcopy
9
+ from typing import Any
10
+
11
+
12
+ def is_namedtuple(obj: object) -> bool:
13
+ """Check if object is type nametuple."""
14
+ # https://github.com/pytorch/pytorch/blob/v1.8.1/torch/nn/parallel/scatter_gather.py#L4-L8
15
+ return (
16
+ isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")
17
+ )
18
+
19
+
20
+ def is_dataclass_instance(obj: object) -> bool:
21
+ """Check if object is dataclass."""
22
+ # https://docs.python.org/3/library/dataclasses.html#module-level-decorators-classes-and-functions
23
+ return dataclasses.is_dataclass(obj) and not isinstance(obj, type)
24
+
25
+
26
+ def apply_to_collection(
27
+ data: Any,
28
+ dtype: type | Any | tuple[type | Any],
29
+ function: Callable,
30
+ *args: Any,
31
+ wrong_dtype: type | tuple[type, ...] | None = None,
32
+ include_none: bool = True,
33
+ allow_frozen: bool = False,
34
+ **kwargs: Any,
35
+ ) -> Any:
36
+ """Recursively applies a function to all elements of a certain dtype.
37
+
38
+ Args:
39
+ data: the collection to apply the function to
40
+ dtype: the given function will be applied to all elements of this dtype
41
+ function: the function to apply
42
+ *args: positional arguments (will be forwarded to calls of ``function``)
43
+ wrong_dtype: the given function won't be applied if this type is specified and the given collections
44
+ is of the ``wrong_dtype`` even if it is of type ``dtype``
45
+ include_none: Whether to include an element if the output of ``function`` is ``None``.
46
+ allow_frozen: Whether not to error upon encountering a frozen dataclass instance.
47
+ **kwargs: keyword arguments (will be forwarded to calls of ``function``)
48
+
49
+ Returns:
50
+ The resulting collection
51
+ """
52
+ # Breaking condition
53
+ if isinstance(data, dtype) and (
54
+ wrong_dtype is None or not isinstance(data, wrong_dtype)
55
+ ):
56
+ return function(data, *args, **kwargs)
57
+
58
+ elem_type = type(data)
59
+
60
+ # Recursively apply to collection items
61
+ if isinstance(data, Mapping):
62
+ out = []
63
+ for k, v in data.items():
64
+ v = apply_to_collection(
65
+ v,
66
+ dtype,
67
+ function,
68
+ *args,
69
+ wrong_dtype=wrong_dtype,
70
+ include_none=include_none,
71
+ allow_frozen=allow_frozen,
72
+ **kwargs,
73
+ )
74
+ if include_none or v is not None:
75
+ out.append((k, v))
76
+ if isinstance(data, defaultdict):
77
+ return elem_type(data.default_factory, OrderedDict(out))
78
+ return elem_type(OrderedDict(out))
79
+
80
+ is_namedtuple_ = is_namedtuple(data)
81
+ is_sequence = isinstance(data, Sequence) and not isinstance(data, str)
82
+ if is_namedtuple_ or is_sequence:
83
+ out = []
84
+ for d in data:
85
+ v = apply_to_collection(
86
+ d,
87
+ dtype,
88
+ function,
89
+ *args,
90
+ wrong_dtype=wrong_dtype,
91
+ include_none=include_none,
92
+ allow_frozen=allow_frozen,
93
+ **kwargs,
94
+ )
95
+ if include_none or v is not None:
96
+ out.append(v)
97
+ return elem_type(*out) if is_namedtuple_ else elem_type(out)
98
+
99
+ if is_dataclass_instance(data):
100
+ # make a deepcopy of the data,
101
+ # but do not deepcopy mapped fields since the computation would
102
+ # be wasted on values that likely get immediately overwritten
103
+ fields = {}
104
+ memo = {}
105
+ for field in dataclasses.fields(data):
106
+ field_value = getattr(data, field.name)
107
+ fields[field.name] = (field_value, field.init)
108
+ memo[id(field_value)] = field_value
109
+ result = deepcopy(data, memo=memo)
110
+ # apply function to each field
111
+ for field_name, (field_value, field_init) in fields.items():
112
+ v = None
113
+ if field_init:
114
+ v = apply_to_collection(
115
+ field_value,
116
+ dtype,
117
+ function,
118
+ *args,
119
+ wrong_dtype=wrong_dtype,
120
+ include_none=include_none,
121
+ allow_frozen=allow_frozen,
122
+ **kwargs,
123
+ )
124
+ if not field_init or (not include_none and v is None): # retain old value
125
+ v = getattr(data, field_name)
126
+ try:
127
+ setattr(result, field_name, v)
128
+ except dataclasses.FrozenInstanceError as e:
129
+ if allow_frozen:
130
+ # Quit early if we encounter a frozen data class; return `result` as is.
131
+ break
132
+ raise ValueError(
133
+ "A frozen dataclass was passed to `apply_to_collection` but this is not allowed."
134
+ ) from e
135
+ return result
136
+
137
+ # data is neither of dtype, nor a collection
138
+ return data
139
+
140
+
141
+ def apply_to_collections(
142
+ data1: Any | None,
143
+ data2: Any | None,
144
+ dtype: type | Any | tuple[type | Any],
145
+ function: Callable,
146
+ *args: Any,
147
+ wrong_dtype: type | tuple[type] | None = None,
148
+ **kwargs: Any,
149
+ ) -> Any:
150
+ """Zips two collections and applies a function to their items of a certain dtype.
151
+
152
+ Args:
153
+ data1: The first collection
154
+ data2: The second collection
155
+ dtype: the given function will be applied to all elements of this dtype
156
+ function: the function to apply
157
+ *args: positional arguments (will be forwarded to calls of ``function``)
158
+ wrong_dtype: the given function won't be applied if this type is specified and the given collections
159
+ is of the ``wrong_dtype`` even if it is of type ``dtype``
160
+ **kwargs: keyword arguments (will be forwarded to calls of ``function``)
161
+
162
+ Returns:
163
+ The resulting collection
164
+
165
+ Raises:
166
+ AssertionError:
167
+ If sequence collections have different data sizes.
168
+ """
169
+ if data1 is None:
170
+ if data2 is None:
171
+ return None
172
+ # in case they were passed reversed
173
+ data1, data2 = data2, None
174
+
175
+ elem_type = type(data1)
176
+
177
+ if (
178
+ isinstance(data1, dtype)
179
+ and data2 is not None
180
+ and (wrong_dtype is None or not isinstance(data1, wrong_dtype))
181
+ ):
182
+ return function(data1, data2, *args, **kwargs)
183
+
184
+ if isinstance(data1, Mapping) and data2 is not None:
185
+ # use union because we want to fail if a key does not exist in both
186
+ zipped = {k: (data1[k], data2[k]) for k in data1.keys() | data2.keys()}
187
+ return elem_type(
188
+ {
189
+ k: apply_to_collections(
190
+ *v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs
191
+ )
192
+ for k, v in zipped.items()
193
+ }
194
+ )
195
+
196
+ is_namedtuple_ = is_namedtuple(data1)
197
+ is_sequence = isinstance(data1, Sequence) and not isinstance(data1, str)
198
+ if (is_namedtuple_ or is_sequence) and data2 is not None:
199
+ if len(data1) != len(data2):
200
+ raise ValueError("Sequence collections have different sizes.")
201
+ out = [
202
+ apply_to_collections(
203
+ v1, v2, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs
204
+ )
205
+ for v1, v2 in zip(data1, data2)
206
+ ]
207
+ return elem_type(*out) if is_namedtuple_ else elem_type(out)
208
+
209
+ if is_dataclass_instance(data1) and data2 is not None:
210
+ if not is_dataclass_instance(data2):
211
+ raise TypeError(
212
+ "Expected inputs to be dataclasses of the same type or to have identical fields"
213
+ f" but got input 1 of type {type(data1)} and input 2 of type {type(data2)}."
214
+ )
215
+ if not (
216
+ len(dataclasses.fields(data1)) == len(dataclasses.fields(data2))
217
+ and all(
218
+ map(
219
+ lambda f1, f2: isinstance(f1, type(f2)),
220
+ dataclasses.fields(data1),
221
+ dataclasses.fields(data2),
222
+ )
223
+ )
224
+ ):
225
+ raise TypeError("Dataclasses fields do not match.")
226
+ # make a deepcopy of the data,
227
+ # but do not deepcopy mapped fields since the computation would
228
+ # be wasted on values that likely get immediately overwritten
229
+ data = [data1, data2]
230
+ fields: list[dict] = [{}, {}]
231
+ memo: dict = {}
232
+ for i in range(len(data)):
233
+ for field in dataclasses.fields(data[i]):
234
+ field_value = getattr(data[i], field.name)
235
+ fields[i][field.name] = (field_value, field.init)
236
+ if i == 0:
237
+ memo[id(field_value)] = field_value
238
+
239
+ result = deepcopy(data1, memo=memo)
240
+
241
+ # apply function to each field
242
+ for (field_name, (field_value1, field_init1)), (
243
+ _,
244
+ (field_value2, field_init2),
245
+ ) in zip(fields[0].items(), fields[1].items()):
246
+ v = None
247
+ if field_init1 and field_init2:
248
+ v = apply_to_collections(
249
+ field_value1,
250
+ field_value2,
251
+ dtype,
252
+ function,
253
+ *args,
254
+ wrong_dtype=wrong_dtype,
255
+ **kwargs,
256
+ )
257
+ if not field_init1 or not field_init2 or v is None: # retain old value
258
+ return apply_to_collection(
259
+ data1, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs
260
+ )
261
+ try:
262
+ setattr(result, field_name, v)
263
+ except dataclasses.FrozenInstanceError as e:
264
+ raise ValueError(
265
+ "A frozen dataclass was passed to `apply_to_collections` but this is not allowed."
266
+ ) from e
267
+ return result
268
+
269
+ return apply_to_collection(
270
+ data1, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs
271
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshutils
3
- Version: 0.2.0
3
+ Version: 0.3.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -11,6 +11,7 @@ Classifier: Programming Language :: Python :: 3.11
11
11
  Classifier: Programming Language :: Python :: 3.12
12
12
  Requires-Dist: beartype (>=0.18.5,<0.19.0)
13
13
  Requires-Dist: jaxtyping (>=0.2.33,<0.3.0)
14
+ Requires-Dist: numpy
14
15
  Requires-Dist: pysnooper (>=1.2.0,<2.0.0)
15
16
  Requires-Dist: typing-extensions
16
17
  Description-Content-Type: text/markdown
@@ -0,0 +1,10 @@
1
+ nshutils/__init__.py,sha256=Uz8WR2N_h4lsbmPb3HwJHrzraxP30zGjH5YgLqLm3k8,104
2
+ nshutils/actsave/__init__.py,sha256=kafzChiViTeOY28GnwKD6yBJfiB7aSU8t1u5cZicdmc,280
3
+ nshutils/actsave/_loader.py,sha256=fAhD32DrJa4onkYfcwc21YIeGEYzOSXCK_HVo9SZLgQ,4604
4
+ nshutils/actsave/_saver.py,sha256=Uo7rNFce0KzGcS2Q03UYXoMzW8z9cRkgs_r1gLiW5Tc,9853
5
+ nshutils/collections.py,sha256=tFzFqhqzTX_XI6IuciWJ0atZ6tj7ajGn1nzaLC4XIGQ,10155
6
+ nshutils/snoop.py,sha256=Rofv1Rd92E0LY40G3A-o9Hu0ZI73RR59wJD5l4Q3PDM,7022
7
+ nshutils/typecheck.py,sha256=wrjL-H2f3J8V1lojXIbcwQBh3039bz3HBVgG9DINYK4,4819
8
+ nshutils-0.3.0.dist-info/METADATA,sha256=y4R0kKMMAQuTKVkso9tcd4YuaBcKAvNj1toYUpvZGBc,571
9
+ nshutils-0.3.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
10
+ nshutils-0.3.0.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- nshutils/__init__.py,sha256=uNKU8zJk7Un4JC-fNHNicz_iUNGIequDgCXxIDqeKr4,71
2
- nshutils/snoop.py,sha256=Rofv1Rd92E0LY40G3A-o9Hu0ZI73RR59wJD5l4Q3PDM,7022
3
- nshutils/typecheck.py,sha256=wrjL-H2f3J8V1lojXIbcwQBh3039bz3HBVgG9DINYK4,4819
4
- nshutils-0.2.0.dist-info/METADATA,sha256=ohNMCYDmB9CR0XTpCpJ-AXqwVjbrGH0s96LTHe5Zwkc,550
5
- nshutils-0.2.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
6
- nshutils-0.2.0.dist-info/RECORD,,