nshutils 0.5.1__tar.gz → 0.6.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshutils
3
- Version: 0.5.1
3
+ Version: 0.6.1
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshutils"
3
- version = "0.5.1"
3
+ version = "0.6.1"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -14,6 +14,12 @@ beartype = "^0.18.5"
14
14
  numpy = "*"
15
15
 
16
16
 
17
+ [tool.poetry.group.dev.dependencies]
18
+ pyright = "^1.1.373"
19
+ ruff = "^0.5.4"
20
+ ipykernel = "^6.29.5"
21
+ ipywidgets = "^8.1.3"
22
+
17
23
  [build-system]
18
24
  requires = ["poetry-core"]
19
25
  build-backend = "poetry.core.masonry.api"
@@ -16,7 +16,7 @@ from typing_extensions import Never, ParamSpec, TypeVar, override
16
16
  from ..collections import apply_to_collection
17
17
 
18
18
  try:
19
- import torch
19
+ import torch # type: ignore
20
20
 
21
21
  if not TYPE_CHECKING:
22
22
  Tensor: TypeAlias = torch.Tensor
@@ -145,7 +145,7 @@ Transform = Callable[[Activation], Mapping[str, ValueOrLambda]]
145
145
 
146
146
  def _ensure_supported():
147
147
  try:
148
- import torch.distributed as dist
148
+ import torch.distributed as dist # type: ignore
149
149
 
150
150
  if dist.is_initialized() and dist.get_world_size() > 1:
151
151
  raise RuntimeError("Only single GPU is supported at the moment")
@@ -6,7 +6,7 @@ import dataclasses
6
6
  from collections import OrderedDict, defaultdict
7
7
  from collections.abc import Callable, Mapping, Sequence
8
8
  from copy import deepcopy
9
- from typing import Any
9
+ from typing import Any, cast
10
10
 
11
11
 
12
12
  def is_namedtuple(obj: object) -> bool:
