nshutils 0.30.1__py3-none-any.whl → 0.31.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nshutils/__init__.pyi CHANGED
@@ -7,6 +7,7 @@ from .collections import apply_to_collection as apply_to_collection
7
7
  from .display import display as display
8
8
  from .logging import init_python_logging as init_python_logging
9
9
  from .logging import setup_logging as setup_logging
10
+ from .lovely import lovely_monkey_patch as lovely_monkey_patch
10
11
  from .snoop import snoop as snoop
11
12
  from .typecheck import tassert as tassert
12
13
  from .typecheck import typecheck_modules as typecheck_modules
@@ -105,14 +105,41 @@ class ActLoad:
105
105
  path, _ = max(all_versions, key=lambda p: p[1])
106
106
  return cls(path)
107
107
 
108
- def __init__(self, dir: Path):
108
+ def __init__(
109
+ self,
110
+ dir: Path | None = None,
111
+ *,
112
+ _base_activations: dict[str, LoadedActivation] | None = None,
113
+ _prefix_chain: list[str] | None = None,
114
+ ):
115
+ """Initialize ActLoad from a directory or from filtered activations.
116
+
117
+ Args:
118
+ dir: Path to the activation directory. Required for root ActLoad instances.
119
+ _base_activations: Pre-filtered activations dict. Used internally for prefix filtering.
120
+ _prefix_chain: Chain of prefixes that have been applied. Used for repr.
121
+ """
109
122
  self._dir = dir
123
+ self._base_activations = _base_activations
124
+ self._prefix_chain = _prefix_chain or []
110
125
 
111
126
  def activation(self, name: str):
112
- return LoadedActivation(self._dir, name)
127
+ if self._dir is None:
128
+ raise ValueError(
129
+ "Cannot create activation from filtered ActLoad without base directory"
130
+ )
131
+ # For filtered instances, we need to reconstruct the full name
132
+ full_name = "".join(self._prefix_chain) + name
133
+ return LoadedActivation(self._dir, full_name)
113
134
 
114
135
  @cached_property
115
136
  def activations(self):
137
+ if self._base_activations is not None:
138
+ return self._base_activations
139
+
140
+ if self._dir is None:
141
+ raise ValueError("ActLoad requires either dir or _base_activations")
142
+
116
143
  dirs = list(self._dir.iterdir())
117
144
  # Sort the dirs by the last modified time
118
145
  dirs.sort(key=lambda p: p.stat().st_mtime)
@@ -128,6 +155,9 @@ class ActLoad:
128
155
  def __len__(self):
129
156
  return len(self.activations)
130
157
 
158
+ def _ipython_key_completions_(self):
159
+ return list(self.activations.keys())
160
+
131
161
  @override
132
162
  def __repr__(self):
133
163
  acts_str = pprint.pformat(
@@ -137,10 +167,50 @@ class ActLoad:
137
167
  }
138
168
  )
139
169
  acts_str = acts_str.replace("'<", "<").replace(">'", ">")
140
- return f"ActLoad({acts_str})"
170
+
171
+ if self._prefix_chain:
172
+ prefix_str = "".join(self._prefix_chain)
173
+ return f"ActLoad(prefix='{prefix_str}', {acts_str})"
174
+ else:
175
+ return f"ActLoad({acts_str})"
141
176
 
142
177
  def get(self, name: str, /, default: T) -> LoadedActivation | T:
143
178
  return self.activations.get(name, default)
144
179
 
