nshutils 0.22.5__tar.gz → 0.30.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (22) hide show
  1. {nshutils-0.22.5 → nshutils-0.30.1}/PKG-INFO +1 -1
  2. {nshutils-0.22.5 → nshutils-0.30.1}/pyproject.toml +1 -1
  3. {nshutils-0.22.5 → nshutils-0.30.1}/src/nshutils/actsave/_saver.py +44 -2
  4. {nshutils-0.22.5 → nshutils-0.30.1}/src/nshutils/lovely/_base.py +64 -35
  5. {nshutils-0.22.5 → nshutils-0.30.1}/src/nshutils/lovely/jax_.py +6 -1
  6. {nshutils-0.22.5 → nshutils-0.30.1}/src/nshutils/lovely/numpy_.py +8 -1
  7. {nshutils-0.22.5 → nshutils-0.30.1}/src/nshutils/lovely/torch_.py +7 -2
  8. {nshutils-0.22.5 → nshutils-0.30.1}/README.md +0 -0
  9. {nshutils-0.22.5 → nshutils-0.30.1}/src/nshutils/__init__.py +0 -0
  10. {nshutils-0.22.5 → nshutils-0.30.1}/src/nshutils/__init__.pyi +0 -0
  11. {nshutils-0.22.5 → nshutils-0.30.1}/src/nshutils/actsave/__init__.py +0 -0
  12. {nshutils-0.22.5 → nshutils-0.30.1}/src/nshutils/actsave/_loader.py +0 -0
  13. {nshutils-0.22.5 → nshutils-0.30.1}/src/nshutils/collections.py +0 -0
  14. {nshutils-0.22.5 → nshutils-0.30.1}/src/nshutils/display.py +0 -0
  15. {nshutils-0.22.5 → nshutils-0.30.1}/src/nshutils/logging.py +0 -0
  16. {nshutils-0.22.5 → nshutils-0.30.1}/src/nshutils/lovely/__init__.py +0 -0
  17. {nshutils-0.22.5 → nshutils-0.30.1}/src/nshutils/lovely/_monkey_patch_all.py +0 -0
  18. {nshutils-0.22.5 → nshutils-0.30.1}/src/nshutils/lovely/config.py +0 -0
  19. {nshutils-0.22.5 → nshutils-0.30.1}/src/nshutils/lovely/utils.py +0 -0
  20. {nshutils-0.22.5 → nshutils-0.30.1}/src/nshutils/snoop.py +0 -0
  21. {nshutils-0.22.5 → nshutils-0.30.1}/src/nshutils/typecheck.py +0 -0
  22. {nshutils-0.22.5 → nshutils-0.30.1}/src/nshutils/util.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshutils
3
- Version: 0.22.5
3
+ Version: 0.30.1
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "nshutils"
3
- version = "0.22.5"
3
+ version = "0.30.1"
4
4
  description = ""
5
5
  authors = [{ name = "Nima Shoghi", email = "nimashoghi@gmail.com" }]
6
6
  requires-python = ">=3.9,<4.0"
@@ -241,6 +241,17 @@ class _Saver:
241
241
  class ActSaveProvider:
242
242
  _saver: _Saver | None = None
243
243
  _prefixes: list[str] = []
244
+ _disable_count: int = 0
245
+
246
+ @property
247
+ def is_initialized(self) -> bool:
248
+ """Returns True if ActSave.enable() has been called and not subsequently disabled."""
249
+ return self._saver is not None
250
+
251
+ @property
252
+ def is_enabled(self) -> bool:
253
+ """Returns True if ActSave is currently active and will save activations."""
254
+ return self.is_initialized and self._disable_count == 0
244
255
 
245
256
  def enable(self, save_dir: Path | None = None):
