torchax 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

torchax/ops/jtorch.py CHANGED
@@ -1,8 +1,9 @@
1
1
  """Tensor constructor overrides"""
2
+
2
3
  import math
3
4
  import collections.abc
4
5
  import functools
5
- from typing import Optional, Sequence
6
+ from typing import Optional, Sequence, Tuple
6
7
  import numpy as np
7
8
 
8
9
  import jax
@@ -12,8 +13,10 @@ from jax.experimental.shard_map import shard_map
12
13
 
13
14
  import torch
14
15
  from torchax.ops.ops_registry import register_torch_function_op
15
- from torchax.ops import op_base, mappings, jaten
16
+ from torchax.ops import op_base, mappings, jaten, jimage
16
17
  import torchax.tensor
18
+ from torchax.view import View, NarrowInfo
19
+ import torch.utils._pytree as pytree
17
20
 
18
21
 
19
22
  def register_function(torch_func, **kwargs):
@@ -21,7 +24,8 @@ def register_function(torch_func, **kwargs):
21
24
 
22
25
 
23
26
  @register_function(torch.as_tensor, is_jax_function=False, needs_env=True)
24
- @op_base.convert_dtype(use_default_dtype=False) # Attempt to infer type from elements
27
+ @op_base.convert_dtype(
28
+ use_default_dtype=False) # Attempt to infer type from elements
25
29
  def _as_tensor(data, dtype=None, device=None, env=None):
26
30
  if isinstance(data, torch.Tensor):
27
31
  return env._to_copy(data, dtype, device)
@@ -33,7 +37,8 @@ def _as_tensor(data, dtype=None, device=None, env=None):
33
37
 
34
38
 
35
39
  @register_function(torch.tensor)
36
- @op_base.convert_dtype(use_default_dtype=False) # Attempt to infer type from elements
40
+ @op_base.convert_dtype(
41
+ use_default_dtype=False) # Attempt to infer type from elements
37
42
  def _tensor(data, *, dtype=None, **kwargs):
38
43
  python_types_to_torch_types = {
39
44
  bool: jnp.bool,
@@ -57,8 +62,8 @@ def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):
57
62
 
58
63
  @register_function(torch.angle)
59
64
  def _torch_angle(input):
60
- if input.dtype.name == 'int64':
61
- input = input.astype(jnp.dtype('float32'))
65
+ if input.dtype.name == "int64":
66
+ input = input.astype(jnp.dtype("float32"))
62
67
  return jnp.angle(input)
63
68
 
64
69
 
@@ -72,19 +77,21 @@ def _torch_argsort(input, dim=-1, descending=False, stable=False):
72
77
  # behavior is the same as a jnp array of rank 1
73
78
  expanded = True
74
79
  input = jnp.expand_dims(input, 0)
75
- res = jnp.argsort(input, axis=dim, descending=descending,
76
- stable=stable)
80
+ res = jnp.argsort(input, axis=dim, descending=descending, stable=stable)
77
81
  if expanded:
78
82
  res = res.squeeze()
79
83
  return res
80
84
 
85
+
81
86
  @register_function(torch.diag)
82
87
  def _diag(input, diagonal=0):
83
88
  return jnp.diag(input, k=diagonal)
84
89
 
90
+
85
91
  @register_function(torch.einsum)
86
92
  @register_function(torch.ops.aten.einsum)
87
93
  def _einsum(equation, *operands):
94
+
88
95
  def get_params(*a):
89
96
  inner_list = a[0]
90
97
  if not isinstance(inner_list, jax.Array):
@@ -95,71 +102,90 @@ def _einsum(equation, *operands):
95
102
  A, B = inner_list
96
103
  return A, B
97
104
  return operands
98
- assert isinstance(equation, str), 'Only accept str equation'
105
+
106
+ assert isinstance(equation, str), "Only accept str equation"
99
107
  filtered_operands = get_params(*operands)
