torchax 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

torchax/ops/jtorch.py CHANGED
@@ -1,8 +1,9 @@
1
1
  """Tensor constructor overrides"""
2
+
2
3
  import math
3
4
  import collections.abc
4
5
  import functools
5
- from typing import Optional, Sequence
6
+ from typing import Optional, Sequence, Tuple
6
7
  import numpy as np
7
8
 
8
9
  import jax
@@ -12,8 +13,10 @@ from jax.experimental.shard_map import shard_map
12
13
 
13
14
  import torch
14
15
  from torchax.ops.ops_registry import register_torch_function_op
15
- from torchax.ops import op_base, mappings, jaten
16
+ from torchax.ops import op_base, mappings, jaten, jimage
16
17
  import torchax.tensor
18
+ from torchax.view import View, NarrowInfo
19
+ import torch.utils._pytree as pytree
17
20
 
18
21
 
19
22
  def register_function(torch_func, **kwargs):
@@ -21,7 +24,8 @@ def register_function(torch_func, **kwargs):
21
24
 
22
25
 
23
26
  @register_function(torch.as_tensor, is_jax_function=False, needs_env=True)
24
- @op_base.convert_dtype(use_default_dtype=False) # Attempt to infer type from elements
27
+ @op_base.convert_dtype(
28
+ use_default_dtype=False) # Attempt to infer type from elements
25
29
  def _as_tensor(data, dtype=None, device=None, env=None):
26
30
  if isinstance(data, torch.Tensor):
27
31
  return env._to_copy(data, dtype, device)
@@ -33,7 +37,8 @@ def _as_tensor(data, dtype=None, device=None, env=None):
33
37
 
34
38
 
35
39
  @register_function(torch.tensor)
36
- @op_base.convert_dtype(use_default_dtype=False) # Attempt to infer type from elements
40
+ @op_base.convert_dtype(
41
+ use_default_dtype=False) # Attempt to infer type from elements
37
42
  def _tensor(data, *, dtype=None, **kwargs):
38
43
  python_types_to_torch_types = {
39
44
  bool: jnp.bool,
@@ -57,8 +62,8 @@ def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):
57
62
 
58
63
  @register_function(torch.angle)
59
64
  def _torch_angle(input):
60
- if input.dtype.name == 'int64':
61
- input = input.astype(jnp.dtype('float32'))
65
+ if input.dtype.name == "int64":
66
+ input = input.astype(jnp.dtype("float32"))
62
67
  return jnp.angle(input)
63
68
 
64
69
 
@@ -72,19 +77,21 @@ def _torch_argsort(input, dim=-1, descending=False, stable=False):
72
77
  # behavior is the same as a jnp array of rank 1
73
78
  expanded = True
74
79
  input = jnp.expand_dims(input, 0)
75
- res = jnp.argsort(input, axis=dim, descending=descending,
76
- stable=stable)
80
+ res = jnp.argsort(input, axis=dim, descending=descending, stable=stable)
77
81
  if expanded:
78
82
  res = res.squeeze()
79
83
  return res
80
84
 
85
+
81
86
  @register_function(torch.diag)
82
87
  def _diag(input, diagonal=0):
83
88
  return jnp.diag(input, k=diagonal)
84
89
 
90
+
85
91
  @register_function(torch.einsum)
86
92
  @register_function(torch.ops.aten.einsum)
87
93
  def _einsum(equation, *operands):
94
+
88
95
  def get_params(*a):
89
96
  inner_list = a[0]
90
97
  if not isinstance(inner_list, jax.Array):
@@ -95,68 +102,80 @@ def _einsum(equation, *operands):
95
102
  A, B = inner_list
96
103
  return A, B
97
104
  return operands
98
- assert isinstance(equation, str), 'Only accept str equation'
105
+
106
+ assert isinstance(equation, str), "Only accept str equation"
99
107
  filtered_operands = get_params(*operands)
100
108
  return jnp.einsum(equation, *filtered_operands)
101
109
 
102
110
 
