torchax 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

torchax/ops/jaten.py CHANGED
@@ -15,75 +15,22 @@ from torchax.ops import ops_registry
15
15
  from torchax.ops import op_base, mappings
16
16
  from torchax import interop
17
17
  from torchax.ops import jax_reimplement
18
-
18
+ from torchax.view import View
19
19
  # Keys are OpOverload, value is a callable that takes
20
20
  # Tensor
21
21
  all_ops = {}
22
22
 
23
- # list all Aten ops from pytorch that does mutation
24
- # and need to be implemented in jax
25
-
26
- mutation_ops_to_functional = {
27
- torch.ops.aten.add_: torch.ops.aten.add,
28
- torch.ops.aten.sub_: torch.ops.aten.sub,
29
- torch.ops.aten.mul_: torch.ops.aten.mul,
30
- torch.ops.aten.div_: torch.ops.aten.div,
31
- torch.ops.aten.pow_: torch.ops.aten.pow,
32
- torch.ops.aten.lt_: torch.ops.aten.lt,
33
- torch.ops.aten.le_: torch.ops.aten.le,
34
- torch.ops.aten.gt_: torch.ops.aten.gt,
35
- torch.ops.aten.ge_: torch.ops.aten.ge,
36
- torch.ops.aten.eq_: torch.ops.aten.eq,
37
- torch.ops.aten.ne_: torch.ops.aten.ne,
38
- torch.ops.aten.bernoulli_: torch.ops.aten.bernoulli.p,
39
- torch.ops.aten.geometric_: torch.ops.aten.geometric,
40
- torch.ops.aten.normal_: torch.ops.aten.normal,
41
- torch.ops.aten.random_: torch.ops.aten.uniform,
42
- torch.ops.aten.uniform_: torch.ops.aten.uniform,
43
- torch.ops.aten.relu_: torch.ops.aten.relu,
44
- # squeeze_ is expected to change tensor's shape. So replace with new value
45
- torch.ops.aten.squeeze_: (torch.ops.aten.squeeze, True),
46
- torch.ops.aten.sqrt_: torch.ops.aten.sqrt,
47
- torch.ops.aten.clamp_: torch.ops.aten.clamp,
48
- torch.ops.aten.clamp_min_: torch.ops.aten.clamp_min,
49
- torch.ops.aten.sigmoid_: torch.ops.aten.sigmoid,
50
- torch.ops.aten.tanh_: torch.ops.aten.tanh,
51
- torch.ops.aten.ceil_: torch.ops.aten.ceil,
52
- torch.ops.aten.logical_not_: torch.ops.aten.logical_not,
53
- torch.ops.aten.unsqueeze_: torch.ops.aten.unsqueeze,
54
- torch.ops.aten.transpose_: torch.ops.aten.transpose,
55
- torch.ops.aten.log_normal_: torch.ops.aten.log_normal,
56
- torch.ops.aten.scatter_add_: torch.ops.aten.scatter_add,
57
- torch.ops.aten.scatter_reduce_.two: torch.ops.aten.scatter_reduce,
58
- torch.ops.aten.scatter_: torch.ops.aten.scatter,
59
- }
60
-
61
- # Note: tuple comparisons work intuitively, e.g. `_jax_version >= (0, 4, 32)`.
62
- _jax_version = tuple(int(v) for v in jax.version._version.split("."))
63
-
64
-
65
- def make_mutation(op):
66
- if type(mutation_ops_to_functional[op]) is tuple:
67
- return op_base.InplaceOp(mutation_ops_to_functional[op][0],
68
- replace=mutation_ops_to_functional[op][1],
69
- position_to_mutate=0)
70
- return op_base.InplaceOp(mutation_ops_to_functional[op], position_to_mutate=0)
71
-
72
-
73
- for op in mutation_ops_to_functional.keys():
74
- ops_registry.register_torch_dispatch_op(
75
- op, make_mutation(op), is_jax_function=False
76
- )
77
-
78
23
 
79
24
  def op(*aten, **kwargs):
25
+
80
26
  def inner(func):
81
27
  for a in aten:
82
28
  ops_registry.register_torch_dispatch_op(a, func, **kwargs)
83
29
  continue
84
30
 
85
31
  if isinstance(a, torch._ops.OpOverloadPacket):
86
- opname = a.default.name() if 'default' in a.overloads() else a._qualified_op_name
32
+ opname = a.default.name() if 'default' in a.overloads(
33
+ ) else a._qualified_op_name
87
34
  elif isinstance(a, torch._ops.OpOverload):
88
35
  opname = a.name()
89
36
  else:
@@ -91,17 +38,18 @@ def op(*aten, **kwargs):
91
38
 
92
39
  torchfunc = functools.partial(interop.call_jax, func)
93
40
  # HACK: to_copy is where we make the initial conversion from CPU tensor to JAX tensor
94
- torch.library.impl(opname, 'privateuseone')(torchfunc if a != torch.ops.aten._to_copy else func)
41
+ torch.library.impl(opname, 'privateuseone')(
42
+ torchfunc if a != torch.ops.aten._to_copy else func)
95
43
  return func
96
44
 
97
45
  return inner
98
46
 
99
47
 
100
48
  @op(
101
- torch.ops.aten.view_copy,
102
- torch.ops.aten.view,
103
- torch.ops.aten._unsafe_view,
104
- torch.ops.aten.reshape,
49
+ torch.ops.aten.view_copy,
50
+ torch.ops.aten.view,
51
+ torch.ops.aten._unsafe_view,
52
+ torch.ops.aten.reshape,
105
53
  )
106
54
  def _aten_unsafe_view(x, shape):
107
55
  return jnp.reshape(x, shape)
@@ -121,8 +69,19 @@ def _aten_add(x, y, *, alpha=1):
121
69
  return res
122
70
 
123
71
 
124
- @op(torch.ops.aten.copy_, is_jax_function=False)
125
- def _aten_copy(x, y, memory_format=None):
72
+ @op(torch.ops.aten.copy_,
73
+ is_jax_function=False,
74
+ is_view_op=True,
75
+ needs_env=True)
76
+ def _aten_copy(x, y, memory_format=None, env=None):
77
+
78
+ if y.device.type == 'cpu':
79
+ y = env.to_xla(y)
80
+
81
+ if isinstance(x, View):
82
+ x.update(y)
83
+ return x
84
+
126
85
  if x.ndim == 1 and y.ndim == 0:
127
86
  # case of torch.empty((1,)).copy_(tensor(N))
128
87
  # we need to return 0D tensor([N]) and not scalar tensor(N)
@@ -147,14 +106,20 @@ def _aten_trunc(x):
147
106
 
148
107
  @op(torch.ops.aten.index_copy)
149
108
  def _aten_index_copy(x, dim, indexes, source):
109
+ if x.ndim == 0:
110
+ return source
111
+ if x.ndim == 1:
112
+ source = jnp.squeeze(source)
150
113
  # return jax.lax.scatter(x, index, dim)
114
+ if dim < 0:
115
+ dim = dim + x.ndim
151
116
  dims = []
152
117
  for i in range(len(x.shape)):
153
118
  if i == dim:
154
119
  dims.append(indexes)
155
120
  else:
156
121
  dims.append(slice(None, None, None))
157
- return x.at[dim].set(source)
122
+ return x.at[tuple(dims)].set(source)
158
123
 
159
124
 
160
125
  # aten.cauchy_
@@ -199,7 +164,9 @@ def _aten_complex(real, imag):
199
164
  Returns:
200
165
  A complex array with the specified real and imaginary parts.
201
166
  """
202
- return jnp.array(real, dtype=jnp.float32) + 1j * jnp.array(imag, dtype=jnp.float32)
167
+ return jnp.array(
168
+ real, dtype=jnp.float32) + 1j * jnp.array(
169
+ imag, dtype=jnp.float32)
203
170
 
204
171
 
205
172
  # aten.exponential_
@@ -223,13 +190,14 @@ def _aten_exponential_(x, lambd=1.0):
223
190
  # aten.linalg_householder_product
224
191
  @op(torch.ops.aten.linalg_householder_product)
225
192
  def _aten_linalg_householder_product(input, tau):
226
- return jax.lax.linalg.householder_product(a = input, taus = tau)
193
+ return jax.lax.linalg.householder_product(a=input, taus=tau)
227
194
 
228
195
 
229
196
  @op(torch.ops.aten.select)
230
197
  def _aten_select(x, dim, indexes):
231
198
  return jax.lax.index_in_dim(x, index=indexes, axis=dim, keepdims=False)
232
199
 
200
+
233
201
  @op(torch.ops.aten.index_select)
234
202
  @op(torch.ops.aten.select_copy)
235
203
  def _aten_index_select(x, dim, index):
@@ -249,11 +217,10 @@ def _aten_linalg_cholesky_ex(input, upper=False, check_errors=False):
249
217
  raise NotImplementedError(
250
218
  "check_errors=True is not supported in this JAX implementation. "
251
219
  "Check for positive definiteness using jnp.linalg.eigvalsh before "
252
- "calling this function."
253
- )
220
+ "calling this function.")
254
221
 
255
222
  L = jax.scipy.linalg.cholesky(input, lower=not upper)
256
- if len(L.shape) >2:
223
+ if len(L.shape) > 2:
257
224
  info = jnp.zeros(shape=L.shape[:-2], dtype=jnp.int32)
258
225
  else:
259
226
  info = jnp.array(0, dtype=jnp.int32)
@@ -263,7 +230,7 @@ def _aten_linalg_cholesky_ex(input, upper=False, check_errors=False):
263
230
  @op(torch.ops.aten.cholesky_solve)
264
231
  def _aten_cholesky_solve(input, input2, upper=False):
265
232
  # Ensure input2 is lower triangular for cho_solve
266
- L = input2 if not upper else input2.T
233
+ L = input2 if not upper else input2.T
267
234
  # Use cho_solve to solve the linear system
268
235
  solution = jax.scipy.linalg.cho_solve((L, True), input)
269
236
  return solution
@@ -275,7 +242,7 @@ def _aten_special_zeta(x, q):
275
242
  res = jax.scipy.special.zeta(x, q)
276
243
  if isinstance(x, int) or isinstance(q, int):
277
244
  res = res.astype(new_dtype)
278
- return res # jax.scipy.special.zeta(x, q)
245
+ return res # jax.scipy.special.zeta(x, q)
279
246
 
280
247
 
281
248
  # aten.igammac
@@ -286,7 +253,7 @@ def _aten_igammac(input, other):
286
253
  if isinstance(other, jnp.ndarray):
287
254
  other = jnp.where(other < 0, jnp.nan, other)
288
255
  else:
289
- if (input==0 and other==0) or (input < 0) or (other < 0):
256
+ if (input == 0 and other == 0) or (input < 0) or (other < 0):
290
257
  other = jnp.nan
291
258
  return jnp.array(jax.scipy.special.gammaincc(input, other))
292
259
 
@@ -294,7 +261,7 @@ def _aten_igammac(input, other):
294
261
  @op(torch.ops.aten.mean)
295
262
  def _aten_mean(x, dim=None, keepdim=False):
296
263
  if x.shape == () and dim is not None:
297
- dim = None # disable dim for jax array without dim
264
+ dim = None # disable dim for jax array without dim
298
265
  return jnp.mean(x, dim, keepdims=keepdim)
299
266
 
300
267
 
@@ -310,13 +277,14 @@ def _torch_binary_scalar_type(scalar, tensor):
310
277
 
311
278
 
312
279
  @op(torch.ops.aten.searchsorted.Tensor)
313
- def _aten_searchsorted(sorted_sequence, values):
280
+ def _aten_searchsorted(sorted_sequence, values):
314
281
  new_dtype = mappings.t2j_dtype(torch.get_default_dtype())
315
282
  res = jnp.searchsorted(sorted_sequence, values)
316
- if sorted_sequence.dtype == np.dtype(np.int32) or sorted_sequence.dtype == np.dtype(np.int32):
283
+ if sorted_sequence.dtype == np.dtype(
284
+ np.int32) or sorted_sequence.dtype == np.dtype(np.int32):
317
285
  # res = res.astype(new_dtype)
318
286
  res = res.astype(np.dtype(np.int64))
319
- return res # jnp.searchsorted(sorted_sequence, values)
287
+ return res # jnp.searchsorted(sorted_sequence, values)
320
288
 
321
289
 
322
290
  @op(torch.ops.aten.sub.Tensor)
@@ -328,7 +296,7 @@ def _aten_sub(x, y, alpha=1):
328
296
  if isinstance(y, float):
329
297
  dtype = _torch_binary_scalar_type(y, x)
330
298
  y = jnp.array(y, dtype=dtype)
331
- return x - y*alpha
299
+ return x - y * alpha
332
300
 
333
301
 
334
302
  @op(torch.ops.aten.numpy_T)
@@ -345,7 +313,6 @@ def _aten_numpy_T(input):
345
313
  return jnp.transpose(input)
346
314
 
347
315
 
348
-
349
316
  @op(torch.ops.aten.mm)
350
317
  def _aten_mm(x, y):
351
318
  res = x @ y
@@ -379,13 +346,15 @@ def _aten_t(x):
379
346
  @op(torch.ops.aten.transpose)
380
347
  @op(torch.ops.aten.transpose_copy)
381
348
  def _aten_transpose(x, dim0, dim1):
382
- shape = list(range(len(x.shape)))
383
- shape[dim0], shape[dim1] = shape[dim1], shape[dim0]
384
- return jnp.transpose(x, shape)
349
+ if x.ndim == 0:
350
+ return x
351
+ dim0 = dim0 if dim0 >= 0 else dim0 + x.ndim
352
+ dim1 = dim1 if dim1 >= 0 else dim1 + x.ndim
353
+ return jnp.swapaxes(x, dim0, dim1)
385
354
 
386
355
 
387
356
  @op(torch.ops.aten.triu)
388
- def _aten_triu(m, k):
357
+ def _aten_triu(m, k=0):
389
358
  return jnp.triu(m, k)
390
359
 
391
360
 
@@ -406,6 +375,7 @@ def _aten_slice(self, dim=0, start=None, end=None, step=1):
406
375
  return self[tuple(dims)]
407
376
 
408
377
 
378
+ @op(torch.ops.aten.positive)
409
379
  @op(torch.ops.aten.detach)
410
380
  def _aten_detach(self):
411
381
  return self
@@ -439,7 +409,8 @@ def _aten_resize_as_(x, y):
439
409
 
440
410
  @op(torch.ops.aten.repeat_interleave.Tensor)
441
411
  def repeat_interleave(repeats, dim=0):
442
- return jnp.repeat(jnp.arange(repeats.shape[dim]), repeats)
412
+ return jnp.repeat(np.arange(repeats.shape[dim]), repeats)
413
+
443
414
 
444
415
  @op(torch.ops.aten.repeat_interleave.self_int)
445
416
  @op(torch.ops.aten.repeat_interleave.self_Tensor)
@@ -451,12 +422,6 @@ def repeat_interleave(self, repeats, dim=0):
451
422
  return jnp.repeat(self, repeats, dim, total_repeat_length=total_repeat_length)
452
423
 
453
424
 
454
- # aten.upsample_bilinear2d
455
- @op(torch.ops.aten.upsample_bilinear2d)
456
- def _aten_upsample_bilinear2d(x, output_size, align_corners=False, scale_h=None, scale_w=None):
457
- return _aten_upsample_bilinear2d_aa(x, output_size=output_size, align_corners=align_corners, scale_factors=None, scales_h=scale_h, scales_w=scale_w)
458
-
459
-
460
425
  @op(torch.ops.aten.view_as_real)
461
426
  def _aten_view_as_real(x):
462
427
  real = jnp.real(x)
@@ -473,7 +438,7 @@ def _aten_stack(tensors, dim=0):
473
438
  @op(torch.ops.aten._softmax)
474
439
  @op(torch.ops.aten.softmax)
475
440
  @op(torch.ops.aten.softmax.int)
476
- def _aten_softmax(x, dim, halftofloat = False):
441
+ def _aten_softmax(x, dim, halftofloat=False):
477
442
  if x.shape == ():
478
443
  return jax.nn.softmax(x.reshape([1]), axis=0).reshape([])
479
444
  return jax.nn.softmax(x, dim)
@@ -482,10 +447,12 @@ def _aten_softmax(x, dim, halftofloat = False):
482
447
  def _is_int(x):
483
448
  if isinstance(x, int):
484
449
  return True
485
- if isinstance(x, jax.Array) and (x.dtype.name.startswith('int') or x.dtype.name.startswith('uint')):
450
+ if isinstance(x, jax.Array) and (x.dtype.name.startswith('int') or
451
+ x.dtype.name.startswith('uint')):
486
452
  return True
487
453
  return False
488
454
 
455
+
489
456
  def highest_precision_int_dtype(tensor1, tensor2):
490
457
  if isinstance(tensor1, int):
491
458
  return tensor2.dtype
@@ -493,12 +460,20 @@ def highest_precision_int_dtype(tensor1, tensor2):
493
460
  return tensor1.dtype
494
461
 
495
462
  dtype_hierarchy = {
496
- 'uint8': 8, 'int8': 8,
497
- 'uint16': 16, 'int16': 16,
498
- 'uint32': 32, 'int32': 32,
499
- 'uint64': 64, 'int64': 64,
463
+ 'uint8': 8,
464
+ 'int8': 8,
465
+ 'uint16': 16,
466
+ 'int16': 16,
467
+ 'uint32': 32,
468
+ 'int32': 32,
469
+ 'uint64': 64,
470
+ 'int64': 64,
500
471
  }
501
- return max(tensor1.dtype, tensor2.dtype, key=lambda dtype: dtype_hierarchy[str(dtype)])
472
+ return max(
473
+ tensor1.dtype,
474
+ tensor2.dtype,
475
+ key=lambda dtype: dtype_hierarchy[str(dtype)])
476
+
502
477
 
503
478
  @op(torch.ops.aten.pow)
504
479
  def _aten_pow(x, y):
@@ -553,11 +528,13 @@ def _aten_div(x, y, rounding_mode=""):
553
528
  def _aten_true_divide(x, y):
554
529
  return x / y
555
530
 
531
+
556
532
  @op(torch.ops.aten.dist)
557
533
  def _aten_dist(input, other, p=2):
558
534
  diff = jnp.abs(jnp.subtract(input, other))
559
535
  return _aten_linalg_vector_norm(diff, ord=p)
560
536
 
537
+
561
538
  @op(torch.ops.aten.bmm)
562
539
  def _aten_bmm(x, y):
563
540
  res = x @ y
@@ -567,9 +544,14 @@ def _aten_bmm(x, y):
567
544
 
568
545
  @op(torch.ops.aten.embedding)
569
546
  # embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False)
570
- def _aten_embedding(a, w, padding_idx=-1, scale_grad_by_freq=False, sparse=False):
547
+ def _aten_embedding(a,
548
+ w,
549
+ padding_idx=-1,
550
+ scale_grad_by_freq=False,
551
+ sparse=False):
571
552
  return jnp.take(a, w, axis=0)
572
553
 
554
+
573
555
  @op(torch.ops.aten.embedding_renorm_)
574
556
  def _aten_embedding_renorm_(weight, indices, max_norm, norm_type):
575
557
  # Adapted from https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Embedding.cpp
@@ -587,27 +569,26 @@ def _aten_embedding_renorm_(weight, indices, max_norm, norm_type):
587
569
 
588
570
  indices_to_update = unique_indices[indice_idx]
589
571
 
590
- weight = weight.at[indices_to_update].set(
591
- weight[indices_to_update] * scale[:, None]
592
- )
572
+ weight = weight.at[indices_to_update].set(weight[indices_to_update] *
573
+ scale[:, None])
593
574
  return weight
594
575
 
576
+
595
577
  #- func: _embedding_bag_forward_only(
596
578
  # Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False,
597
579
  # int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) -> (Tensor, Tensor, Tensor, Tensor)
598
580
  @op(torch.ops.aten._embedding_bag)
599
581
  @op(torch.ops.aten._embedding_bag_forward_only)
600
- def _aten__embedding_bag(
601
- weight,
602
- indices,
603
- offsets=None,
604
- scale_grad_by_freq=False,
605
- mode=0,
606
- sparse=False,
607
- per_sample_weights=None,
608
- include_last_offset=False,
609
- padding_idx=-1):
610
- """Jax implementation of the PyTorch _embedding_bag function.
582
+ def _aten__embedding_bag(weight,
583
+ indices,
584
+ offsets=None,
585
+ scale_grad_by_freq=False,
586
+ mode=0,
587
+ sparse=False,
588
+ per_sample_weights=None,
589
+ include_last_offset=False,
590
+ padding_idx=-1):
591
+ """Jax implementation of the PyTorch _embedding_bag function.
611
592
 
612
593
  Args:
613
594
  weight: The learnable weights of the module of shape (num_embeddings, embedding_dim).
@@ -623,48 +604,50 @@ def _aten__embedding_bag(
623
604
  Returns:
624
605
  A tuple of (output, offset2bag, bag_size, max_indices).
625
606
  """
626
- embedded = _aten_embedding(weight, indices, padding_idx)
627
-
628
- if offsets is None:
629
- # offsets is None only when indices.ndim > 1
630
- if mode == 0: # sum
631
- output = jnp.sum(embedded, axis=1)
632
- elif mode == 1: # mean
633
- output = jnp.mean(embedded, axis=1)
634
- elif mode == 2: # max
635
- output = jnp.max(embedded, axis=1)
636
- return output, None, None, None
637
-
638
- if isinstance(offsets, jax.Array):
639
- offsets_np = np.array(offsets)
640
- else:
641
- offsets_np = offsets
642
- offset2bag = np.zeros(indices.shape[0], dtype=np.int64)
643
- bag_size = np.zeros(offsets_np.shape[0], dtype=np.int64)
644
- max_indices = jnp.full_like(indices, -1)
607
+ embedded = _aten_embedding(weight, indices, padding_idx)
608
+
609
+ if offsets is None:
610
+ # offsets is None only when indices.ndim > 1
611
+ if mode == 0: # sum
612
+ output = jnp.sum(embedded, axis=1)
613
+ elif mode == 1: # mean
614
+ output = jnp.mean(embedded, axis=1)
615
+ elif mode == 2: # max
616
+ output = jnp.max(embedded, axis=1)
617
+ return output, None, None, None
618
+
619
+ if isinstance(offsets, jax.Array):
620
+ offsets_np = np.array(offsets)
621
+ else:
622
+ offsets_np = offsets
623
+ offset2bag = np.zeros(indices.shape[0], dtype=np.int64)
624
+ bag_size = np.zeros(offsets_np.shape[0], dtype=np.int64)
625
+ max_indices = jnp.full_like(indices, -1)
645
626
 
646
- for bag in range(offsets_np.shape[0]):
647
- start = int(offsets_np[bag])
627
+ for bag in range(offsets_np.shape[0]):
628
+ start = int(offsets_np[bag])
648
629
 
649
- end = int(indices.shape[0] if bag + 1 == offsets_np.shape[0] else offsets_np[bag + 1])
650
- bag_size[bag] = end - start
651
- offset2bag = offset2bag.at[start:end].set(bag)
630
+ end = int(indices.shape[0] if bag +
631
+ 1 == offsets_np.shape[0] else offsets_np[bag + 1])
632
+ bag_size[bag] = end - start
633
+ offset2bag = offset2bag.at[start:end].set(bag)
652
634
 
653
- if end - start > 0:
654
- if mode == 0:
655
- output_bag = jnp.sum(embedded[start:end], axis=0)
656
- elif mode == 1:
657
- output_bag = jnp.mean(embedded[start:end], axis=0)
658
- elif mode == 2:
659
- output_bag = jnp.max(embedded[start:end], axis=0)
660
- max_indices = max_indices.at[start:end].set(jnp.argmax(embedded[start:end], axis=0))
635
+ if end - start > 0:
636
+ if mode == 0:
637
+ output_bag = jnp.sum(embedded[start:end], axis=0)
638
+ elif mode == 1:
639
+ output_bag = jnp.mean(embedded[start:end], axis=0)
640
+ elif mode == 2:
641
+ output_bag = jnp.max(embedded[start:end], axis=0)
642
+ max_indices = max_indices.at[start:end].set(
643
+ jnp.argmax(embedded[start:end], axis=0))
661
644
 
662
- # The original code returned offset2bag, bag_size, and max_indices as numpy arrays.
663
- # Converting them to JAX arrays for consistency.
664
- offset2bag = jnp.array(offset2bag)
665
- bag_size = jnp.array(bag_size)
645
+ # The original code returned offset2bag, bag_size, and max_indices as numpy arrays.
646
+ # Converting them to JAX arrays for consistency.
647
+ offset2bag = jnp.array(offset2bag)
648
+ bag_size = jnp.array(bag_size)
666
649
 
667
- return output_bag, offset2bag, bag_size, max_indices
650
+ return output_bag, offset2bag, bag_size, max_indices
668
651
 
669
652
 
670
653
  @op(torch.ops.aten.rsqrt)
@@ -676,6 +659,7 @@ def _aten_rsqrt(x):
676
659
  @op(torch.ops.aten.expand)
677
660
  @op(torch.ops.aten.expand_copy)
678
661
  def _aten_expand(x, dims):
662
+
679
663
  def fix_dims(d, xs):
680
664
  if d == -1:
681
665
  return xs
@@ -683,7 +667,9 @@ def _aten_expand(x, dims):
683
667
 
684
668
  shape = list(x.shape)
685
669
  if len(shape) < len(dims):
686
- shape = [1, ] * (len(dims) - len(shape)) + shape
670
+ shape = [
671
+ 1,
672
+ ] * (len(dims) - len(shape)) + shape
687
673
  # make sure that dims and shape is the same by
688
674
  # left pad with 1s. Otherwise the zip below will
689
675
  # truncate
@@ -705,15 +691,15 @@ def _aten__to_copy(self, **kwargs):
705
691
 
706
692
 
707
693
  @op(torch.ops.aten.empty)
708
- @op_base.convert_dtype()
694
+ @op_base.convert_dtype(use_default_dtype=False)
709
695
  def _aten_empty(size: Sequence[int], *, dtype=None, **kwargs):
710
696
  return jnp.empty(size, dtype=dtype)
711
697
 
712
698
 
713
699
  @op(torch.ops.aten.empty_like)
714
- @op_base.convert_dtype()
700
+ @op_base.convert_dtype(use_default_dtype=False)
715
701
  def _aten_empty_like(input, *, dtype=None, **kwargs):
716
- return jnp.empty_like(input, dtype=dtype)
702
+ return jnp.empty_like(input, dtype)
717
703
 
718
704
 
719
705
  @op(torch.ops.aten.ones)
@@ -750,7 +736,6 @@ def _aten_empty_strided(sizes, stride, dtype=None, **kwargs):
750
736
  return jnp.empty(sizes, dtype=dtype)
751
737
 
752
738
 
753
- @op(torch.ops.aten.index_put_)
754
739
  @op(torch.ops.aten.index_put)
755
740
  def _aten_index_put(self, indexes, values, accumulate=False):
756
741
  indexes = [slice(None, None, None) if i is None else i for i in indexes]
@@ -784,8 +769,8 @@ def split_with_sizes(x, sizes, dim=0):
784
769
  A list of sub-arrays.