100
108
  return jnp.einsum(equation, *filtered_operands)
101
109
 
102
110
 
103
- def _sdpa_reference(query, key, value, attn_mask=None, dropout_p=0.0,
104
- is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
105
- L, S = query.size(-2), key.size(-2)
106
- scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
107
- attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
108
- if is_causal:
109
- assert attn_mask is None
110
- temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
111
- attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
112
- attn_bias.to(query.dtype)
113
- if attn_mask is not None:
114
- if attn_mask.dtype == torch.bool:
115
- attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
116
- else:
117
- attn_bias += attn_mask
118
- if enable_gqa:
119
- key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
120
- value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
121
-
122
- attn_weight = query @ key.transpose(-2, -1) * scale_factor
123
- attn_weight += attn_bias
124
- attn_weight = torch.softmax(attn_weight, dim=-1)
125
- if dropout_p > 0:
126
- attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
127
- return attn_weight @ value
111
+ def _sdpa_reference(
112
+ query,
113
+ key,
114
+ value,
115
+ attn_mask=None,
116
+ dropout_p=0.0,
117
+ is_causal=False,
118
+ scale=None,
119
+ enable_gqa=False,
120
+ ) -> torch.Tensor:
121
+ L, S = query.size(-2), key.size(-2)
122
+ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
123
+ attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
124
+ if is_causal:
125
+ assert attn_mask is None
126
+ temp_mask = torch.ones(
127
+ L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
128
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
129
+ attn_bias.to(query.dtype)
130
+ if attn_mask is not None:
131
+ if attn_mask.dtype == torch.bool:
132
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
133
+ else:
134
+ attn_bias += attn_mask
135
+ if enable_gqa:
136
+ key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
137
+ value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)
138
+
139
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
140
+ attn_weight += attn_bias
141
+ attn_weight = torch.softmax(attn_weight, dim=-1)
142
+ if dropout_p > 0:
143
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
144
+ return attn_weight @ value
128
145
 
129
146
 
130
147
  from jax.sharding import PartitionSpec
131
148
 
149
+
132
150
  def _tpu_flash_attention(query, key, value, env):
133
- fsdp_partition = PartitionSpec('fsdp')
151
+ fsdp_partition = PartitionSpec("fsdp")
152
+
134
153
  def wrap_flash_attention(query, key, value):
135
154
  block_sizes = flash_attention.BlockSizes(
136
- block_b=min(2, query.shape[0]),
137
- block_q=min(512, query.shape[2]),
138
- block_k_major=min(512, key.shape[2]),
139
- block_k=min(512, key.shape[2]),
140
- block_q_major_dkv=min(512, query.shape[2]),
141
- block_k_major_dkv=min(512, key.shape[2]),
142
- block_k_dkv=min(512, key.shape[2]),
143
- block_q_dkv=min(512, query.shape[2]),
144
- block_k_major_dq=min(512, key.shape[2]),
145
- block_k_dq=min(256, key.shape[2]),
146
- block_q_dq=min(1024, query.shape[2]),
155
+ block_b=min(2, query.shape[0]),
156
+ block_q=min(512, query.shape[2]),
157
+ block_k_major=min(512, key.shape[2]),
158
+ block_k=min(512, key.shape[2]),
159
+ block_q_major_dkv=min(512, query.shape[2]),
160
+ block_k_major_dkv=min(512, key.shape[2]),
161
+ block_k_dkv=min(512, key.shape[2]),
162
+ block_q_dkv=min(512, query.shape[2]),
163
+ block_k_major_dq=min(512, key.shape[2]),
164
+ block_k_dq=min(256, key.shape[2]),
165
+ block_q_dq=min(1024, query.shape[2]),
147
166
  )
148
167
  return flash_attention.flash_attention(
149
168
  query, key, value, causal=True, block_sizes=block_sizes)
150
169
 
151
170
  if env.config.shmap_flash_attention:
152
171
  wrap_flash_attention = shard_map(
153
- wrap_flash_attention,
154
- mesh=env._mesh,
155
- in_specs=(fsdp_partition, fsdp_partition, fsdp_partition),
156
- out_specs=fsdp_partition ,
157
- check_rep=False,
172
+ wrap_flash_attention,
173
+ mesh=env._mesh,
174
+ in_specs=(fsdp_partition, fsdp_partition, fsdp_partition),
175
+ out_specs=fsdp_partition,
176
+ check_rep=False,
158
177
  )
159
- #return flash_attn_mapped(query, key, value)
178
+ # return flash_attn_mapped(query, key, value)
160
179
  return wrap_flash_attention(query, key, value)
161
180
 
162
181
 
182
+ @register_function(torch.nn.functional.one_hot)
183
+ def one_hot(tensor, num_classes=-1):
184
+ if num_classes == -1:
185
+ num_classes = jnp.max(tensor) + 1
186
+ return jax.nn.one_hot(tensor, num_classes).astype(jnp.int64)
187
+
188
+
163
189
  @register_function(torch.nn.functional.pad)
164
190
  def pad(tensor, pad, mode="constant", value=None):
165
191
  # For padding modes that have different names between Torch and NumPy, this
@@ -210,27 +236,60 @@ def pad(tensor, pad, mode="constant", value=None):
210
236
  return jnp.pad(tensor[nd_slice], numpy_pad_width, mode=numpy_mode, **kwargs)
211
237
 
212
238
 
213
- @register_function(torch.nn.functional.scaled_dot_product_attention, is_jax_function=False, needs_env=True)
214
- @register_function(torch.ops.aten.scaled_dot_product_attention, is_jax_function=False, needs_env=True)
239
+ @register_function(
240
+ torch.nn.functional.scaled_dot_product_attention,
241
+ is_jax_function=False,
242
+ needs_env=True,
243
+ )
244
+ @register_function(
245
+ torch.ops.aten.scaled_dot_product_attention,
246
+ is_jax_function=False,
247
+ needs_env=True)
215
248
  def scaled_dot_product_attention(
216
- query, key, value, attn_mask=None,
217
- dropout_p=0.0, is_causal=False, scale=None, enable_gqa=False, env=None) -> torch.Tensor:
218
-
219
- if env.config.use_tpu_flash_attention:
249
+ query,
250
+ key,
251
+ value,
252
+ attn_mask=None,
253
+ dropout_p=0.0,
254
+ is_causal=False,
255
+ scale=None,
256
+ enable_gqa=False,
257
+ env=None,
258
+ ) -> torch.Tensor:
259
+
260
+ if env.config.use_tpu_flash_attention:
220
261
  jquery, jkey, jvalue = env.t2j_iso((query, key, value))