103
- def _sdpa_reference(query, key, value, attn_mask=None, dropout_p=0.0,
104
- is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
105
- L, S = query.size(-2), key.size(-2)
106
- scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
107
- attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
108
- if is_causal:
109
- assert attn_mask is None
110
- temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
111
- attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
112
- attn_bias.to(query.dtype)
113
- if attn_mask is not None:
114
- if attn_mask.dtype == torch.bool:
115
- attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
116
- else:
117
- attn_bias += attn_mask
118
- if enable_gqa:
119
- key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
120
- value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
121
-
122
- attn_weight = query @ key.transpose(-2, -1) * scale_factor
123
- attn_weight += attn_bias
124
- attn_weight = torch.softmax(attn_weight, dim=-1)
125
- if dropout_p > 0:
126
- attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
127
- return attn_weight @ value
111
+ def _sdpa_reference(
112
+ query,
113
+ key,
114
+ value,
115
+ attn_mask=None,
116
+ dropout_p=0.0,
117
+ is_causal=False,
118
+ scale=None,
119
+ enable_gqa=False,
120
+ ) -> torch.Tensor:
121
+ L, S = query.size(-2), key.size(-2)
122
+ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
123
+ attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
124
+ if is_causal:
125
+ assert attn_mask is None
126
+ temp_mask = torch.ones(
127
+ L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
128
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
129
+ attn_bias.to(query.dtype)
130
+ if attn_mask is not None:
131
+ if attn_mask.dtype == torch.bool:
132
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
133
+ else:
134
+ attn_bias += attn_mask
135
+ if enable_gqa:
136
+ key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
137
+ value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)
138
+
139
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
140
+ attn_weight += attn_bias
141
+ attn_weight = torch.softmax(attn_weight, dim=-1)
142
+ if dropout_p > 0:
143
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
144
+ return attn_weight @ value
128
145
 
129
146
 
130
147
  from jax.sharding import PartitionSpec
131
148
 
149
+
132
150
  def _tpu_flash_attention(query, key, value, env):
133
- fsdp_partition = PartitionSpec('fsdp')
151
+ fsdp_partition = PartitionSpec("fsdp")
152
+
134
153
  def wrap_flash_attention(query, key, value):
135
154
  block_sizes = flash_attention.BlockSizes(
136
- block_b=min(2, query.shape[0]),
137
- block_q=min(512, query.shape[2]),
138
- block_k_major=min(512, key.shape[2]),
139
- block_k=min(512, key.shape[2]),
140
- block_q_major_dkv=min(512, query.shape[2]),
141
- block_k_major_dkv=min(512, key.shape[2]),
142
- block_k_dkv=min(512, key.shape[2]),
143
- block_q_dkv=min(512, query.shape[2]),
144
- block_k_major_dq=min(512, key.shape[2]),
145
- block_k_dq=min(256, key.shape[2]),
146
- block_q_dq=min(1024, query.shape[2]),
155
+ block_b=min(2, query.shape[0]),
156
+ block_q=min(512, query.shape[2]),
157
+ block_k_major=min(512, key.shape[2]),
158
+ block_k=min(512, key.shape[2]),
159
+ block_q_major_dkv=min(512, query.shape[2]),
160
+ block_k_major_dkv=min(512, key.shape[2]),
161
+ block_k_dkv=min(512, key.shape[2]),
162
+ block_q_dkv=min(512, query.shape[2]),
163
+ block_k_major_dq=min(512, key.shape[2]),
164
+ block_k_dq=min(256, key.shape[2]),
165
+ block_q_dq=min(1024, query.shape[2]),
147
166
  )
148
167
  return flash_attention.flash_attention(
149
168
  query, key, value, causal=True, block_sizes=block_sizes)
150
169
 
151
170
  if env.config.shmap_flash_attention:
152
171
  wrap_flash_attention = shard_map(
153
- wrap_flash_attention,
154
- mesh=env._mesh,
155
- in_specs=(fsdp_partition, fsdp_partition, fsdp_partition),
156
- out_specs=fsdp_partition ,
157
- check_rep=False,
172
+ wrap_flash_attention,
173
+ mesh=env._mesh,
174
+ in_specs=(fsdp_partition, fsdp_partition, fsdp_partition),
175
+ out_specs=fsdp_partition,
176
+ check_rep=False,
158
177
  )
159
- #return flash_attn_mapped(query, key, value)
178
+ # return flash_attn_mapped(query, key, value)
160
179
  return wrap_flash_attention(query, key, value)
161
180
 
162
181
 
@@ -210,27 +229,60 @@ def pad(tensor, pad, mode="constant", value=None):
210
229
  return jnp.pad(tensor[nd_slice], numpy_pad_width, mode=numpy_mode, **kwargs)
211
230
 
212
231
 
213
- @register_function(torch.nn.functional.scaled_dot_product_attention, is_jax_function=False, needs_env=True)
214
- @register_function(torch.ops.aten.scaled_dot_product_attention, is_jax_function=False, needs_env=True)
232
+ @register_function(
233
+ torch.nn.functional.scaled_dot_product_attention,
234
+ is_jax_function=False,
235
+ needs_env=True,
236
+ )
237
+ @register_function(
238
+ torch.ops.aten.scaled_dot_product_attention,
239
+ is_jax_function=False,
240
+ needs_env=True)
215
241
  def scaled_dot_product_attention(
216
- query, key, value, attn_mask=None,
217
- dropout_p=0.0, is_causal=False, scale=None, enable_gqa=False, env=None) -> torch.Tensor:
218
-
219
- if env.config.use_tpu_flash_attention:
242
+ query,
243
+ key,
244
+ value,
245
+ attn_mask=None,
246
+ dropout_p=0.0,
247
+ is_causal=False,
248
+ scale=None,
249
+ enable_gqa=False,
250
+ env=None,
251
+ ) -> torch.Tensor:
252
+
253
+ if env.config.use_tpu_flash_attention:
220
254
  jquery, jkey, jvalue = env.t2j_iso((query, key, value))
