nshutils 0.31.0__tar.gz → 0.32.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nshutils-0.31.0 → nshutils-0.32.0}/PKG-INFO +1 -1
- {nshutils-0.31.0 → nshutils-0.32.0}/pyproject.toml +1 -1
- {nshutils-0.31.0 → nshutils-0.32.0}/src/nshutils/actsave/_saver.py +24 -6
- {nshutils-0.31.0 → nshutils-0.32.0}/src/nshutils/lovely/_monkey_patch_all.py +4 -3
- {nshutils-0.31.0 → nshutils-0.32.0}/src/nshutils/lovely/jax_.py +5 -5
- {nshutils-0.31.0 → nshutils-0.32.0}/src/nshutils/lovely/torch_.py +5 -5
- {nshutils-0.31.0 → nshutils-0.32.0}/src/nshutils/snoop.py +25 -8
- {nshutils-0.31.0 → nshutils-0.32.0}/src/nshutils/typecheck.py +9 -9
- {nshutils-0.31.0 → nshutils-0.32.0}/README.md +0 -0
- {nshutils-0.31.0 → nshutils-0.32.0}/src/nshutils/__init__.py +0 -0
- {nshutils-0.31.0 → nshutils-0.32.0}/src/nshutils/__init__.pyi +0 -0
- {nshutils-0.31.0 → nshutils-0.32.0}/src/nshutils/actsave/__init__.py +0 -0
- {nshutils-0.31.0 → nshutils-0.32.0}/src/nshutils/actsave/_loader.py +0 -0
- {nshutils-0.31.0 → nshutils-0.32.0}/src/nshutils/collections.py +0 -0
- {nshutils-0.31.0 → nshutils-0.32.0}/src/nshutils/display.py +0 -0
- {nshutils-0.31.0 → nshutils-0.32.0}/src/nshutils/logging.py +0 -0
- {nshutils-0.31.0 → nshutils-0.32.0}/src/nshutils/lovely/__init__.py +0 -0
- {nshutils-0.31.0 → nshutils-0.32.0}/src/nshutils/lovely/_base.py +0 -0
- {nshutils-0.31.0 → nshutils-0.32.0}/src/nshutils/lovely/config.py +0 -0
- {nshutils-0.31.0 → nshutils-0.32.0}/src/nshutils/lovely/numpy_.py +0 -0
- {nshutils-0.31.0 → nshutils-0.32.0}/src/nshutils/lovely/utils.py +0 -0
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import contextlib
|
4
4
|
import fnmatch
|
5
|
+
import os
|
5
6
|
import tempfile
|
6
7
|
import weakref
|
7
8
|
from collections.abc import Callable, Mapping
|
@@ -19,7 +20,7 @@ from ..collections import apply_to_collection
|
|
19
20
|
|
20
21
|
if not TYPE_CHECKING:
|
21
22
|
try:
|
22
|
-
import torch #
|
23
|
+
import torch # pyright: ignore[reportMissingImports]
|
23
24
|
|
24
25
|
Tensor = torch.Tensor
|
25
26
|
_torch_installed = True
|
@@ -29,9 +30,9 @@ if not TYPE_CHECKING:
|
|
29
30
|
|
30
31
|
Tensor = Never
|
31
32
|
else:
|
32
|
-
import torch #
|
33
|
+
import torch # pyright: ignore[reportMissingImports]
|
33
34
|
|
34
|
-
Tensor = torch.Tensor
|
35
|
+
Tensor = TypeAliasType("Tensor", torch.Tensor)
|
35
36
|
_torch_installed: Literal[True] = True
|
36
37
|
|
37
38
|
log = getLogger(__name__)
|
@@ -59,7 +60,7 @@ def _to_numpy(activation: Value) -> np.ndarray:
|
|
59
60
|
return np.array(activation)
|
60
61
|
elif isinstance(activation, np.ndarray):
|
61
62
|
return activation
|
62
|
-
elif _torch_installed and isinstance(activation, Tensor):
|
63
|
+
elif _torch_installed and isinstance(activation, torch.Tensor):
|
63
64
|
activation_ = activation.detach()
|
64
65
|
if activation_.is_floating_point():
|
65
66
|
# NOTE: We need to convert to float32 because [b]float16 is not supported by numpy
|
@@ -125,7 +126,8 @@ class Activation:
|
|
125
126
|
if activation is None:
|
126
127
|
return None
|
127
128
|
|
128
|
-
|
129
|
+
if _torch_installed:
|
130
|
+
activation = apply_to_collection(activation, Tensor, _to_numpy)
|
129
131
|
activation = _to_numpy(activation)
|
130
132
|
|
131
133
|
# Set the transformed value
|
@@ -266,7 +268,11 @@ class ActSaveProvider:
|
|
266
268
|
|
267
269
|
if save_dir is None:
|
268
270
|
save_dir = Path(tempfile.gettempdir()) / f"actsave-{uuid7str()}"
|
269
|
-
log.
|
271
|
+
log.warning(
|
272
|
+
f"ActSave: Using temporary directory {save_dir} for activations."
|
273
|
+
)
|
274
|
+
else:
|
275
|
+
log.info(f"ActSave enabled. Saving to {save_dir}")
|
270
276
|
self._saver = _Saver(save_dir, lambda: self._prefixes)
|
271
277
|
|
272
278
|
def disable(self):
|
@@ -307,6 +313,18 @@ class ActSaveProvider:
|
|
307
313
|
self._prefixes = []
|
308
314
|
self._disable_count = 0
|
309
315
|
|
316
|
+
# Check for environment variable `ACTSAVE` to automatically enable saving.
|
317
|
+
# If set to "1" or "true" (case-insensitive), activations are saved to a temporary directory.
|
318
|
+
# If set to a path, activations are saved to that path.
|
319
|
+
if env_var := os.environ.get("ACTSAVE"):
|
320
|
+
log.info(
|
321
|
+
f"`ACTSAVE={env_var}` detected, attempting to auto-enable activation saving."
|
322
|
+
)
|
323
|
+
if env_var.lower() in ("1", "true"):
|
324
|
+
self.enable()
|
325
|
+
else:
|
326
|
+
self.enable(Path(env_var))
|
327
|
+
|
310
328
|
@contextlib.contextmanager
|
311
329
|
def disabled(self, condition: bool | Callable[[], bool] = True):
|
312
330
|
"""
|
@@ -30,9 +30,10 @@ def _find_deps() -> list[Library]:
|
|
30
30
|
|
31
31
|
class monkey_patch(lovely_patch):
|
32
32
|
def __init__(self, libraries: list[Library] | Literal["auto"] = "auto"):
|
33
|
-
|
34
|
-
if self.libraries == "auto":
|
33
|
+
if libraries == "auto":
|
35
34
|
self.libraries = _find_deps()
|
35
|
+
else:
|
36
|
+
self.libraries = libraries
|
36
37
|
|
37
38
|
if not self.libraries:
|
38
39
|
raise ValueError(
|
@@ -59,7 +60,7 @@ class monkey_patch(lovely_patch):
|
|
59
60
|
|
60
61
|
self.stack.enter_context(numpy_monkey_patch())
|
61
62
|
else:
|
62
|
-
assert_never(library)
|
63
|
+
assert_never(library)
|
63
64
|
|
64
65
|
log.info(
|
65
66
|
f"Monkey patched libraries: {', '.join(self.libraries)}. "
|
@@ -9,7 +9,7 @@ from ._base import lovely_patch, lovely_repr
|
|
9
9
|
from .utils import LovelyStats, array_stats, patch_to
|
10
10
|
|
11
11
|
if TYPE_CHECKING:
|
12
|
-
import jax
|
12
|
+
import jax # pyright: ignore[reportMissingImports]
|
13
13
|
|
14
14
|
|
15
15
|
def _type_name(array: jax.Array):
|
@@ -42,7 +42,7 @@ def _dtype_str(array: jax.Array) -> str:
|
|
42
42
|
|
43
43
|
|
44
44
|
def _device(array: jax.Array) -> str:
|
45
|
-
from jaxlib.xla_extension import Device
|
45
|
+
from jaxlib.xla_extension import Device # pyright: ignore[reportMissingImports]
|
46
46
|
|
47
47
|
if callable(device := array.device):
|
48
48
|
device = device()
|
@@ -56,7 +56,7 @@ def _device(array: jax.Array) -> str:
|
|
56
56
|
|
57
57
|
@lovely_repr(dependencies=["jax"])
|
58
58
|
def jax_repr(array: jax.Array) -> LovelyStats | None:
|
59
|
-
import jax.numpy as jnp
|
59
|
+
import jax.numpy as jnp # pyright: ignore[reportMissingImports]
|
60
60
|
|
61
61
|
# For dtypes like `object` or `str`, we let the fallback repr handle it
|
62
62
|
if not jnp.issubdtype(array.dtype, jnp.number):
|
@@ -85,7 +85,7 @@ class jax_monkey_patch(lovely_patch):
|
|
85
85
|
|
86
86
|
@override
|
87
87
|
def patch(self):
|
88
|
-
from jax._src import array
|
88
|
+
from jax._src import array # pyright: ignore[reportMissingImports]
|
89
89
|
|
90
90
|
self.prev_repr = array.ArrayImpl.__repr__
|
91
91
|
self.prev_str = array.ArrayImpl.__str__
|
@@ -96,7 +96,7 @@ class jax_monkey_patch(lovely_patch):
|
|
96
96
|
|
97
97
|
@override
|
98
98
|
def unpatch(self):
|
99
|
-
from jax._src import array
|
99
|
+
from jax._src import array # pyright: ignore[reportMissingImports]
|
100
100
|
|
101
101
|
patch_to(array.ArrayImpl, "__repr__", self.prev_repr)
|
102
102
|
patch_to(array.ArrayImpl, "__str__", self.prev_str)
|
@@ -9,11 +9,11 @@ from ._base import lovely_patch, lovely_repr
|
|
9
9
|
from .utils import LovelyStats, array_stats, patch_to
|
10
10
|
|
11
11
|
if TYPE_CHECKING:
|
12
|
-
import torch
|
12
|
+
import torch # pyright: ignore[reportMissingImports]
|
13
13
|
|
14
14
|
|
15
15
|
def _type_name(tensor: torch.Tensor):
|
16
|
-
import torch
|
16
|
+
import torch # pyright: ignore[reportMissingImports]
|
17
17
|
|
18
18
|
return (
|
19
19
|
"tensor"
|
@@ -45,7 +45,7 @@ def _dtype_str(tensor: torch.Tensor) -> str:
|
|
45
45
|
|
46
46
|
|
47
47
|
def _to_np(tensor: torch.Tensor) -> np.ndarray:
|
48
|
-
import torch
|
48
|
+
import torch # pyright: ignore[reportMissingImports]
|
49
49
|
|
50
50
|
# Get tensor data as CPU NumPy array for analysis
|
51
51
|
t_cpu = tensor.detach().cpu()
|
@@ -88,7 +88,7 @@ class torch_monkey_patch(lovely_patch):
|
|
88
88
|
|
89
89
|
@override
|
90
90
|
def patch(self):
|
91
|
-
import torch
|
91
|
+
import torch # pyright: ignore[reportMissingImports]
|
92
92
|
|
93
93
|
self.original_repr = torch.Tensor.__repr__
|
94
94
|
self.original_str = torch.Tensor.__str__
|
@@ -104,7 +104,7 @@ class torch_monkey_patch(lovely_patch):
|
|
104
104
|
|
105
105
|
@override
|
106
106
|
def unpatch(self):
|
107
|
-
import torch
|
107
|
+
import torch # pyright: ignore[reportMissingImports]
|
108
108
|
|
109
109
|
patch_to(torch.Tensor, "__repr__", self.original_repr)
|
110
110
|
patch_to(torch.Tensor, "__str__", self.original_str)
|
@@ -20,19 +20,24 @@ try:
|
|
20
20
|
import warnings
|
21
21
|
from contextlib import nullcontext
|
22
22
|
|
23
|
-
import pysnooper #
|
24
|
-
import pysnooper.utils #
|
23
|
+
import pysnooper # pyright: ignore[reportMissingImports]
|
24
|
+
import pysnooper.utils # pyright: ignore[reportMissingImports]
|
25
25
|
|
26
26
|
try:
|
27
|
-
import torch #
|
27
|
+
import torch # pyright: ignore[reportMissingImports]
|
28
28
|
except ImportError:
|
29
29
|
torch = None
|
30
30
|
|
31
31
|
try:
|
32
|
-
import numpy #
|
32
|
+
import numpy # pyright: ignore[reportMissingImports]
|
33
33
|
except ImportError:
|
34
34
|
numpy = None
|
35
35
|
|
36
|
+
try:
|
37
|
+
import jax # pyright: ignore[reportMissingImports]
|
38
|
+
except ImportError:
|
39
|
+
jax = None
|
40
|
+
|
36
41
|
FLOATING_POINTS = set()
|
37
42
|
for i in ["float", "double", "half", "complex128", "complex32", "complex64"]:
|
38
43
|
# older version of PyTorch do not have complex dtypes
|
@@ -48,17 +53,25 @@ try:
|
|
48
53
|
|
49
54
|
def default_format(x):
|
50
55
|
try:
|
51
|
-
|
56
|
+
from .lovely import torch_repr
|
52
57
|
|
53
|
-
return
|
58
|
+
return torch_repr(x)
|
54
59
|
except BaseException:
|
55
60
|
return str(x.shape)
|
56
61
|
|
57
62
|
def default_numpy_format(x):
|
58
63
|
try:
|
59
|
-
|
64
|
+
from .lovely import numpy_repr
|
65
|
+
|
66
|
+
return numpy_repr(x)
|
67
|
+
except BaseException:
|
68
|
+
return str(x.shape)
|
69
|
+
|
70
|
+
def default_jax_format(x):
|
71
|
+
try:
|
72
|
+
from .lovely import jax_repr
|
60
73
|
|
61
|
-
return
|
74
|
+
return jax_repr(x)
|
62
75
|
except BaseException:
|
63
76
|
return str(x.shape)
|
64
77
|
|
@@ -68,6 +81,7 @@ try:
|
|
68
81
|
*args,
|
69
82
|
tensor_format=default_format,
|
70
83
|
numpy_format=default_numpy_format,
|
84
|
+
jax_format=default_jax_format,
|
71
85
|
**kwargs,
|
72
86
|
):
|
73
87
|
self.orig_custom_repr = (
|
@@ -78,6 +92,7 @@ try:
|
|
78
92
|
super(TorchSnooper, self).__init__(*args, **kwargs)
|
79
93
|
self.tensor_format = tensor_format
|
80
94
|
self.numpy_format = numpy_format
|
95
|
+
self.jax_format = jax_format
|
81
96
|
|
82
97
|
@staticmethod
|
83
98
|
def is_return_types(x):
|
@@ -176,6 +191,8 @@ try:
|
|
176
191
|
return self.tensor_format(x)
|
177
192
|
if numpy is not None and isinstance(x, numpy.ndarray):
|
178
193
|
return self.numpy_format(x)
|
194
|
+
if jax is not None and isinstance(x, jax.Array):
|
195
|
+
return self.jax_format(x)
|
179
196
|
if self.is_return_types(x):
|
180
197
|
return self.return_types_repr(x)
|
181
198
|
if orig_repr_func is not repr:
|
@@ -38,18 +38,18 @@ from jaxtyping._storage import get_shape_memo, shape_str
|
|
38
38
|
from typing_extensions import TypeVar
|
39
39
|
|
40
40
|
try:
|
41
|
-
import torch #
|
41
|
+
import torch # pyright: ignore[reportMissingImports]
|
42
42
|
except ImportError:
|
43
43
|
torch = None
|
44
44
|
|
45
45
|
try:
|
46
|
-
import np #
|
46
|
+
import np # pyright: ignore[reportMissingImports]
|
47
47
|
except ImportError:
|
48
48
|
np = None
|
49
49
|
|
50
50
|
|
51
51
|
try:
|
52
|
-
import jax #
|
52
|
+
import jax # pyright: ignore[reportMissingImports]
|
53
53
|
except ImportError:
|
54
54
|
jax = None
|
55
55
|
|
@@ -124,23 +124,23 @@ def _make_error_str(input: Any, t: Any) -> str:
|
|
124
124
|
error_components.append(t.__instancecheck_str__(input))
|
125
125
|
if torch is not None and torch.is_tensor(input):
|
126
126
|
try:
|
127
|
-
from
|
127
|
+
from .lovely import torch_repr
|
128
128
|
|
129
|
-
error_components.append(
|
129
|
+
error_components.append(torch_repr(input))
|
130
130
|
except BaseException:
|
131
131
|
error_components.append(repr(input.shape))
|
132
132
|
elif jax is not None and isinstance(input, jax.Array):
|
133
133
|
try:
|
134
|
-
from
|
134
|
+
from .lovely import jax_repr
|
135
135
|
|
136
|
-
error_components.append(
|
136
|
+
error_components.append(jax_repr(input))
|
137
137
|
except BaseException:
|
138
138
|
error_components.append(repr(input.shape))
|
139
139
|
elif np is not None and isinstance(input, np.ndarray):
|
140
140
|
try:
|
141
|
-
from
|
141
|
+
from .lovely import numpy_repr
|
142
142
|
|
143
|
-
error_components.append(
|
143
|
+
error_components.append(numpy_repr(input))
|
144
144
|
except BaseException:
|
145
145
|
error_components.append(repr(input.shape))
|
146
146
|
error_components.append(shape_str(get_shape_memo()))
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|