221
262
  res = _tpu_flash_attention(jquery, jkey, jvalue, env)
222
263
  return env.j2t_iso(res)
223
264
 
224
- return _sdpa_reference(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa)
265
+ return _sdpa_reference(query, key, value, attn_mask, dropout_p, is_causal,
266
+ scale, enable_gqa)
267
+
225
268
 
226
- @register_function(torch.Tensor.__getitem__)
269
+ @register_function(
270
+ torch.Tensor.__getitem__, is_jax_function=False, is_view_op=True)
227
271
  def getitem(self, indexes):
272
+
228
273
  if isinstance(indexes, list) and isinstance(indexes[0], int):
229
274
  # list of int, i.e. x[[1, 2]] NOT x[1, 2] (the second would be tuple of int)
230
- indexes = (indexes, )
275
+ indexes = (indexes,)
231
276
  elif isinstance(indexes, list):
232
277
  indexes = tuple(indexes)
233
- return self[indexes]
278
+
279
+ def is_narrow_slicing():
280
+ tensor_free = not pytree.tree_any(
281
+ lambda x: isinstance(x, torch.Tensor) or isinstance(x, jax.Array),
282
+ indexes)
283
+ list_free = not isinstance(indexes, tuple) or all(
284
+ [False if isinstance(x, list) else True for x in indexes])
285
+ return tensor_free and list_free
286
+
287
+ if is_narrow_slicing():
288
+ return View(self, view_info=NarrowInfo(indexes), env=self._env)
289
+
290
+ indexes = self._env.t2j_iso(indexes)
291
+ return torchax.tensor.Tensor(self._elem[indexes], self._env)
292
+
234
293
 
