torchax 0.0.10.dev20251118__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/CONTRIBUTING.md +43 -0
- torchax/__init__.py +149 -0
- torchax/amp.py +218 -0
- torchax/checkpoint.py +85 -0
- torchax/config.py +44 -0
- torchax/decompositions.py +796 -0
- torchax/device_module.py +47 -0
- torchax/export.py +258 -0
- torchax/flax.py +55 -0
- torchax/interop.py +369 -0
- torchax/mesh_util.py +236 -0
- torchax/ops/__init__.py +25 -0
- torchax/ops/jaten.py +5753 -0
- torchax/ops/jax_reimplement.py +211 -0
- torchax/ops/jc10d.py +64 -0
- torchax/ops/jimage.py +122 -0
- torchax/ops/jlibrary.py +94 -0
- torchax/ops/jtorch.py +608 -0
- torchax/ops/jtorchvision_nms.py +268 -0
- torchax/ops/mappings.py +139 -0
- torchax/ops/op_base.py +137 -0
- torchax/ops/ops_registry.py +74 -0
- torchax/tensor.py +732 -0
- torchax/train.py +130 -0
- torchax/types.py +27 -0
- torchax/util.py +104 -0
- torchax/view.py +399 -0
- torchax-0.0.10.dev20251118.dist-info/METADATA +507 -0
- torchax-0.0.10.dev20251118.dist-info/RECORD +31 -0
- torchax-0.0.10.dev20251118.dist-info/WHEEL +4 -0
- torchax-0.0.10.dev20251118.dist-info/licenses/LICENSE +201 -0
torchax/ops/jtorch.py
ADDED
|
@@ -0,0 +1,608 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Tensor constructor overrides"""
|
|
16
|
+
|
|
17
|
+
import collections.abc
|
|
18
|
+
import functools
|
|
19
|
+
import math
|
|
20
|
+
from collections.abc import Sequence
|
|
21
|
+
|
|
22
|
+
import jax
|
|
23
|
+
import jax.numpy as jnp
|
|
24
|
+
import numpy as np
|
|
25
|
+
import torch
|
|
26
|
+
import torch.utils._pytree as pytree
|
|
27
|
+
from jax.experimental.pallas.ops.tpu import flash_attention
|
|
28
|
+
from jax.experimental.shard_map import shard_map
|
|
29
|
+
from jax.sharding import PartitionSpec
|
|
30
|
+
|
|
31
|
+
import torchax.tensor
|
|
32
|
+
from torchax.ops import jaten, jimage, mappings, op_base
|
|
33
|
+
from torchax.ops.ops_registry import register_torch_function_op
|
|
34
|
+
from torchax.view import NarrowInfo, View
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def register_function(torch_func, **kwargs):
|
|
38
|
+
return functools.partial(register_torch_function_op, torch_func, **kwargs)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@register_function(torch.as_tensor, is_jax_function=False, needs_env=True)
|
|
42
|
+
@op_base.convert_dtype(use_default_dtype=False) # Attempt to infer type from elements
|
|
43
|
+
def _as_tensor(data, dtype=None, device=None, env=None):
|
|
44
|
+
if isinstance(data, torch.Tensor):
|
|
45
|
+
return env._to_copy(data, dtype, device)
|
|
46
|
+
if isinstance(data, np.ndarray):
|
|
47
|
+
jax_res = jnp.asarray(data)
|
|
48
|
+
else:
|
|
49
|
+
jax_res = _tensor(data, dtype=dtype)
|
|
50
|
+
return torchax.tensor.Tensor(jax_res, env)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@register_function(torch.tensor)
|
|
54
|
+
@op_base.convert_dtype(use_default_dtype=False) # Attempt to infer type from elements
|
|
55
|
+
def _tensor(data, *, dtype=None, **kwargs):
|
|
56
|
+
python_types_to_torch_types = {
|
|
57
|
+
bool: jnp.bool,
|
|
58
|
+
int: jnp.int64,
|
|
59
|
+
float: jnp.float32,
|
|
60
|
+
complex: jnp.complex64,
|
|
61
|
+
}
|
|
62
|
+
if not dtype:
|
|
63
|
+
leaves = jax.tree_util.tree_leaves(data)
|
|
64
|
+
if len(leaves) > 0:
|
|
65
|
+
dtype = python_types_to_torch_types.get(type(leaves[0]))
|
|
66
|
+
|
|
67
|
+
return jnp.array(data, dtype=dtype or mappings.t2j_dtype(torch.get_default_dtype()))
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@register_function(torch.allclose)
|
|
71
|
+
def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):
|
|
72
|
+
return jnp.allclose(input, other, rtol, atol, equal_nan)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@register_function(torch.angle)
|
|
76
|
+
def _torch_angle(input):
|
|
77
|
+
if input.dtype.name == "int64":
|
|
78
|
+
input = input.astype(jnp.dtype("float32"))
|
|
79
|
+
return jnp.angle(input)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@register_function(torch.argsort)
|
|
83
|
+
def _torch_argsort(input, dim=-1, descending=False, stable=False):
|
|
84
|
+
expanded = False
|
|
85
|
+
if input.ndim == 0:
|
|
86
|
+
# for self of rank 0:
|
|
87
|
+
# torch.any(x, 0), torch.any(x, -1) works;
|
|
88
|
+
# torch.any(x, 1) throws out of bounds, so it's
|
|
89
|
+
# behavior is the same as a jnp array of rank 1
|
|
90
|
+
expanded = True
|
|
91
|
+
input = jnp.expand_dims(input, 0)
|
|
92
|
+
res = jnp.argsort(input, axis=dim, descending=descending, stable=stable)
|
|
93
|
+
if expanded:
|
|
94
|
+
res = res.squeeze()
|
|
95
|
+
return res
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@register_function(torch.diag)
|
|
99
|
+
def _diag(input, diagonal=0):
|
|
100
|
+
return jnp.diag(input, k=diagonal)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@register_function(torch.einsum)
|
|
104
|
+
@register_function(torch.ops.aten.einsum)
|
|
105
|
+
def _einsum(equation, *operands):
|
|
106
|
+
def get_params(*a):
|
|
107
|
+
inner_list = a[0]
|
|
108
|
+
if not isinstance(inner_list, jax.Array):
|
|
109
|
+
if len(inner_list) == 1:
|
|
110
|
+
A = inner_list
|
|
111
|
+
return A
|
|
112
|
+
elif len(inner_list) == 2:
|
|
113
|
+
A, B = inner_list
|
|
114
|
+
return A, B
|
|
115
|
+
return operands
|
|
116
|
+
|
|
117
|
+
assert isinstance(equation, str), "Only accept str equation"
|
|
118
|
+
filtered_operands = get_params(*operands)
|
|
119
|
+
return jnp.einsum(equation, *filtered_operands)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def _sdpa_reference(
|
|
123
|
+
query,
|
|
124
|
+
key,
|
|
125
|
+
value,
|
|
126
|
+
attn_mask=None,
|
|
127
|
+
dropout_p=0.0,
|
|
128
|
+
is_causal=False,
|
|
129
|
+
scale=None,
|
|
130
|
+
enable_gqa=False,
|
|
131
|
+
) -> torch.Tensor:
|
|
132
|
+
L, S = query.size(-2), key.size(-2)
|
|
133
|
+
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
|
|
134
|
+
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
|
|
135
|
+
if is_causal:
|
|
136
|
+
assert attn_mask is None
|
|
137
|
+
temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
|
|
138
|
+
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
|
139
|
+
attn_bias.to(query.dtype)
|
|
140
|
+
if attn_mask is not None:
|
|
141
|
+
if attn_mask.dtype == torch.bool:
|
|
142
|
+
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
|
143
|
+
else:
|
|
144
|
+
attn_bias += attn_mask
|
|
145
|
+
if enable_gqa:
|
|
146
|
+
key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
|
|
147
|
+
value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)
|
|
148
|
+
|
|
149
|
+
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
|
150
|
+
attn_weight += attn_bias
|
|
151
|
+
attn_weight = torch.softmax(attn_weight, dim=-1)
|
|
152
|
+
if dropout_p > 0:
|
|
153
|
+
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
|
|
154
|
+
return attn_weight @ value
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def _tpu_flash_attention(query, key, value, env):
|
|
158
|
+
fsdp_partition = PartitionSpec("fsdp")
|
|
159
|
+
|
|
160
|
+
def wrap_flash_attention(query, key, value):
|
|
161
|
+
block_sizes = flash_attention.BlockSizes(
|
|
162
|
+
block_b=min(2, query.shape[0]),
|
|
163
|
+
block_q=min(512, query.shape[2]),
|
|
164
|
+
block_k_major=min(512, key.shape[2]),
|
|
165
|
+
block_k=min(512, key.shape[2]),
|
|
166
|
+
block_q_major_dkv=min(512, query.shape[2]),
|
|
167
|
+
block_k_major_dkv=min(512, key.shape[2]),
|
|
168
|
+
block_k_dkv=min(512, key.shape[2]),
|
|
169
|
+
block_q_dkv=min(512, query.shape[2]),
|
|
170
|
+
block_k_major_dq=min(512, key.shape[2]),
|
|
171
|
+
block_k_dq=min(256, key.shape[2]),
|
|
172
|
+
block_q_dq=min(1024, query.shape[2]),
|
|
173
|
+
)
|
|
174
|
+
return flash_attention.flash_attention(
|
|
175
|
+
query, key, value, causal=True, block_sizes=block_sizes
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
if env.config.shmap_flash_attention:
|
|
179
|
+
wrap_flash_attention = shard_map(
|
|
180
|
+
wrap_flash_attention,
|
|
181
|
+
mesh=env._mesh,
|
|
182
|
+
in_specs=(fsdp_partition, fsdp_partition, fsdp_partition),
|
|
183
|
+
out_specs=fsdp_partition,
|
|
184
|
+
check_rep=False,
|
|
185
|
+
)
|
|
186
|
+
# return flash_attn_mapped(query, key, value)
|
|
187
|
+
return wrap_flash_attention(query, key, value)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
@register_function(torch.nn.functional.one_hot)
|
|
191
|
+
def one_hot(tensor, num_classes=-1):
|
|
192
|
+
if num_classes == -1:
|
|
193
|
+
num_classes = jnp.max(tensor) + 1
|
|
194
|
+
return jax.nn.one_hot(tensor, num_classes).astype(jnp.int64)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
@register_function(torch.nn.functional.pad)
|
|
198
|
+
def pad(tensor, pad, mode="constant", value=None):
|
|
199
|
+
# For padding modes that have different names between Torch and NumPy, this
|
|
200
|
+
# dict provides a Torch-to-NumPy translation. Any string not in this dict will
|
|
201
|
+
# be passed through as-is.
|
|
202
|
+
MODE_NAME_TRANSLATION = {
|
|
203
|
+
"circular": "wrap",
|
|
204
|
+
"replicate": "edge",
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
numpy_mode = MODE_NAME_TRANSLATION.get(mode, mode)
|
|
208
|
+
|
|
209
|
+
num_prefix_dims = tensor.ndim - len(pad) // 2
|
|
210
|
+
|
|
211
|
+
numpy_pad_width = [(0, 0)] * num_prefix_dims
|
|
212
|
+
nd_slice = [slice(None)] * num_prefix_dims
|
|
213
|
+
|
|
214
|
+
for i in range(len(pad) - 2, -1, -2):
|
|
215
|
+
pad_start, pad_end = pad[i : i + 2]
|
|
216
|
+
slice_start, slice_end = None, None
|
|
217
|
+
|
|
218
|
+
if pad_start < 0:
|
|
219
|
+
slice_start = -pad_start
|
|
220
|
+
pad_start = 0
|
|
221
|
+
|
|
222
|
+
if pad_end < 0:
|
|
223
|
+
slice_end = pad_end
|
|
224
|
+
pad_end = 0
|
|
225
|
+
|
|
226
|
+
numpy_pad_width.append((pad_start, pad_end))
|
|
227
|
+
nd_slice.append(slice(slice_start, slice_end))
|
|
228
|
+
|
|
229
|
+
nd_slice = tuple(nd_slice)
|
|
230
|
+
|
|
231
|
+
# `jax.numpy.pad` complains if we provide an irrelevant `constant_values` arg,
|
|
232
|
+
# even if the value we pass in is `None`. (It treats `None` as `nan`.)
|
|
233
|
+
kwargs = {}
|
|
234
|
+
if mode == "constant" and value is not None:
|
|
235
|
+
kwargs["constant_values"] = value
|
|
236
|
+
|
|
237
|
+
# The "replicate" mode pads first and then slices, whereas the "circular" mode
|
|
238
|
+
# slices first and then pads. The latter approach deals with smaller tensors,
|
|
239
|
+
# so we default to that option in modes where the order of operations doesn't
|
|
240
|
+
# affect the result.
|
|
241
|
+
if mode == "replicate":
|
|
242
|
+
return jnp.pad(tensor, numpy_pad_width, mode=numpy_mode, **kwargs)[nd_slice]
|
|
243
|
+
else:
|
|
244
|
+
return jnp.pad(tensor[nd_slice], numpy_pad_width, mode=numpy_mode, **kwargs)
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
@register_function(
|
|
248
|
+
torch.nn.functional.scaled_dot_product_attention,
|
|
249
|
+
is_jax_function=False,
|
|
250
|
+
needs_env=True,
|
|
251
|
+
)
|
|
252
|
+
@register_function(
|
|
253
|
+
torch.ops.aten.scaled_dot_product_attention,
|
|
254
|
+
is_jax_function=False,
|
|
255
|
+
needs_env=True,
|
|
256
|
+
)
|
|
257
|
+
def scaled_dot_product_attention(
|
|
258
|
+
query,
|
|
259
|
+
key,
|
|
260
|
+
value,
|
|
261
|
+
attn_mask=None,
|
|
262
|
+
dropout_p=0.0,
|
|
263
|
+
is_causal=False,
|
|
264
|
+
scale=None,
|
|
265
|
+
enable_gqa=False,
|
|
266
|
+
env=None,
|
|
267
|
+
) -> torch.Tensor:
|
|
268
|
+
if env.config.use_tpu_flash_attention:
|
|
269
|
+
jquery, jkey, jvalue = env.t2j_iso((query, key, value))
|
|
270
|
+
res = _tpu_flash_attention(jquery, jkey, jvalue, env)
|
|
271
|
+
return env.j2t_iso(res)
|
|
272
|
+
|
|
273
|
+
return _sdpa_reference(
|
|
274
|
+
query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
@register_function(torch.Tensor.__getitem__, is_jax_function=False, is_view_op=True)
|
|
279
|
+
def getitem(self, indexes):
|
|
280
|
+
if isinstance(indexes, list) and isinstance(indexes[0], int):
|
|
281
|
+
# list of int, i.e. x[[1, 2]] NOT x[1, 2] (the second would be tuple of int)
|
|
282
|
+
indexes = (indexes,)
|
|
283
|
+
elif isinstance(indexes, list):
|
|
284
|
+
indexes = tuple(indexes)
|
|
285
|
+
|
|
286
|
+
def is_narrow_slicing():
|
|
287
|
+
tensor_free = not pytree.tree_any(
|
|
288
|
+
lambda x: isinstance(x, torch.Tensor) or isinstance(x, jax.Array),
|
|
289
|
+
indexes,
|
|
290
|
+
)
|
|
291
|
+
list_free = not isinstance(indexes, tuple) or all(
|
|
292
|
+
False if isinstance(x, list) else True for x in indexes
|
|
293
|
+
)
|
|
294
|
+
return tensor_free and list_free
|
|
295
|
+
|
|
296
|
+
if is_narrow_slicing():
|
|
297
|
+
return View(self, view_info=NarrowInfo(indexes), env=self._env)
|
|
298
|
+
|
|
299
|
+
indexes = self._env.t2j_iso(indexes)
|
|
300
|
+
return torchax.tensor.Tensor(self._elem[indexes], self._env)
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
@register_function(torch.corrcoef)
|
|
304
|
+
def _corrcoef(x):
|
|
305
|
+
if x.dtype.name == "int64":
|
|
306
|
+
return jnp.corrcoef(x).astype(jnp.float32)
|
|
307
|
+
return jnp.corrcoef(x)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
@register_function(torch.sparse.mm, is_jax_function=False)
|
|
311
|
+
def _sparse_mm(mat1, mat2, reduce="sum"):
|
|
312
|
+
return torch.mm(mat1, mat2)
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
@register_function(torch.isclose)
|
|
316
|
+
def _aten_isclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):
|
|
317
|
+
return jnp.isclose(input, other, rtol, atol, equal_nan)
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
@register_function(torch.linalg.det)
|
|
321
|
+
def linalg_det(input):
|
|
322
|
+
return jnp.linalg.det(input)
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
@register_function(torch.ones)
|
|
326
|
+
def _ones(*size: int, dtype=None, **kwargs):
|
|
327
|
+
if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
|
|
328
|
+
size = size[0]
|
|
329
|
+
return jaten._ones(size, dtype=dtype)
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
@register_function(torch.zeros, is_jax_function=True)
|
|
333
|
+
def _zeros(*size: int, dtype=None, **kwargs):
|
|
334
|
+
if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
|
|
335
|
+
size = size[0]
|
|
336
|
+
return jaten._zeros(size, dtype=dtype)
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
@register_function(torch.eye)
|
|
340
|
+
@op_base.convert_dtype()
|
|
341
|
+
def _eye(n: int, m: int | None = None, *, dtype=None, **kwargs):
|
|
342
|
+
return jnp.eye(n, m, dtype=dtype)
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
@register_function(torch.full)
|
|
346
|
+
@op_base.convert_dtype(use_default_dtype=False)
|
|
347
|
+
def _full(size: Sequence[int], fill_value, *, dtype=None, **kwargs):
|
|
348
|
+
# TODO: handle torch.Size
|
|
349
|
+
return jnp.full(size, fill_value, dtype=dtype)
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
@register_function(torch.empty)
|
|
353
|
+
@op_base.convert_dtype()
|
|
354
|
+
def empty(*size: Sequence[int], dtype=None, **kwargs):
|
|
355
|
+
if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
|
|
356
|
+
size = size[0]
|
|
357
|
+
return jnp.empty(size, dtype=dtype)
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
@register_function(torch.arange, is_jax_function=True)
|
|
361
|
+
def arange(
|
|
362
|
+
start,
|
|
363
|
+
end=None,
|
|
364
|
+
step=None,
|
|
365
|
+
out=None,
|
|
366
|
+
dtype=None,
|
|
367
|
+
layout=torch.strided,
|
|
368
|
+
device=None,
|
|
369
|
+
requires_grad=False,
|
|
370
|
+
pin_memory=None,
|
|
371
|
+
):
|
|
372
|
+
if end is None:
|
|
373
|
+
end = start
|
|
374
|
+
start = 0
|
|
375
|
+
if step is None:
|
|
376
|
+
step = 1
|
|
377
|
+
return jaten._aten_arange(start, end, step, dtype=dtype)
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
@register_function(torch.empty_strided, is_jax_function=True)
|
|
381
|
+
def empty_strided(
|
|
382
|
+
size,
|
|
383
|
+
stride,
|
|
384
|
+
*,
|
|
385
|
+
dtype=None,
|
|
386
|
+
layout=None,
|
|
387
|
+
device=None,
|
|
388
|
+
requires_grad=False,
|
|
389
|
+
pin_memory=False,
|
|
390
|
+
):
|
|
391
|
+
return empty(size, dtype=dtype, requires_grad=requires_grad)
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
@register_function(torch.unravel_index)
|
|
395
|
+
def unravel_index(indices, shape):
|
|
396
|
+
return jnp.unravel_index(indices, shape)
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
@register_function(torch.rand, is_jax_function=True, needs_env=True)
|
|
400
|
+
def rand(*size, **kwargs):
|
|
401
|
+
if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
|
|
402
|
+
size = size[0]
|
|
403
|
+
return jaten._rand(size, **kwargs)
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
@register_function(torch.randn, is_jax_function=True, needs_env=True)
|
|
407
|
+
def randn(
|
|
408
|
+
*size,
|
|
409
|
+
generator=None,
|
|
410
|
+
out=None,
|
|
411
|
+
dtype=None,
|
|
412
|
+
layout=torch.strided,
|
|
413
|
+
device=None,
|
|
414
|
+
requires_grad=False,
|
|
415
|
+
pin_memory=False,
|
|
416
|
+
env=None,
|
|
417
|
+
):
|
|
418
|
+
if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
|
|
419
|
+
size = size[0]
|
|
420
|
+
return jaten._aten_randn(size, generator=generator, dtype=dtype, env=env)
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
@register_function(torch.randint, is_jax_function=False, needs_env=True)
|
|
424
|
+
def randint(*args, **kwargs):
|
|
425
|
+
return jaten._aten_randint(*args, **kwargs)
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
@register_function(torch.logdet)
|
|
429
|
+
def logdet(input):
|
|
430
|
+
_, logabsdet = jaten._aten__linalg_slogdet(input)
|
|
431
|
+
return logabsdet
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
@register_function(torch.linalg.slogdet)
|
|
435
|
+
def linalg_slogdet(input):
|
|
436
|
+
sign, logabsdet = jaten._aten__linalg_slogdet(input)
|
|
437
|
+
return torch.return_types.slogdet((sign, logabsdet))
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
@register_function(torch.tensor_split)
|
|
441
|
+
def tensor_split(input, indices_or_sections, dim=0):
|
|
442
|
+
return jnp.array_split(input, indices_or_sections, axis=dim)
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
@register_function(torch.linalg.solve)
|
|
446
|
+
def linalg_solve(a, b):
|
|
447
|
+
res, _ = jaten._aten__linalg_solve_ex(a, b)
|
|
448
|
+
return res
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
@register_function(torch.linalg.solve_ex)
|
|
452
|
+
def linalg_solve_ex(a, b):
|
|
453
|
+
res, info = jaten._aten__linalg_solve_ex(a, b)
|
|
454
|
+
return res, info
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
@register_function(torch.linalg.svd)
|
|
458
|
+
def linalg_svd(a, full_matrices=True):
|
|
459
|
+
return jaten._aten__linalg_svd(a, full_matrices=full_matrices)
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
@register_function(torch.linalg.matrix_power)
|
|
463
|
+
def matrix_power(A, n, *, out=None):
|
|
464
|
+
return jnp.linalg.matrix_power(A, n)
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
@register_function(torch.svd)
|
|
468
|
+
def svd(a, some=True, compute_uv=True):
|
|
469
|
+
if not compute_uv:
|
|
470
|
+
S = jaten._aten__linalg_svd(a, full_matrices=False)[1]
|
|
471
|
+
U = jnp.zeros((a.shape[-2], a.shape[-2]), dtype=a.dtype)
|
|
472
|
+
V = jnp.zeros((a.shape[-1], a.shape[-1]), dtype=a.dtype)
|
|
473
|
+
return U, S, V
|
|
474
|
+
U, S, V = jaten._aten__linalg_svd(a, full_matrices=not some)
|
|
475
|
+
return U, S, jnp.matrix_transpose(V)
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
@register_function(torch.cdist)
|
|
479
|
+
def _cdist(x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary"):
|
|
480
|
+
return jaten._aten_cdist(x1, x2, p, compute_mode)
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
@register_function(torch.lu)
|
|
484
|
+
def lu(A, **kwargs):
|
|
485
|
+
lu, pivots, _ = jax.lax.linalg.lu(A)
|
|
486
|
+
# JAX pivots are offset by 1 compared to torch
|
|
487
|
+
_pivots = pivots + 1
|
|
488
|
+
info_shape = pivots.shape[:-1]
|
|
489
|
+
info = jnp.zeros(info_shape, dtype=mappings.t2j_dtype(torch.int32))
|
|
490
|
+
if kwargs["get_infos"]:
|
|
491
|
+
return lu, _pivots, info
|
|
492
|
+
return lu, _pivots
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
@register_function(torch.lu_solve)
|
|
496
|
+
def lu_solve(b, LU_data, LU_pivots, **kwargs):
|
|
497
|
+
# JAX pivots are offset by 1 compared to torch
|
|
498
|
+
_pivots = LU_pivots - 1
|
|
499
|
+
x = jax.scipy.linalg.lu_solve((LU_data, _pivots), b)
|
|
500
|
+
return x
|
|
501
|
+
|
|
502
|
+
|
|
503
|
+
@register_function(torch.linalg.tensorsolve)
|
|
504
|
+
def linalg_tensorsolve(A, b, dims=None):
|
|
505
|
+
# examples:
|
|
506
|
+
# A = torch.randn(2, 3, 6), b = torch.randn(3, 2)
|
|
507
|
+
# A = torch.randn(2, 3, 6), b = torch.randn(2, 3) -> torch.Size([3, 6])
|
|
508
|
+
# A = torch.randn(9, 2, 6, 3) b = torch.randn(6, 3) -> torch.Size([6, 3])
|
|
509
|
+
# A = torch.randn(9, 2, 3, 6) b = torch.randn(6, 3) -> torch.Size([3, 6])
|
|
510
|
+
# A = torch.randn(18, 6, 3) b = torch.randn(18) -> torch.Size([6, 3])
|
|
511
|
+
# A = torch.randn(3, 8, 4, 6) b = torch.randn(4, 6) -> torch.Size([4,6])
|
|
512
|
+
# A = torch.randn(3, 8, 1, 2, 2, 6) b = torch.randn(3, 4, 2) -> torch.Size([2, 2, 6])
|
|
513
|
+
|
|
514
|
+
# torch allows b to be shaped differently.
|
|
515
|
+
# especially when axes are moved using dims.
|
|
516
|
+
# ValueError: After moving axes to end, leading shape of a must match shape of b. got a.shape=(3, 2, 6), b.shape=(2, 3)
|
|
517
|
+
# So we are handling the moveaxis and forcing b's shape to match what jax expects
|
|
518
|
+
if dims is not None:
|
|
519
|
+
A = jnp.moveaxis(A, dims, len(dims) * (A.ndim - 1,))
|
|
520
|
+
dims = None
|
|
521
|
+
if A.shape[: b.ndim] != b.shape:
|
|
522
|
+
b = jnp.reshape(b, A.shape[: b.ndim])
|
|
523
|
+
return jnp.linalg.tensorsolve(A, b, axes=dims)
|
|
524
|
+
|
|
525
|
+
|
|
526
|
+
@register_function(torch.nn.functional.linear)
|
|
527
|
+
def functional_linear(self, weights, bias=None):
|
|
528
|
+
res = jnp.einsum("...a,ba->...b", self, weights)
|
|
529
|
+
if bias is not None:
|
|
530
|
+
res += bias
|
|
531
|
+
return res
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
@register_function(torch.nn.functional.interpolate)
|
|
535
|
+
def functional_interpolate(
|
|
536
|
+
input,
|
|
537
|
+
size: tuple[int, int],
|
|
538
|
+
scale_factor: float | None,
|
|
539
|
+
mode: str,
|
|
540
|
+
align_corners: bool,
|
|
541
|
+
recompute_scale_factor: bool,
|
|
542
|
+
antialias: bool,
|
|
543
|
+
):
|
|
544
|
+
supported_methods = (
|
|
545
|
+
"nearest",
|
|
546
|
+
"linear",
|
|
547
|
+
"bilinear",
|
|
548
|
+
"trilinear",
|
|
549
|
+
"cubic",
|
|
550
|
+
"bicubic",
|
|
551
|
+
"tricubic",
|
|
552
|
+
"lanczos3",
|
|
553
|
+
"lanczos5",
|
|
554
|
+
)
|
|
555
|
+
is_jax_supported = mode in supported_methods
|
|
556
|
+
if not is_jax_supported:
|
|
557
|
+
raise torchax.tensor.OperatorNotFound(
|
|
558
|
+
f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}"
|
|
559
|
+
)
|
|
560
|
+
# None check
|
|
561
|
+
antialias = antialias or False
|
|
562
|
+
align_corners = align_corners or False
|
|
563
|
+
|
|
564
|
+
if mode in ("cubic", "bicubic", "tricubic") and not antialias and size is not None:
|
|
565
|
+
return jimage.interpolate_bicubic_no_aa(
|
|
566
|
+
input,
|
|
567
|
+
size[0],
|
|
568
|
+
size[1],
|
|
569
|
+
align_corners,
|
|
570
|
+
)
|
|
571
|
+
else:
|
|
572
|
+
# fallback
|
|
573
|
+
raise torchax.tensor.OperatorNotFound(
|
|
574
|
+
f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}"
|
|
575
|
+
)
|
|
576
|
+
|
|
577
|
+
|
|
578
|
+
@register_function(torch.Tensor.repeat_interleave)
|
|
579
|
+
def torch_Tensor_repeat_interleave(self, repeats, dim=None, *, output_size=None):
|
|
580
|
+
return jnp.repeat(self, repeats, axis=dim, total_repeat_length=output_size)
|
|
581
|
+
|
|
582
|
+
|
|
583
|
+
@register_function(torch.nn.functional.max_pool2d)
|
|
584
|
+
def _functional_max_pool2d(
|
|
585
|
+
input,
|
|
586
|
+
kernel_size,
|
|
587
|
+
stride=None,
|
|
588
|
+
padding=0,
|
|
589
|
+
dilation=1,
|
|
590
|
+
ceil_mode=False,
|
|
591
|
+
return_indices=False,
|
|
592
|
+
):
|
|
593
|
+
if isinstance(kernel_size, int):
|
|
594
|
+
kernel_size = (kernel_size, kernel_size)
|
|
595
|
+
if stride is None:
|
|
596
|
+
stride = kernel_size
|
|
597
|
+
if isinstance(stride, int):
|
|
598
|
+
stride = (stride, stride)
|
|
599
|
+
|
|
600
|
+
return jaten.max_pool(
|
|
601
|
+
input,
|
|
602
|
+
kernel_size,
|
|
603
|
+
stride,
|
|
604
|
+
padding,
|
|
605
|
+
dilation,
|
|
606
|
+
ceil_mode,
|
|
607
|
+
with_index=return_indices,
|
|
608
|
+
)
|