nshutils 0.22.6__tar.gz → 0.31.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nshutils-0.22.6 → nshutils-0.31.0}/PKG-INFO +1 -1
- {nshutils-0.22.6 → nshutils-0.31.0}/pyproject.toml +9 -1
- {nshutils-0.22.6 → nshutils-0.31.0}/src/nshutils/__init__.pyi +1 -0
- {nshutils-0.22.6 → nshutils-0.31.0}/src/nshutils/actsave/_loader.py +73 -3
- {nshutils-0.22.6 → nshutils-0.31.0}/src/nshutils/actsave/_saver.py +44 -2
- {nshutils-0.22.6 → nshutils-0.31.0}/src/nshutils/lovely/__init__.py +2 -0
- nshutils-0.31.0/src/nshutils/lovely/_base.py +160 -0
- {nshutils-0.22.6 → nshutils-0.31.0}/src/nshutils/lovely/_monkey_patch_all.py +28 -20
- {nshutils-0.22.6 → nshutils-0.31.0}/src/nshutils/lovely/jax_.py +21 -12
- {nshutils-0.22.6 → nshutils-0.31.0}/src/nshutils/lovely/numpy_.py +26 -23
- {nshutils-0.22.6 → nshutils-0.31.0}/src/nshutils/lovely/torch_.py +28 -16
- nshutils-0.22.6/src/nshutils/lovely/_base.py +0 -158
- nshutils-0.22.6/src/nshutils/util.py +0 -92
- {nshutils-0.22.6 → nshutils-0.31.0}/README.md +0 -0
- {nshutils-0.22.6 → nshutils-0.31.0}/src/nshutils/__init__.py +0 -0
- {nshutils-0.22.6 → nshutils-0.31.0}/src/nshutils/actsave/__init__.py +0 -0
- {nshutils-0.22.6 → nshutils-0.31.0}/src/nshutils/collections.py +0 -0
- {nshutils-0.22.6 → nshutils-0.31.0}/src/nshutils/display.py +0 -0
- {nshutils-0.22.6 → nshutils-0.31.0}/src/nshutils/logging.py +0 -0
- {nshutils-0.22.6 → nshutils-0.31.0}/src/nshutils/lovely/config.py +0 -0
- {nshutils-0.22.6 → nshutils-0.31.0}/src/nshutils/lovely/utils.py +0 -0
- {nshutils-0.22.6 → nshutils-0.31.0}/src/nshutils/snoop.py +0 -0
- {nshutils-0.22.6 → nshutils-0.31.0}/src/nshutils/typecheck.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
[project]
|
2
2
|
name = "nshutils"
|
3
|
-
version = "0.
|
3
|
+
version = "0.31.0"
|
4
4
|
description = ""
|
5
5
|
authors = [{ name = "Nima Shoghi", email = "nimashoghi@gmail.com" }]
|
6
6
|
requires-python = ">=3.9,<4.0"
|
@@ -28,6 +28,8 @@ basedpyright = "*"
|
|
28
28
|
ruff = "*"
|
29
29
|
ipykernel = "*"
|
30
30
|
ipywidgets = "*"
|
31
|
+
pytest = "*"
|
32
|
+
pytest-cov = "*"
|
31
33
|
|
32
34
|
[build-system]
|
33
35
|
requires = ["poetry-core"]
|
@@ -47,3 +49,9 @@ ignore = ["F722", "F821", "E731", "E741"]
|
|
47
49
|
|
48
50
|
[tool.ruff.lint.isort]
|
49
51
|
required-imports = ["from __future__ import annotations"]
|
52
|
+
|
53
|
+
|
54
|
+
[tool.pytest.ini_options]
|
55
|
+
testpaths = ["tests"]
|
56
|
+
python_files = ["test_*.py"]
|
57
|
+
addopts = "--cov=nshutils --cov-report=term-missing"
|
@@ -7,6 +7,7 @@ from .collections import apply_to_collection as apply_to_collection
|
|
7
7
|
from .display import display as display
|
8
8
|
from .logging import init_python_logging as init_python_logging
|
9
9
|
from .logging import setup_logging as setup_logging
|
10
|
+
from .lovely import lovely_monkey_patch as lovely_monkey_patch
|
10
11
|
from .snoop import snoop as snoop
|
11
12
|
from .typecheck import tassert as tassert
|
12
13
|
from .typecheck import typecheck_modules as typecheck_modules
|
@@ -105,14 +105,41 @@ class ActLoad:
|
|
105
105
|
path, _ = max(all_versions, key=lambda p: p[1])
|
106
106
|
return cls(path)
|
107
107
|
|
108
|
-
def __init__(
|
108
|
+
def __init__(
|
109
|
+
self,
|
110
|
+
dir: Path | None = None,
|
111
|
+
*,
|
112
|
+
_base_activations: dict[str, LoadedActivation] | None = None,
|
113
|
+
_prefix_chain: list[str] | None = None,
|
114
|
+
):
|
115
|
+
"""Initialize ActLoad from a directory or from filtered activations.
|
116
|
+
|
117
|
+
Args:
|
118
|
+
dir: Path to the activation directory. Required for root ActLoad instances.
|
119
|
+
_base_activations: Pre-filtered activations dict. Used internally for prefix filtering.
|
120
|
+
_prefix_chain: Chain of prefixes that have been applied. Used for repr.
|
121
|
+
"""
|
109
122
|
self._dir = dir
|
123
|
+
self._base_activations = _base_activations
|
124
|
+
self._prefix_chain = _prefix_chain or []
|
110
125
|
|
111
126
|
def activation(self, name: str):
|
112
|
-
|
127
|
+
if self._dir is None:
|
128
|
+
raise ValueError(
|
129
|
+
"Cannot create activation from filtered ActLoad without base directory"
|
130
|
+
)
|
131
|
+
# For filtered instances, we need to reconstruct the full name
|
132
|
+
full_name = "".join(self._prefix_chain) + name
|
133
|
+
return LoadedActivation(self._dir, full_name)
|
113
134
|
|
114
135
|
@cached_property
|
115
136
|
def activations(self):
|
137
|
+
if self._base_activations is not None:
|
138
|
+
return self._base_activations
|
139
|
+
|
140
|
+
if self._dir is None:
|
141
|
+
raise ValueError("ActLoad requires either dir or _base_activations")
|
142
|
+
|
116
143
|
dirs = list(self._dir.iterdir())
|
117
144
|
# Sort the dirs by the last modified time
|
118
145
|
dirs.sort(key=lambda p: p.stat().st_mtime)
|
@@ -128,6 +155,9 @@ class ActLoad:
|
|
128
155
|
def __len__(self):
|
129
156
|
return len(self.activations)
|
130
157
|
|
158
|
+
def _ipython_key_completions_(self):
|
159
|
+
return list(self.activations.keys())
|
160
|
+
|
131
161
|
@override
|
132
162
|
def __repr__(self):
|
133
163
|
acts_str = pprint.pformat(
|
@@ -137,10 +167,50 @@ class ActLoad:
|
|
137
167
|
}
|
138
168
|
)
|
139
169
|
acts_str = acts_str.replace("'<", "<").replace(">'", ">")
|
140
|
-
|
170
|
+
|
171
|
+
if self._prefix_chain:
|
172
|
+
prefix_str = "".join(self._prefix_chain)
|
173
|
+
return f"ActLoad(prefix='{prefix_str}', {acts_str})"
|
174
|
+
else:
|
175
|
+
return f"ActLoad({acts_str})"
|
141
176
|
|
142
177
|
def get(self, name: str, /, default: T) -> LoadedActivation | T:
|
143
178
|
return self.activations.get(name, default)
|
144
179
|
|
180
|
+
def filter_by_prefix(self, prefix: str) -> ActLoad:
|
181
|
+
"""Create a filtered view of activations that match the given prefix.
|
182
|
+
|
183
|
+
Args:
|
184
|
+
prefix: The prefix to filter by. Only activations whose names start
|
185
|
+
with this prefix will be included in the filtered view.
|
186
|
+
|
187
|
+
Returns:
|
188
|
+
A new ActLoad instance that provides access to matching activations
|
189
|
+
with the prefix stripped from their keys. Can be chained multiple times.
|
190
|
+
|
191
|
+
Example:
|
192
|
+
>>> loader = ActLoad(some_dir)
|
193
|
+
>>> # If loader has keys "my.activation.first", "my.activation.second", "other.key"
|
194
|
+
>>> filtered = loader.filter_by_prefix("my.activation.")
|
195
|
+
>>> filtered["first"] # Accesses "my.activation.first"
|
196
|
+
>>> filtered["second"] # Accesses "my.activation.second"
|
197
|
+
>>> # Can be chained:
|
198
|
+
>>> double_filtered = loader.filter_by_prefix("my.").filter_by_prefix("activation.")
|
199
|
+
>>> double_filtered["first"] # Also accesses "my.activation.first"
|
200
|
+
"""
|
201
|
+
filtered_activations = {}
|
202
|
+
for name, activation in self.activations.items():
|
203
|
+
if name.startswith(prefix):
|
204
|
+
# Strip the prefix from the key
|
205
|
+
stripped_name = name[len(prefix) :]
|
206
|
+
filtered_activations[stripped_name] = activation
|
207
|
+
|
208
|
+
new_prefix_chain = self._prefix_chain + [prefix]
|
209
|
+
return ActLoad(
|
210
|
+
_base_activations=filtered_activations,
|
211
|
+
_prefix_chain=new_prefix_chain,
|
212
|
+
dir=self._dir,
|
213
|
+
)
|
214
|
+
|
145
215
|
|
146
216
|
ActivationLoader = ActLoad
|
@@ -241,6 +241,17 @@ class _Saver:
|
|
241
241
|
class ActSaveProvider:
|
242
242
|
_saver: _Saver | None = None
|
243
243
|
_prefixes: list[str] = []
|
244
|
+
_disable_count: int = 0
|
245
|
+
|
246
|
+
@property
|
247
|
+
def is_initialized(self) -> bool:
|
248
|
+
"""Returns True if ActSave.enable() has been called and not subsequently disabled."""
|
249
|
+
return self._saver is not None
|
250
|
+
|
251
|
+
@property
|
252
|
+
def is_enabled(self) -> bool:
|
253
|
+
"""Returns True if ActSave is currently active and will save activations."""
|
254
|
+
return self.is_initialized and self._disable_count == 0
|
244
255
|
|
245
256
|
def enable(self, save_dir: Path | None = None):
|
246
257
|
"""
|
@@ -294,6 +305,34 @@ class ActSaveProvider:
|
|
294
305
|
|
295
306
|
self._saver = None
|
296
307
|
self._prefixes = []
|
308
|
+
self._disable_count = 0
|
309
|
+
|
310
|
+
@contextlib.contextmanager
|
311
|
+
def disabled(self, condition: bool | Callable[[], bool] = True):
|
312
|
+
"""
|
313
|
+
Context manager to temporarily disable activation saving.
|
314
|
+
|
315
|
+
Args:
|
316
|
+
condition (bool | Callable[[], bool], optional):
|
317
|
+
If True or a callable returning True, saving is disabled within this context.
|
318
|
+
Defaults to True.
|
319
|
+
"""
|
320
|
+
if _torch_is_scripting():
|
321
|
+
yield
|
322
|
+
return
|
323
|
+
|
324
|
+
should_disable = condition() if callable(condition) else condition
|
325
|
+
if should_disable:
|
326
|
+
self._disable_count += 1
|
327
|
+
|
328
|
+
try:
|
329
|
+
yield
|
330
|
+
finally:
|
331
|
+
if should_disable:
|
332
|
+
self._disable_count -= 1
|
333
|
+
if self._disable_count < 0: # Should not happen
|
334
|
+
log.warning("ActSave disable count went below zero.")
|
335
|
+
self._disable_count = 0
|
297
336
|
|
298
337
|
@contextlib.contextmanager
|
299
338
|
def context(self, label: str):
|
@@ -307,7 +346,7 @@ class ActSaveProvider:
|
|
307
346
|
yield
|
308
347
|
return
|
309
348
|
|
310
|
-
if self.
|
349
|
+
if not self.is_enabled:
|
311
350
|
yield
|
312
351
|
return
|
313
352
|
|
@@ -368,9 +407,12 @@ class ActSaveProvider:
|
|
368
407
|
if _torch_is_scripting():
|
369
408
|
return
|
370
409
|
|
371
|
-
if self.
|
410
|
+
if not self.is_enabled:
|
372
411
|
return
|
373
412
|
|
413
|
+
# Ensure _saver is not None, which is guaranteed by is_enabled but mypy needs help
|
414
|
+
assert self._saver is not None
|
415
|
+
|
374
416
|
if acts is not None and callable(acts):
|
375
417
|
acts = acts()
|
376
418
|
self._saver.save(acts, **kwargs)
|
@@ -0,0 +1,160 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import contextlib
|
4
|
+
import functools
|
5
|
+
import importlib.util
|
6
|
+
import logging
|
7
|
+
from abc import ABC, abstractmethod
|
8
|
+
from collections.abc import Callable
|
9
|
+
from typing import Any, Generic, Optional, cast
|
10
|
+
|
11
|
+
from typing_extensions import (
|
12
|
+
ParamSpec,
|
13
|
+
Protocol,
|
14
|
+
TypeAliasType,
|
15
|
+
TypeVar,
|
16
|
+
override,
|
17
|
+
runtime_checkable,
|
18
|
+
)
|
19
|
+
|
20
|
+
from .utils import LovelyStats, format_tensor_stats
|
21
|
+
|
22
|
+
log = logging.getLogger(__name__)
|
23
|
+
|
24
|
+
TArray = TypeVar("TArray", infer_variance=True)
|
25
|
+
P = ParamSpec("P")
|
26
|
+
|
27
|
+
LovelyStatsFn = TypeAliasType(
|
28
|
+
"LovelyStatsFn",
|
29
|
+
Callable[[TArray], Optional[LovelyStats]],
|
30
|
+
type_params=(TArray,),
|
31
|
+
)
|
32
|
+
|
33
|
+
|
34
|
+
@runtime_checkable
|
35
|
+
class LovelyReprFn(Protocol[TArray]):
|
36
|
+
@property
|
37
|
+
def __lovely_repr_instance__(self) -> lovely_repr[TArray]: ...
|
38
|
+
|
39
|
+
@__lovely_repr_instance__.setter
|
40
|
+
def __lovely_repr_instance__(self, value: lovely_repr[TArray]) -> None: ...
|
41
|
+
|
42
|
+
@property
|
43
|
+
def __name__(self) -> str: ...
|
44
|
+
|
45
|
+
def set_fallback_repr(self, repr_fn: Callable[[TArray], str]) -> None: ...
|
46
|
+
def __call__(self, value: TArray, /) -> str: ...
|
47
|
+
|
48
|
+
|
49
|
+
def _find_missing_deps(dependencies: list[str]):
|
50
|
+
missing_deps: list[str] = []
|
51
|
+
|
52
|
+
for dep in dependencies:
|
53
|
+
if importlib.util.find_spec(dep) is not None:
|
54
|
+
continue
|
55
|
+
|
56
|
+
missing_deps.append(dep)
|
57
|
+
|
58
|
+
return missing_deps
|
59
|
+
|
60
|
+
|
61
|
+
class lovely_repr(Generic[TArray]):
|
62
|
+
@override
|
63
|
+
def __init__(
|
64
|
+
self,
|
65
|
+
dependencies: list[str],
|
66
|
+
fallback_repr: Callable[[TArray], str] | None = None,
|
67
|
+
):
|
68
|
+
"""
|
69
|
+
Decorator to create a lovely representation function for an array.
|
70
|
+
|
71
|
+
Args:
|
72
|
+
dependencies: List of dependencies to check before running the function.
|
73
|
+
If any dependency is not available, the function will not run.
|
74
|
+
fallback_repr: A function that takes an array and returns its fallback representation.
|
75
|
+
Returns:
|
76
|
+
A decorator function that takes a function and returns a lovely representation function.
|
77
|
+
|
78
|
+
Example:
|
79
|
+
@lovely_repr(dependencies=["torch"])
|
80
|
+
def my_array_stats(array):
|
81
|
+
return {...}
|
82
|
+
"""
|
83
|
+
super().__init__()
|
84
|
+
|
85
|
+
if fallback_repr is None:
|
86
|
+
fallback_repr = repr
|
87
|
+
|
88
|
+
self._dependencies = dependencies
|
89
|
+
self._fallback_repr = fallback_repr
|
90
|
+
|
91
|
+
def set_fallback_repr(self, repr_fn: Callable[[TArray], str]) -> None:
|
92
|
+
self._fallback_repr = repr_fn
|
93
|
+
|
94
|
+
def __call__(
|
95
|
+
self, array_stats_fn: LovelyStatsFn[TArray], /
|
96
|
+
) -> LovelyReprFn[TArray]:
|
97
|
+
@functools.wraps(array_stats_fn)
|
98
|
+
def wrapper_fn(array: TArray) -> str:
|
99
|
+
if missing_deps := _find_missing_deps(self._dependencies):
|
100
|
+
log.warning(
|
101
|
+
f"Missing dependencies: {', '.join(missing_deps)}. "
|
102
|
+
"Skipping lovely representation."
|
103
|
+
)
|
104
|
+
return self._fallback_repr(array)
|
105
|
+
|
106
|
+
if (stats := array_stats_fn(array)) is None:
|
107
|
+
return self._fallback_repr(array)
|
108
|
+
|
109
|
+
return format_tensor_stats(stats)
|
110
|
+
|
111
|
+
wrapper = cast(LovelyReprFn[TArray], wrapper_fn)
|
112
|
+
wrapper.__lovely_repr_instance__ = self
|
113
|
+
wrapper.set_fallback_repr = self.set_fallback_repr
|
114
|
+
return wrapper
|
115
|
+
|
116
|
+
|
117
|
+
class lovely_patch(contextlib.AbstractContextManager["lovely_patch"], ABC):
|
118
|
+
def __init__(self):
|
119
|
+
self._patched = False
|
120
|
+
self.__enter__()
|
121
|
+
|
122
|
+
def dependencies(self) -> list[str]:
|
123
|
+
"""Subclasses can override this to specify the dependencies of the patch."""
|
124
|
+
return []
|
125
|
+
|
126
|
+
@abstractmethod
|
127
|
+
def patch(self):
|
128
|
+
"""Subclasses must implement this."""
|
129
|
+
|
130
|
+
@abstractmethod
|
131
|
+
def unpatch(self):
|
132
|
+
"""Subclasses must implement this."""
|
133
|
+
|
134
|
+
@override
|
135
|
+
def __enter__(self):
|
136
|
+
if self._patched:
|
137
|
+
return self
|
138
|
+
|
139
|
+
if missing_deps := _find_missing_deps(self.dependencies()):
|
140
|
+
log.warning(
|
141
|
+
f"Missing dependencies: {', '.join(missing_deps)}. "
|
142
|
+
"Skipping monkey patch."
|
143
|
+
)
|
144
|
+
return self
|
145
|
+
|
146
|
+
self.patch()
|
147
|
+
self._patched = True
|
148
|
+
return self
|
149
|
+
|
150
|
+
@override
|
151
|
+
def __exit__(self, *exc_info):
|
152
|
+
if not self._patched:
|
153
|
+
return
|
154
|
+
|
155
|
+
self.unpatch()
|
156
|
+
self._patched = False
|
157
|
+
|
158
|
+
def close(self):
|
159
|
+
"""Explicitly clean up the resource."""
|
160
|
+
self.__exit__(None, None, None)
|
@@ -5,9 +5,9 @@ import importlib.util
|
|
5
5
|
import logging
|
6
6
|
from typing import Literal
|
7
7
|
|
8
|
-
from typing_extensions import TypeAliasType, assert_never
|
8
|
+
from typing_extensions import TypeAliasType, assert_never, override
|
9
9
|
|
10
|
-
from
|
10
|
+
from ._base import lovely_patch
|
11
11
|
|
12
12
|
Library = TypeAliasType("Library", Literal["numpy", "torch", "jax"])
|
13
13
|
|
@@ -28,40 +28,48 @@ def _find_deps() -> list[Library]:
|
|
28
28
|
return deps
|
29
29
|
|
30
30
|
|
31
|
-
|
32
|
-
def
|
33
|
-
|
34
|
-
libraries
|
31
|
+
class monkey_patch(lovely_patch):
|
32
|
+
def __init__(self, libraries: list[Library] | Literal["auto"] = "auto"):
|
33
|
+
self.libraries = libraries
|
34
|
+
if self.libraries == "auto":
|
35
|
+
self.libraries = _find_deps()
|
35
36
|
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
37
|
+
if not self.libraries:
|
38
|
+
raise ValueError(
|
39
|
+
"No libraries found for monkey patching. "
|
40
|
+
"Please install numpy, torch, or jax."
|
41
|
+
)
|
42
|
+
|
43
|
+
self.stack = contextlib.ExitStack()
|
44
|
+
super().__init__()
|
41
45
|
|
42
|
-
|
43
|
-
|
46
|
+
@override
|
47
|
+
def patch(self):
|
48
|
+
for library in self.libraries:
|
44
49
|
if library == "torch":
|
45
50
|
from .torch_ import torch_monkey_patch
|
46
51
|
|
47
|
-
stack.enter_context(torch_monkey_patch())
|
52
|
+
self.stack.enter_context(torch_monkey_patch())
|
48
53
|
elif library == "jax":
|
49
54
|
from .jax_ import jax_monkey_patch
|
50
55
|
|
51
|
-
stack.enter_context(jax_monkey_patch())
|
56
|
+
self.stack.enter_context(jax_monkey_patch())
|
52
57
|
elif library == "numpy":
|
53
58
|
from .numpy_ import numpy_monkey_patch
|
54
59
|
|
55
|
-
stack.enter_context(numpy_monkey_patch())
|
60
|
+
self.stack.enter_context(numpy_monkey_patch())
|
56
61
|
else:
|
57
|
-
assert_never(library)
|
62
|
+
assert_never(library) # type: ignore
|
58
63
|
|
59
64
|
log.info(
|
60
|
-
f"Monkey patched libraries: {', '.join(libraries)}. "
|
65
|
+
f"Monkey patched libraries: {', '.join(self.libraries)}. "
|
61
66
|
"You can now use the lovely functions with these libraries."
|
62
67
|
)
|
63
|
-
|
68
|
+
|
69
|
+
@override
|
70
|
+
def unpatch(self):
|
71
|
+
self.stack.close()
|
64
72
|
log.info(
|
65
|
-
f"Unmonkey patched libraries: {', '.join(libraries)}. "
|
73
|
+
f"Unmonkey patched libraries: {', '.join(self.libraries)}. "
|
66
74
|
"You can now use the lovely functions with these libraries."
|
67
75
|
)
|
@@ -3,8 +3,9 @@ from __future__ import annotations
|
|
3
3
|
from typing import TYPE_CHECKING, cast
|
4
4
|
|
5
5
|
import numpy as np
|
6
|
+
from typing_extensions import override
|
6
7
|
|
7
|
-
from ._base import
|
8
|
+
from ._base import lovely_patch, lovely_repr
|
8
9
|
from .utils import LovelyStats, array_stats, patch_to
|
9
10
|
|
10
11
|
if TYPE_CHECKING:
|
@@ -53,7 +54,7 @@ def _device(array: jax.Array) -> str:
|
|
53
54
|
return f"{device.platform}:{device.id}"
|
54
55
|
|
55
56
|
|
56
|
-
@lovely_repr(dependencies=["jax"]
|
57
|
+
@lovely_repr(dependencies=["jax"])
|
57
58
|
def jax_repr(array: jax.Array) -> LovelyStats | None:
|
58
59
|
import jax.numpy as jnp
|
59
60
|
|
@@ -77,17 +78,25 @@ def jax_repr(array: jax.Array) -> LovelyStats | None:
|
|
77
78
|
}
|
78
79
|
|
79
80
|
|
80
|
-
|
81
|
-
|
82
|
-
|
81
|
+
class jax_monkey_patch(lovely_patch):
|
82
|
+
@override
|
83
|
+
def dependencies(self) -> list[str]:
|
84
|
+
return ["jax"]
|
85
|
+
|
86
|
+
@override
|
87
|
+
def patch(self):
|
88
|
+
from jax._src import array
|
89
|
+
|
90
|
+
self.prev_repr = array.ArrayImpl.__repr__
|
91
|
+
self.prev_str = array.ArrayImpl.__str__
|
92
|
+
jax_repr.set_fallback_repr(self.prev_repr)
|
83
93
|
|
84
|
-
prev_repr = array.ArrayImpl.__repr__
|
85
|
-
prev_str = array.ArrayImpl.__str__
|
86
|
-
try:
|
87
94
|
patch_to(array.ArrayImpl, "__repr__", jax_repr)
|
88
95
|
patch_to(array.ArrayImpl, "__str__", jax_repr)
|
89
96
|
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
97
|
+
@override
|
98
|
+
def unpatch(self):
|
99
|
+
from jax._src import array
|
100
|
+
|
101
|
+
patch_to(array.ArrayImpl, "__repr__", self.prev_repr)
|
102
|
+
patch_to(array.ArrayImpl, "__str__", self.prev_str)
|
@@ -3,8 +3,9 @@ from __future__ import annotations
|
|
3
3
|
import logging
|
4
4
|
|
5
5
|
import numpy as np
|
6
|
+
from typing_extensions import override
|
6
7
|
|
7
|
-
from ._base import
|
8
|
+
from ._base import lovely_patch, lovely_repr
|
8
9
|
from .utils import LovelyStats, array_stats
|
9
10
|
|
10
11
|
|
@@ -51,7 +52,7 @@ def _dtype_str(array: np.ndarray) -> str:
|
|
51
52
|
return dtype_base
|
52
53
|
|
53
54
|
|
54
|
-
@lovely_repr(dependencies=["numpy"]
|
55
|
+
@lovely_repr(dependencies=["numpy"])
|
55
56
|
def numpy_repr(array: np.ndarray) -> LovelyStats | None:
|
56
57
|
# For dtypes like `object` or `str`, we let the fallback repr handle it
|
57
58
|
if not np.issubdtype(array.dtype, np.number):
|
@@ -71,42 +72,44 @@ def numpy_repr(array: np.ndarray) -> LovelyStats | None:
|
|
71
72
|
}
|
72
73
|
|
73
74
|
|
74
|
-
|
75
|
-
if _np_ge_2():
|
75
|
+
numpy_repr.set_fallback_repr(np.array_repr)
|
76
76
|
|
77
|
-
|
78
|
-
|
79
|
-
|
77
|
+
|
78
|
+
class numpy_monkey_patch(lovely_patch):
|
79
|
+
@override
|
80
|
+
def dependencies(self) -> list[str]:
|
81
|
+
return ["numpy"]
|
82
|
+
|
83
|
+
@override
|
84
|
+
def patch(self):
|
85
|
+
if _np_ge_2():
|
86
|
+
self.original_options = np.get_printoptions()
|
80
87
|
np.set_printoptions(override_repr=numpy_repr)
|
81
88
|
logging.info(
|
82
89
|
f"Numpy monkey patching: using {numpy_repr.__name__} for numpy arrays. "
|
83
90
|
f"{np.get_printoptions()=}"
|
84
91
|
)
|
85
|
-
|
86
|
-
|
87
|
-
|
92
|
+
else:
|
93
|
+
# For legacy numpy, `set_string_function(None)` reverts to the default,
|
94
|
+
# so no state needs to be saved.
|
95
|
+
np.set_string_function(numpy_repr, True) # pyright: ignore[reportAttributeAccessIssue]
|
96
|
+
np.set_string_function(numpy_repr, False) # pyright: ignore[reportAttributeAccessIssue]
|
88
97
|
logging.info(
|
89
|
-
f"Numpy
|
98
|
+
f"Numpy monkey patching: using {numpy_repr.__name__} for numpy arrays. "
|
90
99
|
f"{np.get_printoptions()=}"
|
91
100
|
)
|
92
101
|
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
try:
|
98
|
-
np.set_string_function(numpy_repr, True) # pyright: ignore[reportAttributeAccessIssue]
|
99
|
-
np.set_string_function(numpy_repr, False) # pyright: ignore[reportAttributeAccessIssue]
|
100
|
-
|
102
|
+
@override
|
103
|
+
def unpatch(self):
|
104
|
+
if _np_ge_2():
|
105
|
+
np.set_printoptions(**self.original_options)
|
101
106
|
logging.info(
|
102
|
-
f"Numpy
|
107
|
+
f"Numpy unmonkey patching: using {numpy_repr.__name__} for numpy arrays. "
|
103
108
|
f"{np.get_printoptions()=}"
|
104
109
|
)
|
105
|
-
|
106
|
-
finally:
|
110
|
+
else:
|
107
111
|
np.set_string_function(None, True) # pyright: ignore[reportAttributeAccessIssue]
|
108
112
|
np.set_string_function(None, False) # pyright: ignore[reportAttributeAccessIssue]
|
109
|
-
|
110
113
|
logging.info(
|
111
114
|
f"Numpy unmonkey patching: using {numpy_repr.__name__} for numpy arrays. "
|
112
115
|
f"{np.get_printoptions()=}"
|
@@ -3,8 +3,9 @@ from __future__ import annotations
|
|
3
3
|
from typing import TYPE_CHECKING
|
4
4
|
|
5
5
|
import numpy as np
|
6
|
+
from typing_extensions import override
|
6
7
|
|
7
|
-
from ._base import
|
8
|
+
from ._base import lovely_patch, lovely_repr
|
8
9
|
from .utils import LovelyStats, array_stats, patch_to
|
9
10
|
|
10
11
|
if TYPE_CHECKING:
|
@@ -59,7 +60,7 @@ def _to_np(tensor: torch.Tensor) -> np.ndarray:
|
|
59
60
|
return t_np
|
60
61
|
|
61
62
|
|
62
|
-
@lovely_repr(dependencies=["torch"]
|
63
|
+
@lovely_repr(dependencies=["torch"])
|
63
64
|
def torch_repr(tensor: torch.Tensor) -> LovelyStats | None:
|
64
65
|
return {
|
65
66
|
# Basic attributes
|
@@ -80,20 +81,31 @@ def torch_repr(tensor: torch.Tensor) -> LovelyStats | None:
|
|
80
81
|
}
|
81
82
|
|
82
83
|
|
83
|
-
|
84
|
-
|
85
|
-
|
84
|
+
class torch_monkey_patch(lovely_patch):
|
85
|
+
@override
|
86
|
+
def dependencies(self) -> list[str]:
|
87
|
+
return ["torch"]
|
88
|
+
|
89
|
+
@override
|
90
|
+
def patch(self):
|
91
|
+
import torch
|
92
|
+
|
93
|
+
self.original_repr = torch.Tensor.__repr__
|
94
|
+
self.original_str = torch.Tensor.__str__
|
95
|
+
self.original_parameter_repr = torch.nn.Parameter.__repr__
|
96
|
+
torch_repr.set_fallback_repr(self.original_repr)
|
86
97
|
|
87
|
-
original_repr = torch.Tensor.__repr__
|
88
|
-
original_str = torch.Tensor.__str__
|
89
|
-
original_parameter_repr = torch.nn.Parameter.__repr__
|
90
|
-
try:
|
91
98
|
patch_to(torch.Tensor, "__repr__", torch_repr)
|
92
99
|
patch_to(torch.Tensor, "__str__", torch_repr)
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
+
try:
|
101
|
+
delattr(torch.nn.Parameter, "__repr__")
|
102
|
+
except AttributeError:
|
103
|
+
pass
|
104
|
+
|
105
|
+
@override
|
106
|
+
def unpatch(self):
|
107
|
+
import torch
|
108
|
+
|
109
|
+
patch_to(torch.Tensor, "__repr__", self.original_repr)
|
110
|
+
patch_to(torch.Tensor, "__str__", self.original_str)
|
111
|
+
patch_to(torch.nn.Parameter, "__repr__", self.original_parameter_repr)
|
@@ -1,158 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
import functools
|
4
|
-
import importlib.util
|
5
|
-
import logging
|
6
|
-
from collections.abc import Callable, Iterator
|
7
|
-
|
8
|
-
from typing_extensions import ParamSpec, TypeAliasType, TypeVar
|
9
|
-
|
10
|
-
from ..util import ContextResource, resource_factory_contextmanager
|
11
|
-
from .utils import LovelyStats, format_tensor_stats
|
12
|
-
|
13
|
-
log = logging.getLogger(__name__)
|
14
|
-
|
15
|
-
TArray = TypeVar("TArray", infer_variance=True)
|
16
|
-
P = ParamSpec("P")
|
17
|
-
|
18
|
-
LovelyStatsFn = TypeAliasType(
|
19
|
-
"LovelyStatsFn",
|
20
|
-
Callable[[TArray], LovelyStats | None],
|
21
|
-
type_params=(TArray,),
|
22
|
-
)
|
23
|
-
LovelyReprFn = TypeAliasType(
|
24
|
-
"LovelyReprFn",
|
25
|
-
Callable[[TArray], str],
|
26
|
-
type_params=(TArray,),
|
27
|
-
)
|
28
|
-
|
29
|
-
|
30
|
-
def _find_missing_deps(dependencies: list[str]):
|
31
|
-
missing_deps: list[str] = []
|
32
|
-
|
33
|
-
for dep in dependencies:
|
34
|
-
if importlib.util.find_spec(dep) is not None:
|
35
|
-
continue
|
36
|
-
|
37
|
-
missing_deps.append(dep)
|
38
|
-
|
39
|
-
return missing_deps
|
40
|
-
|
41
|
-
|
42
|
-
def lovely_repr(dependencies: list[str], fallback_repr: Callable[[TArray], str]):
|
43
|
-
"""
|
44
|
-
Decorator to create a lovely representation function for an array.
|
45
|
-
|
46
|
-
Args:
|
47
|
-
dependencies: List of dependencies to check before running the function.
|
48
|
-
If any dependency is not available, the function will not run.
|
49
|
-
fallback_repr: A function that takes an array and returns its fallback representation.
|
50
|
-
Returns:
|
51
|
-
A decorator function that takes a function and returns a lovely representation function.
|
52
|
-
|
53
|
-
Example:
|
54
|
-
@lovely_repr(dependencies=["torch"])
|
55
|
-
def my_array_stats(array):
|
56
|
-
return {...}
|
57
|
-
"""
|
58
|
-
|
59
|
-
def decorator_fn(array_stats_fn: LovelyStatsFn[TArray]) -> LovelyReprFn[TArray]:
|
60
|
-
"""
|
61
|
-
Decorator to create a lovely representation function for an array.
|
62
|
-
|
63
|
-
Args:
|
64
|
-
array_stats_fn: A function that takes an array and returns its stats,
|
65
|
-
or `None` if the array is not supported.
|
66
|
-
|
67
|
-
Returns:
|
68
|
-
A function that takes an array and returns its lovely representation.
|
69
|
-
"""
|
70
|
-
|
71
|
-
@functools.wraps(array_stats_fn)
|
72
|
-
def wrapper(array: TArray) -> str:
|
73
|
-
if missing_deps := _find_missing_deps(dependencies):
|
74
|
-
log.warning(
|
75
|
-
f"Missing dependencies: {', '.join(missing_deps)}. "
|
76
|
-
"Skipping lovely representation."
|
77
|
-
)
|
78
|
-
return fallback_repr(array)
|
79
|
-
|
80
|
-
if (stats := array_stats_fn(array)) is None:
|
81
|
-
return fallback_repr(array)
|
82
|
-
|
83
|
-
return format_tensor_stats(stats)
|
84
|
-
|
85
|
-
return wrapper
|
86
|
-
|
87
|
-
return decorator_fn
|
88
|
-
|
89
|
-
|
90
|
-
LovelyMonkeyPatchInputFn = TypeAliasType(
|
91
|
-
"LovelyMonkeyPatchInputFn",
|
92
|
-
Callable[P, Iterator[None]],
|
93
|
-
type_params=(P,),
|
94
|
-
)
|
95
|
-
LovelyMonkeyPatchFn = TypeAliasType(
|
96
|
-
"LovelyMonkeyPatchFn",
|
97
|
-
Callable[P, ContextResource[None]],
|
98
|
-
type_params=(P,),
|
99
|
-
)
|
100
|
-
|
101
|
-
|
102
|
-
def _nullcontext_generator():
|
103
|
-
"""A generator that does nothing."""
|
104
|
-
yield
|
105
|
-
|
106
|
-
|
107
|
-
def _wrap_monkey_patch_fn(
|
108
|
-
monkey_patch_fn: LovelyMonkeyPatchInputFn[P],
|
109
|
-
dependencies: list[str],
|
110
|
-
) -> LovelyMonkeyPatchInputFn[P]:
|
111
|
-
@functools.wraps(monkey_patch_fn)
|
112
|
-
def wrapper(*args: P.args, **kwargs: P.kwargs) -> Iterator[None]:
|
113
|
-
if missing_deps := _find_missing_deps(dependencies):
|
114
|
-
log.warning(
|
115
|
-
f"Missing dependencies: {', '.join(missing_deps)}. "
|
116
|
-
"Skipping monkey patch."
|
117
|
-
)
|
118
|
-
return _nullcontext_generator()
|
119
|
-
|
120
|
-
return monkey_patch_fn(*args, **kwargs)
|
121
|
-
|
122
|
-
return wrapper
|
123
|
-
|
124
|
-
|
125
|
-
def monkey_patch_contextmanager(dependencies: list[str]):
|
126
|
-
"""
|
127
|
-
Decorator to create a monkey patch function for an array.
|
128
|
-
|
129
|
-
Args:
|
130
|
-
dependencies: List of dependencies to check before running the function.
|
131
|
-
If any dependency is not available, the function will not run.
|
132
|
-
|
133
|
-
Returns:
|
134
|
-
A decorator function that takes a function and returns a monkey patch function.
|
135
|
-
|
136
|
-
Example:
|
137
|
-
@monkey_patch_contextmanager(dependencies=["torch"])
|
138
|
-
def my_array_monkey_patch():
|
139
|
-
...
|
140
|
-
"""
|
141
|
-
|
142
|
-
def decorator_fn(
|
143
|
-
monkey_patch_fn: LovelyMonkeyPatchInputFn[P],
|
144
|
-
) -> LovelyMonkeyPatchFn[P]:
|
145
|
-
"""
|
146
|
-
Decorator to create a monkey patch function for an array.
|
147
|
-
|
148
|
-
Args:
|
149
|
-
monkey_patch_fn: A function that applies the monkey patch.
|
150
|
-
|
151
|
-
Returns:
|
152
|
-
A function that applies the monkey patch.
|
153
|
-
"""
|
154
|
-
|
155
|
-
wrapped_fn = _wrap_monkey_patch_fn(monkey_patch_fn, dependencies)
|
156
|
-
return resource_factory_contextmanager(wrapped_fn)
|
157
|
-
|
158
|
-
return decorator_fn
|
@@ -1,92 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
import contextlib
|
4
|
-
import functools
|
5
|
-
from collections.abc import Callable, Iterator
|
6
|
-
from typing import Any, Generic
|
7
|
-
|
8
|
-
from typing_extensions import ParamSpec, TypeVar, override
|
9
|
-
|
10
|
-
R = TypeVar("R")
|
11
|
-
P = ParamSpec("P")
|
12
|
-
|
13
|
-
|
14
|
-
class ContextResource(contextlib.AbstractContextManager[R], Generic[R]):
|
15
|
-
"""A class that provides both direct access to a resource and context management."""
|
16
|
-
|
17
|
-
def __init__(self, resource: R, cleanup_func: Callable[[R], Any]):
|
18
|
-
self.resource = resource
|
19
|
-
self._cleanup_func = cleanup_func
|
20
|
-
|
21
|
-
@override
|
22
|
-
def __enter__(self) -> R:
|
23
|
-
"""When used as a context manager, return the wrapped resource."""
|
24
|
-
return self.resource
|
25
|
-
|
26
|
-
@override
|
27
|
-
def __exit__(self, *exc_info) -> None:
|
28
|
-
"""Clean up the resource when exiting the context."""
|
29
|
-
self._cleanup_func(self.resource)
|
30
|
-
|
31
|
-
def close(self) -> None:
|
32
|
-
"""Explicitly clean up the resource."""
|
33
|
-
self._cleanup_func(self.resource)
|
34
|
-
|
35
|
-
|
36
|
-
def resource_factory(
|
37
|
-
create_func: Callable[P, R], cleanup_func: Callable[[R], None]
|
38
|
-
) -> Callable[P, ContextResource[R]]:
|
39
|
-
"""
|
40
|
-
Create a factory function that returns a ContextResource.
|
41
|
-
|
42
|
-
Args:
|
43
|
-
create_func: Function that creates the resource
|
44
|
-
cleanup_func: Function that cleans up the resource
|
45
|
-
|
46
|
-
Returns:
|
47
|
-
A function that returns a ContextResource wrapping the created resource
|
48
|
-
"""
|
49
|
-
|
50
|
-
@functools.wraps(create_func)
|
51
|
-
def wrapper(*args: P.args, **kwargs: P.kwargs) -> ContextResource[R]:
|
52
|
-
resource = create_func(*args, **kwargs)
|
53
|
-
return ContextResource(resource, cleanup_func)
|
54
|
-
|
55
|
-
return wrapper
|
56
|
-
|
57
|
-
|
58
|
-
def resource_factory_from_context_fn(
|
59
|
-
context_func: Callable[P, contextlib.AbstractContextManager[R]],
|
60
|
-
) -> Callable[P, ContextResource[R]]:
|
61
|
-
"""
|
62
|
-
Create a factory function that returns a ContextResource.
|
63
|
-
|
64
|
-
Args:
|
65
|
-
context_func: Function that creates the resource
|
66
|
-
|
67
|
-
Returns:
|
68
|
-
A function that returns a ContextResource wrapping the created resource
|
69
|
-
"""
|
70
|
-
|
71
|
-
@functools.wraps(context_func)
|
72
|
-
def wrapper(*args: P.args, **kwargs: P.kwargs) -> ContextResource[R]:
|
73
|
-
context = context_func(*args, **kwargs)
|
74
|
-
resource = context.__enter__()
|
75
|
-
return ContextResource(resource, lambda _: context.__exit__(None, None, None))
|
76
|
-
|
77
|
-
return wrapper
|
78
|
-
|
79
|
-
|
80
|
-
def resource_factory_contextmanager(
|
81
|
-
context_func: Callable[P, Iterator[R]],
|
82
|
-
) -> Callable[P, ContextResource[R]]:
|
83
|
-
"""
|
84
|
-
Create a factory function that returns a ContextResource.
|
85
|
-
|
86
|
-
Args:
|
87
|
-
context_func: Generator function that creates the resource, yields it, and cleans up the resource when done.
|
88
|
-
|
89
|
-
Returns:
|
90
|
-
A function that returns a ContextResource wrapping the created resource
|
91
|
-
"""
|
92
|
-
return resource_factory_from_context_fn(contextlib.contextmanager(context_func))
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|