235
294
  @register_function(torch.corrcoef)
236
295
  def _corrcoef(x):
@@ -238,15 +297,22 @@ def _corrcoef(x):
238
297
  return jnp.corrcoef(x).astype(jnp.float32)
239
298
  return jnp.corrcoef(x)
240
299
 
300
+
241
301
  @register_function(torch.sparse.mm, is_jax_function=False)
242
- def _sparse_mm(mat1, mat2, reduce='sum'):
302
+ def _sparse_mm(mat1, mat2, reduce="sum"):
243
303
  return torch.mm(mat1, mat2)
244
304
 
305
+
245
306
  @register_function(torch.isclose)
246
307
  def _aten_isclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):
247
308
  return jnp.isclose(input, other, rtol, atol, equal_nan)
248
309
 
249
310
 
311
+ @register_function(torch.linalg.det)
312
+ def linalg_det(input):
313
+ return jnp.linalg.det(input)
314
+
315
+
250
316
  @register_function(torch.ones)
251
317
  def _ones(*size: int, dtype=None, **kwargs):
252
318
  if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
@@ -281,23 +347,39 @@ def empty(*size: Sequence[int], dtype=None, **kwargs):
281
347
  size = size[0]
282
348
  return jnp.empty(size, dtype=dtype)
283
349
 
284
- @register_function(torch.arange, is_jax_function=False)
350
+
351
+ @register_function(torch.arange, is_jax_function=True)
285
352
  def arange(
286
- start, end=None, step=None,
287
- out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False,
288
- pin_memory=None,
353
+ start,
354
+ end=None,
355
+ step=None,
356
+ out=None,
357
+ dtype=None,
358
+ layout=torch.strided,
359
+ device=None,
360
+ requires_grad=False,
361
+ pin_memory=None,
289
362
  ):
290
363
  if end is None:
291
364
  end = start
292
365
  start = 0
293
366
  if step is None:
294
367
  step = 1
295
- return torch.ops.aten.arange(start, end, step, dtype=dtype)
368
+ return jaten._aten_arange(start, end, step, dtype=dtype)
369
+
296
370
 
297
- @register_function(torch.empty_strided, is_jax_function=False)
371
+ @register_function(torch.empty_strided, is_jax_function=True)
298
372
  def empty_strided(
299
- size, stride, *, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False):
300
- return empty(size, dtype=dtype)
373
+ size,
374
+ stride,
375
+ *,
376
+ dtype=None,
377
+ layout=None,
378
+ device=None,
379
+ requires_grad=False,
380
+ pin_memory=False,
381
+ ):
382
+ return empty(size, dtype=dtype, requires_grad=requires_grad)
301
383
 
302
384
 
303
385
  @register_function(torch.unravel_index)
@@ -305,27 +387,33 @@ def unravel_index(indices, shape):
305
387
  return jnp.unravel_index(indices, shape)
306
388
 
307
389
 
308
- @register_function(torch.rand, is_jax_function=False)
309
- def rand(
310
- *size, **kwargs
311
- ):
390
+ @register_function(torch.rand, is_jax_function=True, needs_env=True)
391
+ def rand(*size, **kwargs):
312
392
  if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
313
393
  size = size[0]
314
- return torch.ops.aten.rand(size, **kwargs)
394
+ return jaten._rand(size, **kwargs)
395
+
315
396
 