221
255
  res = _tpu_flash_attention(jquery, jkey, jvalue, env)
222
256
  return env.j2t_iso(res)
223
257
 
224
- return _sdpa_reference(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa)
258
+ return _sdpa_reference(query, key, value, attn_mask, dropout_p, is_causal,
259
+ scale, enable_gqa)
260
+
225
261
 
226
- @register_function(torch.Tensor.__getitem__)
262
+ @register_function(
263
+ torch.Tensor.__getitem__, is_jax_function=False, is_view_op=True)
227
264
  def getitem(self, indexes):
265
+
228
266
  if isinstance(indexes, list) and isinstance(indexes[0], int):
229
267
  # list of int, i.e. x[[1, 2]] NOT x[1, 2] (the second would be tuple of int)
230
- indexes = (indexes, )
268
+ indexes = (indexes,)
231
269
  elif isinstance(indexes, list):
232
270
  indexes = tuple(indexes)
233
- return self[indexes]
271
+
272
+ def is_narrow_slicing():
273
+ tensor_free = not pytree.tree_any(
274
+ lambda x: isinstance(x, torch.Tensor) or isinstance(x, jax.Array),
275
+ indexes)
276
+ list_free = not isinstance(indexes, tuple) or all(
277
+ [False if isinstance(x, list) else True for x in indexes])
278
+ return tensor_free and list_free
279
+
280
+ if is_narrow_slicing():
281
+ return View(self, view_info=NarrowInfo(indexes), env=self._env)
282
+
283
+ indexes = self._env.t2j_iso(indexes)
284
+ return torchax.tensor.Tensor(self._elem[indexes], self._env)
285
+
234
286
 
235
287
  @register_function(torch.corrcoef)
236
288
  def _corrcoef(x):
@@ -238,15 +290,22 @@ def _corrcoef(x):
238
290
  return jnp.corrcoef(x).astype(jnp.float32)
239
291
  return jnp.corrcoef(x)
240
292
 
293
+
241
294
  @register_function(torch.sparse.mm, is_jax_function=False)
242
- def _sparse_mm(mat1, mat2, reduce='sum'):
295
+ def _sparse_mm(mat1, mat2, reduce="sum"):
243
296
  return torch.mm(mat1, mat2)
244
297
 
298
+
245
299
  @register_function(torch.isclose)
246
300
  def _aten_isclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):
247
301
  return jnp.isclose(input, other, rtol, atol, equal_nan)
248
302
 
249
303
 
304
+ @register_function(torch.linalg.det)
305
+ def linalg_det(input):
306
+ return jnp.linalg.det(input)
307
+
308
+
250
309
  @register_function(torch.ones)