785
770
  """
786
771
  if isinstance(sizes, int):
787
- # split equal size
788
- new_sizes = [sizes] * (x.shape[dim] // sizes)
772
+ # split equal size, round up
773
+ new_sizes = [sizes] * (-(-x.shape[dim] // sizes))
789
774
  sizes = new_sizes
790
775
  rank = x.ndim
791
776
  splits = np.cumsum(sizes) # Cumulative sum for split points
@@ -796,14 +781,15 @@ def split_with_sizes(x, sizes, dim=0):
796
781
  return tuple(res)
797
782
 
798
783
  return [
799
- x[make_range(rank, dim, start, end)]
800
- for start, end in zip([0] + list(splits[:-1]), splits)
784
+ x[make_range(rank, dim, start, end)]
785
+ for start, end in zip([0] + list(splits[:-1]), splits)
801
786
  ]
802
787
 
803
788
 
804
789
  @op(torch.ops.aten.permute)
805
790
  @op(torch.ops.aten.permute_copy)
806
791
  def permute(t, dims):
792
+ # TODO: return a View instead
807
793
  return jnp.transpose(t, dims)
808
794
 
809
795
 
@@ -819,6 +805,7 @@ def _aten_unsqueeze(self, dim):
819
805
  def _aten_ne(x, y):
820
806
  return jnp.not_equal(x, y)
821
807
 
808
+
822
809
  # Create indices along a specific axis
823
810
  #
824
811
  # For example
@@ -832,14 +819,12 @@ def _aten_ne(x, y):
832
819
  def _indices_along_axis(x, axis):
833
820
  return jnp.expand_dims(
834
821
  jnp.arange(x.shape[axis]),
835
- axis = [d for d in range(len(x.shape)) if d != axis]
836
- )
822
+ axis=[d for d in range(len(x.shape)) if d != axis])
823
+
837
824
 
838
825
  def _broadcast_indices(indices, shape):
839
- return jnp.broadcast_to(
840
- indices,
841
- shape
842
- )
826
+ return jnp.broadcast_to(indices, shape)
827
+
843
828
 
844
829
  @op(torch.ops.aten.cummax)
845
830
  def _aten_cummax(x, dim):
@@ -851,36 +836,45 @@ def _aten_cummax(x, dim):
851
836
  indice_along_axis = _indices_along_axis(x, axis)
852
837
  indices = _broadcast_indices(indice_along_axis, x.shape)
853
838
 
854
-
855
839
  def cummax_reduce_func(carry, elem):
856
- v1, v2 = carry['val'], elem['val']
840
+ v1, v2 = carry['val'], elem['val']
857
841
  i1, i2 = carry['idx'], elem['idx']
858
842
 
859
843
  v = jnp.maximum(v1, v2)
860
844
  i = jnp.where(v1 > v2, i1, i2)
861
845
  return {'val': v, 'idx': i}
862
- res = jax.lax.associative_scan(cummax_reduce_func, {'val': x, 'idx': indices}, axis=axis)
846
+
847
+ res = jax.lax.associative_scan(
848
+ cummax_reduce_func, {
849
+ 'val': x,
850
+ 'idx': indices
851
+ }, axis=axis)
863
852
  return res['val'], res['idx']
864
853
 
854
+
865
855
  @op(torch.ops.aten.cummin)
866
856
  def _aten_cummin(x, dim):
867
857
  if not x.shape:
868
858
  return x, jnp.zeros_like(x, dtype=jnp.int64)
869
-
859
+
870
860
  axis = dim
871
861
 
872
862
  indice_along_axis = _indices_along_axis(x, axis)
873
863
  indices = _broadcast_indices(indice_along_axis, x.shape)
874
864
 
875
865
  def cummin_reduce_func(carry, elem):
876
- v1, v2 = carry['val'], elem['val']
866
+ v1, v2 = carry['val'], elem['val']
877
867
  i1, i2 = carry['idx'], elem['idx']
878
868
 
879
869
  v = jnp.minimum(v1, v2)
880
870
  i = jnp.where(v1 < v2, i1, i2)
881
871
  return {'val': v, 'idx': i}
882
872
 
883
- res = jax.lax.associative_scan(cummin_reduce_func, {'val': x, 'idx': indices}, axis=axis)
873
+ res = jax.lax.associative_scan(
874
+ cummin_reduce_func, {
875
+ 'val': x,
876
+ 'idx': indices
877
+ }, axis=axis)
884
878
  return res['val'], res['idx']
885
879
 
886
880
 
@@ -908,9 +902,11 @@ def _aten_cumprod(input, dim, dtype=None, out=None):
908
902
 
909
903
 
910
904
  @op(torch.ops.aten.native_layer_norm)
911
- def _aten_native_layer_norm(
912
- input, normalized_shape, weight=None, bias=None, eps=1e-5
913
- ):
905
+ def _aten_native_layer_norm(input,
906
+ normalized_shape,
907
+ weight=None,
908
+ bias=None,
909
+ eps=1e-5):
914
910
  """Implements layer normalization in Jax as defined by `aten::native_layer_norm`.
915
911
 
916
912
  Args:
@@ -944,7 +940,7 @@ def _aten_native_layer_norm(
944
940
  norm_x += bias
945
941
  return norm_x, mean, rstd
946
942
 
947
-
943
+
948
944
  @op(torch.ops.aten.matmul)
949
945
  def _aten_matmul(x, y):
950
946
  return x @ y
@@ -960,6 +956,7 @@ def _aten_addmm(self, mat1, mat2, *, beta=1.0, alpha=1.0):
960
956
  self += alpha * jnp.matmul(mat1, mat2)
961
957
  return self
962
958
 
959
+
963
960
  @op(torch.ops.aten.sparse_sampled_addmm)
964
961
  def _aten_sparse_addmm(self, mat1, mat2, *, beta=1.0, alpha=1.0):
965
962
  alpha = jnp.array(alpha).astype(mat1.dtype)
@@ -974,9 +971,8 @@ def _aten_addbmm(input, batch1, batch2, *, beta=1, alpha=1):
974
971
  alpha = jnp.array(alpha).astype(batch1.dtype)
975
972
  beta = jnp.array(beta).astype(batch1.dtype)
976
973
  mm = jnp.einsum("bxy, byz -> xz", batch1, batch2)
977
- return jax.lax.cond(
978
- beta == 0, lambda: alpha * mm, lambda: beta * input + alpha * mm
979
- )
974
+ return jax.lax.cond(beta == 0, lambda: alpha * mm,
975
+ lambda: beta * input + alpha * mm)
980
976
 
981
977
 
982
978
  @op(torch.ops.aten.gelu)
@@ -987,73 +983,69 @@ def _aten_gelu(self, *, approximate="none"):
987
983
 
988
984
  @op(torch.ops.aten.squeeze)
989
985
  @op(torch.ops.aten.squeeze_copy)
990
- def _aten_squeeze_dim(self, dim):
991
- """Squeezes a Jax tensor by removing a single dimension of size 1.
992
-
993
- Args:
994
- self: The input tensor.
995
- dim: The dimension to squeeze.
996
-
997
- Returns:
998
- The squeezed tensor with the specified dimension removed if it is 1,
999
- otherwise the original tensor is returned.
1000
- """
1001
-
1002
- # Validate input arguments
1003
- if not isinstance(self, jnp.ndarray):
1004
- raise TypeError(f"Expected a Jax tensor, got {type(self)}.")
1005
- if isinstance(dim, int):
1006
- dim = [dim]
1007
-
1008
- # Check if the specified dimension has size 1
1009
- if (len(self.shape) == 0) or all([self.shape[d] != 1 for d in dim]):
986
+ def _aten_squeeze_dim(self, dim=None):
987
+ if self.ndim == 0:
1010
988
  return self
989
+ if dim is not None:
990
+ if isinstance(dim, int):
991
+ if self.shape[dim] != 1:
992
+ return self
993
+ if dim < 0:
994
+ dim += self.ndim
995
+ else:
996
+ # NOTE: torch leaves the dims that is not 1 unchanged,
997
+ # but jax raises error.
998
+ dim = [
999
+ i if i >= 0 else (i + self.ndim) for i in dim if self.shape[i] == 1
1000
+ ]
1011
1001
 
1012
- # Use slicing to remove the dimension if it is 1
1013
- new_shape = list(self.shape)
1014
-
1015
- def fix_dim(p):
1016
- if p < 0:
1017
- return p + len(self.shape)
1018
- return p
1002
+ return jnp.squeeze(self, dim)
1019
1003
 
1020
- dim = [fix_dim(d) for d in dim]
1021
- new_shape = [p for i, p in enumerate(self.shape) if i not in dim or p != 1]
1022
- return self.reshape(new_shape)
1023
1004
 
1024
1005
  @op(torch.ops.aten.bucketize)
1025
- def _aten_bucketize(input, boundaries, *, out_int32=False, right=False, out=None):
1026
- assert boundaries[0] < boundaries[-1], "boundaries must contain a strictly increasing sequence"
1006
+ def _aten_bucketize(input,
1007
+ boundaries,
1008
+ *,
1009
+ out_int32=False,
1010
+ right=False,
1011
+ out=None):
1027
1012
  return_type = jnp.int32 if out_int32 else jnp.int64
1028
1013
  return jnp.digitize(input, boundaries, right=not right).astype(return_type)
1029
1014
 
1030
1015
 
1031
1016
  @op(torch.ops.aten.conv2d)
1032
1017
  def _aten_conv2d(
1033
- input,
1034
- weight,
1035
- bias,
1036
- stride,
1037
- padding,
1038
- dilation,
1039
- groups,
1018
+ input,
1019
+ weight,
1020
+ bias,
1021
+ stride,
1022
+ padding,
1023
+ dilation,
1024
+ groups,
1040
1025
  ):
1041
1026
  return _aten_convolution(
1042
- input, weight, bias, stride, padding,
1043
- dilation, transposed=False,
1044
- output_padding=1, groups=groups)
1027
+ input,
1028
+ weight,
1029
+ bias,
1030
+ stride,
1031
+ padding,
1032
+ dilation,
1033
+ transposed=False,
1034
+ output_padding=1,
1035
+ groups=groups)
1036
+
1045
1037
 
1046
1038
  @op(torch.ops.aten.convolution)
1047
1039
  def _aten_convolution(
1048
- input,
1049
- weight,
1050
- bias,
1051
- stride,
1052
- padding,
1053
- dilation,
1054
- transposed,
1055
- output_padding,
1056
- groups,
1040
+ input,
1041
+ weight,
1042
+ bias,
1043
+ stride,
1044
+ padding,
1045
+ dilation,
1046
+ transposed,
1047
+ output_padding,
1048
+ groups,
1057
1049
  ):
1058
1050
  num_shape_dim = weight.ndim - 1
1059
1051
  batch_dims = input.shape[:-num_shape_dim]