316
- @register_function(torch.randn, is_jax_function=False)
397
+ @register_function(torch.randn, is_jax_function=True, needs_env=True)
317
398
  def randn(
318
- *size, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False
399
+ *size,
400
+ generator=None,
401
+ out=None,
402
+ dtype=None,
403
+ layout=torch.strided,
404
+ device=None,
405
+ requires_grad=False,
406
+ pin_memory=False,
407
+ env=None,
319
408
  ):
320
409
  if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
321
410
  size = size[0]
322
- return torch.ops.aten.randn(size, generator=generator, dtype=dtype)
411
+ return jaten._aten_randn(size, generator=generator, dtype=dtype, env=env)
323
412
 
324
- @register_function(torch.randint, is_jax_function=False)
325
- def randint(
326
- *args, **kwargs
327
- ):
328
- return torch.ops.aten.randint(*args, **kwargs)
413
+
414
+ @register_function(torch.randint, is_jax_function=False, needs_env=True)
415
+ def randint(*args, **kwargs):
416
+ return jaten._aten_randint(*args, **kwargs)
329
417
 
330
418
 
331
419
  @register_function(torch.logdet)
@@ -356,14 +444,17 @@ def linalg_solve_ex(a, b):
356
444
  res, info = jaten._aten__linalg_solve_ex(a, b)
357
445
  return res, info
358
446
 
447
+
359
448
  @register_function(torch.linalg.svd)
360
449
  def linalg_svd(a, full_matrices=True):
361
450
  return jaten._aten__linalg_svd(a, full_matrices=full_matrices)
362
451
 
452
+
363
453
  @register_function(torch.linalg.matrix_power)
364
454
  def matrix_power(A, n, *, out=None):
365
455
  return jnp.linalg.matrix_power(A, n)
366
456
 
457
+
367
458
  @register_function(torch.svd)
368
459
  def svd(a, some=True, compute_uv=True):
369
460
  if not compute_uv:
@@ -374,21 +465,24 @@ def svd(a, some=True, compute_uv=True):
374
465
  U, S, V = jaten._aten__linalg_svd(a, full_matrices=not some)
375
466
  return U, S, jnp.matrix_transpose(V)
376
467
 
468
+
377
469
  @register_function(torch.cdist)