251
310
  def _ones(*size: int, dtype=None, **kwargs):
252
311
  if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
@@ -281,11 +340,18 @@ def empty(*size: Sequence[int], dtype=None, **kwargs):
281
340
  size = size[0]
282
341
  return jnp.empty(size, dtype=dtype)
283
342
 
343
+
284
344
  @register_function(torch.arange, is_jax_function=False)
285
345
  def arange(
286
- start, end=None, step=None,
287
- out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False,
288
- pin_memory=None,
346
+ start,
347
+ end=None,
348
+ step=None,
349
+ out=None,
350
+ dtype=None,
351
+ layout=torch.strided,
352
+ device=None,
353
+ requires_grad=False,
354
+ pin_memory=None,
289
355
  ):
290
356
  if end is None:
291
357
  end = start
@@ -294,9 +360,18 @@ def arange(
294
360
  step = 1
295
361
  return torch.ops.aten.arange(start, end, step, dtype=dtype)
296
362
 
363
+
297
364
  @register_function(torch.empty_strided, is_jax_function=False)
298
365
  def empty_strided(
299
- size, stride, *, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False):
366
+ size,
367
+ stride,
368
+ *,
369
+ dtype=None,
370
+ layout=None,
371
+ device=None,
372
+ requires_grad=False,
373
+ pin_memory=False,
374
+ ):
300
375
  return empty(size, dtype=dtype)
301
376
 
302
377
 
@@ -306,25 +381,30 @@ def unravel_index(indices, shape):
306
381
 
307
382
 
308
383
  @register_function(torch.rand, is_jax_function=False)
309
- def rand(
310
- *size, **kwargs
311
- ):
384
+ def rand(*size, **kwargs):
312
385
  if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
313
386
  size = size[0]
314
387
  return torch.ops.aten.rand(size, **kwargs)
315
388
 
389
+
316
390
  @register_function(torch.randn, is_jax_function=False)
317
391
  def randn(
318
- *size, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False
392
+ *size,
393
+ generator=None,
394
+ out=None,
395
+ dtype=None,
396
+ layout=torch.strided,
397
+ device=None,
398
+ requires_grad=False,
399
+ pin_memory=False,
319
400
  ):
320
401
  if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
321
402
  size = size[0]
322
403
  return torch.ops.aten.randn(size, generator=generator, dtype=dtype)
323
404
 
405
+
324
406
  @register_function(torch.randint, is_jax_function=False)
325
- def randint(
326
- *args, **kwargs
327
- ):
407
+ def randint(*args, **kwargs):
328
408
  return torch.ops.aten.randint(*args, **kwargs)
329
409
 
330
410
 
@@ -356,14 +436,17 @@ def linalg_solve_ex(a, b):
356
436
  res, info = jaten._aten__linalg_solve_ex(a, b)
357
437
  return res, info
358
438
 
439
+
359
440
  @register_function(torch.linalg.svd)
360
441
  def linalg_svd(a, full_matrices=True):
361
442
  return jaten._aten__linalg_svd(a, full_matrices=full_matrices)
362
443
 
444
+
363
445
  @register_function(torch.linalg.matrix_power)
364
446
  def matrix_power(A, n, *, out=None):
365
447
  return jnp.linalg.matrix_power(A, n)
366
448
 
449
+
367
450
  @register_function(torch.svd)
368
451
  def svd(a, some=True, compute_uv=True):
369
452
  if not compute_uv:
@@ -374,21 +457,24 @@ def svd(a, some=True, compute_uv=True):
374
457
  U, S, V = jaten._aten__linalg_svd(a, full_matrices=not some)
375
458
  return U, S, jnp.matrix_transpose(V)
376
459
 
460
+
377
461
  @register_function(torch.cdist)