@@ -97,6 +97,7 @@ def apply_to_collection(
97
97
  return elem_type(*out) if is_namedtuple_ else elem_type(out)
98
98
 
99
99
  if is_dataclass_instance(data):
100
+ data = cast(Any, data)
100
101
  # make a deepcopy of the data,
101
102
  # but do not deepcopy mapped fields since the computation would
102
103
  # be wasted on values that likely get immediately overwritten
@@ -136,136 +137,3 @@ def apply_to_collection(
136
137
 
137
138
  # data is neither of dtype, nor a collection
138
139
  return data
139
-
140
-
141
- def apply_to_collections(
142
- data1: Any | None,
143
- data2: Any | None,
144
- dtype: type | Any | tuple[type | Any],
145
- function: Callable,
146
- *args: Any,
147
- wrong_dtype: type | tuple[type] | None = None,
148
- **kwargs: Any,
149
- ) -> Any:
150
- """Zips two collections and applies a function to their items of a certain dtype.
151
-
152
- Args:
153
- data1: The first collection
154
- data2: The second collection
155
- dtype: the given function will be applied to all elements of this dtype
156
- function: the function to apply
157
- *args: positional arguments (will be forwarded to calls of ``function``)
158
- wrong_dtype: the given function won't be applied if this type is specified and the given collections
159
- is of the ``wrong_dtype`` even if it is of type ``dtype``
160
- **kwargs: keyword arguments (will be forwarded to calls of ``function``)
161
-
162
- Returns:
163
- The resulting collection
164
-
165
- Raises:
166
- AssertionError:
167
- If sequence collections have different data sizes.
168
- """
169
- if data1 is None:
170
- if data2 is None:
171
- return None
172
- # in case they were passed reversed
173
- data1, data2 = data2, None
174
-
175
- elem_type = type(data1)
176
-
177
- if (
178
- isinstance(data1, dtype)
179
- and data2 is not None
180
- and (wrong_dtype is None or not isinstance(data1, wrong_dtype))
181
- ):
182
- return function(data1, data2, *args, **kwargs)
183
-
184
- if isinstance(data1, Mapping) and data2 is not None:
185
- # use union because we want to fail if a key does not exist in both
186
- zipped = {k: (data1[k], data2[k]) for k in data1.keys() | data2.keys()}
187
- return elem_type(
188
- {
189
- k: apply_to_collections(
190
- *v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs
191
- )
192
- for k, v in zipped.items()
193
- }
194
- )
195
-
196
- is_namedtuple_ = is_namedtuple(data1)
197
- is_sequence = isinstance(data1, Sequence) and not isinstance(data1, str)
198
- if (is_namedtuple_ or is_sequence) and data2 is not None:
199
- if len(data1) != len(data2):
200
- raise ValueError("Sequence collections have different sizes.")
201
- out = [
202
- apply_to_collections(
203
- v1, v2, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs
204
- )
205
- for v1, v2 in zip(data1, data2)
206
- ]
207
- return elem_type(*out) if is_namedtuple_ else elem_type(out)
208
-
209
- if is_dataclass_instance(data1) and data2 is not None:
210
- if not is_dataclass_instance(data2):
211
- raise TypeError(
212
- "Expected inputs to be dataclasses of the same type or to have identical fields"
213
- f" but got input 1 of type {type(data1)} and input 2 of type {type(data2)}."
214
- )
215
- if not (
216
- len(dataclasses.fields(data1)) == len(dataclasses.fields(data2))
217
- and all(
218
- map(
219
- lambda f1, f2: isinstance(f1, type(f2)),
220
- dataclasses.fields(data1),
221
- dataclasses.fields(data2),
222
- )
223
- )
224
- ):
225
- raise TypeError("Dataclasses fields do not match.")
226
- # make a deepcopy of the data,
227
- # but do not deepcopy mapped fields since the computation would
228
- # be wasted on values that likely get immediately overwritten
229
- data = [data1, data2]
230
- fields: list[dict] = [{}, {}]
231
- memo: dict = {}
232
- for i in range(len(data)):
233
- for field in dataclasses.fields(data[i]):
234
- field_value = getattr(data[i], field.name)
235
- fields[i][field.name] = (field_value, field.init)
236
- if i == 0:
237
- memo[id(field_value)] = field_value
238
-
239
- result = deepcopy(data1, memo=memo)
240
-
241
- # apply function to each field
242
- for (field_name, (field_value1, field_init1)), (
243
- _,
244
- (field_value2, field_init2),
245
- ) in zip(fields[0].items(), fields[1].items()):
246
- v = None
247
- if field_init1 and field_init2:
248
- v = apply_to_collections(
249
- field_value1,
250
- field_value2,
251
- dtype,
252
- function,
253
- *args,
254
- wrong_dtype=wrong_dtype,
255
- **kwargs,
256
- )
257
- if not field_init1 or not field_init2 or v is None: # retain old value
258
- return apply_to_collection(
259
- data1, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs
260
- )
261
- try:
262
- setattr(result, field_name, v)
263
- except dataclasses.FrozenInstanceError as e:
264
- raise ValueError(
265
- "A frozen dataclass was passed to `apply_to_collections` but this is not allowed."
266
- ) from e
267
- return result
268
-
269
- return apply_to_collection(
270
- data1, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs
271
- )
@@ -13,7 +13,7 @@ def init_python_logging(
13
13
  ):
14
14
  if lovely_tensors:
15
15
  try:
16
- import lovely_tensors as _lovely_tensors
16
+ import lovely_tensors as _lovely_tensors # type: ignore
17
17
 
18
18
  _lovely_tensors.monkey_patch()
19
19
  except ImportError:
@@ -23,7 +23,7 @@ def init_python_logging(
23
23
 
24
24
  if lovely_numpy:
25
25
  try:
26
- import lovely_numpy as _lovely_numpy
26
+ import lovely_numpy as _lovely_numpy # type: ignore
27
27
 
28
28
  _lovely_numpy.set_config(repr=_lovely_numpy.lovely)
29
29
  except ImportError:
@@ -39,7 +39,7 @@ def init_python_logging(
39
39
 
40
40
  if rich:
41
41
  try:
42
- from rich.logging import RichHandler
42
+ from rich.logging import RichHandler # type: ignore
43
43
 
44
44
  log_handlers.append(RichHandler(rich_tracebacks=rich_tracebacks))
45
45
  except ImportError:
@@ -21,19 +21,19 @@ try:
21
21
  from pkg_resources import DistributionNotFound, get_distribution
22
22
 
23
23
  try:
24
- import torch
24
+ import torch # type: ignore
25
25
  except ImportError:
26
26
  torch = None
27
27
 
28
28
  try:
29
- import numpy
29
+ import numpy # type: ignore
30
30
  except ImportError:
31
31
  numpy = None
32
32
 
33
33
  FLOATING_POINTS = set()
34
34
  for i in ["float", "double", "half", "complex128", "complex32", "complex64"]:
35
35
  # older version of PyTorch do not have complex dtypes
36
- if torch is not None and not hasattr(torch, i):
36
+ if torch is None or not hasattr(torch, i):
37
37
  continue
38
38
  FLOATING_POINTS.add(getattr(torch, i))
39
39
 
@@ -45,15 +45,15 @@ try:
45
45
 
46
46
  def default_format(x):
47
47
  try:
48
- import lovely_tensors as lt
48
+ import lovely_tensors as lt # type: ignore
49
49
 
50
- return = str(lt.lovely(x))
50
+ return str(lt.lovely(x))
51
51
  except BaseException:
52
52
  return str(x.shape)
53
53
 
54
54
  def default_numpy_format(x):
55
55
  try:
56
- import lovely_numpy as lo
56
+ import lovely_numpy as lo # type: ignore
57
57
 
58
58
  return str(lo.lovely(x))
59
59
  except BaseException:
@@ -34,18 +34,18 @@ from jaxtyping._storage import get_shape_memo, shape_str
34
34
  from typing_extensions import TypeVar
35
35
 
36
36
  try:
37
- import torch
37
+ import torch # type: ignore
38
38
  except ImportError:
39
39
  torch = None
40
40
 
41
41
  try:
42
- import np
42
+ import np # type: ignore
43
43
  except ImportError:
44
44
  np = None
45
45
 
46
46
 
47
47
  try:
48
- import jax
48
+ import jax # type: ignore
49
49
  except ImportError:
50
50
  jax = None
51
51
  log = getLogger(__name__)
@@ -106,21 +106,21 @@ def _make_error_str(input: Any, t: Any) -> str:
106
106
  error_components.append(t.__instancecheck_str__(input))
107
107
  if torch is not None and torch.is_tensor(input):
108
108
  try:
109
- from lovely_tensors import lovely
109
+ from lovely_tensors import lovely # type: ignore
110
110
 
111
111
  error_components.append(repr(lovely(input)))
112
112
  except BaseException:
113
113
  error_components.append(repr(input.shape))
114
114
  elif jax is not None and isinstance(input, jax.Array):
115
115
  try:
116
- from lovely_jax import lovely
116
+ from lovely_jax import lovely # type: ignore
117
117
 
118
118
  error_components.append(repr(lovely(input)))
119
119
  except BaseException:
120
120
  error_components.append(repr(input.shape))
121
121
  elif np is not None and isinstance(input, np.ndarray):
122
122
  try:
123
- from lovely_numpy import lovely
123
+ from lovely_numpy import lovely # type: ignore
124
124
 
125
125
  error_components.append(repr(lovely(input)))
126
126
  except BaseException:
File without changes