378
- def _cdist(x1, x2, p=2.0, compute_mode='use_mm_for_euclid_dist_if_necessary'):
379
- return jaten._aten_cdist(x1, x2, p, compute_mode)
470
+ def _cdist(x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary"):
471
+ return jaten._aten_cdist(x1, x2, p, compute_mode)
472
+
380
473
 
381
474
  @register_function(torch.lu)
382
475
  def lu(A, **kwargs):
383
- lu,pivots,_ = jax.lax.linalg.lu(A)
476
+ lu, pivots, _ = jax.lax.linalg.lu(A)
384
477
  # JAX pivots are offset by 1 compared to torch
385
478
  _pivots = pivots + 1
386
479
  info_shape = pivots.shape[:-1]
387
480
  info = jnp.zeros(info_shape, dtype=mappings.t2j_dtype(torch.int32))
388
- if kwargs['get_infos'] == True:
481
+ if kwargs["get_infos"] == True:
389
482
  return lu, _pivots, info
390
483
  return lu, _pivots
391
484
 
485
+
392
486
  @register_function(torch.lu_solve)
393
487
  def lu_solve(b, LU_data, LU_pivots, **kwargs):
394
488
  # JAX pivots are offset by 1 compared to torch
@@ -396,6 +490,7 @@ def lu_solve(b, LU_data, LU_pivots, **kwargs):
396
490
  x = jax.scipy.linalg.lu_solve((LU_data, _pivots), b)
397
491
  return x
398
492
 
493
+
399
494
  @register_function(torch.linalg.tensorsolve)
400
495
  def linalg_tensorsolve(A, b, dims=None):
401
496
  # examples:
@@ -425,3 +520,57 @@ def functional_linear(self, weights, bias=None):
425
520
  if bias is not None:
426
521
  res += bias
427
522
  return res
523
+
524
+
525
+ @register_function(torch.nn.functional.interpolate)
526
+ def functional_interpolate(
527
+ input,
528
+ size: Tuple[int, int],
529
+ scale_factor: Optional[float],
530
+ mode: str,
531
+ align_corners: bool,
532
+ recompute_scale_factor: bool,
533
+ antialias: bool,
534
+ ):
535
+ supported_methods = (
536
+ "nearest",
537
+ "linear",
538
+ "bilinear",
539
+ "trilinear",
540
+ "cubic",
541
+ "bicubic",
542
+ "tricubic",
543
+ "lanczos3",
544
+ "lanczos5",
545
+ )
546
+ is_jax_supported = mode in supported_methods
547
+ if not is_jax_supported:
548
+ raise torchax.tensor.OperatorNotFound(
549
+ f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}"
550
+ )
551
+ # None check
552
+ antialias = antialias or False
553
+ align_corners = align_corners or False
554
+
555
+ if mode in ('cubic', 'bicubic',
556
+ 'tricubic') and not antialias and size is not None:
557
+ return jimage.interpolate_bicubic_no_aa(
558
+ input,
559
+ size[0],
560
+ size[1],
561
+ align_corners,
562
+ )
563
+ else:
564
+ # fallback
565
+ raise torchax.tensor.OperatorNotFound(
566
+ f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}"
567
+ )
568
+
569
+
570
+ @register_function(torch.Tensor.repeat_interleave)
571
+ def torch_Tensor_repeat_interleave(self,
572
+ repeats,
573
+ dim=None,
574
+ *,
575
+ output_size=None):
576
+ return jnp.repeat(self, repeats, axis=dim, total_repeat_length=output_size)
@@ -53,8 +53,8 @@ def _self_suppression(in_args):
53
53
  can_suppress_others = jnp.reshape(
54
54
  jnp.max(iou, 1) <= 0.5, [batch_size, -1, 1]).astype(iou.dtype)
55
55
  iou_suppressed = jnp.reshape(
56
- (jnp.max(can_suppress_others * iou, 1) <= 0.5).astype(iou.dtype),
57
- [batch_size, -1, 1]) * iou
56
+ (jnp.max(can_suppress_others * iou, 1) <= 0.5).astype(
57
+ iou.dtype), [batch_size, -1, 1]) * iou
58
58
  iou_sum_new = jnp.sum(iou_suppressed, [1, 2])
59
59
  return iou_suppressed, jnp.any(iou_sum - iou_sum_new > 0.5), iou_sum_new
60
60
 
@@ -65,9 +65,8 @@ def _cross_suppression(in_args):
65
65
  new_slice = lax.dynamic_slice(boxes, [0, inner_idx * _NMS_TILE_SIZE, 0],
66
66
  [batch_size, _NMS_TILE_SIZE, 4])
67
67
  iou = _bbox_overlap(new_slice, box_slice)
68
- ret_slice = jnp.expand_dims(
69
- (jnp.all(iou < iou_threshold, [1])).astype(box_slice.dtype),
70
- 2) * box_slice
68
+ ret_slice = jnp.expand_dims((jnp.all(iou < iou_threshold, [1])).astype(
69
+ box_slice.dtype), 2) * box_slice
71
70
  return boxes, ret_slice, iou_threshold, inner_idx + 1
72
71
 
73
72
 
@@ -90,45 +89,40 @@ def _suppression_loop_body(in_args):
90
89
  # Iterates over tiles that can possibly suppress the current tile.
91
90
  box_slice = lax.dynamic_slice(boxes, [0, idx * _NMS_TILE_SIZE, 0],
92
91
  [batch_size, _NMS_TILE_SIZE, 4])
