nshutils 0.22.5__tar.gz → 0.30.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nshutils-0.22.5 → nshutils-0.30.1}/PKG-INFO +1 -1
- {nshutils-0.22.5 → nshutils-0.30.1}/pyproject.toml +1 -1
- {nshutils-0.22.5 → nshutils-0.30.1}/src/nshutils/actsave/_saver.py +44 -2
- {nshutils-0.22.5 → nshutils-0.30.1}/src/nshutils/lovely/_base.py +64 -35
- {nshutils-0.22.5 → nshutils-0.30.1}/src/nshutils/lovely/jax_.py +6 -1
- {nshutils-0.22.5 → nshutils-0.30.1}/src/nshutils/lovely/numpy_.py +8 -1
- {nshutils-0.22.5 → nshutils-0.30.1}/src/nshutils/lovely/torch_.py +7 -2
- {nshutils-0.22.5 → nshutils-0.30.1}/README.md +0 -0
- {nshutils-0.22.5 → nshutils-0.30.1}/src/nshutils/__init__.py +0 -0
- {nshutils-0.22.5 → nshutils-0.30.1}/src/nshutils/__init__.pyi +0 -0
- {nshutils-0.22.5 → nshutils-0.30.1}/src/nshutils/actsave/__init__.py +0 -0
- {nshutils-0.22.5 → nshutils-0.30.1}/src/nshutils/actsave/_loader.py +0 -0
- {nshutils-0.22.5 → nshutils-0.30.1}/src/nshutils/collections.py +0 -0
- {nshutils-0.22.5 → nshutils-0.30.1}/src/nshutils/display.py +0 -0
- {nshutils-0.22.5 → nshutils-0.30.1}/src/nshutils/logging.py +0 -0
- {nshutils-0.22.5 → nshutils-0.30.1}/src/nshutils/lovely/__init__.py +0 -0
- {nshutils-0.22.5 → nshutils-0.30.1}/src/nshutils/lovely/_monkey_patch_all.py +0 -0
- {nshutils-0.22.5 → nshutils-0.30.1}/src/nshutils/lovely/config.py +0 -0
- {nshutils-0.22.5 → nshutils-0.30.1}/src/nshutils/lovely/utils.py +0 -0
- {nshutils-0.22.5 → nshutils-0.30.1}/src/nshutils/snoop.py +0 -0
- {nshutils-0.22.5 → nshutils-0.30.1}/src/nshutils/typecheck.py +0 -0
- {nshutils-0.22.5 → nshutils-0.30.1}/src/nshutils/util.py +0 -0
@@ -241,6 +241,17 @@ class _Saver:
|
|
241
241
|
class ActSaveProvider:
|
242
242
|
_saver: _Saver | None = None
|
243
243
|
_prefixes: list[str] = []
|
244
|
+
_disable_count: int = 0
|
245
|
+
|
246
|
+
@property
|
247
|
+
def is_initialized(self) -> bool:
|
248
|
+
"""Returns True if ActSave.enable() has been called and not subsequently disabled."""
|
249
|
+
return self._saver is not None
|
250
|
+
|
251
|
+
@property
|
252
|
+
def is_enabled(self) -> bool:
|
253
|
+
"""Returns True if ActSave is currently active and will save activations."""
|
254
|
+
return self.is_initialized and self._disable_count == 0
|
244
255
|
|
245
256
|
def enable(self, save_dir: Path | None = None):
|
246
257
|
"""
|
@@ -294,6 +305,34 @@ class ActSaveProvider:
|
|
294
305
|
|
295
306
|
self._saver = None
|
296
307
|
self._prefixes = []
|
308
|
+
self._disable_count = 0
|
309
|
+
|
310
|
+
@contextlib.contextmanager
|
311
|
+
def disabled(self, condition: bool | Callable[[], bool] = True):
|
312
|
+
"""
|
313
|
+
Context manager to temporarily disable activation saving.
|
314
|
+
|
315
|
+
Args:
|
316
|
+
condition (bool | Callable[[], bool], optional):
|
317
|
+
If True or a callable returning True, saving is disabled within this context.
|
318
|
+
Defaults to True.
|
319
|
+
"""
|
320
|
+
if _torch_is_scripting():
|
321
|
+
yield
|
322
|
+
return
|
323
|
+
|
324
|
+
should_disable = condition() if callable(condition) else condition
|
325
|
+
if should_disable:
|
326
|
+
self._disable_count += 1
|
327
|
+
|
328
|
+
try:
|
329
|
+
yield
|
330
|
+
finally:
|
331
|
+
if should_disable:
|
332
|
+
self._disable_count -= 1
|
333
|
+
if self._disable_count < 0: # Should not happen
|
334
|
+
log.warning("ActSave disable count went below zero.")
|
335
|
+
self._disable_count = 0
|
297
336
|
|
298
337
|
@contextlib.contextmanager
|
299
338
|
def context(self, label: str):
|
@@ -307,7 +346,7 @@ class ActSaveProvider:
|
|
307
346
|
yield
|
308
347
|
return
|
309
348
|
|
310
|
-
if self.
|
349
|
+
if not self.is_enabled:
|
311
350
|
yield
|
312
351
|
return
|
313
352
|
|
@@ -368,9 +407,12 @@ class ActSaveProvider:
|
|
368
407
|
if _torch_is_scripting():
|
369
408
|
return
|
370
409
|
|
371
|
-
if self.
|
410
|
+
if not self.is_enabled:
|
372
411
|
return
|
373
412
|
|
413
|
+
# Ensure _saver is not None, which is guaranteed by is_enabled but mypy needs help
|
414
|
+
assert self._saver is not None
|
415
|
+
|
374
416
|
if acts is not None and callable(acts):
|
375
417
|
acts = acts()
|
376
418
|
self._saver.save(acts, **kwargs)
|
@@ -4,8 +4,16 @@ import functools
|
|
4
4
|
import importlib.util
|
5
5
|
import logging
|
6
6
|
from collections.abc import Callable, Iterator
|
7
|
-
|
8
|
-
|
7
|
+
from typing import Generic, Optional, cast
|
8
|
+
|
9
|
+
from typing_extensions import (
|
10
|
+
ParamSpec,
|
11
|
+
Protocol,
|
12
|
+
TypeAliasType,
|
13
|
+
TypeVar,
|
14
|
+
override,
|
15
|
+
runtime_checkable,
|
16
|
+
)
|
9
17
|
|
10
18
|
from ..util import ContextResource, resource_factory_contextmanager
|
11
19
|
from .utils import LovelyStats, format_tensor_stats
|
@@ -17,16 +25,26 @@ P = ParamSpec("P")
|
|
17
25
|
|
18
26
|
LovelyStatsFn = TypeAliasType(
|
19
27
|
"LovelyStatsFn",
|
20
|
-
Callable[[TArray], LovelyStats],
|
21
|
-
type_params=(TArray,),
|
22
|
-
)
|
23
|
-
LovelyReprFn = TypeAliasType(
|
24
|
-
"LovelyReprFn",
|
25
|
-
Callable[[TArray], str],
|
28
|
+
Callable[[TArray], Optional[LovelyStats]],
|
26
29
|
type_params=(TArray,),
|
27
30
|
)
|
28
31
|
|
29
32
|
|
33
|
+
@runtime_checkable
|
34
|
+
class LovelyReprFn(Protocol[TArray]):
|
35
|
+
@property
|
36
|
+
def __lovely_repr_instance__(self) -> lovely_repr[TArray]: ...
|
37
|
+
|
38
|
+
@__lovely_repr_instance__.setter
|
39
|
+
def __lovely_repr_instance__(self, value: lovely_repr[TArray]) -> None: ...
|
40
|
+
|
41
|
+
@property
|
42
|
+
def __name__(self) -> str: ...
|
43
|
+
|
44
|
+
def set_fallback_repr(self, repr_fn: Callable[[TArray], str]) -> None: ...
|
45
|
+
def __call__(self, value: TArray, /) -> str: ...
|
46
|
+
|
47
|
+
|
30
48
|
def _find_missing_deps(dependencies: list[str]):
|
31
49
|
missing_deps: list[str] = []
|
32
50
|
|
@@ -39,50 +57,61 @@ def _find_missing_deps(dependencies: list[str]):
|
|
39
57
|
return missing_deps
|
40
58
|
|
41
59
|
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
Returns:
|
51
|
-
A decorator function that takes a function and returns a lovely representation function.
|
52
|
-
|
53
|
-
Example:
|
54
|
-
@lovely_repr(dependencies=["torch"])
|
55
|
-
def my_array_stats(array):
|
56
|
-
return {...}
|
57
|
-
"""
|
58
|
-
|
59
|
-
def decorator_fn(array_stats_fn: LovelyStatsFn[TArray]) -> LovelyReprFn[TArray]:
|
60
|
+
class lovely_repr(Generic[TArray]):
|
61
|
+
@override
|
62
|
+
def __init__(
|
63
|
+
self,
|
64
|
+
dependencies: list[str],
|
65
|
+
fallback_repr: Callable[[TArray], str] | None = None,
|
66
|
+
):
|
60
67
|
"""
|
61
68
|
Decorator to create a lovely representation function for an array.
|
62
69
|
|
63
70
|
Args:
|
64
|
-
|
65
|
-
|
71
|
+
dependencies: List of dependencies to check before running the function.
|
72
|
+
If any dependency is not available, the function will not run.
|
73
|
+
fallback_repr: A function that takes an array and returns its fallback representation.
|
66
74
|
Returns:
|
67
|
-
A function that takes
|
75
|
+
A decorator function that takes a function and returns a lovely representation function.
|
76
|
+
|
77
|
+
Example:
|
78
|
+
@lovely_repr(dependencies=["torch"])
|
79
|
+
def my_array_stats(array):
|
80
|
+
return {...}
|
68
81
|
"""
|
82
|
+
super().__init__()
|
83
|
+
|
84
|
+
if fallback_repr is None:
|
85
|
+
fallback_repr = repr
|
86
|
+
|
87
|
+
self._dependencies = dependencies
|
88
|
+
self._fallback_repr = fallback_repr
|
89
|
+
|
90
|
+
def set_fallback_repr(self, repr_fn: Callable[[TArray], str]) -> None:
|
91
|
+
self._fallback_repr = repr_fn
|
69
92
|
|
93
|
+
def __call__(
|
94
|
+
self, array_stats_fn: LovelyStatsFn[TArray], /
|
95
|
+
) -> LovelyReprFn[TArray]:
|
70
96
|
@functools.wraps(array_stats_fn)
|
71
|
-
def
|
72
|
-
if missing_deps := _find_missing_deps(
|
97
|
+
def wrapper_fn(array: TArray) -> str:
|
98
|
+
if missing_deps := _find_missing_deps(self._dependencies):
|
73
99
|
log.warning(
|
74
100
|
f"Missing dependencies: {', '.join(missing_deps)}. "
|
75
101
|
"Skipping lovely representation."
|
76
102
|
)
|
77
|
-
return
|
103
|
+
return self._fallback_repr(array)
|
104
|
+
|
105
|
+
if (stats := array_stats_fn(array)) is None:
|
106
|
+
return self._fallback_repr(array)
|
78
107
|
|
79
|
-
stats = array_stats_fn(array)
|
80
108
|
return format_tensor_stats(stats)
|
81
109
|
|
110
|
+
wrapper = cast(LovelyReprFn[TArray], wrapper_fn)
|
111
|
+
wrapper.__lovely_repr_instance__ = self
|
112
|
+
wrapper.set_fallback_repr = self.set_fallback_repr
|
82
113
|
return wrapper
|
83
114
|
|
84
|
-
return decorator_fn
|
85
|
-
|
86
115
|
|
87
116
|
LovelyMonkeyPatchInputFn = TypeAliasType(
|
88
117
|
"LovelyMonkeyPatchInputFn",
|
@@ -54,9 +54,13 @@ def _device(array: jax.Array) -> str:
|
|
54
54
|
|
55
55
|
|
56
56
|
@lovely_repr(dependencies=["jax"])
|
57
|
-
def jax_repr(array: jax.Array) -> LovelyStats:
|
57
|
+
def jax_repr(array: jax.Array) -> LovelyStats | None:
|
58
58
|
import jax.numpy as jnp
|
59
59
|
|
60
|
+
# For dtypes like `object` or `str`, we let the fallback repr handle it
|
61
|
+
if not jnp.issubdtype(array.dtype, jnp.number):
|
62
|
+
return None
|
63
|
+
|
60
64
|
return {
|
61
65
|
# Basic attributes
|
62
66
|
"shape": array.shape,
|
@@ -79,6 +83,7 @@ def jax_monkey_patch():
|
|
79
83
|
|
80
84
|
prev_repr = array.ArrayImpl.__repr__
|
81
85
|
prev_str = array.ArrayImpl.__str__
|
86
|
+
jax_repr.set_fallback_repr(prev_repr)
|
82
87
|
try:
|
83
88
|
patch_to(array.ArrayImpl, "__repr__", jax_repr)
|
84
89
|
patch_to(array.ArrayImpl, "__str__", jax_repr)
|
@@ -52,7 +52,11 @@ def _dtype_str(array: np.ndarray) -> str:
|
|
52
52
|
|
53
53
|
|
54
54
|
@lovely_repr(dependencies=["numpy"])
|
55
|
-
def numpy_repr(array: np.ndarray) -> LovelyStats:
|
55
|
+
def numpy_repr(array: np.ndarray) -> LovelyStats | None:
|
56
|
+
# For dtypes like `object` or `str`, we let the fallback repr handle it
|
57
|
+
if not np.issubdtype(array.dtype, np.number):
|
58
|
+
return None
|
59
|
+
|
56
60
|
return {
|
57
61
|
# Basic attributes
|
58
62
|
"shape": array.shape,
|
@@ -67,6 +71,9 @@ def numpy_repr(array: np.ndarray) -> LovelyStats:
|
|
67
71
|
}
|
68
72
|
|
69
73
|
|
74
|
+
numpy_repr.set_fallback_repr(np.array_repr)
|
75
|
+
|
76
|
+
|
70
77
|
# If numpy 2.0, use the new API override_repr.
|
71
78
|
if _np_ge_2():
|
72
79
|
|
@@ -60,7 +60,7 @@ def _to_np(tensor: torch.Tensor) -> np.ndarray:
|
|
60
60
|
|
61
61
|
|
62
62
|
@lovely_repr(dependencies=["torch"])
|
63
|
-
def torch_repr(tensor: torch.Tensor) -> LovelyStats:
|
63
|
+
def torch_repr(tensor: torch.Tensor) -> LovelyStats | None:
|
64
64
|
return {
|
65
65
|
# Basic attributes
|
66
66
|
"shape": tensor.shape,
|
@@ -87,10 +87,15 @@ def torch_monkey_patch():
|
|
87
87
|
original_repr = torch.Tensor.__repr__
|
88
88
|
original_str = torch.Tensor.__str__
|
89
89
|
original_parameter_repr = torch.nn.Parameter.__repr__
|
90
|
+
torch_repr.set_fallback_repr(original_repr)
|
91
|
+
|
90
92
|
try:
|
91
93
|
patch_to(torch.Tensor, "__repr__", torch_repr)
|
92
94
|
patch_to(torch.Tensor, "__str__", torch_repr)
|
93
|
-
|
95
|
+
try:
|
96
|
+
delattr(torch.nn.Parameter, "__repr__")
|
97
|
+
except AttributeError:
|
98
|
+
pass
|
94
99
|
|
95
100
|
yield
|
96
101
|
finally:
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|