246
257
  """
@@ -294,6 +305,34 @@ class ActSaveProvider:
294
305
 
295
306
  self._saver = None
296
307
  self._prefixes = []
308
+ self._disable_count = 0
309
+
310
+ @contextlib.contextmanager
311
+ def disabled(self, condition: bool | Callable[[], bool] = True):
312
+ """
313
+ Context manager to temporarily disable activation saving.
314
+
315
+ Args:
316
+ condition (bool | Callable[[], bool], optional):
317
+ If True or a callable returning True, saving is disabled within this context.
318
+ Defaults to True.
319
+ """
320
+ if _torch_is_scripting():
321
+ yield
322
+ return
323
+
324
+ should_disable = condition() if callable(condition) else condition
325
+ if should_disable:
326
+ self._disable_count += 1
327
+
328
+ try:
329
+ yield
330
+ finally:
331
+ if should_disable:
332
+ self._disable_count -= 1
333
+ if self._disable_count < 0: # Should not happen
334
+ log.warning("ActSave disable count went below zero.")
335
+ self._disable_count = 0
297
336
 
298
337
  @contextlib.contextmanager
299
338
  def context(self, label: str):
@@ -307,7 +346,7 @@ class ActSaveProvider:
307
346
  yield
308
347
  return
309
348
 
310
- if self._saver is None:
349
+ if not self.is_enabled:
311
350
  yield
312
351
  return
313
352
 
@@ -368,9 +407,12 @@ class ActSaveProvider:
368
407
  if _torch_is_scripting():
369
408
  return
370
409
 
371
- if self._saver is None:
410
+ if not self.is_enabled:
372
411
  return
373
412
 
413
+ # Ensure _saver is not None, which is guaranteed by is_enabled but mypy needs help
414
+ assert self._saver is not None
415
+
374
416
  if acts is not None and callable(acts):
375
417
  acts = acts()
376
418
  self._saver.save(acts, **kwargs)
@@ -4,8 +4,16 @@ import functools
4
4
  import importlib.util
5
5
  import logging
6
6
  from collections.abc import Callable, Iterator
7
-
8
- from typing_extensions import ParamSpec, TypeAliasType, TypeVar
7
+ from typing import Generic, Optional, cast
8
+
9
+ from typing_extensions import (
10
+ ParamSpec,
11
+ Protocol,
12
+ TypeAliasType,
13
+ TypeVar,
14
+ override,
15
+ runtime_checkable,
16
+ )
9
17
 
10
18
  from ..util import ContextResource, resource_factory_contextmanager
11
19
  from .utils import LovelyStats, format_tensor_stats
@@ -17,16 +25,26 @@ P = ParamSpec("P")
17
25
 
18
26
  LovelyStatsFn = TypeAliasType(
19
27
  "LovelyStatsFn",
20
- Callable[[TArray], LovelyStats],
21
- type_params=(TArray,),
22
- )
23
- LovelyReprFn = TypeAliasType(
24
- "LovelyReprFn",
25
- Callable[[TArray], str],
28
+ Callable[[TArray], Optional[LovelyStats]],
26
29
  type_params=(TArray,),
27
30
  )
28
31
 
29
32
 
33
+ @runtime_checkable
34
+ class LovelyReprFn(Protocol[TArray]):
35
+ @property
36
+ def __lovely_repr_instance__(self) -> lovely_repr[TArray]: ...
37
+
38
+ @__lovely_repr_instance__.setter
39
+ def __lovely_repr_instance__(self, value: lovely_repr[TArray]) -> None: ...
40
+
41
+ @property
42
+ def __name__(self) -> str: ...
43
+
44
+ def set_fallback_repr(self, repr_fn: Callable[[TArray], str]) -> None: ...
45
+ def __call__(self, value: TArray, /) -> str: ...
46
+
47
+
30
48
  def _find_missing_deps(dependencies: list[str]):
31
49
  missing_deps: list[str] = []
32
50
 
@@ -39,50 +57,61 @@ def _find_missing_deps(dependencies: list[str]):
39
57
  return missing_deps
40
58
 
41
59
 
42
- def lovely_repr(dependencies: list[str]):
43
- """
44
- Decorator to create a lovely representation function for an array.
45
-
46
- Args:
47
- dependencies: List of dependencies to check before running the function.
48
- If any dependency is not available, the function will not run.
49
-
50
- Returns:
51
- A decorator function that takes a function and returns a lovely representation function.
52
-
53
- Example:
54
- @lovely_repr(dependencies=["torch"])
55
- def my_array_stats(array):
56
- return {...}
57
- """
58
-
59
- def decorator_fn(array_stats_fn: LovelyStatsFn[TArray]) -> LovelyReprFn[TArray]:
60
+ class lovely_repr(Generic[TArray]):
61
+ @override
62
+ def __init__(
63
+ self,
64
+ dependencies: list[str],
65
+ fallback_repr: Callable[[TArray], str] | None = None,
66
+ ):
60
67
  """
61
68
  Decorator to create a lovely representation function for an array.
62
69
 
63
70
  Args:
64
- array_stats_fn: A function that takes an array and returns its stats.
65
-
71
+ dependencies: List of dependencies to check before running the function.
72
+ If any dependency is not available, the function will not run.
73
+ fallback_repr: A function that takes an array and returns its fallback representation.
66
74
  Returns:
67
- A function that takes an array and returns its lovely representation.
75
+ A decorator function that takes a function and returns a lovely representation function.
76
+
77
+ Example:
78
+ @lovely_repr(dependencies=["torch"])
79
+ def my_array_stats(array):
80
+ return {...}
68
81
  """
82
+ super().__init__()
83
+
84
+ if fallback_repr is None:
85
+ fallback_repr = repr
86
+
87
+ self._dependencies = dependencies
88
+ self._fallback_repr = fallback_repr
89
+
90
+ def set_fallback_repr(self, repr_fn: Callable[[TArray], str]) -> None:
91
+ self._fallback_repr = repr_fn
69
92
 
93
+ def __call__(
94
+ self, array_stats_fn: LovelyStatsFn[TArray], /
95
+ ) -> LovelyReprFn[TArray]:
70
96
  @functools.wraps(array_stats_fn)
71
- def wrapper(array: TArray) -> str:
72
- if missing_deps := _find_missing_deps(dependencies):
97
+ def wrapper_fn(array: TArray) -> str:
98
+ if missing_deps := _find_missing_deps(self._dependencies):
73
99
  log.warning(
74
100
  f"Missing dependencies: {', '.join(missing_deps)}. "
75
101
  "Skipping lovely representation."
76
102
  )
77
- return repr(array)
103
+ return self._fallback_repr(array)
104
+
105
+ if (stats := array_stats_fn(array)) is None:
106
+ return self._fallback_repr(array)
78
107
 
79
- stats = array_stats_fn(array)
80
108
  return format_tensor_stats(stats)
81
109
 
110
+ wrapper = cast(LovelyReprFn[TArray], wrapper_fn)
111
+ wrapper.__lovely_repr_instance__ = self
112
+ wrapper.set_fallback_repr = self.set_fallback_repr
82
113
  return wrapper
83
114
 
84
- return decorator_fn
85
-
86
115
 
87
116
  LovelyMonkeyPatchInputFn = TypeAliasType(
88
117
  "LovelyMonkeyPatchInputFn",
@@ -54,9 +54,13 @@ def _device(array: jax.Array) -> str:
54
54
 
55
55
 
56
56
  @lovely_repr(dependencies=["jax"])
57
- def jax_repr(array: jax.Array) -> LovelyStats:
57
+ def jax_repr(array: jax.Array) -> LovelyStats | None:
58
58
  import jax.numpy as jnp
59
59
 
60
+ # For dtypes like `object` or `str`, we let the fallback repr handle it
61
+ if not jnp.issubdtype(array.dtype, jnp.number):
62
+ return None
63
+
60
64
  return {
61
65
  # Basic attributes
62
66
  "shape": array.shape,
@@ -79,6 +83,7 @@ def jax_monkey_patch():
79
83
 
80
84
  prev_repr = array.ArrayImpl.__repr__
81
85
  prev_str = array.ArrayImpl.__str__
86
+ jax_repr.set_fallback_repr(prev_repr)
82
87
  try:
83
88
  patch_to(array.ArrayImpl, "__repr__", jax_repr)
84
89
  patch_to(array.ArrayImpl, "__str__", jax_repr)
@@ -52,7 +52,11 @@ def _dtype_str(array: np.ndarray) -> str:
52
52
 
53
53
 
54
54
  @lovely_repr(dependencies=["numpy"])
55
- def numpy_repr(array: np.ndarray) -> LovelyStats:
55
+ def numpy_repr(array: np.ndarray) -> LovelyStats | None:
56
+ # For dtypes like `object` or `str`, we let the fallback repr handle it
57
+ if not np.issubdtype(array.dtype, np.number):
58
+ return None
59
+
56
60
  return {
57
61
  # Basic attributes
58
62
  "shape": array.shape,
@@ -67,6 +71,9 @@ def numpy_repr(array: np.ndarray) -> LovelyStats:
67
71
  }
68
72
 
69
73
 
74
+ numpy_repr.set_fallback_repr(np.array_repr)
75
+
76
+
70
77
  # If numpy 2.0, use the new API override_repr.
71
78
  if _np_ge_2():
72
79
 
@@ -60,7 +60,7 @@ def _to_np(tensor: torch.Tensor) -> np.ndarray:
60
60
 
61
61
 
62
62
  @lovely_repr(dependencies=["torch"])
63
- def torch_repr(tensor: torch.Tensor) -> LovelyStats:
63
+ def torch_repr(tensor: torch.Tensor) -> LovelyStats | None:
64
64
  return {
65
65
  # Basic attributes
66
66
  "shape": tensor.shape,
@@ -87,10 +87,15 @@ def torch_monkey_patch():
87
87
  original_repr = torch.Tensor.__repr__
88
88
  original_str = torch.Tensor.__str__
89
89
  original_parameter_repr = torch.nn.Parameter.__repr__
90
+ torch_repr.set_fallback_repr(original_repr)
91
+
90
92
  try:
91
93
  patch_to(torch.Tensor, "__repr__", torch_repr)
92
94
  patch_to(torch.Tensor, "__str__", torch_repr)
93
- del torch.nn.Parameter.__repr__
95
+ try:
96
+ delattr(torch.nn.Parameter, "__repr__")
97
+ except AttributeError:
98
+ pass
94
99
 
95
100
  yield
96
101
  finally:
File without changes