92
+
93
93
  def _loop_cond(in_args):
94
94
  _, _, _, inner_idx = in_args
95
95
  return inner_idx < idx
96
96
 
97
- _, box_slice, _, _ = lax.while_loop(
98
- _loop_cond,
99
- _cross_suppression, (boxes, box_slice, iou_threshold,
100
- 0))
97
+ _, box_slice, _, _ = lax.while_loop(_loop_cond, _cross_suppression,
98
+ (boxes, box_slice, iou_threshold, 0))
101
99
 
102
100
  # Iterates over the current tile to compute self-suppression.
103
101
  iou = _bbox_overlap(box_slice, box_slice)
104
102
  mask = jnp.expand_dims(
105
- jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [1, -1]) > jnp.reshape(
106
- jnp.arange(_NMS_TILE_SIZE), [-1, 1]), 0)
103
+ jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [1, -1])
104
+ > jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [-1, 1]), 0)
107
105
  iou *= (jnp.logical_and(mask, iou >= iou_threshold)).astype(iou.dtype)
108
106
 
109
107
  def _loop_cond2(in_args):
110
108
  _, loop_condition, _ = in_args
111
109
  return loop_condition
112
110
 
113
- suppressed_iou, _, _ = lax.while_loop(
114
- _loop_cond2, _self_suppression,
115
- (iou, True,
116
- jnp.sum(iou, [1, 2])))
111
+ suppressed_iou, _, _ = lax.while_loop(_loop_cond2, _self_suppression,
112
+ (iou, True, jnp.sum(iou, [1, 2])))
117
113
  suppressed_box = jnp.sum(suppressed_iou, 1) > 0
118
114
  box_slice *= jnp.expand_dims(1.0 - suppressed_box.astype(box_slice.dtype), 2)
119
115
 
120
116
  # Uses box_slice to update the input boxes.
121
- mask = jnp.reshape(
122
- (jnp.equal(jnp.arange(num_tiles), idx)).astype(boxes.dtype),
123
- [1, -1, 1, 1])
117
+ mask = jnp.reshape((jnp.equal(jnp.arange(num_tiles),
118
+ idx)).astype(boxes.dtype), [1, -1, 1, 1])
124
119
  boxes = jnp.tile(jnp.expand_dims(
125
120
  box_slice, 1), [1, num_tiles, 1, 1]) * mask + jnp.reshape(
126
121
  boxes, [batch_size, num_tiles, _NMS_TILE_SIZE, 4]) * (1 - mask)
127
122
  boxes = jnp.reshape(boxes, [batch_size, -1, 4])
128
123
 
129
124
  # Updates output_size.
130
- output_size += jnp.sum(
131
- jnp.any(box_slice > 0, [2]).astype(jnp.int32), [1])
125
+ output_size += jnp.sum(jnp.any(box_slice > 0, [2]).astype(jnp.int32), [1])
132
126
  return boxes, iou_threshold, output_size, idx + 1
133
127
 
134
128
 