180
+ def filter_by_prefix(self, prefix: str) -> ActLoad:
181
+ """Create a filtered view of activations that match the given prefix.
182
+
183
+ Args:
184
+ prefix: The prefix to filter by. Only activations whose names start
185
+ with this prefix will be included in the filtered view.
186
+
187
+ Returns:
188
+ A new ActLoad instance that provides access to matching activations
189
+ with the prefix stripped from their keys. Can be chained multiple times.
190
+
191
+ Example:
192
+ >>> loader = ActLoad(some_dir)
193
+ >>> # If loader has keys "my.activation.first", "my.activation.second", "other.key"
194
+ >>> filtered = loader.filter_by_prefix("my.activation.")
195
+ >>> filtered["first"] # Accesses "my.activation.first"
196
+ >>> filtered["second"] # Accesses "my.activation.second"
197
+ >>> # Can be chained:
198
+ >>> double_filtered = loader.filter_by_prefix("my.").filter_by_prefix("activation.")
199
+ >>> double_filtered["first"] # Also accesses "my.activation.first"
200
+ """
201
+ filtered_activations = {}
202
+ for name, activation in self.activations.items():
203
+ if name.startswith(prefix):
204
+ # Strip the prefix from the key
205
+ stripped_name = name[len(prefix) :]
206
+ filtered_activations[stripped_name] = activation
207
+
208
+ new_prefix_chain = self._prefix_chain + [prefix]
209
+ return ActLoad(
210
+ _base_activations=filtered_activations,
211
+ _prefix_chain=new_prefix_chain,
212
+ dir=self._dir,
213
+ )
214
+
145
215
 
146
216
  ActivationLoader = ActLoad
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import contextlib
4
4
  import fnmatch
5
+ import os
5
6
  import tempfile
6
7
  import weakref
7
8
  from collections.abc import Callable, Mapping
@@ -266,7 +267,11 @@ class ActSaveProvider:
266
267
 
267
268
  if save_dir is None:
268
269
  save_dir = Path(tempfile.gettempdir()) / f"actsave-{uuid7str()}"
269
- log.critical(f"No save_dir specified, using {save_dir=}")
270
+ log.warning(
271
+ f"ActSave: Using temporary directory {save_dir} for activations."
272
+ )
273
+ else:
274
+ log.info(f"ActSave enabled. Saving to {save_dir}")
270
275
  self._saver = _Saver(save_dir, lambda: self._prefixes)
271
276
 
272
277
  def disable(self):
@@ -307,6 +312,18 @@ class ActSaveProvider:
307
312
  self._prefixes = []
308
313
  self._disable_count = 0
309
314
 
315
+ # Check for environment variable `ACTSAVE` to automatically enable saving.
316
+ # If set to "1" or "true" (case-insensitive), activations are saved to a temporary directory.
317
+ # If set to a path, activations are saved to that path.
318
+ if env_var := os.environ.get("ACTSAVE"):
319
+ log.info(
320
+ f"`ACTSAVE={env_var}` detected, attempting to auto-enable activation saving."
321
+ )
322
+ if env_var.lower() in ("1", "true"):
323
+ self.enable()
324
+ else:
325
+ self.enable(Path(env_var))
326
+
310
327
  @contextlib.contextmanager
311
328
  def disabled(self, condition: bool | Callable[[], bool] = True):
