nshutils 0.5.0__tar.gz → 0.6.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshutils
3
- Version: 0.5.0
3
+ Version: 0.6.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshutils"
3
- version = "0.5.0"
3
+ version = "0.6.0"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -14,6 +14,12 @@ beartype = "^0.18.5"
14
14
  numpy = "*"
15
15
 
16
16
 
17
+ [tool.poetry.group.dev.dependencies]
18
+ pyright = "^1.1.373"
19
+ ruff = "^0.5.4"
20
+ ipykernel = "^6.29.5"
21
+ ipywidgets = "^8.1.3"
22
+
17
23
  [build-system]
18
24
  requires = ["poetry-core"]
19
25
  build-backend = "poetry.core.masonry.api"
@@ -16,7 +16,7 @@ from typing_extensions import Never, ParamSpec, TypeVar, override
16
16
  from ..collections import apply_to_collection
17
17
 
18
18
  try:
19
- import torch
19
+ import torch # type: ignore
20
20
 
21
21
  if not TYPE_CHECKING:
22
22
  Tensor: TypeAlias = torch.Tensor
@@ -145,7 +145,7 @@ Transform = Callable[[Activation], Mapping[str, ValueOrLambda]]
145
145
 
146
146
  def _ensure_supported():
147
147
  try:
148
- import torch.distributed as dist
148
+ import torch.distributed as dist # type: ignore
149
149
 
150
150
  if dist.is_initialized() and dist.get_world_size() > 1:
151
151
  raise RuntimeError("Only single GPU is supported at the moment")
@@ -6,7 +6,7 @@ import dataclasses
6
6
  from collections import OrderedDict, defaultdict
7
7
  from collections.abc import Callable, Mapping, Sequence
8
8
  from copy import deepcopy
9
- from typing import Any
9
+ from typing import Any, cast
10
10
 
11
11
 
12
12
  def is_namedtuple(obj: object) -> bool:
