torchax 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/CONTRIBUTING.md +2 -2
- torchax/__init__.py +26 -24
- torchax/amp.py +332 -0
- torchax/config.py +25 -14
- torchax/configuration.py +30 -0
- torchax/decompositions.py +663 -195
- torchax/device_module.py +14 -1
- torchax/environment.py +0 -1
- torchax/export.py +26 -17
- torchax/flax.py +39 -0
- torchax/interop.py +288 -141
- torchax/mesh_util.py +220 -0
- torchax/ops/jaten.py +1723 -1297
- torchax/ops/jax_reimplement.py +23 -21
- torchax/ops/jc10d.py +5 -4
- torchax/ops/jimage.py +113 -0
- torchax/ops/jlibrary.py +9 -2
- torchax/ops/jtorch.py +237 -88
- torchax/ops/jtorchvision_nms.py +32 -43
- torchax/ops/mappings.py +77 -35
- torchax/ops/op_base.py +59 -32
- torchax/ops/ops_registry.py +40 -35
- torchax/tensor.py +442 -288
- torchax/train.py +38 -41
- torchax/util.py +88 -0
- torchax/view.py +377 -0
- {torchax-0.0.4.dist-info → torchax-0.0.6.dist-info}/METADATA +111 -145
- torchax-0.0.6.dist-info/RECORD +33 -0
- torchax/distributed.py +0 -246
- torchax-0.0.4.dist-info/RECORD +0 -27
- {torchax-0.0.4.dist-info → torchax-0.0.6.dist-info}/WHEEL +0 -0
- {torchax-0.0.4.dist-info → torchax-0.0.6.dist-info}/licenses/LICENSE +0 -0
torchax/ops/jaten.py
CHANGED
|
@@ -15,75 +15,22 @@ from torchax.ops import ops_registry
|
|
|
15
15
|
from torchax.ops import op_base, mappings
|
|
16
16
|
from torchax import interop
|
|
17
17
|
from torchax.ops import jax_reimplement
|
|
18
|
-
|
|
18
|
+
from torchax.view import View
|
|
19
19
|
# Keys are OpOverload, value is a callable that takes
|
|
20
20
|
# Tensor
|
|
21
21
|
all_ops = {}
|
|
22
22
|
|
|
23
|
-
# list all Aten ops from pytorch that does mutation
|
|
24
|
-
# and need to be implemented in jax
|
|
25
|
-
|
|
26
|
-
mutation_ops_to_functional = {
|
|
27
|
-
torch.ops.aten.add_: torch.ops.aten.add,
|
|
28
|
-
torch.ops.aten.sub_: torch.ops.aten.sub,
|
|
29
|
-
torch.ops.aten.mul_: torch.ops.aten.mul,
|
|
30
|
-
torch.ops.aten.div_: torch.ops.aten.div,
|
|
31
|
-
torch.ops.aten.pow_: torch.ops.aten.pow,
|
|
32
|
-
torch.ops.aten.lt_: torch.ops.aten.lt,
|
|
33
|
-
torch.ops.aten.le_: torch.ops.aten.le,
|
|
34
|
-
torch.ops.aten.gt_: torch.ops.aten.gt,
|
|
35
|
-
torch.ops.aten.ge_: torch.ops.aten.ge,
|
|
36
|
-
torch.ops.aten.eq_: torch.ops.aten.eq,
|
|
37
|
-
torch.ops.aten.ne_: torch.ops.aten.ne,
|
|
38
|
-
torch.ops.aten.bernoulli_: torch.ops.aten.bernoulli.p,
|
|
39
|
-
torch.ops.aten.geometric_: torch.ops.aten.geometric,
|
|
40
|
-
torch.ops.aten.normal_: torch.ops.aten.normal,
|
|
41
|
-
torch.ops.aten.random_: torch.ops.aten.uniform,
|
|
42
|
-
torch.ops.aten.uniform_: torch.ops.aten.uniform,
|
|
43
|
-
torch.ops.aten.relu_: torch.ops.aten.relu,
|
|
44
|
-
# squeeze_ is expected to change tensor's shape. So replace with new value
|
|
45
|
-
torch.ops.aten.squeeze_: (torch.ops.aten.squeeze, True),
|
|
46
|
-
torch.ops.aten.sqrt_: torch.ops.aten.sqrt,
|
|
47
|
-
torch.ops.aten.clamp_: torch.ops.aten.clamp,
|
|
48
|
-
torch.ops.aten.clamp_min_: torch.ops.aten.clamp_min,
|
|
49
|
-
torch.ops.aten.sigmoid_: torch.ops.aten.sigmoid,
|
|
50
|
-
torch.ops.aten.tanh_: torch.ops.aten.tanh,
|
|
51
|
-
torch.ops.aten.ceil_: torch.ops.aten.ceil,
|
|
52
|
-
torch.ops.aten.logical_not_: torch.ops.aten.logical_not,
|
|
53
|
-
torch.ops.aten.unsqueeze_: torch.ops.aten.unsqueeze,
|
|
54
|
-
torch.ops.aten.transpose_: torch.ops.aten.transpose,
|
|
55
|
-
torch.ops.aten.log_normal_: torch.ops.aten.log_normal,
|
|
56
|
-
torch.ops.aten.scatter_add_: torch.ops.aten.scatter_add,
|
|
57
|
-
torch.ops.aten.scatter_reduce_.two: torch.ops.aten.scatter_reduce,
|
|
58
|
-
torch.ops.aten.scatter_: torch.ops.aten.scatter,
|
|
59
|
-
}
|
|
60
|
-
|
|
61
|
-
# Note: tuple comparisons work intuitively, e.g. `_jax_version >= (0, 4, 32)`.
|
|
62
|
-
_jax_version = tuple(int(v) for v in jax.version._version.split("."))
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
def make_mutation(op):
|
|
66
|
-
if type(mutation_ops_to_functional[op]) is tuple:
|
|
67
|
-
return op_base.InplaceOp(mutation_ops_to_functional[op][0],
|
|
68
|
-
replace=mutation_ops_to_functional[op][1],
|
|
69
|
-
position_to_mutate=0)
|
|
70
|
-
return op_base.InplaceOp(mutation_ops_to_functional[op], position_to_mutate=0)
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
for op in mutation_ops_to_functional.keys():
|
|
74
|
-
ops_registry.register_torch_dispatch_op(
|
|
75
|
-
op, make_mutation(op), is_jax_function=False
|
|
76
|
-
)
|
|
77
|
-
|
|
78
23
|
|
|
79
24
|
def op(*aten, **kwargs):
|
|
25
|
+
|
|
80
26
|
def inner(func):
|
|
81
27
|
for a in aten:
|
|
82
28
|
ops_registry.register_torch_dispatch_op(a, func, **kwargs)
|
|
83
29
|
continue
|
|
84
30
|
|
|
85
31
|
if isinstance(a, torch._ops.OpOverloadPacket):
|
|
86
|
-
opname = a.default.name() if 'default' in a.overloads(
|
|
32
|
+
opname = a.default.name() if 'default' in a.overloads(
|
|
33
|
+
) else a._qualified_op_name
|
|
87
34
|
elif isinstance(a, torch._ops.OpOverload):
|
|
88
35
|
opname = a.name()
|
|
89
36
|
else:
|
|
@@ -91,17 +38,18 @@ def op(*aten, **kwargs):
|
|
|
91
38
|
|
|
92
39
|
torchfunc = functools.partial(interop.call_jax, func)
|
|
93
40
|
# HACK: to_copy is where we make the initial conversion from CPU tensor to JAX tensor
|
|
94
|
-
torch.library.impl(opname, 'privateuseone')(
|
|
41
|
+
torch.library.impl(opname, 'privateuseone')(
|
|
42
|
+
torchfunc if a != torch.ops.aten._to_copy else func)
|
|
95
43
|
return func
|
|
96
44
|
|
|
97
45
|
return inner
|
|
98
46
|
|
|
99
47
|
|
|
100
48
|
@op(
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
49
|
+
torch.ops.aten.view_copy,
|
|
50
|
+
torch.ops.aten.view,
|
|
51
|
+
torch.ops.aten._unsafe_view,
|
|
52
|
+
torch.ops.aten.reshape,
|
|
105
53
|
)
|
|
106
54
|
def _aten_unsafe_view(x, shape):
|
|
107
55
|
return jnp.reshape(x, shape)
|
|
@@ -121,8 +69,19 @@ def _aten_add(x, y, *, alpha=1):
|
|
|
121
69
|
return res
|
|
122
70
|
|
|
123
71
|
|
|
124
|
-
@op(torch.ops.aten.copy_,
|
|
125
|
-
|
|
72
|
+
@op(torch.ops.aten.copy_,
|
|
73
|
+
is_jax_function=False,
|
|
74
|
+
is_view_op=True,
|
|
75
|
+
needs_env=True)
|
|
76
|
+
def _aten_copy(x, y, memory_format=None, env=None):
|
|
77
|
+
|
|
78
|
+
if y.device.type == 'cpu':
|
|
79
|
+
y = env.to_xla(y)
|
|
80
|
+
|
|
81
|
+
if isinstance(x, View):
|
|
82
|
+
x.update(y)
|
|
83
|
+
return x
|
|
84
|
+
|
|
126
85
|
if x.ndim == 1 and y.ndim == 0:
|
|
127
86
|
# case of torch.empty((1,)).copy_(tensor(N))
|
|
128
87
|
# we need to return 0D tensor([N]) and not scalar tensor(N)
|
|
@@ -147,14 +106,20 @@ def _aten_trunc(x):
|
|
|
147
106
|
|
|
148
107
|
@op(torch.ops.aten.index_copy)
|
|
149
108
|
def _aten_index_copy(x, dim, indexes, source):
|
|
109
|
+
if x.ndim == 0:
|
|
110
|
+
return source
|
|
111
|
+
if x.ndim == 1:
|
|
112
|
+
source = jnp.squeeze(source)
|
|
150
113
|
# return jax.lax.scatter(x, index, dim)
|
|
114
|
+
if dim < 0:
|
|
115
|
+
dim = dim + x.ndim
|
|
151
116
|
dims = []
|
|
152
117
|
for i in range(len(x.shape)):
|
|
153
118
|
if i == dim:
|
|
154
119
|
dims.append(indexes)
|
|
155
120
|
else:
|
|
156
121
|
dims.append(slice(None, None, None))
|
|
157
|
-
return x.at[
|
|
122
|
+
return x.at[tuple(dims)].set(source)
|
|
158
123
|
|
|
159
124
|
|
|
160
125
|
# aten.cauchy_
|
|
@@ -199,7 +164,9 @@ def _aten_complex(real, imag):
|
|
|
199
164
|
Returns:
|
|
200
165
|
A complex array with the specified real and imaginary parts.
|
|
201
166
|
"""
|
|
202
|
-
return jnp.array(
|
|
167
|
+
return jnp.array(
|
|
168
|
+
real, dtype=jnp.float32) + 1j * jnp.array(
|
|
169
|
+
imag, dtype=jnp.float32)
|
|
203
170
|
|
|
204
171
|
|
|
205
172
|
# aten.exponential_
|
|
@@ -223,13 +190,14 @@ def _aten_exponential_(x, lambd=1.0):
|
|
|
223
190
|
# aten.linalg_householder_product
|
|
224
191
|
@op(torch.ops.aten.linalg_householder_product)
|
|
225
192
|
def _aten_linalg_householder_product(input, tau):
|
|
226
|
-
return jax.lax.linalg.householder_product(a
|
|
193
|
+
return jax.lax.linalg.householder_product(a=input, taus=tau)
|
|
227
194
|
|
|
228
195
|
|
|
229
196
|
@op(torch.ops.aten.select)
|
|
230
197
|
def _aten_select(x, dim, indexes):
|
|
231
198
|
return jax.lax.index_in_dim(x, index=indexes, axis=dim, keepdims=False)
|
|
232
199
|
|
|
200
|
+
|
|
233
201
|
@op(torch.ops.aten.index_select)
|
|
234
202
|
@op(torch.ops.aten.select_copy)
|
|
235
203
|
def _aten_index_select(x, dim, index):
|
|
@@ -249,11 +217,10 @@ def _aten_linalg_cholesky_ex(input, upper=False, check_errors=False):
|
|
|
249
217
|
raise NotImplementedError(
|
|
250
218
|
"check_errors=True is not supported in this JAX implementation. "
|
|
251
219
|
"Check for positive definiteness using jnp.linalg.eigvalsh before "
|
|
252
|
-
"calling this function."
|
|
253
|
-
)
|
|
220
|
+
"calling this function.")
|
|
254
221
|
|
|
255
222
|
L = jax.scipy.linalg.cholesky(input, lower=not upper)
|
|
256
|
-
if len(L.shape) >2:
|
|
223
|
+
if len(L.shape) > 2:
|
|
257
224
|
info = jnp.zeros(shape=L.shape[:-2], dtype=jnp.int32)
|
|
258
225
|
else:
|
|
259
226
|
info = jnp.array(0, dtype=jnp.int32)
|
|
@@ -263,7 +230,7 @@ def _aten_linalg_cholesky_ex(input, upper=False, check_errors=False):
|
|
|
263
230
|
@op(torch.ops.aten.cholesky_solve)
|
|
264
231
|
def _aten_cholesky_solve(input, input2, upper=False):
|
|
265
232
|
# Ensure input2 is lower triangular for cho_solve
|
|
266
|
-
L = input2 if not upper else input2.T
|
|
233
|
+
L = input2 if not upper else input2.T
|
|
267
234
|
# Use cho_solve to solve the linear system
|
|
268
235
|
solution = jax.scipy.linalg.cho_solve((L, True), input)
|
|
269
236
|
return solution
|
|
@@ -275,7 +242,7 @@ def _aten_special_zeta(x, q):
|
|
|
275
242
|
res = jax.scipy.special.zeta(x, q)
|
|
276
243
|
if isinstance(x, int) or isinstance(q, int):
|
|
277
244
|
res = res.astype(new_dtype)
|
|
278
|
-
return res
|
|
245
|
+
return res # jax.scipy.special.zeta(x, q)
|
|
279
246
|
|
|
280
247
|
|
|
281
248
|
# aten.igammac
|
|
@@ -286,7 +253,7 @@ def _aten_igammac(input, other):
|
|
|
286
253
|
if isinstance(other, jnp.ndarray):
|
|
287
254
|
other = jnp.where(other < 0, jnp.nan, other)
|
|
288
255
|
else:
|
|
289
|
-
if (input==0 and other==0) or (input < 0) or (other < 0):
|
|
256
|
+
if (input == 0 and other == 0) or (input < 0) or (other < 0):
|
|
290
257
|
other = jnp.nan
|
|
291
258
|
return jnp.array(jax.scipy.special.gammaincc(input, other))
|
|
292
259
|
|
|
@@ -294,7 +261,7 @@ def _aten_igammac(input, other):
|
|
|
294
261
|
@op(torch.ops.aten.mean)
|
|
295
262
|
def _aten_mean(x, dim=None, keepdim=False):
|
|
296
263
|
if x.shape == () and dim is not None:
|
|
297
|
-
dim = None
|
|
264
|
+
dim = None # disable dim for jax array without dim
|
|
298
265
|
return jnp.mean(x, dim, keepdims=keepdim)
|
|
299
266
|
|
|
300
267
|
|
|
@@ -310,13 +277,14 @@ def _torch_binary_scalar_type(scalar, tensor):
|
|
|
310
277
|
|
|
311
278
|
|
|
312
279
|
@op(torch.ops.aten.searchsorted.Tensor)
|
|
313
|
-
def _aten_searchsorted(sorted_sequence, values):
|
|
280
|
+
def _aten_searchsorted(sorted_sequence, values):
|
|
314
281
|
new_dtype = mappings.t2j_dtype(torch.get_default_dtype())
|
|
315
282
|
res = jnp.searchsorted(sorted_sequence, values)
|
|
316
|
-
if sorted_sequence.dtype == np.dtype(
|
|
283
|
+
if sorted_sequence.dtype == np.dtype(
|
|
284
|
+
np.int32) or sorted_sequence.dtype == np.dtype(np.int32):
|
|
317
285
|
# res = res.astype(new_dtype)
|
|
318
286
|
res = res.astype(np.dtype(np.int64))
|
|
319
|
-
return res
|
|
287
|
+
return res # jnp.searchsorted(sorted_sequence, values)
|
|
320
288
|
|
|
321
289
|
|
|
322
290
|
@op(torch.ops.aten.sub.Tensor)
|
|
@@ -328,7 +296,7 @@ def _aten_sub(x, y, alpha=1):
|
|
|
328
296
|
if isinstance(y, float):
|
|
329
297
|
dtype = _torch_binary_scalar_type(y, x)
|
|
330
298
|
y = jnp.array(y, dtype=dtype)
|
|
331
|
-
return x - y*alpha
|
|
299
|
+
return x - y * alpha
|
|
332
300
|
|
|
333
301
|
|
|
334
302
|
@op(torch.ops.aten.numpy_T)
|
|
@@ -345,7 +313,6 @@ def _aten_numpy_T(input):
|
|
|
345
313
|
return jnp.transpose(input)
|
|
346
314
|
|
|
347
315
|
|
|
348
|
-
|
|
349
316
|
@op(torch.ops.aten.mm)
|
|
350
317
|
def _aten_mm(x, y):
|
|
351
318
|
res = x @ y
|
|
@@ -379,13 +346,15 @@ def _aten_t(x):
|
|
|
379
346
|
@op(torch.ops.aten.transpose)
|
|
380
347
|
@op(torch.ops.aten.transpose_copy)
|
|
381
348
|
def _aten_transpose(x, dim0, dim1):
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
349
|
+
if x.ndim == 0:
|
|
350
|
+
return x
|
|
351
|
+
dim0 = dim0 if dim0 >= 0 else dim0 + x.ndim
|
|
352
|
+
dim1 = dim1 if dim1 >= 0 else dim1 + x.ndim
|
|
353
|
+
return jnp.swapaxes(x, dim0, dim1)
|
|
385
354
|
|
|
386
355
|
|
|
387
356
|
@op(torch.ops.aten.triu)
|
|
388
|
-
def _aten_triu(m, k):
|
|
357
|
+
def _aten_triu(m, k=0):
|
|
389
358
|
return jnp.triu(m, k)
|
|
390
359
|
|
|
391
360
|
|
|
@@ -406,6 +375,7 @@ def _aten_slice(self, dim=0, start=None, end=None, step=1):
|
|
|
406
375
|
return self[tuple(dims)]
|
|
407
376
|
|
|
408
377
|
|
|
378
|
+
@op(torch.ops.aten.positive)
|
|
409
379
|
@op(torch.ops.aten.detach)
|
|
410
380
|
def _aten_detach(self):
|
|
411
381
|
return self
|
|
@@ -439,7 +409,8 @@ def _aten_resize_as_(x, y):
|
|
|
439
409
|
|
|
440
410
|
@op(torch.ops.aten.repeat_interleave.Tensor)
|
|
441
411
|
def repeat_interleave(repeats, dim=0):
|
|
442
|
-
return jnp.repeat(
|
|
412
|
+
return jnp.repeat(np.arange(repeats.shape[dim]), repeats)
|
|
413
|
+
|
|
443
414
|
|
|
444
415
|
@op(torch.ops.aten.repeat_interleave.self_int)
|
|
445
416
|
@op(torch.ops.aten.repeat_interleave.self_Tensor)
|
|
@@ -451,12 +422,6 @@ def repeat_interleave(self, repeats, dim=0):
|
|
|
451
422
|
return jnp.repeat(self, repeats, dim, total_repeat_length=total_repeat_length)
|
|
452
423
|
|
|
453
424
|
|
|
454
|
-
# aten.upsample_bilinear2d
|
|
455
|
-
@op(torch.ops.aten.upsample_bilinear2d)
|
|
456
|
-
def _aten_upsample_bilinear2d(x, output_size, align_corners=False, scale_h=None, scale_w=None):
|
|
457
|
-
return _aten_upsample_bilinear2d_aa(x, output_size=output_size, align_corners=align_corners, scale_factors=None, scales_h=scale_h, scales_w=scale_w)
|
|
458
|
-
|
|
459
|
-
|
|
460
425
|
@op(torch.ops.aten.view_as_real)
|
|
461
426
|
def _aten_view_as_real(x):
|
|
462
427
|
real = jnp.real(x)
|
|
@@ -473,7 +438,7 @@ def _aten_stack(tensors, dim=0):
|
|
|
473
438
|
@op(torch.ops.aten._softmax)
|
|
474
439
|
@op(torch.ops.aten.softmax)
|
|
475
440
|
@op(torch.ops.aten.softmax.int)
|
|
476
|
-
def _aten_softmax(x, dim, halftofloat
|
|
441
|
+
def _aten_softmax(x, dim, halftofloat=False):
|
|
477
442
|
if x.shape == ():
|
|
478
443
|
return jax.nn.softmax(x.reshape([1]), axis=0).reshape([])
|
|
479
444
|
return jax.nn.softmax(x, dim)
|
|
@@ -482,10 +447,12 @@ def _aten_softmax(x, dim, halftofloat = False):
|
|
|
482
447
|
def _is_int(x):
|
|
483
448
|
if isinstance(x, int):
|
|
484
449
|
return True
|
|
485
|
-
if isinstance(x, jax.Array) and (x.dtype.name.startswith('int') or
|
|
450
|
+
if isinstance(x, jax.Array) and (x.dtype.name.startswith('int') or
|
|
451
|
+
x.dtype.name.startswith('uint')):
|
|
486
452
|
return True
|
|
487
453
|
return False
|
|
488
454
|
|
|
455
|
+
|
|
489
456
|
def highest_precision_int_dtype(tensor1, tensor2):
|
|
490
457
|
if isinstance(tensor1, int):
|
|
491
458
|
return tensor2.dtype
|
|
@@ -493,12 +460,20 @@ def highest_precision_int_dtype(tensor1, tensor2):
|
|
|
493
460
|
return tensor1.dtype
|
|
494
461
|
|
|
495
462
|
dtype_hierarchy = {
|
|
496
|
-
'uint8': 8,
|
|
497
|
-
'
|
|
498
|
-
'
|
|
499
|
-
'
|
|
463
|
+
'uint8': 8,
|
|
464
|
+
'int8': 8,
|
|
465
|
+
'uint16': 16,
|
|
466
|
+
'int16': 16,
|
|
467
|
+
'uint32': 32,
|
|
468
|
+
'int32': 32,
|
|
469
|
+
'uint64': 64,
|
|
470
|
+
'int64': 64,
|
|
500
471
|
}
|
|
501
|
-
return max(
|
|
472
|
+
return max(
|
|
473
|
+
tensor1.dtype,
|
|
474
|
+
tensor2.dtype,
|
|
475
|
+
key=lambda dtype: dtype_hierarchy[str(dtype)])
|
|
476
|
+
|
|
502
477
|
|
|
503
478
|
@op(torch.ops.aten.pow)
|
|
504
479
|
def _aten_pow(x, y):
|
|
@@ -553,11 +528,13 @@ def _aten_div(x, y, rounding_mode=""):
|
|
|
553
528
|
def _aten_true_divide(x, y):
|
|
554
529
|
return x / y
|
|
555
530
|
|
|
531
|
+
|
|
556
532
|
@op(torch.ops.aten.dist)
|
|
557
533
|
def _aten_dist(input, other, p=2):
|
|
558
534
|
diff = jnp.abs(jnp.subtract(input, other))
|
|
559
535
|
return _aten_linalg_vector_norm(diff, ord=p)
|
|
560
536
|
|
|
537
|
+
|
|
561
538
|
@op(torch.ops.aten.bmm)
|
|
562
539
|
def _aten_bmm(x, y):
|
|
563
540
|
res = x @ y
|
|
@@ -567,9 +544,14 @@ def _aten_bmm(x, y):
|
|
|
567
544
|
|
|
568
545
|
@op(torch.ops.aten.embedding)
|
|
569
546
|
# embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False)
|
|
570
|
-
def _aten_embedding(a,
|
|
547
|
+
def _aten_embedding(a,
|
|
548
|
+
w,
|
|
549
|
+
padding_idx=-1,
|
|
550
|
+
scale_grad_by_freq=False,
|
|
551
|
+
sparse=False):
|
|
571
552
|
return jnp.take(a, w, axis=0)
|
|
572
553
|
|
|
554
|
+
|
|
573
555
|
@op(torch.ops.aten.embedding_renorm_)
|
|
574
556
|
def _aten_embedding_renorm_(weight, indices, max_norm, norm_type):
|
|
575
557
|
# Adapted from https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Embedding.cpp
|
|
@@ -587,27 +569,26 @@ def _aten_embedding_renorm_(weight, indices, max_norm, norm_type):
|
|
|
587
569
|
|
|
588
570
|
indices_to_update = unique_indices[indice_idx]
|
|
589
571
|
|
|
590
|
-
weight = weight.at[indices_to_update].set(
|
|
591
|
-
|
|
592
|
-
)
|
|
572
|
+
weight = weight.at[indices_to_update].set(weight[indices_to_update] *
|
|
573
|
+
scale[:, None])
|
|
593
574
|
return weight
|
|
594
575
|
|
|
576
|
+
|
|
595
577
|
#- func: _embedding_bag_forward_only(
|
|
596
578
|
# Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False,
|
|
597
579
|
# int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) -> (Tensor, Tensor, Tensor, Tensor)
|
|
598
580
|
@op(torch.ops.aten._embedding_bag)
|
|
599
581
|
@op(torch.ops.aten._embedding_bag_forward_only)
|
|
600
|
-
def _aten__embedding_bag(
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
"""Jax implementation of the PyTorch _embedding_bag function.
|
|
582
|
+
def _aten__embedding_bag(weight,
|
|
583
|
+
indices,
|
|
584
|
+
offsets=None,
|
|
585
|
+
scale_grad_by_freq=False,
|
|
586
|
+
mode=0,
|
|
587
|
+
sparse=False,
|
|
588
|
+
per_sample_weights=None,
|
|
589
|
+
include_last_offset=False,
|
|
590
|
+
padding_idx=-1):
|
|
591
|
+
"""Jax implementation of the PyTorch _embedding_bag function.
|
|
611
592
|
|
|
612
593
|
Args:
|
|
613
594
|
weight: The learnable weights of the module of shape (num_embeddings, embedding_dim).
|
|
@@ -623,48 +604,50 @@ def _aten__embedding_bag(
|
|
|
623
604
|
Returns:
|
|
624
605
|
A tuple of (output, offset2bag, bag_size, max_indices).
|
|
625
606
|
"""
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
607
|
+
embedded = _aten_embedding(weight, indices, padding_idx)
|
|
608
|
+
|
|
609
|
+
if offsets is None:
|
|
610
|
+
# offsets is None only when indices.ndim > 1
|
|
611
|
+
if mode == 0: # sum
|
|
612
|
+
output = jnp.sum(embedded, axis=1)
|
|
613
|
+
elif mode == 1: # mean
|
|
614
|
+
output = jnp.mean(embedded, axis=1)
|
|
615
|
+
elif mode == 2: # max
|
|
616
|
+
output = jnp.max(embedded, axis=1)
|
|
617
|
+
return output, None, None, None
|
|
618
|
+
|
|
619
|
+
if isinstance(offsets, jax.Array):
|
|
620
|
+
offsets_np = np.array(offsets)
|
|
621
|
+
else:
|
|
622
|
+
offsets_np = offsets
|
|
623
|
+
offset2bag = np.zeros(indices.shape[0], dtype=np.int64)
|
|
624
|
+
bag_size = np.zeros(offsets_np.shape[0], dtype=np.int64)
|
|
625
|
+
max_indices = jnp.full_like(indices, -1)
|
|
645
626
|
|
|
646
|
-
|
|
647
|
-
|
|
627
|
+
for bag in range(offsets_np.shape[0]):
|
|
628
|
+
start = int(offsets_np[bag])
|
|
648
629
|
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
630
|
+
end = int(indices.shape[0] if bag +
|
|
631
|
+
1 == offsets_np.shape[0] else offsets_np[bag + 1])
|
|
632
|
+
bag_size[bag] = end - start
|
|
633
|
+
offset2bag = offset2bag.at[start:end].set(bag)
|
|
652
634
|
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
635
|
+
if end - start > 0:
|
|
636
|
+
if mode == 0:
|
|
637
|
+
output_bag = jnp.sum(embedded[start:end], axis=0)
|
|
638
|
+
elif mode == 1:
|
|
639
|
+
output_bag = jnp.mean(embedded[start:end], axis=0)
|
|
640
|
+
elif mode == 2:
|
|
641
|
+
output_bag = jnp.max(embedded[start:end], axis=0)
|
|
642
|
+
max_indices = max_indices.at[start:end].set(
|
|
643
|
+
jnp.argmax(embedded[start:end], axis=0))
|
|
661
644
|
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
645
|
+
# The original code returned offset2bag, bag_size, and max_indices as numpy arrays.
|
|
646
|
+
# Converting them to JAX arrays for consistency.
|
|
647
|
+
offset2bag = jnp.array(offset2bag)
|
|
648
|
+
bag_size = jnp.array(bag_size)
|
|
666
649
|
|
|
667
|
-
|
|
650
|
+
return output_bag, offset2bag, bag_size, max_indices
|
|
668
651
|
|
|
669
652
|
|
|
670
653
|
@op(torch.ops.aten.rsqrt)
|
|
@@ -676,6 +659,7 @@ def _aten_rsqrt(x):
|
|
|
676
659
|
@op(torch.ops.aten.expand)
|
|
677
660
|
@op(torch.ops.aten.expand_copy)
|
|
678
661
|
def _aten_expand(x, dims):
|
|
662
|
+
|
|
679
663
|
def fix_dims(d, xs):
|
|
680
664
|
if d == -1:
|
|
681
665
|
return xs
|
|
@@ -683,7 +667,9 @@ def _aten_expand(x, dims):
|
|
|
683
667
|
|
|
684
668
|
shape = list(x.shape)
|
|
685
669
|
if len(shape) < len(dims):
|
|
686
|
-
shape = [
|
|
670
|
+
shape = [
|
|
671
|
+
1,
|
|
672
|
+
] * (len(dims) - len(shape)) + shape
|
|
687
673
|
# make sure that dims and shape is the same by
|
|
688
674
|
# left pad with 1s. Otherwise the zip below will
|
|
689
675
|
# truncate
|
|
@@ -705,15 +691,15 @@ def _aten__to_copy(self, **kwargs):
|
|
|
705
691
|
|
|
706
692
|
|
|
707
693
|
@op(torch.ops.aten.empty)
|
|
708
|
-
@op_base.convert_dtype()
|
|
694
|
+
@op_base.convert_dtype(use_default_dtype=False)
|
|
709
695
|
def _aten_empty(size: Sequence[int], *, dtype=None, **kwargs):
|
|
710
696
|
return jnp.empty(size, dtype=dtype)
|
|
711
697
|
|
|
712
698
|
|
|
713
699
|
@op(torch.ops.aten.empty_like)
|
|
714
|
-
@op_base.convert_dtype()
|
|
700
|
+
@op_base.convert_dtype(use_default_dtype=False)
|
|
715
701
|
def _aten_empty_like(input, *, dtype=None, **kwargs):
|
|
716
|
-
return jnp.empty_like(input, dtype
|
|
702
|
+
return jnp.empty_like(input, dtype)
|
|
717
703
|
|
|
718
704
|
|
|
719
705
|
@op(torch.ops.aten.ones)
|
|
@@ -750,7 +736,6 @@ def _aten_empty_strided(sizes, stride, dtype=None, **kwargs):
|
|
|
750
736
|
return jnp.empty(sizes, dtype=dtype)
|
|
751
737
|
|
|
752
738
|
|
|
753
|
-
@op(torch.ops.aten.index_put_)
|
|
754
739
|
@op(torch.ops.aten.index_put)
|
|
755
740
|
def _aten_index_put(self, indexes, values, accumulate=False):
|
|
756
741
|
indexes = [slice(None, None, None) if i is None else i for i in indexes]
|
|
@@ -784,8 +769,8 @@ def split_with_sizes(x, sizes, dim=0):
|
|
|
784
769
|
A list of sub-arrays.
|
|
785
770
|
"""
|
|
786
771
|
if isinstance(sizes, int):
|
|
787
|
-
# split equal size
|
|
788
|
-
new_sizes = [sizes] * (x.shape[dim] // sizes)
|
|
772
|
+
# split equal size, round up
|
|
773
|
+
new_sizes = [sizes] * (-(-x.shape[dim] // sizes))
|
|
789
774
|
sizes = new_sizes
|
|
790
775
|
rank = x.ndim
|
|
791
776
|
splits = np.cumsum(sizes) # Cumulative sum for split points
|
|
@@ -796,14 +781,15 @@ def split_with_sizes(x, sizes, dim=0):
|
|
|
796
781
|
return tuple(res)
|
|
797
782
|
|
|
798
783
|
return [
|
|
799
|
-
|
|
800
|
-
|
|
784
|
+
x[make_range(rank, dim, start, end)]
|
|
785
|
+
for start, end in zip([0] + list(splits[:-1]), splits)
|
|
801
786
|
]
|
|
802
787
|
|
|
803
788
|
|
|
804
789
|
@op(torch.ops.aten.permute)
|
|
805
790
|
@op(torch.ops.aten.permute_copy)
|
|
806
791
|
def permute(t, dims):
|
|
792
|
+
# TODO: return a View instead
|
|
807
793
|
return jnp.transpose(t, dims)
|
|
808
794
|
|
|
809
795
|
|
|
@@ -819,6 +805,7 @@ def _aten_unsqueeze(self, dim):
|
|
|
819
805
|
def _aten_ne(x, y):
|
|
820
806
|
return jnp.not_equal(x, y)
|
|
821
807
|
|
|
808
|
+
|
|
822
809
|
# Create indices along a specific axis
|
|
823
810
|
#
|
|
824
811
|
# For example
|
|
@@ -832,14 +819,12 @@ def _aten_ne(x, y):
|
|
|
832
819
|
def _indices_along_axis(x, axis):
|
|
833
820
|
return jnp.expand_dims(
|
|
834
821
|
jnp.arange(x.shape[axis]),
|
|
835
|
-
axis
|
|
836
|
-
|
|
822
|
+
axis=[d for d in range(len(x.shape)) if d != axis])
|
|
823
|
+
|
|
837
824
|
|
|
838
825
|
def _broadcast_indices(indices, shape):
|
|
839
|
-
return jnp.broadcast_to(
|
|
840
|
-
|
|
841
|
-
shape
|
|
842
|
-
)
|
|
826
|
+
return jnp.broadcast_to(indices, shape)
|
|
827
|
+
|
|
843
828
|
|
|
844
829
|
@op(torch.ops.aten.cummax)
|
|
845
830
|
def _aten_cummax(x, dim):
|
|
@@ -851,36 +836,45 @@ def _aten_cummax(x, dim):
|
|
|
851
836
|
indice_along_axis = _indices_along_axis(x, axis)
|
|
852
837
|
indices = _broadcast_indices(indice_along_axis, x.shape)
|
|
853
838
|
|
|
854
|
-
|
|
855
839
|
def cummax_reduce_func(carry, elem):
|
|
856
|
-
v1, v2 = carry['val'], elem['val']
|
|
840
|
+
v1, v2 = carry['val'], elem['val']
|
|
857
841
|
i1, i2 = carry['idx'], elem['idx']
|
|
858
842
|
|
|
859
843
|
v = jnp.maximum(v1, v2)
|
|
860
844
|
i = jnp.where(v1 > v2, i1, i2)
|
|
861
845
|
return {'val': v, 'idx': i}
|
|
862
|
-
|
|
846
|
+
|
|
847
|
+
res = jax.lax.associative_scan(
|
|
848
|
+
cummax_reduce_func, {
|
|
849
|
+
'val': x,
|
|
850
|
+
'idx': indices
|
|
851
|
+
}, axis=axis)
|
|
863
852
|
return res['val'], res['idx']
|
|
864
853
|
|
|
854
|
+
|
|
865
855
|
@op(torch.ops.aten.cummin)
|
|
866
856
|
def _aten_cummin(x, dim):
|
|
867
857
|
if not x.shape:
|
|
868
858
|
return x, jnp.zeros_like(x, dtype=jnp.int64)
|
|
869
|
-
|
|
859
|
+
|
|
870
860
|
axis = dim
|
|
871
861
|
|
|
872
862
|
indice_along_axis = _indices_along_axis(x, axis)
|
|
873
863
|
indices = _broadcast_indices(indice_along_axis, x.shape)
|
|
874
864
|
|
|
875
865
|
def cummin_reduce_func(carry, elem):
|
|
876
|
-
v1, v2 = carry['val'], elem['val']
|
|
866
|
+
v1, v2 = carry['val'], elem['val']
|
|
877
867
|
i1, i2 = carry['idx'], elem['idx']
|
|
878
868
|
|
|
879
869
|
v = jnp.minimum(v1, v2)
|
|
880
870
|
i = jnp.where(v1 < v2, i1, i2)
|
|
881
871
|
return {'val': v, 'idx': i}
|
|
882
872
|
|
|
883
|
-
res = jax.lax.associative_scan(
|
|
873
|
+
res = jax.lax.associative_scan(
|
|
874
|
+
cummin_reduce_func, {
|
|
875
|
+
'val': x,
|
|
876
|
+
'idx': indices
|
|
877
|
+
}, axis=axis)
|
|
884
878
|
return res['val'], res['idx']
|
|
885
879
|
|
|
886
880
|
|
|
@@ -908,9 +902,11 @@ def _aten_cumprod(input, dim, dtype=None, out=None):
|
|
|
908
902
|
|
|
909
903
|
|
|
910
904
|
@op(torch.ops.aten.native_layer_norm)
|
|
911
|
-
def _aten_native_layer_norm(
|
|
912
|
-
|
|
913
|
-
|
|
905
|
+
def _aten_native_layer_norm(input,
|
|
906
|
+
normalized_shape,
|
|
907
|
+
weight=None,
|
|
908
|
+
bias=None,
|
|
909
|
+
eps=1e-5):
|
|
914
910
|
"""Implements layer normalization in Jax as defined by `aten::native_layer_norm`.
|
|
915
911
|
|
|
916
912
|
Args:
|
|
@@ -944,7 +940,7 @@ def _aten_native_layer_norm(
|
|
|
944
940
|
norm_x += bias
|
|
945
941
|
return norm_x, mean, rstd
|
|
946
942
|
|
|
947
|
-
|
|
943
|
+
|
|
948
944
|
@op(torch.ops.aten.matmul)
|
|
949
945
|
def _aten_matmul(x, y):
|
|
950
946
|
return x @ y
|
|
@@ -960,6 +956,7 @@ def _aten_addmm(self, mat1, mat2, *, beta=1.0, alpha=1.0):
|
|
|
960
956
|
self += alpha * jnp.matmul(mat1, mat2)
|
|
961
957
|
return self
|
|
962
958
|
|
|
959
|
+
|
|
963
960
|
@op(torch.ops.aten.sparse_sampled_addmm)
|
|
964
961
|
def _aten_sparse_addmm(self, mat1, mat2, *, beta=1.0, alpha=1.0):
|
|
965
962
|
alpha = jnp.array(alpha).astype(mat1.dtype)
|
|
@@ -974,9 +971,8 @@ def _aten_addbmm(input, batch1, batch2, *, beta=1, alpha=1):
|
|
|
974
971
|
alpha = jnp.array(alpha).astype(batch1.dtype)
|
|
975
972
|
beta = jnp.array(beta).astype(batch1.dtype)
|
|
976
973
|
mm = jnp.einsum("bxy, byz -> xz", batch1, batch2)
|
|
977
|
-
return jax.lax.cond(
|
|
978
|
-
|
|
979
|
-
)
|
|
974
|
+
return jax.lax.cond(beta == 0, lambda: alpha * mm,
|
|
975
|
+
lambda: beta * input + alpha * mm)
|
|
980
976
|
|
|
981
977
|
|
|
982
978
|
@op(torch.ops.aten.gelu)
|
|
@@ -987,73 +983,69 @@ def _aten_gelu(self, *, approximate="none"):
|
|
|
987
983
|
|
|
988
984
|
@op(torch.ops.aten.squeeze)
|
|
989
985
|
@op(torch.ops.aten.squeeze_copy)
|
|
990
|
-
def _aten_squeeze_dim(self, dim):
|
|
991
|
-
|
|
992
|
-
|
|
993
|
-
Args:
|
|
994
|
-
self: The input tensor.
|
|
995
|
-
dim: The dimension to squeeze.
|
|
996
|
-
|
|
997
|
-
Returns:
|
|
998
|
-
The squeezed tensor with the specified dimension removed if it is 1,
|
|
999
|
-
otherwise the original tensor is returned.
|
|
1000
|
-
"""
|
|
1001
|
-
|
|
1002
|
-
# Validate input arguments
|
|
1003
|
-
if not isinstance(self, jnp.ndarray):
|
|
1004
|
-
raise TypeError(f"Expected a Jax tensor, got {type(self)}.")
|
|
1005
|
-
if isinstance(dim, int):
|
|
1006
|
-
dim = [dim]
|
|
1007
|
-
|
|
1008
|
-
# Check if the specified dimension has size 1
|
|
1009
|
-
if (len(self.shape) == 0) or all([self.shape[d] != 1 for d in dim]):
|
|
986
|
+
def _aten_squeeze_dim(self, dim=None):
|
|
987
|
+
if self.ndim == 0:
|
|
1010
988
|
return self
|
|
989
|
+
if dim is not None:
|
|
990
|
+
if isinstance(dim, int):
|
|
991
|
+
if self.shape[dim] != 1:
|
|
992
|
+
return self
|
|
993
|
+
if dim < 0:
|
|
994
|
+
dim += self.ndim
|
|
995
|
+
else:
|
|
996
|
+
# NOTE: torch leaves the dims that is not 1 unchanged,
|
|
997
|
+
# but jax raises error.
|
|
998
|
+
dim = [
|
|
999
|
+
i if i >= 0 else (i + self.ndim) for i in dim if self.shape[i] == 1
|
|
1000
|
+
]
|
|
1011
1001
|
|
|
1012
|
-
|
|
1013
|
-
new_shape = list(self.shape)
|
|
1014
|
-
|
|
1015
|
-
def fix_dim(p):
|
|
1016
|
-
if p < 0:
|
|
1017
|
-
return p + len(self.shape)
|
|
1018
|
-
return p
|
|
1002
|
+
return jnp.squeeze(self, dim)
|
|
1019
1003
|
|
|
1020
|
-
dim = [fix_dim(d) for d in dim]
|
|
1021
|
-
new_shape = [p for i, p in enumerate(self.shape) if i not in dim or p != 1]
|
|
1022
|
-
return self.reshape(new_shape)
|
|
1023
1004
|
|
|
1024
1005
|
@op(torch.ops.aten.bucketize)
|
|
1025
|
-
def _aten_bucketize(input,
|
|
1026
|
-
|
|
1006
|
+
def _aten_bucketize(input,
|
|
1007
|
+
boundaries,
|
|
1008
|
+
*,
|
|
1009
|
+
out_int32=False,
|
|
1010
|
+
right=False,
|
|
1011
|
+
out=None):
|
|
1027
1012
|
return_type = jnp.int32 if out_int32 else jnp.int64
|
|
1028
1013
|
return jnp.digitize(input, boundaries, right=not right).astype(return_type)
|
|
1029
1014
|
|
|
1030
1015
|
|
|
1031
1016
|
@op(torch.ops.aten.conv2d)
|
|
1032
1017
|
def _aten_conv2d(
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1018
|
+
input,
|
|
1019
|
+
weight,
|
|
1020
|
+
bias,
|
|
1021
|
+
stride,
|
|
1022
|
+
padding,
|
|
1023
|
+
dilation,
|
|
1024
|
+
groups,
|
|
1040
1025
|
):
|
|
1041
1026
|
return _aten_convolution(
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1027
|
+
input,
|
|
1028
|
+
weight,
|
|
1029
|
+
bias,
|
|
1030
|
+
stride,
|
|
1031
|
+
padding,
|
|
1032
|
+
dilation,
|
|
1033
|
+
transposed=False,
|
|
1034
|
+
output_padding=1,
|
|
1035
|
+
groups=groups)
|
|
1036
|
+
|
|
1045
1037
|
|
|
1046
1038
|
@op(torch.ops.aten.convolution)
|
|
1047
1039
|
def _aten_convolution(
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1040
|
+
input,
|
|
1041
|
+
weight,
|
|
1042
|
+
bias,
|
|
1043
|
+
stride,
|
|
1044
|
+
padding,
|
|
1045
|
+
dilation,
|
|
1046
|
+
transposed,
|
|
1047
|
+
output_padding,
|
|
1048
|
+
groups,
|
|
1057
1049
|
):
|
|
1058
1050
|
num_shape_dim = weight.ndim - 1
|
|
1059
1051
|
batch_dims = input.shape[:-num_shape_dim]
|
|
@@ -1068,7 +1060,7 @@ def _aten_convolution(
|
|
|
1068
1060
|
# See https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html
|
|
1069
1061
|
pad_out = []
|
|
1070
1062
|
for i in range(num_spatial_dims):
|
|
1071
|
-
front = dilation[i] * (weight.shape[i+2] - 1) - padding[i]
|
|
1063
|
+
front = dilation[i] * (weight.shape[i + 2] - 1) - padding[i]
|
|
1072
1064
|
back = front + output_padding[i]
|
|
1073
1065
|
pad_out.append((front, back))
|
|
1074
1066
|
return pad_out
|
|
@@ -1089,39 +1081,38 @@ def _aten_convolution(
|
|
|
1089
1081
|
rhs_spec.append(i + 2)
|
|
1090
1082
|
out_spec.append(i + 2)
|
|
1091
1083
|
return jax.lax.ConvDimensionNumbers(
|
|
1092
|
-
|
|
1093
|
-
)
|
|
1084
|
+
*map(tuple, (lhs_spec, rhs_spec, out_spec)))
|
|
1094
1085
|
|
|
1095
1086
|
if transposed:
|
|
1096
|
-
rhs = jnp.flip(weight, range(2, 1+num_shape_dim))
|
|
1087
|
+
rhs = jnp.flip(weight, range(2, 1 + num_shape_dim))
|
|
1097
1088
|
if groups != 1:
|
|
1098
1089
|
# reshape filters for tranposed depthwise convolution
|
|
1099
1090
|
assert rhs.shape[0] % groups == 0
|
|
1100
|
-
rhs_shape = [rhs.shape[0]//groups, rhs.shape[1]*groups]
|
|
1091
|
+
rhs_shape = [rhs.shape[0] // groups, rhs.shape[1] * groups]
|
|
1101
1092
|
rhs_shape.extend(rhs.shape[2:])
|
|
1102
1093
|
rhs = jnp.reshape(rhs, rhs_shape)
|
|
1103
1094
|
res = jax.lax.conv_general_dilated(
|
|
1104
|
-
|
|
1105
|
-
|
|
1106
|
-
|
|
1107
|
-
|
|
1108
|
-
|
|
1109
|
-
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
|
|
1095
|
+
input,
|
|
1096
|
+
rhs,
|
|
1097
|
+
(1,) * len(stride),
|
|
1098
|
+
make_padding(padding, len(stride)),
|
|
1099
|
+
lhs_dilation=stride,
|
|
1100
|
+
rhs_dilation=dilation,
|
|
1101
|
+
dimension_numbers=create_default_conv_dimension_numbers(len(stride)),
|
|
1102
|
+
feature_group_count=groups,
|
|
1103
|
+
batch_group_count=1,
|
|
1113
1104
|
)
|
|
1114
1105
|
else:
|
|
1115
1106
|
res = jax.lax.conv_general_dilated(
|
|
1116
|
-
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
|
|
1123
|
-
|
|
1124
|
-
|
|
1107
|
+
input,
|
|
1108
|
+
weight,
|
|
1109
|
+
stride,
|
|
1110
|
+
make_padding(padding, len(stride)),
|
|
1111
|
+
lhs_dilation=(1,) * len(stride),
|
|
1112
|
+
rhs_dilation=dilation,
|
|
1113
|
+
dimension_numbers=create_default_conv_dimension_numbers(len(stride)),
|
|
1114
|
+
feature_group_count=groups,
|
|
1115
|
+
batch_group_count=1,
|
|
1125
1116
|
)
|
|
1126
1117
|
|
|
1127
1118
|
if bias is not None:
|
|
@@ -1137,10 +1128,9 @@ def _aten_convolution(
|
|
|
1137
1128
|
|
|
1138
1129
|
|
|
1139
1130
|
# _native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps)
|
|
1140
|
-
@op(torch.ops.aten._native_batch_norm_legit)
|
|
1141
|
-
def _aten__native_batch_norm_legit(
|
|
1142
|
-
|
|
1143
|
-
):
|
|
1131
|
+
@op(torch.ops.aten._native_batch_norm_legit.default)
|
|
1132
|
+
def _aten__native_batch_norm_legit(input, weight, bias, running_mean,
|
|
1133
|
+
running_var, training, momentum, eps):
|
|
1144
1134
|
"""JAX implementation of batch normalization with optional parameters.
|
|
1145
1135
|
Refers to https://github.com/pytorch/pytorch/blob/cd3a71f754a2248bcfe500de7c9860bd7d2002bf/torch/_decomp/decompositions.py#L1713.
|
|
1146
1136
|
|
|
@@ -1161,8 +1151,7 @@ def _aten__native_batch_norm_legit(
|
|
|
1161
1151
|
DeviceArray: Reversed batch variance (C,) or empty if training is False
|
|
1162
1152
|
"""
|
|
1163
1153
|
reduction_dims = [0] + list(range(2, input.ndim))
|
|
1164
|
-
reshape_dims = [1, -1] + [1]*(input.ndim-2)
|
|
1165
|
-
|
|
1154
|
+
reshape_dims = [1, -1] + [1] * (input.ndim - 2)
|
|
1166
1155
|
if training:
|
|
1167
1156
|
# Calculate batch mean and variance
|
|
1168
1157
|
mean = jnp.mean(input, axis=reduction_dims, keepdims=True)
|
|
@@ -1175,7 +1164,9 @@ def _aten__native_batch_norm_legit(
|
|
|
1175
1164
|
saved_rstd = jnp.squeeze(rstd, reduction_dims)
|
|
1176
1165
|
else:
|
|
1177
1166
|
rstd = jax.lax.rsqrt(running_var.reshape(reshape_dims) + eps)
|
|
1178
|
-
saved_mean = jnp.array(
|
|
1167
|
+
saved_mean = jnp.array(
|
|
1168
|
+
[], dtype=input.dtype
|
|
1169
|
+
) # No need to calculate batch statistics in inference mode
|
|
1179
1170
|
saved_rstd = jnp.array([], dtype=input.dtype)
|
|
1180
1171
|
|
|
1181
1172
|
# Normalize
|
|
@@ -1190,19 +1181,17 @@ def _aten__native_batch_norm_legit(
|
|
|
1190
1181
|
if weight is not None:
|
|
1191
1182
|
x_hat *= weight.reshape(reshape_dims) # Reshape weight for broadcasting
|
|
1192
1183
|
if bias is not None:
|
|
1193
|
-
x_hat += bias.reshape(reshape_dims)
|
|
1184
|
+
x_hat += bias.reshape(reshape_dims) # Reshape bias for broadcasting
|
|
1194
1185
|
|
|
1195
1186
|
return x_hat, saved_mean, saved_rstd
|
|
1196
1187
|
|
|
1197
1188
|
|
|
1198
|
-
|
|
1199
1189
|
@op(torch.ops.aten._native_batch_norm_legit_no_training)
|
|
1200
|
-
def _aten__native_batch_norm_legit_no_training(
|
|
1201
|
-
|
|
1202
|
-
):
|
|
1203
|
-
return _aten__native_batch_norm_legit(
|
|
1204
|
-
|
|
1205
|
-
)
|
|
1190
|
+
def _aten__native_batch_norm_legit_no_training(input, weight, bias,
|
|
1191
|
+
running_mean, running_var,
|
|
1192
|
+
momentum, eps):
|
|
1193
|
+
return _aten__native_batch_norm_legit(input, weight, bias, running_mean,
|
|
1194
|
+
running_var, False, momentum, eps)
|
|
1206
1195
|
|
|
1207
1196
|
|
|
1208
1197
|
@op(torch.ops.aten.relu)
|
|
@@ -1212,7 +1201,15 @@ def _aten_relu(self):
|
|
|
1212
1201
|
|
|
1213
1202
|
@op(torch.ops.aten.cat)
|
|
1214
1203
|
def _aten_cat(tensors, dims=0):
|
|
1215
|
-
|
|
1204
|
+
# handle empty tensors as a special case.
|
|
1205
|
+
# torch.cat will ignore the empty tensor, while jnp.concatenate
|
|
1206
|
+
# will error if the dims > 0.
|
|
1207
|
+
filtered_tensors = [
|
|
1208
|
+
t for t in tensors if not (t.ndim == 1 and t.shape[0] == 0)
|
|
1209
|
+
]
|
|
1210
|
+
if filtered_tensors:
|
|
1211
|
+
return jnp.concatenate(filtered_tensors, dims)
|
|
1212
|
+
return tensors[0]
|
|
1216
1213
|
|
|
1217
1214
|
|
|
1218
1215
|
def _ceil_mode_padding(
|
|
@@ -1220,6 +1217,7 @@ def _ceil_mode_padding(
|
|
|
1220
1217
|
input_shape: list[int],
|
|
1221
1218
|
kernel_size: list[int],
|
|
1222
1219
|
stride: list[int],
|
|
1220
|
+
dilation: list[int],
|
|
1223
1221
|
ceil_mode: bool,
|
|
1224
1222
|
):
|
|
1225
1223
|
"""Creates low and high padding specification for the given padding (which is symmetric) and ceil mode.
|
|
@@ -1232,20 +1230,13 @@ def _ceil_mode_padding(
|
|
|
1232
1230
|
right_padding = left_padding
|
|
1233
1231
|
|
|
1234
1232
|
input_size = input_shape[2 + i]
|
|
1235
|
-
output_size_rem = (input_size + 2 * left_padding -
|
|
1236
|
-
|
|
1237
|
-
]
|
|
1233
|
+
output_size_rem = (input_size + 2 * left_padding -
|
|
1234
|
+
(kernel_size[i] - 1) * dilation[i] - 1) % stride[i]
|
|
1238
1235
|
if ceil_mode and output_size_rem != 0:
|
|
1239
1236
|
extra_padding = stride[i] - output_size_rem
|
|
1240
|
-
new_output_size = (
|
|
1241
|
-
|
|
1242
|
-
|
|
1243
|
-
+ right_padding
|
|
1244
|
-
+ extra_padding
|
|
1245
|
-
- kernel_size[i]
|
|
1246
|
-
+ stride[i]
|
|
1247
|
-
- 1
|
|
1248
|
-
) // stride[i] + 1
|
|
1237
|
+
new_output_size = (input_size + left_padding + right_padding +
|
|
1238
|
+
extra_padding - (kernel_size[i] - 1) * dilation[i] -
|
|
1239
|
+
1 + stride[i] - 1) // stride[i] + 1
|
|
1249
1240
|
# Ensure that the last pooling starts inside the image.
|
|
1250
1241
|
size_to_compare = input_size + left_padding
|
|
1251
1242
|
|
|
@@ -1258,30 +1249,36 @@ def _ceil_mode_padding(
|
|
|
1258
1249
|
|
|
1259
1250
|
@op(torch.ops.aten.max_pool2d_with_indices)
|
|
1260
1251
|
@op(torch.ops.aten.max_pool3d_with_indices)
|
|
1261
|
-
def _aten_max_pool2d_with_indices(
|
|
1262
|
-
|
|
1263
|
-
|
|
1252
|
+
def _aten_max_pool2d_with_indices(inputs,
|
|
1253
|
+
kernel_size,
|
|
1254
|
+
strides=None,
|
|
1255
|
+
padding=0,
|
|
1256
|
+
dilation=1,
|
|
1257
|
+
ceil_mode=False):
|
|
1264
1258
|
num_batch_dims = len(inputs.shape) - len(kernel_size) - 1
|
|
1265
1259
|
kernel_size = tuple(kernel_size)
|
|
1266
|
-
|
|
1260
|
+
# Default stride is kernel_size
|
|
1261
|
+
strides = tuple(strides) if strides else kernel_size
|
|
1267
1262
|
if isinstance(padding, int):
|
|
1268
1263
|
padding = [padding for _ in range(len(kernel_size))]
|
|
1264
|
+
if isinstance(dilation, int):
|
|
1265
|
+
dilation = tuple(dilation for _ in range(len(kernel_size)))
|
|
1266
|
+
elif isinstance(dilation, list):
|
|
1267
|
+
dilation = tuple(dilation)
|
|
1269
1268
|
|
|
1270
1269
|
input_shape = inputs.shape
|
|
1271
1270
|
if num_batch_dims == 0:
|
|
1272
1271
|
input_shape = [1, *input_shape]
|
|
1273
|
-
padding = _ceil_mode_padding(
|
|
1274
|
-
|
|
1275
|
-
)
|
|
1272
|
+
padding = _ceil_mode_padding(padding, input_shape, kernel_size, strides,
|
|
1273
|
+
dilation, ceil_mode)
|
|
1276
1274
|
|
|
1277
|
-
|
|
1278
|
-
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
strides
|
|
1282
|
-
), f"len({window_shape}) must equal len({strides})"
|
|
1275
|
+
assert len(kernel_size) == len(
|
|
1276
|
+
strides), f"len({kernel_size=}) must equal len({strides=})"
|
|
1277
|
+
assert len(kernel_size) == len(
|
|
1278
|
+
dilation), f"len({kernel_size=}) must equal len({dilation=})"
|
|
1283
1279
|
strides = (1,) * (1 + num_batch_dims) + strides
|
|
1284
|
-
dims = (1,) * (1 + num_batch_dims) +
|
|
1280
|
+
dims = (1,) * (1 + num_batch_dims) + kernel_size
|
|
1281
|
+
dilation = (1,) * (1 + num_batch_dims) + dilation
|
|
1285
1282
|
|
|
1286
1283
|
is_single_input = False
|
|
1287
1284
|
if num_batch_dims == 0:
|
|
@@ -1290,26 +1287,27 @@ def _aten_max_pool2d_with_indices(
|
|
|
1290
1287
|
inputs = inputs[None]
|
|
1291
1288
|
strides = (1,) + strides
|
|
1292
1289
|
dims = (1,) + dims
|
|
1290
|
+
dilation = (1,) + dilation
|
|
1293
1291
|
is_single_input = True
|
|
1294
1292
|
|
|
1295
1293
|
assert inputs.ndim == len(dims), f"len({inputs.shape}) != len({dims})"
|
|
1296
1294
|
if not isinstance(padding, str):
|
|
1297
1295
|
padding = tuple(map(tuple, padding))
|
|
1298
|
-
assert len(padding) == len(
|
|
1299
|
-
|
|
1300
|
-
|
|
1301
|
-
)
|
|
1302
|
-
|
|
1303
|
-
[len(x) == 2 for x in padding]
|
|
1304
|
-
), f"each entry in padding {padding} must be length 2"
|
|
1296
|
+
assert len(padding) == len(kernel_size), (
|
|
1297
|
+
f"padding {padding} must specify pads for same number of dims as "
|
|
1298
|
+
f"kernel_size {kernel_size}")
|
|
1299
|
+
assert all([len(x) == 2 for x in padding
|
|
1300
|
+
]), f"each entry in padding {padding} must be length 2"
|
|
1305
1301
|
padding = ((0, 0), (0, 0)) + padding
|
|
1306
1302
|
|
|
1307
|
-
indices = jnp.arange(np.prod(inputs.shape))
|
|
1303
|
+
indices = jnp.arange(np.prod(inputs.shape[-len(kernel_size):]))
|
|
1304
|
+
indices = indices.reshape(inputs.shape[-len(kernel_size):])
|
|
1305
|
+
indices = jnp.broadcast_to(indices, inputs.shape)
|
|
1308
1306
|
|
|
1309
1307
|
def reduce_fn(a, b):
|
|
1310
1308
|
ai, av = a
|
|
1311
1309
|
bi, bv = b
|
|
1312
|
-
which = av
|
|
1310
|
+
which = av >= bv # torch breaks ties in favor of later indices
|
|
1313
1311
|
return jnp.where(which, ai, bi), jnp.where(which, av, bv)
|
|
1314
1312
|
|
|
1315
1313
|
init_val = -jnp.inf
|
|
@@ -1317,44 +1315,90 @@ def _aten_max_pool2d_with_indices(
|
|
|
1317
1315
|
init_val = -(1 << 31)
|
|
1318
1316
|
init_val = jnp.array(init_val).astype(inputs.dtype)
|
|
1319
1317
|
|
|
1320
|
-
# Separate maxpool result and indices into two reduce_window ops. Since
|
|
1321
|
-
# the indices tensor is usually unused in inference, separating the two
|
|
1318
|
+
# Separate maxpool result and indices into two reduce_window ops. Since
|
|
1319
|
+
# the indices tensor is usually unused in inference, separating the two
|
|
1322
1320
|
# can help DCE computations for argmax.
|
|
1323
1321
|
y = jax.lax.reduce_window(
|
|
1324
|
-
inputs,
|
|
1325
|
-
|
|
1322
|
+
inputs,
|
|
1323
|
+
init_val,
|
|
1324
|
+
jax.lax.max,
|
|
1325
|
+
dims,
|
|
1326
|
+
strides,
|
|
1327
|
+
padding,
|
|
1328
|
+
window_dilation=dilation)
|
|
1326
1329
|
indices, _ = jax.lax.reduce_window(
|
|
1327
|
-
(indices, inputs),
|
|
1330
|
+
(indices, inputs),
|
|
1331
|
+
(0, init_val),
|
|
1332
|
+
reduce_fn,
|
|
1333
|
+
dims,
|
|
1334
|
+
strides,
|
|
1335
|
+
padding,
|
|
1336
|
+
window_dilation=dilation,
|
|
1328
1337
|
)
|
|
1329
1338
|
if is_single_input:
|
|
1330
1339
|
indices = jnp.squeeze(indices, axis=0)
|
|
1331
1340
|
y = jnp.squeeze(y, axis=0)
|
|
1332
|
-
|
|
1341
|
+
|
|
1333
1342
|
return y, indices
|
|
1334
1343
|
|
|
1335
1344
|
|
|
1345
|
+
# Aten ops registered under the `xla` library.
|
|
1346
|
+
try:
|
|
1347
|
+
|
|
1348
|
+
@op(torch.ops.xla.max_pool2d_forward)
|
|
1349
|
+
def _xla_max_pool2d_forward(*args, **kwargs):
|
|
1350
|
+
return _aten_max_pool2d_with_indices(*args, **kwargs)[0]
|
|
1351
|
+
|
|
1352
|
+
@op(torch.ops.xla.aot_mark_sharding)
|
|
1353
|
+
def _xla_aot_mark_sharding(t, mesh: str, partition_spec: str):
|
|
1354
|
+
from jax.sharding import PartitionSpec as P, NamedSharding
|
|
1355
|
+
import ast
|
|
1356
|
+
import torch_xla.distributed.spmd as xs
|
|
1357
|
+
pmesh = xs.Mesh.from_str(mesh)
|
|
1358
|
+
assert pmesh is not None
|
|
1359
|
+
partition_spec_eval = ast.literal_eval(partition_spec)
|
|
1360
|
+
jmesh = pmesh.get_jax_mesh()
|
|
1361
|
+
return jax.lax.with_sharding_constraint(
|
|
1362
|
+
t, NamedSharding(jmesh, P(*partition_spec_eval)))
|
|
1363
|
+
|
|
1364
|
+
@op(torch.ops.xla.einsum_linear_forward)
|
|
1365
|
+
def _xla_einsum_linear_forward(input, weight, bias):
|
|
1366
|
+
with jax.named_scope('einsum_linear_forward'):
|
|
1367
|
+
product = jax.numpy.einsum('...n,mn->...m', input, weight)
|
|
1368
|
+
if bias is not None:
|
|
1369
|
+
return product + bias
|
|
1370
|
+
return product
|
|
1371
|
+
|
|
1372
|
+
except AttributeError:
|
|
1373
|
+
pass
|
|
1374
|
+
|
|
1336
1375
|
# TODO add more ops
|
|
1337
1376
|
|
|
1338
1377
|
|
|
1339
1378
|
@op(torch.ops.aten.min)
|
|
1340
1379
|
def _aten_min(x, dim=None, keepdim=False):
|
|
1341
1380
|
if dim is not None:
|
|
1342
|
-
return _with_reduction_scalar(jnp.min, x, dim,
|
|
1381
|
+
return _with_reduction_scalar(jnp.min, x, dim,
|
|
1382
|
+
keepdim), _with_reduction_scalar(
|
|
1383
|
+
jnp.argmin, x, dim,
|
|
1384
|
+
keepdim).astype(jnp.int64)
|
|
1343
1385
|
else:
|
|
1344
1386
|
return _with_reduction_scalar(jnp.min, x, dim, keepdim)
|
|
1345
1387
|
|
|
1346
1388
|
|
|
1347
1389
|
@op(torch.ops.aten.mode)
|
|
1348
1390
|
def _aten_mode(input, dim=-1, keepdim=False, *, out=None):
|
|
1349
|
-
if input.ndim == 0:
|
|
1391
|
+
if input.ndim == 0: # single number
|
|
1350
1392
|
return input, jnp.array(0)
|
|
1351
|
-
dim = (input.ndim +
|
|
1393
|
+
dim = (input.ndim +
|
|
1394
|
+
dim) % input.ndim # jnp.scipy.stats.mode does not accept -1 as dim
|
|
1352
1395
|
# keepdims must be True for accurate broadcasting
|
|
1353
1396
|
mode, _ = jax.scipy.stats.mode(input, axis=dim, keepdims=True)
|
|
1354
1397
|
mode_broadcast = jnp.broadcast_to(mode, input.shape)
|
|
1355
1398
|
if not keepdim:
|
|
1356
1399
|
mode = mode.squeeze(axis=dim)
|
|
1357
|
-
indices = jnp.argmax(
|
|
1400
|
+
indices = jnp.argmax(
|
|
1401
|
+
jnp.equal(mode_broadcast, input), axis=dim, keepdims=keepdim)
|
|
1358
1402
|
return mode, indices
|
|
1359
1403
|
|
|
1360
1404
|
|
|
@@ -1388,8 +1432,7 @@ def _aten_var(x, dim=None, *, correction=1, keepdim=False, out=None):
|
|
|
1388
1432
|
@op(torch.ops.prims.broadcast_in_dim)
|
|
1389
1433
|
def _prims_broadcast_in_dim(t, shape, broadcast_dimensions):
|
|
1390
1434
|
return jax.lax.broadcast_in_dim(
|
|
1391
|
-
|
|
1392
|
-
)
|
|
1435
|
+
t, shape, broadcast_dimensions=broadcast_dimensions)
|
|
1393
1436
|
|
|
1394
1437
|
|
|
1395
1438
|
# aten.native_group_norm -- should use decomp table
|
|
@@ -1432,17 +1475,15 @@ def _aten_native_group_norm(input, weight, bias, N, C, HxW, group, eps=1e-5):
|
|
|
1432
1475
|
normalized = (x - mean) * rstd
|
|
1433
1476
|
return normalized, mean, rstd
|
|
1434
1477
|
|
|
1435
|
-
normalized, group_mean, group_rstd = jax.lax.map(
|
|
1436
|
-
|
|
1437
|
-
)
|
|
1478
|
+
normalized, group_mean, group_rstd = jax.lax.map(group_norm_body,
|
|
1479
|
+
reshaped_input)
|
|
1438
1480
|
|
|
1439
1481
|
# Reshape back to original input shape
|
|
1440
1482
|
output = jnp.reshape(normalized, input_shape)
|
|
1441
1483
|
|
|
1442
1484
|
# **Affine transformation**
|
|
1443
|
-
affine_shape = [
|
|
1444
|
-
|
|
1445
|
-
] # Shape for broadcasting
|
|
1485
|
+
affine_shape = [-1 if i == 1 else 1 for i in range(input.ndim)
|
|
1486
|
+
] # Shape for broadcasting
|
|
1446
1487
|
if weight is not None and bias is not None:
|
|
1447
1488
|
output = bias.reshape(affine_shape) + output * weight.reshape(affine_shape)
|
|
1448
1489
|
elif weight is not None:
|
|
@@ -1474,22 +1515,25 @@ def _aten_linalg_vector_norm(self, ord=2, dim=None, keepdim=False, dtype=None):
|
|
|
1474
1515
|
The tensor containing the calculated vector norms.
|
|
1475
1516
|
"""
|
|
1476
1517
|
|
|
1477
|
-
if ord not in {2, float("inf"), float("-inf"), "fro"
|
|
1518
|
+
if ord not in {2, float("inf"), float("-inf"), "fro"
|
|
1519
|
+
} and not isinstance(ord, (int, float)):
|
|
1478
1520
|
raise ValueError(
|
|
1479
|
-
|
|
1480
|
-
|
|
1481
|
-
|
|
1482
|
-
|
|
1521
|
+
f"Unsupported ord value: {ord}. Supported values are 2, inf, -inf, and"
|
|
1522
|
+
" 'fro'.")
|
|
1523
|
+
|
|
1483
1524
|
# Special cases (for efficiency and clarity)
|
|
1484
1525
|
if ord == 0:
|
|
1485
1526
|
if self.shape == ():
|
|
1486
1527
|
# float sets it to float64. set it back to input type
|
|
1487
1528
|
result = jnp.astype(jnp.array(float(self != 0)), self.dtype)
|
|
1488
1529
|
else:
|
|
1489
|
-
result = _with_reduction_scalar(jnp.sum, jnp.where(self != 0, 1, 0), dim,
|
|
1530
|
+
result = _with_reduction_scalar(jnp.sum, jnp.where(self != 0, 1, 0), dim,
|
|
1531
|
+
keepdim)
|
|
1490
1532
|
|
|
1491
1533
|
elif ord == 2: # Euclidean norm
|
|
1492
|
-
result = jnp.sqrt(
|
|
1534
|
+
result = jnp.sqrt(
|
|
1535
|
+
_with_reduction_scalar(jnp.sum,
|
|
1536
|
+
jnp.abs(self)**2, dim, keepdim))
|
|
1493
1537
|
|
|
1494
1538
|
elif ord == float("inf"):
|
|
1495
1539
|
result = _with_reduction_scalar(jnp.max, jnp.abs(self), dim, keepdim)
|
|
@@ -1498,12 +1542,14 @@ def _aten_linalg_vector_norm(self, ord=2, dim=None, keepdim=False, dtype=None):
|
|
|
1498
1542
|
result = _with_reduction_scalar(jnp.min, jnp.abs(self), dim, keepdim)
|
|
1499
1543
|
|
|
1500
1544
|
elif ord == "fro": # Frobenius norm
|
|
1501
|
-
result = jnp.sqrt(
|
|
1545
|
+
result = jnp.sqrt(
|
|
1546
|
+
_with_reduction_scalar(jnp.sum,
|
|
1547
|
+
jnp.abs(self)**2, dim, keepdim))
|
|
1502
1548
|
|
|
1503
1549
|
else: # General case (e.g., ord = 1, ord = 3)
|
|
1504
|
-
result = _with_reduction_scalar(jnp.sum,
|
|
1505
|
-
|
|
1506
|
-
|
|
1550
|
+
result = _with_reduction_scalar(jnp.sum,
|
|
1551
|
+
jnp.abs(self)**ord, dim,
|
|
1552
|
+
keepdim)**(1.0 / ord)
|
|
1507
1553
|
|
|
1508
1554
|
# (Optional) dtype conversion
|
|
1509
1555
|
if dtype is not None:
|
|
@@ -1539,9 +1585,12 @@ def _aten_sinh(self):
|
|
|
1539
1585
|
|
|
1540
1586
|
# aten.native_layer_norm_backward
|
|
1541
1587
|
@op(torch.ops.aten.native_layer_norm_backward)
|
|
1542
|
-
def _aten_native_layer_norm_backward(
|
|
1543
|
-
|
|
1544
|
-
|
|
1588
|
+
def _aten_native_layer_norm_backward(grad_out,
|
|
1589
|
+
input,
|
|
1590
|
+
normalized_shape,
|
|
1591
|
+
weight,
|
|
1592
|
+
bias,
|
|
1593
|
+
eps=1e-5):
|
|
1545
1594
|
"""Implements the backward pass of layer normalization in Jax as defined by `aten::native_layer_norm_backward`.
|
|
1546
1595
|
|
|
1547
1596
|
Args:
|
|
@@ -1555,9 +1604,8 @@ def _aten_native_layer_norm_backward(
|
|
|
1555
1604
|
Returns:
|
|
1556
1605
|
A tuple of (grad_input, grad_weight, grad_bias).
|
|
1557
1606
|
"""
|
|
1558
|
-
return jax.lax.native_layer_norm_backward(
|
|
1559
|
-
|
|
1560
|
-
)
|
|
1607
|
+
return jax.lax.native_layer_norm_backward(grad_out, input, normalized_shape,
|
|
1608
|
+
weight, bias, eps)
|
|
1561
1609
|
|
|
1562
1610
|
|
|
1563
1611
|
# aten.reflection_pad3d_backward
|
|
@@ -1585,12 +1633,14 @@ def _aten_bitwise_not(self):
|
|
|
1585
1633
|
|
|
1586
1634
|
|
|
1587
1635
|
# aten.bitwise_left_shift
|
|
1636
|
+
@op(torch.ops.aten.__lshift__)
|
|
1588
1637
|
@op(torch.ops.aten.bitwise_left_shift)
|
|
1589
1638
|
def _aten_bitwise_left_shift(input, other):
|
|
1590
1639
|
return jnp.left_shift(input, other)
|
|
1591
1640
|
|
|
1592
1641
|
|
|
1593
1642
|
# aten.bitwise_right_shift
|
|
1643
|
+
@op(torch.ops.aten.__rshift__)
|
|
1594
1644
|
@op(torch.ops.aten.bitwise_right_shift)
|
|
1595
1645
|
def _aten_bitwise_right_shift(input, other):
|
|
1596
1646
|
return jnp.right_shift(input, other)
|
|
@@ -1671,10 +1721,8 @@ def _scatter_index(dim, index):
|
|
|
1671
1721
|
target_shape = [1] * len(index_shape)
|
|
1672
1722
|
target_shape[i] = index_shape[i]
|
|
1673
1723
|
input_indexes.append(
|
|
1674
|
-
|
|
1675
|
-
|
|
1676
|
-
)
|
|
1677
|
-
)
|
|
1724
|
+
jnp.broadcast_to(
|
|
1725
|
+
jnp.arange(index_shape[i]).reshape(target_shape), index_shape))
|
|
1678
1726
|
return tuple(input_indexes), tuple(source_indexes)
|
|
1679
1727
|
|
|
1680
1728
|
|
|
@@ -1686,6 +1734,7 @@ def _aten_scatter_add(input, dim, index, src):
|
|
|
1686
1734
|
input_indexes, source_indexes = _scatter_index(dim, index)
|
|
1687
1735
|
return input.at[input_indexes].add(src[source_indexes])
|
|
1688
1736
|
|
|
1737
|
+
|
|
1689
1738
|
# aten.masked_scatter
|
|
1690
1739
|
@op(torch.ops.aten.masked_scatter)
|
|
1691
1740
|
def _aten_masked_scatter(self, mask, source):
|
|
@@ -1707,6 +1756,7 @@ def _aten_masked_scatter(self, mask, source):
|
|
|
1707
1756
|
|
|
1708
1757
|
return final_arr
|
|
1709
1758
|
|
|
1759
|
+
|
|
1710
1760
|
@op(torch.ops.aten.masked_select)
|
|
1711
1761
|
def _aten_masked_select(self, mask, *args, **kwargs):
|
|
1712
1762
|
broadcast_shape = jnp.broadcast_shapes(self.shape, mask.shape)
|
|
@@ -1722,6 +1772,7 @@ def _aten_masked_select(self, mask, *args, **kwargs):
|
|
|
1722
1772
|
|
|
1723
1773
|
return self_flat[true_indices]
|
|
1724
1774
|
|
|
1775
|
+
|
|
1725
1776
|
# aten.logical_not
|
|
1726
1777
|
|
|
1727
1778
|
|
|
@@ -1730,11 +1781,13 @@ def _aten_masked_select(self, mask, *args, **kwargs):
|
|
|
1730
1781
|
def _aten_sign(x):
|
|
1731
1782
|
return jnp.sign(x)
|
|
1732
1783
|
|
|
1784
|
+
|
|
1733
1785
|
# aten.signbit
|
|
1734
1786
|
@op(torch.ops.aten.signbit)
|
|
1735
1787
|
def _aten_signbit(x):
|
|
1736
1788
|
return jnp.signbit(x)
|
|
1737
1789
|
|
|
1790
|
+
|
|
1738
1791
|
# aten.sigmoid
|
|
1739
1792
|
@op(torch.ops.aten.sigmoid)
|
|
1740
1793
|
@op_base.promote_int_input
|
|
@@ -1760,7 +1813,13 @@ def _aten_atan(self):
|
|
|
1760
1813
|
|
|
1761
1814
|
@op(torch.ops.aten.scatter_reduce)
|
|
1762
1815
|
@op(torch.ops.aten.scatter)
|
|
1763
|
-
def _aten_scatter_reduce(input,
|
|
1816
|
+
def _aten_scatter_reduce(input,
|
|
1817
|
+
dim,
|
|
1818
|
+
index,
|
|
1819
|
+
src,
|
|
1820
|
+
reduce=None,
|
|
1821
|
+
*,
|
|
1822
|
+
include_self=True):
|
|
1764
1823
|
if not isinstance(src, jnp.ndarray):
|
|
1765
1824
|
src = jnp.array(src, dtype=input.dtype)
|
|
1766
1825
|
input_indexes, source_indexes = _scatter_index(dim, index)
|
|
@@ -1817,41 +1876,6 @@ def _aten_gt(self, other):
|
|
|
1817
1876
|
return self > other
|
|
1818
1877
|
|
|
1819
1878
|
|
|
1820
|
-
# aten.pixel_shuffle
|
|
1821
|
-
@op(torch.ops.aten.pixel_shuffle)
|
|
1822
|
-
def _aten_pixel_shuffle(x, upscale_factor):
|
|
1823
|
-
"""PixelShuffle implementation in JAX.
|
|
1824
|
-
|
|
1825
|
-
Args:
|
|
1826
|
-
x: Input tensor. Typically a feature map.
|
|
1827
|
-
upscale_factor: Integer by which to upscale the spatial dimensions.
|
|
1828
|
-
|
|
1829
|
-
Returns:
|
|
1830
|
-
Tensor after PixelShuffle operation.
|
|
1831
|
-
"""
|
|
1832
|
-
|
|
1833
|
-
batch_size, channels, height, width = x.shape
|
|
1834
|
-
|
|
1835
|
-
if channels % (upscale_factor**2) != 0:
|
|
1836
|
-
raise ValueError(
|
|
1837
|
-
"Number of channels must be divisible by the square of the upscale factor."
|
|
1838
|
-
)
|
|
1839
|
-
|
|
1840
|
-
new_channels = channels // (upscale_factor**2)
|
|
1841
|
-
new_height = height * upscale_factor
|
|
1842
|
-
new_width = width * upscale_factor
|
|
1843
|
-
|
|
1844
|
-
x = x.reshape(
|
|
1845
|
-
batch_size, new_channels, upscale_factor, upscale_factor, height, width
|
|
1846
|
-
)
|
|
1847
|
-
x = jnp.transpose(
|
|
1848
|
-
x, (0, 1, 2, 4, 3, 5)
|
|
1849
|
-
) # Move channels to spatial dimensions
|
|
1850
|
-
x = x.reshape(batch_size, new_channels, new_height, new_width)
|
|
1851
|
-
|
|
1852
|
-
return x
|
|
1853
|
-
|
|
1854
|
-
|
|
1855
1879
|
# aten.sym_stride
|
|
1856
1880
|
# aten.lt
|
|
1857
1881
|
@op(torch.ops.aten.lt)
|
|
@@ -1883,8 +1907,7 @@ def pool(inputs, init, reduce_fn, window_shape, strides, padding):
|
|
|
1883
1907
|
num_batch_dims = inputs.ndim - (len(window_shape) + 1)
|
|
1884
1908
|
strides = strides or (1,) * len(window_shape)
|
|
1885
1909
|
assert len(window_shape) == len(
|
|
1886
|
-
|
|
1887
|
-
), f"len({window_shape}) must equal len({strides})"
|
|
1910
|
+
strides), f"len({window_shape}) must equal len({strides})"
|
|
1888
1911
|
strides = (1,) * (1 + num_batch_dims) + strides
|
|
1889
1912
|
dims = (1,) * (1 + num_batch_dims) + window_shape
|
|
1890
1913
|
|
|
@@ -1901,23 +1924,22 @@ def pool(inputs, init, reduce_fn, window_shape, strides, padding):
|
|
|
1901
1924
|
if not isinstance(padding, str):
|
|
1902
1925
|
padding = tuple(map(tuple, padding))
|
|
1903
1926
|
assert len(padding) == len(window_shape), (
|
|
1904
|
-
|
|
1905
|
-
|
|
1906
|
-
)
|
|
1907
|
-
|
|
1908
|
-
[len(x) == 2 for x in padding]
|
|
1909
|
-
), f"each entry in padding {padding} must be length 2"
|
|
1927
|
+
f"padding {padding} must specify pads for same number of dims as "
|
|
1928
|
+
f"window_shape {window_shape}")
|
|
1929
|
+
assert all([len(x) == 2 for x in padding
|
|
1930
|
+
]), f"each entry in padding {padding} must be length 2"
|
|
1910
1931
|
padding = ((0, 0), (0, 0)) + padding
|
|
1911
1932
|
y = jax.lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding)
|
|
1912
1933
|
if is_single_input:
|
|
1913
1934
|
y = jnp.squeeze(y, axis=0)
|
|
1914
1935
|
return y
|
|
1915
1936
|
|
|
1916
|
-
|
|
1937
|
+
|
|
1917
1938
|
@op(torch.ops.aten._adaptive_avg_pool2d)
|
|
1918
1939
|
@op(torch.ops.aten._adaptive_avg_pool3d)
|
|
1919
|
-
def adaptive_avg_pool2or3d(input: jnp.ndarray,
|
|
1920
|
-
|
|
1940
|
+
def adaptive_avg_pool2or3d(input: jnp.ndarray,
|
|
1941
|
+
output_size: Tuple[int, int]) -> jnp.ndarray:
|
|
1942
|
+
"""
|
|
1921
1943
|
Applies a 2/3D adaptive average pooling over an input signal composed of several input planes.
|
|
1922
1944
|
|
|
1923
1945
|
See :class:`~torch.nn.AdaptiveAvgPool2d` for details and output shape.
|
|
@@ -1929,124 +1951,128 @@ def adaptive_avg_pool2or3d(input: jnp.ndarray, output_size: Tuple[int, int]) ->
|
|
|
1929
1951
|
Context:
|
|
1930
1952
|
https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py#L2401
|
|
1931
1953
|
"""
|
|
1932
|
-
|
|
1933
|
-
|
|
1934
|
-
|
|
1935
|
-
|
|
1936
|
-
|
|
1937
|
-
|
|
1938
|
-
|
|
1939
|
-
|
|
1940
|
-
|
|
1941
|
-
|
|
1942
|
-
|
|
1943
|
-
|
|
1944
|
-
|
|
1945
|
-
|
|
1946
|
-
|
|
1947
|
-
|
|
1948
|
-
|
|
1949
|
-
|
|
1950
|
-
|
|
1951
|
-
|
|
1952
|
-
|
|
1953
|
-
|
|
1954
|
-
|
|
1955
|
-
|
|
1956
|
-
|
|
1957
|
-
|
|
1958
|
-
|
|
1959
|
-
|
|
1960
|
-
|
|
1961
|
-
|
|
1962
|
-
|
|
1963
|
-
|
|
1964
|
-
|
|
1965
|
-
|
|
1966
|
-
|
|
1967
|
-
|
|
1968
|
-
|
|
1969
|
-
|
|
1970
|
-
|
|
1971
|
-
|
|
1972
|
-
|
|
1973
|
-
|
|
1974
|
-
|
|
1975
|
-
|
|
1976
|
-
|
|
1977
|
-
|
|
1978
|
-
|
|
1979
|
-
|
|
1980
|
-
|
|
1981
|
-
|
|
1982
|
-
|
|
1983
|
-
|
|
1984
|
-
|
|
1985
|
-
|
|
1986
|
-
|
|
1987
|
-
|
|
1988
|
-
|
|
1989
|
-
|
|
1990
|
-
|
|
1991
|
-
|
|
1992
|
-
|
|
1993
|
-
|
|
1994
|
-
|
|
1995
|
-
|
|
1996
|
-
|
|
1997
|
-
|
|
1998
|
-
|
|
1999
|
-
|
|
2000
|
-
|
|
1954
|
+
shape = input.shape
|
|
1955
|
+
ndim = len(shape)
|
|
1956
|
+
out_dim = len(output_size)
|
|
1957
|
+
num_spatial_dim = ndim - out_dim
|
|
1958
|
+
|
|
1959
|
+
# Preconditions
|
|
1960
|
+
|
|
1961
|
+
assert ndim in (
|
|
1962
|
+
out_dim + 1, out_dim + 2
|
|
1963
|
+
), f"adaptive_avg_pool{num_spatial_dim}d(): Expected {num_spatial_dim+1}D or {num_spatial_dim+2}D tensor, but got {ndim}"
|
|
1964
|
+
for d in input.shape[-2:]:
|
|
1965
|
+
assert d != 0, "adaptive_avg_pool{num_spactial_dim}d(): Expected input to have non-zero size for " \
|
|
1966
|
+
f"non-batch dimensions, but input has shape {tuple(shape)}."
|
|
1967
|
+
|
|
1968
|
+
# Optimisation (we should also do this in the kernel implementation)
|
|
1969
|
+
if all(s % o == 0 for o, s in zip(output_size, shape[-out_dim:])):
|
|
1970
|
+
stride = tuple(i // o for i, o in zip(shape[-out_dim:], output_size))
|
|
1971
|
+
kernel = tuple(i - (o - 1) * s
|
|
1972
|
+
for i, o, s in zip(shape[-out_dim:], output_size, stride))
|
|
1973
|
+
return _aten_avg_pool(
|
|
1974
|
+
input,
|
|
1975
|
+
kernel,
|
|
1976
|
+
strides=stride,
|
|
1977
|
+
)
|
|
1978
|
+
|
|
1979
|
+
def start_index(a, b, c):
|
|
1980
|
+
return (a * c) // b
|
|
1981
|
+
|
|
1982
|
+
def end_index(a, b, c):
|
|
1983
|
+
return ((a + 1) * c + b - 1) // b
|
|
1984
|
+
|
|
1985
|
+
def compute_idx(in_size, out_size):
|
|
1986
|
+
orange = jnp.arange(out_size, dtype=jnp.int64)
|
|
1987
|
+
i0 = start_index(orange, out_size, in_size)
|
|
1988
|
+
# Let length = end_index - start_index, i.e. the length of the pooling kernels
|
|
1989
|
+
# length.max() can be computed analytically as follows:
|
|
1990
|
+
maxlength = in_size // out_size + 1
|
|
1991
|
+
in_size_mod = in_size % out_size
|
|
1992
|
+
# adaptive = True iff there are kernels with different lengths
|
|
1993
|
+
adaptive = not (in_size_mod == 0 or out_size % in_size_mod == 0)
|
|
1994
|
+
if adaptive:
|
|
1995
|
+
maxlength += 1
|
|
1996
|
+
elif in_size_mod == 0:
|
|
1997
|
+
maxlength -= 1
|
|
1998
|
+
|
|
1999
|
+
range_max = jnp.arange(maxlength, dtype=jnp.int64)
|
|
2000
|
+
idx = i0[:, None] + range_max
|
|
2001
|
+
if adaptive:
|
|
2002
|
+
# Need to clamp to avoid accessing out-of-bounds memory
|
|
2003
|
+
idx = jnp.minimum(idx, in_size - 1)
|
|
2004
|
+
|
|
2005
|
+
# Compute the length
|
|
2006
|
+
i1 = end_index(orange, out_size, in_size)
|
|
2007
|
+
length = i1 - i0
|
|
2008
|
+
else:
|
|
2009
|
+
length = maxlength
|
|
2010
|
+
return idx, length, range_max, adaptive
|
|
2011
|
+
|
|
2012
|
+
idx, length, range_max, adaptive = [[None] * out_dim for _ in range(4)]
|
|
2013
|
+
# length is not None if it's constant, otherwise we'll need to compute it
|
|
2014
|
+
for i, (s, o) in enumerate(zip(shape[-out_dim:], output_size)):
|
|
2015
|
+
idx[i], length[i], range_max[i], adaptive[i] = compute_idx(s, o)
|
|
2016
|
+
|
|
2017
|
+
def _unsqueeze_to_dim(x, dim):
|
|
2018
|
+
ndim = len(x.shape)
|
|
2019
|
+
return jax.lax.expand_dims(x, tuple(range(ndim, dim)))
|
|
2020
|
+
|
|
2021
|
+
if out_dim == 2:
|
|
2022
|
+
# NOTE: unsqueeze to insert extra 1 in ranks; so they
|
|
2023
|
+
# would broadcast
|
|
2024
|
+
vals = input[..., _unsqueeze_to_dim(idx[0], 4), idx[1]]
|
|
2025
|
+
reduce_axis = (-3, -1)
|
|
2026
|
+
else:
|
|
2027
|
+
assert out_dim == 3
|
|
2028
|
+
vals = input[...,
|
|
2029
|
+
_unsqueeze_to_dim(idx[0], 6),
|
|
2030
|
+
_unsqueeze_to_dim(idx[1], 4), idx[2]]
|
|
2031
|
+
reduce_axis = (-5, -3, -1)
|
|
2032
|
+
|
|
2033
|
+
# Shortcut for the simpler case
|
|
2034
|
+
if not any(adaptive):
|
|
2035
|
+
return jnp.mean(vals, axis=reduce_axis)
|
|
2036
|
+
|
|
2037
|
+
def maybe_mask(vals, length, range_max, adaptive, dim):
|
|
2038
|
+
if isinstance(length, int):
|
|
2039
|
+
return vals, length
|
|
2001
2040
|
else:
|
|
2002
|
-
|
|
2003
|
-
|
|
2004
|
-
|
|
2005
|
-
|
|
2006
|
-
|
|
2007
|
-
|
|
2008
|
-
|
|
2009
|
-
|
|
2010
|
-
|
|
2011
|
-
|
|
2012
|
-
|
|
2013
|
-
|
|
2014
|
-
|
|
2015
|
-
|
|
2016
|
-
|
|
2017
|
-
|
|
2018
|
-
|
|
2019
|
-
|
|
2020
|
-
|
|
2021
|
-
|
|
2022
|
-
|
|
2023
|
-
|
|
2024
|
-
|
|
2025
|
-
# Compute the length of each window
|
|
2026
|
-
length = _unsqueeze_to_dim(length, -dim)
|
|
2027
|
-
return vals, length
|
|
2028
|
-
|
|
2029
|
-
for i in range(len(length)):
|
|
2030
|
-
vals, length[i] = maybe_mask(vals, length[i], range_max[i], adaptive=adaptive[i], dim=(i - out_dim))
|
|
2031
|
-
|
|
2032
|
-
# We unroll the sum as we assume that the kernels are going to be small
|
|
2033
|
-
ret = jnp.sum(vals, axis=reduce_axis)
|
|
2034
|
-
# NOTE: math.prod because we want to expand it to length[0] * length[1] * ...
|
|
2035
|
-
# this is multiplication with broadcasting, not regular pointwise product
|
|
2036
|
-
return ret / math.prod(length)
|
|
2037
|
-
|
|
2041
|
+
# zero-out the things we didn't really want to select
|
|
2042
|
+
assert dim < 0
|
|
2043
|
+
# hack
|
|
2044
|
+
mask = range_max >= length[:, None]
|
|
2045
|
+
if dim == -2:
|
|
2046
|
+
mask = _unsqueeze_to_dim(mask, 4)
|
|
2047
|
+
elif dim == -3:
|
|
2048
|
+
mask = _unsqueeze_to_dim(mask, 6)
|
|
2049
|
+
vals = jnp.where(mask, 0.0, vals)
|
|
2050
|
+
# Compute the length of each window
|
|
2051
|
+
length = _unsqueeze_to_dim(length, -dim)
|
|
2052
|
+
return vals, length
|
|
2053
|
+
|
|
2054
|
+
for i in range(len(length)):
|
|
2055
|
+
vals, length[i] = maybe_mask(
|
|
2056
|
+
vals, length[i], range_max[i], adaptive=adaptive[i], dim=(i - out_dim))
|
|
2057
|
+
|
|
2058
|
+
# We unroll the sum as we assume that the kernels are going to be small
|
|
2059
|
+
ret = jnp.sum(vals, axis=reduce_axis)
|
|
2060
|
+
# NOTE: math.prod because we want to expand it to length[0] * length[1] * ...
|
|
2061
|
+
# this is multiplication with broadcasting, not regular pointwise product
|
|
2062
|
+
return ret / math.prod(length)
|
|
2063
|
+
|
|
2038
2064
|
|
|
2039
2065
|
@op(torch.ops.aten.avg_pool1d)
|
|
2040
2066
|
@op(torch.ops.aten.avg_pool2d)
|
|
2041
2067
|
@op(torch.ops.aten.avg_pool3d)
|
|
2042
2068
|
def _aten_avg_pool(
|
|
2043
|
-
|
|
2044
|
-
|
|
2045
|
-
|
|
2046
|
-
|
|
2047
|
-
|
|
2048
|
-
|
|
2049
|
-
|
|
2069
|
+
inputs,
|
|
2070
|
+
kernel_size,
|
|
2071
|
+
strides=None,
|
|
2072
|
+
padding=0,
|
|
2073
|
+
ceil_mode=False,
|
|
2074
|
+
count_include_pad=True,
|
|
2075
|
+
divisor_override=None,
|
|
2050
2076
|
):
|
|
2051
2077
|
num_batch_dims = len(inputs.shape) - len(kernel_size) - 1
|
|
2052
2078
|
kernel_size = tuple(kernel_size)
|
|
@@ -2060,7 +2086,7 @@ def _aten_avg_pool(
|
|
|
2060
2086
|
if num_batch_dims == 0:
|
|
2061
2087
|
input_shape = [1, *input_shape]
|
|
2062
2088
|
padding = _ceil_mode_padding(padding, input_shape, kernel_size, strides,
|
|
2063
|
-
ceil_mode)
|
|
2089
|
+
[1] * len(kernel_size), ceil_mode)
|
|
2064
2090
|
|
|
2065
2091
|
y = pool(inputs, 0.0, jax.lax.add, kernel_size, strides, padding)
|
|
2066
2092
|
if divisor_override is not None:
|
|
@@ -2102,9 +2128,11 @@ def _aten_avg_pool(
|
|
|
2102
2128
|
)
|
|
2103
2129
|
return y.astype(inputs.dtype)
|
|
2104
2130
|
|
|
2131
|
+
|
|
2105
2132
|
# helper function to generate all indices to iterate through ndarray
|
|
2106
|
-
def _generate_indices(dims, skip_dim_indices
|
|
2133
|
+
def _generate_indices(dims, skip_dim_indices=[]):
|
|
2107
2134
|
res = []
|
|
2135
|
+
|
|
2108
2136
|
def _helper(curr_dim_idx, sofar):
|
|
2109
2137
|
if curr_dim_idx in skip_dim_indices:
|
|
2110
2138
|
_helper(curr_dim_idx + 1, sofar[:])
|
|
@@ -2115,10 +2143,11 @@ def _generate_indices(dims, skip_dim_indices = []):
|
|
|
2115
2143
|
for i in range(dims[curr_dim_idx]):
|
|
2116
2144
|
sofar[curr_dim_idx] = i
|
|
2117
2145
|
_helper(curr_dim_idx + 1, sofar[:])
|
|
2118
|
-
|
|
2146
|
+
|
|
2119
2147
|
_helper(0, [0 for _ in dims])
|
|
2120
2148
|
return res
|
|
2121
2149
|
|
|
2150
|
+
|
|
2122
2151
|
# aten.sym_numel
|
|
2123
2152
|
# aten.reciprocal
|
|
2124
2153
|
@op(torch.ops.aten.reciprocal)
|
|
@@ -2174,10 +2203,14 @@ def _aten_round(input, decimals=0):
|
|
|
2174
2203
|
@op(torch.ops.aten.max)
|
|
2175
2204
|
def _aten_max(self, dim=None, keepdim=False):
|
|
2176
2205
|
if dim is not None:
|
|
2177
|
-
return _with_reduction_scalar(jnp.max, self, dim,
|
|
2206
|
+
return _with_reduction_scalar(jnp.max, self, dim,
|
|
2207
|
+
keepdim), _with_reduction_scalar(
|
|
2208
|
+
jnp.argmax, self, dim,
|
|
2209
|
+
keepdim).astype(jnp.int64)
|
|
2178
2210
|
else:
|
|
2179
2211
|
return _with_reduction_scalar(jnp.max, self, dim, keepdim)
|
|
2180
2212
|
|
|
2213
|
+
|
|
2181
2214
|
# aten.maximum
|
|
2182
2215
|
@op(torch.ops.aten.maximum)
|
|
2183
2216
|
def _aten_maximum(self, other):
|
|
@@ -2216,27 +2249,28 @@ def _with_reduction_scalar(jax_func, self, dim, keepdim):
|
|
|
2216
2249
|
def _aten_any(self, dim=None, keepdim=False):
|
|
2217
2250
|
return _with_reduction_scalar(jnp.any, self, dim, keepdim)
|
|
2218
2251
|
|
|
2252
|
+
|
|
2219
2253
|
# aten.arange
|
|
2220
2254
|
@op(torch.ops.aten.arange.start_step)
|
|
2221
2255
|
@op(torch.ops.aten.arange.start)
|
|
2222
2256
|
@op(torch.ops.aten.arange.default)
|
|
2223
2257
|
@op_base.convert_dtype(use_default_dtype=False)
|
|
2224
2258
|
def _aten_arange(
|
|
2225
|
-
|
|
2226
|
-
|
|
2227
|
-
|
|
2228
|
-
|
|
2229
|
-
|
|
2230
|
-
|
|
2231
|
-
|
|
2232
|
-
|
|
2233
|
-
|
|
2259
|
+
start,
|
|
2260
|
+
end=None,
|
|
2261
|
+
step=None,
|
|
2262
|
+
*,
|
|
2263
|
+
dtype=None,
|
|
2264
|
+
layout=None,
|
|
2265
|
+
requires_grad=False,
|
|
2266
|
+
device=None,
|
|
2267
|
+
pin_memory=False,
|
|
2234
2268
|
):
|
|
2235
2269
|
return jnp.arange(
|
|
2236
|
-
|
|
2237
|
-
|
|
2238
|
-
|
|
2239
|
-
|
|
2270
|
+
op_base.maybe_convert_constant_dtype(start, dtype),
|
|
2271
|
+
op_base.maybe_convert_constant_dtype(end, dtype),
|
|
2272
|
+
op_base.maybe_convert_constant_dtype(step, dtype),
|
|
2273
|
+
dtype=dtype,
|
|
2240
2274
|
)
|
|
2241
2275
|
|
|
2242
2276
|
|
|
@@ -2245,6 +2279,7 @@ def _aten_arange(
|
|
|
2245
2279
|
def _aten_argmax(self, dim=None, keepdim=False):
|
|
2246
2280
|
return _with_reduction_scalar(jnp.argmax, self, dim, keepdim)
|
|
2247
2281
|
|
|
2282
|
+
|
|
2248
2283
|
def _strided_index(sizes, strides, storage_offset=None):
|
|
2249
2284
|
ind = jnp.zeros(sizes, dtype=jnp.int32)
|
|
2250
2285
|
|
|
@@ -2257,6 +2292,7 @@ def _strided_index(sizes, strides, storage_offset=None):
|
|
|
2257
2292
|
ind += storage_offset
|
|
2258
2293
|
return ind
|
|
2259
2294
|
|
|
2295
|
+
|
|
2260
2296
|
# aten.as_strided
|
|
2261
2297
|
@op(torch.ops.aten.as_strided)
|
|
2262
2298
|
@op(torch.ops.aten.as_strided_copy)
|
|
@@ -2311,7 +2347,7 @@ def _aten_broadcast_tensors(*tensors):
|
|
|
2311
2347
|
Args:
|
|
2312
2348
|
shapes: A list of tuples representing the shapes of the input tensors.
|
|
2313
2349
|
|
|
2314
|
-
Returns:
|
|
2350
|
+
Returns:
|
|
2315
2351
|
A tuple representing the broadcasted output shape.
|
|
2316
2352
|
"""
|
|
2317
2353
|
|
|
@@ -2344,11 +2380,13 @@ def _aten_broadcast_tensors(*tensors):
|
|
|
2344
2380
|
A tuple specifying which dimensions of the input tensor should be broadcasted.
|
|
2345
2381
|
"""
|
|
2346
2382
|
|
|
2347
|
-
res = tuple(
|
|
2383
|
+
res = tuple(
|
|
2384
|
+
i for i, (in_dim, out_dim) in enumerate(zip(input_shape, output_shape)))
|
|
2348
2385
|
return res
|
|
2349
2386
|
|
|
2350
2387
|
# clean some function's previous wrap
|
|
2351
|
-
if len(tensors)==1 and len(tensors[0])>=1 and isinstance(
|
|
2388
|
+
if len(tensors) == 1 and len(tensors[0]) >= 1 and isinstance(
|
|
2389
|
+
tensors[0][0], jax.Array):
|
|
2352
2390
|
tensors = tensors[0]
|
|
2353
2391
|
|
|
2354
2392
|
# Get the shapes of all input tensors
|
|
@@ -2357,7 +2395,8 @@ def _aten_broadcast_tensors(*tensors):
|
|
|
2357
2395
|
output_shape = _get_broadcast_shape(shapes)
|
|
2358
2396
|
# Broadcast each tensor to the output shape
|
|
2359
2397
|
broadcasted_tensors = [
|
|
2360
|
-
jax.lax.broadcast_in_dim(t, output_shape,
|
|
2398
|
+
jax.lax.broadcast_in_dim(t, output_shape,
|
|
2399
|
+
_broadcast_dimensions(t.shape, output_shape))
|
|
2361
2400
|
for t in tensors
|
|
2362
2401
|
]
|
|
2363
2402
|
|
|
@@ -2376,6 +2415,7 @@ def _aten_broadcast_to(input, shape):
|
|
|
2376
2415
|
def _aten_clamp(self, min=None, max=None):
|
|
2377
2416
|
return jnp.clip(self, min, max)
|
|
2378
2417
|
|
|
2418
|
+
|
|
2379
2419
|
@op(torch.ops.aten.clamp_min)
|
|
2380
2420
|
def _aten_clamp_min(input, min):
|
|
2381
2421
|
return jnp.clip(input, min=min)
|
|
@@ -2394,7 +2434,7 @@ def _aten_constant_pad_nd(input, padding, value=0):
|
|
|
2394
2434
|
rev_padding = [(padding[i - 1], padding[i], 0) for i in range(m - 1, 0, -2)]
|
|
2395
2435
|
pad_dim = tuple(([(0, 0, 0)] * (len(input.shape) - m // 2)) + rev_padding)
|
|
2396
2436
|
value_casted = jax.numpy.array(value, dtype=input.dtype)
|
|
2397
|
-
return jax.lax.pad(input, padding_value=value_casted, padding_config
|
|
2437
|
+
return jax.lax.pad(input, padding_value=value_casted, padding_config=pad_dim)
|
|
2398
2438
|
|
|
2399
2439
|
|
|
2400
2440
|
# aten.convolution_backward
|
|
@@ -2421,9 +2461,8 @@ def _aten_cdist_forward(x1, x2, p, compute_mode=""):
|
|
|
2421
2461
|
@op(torch.ops.aten._pdist_forward)
|
|
2422
2462
|
def _aten__pdist_forward(x, p=2):
|
|
2423
2463
|
pairwise_dists = _aten_cdist_forward(x, x, p)
|
|
2424
|
-
condensed_dists = pairwise_dists[
|
|
2425
|
-
|
|
2426
|
-
]
|
|
2464
|
+
condensed_dists = pairwise_dists[jnp.triu_indices(
|
|
2465
|
+
pairwise_dists.shape[0], k=1)]
|
|
2427
2466
|
return condensed_dists
|
|
2428
2467
|
|
|
2429
2468
|
|
|
@@ -2449,25 +2488,33 @@ def _aten_cosh(input):
|
|
|
2449
2488
|
return jnp.cosh(input)
|
|
2450
2489
|
|
|
2451
2490
|
|
|
2491
|
+
@op(torch.ops.aten.diag)
|
|
2492
|
+
def _aten_diag(input, diagonal=0):
|
|
2493
|
+
return jnp.diag(input, diagonal)
|
|
2494
|
+
|
|
2495
|
+
|
|
2452
2496
|
# aten.diagonal
|
|
2453
2497
|
@op(torch.ops.aten.diagonal)
|
|
2498
|
+
@op(torch.ops.aten.diagonal_copy)
|
|
2454
2499
|
def _aten_diagonal(input, offset=0, dim1=0, dim2=1):
|
|
2455
2500
|
return jnp.diagonal(input, offset, dim1, dim2)
|
|
2456
2501
|
|
|
2457
2502
|
|
|
2458
2503
|
def diag_indices_with_offset(input_shape, offset, dim1=0, dim2=1):
|
|
2459
|
-
|
|
2460
|
-
|
|
2461
|
-
|
|
2462
|
-
|
|
2463
|
-
|
|
2464
|
-
|
|
2465
|
-
|
|
2466
|
-
|
|
2467
|
-
|
|
2468
|
-
|
|
2469
|
-
|
|
2470
|
-
|
|
2504
|
+
input_len = len(input_shape)
|
|
2505
|
+
if dim1 == dim2 or not (0 <= dim1 < input_len and 0 <= dim2 < input_len):
|
|
2506
|
+
raise ValueError("dim1 and dim2 must be different and in range [0, " +
|
|
2507
|
+
str(input_len - 1) + "]")
|
|
2508
|
+
|
|
2509
|
+
size1, size2 = input_shape[dim1], input_shape[dim2]
|
|
2510
|
+
if offset >= 0:
|
|
2511
|
+
indices1 = jnp.arange(min(size1, size2 - offset))
|
|
2512
|
+
indices2 = jnp.arange(offset, offset + len(indices1))
|
|
2513
|
+
else:
|
|
2514
|
+
indices2 = jnp.arange(min(size1 + offset, size2))
|
|
2515
|
+
indices1 = jnp.arange(-offset, -offset + len(indices2))
|
|
2516
|
+
return [indices1, indices2]
|
|
2517
|
+
|
|
2471
2518
|
|
|
2472
2519
|
@op(torch.ops.aten.diagonal_scatter)
|
|
2473
2520
|
def _aten_diagonal_scatter(input, src, offset=0, dim1=0, dim2=1):
|
|
@@ -2476,17 +2523,17 @@ def _aten_diagonal_scatter(input, src, offset=0, dim1=0, dim2=1):
|
|
|
2476
2523
|
if input.ndim == 2:
|
|
2477
2524
|
return input.at[tuple(indexes)].set(src)
|
|
2478
2525
|
else:
|
|
2479
|
-
# src has the same shape as the output of
|
|
2526
|
+
# src has the same shape as the output of
|
|
2480
2527
|
# jnp.diagonal(input, offset, dim1, dim2).
|
|
2481
2528
|
# Last dimension always contains the diagonal elements,
|
|
2482
2529
|
# while the preceding dimensions represent the "slices"
|
|
2483
2530
|
# from which these diagonals are extracted. Thus,
|
|
2484
2531
|
# we alter input axes to match this assumption, write src
|
|
2485
2532
|
# and then move the axes back to the original state.
|
|
2486
|
-
input = jnp.moveaxis(input, (dim1, dim2), (-2
|
|
2487
|
-
multi_indexes = [slice(None)]*(input.ndim-2) + indexes
|
|
2533
|
+
input = jnp.moveaxis(input, (dim1, dim2), (-2, -1))
|
|
2534
|
+
multi_indexes = [slice(None)] * (input.ndim - 2) + indexes
|
|
2488
2535
|
input = input.at[tuple(multi_indexes)].set(src)
|
|
2489
|
-
return jnp.moveaxis(input, (-2
|
|
2536
|
+
return jnp.moveaxis(input, (-2, -1), (dim1, dim2))
|
|
2490
2537
|
|
|
2491
2538
|
|
|
2492
2539
|
# aten.diagflat
|
|
@@ -2507,9 +2554,9 @@ def _aten_eq(input1, input2):
|
|
|
2507
2554
|
|
|
2508
2555
|
|
|
2509
2556
|
# aten.equal
|
|
2510
|
-
@op(torch.ops.aten.equal
|
|
2557
|
+
@op(torch.ops.aten.equal)
|
|
2511
2558
|
def _aten_equal(input, other):
|
|
2512
|
-
res = jnp.array_equal(input
|
|
2559
|
+
res = jnp.array_equal(input, other)
|
|
2513
2560
|
return bool(res)
|
|
2514
2561
|
|
|
2515
2562
|
|
|
@@ -2559,7 +2606,12 @@ def _aten_exp2(input):
|
|
|
2559
2606
|
# aten.fill
|
|
2560
2607
|
@op(torch.ops.aten.fill)
|
|
2561
2608
|
@op(torch.ops.aten.full_like)
|
|
2562
|
-
def _aten_fill(x,
|
|
2609
|
+
def _aten_fill(x,
|
|
2610
|
+
value,
|
|
2611
|
+
dtype=None,
|
|
2612
|
+
pin_memory=None,
|
|
2613
|
+
memory_format=None,
|
|
2614
|
+
device=None):
|
|
2563
2615
|
if dtype is None:
|
|
2564
2616
|
dtype = x.dtype
|
|
2565
2617
|
else:
|
|
@@ -2634,7 +2686,8 @@ def _aten_glu(x, dim=-1):
|
|
|
2634
2686
|
# aten.hardtanh
|
|
2635
2687
|
@op(torch.ops.aten.hardtanh)
|
|
2636
2688
|
def _aten_hardtanh(input, min_val=-1, max_val=1, inplace=False):
|
|
2637
|
-
if input.dtype == np.int64 and isinstance(max_val, float) and isinstance(
|
|
2689
|
+
if input.dtype == np.int64 and isinstance(max_val, float) and isinstance(
|
|
2690
|
+
min_val, float):
|
|
2638
2691
|
min_val = int(min_val)
|
|
2639
2692
|
max_val = int(max_val)
|
|
2640
2693
|
return jnp.clip(input, min_val, max_val)
|
|
@@ -2644,7 +2697,7 @@ def _aten_hardtanh(input, min_val=-1, max_val=1, inplace=False):
|
|
|
2644
2697
|
@op(torch.ops.aten.histc)
|
|
2645
2698
|
def _aten_histc(input, bins=100, min=0, max=0):
|
|
2646
2699
|
# TODO(@manfei): this function might cause some uncertainty
|
|
2647
|
-
if min==0 and max==0:
|
|
2700
|
+
if min == 0 and max == 0:
|
|
2648
2701
|
if isinstance(input, jnp.ndarray) and input.size == 0:
|
|
2649
2702
|
min = 0
|
|
2650
2703
|
max = 0
|
|
@@ -2652,7 +2705,8 @@ def _aten_histc(input, bins=100, min=0, max=0):
|
|
|
2652
2705
|
min = jnp.min(input)
|
|
2653
2706
|
max = jnp.max(input)
|
|
2654
2707
|
range_value = (min, max)
|
|
2655
|
-
hist, bin_edges = jnp.histogram(
|
|
2708
|
+
hist, bin_edges = jnp.histogram(
|
|
2709
|
+
input, bins=bins, range=range_value, weights=None, density=None)
|
|
2656
2710
|
return hist
|
|
2657
2711
|
|
|
2658
2712
|
|
|
@@ -2667,22 +2721,28 @@ def _aten_digamma(input, *, out=None):
|
|
|
2667
2721
|
# replace indices where input == 0 with -inf in res
|
|
2668
2722
|
return jnp.where(jnp.equal(input, jnp.zeros(input.shape)), -jnp.inf, res)
|
|
2669
2723
|
|
|
2724
|
+
|
|
2670
2725
|
@op(torch.ops.aten.igamma)
|
|
2671
2726
|
def _aten_igamma(input, other):
|
|
2672
2727
|
return jax.scipy.special.gammainc(input, other)
|
|
2673
2728
|
|
|
2729
|
+
|
|
2674
2730
|
@op(torch.ops.aten.lgamma)
|
|
2675
2731
|
def _aten_lgamma(input, *, out=None):
|
|
2676
2732
|
return jax.scipy.special.gammaln(input).astype(jnp.float32)
|
|
2677
2733
|
|
|
2734
|
+
|
|
2678
2735
|
@op(torch.ops.aten.mvlgamma)
|
|
2679
2736
|
def _aten_mvlgamma(input, p, *, out=None):
|
|
2680
|
-
|
|
2737
|
+
input = input.astype(mappings.t2j_dtype(torch.get_default_dtype()))
|
|
2738
|
+
return jax.scipy.special.multigammaln(input, p)
|
|
2739
|
+
|
|
2681
2740
|
|
|
2682
2741
|
@op(torch.ops.aten.linalg_eig)
|
|
2683
2742
|
def _aten_linalg_eig(A):
|
|
2684
2743
|
return jnp.linalg.eig(A)
|
|
2685
2744
|
|
|
2745
|
+
|
|
2686
2746
|
@op(torch.ops.aten._linalg_eigh)
|
|
2687
2747
|
def _aten_linalg_eigh(A, UPLO='L'):
|
|
2688
2748
|
return jnp.linalg.eigh(A, UPLO)
|
|
@@ -2704,7 +2764,9 @@ def _aten_linalg_lstsq(A, B, rcond=None, driver='gelsy'):
|
|
|
2704
2764
|
A_reshaped = A.reshape((batch_size,) + A.shape[-2:])
|
|
2705
2765
|
B_reshaped = B.reshape((batch_size,) + B.shape[-2:])
|
|
2706
2766
|
|
|
2707
|
-
X, residuals, rank, singular_values = jax.vmap(
|
|
2767
|
+
X, residuals, rank, singular_values = jax.vmap(
|
|
2768
|
+
jnp.linalg.lstsq, in_axes=(0,
|
|
2769
|
+
0))(A_reshaped, B_reshaped, rcond=rcond)
|
|
2708
2770
|
|
|
2709
2771
|
X = X.reshape(batch_shape + X.shape[-2:])
|
|
2710
2772
|
|
|
@@ -2720,7 +2782,8 @@ def _aten_linalg_lstsq(A, B, rcond=None, driver='gelsy'):
|
|
|
2720
2782
|
residuals = residuals.reshape(batch_shape + residuals.shape[-1:])
|
|
2721
2783
|
|
|
2722
2784
|
if driver in ['gelsd', 'gelss']:
|
|
2723
|
-
singular_values = singular_values.reshape(batch_shape +
|
|
2785
|
+
singular_values = singular_values.reshape(batch_shape +
|
|
2786
|
+
singular_values.shape[-1:])
|
|
2724
2787
|
else:
|
|
2725
2788
|
singular_values = jnp.array([], dtype=input_dtype)
|
|
2726
2789
|
|
|
@@ -2729,17 +2792,17 @@ def _aten_linalg_lstsq(A, B, rcond=None, driver='gelsy'):
|
|
|
2729
2792
|
X, residuals, rank, singular_values = jnp.linalg.lstsq(A, B, rcond=rcond)
|
|
2730
2793
|
|
|
2731
2794
|
if driver not in ['gelsd', 'gelsy', 'gelss']:
|
|
2732
|
-
|
|
2795
|
+
rank = jnp.array([], dtype=jnp.int64)
|
|
2733
2796
|
|
|
2734
2797
|
rank_value = None
|
|
2735
2798
|
if rank.size > 0:
|
|
2736
|
-
|
|
2737
|
-
|
|
2799
|
+
rank_value = int(rank.item())
|
|
2800
|
+
rank = jnp.array(rank_value, dtype=jnp.int64)
|
|
2738
2801
|
|
|
2739
2802
|
# When driver is ‘gels’, assume that A is full-rank.
|
|
2740
|
-
full_rank =
|
|
2803
|
+
full_rank = driver == 'gels' or rank_value == n
|
|
2741
2804
|
if driver == 'gelsy' or m <= n or (not full_rank):
|
|
2742
|
-
|
|
2805
|
+
residuals = jnp.array([], dtype=input_dtype)
|
|
2743
2806
|
|
|
2744
2807
|
if driver not in ['gelsd', 'gelss']:
|
|
2745
2808
|
singular_values = jnp.array([], dtype=input_dtype)
|
|
@@ -2753,8 +2816,7 @@ def _aten_linalg_ldl_factor_ex(A, hermitian=False, check_errors=False):
|
|
|
2753
2816
|
# https://github.com/jax-ml/jax/issues/12779
|
|
2754
2817
|
# TODO: Not tested for complex inputs. Does not support hermitian=True
|
|
2755
2818
|
pivots = jnp.broadcast_to(
|
|
2756
|
-
jnp.arange(1, A.shape[-1]+1, dtype=jnp.int32), A.shape[:-1]
|
|
2757
|
-
)
|
|
2819
|
+
jnp.arange(1, A.shape[-1] + 1, dtype=jnp.int32), A.shape[:-1])
|
|
2758
2820
|
info = jnp.zeros(A.shape[:-2], jnp.int32)
|
|
2759
2821
|
C = jnp.linalg.cholesky(A)
|
|
2760
2822
|
if C.size == 0:
|
|
@@ -2767,7 +2829,7 @@ def _aten_linalg_ldl_factor_ex(A, hermitian=False, check_errors=False):
|
|
|
2767
2829
|
|
|
2768
2830
|
D = C * jnp.eye(C.shape[-1], dtype=A.dtype)
|
|
2769
2831
|
LD = C @ jnp.linalg.inv(D)
|
|
2770
|
-
LD = fill_diagonal_batch(LD, D*D)
|
|
2832
|
+
LD = fill_diagonal_batch(LD, D * D)
|
|
2771
2833
|
return LD, pivots, info
|
|
2772
2834
|
|
|
2773
2835
|
|
|
@@ -2787,9 +2849,9 @@ def _aten_linalg_lu(A, pivot=True, out=None):
|
|
|
2787
2849
|
U = jnp.triu(lu[..., :k, :])
|
|
2788
2850
|
|
|
2789
2851
|
def perm_to_P(perm):
|
|
2790
|
-
|
|
2791
|
-
|
|
2792
|
-
|
|
2852
|
+
m = perm.shape[-1]
|
|
2853
|
+
P = jnp.eye(m, dtype=dtype)[perm].T
|
|
2854
|
+
return P
|
|
2793
2855
|
|
|
2794
2856
|
if permutation.ndim > 1:
|
|
2795
2857
|
num_batch_dims = permutation.ndim - 1
|
|
@@ -2798,7 +2860,7 @@ def _aten_linalg_lu(A, pivot=True, out=None):
|
|
|
2798
2860
|
|
|
2799
2861
|
P = perm_to_P(permutation)
|
|
2800
2862
|
|
|
2801
|
-
return P,L,U
|
|
2863
|
+
return P, L, U
|
|
2802
2864
|
|
|
2803
2865
|
|
|
2804
2866
|
@op(torch.ops.aten.linalg_lu_factor_ex)
|
|
@@ -2810,6 +2872,21 @@ def _aten_linalg_lu_factor_ex(A, pivot=True, check_errors=False):
|
|
|
2810
2872
|
return lu, pivots, info
|
|
2811
2873
|
|
|
2812
2874
|
|
|
2875
|
+
@op(torch.ops.aten.linalg_lu_solve)
|
|
2876
|
+
def _aten_linalg_lu_solve(LU, pivots, B, left=True, adjoint=False):
|
|
2877
|
+
# JAX pivots are offset by 1 compared to torch
|
|
2878
|
+
pivots = pivots - 1
|
|
2879
|
+
if not left:
|
|
2880
|
+
# XA = B is same as A'X = B'
|
|
2881
|
+
trans = 0 if adjoint else 2
|
|
2882
|
+
x = jax.scipy.linalg.lu_solve((LU, pivots), jnp.matrix_transpose(B), trans)
|
|
2883
|
+
x = jnp.matrix_transpose(x)
|
|
2884
|
+
else:
|
|
2885
|
+
trans = 2 if adjoint else 0
|
|
2886
|
+
x = jax.scipy.linalg.lu_solve((LU, pivots), B, trans)
|
|
2887
|
+
return x
|
|
2888
|
+
|
|
2889
|
+
|
|
2813
2890
|
@op(torch.ops.aten.gcd)
|
|
2814
2891
|
def _aten_gcd(input, other):
|
|
2815
2892
|
return jnp.gcd(input, other)
|
|
@@ -2874,12 +2951,14 @@ def _aten_log2(x):
|
|
|
2874
2951
|
|
|
2875
2952
|
# aten.logical_and
|
|
2876
2953
|
@op(torch.ops.aten.logical_and)
|
|
2954
|
+
@op(torch.ops.aten.__and__)
|
|
2877
2955
|
def _aten_logical_and(self, other):
|
|
2878
2956
|
return jnp.logical_and(self, other)
|
|
2879
2957
|
|
|
2880
2958
|
|
|
2881
2959
|
# aten.logical_or
|
|
2882
2960
|
@op(torch.ops.aten.logical_or)
|
|
2961
|
+
@op(torch.ops.aten.__or__)
|
|
2883
2962
|
def _aten_logical_or(self, other):
|
|
2884
2963
|
return jnp.logical_or(self, other)
|
|
2885
2964
|
|
|
@@ -2894,7 +2973,7 @@ def _aten_logical_not(self):
|
|
|
2894
2973
|
@op(torch.ops.aten._log_softmax)
|
|
2895
2974
|
def _aten_log_softmax(self, axis=-1, half_to_float=False):
|
|
2896
2975
|
if self.shape == ():
|
|
2897
|
-
|
|
2976
|
+
return jnp.astype(0.0, self.dtype)
|
|
2898
2977
|
return jax.nn.log_softmax(self, axis)
|
|
2899
2978
|
|
|
2900
2979
|
|
|
@@ -2921,6 +3000,7 @@ def _aten_logcumsumexp(self, dim=None):
|
|
|
2921
3000
|
# aten.max_pool3d_backward
|
|
2922
3001
|
# aten.logical_xor
|
|
2923
3002
|
@op(torch.ops.aten.logical_xor)
|
|
3003
|
+
@op(torch.ops.aten.__xor__)
|
|
2924
3004
|
def _aten_logical_xor(self, other):
|
|
2925
3005
|
return jnp.logical_xor(self, other)
|
|
2926
3006
|
|
|
@@ -2933,19 +3013,22 @@ def _aten_logical_xor(self, other):
|
|
|
2933
3013
|
def _aten_neg(x):
|
|
2934
3014
|
return -1 * x
|
|
2935
3015
|
|
|
3016
|
+
|
|
2936
3017
|
@op(torch.ops.aten.nextafter)
|
|
2937
3018
|
def _aten_nextafter(input, other, *, out=None):
|
|
2938
3019
|
return jnp.nextafter(input, other)
|
|
2939
3020
|
|
|
2940
3021
|
|
|
2941
3022
|
@op(torch.ops.aten.nonzero_static)
|
|
2942
|
-
def _aten_nonzero_static(input, size, fill_value
|
|
3023
|
+
def _aten_nonzero_static(input, size, fill_value=-1):
|
|
2943
3024
|
indices = jnp.argwhere(input)
|
|
2944
3025
|
|
|
2945
3026
|
if size < indices.shape[0]:
|
|
2946
3027
|
indices = indices[:size]
|
|
2947
3028
|
elif size > indices.shape[0]:
|
|
2948
|
-
padding = jnp.full((size - indices.shape[0], indices.shape[1]),
|
|
3029
|
+
padding = jnp.full((size - indices.shape[0], indices.shape[1]),
|
|
3030
|
+
fill_value,
|
|
3031
|
+
dtype=indices.dtype)
|
|
2949
3032
|
indices = jnp.concatenate((indices, padding))
|
|
2950
3033
|
|
|
2951
3034
|
return indices
|
|
@@ -2954,9 +3037,11 @@ def _aten_nonzero_static(input, size, fill_value = -1):
|
|
|
2954
3037
|
# aten.nonzero
|
|
2955
3038
|
@op(torch.ops.aten.nonzero)
|
|
2956
3039
|
def _aten_nonzero(x, as_tuple=False):
|
|
2957
|
-
if jnp.ndim(x) == 0 and (as_tuple or x.item()==0):
|
|
3040
|
+
if jnp.ndim(x) == 0 and (as_tuple or x.item() == 0):
|
|
2958
3041
|
return torch.empty(0, 0, dtype=torch.int64)
|
|
2959
|
-
if jnp.ndim(
|
|
3042
|
+
if jnp.ndim(
|
|
3043
|
+
x
|
|
3044
|
+
) == 0: # when x is scalar, return torch.tensor([], size=(1, 0), dtype=torch.int64)
|
|
2960
3045
|
res = torch.empty(1, 0, dtype=torch.int64)
|
|
2961
3046
|
return jnp.array(res.numpy())
|
|
2962
3047
|
index_tuple = jnp.nonzero(x)
|
|
@@ -2997,15 +3082,15 @@ def _aten_put(self, index, source, accumulate=False):
|
|
|
2997
3082
|
# aten.randperm
|
|
2998
3083
|
# randperm.generator(SymInt n, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None)
|
|
2999
3084
|
@op(torch.ops.aten.randperm, needs_env=True)
|
|
3000
|
-
def _aten_randperm(
|
|
3001
|
-
|
|
3002
|
-
|
|
3003
|
-
|
|
3004
|
-
|
|
3005
|
-
|
|
3006
|
-
|
|
3007
|
-
|
|
3008
|
-
|
|
3085
|
+
def _aten_randperm(n,
|
|
3086
|
+
*,
|
|
3087
|
+
generator=None,
|
|
3088
|
+
dtype=None,
|
|
3089
|
+
layout=None,
|
|
3090
|
+
device=None,
|
|
3091
|
+
pin_memory=None,
|
|
3092
|
+
env=None):
|
|
3093
|
+
"""
|
|
3009
3094
|
Generates a random permutation of integers from 0 to n-1.
|
|
3010
3095
|
|
|
3011
3096
|
Args:
|
|
@@ -3019,14 +3104,14 @@ def _aten_randperm(
|
|
|
3019
3104
|
Returns:
|
|
3020
3105
|
A DeviceArray containing a random permutation of integers from 0 to n-1.
|
|
3021
3106
|
"""
|
|
3022
|
-
|
|
3023
|
-
|
|
3024
|
-
|
|
3025
|
-
|
|
3026
|
-
|
|
3027
|
-
|
|
3028
|
-
|
|
3029
|
-
|
|
3107
|
+
if dtype:
|
|
3108
|
+
dtype = mappings.t2j_dtype(dtype)
|
|
3109
|
+
else:
|
|
3110
|
+
dtype = jnp.int64.dtype
|
|
3111
|
+
key = env.get_and_rotate_prng_key(generator)
|
|
3112
|
+
indices = jnp.arange(n, dtype=dtype)
|
|
3113
|
+
permutation = jax.random.permutation(key, indices)
|
|
3114
|
+
return permutation
|
|
3030
3115
|
|
|
3031
3116
|
|
|
3032
3117
|
# aten.reflection_pad3d
|
|
@@ -3071,8 +3156,8 @@ def _aten_sort(a, dim=-1, descending=False, stable=False):
|
|
|
3071
3156
|
if a.shape == ():
|
|
3072
3157
|
return (a, jnp.astype(0, 'int64'))
|
|
3073
3158
|
return (
|
|
3074
|
-
|
|
3075
|
-
|
|
3159
|
+
jnp.sort(a, axis=dim, stable=stable, descending=descending),
|
|
3160
|
+
jnp.argsort(a, axis=dim, stable=stable, descending=descending),
|
|
3076
3161
|
)
|
|
3077
3162
|
|
|
3078
3163
|
|
|
@@ -3114,8 +3199,8 @@ def _aten_topk(input, k, dim=None, largest=True, sorted=True, *, out=None):
|
|
|
3114
3199
|
if dim != -1 and dim != len(input.shape) - 1:
|
|
3115
3200
|
transpose_shape = list(range(len(input.shape)))
|
|
3116
3201
|
transpose_shape[dim], transpose_shape[-1] = (
|
|
3117
|
-
|
|
3118
|
-
|
|
3202
|
+
transpose_shape[-1],
|
|
3203
|
+
transpose_shape[dim],
|
|
3119
3204
|
)
|
|
3120
3205
|
input = jnp.transpose(input, transpose_shape)
|
|
3121
3206
|
|
|
@@ -3124,8 +3209,7 @@ def _aten_topk(input, k, dim=None, largest=True, sorted=True, *, out=None):
|
|
|
3124
3209
|
if sorted:
|
|
3125
3210
|
values = jnp.sort(values, descending=True)
|
|
3126
3211
|
indices = jnp.take_along_axis(
|
|
3127
|
-
|
|
3128
|
-
)
|
|
3212
|
+
indices, jnp.argsort(values, axis=-1, descending=True), axis=-1)
|
|
3129
3213
|
|
|
3130
3214
|
if not largest:
|
|
3131
3215
|
values = -values # Negate values back if we found smallest
|
|
@@ -3140,21 +3224,39 @@ def _aten_topk(input, k, dim=None, largest=True, sorted=True, *, out=None):
|
|
|
3140
3224
|
# aten.tril_indices
|
|
3141
3225
|
#tril_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None)
|
|
3142
3226
|
@op(torch.ops.aten.tril_indices)
|
|
3143
|
-
def _aten_tril_indices(row,
|
|
3227
|
+
def _aten_tril_indices(row,
|
|
3228
|
+
col,
|
|
3229
|
+
offset=0,
|
|
3230
|
+
*,
|
|
3231
|
+
dtype=jnp.int64.dtype,
|
|
3232
|
+
layout=None,
|
|
3233
|
+
device=None,
|
|
3234
|
+
pin_memory=None):
|
|
3144
3235
|
a, b = jnp.tril_indices(row, offset, col)
|
|
3145
3236
|
return jnp.stack((a, b))
|
|
3146
3237
|
|
|
3238
|
+
|
|
3147
3239
|
# aten.tril_indices
|
|
3148
3240
|
#tril_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None)
|
|
3149
3241
|
@op(torch.ops.aten.triu_indices)
|
|
3150
|
-
def _aten_triu_indices(row,
|
|
3242
|
+
def _aten_triu_indices(row,
|
|
3243
|
+
col,
|
|
3244
|
+
offset=0,
|
|
3245
|
+
*,
|
|
3246
|
+
dtype=jnp.int64.dtype,
|
|
3247
|
+
layout=None,
|
|
3248
|
+
device=None,
|
|
3249
|
+
pin_memory=None):
|
|
3151
3250
|
a, b = jnp.triu_indices(row, offset, col)
|
|
3152
3251
|
return jnp.stack((a, b))
|
|
3153
3252
|
|
|
3154
3253
|
|
|
3155
3254
|
@op(torch.ops.aten.unbind_copy)
|
|
3156
3255
|
def _aten_unbind(a, dim=0):
|
|
3157
|
-
return [
|
|
3256
|
+
return [
|
|
3257
|
+
jax.lax.index_in_dim(a, i, dim, keepdims=False)
|
|
3258
|
+
for i in range(a.shape[dim])
|
|
3259
|
+
]
|
|
3158
3260
|
|
|
3159
3261
|
|
|
3160
3262
|
# aten.unique_dim
|
|
@@ -3167,12 +3269,13 @@ def _aten_unique_dim(input_tensor,
|
|
|
3167
3269
|
sort=True,
|
|
3168
3270
|
return_inverse=False,
|
|
3169
3271
|
return_counts=False):
|
|
3170
|
-
result_tensor_or_tuple = jnp.unique(
|
|
3171
|
-
|
|
3172
|
-
|
|
3173
|
-
|
|
3174
|
-
|
|
3175
|
-
|
|
3272
|
+
result_tensor_or_tuple = jnp.unique(
|
|
3273
|
+
input_tensor,
|
|
3274
|
+
return_index=False,
|
|
3275
|
+
return_inverse=return_inverse,
|
|
3276
|
+
return_counts=return_counts,
|
|
3277
|
+
axis=dim,
|
|
3278
|
+
equal_nan=False)
|
|
3176
3279
|
result_list = (
|
|
3177
3280
|
list(result_tensor_or_tuple) if isinstance(result_tensor_or_tuple, tuple)
|
|
3178
3281
|
else [result_tensor_or_tuple])
|
|
@@ -3197,15 +3300,14 @@ def _aten_unique_dim(input_tensor,
|
|
|
3197
3300
|
# NOTE: Like the CUDA and CPU implementations, this implementation always sorts
|
|
3198
3301
|
# the tensor regardless of the `sorted` argument passed to `torch.unique`.
|
|
3199
3302
|
@op(torch.ops.aten._unique)
|
|
3200
|
-
def _aten_unique(input_tensor,
|
|
3201
|
-
|
|
3202
|
-
|
|
3203
|
-
|
|
3204
|
-
|
|
3205
|
-
|
|
3206
|
-
|
|
3207
|
-
|
|
3208
|
-
equal_nan=False)
|
|
3303
|
+
def _aten_unique(input_tensor, sort=True, return_inverse=False):
|
|
3304
|
+
result_tensor_or_tuple = jnp.unique(
|
|
3305
|
+
input_tensor,
|
|
3306
|
+
return_index=False,
|
|
3307
|
+
return_inverse=return_inverse,
|
|
3308
|
+
return_counts=False,
|
|
3309
|
+
axis=None,
|
|
3310
|
+
equal_nan=False)
|
|
3209
3311
|
if return_inverse:
|
|
3210
3312
|
return result_tensor_or_tuple
|
|
3211
3313
|
else:
|
|
@@ -3221,11 +3323,12 @@ def _aten_unique2(input_tensor,
|
|
|
3221
3323
|
sort=True,
|
|
3222
3324
|
return_inverse=False,
|
|
3223
3325
|
return_counts=False):
|
|
3224
|
-
return _aten_unique_dim(
|
|
3225
|
-
|
|
3226
|
-
|
|
3227
|
-
|
|
3228
|
-
|
|
3326
|
+
return _aten_unique_dim(
|
|
3327
|
+
input_tensor=input_tensor,
|
|
3328
|
+
dim=None,
|
|
3329
|
+
sort=sort,
|
|
3330
|
+
return_inverse=return_inverse,
|
|
3331
|
+
return_counts=return_counts)
|
|
3229
3332
|
|
|
3230
3333
|
|
|
3231
3334
|
# aten.unique_consecutive
|
|
@@ -3255,17 +3358,18 @@ def _aten_unique_consecutive(input_tensor,
|
|
|
3255
3358
|
if dim < 0:
|
|
3256
3359
|
dim += ndim
|
|
3257
3360
|
|
|
3258
|
-
nd_slice_0 = tuple(
|
|
3259
|
-
|
|
3260
|
-
nd_slice_1 = tuple(
|
|
3261
|
-
|
|
3361
|
+
nd_slice_0 = tuple(
|
|
3362
|
+
slice(None, -1) if d == dim else slice(None) for d in range(ndim))
|
|
3363
|
+
nd_slice_1 = tuple(
|
|
3364
|
+
slice(1, None) if d == dim else slice(None) for d in range(ndim))
|
|
3262
3365
|
|
|
3263
3366
|
axes_to_reduce = tuple(d for d in range(ndim) if d != dim)
|
|
3264
3367
|
|
|
3265
3368
|
does_not_equal_prior = (
|
|
3266
|
-
jnp.any(
|
|
3267
|
-
|
|
3268
|
-
|
|
3369
|
+
jnp.any(
|
|
3370
|
+
input_tensor[nd_slice_0] != input_tensor[nd_slice_1],
|
|
3371
|
+
axis=axes_to_reduce,
|
|
3372
|
+
keepdims=False))
|
|
3269
3373
|
|
|
3270
3374
|
if input_tensor.shape[dim] != 0:
|
|
3271
3375
|
# Prepend `True` to represent the first element of the input.
|
|
@@ -3273,18 +3377,17 @@ def _aten_unique_consecutive(input_tensor,
|
|
|
3273
3377
|
|
|
3274
3378
|
include_indices = jnp.argwhere(does_not_equal_prior)[:, 0]
|
|
3275
3379
|
|
|
3276
|
-
output_tensor = input_tensor[
|
|
3277
|
-
|
|
3380
|
+
output_tensor = input_tensor[tuple(
|
|
3381
|
+
include_indices if d == dim else slice(None) for d in range(ndim))]
|
|
3278
3382
|
|
|
3279
3383
|
if return_inverse or return_counts:
|
|
3280
|
-
counts = (
|
|
3281
|
-
|
|
3384
|
+
counts = (
|
|
3385
|
+
jnp.append(include_indices[1:], input_tensor.shape[dim]) -
|
|
3386
|
+
include_indices[:])
|
|
3282
3387
|
|
|
3283
3388
|
inverse = (
|
|
3284
3389
|
jnp.reshape(jnp.repeat(jnp.arange(len(counts)), counts), inverse_shape)
|
|
3285
|
-
if return_inverse
|
|
3286
|
-
else None
|
|
3287
|
-
)
|
|
3390
|
+
if return_inverse else None)
|
|
3288
3391
|
|
|
3289
3392
|
return output_tensor, inverse, counts
|
|
3290
3393
|
|
|
@@ -3302,25 +3405,33 @@ def _aten_unique_consecutive(input_tensor,
|
|
|
3302
3405
|
@op(torch.ops.aten.where.ScalarSelf)
|
|
3303
3406
|
@op(torch.ops.aten.where.ScalarOther)
|
|
3304
3407
|
@op(torch.ops.aten.where.Scalar)
|
|
3305
|
-
def _aten_where(condition, x
|
|
3408
|
+
def _aten_where(condition, x=None, y=None):
|
|
3306
3409
|
return jnp.where(condition, x, y)
|
|
3307
3410
|
|
|
3308
3411
|
|
|
3309
3412
|
# aten.to.dtype
|
|
3310
3413
|
# Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None
|
|
3311
3414
|
@op(torch.ops.aten.to.dtype)
|
|
3312
|
-
def _aten_to_dtype(
|
|
3313
|
-
|
|
3314
|
-
|
|
3415
|
+
def _aten_to_dtype(a,
|
|
3416
|
+
dtype,
|
|
3417
|
+
non_blocking=False,
|
|
3418
|
+
copy=False,
|
|
3419
|
+
memory_format=None):
|
|
3315
3420
|
if dtype:
|
|
3316
3421
|
jaxdtype = mappings.t2j_dtype(dtype)
|
|
3317
3422
|
return a.astype(jaxdtype)
|
|
3318
3423
|
|
|
3319
3424
|
|
|
3320
3425
|
@op(torch.ops.aten.to.dtype_layout)
|
|
3321
|
-
def _aten_to_dtype_layout(
|
|
3322
|
-
|
|
3323
|
-
|
|
3426
|
+
def _aten_to_dtype_layout(a,
|
|
3427
|
+
*,
|
|
3428
|
+
dtype=None,
|
|
3429
|
+
layout=None,
|
|
3430
|
+
device=None,
|
|
3431
|
+
pin_memory=None,
|
|
3432
|
+
non_blocking=False,
|
|
3433
|
+
copy=False,
|
|
3434
|
+
memory_format=None):
|
|
3324
3435
|
return _aten_to_dtype(
|
|
3325
3436
|
a,
|
|
3326
3437
|
dtype,
|
|
@@ -3328,6 +3439,7 @@ def _aten_to_dtype_layout(
|
|
|
3328
3439
|
copy=copy,
|
|
3329
3440
|
memory_format=memory_format)
|
|
3330
3441
|
|
|
3442
|
+
|
|
3331
3443
|
# aten.to.device
|
|
3332
3444
|
|
|
3333
3445
|
|
|
@@ -3348,9 +3460,11 @@ def _aten_var_mean_correction(tensor, dim=None, correction=1, keepdim=False):
|
|
|
3348
3460
|
|
|
3349
3461
|
@op(torch.ops.aten.scalar_tensor)
|
|
3350
3462
|
@op_base.convert_dtype()
|
|
3351
|
-
def _aten_scalar_tensor(
|
|
3352
|
-
|
|
3353
|
-
|
|
3463
|
+
def _aten_scalar_tensor(s,
|
|
3464
|
+
dtype=None,
|
|
3465
|
+
layout=None,
|
|
3466
|
+
device=None,
|
|
3467
|
+
pin_memory=None):
|
|
3354
3468
|
return jnp.array(s, dtype=dtype)
|
|
3355
3469
|
|
|
3356
3470
|
|
|
@@ -3360,9 +3474,9 @@ def _aten_to_device(x, device, dtype):
|
|
|
3360
3474
|
|
|
3361
3475
|
|
|
3362
3476
|
@op(torch.ops.aten.max_pool2d_with_indices_backward)
|
|
3363
|
-
def max_pool2d_with_indices_backward_custom(
|
|
3364
|
-
|
|
3365
|
-
):
|
|
3477
|
+
def max_pool2d_with_indices_backward_custom(grad_output, self, kernel_size,
|
|
3478
|
+
stride, padding, dilation,
|
|
3479
|
+
ceil_mode, indices):
|
|
3366
3480
|
"""
|
|
3367
3481
|
Approximates the gradient calculation of PyTorch's max_pool2d_with_indices_backward.
|
|
3368
3482
|
|
|
@@ -3417,16 +3531,16 @@ def _aten_tensor_split(ary, indices_or_sections, axis=0):
|
|
|
3417
3531
|
|
|
3418
3532
|
@op(torch.ops.aten.randn, needs_env=True)
|
|
3419
3533
|
@op_base.convert_dtype()
|
|
3420
|
-
def
|
|
3421
|
-
|
|
3422
|
-
|
|
3423
|
-
|
|
3424
|
-
|
|
3425
|
-
|
|
3426
|
-
|
|
3427
|
-
|
|
3428
|
-
|
|
3429
|
-
|
|
3534
|
+
def _aten_randn(
|
|
3535
|
+
*size,
|
|
3536
|
+
generator=None,
|
|
3537
|
+
out=None,
|
|
3538
|
+
dtype=None,
|
|
3539
|
+
layout=torch.strided,
|
|
3540
|
+
device=None,
|
|
3541
|
+
requires_grad=False,
|
|
3542
|
+
pin_memory=False,
|
|
3543
|
+
env=None,
|
|
3430
3544
|
):
|
|
3431
3545
|
shape = size
|
|
3432
3546
|
if len(shape) == 1 and isinstance(shape[0], (list, tuple)):
|
|
@@ -3437,13 +3551,14 @@ def _randn(
|
|
|
3437
3551
|
res = res.astype(dtype)
|
|
3438
3552
|
return res
|
|
3439
3553
|
|
|
3554
|
+
|
|
3440
3555
|
@op(torch.ops.aten.bernoulli.p, needs_env=True)
|
|
3441
|
-
def
|
|
3442
|
-
|
|
3443
|
-
|
|
3444
|
-
|
|
3445
|
-
|
|
3446
|
-
|
|
3556
|
+
def _aten_bernoulli(
|
|
3557
|
+
self,
|
|
3558
|
+
p=0.5,
|
|
3559
|
+
*,
|
|
3560
|
+
generator=None,
|
|
3561
|
+
env=None,
|
|
3447
3562
|
):
|
|
3448
3563
|
key = env.get_and_rotate_prng_key(generator)
|
|
3449
3564
|
res = jax.random.uniform(key, self.shape) < p
|
|
@@ -3460,14 +3575,14 @@ def geometric(self, p, *, generator=None, env=None):
|
|
|
3460
3575
|
@op(torch.ops.aten.randn_like, needs_env=True)
|
|
3461
3576
|
@op_base.convert_dtype()
|
|
3462
3577
|
def _aten_randn_like(
|
|
3463
|
-
|
|
3464
|
-
|
|
3465
|
-
|
|
3466
|
-
|
|
3467
|
-
|
|
3468
|
-
|
|
3469
|
-
|
|
3470
|
-
|
|
3578
|
+
x,
|
|
3579
|
+
*,
|
|
3580
|
+
dtype=None,
|
|
3581
|
+
layout=None,
|
|
3582
|
+
device=None,
|
|
3583
|
+
pin_memory=False,
|
|
3584
|
+
memory_format=torch.preserve_format,
|
|
3585
|
+
env=None,
|
|
3471
3586
|
):
|
|
3472
3587
|
key = env.get_and_rotate_prng_key()
|
|
3473
3588
|
return jax.random.normal(key, dtype=dtype or x.dtype, shape=x.shape)
|
|
@@ -3476,15 +3591,15 @@ def _aten_randn_like(
|
|
|
3476
3591
|
@op(torch.ops.aten.rand, needs_env=True)
|
|
3477
3592
|
@op_base.convert_dtype()
|
|
3478
3593
|
def _rand(
|
|
3479
|
-
|
|
3480
|
-
|
|
3481
|
-
|
|
3482
|
-
|
|
3483
|
-
|
|
3484
|
-
|
|
3485
|
-
|
|
3486
|
-
|
|
3487
|
-
|
|
3594
|
+
*size,
|
|
3595
|
+
generator=None,
|
|
3596
|
+
out=None,
|
|
3597
|
+
dtype=None,
|
|
3598
|
+
layout=torch.strided,
|
|
3599
|
+
device=None,
|
|
3600
|
+
requires_grad=False,
|
|
3601
|
+
pin_memory=False,
|
|
3602
|
+
env=None,
|
|
3488
3603
|
):
|
|
3489
3604
|
shape = size
|
|
3490
3605
|
if len(shape) == 1 and isinstance(shape[0], (list, tuple)):
|
|
@@ -3505,32 +3620,48 @@ def _aten_outer(a, b):
|
|
|
3505
3620
|
def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):
|
|
3506
3621
|
return jnp.allclose(input, other, rtol, atol, equal_nan)
|
|
3507
3622
|
|
|
3623
|
+
|
|
3508
3624
|
@op(torch.ops.aten.native_batch_norm)
|
|
3509
|
-
def _aten_native_batch_norm(input,
|
|
3625
|
+
def _aten_native_batch_norm(input,
|
|
3626
|
+
weight,
|
|
3627
|
+
bias,
|
|
3628
|
+
running_mean,
|
|
3629
|
+
running_var,
|
|
3630
|
+
training=False,
|
|
3631
|
+
momentum=0.1,
|
|
3632
|
+
eps=1e-5):
|
|
3510
3633
|
|
|
3511
3634
|
if running_mean is None:
|
|
3512
|
-
running_mean = jnp.zeros(
|
|
3635
|
+
running_mean = jnp.zeros(
|
|
3636
|
+
input.shape[1], dtype=input.dtype) # Initialize running mean if None
|
|
3513
3637
|
if running_var is None:
|
|
3514
|
-
running_var = jnp.ones(
|
|
3638
|
+
running_var = jnp.ones(
|
|
3639
|
+
input.shape[1],
|
|
3640
|
+
dtype=input.dtype) # Initialize running variance if None
|
|
3515
3641
|
|
|
3516
3642
|
if training:
|
|
3517
|
-
return _aten__native_batch_norm_legit(input, weight, bias, running_mean,
|
|
3643
|
+
return _aten__native_batch_norm_legit(input, weight, bias, running_mean,
|
|
3644
|
+
running_var, training, momentum, eps)
|
|
3518
3645
|
else:
|
|
3519
|
-
return _aten__native_batch_norm_legit_no_training(input, weight, bias,
|
|
3646
|
+
return _aten__native_batch_norm_legit_no_training(input, weight, bias,
|
|
3647
|
+
running_mean, running_var,
|
|
3648
|
+
momentum, eps)
|
|
3520
3649
|
|
|
3521
3650
|
|
|
3522
3651
|
@op(torch.ops.aten.normal, needs_env=True)
|
|
3523
3652
|
def _aten_normal(self, mean=0, std=1, generator=None, env=None):
|
|
3524
3653
|
shape = self.shape
|
|
3525
|
-
res =
|
|
3654
|
+
res = _aten_randn(*shape, generator=generator, env=env)
|
|
3526
3655
|
return res * std + mean
|
|
3527
3656
|
|
|
3657
|
+
|
|
3528
3658
|
# TODO: not clear what this function should actually do
|
|
3529
3659
|
# https://github.com/pytorch/pytorch/blob/d96c80649f301129219469d8b4353e52edab3b78/aten/src/ATen/native/native_functions.yaml#L7933-L7940
|
|
3530
3660
|
@op(torch.ops.aten.lift_fresh)
|
|
3531
3661
|
def _aten_lift_fresh(self):
|
|
3532
3662
|
return self
|
|
3533
3663
|
|
|
3664
|
+
|
|
3534
3665
|
@op(torch.ops.aten.uniform, needs_env=True)
|
|
3535
3666
|
def _aten_uniform(self, from_=0, to=1, *, generator=None, env=None):
|
|
3536
3667
|
assert from_ <= to, f'Uniform from(passed in {from_}) must be less than to(passed in {to})'
|
|
@@ -3538,16 +3669,18 @@ def _aten_uniform(self, from_=0, to=1, *, generator=None, env=None):
|
|
|
3538
3669
|
res = _rand(*shape, generator=generator, env=env)
|
|
3539
3670
|
return res * (to - from_) + from_
|
|
3540
3671
|
|
|
3672
|
+
|
|
3541
3673
|
#func: randint.low_generator(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
|
3542
3674
|
|
|
3675
|
+
|
|
3543
3676
|
@op(torch.ops.aten.randint, needs_env=True)
|
|
3544
3677
|
@op_base.convert_dtype(use_default_dtype=False)
|
|
3545
3678
|
def _aten_randint(
|
|
3546
|
-
|
|
3547
|
-
|
|
3548
|
-
|
|
3549
|
-
|
|
3550
|
-
|
|
3679
|
+
*args,
|
|
3680
|
+
generator=None,
|
|
3681
|
+
dtype=None,
|
|
3682
|
+
env=None,
|
|
3683
|
+
**kwargs,
|
|
3551
3684
|
):
|
|
3552
3685
|
if len(args) == 3:
|
|
3553
3686
|
# low, high, size
|
|
@@ -3556,7 +3689,8 @@ def _aten_randint(
|
|
|
3556
3689
|
high, size = args
|
|
3557
3690
|
low = 0
|
|
3558
3691
|
else:
|
|
3559
|
-
raise AssertionError(
|
|
3692
|
+
raise AssertionError(
|
|
3693
|
+
f'Expected at 2 or 3 args for Aten::randint, got {len(args)}')
|
|
3560
3694
|
|
|
3561
3695
|
key = env.get_and_rotate_prng_key(generator)
|
|
3562
3696
|
res = jax.random.randint(key, size, low, high)
|
|
@@ -3564,15 +3698,18 @@ def _aten_randint(
|
|
|
3564
3698
|
res = res.astype(dtype)
|
|
3565
3699
|
return res
|
|
3566
3700
|
|
|
3567
|
-
|
|
3701
|
+
|
|
3702
|
+
@op(torch.ops.aten.randint_like,
|
|
3703
|
+
torch.ops.aten.randint.generator,
|
|
3704
|
+
needs_env=True)
|
|
3568
3705
|
@op_base.convert_dtype(use_default_dtype=False)
|
|
3569
3706
|
def _aten_randint_like(
|
|
3570
|
-
|
|
3571
|
-
|
|
3572
|
-
|
|
3573
|
-
|
|
3574
|
-
|
|
3575
|
-
|
|
3707
|
+
input,
|
|
3708
|
+
*args,
|
|
3709
|
+
generator=None,
|
|
3710
|
+
dtype=None,
|
|
3711
|
+
env=None,
|
|
3712
|
+
**kwargs,
|
|
3576
3713
|
):
|
|
3577
3714
|
if len(args) == 2:
|
|
3578
3715
|
low, high = args
|
|
@@ -3580,7 +3717,8 @@ def _aten_randint_like(
|
|
|
3580
3717
|
high = args[0]
|
|
3581
3718
|
low = 0
|
|
3582
3719
|
else:
|
|
3583
|
-
raise AssertionError(
|
|
3720
|
+
raise AssertionError(
|
|
3721
|
+
f'Expected at 1 or 2 args for Aten::randint_like, got {len(args)}')
|
|
3584
3722
|
|
|
3585
3723
|
shape = input.shape
|
|
3586
3724
|
dtype = dtype or input.dtype
|
|
@@ -3590,6 +3728,7 @@ def _aten_randint_like(
|
|
|
3590
3728
|
res = res.astype(dtype)
|
|
3591
3729
|
return res
|
|
3592
3730
|
|
|
3731
|
+
|
|
3593
3732
|
@op(torch.ops.aten.dim, is_jax_function=False)
|
|
3594
3733
|
def _aten_dim(self):
|
|
3595
3734
|
return len(self.shape)
|
|
@@ -3602,10 +3741,11 @@ def _aten_copysign(input, other, *, out=None):
|
|
|
3602
3741
|
# regardless of their exact integer dtype, whereas jax.copysign returns
|
|
3603
3742
|
# float64 when one or both of them is int64.
|
|
3604
3743
|
if jnp.issubdtype(input.dtype, jnp.integer) and jnp.issubdtype(
|
|
3605
|
-
|
|
3606
|
-
):
|
|
3744
|
+
other.dtype, jnp.integer):
|
|
3607
3745
|
result = result.astype(jnp.float32)
|
|
3608
3746
|
return result
|
|
3747
|
+
|
|
3748
|
+
|
|
3609
3749
|
@op(torch.ops.aten.i0)
|
|
3610
3750
|
@op_base.promote_int_input
|
|
3611
3751
|
def _aten_i0(self):
|
|
@@ -3637,6 +3777,7 @@ def _aten_special_laguerre_polynomial_l(self, n):
|
|
|
3637
3777
|
|
|
3638
3778
|
@jnp.vectorize
|
|
3639
3779
|
def vectorized(x, n_i):
|
|
3780
|
+
|
|
3640
3781
|
def negative_n(x):
|
|
3641
3782
|
return jnp.zeros_like(x)
|
|
3642
3783
|
|
|
@@ -3650,6 +3791,7 @@ def _aten_special_laguerre_polynomial_l(self, n):
|
|
|
3650
3791
|
return jnp.ones_like(x)
|
|
3651
3792
|
|
|
3652
3793
|
def default(x):
|
|
3794
|
+
|
|
3653
3795
|
def f(k, carry):
|
|
3654
3796
|
p, q = carry
|
|
3655
3797
|
return (q, ((k * 2 + (jnp.ones_like(x) - x)) * q - k * p) / (k + 1))
|
|
@@ -3658,9 +3800,9 @@ def _aten_special_laguerre_polynomial_l(self, n):
|
|
|
3658
3800
|
return q
|
|
3659
3801
|
|
|
3660
3802
|
return jnp.piecewise(
|
|
3661
|
-
x, [n_i == 1, n_i == 0,
|
|
3662
|
-
|
|
3663
|
-
|
|
3803
|
+
x, [n_i == 1, n_i == 0,
|
|
3804
|
+
jnp.abs(n_i) == jnp.zeros_like(x), n_i < 0],
|
|
3805
|
+
[one_n, zero_n, zero_abs, negative_n, default])
|
|
3664
3806
|
|
|
3665
3807
|
return vectorized(self, n.astype(jnp.int64))
|
|
3666
3808
|
|
|
@@ -3760,125 +3902,124 @@ def _aten_special_modified_bessel_i0(self):
|
|
|
3760
3902
|
return jnp.exp(x) * (0.5 * (b - p)) / jnp.sqrt(x)
|
|
3761
3903
|
|
|
3762
3904
|
self = jnp.abs(self)
|
|
3763
|
-
return jnp.piecewise(
|
|
3764
|
-
|
|
3765
|
-
)
|
|
3905
|
+
return jnp.piecewise(self, [self <= 8], [small, default])
|
|
3906
|
+
|
|
3766
3907
|
|
|
3767
3908
|
@op(torch.ops.aten.special_modified_bessel_i1)
|
|
3768
3909
|
@op_base.promote_int_input
|
|
3769
3910
|
def _aten_special_modified_bessel_i1(self):
|
|
3770
|
-
|
|
3771
|
-
|
|
3772
|
-
def small(x):
|
|
3773
|
-
A = jnp.array(
|
|
3774
|
-
[
|
|
3775
|
-
2.77791411276104639959e-18,
|
|
3776
|
-
-2.11142121435816608115e-17,
|
|
3777
|
-
1.55363195773620046921e-16,
|
|
3778
|
-
-1.10559694773538630805e-15,
|
|
3779
|
-
7.60068429473540693410e-15,
|
|
3780
|
-
-5.04218550472791168711e-14,
|
|
3781
|
-
3.22379336594557470981e-13,
|
|
3782
|
-
-1.98397439776494371520e-12,
|
|
3783
|
-
1.17361862988909016308e-11,
|
|
3784
|
-
-6.66348972350202774223e-11,
|
|
3785
|
-
3.62559028155211703701e-10,
|
|
3786
|
-
-1.88724975172282928790e-09,
|
|
3787
|
-
9.38153738649577178388e-09,
|
|
3788
|
-
-4.44505912879632808065e-08,
|
|
3789
|
-
2.00329475355213526229e-07,
|
|
3790
|
-
-8.56872026469545474066e-07,
|
|
3791
|
-
3.47025130813767847674e-06,
|
|
3792
|
-
-1.32731636560394358279e-05,
|
|
3793
|
-
4.78156510755005422638e-05,
|
|
3794
|
-
-1.61760815825896745588e-04,
|
|
3795
|
-
5.12285956168575772895e-04,
|
|
3796
|
-
-1.51357245063125314899e-03,
|
|
3797
|
-
4.15642294431288815669e-03,
|
|
3798
|
-
-1.05640848946261981558e-02,
|
|
3799
|
-
2.47264490306265168283e-02,
|
|
3800
|
-
-5.29459812080949914269e-02,
|
|
3801
|
-
1.02643658689847095384e-01,
|
|
3802
|
-
-1.76416518357834055153e-01,
|
|
3803
|
-
2.52587186443633654823e-01,
|
|
3804
|
-
],
|
|
3805
|
-
dtype=self.dtype,
|
|
3806
|
-
)
|
|
3807
|
-
|
|
3808
|
-
def f(carry, val):
|
|
3809
|
-
p, q, a = carry
|
|
3810
|
-
p, q = q, a
|
|
3811
|
-
return (p, q, ((jnp.abs(x) / 2.0) - 2.0) * q - p + val), None
|
|
3812
|
-
|
|
3813
|
-
(p, _, a), _ = jax.lax.scan(
|
|
3814
|
-
f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A)
|
|
3815
|
-
|
|
3816
|
-
return jax.lax.cond(
|
|
3817
|
-
x < 0, lambda: -(0.5 * (a - p) * jnp.abs(x) * jnp.exp(jnp.abs(x))), lambda: 0.5 * (a - p) * jnp.abs(x) * jnp.exp(jnp.abs(x))
|
|
3818
|
-
)
|
|
3911
|
+
# Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3271-L3364
|
|
3819
3912
|
|
|
3820
|
-
|
|
3821
|
-
|
|
3822
|
-
|
|
3823
|
-
|
|
3824
|
-
|
|
3825
|
-
|
|
3826
|
-
|
|
3827
|
-
|
|
3828
|
-
|
|
3829
|
-
|
|
3830
|
-
|
|
3831
|
-
|
|
3832
|
-
|
|
3833
|
-
|
|
3834
|
-
|
|
3835
|
-
|
|
3836
|
-
|
|
3837
|
-
|
|
3838
|
-
|
|
3839
|
-
|
|
3840
|
-
|
|
3841
|
-
|
|
3842
|
-
|
|
3843
|
-
|
|
3844
|
-
|
|
3845
|
-
|
|
3846
|
-
|
|
3847
|
-
|
|
3848
|
-
|
|
3849
|
-
|
|
3850
|
-
|
|
3851
|
-
|
|
3852
|
-
|
|
3853
|
-
|
|
3854
|
-
|
|
3855
|
-
return (p, q, (32.0 / jnp.abs(x) - 2.0) * q - p + val), None
|
|
3856
|
-
|
|
3857
|
-
(p, _, b), _ = jax.lax.scan(
|
|
3858
|
-
f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B)
|
|
3859
|
-
|
|
3860
|
-
return jax.lax.cond(
|
|
3861
|
-
x < 0, lambda: -(jnp.exp(jnp.abs(x)) * (0.5 * (b - p)) / jnp.sqrt(jnp.abs(x))), lambda: jnp.exp(jnp.abs(x)) * (0.5 * (b - p)) / jnp.sqrt(jnp.abs(x))
|
|
3862
|
-
)
|
|
3913
|
+
def small(x):
|
|
3914
|
+
A = jnp.array(
|
|
3915
|
+
[
|
|
3916
|
+
2.77791411276104639959e-18,
|
|
3917
|
+
-2.11142121435816608115e-17,
|
|
3918
|
+
1.55363195773620046921e-16,
|
|
3919
|
+
-1.10559694773538630805e-15,
|
|
3920
|
+
7.60068429473540693410e-15,
|
|
3921
|
+
-5.04218550472791168711e-14,
|
|
3922
|
+
3.22379336594557470981e-13,
|
|
3923
|
+
-1.98397439776494371520e-12,
|
|
3924
|
+
1.17361862988909016308e-11,
|
|
3925
|
+
-6.66348972350202774223e-11,
|
|
3926
|
+
3.62559028155211703701e-10,
|
|
3927
|
+
-1.88724975172282928790e-09,
|
|
3928
|
+
9.38153738649577178388e-09,
|
|
3929
|
+
-4.44505912879632808065e-08,
|
|
3930
|
+
2.00329475355213526229e-07,
|
|
3931
|
+
-8.56872026469545474066e-07,
|
|
3932
|
+
3.47025130813767847674e-06,
|
|
3933
|
+
-1.32731636560394358279e-05,
|
|
3934
|
+
4.78156510755005422638e-05,
|
|
3935
|
+
-1.61760815825896745588e-04,
|
|
3936
|
+
5.12285956168575772895e-04,
|
|
3937
|
+
-1.51357245063125314899e-03,
|
|
3938
|
+
4.15642294431288815669e-03,
|
|
3939
|
+
-1.05640848946261981558e-02,
|
|
3940
|
+
2.47264490306265168283e-02,
|
|
3941
|
+
-5.29459812080949914269e-02,
|
|
3942
|
+
1.02643658689847095384e-01,
|
|
3943
|
+
-1.76416518357834055153e-01,
|
|
3944
|
+
2.52587186443633654823e-01,
|
|
3945
|
+
],
|
|
3946
|
+
dtype=self.dtype,
|
|
3947
|
+
)
|
|
3863
3948
|
|
|
3864
|
-
|
|
3865
|
-
|
|
3949
|
+
def f(carry, val):
|
|
3950
|
+
p, q, a = carry
|
|
3951
|
+
p, q = q, a
|
|
3952
|
+
return (p, q, ((jnp.abs(x) / 2.0) - 2.0) * q - p + val), None
|
|
3953
|
+
|
|
3954
|
+
(p, _, a), _ = jax.lax.scan(
|
|
3955
|
+
f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A)
|
|
3956
|
+
|
|
3957
|
+
return jax.lax.cond(
|
|
3958
|
+
x < 0, lambda: -(0.5 * (a - p) * jnp.abs(x) * jnp.exp(jnp.abs(x))),
|
|
3959
|
+
lambda: 0.5 * (a - p) * jnp.abs(x) * jnp.exp(jnp.abs(x)))
|
|
3960
|
+
|
|
3961
|
+
def default(x):
|
|
3962
|
+
B = jnp.array(
|
|
3963
|
+
[
|
|
3964
|
+
7.51729631084210481353e-18,
|
|
3965
|
+
4.41434832307170791151e-18,
|
|
3966
|
+
-4.65030536848935832153e-17,
|
|
3967
|
+
-3.20952592199342395980e-17,
|
|
3968
|
+
2.96262899764595013876e-16,
|
|
3969
|
+
3.30820231092092828324e-16,
|
|
3970
|
+
-1.88035477551078244854e-15,
|
|
3971
|
+
-3.81440307243700780478e-15,
|
|
3972
|
+
1.04202769841288027642e-14,
|
|
3973
|
+
4.27244001671195135429e-14,
|
|
3974
|
+
-2.10154184277266431302e-14,
|
|
3975
|
+
-4.08355111109219731823e-13,
|
|
3976
|
+
-7.19855177624590851209e-13,
|
|
3977
|
+
2.03562854414708950722e-12,
|
|
3978
|
+
1.41258074366137813316e-11,
|
|
3979
|
+
3.25260358301548823856e-11,
|
|
3980
|
+
-1.89749581235054123450e-11,
|
|
3981
|
+
-5.58974346219658380687e-10,
|
|
3982
|
+
-3.83538038596423702205e-09,
|
|
3983
|
+
-2.63146884688951950684e-08,
|
|
3984
|
+
-2.51223623787020892529e-07,
|
|
3985
|
+
-3.88256480887769039346e-06,
|
|
3986
|
+
-1.10588938762623716291e-04,
|
|
3987
|
+
-9.76109749136146840777e-03,
|
|
3988
|
+
7.78576235018280120474e-01,
|
|
3989
|
+
],
|
|
3990
|
+
dtype=self.dtype,
|
|
3866
3991
|
)
|
|
3867
3992
|
|
|
3993
|
+
def f(carry, val):
|
|
3994
|
+
p, q, b = carry
|
|
3995
|
+
p, q = q, b
|
|
3996
|
+
return (p, q, (32.0 / jnp.abs(x) - 2.0) * q - p + val), None
|
|
3997
|
+
|
|
3998
|
+
(p, _, b), _ = jax.lax.scan(
|
|
3999
|
+
f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B)
|
|
4000
|
+
|
|
4001
|
+
return jax.lax.cond(
|
|
4002
|
+
x < 0, lambda: -(jnp.exp(jnp.abs(x)) *
|
|
4003
|
+
(0.5 * (b - p)) / jnp.sqrt(jnp.abs(x))),
|
|
4004
|
+
lambda: jnp.exp(jnp.abs(x)) * (0.5 * (b - p)) / jnp.sqrt(jnp.abs(x)))
|
|
4005
|
+
|
|
4006
|
+
return jnp.piecewise(self, [self <= 8], [small, default])
|
|
4007
|
+
|
|
4008
|
+
|
|
3868
4009
|
@op(torch.ops.aten.special_modified_bessel_k0)
|
|
3869
4010
|
@op_base.promote_int_input
|
|
3870
4011
|
def _aten_special_modified_bessel_k0(self):
|
|
3871
|
-
|
|
4012
|
+
# Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3367-L3441
|
|
3872
4013
|
|
|
3873
|
-
|
|
3874
|
-
|
|
4014
|
+
def zero(x):
|
|
4015
|
+
return jnp.array(jnp.inf, x.dtype)
|
|
3875
4016
|
|
|
3876
|
-
|
|
3877
|
-
|
|
4017
|
+
def negative(x):
|
|
4018
|
+
return jnp.array(jnp.nan, x.dtype)
|
|
3878
4019
|
|
|
3879
|
-
|
|
3880
|
-
|
|
3881
|
-
|
|
4020
|
+
def small(x):
|
|
4021
|
+
A = jnp.array(
|
|
4022
|
+
[
|
|
3882
4023
|
1.37446543561352307156e-16,
|
|
3883
4024
|
4.25981614279661018399e-14,
|
|
3884
4025
|
1.03496952576338420167e-11,
|
|
@@ -3889,23 +4030,24 @@ def _aten_special_modified_bessel_k0(self):
|
|
|
3889
4030
|
3.59799365153615016266e-02,
|
|
3890
4031
|
3.44289899924628486886e-01,
|
|
3891
4032
|
-5.35327393233902768720e-01,
|
|
3892
|
-
|
|
3893
|
-
|
|
3894
|
-
|
|
4033
|
+
],
|
|
4034
|
+
dtype=self.dtype,
|
|
4035
|
+
)
|
|
3895
4036
|
|
|
3896
|
-
|
|
3897
|
-
|
|
3898
|
-
|
|
3899
|
-
|
|
4037
|
+
def f(carry, val):
|
|
4038
|
+
p, q, a = carry
|
|
4039
|
+
p, q = q, a
|
|
4040
|
+
return (p, q, (x * x - 2.0) * q - p + val), None
|
|
4041
|
+
|
|
4042
|
+
(p, _, a), _ = jax.lax.scan(
|
|
4043
|
+
f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A)
|
|
3900
4044
|
|
|
3901
|
-
|
|
3902
|
-
|
|
3903
|
-
|
|
3904
|
-
return 0.5 * (a - p) - jnp.log(0.5 * x) * _aten_special_modified_bessel_i0(x)
|
|
4045
|
+
return 0.5 * (a - p) - jnp.log(
|
|
4046
|
+
0.5 * x) * _aten_special_modified_bessel_i0(x)
|
|
3905
4047
|
|
|
3906
|
-
|
|
3907
|
-
|
|
3908
|
-
|
|
4048
|
+
def default(x):
|
|
4049
|
+
B = jnp.array(
|
|
4050
|
+
[
|
|
3909
4051
|
5.30043377268626276149e-18,
|
|
3910
4052
|
-1.64758043015242134646e-17,
|
|
3911
4053
|
5.21039150503902756861e-17,
|
|
@@ -3931,38 +4073,38 @@ def _aten_special_modified_bessel_k0(self):
|
|
|
3931
4073
|
1.56988388573005337491e-03,
|
|
3932
4074
|
-3.14481013119645005427e-02,
|
|
3933
4075
|
2.44030308206595545468e+00,
|
|
3934
|
-
|
|
3935
|
-
|
|
3936
|
-
|
|
4076
|
+
],
|
|
4077
|
+
dtype=self.dtype,
|
|
4078
|
+
)
|
|
4079
|
+
|
|
4080
|
+
def f(carry, val):
|
|
4081
|
+
p, q, b = carry
|
|
4082
|
+
p, q = q, b
|
|
4083
|
+
return (p, q, (8.0 / x - 2.0) * q - p + val), None
|
|
3937
4084
|
|
|
3938
|
-
|
|
3939
|
-
|
|
3940
|
-
p, q = q, b
|
|
3941
|
-
return (p, q, (8.0 / x - 2.0) * q - p + val), None
|
|
4085
|
+
(p, _, b), _ = jax.lax.scan(
|
|
4086
|
+
f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B)
|
|
3942
4087
|
|
|
3943
|
-
|
|
3944
|
-
|
|
3945
|
-
|
|
3946
|
-
|
|
4088
|
+
return jnp.exp(-x) * (0.5 * (b - p)) / jnp.sqrt(x)
|
|
4089
|
+
|
|
4090
|
+
return jnp.piecewise(self, [self <= 2, self < 0, self == 0],
|
|
4091
|
+
[small, negative, zero, default])
|
|
3947
4092
|
|
|
3948
|
-
return jnp.piecewise(
|
|
3949
|
-
self, [self <= 2, self < 0, self == 0], [small, negative, zero, default]
|
|
3950
|
-
)
|
|
3951
4093
|
|
|
3952
4094
|
@op(torch.ops.aten.special_modified_bessel_k1)
|
|
3953
4095
|
@op_base.promote_int_input
|
|
3954
4096
|
def _aten_special_modified_bessel_k1(self):
|
|
3955
|
-
|
|
4097
|
+
# Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3444-L3519
|
|
3956
4098
|
|
|
3957
|
-
|
|
3958
|
-
|
|
4099
|
+
def zero(x):
|
|
4100
|
+
return jnp.array(jnp.inf, x.dtype)
|
|
3959
4101
|
|
|
3960
|
-
|
|
3961
|
-
|
|
4102
|
+
def negative(x):
|
|
4103
|
+
return jnp.array(jnp.nan, x.dtype)
|
|
3962
4104
|
|
|
3963
|
-
|
|
3964
|
-
|
|
3965
|
-
|
|
4105
|
+
def small(x):
|
|
4106
|
+
A = jnp.array(
|
|
4107
|
+
[
|
|
3966
4108
|
-7.02386347938628759343e-18,
|
|
3967
4109
|
-2.42744985051936593393e-15,
|
|
3968
4110
|
-6.66690169419932900609e-13,
|
|
@@ -3974,24 +4116,25 @@ def _aten_special_modified_bessel_k1(self):
|
|
|
3974
4116
|
-1.22611180822657148235e-01,
|
|
3975
4117
|
-3.53155960776544875667e-01,
|
|
3976
4118
|
1.52530022733894777053e+00,
|
|
3977
|
-
|
|
3978
|
-
|
|
3979
|
-
|
|
4119
|
+
],
|
|
4120
|
+
dtype=self.dtype,
|
|
4121
|
+
)
|
|
3980
4122
|
|
|
3981
|
-
|
|
3982
|
-
|
|
3983
|
-
|
|
3984
|
-
|
|
3985
|
-
|
|
4123
|
+
def f(carry, val):
|
|
4124
|
+
p, q, a = carry
|
|
4125
|
+
p, q = q, a
|
|
4126
|
+
a = (x * x - 2.0) * q - p + val
|
|
4127
|
+
return (p, q, a), None
|
|
4128
|
+
|
|
4129
|
+
(p, _, a), _ = jax.lax.scan(
|
|
4130
|
+
f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A)
|
|
3986
4131
|
|
|
3987
|
-
|
|
3988
|
-
|
|
3989
|
-
|
|
3990
|
-
return jnp.log(0.5 * x) * _aten_special_modified_bessel_i1(x) + 0.5 * (a - p) / x
|
|
4132
|
+
return jnp.log(
|
|
4133
|
+
0.5 * x) * _aten_special_modified_bessel_i1(x) + 0.5 * (a - p) / x
|
|
3991
4134
|
|
|
3992
|
-
|
|
3993
|
-
|
|
3994
|
-
|
|
4135
|
+
def default(x):
|
|
4136
|
+
B = jnp.array(
|
|
4137
|
+
[
|
|
3995
4138
|
-5.75674448366501715755e-18,
|
|
3996
4139
|
1.79405087314755922667e-17,
|
|
3997
4140
|
-5.68946255844285935196e-17,
|
|
@@ -4017,24 +4160,24 @@ def _aten_special_modified_bessel_k1(self):
|
|
|
4017
4160
|
-2.85781685962277938680e-03,
|
|
4018
4161
|
1.03923736576817238437e-01,
|
|
4019
4162
|
2.72062619048444266945e+00,
|
|
4020
|
-
|
|
4021
|
-
|
|
4022
|
-
|
|
4163
|
+
],
|
|
4164
|
+
dtype=self.dtype,
|
|
4165
|
+
)
|
|
4166
|
+
|
|
4167
|
+
def f(carry, val):
|
|
4168
|
+
p, q, b = carry
|
|
4169
|
+
p, q = q, b
|
|
4170
|
+
b = (8.0 / x - 2.0) * q - p + val
|
|
4171
|
+
return (p, q, b), None
|
|
4172
|
+
|
|
4173
|
+
(p, _, b), _ = jax.lax.scan(
|
|
4174
|
+
f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B)
|
|
4023
4175
|
|
|
4024
|
-
|
|
4025
|
-
p, q, b = carry
|
|
4026
|
-
p, q = q, b
|
|
4027
|
-
b = (8.0 / x - 2.0) * q - p + val
|
|
4028
|
-
return (p, q, b), None
|
|
4176
|
+
return jnp.exp(-x) * (0.5 * (b - p)) / jnp.sqrt(x)
|
|
4029
4177
|
|
|
4030
|
-
|
|
4031
|
-
|
|
4032
|
-
|
|
4033
|
-
return jnp.exp(-x) * (0.5 * (b - p)) / jnp.sqrt(x)
|
|
4178
|
+
return jnp.piecewise(self, [self <= 2, self < 0, self == 0],
|
|
4179
|
+
[small, negative, zero, default])
|
|
4034
4180
|
|
|
4035
|
-
return jnp.piecewise(
|
|
4036
|
-
self, [self <= 2, self < 0, self == 0], [small, negative, zero, default]
|
|
4037
|
-
)
|
|
4038
4181
|
|
|
4039
4182
|
@op(torch.ops.aten.polygamma)
|
|
4040
4183
|
def _aten_polygamma(x, n):
|
|
@@ -4042,10 +4185,12 @@ def _aten_polygamma(x, n):
|
|
|
4042
4185
|
n = n.astype(mappings.t2j_dtype(torch.get_default_dtype()))
|
|
4043
4186
|
return jax.lax.polygamma(jnp.float32(x), n)
|
|
4044
4187
|
|
|
4188
|
+
|
|
4045
4189
|
@op(torch.ops.aten.special_ndtri)
|
|
4046
4190
|
@op_base.promote_int_input
|
|
4047
4191
|
def _aten_special_ndtri(self):
|
|
4048
|
-
|
|
4192
|
+
return jax.scipy.special.ndtri(self)
|
|
4193
|
+
|
|
4049
4194
|
|
|
4050
4195
|
@op(torch.ops.aten.special_bessel_j0)
|
|
4051
4196
|
@op_base.promote_int_input
|
|
@@ -4057,112 +4202,104 @@ def _aten_special_bessel_j0(self):
|
|
|
4057
4202
|
|
|
4058
4203
|
def small(x):
|
|
4059
4204
|
RP = jnp.array(
|
|
4060
|
-
|
|
4061
|
-
|
|
4062
|
-
|
|
4063
|
-
|
|
4064
|
-
|
|
4065
|
-
|
|
4066
|
-
|
|
4205
|
+
[
|
|
4206
|
+
-4.79443220978201773821e09,
|
|
4207
|
+
1.95617491946556577543e12,
|
|
4208
|
+
-2.49248344360967716204e14,
|
|
4209
|
+
9.70862251047306323952e15,
|
|
4210
|
+
],
|
|
4211
|
+
dtype=self.dtype,
|
|
4067
4212
|
)
|
|
4068
4213
|
RQ = jnp.array(
|
|
4069
|
-
|
|
4070
|
-
|
|
4071
|
-
|
|
4072
|
-
|
|
4073
|
-
|
|
4074
|
-
|
|
4075
|
-
|
|
4076
|
-
|
|
4077
|
-
|
|
4078
|
-
|
|
4079
|
-
|
|
4214
|
+
[
|
|
4215
|
+
4.99563147152651017219e02,
|
|
4216
|
+
1.73785401676374683123e05,
|
|
4217
|
+
4.84409658339962045305e07,
|
|
4218
|
+
1.11855537045356834862e10,
|
|
4219
|
+
2.11277520115489217587e12,
|
|
4220
|
+
3.10518229857422583814e14,
|
|
4221
|
+
3.18121955943204943306e16,
|
|
4222
|
+
1.71086294081043136091e18,
|
|
4223
|
+
],
|
|
4224
|
+
dtype=self.dtype,
|
|
4080
4225
|
)
|
|
4081
4226
|
|
|
4082
4227
|
rp = op_base.foreach_loop(RP, lambda carry, rp_i: carry * (x * x) + rp_i)
|
|
4083
4228
|
rq = op_base.foreach_loop(RQ, lambda carry, rq_i: carry * (x * x) + rq_i)
|
|
4084
4229
|
|
|
4085
|
-
return (
|
|
4086
|
-
|
|
4087
|
-
* (x * x - 3.04712623436620863991e01)
|
|
4088
|
-
* rp
|
|
4089
|
-
/ rq
|
|
4090
|
-
)
|
|
4230
|
+
return ((x * x - 5.78318596294678452118e00) *
|
|
4231
|
+
(x * x - 3.04712623436620863991e01) * rp / rq)
|
|
4091
4232
|
|
|
4092
4233
|
def default(x):
|
|
4093
4234
|
PP = jnp.array(
|
|
4094
|
-
|
|
4095
|
-
|
|
4096
|
-
|
|
4097
|
-
|
|
4098
|
-
|
|
4099
|
-
|
|
4100
|
-
|
|
4101
|
-
|
|
4102
|
-
|
|
4103
|
-
|
|
4235
|
+
[
|
|
4236
|
+
7.96936729297347051624e-04,
|
|
4237
|
+
8.28352392107440799803e-02,
|
|
4238
|
+
1.23953371646414299388e00,
|
|
4239
|
+
5.44725003058768775090e00,
|
|
4240
|
+
8.74716500199817011941e00,
|
|
4241
|
+
5.30324038235394892183e00,
|
|
4242
|
+
9.99999999999999997821e-01,
|
|
4243
|
+
],
|
|
4244
|
+
dtype=self.dtype,
|
|
4104
4245
|
)
|
|
4105
4246
|
PQ = jnp.array(
|
|
4106
|
-
|
|
4107
|
-
|
|
4108
|
-
|
|
4109
|
-
|
|
4110
|
-
|
|
4111
|
-
|
|
4112
|
-
|
|
4113
|
-
|
|
4114
|
-
|
|
4115
|
-
|
|
4247
|
+
[
|
|
4248
|
+
9.24408810558863637013e-04,
|
|
4249
|
+
8.56288474354474431428e-02,
|
|
4250
|
+
1.25352743901058953537e00,
|
|
4251
|
+
5.47097740330417105182e00,
|
|
4252
|
+
8.76190883237069594232e00,
|
|
4253
|
+
5.30605288235394617618e00,
|
|
4254
|
+
1.00000000000000000218e00,
|
|
4255
|
+
],
|
|
4256
|
+
dtype=self.dtype,
|
|
4116
4257
|
)
|
|
4117
4258
|
QP = jnp.array(
|
|
4118
|
-
|
|
4119
|
-
|
|
4120
|
-
|
|
4121
|
-
|
|
4122
|
-
|
|
4123
|
-
|
|
4124
|
-
|
|
4125
|
-
|
|
4126
|
-
|
|
4127
|
-
|
|
4128
|
-
|
|
4259
|
+
[
|
|
4260
|
+
-1.13663838898469149931e-02,
|
|
4261
|
+
-1.28252718670509318512e00,
|
|
4262
|
+
-1.95539544257735972385e01,
|
|
4263
|
+
-9.32060152123768231369e01,
|
|
4264
|
+
-1.77681167980488050595e02,
|
|
4265
|
+
-1.47077505154951170175e02,
|
|
4266
|
+
-5.14105326766599330220e01,
|
|
4267
|
+
-6.05014350600728481186e00,
|
|
4268
|
+
],
|
|
4269
|
+
dtype=self.dtype,
|
|
4129
4270
|
)
|
|
4130
4271
|
QQ = jnp.array(
|
|
4131
|
-
|
|
4132
|
-
|
|
4133
|
-
|
|
4134
|
-
|
|
4135
|
-
|
|
4136
|
-
|
|
4137
|
-
|
|
4138
|
-
|
|
4139
|
-
|
|
4140
|
-
|
|
4272
|
+
[
|
|
4273
|
+
6.43178256118178023184e01,
|
|
4274
|
+
8.56430025976980587198e02,
|
|
4275
|
+
3.88240183605401609683e03,
|
|
4276
|
+
7.24046774195652478189e03,
|
|
4277
|
+
5.93072701187316984827e03,
|
|
4278
|
+
2.06209331660327847417e03,
|
|
4279
|
+
2.42005740240291393179e02,
|
|
4280
|
+
],
|
|
4281
|
+
dtype=self.dtype,
|
|
4141
4282
|
)
|
|
4142
4283
|
|
|
4143
|
-
pp = op_base.foreach_loop(
|
|
4144
|
-
|
|
4145
|
-
|
|
4146
|
-
|
|
4147
|
-
|
|
4148
|
-
|
|
4149
|
-
|
|
4150
|
-
|
|
4151
|
-
|
|
4152
|
-
|
|
4153
|
-
|
|
4154
|
-
|
|
4155
|
-
|
|
4156
|
-
* 0.797884560802865355879892119868763737
|
|
4157
|
-
/ jnp.sqrt(x)
|
|
4158
|
-
)
|
|
4284
|
+
pp = op_base.foreach_loop(
|
|
4285
|
+
PP, lambda carry, pp_i: carry * (25.0 / (x * x)) + pp_i)
|
|
4286
|
+
pq = op_base.foreach_loop(
|
|
4287
|
+
PQ, lambda carry, pq_i: carry * (25.0 / (x * x)) + pq_i)
|
|
4288
|
+
qp = op_base.foreach_loop(
|
|
4289
|
+
QP, lambda carry, qp_i: carry * (25.0 / (x * x)) + qp_i)
|
|
4290
|
+
qq = op_base.foreach_loop(
|
|
4291
|
+
QQ, lambda carry, qq_i: carry * (25.0 / (x * x)) + qq_i)
|
|
4292
|
+
|
|
4293
|
+
return ((pp / pq * jnp.cos(x - 0.785398163397448309615660845819875721) -
|
|
4294
|
+
5.0 / x *
|
|
4295
|
+
(qp / qq) * jnp.sin(x - 0.785398163397448309615660845819875721)) *
|
|
4296
|
+
0.797884560802865355879892119868763737 / jnp.sqrt(x))
|
|
4159
4297
|
|
|
4160
4298
|
self = jnp.abs(self)
|
|
4161
4299
|
# Last True condition in `piecewise` takes priority, but last function is
|
|
4162
4300
|
# default. See https://github.com/numpy/numpy/issues/16475
|
|
4163
|
-
return jnp.piecewise(
|
|
4164
|
-
|
|
4165
|
-
)
|
|
4301
|
+
return jnp.piecewise(self, [self <= 5.0, self < 0.00001],
|
|
4302
|
+
[small, very_small, default])
|
|
4166
4303
|
|
|
4167
4304
|
|
|
4168
4305
|
@op(torch.ops.aten.special_bessel_j1)
|
|
@@ -4172,114 +4309,106 @@ def _aten_special_bessel_j1(self):
|
|
|
4172
4309
|
|
|
4173
4310
|
def small(x):
|
|
4174
4311
|
RP = jnp.array(
|
|
4175
|
-
|
|
4176
|
-
|
|
4177
|
-
|
|
4178
|
-
|
|
4179
|
-
|
|
4180
|
-
|
|
4181
|
-
|
|
4312
|
+
[
|
|
4313
|
+
-8.99971225705559398224e08,
|
|
4314
|
+
4.52228297998194034323e11,
|
|
4315
|
+
-7.27494245221818276015e13,
|
|
4316
|
+
3.68295732863852883286e15,
|
|
4317
|
+
],
|
|
4318
|
+
dtype=self.dtype,
|
|
4182
4319
|
)
|
|
4183
4320
|
RQ = jnp.array(
|
|
4184
|
-
|
|
4185
|
-
|
|
4186
|
-
|
|
4187
|
-
|
|
4188
|
-
|
|
4189
|
-
|
|
4190
|
-
|
|
4191
|
-
|
|
4192
|
-
|
|
4193
|
-
|
|
4194
|
-
|
|
4321
|
+
[
|
|
4322
|
+
6.20836478118054335476e02,
|
|
4323
|
+
2.56987256757748830383e05,
|
|
4324
|
+
8.35146791431949253037e07,
|
|
4325
|
+
2.21511595479792499675e10,
|
|
4326
|
+
4.74914122079991414898e12,
|
|
4327
|
+
7.84369607876235854894e14,
|
|
4328
|
+
8.95222336184627338078e16,
|
|
4329
|
+
5.32278620332680085395e18,
|
|
4330
|
+
],
|
|
4331
|
+
dtype=self.dtype,
|
|
4195
4332
|
)
|
|
4196
4333
|
|
|
4197
4334
|
rp = op_base.foreach_loop(RP, lambda carry, rp_i: carry * (x * x) + rp_i)
|
|
4198
4335
|
rq = op_base.foreach_loop(RQ, lambda carry, rq_i: carry * (x * x) + rq_i)
|
|
4199
4336
|
|
|
4200
|
-
return (
|
|
4201
|
-
|
|
4202
|
-
/ rq
|
|
4203
|
-
* x
|
|
4204
|
-
* (x * x - 1.46819706421238932572e01)
|
|
4205
|
-
* (x * x - 4.92184563216946036703e01)
|
|
4206
|
-
)
|
|
4337
|
+
return (rp / rq * x * (x * x - 1.46819706421238932572e01) *
|
|
4338
|
+
(x * x - 4.92184563216946036703e01))
|
|
4207
4339
|
|
|
4208
4340
|
def default(x):
|
|
4209
4341
|
PP = jnp.array(
|
|
4210
|
-
|
|
4211
|
-
|
|
4212
|
-
|
|
4213
|
-
|
|
4214
|
-
|
|
4215
|
-
|
|
4216
|
-
|
|
4217
|
-
|
|
4218
|
-
|
|
4219
|
-
|
|
4342
|
+
[
|
|
4343
|
+
7.62125616208173112003e-04,
|
|
4344
|
+
7.31397056940917570436e-02,
|
|
4345
|
+
1.12719608129684925192e00,
|
|
4346
|
+
5.11207951146807644818e00,
|
|
4347
|
+
8.42404590141772420927e00,
|
|
4348
|
+
5.21451598682361504063e00,
|
|
4349
|
+
1.00000000000000000254e00,
|
|
4350
|
+
],
|
|
4351
|
+
dtype=self.dtype,
|
|
4220
4352
|
)
|
|
4221
4353
|
PQ = jnp.array(
|
|
4222
|
-
|
|
4223
|
-
|
|
4224
|
-
|
|
4225
|
-
|
|
4226
|
-
|
|
4227
|
-
|
|
4228
|
-
|
|
4229
|
-
|
|
4230
|
-
|
|
4231
|
-
|
|
4354
|
+
[
|
|
4355
|
+
5.71323128072548699714e-04,
|
|
4356
|
+
6.88455908754495404082e-02,
|
|
4357
|
+
1.10514232634061696926e00,
|
|
4358
|
+
5.07386386128601488557e00,
|
|
4359
|
+
8.39985554327604159757e00,
|
|
4360
|
+
5.20982848682361821619e00,
|
|
4361
|
+
9.99999999999999997461e-01,
|
|
4362
|
+
],
|
|
4363
|
+
dtype=self.dtype,
|
|
4232
4364
|
)
|
|
4233
4365
|
QP = jnp.array(
|
|
4234
|
-
|
|
4235
|
-
|
|
4236
|
-
|
|
4237
|
-
|
|
4238
|
-
|
|
4239
|
-
|
|
4240
|
-
|
|
4241
|
-
|
|
4242
|
-
|
|
4243
|
-
|
|
4244
|
-
|
|
4366
|
+
[
|
|
4367
|
+
5.10862594750176621635e-02,
|
|
4368
|
+
4.98213872951233449420e00,
|
|
4369
|
+
7.58238284132545283818e01,
|
|
4370
|
+
3.66779609360150777800e02,
|
|
4371
|
+
7.10856304998926107277e02,
|
|
4372
|
+
5.97489612400613639965e02,
|
|
4373
|
+
2.11688757100572135698e02,
|
|
4374
|
+
2.52070205858023719784e01,
|
|
4375
|
+
],
|
|
4376
|
+
dtype=self.dtype,
|
|
4245
4377
|
)
|
|
4246
4378
|
QQ = jnp.array(
|
|
4247
|
-
|
|
4248
|
-
|
|
4249
|
-
|
|
4250
|
-
|
|
4251
|
-
|
|
4252
|
-
|
|
4253
|
-
|
|
4254
|
-
|
|
4255
|
-
|
|
4256
|
-
|
|
4379
|
+
[
|
|
4380
|
+
7.42373277035675149943e01,
|
|
4381
|
+
1.05644886038262816351e03,
|
|
4382
|
+
4.98641058337653607651e03,
|
|
4383
|
+
9.56231892404756170795e03,
|
|
4384
|
+
7.99704160447350683650e03,
|
|
4385
|
+
2.82619278517639096600e03,
|
|
4386
|
+
3.36093607810698293419e02,
|
|
4387
|
+
],
|
|
4388
|
+
dtype=self.dtype,
|
|
4257
4389
|
)
|
|
4258
4390
|
|
|
4259
|
-
pp = op_base.foreach_loop(
|
|
4260
|
-
|
|
4261
|
-
|
|
4262
|
-
|
|
4263
|
-
|
|
4264
|
-
|
|
4265
|
-
|
|
4266
|
-
|
|
4267
|
-
|
|
4268
|
-
|
|
4269
|
-
|
|
4270
|
-
|
|
4271
|
-
|
|
4272
|
-
* 0.797884560802865355879892119868763737
|
|
4273
|
-
/ jnp.sqrt(x)
|
|
4274
|
-
)
|
|
4391
|
+
pp = op_base.foreach_loop(
|
|
4392
|
+
PP, lambda carry, pp_i: carry * (25.0 / (x * x)) + pp_i)
|
|
4393
|
+
pq = op_base.foreach_loop(
|
|
4394
|
+
PQ, lambda carry, pq_i: carry * (25.0 / (x * x)) + pq_i)
|
|
4395
|
+
qp = op_base.foreach_loop(
|
|
4396
|
+
QP, lambda carry, qp_i: carry * (25.0 / (x * x)) + qp_i)
|
|
4397
|
+
qq = op_base.foreach_loop(
|
|
4398
|
+
QQ, lambda carry, qq_i: carry * (25.0 / (x * x)) + qq_i)
|
|
4399
|
+
|
|
4400
|
+
return ((pp / pq * jnp.cos(x - 2.356194490192344928846982537459627163) -
|
|
4401
|
+
5.0 / x *
|
|
4402
|
+
(qp / qq) * jnp.sin(x - 2.356194490192344928846982537459627163)) *
|
|
4403
|
+
0.797884560802865355879892119868763737 / jnp.sqrt(x))
|
|
4275
4404
|
|
|
4276
4405
|
# If x < 0, bessel_j1(x) = -bessel_j1(-x)
|
|
4277
4406
|
sign = jnp.sign(self)
|
|
4278
4407
|
self = jnp.abs(self)
|
|
4279
4408
|
return sign * jnp.piecewise(
|
|
4280
|
-
|
|
4281
|
-
|
|
4282
|
-
|
|
4409
|
+
self,
|
|
4410
|
+
[self <= 5.0],
|
|
4411
|
+
[small, default],
|
|
4283
4412
|
)
|
|
4284
4413
|
|
|
4285
4414
|
|
|
@@ -4296,85 +4425,86 @@ def _aten_special_bessel_y0(self):
|
|
|
4296
4425
|
|
|
4297
4426
|
def small(x):
|
|
4298
4427
|
YP = jnp.array(
|
|
4299
|
-
|
|
4300
|
-
|
|
4301
|
-
|
|
4302
|
-
|
|
4303
|
-
|
|
4304
|
-
|
|
4305
|
-
|
|
4306
|
-
|
|
4307
|
-
|
|
4308
|
-
|
|
4309
|
-
|
|
4428
|
+
[
|
|
4429
|
+
1.55924367855235737965e04,
|
|
4430
|
+
-1.46639295903971606143e07,
|
|
4431
|
+
5.43526477051876500413e09,
|
|
4432
|
+
-9.82136065717911466409e11,
|
|
4433
|
+
8.75906394395366999549e13,
|
|
4434
|
+
-3.46628303384729719441e15,
|
|
4435
|
+
4.42733268572569800351e16,
|
|
4436
|
+
-1.84950800436986690637e16,
|
|
4437
|
+
],
|
|
4438
|
+
dtype=self.dtype,
|
|
4310
4439
|
)
|
|
4311
4440
|
YQ = jnp.array(
|
|
4312
|
-
|
|
4313
|
-
|
|
4314
|
-
|
|
4315
|
-
|
|
4316
|
-
|
|
4317
|
-
|
|
4318
|
-
|
|
4319
|
-
|
|
4320
|
-
|
|
4321
|
-
|
|
4441
|
+
[
|
|
4442
|
+
1.04128353664259848412e03,
|
|
4443
|
+
6.26107330137134956842e05,
|
|
4444
|
+
2.68919633393814121987e08,
|
|
4445
|
+
8.64002487103935000337e10,
|
|
4446
|
+
2.02979612750105546709e13,
|
|
4447
|
+
3.17157752842975028269e15,
|
|
4448
|
+
2.50596256172653059228e17,
|
|
4449
|
+
],
|
|
4450
|
+
dtype=self.dtype,
|
|
4322
4451
|
)
|
|
4323
4452
|
|
|
4324
4453
|
yp = op_base.foreach_loop(YP, lambda carry, yp_i: carry * (x * x) + yp_i)
|
|
4325
4454
|
yq = op_base.foreach_loop(YQ, lambda carry, yq_i: carry * (x * x) + yq_i)
|
|
4326
4455
|
|
|
4327
|
-
return yp / yq + (0.636619772367581343075535053490057448 * jnp.log(x) *
|
|
4456
|
+
return yp / yq + (0.636619772367581343075535053490057448 * jnp.log(x) *
|
|
4457
|
+
_aten_special_bessel_j0(x))
|
|
4328
4458
|
|
|
4329
4459
|
def default(x):
|
|
4330
4460
|
PP = jnp.array(
|
|
4331
|
-
|
|
4332
|
-
|
|
4333
|
-
|
|
4334
|
-
|
|
4335
|
-
|
|
4336
|
-
|
|
4337
|
-
|
|
4338
|
-
|
|
4339
|
-
|
|
4340
|
-
|
|
4461
|
+
[
|
|
4462
|
+
7.96936729297347051624e-04,
|
|
4463
|
+
8.28352392107440799803e-02,
|
|
4464
|
+
1.23953371646414299388e00,
|
|
4465
|
+
5.44725003058768775090e00,
|
|
4466
|
+
8.74716500199817011941e00,
|
|
4467
|
+
5.30324038235394892183e00,
|
|
4468
|
+
9.99999999999999997821e-01,
|
|
4469
|
+
],
|
|
4470
|
+
dtype=self.dtype,
|
|
4341
4471
|
)
|
|
4342
4472
|
PQ = jnp.array(
|
|
4343
|
-
|
|
4344
|
-
|
|
4345
|
-
|
|
4346
|
-
|
|
4347
|
-
|
|
4348
|
-
|
|
4349
|
-
|
|
4350
|
-
|
|
4351
|
-
|
|
4352
|
-
|
|
4473
|
+
[
|
|
4474
|
+
9.24408810558863637013e-04,
|
|
4475
|
+
8.56288474354474431428e-02,
|
|
4476
|
+
1.25352743901058953537e00,
|
|
4477
|
+
5.47097740330417105182e00,
|
|
4478
|
+
8.76190883237069594232e00,
|
|
4479
|
+
5.30605288235394617618e00,
|
|
4480
|
+
1.00000000000000000218e00,
|
|
4481
|
+
],
|
|
4482
|
+
dtype=self.dtype,
|
|
4353
4483
|
)
|
|
4354
4484
|
QP = jnp.array(
|
|
4355
|
-
|
|
4356
|
-
|
|
4357
|
-
|
|
4358
|
-
|
|
4359
|
-
|
|
4360
|
-
|
|
4361
|
-
|
|
4362
|
-
|
|
4363
|
-
|
|
4364
|
-
|
|
4365
|
-
|
|
4485
|
+
[
|
|
4486
|
+
-1.13663838898469149931e-02,
|
|
4487
|
+
-1.28252718670509318512e00,
|
|
4488
|
+
-1.95539544257735972385e01,
|
|
4489
|
+
-9.32060152123768231369e01,
|
|
4490
|
+
-1.77681167980488050595e02,
|
|
4491
|
+
-1.47077505154951170175e02,
|
|
4492
|
+
-5.14105326766599330220e01,
|
|
4493
|
+
-6.05014350600728481186e00,
|
|
4494
|
+
],
|
|
4495
|
+
dtype=self.dtype,
|
|
4366
4496
|
)
|
|
4367
4497
|
QQ = jnp.array(
|
|
4368
|
-
|
|
4369
|
-
|
|
4370
|
-
|
|
4371
|
-
|
|
4372
|
-
|
|
4373
|
-
|
|
4374
|
-
|
|
4375
|
-
|
|
4376
|
-
|
|
4377
|
-
|
|
4498
|
+
[
|
|
4499
|
+
6.43178256118178023184e01,
|
|
4500
|
+
8.56430025976980587198e02,
|
|
4501
|
+
3.88240183605401609683e03,
|
|
4502
|
+
7.24046774195652478189e03,
|
|
4503
|
+
5.93072701187316984827e03,
|
|
4504
|
+
2.06209331660327847417e03,
|
|
4505
|
+
2.42005740240291393179e02,
|
|
4506
|
+
],
|
|
4507
|
+
dtype=self.dtype,
|
|
4378
4508
|
)
|
|
4379
4509
|
|
|
4380
4510
|
factor = 25.0 / (x * x)
|
|
@@ -4383,22 +4513,15 @@ def _aten_special_bessel_y0(self):
|
|
|
4383
4513
|
qp = op_base.foreach_loop(QP, lambda carry, qp_i: carry * factor + qp_i)
|
|
4384
4514
|
qq = op_base.foreach_loop(QQ, lambda carry, qq_i: carry * factor + qq_i)
|
|
4385
4515
|
|
|
4386
|
-
return (
|
|
4387
|
-
|
|
4388
|
-
|
|
4389
|
-
|
|
4390
|
-
/ x
|
|
4391
|
-
* (qp / qq)
|
|
4392
|
-
* jnp.cos(x - 0.785398163397448309615660845819875721)
|
|
4393
|
-
)
|
|
4394
|
-
* 0.797884560802865355879892119868763737
|
|
4395
|
-
/ jnp.sqrt(x)
|
|
4396
|
-
)
|
|
4516
|
+
return ((pp / pq * jnp.sin(x - 0.785398163397448309615660845819875721) +
|
|
4517
|
+
5.0 / x *
|
|
4518
|
+
(qp / qq) * jnp.cos(x - 0.785398163397448309615660845819875721)) *
|
|
4519
|
+
0.797884560802865355879892119868763737 / jnp.sqrt(x))
|
|
4397
4520
|
|
|
4398
4521
|
return jnp.piecewise(
|
|
4399
|
-
|
|
4400
|
-
|
|
4401
|
-
|
|
4522
|
+
self,
|
|
4523
|
+
[self <= 5.0, self < 0., self == 0.],
|
|
4524
|
+
[small, negative, zero, default],
|
|
4402
4525
|
)
|
|
4403
4526
|
|
|
4404
4527
|
|
|
@@ -4415,90 +4538,86 @@ def _aten_special_bessel_y1(self):
|
|
|
4415
4538
|
|
|
4416
4539
|
def small(x):
|
|
4417
4540
|
YP = jnp.array(
|
|
4418
|
-
|
|
4419
|
-
|
|
4420
|
-
|
|
4421
|
-
|
|
4422
|
-
|
|
4423
|
-
|
|
4424
|
-
|
|
4425
|
-
|
|
4426
|
-
|
|
4541
|
+
[
|
|
4542
|
+
1.26320474790178026440e09,
|
|
4543
|
+
-6.47355876379160291031e11,
|
|
4544
|
+
1.14509511541823727583e14,
|
|
4545
|
+
-8.12770255501325109621e15,
|
|
4546
|
+
2.02439475713594898196e17,
|
|
4547
|
+
-7.78877196265950026825e17,
|
|
4548
|
+
],
|
|
4549
|
+
dtype=self.dtype,
|
|
4427
4550
|
)
|
|
4428
4551
|
YQ = jnp.array(
|
|
4429
|
-
|
|
4430
|
-
|
|
4431
|
-
|
|
4432
|
-
|
|
4433
|
-
|
|
4434
|
-
|
|
4435
|
-
|
|
4436
|
-
|
|
4437
|
-
|
|
4438
|
-
|
|
4439
|
-
|
|
4552
|
+
[
|
|
4553
|
+
5.94301592346128195359e02,
|
|
4554
|
+
2.35564092943068577943e05,
|
|
4555
|
+
7.34811944459721705660e07,
|
|
4556
|
+
1.87601316108706159478e10,
|
|
4557
|
+
3.88231277496238566008e12,
|
|
4558
|
+
6.20557727146953693363e14,
|
|
4559
|
+
6.87141087355300489866e16,
|
|
4560
|
+
3.97270608116560655612e18,
|
|
4561
|
+
],
|
|
4562
|
+
dtype=self.dtype,
|
|
4440
4563
|
)
|
|
4441
4564
|
|
|
4442
4565
|
yp = op_base.foreach_loop(YP, lambda carry, yp_i: carry * (x * x) + yp_i)
|
|
4443
4566
|
yq = op_base.foreach_loop(YQ, lambda carry, yq_i: carry * (x * x) + yq_i)
|
|
4444
4567
|
|
|
4445
|
-
return (
|
|
4446
|
-
|
|
4447
|
-
|
|
4448
|
-
0.636619772367581343075535053490057448
|
|
4449
|
-
* (_aten_special_bessel_j1(x) * jnp.log(x) - 1.0 / x)
|
|
4450
|
-
)
|
|
4451
|
-
)
|
|
4568
|
+
return (x * (yp / yq) +
|
|
4569
|
+
(0.636619772367581343075535053490057448 *
|
|
4570
|
+
(_aten_special_bessel_j1(x) * jnp.log(x) - 1.0 / x)))
|
|
4452
4571
|
|
|
4453
4572
|
def default(x):
|
|
4454
4573
|
PP = jnp.array(
|
|
4455
|
-
|
|
4456
|
-
|
|
4457
|
-
|
|
4458
|
-
|
|
4459
|
-
|
|
4460
|
-
|
|
4461
|
-
|
|
4462
|
-
|
|
4463
|
-
|
|
4464
|
-
|
|
4574
|
+
[
|
|
4575
|
+
7.62125616208173112003e-04,
|
|
4576
|
+
7.31397056940917570436e-02,
|
|
4577
|
+
1.12719608129684925192e00,
|
|
4578
|
+
5.11207951146807644818e00,
|
|
4579
|
+
8.42404590141772420927e00,
|
|
4580
|
+
5.21451598682361504063e00,
|
|
4581
|
+
1.00000000000000000254e00,
|
|
4582
|
+
],
|
|
4583
|
+
dtype=self.dtype,
|
|
4465
4584
|
)
|
|
4466
4585
|
PQ = jnp.array(
|
|
4467
|
-
|
|
4468
|
-
|
|
4469
|
-
|
|
4470
|
-
|
|
4471
|
-
|
|
4472
|
-
|
|
4473
|
-
|
|
4474
|
-
|
|
4475
|
-
|
|
4476
|
-
|
|
4586
|
+
[
|
|
4587
|
+
5.71323128072548699714e-04,
|
|
4588
|
+
6.88455908754495404082e-02,
|
|
4589
|
+
1.10514232634061696926e00,
|
|
4590
|
+
5.07386386128601488557e00,
|
|
4591
|
+
8.39985554327604159757e00,
|
|
4592
|
+
5.20982848682361821619e00,
|
|
4593
|
+
9.99999999999999997461e-01,
|
|
4594
|
+
],
|
|
4595
|
+
dtype=self.dtype,
|
|
4477
4596
|
)
|
|
4478
4597
|
QP = jnp.array(
|
|
4479
|
-
|
|
4480
|
-
|
|
4481
|
-
|
|
4482
|
-
|
|
4483
|
-
|
|
4484
|
-
|
|
4485
|
-
|
|
4486
|
-
|
|
4487
|
-
|
|
4488
|
-
|
|
4489
|
-
|
|
4598
|
+
[
|
|
4599
|
+
5.10862594750176621635e-02,
|
|
4600
|
+
4.98213872951233449420e00,
|
|
4601
|
+
7.58238284132545283818e01,
|
|
4602
|
+
3.66779609360150777800e02,
|
|
4603
|
+
7.10856304998926107277e02,
|
|
4604
|
+
5.97489612400613639965e02,
|
|
4605
|
+
2.11688757100572135698e02,
|
|
4606
|
+
2.52070205858023719784e01,
|
|
4607
|
+
],
|
|
4608
|
+
dtype=self.dtype,
|
|
4490
4609
|
)
|
|
4491
4610
|
QQ = jnp.array(
|
|
4492
|
-
|
|
4493
|
-
|
|
4494
|
-
|
|
4495
|
-
|
|
4496
|
-
|
|
4497
|
-
|
|
4498
|
-
|
|
4499
|
-
|
|
4500
|
-
|
|
4501
|
-
|
|
4611
|
+
[
|
|
4612
|
+
7.42373277035675149943e01,
|
|
4613
|
+
1.05644886038262816351e03,
|
|
4614
|
+
4.98641058337653607651e03,
|
|
4615
|
+
9.56231892404756170795e03,
|
|
4616
|
+
7.99704160447350683650e03,
|
|
4617
|
+
2.82619278517639096600e03,
|
|
4618
|
+
3.36093607810698293419e02,
|
|
4619
|
+
],
|
|
4620
|
+
dtype=self.dtype,
|
|
4502
4621
|
)
|
|
4503
4622
|
|
|
4504
4623
|
factor = 25.0 / (x * x)
|
|
@@ -4507,22 +4626,15 @@ def _aten_special_bessel_y1(self):
|
|
|
4507
4626
|
qp = op_base.foreach_loop(QP, lambda carry, qp_i: carry * factor + qp_i)
|
|
4508
4627
|
qq = op_base.foreach_loop(QQ, lambda carry, qq_i: carry * factor + qq_i)
|
|
4509
4628
|
|
|
4510
|
-
return (
|
|
4511
|
-
|
|
4512
|
-
|
|
4513
|
-
|
|
4514
|
-
/ x
|
|
4515
|
-
* (qp / qq)
|
|
4516
|
-
* jnp.cos(x - 2.356194490192344928846982537459627163)
|
|
4517
|
-
)
|
|
4518
|
-
* 0.797884560802865355879892119868763737
|
|
4519
|
-
/ jnp.sqrt(x)
|
|
4520
|
-
)
|
|
4629
|
+
return ((pp / pq * jnp.sin(x - 2.356194490192344928846982537459627163) +
|
|
4630
|
+
5.0 / x *
|
|
4631
|
+
(qp / qq) * jnp.cos(x - 2.356194490192344928846982537459627163)) *
|
|
4632
|
+
0.797884560802865355879892119868763737 / jnp.sqrt(x))
|
|
4521
4633
|
|
|
4522
4634
|
return jnp.piecewise(
|
|
4523
|
-
|
|
4524
|
-
|
|
4525
|
-
|
|
4635
|
+
self,
|
|
4636
|
+
[self <= 5.0, self < 0., self == 0.],
|
|
4637
|
+
[small, negative, zero, default],
|
|
4526
4638
|
)
|
|
4527
4639
|
|
|
4528
4640
|
|
|
@@ -4533,11 +4645,13 @@ def _aten_special_chebyshev_polynomial_t(self, n):
|
|
|
4533
4645
|
|
|
4534
4646
|
@jnp.vectorize
|
|
4535
4647
|
def vectorized(x, n_i):
|
|
4648
|
+
|
|
4536
4649
|
def negative_n(x):
|
|
4537
4650
|
return jnp.zeros_like(x)
|
|
4538
4651
|
|
|
4539
4652
|
def one_x(x):
|
|
4540
|
-
return jnp.where((x > 0) | (n_i % 2 == 0), jnp.ones_like(x),
|
|
4653
|
+
return jnp.where((x > 0) | (n_i % 2 == 0), jnp.ones_like(x),
|
|
4654
|
+
-jnp.ones_like(x))
|
|
4541
4655
|
|
|
4542
4656
|
def large_n_small_x(x):
|
|
4543
4657
|
return jnp.cos(n_i * jnp.acos(x))
|
|
@@ -4549,24 +4663,18 @@ def _aten_special_chebyshev_polynomial_t(self, n):
|
|
|
4549
4663
|
return x
|
|
4550
4664
|
|
|
4551
4665
|
def default(x):
|
|
4666
|
+
|
|
4552
4667
|
def f(_, carry):
|
|
4553
4668
|
p, q = carry
|
|
4554
4669
|
return (q, 2 * x * q - p)
|
|
4555
4670
|
|
|
4556
|
-
_, r
|
|
4671
|
+
_, r = jax.lax.fori_loop(0, n_i - 1, f, init_val=(1., x))
|
|
4557
4672
|
return r
|
|
4558
4673
|
|
|
4559
|
-
return jnp.piecewise(
|
|
4560
|
-
|
|
4561
|
-
|
|
4562
|
-
|
|
4563
|
-
n_i == 0,
|
|
4564
|
-
(n_i == 6) & (jnp.abs(x) < 1),
|
|
4565
|
-
jnp.abs(x) == 1.,
|
|
4566
|
-
n_i < 0
|
|
4567
|
-
],
|
|
4568
|
-
[one_n, zero_n, large_n_small_x, one_x, negative_n, default]
|
|
4569
|
-
)
|
|
4674
|
+
return jnp.piecewise(x, [
|
|
4675
|
+
n_i == 1, n_i == 0, (n_i == 6) & (jnp.abs(x) < 1),
|
|
4676
|
+
jnp.abs(x) == 1., n_i < 0
|
|
4677
|
+
], [one_n, zero_n, large_n_small_x, one_x, negative_n, default])
|
|
4570
4678
|
|
|
4571
4679
|
# Explcicitly vectorize since we must vectorizes over both self and n
|
|
4572
4680
|
return vectorized(self, n.astype(jnp.int64))
|
|
@@ -4579,6 +4687,7 @@ def _aten_special_chebyshev_polynomial_u(self, n):
|
|
|
4579
4687
|
|
|
4580
4688
|
@jnp.vectorize
|
|
4581
4689
|
def vectorized(x, n_i):
|
|
4690
|
+
|
|
4582
4691
|
def negative_n(x):
|
|
4583
4692
|
return jnp.zeros_like(x)
|
|
4584
4693
|
|
|
@@ -4588,9 +4697,9 @@ def _aten_special_chebyshev_polynomial_u(self, n):
|
|
|
4588
4697
|
def large_n_small_x(x):
|
|
4589
4698
|
sin_acos_x = jnp.sin(jnp.acos(x))
|
|
4590
4699
|
return jnp.where(
|
|
4591
|
-
|
|
4592
|
-
|
|
4593
|
-
|
|
4700
|
+
sin_acos_x != 0,
|
|
4701
|
+
jnp.sin((n_i + 1) * jnp.acos(x)) / sin_acos_x,
|
|
4702
|
+
(n_i + 1) * jnp.cos((n_i + 1) * jnp.acos(x)) / x,
|
|
4594
4703
|
)
|
|
4595
4704
|
|
|
4596
4705
|
def zero_n(x):
|
|
@@ -4600,6 +4709,7 @@ def _aten_special_chebyshev_polynomial_u(self, n):
|
|
|
4600
4709
|
return 2 * x
|
|
4601
4710
|
|
|
4602
4711
|
def default(x):
|
|
4712
|
+
|
|
4603
4713
|
def f(_, carry):
|
|
4604
4714
|
p, q = carry
|
|
4605
4715
|
return (q, 2 * x * q - p)
|
|
@@ -4608,15 +4718,15 @@ def _aten_special_chebyshev_polynomial_u(self, n):
|
|
|
4608
4718
|
return r
|
|
4609
4719
|
|
|
4610
4720
|
return jnp.piecewise(
|
|
4611
|
-
|
|
4612
|
-
|
|
4613
|
-
|
|
4614
|
-
|
|
4615
|
-
|
|
4616
|
-
|
|
4617
|
-
|
|
4618
|
-
|
|
4619
|
-
|
|
4721
|
+
x,
|
|
4722
|
+
[
|
|
4723
|
+
n_i == 1,
|
|
4724
|
+
n_i == 0,
|
|
4725
|
+
(n_i > 8) & (jnp.abs(x) < 1),
|
|
4726
|
+
jnp.abs(x) == 1.0,
|
|
4727
|
+
n_i < 0,
|
|
4728
|
+
],
|
|
4729
|
+
[one_n, zero_n, large_n_small_x, one_x, negative_n, default],
|
|
4620
4730
|
)
|
|
4621
4731
|
|
|
4622
4732
|
return vectorized(self, n.astype(jnp.int64))
|
|
@@ -4627,6 +4737,7 @@ def _aten_special_chebyshev_polynomial_u(self, n):
|
|
|
4627
4737
|
def _aten_special_erfcx(x):
|
|
4628
4738
|
return jnp.exp(x * x) * jax.lax.erfc(x)
|
|
4629
4739
|
|
|
4740
|
+
|
|
4630
4741
|
@op(torch.ops.aten.erfc)
|
|
4631
4742
|
@op_base.promote_int_input
|
|
4632
4743
|
def _aten_erfcx(x):
|
|
@@ -4640,6 +4751,7 @@ def _aten_special_hermite_polynomial_h(self, n):
|
|
|
4640
4751
|
|
|
4641
4752
|
@jnp.vectorize
|
|
4642
4753
|
def vectorized(x, n_i):
|
|
4754
|
+
|
|
4643
4755
|
def negative_n(x):
|
|
4644
4756
|
return jnp.zeros_like(x)
|
|
4645
4757
|
|
|
@@ -4650,6 +4762,7 @@ def _aten_special_hermite_polynomial_h(self, n):
|
|
|
4650
4762
|
return 2 * x
|
|
4651
4763
|
|
|
4652
4764
|
def default(x):
|
|
4765
|
+
|
|
4653
4766
|
def f(k, carry):
|
|
4654
4767
|
p, q = carry
|
|
4655
4768
|
return (q, 2 * x * q - 2 * k * p)
|
|
@@ -4657,9 +4770,8 @@ def _aten_special_hermite_polynomial_h(self, n):
|
|
|
4657
4770
|
_, r = jax.lax.fori_loop(1, n_i, f, init_val=(1.0, 2 * x))
|
|
4658
4771
|
return r
|
|
4659
4772
|
|
|
4660
|
-
return jnp.piecewise(
|
|
4661
|
-
|
|
4662
|
-
)
|
|
4773
|
+
return jnp.piecewise(x, [n_i == 1, n_i == 0, n_i < 0],
|
|
4774
|
+
[one_n, zero_n, negative_n, default])
|
|
4663
4775
|
|
|
4664
4776
|
return vectorized(self, n.astype(jnp.int64))
|
|
4665
4777
|
|
|
@@ -4671,6 +4783,7 @@ def _aten_special_hermite_polynomial_he(self, n):
|
|
|
4671
4783
|
|
|
4672
4784
|
@jnp.vectorize
|
|
4673
4785
|
def vectorized(x, n_i):
|
|
4786
|
+
|
|
4674
4787
|
def negative_n(x):
|
|
4675
4788
|
return jnp.zeros_like(x)
|
|
4676
4789
|
|
|
@@ -4681,6 +4794,7 @@ def _aten_special_hermite_polynomial_he(self, n):
|
|
|
4681
4794
|
return x
|
|
4682
4795
|
|
|
4683
4796
|
def default(x):
|
|
4797
|
+
|
|
4684
4798
|
def f(k, carry):
|
|
4685
4799
|
p, q = carry
|
|
4686
4800
|
return (q, x * q - k * p)
|
|
@@ -4688,24 +4802,34 @@ def _aten_special_hermite_polynomial_he(self, n):
|
|
|
4688
4802
|
_, r = jax.lax.fori_loop(1, n_i, f, init_val=(1.0, x))
|
|
4689
4803
|
return r
|
|
4690
4804
|
|
|
4691
|
-
return jnp.piecewise(
|
|
4692
|
-
|
|
4693
|
-
)
|
|
4805
|
+
return jnp.piecewise(x, [n_i == 1.0, n_i == 0.0, n_i < 0],
|
|
4806
|
+
[one_n, zero_n, negative_n, default])
|
|
4694
4807
|
|
|
4695
4808
|
return vectorized(self, n.astype(jnp.int64))
|
|
4696
4809
|
|
|
4697
4810
|
|
|
4698
4811
|
@op(torch.ops.aten.multinomial, needs_env=True)
|
|
4699
|
-
def _aten_multinomial(input,
|
|
4700
|
-
|
|
4701
|
-
|
|
4812
|
+
def _aten_multinomial(input,
|
|
4813
|
+
num_samples,
|
|
4814
|
+
replacement=False,
|
|
4815
|
+
*,
|
|
4816
|
+
generator=None,
|
|
4817
|
+
out=None,
|
|
4818
|
+
env=None):
|
|
4819
|
+
assert num_samples <= input.shape[
|
|
4820
|
+
-1] or replacement, "cannot take a larger sample than population when replacement=False"
|
|
4702
4821
|
key = env.get_and_rotate_prng_key(generator)
|
|
4703
4822
|
if input.ndim == 1:
|
|
4704
|
-
|
|
4705
|
-
|
|
4823
|
+
return jax.random.choice(
|
|
4824
|
+
key, input.shape[-1], (num_samples,), replace=replacement, p=input)
|
|
4706
4825
|
else:
|
|
4707
|
-
|
|
4708
|
-
|
|
4826
|
+
return jnp.array([
|
|
4827
|
+
jax.random.choice(
|
|
4828
|
+
key,
|
|
4829
|
+
input.shape[-1], (num_samples,),
|
|
4830
|
+
replace=replacement,
|
|
4831
|
+
p=input[i, :]) for i in range(input.shape[0])
|
|
4832
|
+
])
|
|
4709
4833
|
|
|
4710
4834
|
|
|
4711
4835
|
@op(torch.ops.aten.narrow)
|
|
@@ -4738,7 +4862,12 @@ def _aten_flatten(x, start_dim=0, end_dim=-1):
|
|
|
4738
4862
|
|
|
4739
4863
|
@op(torch.ops.aten.new_empty)
|
|
4740
4864
|
def _new_empty(self, size, **kwargs):
|
|
4741
|
-
|
|
4865
|
+
dtype = kwargs.get('dtype')
|
|
4866
|
+
if dtype is not None:
|
|
4867
|
+
dtype = mappings.t2j_dtype(dtype)
|
|
4868
|
+
else:
|
|
4869
|
+
dtype = self.dtype
|
|
4870
|
+
return jnp.empty(size, dtype=dtype)
|
|
4742
4871
|
|
|
4743
4872
|
|
|
4744
4873
|
@op(torch.ops.aten.new_empty_strided)
|
|
@@ -4756,10 +4885,8 @@ def _aten_unsafe_index_put(self, indices, values, accumulate=False):
|
|
|
4756
4885
|
return _aten_index_put(self, indices, values, accumulate)
|
|
4757
4886
|
|
|
4758
4887
|
|
|
4759
|
-
@op(torch.ops.aten.conj_physical,
|
|
4760
|
-
torch.ops.aten.
|
|
4761
|
-
torch.ops.aten._conj_physical,
|
|
4762
|
-
torch.ops.aten._conj)
|
|
4888
|
+
@op(torch.ops.aten.conj_physical, torch.ops.aten.conj,
|
|
4889
|
+
torch.ops.aten._conj_physical, torch.ops.aten._conj)
|
|
4763
4890
|
def _aten_conj_physical(self):
|
|
4764
4891
|
return jnp.conjugate(self)
|
|
4765
4892
|
|
|
@@ -4768,6 +4895,7 @@ def _aten_conj_physical(self):
|
|
|
4768
4895
|
def _aten_log_sigmoid(x):
|
|
4769
4896
|
return jax.nn.log_sigmoid(x)
|
|
4770
4897
|
|
|
4898
|
+
|
|
4771
4899
|
# torch.qr
|
|
4772
4900
|
@op(torch.ops.aten.qr)
|
|
4773
4901
|
def _aten_qr(input, *args, **kwargs):
|
|
@@ -4778,6 +4906,7 @@ def _aten_qr(input, *args, **kwargs):
|
|
|
4778
4906
|
jax_mode = "complete"
|
|
4779
4907
|
return jax.numpy.linalg.qr(input, mode=jax_mode)
|
|
4780
4908
|
|
|
4909
|
+
|
|
4781
4910
|
# torch.linalg.qr
|
|
4782
4911
|
@op(torch.ops.aten.linalg_qr)
|
|
4783
4912
|
def _aten_linalg_qr(input, *args, **kwargs):
|
|
@@ -4820,19 +4949,25 @@ def _aten__linalg_solve_ex(a, b):
|
|
|
4820
4949
|
res = jnp.linalg.solve(a, b)
|
|
4821
4950
|
if batched:
|
|
4822
4951
|
res = res.squeeze(-1)
|
|
4823
|
-
info_shape = a.shape[
|
|
4952
|
+
info_shape = a.shape[:-2]
|
|
4824
4953
|
info = jnp.zeros(info_shape, dtype=mappings.t2j_dtype(torch.int32))
|
|
4825
4954
|
return res, info
|
|
4826
4955
|
|
|
4827
4956
|
|
|
4828
4957
|
# torch.linalg.solve_triangular
|
|
4829
4958
|
@op(torch.ops.aten.linalg_solve_triangular)
|
|
4830
|
-
def _aten_linalg_solve_triangular(a,
|
|
4959
|
+
def _aten_linalg_solve_triangular(a,
|
|
4960
|
+
b,
|
|
4961
|
+
*,
|
|
4962
|
+
upper=True,
|
|
4963
|
+
left=True,
|
|
4964
|
+
unitriangular=False):
|
|
4831
4965
|
if left is False:
|
|
4832
4966
|
a = jnp.matrix_transpose(a)
|
|
4833
4967
|
b = jnp.matrix_transpose(b)
|
|
4834
4968
|
upper = not upper
|
|
4835
|
-
res = jax.scipy.linalg.solve_triangular(
|
|
4969
|
+
res = jax.scipy.linalg.solve_triangular(
|
|
4970
|
+
a, b, lower=not upper, unit_diagonal=unitriangular)
|
|
4836
4971
|
if left is False:
|
|
4837
4972
|
res = jnp.matrix_transpose(res)
|
|
4838
4973
|
return res
|
|
@@ -4852,21 +4987,31 @@ def _aten__linalg_check_errors(*args, **kwargs):
|
|
|
4852
4987
|
|
|
4853
4988
|
@op(torch.ops.aten.median)
|
|
4854
4989
|
def _aten_median(self, dim=None, keepdim=False):
|
|
4855
|
-
output = _with_reduction_scalar(
|
|
4990
|
+
output = _with_reduction_scalar(
|
|
4991
|
+
functools.partial(jnp.quantile, q=0.5, method='lower'),
|
|
4992
|
+
self,
|
|
4993
|
+
dim=dim,
|
|
4994
|
+
keepdim=keepdim).astype(self.dtype)
|
|
4856
4995
|
if dim is None:
|
|
4857
4996
|
return output
|
|
4858
4997
|
else:
|
|
4859
|
-
index = _with_reduction_scalar(_get_median_index, self, dim,
|
|
4998
|
+
index = _with_reduction_scalar(_get_median_index, self, dim,
|
|
4999
|
+
keepdim).astype(jnp.int64)
|
|
4860
5000
|
return output, index
|
|
4861
5001
|
|
|
4862
5002
|
|
|
4863
5003
|
@op(torch.ops.aten.nanmedian)
|
|
4864
5004
|
def _aten_nanmedian(input, dim=None, keepdim=False, *, out=None):
|
|
4865
|
-
output = _with_reduction_scalar(
|
|
5005
|
+
output = _with_reduction_scalar(
|
|
5006
|
+
functools.partial(jnp.nanquantile, q=0.5, method='lower'),
|
|
5007
|
+
input,
|
|
5008
|
+
dim=dim,
|
|
5009
|
+
keepdim=keepdim).astype(input.dtype)
|
|
4866
5010
|
if dim is None:
|
|
4867
5011
|
return output
|
|
4868
5012
|
else:
|
|
4869
|
-
index = _with_reduction_scalar(_get_median_index, input, dim,
|
|
5013
|
+
index = _with_reduction_scalar(_get_median_index, input, dim,
|
|
5014
|
+
keepdim).astype(jnp.int64)
|
|
4870
5015
|
return output, index
|
|
4871
5016
|
|
|
4872
5017
|
|
|
@@ -4874,20 +5019,31 @@ def _get_median_index(x, axis=None, keepdims=False):
|
|
|
4874
5019
|
sorted_arg = jnp.argsort(x, axis=axis)
|
|
4875
5020
|
n = x.shape[axis] if axis is not None else x.size
|
|
4876
5021
|
if n % 2 == 1:
|
|
4877
|
-
|
|
5022
|
+
index = n // 2
|
|
4878
5023
|
else:
|
|
4879
|
-
|
|
5024
|
+
index = (n // 2) - 1
|
|
4880
5025
|
if axis is None:
|
|
4881
|
-
|
|
5026
|
+
median_index = sorted_arg[index]
|
|
4882
5027
|
else:
|
|
4883
|
-
|
|
5028
|
+
median_index = jnp.take(sorted_arg, index, axis=axis)
|
|
4884
5029
|
if keepdims and axis is not None:
|
|
4885
|
-
|
|
5030
|
+
median_index = jnp.expand_dims(median_index, axis)
|
|
4886
5031
|
return median_index
|
|
4887
5032
|
|
|
5033
|
+
|
|
4888
5034
|
@op(torch.ops.aten.triangular_solve)
|
|
4889
|
-
def _aten_triangular_solve(b,
|
|
4890
|
-
|
|
5035
|
+
def _aten_triangular_solve(b,
|
|
5036
|
+
a,
|
|
5037
|
+
upper=True,
|
|
5038
|
+
transpose=False,
|
|
5039
|
+
unittriangular=False):
|
|
5040
|
+
return (jax.lax.linalg.triangular_solve(
|
|
5041
|
+
a,
|
|
5042
|
+
b,
|
|
5043
|
+
left_side=True,
|
|
5044
|
+
lower=not upper,
|
|
5045
|
+
transpose_a=transpose,
|
|
5046
|
+
unit_diagonal=unittriangular), a)
|
|
4891
5047
|
|
|
4892
5048
|
|
|
4893
5049
|
# func: _fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor
|
|
@@ -4895,16 +5051,16 @@ def _aten_triangular_solve(b, a, upper=True, transpose=False, unittriangular=Fal
|
|
|
4895
5051
|
def _aten__fft_c2c(self, dim, normalization, forward):
|
|
4896
5052
|
if forward:
|
|
4897
5053
|
norm = [
|
|
4898
|
-
|
|
4899
|
-
|
|
4900
|
-
|
|
5054
|
+
'backward',
|
|
5055
|
+
'ortho',
|
|
5056
|
+
'forward',
|
|
4901
5057
|
][normalization]
|
|
4902
5058
|
return jnp.fft.fftn(self, axes=dim, norm=norm)
|
|
4903
5059
|
else:
|
|
4904
5060
|
norm = [
|
|
4905
|
-
|
|
4906
|
-
|
|
4907
|
-
|
|
5061
|
+
'forward',
|
|
5062
|
+
'ortho',
|
|
5063
|
+
'backward',
|
|
4908
5064
|
][normalization]
|
|
4909
5065
|
return jnp.fft.ifftn(self, axes=dim, norm=norm)
|
|
4910
5066
|
|
|
@@ -4912,21 +5068,22 @@ def _aten__fft_c2c(self, dim, normalization, forward):
|
|
|
4912
5068
|
@op(torch.ops.aten._fft_r2c)
|
|
4913
5069
|
def _aten__fft_r2c(self, dim, normalization, onesided):
|
|
4914
5070
|
norm = [
|
|
4915
|
-
|
|
4916
|
-
|
|
4917
|
-
|
|
5071
|
+
'backward',
|
|
5072
|
+
'ortho',
|
|
5073
|
+
'forward',
|
|
4918
5074
|
][normalization]
|
|
4919
5075
|
if onesided:
|
|
4920
5076
|
return jnp.fft.rfftn(self, axes=dim, norm=norm)
|
|
4921
5077
|
else:
|
|
4922
5078
|
return jnp.fft.fftn(self, axes=dim, norm=norm)
|
|
4923
5079
|
|
|
5080
|
+
|
|
4924
5081
|
@op(torch.ops.aten._fft_c2r)
|
|
4925
5082
|
def _aten__fft_c2r(self, dim, normalization, last_dim_size):
|
|
4926
5083
|
norm = [
|
|
4927
|
-
|
|
4928
|
-
|
|
4929
|
-
|
|
5084
|
+
'forward',
|
|
5085
|
+
'ortho',
|
|
5086
|
+
'backward',
|
|
4930
5087
|
][normalization]
|
|
4931
5088
|
if len(dim) == 1:
|
|
4932
5089
|
s = [last_dim_size]
|
|
@@ -4936,34 +5093,49 @@ def _aten__fft_c2r(self, dim, normalization, last_dim_size):
|
|
|
4936
5093
|
|
|
4937
5094
|
|
|
4938
5095
|
@op(torch.ops.aten._trilinear)
|
|
4939
|
-
def _aten_trilinear(i1,
|
|
4940
|
-
|
|
5096
|
+
def _aten_trilinear(i1,
|
|
5097
|
+
i2,
|
|
5098
|
+
i3,
|
|
5099
|
+
expand1,
|
|
5100
|
+
expand2,
|
|
5101
|
+
expand3,
|
|
5102
|
+
sumdim,
|
|
5103
|
+
unroll_dim=1):
|
|
5104
|
+
return _aten_sum(
|
|
5105
|
+
jnp.expand_dims(i1, expand1) * jnp.expand_dims(i2, expand2) *
|
|
5106
|
+
jnp.expand_dims(i3, expand3), sumdim)
|
|
4941
5107
|
|
|
4942
5108
|
|
|
4943
5109
|
@op(torch.ops.aten.max_unpool2d)
|
|
4944
5110
|
@op(torch.ops.aten.max_unpool3d)
|
|
4945
5111
|
def _aten_max_unpoolxd(input, indices, output_size, stride=None, padding=0):
|
|
4946
5112
|
if output_size is None:
|
|
4947
|
-
raise ValueError(
|
|
5113
|
+
raise ValueError(
|
|
5114
|
+
"output_size value is not set correctly. It cannot be None or empty.")
|
|
4948
5115
|
|
|
4949
5116
|
output_size = [input.shape[0], input.shape[1]] + output_size
|
|
4950
5117
|
output = jnp.zeros(output_size, dtype=input.dtype)
|
|
4951
5118
|
|
|
4952
5119
|
for idx in np.ndindex(input.shape):
|
|
4953
|
-
|
|
4954
|
-
|
|
4955
|
-
|
|
4956
|
-
|
|
4957
|
-
|
|
5120
|
+
max_index = indices[idx]
|
|
5121
|
+
spatial_dims = output_size[2:] # (D, H, W)
|
|
5122
|
+
unpooled_spatial_idx = np.unravel_index(max_index, spatial_dims)
|
|
5123
|
+
full_idx = idx[:2] + unpooled_spatial_idx
|
|
5124
|
+
output = output.at[full_idx].set(input[idx])
|
|
4958
5125
|
|
|
4959
5126
|
return output
|
|
4960
5127
|
|
|
4961
|
-
|
|
4962
|
-
def
|
|
5128
|
+
|
|
5129
|
+
def _aten_upsample(input,
|
|
5130
|
+
output_size,
|
|
5131
|
+
align_corners,
|
|
5132
|
+
antialias,
|
|
5133
|
+
method,
|
|
5134
|
+
scale_factors=None,
|
|
5135
|
+
scales_h=None,
|
|
5136
|
+
scales_w=None):
|
|
4963
5137
|
# input: is of type jaxlib.xla_extension.ArrayImpl
|
|
4964
5138
|
image = input
|
|
4965
|
-
method = "bilinear"
|
|
4966
|
-
antialias = True # ignored for upsampling
|
|
4967
5139
|
|
|
4968
5140
|
# https://jax.readthedocs.io/en/latest/_autosummary/jax.image.resize.html
|
|
4969
5141
|
# Resize does not distinguish batch, channel size.
|
|
@@ -4977,12 +5149,12 @@ def _aten_upsample_bilinear2d_aa(input, output_size, align_corners, scale_factor
|
|
|
4977
5149
|
shape = list(image.shape)
|
|
4978
5150
|
# overriding output_size
|
|
4979
5151
|
if scale_factors:
|
|
4980
|
-
shape[-1] = int(math.floor(shape[-1]*scale_factors[-1]))
|
|
4981
|
-
shape[-2] = int(math.floor(shape[-2]*scale_factors[-2]))
|
|
5152
|
+
shape[-1] = int(math.floor(shape[-1] * scale_factors[-1]))
|
|
5153
|
+
shape[-2] = int(math.floor(shape[-2] * scale_factors[-2]))
|
|
4982
5154
|
if scales_h:
|
|
4983
|
-
shape[-2] = int(math.floor(shape[-2]*scales_h))
|
|
5155
|
+
shape[-2] = int(math.floor(shape[-2] * scales_h))
|
|
4984
5156
|
if scales_w:
|
|
4985
|
-
shape[-1] = int(math.floor(shape[-1]*scales_w))
|
|
5157
|
+
shape[-1] = int(math.floor(shape[-1] * scales_w))
|
|
4986
5158
|
# output_size overrides scale_factors, scales_*
|
|
4987
5159
|
if output_size:
|
|
4988
5160
|
shape[-1] = output_size[-1]
|
|
@@ -4992,11 +5164,11 @@ def _aten_upsample_bilinear2d_aa(input, output_size, align_corners, scale_factor
|
|
|
4992
5164
|
if shape == list(image.shape):
|
|
4993
5165
|
return image
|
|
4994
5166
|
|
|
4995
|
-
spatial_dims = (2,3)
|
|
5167
|
+
spatial_dims = (2, 3)
|
|
4996
5168
|
if len(shape) == 3:
|
|
4997
|
-
spatial_dims = (1,2)
|
|
5169
|
+
spatial_dims = (1, 2)
|
|
4998
5170
|
|
|
4999
|
-
scale = list([shape[i] / image.shape[i]
|
|
5171
|
+
scale = list([shape[i] / image.shape[i] for i in spatial_dims])
|
|
5000
5172
|
if scale_factors:
|
|
5001
5173
|
scale = scale_factors
|
|
5002
5174
|
if scales_h:
|
|
@@ -5008,7 +5180,9 @@ def _aten_upsample_bilinear2d_aa(input, output_size, align_corners, scale_factor
|
|
|
5008
5180
|
# align_corners is not supported in resize()
|
|
5009
5181
|
# https://github.com/jax-ml/jax/issues/11206
|
|
5010
5182
|
if align_corners:
|
|
5011
|
-
scale = jnp.array([
|
|
5183
|
+
scale = jnp.array([
|
|
5184
|
+
(shape[i] - 1.0) / (image.shape[i] - 1.0) for i in spatial_dims
|
|
5185
|
+
])
|
|
5012
5186
|
|
|
5013
5187
|
translation = jnp.array([0 for i in spatial_dims])
|
|
5014
5188
|
|
|
@@ -5022,12 +5196,53 @@ def _aten_upsample_bilinear2d_aa(input, output_size, align_corners, scale_factor
|
|
|
5022
5196
|
antialias=antialias,
|
|
5023
5197
|
)
|
|
5024
5198
|
|
|
5199
|
+
|
|
5200
|
+
@op(torch.ops.aten._upsample_bilinear2d_aa)
|
|
5201
|
+
def _aten_upsample_billinear_aa(input,
|
|
5202
|
+
output_size,
|
|
5203
|
+
align_corners,
|
|
5204
|
+
scale_factors=None,
|
|
5205
|
+
scales_h=None,
|
|
5206
|
+
scales_w=None):
|
|
5207
|
+
return _aten_upsample(
|
|
5208
|
+
input,
|
|
5209
|
+
output_size,
|
|
5210
|
+
align_corners,
|
|
5211
|
+
True, # antialias
|
|
5212
|
+
"bilinear", # method
|
|
5213
|
+
scale_factors,
|
|
5214
|
+
scales_h,
|
|
5215
|
+
scales_w)
|
|
5216
|
+
|
|
5217
|
+
|
|
5218
|
+
@op(torch.ops.aten._upsample_bicubic2d_aa)
|
|
5219
|
+
def _aten_upsample_bicubic2d_aa(input,
|
|
5220
|
+
output_size,
|
|
5221
|
+
align_corners,
|
|
5222
|
+
scale_factors=None,
|
|
5223
|
+
scales_h=None,
|
|
5224
|
+
scales_w=None):
|
|
5225
|
+
return _aten_upsample(
|
|
5226
|
+
input,
|
|
5227
|
+
output_size,
|
|
5228
|
+
align_corners,
|
|
5229
|
+
True, # antialias
|
|
5230
|
+
"bicubic", # method
|
|
5231
|
+
scale_factors,
|
|
5232
|
+
scales_h,
|
|
5233
|
+
scales_w)
|
|
5234
|
+
|
|
5235
|
+
|
|
5025
5236
|
@op(torch.ops.aten.polar)
|
|
5026
5237
|
def _aten_polar(abs, angle, *, out=None):
|
|
5027
5238
|
return jax.lax.complex(abs * jnp.cos(angle), abs * jnp.sin(angle))
|
|
5028
5239
|
|
|
5240
|
+
|
|
5029
5241
|
@op(torch.ops.aten.cdist)
|
|
5030
|
-
def _aten_cdist(x1,
|
|
5242
|
+
def _aten_cdist(x1,
|
|
5243
|
+
x2,
|
|
5244
|
+
p=2.0,
|
|
5245
|
+
compute_mode='use_mm_for_euclid_dist_if_necessary'):
|
|
5031
5246
|
x1 = x1.astype(jnp.float32)
|
|
5032
5247
|
x2 = x2.astype(jnp.float32)
|
|
5033
5248
|
|
|
@@ -5036,7 +5251,8 @@ def _aten_cdist(x1, x2, p=2.0, compute_mode='use_mm_for_euclid_dist_if_necessary
|
|
|
5036
5251
|
return _hamming_distance(x1, x2).astype(jnp.float32)
|
|
5037
5252
|
elif p == 2.0:
|
|
5038
5253
|
# Use optimized Euclidean distance calculation
|
|
5039
|
-
if compute_mode == 'use_mm_for_euclid_dist_if_necessary' and (
|
|
5254
|
+
if compute_mode == 'use_mm_for_euclid_dist_if_necessary' and (
|
|
5255
|
+
x1.shape[-2] > 25 or x2.shape[-2] > 25):
|
|
5040
5256
|
return _euclidean_mm(x1, x2)
|
|
5041
5257
|
elif compute_mode == 'use_mm_for_euclid_dist':
|
|
5042
5258
|
return _euclidean_mm(x1, x2)
|
|
@@ -5045,7 +5261,8 @@ def _aten_cdist(x1, x2, p=2.0, compute_mode='use_mm_for_euclid_dist_if_necessary
|
|
|
5045
5261
|
else:
|
|
5046
5262
|
# General p-norm distance calculation
|
|
5047
5263
|
diff = jnp.abs(jnp.expand_dims(x1, -2) - jnp.expand_dims(x2, -3))
|
|
5048
|
-
return jnp.sum(jnp.power(diff, p), axis=-1).astype(jnp.float32)
|
|
5264
|
+
return jnp.sum(jnp.power(diff, p), axis=-1).astype(jnp.float32)**(1 / p)
|
|
5265
|
+
|
|
5049
5266
|
|
|
5050
5267
|
def _hamming_distance(x1, x2):
|
|
5051
5268
|
"""
|
|
@@ -5064,6 +5281,7 @@ def _hamming_distance(x1, x2):
|
|
|
5064
5281
|
|
|
5065
5282
|
return hamming_dist
|
|
5066
5283
|
|
|
5284
|
+
|
|
5067
5285
|
def _euclidean_mm(x1, x2):
|
|
5068
5286
|
"""
|
|
5069
5287
|
Computes the Euclidean distance using matrix multiplication.
|
|
@@ -5075,8 +5293,8 @@ def _euclidean_mm(x1, x2):
|
|
|
5075
5293
|
Returns:
|
|
5076
5294
|
JAX array of shape (..., P, R) representing pairwise Euclidean distances.
|
|
5077
5295
|
"""
|
|
5078
|
-
x1_sq = jnp.sum(x1
|
|
5079
|
-
x2_sq = jnp.sum(x2
|
|
5296
|
+
x1_sq = jnp.sum(x1**2, axis=-1, keepdims=True).astype(jnp.float32)
|
|
5297
|
+
x2_sq = jnp.sum(x2**2, axis=-1, keepdims=True).astype(jnp.float32)
|
|
5080
5298
|
|
|
5081
5299
|
x2_sq = jnp.swapaxes(x2_sq, -2, -1)
|
|
5082
5300
|
|
|
@@ -5088,6 +5306,7 @@ def _euclidean_mm(x1, x2):
|
|
|
5088
5306
|
|
|
5089
5307
|
return dist
|
|
5090
5308
|
|
|
5309
|
+
|
|
5091
5310
|
def _euclidean_direct(x1, x2):
|
|
5092
5311
|
"""
|
|
5093
5312
|
Computes the Euclidean distance directly without matrix multiplication.
|
|
@@ -5101,7 +5320,7 @@ def _euclidean_direct(x1, x2):
|
|
|
5101
5320
|
"""
|
|
5102
5321
|
diff = jnp.expand_dims(x1, -2) - jnp.expand_dims(x2, -3)
|
|
5103
5322
|
|
|
5104
|
-
dist_sq = jnp.sum(diff
|
|
5323
|
+
dist_sq = jnp.sum(diff**2, axis=-1).astype(jnp.float32)
|
|
5105
5324
|
|
|
5106
5325
|
dist_sq = jnp.maximum(dist_sq, 0.0)
|
|
5107
5326
|
|
|
@@ -5109,13 +5328,14 @@ def _euclidean_direct(x1, x2):
|
|
|
5109
5328
|
|
|
5110
5329
|
return dist
|
|
5111
5330
|
|
|
5331
|
+
|
|
5112
5332
|
@op(torch.ops.aten.lu_unpack)
|
|
5113
5333
|
def _aten_lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
|
|
5114
5334
|
# lu_unpack doesnt exist in jax.
|
|
5115
5335
|
# Get commonly used data shape variables
|
|
5116
5336
|
n = LU_data.shape[-2]
|
|
5117
5337
|
m = LU_data.shape[-1]
|
|
5118
|
-
dim = min(n,m)
|
|
5338
|
+
dim = min(n, m)
|
|
5119
5339
|
|
|
5120
5340
|
### Compute the Lower and Upper triangle
|
|
5121
5341
|
if unpack_data:
|
|
@@ -5130,7 +5350,7 @@ def _aten_lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
|
|
|
5130
5350
|
start_indices = jnp.zeros(len(LU_data.shape), dtype=int)
|
|
5131
5351
|
limit_indices = list(LU_data.shape)
|
|
5132
5352
|
limit_indices[-1] = dim
|
|
5133
|
-
L = jax.lax.slice(L, start_indices, limit_indices)
|
|
5353
|
+
L = jax.lax.slice(L, start_indices, limit_indices)
|
|
5134
5354
|
|
|
5135
5355
|
# Extract upper triangle
|
|
5136
5356
|
U = jnp.triu(LU_data)
|
|
@@ -5160,13 +5380,15 @@ def _aten_lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
|
|
|
5160
5380
|
|
|
5161
5381
|
# closure to be called for each input 2D matrix.
|
|
5162
5382
|
def _lu_unpack_2d(p, pivot):
|
|
5163
|
-
_pivot = pivot - 1
|
|
5383
|
+
_pivot = pivot - 1 # pivots are offset by 1 in jax
|
|
5164
5384
|
indices = jnp.array([*range(n)], dtype=jnp.int32)
|
|
5385
|
+
|
|
5165
5386
|
def update_indices(i, _indices):
|
|
5166
5387
|
tmp = _indices[i]
|
|
5167
5388
|
_indices = _indices.at[i].set(_indices[_pivot[i]])
|
|
5168
5389
|
_indices = _indices.at[_pivot[i]].set(tmp)
|
|
5169
5390
|
return _indices
|
|
5391
|
+
|
|
5170
5392
|
indices = jax.lax.fori_loop(0, _pivot.size, update_indices, indices)
|
|
5171
5393
|
p = p[jnp.array(indices)]
|
|
5172
5394
|
p = jnp.transpose(p)
|
|
@@ -5191,7 +5413,7 @@ def _aten_lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
|
|
|
5191
5413
|
reshapedPivot = LU_pivots.reshape(newPivotshape)
|
|
5192
5414
|
|
|
5193
5415
|
# vmap the reshaped 3d tensors
|
|
5194
|
-
v_lu_unpack_2d = jax.vmap(_lu_unpack_2d, in_axes=(0,0))
|
|
5416
|
+
v_lu_unpack_2d = jax.vmap(_lu_unpack_2d, in_axes=(0, 0))
|
|
5195
5417
|
unpackedP = v_lu_unpack_2d(reshapedP, reshapedPivot)
|
|
5196
5418
|
|
|
5197
5419
|
# reshape result back to P's shape
|
|
@@ -5210,3 +5432,207 @@ def linear(input, weight, bias=None):
|
|
|
5210
5432
|
if bias is not None:
|
|
5211
5433
|
res += bias
|
|
5212
5434
|
return res
|
|
5435
|
+
|
|
5436
|
+
|
|
5437
|
+
@op(torch.ops.aten.kthvalue)
|
|
5438
|
+
def kthvalue(input, k, dim=None, keepdim=False, *, out=None):
|
|
5439
|
+
if input.ndim == 0:
|
|
5440
|
+
return input, jnp.array(0)
|
|
5441
|
+
dimension = -1
|
|
5442
|
+
if dim is not None:
|
|
5443
|
+
dimension = dim
|
|
5444
|
+
while dimension < 0:
|
|
5445
|
+
dimension = dimension + input.ndim
|
|
5446
|
+
values = jax.lax.index_in_dim(
|
|
5447
|
+
jnp.partition(input, k - 1, dimension), k - 1, dimension, keepdim)
|
|
5448
|
+
indices = jax.lax.index_in_dim(
|
|
5449
|
+
jnp.argpartition(input, k - 1, dimension).astype('int64'), k - 1,
|
|
5450
|
+
dimension, keepdim)
|
|
5451
|
+
return values, indices
|
|
5452
|
+
|
|
5453
|
+
|
|
5454
|
+
@op(torch.ops.aten.take)
|
|
5455
|
+
def _aten_take(self, index):
|
|
5456
|
+
return self.flatten()[index]
|
|
5457
|
+
|
|
5458
|
+
|
|
5459
|
+
# func: pad(Tensor self, SymInt[] pad, str mode="constant", float? value=None) -> Tensor
|
|
5460
|
+
@op(torch.ops.aten.pad)
|
|
5461
|
+
def _aten_pad(self, pad, mode='constant', value=None):
|
|
5462
|
+
if not isinstance(pad, (tuple, list)) or len(pad) % 2 != 0:
|
|
5463
|
+
raise ValueError("Padding must be a sequence of even length.")
|
|
5464
|
+
|
|
5465
|
+
num_dims = self.ndim
|
|
5466
|
+
if len(pad) > 2 * num_dims:
|
|
5467
|
+
raise ValueError(
|
|
5468
|
+
f"Padding sequence length ({len(pad)}) exceeds 2 * number of dimensions ({2 * num_dims})."
|
|
5469
|
+
)
|
|
5470
|
+
|
|
5471
|
+
# JAX's pad function expects padding for each dimension as a tuple of (low, high)
|
|
5472
|
+
# We need to reverse the pad sequence and group them for JAX.
|
|
5473
|
+
# pad = [p_l0, p_r0, p_l1, p_r1, ...]
|
|
5474
|
+
# becomes ((..., ..., (p_l1, p_r1), (p_l0, p_r0)))
|
|
5475
|
+
jax_pad_width = []
|
|
5476
|
+
# Iterate in reverse pairs
|
|
5477
|
+
for i in range(len(pad) // 2):
|
|
5478
|
+
jax_pad_width.append((pad[(2 * i)], pad[(2 * i + 1)]))
|
|
5479
|
+
|
|
5480
|
+
# Pad any leading dimensions with (0, 0) if the pad sequence is shorter
|
|
5481
|
+
# than the number of dimensions.
|
|
5482
|
+
for _ in range(num_dims - len(pad) // 2):
|
|
5483
|
+
jax_pad_width.append((0, 0))
|
|
5484
|
+
|
|
5485
|
+
# Reverse the jax_pad_width list to match the dimension order
|
|
5486
|
+
jax_pad_width.reverse()
|
|
5487
|
+
|
|
5488
|
+
if mode == "constant":
|
|
5489
|
+
if value is None:
|
|
5490
|
+
value = 0.0
|
|
5491
|
+
return jnp.pad(
|
|
5492
|
+
self, pad_width=jax_pad_width, mode="constant", constant_values=value)
|
|
5493
|
+
elif mode == "reflect":
|
|
5494
|
+
return jnp.pad(self, pad_width=jax_pad_width, mode="reflect")
|
|
5495
|
+
elif mode == "edge":
|
|
5496
|
+
return jnp.pad(self, pad_width=jax_pad_width, mode="edge")
|
|
5497
|
+
else:
|
|
5498
|
+
raise ValueError(
|
|
5499
|
+
f"Unsupported padding mode: {mode}. Expected 'constant', 'reflect', or 'edge'."
|
|
5500
|
+
)
|
|
5501
|
+
|
|
5502
|
+
|
|
5503
|
+
@op(torch.ops.aten.is_nonzero)
|
|
5504
|
+
def _aten_is_nonzero(a):
|
|
5505
|
+
a = jnp.squeeze(a)
|
|
5506
|
+
if a.shape == (0,):
|
|
5507
|
+
raise RuntimeError('bool value of Tensor with no values is ambiguous')
|
|
5508
|
+
if a.ndim != 0:
|
|
5509
|
+
raise RuntimeError(
|
|
5510
|
+
'bool value of Tensor with more than one value is ambiguous')
|
|
5511
|
+
return a.item() != 0
|
|
5512
|
+
|
|
5513
|
+
|
|
5514
|
+
@op(torch.ops.aten.logit)
|
|
5515
|
+
def _aten_logit(self: jax.Array, eps: float | None = None) -> jax.Array:
|
|
5516
|
+
"""
|
|
5517
|
+
Computes the logit function of the input tensor.
|
|
5518
|
+
|
|
5519
|
+
logit(p) = log(p / (1 - p))
|
|
5520
|
+
|
|
5521
|
+
Args:
|
|
5522
|
+
self: Input tensor.
|
|
5523
|
+
eps: A small value to clip the input tensor to avoid log(0) or division by zero.
|
|
5524
|
+
If None, no clipping is performed.
|
|
5525
|
+
|
|
5526
|
+
Returns:
|
|
5527
|
+
A tensor with the logit of each element of the input.
|
|
5528
|
+
"""
|
|
5529
|
+
if eps is not None:
|
|
5530
|
+
self = jnp.clip(self, eps, 1.0 - eps)
|
|
5531
|
+
res = jnp.log(self / (1.0 - self))
|
|
5532
|
+
res = res.astype(mappings.t2j_dtype(torch.get_default_dtype()))
|
|
5533
|
+
return res
|
|
5534
|
+
|
|
5535
|
+
|
|
5536
|
+
@op(torch.ops.aten.floor_divide)
|
|
5537
|
+
def _aten_floor_divide(x, y):
|
|
5538
|
+
res = jnp.floor_divide(x, y)
|
|
5539
|
+
return res
|
|
5540
|
+
|
|
5541
|
+
|
|
5542
|
+
@op(torch.ops.aten._assert_tensor_metadata)
|
|
5543
|
+
@op(torch.ops.aten._assert_scalar)
|
|
5544
|
+
def _aten__assert_tensor_metadata(*args, **kwargs):
|
|
5545
|
+
pass
|
|
5546
|
+
|
|
5547
|
+
|
|
5548
|
+
mutation_ops_to_functional = {
|
|
5549
|
+
torch.ops.aten.add_:
|
|
5550
|
+
op_base.InplaceOp(torch.ops.aten.add),
|
|
5551
|
+
torch.ops.aten.sub_:
|
|
5552
|
+
op_base.InplaceOp(torch.ops.aten.sub),
|
|
5553
|
+
torch.ops.aten.mul_:
|
|
5554
|
+
op_base.InplaceOp(torch.ops.aten.mul),
|
|
5555
|
+
torch.ops.aten.div_:
|
|
5556
|
+
op_base.InplaceOp(torch.ops.aten.div),
|
|
5557
|
+
torch.ops.aten.pow_:
|
|
5558
|
+
op_base.InplaceOp(torch.ops.aten.pow),
|
|
5559
|
+
torch.ops.aten.lt_:
|
|
5560
|
+
op_base.InplaceOp(torch.ops.aten.lt),
|
|
5561
|
+
torch.ops.aten.le_:
|
|
5562
|
+
op_base.InplaceOp(torch.ops.aten.le),
|
|
5563
|
+
torch.ops.aten.gt_:
|
|
5564
|
+
op_base.InplaceOp(torch.ops.aten.gt),
|
|
5565
|
+
torch.ops.aten.ge_:
|
|
5566
|
+
op_base.InplaceOp(torch.ops.aten.ge),
|
|
5567
|
+
torch.ops.aten.eq_:
|
|
5568
|
+
op_base.InplaceOp(torch.ops.aten.eq),
|
|
5569
|
+
torch.ops.aten.ne_:
|
|
5570
|
+
op_base.InplaceOp(torch.ops.aten.ne),
|
|
5571
|
+
torch.ops.aten.bernoulli_:
|
|
5572
|
+
op_base.InplaceOp(torch.ops.aten.bernoulli.p),
|
|
5573
|
+
torch.ops.aten.bernoulli_.float:
|
|
5574
|
+
op_base.InplaceOp(_aten_bernoulli, is_jax_func=True),
|
|
5575
|
+
torch.ops.aten.geometric_:
|
|
5576
|
+
op_base.InplaceOp(torch.ops.aten.geometric),
|
|
5577
|
+
torch.ops.aten.normal_:
|
|
5578
|
+
op_base.InplaceOp(torch.ops.aten.normal),
|
|
5579
|
+
torch.ops.aten.random_:
|
|
5580
|
+
op_base.InplaceOp(torch.ops.aten.uniform),
|
|
5581
|
+
torch.ops.aten.uniform_:
|
|
5582
|
+
op_base.InplaceOp(torch.ops.aten.uniform),
|
|
5583
|
+
torch.ops.aten.relu_:
|
|
5584
|
+
op_base.InplaceOp(torch.ops.aten.relu),
|
|
5585
|
+
# squeeze_ is expected to change tensor's shape. So replace with new value
|
|
5586
|
+
torch.ops.aten.squeeze_:
|
|
5587
|
+
op_base.InplaceOp(torch.ops.aten.squeeze, True),
|
|
5588
|
+
torch.ops.aten.sqrt_:
|
|
5589
|
+
op_base.InplaceOp(torch.ops.aten.sqrt),
|
|
5590
|
+
torch.ops.aten.clamp_:
|
|
5591
|
+
op_base.InplaceOp(torch.ops.aten.clamp),
|
|
5592
|
+
torch.ops.aten.clamp_min_:
|
|
5593
|
+
op_base.InplaceOp(torch.ops.aten.clamp_min),
|
|
5594
|
+
torch.ops.aten.sigmoid_:
|
|
5595
|
+
op_base.InplaceOp(torch.ops.aten.sigmoid),
|
|
5596
|
+
torch.ops.aten.tanh_:
|
|
5597
|
+
op_base.InplaceOp(torch.ops.aten.tanh),
|
|
5598
|
+
torch.ops.aten.ceil_:
|
|
5599
|
+
op_base.InplaceOp(torch.ops.aten.ceil),
|
|
5600
|
+
torch.ops.aten.logical_not_:
|
|
5601
|
+
op_base.InplaceOp(torch.ops.aten.logical_not),
|
|
5602
|
+
torch.ops.aten.unsqueeze_:
|
|
5603
|
+
op_base.InplaceOp(torch.ops.aten.unsqueeze),
|
|
5604
|
+
torch.ops.aten.transpose_:
|
|
5605
|
+
op_base.InplaceOp(torch.ops.aten.transpose),
|
|
5606
|
+
torch.ops.aten.log_normal_:
|
|
5607
|
+
op_base.InplaceOp(torch.ops.aten.log_normal),
|
|
5608
|
+
torch.ops.aten.scatter_add_:
|
|
5609
|
+
op_base.InplaceOp(torch.ops.aten.scatter_add),
|
|
5610
|
+
torch.ops.aten.scatter_reduce_.two:
|
|
5611
|
+
op_base.InplaceOp(torch.ops.aten.scatter_reduce),
|
|
5612
|
+
torch.ops.aten.scatter_:
|
|
5613
|
+
op_base.InplaceOp(torch.ops.aten.scatter),
|
|
5614
|
+
torch.ops.aten.bitwise_or_:
|
|
5615
|
+
op_base.InplaceOp(torch.ops.aten.bitwise_or),
|
|
5616
|
+
torch.ops.aten.floor_divide_:
|
|
5617
|
+
op_base.InplaceOp(torch.ops.aten.floor_divide),
|
|
5618
|
+
torch.ops.aten.remainder_:
|
|
5619
|
+
op_base.InplaceOp(torch.ops.aten.remainder),
|
|
5620
|
+
torch.ops.aten.index_put_:
|
|
5621
|
+
op_base.InplaceOp(torch.ops.aten.index_put),
|
|
5622
|
+
}
|
|
5623
|
+
|
|
5624
|
+
# Note: tuple comparisons work intuitively, e.g. `_jax_version >= (0, 4, 32)`.
|
|
5625
|
+
_jax_version = tuple(int(v) for v in jax.version._version.split("."))
|
|
5626
|
+
|
|
5627
|
+
mutation_needs_env = {
|
|
5628
|
+
torch.ops.aten.bernoulli_,
|
|
5629
|
+
torch.ops.aten.bernoulli_.float,
|
|
5630
|
+
}
|
|
5631
|
+
|
|
5632
|
+
for operator, mutation in mutation_ops_to_functional.items():
|
|
5633
|
+
ops_registry.register_torch_dispatch_op(
|
|
5634
|
+
operator,
|
|
5635
|
+
mutation,
|
|
5636
|
+
is_jax_function=False,
|
|
5637
|
+
is_view_op=True,
|
|
5638
|
+
needs_env=(operator in mutation_needs_env))
|