378
- def _cdist(x1, x2, p=2.0, compute_mode='use_mm_for_euclid_dist_if_necessary'):
379
- return jaten._aten_cdist(x1, x2, p, compute_mode)
462
+ def _cdist(x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary"):
463
+ return jaten._aten_cdist(x1, x2, p, compute_mode)
464
+
380
465
 
381
466
  @register_function(torch.lu)
382
467
  def lu(A, **kwargs):
383
- lu,pivots,_ = jax.lax.linalg.lu(A)
468
+ lu, pivots, _ = jax.lax.linalg.lu(A)
384
469
  # JAX pivots are offset by 1 compared to torch
385
470
  _pivots = pivots + 1
386
471
  info_shape = pivots.shape[:-1]
387
472
  info = jnp.zeros(info_shape, dtype=mappings.t2j_dtype(torch.int32))
388
- if kwargs['get_infos'] == True:
473
+ if kwargs["get_infos"] == True:
389
474
  return lu, _pivots, info
390
475
  return lu, _pivots
391
476
 
477
+
392
478
  @register_function(torch.lu_solve)
393
479
  def lu_solve(b, LU_data, LU_pivots, **kwargs):
394
480
  # JAX pivots are offset by 1 compared to torch
@@ -396,6 +482,7 @@ def lu_solve(b, LU_data, LU_pivots, **kwargs):
396
482
  x = jax.scipy.linalg.lu_solve((LU_data, _pivots), b)
397
483
  return x
398
484
 
485
+
399
486
  @register_function(torch.linalg.tensorsolve)
400
487
  def linalg_tensorsolve(A, b, dims=None):
401
488
  # examples:
@@ -425,3 +512,57 @@ def functional_linear(self, weights, bias=None):
425
512
  if bias is not None:
426
513
  res += bias
427
514
  return res
515
+
516
+
517
+ @register_function(torch.nn.functional.interpolate)
518
+ def functional_interpolate(
519
+ input,
520
+ size: Tuple[int, int],
521
+ scale_factor: Optional[float],
522
+ mode: str,
523
+ align_corners: bool,
524
+ recompute_scale_factor: bool,
525
+ antialias: bool,
526
+ ):
527
+ supported_methods = (
528
+ "nearest",
529
+ "linear",
530
+ "bilinear",
531
+ "trilinear",
532
+ "cubic",
533
+ "bicubic",
534
+ "tricubic",
535
+ "lanczos3",
536
+ "lanczos5",
537
+ )
538
+ is_jax_supported = mode in supported_methods
539
+ if not is_jax_supported:
540
+ raise torchax.tensor.OperatorNotFound(
541
+ f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}"
542
+ )
543
+ # None check
544
+ antialias = antialias or False
545
+ align_corners = align_corners or False
546
+
547
+ if mode in ('cubic', 'bicubic',
548
+ 'tricubic') and not antialias and size is not None:
549
+ return jimage.interpolate_bicubic_no_aa(
550
+ input,
551
+ size[0],
552
+ size[1],
553
+ align_corners,
554
+ )
555
+ else:
556
+ # fallback
557
+ raise torchax.tensor.OperatorNotFound(
558
+ f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}"
559
+ )
560
+
561
+
562
+ @register_function(torch.Tensor.repeat_interleave)
563
+ def torch_Tensor_repeat_interleave(self,
564
+ repeats,
565
+ dim=None,
566
+ *,
567
+ output_size=None):
568
+ return jnp.repeat(self, repeats, axis=dim, total_repeat_length=output_size)
@@ -53,8 +53,8 @@ def _self_suppression(in_args):
53
53
  can_suppress_others = jnp.reshape(
54
54
  jnp.max(iou, 1) <= 0.5, [batch_size, -1, 1]).astype(iou.dtype)
55
55
  iou_suppressed = jnp.reshape(
56
- (jnp.max(can_suppress_others * iou, 1) <= 0.5).astype(iou.dtype),
57
- [batch_size, -1, 1]) * iou
56
+ (jnp.max(can_suppress_others * iou, 1) <= 0.5).astype(
57
+ iou.dtype), [batch_size, -1, 1]) * iou
58
58
  iou_sum_new = jnp.sum(iou_suppressed, [1, 2])
59
59
  return iou_suppressed, jnp.any(iou_sum - iou_sum_new > 0.5), iou_sum_new
60
60
 
@@ -65,9 +65,8 @@ def _cross_suppression(in_args):
65
65
  new_slice = lax.dynamic_slice(boxes, [0, inner_idx * _NMS_TILE_SIZE, 0],
66
66
  [batch_size, _NMS_TILE_SIZE, 4])
67
67
  iou = _bbox_overlap(new_slice, box_slice)
