nshutils 0.22.6__tar.gz → 0.31.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (23) hide show
  1. {nshutils-0.22.6 → nshutils-0.31.0}/PKG-INFO +1 -1
  2. {nshutils-0.22.6 → nshutils-0.31.0}/pyproject.toml +9 -1
  3. {nshutils-0.22.6 → nshutils-0.31.0}/src/nshutils/__init__.pyi +1 -0
  4. {nshutils-0.22.6 → nshutils-0.31.0}/src/nshutils/actsave/_loader.py +73 -3
  5. {nshutils-0.22.6 → nshutils-0.31.0}/src/nshutils/actsave/_saver.py +44 -2
  6. {nshutils-0.22.6 → nshutils-0.31.0}/src/nshutils/lovely/__init__.py +2 -0
  7. nshutils-0.31.0/src/nshutils/lovely/_base.py +160 -0
  8. {nshutils-0.22.6 → nshutils-0.31.0}/src/nshutils/lovely/_monkey_patch_all.py +28 -20
  9. {nshutils-0.22.6 → nshutils-0.31.0}/src/nshutils/lovely/jax_.py +21 -12
  10. {nshutils-0.22.6 → nshutils-0.31.0}/src/nshutils/lovely/numpy_.py +26 -23
  11. {nshutils-0.22.6 → nshutils-0.31.0}/src/nshutils/lovely/torch_.py +28 -16
  12. nshutils-0.22.6/src/nshutils/lovely/_base.py +0 -158
  13. nshutils-0.22.6/src/nshutils/util.py +0 -92
  14. {nshutils-0.22.6 → nshutils-0.31.0}/README.md +0 -0
  15. {nshutils-0.22.6 → nshutils-0.31.0}/src/nshutils/__init__.py +0 -0
  16. {nshutils-0.22.6 → nshutils-0.31.0}/src/nshutils/actsave/__init__.py +0 -0
  17. {nshutils-0.22.6 → nshutils-0.31.0}/src/nshutils/collections.py +0 -0
  18. {nshutils-0.22.6 → nshutils-0.31.0}/src/nshutils/display.py +0 -0
  19. {nshutils-0.22.6 → nshutils-0.31.0}/src/nshutils/logging.py +0 -0
  20. {nshutils-0.22.6 → nshutils-0.31.0}/src/nshutils/lovely/config.py +0 -0
  21. {nshutils-0.22.6 → nshutils-0.31.0}/src/nshutils/lovely/utils.py +0 -0
  22. {nshutils-0.22.6 → nshutils-0.31.0}/src/nshutils/snoop.py +0 -0
  23. {nshutils-0.22.6 → nshutils-0.31.0}/src/nshutils/typecheck.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshutils
3
- Version: 0.22.6
3
+ Version: 0.31.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "nshutils"
3
- version = "0.22.6"
3
+ version = "0.31.0"
4
4
  description = ""
5
5
  authors = [{ name = "Nima Shoghi", email = "nimashoghi@gmail.com" }]
6
6
  requires-python = ">=3.9,<4.0"
@@ -28,6 +28,8 @@ basedpyright = "*"
28
28
  ruff = "*"
29
29
  ipykernel = "*"
30
30
  ipywidgets = "*"
31
+ pytest = "*"
32
+ pytest-cov = "*"
31
33
 
32
34
  [build-system]
33
35
  requires = ["poetry-core"]
@@ -47,3 +49,9 @@ ignore = ["F722", "F821", "E731", "E741"]
47
49
 
48
50
  [tool.ruff.lint.isort]
49
51
  required-imports = ["from __future__ import annotations"]
52
+
53
+
54
+ [tool.pytest.ini_options]
55
+ testpaths = ["tests"]
56
+ python_files = ["test_*.py"]
57
+ addopts = "--cov=nshutils --cov-report=term-missing"
@@ -7,6 +7,7 @@ from .collections import apply_to_collection as apply_to_collection
7
7
  from .display import display as display
8
8
  from .logging import init_python_logging as init_python_logging
9
9
  from .logging import setup_logging as setup_logging
10
+ from .lovely import lovely_monkey_patch as lovely_monkey_patch
10
11
  from .snoop import snoop as snoop
11
12
  from .typecheck import tassert as tassert
12
13
  from .typecheck import typecheck_modules as typecheck_modules
@@ -105,14 +105,41 @@ class ActLoad:
105
105
  path, _ = max(all_versions, key=lambda p: p[1])
106
106
  return cls(path)
107
107
 
