torchax 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/CONTRIBUTING.md +2 -2
- torchax/__init__.py +57 -19
- torchax/amp.py +333 -0
- torchax/config.py +19 -12
- torchax/decompositions.py +663 -195
- torchax/device_module.py +7 -1
- torchax/distributed.py +55 -60
- torchax/export.py +26 -17
- torchax/flax.py +39 -0
- torchax/interop.py +275 -141
- torchax/mesh_util.py +211 -0
- torchax/ops/jaten.py +1718 -1294
- torchax/ops/jax_reimplement.py +23 -21
- torchax/ops/jc10d.py +5 -4
- torchax/ops/jimage.py +113 -0
- torchax/ops/jlibrary.py +9 -2
- torchax/ops/jtorch.py +219 -78
- torchax/ops/jtorchvision_nms.py +32 -43
- torchax/ops/mappings.py +77 -35
- torchax/ops/op_base.py +59 -32
- torchax/ops/ops_registry.py +40 -35
- torchax/tensor.py +417 -275
- torchax/train.py +38 -41
- torchax/util.py +88 -0
- torchax/view.py +377 -0
- {torchax-0.0.4.dist-info → torchax-0.0.5.dist-info}/METADATA +111 -145
- torchax-0.0.5.dist-info/RECORD +32 -0
- torchax/environment.py +0 -2
- torchax-0.0.4.dist-info/RECORD +0 -27
- {torchax-0.0.4.dist-info → torchax-0.0.5.dist-info}/WHEEL +0 -0
- {torchax-0.0.4.dist-info → torchax-0.0.5.dist-info}/licenses/LICENSE +0 -0
torchax/ops/jaten.py
CHANGED
|
@@ -15,75 +15,22 @@ from torchax.ops import ops_registry
|
|
|
15
15
|
from torchax.ops import op_base, mappings
|
|
16
16
|
from torchax import interop
|
|
17
17
|
from torchax.ops import jax_reimplement
|
|
18
|
-
|
|
18
|
+
from torchax.view import View
|
|
19
19
|
# Keys are OpOverload, value is a callable that takes
|
|
20
20
|
# Tensor
|
|
21
21
|
all_ops = {}
|
|
22
22
|
|
|
23
|
-
# list all Aten ops from pytorch that does mutation
|
|
24
|
-
# and need to be implemented in jax
|
|
25
|
-
|
|
26
|
-
mutation_ops_to_functional = {
|
|
27
|
-
torch.ops.aten.add_: torch.ops.aten.add,
|
|
28
|
-
torch.ops.aten.sub_: torch.ops.aten.sub,
|
|
29
|
-
torch.ops.aten.mul_: torch.ops.aten.mul,
|
|
30
|
-
torch.ops.aten.div_: torch.ops.aten.div,
|
|
31
|
-
torch.ops.aten.pow_: torch.ops.aten.pow,
|
|
32
|
-
torch.ops.aten.lt_: torch.ops.aten.lt,
|
|
33
|
-
torch.ops.aten.le_: torch.ops.aten.le,
|
|
34
|
-
torch.ops.aten.gt_: torch.ops.aten.gt,
|
|
35
|
-
torch.ops.aten.ge_: torch.ops.aten.ge,
|
|
36
|
-
torch.ops.aten.eq_: torch.ops.aten.eq,
|
|
37
|
-
torch.ops.aten.ne_: torch.ops.aten.ne,
|
|
38
|
-
torch.ops.aten.bernoulli_: torch.ops.aten.bernoulli.p,
|
|
39
|
-
torch.ops.aten.geometric_: torch.ops.aten.geometric,
|
|
40
|
-
torch.ops.aten.normal_: torch.ops.aten.normal,
|
|
41
|
-
torch.ops.aten.random_: torch.ops.aten.uniform,
|
|
42
|
-
torch.ops.aten.uniform_: torch.ops.aten.uniform,
|
|
43
|
-
torch.ops.aten.relu_: torch.ops.aten.relu,
|
|
44
|
-
# squeeze_ is expected to change tensor's shape. So replace with new value
|
|
45
|
-
torch.ops.aten.squeeze_: (torch.ops.aten.squeeze, True),
|
|
46
|
-
torch.ops.aten.sqrt_: torch.ops.aten.sqrt,
|
|
47
|
-
torch.ops.aten.clamp_: torch.ops.aten.clamp,
|
|
48
|
-
torch.ops.aten.clamp_min_: torch.ops.aten.clamp_min,
|
|
49
|
-
torch.ops.aten.sigmoid_: torch.ops.aten.sigmoid,
|
|
50
|
-
torch.ops.aten.tanh_: torch.ops.aten.tanh,
|
|
51
|
-
torch.ops.aten.ceil_: torch.ops.aten.ceil,
|
|
52
|
-
torch.ops.aten.logical_not_: torch.ops.aten.logical_not,
|
|
53
|
-
torch.ops.aten.unsqueeze_: torch.ops.aten.unsqueeze,
|
|
54
|
-
torch.ops.aten.transpose_: torch.ops.aten.transpose,
|
|
55
|
-
torch.ops.aten.log_normal_: torch.ops.aten.log_normal,
|
|
56
|
-
torch.ops.aten.scatter_add_: torch.ops.aten.scatter_add,
|
|
57
|
-
torch.ops.aten.scatter_reduce_.two: torch.ops.aten.scatter_reduce,
|
|
58
|
-
torch.ops.aten.scatter_: torch.ops.aten.scatter,
|
|
59
|
-
}
|
|
60
|
-
|
|
61
|
-
# Note: tuple comparisons work intuitively, e.g. `_jax_version >= (0, 4, 32)`.
|
|
62
|
-
_jax_version = tuple(int(v) for v in jax.version._version.split("."))
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
def make_mutation(op):
|
|
66
|
-
if type(mutation_ops_to_functional[op]) is tuple:
|
|
67
|
-
return op_base.InplaceOp(mutation_ops_to_functional[op][0],
|
|
68
|
-
replace=mutation_ops_to_functional[op][1],
|
|
69
|
-
position_to_mutate=0)
|
|
70
|
-
return op_base.InplaceOp(mutation_ops_to_functional[op], position_to_mutate=0)
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
for op in mutation_ops_to_functional.keys():
|
|
74
|
-
ops_registry.register_torch_dispatch_op(
|
|
75
|
-
op, make_mutation(op), is_jax_function=False
|
|
76
|
-
)
|
|
77
|
-
|
|
78
23
|
|
|
79
24
|
def op(*aten, **kwargs):
|
|
25
|
+
|
|
80
26
|
def inner(func):
|
|
81
27
|
for a in aten:
|
|
82
28
|
ops_registry.register_torch_dispatch_op(a, func, **kwargs)
|
|
83
29
|
continue
|
|
84
30
|
|
|
85
31
|
if isinstance(a, torch._ops.OpOverloadPacket):
|
|
86
|
-
opname = a.default.name() if 'default' in a.overloads(
|
|
32
|
+
opname = a.default.name() if 'default' in a.overloads(
|
|
33
|
+
) else a._qualified_op_name
|
|
87
34
|
elif isinstance(a, torch._ops.OpOverload):
|
|
88
35
|
opname = a.name()
|
|
89
36
|
else:
|
|
@@ -91,17 +38,18 @@ def op(*aten, **kwargs):
|
|
|
91
38
|
|
|
92
39
|
torchfunc = functools.partial(interop.call_jax, func)
|
|
93
40
|
# HACK: to_copy is where we make the initial conversion from CPU tensor to JAX tensor
|
|
94
|
-
torch.library.impl(opname, 'privateuseone')(
|
|
41
|
+
torch.library.impl(opname, 'privateuseone')(
|
|
42
|
+
torchfunc if a != torch.ops.aten._to_copy else func)
|
|
95
43
|
return func
|
|
96
44
|
|
|
97
45
|
return inner
|
|
98
46
|
|
|
99
47
|
|
|
100
48
|
@op(
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
49
|
+
torch.ops.aten.view_copy,
|
|
50
|
+
torch.ops.aten.view,
|
|
51
|
+
torch.ops.aten._unsafe_view,
|
|
52
|
+
torch.ops.aten.reshape,
|
|
105
53
|
)
|
|
106
54
|
def _aten_unsafe_view(x, shape):
|
|
107
55
|
return jnp.reshape(x, shape)
|
|
@@ -121,8 +69,19 @@ def _aten_add(x, y, *, alpha=1):
|
|
|
121
69
|
return res
|
|
122
70
|
|
|
123
71
|
|
|
124
|
-
@op(torch.ops.aten.copy_,
|
|
125
|
-
|
|
72
|
+
@op(torch.ops.aten.copy_,
|
|
73
|
+
is_jax_function=False,
|
|
74
|
+
is_view_op=True,
|
|
75
|
+
needs_env=True)
|
|
76
|
+
def _aten_copy(x, y, memory_format=None, env=None):
|
|
77
|
+
|
|
78
|
+
if y.device.type == 'cpu':
|
|
79
|
+
y = env.to_xla(y)
|
|
80
|
+
|
|
81
|
+
if isinstance(x, View):
|
|
82
|
+
x.update(y)
|
|
83
|
+
return x
|
|
84
|
+
|
|
126
85
|
if x.ndim == 1 and y.ndim == 0:
|
|
127
86
|
# case of torch.empty((1,)).copy_(tensor(N))
|
|
128
87
|
# we need to return 0D tensor([N]) and not scalar tensor(N)
|
|
@@ -147,14 +106,20 @@ def _aten_trunc(x):
|
|
|
147
106
|
|
|
148
107
|
@op(torch.ops.aten.index_copy)
|
|
149
108
|
def _aten_index_copy(x, dim, indexes, source):
|
|
109
|
+
if x.ndim == 0:
|
|
110
|
+
return source
|
|
111
|
+
if x.ndim == 1:
|
|
112
|
+
source = jnp.squeeze(source)
|
|
150
113
|
# return jax.lax.scatter(x, index, dim)
|
|
114
|
+
if dim < 0:
|
|
115
|
+
dim = dim + x.ndim
|
|
151
116
|
dims = []
|
|
152
117
|
for i in range(len(x.shape)):
|
|
153
118
|
if i == dim:
|
|
154
119
|
dims.append(indexes)
|
|
155
120
|
else:
|
|
156
121
|
dims.append(slice(None, None, None))
|
|
157
|
-
return x.at[
|
|
122
|
+
return x.at[tuple(dims)].set(source)
|
|
158
123
|
|
|
159
124
|
|
|
160
125
|
# aten.cauchy_
|
|
@@ -199,7 +164,9 @@ def _aten_complex(real, imag):
|
|
|
199
164
|
Returns:
|
|
200
165
|
A complex array with the specified real and imaginary parts.
|
|
201
166
|
"""
|
|
202
|
-
return jnp.array(
|
|
167
|
+
return jnp.array(
|
|
168
|
+
real, dtype=jnp.float32) + 1j * jnp.array(
|
|
169
|
+
imag, dtype=jnp.float32)
|
|
203
170
|
|
|
204
171
|
|
|
205
172
|
# aten.exponential_
|
|
@@ -223,13 +190,14 @@ def _aten_exponential_(x, lambd=1.0):
|
|
|
223
190
|
# aten.linalg_householder_product
|
|
224
191
|
@op(torch.ops.aten.linalg_householder_product)
|
|
225
192
|
def _aten_linalg_householder_product(input, tau):
|
|
226
|
-
return jax.lax.linalg.householder_product(a
|
|
193
|
+
return jax.lax.linalg.householder_product(a=input, taus=tau)
|
|
227
194
|
|
|
228
195
|
|
|
229
196
|
@op(torch.ops.aten.select)
|
|
230
197
|
def _aten_select(x, dim, indexes):
|
|
231
198
|
return jax.lax.index_in_dim(x, index=indexes, axis=dim, keepdims=False)
|
|
232
199
|
|
|
200
|
+
|
|
233
201
|
@op(torch.ops.aten.index_select)
|
|
234
202
|
@op(torch.ops.aten.select_copy)
|
|
235
203
|
def _aten_index_select(x, dim, index):
|
|
@@ -249,11 +217,10 @@ def _aten_linalg_cholesky_ex(input, upper=False, check_errors=False):
|
|
|
249
217
|
raise NotImplementedError(
|
|
250
218
|
"check_errors=True is not supported in this JAX implementation. "
|
|
251
219
|
"Check for positive definiteness using jnp.linalg.eigvalsh before "
|
|
252
|
-
"calling this function."
|
|
253
|
-
)
|
|
220
|
+
"calling this function.")
|
|
254
221
|
|
|
255
222
|
L = jax.scipy.linalg.cholesky(input, lower=not upper)
|
|
256
|
-
if len(L.shape) >2:
|
|
223
|
+
if len(L.shape) > 2:
|
|
257
224
|
info = jnp.zeros(shape=L.shape[:-2], dtype=jnp.int32)
|
|
258
225
|
else:
|
|
259
226
|
info = jnp.array(0, dtype=jnp.int32)
|
|
@@ -263,7 +230,7 @@ def _aten_linalg_cholesky_ex(input, upper=False, check_errors=False):
|
|
|
263
230
|
@op(torch.ops.aten.cholesky_solve)
|
|
264
231
|
def _aten_cholesky_solve(input, input2, upper=False):
|
|
265
232
|
# Ensure input2 is lower triangular for cho_solve
|
|
266
|
-
L = input2 if not upper else input2.T
|
|
233
|
+
L = input2 if not upper else input2.T
|
|
267
234
|
# Use cho_solve to solve the linear system
|
|
268
235
|
solution = jax.scipy.linalg.cho_solve((L, True), input)
|
|
269
236
|
return solution
|
|
@@ -275,7 +242,7 @@ def _aten_special_zeta(x, q):
|
|
|
275
242
|
res = jax.scipy.special.zeta(x, q)
|
|
276
243
|
if isinstance(x, int) or isinstance(q, int):
|
|
277
244
|
res = res.astype(new_dtype)
|
|
278
|
-
return res
|
|
245
|
+
return res # jax.scipy.special.zeta(x, q)
|
|
279
246
|
|
|
280
247
|
|
|
281
248
|
# aten.igammac
|
|
@@ -286,7 +253,7 @@ def _aten_igammac(input, other):
|
|
|
286
253
|
if isinstance(other, jnp.ndarray):
|
|
287
254
|
other = jnp.where(other < 0, jnp.nan, other)
|
|
288
255
|
else:
|
|
289
|
-
if (input==0 and other==0) or (input < 0) or (other < 0):
|
|
256
|
+
if (input == 0 and other == 0) or (input < 0) or (other < 0):
|
|
290
257
|
other = jnp.nan
|
|
291
258
|
return jnp.array(jax.scipy.special.gammaincc(input, other))
|
|
292
259
|
|
|
@@ -294,7 +261,7 @@ def _aten_igammac(input, other):
|
|
|
294
261
|
@op(torch.ops.aten.mean)
|
|
295
262
|
def _aten_mean(x, dim=None, keepdim=False):
|
|
296
263
|
if x.shape == () and dim is not None:
|
|
297
|
-
dim = None
|
|
264
|
+
dim = None # disable dim for jax array without dim
|
|
298
265
|
return jnp.mean(x, dim, keepdims=keepdim)
|
|
299
266
|
|
|
300
267
|
|
|
@@ -310,13 +277,14 @@ def _torch_binary_scalar_type(scalar, tensor):
|
|
|
310
277
|
|
|
311
278
|
|
|
312
279
|
@op(torch.ops.aten.searchsorted.Tensor)
|
|
313
|
-
def _aten_searchsorted(sorted_sequence, values):
|
|
280
|
+
def _aten_searchsorted(sorted_sequence, values):
|
|
314
281
|
new_dtype = mappings.t2j_dtype(torch.get_default_dtype())
|
|
315
282
|
res = jnp.searchsorted(sorted_sequence, values)
|
|
316
|
-
if sorted_sequence.dtype == np.dtype(
|
|
283
|
+
if sorted_sequence.dtype == np.dtype(
|
|
284
|
+
np.int32) or sorted_sequence.dtype == np.dtype(np.int32):
|
|
317
285
|
# res = res.astype(new_dtype)
|
|
318
286
|
res = res.astype(np.dtype(np.int64))
|
|
319
|
-
return res
|
|
287
|
+
return res # jnp.searchsorted(sorted_sequence, values)
|
|
320
288
|
|
|
321
289
|
|
|
322
290
|
@op(torch.ops.aten.sub.Tensor)
|
|
@@ -328,7 +296,7 @@ def _aten_sub(x, y, alpha=1):
|
|
|
328
296
|
if isinstance(y, float):
|
|
329
297
|
dtype = _torch_binary_scalar_type(y, x)
|
|
330
298
|
y = jnp.array(y, dtype=dtype)
|
|
331
|
-
return x - y*alpha
|
|
299
|
+
return x - y * alpha
|
|
332
300
|
|
|
333
301
|
|
|
334
302
|
@op(torch.ops.aten.numpy_T)
|
|
@@ -345,7 +313,6 @@ def _aten_numpy_T(input):
|
|
|
345
313
|
return jnp.transpose(input)
|
|
346
314
|
|
|
347
315
|
|
|
348
|
-
|
|
349
316
|
@op(torch.ops.aten.mm)
|
|
350
317
|
def _aten_mm(x, y):
|
|
351
318
|
res = x @ y
|
|
@@ -379,13 +346,15 @@ def _aten_t(x):
|
|
|
379
346
|
@op(torch.ops.aten.transpose)
|
|
380
347
|
@op(torch.ops.aten.transpose_copy)
|
|
381
348
|
def _aten_transpose(x, dim0, dim1):
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
349
|
+
if x.ndim == 0:
|
|
350
|
+
return x
|
|
351
|
+
dim0 = dim0 if dim0 >= 0 else dim0 + x.ndim
|
|
352
|
+
dim1 = dim1 if dim1 >= 0 else dim1 + x.ndim
|
|
353
|
+
return jnp.swapaxes(x, dim0, dim1)
|
|
385
354
|
|
|
386
355
|
|
|
387
356
|
@op(torch.ops.aten.triu)
|
|
388
|
-
def _aten_triu(m, k):
|
|
357
|
+
def _aten_triu(m, k=0):
|
|
389
358
|
return jnp.triu(m, k)
|
|
390
359
|
|
|
391
360
|
|
|
@@ -406,6 +375,7 @@ def _aten_slice(self, dim=0, start=None, end=None, step=1):
|
|
|
406
375
|
return self[tuple(dims)]
|
|
407
376
|
|
|
408
377
|
|
|
378
|
+
@op(torch.ops.aten.positive)
|
|
409
379
|
@op(torch.ops.aten.detach)
|
|
410
380
|
def _aten_detach(self):
|
|
411
381
|
return self
|
|
@@ -439,7 +409,8 @@ def _aten_resize_as_(x, y):
|
|
|
439
409
|
|
|
440
410
|
@op(torch.ops.aten.repeat_interleave.Tensor)
|
|
441
411
|
def repeat_interleave(repeats, dim=0):
|
|
442
|
-
return jnp.repeat(
|
|
412
|
+
return jnp.repeat(np.arange(repeats.shape[dim]), repeats)
|
|
413
|
+
|
|
443
414
|
|
|
444
415
|
@op(torch.ops.aten.repeat_interleave.self_int)
|
|
445
416
|
@op(torch.ops.aten.repeat_interleave.self_Tensor)
|
|
@@ -451,12 +422,6 @@ def repeat_interleave(self, repeats, dim=0):
|
|
|
451
422
|
return jnp.repeat(self, repeats, dim, total_repeat_length=total_repeat_length)
|
|
452
423
|
|
|
453
424
|
|
|
454
|
-
# aten.upsample_bilinear2d
|
|
455
|
-
@op(torch.ops.aten.upsample_bilinear2d)
|
|
456
|
-
def _aten_upsample_bilinear2d(x, output_size, align_corners=False, scale_h=None, scale_w=None):
|
|
457
|
-
return _aten_upsample_bilinear2d_aa(x, output_size=output_size, align_corners=align_corners, scale_factors=None, scales_h=scale_h, scales_w=scale_w)
|
|
458
|
-
|
|
459
|
-
|
|
460
425
|
@op(torch.ops.aten.view_as_real)
|
|
461
426
|
def _aten_view_as_real(x):
|
|
462
427
|
real = jnp.real(x)
|
|
@@ -473,7 +438,7 @@ def _aten_stack(tensors, dim=0):
|
|
|
473
438
|
@op(torch.ops.aten._softmax)
|
|
474
439
|
@op(torch.ops.aten.softmax)
|
|
475
440
|
@op(torch.ops.aten.softmax.int)
|
|
476
|
-
def _aten_softmax(x, dim, halftofloat
|
|
441
|
+
def _aten_softmax(x, dim, halftofloat=False):
|
|
477
442
|
if x.shape == ():
|
|
478
443
|
return jax.nn.softmax(x.reshape([1]), axis=0).reshape([])
|
|
479
444
|
return jax.nn.softmax(x, dim)
|
|
@@ -482,10 +447,12 @@ def _aten_softmax(x, dim, halftofloat = False):
|
|
|
482
447
|
def _is_int(x):
|
|
483
448
|
if isinstance(x, int):
|
|
484
449
|
return True
|
|
485
|
-
if isinstance(x, jax.Array) and (x.dtype.name.startswith('int') or
|
|
450
|
+
if isinstance(x, jax.Array) and (x.dtype.name.startswith('int') or
|
|
451
|
+
x.dtype.name.startswith('uint')):
|
|
486
452
|
return True
|
|
487
453
|
return False
|
|
488
454
|
|
|
455
|
+
|
|
489
456
|
def highest_precision_int_dtype(tensor1, tensor2):
|
|
490
457
|
if isinstance(tensor1, int):
|
|
491
458
|
return tensor2.dtype
|
|
@@ -493,12 +460,20 @@ def highest_precision_int_dtype(tensor1, tensor2):
|
|
|
493
460
|
return tensor1.dtype
|
|
494
461
|
|
|
495
462
|
dtype_hierarchy = {
|
|
496
|
-
'uint8': 8,
|
|
497
|
-
'
|
|
498
|
-
'
|
|
499
|
-
'
|
|
463
|
+
'uint8': 8,
|
|
464
|
+
'int8': 8,
|
|
465
|
+
'uint16': 16,
|
|
466
|
+
'int16': 16,
|
|
467
|
+
'uint32': 32,
|
|
468
|
+
'int32': 32,
|
|
469
|
+
'uint64': 64,
|
|
470
|
+
'int64': 64,
|
|
500
471
|
}
|
|
501
|
-
return max(
|
|
472
|
+
return max(
|
|
473
|
+
tensor1.dtype,
|
|
474
|
+
tensor2.dtype,
|
|
475
|
+
key=lambda dtype: dtype_hierarchy[str(dtype)])
|
|
476
|
+
|
|
502
477
|
|
|
503
478
|
@op(torch.ops.aten.pow)
|
|
504
479
|
def _aten_pow(x, y):
|
|
@@ -553,11 +528,13 @@ def _aten_div(x, y, rounding_mode=""):
|
|
|
553
528
|
def _aten_true_divide(x, y):
|
|
554
529
|
return x / y
|
|
555
530
|
|
|
531
|
+
|
|
556
532
|
@op(torch.ops.aten.dist)
|
|
557
533
|
def _aten_dist(input, other, p=2):
|
|
558
534
|
diff = jnp.abs(jnp.subtract(input, other))
|
|
559
535
|
return _aten_linalg_vector_norm(diff, ord=p)
|
|
560
536
|
|
|
537
|
+
|
|
561
538
|
@op(torch.ops.aten.bmm)
|
|
562
539
|
def _aten_bmm(x, y):
|
|
563
540
|
res = x @ y
|
|
@@ -567,9 +544,14 @@ def _aten_bmm(x, y):
|
|
|
567
544
|
|
|
568
545
|
@op(torch.ops.aten.embedding)
|
|
569
546
|
# embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False)
|
|
570
|
-
def _aten_embedding(a,
|
|
547
|
+
def _aten_embedding(a,
|
|
548
|
+
w,
|
|
549
|
+
padding_idx=-1,
|
|
550
|
+
scale_grad_by_freq=False,
|
|
551
|
+
sparse=False):
|
|
571
552
|
return jnp.take(a, w, axis=0)
|
|
572
553
|
|
|
554
|
+
|
|
573
555
|
@op(torch.ops.aten.embedding_renorm_)
|
|
574
556
|
def _aten_embedding_renorm_(weight, indices, max_norm, norm_type):
|
|
575
557
|
# Adapted from https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Embedding.cpp
|
|
@@ -587,27 +569,26 @@ def _aten_embedding_renorm_(weight, indices, max_norm, norm_type):
|
|
|
587
569
|
|
|
588
570
|
indices_to_update = unique_indices[indice_idx]
|
|
589
571
|
|
|
590
|
-
weight = weight.at[indices_to_update].set(
|
|
591
|
-
|
|
592
|
-
)
|
|
572
|
+
weight = weight.at[indices_to_update].set(weight[indices_to_update] *
|
|
573
|
+
scale[:, None])
|
|
593
574
|
return weight
|
|
594
575
|
|
|
576
|
+
|
|
595
577
|
#- func: _embedding_bag_forward_only(
|
|
596
578
|
# Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False,
|
|
597
579
|
# int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) -> (Tensor, Tensor, Tensor, Tensor)
|
|
598
580
|
@op(torch.ops.aten._embedding_bag)
|
|
599
581
|
@op(torch.ops.aten._embedding_bag_forward_only)
|
|
600
|
-
def _aten__embedding_bag(
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
"""Jax implementation of the PyTorch _embedding_bag function.
|
|
582
|
+
def _aten__embedding_bag(weight,
|
|
583
|
+
indices,
|
|
584
|
+
offsets=None,
|
|
585
|
+
scale_grad_by_freq=False,
|
|
586
|
+
mode=0,
|
|
587
|
+
sparse=False,
|
|
588
|
+
per_sample_weights=None,
|
|
589
|
+
include_last_offset=False,
|
|
590
|
+
padding_idx=-1):
|
|
591
|
+
"""Jax implementation of the PyTorch _embedding_bag function.
|
|
611
592
|
|
|
612
593
|
Args:
|
|
613
594
|
weight: The learnable weights of the module of shape (num_embeddings, embedding_dim).
|
|
@@ -623,48 +604,50 @@ def _aten__embedding_bag(
|
|
|
623
604
|
Returns:
|
|
624
605
|
A tuple of (output, offset2bag, bag_size, max_indices).
|
|
625
606
|
"""
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
607
|
+
embedded = _aten_embedding(weight, indices, padding_idx)
|
|
608
|
+
|
|
609
|
+
if offsets is None:
|
|
610
|
+
# offsets is None only when indices.ndim > 1
|
|
611
|
+
if mode == 0: # sum
|
|
612
|
+
output = jnp.sum(embedded, axis=1)
|
|
613
|
+
elif mode == 1: # mean
|
|
614
|
+
output = jnp.mean(embedded, axis=1)
|
|
615
|
+
elif mode == 2: # max
|
|
616
|
+
output = jnp.max(embedded, axis=1)
|
|
617
|
+
return output, None, None, None
|
|
618
|
+
|
|
619
|
+
if isinstance(offsets, jax.Array):
|
|
620
|
+
offsets_np = np.array(offsets)
|
|
621
|
+
else:
|
|
622
|
+
offsets_np = offsets
|
|
623
|
+
offset2bag = np.zeros(indices.shape[0], dtype=np.int64)
|
|
624
|
+
bag_size = np.zeros(offsets_np.shape[0], dtype=np.int64)
|
|
625
|
+
max_indices = jnp.full_like(indices, -1)
|
|
645
626
|
|
|
646
|
-
|
|
647
|
-
|
|
627
|
+
for bag in range(offsets_np.shape[0]):
|
|
628
|
+
start = int(offsets_np[bag])
|
|
648
629
|
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
630
|
+
end = int(indices.shape[0] if bag +
|
|
631
|
+
1 == offsets_np.shape[0] else offsets_np[bag + 1])
|
|
632
|
+
bag_size[bag] = end - start
|
|
633
|
+
offset2bag = offset2bag.at[start:end].set(bag)
|
|
652
634
|
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
635
|
+
if end - start > 0:
|
|
636
|
+
if mode == 0:
|
|
637
|
+
output_bag = jnp.sum(embedded[start:end], axis=0)
|
|
638
|
+
elif mode == 1:
|
|
639
|
+
output_bag = jnp.mean(embedded[start:end], axis=0)
|
|
640
|
+
elif mode == 2:
|
|
641
|
+
output_bag = jnp.max(embedded[start:end], axis=0)
|
|
642
|
+
max_indices = max_indices.at[start:end].set(
|
|
643
|
+
jnp.argmax(embedded[start:end], axis=0))
|
|
661
644
|
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
645
|
+
# The original code returned offset2bag, bag_size, and max_indices as numpy arrays.
|
|
646
|
+
# Converting them to JAX arrays for consistency.
|
|
647
|
+
offset2bag = jnp.array(offset2bag)
|
|
648
|
+
bag_size = jnp.array(bag_size)
|
|
666
649
|
|
|
667
|
-
|
|
650
|
+
return output_bag, offset2bag, bag_size, max_indices
|
|
668
651
|
|
|
669
652
|
|
|
670
653
|
@op(torch.ops.aten.rsqrt)
|
|
@@ -676,6 +659,7 @@ def _aten_rsqrt(x):
|
|
|
676
659
|
@op(torch.ops.aten.expand)
|
|
677
660
|
@op(torch.ops.aten.expand_copy)
|
|
678
661
|
def _aten_expand(x, dims):
|
|
662
|
+
|
|
679
663
|
def fix_dims(d, xs):
|
|
680
664
|
if d == -1:
|
|
681
665
|
return xs
|
|
@@ -683,7 +667,9 @@ def _aten_expand(x, dims):
|
|
|
683
667
|
|
|
684
668
|
shape = list(x.shape)
|
|
685
669
|
if len(shape) < len(dims):
|
|
686
|
-
shape = [
|
|
670
|
+
shape = [
|
|
671
|
+
1,
|
|
672
|
+
] * (len(dims) - len(shape)) + shape
|
|
687
673
|
# make sure that dims and shape is the same by
|
|
688
674
|
# left pad with 1s. Otherwise the zip below will
|
|
689
675
|
# truncate
|
|
@@ -705,15 +691,15 @@ def _aten__to_copy(self, **kwargs):
|
|
|
705
691
|
|
|
706
692
|
|
|
707
693
|
@op(torch.ops.aten.empty)
|
|
708
|
-
@op_base.convert_dtype()
|
|
694
|
+
@op_base.convert_dtype(use_default_dtype=False)
|
|
709
695
|
def _aten_empty(size: Sequence[int], *, dtype=None, **kwargs):
|
|
710
696
|
return jnp.empty(size, dtype=dtype)
|
|
711
697
|
|
|
712
698
|
|
|
713
699
|
@op(torch.ops.aten.empty_like)
|
|
714
|
-
@op_base.convert_dtype()
|
|
700
|
+
@op_base.convert_dtype(use_default_dtype=False)
|
|
715
701
|
def _aten_empty_like(input, *, dtype=None, **kwargs):
|
|
716
|
-
return jnp.empty_like(input, dtype
|
|
702
|
+
return jnp.empty_like(input, dtype)
|
|
717
703
|
|
|
718
704
|
|
|
719
705
|
@op(torch.ops.aten.ones)
|
|
@@ -784,8 +770,8 @@ def split_with_sizes(x, sizes, dim=0):
|
|
|
784
770
|
A list of sub-arrays.
|
|
785
771
|
"""
|
|
786
772
|
if isinstance(sizes, int):
|
|
787
|
-
# split equal size
|
|
788
|
-
new_sizes = [sizes] * (x.shape[dim] // sizes)
|
|
773
|
+
# split equal size, round up
|
|
774
|
+
new_sizes = [sizes] * (-(-x.shape[dim] // sizes))
|
|
789
775
|
sizes = new_sizes
|
|
790
776
|
rank = x.ndim
|
|
791
777
|
splits = np.cumsum(sizes) # Cumulative sum for split points
|
|
@@ -796,14 +782,15 @@ def split_with_sizes(x, sizes, dim=0):
|
|
|
796
782
|
return tuple(res)
|
|
797
783
|
|
|
798
784
|
return [
|
|
799
|
-
|
|
800
|
-
|
|
785
|
+
x[make_range(rank, dim, start, end)]
|
|
786
|
+
for start, end in zip([0] + list(splits[:-1]), splits)
|
|
801
787
|
]
|
|
802
788
|
|
|
803
789
|
|
|
804
790
|
@op(torch.ops.aten.permute)
|
|
805
791
|
@op(torch.ops.aten.permute_copy)
|
|
806
792
|
def permute(t, dims):
|
|
793
|
+
# TODO: return a View instead
|
|
807
794
|
return jnp.transpose(t, dims)
|
|
808
795
|
|
|
809
796
|
|
|
@@ -819,6 +806,7 @@ def _aten_unsqueeze(self, dim):
|
|
|
819
806
|
def _aten_ne(x, y):
|
|
820
807
|
return jnp.not_equal(x, y)
|
|
821
808
|
|
|
809
|
+
|
|
822
810
|
# Create indices along a specific axis
|
|
823
811
|
#
|
|
824
812
|
# For example
|
|
@@ -832,14 +820,12 @@ def _aten_ne(x, y):
|
|
|
832
820
|
def _indices_along_axis(x, axis):
|
|
833
821
|
return jnp.expand_dims(
|
|
834
822
|
jnp.arange(x.shape[axis]),
|
|
835
|
-
axis
|
|
836
|
-
|
|
823
|
+
axis=[d for d in range(len(x.shape)) if d != axis])
|
|
824
|
+
|
|
837
825
|
|
|
838
826
|
def _broadcast_indices(indices, shape):
|
|
839
|
-
return jnp.broadcast_to(
|
|
840
|
-
|
|
841
|
-
shape
|
|
842
|
-
)
|
|
827
|
+
return jnp.broadcast_to(indices, shape)
|
|
828
|
+
|
|
843
829
|
|
|
844
830
|
@op(torch.ops.aten.cummax)
|
|
845
831
|
def _aten_cummax(x, dim):
|
|
@@ -851,36 +837,45 @@ def _aten_cummax(x, dim):
|
|
|
851
837
|
indice_along_axis = _indices_along_axis(x, axis)
|
|
852
838
|
indices = _broadcast_indices(indice_along_axis, x.shape)
|
|
853
839
|
|
|
854
|
-
|
|
855
840
|
def cummax_reduce_func(carry, elem):
|
|
856
|
-
v1, v2 = carry['val'], elem['val']
|
|
841
|
+
v1, v2 = carry['val'], elem['val']
|
|
857
842
|
i1, i2 = carry['idx'], elem['idx']
|
|
858
843
|
|
|
859
844
|
v = jnp.maximum(v1, v2)
|
|
860
845
|
i = jnp.where(v1 > v2, i1, i2)
|
|
861
846
|
return {'val': v, 'idx': i}
|
|
862
|
-
|
|
847
|
+
|
|
848
|
+
res = jax.lax.associative_scan(
|
|
849
|
+
cummax_reduce_func, {
|
|
850
|
+
'val': x,
|
|
851
|
+
'idx': indices
|
|
852
|
+
}, axis=axis)
|
|
863
853
|
return res['val'], res['idx']
|
|
864
854
|
|
|
855
|
+
|
|
865
856
|
@op(torch.ops.aten.cummin)
|
|
866
857
|
def _aten_cummin(x, dim):
|
|
867
858
|
if not x.shape:
|
|
868
859
|
return x, jnp.zeros_like(x, dtype=jnp.int64)
|
|
869
|
-
|
|
860
|
+
|
|
870
861
|
axis = dim
|
|
871
862
|
|
|
872
863
|
indice_along_axis = _indices_along_axis(x, axis)
|
|
873
864
|
indices = _broadcast_indices(indice_along_axis, x.shape)
|
|
874
865
|
|
|
875
866
|
def cummin_reduce_func(carry, elem):
|
|
876
|
-
v1, v2 = carry['val'], elem['val']
|
|
867
|
+
v1, v2 = carry['val'], elem['val']
|
|
877
868
|
i1, i2 = carry['idx'], elem['idx']
|
|
878
869
|
|
|
879
870
|
v = jnp.minimum(v1, v2)
|
|
880
871
|
i = jnp.where(v1 < v2, i1, i2)
|
|
881
872
|
return {'val': v, 'idx': i}
|
|
882
873
|
|
|
883
|
-
res = jax.lax.associative_scan(
|
|
874
|
+
res = jax.lax.associative_scan(
|
|
875
|
+
cummin_reduce_func, {
|
|
876
|
+
'val': x,
|
|
877
|
+
'idx': indices
|
|
878
|
+
}, axis=axis)
|
|
884
879
|
return res['val'], res['idx']
|
|
885
880
|
|
|
886
881
|
|
|
@@ -908,9 +903,11 @@ def _aten_cumprod(input, dim, dtype=None, out=None):
|
|
|
908
903
|
|
|
909
904
|
|
|
910
905
|
@op(torch.ops.aten.native_layer_norm)
|
|
911
|
-
def _aten_native_layer_norm(
|
|
912
|
-
|
|
913
|
-
|
|
906
|
+
def _aten_native_layer_norm(input,
|
|
907
|
+
normalized_shape,
|
|
908
|
+
weight=None,
|
|
909
|
+
bias=None,
|
|
910
|
+
eps=1e-5):
|
|
914
911
|
"""Implements layer normalization in Jax as defined by `aten::native_layer_norm`.
|
|
915
912
|
|
|
916
913
|
Args:
|
|
@@ -944,7 +941,7 @@ def _aten_native_layer_norm(
|
|
|
944
941
|
norm_x += bias
|
|
945
942
|
return norm_x, mean, rstd
|
|
946
943
|
|
|
947
|
-
|
|
944
|
+
|
|
948
945
|
@op(torch.ops.aten.matmul)
|
|
949
946
|
def _aten_matmul(x, y):
|
|
950
947
|
return x @ y
|
|
@@ -960,6 +957,7 @@ def _aten_addmm(self, mat1, mat2, *, beta=1.0, alpha=1.0):
|
|
|
960
957
|
self += alpha * jnp.matmul(mat1, mat2)
|
|
961
958
|
return self
|
|
962
959
|
|
|
960
|
+
|
|
963
961
|
@op(torch.ops.aten.sparse_sampled_addmm)
|
|
964
962
|
def _aten_sparse_addmm(self, mat1, mat2, *, beta=1.0, alpha=1.0):
|
|
965
963
|
alpha = jnp.array(alpha).astype(mat1.dtype)
|
|
@@ -974,9 +972,8 @@ def _aten_addbmm(input, batch1, batch2, *, beta=1, alpha=1):
|
|
|
974
972
|
alpha = jnp.array(alpha).astype(batch1.dtype)
|
|
975
973
|
beta = jnp.array(beta).astype(batch1.dtype)
|
|
976
974
|
mm = jnp.einsum("bxy, byz -> xz", batch1, batch2)
|
|
977
|
-
return jax.lax.cond(
|
|
978
|
-
|
|
979
|
-
)
|
|
975
|
+
return jax.lax.cond(beta == 0, lambda: alpha * mm,
|
|
976
|
+
lambda: beta * input + alpha * mm)
|
|
980
977
|
|
|
981
978
|
|
|
982
979
|
@op(torch.ops.aten.gelu)
|
|
@@ -987,73 +984,69 @@ def _aten_gelu(self, *, approximate="none"):
|
|
|
987
984
|
|
|
988
985
|
@op(torch.ops.aten.squeeze)
|
|
989
986
|
@op(torch.ops.aten.squeeze_copy)
|
|
990
|
-
def _aten_squeeze_dim(self, dim):
|
|
991
|
-
|
|
992
|
-
|
|
993
|
-
Args:
|
|
994
|
-
self: The input tensor.
|
|
995
|
-
dim: The dimension to squeeze.
|
|
996
|
-
|
|
997
|
-
Returns:
|
|
998
|
-
The squeezed tensor with the specified dimension removed if it is 1,
|
|
999
|
-
otherwise the original tensor is returned.
|
|
1000
|
-
"""
|
|
1001
|
-
|
|
1002
|
-
# Validate input arguments
|
|
1003
|
-
if not isinstance(self, jnp.ndarray):
|
|
1004
|
-
raise TypeError(f"Expected a Jax tensor, got {type(self)}.")
|
|
1005
|
-
if isinstance(dim, int):
|
|
1006
|
-
dim = [dim]
|
|
1007
|
-
|
|
1008
|
-
# Check if the specified dimension has size 1
|
|
1009
|
-
if (len(self.shape) == 0) or all([self.shape[d] != 1 for d in dim]):
|
|
987
|
+
def _aten_squeeze_dim(self, dim=None):
|
|
988
|
+
if self.ndim == 0:
|
|
1010
989
|
return self
|
|
990
|
+
if dim is not None:
|
|
991
|
+
if isinstance(dim, int):
|
|
992
|
+
if self.shape[dim] != 1:
|
|
993
|
+
return self
|
|
994
|
+
if dim < 0:
|
|
995
|
+
dim += self.ndim
|
|
996
|
+
else:
|
|
997
|
+
# NOTE: torch leaves the dims that is not 1 unchanged,
|
|
998
|
+
# but jax raises error.
|
|
999
|
+
dim = [
|
|
1000
|
+
i if i >= 0 else (i + self.ndim) for i in dim if self.shape[i] == 1
|
|
1001
|
+
]
|
|
1011
1002
|
|
|
1012
|
-
|
|
1013
|
-
new_shape = list(self.shape)
|
|
1014
|
-
|
|
1015
|
-
def fix_dim(p):
|
|
1016
|
-
if p < 0:
|
|
1017
|
-
return p + len(self.shape)
|
|
1018
|
-
return p
|
|
1003
|
+
return jnp.squeeze(self, dim)
|
|
1019
1004
|
|
|
1020
|
-
dim = [fix_dim(d) for d in dim]
|
|
1021
|
-
new_shape = [p for i, p in enumerate(self.shape) if i not in dim or p != 1]
|
|
1022
|
-
return self.reshape(new_shape)
|
|
1023
1005
|
|
|
1024
1006
|
@op(torch.ops.aten.bucketize)
|
|
1025
|
-
def _aten_bucketize(input,
|
|
1026
|
-
|
|
1007
|
+
def _aten_bucketize(input,
|
|
1008
|
+
boundaries,
|
|
1009
|
+
*,
|
|
1010
|
+
out_int32=False,
|
|
1011
|
+
right=False,
|
|
1012
|
+
out=None):
|
|
1027
1013
|
return_type = jnp.int32 if out_int32 else jnp.int64
|
|
1028
1014
|
return jnp.digitize(input, boundaries, right=not right).astype(return_type)
|
|
1029
1015
|
|
|
1030
1016
|
|
|
1031
1017
|
@op(torch.ops.aten.conv2d)
|
|
1032
1018
|
def _aten_conv2d(
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1019
|
+
input,
|
|
1020
|
+
weight,
|
|
1021
|
+
bias,
|
|
1022
|
+
stride,
|
|
1023
|
+
padding,
|
|
1024
|
+
dilation,
|
|
1025
|
+
groups,
|
|
1040
1026
|
):
|
|
1041
1027
|
return _aten_convolution(
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1028
|
+
input,
|
|
1029
|
+
weight,
|
|
1030
|
+
bias,
|
|
1031
|
+
stride,
|
|
1032
|
+
padding,
|
|
1033
|
+
dilation,
|
|
1034
|
+
transposed=False,
|
|
1035
|
+
output_padding=1,
|
|
1036
|
+
groups=groups)
|
|
1037
|
+
|
|
1045
1038
|
|
|
1046
1039
|
@op(torch.ops.aten.convolution)
|
|
1047
1040
|
def _aten_convolution(
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1041
|
+
input,
|
|
1042
|
+
weight,
|
|
1043
|
+
bias,
|
|
1044
|
+
stride,
|
|
1045
|
+
padding,
|
|
1046
|
+
dilation,
|
|
1047
|
+
transposed,
|
|
1048
|
+
output_padding,
|
|
1049
|
+
groups,
|
|
1057
1050
|
):
|
|
1058
1051
|
num_shape_dim = weight.ndim - 1
|
|
1059
1052
|
batch_dims = input.shape[:-num_shape_dim]
|
|
@@ -1068,7 +1061,7 @@ def _aten_convolution(
|
|
|
1068
1061
|
# See https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html
|
|
1069
1062
|
pad_out = []
|
|
1070
1063
|
for i in range(num_spatial_dims):
|
|
1071
|
-
front = dilation[i] * (weight.shape[i+2] - 1) - padding[i]
|
|
1064
|
+
front = dilation[i] * (weight.shape[i + 2] - 1) - padding[i]
|
|
1072
1065
|
back = front + output_padding[i]
|
|
1073
1066
|
pad_out.append((front, back))
|
|
1074
1067
|
return pad_out
|
|
@@ -1089,39 +1082,38 @@ def _aten_convolution(
|
|
|
1089
1082
|
rhs_spec.append(i + 2)
|
|
1090
1083
|
out_spec.append(i + 2)
|
|
1091
1084
|
return jax.lax.ConvDimensionNumbers(
|
|
1092
|
-
|
|
1093
|
-
)
|
|
1085
|
+
*map(tuple, (lhs_spec, rhs_spec, out_spec)))
|
|
1094
1086
|
|
|
1095
1087
|
if transposed:
|
|
1096
|
-
rhs = jnp.flip(weight, range(2, 1+num_shape_dim))
|
|
1088
|
+
rhs = jnp.flip(weight, range(2, 1 + num_shape_dim))
|
|
1097
1089
|
if groups != 1:
|
|
1098
1090
|
# reshape filters for tranposed depthwise convolution
|
|
1099
1091
|
assert rhs.shape[0] % groups == 0
|
|
1100
|
-
rhs_shape = [rhs.shape[0]//groups, rhs.shape[1]*groups]
|
|
1092
|
+
rhs_shape = [rhs.shape[0] // groups, rhs.shape[1] * groups]
|
|
1101
1093
|
rhs_shape.extend(rhs.shape[2:])
|
|
1102
1094
|
rhs = jnp.reshape(rhs, rhs_shape)
|
|
1103
1095
|
res = jax.lax.conv_general_dilated(
|
|
1104
|
-
|
|
1105
|
-
|
|
1106
|
-
|
|
1107
|
-
|
|
1108
|
-
|
|
1109
|
-
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
|
|
1096
|
+
input,
|
|
1097
|
+
rhs,
|
|
1098
|
+
(1,) * len(stride),
|
|
1099
|
+
make_padding(padding, len(stride)),
|
|
1100
|
+
lhs_dilation=stride,
|
|
1101
|
+
rhs_dilation=dilation,
|
|
1102
|
+
dimension_numbers=create_default_conv_dimension_numbers(len(stride)),
|
|
1103
|
+
feature_group_count=groups,
|
|
1104
|
+
batch_group_count=1,
|
|
1113
1105
|
)
|
|
1114
1106
|
else:
|
|
1115
1107
|
res = jax.lax.conv_general_dilated(
|
|
1116
|
-
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
|
|
1123
|
-
|
|
1124
|
-
|
|
1108
|
+
input,
|
|
1109
|
+
weight,
|
|
1110
|
+
stride,
|
|
1111
|
+
make_padding(padding, len(stride)),
|
|
1112
|
+
lhs_dilation=(1,) * len(stride),
|
|
1113
|
+
rhs_dilation=dilation,
|
|
1114
|
+
dimension_numbers=create_default_conv_dimension_numbers(len(stride)),
|
|
1115
|
+
feature_group_count=groups,
|
|
1116
|
+
batch_group_count=1,
|
|
1125
1117
|
)
|
|
1126
1118
|
|
|
1127
1119
|
if bias is not None:
|
|
@@ -1137,10 +1129,9 @@ def _aten_convolution(
|
|
|
1137
1129
|
|
|
1138
1130
|
|
|
1139
1131
|
# _native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps)
|
|
1140
|
-
@op(torch.ops.aten._native_batch_norm_legit)
|
|
1141
|
-
def _aten__native_batch_norm_legit(
|
|
1142
|
-
|
|
1143
|
-
):
|
|
1132
|
+
@op(torch.ops.aten._native_batch_norm_legit.default)
|
|
1133
|
+
def _aten__native_batch_norm_legit(input, weight, bias, running_mean,
|
|
1134
|
+
running_var, training, momentum, eps):
|
|
1144
1135
|
"""JAX implementation of batch normalization with optional parameters.
|
|
1145
1136
|
Refers to https://github.com/pytorch/pytorch/blob/cd3a71f754a2248bcfe500de7c9860bd7d2002bf/torch/_decomp/decompositions.py#L1713.
|
|
1146
1137
|
|
|
@@ -1161,8 +1152,7 @@ def _aten__native_batch_norm_legit(
|
|
|
1161
1152
|
DeviceArray: Reversed batch variance (C,) or empty if training is False
|
|
1162
1153
|
"""
|
|
1163
1154
|
reduction_dims = [0] + list(range(2, input.ndim))
|
|
1164
|
-
reshape_dims = [1, -1] + [1]*(input.ndim-2)
|
|
1165
|
-
|
|
1155
|
+
reshape_dims = [1, -1] + [1] * (input.ndim - 2)
|
|
1166
1156
|
if training:
|
|
1167
1157
|
# Calculate batch mean and variance
|
|
1168
1158
|
mean = jnp.mean(input, axis=reduction_dims, keepdims=True)
|
|
@@ -1175,7 +1165,9 @@ def _aten__native_batch_norm_legit(
|
|
|
1175
1165
|
saved_rstd = jnp.squeeze(rstd, reduction_dims)
|
|
1176
1166
|
else:
|
|
1177
1167
|
rstd = jax.lax.rsqrt(running_var.reshape(reshape_dims) + eps)
|
|
1178
|
-
saved_mean = jnp.array(
|
|
1168
|
+
saved_mean = jnp.array(
|
|
1169
|
+
[], dtype=input.dtype
|
|
1170
|
+
) # No need to calculate batch statistics in inference mode
|
|
1179
1171
|
saved_rstd = jnp.array([], dtype=input.dtype)
|
|
1180
1172
|
|
|
1181
1173
|
# Normalize
|
|
@@ -1190,19 +1182,17 @@ def _aten__native_batch_norm_legit(
|
|
|
1190
1182
|
if weight is not None:
|
|
1191
1183
|
x_hat *= weight.reshape(reshape_dims) # Reshape weight for broadcasting
|
|
1192
1184
|
if bias is not None:
|
|
1193
|
-
x_hat += bias.reshape(reshape_dims)
|
|
1185
|
+
x_hat += bias.reshape(reshape_dims) # Reshape bias for broadcasting
|
|
1194
1186
|
|
|
1195
1187
|
return x_hat, saved_mean, saved_rstd
|
|
1196
1188
|
|
|
1197
1189
|
|
|
1198
|
-
|
|
1199
1190
|
@op(torch.ops.aten._native_batch_norm_legit_no_training)
|
|
1200
|
-
def _aten__native_batch_norm_legit_no_training(
|
|
1201
|
-
|
|
1202
|
-
):
|
|
1203
|
-
return _aten__native_batch_norm_legit(
|
|
1204
|
-
|
|
1205
|
-
)
|
|
1191
|
+
def _aten__native_batch_norm_legit_no_training(input, weight, bias,
|
|
1192
|
+
running_mean, running_var,
|
|
1193
|
+
momentum, eps):
|
|
1194
|
+
return _aten__native_batch_norm_legit(input, weight, bias, running_mean,
|
|
1195
|
+
running_var, False, momentum, eps)
|
|
1206
1196
|
|
|
1207
1197
|
|
|
1208
1198
|
@op(torch.ops.aten.relu)
|
|
@@ -1212,7 +1202,15 @@ def _aten_relu(self):
|
|
|
1212
1202
|
|
|
1213
1203
|
@op(torch.ops.aten.cat)
|
|
1214
1204
|
def _aten_cat(tensors, dims=0):
|
|
1215
|
-
|
|
1205
|
+
# handle empty tensors as a special case.
|
|
1206
|
+
# torch.cat will ignore the empty tensor, while jnp.concatenate
|
|
1207
|
+
# will error if the dims > 0.
|
|
1208
|
+
filtered_tensors = [
|
|
1209
|
+
t for t in tensors if not (t.ndim == 1 and t.shape[0] == 0)
|
|
1210
|
+
]
|
|
1211
|
+
if filtered_tensors:
|
|
1212
|
+
return jnp.concatenate(filtered_tensors, dims)
|
|
1213
|
+
return tensors[0]
|
|
1216
1214
|
|
|
1217
1215
|
|
|
1218
1216
|
def _ceil_mode_padding(
|
|
@@ -1220,6 +1218,7 @@ def _ceil_mode_padding(
|
|
|
1220
1218
|
input_shape: list[int],
|
|
1221
1219
|
kernel_size: list[int],
|
|
1222
1220
|
stride: list[int],
|
|
1221
|
+
dilation: list[int],
|
|
1223
1222
|
ceil_mode: bool,
|
|
1224
1223
|
):
|
|
1225
1224
|
"""Creates low and high padding specification for the given padding (which is symmetric) and ceil mode.
|
|
@@ -1232,20 +1231,13 @@ def _ceil_mode_padding(
|
|
|
1232
1231
|
right_padding = left_padding
|
|
1233
1232
|
|
|
1234
1233
|
input_size = input_shape[2 + i]
|
|
1235
|
-
output_size_rem = (input_size + 2 * left_padding -
|
|
1236
|
-
|
|
1237
|
-
]
|
|
1234
|
+
output_size_rem = (input_size + 2 * left_padding -
|
|
1235
|
+
(kernel_size[i] - 1) * dilation[i] - 1) % stride[i]
|
|
1238
1236
|
if ceil_mode and output_size_rem != 0:
|
|
1239
1237
|
extra_padding = stride[i] - output_size_rem
|
|
1240
|
-
new_output_size = (
|
|
1241
|
-
|
|
1242
|
-
|
|
1243
|
-
+ right_padding
|
|
1244
|
-
+ extra_padding
|
|
1245
|
-
- kernel_size[i]
|
|
1246
|
-
+ stride[i]
|
|
1247
|
-
- 1
|
|
1248
|
-
) // stride[i] + 1
|
|
1238
|
+
new_output_size = (input_size + left_padding + right_padding +
|
|
1239
|
+
extra_padding - (kernel_size[i] - 1) * dilation[i] -
|
|
1240
|
+
1 + stride[i] - 1) // stride[i] + 1
|
|
1249
1241
|
# Ensure that the last pooling starts inside the image.
|
|
1250
1242
|
size_to_compare = input_size + left_padding
|
|
1251
1243
|
|
|
@@ -1258,30 +1250,36 @@ def _ceil_mode_padding(
|
|
|
1258
1250
|
|
|
1259
1251
|
@op(torch.ops.aten.max_pool2d_with_indices)
|
|
1260
1252
|
@op(torch.ops.aten.max_pool3d_with_indices)
|
|
1261
|
-
def _aten_max_pool2d_with_indices(
|
|
1262
|
-
|
|
1263
|
-
|
|
1253
|
+
def _aten_max_pool2d_with_indices(inputs,
|
|
1254
|
+
kernel_size,
|
|
1255
|
+
strides=None,
|
|
1256
|
+
padding=0,
|
|
1257
|
+
dilation=1,
|
|
1258
|
+
ceil_mode=False):
|
|
1264
1259
|
num_batch_dims = len(inputs.shape) - len(kernel_size) - 1
|
|
1265
1260
|
kernel_size = tuple(kernel_size)
|
|
1266
|
-
|
|
1261
|
+
# Default stride is kernel_size
|
|
1262
|
+
strides = tuple(strides) if strides else kernel_size
|
|
1267
1263
|
if isinstance(padding, int):
|
|
1268
1264
|
padding = [padding for _ in range(len(kernel_size))]
|
|
1265
|
+
if isinstance(dilation, int):
|
|
1266
|
+
dilation = tuple(dilation for _ in range(len(kernel_size)))
|
|
1267
|
+
elif isinstance(dilation, list):
|
|
1268
|
+
dilation = tuple(dilation)
|
|
1269
1269
|
|
|
1270
1270
|
input_shape = inputs.shape
|
|
1271
1271
|
if num_batch_dims == 0:
|
|
1272
1272
|
input_shape = [1, *input_shape]
|
|
1273
|
-
padding = _ceil_mode_padding(
|
|
1274
|
-
|
|
1275
|
-
)
|
|
1273
|
+
padding = _ceil_mode_padding(padding, input_shape, kernel_size, strides,
|
|
1274
|
+
dilation, ceil_mode)
|
|
1276
1275
|
|
|
1277
|
-
|
|
1278
|
-
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
strides
|
|
1282
|
-
), f"len({window_shape}) must equal len({strides})"
|
|
1276
|
+
assert len(kernel_size) == len(
|
|
1277
|
+
strides), f"len({kernel_size=}) must equal len({strides=})"
|
|
1278
|
+
assert len(kernel_size) == len(
|
|
1279
|
+
dilation), f"len({kernel_size=}) must equal len({dilation=})"
|
|
1283
1280
|
strides = (1,) * (1 + num_batch_dims) + strides
|
|
1284
|
-
dims = (1,) * (1 + num_batch_dims) +
|
|
1281
|
+
dims = (1,) * (1 + num_batch_dims) + kernel_size
|
|
1282
|
+
dilation = (1,) * (1 + num_batch_dims) + dilation
|
|
1285
1283
|
|
|
1286
1284
|
is_single_input = False
|
|
1287
1285
|
if num_batch_dims == 0:
|
|
@@ -1290,26 +1288,27 @@ def _aten_max_pool2d_with_indices(
|
|
|
1290
1288
|
inputs = inputs[None]
|
|
1291
1289
|
strides = (1,) + strides
|
|
1292
1290
|
dims = (1,) + dims
|
|
1291
|
+
dilation = (1,) + dilation
|
|
1293
1292
|
is_single_input = True
|
|
1294
1293
|
|
|
1295
1294
|
assert inputs.ndim == len(dims), f"len({inputs.shape}) != len({dims})"
|
|
1296
1295
|
if not isinstance(padding, str):
|
|
1297
1296
|
padding = tuple(map(tuple, padding))
|
|
1298
|
-
assert len(padding) == len(
|
|
1299
|
-
|
|
1300
|
-
|
|
1301
|
-
)
|
|
1302
|
-
|
|
1303
|
-
[len(x) == 2 for x in padding]
|
|
1304
|
-
), f"each entry in padding {padding} must be length 2"
|
|
1297
|
+
assert len(padding) == len(kernel_size), (
|
|
1298
|
+
f"padding {padding} must specify pads for same number of dims as "
|
|
1299
|
+
f"kernel_size {kernel_size}")
|
|
1300
|
+
assert all([len(x) == 2 for x in padding
|
|
1301
|
+
]), f"each entry in padding {padding} must be length 2"
|
|
1305
1302
|
padding = ((0, 0), (0, 0)) + padding
|
|
1306
1303
|
|
|
1307
|
-
indices = jnp.arange(np.prod(inputs.shape))
|
|
1304
|
+
indices = jnp.arange(np.prod(inputs.shape[-len(kernel_size):]))
|
|
1305
|
+
indices = indices.reshape(inputs.shape[-len(kernel_size):])
|
|
1306
|
+
indices = jnp.broadcast_to(indices, inputs.shape)
|
|
1308
1307
|
|
|
1309
1308
|
def reduce_fn(a, b):
|
|
1310
1309
|
ai, av = a
|
|
1311
1310
|
bi, bv = b
|
|
1312
|
-
which = av
|
|
1311
|
+
which = av >= bv # torch breaks ties in favor of later indices
|
|
1313
1312
|
return jnp.where(which, ai, bi), jnp.where(which, av, bv)
|
|
1314
1313
|
|
|
1315
1314
|
init_val = -jnp.inf
|
|
@@ -1317,44 +1316,90 @@ def _aten_max_pool2d_with_indices(
|
|
|
1317
1316
|
init_val = -(1 << 31)
|
|
1318
1317
|
init_val = jnp.array(init_val).astype(inputs.dtype)
|
|
1319
1318
|
|
|
1320
|
-
# Separate maxpool result and indices into two reduce_window ops. Since
|
|
1321
|
-
# the indices tensor is usually unused in inference, separating the two
|
|
1319
|
+
# Separate maxpool result and indices into two reduce_window ops. Since
|
|
1320
|
+
# the indices tensor is usually unused in inference, separating the two
|
|
1322
1321
|
# can help DCE computations for argmax.
|
|
1323
1322
|
y = jax.lax.reduce_window(
|
|
1324
|
-
inputs,
|
|
1325
|
-
|
|
1323
|
+
inputs,
|
|
1324
|
+
init_val,
|
|
1325
|
+
jax.lax.max,
|
|
1326
|
+
dims,
|
|
1327
|
+
strides,
|
|
1328
|
+
padding,
|
|
1329
|
+
window_dilation=dilation)
|
|
1326
1330
|
indices, _ = jax.lax.reduce_window(
|
|
1327
|
-
(indices, inputs),
|
|
1331
|
+
(indices, inputs),
|
|
1332
|
+
(0, init_val),
|
|
1333
|
+
reduce_fn,
|
|
1334
|
+
dims,
|
|
1335
|
+
strides,
|
|
1336
|
+
padding,
|
|
1337
|
+
window_dilation=dilation,
|
|
1328
1338
|
)
|
|
1329
1339
|
if is_single_input:
|
|
1330
1340
|
indices = jnp.squeeze(indices, axis=0)
|
|
1331
1341
|
y = jnp.squeeze(y, axis=0)
|
|
1332
|
-
|
|
1342
|
+
|
|
1333
1343
|
return y, indices
|
|
1334
1344
|
|
|
1335
1345
|
|
|
1346
|
+
# Aten ops registered under the `xla` library.
|
|
1347
|
+
try:
|
|
1348
|
+
|
|
1349
|
+
@op(torch.ops.xla.max_pool2d_forward)
|
|
1350
|
+
def _xla_max_pool2d_forward(*args, **kwargs):
|
|
1351
|
+
return _aten_max_pool2d_with_indices(*args, **kwargs)[0]
|
|
1352
|
+
|
|
1353
|
+
@op(torch.ops.xla.aot_mark_sharding)
|
|
1354
|
+
def _xla_aot_mark_sharding(t, mesh: str, partition_spec: str):
|
|
1355
|
+
from jax.sharding import PartitionSpec as P, NamedSharding
|
|
1356
|
+
import ast
|
|
1357
|
+
import torch_xla.distributed.spmd as xs
|
|
1358
|
+
pmesh = xs.Mesh.from_str(mesh)
|
|
1359
|
+
assert pmesh is not None
|
|
1360
|
+
partition_spec_eval = ast.literal_eval(partition_spec)
|
|
1361
|
+
jmesh = pmesh.get_jax_mesh()
|
|
1362
|
+
return jax.lax.with_sharding_constraint(
|
|
1363
|
+
t, NamedSharding(jmesh, P(*partition_spec_eval)))
|
|
1364
|
+
|
|
1365
|
+
@op(torch.ops.xla.einsum_linear_forward)
|
|
1366
|
+
def _xla_einsum_linear_forward(input, weight, bias):
|
|
1367
|
+
with jax.named_scope('einsum_linear_forward'):
|
|
1368
|
+
product = jax.numpy.einsum('...n,mn->...m', input, weight)
|
|
1369
|
+
if bias is not None:
|
|
1370
|
+
return product + bias
|
|
1371
|
+
return product
|
|
1372
|
+
|
|
1373
|
+
except AttributeError:
|
|
1374
|
+
pass
|
|
1375
|
+
|
|
1336
1376
|
# TODO add more ops
|
|
1337
1377
|
|
|
1338
1378
|
|
|
1339
1379
|
@op(torch.ops.aten.min)
|
|
1340
1380
|
def _aten_min(x, dim=None, keepdim=False):
|
|
1341
1381
|
if dim is not None:
|
|
1342
|
-
return _with_reduction_scalar(jnp.min, x, dim,
|
|
1382
|
+
return _with_reduction_scalar(jnp.min, x, dim,
|
|
1383
|
+
keepdim), _with_reduction_scalar(
|
|
1384
|
+
jnp.argmin, x, dim,
|
|
1385
|
+
keepdim).astype(jnp.int64)
|
|
1343
1386
|
else:
|
|
1344
1387
|
return _with_reduction_scalar(jnp.min, x, dim, keepdim)
|
|
1345
1388
|
|
|
1346
1389
|
|
|
1347
1390
|
@op(torch.ops.aten.mode)
|
|
1348
1391
|
def _aten_mode(input, dim=-1, keepdim=False, *, out=None):
|
|
1349
|
-
if input.ndim == 0:
|
|
1392
|
+
if input.ndim == 0: # single number
|
|
1350
1393
|
return input, jnp.array(0)
|
|
1351
|
-
dim = (input.ndim +
|
|
1394
|
+
dim = (input.ndim +
|
|
1395
|
+
dim) % input.ndim # jnp.scipy.stats.mode does not accept -1 as dim
|
|
1352
1396
|
# keepdims must be True for accurate broadcasting
|
|
1353
1397
|
mode, _ = jax.scipy.stats.mode(input, axis=dim, keepdims=True)
|
|
1354
1398
|
mode_broadcast = jnp.broadcast_to(mode, input.shape)
|
|
1355
1399
|
if not keepdim:
|
|
1356
1400
|
mode = mode.squeeze(axis=dim)
|
|
1357
|
-
indices = jnp.argmax(
|
|
1401
|
+
indices = jnp.argmax(
|
|
1402
|
+
jnp.equal(mode_broadcast, input), axis=dim, keepdims=keepdim)
|
|
1358
1403
|
return mode, indices
|
|
1359
1404
|
|
|
1360
1405
|
|
|
@@ -1388,8 +1433,7 @@ def _aten_var(x, dim=None, *, correction=1, keepdim=False, out=None):
|
|
|
1388
1433
|
@op(torch.ops.prims.broadcast_in_dim)
|
|
1389
1434
|
def _prims_broadcast_in_dim(t, shape, broadcast_dimensions):
|
|
1390
1435
|
return jax.lax.broadcast_in_dim(
|
|
1391
|
-
|
|
1392
|
-
)
|
|
1436
|
+
t, shape, broadcast_dimensions=broadcast_dimensions)
|
|
1393
1437
|
|
|
1394
1438
|
|
|
1395
1439
|
# aten.native_group_norm -- should use decomp table
|
|
@@ -1432,17 +1476,15 @@ def _aten_native_group_norm(input, weight, bias, N, C, HxW, group, eps=1e-5):
|
|
|
1432
1476
|
normalized = (x - mean) * rstd
|
|
1433
1477
|
return normalized, mean, rstd
|
|
1434
1478
|
|
|
1435
|
-
normalized, group_mean, group_rstd = jax.lax.map(
|
|
1436
|
-
|
|
1437
|
-
)
|
|
1479
|
+
normalized, group_mean, group_rstd = jax.lax.map(group_norm_body,
|
|
1480
|
+
reshaped_input)
|
|
1438
1481
|
|
|
1439
1482
|
# Reshape back to original input shape
|
|
1440
1483
|
output = jnp.reshape(normalized, input_shape)
|
|
1441
1484
|
|
|
1442
1485
|
# **Affine transformation**
|
|
1443
|
-
affine_shape = [
|
|
1444
|
-
|
|
1445
|
-
] # Shape for broadcasting
|
|
1486
|
+
affine_shape = [-1 if i == 1 else 1 for i in range(input.ndim)
|
|
1487
|
+
] # Shape for broadcasting
|
|
1446
1488
|
if weight is not None and bias is not None:
|
|
1447
1489
|
output = bias.reshape(affine_shape) + output * weight.reshape(affine_shape)
|
|
1448
1490
|
elif weight is not None:
|
|
@@ -1474,22 +1516,25 @@ def _aten_linalg_vector_norm(self, ord=2, dim=None, keepdim=False, dtype=None):
|
|
|
1474
1516
|
The tensor containing the calculated vector norms.
|
|
1475
1517
|
"""
|
|
1476
1518
|
|
|
1477
|
-
if ord not in {2, float("inf"), float("-inf"), "fro"
|
|
1519
|
+
if ord not in {2, float("inf"), float("-inf"), "fro"
|
|
1520
|
+
} and not isinstance(ord, (int, float)):
|
|
1478
1521
|
raise ValueError(
|
|
1479
|
-
|
|
1480
|
-
|
|
1481
|
-
|
|
1482
|
-
|
|
1522
|
+
f"Unsupported ord value: {ord}. Supported values are 2, inf, -inf, and"
|
|
1523
|
+
" 'fro'.")
|
|
1524
|
+
|
|
1483
1525
|
# Special cases (for efficiency and clarity)
|
|
1484
1526
|
if ord == 0:
|
|
1485
1527
|
if self.shape == ():
|
|
1486
1528
|
# float sets it to float64. set it back to input type
|
|
1487
1529
|
result = jnp.astype(jnp.array(float(self != 0)), self.dtype)
|
|
1488
1530
|
else:
|
|
1489
|
-
result = _with_reduction_scalar(jnp.sum, jnp.where(self != 0, 1, 0), dim,
|
|
1531
|
+
result = _with_reduction_scalar(jnp.sum, jnp.where(self != 0, 1, 0), dim,
|
|
1532
|
+
keepdim)
|
|
1490
1533
|
|
|
1491
1534
|
elif ord == 2: # Euclidean norm
|
|
1492
|
-
result = jnp.sqrt(
|
|
1535
|
+
result = jnp.sqrt(
|
|
1536
|
+
_with_reduction_scalar(jnp.sum,
|
|
1537
|
+
jnp.abs(self)**2, dim, keepdim))
|
|
1493
1538
|
|
|
1494
1539
|
elif ord == float("inf"):
|
|
1495
1540
|
result = _with_reduction_scalar(jnp.max, jnp.abs(self), dim, keepdim)
|
|
@@ -1498,12 +1543,14 @@ def _aten_linalg_vector_norm(self, ord=2, dim=None, keepdim=False, dtype=None):
|
|
|
1498
1543
|
result = _with_reduction_scalar(jnp.min, jnp.abs(self), dim, keepdim)
|
|
1499
1544
|
|
|
1500
1545
|
elif ord == "fro": # Frobenius norm
|
|
1501
|
-
result = jnp.sqrt(
|
|
1546
|
+
result = jnp.sqrt(
|
|
1547
|
+
_with_reduction_scalar(jnp.sum,
|
|
1548
|
+
jnp.abs(self)**2, dim, keepdim))
|
|
1502
1549
|
|
|
1503
1550
|
else: # General case (e.g., ord = 1, ord = 3)
|
|
1504
|
-
result = _with_reduction_scalar(jnp.sum,
|
|
1505
|
-
|
|
1506
|
-
|
|
1551
|
+
result = _with_reduction_scalar(jnp.sum,
|
|
1552
|
+
jnp.abs(self)**ord, dim,
|
|
1553
|
+
keepdim)**(1.0 / ord)
|
|
1507
1554
|
|
|
1508
1555
|
# (Optional) dtype conversion
|
|
1509
1556
|
if dtype is not None:
|
|
@@ -1539,9 +1586,12 @@ def _aten_sinh(self):
|
|
|
1539
1586
|
|
|
1540
1587
|
# aten.native_layer_norm_backward
|
|
1541
1588
|
@op(torch.ops.aten.native_layer_norm_backward)
|
|
1542
|
-
def _aten_native_layer_norm_backward(
|
|
1543
|
-
|
|
1544
|
-
|
|
1589
|
+
def _aten_native_layer_norm_backward(grad_out,
|
|
1590
|
+
input,
|
|
1591
|
+
normalized_shape,
|
|
1592
|
+
weight,
|
|
1593
|
+
bias,
|
|
1594
|
+
eps=1e-5):
|
|
1545
1595
|
"""Implements the backward pass of layer normalization in Jax as defined by `aten::native_layer_norm_backward`.
|
|
1546
1596
|
|
|
1547
1597
|
Args:
|
|
@@ -1555,9 +1605,8 @@ def _aten_native_layer_norm_backward(
|
|
|
1555
1605
|
Returns:
|
|
1556
1606
|
A tuple of (grad_input, grad_weight, grad_bias).
|
|
1557
1607
|
"""
|
|
1558
|
-
return jax.lax.native_layer_norm_backward(
|
|
1559
|
-
|
|
1560
|
-
)
|
|
1608
|
+
return jax.lax.native_layer_norm_backward(grad_out, input, normalized_shape,
|
|
1609
|
+
weight, bias, eps)
|
|
1561
1610
|
|
|
1562
1611
|
|
|
1563
1612
|
# aten.reflection_pad3d_backward
|
|
@@ -1585,12 +1634,14 @@ def _aten_bitwise_not(self):
|
|
|
1585
1634
|
|
|
1586
1635
|
|
|
1587
1636
|
# aten.bitwise_left_shift
|
|
1637
|
+
@op(torch.ops.aten.__lshift__)
|
|
1588
1638
|
@op(torch.ops.aten.bitwise_left_shift)
|
|
1589
1639
|
def _aten_bitwise_left_shift(input, other):
|
|
1590
1640
|
return jnp.left_shift(input, other)
|
|
1591
1641
|
|
|
1592
1642
|
|
|
1593
1643
|
# aten.bitwise_right_shift
|
|
1644
|
+
@op(torch.ops.aten.__rshift__)
|
|
1594
1645
|
@op(torch.ops.aten.bitwise_right_shift)
|
|
1595
1646
|
def _aten_bitwise_right_shift(input, other):
|
|
1596
1647
|
return jnp.right_shift(input, other)
|
|
@@ -1671,10 +1722,8 @@ def _scatter_index(dim, index):
|
|
|
1671
1722
|
target_shape = [1] * len(index_shape)
|
|
1672
1723
|
target_shape[i] = index_shape[i]
|
|
1673
1724
|
input_indexes.append(
|
|
1674
|
-
|
|
1675
|
-
|
|
1676
|
-
)
|
|
1677
|
-
)
|
|
1725
|
+
jnp.broadcast_to(
|
|
1726
|
+
jnp.arange(index_shape[i]).reshape(target_shape), index_shape))
|
|
1678
1727
|
return tuple(input_indexes), tuple(source_indexes)
|
|
1679
1728
|
|
|
1680
1729
|
|
|
@@ -1686,6 +1735,7 @@ def _aten_scatter_add(input, dim, index, src):
|
|
|
1686
1735
|
input_indexes, source_indexes = _scatter_index(dim, index)
|
|
1687
1736
|
return input.at[input_indexes].add(src[source_indexes])
|
|
1688
1737
|
|
|
1738
|
+
|
|
1689
1739
|
# aten.masked_scatter
|
|
1690
1740
|
@op(torch.ops.aten.masked_scatter)
|
|
1691
1741
|
def _aten_masked_scatter(self, mask, source):
|
|
@@ -1707,6 +1757,7 @@ def _aten_masked_scatter(self, mask, source):
|
|
|
1707
1757
|
|
|
1708
1758
|
return final_arr
|
|
1709
1759
|
|
|
1760
|
+
|
|
1710
1761
|
@op(torch.ops.aten.masked_select)
|
|
1711
1762
|
def _aten_masked_select(self, mask, *args, **kwargs):
|
|
1712
1763
|
broadcast_shape = jnp.broadcast_shapes(self.shape, mask.shape)
|
|
@@ -1722,6 +1773,7 @@ def _aten_masked_select(self, mask, *args, **kwargs):
|
|
|
1722
1773
|
|
|
1723
1774
|
return self_flat[true_indices]
|
|
1724
1775
|
|
|
1776
|
+
|
|
1725
1777
|
# aten.logical_not
|
|
1726
1778
|
|
|
1727
1779
|
|
|
@@ -1730,11 +1782,13 @@ def _aten_masked_select(self, mask, *args, **kwargs):
|
|
|
1730
1782
|
def _aten_sign(x):
|
|
1731
1783
|
return jnp.sign(x)
|
|
1732
1784
|
|
|
1785
|
+
|
|
1733
1786
|
# aten.signbit
|
|
1734
1787
|
@op(torch.ops.aten.signbit)
|
|
1735
1788
|
def _aten_signbit(x):
|
|
1736
1789
|
return jnp.signbit(x)
|
|
1737
1790
|
|
|
1791
|
+
|
|
1738
1792
|
# aten.sigmoid
|
|
1739
1793
|
@op(torch.ops.aten.sigmoid)
|
|
1740
1794
|
@op_base.promote_int_input
|
|
@@ -1760,7 +1814,13 @@ def _aten_atan(self):
|
|
|
1760
1814
|
|
|
1761
1815
|
@op(torch.ops.aten.scatter_reduce)
|
|
1762
1816
|
@op(torch.ops.aten.scatter)
|
|
1763
|
-
def _aten_scatter_reduce(input,
|
|
1817
|
+
def _aten_scatter_reduce(input,
|
|
1818
|
+
dim,
|
|
1819
|
+
index,
|
|
1820
|
+
src,
|
|
1821
|
+
reduce=None,
|
|
1822
|
+
*,
|
|
1823
|
+
include_self=True):
|
|
1764
1824
|
if not isinstance(src, jnp.ndarray):
|
|
1765
1825
|
src = jnp.array(src, dtype=input.dtype)
|
|
1766
1826
|
input_indexes, source_indexes = _scatter_index(dim, index)
|
|
@@ -1817,41 +1877,6 @@ def _aten_gt(self, other):
|
|
|
1817
1877
|
return self > other
|
|
1818
1878
|
|
|
1819
1879
|
|
|
1820
|
-
# aten.pixel_shuffle
|
|
1821
|
-
@op(torch.ops.aten.pixel_shuffle)
|
|
1822
|
-
def _aten_pixel_shuffle(x, upscale_factor):
|
|
1823
|
-
"""PixelShuffle implementation in JAX.
|
|
1824
|
-
|
|
1825
|
-
Args:
|
|
1826
|
-
x: Input tensor. Typically a feature map.
|
|
1827
|
-
upscale_factor: Integer by which to upscale the spatial dimensions.
|
|
1828
|
-
|
|
1829
|
-
Returns:
|
|
1830
|
-
Tensor after PixelShuffle operation.
|
|
1831
|
-
"""
|
|
1832
|
-
|
|
1833
|
-
batch_size, channels, height, width = x.shape
|
|
1834
|
-
|
|
1835
|
-
if channels % (upscale_factor**2) != 0:
|
|
1836
|
-
raise ValueError(
|
|
1837
|
-
"Number of channels must be divisible by the square of the upscale factor."
|
|
1838
|
-
)
|
|
1839
|
-
|
|
1840
|
-
new_channels = channels // (upscale_factor**2)
|
|
1841
|
-
new_height = height * upscale_factor
|
|
1842
|
-
new_width = width * upscale_factor
|
|
1843
|
-
|
|
1844
|
-
x = x.reshape(
|
|
1845
|
-
batch_size, new_channels, upscale_factor, upscale_factor, height, width
|
|
1846
|
-
)
|
|
1847
|
-
x = jnp.transpose(
|
|
1848
|
-
x, (0, 1, 2, 4, 3, 5)
|
|
1849
|
-
) # Move channels to spatial dimensions
|
|
1850
|
-
x = x.reshape(batch_size, new_channels, new_height, new_width)
|
|
1851
|
-
|
|
1852
|
-
return x
|
|
1853
|
-
|
|
1854
|
-
|
|
1855
1880
|
# aten.sym_stride
|
|
1856
1881
|
# aten.lt
|
|
1857
1882
|
@op(torch.ops.aten.lt)
|
|
@@ -1883,8 +1908,7 @@ def pool(inputs, init, reduce_fn, window_shape, strides, padding):
|
|
|
1883
1908
|
num_batch_dims = inputs.ndim - (len(window_shape) + 1)
|
|
1884
1909
|
strides = strides or (1,) * len(window_shape)
|
|
1885
1910
|
assert len(window_shape) == len(
|
|
1886
|
-
|
|
1887
|
-
), f"len({window_shape}) must equal len({strides})"
|
|
1911
|
+
strides), f"len({window_shape}) must equal len({strides})"
|
|
1888
1912
|
strides = (1,) * (1 + num_batch_dims) + strides
|
|
1889
1913
|
dims = (1,) * (1 + num_batch_dims) + window_shape
|
|
1890
1914
|
|
|
@@ -1901,23 +1925,22 @@ def pool(inputs, init, reduce_fn, window_shape, strides, padding):
|
|
|
1901
1925
|
if not isinstance(padding, str):
|
|
1902
1926
|
padding = tuple(map(tuple, padding))
|
|
1903
1927
|
assert len(padding) == len(window_shape), (
|
|
1904
|
-
|
|
1905
|
-
|
|
1906
|
-
)
|
|
1907
|
-
|
|
1908
|
-
[len(x) == 2 for x in padding]
|
|
1909
|
-
), f"each entry in padding {padding} must be length 2"
|
|
1928
|
+
f"padding {padding} must specify pads for same number of dims as "
|
|
1929
|
+
f"window_shape {window_shape}")
|
|
1930
|
+
assert all([len(x) == 2 for x in padding
|
|
1931
|
+
]), f"each entry in padding {padding} must be length 2"
|
|
1910
1932
|
padding = ((0, 0), (0, 0)) + padding
|
|
1911
1933
|
y = jax.lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding)
|
|
1912
1934
|
if is_single_input:
|
|
1913
1935
|
y = jnp.squeeze(y, axis=0)
|
|
1914
1936
|
return y
|
|
1915
1937
|
|
|
1916
|
-
|
|
1938
|
+
|
|
1917
1939
|
@op(torch.ops.aten._adaptive_avg_pool2d)
|
|
1918
1940
|
@op(torch.ops.aten._adaptive_avg_pool3d)
|
|
1919
|
-
def adaptive_avg_pool2or3d(input: jnp.ndarray,
|
|
1920
|
-
|
|
1941
|
+
def adaptive_avg_pool2or3d(input: jnp.ndarray,
|
|
1942
|
+
output_size: Tuple[int, int]) -> jnp.ndarray:
|
|
1943
|
+
"""
|
|
1921
1944
|
Applies a 2/3D adaptive average pooling over an input signal composed of several input planes.
|
|
1922
1945
|
|
|
1923
1946
|
See :class:`~torch.nn.AdaptiveAvgPool2d` for details and output shape.
|
|
@@ -1929,124 +1952,128 @@ def adaptive_avg_pool2or3d(input: jnp.ndarray, output_size: Tuple[int, int]) ->
|
|
|
1929
1952
|
Context:
|
|
1930
1953
|
https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py#L2401
|
|
1931
1954
|
"""
|
|
1932
|
-
|
|
1933
|
-
|
|
1934
|
-
|
|
1935
|
-
|
|
1936
|
-
|
|
1937
|
-
|
|
1938
|
-
|
|
1939
|
-
|
|
1940
|
-
|
|
1941
|
-
|
|
1942
|
-
|
|
1943
|
-
|
|
1944
|
-
|
|
1945
|
-
|
|
1946
|
-
|
|
1947
|
-
|
|
1948
|
-
|
|
1949
|
-
|
|
1950
|
-
|
|
1951
|
-
|
|
1952
|
-
|
|
1953
|
-
|
|
1954
|
-
|
|
1955
|
-
|
|
1956
|
-
|
|
1957
|
-
|
|
1958
|
-
|
|
1959
|
-
|
|
1960
|
-
|
|
1961
|
-
|
|
1962
|
-
|
|
1963
|
-
|
|
1964
|
-
|
|
1965
|
-
|
|
1966
|
-
|
|
1967
|
-
|
|
1968
|
-
|
|
1969
|
-
|
|
1970
|
-
|
|
1971
|
-
|
|
1972
|
-
|
|
1973
|
-
|
|
1974
|
-
|
|
1975
|
-
|
|
1976
|
-
|
|
1977
|
-
|
|
1978
|
-
|
|
1979
|
-
|
|
1980
|
-
|
|
1981
|
-
|
|
1982
|
-
|
|
1983
|
-
|
|
1984
|
-
|
|
1985
|
-
|
|
1986
|
-
|
|
1987
|
-
|
|
1988
|
-
|
|
1989
|
-
|
|
1990
|
-
|
|
1991
|
-
|
|
1992
|
-
|
|
1993
|
-
|
|
1994
|
-
|
|
1995
|
-
|
|
1996
|
-
|
|
1997
|
-
|
|
1998
|
-
|
|
1999
|
-
|
|
2000
|
-
|
|
1955
|
+
shape = input.shape
|
|
1956
|
+
ndim = len(shape)
|
|
1957
|
+
out_dim = len(output_size)
|
|
1958
|
+
num_spatial_dim = ndim - out_dim
|
|
1959
|
+
|
|
1960
|
+
# Preconditions
|
|
1961
|
+
|
|
1962
|
+
assert ndim in (
|
|
1963
|
+
out_dim + 1, out_dim + 2
|
|
1964
|
+
), f"adaptive_avg_pool{num_spatial_dim}d(): Expected {num_spatial_dim+1}D or {num_spatial_dim+2}D tensor, but got {ndim}"
|
|
1965
|
+
for d in input.shape[-2:]:
|
|
1966
|
+
assert d != 0, "adaptive_avg_pool{num_spactial_dim}d(): Expected input to have non-zero size for " \
|
|
1967
|
+
f"non-batch dimensions, but input has shape {tuple(shape)}."
|
|
1968
|
+
|
|
1969
|
+
# Optimisation (we should also do this in the kernel implementation)
|
|
1970
|
+
if all(s % o == 0 for o, s in zip(output_size, shape[-out_dim:])):
|
|
1971
|
+
stride = tuple(i // o for i, o in zip(shape[-out_dim:], output_size))
|
|
1972
|
+
kernel = tuple(i - (o - 1) * s
|
|
1973
|
+
for i, o, s in zip(shape[-out_dim:], output_size, stride))
|
|
1974
|
+
return _aten_avg_pool(
|
|
1975
|
+
input,
|
|
1976
|
+
kernel,
|
|
1977
|
+
strides=stride,
|
|
1978
|
+
)
|
|
1979
|
+
|
|
1980
|
+
def start_index(a, b, c):
|
|
1981
|
+
return (a * c) // b
|
|
1982
|
+
|
|
1983
|
+
def end_index(a, b, c):
|
|
1984
|
+
return ((a + 1) * c + b - 1) // b
|
|
1985
|
+
|
|
1986
|
+
def compute_idx(in_size, out_size):
|
|
1987
|
+
orange = jnp.arange(out_size, dtype=jnp.int64)
|
|
1988
|
+
i0 = start_index(orange, out_size, in_size)
|
|
1989
|
+
# Let length = end_index - start_index, i.e. the length of the pooling kernels
|
|
1990
|
+
# length.max() can be computed analytically as follows:
|
|
1991
|
+
maxlength = in_size // out_size + 1
|
|
1992
|
+
in_size_mod = in_size % out_size
|
|
1993
|
+
# adaptive = True iff there are kernels with different lengths
|
|
1994
|
+
adaptive = not (in_size_mod == 0 or out_size % in_size_mod == 0)
|
|
1995
|
+
if adaptive:
|
|
1996
|
+
maxlength += 1
|
|
1997
|
+
elif in_size_mod == 0:
|
|
1998
|
+
maxlength -= 1
|
|
1999
|
+
|
|
2000
|
+
range_max = jnp.arange(maxlength, dtype=jnp.int64)
|
|
2001
|
+
idx = i0[:, None] + range_max
|
|
2002
|
+
if adaptive:
|
|
2003
|
+
# Need to clamp to avoid accessing out-of-bounds memory
|
|
2004
|
+
idx = jnp.minimum(idx, in_size - 1)
|
|
2005
|
+
|
|
2006
|
+
# Compute the length
|
|
2007
|
+
i1 = end_index(orange, out_size, in_size)
|
|
2008
|
+
length = i1 - i0
|
|
2009
|
+
else:
|
|
2010
|
+
length = maxlength
|
|
2011
|
+
return idx, length, range_max, adaptive
|
|
2012
|
+
|
|
2013
|
+
idx, length, range_max, adaptive = [[None] * out_dim for _ in range(4)]
|
|
2014
|
+
# length is not None if it's constant, otherwise we'll need to compute it
|
|
2015
|
+
for i, (s, o) in enumerate(zip(shape[-out_dim:], output_size)):
|
|
2016
|
+
idx[i], length[i], range_max[i], adaptive[i] = compute_idx(s, o)
|
|
2017
|
+
|
|
2018
|
+
def _unsqueeze_to_dim(x, dim):
|
|
2019
|
+
ndim = len(x.shape)
|
|
2020
|
+
return jax.lax.expand_dims(x, tuple(range(ndim, dim)))
|
|
2021
|
+
|
|
2022
|
+
if out_dim == 2:
|
|
2023
|
+
# NOTE: unsqueeze to insert extra 1 in ranks; so they
|
|
2024
|
+
# would broadcast
|
|
2025
|
+
vals = input[..., _unsqueeze_to_dim(idx[0], 4), idx[1]]
|
|
2026
|
+
reduce_axis = (-3, -1)
|
|
2027
|
+
else:
|
|
2028
|
+
assert out_dim == 3
|
|
2029
|
+
vals = input[...,
|
|
2030
|
+
_unsqueeze_to_dim(idx[0], 6),
|
|
2031
|
+
_unsqueeze_to_dim(idx[1], 4), idx[2]]
|
|
2032
|
+
reduce_axis = (-5, -3, -1)
|
|
2033
|
+
|
|
2034
|
+
# Shortcut for the simpler case
|
|
2035
|
+
if not any(adaptive):
|
|
2036
|
+
return jnp.mean(vals, axis=reduce_axis)
|
|
2037
|
+
|
|
2038
|
+
def maybe_mask(vals, length, range_max, adaptive, dim):
|
|
2039
|
+
if isinstance(length, int):
|
|
2040
|
+
return vals, length
|
|
2001
2041
|
else:
|
|
2002
|
-
|
|
2003
|
-
|
|
2004
|
-
|
|
2005
|
-
|
|
2006
|
-
|
|
2007
|
-
|
|
2008
|
-
|
|
2009
|
-
|
|
2010
|
-
|
|
2011
|
-
|
|
2012
|
-
|
|
2013
|
-
|
|
2014
|
-
|
|
2015
|
-
|
|
2016
|
-
|
|
2017
|
-
|
|
2018
|
-
|
|
2019
|
-
|
|
2020
|
-
|
|
2021
|
-
|
|
2022
|
-
|
|
2023
|
-
|
|
2024
|
-
|
|
2025
|
-
# Compute the length of each window
|
|
2026
|
-
length = _unsqueeze_to_dim(length, -dim)
|
|
2027
|
-
return vals, length
|
|
2028
|
-
|
|
2029
|
-
for i in range(len(length)):
|
|
2030
|
-
vals, length[i] = maybe_mask(vals, length[i], range_max[i], adaptive=adaptive[i], dim=(i - out_dim))
|
|
2031
|
-
|
|
2032
|
-
# We unroll the sum as we assume that the kernels are going to be small
|
|
2033
|
-
ret = jnp.sum(vals, axis=reduce_axis)
|
|
2034
|
-
# NOTE: math.prod because we want to expand it to length[0] * length[1] * ...
|
|
2035
|
-
# this is multiplication with broadcasting, not regular pointwise product
|
|
2036
|
-
return ret / math.prod(length)
|
|
2037
|
-
|
|
2042
|
+
# zero-out the things we didn't really want to select
|
|
2043
|
+
assert dim < 0
|
|
2044
|
+
# hack
|
|
2045
|
+
mask = range_max >= length[:, None]
|
|
2046
|
+
if dim == -2:
|
|
2047
|
+
mask = _unsqueeze_to_dim(mask, 4)
|
|
2048
|
+
elif dim == -3:
|
|
2049
|
+
mask = _unsqueeze_to_dim(mask, 6)
|
|
2050
|
+
vals = jnp.where(mask, 0.0, vals)
|
|
2051
|
+
# Compute the length of each window
|
|
2052
|
+
length = _unsqueeze_to_dim(length, -dim)
|
|
2053
|
+
return vals, length
|
|
2054
|
+
|
|
2055
|
+
for i in range(len(length)):
|
|
2056
|
+
vals, length[i] = maybe_mask(
|
|
2057
|
+
vals, length[i], range_max[i], adaptive=adaptive[i], dim=(i - out_dim))
|
|
2058
|
+
|
|
2059
|
+
# We unroll the sum as we assume that the kernels are going to be small
|
|
2060
|
+
ret = jnp.sum(vals, axis=reduce_axis)
|
|
2061
|
+
# NOTE: math.prod because we want to expand it to length[0] * length[1] * ...
|
|
2062
|
+
# this is multiplication with broadcasting, not regular pointwise product
|
|
2063
|
+
return ret / math.prod(length)
|
|
2064
|
+
|
|
2038
2065
|
|
|
2039
2066
|
@op(torch.ops.aten.avg_pool1d)
|
|
2040
2067
|
@op(torch.ops.aten.avg_pool2d)
|
|
2041
2068
|
@op(torch.ops.aten.avg_pool3d)
|
|
2042
2069
|
def _aten_avg_pool(
|
|
2043
|
-
|
|
2044
|
-
|
|
2045
|
-
|
|
2046
|
-
|
|
2047
|
-
|
|
2048
|
-
|
|
2049
|
-
|
|
2070
|
+
inputs,
|
|
2071
|
+
kernel_size,
|
|
2072
|
+
strides=None,
|
|
2073
|
+
padding=0,
|
|
2074
|
+
ceil_mode=False,
|
|
2075
|
+
count_include_pad=True,
|
|
2076
|
+
divisor_override=None,
|
|
2050
2077
|
):
|
|
2051
2078
|
num_batch_dims = len(inputs.shape) - len(kernel_size) - 1
|
|
2052
2079
|
kernel_size = tuple(kernel_size)
|
|
@@ -2060,7 +2087,7 @@ def _aten_avg_pool(
|
|
|
2060
2087
|
if num_batch_dims == 0:
|
|
2061
2088
|
input_shape = [1, *input_shape]
|
|
2062
2089
|
padding = _ceil_mode_padding(padding, input_shape, kernel_size, strides,
|
|
2063
|
-
ceil_mode)
|
|
2090
|
+
[1] * len(kernel_size), ceil_mode)
|
|
2064
2091
|
|
|
2065
2092
|
y = pool(inputs, 0.0, jax.lax.add, kernel_size, strides, padding)
|
|
2066
2093
|
if divisor_override is not None:
|
|
@@ -2102,9 +2129,11 @@ def _aten_avg_pool(
|
|
|
2102
2129
|
)
|
|
2103
2130
|
return y.astype(inputs.dtype)
|
|
2104
2131
|
|
|
2132
|
+
|
|
2105
2133
|
# helper function to generate all indices to iterate through ndarray
|
|
2106
|
-
def _generate_indices(dims, skip_dim_indices
|
|
2134
|
+
def _generate_indices(dims, skip_dim_indices=[]):
|
|
2107
2135
|
res = []
|
|
2136
|
+
|
|
2108
2137
|
def _helper(curr_dim_idx, sofar):
|
|
2109
2138
|
if curr_dim_idx in skip_dim_indices:
|
|
2110
2139
|
_helper(curr_dim_idx + 1, sofar[:])
|
|
@@ -2115,10 +2144,11 @@ def _generate_indices(dims, skip_dim_indices = []):
|
|
|
2115
2144
|
for i in range(dims[curr_dim_idx]):
|
|
2116
2145
|
sofar[curr_dim_idx] = i
|
|
2117
2146
|
_helper(curr_dim_idx + 1, sofar[:])
|
|
2118
|
-
|
|
2147
|
+
|
|
2119
2148
|
_helper(0, [0 for _ in dims])
|
|
2120
2149
|
return res
|
|
2121
2150
|
|
|
2151
|
+
|
|
2122
2152
|
# aten.sym_numel
|
|
2123
2153
|
# aten.reciprocal
|
|
2124
2154
|
@op(torch.ops.aten.reciprocal)
|
|
@@ -2174,10 +2204,14 @@ def _aten_round(input, decimals=0):
|
|
|
2174
2204
|
@op(torch.ops.aten.max)
|
|
2175
2205
|
def _aten_max(self, dim=None, keepdim=False):
|
|
2176
2206
|
if dim is not None:
|
|
2177
|
-
return _with_reduction_scalar(jnp.max, self, dim,
|
|
2207
|
+
return _with_reduction_scalar(jnp.max, self, dim,
|
|
2208
|
+
keepdim), _with_reduction_scalar(
|
|
2209
|
+
jnp.argmax, self, dim,
|
|
2210
|
+
keepdim).astype(jnp.int64)
|
|
2178
2211
|
else:
|
|
2179
2212
|
return _with_reduction_scalar(jnp.max, self, dim, keepdim)
|
|
2180
2213
|
|
|
2214
|
+
|
|
2181
2215
|
# aten.maximum
|
|
2182
2216
|
@op(torch.ops.aten.maximum)
|
|
2183
2217
|
def _aten_maximum(self, other):
|
|
@@ -2216,27 +2250,28 @@ def _with_reduction_scalar(jax_func, self, dim, keepdim):
|
|
|
2216
2250
|
def _aten_any(self, dim=None, keepdim=False):
|
|
2217
2251
|
return _with_reduction_scalar(jnp.any, self, dim, keepdim)
|
|
2218
2252
|
|
|
2253
|
+
|
|
2219
2254
|
# aten.arange
|
|
2220
2255
|
@op(torch.ops.aten.arange.start_step)
|
|
2221
2256
|
@op(torch.ops.aten.arange.start)
|
|
2222
2257
|
@op(torch.ops.aten.arange.default)
|
|
2223
2258
|
@op_base.convert_dtype(use_default_dtype=False)
|
|
2224
2259
|
def _aten_arange(
|
|
2225
|
-
|
|
2226
|
-
|
|
2227
|
-
|
|
2228
|
-
|
|
2229
|
-
|
|
2230
|
-
|
|
2231
|
-
|
|
2232
|
-
|
|
2233
|
-
|
|
2260
|
+
start,
|
|
2261
|
+
end=None,
|
|
2262
|
+
step=None,
|
|
2263
|
+
*,
|
|
2264
|
+
dtype=None,
|
|
2265
|
+
layout=None,
|
|
2266
|
+
requires_grad=False,
|
|
2267
|
+
device=None,
|
|
2268
|
+
pin_memory=False,
|
|
2234
2269
|
):
|
|
2235
2270
|
return jnp.arange(
|
|
2236
|
-
|
|
2237
|
-
|
|
2238
|
-
|
|
2239
|
-
|
|
2271
|
+
op_base.maybe_convert_constant_dtype(start, dtype),
|
|
2272
|
+
op_base.maybe_convert_constant_dtype(end, dtype),
|
|
2273
|
+
op_base.maybe_convert_constant_dtype(step, dtype),
|
|
2274
|
+
dtype=dtype,
|
|
2240
2275
|
)
|
|
2241
2276
|
|
|
2242
2277
|
|
|
@@ -2245,6 +2280,7 @@ def _aten_arange(
|
|
|
2245
2280
|
def _aten_argmax(self, dim=None, keepdim=False):
|
|
2246
2281
|
return _with_reduction_scalar(jnp.argmax, self, dim, keepdim)
|
|
2247
2282
|
|
|
2283
|
+
|
|
2248
2284
|
def _strided_index(sizes, strides, storage_offset=None):
|
|
2249
2285
|
ind = jnp.zeros(sizes, dtype=jnp.int32)
|
|
2250
2286
|
|
|
@@ -2257,6 +2293,7 @@ def _strided_index(sizes, strides, storage_offset=None):
|
|
|
2257
2293
|
ind += storage_offset
|
|
2258
2294
|
return ind
|
|
2259
2295
|
|
|
2296
|
+
|
|
2260
2297
|
# aten.as_strided
|
|
2261
2298
|
@op(torch.ops.aten.as_strided)
|
|
2262
2299
|
@op(torch.ops.aten.as_strided_copy)
|
|
@@ -2311,7 +2348,7 @@ def _aten_broadcast_tensors(*tensors):
|
|
|
2311
2348
|
Args:
|
|
2312
2349
|
shapes: A list of tuples representing the shapes of the input tensors.
|
|
2313
2350
|
|
|
2314
|
-
Returns:
|
|
2351
|
+
Returns:
|
|
2315
2352
|
A tuple representing the broadcasted output shape.
|
|
2316
2353
|
"""
|
|
2317
2354
|
|
|
@@ -2344,11 +2381,13 @@ def _aten_broadcast_tensors(*tensors):
|
|
|
2344
2381
|
A tuple specifying which dimensions of the input tensor should be broadcasted.
|
|
2345
2382
|
"""
|
|
2346
2383
|
|
|
2347
|
-
res = tuple(
|
|
2384
|
+
res = tuple(
|
|
2385
|
+
i for i, (in_dim, out_dim) in enumerate(zip(input_shape, output_shape)))
|
|
2348
2386
|
return res
|
|
2349
2387
|
|
|
2350
2388
|
# clean some function's previous wrap
|
|
2351
|
-
if len(tensors)==1 and len(tensors[0])>=1 and isinstance(
|
|
2389
|
+
if len(tensors) == 1 and len(tensors[0]) >= 1 and isinstance(
|
|
2390
|
+
tensors[0][0], jax.Array):
|
|
2352
2391
|
tensors = tensors[0]
|
|
2353
2392
|
|
|
2354
2393
|
# Get the shapes of all input tensors
|
|
@@ -2357,7 +2396,8 @@ def _aten_broadcast_tensors(*tensors):
|
|
|
2357
2396
|
output_shape = _get_broadcast_shape(shapes)
|
|
2358
2397
|
# Broadcast each tensor to the output shape
|
|
2359
2398
|
broadcasted_tensors = [
|
|
2360
|
-
jax.lax.broadcast_in_dim(t, output_shape,
|
|
2399
|
+
jax.lax.broadcast_in_dim(t, output_shape,
|
|
2400
|
+
_broadcast_dimensions(t.shape, output_shape))
|
|
2361
2401
|
for t in tensors
|
|
2362
2402
|
]
|
|
2363
2403
|
|
|
@@ -2376,6 +2416,7 @@ def _aten_broadcast_to(input, shape):
|
|
|
2376
2416
|
def _aten_clamp(self, min=None, max=None):
|
|
2377
2417
|
return jnp.clip(self, min, max)
|
|
2378
2418
|
|
|
2419
|
+
|
|
2379
2420
|
@op(torch.ops.aten.clamp_min)
|
|
2380
2421
|
def _aten_clamp_min(input, min):
|
|
2381
2422
|
return jnp.clip(input, min=min)
|
|
@@ -2394,7 +2435,7 @@ def _aten_constant_pad_nd(input, padding, value=0):
|
|
|
2394
2435
|
rev_padding = [(padding[i - 1], padding[i], 0) for i in range(m - 1, 0, -2)]
|
|
2395
2436
|
pad_dim = tuple(([(0, 0, 0)] * (len(input.shape) - m // 2)) + rev_padding)
|
|
2396
2437
|
value_casted = jax.numpy.array(value, dtype=input.dtype)
|
|
2397
|
-
return jax.lax.pad(input, padding_value=value_casted, padding_config
|
|
2438
|
+
return jax.lax.pad(input, padding_value=value_casted, padding_config=pad_dim)
|
|
2398
2439
|
|
|
2399
2440
|
|
|
2400
2441
|
# aten.convolution_backward
|
|
@@ -2421,9 +2462,8 @@ def _aten_cdist_forward(x1, x2, p, compute_mode=""):
|
|
|
2421
2462
|
@op(torch.ops.aten._pdist_forward)
|
|
2422
2463
|
def _aten__pdist_forward(x, p=2):
|
|
2423
2464
|
pairwise_dists = _aten_cdist_forward(x, x, p)
|
|
2424
|
-
condensed_dists = pairwise_dists[
|
|
2425
|
-
|
|
2426
|
-
]
|
|
2465
|
+
condensed_dists = pairwise_dists[jnp.triu_indices(
|
|
2466
|
+
pairwise_dists.shape[0], k=1)]
|
|
2427
2467
|
return condensed_dists
|
|
2428
2468
|
|
|
2429
2469
|
|
|
@@ -2449,25 +2489,33 @@ def _aten_cosh(input):
|
|
|
2449
2489
|
return jnp.cosh(input)
|
|
2450
2490
|
|
|
2451
2491
|
|
|
2492
|
+
@op(torch.ops.aten.diag)
|
|
2493
|
+
def _aten_diag(input, diagonal=0):
|
|
2494
|
+
return jnp.diag(input, diagonal)
|
|
2495
|
+
|
|
2496
|
+
|
|
2452
2497
|
# aten.diagonal
|
|
2453
2498
|
@op(torch.ops.aten.diagonal)
|
|
2499
|
+
@op(torch.ops.aten.diagonal_copy)
|
|
2454
2500
|
def _aten_diagonal(input, offset=0, dim1=0, dim2=1):
|
|
2455
2501
|
return jnp.diagonal(input, offset, dim1, dim2)
|
|
2456
2502
|
|
|
2457
2503
|
|
|
2458
2504
|
def diag_indices_with_offset(input_shape, offset, dim1=0, dim2=1):
|
|
2459
|
-
|
|
2460
|
-
|
|
2461
|
-
|
|
2462
|
-
|
|
2463
|
-
|
|
2464
|
-
|
|
2465
|
-
|
|
2466
|
-
|
|
2467
|
-
|
|
2468
|
-
|
|
2469
|
-
|
|
2470
|
-
|
|
2505
|
+
input_len = len(input_shape)
|
|
2506
|
+
if dim1 == dim2 or not (0 <= dim1 < input_len and 0 <= dim2 < input_len):
|
|
2507
|
+
raise ValueError("dim1 and dim2 must be different and in range [0, " +
|
|
2508
|
+
str(input_len - 1) + "]")
|
|
2509
|
+
|
|
2510
|
+
size1, size2 = input_shape[dim1], input_shape[dim2]
|
|
2511
|
+
if offset >= 0:
|
|
2512
|
+
indices1 = jnp.arange(min(size1, size2 - offset))
|
|
2513
|
+
indices2 = jnp.arange(offset, offset + len(indices1))
|
|
2514
|
+
else:
|
|
2515
|
+
indices2 = jnp.arange(min(size1 + offset, size2))
|
|
2516
|
+
indices1 = jnp.arange(-offset, -offset + len(indices2))
|
|
2517
|
+
return [indices1, indices2]
|
|
2518
|
+
|
|
2471
2519
|
|
|
2472
2520
|
@op(torch.ops.aten.diagonal_scatter)
|
|
2473
2521
|
def _aten_diagonal_scatter(input, src, offset=0, dim1=0, dim2=1):
|
|
@@ -2476,17 +2524,17 @@ def _aten_diagonal_scatter(input, src, offset=0, dim1=0, dim2=1):
|
|
|
2476
2524
|
if input.ndim == 2:
|
|
2477
2525
|
return input.at[tuple(indexes)].set(src)
|
|
2478
2526
|
else:
|
|
2479
|
-
# src has the same shape as the output of
|
|
2527
|
+
# src has the same shape as the output of
|
|
2480
2528
|
# jnp.diagonal(input, offset, dim1, dim2).
|
|
2481
2529
|
# Last dimension always contains the diagonal elements,
|
|
2482
2530
|
# while the preceding dimensions represent the "slices"
|
|
2483
2531
|
# from which these diagonals are extracted. Thus,
|
|
2484
2532
|
# we alter input axes to match this assumption, write src
|
|
2485
2533
|
# and then move the axes back to the original state.
|
|
2486
|
-
input = jnp.moveaxis(input, (dim1, dim2), (-2
|
|
2487
|
-
multi_indexes = [slice(None)]*(input.ndim-2) + indexes
|
|
2534
|
+
input = jnp.moveaxis(input, (dim1, dim2), (-2, -1))
|
|
2535
|
+
multi_indexes = [slice(None)] * (input.ndim - 2) + indexes
|
|
2488
2536
|
input = input.at[tuple(multi_indexes)].set(src)
|
|
2489
|
-
return jnp.moveaxis(input, (-2
|
|
2537
|
+
return jnp.moveaxis(input, (-2, -1), (dim1, dim2))
|
|
2490
2538
|
|
|
2491
2539
|
|
|
2492
2540
|
# aten.diagflat
|
|
@@ -2507,9 +2555,9 @@ def _aten_eq(input1, input2):
|
|
|
2507
2555
|
|
|
2508
2556
|
|
|
2509
2557
|
# aten.equal
|
|
2510
|
-
@op(torch.ops.aten.equal
|
|
2558
|
+
@op(torch.ops.aten.equal)
|
|
2511
2559
|
def _aten_equal(input, other):
|
|
2512
|
-
res = jnp.array_equal(input
|
|
2560
|
+
res = jnp.array_equal(input, other)
|
|
2513
2561
|
return bool(res)
|
|
2514
2562
|
|
|
2515
2563
|
|
|
@@ -2559,7 +2607,12 @@ def _aten_exp2(input):
|
|
|
2559
2607
|
# aten.fill
|
|
2560
2608
|
@op(torch.ops.aten.fill)
|
|
2561
2609
|
@op(torch.ops.aten.full_like)
|
|
2562
|
-
def _aten_fill(x,
|
|
2610
|
+
def _aten_fill(x,
|
|
2611
|
+
value,
|
|
2612
|
+
dtype=None,
|
|
2613
|
+
pin_memory=None,
|
|
2614
|
+
memory_format=None,
|
|
2615
|
+
device=None):
|
|
2563
2616
|
if dtype is None:
|
|
2564
2617
|
dtype = x.dtype
|
|
2565
2618
|
else:
|
|
@@ -2634,7 +2687,8 @@ def _aten_glu(x, dim=-1):
|
|
|
2634
2687
|
# aten.hardtanh
|
|
2635
2688
|
@op(torch.ops.aten.hardtanh)
|
|
2636
2689
|
def _aten_hardtanh(input, min_val=-1, max_val=1, inplace=False):
|
|
2637
|
-
if input.dtype == np.int64 and isinstance(max_val, float) and isinstance(
|
|
2690
|
+
if input.dtype == np.int64 and isinstance(max_val, float) and isinstance(
|
|
2691
|
+
min_val, float):
|
|
2638
2692
|
min_val = int(min_val)
|
|
2639
2693
|
max_val = int(max_val)
|
|
2640
2694
|
return jnp.clip(input, min_val, max_val)
|
|
@@ -2644,7 +2698,7 @@ def _aten_hardtanh(input, min_val=-1, max_val=1, inplace=False):
|
|
|
2644
2698
|
@op(torch.ops.aten.histc)
|
|
2645
2699
|
def _aten_histc(input, bins=100, min=0, max=0):
|
|
2646
2700
|
# TODO(@manfei): this function might cause some uncertainty
|
|
2647
|
-
if min==0 and max==0:
|
|
2701
|
+
if min == 0 and max == 0:
|
|
2648
2702
|
if isinstance(input, jnp.ndarray) and input.size == 0:
|
|
2649
2703
|
min = 0
|
|
2650
2704
|
max = 0
|
|
@@ -2652,7 +2706,8 @@ def _aten_histc(input, bins=100, min=0, max=0):
|
|
|
2652
2706
|
min = jnp.min(input)
|
|
2653
2707
|
max = jnp.max(input)
|
|
2654
2708
|
range_value = (min, max)
|
|
2655
|
-
hist, bin_edges = jnp.histogram(
|
|
2709
|
+
hist, bin_edges = jnp.histogram(
|
|
2710
|
+
input, bins=bins, range=range_value, weights=None, density=None)
|
|
2656
2711
|
return hist
|
|
2657
2712
|
|
|
2658
2713
|
|
|
@@ -2667,22 +2722,28 @@ def _aten_digamma(input, *, out=None):
|
|
|
2667
2722
|
# replace indices where input == 0 with -inf in res
|
|
2668
2723
|
return jnp.where(jnp.equal(input, jnp.zeros(input.shape)), -jnp.inf, res)
|
|
2669
2724
|
|
|
2725
|
+
|
|
2670
2726
|
@op(torch.ops.aten.igamma)
|
|
2671
2727
|
def _aten_igamma(input, other):
|
|
2672
2728
|
return jax.scipy.special.gammainc(input, other)
|
|
2673
2729
|
|
|
2730
|
+
|
|
2674
2731
|
@op(torch.ops.aten.lgamma)
|
|
2675
2732
|
def _aten_lgamma(input, *, out=None):
|
|
2676
2733
|
return jax.scipy.special.gammaln(input).astype(jnp.float32)
|
|
2677
2734
|
|
|
2735
|
+
|
|
2678
2736
|
@op(torch.ops.aten.mvlgamma)
|
|
2679
2737
|
def _aten_mvlgamma(input, p, *, out=None):
|
|
2680
|
-
|
|
2738
|
+
input = input.astype(mappings.t2j_dtype(torch.get_default_dtype()))
|
|
2739
|
+
return jax.scipy.special.multigammaln(input, p)
|
|
2740
|
+
|
|
2681
2741
|
|
|
2682
2742
|
@op(torch.ops.aten.linalg_eig)
|
|
2683
2743
|
def _aten_linalg_eig(A):
|
|
2684
2744
|
return jnp.linalg.eig(A)
|
|
2685
2745
|
|
|
2746
|
+
|
|
2686
2747
|
@op(torch.ops.aten._linalg_eigh)
|
|
2687
2748
|
def _aten_linalg_eigh(A, UPLO='L'):
|
|
2688
2749
|
return jnp.linalg.eigh(A, UPLO)
|
|
@@ -2704,7 +2765,9 @@ def _aten_linalg_lstsq(A, B, rcond=None, driver='gelsy'):
|
|
|
2704
2765
|
A_reshaped = A.reshape((batch_size,) + A.shape[-2:])
|
|
2705
2766
|
B_reshaped = B.reshape((batch_size,) + B.shape[-2:])
|
|
2706
2767
|
|
|
2707
|
-
X, residuals, rank, singular_values = jax.vmap(
|
|
2768
|
+
X, residuals, rank, singular_values = jax.vmap(
|
|
2769
|
+
jnp.linalg.lstsq, in_axes=(0,
|
|
2770
|
+
0))(A_reshaped, B_reshaped, rcond=rcond)
|
|
2708
2771
|
|
|
2709
2772
|
X = X.reshape(batch_shape + X.shape[-2:])
|
|
2710
2773
|
|
|
@@ -2720,7 +2783,8 @@ def _aten_linalg_lstsq(A, B, rcond=None, driver='gelsy'):
|
|
|
2720
2783
|
residuals = residuals.reshape(batch_shape + residuals.shape[-1:])
|
|
2721
2784
|
|
|
2722
2785
|
if driver in ['gelsd', 'gelss']:
|
|
2723
|
-
singular_values = singular_values.reshape(batch_shape +
|
|
2786
|
+
singular_values = singular_values.reshape(batch_shape +
|
|
2787
|
+
singular_values.shape[-1:])
|
|
2724
2788
|
else:
|
|
2725
2789
|
singular_values = jnp.array([], dtype=input_dtype)
|
|
2726
2790
|
|
|
@@ -2729,17 +2793,17 @@ def _aten_linalg_lstsq(A, B, rcond=None, driver='gelsy'):
|
|
|
2729
2793
|
X, residuals, rank, singular_values = jnp.linalg.lstsq(A, B, rcond=rcond)
|
|
2730
2794
|
|
|
2731
2795
|
if driver not in ['gelsd', 'gelsy', 'gelss']:
|
|
2732
|
-
|
|
2796
|
+
rank = jnp.array([], dtype=jnp.int64)
|
|
2733
2797
|
|
|
2734
2798
|
rank_value = None
|
|
2735
2799
|
if rank.size > 0:
|
|
2736
|
-
|
|
2737
|
-
|
|
2800
|
+
rank_value = int(rank.item())
|
|
2801
|
+
rank = jnp.array(rank_value, dtype=jnp.int64)
|
|
2738
2802
|
|
|
2739
2803
|
# When driver is ‘gels’, assume that A is full-rank.
|
|
2740
|
-
full_rank =
|
|
2804
|
+
full_rank = driver == 'gels' or rank_value == n
|
|
2741
2805
|
if driver == 'gelsy' or m <= n or (not full_rank):
|
|
2742
|
-
|
|
2806
|
+
residuals = jnp.array([], dtype=input_dtype)
|
|
2743
2807
|
|
|
2744
2808
|
if driver not in ['gelsd', 'gelss']:
|
|
2745
2809
|
singular_values = jnp.array([], dtype=input_dtype)
|
|
@@ -2753,8 +2817,7 @@ def _aten_linalg_ldl_factor_ex(A, hermitian=False, check_errors=False):
|
|
|
2753
2817
|
# https://github.com/jax-ml/jax/issues/12779
|
|
2754
2818
|
# TODO: Not tested for complex inputs. Does not support hermitian=True
|
|
2755
2819
|
pivots = jnp.broadcast_to(
|
|
2756
|
-
jnp.arange(1, A.shape[-1]+1, dtype=jnp.int32), A.shape[:-1]
|
|
2757
|
-
)
|
|
2820
|
+
jnp.arange(1, A.shape[-1] + 1, dtype=jnp.int32), A.shape[:-1])
|
|
2758
2821
|
info = jnp.zeros(A.shape[:-2], jnp.int32)
|
|
2759
2822
|
C = jnp.linalg.cholesky(A)
|
|
2760
2823
|
if C.size == 0:
|
|
@@ -2767,7 +2830,7 @@ def _aten_linalg_ldl_factor_ex(A, hermitian=False, check_errors=False):
|
|
|
2767
2830
|
|
|
2768
2831
|
D = C * jnp.eye(C.shape[-1], dtype=A.dtype)
|
|
2769
2832
|
LD = C @ jnp.linalg.inv(D)
|
|
2770
|
-
LD = fill_diagonal_batch(LD, D*D)
|
|
2833
|
+
LD = fill_diagonal_batch(LD, D * D)
|
|
2771
2834
|
return LD, pivots, info
|
|
2772
2835
|
|
|
2773
2836
|
|
|
@@ -2787,9 +2850,9 @@ def _aten_linalg_lu(A, pivot=True, out=None):
|
|
|
2787
2850
|
U = jnp.triu(lu[..., :k, :])
|
|
2788
2851
|
|
|
2789
2852
|
def perm_to_P(perm):
|
|
2790
|
-
|
|
2791
|
-
|
|
2792
|
-
|
|
2853
|
+
m = perm.shape[-1]
|
|
2854
|
+
P = jnp.eye(m, dtype=dtype)[perm].T
|
|
2855
|
+
return P
|
|
2793
2856
|
|
|
2794
2857
|
if permutation.ndim > 1:
|
|
2795
2858
|
num_batch_dims = permutation.ndim - 1
|
|
@@ -2798,7 +2861,7 @@ def _aten_linalg_lu(A, pivot=True, out=None):
|
|
|
2798
2861
|
|
|
2799
2862
|
P = perm_to_P(permutation)
|
|
2800
2863
|
|
|
2801
|
-
return P,L,U
|
|
2864
|
+
return P, L, U
|
|
2802
2865
|
|
|
2803
2866
|
|
|
2804
2867
|
@op(torch.ops.aten.linalg_lu_factor_ex)
|
|
@@ -2810,6 +2873,21 @@ def _aten_linalg_lu_factor_ex(A, pivot=True, check_errors=False):
|
|
|
2810
2873
|
return lu, pivots, info
|
|
2811
2874
|
|
|
2812
2875
|
|
|
2876
|
+
@op(torch.ops.aten.linalg_lu_solve)
|
|
2877
|
+
def _aten_linalg_lu_solve(LU, pivots, B, left=True, adjoint=False):
|
|
2878
|
+
# JAX pivots are offset by 1 compared to torch
|
|
2879
|
+
pivots = pivots - 1
|
|
2880
|
+
if not left:
|
|
2881
|
+
# XA = B is same as A'X = B'
|
|
2882
|
+
trans = 0 if adjoint else 2
|
|
2883
|
+
x = jax.scipy.linalg.lu_solve((LU, pivots), jnp.matrix_transpose(B), trans)
|
|
2884
|
+
x = jnp.matrix_transpose(x)
|
|
2885
|
+
else:
|
|
2886
|
+
trans = 2 if adjoint else 0
|
|
2887
|
+
x = jax.scipy.linalg.lu_solve((LU, pivots), B, trans)
|
|
2888
|
+
return x
|
|
2889
|
+
|
|
2890
|
+
|
|
2813
2891
|
@op(torch.ops.aten.gcd)
|
|
2814
2892
|
def _aten_gcd(input, other):
|
|
2815
2893
|
return jnp.gcd(input, other)
|
|
@@ -2874,12 +2952,14 @@ def _aten_log2(x):
|
|
|
2874
2952
|
|
|
2875
2953
|
# aten.logical_and
|
|
2876
2954
|
@op(torch.ops.aten.logical_and)
|
|
2955
|
+
@op(torch.ops.aten.__and__)
|
|
2877
2956
|
def _aten_logical_and(self, other):
|
|
2878
2957
|
return jnp.logical_and(self, other)
|
|
2879
2958
|
|
|
2880
2959
|
|
|
2881
2960
|
# aten.logical_or
|
|
2882
2961
|
@op(torch.ops.aten.logical_or)
|
|
2962
|
+
@op(torch.ops.aten.__or__)
|
|
2883
2963
|
def _aten_logical_or(self, other):
|
|
2884
2964
|
return jnp.logical_or(self, other)
|
|
2885
2965
|
|
|
@@ -2894,7 +2974,7 @@ def _aten_logical_not(self):
|
|
|
2894
2974
|
@op(torch.ops.aten._log_softmax)
|
|
2895
2975
|
def _aten_log_softmax(self, axis=-1, half_to_float=False):
|
|
2896
2976
|
if self.shape == ():
|
|
2897
|
-
|
|
2977
|
+
return jnp.astype(0.0, self.dtype)
|
|
2898
2978
|
return jax.nn.log_softmax(self, axis)
|
|
2899
2979
|
|
|
2900
2980
|
|
|
@@ -2921,6 +3001,7 @@ def _aten_logcumsumexp(self, dim=None):
|
|
|
2921
3001
|
# aten.max_pool3d_backward
|
|
2922
3002
|
# aten.logical_xor
|
|
2923
3003
|
@op(torch.ops.aten.logical_xor)
|
|
3004
|
+
@op(torch.ops.aten.__xor__)
|
|
2924
3005
|
def _aten_logical_xor(self, other):
|
|
2925
3006
|
return jnp.logical_xor(self, other)
|
|
2926
3007
|
|
|
@@ -2933,19 +3014,22 @@ def _aten_logical_xor(self, other):
|
|
|
2933
3014
|
def _aten_neg(x):
|
|
2934
3015
|
return -1 * x
|
|
2935
3016
|
|
|
3017
|
+
|
|
2936
3018
|
@op(torch.ops.aten.nextafter)
|
|
2937
3019
|
def _aten_nextafter(input, other, *, out=None):
|
|
2938
3020
|
return jnp.nextafter(input, other)
|
|
2939
3021
|
|
|
2940
3022
|
|
|
2941
3023
|
@op(torch.ops.aten.nonzero_static)
|
|
2942
|
-
def _aten_nonzero_static(input, size, fill_value
|
|
3024
|
+
def _aten_nonzero_static(input, size, fill_value=-1):
|
|
2943
3025
|
indices = jnp.argwhere(input)
|
|
2944
3026
|
|
|
2945
3027
|
if size < indices.shape[0]:
|
|
2946
3028
|
indices = indices[:size]
|
|
2947
3029
|
elif size > indices.shape[0]:
|
|
2948
|
-
padding = jnp.full((size - indices.shape[0], indices.shape[1]),
|
|
3030
|
+
padding = jnp.full((size - indices.shape[0], indices.shape[1]),
|
|
3031
|
+
fill_value,
|
|
3032
|
+
dtype=indices.dtype)
|
|
2949
3033
|
indices = jnp.concatenate((indices, padding))
|
|
2950
3034
|
|
|
2951
3035
|
return indices
|
|
@@ -2954,9 +3038,11 @@ def _aten_nonzero_static(input, size, fill_value = -1):
|
|
|
2954
3038
|
# aten.nonzero
|
|
2955
3039
|
@op(torch.ops.aten.nonzero)
|
|
2956
3040
|
def _aten_nonzero(x, as_tuple=False):
|
|
2957
|
-
if jnp.ndim(x) == 0 and (as_tuple or x.item()==0):
|
|
3041
|
+
if jnp.ndim(x) == 0 and (as_tuple or x.item() == 0):
|
|
2958
3042
|
return torch.empty(0, 0, dtype=torch.int64)
|
|
2959
|
-
if jnp.ndim(
|
|
3043
|
+
if jnp.ndim(
|
|
3044
|
+
x
|
|
3045
|
+
) == 0: # when x is scalar, return torch.tensor([], size=(1, 0), dtype=torch.int64)
|
|
2960
3046
|
res = torch.empty(1, 0, dtype=torch.int64)
|
|
2961
3047
|
return jnp.array(res.numpy())
|
|
2962
3048
|
index_tuple = jnp.nonzero(x)
|
|
@@ -2997,15 +3083,15 @@ def _aten_put(self, index, source, accumulate=False):
|
|
|
2997
3083
|
# aten.randperm
|
|
2998
3084
|
# randperm.generator(SymInt n, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None)
|
|
2999
3085
|
@op(torch.ops.aten.randperm, needs_env=True)
|
|
3000
|
-
def _aten_randperm(
|
|
3001
|
-
|
|
3002
|
-
|
|
3003
|
-
|
|
3004
|
-
|
|
3005
|
-
|
|
3006
|
-
|
|
3007
|
-
|
|
3008
|
-
|
|
3086
|
+
def _aten_randperm(n,
|
|
3087
|
+
*,
|
|
3088
|
+
generator=None,
|
|
3089
|
+
dtype=None,
|
|
3090
|
+
layout=None,
|
|
3091
|
+
device=None,
|
|
3092
|
+
pin_memory=None,
|
|
3093
|
+
env=None):
|
|
3094
|
+
"""
|
|
3009
3095
|
Generates a random permutation of integers from 0 to n-1.
|
|
3010
3096
|
|
|
3011
3097
|
Args:
|
|
@@ -3019,14 +3105,14 @@ def _aten_randperm(
|
|
|
3019
3105
|
Returns:
|
|
3020
3106
|
A DeviceArray containing a random permutation of integers from 0 to n-1.
|
|
3021
3107
|
"""
|
|
3022
|
-
|
|
3023
|
-
|
|
3024
|
-
|
|
3025
|
-
|
|
3026
|
-
|
|
3027
|
-
|
|
3028
|
-
|
|
3029
|
-
|
|
3108
|
+
if dtype:
|
|
3109
|
+
dtype = mappings.t2j_dtype(dtype)
|
|
3110
|
+
else:
|
|
3111
|
+
dtype = jnp.int64.dtype
|
|
3112
|
+
key = env.get_and_rotate_prng_key(generator)
|
|
3113
|
+
indices = jnp.arange(n, dtype=dtype)
|
|
3114
|
+
permutation = jax.random.permutation(key, indices)
|
|
3115
|
+
return permutation
|
|
3030
3116
|
|
|
3031
3117
|
|
|
3032
3118
|
# aten.reflection_pad3d
|
|
@@ -3071,8 +3157,8 @@ def _aten_sort(a, dim=-1, descending=False, stable=False):
|
|
|
3071
3157
|
if a.shape == ():
|
|
3072
3158
|
return (a, jnp.astype(0, 'int64'))
|
|
3073
3159
|
return (
|
|
3074
|
-
|
|
3075
|
-
|
|
3160
|
+
jnp.sort(a, axis=dim, stable=stable, descending=descending),
|
|
3161
|
+
jnp.argsort(a, axis=dim, stable=stable, descending=descending),
|
|
3076
3162
|
)
|
|
3077
3163
|
|
|
3078
3164
|
|
|
@@ -3114,8 +3200,8 @@ def _aten_topk(input, k, dim=None, largest=True, sorted=True, *, out=None):
|
|
|
3114
3200
|
if dim != -1 and dim != len(input.shape) - 1:
|
|
3115
3201
|
transpose_shape = list(range(len(input.shape)))
|
|
3116
3202
|
transpose_shape[dim], transpose_shape[-1] = (
|
|
3117
|
-
|
|
3118
|
-
|
|
3203
|
+
transpose_shape[-1],
|
|
3204
|
+
transpose_shape[dim],
|
|
3119
3205
|
)
|
|
3120
3206
|
input = jnp.transpose(input, transpose_shape)
|
|
3121
3207
|
|
|
@@ -3124,8 +3210,7 @@ def _aten_topk(input, k, dim=None, largest=True, sorted=True, *, out=None):
|
|
|
3124
3210
|
if sorted:
|
|
3125
3211
|
values = jnp.sort(values, descending=True)
|
|
3126
3212
|
indices = jnp.take_along_axis(
|
|
3127
|
-
|
|
3128
|
-
)
|
|
3213
|
+
indices, jnp.argsort(values, axis=-1, descending=True), axis=-1)
|
|
3129
3214
|
|
|
3130
3215
|
if not largest:
|
|
3131
3216
|
values = -values # Negate values back if we found smallest
|
|
@@ -3140,21 +3225,39 @@ def _aten_topk(input, k, dim=None, largest=True, sorted=True, *, out=None):
|
|
|
3140
3225
|
# aten.tril_indices
|
|
3141
3226
|
#tril_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None)
|
|
3142
3227
|
@op(torch.ops.aten.tril_indices)
|
|
3143
|
-
def _aten_tril_indices(row,
|
|
3228
|
+
def _aten_tril_indices(row,
|
|
3229
|
+
col,
|
|
3230
|
+
offset=0,
|
|
3231
|
+
*,
|
|
3232
|
+
dtype=jnp.int64.dtype,
|
|
3233
|
+
layout=None,
|
|
3234
|
+
device=None,
|
|
3235
|
+
pin_memory=None):
|
|
3144
3236
|
a, b = jnp.tril_indices(row, offset, col)
|
|
3145
3237
|
return jnp.stack((a, b))
|
|
3146
3238
|
|
|
3239
|
+
|
|
3147
3240
|
# aten.tril_indices
|
|
3148
3241
|
#tril_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None)
|
|
3149
3242
|
@op(torch.ops.aten.triu_indices)
|
|
3150
|
-
def _aten_triu_indices(row,
|
|
3243
|
+
def _aten_triu_indices(row,
|
|
3244
|
+
col,
|
|
3245
|
+
offset=0,
|
|
3246
|
+
*,
|
|
3247
|
+
dtype=jnp.int64.dtype,
|
|
3248
|
+
layout=None,
|
|
3249
|
+
device=None,
|
|
3250
|
+
pin_memory=None):
|
|
3151
3251
|
a, b = jnp.triu_indices(row, offset, col)
|
|
3152
3252
|
return jnp.stack((a, b))
|
|
3153
3253
|
|
|
3154
3254
|
|
|
3155
3255
|
@op(torch.ops.aten.unbind_copy)
|
|
3156
3256
|
def _aten_unbind(a, dim=0):
|
|
3157
|
-
return [
|
|
3257
|
+
return [
|
|
3258
|
+
jax.lax.index_in_dim(a, i, dim, keepdims=False)
|
|
3259
|
+
for i in range(a.shape[dim])
|
|
3260
|
+
]
|
|
3158
3261
|
|
|
3159
3262
|
|
|
3160
3263
|
# aten.unique_dim
|
|
@@ -3167,12 +3270,13 @@ def _aten_unique_dim(input_tensor,
|
|
|
3167
3270
|
sort=True,
|
|
3168
3271
|
return_inverse=False,
|
|
3169
3272
|
return_counts=False):
|
|
3170
|
-
result_tensor_or_tuple = jnp.unique(
|
|
3171
|
-
|
|
3172
|
-
|
|
3173
|
-
|
|
3174
|
-
|
|
3175
|
-
|
|
3273
|
+
result_tensor_or_tuple = jnp.unique(
|
|
3274
|
+
input_tensor,
|
|
3275
|
+
return_index=False,
|
|
3276
|
+
return_inverse=return_inverse,
|
|
3277
|
+
return_counts=return_counts,
|
|
3278
|
+
axis=dim,
|
|
3279
|
+
equal_nan=False)
|
|
3176
3280
|
result_list = (
|
|
3177
3281
|
list(result_tensor_or_tuple) if isinstance(result_tensor_or_tuple, tuple)
|
|
3178
3282
|
else [result_tensor_or_tuple])
|
|
@@ -3197,15 +3301,14 @@ def _aten_unique_dim(input_tensor,
|
|
|
3197
3301
|
# NOTE: Like the CUDA and CPU implementations, this implementation always sorts
|
|
3198
3302
|
# the tensor regardless of the `sorted` argument passed to `torch.unique`.
|
|
3199
3303
|
@op(torch.ops.aten._unique)
|
|
3200
|
-
def _aten_unique(input_tensor,
|
|
3201
|
-
|
|
3202
|
-
|
|
3203
|
-
|
|
3204
|
-
|
|
3205
|
-
|
|
3206
|
-
|
|
3207
|
-
|
|
3208
|
-
equal_nan=False)
|
|
3304
|
+
def _aten_unique(input_tensor, sort=True, return_inverse=False):
|
|
3305
|
+
result_tensor_or_tuple = jnp.unique(
|
|
3306
|
+
input_tensor,
|
|
3307
|
+
return_index=False,
|
|
3308
|
+
return_inverse=return_inverse,
|
|
3309
|
+
return_counts=False,
|
|
3310
|
+
axis=None,
|
|
3311
|
+
equal_nan=False)
|
|
3209
3312
|
if return_inverse:
|
|
3210
3313
|
return result_tensor_or_tuple
|
|
3211
3314
|
else:
|
|
@@ -3221,11 +3324,12 @@ def _aten_unique2(input_tensor,
|
|
|
3221
3324
|
sort=True,
|
|
3222
3325
|
return_inverse=False,
|
|
3223
3326
|
return_counts=False):
|
|
3224
|
-
return _aten_unique_dim(
|
|
3225
|
-
|
|
3226
|
-
|
|
3227
|
-
|
|
3228
|
-
|
|
3327
|
+
return _aten_unique_dim(
|
|
3328
|
+
input_tensor=input_tensor,
|
|
3329
|
+
dim=None,
|
|
3330
|
+
sort=sort,
|
|
3331
|
+
return_inverse=return_inverse,
|
|
3332
|
+
return_counts=return_counts)
|
|
3229
3333
|
|
|
3230
3334
|
|
|
3231
3335
|
# aten.unique_consecutive
|
|
@@ -3255,17 +3359,18 @@ def _aten_unique_consecutive(input_tensor,
|
|
|
3255
3359
|
if dim < 0:
|
|
3256
3360
|
dim += ndim
|
|
3257
3361
|
|
|
3258
|
-
nd_slice_0 = tuple(
|
|
3259
|
-
|
|
3260
|
-
nd_slice_1 = tuple(
|
|
3261
|
-
|
|
3362
|
+
nd_slice_0 = tuple(
|
|
3363
|
+
slice(None, -1) if d == dim else slice(None) for d in range(ndim))
|
|
3364
|
+
nd_slice_1 = tuple(
|
|
3365
|
+
slice(1, None) if d == dim else slice(None) for d in range(ndim))
|
|
3262
3366
|
|
|
3263
3367
|
axes_to_reduce = tuple(d for d in range(ndim) if d != dim)
|
|
3264
3368
|
|
|
3265
3369
|
does_not_equal_prior = (
|
|
3266
|
-
jnp.any(
|
|
3267
|
-
|
|
3268
|
-
|
|
3370
|
+
jnp.any(
|
|
3371
|
+
input_tensor[nd_slice_0] != input_tensor[nd_slice_1],
|
|
3372
|
+
axis=axes_to_reduce,
|
|
3373
|
+
keepdims=False))
|
|
3269
3374
|
|
|
3270
3375
|
if input_tensor.shape[dim] != 0:
|
|
3271
3376
|
# Prepend `True` to represent the first element of the input.
|
|
@@ -3273,18 +3378,17 @@ def _aten_unique_consecutive(input_tensor,
|
|
|
3273
3378
|
|
|
3274
3379
|
include_indices = jnp.argwhere(does_not_equal_prior)[:, 0]
|
|
3275
3380
|
|
|
3276
|
-
output_tensor = input_tensor[
|
|
3277
|
-
|
|
3381
|
+
output_tensor = input_tensor[tuple(
|
|
3382
|
+
include_indices if d == dim else slice(None) for d in range(ndim))]
|
|
3278
3383
|
|
|
3279
3384
|
if return_inverse or return_counts:
|
|
3280
|
-
counts = (
|
|
3281
|
-
|
|
3385
|
+
counts = (
|
|
3386
|
+
jnp.append(include_indices[1:], input_tensor.shape[dim]) -
|
|
3387
|
+
include_indices[:])
|
|
3282
3388
|
|
|
3283
3389
|
inverse = (
|
|
3284
3390
|
jnp.reshape(jnp.repeat(jnp.arange(len(counts)), counts), inverse_shape)
|
|
3285
|
-
if return_inverse
|
|
3286
|
-
else None
|
|
3287
|
-
)
|
|
3391
|
+
if return_inverse else None)
|
|
3288
3392
|
|
|
3289
3393
|
return output_tensor, inverse, counts
|
|
3290
3394
|
|
|
@@ -3302,25 +3406,33 @@ def _aten_unique_consecutive(input_tensor,
|
|
|
3302
3406
|
@op(torch.ops.aten.where.ScalarSelf)
|
|
3303
3407
|
@op(torch.ops.aten.where.ScalarOther)
|
|
3304
3408
|
@op(torch.ops.aten.where.Scalar)
|
|
3305
|
-
def _aten_where(condition, x
|
|
3409
|
+
def _aten_where(condition, x=None, y=None):
|
|
3306
3410
|
return jnp.where(condition, x, y)
|
|
3307
3411
|
|
|
3308
3412
|
|
|
3309
3413
|
# aten.to.dtype
|
|
3310
3414
|
# Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None
|
|
3311
3415
|
@op(torch.ops.aten.to.dtype)
|
|
3312
|
-
def _aten_to_dtype(
|
|
3313
|
-
|
|
3314
|
-
|
|
3416
|
+
def _aten_to_dtype(a,
|
|
3417
|
+
dtype,
|
|
3418
|
+
non_blocking=False,
|
|
3419
|
+
copy=False,
|
|
3420
|
+
memory_format=None):
|
|
3315
3421
|
if dtype:
|
|
3316
3422
|
jaxdtype = mappings.t2j_dtype(dtype)
|
|
3317
3423
|
return a.astype(jaxdtype)
|
|
3318
3424
|
|
|
3319
3425
|
|
|
3320
3426
|
@op(torch.ops.aten.to.dtype_layout)
|
|
3321
|
-
def _aten_to_dtype_layout(
|
|
3322
|
-
|
|
3323
|
-
|
|
3427
|
+
def _aten_to_dtype_layout(a,
|
|
3428
|
+
*,
|
|
3429
|
+
dtype=None,
|
|
3430
|
+
layout=None,
|
|
3431
|
+
device=None,
|
|
3432
|
+
pin_memory=None,
|
|
3433
|
+
non_blocking=False,
|
|
3434
|
+
copy=False,
|
|
3435
|
+
memory_format=None):
|
|
3324
3436
|
return _aten_to_dtype(
|
|
3325
3437
|
a,
|
|
3326
3438
|
dtype,
|
|
@@ -3328,6 +3440,7 @@ def _aten_to_dtype_layout(
|
|
|
3328
3440
|
copy=copy,
|
|
3329
3441
|
memory_format=memory_format)
|
|
3330
3442
|
|
|
3443
|
+
|
|
3331
3444
|
# aten.to.device
|
|
3332
3445
|
|
|
3333
3446
|
|
|
@@ -3348,9 +3461,11 @@ def _aten_var_mean_correction(tensor, dim=None, correction=1, keepdim=False):
|
|
|
3348
3461
|
|
|
3349
3462
|
@op(torch.ops.aten.scalar_tensor)
|
|
3350
3463
|
@op_base.convert_dtype()
|
|
3351
|
-
def _aten_scalar_tensor(
|
|
3352
|
-
|
|
3353
|
-
|
|
3464
|
+
def _aten_scalar_tensor(s,
|
|
3465
|
+
dtype=None,
|
|
3466
|
+
layout=None,
|
|
3467
|
+
device=None,
|
|
3468
|
+
pin_memory=None):
|
|
3354
3469
|
return jnp.array(s, dtype=dtype)
|
|
3355
3470
|
|
|
3356
3471
|
|
|
@@ -3360,9 +3475,9 @@ def _aten_to_device(x, device, dtype):
|
|
|
3360
3475
|
|
|
3361
3476
|
|
|
3362
3477
|
@op(torch.ops.aten.max_pool2d_with_indices_backward)
|
|
3363
|
-
def max_pool2d_with_indices_backward_custom(
|
|
3364
|
-
|
|
3365
|
-
):
|
|
3478
|
+
def max_pool2d_with_indices_backward_custom(grad_output, self, kernel_size,
|
|
3479
|
+
stride, padding, dilation,
|
|
3480
|
+
ceil_mode, indices):
|
|
3366
3481
|
"""
|
|
3367
3482
|
Approximates the gradient calculation of PyTorch's max_pool2d_with_indices_backward.
|
|
3368
3483
|
|
|
@@ -3418,15 +3533,15 @@ def _aten_tensor_split(ary, indices_or_sections, axis=0):
|
|
|
3418
3533
|
@op(torch.ops.aten.randn, needs_env=True)
|
|
3419
3534
|
@op_base.convert_dtype()
|
|
3420
3535
|
def _randn(
|
|
3421
|
-
|
|
3422
|
-
|
|
3423
|
-
|
|
3424
|
-
|
|
3425
|
-
|
|
3426
|
-
|
|
3427
|
-
|
|
3428
|
-
|
|
3429
|
-
|
|
3536
|
+
*size,
|
|
3537
|
+
generator=None,
|
|
3538
|
+
out=None,
|
|
3539
|
+
dtype=None,
|
|
3540
|
+
layout=torch.strided,
|
|
3541
|
+
device=None,
|
|
3542
|
+
requires_grad=False,
|
|
3543
|
+
pin_memory=False,
|
|
3544
|
+
env=None,
|
|
3430
3545
|
):
|
|
3431
3546
|
shape = size
|
|
3432
3547
|
if len(shape) == 1 and isinstance(shape[0], (list, tuple)):
|
|
@@ -3437,13 +3552,14 @@ def _randn(
|
|
|
3437
3552
|
res = res.astype(dtype)
|
|
3438
3553
|
return res
|
|
3439
3554
|
|
|
3555
|
+
|
|
3440
3556
|
@op(torch.ops.aten.bernoulli.p, needs_env=True)
|
|
3441
|
-
def
|
|
3442
|
-
|
|
3443
|
-
|
|
3444
|
-
|
|
3445
|
-
|
|
3446
|
-
|
|
3557
|
+
def _aten_bernoulli(
|
|
3558
|
+
self,
|
|
3559
|
+
p=0.5,
|
|
3560
|
+
*,
|
|
3561
|
+
generator=None,
|
|
3562
|
+
env=None,
|
|
3447
3563
|
):
|
|
3448
3564
|
key = env.get_and_rotate_prng_key(generator)
|
|
3449
3565
|
res = jax.random.uniform(key, self.shape) < p
|
|
@@ -3460,14 +3576,14 @@ def geometric(self, p, *, generator=None, env=None):
|
|
|
3460
3576
|
@op(torch.ops.aten.randn_like, needs_env=True)
|
|
3461
3577
|
@op_base.convert_dtype()
|
|
3462
3578
|
def _aten_randn_like(
|
|
3463
|
-
|
|
3464
|
-
|
|
3465
|
-
|
|
3466
|
-
|
|
3467
|
-
|
|
3468
|
-
|
|
3469
|
-
|
|
3470
|
-
|
|
3579
|
+
x,
|
|
3580
|
+
*,
|
|
3581
|
+
dtype=None,
|
|
3582
|
+
layout=None,
|
|
3583
|
+
device=None,
|
|
3584
|
+
pin_memory=False,
|
|
3585
|
+
memory_format=torch.preserve_format,
|
|
3586
|
+
env=None,
|
|
3471
3587
|
):
|
|
3472
3588
|
key = env.get_and_rotate_prng_key()
|
|
3473
3589
|
return jax.random.normal(key, dtype=dtype or x.dtype, shape=x.shape)
|
|
@@ -3476,15 +3592,15 @@ def _aten_randn_like(
|
|
|
3476
3592
|
@op(torch.ops.aten.rand, needs_env=True)
|
|
3477
3593
|
@op_base.convert_dtype()
|
|
3478
3594
|
def _rand(
|
|
3479
|
-
|
|
3480
|
-
|
|
3481
|
-
|
|
3482
|
-
|
|
3483
|
-
|
|
3484
|
-
|
|
3485
|
-
|
|
3486
|
-
|
|
3487
|
-
|
|
3595
|
+
*size,
|
|
3596
|
+
generator=None,
|
|
3597
|
+
out=None,
|
|
3598
|
+
dtype=None,
|
|
3599
|
+
layout=torch.strided,
|
|
3600
|
+
device=None,
|
|
3601
|
+
requires_grad=False,
|
|
3602
|
+
pin_memory=False,
|
|
3603
|
+
env=None,
|
|
3488
3604
|
):
|
|
3489
3605
|
shape = size
|
|
3490
3606
|
if len(shape) == 1 and isinstance(shape[0], (list, tuple)):
|
|
@@ -3505,18 +3621,32 @@ def _aten_outer(a, b):
|
|
|
3505
3621
|
def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):
|
|
3506
3622
|
return jnp.allclose(input, other, rtol, atol, equal_nan)
|
|
3507
3623
|
|
|
3624
|
+
|
|
3508
3625
|
@op(torch.ops.aten.native_batch_norm)
|
|
3509
|
-
def _aten_native_batch_norm(input,
|
|
3626
|
+
def _aten_native_batch_norm(input,
|
|
3627
|
+
weight,
|
|
3628
|
+
bias,
|
|
3629
|
+
running_mean,
|
|
3630
|
+
running_var,
|
|
3631
|
+
training=False,
|
|
3632
|
+
momentum=0.1,
|
|
3633
|
+
eps=1e-5):
|
|
3510
3634
|
|
|
3511
3635
|
if running_mean is None:
|
|
3512
|
-
running_mean = jnp.zeros(
|
|
3636
|
+
running_mean = jnp.zeros(
|
|
3637
|
+
input.shape[1], dtype=input.dtype) # Initialize running mean if None
|
|
3513
3638
|
if running_var is None:
|
|
3514
|
-
running_var = jnp.ones(
|
|
3639
|
+
running_var = jnp.ones(
|
|
3640
|
+
input.shape[1],
|
|
3641
|
+
dtype=input.dtype) # Initialize running variance if None
|
|
3515
3642
|
|
|
3516
3643
|
if training:
|
|
3517
|
-
return _aten__native_batch_norm_legit(input, weight, bias, running_mean,
|
|
3644
|
+
return _aten__native_batch_norm_legit(input, weight, bias, running_mean,
|
|
3645
|
+
running_var, training, momentum, eps)
|
|
3518
3646
|
else:
|
|
3519
|
-
return _aten__native_batch_norm_legit_no_training(input, weight, bias,
|
|
3647
|
+
return _aten__native_batch_norm_legit_no_training(input, weight, bias,
|
|
3648
|
+
running_mean, running_var,
|
|
3649
|
+
momentum, eps)
|
|
3520
3650
|
|
|
3521
3651
|
|
|
3522
3652
|
@op(torch.ops.aten.normal, needs_env=True)
|
|
@@ -3525,12 +3655,14 @@ def _aten_normal(self, mean=0, std=1, generator=None, env=None):
|
|
|
3525
3655
|
res = _randn(*shape, generator=generator, env=env)
|
|
3526
3656
|
return res * std + mean
|
|
3527
3657
|
|
|
3658
|
+
|
|
3528
3659
|
# TODO: not clear what this function should actually do
|
|
3529
3660
|
# https://github.com/pytorch/pytorch/blob/d96c80649f301129219469d8b4353e52edab3b78/aten/src/ATen/native/native_functions.yaml#L7933-L7940
|
|
3530
3661
|
@op(torch.ops.aten.lift_fresh)
|
|
3531
3662
|
def _aten_lift_fresh(self):
|
|
3532
3663
|
return self
|
|
3533
3664
|
|
|
3665
|
+
|
|
3534
3666
|
@op(torch.ops.aten.uniform, needs_env=True)
|
|
3535
3667
|
def _aten_uniform(self, from_=0, to=1, *, generator=None, env=None):
|
|
3536
3668
|
assert from_ <= to, f'Uniform from(passed in {from_}) must be less than to(passed in {to})'
|
|
@@ -3538,16 +3670,18 @@ def _aten_uniform(self, from_=0, to=1, *, generator=None, env=None):
|
|
|
3538
3670
|
res = _rand(*shape, generator=generator, env=env)
|
|
3539
3671
|
return res * (to - from_) + from_
|
|
3540
3672
|
|
|
3673
|
+
|
|
3541
3674
|
#func: randint.low_generator(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
|
3542
3675
|
|
|
3676
|
+
|
|
3543
3677
|
@op(torch.ops.aten.randint, needs_env=True)
|
|
3544
3678
|
@op_base.convert_dtype(use_default_dtype=False)
|
|
3545
3679
|
def _aten_randint(
|
|
3546
|
-
|
|
3547
|
-
|
|
3548
|
-
|
|
3549
|
-
|
|
3550
|
-
|
|
3680
|
+
*args,
|
|
3681
|
+
generator=None,
|
|
3682
|
+
dtype=None,
|
|
3683
|
+
env=None,
|
|
3684
|
+
**kwargs,
|
|
3551
3685
|
):
|
|
3552
3686
|
if len(args) == 3:
|
|
3553
3687
|
# low, high, size
|
|
@@ -3556,7 +3690,8 @@ def _aten_randint(
|
|
|
3556
3690
|
high, size = args
|
|
3557
3691
|
low = 0
|
|
3558
3692
|
else:
|
|
3559
|
-
raise AssertionError(
|
|
3693
|
+
raise AssertionError(
|
|
3694
|
+
f'Expected at 2 or 3 args for Aten::randint, got {len(args)}')
|
|
3560
3695
|
|
|
3561
3696
|
key = env.get_and_rotate_prng_key(generator)
|
|
3562
3697
|
res = jax.random.randint(key, size, low, high)
|
|
@@ -3564,15 +3699,18 @@ def _aten_randint(
|
|
|
3564
3699
|
res = res.astype(dtype)
|
|
3565
3700
|
return res
|
|
3566
3701
|
|
|
3567
|
-
|
|
3702
|
+
|
|
3703
|
+
@op(torch.ops.aten.randint_like,
|
|
3704
|
+
torch.ops.aten.randint.generator,
|
|
3705
|
+
needs_env=True)
|
|
3568
3706
|
@op_base.convert_dtype(use_default_dtype=False)
|
|
3569
3707
|
def _aten_randint_like(
|
|
3570
|
-
|
|
3571
|
-
|
|
3572
|
-
|
|
3573
|
-
|
|
3574
|
-
|
|
3575
|
-
|
|
3708
|
+
input,
|
|
3709
|
+
*args,
|
|
3710
|
+
generator=None,
|
|
3711
|
+
dtype=None,
|
|
3712
|
+
env=None,
|
|
3713
|
+
**kwargs,
|
|
3576
3714
|
):
|
|
3577
3715
|
if len(args) == 2:
|
|
3578
3716
|
low, high = args
|
|
@@ -3580,7 +3718,8 @@ def _aten_randint_like(
|
|
|
3580
3718
|
high = args[0]
|
|
3581
3719
|
low = 0
|
|
3582
3720
|
else:
|
|
3583
|
-
raise AssertionError(
|
|
3721
|
+
raise AssertionError(
|
|
3722
|
+
f'Expected at 1 or 2 args for Aten::randint_like, got {len(args)}')
|
|
3584
3723
|
|
|
3585
3724
|
shape = input.shape
|
|
3586
3725
|
dtype = dtype or input.dtype
|
|
@@ -3590,6 +3729,7 @@ def _aten_randint_like(
|
|
|
3590
3729
|
res = res.astype(dtype)
|
|
3591
3730
|
return res
|
|
3592
3731
|
|
|
3732
|
+
|
|
3593
3733
|
@op(torch.ops.aten.dim, is_jax_function=False)
|
|
3594
3734
|
def _aten_dim(self):
|
|
3595
3735
|
return len(self.shape)
|
|
@@ -3602,10 +3742,11 @@ def _aten_copysign(input, other, *, out=None):
|
|
|
3602
3742
|
# regardless of their exact integer dtype, whereas jax.copysign returns
|
|
3603
3743
|
# float64 when one or both of them is int64.
|
|
3604
3744
|
if jnp.issubdtype(input.dtype, jnp.integer) and jnp.issubdtype(
|
|
3605
|
-
|
|
3606
|
-
):
|
|
3745
|
+
other.dtype, jnp.integer):
|
|
3607
3746
|
result = result.astype(jnp.float32)
|
|
3608
3747
|
return result
|
|
3748
|
+
|
|
3749
|
+
|
|
3609
3750
|
@op(torch.ops.aten.i0)
|
|
3610
3751
|
@op_base.promote_int_input
|
|
3611
3752
|
def _aten_i0(self):
|
|
@@ -3637,6 +3778,7 @@ def _aten_special_laguerre_polynomial_l(self, n):
|
|
|
3637
3778
|
|
|
3638
3779
|
@jnp.vectorize
|
|
3639
3780
|
def vectorized(x, n_i):
|
|
3781
|
+
|
|
3640
3782
|
def negative_n(x):
|
|
3641
3783
|
return jnp.zeros_like(x)
|
|
3642
3784
|
|
|
@@ -3650,6 +3792,7 @@ def _aten_special_laguerre_polynomial_l(self, n):
|
|
|
3650
3792
|
return jnp.ones_like(x)
|
|
3651
3793
|
|
|
3652
3794
|
def default(x):
|
|
3795
|
+
|
|
3653
3796
|
def f(k, carry):
|
|
3654
3797
|
p, q = carry
|
|
3655
3798
|
return (q, ((k * 2 + (jnp.ones_like(x) - x)) * q - k * p) / (k + 1))
|
|
@@ -3658,9 +3801,9 @@ def _aten_special_laguerre_polynomial_l(self, n):
|
|
|
3658
3801
|
return q
|
|
3659
3802
|
|
|
3660
3803
|
return jnp.piecewise(
|
|
3661
|
-
x, [n_i == 1, n_i == 0,
|
|
3662
|
-
|
|
3663
|
-
|
|
3804
|
+
x, [n_i == 1, n_i == 0,
|
|
3805
|
+
jnp.abs(n_i) == jnp.zeros_like(x), n_i < 0],
|
|
3806
|
+
[one_n, zero_n, zero_abs, negative_n, default])
|
|
3664
3807
|
|
|
3665
3808
|
return vectorized(self, n.astype(jnp.int64))
|
|
3666
3809
|
|
|
@@ -3760,125 +3903,124 @@ def _aten_special_modified_bessel_i0(self):
|
|
|
3760
3903
|
return jnp.exp(x) * (0.5 * (b - p)) / jnp.sqrt(x)
|
|
3761
3904
|
|
|
3762
3905
|
self = jnp.abs(self)
|
|
3763
|
-
return jnp.piecewise(
|
|
3764
|
-
|
|
3765
|
-
)
|
|
3906
|
+
return jnp.piecewise(self, [self <= 8], [small, default])
|
|
3907
|
+
|
|
3766
3908
|
|
|
3767
3909
|
@op(torch.ops.aten.special_modified_bessel_i1)
|
|
3768
3910
|
@op_base.promote_int_input
|
|
3769
3911
|
def _aten_special_modified_bessel_i1(self):
|
|
3770
|
-
|
|
3771
|
-
|
|
3772
|
-
def small(x):
|
|
3773
|
-
A = jnp.array(
|
|
3774
|
-
[
|
|
3775
|
-
2.77791411276104639959e-18,
|
|
3776
|
-
-2.11142121435816608115e-17,
|
|
3777
|
-
1.55363195773620046921e-16,
|
|
3778
|
-
-1.10559694773538630805e-15,
|
|
3779
|
-
7.60068429473540693410e-15,
|
|
3780
|
-
-5.04218550472791168711e-14,
|
|
3781
|
-
3.22379336594557470981e-13,
|
|
3782
|
-
-1.98397439776494371520e-12,
|
|
3783
|
-
1.17361862988909016308e-11,
|
|
3784
|
-
-6.66348972350202774223e-11,
|
|
3785
|
-
3.62559028155211703701e-10,
|
|
3786
|
-
-1.88724975172282928790e-09,
|
|
3787
|
-
9.38153738649577178388e-09,
|
|
3788
|
-
-4.44505912879632808065e-08,
|
|
3789
|
-
2.00329475355213526229e-07,
|
|
3790
|
-
-8.56872026469545474066e-07,
|
|
3791
|
-
3.47025130813767847674e-06,
|
|
3792
|
-
-1.32731636560394358279e-05,
|
|
3793
|
-
4.78156510755005422638e-05,
|
|
3794
|
-
-1.61760815825896745588e-04,
|
|
3795
|
-
5.12285956168575772895e-04,
|
|
3796
|
-
-1.51357245063125314899e-03,
|
|
3797
|
-
4.15642294431288815669e-03,
|
|
3798
|
-
-1.05640848946261981558e-02,
|
|
3799
|
-
2.47264490306265168283e-02,
|
|
3800
|
-
-5.29459812080949914269e-02,
|
|
3801
|
-
1.02643658689847095384e-01,
|
|
3802
|
-
-1.76416518357834055153e-01,
|
|
3803
|
-
2.52587186443633654823e-01,
|
|
3804
|
-
],
|
|
3805
|
-
dtype=self.dtype,
|
|
3806
|
-
)
|
|
3807
|
-
|
|
3808
|
-
def f(carry, val):
|
|
3809
|
-
p, q, a = carry
|
|
3810
|
-
p, q = q, a
|
|
3811
|
-
return (p, q, ((jnp.abs(x) / 2.0) - 2.0) * q - p + val), None
|
|
3812
|
-
|
|
3813
|
-
(p, _, a), _ = jax.lax.scan(
|
|
3814
|
-
f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A)
|
|
3815
|
-
|
|
3816
|
-
return jax.lax.cond(
|
|
3817
|
-
x < 0, lambda: -(0.5 * (a - p) * jnp.abs(x) * jnp.exp(jnp.abs(x))), lambda: 0.5 * (a - p) * jnp.abs(x) * jnp.exp(jnp.abs(x))
|
|
3818
|
-
)
|
|
3912
|
+
# Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3271-L3364
|
|
3819
3913
|
|
|
3820
|
-
|
|
3821
|
-
|
|
3822
|
-
|
|
3823
|
-
|
|
3824
|
-
|
|
3825
|
-
|
|
3826
|
-
|
|
3827
|
-
|
|
3828
|
-
|
|
3829
|
-
|
|
3830
|
-
|
|
3831
|
-
|
|
3832
|
-
|
|
3833
|
-
|
|
3834
|
-
|
|
3835
|
-
|
|
3836
|
-
|
|
3837
|
-
|
|
3838
|
-
|
|
3839
|
-
|
|
3840
|
-
|
|
3841
|
-
|
|
3842
|
-
|
|
3843
|
-
|
|
3844
|
-
|
|
3845
|
-
|
|
3846
|
-
|
|
3847
|
-
|
|
3848
|
-
|
|
3849
|
-
|
|
3850
|
-
|
|
3851
|
-
|
|
3852
|
-
|
|
3853
|
-
|
|
3854
|
-
|
|
3855
|
-
return (p, q, (32.0 / jnp.abs(x) - 2.0) * q - p + val), None
|
|
3856
|
-
|
|
3857
|
-
(p, _, b), _ = jax.lax.scan(
|
|
3858
|
-
f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B)
|
|
3859
|
-
|
|
3860
|
-
return jax.lax.cond(
|
|
3861
|
-
x < 0, lambda: -(jnp.exp(jnp.abs(x)) * (0.5 * (b - p)) / jnp.sqrt(jnp.abs(x))), lambda: jnp.exp(jnp.abs(x)) * (0.5 * (b - p)) / jnp.sqrt(jnp.abs(x))
|
|
3862
|
-
)
|
|
3914
|
+
def small(x):
|
|
3915
|
+
A = jnp.array(
|
|
3916
|
+
[
|
|
3917
|
+
2.77791411276104639959e-18,
|
|
3918
|
+
-2.11142121435816608115e-17,
|
|
3919
|
+
1.55363195773620046921e-16,
|
|
3920
|
+
-1.10559694773538630805e-15,
|
|
3921
|
+
7.60068429473540693410e-15,
|
|
3922
|
+
-5.04218550472791168711e-14,
|
|
3923
|
+
3.22379336594557470981e-13,
|
|
3924
|
+
-1.98397439776494371520e-12,
|
|
3925
|
+
1.17361862988909016308e-11,
|
|
3926
|
+
-6.66348972350202774223e-11,
|
|
3927
|
+
3.62559028155211703701e-10,
|
|
3928
|
+
-1.88724975172282928790e-09,
|
|
3929
|
+
9.38153738649577178388e-09,
|
|
3930
|
+
-4.44505912879632808065e-08,
|
|
3931
|
+
2.00329475355213526229e-07,
|
|
3932
|
+
-8.56872026469545474066e-07,
|
|
3933
|
+
3.47025130813767847674e-06,
|
|
3934
|
+
-1.32731636560394358279e-05,
|
|
3935
|
+
4.78156510755005422638e-05,
|
|
3936
|
+
-1.61760815825896745588e-04,
|
|
3937
|
+
5.12285956168575772895e-04,
|
|
3938
|
+
-1.51357245063125314899e-03,
|
|
3939
|
+
4.15642294431288815669e-03,
|
|
3940
|
+
-1.05640848946261981558e-02,
|
|
3941
|
+
2.47264490306265168283e-02,
|
|
3942
|
+
-5.29459812080949914269e-02,
|
|
3943
|
+
1.02643658689847095384e-01,
|
|
3944
|
+
-1.76416518357834055153e-01,
|
|
3945
|
+
2.52587186443633654823e-01,
|
|
3946
|
+
],
|
|
3947
|
+
dtype=self.dtype,
|
|
3948
|
+
)
|
|
3863
3949
|
|
|
3864
|
-
|
|
3865
|
-
|
|
3950
|
+
def f(carry, val):
|
|
3951
|
+
p, q, a = carry
|
|
3952
|
+
p, q = q, a
|
|
3953
|
+
return (p, q, ((jnp.abs(x) / 2.0) - 2.0) * q - p + val), None
|
|
3954
|
+
|
|
3955
|
+
(p, _, a), _ = jax.lax.scan(
|
|
3956
|
+
f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A)
|
|
3957
|
+
|
|
3958
|
+
return jax.lax.cond(
|
|
3959
|
+
x < 0, lambda: -(0.5 * (a - p) * jnp.abs(x) * jnp.exp(jnp.abs(x))),
|
|
3960
|
+
lambda: 0.5 * (a - p) * jnp.abs(x) * jnp.exp(jnp.abs(x)))
|
|
3961
|
+
|
|
3962
|
+
def default(x):
|
|
3963
|
+
B = jnp.array(
|
|
3964
|
+
[
|
|
3965
|
+
7.51729631084210481353e-18,
|
|
3966
|
+
4.41434832307170791151e-18,
|
|
3967
|
+
-4.65030536848935832153e-17,
|
|
3968
|
+
-3.20952592199342395980e-17,
|
|
3969
|
+
2.96262899764595013876e-16,
|
|
3970
|
+
3.30820231092092828324e-16,
|
|
3971
|
+
-1.88035477551078244854e-15,
|
|
3972
|
+
-3.81440307243700780478e-15,
|
|
3973
|
+
1.04202769841288027642e-14,
|
|
3974
|
+
4.27244001671195135429e-14,
|
|
3975
|
+
-2.10154184277266431302e-14,
|
|
3976
|
+
-4.08355111109219731823e-13,
|
|
3977
|
+
-7.19855177624590851209e-13,
|
|
3978
|
+
2.03562854414708950722e-12,
|
|
3979
|
+
1.41258074366137813316e-11,
|
|
3980
|
+
3.25260358301548823856e-11,
|
|
3981
|
+
-1.89749581235054123450e-11,
|
|
3982
|
+
-5.58974346219658380687e-10,
|
|
3983
|
+
-3.83538038596423702205e-09,
|
|
3984
|
+
-2.63146884688951950684e-08,
|
|
3985
|
+
-2.51223623787020892529e-07,
|
|
3986
|
+
-3.88256480887769039346e-06,
|
|
3987
|
+
-1.10588938762623716291e-04,
|
|
3988
|
+
-9.76109749136146840777e-03,
|
|
3989
|
+
7.78576235018280120474e-01,
|
|
3990
|
+
],
|
|
3991
|
+
dtype=self.dtype,
|
|
3866
3992
|
)
|
|
3867
3993
|
|
|
3994
|
+
def f(carry, val):
|
|
3995
|
+
p, q, b = carry
|
|
3996
|
+
p, q = q, b
|
|
3997
|
+
return (p, q, (32.0 / jnp.abs(x) - 2.0) * q - p + val), None
|
|
3998
|
+
|
|
3999
|
+
(p, _, b), _ = jax.lax.scan(
|
|
4000
|
+
f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B)
|
|
4001
|
+
|
|
4002
|
+
return jax.lax.cond(
|
|
4003
|
+
x < 0, lambda: -(jnp.exp(jnp.abs(x)) *
|
|
4004
|
+
(0.5 * (b - p)) / jnp.sqrt(jnp.abs(x))),
|
|
4005
|
+
lambda: jnp.exp(jnp.abs(x)) * (0.5 * (b - p)) / jnp.sqrt(jnp.abs(x)))
|
|
4006
|
+
|
|
4007
|
+
return jnp.piecewise(self, [self <= 8], [small, default])
|
|
4008
|
+
|
|
4009
|
+
|
|
3868
4010
|
@op(torch.ops.aten.special_modified_bessel_k0)
|
|
3869
4011
|
@op_base.promote_int_input
|
|
3870
4012
|
def _aten_special_modified_bessel_k0(self):
|
|
3871
|
-
|
|
4013
|
+
# Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3367-L3441
|
|
3872
4014
|
|
|
3873
|
-
|
|
3874
|
-
|
|
4015
|
+
def zero(x):
|
|
4016
|
+
return jnp.array(jnp.inf, x.dtype)
|
|
3875
4017
|
|
|
3876
|
-
|
|
3877
|
-
|
|
4018
|
+
def negative(x):
|
|
4019
|
+
return jnp.array(jnp.nan, x.dtype)
|
|
3878
4020
|
|
|
3879
|
-
|
|
3880
|
-
|
|
3881
|
-
|
|
4021
|
+
def small(x):
|
|
4022
|
+
A = jnp.array(
|
|
4023
|
+
[
|
|
3882
4024
|
1.37446543561352307156e-16,
|
|
3883
4025
|
4.25981614279661018399e-14,
|
|
3884
4026
|
1.03496952576338420167e-11,
|
|
@@ -3889,23 +4031,24 @@ def _aten_special_modified_bessel_k0(self):
|
|
|
3889
4031
|
3.59799365153615016266e-02,
|
|
3890
4032
|
3.44289899924628486886e-01,
|
|
3891
4033
|
-5.35327393233902768720e-01,
|
|
3892
|
-
|
|
3893
|
-
|
|
3894
|
-
|
|
4034
|
+
],
|
|
4035
|
+
dtype=self.dtype,
|
|
4036
|
+
)
|
|
3895
4037
|
|
|
3896
|
-
|
|
3897
|
-
|
|
3898
|
-
|
|
3899
|
-
|
|
4038
|
+
def f(carry, val):
|
|
4039
|
+
p, q, a = carry
|
|
4040
|
+
p, q = q, a
|
|
4041
|
+
return (p, q, (x * x - 2.0) * q - p + val), None
|
|
4042
|
+
|
|
4043
|
+
(p, _, a), _ = jax.lax.scan(
|
|
4044
|
+
f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A)
|
|
3900
4045
|
|
|
3901
|
-
|
|
3902
|
-
|
|
3903
|
-
|
|
3904
|
-
return 0.5 * (a - p) - jnp.log(0.5 * x) * _aten_special_modified_bessel_i0(x)
|
|
4046
|
+
return 0.5 * (a - p) - jnp.log(
|
|
4047
|
+
0.5 * x) * _aten_special_modified_bessel_i0(x)
|
|
3905
4048
|
|
|
3906
|
-
|
|
3907
|
-
|
|
3908
|
-
|
|
4049
|
+
def default(x):
|
|
4050
|
+
B = jnp.array(
|
|
4051
|
+
[
|
|
3909
4052
|
5.30043377268626276149e-18,
|
|
3910
4053
|
-1.64758043015242134646e-17,
|
|
3911
4054
|
5.21039150503902756861e-17,
|
|
@@ -3931,38 +4074,38 @@ def _aten_special_modified_bessel_k0(self):
|
|
|
3931
4074
|
1.56988388573005337491e-03,
|
|
3932
4075
|
-3.14481013119645005427e-02,
|
|
3933
4076
|
2.44030308206595545468e+00,
|
|
3934
|
-
|
|
3935
|
-
|
|
3936
|
-
|
|
4077
|
+
],
|
|
4078
|
+
dtype=self.dtype,
|
|
4079
|
+
)
|
|
4080
|
+
|
|
4081
|
+
def f(carry, val):
|
|
4082
|
+
p, q, b = carry
|
|
4083
|
+
p, q = q, b
|
|
4084
|
+
return (p, q, (8.0 / x - 2.0) * q - p + val), None
|
|
3937
4085
|
|
|
3938
|
-
|
|
3939
|
-
|
|
3940
|
-
p, q = q, b
|
|
3941
|
-
return (p, q, (8.0 / x - 2.0) * q - p + val), None
|
|
4086
|
+
(p, _, b), _ = jax.lax.scan(
|
|
4087
|
+
f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B)
|
|
3942
4088
|
|
|
3943
|
-
|
|
3944
|
-
|
|
3945
|
-
|
|
3946
|
-
|
|
4089
|
+
return jnp.exp(-x) * (0.5 * (b - p)) / jnp.sqrt(x)
|
|
4090
|
+
|
|
4091
|
+
return jnp.piecewise(self, [self <= 2, self < 0, self == 0],
|
|
4092
|
+
[small, negative, zero, default])
|
|
3947
4093
|
|
|
3948
|
-
return jnp.piecewise(
|
|
3949
|
-
self, [self <= 2, self < 0, self == 0], [small, negative, zero, default]
|
|
3950
|
-
)
|
|
3951
4094
|
|
|
3952
4095
|
@op(torch.ops.aten.special_modified_bessel_k1)
|
|
3953
4096
|
@op_base.promote_int_input
|
|
3954
4097
|
def _aten_special_modified_bessel_k1(self):
|
|
3955
|
-
|
|
4098
|
+
# Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3444-L3519
|
|
3956
4099
|
|
|
3957
|
-
|
|
3958
|
-
|
|
4100
|
+
def zero(x):
|
|
4101
|
+
return jnp.array(jnp.inf, x.dtype)
|
|
3959
4102
|
|
|
3960
|
-
|
|
3961
|
-
|
|
4103
|
+
def negative(x):
|
|
4104
|
+
return jnp.array(jnp.nan, x.dtype)
|
|
3962
4105
|
|
|
3963
|
-
|
|
3964
|
-
|
|
3965
|
-
|
|
4106
|
+
def small(x):
|
|
4107
|
+
A = jnp.array(
|
|
4108
|
+
[
|
|
3966
4109
|
-7.02386347938628759343e-18,
|
|
3967
4110
|
-2.42744985051936593393e-15,
|
|
3968
4111
|
-6.66690169419932900609e-13,
|
|
@@ -3974,24 +4117,25 @@ def _aten_special_modified_bessel_k1(self):
|
|
|
3974
4117
|
-1.22611180822657148235e-01,
|
|
3975
4118
|
-3.53155960776544875667e-01,
|
|
3976
4119
|
1.52530022733894777053e+00,
|
|
3977
|
-
|
|
3978
|
-
|
|
3979
|
-
|
|
4120
|
+
],
|
|
4121
|
+
dtype=self.dtype,
|
|
4122
|
+
)
|
|
3980
4123
|
|
|
3981
|
-
|
|
3982
|
-
|
|
3983
|
-
|
|
3984
|
-
|
|
3985
|
-
|
|
4124
|
+
def f(carry, val):
|
|
4125
|
+
p, q, a = carry
|
|
4126
|
+
p, q = q, a
|
|
4127
|
+
a = (x * x - 2.0) * q - p + val
|
|
4128
|
+
return (p, q, a), None
|
|
4129
|
+
|
|
4130
|
+
(p, _, a), _ = jax.lax.scan(
|
|
4131
|
+
f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A)
|
|
3986
4132
|
|
|
3987
|
-
|
|
3988
|
-
|
|
3989
|
-
|
|
3990
|
-
return jnp.log(0.5 * x) * _aten_special_modified_bessel_i1(x) + 0.5 * (a - p) / x
|
|
4133
|
+
return jnp.log(
|
|
4134
|
+
0.5 * x) * _aten_special_modified_bessel_i1(x) + 0.5 * (a - p) / x
|
|
3991
4135
|
|
|
3992
|
-
|
|
3993
|
-
|
|
3994
|
-
|
|
4136
|
+
def default(x):
|
|
4137
|
+
B = jnp.array(
|
|
4138
|
+
[
|
|
3995
4139
|
-5.75674448366501715755e-18,
|
|
3996
4140
|
1.79405087314755922667e-17,
|
|
3997
4141
|
-5.68946255844285935196e-17,
|
|
@@ -4017,24 +4161,24 @@ def _aten_special_modified_bessel_k1(self):
|
|
|
4017
4161
|
-2.85781685962277938680e-03,
|
|
4018
4162
|
1.03923736576817238437e-01,
|
|
4019
4163
|
2.72062619048444266945e+00,
|
|
4020
|
-
|
|
4021
|
-
|
|
4022
|
-
|
|
4164
|
+
],
|
|
4165
|
+
dtype=self.dtype,
|
|
4166
|
+
)
|
|
4167
|
+
|
|
4168
|
+
def f(carry, val):
|
|
4169
|
+
p, q, b = carry
|
|
4170
|
+
p, q = q, b
|
|
4171
|
+
b = (8.0 / x - 2.0) * q - p + val
|
|
4172
|
+
return (p, q, b), None
|
|
4173
|
+
|
|
4174
|
+
(p, _, b), _ = jax.lax.scan(
|
|
4175
|
+
f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B)
|
|
4023
4176
|
|
|
4024
|
-
|
|
4025
|
-
p, q, b = carry
|
|
4026
|
-
p, q = q, b
|
|
4027
|
-
b = (8.0 / x - 2.0) * q - p + val
|
|
4028
|
-
return (p, q, b), None
|
|
4177
|
+
return jnp.exp(-x) * (0.5 * (b - p)) / jnp.sqrt(x)
|
|
4029
4178
|
|
|
4030
|
-
|
|
4031
|
-
|
|
4032
|
-
|
|
4033
|
-
return jnp.exp(-x) * (0.5 * (b - p)) / jnp.sqrt(x)
|
|
4179
|
+
return jnp.piecewise(self, [self <= 2, self < 0, self == 0],
|
|
4180
|
+
[small, negative, zero, default])
|
|
4034
4181
|
|
|
4035
|
-
return jnp.piecewise(
|
|
4036
|
-
self, [self <= 2, self < 0, self == 0], [small, negative, zero, default]
|
|
4037
|
-
)
|
|
4038
4182
|
|
|
4039
4183
|
@op(torch.ops.aten.polygamma)
|
|
4040
4184
|
def _aten_polygamma(x, n):
|
|
@@ -4042,10 +4186,12 @@ def _aten_polygamma(x, n):
|
|
|
4042
4186
|
n = n.astype(mappings.t2j_dtype(torch.get_default_dtype()))
|
|
4043
4187
|
return jax.lax.polygamma(jnp.float32(x), n)
|
|
4044
4188
|
|
|
4189
|
+
|
|
4045
4190
|
@op(torch.ops.aten.special_ndtri)
|
|
4046
4191
|
@op_base.promote_int_input
|
|
4047
4192
|
def _aten_special_ndtri(self):
|
|
4048
|
-
|
|
4193
|
+
return jax.scipy.special.ndtri(self)
|
|
4194
|
+
|
|
4049
4195
|
|
|
4050
4196
|
@op(torch.ops.aten.special_bessel_j0)
|
|
4051
4197
|
@op_base.promote_int_input
|
|
@@ -4057,112 +4203,104 @@ def _aten_special_bessel_j0(self):
|
|
|
4057
4203
|
|
|
4058
4204
|
def small(x):
|
|
4059
4205
|
RP = jnp.array(
|
|
4060
|
-
|
|
4061
|
-
|
|
4062
|
-
|
|
4063
|
-
|
|
4064
|
-
|
|
4065
|
-
|
|
4066
|
-
|
|
4206
|
+
[
|
|
4207
|
+
-4.79443220978201773821e09,
|
|
4208
|
+
1.95617491946556577543e12,
|
|
4209
|
+
-2.49248344360967716204e14,
|
|
4210
|
+
9.70862251047306323952e15,
|
|
4211
|
+
],
|
|
4212
|
+
dtype=self.dtype,
|
|
4067
4213
|
)
|
|
4068
4214
|
RQ = jnp.array(
|
|
4069
|
-
|
|
4070
|
-
|
|
4071
|
-
|
|
4072
|
-
|
|
4073
|
-
|
|
4074
|
-
|
|
4075
|
-
|
|
4076
|
-
|
|
4077
|
-
|
|
4078
|
-
|
|
4079
|
-
|
|
4215
|
+
[
|
|
4216
|
+
4.99563147152651017219e02,
|
|
4217
|
+
1.73785401676374683123e05,
|
|
4218
|
+
4.84409658339962045305e07,
|
|
4219
|
+
1.11855537045356834862e10,
|
|
4220
|
+
2.11277520115489217587e12,
|
|
4221
|
+
3.10518229857422583814e14,
|
|
4222
|
+
3.18121955943204943306e16,
|
|
4223
|
+
1.71086294081043136091e18,
|
|
4224
|
+
],
|
|
4225
|
+
dtype=self.dtype,
|
|
4080
4226
|
)
|
|
4081
4227
|
|
|
4082
4228
|
rp = op_base.foreach_loop(RP, lambda carry, rp_i: carry * (x * x) + rp_i)
|
|
4083
4229
|
rq = op_base.foreach_loop(RQ, lambda carry, rq_i: carry * (x * x) + rq_i)
|
|
4084
4230
|
|
|
4085
|
-
return (
|
|
4086
|
-
|
|
4087
|
-
* (x * x - 3.04712623436620863991e01)
|
|
4088
|
-
* rp
|
|
4089
|
-
/ rq
|
|
4090
|
-
)
|
|
4231
|
+
return ((x * x - 5.78318596294678452118e00) *
|
|
4232
|
+
(x * x - 3.04712623436620863991e01) * rp / rq)
|
|
4091
4233
|
|
|
4092
4234
|
def default(x):
|
|
4093
4235
|
PP = jnp.array(
|
|
4094
|
-
|
|
4095
|
-
|
|
4096
|
-
|
|
4097
|
-
|
|
4098
|
-
|
|
4099
|
-
|
|
4100
|
-
|
|
4101
|
-
|
|
4102
|
-
|
|
4103
|
-
|
|
4236
|
+
[
|
|
4237
|
+
7.96936729297347051624e-04,
|
|
4238
|
+
8.28352392107440799803e-02,
|
|
4239
|
+
1.23953371646414299388e00,
|
|
4240
|
+
5.44725003058768775090e00,
|
|
4241
|
+
8.74716500199817011941e00,
|
|
4242
|
+
5.30324038235394892183e00,
|
|
4243
|
+
9.99999999999999997821e-01,
|
|
4244
|
+
],
|
|
4245
|
+
dtype=self.dtype,
|
|
4104
4246
|
)
|
|
4105
4247
|
PQ = jnp.array(
|
|
4106
|
-
|
|
4107
|
-
|
|
4108
|
-
|
|
4109
|
-
|
|
4110
|
-
|
|
4111
|
-
|
|
4112
|
-
|
|
4113
|
-
|
|
4114
|
-
|
|
4115
|
-
|
|
4248
|
+
[
|
|
4249
|
+
9.24408810558863637013e-04,
|
|
4250
|
+
8.56288474354474431428e-02,
|
|
4251
|
+
1.25352743901058953537e00,
|
|
4252
|
+
5.47097740330417105182e00,
|
|
4253
|
+
8.76190883237069594232e00,
|
|
4254
|
+
5.30605288235394617618e00,
|
|
4255
|
+
1.00000000000000000218e00,
|
|
4256
|
+
],
|
|
4257
|
+
dtype=self.dtype,
|
|
4116
4258
|
)
|
|
4117
4259
|
QP = jnp.array(
|
|
4118
|
-
|
|
4119
|
-
|
|
4120
|
-
|
|
4121
|
-
|
|
4122
|
-
|
|
4123
|
-
|
|
4124
|
-
|
|
4125
|
-
|
|
4126
|
-
|
|
4127
|
-
|
|
4128
|
-
|
|
4260
|
+
[
|
|
4261
|
+
-1.13663838898469149931e-02,
|
|
4262
|
+
-1.28252718670509318512e00,
|
|
4263
|
+
-1.95539544257735972385e01,
|
|
4264
|
+
-9.32060152123768231369e01,
|
|
4265
|
+
-1.77681167980488050595e02,
|
|
4266
|
+
-1.47077505154951170175e02,
|
|
4267
|
+
-5.14105326766599330220e01,
|
|
4268
|
+
-6.05014350600728481186e00,
|
|
4269
|
+
],
|
|
4270
|
+
dtype=self.dtype,
|
|
4129
4271
|
)
|
|
4130
4272
|
QQ = jnp.array(
|
|
4131
|
-
|
|
4132
|
-
|
|
4133
|
-
|
|
4134
|
-
|
|
4135
|
-
|
|
4136
|
-
|
|
4137
|
-
|
|
4138
|
-
|
|
4139
|
-
|
|
4140
|
-
|
|
4273
|
+
[
|
|
4274
|
+
6.43178256118178023184e01,
|
|
4275
|
+
8.56430025976980587198e02,
|
|
4276
|
+
3.88240183605401609683e03,
|
|
4277
|
+
7.24046774195652478189e03,
|
|
4278
|
+
5.93072701187316984827e03,
|
|
4279
|
+
2.06209331660327847417e03,
|
|
4280
|
+
2.42005740240291393179e02,
|
|
4281
|
+
],
|
|
4282
|
+
dtype=self.dtype,
|
|
4141
4283
|
)
|
|
4142
4284
|
|
|
4143
|
-
pp = op_base.foreach_loop(
|
|
4144
|
-
|
|
4145
|
-
|
|
4146
|
-
|
|
4147
|
-
|
|
4148
|
-
|
|
4149
|
-
|
|
4150
|
-
|
|
4151
|
-
|
|
4152
|
-
|
|
4153
|
-
|
|
4154
|
-
|
|
4155
|
-
|
|
4156
|
-
* 0.797884560802865355879892119868763737
|
|
4157
|
-
/ jnp.sqrt(x)
|
|
4158
|
-
)
|
|
4285
|
+
pp = op_base.foreach_loop(
|
|
4286
|
+
PP, lambda carry, pp_i: carry * (25.0 / (x * x)) + pp_i)
|
|
4287
|
+
pq = op_base.foreach_loop(
|
|
4288
|
+
PQ, lambda carry, pq_i: carry * (25.0 / (x * x)) + pq_i)
|
|
4289
|
+
qp = op_base.foreach_loop(
|
|
4290
|
+
QP, lambda carry, qp_i: carry * (25.0 / (x * x)) + qp_i)
|
|
4291
|
+
qq = op_base.foreach_loop(
|
|
4292
|
+
QQ, lambda carry, qq_i: carry * (25.0 / (x * x)) + qq_i)
|
|
4293
|
+
|
|
4294
|
+
return ((pp / pq * jnp.cos(x - 0.785398163397448309615660845819875721) -
|
|
4295
|
+
5.0 / x *
|
|
4296
|
+
(qp / qq) * jnp.sin(x - 0.785398163397448309615660845819875721)) *
|
|
4297
|
+
0.797884560802865355879892119868763737 / jnp.sqrt(x))
|
|
4159
4298
|
|
|
4160
4299
|
self = jnp.abs(self)
|
|
4161
4300
|
# Last True condition in `piecewise` takes priority, but last function is
|
|
4162
4301
|
# default. See https://github.com/numpy/numpy/issues/16475
|
|
4163
|
-
return jnp.piecewise(
|
|
4164
|
-
|
|
4165
|
-
)
|
|
4302
|
+
return jnp.piecewise(self, [self <= 5.0, self < 0.00001],
|
|
4303
|
+
[small, very_small, default])
|
|
4166
4304
|
|
|
4167
4305
|
|
|
4168
4306
|
@op(torch.ops.aten.special_bessel_j1)
|
|
@@ -4172,114 +4310,106 @@ def _aten_special_bessel_j1(self):
|
|
|
4172
4310
|
|
|
4173
4311
|
def small(x):
|
|
4174
4312
|
RP = jnp.array(
|
|
4175
|
-
|
|
4176
|
-
|
|
4177
|
-
|
|
4178
|
-
|
|
4179
|
-
|
|
4180
|
-
|
|
4181
|
-
|
|
4313
|
+
[
|
|
4314
|
+
-8.99971225705559398224e08,
|
|
4315
|
+
4.52228297998194034323e11,
|
|
4316
|
+
-7.27494245221818276015e13,
|
|
4317
|
+
3.68295732863852883286e15,
|
|
4318
|
+
],
|
|
4319
|
+
dtype=self.dtype,
|
|
4182
4320
|
)
|
|
4183
4321
|
RQ = jnp.array(
|
|
4184
|
-
|
|
4185
|
-
|
|
4186
|
-
|
|
4187
|
-
|
|
4188
|
-
|
|
4189
|
-
|
|
4190
|
-
|
|
4191
|
-
|
|
4192
|
-
|
|
4193
|
-
|
|
4194
|
-
|
|
4322
|
+
[
|
|
4323
|
+
6.20836478118054335476e02,
|
|
4324
|
+
2.56987256757748830383e05,
|
|
4325
|
+
8.35146791431949253037e07,
|
|
4326
|
+
2.21511595479792499675e10,
|
|
4327
|
+
4.74914122079991414898e12,
|
|
4328
|
+
7.84369607876235854894e14,
|
|
4329
|
+
8.95222336184627338078e16,
|
|
4330
|
+
5.32278620332680085395e18,
|
|
4331
|
+
],
|
|
4332
|
+
dtype=self.dtype,
|
|
4195
4333
|
)
|
|
4196
4334
|
|
|
4197
4335
|
rp = op_base.foreach_loop(RP, lambda carry, rp_i: carry * (x * x) + rp_i)
|
|
4198
4336
|
rq = op_base.foreach_loop(RQ, lambda carry, rq_i: carry * (x * x) + rq_i)
|
|
4199
4337
|
|
|
4200
|
-
return (
|
|
4201
|
-
|
|
4202
|
-
/ rq
|
|
4203
|
-
* x
|
|
4204
|
-
* (x * x - 1.46819706421238932572e01)
|
|
4205
|
-
* (x * x - 4.92184563216946036703e01)
|
|
4206
|
-
)
|
|
4338
|
+
return (rp / rq * x * (x * x - 1.46819706421238932572e01) *
|
|
4339
|
+
(x * x - 4.92184563216946036703e01))
|
|
4207
4340
|
|
|
4208
4341
|
def default(x):
|
|
4209
4342
|
PP = jnp.array(
|
|
4210
|
-
|
|
4211
|
-
|
|
4212
|
-
|
|
4213
|
-
|
|
4214
|
-
|
|
4215
|
-
|
|
4216
|
-
|
|
4217
|
-
|
|
4218
|
-
|
|
4219
|
-
|
|
4343
|
+
[
|
|
4344
|
+
7.62125616208173112003e-04,
|
|
4345
|
+
7.31397056940917570436e-02,
|
|
4346
|
+
1.12719608129684925192e00,
|
|
4347
|
+
5.11207951146807644818e00,
|
|
4348
|
+
8.42404590141772420927e00,
|
|
4349
|
+
5.21451598682361504063e00,
|
|
4350
|
+
1.00000000000000000254e00,
|
|
4351
|
+
],
|
|
4352
|
+
dtype=self.dtype,
|
|
4220
4353
|
)
|
|
4221
4354
|
PQ = jnp.array(
|
|
4222
|
-
|
|
4223
|
-
|
|
4224
|
-
|
|
4225
|
-
|
|
4226
|
-
|
|
4227
|
-
|
|
4228
|
-
|
|
4229
|
-
|
|
4230
|
-
|
|
4231
|
-
|
|
4355
|
+
[
|
|
4356
|
+
5.71323128072548699714e-04,
|
|
4357
|
+
6.88455908754495404082e-02,
|
|
4358
|
+
1.10514232634061696926e00,
|
|
4359
|
+
5.07386386128601488557e00,
|
|
4360
|
+
8.39985554327604159757e00,
|
|
4361
|
+
5.20982848682361821619e00,
|
|
4362
|
+
9.99999999999999997461e-01,
|
|
4363
|
+
],
|
|
4364
|
+
dtype=self.dtype,
|
|
4232
4365
|
)
|
|
4233
4366
|
QP = jnp.array(
|
|
4234
|
-
|
|
4235
|
-
|
|
4236
|
-
|
|
4237
|
-
|
|
4238
|
-
|
|
4239
|
-
|
|
4240
|
-
|
|
4241
|
-
|
|
4242
|
-
|
|
4243
|
-
|
|
4244
|
-
|
|
4367
|
+
[
|
|
4368
|
+
5.10862594750176621635e-02,
|
|
4369
|
+
4.98213872951233449420e00,
|
|
4370
|
+
7.58238284132545283818e01,
|
|
4371
|
+
3.66779609360150777800e02,
|
|
4372
|
+
7.10856304998926107277e02,
|
|
4373
|
+
5.97489612400613639965e02,
|
|
4374
|
+
2.11688757100572135698e02,
|
|
4375
|
+
2.52070205858023719784e01,
|
|
4376
|
+
],
|
|
4377
|
+
dtype=self.dtype,
|
|
4245
4378
|
)
|
|
4246
4379
|
QQ = jnp.array(
|
|
4247
|
-
|
|
4248
|
-
|
|
4249
|
-
|
|
4250
|
-
|
|
4251
|
-
|
|
4252
|
-
|
|
4253
|
-
|
|
4254
|
-
|
|
4255
|
-
|
|
4256
|
-
|
|
4380
|
+
[
|
|
4381
|
+
7.42373277035675149943e01,
|
|
4382
|
+
1.05644886038262816351e03,
|
|
4383
|
+
4.98641058337653607651e03,
|
|
4384
|
+
9.56231892404756170795e03,
|
|
4385
|
+
7.99704160447350683650e03,
|
|
4386
|
+
2.82619278517639096600e03,
|
|
4387
|
+
3.36093607810698293419e02,
|
|
4388
|
+
],
|
|
4389
|
+
dtype=self.dtype,
|
|
4257
4390
|
)
|
|
4258
4391
|
|
|
4259
|
-
pp = op_base.foreach_loop(
|
|
4260
|
-
|
|
4261
|
-
|
|
4262
|
-
|
|
4263
|
-
|
|
4264
|
-
|
|
4265
|
-
|
|
4266
|
-
|
|
4267
|
-
|
|
4268
|
-
|
|
4269
|
-
|
|
4270
|
-
|
|
4271
|
-
|
|
4272
|
-
* 0.797884560802865355879892119868763737
|
|
4273
|
-
/ jnp.sqrt(x)
|
|
4274
|
-
)
|
|
4392
|
+
pp = op_base.foreach_loop(
|
|
4393
|
+
PP, lambda carry, pp_i: carry * (25.0 / (x * x)) + pp_i)
|
|
4394
|
+
pq = op_base.foreach_loop(
|
|
4395
|
+
PQ, lambda carry, pq_i: carry * (25.0 / (x * x)) + pq_i)
|
|
4396
|
+
qp = op_base.foreach_loop(
|
|
4397
|
+
QP, lambda carry, qp_i: carry * (25.0 / (x * x)) + qp_i)
|
|
4398
|
+
qq = op_base.foreach_loop(
|
|
4399
|
+
QQ, lambda carry, qq_i: carry * (25.0 / (x * x)) + qq_i)
|
|
4400
|
+
|
|
4401
|
+
return ((pp / pq * jnp.cos(x - 2.356194490192344928846982537459627163) -
|
|
4402
|
+
5.0 / x *
|
|
4403
|
+
(qp / qq) * jnp.sin(x - 2.356194490192344928846982537459627163)) *
|
|
4404
|
+
0.797884560802865355879892119868763737 / jnp.sqrt(x))
|
|
4275
4405
|
|
|
4276
4406
|
# If x < 0, bessel_j1(x) = -bessel_j1(-x)
|
|
4277
4407
|
sign = jnp.sign(self)
|
|
4278
4408
|
self = jnp.abs(self)
|
|
4279
4409
|
return sign * jnp.piecewise(
|
|
4280
|
-
|
|
4281
|
-
|
|
4282
|
-
|
|
4410
|
+
self,
|
|
4411
|
+
[self <= 5.0],
|
|
4412
|
+
[small, default],
|
|
4283
4413
|
)
|
|
4284
4414
|
|
|
4285
4415
|
|
|
@@ -4296,85 +4426,86 @@ def _aten_special_bessel_y0(self):
|
|
|
4296
4426
|
|
|
4297
4427
|
def small(x):
|
|
4298
4428
|
YP = jnp.array(
|
|
4299
|
-
|
|
4300
|
-
|
|
4301
|
-
|
|
4302
|
-
|
|
4303
|
-
|
|
4304
|
-
|
|
4305
|
-
|
|
4306
|
-
|
|
4307
|
-
|
|
4308
|
-
|
|
4309
|
-
|
|
4429
|
+
[
|
|
4430
|
+
1.55924367855235737965e04,
|
|
4431
|
+
-1.46639295903971606143e07,
|
|
4432
|
+
5.43526477051876500413e09,
|
|
4433
|
+
-9.82136065717911466409e11,
|
|
4434
|
+
8.75906394395366999549e13,
|
|
4435
|
+
-3.46628303384729719441e15,
|
|
4436
|
+
4.42733268572569800351e16,
|
|
4437
|
+
-1.84950800436986690637e16,
|
|
4438
|
+
],
|
|
4439
|
+
dtype=self.dtype,
|
|
4310
4440
|
)
|
|
4311
4441
|
YQ = jnp.array(
|
|
4312
|
-
|
|
4313
|
-
|
|
4314
|
-
|
|
4315
|
-
|
|
4316
|
-
|
|
4317
|
-
|
|
4318
|
-
|
|
4319
|
-
|
|
4320
|
-
|
|
4321
|
-
|
|
4442
|
+
[
|
|
4443
|
+
1.04128353664259848412e03,
|
|
4444
|
+
6.26107330137134956842e05,
|
|
4445
|
+
2.68919633393814121987e08,
|
|
4446
|
+
8.64002487103935000337e10,
|
|
4447
|
+
2.02979612750105546709e13,
|
|
4448
|
+
3.17157752842975028269e15,
|
|
4449
|
+
2.50596256172653059228e17,
|
|
4450
|
+
],
|
|
4451
|
+
dtype=self.dtype,
|
|
4322
4452
|
)
|
|
4323
4453
|
|
|
4324
4454
|
yp = op_base.foreach_loop(YP, lambda carry, yp_i: carry * (x * x) + yp_i)
|
|
4325
4455
|
yq = op_base.foreach_loop(YQ, lambda carry, yq_i: carry * (x * x) + yq_i)
|
|
4326
4456
|
|
|
4327
|
-
return yp / yq + (0.636619772367581343075535053490057448 * jnp.log(x) *
|
|
4457
|
+
return yp / yq + (0.636619772367581343075535053490057448 * jnp.log(x) *
|
|
4458
|
+
_aten_special_bessel_j0(x))
|
|
4328
4459
|
|
|
4329
4460
|
def default(x):
|
|
4330
4461
|
PP = jnp.array(
|
|
4331
|
-
|
|
4332
|
-
|
|
4333
|
-
|
|
4334
|
-
|
|
4335
|
-
|
|
4336
|
-
|
|
4337
|
-
|
|
4338
|
-
|
|
4339
|
-
|
|
4340
|
-
|
|
4462
|
+
[
|
|
4463
|
+
7.96936729297347051624e-04,
|
|
4464
|
+
8.28352392107440799803e-02,
|
|
4465
|
+
1.23953371646414299388e00,
|
|
4466
|
+
5.44725003058768775090e00,
|
|
4467
|
+
8.74716500199817011941e00,
|
|
4468
|
+
5.30324038235394892183e00,
|
|
4469
|
+
9.99999999999999997821e-01,
|
|
4470
|
+
],
|
|
4471
|
+
dtype=self.dtype,
|
|
4341
4472
|
)
|
|
4342
4473
|
PQ = jnp.array(
|
|
4343
|
-
|
|
4344
|
-
|
|
4345
|
-
|
|
4346
|
-
|
|
4347
|
-
|
|
4348
|
-
|
|
4349
|
-
|
|
4350
|
-
|
|
4351
|
-
|
|
4352
|
-
|
|
4474
|
+
[
|
|
4475
|
+
9.24408810558863637013e-04,
|
|
4476
|
+
8.56288474354474431428e-02,
|
|
4477
|
+
1.25352743901058953537e00,
|
|
4478
|
+
5.47097740330417105182e00,
|
|
4479
|
+
8.76190883237069594232e00,
|
|
4480
|
+
5.30605288235394617618e00,
|
|
4481
|
+
1.00000000000000000218e00,
|
|
4482
|
+
],
|
|
4483
|
+
dtype=self.dtype,
|
|
4353
4484
|
)
|
|
4354
4485
|
QP = jnp.array(
|
|
4355
|
-
|
|
4356
|
-
|
|
4357
|
-
|
|
4358
|
-
|
|
4359
|
-
|
|
4360
|
-
|
|
4361
|
-
|
|
4362
|
-
|
|
4363
|
-
|
|
4364
|
-
|
|
4365
|
-
|
|
4486
|
+
[
|
|
4487
|
+
-1.13663838898469149931e-02,
|
|
4488
|
+
-1.28252718670509318512e00,
|
|
4489
|
+
-1.95539544257735972385e01,
|
|
4490
|
+
-9.32060152123768231369e01,
|
|
4491
|
+
-1.77681167980488050595e02,
|
|
4492
|
+
-1.47077505154951170175e02,
|
|
4493
|
+
-5.14105326766599330220e01,
|
|
4494
|
+
-6.05014350600728481186e00,
|
|
4495
|
+
],
|
|
4496
|
+
dtype=self.dtype,
|
|
4366
4497
|
)
|
|
4367
4498
|
QQ = jnp.array(
|
|
4368
|
-
|
|
4369
|
-
|
|
4370
|
-
|
|
4371
|
-
|
|
4372
|
-
|
|
4373
|
-
|
|
4374
|
-
|
|
4375
|
-
|
|
4376
|
-
|
|
4377
|
-
|
|
4499
|
+
[
|
|
4500
|
+
6.43178256118178023184e01,
|
|
4501
|
+
8.56430025976980587198e02,
|
|
4502
|
+
3.88240183605401609683e03,
|
|
4503
|
+
7.24046774195652478189e03,
|
|
4504
|
+
5.93072701187316984827e03,
|
|
4505
|
+
2.06209331660327847417e03,
|
|
4506
|
+
2.42005740240291393179e02,
|
|
4507
|
+
],
|
|
4508
|
+
dtype=self.dtype,
|
|
4378
4509
|
)
|
|
4379
4510
|
|
|
4380
4511
|
factor = 25.0 / (x * x)
|
|
@@ -4383,22 +4514,15 @@ def _aten_special_bessel_y0(self):
|
|
|
4383
4514
|
qp = op_base.foreach_loop(QP, lambda carry, qp_i: carry * factor + qp_i)
|
|
4384
4515
|
qq = op_base.foreach_loop(QQ, lambda carry, qq_i: carry * factor + qq_i)
|
|
4385
4516
|
|
|
4386
|
-
return (
|
|
4387
|
-
|
|
4388
|
-
|
|
4389
|
-
|
|
4390
|
-
/ x
|
|
4391
|
-
* (qp / qq)
|
|
4392
|
-
* jnp.cos(x - 0.785398163397448309615660845819875721)
|
|
4393
|
-
)
|
|
4394
|
-
* 0.797884560802865355879892119868763737
|
|
4395
|
-
/ jnp.sqrt(x)
|
|
4396
|
-
)
|
|
4517
|
+
return ((pp / pq * jnp.sin(x - 0.785398163397448309615660845819875721) +
|
|
4518
|
+
5.0 / x *
|
|
4519
|
+
(qp / qq) * jnp.cos(x - 0.785398163397448309615660845819875721)) *
|
|
4520
|
+
0.797884560802865355879892119868763737 / jnp.sqrt(x))
|
|
4397
4521
|
|
|
4398
4522
|
return jnp.piecewise(
|
|
4399
|
-
|
|
4400
|
-
|
|
4401
|
-
|
|
4523
|
+
self,
|
|
4524
|
+
[self <= 5.0, self < 0., self == 0.],
|
|
4525
|
+
[small, negative, zero, default],
|
|
4402
4526
|
)
|
|
4403
4527
|
|
|
4404
4528
|
|
|
@@ -4415,90 +4539,86 @@ def _aten_special_bessel_y1(self):
|
|
|
4415
4539
|
|
|
4416
4540
|
def small(x):
|
|
4417
4541
|
YP = jnp.array(
|
|
4418
|
-
|
|
4419
|
-
|
|
4420
|
-
|
|
4421
|
-
|
|
4422
|
-
|
|
4423
|
-
|
|
4424
|
-
|
|
4425
|
-
|
|
4426
|
-
|
|
4542
|
+
[
|
|
4543
|
+
1.26320474790178026440e09,
|
|
4544
|
+
-6.47355876379160291031e11,
|
|
4545
|
+
1.14509511541823727583e14,
|
|
4546
|
+
-8.12770255501325109621e15,
|
|
4547
|
+
2.02439475713594898196e17,
|
|
4548
|
+
-7.78877196265950026825e17,
|
|
4549
|
+
],
|
|
4550
|
+
dtype=self.dtype,
|
|
4427
4551
|
)
|
|
4428
4552
|
YQ = jnp.array(
|
|
4429
|
-
|
|
4430
|
-
|
|
4431
|
-
|
|
4432
|
-
|
|
4433
|
-
|
|
4434
|
-
|
|
4435
|
-
|
|
4436
|
-
|
|
4437
|
-
|
|
4438
|
-
|
|
4439
|
-
|
|
4553
|
+
[
|
|
4554
|
+
5.94301592346128195359e02,
|
|
4555
|
+
2.35564092943068577943e05,
|
|
4556
|
+
7.34811944459721705660e07,
|
|
4557
|
+
1.87601316108706159478e10,
|
|
4558
|
+
3.88231277496238566008e12,
|
|
4559
|
+
6.20557727146953693363e14,
|
|
4560
|
+
6.87141087355300489866e16,
|
|
4561
|
+
3.97270608116560655612e18,
|
|
4562
|
+
],
|
|
4563
|
+
dtype=self.dtype,
|
|
4440
4564
|
)
|
|
4441
4565
|
|
|
4442
4566
|
yp = op_base.foreach_loop(YP, lambda carry, yp_i: carry * (x * x) + yp_i)
|
|
4443
4567
|
yq = op_base.foreach_loop(YQ, lambda carry, yq_i: carry * (x * x) + yq_i)
|
|
4444
4568
|
|
|
4445
|
-
return (
|
|
4446
|
-
|
|
4447
|
-
|
|
4448
|
-
0.636619772367581343075535053490057448
|
|
4449
|
-
* (_aten_special_bessel_j1(x) * jnp.log(x) - 1.0 / x)
|
|
4450
|
-
)
|
|
4451
|
-
)
|
|
4569
|
+
return (x * (yp / yq) +
|
|
4570
|
+
(0.636619772367581343075535053490057448 *
|
|
4571
|
+
(_aten_special_bessel_j1(x) * jnp.log(x) - 1.0 / x)))
|
|
4452
4572
|
|
|
4453
4573
|
def default(x):
|
|
4454
4574
|
PP = jnp.array(
|
|
4455
|
-
|
|
4456
|
-
|
|
4457
|
-
|
|
4458
|
-
|
|
4459
|
-
|
|
4460
|
-
|
|
4461
|
-
|
|
4462
|
-
|
|
4463
|
-
|
|
4464
|
-
|
|
4575
|
+
[
|
|
4576
|
+
7.62125616208173112003e-04,
|
|
4577
|
+
7.31397056940917570436e-02,
|
|
4578
|
+
1.12719608129684925192e00,
|
|
4579
|
+
5.11207951146807644818e00,
|
|
4580
|
+
8.42404590141772420927e00,
|
|
4581
|
+
5.21451598682361504063e00,
|
|
4582
|
+
1.00000000000000000254e00,
|
|
4583
|
+
],
|
|
4584
|
+
dtype=self.dtype,
|
|
4465
4585
|
)
|
|
4466
4586
|
PQ = jnp.array(
|
|
4467
|
-
|
|
4468
|
-
|
|
4469
|
-
|
|
4470
|
-
|
|
4471
|
-
|
|
4472
|
-
|
|
4473
|
-
|
|
4474
|
-
|
|
4475
|
-
|
|
4476
|
-
|
|
4587
|
+
[
|
|
4588
|
+
5.71323128072548699714e-04,
|
|
4589
|
+
6.88455908754495404082e-02,
|
|
4590
|
+
1.10514232634061696926e00,
|
|
4591
|
+
5.07386386128601488557e00,
|
|
4592
|
+
8.39985554327604159757e00,
|
|
4593
|
+
5.20982848682361821619e00,
|
|
4594
|
+
9.99999999999999997461e-01,
|
|
4595
|
+
],
|
|
4596
|
+
dtype=self.dtype,
|
|
4477
4597
|
)
|
|
4478
4598
|
QP = jnp.array(
|
|
4479
|
-
|
|
4480
|
-
|
|
4481
|
-
|
|
4482
|
-
|
|
4483
|
-
|
|
4484
|
-
|
|
4485
|
-
|
|
4486
|
-
|
|
4487
|
-
|
|
4488
|
-
|
|
4489
|
-
|
|
4599
|
+
[
|
|
4600
|
+
5.10862594750176621635e-02,
|
|
4601
|
+
4.98213872951233449420e00,
|
|
4602
|
+
7.58238284132545283818e01,
|
|
4603
|
+
3.66779609360150777800e02,
|
|
4604
|
+
7.10856304998926107277e02,
|
|
4605
|
+
5.97489612400613639965e02,
|
|
4606
|
+
2.11688757100572135698e02,
|
|
4607
|
+
2.52070205858023719784e01,
|
|
4608
|
+
],
|
|
4609
|
+
dtype=self.dtype,
|
|
4490
4610
|
)
|
|
4491
4611
|
QQ = jnp.array(
|
|
4492
|
-
|
|
4493
|
-
|
|
4494
|
-
|
|
4495
|
-
|
|
4496
|
-
|
|
4497
|
-
|
|
4498
|
-
|
|
4499
|
-
|
|
4500
|
-
|
|
4501
|
-
|
|
4612
|
+
[
|
|
4613
|
+
7.42373277035675149943e01,
|
|
4614
|
+
1.05644886038262816351e03,
|
|
4615
|
+
4.98641058337653607651e03,
|
|
4616
|
+
9.56231892404756170795e03,
|
|
4617
|
+
7.99704160447350683650e03,
|
|
4618
|
+
2.82619278517639096600e03,
|
|
4619
|
+
3.36093607810698293419e02,
|
|
4620
|
+
],
|
|
4621
|
+
dtype=self.dtype,
|
|
4502
4622
|
)
|
|
4503
4623
|
|
|
4504
4624
|
factor = 25.0 / (x * x)
|
|
@@ -4507,22 +4627,15 @@ def _aten_special_bessel_y1(self):
|
|
|
4507
4627
|
qp = op_base.foreach_loop(QP, lambda carry, qp_i: carry * factor + qp_i)
|
|
4508
4628
|
qq = op_base.foreach_loop(QQ, lambda carry, qq_i: carry * factor + qq_i)
|
|
4509
4629
|
|
|
4510
|
-
return (
|
|
4511
|
-
|
|
4512
|
-
|
|
4513
|
-
|
|
4514
|
-
/ x
|
|
4515
|
-
* (qp / qq)
|
|
4516
|
-
* jnp.cos(x - 2.356194490192344928846982537459627163)
|
|
4517
|
-
)
|
|
4518
|
-
* 0.797884560802865355879892119868763737
|
|
4519
|
-
/ jnp.sqrt(x)
|
|
4520
|
-
)
|
|
4630
|
+
return ((pp / pq * jnp.sin(x - 2.356194490192344928846982537459627163) +
|
|
4631
|
+
5.0 / x *
|
|
4632
|
+
(qp / qq) * jnp.cos(x - 2.356194490192344928846982537459627163)) *
|
|
4633
|
+
0.797884560802865355879892119868763737 / jnp.sqrt(x))
|
|
4521
4634
|
|
|
4522
4635
|
return jnp.piecewise(
|
|
4523
|
-
|
|
4524
|
-
|
|
4525
|
-
|
|
4636
|
+
self,
|
|
4637
|
+
[self <= 5.0, self < 0., self == 0.],
|
|
4638
|
+
[small, negative, zero, default],
|
|
4526
4639
|
)
|
|
4527
4640
|
|
|
4528
4641
|
|
|
@@ -4533,11 +4646,13 @@ def _aten_special_chebyshev_polynomial_t(self, n):
|
|
|
4533
4646
|
|
|
4534
4647
|
@jnp.vectorize
|
|
4535
4648
|
def vectorized(x, n_i):
|
|
4649
|
+
|
|
4536
4650
|
def negative_n(x):
|
|
4537
4651
|
return jnp.zeros_like(x)
|
|
4538
4652
|
|
|
4539
4653
|
def one_x(x):
|
|
4540
|
-
return jnp.where((x > 0) | (n_i % 2 == 0), jnp.ones_like(x),
|
|
4654
|
+
return jnp.where((x > 0) | (n_i % 2 == 0), jnp.ones_like(x),
|
|
4655
|
+
-jnp.ones_like(x))
|
|
4541
4656
|
|
|
4542
4657
|
def large_n_small_x(x):
|
|
4543
4658
|
return jnp.cos(n_i * jnp.acos(x))
|
|
@@ -4549,24 +4664,18 @@ def _aten_special_chebyshev_polynomial_t(self, n):
|
|
|
4549
4664
|
return x
|
|
4550
4665
|
|
|
4551
4666
|
def default(x):
|
|
4667
|
+
|
|
4552
4668
|
def f(_, carry):
|
|
4553
4669
|
p, q = carry
|
|
4554
4670
|
return (q, 2 * x * q - p)
|
|
4555
4671
|
|
|
4556
|
-
_, r
|
|
4672
|
+
_, r = jax.lax.fori_loop(0, n_i - 1, f, init_val=(1., x))
|
|
4557
4673
|
return r
|
|
4558
4674
|
|
|
4559
|
-
return jnp.piecewise(
|
|
4560
|
-
|
|
4561
|
-
|
|
4562
|
-
|
|
4563
|
-
n_i == 0,
|
|
4564
|
-
(n_i == 6) & (jnp.abs(x) < 1),
|
|
4565
|
-
jnp.abs(x) == 1.,
|
|
4566
|
-
n_i < 0
|
|
4567
|
-
],
|
|
4568
|
-
[one_n, zero_n, large_n_small_x, one_x, negative_n, default]
|
|
4569
|
-
)
|
|
4675
|
+
return jnp.piecewise(x, [
|
|
4676
|
+
n_i == 1, n_i == 0, (n_i == 6) & (jnp.abs(x) < 1),
|
|
4677
|
+
jnp.abs(x) == 1., n_i < 0
|
|
4678
|
+
], [one_n, zero_n, large_n_small_x, one_x, negative_n, default])
|
|
4570
4679
|
|
|
4571
4680
|
# Explcicitly vectorize since we must vectorizes over both self and n
|
|
4572
4681
|
return vectorized(self, n.astype(jnp.int64))
|
|
@@ -4579,6 +4688,7 @@ def _aten_special_chebyshev_polynomial_u(self, n):
|
|
|
4579
4688
|
|
|
4580
4689
|
@jnp.vectorize
|
|
4581
4690
|
def vectorized(x, n_i):
|
|
4691
|
+
|
|
4582
4692
|
def negative_n(x):
|
|
4583
4693
|
return jnp.zeros_like(x)
|
|
4584
4694
|
|
|
@@ -4588,9 +4698,9 @@ def _aten_special_chebyshev_polynomial_u(self, n):
|
|
|
4588
4698
|
def large_n_small_x(x):
|
|
4589
4699
|
sin_acos_x = jnp.sin(jnp.acos(x))
|
|
4590
4700
|
return jnp.where(
|
|
4591
|
-
|
|
4592
|
-
|
|
4593
|
-
|
|
4701
|
+
sin_acos_x != 0,
|
|
4702
|
+
jnp.sin((n_i + 1) * jnp.acos(x)) / sin_acos_x,
|
|
4703
|
+
(n_i + 1) * jnp.cos((n_i + 1) * jnp.acos(x)) / x,
|
|
4594
4704
|
)
|
|
4595
4705
|
|
|
4596
4706
|
def zero_n(x):
|
|
@@ -4600,6 +4710,7 @@ def _aten_special_chebyshev_polynomial_u(self, n):
|
|
|
4600
4710
|
return 2 * x
|
|
4601
4711
|
|
|
4602
4712
|
def default(x):
|
|
4713
|
+
|
|
4603
4714
|
def f(_, carry):
|
|
4604
4715
|
p, q = carry
|
|
4605
4716
|
return (q, 2 * x * q - p)
|
|
@@ -4608,15 +4719,15 @@ def _aten_special_chebyshev_polynomial_u(self, n):
|
|
|
4608
4719
|
return r
|
|
4609
4720
|
|
|
4610
4721
|
return jnp.piecewise(
|
|
4611
|
-
|
|
4612
|
-
|
|
4613
|
-
|
|
4614
|
-
|
|
4615
|
-
|
|
4616
|
-
|
|
4617
|
-
|
|
4618
|
-
|
|
4619
|
-
|
|
4722
|
+
x,
|
|
4723
|
+
[
|
|
4724
|
+
n_i == 1,
|
|
4725
|
+
n_i == 0,
|
|
4726
|
+
(n_i > 8) & (jnp.abs(x) < 1),
|
|
4727
|
+
jnp.abs(x) == 1.0,
|
|
4728
|
+
n_i < 0,
|
|
4729
|
+
],
|
|
4730
|
+
[one_n, zero_n, large_n_small_x, one_x, negative_n, default],
|
|
4620
4731
|
)
|
|
4621
4732
|
|
|
4622
4733
|
return vectorized(self, n.astype(jnp.int64))
|
|
@@ -4627,6 +4738,7 @@ def _aten_special_chebyshev_polynomial_u(self, n):
|
|
|
4627
4738
|
def _aten_special_erfcx(x):
|
|
4628
4739
|
return jnp.exp(x * x) * jax.lax.erfc(x)
|
|
4629
4740
|
|
|
4741
|
+
|
|
4630
4742
|
@op(torch.ops.aten.erfc)
|
|
4631
4743
|
@op_base.promote_int_input
|
|
4632
4744
|
def _aten_erfcx(x):
|
|
@@ -4640,6 +4752,7 @@ def _aten_special_hermite_polynomial_h(self, n):
|
|
|
4640
4752
|
|
|
4641
4753
|
@jnp.vectorize
|
|
4642
4754
|
def vectorized(x, n_i):
|
|
4755
|
+
|
|
4643
4756
|
def negative_n(x):
|
|
4644
4757
|
return jnp.zeros_like(x)
|
|
4645
4758
|
|
|
@@ -4650,6 +4763,7 @@ def _aten_special_hermite_polynomial_h(self, n):
|
|
|
4650
4763
|
return 2 * x
|
|
4651
4764
|
|
|
4652
4765
|
def default(x):
|
|
4766
|
+
|
|
4653
4767
|
def f(k, carry):
|
|
4654
4768
|
p, q = carry
|
|
4655
4769
|
return (q, 2 * x * q - 2 * k * p)
|
|
@@ -4657,9 +4771,8 @@ def _aten_special_hermite_polynomial_h(self, n):
|
|
|
4657
4771
|
_, r = jax.lax.fori_loop(1, n_i, f, init_val=(1.0, 2 * x))
|
|
4658
4772
|
return r
|
|
4659
4773
|
|
|
4660
|
-
return jnp.piecewise(
|
|
4661
|
-
|
|
4662
|
-
)
|
|
4774
|
+
return jnp.piecewise(x, [n_i == 1, n_i == 0, n_i < 0],
|
|
4775
|
+
[one_n, zero_n, negative_n, default])
|
|
4663
4776
|
|
|
4664
4777
|
return vectorized(self, n.astype(jnp.int64))
|
|
4665
4778
|
|
|
@@ -4671,6 +4784,7 @@ def _aten_special_hermite_polynomial_he(self, n):
|
|
|
4671
4784
|
|
|
4672
4785
|
@jnp.vectorize
|
|
4673
4786
|
def vectorized(x, n_i):
|
|
4787
|
+
|
|
4674
4788
|
def negative_n(x):
|
|
4675
4789
|
return jnp.zeros_like(x)
|
|
4676
4790
|
|
|
@@ -4681,6 +4795,7 @@ def _aten_special_hermite_polynomial_he(self, n):
|
|
|
4681
4795
|
return x
|
|
4682
4796
|
|
|
4683
4797
|
def default(x):
|
|
4798
|
+
|
|
4684
4799
|
def f(k, carry):
|
|
4685
4800
|
p, q = carry
|
|
4686
4801
|
return (q, x * q - k * p)
|
|
@@ -4688,24 +4803,34 @@ def _aten_special_hermite_polynomial_he(self, n):
|
|
|
4688
4803
|
_, r = jax.lax.fori_loop(1, n_i, f, init_val=(1.0, x))
|
|
4689
4804
|
return r
|
|
4690
4805
|
|
|
4691
|
-
return jnp.piecewise(
|
|
4692
|
-
|
|
4693
|
-
)
|
|
4806
|
+
return jnp.piecewise(x, [n_i == 1.0, n_i == 0.0, n_i < 0],
|
|
4807
|
+
[one_n, zero_n, negative_n, default])
|
|
4694
4808
|
|
|
4695
4809
|
return vectorized(self, n.astype(jnp.int64))
|
|
4696
4810
|
|
|
4697
4811
|
|
|
4698
4812
|
@op(torch.ops.aten.multinomial, needs_env=True)
|
|
4699
|
-
def _aten_multinomial(input,
|
|
4700
|
-
|
|
4701
|
-
|
|
4813
|
+
def _aten_multinomial(input,
|
|
4814
|
+
num_samples,
|
|
4815
|
+
replacement=False,
|
|
4816
|
+
*,
|
|
4817
|
+
generator=None,
|
|
4818
|
+
out=None,
|
|
4819
|
+
env=None):
|
|
4820
|
+
assert num_samples <= input.shape[
|
|
4821
|
+
-1] or replacement, "cannot take a larger sample than population when replacement=False"
|
|
4702
4822
|
key = env.get_and_rotate_prng_key(generator)
|
|
4703
4823
|
if input.ndim == 1:
|
|
4704
|
-
|
|
4705
|
-
|
|
4824
|
+
return jax.random.choice(
|
|
4825
|
+
key, input.shape[-1], (num_samples,), replace=replacement, p=input)
|
|
4706
4826
|
else:
|
|
4707
|
-
|
|
4708
|
-
|
|
4827
|
+
return jnp.array([
|
|
4828
|
+
jax.random.choice(
|
|
4829
|
+
key,
|
|
4830
|
+
input.shape[-1], (num_samples,),
|
|
4831
|
+
replace=replacement,
|
|
4832
|
+
p=input[i, :]) for i in range(input.shape[0])
|
|
4833
|
+
])
|
|
4709
4834
|
|
|
4710
4835
|
|
|
4711
4836
|
@op(torch.ops.aten.narrow)
|
|
@@ -4738,7 +4863,12 @@ def _aten_flatten(x, start_dim=0, end_dim=-1):
|
|
|
4738
4863
|
|
|
4739
4864
|
@op(torch.ops.aten.new_empty)
|
|
4740
4865
|
def _new_empty(self, size, **kwargs):
|
|
4741
|
-
|
|
4866
|
+
dtype = kwargs.get('dtype')
|
|
4867
|
+
if dtype is not None:
|
|
4868
|
+
dtype = mappings.t2j_dtype(dtype)
|
|
4869
|
+
else:
|
|
4870
|
+
dtype = self.dtype
|
|
4871
|
+
return jnp.empty(size, dtype=dtype)
|
|
4742
4872
|
|
|
4743
4873
|
|
|
4744
4874
|
@op(torch.ops.aten.new_empty_strided)
|
|
@@ -4756,10 +4886,8 @@ def _aten_unsafe_index_put(self, indices, values, accumulate=False):
|
|
|
4756
4886
|
return _aten_index_put(self, indices, values, accumulate)
|
|
4757
4887
|
|
|
4758
4888
|
|
|
4759
|
-
@op(torch.ops.aten.conj_physical,
|
|
4760
|
-
torch.ops.aten.
|
|
4761
|
-
torch.ops.aten._conj_physical,
|
|
4762
|
-
torch.ops.aten._conj)
|
|
4889
|
+
@op(torch.ops.aten.conj_physical, torch.ops.aten.conj,
|
|
4890
|
+
torch.ops.aten._conj_physical, torch.ops.aten._conj)
|
|
4763
4891
|
def _aten_conj_physical(self):
|
|
4764
4892
|
return jnp.conjugate(self)
|
|
4765
4893
|
|
|
@@ -4768,6 +4896,7 @@ def _aten_conj_physical(self):
|
|
|
4768
4896
|
def _aten_log_sigmoid(x):
|
|
4769
4897
|
return jax.nn.log_sigmoid(x)
|
|
4770
4898
|
|
|
4899
|
+
|
|
4771
4900
|
# torch.qr
|
|
4772
4901
|
@op(torch.ops.aten.qr)
|
|
4773
4902
|
def _aten_qr(input, *args, **kwargs):
|
|
@@ -4778,6 +4907,7 @@ def _aten_qr(input, *args, **kwargs):
|
|
|
4778
4907
|
jax_mode = "complete"
|
|
4779
4908
|
return jax.numpy.linalg.qr(input, mode=jax_mode)
|
|
4780
4909
|
|
|
4910
|
+
|
|
4781
4911
|
# torch.linalg.qr
|
|
4782
4912
|
@op(torch.ops.aten.linalg_qr)
|
|
4783
4913
|
def _aten_linalg_qr(input, *args, **kwargs):
|
|
@@ -4820,19 +4950,25 @@ def _aten__linalg_solve_ex(a, b):
|
|
|
4820
4950
|
res = jnp.linalg.solve(a, b)
|
|
4821
4951
|
if batched:
|
|
4822
4952
|
res = res.squeeze(-1)
|
|
4823
|
-
info_shape = a.shape[
|
|
4953
|
+
info_shape = a.shape[:-2]
|
|
4824
4954
|
info = jnp.zeros(info_shape, dtype=mappings.t2j_dtype(torch.int32))
|
|
4825
4955
|
return res, info
|
|
4826
4956
|
|
|
4827
4957
|
|
|
4828
4958
|
# torch.linalg.solve_triangular
|
|
4829
4959
|
@op(torch.ops.aten.linalg_solve_triangular)
|
|
4830
|
-
def _aten_linalg_solve_triangular(a,
|
|
4960
|
+
def _aten_linalg_solve_triangular(a,
|
|
4961
|
+
b,
|
|
4962
|
+
*,
|
|
4963
|
+
upper=True,
|
|
4964
|
+
left=True,
|
|
4965
|
+
unitriangular=False):
|
|
4831
4966
|
if left is False:
|
|
4832
4967
|
a = jnp.matrix_transpose(a)
|
|
4833
4968
|
b = jnp.matrix_transpose(b)
|
|
4834
4969
|
upper = not upper
|
|
4835
|
-
res = jax.scipy.linalg.solve_triangular(
|
|
4970
|
+
res = jax.scipy.linalg.solve_triangular(
|
|
4971
|
+
a, b, lower=not upper, unit_diagonal=unitriangular)
|
|
4836
4972
|
if left is False:
|
|
4837
4973
|
res = jnp.matrix_transpose(res)
|
|
4838
4974
|
return res
|
|
@@ -4852,21 +4988,31 @@ def _aten__linalg_check_errors(*args, **kwargs):
|
|
|
4852
4988
|
|
|
4853
4989
|
@op(torch.ops.aten.median)
|
|
4854
4990
|
def _aten_median(self, dim=None, keepdim=False):
|
|
4855
|
-
output = _with_reduction_scalar(
|
|
4991
|
+
output = _with_reduction_scalar(
|
|
4992
|
+
functools.partial(jnp.quantile, q=0.5, method='lower'),
|
|
4993
|
+
self,
|
|
4994
|
+
dim=dim,
|
|
4995
|
+
keepdim=keepdim).astype(self.dtype)
|
|
4856
4996
|
if dim is None:
|
|
4857
4997
|
return output
|
|
4858
4998
|
else:
|
|
4859
|
-
index = _with_reduction_scalar(_get_median_index, self, dim,
|
|
4999
|
+
index = _with_reduction_scalar(_get_median_index, self, dim,
|
|
5000
|
+
keepdim).astype(jnp.int64)
|
|
4860
5001
|
return output, index
|
|
4861
5002
|
|
|
4862
5003
|
|
|
4863
5004
|
@op(torch.ops.aten.nanmedian)
|
|
4864
5005
|
def _aten_nanmedian(input, dim=None, keepdim=False, *, out=None):
|
|
4865
|
-
output = _with_reduction_scalar(
|
|
5006
|
+
output = _with_reduction_scalar(
|
|
5007
|
+
functools.partial(jnp.nanquantile, q=0.5, method='lower'),
|
|
5008
|
+
input,
|
|
5009
|
+
dim=dim,
|
|
5010
|
+
keepdim=keepdim).astype(input.dtype)
|
|
4866
5011
|
if dim is None:
|
|
4867
5012
|
return output
|
|
4868
5013
|
else:
|
|
4869
|
-
index = _with_reduction_scalar(_get_median_index, input, dim,
|
|
5014
|
+
index = _with_reduction_scalar(_get_median_index, input, dim,
|
|
5015
|
+
keepdim).astype(jnp.int64)
|
|
4870
5016
|
return output, index
|
|
4871
5017
|
|
|
4872
5018
|
|
|
@@ -4874,20 +5020,31 @@ def _get_median_index(x, axis=None, keepdims=False):
|
|
|
4874
5020
|
sorted_arg = jnp.argsort(x, axis=axis)
|
|
4875
5021
|
n = x.shape[axis] if axis is not None else x.size
|
|
4876
5022
|
if n % 2 == 1:
|
|
4877
|
-
|
|
5023
|
+
index = n // 2
|
|
4878
5024
|
else:
|
|
4879
|
-
|
|
5025
|
+
index = (n // 2) - 1
|
|
4880
5026
|
if axis is None:
|
|
4881
|
-
|
|
5027
|
+
median_index = sorted_arg[index]
|
|
4882
5028
|
else:
|
|
4883
|
-
|
|
5029
|
+
median_index = jnp.take(sorted_arg, index, axis=axis)
|
|
4884
5030
|
if keepdims and axis is not None:
|
|
4885
|
-
|
|
5031
|
+
median_index = jnp.expand_dims(median_index, axis)
|
|
4886
5032
|
return median_index
|
|
4887
5033
|
|
|
5034
|
+
|
|
4888
5035
|
@op(torch.ops.aten.triangular_solve)
|
|
4889
|
-
def _aten_triangular_solve(b,
|
|
4890
|
-
|
|
5036
|
+
def _aten_triangular_solve(b,
|
|
5037
|
+
a,
|
|
5038
|
+
upper=True,
|
|
5039
|
+
transpose=False,
|
|
5040
|
+
unittriangular=False):
|
|
5041
|
+
return (jax.lax.linalg.triangular_solve(
|
|
5042
|
+
a,
|
|
5043
|
+
b,
|
|
5044
|
+
left_side=True,
|
|
5045
|
+
lower=not upper,
|
|
5046
|
+
transpose_a=transpose,
|
|
5047
|
+
unit_diagonal=unittriangular), a)
|
|
4891
5048
|
|
|
4892
5049
|
|
|
4893
5050
|
# func: _fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor
|
|
@@ -4895,16 +5052,16 @@ def _aten_triangular_solve(b, a, upper=True, transpose=False, unittriangular=Fal
|
|
|
4895
5052
|
def _aten__fft_c2c(self, dim, normalization, forward):
|
|
4896
5053
|
if forward:
|
|
4897
5054
|
norm = [
|
|
4898
|
-
|
|
4899
|
-
|
|
4900
|
-
|
|
5055
|
+
'backward',
|
|
5056
|
+
'ortho',
|
|
5057
|
+
'forward',
|
|
4901
5058
|
][normalization]
|
|
4902
5059
|
return jnp.fft.fftn(self, axes=dim, norm=norm)
|
|
4903
5060
|
else:
|
|
4904
5061
|
norm = [
|
|
4905
|
-
|
|
4906
|
-
|
|
4907
|
-
|
|
5062
|
+
'forward',
|
|
5063
|
+
'ortho',
|
|
5064
|
+
'backward',
|
|
4908
5065
|
][normalization]
|
|
4909
5066
|
return jnp.fft.ifftn(self, axes=dim, norm=norm)
|
|
4910
5067
|
|
|
@@ -4912,21 +5069,22 @@ def _aten__fft_c2c(self, dim, normalization, forward):
|
|
|
4912
5069
|
@op(torch.ops.aten._fft_r2c)
|
|
4913
5070
|
def _aten__fft_r2c(self, dim, normalization, onesided):
|
|
4914
5071
|
norm = [
|
|
4915
|
-
|
|
4916
|
-
|
|
4917
|
-
|
|
5072
|
+
'backward',
|
|
5073
|
+
'ortho',
|
|
5074
|
+
'forward',
|
|
4918
5075
|
][normalization]
|
|
4919
5076
|
if onesided:
|
|
4920
5077
|
return jnp.fft.rfftn(self, axes=dim, norm=norm)
|
|
4921
5078
|
else:
|
|
4922
5079
|
return jnp.fft.fftn(self, axes=dim, norm=norm)
|
|
4923
5080
|
|
|
5081
|
+
|
|
4924
5082
|
@op(torch.ops.aten._fft_c2r)
|
|
4925
5083
|
def _aten__fft_c2r(self, dim, normalization, last_dim_size):
|
|
4926
5084
|
norm = [
|
|
4927
|
-
|
|
4928
|
-
|
|
4929
|
-
|
|
5085
|
+
'forward',
|
|
5086
|
+
'ortho',
|
|
5087
|
+
'backward',
|
|
4930
5088
|
][normalization]
|
|
4931
5089
|
if len(dim) == 1:
|
|
4932
5090
|
s = [last_dim_size]
|
|
@@ -4936,34 +5094,49 @@ def _aten__fft_c2r(self, dim, normalization, last_dim_size):
|
|
|
4936
5094
|
|
|
4937
5095
|
|
|
4938
5096
|
@op(torch.ops.aten._trilinear)
|
|
4939
|
-
def _aten_trilinear(i1,
|
|
4940
|
-
|
|
5097
|
+
def _aten_trilinear(i1,
|
|
5098
|
+
i2,
|
|
5099
|
+
i3,
|
|
5100
|
+
expand1,
|
|
5101
|
+
expand2,
|
|
5102
|
+
expand3,
|
|
5103
|
+
sumdim,
|
|
5104
|
+
unroll_dim=1):
|
|
5105
|
+
return _aten_sum(
|
|
5106
|
+
jnp.expand_dims(i1, expand1) * jnp.expand_dims(i2, expand2) *
|
|
5107
|
+
jnp.expand_dims(i3, expand3), sumdim)
|
|
4941
5108
|
|
|
4942
5109
|
|
|
4943
5110
|
@op(torch.ops.aten.max_unpool2d)
|
|
4944
5111
|
@op(torch.ops.aten.max_unpool3d)
|
|
4945
5112
|
def _aten_max_unpoolxd(input, indices, output_size, stride=None, padding=0):
|
|
4946
5113
|
if output_size is None:
|
|
4947
|
-
raise ValueError(
|
|
5114
|
+
raise ValueError(
|
|
5115
|
+
"output_size value is not set correctly. It cannot be None or empty.")
|
|
4948
5116
|
|
|
4949
5117
|
output_size = [input.shape[0], input.shape[1]] + output_size
|
|
4950
5118
|
output = jnp.zeros(output_size, dtype=input.dtype)
|
|
4951
5119
|
|
|
4952
5120
|
for idx in np.ndindex(input.shape):
|
|
4953
|
-
|
|
4954
|
-
|
|
4955
|
-
|
|
4956
|
-
|
|
4957
|
-
|
|
5121
|
+
max_index = indices[idx]
|
|
5122
|
+
spatial_dims = output_size[2:] # (D, H, W)
|
|
5123
|
+
unpooled_spatial_idx = np.unravel_index(max_index, spatial_dims)
|
|
5124
|
+
full_idx = idx[:2] + unpooled_spatial_idx
|
|
5125
|
+
output = output.at[full_idx].set(input[idx])
|
|
4958
5126
|
|
|
4959
5127
|
return output
|
|
4960
5128
|
|
|
4961
|
-
|
|
4962
|
-
def
|
|
5129
|
+
|
|
5130
|
+
def _aten_upsample(input,
|
|
5131
|
+
output_size,
|
|
5132
|
+
align_corners,
|
|
5133
|
+
antialias,
|
|
5134
|
+
method,
|
|
5135
|
+
scale_factors=None,
|
|
5136
|
+
scales_h=None,
|
|
5137
|
+
scales_w=None):
|
|
4963
5138
|
# input: is of type jaxlib.xla_extension.ArrayImpl
|
|
4964
5139
|
image = input
|
|
4965
|
-
method = "bilinear"
|
|
4966
|
-
antialias = True # ignored for upsampling
|
|
4967
5140
|
|
|
4968
5141
|
# https://jax.readthedocs.io/en/latest/_autosummary/jax.image.resize.html
|
|
4969
5142
|
# Resize does not distinguish batch, channel size.
|
|
@@ -4977,12 +5150,12 @@ def _aten_upsample_bilinear2d_aa(input, output_size, align_corners, scale_factor
|
|
|
4977
5150
|
shape = list(image.shape)
|
|
4978
5151
|
# overriding output_size
|
|
4979
5152
|
if scale_factors:
|
|
4980
|
-
shape[-1] = int(math.floor(shape[-1]*scale_factors[-1]))
|
|
4981
|
-
shape[-2] = int(math.floor(shape[-2]*scale_factors[-2]))
|
|
5153
|
+
shape[-1] = int(math.floor(shape[-1] * scale_factors[-1]))
|
|
5154
|
+
shape[-2] = int(math.floor(shape[-2] * scale_factors[-2]))
|
|
4982
5155
|
if scales_h:
|
|
4983
|
-
shape[-2] = int(math.floor(shape[-2]*scales_h))
|
|
5156
|
+
shape[-2] = int(math.floor(shape[-2] * scales_h))
|
|
4984
5157
|
if scales_w:
|
|
4985
|
-
shape[-1] = int(math.floor(shape[-1]*scales_w))
|
|
5158
|
+
shape[-1] = int(math.floor(shape[-1] * scales_w))
|
|
4986
5159
|
# output_size overrides scale_factors, scales_*
|
|
4987
5160
|
if output_size:
|
|
4988
5161
|
shape[-1] = output_size[-1]
|
|
@@ -4992,11 +5165,11 @@ def _aten_upsample_bilinear2d_aa(input, output_size, align_corners, scale_factor
|
|
|
4992
5165
|
if shape == list(image.shape):
|
|
4993
5166
|
return image
|
|
4994
5167
|
|
|
4995
|
-
spatial_dims = (2,3)
|
|
5168
|
+
spatial_dims = (2, 3)
|
|
4996
5169
|
if len(shape) == 3:
|
|
4997
|
-
spatial_dims = (1,2)
|
|
5170
|
+
spatial_dims = (1, 2)
|
|
4998
5171
|
|
|
4999
|
-
scale = list([shape[i] / image.shape[i]
|
|
5172
|
+
scale = list([shape[i] / image.shape[i] for i in spatial_dims])
|
|
5000
5173
|
if scale_factors:
|
|
5001
5174
|
scale = scale_factors
|
|
5002
5175
|
if scales_h:
|
|
@@ -5008,7 +5181,9 @@ def _aten_upsample_bilinear2d_aa(input, output_size, align_corners, scale_factor
|
|
|
5008
5181
|
# align_corners is not supported in resize()
|
|
5009
5182
|
# https://github.com/jax-ml/jax/issues/11206
|
|
5010
5183
|
if align_corners:
|
|
5011
|
-
scale = jnp.array([
|
|
5184
|
+
scale = jnp.array([
|
|
5185
|
+
(shape[i] - 1.0) / (image.shape[i] - 1.0) for i in spatial_dims
|
|
5186
|
+
])
|
|
5012
5187
|
|
|
5013
5188
|
translation = jnp.array([0 for i in spatial_dims])
|
|
5014
5189
|
|
|
@@ -5022,12 +5197,53 @@ def _aten_upsample_bilinear2d_aa(input, output_size, align_corners, scale_factor
|
|
|
5022
5197
|
antialias=antialias,
|
|
5023
5198
|
)
|
|
5024
5199
|
|
|
5200
|
+
|
|
5201
|
+
@op(torch.ops.aten._upsample_bilinear2d_aa)
|
|
5202
|
+
def _aten_upsample_billinear_aa(input,
|
|
5203
|
+
output_size,
|
|
5204
|
+
align_corners,
|
|
5205
|
+
scale_factors=None,
|
|
5206
|
+
scales_h=None,
|
|
5207
|
+
scales_w=None):
|
|
5208
|
+
return _aten_upsample(
|
|
5209
|
+
input,
|
|
5210
|
+
output_size,
|
|
5211
|
+
align_corners,
|
|
5212
|
+
True, # antialias
|
|
5213
|
+
"bilinear", # method
|
|
5214
|
+
scale_factors,
|
|
5215
|
+
scales_h,
|
|
5216
|
+
scales_w)
|
|
5217
|
+
|
|
5218
|
+
|
|
5219
|
+
@op(torch.ops.aten._upsample_bicubic2d_aa)
|
|
5220
|
+
def _aten_upsample_bicubic2d_aa(input,
|
|
5221
|
+
output_size,
|
|
5222
|
+
align_corners,
|
|
5223
|
+
scale_factors=None,
|
|
5224
|
+
scales_h=None,
|
|
5225
|
+
scales_w=None):
|
|
5226
|
+
return _aten_upsample(
|
|
5227
|
+
input,
|
|
5228
|
+
output_size,
|
|
5229
|
+
align_corners,
|
|
5230
|
+
True, # antialias
|
|
5231
|
+
"bicubic", # method
|
|
5232
|
+
scale_factors,
|
|
5233
|
+
scales_h,
|
|
5234
|
+
scales_w)
|
|
5235
|
+
|
|
5236
|
+
|
|
5025
5237
|
@op(torch.ops.aten.polar)
|
|
5026
5238
|
def _aten_polar(abs, angle, *, out=None):
|
|
5027
5239
|
return jax.lax.complex(abs * jnp.cos(angle), abs * jnp.sin(angle))
|
|
5028
5240
|
|
|
5241
|
+
|
|
5029
5242
|
@op(torch.ops.aten.cdist)
|
|
5030
|
-
def _aten_cdist(x1,
|
|
5243
|
+
def _aten_cdist(x1,
|
|
5244
|
+
x2,
|
|
5245
|
+
p=2.0,
|
|
5246
|
+
compute_mode='use_mm_for_euclid_dist_if_necessary'):
|
|
5031
5247
|
x1 = x1.astype(jnp.float32)
|
|
5032
5248
|
x2 = x2.astype(jnp.float32)
|
|
5033
5249
|
|
|
@@ -5036,7 +5252,8 @@ def _aten_cdist(x1, x2, p=2.0, compute_mode='use_mm_for_euclid_dist_if_necessary
|
|
|
5036
5252
|
return _hamming_distance(x1, x2).astype(jnp.float32)
|
|
5037
5253
|
elif p == 2.0:
|
|
5038
5254
|
# Use optimized Euclidean distance calculation
|
|
5039
|
-
if compute_mode == 'use_mm_for_euclid_dist_if_necessary' and (
|
|
5255
|
+
if compute_mode == 'use_mm_for_euclid_dist_if_necessary' and (
|
|
5256
|
+
x1.shape[-2] > 25 or x2.shape[-2] > 25):
|
|
5040
5257
|
return _euclidean_mm(x1, x2)
|
|
5041
5258
|
elif compute_mode == 'use_mm_for_euclid_dist':
|
|
5042
5259
|
return _euclidean_mm(x1, x2)
|
|
@@ -5045,7 +5262,8 @@ def _aten_cdist(x1, x2, p=2.0, compute_mode='use_mm_for_euclid_dist_if_necessary
|
|
|
5045
5262
|
else:
|
|
5046
5263
|
# General p-norm distance calculation
|
|
5047
5264
|
diff = jnp.abs(jnp.expand_dims(x1, -2) - jnp.expand_dims(x2, -3))
|
|
5048
|
-
return jnp.sum(jnp.power(diff, p), axis=-1).astype(jnp.float32)
|
|
5265
|
+
return jnp.sum(jnp.power(diff, p), axis=-1).astype(jnp.float32)**(1 / p)
|
|
5266
|
+
|
|
5049
5267
|
|
|
5050
5268
|
def _hamming_distance(x1, x2):
|
|
5051
5269
|
"""
|
|
@@ -5064,6 +5282,7 @@ def _hamming_distance(x1, x2):
|
|
|
5064
5282
|
|
|
5065
5283
|
return hamming_dist
|
|
5066
5284
|
|
|
5285
|
+
|
|
5067
5286
|
def _euclidean_mm(x1, x2):
|
|
5068
5287
|
"""
|
|
5069
5288
|
Computes the Euclidean distance using matrix multiplication.
|
|
@@ -5075,8 +5294,8 @@ def _euclidean_mm(x1, x2):
|
|
|
5075
5294
|
Returns:
|
|
5076
5295
|
JAX array of shape (..., P, R) representing pairwise Euclidean distances.
|
|
5077
5296
|
"""
|
|
5078
|
-
x1_sq = jnp.sum(x1
|
|
5079
|
-
x2_sq = jnp.sum(x2
|
|
5297
|
+
x1_sq = jnp.sum(x1**2, axis=-1, keepdims=True).astype(jnp.float32)
|
|
5298
|
+
x2_sq = jnp.sum(x2**2, axis=-1, keepdims=True).astype(jnp.float32)
|
|
5080
5299
|
|
|
5081
5300
|
x2_sq = jnp.swapaxes(x2_sq, -2, -1)
|
|
5082
5301
|
|
|
@@ -5088,6 +5307,7 @@ def _euclidean_mm(x1, x2):
|
|
|
5088
5307
|
|
|
5089
5308
|
return dist
|
|
5090
5309
|
|
|
5310
|
+
|
|
5091
5311
|
def _euclidean_direct(x1, x2):
|
|
5092
5312
|
"""
|
|
5093
5313
|
Computes the Euclidean distance directly without matrix multiplication.
|
|
@@ -5101,7 +5321,7 @@ def _euclidean_direct(x1, x2):
|
|
|
5101
5321
|
"""
|
|
5102
5322
|
diff = jnp.expand_dims(x1, -2) - jnp.expand_dims(x2, -3)
|
|
5103
5323
|
|
|
5104
|
-
dist_sq = jnp.sum(diff
|
|
5324
|
+
dist_sq = jnp.sum(diff**2, axis=-1).astype(jnp.float32)
|
|
5105
5325
|
|
|
5106
5326
|
dist_sq = jnp.maximum(dist_sq, 0.0)
|
|
5107
5327
|
|
|
@@ -5109,13 +5329,14 @@ def _euclidean_direct(x1, x2):
|
|
|
5109
5329
|
|
|
5110
5330
|
return dist
|
|
5111
5331
|
|
|
5332
|
+
|
|
5112
5333
|
@op(torch.ops.aten.lu_unpack)
|
|
5113
5334
|
def _aten_lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
|
|
5114
5335
|
# lu_unpack doesnt exist in jax.
|
|
5115
5336
|
# Get commonly used data shape variables
|
|
5116
5337
|
n = LU_data.shape[-2]
|
|
5117
5338
|
m = LU_data.shape[-1]
|
|
5118
|
-
dim = min(n,m)
|
|
5339
|
+
dim = min(n, m)
|
|
5119
5340
|
|
|
5120
5341
|
### Compute the Lower and Upper triangle
|
|
5121
5342
|
if unpack_data:
|
|
@@ -5130,7 +5351,7 @@ def _aten_lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
|
|
|
5130
5351
|
start_indices = jnp.zeros(len(LU_data.shape), dtype=int)
|
|
5131
5352
|
limit_indices = list(LU_data.shape)
|
|
5132
5353
|
limit_indices[-1] = dim
|
|
5133
|
-
L = jax.lax.slice(L, start_indices, limit_indices)
|
|
5354
|
+
L = jax.lax.slice(L, start_indices, limit_indices)
|
|
5134
5355
|
|
|
5135
5356
|
# Extract upper triangle
|
|
5136
5357
|
U = jnp.triu(LU_data)
|
|
@@ -5160,13 +5381,15 @@ def _aten_lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
|
|
|
5160
5381
|
|
|
5161
5382
|
# closure to be called for each input 2D matrix.
|
|
5162
5383
|
def _lu_unpack_2d(p, pivot):
|
|
5163
|
-
_pivot = pivot - 1
|
|
5384
|
+
_pivot = pivot - 1 # pivots are offset by 1 in jax
|
|
5164
5385
|
indices = jnp.array([*range(n)], dtype=jnp.int32)
|
|
5386
|
+
|
|
5165
5387
|
def update_indices(i, _indices):
|
|
5166
5388
|
tmp = _indices[i]
|
|
5167
5389
|
_indices = _indices.at[i].set(_indices[_pivot[i]])
|
|
5168
5390
|
_indices = _indices.at[_pivot[i]].set(tmp)
|
|
5169
5391
|
return _indices
|
|
5392
|
+
|
|
5170
5393
|
indices = jax.lax.fori_loop(0, _pivot.size, update_indices, indices)
|
|
5171
5394
|
p = p[jnp.array(indices)]
|
|
5172
5395
|
p = jnp.transpose(p)
|
|
@@ -5191,7 +5414,7 @@ def _aten_lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
|
|
|
5191
5414
|
reshapedPivot = LU_pivots.reshape(newPivotshape)
|
|
5192
5415
|
|
|
5193
5416
|
# vmap the reshaped 3d tensors
|
|
5194
|
-
v_lu_unpack_2d = jax.vmap(_lu_unpack_2d, in_axes=(0,0))
|
|
5417
|
+
v_lu_unpack_2d = jax.vmap(_lu_unpack_2d, in_axes=(0, 0))
|
|
5195
5418
|
unpackedP = v_lu_unpack_2d(reshapedP, reshapedPivot)
|
|
5196
5419
|
|
|
5197
5420
|
# reshape result back to P's shape
|
|
@@ -5210,3 +5433,204 @@ def linear(input, weight, bias=None):
|
|
|
5210
5433
|
if bias is not None:
|
|
5211
5434
|
res += bias
|
|
5212
5435
|
return res
|
|
5436
|
+
|
|
5437
|
+
|
|
5438
|
+
@op(torch.ops.aten.kthvalue)
|
|
5439
|
+
def kthvalue(input, k, dim=None, keepdim=False, *, out=None):
|
|
5440
|
+
if input.ndim == 0:
|
|
5441
|
+
return input, jnp.array(0)
|
|
5442
|
+
dimension = -1
|
|
5443
|
+
if dim is not None:
|
|
5444
|
+
dimension = dim
|
|
5445
|
+
while dimension < 0:
|
|
5446
|
+
dimension = dimension + input.ndim
|
|
5447
|
+
values = jax.lax.index_in_dim(
|
|
5448
|
+
jnp.partition(input, k - 1, dimension), k - 1, dimension, keepdim)
|
|
5449
|
+
indices = jax.lax.index_in_dim(
|
|
5450
|
+
jnp.argpartition(input, k - 1, dimension).astype('int64'), k - 1,
|
|
5451
|
+
dimension, keepdim)
|
|
5452
|
+
return values, indices
|
|
5453
|
+
|
|
5454
|
+
|
|
5455
|
+
@op(torch.ops.aten.take)
|
|
5456
|
+
def _aten_take(self, index):
|
|
5457
|
+
return self.flatten()[index]
|
|
5458
|
+
|
|
5459
|
+
|
|
5460
|
+
# func: pad(Tensor self, SymInt[] pad, str mode="constant", float? value=None) -> Tensor
|
|
5461
|
+
@op(torch.ops.aten.pad)
|
|
5462
|
+
def _aten_pad(self, pad, mode='constant', value=None):
|
|
5463
|
+
if not isinstance(pad, (tuple, list)) or len(pad) % 2 != 0:
|
|
5464
|
+
raise ValueError("Padding must be a sequence of even length.")
|
|
5465
|
+
|
|
5466
|
+
num_dims = self.ndim
|
|
5467
|
+
if len(pad) > 2 * num_dims:
|
|
5468
|
+
raise ValueError(
|
|
5469
|
+
f"Padding sequence length ({len(pad)}) exceeds 2 * number of dimensions ({2 * num_dims})."
|
|
5470
|
+
)
|
|
5471
|
+
|
|
5472
|
+
# JAX's pad function expects padding for each dimension as a tuple of (low, high)
|
|
5473
|
+
# We need to reverse the pad sequence and group them for JAX.
|
|
5474
|
+
# pad = [p_l0, p_r0, p_l1, p_r1, ...]
|
|
5475
|
+
# becomes ((..., ..., (p_l1, p_r1), (p_l0, p_r0)))
|
|
5476
|
+
jax_pad_width = []
|
|
5477
|
+
# Iterate in reverse pairs
|
|
5478
|
+
for i in range(len(pad) // 2):
|
|
5479
|
+
jax_pad_width.append((pad[(2 * i)], pad[(2 * i + 1)]))
|
|
5480
|
+
|
|
5481
|
+
# Pad any leading dimensions with (0, 0) if the pad sequence is shorter
|
|
5482
|
+
# than the number of dimensions.
|
|
5483
|
+
for _ in range(num_dims - len(pad) // 2):
|
|
5484
|
+
jax_pad_width.append((0, 0))
|
|
5485
|
+
|
|
5486
|
+
# Reverse the jax_pad_width list to match the dimension order
|
|
5487
|
+
jax_pad_width.reverse()
|
|
5488
|
+
|
|
5489
|
+
if mode == "constant":
|
|
5490
|
+
if value is None:
|
|
5491
|
+
value = 0.0
|
|
5492
|
+
return jnp.pad(
|
|
5493
|
+
self, pad_width=jax_pad_width, mode="constant", constant_values=value)
|
|
5494
|
+
elif mode == "reflect":
|
|
5495
|
+
return jnp.pad(self, pad_width=jax_pad_width, mode="reflect")
|
|
5496
|
+
elif mode == "edge":
|
|
5497
|
+
return jnp.pad(self, pad_width=jax_pad_width, mode="edge")
|
|
5498
|
+
else:
|
|
5499
|
+
raise ValueError(
|
|
5500
|
+
f"Unsupported padding mode: {mode}. Expected 'constant', 'reflect', or 'edge'."
|
|
5501
|
+
)
|
|
5502
|
+
|
|
5503
|
+
|
|
5504
|
+
@op(torch.ops.aten.is_nonzero)
|
|
5505
|
+
def _aten_is_nonzero(a):
|
|
5506
|
+
a = jnp.squeeze(a)
|
|
5507
|
+
if a.shape == (0,):
|
|
5508
|
+
raise RuntimeError('bool value of Tensor with no values is ambiguous')
|
|
5509
|
+
if a.ndim != 0:
|
|
5510
|
+
raise RuntimeError(
|
|
5511
|
+
'bool value of Tensor with more than one value is ambiguous')
|
|
5512
|
+
return a.item() != 0
|
|
5513
|
+
|
|
5514
|
+
|
|
5515
|
+
@op(torch.ops.aten.logit)
|
|
5516
|
+
def _aten_logit(self: jax.Array, eps: float | None = None) -> jax.Array:
|
|
5517
|
+
"""
|
|
5518
|
+
Computes the logit function of the input tensor.
|
|
5519
|
+
|
|
5520
|
+
logit(p) = log(p / (1 - p))
|
|
5521
|
+
|
|
5522
|
+
Args:
|
|
5523
|
+
self: Input tensor.
|
|
5524
|
+
eps: A small value to clip the input tensor to avoid log(0) or division by zero.
|
|
5525
|
+
If None, no clipping is performed.
|
|
5526
|
+
|
|
5527
|
+
Returns:
|
|
5528
|
+
A tensor with the logit of each element of the input.
|
|
5529
|
+
"""
|
|
5530
|
+
if eps is not None:
|
|
5531
|
+
self = jnp.clip(self, eps, 1.0 - eps)
|
|
5532
|
+
res = jnp.log(self / (1.0 - self))
|
|
5533
|
+
res = res.astype(mappings.t2j_dtype(torch.get_default_dtype()))
|
|
5534
|
+
return res
|
|
5535
|
+
|
|
5536
|
+
|
|
5537
|
+
@op(torch.ops.aten.floor_divide)
|
|
5538
|
+
def _aten_floor_divide(x, y):
|
|
5539
|
+
res = jnp.floor_divide(x, y)
|
|
5540
|
+
return res
|
|
5541
|
+
|
|
5542
|
+
|
|
5543
|
+
@op(torch.ops.aten._assert_tensor_metadata)
|
|
5544
|
+
def _aten__assert_tensor_metadata(*args, **kwargs):
|
|
5545
|
+
pass
|
|
5546
|
+
|
|
5547
|
+
|
|
5548
|
+
mutation_ops_to_functional = {
|
|
5549
|
+
torch.ops.aten.add_:
|
|
5550
|
+
op_base.InplaceOp(torch.ops.aten.add),
|
|
5551
|
+
torch.ops.aten.sub_:
|
|
5552
|
+
op_base.InplaceOp(torch.ops.aten.sub),
|
|
5553
|
+
torch.ops.aten.mul_:
|
|
5554
|
+
op_base.InplaceOp(torch.ops.aten.mul),
|
|
5555
|
+
torch.ops.aten.div_:
|
|
5556
|
+
op_base.InplaceOp(torch.ops.aten.div),
|
|
5557
|
+
torch.ops.aten.pow_:
|
|
5558
|
+
op_base.InplaceOp(torch.ops.aten.pow),
|
|
5559
|
+
torch.ops.aten.lt_:
|
|
5560
|
+
op_base.InplaceOp(torch.ops.aten.lt),
|
|
5561
|
+
torch.ops.aten.le_:
|
|
5562
|
+
op_base.InplaceOp(torch.ops.aten.le),
|
|
5563
|
+
torch.ops.aten.gt_:
|
|
5564
|
+
op_base.InplaceOp(torch.ops.aten.gt),
|
|
5565
|
+
torch.ops.aten.ge_:
|
|
5566
|
+
op_base.InplaceOp(torch.ops.aten.ge),
|
|
5567
|
+
torch.ops.aten.eq_:
|
|
5568
|
+
op_base.InplaceOp(torch.ops.aten.eq),
|
|
5569
|
+
torch.ops.aten.ne_:
|
|
5570
|
+
op_base.InplaceOp(torch.ops.aten.ne),
|
|
5571
|
+
torch.ops.aten.bernoulli_:
|
|
5572
|
+
op_base.InplaceOp(torch.ops.aten.bernoulli.p),
|
|
5573
|
+
torch.ops.aten.bernoulli_.float:
|
|
5574
|
+
op_base.InplaceOp(_aten_bernoulli, is_jax_func=True),
|
|
5575
|
+
torch.ops.aten.geometric_:
|
|
5576
|
+
op_base.InplaceOp(torch.ops.aten.geometric),
|
|
5577
|
+
torch.ops.aten.normal_:
|
|
5578
|
+
op_base.InplaceOp(torch.ops.aten.normal),
|
|
5579
|
+
torch.ops.aten.random_:
|
|
5580
|
+
op_base.InplaceOp(torch.ops.aten.uniform),
|
|
5581
|
+
torch.ops.aten.uniform_:
|
|
5582
|
+
op_base.InplaceOp(torch.ops.aten.uniform),
|
|
5583
|
+
torch.ops.aten.relu_:
|
|
5584
|
+
op_base.InplaceOp(torch.ops.aten.relu),
|
|
5585
|
+
# squeeze_ is expected to change tensor's shape. So replace with new value
|
|
5586
|
+
torch.ops.aten.squeeze_:
|
|
5587
|
+
op_base.InplaceOp(torch.ops.aten.squeeze, True),
|
|
5588
|
+
torch.ops.aten.sqrt_:
|
|
5589
|
+
op_base.InplaceOp(torch.ops.aten.sqrt),
|
|
5590
|
+
torch.ops.aten.clamp_:
|
|
5591
|
+
op_base.InplaceOp(torch.ops.aten.clamp),
|
|
5592
|
+
torch.ops.aten.clamp_min_:
|
|
5593
|
+
op_base.InplaceOp(torch.ops.aten.clamp_min),
|
|
5594
|
+
torch.ops.aten.sigmoid_:
|
|
5595
|
+
op_base.InplaceOp(torch.ops.aten.sigmoid),
|
|
5596
|
+
torch.ops.aten.tanh_:
|
|
5597
|
+
op_base.InplaceOp(torch.ops.aten.tanh),
|
|
5598
|
+
torch.ops.aten.ceil_:
|
|
5599
|
+
op_base.InplaceOp(torch.ops.aten.ceil),
|
|
5600
|
+
torch.ops.aten.logical_not_:
|
|
5601
|
+
op_base.InplaceOp(torch.ops.aten.logical_not),
|
|
5602
|
+
torch.ops.aten.unsqueeze_:
|
|
5603
|
+
op_base.InplaceOp(torch.ops.aten.unsqueeze),
|
|
5604
|
+
torch.ops.aten.transpose_:
|
|
5605
|
+
op_base.InplaceOp(torch.ops.aten.transpose),
|
|
5606
|
+
torch.ops.aten.log_normal_:
|
|
5607
|
+
op_base.InplaceOp(torch.ops.aten.log_normal),
|
|
5608
|
+
torch.ops.aten.scatter_add_:
|
|
5609
|
+
op_base.InplaceOp(torch.ops.aten.scatter_add),
|
|
5610
|
+
torch.ops.aten.scatter_reduce_.two:
|
|
5611
|
+
op_base.InplaceOp(torch.ops.aten.scatter_reduce),
|
|
5612
|
+
torch.ops.aten.scatter_:
|
|
5613
|
+
op_base.InplaceOp(torch.ops.aten.scatter),
|
|
5614
|
+
torch.ops.aten.bitwise_or_:
|
|
5615
|
+
op_base.InplaceOp(torch.ops.aten.bitwise_or),
|
|
5616
|
+
torch.ops.aten.floor_divide_:
|
|
5617
|
+
op_base.InplaceOp(torch.ops.aten.floor_divide),
|
|
5618
|
+
torch.ops.aten.remainder_:
|
|
5619
|
+
op_base.InplaceOp(torch.ops.aten.remainder),
|
|
5620
|
+
}
|
|
5621
|
+
|
|
5622
|
+
# Note: tuple comparisons work intuitively, e.g. `_jax_version >= (0, 4, 32)`.
|
|
5623
|
+
_jax_version = tuple(int(v) for v in jax.version._version.split("."))
|
|
5624
|
+
|
|
5625
|
+
mutation_needs_env = {
|
|
5626
|
+
torch.ops.aten.bernoulli_,
|
|
5627
|
+
torch.ops.aten.bernoulli_.float,
|
|
5628
|
+
}
|
|
5629
|
+
|
|
5630
|
+
for operator, mutation in mutation_ops_to_functional.items():
|
|
5631
|
+
ops_registry.register_torch_dispatch_op(
|
|
5632
|
+
operator,
|
|
5633
|
+
mutation,
|
|
5634
|
+
is_jax_function=False,
|
|
5635
|
+
is_view_op=True,
|
|
5636
|
+
needs_env=(operator in mutation_needs_env))
|