nshutils 0.22.6__py3-none-any.whl → 0.30.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -241,6 +241,17 @@ class _Saver:
241
241
  class ActSaveProvider:
242
242
  _saver: _Saver | None = None
243
243
  _prefixes: list[str] = []
244
+ _disable_count: int = 0
245
+
246
+ @property
247
+ def is_initialized(self) -> bool:
248
+ """Returns True if ActSave.enable() has been called and not subsequently disabled."""
249
+ return self._saver is not None
250
+
251
+ @property
252
+ def is_enabled(self) -> bool:
253
+ """Returns True if ActSave is currently active and will save activations."""
254
+ return self.is_initialized and self._disable_count == 0
244
255
 
245
256
  def enable(self, save_dir: Path | None = None):
246
257
  """
@@ -294,6 +305,34 @@ class ActSaveProvider:
294
305
 
295
306
  self._saver = None
296
307
  self._prefixes = []
308
+ self._disable_count = 0
309
+
310
+ @contextlib.contextmanager
311
+ def disabled(self, condition: bool | Callable[[], bool] = True):
312
+ """
313
+ Context manager to temporarily disable activation saving.
314
+
315
+ Args:
316
+ condition (bool | Callable[[], bool], optional):
317
+ If True or a callable returning True, saving is disabled within this context.
318
+ Defaults to True.
319
+ """
320
+ if _torch_is_scripting():
321
+ yield
322
+ return
323
+
324
+ should_disable = condition() if callable(condition) else condition
325
+ if should_disable:
326
+ self._disable_count += 1
327
+
328
+ try:
329
+ yield
330
+ finally:
331
+ if should_disable:
332
+ self._disable_count -= 1
333
+ if self._disable_count < 0: # Should not happen
334
+ log.warning("ActSave disable count went below zero.")
335
+ self._disable_count = 0
297
336
 
298
337
  @contextlib.contextmanager
299
338
  def context(self, label: str):
@@ -307,7 +346,7 @@ class ActSaveProvider:
307
346
  yield
308
347
  return
309
348
 
310
- if self._saver is None:
349
+ if not self.is_enabled:
311
350
  yield
312
351
  return
313
352
 
@@ -368,9 +407,12 @@ class ActSaveProvider:
368
407
  if _torch_is_scripting():
369
408
  return
370
409
 
371
- if self._saver is None:
410
+ if not self.is_enabled:
372
411
  return
373
412
 
413
+ # Ensure _saver is not None, which is guaranteed by is_enabled but mypy needs help
414
+ assert self._saver is not None
415
+
374
416
  if acts is not None and callable(acts):
375
417
  acts = acts()
376
418
  self._saver.save(acts, **kwargs)
nshutils/lovely/_base.py CHANGED
@@ -4,8 +4,16 @@ import functools
4
4
  import importlib.util
5
5
  import logging
6
6
  from collections.abc import Callable, Iterator
7
-
8
- from typing_extensions import ParamSpec, TypeAliasType, TypeVar
7
+ from typing import Generic, Optional, cast
8
+
9
+ from typing_extensions import (
10
+ ParamSpec,
11
+ Protocol,
12
+ TypeAliasType,
13
+ TypeVar,
14
+ override,
15
+ runtime_checkable,
16
+ )
9
17
 
10
18
  from ..util import ContextResource, resource_factory_contextmanager
11
19
  from .utils import LovelyStats, format_tensor_stats
@@ -17,16 +25,26 @@ P = ParamSpec("P")
17
25
 
18
26
  LovelyStatsFn = TypeAliasType(
19
27
  "LovelyStatsFn",
20
- Callable[[TArray], LovelyStats | None],
21
- type_params=(TArray,),
22
- )
23
- LovelyReprFn = TypeAliasType(
24
- "LovelyReprFn",
25
- Callable[[TArray], str],
28
+ Callable[[TArray], Optional[LovelyStats]],
26
29
  type_params=(TArray,),
27
30
  )
28
31
 
29
32
 
33
+ @runtime_checkable
34
+ class LovelyReprFn(Protocol[TArray]):
35
+ @property
36
+ def __lovely_repr_instance__(self) -> lovely_repr[TArray]: ...
37
+
38
+ @__lovely_repr_instance__.setter
39
+ def __lovely_repr_instance__(self, value: lovely_repr[TArray]) -> None: ...
40
+
41
+ @property
42
+ def __name__(self) -> str: ...
43
+
44
+ def set_fallback_repr(self, repr_fn: Callable[[TArray], str]) -> None: ...
45
+ def __call__(self, value: TArray, /) -> str: ...
46
+
47
+
30
48
  def _find_missing_deps(dependencies: list[str]):
31
49
  missing_deps: list[str] = []
32
50
 
@@ -39,53 +57,61 @@ def _find_missing_deps(dependencies: list[str]):
39
57
  return missing_deps
40
58
 
41
59
 
42
- def lovely_repr(dependencies: list[str], fallback_repr: Callable[[TArray], str]):
43
- """
44
- Decorator to create a lovely representation function for an array.
45
-
46
- Args:
47
- dependencies: List of dependencies to check before running the function.
48
- If any dependency is not available, the function will not run.
49
- fallback_repr: A function that takes an array and returns its fallback representation.
50
- Returns:
51
- A decorator function that takes a function and returns a lovely representation function.
52
-
53
- Example:
54
- @lovely_repr(dependencies=["torch"])
55
- def my_array_stats(array):
56
- return {...}
57
- """
58
-
59
- def decorator_fn(array_stats_fn: LovelyStatsFn[TArray]) -> LovelyReprFn[TArray]:
60
+ class lovely_repr(Generic[TArray]):
61
+ @override
62
+ def __init__(
63
+ self,
64
+ dependencies: list[str],
65
+ fallback_repr: Callable[[TArray], str] | None = None,
66
+ ):
60
67
  """
