klongpy 0.6.8__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- klongpy/__init__.py +19 -1
- klongpy/adverbs.py +5 -5
- klongpy/autograd.py +308 -0
- klongpy/backend.py +167 -99
- klongpy/backends/__init__.py +94 -0
- klongpy/backends/base.py +320 -0
- klongpy/backends/numpy_backend.py +122 -0
- klongpy/backends/torch_backend.py +995 -0
- klongpy-0.6.8.data/scripts/kgpy → klongpy/cli.py +65 -88
- klongpy/core.py +228 -106
- klongpy/db/sys_fn_db.py +4 -3
- klongpy/dyads.py +173 -32
- klongpy/interpreter.py +31 -3
- klongpy/lib/help.kg +2 -2
- klongpy/monads.py +49 -12
- klongpy/repl.py +91 -0
- klongpy/sys_fn.py +129 -18
- klongpy/sys_fn_autograd.py +290 -0
- klongpy/sys_fn_ipc.py +18 -7
- klongpy/sys_fn_timer.py +13 -3
- klongpy/web/sys_fn_web.py +28 -6
- klongpy-0.7.0.dist-info/METADATA +493 -0
- klongpy-0.7.0.dist-info/RECORD +48 -0
- {klongpy-0.6.8.dist-info → klongpy-0.7.0.dist-info}/WHEEL +1 -1
- klongpy-0.7.0.dist-info/entry_points.txt +2 -0
- {klongpy-0.6.8.dist-info → klongpy-0.7.0.dist-info}/top_level.txt +0 -1
- klongpy-0.6.8.dist-info/METADATA +0 -412
- klongpy-0.6.8.dist-info/RECORD +0 -72
- tests/__init__.py +0 -6
- tests/gen_join_over.py +0 -119
- tests/gen_py_suite.py +0 -77
- tests/gen_test_fn.py +0 -259
- tests/perf_async.py +0 -25
- tests/perf_avg.py +0 -18
- tests/perf_duckdb.py +0 -32
- tests/perf_gen.py +0 -38
- tests/perf_ipc_overhead.py +0 -34
- tests/perf_join.py +0 -53
- tests/perf_load.py +0 -17
- tests/perf_prog.py +0 -18
- tests/perf_serdes.py +0 -52
- tests/perf_sys_fn_db.py +0 -263
- tests/perf_vector.py +0 -40
- tests/test_accel.py +0 -227
- tests/test_df_cache.py +0 -85
- tests/test_examples.py +0 -64
- tests/test_extra_suite.py +0 -382
- tests/test_file_cache.py +0 -185
- tests/test_interop.py +0 -181
- tests/test_kgtests.py +0 -65
- tests/test_known_bugs.py +0 -206
- tests/test_prog.py +0 -107
- tests/test_suite.py +0 -1479
- tests/test_suite_file.py +0 -153
- tests/test_sys_fn.py +0 -420
- tests/test_sys_fn_db.py +0 -88
- tests/test_sys_fn_ipc.py +0 -587
- tests/test_sys_fn_timer.py +0 -133
- tests/test_util.py +0 -233
- tests/utils.py +0 -126
- {klongpy-0.6.8.dist-info → klongpy-0.7.0.dist-info/licenses}/LICENSE +0 -0
|
@@ -0,0 +1,995 @@
|
|
|
1
|
+
"""
|
|
2
|
+
PyTorch backend provider for KlongPy.
|
|
3
|
+
|
|
4
|
+
This backend uses PyTorch tensors for array operations, enabling GPU acceleration.
|
|
5
|
+
It does not support object dtype or string operations.
|
|
6
|
+
"""
|
|
7
|
+
import math
|
|
8
|
+
import numpy
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from .base import BackendProvider, UnsupportedDtypeError, is_jagged_array
|
|
12
|
+
|
|
13
|
+
# numpy 2.x moved VisibleDeprecationWarning to numpy.exceptions
|
|
14
|
+
from numpy.exceptions import VisibleDeprecationWarning as NumpyVisibleDeprecationWarning
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TorchUnsupportedDtypeError(UnsupportedDtypeError):
|
|
18
|
+
"""Raised when an operation requires object dtype which is not supported by PyTorch."""
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class TorchRandomModule:
|
|
23
|
+
"""NumPy-compatible random module using PyTorch tensors."""
|
|
24
|
+
def __init__(self, backend):
|
|
25
|
+
self._backend = backend
|
|
26
|
+
|
|
27
|
+
def random(self, size=None):
|
|
28
|
+
"""Return random floats in the half-open interval [0.0, 1.0)."""
|
|
29
|
+
if size is None:
|
|
30
|
+
return torch.rand(1, device=self._backend.device).item()
|
|
31
|
+
if isinstance(size, int):
|
|
32
|
+
size = (size,)
|
|
33
|
+
return torch.rand(*size, device=self._backend.device)
|
|
34
|
+
|
|
35
|
+
def rand(self, *shape):
|
|
36
|
+
if len(shape) == 0:
|
|
37
|
+
return torch.rand(1, device=self._backend.device).item()
|
|
38
|
+
return torch.rand(*shape, device=self._backend.device)
|
|
39
|
+
|
|
40
|
+
def randn(self, *shape):
|
|
41
|
+
if len(shape) == 0:
|
|
42
|
+
return torch.randn(1, device=self._backend.device).item()
|
|
43
|
+
return torch.randn(*shape, device=self._backend.device)
|
|
44
|
+
|
|
45
|
+
def randint(self, low, high=None, size=None):
|
|
46
|
+
if high is None:
|
|
47
|
+
high = low
|
|
48
|
+
low = 0
|
|
49
|
+
if size is None:
|
|
50
|
+
return torch.randint(low, high, (1,), device=self._backend.device).item()
|
|
51
|
+
if isinstance(size, int):
|
|
52
|
+
size = (size,)
|
|
53
|
+
return torch.randint(low, high, size, device=self._backend.device)
|
|
54
|
+
|
|
55
|
+
def choice(self, a, size=None, replace=True):
|
|
56
|
+
if isinstance(a, int):
|
|
57
|
+
a = torch.arange(a, device=self._backend.device)
|
|
58
|
+
elif not isinstance(a, torch.Tensor):
|
|
59
|
+
a = torch.tensor(a, device=self._backend.device)
|
|
60
|
+
n = len(a)
|
|
61
|
+
if size is None:
|
|
62
|
+
idx = torch.randint(0, n, (1,), device=self._backend.device).item()
|
|
63
|
+
return a[idx]
|
|
64
|
+
if isinstance(size, int):
|
|
65
|
+
size = (size,)
|
|
66
|
+
total = 1
|
|
67
|
+
for s in size:
|
|
68
|
+
total *= s
|
|
69
|
+
if replace:
|
|
70
|
+
indices = torch.randint(0, n, (total,), device=self._backend.device)
|
|
71
|
+
else:
|
|
72
|
+
indices = torch.randperm(n, device=self._backend.device)[:total]
|
|
73
|
+
return a[indices].reshape(size)
|
|
74
|
+
|
|
75
|
+
def seed(self, seed):
|
|
76
|
+
torch.manual_seed(seed)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class TorchDtype:
|
|
80
|
+
"""Wrapper for torch dtype providing numpy-compatible 'kind' attribute."""
|
|
81
|
+
|
|
82
|
+
def __init__(self, torch_dtype):
|
|
83
|
+
self._dtype = torch_dtype
|
|
84
|
+
kind_map = {
|
|
85
|
+
torch.float16: 'f',
|
|
86
|
+
torch.float32: 'f',
|
|
87
|
+
torch.float64: 'f',
|
|
88
|
+
torch.bfloat16: 'f',
|
|
89
|
+
torch.int8: 'i',
|
|
90
|
+
torch.int16: 'i',
|
|
91
|
+
torch.int32: 'i',
|
|
92
|
+
torch.int64: 'i',
|
|
93
|
+
torch.uint8: 'u',
|
|
94
|
+
torch.bool: 'b',
|
|
95
|
+
torch.complex64: 'c',
|
|
96
|
+
torch.complex128: 'c',
|
|
97
|
+
}
|
|
98
|
+
self.kind = kind_map.get(torch_dtype, 'f') # default to float
|
|
99
|
+
|
|
100
|
+
def __eq__(self, other):
|
|
101
|
+
if isinstance(other, TorchDtype):
|
|
102
|
+
return self._dtype == other._dtype
|
|
103
|
+
if isinstance(other, str):
|
|
104
|
+
return False # torch dtype != string like 'O'
|
|
105
|
+
return self._dtype == other
|
|
106
|
+
|
|
107
|
+
def __ne__(self, other):
|
|
108
|
+
return not self.__eq__(other)
|
|
109
|
+
|
|
110
|
+
def __repr__(self):
|
|
111
|
+
return repr(self._dtype)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class TorchBackend:
|
|
115
|
+
"""NumPy-compatible interface using PyTorch tensors for GPU acceleration."""
|
|
116
|
+
|
|
117
|
+
def __init__(self, device=None):
|
|
118
|
+
self._numpy = numpy
|
|
119
|
+
self._torch = torch
|
|
120
|
+
self._random = None
|
|
121
|
+
self._add = None
|
|
122
|
+
self._subtract = None
|
|
123
|
+
self._multiply = None
|
|
124
|
+
self._divide = None
|
|
125
|
+
|
|
126
|
+
# Device priority: explicit > CUDA > MPS (Apple Silicon) > CPU
|
|
127
|
+
if device is not None:
|
|
128
|
+
self.device = torch.device(device)
|
|
129
|
+
elif torch.cuda.is_available():
|
|
130
|
+
self.device = torch.device('cuda')
|
|
131
|
+
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
|
132
|
+
self.device = torch.device('mps')
|
|
133
|
+
else:
|
|
134
|
+
self.device = torch.device('cpu')
|
|
135
|
+
|
|
136
|
+
@property
|
|
137
|
+
def random(self):
|
|
138
|
+
if self._random is None:
|
|
139
|
+
self._random = TorchRandomModule(self)
|
|
140
|
+
return self._random
|
|
141
|
+
|
|
142
|
+
def __getattr__(self, name):
|
|
143
|
+
# First check if torch has this attribute
|
|
144
|
+
if hasattr(self._torch, name):
|
|
145
|
+
attr = getattr(self._torch, name)
|
|
146
|
+
if callable(attr):
|
|
147
|
+
return self._wrap_torch_func(attr, name)
|
|
148
|
+
return attr
|
|
149
|
+
# Fall back to numpy for things torch doesn't have
|
|
150
|
+
if hasattr(self._numpy, name):
|
|
151
|
+
return getattr(self._numpy, name)
|
|
152
|
+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
|
153
|
+
|
|
154
|
+
def _wrap_torch_func(self, func, name):
|
|
155
|
+
# Functions that require tensor inputs (not Python scalars)
|
|
156
|
+
tensor_required_funcs = {
|
|
157
|
+
'abs', 'trunc', 'floor', 'ceil', 'round', 'sign',
|
|
158
|
+
'sin', 'cos', 'tan', 'exp', 'log', 'sqrt',
|
|
159
|
+
'isinf', 'isnan', 'isfinite',
|
|
160
|
+
'minimum', 'maximum', 'fmod',
|
|
161
|
+
'less', 'greater', 'less_equal', 'greater_equal',
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
def wrapper(*args, **kwargs):
|
|
165
|
+
converted_args = []
|
|
166
|
+
needs_tensor = name in tensor_required_funcs
|
|
167
|
+
for arg in args:
|
|
168
|
+
if isinstance(arg, numpy.ndarray):
|
|
169
|
+
# Handle float64 arrays on MPS by converting to float32
|
|
170
|
+
if arg.dtype == numpy.float64 and self.device.type == 'mps':
|
|
171
|
+
arg = arg.astype(numpy.float32)
|
|
172
|
+
converted_args.append(torch.from_numpy(arg).to(self.device))
|
|
173
|
+
elif isinstance(arg, list):
|
|
174
|
+
try:
|
|
175
|
+
converted_args.append(torch.tensor(arg, device=self.device))
|
|
176
|
+
except (ValueError, TypeError):
|
|
177
|
+
converted_args.append(arg)
|
|
178
|
+
elif needs_tensor and isinstance(arg, (int, float)):
|
|
179
|
+
# Convert Python scalars to tensors for functions that require it
|
|
180
|
+
dtype = torch.float32 if isinstance(arg, float) else torch.int64
|
|
181
|
+
converted_args.append(torch.tensor(arg, dtype=dtype, device=self.device))
|
|
182
|
+
else:
|
|
183
|
+
converted_args.append(arg)
|
|
184
|
+
return func(*converted_args, **kwargs)
|
|
185
|
+
return wrapper
|
|
186
|
+
|
|
187
|
+
def asarray(self, a, dtype=None):
|
|
188
|
+
"""Convert input to a torch tensor.
|
|
189
|
+
|
|
190
|
+
Note: MPS (Apple Silicon) doesn't support float64, so we convert to float32.
|
|
191
|
+
Object dtypes are not supported - use numpy backend for heterogeneous data.
|
|
192
|
+
"""
|
|
193
|
+
if dtype is not None and (dtype == object or (hasattr(dtype, 'kind') and dtype.kind == 'O')):
|
|
194
|
+
raise TorchUnsupportedDtypeError(
|
|
195
|
+
"PyTorch backend does not support object dtype."
|
|
196
|
+
)
|
|
197
|
+
if isinstance(a, torch.Tensor):
|
|
198
|
+
if a.device != self.device:
|
|
199
|
+
return a.to(self.device)
|
|
200
|
+
return a
|
|
201
|
+
if isinstance(a, numpy.ndarray):
|
|
202
|
+
if a.dtype == object:
|
|
203
|
+
raise TorchUnsupportedDtypeError(
|
|
204
|
+
"PyTorch backend does not support object dtype arrays."
|
|
205
|
+
)
|
|
206
|
+
if a.dtype == numpy.float64 and self.device.type == 'mps':
|
|
207
|
+
a = a.astype(numpy.float32)
|
|
208
|
+
return torch.from_numpy(a).to(self.device)
|
|
209
|
+
# Check if input is a list/tuple of tensors - use stack to preserve gradients
|
|
210
|
+
if isinstance(a, (list, tuple)) and len(a) > 0 and all(isinstance(x, torch.Tensor) for x in a):
|
|
211
|
+
# torch.stack preserves requires_grad, torch.tensor does not
|
|
212
|
+
result = torch.stack(a)
|
|
213
|
+
if result.device != self.device:
|
|
214
|
+
result = result.to(self.device)
|
|
215
|
+
# Handle float64 on MPS
|
|
216
|
+
if result.dtype == torch.float64 and self.device.type == 'mps':
|
|
217
|
+
result = result.to(torch.float32)
|
|
218
|
+
return result
|
|
219
|
+
# Check if input contains any arrays/tensors mixed with non-arrays - these need object dtype
|
|
220
|
+
if isinstance(a, (list, tuple)) and len(a) > 0:
|
|
221
|
+
has_array = any(isinstance(x, (torch.Tensor, numpy.ndarray, list)) for x in a)
|
|
222
|
+
has_scalar = any(isinstance(x, (int, float)) and not isinstance(x, bool) for x in a)
|
|
223
|
+
if has_array and has_scalar:
|
|
224
|
+
# Mixed array/scalar list - can't represent in torch without losing structure
|
|
225
|
+
raise TorchUnsupportedDtypeError(
|
|
226
|
+
"PyTorch backend cannot convert mixed array/scalar lists without losing structure."
|
|
227
|
+
)
|
|
228
|
+
try:
|
|
229
|
+
t = torch.tensor(a, device=self.device)
|
|
230
|
+
# Handle float64 on MPS
|
|
231
|
+
if t.dtype == torch.float64 and self.device.type == 'mps':
|
|
232
|
+
t = t.to(torch.float32)
|
|
233
|
+
return t
|
|
234
|
+
except (ValueError, TypeError, RuntimeError) as e:
|
|
235
|
+
raise TorchUnsupportedDtypeError(
|
|
236
|
+
f"PyTorch backend cannot convert this data: {e}"
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
def array(self, a, dtype=None):
|
|
240
|
+
"""Create a torch tensor."""
|
|
241
|
+
return self.asarray(a, dtype=dtype)
|
|
242
|
+
|
|
243
|
+
def isarray(self, x):
|
|
244
|
+
"""Check if x is an array (numpy or torch tensor)."""
|
|
245
|
+
return isinstance(x, (numpy.ndarray, torch.Tensor))
|
|
246
|
+
|
|
247
|
+
def zeros(self, shape, dtype=None):
|
|
248
|
+
return torch.zeros(shape, device=self.device)
|
|
249
|
+
|
|
250
|
+
def ones(self, shape, dtype=None):
|
|
251
|
+
return torch.ones(shape, device=self.device)
|
|
252
|
+
|
|
253
|
+
def arange(self, *args, **kwargs):
|
|
254
|
+
return torch.arange(*args, device=self.device, **kwargs)
|
|
255
|
+
|
|
256
|
+
def concatenate(self, arrays, axis=0):
|
|
257
|
+
tensors = [self.asarray(a) for a in arrays]
|
|
258
|
+
return torch.cat(tensors, dim=axis)
|
|
259
|
+
|
|
260
|
+
def hstack(self, arrays):
|
|
261
|
+
tensors = [self.asarray(a) for a in arrays]
|
|
262
|
+
return torch.hstack(tensors)
|
|
263
|
+
|
|
264
|
+
def vstack(self, arrays):
|
|
265
|
+
tensors = [self.asarray(a) for a in arrays]
|
|
266
|
+
return torch.vstack(tensors)
|
|
267
|
+
|
|
268
|
+
def stack(self, arrays, axis=0):
|
|
269
|
+
tensors = [self.asarray(a) for a in arrays]
|
|
270
|
+
return torch.stack(tensors, dim=axis)
|
|
271
|
+
|
|
272
|
+
@property
|
|
273
|
+
def ndarray(self):
|
|
274
|
+
"""Return the tensor class for isinstance checks."""
|
|
275
|
+
return torch.Tensor
|
|
276
|
+
|
|
277
|
+
@property
|
|
278
|
+
def integer(self):
|
|
279
|
+
return numpy.integer
|
|
280
|
+
|
|
281
|
+
@property
|
|
282
|
+
def floating(self):
|
|
283
|
+
return numpy.floating
|
|
284
|
+
|
|
285
|
+
def copy(self, a):
|
|
286
|
+
if isinstance(a, torch.Tensor):
|
|
287
|
+
return a.clone()
|
|
288
|
+
return self.asarray(a).clone()
|
|
289
|
+
|
|
290
|
+
def isclose(self, a, b, rtol=1e-05, atol=1e-08):
|
|
291
|
+
# For scalars, use numpy's isclose to avoid tensor conversion issues
|
|
292
|
+
if not hasattr(a, '__len__') and not hasattr(b, '__len__'):
|
|
293
|
+
return self._numpy.isclose(float(a), float(b), rtol=rtol, atol=atol)
|
|
294
|
+
a_t = self.asarray(a) if not isinstance(a, torch.Tensor) else a
|
|
295
|
+
b_t = self.asarray(b) if not isinstance(b, torch.Tensor) else b
|
|
296
|
+
# torch.isclose requires same dtype, convert to float if needed
|
|
297
|
+
if a_t.dtype != b_t.dtype:
|
|
298
|
+
# Use float32 for MPS compatibility (MPS doesn't support float64)
|
|
299
|
+
a_t = a_t.to(torch.float32)
|
|
300
|
+
b_t = b_t.to(torch.float32)
|
|
301
|
+
return torch.isclose(a_t, b_t, rtol=rtol, atol=atol)
|
|
302
|
+
|
|
303
|
+
def array_equal(self, a, b):
|
|
304
|
+
a_t = self.asarray(a) if not isinstance(a, torch.Tensor) else a
|
|
305
|
+
b_t = self.asarray(b) if not isinstance(b, torch.Tensor) else b
|
|
306
|
+
return torch.equal(a_t, b_t)
|
|
307
|
+
|
|
308
|
+
def take(self, a, indices, axis=None):
|
|
309
|
+
a_t = self.asarray(a) if not isinstance(a, torch.Tensor) else a
|
|
310
|
+
indices_t = self.asarray(indices) if not isinstance(indices, torch.Tensor) else indices
|
|
311
|
+
if axis is None:
|
|
312
|
+
return a_t.flatten()[indices_t.long()]
|
|
313
|
+
return torch.index_select(a_t, axis, indices_t.long())
|
|
314
|
+
|
|
315
|
+
def transpose(self, a, axes=None):
|
|
316
|
+
"""Transpose a tensor."""
|
|
317
|
+
a_t = self.asarray(a) if not isinstance(a, torch.Tensor) else a
|
|
318
|
+
if axes is None:
|
|
319
|
+
return a_t.T if a_t.ndim >= 2 else a_t
|
|
320
|
+
return a_t.permute(*axes)
|
|
321
|
+
|
|
322
|
+
def sum(self, a, axis=None, dtype=None, out=None, keepdims=False):
|
|
323
|
+
"""Sum array elements."""
|
|
324
|
+
a_t = self.asarray(a) if not isinstance(a, torch.Tensor) else a
|
|
325
|
+
if axis is None:
|
|
326
|
+
return a_t.sum()
|
|
327
|
+
return a_t.sum(dim=axis, keepdim=keepdims)
|
|
328
|
+
|
|
329
|
+
def abs(self, a):
|
|
330
|
+
"""Absolute value."""
|
|
331
|
+
if isinstance(a, (int, float)):
|
|
332
|
+
return abs(a)
|
|
333
|
+
a_t = self.asarray(a) if not isinstance(a, torch.Tensor) else a
|
|
334
|
+
return torch.abs(a_t)
|
|
335
|
+
|
|
336
|
+
def minimum(self, a, b):
|
|
337
|
+
"""Element-wise minimum."""
|
|
338
|
+
a_t = self.asarray(a) if not isinstance(a, torch.Tensor) else a
|
|
339
|
+
b_t = self.asarray(b) if not isinstance(b, torch.Tensor) else b
|
|
340
|
+
return torch.minimum(a_t, b_t)
|
|
341
|
+
|
|
342
|
+
def maximum(self, a, b):
|
|
343
|
+
"""Element-wise maximum."""
|
|
344
|
+
a_t = self.asarray(a) if not isinstance(a, torch.Tensor) else a
|
|
345
|
+
b_t = self.asarray(b) if not isinstance(b, torch.Tensor) else b
|
|
346
|
+
return torch.maximum(a_t, b_t)
|
|
347
|
+
|
|
348
|
+
def floor(self, a):
|
|
349
|
+
"""Floor of input."""
|
|
350
|
+
if isinstance(a, (int, float)):
|
|
351
|
+
return math.floor(a)
|
|
352
|
+
a_t = self.asarray(a) if not isinstance(a, torch.Tensor) else a
|
|
353
|
+
return torch.floor(a_t)
|
|
354
|
+
|
|
355
|
+
def ceil(self, a):
|
|
356
|
+
"""Ceiling of input."""
|
|
357
|
+
if isinstance(a, (int, float)):
|
|
358
|
+
return math.ceil(a)
|
|
359
|
+
a_t = self.asarray(a) if not isinstance(a, torch.Tensor) else a
|
|
360
|
+
return torch.ceil(a_t)
|
|
361
|
+
|
|
362
|
+
def trunc(self, a):
|
|
363
|
+
"""Truncate to integer."""
|
|
364
|
+
if isinstance(a, (int, float)):
|
|
365
|
+
return math.trunc(a)
|
|
366
|
+
a_t = self.asarray(a) if not isinstance(a, torch.Tensor) else a
|
|
367
|
+
return torch.trunc(a_t)
|
|
368
|
+
|
|
369
|
+
def isinf(self, a):
|
|
370
|
+
"""Check for infinity."""
|
|
371
|
+
if isinstance(a, (int, float)):
|
|
372
|
+
return math.isinf(a)
|
|
373
|
+
a_t = self.asarray(a) if not isinstance(a, torch.Tensor) else a
|
|
374
|
+
return torch.isinf(a_t)
|
|
375
|
+
|
|
376
|
+
def isnan(self, a):
|
|
377
|
+
"""Check for NaN."""
|
|
378
|
+
if isinstance(a, (int, float)):
|
|
379
|
+
return math.isnan(a)
|
|
380
|
+
a_t = self.asarray(a) if not isinstance(a, torch.Tensor) else a
|
|
381
|
+
return torch.isnan(a_t)
|
|
382
|
+
|
|
383
|
+
def sign(self, a):
|
|
384
|
+
"""Sign of elements."""
|
|
385
|
+
if isinstance(a, (int, float)):
|
|
386
|
+
return (a > 0) - (a < 0)
|
|
387
|
+
a_t = self.asarray(a) if not isinstance(a, torch.Tensor) else a
|
|
388
|
+
return torch.sign(a_t)
|
|
389
|
+
|
|
390
|
+
class TorchUfunc:
|
|
391
|
+
"""Wraps torch ops to support numpy ufunc interface (reduce, accumulate).
|
|
392
|
+
|
|
393
|
+
Falls back to numpy for object arrays since torch doesn't support them.
|
|
394
|
+
"""
|
|
395
|
+
def __init__(self, backend, op, reduce_op, accumulate_op=None, numpy_ufunc=None):
|
|
396
|
+
self._backend = backend
|
|
397
|
+
self._op = op
|
|
398
|
+
self._reduce_op = reduce_op
|
|
399
|
+
self._accumulate_op = accumulate_op
|
|
400
|
+
self._torch = torch
|
|
401
|
+
self._numpy_ufunc = numpy_ufunc
|
|
402
|
+
|
|
403
|
+
def _is_object_array(self, x):
|
|
404
|
+
return isinstance(x, numpy.ndarray) and x.dtype == object
|
|
405
|
+
|
|
406
|
+
def _to_numpy(self, x):
|
|
407
|
+
if isinstance(x, self._torch.Tensor):
|
|
408
|
+
return x.detach().cpu().numpy()
|
|
409
|
+
return x
|
|
410
|
+
|
|
411
|
+
def __call__(self, a, b):
|
|
412
|
+
a_is_tensor = isinstance(a, self._torch.Tensor)
|
|
413
|
+
b_is_tensor = isinstance(b, self._torch.Tensor)
|
|
414
|
+
# Fast path for tensor operations
|
|
415
|
+
if a_is_tensor and b_is_tensor and a.device == b.device:
|
|
416
|
+
return self._op(a, b)
|
|
417
|
+
if (a_is_tensor and isinstance(b, (int, float))) or \
|
|
418
|
+
(b_is_tensor and isinstance(a, (int, float))):
|
|
419
|
+
return self._op(a, b)
|
|
420
|
+
# Numpy fallback for object arrays
|
|
421
|
+
if self._numpy_ufunc and (self._is_object_array(a) or self._is_object_array(b)):
|
|
422
|
+
return self._numpy_ufunc(self._to_numpy(a), self._to_numpy(b))
|
|
423
|
+
try:
|
|
424
|
+
return self._op(self._backend.asarray(a), self._backend.asarray(b))
|
|
425
|
+
except TorchUnsupportedDtypeError:
|
|
426
|
+
if self._numpy_ufunc:
|
|
427
|
+
return self._numpy_ufunc(self._to_numpy(a), self._to_numpy(b))
|
|
428
|
+
raise
|
|
429
|
+
|
|
430
|
+
def reduce(self, a, axis=None):
|
|
431
|
+
if self._numpy_ufunc and self._is_object_array(a):
|
|
432
|
+
return self._numpy_ufunc.reduce(self._to_numpy(a), axis=axis)
|
|
433
|
+
try:
|
|
434
|
+
arr = self._backend.asarray(a)
|
|
435
|
+
if axis is None:
|
|
436
|
+
return self._reduce_op(arr)
|
|
437
|
+
return self._reduce_op(arr, dim=axis)
|
|
438
|
+
except TorchUnsupportedDtypeError:
|
|
439
|
+
if self._numpy_ufunc:
|
|
440
|
+
return self._numpy_ufunc.reduce(self._to_numpy(a), axis=axis)
|
|
441
|
+
raise
|
|
442
|
+
|
|
443
|
+
def accumulate(self, a, axis=0):
|
|
444
|
+
if self._numpy_ufunc and self._is_object_array(a):
|
|
445
|
+
return self._numpy_ufunc.accumulate(self._to_numpy(a), axis=axis)
|
|
446
|
+
try:
|
|
447
|
+
arr = self._backend.asarray(a)
|
|
448
|
+
if self._accumulate_op:
|
|
449
|
+
return self._accumulate_op(arr, dim=axis)
|
|
450
|
+
result = [arr[0]]
|
|
451
|
+
for i in range(1, len(arr)):
|
|
452
|
+
result.append(self._op(result[-1], arr[i]))
|
|
453
|
+
return self._torch.stack(result)
|
|
454
|
+
except TorchUnsupportedDtypeError:
|
|
455
|
+
if self._numpy_ufunc:
|
|
456
|
+
return self._numpy_ufunc.accumulate(self._to_numpy(a), axis=axis)
|
|
457
|
+
raise
|
|
458
|
+
|
|
459
|
+
@property
|
|
460
|
+
def add(self):
|
|
461
|
+
if self._add is None:
|
|
462
|
+
self._add = self.TorchUfunc(self, torch.add, torch.sum, torch.cumsum, numpy.add)
|
|
463
|
+
return self._add
|
|
464
|
+
|
|
465
|
+
@property
|
|
466
|
+
def subtract(self):
|
|
467
|
+
def cumulative_subtract(a, dim=0):
|
|
468
|
+
result = [a[0]]
|
|
469
|
+
for i in range(1, a.shape[dim]):
|
|
470
|
+
result.append(result[-1] - a[i])
|
|
471
|
+
return torch.stack(result)
|
|
472
|
+
return self.TorchUfunc(
|
|
473
|
+
self, torch.subtract,
|
|
474
|
+
lambda a, dim=None: a[0] - torch.sum(a[1:]) if dim is None else None,
|
|
475
|
+
cumulative_subtract,
|
|
476
|
+
numpy.subtract
|
|
477
|
+
)
|
|
478
|
+
|
|
479
|
+
@property
|
|
480
|
+
def multiply(self):
|
|
481
|
+
if self._multiply is None:
|
|
482
|
+
self._multiply = self.TorchUfunc(self, torch.multiply, torch.prod, torch.cumprod, numpy.multiply)
|
|
483
|
+
return self._multiply
|
|
484
|
+
|
|
485
|
+
@property
|
|
486
|
+
def divide(self):
|
|
487
|
+
def reduce_divide(a, dim=None):
|
|
488
|
+
if dim is None:
|
|
489
|
+
result = a.flatten()[0]
|
|
490
|
+
for x in a.flatten()[1:]:
|
|
491
|
+
result = result / x
|
|
492
|
+
return result
|
|
493
|
+
return None
|
|
494
|
+
return self.TorchUfunc(self, torch.divide, reduce_divide, None, numpy.divide)
|
|
495
|
+
|
|
496
|
+
@property
|
|
497
|
+
def inf(self):
|
|
498
|
+
return float('inf')
|
|
499
|
+
|
|
500
|
+
def seterr(self, **kwargs):
|
|
501
|
+
pass
|
|
502
|
+
|
|
503
|
+
@property
|
|
504
|
+
def VisibleDeprecationWarning(self):
|
|
505
|
+
return NumpyVisibleDeprecationWarning
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
class TorchBackendProvider(BackendProvider):
|
|
509
|
+
"""PyTorch-based backend provider."""
|
|
510
|
+
|
|
511
|
+
def __init__(self, device=None):
|
|
512
|
+
self._torch_backend = TorchBackend(device)
|
|
513
|
+
self._device = device
|
|
514
|
+
|
|
515
|
+
@property
|
|
516
|
+
def name(self) -> str:
|
|
517
|
+
return 'torch'
|
|
518
|
+
|
|
519
|
+
@property
|
|
520
|
+
def np(self):
|
|
521
|
+
return self._torch_backend
|
|
522
|
+
|
|
523
|
+
@property
|
|
524
|
+
def device(self):
|
|
525
|
+
return self._torch_backend.device
|
|
526
|
+
|
|
527
|
+
def supports_object_dtype(self) -> bool:
|
|
528
|
+
return False
|
|
529
|
+
|
|
530
|
+
def supports_strings(self) -> bool:
|
|
531
|
+
return False
|
|
532
|
+
|
|
533
|
+
def supports_float64(self) -> bool:
|
|
534
|
+
# MPS device doesn't support float64
|
|
535
|
+
return 'mps' not in str(self.device).lower()
|
|
536
|
+
|
|
537
|
+
def is_array(self, x) -> bool:
|
|
538
|
+
return isinstance(x, (numpy.ndarray, torch.Tensor))
|
|
539
|
+
|
|
540
|
+
def is_backend_array(self, x) -> bool:
|
|
541
|
+
"""Check if x is specifically a torch tensor (not numpy)."""
|
|
542
|
+
return isinstance(x, torch.Tensor)
|
|
543
|
+
|
|
544
|
+
def get_dtype_kind(self, arr) -> str:
|
|
545
|
+
if hasattr(arr, 'dtype'):
|
|
546
|
+
dtype = arr.dtype
|
|
547
|
+
# numpy arrays have dtype.kind
|
|
548
|
+
if hasattr(dtype, 'kind'):
|
|
549
|
+
return dtype.kind
|
|
550
|
+
# torch tensors need manual mapping
|
|
551
|
+
kind_map = {
|
|
552
|
+
torch.float16: 'f',
|
|
553
|
+
torch.float32: 'f',
|
|
554
|
+
torch.float64: 'f',
|
|
555
|
+
torch.bfloat16: 'f',
|
|
556
|
+
torch.int8: 'i',
|
|
557
|
+
torch.int16: 'i',
|
|
558
|
+
torch.int32: 'i',
|
|
559
|
+
torch.int64: 'i',
|
|
560
|
+
torch.uint8: 'u',
|
|
561
|
+
torch.bool: 'b',
|
|
562
|
+
torch.complex64: 'c',
|
|
563
|
+
torch.complex128: 'c',
|
|
564
|
+
}
|
|
565
|
+
return kind_map.get(dtype, 'f')
|
|
566
|
+
return None
|
|
567
|
+
|
|
568
|
+
def to_numpy(self, x):
|
|
569
|
+
"""Convert torch tensor to numpy array."""
|
|
570
|
+
if isinstance(x, torch.Tensor):
|
|
571
|
+
return x.detach().cpu().numpy()
|
|
572
|
+
return x
|
|
573
|
+
|
|
574
|
+
def is_scalar_integer(self, x) -> bool:
|
|
575
|
+
if isinstance(x, torch.Tensor) and x.ndim == 0:
|
|
576
|
+
return x.dtype in (torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8)
|
|
577
|
+
return False
|
|
578
|
+
|
|
579
|
+
def is_scalar_float(self, x) -> bool:
|
|
580
|
+
if isinstance(x, torch.Tensor) and x.ndim == 0:
|
|
581
|
+
return x.dtype in (torch.float16, torch.float32, torch.float64, torch.bfloat16)
|
|
582
|
+
return False
|
|
583
|
+
|
|
584
|
+
def argsort(self, a, descending=False):
|
|
585
|
+
"""Return indices that would sort the array."""
|
|
586
|
+
if not isinstance(a, torch.Tensor):
|
|
587
|
+
a = self._torch_backend.asarray(a)
|
|
588
|
+
return torch.argsort(a, descending=descending)
|
|
589
|
+
|
|
590
|
+
def array_size(self, a):
|
|
591
|
+
"""Get the total number of elements in an array/tensor."""
|
|
592
|
+
if isinstance(a, torch.Tensor):
|
|
593
|
+
return a.numel()
|
|
594
|
+
if hasattr(a, 'size'):
|
|
595
|
+
size = a.size
|
|
596
|
+
return size if isinstance(size, int) else size()
|
|
597
|
+
return len(a) if hasattr(a, '__len__') else 1
|
|
598
|
+
|
|
599
|
+
def safe_equal(self, x, y):
|
|
600
|
+
"""Compare two values, handling torch tensors correctly."""
|
|
601
|
+
if isinstance(x, torch.Tensor) or isinstance(y, torch.Tensor):
|
|
602
|
+
# Convert scalars to tensors for comparison
|
|
603
|
+
if isinstance(x, torch.Tensor) and x.dim() == 0:
|
|
604
|
+
x = x.item()
|
|
605
|
+
if isinstance(y, torch.Tensor) and y.dim() == 0:
|
|
606
|
+
y = y.item()
|
|
607
|
+
return x == y
|
|
608
|
+
# Default numpy comparison
|
|
609
|
+
return numpy.asarray(x, dtype=object) == numpy.asarray(y, dtype=object)
|
|
610
|
+
|
|
611
|
+
def detach_if_needed(self, x):
|
|
612
|
+
"""Detach tensor if it requires grad, to allow type conversions."""
|
|
613
|
+
if isinstance(x, torch.Tensor) and x.requires_grad:
|
|
614
|
+
return x.detach()
|
|
615
|
+
return x
|
|
616
|
+
|
|
617
|
+
def to_int_array(self, a):
|
|
618
|
+
"""Convert array/tensor to integer type."""
|
|
619
|
+
if isinstance(a, torch.Tensor):
|
|
620
|
+
return a.to(int)
|
|
621
|
+
return numpy.asarray(a, dtype=int) if isinstance(a, numpy.ndarray) else int(a)
|
|
622
|
+
|
|
623
|
+
def power(self, a, b):
|
|
624
|
+
"""Compute a^b, handling gradient tracking for torch tensors."""
|
|
625
|
+
# Use torch.pow for tensors to maintain gradients when possible
|
|
626
|
+
if isinstance(a, torch.Tensor):
|
|
627
|
+
return a.pow(b)
|
|
628
|
+
# For numpy arrays or scalars
|
|
629
|
+
a_val = float(a) if isinstance(a, (int, numpy.integer)) else a
|
|
630
|
+
return numpy.power(a_val, b)
|
|
631
|
+
|
|
632
|
+
def has_gradient(self, x) -> bool:
|
|
633
|
+
"""Check if x is tracking gradients."""
|
|
634
|
+
return isinstance(x, torch.Tensor) and x.requires_grad
|
|
635
|
+
|
|
636
|
+
def supports_autograd(self) -> bool:
|
|
637
|
+
"""Torch backend supports automatic differentiation."""
|
|
638
|
+
return True
|
|
639
|
+
|
|
640
|
+
def create_grad_tensor(self, x):
|
|
641
|
+
"""Create a tensor that tracks gradients."""
|
|
642
|
+
if isinstance(x, torch.Tensor):
|
|
643
|
+
return x.clone().detach().float().requires_grad_(True)
|
|
644
|
+
elif isinstance(x, numpy.ndarray):
|
|
645
|
+
return torch.from_numpy(x.astype(numpy.float64)).float().requires_grad_(True)
|
|
646
|
+
else:
|
|
647
|
+
return torch.tensor(x, dtype=torch.float32, requires_grad=True)
|
|
648
|
+
|
|
649
|
+
def compute_autograd(self, func, x):
|
|
650
|
+
"""Compute gradient using PyTorch automatic differentiation."""
|
|
651
|
+
from ..autograd import AutogradChainBrokenError, NonScalarLossError
|
|
652
|
+
|
|
653
|
+
x_tensor = self.create_grad_tensor(x)
|
|
654
|
+
|
|
655
|
+
# Compute the function value
|
|
656
|
+
y = func(x_tensor)
|
|
657
|
+
|
|
658
|
+
# Check result type - must be a tensor for autograd to work
|
|
659
|
+
if not isinstance(y, torch.Tensor):
|
|
660
|
+
if isinstance(y, numpy.ndarray):
|
|
661
|
+
raise AutogradChainBrokenError(
|
|
662
|
+
"function output",
|
|
663
|
+
"torch.Tensor",
|
|
664
|
+
"numpy.ndarray",
|
|
665
|
+
"Avoid numpy operations. Use torch-compatible functions."
|
|
666
|
+
)
|
|
667
|
+
raise AutogradChainBrokenError(
|
|
668
|
+
"function output",
|
|
669
|
+
"torch.Tensor",
|
|
670
|
+
type(y).__name__,
|
|
671
|
+
"For autograd, use torch-compatible operations."
|
|
672
|
+
)
|
|
673
|
+
|
|
674
|
+
# Ensure y is a scalar
|
|
675
|
+
if y.numel() != 1:
|
|
676
|
+
raise NonScalarLossError(tuple(y.shape))
|
|
677
|
+
|
|
678
|
+
# Check requires_grad
|
|
679
|
+
if not y.requires_grad:
|
|
680
|
+
raise AutogradChainBrokenError(
|
|
681
|
+
"gradient computation",
|
|
682
|
+
"requires_grad=True",
|
|
683
|
+
"requires_grad=False",
|
|
684
|
+
"Output lost gradient tracking. Avoid .item(), .numpy(), or Python float()."
|
|
685
|
+
)
|
|
686
|
+
|
|
687
|
+
# Compute gradient
|
|
688
|
+
y.backward()
|
|
689
|
+
|
|
690
|
+
return x_tensor.grad
|
|
691
|
+
|
|
692
|
+
def compute_multi_autograd(self, func, params):
|
|
693
|
+
"""
|
|
694
|
+
Compute gradients for multiple parameters using torch.autograd.grad().
|
|
695
|
+
|
|
696
|
+
Args:
|
|
697
|
+
func: Callable that takes a list of tensors and returns a scalar loss
|
|
698
|
+
params: List of parameter values to compute gradients for
|
|
699
|
+
|
|
700
|
+
Returns:
|
|
701
|
+
List of gradients, one per parameter
|
|
702
|
+
"""
|
|
703
|
+
from ..autograd import AutogradChainBrokenError, NonScalarLossError
|
|
704
|
+
|
|
705
|
+
# Create grad tensors for all parameters
|
|
706
|
+
grad_tensors = [self.create_grad_tensor(p) for p in params]
|
|
707
|
+
|
|
708
|
+
# Compute the function value (loss)
|
|
709
|
+
y = func(grad_tensors)
|
|
710
|
+
|
|
711
|
+
# Validate output is a tensor
|
|
712
|
+
if not isinstance(y, torch.Tensor):
|
|
713
|
+
if isinstance(y, numpy.ndarray):
|
|
714
|
+
raise AutogradChainBrokenError(
|
|
715
|
+
"loss computation",
|
|
716
|
+
"torch.Tensor",
|
|
717
|
+
"numpy.ndarray",
|
|
718
|
+
"Avoid numpy operations in the loss function."
|
|
719
|
+
)
|
|
720
|
+
raise AutogradChainBrokenError(
|
|
721
|
+
"loss computation",
|
|
722
|
+
"torch.Tensor",
|
|
723
|
+
type(y).__name__,
|
|
724
|
+
"For autograd, use torch-compatible operations."
|
|
725
|
+
)
|
|
726
|
+
|
|
727
|
+
# Ensure y is a scalar
|
|
728
|
+
if y.numel() != 1:
|
|
729
|
+
raise NonScalarLossError(tuple(y.shape))
|
|
730
|
+
|
|
731
|
+
# Compute all gradients in one backward pass using torch.autograd.grad
|
|
732
|
+
grads = torch.autograd.grad(y, grad_tensors, create_graph=False)
|
|
733
|
+
|
|
734
|
+
return list(grads)
|
|
735
|
+
|
|
736
|
+
def compute_jacobian(self, func, x):
|
|
737
|
+
"""
|
|
738
|
+
Compute Jacobian matrix using torch.autograd.functional.jacobian().
|
|
739
|
+
|
|
740
|
+
Args:
|
|
741
|
+
func: Callable that takes x and returns a vector
|
|
742
|
+
x: Input point
|
|
743
|
+
|
|
744
|
+
Returns:
|
|
745
|
+
Jacobian matrix J where J[i,j] = df_i/dx_j
|
|
746
|
+
"""
|
|
747
|
+
import torch.autograd.functional as F
|
|
748
|
+
|
|
749
|
+
x_tensor = self.create_grad_tensor(x)
|
|
750
|
+
|
|
751
|
+
# torch.autograd.functional.jacobian expects func(inputs) -> outputs
|
|
752
|
+
jacobian = F.jacobian(func, x_tensor)
|
|
753
|
+
|
|
754
|
+
return jacobian
|
|
755
|
+
|
|
756
|
+
def str_to_char_array(self, s):
|
|
757
|
+
"""Not supported in torch backend."""
|
|
758
|
+
raise TorchUnsupportedDtypeError(
|
|
759
|
+
"PyTorch backend does not support string-to-character array conversion."
|
|
760
|
+
)
|
|
761
|
+
|
|
762
|
+
def compile_function(self, func, example_input, output_path=None, mode="default",
|
|
763
|
+
backend="inductor", fullgraph=False, dynamic=None):
|
|
764
|
+
"""
|
|
765
|
+
Compile a function using torch.compile with configurable options.
|
|
766
|
+
|
|
767
|
+
Args:
|
|
768
|
+
func: Callable to compile
|
|
769
|
+
example_input: Example input for tracing the function
|
|
770
|
+
output_path: Optional path to save the exported graph (.pt2 file)
|
|
771
|
+
mode: Compilation mode - affects speed/quality tradeoff
|
|
772
|
+
- "default": Balanced compilation (default)
|
|
773
|
+
- "reduce-overhead": Faster compile, less optimization
|
|
774
|
+
- "max-autotune": Slower compile, maximum runtime performance
|
|
775
|
+
backend: Compilation backend
|
|
776
|
+
- "inductor": Default backend with C++/Triton codegen
|
|
777
|
+
- "eager": No compilation (for debugging)
|
|
778
|
+
- "aot_eager": Ahead-of-time eager (debugging with autograd)
|
|
779
|
+
- "cudagraphs": CUDA graphs for GPU (reduces launch overhead)
|
|
780
|
+
fullgraph: If True, requires entire function to compile as one graph
|
|
781
|
+
dynamic: If True, enables dynamic shapes; if False, assumes static shapes
|
|
782
|
+
|
|
783
|
+
Returns:
|
|
784
|
+
If output_path is None: compiled function
|
|
785
|
+
If output_path is provided: dict with compiled function and export info
|
|
786
|
+
|
|
787
|
+
Compilation Modes Comparison:
|
|
788
|
+
| Mode | Compile Time | Runtime Speed | Best For |
|
|
789
|
+
|-----------------|--------------|---------------|---------------------|
|
|
790
|
+
| default | Medium | Good | General use |
|
|
791
|
+
| reduce-overhead | Fast | Moderate | Quick iteration |
|
|
792
|
+
| max-autotune | Slow | Best | Production/training |
|
|
793
|
+
| (eager backend) | None | Baseline | Debugging |
|
|
794
|
+
"""
|
|
795
|
+
# Convert example input to tensor if needed
|
|
796
|
+
if not isinstance(example_input, torch.Tensor):
|
|
797
|
+
example_input = self.create_grad_tensor(example_input)
|
|
798
|
+
|
|
799
|
+
# Build compile options
|
|
800
|
+
compile_kwargs = {
|
|
801
|
+
'mode': mode,
|
|
802
|
+
'backend': backend,
|
|
803
|
+
'fullgraph': fullgraph,
|
|
804
|
+
}
|
|
805
|
+
if dynamic is not None:
|
|
806
|
+
compile_kwargs['dynamic'] = dynamic
|
|
807
|
+
|
|
808
|
+
# Compile the function with specified options
|
|
809
|
+
compiled_fn = torch.compile(func, **compile_kwargs)
|
|
810
|
+
|
|
811
|
+
# Warm up the compiled function (triggers actual compilation)
|
|
812
|
+
_ = compiled_fn(example_input)
|
|
813
|
+
|
|
814
|
+
if output_path is None:
|
|
815
|
+
return compiled_fn
|
|
816
|
+
|
|
817
|
+
# Export the function graph for inspection
|
|
818
|
+
try:
|
|
819
|
+
# Use torch.export for graph capture
|
|
820
|
+
exported = torch.export.export(func, (example_input,))
|
|
821
|
+
|
|
822
|
+
# Save the exported program
|
|
823
|
+
torch.export.save(exported, output_path)
|
|
824
|
+
|
|
825
|
+
# Get graph representation for inspection
|
|
826
|
+
graph_str = str(exported.graph_module.graph)
|
|
827
|
+
|
|
828
|
+
return {
|
|
829
|
+
'compiled_fn': compiled_fn,
|
|
830
|
+
'export_path': output_path,
|
|
831
|
+
'graph': graph_str,
|
|
832
|
+
'graph_module': exported.graph_module,
|
|
833
|
+
'mode': mode,
|
|
834
|
+
'backend': backend,
|
|
835
|
+
}
|
|
836
|
+
except Exception as e:
|
|
837
|
+
# If export fails, still return the compiled function
|
|
838
|
+
return {
|
|
839
|
+
'compiled_fn': compiled_fn,
|
|
840
|
+
'export_path': None,
|
|
841
|
+
'export_error': str(e),
|
|
842
|
+
'mode': mode,
|
|
843
|
+
'backend': backend,
|
|
844
|
+
}
|
|
845
|
+
|
|
846
|
+
def get_compile_modes(self):
|
|
847
|
+
"""
|
|
848
|
+
Return information about available compilation modes.
|
|
849
|
+
|
|
850
|
+
Returns:
|
|
851
|
+
Dict with mode descriptions and recommendations
|
|
852
|
+
"""
|
|
853
|
+
return {
|
|
854
|
+
'modes': {
|
|
855
|
+
'default': 'Balanced compilation - good for most cases',
|
|
856
|
+
'reduce-overhead': 'Faster compile time, less optimization - good for development',
|
|
857
|
+
'max-autotune': 'Maximum optimization - best for production/training loops',
|
|
858
|
+
},
|
|
859
|
+
'backends': {
|
|
860
|
+
'inductor': 'Default backend with C++/Triton code generation',
|
|
861
|
+
'eager': 'No compilation - runs original Python (for debugging)',
|
|
862
|
+
'aot_eager': 'Ahead-of-time eager - captures autograd graph (debugging)',
|
|
863
|
+
'cudagraphs': 'CUDA graphs - reduces kernel launch overhead (GPU only)',
|
|
864
|
+
},
|
|
865
|
+
'recommendations': {
|
|
866
|
+
'development': {'mode': 'reduce-overhead', 'backend': 'inductor'},
|
|
867
|
+
'production': {'mode': 'max-autotune', 'backend': 'inductor'},
|
|
868
|
+
'debugging': {'mode': 'default', 'backend': 'eager'},
|
|
869
|
+
'gpu_inference': {'mode': 'max-autotune', 'backend': 'cudagraphs'},
|
|
870
|
+
}
|
|
871
|
+
}
|
|
872
|
+
|
|
873
|
+
def gradcheck(self, func, inputs, eps=1e-6, atol=1e-5, rtol=1e-3):
|
|
874
|
+
"""
|
|
875
|
+
Check gradients computed by autograd against numeric gradients.
|
|
876
|
+
|
|
877
|
+
Uses torch.autograd.gradcheck to verify correctness.
|
|
878
|
+
|
|
879
|
+
Args:
|
|
880
|
+
func: Function to check (should return scalar or tensor)
|
|
881
|
+
inputs: Tuple of input tensors (must have requires_grad=True)
|
|
882
|
+
eps: Step size for numeric differentiation
|
|
883
|
+
atol: Absolute tolerance
|
|
884
|
+
rtol: Relative tolerance
|
|
885
|
+
|
|
886
|
+
Returns:
|
|
887
|
+
True if gradients match, raises GradcheckError otherwise
|
|
888
|
+
"""
|
|
889
|
+
# Ensure inputs are tensors with gradients
|
|
890
|
+
tensor_inputs = []
|
|
891
|
+
for inp in inputs:
|
|
892
|
+
if isinstance(inp, torch.Tensor):
|
|
893
|
+
if not inp.requires_grad:
|
|
894
|
+
inp = inp.clone().detach().float().requires_grad_(True)
|
|
895
|
+
tensor_inputs.append(inp)
|
|
896
|
+
else:
|
|
897
|
+
tensor_inputs.append(self.create_grad_tensor(inp))
|
|
898
|
+
|
|
899
|
+
return torch.autograd.gradcheck(
|
|
900
|
+
func,
|
|
901
|
+
tuple(tensor_inputs),
|
|
902
|
+
eps=eps,
|
|
903
|
+
atol=atol,
|
|
904
|
+
rtol=rtol,
|
|
905
|
+
raise_exception=True
|
|
906
|
+
)
|
|
907
|
+
|
|
908
|
+
def klong_gradcheck(self, klong, fn, inputs):
|
|
909
|
+
"""
|
|
910
|
+
Check gradients for a Klong function.
|
|
911
|
+
|
|
912
|
+
Handles wrapping the Klong function, dtype selection based on device,
|
|
913
|
+
and tolerance adjustment for float32 (MPS).
|
|
914
|
+
|
|
915
|
+
Args:
|
|
916
|
+
klong: KlongInterpreter instance
|
|
917
|
+
fn: Klong function to check
|
|
918
|
+
inputs: Input value or list of inputs
|
|
919
|
+
|
|
920
|
+
Returns:
|
|
921
|
+
1 if gradients are correct, raises error otherwise
|
|
922
|
+
"""
|
|
923
|
+
from ..autograd import _invoke_fn
|
|
924
|
+
|
|
925
|
+
# Determine dtype based on device support
|
|
926
|
+
use_float32 = self.device.type == 'mps' # MPS doesn't support float64
|
|
927
|
+
dtype = torch.float32 if use_float32 else torch.float64
|
|
928
|
+
|
|
929
|
+
# Wrap the Klong function
|
|
930
|
+
def wrapped_fn(v):
|
|
931
|
+
result = _invoke_fn(klong, fn, [v])
|
|
932
|
+
# Ensure result is a scalar tensor for gradcheck
|
|
933
|
+
if isinstance(result, torch.Tensor) and result.numel() > 1:
|
|
934
|
+
result = result.sum()
|
|
935
|
+
return result
|
|
936
|
+
|
|
937
|
+
# Convert inputs to tensor on CPU for gradcheck (avoids MPS float64 issues)
|
|
938
|
+
if isinstance(inputs, (list, tuple)) and not isinstance(inputs[0], torch.Tensor):
|
|
939
|
+
tensor_inputs = torch.tensor(inputs, dtype=dtype, device='cpu', requires_grad=True)
|
|
940
|
+
elif not isinstance(inputs, torch.Tensor):
|
|
941
|
+
tensor_inputs = torch.tensor([inputs], dtype=dtype, device='cpu', requires_grad=True)
|
|
942
|
+
else:
|
|
943
|
+
tensor_inputs = inputs.to(dtype=dtype, device='cpu').requires_grad_(True)
|
|
944
|
+
|
|
945
|
+
# Run gradcheck with adjusted tolerances for float32
|
|
946
|
+
if use_float32:
|
|
947
|
+
result = self.gradcheck(wrapped_fn, (tensor_inputs,), eps=1e-4, atol=1e-3, rtol=1e-2)
|
|
948
|
+
else:
|
|
949
|
+
result = self.gradcheck(wrapped_fn, (tensor_inputs,))
|
|
950
|
+
|
|
951
|
+
return 1 if result else 0
|
|
952
|
+
|
|
953
|
+
def kg_asarray(self, a):
|
|
954
|
+
"""
|
|
955
|
+
Converts input data into a PyTorch tensor for KlongPy.
|
|
956
|
+
|
|
957
|
+
For data that can't be converted to tensors (strings, heterogeneous
|
|
958
|
+
types, jagged arrays), falls back to numpy object arrays to maintain
|
|
959
|
+
compatibility with Klong's list semantics.
|
|
960
|
+
"""
|
|
961
|
+
if isinstance(a, str):
|
|
962
|
+
# Strings become numpy character arrays like in numpy backend
|
|
963
|
+
return numpy.array(list(a))
|
|
964
|
+
try:
|
|
965
|
+
# Check for jagged arrays early - torch converts them incorrectly
|
|
966
|
+
if is_jagged_array(a):
|
|
967
|
+
raise TorchUnsupportedDtypeError("Jagged arrays not supported")
|
|
968
|
+
arr = self._torch_backend.asarray(a)
|
|
969
|
+
if hasattr(arr, 'dtype'):
|
|
970
|
+
# For torch tensors, dtype doesn't have .kind attribute
|
|
971
|
+
if hasattr(arr.dtype, 'kind'):
|
|
972
|
+
if arr.dtype.kind not in ['O', 'i', 'f']:
|
|
973
|
+
raise ValueError
|
|
974
|
+
return arr
|
|
975
|
+
except (NumpyVisibleDeprecationWarning, ValueError, TypeError, RuntimeError, TorchUnsupportedDtypeError):
|
|
976
|
+
# Fall back to numpy object array for heterogeneous/unsupported data
|
|
977
|
+
# Use numpy for inner conversions to avoid MPS tensor issues
|
|
978
|
+
def _numpy_convert(x):
|
|
979
|
+
if isinstance(x, list):
|
|
980
|
+
try:
|
|
981
|
+
return numpy.asarray(x)
|
|
982
|
+
except (ValueError, TypeError):
|
|
983
|
+
return numpy.asarray([_numpy_convert(i) for i in x], dtype=object)
|
|
984
|
+
return x
|
|
985
|
+
try:
|
|
986
|
+
arr = numpy.asarray(a, dtype=object)
|
|
987
|
+
# Recursively convert inner lists to numpy arrays
|
|
988
|
+
arr = numpy.asarray(
|
|
989
|
+
[_numpy_convert(x) if isinstance(x, list) else x for x in arr],
|
|
990
|
+
dtype=object
|
|
991
|
+
)
|
|
992
|
+
return arr
|
|
993
|
+
except (ValueError, TypeError):
|
|
994
|
+
# Last resort: keep as list
|
|
995
|
+
return a
|