68
- ret_slice = jnp.expand_dims(
69
- (jnp.all(iou < iou_threshold, [1])).astype(box_slice.dtype),
70
- 2) * box_slice
68
+ ret_slice = jnp.expand_dims((jnp.all(iou < iou_threshold, [1])).astype(
69
+ box_slice.dtype), 2) * box_slice
71
70
  return boxes, ret_slice, iou_threshold, inner_idx + 1
72
71
 
73
72
 
@@ -90,45 +89,40 @@ def _suppression_loop_body(in_args):
90
89
  # Iterates over tiles that can possibly suppress the current tile.
91
90
  box_slice = lax.dynamic_slice(boxes, [0, idx * _NMS_TILE_SIZE, 0],
92
91
  [batch_size, _NMS_TILE_SIZE, 4])
92
+
93
93
  def _loop_cond(in_args):
94
94
  _, _, _, inner_idx = in_args
95
95
  return inner_idx < idx
96
96
 
97
- _, box_slice, _, _ = lax.while_loop(
98
- _loop_cond,
99
- _cross_suppression, (boxes, box_slice, iou_threshold,
100
- 0))
97
+ _, box_slice, _, _ = lax.while_loop(_loop_cond, _cross_suppression,
98
+ (boxes, box_slice, iou_threshold, 0))
101
99
 
102
100
  # Iterates over the current tile to compute self-suppression.
103
101
  iou = _bbox_overlap(box_slice, box_slice)
104
102
  mask = jnp.expand_dims(
105
- jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [1, -1]) > jnp.reshape(
106
- jnp.arange(_NMS_TILE_SIZE), [-1, 1]), 0)
103
+ jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [1, -1])
104
+ > jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [-1, 1]), 0)
107
105
  iou *= (jnp.logical_and(mask, iou >= iou_threshold)).astype(iou.dtype)
108
106
 
109
107
  def _loop_cond2(in_args):
110
108
  _, loop_condition, _ = in_args
111
109
  return loop_condition
112
110
 
113
- suppressed_iou, _, _ = lax.while_loop(
114
- _loop_cond2, _self_suppression,
115
- (iou, True,
116
- jnp.sum(iou, [1, 2])))
111
+ suppressed_iou, _, _ = lax.while_loop(_loop_cond2, _self_suppression,
112
+ (iou, True, jnp.sum(iou, [1, 2])))
117
113
  suppressed_box = jnp.sum(suppressed_iou, 1) > 0
118
114
  box_slice *= jnp.expand_dims(1.0 - suppressed_box.astype(box_slice.dtype), 2)
119
115
 
120
116
  # Uses box_slice to update the input boxes.
121
- mask = jnp.reshape(
122
- (jnp.equal(jnp.arange(num_tiles), idx)).astype(boxes.dtype),
123
- [1, -1, 1, 1])
117
+ mask = jnp.reshape((jnp.equal(jnp.arange(num_tiles),
118
+ idx)).astype(boxes.dtype), [1, -1, 1, 1])
124
119
  boxes = jnp.tile(jnp.expand_dims(
125
120
  box_slice, 1), [1, num_tiles, 1, 1]) * mask + jnp.reshape(
126
121
  boxes, [batch_size, num_tiles, _NMS_TILE_SIZE, 4]) * (1 - mask)
127
122
  boxes = jnp.reshape(boxes, [batch_size, -1, 4])
128
123
 
129
124
  # Updates output_size.
130
- output_size += jnp.sum(
131
- jnp.any(box_slice > 0, [2]).astype(jnp.int32), [1])
125
+ output_size += jnp.sum(jnp.any(box_slice > 0, [2]).astype(jnp.int32), [1])
132
126
  return boxes, iou_threshold, output_size, idx + 1
133
127
 
134
128
 
