torchax 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

torchax/ops/jaten.py CHANGED
@@ -15,75 +15,22 @@ from torchax.ops import ops_registry
15
15
  from torchax.ops import op_base, mappings
16
16
  from torchax import interop
17
17
  from torchax.ops import jax_reimplement
18
-
18
+ from torchax.view import View
19
19
  # Keys are OpOverload, value is a callable that takes
20
20
  # Tensor
21
21
  all_ops = {}
22
22
 
23
- # list all Aten ops from pytorch that does mutation
24
- # and need to be implemented in jax
25
-
26
- mutation_ops_to_functional = {
27
- torch.ops.aten.add_: torch.ops.aten.add,
28
- torch.ops.aten.sub_: torch.ops.aten.sub,
29
- torch.ops.aten.mul_: torch.ops.aten.mul,
30
- torch.ops.aten.div_: torch.ops.aten.div,
31
- torch.ops.aten.pow_: torch.ops.aten.pow,
32
- torch.ops.aten.lt_: torch.ops.aten.lt,
33
- torch.ops.aten.le_: torch.ops.aten.le,
34
- torch.ops.aten.gt_: torch.ops.aten.gt,
35
- torch.ops.aten.ge_: torch.ops.aten.ge,
36
- torch.ops.aten.eq_: torch.ops.aten.eq,
37
- torch.ops.aten.ne_: torch.ops.aten.ne,
38
- torch.ops.aten.bernoulli_: torch.ops.aten.bernoulli.p,
39
- torch.ops.aten.geometric_: torch.ops.aten.geometric,
40
- torch.ops.aten.normal_: torch.ops.aten.normal,
41
- torch.ops.aten.random_: torch.ops.aten.uniform,
42
- torch.ops.aten.uniform_: torch.ops.aten.uniform,
43
- torch.ops.aten.relu_: torch.ops.aten.relu,
44
- # squeeze_ is expected to change tensor's shape. So replace with new value
45
- torch.ops.aten.squeeze_: (torch.ops.aten.squeeze, True),
46
- torch.ops.aten.sqrt_: torch.ops.aten.sqrt,
47
- torch.ops.aten.clamp_: torch.ops.aten.clamp,
48
- torch.ops.aten.clamp_min_: torch.ops.aten.clamp_min,
49
- torch.ops.aten.sigmoid_: torch.ops.aten.sigmoid,
50
- torch.ops.aten.tanh_: torch.ops.aten.tanh,
51
- torch.ops.aten.ceil_: torch.ops.aten.ceil,
52
- torch.ops.aten.logical_not_: torch.ops.aten.logical_not,
53
- torch.ops.aten.unsqueeze_: torch.ops.aten.unsqueeze,
54
- torch.ops.aten.transpose_: torch.ops.aten.transpose,
55
- torch.ops.aten.log_normal_: torch.ops.aten.log_normal,
56
- torch.ops.aten.scatter_add_: torch.ops.aten.scatter_add,
57
- torch.ops.aten.scatter_reduce_.two: torch.ops.aten.scatter_reduce,
58
- torch.ops.aten.scatter_: torch.ops.aten.scatter,
59
- }
60
-
61
- # Note: tuple comparisons work intuitively, e.g. `_jax_version >= (0, 4, 32)`.
62
- _jax_version = tuple(int(v) for v in jax.version._version.split("."))
63
-
64
-
65
- def make_mutation(op):
66
- if type(mutation_ops_to_functional[op]) is tuple:
67
- return op_base.InplaceOp(mutation_ops_to_functional[op][0],
68
- replace=mutation_ops_to_functional[op][1],
69
- position_to_mutate=0)
70
- return op_base.InplaceOp(mutation_ops_to_functional[op], position_to_mutate=0)
71
-
72
-
73
- for op in mutation_ops_to_functional.keys():
74
- ops_registry.register_torch_dispatch_op(
75
- op, make_mutation(op), is_jax_function=False
76
- )
77
-
78
23
 
79
24
  def op(*aten, **kwargs):
25
+
80
26
  def inner(func):
81
27
  for a in aten:
82
28
  ops_registry.register_torch_dispatch_op(a, func, **kwargs)
83
29
  continue
84
30
 
85
31
  if isinstance(a, torch._ops.OpOverloadPacket):
86
- opname = a.default.name() if 'default' in a.overloads() else a._qualified_op_name
32
+ opname = a.default.name() if 'default' in a.overloads(
33
+ ) else a._qualified_op_name
87
34
  elif isinstance(a, torch._ops.OpOverload):
88
35
  opname = a.name()
89
36
  else:
@@ -91,17 +38,18 @@ def op(*aten, **kwargs):
91
38
 
92
39
  torchfunc = functools.partial(interop.call_jax, func)
93
40
  # HACK: to_copy is where we make the initial conversion from CPU tensor to JAX tensor
94
- torch.library.impl(opname, 'privateuseone')(torchfunc if a != torch.ops.aten._to_copy else func)
41
+ torch.library.impl(opname, 'privateuseone')(
42
+ torchfunc if a != torch.ops.aten._to_copy else func)
95
43
  return func
96
44
 
97
45
  return inner
98
46
 
99
47
 
100
48
  @op(
101
- torch.ops.aten.view_copy,
102
- torch.ops.aten.view,
103
- torch.ops.aten._unsafe_view,
104
- torch.ops.aten.reshape,
49
+ torch.ops.aten.view_copy,
50
+ torch.ops.aten.view,
51
+ torch.ops.aten._unsafe_view,
52
+ torch.ops.aten.reshape,
105
53
  )
106
54
  def _aten_unsafe_view(x, shape):
107
55
  return jnp.reshape(x, shape)
@@ -121,8 +69,19 @@ def _aten_add(x, y, *, alpha=1):
121
69
  return res
122
70
 
123
71
 
124
- @op(torch.ops.aten.copy_, is_jax_function=False)
125
- def _aten_copy(x, y, memory_format=None):
72
+ @op(torch.ops.aten.copy_,
73
+ is_jax_function=False,
74
+ is_view_op=True,
75
+ needs_env=True)
76
+ def _aten_copy(x, y, memory_format=None, env=None):
77
+
78
+ if y.device.type == 'cpu':
79
+ y = env.to_xla(y)
80
+
81
+ if isinstance(x, View):
82
+ x.update(y)
83
+ return x
84
+
126
85
  if x.ndim == 1 and y.ndim == 0:
127
86
  # case of torch.empty((1,)).copy_(tensor(N))
128
87
  # we need to return 0D tensor([N]) and not scalar tensor(N)
@@ -147,14 +106,20 @@ def _aten_trunc(x):
147
106
 
148
107
  @op(torch.ops.aten.index_copy)
149
108
  def _aten_index_copy(x, dim, indexes, source):
109
+ if x.ndim == 0:
110
+ return source
111
+ if x.ndim == 1:
112
+ source = jnp.squeeze(source)
150
113
  # return jax.lax.scatter(x, index, dim)
114
+ if dim < 0:
115
+ dim = dim + x.ndim
151
116
  dims = []
152
117
  for i in range(len(x.shape)):
153
118
  if i == dim:
154
119
  dims.append(indexes)
155
120
  else:
156
121
  dims.append(slice(None, None, None))
157
- return x.at[dim].set(source)
122
+ return x.at[tuple(dims)].set(source)
158
123
 
159
124
 
160
125
  # aten.cauchy_
@@ -199,7 +164,9 @@ def _aten_complex(real, imag):
199
164
  Returns:
200
165
  A complex array with the specified real and imaginary parts.
201
166
  """
202
- return jnp.array(real, dtype=jnp.float32) + 1j * jnp.array(imag, dtype=jnp.float32)
167
+ return jnp.array(
168
+ real, dtype=jnp.float32) + 1j * jnp.array(
169
+ imag, dtype=jnp.float32)
203
170
 
204
171
 
205
172
  # aten.exponential_
@@ -223,13 +190,14 @@ def _aten_exponential_(x, lambd=1.0):
223
190
  # aten.linalg_householder_product
224
191
  @op(torch.ops.aten.linalg_householder_product)
225
192
  def _aten_linalg_householder_product(input, tau):
226
- return jax.lax.linalg.householder_product(a = input, taus = tau)
193
+ return jax.lax.linalg.householder_product(a=input, taus=tau)
227
194
 
228
195
 
229
196
  @op(torch.ops.aten.select)
230
197
  def _aten_select(x, dim, indexes):
231
198
  return jax.lax.index_in_dim(x, index=indexes, axis=dim, keepdims=False)
232
199
 
200
+
233
201
  @op(torch.ops.aten.index_select)
234
202
  @op(torch.ops.aten.select_copy)
235
203
  def _aten_index_select(x, dim, index):
@@ -249,11 +217,10 @@ def _aten_linalg_cholesky_ex(input, upper=False, check_errors=False):
249
217
  raise NotImplementedError(
250
218
  "check_errors=True is not supported in this JAX implementation. "
251
219
  "Check for positive definiteness using jnp.linalg.eigvalsh before "
252
- "calling this function."
253
- )
220
+ "calling this function.")
254
221
 
255
222
  L = jax.scipy.linalg.cholesky(input, lower=not upper)
256
- if len(L.shape) >2:
223
+ if len(L.shape) > 2:
257
224
  info = jnp.zeros(shape=L.shape[:-2], dtype=jnp.int32)
258
225
  else:
259
226
  info = jnp.array(0, dtype=jnp.int32)
@@ -263,7 +230,7 @@ def _aten_linalg_cholesky_ex(input, upper=False, check_errors=False):
263
230
  @op(torch.ops.aten.cholesky_solve)
264
231
  def _aten_cholesky_solve(input, input2, upper=False):
265
232
  # Ensure input2 is lower triangular for cho_solve
266
- L = input2 if not upper else input2.T
233
+ L = input2 if not upper else input2.T
267
234
  # Use cho_solve to solve the linear system
268
235
  solution = jax.scipy.linalg.cho_solve((L, True), input)
269
236
  return solution
@@ -275,7 +242,7 @@ def _aten_special_zeta(x, q):
275
242
  res = jax.scipy.special.zeta(x, q)
276
243
  if isinstance(x, int) or isinstance(q, int):
277
244
  res = res.astype(new_dtype)
278
- return res # jax.scipy.special.zeta(x, q)
245
+ return res # jax.scipy.special.zeta(x, q)
279
246
 
280
247
 
281
248
  # aten.igammac
@@ -286,7 +253,7 @@ def _aten_igammac(input, other):
286
253
  if isinstance(other, jnp.ndarray):
287
254
  other = jnp.where(other < 0, jnp.nan, other)
288
255
  else:
289
- if (input==0 and other==0) or (input < 0) or (other < 0):
256
+ if (input == 0 and other == 0) or (input < 0) or (other < 0):
290
257
  other = jnp.nan
291
258
  return jnp.array(jax.scipy.special.gammaincc(input, other))
292
259
 
@@ -294,7 +261,7 @@ def _aten_igammac(input, other):
294
261
  @op(torch.ops.aten.mean)
295
262
  def _aten_mean(x, dim=None, keepdim=False):
296
263
  if x.shape == () and dim is not None:
297
- dim = None # disable dim for jax array without dim
264
+ dim = None # disable dim for jax array without dim
298
265
  return jnp.mean(x, dim, keepdims=keepdim)
299
266
 
300
267
 
@@ -310,13 +277,14 @@ def _torch_binary_scalar_type(scalar, tensor):
310
277
 
311
278
 
312
279
  @op(torch.ops.aten.searchsorted.Tensor)
313
- def _aten_searchsorted(sorted_sequence, values):
280
+ def _aten_searchsorted(sorted_sequence, values):
314
281
  new_dtype = mappings.t2j_dtype(torch.get_default_dtype())
315
282
  res = jnp.searchsorted(sorted_sequence, values)
316
- if sorted_sequence.dtype == np.dtype(np.int32) or sorted_sequence.dtype == np.dtype(np.int32):
283
+ if sorted_sequence.dtype == np.dtype(
284
+ np.int32) or sorted_sequence.dtype == np.dtype(np.int32):
317
285
  # res = res.astype(new_dtype)
318
286
  res = res.astype(np.dtype(np.int64))
319
- return res # jnp.searchsorted(sorted_sequence, values)
287
+ return res # jnp.searchsorted(sorted_sequence, values)
320
288
 
321
289
 
322
290
  @op(torch.ops.aten.sub.Tensor)
@@ -328,7 +296,7 @@ def _aten_sub(x, y, alpha=1):
328
296
  if isinstance(y, float):
329
297
  dtype = _torch_binary_scalar_type(y, x)
330
298
  y = jnp.array(y, dtype=dtype)
331
- return x - y*alpha
299
+ return x - y * alpha
332
300
 
333
301
 
334
302
  @op(torch.ops.aten.numpy_T)
@@ -345,7 +313,6 @@ def _aten_numpy_T(input):
345
313
  return jnp.transpose(input)
346
314
 
347
315
 
348
-
349
316
  @op(torch.ops.aten.mm)
350
317
  def _aten_mm(x, y):
351
318
  res = x @ y
@@ -379,13 +346,15 @@ def _aten_t(x):
379
346
  @op(torch.ops.aten.transpose)
380
347
  @op(torch.ops.aten.transpose_copy)
381
348
  def _aten_transpose(x, dim0, dim1):
382
- shape = list(range(len(x.shape)))
383
- shape[dim0], shape[dim1] = shape[dim1], shape[dim0]
384
- return jnp.transpose(x, shape)
349
+ if x.ndim == 0:
350
+ return x
351
+ dim0 = dim0 if dim0 >= 0 else dim0 + x.ndim
352
+ dim1 = dim1 if dim1 >= 0 else dim1 + x.ndim
353
+ return jnp.swapaxes(x, dim0, dim1)
385
354
 
386
355
 
387
356
  @op(torch.ops.aten.triu)
388
- def _aten_triu(m, k):
357
+ def _aten_triu(m, k=0):
389
358
  return jnp.triu(m, k)
390
359
 
391
360
 
@@ -406,6 +375,7 @@ def _aten_slice(self, dim=0, start=None, end=None, step=1):
406
375
  return self[tuple(dims)]
407
376
 
408
377
 
378
+ @op(torch.ops.aten.positive)
409
379
  @op(torch.ops.aten.detach)
410
380
  def _aten_detach(self):
411
381
  return self
@@ -439,7 +409,8 @@ def _aten_resize_as_(x, y):
439
409
 
440
410
  @op(torch.ops.aten.repeat_interleave.Tensor)
441
411
  def repeat_interleave(repeats, dim=0):
442
- return jnp.repeat(jnp.arange(repeats.shape[dim]), repeats)
412
+ return jnp.repeat(np.arange(repeats.shape[dim]), repeats)
413
+
443
414
 
444
415
  @op(torch.ops.aten.repeat_interleave.self_int)
445
416
  @op(torch.ops.aten.repeat_interleave.self_Tensor)
@@ -451,12 +422,6 @@ def repeat_interleave(self, repeats, dim=0):
451
422
  return jnp.repeat(self, repeats, dim, total_repeat_length=total_repeat_length)
452
423
 
453
424
 
454
- # aten.upsample_bilinear2d
455
- @op(torch.ops.aten.upsample_bilinear2d)
456
- def _aten_upsample_bilinear2d(x, output_size, align_corners=False, scale_h=None, scale_w=None):
457
- return _aten_upsample_bilinear2d_aa(x, output_size=output_size, align_corners=align_corners, scale_factors=None, scales_h=scale_h, scales_w=scale_w)
458
-
459
-
460
425
  @op(torch.ops.aten.view_as_real)
461
426
  def _aten_view_as_real(x):
462
427
  real = jnp.real(x)
@@ -473,7 +438,7 @@ def _aten_stack(tensors, dim=0):
473
438
  @op(torch.ops.aten._softmax)
474
439
  @op(torch.ops.aten.softmax)
475
440
  @op(torch.ops.aten.softmax.int)
476
- def _aten_softmax(x, dim, halftofloat = False):
441
+ def _aten_softmax(x, dim, halftofloat=False):
477
442
  if x.shape == ():
478
443
  return jax.nn.softmax(x.reshape([1]), axis=0).reshape([])
479
444
  return jax.nn.softmax(x, dim)
@@ -482,10 +447,12 @@ def _aten_softmax(x, dim, halftofloat = False):
482
447
  def _is_int(x):
483
448
  if isinstance(x, int):
484
449
  return True
485
- if isinstance(x, jax.Array) and (x.dtype.name.startswith('int') or x.dtype.name.startswith('uint')):
450
+ if isinstance(x, jax.Array) and (x.dtype.name.startswith('int') or
451
+ x.dtype.name.startswith('uint')):
486
452
  return True
487
453
  return False
488
454
 
455
+
489
456
  def highest_precision_int_dtype(tensor1, tensor2):
490
457
  if isinstance(tensor1, int):
491
458
  return tensor2.dtype
@@ -493,12 +460,20 @@ def highest_precision_int_dtype(tensor1, tensor2):
493
460
  return tensor1.dtype
494
461
 
495
462
  dtype_hierarchy = {
496
- 'uint8': 8, 'int8': 8,
497
- 'uint16': 16, 'int16': 16,
498
- 'uint32': 32, 'int32': 32,
499
- 'uint64': 64, 'int64': 64,
463
+ 'uint8': 8,
464
+ 'int8': 8,
465
+ 'uint16': 16,
466
+ 'int16': 16,
467
+ 'uint32': 32,
468
+ 'int32': 32,
469
+ 'uint64': 64,
470
+ 'int64': 64,
500
471
  }
501
- return max(tensor1.dtype, tensor2.dtype, key=lambda dtype: dtype_hierarchy[str(dtype)])
472
+ return max(
473
+ tensor1.dtype,
474
+ tensor2.dtype,
475
+ key=lambda dtype: dtype_hierarchy[str(dtype)])
476
+
502
477
 
503
478
  @op(torch.ops.aten.pow)
504
479
  def _aten_pow(x, y):
@@ -553,11 +528,13 @@ def _aten_div(x, y, rounding_mode=""):
553
528
  def _aten_true_divide(x, y):
554
529
  return x / y
555
530
 
531
+
556
532
  @op(torch.ops.aten.dist)
557
533
  def _aten_dist(input, other, p=2):
558
534
  diff = jnp.abs(jnp.subtract(input, other))
559
535
  return _aten_linalg_vector_norm(diff, ord=p)
560
536
 
537
+
561
538
  @op(torch.ops.aten.bmm)
562
539
  def _aten_bmm(x, y):
563
540
  res = x @ y
@@ -567,9 +544,14 @@ def _aten_bmm(x, y):
567
544
 
568
545
  @op(torch.ops.aten.embedding)
569
546
  # embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False)
570
- def _aten_embedding(a, w, padding_idx=-1, scale_grad_by_freq=False, sparse=False):
547
+ def _aten_embedding(a,
548
+ w,
549
+ padding_idx=-1,
550
+ scale_grad_by_freq=False,
551
+ sparse=False):
571
552
  return jnp.take(a, w, axis=0)
572
553
 
554
+
573
555
  @op(torch.ops.aten.embedding_renorm_)
574
556
  def _aten_embedding_renorm_(weight, indices, max_norm, norm_type):
575
557
  # Adapted from https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Embedding.cpp
@@ -587,27 +569,26 @@ def _aten_embedding_renorm_(weight, indices, max_norm, norm_type):
587
569
 
588
570
  indices_to_update = unique_indices[indice_idx]
589
571
 
590
- weight = weight.at[indices_to_update].set(
591
- weight[indices_to_update] * scale[:, None]
592
- )
572
+ weight = weight.at[indices_to_update].set(weight[indices_to_update] *
573
+ scale[:, None])
593
574
  return weight
594
575
 
576
+
595
577
  #- func: _embedding_bag_forward_only(
596
578
  # Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False,
597
579
  # int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) -> (Tensor, Tensor, Tensor, Tensor)
598
580
  @op(torch.ops.aten._embedding_bag)
599
581
  @op(torch.ops.aten._embedding_bag_forward_only)
600
- def _aten__embedding_bag(
601
- weight,
602
- indices,
603
- offsets=None,
604
- scale_grad_by_freq=False,
605
- mode=0,
606
- sparse=False,
607
- per_sample_weights=None,
608
- include_last_offset=False,
609
- padding_idx=-1):
610
- """Jax implementation of the PyTorch _embedding_bag function.
582
+ def _aten__embedding_bag(weight,
583
+ indices,
584
+ offsets=None,
585
+ scale_grad_by_freq=False,
586
+ mode=0,
587
+ sparse=False,
588
+ per_sample_weights=None,
589
+ include_last_offset=False,
590
+ padding_idx=-1):
591
+ """Jax implementation of the PyTorch _embedding_bag function.
611
592
 
612
593
  Args:
613
594
  weight: The learnable weights of the module of shape (num_embeddings, embedding_dim).
@@ -623,48 +604,50 @@ def _aten__embedding_bag(
623
604
  Returns:
624
605
  A tuple of (output, offset2bag, bag_size, max_indices).
625
606
  """
626
- embedded = _aten_embedding(weight, indices, padding_idx)
627
-
628
- if offsets is None:
629
- # offsets is None only when indices.ndim > 1
630
- if mode == 0: # sum
631
- output = jnp.sum(embedded, axis=1)
632
- elif mode == 1: # mean
633
- output = jnp.mean(embedded, axis=1)
634
- elif mode == 2: # max
635
- output = jnp.max(embedded, axis=1)
636
- return output, None, None, None
637
-
638
- if isinstance(offsets, jax.Array):
639
- offsets_np = np.array(offsets)
640
- else:
641
- offsets_np = offsets
642
- offset2bag = np.zeros(indices.shape[0], dtype=np.int64)
643
- bag_size = np.zeros(offsets_np.shape[0], dtype=np.int64)
644
- max_indices = jnp.full_like(indices, -1)
607
+ embedded = _aten_embedding(weight, indices, padding_idx)
608
+
609
+ if offsets is None:
610
+ # offsets is None only when indices.ndim > 1
611
+ if mode == 0: # sum
612
+ output = jnp.sum(embedded, axis=1)
613
+ elif mode == 1: # mean
614
+ output = jnp.mean(embedded, axis=1)
615
+ elif mode == 2: # max
616
+ output = jnp.max(embedded, axis=1)
617
+ return output, None, None, None
618
+
619
+ if isinstance(offsets, jax.Array):
620
+ offsets_np = np.array(offsets)
621
+ else:
622
+ offsets_np = offsets
623
+ offset2bag = np.zeros(indices.shape[0], dtype=np.int64)
624
+ bag_size = np.zeros(offsets_np.shape[0], dtype=np.int64)
625
+ max_indices = jnp.full_like(indices, -1)
645
626
 
646
- for bag in range(offsets_np.shape[0]):
647
- start = int(offsets_np[bag])
627
+ for bag in range(offsets_np.shape[0]):
628
+ start = int(offsets_np[bag])
648
629
 
649
- end = int(indices.shape[0] if bag + 1 == offsets_np.shape[0] else offsets_np[bag + 1])
650
- bag_size[bag] = end - start
651
- offset2bag = offset2bag.at[start:end].set(bag)
630
+ end = int(indices.shape[0] if bag +
631
+ 1 == offsets_np.shape[0] else offsets_np[bag + 1])
632
+ bag_size[bag] = end - start
633
+ offset2bag = offset2bag.at[start:end].set(bag)
652
634
 
653
- if end - start > 0:
654
- if mode == 0:
655
- output_bag = jnp.sum(embedded[start:end], axis=0)
656
- elif mode == 1:
657
- output_bag = jnp.mean(embedded[start:end], axis=0)
658
- elif mode == 2:
659
- output_bag = jnp.max(embedded[start:end], axis=0)
660
- max_indices = max_indices.at[start:end].set(jnp.argmax(embedded[start:end], axis=0))
635
+ if end - start > 0:
636
+ if mode == 0:
637
+ output_bag = jnp.sum(embedded[start:end], axis=0)
638
+ elif mode == 1:
639
+ output_bag = jnp.mean(embedded[start:end], axis=0)
640
+ elif mode == 2:
641
+ output_bag = jnp.max(embedded[start:end], axis=0)
642
+ max_indices = max_indices.at[start:end].set(
643
+ jnp.argmax(embedded[start:end], axis=0))
661
644
 
662
- # The original code returned offset2bag, bag_size, and max_indices as numpy arrays.
663
- # Converting them to JAX arrays for consistency.
664
- offset2bag = jnp.array(offset2bag)
665
- bag_size = jnp.array(bag_size)
645
+ # The original code returned offset2bag, bag_size, and max_indices as numpy arrays.
646
+ # Converting them to JAX arrays for consistency.
647
+ offset2bag = jnp.array(offset2bag)
648
+ bag_size = jnp.array(bag_size)
666
649
 
667
- return output_bag, offset2bag, bag_size, max_indices
650
+ return output_bag, offset2bag, bag_size, max_indices
668
651
 
669
652
 
670
653
  @op(torch.ops.aten.rsqrt)
@@ -676,6 +659,7 @@ def _aten_rsqrt(x):
676
659
  @op(torch.ops.aten.expand)
677
660
  @op(torch.ops.aten.expand_copy)
678
661
  def _aten_expand(x, dims):
662
+
679
663
  def fix_dims(d, xs):
680
664
  if d == -1:
681
665
  return xs
@@ -683,7 +667,9 @@ def _aten_expand(x, dims):
683
667
 
684
668
  shape = list(x.shape)
685
669
  if len(shape) < len(dims):
686
- shape = [1, ] * (len(dims) - len(shape)) + shape
670
+ shape = [
671
+ 1,
672
+ ] * (len(dims) - len(shape)) + shape
687
673
  # make sure that dims and shape is the same by
688
674
  # left pad with 1s. Otherwise the zip below will
689
675
  # truncate
@@ -705,15 +691,15 @@ def _aten__to_copy(self, **kwargs):
705
691
 
706
692
 
707
693
  @op(torch.ops.aten.empty)
708
- @op_base.convert_dtype()
694
+ @op_base.convert_dtype(use_default_dtype=False)
709
695
  def _aten_empty(size: Sequence[int], *, dtype=None, **kwargs):
710
696
  return jnp.empty(size, dtype=dtype)
711
697
 
712
698
 
713
699
  @op(torch.ops.aten.empty_like)
714
- @op_base.convert_dtype()
700
+ @op_base.convert_dtype(use_default_dtype=False)
715
701
  def _aten_empty_like(input, *, dtype=None, **kwargs):
716
- return jnp.empty_like(input, dtype=dtype)
702
+ return jnp.empty_like(input, dtype)
717
703
 
718
704
 
719
705
  @op(torch.ops.aten.ones)
@@ -784,8 +770,8 @@ def split_with_sizes(x, sizes, dim=0):
784
770
  A list of sub-arrays.