312
329
  """
@@ -8,3 +8,5 @@ from .numpy_ import numpy_monkey_patch as numpy_monkey_patch
8
8
  from .numpy_ import numpy_repr as numpy_repr
9
9
  from .torch_ import torch_monkey_patch as torch_monkey_patch
10
10
  from .torch_ import torch_repr as torch_repr
11
+
12
+ lovely_monkey_patch = monkey_patch
nshutils/lovely/_base.py CHANGED
@@ -1,10 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import contextlib
3
4
  import functools
4
5
  import importlib.util
5
6
  import logging
6
- from collections.abc import Callable, Iterator
7
- from typing import Generic, Optional, cast
7
+ from abc import ABC, abstractmethod
8
+ from collections.abc import Callable
9
+ from typing import Any, Generic, Optional, cast
8
10
 
9
11
  from typing_extensions import (
10
12
  ParamSpec,
@@ -15,7 +17,6 @@ from typing_extensions import (
15
17
  runtime_checkable,
16
18
  )
17
19
 
18
- from ..util import ContextResource, resource_factory_contextmanager
19
20
  from .utils import LovelyStats, format_tensor_stats
20
21
 
21
22
  log = logging.getLogger(__name__)
@@ -113,72 +114,47 @@ class lovely_repr(Generic[TArray]):
113
114
  return wrapper
114
115
 
115
116
 
116
- LovelyMonkeyPatchInputFn = TypeAliasType(
117
- "LovelyMonkeyPatchInputFn",
118
- Callable[P, Iterator[None]],
119
- type_params=(P,),
120
- )
121
- LovelyMonkeyPatchFn = TypeAliasType(
122
- "LovelyMonkeyPatchFn",
123
- Callable[P, ContextResource[None]],
124
- type_params=(P,),
125
- )
117
+ class lovely_patch(contextlib.AbstractContextManager["lovely_patch"], ABC):
118
+ def __init__(self):
119
+ self._patched = False
120
+ self.__enter__()
126
121
 
122
+ def dependencies(self) -> list[str]:
123
+ """Subclasses can override this to specify the dependencies of the patch."""
124
+ return []
127
125
 
128
- def _nullcontext_generator():
129
- """A generator that does nothing."""
130
- yield
126
+ @abstractmethod
127
+ def patch(self):
128
+ """Subclasses must implement this."""
131
129
 
130
+ @abstractmethod
131
+ def unpatch(self):
132
+ """Subclasses must implement this."""
132
133
 
133
- def _wrap_monkey_patch_fn(
134
- monkey_patch_fn: LovelyMonkeyPatchInputFn[P],
135
- dependencies: list[str],
136
- ) -> LovelyMonkeyPatchInputFn[P]:
137
- @functools.wraps(monkey_patch_fn)
138
- def wrapper(*args: P.args, **kwargs: P.kwargs) -> Iterator[None]:
139
- if missing_deps := _find_missing_deps(dependencies):
134
+ @override
135
+ def __enter__(self):
136
+ if self._patched:
137
+ return self
138
+
139
+ if missing_deps := _find_missing_deps(self.dependencies()):
140
140
  log.warning(
141
141
  f"Missing dependencies: {', '.join(missing_deps)}. "
142
142
  "Skipping monkey patch."
143
143
  )
144
- return _nullcontext_generator()
145
-
146
- return monkey_patch_fn(*args, **kwargs)
147
-
148
- return wrapper
149
-
150
-
151
- def monkey_patch_contextmanager(dependencies: list[str]):
152
- """
153
- Decorator to create a monkey patch function for an array.
154
-
155
- Args:
156
- dependencies: List of dependencies to check before running the function.
157
- If any dependency is not available, the function will not run.
144
+ return self
158
145
 
159
- Returns:
160
- A decorator function that takes a function and returns a monkey patch function.
146
+ self.patch()
147
+ self._patched = True
148
+ return self
161
149
 
