nshutils 0.30.1__tar.gz → 0.31.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nshutils-0.30.1 → nshutils-0.31.0}/PKG-INFO +1 -1
- {nshutils-0.30.1 → nshutils-0.31.0}/pyproject.toml +9 -1
- {nshutils-0.30.1 → nshutils-0.31.0}/src/nshutils/__init__.pyi +1 -0
- {nshutils-0.30.1 → nshutils-0.31.0}/src/nshutils/actsave/_loader.py +73 -3
- {nshutils-0.30.1 → nshutils-0.31.0}/src/nshutils/lovely/__init__.py +2 -0
- {nshutils-0.30.1 → nshutils-0.31.0}/src/nshutils/lovely/_base.py +36 -60
- {nshutils-0.30.1 → nshutils-0.31.0}/src/nshutils/lovely/_monkey_patch_all.py +28 -20
- {nshutils-0.30.1 → nshutils-0.31.0}/src/nshutils/lovely/jax_.py +20 -12
- {nshutils-0.30.1 → nshutils-0.31.0}/src/nshutils/lovely/numpy_.py +22 -22
- {nshutils-0.30.1 → nshutils-0.31.0}/src/nshutils/lovely/torch_.py +21 -14
- nshutils-0.30.1/src/nshutils/util.py +0 -92
- {nshutils-0.30.1 → nshutils-0.31.0}/README.md +0 -0
- {nshutils-0.30.1 → nshutils-0.31.0}/src/nshutils/__init__.py +0 -0
- {nshutils-0.30.1 → nshutils-0.31.0}/src/nshutils/actsave/__init__.py +0 -0
- {nshutils-0.30.1 → nshutils-0.31.0}/src/nshutils/actsave/_saver.py +0 -0
- {nshutils-0.30.1 → nshutils-0.31.0}/src/nshutils/collections.py +0 -0
- {nshutils-0.30.1 → nshutils-0.31.0}/src/nshutils/display.py +0 -0
- {nshutils-0.30.1 → nshutils-0.31.0}/src/nshutils/logging.py +0 -0
- {nshutils-0.30.1 → nshutils-0.31.0}/src/nshutils/lovely/config.py +0 -0
- {nshutils-0.30.1 → nshutils-0.31.0}/src/nshutils/lovely/utils.py +0 -0
- {nshutils-0.30.1 → nshutils-0.31.0}/src/nshutils/snoop.py +0 -0
- {nshutils-0.30.1 → nshutils-0.31.0}/src/nshutils/typecheck.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
[project]
|
2
2
|
name = "nshutils"
|
3
|
-
version = "0.
|
3
|
+
version = "0.31.0"
|
4
4
|
description = ""
|
5
5
|
authors = [{ name = "Nima Shoghi", email = "nimashoghi@gmail.com" }]
|
6
6
|
requires-python = ">=3.9,<4.0"
|
@@ -28,6 +28,8 @@ basedpyright = "*"
|
|
28
28
|
ruff = "*"
|
29
29
|
ipykernel = "*"
|
30
30
|
ipywidgets = "*"
|
31
|
+
pytest = "*"
|
32
|
+
pytest-cov = "*"
|
31
33
|
|
32
34
|
[build-system]
|
33
35
|
requires = ["poetry-core"]
|
@@ -47,3 +49,9 @@ ignore = ["F722", "F821", "E731", "E741"]
|
|
47
49
|
|
48
50
|
[tool.ruff.lint.isort]
|
49
51
|
required-imports = ["from __future__ import annotations"]
|
52
|
+
|
53
|
+
|
54
|
+
[tool.pytest.ini_options]
|
55
|
+
testpaths = ["tests"]
|
56
|
+
python_files = ["test_*.py"]
|
57
|
+
addopts = "--cov=nshutils --cov-report=term-missing"
|
@@ -7,6 +7,7 @@ from .collections import apply_to_collection as apply_to_collection
|
|
7
7
|
from .display import display as display
|
8
8
|
from .logging import init_python_logging as init_python_logging
|
9
9
|
from .logging import setup_logging as setup_logging
|
10
|
+
from .lovely import lovely_monkey_patch as lovely_monkey_patch
|
10
11
|
from .snoop import snoop as snoop
|
11
12
|
from .typecheck import tassert as tassert
|
12
13
|
from .typecheck import typecheck_modules as typecheck_modules
|
@@ -105,14 +105,41 @@ class ActLoad:
|
|
105
105
|
path, _ = max(all_versions, key=lambda p: p[1])
|
106
106
|
return cls(path)
|
107
107
|
|
108
|
-
def __init__(
|
108
|
+
def __init__(
|
109
|
+
self,
|
110
|
+
dir: Path | None = None,
|
111
|
+
*,
|
112
|
+
_base_activations: dict[str, LoadedActivation] | None = None,
|
113
|
+
_prefix_chain: list[str] | None = None,
|
114
|
+
):
|
115
|
+
"""Initialize ActLoad from a directory or from filtered activations.
|
116
|
+
|
117
|
+
Args:
|
118
|
+
dir: Path to the activation directory. Required for root ActLoad instances.
|
119
|
+
_base_activations: Pre-filtered activations dict. Used internally for prefix filtering.
|
120
|
+
_prefix_chain: Chain of prefixes that have been applied. Used for repr.
|
121
|
+
"""
|
109
122
|
self._dir = dir
|
123
|
+
self._base_activations = _base_activations
|
124
|
+
self._prefix_chain = _prefix_chain or []
|
110
125
|
|
111
126
|
def activation(self, name: str):
|
112
|
-
|
127
|
+
if self._dir is None:
|
128
|
+
raise ValueError(
|
129
|
+
"Cannot create activation from filtered ActLoad without base directory"
|
130
|
+
)
|
131
|
+
# For filtered instances, we need to reconstruct the full name
|
132
|
+
full_name = "".join(self._prefix_chain) + name
|
133
|
+
return LoadedActivation(self._dir, full_name)
|
113
134
|
|
114
135
|
@cached_property
|
115
136
|
def activations(self):
|
137
|
+
if self._base_activations is not None:
|
138
|
+
return self._base_activations
|
139
|
+
|
140
|
+
if self._dir is None:
|
141
|
+
raise ValueError("ActLoad requires either dir or _base_activations")
|
142
|
+
|
116
143
|
dirs = list(self._dir.iterdir())
|
117
144
|
# Sort the dirs by the last modified time
|
118
145
|
dirs.sort(key=lambda p: p.stat().st_mtime)
|
@@ -128,6 +155,9 @@ class ActLoad:
|
|
128
155
|
def __len__(self):
|
129
156
|
return len(self.activations)
|
130
157
|
|
158
|
+
def _ipython_key_completions_(self):
|
159
|
+
return list(self.activations.keys())
|
160
|
+
|
131
161
|
@override
|
132
162
|
def __repr__(self):
|
133
163
|
acts_str = pprint.pformat(
|
@@ -137,10 +167,50 @@ class ActLoad:
|
|
137
167
|
}
|
138
168
|
)
|
139
169
|
acts_str = acts_str.replace("'<", "<").replace(">'", ">")
|
140
|
-
|
170
|
+
|
171
|
+
if self._prefix_chain:
|
172
|
+
prefix_str = "".join(self._prefix_chain)
|
173
|
+
return f"ActLoad(prefix='{prefix_str}', {acts_str})"
|
174
|
+
else:
|
175
|
+
return f"ActLoad({acts_str})"
|
141
176
|
|
142
177
|
def get(self, name: str, /, default: T) -> LoadedActivation | T:
|
143
178
|
return self.activations.get(name, default)
|
144
179
|
|
180
|
+
def filter_by_prefix(self, prefix: str) -> ActLoad:
|
181
|
+
"""Create a filtered view of activations that match the given prefix.
|
182
|
+
|
183
|
+
Args:
|
184
|
+
prefix: The prefix to filter by. Only activations whose names start
|
185
|
+
with this prefix will be included in the filtered view.
|
186
|
+
|
187
|
+
Returns:
|
188
|
+
A new ActLoad instance that provides access to matching activations
|
189
|
+
with the prefix stripped from their keys. Can be chained multiple times.
|
190
|
+
|
191
|
+
Example:
|
192
|
+
>>> loader = ActLoad(some_dir)
|
193
|
+
>>> # If loader has keys "my.activation.first", "my.activation.second", "other.key"
|
194
|
+
>>> filtered = loader.filter_by_prefix("my.activation.")
|
195
|
+
>>> filtered["first"] # Accesses "my.activation.first"
|
196
|
+
>>> filtered["second"] # Accesses "my.activation.second"
|
197
|
+
>>> # Can be chained:
|
198
|
+
>>> double_filtered = loader.filter_by_prefix("my.").filter_by_prefix("activation.")
|
199
|
+
>>> double_filtered["first"] # Also accesses "my.activation.first"
|
200
|
+
"""
|
201
|
+
filtered_activations = {}
|
202
|
+
for name, activation in self.activations.items():
|
203
|
+
if name.startswith(prefix):
|
204
|
+
# Strip the prefix from the key
|
205
|
+
stripped_name = name[len(prefix) :]
|
206
|
+
filtered_activations[stripped_name] = activation
|
207
|
+
|
208
|
+
new_prefix_chain = self._prefix_chain + [prefix]
|
209
|
+
return ActLoad(
|
210
|
+
_base_activations=filtered_activations,
|
211
|
+
_prefix_chain=new_prefix_chain,
|
212
|
+
dir=self._dir,
|
213
|
+
)
|
214
|
+
|
145
215
|
|
146
216
|
ActivationLoader = ActLoad
|
@@ -1,10 +1,12 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import contextlib
|
3
4
|
import functools
|
4
5
|
import importlib.util
|
5
6
|
import logging
|
6
|
-
from
|
7
|
-
from
|
7
|
+
from abc import ABC, abstractmethod
|
8
|
+
from collections.abc import Callable
|
9
|
+
from typing import Any, Generic, Optional, cast
|
8
10
|
|
9
11
|
from typing_extensions import (
|
10
12
|
ParamSpec,
|
@@ -15,7 +17,6 @@ from typing_extensions import (
|
|
15
17
|
runtime_checkable,
|
16
18
|
)
|
17
19
|
|
18
|
-
from ..util import ContextResource, resource_factory_contextmanager
|
19
20
|
from .utils import LovelyStats, format_tensor_stats
|
20
21
|
|
21
22
|
log = logging.getLogger(__name__)
|
@@ -113,72 +114,47 @@ class lovely_repr(Generic[TArray]):
|
|
113
114
|
return wrapper
|
114
115
|
|
115
116
|
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
)
|
121
|
-
LovelyMonkeyPatchFn = TypeAliasType(
|
122
|
-
"LovelyMonkeyPatchFn",
|
123
|
-
Callable[P, ContextResource[None]],
|
124
|
-
type_params=(P,),
|
125
|
-
)
|
117
|
+
class lovely_patch(contextlib.AbstractContextManager["lovely_patch"], ABC):
|
118
|
+
def __init__(self):
|
119
|
+
self._patched = False
|
120
|
+
self.__enter__()
|
126
121
|
|
122
|
+
def dependencies(self) -> list[str]:
|
123
|
+
"""Subclasses can override this to specify the dependencies of the patch."""
|
124
|
+
return []
|
127
125
|
|
128
|
-
|
129
|
-
|
130
|
-
|
126
|
+
@abstractmethod
|
127
|
+
def patch(self):
|
128
|
+
"""Subclasses must implement this."""
|
131
129
|
|
130
|
+
@abstractmethod
|
131
|
+
def unpatch(self):
|
132
|
+
"""Subclasses must implement this."""
|
132
133
|
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
if missing_deps := _find_missing_deps(dependencies):
|
134
|
+
@override
|
135
|
+
def __enter__(self):
|
136
|
+
if self._patched:
|
137
|
+
return self
|
138
|
+
|
139
|
+
if missing_deps := _find_missing_deps(self.dependencies()):
|
140
140
|
log.warning(
|
141
141
|
f"Missing dependencies: {', '.join(missing_deps)}. "
|
142
142
|
"Skipping monkey patch."
|
143
143
|
)
|
144
|
-
return
|
145
|
-
|
146
|
-
return monkey_patch_fn(*args, **kwargs)
|
147
|
-
|
148
|
-
return wrapper
|
149
|
-
|
150
|
-
|
151
|
-
def monkey_patch_contextmanager(dependencies: list[str]):
|
152
|
-
"""
|
153
|
-
Decorator to create a monkey patch function for an array.
|
154
|
-
|
155
|
-
Args:
|
156
|
-
dependencies: List of dependencies to check before running the function.
|
157
|
-
If any dependency is not available, the function will not run.
|
144
|
+
return self
|
158
145
|
|
159
|
-
|
160
|
-
|
146
|
+
self.patch()
|
147
|
+
self._patched = True
|
148
|
+
return self
|
161
149
|
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
"""
|
167
|
-
|
168
|
-
def decorator_fn(
|
169
|
-
monkey_patch_fn: LovelyMonkeyPatchInputFn[P],
|
170
|
-
) -> LovelyMonkeyPatchFn[P]:
|
171
|
-
"""
|
172
|
-
Decorator to create a monkey patch function for an array.
|
173
|
-
|
174
|
-
Args:
|
175
|
-
monkey_patch_fn: A function that applies the monkey patch.
|
176
|
-
|
177
|
-
Returns:
|
178
|
-
A function that applies the monkey patch.
|
179
|
-
"""
|
150
|
+
@override
|
151
|
+
def __exit__(self, *exc_info):
|
152
|
+
if not self._patched:
|
153
|
+
return
|
180
154
|
|
181
|
-
|
182
|
-
|
155
|
+
self.unpatch()
|
156
|
+
self._patched = False
|
183
157
|
|
184
|
-
|
158
|
+
def close(self):
|
159
|
+
"""Explicitly clean up the resource."""
|
160
|
+
self.__exit__(None, None, None)
|
@@ -5,9 +5,9 @@ import importlib.util
|
|
5
5
|
import logging
|
6
6
|
from typing import Literal
|
7
7
|
|
8
|
-
from typing_extensions import TypeAliasType, assert_never
|
8
|
+
from typing_extensions import TypeAliasType, assert_never, override
|
9
9
|
|
10
|
-
from
|
10
|
+
from ._base import lovely_patch
|
11
11
|
|
12
12
|
Library = TypeAliasType("Library", Literal["numpy", "torch", "jax"])
|
13
13
|
|
@@ -28,40 +28,48 @@ def _find_deps() -> list[Library]:
|
|
28
28
|
return deps
|
29
29
|
|
30
30
|
|
31
|
-
|
32
|
-
def
|
33
|
-
|
34
|
-
libraries
|
31
|
+
class monkey_patch(lovely_patch):
|
32
|
+
def __init__(self, libraries: list[Library] | Literal["auto"] = "auto"):
|
33
|
+
self.libraries = libraries
|
34
|
+
if self.libraries == "auto":
|
35
|
+
self.libraries = _find_deps()
|
35
36
|
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
37
|
+
if not self.libraries:
|
38
|
+
raise ValueError(
|
39
|
+
"No libraries found for monkey patching. "
|
40
|
+
"Please install numpy, torch, or jax."
|
41
|
+
)
|
42
|
+
|
43
|
+
self.stack = contextlib.ExitStack()
|
44
|
+
super().__init__()
|
41
45
|
|
42
|
-
|
43
|
-
|
46
|
+
@override
|
47
|
+
def patch(self):
|
48
|
+
for library in self.libraries:
|
44
49
|
if library == "torch":
|
45
50
|
from .torch_ import torch_monkey_patch
|
46
51
|
|
47
|
-
stack.enter_context(torch_monkey_patch())
|
52
|
+
self.stack.enter_context(torch_monkey_patch())
|
48
53
|
elif library == "jax":
|
49
54
|
from .jax_ import jax_monkey_patch
|
50
55
|
|
51
|
-
stack.enter_context(jax_monkey_patch())
|
56
|
+
self.stack.enter_context(jax_monkey_patch())
|
52
57
|
elif library == "numpy":
|
53
58
|
from .numpy_ import numpy_monkey_patch
|
54
59
|
|
55
|
-
stack.enter_context(numpy_monkey_patch())
|
60
|
+
self.stack.enter_context(numpy_monkey_patch())
|
56
61
|
else:
|
57
|
-
assert_never(library)
|
62
|
+
assert_never(library) # type: ignore
|
58
63
|
|
59
64
|
log.info(
|
60
|
-
f"Monkey patched libraries: {', '.join(libraries)}. "
|
65
|
+
f"Monkey patched libraries: {', '.join(self.libraries)}. "
|
61
66
|
"You can now use the lovely functions with these libraries."
|
62
67
|
)
|
63
|
-
|
68
|
+
|
69
|
+
@override
|
70
|
+
def unpatch(self):
|
71
|
+
self.stack.close()
|
64
72
|
log.info(
|
65
|
-
f"Unmonkey patched libraries: {', '.join(libraries)}. "
|
73
|
+
f"Unmonkey patched libraries: {', '.join(self.libraries)}. "
|
66
74
|
"You can now use the lovely functions with these libraries."
|
67
75
|
)
|
@@ -3,8 +3,9 @@ from __future__ import annotations
|
|
3
3
|
from typing import TYPE_CHECKING, cast
|
4
4
|
|
5
5
|
import numpy as np
|
6
|
+
from typing_extensions import override
|
6
7
|
|
7
|
-
from ._base import
|
8
|
+
from ._base import lovely_patch, lovely_repr
|
8
9
|
from .utils import LovelyStats, array_stats, patch_to
|
9
10
|
|
10
11
|
if TYPE_CHECKING:
|
@@ -77,18 +78,25 @@ def jax_repr(array: jax.Array) -> LovelyStats | None:
|
|
77
78
|
}
|
78
79
|
|
79
80
|
|
80
|
-
|
81
|
-
|
82
|
-
|
81
|
+
class jax_monkey_patch(lovely_patch):
|
82
|
+
@override
|
83
|
+
def dependencies(self) -> list[str]:
|
84
|
+
return ["jax"]
|
85
|
+
|
86
|
+
@override
|
87
|
+
def patch(self):
|
88
|
+
from jax._src import array
|
89
|
+
|
90
|
+
self.prev_repr = array.ArrayImpl.__repr__
|
91
|
+
self.prev_str = array.ArrayImpl.__str__
|
92
|
+
jax_repr.set_fallback_repr(self.prev_repr)
|
83
93
|
|
84
|
-
prev_repr = array.ArrayImpl.__repr__
|
85
|
-
prev_str = array.ArrayImpl.__str__
|
86
|
-
jax_repr.set_fallback_repr(prev_repr)
|
87
|
-
try:
|
88
94
|
patch_to(array.ArrayImpl, "__repr__", jax_repr)
|
89
95
|
patch_to(array.ArrayImpl, "__str__", jax_repr)
|
90
96
|
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
97
|
+
@override
|
98
|
+
def unpatch(self):
|
99
|
+
from jax._src import array
|
100
|
+
|
101
|
+
patch_to(array.ArrayImpl, "__repr__", self.prev_repr)
|
102
|
+
patch_to(array.ArrayImpl, "__str__", self.prev_str)
|
@@ -3,8 +3,9 @@ from __future__ import annotations
|
|
3
3
|
import logging
|
4
4
|
|
5
5
|
import numpy as np
|
6
|
+
from typing_extensions import override
|
6
7
|
|
7
|
-
from ._base import
|
8
|
+
from ._base import lovely_patch, lovely_repr
|
8
9
|
from .utils import LovelyStats, array_stats
|
9
10
|
|
10
11
|
|
@@ -74,42 +75,41 @@ def numpy_repr(array: np.ndarray) -> LovelyStats | None:
|
|
74
75
|
numpy_repr.set_fallback_repr(np.array_repr)
|
75
76
|
|
76
77
|
|
77
|
-
|
78
|
-
|
78
|
+
class numpy_monkey_patch(lovely_patch):
|
79
|
+
@override
|
80
|
+
def dependencies(self) -> list[str]:
|
81
|
+
return ["numpy"]
|
79
82
|
|
80
|
-
@
|
81
|
-
def
|
82
|
-
|
83
|
+
@override
|
84
|
+
def patch(self):
|
85
|
+
if _np_ge_2():
|
86
|
+
self.original_options = np.get_printoptions()
|
83
87
|
np.set_printoptions(override_repr=numpy_repr)
|
84
88
|
logging.info(
|
85
89
|
f"Numpy monkey patching: using {numpy_repr.__name__} for numpy arrays. "
|
86
90
|
f"{np.get_printoptions()=}"
|
87
91
|
)
|
88
|
-
|
89
|
-
|
90
|
-
|
92
|
+
else:
|
93
|
+
# For legacy numpy, `set_string_function(None)` reverts to the default,
|
94
|
+
# so no state needs to be saved.
|
95
|
+
np.set_string_function(numpy_repr, True) # pyright: ignore[reportAttributeAccessIssue]
|
96
|
+
np.set_string_function(numpy_repr, False) # pyright: ignore[reportAttributeAccessIssue]
|
91
97
|
logging.info(
|
92
|
-
f"Numpy
|
98
|
+
f"Numpy monkey patching: using {numpy_repr.__name__} for numpy arrays. "
|
93
99
|
f"{np.get_printoptions()=}"
|
94
100
|
)
|
95
101
|
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
try:
|
101
|
-
np.set_string_function(numpy_repr, True) # pyright: ignore[reportAttributeAccessIssue]
|
102
|
-
np.set_string_function(numpy_repr, False) # pyright: ignore[reportAttributeAccessIssue]
|
103
|
-
|
102
|
+
@override
|
103
|
+
def unpatch(self):
|
104
|
+
if _np_ge_2():
|
105
|
+
np.set_printoptions(**self.original_options)
|
104
106
|
logging.info(
|
105
|
-
f"Numpy
|
107
|
+
f"Numpy unmonkey patching: using {numpy_repr.__name__} for numpy arrays. "
|
106
108
|
f"{np.get_printoptions()=}"
|
107
109
|
)
|
108
|
-
|
109
|
-
finally:
|
110
|
+
else:
|
110
111
|
np.set_string_function(None, True) # pyright: ignore[reportAttributeAccessIssue]
|
111
112
|
np.set_string_function(None, False) # pyright: ignore[reportAttributeAccessIssue]
|
112
|
-
|
113
113
|
logging.info(
|
114
114
|
f"Numpy unmonkey patching: using {numpy_repr.__name__} for numpy arrays. "
|
115
115
|
f"{np.get_printoptions()=}"
|
@@ -3,8 +3,9 @@ from __future__ import annotations
|
|
3
3
|
from typing import TYPE_CHECKING
|
4
4
|
|
5
5
|
import numpy as np
|
6
|
+
from typing_extensions import override
|
6
7
|
|
7
|
-
from ._base import
|
8
|
+
from ._base import lovely_patch, lovely_repr
|
8
9
|
from .utils import LovelyStats, array_stats, patch_to
|
9
10
|
|
10
11
|
if TYPE_CHECKING:
|
@@ -80,16 +81,20 @@ def torch_repr(tensor: torch.Tensor) -> LovelyStats | None:
|
|
80
81
|
}
|
81
82
|
|
82
83
|
|
83
|
-
|
84
|
-
|
85
|
-
|
84
|
+
class torch_monkey_patch(lovely_patch):
|
85
|
+
@override
|
86
|
+
def dependencies(self) -> list[str]:
|
87
|
+
return ["torch"]
|
88
|
+
|
89
|
+
@override
|
90
|
+
def patch(self):
|
91
|
+
import torch
|
86
92
|
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
93
|
+
self.original_repr = torch.Tensor.__repr__
|
94
|
+
self.original_str = torch.Tensor.__str__
|
95
|
+
self.original_parameter_repr = torch.nn.Parameter.__repr__
|
96
|
+
torch_repr.set_fallback_repr(self.original_repr)
|
91
97
|
|
92
|
-
try:
|
93
98
|
patch_to(torch.Tensor, "__repr__", torch_repr)
|
94
99
|
patch_to(torch.Tensor, "__str__", torch_repr)
|
95
100
|
try:
|
@@ -97,8 +102,10 @@ def torch_monkey_patch():
|
|
97
102
|
except AttributeError:
|
98
103
|
pass
|
99
104
|
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
patch_to(torch.
|
105
|
+
@override
|
106
|
+
def unpatch(self):
|
107
|
+
import torch
|
108
|
+
|
109
|
+
patch_to(torch.Tensor, "__repr__", self.original_repr)
|
110
|
+
patch_to(torch.Tensor, "__str__", self.original_str)
|
111
|
+
patch_to(torch.nn.Parameter, "__repr__", self.original_parameter_repr)
|
@@ -1,92 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
import contextlib
|
4
|
-
import functools
|
5
|
-
from collections.abc import Callable, Iterator
|
6
|
-
from typing import Any, Generic
|
7
|
-
|
8
|
-
from typing_extensions import ParamSpec, TypeVar, override
|
9
|
-
|
10
|
-
R = TypeVar("R")
|
11
|
-
P = ParamSpec("P")
|
12
|
-
|
13
|
-
|
14
|
-
class ContextResource(contextlib.AbstractContextManager[R], Generic[R]):
|
15
|
-
"""A class that provides both direct access to a resource and context management."""
|
16
|
-
|
17
|
-
def __init__(self, resource: R, cleanup_func: Callable[[R], Any]):
|
18
|
-
self.resource = resource
|
19
|
-
self._cleanup_func = cleanup_func
|
20
|
-
|
21
|
-
@override
|
22
|
-
def __enter__(self) -> R:
|
23
|
-
"""When used as a context manager, return the wrapped resource."""
|
24
|
-
return self.resource
|
25
|
-
|
26
|
-
@override
|
27
|
-
def __exit__(self, *exc_info) -> None:
|
28
|
-
"""Clean up the resource when exiting the context."""
|
29
|
-
self._cleanup_func(self.resource)
|
30
|
-
|
31
|
-
def close(self) -> None:
|
32
|
-
"""Explicitly clean up the resource."""
|
33
|
-
self._cleanup_func(self.resource)
|
34
|
-
|
35
|
-
|
36
|
-
def resource_factory(
|
37
|
-
create_func: Callable[P, R], cleanup_func: Callable[[R], None]
|
38
|
-
) -> Callable[P, ContextResource[R]]:
|
39
|
-
"""
|
40
|
-
Create a factory function that returns a ContextResource.
|
41
|
-
|
42
|
-
Args:
|
43
|
-
create_func: Function that creates the resource
|
44
|
-
cleanup_func: Function that cleans up the resource
|
45
|
-
|
46
|
-
Returns:
|
47
|
-
A function that returns a ContextResource wrapping the created resource
|
48
|
-
"""
|
49
|
-
|
50
|
-
@functools.wraps(create_func)
|
51
|
-
def wrapper(*args: P.args, **kwargs: P.kwargs) -> ContextResource[R]:
|
52
|
-
resource = create_func(*args, **kwargs)
|
53
|
-
return ContextResource(resource, cleanup_func)
|
54
|
-
|
55
|
-
return wrapper
|
56
|
-
|
57
|
-
|
58
|
-
def resource_factory_from_context_fn(
|
59
|
-
context_func: Callable[P, contextlib.AbstractContextManager[R]],
|
60
|
-
) -> Callable[P, ContextResource[R]]:
|
61
|
-
"""
|
62
|
-
Create a factory function that returns a ContextResource.
|
63
|
-
|
64
|
-
Args:
|
65
|
-
context_func: Function that creates the resource
|
66
|
-
|
67
|
-
Returns:
|
68
|
-
A function that returns a ContextResource wrapping the created resource
|
69
|
-
"""
|
70
|
-
|
71
|
-
@functools.wraps(context_func)
|
72
|
-
def wrapper(*args: P.args, **kwargs: P.kwargs) -> ContextResource[R]:
|
73
|
-
context = context_func(*args, **kwargs)
|
74
|
-
resource = context.__enter__()
|
75
|
-
return ContextResource(resource, lambda _: context.__exit__(None, None, None))
|
76
|
-
|
77
|
-
return wrapper
|
78
|
-
|
79
|
-
|
80
|
-
def resource_factory_contextmanager(
|
81
|
-
context_func: Callable[P, Iterator[R]],
|
82
|
-
) -> Callable[P, ContextResource[R]]:
|
83
|
-
"""
|
84
|
-
Create a factory function that returns a ContextResource.
|
85
|
-
|
86
|
-
Args:
|
87
|
-
context_func: Generator function that creates the resource, yields it, and cleans up the resource when done.
|
88
|
-
|
89
|
-
Returns:
|
90
|
-
A function that returns a ContextResource wrapping the created resource
|
91
|
-
"""
|
92
|
-
return resource_factory_from_context_fn(contextlib.contextmanager(context_func))
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|