torchax 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/CONTRIBUTING.md +2 -2
- torchax/__init__.py +57 -19
- torchax/amp.py +333 -0
- torchax/config.py +19 -12
- torchax/decompositions.py +663 -195
- torchax/device_module.py +7 -1
- torchax/distributed.py +55 -60
- torchax/export.py +26 -17
- torchax/flax.py +39 -0
- torchax/interop.py +275 -141
- torchax/mesh_util.py +211 -0
- torchax/ops/jaten.py +1718 -1294
- torchax/ops/jax_reimplement.py +23 -21
- torchax/ops/jc10d.py +5 -4
- torchax/ops/jimage.py +113 -0
- torchax/ops/jlibrary.py +9 -2
- torchax/ops/jtorch.py +219 -78
- torchax/ops/jtorchvision_nms.py +32 -43
- torchax/ops/mappings.py +77 -35
- torchax/ops/op_base.py +59 -32
- torchax/ops/ops_registry.py +40 -35
- torchax/tensor.py +417 -275
- torchax/train.py +38 -41
- torchax/util.py +88 -0
- torchax/view.py +377 -0
- {torchax-0.0.4.dist-info → torchax-0.0.5.dist-info}/METADATA +111 -145
- torchax-0.0.5.dist-info/RECORD +32 -0
- torchax/environment.py +0 -2
- torchax-0.0.4.dist-info/RECORD +0 -27
- {torchax-0.0.4.dist-info → torchax-0.0.5.dist-info}/WHEEL +0 -0
- {torchax-0.0.4.dist-info → torchax-0.0.5.dist-info}/licenses/LICENSE +0 -0
torchax/ops/jtorch.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
"""Tensor constructor overrides"""
|
|
2
|
+
|
|
2
3
|
import math
|
|
3
4
|
import collections.abc
|
|
4
5
|
import functools
|
|
5
|
-
from typing import Optional, Sequence
|
|
6
|
+
from typing import Optional, Sequence, Tuple
|
|
6
7
|
import numpy as np
|
|
7
8
|
|
|
8
9
|
import jax
|
|
@@ -12,8 +13,10 @@ from jax.experimental.shard_map import shard_map
|
|
|
12
13
|
|
|
13
14
|
import torch
|
|
14
15
|
from torchax.ops.ops_registry import register_torch_function_op
|
|
15
|
-
from torchax.ops import op_base, mappings, jaten
|
|
16
|
+
from torchax.ops import op_base, mappings, jaten, jimage
|
|
16
17
|
import torchax.tensor
|
|
18
|
+
from torchax.view import View, NarrowInfo
|
|
19
|
+
import torch.utils._pytree as pytree
|
|
17
20
|
|
|
18
21
|
|
|
19
22
|
def register_function(torch_func, **kwargs):
|
|
@@ -21,7 +24,8 @@ def register_function(torch_func, **kwargs):
|
|
|
21
24
|
|
|
22
25
|
|
|
23
26
|
@register_function(torch.as_tensor, is_jax_function=False, needs_env=True)
|
|
24
|
-
@op_base.convert_dtype(
|
|
27
|
+
@op_base.convert_dtype(
|
|
28
|
+
use_default_dtype=False) # Attempt to infer type from elements
|
|
25
29
|
def _as_tensor(data, dtype=None, device=None, env=None):
|
|
26
30
|
if isinstance(data, torch.Tensor):
|
|
27
31
|
return env._to_copy(data, dtype, device)
|
|
@@ -33,7 +37,8 @@ def _as_tensor(data, dtype=None, device=None, env=None):
|
|
|
33
37
|
|
|
34
38
|
|
|
35
39
|
@register_function(torch.tensor)
|
|
36
|
-
@op_base.convert_dtype(
|
|
40
|
+
@op_base.convert_dtype(
|
|
41
|
+
use_default_dtype=False) # Attempt to infer type from elements
|
|
37
42
|
def _tensor(data, *, dtype=None, **kwargs):
|
|
38
43
|
python_types_to_torch_types = {
|
|
39
44
|
bool: jnp.bool,
|
|
@@ -57,8 +62,8 @@ def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):
|
|
|
57
62
|
|
|
58
63
|
@register_function(torch.angle)
|
|
59
64
|
def _torch_angle(input):
|
|
60
|
-
if input.dtype.name ==
|
|
61
|
-
input = input.astype(jnp.dtype(
|
|
65
|
+
if input.dtype.name == "int64":
|
|
66
|
+
input = input.astype(jnp.dtype("float32"))
|
|
62
67
|
return jnp.angle(input)
|
|
63
68
|
|
|
64
69
|
|
|
@@ -72,19 +77,21 @@ def _torch_argsort(input, dim=-1, descending=False, stable=False):
|
|
|
72
77
|
# behavior is the same as a jnp array of rank 1
|
|
73
78
|
expanded = True
|
|
74
79
|
input = jnp.expand_dims(input, 0)
|
|
75
|
-
res = jnp.argsort(input, axis=dim, descending=descending,
|
|
76
|
-
stable=stable)
|
|
80
|
+
res = jnp.argsort(input, axis=dim, descending=descending, stable=stable)
|
|
77
81
|
if expanded:
|
|
78
82
|
res = res.squeeze()
|
|
79
83
|
return res
|
|
80
84
|
|
|
85
|
+
|
|
81
86
|
@register_function(torch.diag)
|
|
82
87
|
def _diag(input, diagonal=0):
|
|
83
88
|
return jnp.diag(input, k=diagonal)
|
|
84
89
|
|
|
90
|
+
|
|
85
91
|
@register_function(torch.einsum)
|
|
86
92
|
@register_function(torch.ops.aten.einsum)
|
|
87
93
|
def _einsum(equation, *operands):
|
|
94
|
+
|
|
88
95
|
def get_params(*a):
|
|
89
96
|
inner_list = a[0]
|
|
90
97
|
if not isinstance(inner_list, jax.Array):
|
|
@@ -95,68 +102,80 @@ def _einsum(equation, *operands):
|
|
|
95
102
|
A, B = inner_list
|
|
96
103
|
return A, B
|
|
97
104
|
return operands
|
|
98
|
-
|
|
105
|
+
|
|
106
|
+
assert isinstance(equation, str), "Only accept str equation"
|
|
99
107
|
filtered_operands = get_params(*operands)
|
|
100
108
|
return jnp.einsum(equation, *filtered_operands)
|
|
101
109
|
|
|
102
110
|
|
|
103
|
-
def _sdpa_reference(
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
111
|
+
def _sdpa_reference(
|
|
112
|
+
query,
|
|
113
|
+
key,
|
|
114
|
+
value,
|
|
115
|
+
attn_mask=None,
|
|
116
|
+
dropout_p=0.0,
|
|
117
|
+
is_causal=False,
|
|
118
|
+
scale=None,
|
|
119
|
+
enable_gqa=False,
|
|
120
|
+
) -> torch.Tensor:
|
|
121
|
+
L, S = query.size(-2), key.size(-2)
|
|
122
|
+
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
|
|
123
|
+
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
|
|
124
|
+
if is_causal:
|
|
125
|
+
assert attn_mask is None
|
|
126
|
+
temp_mask = torch.ones(
|
|
127
|
+
L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
|
|
128
|
+
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
|
129
|
+
attn_bias.to(query.dtype)
|
|
130
|
+
if attn_mask is not None:
|
|
131
|
+
if attn_mask.dtype == torch.bool:
|
|
132
|
+
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
|
133
|
+
else:
|
|
134
|
+
attn_bias += attn_mask
|
|
135
|
+
if enable_gqa:
|
|
136
|
+
key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
|
|
137
|
+
value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)
|
|
138
|
+
|
|
139
|
+
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
|
140
|
+
attn_weight += attn_bias
|
|
141
|
+
attn_weight = torch.softmax(attn_weight, dim=-1)
|
|
142
|
+
if dropout_p > 0:
|
|
143
|
+
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
|
|
144
|
+
return attn_weight @ value
|
|
128
145
|
|
|
129
146
|
|
|
130
147
|
from jax.sharding import PartitionSpec
|
|
131
148
|
|
|
149
|
+
|
|
132
150
|
def _tpu_flash_attention(query, key, value, env):
|
|
133
|
-
fsdp_partition = PartitionSpec(
|
|
151
|
+
fsdp_partition = PartitionSpec("fsdp")
|
|
152
|
+
|
|
134
153
|
def wrap_flash_attention(query, key, value):
|
|
135
154
|
block_sizes = flash_attention.BlockSizes(
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
155
|
+
block_b=min(2, query.shape[0]),
|
|
156
|
+
block_q=min(512, query.shape[2]),
|
|
157
|
+
block_k_major=min(512, key.shape[2]),
|
|
158
|
+
block_k=min(512, key.shape[2]),
|
|
159
|
+
block_q_major_dkv=min(512, query.shape[2]),
|
|
160
|
+
block_k_major_dkv=min(512, key.shape[2]),
|
|
161
|
+
block_k_dkv=min(512, key.shape[2]),
|
|
162
|
+
block_q_dkv=min(512, query.shape[2]),
|
|
163
|
+
block_k_major_dq=min(512, key.shape[2]),
|
|
164
|
+
block_k_dq=min(256, key.shape[2]),
|
|
165
|
+
block_q_dq=min(1024, query.shape[2]),
|
|
147
166
|
)
|
|
148
167
|
return flash_attention.flash_attention(
|
|
149
168
|
query, key, value, causal=True, block_sizes=block_sizes)
|
|
150
169
|
|
|
151
170
|
if env.config.shmap_flash_attention:
|
|
152
171
|
wrap_flash_attention = shard_map(
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
172
|
+
wrap_flash_attention,
|
|
173
|
+
mesh=env._mesh,
|
|
174
|
+
in_specs=(fsdp_partition, fsdp_partition, fsdp_partition),
|
|
175
|
+
out_specs=fsdp_partition,
|
|
176
|
+
check_rep=False,
|
|
158
177
|
)
|
|
159
|
-
#return flash_attn_mapped(query, key, value)
|
|
178
|
+
# return flash_attn_mapped(query, key, value)
|
|
160
179
|
return wrap_flash_attention(query, key, value)
|
|
161
180
|
|
|
162
181
|
|
|
@@ -210,27 +229,60 @@ def pad(tensor, pad, mode="constant", value=None):
|
|
|
210
229
|
return jnp.pad(tensor[nd_slice], numpy_pad_width, mode=numpy_mode, **kwargs)
|
|
211
230
|
|
|
212
231
|
|
|
213
|
-
@register_function(
|
|
214
|
-
|
|
232
|
+
@register_function(
|
|
233
|
+
torch.nn.functional.scaled_dot_product_attention,
|
|
234
|
+
is_jax_function=False,
|
|
235
|
+
needs_env=True,
|
|
236
|
+
)
|
|
237
|
+
@register_function(
|
|
238
|
+
torch.ops.aten.scaled_dot_product_attention,
|
|
239
|
+
is_jax_function=False,
|
|
240
|
+
needs_env=True)
|
|
215
241
|
def scaled_dot_product_attention(
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
242
|
+
query,
|
|
243
|
+
key,
|
|
244
|
+
value,
|
|
245
|
+
attn_mask=None,
|
|
246
|
+
dropout_p=0.0,
|
|
247
|
+
is_causal=False,
|
|
248
|
+
scale=None,
|
|
249
|
+
enable_gqa=False,
|
|
250
|
+
env=None,
|
|
251
|
+
) -> torch.Tensor:
|
|
252
|
+
|
|
253
|
+
if env.config.use_tpu_flash_attention:
|
|
220
254
|
jquery, jkey, jvalue = env.t2j_iso((query, key, value))
|
|
221
255
|
res = _tpu_flash_attention(jquery, jkey, jvalue, env)
|
|
222
256
|
return env.j2t_iso(res)
|
|
223
257
|
|
|
224
|
-
|
|
258
|
+
return _sdpa_reference(query, key, value, attn_mask, dropout_p, is_causal,
|
|
259
|
+
scale, enable_gqa)
|
|
260
|
+
|
|
225
261
|
|
|
226
|
-
@register_function(
|
|
262
|
+
@register_function(
|
|
263
|
+
torch.Tensor.__getitem__, is_jax_function=False, is_view_op=True)
|
|
227
264
|
def getitem(self, indexes):
|
|
265
|
+
|
|
228
266
|
if isinstance(indexes, list) and isinstance(indexes[0], int):
|
|
229
267
|
# list of int, i.e. x[[1, 2]] NOT x[1, 2] (the second would be tuple of int)
|
|
230
|
-
indexes = (indexes,
|
|
268
|
+
indexes = (indexes,)
|
|
231
269
|
elif isinstance(indexes, list):
|
|
232
270
|
indexes = tuple(indexes)
|
|
233
|
-
|
|
271
|
+
|
|
272
|
+
def is_narrow_slicing():
|
|
273
|
+
tensor_free = not pytree.tree_any(
|
|
274
|
+
lambda x: isinstance(x, torch.Tensor) or isinstance(x, jax.Array),
|
|
275
|
+
indexes)
|
|
276
|
+
list_free = not isinstance(indexes, tuple) or all(
|
|
277
|
+
[False if isinstance(x, list) else True for x in indexes])
|
|
278
|
+
return tensor_free and list_free
|
|
279
|
+
|
|
280
|
+
if is_narrow_slicing():
|
|
281
|
+
return View(self, view_info=NarrowInfo(indexes), env=self._env)
|
|
282
|
+
|
|
283
|
+
indexes = self._env.t2j_iso(indexes)
|
|
284
|
+
return torchax.tensor.Tensor(self._elem[indexes], self._env)
|
|
285
|
+
|
|
234
286
|
|
|
235
287
|
@register_function(torch.corrcoef)
|
|
236
288
|
def _corrcoef(x):
|
|
@@ -238,15 +290,22 @@ def _corrcoef(x):
|
|
|
238
290
|
return jnp.corrcoef(x).astype(jnp.float32)
|
|
239
291
|
return jnp.corrcoef(x)
|
|
240
292
|
|
|
293
|
+
|
|
241
294
|
@register_function(torch.sparse.mm, is_jax_function=False)
|
|
242
|
-
def _sparse_mm(mat1, mat2, reduce=
|
|
295
|
+
def _sparse_mm(mat1, mat2, reduce="sum"):
|
|
243
296
|
return torch.mm(mat1, mat2)
|
|
244
297
|
|
|
298
|
+
|
|
245
299
|
@register_function(torch.isclose)
|
|
246
300
|
def _aten_isclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):
|
|
247
301
|
return jnp.isclose(input, other, rtol, atol, equal_nan)
|
|
248
302
|
|
|
249
303
|
|
|
304
|
+
@register_function(torch.linalg.det)
|
|
305
|
+
def linalg_det(input):
|
|
306
|
+
return jnp.linalg.det(input)
|
|
307
|
+
|
|
308
|
+
|
|
250
309
|
@register_function(torch.ones)
|
|
251
310
|
def _ones(*size: int, dtype=None, **kwargs):
|
|
252
311
|
if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
|
|
@@ -281,11 +340,18 @@ def empty(*size: Sequence[int], dtype=None, **kwargs):
|
|
|
281
340
|
size = size[0]
|
|
282
341
|
return jnp.empty(size, dtype=dtype)
|
|
283
342
|
|
|
343
|
+
|
|
284
344
|
@register_function(torch.arange, is_jax_function=False)
|
|
285
345
|
def arange(
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
346
|
+
start,
|
|
347
|
+
end=None,
|
|
348
|
+
step=None,
|
|
349
|
+
out=None,
|
|
350
|
+
dtype=None,
|
|
351
|
+
layout=torch.strided,
|
|
352
|
+
device=None,
|
|
353
|
+
requires_grad=False,
|
|
354
|
+
pin_memory=None,
|
|
289
355
|
):
|
|
290
356
|
if end is None:
|
|
291
357
|
end = start
|
|
@@ -294,9 +360,18 @@ def arange(
|
|
|
294
360
|
step = 1
|
|
295
361
|
return torch.ops.aten.arange(start, end, step, dtype=dtype)
|
|
296
362
|
|
|
363
|
+
|
|
297
364
|
@register_function(torch.empty_strided, is_jax_function=False)
|
|
298
365
|
def empty_strided(
|
|
299
|
-
|
|
366
|
+
size,
|
|
367
|
+
stride,
|
|
368
|
+
*,
|
|
369
|
+
dtype=None,
|
|
370
|
+
layout=None,
|
|
371
|
+
device=None,
|
|
372
|
+
requires_grad=False,
|
|
373
|
+
pin_memory=False,
|
|
374
|
+
):
|
|
300
375
|
return empty(size, dtype=dtype)
|
|
301
376
|
|
|
302
377
|
|
|
@@ -306,25 +381,30 @@ def unravel_index(indices, shape):
|
|
|
306
381
|
|
|
307
382
|
|
|
308
383
|
@register_function(torch.rand, is_jax_function=False)
|
|
309
|
-
def rand(
|
|
310
|
-
*size, **kwargs
|
|
311
|
-
):
|
|
384
|
+
def rand(*size, **kwargs):
|
|
312
385
|
if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
|
|
313
386
|
size = size[0]
|
|
314
387
|
return torch.ops.aten.rand(size, **kwargs)
|
|
315
388
|
|
|
389
|
+
|
|
316
390
|
@register_function(torch.randn, is_jax_function=False)
|
|
317
391
|
def randn(
|
|
318
|
-
|
|
392
|
+
*size,
|
|
393
|
+
generator=None,
|
|
394
|
+
out=None,
|
|
395
|
+
dtype=None,
|
|
396
|
+
layout=torch.strided,
|
|
397
|
+
device=None,
|
|
398
|
+
requires_grad=False,
|
|
399
|
+
pin_memory=False,
|
|
319
400
|
):
|
|
320
401
|
if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
|
|
321
402
|
size = size[0]
|
|
322
403
|
return torch.ops.aten.randn(size, generator=generator, dtype=dtype)
|
|
323
404
|
|
|
405
|
+
|
|
324
406
|
@register_function(torch.randint, is_jax_function=False)
|
|
325
|
-
def randint(
|
|
326
|
-
*args, **kwargs
|
|
327
|
-
):
|
|
407
|
+
def randint(*args, **kwargs):
|
|
328
408
|
return torch.ops.aten.randint(*args, **kwargs)
|
|
329
409
|
|
|
330
410
|
|
|
@@ -356,14 +436,17 @@ def linalg_solve_ex(a, b):
|
|
|
356
436
|
res, info = jaten._aten__linalg_solve_ex(a, b)
|
|
357
437
|
return res, info
|
|
358
438
|
|
|
439
|
+
|
|
359
440
|
@register_function(torch.linalg.svd)
|
|
360
441
|
def linalg_svd(a, full_matrices=True):
|
|
361
442
|
return jaten._aten__linalg_svd(a, full_matrices=full_matrices)
|
|
362
443
|
|
|
444
|
+
|
|
363
445
|
@register_function(torch.linalg.matrix_power)
|
|
364
446
|
def matrix_power(A, n, *, out=None):
|
|
365
447
|
return jnp.linalg.matrix_power(A, n)
|
|
366
448
|
|
|
449
|
+
|
|
367
450
|
@register_function(torch.svd)
|
|
368
451
|
def svd(a, some=True, compute_uv=True):
|
|
369
452
|
if not compute_uv:
|
|
@@ -374,21 +457,24 @@ def svd(a, some=True, compute_uv=True):
|
|
|
374
457
|
U, S, V = jaten._aten__linalg_svd(a, full_matrices=not some)
|
|
375
458
|
return U, S, jnp.matrix_transpose(V)
|
|
376
459
|
|
|
460
|
+
|
|
377
461
|
@register_function(torch.cdist)
|
|
378
|
-
def _cdist(x1, x2, p=2.0, compute_mode=
|
|
379
|
-
|
|
462
|
+
def _cdist(x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary"):
|
|
463
|
+
return jaten._aten_cdist(x1, x2, p, compute_mode)
|
|
464
|
+
|
|
380
465
|
|
|
381
466
|
@register_function(torch.lu)
|
|
382
467
|
def lu(A, **kwargs):
|
|
383
|
-
lu,pivots,_ = jax.lax.linalg.lu(A)
|
|
468
|
+
lu, pivots, _ = jax.lax.linalg.lu(A)
|
|
384
469
|
# JAX pivots are offset by 1 compared to torch
|
|
385
470
|
_pivots = pivots + 1
|
|
386
471
|
info_shape = pivots.shape[:-1]
|
|
387
472
|
info = jnp.zeros(info_shape, dtype=mappings.t2j_dtype(torch.int32))
|
|
388
|
-
if kwargs[
|
|
473
|
+
if kwargs["get_infos"] == True:
|
|
389
474
|
return lu, _pivots, info
|
|
390
475
|
return lu, _pivots
|
|
391
476
|
|
|
477
|
+
|
|
392
478
|
@register_function(torch.lu_solve)
|
|
393
479
|
def lu_solve(b, LU_data, LU_pivots, **kwargs):
|
|
394
480
|
# JAX pivots are offset by 1 compared to torch
|
|
@@ -396,6 +482,7 @@ def lu_solve(b, LU_data, LU_pivots, **kwargs):
|
|
|
396
482
|
x = jax.scipy.linalg.lu_solve((LU_data, _pivots), b)
|
|
397
483
|
return x
|
|
398
484
|
|
|
485
|
+
|
|
399
486
|
@register_function(torch.linalg.tensorsolve)
|
|
400
487
|
def linalg_tensorsolve(A, b, dims=None):
|
|
401
488
|
# examples:
|
|
@@ -425,3 +512,57 @@ def functional_linear(self, weights, bias=None):
|
|
|
425
512
|
if bias is not None:
|
|
426
513
|
res += bias
|
|
427
514
|
return res
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
@register_function(torch.nn.functional.interpolate)
|
|
518
|
+
def functional_interpolate(
|
|
519
|
+
input,
|
|
520
|
+
size: Tuple[int, int],
|
|
521
|
+
scale_factor: Optional[float],
|
|
522
|
+
mode: str,
|
|
523
|
+
align_corners: bool,
|
|
524
|
+
recompute_scale_factor: bool,
|
|
525
|
+
antialias: bool,
|
|
526
|
+
):
|
|
527
|
+
supported_methods = (
|
|
528
|
+
"nearest",
|
|
529
|
+
"linear",
|
|
530
|
+
"bilinear",
|
|
531
|
+
"trilinear",
|
|
532
|
+
"cubic",
|
|
533
|
+
"bicubic",
|
|
534
|
+
"tricubic",
|
|
535
|
+
"lanczos3",
|
|
536
|
+
"lanczos5",
|
|
537
|
+
)
|
|
538
|
+
is_jax_supported = mode in supported_methods
|
|
539
|
+
if not is_jax_supported:
|
|
540
|
+
raise torchax.tensor.OperatorNotFound(
|
|
541
|
+
f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}"
|
|
542
|
+
)
|
|
543
|
+
# None check
|
|
544
|
+
antialias = antialias or False
|
|
545
|
+
align_corners = align_corners or False
|
|
546
|
+
|
|
547
|
+
if mode in ('cubic', 'bicubic',
|
|
548
|
+
'tricubic') and not antialias and size is not None:
|
|
549
|
+
return jimage.interpolate_bicubic_no_aa(
|
|
550
|
+
input,
|
|
551
|
+
size[0],
|
|
552
|
+
size[1],
|
|
553
|
+
align_corners,
|
|
554
|
+
)
|
|
555
|
+
else:
|
|
556
|
+
# fallback
|
|
557
|
+
raise torchax.tensor.OperatorNotFound(
|
|
558
|
+
f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}"
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
@register_function(torch.Tensor.repeat_interleave)
|
|
563
|
+
def torch_Tensor_repeat_interleave(self,
|
|
564
|
+
repeats,
|
|
565
|
+
dim=None,
|
|
566
|
+
*,
|
|
567
|
+
output_size=None):
|
|
568
|
+
return jnp.repeat(self, repeats, axis=dim, total_repeat_length=output_size)
|
torchax/ops/jtorchvision_nms.py
CHANGED
|
@@ -53,8 +53,8 @@ def _self_suppression(in_args):
|
|
|
53
53
|
can_suppress_others = jnp.reshape(
|
|
54
54
|
jnp.max(iou, 1) <= 0.5, [batch_size, -1, 1]).astype(iou.dtype)
|
|
55
55
|
iou_suppressed = jnp.reshape(
|
|
56
|
-
(jnp.max(can_suppress_others * iou, 1) <= 0.5).astype(
|
|
57
|
-
|
|
56
|
+
(jnp.max(can_suppress_others * iou, 1) <= 0.5).astype(
|
|
57
|
+
iou.dtype), [batch_size, -1, 1]) * iou
|
|
58
58
|
iou_sum_new = jnp.sum(iou_suppressed, [1, 2])
|
|
59
59
|
return iou_suppressed, jnp.any(iou_sum - iou_sum_new > 0.5), iou_sum_new
|
|
60
60
|
|
|
@@ -65,9 +65,8 @@ def _cross_suppression(in_args):
|
|
|
65
65
|
new_slice = lax.dynamic_slice(boxes, [0, inner_idx * _NMS_TILE_SIZE, 0],
|
|
66
66
|
[batch_size, _NMS_TILE_SIZE, 4])
|
|
67
67
|
iou = _bbox_overlap(new_slice, box_slice)
|
|
68
|
-
ret_slice = jnp.expand_dims(
|
|
69
|
-
|
|
70
|
-
2) * box_slice
|
|
68
|
+
ret_slice = jnp.expand_dims((jnp.all(iou < iou_threshold, [1])).astype(
|
|
69
|
+
box_slice.dtype), 2) * box_slice
|
|
71
70
|
return boxes, ret_slice, iou_threshold, inner_idx + 1
|
|
72
71
|
|
|
73
72
|
|
|
@@ -90,45 +89,40 @@ def _suppression_loop_body(in_args):
|
|
|
90
89
|
# Iterates over tiles that can possibly suppress the current tile.
|
|
91
90
|
box_slice = lax.dynamic_slice(boxes, [0, idx * _NMS_TILE_SIZE, 0],
|
|
92
91
|
[batch_size, _NMS_TILE_SIZE, 4])
|
|
92
|
+
|
|
93
93
|
def _loop_cond(in_args):
|
|
94
94
|
_, _, _, inner_idx = in_args
|
|
95
95
|
return inner_idx < idx
|
|
96
96
|
|
|
97
|
-
_, box_slice, _, _ = lax.while_loop(
|
|
98
|
-
|
|
99
|
-
_cross_suppression, (boxes, box_slice, iou_threshold,
|
|
100
|
-
0))
|
|
97
|
+
_, box_slice, _, _ = lax.while_loop(_loop_cond, _cross_suppression,
|
|
98
|
+
(boxes, box_slice, iou_threshold, 0))
|
|
101
99
|
|
|
102
100
|
# Iterates over the current tile to compute self-suppression.
|
|
103
101
|
iou = _bbox_overlap(box_slice, box_slice)
|
|
104
102
|
mask = jnp.expand_dims(
|
|
105
|
-
jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [1, -1])
|
|
106
|
-
|
|
103
|
+
jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [1, -1])
|
|
104
|
+
> jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [-1, 1]), 0)
|
|
107
105
|
iou *= (jnp.logical_and(mask, iou >= iou_threshold)).astype(iou.dtype)
|
|
108
106
|
|
|
109
107
|
def _loop_cond2(in_args):
|
|
110
108
|
_, loop_condition, _ = in_args
|
|
111
109
|
return loop_condition
|
|
112
110
|
|
|
113
|
-
suppressed_iou, _, _ = lax.while_loop(
|
|
114
|
-
|
|
115
|
-
(iou, True,
|
|
116
|
-
jnp.sum(iou, [1, 2])))
|
|
111
|
+
suppressed_iou, _, _ = lax.while_loop(_loop_cond2, _self_suppression,
|
|
112
|
+
(iou, True, jnp.sum(iou, [1, 2])))
|
|
117
113
|
suppressed_box = jnp.sum(suppressed_iou, 1) > 0
|
|
118
114
|
box_slice *= jnp.expand_dims(1.0 - suppressed_box.astype(box_slice.dtype), 2)
|
|
119
115
|
|
|
120
116
|
# Uses box_slice to update the input boxes.
|
|
121
|
-
mask = jnp.reshape(
|
|
122
|
-
|
|
123
|
-
[1, -1, 1, 1])
|
|
117
|
+
mask = jnp.reshape((jnp.equal(jnp.arange(num_tiles),
|
|
118
|
+
idx)).astype(boxes.dtype), [1, -1, 1, 1])
|
|
124
119
|
boxes = jnp.tile(jnp.expand_dims(
|
|
125
120
|
box_slice, 1), [1, num_tiles, 1, 1]) * mask + jnp.reshape(
|
|
126
121
|
boxes, [batch_size, num_tiles, _NMS_TILE_SIZE, 4]) * (1 - mask)
|
|
127
122
|
boxes = jnp.reshape(boxes, [batch_size, -1, 4])
|
|
128
123
|
|
|
129
124
|
# Updates output_size.
|
|
130
|
-
output_size += jnp.sum(
|
|
131
|
-
jnp.any(box_slice > 0, [2]).astype(jnp.int32), [1])
|
|
125
|
+
output_size += jnp.sum(jnp.any(box_slice > 0, [2]).astype(jnp.int32), [1])
|
|
132
126
|
return boxes, iou_threshold, output_size, idx + 1
|
|
133
127
|
|
|
134
128
|
|
|
@@ -185,8 +179,8 @@ def non_max_suppression_padded(scores, boxes, max_output_size, iou_threshold):
|
|
|
185
179
|
"""
|
|
186
180
|
batch_size = boxes.shape[0]
|
|
187
181
|
num_boxes = boxes.shape[1]
|
|
188
|
-
pad = int(jnp.ceil(
|
|
189
|
-
|
|
182
|
+
pad = int(jnp.ceil(
|
|
183
|
+
float(num_boxes) / _NMS_TILE_SIZE)) * _NMS_TILE_SIZE - num_boxes
|
|
190
184
|
boxes = jnp.pad(boxes.astype(jnp.float32), [[0, 0], [0, pad], [0, 0]])
|
|
191
185
|
scores = jnp.pad(scores.astype(jnp.float32), [[0, 0], [0, pad]])
|
|
192
186
|
num_boxes += pad
|
|
@@ -194,15 +188,12 @@ def non_max_suppression_padded(scores, boxes, max_output_size, iou_threshold):
|
|
|
194
188
|
def _loop_cond(in_args):
|
|
195
189
|
unused_boxes, unused_threshold, output_size, idx = in_args
|
|
196
190
|
return jnp.logical_and(
|
|
197
|
-
jnp.min(output_size) < max_output_size,
|
|
198
|
-
|
|
191
|
+
jnp.min(output_size) < max_output_size, idx
|
|
192
|
+
< num_boxes // _NMS_TILE_SIZE)
|
|
199
193
|
|
|
200
194
|
selected_boxes, _, output_size, _ = lax.while_loop(
|
|
201
|
-
_loop_cond, _suppression_loop_body,
|
|
202
|
-
|
|
203
|
-
jnp.zeros([batch_size], jnp.int32),
|
|
204
|
-
0
|
|
205
|
-
))
|
|
195
|
+
_loop_cond, _suppression_loop_body,
|
|
196
|
+
(boxes, iou_threshold, jnp.zeros([batch_size], jnp.int32), 0))
|
|
206
197
|
idx = num_boxes - lax.top_k(
|
|
207
198
|
jnp.any(selected_boxes > 0, [2]).astype(jnp.int32) *
|
|
208
199
|
jnp.expand_dims(jnp.arange(num_boxes, 0, -1), 0),
|
|
@@ -210,30 +201,28 @@ def non_max_suppression_padded(scores, boxes, max_output_size, iou_threshold):
|
|
|
210
201
|
idx = jnp.minimum(idx, num_boxes - 1)
|
|
211
202
|
idx = jnp.reshape(
|
|
212
203
|
idx + jnp.reshape(jnp.arange(batch_size) * num_boxes, [-1, 1]), [-1])
|
|
213
|
-
|
|
204
|
+
|
|
214
205
|
return idx
|
|
215
|
-
boxes = jnp.reshape(
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
jnp.reshape(jnp.arange(max_output_size), [1, -1, 1]) < jnp.reshape(
|
|
220
|
-
output_size, [-1, 1, 1])).astype(boxes.dtype)
|
|
206
|
+
boxes = jnp.reshape((jnp.reshape(boxes, [-1, 4]))[idx],
|
|
207
|
+
[batch_size, max_output_size, 4])
|
|
208
|
+
boxes = boxes * (jnp.reshape(jnp.arange(max_output_size), [1, -1, 1])
|
|
209
|
+
< jnp.reshape(output_size, [-1, 1, 1])).astype(boxes.dtype)
|
|
221
210
|
scores = jnp.reshape(
|
|
222
|
-
jnp.reshape(scores, [-1, 1])[idx],
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
jnp.reshape(jnp.arange(max_output_size), [1, -1]) < jnp.reshape(
|
|
226
|
-
output_size, [-1, 1])).astype(scores.dtype)
|
|
211
|
+
jnp.reshape(scores, [-1, 1])[idx], [batch_size, max_output_size])
|
|
212
|
+
scores = scores * (jnp.reshape(jnp.arange(max_output_size), [1, -1])
|
|
213
|
+
< jnp.reshape(output_size, [-1, 1])).astype(scores.dtype)
|
|
227
214
|
return scores, boxes
|
|
228
215
|
|
|
229
216
|
|
|
230
217
|
# registry:
|
|
231
218
|
|
|
219
|
+
|
|
232
220
|
def nms(boxes, scores, iou_threshold):
|
|
233
221
|
max_output_size = boxes.shape[0]
|
|
234
222
|
boxes = boxes.reshape((1, *boxes.shape))
|
|
235
223
|
scores = scores.reshape((1, *scores.shape))
|
|
236
|
-
res = non_max_suppression_padded(scores, boxes, max_output_size,
|
|
224
|
+
res = non_max_suppression_padded(scores, boxes, max_output_size,
|
|
225
|
+
iou_threshold)
|
|
237
226
|
return res
|
|
238
227
|
|
|
239
228
|
|
|
@@ -242,4 +231,4 @@ try:
|
|
|
242
231
|
import torchvision
|
|
243
232
|
ops_registry.register_torch_dispatch_op(torch.ops.torchvision.nms, nms)
|
|
244
233
|
except Exception:
|
|
245
|
-
pass
|
|
234
|
+
pass
|