nshutils 0.5.1__tar.gz → 0.6.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nshutils-0.5.1 → nshutils-0.6.1}/PKG-INFO +1 -1
- {nshutils-0.5.1 → nshutils-0.6.1}/pyproject.toml +7 -1
- {nshutils-0.5.1 → nshutils-0.6.1}/src/nshutils/actsave/_saver.py +2 -2
- {nshutils-0.5.1 → nshutils-0.6.1}/src/nshutils/collections.py +2 -134
- {nshutils-0.5.1 → nshutils-0.6.1}/src/nshutils/logging.py +3 -3
- {nshutils-0.5.1 → nshutils-0.6.1}/src/nshutils/snoop.py +6 -6
- {nshutils-0.5.1 → nshutils-0.6.1}/src/nshutils/typecheck.py +6 -6
- {nshutils-0.5.1 → nshutils-0.6.1}/README.md +0 -0
- {nshutils-0.5.1 → nshutils-0.6.1}/src/nshutils/__init__.py +0 -0
- {nshutils-0.5.1 → nshutils-0.6.1}/src/nshutils/actsave/__init__.py +0 -0
- {nshutils-0.5.1 → nshutils-0.6.1}/src/nshutils/actsave/_loader.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
[tool.poetry]
|
2
2
|
name = "nshutils"
|
3
|
-
version = "0.
|
3
|
+
version = "0.6.1"
|
4
4
|
description = ""
|
5
5
|
authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
|
6
6
|
readme = "README.md"
|
@@ -14,6 +14,12 @@ beartype = "^0.18.5"
|
|
14
14
|
numpy = "*"
|
15
15
|
|
16
16
|
|
17
|
+
[tool.poetry.group.dev.dependencies]
|
18
|
+
pyright = "^1.1.373"
|
19
|
+
ruff = "^0.5.4"
|
20
|
+
ipykernel = "^6.29.5"
|
21
|
+
ipywidgets = "^8.1.3"
|
22
|
+
|
17
23
|
[build-system]
|
18
24
|
requires = ["poetry-core"]
|
19
25
|
build-backend = "poetry.core.masonry.api"
|
@@ -16,7 +16,7 @@ from typing_extensions import Never, ParamSpec, TypeVar, override
|
|
16
16
|
from ..collections import apply_to_collection
|
17
17
|
|
18
18
|
try:
|
19
|
-
import torch
|
19
|
+
import torch # type: ignore
|
20
20
|
|
21
21
|
if not TYPE_CHECKING:
|
22
22
|
Tensor: TypeAlias = torch.Tensor
|
@@ -145,7 +145,7 @@ Transform = Callable[[Activation], Mapping[str, ValueOrLambda]]
|
|
145
145
|
|
146
146
|
def _ensure_supported():
|
147
147
|
try:
|
148
|
-
import torch.distributed as dist
|
148
|
+
import torch.distributed as dist # type: ignore
|
149
149
|
|
150
150
|
if dist.is_initialized() and dist.get_world_size() > 1:
|
151
151
|
raise RuntimeError("Only single GPU is supported at the moment")
|
@@ -6,7 +6,7 @@ import dataclasses
|
|
6
6
|
from collections import OrderedDict, defaultdict
|
7
7
|
from collections.abc import Callable, Mapping, Sequence
|
8
8
|
from copy import deepcopy
|
9
|
-
from typing import Any
|
9
|
+
from typing import Any, cast
|
10
10
|
|
11
11
|
|
12
12
|
def is_namedtuple(obj: object) -> bool:
|
@@ -97,6 +97,7 @@ def apply_to_collection(
|
|
97
97
|
return elem_type(*out) if is_namedtuple_ else elem_type(out)
|
98
98
|
|
99
99
|
if is_dataclass_instance(data):
|
100
|
+
data = cast(Any, data)
|
100
101
|
# make a deepcopy of the data,
|
101
102
|
# but do not deepcopy mapped fields since the computation would
|
102
103
|
# be wasted on values that likely get immediately overwritten
|
@@ -136,136 +137,3 @@ def apply_to_collection(
|
|
136
137
|
|
137
138
|
# data is neither of dtype, nor a collection
|
138
139
|
return data
|
139
|
-
|
140
|
-
|
141
|
-
def apply_to_collections(
|
142
|
-
data1: Any | None,
|
143
|
-
data2: Any | None,
|
144
|
-
dtype: type | Any | tuple[type | Any],
|
145
|
-
function: Callable,
|
146
|
-
*args: Any,
|
147
|
-
wrong_dtype: type | tuple[type] | None = None,
|
148
|
-
**kwargs: Any,
|
149
|
-
) -> Any:
|
150
|
-
"""Zips two collections and applies a function to their items of a certain dtype.
|
151
|
-
|
152
|
-
Args:
|
153
|
-
data1: The first collection
|
154
|
-
data2: The second collection
|
155
|
-
dtype: the given function will be applied to all elements of this dtype
|
156
|
-
function: the function to apply
|
157
|
-
*args: positional arguments (will be forwarded to calls of ``function``)
|
158
|
-
wrong_dtype: the given function won't be applied if this type is specified and the given collections
|
159
|
-
is of the ``wrong_dtype`` even if it is of type ``dtype``
|
160
|
-
**kwargs: keyword arguments (will be forwarded to calls of ``function``)
|
161
|
-
|
162
|
-
Returns:
|
163
|
-
The resulting collection
|
164
|
-
|
165
|
-
Raises:
|
166
|
-
AssertionError:
|
167
|
-
If sequence collections have different data sizes.
|
168
|
-
"""
|
169
|
-
if data1 is None:
|
170
|
-
if data2 is None:
|
171
|
-
return None
|
172
|
-
# in case they were passed reversed
|
173
|
-
data1, data2 = data2, None
|
174
|
-
|
175
|
-
elem_type = type(data1)
|
176
|
-
|
177
|
-
if (
|
178
|
-
isinstance(data1, dtype)
|
179
|
-
and data2 is not None
|
180
|
-
and (wrong_dtype is None or not isinstance(data1, wrong_dtype))
|
181
|
-
):
|
182
|
-
return function(data1, data2, *args, **kwargs)
|
183
|
-
|
184
|
-
if isinstance(data1, Mapping) and data2 is not None:
|
185
|
-
# use union because we want to fail if a key does not exist in both
|
186
|
-
zipped = {k: (data1[k], data2[k]) for k in data1.keys() | data2.keys()}
|
187
|
-
return elem_type(
|
188
|
-
{
|
189
|
-
k: apply_to_collections(
|
190
|
-
*v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs
|
191
|
-
)
|
192
|
-
for k, v in zipped.items()
|
193
|
-
}
|
194
|
-
)
|
195
|
-
|
196
|
-
is_namedtuple_ = is_namedtuple(data1)
|
197
|
-
is_sequence = isinstance(data1, Sequence) and not isinstance(data1, str)
|
198
|
-
if (is_namedtuple_ or is_sequence) and data2 is not None:
|
199
|
-
if len(data1) != len(data2):
|
200
|
-
raise ValueError("Sequence collections have different sizes.")
|
201
|
-
out = [
|
202
|
-
apply_to_collections(
|
203
|
-
v1, v2, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs
|
204
|
-
)
|
205
|
-
for v1, v2 in zip(data1, data2)
|
206
|
-
]
|
207
|
-
return elem_type(*out) if is_namedtuple_ else elem_type(out)
|
208
|
-
|
209
|
-
if is_dataclass_instance(data1) and data2 is not None:
|
210
|
-
if not is_dataclass_instance(data2):
|
211
|
-
raise TypeError(
|
212
|
-
"Expected inputs to be dataclasses of the same type or to have identical fields"
|
213
|
-
f" but got input 1 of type {type(data1)} and input 2 of type {type(data2)}."
|
214
|
-
)
|
215
|
-
if not (
|
216
|
-
len(dataclasses.fields(data1)) == len(dataclasses.fields(data2))
|
217
|
-
and all(
|
218
|
-
map(
|
219
|
-
lambda f1, f2: isinstance(f1, type(f2)),
|
220
|
-
dataclasses.fields(data1),
|
221
|
-
dataclasses.fields(data2),
|
222
|
-
)
|
223
|
-
)
|
224
|
-
):
|
225
|
-
raise TypeError("Dataclasses fields do not match.")
|
226
|
-
# make a deepcopy of the data,
|
227
|
-
# but do not deepcopy mapped fields since the computation would
|
228
|
-
# be wasted on values that likely get immediately overwritten
|
229
|
-
data = [data1, data2]
|
230
|
-
fields: list[dict] = [{}, {}]
|
231
|
-
memo: dict = {}
|
232
|
-
for i in range(len(data)):
|
233
|
-
for field in dataclasses.fields(data[i]):
|
234
|
-
field_value = getattr(data[i], field.name)
|
235
|
-
fields[i][field.name] = (field_value, field.init)
|
236
|
-
if i == 0:
|
237
|
-
memo[id(field_value)] = field_value
|
238
|
-
|
239
|
-
result = deepcopy(data1, memo=memo)
|
240
|
-
|
241
|
-
# apply function to each field
|
242
|
-
for (field_name, (field_value1, field_init1)), (
|
243
|
-
_,
|
244
|
-
(field_value2, field_init2),
|
245
|
-
) in zip(fields[0].items(), fields[1].items()):
|
246
|
-
v = None
|
247
|
-
if field_init1 and field_init2:
|
248
|
-
v = apply_to_collections(
|
249
|
-
field_value1,
|
250
|
-
field_value2,
|
251
|
-
dtype,
|
252
|
-
function,
|
253
|
-
*args,
|
254
|
-
wrong_dtype=wrong_dtype,
|
255
|
-
**kwargs,
|
256
|
-
)
|
257
|
-
if not field_init1 or not field_init2 or v is None: # retain old value
|
258
|
-
return apply_to_collection(
|
259
|
-
data1, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs
|
260
|
-
)
|
261
|
-
try:
|
262
|
-
setattr(result, field_name, v)
|
263
|
-
except dataclasses.FrozenInstanceError as e:
|
264
|
-
raise ValueError(
|
265
|
-
"A frozen dataclass was passed to `apply_to_collections` but this is not allowed."
|
266
|
-
) from e
|
267
|
-
return result
|
268
|
-
|
269
|
-
return apply_to_collection(
|
270
|
-
data1, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs
|
271
|
-
)
|
@@ -13,7 +13,7 @@ def init_python_logging(
|
|
13
13
|
):
|
14
14
|
if lovely_tensors:
|
15
15
|
try:
|
16
|
-
import lovely_tensors as _lovely_tensors
|
16
|
+
import lovely_tensors as _lovely_tensors # type: ignore
|
17
17
|
|
18
18
|
_lovely_tensors.monkey_patch()
|
19
19
|
except ImportError:
|
@@ -23,7 +23,7 @@ def init_python_logging(
|
|
23
23
|
|
24
24
|
if lovely_numpy:
|
25
25
|
try:
|
26
|
-
import lovely_numpy as _lovely_numpy
|
26
|
+
import lovely_numpy as _lovely_numpy # type: ignore
|
27
27
|
|
28
28
|
_lovely_numpy.set_config(repr=_lovely_numpy.lovely)
|
29
29
|
except ImportError:
|
@@ -39,7 +39,7 @@ def init_python_logging(
|
|
39
39
|
|
40
40
|
if rich:
|
41
41
|
try:
|
42
|
-
from rich.logging import RichHandler
|
42
|
+
from rich.logging import RichHandler # type: ignore
|
43
43
|
|
44
44
|
log_handlers.append(RichHandler(rich_tracebacks=rich_tracebacks))
|
45
45
|
except ImportError:
|
@@ -21,19 +21,19 @@ try:
|
|
21
21
|
from pkg_resources import DistributionNotFound, get_distribution
|
22
22
|
|
23
23
|
try:
|
24
|
-
import torch
|
24
|
+
import torch # type: ignore
|
25
25
|
except ImportError:
|
26
26
|
torch = None
|
27
27
|
|
28
28
|
try:
|
29
|
-
import numpy
|
29
|
+
import numpy # type: ignore
|
30
30
|
except ImportError:
|
31
31
|
numpy = None
|
32
32
|
|
33
33
|
FLOATING_POINTS = set()
|
34
34
|
for i in ["float", "double", "half", "complex128", "complex32", "complex64"]:
|
35
35
|
# older version of PyTorch do not have complex dtypes
|
36
|
-
if torch is
|
36
|
+
if torch is None or not hasattr(torch, i):
|
37
37
|
continue
|
38
38
|
FLOATING_POINTS.add(getattr(torch, i))
|
39
39
|
|
@@ -45,15 +45,15 @@ try:
|
|
45
45
|
|
46
46
|
def default_format(x):
|
47
47
|
try:
|
48
|
-
import lovely_tensors as lt
|
48
|
+
import lovely_tensors as lt # type: ignore
|
49
49
|
|
50
|
-
return
|
50
|
+
return str(lt.lovely(x))
|
51
51
|
except BaseException:
|
52
52
|
return str(x.shape)
|
53
53
|
|
54
54
|
def default_numpy_format(x):
|
55
55
|
try:
|
56
|
-
import lovely_numpy as lo
|
56
|
+
import lovely_numpy as lo # type: ignore
|
57
57
|
|
58
58
|
return str(lo.lovely(x))
|
59
59
|
except BaseException:
|
@@ -34,18 +34,18 @@ from jaxtyping._storage import get_shape_memo, shape_str
|
|
34
34
|
from typing_extensions import TypeVar
|
35
35
|
|
36
36
|
try:
|
37
|
-
import torch
|
37
|
+
import torch # type: ignore
|
38
38
|
except ImportError:
|
39
39
|
torch = None
|
40
40
|
|
41
41
|
try:
|
42
|
-
import np
|
42
|
+
import np # type: ignore
|
43
43
|
except ImportError:
|
44
44
|
np = None
|
45
45
|
|
46
46
|
|
47
47
|
try:
|
48
|
-
import jax
|
48
|
+
import jax # type: ignore
|
49
49
|
except ImportError:
|
50
50
|
jax = None
|
51
51
|
log = getLogger(__name__)
|
@@ -106,21 +106,21 @@ def _make_error_str(input: Any, t: Any) -> str:
|
|
106
106
|
error_components.append(t.__instancecheck_str__(input))
|
107
107
|
if torch is not None and torch.is_tensor(input):
|
108
108
|
try:
|
109
|
-
from lovely_tensors import lovely
|
109
|
+
from lovely_tensors import lovely # type: ignore
|
110
110
|
|
111
111
|
error_components.append(repr(lovely(input)))
|
112
112
|
except BaseException:
|
113
113
|
error_components.append(repr(input.shape))
|
114
114
|
elif jax is not None and isinstance(input, jax.Array):
|
115
115
|
try:
|
116
|
-
from lovely_jax import lovely
|
116
|
+
from lovely_jax import lovely # type: ignore
|
117
117
|
|
118
118
|
error_components.append(repr(lovely(input)))
|
119
119
|
except BaseException:
|
120
120
|
error_components.append(repr(input.shape))
|
121
121
|
elif np is not None and isinstance(input, np.ndarray):
|
122
122
|
try:
|
123
|
-
from lovely_numpy import lovely
|
123
|
+
from lovely_numpy import lovely # type: ignore
|
124
124
|
|
125
125
|
error_components.append(repr(lovely(input)))
|
126
126
|
except BaseException:
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|