torchax 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/CONTRIBUTING.md +2 -2
- torchax/__init__.py +26 -24
- torchax/amp.py +332 -0
- torchax/config.py +25 -14
- torchax/configuration.py +30 -0
- torchax/decompositions.py +663 -195
- torchax/device_module.py +14 -1
- torchax/environment.py +0 -1
- torchax/export.py +26 -17
- torchax/flax.py +39 -0
- torchax/interop.py +288 -141
- torchax/mesh_util.py +220 -0
- torchax/ops/jaten.py +1723 -1297
- torchax/ops/jax_reimplement.py +23 -21
- torchax/ops/jc10d.py +5 -4
- torchax/ops/jimage.py +113 -0
- torchax/ops/jlibrary.py +9 -2
- torchax/ops/jtorch.py +237 -88
- torchax/ops/jtorchvision_nms.py +32 -43
- torchax/ops/mappings.py +77 -35
- torchax/ops/op_base.py +59 -32
- torchax/ops/ops_registry.py +40 -35
- torchax/tensor.py +442 -288
- torchax/train.py +38 -41
- torchax/util.py +88 -0
- torchax/view.py +377 -0
- {torchax-0.0.4.dist-info → torchax-0.0.6.dist-info}/METADATA +111 -145
- torchax-0.0.6.dist-info/RECORD +33 -0
- torchax/distributed.py +0 -246
- torchax-0.0.4.dist-info/RECORD +0 -27
- {torchax-0.0.4.dist-info → torchax-0.0.6.dist-info}/WHEEL +0 -0
- {torchax-0.0.4.dist-info → torchax-0.0.6.dist-info}/licenses/LICENSE +0 -0
torchax/ops/jtorch.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
"""Tensor constructor overrides"""
|
|
2
|
+
|
|
2
3
|
import math
|
|
3
4
|
import collections.abc
|
|
4
5
|
import functools
|
|
5
|
-
from typing import Optional, Sequence
|
|
6
|
+
from typing import Optional, Sequence, Tuple
|
|
6
7
|
import numpy as np
|
|
7
8
|
|
|
8
9
|
import jax
|
|
@@ -12,8 +13,10 @@ from jax.experimental.shard_map import shard_map
|
|
|
12
13
|
|
|
13
14
|
import torch
|
|
14
15
|
from torchax.ops.ops_registry import register_torch_function_op
|
|
15
|
-
from torchax.ops import op_base, mappings, jaten
|
|
16
|
+
from torchax.ops import op_base, mappings, jaten, jimage
|
|
16
17
|
import torchax.tensor
|
|
18
|
+
from torchax.view import View, NarrowInfo
|
|
19
|
+
import torch.utils._pytree as pytree
|
|
17
20
|
|
|
18
21
|
|
|
19
22
|
def register_function(torch_func, **kwargs):
|
|
@@ -21,7 +24,8 @@ def register_function(torch_func, **kwargs):
|
|
|
21
24
|
|
|
22
25
|
|
|
23
26
|
@register_function(torch.as_tensor, is_jax_function=False, needs_env=True)
|
|
24
|
-
@op_base.convert_dtype(
|
|
27
|
+
@op_base.convert_dtype(
|
|
28
|
+
use_default_dtype=False) # Attempt to infer type from elements
|
|
25
29
|
def _as_tensor(data, dtype=None, device=None, env=None):
|
|
26
30
|
if isinstance(data, torch.Tensor):
|
|
27
31
|
return env._to_copy(data, dtype, device)
|
|
@@ -33,7 +37,8 @@ def _as_tensor(data, dtype=None, device=None, env=None):
|
|
|
33
37
|
|
|
34
38
|
|
|
35
39
|
@register_function(torch.tensor)
|
|
36
|
-
@op_base.convert_dtype(
|
|
40
|
+
@op_base.convert_dtype(
|
|
41
|
+
use_default_dtype=False) # Attempt to infer type from elements
|
|
37
42
|
def _tensor(data, *, dtype=None, **kwargs):
|
|
38
43
|
python_types_to_torch_types = {
|
|
39
44
|
bool: jnp.bool,
|
|
@@ -57,8 +62,8 @@ def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):
|
|
|
57
62
|
|
|
58
63
|
@register_function(torch.angle)
|
|
59
64
|
def _torch_angle(input):
|
|
60
|
-
if input.dtype.name ==
|
|
61
|
-
input = input.astype(jnp.dtype(
|
|
65
|
+
if input.dtype.name == "int64":
|
|
66
|
+
input = input.astype(jnp.dtype("float32"))
|
|
62
67
|
return jnp.angle(input)
|
|
63
68
|
|
|
64
69
|
|
|
@@ -72,19 +77,21 @@ def _torch_argsort(input, dim=-1, descending=False, stable=False):
|
|
|
72
77
|
# behavior is the same as a jnp array of rank 1
|
|
73
78
|
expanded = True
|
|
74
79
|
input = jnp.expand_dims(input, 0)
|
|
75
|
-
res = jnp.argsort(input, axis=dim, descending=descending,
|
|
76
|
-
stable=stable)
|
|
80
|
+
res = jnp.argsort(input, axis=dim, descending=descending, stable=stable)
|
|
77
81
|
if expanded:
|
|
78
82
|
res = res.squeeze()
|
|
79
83
|
return res
|
|
80
84
|
|
|
85
|
+
|
|
81
86
|
@register_function(torch.diag)
|
|
82
87
|
def _diag(input, diagonal=0):
|
|
83
88
|
return jnp.diag(input, k=diagonal)
|
|
84
89
|
|
|
90
|
+
|
|
85
91
|
@register_function(torch.einsum)
|
|
86
92
|
@register_function(torch.ops.aten.einsum)
|
|
87
93
|
def _einsum(equation, *operands):
|
|
94
|
+
|
|
88
95
|
def get_params(*a):
|
|
89
96
|
inner_list = a[0]
|
|
90
97
|
if not isinstance(inner_list, jax.Array):
|
|
@@ -95,71 +102,90 @@ def _einsum(equation, *operands):
|
|
|
95
102
|
A, B = inner_list
|
|
96
103
|
return A, B
|
|
97
104
|
return operands
|
|
98
|
-
|
|
105
|
+
|
|
106
|
+
assert isinstance(equation, str), "Only accept str equation"
|
|
99
107
|
filtered_operands = get_params(*operands)
|
|
100
108
|
return jnp.einsum(equation, *filtered_operands)
|
|
101
109
|
|
|
102
110
|
|
|
103
|
-
def _sdpa_reference(
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
111
|
+
def _sdpa_reference(
|
|
112
|
+
query,
|
|
113
|
+
key,
|
|
114
|
+
value,
|
|
115
|
+
attn_mask=None,
|
|
116
|
+
dropout_p=0.0,
|
|
117
|
+
is_causal=False,
|
|
118
|
+
scale=None,
|
|
119
|
+
enable_gqa=False,
|
|
120
|
+
) -> torch.Tensor:
|
|
121
|
+
L, S = query.size(-2), key.size(-2)
|
|
122
|
+
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
|
|
123
|
+
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
|
|
124
|
+
if is_causal:
|
|
125
|
+
assert attn_mask is None
|
|
126
|
+
temp_mask = torch.ones(
|
|
127
|
+
L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
|
|
128
|
+
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
|
129
|
+
attn_bias.to(query.dtype)
|
|
130
|
+
if attn_mask is not None:
|
|
131
|
+
if attn_mask.dtype == torch.bool:
|
|
132
|
+
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
|
133
|
+
else:
|
|
134
|
+
attn_bias += attn_mask
|
|
135
|
+
if enable_gqa:
|
|
136
|
+
key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
|
|
137
|
+
value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)
|
|
138
|
+
|
|
139
|
+
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
|
140
|
+
attn_weight += attn_bias
|
|
141
|
+
attn_weight = torch.softmax(attn_weight, dim=-1)
|
|
142
|
+
if dropout_p > 0:
|
|
143
|
+
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
|
|
144
|
+
return attn_weight @ value
|
|
128
145
|
|
|
129
146
|
|
|
130
147
|
from jax.sharding import PartitionSpec
|
|
131
148
|
|
|
149
|
+
|
|
132
150
|
def _tpu_flash_attention(query, key, value, env):
|
|
133
|
-
fsdp_partition = PartitionSpec(
|
|
151
|
+
fsdp_partition = PartitionSpec("fsdp")
|
|
152
|
+
|
|
134
153
|
def wrap_flash_attention(query, key, value):
|
|
135
154
|
block_sizes = flash_attention.BlockSizes(
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
155
|
+
block_b=min(2, query.shape[0]),
|
|
156
|
+
block_q=min(512, query.shape[2]),
|
|
157
|
+
block_k_major=min(512, key.shape[2]),
|
|
158
|
+
block_k=min(512, key.shape[2]),
|
|
159
|
+
block_q_major_dkv=min(512, query.shape[2]),
|
|
160
|
+
block_k_major_dkv=min(512, key.shape[2]),
|
|
161
|
+
block_k_dkv=min(512, key.shape[2]),
|
|
162
|
+
block_q_dkv=min(512, query.shape[2]),
|
|
163
|
+
block_k_major_dq=min(512, key.shape[2]),
|
|
164
|
+
block_k_dq=min(256, key.shape[2]),
|
|
165
|
+
block_q_dq=min(1024, query.shape[2]),
|
|
147
166
|
)
|
|
148
167
|
return flash_attention.flash_attention(
|
|
149
168
|
query, key, value, causal=True, block_sizes=block_sizes)
|
|
150
169
|
|
|
151
170
|
if env.config.shmap_flash_attention:
|
|
152
171
|
wrap_flash_attention = shard_map(
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
172
|
+
wrap_flash_attention,
|
|
173
|
+
mesh=env._mesh,
|
|
174
|
+
in_specs=(fsdp_partition, fsdp_partition, fsdp_partition),
|
|
175
|
+
out_specs=fsdp_partition,
|
|
176
|
+
check_rep=False,
|
|
158
177
|
)
|
|
159
|
-
#return flash_attn_mapped(query, key, value)
|
|
178
|
+
# return flash_attn_mapped(query, key, value)
|
|
160
179
|
return wrap_flash_attention(query, key, value)
|
|
161
180
|
|
|
162
181
|
|
|
182
|
+
@register_function(torch.nn.functional.one_hot)
|
|
183
|
+
def one_hot(tensor, num_classes=-1):
|
|
184
|
+
if num_classes == -1:
|
|
185
|
+
num_classes = jnp.max(tensor) + 1
|
|
186
|
+
return jax.nn.one_hot(tensor, num_classes).astype(jnp.int64)
|
|
187
|
+
|
|
188
|
+
|
|
163
189
|
@register_function(torch.nn.functional.pad)
|
|
164
190
|
def pad(tensor, pad, mode="constant", value=None):
|
|
165
191
|
# For padding modes that have different names between Torch and NumPy, this
|
|
@@ -210,27 +236,60 @@ def pad(tensor, pad, mode="constant", value=None):
|
|
|
210
236
|
return jnp.pad(tensor[nd_slice], numpy_pad_width, mode=numpy_mode, **kwargs)
|
|
211
237
|
|
|
212
238
|
|
|
213
|
-
@register_function(
|
|
214
|
-
|
|
239
|
+
@register_function(
|
|
240
|
+
torch.nn.functional.scaled_dot_product_attention,
|
|
241
|
+
is_jax_function=False,
|
|
242
|
+
needs_env=True,
|
|
243
|
+
)
|
|
244
|
+
@register_function(
|
|
245
|
+
torch.ops.aten.scaled_dot_product_attention,
|
|
246
|
+
is_jax_function=False,
|
|
247
|
+
needs_env=True)
|
|
215
248
|
def scaled_dot_product_attention(
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
249
|
+
query,
|
|
250
|
+
key,
|
|
251
|
+
value,
|
|
252
|
+
attn_mask=None,
|
|
253
|
+
dropout_p=0.0,
|
|
254
|
+
is_causal=False,
|
|
255
|
+
scale=None,
|
|
256
|
+
enable_gqa=False,
|
|
257
|
+
env=None,
|
|
258
|
+
) -> torch.Tensor:
|
|
259
|
+
|
|
260
|
+
if env.config.use_tpu_flash_attention:
|
|
220
261
|
jquery, jkey, jvalue = env.t2j_iso((query, key, value))
|
|
221
262
|
res = _tpu_flash_attention(jquery, jkey, jvalue, env)
|
|
222
263
|
return env.j2t_iso(res)
|
|
223
264
|
|
|
224
|
-
|
|
265
|
+
return _sdpa_reference(query, key, value, attn_mask, dropout_p, is_causal,
|
|
266
|
+
scale, enable_gqa)
|
|
267
|
+
|
|
225
268
|
|
|
226
|
-
@register_function(
|
|
269
|
+
@register_function(
|
|
270
|
+
torch.Tensor.__getitem__, is_jax_function=False, is_view_op=True)
|
|
227
271
|
def getitem(self, indexes):
|
|
272
|
+
|
|
228
273
|
if isinstance(indexes, list) and isinstance(indexes[0], int):
|
|
229
274
|
# list of int, i.e. x[[1, 2]] NOT x[1, 2] (the second would be tuple of int)
|
|
230
|
-
indexes = (indexes,
|
|
275
|
+
indexes = (indexes,)
|
|
231
276
|
elif isinstance(indexes, list):
|
|
232
277
|
indexes = tuple(indexes)
|
|
233
|
-
|
|
278
|
+
|
|
279
|
+
def is_narrow_slicing():
|
|
280
|
+
tensor_free = not pytree.tree_any(
|
|
281
|
+
lambda x: isinstance(x, torch.Tensor) or isinstance(x, jax.Array),
|
|
282
|
+
indexes)
|
|
283
|
+
list_free = not isinstance(indexes, tuple) or all(
|
|
284
|
+
[False if isinstance(x, list) else True for x in indexes])
|
|
285
|
+
return tensor_free and list_free
|
|
286
|
+
|
|
287
|
+
if is_narrow_slicing():
|
|
288
|
+
return View(self, view_info=NarrowInfo(indexes), env=self._env)
|
|
289
|
+
|
|
290
|
+
indexes = self._env.t2j_iso(indexes)
|
|
291
|
+
return torchax.tensor.Tensor(self._elem[indexes], self._env)
|
|
292
|
+
|
|
234
293
|
|
|
235
294
|
@register_function(torch.corrcoef)
|
|
236
295
|
def _corrcoef(x):
|
|
@@ -238,15 +297,22 @@ def _corrcoef(x):
|
|
|
238
297
|
return jnp.corrcoef(x).astype(jnp.float32)
|
|
239
298
|
return jnp.corrcoef(x)
|
|
240
299
|
|
|
300
|
+
|
|
241
301
|
@register_function(torch.sparse.mm, is_jax_function=False)
|
|
242
|
-
def _sparse_mm(mat1, mat2, reduce=
|
|
302
|
+
def _sparse_mm(mat1, mat2, reduce="sum"):
|
|
243
303
|
return torch.mm(mat1, mat2)
|
|
244
304
|
|
|
305
|
+
|
|
245
306
|
@register_function(torch.isclose)
|
|
246
307
|
def _aten_isclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):
|
|
247
308
|
return jnp.isclose(input, other, rtol, atol, equal_nan)
|
|
248
309
|
|
|
249
310
|
|
|
311
|
+
@register_function(torch.linalg.det)
|
|
312
|
+
def linalg_det(input):
|
|
313
|
+
return jnp.linalg.det(input)
|
|
314
|
+
|
|
315
|
+
|
|
250
316
|
@register_function(torch.ones)
|
|
251
317
|
def _ones(*size: int, dtype=None, **kwargs):
|
|
252
318
|
if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
|
|
@@ -281,23 +347,39 @@ def empty(*size: Sequence[int], dtype=None, **kwargs):
|
|
|
281
347
|
size = size[0]
|
|
282
348
|
return jnp.empty(size, dtype=dtype)
|
|
283
349
|
|
|
284
|
-
|
|
350
|
+
|
|
351
|
+
@register_function(torch.arange, is_jax_function=True)
|
|
285
352
|
def arange(
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
353
|
+
start,
|
|
354
|
+
end=None,
|
|
355
|
+
step=None,
|
|
356
|
+
out=None,
|
|
357
|
+
dtype=None,
|
|
358
|
+
layout=torch.strided,
|
|
359
|
+
device=None,
|
|
360
|
+
requires_grad=False,
|
|
361
|
+
pin_memory=None,
|
|
289
362
|
):
|
|
290
363
|
if end is None:
|
|
291
364
|
end = start
|
|
292
365
|
start = 0
|
|
293
366
|
if step is None:
|
|
294
367
|
step = 1
|
|
295
|
-
return
|
|
368
|
+
return jaten._aten_arange(start, end, step, dtype=dtype)
|
|
369
|
+
|
|
296
370
|
|
|
297
|
-
@register_function(torch.empty_strided, is_jax_function=
|
|
371
|
+
@register_function(torch.empty_strided, is_jax_function=True)
|
|
298
372
|
def empty_strided(
|
|
299
|
-
|
|
300
|
-
|
|
373
|
+
size,
|
|
374
|
+
stride,
|
|
375
|
+
*,
|
|
376
|
+
dtype=None,
|
|
377
|
+
layout=None,
|
|
378
|
+
device=None,
|
|
379
|
+
requires_grad=False,
|
|
380
|
+
pin_memory=False,
|
|
381
|
+
):
|
|
382
|
+
return empty(size, dtype=dtype, requires_grad=requires_grad)
|
|
301
383
|
|
|
302
384
|
|
|
303
385
|
@register_function(torch.unravel_index)
|
|
@@ -305,27 +387,33 @@ def unravel_index(indices, shape):
|
|
|
305
387
|
return jnp.unravel_index(indices, shape)
|
|
306
388
|
|
|
307
389
|
|
|
308
|
-
@register_function(torch.rand, is_jax_function=
|
|
309
|
-
def rand(
|
|
310
|
-
*size, **kwargs
|
|
311
|
-
):
|
|
390
|
+
@register_function(torch.rand, is_jax_function=True, needs_env=True)
|
|
391
|
+
def rand(*size, **kwargs):
|
|
312
392
|
if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
|
|
313
393
|
size = size[0]
|
|
314
|
-
return
|
|
394
|
+
return jaten._rand(size, **kwargs)
|
|
395
|
+
|
|
315
396
|
|
|
316
|
-
@register_function(torch.randn, is_jax_function=
|
|
397
|
+
@register_function(torch.randn, is_jax_function=True, needs_env=True)
|
|
317
398
|
def randn(
|
|
318
|
-
|
|
399
|
+
*size,
|
|
400
|
+
generator=None,
|
|
401
|
+
out=None,
|
|
402
|
+
dtype=None,
|
|
403
|
+
layout=torch.strided,
|
|
404
|
+
device=None,
|
|
405
|
+
requires_grad=False,
|
|
406
|
+
pin_memory=False,
|
|
407
|
+
env=None,
|
|
319
408
|
):
|
|
320
409
|
if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
|
|
321
410
|
size = size[0]
|
|
322
|
-
return
|
|
411
|
+
return jaten._aten_randn(size, generator=generator, dtype=dtype, env=env)
|
|
323
412
|
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
)
|
|
328
|
-
return torch.ops.aten.randint(*args, **kwargs)
|
|
413
|
+
|
|
414
|
+
@register_function(torch.randint, is_jax_function=False, needs_env=True)
|
|
415
|
+
def randint(*args, **kwargs):
|
|
416
|
+
return jaten._aten_randint(*args, **kwargs)
|
|
329
417
|
|
|
330
418
|
|
|
331
419
|
@register_function(torch.logdet)
|
|
@@ -356,14 +444,17 @@ def linalg_solve_ex(a, b):
|
|
|
356
444
|
res, info = jaten._aten__linalg_solve_ex(a, b)
|
|
357
445
|
return res, info
|
|
358
446
|
|
|
447
|
+
|
|
359
448
|
@register_function(torch.linalg.svd)
|
|
360
449
|
def linalg_svd(a, full_matrices=True):
|
|
361
450
|
return jaten._aten__linalg_svd(a, full_matrices=full_matrices)
|
|
362
451
|
|
|
452
|
+
|
|
363
453
|
@register_function(torch.linalg.matrix_power)
|
|
364
454
|
def matrix_power(A, n, *, out=None):
|
|
365
455
|
return jnp.linalg.matrix_power(A, n)
|
|
366
456
|
|
|
457
|
+
|
|
367
458
|
@register_function(torch.svd)
|
|
368
459
|
def svd(a, some=True, compute_uv=True):
|
|
369
460
|
if not compute_uv:
|
|
@@ -374,21 +465,24 @@ def svd(a, some=True, compute_uv=True):
|
|
|
374
465
|
U, S, V = jaten._aten__linalg_svd(a, full_matrices=not some)
|
|
375
466
|
return U, S, jnp.matrix_transpose(V)
|
|
376
467
|
|
|
468
|
+
|
|
377
469
|
@register_function(torch.cdist)
|
|
378
|
-
def _cdist(x1, x2, p=2.0, compute_mode=
|
|
379
|
-
|
|
470
|
+
def _cdist(x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary"):
|
|
471
|
+
return jaten._aten_cdist(x1, x2, p, compute_mode)
|
|
472
|
+
|
|
380
473
|
|
|
381
474
|
@register_function(torch.lu)
|
|
382
475
|
def lu(A, **kwargs):
|
|
383
|
-
lu,pivots,_ = jax.lax.linalg.lu(A)
|
|
476
|
+
lu, pivots, _ = jax.lax.linalg.lu(A)
|
|
384
477
|
# JAX pivots are offset by 1 compared to torch
|
|
385
478
|
_pivots = pivots + 1
|
|
386
479
|
info_shape = pivots.shape[:-1]
|
|
387
480
|
info = jnp.zeros(info_shape, dtype=mappings.t2j_dtype(torch.int32))
|
|
388
|
-
if kwargs[
|
|
481
|
+
if kwargs["get_infos"] == True:
|
|
389
482
|
return lu, _pivots, info
|
|
390
483
|
return lu, _pivots
|
|
391
484
|
|
|
485
|
+
|
|
392
486
|
@register_function(torch.lu_solve)
|
|
393
487
|
def lu_solve(b, LU_data, LU_pivots, **kwargs):
|
|
394
488
|
# JAX pivots are offset by 1 compared to torch
|
|
@@ -396,6 +490,7 @@ def lu_solve(b, LU_data, LU_pivots, **kwargs):
|
|
|
396
490
|
x = jax.scipy.linalg.lu_solve((LU_data, _pivots), b)
|
|
397
491
|
return x
|
|
398
492
|
|
|
493
|
+
|
|
399
494
|
@register_function(torch.linalg.tensorsolve)
|
|
400
495
|
def linalg_tensorsolve(A, b, dims=None):
|
|
401
496
|
# examples:
|
|
@@ -425,3 +520,57 @@ def functional_linear(self, weights, bias=None):
|
|
|
425
520
|
if bias is not None:
|
|
426
521
|
res += bias
|
|
427
522
|
return res
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
@register_function(torch.nn.functional.interpolate)
|
|
526
|
+
def functional_interpolate(
|
|
527
|
+
input,
|
|
528
|
+
size: Tuple[int, int],
|
|
529
|
+
scale_factor: Optional[float],
|
|
530
|
+
mode: str,
|
|
531
|
+
align_corners: bool,
|
|
532
|
+
recompute_scale_factor: bool,
|
|
533
|
+
antialias: bool,
|
|
534
|
+
):
|
|
535
|
+
supported_methods = (
|
|
536
|
+
"nearest",
|
|
537
|
+
"linear",
|
|
538
|
+
"bilinear",
|
|
539
|
+
"trilinear",
|
|
540
|
+
"cubic",
|
|
541
|
+
"bicubic",
|
|
542
|
+
"tricubic",
|
|
543
|
+
"lanczos3",
|
|
544
|
+
"lanczos5",
|
|
545
|
+
)
|
|
546
|
+
is_jax_supported = mode in supported_methods
|
|
547
|
+
if not is_jax_supported:
|
|
548
|
+
raise torchax.tensor.OperatorNotFound(
|
|
549
|
+
f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}"
|
|
550
|
+
)
|
|
551
|
+
# None check
|
|
552
|
+
antialias = antialias or False
|
|
553
|
+
align_corners = align_corners or False
|
|
554
|
+
|
|
555
|
+
if mode in ('cubic', 'bicubic',
|
|
556
|
+
'tricubic') and not antialias and size is not None:
|
|
557
|
+
return jimage.interpolate_bicubic_no_aa(
|
|
558
|
+
input,
|
|
559
|
+
size[0],
|
|
560
|
+
size[1],
|
|
561
|
+
align_corners,
|
|
562
|
+
)
|
|
563
|
+
else:
|
|
564
|
+
# fallback
|
|
565
|
+
raise torchax.tensor.OperatorNotFound(
|
|
566
|
+
f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}"
|
|
567
|
+
)
|
|
568
|
+
|
|
569
|
+
|
|
570
|
+
@register_function(torch.Tensor.repeat_interleave)
|
|
571
|
+
def torch_Tensor_repeat_interleave(self,
|
|
572
|
+
repeats,
|
|
573
|
+
dim=None,
|
|
574
|
+
*,
|
|
575
|
+
output_size=None):
|
|
576
|
+
return jnp.repeat(self, repeats, axis=dim, total_repeat_length=output_size)
|
torchax/ops/jtorchvision_nms.py
CHANGED
|
@@ -53,8 +53,8 @@ def _self_suppression(in_args):
|
|
|
53
53
|
can_suppress_others = jnp.reshape(
|
|
54
54
|
jnp.max(iou, 1) <= 0.5, [batch_size, -1, 1]).astype(iou.dtype)
|
|
55
55
|
iou_suppressed = jnp.reshape(
|
|
56
|
-
(jnp.max(can_suppress_others * iou, 1) <= 0.5).astype(
|
|
57
|
-
|
|
56
|
+
(jnp.max(can_suppress_others * iou, 1) <= 0.5).astype(
|
|
57
|
+
iou.dtype), [batch_size, -1, 1]) * iou
|
|
58
58
|
iou_sum_new = jnp.sum(iou_suppressed, [1, 2])
|
|
59
59
|
return iou_suppressed, jnp.any(iou_sum - iou_sum_new > 0.5), iou_sum_new
|
|
60
60
|
|
|
@@ -65,9 +65,8 @@ def _cross_suppression(in_args):
|
|
|
65
65
|
new_slice = lax.dynamic_slice(boxes, [0, inner_idx * _NMS_TILE_SIZE, 0],
|
|
66
66
|
[batch_size, _NMS_TILE_SIZE, 4])
|
|
67
67
|
iou = _bbox_overlap(new_slice, box_slice)
|
|
68
|
-
ret_slice = jnp.expand_dims(
|
|
69
|
-
|
|
70
|
-
2) * box_slice
|
|
68
|
+
ret_slice = jnp.expand_dims((jnp.all(iou < iou_threshold, [1])).astype(
|
|
69
|
+
box_slice.dtype), 2) * box_slice
|
|
71
70
|
return boxes, ret_slice, iou_threshold, inner_idx + 1
|
|
72
71
|
|
|
73
72
|
|
|
@@ -90,45 +89,40 @@ def _suppression_loop_body(in_args):
|
|
|
90
89
|
# Iterates over tiles that can possibly suppress the current tile.
|
|
91
90
|
box_slice = lax.dynamic_slice(boxes, [0, idx * _NMS_TILE_SIZE, 0],
|
|
92
91
|
[batch_size, _NMS_TILE_SIZE, 4])
|
|
92
|
+
|
|
93
93
|
def _loop_cond(in_args):
|
|
94
94
|
_, _, _, inner_idx = in_args
|
|
95
95
|
return inner_idx < idx
|
|
96
96
|
|
|
97
|
-
_, box_slice, _, _ = lax.while_loop(
|
|
98
|
-
|
|
99
|
-
_cross_suppression, (boxes, box_slice, iou_threshold,
|
|
100
|
-
0))
|
|
97
|
+
_, box_slice, _, _ = lax.while_loop(_loop_cond, _cross_suppression,
|
|
98
|
+
(boxes, box_slice, iou_threshold, 0))
|
|
101
99
|
|
|
102
100
|
# Iterates over the current tile to compute self-suppression.
|
|
103
101
|
iou = _bbox_overlap(box_slice, box_slice)
|
|
104
102
|
mask = jnp.expand_dims(
|
|
105
|
-
jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [1, -1])
|
|
106
|
-
|
|
103
|
+
jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [1, -1])
|
|
104
|
+
> jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [-1, 1]), 0)
|
|
107
105
|
iou *= (jnp.logical_and(mask, iou >= iou_threshold)).astype(iou.dtype)
|
|
108
106
|
|
|
109
107
|
def _loop_cond2(in_args):
|
|
110
108
|
_, loop_condition, _ = in_args
|
|
111
109
|
return loop_condition
|
|
112
110
|
|
|
113
|
-
suppressed_iou, _, _ = lax.while_loop(
|
|
114
|
-
|
|
115
|
-
(iou, True,
|
|
116
|
-
jnp.sum(iou, [1, 2])))
|
|
111
|
+
suppressed_iou, _, _ = lax.while_loop(_loop_cond2, _self_suppression,
|
|
112
|
+
(iou, True, jnp.sum(iou, [1, 2])))
|
|
117
113
|
suppressed_box = jnp.sum(suppressed_iou, 1) > 0
|
|
118
114
|
box_slice *= jnp.expand_dims(1.0 - suppressed_box.astype(box_slice.dtype), 2)
|
|
119
115
|
|
|
120
116
|
# Uses box_slice to update the input boxes.
|
|
121
|
-
mask = jnp.reshape(
|
|
122
|
-
|
|
123
|
-
[1, -1, 1, 1])
|
|
117
|
+
mask = jnp.reshape((jnp.equal(jnp.arange(num_tiles),
|
|
118
|
+
idx)).astype(boxes.dtype), [1, -1, 1, 1])
|
|
124
119
|
boxes = jnp.tile(jnp.expand_dims(
|
|
125
120
|
box_slice, 1), [1, num_tiles, 1, 1]) * mask + jnp.reshape(
|
|
126
121
|
boxes, [batch_size, num_tiles, _NMS_TILE_SIZE, 4]) * (1 - mask)
|
|
127
122
|
boxes = jnp.reshape(boxes, [batch_size, -1, 4])
|
|
128
123
|
|
|
129
124
|
# Updates output_size.
|
|
130
|
-
output_size += jnp.sum(
|
|
131
|
-
jnp.any(box_slice > 0, [2]).astype(jnp.int32), [1])
|
|
125
|
+
output_size += jnp.sum(jnp.any(box_slice > 0, [2]).astype(jnp.int32), [1])
|
|
132
126
|
return boxes, iou_threshold, output_size, idx + 1
|
|
133
127
|
|
|
134
128
|
|
|
@@ -185,8 +179,8 @@ def non_max_suppression_padded(scores, boxes, max_output_size, iou_threshold):
|
|
|
185
179
|
"""
|
|
186
180
|
batch_size = boxes.shape[0]
|
|
187
181
|
num_boxes = boxes.shape[1]
|
|
188
|
-
pad = int(jnp.ceil(
|
|
189
|
-
|
|
182
|
+
pad = int(jnp.ceil(
|
|
183
|
+
float(num_boxes) / _NMS_TILE_SIZE)) * _NMS_TILE_SIZE - num_boxes
|
|
190
184
|
boxes = jnp.pad(boxes.astype(jnp.float32), [[0, 0], [0, pad], [0, 0]])
|
|
191
185
|
scores = jnp.pad(scores.astype(jnp.float32), [[0, 0], [0, pad]])
|
|
192
186
|
num_boxes += pad
|
|
@@ -194,15 +188,12 @@ def non_max_suppression_padded(scores, boxes, max_output_size, iou_threshold):
|
|
|
194
188
|
def _loop_cond(in_args):
|
|
195
189
|
unused_boxes, unused_threshold, output_size, idx = in_args
|
|
196
190
|
return jnp.logical_and(
|
|
197
|
-
jnp.min(output_size) < max_output_size,
|
|
198
|
-
|
|
191
|
+
jnp.min(output_size) < max_output_size, idx
|
|
192
|
+
< num_boxes // _NMS_TILE_SIZE)
|
|
199
193
|
|
|
200
194
|
selected_boxes, _, output_size, _ = lax.while_loop(
|
|
201
|
-
_loop_cond, _suppression_loop_body,
|
|
202
|
-
|
|
203
|
-
jnp.zeros([batch_size], jnp.int32),
|
|
204
|
-
0
|
|
205
|
-
))
|
|
195
|
+
_loop_cond, _suppression_loop_body,
|
|
196
|
+
(boxes, iou_threshold, jnp.zeros([batch_size], jnp.int32), 0))
|
|
206
197
|
idx = num_boxes - lax.top_k(
|
|
207
198
|
jnp.any(selected_boxes > 0, [2]).astype(jnp.int32) *
|
|
208
199
|
jnp.expand_dims(jnp.arange(num_boxes, 0, -1), 0),
|
|
@@ -210,30 +201,28 @@ def non_max_suppression_padded(scores, boxes, max_output_size, iou_threshold):
|
|
|
210
201
|
idx = jnp.minimum(idx, num_boxes - 1)
|
|
211
202
|
idx = jnp.reshape(
|
|
212
203
|
idx + jnp.reshape(jnp.arange(batch_size) * num_boxes, [-1, 1]), [-1])
|
|
213
|
-
|
|
204
|
+
|
|
214
205
|
return idx
|
|
215
|
-
boxes = jnp.reshape(
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
jnp.reshape(jnp.arange(max_output_size), [1, -1, 1]) < jnp.reshape(
|
|
220
|
-
output_size, [-1, 1, 1])).astype(boxes.dtype)
|
|
206
|
+
boxes = jnp.reshape((jnp.reshape(boxes, [-1, 4]))[idx],
|
|
207
|
+
[batch_size, max_output_size, 4])
|
|
208
|
+
boxes = boxes * (jnp.reshape(jnp.arange(max_output_size), [1, -1, 1])
|
|
209
|
+
< jnp.reshape(output_size, [-1, 1, 1])).astype(boxes.dtype)
|
|
221
210
|
scores = jnp.reshape(
|
|
222
|
-
jnp.reshape(scores, [-1, 1])[idx],
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
jnp.reshape(jnp.arange(max_output_size), [1, -1]) < jnp.reshape(
|
|
226
|
-
output_size, [-1, 1])).astype(scores.dtype)
|
|
211
|
+
jnp.reshape(scores, [-1, 1])[idx], [batch_size, max_output_size])
|
|
212
|
+
scores = scores * (jnp.reshape(jnp.arange(max_output_size), [1, -1])
|
|
213
|
+
< jnp.reshape(output_size, [-1, 1])).astype(scores.dtype)
|
|
227
214
|
return scores, boxes
|
|
228
215
|
|
|
229
216
|
|
|
230
217
|
# registry:
|
|
231
218
|
|
|
219
|
+
|
|
232
220
|
def nms(boxes, scores, iou_threshold):
|
|
233
221
|
max_output_size = boxes.shape[0]
|
|
234
222
|
boxes = boxes.reshape((1, *boxes.shape))
|
|
235
223
|
scores = scores.reshape((1, *scores.shape))
|
|
236
|
-
res = non_max_suppression_padded(scores, boxes, max_output_size,
|
|
224
|
+
res = non_max_suppression_padded(scores, boxes, max_output_size,
|
|
225
|
+
iou_threshold)
|
|
237
226
|
return res
|
|
238
227
|
|
|
239
228
|
|
|
@@ -242,4 +231,4 @@ try:
|
|
|
242
231
|
import torchvision
|
|
243
232
|
ops_registry.register_torch_dispatch_op(torch.ops.torchvision.nms, nms)
|
|
244
233
|
except Exception:
|
|
245
|
-
pass
|
|
234
|
+
pass
|