@@ -1068,7 +1060,7 @@ def _aten_convolution(
1068
1060
  # See https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html
1069
1061
  pad_out = []
1070
1062
  for i in range(num_spatial_dims):
1071
- front = dilation[i] * (weight.shape[i+2] - 1) - padding[i]
1063
+ front = dilation[i] * (weight.shape[i + 2] - 1) - padding[i]
1072
1064
  back = front + output_padding[i]
1073
1065
  pad_out.append((front, back))
1074
1066
  return pad_out
@@ -1089,39 +1081,38 @@ def _aten_convolution(
1089
1081
  rhs_spec.append(i + 2)
1090
1082
  out_spec.append(i + 2)
1091
1083
  return jax.lax.ConvDimensionNumbers(
1092
- *map(tuple, (lhs_spec, rhs_spec, out_spec))
1093
- )
1084
+ *map(tuple, (lhs_spec, rhs_spec, out_spec)))
1094
1085
 
1095
1086
  if transposed:
1096
- rhs = jnp.flip(weight, range(2, 1+num_shape_dim))
1087
+ rhs = jnp.flip(weight, range(2, 1 + num_shape_dim))
1097
1088
  if groups != 1:
1098
1089
  # reshape filters for tranposed depthwise convolution
1099
1090
  assert rhs.shape[0] % groups == 0
1100
- rhs_shape = [rhs.shape[0]//groups, rhs.shape[1]*groups]
1091
+ rhs_shape = [rhs.shape[0] // groups, rhs.shape[1] * groups]
1101
1092
  rhs_shape.extend(rhs.shape[2:])
1102
1093
  rhs = jnp.reshape(rhs, rhs_shape)
1103
1094
  res = jax.lax.conv_general_dilated(
1104
- input,
1105
- rhs,
1106
- (1,) * len(stride),
1107
- make_padding(padding, len(stride)),
1108
- lhs_dilation=stride,
1109
- rhs_dilation=dilation,
1110
- dimension_numbers=create_default_conv_dimension_numbers(len(stride)),
1111
- feature_group_count=groups,
1112
- batch_group_count=1,
1095
+ input,
1096
+ rhs,
1097
+ (1,) * len(stride),
1098
+ make_padding(padding, len(stride)),
1099
+ lhs_dilation=stride,
1100
+ rhs_dilation=dilation,
1101
+ dimension_numbers=create_default_conv_dimension_numbers(len(stride)),
1102
+ feature_group_count=groups,
1103
+ batch_group_count=1,
1113
1104
  )
1114
1105
  else:
1115
1106
  res = jax.lax.conv_general_dilated(
1116
- input,
1117
- weight,
1118
- stride,
1119
- make_padding(padding, len(stride)),
1120
- lhs_dilation=(1,) * len(stride),
1121
- rhs_dilation=dilation,
1122
- dimension_numbers=create_default_conv_dimension_numbers(len(stride)),
1123
- feature_group_count=groups,
1124
- batch_group_count=1,
1107
+ input,
1108
+ weight,
1109
+ stride,
1110
+ make_padding(padding, len(stride)),
1111
+ lhs_dilation=(1,) * len(stride),
1112
+ rhs_dilation=dilation,
1113
+ dimension_numbers=create_default_conv_dimension_numbers(len(stride)),
1114
+ feature_group_count=groups,
1115
+ batch_group_count=1,
1125
1116
  )
1126
1117
 
1127
1118
  if bias is not None:
@@ -1137,10 +1128,9 @@ def _aten_convolution(
1137
1128
 
1138
1129
 
1139
1130
  # _native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps)
1140
- @op(torch.ops.aten._native_batch_norm_legit)
1141
- def _aten__native_batch_norm_legit(
1142
- input, weight, bias, running_mean, running_var, training, momentum, eps
1143
- ):
1131
+ @op(torch.ops.aten._native_batch_norm_legit.default)
1132
+ def _aten__native_batch_norm_legit(input, weight, bias, running_mean,
1133
+ running_var, training, momentum, eps):
1144
1134
  """JAX implementation of batch normalization with optional parameters.
1145
1135
  Refers to https://github.com/pytorch/pytorch/blob/cd3a71f754a2248bcfe500de7c9860bd7d2002bf/torch/_decomp/decompositions.py#L1713.
1146
1136
 
@@ -1161,8 +1151,7 @@ def _aten__native_batch_norm_legit(
1161
1151
  DeviceArray: Reversed batch variance (C,) or empty if training is False
1162
1152
  """
1163
1153
  reduction_dims = [0] + list(range(2, input.ndim))
1164
- reshape_dims = [1, -1] + [1]*(input.ndim-2)
1165
-
1154
+ reshape_dims = [1, -1] + [1] * (input.ndim - 2)
1166
1155
  if training:
1167
1156
  # Calculate batch mean and variance
1168
1157
  mean = jnp.mean(input, axis=reduction_dims, keepdims=True)
@@ -1175,7 +1164,9 @@ def _aten__native_batch_norm_legit(
1175
1164
  saved_rstd = jnp.squeeze(rstd, reduction_dims)
1176
1165
  else:
1177
1166
  rstd = jax.lax.rsqrt(running_var.reshape(reshape_dims) + eps)
1178
- saved_mean = jnp.array([], dtype=input.dtype) # No need to calculate batch statistics in inference mode
1167
+ saved_mean = jnp.array(
1168
+ [], dtype=input.dtype
1169
+ ) # No need to calculate batch statistics in inference mode
1179
1170
  saved_rstd = jnp.array([], dtype=input.dtype)
1180
1171
 
1181
1172
  # Normalize
@@ -1190,19 +1181,17 @@ def _aten__native_batch_norm_legit(
1190
1181
  if weight is not None:
1191
1182
  x_hat *= weight.reshape(reshape_dims) # Reshape weight for broadcasting
1192
1183
  if bias is not None:
1193
- x_hat += bias.reshape(reshape_dims) # Reshape bias for broadcasting
1184
+ x_hat += bias.reshape(reshape_dims) # Reshape bias for broadcasting
1194
1185
 
1195
1186
  return x_hat, saved_mean, saved_rstd
1196
1187
 
1197
1188
 
1198
-
1199
1189
  @op(torch.ops.aten._native_batch_norm_legit_no_training)
1200
- def _aten__native_batch_norm_legit_no_training(
1201
- input, weight, bias, running_mean, running_var, momentum, eps
1202
- ):
1203
- return _aten__native_batch_norm_legit(
1204
- input, weight, bias, running_mean, running_var, False, momentum, eps
1205
- )
1190
+ def _aten__native_batch_norm_legit_no_training(input, weight, bias,
1191
+ running_mean, running_var,
1192
+ momentum, eps):
1193
+ return _aten__native_batch_norm_legit(input, weight, bias, running_mean,
1194
+ running_var, False, momentum, eps)
1206
1195
 
1207
1196
 
1208
1197
  @op(torch.ops.aten.relu)
@@ -1212,7 +1201,15 @@ def _aten_relu(self):
1212
1201
 
1213
1202
  @op(torch.ops.aten.cat)
1214
1203
  def _aten_cat(tensors, dims=0):
1215
- return jnp.concatenate(tensors, dims)
1204
+ # handle empty tensors as a special case.
1205
+ # torch.cat will ignore the empty tensor, while jnp.concatenate
1206
+ # will error if the dims > 0.
1207
+ filtered_tensors = [
1208
+ t for t in tensors if not (t.ndim == 1 and t.shape[0] == 0)
1209
+ ]
1210
+ if filtered_tensors:
1211
+ return jnp.concatenate(filtered_tensors, dims)
1212
+ return tensors[0]
1216
1213
 
1217
1214
 
1218
1215
  def _ceil_mode_padding(
@@ -1220,6 +1217,7 @@ def _ceil_mode_padding(
1220
1217
  input_shape: list[int],
1221
1218
  kernel_size: list[int],
1222
1219
  stride: list[int],
1220
+ dilation: list[int],
1223
1221
  ceil_mode: bool,
1224
1222
  ):
1225
1223
  """Creates low and high padding specification for the given padding (which is symmetric) and ceil mode.
@@ -1232,20 +1230,13 @@ def _ceil_mode_padding(
1232
1230
  right_padding = left_padding
1233
1231
 
1234
1232
  input_size = input_shape[2 + i]
1235
- output_size_rem = (input_size + 2 * left_padding - kernel_size[i]) % stride[
1236
- i
1237
- ]
1233
+ output_size_rem = (input_size + 2 * left_padding -
1234
+ (kernel_size[i] - 1) * dilation[i] - 1) % stride[i]
1238
1235
  if ceil_mode and output_size_rem != 0:
1239
1236
  extra_padding = stride[i] - output_size_rem
1240
- new_output_size = (
1241
- input_size
1242
- + left_padding
1243
- + right_padding
1244
- + extra_padding
1245
- - kernel_size[i]
1246
- + stride[i]
1247
- - 1
1248
- ) // stride[i] + 1
1237
+ new_output_size = (input_size + left_padding + right_padding +
1238
+ extra_padding - (kernel_size[i] - 1) * dilation[i] -
1239
+ 1 + stride[i] - 1) // stride[i] + 1
1249
1240
  # Ensure that the last pooling starts inside the image.
1250
1241
  size_to_compare = input_size + left_padding
1251
1242
 
@@ -1258,30 +1249,36 @@ def _ceil_mode_padding(
1258
1249
 
1259
1250
  @op(torch.ops.aten.max_pool2d_with_indices)
1260
1251
  @op(torch.ops.aten.max_pool3d_with_indices)
1261
- def _aten_max_pool2d_with_indices(
1262
- inputs, kernel_size, strides, padding=0, dilation=1, ceil_mode=False
1263
- ):
1252
+ def _aten_max_pool2d_with_indices(inputs,
1253
+ kernel_size,
1254
+ strides=None,
1255
+ padding=0,
1256
+ dilation=1,
1257
+ ceil_mode=False):
1264
1258
  num_batch_dims = len(inputs.shape) - len(kernel_size) - 1
1265
1259
  kernel_size = tuple(kernel_size)
1266
- strides = tuple(strides)
1260
+ # Default stride is kernel_size
1261
+ strides = tuple(strides) if strides else kernel_size
1267
1262
  if isinstance(padding, int):
1268
1263
  padding = [padding for _ in range(len(kernel_size))]
1264
+ if isinstance(dilation, int):
1265
+ dilation = tuple(dilation for _ in range(len(kernel_size)))
1266
+ elif isinstance(dilation, list):
1267
+ dilation = tuple(dilation)
1269
1268
 
1270
1269
  input_shape = inputs.shape
1271
1270
  if num_batch_dims == 0:
1272
1271
  input_shape = [1, *input_shape]
1273
- padding = _ceil_mode_padding(
1274
- padding, input_shape, kernel_size, strides, ceil_mode
1275
- )
1272
+ padding = _ceil_mode_padding(padding, input_shape, kernel_size, strides,
1273
+ dilation, ceil_mode)
1276
1274
 
1277
- window_shape = kernel_size
1278
- num_batch_dims = inputs.ndim - (len(window_shape) + 1)
1279
- strides = strides or (1,) * len(window_shape)
1280
- assert len(window_shape) == len(
1281
- strides
1282
- ), f"len({window_shape}) must equal len({strides})"
1275
+ assert len(kernel_size) == len(
1276
+ strides), f"len({kernel_size=}) must equal len({strides=})"
1277
+ assert len(kernel_size) == len(
1278
+ dilation), f"len({kernel_size=}) must equal len({dilation=})"
1283
1279
  strides = (1,) * (1 + num_batch_dims) + strides
1284
- dims = (1,) * (1 + num_batch_dims) + window_shape
1280
+ dims = (1,) * (1 + num_batch_dims) + kernel_size
1281
+ dilation = (1,) * (1 + num_batch_dims) + dilation
1285
1282
 
1286
1283
  is_single_input = False
1287
1284
  if num_batch_dims == 0:
@@ -1290,26 +1287,27 @@ def _aten_max_pool2d_with_indices(
1290
1287
  inputs = inputs[None]
1291
1288
  strides = (1,) + strides
1292
1289
  dims = (1,) + dims
1290
+ dilation = (1,) + dilation
1293
1291
  is_single_input = True
1294
1292
 
1295
1293
  assert inputs.ndim == len(dims), f"len({inputs.shape}) != len({dims})"
1296
1294
  if not isinstance(padding, str):
1297
1295
  padding = tuple(map(tuple, padding))
1298
- assert len(padding) == len(window_shape), (
1299
- f"padding {padding} must specify pads for same number of dims as "
1300
- f"window_shape {window_shape}"
1301
- )
1302
- assert all(
1303
- [len(x) == 2 for x in padding]
1304
- ), f"each entry in padding {padding} must be length 2"
1296
+ assert len(padding) == len(kernel_size), (
1297
+ f"padding {padding} must specify pads for same number of dims as "
1298
+ f"kernel_size {kernel_size}")
1299
+ assert all([len(x) == 2 for x in padding
1300
+ ]), f"each entry in padding {padding} must be length 2"
1305
1301
  padding = ((0, 0), (0, 0)) + padding
1306
1302
 
1307
- indices = jnp.arange(np.prod(inputs.shape)).reshape(inputs.shape)
1303
+ indices = jnp.arange(np.prod(inputs.shape[-len(kernel_size):]))
1304
+ indices = indices.reshape(inputs.shape[-len(kernel_size):])
1305
+ indices = jnp.broadcast_to(indices, inputs.shape)
1308
1306
 
1309
1307
  def reduce_fn(a, b):
1310
1308
  ai, av = a
1311
1309
  bi, bv = b
1312
- which = av > bv
1310
+ which = av >= bv # torch breaks ties in favor of later indices
1313
1311
  return jnp.where(which, ai, bi), jnp.where(which, av, bv)
1314
1312
 
1315
1313
  init_val = -jnp.inf
@@ -1317,44 +1315,90 @@ def _aten_max_pool2d_with_indices(
1317
1315
  init_val = -(1 << 31)
1318
1316
  init_val = jnp.array(init_val).astype(inputs.dtype)
1319
1317
 
1320
- # Separate maxpool result and indices into two reduce_window ops. Since
1321
- # the indices tensor is usually unused in inference, separating the two
1318
+ # Separate maxpool result and indices into two reduce_window ops. Since
1319
+ # the indices tensor is usually unused in inference, separating the two
1322
1320
  # can help DCE computations for argmax.
1323
1321
  y = jax.lax.reduce_window(
1324
- inputs, init_val, jax.lax.max, dims, strides, padding
1325
- )
1322
+ inputs,
1323
+ init_val,
1324
+ jax.lax.max,
1325
+ dims,
1326
+ strides,
1327
+ padding,
1328
+ window_dilation=dilation)
1326
1329
  indices, _ = jax.lax.reduce_window(
1327
- (indices, inputs), (0, init_val), reduce_fn, dims, strides, padding
1330
+ (indices, inputs),
1331
+ (0, init_val),
1332
+ reduce_fn,
1333
+ dims,
1334
+ strides,
1335
+ padding,
1336
+ window_dilation=dilation,
1328
1337
  )
1329
1338
  if is_single_input:
1330
1339
  indices = jnp.squeeze(indices, axis=0)
1331
1340
  y = jnp.squeeze(y, axis=0)
1332
-
1341
+
1333
1342
  return y, indices
1334
1343
 
1335
1344
 
1345
+ # Aten ops registered under the `xla` library.
1346
+ try:
1347
+
1348
+ @op(torch.ops.xla.max_pool2d_forward)
1349
+ def _xla_max_pool2d_forward(*args, **kwargs):
1350
+ return _aten_max_pool2d_with_indices(*args, **kwargs)[0]
1351
+
1352
+ @op(torch.ops.xla.aot_mark_sharding)
1353
+ def _xla_aot_mark_sharding(t, mesh: str, partition_spec: str):
1354
+ from jax.sharding import PartitionSpec as P, NamedSharding
1355
+ import ast
1356
+ import torch_xla.distributed.spmd as xs
1357
+ pmesh = xs.Mesh.from_str(mesh)
1358
+ assert pmesh is not None
1359
+ partition_spec_eval = ast.literal_eval(partition_spec)
1360
+ jmesh = pmesh.get_jax_mesh()
1361
+ return jax.lax.with_sharding_constraint(
1362
+ t, NamedSharding(jmesh, P(*partition_spec_eval)))
1363
+
1364
+ @op(torch.ops.xla.einsum_linear_forward)
1365
+ def _xla_einsum_linear_forward(input, weight, bias):
1366
+ with jax.named_scope('einsum_linear_forward'):
1367
+ product = jax.numpy.einsum('...n,mn->...m', input, weight)
1368
+ if bias is not None:
1369
+ return product + bias
1370
+ return product
1371
+
1372
+ except AttributeError:
1373
+ pass
1374
+
1336
1375
  # TODO add more ops
1337
1376
 
1338
1377
 
1339
1378
  @op(torch.ops.aten.min)
1340
1379
  def _aten_min(x, dim=None, keepdim=False):
1341
1380
  if dim is not None:
1342
- return _with_reduction_scalar(jnp.min, x, dim, keepdim), _with_reduction_scalar(jnp.argmin, x, dim, keepdim).astype(jnp.int64)
1381
+ return _with_reduction_scalar(jnp.min, x, dim,
1382
+ keepdim), _with_reduction_scalar(
1383
+ jnp.argmin, x, dim,
1384
+ keepdim).astype(jnp.int64)
1343
1385
  else:
1344
1386
  return _with_reduction_scalar(jnp.min, x, dim, keepdim)
1345
1387
 
1346
1388
 
1347
1389
  @op(torch.ops.aten.mode)
1348
1390
  def _aten_mode(input, dim=-1, keepdim=False, *, out=None):
1349
- if input.ndim == 0: # single number
1391
+ if input.ndim == 0: # single number
1350
1392
  return input, jnp.array(0)
1351
- dim = (input.ndim + dim) % input.ndim # jnp.scipy.stats.mode does not accept -1 as dim
1393
+ dim = (input.ndim +
1394
+ dim) % input.ndim # jnp.scipy.stats.mode does not accept -1 as dim
1352
1395
  # keepdims must be True for accurate broadcasting
1353
1396
  mode, _ = jax.scipy.stats.mode(input, axis=dim, keepdims=True)
1354
1397
  mode_broadcast = jnp.broadcast_to(mode, input.shape)
1355
1398
  if not keepdim:
1356
1399
  mode = mode.squeeze(axis=dim)
1357
- indices = jnp.argmax(jnp.equal(mode_broadcast, input), axis=dim, keepdims=keepdim)
1400
+ indices = jnp.argmax(
1401
+ jnp.equal(mode_broadcast, input), axis=dim, keepdims=keepdim)
1358
1402
  return mode, indices
1359
1403
 
1360
1404
 
@@ -1388,8 +1432,7 @@ def _aten_var(x, dim=None, *, correction=1, keepdim=False, out=None):
1388
1432
  @op(torch.ops.prims.broadcast_in_dim)
1389
1433
  def _prims_broadcast_in_dim(t, shape, broadcast_dimensions):
1390
1434
  return jax.lax.broadcast_in_dim(
1391
- t, shape, broadcast_dimensions=broadcast_dimensions
1392
- )
1435
+ t, shape, broadcast_dimensions=broadcast_dimensions)
1393
1436
 
1394
1437
 
1395
1438
  # aten.native_group_norm -- should use decomp table
@@ -1432,17 +1475,15 @@ def _aten_native_group_norm(input, weight, bias, N, C, HxW, group, eps=1e-5):
1432
1475
  normalized = (x - mean) * rstd
1433
1476
  return normalized, mean, rstd
1434
1477
 
1435
- normalized, group_mean, group_rstd = jax.lax.map(
1436
- group_norm_body, reshaped_input
1437
- )
1478
+ normalized, group_mean, group_rstd = jax.lax.map(group_norm_body,
1479
+ reshaped_input)
1438
1480
 
1439
1481
  # Reshape back to original input shape
1440
1482
  output = jnp.reshape(normalized, input_shape)
1441
1483
 
1442
1484
  # **Affine transformation**
1443
- affine_shape = [
1444
- -1 if i == 1 else 1 for i in range(input.ndim)
1445
- ] # Shape for broadcasting
1485
+ affine_shape = [-1 if i == 1 else 1 for i in range(input.ndim)
1486
+ ] # Shape for broadcasting
1446
1487
  if weight is not None and bias is not None:
1447
1488
  output = bias.reshape(affine_shape) + output * weight.reshape(affine_shape)
1448
1489
  elif weight is not None:
@@ -1474,22 +1515,25 @@ def _aten_linalg_vector_norm(self, ord=2, dim=None, keepdim=False, dtype=None):
1474
1515
  The tensor containing the calculated vector norms.
1475
1516
  """
1476
1517
 
1477
- if ord not in {2, float("inf"), float("-inf"), "fro"} and not isinstance(ord, (int, float)):
1518
+ if ord not in {2, float("inf"), float("-inf"), "fro"
1519
+ } and not isinstance(ord, (int, float)):
1478
1520
  raise ValueError(
1479
- f"Unsupported ord value: {ord}. Supported values are 2, inf, -inf, and"
1480
- " 'fro'."
1481
- )
1482
-
1521
+ f"Unsupported ord value: {ord}. Supported values are 2, inf, -inf, and"
1522
+ " 'fro'.")
1523
+
1483
1524
  # Special cases (for efficiency and clarity)
1484
1525
  if ord == 0:
1485
1526
  if self.shape == ():
1486
1527
  # float sets it to float64. set it back to input type
1487
1528
  result = jnp.astype(jnp.array(float(self != 0)), self.dtype)
1488
1529
  else:
1489
- result = _with_reduction_scalar(jnp.sum, jnp.where(self != 0, 1, 0), dim, keepdim)
1530
+ result = _with_reduction_scalar(jnp.sum, jnp.where(self != 0, 1, 0), dim,
1531
+ keepdim)
1490
1532
 
1491
1533
  elif ord == 2: # Euclidean norm
1492
- result = jnp.sqrt(_with_reduction_scalar(jnp.sum, jnp.abs(self) ** 2, dim, keepdim))
1534
+ result = jnp.sqrt(
1535
+ _with_reduction_scalar(jnp.sum,
1536
+ jnp.abs(self)**2, dim, keepdim))
1493
1537
 
1494
1538
  elif ord == float("inf"):
1495
1539
  result = _with_reduction_scalar(jnp.max, jnp.abs(self), dim, keepdim)
@@ -1498,12 +1542,14 @@ def _aten_linalg_vector_norm(self, ord=2, dim=None, keepdim=False, dtype=None):
1498
1542
  result = _with_reduction_scalar(jnp.min, jnp.abs(self), dim, keepdim)
1499
1543
 
1500
1544
  elif ord == "fro": # Frobenius norm
1501
- result = jnp.sqrt(_with_reduction_scalar(jnp.sum, jnp.abs(self) ** 2, dim, keepdim))
1545
+ result = jnp.sqrt(
1546
+ _with_reduction_scalar(jnp.sum,
1547
+ jnp.abs(self)**2, dim, keepdim))
1502
1548
 
1503
1549
  else: # General case (e.g., ord = 1, ord = 3)
1504
- result = _with_reduction_scalar(jnp.sum, jnp.abs(self) ** ord, dim, keepdim) ** (
1505
- 1.0 / ord
1506
- )
1550
+ result = _with_reduction_scalar(jnp.sum,
1551
+ jnp.abs(self)**ord, dim,
1552
+ keepdim)**(1.0 / ord)
1507
1553
 
1508
1554
  # (Optional) dtype conversion
1509
1555
  if dtype is not None:
@@ -1539,9 +1585,12 @@ def _aten_sinh(self):
1539
1585
 
1540
1586
  # aten.native_layer_norm_backward
1541
1587
  @op(torch.ops.aten.native_layer_norm_backward)
1542
- def _aten_native_layer_norm_backward(
1543
- grad_out, input, normalized_shape, weight, bias, eps=1e-5
1544
- ):
1588
+ def _aten_native_layer_norm_backward(grad_out,
1589
+ input,
1590
+ normalized_shape,
1591
+ weight,
1592
+ bias,
1593
+ eps=1e-5):
1545
1594
  """Implements the backward pass of layer normalization in Jax as defined by `aten::native_layer_norm_backward`.
1546
1595
 
1547
1596
  Args:
@@ -1555,9 +1604,8 @@ def _aten_native_layer_norm_backward(
1555
1604
  Returns:
1556
1605
  A tuple of (grad_input, grad_weight, grad_bias).
1557
1606
  """
1558
- return jax.lax.native_layer_norm_backward(
1559
- grad_out, input, normalized_shape, weight, bias, eps
1560
- )
1607
+ return jax.lax.native_layer_norm_backward(grad_out, input, normalized_shape,
1608
+ weight, bias, eps)
1561
1609
 
1562
1610
 
1563
1611
  # aten.reflection_pad3d_backward
@@ -1585,12 +1633,14 @@ def _aten_bitwise_not(self):
1585
1633
 
1586
1634
 
1587
1635
  # aten.bitwise_left_shift
1636
+ @op(torch.ops.aten.__lshift__)
1588
1637
  @op(torch.ops.aten.bitwise_left_shift)
1589
1638
  def _aten_bitwise_left_shift(input, other):
1590
1639
  return jnp.left_shift(input, other)
1591
1640
 
1592
1641
 
1593
1642
  # aten.bitwise_right_shift
1643
+ @op(torch.ops.aten.__rshift__)
1594
1644
  @op(torch.ops.aten.bitwise_right_shift)
1595
1645
  def _aten_bitwise_right_shift(input, other):
1596
1646
  return jnp.right_shift(input, other)
@@ -1671,10 +1721,8 @@ def _scatter_index(dim, index):
1671
1721
  target_shape = [1] * len(index_shape)
1672
1722
  target_shape[i] = index_shape[i]
1673
1723
  input_indexes.append(
1674
- jnp.broadcast_to(
1675
- jnp.arange(index_shape[i]).reshape(target_shape), index_shape
1676
- )
1677
- )
1724
+ jnp.broadcast_to(
1725
+ jnp.arange(index_shape[i]).reshape(target_shape), index_shape))
1678
1726
  return tuple(input_indexes), tuple(source_indexes)
1679
1727
 
1680
1728
 
@@ -1686,6 +1734,7 @@ def _aten_scatter_add(input, dim, index, src):
1686
1734
  input_indexes, source_indexes = _scatter_index(dim, index)
1687
1735
  return input.at[input_indexes].add(src[source_indexes])
1688
1736
 
1737
+
1689
1738
  # aten.masked_scatter
1690
1739
  @op(torch.ops.aten.masked_scatter)
1691
1740
  def _aten_masked_scatter(self, mask, source):
@@ -1707,6 +1756,7 @@ def _aten_masked_scatter(self, mask, source):
1707
1756
 
1708
1757
  return final_arr
1709
1758
 
1759
+
1710
1760
  @op(torch.ops.aten.masked_select)
1711
1761
  def _aten_masked_select(self, mask, *args, **kwargs):
1712
1762
  broadcast_shape = jnp.broadcast_shapes(self.shape, mask.shape)
@@ -1722,6 +1772,7 @@ def _aten_masked_select(self, mask, *args, **kwargs):
1722
1772
 
1723
1773
  return self_flat[true_indices]
1724
1774
 
1775
+
1725
1776
  # aten.logical_not
1726
1777
 
1727
1778
 
@@ -1730,11 +1781,13 @@ def _aten_masked_select(self, mask, *args, **kwargs):
1730
1781
  def _aten_sign(x):
1731
1782
  return jnp.sign(x)
1732
1783
 
1784
+
1733
1785
  # aten.signbit
1734
1786
  @op(torch.ops.aten.signbit)
1735
1787
  def _aten_signbit(x):
1736
1788
  return jnp.signbit(x)
1737
1789
 
1790
+
1738
1791
  # aten.sigmoid
1739
1792
  @op(torch.ops.aten.sigmoid)
1740
1793
  @op_base.promote_int_input
@@ -1760,7 +1813,13 @@ def _aten_atan(self):
1760
1813
 
1761
1814
  @op(torch.ops.aten.scatter_reduce)
1762
1815
  @op(torch.ops.aten.scatter)
1763
- def _aten_scatter_reduce(input, dim, index, src, reduce=None, *, include_self=True):
1816
+ def _aten_scatter_reduce(input,
1817
+ dim,
1818
+ index,
1819
+ src,
1820
+ reduce=None,
1821
+ *,
1822
+ include_self=True):
1764
1823
  if not isinstance(src, jnp.ndarray):
1765
1824
  src = jnp.array(src, dtype=input.dtype)
1766
1825
  input_indexes, source_indexes = _scatter_index(dim, index)
@@ -1817,41 +1876,6 @@ def _aten_gt(self, other):
1817
1876
  return self > other
1818
1877
 
1819
1878
 
1820
- # aten.pixel_shuffle
1821
- @op(torch.ops.aten.pixel_shuffle)
1822
- def _aten_pixel_shuffle(x, upscale_factor):
1823
- """PixelShuffle implementation in JAX.
1824
-
1825
- Args:
1826
- x: Input tensor. Typically a feature map.
1827
- upscale_factor: Integer by which to upscale the spatial dimensions.
1828
-
1829
- Returns:
1830
- Tensor after PixelShuffle operation.
1831
- """
1832
-
1833
- batch_size, channels, height, width = x.shape
1834
-
1835
- if channels % (upscale_factor**2) != 0:
1836
- raise ValueError(
1837
- "Number of channels must be divisible by the square of the upscale factor."
1838
- )
1839
-
1840
- new_channels = channels // (upscale_factor**2)
1841
- new_height = height * upscale_factor
1842
- new_width = width * upscale_factor
1843
-
1844
- x = x.reshape(
1845
- batch_size, new_channels, upscale_factor, upscale_factor, height, width
1846
- )
1847
- x = jnp.transpose(
1848
- x, (0, 1, 2, 4, 3, 5)
1849
- ) # Move channels to spatial dimensions
1850
- x = x.reshape(batch_size, new_channels, new_height, new_width)
1851
-
1852
- return x
1853
-
1854
-
1855
1879
  # aten.sym_stride
1856
1880
  # aten.lt
1857
1881
  @op(torch.ops.aten.lt)
@@ -1883,8 +1907,7 @@ def pool(inputs, init, reduce_fn, window_shape, strides, padding):
1883
1907
  num_batch_dims = inputs.ndim - (len(window_shape) + 1)
1884
1908
  strides = strides or (1,) * len(window_shape)
1885
1909
  assert len(window_shape) == len(
1886
- strides
1887
- ), f"len({window_shape}) must equal len({strides})"
1910
+ strides), f"len({window_shape}) must equal len({strides})"
1888
1911
  strides = (1,) * (1 + num_batch_dims) + strides
1889
1912
  dims = (1,) * (1 + num_batch_dims) + window_shape
1890
1913
 
@@ -1901,23 +1924,22 @@ def pool(inputs, init, reduce_fn, window_shape, strides, padding):
1901
1924
  if not isinstance(padding, str):
1902
1925
  padding = tuple(map(tuple, padding))
1903
1926
  assert len(padding) == len(window_shape), (
1904
- f"padding {padding} must specify pads for same number of dims as "
1905
- f"window_shape {window_shape}"
1906
- )
1907
- assert all(
1908
- [len(x) == 2 for x in padding]
1909
- ), f"each entry in padding {padding} must be length 2"
1927
+ f"padding {padding} must specify pads for same number of dims as "
1928
+ f"window_shape {window_shape}")
1929
+ assert all([len(x) == 2 for x in padding
1930
+ ]), f"each entry in padding {padding} must be length 2"
1910
1931
  padding = ((0, 0), (0, 0)) + padding
1911
1932
  y = jax.lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding)
1912
1933
  if is_single_input:
1913
1934
  y = jnp.squeeze(y, axis=0)
1914
1935
  return y
1915
1936
 
1916
-
1937
+
1917
1938
  @op(torch.ops.aten._adaptive_avg_pool2d)
1918
1939
  @op(torch.ops.aten._adaptive_avg_pool3d)
1919
- def adaptive_avg_pool2or3d(input: jnp.ndarray, output_size: Tuple[int, int]) -> jnp.ndarray:
1920
- """
1940
+ def adaptive_avg_pool2or3d(input: jnp.ndarray,
1941
+ output_size: Tuple[int, int]) -> jnp.ndarray:
1942
+ """
1921
1943
  Applies a 2/3D adaptive average pooling over an input signal composed of several input planes.
1922
1944
 
1923
1945
  See :class:`~torch.nn.AdaptiveAvgPool2d` for details and output shape.
@@ -1929,124 +1951,128 @@ def adaptive_avg_pool2or3d(input: jnp.ndarray, output_size: Tuple[int, int]) ->
1929
1951
  Context:
1930
1952
  https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py#L2401
1931
1953
  """
1932
- shape = input.shape
1933
- ndim = len(shape)
1934
- out_dim = len(output_size)
1935
- num_spatial_dim = ndim - out_dim
1936
-
1937
- # Preconditions
1938
-
1939
- assert ndim in (out_dim + 1, out_dim + 2), f"adaptive_avg_pool{num_spatial_dim}d(): Expected {num_spatial_dim+1}D or {num_spatial_dim+2}D tensor, but got {ndim}"
1940
- for d in input.shape[-2:]:
1941
- assert d != 0, "adaptive_avg_pool{num_spactial_dim}d(): Expected input to have non-zero size for " \
1942
- f"non-batch dimensions, but input has shape {tuple(shape)}."
1943
-
1944
- # Optimisation (we should also do this in the kernel implementation)
1945
- if all(s % o == 0 for o, s in zip(output_size, shape[-out_dim:])):
1946
- stride = tuple(i // o for i, o in zip(shape[-out_dim:], output_size))
1947
- kernel = tuple(i - (o - 1) * s for i, o, s in zip(shape[-out_dim:], output_size, stride))
1948
- return _aten_avg_pool(
1949
- input,
1950
- kernel,
1951
- strides=stride,
1952
- )
1953
-
1954
- def start_index(a, b, c):
1955
- return (a * c) // b
1956
-
1957
- def end_index(a, b, c):
1958
- return ((a + 1) * c + b - 1) // b
1959
-
1960
- def compute_idx(in_size, out_size):
1961
- orange = jnp.arange(out_size, dtype=jnp.int64)
1962
- i0 = start_index(orange, out_size, in_size)
1963
- # Let length = end_index - start_index, i.e. the length of the pooling kernels
1964
- # length.max() can be computed analytically as follows:
1965
- maxlength = in_size // out_size + 1
1966
- in_size_mod = in_size % out_size
1967
- # adaptive = True iff there are kernels with different lengths
1968
- adaptive = not (in_size_mod == 0 or out_size % in_size_mod == 0)
1969
- if adaptive:
1970
- maxlength += 1
1971
- elif in_size_mod == 0:
1972
- maxlength -= 1
1973
-
1974
- range_max = jnp.arange(maxlength, dtype=jnp.int64)
1975
- idx = i0[:, None] + range_max
1976
- if adaptive:
1977
- # Need to clamp to avoid accessing out-of-bounds memory
1978
- idx = jnp.minimum(idx, in_size - 1)
1979
-
1980
- # Compute the length
1981
- i1 = end_index(orange, out_size, in_size)
1982
- length = i1 - i0
1983
- else:
1984
- length = maxlength
1985
- return idx, length, range_max, adaptive
1986
-
1987
- idx, length, range_max, adaptive = [[None] * out_dim for _ in range(4)]
1988
- # length is not None if it's constant, otherwise we'll need to compute it
1989
- for i, (s, o) in enumerate(zip(shape[-out_dim:], output_size)):
1990
- idx[i], length[i], range_max[i], adaptive[i] = compute_idx(s, o)
1991
-
1992
- def _unsqueeze_to_dim(x, dim):
1993
- ndim = len(x.shape)
1994
- return jax.lax.expand_dims(x, tuple(range(ndim, dim)))
1995
-
1996
- if out_dim == 2:
1997
- # NOTE: unsqueeze to insert extra 1 in ranks; so they
1998
- # would broadcast
1999
- vals = input[..., _unsqueeze_to_dim(idx[0], 4), idx[1]]
2000
- reduce_axis = (-3, -1)
1954
+ shape = input.shape
1955
+ ndim = len(shape)
1956
+ out_dim = len(output_size)
1957
+ num_spatial_dim = ndim - out_dim
1958
+
1959
+ # Preconditions
1960
+
1961
+ assert ndim in (
1962
+ out_dim + 1, out_dim + 2
1963
+ ), f"adaptive_avg_pool{num_spatial_dim}d(): Expected {num_spatial_dim+1}D or {num_spatial_dim+2}D tensor, but got {ndim}"
1964
+ for d in input.shape[-2:]:
1965
+ assert d != 0, "adaptive_avg_pool{num_spactial_dim}d(): Expected input to have non-zero size for " \
1966
+ f"non-batch dimensions, but input has shape {tuple(shape)}."
1967
+
1968
+ # Optimisation (we should also do this in the kernel implementation)
1969
+ if all(s % o == 0 for o, s in zip(output_size, shape[-out_dim:])):
1970
+ stride = tuple(i // o for i, o in zip(shape[-out_dim:], output_size))
1971
+ kernel = tuple(i - (o - 1) * s
1972
+ for i, o, s in zip(shape[-out_dim:], output_size, stride))
1973
+ return _aten_avg_pool(
1974
+ input,
1975
+ kernel,
1976
+ strides=stride,
1977
+ )
1978
+
1979
+ def start_index(a, b, c):
1980
+ return (a * c) // b
1981
+
1982
+ def end_index(a, b, c):
1983
+ return ((a + 1) * c + b - 1) // b
1984
+
1985
+ def compute_idx(in_size, out_size):
1986
+ orange = jnp.arange(out_size, dtype=jnp.int64)
1987
+ i0 = start_index(orange, out_size, in_size)
1988
+ # Let length = end_index - start_index, i.e. the length of the pooling kernels
1989
+ # length.max() can be computed analytically as follows:
1990
+ maxlength = in_size // out_size + 1
1991
+ in_size_mod = in_size % out_size
1992
+ # adaptive = True iff there are kernels with different lengths
1993
+ adaptive = not (in_size_mod == 0 or out_size % in_size_mod == 0)
1994
+ if adaptive:
1995
+ maxlength += 1
1996
+ elif in_size_mod == 0:
1997
+ maxlength -= 1
1998
+
1999
+ range_max = jnp.arange(maxlength, dtype=jnp.int64)
2000
+ idx = i0[:, None] + range_max
2001
+ if adaptive:
2002
+ # Need to clamp to avoid accessing out-of-bounds memory
2003
+ idx = jnp.minimum(idx, in_size - 1)
2004
+
2005
+ # Compute the length
2006
+ i1 = end_index(orange, out_size, in_size)
2007
+ length = i1 - i0
2008
+ else:
2009
+ length = maxlength
2010
+ return idx, length, range_max, adaptive
2011
+
2012
+ idx, length, range_max, adaptive = [[None] * out_dim for _ in range(4)]
2013
+ # length is not None if it's constant, otherwise we'll need to compute it
2014
+ for i, (s, o) in enumerate(zip(shape[-out_dim:], output_size)):
2015
+ idx[i], length[i], range_max[i], adaptive[i] = compute_idx(s, o)
2016
+
2017
+ def _unsqueeze_to_dim(x, dim):
2018
+ ndim = len(x.shape)
2019
+ return jax.lax.expand_dims(x, tuple(range(ndim, dim)))
2020
+
2021
+ if out_dim == 2:
2022
+ # NOTE: unsqueeze to insert extra 1 in ranks; so they
2023
+ # would broadcast
2024
+ vals = input[..., _unsqueeze_to_dim(idx[0], 4), idx[1]]
2025
+ reduce_axis = (-3, -1)
2026
+ else:
2027
+ assert out_dim == 3
2028
+ vals = input[...,
2029
+ _unsqueeze_to_dim(idx[0], 6),
2030
+ _unsqueeze_to_dim(idx[1], 4), idx[2]]
2031
+ reduce_axis = (-5, -3, -1)
2032
+
2033
+ # Shortcut for the simpler case
2034
+ if not any(adaptive):
2035
+ return jnp.mean(vals, axis=reduce_axis)
2036
+
2037
+ def maybe_mask(vals, length, range_max, adaptive, dim):
2038
+ if isinstance(length, int):
2039
+ return vals, length
2001
2040
  else:
2002
- assert out_dim == 3
2003
- vals = input[..., _unsqueeze_to_dim(idx[0], 6),
2004
- _unsqueeze_to_dim(idx[1], 4),
2005
- idx[2]]
2006
- reduce_axis = (-5, -3, -1)
2007
-
2008
- # Shortcut for the simpler case
2009
- if not any(adaptive):
2010
- return jnp.mean(vals, axis=reduce_axis)
2011
-
2012
- def maybe_mask(vals, length, range_max, adaptive, dim):
2013
- if isinstance(length, int):
2014
- return vals, length
2015
- else:
2016
- # zero-out the things we didn't really want to select
2017
- assert dim < 0
2018
- # hack
2019
- mask = range_max >= length[:, None]
2020
- if dim == -2:
2021
- mask = _unsqueeze_to_dim(mask, 4)
2022
- elif dim == -3:
2023
- mask = _unsqueeze_to_dim(mask, 6)
2024
- vals = jnp.where(mask, 0.0, vals)
2025
- # Compute the length of each window
2026
- length = _unsqueeze_to_dim(length, -dim)
2027
- return vals, length
2028
-
2029
- for i in range(len(length)):
2030
- vals, length[i] = maybe_mask(vals, length[i], range_max[i], adaptive=adaptive[i], dim=(i - out_dim))
2031
-
2032
- # We unroll the sum as we assume that the kernels are going to be small
2033
- ret = jnp.sum(vals, axis=reduce_axis)
2034
- # NOTE: math.prod because we want to expand it to length[0] * length[1] * ...
2035
- # this is multiplication with broadcasting, not regular pointwise product
2036
- return ret / math.prod(length)
2037
-
2041
+ # zero-out the things we didn't really want to select
2042
+ assert dim < 0
2043
+ # hack
2044
+ mask = range_max >= length[:, None]
2045
+ if dim == -2:
2046
+ mask = _unsqueeze_to_dim(mask, 4)
2047
+ elif dim == -3:
2048
+ mask = _unsqueeze_to_dim(mask, 6)
2049
+ vals = jnp.where(mask, 0.0, vals)
2050
+ # Compute the length of each window
2051
+ length = _unsqueeze_to_dim(length, -dim)
2052
+ return vals, length
2053
+
2054
+ for i in range(len(length)):
2055
+ vals, length[i] = maybe_mask(
2056
+ vals, length[i], range_max[i], adaptive=adaptive[i], dim=(i - out_dim))
2057
+
2058
+ # We unroll the sum as we assume that the kernels are going to be small
2059
+ ret = jnp.sum(vals, axis=reduce_axis)
2060
+ # NOTE: math.prod because we want to expand it to length[0] * length[1] * ...
2061
+ # this is multiplication with broadcasting, not regular pointwise product
2062
+ return ret / math.prod(length)
2063
+
2038
2064
 
2039
2065
  @op(torch.ops.aten.avg_pool1d)
2040
2066
  @op(torch.ops.aten.avg_pool2d)
2041
2067
  @op(torch.ops.aten.avg_pool3d)
2042
2068
  def _aten_avg_pool(
2043
- inputs,
2044
- kernel_size,
2045
- strides=None,
2046
- padding=0,
2047
- ceil_mode=False,
2048
- count_include_pad=True,
2049
- divisor_override=None,
2069
+ inputs,
2070
+ kernel_size,
2071
+ strides=None,
2072
+ padding=0,
2073
+ ceil_mode=False,
2074
+ count_include_pad=True,
2075
+ divisor_override=None,
2050
2076
  ):
2051
2077
  num_batch_dims = len(inputs.shape) - len(kernel_size) - 1
2052
2078
  kernel_size = tuple(kernel_size)
@@ -2060,7 +2086,7 @@ def _aten_avg_pool(
2060
2086
  if num_batch_dims == 0:
2061
2087
  input_shape = [1, *input_shape]
2062
2088
  padding = _ceil_mode_padding(padding, input_shape, kernel_size, strides,
2063
- ceil_mode)
2089
+ [1] * len(kernel_size), ceil_mode)
2064
2090
 
2065
2091
  y = pool(inputs, 0.0, jax.lax.add, kernel_size, strides, padding)
2066
2092
  if divisor_override is not None:
@@ -2102,9 +2128,11 @@ def _aten_avg_pool(
2102
2128
  )
2103
2129
  return y.astype(inputs.dtype)
2104
2130
 
2131
+
2105
2132
  # helper function to generate all indices to iterate through ndarray
2106
- def _generate_indices(dims, skip_dim_indices = []):
2133
+ def _generate_indices(dims, skip_dim_indices=[]):
2107
2134
  res = []
2135
+
2108
2136
  def _helper(curr_dim_idx, sofar):
2109
2137
  if curr_dim_idx in skip_dim_indices:
2110
2138
  _helper(curr_dim_idx + 1, sofar[:])
@@ -2115,10 +2143,11 @@ def _generate_indices(dims, skip_dim_indices = []):
2115
2143
  for i in range(dims[curr_dim_idx]):
2116
2144
  sofar[curr_dim_idx] = i
2117
2145
  _helper(curr_dim_idx + 1, sofar[:])
2118
-
2146
+
2119
2147
  _helper(0, [0 for _ in dims])
2120
2148
  return res
2121
2149
 
2150
+
2122
2151
  # aten.sym_numel
2123
2152
  # aten.reciprocal
2124
2153
  @op(torch.ops.aten.reciprocal)
@@ -2174,10 +2203,14 @@ def _aten_round(input, decimals=0):
2174
2203
  @op(torch.ops.aten.max)
2175
2204
  def _aten_max(self, dim=None, keepdim=False):
2176
2205
  if dim is not None:
2177
- return _with_reduction_scalar(jnp.max, self, dim, keepdim), _with_reduction_scalar(jnp.argmax, self, dim, keepdim).astype(jnp.int64)
2206
+ return _with_reduction_scalar(jnp.max, self, dim,
2207
+ keepdim), _with_reduction_scalar(
2208
+ jnp.argmax, self, dim,
2209
+ keepdim).astype(jnp.int64)
2178
2210
  else:
2179
2211
  return _with_reduction_scalar(jnp.max, self, dim, keepdim)
2180
2212
 
2213
+
2181
2214
  # aten.maximum
2182
2215
  @op(torch.ops.aten.maximum)
2183
2216
  def _aten_maximum(self, other):
@@ -2216,27 +2249,28 @@ def _with_reduction_scalar(jax_func, self, dim, keepdim):
2216
2249
  def _aten_any(self, dim=None, keepdim=False):
2217
2250
  return _with_reduction_scalar(jnp.any, self, dim, keepdim)
2218
2251
 
2252
+
2219
2253
  # aten.arange
2220
2254
  @op(torch.ops.aten.arange.start_step)
2221
2255
  @op(torch.ops.aten.arange.start)
2222
2256
  @op(torch.ops.aten.arange.default)
2223
2257
  @op_base.convert_dtype(use_default_dtype=False)
2224
2258
  def _aten_arange(
2225
- start,
2226
- end=None,
2227
- step=None,
2228
- *,
2229
- dtype=None,
2230
- layout=None,
2231
- requires_grad=False,
2232
- device=None,
2233
- pin_memory=False,
2259
+ start,
2260
+ end=None,
2261
+ step=None,
2262
+ *,
2263
+ dtype=None,
2264
+ layout=None,
2265
+ requires_grad=False,
2266
+ device=None,
2267
+ pin_memory=False,
2234
2268
  ):
2235
2269
  return jnp.arange(
2236
- op_base.maybe_convert_constant_dtype(start, dtype),
2237
- op_base.maybe_convert_constant_dtype(end, dtype),
2238
- op_base.maybe_convert_constant_dtype(step, dtype),
2239
- dtype=dtype,
2270
+ op_base.maybe_convert_constant_dtype(start, dtype),
2271
+ op_base.maybe_convert_constant_dtype(end, dtype),
2272
+ op_base.maybe_convert_constant_dtype(step, dtype),
2273
+ dtype=dtype,
2240
2274
  )
2241
2275
 
2242
2276
 
@@ -2245,6 +2279,7 @@ def _aten_arange(
2245
2279
  def _aten_argmax(self, dim=None, keepdim=False):
2246
2280
  return _with_reduction_scalar(jnp.argmax, self, dim, keepdim)
2247
2281
 
2282
+
2248
2283
  def _strided_index(sizes, strides, storage_offset=None):
2249
2284
  ind = jnp.zeros(sizes, dtype=jnp.int32)
2250
2285
 
@@ -2257,6 +2292,7 @@ def _strided_index(sizes, strides, storage_offset=None):
2257
2292
  ind += storage_offset
2258
2293
  return ind
2259
2294
 
2295
+
2260
2296
  # aten.as_strided
2261
2297
  @op(torch.ops.aten.as_strided)
2262
2298
  @op(torch.ops.aten.as_strided_copy)
@@ -2311,7 +2347,7 @@ def _aten_broadcast_tensors(*tensors):
2311
2347
  Args:
2312
2348
  shapes: A list of tuples representing the shapes of the input tensors.
2313
2349
 
2314
- Returns:
2350
+ Returns:
2315
2351
  A tuple representing the broadcasted output shape.
2316
2352
  """
2317
2353
 
@@ -2344,11 +2380,13 @@ def _aten_broadcast_tensors(*tensors):
2344
2380
  A tuple specifying which dimensions of the input tensor should be broadcasted.
2345
2381
  """
2346
2382
 
2347
- res = tuple(i for i, (in_dim, out_dim) in enumerate(zip(input_shape, output_shape)))
2383
+ res = tuple(
2384
+ i for i, (in_dim, out_dim) in enumerate(zip(input_shape, output_shape)))
2348
2385
  return res
2349
2386
 
2350
2387
  # clean some function's previous wrap
2351
- if len(tensors)==1 and len(tensors[0])>=1 and isinstance(tensors[0][0], jax.Array):
2388
+ if len(tensors) == 1 and len(tensors[0]) >= 1 and isinstance(
2389
+ tensors[0][0], jax.Array):
2352
2390
  tensors = tensors[0]
2353
2391
 
2354
2392
  # Get the shapes of all input tensors
@@ -2357,7 +2395,8 @@ def _aten_broadcast_tensors(*tensors):
2357
2395
  output_shape = _get_broadcast_shape(shapes)
2358
2396
  # Broadcast each tensor to the output shape
2359
2397
  broadcasted_tensors = [
2360
- jax.lax.broadcast_in_dim(t, output_shape, _broadcast_dimensions(t.shape, output_shape))
2398
+ jax.lax.broadcast_in_dim(t, output_shape,
2399
+ _broadcast_dimensions(t.shape, output_shape))
2361
2400
  for t in tensors
2362
2401
  ]
2363
2402
 
@@ -2376,6 +2415,7 @@ def _aten_broadcast_to(input, shape):
2376
2415
  def _aten_clamp(self, min=None, max=None):
2377
2416
  return jnp.clip(self, min, max)
2378
2417
 
2418
+
2379
2419
  @op(torch.ops.aten.clamp_min)
2380
2420
  def _aten_clamp_min(input, min):
2381
2421
  return jnp.clip(input, min=min)
@@ -2394,7 +2434,7 @@ def _aten_constant_pad_nd(input, padding, value=0):
2394
2434
  rev_padding = [(padding[i - 1], padding[i], 0) for i in range(m - 1, 0, -2)]
2395
2435
  pad_dim = tuple(([(0, 0, 0)] * (len(input.shape) - m // 2)) + rev_padding)
2396
2436
  value_casted = jax.numpy.array(value, dtype=input.dtype)
2397
- return jax.lax.pad(input, padding_value=value_casted, padding_config = pad_dim)
2437
+ return jax.lax.pad(input, padding_value=value_casted, padding_config=pad_dim)
2398
2438
 
2399
2439
 
2400
2440
  # aten.convolution_backward
@@ -2421,9 +2461,8 @@ def _aten_cdist_forward(x1, x2, p, compute_mode=""):
2421
2461
  @op(torch.ops.aten._pdist_forward)
2422
2462
  def _aten__pdist_forward(x, p=2):
2423
2463
  pairwise_dists = _aten_cdist_forward(x, x, p)
2424
- condensed_dists = pairwise_dists[
2425
- jnp.triu_indices(pairwise_dists.shape[0], k=1)
2426
- ]
2464
+ condensed_dists = pairwise_dists[jnp.triu_indices(
2465
+ pairwise_dists.shape[0], k=1)]
2427
2466
  return condensed_dists
2428
2467
 
2429
2468
 
@@ -2449,25 +2488,33 @@ def _aten_cosh(input):
2449
2488
  return jnp.cosh(input)
2450
2489
 
2451
2490
 
2491
+ @op(torch.ops.aten.diag)
2492
+ def _aten_diag(input, diagonal=0):
2493
+ return jnp.diag(input, diagonal)
2494
+
2495
+
2452
2496
  # aten.diagonal
2453
2497
  @op(torch.ops.aten.diagonal)
2498
+ @op(torch.ops.aten.diagonal_copy)
2454
2499
  def _aten_diagonal(input, offset=0, dim1=0, dim2=1):
2455
2500
  return jnp.diagonal(input, offset, dim1, dim2)
2456
2501
 
2457
2502
 
2458
2503
  def diag_indices_with_offset(input_shape, offset, dim1=0, dim2=1):
2459
- input_len = len(input_shape)
2460
- if dim1 == dim2 or not (0 <= dim1 < input_len and 0 <= dim2 < input_len):
2461
- raise ValueError("dim1 and dim2 must be different and in range [0, " + str(input_len-1)+ "]")
2462
-
2463
- size1, size2 = input_shape[dim1], input_shape[dim2]
2464
- if offset >= 0:
2465
- indices1 = jnp.arange(min(size1, size2 - offset))
2466
- indices2 = jnp.arange(offset, offset + len(indices1))
2467
- else:
2468
- indices2 = jnp.arange(min(size1 + offset, size2 ))
2469
- indices1 = jnp.arange(-offset, -offset + len(indices2))
2470
- return [indices1, indices2]
2504
+ input_len = len(input_shape)
2505
+ if dim1 == dim2 or not (0 <= dim1 < input_len and 0 <= dim2 < input_len):
2506
+ raise ValueError("dim1 and dim2 must be different and in range [0, " +
2507
+ str(input_len - 1) + "]")
2508
+
2509
+ size1, size2 = input_shape[dim1], input_shape[dim2]
2510
+ if offset >= 0:
2511
+ indices1 = jnp.arange(min(size1, size2 - offset))
2512
+ indices2 = jnp.arange(offset, offset + len(indices1))
2513
+ else:
2514
+ indices2 = jnp.arange(min(size1 + offset, size2))
2515
+ indices1 = jnp.arange(-offset, -offset + len(indices2))
2516
+ return [indices1, indices2]
2517
+
2471
2518
 
2472
2519
  @op(torch.ops.aten.diagonal_scatter)
2473
2520
  def _aten_diagonal_scatter(input, src, offset=0, dim1=0, dim2=1):
@@ -2476,17 +2523,17 @@ def _aten_diagonal_scatter(input, src, offset=0, dim1=0, dim2=1):
2476
2523
  if input.ndim == 2:
2477
2524
  return input.at[tuple(indexes)].set(src)
2478
2525
  else:
2479
- # src has the same shape as the output of
2526
+ # src has the same shape as the output of
2480
2527
  # jnp.diagonal(input, offset, dim1, dim2).
2481
2528
  # Last dimension always contains the diagonal elements,
2482
2529
  # while the preceding dimensions represent the "slices"
2483
2530
  # from which these diagonals are extracted. Thus,
2484
2531
  # we alter input axes to match this assumption, write src
2485
2532
  # and then move the axes back to the original state.
2486
- input = jnp.moveaxis(input, (dim1, dim2), (-2,-1))
2487
- multi_indexes = [slice(None)]*(input.ndim-2) + indexes
2533
+ input = jnp.moveaxis(input, (dim1, dim2), (-2, -1))
2534
+ multi_indexes = [slice(None)] * (input.ndim - 2) + indexes
2488
2535
  input = input.at[tuple(multi_indexes)].set(src)
2489
- return jnp.moveaxis(input, (-2,-1), (dim1, dim2))
2536
+ return jnp.moveaxis(input, (-2, -1), (dim1, dim2))
2490
2537
 
2491
2538
 
2492
2539
  # aten.diagflat
@@ -2507,9 +2554,9 @@ def _aten_eq(input1, input2):
2507
2554
 
2508
2555
 
2509
2556
  # aten.equal
2510
- @op(torch.ops.aten.equal, is_jax_function=False)
2557
+ @op(torch.ops.aten.equal)
2511
2558
  def _aten_equal(input, other):
2512
- res = jnp.array_equal(input._elem, other._elem)
2559
+ res = jnp.array_equal(input, other)
2513
2560
  return bool(res)
2514
2561
 
2515
2562
 
@@ -2559,7 +2606,12 @@ def _aten_exp2(input):
2559
2606
  # aten.fill
2560
2607
  @op(torch.ops.aten.fill)
2561
2608
  @op(torch.ops.aten.full_like)
2562
- def _aten_fill(x, value, dtype=None, pin_memory=None, memory_format=None, device=None):
2609
+ def _aten_fill(x,
2610
+ value,
2611
+ dtype=None,
2612
+ pin_memory=None,
2613
+ memory_format=None,
2614
+ device=None):
2563
2615
  if dtype is None:
2564
2616
  dtype = x.dtype
2565
2617
  else:
@@ -2634,7 +2686,8 @@ def _aten_glu(x, dim=-1):
2634
2686
  # aten.hardtanh
2635
2687
  @op(torch.ops.aten.hardtanh)
2636
2688
  def _aten_hardtanh(input, min_val=-1, max_val=1, inplace=False):
2637
- if input.dtype == np.int64 and isinstance(max_val, float) and isinstance(min_val, float):
2689
+ if input.dtype == np.int64 and isinstance(max_val, float) and isinstance(
2690
+ min_val, float):
2638
2691
  min_val = int(min_val)
2639
2692
  max_val = int(max_val)
2640
2693
  return jnp.clip(input, min_val, max_val)
@@ -2644,7 +2697,7 @@ def _aten_hardtanh(input, min_val=-1, max_val=1, inplace=False):
2644
2697
  @op(torch.ops.aten.histc)
2645
2698
  def _aten_histc(input, bins=100, min=0, max=0):
2646
2699
  # TODO(@manfei): this function might cause some uncertainty
2647
- if min==0 and max==0:
2700
+ if min == 0 and max == 0:
2648
2701
  if isinstance(input, jnp.ndarray) and input.size == 0:
2649
2702
  min = 0
2650
2703
  max = 0
@@ -2652,7 +2705,8 @@ def _aten_histc(input, bins=100, min=0, max=0):
2652
2705
  min = jnp.min(input)
2653
2706
  max = jnp.max(input)
2654
2707
  range_value = (min, max)
2655
- hist, bin_edges = jnp.histogram(input, bins=bins, range=range_value, weights=None, density=None)
2708
+ hist, bin_edges = jnp.histogram(
2709
+ input, bins=bins, range=range_value, weights=None, density=None)
2656
2710
  return hist
2657
2711
 
2658
2712
 
@@ -2667,22 +2721,28 @@ def _aten_digamma(input, *, out=None):
2667
2721
  # replace indices where input == 0 with -inf in res
2668
2722
  return jnp.where(jnp.equal(input, jnp.zeros(input.shape)), -jnp.inf, res)
2669
2723
 
2724
+
2670
2725
  @op(torch.ops.aten.igamma)
2671
2726
  def _aten_igamma(input, other):
2672
2727
  return jax.scipy.special.gammainc(input, other)
2673
2728
 
2729
+
2674
2730
  @op(torch.ops.aten.lgamma)
2675
2731
  def _aten_lgamma(input, *, out=None):
2676
2732
  return jax.scipy.special.gammaln(input).astype(jnp.float32)
2677
2733
 
2734
+
2678
2735
  @op(torch.ops.aten.mvlgamma)
2679
2736
  def _aten_mvlgamma(input, p, *, out=None):
2680
- return jax.scipy.special.multigammaln(input, d)
2737
+ input = input.astype(mappings.t2j_dtype(torch.get_default_dtype()))
2738
+ return jax.scipy.special.multigammaln(input, p)
2739
+
2681
2740
 
2682
2741
  @op(torch.ops.aten.linalg_eig)
2683
2742
  def _aten_linalg_eig(A):
2684
2743
  return jnp.linalg.eig(A)
2685
2744
 
2745
+
2686
2746
  @op(torch.ops.aten._linalg_eigh)
2687
2747
  def _aten_linalg_eigh(A, UPLO='L'):
2688
2748
  return jnp.linalg.eigh(A, UPLO)
@@ -2704,7 +2764,9 @@ def _aten_linalg_lstsq(A, B, rcond=None, driver='gelsy'):
2704
2764
  A_reshaped = A.reshape((batch_size,) + A.shape[-2:])
2705
2765
  B_reshaped = B.reshape((batch_size,) + B.shape[-2:])
2706
2766
 
2707
- X, residuals, rank, singular_values = jax.vmap(jnp.linalg.lstsq, in_axes=(0, 0))(A_reshaped, B_reshaped, rcond=rcond)
2767
+ X, residuals, rank, singular_values = jax.vmap(
2768
+ jnp.linalg.lstsq, in_axes=(0,
2769
+ 0))(A_reshaped, B_reshaped, rcond=rcond)
2708
2770
 
2709
2771
  X = X.reshape(batch_shape + X.shape[-2:])
2710
2772
 
@@ -2720,7 +2782,8 @@ def _aten_linalg_lstsq(A, B, rcond=None, driver='gelsy'):
2720
2782
  residuals = residuals.reshape(batch_shape + residuals.shape[-1:])
2721
2783
 
2722
2784
  if driver in ['gelsd', 'gelss']:
2723
- singular_values = singular_values.reshape(batch_shape + singular_values.shape[-1:])
2785
+ singular_values = singular_values.reshape(batch_shape +
2786
+ singular_values.shape[-1:])
2724
2787
  else:
2725
2788
  singular_values = jnp.array([], dtype=input_dtype)
2726
2789
 
@@ -2729,17 +2792,17 @@ def _aten_linalg_lstsq(A, B, rcond=None, driver='gelsy'):
2729
2792
  X, residuals, rank, singular_values = jnp.linalg.lstsq(A, B, rcond=rcond)
2730
2793
 
2731
2794
  if driver not in ['gelsd', 'gelsy', 'gelss']:
2732
- rank = jnp.array([], dtype=jnp.int64)
2795
+ rank = jnp.array([], dtype=jnp.int64)
2733
2796
 
2734
2797
  rank_value = None
2735
2798
  if rank.size > 0:
2736
- rank_value = int(rank.item())
2737
- rank = jnp.array(rank_value, dtype=jnp.int64)
2799
+ rank_value = int(rank.item())
2800
+ rank = jnp.array(rank_value, dtype=jnp.int64)
2738
2801
 
2739
2802
  # When driver is ‘gels’, assume that A is full-rank.
2740
- full_rank = driver == 'gels' or rank_value == n
2803
+ full_rank = driver == 'gels' or rank_value == n
2741
2804
  if driver == 'gelsy' or m <= n or (not full_rank):
2742
- residuals = jnp.array([], dtype=input_dtype)
2805
+ residuals = jnp.array([], dtype=input_dtype)
2743
2806
 
2744
2807
  if driver not in ['gelsd', 'gelss']:
2745
2808
  singular_values = jnp.array([], dtype=input_dtype)
@@ -2753,8 +2816,7 @@ def _aten_linalg_ldl_factor_ex(A, hermitian=False, check_errors=False):
2753
2816
  # https://github.com/jax-ml/jax/issues/12779
2754
2817
  # TODO: Not tested for complex inputs. Does not support hermitian=True
2755
2818
  pivots = jnp.broadcast_to(
2756
- jnp.arange(1, A.shape[-1]+1, dtype=jnp.int32), A.shape[:-1]
2757
- )
2819
+ jnp.arange(1, A.shape[-1] + 1, dtype=jnp.int32), A.shape[:-1])
2758
2820
  info = jnp.zeros(A.shape[:-2], jnp.int32)
2759
2821
  C = jnp.linalg.cholesky(A)
2760
2822
  if C.size == 0:
@@ -2767,7 +2829,7 @@ def _aten_linalg_ldl_factor_ex(A, hermitian=False, check_errors=False):
2767
2829
 
2768
2830
  D = C * jnp.eye(C.shape[-1], dtype=A.dtype)
2769
2831
  LD = C @ jnp.linalg.inv(D)
2770
- LD = fill_diagonal_batch(LD, D*D)
2832
+ LD = fill_diagonal_batch(LD, D * D)
2771
2833
  return LD, pivots, info
2772
2834
 
2773
2835
 
@@ -2787,9 +2849,9 @@ def _aten_linalg_lu(A, pivot=True, out=None):
2787
2849
  U = jnp.triu(lu[..., :k, :])
2788
2850
 
2789
2851
  def perm_to_P(perm):
2790
- m = perm.shape[-1]
2791
- P = jnp.eye(m, dtype=dtype)[perm].T
2792
- return P
2852
+ m = perm.shape[-1]
2853
+ P = jnp.eye(m, dtype=dtype)[perm].T
2854
+ return P
2793
2855
 
2794
2856
  if permutation.ndim > 1:
2795
2857
  num_batch_dims = permutation.ndim - 1
@@ -2798,7 +2860,7 @@ def _aten_linalg_lu(A, pivot=True, out=None):
2798
2860
 
2799
2861
  P = perm_to_P(permutation)
2800
2862
 
2801
- return P,L,U
2863
+ return P, L, U
2802
2864
 
2803
2865
 
2804
2866
  @op(torch.ops.aten.linalg_lu_factor_ex)
@@ -2810,6 +2872,21 @@ def _aten_linalg_lu_factor_ex(A, pivot=True, check_errors=False):
2810
2872
  return lu, pivots, info
2811
2873
 
2812
2874
 
2875
+ @op(torch.ops.aten.linalg_lu_solve)
2876
+ def _aten_linalg_lu_solve(LU, pivots, B, left=True, adjoint=False):
2877
+ # JAX pivots are offset by 1 compared to torch
2878
+ pivots = pivots - 1
2879
+ if not left:
2880
+ # XA = B is same as A'X = B'
2881
+ trans = 0 if adjoint else 2
2882
+ x = jax.scipy.linalg.lu_solve((LU, pivots), jnp.matrix_transpose(B), trans)
2883
+ x = jnp.matrix_transpose(x)
2884
+ else:
2885
+ trans = 2 if adjoint else 0
2886
+ x = jax.scipy.linalg.lu_solve((LU, pivots), B, trans)
2887
+ return x
2888
+
2889
+
2813
2890
  @op(torch.ops.aten.gcd)
2814
2891
  def _aten_gcd(input, other):
2815
2892
  return jnp.gcd(input, other)
@@ -2874,12 +2951,14 @@ def _aten_log2(x):
2874
2951
 
2875
2952
  # aten.logical_and
2876
2953
  @op(torch.ops.aten.logical_and)
2954
+ @op(torch.ops.aten.__and__)
2877
2955
  def _aten_logical_and(self, other):
2878
2956
  return jnp.logical_and(self, other)
2879
2957
 
2880
2958
 
2881
2959
  # aten.logical_or
2882
2960
  @op(torch.ops.aten.logical_or)
2961
+ @op(torch.ops.aten.__or__)
2883
2962
  def _aten_logical_or(self, other):
2884
2963
  return jnp.logical_or(self, other)
2885
2964
 
@@ -2894,7 +2973,7 @@ def _aten_logical_not(self):
2894
2973
  @op(torch.ops.aten._log_softmax)
2895
2974
  def _aten_log_softmax(self, axis=-1, half_to_float=False):
2896
2975
  if self.shape == ():
2897
- return jnp.astype(0.0, self.dtype)
2976
+ return jnp.astype(0.0, self.dtype)
2898
2977
  return jax.nn.log_softmax(self, axis)
2899
2978
 
2900
2979
 
@@ -2921,6 +3000,7 @@ def _aten_logcumsumexp(self, dim=None):
2921
3000
  # aten.max_pool3d_backward
2922
3001
  # aten.logical_xor
2923
3002
  @op(torch.ops.aten.logical_xor)
3003
+ @op(torch.ops.aten.__xor__)
2924
3004
  def _aten_logical_xor(self, other):
2925
3005
  return jnp.logical_xor(self, other)
2926
3006
 
@@ -2933,19 +3013,22 @@ def _aten_logical_xor(self, other):
2933
3013
  def _aten_neg(x):
2934
3014
  return -1 * x
2935
3015
 
3016
+
2936
3017
  @op(torch.ops.aten.nextafter)
2937
3018
  def _aten_nextafter(input, other, *, out=None):
2938
3019
  return jnp.nextafter(input, other)
2939
3020
 
2940
3021
 
2941
3022
  @op(torch.ops.aten.nonzero_static)
2942
- def _aten_nonzero_static(input, size, fill_value = -1):
3023
+ def _aten_nonzero_static(input, size, fill_value=-1):
2943
3024
  indices = jnp.argwhere(input)
2944
3025
 
2945
3026
  if size < indices.shape[0]:
2946
3027
  indices = indices[:size]
2947
3028
  elif size > indices.shape[0]:
2948
- padding = jnp.full((size - indices.shape[0], indices.shape[1]), fill_value, dtype=indices.dtype)
3029
+ padding = jnp.full((size - indices.shape[0], indices.shape[1]),
3030
+ fill_value,
3031
+ dtype=indices.dtype)
2949
3032
  indices = jnp.concatenate((indices, padding))
2950
3033
 
2951
3034
  return indices
@@ -2954,9 +3037,11 @@ def _aten_nonzero_static(input, size, fill_value = -1):
2954
3037
  # aten.nonzero
2955
3038
  @op(torch.ops.aten.nonzero)
2956
3039
  def _aten_nonzero(x, as_tuple=False):
2957
- if jnp.ndim(x) == 0 and (as_tuple or x.item()==0):
3040
+ if jnp.ndim(x) == 0 and (as_tuple or x.item() == 0):
2958
3041
  return torch.empty(0, 0, dtype=torch.int64)
2959
- if jnp.ndim(x) == 0: # when x is scalar, return torch.tensor([], size=(1, 0), dtype=torch.int64)
3042
+ if jnp.ndim(
3043
+ x
3044
+ ) == 0: # when x is scalar, return torch.tensor([], size=(1, 0), dtype=torch.int64)
2960
3045
  res = torch.empty(1, 0, dtype=torch.int64)
2961
3046
  return jnp.array(res.numpy())
2962
3047
  index_tuple = jnp.nonzero(x)
@@ -2997,15 +3082,15 @@ def _aten_put(self, index, source, accumulate=False):
2997
3082
  # aten.randperm
2998
3083
  # randperm.generator(SymInt n, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None)
2999
3084
  @op(torch.ops.aten.randperm, needs_env=True)
3000
- def _aten_randperm(
3001
- n, *,
3002
- generator=None,
3003
- dtype=None,
3004
- layout=None,
3005
- device=None,
3006
- pin_memory=None,
3007
- env=None):
3008
- """
3085
+ def _aten_randperm(n,
3086
+ *,
3087
+ generator=None,
3088
+ dtype=None,
3089
+ layout=None,
3090
+ device=None,
3091
+ pin_memory=None,
3092
+ env=None):
3093
+ """
3009
3094
  Generates a random permutation of integers from 0 to n-1.
3010
3095
 
3011
3096
  Args:
@@ -3019,14 +3104,14 @@ def _aten_randperm(
3019
3104
  Returns:
3020
3105
  A DeviceArray containing a random permutation of integers from 0 to n-1.
3021
3106
  """
3022
- if dtype:
3023
- dtype = mappings.t2j_dtype(dtype)
3024
- else:
3025
- dtype = jnp.int64.dtype
3026
- key = env.get_and_rotate_prng_key(generator)
3027
- indices = jnp.arange(n, dtype=dtype)
3028
- permutation = jax.random.permutation(key, indices)
3029
- return permutation
3107
+ if dtype:
3108
+ dtype = mappings.t2j_dtype(dtype)
3109
+ else:
3110
+ dtype = jnp.int64.dtype
3111
+ key = env.get_and_rotate_prng_key(generator)
3112
+ indices = jnp.arange(n, dtype=dtype)
3113
+ permutation = jax.random.permutation(key, indices)
3114
+ return permutation
3030
3115
 
3031
3116
 
3032
3117
  # aten.reflection_pad3d
@@ -3071,8 +3156,8 @@ def _aten_sort(a, dim=-1, descending=False, stable=False):
3071
3156
  if a.shape == ():
3072
3157
  return (a, jnp.astype(0, 'int64'))
3073
3158
  return (
3074
- jnp.sort(a, axis=dim, stable=stable, descending=descending),
3075
- jnp.argsort(a, axis=dim, stable=stable, descending=descending),
3159
+ jnp.sort(a, axis=dim, stable=stable, descending=descending),
3160
+ jnp.argsort(a, axis=dim, stable=stable, descending=descending),
3076
3161
  )
3077
3162
 
3078
3163
 
@@ -3114,8 +3199,8 @@ def _aten_topk(input, k, dim=None, largest=True, sorted=True, *, out=None):
3114
3199
  if dim != -1 and dim != len(input.shape) - 1:
3115
3200
  transpose_shape = list(range(len(input.shape)))
3116
3201
  transpose_shape[dim], transpose_shape[-1] = (
3117
- transpose_shape[-1],
3118
- transpose_shape[dim],
3202
+ transpose_shape[-1],
3203
+ transpose_shape[dim],
3119
3204
  )
3120
3205
  input = jnp.transpose(input, transpose_shape)
3121
3206
 
@@ -3124,8 +3209,7 @@ def _aten_topk(input, k, dim=None, largest=True, sorted=True, *, out=None):
3124
3209
  if sorted:
3125
3210
  values = jnp.sort(values, descending=True)
3126
3211
  indices = jnp.take_along_axis(
3127
- indices, jnp.argsort(values, axis=-1, descending=True), axis=-1
3128
- )
3212
+ indices, jnp.argsort(values, axis=-1, descending=True), axis=-1)
3129
3213
 
3130
3214
  if not largest:
3131
3215
  values = -values # Negate values back if we found smallest
@@ -3140,21 +3224,39 @@ def _aten_topk(input, k, dim=None, largest=True, sorted=True, *, out=None):
3140
3224
  # aten.tril_indices
3141
3225
  #tril_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None)
3142
3226
  @op(torch.ops.aten.tril_indices)
3143
- def _aten_tril_indices(row, col, offset=0, *, dtype=jnp.int64.dtype, layout=None, device=None, pin_memory=None):
3227
+ def _aten_tril_indices(row,
3228
+ col,
3229
+ offset=0,
3230
+ *,
3231
+ dtype=jnp.int64.dtype,
3232
+ layout=None,
3233
+ device=None,
3234
+ pin_memory=None):
3144
3235
  a, b = jnp.tril_indices(row, offset, col)
3145
3236
  return jnp.stack((a, b))
3146
3237
 
3238
+
3147
3239
  # aten.tril_indices
3148
3240
  #tril_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None)
3149
3241
  @op(torch.ops.aten.triu_indices)
3150
- def _aten_triu_indices(row, col, offset=0, *, dtype=jnp.int64.dtype, layout=None, device=None, pin_memory=None):
3242
+ def _aten_triu_indices(row,
3243
+ col,
3244
+ offset=0,
3245
+ *,
3246
+ dtype=jnp.int64.dtype,
3247
+ layout=None,
3248
+ device=None,
3249
+ pin_memory=None):
3151
3250
  a, b = jnp.triu_indices(row, offset, col)
3152
3251
  return jnp.stack((a, b))
3153
3252
 
3154
3253
 
3155
3254
  @op(torch.ops.aten.unbind_copy)
3156
3255
  def _aten_unbind(a, dim=0):
3157
- return [jax.lax.index_in_dim(a, i, dim, keepdims=False) for i in range(a.shape[dim])]
3256
+ return [
3257
+ jax.lax.index_in_dim(a, i, dim, keepdims=False)
3258
+ for i in range(a.shape[dim])
3259
+ ]
3158
3260
 
3159
3261
 
3160
3262
  # aten.unique_dim
@@ -3167,12 +3269,13 @@ def _aten_unique_dim(input_tensor,
3167
3269
  sort=True,
3168
3270
  return_inverse=False,
3169
3271
  return_counts=False):
3170
- result_tensor_or_tuple = jnp.unique(input_tensor,
3171
- return_index=False,
3172
- return_inverse=return_inverse,
3173
- return_counts=return_counts,
3174
- axis=dim,
3175
- equal_nan=False)
3272
+ result_tensor_or_tuple = jnp.unique(
3273
+ input_tensor,
3274
+ return_index=False,
3275
+ return_inverse=return_inverse,
3276
+ return_counts=return_counts,
3277
+ axis=dim,
3278
+ equal_nan=False)
3176
3279
  result_list = (
3177
3280
  list(result_tensor_or_tuple) if isinstance(result_tensor_or_tuple, tuple)
3178
3281
  else [result_tensor_or_tuple])
@@ -3197,15 +3300,14 @@ def _aten_unique_dim(input_tensor,
3197
3300
  # NOTE: Like the CUDA and CPU implementations, this implementation always sorts
3198
3301
  # the tensor regardless of the `sorted` argument passed to `torch.unique`.
3199
3302
  @op(torch.ops.aten._unique)
3200
- def _aten_unique(input_tensor,
3201
- sort=True,
3202
- return_inverse=False):
3203
- result_tensor_or_tuple = jnp.unique(input_tensor,
3204
- return_index=False,
3205
- return_inverse=return_inverse,
3206
- return_counts=False,
3207
- axis=None,
3208
- equal_nan=False)
3303
+ def _aten_unique(input_tensor, sort=True, return_inverse=False):
3304
+ result_tensor_or_tuple = jnp.unique(
3305
+ input_tensor,
3306
+ return_index=False,
3307
+ return_inverse=return_inverse,
3308
+ return_counts=False,
3309
+ axis=None,
3310
+ equal_nan=False)
3209
3311
  if return_inverse:
3210
3312
  return result_tensor_or_tuple
3211
3313
  else:
@@ -3221,11 +3323,12 @@ def _aten_unique2(input_tensor,
3221
3323
  sort=True,
3222
3324
  return_inverse=False,
3223
3325
  return_counts=False):
3224
- return _aten_unique_dim(input_tensor=input_tensor,
3225
- dim=None,
3226
- sort=sort,
3227
- return_inverse=return_inverse,
3228
- return_counts=return_counts)
3326
+ return _aten_unique_dim(
3327
+ input_tensor=input_tensor,
3328
+ dim=None,
3329
+ sort=sort,
3330
+ return_inverse=return_inverse,
3331
+ return_counts=return_counts)
3229
3332
 
3230
3333
 
3231
3334
  # aten.unique_consecutive
@@ -3255,17 +3358,18 @@ def _aten_unique_consecutive(input_tensor,
3255
3358
  if dim < 0:
3256
3359
  dim += ndim
3257
3360
 
3258
- nd_slice_0 = tuple(slice(None, -1) if d == dim else slice(None)
3259
- for d in range(ndim))
3260
- nd_slice_1 = tuple(slice(1, None) if d == dim else slice(None)
3261
- for d in range(ndim))
3361
+ nd_slice_0 = tuple(
3362
+ slice(None, -1) if d == dim else slice(None) for d in range(ndim))
3363
+ nd_slice_1 = tuple(
3364
+ slice(1, None) if d == dim else slice(None) for d in range(ndim))
3262
3365
 
3263
3366
  axes_to_reduce = tuple(d for d in range(ndim) if d != dim)
3264
3367
 
3265
3368
  does_not_equal_prior = (
3266
- jnp.any(input_tensor[nd_slice_0] != input_tensor[nd_slice_1],
3267
- axis=axes_to_reduce,
3268
- keepdims=False))
3369
+ jnp.any(
3370
+ input_tensor[nd_slice_0] != input_tensor[nd_slice_1],
3371
+ axis=axes_to_reduce,
3372
+ keepdims=False))
3269
3373
 
3270
3374
  if input_tensor.shape[dim] != 0:
3271
3375
  # Prepend `True` to represent the first element of the input.
@@ -3273,18 +3377,17 @@ def _aten_unique_consecutive(input_tensor,
3273
3377
 
3274
3378
  include_indices = jnp.argwhere(does_not_equal_prior)[:, 0]
3275
3379
 
3276
- output_tensor = input_tensor[
3277
- tuple(include_indices if d == dim else slice(None) for d in range(ndim))]
3380
+ output_tensor = input_tensor[tuple(
3381
+ include_indices if d == dim else slice(None) for d in range(ndim))]
3278
3382
 
3279
3383
  if return_inverse or return_counts:
3280
- counts = (jnp.append(include_indices[1:], input_tensor.shape[dim]) -
3281
- include_indices[:])
3384
+ counts = (
3385
+ jnp.append(include_indices[1:], input_tensor.shape[dim]) -
3386
+ include_indices[:])
3282
3387
 
3283
3388
  inverse = (
3284
3389
  jnp.reshape(jnp.repeat(jnp.arange(len(counts)), counts), inverse_shape)
3285
- if return_inverse
3286
- else None
3287
- )
3390
+ if return_inverse else None)
3288
3391
 
3289
3392
  return output_tensor, inverse, counts
3290
3393
 
@@ -3302,25 +3405,33 @@ def _aten_unique_consecutive(input_tensor,
3302
3405
  @op(torch.ops.aten.where.ScalarSelf)
3303
3406
  @op(torch.ops.aten.where.ScalarOther)
3304
3407
  @op(torch.ops.aten.where.Scalar)
3305
- def _aten_where(condition, x = None, y = None):
3408
+ def _aten_where(condition, x=None, y=None):
3306
3409
  return jnp.where(condition, x, y)
3307
3410
 
3308
3411
 
3309
3412
  # aten.to.dtype
3310
3413
  # Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None
3311
3414
  @op(torch.ops.aten.to.dtype)
3312
- def _aten_to_dtype(
3313
- a, dtype, non_blocking=False, copy=False, memory_format=None
3314
- ):
3415
+ def _aten_to_dtype(a,
3416
+ dtype,
3417
+ non_blocking=False,
3418
+ copy=False,
3419
+ memory_format=None):
3315
3420
  if dtype:
3316
3421
  jaxdtype = mappings.t2j_dtype(dtype)
3317
3422
  return a.astype(jaxdtype)
3318
3423
 
3319
3424
 
3320
3425
  @op(torch.ops.aten.to.dtype_layout)
3321
- def _aten_to_dtype_layout(
3322
- a, *, dtype=None, layout=None, device=None, pin_memory=None, non_blocking=False, copy=False, memory_format=None
3323
- ):
3426
+ def _aten_to_dtype_layout(a,
3427
+ *,
3428
+ dtype=None,
3429
+ layout=None,
3430
+ device=None,
3431
+ pin_memory=None,
3432
+ non_blocking=False,
3433
+ copy=False,
3434
+ memory_format=None):
3324
3435
  return _aten_to_dtype(
3325
3436
  a,
3326
3437
  dtype,
@@ -3328,6 +3439,7 @@ def _aten_to_dtype_layout(
3328
3439
  copy=copy,
3329
3440
  memory_format=memory_format)
3330
3441
 
3442
+
3331
3443
  # aten.to.device
3332
3444
 
3333
3445
 
@@ -3348,9 +3460,11 @@ def _aten_var_mean_correction(tensor, dim=None, correction=1, keepdim=False):
3348
3460
 
3349
3461
  @op(torch.ops.aten.scalar_tensor)
3350
3462
  @op_base.convert_dtype()
3351
- def _aten_scalar_tensor(
3352
- s, dtype=None, layout=None, device=None, pin_memory=None
3353
- ):
3463
+ def _aten_scalar_tensor(s,
3464
+ dtype=None,
3465
+ layout=None,
3466
+ device=None,
3467
+ pin_memory=None):
3354
3468
  return jnp.array(s, dtype=dtype)
3355
3469
 
3356
3470
 
@@ -3360,9 +3474,9 @@ def _aten_to_device(x, device, dtype):
3360
3474
 
3361
3475
 
3362
3476
  @op(torch.ops.aten.max_pool2d_with_indices_backward)
3363
- def max_pool2d_with_indices_backward_custom(
3364
- grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices
3365
- ):
3477
+ def max_pool2d_with_indices_backward_custom(grad_output, self, kernel_size,
3478
+ stride, padding, dilation,
3479
+ ceil_mode, indices):
3366
3480
  """
3367
3481
  Approximates the gradient calculation of PyTorch's max_pool2d_with_indices_backward.
3368
3482
 
@@ -3417,16 +3531,16 @@ def _aten_tensor_split(ary, indices_or_sections, axis=0):
3417
3531
 
3418
3532
  @op(torch.ops.aten.randn, needs_env=True)
3419
3533
  @op_base.convert_dtype()
3420
- def _randn(
3421
- *size,
3422
- generator=None,
3423
- out=None,
3424
- dtype=None,
3425
- layout=torch.strided,
3426
- device=None,
3427
- requires_grad=False,
3428
- pin_memory=False,
3429
- env=None,
3534
+ def _aten_randn(
3535
+ *size,
3536
+ generator=None,
3537
+ out=None,
3538
+ dtype=None,
3539
+ layout=torch.strided,
3540
+ device=None,
3541
+ requires_grad=False,
3542
+ pin_memory=False,
3543
+ env=None,
3430
3544
  ):
3431
3545
  shape = size
3432
3546
  if len(shape) == 1 and isinstance(shape[0], (list, tuple)):
@@ -3437,13 +3551,14 @@ def _randn(
3437
3551
  res = res.astype(dtype)
3438
3552
  return res
3439
3553
 
3554
+
3440
3555
  @op(torch.ops.aten.bernoulli.p, needs_env=True)
3441
- def _bernoulli(
3442
- self,
3443
- p = 0.5,
3444
- *,
3445
- generator=None,
3446
- env=None,
3556
+ def _aten_bernoulli(
3557
+ self,
3558
+ p=0.5,
3559
+ *,
3560
+ generator=None,
3561
+ env=None,
3447
3562
  ):
3448
3563
  key = env.get_and_rotate_prng_key(generator)
3449
3564
  res = jax.random.uniform(key, self.shape) < p
@@ -3460,14 +3575,14 @@ def geometric(self, p, *, generator=None, env=None):
3460
3575
  @op(torch.ops.aten.randn_like, needs_env=True)
3461
3576
  @op_base.convert_dtype()
3462
3577
  def _aten_randn_like(
3463
- x,
3464
- *,
3465
- dtype=None,
3466
- layout=None,
3467
- device=None,
3468
- pin_memory=False,
3469
- memory_format=torch.preserve_format,
3470
- env=None,
3578
+ x,
3579
+ *,
3580
+ dtype=None,
3581
+ layout=None,
3582
+ device=None,
3583
+ pin_memory=False,
3584
+ memory_format=torch.preserve_format,
3585
+ env=None,
3471
3586
  ):
3472
3587
  key = env.get_and_rotate_prng_key()
3473
3588
  return jax.random.normal(key, dtype=dtype or x.dtype, shape=x.shape)
@@ -3476,15 +3591,15 @@ def _aten_randn_like(
3476
3591
  @op(torch.ops.aten.rand, needs_env=True)
3477
3592
  @op_base.convert_dtype()
3478
3593
  def _rand(
3479
- *size,
3480
- generator=None,
3481
- out=None,
3482
- dtype=None,
3483
- layout=torch.strided,
3484
- device=None,
3485
- requires_grad=False,
3486
- pin_memory=False,
3487
- env=None,
3594
+ *size,
3595
+ generator=None,
3596
+ out=None,
3597
+ dtype=None,
3598
+ layout=torch.strided,
3599
+ device=None,
3600
+ requires_grad=False,
3601
+ pin_memory=False,
3602
+ env=None,
3488
3603
  ):
3489
3604
  shape = size
3490
3605
  if len(shape) == 1 and isinstance(shape[0], (list, tuple)):
@@ -3505,32 +3620,48 @@ def _aten_outer(a, b):
3505
3620
  def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):
3506
3621
  return jnp.allclose(input, other, rtol, atol, equal_nan)
3507
3622
 
3623
+
3508
3624
  @op(torch.ops.aten.native_batch_norm)
3509
- def _aten_native_batch_norm(input, weight, bias, running_mean, running_var, training=False, momentum=0.1, eps=1e-5):
3625
+ def _aten_native_batch_norm(input,
3626
+ weight,
3627
+ bias,
3628
+ running_mean,
3629
+ running_var,
3630
+ training=False,
3631
+ momentum=0.1,
3632
+ eps=1e-5):
3510
3633
 
3511
3634
  if running_mean is None:
3512
- running_mean = jnp.zeros(input.shape[1], dtype=input.dtype) # Initialize running mean if None
3635
+ running_mean = jnp.zeros(
3636
+ input.shape[1], dtype=input.dtype) # Initialize running mean if None
3513
3637
  if running_var is None:
3514
- running_var = jnp.ones(input.shape[1], dtype=input.dtype) # Initialize running variance if None
3638
+ running_var = jnp.ones(
3639
+ input.shape[1],
3640
+ dtype=input.dtype) # Initialize running variance if None
3515
3641
 
3516
3642
  if training:
3517
- return _aten__native_batch_norm_legit(input, weight, bias, running_mean, running_var, training, momentum, eps)
3643
+ return _aten__native_batch_norm_legit(input, weight, bias, running_mean,
3644
+ running_var, training, momentum, eps)
3518
3645
  else:
3519
- return _aten__native_batch_norm_legit_no_training(input, weight, bias, running_mean, running_var, momentum, eps)
3646
+ return _aten__native_batch_norm_legit_no_training(input, weight, bias,
3647
+ running_mean, running_var,
3648
+ momentum, eps)
3520
3649
 
3521
3650
 
3522
3651
  @op(torch.ops.aten.normal, needs_env=True)
3523
3652
  def _aten_normal(self, mean=0, std=1, generator=None, env=None):
3524
3653
  shape = self.shape
3525
- res = _randn(*shape, generator=generator, env=env)
3654
+ res = _aten_randn(*shape, generator=generator, env=env)
3526
3655
  return res * std + mean
3527
3656
 
3657
+
3528
3658
  # TODO: not clear what this function should actually do
3529
3659
  # https://github.com/pytorch/pytorch/blob/d96c80649f301129219469d8b4353e52edab3b78/aten/src/ATen/native/native_functions.yaml#L7933-L7940
3530
3660
  @op(torch.ops.aten.lift_fresh)
3531
3661
  def _aten_lift_fresh(self):
3532
3662
  return self
3533
3663
 
3664
+
3534
3665
  @op(torch.ops.aten.uniform, needs_env=True)
3535
3666
  def _aten_uniform(self, from_=0, to=1, *, generator=None, env=None):
3536
3667
  assert from_ <= to, f'Uniform from(passed in {from_}) must be less than to(passed in {to})'
@@ -3538,16 +3669,18 @@ def _aten_uniform(self, from_=0, to=1, *, generator=None, env=None):
3538
3669
  res = _rand(*shape, generator=generator, env=env)
3539
3670
  return res * (to - from_) + from_
3540
3671
 
3672
+
3541
3673
  #func: randint.low_generator(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
3542
3674
 
3675
+
3543
3676
  @op(torch.ops.aten.randint, needs_env=True)
3544
3677
  @op_base.convert_dtype(use_default_dtype=False)
3545
3678
  def _aten_randint(
3546
- *args,
3547
- generator=None,
3548
- dtype=None,
3549
- env=None,
3550
- **kwargs,
3679
+ *args,
3680
+ generator=None,
3681
+ dtype=None,
3682
+ env=None,
3683
+ **kwargs,
3551
3684
  ):
3552
3685
  if len(args) == 3:
3553
3686
  # low, high, size
@@ -3556,7 +3689,8 @@ def _aten_randint(
3556
3689
  high, size = args
3557
3690
  low = 0
3558
3691
  else:
3559
- raise AssertionError(f'Expected at 2 or 3 args for Aten::randint, got {len(args)}')
3692
+ raise AssertionError(
3693
+ f'Expected at 2 or 3 args for Aten::randint, got {len(args)}')
3560
3694
 
3561
3695
  key = env.get_and_rotate_prng_key(generator)
3562
3696
  res = jax.random.randint(key, size, low, high)
@@ -3564,15 +3698,18 @@ def _aten_randint(
3564
3698
  res = res.astype(dtype)
3565
3699
  return res
3566
3700
 
3567
- @op(torch.ops.aten.randint_like, torch.ops.aten.randint.generator, needs_env=True)
3701
+
3702
+ @op(torch.ops.aten.randint_like,
3703
+ torch.ops.aten.randint.generator,
3704
+ needs_env=True)
3568
3705
  @op_base.convert_dtype(use_default_dtype=False)
3569
3706
  def _aten_randint_like(
3570
- input,
3571
- *args,
3572
- generator=None,
3573
- dtype=None,
3574
- env=None,
3575
- **kwargs,
3707
+ input,
3708
+ *args,
3709
+ generator=None,
3710
+ dtype=None,
3711
+ env=None,
3712
+ **kwargs,
3576
3713
  ):
3577
3714
  if len(args) == 2:
3578
3715
  low, high = args
@@ -3580,7 +3717,8 @@ def _aten_randint_like(
3580
3717
  high = args[0]
3581
3718
  low = 0
3582
3719
  else:
3583
- raise AssertionError(f'Expected at 1 or 2 args for Aten::randint_like, got {len(args)}')
3720
+ raise AssertionError(
3721
+ f'Expected at 1 or 2 args for Aten::randint_like, got {len(args)}')
3584
3722
 
3585
3723
  shape = input.shape
3586
3724
  dtype = dtype or input.dtype
@@ -3590,6 +3728,7 @@ def _aten_randint_like(
3590
3728
  res = res.astype(dtype)
3591
3729
  return res
3592
3730
 
3731
+
3593
3732
  @op(torch.ops.aten.dim, is_jax_function=False)
3594
3733
  def _aten_dim(self):
3595
3734
  return len(self.shape)
@@ -3602,10 +3741,11 @@ def _aten_copysign(input, other, *, out=None):
3602
3741
  # regardless of their exact integer dtype, whereas jax.copysign returns
3603
3742
  # float64 when one or both of them is int64.
3604
3743
  if jnp.issubdtype(input.dtype, jnp.integer) and jnp.issubdtype(
3605
- other.dtype, jnp.integer
3606
- ):
3744
+ other.dtype, jnp.integer):
3607
3745
  result = result.astype(jnp.float32)
3608
3746
  return result
3747
+
3748
+
3609
3749
  @op(torch.ops.aten.i0)
3610
3750
  @op_base.promote_int_input
3611
3751
  def _aten_i0(self):
@@ -3637,6 +3777,7 @@ def _aten_special_laguerre_polynomial_l(self, n):
3637
3777
 
3638
3778
  @jnp.vectorize
3639
3779
  def vectorized(x, n_i):
3780
+
3640
3781
  def negative_n(x):
3641
3782
  return jnp.zeros_like(x)
3642
3783
 
@@ -3650,6 +3791,7 @@ def _aten_special_laguerre_polynomial_l(self, n):
3650
3791
  return jnp.ones_like(x)
3651
3792
 
3652
3793
  def default(x):
3794
+
3653
3795
  def f(k, carry):
3654
3796
  p, q = carry
3655
3797
  return (q, ((k * 2 + (jnp.ones_like(x) - x)) * q - k * p) / (k + 1))
@@ -3658,9 +3800,9 @@ def _aten_special_laguerre_polynomial_l(self, n):
3658
3800
  return q
3659
3801
 
3660
3802
  return jnp.piecewise(
3661
- x, [n_i == 1, n_i == 0, jnp.abs(n_i) == jnp.zeros_like(x), n_i < 0], [
3662
- one_n, zero_n, zero_abs, negative_n, default]
3663
- )
3803
+ x, [n_i == 1, n_i == 0,
3804
+ jnp.abs(n_i) == jnp.zeros_like(x), n_i < 0],
3805
+ [one_n, zero_n, zero_abs, negative_n, default])
3664
3806
 
3665
3807
  return vectorized(self, n.astype(jnp.int64))
3666
3808
 
@@ -3760,125 +3902,124 @@ def _aten_special_modified_bessel_i0(self):
3760
3902
  return jnp.exp(x) * (0.5 * (b - p)) / jnp.sqrt(x)
3761
3903
 
3762
3904
  self = jnp.abs(self)
3763
- return jnp.piecewise(
3764
- self, [self <= 8], [small, default]
3765
- )
3905
+ return jnp.piecewise(self, [self <= 8], [small, default])
3906
+
3766
3907
 
3767
3908
  @op(torch.ops.aten.special_modified_bessel_i1)
3768
3909
  @op_base.promote_int_input
3769
3910
  def _aten_special_modified_bessel_i1(self):
3770
- # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3271-L3364
3771
-
3772
- def small(x):
3773
- A = jnp.array(
3774
- [
3775
- 2.77791411276104639959e-18,
3776
- -2.11142121435816608115e-17,
3777
- 1.55363195773620046921e-16,
3778
- -1.10559694773538630805e-15,
3779
- 7.60068429473540693410e-15,
3780
- -5.04218550472791168711e-14,
3781
- 3.22379336594557470981e-13,
3782
- -1.98397439776494371520e-12,
3783
- 1.17361862988909016308e-11,
3784
- -6.66348972350202774223e-11,
3785
- 3.62559028155211703701e-10,
3786
- -1.88724975172282928790e-09,
3787
- 9.38153738649577178388e-09,
3788
- -4.44505912879632808065e-08,
3789
- 2.00329475355213526229e-07,
3790
- -8.56872026469545474066e-07,
3791
- 3.47025130813767847674e-06,
3792
- -1.32731636560394358279e-05,
3793
- 4.78156510755005422638e-05,
3794
- -1.61760815825896745588e-04,
3795
- 5.12285956168575772895e-04,
3796
- -1.51357245063125314899e-03,
3797
- 4.15642294431288815669e-03,
3798
- -1.05640848946261981558e-02,
3799
- 2.47264490306265168283e-02,
3800
- -5.29459812080949914269e-02,
3801
- 1.02643658689847095384e-01,
3802
- -1.76416518357834055153e-01,
3803
- 2.52587186443633654823e-01,
3804
- ],
3805
- dtype=self.dtype,
3806
- )
3807
-
3808
- def f(carry, val):
3809
- p, q, a = carry
3810
- p, q = q, a
3811
- return (p, q, ((jnp.abs(x) / 2.0) - 2.0) * q - p + val), None
3812
-
3813
- (p, _, a), _ = jax.lax.scan(
3814
- f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A)
3815
-
3816
- return jax.lax.cond(
3817
- x < 0, lambda: -(0.5 * (a - p) * jnp.abs(x) * jnp.exp(jnp.abs(x))), lambda: 0.5 * (a - p) * jnp.abs(x) * jnp.exp(jnp.abs(x))
3818
- )
3911
+ # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3271-L3364
3819
3912
 
3820
- def default(x):
3821
- B = jnp.array(
3822
- [
3823
- 7.51729631084210481353e-18,
3824
- 4.41434832307170791151e-18,
3825
- -4.65030536848935832153e-17,
3826
- -3.20952592199342395980e-17,
3827
- 2.96262899764595013876e-16,
3828
- 3.30820231092092828324e-16,
3829
- -1.88035477551078244854e-15,
3830
- -3.81440307243700780478e-15,
3831
- 1.04202769841288027642e-14,
3832
- 4.27244001671195135429e-14,
3833
- -2.10154184277266431302e-14,
3834
- -4.08355111109219731823e-13,
3835
- -7.19855177624590851209e-13,
3836
- 2.03562854414708950722e-12,
3837
- 1.41258074366137813316e-11,
3838
- 3.25260358301548823856e-11,
3839
- -1.89749581235054123450e-11,
3840
- -5.58974346219658380687e-10,
3841
- -3.83538038596423702205e-09,
3842
- -2.63146884688951950684e-08,
3843
- -2.51223623787020892529e-07,
3844
- -3.88256480887769039346e-06,
3845
- -1.10588938762623716291e-04,
3846
- -9.76109749136146840777e-03,
3847
- 7.78576235018280120474e-01,
3848
- ],
3849
- dtype=self.dtype,
3850
- )
3851
-
3852
- def f(carry, val):
3853
- p, q, b = carry
3854
- p, q = q, b
3855
- return (p, q, (32.0 / jnp.abs(x) - 2.0) * q - p + val), None
3856
-
3857
- (p, _, b), _ = jax.lax.scan(
3858
- f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B)
3859
-
3860
- return jax.lax.cond(
3861
- x < 0, lambda: -(jnp.exp(jnp.abs(x)) * (0.5 * (b - p)) / jnp.sqrt(jnp.abs(x))), lambda: jnp.exp(jnp.abs(x)) * (0.5 * (b - p)) / jnp.sqrt(jnp.abs(x))
3862
- )
3913
+ def small(x):
3914
+ A = jnp.array(
3915
+ [
3916
+ 2.77791411276104639959e-18,
3917
+ -2.11142121435816608115e-17,
3918
+ 1.55363195773620046921e-16,
3919
+ -1.10559694773538630805e-15,
3920
+ 7.60068429473540693410e-15,
3921
+ -5.04218550472791168711e-14,
3922
+ 3.22379336594557470981e-13,
3923
+ -1.98397439776494371520e-12,
3924
+ 1.17361862988909016308e-11,
3925
+ -6.66348972350202774223e-11,
3926
+ 3.62559028155211703701e-10,
3927
+ -1.88724975172282928790e-09,
3928
+ 9.38153738649577178388e-09,
3929
+ -4.44505912879632808065e-08,
3930
+ 2.00329475355213526229e-07,
3931
+ -8.56872026469545474066e-07,
3932
+ 3.47025130813767847674e-06,
3933
+ -1.32731636560394358279e-05,
3934
+ 4.78156510755005422638e-05,
3935
+ -1.61760815825896745588e-04,
3936
+ 5.12285956168575772895e-04,
3937
+ -1.51357245063125314899e-03,
3938
+ 4.15642294431288815669e-03,
3939
+ -1.05640848946261981558e-02,
3940
+ 2.47264490306265168283e-02,
3941
+ -5.29459812080949914269e-02,
3942
+ 1.02643658689847095384e-01,
3943
+ -1.76416518357834055153e-01,
3944
+ 2.52587186443633654823e-01,
3945
+ ],
3946
+ dtype=self.dtype,
3947
+ )
3863
3948
 
3864
- return jnp.piecewise(
3865
- self, [self <= 8], [small, default]
3949
+ def f(carry, val):
3950
+ p, q, a = carry
3951
+ p, q = q, a
3952
+ return (p, q, ((jnp.abs(x) / 2.0) - 2.0) * q - p + val), None
3953
+
3954
+ (p, _, a), _ = jax.lax.scan(
3955
+ f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A)
3956
+
3957
+ return jax.lax.cond(
3958
+ x < 0, lambda: -(0.5 * (a - p) * jnp.abs(x) * jnp.exp(jnp.abs(x))),
3959
+ lambda: 0.5 * (a - p) * jnp.abs(x) * jnp.exp(jnp.abs(x)))
3960
+
3961
+ def default(x):
3962
+ B = jnp.array(
3963
+ [
3964
+ 7.51729631084210481353e-18,
3965
+ 4.41434832307170791151e-18,
3966
+ -4.65030536848935832153e-17,
3967
+ -3.20952592199342395980e-17,
3968
+ 2.96262899764595013876e-16,
3969
+ 3.30820231092092828324e-16,
3970
+ -1.88035477551078244854e-15,
3971
+ -3.81440307243700780478e-15,
3972
+ 1.04202769841288027642e-14,
3973
+ 4.27244001671195135429e-14,
3974
+ -2.10154184277266431302e-14,
3975
+ -4.08355111109219731823e-13,
3976
+ -7.19855177624590851209e-13,
3977
+ 2.03562854414708950722e-12,
3978
+ 1.41258074366137813316e-11,
3979
+ 3.25260358301548823856e-11,
3980
+ -1.89749581235054123450e-11,
3981
+ -5.58974346219658380687e-10,
3982
+ -3.83538038596423702205e-09,
3983
+ -2.63146884688951950684e-08,
3984
+ -2.51223623787020892529e-07,
3985
+ -3.88256480887769039346e-06,
3986
+ -1.10588938762623716291e-04,
3987
+ -9.76109749136146840777e-03,
3988
+ 7.78576235018280120474e-01,
3989
+ ],
3990
+ dtype=self.dtype,
3866
3991
  )
3867
3992
 
3993
+ def f(carry, val):
3994
+ p, q, b = carry
3995
+ p, q = q, b
3996
+ return (p, q, (32.0 / jnp.abs(x) - 2.0) * q - p + val), None
3997
+
3998
+ (p, _, b), _ = jax.lax.scan(
3999
+ f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B)
4000
+
4001
+ return jax.lax.cond(
4002
+ x < 0, lambda: -(jnp.exp(jnp.abs(x)) *
4003
+ (0.5 * (b - p)) / jnp.sqrt(jnp.abs(x))),
4004
+ lambda: jnp.exp(jnp.abs(x)) * (0.5 * (b - p)) / jnp.sqrt(jnp.abs(x)))
4005
+
4006
+ return jnp.piecewise(self, [self <= 8], [small, default])
4007
+
4008
+
3868
4009
  @op(torch.ops.aten.special_modified_bessel_k0)
3869
4010
  @op_base.promote_int_input
3870
4011
  def _aten_special_modified_bessel_k0(self):
3871
- # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3367-L3441
4012
+ # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3367-L3441
3872
4013
 
3873
- def zero(x):
3874
- return jnp.array(jnp.inf, x.dtype)
4014
+ def zero(x):
4015
+ return jnp.array(jnp.inf, x.dtype)
3875
4016
 
3876
- def negative(x):
3877
- return jnp.array(jnp.nan, x.dtype)
4017
+ def negative(x):
4018
+ return jnp.array(jnp.nan, x.dtype)
3878
4019
 
3879
- def small(x):
3880
- A = jnp.array(
3881
- [
4020
+ def small(x):
4021
+ A = jnp.array(
4022
+ [
3882
4023
  1.37446543561352307156e-16,
3883
4024
  4.25981614279661018399e-14,
3884
4025
  1.03496952576338420167e-11,
@@ -3889,23 +4030,24 @@ def _aten_special_modified_bessel_k0(self):
3889
4030
  3.59799365153615016266e-02,
3890
4031
  3.44289899924628486886e-01,
3891
4032
  -5.35327393233902768720e-01,
3892
- ],
3893
- dtype=self.dtype,
3894
- )
4033
+ ],
4034
+ dtype=self.dtype,
4035
+ )
3895
4036
 
3896
- def f(carry, val):
3897
- p, q, a = carry
3898
- p, q = q, a
3899
- return (p, q, (x * x - 2.0) * q - p + val), None
4037
+ def f(carry, val):
4038
+ p, q, a = carry
4039
+ p, q = q, a
4040
+ return (p, q, (x * x - 2.0) * q - p + val), None
4041
+
4042
+ (p, _, a), _ = jax.lax.scan(
4043
+ f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A)
3900
4044
 
3901
- (p, _, a), _ = jax.lax.scan(
3902
- f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A)
3903
-
3904
- return 0.5 * (a - p) - jnp.log(0.5 * x) * _aten_special_modified_bessel_i0(x)
4045
+ return 0.5 * (a - p) - jnp.log(
4046
+ 0.5 * x) * _aten_special_modified_bessel_i0(x)
3905
4047
 
3906
- def default(x):
3907
- B = jnp.array(
3908
- [
4048
+ def default(x):
4049
+ B = jnp.array(
4050
+ [
3909
4051
  5.30043377268626276149e-18,
3910
4052
  -1.64758043015242134646e-17,
3911
4053
  5.21039150503902756861e-17,
@@ -3931,38 +4073,38 @@ def _aten_special_modified_bessel_k0(self):
3931
4073
  1.56988388573005337491e-03,
3932
4074
  -3.14481013119645005427e-02,
3933
4075
  2.44030308206595545468e+00,
3934
- ],
3935
- dtype=self.dtype,
3936
- )
4076
+ ],
4077
+ dtype=self.dtype,
4078
+ )
4079
+
4080
+ def f(carry, val):
4081
+ p, q, b = carry
4082
+ p, q = q, b
4083
+ return (p, q, (8.0 / x - 2.0) * q - p + val), None
3937
4084
 
3938
- def f(carry, val):
3939
- p, q, b = carry
3940
- p, q = q, b
3941
- return (p, q, (8.0 / x - 2.0) * q - p + val), None
4085
+ (p, _, b), _ = jax.lax.scan(
4086
+ f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B)
3942
4087
 
3943
- (p, _, b), _ = jax.lax.scan(
3944
- f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B)
3945
-
3946
- return jnp.exp(-x) * (0.5 * (b - p)) / jnp.sqrt(x)
4088
+ return jnp.exp(-x) * (0.5 * (b - p)) / jnp.sqrt(x)
4089
+
4090
+ return jnp.piecewise(self, [self <= 2, self < 0, self == 0],
4091
+ [small, negative, zero, default])
3947
4092
 
3948
- return jnp.piecewise(
3949
- self, [self <= 2, self < 0, self == 0], [small, negative, zero, default]
3950
- )
3951
4093
 
3952
4094
  @op(torch.ops.aten.special_modified_bessel_k1)
3953
4095
  @op_base.promote_int_input
3954
4096
  def _aten_special_modified_bessel_k1(self):
3955
- # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3444-L3519
4097
+ # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3444-L3519
3956
4098
 
3957
- def zero(x):
3958
- return jnp.array(jnp.inf, x.dtype)
4099
+ def zero(x):
4100
+ return jnp.array(jnp.inf, x.dtype)
3959
4101
 
3960
- def negative(x):
3961
- return jnp.array(jnp.nan, x.dtype)
4102
+ def negative(x):
4103
+ return jnp.array(jnp.nan, x.dtype)
3962
4104
 
3963
- def small(x):
3964
- A = jnp.array(
3965
- [
4105
+ def small(x):
4106
+ A = jnp.array(
4107
+ [
3966
4108
  -7.02386347938628759343e-18,
3967
4109
  -2.42744985051936593393e-15,
3968
4110
  -6.66690169419932900609e-13,
@@ -3974,24 +4116,25 @@ def _aten_special_modified_bessel_k1(self):
3974
4116
  -1.22611180822657148235e-01,
3975
4117
  -3.53155960776544875667e-01,
3976
4118
  1.52530022733894777053e+00,
3977
- ],
3978
- dtype=self.dtype,
3979
- )
4119
+ ],
4120
+ dtype=self.dtype,
4121
+ )
3980
4122
 
3981
- def f(carry, val):
3982
- p, q, a = carry
3983
- p, q = q, a
3984
- a = (x * x - 2.0) * q - p + val
3985
- return (p, q, a), None
4123
+ def f(carry, val):
4124
+ p, q, a = carry
4125
+ p, q = q, a
4126
+ a = (x * x - 2.0) * q - p + val
4127
+ return (p, q, a), None
4128
+
4129
+ (p, _, a), _ = jax.lax.scan(
4130
+ f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A)
3986
4131
 
3987
- (p, _, a), _ = jax.lax.scan(
3988
- f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A)
3989
-
3990
- return jnp.log(0.5 * x) * _aten_special_modified_bessel_i1(x) + 0.5 * (a - p) / x
4132
+ return jnp.log(
4133
+ 0.5 * x) * _aten_special_modified_bessel_i1(x) + 0.5 * (a - p) / x
3991
4134
 
3992
- def default(x):
3993
- B = jnp.array(
3994
- [
4135
+ def default(x):
4136
+ B = jnp.array(
4137
+ [
3995
4138
  -5.75674448366501715755e-18,
3996
4139
  1.79405087314755922667e-17,
3997
4140
  -5.68946255844285935196e-17,
@@ -4017,24 +4160,24 @@ def _aten_special_modified_bessel_k1(self):
4017
4160
  -2.85781685962277938680e-03,
4018
4161
  1.03923736576817238437e-01,
4019
4162
  2.72062619048444266945e+00,
4020
- ],
4021
- dtype=self.dtype,
4022
- )
4163
+ ],
4164
+ dtype=self.dtype,
4165
+ )
4166
+
4167
+ def f(carry, val):
4168
+ p, q, b = carry
4169
+ p, q = q, b
4170
+ b = (8.0 / x - 2.0) * q - p + val
4171
+ return (p, q, b), None
4172
+
4173
+ (p, _, b), _ = jax.lax.scan(
4174
+ f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B)
4023
4175
 
4024
- def f(carry, val):
4025
- p, q, b = carry
4026
- p, q = q, b
4027
- b = (8.0 / x - 2.0) * q - p + val
4028
- return (p, q, b), None
4176
+ return jnp.exp(-x) * (0.5 * (b - p)) / jnp.sqrt(x)
4029
4177
 
4030
- (p, _, b), _ = jax.lax.scan(
4031
- f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B)
4032
-
4033
- return jnp.exp(-x) * (0.5 * (b - p)) / jnp.sqrt(x)
4178
+ return jnp.piecewise(self, [self <= 2, self < 0, self == 0],
4179
+ [small, negative, zero, default])
4034
4180
 
4035
- return jnp.piecewise(
4036
- self, [self <= 2, self < 0, self == 0], [small, negative, zero, default]
4037
- )
4038
4181
 
4039
4182
  @op(torch.ops.aten.polygamma)
4040
4183
  def _aten_polygamma(x, n):
@@ -4042,10 +4185,12 @@ def _aten_polygamma(x, n):
4042
4185
  n = n.astype(mappings.t2j_dtype(torch.get_default_dtype()))
4043
4186
  return jax.lax.polygamma(jnp.float32(x), n)
4044
4187
 
4188
+
4045
4189
  @op(torch.ops.aten.special_ndtri)
4046
4190
  @op_base.promote_int_input
4047
4191
  def _aten_special_ndtri(self):
4048
- return jax.scipy.special.ndtri(self)
4192
+ return jax.scipy.special.ndtri(self)
4193
+
4049
4194
 
4050
4195
  @op(torch.ops.aten.special_bessel_j0)
4051
4196
  @op_base.promote_int_input
@@ -4057,112 +4202,104 @@ def _aten_special_bessel_j0(self):
4057
4202
 
4058
4203
  def small(x):
4059
4204
  RP = jnp.array(
4060
- [
4061
- -4.79443220978201773821e09,
4062
- 1.95617491946556577543e12,
4063
- -2.49248344360967716204e14,
4064
- 9.70862251047306323952e15,
4065
- ],
4066
- dtype=self.dtype,
4205
+ [
4206
+ -4.79443220978201773821e09,
4207
+ 1.95617491946556577543e12,
4208
+ -2.49248344360967716204e14,
4209
+ 9.70862251047306323952e15,
4210
+ ],
4211
+ dtype=self.dtype,
4067
4212
  )
4068
4213
  RQ = jnp.array(
4069
- [
4070
- 4.99563147152651017219e02,
4071
- 1.73785401676374683123e05,
4072
- 4.84409658339962045305e07,
4073
- 1.11855537045356834862e10,
4074
- 2.11277520115489217587e12,
4075
- 3.10518229857422583814e14,
4076
- 3.18121955943204943306e16,
4077
- 1.71086294081043136091e18,
4078
- ],
4079
- dtype=self.dtype,
4214
+ [
4215
+ 4.99563147152651017219e02,
4216
+ 1.73785401676374683123e05,
4217
+ 4.84409658339962045305e07,
4218
+ 1.11855537045356834862e10,
4219
+ 2.11277520115489217587e12,
4220
+ 3.10518229857422583814e14,
4221
+ 3.18121955943204943306e16,
4222
+ 1.71086294081043136091e18,
4223
+ ],
4224
+ dtype=self.dtype,
4080
4225
  )
4081
4226
 
4082
4227
  rp = op_base.foreach_loop(RP, lambda carry, rp_i: carry * (x * x) + rp_i)
4083
4228
  rq = op_base.foreach_loop(RQ, lambda carry, rq_i: carry * (x * x) + rq_i)
4084
4229
 
4085
- return (
4086
- (x * x - 5.78318596294678452118e00)
4087
- * (x * x - 3.04712623436620863991e01)
4088
- * rp
4089
- / rq
4090
- )
4230
+ return ((x * x - 5.78318596294678452118e00) *
4231
+ (x * x - 3.04712623436620863991e01) * rp / rq)
4091
4232
 
4092
4233
  def default(x):
4093
4234
  PP = jnp.array(
4094
- [
4095
- 7.96936729297347051624e-04,
4096
- 8.28352392107440799803e-02,
4097
- 1.23953371646414299388e00,
4098
- 5.44725003058768775090e00,
4099
- 8.74716500199817011941e00,
4100
- 5.30324038235394892183e00,
4101
- 9.99999999999999997821e-01,
4102
- ],
4103
- dtype=self.dtype,
4235
+ [
4236
+ 7.96936729297347051624e-04,
4237
+ 8.28352392107440799803e-02,
4238
+ 1.23953371646414299388e00,
4239
+ 5.44725003058768775090e00,
4240
+ 8.74716500199817011941e00,
4241
+ 5.30324038235394892183e00,
4242
+ 9.99999999999999997821e-01,
4243
+ ],
4244
+ dtype=self.dtype,
4104
4245
  )
4105
4246
  PQ = jnp.array(
4106
- [
4107
- 9.24408810558863637013e-04,
4108
- 8.56288474354474431428e-02,
4109
- 1.25352743901058953537e00,
4110
- 5.47097740330417105182e00,
4111
- 8.76190883237069594232e00,
4112
- 5.30605288235394617618e00,
4113
- 1.00000000000000000218e00,
4114
- ],
4115
- dtype=self.dtype,
4247
+ [
4248
+ 9.24408810558863637013e-04,
4249
+ 8.56288474354474431428e-02,
4250
+ 1.25352743901058953537e00,
4251
+ 5.47097740330417105182e00,
4252
+ 8.76190883237069594232e00,
4253
+ 5.30605288235394617618e00,
4254
+ 1.00000000000000000218e00,
4255
+ ],
4256
+ dtype=self.dtype,
4116
4257
  )
4117
4258
  QP = jnp.array(
4118
- [
4119
- -1.13663838898469149931e-02,
4120
- -1.28252718670509318512e00,
4121
- -1.95539544257735972385e01,
4122
- -9.32060152123768231369e01,
4123
- -1.77681167980488050595e02,
4124
- -1.47077505154951170175e02,
4125
- -5.14105326766599330220e01,
4126
- -6.05014350600728481186e00,
4127
- ],
4128
- dtype=self.dtype,
4259
+ [
4260
+ -1.13663838898469149931e-02,
4261
+ -1.28252718670509318512e00,
4262
+ -1.95539544257735972385e01,
4263
+ -9.32060152123768231369e01,
4264
+ -1.77681167980488050595e02,
4265
+ -1.47077505154951170175e02,
4266
+ -5.14105326766599330220e01,
4267
+ -6.05014350600728481186e00,
4268
+ ],
4269
+ dtype=self.dtype,
4129
4270
  )
4130
4271
  QQ = jnp.array(
4131
- [
4132
- 6.43178256118178023184e01,
4133
- 8.56430025976980587198e02,
4134
- 3.88240183605401609683e03,
4135
- 7.24046774195652478189e03,
4136
- 5.93072701187316984827e03,
4137
- 2.06209331660327847417e03,
4138
- 2.42005740240291393179e02,
4139
- ],
4140
- dtype=self.dtype,
4272
+ [
4273
+ 6.43178256118178023184e01,
4274
+ 8.56430025976980587198e02,
4275
+ 3.88240183605401609683e03,
4276
+ 7.24046774195652478189e03,
4277
+ 5.93072701187316984827e03,
4278
+ 2.06209331660327847417e03,
4279
+ 2.42005740240291393179e02,
4280
+ ],
4281
+ dtype=self.dtype,
4141
4282
  )
4142
4283
 
4143
- pp = op_base.foreach_loop(PP, lambda carry, pp_i: carry * (25.0 / (x * x)) + pp_i)
4144
- pq = op_base.foreach_loop(PQ, lambda carry, pq_i: carry * (25.0 / (x * x)) + pq_i)
4145
- qp = op_base.foreach_loop(QP, lambda carry, qp_i: carry * (25.0 / (x * x)) + qp_i)
4146
- qq = op_base.foreach_loop(QQ, lambda carry, qq_i: carry * (25.0 / (x * x)) + qq_i)
4147
-
4148
- return (
4149
- (
4150
- pp / pq * jnp.cos(x - 0.785398163397448309615660845819875721)
4151
- - 5.0
4152
- / x
4153
- * (qp / qq)
4154
- * jnp.sin(x - 0.785398163397448309615660845819875721)
4155
- )
4156
- * 0.797884560802865355879892119868763737
4157
- / jnp.sqrt(x)
4158
- )
4284
+ pp = op_base.foreach_loop(
4285
+ PP, lambda carry, pp_i: carry * (25.0 / (x * x)) + pp_i)
4286
+ pq = op_base.foreach_loop(
4287
+ PQ, lambda carry, pq_i: carry * (25.0 / (x * x)) + pq_i)
4288
+ qp = op_base.foreach_loop(
4289
+ QP, lambda carry, qp_i: carry * (25.0 / (x * x)) + qp_i)
4290
+ qq = op_base.foreach_loop(
4291
+ QQ, lambda carry, qq_i: carry * (25.0 / (x * x)) + qq_i)
4292
+
4293
+ return ((pp / pq * jnp.cos(x - 0.785398163397448309615660845819875721) -
4294
+ 5.0 / x *
4295
+ (qp / qq) * jnp.sin(x - 0.785398163397448309615660845819875721)) *
4296
+ 0.797884560802865355879892119868763737 / jnp.sqrt(x))
4159
4297
 
4160
4298
  self = jnp.abs(self)
4161
4299
  # Last True condition in `piecewise` takes priority, but last function is
4162
4300
  # default. See https://github.com/numpy/numpy/issues/16475
4163
- return jnp.piecewise(
4164
- self, [self <= 5.0, self < 0.00001], [small, very_small, default]
4165
- )
4301
+ return jnp.piecewise(self, [self <= 5.0, self < 0.00001],
4302
+ [small, very_small, default])
4166
4303
 
4167
4304
 
4168
4305
  @op(torch.ops.aten.special_bessel_j1)
@@ -4172,114 +4309,106 @@ def _aten_special_bessel_j1(self):
4172
4309
 
4173
4310
  def small(x):
4174
4311
  RP = jnp.array(
4175
- [
4176
- -8.99971225705559398224e08,
4177
- 4.52228297998194034323e11,
4178
- -7.27494245221818276015e13,
4179
- 3.68295732863852883286e15,
4180
- ],
4181
- dtype=self.dtype,
4312
+ [
4313
+ -8.99971225705559398224e08,
4314
+ 4.52228297998194034323e11,
4315
+ -7.27494245221818276015e13,
4316
+ 3.68295732863852883286e15,
4317
+ ],
4318
+ dtype=self.dtype,
4182
4319
  )
4183
4320
  RQ = jnp.array(
4184
- [
4185
- 6.20836478118054335476e02,
4186
- 2.56987256757748830383e05,
4187
- 8.35146791431949253037e07,
4188
- 2.21511595479792499675e10,
4189
- 4.74914122079991414898e12,
4190
- 7.84369607876235854894e14,
4191
- 8.95222336184627338078e16,
4192
- 5.32278620332680085395e18,
4193
- ],
4194
- dtype=self.dtype,
4321
+ [
4322
+ 6.20836478118054335476e02,
4323
+ 2.56987256757748830383e05,
4324
+ 8.35146791431949253037e07,
4325
+ 2.21511595479792499675e10,
4326
+ 4.74914122079991414898e12,
4327
+ 7.84369607876235854894e14,
4328
+ 8.95222336184627338078e16,
4329
+ 5.32278620332680085395e18,
4330
+ ],
4331
+ dtype=self.dtype,
4195
4332
  )
4196
4333
 
4197
4334
  rp = op_base.foreach_loop(RP, lambda carry, rp_i: carry * (x * x) + rp_i)
4198
4335
  rq = op_base.foreach_loop(RQ, lambda carry, rq_i: carry * (x * x) + rq_i)
4199
4336
 
4200
- return (
4201
- rp
4202
- / rq
4203
- * x
4204
- * (x * x - 1.46819706421238932572e01)
4205
- * (x * x - 4.92184563216946036703e01)
4206
- )
4337
+ return (rp / rq * x * (x * x - 1.46819706421238932572e01) *
4338
+ (x * x - 4.92184563216946036703e01))
4207
4339
 
4208
4340
  def default(x):
4209
4341
  PP = jnp.array(
4210
- [
4211
- 7.62125616208173112003e-04,
4212
- 7.31397056940917570436e-02,
4213
- 1.12719608129684925192e00,
4214
- 5.11207951146807644818e00,
4215
- 8.42404590141772420927e00,
4216
- 5.21451598682361504063e00,
4217
- 1.00000000000000000254e00,
4218
- ],
4219
- dtype=self.dtype,
4342
+ [
4343
+ 7.62125616208173112003e-04,
4344
+ 7.31397056940917570436e-02,
4345
+ 1.12719608129684925192e00,
4346
+ 5.11207951146807644818e00,
4347
+ 8.42404590141772420927e00,
4348
+ 5.21451598682361504063e00,
4349
+ 1.00000000000000000254e00,
4350
+ ],
4351
+ dtype=self.dtype,
4220
4352
  )
4221
4353
  PQ = jnp.array(
4222
- [
4223
- 5.71323128072548699714e-04,
4224
- 6.88455908754495404082e-02,
4225
- 1.10514232634061696926e00,
4226
- 5.07386386128601488557e00,
4227
- 8.39985554327604159757e00,
4228
- 5.20982848682361821619e00,
4229
- 9.99999999999999997461e-01,
4230
- ],
4231
- dtype=self.dtype,
4354
+ [
4355
+ 5.71323128072548699714e-04,
4356
+ 6.88455908754495404082e-02,
4357
+ 1.10514232634061696926e00,
4358
+ 5.07386386128601488557e00,
4359
+ 8.39985554327604159757e00,
4360
+ 5.20982848682361821619e00,
4361
+ 9.99999999999999997461e-01,
4362
+ ],
4363
+ dtype=self.dtype,
4232
4364
  )
4233
4365
  QP = jnp.array(
4234
- [
4235
- 5.10862594750176621635e-02,
4236
- 4.98213872951233449420e00,
4237
- 7.58238284132545283818e01,
4238
- 3.66779609360150777800e02,
4239
- 7.10856304998926107277e02,
4240
- 5.97489612400613639965e02,
4241
- 2.11688757100572135698e02,
4242
- 2.52070205858023719784e01,
4243
- ],
4244
- dtype=self.dtype,
4366
+ [
4367
+ 5.10862594750176621635e-02,
4368
+ 4.98213872951233449420e00,
4369
+ 7.58238284132545283818e01,
4370
+ 3.66779609360150777800e02,
4371
+ 7.10856304998926107277e02,
4372
+ 5.97489612400613639965e02,
4373
+ 2.11688757100572135698e02,
4374
+ 2.52070205858023719784e01,
4375
+ ],
4376
+ dtype=self.dtype,
4245
4377
  )
4246
4378
  QQ = jnp.array(
4247
- [
4248
- 7.42373277035675149943e01,
4249
- 1.05644886038262816351e03,
4250
- 4.98641058337653607651e03,
4251
- 9.56231892404756170795e03,
4252
- 7.99704160447350683650e03,
4253
- 2.82619278517639096600e03,
4254
- 3.36093607810698293419e02,
4255
- ],
4256
- dtype=self.dtype,
4379
+ [
4380
+ 7.42373277035675149943e01,
4381
+ 1.05644886038262816351e03,
4382
+ 4.98641058337653607651e03,
4383
+ 9.56231892404756170795e03,
4384
+ 7.99704160447350683650e03,
4385
+ 2.82619278517639096600e03,
4386
+ 3.36093607810698293419e02,
4387
+ ],
4388
+ dtype=self.dtype,
4257
4389
  )
4258
4390
 
4259
- pp = op_base.foreach_loop(PP, lambda carry, pp_i: carry * (25.0 / (x * x)) + pp_i)
4260
- pq = op_base.foreach_loop(PQ, lambda carry, pq_i: carry * (25.0 / (x * x)) + pq_i)
4261
- qp = op_base.foreach_loop(QP, lambda carry, qp_i: carry * (25.0 / (x * x)) + qp_i)
4262
- qq = op_base.foreach_loop(QQ, lambda carry, qq_i: carry * (25.0 / (x * x)) + qq_i)
4263
-
4264
- return (
4265
- (
4266
- pp / pq * jnp.cos(x - 2.356194490192344928846982537459627163)
4267
- - 5.0
4268
- / x
4269
- * (qp / qq)
4270
- * jnp.sin(x - 2.356194490192344928846982537459627163)
4271
- )
4272
- * 0.797884560802865355879892119868763737
4273
- / jnp.sqrt(x)
4274
- )
4391
+ pp = op_base.foreach_loop(
4392
+ PP, lambda carry, pp_i: carry * (25.0 / (x * x)) + pp_i)
4393
+ pq = op_base.foreach_loop(
4394
+ PQ, lambda carry, pq_i: carry * (25.0 / (x * x)) + pq_i)
4395
+ qp = op_base.foreach_loop(
4396
+ QP, lambda carry, qp_i: carry * (25.0 / (x * x)) + qp_i)
4397
+ qq = op_base.foreach_loop(
4398
+ QQ, lambda carry, qq_i: carry * (25.0 / (x * x)) + qq_i)
4399
+
4400
+ return ((pp / pq * jnp.cos(x - 2.356194490192344928846982537459627163) -
4401
+ 5.0 / x *
4402
+ (qp / qq) * jnp.sin(x - 2.356194490192344928846982537459627163)) *
4403
+ 0.797884560802865355879892119868763737 / jnp.sqrt(x))
4275
4404
 
4276
4405
  # If x < 0, bessel_j1(x) = -bessel_j1(-x)
4277
4406
  sign = jnp.sign(self)
4278
4407
  self = jnp.abs(self)
4279
4408
  return sign * jnp.piecewise(
4280
- self,
4281
- [self <= 5.0],
4282
- [small, default],
4409
+ self,
4410
+ [self <= 5.0],
4411
+ [small, default],
4283
4412
  )
4284
4413
 
4285
4414
 
@@ -4296,85 +4425,86 @@ def _aten_special_bessel_y0(self):
4296
4425
 
4297
4426
  def small(x):
4298
4427
  YP = jnp.array(
4299
- [
4300
- 1.55924367855235737965e04,
4301
- -1.46639295903971606143e07,
4302
- 5.43526477051876500413e09,
4303
- -9.82136065717911466409e11,
4304
- 8.75906394395366999549e13,
4305
- -3.46628303384729719441e15,
4306
- 4.42733268572569800351e16,
4307
- -1.84950800436986690637e16,
4308
- ],
4309
- dtype=self.dtype,
4428
+ [
4429
+ 1.55924367855235737965e04,
4430
+ -1.46639295903971606143e07,
4431
+ 5.43526477051876500413e09,
4432
+ -9.82136065717911466409e11,
4433
+ 8.75906394395366999549e13,
4434
+ -3.46628303384729719441e15,
4435
+ 4.42733268572569800351e16,
4436
+ -1.84950800436986690637e16,
4437
+ ],
4438
+ dtype=self.dtype,
4310
4439
  )
4311
4440
  YQ = jnp.array(
4312
- [
4313
- 1.04128353664259848412e03,
4314
- 6.26107330137134956842e05,
4315
- 2.68919633393814121987e08,
4316
- 8.64002487103935000337e10,
4317
- 2.02979612750105546709e13,
4318
- 3.17157752842975028269e15,
4319
- 2.50596256172653059228e17,
4320
- ],
4321
- dtype=self.dtype,
4441
+ [
4442
+ 1.04128353664259848412e03,
4443
+ 6.26107330137134956842e05,
4444
+ 2.68919633393814121987e08,
4445
+ 8.64002487103935000337e10,
4446
+ 2.02979612750105546709e13,
4447
+ 3.17157752842975028269e15,
4448
+ 2.50596256172653059228e17,
4449
+ ],
4450
+ dtype=self.dtype,
4322
4451
  )
4323
4452
 
4324
4453
  yp = op_base.foreach_loop(YP, lambda carry, yp_i: carry * (x * x) + yp_i)
4325
4454
  yq = op_base.foreach_loop(YQ, lambda carry, yq_i: carry * (x * x) + yq_i)
4326
4455
 
4327
- return yp / yq + (0.636619772367581343075535053490057448 * jnp.log(x) * _aten_special_bessel_j0(x))
4456
+ return yp / yq + (0.636619772367581343075535053490057448 * jnp.log(x) *
4457
+ _aten_special_bessel_j0(x))
4328
4458
 
4329
4459
  def default(x):
4330
4460
  PP = jnp.array(
4331
- [
4332
- 7.96936729297347051624e-04,
4333
- 8.28352392107440799803e-02,
4334
- 1.23953371646414299388e00,
4335
- 5.44725003058768775090e00,
4336
- 8.74716500199817011941e00,
4337
- 5.30324038235394892183e00,
4338
- 9.99999999999999997821e-01,
4339
- ],
4340
- dtype=self.dtype,
4461
+ [
4462
+ 7.96936729297347051624e-04,
4463
+ 8.28352392107440799803e-02,
4464
+ 1.23953371646414299388e00,
4465
+ 5.44725003058768775090e00,
4466
+ 8.74716500199817011941e00,
4467
+ 5.30324038235394892183e00,
4468
+ 9.99999999999999997821e-01,
4469
+ ],
4470
+ dtype=self.dtype,
4341
4471
  )
4342
4472
  PQ = jnp.array(
4343
- [
4344
- 9.24408810558863637013e-04,
4345
- 8.56288474354474431428e-02,
4346
- 1.25352743901058953537e00,
4347
- 5.47097740330417105182e00,
4348
- 8.76190883237069594232e00,
4349
- 5.30605288235394617618e00,
4350
- 1.00000000000000000218e00,
4351
- ],
4352
- dtype=self.dtype,
4473
+ [
4474
+ 9.24408810558863637013e-04,
4475
+ 8.56288474354474431428e-02,
4476
+ 1.25352743901058953537e00,
4477
+ 5.47097740330417105182e00,
4478
+ 8.76190883237069594232e00,
4479
+ 5.30605288235394617618e00,
4480
+ 1.00000000000000000218e00,
4481
+ ],
4482
+ dtype=self.dtype,
4353
4483
  )
4354
4484
  QP = jnp.array(
4355
- [
4356
- -1.13663838898469149931e-02,
4357
- -1.28252718670509318512e00,
4358
- -1.95539544257735972385e01,
4359
- -9.32060152123768231369e01,
4360
- -1.77681167980488050595e02,
4361
- -1.47077505154951170175e02,
4362
- -5.14105326766599330220e01,
4363
- -6.05014350600728481186e00,
4364
- ],
4365
- dtype=self.dtype,
4485
+ [
4486
+ -1.13663838898469149931e-02,
4487
+ -1.28252718670509318512e00,
4488
+ -1.95539544257735972385e01,
4489
+ -9.32060152123768231369e01,
4490
+ -1.77681167980488050595e02,
4491
+ -1.47077505154951170175e02,
4492
+ -5.14105326766599330220e01,
4493
+ -6.05014350600728481186e00,
4494
+ ],
4495
+ dtype=self.dtype,
4366
4496
  )
4367
4497
  QQ = jnp.array(
4368
- [
4369
- 6.43178256118178023184e01,
4370
- 8.56430025976980587198e02,
4371
- 3.88240183605401609683e03,
4372
- 7.24046774195652478189e03,
4373
- 5.93072701187316984827e03,
4374
- 2.06209331660327847417e03,
4375
- 2.42005740240291393179e02,
4376
- ],
4377
- dtype=self.dtype,
4498
+ [
4499
+ 6.43178256118178023184e01,
4500
+ 8.56430025976980587198e02,
4501
+ 3.88240183605401609683e03,
4502
+ 7.24046774195652478189e03,
4503
+ 5.93072701187316984827e03,
4504
+ 2.06209331660327847417e03,
4505
+ 2.42005740240291393179e02,
4506
+ ],
4507
+ dtype=self.dtype,
4378
4508
  )
4379
4509
 
4380
4510
  factor = 25.0 / (x * x)
@@ -4383,22 +4513,15 @@ def _aten_special_bessel_y0(self):
4383
4513
  qp = op_base.foreach_loop(QP, lambda carry, qp_i: carry * factor + qp_i)
4384
4514
  qq = op_base.foreach_loop(QQ, lambda carry, qq_i: carry * factor + qq_i)
4385
4515
 
4386
- return (
4387
- (
4388
- pp / pq * jnp.sin(x - 0.785398163397448309615660845819875721)
4389
- + 5.0
4390
- / x
4391
- * (qp / qq)
4392
- * jnp.cos(x - 0.785398163397448309615660845819875721)
4393
- )
4394
- * 0.797884560802865355879892119868763737
4395
- / jnp.sqrt(x)
4396
- )
4516
+ return ((pp / pq * jnp.sin(x - 0.785398163397448309615660845819875721) +
4517
+ 5.0 / x *
4518
+ (qp / qq) * jnp.cos(x - 0.785398163397448309615660845819875721)) *
4519
+ 0.797884560802865355879892119868763737 / jnp.sqrt(x))
4397
4520
 
4398
4521
  return jnp.piecewise(
4399
- self,
4400
- [self <= 5.0, self < 0., self == 0.],
4401
- [small, negative, zero, default],
4522
+ self,
4523
+ [self <= 5.0, self < 0., self == 0.],
4524
+ [small, negative, zero, default],
4402
4525
  )
4403
4526
 
4404
4527
 
@@ -4415,90 +4538,86 @@ def _aten_special_bessel_y1(self):
4415
4538
 
4416
4539
  def small(x):
4417
4540
  YP = jnp.array(
4418
- [
4419
- 1.26320474790178026440e09,
4420
- -6.47355876379160291031e11,
4421
- 1.14509511541823727583e14,
4422
- -8.12770255501325109621e15,
4423
- 2.02439475713594898196e17,
4424
- -7.78877196265950026825e17,
4425
- ],
4426
- dtype=self.dtype,
4541
+ [
4542
+ 1.26320474790178026440e09,
4543
+ -6.47355876379160291031e11,
4544
+ 1.14509511541823727583e14,
4545
+ -8.12770255501325109621e15,
4546
+ 2.02439475713594898196e17,
4547
+ -7.78877196265950026825e17,
4548
+ ],
4549
+ dtype=self.dtype,
4427
4550
  )
4428
4551
  YQ = jnp.array(
4429
- [
4430
- 5.94301592346128195359e02,
4431
- 2.35564092943068577943e05,
4432
- 7.34811944459721705660e07,
4433
- 1.87601316108706159478e10,
4434
- 3.88231277496238566008e12,
4435
- 6.20557727146953693363e14,
4436
- 6.87141087355300489866e16,
4437
- 3.97270608116560655612e18,
4438
- ],
4439
- dtype=self.dtype,
4552
+ [
4553
+ 5.94301592346128195359e02,
4554
+ 2.35564092943068577943e05,
4555
+ 7.34811944459721705660e07,
4556
+ 1.87601316108706159478e10,
4557
+ 3.88231277496238566008e12,
4558
+ 6.20557727146953693363e14,
4559
+ 6.87141087355300489866e16,
4560
+ 3.97270608116560655612e18,
4561
+ ],
4562
+ dtype=self.dtype,
4440
4563
  )
4441
4564
 
4442
4565
  yp = op_base.foreach_loop(YP, lambda carry, yp_i: carry * (x * x) + yp_i)
4443
4566
  yq = op_base.foreach_loop(YQ, lambda carry, yq_i: carry * (x * x) + yq_i)
4444
4567
 
4445
- return (
4446
- x * (yp / yq)
4447
- + (
4448
- 0.636619772367581343075535053490057448
4449
- * (_aten_special_bessel_j1(x) * jnp.log(x) - 1.0 / x)
4450
- )
4451
- )
4568
+ return (x * (yp / yq) +
4569
+ (0.636619772367581343075535053490057448 *
4570
+ (_aten_special_bessel_j1(x) * jnp.log(x) - 1.0 / x)))
4452
4571
 
4453
4572
  def default(x):
4454
4573
  PP = jnp.array(
4455
- [
4456
- 7.62125616208173112003e-04,
4457
- 7.31397056940917570436e-02,
4458
- 1.12719608129684925192e00,
4459
- 5.11207951146807644818e00,
4460
- 8.42404590141772420927e00,
4461
- 5.21451598682361504063e00,
4462
- 1.00000000000000000254e00,
4463
- ],
4464
- dtype=self.dtype,
4574
+ [
4575
+ 7.62125616208173112003e-04,
4576
+ 7.31397056940917570436e-02,
4577
+ 1.12719608129684925192e00,
4578
+ 5.11207951146807644818e00,
4579
+ 8.42404590141772420927e00,
4580
+ 5.21451598682361504063e00,
4581
+ 1.00000000000000000254e00,
4582
+ ],
4583
+ dtype=self.dtype,
4465
4584
  )
4466
4585
  PQ = jnp.array(
4467
- [
4468
- 5.71323128072548699714e-04,
4469
- 6.88455908754495404082e-02,
4470
- 1.10514232634061696926e00,
4471
- 5.07386386128601488557e00,
4472
- 8.39985554327604159757e00,
4473
- 5.20982848682361821619e00,
4474
- 9.99999999999999997461e-01,
4475
- ],
4476
- dtype=self.dtype,
4586
+ [
4587
+ 5.71323128072548699714e-04,
4588
+ 6.88455908754495404082e-02,
4589
+ 1.10514232634061696926e00,
4590
+ 5.07386386128601488557e00,
4591
+ 8.39985554327604159757e00,
4592
+ 5.20982848682361821619e00,
4593
+ 9.99999999999999997461e-01,
4594
+ ],
4595
+ dtype=self.dtype,
4477
4596
  )
4478
4597
  QP = jnp.array(
4479
- [
4480
- 5.10862594750176621635e-02,
4481
- 4.98213872951233449420e00,
4482
- 7.58238284132545283818e01,
4483
- 3.66779609360150777800e02,
4484
- 7.10856304998926107277e02,
4485
- 5.97489612400613639965e02,
4486
- 2.11688757100572135698e02,
4487
- 2.52070205858023719784e01,
4488
- ],
4489
- dtype=self.dtype,
4598
+ [
4599
+ 5.10862594750176621635e-02,
4600
+ 4.98213872951233449420e00,
4601
+ 7.58238284132545283818e01,
4602
+ 3.66779609360150777800e02,
4603
+ 7.10856304998926107277e02,
4604
+ 5.97489612400613639965e02,
4605
+ 2.11688757100572135698e02,
4606
+ 2.52070205858023719784e01,
4607
+ ],
4608
+ dtype=self.dtype,
4490
4609
  )
4491
4610
  QQ = jnp.array(
4492
- [
4493
- 7.42373277035675149943e01,
4494
- 1.05644886038262816351e03,
4495
- 4.98641058337653607651e03,
4496
- 9.56231892404756170795e03,
4497
- 7.99704160447350683650e03,
4498
- 2.82619278517639096600e03,
4499
- 3.36093607810698293419e02,
4500
- ],
4501
- dtype=self.dtype,
4611
+ [
4612
+ 7.42373277035675149943e01,
4613
+ 1.05644886038262816351e03,
4614
+ 4.98641058337653607651e03,
4615
+ 9.56231892404756170795e03,
4616
+ 7.99704160447350683650e03,
4617
+ 2.82619278517639096600e03,
4618
+ 3.36093607810698293419e02,
4619
+ ],
4620
+ dtype=self.dtype,
4502
4621
  )
4503
4622
 
4504
4623
  factor = 25.0 / (x * x)
@@ -4507,22 +4626,15 @@ def _aten_special_bessel_y1(self):
4507
4626
  qp = op_base.foreach_loop(QP, lambda carry, qp_i: carry * factor + qp_i)
4508
4627
  qq = op_base.foreach_loop(QQ, lambda carry, qq_i: carry * factor + qq_i)
4509
4628
 
4510
- return (
4511
- (
4512
- pp / pq * jnp.sin(x - 2.356194490192344928846982537459627163)
4513
- + 5.0
4514
- / x
4515
- * (qp / qq)
4516
- * jnp.cos(x - 2.356194490192344928846982537459627163)
4517
- )
4518
- * 0.797884560802865355879892119868763737
4519
- / jnp.sqrt(x)
4520
- )
4629
+ return ((pp / pq * jnp.sin(x - 2.356194490192344928846982537459627163) +
4630
+ 5.0 / x *
4631
+ (qp / qq) * jnp.cos(x - 2.356194490192344928846982537459627163)) *
4632
+ 0.797884560802865355879892119868763737 / jnp.sqrt(x))
4521
4633
 
4522
4634
  return jnp.piecewise(
4523
- self,
4524
- [self <= 5.0, self < 0., self == 0.],
4525
- [small, negative, zero, default],
4635
+ self,
4636
+ [self <= 5.0, self < 0., self == 0.],
4637
+ [small, negative, zero, default],
4526
4638
  )
4527
4639
 
4528
4640
 
@@ -4533,11 +4645,13 @@ def _aten_special_chebyshev_polynomial_t(self, n):
4533
4645
 
4534
4646
  @jnp.vectorize
4535
4647
  def vectorized(x, n_i):
4648
+
4536
4649
  def negative_n(x):
4537
4650
  return jnp.zeros_like(x)
4538
4651
 
4539
4652
  def one_x(x):
4540
- return jnp.where((x > 0) | (n_i % 2 == 0), jnp.ones_like(x), -jnp.ones_like(x))
4653
+ return jnp.where((x > 0) | (n_i % 2 == 0), jnp.ones_like(x),
4654
+ -jnp.ones_like(x))
4541
4655
 
4542
4656
  def large_n_small_x(x):
4543
4657
  return jnp.cos(n_i * jnp.acos(x))
@@ -4549,24 +4663,18 @@ def _aten_special_chebyshev_polynomial_t(self, n):
4549
4663
  return x
4550
4664
 
4551
4665
  def default(x):
4666
+
4552
4667
  def f(_, carry):
4553
4668
  p, q = carry
4554
4669
  return (q, 2 * x * q - p)
4555
4670
 
4556
- _, r = jax.lax.fori_loop(0, n_i - 1, f, init_val=(1., x))
4671
+ _, r = jax.lax.fori_loop(0, n_i - 1, f, init_val=(1., x))
4557
4672
  return r
4558
4673
 
4559
- return jnp.piecewise(
4560
- x,
4561
- [
4562
- n_i == 1,
4563
- n_i == 0,
4564
- (n_i == 6) & (jnp.abs(x) < 1),
4565
- jnp.abs(x) == 1.,
4566
- n_i < 0
4567
- ],
4568
- [one_n, zero_n, large_n_small_x, one_x, negative_n, default]
4569
- )
4674
+ return jnp.piecewise(x, [
4675
+ n_i == 1, n_i == 0, (n_i == 6) & (jnp.abs(x) < 1),
4676
+ jnp.abs(x) == 1., n_i < 0
4677
+ ], [one_n, zero_n, large_n_small_x, one_x, negative_n, default])
4570
4678
 
4571
4679
  # Explcicitly vectorize since we must vectorizes over both self and n
4572
4680
  return vectorized(self, n.astype(jnp.int64))
@@ -4579,6 +4687,7 @@ def _aten_special_chebyshev_polynomial_u(self, n):
4579
4687
 
4580
4688
  @jnp.vectorize
4581
4689
  def vectorized(x, n_i):
4690
+
4582
4691
  def negative_n(x):
4583
4692
  return jnp.zeros_like(x)
4584
4693
 
@@ -4588,9 +4697,9 @@ def _aten_special_chebyshev_polynomial_u(self, n):
4588
4697
  def large_n_small_x(x):
4589
4698
  sin_acos_x = jnp.sin(jnp.acos(x))
4590
4699
  return jnp.where(
4591
- sin_acos_x != 0,
4592
- jnp.sin((n_i + 1) * jnp.acos(x)) / sin_acos_x,
4593
- (n_i + 1) * jnp.cos((n_i + 1) * jnp.acos(x)) / x,
4700
+ sin_acos_x != 0,
4701
+ jnp.sin((n_i + 1) * jnp.acos(x)) / sin_acos_x,
4702
+ (n_i + 1) * jnp.cos((n_i + 1) * jnp.acos(x)) / x,
4594
4703
  )
4595
4704
 
4596
4705
  def zero_n(x):
@@ -4600,6 +4709,7 @@ def _aten_special_chebyshev_polynomial_u(self, n):
4600
4709
  return 2 * x
4601
4710
 
4602
4711
  def default(x):
4712
+
4603
4713
  def f(_, carry):
4604
4714
  p, q = carry
4605
4715
  return (q, 2 * x * q - p)
@@ -4608,15 +4718,15 @@ def _aten_special_chebyshev_polynomial_u(self, n):
4608
4718
  return r
4609
4719
 
4610
4720
  return jnp.piecewise(
4611
- x,
4612
- [
4613
- n_i == 1,
4614
- n_i == 0,
4615
- (n_i > 8) & (jnp.abs(x) < 1),
4616
- jnp.abs(x) == 1.0,
4617
- n_i < 0,
4618
- ],
4619
- [one_n, zero_n, large_n_small_x, one_x, negative_n, default],
4721
+ x,
4722
+ [
4723
+ n_i == 1,
4724
+ n_i == 0,
4725
+ (n_i > 8) & (jnp.abs(x) < 1),
4726
+ jnp.abs(x) == 1.0,
4727
+ n_i < 0,
4728
+ ],
4729
+ [one_n, zero_n, large_n_small_x, one_x, negative_n, default],
4620
4730
  )
4621
4731
 
4622
4732
  return vectorized(self, n.astype(jnp.int64))
@@ -4627,6 +4737,7 @@ def _aten_special_chebyshev_polynomial_u(self, n):
4627
4737
  def _aten_special_erfcx(x):
4628
4738
  return jnp.exp(x * x) * jax.lax.erfc(x)
4629
4739
 
4740
+
4630
4741
  @op(torch.ops.aten.erfc)
4631
4742
  @op_base.promote_int_input
4632
4743
  def _aten_erfcx(x):
@@ -4640,6 +4751,7 @@ def _aten_special_hermite_polynomial_h(self, n):
4640
4751
 
4641
4752
  @jnp.vectorize
4642
4753
  def vectorized(x, n_i):
4754
+
4643
4755
  def negative_n(x):
4644
4756
  return jnp.zeros_like(x)
4645
4757
 
@@ -4650,6 +4762,7 @@ def _aten_special_hermite_polynomial_h(self, n):
4650
4762
  return 2 * x
4651
4763
 
4652
4764
  def default(x):
4765
+
4653
4766
  def f(k, carry):
4654
4767
  p, q = carry
4655
4768
  return (q, 2 * x * q - 2 * k * p)
@@ -4657,9 +4770,8 @@ def _aten_special_hermite_polynomial_h(self, n):
4657
4770
  _, r = jax.lax.fori_loop(1, n_i, f, init_val=(1.0, 2 * x))
4658
4771
  return r
4659
4772
 
4660
- return jnp.piecewise(
4661
- x, [n_i == 1, n_i == 0, n_i < 0], [one_n, zero_n, negative_n, default]
4662
- )
4773
+ return jnp.piecewise(x, [n_i == 1, n_i == 0, n_i < 0],
4774
+ [one_n, zero_n, negative_n, default])
4663
4775
 
4664
4776
  return vectorized(self, n.astype(jnp.int64))
4665
4777
 
@@ -4671,6 +4783,7 @@ def _aten_special_hermite_polynomial_he(self, n):
4671
4783
 
4672
4784
  @jnp.vectorize
4673
4785
  def vectorized(x, n_i):
4786
+
4674
4787
  def negative_n(x):
4675
4788
  return jnp.zeros_like(x)
4676
4789
 
@@ -4681,6 +4794,7 @@ def _aten_special_hermite_polynomial_he(self, n):
4681
4794
  return x
4682
4795
 
4683
4796
  def default(x):
4797
+
4684
4798
  def f(k, carry):
4685
4799
  p, q = carry
4686
4800
  return (q, x * q - k * p)
@@ -4688,24 +4802,34 @@ def _aten_special_hermite_polynomial_he(self, n):
4688
4802
  _, r = jax.lax.fori_loop(1, n_i, f, init_val=(1.0, x))
4689
4803
  return r
4690
4804
 
4691
- return jnp.piecewise(
4692
- x, [n_i == 1.0, n_i == 0.0, n_i < 0], [one_n, zero_n, negative_n, default]
4693
- )
4805
+ return jnp.piecewise(x, [n_i == 1.0, n_i == 0.0, n_i < 0],
4806
+ [one_n, zero_n, negative_n, default])
4694
4807
 
4695
4808
  return vectorized(self, n.astype(jnp.int64))
4696
4809
 
4697
4810
 
4698
4811
  @op(torch.ops.aten.multinomial, needs_env=True)
4699
- def _aten_multinomial(input, num_samples, replacement=False, *, generator=None, out=None, env=None):
4700
- assert num_samples <= input.shape[-1] or replacement, "cannot take a larger sample than population when replacement=False"
4701
- assert jnp.all(input >= 0), "inputs must be non-negative"
4812
+ def _aten_multinomial(input,
4813
+ num_samples,
4814
+ replacement=False,
4815
+ *,
4816
+ generator=None,
4817
+ out=None,
4818
+ env=None):
4819
+ assert num_samples <= input.shape[
4820
+ -1] or replacement, "cannot take a larger sample than population when replacement=False"
4702
4821
  key = env.get_and_rotate_prng_key(generator)
4703
4822
  if input.ndim == 1:
4704
- assert jnp.sum(input) > 0, "rows of input must have non-zero sum"
4705
- return jax.random.choice(key, input.shape[-1], (num_samples,), replace=replacement, p=input)
4823
+ return jax.random.choice(
4824
+ key, input.shape[-1], (num_samples,), replace=replacement, p=input)
4706
4825
  else:
4707
- assert jnp.all(jnp.sum(input, axis=1) > 0), "rows of input must have non-zero sum"
4708
- return jnp.array([jax.random.choice(key, input.shape[-1], (num_samples,), replace=replacement, p=input[i, :]) for i in range(input.shape[0])])
4826
+ return jnp.array([
4827
+ jax.random.choice(
4828
+ key,
4829
+ input.shape[-1], (num_samples,),
4830
+ replace=replacement,
4831
+ p=input[i, :]) for i in range(input.shape[0])
4832
+ ])
4709
4833
 
4710
4834
 
4711
4835
  @op(torch.ops.aten.narrow)
@@ -4738,7 +4862,12 @@ def _aten_flatten(x, start_dim=0, end_dim=-1):
4738
4862
 
4739
4863
  @op(torch.ops.aten.new_empty)
4740
4864
  def _new_empty(self, size, **kwargs):
4741
- return jnp.empty(size)
4865
+ dtype = kwargs.get('dtype')
4866
+ if dtype is not None:
4867
+ dtype = mappings.t2j_dtype(dtype)
4868
+ else:
4869
+ dtype = self.dtype
4870
+ return jnp.empty(size, dtype=dtype)
4742
4871
 
4743
4872
 
4744
4873
  @op(torch.ops.aten.new_empty_strided)
@@ -4756,10 +4885,8 @@ def _aten_unsafe_index_put(self, indices, values, accumulate=False):
4756
4885
  return _aten_index_put(self, indices, values, accumulate)
4757
4886
 
4758
4887
 
4759
- @op(torch.ops.aten.conj_physical,
4760
- torch.ops.aten.conj,
4761
- torch.ops.aten._conj_physical,
4762
- torch.ops.aten._conj)
4888
+ @op(torch.ops.aten.conj_physical, torch.ops.aten.conj,
4889
+ torch.ops.aten._conj_physical, torch.ops.aten._conj)
4763
4890
  def _aten_conj_physical(self):
4764
4891
  return jnp.conjugate(self)
4765
4892
 
@@ -4768,6 +4895,7 @@ def _aten_conj_physical(self):
4768
4895
  def _aten_log_sigmoid(x):
4769
4896
  return jax.nn.log_sigmoid(x)
4770
4897
 
4898
+
4771
4899
  # torch.qr
4772
4900
  @op(torch.ops.aten.qr)
4773
4901
  def _aten_qr(input, *args, **kwargs):
@@ -4778,6 +4906,7 @@ def _aten_qr(input, *args, **kwargs):
4778
4906
  jax_mode = "complete"
4779
4907
  return jax.numpy.linalg.qr(input, mode=jax_mode)
4780
4908
 
4909
+
4781
4910
  # torch.linalg.qr
4782
4911
  @op(torch.ops.aten.linalg_qr)
4783
4912
  def _aten_linalg_qr(input, *args, **kwargs):
@@ -4820,19 +4949,25 @@ def _aten__linalg_solve_ex(a, b):
4820
4949
  res = jnp.linalg.solve(a, b)
4821
4950
  if batched:
4822
4951
  res = res.squeeze(-1)
4823
- info_shape = a.shape[0] if len(a.shape) >= 3 else []
4952
+ info_shape = a.shape[:-2]
4824
4953
  info = jnp.zeros(info_shape, dtype=mappings.t2j_dtype(torch.int32))
4825
4954
  return res, info
4826
4955
 
4827
4956
 
4828
4957
  # torch.linalg.solve_triangular
4829
4958
  @op(torch.ops.aten.linalg_solve_triangular)
4830
- def _aten_linalg_solve_triangular(a, b, *, upper=True, left=True, unitriangular=False):
4959
+ def _aten_linalg_solve_triangular(a,
4960
+ b,
4961
+ *,
4962
+ upper=True,
4963
+ left=True,
4964
+ unitriangular=False):
4831
4965
  if left is False:
4832
4966
  a = jnp.matrix_transpose(a)
4833
4967
  b = jnp.matrix_transpose(b)
4834
4968
  upper = not upper
4835
- res = jax.scipy.linalg.solve_triangular(a, b, lower=not upper, unit_diagonal=unitriangular)
4969
+ res = jax.scipy.linalg.solve_triangular(
4970
+ a, b, lower=not upper, unit_diagonal=unitriangular)
4836
4971
  if left is False:
4837
4972
  res = jnp.matrix_transpose(res)
4838
4973
  return res
@@ -4852,21 +4987,31 @@ def _aten__linalg_check_errors(*args, **kwargs):
4852
4987
 
4853
4988
  @op(torch.ops.aten.median)
4854
4989
  def _aten_median(self, dim=None, keepdim=False):
4855
- output = _with_reduction_scalar(functools.partial(jnp.quantile, q=0.5, method='lower'), self, dim=dim, keepdim=keepdim).astype(self.dtype)
4990
+ output = _with_reduction_scalar(
4991
+ functools.partial(jnp.quantile, q=0.5, method='lower'),
4992
+ self,
4993
+ dim=dim,
4994
+ keepdim=keepdim).astype(self.dtype)
4856
4995
  if dim is None:
4857
4996
  return output
4858
4997
  else:
4859
- index = _with_reduction_scalar(_get_median_index, self, dim, keepdim).astype(jnp.int64)
4998
+ index = _with_reduction_scalar(_get_median_index, self, dim,
4999
+ keepdim).astype(jnp.int64)
4860
5000
  return output, index
4861
5001
 
4862
5002
 
4863
5003
  @op(torch.ops.aten.nanmedian)
4864
5004
  def _aten_nanmedian(input, dim=None, keepdim=False, *, out=None):
4865
- output = _with_reduction_scalar(functools.partial(jnp.nanquantile, q=0.5, method='lower'), input, dim=dim, keepdim=keepdim).astype(input.dtype)
5005
+ output = _with_reduction_scalar(
5006
+ functools.partial(jnp.nanquantile, q=0.5, method='lower'),
5007
+ input,
5008
+ dim=dim,
5009
+ keepdim=keepdim).astype(input.dtype)
4866
5010
  if dim is None:
4867
5011
  return output
4868
5012
  else:
4869
- index = _with_reduction_scalar(_get_median_index, input, dim, keepdim).astype(jnp.int64)
5013
+ index = _with_reduction_scalar(_get_median_index, input, dim,
5014
+ keepdim).astype(jnp.int64)
4870
5015
  return output, index
4871
5016
 
4872
5017
 
@@ -4874,20 +5019,31 @@ def _get_median_index(x, axis=None, keepdims=False):
4874
5019
  sorted_arg = jnp.argsort(x, axis=axis)
4875
5020
  n = x.shape[axis] if axis is not None else x.size
4876
5021
  if n % 2 == 1:
4877
- index = n // 2
5022
+ index = n // 2
4878
5023
  else:
4879
- index = (n // 2) - 1
5024
+ index = (n // 2) - 1
4880
5025
  if axis is None:
4881
- median_index = sorted_arg[index]
5026
+ median_index = sorted_arg[index]
4882
5027
  else:
4883
- median_index = jnp.take(sorted_arg, index, axis=axis)
5028
+ median_index = jnp.take(sorted_arg, index, axis=axis)
4884
5029
  if keepdims and axis is not None:
4885
- median_index = jnp.expand_dims(median_index, axis)
5030
+ median_index = jnp.expand_dims(median_index, axis)
4886
5031
  return median_index
4887
5032
 
5033
+
4888
5034
  @op(torch.ops.aten.triangular_solve)
4889
- def _aten_triangular_solve(b, a, upper=True, transpose=False, unittriangular=False):
4890
- return (jax.lax.linalg.triangular_solve(a, b, left_side=True, lower=not upper, transpose_a=transpose, unit_diagonal=unittriangular), a)
5035
+ def _aten_triangular_solve(b,
5036
+ a,
5037
+ upper=True,
5038
+ transpose=False,
5039
+ unittriangular=False):
5040
+ return (jax.lax.linalg.triangular_solve(
5041
+ a,
5042
+ b,
5043
+ left_side=True,
5044
+ lower=not upper,
5045
+ transpose_a=transpose,
5046
+ unit_diagonal=unittriangular), a)
4891
5047
 
4892
5048
 
4893
5049
  # func: _fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor
@@ -4895,16 +5051,16 @@ def _aten_triangular_solve(b, a, upper=True, transpose=False, unittriangular=Fal
4895
5051
  def _aten__fft_c2c(self, dim, normalization, forward):
4896
5052
  if forward:
4897
5053
  norm = [
4898
- 'backward',
4899
- 'ortho',
4900
- 'forward',
5054
+ 'backward',
5055
+ 'ortho',
5056
+ 'forward',
4901
5057
  ][normalization]
4902
5058
  return jnp.fft.fftn(self, axes=dim, norm=norm)
4903
5059
  else:
4904
5060
  norm = [
4905
- 'forward',
4906
- 'ortho',
4907
- 'backward',
5061
+ 'forward',
5062
+ 'ortho',
5063
+ 'backward',
4908
5064
  ][normalization]
4909
5065
  return jnp.fft.ifftn(self, axes=dim, norm=norm)
4910
5066
 
@@ -4912,21 +5068,22 @@ def _aten__fft_c2c(self, dim, normalization, forward):
4912
5068
  @op(torch.ops.aten._fft_r2c)
4913
5069
  def _aten__fft_r2c(self, dim, normalization, onesided):
4914
5070
  norm = [
4915
- 'backward',
4916
- 'ortho',
4917
- 'forward',
5071
+ 'backward',
5072
+ 'ortho',
5073
+ 'forward',
4918
5074
  ][normalization]
4919
5075
  if onesided:
4920
5076
  return jnp.fft.rfftn(self, axes=dim, norm=norm)
4921
5077
  else:
4922
5078
  return jnp.fft.fftn(self, axes=dim, norm=norm)
4923
5079
 
5080
+
4924
5081
  @op(torch.ops.aten._fft_c2r)
4925
5082
  def _aten__fft_c2r(self, dim, normalization, last_dim_size):
4926
5083
  norm = [
4927
- 'forward',
4928
- 'ortho',
4929
- 'backward',
5084
+ 'forward',
5085
+ 'ortho',
5086
+ 'backward',
4930
5087
  ][normalization]
4931
5088
  if len(dim) == 1:
4932
5089
  s = [last_dim_size]
@@ -4936,34 +5093,49 @@ def _aten__fft_c2r(self, dim, normalization, last_dim_size):
4936
5093
 
4937
5094
 
4938
5095
  @op(torch.ops.aten._trilinear)
4939
- def _aten_trilinear(i1, i2, i3, expand1, expand2, expand3, sumdim, unroll_dim=1):
4940
- return _aten_sum(jnp.expand_dims(i1, expand1) * jnp.expand_dims(i2, expand2) * jnp.expand_dims(i3, expand3), sumdim)
5096
+ def _aten_trilinear(i1,
5097
+ i2,
5098
+ i3,
5099
+ expand1,
5100
+ expand2,
5101
+ expand3,
5102
+ sumdim,
5103
+ unroll_dim=1):
5104
+ return _aten_sum(
5105
+ jnp.expand_dims(i1, expand1) * jnp.expand_dims(i2, expand2) *
5106
+ jnp.expand_dims(i3, expand3), sumdim)
4941
5107
 
4942
5108
 
4943
5109
  @op(torch.ops.aten.max_unpool2d)
4944
5110
  @op(torch.ops.aten.max_unpool3d)
4945
5111
  def _aten_max_unpoolxd(input, indices, output_size, stride=None, padding=0):
4946
5112
  if output_size is None:
4947
- raise ValueError("output_size value is not set correctly. It cannot be None or empty.")
5113
+ raise ValueError(
5114
+ "output_size value is not set correctly. It cannot be None or empty.")
4948
5115
 
4949
5116
  output_size = [input.shape[0], input.shape[1]] + output_size
4950
5117
  output = jnp.zeros(output_size, dtype=input.dtype)
4951
5118
 
4952
5119
  for idx in np.ndindex(input.shape):
4953
- max_index = indices[idx]
4954
- spatial_dims = output_size[2:] # (D, H, W)
4955
- unpooled_spatial_idx = np.unravel_index(max_index, spatial_dims)
4956
- full_idx = idx[:2] + unpooled_spatial_idx
4957
- output = output.at[full_idx].set(input[idx])
5120
+ max_index = indices[idx]
5121
+ spatial_dims = output_size[2:] # (D, H, W)
5122
+ unpooled_spatial_idx = np.unravel_index(max_index, spatial_dims)
5123
+ full_idx = idx[:2] + unpooled_spatial_idx
5124
+ output = output.at[full_idx].set(input[idx])
4958
5125
 
4959
5126
  return output
4960
5127
 
4961
- @op(torch.ops.aten._upsample_bilinear2d_aa)
4962
- def _aten_upsample_bilinear2d_aa(input, output_size, align_corners, scale_factors=None, scales_h=None, scales_w=None):
5128
+
5129
+ def _aten_upsample(input,
5130
+ output_size,
5131
+ align_corners,
5132
+ antialias,
5133
+ method,
5134
+ scale_factors=None,
5135
+ scales_h=None,
5136
+ scales_w=None):
4963
5137
  # input: is of type jaxlib.xla_extension.ArrayImpl
4964
5138
  image = input
4965
- method = "bilinear"
4966
- antialias = True # ignored for upsampling
4967
5139
 
4968
5140
  # https://jax.readthedocs.io/en/latest/_autosummary/jax.image.resize.html
4969
5141
  # Resize does not distinguish batch, channel size.
@@ -4977,12 +5149,12 @@ def _aten_upsample_bilinear2d_aa(input, output_size, align_corners, scale_factor
4977
5149
  shape = list(image.shape)
4978
5150
  # overriding output_size
4979
5151
  if scale_factors:
4980
- shape[-1] = int(math.floor(shape[-1]*scale_factors[-1]))
4981
- shape[-2] = int(math.floor(shape[-2]*scale_factors[-2]))
5152
+ shape[-1] = int(math.floor(shape[-1] * scale_factors[-1]))
5153
+ shape[-2] = int(math.floor(shape[-2] * scale_factors[-2]))
4982
5154
  if scales_h:
4983
- shape[-2] = int(math.floor(shape[-2]*scales_h))
5155
+ shape[-2] = int(math.floor(shape[-2] * scales_h))
4984
5156
  if scales_w:
4985
- shape[-1] = int(math.floor(shape[-1]*scales_w))
5157
+ shape[-1] = int(math.floor(shape[-1] * scales_w))
4986
5158
  # output_size overrides scale_factors, scales_*
4987
5159
  if output_size:
4988
5160
  shape[-1] = output_size[-1]
@@ -4992,11 +5164,11 @@ def _aten_upsample_bilinear2d_aa(input, output_size, align_corners, scale_factor
4992
5164
  if shape == list(image.shape):
4993
5165
  return image
4994
5166
 
4995
- spatial_dims = (2,3)
5167
+ spatial_dims = (2, 3)
4996
5168
  if len(shape) == 3:
4997
- spatial_dims = (1,2)
5169
+ spatial_dims = (1, 2)
4998
5170
 
4999
- scale = list([shape[i] / image.shape[i] for i in spatial_dims])
5171
+ scale = list([shape[i] / image.shape[i] for i in spatial_dims])
5000
5172
  if scale_factors:
5001
5173
  scale = scale_factors
5002
5174
  if scales_h:
@@ -5008,7 +5180,9 @@ def _aten_upsample_bilinear2d_aa(input, output_size, align_corners, scale_factor
5008
5180
  # align_corners is not supported in resize()
5009
5181
  # https://github.com/jax-ml/jax/issues/11206
5010
5182
  if align_corners:
5011
- scale = jnp.array([(shape[i] - 1.0) / (image.shape[i] - 1.0) for i in spatial_dims])
5183
+ scale = jnp.array([
5184
+ (shape[i] - 1.0) / (image.shape[i] - 1.0) for i in spatial_dims
5185
+ ])
5012
5186
 
5013
5187
  translation = jnp.array([0 for i in spatial_dims])
5014
5188
 
@@ -5022,12 +5196,53 @@ def _aten_upsample_bilinear2d_aa(input, output_size, align_corners, scale_factor
5022
5196
  antialias=antialias,
5023
5197
  )
5024
5198
 
5199
+
5200
+ @op(torch.ops.aten._upsample_bilinear2d_aa)
5201
+ def _aten_upsample_billinear_aa(input,
5202
+ output_size,
5203
+ align_corners,
5204
+ scale_factors=None,
5205
+ scales_h=None,
5206
+ scales_w=None):
5207
+ return _aten_upsample(
5208
+ input,
5209
+ output_size,
5210
+ align_corners,
5211
+ True, # antialias
5212
+ "bilinear", # method
5213
+ scale_factors,
5214
+ scales_h,
5215
+ scales_w)
5216
+
5217
+
5218
+ @op(torch.ops.aten._upsample_bicubic2d_aa)
5219
+ def _aten_upsample_bicubic2d_aa(input,
5220
+ output_size,
5221
+ align_corners,
5222
+ scale_factors=None,
5223
+ scales_h=None,
5224
+ scales_w=None):
5225
+ return _aten_upsample(
5226
+ input,
5227
+ output_size,
5228
+ align_corners,
5229
+ True, # antialias
5230
+ "bicubic", # method
5231
+ scale_factors,
5232
+ scales_h,
5233
+ scales_w)
5234
+
5235
+
5025
5236
  @op(torch.ops.aten.polar)
5026
5237
  def _aten_polar(abs, angle, *, out=None):
5027
5238
  return jax.lax.complex(abs * jnp.cos(angle), abs * jnp.sin(angle))
5028
5239
 
5240
+
5029
5241
  @op(torch.ops.aten.cdist)
5030
- def _aten_cdist(x1, x2, p=2.0, compute_mode='use_mm_for_euclid_dist_if_necessary'):
5242
+ def _aten_cdist(x1,
5243
+ x2,
5244
+ p=2.0,
5245
+ compute_mode='use_mm_for_euclid_dist_if_necessary'):
5031
5246
  x1 = x1.astype(jnp.float32)
5032
5247
  x2 = x2.astype(jnp.float32)
5033
5248
 
@@ -5036,7 +5251,8 @@ def _aten_cdist(x1, x2, p=2.0, compute_mode='use_mm_for_euclid_dist_if_necessary
5036
5251
  return _hamming_distance(x1, x2).astype(jnp.float32)
5037
5252
  elif p == 2.0:
5038
5253
  # Use optimized Euclidean distance calculation
5039
- if compute_mode == 'use_mm_for_euclid_dist_if_necessary' and (x1.shape[-2] > 25 or x2.shape[-2] > 25):
5254
+ if compute_mode == 'use_mm_for_euclid_dist_if_necessary' and (
5255
+ x1.shape[-2] > 25 or x2.shape[-2] > 25):
5040
5256
  return _euclidean_mm(x1, x2)
5041
5257
  elif compute_mode == 'use_mm_for_euclid_dist':
5042
5258
  return _euclidean_mm(x1, x2)
@@ -5045,7 +5261,8 @@ def _aten_cdist(x1, x2, p=2.0, compute_mode='use_mm_for_euclid_dist_if_necessary
5045
5261
  else:
5046
5262
  # General p-norm distance calculation
5047
5263
  diff = jnp.abs(jnp.expand_dims(x1, -2) - jnp.expand_dims(x2, -3))
5048
- return jnp.sum(jnp.power(diff, p), axis=-1).astype(jnp.float32) ** (1 / p)
5264
+ return jnp.sum(jnp.power(diff, p), axis=-1).astype(jnp.float32)**(1 / p)
5265
+
5049
5266
 
5050
5267
  def _hamming_distance(x1, x2):
5051
5268
  """
@@ -5064,6 +5281,7 @@ def _hamming_distance(x1, x2):
5064
5281
 
5065
5282
  return hamming_dist
5066
5283
 
5284
+
5067
5285
  def _euclidean_mm(x1, x2):
5068
5286
  """
5069
5287
  Computes the Euclidean distance using matrix multiplication.
@@ -5075,8 +5293,8 @@ def _euclidean_mm(x1, x2):
5075
5293
  Returns:
5076
5294
  JAX array of shape (..., P, R) representing pairwise Euclidean distances.
5077
5295
  """
5078
- x1_sq = jnp.sum(x1 ** 2, axis=-1, keepdims=True).astype(jnp.float32)
5079
- x2_sq = jnp.sum(x2 ** 2, axis=-1, keepdims=True).astype(jnp.float32)
5296
+ x1_sq = jnp.sum(x1**2, axis=-1, keepdims=True).astype(jnp.float32)
5297
+ x2_sq = jnp.sum(x2**2, axis=-1, keepdims=True).astype(jnp.float32)
5080
5298
 
5081
5299
  x2_sq = jnp.swapaxes(x2_sq, -2, -1)
5082
5300
 
@@ -5088,6 +5306,7 @@ def _euclidean_mm(x1, x2):
5088
5306
 
5089
5307
  return dist
5090
5308
 
5309
+
5091
5310
  def _euclidean_direct(x1, x2):
5092
5311
  """
5093
5312
  Computes the Euclidean distance directly without matrix multiplication.
@@ -5101,7 +5320,7 @@ def _euclidean_direct(x1, x2):
5101
5320
  """
5102
5321
  diff = jnp.expand_dims(x1, -2) - jnp.expand_dims(x2, -3)
5103
5322
 
5104
- dist_sq = jnp.sum(diff ** 2, axis=-1).astype(jnp.float32)
5323
+ dist_sq = jnp.sum(diff**2, axis=-1).astype(jnp.float32)
5105
5324
 
5106
5325
  dist_sq = jnp.maximum(dist_sq, 0.0)
5107
5326
 
@@ -5109,13 +5328,14 @@ def _euclidean_direct(x1, x2):
5109
5328
 
5110
5329
  return dist
5111
5330
 
5331
+
5112
5332
  @op(torch.ops.aten.lu_unpack)
5113
5333
  def _aten_lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
5114
5334
  # lu_unpack doesnt exist in jax.
5115
5335
  # Get commonly used data shape variables
5116
5336
  n = LU_data.shape[-2]
5117
5337
  m = LU_data.shape[-1]
5118
- dim = min(n,m)
5338
+ dim = min(n, m)
5119
5339
 
5120
5340
  ### Compute the Lower and Upper triangle
5121
5341
  if unpack_data:
@@ -5130,7 +5350,7 @@ def _aten_lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
5130
5350
  start_indices = jnp.zeros(len(LU_data.shape), dtype=int)
5131
5351
  limit_indices = list(LU_data.shape)
5132
5352
  limit_indices[-1] = dim
5133
- L = jax.lax.slice(L, start_indices, limit_indices)
5353
+ L = jax.lax.slice(L, start_indices, limit_indices)
5134
5354
 
5135
5355
  # Extract upper triangle
5136
5356
  U = jnp.triu(LU_data)
@@ -5160,13 +5380,15 @@ def _aten_lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
5160
5380
 
5161
5381
  # closure to be called for each input 2D matrix.
5162
5382
  def _lu_unpack_2d(p, pivot):
5163
- _pivot = pivot - 1 # pivots are offset by 1 in jax
5383
+ _pivot = pivot - 1 # pivots are offset by 1 in jax
5164
5384
  indices = jnp.array([*range(n)], dtype=jnp.int32)
5385
+
5165
5386
  def update_indices(i, _indices):
5166
5387
  tmp = _indices[i]
5167
5388
  _indices = _indices.at[i].set(_indices[_pivot[i]])
5168
5389
  _indices = _indices.at[_pivot[i]].set(tmp)
5169
5390
  return _indices
5391
+
5170
5392
  indices = jax.lax.fori_loop(0, _pivot.size, update_indices, indices)
5171
5393
  p = p[jnp.array(indices)]
5172
5394
  p = jnp.transpose(p)
@@ -5191,7 +5413,7 @@ def _aten_lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
5191
5413
  reshapedPivot = LU_pivots.reshape(newPivotshape)
5192
5414
 
5193
5415
  # vmap the reshaped 3d tensors
5194
- v_lu_unpack_2d = jax.vmap(_lu_unpack_2d, in_axes=(0,0))
5416
+ v_lu_unpack_2d = jax.vmap(_lu_unpack_2d, in_axes=(0, 0))
5195
5417
  unpackedP = v_lu_unpack_2d(reshapedP, reshapedPivot)
5196
5418
 
5197
5419
  # reshape result back to P's shape
@@ -5210,3 +5432,207 @@ def linear(input, weight, bias=None):
5210
5432
  if bias is not None:
5211
5433
  res += bias
5212
5434
  return res
5435
+
5436
+
5437
+ @op(torch.ops.aten.kthvalue)
5438
+ def kthvalue(input, k, dim=None, keepdim=False, *, out=None):
5439
+ if input.ndim == 0:
5440
+ return input, jnp.array(0)
5441
+ dimension = -1
5442
+ if dim is not None:
5443
+ dimension = dim
5444
+ while dimension < 0:
5445
+ dimension = dimension + input.ndim
5446
+ values = jax.lax.index_in_dim(
5447
+ jnp.partition(input, k - 1, dimension), k - 1, dimension, keepdim)
5448
+ indices = jax.lax.index_in_dim(
5449
+ jnp.argpartition(input, k - 1, dimension).astype('int64'), k - 1,
5450
+ dimension, keepdim)
5451
+ return values, indices
5452
+
5453
+
5454
+ @op(torch.ops.aten.take)
5455
+ def _aten_take(self, index):
5456
+ return self.flatten()[index]
5457
+
5458
+
5459
+ # func: pad(Tensor self, SymInt[] pad, str mode="constant", float? value=None) -> Tensor
5460
+ @op(torch.ops.aten.pad)
5461
+ def _aten_pad(self, pad, mode='constant', value=None):
5462
+ if not isinstance(pad, (tuple, list)) or len(pad) % 2 != 0:
5463
+ raise ValueError("Padding must be a sequence of even length.")
5464
+
5465
+ num_dims = self.ndim
5466
+ if len(pad) > 2 * num_dims:
5467
+ raise ValueError(
5468
+ f"Padding sequence length ({len(pad)}) exceeds 2 * number of dimensions ({2 * num_dims})."
5469
+ )
5470
+
5471
+ # JAX's pad function expects padding for each dimension as a tuple of (low, high)
5472
+ # We need to reverse the pad sequence and group them for JAX.
5473
+ # pad = [p_l0, p_r0, p_l1, p_r1, ...]
5474
+ # becomes ((..., ..., (p_l1, p_r1), (p_l0, p_r0)))
5475
+ jax_pad_width = []
5476
+ # Iterate in reverse pairs
5477
+ for i in range(len(pad) // 2):
5478
+ jax_pad_width.append((pad[(2 * i)], pad[(2 * i + 1)]))
5479
+
5480
+ # Pad any leading dimensions with (0, 0) if the pad sequence is shorter
5481
+ # than the number of dimensions.
5482
+ for _ in range(num_dims - len(pad) // 2):
5483
+ jax_pad_width.append((0, 0))
5484
+
5485
+ # Reverse the jax_pad_width list to match the dimension order
5486
+ jax_pad_width.reverse()
5487
+
5488
+ if mode == "constant":
5489
+ if value is None:
5490
+ value = 0.0
5491
+ return jnp.pad(
5492
+ self, pad_width=jax_pad_width, mode="constant", constant_values=value)
5493
+ elif mode == "reflect":
5494
+ return jnp.pad(self, pad_width=jax_pad_width, mode="reflect")
5495
+ elif mode == "edge":
5496
+ return jnp.pad(self, pad_width=jax_pad_width, mode="edge")
5497
+ else:
5498
+ raise ValueError(
5499
+ f"Unsupported padding mode: {mode}. Expected 'constant', 'reflect', or 'edge'."
5500
+ )
5501
+
5502
+
5503
+ @op(torch.ops.aten.is_nonzero)
5504
+ def _aten_is_nonzero(a):
5505
+ a = jnp.squeeze(a)
5506
+ if a.shape == (0,):
5507
+ raise RuntimeError('bool value of Tensor with no values is ambiguous')
5508
+ if a.ndim != 0:
5509
+ raise RuntimeError(
5510
+ 'bool value of Tensor with more than one value is ambiguous')
5511
+ return a.item() != 0
5512
+
5513
+
5514
+ @op(torch.ops.aten.logit)
5515
+ def _aten_logit(self: jax.Array, eps: float | None = None) -> jax.Array:
5516
+ """
5517
+ Computes the logit function of the input tensor.
5518
+
5519
+ logit(p) = log(p / (1 - p))
5520
+
5521
+ Args:
5522
+ self: Input tensor.
5523
+ eps: A small value to clip the input tensor to avoid log(0) or division by zero.
5524
+ If None, no clipping is performed.
5525
+
5526
+ Returns:
5527
+ A tensor with the logit of each element of the input.
5528
+ """
5529
+ if eps is not None:
5530
+ self = jnp.clip(self, eps, 1.0 - eps)
5531
+ res = jnp.log(self / (1.0 - self))
5532
+ res = res.astype(mappings.t2j_dtype(torch.get_default_dtype()))
5533
+ return res
5534
+
5535
+
5536
+ @op(torch.ops.aten.floor_divide)
5537
+ def _aten_floor_divide(x, y):
5538
+ res = jnp.floor_divide(x, y)
5539
+ return res
5540
+
5541
+
5542
+ @op(torch.ops.aten._assert_tensor_metadata)
5543
+ @op(torch.ops.aten._assert_scalar)
5544
+ def _aten__assert_tensor_metadata(*args, **kwargs):
5545
+ pass
5546
+
5547
+
5548
+ mutation_ops_to_functional = {
5549
+ torch.ops.aten.add_:
5550
+ op_base.InplaceOp(torch.ops.aten.add),
5551
+ torch.ops.aten.sub_:
5552
+ op_base.InplaceOp(torch.ops.aten.sub),
5553
+ torch.ops.aten.mul_:
5554
+ op_base.InplaceOp(torch.ops.aten.mul),
5555
+ torch.ops.aten.div_:
5556
+ op_base.InplaceOp(torch.ops.aten.div),
5557
+ torch.ops.aten.pow_:
5558
+ op_base.InplaceOp(torch.ops.aten.pow),
5559
+ torch.ops.aten.lt_:
5560
+ op_base.InplaceOp(torch.ops.aten.lt),
5561
+ torch.ops.aten.le_:
5562
+ op_base.InplaceOp(torch.ops.aten.le),
5563
+ torch.ops.aten.gt_:
5564
+ op_base.InplaceOp(torch.ops.aten.gt),
5565
+ torch.ops.aten.ge_:
5566
+ op_base.InplaceOp(torch.ops.aten.ge),
5567
+ torch.ops.aten.eq_:
5568
+ op_base.InplaceOp(torch.ops.aten.eq),
5569
+ torch.ops.aten.ne_:
5570
+ op_base.InplaceOp(torch.ops.aten.ne),
5571
+ torch.ops.aten.bernoulli_:
5572
+ op_base.InplaceOp(torch.ops.aten.bernoulli.p),
5573
+ torch.ops.aten.bernoulli_.float:
5574
+ op_base.InplaceOp(_aten_bernoulli, is_jax_func=True),
5575
+ torch.ops.aten.geometric_:
5576
+ op_base.InplaceOp(torch.ops.aten.geometric),
5577
+ torch.ops.aten.normal_:
5578
+ op_base.InplaceOp(torch.ops.aten.normal),
5579
+ torch.ops.aten.random_:
5580
+ op_base.InplaceOp(torch.ops.aten.uniform),
5581
+ torch.ops.aten.uniform_:
5582
+ op_base.InplaceOp(torch.ops.aten.uniform),
5583
+ torch.ops.aten.relu_:
5584
+ op_base.InplaceOp(torch.ops.aten.relu),
5585
+ # squeeze_ is expected to change tensor's shape. So replace with new value
5586
+ torch.ops.aten.squeeze_:
5587
+ op_base.InplaceOp(torch.ops.aten.squeeze, True),
5588
+ torch.ops.aten.sqrt_:
5589
+ op_base.InplaceOp(torch.ops.aten.sqrt),
5590
+ torch.ops.aten.clamp_:
5591
+ op_base.InplaceOp(torch.ops.aten.clamp),
5592
+ torch.ops.aten.clamp_min_:
5593
+ op_base.InplaceOp(torch.ops.aten.clamp_min),
5594
+ torch.ops.aten.sigmoid_:
5595
+ op_base.InplaceOp(torch.ops.aten.sigmoid),
5596
+ torch.ops.aten.tanh_:
5597
+ op_base.InplaceOp(torch.ops.aten.tanh),
5598
+ torch.ops.aten.ceil_:
5599
+ op_base.InplaceOp(torch.ops.aten.ceil),
5600
+ torch.ops.aten.logical_not_:
5601
+ op_base.InplaceOp(torch.ops.aten.logical_not),
5602
+ torch.ops.aten.unsqueeze_:
5603
+ op_base.InplaceOp(torch.ops.aten.unsqueeze),
5604
+ torch.ops.aten.transpose_:
5605
+ op_base.InplaceOp(torch.ops.aten.transpose),
5606
+ torch.ops.aten.log_normal_:
5607
+ op_base.InplaceOp(torch.ops.aten.log_normal),
5608
+ torch.ops.aten.scatter_add_:
5609
+ op_base.InplaceOp(torch.ops.aten.scatter_add),
5610
+ torch.ops.aten.scatter_reduce_.two:
5611
+ op_base.InplaceOp(torch.ops.aten.scatter_reduce),
5612
+ torch.ops.aten.scatter_:
5613
+ op_base.InplaceOp(torch.ops.aten.scatter),
5614
+ torch.ops.aten.bitwise_or_:
5615
+ op_base.InplaceOp(torch.ops.aten.bitwise_or),
5616
+ torch.ops.aten.floor_divide_:
5617
+ op_base.InplaceOp(torch.ops.aten.floor_divide),
5618
+ torch.ops.aten.remainder_:
5619
+ op_base.InplaceOp(torch.ops.aten.remainder),
5620
+ torch.ops.aten.index_put_:
5621
+ op_base.InplaceOp(torch.ops.aten.index_put),
5622
+ }
5623
+
5624
+ # Note: tuple comparisons work intuitively, e.g. `_jax_version >= (0, 4, 32)`.
5625
+ _jax_version = tuple(int(v) for v in jax.version._version.split("."))
5626
+
5627
+ mutation_needs_env = {
5628
+ torch.ops.aten.bernoulli_,
5629
+ torch.ops.aten.bernoulli_.float,
5630
+ }
5631
+
5632
+ for operator, mutation in mutation_ops_to_functional.items():
5633
+ ops_registry.register_torch_dispatch_op(
5634
+ operator,
5635
+ mutation,
5636
+ is_jax_function=False,
5637
+ is_view_op=True,
5638
+ needs_env=(operator in mutation_needs_env))