785
771
  """
786
772
  if isinstance(sizes, int):
787
- # split equal size
788
- new_sizes = [sizes] * (x.shape[dim] // sizes)
773
+ # split equal size, round up
774
+ new_sizes = [sizes] * (-(-x.shape[dim] // sizes))
789
775
  sizes = new_sizes
790
776
  rank = x.ndim
791
777
  splits = np.cumsum(sizes) # Cumulative sum for split points
@@ -796,14 +782,15 @@ def split_with_sizes(x, sizes, dim=0):
796
782
  return tuple(res)
797
783
 
798
784
  return [
799
- x[make_range(rank, dim, start, end)]
800
- for start, end in zip([0] + list(splits[:-1]), splits)
785
+ x[make_range(rank, dim, start, end)]
786
+ for start, end in zip([0] + list(splits[:-1]), splits)
801
787
  ]
802
788
 
803
789
 
804
790
  @op(torch.ops.aten.permute)
805
791
  @op(torch.ops.aten.permute_copy)
806
792
  def permute(t, dims):
793
+ # TODO: return a View instead
807
794
  return jnp.transpose(t, dims)
808
795
 
809
796
 
@@ -819,6 +806,7 @@ def _aten_unsqueeze(self, dim):
819
806
  def _aten_ne(x, y):
820
807
  return jnp.not_equal(x, y)
821
808
 
809
+
822
810
  # Create indices along a specific axis
823
811
  #
824
812
  # For example
@@ -832,14 +820,12 @@ def _aten_ne(x, y):
832
820
  def _indices_along_axis(x, axis):
833
821
  return jnp.expand_dims(
834
822
  jnp.arange(x.shape[axis]),
835
- axis = [d for d in range(len(x.shape)) if d != axis]
836
- )
823
+ axis=[d for d in range(len(x.shape)) if d != axis])
824
+
837
825
 
838
826
  def _broadcast_indices(indices, shape):
839
- return jnp.broadcast_to(
840
- indices,
841
- shape
842
- )
827
+ return jnp.broadcast_to(indices, shape)
828
+
843
829
 
844
830
  @op(torch.ops.aten.cummax)
845
831
  def _aten_cummax(x, dim):
@@ -851,36 +837,45 @@ def _aten_cummax(x, dim):
851
837
  indice_along_axis = _indices_along_axis(x, axis)
852
838
  indices = _broadcast_indices(indice_along_axis, x.shape)
853
839
 
854
-
855
840
  def cummax_reduce_func(carry, elem):
856
- v1, v2 = carry['val'], elem['val']
841
+ v1, v2 = carry['val'], elem['val']
857
842
  i1, i2 = carry['idx'], elem['idx']
858
843
 
859
844
  v = jnp.maximum(v1, v2)
860
845
  i = jnp.where(v1 > v2, i1, i2)
861
846
  return {'val': v, 'idx': i}
862
- res = jax.lax.associative_scan(cummax_reduce_func, {'val': x, 'idx': indices}, axis=axis)
847
+
848
+ res = jax.lax.associative_scan(
849
+ cummax_reduce_func, {
850
+ 'val': x,
851
+ 'idx': indices
852
+ }, axis=axis)
863
853
  return res['val'], res['idx']
864
854
 
855
+
865
856
  @op(torch.ops.aten.cummin)
866
857
  def _aten_cummin(x, dim):
867
858
  if not x.shape:
868
859
  return x, jnp.zeros_like(x, dtype=jnp.int64)
869
-
860
+
870
861
  axis = dim
871
862
 
872
863
  indice_along_axis = _indices_along_axis(x, axis)
873
864
  indices = _broadcast_indices(indice_along_axis, x.shape)
874
865
 
875
866
  def cummin_reduce_func(carry, elem):
876
- v1, v2 = carry['val'], elem['val']
867
+ v1, v2 = carry['val'], elem['val']
877
868
  i1, i2 = carry['idx'], elem['idx']
878
869
 
879
870
  v = jnp.minimum(v1, v2)
880
871
  i = jnp.where(v1 < v2, i1, i2)
881
872
  return {'val': v, 'idx': i}
882
873
 
883
- res = jax.lax.associative_scan(cummin_reduce_func, {'val': x, 'idx': indices}, axis=axis)
874
+ res = jax.lax.associative_scan(
875
+ cummin_reduce_func, {
876
+ 'val': x,
877
+ 'idx': indices
878
+ }, axis=axis)
884
879
  return res['val'], res['idx']
885
880
 
886
881
 
@@ -908,9 +903,11 @@ def _aten_cumprod(input, dim, dtype=None, out=None):
908
903
 
909
904
 
910
905
  @op(torch.ops.aten.native_layer_norm)
911
- def _aten_native_layer_norm(
912
- input, normalized_shape, weight=None, bias=None, eps=1e-5
913
- ):
906
+ def _aten_native_layer_norm(input,
907
+ normalized_shape,
908
+ weight=None,
909
+ bias=None,
910
+ eps=1e-5):
914
911
  """Implements layer normalization in Jax as defined by `aten::native_layer_norm`.
915
912
 
916
913
  Args:
@@ -944,7 +941,7 @@ def _aten_native_layer_norm(
944
941
  norm_x += bias
945
942
  return norm_x, mean, rstd
946
943
 
947
-
944
+
948
945
  @op(torch.ops.aten.matmul)
949
946
  def _aten_matmul(x, y):
950
947
  return x @ y
@@ -960,6 +957,7 @@ def _aten_addmm(self, mat1, mat2, *, beta=1.0, alpha=1.0):
960
957
  self += alpha * jnp.matmul(mat1, mat2)
961
958
  return self
962
959
 
960
+
963
961
  @op(torch.ops.aten.sparse_sampled_addmm)
964
962
  def _aten_sparse_addmm(self, mat1, mat2, *, beta=1.0, alpha=1.0):
965
963
  alpha = jnp.array(alpha).astype(mat1.dtype)
@@ -974,9 +972,8 @@ def _aten_addbmm(input, batch1, batch2, *, beta=1, alpha=1):
974
972
  alpha = jnp.array(alpha).astype(batch1.dtype)
975
973
  beta = jnp.array(beta).astype(batch1.dtype)
976
974
  mm = jnp.einsum("bxy, byz -> xz", batch1, batch2)
977
- return jax.lax.cond(
978
- beta == 0, lambda: alpha * mm, lambda: beta * input + alpha * mm
979
- )
975
+ return jax.lax.cond(beta == 0, lambda: alpha * mm,
976
+ lambda: beta * input + alpha * mm)
980
977
 
981
978
 
982
979
  @op(torch.ops.aten.gelu)
@@ -987,73 +984,69 @@ def _aten_gelu(self, *, approximate="none"):
987
984
 
988
985
  @op(torch.ops.aten.squeeze)
989
986
  @op(torch.ops.aten.squeeze_copy)
990
- def _aten_squeeze_dim(self, dim):
991
- """Squeezes a Jax tensor by removing a single dimension of size 1.
992
-
993
- Args:
994
- self: The input tensor.
995
- dim: The dimension to squeeze.
996
-
997
- Returns:
998
- The squeezed tensor with the specified dimension removed if it is 1,
999
- otherwise the original tensor is returned.
1000
- """
1001
-
1002
- # Validate input arguments
1003
- if not isinstance(self, jnp.ndarray):
1004
- raise TypeError(f"Expected a Jax tensor, got {type(self)}.")
1005
- if isinstance(dim, int):
1006
- dim = [dim]
1007
-
1008
- # Check if the specified dimension has size 1
1009
- if (len(self.shape) == 0) or all([self.shape[d] != 1 for d in dim]):
987
+ def _aten_squeeze_dim(self, dim=None):
988
+ if self.ndim == 0:
1010
989
  return self
990
+ if dim is not None:
991
+ if isinstance(dim, int):
992
+ if self.shape[dim] != 1:
993
+ return self
994
+ if dim < 0:
995
+ dim += self.ndim
996
+ else:
997
+ # NOTE: torch leaves the dims that is not 1 unchanged,
998
+ # but jax raises error.
999
+ dim = [
1000
+ i if i >= 0 else (i + self.ndim) for i in dim if self.shape[i] == 1
1001
+ ]
1011
1002
 
1012
- # Use slicing to remove the dimension if it is 1
1013
- new_shape = list(self.shape)
1014
-
1015
- def fix_dim(p):
1016
- if p < 0:
1017
- return p + len(self.shape)
1018
- return p
1003
+ return jnp.squeeze(self, dim)
1019
1004
 
1020
- dim = [fix_dim(d) for d in dim]
1021
- new_shape = [p for i, p in enumerate(self.shape) if i not in dim or p != 1]
1022
- return self.reshape(new_shape)
1023
1005
 
1024
1006
  @op(torch.ops.aten.bucketize)
1025
- def _aten_bucketize(input, boundaries, *, out_int32=False, right=False, out=None):
1026
- assert boundaries[0] < boundaries[-1], "boundaries must contain a strictly increasing sequence"
1007
+ def _aten_bucketize(input,
1008
+ boundaries,
1009
+ *,
1010
+ out_int32=False,
1011
+ right=False,
1012
+ out=None):
1027
1013
  return_type = jnp.int32 if out_int32 else jnp.int64
1028
1014
  return jnp.digitize(input, boundaries, right=not right).astype(return_type)
1029
1015
 
1030
1016
 
1031
1017
  @op(torch.ops.aten.conv2d)
1032
1018
  def _aten_conv2d(
1033
- input,
1034
- weight,
1035
- bias,
1036
- stride,
1037
- padding,
1038
- dilation,
1039
- groups,
1019
+ input,
1020
+ weight,
1021
+ bias,
1022
+ stride,
1023
+ padding,
1024
+ dilation,
1025
+ groups,
1040
1026
  ):
1041
1027
  return _aten_convolution(
1042
- input, weight, bias, stride, padding,
1043
- dilation, transposed=False,
1044
- output_padding=1, groups=groups)
1028
+ input,
1029
+ weight,
1030
+ bias,
1031
+ stride,
1032
+ padding,
1033
+ dilation,
1034
+ transposed=False,
1035
+ output_padding=1,
1036
+ groups=groups)
1037
+
1045
1038
 
1046
1039
  @op(torch.ops.aten.convolution)
1047
1040
  def _aten_convolution(
1048
- input,
1049
- weight,
1050
- bias,
1051
- stride,
1052
- padding,
1053
- dilation,
1054
- transposed,
1055
- output_padding,
1056
- groups,
1041
+ input,
1042
+ weight,
1043
+ bias,
1044
+ stride,
1045
+ padding,
1046
+ dilation,
1047
+ transposed,
1048
+ output_padding,
1049
+ groups,
1057
1050
  ):
1058
1051
  num_shape_dim = weight.ndim - 1
1059
1052
  batch_dims = input.shape[:-num_shape_dim]