61
68
  Decorator to create a lovely representation function for an array.
62
69
 
63
70
  Args:
64
- array_stats_fn: A function that takes an array and returns its stats,
65
- or `None` if the array is not supported.
66
-
71
+ dependencies: List of dependencies to check before running the function.
72
+ If any dependency is not available, the function will not run.
73
+ fallback_repr: A function that takes an array and returns its fallback representation.
67
74
  Returns:
68
- A function that takes an array and returns its lovely representation.
75
+ A decorator function that takes a function and returns a lovely representation function.
76
+
77
+ Example:
78
+ @lovely_repr(dependencies=["torch"])
79
+ def my_array_stats(array):
80
+ return {...}
69
81
  """
82
+ super().__init__()
83
+
84
+ if fallback_repr is None:
85
+ fallback_repr = repr
86
+
87
+ self._dependencies = dependencies
88
+ self._fallback_repr = fallback_repr
70
89
 
90
+ def set_fallback_repr(self, repr_fn: Callable[[TArray], str]) -> None:
91
+ self._fallback_repr = repr_fn
92
+
93
+ def __call__(
94
+ self, array_stats_fn: LovelyStatsFn[TArray], /
95
+ ) -> LovelyReprFn[TArray]:
71
96
  @functools.wraps(array_stats_fn)
72
- def wrapper(array: TArray) -> str:
73
- if missing_deps := _find_missing_deps(dependencies):
97
+ def wrapper_fn(array: TArray) -> str:
98
+ if missing_deps := _find_missing_deps(self._dependencies):
74
99
  log.warning(
75
100
  f"Missing dependencies: {', '.join(missing_deps)}. "
76
101
  "Skipping lovely representation."
77
102
  )
78
- return fallback_repr(array)
103
+ return self._fallback_repr(array)
79
104
 
80
105
  if (stats := array_stats_fn(array)) is None:
81
- return fallback_repr(array)
106
+ return self._fallback_repr(array)
82
107
 
83
108
  return format_tensor_stats(stats)
84
109
 
110
+ wrapper = cast(LovelyReprFn[TArray], wrapper_fn)
111
+ wrapper.__lovely_repr_instance__ = self
112
+ wrapper.set_fallback_repr = self.set_fallback_repr
85
113
  return wrapper
86
114
 
87
- return decorator_fn
88
-
89
115
 
90
116
  LovelyMonkeyPatchInputFn = TypeAliasType(
91
117
  "LovelyMonkeyPatchInputFn",
nshutils/lovely/jax_.py CHANGED
@@ -53,7 +53,7 @@ def _device(array: jax.Array) -> str:
53
53
  return f"{device.platform}:{device.id}"
54
54
 
55
55
 
56
- @lovely_repr(dependencies=["jax"], fallback_repr=jax.Array.__repr__)
56
+ @lovely_repr(dependencies=["jax"])
57
57
  def jax_repr(array: jax.Array) -> LovelyStats | None:
58
58
  import jax.numpy as jnp
59
59
 
@@ -83,6 +83,7 @@ def jax_monkey_patch():
83
83
 
84
84
  prev_repr = array.ArrayImpl.__repr__
85
85
  prev_str = array.ArrayImpl.__str__
86
+ jax_repr.set_fallback_repr(prev_repr)
86
87
  try:
87
88
  patch_to(array.ArrayImpl, "__repr__", jax_repr)
88
89
  patch_to(array.ArrayImpl, "__str__", jax_repr)
nshutils/lovely/numpy_.py CHANGED
@@ -51,7 +51,7 @@ def _dtype_str(array: np.ndarray) -> str:
51
51
  return dtype_base
52
52
 
53
53
 
54
- @lovely_repr(dependencies=["numpy"], fallback_repr=np.array_repr)
54
+ @lovely_repr(dependencies=["numpy"])
55
55
  def numpy_repr(array: np.ndarray) -> LovelyStats | None:
56
56
  # For dtypes like `object` or `str`, we let the fallback repr handle it
57
57
  if not np.issubdtype(array.dtype, np.number):
@@ -71,6 +71,9 @@ def numpy_repr(array: np.ndarray) -> LovelyStats | None:
71
71
  }
72
72
 
73
73
 
74
+ numpy_repr.set_fallback_repr(np.array_repr)
75
+
76
+
74
77
  # If numpy 2.0, use the new API override_repr.
75
78
  if _np_ge_2():
76
79
 
nshutils/lovely/torch_.py CHANGED
@@ -59,7 +59,7 @@ def _to_np(tensor: torch.Tensor) -> np.ndarray:
59
59
  return t_np
60
60
 
61
61
 
62
- @lovely_repr(dependencies=["torch"], fallback_repr=torch.Tensor.__repr__)
62
+ @lovely_repr(dependencies=["torch"])
63
63
  def torch_repr(tensor: torch.Tensor) -> LovelyStats | None:
64
64
  return {
65
65
  # Basic attributes
@@ -87,10 +87,15 @@ def torch_monkey_patch():
87
87
  original_repr = torch.Tensor.__repr__
88
88
  original_str = torch.Tensor.__str__
89
89
  original_parameter_repr = torch.nn.Parameter.__repr__
90
+ torch_repr.set_fallback_repr(original_repr)
91
+
90
92
  try:
91
93
  patch_to(torch.Tensor, "__repr__", torch_repr)
92
94
  patch_to(torch.Tensor, "__str__", torch_repr)
93
- del torch.nn.Parameter.__repr__
95
+ try:
96
+ delattr(torch.nn.Parameter, "__repr__")
97
+ except AttributeError:
98
+ pass
94
99
 
95
100
  yield
96
101
  finally:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshutils
3
- Version: 0.22.6
3
+ Version: 0.30.1
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -2,21 +2,21 @@ nshutils/__init__.py,sha256=AFx1d5k34MyJ2kCHQL5vrZB8GDp2nYUaIUEjszSa25I,477
2
2
  nshutils/__init__.pyi,sha256=ICbY2_XBAlXIVOGyK4PQpatmlUFHHc5-bqM4sfFZoAY,613
3
3
  nshutils/actsave/__init__.py,sha256=hAVsog9d1g3_rQN1TRslrl6sK1PhCGbjy8PPUAmJI58,203
4
4
  nshutils/actsave/_loader.py,sha256=mof3HezUNvLliz7macstX6ewXW05L0Mtv3zJyrbmImg,4640
5
- nshutils/actsave/_saver.py,sha256=GyoJpeJIIG_Y2Hr3MTXylTM6YNyoInf2mvKs0foU7no,10528
5
+ nshutils/actsave/_saver.py,sha256=IS9TVP8WUizoj5fHrQ6hodtjidT__LDRwz5aoWHupVo,12013
6
6
  nshutils/collections.py,sha256=QWGyANmo4Efq4XRNHDSTE9tRLStwEZHGwE0ATHR-Vqo,5233
7
7
  nshutils/display.py,sha256=Ge63yllx7gi-MKL3mKQeQ5doql_nj56-o5aoTVmusDg,1473
8
8
  nshutils/logging.py,sha256=78pv3-I_gmbKSf5_mYYBr6_H4GNBGErghAdhH9wfYIc,2205
9
9
  nshutils/lovely/__init__.py,sha256=gbWMNs7xfK1CiNdkHvfH0KcyaGjdZ8_WUBGfaEUDN4I,451
10
- nshutils/lovely/_base.py,sha256=SH8MY6wbucB7us4pU1cekoU2glanutEB8vDrLDkLEL0,4544
10
+ nshutils/lovely/_base.py,sha256=-JYF2zci04PJjmkBdm_iV3uWgD_d7e5zCIAINDlQIKc,5266
11
11
  nshutils/lovely/_monkey_patch_all.py,sha256=zgMupp2Wc_O9R3arl-BAIePpvQSi6TCeshGMaui-Cc8,1986
12
12
  nshutils/lovely/config.py,sha256=lVNMuU1oUvsYlGN0Sn-m6iOLbJIchVnWDpyHm09nWo8,1224
13
- nshutils/lovely/jax_.py,sha256=J64CtP2mx_yPQ7ZiLKSrqEJfPDBBbx3j6o4zKc1f1eg,2404
14
- nshutils/lovely/numpy_.py,sha256=iaTUQpamudyWd7pTBJyj_hcS9X8H53muCw2IYY9ZcUM,3345
15
- nshutils/lovely/torch_.py,sha256=Jt0c7ysfrbgxORhL6fpbxCnfBDoYaTtTyc935yCJ7H0,2682
13
+ nshutils/lovely/jax_.py,sha256=c_hvlch_c9OZ0WJjFIeY46kKQcCELspwdmoexkKLsCg,2412
14
+ nshutils/lovely/numpy_.py,sha256=GDOOuhCYfShfKUZiuI8J91eAm27urrYyxETTR-Mxz0E,3362
15
+ nshutils/lovely/torch_.py,sha256=9diSkM1L2B6l0yQqTRBoZUVElyqgHUcJdFCXD3NvTxk,2767
16
16
  nshutils/lovely/utils.py,sha256=2ksT5YGVViFuWc8jSkwVCsABripJmyVJdEDDH7aab70,10459
17
17
  nshutils/snoop.py,sha256=7d7_Q5sJmINL1J29wcnxEvpV95zvZYNoVn5frCq-rww,7393
18
18
  nshutils/typecheck.py,sha256=Gi7xtfilN_UwZ1FTFqBVKDhcQzBEDonVxIv3bUj-uXY,5582
19
19
  nshutils/util.py,sha256=tx-XiRbOrpafV3OkJDE5IVFtzn3kN7uSZ8FkMor0H5c,2845
20
- nshutils-0.22.6.dist-info/METADATA,sha256=cclgVNUapsycUT65AJHSQXQpej_WEwwhRIl9vwrLMbY,4406
21
- nshutils-0.22.6.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
22
- nshutils-0.22.6.dist-info/RECORD,,
20
+ nshutils-0.30.1.dist-info/METADATA,sha256=iFnO4L_bpOdtCm8f-m7IvA4A8WtkWb5T07dYi-aUnzI,4406
21
+ nshutils-0.30.1.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
22
+ nshutils-0.30.1.dist-info/RECORD,,