@@ -97,6 +97,7 @@ def apply_to_collection(
97
97
  return elem_type(*out) if is_namedtuple_ else elem_type(out)
98
98
 
99
99
  if is_dataclass_instance(data):
100
+ data = cast(Any, data)
100
101
  # make a deepcopy of the data,
101
102
  # but do not deepcopy mapped fields since the computation would
102
103
  # be wasted on values that likely get immediately overwritten
@@ -136,136 +137,3 @@ def apply_to_collection(
136
137
 
137
138
  # data is neither of dtype, nor a collection
138
139
  return data
139
-
140
-
141
- def apply_to_collections(
142
- data1: Any | None,
143
- data2: Any | None,
144
- dtype: type | Any | tuple[type | Any],
145
- function: Callable,
146
- *args: Any,
147
- wrong_dtype: type | tuple[type] | None = None,
148
- **kwargs: Any,
149
- ) -> Any:
150
- """Zips two collections and applies a function to their items of a certain dtype.
151
-
152
- Args:
153
- data1: The first collection
154
- data2: The second collection
155
- dtype: the given function will be applied to all elements of this dtype
156
- function: the function to apply
157
- *args: positional arguments (will be forwarded to calls of ``function``)
158
- wrong_dtype: the given function won't be applied if this type is specified and the given collections
159
- is of the ``wrong_dtype`` even if it is of type ``dtype``
160
- **kwargs: keyword arguments (will be forwarded to calls of ``function``)
161
-
162
- Returns:
163
- The resulting collection
164
-
165
- Raises:
166
- AssertionError:
167
- If sequence collections have different data sizes.
168
- """
169
- if data1 is None:
170
- if data2 is None:
171
- return None
172
- # in case they were passed reversed
173
- data1, data2 = data2, None
174
-
175
- elem_type = type(data1)
176
-
177
- if (
178
- isinstance(data1, dtype)
179
- and data2 is not None
180
- and (wrong_dtype is None or not isinstance(data1, wrong_dtype))
181
- ):
182
- return function(data1, data2, *args, **kwargs)
183
-
184
- if isinstance(data1, Mapping) and data2 is not None:
185
- # use union because we want to fail if a key does not exist in both
186
- zipped = {k: (data1[k], data2[k]) for k in data1.keys() | data2.keys()}
187
- return elem_type(
188
- {
189
- k: apply_to_collections(
190
- *v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs
191
- )
192
- for k, v in zipped.items()
193
- }
194
- )
195
-
196
- is_namedtuple_ = is_namedtuple(data1)
197
- is_sequence = isinstance(data1, Sequence) and not isinstance(data1, str)
198
- if (is_namedtuple_ or is_sequence) and data2 is not None:
199
- if len(data1) != len(data2):
200
- raise ValueError("Sequence collections have different sizes.")
201
- out = [
202
- apply_to_collections(
203
- v1, v2, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs
204
- )
205
- for v1, v2 in zip(data1, data2)
206
- ]
207
- return elem_type(*out) if is_namedtuple_ else elem_type(out)
208
-
209
- if is_dataclass_instance(data1) and data2 is not None:
210
- if not is_dataclass_instance(data2):
211
- raise TypeError(
212
- "Expected inputs to be dataclasses of the same type or to have identical fields"
213
- f" but got input 1 of type {type(data1)} and input 2 of type {type(data2)}."
214
- )
215
- if not (
216
- len(dataclasses.fields(data1)) == len(dataclasses.fields(data2))
217
- and all(
218
- map(
219
- lambda f1, f2: isinstance(f1, type(f2)),
220
- dataclasses.fields(data1),
221
- dataclasses.fields(data2),
222
- )
223
- )
224
- ):
225
- raise TypeError("Dataclasses fields do not match.")
226
- # make a deepcopy of the data,
227
- # but do not deepcopy mapped fields since the computation would
228
- # be wasted on values that likely get immediately overwritten
229
- data = [data1, data2]
230
- fields: list[dict] = [{}, {}]
231
- memo: dict = {}
232
- for i in range(len(data)):
233
- for field in dataclasses.fields(data[i]):
234
- field_value = getattr(data[i], field.name)
235
- fields[i][field.name] = (field_value, field.init)
236
- if i == 0:
237
- memo[id(field_value)] = field_value
238
-
239
- result = deepcopy(data1, memo=memo)
240
-
241
- # apply function to each field
242
- for (field_name, (field_value1, field_init1)), (
243
- _,
244
- (field_value2, field_init2),
245
- ) in zip(fields[0].items(), fields[1].items()):
246
- v = None
247
- if field_init1 and field_init2:
248
- v = apply_to_collections(
249
- field_value1,
250
- field_value2,
251
- dtype,
252
- function,
253
- *args,
254
- wrong_dtype=wrong_dtype,
255
- **kwargs,
256
- )
257
- if not field_init1 or not field_init2 or v is None: # retain old value
258
- return apply_to_collection(
259
- data1, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs
260
- )
261
- try:
262
- setattr(result, field_name, v)
263
- except dataclasses.FrozenInstanceError as e:
264
- raise ValueError(
265
- "A frozen dataclass was passed to `apply_to_collections` but this is not allowed."
266
- ) from e
267
- return result
268
-
269
- return apply_to_collection(
270
- data1, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs
271
- )
@@ -13,7 +13,7 @@ def init_python_logging(
13
13
  ):
14
14
  if lovely_tensors:
15
15
  try:
16
- import lovely_tensors as _lovely_tensors
16
+ import lovely_tensors as _lovely_tensors # type: ignore
17
17
 
18
18
  _lovely_tensors.monkey_patch()
19
19
  except ImportError:
@@ -23,7 +23,7 @@ def init_python_logging(
23
23
 
24
24
  if lovely_numpy:
25
25
  try:
26
- import lovely_numpy as _lovely_numpy
26
+ import lovely_numpy as _lovely_numpy # type: ignore
27
27
 
28
28
  _lovely_numpy.set_config(repr=_lovely_numpy.lovely)
29
29
  except ImportError:
@@ -39,7 +39,7 @@ def init_python_logging(
39
39
 
40
40
  if rich:
41
41
  try:
42
- from rich.logging import RichHandler
42
+ from rich.logging import RichHandler # type: ignore
43
43
 
44
44
  log_handlers.append(RichHandler(rich_tracebacks=rich_tracebacks))
45
45
  except ImportError:
@@ -16,18 +16,26 @@ try:
16
16
  import warnings
17
17
  from contextlib import nullcontext
18
18
 
19
- import lovely_numpy as lo
20
- import lovely_tensors as lt
21
- import numpy
22
19
  import pysnooper
23
20
  import pysnooper.utils
24
- import torch
25
21
  from pkg_resources import DistributionNotFound, get_distribution
26
22
 
23
+ try:
24
+ import torch # type: ignore
25
+ except ImportError:
26
+ torch = None
27
+
28
+ try:
29
+ import numpy # type: ignore
30
+ except ImportError:
31
+ numpy = None
32
+
27
33
  FLOATING_POINTS = set()
28
34
  for i in ["float", "double", "half", "complex128", "complex32", "complex64"]:
29
- if hasattr(torch, i): # older version of PyTorch do not have complex dtypes
30
- FLOATING_POINTS.add(getattr(torch, i))
35
+ # older version of PyTorch do not have complex dtypes
36
+ if torch is not None and not hasattr(torch, i):
37
+ continue
38
+ FLOATING_POINTS.add(getattr(torch, i))
31
39
 
32
40
  try:
33
41
  __version__ = get_distribution(__name__).version
@@ -37,13 +45,19 @@ try:
37
45
 
38
46
  def default_format(x):
39
47
  try:
40
- formatted = str(lt.lovely(x))
41
- return formatted
48
+ import lovely_tensors as lt # type: ignore
49
+
50
+ return str(lt.lovely(x))
42
51
  except BaseException:
43
52
  return str(x.shape)
44
53
 
45
54
  def default_numpy_format(x):
46
- return str(lo.lovely(x))
55
+ try:
56
+ import lovely_numpy as lo # type: ignore
57
+
58
+ return str(lo.lovely(x))
59
+ except BaseException:
60
+ return str(x.shape)
47
61
 
48
62
  class TorchSnooper(pysnooper.tracer.Tracer):
49
63
  def __init__(
@@ -155,9 +169,9 @@ try:
155
169
 
156
170
  def compute_repr(self, x):
157
171
  orig_repr_func = pysnooper.utils.get_repr_function(x, self.orig_custom_repr)
158
- if torch.is_tensor(x):
172
+ if torch is not None and torch.is_tensor(x):
159
173
  return self.tensor_format(x)
160
- if isinstance(x, numpy.ndarray):
174
+ if numpy is not None and isinstance(x, numpy.ndarray):
161
175
  return self.numpy_format(x)
162
176
  if self.is_return_types(x):
163
177
  return self.return_types_repr(x)
@@ -34,18 +34,18 @@ from jaxtyping._storage import get_shape_memo, shape_str
34
34
  from typing_extensions import TypeVar
35
35
 
36
36
  try:
37
- import torch
37
+ import torch # type: ignore
38
38
  except ImportError:
39
39
  torch = None
40
40
 
41
41
  try:
42
- import np
42
+ import np # type: ignore
43
43
  except ImportError:
44
44
  np = None
45
45
 
46
46
 
47
47
  try:
48
- import jax
48
+ import jax # type: ignore
49
49
  except ImportError:
50
50
  jax = None
51
51
  log = getLogger(__name__)
@@ -106,21 +106,21 @@ def _make_error_str(input: Any, t: Any) -> str:
106
106
  error_components.append(t.__instancecheck_str__(input))
107
107
  if torch is not None and torch.is_tensor(input):
108
108
  try:
109
- from lovely_tensors import lovely
109
+ from lovely_tensors import lovely # type: ignore
110
110
 
111
111
  error_components.append(repr(lovely(input)))
112
112
  except BaseException:
113
113
  error_components.append(repr(input.shape))
114
114
  elif jax is not None and isinstance(input, jax.Array):
115
115
  try:
116
- from lovely_jax import lovely
116
+ from lovely_jax import lovely # type: ignore
117
117
 
118
118
  error_components.append(repr(lovely(input)))
119
119
  except BaseException:
120
120
  error_components.append(repr(input.shape))
121
121
  elif np is not None and isinstance(input, np.ndarray):
122
122
  try:
123
- from lovely_numpy import lovely
123
+ from lovely_numpy import lovely # type: ignore
124
124
 
125
125
  error_components.append(repr(lovely(input)))
126
126
  except BaseException:
File without changes