nshutils 0.31.1__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -20,7 +20,7 @@ from ..collections import apply_to_collection
20
20
 
21
21
  if not TYPE_CHECKING:
22
22
  try:
23
- import torch # type: ignore
23
+ import torch # pyright: ignore[reportMissingImports]
24
24
 
25
25
  Tensor = torch.Tensor
26
26
  _torch_installed = True
@@ -30,9 +30,9 @@ if not TYPE_CHECKING:
30
30
 
31
31
  Tensor = Never
32
32
  else:
33
- import torch # type: ignore
33
+ import torch # pyright: ignore[reportMissingImports]
34
34
 
35
- Tensor = torch.Tensor
35
+ Tensor = TypeAliasType("Tensor", torch.Tensor)
36
36
  _torch_installed: Literal[True] = True
37
37
 
38
38
  log = getLogger(__name__)
@@ -60,7 +60,7 @@ def _to_numpy(activation: Value) -> np.ndarray:
60
60
  return np.array(activation)
61
61
  elif isinstance(activation, np.ndarray):
62
62
  return activation
63
- elif _torch_installed and isinstance(activation, Tensor):
63
+ elif _torch_installed and isinstance(activation, torch.Tensor):
64
64
  activation_ = activation.detach()
65
65
  if activation_.is_floating_point():
66
66
  # NOTE: We need to convert to float32 because [b]float16 is not supported by numpy
@@ -126,7 +126,8 @@ class Activation:
126
126
  if activation is None:
127
127
  return None
128
128
 
129
- activation = apply_to_collection(activation, Tensor, _to_numpy)
129
+ if _torch_installed:
130
+ activation = apply_to_collection(activation, Tensor, _to_numpy)
130
131
  activation = _to_numpy(activation)
131
132
 
132
133
  # Set the transformed value
@@ -30,9 +30,10 @@ def _find_deps() -> list[Library]:
30
30
 
31
31
  class monkey_patch(lovely_patch):
32
32
  def __init__(self, libraries: list[Library] | Literal["auto"] = "auto"):
33
- self.libraries = libraries
34
- if self.libraries == "auto":
33
+ if libraries == "auto":
35
34
  self.libraries = _find_deps()
35
+ else:
36
+ self.libraries = libraries
36
37
 
37
38
  if not self.libraries:
38
39
  raise ValueError(
@@ -59,7 +60,7 @@ class monkey_patch(lovely_patch):
59
60
 
60
61
  self.stack.enter_context(numpy_monkey_patch())
61
62
  else:
62
- assert_never(library) # type: ignore
63
+ assert_never(library)
63
64
 
64
65
  log.info(
65
66
  f"Monkey patched libraries: {', '.join(self.libraries)}. "
nshutils/lovely/jax_.py CHANGED
@@ -9,7 +9,7 @@ from ._base import lovely_patch, lovely_repr
9
9
  from .utils import LovelyStats, array_stats, patch_to
10
10
 
11
11
  if TYPE_CHECKING:
12
- import jax
12
+ import jax # pyright: ignore[reportMissingImports]
13
13
 
14
14
 
15
15
  def _type_name(array: jax.Array):
@@ -42,7 +42,7 @@ def _dtype_str(array: jax.Array) -> str:
42
42
 
43
43
 
44
44
  def _device(array: jax.Array) -> str:
45
- from jaxlib.xla_extension import Device
45
+ from jaxlib.xla_extension import Device # pyright: ignore[reportMissingImports]
46
46
 
47
47
  if callable(device := array.device):
48
48
  device = device()
@@ -56,7 +56,7 @@ def _device(array: jax.Array) -> str:
56
56
 
57
57
  @lovely_repr(dependencies=["jax"])
58
58
  def jax_repr(array: jax.Array) -> LovelyStats | None:
59
- import jax.numpy as jnp
59
+ import jax.numpy as jnp # pyright: ignore[reportMissingImports]
60
60
 
61
61
  # For dtypes like `object` or `str`, we let the fallback repr handle it
62
62
  if not jnp.issubdtype(array.dtype, jnp.number):
@@ -85,7 +85,7 @@ class jax_monkey_patch(lovely_patch):
85
85
 
86
86
  @override
87
87
  def patch(self):
88
- from jax._src import array
88
+ from jax._src import array # pyright: ignore[reportMissingImports]
89
89
 
90
90
  self.prev_repr = array.ArrayImpl.__repr__
91
91
  self.prev_str = array.ArrayImpl.__str__
@@ -96,7 +96,7 @@ class jax_monkey_patch(lovely_patch):
96
96
 
97
97
  @override
98
98
  def unpatch(self):
99
- from jax._src import array
99
+ from jax._src import array # pyright: ignore[reportMissingImports]
100
100
 
101
101
  patch_to(array.ArrayImpl, "__repr__", self.prev_repr)
102
102
  patch_to(array.ArrayImpl, "__str__", self.prev_str)
nshutils/lovely/torch_.py CHANGED
@@ -9,11 +9,11 @@ from ._base import lovely_patch, lovely_repr
9
9
  from .utils import LovelyStats, array_stats, patch_to
10
10
 
11
11
  if TYPE_CHECKING:
12
- import torch
12
+ import torch # pyright: ignore[reportMissingImports]
13
13
 
14
14
 
15
15
  def _type_name(tensor: torch.Tensor):
16
- import torch
16
+ import torch # pyright: ignore[reportMissingImports]
17
17
 
18
18
  return (
19
19
  "tensor"
@@ -45,7 +45,7 @@ def _dtype_str(tensor: torch.Tensor) -> str:
45
45
 
46
46
 
47
47
  def _to_np(tensor: torch.Tensor) -> np.ndarray:
48
- import torch
48
+ import torch # pyright: ignore[reportMissingImports]
49
49
 
50
50
  # Get tensor data as CPU NumPy array for analysis
51
51
  t_cpu = tensor.detach().cpu()
@@ -88,7 +88,7 @@ class torch_monkey_patch(lovely_patch):
88
88
 
89
89
  @override
90
90
  def patch(self):
91
- import torch
91
+ import torch # pyright: ignore[reportMissingImports]
92
92
 
93
93
  self.original_repr = torch.Tensor.__repr__
94
94
  self.original_str = torch.Tensor.__str__
@@ -104,7 +104,7 @@ class torch_monkey_patch(lovely_patch):
104
104
 
105
105
  @override
106
106
  def unpatch(self):
107
- import torch
107
+ import torch # pyright: ignore[reportMissingImports]
108
108
 
109
109
  patch_to(torch.Tensor, "__repr__", self.original_repr)
110
110
  patch_to(torch.Tensor, "__str__", self.original_str)
nshutils/snoop.py CHANGED
@@ -20,19 +20,24 @@ try:
20
20
  import warnings
21
21
  from contextlib import nullcontext
22
22
 
23
- import pysnooper # type: ignore
24
- import pysnooper.utils # type: ignore
23
+ import pysnooper # pyright: ignore[reportMissingImports]
24
+ import pysnooper.utils # pyright: ignore[reportMissingImports]
25
25
 
26
26
  try:
27
- import torch # type: ignore
27
+ import torch # pyright: ignore[reportMissingImports]
28
28
  except ImportError:
29
29
  torch = None
30
30
 
31
31
  try:
32
- import numpy # type: ignore
32
+ import numpy # pyright: ignore[reportMissingImports]
33
33
  except ImportError:
34
34
  numpy = None
35
35
 
36
+ try:
37
+ import jax # pyright: ignore[reportMissingImports]
38
+ except ImportError:
39
+ jax = None
40
+
36
41
  FLOATING_POINTS = set()
37
42
  for i in ["float", "double", "half", "complex128", "complex32", "complex64"]:
38
43
  # older version of PyTorch do not have complex dtypes
@@ -48,17 +53,25 @@ try:
48
53
 
49
54
  def default_format(x):
50
55
  try:
51
- import lovely_tensors as lt # type: ignore
56
+ from .lovely import torch_repr
52
57
 
53
- return str(lt.lovely(x))
58
+ return torch_repr(x)
54
59
  except BaseException:
55
60
  return str(x.shape)
56
61
 
57
62
  def default_numpy_format(x):
58
63
  try:
59
- import lovely_numpy as lo # type: ignore
64
+ from .lovely import numpy_repr
65
+
66
+ return numpy_repr(x)
67
+ except BaseException:
68
+ return str(x.shape)
69
+
70
+ def default_jax_format(x):
71
+ try:
72
+ from .lovely import jax_repr
60
73
 
61
- return str(lo.lovely(x))
74
+ return jax_repr(x)
62
75
  except BaseException:
63
76
  return str(x.shape)
64
77
 
@@ -68,6 +81,7 @@ try:
68
81
  *args,
69
82
  tensor_format=default_format,
70
83
  numpy_format=default_numpy_format,
84
+ jax_format=default_jax_format,
71
85
  **kwargs,
72
86
  ):
73
87
  self.orig_custom_repr = (
@@ -78,6 +92,7 @@ try:
78
92
  super(TorchSnooper, self).__init__(*args, **kwargs)
79
93
  self.tensor_format = tensor_format
80
94
  self.numpy_format = numpy_format
95
+ self.jax_format = jax_format
81
96
 
82
97
  @staticmethod
83
98
  def is_return_types(x):
@@ -176,6 +191,8 @@ try:
176
191
  return self.tensor_format(x)
177
192
  if numpy is not None and isinstance(x, numpy.ndarray):
178
193
  return self.numpy_format(x)
194
+ if jax is not None and isinstance(x, jax.Array):
195
+ return self.jax_format(x)
179
196
  if self.is_return_types(x):
180
197
  return self.return_types_repr(x)
181
198
  if orig_repr_func is not repr:
nshutils/typecheck.py CHANGED
@@ -38,18 +38,18 @@ from jaxtyping._storage import get_shape_memo, shape_str
38
38
  from typing_extensions import TypeVar
39
39
 
40
40
  try:
41
- import torch # type: ignore
41
+ import torch # pyright: ignore[reportMissingImports]
42
42
  except ImportError:
43
43
  torch = None
44
44
 
45
45
  try:
46
- import np # type: ignore
46
+ import np # pyright: ignore[reportMissingImports]
47
47
  except ImportError:
48
48
  np = None
49
49
 
50
50
 
51
51
  try:
52
- import jax # type: ignore
52
+ import jax # pyright: ignore[reportMissingImports]
53
53
  except ImportError:
54
54
  jax = None
55
55
 
@@ -124,23 +124,23 @@ def _make_error_str(input: Any, t: Any) -> str:
124
124
  error_components.append(t.__instancecheck_str__(input))
125
125
  if torch is not None and torch.is_tensor(input):
126
126
  try:
127
- from lovely_tensors import lovely # type: ignore
127
+ from .lovely import torch_repr
128
128
 
129
- error_components.append(repr(lovely(input)))
129
+ error_components.append(torch_repr(input))
130
130
  except BaseException:
131
131
  error_components.append(repr(input.shape))
132
132
  elif jax is not None and isinstance(input, jax.Array):
133
133
  try:
134
- from lovely_jax import lovely # type: ignore
134
+ from .lovely import jax_repr
135
135
 
136
- error_components.append(repr(lovely(input)))
136
+ error_components.append(jax_repr(input))
137
137
  except BaseException:
138
138
  error_components.append(repr(input.shape))
139
139
  elif np is not None and isinstance(input, np.ndarray):
140
140
  try:
141
- from lovely_numpy import lovely # type: ignore
141
+ from .lovely import numpy_repr
142
142
 
143
- error_components.append(repr(lovely(input)))
143
+ error_components.append(numpy_repr(input))
144
144
  except BaseException:
145
145
  error_components.append(repr(input.shape))
146
146
  error_components.append(shape_str(get_shape_memo()))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshutils
3
- Version: 0.31.1
3
+ Version: 0.32.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -2,20 +2,20 @@ nshutils/__init__.py,sha256=AFx1d5k34MyJ2kCHQL5vrZB8GDp2nYUaIUEjszSa25I,477
2
2
  nshutils/__init__.pyi,sha256=R4TIk--jAgVyTibdgezJQTMce3HpMCNakAJeaDqA6bc,676
3
3
  nshutils/actsave/__init__.py,sha256=hAVsog9d1g3_rQN1TRslrl6sK1PhCGbjy8PPUAmJI58,203
4
4
  nshutils/actsave/_loader.py,sha256=btLSQdErpTmK6VyG8PxJrJNsztzyavSF71n4Ec3_49E,7619
5
- nshutils/actsave/_saver.py,sha256=_qkX0NZYvy31hdlyfhneac4kUNS_44XjOG0ZtKpdqrg,12720
5
+ nshutils/actsave/_saver.py,sha256=zcsuAP7JhaXy3EzBZillyiIf1EeIsoilNoDXWKvDta0,12834
6
6
  nshutils/collections.py,sha256=QWGyANmo4Efq4XRNHDSTE9tRLStwEZHGwE0ATHR-Vqo,5233
7
7
  nshutils/display.py,sha256=Ge63yllx7gi-MKL3mKQeQ5doql_nj56-o5aoTVmusDg,1473
8
8
  nshutils/logging.py,sha256=78pv3-I_gmbKSf5_mYYBr6_H4GNBGErghAdhH9wfYIc,2205
9
9
  nshutils/lovely/__init__.py,sha256=684eZOKLmSgsTcCVlWk1Ip1cxJxmz-rKeXLmWXuCEWA,487
10
10
  nshutils/lovely/_base.py,sha256=kJY-UhdFTRBlAg_YzfJmG4ICb6vSdOJKiRc6vksxvoE,4424
11
- nshutils/lovely/_monkey_patch_all.py,sha256=xq09InGcOsGDrELV_KIrhE0H4EWyMdrUZ_1_BR2e_b0,2224
11
+ nshutils/lovely/_monkey_patch_all.py,sha256=dgfqrJU1sYY7icV_yupG1gMdgTHo-Mn9EeiWVxcfq9I,2221
12
12
  nshutils/lovely/config.py,sha256=lVNMuU1oUvsYlGN0Sn-m6iOLbJIchVnWDpyHm09nWo8,1224
13
- nshutils/lovely/jax_.py,sha256=PGnv33LrEM3aLvXLBbAx4b7dOkJwONidyPZjToZ62Og,2592
13
+ nshutils/lovely/jax_.py,sha256=Hju21e2vjc2ps6tu6yXiR_GAq03S16WLUzcITrHtk60,2797
14
14
  nshutils/lovely/numpy_.py,sha256=BBP9663l4Hr-TB34xDMHQQZ1zpuOgBegUOGl7_wV6R0,3503
15
- nshutils/lovely/torch_.py,sha256=J1pDJY1zzEANqa6EaJpG1pc_SYgM8YWOo1TjWdVeiA0,2946
15
+ nshutils/lovely/torch_.py,sha256=4r46p4tTaT6cWHl1_VcHh6x3RBST4xijqF2h7f9fnVg,3151
16
16
  nshutils/lovely/utils.py,sha256=2ksT5YGVViFuWc8jSkwVCsABripJmyVJdEDDH7aab70,10459
17
- nshutils/snoop.py,sha256=7d7_Q5sJmINL1J29wcnxEvpV95zvZYNoVn5frCq-rww,7393
18
- nshutils/typecheck.py,sha256=Gi7xtfilN_UwZ1FTFqBVKDhcQzBEDonVxIv3bUj-uXY,5582
19
- nshutils-0.31.1.dist-info/METADATA,sha256=pW-XE-rF3TPtok9VPan8cG6IlaEjlX55NBL4SKaNOHQ,4406
20
- nshutils-0.31.1.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
21
- nshutils-0.31.1.dist-info/RECORD,,
17
+ nshutils/snoop.py,sha256=ZkqodwLOey3cAIZBK_OIV4xT7PyU10Y4L2oApi_87k8,7941
18
+ nshutils/typecheck.py,sha256=U6QDrAwxGtoQW5GYcifWJQkUicCz2iutPw1kexGPKaU,5596
19
+ nshutils-0.32.0.dist-info/METADATA,sha256=Hl3SklKnO3BbyZORJTdMxXTaBUmdnYB3KW0r02ucHM8,4406
20
+ nshutils-0.32.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
21
+ nshutils-0.32.0.dist-info/RECORD,,