@@ -185,8 +179,8 @@ def non_max_suppression_padded(scores, boxes, max_output_size, iou_threshold):
185
179
  """
186
180
  batch_size = boxes.shape[0]
187
181
  num_boxes = boxes.shape[1]
188
- pad = int(jnp.ceil(float(num_boxes) / _NMS_TILE_SIZE)
189
- ) * _NMS_TILE_SIZE - num_boxes
182
+ pad = int(jnp.ceil(
183
+ float(num_boxes) / _NMS_TILE_SIZE)) * _NMS_TILE_SIZE - num_boxes
190
184
  boxes = jnp.pad(boxes.astype(jnp.float32), [[0, 0], [0, pad], [0, 0]])
191
185
  scores = jnp.pad(scores.astype(jnp.float32), [[0, 0], [0, pad]])
192
186
  num_boxes += pad
@@ -194,15 +188,12 @@ def non_max_suppression_padded(scores, boxes, max_output_size, iou_threshold):
194
188
  def _loop_cond(in_args):
195
189
  unused_boxes, unused_threshold, output_size, idx = in_args
196
190
  return jnp.logical_and(
197
- jnp.min(output_size) < max_output_size,
198
- idx < num_boxes // _NMS_TILE_SIZE)
191
+ jnp.min(output_size) < max_output_size, idx
192
+ < num_boxes // _NMS_TILE_SIZE)
199
193
 
200
194
  selected_boxes, _, output_size, _ = lax.while_loop(
201
- _loop_cond, _suppression_loop_body, (
202
- boxes, iou_threshold,
203
- jnp.zeros([batch_size], jnp.int32),
204
- 0
205
- ))
195
+ _loop_cond, _suppression_loop_body,
196
+ (boxes, iou_threshold, jnp.zeros([batch_size], jnp.int32), 0))
206
197
  idx = num_boxes - lax.top_k(
207
198
  jnp.any(selected_boxes > 0, [2]).astype(jnp.int32) *
208
199
  jnp.expand_dims(jnp.arange(num_boxes, 0, -1), 0),
@@ -210,30 +201,28 @@ def non_max_suppression_padded(scores, boxes, max_output_size, iou_threshold):
210
201
  idx = jnp.minimum(idx, num_boxes - 1)
211
202
  idx = jnp.reshape(
212
203
  idx + jnp.reshape(jnp.arange(batch_size) * num_boxes, [-1, 1]), [-1])
213
-
204
+
214
205
  return idx
215
- boxes = jnp.reshape(
216
- (jnp.reshape(boxes, [-1, 4]))[idx],
217
- [batch_size, max_output_size, 4])
218
- boxes = boxes * (
219
- jnp.reshape(jnp.arange(max_output_size), [1, -1, 1]) < jnp.reshape(
220
- output_size, [-1, 1, 1])).astype(boxes.dtype)
206
+ boxes = jnp.reshape((jnp.reshape(boxes, [-1, 4]))[idx],
207
+ [batch_size, max_output_size, 4])
208
+ boxes = boxes * (jnp.reshape(jnp.arange(max_output_size), [1, -1, 1])
209
+ < jnp.reshape(output_size, [-1, 1, 1])).astype(boxes.dtype)
221
210
  scores = jnp.reshape(
222
- jnp.reshape(scores, [-1, 1])[idx],
223
- [batch_size, max_output_size])
224
- scores = scores * (
225
- jnp.reshape(jnp.arange(max_output_size), [1, -1]) < jnp.reshape(
226
- output_size, [-1, 1])).astype(scores.dtype)
211
+ jnp.reshape(scores, [-1, 1])[idx], [batch_size, max_output_size])
212
+ scores = scores * (jnp.reshape(jnp.arange(max_output_size), [1, -1])
213
+ < jnp.reshape(output_size, [-1, 1])).astype(scores.dtype)
227
214
  return scores, boxes
228
215
 
229
216
 
230
217
  # registry:
231
218
 
219
+
232
220
  def nms(boxes, scores, iou_threshold):
233
221
  max_output_size = boxes.shape[0]
234
222
  boxes = boxes.reshape((1, *boxes.shape))
235
223
  scores = scores.reshape((1, *scores.shape))
236
- res = non_max_suppression_padded(scores, boxes, max_output_size, iou_threshold)
224
+ res = non_max_suppression_padded(scores, boxes, max_output_size,
225
+ iou_threshold)
237
226
  return res
238
227
 
239
228
 
@@ -242,4 +231,4 @@ try:
242
231
  import torchvision
243
232
  ops_registry.register_torch_dispatch_op(torch.ops.torchvision.nms, nms)
244
233
  except Exception:
245
- pass
234
+ pass