162
- Example:
163
- @monkey_patch_contextmanager(dependencies=["torch"])
164
- def my_array_monkey_patch():
165
- ...
166
- """
167
-
168
- def decorator_fn(
169
- monkey_patch_fn: LovelyMonkeyPatchInputFn[P],
170
- ) -> LovelyMonkeyPatchFn[P]:
171
- """
172
- Decorator to create a monkey patch function for an array.
173
-
174
- Args:
175
- monkey_patch_fn: A function that applies the monkey patch.
176
-
177
- Returns:
178
- A function that applies the monkey patch.
179
- """
150
+ @override
151
+ def __exit__(self, *exc_info):
152
+ if not self._patched:
153
+ return
180
154
 
181
- wrapped_fn = _wrap_monkey_patch_fn(monkey_patch_fn, dependencies)
182
- return resource_factory_contextmanager(wrapped_fn)
155
+ self.unpatch()
156
+ self._patched = False
183
157
 
184
- return decorator_fn
158
+ def close(self):
159
+ """Explicitly clean up the resource."""
160
+ self.__exit__(None, None, None)
@@ -5,9 +5,9 @@ import importlib.util
5
5
  import logging
6
6
  from typing import Literal
7
7
 
8
- from typing_extensions import TypeAliasType, assert_never
8
+ from typing_extensions import TypeAliasType, assert_never, override
9
9
 
10
- from ..util import resource_factory_contextmanager
10
+ from ._base import lovely_patch
11
11
 
12
12
  Library = TypeAliasType("Library", Literal["numpy", "torch", "jax"])
13
13
 
@@ -28,40 +28,48 @@ def _find_deps() -> list[Library]:
28
28
  return deps
29
29
 
30
30
 
31
- @resource_factory_contextmanager
32
- def monkey_patch(libraries: list[Library] | Literal["auto"] = "auto"):
33
- if libraries == "auto":
34
- libraries = _find_deps()
31
+ class monkey_patch(lovely_patch):
32
+ def __init__(self, libraries: list[Library] | Literal["auto"] = "auto"):
33
+ self.libraries = libraries
34
+ if self.libraries == "auto":
35
+ self.libraries = _find_deps()
35
36
 
36
- if not libraries:
37
- raise ValueError(
38
- "No libraries found for monkey patching. "
39
- "Please install numpy, torch, or jax."
40
- )
37
+ if not self.libraries:
38
+ raise ValueError(
39
+ "No libraries found for monkey patching. "
40
+ "Please install numpy, torch, or jax."
41
+ )
42
+
43
+ self.stack = contextlib.ExitStack()
44
+ super().__init__()
41
45
 
42
- with contextlib.ExitStack() as stack:
43
- for library in libraries:
46
+ @override
47
+ def patch(self):
48
+ for library in self.libraries:
44
49
  if library == "torch":
45
50
  from .torch_ import torch_monkey_patch
46
51
 
47
- stack.enter_context(torch_monkey_patch())
52
+ self.stack.enter_context(torch_monkey_patch())
48
53
  elif library == "jax":
49
54
  from .jax_ import jax_monkey_patch
50
55
 
51
- stack.enter_context(jax_monkey_patch())
56
+ self.stack.enter_context(jax_monkey_patch())
52
57
  elif library == "numpy":
53
58
  from .numpy_ import numpy_monkey_patch
54
59
 
55
- stack.enter_context(numpy_monkey_patch())
60
+ self.stack.enter_context(numpy_monkey_patch())
56
61
  else:
57
- assert_never(library)
62
+ assert_never(library) # type: ignore
58
63
 
59
64
  log.info(
60
- f"Monkey patched libraries: {', '.join(libraries)}. "
65
+ f"Monkey patched libraries: {', '.join(self.libraries)}. "
61
66
  "You can now use the lovely functions with these libraries."
62
67
  )
63
- yield
68
+
69
+ @override
70
+ def unpatch(self):
71
+ self.stack.close()
64
72
  log.info(
65
- f"Unmonkey patched libraries: {', '.join(libraries)}. "
73
+ f"Unmonkey patched libraries: {', '.join(self.libraries)}. "
66
74
  "You can now use the lovely functions with these libraries."
67
75
  )
nshutils/lovely/jax_.py CHANGED
@@ -3,8 +3,9 @@ from __future__ import annotations
3
3
  from typing import TYPE_CHECKING, cast
4
4
 
5
5
  import numpy as np
6
+ from typing_extensions import override
6
7
 
7
- from ._base import lovely_repr, monkey_patch_contextmanager
8
+ from ._base import lovely_patch, lovely_repr
8
9
  from .utils import LovelyStats, array_stats, patch_to
9
10
 
10
11
  if TYPE_CHECKING:
@@ -77,18 +78,25 @@ def jax_repr(array: jax.Array) -> LovelyStats | None:
77
78
  }
78
79
 
79
80
 
80
- @monkey_patch_contextmanager(dependencies=["jax"])
81
- def jax_monkey_patch():
82
- from jax._src import array
81
+ class jax_monkey_patch(lovely_patch):
82
+ @override
83
+ def dependencies(self) -> list[str]:
84
+ return ["jax"]
85
+
86
+ @override
87
+ def patch(self):
88
+ from jax._src import array
89
+
90
+ self.prev_repr = array.ArrayImpl.__repr__
91
+ self.prev_str = array.ArrayImpl.__str__
92
+ jax_repr.set_fallback_repr(self.prev_repr)
83
93
 
84
- prev_repr = array.ArrayImpl.__repr__
85
- prev_str = array.ArrayImpl.__str__
86
- jax_repr.set_fallback_repr(prev_repr)
87
- try:
88
94
  patch_to(array.ArrayImpl, "__repr__", jax_repr)
89
95
  patch_to(array.ArrayImpl, "__str__", jax_repr)
90
96
 
91
- yield
92
- finally:
93
- patch_to(array.ArrayImpl, "__repr__", prev_repr)
94
- patch_to(array.ArrayImpl, "__str__", prev_str)
97
+ @override
98
+ def unpatch(self):
99
+ from jax._src import array
100
+
101
+ patch_to(array.ArrayImpl, "__repr__", self.prev_repr)
102
+ patch_to(array.ArrayImpl, "__str__", self.prev_str)
nshutils/lovely/numpy_.py CHANGED
@@ -3,8 +3,9 @@ from __future__ import annotations
3
3
  import logging
4
4
 
5
5
  import numpy as np
6
+ from typing_extensions import override
6
7
 
7
- from ._base import lovely_repr, monkey_patch_contextmanager
8
+ from ._base import lovely_patch, lovely_repr
8
9
  from .utils import LovelyStats, array_stats
9
10
 
10
11
 
@@ -74,42 +75,41 @@ def numpy_repr(array: np.ndarray) -> LovelyStats | None:
74
75
  numpy_repr.set_fallback_repr(np.array_repr)
75
76
 
76
77
 
77
- # If numpy 2.0, use the new API override_repr.
78
- if _np_ge_2():
78
+ class numpy_monkey_patch(lovely_patch):
79
+ @override
80
+ def dependencies(self) -> list[str]:
81
+ return ["numpy"]
79
82
 
80
- @monkey_patch_contextmanager(dependencies=["numpy"])
81
- def numpy_monkey_patch():
82
- try:
83
+ @override
84
+ def patch(self):
85
+ if _np_ge_2():
86
+ self.original_options = np.get_printoptions()
83
87
  np.set_printoptions(override_repr=numpy_repr)
84
88
  logging.info(
85
89
  f"Numpy monkey patching: using {numpy_repr.__name__} for numpy arrays. "
86
90
  f"{np.get_printoptions()=}"
87
91
  )
88
- yield
89
- finally:
90
- np.set_printoptions(override_repr=None)
92
+ else:
93
+ # For legacy numpy, `set_string_function(None)` reverts to the default,
94
+ # so no state needs to be saved.
95
+ np.set_string_function(numpy_repr, True) # pyright: ignore[reportAttributeAccessIssue]
96
+ np.set_string_function(numpy_repr, False) # pyright: ignore[reportAttributeAccessIssue]
91
97
  logging.info(
92
- f"Numpy unmonkey patching: using {numpy_repr.__name__} for numpy arrays. "
98
+ f"Numpy monkey patching: using {numpy_repr.__name__} for numpy arrays. "
93
99
  f"{np.get_printoptions()=}"
94
100
  )
95
101
 
96
- else:
97
-
98
- @monkey_patch_contextmanager(dependencies=["numpy"])
99
- def numpy_monkey_patch():
100
- try:
101
- np.set_string_function(numpy_repr, True) # pyright: ignore[reportAttributeAccessIssue]
102
- np.set_string_function(numpy_repr, False) # pyright: ignore[reportAttributeAccessIssue]
103
-
102
+ @override
103
+ def unpatch(self):
104
+ if _np_ge_2():
105
+ np.set_printoptions(**self.original_options)
104
106
  logging.info(
105
- f"Numpy monkey patching: using {numpy_repr.__name__} for numpy arrays. "
107
+ f"Numpy unmonkey patching: using {numpy_repr.__name__} for numpy arrays. "
106
108
  f"{np.get_printoptions()=}"
107
109
  )
108
- yield
109
- finally:
110
+ else:
110
111
  np.set_string_function(None, True) # pyright: ignore[reportAttributeAccessIssue]
111
112
  np.set_string_function(None, False) # pyright: ignore[reportAttributeAccessIssue]
112
-
113
113
  logging.info(
114
114
  f"Numpy unmonkey patching: using {numpy_repr.__name__} for numpy arrays. "
115
115
  f"{np.get_printoptions()=}"
nshutils/lovely/torch_.py CHANGED
@@ -3,8 +3,9 @@ from __future__ import annotations
3
3
  from typing import TYPE_CHECKING
4
4
 
5
5
  import numpy as np
6
+ from typing_extensions import override
6
7
 
7
- from ._base import lovely_repr, monkey_patch_contextmanager
8
+ from ._base import lovely_patch, lovely_repr
8
9
  from .utils import LovelyStats, array_stats, patch_to
9
10
 
10
11
  if TYPE_CHECKING:
@@ -80,16 +81,20 @@ def torch_repr(tensor: torch.Tensor) -> LovelyStats | None:
80
81
  }
81
82
 
82
83
 
83
- @monkey_patch_contextmanager(dependencies=["torch"])
84
- def torch_monkey_patch():
85
- import torch
84
+ class torch_monkey_patch(lovely_patch):
85
+ @override
86
+ def dependencies(self) -> list[str]:
87
+ return ["torch"]
88
+
89
+ @override
90
+ def patch(self):
91
+ import torch
86
92
 
87
- original_repr = torch.Tensor.__repr__
88
- original_str = torch.Tensor.__str__
89
- original_parameter_repr = torch.nn.Parameter.__repr__
90
- torch_repr.set_fallback_repr(original_repr)
93
+ self.original_repr = torch.Tensor.__repr__
94
+ self.original_str = torch.Tensor.__str__
95
+ self.original_parameter_repr = torch.nn.Parameter.__repr__
96
+ torch_repr.set_fallback_repr(self.original_repr)
91
97
 
92
- try:
93
98
  patch_to(torch.Tensor, "__repr__", torch_repr)
94
99
  patch_to(torch.Tensor, "__str__", torch_repr)
95
100
  try:
@@ -97,8 +102,10 @@ def torch_monkey_patch():
97
102
  except AttributeError:
98
103
  pass
99
104
 
100
- yield
101
- finally:
102
- patch_to(torch.Tensor, "__repr__", original_repr)
103
- patch_to(torch.Tensor, "__str__", original_str)
104
- patch_to(torch.nn.Parameter, "__repr__", original_parameter_repr)
105
+ @override
106
+ def unpatch(self):
107
+ import torch
108
+
109
+ patch_to(torch.Tensor, "__repr__", self.original_repr)
110
+ patch_to(torch.Tensor, "__str__", self.original_str)
111
+ patch_to(torch.nn.Parameter, "__repr__", self.original_parameter_repr)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshutils
3
- Version: 0.30.1
3
+ Version: 0.31.1
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -0,0 +1,21 @@
1
+ nshutils/__init__.py,sha256=AFx1d5k34MyJ2kCHQL5vrZB8GDp2nYUaIUEjszSa25I,477
2
+ nshutils/__init__.pyi,sha256=R4TIk--jAgVyTibdgezJQTMce3HpMCNakAJeaDqA6bc,676
3
+ nshutils/actsave/__init__.py,sha256=hAVsog9d1g3_rQN1TRslrl6sK1PhCGbjy8PPUAmJI58,203
4
+ nshutils/actsave/_loader.py,sha256=btLSQdErpTmK6VyG8PxJrJNsztzyavSF71n4Ec3_49E,7619
5
+ nshutils/actsave/_saver.py,sha256=_qkX0NZYvy31hdlyfhneac4kUNS_44XjOG0ZtKpdqrg,12720
6
+ nshutils/collections.py,sha256=QWGyANmo4Efq4XRNHDSTE9tRLStwEZHGwE0ATHR-Vqo,5233
7
+ nshutils/display.py,sha256=Ge63yllx7gi-MKL3mKQeQ5doql_nj56-o5aoTVmusDg,1473
8
+ nshutils/logging.py,sha256=78pv3-I_gmbKSf5_mYYBr6_H4GNBGErghAdhH9wfYIc,2205
9
+ nshutils/lovely/__init__.py,sha256=684eZOKLmSgsTcCVlWk1Ip1cxJxmz-rKeXLmWXuCEWA,487
10
+ nshutils/lovely/_base.py,sha256=kJY-UhdFTRBlAg_YzfJmG4ICb6vSdOJKiRc6vksxvoE,4424
11
+ nshutils/lovely/_monkey_patch_all.py,sha256=xq09InGcOsGDrELV_KIrhE0H4EWyMdrUZ_1_BR2e_b0,2224
12
+ nshutils/lovely/config.py,sha256=lVNMuU1oUvsYlGN0Sn-m6iOLbJIchVnWDpyHm09nWo8,1224
13
+ nshutils/lovely/jax_.py,sha256=PGnv33LrEM3aLvXLBbAx4b7dOkJwONidyPZjToZ62Og,2592
14
+ nshutils/lovely/numpy_.py,sha256=BBP9663l4Hr-TB34xDMHQQZ1zpuOgBegUOGl7_wV6R0,3503
15
+ nshutils/lovely/torch_.py,sha256=J1pDJY1zzEANqa6EaJpG1pc_SYgM8YWOo1TjWdVeiA0,2946
16
+ nshutils/lovely/utils.py,sha256=2ksT5YGVViFuWc8jSkwVCsABripJmyVJdEDDH7aab70,10459
17
+ nshutils/snoop.py,sha256=7d7_Q5sJmINL1J29wcnxEvpV95zvZYNoVn5frCq-rww,7393
18
+ nshutils/typecheck.py,sha256=Gi7xtfilN_UwZ1FTFqBVKDhcQzBEDonVxIv3bUj-uXY,5582
19
+ nshutils-0.31.1.dist-info/METADATA,sha256=pW-XE-rF3TPtok9VPan8cG6IlaEjlX55NBL4SKaNOHQ,4406
20
+ nshutils-0.31.1.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
21
+ nshutils-0.31.1.dist-info/RECORD,,
nshutils/util.py DELETED
@@ -1,92 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import contextlib
4
- import functools
5
- from collections.abc import Callable, Iterator
6
- from typing import Any, Generic
7
-
8
- from typing_extensions import ParamSpec, TypeVar, override
9
-
10
- R = TypeVar("R")
11
- P = ParamSpec("P")
12
-
13
-
14
- class ContextResource(contextlib.AbstractContextManager[R], Generic[R]):
15
- """A class that provides both direct access to a resource and context management."""
16
-
17
- def __init__(self, resource: R, cleanup_func: Callable[[R], Any]):
18
- self.resource = resource
19
- self._cleanup_func = cleanup_func
20
-
21
- @override
22
- def __enter__(self) -> R:
23
- """When used as a context manager, return the wrapped resource."""
24
- return self.resource
25
-
26
- @override
27
- def __exit__(self, *exc_info) -> None:
28
- """Clean up the resource when exiting the context."""
29
- self._cleanup_func(self.resource)
30
-
31
- def close(self) -> None:
32
- """Explicitly clean up the resource."""
33
- self._cleanup_func(self.resource)
34
-
35
-
36
- def resource_factory(
37
- create_func: Callable[P, R], cleanup_func: Callable[[R], None]
38
- ) -> Callable[P, ContextResource[R]]:
39
- """
40
- Create a factory function that returns a ContextResource.
41
-
42
- Args:
43
- create_func: Function that creates the resource
44
- cleanup_func: Function that cleans up the resource
45
-
46
- Returns:
47
- A function that returns a ContextResource wrapping the created resource
48
- """
49
-
50
- @functools.wraps(create_func)
51
- def wrapper(*args: P.args, **kwargs: P.kwargs) -> ContextResource[R]:
52
- resource = create_func(*args, **kwargs)
53
- return ContextResource(resource, cleanup_func)
54
-
55
- return wrapper
56
-
57
-
58
- def resource_factory_from_context_fn(
59
- context_func: Callable[P, contextlib.AbstractContextManager[R]],
60
- ) -> Callable[P, ContextResource[R]]:
61
- """
62
- Create a factory function that returns a ContextResource.
63
-
64
- Args:
65
- context_func: Function that creates the resource
66
-
67
- Returns:
68
- A function that returns a ContextResource wrapping the created resource
69
- """
70
-
71
- @functools.wraps(context_func)
72
- def wrapper(*args: P.args, **kwargs: P.kwargs) -> ContextResource[R]:
73
- context = context_func(*args, **kwargs)
74
- resource = context.__enter__()
75
- return ContextResource(resource, lambda _: context.__exit__(None, None, None))
76
-
77
- return wrapper
78
-
79
-
80
- def resource_factory_contextmanager(
81
- context_func: Callable[P, Iterator[R]],
82
- ) -> Callable[P, ContextResource[R]]:
83
- """
84
- Create a factory function that returns a ContextResource.
85
-
86
- Args:
87
- context_func: Generator function that creates the resource, yields it, and cleans up the resource when done.
88
-
89
- Returns:
90
- A function that returns a ContextResource wrapping the created resource
91
- """
92
- return resource_factory_from_context_fn(contextlib.contextmanager(context_func))
@@ -1,22 +0,0 @@
1
- nshutils/__init__.py,sha256=AFx1d5k34MyJ2kCHQL5vrZB8GDp2nYUaIUEjszSa25I,477
2
- nshutils/__init__.pyi,sha256=ICbY2_XBAlXIVOGyK4PQpatmlUFHHc5-bqM4sfFZoAY,613
3
- nshutils/actsave/__init__.py,sha256=hAVsog9d1g3_rQN1TRslrl6sK1PhCGbjy8PPUAmJI58,203
4
- nshutils/actsave/_loader.py,sha256=mof3HezUNvLliz7macstX6ewXW05L0Mtv3zJyrbmImg,4640
5
- nshutils/actsave/_saver.py,sha256=IS9TVP8WUizoj5fHrQ6hodtjidT__LDRwz5aoWHupVo,12013
6
- nshutils/collections.py,sha256=QWGyANmo4Efq4XRNHDSTE9tRLStwEZHGwE0ATHR-Vqo,5233
7
- nshutils/display.py,sha256=Ge63yllx7gi-MKL3mKQeQ5doql_nj56-o5aoTVmusDg,1473
8
- nshutils/logging.py,sha256=78pv3-I_gmbKSf5_mYYBr6_H4GNBGErghAdhH9wfYIc,2205
9
- nshutils/lovely/__init__.py,sha256=gbWMNs7xfK1CiNdkHvfH0KcyaGjdZ8_WUBGfaEUDN4I,451
10
- nshutils/lovely/_base.py,sha256=-JYF2zci04PJjmkBdm_iV3uWgD_d7e5zCIAINDlQIKc,5266
11
- nshutils/lovely/_monkey_patch_all.py,sha256=zgMupp2Wc_O9R3arl-BAIePpvQSi6TCeshGMaui-Cc8,1986
12
- nshutils/lovely/config.py,sha256=lVNMuU1oUvsYlGN0Sn-m6iOLbJIchVnWDpyHm09nWo8,1224
13
- nshutils/lovely/jax_.py,sha256=c_hvlch_c9OZ0WJjFIeY46kKQcCELspwdmoexkKLsCg,2412
14
- nshutils/lovely/numpy_.py,sha256=GDOOuhCYfShfKUZiuI8J91eAm27urrYyxETTR-Mxz0E,3362
15
- nshutils/lovely/torch_.py,sha256=9diSkM1L2B6l0yQqTRBoZUVElyqgHUcJdFCXD3NvTxk,2767
16
- nshutils/lovely/utils.py,sha256=2ksT5YGVViFuWc8jSkwVCsABripJmyVJdEDDH7aab70,10459
17
- nshutils/snoop.py,sha256=7d7_Q5sJmINL1J29wcnxEvpV95zvZYNoVn5frCq-rww,7393
18
- nshutils/typecheck.py,sha256=Gi7xtfilN_UwZ1FTFqBVKDhcQzBEDonVxIv3bUj-uXY,5582
19
- nshutils/util.py,sha256=tx-XiRbOrpafV3OkJDE5IVFtzn3kN7uSZ8FkMor0H5c,2845
20
- nshutils-0.30.1.dist-info/METADATA,sha256=iFnO4L_bpOdtCm8f-m7IvA4A8WtkWb5T07dYi-aUnzI,4406
21
- nshutils-0.30.1.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
22
- nshutils-0.30.1.dist-info/RECORD,,