torchax 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

torchax/ops/jtorch.py ADDED
@@ -0,0 +1,427 @@
1
+ """Tensor constructor overrides"""
2
+ import math
3
+ import collections.abc
4
+ import functools
5
+ from typing import Optional, Sequence
6
+ import numpy as np
7
+
8
+ import jax
9
+ import jax.numpy as jnp
10
+ from jax.experimental.pallas.ops.tpu import flash_attention
11
+ from jax.experimental.shard_map import shard_map
12
+
13
+ import torch
14
+ from torchax.ops.ops_registry import register_torch_function_op
15
+ from torchax.ops import op_base, mappings, jaten
16
+ import torchax.tensor
17
+
18
+
19
+ def register_function(torch_func, **kwargs):
20
+ return functools.partial(register_torch_function_op, torch_func, **kwargs)
21
+
22
+
23
+ @register_function(torch.as_tensor, is_jax_function=False, needs_env=True)
24
+ @op_base.convert_dtype(use_default_dtype=False) # Attempt to infer type from elements
25
+ def _as_tensor(data, dtype=None, device=None, env=None):
26
+ if isinstance(data, torch.Tensor):
27
+ return env._to_copy(data, dtype, device)
28
+ if isinstance(data, np.ndarray):
29
+ jax_res = jnp.asarray(data)
30
+ else:
31
+ jax_res = _tensor(data, dtype=dtype)
32
+ return torchax.tensor.Tensor(jax_res, env)
33
+
34
+
35
+ @register_function(torch.tensor)
36
+ @op_base.convert_dtype(use_default_dtype=False) # Attempt to infer type from elements
37
+ def _tensor(data, *, dtype=None, **kwargs):
38
+ python_types_to_torch_types = {
39
+ bool: jnp.bool,
40
+ int: jnp.int64,
41
+ float: jnp.float32,
42
+ complex: jnp.complex64,
43
+ }
44
+ if not dtype:
45
+ leaves = jax.tree_util.tree_leaves(data)
46
+ if len(leaves) > 0:
47
+ dtype = python_types_to_torch_types.get(type(leaves[0]))
48
+
49
+ return jnp.array(
50
+ data, dtype=dtype or mappings.t2j_dtype(torch.get_default_dtype()))
51
+
52
+
53
+ @register_function(torch.allclose)
54
+ def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):
55
+ return jnp.allclose(input, other, rtol, atol, equal_nan)
56
+
57
+
58
+ @register_function(torch.angle)
59
+ def _torch_angle(input):
60
+ if input.dtype.name == 'int64':
61
+ input = input.astype(jnp.dtype('float32'))
62
+ return jnp.angle(input)
63
+
64
+
65
+ @register_function(torch.argsort)
66
+ def _torch_argsort(input, dim=-1, descending=False, stable=False):
67
+ expanded = False
68
+ if input.ndim == 0:
69
+ # for self of rank 0:
70
+ # torch.any(x, 0), torch.any(x, -1) works;
71
+ # torch.any(x, 1) throws out of bounds, so it's
72
+ # behavior is the same as a jnp array of rank 1
73
+ expanded = True
74
+ input = jnp.expand_dims(input, 0)
75
+ res = jnp.argsort(input, axis=dim, descending=descending,
76
+ stable=stable)
77
+ if expanded:
78
+ res = res.squeeze()
79
+ return res
80
+
81
+ @register_function(torch.diag)
82
+ def _diag(input, diagonal=0):
83
+ return jnp.diag(input, k=diagonal)
84
+
85
+ @register_function(torch.einsum)
86
+ @register_function(torch.ops.aten.einsum)
87
+ def _einsum(equation, *operands):
88
+ def get_params(*a):
89
+ inner_list = a[0]
90
+ if not isinstance(inner_list, jax.Array):
91
+ if len(inner_list) == 1:
92
+ A = inner_list
93
+ return A
94
+ elif len(inner_list) == 2:
95
+ A, B = inner_list
96
+ return A, B
97
+ return operands
98
+ assert isinstance(equation, str), 'Only accept str equation'
99
+ filtered_operands = get_params(*operands)
100
+ return jnp.einsum(equation, *filtered_operands)
101
+
102
+
103
+ def _sdpa_reference(query, key, value, attn_mask=None, dropout_p=0.0,
104
+ is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
105
+ L, S = query.size(-2), key.size(-2)
106
+ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
107
+ attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
108
+ if is_causal:
109
+ assert attn_mask is None
110
+ temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
111
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
112
+ attn_bias.to(query.dtype)
113
+ if attn_mask is not None:
114
+ if attn_mask.dtype == torch.bool:
115
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
116
+ else:
117
+ attn_bias += attn_mask
118
+ if enable_gqa:
119
+ key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
120
+ value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
121
+
122
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
123
+ attn_weight += attn_bias
124
+ attn_weight = torch.softmax(attn_weight, dim=-1)
125
+ if dropout_p > 0:
126
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
127
+ return attn_weight @ value
128
+
129
+
130
+ from jax.sharding import PartitionSpec
131
+
132
+ def _tpu_flash_attention(query, key, value, env):
133
+ fsdp_partition = PartitionSpec('fsdp')
134
+ def wrap_flash_attention(query, key, value):
135
+ block_sizes = flash_attention.BlockSizes(
136
+ block_b=min(2, query.shape[0]),
137
+ block_q=min(512, query.shape[2]),
138
+ block_k_major=min(512, key.shape[2]),
139
+ block_k=min(512, key.shape[2]),
140
+ block_q_major_dkv=min(512, query.shape[2]),
141
+ block_k_major_dkv=min(512, key.shape[2]),
142
+ block_k_dkv=min(512, key.shape[2]),
143
+ block_q_dkv=min(512, query.shape[2]),
144
+ block_k_major_dq=min(512, key.shape[2]),
145
+ block_k_dq=min(256, key.shape[2]),
146
+ block_q_dq=min(1024, query.shape[2]),
147
+ )
148
+ return flash_attention.flash_attention(
149
+ query, key, value, causal=True, block_sizes=block_sizes)
150
+
151
+ if env.config.shmap_flash_attention:
152
+ wrap_flash_attention = shard_map(
153
+ wrap_flash_attention,
154
+ mesh=env._mesh,
155
+ in_specs=(fsdp_partition, fsdp_partition, fsdp_partition),
156
+ out_specs=fsdp_partition ,
157
+ check_rep=False,
158
+ )
159
+ #return flash_attn_mapped(query, key, value)
160
+ return wrap_flash_attention(query, key, value)
161
+
162
+
163
+ @register_function(torch.nn.functional.pad)
164
+ def pad(tensor, pad, mode="constant", value=None):
165
+ # For padding modes that have different names between Torch and NumPy, this
166
+ # dict provides a Torch-to-NumPy translation. Any string not in this dict will
167
+ # be passed through as-is.
168
+ MODE_NAME_TRANSLATION = {
169
+ "circular": "wrap",
170
+ "replicate": "edge",
171
+ }
172
+
173
+ numpy_mode = MODE_NAME_TRANSLATION.get(mode, mode)
174
+
175
+ num_prefix_dims = tensor.ndim - len(pad) // 2
176
+
177
+ numpy_pad_width = [(0, 0)] * num_prefix_dims
178
+ nd_slice = [slice(None)] * num_prefix_dims
179
+
180
+ for i in range(len(pad) - 2, -1, -2):
181
+ pad_start, pad_end = pad[i:i + 2]
182
+ slice_start, slice_end = None, None
183
+
184
+ if pad_start < 0:
185
+ slice_start = -pad_start
186
+ pad_start = 0
187
+
188
+ if pad_end < 0:
189
+ slice_end = pad_end
190
+ pad_end = 0
191
+
192
+ numpy_pad_width.append((pad_start, pad_end))
193
+ nd_slice.append(slice(slice_start, slice_end))
194
+
195
+ nd_slice = tuple(nd_slice)
196
+
197
+ # `jax.numpy.pad` complains if we provide an irrelevant `constant_values` arg,
198
+ # even if the value we pass in is `None`. (It treats `None` as `nan`.)
199
+ kwargs = dict()
200
+ if mode == "constant" and value is not None:
201
+ kwargs["constant_values"] = value
202
+
203
+ # The "replicate" mode pads first and then slices, whereas the "circular" mode
204
+ # slices first and then pads. The latter approach deals with smaller tensors,
205
+ # so we default to that option in modes where the order of operations doesn't
206
+ # affect the result.
207
+ if mode == "replicate":
208
+ return jnp.pad(tensor, numpy_pad_width, mode=numpy_mode, **kwargs)[nd_slice]
209
+ else:
210
+ return jnp.pad(tensor[nd_slice], numpy_pad_width, mode=numpy_mode, **kwargs)
211
+
212
+
213
+ @register_function(torch.nn.functional.scaled_dot_product_attention, is_jax_function=False, needs_env=True)
214
+ @register_function(torch.ops.aten.scaled_dot_product_attention, is_jax_function=False, needs_env=True)
215
+ def scaled_dot_product_attention(
216
+ query, key, value, attn_mask=None,
217
+ dropout_p=0.0, is_causal=False, scale=None, enable_gqa=False, env=None) -> torch.Tensor:
218
+
219
+ if env.config.use_tpu_flash_attention:
220
+ jquery, jkey, jvalue = env.t2j_iso((query, key, value))
221
+ res = _tpu_flash_attention(jquery, jkey, jvalue, env)
222
+ return env.j2t_iso(res)
223
+
224
+ return _sdpa_reference(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa)
225
+
226
+ @register_function(torch.Tensor.__getitem__)
227
+ def getitem(self, indexes):
228
+ if isinstance(indexes, list) and isinstance(indexes[0], int):
229
+ # list of int, i.e. x[[1, 2]] NOT x[1, 2] (the second would be tuple of int)
230
+ indexes = (indexes, )
231
+ elif isinstance(indexes, list):
232
+ indexes = tuple(indexes)
233
+ return self[indexes]
234
+
235
+ @register_function(torch.corrcoef)
236
+ def _corrcoef(x):
237
+ if x.dtype.name == "int64":
238
+ return jnp.corrcoef(x).astype(jnp.float32)
239
+ return jnp.corrcoef(x)
240
+
241
+ @register_function(torch.sparse.mm, is_jax_function=False)
242
+ def _sparse_mm(mat1, mat2, reduce='sum'):
243
+ return torch.mm(mat1, mat2)
244
+
245
+ @register_function(torch.isclose)
246
+ def _aten_isclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):
247
+ return jnp.isclose(input, other, rtol, atol, equal_nan)
248
+
249
+
250
+ @register_function(torch.ones)
251
+ def _ones(*size: int, dtype=None, **kwargs):
252
+ if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
253
+ size = size[0]
254
+ return jaten._ones(size, dtype=dtype)
255
+
256
+
257
+ @register_function(torch.zeros, is_jax_function=True)
258
+ def _zeros(*size: int, dtype=None, **kwargs):
259
+ if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
260
+ size = size[0]
261
+ return jaten._zeros(size, dtype=dtype)
262
+
263
+
264
+ @register_function(torch.eye)
265
+ @op_base.convert_dtype()
266
+ def _eye(n: int, m: Optional[int] = None, *, dtype=None, **kwargs):
267
+ return jnp.eye(n, m, dtype=dtype)
268
+
269
+
270
+ @register_function(torch.full)
271
+ @op_base.convert_dtype()
272
+ def _full(size: Sequence[int], fill_value, *, dtype=None, **kwargs):
273
+ # TODO: handle torch.Size
274
+ return jnp.full(size, fill_value, dtype=dtype)
275
+
276
+
277
+ @register_function(torch.empty)
278
+ @op_base.convert_dtype()
279
+ def empty(*size: Sequence[int], dtype=None, **kwargs):
280
+ if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
281
+ size = size[0]
282
+ return jnp.empty(size, dtype=dtype)
283
+
284
+ @register_function(torch.arange, is_jax_function=False)
285
+ def arange(
286
+ start, end=None, step=None,
287
+ out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False,
288
+ pin_memory=None,
289
+ ):
290
+ if end is None:
291
+ end = start
292
+ start = 0
293
+ if step is None:
294
+ step = 1
295
+ return torch.ops.aten.arange(start, end, step, dtype=dtype)
296
+
297
+ @register_function(torch.empty_strided, is_jax_function=False)
298
+ def empty_strided(
299
+ size, stride, *, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False):
300
+ return empty(size, dtype=dtype)
301
+
302
+
303
+ @register_function(torch.unravel_index)
304
+ def unravel_index(indices, shape):
305
+ return jnp.unravel_index(indices, shape)
306
+
307
+
308
+ @register_function(torch.rand, is_jax_function=False)
309
+ def rand(
310
+ *size, **kwargs
311
+ ):
312
+ if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
313
+ size = size[0]
314
+ return torch.ops.aten.rand(size, **kwargs)
315
+
316
+ @register_function(torch.randn, is_jax_function=False)
317
+ def randn(
318
+ *size, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False
319
+ ):
320
+ if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
321
+ size = size[0]
322
+ return torch.ops.aten.randn(size, generator=generator, dtype=dtype)
323
+
324
+ @register_function(torch.randint, is_jax_function=False)
325
+ def randint(
326
+ *args, **kwargs
327
+ ):
328
+ return torch.ops.aten.randint(*args, **kwargs)
329
+
330
+
331
+ @register_function(torch.logdet)
332
+ def logdet(input):
333
+ _, logabsdet = jaten._aten__linalg_slogdet(input)
334
+ return logabsdet
335
+
336
+
337
+ @register_function(torch.linalg.slogdet)
338
+ def linalg_slogdet(input):
339
+ sign, logabsdet = jaten._aten__linalg_slogdet(input)
340
+ return torch.return_types.slogdet((sign, logabsdet))
341
+
342
+
343
+ @register_function(torch.tensor_split)
344
+ def tensor_split(input, indices_or_sections, dim=0):
345
+ return jnp.array_split(input, indices_or_sections, axis=dim)
346
+
347
+
348
+ @register_function(torch.linalg.solve)
349
+ def linalg_solve(a, b):
350
+ res, _ = jaten._aten__linalg_solve_ex(a, b)
351
+ return res
352
+
353
+
354
+ @register_function(torch.linalg.solve_ex)
355
+ def linalg_solve_ex(a, b):
356
+ res, info = jaten._aten__linalg_solve_ex(a, b)
357
+ return res, info
358
+
359
+ @register_function(torch.linalg.svd)
360
+ def linalg_svd(a, full_matrices=True):
361
+ return jaten._aten__linalg_svd(a, full_matrices=full_matrices)
362
+
363
+ @register_function(torch.linalg.matrix_power)
364
+ def matrix_power(A, n, *, out=None):
365
+ return jnp.linalg.matrix_power(A, n)
366
+
367
+ @register_function(torch.svd)
368
+ def svd(a, some=True, compute_uv=True):
369
+ if not compute_uv:
370
+ S = jaten._aten__linalg_svd(a, full_matrices=False)[1]
371
+ U = jnp.zeros((a.shape[-2], a.shape[-2]), dtype=a.dtype)
372
+ V = jnp.zeros((a.shape[-1], a.shape[-1]), dtype=a.dtype)
373
+ return U, S, V
374
+ U, S, V = jaten._aten__linalg_svd(a, full_matrices=not some)
375
+ return U, S, jnp.matrix_transpose(V)
376
+
377
+ @register_function(torch.cdist)
378
+ def _cdist(x1, x2, p=2.0, compute_mode='use_mm_for_euclid_dist_if_necessary'):
379
+ return jaten._aten_cdist(x1, x2, p, compute_mode)
380
+
381
+ @register_function(torch.lu)
382
+ def lu(A, **kwargs):
383
+ lu,pivots,_ = jax.lax.linalg.lu(A)
384
+ # JAX pivots are offset by 1 compared to torch
385
+ _pivots = pivots + 1
386
+ info_shape = pivots.shape[:-1]
387
+ info = jnp.zeros(info_shape, dtype=mappings.t2j_dtype(torch.int32))
388
+ if kwargs['get_infos'] == True:
389
+ return lu, _pivots, info
390
+ return lu, _pivots
391
+
392
+ @register_function(torch.lu_solve)
393
+ def lu_solve(b, LU_data, LU_pivots, **kwargs):
394
+ # JAX pivots are offset by 1 compared to torch
395
+ _pivots = LU_pivots - 1
396
+ x = jax.scipy.linalg.lu_solve((LU_data, _pivots), b)
397
+ return x
398
+
399
+ @register_function(torch.linalg.tensorsolve)
400
+ def linalg_tensorsolve(A, b, dims=None):
401
+ # examples:
402
+ # A = torch.randn(2, 3, 6), b = torch.randn(3, 2)
403
+ # A = torch.randn(2, 3, 6), b = torch.randn(2, 3) -> torch.Size([3, 6])
404
+ # A = torch.randn(9, 2, 6, 3) b = torch.randn(6, 3) -> torch.Size([6, 3])
405
+ # A = torch.randn(9, 2, 3, 6) b = torch.randn(6, 3) -> torch.Size([3, 6])
406
+ # A = torch.randn(18, 6, 3) b = torch.randn(18) -> torch.Size([6, 3])
407
+ # A = torch.randn(3, 8, 4, 6) b = torch.randn(4, 6) -> torch.Size([4,6])
408
+ # A = torch.randn(3, 8, 1, 2, 2, 6) b = torch.randn(3, 4, 2) -> torch.Size([2, 2, 6])
409
+
410
+ # torch allows b to be shaped differently.
411
+ # especially when axes are moved using dims.
412
+ # ValueError: After moving axes to end, leading shape of a must match shape of b. got a.shape=(3, 2, 6), b.shape=(2, 3)
413
+ # So we are handling the moveaxis and forcing b's shape to match what jax expects
414
+ if dims is not None:
415
+ A = jnp.moveaxis(A, dims, len(dims) * (A.ndim - 1,))
416
+ dims = None
417
+ if A.shape[:b.ndim] != b.shape:
418
+ b = jnp.reshape(b, A.shape[:b.ndim])
419
+ return jnp.linalg.tensorsolve(A, b, axes=dims)
420
+
421
+
422
+ @register_function(torch.nn.functional.linear)
423
+ def functional_linear(self, weights, bias=None):
424
+ res = jnp.einsum("...a,ba->...b", self, weights)
425
+ if bias is not None:
426
+ res += bias
427
+ return res
@@ -0,0 +1,245 @@
1
+ """
2
+ Forked at: https://raw.githubusercontent.com/mlperf/training_results_v0.7/refs/heads/master/Google/benchmarks/ssd/implementations/ssd-research-JAX-tpu-v3-4096/nms.py
3
+ """
4
+
5
+ import functools
6
+ from typing import List, Union, Optional, Tuple
7
+
8
+ import torch
9
+ from jax import lax
10
+ import jax.numpy as jnp
11
+ from . import ops_registry
12
+
13
+ _NMS_TILE_SIZE = 256
14
+
15
+
16
+ def _bbox_overlap(boxes, gt_boxes):
17
+ """Find Bounding box overlap.
18
+
19
+ Args:
20
+ boxes: first set of bounding boxes
21
+ gt_boxes: second set of boxes to compute IOU
22
+
23
+ Returns:
24
+ iou: Intersection over union matrix of all input bounding boxes
25
+ """
26
+ bb_y_min, bb_x_min, bb_y_max, bb_x_max = jnp.split(
27
+ ary=boxes, indices_or_sections=4, axis=2)
28
+ gt_y_min, gt_x_min, gt_y_max, gt_x_max = jnp.split(
29
+ ary=gt_boxes, indices_or_sections=4, axis=2)
30
+
31
+ # Calculates the intersection area.
32
+ i_xmin = jnp.maximum(bb_x_min, jnp.transpose(gt_x_min, [0, 2, 1]))
33
+ i_xmax = jnp.minimum(bb_x_max, jnp.transpose(gt_x_max, [0, 2, 1]))
34
+ i_ymin = jnp.maximum(bb_y_min, jnp.transpose(gt_y_min, [0, 2, 1]))
35
+ i_ymax = jnp.minimum(bb_y_max, jnp.transpose(gt_y_max, [0, 2, 1]))
36
+ i_area = jnp.maximum((i_xmax - i_xmin), 0) * jnp.maximum((i_ymax - i_ymin), 0)
37
+
38
+ # Calculates the union area.
39
+ bb_area = (bb_y_max - bb_y_min) * (bb_x_max - bb_x_min)
40
+ gt_area = (gt_y_max - gt_y_min) * (gt_x_max - gt_x_min)
41
+ # Adds a small epsilon to avoid divide-by-zero.
42
+ u_area = bb_area + jnp.transpose(gt_area, [0, 2, 1]) - i_area + 1e-8
43
+
44
+ # Calculates IoU.
45
+ iou = i_area / u_area
46
+
47
+ return iou
48
+
49
+
50
+ def _self_suppression(in_args):
51
+ iou, _, iou_sum = in_args
52
+ batch_size = iou.shape[0]
53
+ can_suppress_others = jnp.reshape(
54
+ jnp.max(iou, 1) <= 0.5, [batch_size, -1, 1]).astype(iou.dtype)
55
+ iou_suppressed = jnp.reshape(
56
+ (jnp.max(can_suppress_others * iou, 1) <= 0.5).astype(iou.dtype),
57
+ [batch_size, -1, 1]) * iou
58
+ iou_sum_new = jnp.sum(iou_suppressed, [1, 2])
59
+ return iou_suppressed, jnp.any(iou_sum - iou_sum_new > 0.5), iou_sum_new
60
+
61
+
62
+ def _cross_suppression(in_args):
63
+ boxes, box_slice, iou_threshold, inner_idx = in_args
64
+ batch_size = boxes.shape[0]
65
+ new_slice = lax.dynamic_slice(boxes, [0, inner_idx * _NMS_TILE_SIZE, 0],
66
+ [batch_size, _NMS_TILE_SIZE, 4])
67
+ iou = _bbox_overlap(new_slice, box_slice)
68
+ ret_slice = jnp.expand_dims(
69
+ (jnp.all(iou < iou_threshold, [1])).astype(box_slice.dtype),
70
+ 2) * box_slice
71
+ return boxes, ret_slice, iou_threshold, inner_idx + 1
72
+
73
+
74
+ def _suppression_loop_body(in_args):
75
+ """Process boxes in the range [idx*_NMS_TILE_SIZE, (idx+1)*_NMS_TILE_SIZE).
76
+
77
+ Args:
78
+ in_args: A tuple of arguments: boxes, iou_threshold, output_size, idx
79
+
80
+ Returns:
81
+ boxes: updated boxes.
82
+ iou_threshold: pass down iou_threshold to the next iteration.
83
+ output_size: the updated output_size.
84
+ idx: the updated induction variable.
85
+ """
86
+ boxes, iou_threshold, output_size, idx = in_args
87
+ num_tiles = boxes.shape[1] // _NMS_TILE_SIZE
88
+ batch_size = boxes.shape[0]
89
+
90
+ # Iterates over tiles that can possibly suppress the current tile.
91
+ box_slice = lax.dynamic_slice(boxes, [0, idx * _NMS_TILE_SIZE, 0],
92
+ [batch_size, _NMS_TILE_SIZE, 4])
93
+ def _loop_cond(in_args):
94
+ _, _, _, inner_idx = in_args
95
+ return inner_idx < idx
96
+
97
+ _, box_slice, _, _ = lax.while_loop(
98
+ _loop_cond,
99
+ _cross_suppression, (boxes, box_slice, iou_threshold,
100
+ 0))
101
+
102
+ # Iterates over the current tile to compute self-suppression.
103
+ iou = _bbox_overlap(box_slice, box_slice)
104
+ mask = jnp.expand_dims(
105
+ jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [1, -1]) > jnp.reshape(
106
+ jnp.arange(_NMS_TILE_SIZE), [-1, 1]), 0)
107
+ iou *= (jnp.logical_and(mask, iou >= iou_threshold)).astype(iou.dtype)
108
+
109
+ def _loop_cond2(in_args):
110
+ _, loop_condition, _ = in_args
111
+ return loop_condition
112
+
113
+ suppressed_iou, _, _ = lax.while_loop(
114
+ _loop_cond2, _self_suppression,
115
+ (iou, True,
116
+ jnp.sum(iou, [1, 2])))
117
+ suppressed_box = jnp.sum(suppressed_iou, 1) > 0
118
+ box_slice *= jnp.expand_dims(1.0 - suppressed_box.astype(box_slice.dtype), 2)
119
+
120
+ # Uses box_slice to update the input boxes.
121
+ mask = jnp.reshape(
122
+ (jnp.equal(jnp.arange(num_tiles), idx)).astype(boxes.dtype),
123
+ [1, -1, 1, 1])
124
+ boxes = jnp.tile(jnp.expand_dims(
125
+ box_slice, 1), [1, num_tiles, 1, 1]) * mask + jnp.reshape(
126
+ boxes, [batch_size, num_tiles, _NMS_TILE_SIZE, 4]) * (1 - mask)
127
+ boxes = jnp.reshape(boxes, [batch_size, -1, 4])
128
+
129
+ # Updates output_size.
130
+ output_size += jnp.sum(
131
+ jnp.any(box_slice > 0, [2]).astype(jnp.int32), [1])
132
+ return boxes, iou_threshold, output_size, idx + 1
133
+
134
+
135
+ def non_max_suppression_padded(scores, boxes, max_output_size, iou_threshold):
136
+ """A wrapper that handles non-maximum suppression.
137
+
138
+ Assumption:
139
+ * The boxes are sorted by scores unless the box is a dot (all coordinates
140
+ are zero).
141
+ * Boxes with higher scores can be used to suppress boxes with lower scores.
142
+
143
+ The overal design of the algorithm is to handle boxes tile-by-tile:
144
+
145
+ boxes = boxes.pad_to_multiply_of(tile_size)
146
+ num_tiles = len(boxes) // tile_size
147
+ output_boxes = []
148
+ for i in range(num_tiles):
149
+ box_tile = boxes[i*tile_size : (i+1)*tile_size]
150
+ for j in range(i - 1):
151
+ suppressing_tile = boxes[j*tile_size : (j+1)*tile_size]
152
+ iou = _bbox_overlap(box_tile, suppressing_tile)
153
+ # if the box is suppressed in iou, clear it to a dot
154
+ box_tile *= _update_boxes(iou)
155
+ # Iteratively handle the diagnal tile.
156
+ iou = _box_overlap(box_tile, box_tile)
157
+ iou_changed = True
158
+ while iou_changed:
159
+ # boxes that are not suppressed by anything else
160
+ suppressing_boxes = _get_suppressing_boxes(iou)
161
+ # boxes that are suppressed by suppressing_boxes
162
+ suppressed_boxes = _get_suppressed_boxes(iou, suppressing_boxes)
163
+ # clear iou to 0 for boxes that are suppressed, as they cannot be used
164
+ # to suppress other boxes any more
165
+ new_iou = _clear_iou(iou, suppressed_boxes)
166
+ iou_changed = (new_iou != iou)
167
+ iou = new_iou
168
+ # remaining boxes that can still suppress others, are selected boxes.
169
+ output_boxes.append(_get_suppressing_boxes(iou))
170
+ if len(output_boxes) >= max_output_size:
171
+ break
172
+
173
+ Args:
174
+ scores: a tensor with a shape of [batch_size, anchors].
175
+ boxes: a tensor with a shape of [batch_size, anchors, 4].
176
+ max_output_size: a scalar integer `Tensor` representing the maximum number
177
+ of boxes to be selected by non max suppression.
178
+ iou_threshold: a float representing the threshold for deciding whether boxes
179
+ overlap too much with respect to IOU.
180
+ Returns:
181
+ nms_scores: a tensor with a shape of [batch_size, anchors]. It has same
182
+ dtype as input scores.
183
+ nms_proposals: a tensor with a shape of [batch_size, anchors, 4]. It has
184
+ same dtype as input boxes.
185
+ """
186
+ batch_size = boxes.shape[0]
187
+ num_boxes = boxes.shape[1]
188
+ pad = int(jnp.ceil(float(num_boxes) / _NMS_TILE_SIZE)
189
+ ) * _NMS_TILE_SIZE - num_boxes
190
+ boxes = jnp.pad(boxes.astype(jnp.float32), [[0, 0], [0, pad], [0, 0]])
191
+ scores = jnp.pad(scores.astype(jnp.float32), [[0, 0], [0, pad]])
192
+ num_boxes += pad
193
+
194
+ def _loop_cond(in_args):
195
+ unused_boxes, unused_threshold, output_size, idx = in_args
196
+ return jnp.logical_and(
197
+ jnp.min(output_size) < max_output_size,
198
+ idx < num_boxes // _NMS_TILE_SIZE)
199
+
200
+ selected_boxes, _, output_size, _ = lax.while_loop(
201
+ _loop_cond, _suppression_loop_body, (
202
+ boxes, iou_threshold,
203
+ jnp.zeros([batch_size], jnp.int32),
204
+ 0
205
+ ))
206
+ idx = num_boxes - lax.top_k(
207
+ jnp.any(selected_boxes > 0, [2]).astype(jnp.int32) *
208
+ jnp.expand_dims(jnp.arange(num_boxes, 0, -1), 0),
209
+ max_output_size)[0].astype(jnp.int32)
210
+ idx = jnp.minimum(idx, num_boxes - 1)
211
+ idx = jnp.reshape(
212
+ idx + jnp.reshape(jnp.arange(batch_size) * num_boxes, [-1, 1]), [-1])
213
+
214
+ return idx
215
+ boxes = jnp.reshape(
216
+ (jnp.reshape(boxes, [-1, 4]))[idx],
217
+ [batch_size, max_output_size, 4])
218
+ boxes = boxes * (
219
+ jnp.reshape(jnp.arange(max_output_size), [1, -1, 1]) < jnp.reshape(
220
+ output_size, [-1, 1, 1])).astype(boxes.dtype)
221
+ scores = jnp.reshape(
222
+ jnp.reshape(scores, [-1, 1])[idx],
223
+ [batch_size, max_output_size])
224
+ scores = scores * (
225
+ jnp.reshape(jnp.arange(max_output_size), [1, -1]) < jnp.reshape(
226
+ output_size, [-1, 1])).astype(scores.dtype)
227
+ return scores, boxes
228
+
229
+
230
+ # registry:
231
+
232
+ def nms(boxes, scores, iou_threshold):
233
+ max_output_size = boxes.shape[0]
234
+ boxes = boxes.reshape((1, *boxes.shape))
235
+ scores = scores.reshape((1, *scores.shape))
236
+ res = non_max_suppression_padded(scores, boxes, max_output_size, iou_threshold)
237
+ return res
238
+
239
+
240
+ try:
241
+ import torch
242
+ import torchvision
243
+ ops_registry.register_torch_dispatch_op(torch.ops.torchvision.nms, nms)
244
+ except Exception:
245
+ pass