nshutils 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ Metadata-Version: 2.1
2
+ Name: nshutils
3
+ Version: 0.1.0
4
+ Summary:
5
+ Author: Nima Shoghi
6
+ Author-email: nimashoghi@gmail.com
7
+ Requires-Python: >=3.10,<4.0
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: Programming Language :: Python :: 3.10
10
+ Classifier: Programming Language :: Python :: 3.11
11
+ Classifier: Programming Language :: Python :: 3.12
12
+ Requires-Dist: beartype (>=0.18.5,<0.19.0)
13
+ Requires-Dist: jaxtyping (>=0.2.33,<0.3.0)
14
+ Requires-Dist: lovely-jax (>=0.1.3,<0.2.0)
15
+ Requires-Dist: lovely-numpy (>=0.2.13,<0.3.0)
16
+ Requires-Dist: lovely-tensors (>=0.1.16,<0.2.0)
17
+ Requires-Dist: pysnooper (>=1.2.0,<2.0.0)
18
+ Requires-Dist: typing-extensions
19
+ Description-Content-Type: text/markdown
20
+
21
+
File without changes
@@ -0,0 +1,21 @@
1
+ [tool.poetry]
2
+ name = "nshutils"
3
+ version = "0.1.0"
4
+ description = ""
5
+ authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
+ readme = "README.md"
7
+
8
+ [tool.poetry.dependencies]
9
+ python = "^3.10"
10
+ pysnooper = "^1.2.0"
11
+ jaxtyping = "^0.2.33"
12
+ typing-extensions = "*"
13
+ lovely-numpy = {version = "^0.2.13", optional = true}
14
+ lovely-tensors = {version = "^0.1.16", optional = true}
15
+ beartype = "^0.18.5"
16
+ lovely-jax = {version = "^0.1.3", optional = true}
17
+
18
+
19
+ [build-system]
20
+ requires = ["poetry-core"]
21
+ build-backend = "poetry.core.masonry.api"
@@ -0,0 +1,2 @@
1
+ from . import typecheck as typecheck
2
+ from .snoop import snoop as snoop
@@ -0,0 +1,216 @@
1
+ import contextlib
2
+ from typing import Any, Protocol, cast
3
+
4
+ from typing_extensions import TypeVar
5
+
6
+ T = TypeVar("T", infer_variance=True)
7
+
8
+
9
+ class SnoopConstructor(Protocol):
10
+ def __call__(self, *args, **kwargs) -> contextlib.AbstractContextManager: ...
11
+
12
+ def disable(self) -> contextlib.AbstractContextManager: ...
13
+
14
+
15
+ try:
16
+ import warnings
17
+ from contextlib import nullcontext
18
+
19
+ import lovely_numpy as lo
20
+ import lovely_tensors as lt
21
+ import numpy
22
+ import pysnooper
23
+ import pysnooper.utils
24
+ import torch
25
+ from pkg_resources import DistributionNotFound, get_distribution
26
+
27
+ FLOATING_POINTS = set()
28
+ for i in ["float", "double", "half", "complex128", "complex32", "complex64"]:
29
+ if hasattr(torch, i): # older version of PyTorch do not have complex dtypes
30
+ FLOATING_POINTS.add(getattr(torch, i))
31
+
32
+ try:
33
+ __version__ = get_distribution(__name__).version
34
+ except DistributionNotFound:
35
+ # package is not installed
36
+ pass
37
+
38
+ def default_format(x):
39
+ try:
40
+ formatted = str(lt.lovely(x))
41
+ return formatted
42
+ except BaseException:
43
+ return str(x.shape)
44
+
45
+ def default_numpy_format(x):
46
+ return str(lo.lovely(x))
47
+
48
+ class TorchSnooper(pysnooper.tracer.Tracer):
49
+ def __init__(
50
+ self,
51
+ *args,
52
+ tensor_format=default_format,
53
+ numpy_format=default_numpy_format,
54
+ **kwargs,
55
+ ):
56
+ self.orig_custom_repr = (
57
+ kwargs["custom_repr"] if "custom_repr" in kwargs else ()
58
+ )
59
+ custom_repr = (lambda x: True, self.compute_repr)
60
+ kwargs["custom_repr"] = (custom_repr,)
61
+ super(TorchSnooper, self).__init__(*args, **kwargs)
62
+ self.tensor_format = tensor_format
63
+ self.numpy_format = numpy_format
64
+
65
+ @staticmethod
66
+ def is_return_types(x):
67
+ return type(x).__module__ == "torch.return_types"
68
+
69
+ def return_types_repr(self, x):
70
+ if type(x).__name__ in {
71
+ "max",
72
+ "min",
73
+ "median",
74
+ "mode",
75
+ "sort",
76
+ "topk",
77
+ "kthvalue",
78
+ }:
79
+ return (
80
+ type(x).__name__
81
+ + "(values="
82
+ + self.tensor_format(x.values)
83
+ + ", indices="
84
+ + self.tensor_format(x.indices)
85
+ + ")"
86
+ )
87
+ if type(x).__name__ == "svd":
88
+ return (
89
+ "svd(U="
90
+ + self.tensor_format(x.U)
91
+ + ", S="
92
+ + self.tensor_format(x.S)
93
+ + ", V="
94
+ + self.tensor_format(x.V)
95
+ + ")"
96
+ )
97
+ if type(x).__name__ == "slogdet":
98
+ return (
99
+ "slogdet(sign="
100
+ + self.tensor_format(x.sign)
101
+ + ", logabsdet="
102
+ + self.tensor_format(x.logabsdet)
103
+ + ")"
104
+ )
105
+ if type(x).__name__ == "qr":
106
+ return (
107
+ "qr(Q="
108
+ + self.tensor_format(x.Q)
109
+ + ", R="
110
+ + self.tensor_format(x.R)
111
+ + ")"
112
+ )
113
+ if type(x).__name__ == "solve":
114
+ return (
115
+ "solve(solution="
116
+ + self.tensor_format(x.solution)
117
+ + ", LU="
118
+ + self.tensor_format(x.LU)
119
+ + ")"
120
+ )
121
+ if type(x).__name__ == "geqrf":
122
+ return (
123
+ "geqrf(a="
124
+ + self.tensor_format(x.a)
125
+ + ", tau="
126
+ + self.tensor_format(x.tau)
127
+ + ")"
128
+ )
129
+ if type(x).__name__ in {"symeig", "eig"}:
130
+ return (
131
+ type(x).__name__
132
+ + "(eigenvalues="
133
+ + self.tensor_format(x.eigenvalues)
134
+ + ", eigenvectors="
135
+ + self.tensor_format(x.eigenvectors)
136
+ + ")"
137
+ )
138
+ if type(x).__name__ == "triangular_solve":
139
+ return (
140
+ "triangular_solve(solution="
141
+ + self.tensor_format(x.solution)
142
+ + ", cloned_coefficient="
143
+ + self.tensor_format(x.cloned_coefficient)
144
+ + ")"
145
+ )
146
+ if type(x).__name__ == "gels":
147
+ return (
148
+ "gels(solution="
149
+ + self.tensor_format(x.solution)
150
+ + ", QR="
151
+ + self.tensor_format(x.QR)
152
+ + ")"
153
+ )
154
+ warnings.warn("Unknown return_types encountered, open a bug report!")
155
+
156
+ def compute_repr(self, x):
157
+ orig_repr_func = pysnooper.utils.get_repr_function(x, self.orig_custom_repr)
158
+ if torch.is_tensor(x):
159
+ return self.tensor_format(x)
160
+ if isinstance(x, numpy.ndarray):
161
+ return self.numpy_format(x)
162
+ if self.is_return_types(x):
163
+ return self.return_types_repr(x)
164
+ if orig_repr_func is not repr:
165
+ return orig_repr_func(x)
166
+ if isinstance(x, (list, tuple)):
167
+ content = ""
168
+ for i in x:
169
+ if content != "":
170
+ content += ", "
171
+ content += self.compute_repr(i)
172
+ if isinstance(x, tuple) and len(x) == 1:
173
+ content += ","
174
+ if isinstance(x, tuple):
175
+ return "(" + content + ")"
176
+ return "[" + content + "]"
177
+ if isinstance(x, dict):
178
+ content = ""
179
+ for k, v in x.items():
180
+ if content != "":
181
+ content += ", "
182
+ content += self.compute_repr(k) + ": " + self.compute_repr(v)
183
+ return "{" + content + "}"
184
+ return repr(x)
185
+
186
+ class _Snoop:
187
+ disable = nullcontext
188
+ __call__ = TorchSnooper
189
+
190
+ snoop: SnoopConstructor = cast(Any, _Snoop())
191
+
192
+ except ImportError:
193
+ import warnings
194
+ from contextlib import nullcontext
195
+
196
+ from typing_extensions import override
197
+
198
+ _has_warned = False
199
+
200
+ class _snoop_cls(nullcontext):
201
+ @classmethod
202
+ def disable(cls):
203
+ return nullcontext()
204
+
205
+ @override
206
+ def __enter__(self):
207
+ global _has_warned
208
+ if not _has_warned:
209
+ warnings.warn(
210
+ "snoop is not installed, please install it to enable snoop"
211
+ )
212
+ _has_warned = True
213
+
214
+ return super().__enter__()
215
+
216
+ snoop: SnoopConstructor = cast(Any, _snoop_cls)
@@ -0,0 +1,155 @@
1
+ import os
2
+ from collections.abc import Sequence
3
+ from logging import getLogger
4
+ from typing import Any
5
+
6
+ from jaxtyping import BFloat16 as BFloat16
7
+ from jaxtyping import Bool as Bool
8
+ from jaxtyping import Complex as Complex
9
+ from jaxtyping import Complex64 as Complex64
10
+ from jaxtyping import Complex128 as Complex128
11
+ from jaxtyping import Float as Float
12
+ from jaxtyping import Float16 as Float16
13
+ from jaxtyping import Float32 as Float32
14
+ from jaxtyping import Float64 as Float64
15
+ from jaxtyping import Inexact as Inexact
16
+ from jaxtyping import Int as Int
17
+ from jaxtyping import Int4 as Int4
18
+ from jaxtyping import Int8 as Int8
19
+ from jaxtyping import Int16 as Int16
20
+ from jaxtyping import Int32 as Int32
21
+ from jaxtyping import Int64 as Int64
22
+ from jaxtyping import Integer as Integer
23
+ from jaxtyping import Key as Key
24
+ from jaxtyping import Num as Num
25
+ from jaxtyping import Real as Real
26
+ from jaxtyping import Shaped as Shaped
27
+ from jaxtyping import UInt as UInt
28
+ from jaxtyping import UInt4 as UInt4
29
+ from jaxtyping import UInt8 as UInt8
30
+ from jaxtyping import UInt16 as UInt16
31
+ from jaxtyping import UInt32 as UInt32
32
+ from jaxtyping import UInt64 as UInt64
33
+ from jaxtyping._storage import get_shape_memo, shape_str
34
+ from typing_extensions import TypeVar
35
+
36
+ try:
37
+ import torch
38
+ except ImportError:
39
+ torch = None
40
+
41
+ try:
42
+ import np
43
+ except ImportError:
44
+ np = None
45
+
46
+
47
+ try:
48
+ import jax
49
+ except ImportError:
50
+ jax = None
51
+ log = getLogger(__name__)
52
+
53
+ DISABLE_ENV_KEY = "NSHUTILS_DISABLE_TYPECHECKING"
54
+
55
+
56
+ def typecheck_modules(modules: Sequence[str]):
57
+ """
58
+ Typecheck the given modules using `jaxtyping`.
59
+
60
+ Args:
61
+ modules: Modules to typecheck.
62
+ """
63
+ # If `DISABLE_ENV_KEY` is set and the environment variable is set, skip
64
+ # typechecking.
65
+ if DISABLE_ENV_KEY is not None and bool(int(os.environ.get(DISABLE_ENV_KEY, "0"))):
66
+ log.critical(
67
+ f"Type checking is disabled due to the environment variable {DISABLE_ENV_KEY}."
68
+ )
69
+ return
70
+
71
+ # Install the jaxtyping import hook for this module.
72
+ from jaxtyping import install_import_hook
73
+
74
+ install_import_hook(modules, "beartype.beartype")
75
+
76
+ log.critical(f"Type checking the following modules: {modules}")
77
+
78
+
79
+ def typecheck_this_module(additional_modules: Sequence[str] = ()):
80
+ """
81
+ Typecheck the calling module and any additional modules using `jaxtyping`.
82
+
83
+ Args:
84
+ additional_modules: Additional modules to typecheck.
85
+ """
86
+ # Get the calling module's name.
87
+ # Here, we can just use beartype's internal implementation behind
88
+ # `beartype_this_package`.
89
+ from beartype._util.func.utilfuncframe import get_frame, get_frame_package_name
90
+
91
+ # Get the calling module's name.
92
+ assert get_frame is not None, "get_frame is None"
93
+ frame = get_frame(1)
94
+ assert frame is not None, "frame is None"
95
+ calling_module_name = get_frame_package_name(frame)
96
+ assert calling_module_name is not None, "calling_module_name is None"
97
+
98
+ # Typecheck the calling module + any additional modules.
99
+ typecheck_modules((calling_module_name, *additional_modules))
100
+
101
+
102
+ def _make_error_str(input: Any, t: Any) -> str:
103
+ error_components: list[str] = []
104
+ error_components.append("Type checking error:")
105
+ if hasattr(t, "__instancecheck_str__"):
106
+ error_components.append(t.__instancecheck_str__(input))
107
+ if torch is not None and torch.is_tensor(input):
108
+ try:
109
+ from lovely_tensors import lovely
110
+
111
+ error_components.append(repr(lovely(input)))
112
+ except BaseException:
113
+ error_components.append(repr(input.shape))
114
+ elif jax is not None and isinstance(input, jax.Array):
115
+ try:
116
+ from lovely_jax import lovely
117
+
118
+ error_components.append(repr(lovely(input)))
119
+ except BaseException:
120
+ error_components.append(repr(input.shape))
121
+ elif np is not None and isinstance(input, np.ndarray):
122
+ try:
123
+ from lovely_numpy import lovely
124
+
125
+ error_components.append(repr(lovely(input)))
126
+ except BaseException:
127
+ error_components.append(repr(input.shape))
128
+ error_components.append(shape_str(get_shape_memo()))
129
+
130
+ return "\n".join(error_components)
131
+
132
+
133
+ T = TypeVar("T", infer_variance=True)
134
+
135
+
136
+ def tassert(t: Any, input: T | tuple[T, ...]):
137
+ """
138
+ Typecheck the input against the given type.
139
+
140
+ Args:
141
+ t: Type to check against.
142
+ input: Input to check.
143
+ """
144
+ __tracebackhide__ = True
145
+
146
+ # Ignore typechecking if the environment variable is set.
147
+ if DISABLE_ENV_KEY is not None and bool(int(os.environ.get(DISABLE_ENV_KEY, "0"))):
148
+ return
149
+
150
+ if isinstance(input, tuple):
151
+ for i in input:
152
+ assert isinstance(i, t), _make_error_str(i, t)
153
+ return
154
+ else:
155
+ assert isinstance(input, t), _make_error_str(input, t)