108
- def __init__(self, dir: Path):
108
+ def __init__(
109
+ self,
110
+ dir: Path | None = None,
111
+ *,
112
+ _base_activations: dict[str, LoadedActivation] | None = None,
113
+ _prefix_chain: list[str] | None = None,
114
+ ):
115
+ """Initialize ActLoad from a directory or from filtered activations.
116
+
117
+ Args:
118
+ dir: Path to the activation directory. Required for root ActLoad instances.
119
+ _base_activations: Pre-filtered activations dict. Used internally for prefix filtering.
120
+ _prefix_chain: Chain of prefixes that have been applied. Used for repr.
121
+ """
109
122
  self._dir = dir
123
+ self._base_activations = _base_activations
124
+ self._prefix_chain = _prefix_chain or []
110
125
 
111
126
  def activation(self, name: str):
112
- return LoadedActivation(self._dir, name)
127
+ if self._dir is None:
128
+ raise ValueError(
129
+ "Cannot create activation from filtered ActLoad without base directory"
130
+ )
131
+ # For filtered instances, we need to reconstruct the full name
132
+ full_name = "".join(self._prefix_chain) + name
133
+ return LoadedActivation(self._dir, full_name)
113
134
 
114
135
  @cached_property
115
136
  def activations(self):
137
+ if self._base_activations is not None:
138
+ return self._base_activations
139
+
140
+ if self._dir is None:
141
+ raise ValueError("ActLoad requires either dir or _base_activations")
142
+
116
143
  dirs = list(self._dir.iterdir())
117
144
  # Sort the dirs by the last modified time
118
145
  dirs.sort(key=lambda p: p.stat().st_mtime)
@@ -128,6 +155,9 @@ class ActLoad:
128
155
  def __len__(self):
129
156
  return len(self.activations)
130
157
 
158
+ def _ipython_key_completions_(self):
159
+ return list(self.activations.keys())
160
+
131
161
  @override
132
162
  def __repr__(self):
133
163
  acts_str = pprint.pformat(
@@ -137,10 +167,50 @@ class ActLoad:
137
167
  }
138
168
  )
139
169
  acts_str = acts_str.replace("'<", "<").replace(">'", ">")
140
- return f"ActLoad({acts_str})"
170
+
171
+ if self._prefix_chain:
172
+ prefix_str = "".join(self._prefix_chain)
173
+ return f"ActLoad(prefix='{prefix_str}', {acts_str})"
174
+ else:
175
+ return f"ActLoad({acts_str})"
141
176
 
142
177
  def get(self, name: str, /, default: T) -> LoadedActivation | T:
143
178
  return self.activations.get(name, default)
144
179
 
180
+ def filter_by_prefix(self, prefix: str) -> ActLoad:
181
+ """Create a filtered view of activations that match the given prefix.
182
+
183
+ Args:
184
+ prefix: The prefix to filter by. Only activations whose names start
185
+ with this prefix will be included in the filtered view.
186
+
187
+ Returns:
188
+ A new ActLoad instance that provides access to matching activations
189
+ with the prefix stripped from their keys. Can be chained multiple times.
190
+
191
+ Example:
192
+ >>> loader = ActLoad(some_dir)
193
+ >>> # If loader has keys "my.activation.first", "my.activation.second", "other.key"
194
+ >>> filtered = loader.filter_by_prefix("my.activation.")
195
+ >>> filtered["first"] # Accesses "my.activation.first"
196
+ >>> filtered["second"] # Accesses "my.activation.second"
197
+ >>> # Can be chained:
198
+ >>> double_filtered = loader.filter_by_prefix("my.").filter_by_prefix("activation.")
199
+ >>> double_filtered["first"] # Also accesses "my.activation.first"
200
+ """
201
+ filtered_activations = {}
202
+ for name, activation in self.activations.items():
203
+ if name.startswith(prefix):
204
+ # Strip the prefix from the key
205
+ stripped_name = name[len(prefix) :]
206
+ filtered_activations[stripped_name] = activation
207
+
208
+ new_prefix_chain = self._prefix_chain + [prefix]
209
+ return ActLoad(
210
+ _base_activations=filtered_activations,
211
+ _prefix_chain=new_prefix_chain,
212
+ dir=self._dir,
213
+ )
214
+
145
215
 
146
216
  ActivationLoader = ActLoad
@@ -241,6 +241,17 @@ class _Saver:
241
241
  class ActSaveProvider:
242
242
  _saver: _Saver | None = None
243
243
  _prefixes: list[str] = []
244
+ _disable_count: int = 0
245
+
246
+ @property
247
+ def is_initialized(self) -> bool:
248
+ """Returns True if ActSave.enable() has been called and not subsequently disabled."""
249
+ return self._saver is not None
250
+
251
+ @property
252
+ def is_enabled(self) -> bool:
253
+ """Returns True if ActSave is currently active and will save activations."""
254
+ return self.is_initialized and self._disable_count == 0
244
255
 
245
256
  def enable(self, save_dir: Path | None = None):