@@ -1068,7 +1061,7 @@ def _aten_convolution(
1068
1061
  # See https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html
1069
1062
  pad_out = []
1070
1063
  for i in range(num_spatial_dims):
1071
- front = dilation[i] * (weight.shape[i+2] - 1) - padding[i]
1064
+ front = dilation[i] * (weight.shape[i + 2] - 1) - padding[i]
1072
1065
  back = front + output_padding[i]
1073
1066
  pad_out.append((front, back))
1074
1067
  return pad_out
@@ -1089,39 +1082,38 @@ def _aten_convolution(
1089
1082
  rhs_spec.append(i + 2)
1090
1083
  out_spec.append(i + 2)
1091
1084
  return jax.lax.ConvDimensionNumbers(
1092
- *map(tuple, (lhs_spec, rhs_spec, out_spec))
1093
- )
1085
+ *map(tuple, (lhs_spec, rhs_spec, out_spec)))
1094
1086
 
1095
1087
  if transposed:
1096
- rhs = jnp.flip(weight, range(2, 1+num_shape_dim))
1088
+ rhs = jnp.flip(weight, range(2, 1 + num_shape_dim))
1097
1089
  if groups != 1:
1098
1090
  # reshape filters for tranposed depthwise convolution
1099
1091
  assert rhs.shape[0] % groups == 0
1100
- rhs_shape = [rhs.shape[0]//groups, rhs.shape[1]*groups]
1092
+ rhs_shape = [rhs.shape[0] // groups, rhs.shape[1] * groups]
1101
1093
  rhs_shape.extend(rhs.shape[2:])
1102
1094
  rhs = jnp.reshape(rhs, rhs_shape)
1103
1095
  res = jax.lax.conv_general_dilated(
1104
- input,
1105
- rhs,
1106
- (1,) * len(stride),
1107
- make_padding(padding, len(stride)),
1108
- lhs_dilation=stride,
1109
- rhs_dilation=dilation,
1110
- dimension_numbers=create_default_conv_dimension_numbers(len(stride)),
1111
- feature_group_count=groups,
1112
- batch_group_count=1,
1096
+ input,
1097
+ rhs,
1098
+ (1,) * len(stride),
1099
+ make_padding(padding, len(stride)),
1100
+ lhs_dilation=stride,
1101
+ rhs_dilation=dilation,
1102
+ dimension_numbers=create_default_conv_dimension_numbers(len(stride)),
1103
+ feature_group_count=groups,
1104
+ batch_group_count=1,
1113
1105
  )
1114
1106
  else:
1115
1107
  res = jax.lax.conv_general_dilated(
1116
- input,
1117
- weight,
1118
- stride,
1119
- make_padding(padding, len(stride)),
1120
- lhs_dilation=(1,) * len(stride),
1121
- rhs_dilation=dilation,
1122
- dimension_numbers=create_default_conv_dimension_numbers(len(stride)),
1123
- feature_group_count=groups,
1124
- batch_group_count=1,
1108
+ input,
1109
+ weight,
1110
+ stride,
1111
+ make_padding(padding, len(stride)),
1112
+ lhs_dilation=(1,) * len(stride),
1113
+ rhs_dilation=dilation,
1114
+ dimension_numbers=create_default_conv_dimension_numbers(len(stride)),
1115
+ feature_group_count=groups,
1116
+ batch_group_count=1,
1125
1117
  )
1126
1118
 
1127
1119
  if bias is not None:
@@ -1137,10 +1129,9 @@ def _aten_convolution(
1137
1129
 
1138
1130
 
1139
1131
  # _native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps)
1140
- @op(torch.ops.aten._native_batch_norm_legit)
1141
- def _aten__native_batch_norm_legit(
1142
- input, weight, bias, running_mean, running_var, training, momentum, eps
1143
- ):
1132
+ @op(torch.ops.aten._native_batch_norm_legit.default)
1133
+ def _aten__native_batch_norm_legit(input, weight, bias, running_mean,
1134
+ running_var, training, momentum, eps):
1144
1135
  """JAX implementation of batch normalization with optional parameters.
1145
1136
  Refers to https://github.com/pytorch/pytorch/blob/cd3a71f754a2248bcfe500de7c9860bd7d2002bf/torch/_decomp/decompositions.py#L1713.
1146
1137
 
@@ -1161,8 +1152,7 @@ def _aten__native_batch_norm_legit(
1161
1152
  DeviceArray: Reversed batch variance (C,) or empty if training is False
1162
1153
  """
1163
1154
  reduction_dims = [0] + list(range(2, input.ndim))
1164
- reshape_dims = [1, -1] + [1]*(input.ndim-2)
1165
-
1155
+ reshape_dims = [1, -1] + [1] * (input.ndim - 2)
1166
1156
  if training:
1167
1157
  # Calculate batch mean and variance
1168
1158
  mean = jnp.mean(input, axis=reduction_dims, keepdims=True)
@@ -1175,7 +1165,9 @@ def _aten__native_batch_norm_legit(
1175
1165
  saved_rstd = jnp.squeeze(rstd, reduction_dims)
1176
1166
  else:
1177
1167
  rstd = jax.lax.rsqrt(running_var.reshape(reshape_dims) + eps)
1178
- saved_mean = jnp.array([], dtype=input.dtype) # No need to calculate batch statistics in inference mode
1168
+ saved_mean = jnp.array(
1169
+ [], dtype=input.dtype
1170
+ ) # No need to calculate batch statistics in inference mode
1179
1171
  saved_rstd = jnp.array([], dtype=input.dtype)
1180
1172
 
1181
1173
  # Normalize
@@ -1190,19 +1182,17 @@ def _aten__native_batch_norm_legit(
1190
1182
  if weight is not None:
1191
1183
  x_hat *= weight.reshape(reshape_dims) # Reshape weight for broadcasting
1192
1184
  if bias is not None:
1193
- x_hat += bias.reshape(reshape_dims) # Reshape bias for broadcasting
1185
+ x_hat += bias.reshape(reshape_dims) # Reshape bias for broadcasting
1194
1186
 
1195
1187
  return x_hat, saved_mean, saved_rstd
1196
1188
 
1197
1189
 
1198
-
1199
1190
  @op(torch.ops.aten._native_batch_norm_legit_no_training)
1200
- def _aten__native_batch_norm_legit_no_training(
1201
- input, weight, bias, running_mean, running_var, momentum, eps
1202
- ):
1203
- return _aten__native_batch_norm_legit(
1204
- input, weight, bias, running_mean, running_var, False, momentum, eps
1205
- )
1191
+ def _aten__native_batch_norm_legit_no_training(input, weight, bias,
1192
+ running_mean, running_var,
1193
+ momentum, eps):
1194
+ return _aten__native_batch_norm_legit(input, weight, bias, running_mean,
1195
+ running_var, False, momentum, eps)
1206
1196
 
1207
1197
 
1208
1198
  @op(torch.ops.aten.relu)
@@ -1212,7 +1202,15 @@ def _aten_relu(self):
1212
1202
 
1213
1203
  @op(torch.ops.aten.cat)
1214
1204
  def _aten_cat(tensors, dims=0):
1215
- return jnp.concatenate(tensors, dims)
1205
+ # handle empty tensors as a special case.
1206
+ # torch.cat will ignore the empty tensor, while jnp.concatenate
1207
+ # will error if the dims > 0.
1208
+ filtered_tensors = [
1209
+ t for t in tensors if not (t.ndim == 1 and t.shape[0] == 0)
1210
+ ]
1211
+ if filtered_tensors:
1212
+ return jnp.concatenate(filtered_tensors, dims)
1213
+ return tensors[0]
1216
1214
 
1217
1215
 
1218
1216
  def _ceil_mode_padding(
@@ -1220,6 +1218,7 @@ def _ceil_mode_padding(
1220
1218
  input_shape: list[int],
1221
1219
  kernel_size: list[int],
1222
1220
  stride: list[int],
1221
+ dilation: list[int],
1223
1222
  ceil_mode: bool,
1224
1223
  ):
1225
1224
  """Creates low and high padding specification for the given padding (which is symmetric) and ceil mode.
@@ -1232,20 +1231,13 @@ def _ceil_mode_padding(
1232
1231
  right_padding = left_padding
1233
1232
 
1234
1233
  input_size = input_shape[2 + i]
1235
- output_size_rem = (input_size + 2 * left_padding - kernel_size[i]) % stride[
1236
- i
1237
- ]
1234
+ output_size_rem = (input_size + 2 * left_padding -
1235
+ (kernel_size[i] - 1) * dilation[i] - 1) % stride[i]
1238
1236
  if ceil_mode and output_size_rem != 0:
1239
1237
  extra_padding = stride[i] - output_size_rem
1240
- new_output_size = (
1241
- input_size
1242
- + left_padding
1243
- + right_padding
1244
- + extra_padding
1245
- - kernel_size[i]
1246
- + stride[i]
1247
- - 1
1248
- ) // stride[i] + 1
1238
+ new_output_size = (input_size + left_padding + right_padding +
1239
+ extra_padding - (kernel_size[i] - 1) * dilation[i] -
1240
+ 1 + stride[i] - 1) // stride[i] + 1
1249
1241
  # Ensure that the last pooling starts inside the image.
1250
1242
  size_to_compare = input_size + left_padding
1251
1243
 
@@ -1258,30 +1250,36 @@ def _ceil_mode_padding(
1258
1250
 
1259
1251
  @op(torch.ops.aten.max_pool2d_with_indices)
1260
1252
  @op(torch.ops.aten.max_pool3d_with_indices)
1261
- def _aten_max_pool2d_with_indices(
1262
- inputs, kernel_size, strides, padding=0, dilation=1, ceil_mode=False
1263
- ):
1253
+ def _aten_max_pool2d_with_indices(inputs,
1254
+ kernel_size,
1255
+ strides=None,
1256
+ padding=0,
1257
+ dilation=1,
1258
+ ceil_mode=False):
1264
1259
  num_batch_dims = len(inputs.shape) - len(kernel_size) - 1
1265
1260
  kernel_size = tuple(kernel_size)
1266
- strides = tuple(strides)
1261
+ # Default stride is kernel_size
1262
+ strides = tuple(strides) if strides else kernel_size
1267
1263
  if isinstance(padding, int):
1268
1264
  padding = [padding for _ in range(len(kernel_size))]
1265
+ if isinstance(dilation, int):
1266
+ dilation = tuple(dilation for _ in range(len(kernel_size)))
1267
+ elif isinstance(dilation, list):
1268
+ dilation = tuple(dilation)
1269
1269
 
1270
1270
  input_shape = inputs.shape
1271
1271
  if num_batch_dims == 0:
1272
1272
  input_shape = [1, *input_shape]
1273
- padding = _ceil_mode_padding(
1274
- padding, input_shape, kernel_size, strides, ceil_mode
1275
- )
1273
+ padding = _ceil_mode_padding(padding, input_shape, kernel_size, strides,
1274
+ dilation, ceil_mode)
1276
1275
 
1277
- window_shape = kernel_size
1278
- num_batch_dims = inputs.ndim - (len(window_shape) + 1)
1279
- strides = strides or (1,) * len(window_shape)
1280
- assert len(window_shape) == len(
1281
- strides
1282
- ), f"len({window_shape}) must equal len({strides})"
1276
+ assert len(kernel_size) == len(
1277
+ strides), f"len({kernel_size=}) must equal len({strides=})"
1278
+ assert len(kernel_size) == len(
1279
+ dilation), f"len({kernel_size=}) must equal len({dilation=})"
1283
1280
  strides = (1,) * (1 + num_batch_dims) + strides
1284
- dims = (1,) * (1 + num_batch_dims) + window_shape
1281
+ dims = (1,) * (1 + num_batch_dims) + kernel_size
1282
+ dilation = (1,) * (1 + num_batch_dims) + dilation
1285
1283
 
1286
1284
  is_single_input = False
1287
1285
  if num_batch_dims == 0:
@@ -1290,26 +1288,27 @@ def _aten_max_pool2d_with_indices(
1290
1288
  inputs = inputs[None]
1291
1289
  strides = (1,) + strides
1292
1290
  dims = (1,) + dims
1291
+ dilation = (1,) + dilation
1293
1292
  is_single_input = True
1294
1293
 
1295
1294
  assert inputs.ndim == len(dims), f"len({inputs.shape}) != len({dims})"
1296
1295
  if not isinstance(padding, str):
1297
1296
  padding = tuple(map(tuple, padding))
1298
- assert len(padding) == len(window_shape), (
1299
- f"padding {padding} must specify pads for same number of dims as "
1300
- f"window_shape {window_shape}"
1301
- )
1302
- assert all(
1303
- [len(x) == 2 for x in padding]
1304
- ), f"each entry in padding {padding} must be length 2"
1297
+ assert len(padding) == len(kernel_size), (
1298
+ f"padding {padding} must specify pads for same number of dims as "
1299
+ f"kernel_size {kernel_size}")
1300
+ assert all([len(x) == 2 for x in padding
1301
+ ]), f"each entry in padding {padding} must be length 2"
1305
1302
  padding = ((0, 0), (0, 0)) + padding
1306
1303
 
1307
- indices = jnp.arange(np.prod(inputs.shape)).reshape(inputs.shape)
1304
+ indices = jnp.arange(np.prod(inputs.shape[-len(kernel_size):]))
1305
+ indices = indices.reshape(inputs.shape[-len(kernel_size):])
1306
+ indices = jnp.broadcast_to(indices, inputs.shape)
1308
1307
 
1309
1308
  def reduce_fn(a, b):
1310
1309
  ai, av = a
1311
1310
  bi, bv = b
1312
- which = av > bv
1311
+ which = av >= bv # torch breaks ties in favor of later indices
1313
1312
  return jnp.where(which, ai, bi), jnp.where(which, av, bv)
1314
1313
 
1315
1314
  init_val = -jnp.inf
@@ -1317,44 +1316,90 @@ def _aten_max_pool2d_with_indices(
1317
1316
  init_val = -(1 << 31)
1318
1317
  init_val = jnp.array(init_val).astype(inputs.dtype)
1319
1318
 
1320
- # Separate maxpool result and indices into two reduce_window ops. Since
1321
- # the indices tensor is usually unused in inference, separating the two
1319
+ # Separate maxpool result and indices into two reduce_window ops. Since
1320
+ # the indices tensor is usually unused in inference, separating the two
1322
1321
  # can help DCE computations for argmax.
1323
1322
  y = jax.lax.reduce_window(
1324
- inputs, init_val, jax.lax.max, dims, strides, padding
1325
- )
1323
+ inputs,
1324
+ init_val,
1325
+ jax.lax.max,
1326
+ dims,
1327
+ strides,
1328
+ padding,
1329
+ window_dilation=dilation)
1326
1330
  indices, _ = jax.lax.reduce_window(
1327
- (indices, inputs), (0, init_val), reduce_fn, dims, strides, padding
1331
+ (indices, inputs),
1332
+ (0, init_val),
1333
+ reduce_fn,
1334
+ dims,
1335
+ strides,
1336
+ padding,
1337
+ window_dilation=dilation,
1328
1338
  )
1329
1339
  if is_single_input:
1330
1340
  indices = jnp.squeeze(indices, axis=0)
1331
1341
  y = jnp.squeeze(y, axis=0)
1332
-
1342
+
1333
1343
  return y, indices
1334
1344
 
1335
1345
 
1346
+ # Aten ops registered under the `xla` library.
1347
+ try:
1348
+
1349
+ @op(torch.ops.xla.max_pool2d_forward)
1350
+ def _xla_max_pool2d_forward(*args, **kwargs):
1351
+ return _aten_max_pool2d_with_indices(*args, **kwargs)[0]
1352
+
1353
+ @op(torch.ops.xla.aot_mark_sharding)
1354
+ def _xla_aot_mark_sharding(t, mesh: str, partition_spec: str):
1355
+ from jax.sharding import PartitionSpec as P, NamedSharding
1356
+ import ast
1357
+ import torch_xla.distributed.spmd as xs
1358
+ pmesh = xs.Mesh.from_str(mesh)
1359
+ assert pmesh is not None
1360
+ partition_spec_eval = ast.literal_eval(partition_spec)
1361
+ jmesh = pmesh.get_jax_mesh()
1362
+ return jax.lax.with_sharding_constraint(
1363
+ t, NamedSharding(jmesh, P(*partition_spec_eval)))
1364
+
1365
+ @op(torch.ops.xla.einsum_linear_forward)
1366
+ def _xla_einsum_linear_forward(input, weight, bias):
1367
+ with jax.named_scope('einsum_linear_forward'):
1368
+ product = jax.numpy.einsum('...n,mn->...m', input, weight)
1369
+ if bias is not None:
1370
+ return product + bias
1371
+ return product
1372
+
1373
+ except AttributeError:
1374
+ pass
1375
+
1336
1376
  # TODO add more ops
1337
1377
 
1338
1378
 
1339
1379
  @op(torch.ops.aten.min)
1340
1380
  def _aten_min(x, dim=None, keepdim=False):
1341
1381
  if dim is not None:
1342
- return _with_reduction_scalar(jnp.min, x, dim, keepdim), _with_reduction_scalar(jnp.argmin, x, dim, keepdim).astype(jnp.int64)
1382
+ return _with_reduction_scalar(jnp.min, x, dim,
1383
+ keepdim), _with_reduction_scalar(
1384
+ jnp.argmin, x, dim,
1385
+ keepdim).astype(jnp.int64)
1343
1386
  else:
1344
1387
  return _with_reduction_scalar(jnp.min, x, dim, keepdim)
1345
1388
 
1346
1389
 
1347
1390
  @op(torch.ops.aten.mode)
1348
1391
  def _aten_mode(input, dim=-1, keepdim=False, *, out=None):
1349
- if input.ndim == 0: # single number
1392
+ if input.ndim == 0: # single number
1350
1393
  return input, jnp.array(0)
1351
- dim = (input.ndim + dim) % input.ndim # jnp.scipy.stats.mode does not accept -1 as dim
1394
+ dim = (input.ndim +
1395
+ dim) % input.ndim # jnp.scipy.stats.mode does not accept -1 as dim
1352
1396
  # keepdims must be True for accurate broadcasting
1353
1397
  mode, _ = jax.scipy.stats.mode(input, axis=dim, keepdims=True)
1354
1398
  mode_broadcast = jnp.broadcast_to(mode, input.shape)
1355
1399
  if not keepdim:
1356
1400
  mode = mode.squeeze(axis=dim)
1357
- indices = jnp.argmax(jnp.equal(mode_broadcast, input), axis=dim, keepdims=keepdim)
1401
+ indices = jnp.argmax(
1402
+ jnp.equal(mode_broadcast, input), axis=dim, keepdims=keepdim)
1358
1403
  return mode, indices
1359
1404
 
1360
1405
 
@@ -1388,8 +1433,7 @@ def _aten_var(x, dim=None, *, correction=1, keepdim=False, out=None):
1388
1433
  @op(torch.ops.prims.broadcast_in_dim)
1389
1434
  def _prims_broadcast_in_dim(t, shape, broadcast_dimensions):
1390
1435
  return jax.lax.broadcast_in_dim(
1391
- t, shape, broadcast_dimensions=broadcast_dimensions
1392
- )
1436
+ t, shape, broadcast_dimensions=broadcast_dimensions)
1393
1437
 
1394
1438
 
1395
1439
  # aten.native_group_norm -- should use decomp table
@@ -1432,17 +1476,15 @@ def _aten_native_group_norm(input, weight, bias, N, C, HxW, group, eps=1e-5):
1432
1476
  normalized = (x - mean) * rstd
1433
1477
  return normalized, mean, rstd
1434
1478
 
1435
- normalized, group_mean, group_rstd = jax.lax.map(
1436
- group_norm_body, reshaped_input
1437
- )
1479
+ normalized, group_mean, group_rstd = jax.lax.map(group_norm_body,
1480
+ reshaped_input)
1438
1481
 
1439
1482
  # Reshape back to original input shape
1440
1483
  output = jnp.reshape(normalized, input_shape)
1441
1484
 
1442
1485
  # **Affine transformation**
1443
- affine_shape = [
1444
- -1 if i == 1 else 1 for i in range(input.ndim)
1445
- ] # Shape for broadcasting
1486
+ affine_shape = [-1 if i == 1 else 1 for i in range(input.ndim)
1487
+ ] # Shape for broadcasting
1446
1488
  if weight is not None and bias is not None:
1447
1489
  output = bias.reshape(affine_shape) + output * weight.reshape(affine_shape)
1448
1490
  elif weight is not None:
@@ -1474,22 +1516,25 @@ def _aten_linalg_vector_norm(self, ord=2, dim=None, keepdim=False, dtype=None):
1474
1516
  The tensor containing the calculated vector norms.
1475
1517
  """
1476
1518
 
1477
- if ord not in {2, float("inf"), float("-inf"), "fro"} and not isinstance(ord, (int, float)):
1519
+ if ord not in {2, float("inf"), float("-inf"), "fro"
1520
+ } and not isinstance(ord, (int, float)):
1478
1521
  raise ValueError(
1479
- f"Unsupported ord value: {ord}. Supported values are 2, inf, -inf, and"
1480
- " 'fro'."
1481
- )
1482
-
1522
+ f"Unsupported ord value: {ord}. Supported values are 2, inf, -inf, and"
1523
+ " 'fro'.")
1524
+
1483
1525
  # Special cases (for efficiency and clarity)
1484
1526
  if ord == 0:
1485
1527
  if self.shape == ():
1486
1528
  # float sets it to float64. set it back to input type
1487
1529
  result = jnp.astype(jnp.array(float(self != 0)), self.dtype)
1488
1530
  else:
1489
- result = _with_reduction_scalar(jnp.sum, jnp.where(self != 0, 1, 0), dim, keepdim)
1531
+ result = _with_reduction_scalar(jnp.sum, jnp.where(self != 0, 1, 0), dim,
1532
+ keepdim)
1490
1533
 
1491
1534
  elif ord == 2: # Euclidean norm
1492
- result = jnp.sqrt(_with_reduction_scalar(jnp.sum, jnp.abs(self) ** 2, dim, keepdim))
1535
+ result = jnp.sqrt(
1536
+ _with_reduction_scalar(jnp.sum,
1537
+ jnp.abs(self)**2, dim, keepdim))
1493
1538
 
1494
1539
  elif ord == float("inf"):
1495
1540
  result = _with_reduction_scalar(jnp.max, jnp.abs(self), dim, keepdim)
@@ -1498,12 +1543,14 @@ def _aten_linalg_vector_norm(self, ord=2, dim=None, keepdim=False, dtype=None):
1498
1543
  result = _with_reduction_scalar(jnp.min, jnp.abs(self), dim, keepdim)
1499
1544
 
1500
1545
  elif ord == "fro": # Frobenius norm
1501
- result = jnp.sqrt(_with_reduction_scalar(jnp.sum, jnp.abs(self) ** 2, dim, keepdim))
1546
+ result = jnp.sqrt(
1547
+ _with_reduction_scalar(jnp.sum,
1548
+ jnp.abs(self)**2, dim, keepdim))
1502
1549
 
1503
1550
  else: # General case (e.g., ord = 1, ord = 3)
1504
- result = _with_reduction_scalar(jnp.sum, jnp.abs(self) ** ord, dim, keepdim) ** (
1505
- 1.0 / ord
1506
- )
1551
+ result = _with_reduction_scalar(jnp.sum,
1552
+ jnp.abs(self)**ord, dim,
1553
+ keepdim)**(1.0 / ord)
1507
1554
 
1508
1555
  # (Optional) dtype conversion
1509
1556
  if dtype is not None:
@@ -1539,9 +1586,12 @@ def _aten_sinh(self):
1539
1586
 
1540
1587
  # aten.native_layer_norm_backward
1541
1588
  @op(torch.ops.aten.native_layer_norm_backward)
1542
- def _aten_native_layer_norm_backward(
1543
- grad_out, input, normalized_shape, weight, bias, eps=1e-5
1544
- ):
1589
+ def _aten_native_layer_norm_backward(grad_out,
1590
+ input,
1591
+ normalized_shape,
1592
+ weight,
1593
+ bias,
1594
+ eps=1e-5):
1545
1595
  """Implements the backward pass of layer normalization in Jax as defined by `aten::native_layer_norm_backward`.
1546
1596
 
1547
1597
  Args:
@@ -1555,9 +1605,8 @@ def _aten_native_layer_norm_backward(
1555
1605
  Returns:
1556
1606
  A tuple of (grad_input, grad_weight, grad_bias).
1557
1607
  """
1558
- return jax.lax.native_layer_norm_backward(
1559
- grad_out, input, normalized_shape, weight, bias, eps
1560
- )
1608
+ return jax.lax.native_layer_norm_backward(grad_out, input, normalized_shape,
1609
+ weight, bias, eps)
1561
1610
 
1562
1611
 
1563
1612
  # aten.reflection_pad3d_backward
@@ -1585,12 +1634,14 @@ def _aten_bitwise_not(self):
1585
1634
 
1586
1635
 
1587
1636
  # aten.bitwise_left_shift
1637
+ @op(torch.ops.aten.__lshift__)
1588
1638
  @op(torch.ops.aten.bitwise_left_shift)
1589
1639
  def _aten_bitwise_left_shift(input, other):
1590
1640
  return jnp.left_shift(input, other)
1591
1641
 
1592
1642
 
1593
1643
  # aten.bitwise_right_shift
1644
+ @op(torch.ops.aten.__rshift__)
1594
1645
  @op(torch.ops.aten.bitwise_right_shift)
1595
1646
  def _aten_bitwise_right_shift(input, other):
1596
1647
  return jnp.right_shift(input, other)
@@ -1671,10 +1722,8 @@ def _scatter_index(dim, index):
1671
1722
  target_shape = [1] * len(index_shape)
1672
1723
  target_shape[i] = index_shape[i]
1673
1724
  input_indexes.append(
1674
- jnp.broadcast_to(
1675
- jnp.arange(index_shape[i]).reshape(target_shape), index_shape
1676
- )
1677
- )
1725
+ jnp.broadcast_to(
1726
+ jnp.arange(index_shape[i]).reshape(target_shape), index_shape))
1678
1727
  return tuple(input_indexes), tuple(source_indexes)
1679
1728
 
1680
1729
 
@@ -1686,6 +1735,7 @@ def _aten_scatter_add(input, dim, index, src):
1686
1735
  input_indexes, source_indexes = _scatter_index(dim, index)
1687
1736
  return input.at[input_indexes].add(src[source_indexes])
1688
1737
 
1738
+
1689
1739
  # aten.masked_scatter
1690
1740
  @op(torch.ops.aten.masked_scatter)
1691
1741
  def _aten_masked_scatter(self, mask, source):
@@ -1707,6 +1757,7 @@ def _aten_masked_scatter(self, mask, source):
1707
1757
 
1708
1758
  return final_arr
1709
1759
 
1760
+
1710
1761
  @op(torch.ops.aten.masked_select)
1711
1762
  def _aten_masked_select(self, mask, *args, **kwargs):
1712
1763
  broadcast_shape = jnp.broadcast_shapes(self.shape, mask.shape)
@@ -1722,6 +1773,7 @@ def _aten_masked_select(self, mask, *args, **kwargs):
1722
1773
 
1723
1774
  return self_flat[true_indices]
1724
1775
 
1776
+
1725
1777
  # aten.logical_not
1726
1778
 
1727
1779
 
@@ -1730,11 +1782,13 @@ def _aten_masked_select(self, mask, *args, **kwargs):
1730
1782
  def _aten_sign(x):
1731
1783
  return jnp.sign(x)
1732
1784
 
1785
+
1733
1786
  # aten.signbit
1734
1787
  @op(torch.ops.aten.signbit)
1735
1788
  def _aten_signbit(x):
1736
1789
  return jnp.signbit(x)
1737
1790
 
1791
+
1738
1792
  # aten.sigmoid
1739
1793
  @op(torch.ops.aten.sigmoid)
1740
1794
  @op_base.promote_int_input
@@ -1760,7 +1814,13 @@ def _aten_atan(self):
1760
1814
 
1761
1815
  @op(torch.ops.aten.scatter_reduce)
1762
1816
  @op(torch.ops.aten.scatter)
1763
- def _aten_scatter_reduce(input, dim, index, src, reduce=None, *, include_self=True):
1817
+ def _aten_scatter_reduce(input,
1818
+ dim,
1819
+ index,
1820
+ src,
1821
+ reduce=None,
1822
+ *,
1823
+ include_self=True):
1764
1824
  if not isinstance(src, jnp.ndarray):
1765
1825
  src = jnp.array(src, dtype=input.dtype)
1766
1826
  input_indexes, source_indexes = _scatter_index(dim, index)
@@ -1817,41 +1877,6 @@ def _aten_gt(self, other):
1817
1877
  return self > other
1818
1878
 
1819
1879
 
1820
- # aten.pixel_shuffle
1821
- @op(torch.ops.aten.pixel_shuffle)
1822
- def _aten_pixel_shuffle(x, upscale_factor):
1823
- """PixelShuffle implementation in JAX.
1824
-
1825
- Args:
1826
- x: Input tensor. Typically a feature map.
1827
- upscale_factor: Integer by which to upscale the spatial dimensions.
1828
-
1829
- Returns:
1830
- Tensor after PixelShuffle operation.
1831
- """
1832
-
1833
- batch_size, channels, height, width = x.shape
1834
-
1835
- if channels % (upscale_factor**2) != 0:
1836
- raise ValueError(
1837
- "Number of channels must be divisible by the square of the upscale factor."
1838
- )
1839
-
1840
- new_channels = channels // (upscale_factor**2)
1841
- new_height = height * upscale_factor
1842
- new_width = width * upscale_factor
1843
-
1844
- x = x.reshape(
1845
- batch_size, new_channels, upscale_factor, upscale_factor, height, width
1846
- )
1847
- x = jnp.transpose(
1848
- x, (0, 1, 2, 4, 3, 5)
1849
- ) # Move channels to spatial dimensions
1850
- x = x.reshape(batch_size, new_channels, new_height, new_width)
1851
-
1852
- return x
1853
-
1854
-
1855
1880
  # aten.sym_stride
1856
1881
  # aten.lt
1857
1882
  @op(torch.ops.aten.lt)
@@ -1883,8 +1908,7 @@ def pool(inputs, init, reduce_fn, window_shape, strides, padding):
1883
1908
  num_batch_dims = inputs.ndim - (len(window_shape) + 1)
1884
1909
  strides = strides or (1,) * len(window_shape)
1885
1910
  assert len(window_shape) == len(
1886
- strides
1887
- ), f"len({window_shape}) must equal len({strides})"
1911
+ strides), f"len({window_shape}) must equal len({strides})"
1888
1912
  strides = (1,) * (1 + num_batch_dims) + strides
1889
1913
  dims = (1,) * (1 + num_batch_dims) + window_shape
1890
1914
 
@@ -1901,23 +1925,22 @@ def pool(inputs, init, reduce_fn, window_shape, strides, padding):
1901
1925
  if not isinstance(padding, str):
1902
1926
  padding = tuple(map(tuple, padding))
1903
1927
  assert len(padding) == len(window_shape), (
1904
- f"padding {padding} must specify pads for same number of dims as "
1905
- f"window_shape {window_shape}"
1906
- )
1907
- assert all(
1908
- [len(x) == 2 for x in padding]
1909
- ), f"each entry in padding {padding} must be length 2"
1928
+ f"padding {padding} must specify pads for same number of dims as "
1929
+ f"window_shape {window_shape}")
1930
+ assert all([len(x) == 2 for x in padding
1931
+ ]), f"each entry in padding {padding} must be length 2"
1910
1932
  padding = ((0, 0), (0, 0)) + padding
1911
1933
  y = jax.lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding)
1912
1934
  if is_single_input:
1913
1935
  y = jnp.squeeze(y, axis=0)
1914
1936
  return y
1915
1937
 
1916
-
1938
+
1917
1939
  @op(torch.ops.aten._adaptive_avg_pool2d)
1918
1940
  @op(torch.ops.aten._adaptive_avg_pool3d)
1919
- def adaptive_avg_pool2or3d(input: jnp.ndarray, output_size: Tuple[int, int]) -> jnp.ndarray:
1920
- """
1941
+ def adaptive_avg_pool2or3d(input: jnp.ndarray,
1942
+ output_size: Tuple[int, int]) -> jnp.ndarray:
1943
+ """
1921
1944
  Applies a 2/3D adaptive average pooling over an input signal composed of several input planes.
1922
1945
 
1923
1946
  See :class:`~torch.nn.AdaptiveAvgPool2d` for details and output shape.
@@ -1929,124 +1952,128 @@ def adaptive_avg_pool2or3d(input: jnp.ndarray, output_size: Tuple[int, int]) ->
1929
1952
  Context:
1930
1953
  https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py#L2401
1931
1954
  """
1932
- shape = input.shape
1933
- ndim = len(shape)
1934
- out_dim = len(output_size)
1935
- num_spatial_dim = ndim - out_dim
1936
-
1937
- # Preconditions
1938
-
1939
- assert ndim in (out_dim + 1, out_dim + 2), f"adaptive_avg_pool{num_spatial_dim}d(): Expected {num_spatial_dim+1}D or {num_spatial_dim+2}D tensor, but got {ndim}"
1940
- for d in input.shape[-2:]:
1941
- assert d != 0, "adaptive_avg_pool{num_spactial_dim}d(): Expected input to have non-zero size for " \
1942
- f"non-batch dimensions, but input has shape {tuple(shape)}."
1943
-
1944
- # Optimisation (we should also do this in the kernel implementation)
1945
- if all(s % o == 0 for o, s in zip(output_size, shape[-out_dim:])):
1946
- stride = tuple(i // o for i, o in zip(shape[-out_dim:], output_size))
1947
- kernel = tuple(i - (o - 1) * s for i, o, s in zip(shape[-out_dim:], output_size, stride))
1948
- return _aten_avg_pool(
1949
- input,
1950
- kernel,
1951
- strides=stride,
1952
- )
1953
-
1954
- def start_index(a, b, c):
1955
- return (a * c) // b
1956
-
1957
- def end_index(a, b, c):
1958
- return ((a + 1) * c + b - 1) // b
1959
-
1960
- def compute_idx(in_size, out_size):
1961
- orange = jnp.arange(out_size, dtype=jnp.int64)
1962
- i0 = start_index(orange, out_size, in_size)
1963
- # Let length = end_index - start_index, i.e. the length of the pooling kernels
1964
- # length.max() can be computed analytically as follows:
1965
- maxlength = in_size // out_size + 1
1966
- in_size_mod = in_size % out_size
1967
- # adaptive = True iff there are kernels with different lengths
1968
- adaptive = not (in_size_mod == 0 or out_size % in_size_mod == 0)
1969
- if adaptive:
1970
- maxlength += 1
1971
- elif in_size_mod == 0:
1972
- maxlength -= 1
1973
-
1974
- range_max = jnp.arange(maxlength, dtype=jnp.int64)
1975
- idx = i0[:, None] + range_max
1976
- if adaptive:
1977
- # Need to clamp to avoid accessing out-of-bounds memory
1978
- idx = jnp.minimum(idx, in_size - 1)
1979
-
1980
- # Compute the length
1981
- i1 = end_index(orange, out_size, in_size)
1982
- length = i1 - i0
1983
- else:
1984
- length = maxlength
1985
- return idx, length, range_max, adaptive
1986
-
1987
- idx, length, range_max, adaptive = [[None] * out_dim for _ in range(4)]
1988
- # length is not None if it's constant, otherwise we'll need to compute it
1989
- for i, (s, o) in enumerate(zip(shape[-out_dim:], output_size)):
1990
- idx[i], length[i], range_max[i], adaptive[i] = compute_idx(s, o)
1991
-
1992
- def _unsqueeze_to_dim(x, dim):
1993
- ndim = len(x.shape)
1994
- return jax.lax.expand_dims(x, tuple(range(ndim, dim)))
1995
-
1996
- if out_dim == 2:
1997
- # NOTE: unsqueeze to insert extra 1 in ranks; so they
1998
- # would broadcast
1999
- vals = input[..., _unsqueeze_to_dim(idx[0], 4), idx[1]]
2000
- reduce_axis = (-3, -1)
1955
+ shape = input.shape
1956
+ ndim = len(shape)
1957
+ out_dim = len(output_size)
1958
+ num_spatial_dim = ndim - out_dim
1959
+
1960
+ # Preconditions
1961
+
1962
+ assert ndim in (
1963
+ out_dim + 1, out_dim + 2
1964
+ ), f"adaptive_avg_pool{num_spatial_dim}d(): Expected {num_spatial_dim+1}D or {num_spatial_dim+2}D tensor, but got {ndim}"
1965
+ for d in input.shape[-2:]:
1966
+ assert d != 0, "adaptive_avg_pool{num_spactial_dim}d(): Expected input to have non-zero size for " \
1967
+ f"non-batch dimensions, but input has shape {tuple(shape)}."
1968
+
1969
+ # Optimisation (we should also do this in the kernel implementation)
1970
+ if all(s % o == 0 for o, s in zip(output_size, shape[-out_dim:])):
1971
+ stride = tuple(i // o for i, o in zip(shape[-out_dim:], output_size))
1972
+ kernel = tuple(i - (o - 1) * s
1973
+ for i, o, s in zip(shape[-out_dim:], output_size, stride))
1974
+ return _aten_avg_pool(
1975
+ input,
1976
+ kernel,
1977
+ strides=stride,
1978
+ )
1979
+
1980
+ def start_index(a, b, c):
1981
+ return (a * c) // b
1982
+
1983
+ def end_index(a, b, c):
1984
+ return ((a + 1) * c + b - 1) // b
1985
+
1986
+ def compute_idx(in_size, out_size):
1987
+ orange = jnp.arange(out_size, dtype=jnp.int64)
1988
+ i0 = start_index(orange, out_size, in_size)
1989
+ # Let length = end_index - start_index, i.e. the length of the pooling kernels
1990
+ # length.max() can be computed analytically as follows:
1991
+ maxlength = in_size // out_size + 1
1992
+ in_size_mod = in_size % out_size
1993
+ # adaptive = True iff there are kernels with different lengths
1994
+ adaptive = not (in_size_mod == 0 or out_size % in_size_mod == 0)
1995
+ if adaptive:
1996
+ maxlength += 1
1997
+ elif in_size_mod == 0:
1998
+ maxlength -= 1
1999
+
2000
+ range_max = jnp.arange(maxlength, dtype=jnp.int64)
2001
+ idx = i0[:, None] + range_max
2002
+ if adaptive:
2003
+ # Need to clamp to avoid accessing out-of-bounds memory
2004
+ idx = jnp.minimum(idx, in_size - 1)
2005
+
2006
+ # Compute the length
2007
+ i1 = end_index(orange, out_size, in_size)
2008
+ length = i1 - i0
2009
+ else:
2010
+ length = maxlength
2011
+ return idx, length, range_max, adaptive
2012
+
2013
+ idx, length, range_max, adaptive = [[None] * out_dim for _ in range(4)]
2014
+ # length is not None if it's constant, otherwise we'll need to compute it
2015
+ for i, (s, o) in enumerate(zip(shape[-out_dim:], output_size)):
2016
+ idx[i], length[i], range_max[i], adaptive[i] = compute_idx(s, o)
2017
+
2018
+ def _unsqueeze_to_dim(x, dim):
2019
+ ndim = len(x.shape)
2020
+ return jax.lax.expand_dims(x, tuple(range(ndim, dim)))
2021
+
2022
+ if out_dim == 2:
2023
+ # NOTE: unsqueeze to insert extra 1 in ranks; so they
2024
+ # would broadcast
2025
+ vals = input[..., _unsqueeze_to_dim(idx[0], 4), idx[1]]
2026
+ reduce_axis = (-3, -1)
2027
+ else:
2028
+ assert out_dim == 3
2029
+ vals = input[...,
2030
+ _unsqueeze_to_dim(idx[0], 6),
2031
+ _unsqueeze_to_dim(idx[1], 4), idx[2]]
2032
+ reduce_axis = (-5, -3, -1)
2033
+
2034
+ # Shortcut for the simpler case
2035
+ if not any(adaptive):
2036
+ return jnp.mean(vals, axis=reduce_axis)
2037
+
2038
+ def maybe_mask(vals, length, range_max, adaptive, dim):
2039
+ if isinstance(length, int):
2040
+ return vals, length
2001
2041
  else:
2002
- assert out_dim == 3
2003
- vals = input[..., _unsqueeze_to_dim(idx[0], 6),
2004
- _unsqueeze_to_dim(idx[1], 4),
2005
- idx[2]]
2006
- reduce_axis = (-5, -3, -1)
2007
-
2008
- # Shortcut for the simpler case
2009
- if not any(adaptive):
2010
- return jnp.mean(vals, axis=reduce_axis)
2011
-
2012
- def maybe_mask(vals, length, range_max, adaptive, dim):
2013
- if isinstance(length, int):
2014
- return vals, length
2015
- else:
2016
- # zero-out the things we didn't really want to select
2017
- assert dim < 0
2018
- # hack
2019
- mask = range_max >= length[:, None]
2020
- if dim == -2:
2021
- mask = _unsqueeze_to_dim(mask, 4)
2022
- elif dim == -3:
2023
- mask = _unsqueeze_to_dim(mask, 6)
2024
- vals = jnp.where(mask, 0.0, vals)
2025
- # Compute the length of each window
2026
- length = _unsqueeze_to_dim(length, -dim)
2027
- return vals, length
2028
-
2029
- for i in range(len(length)):
2030
- vals, length[i] = maybe_mask(vals, length[i], range_max[i], adaptive=adaptive[i], dim=(i - out_dim))
2031
-
2032
- # We unroll the sum as we assume that the kernels are going to be small
2033
- ret = jnp.sum(vals, axis=reduce_axis)
2034
- # NOTE: math.prod because we want to expand it to length[0] * length[1] * ...
2035
- # this is multiplication with broadcasting, not regular pointwise product
2036
- return ret / math.prod(length)
2037
-
2042
+ # zero-out the things we didn't really want to select
2043
+ assert dim < 0
2044
+ # hack
2045
+ mask = range_max >= length[:, None]
2046
+ if dim == -2:
2047
+ mask = _unsqueeze_to_dim(mask, 4)
2048
+ elif dim == -3:
2049
+ mask = _unsqueeze_to_dim(mask, 6)
2050
+ vals = jnp.where(mask, 0.0, vals)
2051
+ # Compute the length of each window
2052
+ length = _unsqueeze_to_dim(length, -dim)
2053
+ return vals, length
2054
+
2055
+ for i in range(len(length)):
2056
+ vals, length[i] = maybe_mask(
2057
+ vals, length[i], range_max[i], adaptive=adaptive[i], dim=(i - out_dim))
2058
+
2059
+ # We unroll the sum as we assume that the kernels are going to be small
2060
+ ret = jnp.sum(vals, axis=reduce_axis)
2061
+ # NOTE: math.prod because we want to expand it to length[0] * length[1] * ...
2062
+ # this is multiplication with broadcasting, not regular pointwise product
2063
+ return ret / math.prod(length)
2064
+
2038
2065
 
2039
2066
  @op(torch.ops.aten.avg_pool1d)
2040
2067
  @op(torch.ops.aten.avg_pool2d)
2041
2068
  @op(torch.ops.aten.avg_pool3d)
2042
2069
  def _aten_avg_pool(
2043
- inputs,
2044
- kernel_size,
2045
- strides=None,
2046
- padding=0,
2047
- ceil_mode=False,
2048
- count_include_pad=True,
2049
- divisor_override=None,
2070
+ inputs,
2071
+ kernel_size,
2072
+ strides=None,
2073
+ padding=0,
2074
+ ceil_mode=False,
2075
+ count_include_pad=True,
2076
+ divisor_override=None,
2050
2077
  ):
2051
2078
  num_batch_dims = len(inputs.shape) - len(kernel_size) - 1
2052
2079
  kernel_size = tuple(kernel_size)
@@ -2060,7 +2087,7 @@ def _aten_avg_pool(
2060
2087
  if num_batch_dims == 0:
2061
2088
  input_shape = [1, *input_shape]
2062
2089
  padding = _ceil_mode_padding(padding, input_shape, kernel_size, strides,
2063
- ceil_mode)
2090
+ [1] * len(kernel_size), ceil_mode)
2064
2091
 
2065
2092
  y = pool(inputs, 0.0, jax.lax.add, kernel_size, strides, padding)
2066
2093
  if divisor_override is not None:
@@ -2102,9 +2129,11 @@ def _aten_avg_pool(
2102
2129
  )
2103
2130
  return y.astype(inputs.dtype)
2104
2131
 
2132
+
2105
2133
  # helper function to generate all indices to iterate through ndarray
2106
- def _generate_indices(dims, skip_dim_indices = []):
2134
+ def _generate_indices(dims, skip_dim_indices=[]):
2107
2135
  res = []
2136
+
2108
2137
  def _helper(curr_dim_idx, sofar):
2109
2138
  if curr_dim_idx in skip_dim_indices:
2110
2139
  _helper(curr_dim_idx + 1, sofar[:])
@@ -2115,10 +2144,11 @@ def _generate_indices(dims, skip_dim_indices = []):
2115
2144
  for i in range(dims[curr_dim_idx]):
2116
2145
  sofar[curr_dim_idx] = i
2117
2146
  _helper(curr_dim_idx + 1, sofar[:])
2118
-
2147
+
2119
2148
  _helper(0, [0 for _ in dims])
2120
2149
  return res
2121
2150
 
2151
+
2122
2152
  # aten.sym_numel
2123
2153
  # aten.reciprocal
2124
2154
  @op(torch.ops.aten.reciprocal)
@@ -2174,10 +2204,14 @@ def _aten_round(input, decimals=0):
2174
2204
  @op(torch.ops.aten.max)
2175
2205
  def _aten_max(self, dim=None, keepdim=False):
2176
2206
  if dim is not None:
2177
- return _with_reduction_scalar(jnp.max, self, dim, keepdim), _with_reduction_scalar(jnp.argmax, self, dim, keepdim).astype(jnp.int64)
2207
+ return _with_reduction_scalar(jnp.max, self, dim,
2208
+ keepdim), _with_reduction_scalar(
2209
+ jnp.argmax, self, dim,
2210
+ keepdim).astype(jnp.int64)
2178
2211
  else:
2179
2212
  return _with_reduction_scalar(jnp.max, self, dim, keepdim)
2180
2213
 
2214
+
2181
2215
  # aten.maximum
2182
2216
  @op(torch.ops.aten.maximum)
2183
2217
  def _aten_maximum(self, other):
@@ -2216,27 +2250,28 @@ def _with_reduction_scalar(jax_func, self, dim, keepdim):
2216
2250
  def _aten_any(self, dim=None, keepdim=False):
2217
2251
  return _with_reduction_scalar(jnp.any, self, dim, keepdim)
2218
2252
 
2253
+
2219
2254
  # aten.arange
2220
2255
  @op(torch.ops.aten.arange.start_step)
2221
2256
  @op(torch.ops.aten.arange.start)
2222
2257
  @op(torch.ops.aten.arange.default)
2223
2258
  @op_base.convert_dtype(use_default_dtype=False)
2224
2259
  def _aten_arange(
2225
- start,
2226
- end=None,
2227
- step=None,
2228
- *,
2229
- dtype=None,
2230
- layout=None,
2231
- requires_grad=False,
2232
- device=None,
2233
- pin_memory=False,
2260
+ start,
2261
+ end=None,
2262
+ step=None,
2263
+ *,
2264
+ dtype=None,
2265
+ layout=None,
2266
+ requires_grad=False,
2267
+ device=None,
2268
+ pin_memory=False,
2234
2269
  ):
2235
2270
  return jnp.arange(
2236
- op_base.maybe_convert_constant_dtype(start, dtype),
2237
- op_base.maybe_convert_constant_dtype(end, dtype),
2238
- op_base.maybe_convert_constant_dtype(step, dtype),
2239
- dtype=dtype,
2271
+ op_base.maybe_convert_constant_dtype(start, dtype),
2272
+ op_base.maybe_convert_constant_dtype(end, dtype),
2273
+ op_base.maybe_convert_constant_dtype(step, dtype),
2274
+ dtype=dtype,
2240
2275
  )
2241
2276
 
2242
2277
 
@@ -2245,6 +2280,7 @@ def _aten_arange(
2245
2280
  def _aten_argmax(self, dim=None, keepdim=False):
2246
2281
  return _with_reduction_scalar(jnp.argmax, self, dim, keepdim)
2247
2282
 
2283
+
2248
2284
  def _strided_index(sizes, strides, storage_offset=None):
2249
2285
  ind = jnp.zeros(sizes, dtype=jnp.int32)
2250
2286
 
@@ -2257,6 +2293,7 @@ def _strided_index(sizes, strides, storage_offset=None):
2257
2293
  ind += storage_offset
2258
2294
  return ind
2259
2295
 
2296
+
2260
2297
  # aten.as_strided
2261
2298
  @op(torch.ops.aten.as_strided)
2262
2299
  @op(torch.ops.aten.as_strided_copy)
@@ -2311,7 +2348,7 @@ def _aten_broadcast_tensors(*tensors):
2311
2348
  Args:
2312
2349
  shapes: A list of tuples representing the shapes of the input tensors.
2313
2350
 
2314
- Returns:
2351
+ Returns:
2315
2352
  A tuple representing the broadcasted output shape.
2316
2353
  """
2317
2354
 
@@ -2344,11 +2381,13 @@ def _aten_broadcast_tensors(*tensors):
2344
2381
  A tuple specifying which dimensions of the input tensor should be broadcasted.
2345
2382
  """
2346
2383
 
2347
- res = tuple(i for i, (in_dim, out_dim) in enumerate(zip(input_shape, output_shape)))
2384
+ res = tuple(
2385
+ i for i, (in_dim, out_dim) in enumerate(zip(input_shape, output_shape)))
2348
2386
  return res
2349
2387
 
2350
2388
  # clean some function's previous wrap
2351
- if len(tensors)==1 and len(tensors[0])>=1 and isinstance(tensors[0][0], jax.Array):
2389
+ if len(tensors) == 1 and len(tensors[0]) >= 1 and isinstance(
2390
+ tensors[0][0], jax.Array):
2352
2391
  tensors = tensors[0]
2353
2392
 
2354
2393
  # Get the shapes of all input tensors
@@ -2357,7 +2396,8 @@ def _aten_broadcast_tensors(*tensors):
2357
2396
  output_shape = _get_broadcast_shape(shapes)
2358
2397
  # Broadcast each tensor to the output shape
2359
2398
  broadcasted_tensors = [
2360
- jax.lax.broadcast_in_dim(t, output_shape, _broadcast_dimensions(t.shape, output_shape))
2399
+ jax.lax.broadcast_in_dim(t, output_shape,
2400
+ _broadcast_dimensions(t.shape, output_shape))
2361
2401
  for t in tensors
2362
2402
  ]
2363
2403
 
@@ -2376,6 +2416,7 @@ def _aten_broadcast_to(input, shape):
2376
2416
  def _aten_clamp(self, min=None, max=None):
2377
2417
  return jnp.clip(self, min, max)
2378
2418
 
2419
+
2379
2420
  @op(torch.ops.aten.clamp_min)
2380
2421
  def _aten_clamp_min(input, min):
2381
2422
  return jnp.clip(input, min=min)
@@ -2394,7 +2435,7 @@ def _aten_constant_pad_nd(input, padding, value=0):
2394
2435
  rev_padding = [(padding[i - 1], padding[i], 0) for i in range(m - 1, 0, -2)]
2395
2436
  pad_dim = tuple(([(0, 0, 0)] * (len(input.shape) - m // 2)) + rev_padding)
2396
2437
  value_casted = jax.numpy.array(value, dtype=input.dtype)
2397
- return jax.lax.pad(input, padding_value=value_casted, padding_config = pad_dim)
2438
+ return jax.lax.pad(input, padding_value=value_casted, padding_config=pad_dim)
2398
2439
 
2399
2440
 
2400
2441
  # aten.convolution_backward
@@ -2421,9 +2462,8 @@ def _aten_cdist_forward(x1, x2, p, compute_mode=""):
2421
2462
  @op(torch.ops.aten._pdist_forward)
2422
2463
  def _aten__pdist_forward(x, p=2):
2423
2464
  pairwise_dists = _aten_cdist_forward(x, x, p)
2424
- condensed_dists = pairwise_dists[
2425
- jnp.triu_indices(pairwise_dists.shape[0], k=1)
2426
- ]
2465
+ condensed_dists = pairwise_dists[jnp.triu_indices(
2466
+ pairwise_dists.shape[0], k=1)]
2427
2467
  return condensed_dists
2428
2468
 
2429
2469
 
@@ -2449,25 +2489,33 @@ def _aten_cosh(input):
2449
2489
  return jnp.cosh(input)
2450
2490
 
2451
2491
 
2492
+ @op(torch.ops.aten.diag)
2493
+ def _aten_diag(input, diagonal=0):
2494
+ return jnp.diag(input, diagonal)
2495
+
2496
+
2452
2497
  # aten.diagonal
2453
2498
  @op(torch.ops.aten.diagonal)
2499
+ @op(torch.ops.aten.diagonal_copy)
2454
2500
  def _aten_diagonal(input, offset=0, dim1=0, dim2=1):
2455
2501
  return jnp.diagonal(input, offset, dim1, dim2)
2456
2502
 
2457
2503
 
2458
2504
  def diag_indices_with_offset(input_shape, offset, dim1=0, dim2=1):
2459
- input_len = len(input_shape)
2460
- if dim1 == dim2 or not (0 <= dim1 < input_len and 0 <= dim2 < input_len):
2461
- raise ValueError("dim1 and dim2 must be different and in range [0, " + str(input_len-1)+ "]")
2462
-
2463
- size1, size2 = input_shape[dim1], input_shape[dim2]
2464
- if offset >= 0:
2465
- indices1 = jnp.arange(min(size1, size2 - offset))
2466
- indices2 = jnp.arange(offset, offset + len(indices1))
2467
- else:
2468
- indices2 = jnp.arange(min(size1 + offset, size2 ))
2469
- indices1 = jnp.arange(-offset, -offset + len(indices2))
2470
- return [indices1, indices2]
2505
+ input_len = len(input_shape)
2506
+ if dim1 == dim2 or not (0 <= dim1 < input_len and 0 <= dim2 < input_len):
2507
+ raise ValueError("dim1 and dim2 must be different and in range [0, " +
2508
+ str(input_len - 1) + "]")
2509
+
2510
+ size1, size2 = input_shape[dim1], input_shape[dim2]
2511
+ if offset >= 0:
2512
+ indices1 = jnp.arange(min(size1, size2 - offset))
2513
+ indices2 = jnp.arange(offset, offset + len(indices1))
2514
+ else:
2515
+ indices2 = jnp.arange(min(size1 + offset, size2))
2516
+ indices1 = jnp.arange(-offset, -offset + len(indices2))
2517
+ return [indices1, indices2]
2518
+
2471
2519
 
2472
2520
  @op(torch.ops.aten.diagonal_scatter)
2473
2521
  def _aten_diagonal_scatter(input, src, offset=0, dim1=0, dim2=1):
@@ -2476,17 +2524,17 @@ def _aten_diagonal_scatter(input, src, offset=0, dim1=0, dim2=1):
2476
2524
  if input.ndim == 2:
2477
2525
  return input.at[tuple(indexes)].set(src)
2478
2526
  else:
2479
- # src has the same shape as the output of
2527
+ # src has the same shape as the output of
2480
2528
  # jnp.diagonal(input, offset, dim1, dim2).
2481
2529
  # Last dimension always contains the diagonal elements,
2482
2530
  # while the preceding dimensions represent the "slices"
2483
2531
  # from which these diagonals are extracted. Thus,
2484
2532
  # we alter input axes to match this assumption, write src
2485
2533
  # and then move the axes back to the original state.
2486
- input = jnp.moveaxis(input, (dim1, dim2), (-2,-1))
2487
- multi_indexes = [slice(None)]*(input.ndim-2) + indexes
2534
+ input = jnp.moveaxis(input, (dim1, dim2), (-2, -1))
2535
+ multi_indexes = [slice(None)] * (input.ndim - 2) + indexes
2488
2536
  input = input.at[tuple(multi_indexes)].set(src)
2489
- return jnp.moveaxis(input, (-2,-1), (dim1, dim2))
2537
+ return jnp.moveaxis(input, (-2, -1), (dim1, dim2))
2490
2538
 
2491
2539
 
2492
2540
  # aten.diagflat
@@ -2507,9 +2555,9 @@ def _aten_eq(input1, input2):
2507
2555
 
2508
2556
 
2509
2557
  # aten.equal
2510
- @op(torch.ops.aten.equal, is_jax_function=False)
2558
+ @op(torch.ops.aten.equal)
2511
2559
  def _aten_equal(input, other):
2512
- res = jnp.array_equal(input._elem, other._elem)
2560
+ res = jnp.array_equal(input, other)
2513
2561
  return bool(res)
2514
2562
 
2515
2563
 
@@ -2559,7 +2607,12 @@ def _aten_exp2(input):
2559
2607
  # aten.fill
2560
2608
  @op(torch.ops.aten.fill)
2561
2609
  @op(torch.ops.aten.full_like)
2562
- def _aten_fill(x, value, dtype=None, pin_memory=None, memory_format=None, device=None):
2610
+ def _aten_fill(x,
2611
+ value,
2612
+ dtype=None,
2613
+ pin_memory=None,
2614
+ memory_format=None,
2615
+ device=None):
2563
2616
  if dtype is None:
2564
2617
  dtype = x.dtype
2565
2618
  else:
@@ -2634,7 +2687,8 @@ def _aten_glu(x, dim=-1):
2634
2687
  # aten.hardtanh
2635
2688
  @op(torch.ops.aten.hardtanh)
2636
2689
  def _aten_hardtanh(input, min_val=-1, max_val=1, inplace=False):
2637
- if input.dtype == np.int64 and isinstance(max_val, float) and isinstance(min_val, float):
2690
+ if input.dtype == np.int64 and isinstance(max_val, float) and isinstance(
2691
+ min_val, float):
2638
2692
  min_val = int(min_val)
2639
2693
  max_val = int(max_val)
2640
2694
  return jnp.clip(input, min_val, max_val)
@@ -2644,7 +2698,7 @@ def _aten_hardtanh(input, min_val=-1, max_val=1, inplace=False):
2644
2698
  @op(torch.ops.aten.histc)
2645
2699
  def _aten_histc(input, bins=100, min=0, max=0):
2646
2700
  # TODO(@manfei): this function might cause some uncertainty
2647
- if min==0 and max==0:
2701
+ if min == 0 and max == 0:
2648
2702
  if isinstance(input, jnp.ndarray) and input.size == 0:
2649
2703
  min = 0
2650
2704
  max = 0
@@ -2652,7 +2706,8 @@ def _aten_histc(input, bins=100, min=0, max=0):
2652
2706
  min = jnp.min(input)
2653
2707
  max = jnp.max(input)
2654
2708
  range_value = (min, max)
2655
- hist, bin_edges = jnp.histogram(input, bins=bins, range=range_value, weights=None, density=None)
2709
+ hist, bin_edges = jnp.histogram(
2710
+ input, bins=bins, range=range_value, weights=None, density=None)
2656
2711
  return hist
2657
2712
 
2658
2713
 
@@ -2667,22 +2722,28 @@ def _aten_digamma(input, *, out=None):
2667
2722
  # replace indices where input == 0 with -inf in res
2668
2723
  return jnp.where(jnp.equal(input, jnp.zeros(input.shape)), -jnp.inf, res)
2669
2724
 
2725
+
2670
2726
  @op(torch.ops.aten.igamma)
2671
2727
  def _aten_igamma(input, other):
2672
2728
  return jax.scipy.special.gammainc(input, other)
2673
2729
 
2730
+
2674
2731
  @op(torch.ops.aten.lgamma)
2675
2732
  def _aten_lgamma(input, *, out=None):
2676
2733
  return jax.scipy.special.gammaln(input).astype(jnp.float32)
2677
2734
 
2735
+
2678
2736
  @op(torch.ops.aten.mvlgamma)
2679
2737
  def _aten_mvlgamma(input, p, *, out=None):
2680
- return jax.scipy.special.multigammaln(input, d)
2738
+ input = input.astype(mappings.t2j_dtype(torch.get_default_dtype()))
2739
+ return jax.scipy.special.multigammaln(input, p)
2740
+
2681
2741
 
2682
2742
  @op(torch.ops.aten.linalg_eig)
2683
2743
  def _aten_linalg_eig(A):
2684
2744
  return jnp.linalg.eig(A)
2685
2745
 
2746
+
2686
2747
  @op(torch.ops.aten._linalg_eigh)
2687
2748
  def _aten_linalg_eigh(A, UPLO='L'):
2688
2749
  return jnp.linalg.eigh(A, UPLO)
@@ -2704,7 +2765,9 @@ def _aten_linalg_lstsq(A, B, rcond=None, driver='gelsy'):
2704
2765
  A_reshaped = A.reshape((batch_size,) + A.shape[-2:])
2705
2766
  B_reshaped = B.reshape((batch_size,) + B.shape[-2:])
2706
2767
 
2707
- X, residuals, rank, singular_values = jax.vmap(jnp.linalg.lstsq, in_axes=(0, 0))(A_reshaped, B_reshaped, rcond=rcond)
2768
+ X, residuals, rank, singular_values = jax.vmap(
2769
+ jnp.linalg.lstsq, in_axes=(0,
2770
+ 0))(A_reshaped, B_reshaped, rcond=rcond)
2708
2771
 
2709
2772
  X = X.reshape(batch_shape + X.shape[-2:])
2710
2773
 
@@ -2720,7 +2783,8 @@ def _aten_linalg_lstsq(A, B, rcond=None, driver='gelsy'):
2720
2783
  residuals = residuals.reshape(batch_shape + residuals.shape[-1:])
2721
2784
 
2722
2785
  if driver in ['gelsd', 'gelss']:
2723
- singular_values = singular_values.reshape(batch_shape + singular_values.shape[-1:])
2786
+ singular_values = singular_values.reshape(batch_shape +
2787
+ singular_values.shape[-1:])
2724
2788
  else:
2725
2789
  singular_values = jnp.array([], dtype=input_dtype)
2726
2790
 
@@ -2729,17 +2793,17 @@ def _aten_linalg_lstsq(A, B, rcond=None, driver='gelsy'):
2729
2793
  X, residuals, rank, singular_values = jnp.linalg.lstsq(A, B, rcond=rcond)
2730
2794
 
2731
2795
  if driver not in ['gelsd', 'gelsy', 'gelss']:
2732
- rank = jnp.array([], dtype=jnp.int64)
2796
+ rank = jnp.array([], dtype=jnp.int64)
2733
2797
 
2734
2798
  rank_value = None
2735
2799
  if rank.size > 0:
2736
- rank_value = int(rank.item())
2737
- rank = jnp.array(rank_value, dtype=jnp.int64)
2800
+ rank_value = int(rank.item())
2801
+ rank = jnp.array(rank_value, dtype=jnp.int64)
2738
2802
 
2739
2803
  # When driver is ‘gels’, assume that A is full-rank.
2740
- full_rank = driver == 'gels' or rank_value == n
2804
+ full_rank = driver == 'gels' or rank_value == n
2741
2805
  if driver == 'gelsy' or m <= n or (not full_rank):
2742
- residuals = jnp.array([], dtype=input_dtype)
2806
+ residuals = jnp.array([], dtype=input_dtype)
2743
2807
 
2744
2808
  if driver not in ['gelsd', 'gelss']:
2745
2809
  singular_values = jnp.array([], dtype=input_dtype)
@@ -2753,8 +2817,7 @@ def _aten_linalg_ldl_factor_ex(A, hermitian=False, check_errors=False):
2753
2817
  # https://github.com/jax-ml/jax/issues/12779
2754
2818
  # TODO: Not tested for complex inputs. Does not support hermitian=True
2755
2819
  pivots = jnp.broadcast_to(
2756
- jnp.arange(1, A.shape[-1]+1, dtype=jnp.int32), A.shape[:-1]
2757
- )
2820
+ jnp.arange(1, A.shape[-1] + 1, dtype=jnp.int32), A.shape[:-1])
2758
2821
  info = jnp.zeros(A.shape[:-2], jnp.int32)
2759
2822
  C = jnp.linalg.cholesky(A)
2760
2823
  if C.size == 0:
@@ -2767,7 +2830,7 @@ def _aten_linalg_ldl_factor_ex(A, hermitian=False, check_errors=False):
2767
2830
 
2768
2831
  D = C * jnp.eye(C.shape[-1], dtype=A.dtype)
2769
2832
  LD = C @ jnp.linalg.inv(D)
2770
- LD = fill_diagonal_batch(LD, D*D)
2833
+ LD = fill_diagonal_batch(LD, D * D)
2771
2834
  return LD, pivots, info
2772
2835
 
2773
2836
 
@@ -2787,9 +2850,9 @@ def _aten_linalg_lu(A, pivot=True, out=None):
2787
2850
  U = jnp.triu(lu[..., :k, :])
2788
2851
 
2789
2852
  def perm_to_P(perm):
2790
- m = perm.shape[-1]
2791
- P = jnp.eye(m, dtype=dtype)[perm].T
2792
- return P
2853
+ m = perm.shape[-1]
2854
+ P = jnp.eye(m, dtype=dtype)[perm].T
2855
+ return P
2793
2856
 
2794
2857
  if permutation.ndim > 1:
2795
2858
  num_batch_dims = permutation.ndim - 1
@@ -2798,7 +2861,7 @@ def _aten_linalg_lu(A, pivot=True, out=None):
2798
2861
 
2799
2862
  P = perm_to_P(permutation)
2800
2863
 
2801
- return P,L,U
2864
+ return P, L, U
2802
2865
 
2803
2866
 
2804
2867
  @op(torch.ops.aten.linalg_lu_factor_ex)
@@ -2810,6 +2873,21 @@ def _aten_linalg_lu_factor_ex(A, pivot=True, check_errors=False):
2810
2873
  return lu, pivots, info
2811
2874
 
2812
2875
 
2876
+ @op(torch.ops.aten.linalg_lu_solve)
2877
+ def _aten_linalg_lu_solve(LU, pivots, B, left=True, adjoint=False):
2878
+ # JAX pivots are offset by 1 compared to torch
2879
+ pivots = pivots - 1
2880
+ if not left:
2881
+ # XA = B is same as A'X = B'
2882
+ trans = 0 if adjoint else 2
2883
+ x = jax.scipy.linalg.lu_solve((LU, pivots), jnp.matrix_transpose(B), trans)
2884
+ x = jnp.matrix_transpose(x)
2885
+ else:
2886
+ trans = 2 if adjoint else 0
2887
+ x = jax.scipy.linalg.lu_solve((LU, pivots), B, trans)
2888
+ return x
2889
+
2890
+
2813
2891
  @op(torch.ops.aten.gcd)
2814
2892
  def _aten_gcd(input, other):
2815
2893
  return jnp.gcd(input, other)
@@ -2874,12 +2952,14 @@ def _aten_log2(x):
2874
2952
 
2875
2953
  # aten.logical_and
2876
2954
  @op(torch.ops.aten.logical_and)
2955
+ @op(torch.ops.aten.__and__)
2877
2956
  def _aten_logical_and(self, other):
2878
2957
  return jnp.logical_and(self, other)
2879
2958
 
2880
2959
 
2881
2960
  # aten.logical_or
2882
2961
  @op(torch.ops.aten.logical_or)
2962
+ @op(torch.ops.aten.__or__)
2883
2963
  def _aten_logical_or(self, other):
2884
2964
  return jnp.logical_or(self, other)
2885
2965
 
@@ -2894,7 +2974,7 @@ def _aten_logical_not(self):
2894
2974
  @op(torch.ops.aten._log_softmax)
2895
2975
  def _aten_log_softmax(self, axis=-1, half_to_float=False):
2896
2976
  if self.shape == ():
2897
- return jnp.astype(0.0, self.dtype)
2977
+ return jnp.astype(0.0, self.dtype)
2898
2978
  return jax.nn.log_softmax(self, axis)
2899
2979
 
2900
2980
 
@@ -2921,6 +3001,7 @@ def _aten_logcumsumexp(self, dim=None):
2921
3001
  # aten.max_pool3d_backward
2922
3002
  # aten.logical_xor
2923
3003
  @op(torch.ops.aten.logical_xor)
3004
+ @op(torch.ops.aten.__xor__)
2924
3005
  def _aten_logical_xor(self, other):
2925
3006
  return jnp.logical_xor(self, other)
2926
3007
 
@@ -2933,19 +3014,22 @@ def _aten_logical_xor(self, other):
2933
3014
  def _aten_neg(x):
2934
3015
  return -1 * x
2935
3016
 
3017
+
2936
3018
  @op(torch.ops.aten.nextafter)
2937
3019
  def _aten_nextafter(input, other, *, out=None):
2938
3020
  return jnp.nextafter(input, other)
2939
3021
 
2940
3022
 
2941
3023
  @op(torch.ops.aten.nonzero_static)
2942
- def _aten_nonzero_static(input, size, fill_value = -1):
3024
+ def _aten_nonzero_static(input, size, fill_value=-1):
2943
3025
  indices = jnp.argwhere(input)
2944
3026
 
2945
3027
  if size < indices.shape[0]:
2946
3028
  indices = indices[:size]
2947
3029
  elif size > indices.shape[0]:
2948
- padding = jnp.full((size - indices.shape[0], indices.shape[1]), fill_value, dtype=indices.dtype)
3030
+ padding = jnp.full((size - indices.shape[0], indices.shape[1]),
3031
+ fill_value,
3032
+ dtype=indices.dtype)
2949
3033
  indices = jnp.concatenate((indices, padding))
2950
3034
 
2951
3035
  return indices
@@ -2954,9 +3038,11 @@ def _aten_nonzero_static(input, size, fill_value = -1):
2954
3038
  # aten.nonzero
2955
3039
  @op(torch.ops.aten.nonzero)
2956
3040
  def _aten_nonzero(x, as_tuple=False):
2957
- if jnp.ndim(x) == 0 and (as_tuple or x.item()==0):
3041
+ if jnp.ndim(x) == 0 and (as_tuple or x.item() == 0):
2958
3042
  return torch.empty(0, 0, dtype=torch.int64)
2959
- if jnp.ndim(x) == 0: # when x is scalar, return torch.tensor([], size=(1, 0), dtype=torch.int64)
3043
+ if jnp.ndim(
3044
+ x
3045
+ ) == 0: # when x is scalar, return torch.tensor([], size=(1, 0), dtype=torch.int64)
2960
3046
  res = torch.empty(1, 0, dtype=torch.int64)
2961
3047
  return jnp.array(res.numpy())
2962
3048
  index_tuple = jnp.nonzero(x)
@@ -2997,15 +3083,15 @@ def _aten_put(self, index, source, accumulate=False):
2997
3083
  # aten.randperm
2998
3084
  # randperm.generator(SymInt n, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None)
2999
3085
  @op(torch.ops.aten.randperm, needs_env=True)
3000
- def _aten_randperm(
3001
- n, *,
3002
- generator=None,
3003
- dtype=None,
3004
- layout=None,
3005
- device=None,
3006
- pin_memory=None,
3007
- env=None):
3008
- """
3086
+ def _aten_randperm(n,
3087
+ *,
3088
+ generator=None,
3089
+ dtype=None,
3090
+ layout=None,
3091
+ device=None,
3092
+ pin_memory=None,
3093
+ env=None):
3094
+ """
3009
3095
  Generates a random permutation of integers from 0 to n-1.
3010
3096
 
3011
3097
  Args:
@@ -3019,14 +3105,14 @@ def _aten_randperm(
3019
3105
  Returns:
3020
3106
  A DeviceArray containing a random permutation of integers from 0 to n-1.
3021
3107
  """
3022
- if dtype:
3023
- dtype = mappings.t2j_dtype(dtype)
3024
- else:
3025
- dtype = jnp.int64.dtype
3026
- key = env.get_and_rotate_prng_key(generator)
3027
- indices = jnp.arange(n, dtype=dtype)
3028
- permutation = jax.random.permutation(key, indices)
3029
- return permutation
3108
+ if dtype:
3109
+ dtype = mappings.t2j_dtype(dtype)
3110
+ else:
3111
+ dtype = jnp.int64.dtype
3112
+ key = env.get_and_rotate_prng_key(generator)
3113
+ indices = jnp.arange(n, dtype=dtype)
3114
+ permutation = jax.random.permutation(key, indices)
3115
+ return permutation
3030
3116
 
3031
3117
 
3032
3118
  # aten.reflection_pad3d
@@ -3071,8 +3157,8 @@ def _aten_sort(a, dim=-1, descending=False, stable=False):
3071
3157
  if a.shape == ():
3072
3158
  return (a, jnp.astype(0, 'int64'))
3073
3159
  return (
3074
- jnp.sort(a, axis=dim, stable=stable, descending=descending),
3075
- jnp.argsort(a, axis=dim, stable=stable, descending=descending),
3160
+ jnp.sort(a, axis=dim, stable=stable, descending=descending),
3161
+ jnp.argsort(a, axis=dim, stable=stable, descending=descending),
3076
3162
  )
3077
3163
 
3078
3164
 
@@ -3114,8 +3200,8 @@ def _aten_topk(input, k, dim=None, largest=True, sorted=True, *, out=None):
3114
3200
  if dim != -1 and dim != len(input.shape) - 1:
3115
3201
  transpose_shape = list(range(len(input.shape)))
3116
3202
  transpose_shape[dim], transpose_shape[-1] = (
3117
- transpose_shape[-1],
3118
- transpose_shape[dim],
3203
+ transpose_shape[-1],
3204
+ transpose_shape[dim],
3119
3205
  )
3120
3206
  input = jnp.transpose(input, transpose_shape)
3121
3207
 
@@ -3124,8 +3210,7 @@ def _aten_topk(input, k, dim=None, largest=True, sorted=True, *, out=None):
3124
3210
  if sorted:
3125
3211
  values = jnp.sort(values, descending=True)
3126
3212
  indices = jnp.take_along_axis(
3127
- indices, jnp.argsort(values, axis=-1, descending=True), axis=-1
3128
- )
3213
+ indices, jnp.argsort(values, axis=-1, descending=True), axis=-1)
3129
3214
 
3130
3215
  if not largest:
3131
3216
  values = -values # Negate values back if we found smallest
@@ -3140,21 +3225,39 @@ def _aten_topk(input, k, dim=None, largest=True, sorted=True, *, out=None):
3140
3225
  # aten.tril_indices
3141
3226
  #tril_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None)
3142
3227
  @op(torch.ops.aten.tril_indices)
3143
- def _aten_tril_indices(row, col, offset=0, *, dtype=jnp.int64.dtype, layout=None, device=None, pin_memory=None):
3228
+ def _aten_tril_indices(row,
3229
+ col,
3230
+ offset=0,
3231
+ *,
3232
+ dtype=jnp.int64.dtype,
3233
+ layout=None,
3234
+ device=None,
3235
+ pin_memory=None):
3144
3236
  a, b = jnp.tril_indices(row, offset, col)
3145
3237
  return jnp.stack((a, b))
3146
3238
 
3239
+
3147
3240
  # aten.tril_indices
3148
3241
  #tril_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None)
3149
3242
  @op(torch.ops.aten.triu_indices)
3150
- def _aten_triu_indices(row, col, offset=0, *, dtype=jnp.int64.dtype, layout=None, device=None, pin_memory=None):
3243
+ def _aten_triu_indices(row,
3244
+ col,
3245
+ offset=0,
3246
+ *,
3247
+ dtype=jnp.int64.dtype,
3248
+ layout=None,
3249
+ device=None,
3250
+ pin_memory=None):
3151
3251
  a, b = jnp.triu_indices(row, offset, col)
3152
3252
  return jnp.stack((a, b))
3153
3253
 
3154
3254
 
3155
3255
  @op(torch.ops.aten.unbind_copy)
3156
3256
  def _aten_unbind(a, dim=0):
3157
- return [jax.lax.index_in_dim(a, i, dim, keepdims=False) for i in range(a.shape[dim])]
3257
+ return [
3258
+ jax.lax.index_in_dim(a, i, dim, keepdims=False)
3259
+ for i in range(a.shape[dim])
3260
+ ]
3158
3261
 
3159
3262
 
3160
3263
  # aten.unique_dim
@@ -3167,12 +3270,13 @@ def _aten_unique_dim(input_tensor,
3167
3270
  sort=True,
3168
3271
  return_inverse=False,
3169
3272
  return_counts=False):
3170
- result_tensor_or_tuple = jnp.unique(input_tensor,
3171
- return_index=False,
3172
- return_inverse=return_inverse,
3173
- return_counts=return_counts,
3174
- axis=dim,
3175
- equal_nan=False)
3273
+ result_tensor_or_tuple = jnp.unique(
3274
+ input_tensor,
3275
+ return_index=False,
3276
+ return_inverse=return_inverse,
3277
+ return_counts=return_counts,
3278
+ axis=dim,
3279
+ equal_nan=False)
3176
3280
  result_list = (
3177
3281
  list(result_tensor_or_tuple) if isinstance(result_tensor_or_tuple, tuple)
3178
3282
  else [result_tensor_or_tuple])
@@ -3197,15 +3301,14 @@ def _aten_unique_dim(input_tensor,
3197
3301
  # NOTE: Like the CUDA and CPU implementations, this implementation always sorts
3198
3302
  # the tensor regardless of the `sorted` argument passed to `torch.unique`.
3199
3303
  @op(torch.ops.aten._unique)
3200
- def _aten_unique(input_tensor,
3201
- sort=True,
3202
- return_inverse=False):
3203
- result_tensor_or_tuple = jnp.unique(input_tensor,
3204
- return_index=False,
3205
- return_inverse=return_inverse,
3206
- return_counts=False,
3207
- axis=None,
3208
- equal_nan=False)
3304
+ def _aten_unique(input_tensor, sort=True, return_inverse=False):
3305
+ result_tensor_or_tuple = jnp.unique(
3306
+ input_tensor,
3307
+ return_index=False,
3308
+ return_inverse=return_inverse,
3309
+ return_counts=False,
3310
+ axis=None,
3311
+ equal_nan=False)
3209
3312
  if return_inverse:
3210
3313
  return result_tensor_or_tuple
3211
3314
  else:
@@ -3221,11 +3324,12 @@ def _aten_unique2(input_tensor,
3221
3324
  sort=True,
3222
3325
  return_inverse=False,
3223
3326
  return_counts=False):
3224
- return _aten_unique_dim(input_tensor=input_tensor,
3225
- dim=None,
3226
- sort=sort,
3227
- return_inverse=return_inverse,
3228
- return_counts=return_counts)
3327
+ return _aten_unique_dim(
3328
+ input_tensor=input_tensor,
3329
+ dim=None,
3330
+ sort=sort,
3331
+ return_inverse=return_inverse,
3332
+ return_counts=return_counts)
3229
3333
 
3230
3334
 
3231
3335
  # aten.unique_consecutive
@@ -3255,17 +3359,18 @@ def _aten_unique_consecutive(input_tensor,
3255
3359
  if dim < 0:
3256
3360
  dim += ndim
3257
3361
 
3258
- nd_slice_0 = tuple(slice(None, -1) if d == dim else slice(None)
3259
- for d in range(ndim))
3260
- nd_slice_1 = tuple(slice(1, None) if d == dim else slice(None)
3261
- for d in range(ndim))
3362
+ nd_slice_0 = tuple(
3363
+ slice(None, -1) if d == dim else slice(None) for d in range(ndim))
3364
+ nd_slice_1 = tuple(
3365
+ slice(1, None) if d == dim else slice(None) for d in range(ndim))
3262
3366
 
3263
3367
  axes_to_reduce = tuple(d for d in range(ndim) if d != dim)
3264
3368
 
3265
3369
  does_not_equal_prior = (
3266
- jnp.any(input_tensor[nd_slice_0] != input_tensor[nd_slice_1],
3267
- axis=axes_to_reduce,
3268
- keepdims=False))
3370
+ jnp.any(
3371
+ input_tensor[nd_slice_0] != input_tensor[nd_slice_1],
3372
+ axis=axes_to_reduce,
3373
+ keepdims=False))
3269
3374
 
3270
3375
  if input_tensor.shape[dim] != 0:
3271
3376
  # Prepend `True` to represent the first element of the input.
@@ -3273,18 +3378,17 @@ def _aten_unique_consecutive(input_tensor,
3273
3378
 
3274
3379
  include_indices = jnp.argwhere(does_not_equal_prior)[:, 0]
3275
3380
 
3276
- output_tensor = input_tensor[
3277
- tuple(include_indices if d == dim else slice(None) for d in range(ndim))]
3381
+ output_tensor = input_tensor[tuple(
3382
+ include_indices if d == dim else slice(None) for d in range(ndim))]
3278
3383
 
3279
3384
  if return_inverse or return_counts:
3280
- counts = (jnp.append(include_indices[1:], input_tensor.shape[dim]) -
3281
- include_indices[:])
3385
+ counts = (
3386
+ jnp.append(include_indices[1:], input_tensor.shape[dim]) -
3387
+ include_indices[:])
3282
3388
 
3283
3389
  inverse = (
3284
3390
  jnp.reshape(jnp.repeat(jnp.arange(len(counts)), counts), inverse_shape)
3285
- if return_inverse
3286
- else None
3287
- )
3391
+ if return_inverse else None)
3288
3392
 
3289
3393
  return output_tensor, inverse, counts
3290
3394
 
@@ -3302,25 +3406,33 @@ def _aten_unique_consecutive(input_tensor,
3302
3406
  @op(torch.ops.aten.where.ScalarSelf)
3303
3407
  @op(torch.ops.aten.where.ScalarOther)
3304
3408
  @op(torch.ops.aten.where.Scalar)
3305
- def _aten_where(condition, x = None, y = None):
3409
+ def _aten_where(condition, x=None, y=None):
3306
3410
  return jnp.where(condition, x, y)
3307
3411
 
3308
3412
 
3309
3413
  # aten.to.dtype
3310
3414
  # Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None
3311
3415
  @op(torch.ops.aten.to.dtype)
3312
- def _aten_to_dtype(
3313
- a, dtype, non_blocking=False, copy=False, memory_format=None
3314
- ):
3416
+ def _aten_to_dtype(a,
3417
+ dtype,
3418
+ non_blocking=False,
3419
+ copy=False,
3420
+ memory_format=None):
3315
3421
  if dtype:
3316
3422
  jaxdtype = mappings.t2j_dtype(dtype)
3317
3423
  return a.astype(jaxdtype)
3318
3424
 
3319
3425
 
3320
3426
  @op(torch.ops.aten.to.dtype_layout)
3321
- def _aten_to_dtype_layout(
3322
- a, *, dtype=None, layout=None, device=None, pin_memory=None, non_blocking=False, copy=False, memory_format=None
3323
- ):
3427
+ def _aten_to_dtype_layout(a,
3428
+ *,
3429
+ dtype=None,
3430
+ layout=None,
3431
+ device=None,
3432
+ pin_memory=None,
3433
+ non_blocking=False,
3434
+ copy=False,
3435
+ memory_format=None):
3324
3436
  return _aten_to_dtype(
3325
3437
  a,
3326
3438
  dtype,
@@ -3328,6 +3440,7 @@ def _aten_to_dtype_layout(
3328
3440
  copy=copy,
3329
3441
  memory_format=memory_format)
3330
3442
 
3443
+
3331
3444
  # aten.to.device
3332
3445
 
3333
3446
 
@@ -3348,9 +3461,11 @@ def _aten_var_mean_correction(tensor, dim=None, correction=1, keepdim=False):
3348
3461
 
3349
3462
  @op(torch.ops.aten.scalar_tensor)
3350
3463
  @op_base.convert_dtype()
3351
- def _aten_scalar_tensor(
3352
- s, dtype=None, layout=None, device=None, pin_memory=None
3353
- ):
3464
+ def _aten_scalar_tensor(s,
3465
+ dtype=None,
3466
+ layout=None,
3467
+ device=None,
3468
+ pin_memory=None):
3354
3469
  return jnp.array(s, dtype=dtype)
3355
3470
 
3356
3471
 
@@ -3360,9 +3475,9 @@ def _aten_to_device(x, device, dtype):
3360
3475
 
3361
3476
 
3362
3477
  @op(torch.ops.aten.max_pool2d_with_indices_backward)
3363
- def max_pool2d_with_indices_backward_custom(
3364
- grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices
3365
- ):
3478
+ def max_pool2d_with_indices_backward_custom(grad_output, self, kernel_size,
3479
+ stride, padding, dilation,
3480
+ ceil_mode, indices):
3366
3481
  """
3367
3482
  Approximates the gradient calculation of PyTorch's max_pool2d_with_indices_backward.
3368
3483
 
@@ -3418,15 +3533,15 @@ def _aten_tensor_split(ary, indices_or_sections, axis=0):
3418
3533
  @op(torch.ops.aten.randn, needs_env=True)
3419
3534
  @op_base.convert_dtype()
3420
3535
  def _randn(
3421
- *size,
3422
- generator=None,
3423
- out=None,
3424
- dtype=None,
3425
- layout=torch.strided,
3426
- device=None,
3427
- requires_grad=False,
3428
- pin_memory=False,
3429
- env=None,
3536
+ *size,
3537
+ generator=None,
3538
+ out=None,
3539
+ dtype=None,
3540
+ layout=torch.strided,
3541
+ device=None,
3542
+ requires_grad=False,
3543
+ pin_memory=False,
3544
+ env=None,
3430
3545
  ):
3431
3546
  shape = size
3432
3547
  if len(shape) == 1 and isinstance(shape[0], (list, tuple)):
@@ -3437,13 +3552,14 @@ def _randn(
3437
3552
  res = res.astype(dtype)
3438
3553
  return res
3439
3554
 
3555
+
3440
3556
  @op(torch.ops.aten.bernoulli.p, needs_env=True)
3441
- def _bernoulli(
3442
- self,
3443
- p = 0.5,
3444
- *,
3445
- generator=None,
3446
- env=None,
3557
+ def _aten_bernoulli(
3558
+ self,
3559
+ p=0.5,
3560
+ *,
3561
+ generator=None,
3562
+ env=None,
3447
3563
  ):
3448
3564
  key = env.get_and_rotate_prng_key(generator)
3449
3565
  res = jax.random.uniform(key, self.shape) < p
@@ -3460,14 +3576,14 @@ def geometric(self, p, *, generator=None, env=None):
3460
3576
  @op(torch.ops.aten.randn_like, needs_env=True)
3461
3577
  @op_base.convert_dtype()
3462
3578
  def _aten_randn_like(
3463
- x,
3464
- *,
3465
- dtype=None,
3466
- layout=None,
3467
- device=None,
3468
- pin_memory=False,
3469
- memory_format=torch.preserve_format,
3470
- env=None,
3579
+ x,
3580
+ *,
3581
+ dtype=None,
3582
+ layout=None,
3583
+ device=None,
3584
+ pin_memory=False,
3585
+ memory_format=torch.preserve_format,
3586
+ env=None,
3471
3587
  ):
3472
3588
  key = env.get_and_rotate_prng_key()
3473
3589
  return jax.random.normal(key, dtype=dtype or x.dtype, shape=x.shape)
@@ -3476,15 +3592,15 @@ def _aten_randn_like(
3476
3592
  @op(torch.ops.aten.rand, needs_env=True)
3477
3593
  @op_base.convert_dtype()
3478
3594
  def _rand(
3479
- *size,
3480
- generator=None,
3481
- out=None,
3482
- dtype=None,
3483
- layout=torch.strided,
3484
- device=None,
3485
- requires_grad=False,
3486
- pin_memory=False,
3487
- env=None,
3595
+ *size,
3596
+ generator=None,
3597
+ out=None,
3598
+ dtype=None,
3599
+ layout=torch.strided,
3600
+ device=None,
3601
+ requires_grad=False,
3602
+ pin_memory=False,
3603
+ env=None,
3488
3604
  ):
3489
3605
  shape = size
3490
3606
  if len(shape) == 1 and isinstance(shape[0], (list, tuple)):
@@ -3505,18 +3621,32 @@ def _aten_outer(a, b):
3505
3621
  def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):
3506
3622
  return jnp.allclose(input, other, rtol, atol, equal_nan)
3507
3623
 
3624
+
3508
3625
  @op(torch.ops.aten.native_batch_norm)
3509
- def _aten_native_batch_norm(input, weight, bias, running_mean, running_var, training=False, momentum=0.1, eps=1e-5):
3626
+ def _aten_native_batch_norm(input,
3627
+ weight,
3628
+ bias,
3629
+ running_mean,
3630
+ running_var,
3631
+ training=False,
3632
+ momentum=0.1,
3633
+ eps=1e-5):
3510
3634
 
3511
3635
  if running_mean is None:
3512
- running_mean = jnp.zeros(input.shape[1], dtype=input.dtype) # Initialize running mean if None
3636
+ running_mean = jnp.zeros(
3637
+ input.shape[1], dtype=input.dtype) # Initialize running mean if None
3513
3638
  if running_var is None:
3514
- running_var = jnp.ones(input.shape[1], dtype=input.dtype) # Initialize running variance if None
3639
+ running_var = jnp.ones(
3640
+ input.shape[1],
3641
+ dtype=input.dtype) # Initialize running variance if None
3515
3642
 
3516
3643
  if training:
3517
- return _aten__native_batch_norm_legit(input, weight, bias, running_mean, running_var, training, momentum, eps)
3644
+ return _aten__native_batch_norm_legit(input, weight, bias, running_mean,
3645
+ running_var, training, momentum, eps)
3518
3646
  else:
3519
- return _aten__native_batch_norm_legit_no_training(input, weight, bias, running_mean, running_var, momentum, eps)
3647
+ return _aten__native_batch_norm_legit_no_training(input, weight, bias,
3648
+ running_mean, running_var,
3649
+ momentum, eps)
3520
3650
 
3521
3651
 
3522
3652
  @op(torch.ops.aten.normal, needs_env=True)
@@ -3525,12 +3655,14 @@ def _aten_normal(self, mean=0, std=1, generator=None, env=None):
3525
3655
  res = _randn(*shape, generator=generator, env=env)
3526
3656
  return res * std + mean
3527
3657
 
3658
+
3528
3659
  # TODO: not clear what this function should actually do
3529
3660
  # https://github.com/pytorch/pytorch/blob/d96c80649f301129219469d8b4353e52edab3b78/aten/src/ATen/native/native_functions.yaml#L7933-L7940
3530
3661
  @op(torch.ops.aten.lift_fresh)
3531
3662
  def _aten_lift_fresh(self):
3532
3663
  return self
3533
3664
 
3665
+
3534
3666
  @op(torch.ops.aten.uniform, needs_env=True)
3535
3667
  def _aten_uniform(self, from_=0, to=1, *, generator=None, env=None):
3536
3668
  assert from_ <= to, f'Uniform from(passed in {from_}) must be less than to(passed in {to})'
@@ -3538,16 +3670,18 @@ def _aten_uniform(self, from_=0, to=1, *, generator=None, env=None):
3538
3670
  res = _rand(*shape, generator=generator, env=env)
3539
3671
  return res * (to - from_) + from_
3540
3672
 
3673
+
3541
3674
  #func: randint.low_generator(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
3542
3675
 
3676
+
3543
3677
  @op(torch.ops.aten.randint, needs_env=True)
3544
3678
  @op_base.convert_dtype(use_default_dtype=False)
3545
3679
  def _aten_randint(
3546
- *args,
3547
- generator=None,
3548
- dtype=None,
3549
- env=None,
3550
- **kwargs,
3680
+ *args,
3681
+ generator=None,
3682
+ dtype=None,
3683
+ env=None,
3684
+ **kwargs,
3551
3685
  ):
3552
3686
  if len(args) == 3:
3553
3687
  # low, high, size
@@ -3556,7 +3690,8 @@ def _aten_randint(
3556
3690
  high, size = args
3557
3691
  low = 0
3558
3692
  else:
3559
- raise AssertionError(f'Expected at 2 or 3 args for Aten::randint, got {len(args)}')
3693
+ raise AssertionError(
3694
+ f'Expected at 2 or 3 args for Aten::randint, got {len(args)}')
3560
3695
 
3561
3696
  key = env.get_and_rotate_prng_key(generator)
3562
3697
  res = jax.random.randint(key, size, low, high)
@@ -3564,15 +3699,18 @@ def _aten_randint(
3564
3699
  res = res.astype(dtype)
3565
3700
  return res
3566
3701
 
3567
- @op(torch.ops.aten.randint_like, torch.ops.aten.randint.generator, needs_env=True)
3702
+
3703
+ @op(torch.ops.aten.randint_like,
3704
+ torch.ops.aten.randint.generator,
3705
+ needs_env=True)
3568
3706
  @op_base.convert_dtype(use_default_dtype=False)
3569
3707
  def _aten_randint_like(
3570
- input,
3571
- *args,
3572
- generator=None,
3573
- dtype=None,
3574
- env=None,
3575
- **kwargs,
3708
+ input,
3709
+ *args,
3710
+ generator=None,
3711
+ dtype=None,
3712
+ env=None,
3713
+ **kwargs,
3576
3714
  ):
3577
3715
  if len(args) == 2:
3578
3716
  low, high = args
@@ -3580,7 +3718,8 @@ def _aten_randint_like(
3580
3718
  high = args[0]
3581
3719
  low = 0
3582
3720
  else:
3583
- raise AssertionError(f'Expected at 1 or 2 args for Aten::randint_like, got {len(args)}')
3721
+ raise AssertionError(
3722
+ f'Expected at 1 or 2 args for Aten::randint_like, got {len(args)}')
3584
3723
 
3585
3724
  shape = input.shape
3586
3725
  dtype = dtype or input.dtype
@@ -3590,6 +3729,7 @@ def _aten_randint_like(
3590
3729
  res = res.astype(dtype)
3591
3730
  return res
3592
3731
 
3732
+
3593
3733
  @op(torch.ops.aten.dim, is_jax_function=False)
3594
3734
  def _aten_dim(self):
3595
3735
  return len(self.shape)
@@ -3602,10 +3742,11 @@ def _aten_copysign(input, other, *, out=None):
3602
3742
  # regardless of their exact integer dtype, whereas jax.copysign returns
3603
3743
  # float64 when one or both of them is int64.
3604
3744
  if jnp.issubdtype(input.dtype, jnp.integer) and jnp.issubdtype(
3605
- other.dtype, jnp.integer
3606
- ):
3745
+ other.dtype, jnp.integer):
3607
3746
  result = result.astype(jnp.float32)
3608
3747
  return result
3748
+
3749
+
3609
3750
  @op(torch.ops.aten.i0)
3610
3751
  @op_base.promote_int_input
3611
3752
  def _aten_i0(self):
@@ -3637,6 +3778,7 @@ def _aten_special_laguerre_polynomial_l(self, n):
3637
3778
 
3638
3779
  @jnp.vectorize
3639
3780
  def vectorized(x, n_i):
3781
+
3640
3782
  def negative_n(x):
3641
3783
  return jnp.zeros_like(x)
3642
3784
 
@@ -3650,6 +3792,7 @@ def _aten_special_laguerre_polynomial_l(self, n):
3650
3792
  return jnp.ones_like(x)
3651
3793
 
3652
3794
  def default(x):
3795
+
3653
3796
  def f(k, carry):
3654
3797
  p, q = carry
3655
3798
  return (q, ((k * 2 + (jnp.ones_like(x) - x)) * q - k * p) / (k + 1))
@@ -3658,9 +3801,9 @@ def _aten_special_laguerre_polynomial_l(self, n):
3658
3801
  return q
3659
3802
 
3660
3803
  return jnp.piecewise(
3661
- x, [n_i == 1, n_i == 0, jnp.abs(n_i) == jnp.zeros_like(x), n_i < 0], [
3662
- one_n, zero_n, zero_abs, negative_n, default]
3663
- )
3804
+ x, [n_i == 1, n_i == 0,
3805
+ jnp.abs(n_i) == jnp.zeros_like(x), n_i < 0],
3806
+ [one_n, zero_n, zero_abs, negative_n, default])
3664
3807
 
3665
3808
  return vectorized(self, n.astype(jnp.int64))
3666
3809
 
@@ -3760,125 +3903,124 @@ def _aten_special_modified_bessel_i0(self):
3760
3903
  return jnp.exp(x) * (0.5 * (b - p)) / jnp.sqrt(x)
3761
3904
 
3762
3905
  self = jnp.abs(self)
3763
- return jnp.piecewise(
3764
- self, [self <= 8], [small, default]
3765
- )
3906
+ return jnp.piecewise(self, [self <= 8], [small, default])
3907
+
3766
3908
 
3767
3909
  @op(torch.ops.aten.special_modified_bessel_i1)
3768
3910
  @op_base.promote_int_input
3769
3911
  def _aten_special_modified_bessel_i1(self):
3770
- # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3271-L3364
3771
-
3772
- def small(x):
3773
- A = jnp.array(
3774
- [
3775
- 2.77791411276104639959e-18,
3776
- -2.11142121435816608115e-17,
3777
- 1.55363195773620046921e-16,
3778
- -1.10559694773538630805e-15,
3779
- 7.60068429473540693410e-15,
3780
- -5.04218550472791168711e-14,
3781
- 3.22379336594557470981e-13,
3782
- -1.98397439776494371520e-12,
3783
- 1.17361862988909016308e-11,
3784
- -6.66348972350202774223e-11,
3785
- 3.62559028155211703701e-10,
3786
- -1.88724975172282928790e-09,
3787
- 9.38153738649577178388e-09,
3788
- -4.44505912879632808065e-08,
3789
- 2.00329475355213526229e-07,
3790
- -8.56872026469545474066e-07,
3791
- 3.47025130813767847674e-06,
3792
- -1.32731636560394358279e-05,
3793
- 4.78156510755005422638e-05,
3794
- -1.61760815825896745588e-04,
3795
- 5.12285956168575772895e-04,
3796
- -1.51357245063125314899e-03,
3797
- 4.15642294431288815669e-03,
3798
- -1.05640848946261981558e-02,
3799
- 2.47264490306265168283e-02,
3800
- -5.29459812080949914269e-02,
3801
- 1.02643658689847095384e-01,
3802
- -1.76416518357834055153e-01,
3803
- 2.52587186443633654823e-01,
3804
- ],
3805
- dtype=self.dtype,
3806
- )
3807
-
3808
- def f(carry, val):
3809
- p, q, a = carry
3810
- p, q = q, a
3811
- return (p, q, ((jnp.abs(x) / 2.0) - 2.0) * q - p + val), None
3812
-
3813
- (p, _, a), _ = jax.lax.scan(
3814
- f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A)
3815
-
3816
- return jax.lax.cond(
3817
- x < 0, lambda: -(0.5 * (a - p) * jnp.abs(x) * jnp.exp(jnp.abs(x))), lambda: 0.5 * (a - p) * jnp.abs(x) * jnp.exp(jnp.abs(x))
3818
- )
3912
+ # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3271-L3364
3819
3913
 
3820
- def default(x):
3821
- B = jnp.array(
3822
- [
3823
- 7.51729631084210481353e-18,
3824
- 4.41434832307170791151e-18,
3825
- -4.65030536848935832153e-17,
3826
- -3.20952592199342395980e-17,
3827
- 2.96262899764595013876e-16,
3828
- 3.30820231092092828324e-16,
3829
- -1.88035477551078244854e-15,
3830
- -3.81440307243700780478e-15,
3831
- 1.04202769841288027642e-14,
3832
- 4.27244001671195135429e-14,
3833
- -2.10154184277266431302e-14,
3834
- -4.08355111109219731823e-13,
3835
- -7.19855177624590851209e-13,
3836
- 2.03562854414708950722e-12,
3837
- 1.41258074366137813316e-11,
3838
- 3.25260358301548823856e-11,
3839
- -1.89749581235054123450e-11,
3840
- -5.58974346219658380687e-10,
3841
- -3.83538038596423702205e-09,
3842
- -2.63146884688951950684e-08,
3843
- -2.51223623787020892529e-07,
3844
- -3.88256480887769039346e-06,
3845
- -1.10588938762623716291e-04,
3846
- -9.76109749136146840777e-03,
3847
- 7.78576235018280120474e-01,
3848
- ],
3849
- dtype=self.dtype,
3850
- )
3851
-
3852
- def f(carry, val):
3853
- p, q, b = carry
3854
- p, q = q, b
3855
- return (p, q, (32.0 / jnp.abs(x) - 2.0) * q - p + val), None
3856
-
3857
- (p, _, b), _ = jax.lax.scan(
3858
- f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B)
3859
-
3860
- return jax.lax.cond(
3861
- x < 0, lambda: -(jnp.exp(jnp.abs(x)) * (0.5 * (b - p)) / jnp.sqrt(jnp.abs(x))), lambda: jnp.exp(jnp.abs(x)) * (0.5 * (b - p)) / jnp.sqrt(jnp.abs(x))
3862
- )
3914
+ def small(x):
3915
+ A = jnp.array(
3916
+ [
3917
+ 2.77791411276104639959e-18,
3918
+ -2.11142121435816608115e-17,
3919
+ 1.55363195773620046921e-16,
3920
+ -1.10559694773538630805e-15,
3921
+ 7.60068429473540693410e-15,
3922
+ -5.04218550472791168711e-14,
3923
+ 3.22379336594557470981e-13,
3924
+ -1.98397439776494371520e-12,
3925
+ 1.17361862988909016308e-11,
3926
+ -6.66348972350202774223e-11,
3927
+ 3.62559028155211703701e-10,
3928
+ -1.88724975172282928790e-09,
3929
+ 9.38153738649577178388e-09,
3930
+ -4.44505912879632808065e-08,
3931
+ 2.00329475355213526229e-07,
3932
+ -8.56872026469545474066e-07,
3933
+ 3.47025130813767847674e-06,
3934
+ -1.32731636560394358279e-05,
3935
+ 4.78156510755005422638e-05,
3936
+ -1.61760815825896745588e-04,
3937
+ 5.12285956168575772895e-04,
3938
+ -1.51357245063125314899e-03,
3939
+ 4.15642294431288815669e-03,
3940
+ -1.05640848946261981558e-02,
3941
+ 2.47264490306265168283e-02,
3942
+ -5.29459812080949914269e-02,
3943
+ 1.02643658689847095384e-01,
3944
+ -1.76416518357834055153e-01,
3945
+ 2.52587186443633654823e-01,
3946
+ ],
3947
+ dtype=self.dtype,
3948
+ )
3863
3949
 
3864
- return jnp.piecewise(
3865
- self, [self <= 8], [small, default]
3950
+ def f(carry, val):
3951
+ p, q, a = carry
3952
+ p, q = q, a
3953
+ return (p, q, ((jnp.abs(x) / 2.0) - 2.0) * q - p + val), None
3954
+
3955
+ (p, _, a), _ = jax.lax.scan(
3956
+ f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A)
3957
+
3958
+ return jax.lax.cond(
3959
+ x < 0, lambda: -(0.5 * (a - p) * jnp.abs(x) * jnp.exp(jnp.abs(x))),
3960
+ lambda: 0.5 * (a - p) * jnp.abs(x) * jnp.exp(jnp.abs(x)))
3961
+
3962
+ def default(x):
3963
+ B = jnp.array(
3964
+ [
3965
+ 7.51729631084210481353e-18,
3966
+ 4.41434832307170791151e-18,
3967
+ -4.65030536848935832153e-17,
3968
+ -3.20952592199342395980e-17,
3969
+ 2.96262899764595013876e-16,
3970
+ 3.30820231092092828324e-16,
3971
+ -1.88035477551078244854e-15,
3972
+ -3.81440307243700780478e-15,
3973
+ 1.04202769841288027642e-14,
3974
+ 4.27244001671195135429e-14,
3975
+ -2.10154184277266431302e-14,
3976
+ -4.08355111109219731823e-13,
3977
+ -7.19855177624590851209e-13,
3978
+ 2.03562854414708950722e-12,
3979
+ 1.41258074366137813316e-11,
3980
+ 3.25260358301548823856e-11,
3981
+ -1.89749581235054123450e-11,
3982
+ -5.58974346219658380687e-10,
3983
+ -3.83538038596423702205e-09,
3984
+ -2.63146884688951950684e-08,
3985
+ -2.51223623787020892529e-07,
3986
+ -3.88256480887769039346e-06,
3987
+ -1.10588938762623716291e-04,
3988
+ -9.76109749136146840777e-03,
3989
+ 7.78576235018280120474e-01,
3990
+ ],
3991
+ dtype=self.dtype,
3866
3992
  )
3867
3993
 
3994
+ def f(carry, val):
3995
+ p, q, b = carry
3996
+ p, q = q, b
3997
+ return (p, q, (32.0 / jnp.abs(x) - 2.0) * q - p + val), None
3998
+
3999
+ (p, _, b), _ = jax.lax.scan(
4000
+ f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B)
4001
+
4002
+ return jax.lax.cond(
4003
+ x < 0, lambda: -(jnp.exp(jnp.abs(x)) *
4004
+ (0.5 * (b - p)) / jnp.sqrt(jnp.abs(x))),
4005
+ lambda: jnp.exp(jnp.abs(x)) * (0.5 * (b - p)) / jnp.sqrt(jnp.abs(x)))
4006
+
4007
+ return jnp.piecewise(self, [self <= 8], [small, default])
4008
+
4009
+
3868
4010
  @op(torch.ops.aten.special_modified_bessel_k0)
3869
4011
  @op_base.promote_int_input
3870
4012
  def _aten_special_modified_bessel_k0(self):
3871
- # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3367-L3441
4013
+ # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3367-L3441
3872
4014
 
3873
- def zero(x):
3874
- return jnp.array(jnp.inf, x.dtype)
4015
+ def zero(x):
4016
+ return jnp.array(jnp.inf, x.dtype)
3875
4017
 
3876
- def negative(x):
3877
- return jnp.array(jnp.nan, x.dtype)
4018
+ def negative(x):
4019
+ return jnp.array(jnp.nan, x.dtype)
3878
4020
 
3879
- def small(x):
3880
- A = jnp.array(
3881
- [
4021
+ def small(x):
4022
+ A = jnp.array(
4023
+ [
3882
4024
  1.37446543561352307156e-16,
3883
4025
  4.25981614279661018399e-14,
3884
4026
  1.03496952576338420167e-11,
@@ -3889,23 +4031,24 @@ def _aten_special_modified_bessel_k0(self):
3889
4031
  3.59799365153615016266e-02,
3890
4032
  3.44289899924628486886e-01,
3891
4033
  -5.35327393233902768720e-01,
3892
- ],
3893
- dtype=self.dtype,
3894
- )
4034
+ ],
4035
+ dtype=self.dtype,
4036
+ )
3895
4037
 
3896
- def f(carry, val):
3897
- p, q, a = carry
3898
- p, q = q, a
3899
- return (p, q, (x * x - 2.0) * q - p + val), None
4038
+ def f(carry, val):
4039
+ p, q, a = carry
4040
+ p, q = q, a
4041
+ return (p, q, (x * x - 2.0) * q - p + val), None
4042
+
4043
+ (p, _, a), _ = jax.lax.scan(
4044
+ f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A)
3900
4045
 
3901
- (p, _, a), _ = jax.lax.scan(
3902
- f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A)
3903
-
3904
- return 0.5 * (a - p) - jnp.log(0.5 * x) * _aten_special_modified_bessel_i0(x)
4046
+ return 0.5 * (a - p) - jnp.log(
4047
+ 0.5 * x) * _aten_special_modified_bessel_i0(x)
3905
4048
 
3906
- def default(x):
3907
- B = jnp.array(
3908
- [
4049
+ def default(x):
4050
+ B = jnp.array(
4051
+ [
3909
4052
  5.30043377268626276149e-18,
3910
4053
  -1.64758043015242134646e-17,
3911
4054
  5.21039150503902756861e-17,
@@ -3931,38 +4074,38 @@ def _aten_special_modified_bessel_k0(self):
3931
4074
  1.56988388573005337491e-03,
3932
4075
  -3.14481013119645005427e-02,
3933
4076
  2.44030308206595545468e+00,
3934
- ],
3935
- dtype=self.dtype,
3936
- )
4077
+ ],
4078
+ dtype=self.dtype,
4079
+ )
4080
+
4081
+ def f(carry, val):
4082
+ p, q, b = carry
4083
+ p, q = q, b
4084
+ return (p, q, (8.0 / x - 2.0) * q - p + val), None
3937
4085
 
3938
- def f(carry, val):
3939
- p, q, b = carry
3940
- p, q = q, b
3941
- return (p, q, (8.0 / x - 2.0) * q - p + val), None
4086
+ (p, _, b), _ = jax.lax.scan(
4087
+ f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B)
3942
4088
 
3943
- (p, _, b), _ = jax.lax.scan(
3944
- f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B)
3945
-
3946
- return jnp.exp(-x) * (0.5 * (b - p)) / jnp.sqrt(x)
4089
+ return jnp.exp(-x) * (0.5 * (b - p)) / jnp.sqrt(x)
4090
+
4091
+ return jnp.piecewise(self, [self <= 2, self < 0, self == 0],
4092
+ [small, negative, zero, default])
3947
4093
 
3948
- return jnp.piecewise(
3949
- self, [self <= 2, self < 0, self == 0], [small, negative, zero, default]
3950
- )
3951
4094
 
3952
4095
  @op(torch.ops.aten.special_modified_bessel_k1)
3953
4096
  @op_base.promote_int_input
3954
4097
  def _aten_special_modified_bessel_k1(self):
3955
- # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3444-L3519
4098
+ # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3444-L3519
3956
4099
 
3957
- def zero(x):
3958
- return jnp.array(jnp.inf, x.dtype)
4100
+ def zero(x):
4101
+ return jnp.array(jnp.inf, x.dtype)
3959
4102
 
3960
- def negative(x):
3961
- return jnp.array(jnp.nan, x.dtype)
4103
+ def negative(x):
4104
+ return jnp.array(jnp.nan, x.dtype)
3962
4105
 
3963
- def small(x):
3964
- A = jnp.array(
3965
- [
4106
+ def small(x):
4107
+ A = jnp.array(
4108
+ [
3966
4109
  -7.02386347938628759343e-18,
3967
4110
  -2.42744985051936593393e-15,
3968
4111
  -6.66690169419932900609e-13,
@@ -3974,24 +4117,25 @@ def _aten_special_modified_bessel_k1(self):
3974
4117
  -1.22611180822657148235e-01,
3975
4118
  -3.53155960776544875667e-01,
3976
4119
  1.52530022733894777053e+00,
3977
- ],
3978
- dtype=self.dtype,
3979
- )
4120
+ ],
4121
+ dtype=self.dtype,
4122
+ )
3980
4123
 
3981
- def f(carry, val):
3982
- p, q, a = carry
3983
- p, q = q, a
3984
- a = (x * x - 2.0) * q - p + val
3985
- return (p, q, a), None
4124
+ def f(carry, val):
4125
+ p, q, a = carry
4126
+ p, q = q, a
4127
+ a = (x * x - 2.0) * q - p + val
4128
+ return (p, q, a), None
4129
+
4130
+ (p, _, a), _ = jax.lax.scan(
4131
+ f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A)
3986
4132
 
3987
- (p, _, a), _ = jax.lax.scan(
3988
- f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A)
3989
-
3990
- return jnp.log(0.5 * x) * _aten_special_modified_bessel_i1(x) + 0.5 * (a - p) / x
4133
+ return jnp.log(
4134
+ 0.5 * x) * _aten_special_modified_bessel_i1(x) + 0.5 * (a - p) / x
3991
4135
 
3992
- def default(x):
3993
- B = jnp.array(
3994
- [
4136
+ def default(x):
4137
+ B = jnp.array(
4138
+ [
3995
4139
  -5.75674448366501715755e-18,
3996
4140
  1.79405087314755922667e-17,
3997
4141
  -5.68946255844285935196e-17,
@@ -4017,24 +4161,24 @@ def _aten_special_modified_bessel_k1(self):
4017
4161
  -2.85781685962277938680e-03,
4018
4162
  1.03923736576817238437e-01,
4019
4163
  2.72062619048444266945e+00,
4020
- ],
4021
- dtype=self.dtype,
4022
- )
4164
+ ],
4165
+ dtype=self.dtype,
4166
+ )
4167
+
4168
+ def f(carry, val):
4169
+ p, q, b = carry
4170
+ p, q = q, b
4171
+ b = (8.0 / x - 2.0) * q - p + val
4172
+ return (p, q, b), None
4173
+
4174
+ (p, _, b), _ = jax.lax.scan(
4175
+ f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B)
4023
4176
 
4024
- def f(carry, val):
4025
- p, q, b = carry
4026
- p, q = q, b
4027
- b = (8.0 / x - 2.0) * q - p + val
4028
- return (p, q, b), None
4177
+ return jnp.exp(-x) * (0.5 * (b - p)) / jnp.sqrt(x)
4029
4178
 
4030
- (p, _, b), _ = jax.lax.scan(
4031
- f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B)
4032
-
4033
- return jnp.exp(-x) * (0.5 * (b - p)) / jnp.sqrt(x)
4179
+ return jnp.piecewise(self, [self <= 2, self < 0, self == 0],
4180
+ [small, negative, zero, default])
4034
4181
 
4035
- return jnp.piecewise(
4036
- self, [self <= 2, self < 0, self == 0], [small, negative, zero, default]
4037
- )
4038
4182
 
4039
4183
  @op(torch.ops.aten.polygamma)
4040
4184
  def _aten_polygamma(x, n):
@@ -4042,10 +4186,12 @@ def _aten_polygamma(x, n):
4042
4186
  n = n.astype(mappings.t2j_dtype(torch.get_default_dtype()))
4043
4187
  return jax.lax.polygamma(jnp.float32(x), n)
4044
4188
 
4189
+
4045
4190
  @op(torch.ops.aten.special_ndtri)
4046
4191
  @op_base.promote_int_input
4047
4192
  def _aten_special_ndtri(self):
4048
- return jax.scipy.special.ndtri(self)
4193
+ return jax.scipy.special.ndtri(self)
4194
+
4049
4195
 
4050
4196
  @op(torch.ops.aten.special_bessel_j0)
4051
4197
  @op_base.promote_int_input
@@ -4057,112 +4203,104 @@ def _aten_special_bessel_j0(self):
4057
4203
 
4058
4204
  def small(x):
4059
4205
  RP = jnp.array(
4060
- [
4061
- -4.79443220978201773821e09,
4062
- 1.95617491946556577543e12,
4063
- -2.49248344360967716204e14,
4064
- 9.70862251047306323952e15,
4065
- ],
4066
- dtype=self.dtype,
4206
+ [
4207
+ -4.79443220978201773821e09,
4208
+ 1.95617491946556577543e12,
4209
+ -2.49248344360967716204e14,
4210
+ 9.70862251047306323952e15,
4211
+ ],
4212
+ dtype=self.dtype,
4067
4213
  )
4068
4214
  RQ = jnp.array(
4069
- [
4070
- 4.99563147152651017219e02,
4071
- 1.73785401676374683123e05,
4072
- 4.84409658339962045305e07,
4073
- 1.11855537045356834862e10,
4074
- 2.11277520115489217587e12,
4075
- 3.10518229857422583814e14,
4076
- 3.18121955943204943306e16,
4077
- 1.71086294081043136091e18,
4078
- ],
4079
- dtype=self.dtype,
4215
+ [
4216
+ 4.99563147152651017219e02,
4217
+ 1.73785401676374683123e05,
4218
+ 4.84409658339962045305e07,
4219
+ 1.11855537045356834862e10,
4220
+ 2.11277520115489217587e12,
4221
+ 3.10518229857422583814e14,
4222
+ 3.18121955943204943306e16,
4223
+ 1.71086294081043136091e18,
4224
+ ],
4225
+ dtype=self.dtype,
4080
4226
  )
4081
4227
 
4082
4228
  rp = op_base.foreach_loop(RP, lambda carry, rp_i: carry * (x * x) + rp_i)
4083
4229
  rq = op_base.foreach_loop(RQ, lambda carry, rq_i: carry * (x * x) + rq_i)
4084
4230
 
4085
- return (
4086
- (x * x - 5.78318596294678452118e00)
4087
- * (x * x - 3.04712623436620863991e01)
4088
- * rp
4089
- / rq
4090
- )
4231
+ return ((x * x - 5.78318596294678452118e00) *
4232
+ (x * x - 3.04712623436620863991e01) * rp / rq)
4091
4233
 
4092
4234
  def default(x):
4093
4235
  PP = jnp.array(
4094
- [
4095
- 7.96936729297347051624e-04,
4096
- 8.28352392107440799803e-02,
4097
- 1.23953371646414299388e00,
4098
- 5.44725003058768775090e00,
4099
- 8.74716500199817011941e00,
4100
- 5.30324038235394892183e00,
4101
- 9.99999999999999997821e-01,
4102
- ],
4103
- dtype=self.dtype,
4236
+ [
4237
+ 7.96936729297347051624e-04,
4238
+ 8.28352392107440799803e-02,
4239
+ 1.23953371646414299388e00,
4240
+ 5.44725003058768775090e00,
4241
+ 8.74716500199817011941e00,
4242
+ 5.30324038235394892183e00,
4243
+ 9.99999999999999997821e-01,
4244
+ ],
4245
+ dtype=self.dtype,
4104
4246
  )
4105
4247
  PQ = jnp.array(
4106
- [
4107
- 9.24408810558863637013e-04,
4108
- 8.56288474354474431428e-02,
4109
- 1.25352743901058953537e00,
4110
- 5.47097740330417105182e00,
4111
- 8.76190883237069594232e00,
4112
- 5.30605288235394617618e00,
4113
- 1.00000000000000000218e00,
4114
- ],
4115
- dtype=self.dtype,
4248
+ [
4249
+ 9.24408810558863637013e-04,
4250
+ 8.56288474354474431428e-02,
4251
+ 1.25352743901058953537e00,
4252
+ 5.47097740330417105182e00,
4253
+ 8.76190883237069594232e00,
4254
+ 5.30605288235394617618e00,
4255
+ 1.00000000000000000218e00,
4256
+ ],
4257
+ dtype=self.dtype,
4116
4258
  )
4117
4259
  QP = jnp.array(
4118
- [
4119
- -1.13663838898469149931e-02,
4120
- -1.28252718670509318512e00,
4121
- -1.95539544257735972385e01,
4122
- -9.32060152123768231369e01,
4123
- -1.77681167980488050595e02,
4124
- -1.47077505154951170175e02,
4125
- -5.14105326766599330220e01,
4126
- -6.05014350600728481186e00,
4127
- ],
4128
- dtype=self.dtype,
4260
+ [
4261
+ -1.13663838898469149931e-02,
4262
+ -1.28252718670509318512e00,
4263
+ -1.95539544257735972385e01,
4264
+ -9.32060152123768231369e01,
4265
+ -1.77681167980488050595e02,
4266
+ -1.47077505154951170175e02,
4267
+ -5.14105326766599330220e01,
4268
+ -6.05014350600728481186e00,
4269
+ ],
4270
+ dtype=self.dtype,
4129
4271
  )
4130
4272
  QQ = jnp.array(
4131
- [
4132
- 6.43178256118178023184e01,
4133
- 8.56430025976980587198e02,
4134
- 3.88240183605401609683e03,
4135
- 7.24046774195652478189e03,
4136
- 5.93072701187316984827e03,
4137
- 2.06209331660327847417e03,
4138
- 2.42005740240291393179e02,
4139
- ],
4140
- dtype=self.dtype,
4273
+ [
4274
+ 6.43178256118178023184e01,
4275
+ 8.56430025976980587198e02,
4276
+ 3.88240183605401609683e03,
4277
+ 7.24046774195652478189e03,
4278
+ 5.93072701187316984827e03,
4279
+ 2.06209331660327847417e03,
4280
+ 2.42005740240291393179e02,
4281
+ ],
4282
+ dtype=self.dtype,
4141
4283
  )
4142
4284
 
4143
- pp = op_base.foreach_loop(PP, lambda carry, pp_i: carry * (25.0 / (x * x)) + pp_i)
4144
- pq = op_base.foreach_loop(PQ, lambda carry, pq_i: carry * (25.0 / (x * x)) + pq_i)
4145
- qp = op_base.foreach_loop(QP, lambda carry, qp_i: carry * (25.0 / (x * x)) + qp_i)
4146
- qq = op_base.foreach_loop(QQ, lambda carry, qq_i: carry * (25.0 / (x * x)) + qq_i)
4147
-
4148
- return (
4149
- (
4150
- pp / pq * jnp.cos(x - 0.785398163397448309615660845819875721)
4151
- - 5.0
4152
- / x
4153
- * (qp / qq)
4154
- * jnp.sin(x - 0.785398163397448309615660845819875721)
4155
- )
4156
- * 0.797884560802865355879892119868763737
4157
- / jnp.sqrt(x)
4158
- )
4285
+ pp = op_base.foreach_loop(
4286
+ PP, lambda carry, pp_i: carry * (25.0 / (x * x)) + pp_i)
4287
+ pq = op_base.foreach_loop(
4288
+ PQ, lambda carry, pq_i: carry * (25.0 / (x * x)) + pq_i)
4289
+ qp = op_base.foreach_loop(
4290
+ QP, lambda carry, qp_i: carry * (25.0 / (x * x)) + qp_i)
4291
+ qq = op_base.foreach_loop(
4292
+ QQ, lambda carry, qq_i: carry * (25.0 / (x * x)) + qq_i)
4293
+
4294
+ return ((pp / pq * jnp.cos(x - 0.785398163397448309615660845819875721) -
4295
+ 5.0 / x *
4296
+ (qp / qq) * jnp.sin(x - 0.785398163397448309615660845819875721)) *
4297
+ 0.797884560802865355879892119868763737 / jnp.sqrt(x))
4159
4298
 
4160
4299
  self = jnp.abs(self)
4161
4300
  # Last True condition in `piecewise` takes priority, but last function is
4162
4301
  # default. See https://github.com/numpy/numpy/issues/16475
4163
- return jnp.piecewise(
4164
- self, [self <= 5.0, self < 0.00001], [small, very_small, default]
4165
- )
4302
+ return jnp.piecewise(self, [self <= 5.0, self < 0.00001],
4303
+ [small, very_small, default])
4166
4304
 
4167
4305
 
4168
4306
  @op(torch.ops.aten.special_bessel_j1)
@@ -4172,114 +4310,106 @@ def _aten_special_bessel_j1(self):
4172
4310
 
4173
4311
  def small(x):
4174
4312
  RP = jnp.array(
4175
- [
4176
- -8.99971225705559398224e08,
4177
- 4.52228297998194034323e11,
4178
- -7.27494245221818276015e13,
4179
- 3.68295732863852883286e15,
4180
- ],
4181
- dtype=self.dtype,
4313
+ [
4314
+ -8.99971225705559398224e08,
4315
+ 4.52228297998194034323e11,
4316
+ -7.27494245221818276015e13,
4317
+ 3.68295732863852883286e15,
4318
+ ],
4319
+ dtype=self.dtype,
4182
4320
  )
4183
4321
  RQ = jnp.array(
4184
- [
4185
- 6.20836478118054335476e02,
4186
- 2.56987256757748830383e05,
4187
- 8.35146791431949253037e07,
4188
- 2.21511595479792499675e10,
4189
- 4.74914122079991414898e12,
4190
- 7.84369607876235854894e14,
4191
- 8.95222336184627338078e16,
4192
- 5.32278620332680085395e18,
4193
- ],
4194
- dtype=self.dtype,
4322
+ [
4323
+ 6.20836478118054335476e02,
4324
+ 2.56987256757748830383e05,
4325
+ 8.35146791431949253037e07,
4326
+ 2.21511595479792499675e10,
4327
+ 4.74914122079991414898e12,
4328
+ 7.84369607876235854894e14,
4329
+ 8.95222336184627338078e16,
4330
+ 5.32278620332680085395e18,
4331
+ ],
4332
+ dtype=self.dtype,
4195
4333
  )
4196
4334
 
4197
4335
  rp = op_base.foreach_loop(RP, lambda carry, rp_i: carry * (x * x) + rp_i)
4198
4336
  rq = op_base.foreach_loop(RQ, lambda carry, rq_i: carry * (x * x) + rq_i)
4199
4337
 
4200
- return (
4201
- rp
4202
- / rq
4203
- * x
4204
- * (x * x - 1.46819706421238932572e01)
4205
- * (x * x - 4.92184563216946036703e01)
4206
- )
4338
+ return (rp / rq * x * (x * x - 1.46819706421238932572e01) *
4339
+ (x * x - 4.92184563216946036703e01))
4207
4340
 
4208
4341
  def default(x):
4209
4342
  PP = jnp.array(
4210
- [
4211
- 7.62125616208173112003e-04,
4212
- 7.31397056940917570436e-02,
4213
- 1.12719608129684925192e00,
4214
- 5.11207951146807644818e00,
4215
- 8.42404590141772420927e00,
4216
- 5.21451598682361504063e00,
4217
- 1.00000000000000000254e00,
4218
- ],
4219
- dtype=self.dtype,
4343
+ [
4344
+ 7.62125616208173112003e-04,
4345
+ 7.31397056940917570436e-02,
4346
+ 1.12719608129684925192e00,
4347
+ 5.11207951146807644818e00,
4348
+ 8.42404590141772420927e00,
4349
+ 5.21451598682361504063e00,
4350
+ 1.00000000000000000254e00,
4351
+ ],
4352
+ dtype=self.dtype,
4220
4353
  )
4221
4354
  PQ = jnp.array(
4222
- [
4223
- 5.71323128072548699714e-04,
4224
- 6.88455908754495404082e-02,
4225
- 1.10514232634061696926e00,
4226
- 5.07386386128601488557e00,
4227
- 8.39985554327604159757e00,
4228
- 5.20982848682361821619e00,
4229
- 9.99999999999999997461e-01,
4230
- ],
4231
- dtype=self.dtype,
4355
+ [
4356
+ 5.71323128072548699714e-04,
4357
+ 6.88455908754495404082e-02,
4358
+ 1.10514232634061696926e00,
4359
+ 5.07386386128601488557e00,
4360
+ 8.39985554327604159757e00,
4361
+ 5.20982848682361821619e00,
4362
+ 9.99999999999999997461e-01,
4363
+ ],
4364
+ dtype=self.dtype,
4232
4365
  )
4233
4366
  QP = jnp.array(
4234
- [
4235
- 5.10862594750176621635e-02,
4236
- 4.98213872951233449420e00,
4237
- 7.58238284132545283818e01,
4238
- 3.66779609360150777800e02,
4239
- 7.10856304998926107277e02,
4240
- 5.97489612400613639965e02,
4241
- 2.11688757100572135698e02,
4242
- 2.52070205858023719784e01,
4243
- ],
4244
- dtype=self.dtype,
4367
+ [
4368
+ 5.10862594750176621635e-02,
4369
+ 4.98213872951233449420e00,
4370
+ 7.58238284132545283818e01,
4371
+ 3.66779609360150777800e02,
4372
+ 7.10856304998926107277e02,
4373
+ 5.97489612400613639965e02,
4374
+ 2.11688757100572135698e02,
4375
+ 2.52070205858023719784e01,
4376
+ ],
4377
+ dtype=self.dtype,
4245
4378
  )
4246
4379
  QQ = jnp.array(
4247
- [
4248
- 7.42373277035675149943e01,
4249
- 1.05644886038262816351e03,
4250
- 4.98641058337653607651e03,
4251
- 9.56231892404756170795e03,
4252
- 7.99704160447350683650e03,
4253
- 2.82619278517639096600e03,
4254
- 3.36093607810698293419e02,
4255
- ],
4256
- dtype=self.dtype,
4380
+ [
4381
+ 7.42373277035675149943e01,
4382
+ 1.05644886038262816351e03,
4383
+ 4.98641058337653607651e03,
4384
+ 9.56231892404756170795e03,
4385
+ 7.99704160447350683650e03,
4386
+ 2.82619278517639096600e03,
4387
+ 3.36093607810698293419e02,
4388
+ ],
4389
+ dtype=self.dtype,
4257
4390
  )
4258
4391
 
4259
- pp = op_base.foreach_loop(PP, lambda carry, pp_i: carry * (25.0 / (x * x)) + pp_i)
4260
- pq = op_base.foreach_loop(PQ, lambda carry, pq_i: carry * (25.0 / (x * x)) + pq_i)
4261
- qp = op_base.foreach_loop(QP, lambda carry, qp_i: carry * (25.0 / (x * x)) + qp_i)
4262
- qq = op_base.foreach_loop(QQ, lambda carry, qq_i: carry * (25.0 / (x * x)) + qq_i)
4263
-
4264
- return (
4265
- (
4266
- pp / pq * jnp.cos(x - 2.356194490192344928846982537459627163)
4267
- - 5.0
4268
- / x
4269
- * (qp / qq)
4270
- * jnp.sin(x - 2.356194490192344928846982537459627163)
4271
- )
4272
- * 0.797884560802865355879892119868763737
4273
- / jnp.sqrt(x)
4274
- )
4392
+ pp = op_base.foreach_loop(
4393
+ PP, lambda carry, pp_i: carry * (25.0 / (x * x)) + pp_i)
4394
+ pq = op_base.foreach_loop(
4395
+ PQ, lambda carry, pq_i: carry * (25.0 / (x * x)) + pq_i)
4396
+ qp = op_base.foreach_loop(
4397
+ QP, lambda carry, qp_i: carry * (25.0 / (x * x)) + qp_i)
4398
+ qq = op_base.foreach_loop(
4399
+ QQ, lambda carry, qq_i: carry * (25.0 / (x * x)) + qq_i)
4400
+
4401
+ return ((pp / pq * jnp.cos(x - 2.356194490192344928846982537459627163) -
4402
+ 5.0 / x *
4403
+ (qp / qq) * jnp.sin(x - 2.356194490192344928846982537459627163)) *
4404
+ 0.797884560802865355879892119868763737 / jnp.sqrt(x))
4275
4405
 
4276
4406
  # If x < 0, bessel_j1(x) = -bessel_j1(-x)
4277
4407
  sign = jnp.sign(self)
4278
4408
  self = jnp.abs(self)
4279
4409
  return sign * jnp.piecewise(
4280
- self,
4281
- [self <= 5.0],
4282
- [small, default],
4410
+ self,
4411
+ [self <= 5.0],
4412
+ [small, default],
4283
4413
  )
4284
4414
 
4285
4415
 
@@ -4296,85 +4426,86 @@ def _aten_special_bessel_y0(self):
4296
4426
 
4297
4427
  def small(x):
4298
4428
  YP = jnp.array(
4299
- [
4300
- 1.55924367855235737965e04,
4301
- -1.46639295903971606143e07,
4302
- 5.43526477051876500413e09,
4303
- -9.82136065717911466409e11,
4304
- 8.75906394395366999549e13,
4305
- -3.46628303384729719441e15,
4306
- 4.42733268572569800351e16,
4307
- -1.84950800436986690637e16,
4308
- ],
4309
- dtype=self.dtype,
4429
+ [
4430
+ 1.55924367855235737965e04,
4431
+ -1.46639295903971606143e07,
4432
+ 5.43526477051876500413e09,
4433
+ -9.82136065717911466409e11,
4434
+ 8.75906394395366999549e13,
4435
+ -3.46628303384729719441e15,
4436
+ 4.42733268572569800351e16,
4437
+ -1.84950800436986690637e16,
4438
+ ],
4439
+ dtype=self.dtype,
4310
4440
  )
4311
4441
  YQ = jnp.array(
4312
- [
4313
- 1.04128353664259848412e03,
4314
- 6.26107330137134956842e05,
4315
- 2.68919633393814121987e08,
4316
- 8.64002487103935000337e10,
4317
- 2.02979612750105546709e13,
4318
- 3.17157752842975028269e15,
4319
- 2.50596256172653059228e17,
4320
- ],
4321
- dtype=self.dtype,
4442
+ [
4443
+ 1.04128353664259848412e03,
4444
+ 6.26107330137134956842e05,
4445
+ 2.68919633393814121987e08,
4446
+ 8.64002487103935000337e10,
4447
+ 2.02979612750105546709e13,
4448
+ 3.17157752842975028269e15,
4449
+ 2.50596256172653059228e17,
4450
+ ],
4451
+ dtype=self.dtype,
4322
4452
  )
4323
4453
 
4324
4454
  yp = op_base.foreach_loop(YP, lambda carry, yp_i: carry * (x * x) + yp_i)
4325
4455
  yq = op_base.foreach_loop(YQ, lambda carry, yq_i: carry * (x * x) + yq_i)
4326
4456
 
4327
- return yp / yq + (0.636619772367581343075535053490057448 * jnp.log(x) * _aten_special_bessel_j0(x))
4457
+ return yp / yq + (0.636619772367581343075535053490057448 * jnp.log(x) *
4458
+ _aten_special_bessel_j0(x))
4328
4459
 
4329
4460
  def default(x):
4330
4461
  PP = jnp.array(
4331
- [
4332
- 7.96936729297347051624e-04,
4333
- 8.28352392107440799803e-02,
4334
- 1.23953371646414299388e00,
4335
- 5.44725003058768775090e00,
4336
- 8.74716500199817011941e00,
4337
- 5.30324038235394892183e00,
4338
- 9.99999999999999997821e-01,
4339
- ],
4340
- dtype=self.dtype,
4462
+ [
4463
+ 7.96936729297347051624e-04,
4464
+ 8.28352392107440799803e-02,
4465
+ 1.23953371646414299388e00,
4466
+ 5.44725003058768775090e00,
4467
+ 8.74716500199817011941e00,
4468
+ 5.30324038235394892183e00,
4469
+ 9.99999999999999997821e-01,
4470
+ ],
4471
+ dtype=self.dtype,
4341
4472
  )
4342
4473
  PQ = jnp.array(
4343
- [
4344
- 9.24408810558863637013e-04,
4345
- 8.56288474354474431428e-02,
4346
- 1.25352743901058953537e00,
4347
- 5.47097740330417105182e00,
4348
- 8.76190883237069594232e00,
4349
- 5.30605288235394617618e00,
4350
- 1.00000000000000000218e00,
4351
- ],
4352
- dtype=self.dtype,
4474
+ [
4475
+ 9.24408810558863637013e-04,
4476
+ 8.56288474354474431428e-02,
4477
+ 1.25352743901058953537e00,
4478
+ 5.47097740330417105182e00,
4479
+ 8.76190883237069594232e00,
4480
+ 5.30605288235394617618e00,
4481
+ 1.00000000000000000218e00,
4482
+ ],
4483
+ dtype=self.dtype,
4353
4484
  )
4354
4485
  QP = jnp.array(
4355
- [
4356
- -1.13663838898469149931e-02,
4357
- -1.28252718670509318512e00,
4358
- -1.95539544257735972385e01,
4359
- -9.32060152123768231369e01,
4360
- -1.77681167980488050595e02,
4361
- -1.47077505154951170175e02,
4362
- -5.14105326766599330220e01,
4363
- -6.05014350600728481186e00,
4364
- ],
4365
- dtype=self.dtype,
4486
+ [
4487
+ -1.13663838898469149931e-02,
4488
+ -1.28252718670509318512e00,
4489
+ -1.95539544257735972385e01,
4490
+ -9.32060152123768231369e01,
4491
+ -1.77681167980488050595e02,
4492
+ -1.47077505154951170175e02,
4493
+ -5.14105326766599330220e01,
4494
+ -6.05014350600728481186e00,
4495
+ ],
4496
+ dtype=self.dtype,
4366
4497
  )
4367
4498
  QQ = jnp.array(
4368
- [
4369
- 6.43178256118178023184e01,
4370
- 8.56430025976980587198e02,
4371
- 3.88240183605401609683e03,
4372
- 7.24046774195652478189e03,
4373
- 5.93072701187316984827e03,
4374
- 2.06209331660327847417e03,
4375
- 2.42005740240291393179e02,
4376
- ],
4377
- dtype=self.dtype,
4499
+ [
4500
+ 6.43178256118178023184e01,
4501
+ 8.56430025976980587198e02,
4502
+ 3.88240183605401609683e03,
4503
+ 7.24046774195652478189e03,
4504
+ 5.93072701187316984827e03,
4505
+ 2.06209331660327847417e03,
4506
+ 2.42005740240291393179e02,
4507
+ ],
4508
+ dtype=self.dtype,
4378
4509
  )
4379
4510
 
4380
4511
  factor = 25.0 / (x * x)
@@ -4383,22 +4514,15 @@ def _aten_special_bessel_y0(self):
4383
4514
  qp = op_base.foreach_loop(QP, lambda carry, qp_i: carry * factor + qp_i)
4384
4515
  qq = op_base.foreach_loop(QQ, lambda carry, qq_i: carry * factor + qq_i)
4385
4516
 
4386
- return (
4387
- (
4388
- pp / pq * jnp.sin(x - 0.785398163397448309615660845819875721)
4389
- + 5.0
4390
- / x
4391
- * (qp / qq)
4392
- * jnp.cos(x - 0.785398163397448309615660845819875721)
4393
- )
4394
- * 0.797884560802865355879892119868763737
4395
- / jnp.sqrt(x)
4396
- )
4517
+ return ((pp / pq * jnp.sin(x - 0.785398163397448309615660845819875721) +
4518
+ 5.0 / x *
4519
+ (qp / qq) * jnp.cos(x - 0.785398163397448309615660845819875721)) *
4520
+ 0.797884560802865355879892119868763737 / jnp.sqrt(x))
4397
4521
 
4398
4522
  return jnp.piecewise(
4399
- self,
4400
- [self <= 5.0, self < 0., self == 0.],
4401
- [small, negative, zero, default],
4523
+ self,
4524
+ [self <= 5.0, self < 0., self == 0.],
4525
+ [small, negative, zero, default],
4402
4526
  )
4403
4527
 
4404
4528
 
@@ -4415,90 +4539,86 @@ def _aten_special_bessel_y1(self):
4415
4539
 
4416
4540
  def small(x):
4417
4541
  YP = jnp.array(
4418
- [
4419
- 1.26320474790178026440e09,
4420
- -6.47355876379160291031e11,
4421
- 1.14509511541823727583e14,
4422
- -8.12770255501325109621e15,
4423
- 2.02439475713594898196e17,
4424
- -7.78877196265950026825e17,
4425
- ],
4426
- dtype=self.dtype,
4542
+ [
4543
+ 1.26320474790178026440e09,
4544
+ -6.47355876379160291031e11,
4545
+ 1.14509511541823727583e14,
4546
+ -8.12770255501325109621e15,
4547
+ 2.02439475713594898196e17,
4548
+ -7.78877196265950026825e17,
4549
+ ],
4550
+ dtype=self.dtype,
4427
4551
  )
4428
4552
  YQ = jnp.array(
4429
- [
4430
- 5.94301592346128195359e02,
4431
- 2.35564092943068577943e05,
4432
- 7.34811944459721705660e07,
4433
- 1.87601316108706159478e10,
4434
- 3.88231277496238566008e12,
4435
- 6.20557727146953693363e14,
4436
- 6.87141087355300489866e16,
4437
- 3.97270608116560655612e18,
4438
- ],
4439
- dtype=self.dtype,
4553
+ [
4554
+ 5.94301592346128195359e02,
4555
+ 2.35564092943068577943e05,
4556
+ 7.34811944459721705660e07,
4557
+ 1.87601316108706159478e10,
4558
+ 3.88231277496238566008e12,
4559
+ 6.20557727146953693363e14,
4560
+ 6.87141087355300489866e16,
4561
+ 3.97270608116560655612e18,
4562
+ ],
4563
+ dtype=self.dtype,
4440
4564
  )
4441
4565
 
4442
4566
  yp = op_base.foreach_loop(YP, lambda carry, yp_i: carry * (x * x) + yp_i)
4443
4567
  yq = op_base.foreach_loop(YQ, lambda carry, yq_i: carry * (x * x) + yq_i)
4444
4568
 
4445
- return (
4446
- x * (yp / yq)
4447
- + (
4448
- 0.636619772367581343075535053490057448
4449
- * (_aten_special_bessel_j1(x) * jnp.log(x) - 1.0 / x)
4450
- )
4451
- )
4569
+ return (x * (yp / yq) +
4570
+ (0.636619772367581343075535053490057448 *
4571
+ (_aten_special_bessel_j1(x) * jnp.log(x) - 1.0 / x)))
4452
4572
 
4453
4573
  def default(x):
4454
4574
  PP = jnp.array(
4455
- [
4456
- 7.62125616208173112003e-04,
4457
- 7.31397056940917570436e-02,
4458
- 1.12719608129684925192e00,
4459
- 5.11207951146807644818e00,
4460
- 8.42404590141772420927e00,
4461
- 5.21451598682361504063e00,
4462
- 1.00000000000000000254e00,
4463
- ],
4464
- dtype=self.dtype,
4575
+ [
4576
+ 7.62125616208173112003e-04,
4577
+ 7.31397056940917570436e-02,
4578
+ 1.12719608129684925192e00,
4579
+ 5.11207951146807644818e00,
4580
+ 8.42404590141772420927e00,
4581
+ 5.21451598682361504063e00,
4582
+ 1.00000000000000000254e00,
4583
+ ],
4584
+ dtype=self.dtype,
4465
4585
  )
4466
4586
  PQ = jnp.array(
4467
- [
4468
- 5.71323128072548699714e-04,
4469
- 6.88455908754495404082e-02,
4470
- 1.10514232634061696926e00,
4471
- 5.07386386128601488557e00,
4472
- 8.39985554327604159757e00,
4473
- 5.20982848682361821619e00,
4474
- 9.99999999999999997461e-01,
4475
- ],
4476
- dtype=self.dtype,
4587
+ [
4588
+ 5.71323128072548699714e-04,
4589
+ 6.88455908754495404082e-02,
4590
+ 1.10514232634061696926e00,
4591
+ 5.07386386128601488557e00,
4592
+ 8.39985554327604159757e00,
4593
+ 5.20982848682361821619e00,
4594
+ 9.99999999999999997461e-01,
4595
+ ],
4596
+ dtype=self.dtype,
4477
4597
  )
4478
4598
  QP = jnp.array(
4479
- [
4480
- 5.10862594750176621635e-02,
4481
- 4.98213872951233449420e00,
4482
- 7.58238284132545283818e01,
4483
- 3.66779609360150777800e02,
4484
- 7.10856304998926107277e02,
4485
- 5.97489612400613639965e02,
4486
- 2.11688757100572135698e02,
4487
- 2.52070205858023719784e01,
4488
- ],
4489
- dtype=self.dtype,
4599
+ [
4600
+ 5.10862594750176621635e-02,
4601
+ 4.98213872951233449420e00,
4602
+ 7.58238284132545283818e01,
4603
+ 3.66779609360150777800e02,
4604
+ 7.10856304998926107277e02,
4605
+ 5.97489612400613639965e02,
4606
+ 2.11688757100572135698e02,
4607
+ 2.52070205858023719784e01,
4608
+ ],
4609
+ dtype=self.dtype,
4490
4610
  )
4491
4611
  QQ = jnp.array(
4492
- [
4493
- 7.42373277035675149943e01,
4494
- 1.05644886038262816351e03,
4495
- 4.98641058337653607651e03,
4496
- 9.56231892404756170795e03,
4497
- 7.99704160447350683650e03,
4498
- 2.82619278517639096600e03,
4499
- 3.36093607810698293419e02,
4500
- ],
4501
- dtype=self.dtype,
4612
+ [
4613
+ 7.42373277035675149943e01,
4614
+ 1.05644886038262816351e03,
4615
+ 4.98641058337653607651e03,
4616
+ 9.56231892404756170795e03,
4617
+ 7.99704160447350683650e03,
4618
+ 2.82619278517639096600e03,
4619
+ 3.36093607810698293419e02,
4620
+ ],
4621
+ dtype=self.dtype,
4502
4622
  )
4503
4623
 
4504
4624
  factor = 25.0 / (x * x)
@@ -4507,22 +4627,15 @@ def _aten_special_bessel_y1(self):
4507
4627
  qp = op_base.foreach_loop(QP, lambda carry, qp_i: carry * factor + qp_i)
4508
4628
  qq = op_base.foreach_loop(QQ, lambda carry, qq_i: carry * factor + qq_i)
4509
4629
 
4510
- return (
4511
- (
4512
- pp / pq * jnp.sin(x - 2.356194490192344928846982537459627163)
4513
- + 5.0
4514
- / x
4515
- * (qp / qq)
4516
- * jnp.cos(x - 2.356194490192344928846982537459627163)
4517
- )
4518
- * 0.797884560802865355879892119868763737
4519
- / jnp.sqrt(x)
4520
- )
4630
+ return ((pp / pq * jnp.sin(x - 2.356194490192344928846982537459627163) +
4631
+ 5.0 / x *
4632
+ (qp / qq) * jnp.cos(x - 2.356194490192344928846982537459627163)) *
4633
+ 0.797884560802865355879892119868763737 / jnp.sqrt(x))
4521
4634
 
4522
4635
  return jnp.piecewise(
4523
- self,
4524
- [self <= 5.0, self < 0., self == 0.],
4525
- [small, negative, zero, default],
4636
+ self,
4637
+ [self <= 5.0, self < 0., self == 0.],
4638
+ [small, negative, zero, default],
4526
4639
  )
4527
4640
 
4528
4641
 
@@ -4533,11 +4646,13 @@ def _aten_special_chebyshev_polynomial_t(self, n):
4533
4646
 
4534
4647
  @jnp.vectorize
4535
4648
  def vectorized(x, n_i):
4649
+
4536
4650
  def negative_n(x):
4537
4651
  return jnp.zeros_like(x)
4538
4652
 
4539
4653
  def one_x(x):
4540
- return jnp.where((x > 0) | (n_i % 2 == 0), jnp.ones_like(x), -jnp.ones_like(x))
4654
+ return jnp.where((x > 0) | (n_i % 2 == 0), jnp.ones_like(x),
4655
+ -jnp.ones_like(x))
4541
4656
 
4542
4657
  def large_n_small_x(x):
4543
4658
  return jnp.cos(n_i * jnp.acos(x))
@@ -4549,24 +4664,18 @@ def _aten_special_chebyshev_polynomial_t(self, n):
4549
4664
  return x
4550
4665
 
4551
4666
  def default(x):
4667
+
4552
4668
  def f(_, carry):
4553
4669
  p, q = carry
4554
4670
  return (q, 2 * x * q - p)
4555
4671
 
4556
- _, r = jax.lax.fori_loop(0, n_i - 1, f, init_val=(1., x))
4672
+ _, r = jax.lax.fori_loop(0, n_i - 1, f, init_val=(1., x))
4557
4673
  return r
4558
4674
 
4559
- return jnp.piecewise(
4560
- x,
4561
- [
4562
- n_i == 1,
4563
- n_i == 0,
4564
- (n_i == 6) & (jnp.abs(x) < 1),
4565
- jnp.abs(x) == 1.,
4566
- n_i < 0
4567
- ],
4568
- [one_n, zero_n, large_n_small_x, one_x, negative_n, default]
4569
- )
4675
+ return jnp.piecewise(x, [
4676
+ n_i == 1, n_i == 0, (n_i == 6) & (jnp.abs(x) < 1),
4677
+ jnp.abs(x) == 1., n_i < 0
4678
+ ], [one_n, zero_n, large_n_small_x, one_x, negative_n, default])
4570
4679
 
4571
4680
  # Explcicitly vectorize since we must vectorizes over both self and n
4572
4681
  return vectorized(self, n.astype(jnp.int64))
@@ -4579,6 +4688,7 @@ def _aten_special_chebyshev_polynomial_u(self, n):
4579
4688
 
4580
4689
  @jnp.vectorize
4581
4690
  def vectorized(x, n_i):
4691
+
4582
4692
  def negative_n(x):
4583
4693
  return jnp.zeros_like(x)
4584
4694
 
@@ -4588,9 +4698,9 @@ def _aten_special_chebyshev_polynomial_u(self, n):
4588
4698
  def large_n_small_x(x):
4589
4699
  sin_acos_x = jnp.sin(jnp.acos(x))
4590
4700
  return jnp.where(
4591
- sin_acos_x != 0,
4592
- jnp.sin((n_i + 1) * jnp.acos(x)) / sin_acos_x,
4593
- (n_i + 1) * jnp.cos((n_i + 1) * jnp.acos(x)) / x,
4701
+ sin_acos_x != 0,
4702
+ jnp.sin((n_i + 1) * jnp.acos(x)) / sin_acos_x,
4703
+ (n_i + 1) * jnp.cos((n_i + 1) * jnp.acos(x)) / x,
4594
4704
  )
4595
4705
 
4596
4706
  def zero_n(x):
@@ -4600,6 +4710,7 @@ def _aten_special_chebyshev_polynomial_u(self, n):
4600
4710
  return 2 * x
4601
4711
 
4602
4712
  def default(x):
4713
+
4603
4714
  def f(_, carry):
4604
4715
  p, q = carry
4605
4716
  return (q, 2 * x * q - p)
@@ -4608,15 +4719,15 @@ def _aten_special_chebyshev_polynomial_u(self, n):
4608
4719
  return r
4609
4720
 
4610
4721
  return jnp.piecewise(
4611
- x,
4612
- [
4613
- n_i == 1,
4614
- n_i == 0,
4615
- (n_i > 8) & (jnp.abs(x) < 1),
4616
- jnp.abs(x) == 1.0,
4617
- n_i < 0,
4618
- ],
4619
- [one_n, zero_n, large_n_small_x, one_x, negative_n, default],
4722
+ x,
4723
+ [
4724
+ n_i == 1,
4725
+ n_i == 0,
4726
+ (n_i > 8) & (jnp.abs(x) < 1),
4727
+ jnp.abs(x) == 1.0,
4728
+ n_i < 0,
4729
+ ],
4730
+ [one_n, zero_n, large_n_small_x, one_x, negative_n, default],
4620
4731
  )
4621
4732
 
4622
4733
  return vectorized(self, n.astype(jnp.int64))
@@ -4627,6 +4738,7 @@ def _aten_special_chebyshev_polynomial_u(self, n):
4627
4738
  def _aten_special_erfcx(x):
4628
4739
  return jnp.exp(x * x) * jax.lax.erfc(x)
4629
4740
 
4741
+
4630
4742
  @op(torch.ops.aten.erfc)
4631
4743
  @op_base.promote_int_input
4632
4744
  def _aten_erfcx(x):
@@ -4640,6 +4752,7 @@ def _aten_special_hermite_polynomial_h(self, n):
4640
4752
 
4641
4753
  @jnp.vectorize
4642
4754
  def vectorized(x, n_i):
4755
+
4643
4756
  def negative_n(x):
4644
4757
  return jnp.zeros_like(x)
4645
4758
 
@@ -4650,6 +4763,7 @@ def _aten_special_hermite_polynomial_h(self, n):
4650
4763
  return 2 * x
4651
4764
 
4652
4765
  def default(x):
4766
+
4653
4767
  def f(k, carry):
4654
4768
  p, q = carry
4655
4769
  return (q, 2 * x * q - 2 * k * p)
@@ -4657,9 +4771,8 @@ def _aten_special_hermite_polynomial_h(self, n):
4657
4771
  _, r = jax.lax.fori_loop(1, n_i, f, init_val=(1.0, 2 * x))
4658
4772
  return r
4659
4773
 
4660
- return jnp.piecewise(
4661
- x, [n_i == 1, n_i == 0, n_i < 0], [one_n, zero_n, negative_n, default]
4662
- )
4774
+ return jnp.piecewise(x, [n_i == 1, n_i == 0, n_i < 0],
4775
+ [one_n, zero_n, negative_n, default])
4663
4776
 
4664
4777
  return vectorized(self, n.astype(jnp.int64))
4665
4778
 
@@ -4671,6 +4784,7 @@ def _aten_special_hermite_polynomial_he(self, n):
4671
4784
 
4672
4785
  @jnp.vectorize
4673
4786
  def vectorized(x, n_i):
4787
+
4674
4788
  def negative_n(x):
4675
4789
  return jnp.zeros_like(x)
4676
4790
 
@@ -4681,6 +4795,7 @@ def _aten_special_hermite_polynomial_he(self, n):
4681
4795
  return x
4682
4796
 
4683
4797
  def default(x):
4798
+
4684
4799
  def f(k, carry):
4685
4800
  p, q = carry
4686
4801
  return (q, x * q - k * p)
@@ -4688,24 +4803,34 @@ def _aten_special_hermite_polynomial_he(self, n):
4688
4803
  _, r = jax.lax.fori_loop(1, n_i, f, init_val=(1.0, x))
4689
4804
  return r
4690
4805
 
4691
- return jnp.piecewise(
4692
- x, [n_i == 1.0, n_i == 0.0, n_i < 0], [one_n, zero_n, negative_n, default]
4693
- )
4806
+ return jnp.piecewise(x, [n_i == 1.0, n_i == 0.0, n_i < 0],
4807
+ [one_n, zero_n, negative_n, default])
4694
4808
 
4695
4809
  return vectorized(self, n.astype(jnp.int64))
4696
4810
 
4697
4811
 
4698
4812
  @op(torch.ops.aten.multinomial, needs_env=True)
4699
- def _aten_multinomial(input, num_samples, replacement=False, *, generator=None, out=None, env=None):
4700
- assert num_samples <= input.shape[-1] or replacement, "cannot take a larger sample than population when replacement=False"
4701
- assert jnp.all(input >= 0), "inputs must be non-negative"
4813
+ def _aten_multinomial(input,
4814
+ num_samples,
4815
+ replacement=False,
4816
+ *,
4817
+ generator=None,
4818
+ out=None,
4819
+ env=None):
4820
+ assert num_samples <= input.shape[
4821
+ -1] or replacement, "cannot take a larger sample than population when replacement=False"
4702
4822
  key = env.get_and_rotate_prng_key(generator)
4703
4823
  if input.ndim == 1:
4704
- assert jnp.sum(input) > 0, "rows of input must have non-zero sum"
4705
- return jax.random.choice(key, input.shape[-1], (num_samples,), replace=replacement, p=input)
4824
+ return jax.random.choice(
4825
+ key, input.shape[-1], (num_samples,), replace=replacement, p=input)
4706
4826
  else:
4707
- assert jnp.all(jnp.sum(input, axis=1) > 0), "rows of input must have non-zero sum"
4708
- return jnp.array([jax.random.choice(key, input.shape[-1], (num_samples,), replace=replacement, p=input[i, :]) for i in range(input.shape[0])])
4827
+ return jnp.array([
4828
+ jax.random.choice(
4829
+ key,
4830
+ input.shape[-1], (num_samples,),
4831
+ replace=replacement,
4832
+ p=input[i, :]) for i in range(input.shape[0])
4833
+ ])
4709
4834
 
4710
4835
 
4711
4836
  @op(torch.ops.aten.narrow)
@@ -4738,7 +4863,12 @@ def _aten_flatten(x, start_dim=0, end_dim=-1):
4738
4863
 
4739
4864
  @op(torch.ops.aten.new_empty)
4740
4865
  def _new_empty(self, size, **kwargs):
4741
- return jnp.empty(size)
4866
+ dtype = kwargs.get('dtype')
4867
+ if dtype is not None:
4868
+ dtype = mappings.t2j_dtype(dtype)
4869
+ else:
4870
+ dtype = self.dtype
4871
+ return jnp.empty(size, dtype=dtype)
4742
4872
 
4743
4873
 
4744
4874
  @op(torch.ops.aten.new_empty_strided)
@@ -4756,10 +4886,8 @@ def _aten_unsafe_index_put(self, indices, values, accumulate=False):
4756
4886
  return _aten_index_put(self, indices, values, accumulate)
4757
4887
 
4758
4888
 
4759
- @op(torch.ops.aten.conj_physical,
4760
- torch.ops.aten.conj,
4761
- torch.ops.aten._conj_physical,
4762
- torch.ops.aten._conj)
4889
+ @op(torch.ops.aten.conj_physical, torch.ops.aten.conj,
4890
+ torch.ops.aten._conj_physical, torch.ops.aten._conj)
4763
4891
  def _aten_conj_physical(self):
4764
4892
  return jnp.conjugate(self)
4765
4893
 
@@ -4768,6 +4896,7 @@ def _aten_conj_physical(self):
4768
4896
  def _aten_log_sigmoid(x):
4769
4897
  return jax.nn.log_sigmoid(x)
4770
4898
 
4899
+
4771
4900
  # torch.qr
4772
4901
  @op(torch.ops.aten.qr)
4773
4902
  def _aten_qr(input, *args, **kwargs):
@@ -4778,6 +4907,7 @@ def _aten_qr(input, *args, **kwargs):
4778
4907
  jax_mode = "complete"
4779
4908
  return jax.numpy.linalg.qr(input, mode=jax_mode)
4780
4909
 
4910
+
4781
4911
  # torch.linalg.qr
4782
4912
  @op(torch.ops.aten.linalg_qr)
4783
4913
  def _aten_linalg_qr(input, *args, **kwargs):
@@ -4820,19 +4950,25 @@ def _aten__linalg_solve_ex(a, b):
4820
4950
  res = jnp.linalg.solve(a, b)
4821
4951
  if batched:
4822
4952
  res = res.squeeze(-1)
4823
- info_shape = a.shape[0] if len(a.shape) >= 3 else []
4953
+ info_shape = a.shape[:-2]
4824
4954
  info = jnp.zeros(info_shape, dtype=mappings.t2j_dtype(torch.int32))
4825
4955
  return res, info
4826
4956
 
4827
4957
 
4828
4958
  # torch.linalg.solve_triangular
4829
4959
  @op(torch.ops.aten.linalg_solve_triangular)
4830
- def _aten_linalg_solve_triangular(a, b, *, upper=True, left=True, unitriangular=False):
4960
+ def _aten_linalg_solve_triangular(a,
4961
+ b,
4962
+ *,
4963
+ upper=True,
4964
+ left=True,
4965
+ unitriangular=False):
4831
4966
  if left is False:
4832
4967
  a = jnp.matrix_transpose(a)
4833
4968
  b = jnp.matrix_transpose(b)
4834
4969
  upper = not upper
4835
- res = jax.scipy.linalg.solve_triangular(a, b, lower=not upper, unit_diagonal=unitriangular)
4970
+ res = jax.scipy.linalg.solve_triangular(
4971
+ a, b, lower=not upper, unit_diagonal=unitriangular)
4836
4972
  if left is False:
4837
4973
  res = jnp.matrix_transpose(res)
4838
4974
  return res
@@ -4852,21 +4988,31 @@ def _aten__linalg_check_errors(*args, **kwargs):
4852
4988
 
4853
4989
  @op(torch.ops.aten.median)
4854
4990
  def _aten_median(self, dim=None, keepdim=False):
4855
- output = _with_reduction_scalar(functools.partial(jnp.quantile, q=0.5, method='lower'), self, dim=dim, keepdim=keepdim).astype(self.dtype)
4991
+ output = _with_reduction_scalar(
4992
+ functools.partial(jnp.quantile, q=0.5, method='lower'),
4993
+ self,
4994
+ dim=dim,
4995
+ keepdim=keepdim).astype(self.dtype)
4856
4996
  if dim is None:
4857
4997
  return output
4858
4998
  else:
4859
- index = _with_reduction_scalar(_get_median_index, self, dim, keepdim).astype(jnp.int64)
4999
+ index = _with_reduction_scalar(_get_median_index, self, dim,
5000
+ keepdim).astype(jnp.int64)
4860
5001
  return output, index
4861
5002
 
4862
5003
 
4863
5004
  @op(torch.ops.aten.nanmedian)
4864
5005
  def _aten_nanmedian(input, dim=None, keepdim=False, *, out=None):
4865
- output = _with_reduction_scalar(functools.partial(jnp.nanquantile, q=0.5, method='lower'), input, dim=dim, keepdim=keepdim).astype(input.dtype)
5006
+ output = _with_reduction_scalar(
5007
+ functools.partial(jnp.nanquantile, q=0.5, method='lower'),
5008
+ input,
5009
+ dim=dim,
5010
+ keepdim=keepdim).astype(input.dtype)
4866
5011
  if dim is None:
4867
5012
  return output
4868
5013
  else:
4869
- index = _with_reduction_scalar(_get_median_index, input, dim, keepdim).astype(jnp.int64)
5014
+ index = _with_reduction_scalar(_get_median_index, input, dim,
5015
+ keepdim).astype(jnp.int64)
4870
5016
  return output, index
4871
5017
 
4872
5018
 
@@ -4874,20 +5020,31 @@ def _get_median_index(x, axis=None, keepdims=False):
4874
5020
  sorted_arg = jnp.argsort(x, axis=axis)
4875
5021
  n = x.shape[axis] if axis is not None else x.size
4876
5022
  if n % 2 == 1:
4877
- index = n // 2
5023
+ index = n // 2
4878
5024
  else:
4879
- index = (n // 2) - 1
5025
+ index = (n // 2) - 1
4880
5026
  if axis is None:
4881
- median_index = sorted_arg[index]
5027
+ median_index = sorted_arg[index]
4882
5028
  else:
4883
- median_index = jnp.take(sorted_arg, index, axis=axis)
5029
+ median_index = jnp.take(sorted_arg, index, axis=axis)
4884
5030
  if keepdims and axis is not None:
4885
- median_index = jnp.expand_dims(median_index, axis)
5031
+ median_index = jnp.expand_dims(median_index, axis)
4886
5032
  return median_index
4887
5033
 
5034
+
4888
5035
  @op(torch.ops.aten.triangular_solve)
4889
- def _aten_triangular_solve(b, a, upper=True, transpose=False, unittriangular=False):
4890
- return (jax.lax.linalg.triangular_solve(a, b, left_side=True, lower=not upper, transpose_a=transpose, unit_diagonal=unittriangular), a)
5036
+ def _aten_triangular_solve(b,
5037
+ a,
5038
+ upper=True,
5039
+ transpose=False,
5040
+ unittriangular=False):
5041
+ return (jax.lax.linalg.triangular_solve(
5042
+ a,
5043
+ b,
5044
+ left_side=True,
5045
+ lower=not upper,
5046
+ transpose_a=transpose,
5047
+ unit_diagonal=unittriangular), a)
4891
5048
 
4892
5049
 
4893
5050
  # func: _fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor
@@ -4895,16 +5052,16 @@ def _aten_triangular_solve(b, a, upper=True, transpose=False, unittriangular=Fal
4895
5052
  def _aten__fft_c2c(self, dim, normalization, forward):
4896
5053
  if forward:
4897
5054
  norm = [
4898
- 'backward',
4899
- 'ortho',
4900
- 'forward',
5055
+ 'backward',
5056
+ 'ortho',
5057
+ 'forward',
4901
5058
  ][normalization]
4902
5059
  return jnp.fft.fftn(self, axes=dim, norm=norm)
4903
5060
  else:
4904
5061
  norm = [
4905
- 'forward',
4906
- 'ortho',
4907
- 'backward',
5062
+ 'forward',
5063
+ 'ortho',
5064
+ 'backward',
4908
5065
  ][normalization]
4909
5066
  return jnp.fft.ifftn(self, axes=dim, norm=norm)
4910
5067
 
@@ -4912,21 +5069,22 @@ def _aten__fft_c2c(self, dim, normalization, forward):
4912
5069
  @op(torch.ops.aten._fft_r2c)
4913
5070
  def _aten__fft_r2c(self, dim, normalization, onesided):
4914
5071
  norm = [
4915
- 'backward',
4916
- 'ortho',
4917
- 'forward',
5072
+ 'backward',
5073
+ 'ortho',
5074
+ 'forward',
4918
5075
  ][normalization]
4919
5076
  if onesided:
4920
5077
  return jnp.fft.rfftn(self, axes=dim, norm=norm)
4921
5078
  else:
4922
5079
  return jnp.fft.fftn(self, axes=dim, norm=norm)
4923
5080
 
5081
+
4924
5082
  @op(torch.ops.aten._fft_c2r)
4925
5083
  def _aten__fft_c2r(self, dim, normalization, last_dim_size):
4926
5084
  norm = [
4927
- 'forward',
4928
- 'ortho',
4929
- 'backward',
5085
+ 'forward',
5086
+ 'ortho',
5087
+ 'backward',
4930
5088
  ][normalization]
4931
5089
  if len(dim) == 1:
4932
5090
  s = [last_dim_size]
@@ -4936,34 +5094,49 @@ def _aten__fft_c2r(self, dim, normalization, last_dim_size):
4936
5094
 
4937
5095
 
4938
5096
  @op(torch.ops.aten._trilinear)
4939
- def _aten_trilinear(i1, i2, i3, expand1, expand2, expand3, sumdim, unroll_dim=1):
4940
- return _aten_sum(jnp.expand_dims(i1, expand1) * jnp.expand_dims(i2, expand2) * jnp.expand_dims(i3, expand3), sumdim)
5097
+ def _aten_trilinear(i1,
5098
+ i2,
5099
+ i3,
5100
+ expand1,
5101
+ expand2,
5102
+ expand3,
5103
+ sumdim,
5104
+ unroll_dim=1):
5105
+ return _aten_sum(
5106
+ jnp.expand_dims(i1, expand1) * jnp.expand_dims(i2, expand2) *
5107
+ jnp.expand_dims(i3, expand3), sumdim)
4941
5108
 
4942
5109
 
4943
5110
  @op(torch.ops.aten.max_unpool2d)
4944
5111
  @op(torch.ops.aten.max_unpool3d)
4945
5112
  def _aten_max_unpoolxd(input, indices, output_size, stride=None, padding=0):
4946
5113
  if output_size is None:
4947
- raise ValueError("output_size value is not set correctly. It cannot be None or empty.")
5114
+ raise ValueError(
5115
+ "output_size value is not set correctly. It cannot be None or empty.")
4948
5116
 
4949
5117
  output_size = [input.shape[0], input.shape[1]] + output_size
4950
5118
  output = jnp.zeros(output_size, dtype=input.dtype)
4951
5119
 
4952
5120
  for idx in np.ndindex(input.shape):
4953
- max_index = indices[idx]
4954
- spatial_dims = output_size[2:] # (D, H, W)
4955
- unpooled_spatial_idx = np.unravel_index(max_index, spatial_dims)
4956
- full_idx = idx[:2] + unpooled_spatial_idx
4957
- output = output.at[full_idx].set(input[idx])
5121
+ max_index = indices[idx]
5122
+ spatial_dims = output_size[2:] # (D, H, W)
5123
+ unpooled_spatial_idx = np.unravel_index(max_index, spatial_dims)
5124
+ full_idx = idx[:2] + unpooled_spatial_idx
5125
+ output = output.at[full_idx].set(input[idx])
4958
5126
 
4959
5127
  return output
4960
5128
 
4961
- @op(torch.ops.aten._upsample_bilinear2d_aa)
4962
- def _aten_upsample_bilinear2d_aa(input, output_size, align_corners, scale_factors=None, scales_h=None, scales_w=None):
5129
+
5130
+ def _aten_upsample(input,
5131
+ output_size,
5132
+ align_corners,
5133
+ antialias,
5134
+ method,
5135
+ scale_factors=None,
5136
+ scales_h=None,
5137
+ scales_w=None):
4963
5138
  # input: is of type jaxlib.xla_extension.ArrayImpl
4964
5139
  image = input
4965
- method = "bilinear"
4966
- antialias = True # ignored for upsampling
4967
5140
 
4968
5141
  # https://jax.readthedocs.io/en/latest/_autosummary/jax.image.resize.html
4969
5142
  # Resize does not distinguish batch, channel size.
@@ -4977,12 +5150,12 @@ def _aten_upsample_bilinear2d_aa(input, output_size, align_corners, scale_factor
4977
5150
  shape = list(image.shape)
4978
5151
  # overriding output_size
4979
5152
  if scale_factors:
4980
- shape[-1] = int(math.floor(shape[-1]*scale_factors[-1]))
4981
- shape[-2] = int(math.floor(shape[-2]*scale_factors[-2]))
5153
+ shape[-1] = int(math.floor(shape[-1] * scale_factors[-1]))
5154
+ shape[-2] = int(math.floor(shape[-2] * scale_factors[-2]))
4982
5155
  if scales_h:
4983
- shape[-2] = int(math.floor(shape[-2]*scales_h))
5156
+ shape[-2] = int(math.floor(shape[-2] * scales_h))
4984
5157
  if scales_w:
4985
- shape[-1] = int(math.floor(shape[-1]*scales_w))
5158
+ shape[-1] = int(math.floor(shape[-1] * scales_w))
4986
5159
  # output_size overrides scale_factors, scales_*
4987
5160
  if output_size:
4988
5161
  shape[-1] = output_size[-1]
@@ -4992,11 +5165,11 @@ def _aten_upsample_bilinear2d_aa(input, output_size, align_corners, scale_factor
4992
5165
  if shape == list(image.shape):
4993
5166
  return image
4994
5167
 
4995
- spatial_dims = (2,3)
5168
+ spatial_dims = (2, 3)
4996
5169
  if len(shape) == 3:
4997
- spatial_dims = (1,2)
5170
+ spatial_dims = (1, 2)
4998
5171
 
4999
- scale = list([shape[i] / image.shape[i] for i in spatial_dims])
5172
+ scale = list([shape[i] / image.shape[i] for i in spatial_dims])
5000
5173
  if scale_factors:
5001
5174
  scale = scale_factors
5002
5175
  if scales_h:
@@ -5008,7 +5181,9 @@ def _aten_upsample_bilinear2d_aa(input, output_size, align_corners, scale_factor
5008
5181
  # align_corners is not supported in resize()
5009
5182
  # https://github.com/jax-ml/jax/issues/11206
5010
5183
  if align_corners:
5011
- scale = jnp.array([(shape[i] - 1.0) / (image.shape[i] - 1.0) for i in spatial_dims])
5184
+ scale = jnp.array([
5185
+ (shape[i] - 1.0) / (image.shape[i] - 1.0) for i in spatial_dims
5186
+ ])
5012
5187
 
5013
5188
  translation = jnp.array([0 for i in spatial_dims])
5014
5189
 
@@ -5022,12 +5197,53 @@ def _aten_upsample_bilinear2d_aa(input, output_size, align_corners, scale_factor
5022
5197
  antialias=antialias,
5023
5198
  )
5024
5199
 
5200
+
5201
+ @op(torch.ops.aten._upsample_bilinear2d_aa)
5202
+ def _aten_upsample_billinear_aa(input,
5203
+ output_size,
5204
+ align_corners,
5205
+ scale_factors=None,
5206
+ scales_h=None,
5207
+ scales_w=None):
5208
+ return _aten_upsample(
5209
+ input,
5210
+ output_size,
5211
+ align_corners,
5212
+ True, # antialias
5213
+ "bilinear", # method
5214
+ scale_factors,
5215
+ scales_h,
5216
+ scales_w)
5217
+
5218
+
5219
+ @op(torch.ops.aten._upsample_bicubic2d_aa)
5220
+ def _aten_upsample_bicubic2d_aa(input,
5221
+ output_size,
5222
+ align_corners,
5223
+ scale_factors=None,
5224
+ scales_h=None,
5225
+ scales_w=None):
5226
+ return _aten_upsample(
5227
+ input,
5228
+ output_size,
5229
+ align_corners,
5230
+ True, # antialias
5231
+ "bicubic", # method
5232
+ scale_factors,
5233
+ scales_h,
5234
+ scales_w)
5235
+
5236
+
5025
5237
  @op(torch.ops.aten.polar)
5026
5238
  def _aten_polar(abs, angle, *, out=None):
5027
5239
  return jax.lax.complex(abs * jnp.cos(angle), abs * jnp.sin(angle))
5028
5240
 
5241
+
5029
5242
  @op(torch.ops.aten.cdist)
5030
- def _aten_cdist(x1, x2, p=2.0, compute_mode='use_mm_for_euclid_dist_if_necessary'):
5243
+ def _aten_cdist(x1,
5244
+ x2,
5245
+ p=2.0,
5246
+ compute_mode='use_mm_for_euclid_dist_if_necessary'):
5031
5247
  x1 = x1.astype(jnp.float32)
5032
5248
  x2 = x2.astype(jnp.float32)
5033
5249
 
@@ -5036,7 +5252,8 @@ def _aten_cdist(x1, x2, p=2.0, compute_mode='use_mm_for_euclid_dist_if_necessary
5036
5252
  return _hamming_distance(x1, x2).astype(jnp.float32)
5037
5253
  elif p == 2.0:
5038
5254
  # Use optimized Euclidean distance calculation
5039
- if compute_mode == 'use_mm_for_euclid_dist_if_necessary' and (x1.shape[-2] > 25 or x2.shape[-2] > 25):
5255
+ if compute_mode == 'use_mm_for_euclid_dist_if_necessary' and (
5256
+ x1.shape[-2] > 25 or x2.shape[-2] > 25):
5040
5257
  return _euclidean_mm(x1, x2)
5041
5258
  elif compute_mode == 'use_mm_for_euclid_dist':
5042
5259
  return _euclidean_mm(x1, x2)
@@ -5045,7 +5262,8 @@ def _aten_cdist(x1, x2, p=2.0, compute_mode='use_mm_for_euclid_dist_if_necessary
5045
5262
  else:
5046
5263
  # General p-norm distance calculation
5047
5264
  diff = jnp.abs(jnp.expand_dims(x1, -2) - jnp.expand_dims(x2, -3))
5048
- return jnp.sum(jnp.power(diff, p), axis=-1).astype(jnp.float32) ** (1 / p)
5265
+ return jnp.sum(jnp.power(diff, p), axis=-1).astype(jnp.float32)**(1 / p)
5266
+
5049
5267
 
5050
5268
  def _hamming_distance(x1, x2):
5051
5269
  """
@@ -5064,6 +5282,7 @@ def _hamming_distance(x1, x2):
5064
5282
 
5065
5283
  return hamming_dist
5066
5284
 
5285
+
5067
5286
  def _euclidean_mm(x1, x2):
5068
5287
  """
5069
5288
  Computes the Euclidean distance using matrix multiplication.
@@ -5075,8 +5294,8 @@ def _euclidean_mm(x1, x2):
5075
5294
  Returns:
5076
5295
  JAX array of shape (..., P, R) representing pairwise Euclidean distances.
5077
5296
  """
5078
- x1_sq = jnp.sum(x1 ** 2, axis=-1, keepdims=True).astype(jnp.float32)
5079
- x2_sq = jnp.sum(x2 ** 2, axis=-1, keepdims=True).astype(jnp.float32)
5297
+ x1_sq = jnp.sum(x1**2, axis=-1, keepdims=True).astype(jnp.float32)
5298
+ x2_sq = jnp.sum(x2**2, axis=-1, keepdims=True).astype(jnp.float32)
5080
5299
 
5081
5300
  x2_sq = jnp.swapaxes(x2_sq, -2, -1)
5082
5301
 
@@ -5088,6 +5307,7 @@ def _euclidean_mm(x1, x2):
5088
5307
 
5089
5308
  return dist
5090
5309
 
5310
+
5091
5311
  def _euclidean_direct(x1, x2):
5092
5312
  """
5093
5313
  Computes the Euclidean distance directly without matrix multiplication.
@@ -5101,7 +5321,7 @@ def _euclidean_direct(x1, x2):
5101
5321
  """
5102
5322
  diff = jnp.expand_dims(x1, -2) - jnp.expand_dims(x2, -3)
5103
5323
 
5104
- dist_sq = jnp.sum(diff ** 2, axis=-1).astype(jnp.float32)
5324
+ dist_sq = jnp.sum(diff**2, axis=-1).astype(jnp.float32)
5105
5325
 
5106
5326
  dist_sq = jnp.maximum(dist_sq, 0.0)
5107
5327
 
@@ -5109,13 +5329,14 @@ def _euclidean_direct(x1, x2):
5109
5329
 
5110
5330
  return dist
5111
5331
 
5332
+
5112
5333
  @op(torch.ops.aten.lu_unpack)
5113
5334
  def _aten_lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
5114
5335
  # lu_unpack doesnt exist in jax.
5115
5336
  # Get commonly used data shape variables
5116
5337
  n = LU_data.shape[-2]
5117
5338
  m = LU_data.shape[-1]
5118
- dim = min(n,m)
5339
+ dim = min(n, m)
5119
5340
 
5120
5341
  ### Compute the Lower and Upper triangle
5121
5342
  if unpack_data:
@@ -5130,7 +5351,7 @@ def _aten_lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
5130
5351
  start_indices = jnp.zeros(len(LU_data.shape), dtype=int)
5131
5352
  limit_indices = list(LU_data.shape)
5132
5353
  limit_indices[-1] = dim
5133
- L = jax.lax.slice(L, start_indices, limit_indices)
5354
+ L = jax.lax.slice(L, start_indices, limit_indices)
5134
5355
 
5135
5356
  # Extract upper triangle
5136
5357
  U = jnp.triu(LU_data)
@@ -5160,13 +5381,15 @@ def _aten_lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
5160
5381
 
5161
5382
  # closure to be called for each input 2D matrix.
5162
5383
  def _lu_unpack_2d(p, pivot):
5163
- _pivot = pivot - 1 # pivots are offset by 1 in jax
5384
+ _pivot = pivot - 1 # pivots are offset by 1 in jax
5164
5385
  indices = jnp.array([*range(n)], dtype=jnp.int32)
5386
+
5165
5387
  def update_indices(i, _indices):
5166
5388
  tmp = _indices[i]
5167
5389
  _indices = _indices.at[i].set(_indices[_pivot[i]])
5168
5390
  _indices = _indices.at[_pivot[i]].set(tmp)
5169
5391
  return _indices
5392
+
5170
5393
  indices = jax.lax.fori_loop(0, _pivot.size, update_indices, indices)
5171
5394
  p = p[jnp.array(indices)]
5172
5395
  p = jnp.transpose(p)
@@ -5191,7 +5414,7 @@ def _aten_lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
5191
5414
  reshapedPivot = LU_pivots.reshape(newPivotshape)
5192
5415
 
5193
5416
  # vmap the reshaped 3d tensors
5194
- v_lu_unpack_2d = jax.vmap(_lu_unpack_2d, in_axes=(0,0))
5417
+ v_lu_unpack_2d = jax.vmap(_lu_unpack_2d, in_axes=(0, 0))
5195
5418
  unpackedP = v_lu_unpack_2d(reshapedP, reshapedPivot)
5196
5419
 
5197
5420
  # reshape result back to P's shape
@@ -5210,3 +5433,204 @@ def linear(input, weight, bias=None):
5210
5433
  if bias is not None:
5211
5434
  res += bias
5212
5435
  return res
5436
+
5437
+
5438
+ @op(torch.ops.aten.kthvalue)
5439
+ def kthvalue(input, k, dim=None, keepdim=False, *, out=None):
5440
+ if input.ndim == 0:
5441
+ return input, jnp.array(0)
5442
+ dimension = -1
5443
+ if dim is not None:
5444
+ dimension = dim
5445
+ while dimension < 0:
5446
+ dimension = dimension + input.ndim
5447
+ values = jax.lax.index_in_dim(
5448
+ jnp.partition(input, k - 1, dimension), k - 1, dimension, keepdim)
5449
+ indices = jax.lax.index_in_dim(
5450
+ jnp.argpartition(input, k - 1, dimension).astype('int64'), k - 1,
5451
+ dimension, keepdim)
5452
+ return values, indices
5453
+
5454
+
5455
+ @op(torch.ops.aten.take)
5456
+ def _aten_take(self, index):
5457
+ return self.flatten()[index]
5458
+
5459
+
5460
+ # func: pad(Tensor self, SymInt[] pad, str mode="constant", float? value=None) -> Tensor
5461
+ @op(torch.ops.aten.pad)
5462
+ def _aten_pad(self, pad, mode='constant', value=None):
5463
+ if not isinstance(pad, (tuple, list)) or len(pad) % 2 != 0:
5464
+ raise ValueError("Padding must be a sequence of even length.")
5465
+
5466
+ num_dims = self.ndim
5467
+ if len(pad) > 2 * num_dims:
5468
+ raise ValueError(
5469
+ f"Padding sequence length ({len(pad)}) exceeds 2 * number of dimensions ({2 * num_dims})."
5470
+ )
5471
+
5472
+ # JAX's pad function expects padding for each dimension as a tuple of (low, high)
5473
+ # We need to reverse the pad sequence and group them for JAX.
5474
+ # pad = [p_l0, p_r0, p_l1, p_r1, ...]
5475
+ # becomes ((..., ..., (p_l1, p_r1), (p_l0, p_r0)))
5476
+ jax_pad_width = []
5477
+ # Iterate in reverse pairs
5478
+ for i in range(len(pad) // 2):
5479
+ jax_pad_width.append((pad[(2 * i)], pad[(2 * i + 1)]))
5480
+
5481
+ # Pad any leading dimensions with (0, 0) if the pad sequence is shorter
5482
+ # than the number of dimensions.
5483
+ for _ in range(num_dims - len(pad) // 2):
5484
+ jax_pad_width.append((0, 0))
5485
+
5486
+ # Reverse the jax_pad_width list to match the dimension order
5487
+ jax_pad_width.reverse()
5488
+
5489
+ if mode == "constant":
5490
+ if value is None:
5491
+ value = 0.0
5492
+ return jnp.pad(
5493
+ self, pad_width=jax_pad_width, mode="constant", constant_values=value)
5494
+ elif mode == "reflect":
5495
+ return jnp.pad(self, pad_width=jax_pad_width, mode="reflect")
5496
+ elif mode == "edge":
5497
+ return jnp.pad(self, pad_width=jax_pad_width, mode="edge")
5498
+ else:
5499
+ raise ValueError(
5500
+ f"Unsupported padding mode: {mode}. Expected 'constant', 'reflect', or 'edge'."
5501
+ )
5502
+
5503
+
5504
+ @op(torch.ops.aten.is_nonzero)
5505
+ def _aten_is_nonzero(a):
5506
+ a = jnp.squeeze(a)
5507
+ if a.shape == (0,):
5508
+ raise RuntimeError('bool value of Tensor with no values is ambiguous')
5509
+ if a.ndim != 0:
5510
+ raise RuntimeError(
5511
+ 'bool value of Tensor with more than one value is ambiguous')
5512
+ return a.item() != 0
5513
+
5514
+
5515
+ @op(torch.ops.aten.logit)
5516
+ def _aten_logit(self: jax.Array, eps: float | None = None) -> jax.Array:
5517
+ """
5518
+ Computes the logit function of the input tensor.
5519
+
5520
+ logit(p) = log(p / (1 - p))
5521
+
5522
+ Args:
5523
+ self: Input tensor.
5524
+ eps: A small value to clip the input tensor to avoid log(0) or division by zero.
5525
+ If None, no clipping is performed.
5526
+
5527
+ Returns:
5528
+ A tensor with the logit of each element of the input.
5529
+ """
5530
+ if eps is not None:
5531
+ self = jnp.clip(self, eps, 1.0 - eps)
5532
+ res = jnp.log(self / (1.0 - self))
5533
+ res = res.astype(mappings.t2j_dtype(torch.get_default_dtype()))
5534
+ return res
5535
+
5536
+
5537
+ @op(torch.ops.aten.floor_divide)
5538
+ def _aten_floor_divide(x, y):
5539
+ res = jnp.floor_divide(x, y)
5540
+ return res
5541
+
5542
+
5543
+ @op(torch.ops.aten._assert_tensor_metadata)
5544
+ def _aten__assert_tensor_metadata(*args, **kwargs):
5545
+ pass
5546
+
5547
+
5548
+ mutation_ops_to_functional = {
5549
+ torch.ops.aten.add_:
5550
+ op_base.InplaceOp(torch.ops.aten.add),
5551
+ torch.ops.aten.sub_:
5552
+ op_base.InplaceOp(torch.ops.aten.sub),
5553
+ torch.ops.aten.mul_:
5554
+ op_base.InplaceOp(torch.ops.aten.mul),
5555
+ torch.ops.aten.div_:
5556
+ op_base.InplaceOp(torch.ops.aten.div),
5557
+ torch.ops.aten.pow_:
5558
+ op_base.InplaceOp(torch.ops.aten.pow),
5559
+ torch.ops.aten.lt_:
5560
+ op_base.InplaceOp(torch.ops.aten.lt),
5561
+ torch.ops.aten.le_:
5562
+ op_base.InplaceOp(torch.ops.aten.le),
5563
+ torch.ops.aten.gt_:
5564
+ op_base.InplaceOp(torch.ops.aten.gt),
5565
+ torch.ops.aten.ge_:
5566
+ op_base.InplaceOp(torch.ops.aten.ge),
5567
+ torch.ops.aten.eq_:
5568
+ op_base.InplaceOp(torch.ops.aten.eq),
5569
+ torch.ops.aten.ne_:
5570
+ op_base.InplaceOp(torch.ops.aten.ne),
5571
+ torch.ops.aten.bernoulli_:
5572
+ op_base.InplaceOp(torch.ops.aten.bernoulli.p),
5573
+ torch.ops.aten.bernoulli_.float:
5574
+ op_base.InplaceOp(_aten_bernoulli, is_jax_func=True),
5575
+ torch.ops.aten.geometric_:
5576
+ op_base.InplaceOp(torch.ops.aten.geometric),
5577
+ torch.ops.aten.normal_:
5578
+ op_base.InplaceOp(torch.ops.aten.normal),
5579
+ torch.ops.aten.random_:
5580
+ op_base.InplaceOp(torch.ops.aten.uniform),
5581
+ torch.ops.aten.uniform_:
5582
+ op_base.InplaceOp(torch.ops.aten.uniform),
5583
+ torch.ops.aten.relu_:
5584
+ op_base.InplaceOp(torch.ops.aten.relu),
5585
+ # squeeze_ is expected to change tensor's shape. So replace with new value
5586
+ torch.ops.aten.squeeze_:
5587
+ op_base.InplaceOp(torch.ops.aten.squeeze, True),
5588
+ torch.ops.aten.sqrt_:
5589
+ op_base.InplaceOp(torch.ops.aten.sqrt),
5590
+ torch.ops.aten.clamp_:
5591
+ op_base.InplaceOp(torch.ops.aten.clamp),
5592
+ torch.ops.aten.clamp_min_:
5593
+ op_base.InplaceOp(torch.ops.aten.clamp_min),
5594
+ torch.ops.aten.sigmoid_:
5595
+ op_base.InplaceOp(torch.ops.aten.sigmoid),
5596
+ torch.ops.aten.tanh_:
5597
+ op_base.InplaceOp(torch.ops.aten.tanh),
5598
+ torch.ops.aten.ceil_:
5599
+ op_base.InplaceOp(torch.ops.aten.ceil),
5600
+ torch.ops.aten.logical_not_:
5601
+ op_base.InplaceOp(torch.ops.aten.logical_not),
5602
+ torch.ops.aten.unsqueeze_:
5603
+ op_base.InplaceOp(torch.ops.aten.unsqueeze),
5604
+ torch.ops.aten.transpose_:
5605
+ op_base.InplaceOp(torch.ops.aten.transpose),
5606
+ torch.ops.aten.log_normal_:
5607
+ op_base.InplaceOp(torch.ops.aten.log_normal),
5608
+ torch.ops.aten.scatter_add_:
5609
+ op_base.InplaceOp(torch.ops.aten.scatter_add),
5610
+ torch.ops.aten.scatter_reduce_.two:
5611
+ op_base.InplaceOp(torch.ops.aten.scatter_reduce),
5612
+ torch.ops.aten.scatter_:
5613
+ op_base.InplaceOp(torch.ops.aten.scatter),
5614
+ torch.ops.aten.bitwise_or_:
5615
+ op_base.InplaceOp(torch.ops.aten.bitwise_or),
5616
+ torch.ops.aten.floor_divide_:
5617
+ op_base.InplaceOp(torch.ops.aten.floor_divide),
5618
+ torch.ops.aten.remainder_:
5619
+ op_base.InplaceOp(torch.ops.aten.remainder),
5620
+ }
5621
+
5622
+ # Note: tuple comparisons work intuitively, e.g. `_jax_version >= (0, 4, 32)`.
5623
+ _jax_version = tuple(int(v) for v in jax.version._version.split("."))
5624
+
5625
+ mutation_needs_env = {
5626
+ torch.ops.aten.bernoulli_,
5627
+ torch.ops.aten.bernoulli_.float,
5628
+ }
5629
+
5630
+ for operator, mutation in mutation_ops_to_functional.items():
5631
+ ops_registry.register_torch_dispatch_op(
5632
+ operator,
5633
+ mutation,
5634
+ is_jax_function=False,
5635
+ is_view_op=True,
5636
+ needs_env=(operator in mutation_needs_env))