@@ -185,8 +179,8 @@ def non_max_suppression_padded(scores, boxes, max_output_size, iou_threshold):
185
179
  """
186
180
  batch_size = boxes.shape[0]
187
181
  num_boxes = boxes.shape[1]
188
- pad = int(jnp.ceil(float(num_boxes) / _NMS_TILE_SIZE)
189
- ) * _NMS_TILE_SIZE - num_boxes
182
+ pad = int(jnp.ceil(
183
+ float(num_boxes) / _NMS_TILE_SIZE)) * _NMS_TILE_SIZE - num_boxes
190
184
  boxes = jnp.pad(boxes.astype(jnp.float32), [[0, 0], [0, pad], [0, 0]])
191
185
  scores = jnp.pad(scores.astype(jnp.float32), [[0, 0], [0, pad]])
192
186
  num_boxes += pad
@@ -194,15 +188,12 @@ def non_max_suppression_padded(scores, boxes, max_output_size, iou_threshold):
194
188
  def _loop_cond(in_args):
195
189
  unused_boxes, unused_threshold, output_size, idx = in_args
196
190
  return jnp.logical_and(
197
- jnp.min(output_size) < max_output_size,
198
- idx < num_boxes // _NMS_TILE_SIZE)
191
+ jnp.min(output_size) < max_output_size, idx
192
+ < num_boxes // _NMS_TILE_SIZE)
199
193
 
200
194
  selected_boxes, _, output_size, _ = lax.while_loop(
201
- _loop_cond, _suppression_loop_body, (
202
- boxes, iou_threshold,
203
- jnp.zeros([batch_size], jnp.int32),
204
- 0
205
- ))
195
+ _loop_cond, _suppression_loop_body,
196
+ (boxes, iou_threshold, jnp.zeros([batch_size], jnp.int32), 0))
206
197
  idx = num_boxes - lax.top_k(
207
198
  jnp.any(selected_boxes > 0, [2]).astype(jnp.int32) *
208
199
  jnp.expand_dims(jnp.arange(num_boxes, 0, -1), 0),
@@ -210,30 +201,28 @@ def non_max_suppression_padded(scores, boxes, max_output_size, iou_threshold):
210
201
  idx = jnp.minimum(idx, num_boxes - 1)
211
202
  idx = jnp.reshape(
212
203
  idx + jnp.reshape(jnp.arange(batch_size) * num_boxes, [-1, 1]), [-1])
213
-
204
+
214
205
  return idx
215
- boxes = jnp.reshape(
216
- (jnp.reshape(boxes, [-1, 4]))[idx],
217
- [batch_size, max_output_size, 4])
218
- boxes = boxes * (
219
- jnp.reshape(jnp.arange(max_output_size), [1, -1, 1]) < jnp.reshape(
220
- output_size, [-1, 1, 1])).astype(boxes.dtype)
206
+ boxes = jnp.reshape((jnp.reshape(boxes, [-1, 4]))[idx],
207
+ [batch_size, max_output_size, 4])
208
+ boxes = boxes * (jnp.reshape(jnp.arange(max_output_size), [1, -1, 1])
209
+ < jnp.reshape(output_size, [-1, 1, 1])).astype(boxes.dtype)
221
210
  scores = jnp.reshape(
222
- jnp.reshape(scores, [-1, 1])[idx],
223
- [batch_size, max_output_size])
224
- scores = scores * (
225
- jnp.reshape(jnp.arange(max_output_size), [1, -1]) < jnp.reshape(
226
- output_size, [-1, 1])).astype(scores.dtype)
211
+ jnp.reshape(scores, [-1, 1])[idx], [batch_size, max_output_size])
212
+ scores = scores * (jnp.reshape(jnp.arange(max_output_size), [1, -1])
213
+ < jnp.reshape(output_size, [-1, 1])).astype(scores.dtype)
227
214
  return scores, boxes
228
215
 
229
216
 
230
217
  # registry:
231
218
 
219
+
232
220
  def nms(boxes, scores, iou_threshold):
233
221
  max_output_size = boxes.shape[0]
234
222
  boxes = boxes.reshape((1, *boxes.shape))
235
223
  scores = scores.reshape((1, *scores.shape))
236
- res = non_max_suppression_padded(scores, boxes, max_output_size, iou_threshold)
224
+ res = non_max_suppression_padded(scores, boxes, max_output_size,
225
+ iou_threshold)
237
226
  return res
238
227
 
239
228
 
@@ -242,4 +231,4 @@ try:
242
231
  import torchvision
243
232
  ops_registry.register_torch_dispatch_op(torch.ops.torchvision.nms, nms)
244
233
  except Exception:
245
- pass
234
+ pass