246
257
  """
@@ -294,6 +305,34 @@ class ActSaveProvider:
294
305
 
295
306
  self._saver = None
296
307
  self._prefixes = []
308
+ self._disable_count = 0
309
+
310
+ @contextlib.contextmanager
311
+ def disabled(self, condition: bool | Callable[[], bool] = True):
312
+ """
313
+ Context manager to temporarily disable activation saving.
314
+
315
+ Args:
316
+ condition (bool | Callable[[], bool], optional):
317
+ If True or a callable returning True, saving is disabled within this context.
318
+ Defaults to True.
319
+ """
320
+ if _torch_is_scripting():
321
+ yield
322
+ return
323
+
324
+ should_disable = condition() if callable(condition) else condition
325
+ if should_disable:
326
+ self._disable_count += 1
327
+
328
+ try:
329
+ yield
330
+ finally:
331
+ if should_disable:
332
+ self._disable_count -= 1
333
+ if self._disable_count < 0: # Should not happen
334
+ log.warning("ActSave disable count went below zero.")
335
+ self._disable_count = 0
297
336
 
298
337
  @contextlib.contextmanager
299
338
  def context(self, label: str):
@@ -307,7 +346,7 @@ class ActSaveProvider:
307
346
  yield
308
347
  return
309
348
 
310
- if self._saver is None:
349
+ if not self.is_enabled:
311
350
  yield
312
351
  return
313
352
 
@@ -368,9 +407,12 @@ class ActSaveProvider:
368
407
  if _torch_is_scripting():
369
408
  return
370
409
 
371
- if self._saver is None:
410
+ if not self.is_enabled:
372
411
  return
373
412
 
413
+ # Ensure _saver is not None, which is guaranteed by is_enabled but mypy needs help
414
+ assert self._saver is not None
415
+
374
416
  if acts is not None and callable(acts):
375
417
  acts = acts()
376
418
  self._saver.save(acts, **kwargs)
@@ -8,3 +8,5 @@ from .numpy_ import numpy_monkey_patch as numpy_monkey_patch
8
8
  from .numpy_ import numpy_repr as numpy_repr
9
9
  from .torch_ import torch_monkey_patch as torch_monkey_patch
10
10
  from .torch_ import torch_repr as torch_repr
11
+
12
+ lovely_monkey_patch = monkey_patch
@@ -0,0 +1,160 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import functools
5
+ import importlib.util
6
+ import logging
7
+ from abc import ABC, abstractmethod
8
+ from collections.abc import Callable
9
+ from typing import Any, Generic, Optional, cast
10
+
11
+ from typing_extensions import (
12
+ ParamSpec,
13
+ Protocol,
14
+ TypeAliasType,
15
+ TypeVar,
16
+ override,
17
+ runtime_checkable,
18
+ )
19
+
20
+ from .utils import LovelyStats, format_tensor_stats
21
+
22
+ log = logging.getLogger(__name__)
23
+
24
+ TArray = TypeVar("TArray", infer_variance=True)
25
+ P = ParamSpec("P")
26
+
27
+ LovelyStatsFn = TypeAliasType(
28
+ "LovelyStatsFn",
29
+ Callable[[TArray], Optional[LovelyStats]],
30
+ type_params=(TArray,),
31
+ )
32
+
33
+
34
+ @runtime_checkable
35
+ class LovelyReprFn(Protocol[TArray]):
36
+ @property
37
+ def __lovely_repr_instance__(self) -> lovely_repr[TArray]: ...
38
+
39
+ @__lovely_repr_instance__.setter
40
+ def __lovely_repr_instance__(self, value: lovely_repr[TArray]) -> None: ...
41
+
42
+ @property
43
+ def __name__(self) -> str: ...
44
+
45
+ def set_fallback_repr(self, repr_fn: Callable[[TArray], str]) -> None: ...
46
+ def __call__(self, value: TArray, /) -> str: ...
47
+
48
+
49
+ def _find_missing_deps(dependencies: list[str]):
50
+ missing_deps: list[str] = []
51
+
52
+ for dep in dependencies:
53
+ if importlib.util.find_spec(dep) is not None:
54
+ continue
55
+
56
+ missing_deps.append(dep)
57
+
58
+ return missing_deps
59
+
60
+
61
+ class lovely_repr(Generic[TArray]):
62
+ @override
63
+ def __init__(
64
+ self,
65
+ dependencies: list[str],
66
+ fallback_repr: Callable[[TArray], str] | None = None,
67
+ ):
68
+ """
69
+ Decorator to create a lovely representation function for an array.
70
+
71
+ Args:
72
+ dependencies: List of dependencies to check before running the function.
73
+ If any dependency is not available, the function will not run.
74
+ fallback_repr: A function that takes an array and returns its fallback representation.
75
+ Returns:
76
+ A decorator function that takes a function and returns a lovely representation function.
77
+
78
+ Example:
79
+ @lovely_repr(dependencies=["torch"])
80
+ def my_array_stats(array):
81
+ return {...}
82
+ """
83
+ super().__init__()
84
+
85
+ if fallback_repr is None:
86
+ fallback_repr = repr
87
+
88
+ self._dependencies = dependencies
89
+ self._fallback_repr = fallback_repr
90
+
91
+ def set_fallback_repr(self, repr_fn: Callable[[TArray], str]) -> None:
92
+ self._fallback_repr = repr_fn
93
+
94
+ def __call__(
95
+ self, array_stats_fn: LovelyStatsFn[TArray], /
96
+ ) -> LovelyReprFn[TArray]:
97
+ @functools.wraps(array_stats_fn)
98
+ def wrapper_fn(array: TArray) -> str:
99
+ if missing_deps := _find_missing_deps(self._dependencies):
100
+ log.warning(
101
+ f"Missing dependencies: {', '.join(missing_deps)}. "
102
+ "Skipping lovely representation."
103
+ )
104
+ return self._fallback_repr(array)
105
+
106
+ if (stats := array_stats_fn(array)) is None:
107
+ return self._fallback_repr(array)
108
+
109
+ return format_tensor_stats(stats)
110
+
111
+ wrapper = cast(LovelyReprFn[TArray], wrapper_fn)
112
+ wrapper.__lovely_repr_instance__ = self
113
+ wrapper.set_fallback_repr = self.set_fallback_repr
114
+ return wrapper
115
+
116
+
117
+ class lovely_patch(contextlib.AbstractContextManager["lovely_patch"], ABC):
118
+ def __init__(self):
119
+ self._patched = False
120
+ self.__enter__()
121
+
122
+ def dependencies(self) -> list[str]:
123
+ """Subclasses can override this to specify the dependencies of the patch."""
124
+ return []
125
+
126
+ @abstractmethod
127
+ def patch(self):
128
+ """Subclasses must implement this."""
129
+
130
+ @abstractmethod
131
+ def unpatch(self):
132
+ """Subclasses must implement this."""
133
+
134
+ @override
135
+ def __enter__(self):
136
+ if self._patched:
137
+ return self
138
+
139
+ if missing_deps := _find_missing_deps(self.dependencies()):
140
+ log.warning(
141
+ f"Missing dependencies: {', '.join(missing_deps)}. "
142
+ "Skipping monkey patch."
143
+ )
144
+ return self
145
+
146
+ self.patch()
147
+ self._patched = True
148
+ return self
149
+
150
+ @override
151
+ def __exit__(self, *exc_info):
152
+ if not self._patched:
153
+ return
154
+
155
+ self.unpatch()
156
+ self._patched = False
157
+
158
+ def close(self):
159
+ """Explicitly clean up the resource."""
160
+ self.__exit__(None, None, None)
@@ -5,9 +5,9 @@ import importlib.util
5
5
  import logging
6
6
  from typing import Literal
7
7
 
8
- from typing_extensions import TypeAliasType, assert_never
8
+ from typing_extensions import TypeAliasType, assert_never, override
9
9
 
10
- from ..util import resource_factory_contextmanager
10
+ from ._base import lovely_patch
11
11
 
12
12
  Library = TypeAliasType("Library", Literal["numpy", "torch", "jax"])
13
13
 
@@ -28,40 +28,48 @@ def _find_deps() -> list[Library]:
28
28
  return deps
29
29
 
30
30
 
31
- @resource_factory_contextmanager
32
- def monkey_patch(libraries: list[Library] | Literal["auto"] = "auto"):
33
- if libraries == "auto":
34
- libraries = _find_deps()
31
+ class monkey_patch(lovely_patch):
32
+ def __init__(self, libraries: list[Library] | Literal["auto"] = "auto"):
33
+ self.libraries = libraries
34
+ if self.libraries == "auto":
35
+ self.libraries = _find_deps()
35
36
 
36
- if not libraries:
37
- raise ValueError(
38
- "No libraries found for monkey patching. "
39
- "Please install numpy, torch, or jax."
40
- )
37
+ if not self.libraries:
38
+ raise ValueError(
39
+ "No libraries found for monkey patching. "
40
+ "Please install numpy, torch, or jax."
41
+ )
42
+
43
+ self.stack = contextlib.ExitStack()
44
+ super().__init__()
41
45
 
42
- with contextlib.ExitStack() as stack:
43
- for library in libraries:
46
+ @override
47
+ def patch(self):
48
+ for library in self.libraries:
44
49
  if library == "torch":
45
50
  from .torch_ import torch_monkey_patch
46
51
 
47
- stack.enter_context(torch_monkey_patch())
52
+ self.stack.enter_context(torch_monkey_patch())
48
53
  elif library == "jax":
49
54
  from .jax_ import jax_monkey_patch
50
55
 
51
- stack.enter_context(jax_monkey_patch())
56
+ self.stack.enter_context(jax_monkey_patch())
52
57
  elif library == "numpy":
53
58
  from .numpy_ import numpy_monkey_patch
54
59
 
55
- stack.enter_context(numpy_monkey_patch())
60
+ self.stack.enter_context(numpy_monkey_patch())
56
61
  else:
57
- assert_never(library)
62
+ assert_never(library) # type: ignore
58
63
 
59
64
  log.info(
60
- f"Monkey patched libraries: {', '.join(libraries)}. "
65
+ f"Monkey patched libraries: {', '.join(self.libraries)}. "
61
66
  "You can now use the lovely functions with these libraries."
62
67
  )
63
- yield
68
+
69
+ @override
70
+ def unpatch(self):
71
+ self.stack.close()
64
72
  log.info(
65
- f"Unmonkey patched libraries: {', '.join(libraries)}. "
73
+ f"Unmonkey patched libraries: {', '.join(self.libraries)}. "
66
74
  "You can now use the lovely functions with these libraries."
67
75
  )
@@ -3,8 +3,9 @@ from __future__ import annotations
3
3
  from typing import TYPE_CHECKING, cast
4
4
 
5
5
  import numpy as np
6
+ from typing_extensions import override
6
7
 
7
- from ._base import lovely_repr, monkey_patch_contextmanager
8
+ from ._base import lovely_patch, lovely_repr
8
9
  from .utils import LovelyStats, array_stats, patch_to
9
10
 
10
11
  if TYPE_CHECKING:
@@ -53,7 +54,7 @@ def _device(array: jax.Array) -> str:
53
54
  return f"{device.platform}:{device.id}"
54
55
 
55
56
 
56
- @lovely_repr(dependencies=["jax"], fallback_repr=jax.Array.__repr__)
57
+ @lovely_repr(dependencies=["jax"])
57
58
  def jax_repr(array: jax.Array) -> LovelyStats | None:
58
59
  import jax.numpy as jnp
59
60
 
@@ -77,17 +78,25 @@ def jax_repr(array: jax.Array) -> LovelyStats | None:
77
78
  }
78
79
 
79
80
 
80
- @monkey_patch_contextmanager(dependencies=["jax"])
81
- def jax_monkey_patch():
82
- from jax._src import array
81
+ class jax_monkey_patch(lovely_patch):
82
+ @override
83
+ def dependencies(self) -> list[str]:
84
+ return ["jax"]
85
+
86
+ @override
87
+ def patch(self):
88
+ from jax._src import array
89
+
90
+ self.prev_repr = array.ArrayImpl.__repr__
91
+ self.prev_str = array.ArrayImpl.__str__
92
+ jax_repr.set_fallback_repr(self.prev_repr)
83
93
 
84
- prev_repr = array.ArrayImpl.__repr__
85
- prev_str = array.ArrayImpl.__str__
86
- try:
87
94
  patch_to(array.ArrayImpl, "__repr__", jax_repr)
88
95
  patch_to(array.ArrayImpl, "__str__", jax_repr)
89
96
 
90
- yield
91
- finally:
92
- patch_to(array.ArrayImpl, "__repr__", prev_repr)
93
- patch_to(array.ArrayImpl, "__str__", prev_str)
97
+ @override
98
+ def unpatch(self):
99
+ from jax._src import array
100
+
101
+ patch_to(array.ArrayImpl, "__repr__", self.prev_repr)
102
+ patch_to(array.ArrayImpl, "__str__", self.prev_str)
@@ -3,8 +3,9 @@ from __future__ import annotations
3
3
  import logging
4
4
 
5
5
  import numpy as np
6
+ from typing_extensions import override
6
7
 
7
- from ._base import lovely_repr, monkey_patch_contextmanager
8
+ from ._base import lovely_patch, lovely_repr
8
9
  from .utils import LovelyStats, array_stats
9
10
 
10
11
 
@@ -51,7 +52,7 @@ def _dtype_str(array: np.ndarray) -> str:
51
52
  return dtype_base
52
53
 
53
54
 
54
- @lovely_repr(dependencies=["numpy"], fallback_repr=np.array_repr)
55
+ @lovely_repr(dependencies=["numpy"])
55
56
  def numpy_repr(array: np.ndarray) -> LovelyStats | None:
56
57
  # For dtypes like `object` or `str`, we let the fallback repr handle it
57
58
  if not np.issubdtype(array.dtype, np.number):
@@ -71,42 +72,44 @@ def numpy_repr(array: np.ndarray) -> LovelyStats | None:
71
72
  }
72
73
 
73
74
 
74
- # If numpy 2.0, use the new API override_repr.
75
- if _np_ge_2():
75
+ numpy_repr.set_fallback_repr(np.array_repr)
76
76
 
77
- @monkey_patch_contextmanager(dependencies=["numpy"])
78
- def numpy_monkey_patch():
79
- try:
77
+
78
+ class numpy_monkey_patch(lovely_patch):
79
+ @override
80
+ def dependencies(self) -> list[str]:
81
+ return ["numpy"]
82
+
83
+ @override
84
+ def patch(self):
85
+ if _np_ge_2():
86
+ self.original_options = np.get_printoptions()
80
87
  np.set_printoptions(override_repr=numpy_repr)
81
88
  logging.info(
82
89
  f"Numpy monkey patching: using {numpy_repr.__name__} for numpy arrays. "
83
90
  f"{np.get_printoptions()=}"
84
91
  )
85
- yield
86
- finally:
87
- np.set_printoptions(override_repr=None)
92
+ else:
93
+ # For legacy numpy, `set_string_function(None)` reverts to the default,
94
+ # so no state needs to be saved.
95
+ np.set_string_function(numpy_repr, True) # pyright: ignore[reportAttributeAccessIssue]
96
+ np.set_string_function(numpy_repr, False) # pyright: ignore[reportAttributeAccessIssue]
88
97
  logging.info(
89
- f"Numpy unmonkey patching: using {numpy_repr.__name__} for numpy arrays. "
98
+ f"Numpy monkey patching: using {numpy_repr.__name__} for numpy arrays. "
90
99
  f"{np.get_printoptions()=}"
91
100
  )
92
101
 
93
- else:
94
-
95
- @monkey_patch_contextmanager(dependencies=["numpy"])
96
- def numpy_monkey_patch():
97
- try:
98
- np.set_string_function(numpy_repr, True) # pyright: ignore[reportAttributeAccessIssue]
99
- np.set_string_function(numpy_repr, False) # pyright: ignore[reportAttributeAccessIssue]
100
-
102
+ @override
103
+ def unpatch(self):
104
+ if _np_ge_2():
105
+ np.set_printoptions(**self.original_options)
101
106
  logging.info(
102
- f"Numpy monkey patching: using {numpy_repr.__name__} for numpy arrays. "
107
+ f"Numpy unmonkey patching: using {numpy_repr.__name__} for numpy arrays. "
103
108
  f"{np.get_printoptions()=}"
104
109
  )
105
- yield
106
- finally:
110
+ else:
107
111
  np.set_string_function(None, True) # pyright: ignore[reportAttributeAccessIssue]
108
112
  np.set_string_function(None, False) # pyright: ignore[reportAttributeAccessIssue]
109
-
110
113
  logging.info(
111
114
  f"Numpy unmonkey patching: using {numpy_repr.__name__} for numpy arrays. "
112
115
  f"{np.get_printoptions()=}"
@@ -3,8 +3,9 @@ from __future__ import annotations
3
3
  from typing import TYPE_CHECKING
4
4
 
5
5
  import numpy as np
6
+ from typing_extensions import override
6
7
 
7
- from ._base import lovely_repr, monkey_patch_contextmanager
8
+ from ._base import lovely_patch, lovely_repr
8
9
  from .utils import LovelyStats, array_stats, patch_to
9
10
 
10
11
  if TYPE_CHECKING:
@@ -59,7 +60,7 @@ def _to_np(tensor: torch.Tensor) -> np.ndarray:
59
60
  return t_np
60
61
 
61
62
 
62
- @lovely_repr(dependencies=["torch"], fallback_repr=torch.Tensor.__repr__)
63
+ @lovely_repr(dependencies=["torch"])
63
64
  def torch_repr(tensor: torch.Tensor) -> LovelyStats | None:
64
65
  return {
65
66
  # Basic attributes
@@ -80,20 +81,31 @@ def torch_repr(tensor: torch.Tensor) -> LovelyStats | None:
80
81
  }
81
82
 
82
83
 
83
- @monkey_patch_contextmanager(dependencies=["torch"])
84
- def torch_monkey_patch():
85
- import torch
84
+ class torch_monkey_patch(lovely_patch):
85
+ @override
86
+ def dependencies(self) -> list[str]:
87
+ return ["torch"]
88
+
89
+ @override
90
+ def patch(self):
91
+ import torch
92
+
93
+ self.original_repr = torch.Tensor.__repr__
94
+ self.original_str = torch.Tensor.__str__
95
+ self.original_parameter_repr = torch.nn.Parameter.__repr__
96
+ torch_repr.set_fallback_repr(self.original_repr)
86
97
 
87
- original_repr = torch.Tensor.__repr__
88
- original_str = torch.Tensor.__str__
89
- original_parameter_repr = torch.nn.Parameter.__repr__
90
- try:
91
98
  patch_to(torch.Tensor, "__repr__", torch_repr)
92
99
  patch_to(torch.Tensor, "__str__", torch_repr)
93
- del torch.nn.Parameter.__repr__
94
-
95
- yield
96
- finally:
97
- patch_to(torch.Tensor, "__repr__", original_repr)
98
- patch_to(torch.Tensor, "__str__", original_str)
99
- patch_to(torch.nn.Parameter, "__repr__", original_parameter_repr)
100
+ try:
101
+ delattr(torch.nn.Parameter, "__repr__")
102
+ except AttributeError:
103
+ pass
104
+
105
+ @override
106
+ def unpatch(self):
107
+ import torch
108
+
109
+ patch_to(torch.Tensor, "__repr__", self.original_repr)
110
+ patch_to(torch.Tensor, "__str__", self.original_str)
111
+ patch_to(torch.nn.Parameter, "__repr__", self.original_parameter_repr)
@@ -1,158 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import functools
4
- import importlib.util
5
- import logging
6
- from collections.abc import Callable, Iterator
7
-
8
- from typing_extensions import ParamSpec, TypeAliasType, TypeVar
9
-
10
- from ..util import ContextResource, resource_factory_contextmanager
11
- from .utils import LovelyStats, format_tensor_stats
12
-
13
- log = logging.getLogger(__name__)
14
-
15
- TArray = TypeVar("TArray", infer_variance=True)
16
- P = ParamSpec("P")
17
-
18
- LovelyStatsFn = TypeAliasType(
19
- "LovelyStatsFn",
20
- Callable[[TArray], LovelyStats | None],
21
- type_params=(TArray,),
22
- )
23
- LovelyReprFn = TypeAliasType(
24
- "LovelyReprFn",
25
- Callable[[TArray], str],
26
- type_params=(TArray,),
27
- )
28
-
29
-
30
- def _find_missing_deps(dependencies: list[str]):
31
- missing_deps: list[str] = []
32
-
33
- for dep in dependencies:
34
- if importlib.util.find_spec(dep) is not None:
35
- continue
36
-
37
- missing_deps.append(dep)
38
-
39
- return missing_deps
40
-
41
-
42
- def lovely_repr(dependencies: list[str], fallback_repr: Callable[[TArray], str]):
43
- """
44
- Decorator to create a lovely representation function for an array.
45
-
46
- Args:
47
- dependencies: List of dependencies to check before running the function.
48
- If any dependency is not available, the function will not run.
49
- fallback_repr: A function that takes an array and returns its fallback representation.
50
- Returns:
51
- A decorator function that takes a function and returns a lovely representation function.
52
-
53
- Example:
54
- @lovely_repr(dependencies=["torch"])
55
- def my_array_stats(array):
56
- return {...}
57
- """
58
-
59
- def decorator_fn(array_stats_fn: LovelyStatsFn[TArray]) -> LovelyReprFn[TArray]:
60
- """
61
- Decorator to create a lovely representation function for an array.
62
-
63
- Args:
64
- array_stats_fn: A function that takes an array and returns its stats,
65
- or `None` if the array is not supported.
66
-
67
- Returns:
68
- A function that takes an array and returns its lovely representation.
69
- """
70
-
71
- @functools.wraps(array_stats_fn)
72
- def wrapper(array: TArray) -> str:
73
- if missing_deps := _find_missing_deps(dependencies):
74
- log.warning(
75
- f"Missing dependencies: {', '.join(missing_deps)}. "
76
- "Skipping lovely representation."
77
- )
78
- return fallback_repr(array)
79
-
80
- if (stats := array_stats_fn(array)) is None:
81
- return fallback_repr(array)
82
-
83
- return format_tensor_stats(stats)
84
-
85
- return wrapper
86
-
87
- return decorator_fn
88
-
89
-
90
- LovelyMonkeyPatchInputFn = TypeAliasType(
91
- "LovelyMonkeyPatchInputFn",
92
- Callable[P, Iterator[None]],
93
- type_params=(P,),
94
- )
95
- LovelyMonkeyPatchFn = TypeAliasType(
96
- "LovelyMonkeyPatchFn",
97
- Callable[P, ContextResource[None]],
98
- type_params=(P,),
99
- )
100
-
101
-
102
- def _nullcontext_generator():
103
- """A generator that does nothing."""
104
- yield
105
-
106
-
107
- def _wrap_monkey_patch_fn(
108
- monkey_patch_fn: LovelyMonkeyPatchInputFn[P],
109
- dependencies: list[str],
110
- ) -> LovelyMonkeyPatchInputFn[P]:
111
- @functools.wraps(monkey_patch_fn)
112
- def wrapper(*args: P.args, **kwargs: P.kwargs) -> Iterator[None]:
113
- if missing_deps := _find_missing_deps(dependencies):
114
- log.warning(
115
- f"Missing dependencies: {', '.join(missing_deps)}. "
116
- "Skipping monkey patch."
117
- )
118
- return _nullcontext_generator()
119
-
120
- return monkey_patch_fn(*args, **kwargs)
121
-
122
- return wrapper
123
-
124
-
125
- def monkey_patch_contextmanager(dependencies: list[str]):
126
- """
127
- Decorator to create a monkey patch function for an array.
128
-
129
- Args:
130
- dependencies: List of dependencies to check before running the function.
131
- If any dependency is not available, the function will not run.
132
-
133
- Returns:
134
- A decorator function that takes a function and returns a monkey patch function.
135
-
136
- Example:
137
- @monkey_patch_contextmanager(dependencies=["torch"])
138
- def my_array_monkey_patch():
139
- ...
140
- """
141
-
142
- def decorator_fn(
143
- monkey_patch_fn: LovelyMonkeyPatchInputFn[P],
144
- ) -> LovelyMonkeyPatchFn[P]:
145
- """
146
- Decorator to create a monkey patch function for an array.
147
-
148
- Args:
149
- monkey_patch_fn: A function that applies the monkey patch.
150
-
151
- Returns:
152
- A function that applies the monkey patch.
153
- """
154
-
155
- wrapped_fn = _wrap_monkey_patch_fn(monkey_patch_fn, dependencies)
156
- return resource_factory_contextmanager(wrapped_fn)
157
-
158
- return decorator_fn
@@ -1,92 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import contextlib
4
- import functools
5
- from collections.abc import Callable, Iterator
6
- from typing import Any, Generic
7
-
8
- from typing_extensions import ParamSpec, TypeVar, override
9
-
10
- R = TypeVar("R")
11
- P = ParamSpec("P")
12
-
13
-
14
- class ContextResource(contextlib.AbstractContextManager[R], Generic[R]):
15
- """A class that provides both direct access to a resource and context management."""
16
-
17
- def __init__(self, resource: R, cleanup_func: Callable[[R], Any]):
18
- self.resource = resource
19
- self._cleanup_func = cleanup_func
20
-
21
- @override
22
- def __enter__(self) -> R:
23
- """When used as a context manager, return the wrapped resource."""
24
- return self.resource
25
-
26
- @override
27
- def __exit__(self, *exc_info) -> None:
28
- """Clean up the resource when exiting the context."""
29
- self._cleanup_func(self.resource)
30
-
31
- def close(self) -> None:
32
- """Explicitly clean up the resource."""
33
- self._cleanup_func(self.resource)
34
-
35
-
36
- def resource_factory(
37
- create_func: Callable[P, R], cleanup_func: Callable[[R], None]
38
- ) -> Callable[P, ContextResource[R]]:
39
- """
40
- Create a factory function that returns a ContextResource.
41
-
42
- Args:
43
- create_func: Function that creates the resource
44
- cleanup_func: Function that cleans up the resource
45
-
46
- Returns:
47
- A function that returns a ContextResource wrapping the created resource
48
- """
49
-
50
- @functools.wraps(create_func)
51
- def wrapper(*args: P.args, **kwargs: P.kwargs) -> ContextResource[R]:
52
- resource = create_func(*args, **kwargs)
53
- return ContextResource(resource, cleanup_func)
54
-
55
- return wrapper
56
-
57
-
58
- def resource_factory_from_context_fn(
59
- context_func: Callable[P, contextlib.AbstractContextManager[R]],
60
- ) -> Callable[P, ContextResource[R]]:
61
- """
62
- Create a factory function that returns a ContextResource.
63
-
64
- Args:
65
- context_func: Function that creates the resource
66
-
67
- Returns:
68
- A function that returns a ContextResource wrapping the created resource
69
- """
70
-
71
- @functools.wraps(context_func)
72
- def wrapper(*args: P.args, **kwargs: P.kwargs) -> ContextResource[R]:
73
- context = context_func(*args, **kwargs)
74
- resource = context.__enter__()
75
- return ContextResource(resource, lambda _: context.__exit__(None, None, None))
76
-
77
- return wrapper
78
-
79
-
80
- def resource_factory_contextmanager(
81
- context_func: Callable[P, Iterator[R]],
82
- ) -> Callable[P, ContextResource[R]]:
83
- """
84
- Create a factory function that returns a ContextResource.
85
-
86
- Args:
87
- context_func: Generator function that creates the resource, yields it, and cleans up the resource when done.
88
-
89
- Returns:
90
- A function that returns a ContextResource wrapping the created resource
91
- """
92
- return resource_factory_from_context_fn(contextlib.contextmanager(context_func))
File without changes