nshutils 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshutils/__init__.py +2 -0
- nshutils/snoop.py +216 -0
- nshutils/typecheck.py +155 -0
- nshutils-0.1.0.dist-info/METADATA +21 -0
- nshutils-0.1.0.dist-info/RECORD +6 -0
- nshutils-0.1.0.dist-info/WHEEL +4 -0
nshutils/__init__.py
ADDED
nshutils/snoop.py
ADDED
@@ -0,0 +1,216 @@
|
|
1
|
+
import contextlib
|
2
|
+
from typing import Any, Protocol, cast
|
3
|
+
|
4
|
+
from typing_extensions import TypeVar
|
5
|
+
|
6
|
+
T = TypeVar("T", infer_variance=True)
|
7
|
+
|
8
|
+
|
9
|
+
class SnoopConstructor(Protocol):
|
10
|
+
def __call__(self, *args, **kwargs) -> contextlib.AbstractContextManager: ...
|
11
|
+
|
12
|
+
def disable(self) -> contextlib.AbstractContextManager: ...
|
13
|
+
|
14
|
+
|
15
|
+
try:
|
16
|
+
import warnings
|
17
|
+
from contextlib import nullcontext
|
18
|
+
|
19
|
+
import lovely_numpy as lo
|
20
|
+
import lovely_tensors as lt
|
21
|
+
import numpy
|
22
|
+
import pysnooper
|
23
|
+
import pysnooper.utils
|
24
|
+
import torch
|
25
|
+
from pkg_resources import DistributionNotFound, get_distribution
|
26
|
+
|
27
|
+
FLOATING_POINTS = set()
|
28
|
+
for i in ["float", "double", "half", "complex128", "complex32", "complex64"]:
|
29
|
+
if hasattr(torch, i): # older version of PyTorch do not have complex dtypes
|
30
|
+
FLOATING_POINTS.add(getattr(torch, i))
|
31
|
+
|
32
|
+
try:
|
33
|
+
__version__ = get_distribution(__name__).version
|
34
|
+
except DistributionNotFound:
|
35
|
+
# package is not installed
|
36
|
+
pass
|
37
|
+
|
38
|
+
def default_format(x):
|
39
|
+
try:
|
40
|
+
formatted = str(lt.lovely(x))
|
41
|
+
return formatted
|
42
|
+
except BaseException:
|
43
|
+
return str(x.shape)
|
44
|
+
|
45
|
+
def default_numpy_format(x):
|
46
|
+
return str(lo.lovely(x))
|
47
|
+
|
48
|
+
class TorchSnooper(pysnooper.tracer.Tracer):
|
49
|
+
def __init__(
|
50
|
+
self,
|
51
|
+
*args,
|
52
|
+
tensor_format=default_format,
|
53
|
+
numpy_format=default_numpy_format,
|
54
|
+
**kwargs,
|
55
|
+
):
|
56
|
+
self.orig_custom_repr = (
|
57
|
+
kwargs["custom_repr"] if "custom_repr" in kwargs else ()
|
58
|
+
)
|
59
|
+
custom_repr = (lambda x: True, self.compute_repr)
|
60
|
+
kwargs["custom_repr"] = (custom_repr,)
|
61
|
+
super(TorchSnooper, self).__init__(*args, **kwargs)
|
62
|
+
self.tensor_format = tensor_format
|
63
|
+
self.numpy_format = numpy_format
|
64
|
+
|
65
|
+
@staticmethod
|
66
|
+
def is_return_types(x):
|
67
|
+
return type(x).__module__ == "torch.return_types"
|
68
|
+
|
69
|
+
def return_types_repr(self, x):
|
70
|
+
if type(x).__name__ in {
|
71
|
+
"max",
|
72
|
+
"min",
|
73
|
+
"median",
|
74
|
+
"mode",
|
75
|
+
"sort",
|
76
|
+
"topk",
|
77
|
+
"kthvalue",
|
78
|
+
}:
|
79
|
+
return (
|
80
|
+
type(x).__name__
|
81
|
+
+ "(values="
|
82
|
+
+ self.tensor_format(x.values)
|
83
|
+
+ ", indices="
|
84
|
+
+ self.tensor_format(x.indices)
|
85
|
+
+ ")"
|
86
|
+
)
|
87
|
+
if type(x).__name__ == "svd":
|
88
|
+
return (
|
89
|
+
"svd(U="
|
90
|
+
+ self.tensor_format(x.U)
|
91
|
+
+ ", S="
|
92
|
+
+ self.tensor_format(x.S)
|
93
|
+
+ ", V="
|
94
|
+
+ self.tensor_format(x.V)
|
95
|
+
+ ")"
|
96
|
+
)
|
97
|
+
if type(x).__name__ == "slogdet":
|
98
|
+
return (
|
99
|
+
"slogdet(sign="
|
100
|
+
+ self.tensor_format(x.sign)
|
101
|
+
+ ", logabsdet="
|
102
|
+
+ self.tensor_format(x.logabsdet)
|
103
|
+
+ ")"
|
104
|
+
)
|
105
|
+
if type(x).__name__ == "qr":
|
106
|
+
return (
|
107
|
+
"qr(Q="
|
108
|
+
+ self.tensor_format(x.Q)
|
109
|
+
+ ", R="
|
110
|
+
+ self.tensor_format(x.R)
|
111
|
+
+ ")"
|
112
|
+
)
|
113
|
+
if type(x).__name__ == "solve":
|
114
|
+
return (
|
115
|
+
"solve(solution="
|
116
|
+
+ self.tensor_format(x.solution)
|
117
|
+
+ ", LU="
|
118
|
+
+ self.tensor_format(x.LU)
|
119
|
+
+ ")"
|
120
|
+
)
|
121
|
+
if type(x).__name__ == "geqrf":
|
122
|
+
return (
|
123
|
+
"geqrf(a="
|
124
|
+
+ self.tensor_format(x.a)
|
125
|
+
+ ", tau="
|
126
|
+
+ self.tensor_format(x.tau)
|
127
|
+
+ ")"
|
128
|
+
)
|
129
|
+
if type(x).__name__ in {"symeig", "eig"}:
|
130
|
+
return (
|
131
|
+
type(x).__name__
|
132
|
+
+ "(eigenvalues="
|
133
|
+
+ self.tensor_format(x.eigenvalues)
|
134
|
+
+ ", eigenvectors="
|
135
|
+
+ self.tensor_format(x.eigenvectors)
|
136
|
+
+ ")"
|
137
|
+
)
|
138
|
+
if type(x).__name__ == "triangular_solve":
|
139
|
+
return (
|
140
|
+
"triangular_solve(solution="
|
141
|
+
+ self.tensor_format(x.solution)
|
142
|
+
+ ", cloned_coefficient="
|
143
|
+
+ self.tensor_format(x.cloned_coefficient)
|
144
|
+
+ ")"
|
145
|
+
)
|
146
|
+
if type(x).__name__ == "gels":
|
147
|
+
return (
|
148
|
+
"gels(solution="
|
149
|
+
+ self.tensor_format(x.solution)
|
150
|
+
+ ", QR="
|
151
|
+
+ self.tensor_format(x.QR)
|
152
|
+
+ ")"
|
153
|
+
)
|
154
|
+
warnings.warn("Unknown return_types encountered, open a bug report!")
|
155
|
+
|
156
|
+
def compute_repr(self, x):
|
157
|
+
orig_repr_func = pysnooper.utils.get_repr_function(x, self.orig_custom_repr)
|
158
|
+
if torch.is_tensor(x):
|
159
|
+
return self.tensor_format(x)
|
160
|
+
if isinstance(x, numpy.ndarray):
|
161
|
+
return self.numpy_format(x)
|
162
|
+
if self.is_return_types(x):
|
163
|
+
return self.return_types_repr(x)
|
164
|
+
if orig_repr_func is not repr:
|
165
|
+
return orig_repr_func(x)
|
166
|
+
if isinstance(x, (list, tuple)):
|
167
|
+
content = ""
|
168
|
+
for i in x:
|
169
|
+
if content != "":
|
170
|
+
content += ", "
|
171
|
+
content += self.compute_repr(i)
|
172
|
+
if isinstance(x, tuple) and len(x) == 1:
|
173
|
+
content += ","
|
174
|
+
if isinstance(x, tuple):
|
175
|
+
return "(" + content + ")"
|
176
|
+
return "[" + content + "]"
|
177
|
+
if isinstance(x, dict):
|
178
|
+
content = ""
|
179
|
+
for k, v in x.items():
|
180
|
+
if content != "":
|
181
|
+
content += ", "
|
182
|
+
content += self.compute_repr(k) + ": " + self.compute_repr(v)
|
183
|
+
return "{" + content + "}"
|
184
|
+
return repr(x)
|
185
|
+
|
186
|
+
class _Snoop:
|
187
|
+
disable = nullcontext
|
188
|
+
__call__ = TorchSnooper
|
189
|
+
|
190
|
+
snoop: SnoopConstructor = cast(Any, _Snoop())
|
191
|
+
|
192
|
+
except ImportError:
|
193
|
+
import warnings
|
194
|
+
from contextlib import nullcontext
|
195
|
+
|
196
|
+
from typing_extensions import override
|
197
|
+
|
198
|
+
_has_warned = False
|
199
|
+
|
200
|
+
class _snoop_cls(nullcontext):
|
201
|
+
@classmethod
|
202
|
+
def disable(cls):
|
203
|
+
return nullcontext()
|
204
|
+
|
205
|
+
@override
|
206
|
+
def __enter__(self):
|
207
|
+
global _has_warned
|
208
|
+
if not _has_warned:
|
209
|
+
warnings.warn(
|
210
|
+
"snoop is not installed, please install it to enable snoop"
|
211
|
+
)
|
212
|
+
_has_warned = True
|
213
|
+
|
214
|
+
return super().__enter__()
|
215
|
+
|
216
|
+
snoop: SnoopConstructor = cast(Any, _snoop_cls)
|
nshutils/typecheck.py
ADDED
@@ -0,0 +1,155 @@
|
|
1
|
+
import os
|
2
|
+
from collections.abc import Sequence
|
3
|
+
from logging import getLogger
|
4
|
+
from typing import Any
|
5
|
+
|
6
|
+
from jaxtyping import BFloat16 as BFloat16
|
7
|
+
from jaxtyping import Bool as Bool
|
8
|
+
from jaxtyping import Complex as Complex
|
9
|
+
from jaxtyping import Complex64 as Complex64
|
10
|
+
from jaxtyping import Complex128 as Complex128
|
11
|
+
from jaxtyping import Float as Float
|
12
|
+
from jaxtyping import Float16 as Float16
|
13
|
+
from jaxtyping import Float32 as Float32
|
14
|
+
from jaxtyping import Float64 as Float64
|
15
|
+
from jaxtyping import Inexact as Inexact
|
16
|
+
from jaxtyping import Int as Int
|
17
|
+
from jaxtyping import Int4 as Int4
|
18
|
+
from jaxtyping import Int8 as Int8
|
19
|
+
from jaxtyping import Int16 as Int16
|
20
|
+
from jaxtyping import Int32 as Int32
|
21
|
+
from jaxtyping import Int64 as Int64
|
22
|
+
from jaxtyping import Integer as Integer
|
23
|
+
from jaxtyping import Key as Key
|
24
|
+
from jaxtyping import Num as Num
|
25
|
+
from jaxtyping import Real as Real
|
26
|
+
from jaxtyping import Shaped as Shaped
|
27
|
+
from jaxtyping import UInt as UInt
|
28
|
+
from jaxtyping import UInt4 as UInt4
|
29
|
+
from jaxtyping import UInt8 as UInt8
|
30
|
+
from jaxtyping import UInt16 as UInt16
|
31
|
+
from jaxtyping import UInt32 as UInt32
|
32
|
+
from jaxtyping import UInt64 as UInt64
|
33
|
+
from jaxtyping._storage import get_shape_memo, shape_str
|
34
|
+
from typing_extensions import TypeVar
|
35
|
+
|
36
|
+
try:
|
37
|
+
import torch
|
38
|
+
except ImportError:
|
39
|
+
torch = None
|
40
|
+
|
41
|
+
try:
|
42
|
+
import np
|
43
|
+
except ImportError:
|
44
|
+
np = None
|
45
|
+
|
46
|
+
|
47
|
+
try:
|
48
|
+
import jax
|
49
|
+
except ImportError:
|
50
|
+
jax = None
|
51
|
+
log = getLogger(__name__)
|
52
|
+
|
53
|
+
DISABLE_ENV_KEY = "NSHUTILS_DISABLE_TYPECHECKING"
|
54
|
+
|
55
|
+
|
56
|
+
def typecheck_modules(modules: Sequence[str]):
|
57
|
+
"""
|
58
|
+
Typecheck the given modules using `jaxtyping`.
|
59
|
+
|
60
|
+
Args:
|
61
|
+
modules: Modules to typecheck.
|
62
|
+
"""
|
63
|
+
# If `DISABLE_ENV_KEY` is set and the environment variable is set, skip
|
64
|
+
# typechecking.
|
65
|
+
if DISABLE_ENV_KEY is not None and bool(int(os.environ.get(DISABLE_ENV_KEY, "0"))):
|
66
|
+
log.critical(
|
67
|
+
f"Type checking is disabled due to the environment variable {DISABLE_ENV_KEY}."
|
68
|
+
)
|
69
|
+
return
|
70
|
+
|
71
|
+
# Install the jaxtyping import hook for this module.
|
72
|
+
from jaxtyping import install_import_hook
|
73
|
+
|
74
|
+
install_import_hook(modules, "beartype.beartype")
|
75
|
+
|
76
|
+
log.critical(f"Type checking the following modules: {modules}")
|
77
|
+
|
78
|
+
|
79
|
+
def typecheck_this_module(additional_modules: Sequence[str] = ()):
|
80
|
+
"""
|
81
|
+
Typecheck the calling module and any additional modules using `jaxtyping`.
|
82
|
+
|
83
|
+
Args:
|
84
|
+
additional_modules: Additional modules to typecheck.
|
85
|
+
"""
|
86
|
+
# Get the calling module's name.
|
87
|
+
# Here, we can just use beartype's internal implementation behind
|
88
|
+
# `beartype_this_package`.
|
89
|
+
from beartype._util.func.utilfuncframe import get_frame, get_frame_package_name
|
90
|
+
|
91
|
+
# Get the calling module's name.
|
92
|
+
assert get_frame is not None, "get_frame is None"
|
93
|
+
frame = get_frame(1)
|
94
|
+
assert frame is not None, "frame is None"
|
95
|
+
calling_module_name = get_frame_package_name(frame)
|
96
|
+
assert calling_module_name is not None, "calling_module_name is None"
|
97
|
+
|
98
|
+
# Typecheck the calling module + any additional modules.
|
99
|
+
typecheck_modules((calling_module_name, *additional_modules))
|
100
|
+
|
101
|
+
|
102
|
+
def _make_error_str(input: Any, t: Any) -> str:
|
103
|
+
error_components: list[str] = []
|
104
|
+
error_components.append("Type checking error:")
|
105
|
+
if hasattr(t, "__instancecheck_str__"):
|
106
|
+
error_components.append(t.__instancecheck_str__(input))
|
107
|
+
if torch is not None and torch.is_tensor(input):
|
108
|
+
try:
|
109
|
+
from lovely_tensors import lovely
|
110
|
+
|
111
|
+
error_components.append(repr(lovely(input)))
|
112
|
+
except BaseException:
|
113
|
+
error_components.append(repr(input.shape))
|
114
|
+
elif jax is not None and isinstance(input, jax.Array):
|
115
|
+
try:
|
116
|
+
from lovely_jax import lovely
|
117
|
+
|
118
|
+
error_components.append(repr(lovely(input)))
|
119
|
+
except BaseException:
|
120
|
+
error_components.append(repr(input.shape))
|
121
|
+
elif np is not None and isinstance(input, np.ndarray):
|
122
|
+
try:
|
123
|
+
from lovely_numpy import lovely
|
124
|
+
|
125
|
+
error_components.append(repr(lovely(input)))
|
126
|
+
except BaseException:
|
127
|
+
error_components.append(repr(input.shape))
|
128
|
+
error_components.append(shape_str(get_shape_memo()))
|
129
|
+
|
130
|
+
return "\n".join(error_components)
|
131
|
+
|
132
|
+
|
133
|
+
T = TypeVar("T", infer_variance=True)
|
134
|
+
|
135
|
+
|
136
|
+
def tassert(t: Any, input: T | tuple[T, ...]):
|
137
|
+
"""
|
138
|
+
Typecheck the input against the given type.
|
139
|
+
|
140
|
+
Args:
|
141
|
+
t: Type to check against.
|
142
|
+
input: Input to check.
|
143
|
+
"""
|
144
|
+
__tracebackhide__ = True
|
145
|
+
|
146
|
+
# Ignore typechecking if the environment variable is set.
|
147
|
+
if DISABLE_ENV_KEY is not None and bool(int(os.environ.get(DISABLE_ENV_KEY, "0"))):
|
148
|
+
return
|
149
|
+
|
150
|
+
if isinstance(input, tuple):
|
151
|
+
for i in input:
|
152
|
+
assert isinstance(i, t), _make_error_str(i, t)
|
153
|
+
return
|
154
|
+
else:
|
155
|
+
assert isinstance(input, t), _make_error_str(input, t)
|
@@ -0,0 +1,21 @@
|
|
1
|
+
Metadata-Version: 2.1
|
2
|
+
Name: nshutils
|
3
|
+
Version: 0.1.0
|
4
|
+
Summary:
|
5
|
+
Author: Nima Shoghi
|
6
|
+
Author-email: nimashoghi@gmail.com
|
7
|
+
Requires-Python: >=3.10,<4.0
|
8
|
+
Classifier: Programming Language :: Python :: 3
|
9
|
+
Classifier: Programming Language :: Python :: 3.10
|
10
|
+
Classifier: Programming Language :: Python :: 3.11
|
11
|
+
Classifier: Programming Language :: Python :: 3.12
|
12
|
+
Requires-Dist: beartype (>=0.18.5,<0.19.0)
|
13
|
+
Requires-Dist: jaxtyping (>=0.2.33,<0.3.0)
|
14
|
+
Requires-Dist: lovely-jax (>=0.1.3,<0.2.0)
|
15
|
+
Requires-Dist: lovely-numpy (>=0.2.13,<0.3.0)
|
16
|
+
Requires-Dist: lovely-tensors (>=0.1.16,<0.2.0)
|
17
|
+
Requires-Dist: pysnooper (>=1.2.0,<2.0.0)
|
18
|
+
Requires-Dist: typing-extensions
|
19
|
+
Description-Content-Type: text/markdown
|
20
|
+
|
21
|
+
|
@@ -0,0 +1,6 @@
|
|
1
|
+
nshutils/__init__.py,sha256=uNKU8zJk7Un4JC-fNHNicz_iUNGIequDgCXxIDqeKr4,71
|
2
|
+
nshutils/snoop.py,sha256=Rofv1Rd92E0LY40G3A-o9Hu0ZI73RR59wJD5l4Q3PDM,7022
|
3
|
+
nshutils/typecheck.py,sha256=wrjL-H2f3J8V1lojXIbcwQBh3039bz3HBVgG9DINYK4,4819
|
4
|
+
nshutils-0.1.0.dist-info/METADATA,sha256=pfHZUxtiSb_TEXTkUUCpl9Q-bqSDQH_2tjI9dVO_jmk,687
|
5
|
+
nshutils-0.1.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
6
|
+
nshutils-0.1.0.dist-info/RECORD,,
|