torchax 0.0.6__py3-none-any.whl → 0.0.10.dev20251116__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/CONTRIBUTING.md +10 -5
- torchax/__init__.py +92 -65
- torchax/amp.py +14 -0
- torchax/checkpoint.py +79 -0
- torchax/config.py +14 -0
- torchax/decompositions.py +14 -0
- torchax/device_module.py +14 -0
- torchax/export.py +14 -0
- torchax/flax.py +14 -0
- torchax/interop.py +44 -31
- torchax/mesh_util.py +14 -0
- torchax/ops/__init__.py +14 -0
- torchax/ops/jaten.py +3985 -3686
- torchax/ops/jax_reimplement.py +14 -0
- torchax/ops/jc10d.py +14 -0
- torchax/ops/jimage.py +14 -0
- torchax/ops/jlibrary.py +14 -0
- torchax/ops/jtorch.py +364 -309
- torchax/ops/jtorchvision_nms.py +14 -0
- torchax/ops/mappings.py +26 -4
- torchax/ops/op_base.py +14 -0
- torchax/ops/ops_registry.py +14 -0
- torchax/tensor.py +38 -13
- torchax/train.py +112 -97
- torchax/types.py +14 -0
- torchax/util.py +14 -0
- torchax/view.py +14 -0
- torchax-0.0.10.dev20251116.dist-info/METADATA +507 -0
- torchax-0.0.10.dev20251116.dist-info/RECORD +31 -0
- torchax-0.0.10.dev20251116.dist-info/licenses/LICENSE +201 -0
- torchax/configuration.py +0 -30
- torchax/environment.py +0 -1
- torchax/tf_integration.py +0 -119
- torchax-0.0.6.dist-info/METADATA +0 -307
- torchax-0.0.6.dist-info/RECORD +0 -33
- torchax-0.0.6.dist-info/licenses/LICENSE +0 -28
- {torchax-0.0.6.dist-info → torchax-0.0.10.dev20251116.dist-info}/WHEEL +0 -0
torchax/ops/jtorch.py
CHANGED
|
@@ -1,9 +1,24 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
"""Tensor constructor overrides"""
|
|
2
16
|
|
|
3
17
|
import math
|
|
4
18
|
import collections.abc
|
|
5
19
|
import functools
|
|
6
20
|
from typing import Optional, Sequence, Tuple
|
|
21
|
+
from jax._src.interpreters.mlir import wrap_with_memory_kind
|
|
7
22
|
import numpy as np
|
|
8
23
|
|
|
9
24
|
import jax
|
|
@@ -20,92 +35,94 @@ import torch.utils._pytree as pytree
|
|
|
20
35
|
|
|
21
36
|
|
|
22
37
|
def register_function(torch_func, **kwargs):
|
|
23
|
-
|
|
38
|
+
return functools.partial(register_torch_function_op, torch_func, **kwargs)
|
|
24
39
|
|
|
25
40
|
|
|
26
41
|
@register_function(torch.as_tensor, is_jax_function=False, needs_env=True)
|
|
27
42
|
@op_base.convert_dtype(
|
|
28
|
-
use_default_dtype=False
|
|
43
|
+
use_default_dtype=False
|
|
44
|
+
) # Attempt to infer type from elements
|
|
29
45
|
def _as_tensor(data, dtype=None, device=None, env=None):
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
46
|
+
if isinstance(data, torch.Tensor):
|
|
47
|
+
return env._to_copy(data, dtype, device)
|
|
48
|
+
if isinstance(data, np.ndarray):
|
|
49
|
+
jax_res = jnp.asarray(data)
|
|
50
|
+
else:
|
|
51
|
+
jax_res = _tensor(data, dtype=dtype)
|
|
52
|
+
return torchax.tensor.Tensor(jax_res, env)
|
|
37
53
|
|
|
38
54
|
|
|
39
55
|
@register_function(torch.tensor)
|
|
40
56
|
@op_base.convert_dtype(
|
|
41
|
-
use_default_dtype=False
|
|
57
|
+
use_default_dtype=False
|
|
58
|
+
) # Attempt to infer type from elements
|
|
42
59
|
def _tensor(data, *, dtype=None, **kwargs):
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
60
|
+
python_types_to_torch_types = {
|
|
61
|
+
bool: jnp.bool,
|
|
62
|
+
int: jnp.int64,
|
|
63
|
+
float: jnp.float32,
|
|
64
|
+
complex: jnp.complex64,
|
|
65
|
+
}
|
|
66
|
+
if not dtype:
|
|
67
|
+
leaves = jax.tree_util.tree_leaves(data)
|
|
68
|
+
if len(leaves) > 0:
|
|
69
|
+
dtype = python_types_to_torch_types.get(type(leaves[0]))
|
|
70
|
+
|
|
71
|
+
return jnp.array(
|
|
72
|
+
data, dtype=dtype or mappings.t2j_dtype(torch.get_default_dtype())
|
|
73
|
+
)
|
|
56
74
|
|
|
57
75
|
|
|
58
76
|
@register_function(torch.allclose)
|
|
59
77
|
def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):
|
|
60
|
-
|
|
78
|
+
return jnp.allclose(input, other, rtol, atol, equal_nan)
|
|
61
79
|
|
|
62
80
|
|
|
63
81
|
@register_function(torch.angle)
|
|
64
82
|
def _torch_angle(input):
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
83
|
+
if input.dtype.name == "int64":
|
|
84
|
+
input = input.astype(jnp.dtype("float32"))
|
|
85
|
+
return jnp.angle(input)
|
|
68
86
|
|
|
69
87
|
|
|
70
88
|
@register_function(torch.argsort)
|
|
71
89
|
def _torch_argsort(input, dim=-1, descending=False, stable=False):
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
90
|
+
expanded = False
|
|
91
|
+
if input.ndim == 0:
|
|
92
|
+
# for self of rank 0:
|
|
93
|
+
# torch.any(x, 0), torch.any(x, -1) works;
|
|
94
|
+
# torch.any(x, 1) throws out of bounds, so it's
|
|
95
|
+
# behavior is the same as a jnp array of rank 1
|
|
96
|
+
expanded = True
|
|
97
|
+
input = jnp.expand_dims(input, 0)
|
|
98
|
+
res = jnp.argsort(input, axis=dim, descending=descending, stable=stable)
|
|
99
|
+
if expanded:
|
|
100
|
+
res = res.squeeze()
|
|
101
|
+
return res
|
|
84
102
|
|
|
85
103
|
|
|
86
104
|
@register_function(torch.diag)
|
|
87
105
|
def _diag(input, diagonal=0):
|
|
88
|
-
|
|
106
|
+
return jnp.diag(input, k=diagonal)
|
|
89
107
|
|
|
90
108
|
|
|
91
109
|
@register_function(torch.einsum)
|
|
92
110
|
@register_function(torch.ops.aten.einsum)
|
|
93
111
|
def _einsum(equation, *operands):
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
return
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
return jnp.einsum(equation, *filtered_operands)
|
|
112
|
+
def get_params(*a):
|
|
113
|
+
inner_list = a[0]
|
|
114
|
+
if not isinstance(inner_list, jax.Array):
|
|
115
|
+
if len(inner_list) == 1:
|
|
116
|
+
A = inner_list
|
|
117
|
+
return A
|
|
118
|
+
elif len(inner_list) == 2:
|
|
119
|
+
A, B = inner_list
|
|
120
|
+
return A, B
|
|
121
|
+
return operands
|
|
122
|
+
|
|
123
|
+
assert isinstance(equation, str), "Only accept str equation"
|
|
124
|
+
filtered_operands = get_params(*operands)
|
|
125
|
+
return jnp.einsum(equation, *filtered_operands)
|
|
109
126
|
|
|
110
127
|
|
|
111
128
|
def _sdpa_reference(
|
|
@@ -118,122 +135,128 @@ def _sdpa_reference(
|
|
|
118
135
|
scale=None,
|
|
119
136
|
enable_gqa=False,
|
|
120
137
|
) -> torch.Tensor:
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
if attn_mask
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
138
|
+
L, S = query.size(-2), key.size(-2)
|
|
139
|
+
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
|
|
140
|
+
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
|
|
141
|
+
if is_causal:
|
|
142
|
+
assert attn_mask is None
|
|
143
|
+
temp_mask = torch.ones(
|
|
144
|
+
L, S, dtype=torch.bool, device=query.device
|
|
145
|
+
).tril(diagonal=0)
|
|
146
|
+
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
|
147
|
+
attn_bias.to(query.dtype)
|
|
148
|
+
if attn_mask is not None:
|
|
149
|
+
if attn_mask.dtype == torch.bool:
|
|
150
|
+
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
|
151
|
+
else:
|
|
152
|
+
attn_bias += attn_mask
|
|
153
|
+
if enable_gqa:
|
|
154
|
+
key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
|
|
155
|
+
value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)
|
|
156
|
+
|
|
157
|
+
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
|
158
|
+
attn_weight += attn_bias
|
|
159
|
+
attn_weight = torch.softmax(attn_weight, dim=-1)
|
|
160
|
+
if dropout_p > 0:
|
|
161
|
+
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
|
|
162
|
+
return attn_weight @ value
|
|
145
163
|
|
|
146
164
|
|
|
147
165
|
from jax.sharding import PartitionSpec
|
|
148
166
|
|
|
149
167
|
|
|
150
168
|
def _tpu_flash_attention(query, key, value, env):
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
wrap_flash_attention
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
169
|
+
fsdp_partition = PartitionSpec("fsdp")
|
|
170
|
+
|
|
171
|
+
def wrap_flash_attention(query, key, value):
|
|
172
|
+
block_sizes = flash_attention.BlockSizes(
|
|
173
|
+
block_b=min(2, query.shape[0]),
|
|
174
|
+
block_q=min(512, query.shape[2]),
|
|
175
|
+
block_k_major=min(512, key.shape[2]),
|
|
176
|
+
block_k=min(512, key.shape[2]),
|
|
177
|
+
block_q_major_dkv=min(512, query.shape[2]),
|
|
178
|
+
block_k_major_dkv=min(512, key.shape[2]),
|
|
179
|
+
block_k_dkv=min(512, key.shape[2]),
|
|
180
|
+
block_q_dkv=min(512, query.shape[2]),
|
|
181
|
+
block_k_major_dq=min(512, key.shape[2]),
|
|
182
|
+
block_k_dq=min(256, key.shape[2]),
|
|
183
|
+
block_q_dq=min(1024, query.shape[2]),
|
|
184
|
+
)
|
|
185
|
+
return flash_attention.flash_attention(
|
|
186
|
+
query, key, value, causal=True, block_sizes=block_sizes
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
if env.config.shmap_flash_attention:
|
|
190
|
+
wrap_flash_attention = shard_map(
|
|
191
|
+
wrap_flash_attention,
|
|
192
|
+
mesh=env._mesh,
|
|
193
|
+
in_specs=(fsdp_partition, fsdp_partition, fsdp_partition),
|
|
194
|
+
out_specs=fsdp_partition,
|
|
195
|
+
check_rep=False,
|
|
196
|
+
)
|
|
197
|
+
# return flash_attn_mapped(query, key, value)
|
|
198
|
+
return wrap_flash_attention(query, key, value)
|
|
180
199
|
|
|
181
200
|
|
|
182
201
|
@register_function(torch.nn.functional.one_hot)
|
|
183
202
|
def one_hot(tensor, num_classes=-1):
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
203
|
+
if num_classes == -1:
|
|
204
|
+
num_classes = jnp.max(tensor) + 1
|
|
205
|
+
return jax.nn.one_hot(tensor, num_classes).astype(jnp.int64)
|
|
187
206
|
|
|
188
207
|
|
|
189
208
|
@register_function(torch.nn.functional.pad)
|
|
190
209
|
def pad(tensor, pad, mode="constant", value=None):
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
210
|
+
# For padding modes that have different names between Torch and NumPy, this
|
|
211
|
+
# dict provides a Torch-to-NumPy translation. Any string not in this dict will
|
|
212
|
+
# be passed through as-is.
|
|
213
|
+
MODE_NAME_TRANSLATION = {
|
|
214
|
+
"circular": "wrap",
|
|
215
|
+
"replicate": "edge",
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
numpy_mode = MODE_NAME_TRANSLATION.get(mode, mode)
|
|
219
|
+
|
|
220
|
+
num_prefix_dims = tensor.ndim - len(pad) // 2
|
|
221
|
+
|
|
222
|
+
numpy_pad_width = [(0, 0)] * num_prefix_dims
|
|
223
|
+
nd_slice = [slice(None)] * num_prefix_dims
|
|
224
|
+
|
|
225
|
+
for i in range(len(pad) - 2, -1, -2):
|
|
226
|
+
pad_start, pad_end = pad[i : i + 2]
|
|
227
|
+
slice_start, slice_end = None, None
|
|
228
|
+
|
|
229
|
+
if pad_start < 0:
|
|
230
|
+
slice_start = -pad_start
|
|
231
|
+
pad_start = 0
|
|
232
|
+
|
|
233
|
+
if pad_end < 0:
|
|
234
|
+
slice_end = pad_end
|
|
235
|
+
pad_end = 0
|
|
236
|
+
|
|
237
|
+
numpy_pad_width.append((pad_start, pad_end))
|
|
238
|
+
nd_slice.append(slice(slice_start, slice_end))
|
|
239
|
+
|
|
240
|
+
nd_slice = tuple(nd_slice)
|
|
241
|
+
|
|
242
|
+
# `jax.numpy.pad` complains if we provide an irrelevant `constant_values` arg,
|
|
243
|
+
# even if the value we pass in is `None`. (It treats `None` as `nan`.)
|
|
244
|
+
kwargs = dict()
|
|
245
|
+
if mode == "constant" and value is not None:
|
|
246
|
+
kwargs["constant_values"] = value
|
|
247
|
+
|
|
248
|
+
# The "replicate" mode pads first and then slices, whereas the "circular" mode
|
|
249
|
+
# slices first and then pads. The latter approach deals with smaller tensors,
|
|
250
|
+
# so we default to that option in modes where the order of operations doesn't
|
|
251
|
+
# affect the result.
|
|
252
|
+
if mode == "replicate":
|
|
253
|
+
return jnp.pad(tensor, numpy_pad_width, mode=numpy_mode, **kwargs)[
|
|
254
|
+
nd_slice
|
|
255
|
+
]
|
|
256
|
+
else:
|
|
257
|
+
return jnp.pad(
|
|
258
|
+
tensor[nd_slice], numpy_pad_width, mode=numpy_mode, **kwargs
|
|
259
|
+
)
|
|
237
260
|
|
|
238
261
|
|
|
239
262
|
@register_function(
|
|
@@ -244,7 +267,8 @@ def pad(tensor, pad, mode="constant", value=None):
|
|
|
244
267
|
@register_function(
|
|
245
268
|
torch.ops.aten.scaled_dot_product_attention,
|
|
246
269
|
is_jax_function=False,
|
|
247
|
-
needs_env=True
|
|
270
|
+
needs_env=True,
|
|
271
|
+
)
|
|
248
272
|
def scaled_dot_product_attention(
|
|
249
273
|
query,
|
|
250
274
|
key,
|
|
@@ -256,96 +280,98 @@ def scaled_dot_product_attention(
|
|
|
256
280
|
enable_gqa=False,
|
|
257
281
|
env=None,
|
|
258
282
|
) -> torch.Tensor:
|
|
283
|
+
if env.config.use_tpu_flash_attention:
|
|
284
|
+
jquery, jkey, jvalue = env.t2j_iso((query, key, value))
|
|
285
|
+
res = _tpu_flash_attention(jquery, jkey, jvalue, env)
|
|
286
|
+
return env.j2t_iso(res)
|
|
259
287
|
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
return env.j2t_iso(res)
|
|
264
|
-
|
|
265
|
-
return _sdpa_reference(query, key, value, attn_mask, dropout_p, is_causal,
|
|
266
|
-
scale, enable_gqa)
|
|
288
|
+
return _sdpa_reference(
|
|
289
|
+
query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa
|
|
290
|
+
)
|
|
267
291
|
|
|
268
292
|
|
|
269
293
|
@register_function(
|
|
270
|
-
torch.Tensor.__getitem__, is_jax_function=False, is_view_op=True
|
|
294
|
+
torch.Tensor.__getitem__, is_jax_function=False, is_view_op=True
|
|
295
|
+
)
|
|
271
296
|
def getitem(self, indexes):
|
|
297
|
+
if isinstance(indexes, list) and isinstance(indexes[0], int):
|
|
298
|
+
# list of int, i.e. x[[1, 2]] NOT x[1, 2] (the second would be tuple of int)
|
|
299
|
+
indexes = (indexes,)
|
|
300
|
+
elif isinstance(indexes, list):
|
|
301
|
+
indexes = tuple(indexes)
|
|
272
302
|
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
indexes)
|
|
283
|
-
list_free = not isinstance(indexes, tuple) or all(
|
|
284
|
-
[False if isinstance(x, list) else True for x in indexes])
|
|
285
|
-
return tensor_free and list_free
|
|
303
|
+
def is_narrow_slicing():
|
|
304
|
+
tensor_free = not pytree.tree_any(
|
|
305
|
+
lambda x: isinstance(x, torch.Tensor) or isinstance(x, jax.Array),
|
|
306
|
+
indexes,
|
|
307
|
+
)
|
|
308
|
+
list_free = not isinstance(indexes, tuple) or all(
|
|
309
|
+
[False if isinstance(x, list) else True for x in indexes]
|
|
310
|
+
)
|
|
311
|
+
return tensor_free and list_free
|
|
286
312
|
|
|
287
|
-
|
|
288
|
-
|
|
313
|
+
if is_narrow_slicing():
|
|
314
|
+
return View(self, view_info=NarrowInfo(indexes), env=self._env)
|
|
289
315
|
|
|
290
|
-
|
|
291
|
-
|
|
316
|
+
indexes = self._env.t2j_iso(indexes)
|
|
317
|
+
return torchax.tensor.Tensor(self._elem[indexes], self._env)
|
|
292
318
|
|
|
293
319
|
|
|
294
320
|
@register_function(torch.corrcoef)
|
|
295
321
|
def _corrcoef(x):
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
322
|
+
if x.dtype.name == "int64":
|
|
323
|
+
return jnp.corrcoef(x).astype(jnp.float32)
|
|
324
|
+
return jnp.corrcoef(x)
|
|
299
325
|
|
|
300
326
|
|
|
301
327
|
@register_function(torch.sparse.mm, is_jax_function=False)
|
|
302
328
|
def _sparse_mm(mat1, mat2, reduce="sum"):
|
|
303
|
-
|
|
329
|
+
return torch.mm(mat1, mat2)
|
|
304
330
|
|
|
305
331
|
|
|
306
332
|
@register_function(torch.isclose)
|
|
307
333
|
def _aten_isclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):
|
|
308
|
-
|
|
334
|
+
return jnp.isclose(input, other, rtol, atol, equal_nan)
|
|
309
335
|
|
|
310
336
|
|
|
311
337
|
@register_function(torch.linalg.det)
|
|
312
338
|
def linalg_det(input):
|
|
313
|
-
|
|
339
|
+
return jnp.linalg.det(input)
|
|
314
340
|
|
|
315
341
|
|
|
316
342
|
@register_function(torch.ones)
|
|
317
343
|
def _ones(*size: int, dtype=None, **kwargs):
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
344
|
+
if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
|
|
345
|
+
size = size[0]
|
|
346
|
+
return jaten._ones(size, dtype=dtype)
|
|
321
347
|
|
|
322
348
|
|
|
323
349
|
@register_function(torch.zeros, is_jax_function=True)
|
|
324
350
|
def _zeros(*size: int, dtype=None, **kwargs):
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
351
|
+
if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
|
|
352
|
+
size = size[0]
|
|
353
|
+
return jaten._zeros(size, dtype=dtype)
|
|
328
354
|
|
|
329
355
|
|
|
330
356
|
@register_function(torch.eye)
|
|
331
357
|
@op_base.convert_dtype()
|
|
332
358
|
def _eye(n: int, m: Optional[int] = None, *, dtype=None, **kwargs):
|
|
333
|
-
|
|
359
|
+
return jnp.eye(n, m, dtype=dtype)
|
|
334
360
|
|
|
335
361
|
|
|
336
362
|
@register_function(torch.full)
|
|
337
|
-
@op_base.convert_dtype()
|
|
363
|
+
@op_base.convert_dtype(use_default_dtype=False)
|
|
338
364
|
def _full(size: Sequence[int], fill_value, *, dtype=None, **kwargs):
|
|
339
|
-
|
|
340
|
-
|
|
365
|
+
# TODO: handle torch.Size
|
|
366
|
+
return jnp.full(size, fill_value, dtype=dtype)
|
|
341
367
|
|
|
342
368
|
|
|
343
369
|
@register_function(torch.empty)
|
|
344
370
|
@op_base.convert_dtype()
|
|
345
371
|
def empty(*size: Sequence[int], dtype=None, **kwargs):
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
372
|
+
if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
|
|
373
|
+
size = size[0]
|
|
374
|
+
return jnp.empty(size, dtype=dtype)
|
|
349
375
|
|
|
350
376
|
|
|
351
377
|
@register_function(torch.arange, is_jax_function=True)
|
|
@@ -360,12 +386,12 @@ def arange(
|
|
|
360
386
|
requires_grad=False,
|
|
361
387
|
pin_memory=None,
|
|
362
388
|
):
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
389
|
+
if end is None:
|
|
390
|
+
end = start
|
|
391
|
+
start = 0
|
|
392
|
+
if step is None:
|
|
393
|
+
step = 1
|
|
394
|
+
return jaten._aten_arange(start, end, step, dtype=dtype)
|
|
369
395
|
|
|
370
396
|
|
|
371
397
|
@register_function(torch.empty_strided, is_jax_function=True)
|
|
@@ -379,19 +405,19 @@ def empty_strided(
|
|
|
379
405
|
requires_grad=False,
|
|
380
406
|
pin_memory=False,
|
|
381
407
|
):
|
|
382
|
-
|
|
408
|
+
return empty(size, dtype=dtype, requires_grad=requires_grad)
|
|
383
409
|
|
|
384
410
|
|
|
385
411
|
@register_function(torch.unravel_index)
|
|
386
412
|
def unravel_index(indices, shape):
|
|
387
|
-
|
|
413
|
+
return jnp.unravel_index(indices, shape)
|
|
388
414
|
|
|
389
415
|
|
|
390
416
|
@register_function(torch.rand, is_jax_function=True, needs_env=True)
|
|
391
417
|
def rand(*size, **kwargs):
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
418
|
+
if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
|
|
419
|
+
size = size[0]
|
|
420
|
+
return jaten._rand(size, **kwargs)
|
|
395
421
|
|
|
396
422
|
|
|
397
423
|
@register_function(torch.randn, is_jax_function=True, needs_env=True)
|
|
@@ -406,120 +432,120 @@ def randn(
|
|
|
406
432
|
pin_memory=False,
|
|
407
433
|
env=None,
|
|
408
434
|
):
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
435
|
+
if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
|
|
436
|
+
size = size[0]
|
|
437
|
+
return jaten._aten_randn(size, generator=generator, dtype=dtype, env=env)
|
|
412
438
|
|
|
413
439
|
|
|
414
440
|
@register_function(torch.randint, is_jax_function=False, needs_env=True)
|
|
415
441
|
def randint(*args, **kwargs):
|
|
416
|
-
|
|
442
|
+
return jaten._aten_randint(*args, **kwargs)
|
|
417
443
|
|
|
418
444
|
|
|
419
445
|
@register_function(torch.logdet)
|
|
420
446
|
def logdet(input):
|
|
421
|
-
|
|
422
|
-
|
|
447
|
+
_, logabsdet = jaten._aten__linalg_slogdet(input)
|
|
448
|
+
return logabsdet
|
|
423
449
|
|
|
424
450
|
|
|
425
451
|
@register_function(torch.linalg.slogdet)
|
|
426
452
|
def linalg_slogdet(input):
|
|
427
|
-
|
|
428
|
-
|
|
453
|
+
sign, logabsdet = jaten._aten__linalg_slogdet(input)
|
|
454
|
+
return torch.return_types.slogdet((sign, logabsdet))
|
|
429
455
|
|
|
430
456
|
|
|
431
457
|
@register_function(torch.tensor_split)
|
|
432
458
|
def tensor_split(input, indices_or_sections, dim=0):
|
|
433
|
-
|
|
459
|
+
return jnp.array_split(input, indices_or_sections, axis=dim)
|
|
434
460
|
|
|
435
461
|
|
|
436
462
|
@register_function(torch.linalg.solve)
|
|
437
463
|
def linalg_solve(a, b):
|
|
438
|
-
|
|
439
|
-
|
|
464
|
+
res, _ = jaten._aten__linalg_solve_ex(a, b)
|
|
465
|
+
return res
|
|
440
466
|
|
|
441
467
|
|
|
442
468
|
@register_function(torch.linalg.solve_ex)
|
|
443
469
|
def linalg_solve_ex(a, b):
|
|
444
|
-
|
|
445
|
-
|
|
470
|
+
res, info = jaten._aten__linalg_solve_ex(a, b)
|
|
471
|
+
return res, info
|
|
446
472
|
|
|
447
473
|
|
|
448
474
|
@register_function(torch.linalg.svd)
|
|
449
475
|
def linalg_svd(a, full_matrices=True):
|
|
450
|
-
|
|
476
|
+
return jaten._aten__linalg_svd(a, full_matrices=full_matrices)
|
|
451
477
|
|
|
452
478
|
|
|
453
479
|
@register_function(torch.linalg.matrix_power)
|
|
454
480
|
def matrix_power(A, n, *, out=None):
|
|
455
|
-
|
|
481
|
+
return jnp.linalg.matrix_power(A, n)
|
|
456
482
|
|
|
457
483
|
|
|
458
484
|
@register_function(torch.svd)
|
|
459
485
|
def svd(a, some=True, compute_uv=True):
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
486
|
+
if not compute_uv:
|
|
487
|
+
S = jaten._aten__linalg_svd(a, full_matrices=False)[1]
|
|
488
|
+
U = jnp.zeros((a.shape[-2], a.shape[-2]), dtype=a.dtype)
|
|
489
|
+
V = jnp.zeros((a.shape[-1], a.shape[-1]), dtype=a.dtype)
|
|
490
|
+
return U, S, V
|
|
491
|
+
U, S, V = jaten._aten__linalg_svd(a, full_matrices=not some)
|
|
492
|
+
return U, S, jnp.matrix_transpose(V)
|
|
467
493
|
|
|
468
494
|
|
|
469
495
|
@register_function(torch.cdist)
|
|
470
496
|
def _cdist(x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary"):
|
|
471
|
-
|
|
497
|
+
return jaten._aten_cdist(x1, x2, p, compute_mode)
|
|
472
498
|
|
|
473
499
|
|
|
474
500
|
@register_function(torch.lu)
|
|
475
501
|
def lu(A, **kwargs):
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
502
|
+
lu, pivots, _ = jax.lax.linalg.lu(A)
|
|
503
|
+
# JAX pivots are offset by 1 compared to torch
|
|
504
|
+
_pivots = pivots + 1
|
|
505
|
+
info_shape = pivots.shape[:-1]
|
|
506
|
+
info = jnp.zeros(info_shape, dtype=mappings.t2j_dtype(torch.int32))
|
|
507
|
+
if kwargs["get_infos"] == True:
|
|
508
|
+
return lu, _pivots, info
|
|
509
|
+
return lu, _pivots
|
|
484
510
|
|
|
485
511
|
|
|
486
512
|
@register_function(torch.lu_solve)
|
|
487
513
|
def lu_solve(b, LU_data, LU_pivots, **kwargs):
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
514
|
+
# JAX pivots are offset by 1 compared to torch
|
|
515
|
+
_pivots = LU_pivots - 1
|
|
516
|
+
x = jax.scipy.linalg.lu_solve((LU_data, _pivots), b)
|
|
517
|
+
return x
|
|
492
518
|
|
|
493
519
|
|
|
494
520
|
@register_function(torch.linalg.tensorsolve)
|
|
495
521
|
def linalg_tensorsolve(A, b, dims=None):
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
522
|
+
# examples:
|
|
523
|
+
# A = torch.randn(2, 3, 6), b = torch.randn(3, 2)
|
|
524
|
+
# A = torch.randn(2, 3, 6), b = torch.randn(2, 3) -> torch.Size([3, 6])
|
|
525
|
+
# A = torch.randn(9, 2, 6, 3) b = torch.randn(6, 3) -> torch.Size([6, 3])
|
|
526
|
+
# A = torch.randn(9, 2, 3, 6) b = torch.randn(6, 3) -> torch.Size([3, 6])
|
|
527
|
+
# A = torch.randn(18, 6, 3) b = torch.randn(18) -> torch.Size([6, 3])
|
|
528
|
+
# A = torch.randn(3, 8, 4, 6) b = torch.randn(4, 6) -> torch.Size([4,6])
|
|
529
|
+
# A = torch.randn(3, 8, 1, 2, 2, 6) b = torch.randn(3, 4, 2) -> torch.Size([2, 2, 6])
|
|
530
|
+
|
|
531
|
+
# torch allows b to be shaped differently.
|
|
532
|
+
# especially when axes are moved using dims.
|
|
533
|
+
# ValueError: After moving axes to end, leading shape of a must match shape of b. got a.shape=(3, 2, 6), b.shape=(2, 3)
|
|
534
|
+
# So we are handling the moveaxis and forcing b's shape to match what jax expects
|
|
535
|
+
if dims is not None:
|
|
536
|
+
A = jnp.moveaxis(A, dims, len(dims) * (A.ndim - 1,))
|
|
537
|
+
dims = None
|
|
538
|
+
if A.shape[: b.ndim] != b.shape:
|
|
539
|
+
b = jnp.reshape(b, A.shape[: b.ndim])
|
|
540
|
+
return jnp.linalg.tensorsolve(A, b, axes=dims)
|
|
515
541
|
|
|
516
542
|
|
|
517
543
|
@register_function(torch.nn.functional.linear)
|
|
518
544
|
def functional_linear(self, weights, bias=None):
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
545
|
+
res = jnp.einsum("...a,ba->...b", self, weights)
|
|
546
|
+
if bias is not None:
|
|
547
|
+
res += bias
|
|
548
|
+
return res
|
|
523
549
|
|
|
524
550
|
|
|
525
551
|
@register_function(torch.nn.functional.interpolate)
|
|
@@ -532,45 +558,74 @@ def functional_interpolate(
|
|
|
532
558
|
recompute_scale_factor: bool,
|
|
533
559
|
antialias: bool,
|
|
534
560
|
):
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
)
|
|
546
|
-
is_jax_supported = mode in supported_methods
|
|
547
|
-
if not is_jax_supported:
|
|
548
|
-
raise torchax.tensor.OperatorNotFound(
|
|
549
|
-
f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}"
|
|
550
|
-
)
|
|
551
|
-
# None check
|
|
552
|
-
antialias = antialias or False
|
|
553
|
-
align_corners = align_corners or False
|
|
554
|
-
|
|
555
|
-
if mode in ('cubic', 'bicubic',
|
|
556
|
-
'tricubic') and not antialias and size is not None:
|
|
557
|
-
return jimage.interpolate_bicubic_no_aa(
|
|
558
|
-
input,
|
|
559
|
-
size[0],
|
|
560
|
-
size[1],
|
|
561
|
-
align_corners,
|
|
562
|
-
)
|
|
563
|
-
else:
|
|
564
|
-
# fallback
|
|
565
|
-
raise torchax.tensor.OperatorNotFound(
|
|
566
|
-
f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}"
|
|
561
|
+
supported_methods = (
|
|
562
|
+
"nearest",
|
|
563
|
+
"linear",
|
|
564
|
+
"bilinear",
|
|
565
|
+
"trilinear",
|
|
566
|
+
"cubic",
|
|
567
|
+
"bicubic",
|
|
568
|
+
"tricubic",
|
|
569
|
+
"lanczos3",
|
|
570
|
+
"lanczos5",
|
|
567
571
|
)
|
|
572
|
+
is_jax_supported = mode in supported_methods
|
|
573
|
+
if not is_jax_supported:
|
|
574
|
+
raise torchax.tensor.OperatorNotFound(
|
|
575
|
+
f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}"
|
|
576
|
+
)
|
|
577
|
+
# None check
|
|
578
|
+
antialias = antialias or False
|
|
579
|
+
align_corners = align_corners or False
|
|
580
|
+
|
|
581
|
+
if (
|
|
582
|
+
mode in ("cubic", "bicubic", "tricubic")
|
|
583
|
+
and not antialias
|
|
584
|
+
and size is not None
|
|
585
|
+
):
|
|
586
|
+
return jimage.interpolate_bicubic_no_aa(
|
|
587
|
+
input,
|
|
588
|
+
size[0],
|
|
589
|
+
size[1],
|
|
590
|
+
align_corners,
|
|
591
|
+
)
|
|
592
|
+
else:
|
|
593
|
+
# fallback
|
|
594
|
+
raise torchax.tensor.OperatorNotFound(
|
|
595
|
+
f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}"
|
|
596
|
+
)
|
|
568
597
|
|
|
569
598
|
|
|
570
599
|
@register_function(torch.Tensor.repeat_interleave)
|
|
571
|
-
def torch_Tensor_repeat_interleave(
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
600
|
+
def torch_Tensor_repeat_interleave(
|
|
601
|
+
self, repeats, dim=None, *, output_size=None
|
|
602
|
+
):
|
|
603
|
+
return jnp.repeat(self, repeats, axis=dim, total_repeat_length=output_size)
|
|
604
|
+
|
|
605
|
+
|
|
606
|
+
@register_function(torch.nn.functional.max_pool2d)
|
|
607
|
+
def _functional_max_pool2d(
|
|
608
|
+
input,
|
|
609
|
+
kernel_size,
|
|
610
|
+
stride=None,
|
|
611
|
+
padding=0,
|
|
612
|
+
dilation=1,
|
|
613
|
+
ceil_mode=False,
|
|
614
|
+
return_indices=False,
|
|
615
|
+
):
|
|
616
|
+
if isinstance(kernel_size, int):
|
|
617
|
+
kernel_size = (kernel_size, kernel_size)
|
|
618
|
+
if stride is None:
|
|
619
|
+
stride = kernel_size
|
|
620
|
+
if isinstance(stride, int):
|
|
621
|
+
stride = (stride, stride)
|
|
622
|
+
|
|
623
|
+
return jaten.max_pool(
|
|
624
|
+
input,
|
|
625
|
+
kernel_size,
|
|
626
|
+
stride,
|
|
627
|
+
padding,
|
|
628
|
+
dilation,
|
|
629
|
+
ceil_mode,
|
|
630
|
+
with_index=return_indices,
|
|
631
|
+
)
|