torchax 0.0.10.dev20251114__py3-none-any.whl → 0.0.11.dev202612__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/__init__.py +73 -77
- torchax/amp.py +143 -271
- torchax/checkpoint.py +15 -9
- torchax/config.py +0 -4
- torchax/decompositions.py +66 -60
- torchax/export.py +53 -54
- torchax/flax.py +7 -5
- torchax/interop.py +66 -62
- torchax/mesh_util.py +20 -18
- torchax/ops/__init__.py +4 -3
- torchax/ops/jaten.py +3841 -3968
- torchax/ops/jax_reimplement.py +68 -42
- torchax/ops/jc10d.py +4 -6
- torchax/ops/jimage.py +20 -25
- torchax/ops/jlibrary.py +6 -6
- torchax/ops/jtorch.py +355 -419
- torchax/ops/jtorchvision_nms.py +69 -49
- torchax/ops/mappings.py +42 -63
- torchax/ops/op_base.py +17 -25
- torchax/ops/ops_registry.py +35 -30
- torchax/tensor.py +124 -128
- torchax/train.py +100 -102
- torchax/types.py +8 -7
- torchax/util.py +6 -4
- torchax/view.py +144 -136
- {torchax-0.0.10.dev20251114.dist-info → torchax-0.0.11.dev202612.dist-info}/METADATA +7 -1
- torchax-0.0.11.dev202612.dist-info/RECORD +31 -0
- {torchax-0.0.10.dev20251114.dist-info → torchax-0.0.11.dev202612.dist-info}/WHEEL +1 -1
- torchax-0.0.10.dev20251114.dist-info/RECORD +0 -31
- {torchax-0.0.10.dev20251114.dist-info → torchax-0.0.11.dev202612.dist-info}/licenses/LICENSE +0 -0
torchax/ops/jtorch.py
CHANGED
|
@@ -14,618 +14,554 @@
|
|
|
14
14
|
|
|
15
15
|
"""Tensor constructor overrides"""
|
|
16
16
|
|
|
17
|
-
import math
|
|
18
17
|
import collections.abc
|
|
19
18
|
import functools
|
|
20
|
-
|
|
21
|
-
from
|
|
22
|
-
import numpy as np
|
|
19
|
+
import math
|
|
20
|
+
from collections.abc import Sequence
|
|
23
21
|
|
|
24
22
|
import jax
|
|
25
23
|
import jax.numpy as jnp
|
|
26
|
-
|
|
27
|
-
from jax.experimental.shard_map import shard_map
|
|
28
|
-
|
|
24
|
+
import numpy as np
|
|
29
25
|
import torch
|
|
30
|
-
from torchax.ops.ops_registry import register_torch_function_op
|
|
31
|
-
from torchax.ops import op_base, mappings, jaten, jimage
|
|
32
|
-
import torchax.tensor
|
|
33
|
-
from torchax.view import View, NarrowInfo
|
|
34
26
|
import torch.utils._pytree as pytree
|
|
35
27
|
|
|
28
|
+
import torchax.tensor
|
|
29
|
+
from torchax.ops import jaten, jimage, mappings, op_base
|
|
30
|
+
from torchax.ops.ops_registry import register_torch_function_op
|
|
31
|
+
from torchax.view import NarrowInfo, View
|
|
32
|
+
|
|
36
33
|
|
|
37
34
|
def register_function(torch_func, **kwargs):
|
|
38
|
-
|
|
35
|
+
return functools.partial(register_torch_function_op, torch_func, **kwargs)
|
|
39
36
|
|
|
40
37
|
|
|
41
38
|
@register_function(torch.as_tensor, is_jax_function=False, needs_env=True)
|
|
42
|
-
@op_base.convert_dtype(
|
|
43
|
-
use_default_dtype=False
|
|
44
|
-
) # Attempt to infer type from elements
|
|
39
|
+
@op_base.convert_dtype(use_default_dtype=False) # Attempt to infer type from elements
|
|
45
40
|
def _as_tensor(data, dtype=None, device=None, env=None):
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
41
|
+
if isinstance(data, torch.Tensor):
|
|
42
|
+
return env._to_copy(data, dtype, device)
|
|
43
|
+
if isinstance(data, np.ndarray):
|
|
44
|
+
jax_res = jnp.asarray(data)
|
|
45
|
+
else:
|
|
46
|
+
jax_res = _tensor(data, dtype=dtype)
|
|
47
|
+
return torchax.tensor.Tensor(jax_res, env)
|
|
53
48
|
|
|
54
49
|
|
|
55
50
|
@register_function(torch.tensor)
|
|
56
|
-
@op_base.convert_dtype(
|
|
57
|
-
use_default_dtype=False
|
|
58
|
-
) # Attempt to infer type from elements
|
|
51
|
+
@op_base.convert_dtype(use_default_dtype=False) # Attempt to infer type from elements
|
|
59
52
|
def _tensor(data, *, dtype=None, **kwargs):
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
data, dtype=dtype or mappings.t2j_dtype(torch.get_default_dtype())
|
|
73
|
-
)
|
|
53
|
+
python_types_to_torch_types = {
|
|
54
|
+
bool: jnp.bool,
|
|
55
|
+
int: jnp.int64,
|
|
56
|
+
float: jnp.float32,
|
|
57
|
+
complex: jnp.complex64,
|
|
58
|
+
}
|
|
59
|
+
if not dtype:
|
|
60
|
+
leaves = jax.tree_util.tree_leaves(data)
|
|
61
|
+
if len(leaves) > 0:
|
|
62
|
+
dtype = python_types_to_torch_types.get(type(leaves[0]))
|
|
63
|
+
|
|
64
|
+
return jnp.array(data, dtype=dtype or mappings.t2j_dtype(torch.get_default_dtype()))
|
|
74
65
|
|
|
75
66
|
|
|
76
67
|
@register_function(torch.allclose)
|
|
77
68
|
def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):
|
|
78
|
-
|
|
69
|
+
return jnp.allclose(input, other, rtol, atol, equal_nan)
|
|
79
70
|
|
|
80
71
|
|
|
81
72
|
@register_function(torch.angle)
|
|
82
73
|
def _torch_angle(input):
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
74
|
+
if input.dtype.name == "int64":
|
|
75
|
+
input = input.astype(jnp.dtype("float32"))
|
|
76
|
+
return jnp.angle(input)
|
|
86
77
|
|
|
87
78
|
|
|
88
79
|
@register_function(torch.argsort)
|
|
89
80
|
def _torch_argsort(input, dim=-1, descending=False, stable=False):
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
81
|
+
expanded = False
|
|
82
|
+
if input.ndim == 0:
|
|
83
|
+
# for self of rank 0:
|
|
84
|
+
# torch.any(x, 0), torch.any(x, -1) works;
|
|
85
|
+
# torch.any(x, 1) throws out of bounds, so it's
|
|
86
|
+
# behavior is the same as a jnp array of rank 1
|
|
87
|
+
expanded = True
|
|
88
|
+
input = jnp.expand_dims(input, 0)
|
|
89
|
+
res = jnp.argsort(input, axis=dim, descending=descending, stable=stable)
|
|
90
|
+
if expanded:
|
|
91
|
+
res = res.squeeze()
|
|
92
|
+
return res
|
|
102
93
|
|
|
103
94
|
|
|
104
95
|
@register_function(torch.diag)
|
|
105
96
|
def _diag(input, diagonal=0):
|
|
106
|
-
|
|
97
|
+
return jnp.diag(input, k=diagonal)
|
|
107
98
|
|
|
108
99
|
|
|
109
100
|
@register_function(torch.einsum)
|
|
110
101
|
@register_function(torch.ops.aten.einsum)
|
|
111
102
|
def _einsum(equation, *operands):
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
103
|
+
def get_params(*a):
|
|
104
|
+
inner_list = a[0]
|
|
105
|
+
if not isinstance(inner_list, jax.Array):
|
|
106
|
+
if len(inner_list) == 1:
|
|
107
|
+
A = inner_list
|
|
108
|
+
return A
|
|
109
|
+
elif len(inner_list) == 2:
|
|
110
|
+
A, B = inner_list
|
|
111
|
+
return A, B
|
|
112
|
+
return operands
|
|
113
|
+
|
|
114
|
+
assert isinstance(equation, str), "Only accept str equation"
|
|
115
|
+
filtered_operands = get_params(*operands)
|
|
116
|
+
return jnp.einsum(equation, *filtered_operands)
|
|
126
117
|
|
|
127
118
|
|
|
128
119
|
def _sdpa_reference(
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
120
|
+
query,
|
|
121
|
+
key,
|
|
122
|
+
value,
|
|
123
|
+
attn_mask=None,
|
|
124
|
+
dropout_p=0.0,
|
|
125
|
+
is_causal=False,
|
|
126
|
+
scale=None,
|
|
127
|
+
enable_gqa=False,
|
|
137
128
|
) -> torch.Tensor:
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
attn_weight = torch.
|
|
160
|
-
|
|
161
|
-
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
|
|
162
|
-
return attn_weight @ value
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
from jax.sharding import PartitionSpec
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
def _tpu_flash_attention(query, key, value, env):
|
|
169
|
-
fsdp_partition = PartitionSpec("fsdp")
|
|
170
|
-
|
|
171
|
-
def wrap_flash_attention(query, key, value):
|
|
172
|
-
block_sizes = flash_attention.BlockSizes(
|
|
173
|
-
block_b=min(2, query.shape[0]),
|
|
174
|
-
block_q=min(512, query.shape[2]),
|
|
175
|
-
block_k_major=min(512, key.shape[2]),
|
|
176
|
-
block_k=min(512, key.shape[2]),
|
|
177
|
-
block_q_major_dkv=min(512, query.shape[2]),
|
|
178
|
-
block_k_major_dkv=min(512, key.shape[2]),
|
|
179
|
-
block_k_dkv=min(512, key.shape[2]),
|
|
180
|
-
block_q_dkv=min(512, query.shape[2]),
|
|
181
|
-
block_k_major_dq=min(512, key.shape[2]),
|
|
182
|
-
block_k_dq=min(256, key.shape[2]),
|
|
183
|
-
block_q_dq=min(1024, query.shape[2]),
|
|
184
|
-
)
|
|
185
|
-
return flash_attention.flash_attention(
|
|
186
|
-
query, key, value, causal=True, block_sizes=block_sizes
|
|
187
|
-
)
|
|
188
|
-
|
|
189
|
-
if env.config.shmap_flash_attention:
|
|
190
|
-
wrap_flash_attention = shard_map(
|
|
191
|
-
wrap_flash_attention,
|
|
192
|
-
mesh=env._mesh,
|
|
193
|
-
in_specs=(fsdp_partition, fsdp_partition, fsdp_partition),
|
|
194
|
-
out_specs=fsdp_partition,
|
|
195
|
-
check_rep=False,
|
|
196
|
-
)
|
|
197
|
-
# return flash_attn_mapped(query, key, value)
|
|
198
|
-
return wrap_flash_attention(query, key, value)
|
|
129
|
+
L, S = query.size(-2), key.size(-2)
|
|
130
|
+
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
|
|
131
|
+
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
|
|
132
|
+
if is_causal:
|
|
133
|
+
assert attn_mask is None
|
|
134
|
+
temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
|
|
135
|
+
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
|
136
|
+
attn_bias.to(query.dtype)
|
|
137
|
+
if attn_mask is not None:
|
|
138
|
+
if attn_mask.dtype == torch.bool:
|
|
139
|
+
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
|
140
|
+
else:
|
|
141
|
+
attn_bias += attn_mask
|
|
142
|
+
if enable_gqa:
|
|
143
|
+
key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
|
|
144
|
+
value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)
|
|
145
|
+
|
|
146
|
+
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
|
147
|
+
attn_weight += attn_bias
|
|
148
|
+
attn_weight = torch.softmax(attn_weight, dim=-1)
|
|
149
|
+
if dropout_p > 0:
|
|
150
|
+
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
|
|
151
|
+
return attn_weight @ value
|
|
199
152
|
|
|
200
153
|
|
|
201
154
|
@register_function(torch.nn.functional.one_hot)
|
|
202
155
|
def one_hot(tensor, num_classes=-1):
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
156
|
+
if num_classes == -1:
|
|
157
|
+
num_classes = jnp.max(tensor) + 1
|
|
158
|
+
return jax.nn.one_hot(tensor, num_classes).astype(jnp.int64)
|
|
206
159
|
|
|
207
160
|
|
|
208
161
|
@register_function(torch.nn.functional.pad)
|
|
209
162
|
def pad(tensor, pad, mode="constant", value=None):
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
else:
|
|
257
|
-
return jnp.pad(
|
|
258
|
-
tensor[nd_slice], numpy_pad_width, mode=numpy_mode, **kwargs
|
|
259
|
-
)
|
|
163
|
+
# For padding modes that have different names between Torch and NumPy, this
|
|
164
|
+
# dict provides a Torch-to-NumPy translation. Any string not in this dict will
|
|
165
|
+
# be passed through as-is.
|
|
166
|
+
MODE_NAME_TRANSLATION = {
|
|
167
|
+
"circular": "wrap",
|
|
168
|
+
"replicate": "edge",
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
numpy_mode = MODE_NAME_TRANSLATION.get(mode, mode)
|
|
172
|
+
|
|
173
|
+
num_prefix_dims = tensor.ndim - len(pad) // 2
|
|
174
|
+
|
|
175
|
+
numpy_pad_width = [(0, 0)] * num_prefix_dims
|
|
176
|
+
nd_slice = [slice(None)] * num_prefix_dims
|
|
177
|
+
|
|
178
|
+
for i in range(len(pad) - 2, -1, -2):
|
|
179
|
+
pad_start, pad_end = pad[i : i + 2]
|
|
180
|
+
slice_start, slice_end = None, None
|
|
181
|
+
|
|
182
|
+
if pad_start < 0:
|
|
183
|
+
slice_start = -pad_start
|
|
184
|
+
pad_start = 0
|
|
185
|
+
|
|
186
|
+
if pad_end < 0:
|
|
187
|
+
slice_end = pad_end
|
|
188
|
+
pad_end = 0
|
|
189
|
+
|
|
190
|
+
numpy_pad_width.append((pad_start, pad_end))
|
|
191
|
+
nd_slice.append(slice(slice_start, slice_end))
|
|
192
|
+
|
|
193
|
+
nd_slice = tuple(nd_slice)
|
|
194
|
+
|
|
195
|
+
# `jax.numpy.pad` complains if we provide an irrelevant `constant_values` arg,
|
|
196
|
+
# even if the value we pass in is `None`. (It treats `None` as `nan`.)
|
|
197
|
+
kwargs = {}
|
|
198
|
+
if mode == "constant" and value is not None:
|
|
199
|
+
kwargs["constant_values"] = value
|
|
200
|
+
|
|
201
|
+
# The "replicate" mode pads first and then slices, whereas the "circular" mode
|
|
202
|
+
# slices first and then pads. The latter approach deals with smaller tensors,
|
|
203
|
+
# so we default to that option in modes where the order of operations doesn't
|
|
204
|
+
# affect the result.
|
|
205
|
+
if mode == "replicate":
|
|
206
|
+
return jnp.pad(tensor, numpy_pad_width, mode=numpy_mode, **kwargs)[nd_slice]
|
|
207
|
+
else:
|
|
208
|
+
return jnp.pad(tensor[nd_slice], numpy_pad_width, mode=numpy_mode, **kwargs)
|
|
260
209
|
|
|
261
210
|
|
|
262
211
|
@register_function(
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
212
|
+
torch.nn.functional.scaled_dot_product_attention,
|
|
213
|
+
is_jax_function=False,
|
|
214
|
+
needs_env=True,
|
|
266
215
|
)
|
|
267
216
|
@register_function(
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
217
|
+
torch.ops.aten.scaled_dot_product_attention,
|
|
218
|
+
is_jax_function=False,
|
|
219
|
+
needs_env=True,
|
|
271
220
|
)
|
|
272
221
|
def scaled_dot_product_attention(
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
222
|
+
query,
|
|
223
|
+
key,
|
|
224
|
+
value,
|
|
225
|
+
attn_mask=None,
|
|
226
|
+
dropout_p=0.0,
|
|
227
|
+
is_causal=False,
|
|
228
|
+
scale=None,
|
|
229
|
+
enable_gqa=False,
|
|
230
|
+
env=None,
|
|
282
231
|
) -> torch.Tensor:
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
return env.j2t_iso(res)
|
|
287
|
-
|
|
288
|
-
return _sdpa_reference(
|
|
289
|
-
query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa
|
|
290
|
-
)
|
|
232
|
+
return _sdpa_reference(
|
|
233
|
+
query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa
|
|
234
|
+
)
|
|
291
235
|
|
|
292
236
|
|
|
293
|
-
@register_function(
|
|
294
|
-
torch.Tensor.__getitem__, is_jax_function=False, is_view_op=True
|
|
295
|
-
)
|
|
237
|
+
@register_function(torch.Tensor.__getitem__, is_jax_function=False, is_view_op=True)
|
|
296
238
|
def getitem(self, indexes):
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
239
|
+
if isinstance(indexes, list) and isinstance(indexes[0], int):
|
|
240
|
+
# list of int, i.e. x[[1, 2]] NOT x[1, 2] (the second would be tuple of int)
|
|
241
|
+
indexes = (indexes,)
|
|
242
|
+
elif isinstance(indexes, list):
|
|
243
|
+
indexes = tuple(indexes)
|
|
244
|
+
|
|
245
|
+
def is_narrow_slicing():
|
|
246
|
+
tensor_free = not pytree.tree_any(
|
|
247
|
+
lambda x: isinstance(x, torch.Tensor) or isinstance(x, jax.Array),
|
|
248
|
+
indexes,
|
|
249
|
+
)
|
|
250
|
+
list_free = not isinstance(indexes, tuple) or all(
|
|
251
|
+
False if isinstance(x, list) else True for x in indexes
|
|
252
|
+
)
|
|
253
|
+
return tensor_free and list_free
|
|
312
254
|
|
|
313
|
-
|
|
314
|
-
|
|
255
|
+
if is_narrow_slicing():
|
|
256
|
+
return View(self, view_info=NarrowInfo(indexes), env=self._env)
|
|
315
257
|
|
|
316
|
-
|
|
317
|
-
|
|
258
|
+
indexes = self._env.t2j_iso(indexes)
|
|
259
|
+
return torchax.tensor.Tensor(self._elem[indexes], self._env)
|
|
318
260
|
|
|
319
261
|
|
|
320
262
|
@register_function(torch.corrcoef)
|
|
321
263
|
def _corrcoef(x):
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
264
|
+
if x.dtype.name == "int64":
|
|
265
|
+
return jnp.corrcoef(x).astype(jnp.float32)
|
|
266
|
+
return jnp.corrcoef(x)
|
|
325
267
|
|
|
326
268
|
|
|
327
269
|
@register_function(torch.sparse.mm, is_jax_function=False)
|
|
328
270
|
def _sparse_mm(mat1, mat2, reduce="sum"):
|
|
329
|
-
|
|
271
|
+
return torch.mm(mat1, mat2)
|
|
330
272
|
|
|
331
273
|
|
|
332
274
|
@register_function(torch.isclose)
|
|
333
275
|
def _aten_isclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):
|
|
334
|
-
|
|
276
|
+
return jnp.isclose(input, other, rtol, atol, equal_nan)
|
|
335
277
|
|
|
336
278
|
|
|
337
279
|
@register_function(torch.linalg.det)
|
|
338
280
|
def linalg_det(input):
|
|
339
|
-
|
|
281
|
+
return jnp.linalg.det(input)
|
|
340
282
|
|
|
341
283
|
|
|
342
284
|
@register_function(torch.ones)
|
|
343
285
|
def _ones(*size: int, dtype=None, **kwargs):
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
286
|
+
if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
|
|
287
|
+
size = size[0]
|
|
288
|
+
return jaten._ones(size, dtype=dtype)
|
|
347
289
|
|
|
348
290
|
|
|
349
291
|
@register_function(torch.zeros, is_jax_function=True)
|
|
350
292
|
def _zeros(*size: int, dtype=None, **kwargs):
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
293
|
+
if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
|
|
294
|
+
size = size[0]
|
|
295
|
+
return jaten._zeros(size, dtype=dtype)
|
|
354
296
|
|
|
355
297
|
|
|
356
298
|
@register_function(torch.eye)
|
|
357
299
|
@op_base.convert_dtype()
|
|
358
|
-
def _eye(n: int, m:
|
|
359
|
-
|
|
300
|
+
def _eye(n: int, m: int | None = None, *, dtype=None, **kwargs):
|
|
301
|
+
return jnp.eye(n, m, dtype=dtype)
|
|
360
302
|
|
|
361
303
|
|
|
362
304
|
@register_function(torch.full)
|
|
363
305
|
@op_base.convert_dtype(use_default_dtype=False)
|
|
364
306
|
def _full(size: Sequence[int], fill_value, *, dtype=None, **kwargs):
|
|
365
|
-
|
|
366
|
-
|
|
307
|
+
# TODO: handle torch.Size
|
|
308
|
+
return jnp.full(size, fill_value, dtype=dtype)
|
|
367
309
|
|
|
368
310
|
|
|
369
311
|
@register_function(torch.empty)
|
|
370
312
|
@op_base.convert_dtype()
|
|
371
313
|
def empty(*size: Sequence[int], dtype=None, **kwargs):
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
314
|
+
if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
|
|
315
|
+
size = size[0]
|
|
316
|
+
return jnp.empty(size, dtype=dtype)
|
|
375
317
|
|
|
376
318
|
|
|
377
319
|
@register_function(torch.arange, is_jax_function=True)
|
|
378
320
|
def arange(
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
321
|
+
start,
|
|
322
|
+
end=None,
|
|
323
|
+
step=None,
|
|
324
|
+
out=None,
|
|
325
|
+
dtype=None,
|
|
326
|
+
layout=torch.strided,
|
|
327
|
+
device=None,
|
|
328
|
+
requires_grad=False,
|
|
329
|
+
pin_memory=None,
|
|
388
330
|
):
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
331
|
+
if end is None:
|
|
332
|
+
end = start
|
|
333
|
+
start = 0
|
|
334
|
+
if step is None:
|
|
335
|
+
step = 1
|
|
336
|
+
return jaten._aten_arange(start, end, step, dtype=dtype)
|
|
395
337
|
|
|
396
338
|
|
|
397
339
|
@register_function(torch.empty_strided, is_jax_function=True)
|
|
398
340
|
def empty_strided(
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
341
|
+
size,
|
|
342
|
+
stride,
|
|
343
|
+
*,
|
|
344
|
+
dtype=None,
|
|
345
|
+
layout=None,
|
|
346
|
+
device=None,
|
|
347
|
+
requires_grad=False,
|
|
348
|
+
pin_memory=False,
|
|
407
349
|
):
|
|
408
|
-
|
|
350
|
+
return empty(size, dtype=dtype, requires_grad=requires_grad)
|
|
409
351
|
|
|
410
352
|
|
|
411
353
|
@register_function(torch.unravel_index)
|
|
412
354
|
def unravel_index(indices, shape):
|
|
413
|
-
|
|
355
|
+
return jnp.unravel_index(indices, shape)
|
|
414
356
|
|
|
415
357
|
|
|
416
358
|
@register_function(torch.rand, is_jax_function=True, needs_env=True)
|
|
417
359
|
def rand(*size, **kwargs):
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
360
|
+
if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
|
|
361
|
+
size = size[0]
|
|
362
|
+
return jaten._rand(size, **kwargs)
|
|
421
363
|
|
|
422
364
|
|
|
423
365
|
@register_function(torch.randn, is_jax_function=True, needs_env=True)
|
|
424
366
|
def randn(
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
367
|
+
*size,
|
|
368
|
+
generator=None,
|
|
369
|
+
out=None,
|
|
370
|
+
dtype=None,
|
|
371
|
+
layout=torch.strided,
|
|
372
|
+
device=None,
|
|
373
|
+
requires_grad=False,
|
|
374
|
+
pin_memory=False,
|
|
375
|
+
env=None,
|
|
434
376
|
):
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
377
|
+
if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
|
|
378
|
+
size = size[0]
|
|
379
|
+
return jaten._aten_randn(size, generator=generator, dtype=dtype, env=env)
|
|
438
380
|
|
|
439
381
|
|
|
440
382
|
@register_function(torch.randint, is_jax_function=False, needs_env=True)
|
|
441
383
|
def randint(*args, **kwargs):
|
|
442
|
-
|
|
384
|
+
return jaten._aten_randint(*args, **kwargs)
|
|
443
385
|
|
|
444
386
|
|
|
445
387
|
@register_function(torch.logdet)
|
|
446
388
|
def logdet(input):
|
|
447
|
-
|
|
448
|
-
|
|
389
|
+
_, logabsdet = jaten._aten__linalg_slogdet(input)
|
|
390
|
+
return logabsdet
|
|
449
391
|
|
|
450
392
|
|
|
451
393
|
@register_function(torch.linalg.slogdet)
|
|
452
394
|
def linalg_slogdet(input):
|
|
453
|
-
|
|
454
|
-
|
|
395
|
+
sign, logabsdet = jaten._aten__linalg_slogdet(input)
|
|
396
|
+
return torch.return_types.slogdet((sign, logabsdet))
|
|
455
397
|
|
|
456
398
|
|
|
457
399
|
@register_function(torch.tensor_split)
|
|
458
400
|
def tensor_split(input, indices_or_sections, dim=0):
|
|
459
|
-
|
|
401
|
+
return jnp.array_split(input, indices_or_sections, axis=dim)
|
|
460
402
|
|
|
461
403
|
|
|
462
404
|
@register_function(torch.linalg.solve)
|
|
463
405
|
def linalg_solve(a, b):
|
|
464
|
-
|
|
465
|
-
|
|
406
|
+
res, _ = jaten._aten__linalg_solve_ex(a, b)
|
|
407
|
+
return res
|
|
466
408
|
|
|
467
409
|
|
|
468
410
|
@register_function(torch.linalg.solve_ex)
|
|
469
411
|
def linalg_solve_ex(a, b):
|
|
470
|
-
|
|
471
|
-
|
|
412
|
+
res, info = jaten._aten__linalg_solve_ex(a, b)
|
|
413
|
+
return res, info
|
|
472
414
|
|
|
473
415
|
|
|
474
416
|
@register_function(torch.linalg.svd)
|
|
475
417
|
def linalg_svd(a, full_matrices=True):
|
|
476
|
-
|
|
418
|
+
return jaten._aten__linalg_svd(a, full_matrices=full_matrices)
|
|
477
419
|
|
|
478
420
|
|
|
479
421
|
@register_function(torch.linalg.matrix_power)
|
|
480
422
|
def matrix_power(A, n, *, out=None):
|
|
481
|
-
|
|
423
|
+
return jnp.linalg.matrix_power(A, n)
|
|
482
424
|
|
|
483
425
|
|
|
484
426
|
@register_function(torch.svd)
|
|
485
427
|
def svd(a, some=True, compute_uv=True):
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
428
|
+
if not compute_uv:
|
|
429
|
+
S = jaten._aten__linalg_svd(a, full_matrices=False)[1]
|
|
430
|
+
U = jnp.zeros((a.shape[-2], a.shape[-2]), dtype=a.dtype)
|
|
431
|
+
V = jnp.zeros((a.shape[-1], a.shape[-1]), dtype=a.dtype)
|
|
432
|
+
return U, S, V
|
|
433
|
+
U, S, V = jaten._aten__linalg_svd(a, full_matrices=not some)
|
|
434
|
+
return U, S, jnp.matrix_transpose(V)
|
|
493
435
|
|
|
494
436
|
|
|
495
437
|
@register_function(torch.cdist)
|
|
496
438
|
def _cdist(x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary"):
|
|
497
|
-
|
|
439
|
+
return jaten._aten_cdist(x1, x2, p, compute_mode)
|
|
498
440
|
|
|
499
441
|
|
|
500
442
|
@register_function(torch.lu)
|
|
501
443
|
def lu(A, **kwargs):
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
444
|
+
lu, pivots, _ = jax.lax.linalg.lu(A)
|
|
445
|
+
# JAX pivots are offset by 1 compared to torch
|
|
446
|
+
_pivots = pivots + 1
|
|
447
|
+
info_shape = pivots.shape[:-1]
|
|
448
|
+
info = jnp.zeros(info_shape, dtype=mappings.t2j_dtype(torch.int32))
|
|
449
|
+
if kwargs["get_infos"]:
|
|
450
|
+
return lu, _pivots, info
|
|
451
|
+
return lu, _pivots
|
|
510
452
|
|
|
511
453
|
|
|
512
454
|
@register_function(torch.lu_solve)
|
|
513
455
|
def lu_solve(b, LU_data, LU_pivots, **kwargs):
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
456
|
+
# JAX pivots are offset by 1 compared to torch
|
|
457
|
+
_pivots = LU_pivots - 1
|
|
458
|
+
x = jax.scipy.linalg.lu_solve((LU_data, _pivots), b)
|
|
459
|
+
return x
|
|
518
460
|
|
|
519
461
|
|
|
520
462
|
@register_function(torch.linalg.tensorsolve)
|
|
521
463
|
def linalg_tensorsolve(A, b, dims=None):
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
464
|
+
# examples:
|
|
465
|
+
# A = torch.randn(2, 3, 6), b = torch.randn(3, 2)
|
|
466
|
+
# A = torch.randn(2, 3, 6), b = torch.randn(2, 3) -> torch.Size([3, 6])
|
|
467
|
+
# A = torch.randn(9, 2, 6, 3) b = torch.randn(6, 3) -> torch.Size([6, 3])
|
|
468
|
+
# A = torch.randn(9, 2, 3, 6) b = torch.randn(6, 3) -> torch.Size([3, 6])
|
|
469
|
+
# A = torch.randn(18, 6, 3) b = torch.randn(18) -> torch.Size([6, 3])
|
|
470
|
+
# A = torch.randn(3, 8, 4, 6) b = torch.randn(4, 6) -> torch.Size([4,6])
|
|
471
|
+
# A = torch.randn(3, 8, 1, 2, 2, 6) b = torch.randn(3, 4, 2) -> torch.Size([2, 2, 6])
|
|
472
|
+
|
|
473
|
+
# torch allows b to be shaped differently.
|
|
474
|
+
# especially when axes are moved using dims.
|
|
475
|
+
# ValueError: After moving axes to end, leading shape of a must match shape of b. got a.shape=(3, 2, 6), b.shape=(2, 3)
|
|
476
|
+
# So we are handling the moveaxis and forcing b's shape to match what jax expects
|
|
477
|
+
if dims is not None:
|
|
478
|
+
A = jnp.moveaxis(A, dims, len(dims) * (A.ndim - 1,))
|
|
479
|
+
dims = None
|
|
480
|
+
if A.shape[: b.ndim] != b.shape:
|
|
481
|
+
b = jnp.reshape(b, A.shape[: b.ndim])
|
|
482
|
+
return jnp.linalg.tensorsolve(A, b, axes=dims)
|
|
541
483
|
|
|
542
484
|
|
|
543
485
|
@register_function(torch.nn.functional.linear)
|
|
544
486
|
def functional_linear(self, weights, bias=None):
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
487
|
+
res = jnp.einsum("...a,ba->...b", self, weights)
|
|
488
|
+
if bias is not None:
|
|
489
|
+
res += bias
|
|
490
|
+
return res
|
|
549
491
|
|
|
550
492
|
|
|
551
493
|
@register_function(torch.nn.functional.interpolate)
|
|
552
494
|
def functional_interpolate(
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
495
|
+
input,
|
|
496
|
+
size: tuple[int, int],
|
|
497
|
+
scale_factor: float | None,
|
|
498
|
+
mode: str,
|
|
499
|
+
align_corners: bool,
|
|
500
|
+
recompute_scale_factor: bool,
|
|
501
|
+
antialias: bool,
|
|
560
502
|
):
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
503
|
+
supported_methods = (
|
|
504
|
+
"nearest",
|
|
505
|
+
"linear",
|
|
506
|
+
"bilinear",
|
|
507
|
+
"trilinear",
|
|
508
|
+
"cubic",
|
|
509
|
+
"bicubic",
|
|
510
|
+
"tricubic",
|
|
511
|
+
"lanczos3",
|
|
512
|
+
"lanczos5",
|
|
513
|
+
)
|
|
514
|
+
is_jax_supported = mode in supported_methods
|
|
515
|
+
if not is_jax_supported:
|
|
516
|
+
raise torchax.tensor.OperatorNotFound(
|
|
517
|
+
f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}"
|
|
518
|
+
)
|
|
519
|
+
# None check
|
|
520
|
+
antialias = antialias or False
|
|
521
|
+
align_corners = align_corners or False
|
|
522
|
+
|
|
523
|
+
if mode in ("cubic", "bicubic", "tricubic") and not antialias and size is not None:
|
|
524
|
+
return jimage.interpolate_bicubic_no_aa(
|
|
525
|
+
input,
|
|
526
|
+
size[0],
|
|
527
|
+
size[1],
|
|
528
|
+
align_corners,
|
|
529
|
+
)
|
|
530
|
+
else:
|
|
531
|
+
# fallback
|
|
532
|
+
raise torchax.tensor.OperatorNotFound(
|
|
533
|
+
f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}"
|
|
571
534
|
)
|
|
572
|
-
is_jax_supported = mode in supported_methods
|
|
573
|
-
if not is_jax_supported:
|
|
574
|
-
raise torchax.tensor.OperatorNotFound(
|
|
575
|
-
f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}"
|
|
576
|
-
)
|
|
577
|
-
# None check
|
|
578
|
-
antialias = antialias or False
|
|
579
|
-
align_corners = align_corners or False
|
|
580
|
-
|
|
581
|
-
if (
|
|
582
|
-
mode in ("cubic", "bicubic", "tricubic")
|
|
583
|
-
and not antialias
|
|
584
|
-
and size is not None
|
|
585
|
-
):
|
|
586
|
-
return jimage.interpolate_bicubic_no_aa(
|
|
587
|
-
input,
|
|
588
|
-
size[0],
|
|
589
|
-
size[1],
|
|
590
|
-
align_corners,
|
|
591
|
-
)
|
|
592
|
-
else:
|
|
593
|
-
# fallback
|
|
594
|
-
raise torchax.tensor.OperatorNotFound(
|
|
595
|
-
f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}"
|
|
596
|
-
)
|
|
597
535
|
|
|
598
536
|
|
|
599
537
|
@register_function(torch.Tensor.repeat_interleave)
|
|
600
|
-
def torch_Tensor_repeat_interleave(
|
|
601
|
-
|
|
602
|
-
):
|
|
603
|
-
return jnp.repeat(self, repeats, axis=dim, total_repeat_length=output_size)
|
|
538
|
+
def torch_Tensor_repeat_interleave(self, repeats, dim=None, *, output_size=None):
|
|
539
|
+
return jnp.repeat(self, repeats, axis=dim, total_repeat_length=output_size)
|
|
604
540
|
|
|
605
541
|
|
|
606
542
|
@register_function(torch.nn.functional.max_pool2d)
|
|
607
543
|
def _functional_max_pool2d(
|
|
544
|
+
input,
|
|
545
|
+
kernel_size,
|
|
546
|
+
stride=None,
|
|
547
|
+
padding=0,
|
|
548
|
+
dilation=1,
|
|
549
|
+
ceil_mode=False,
|
|
550
|
+
return_indices=False,
|
|
551
|
+
):
|
|
552
|
+
if isinstance(kernel_size, int):
|
|
553
|
+
kernel_size = (kernel_size, kernel_size)
|
|
554
|
+
if stride is None:
|
|
555
|
+
stride = kernel_size
|
|
556
|
+
if isinstance(stride, int):
|
|
557
|
+
stride = (stride, stride)
|
|
558
|
+
|
|
559
|
+
return jaten.max_pool(
|
|
608
560
|
input,
|
|
609
561
|
kernel_size,
|
|
610
|
-
stride
|
|
611
|
-
padding
|
|
612
|
-
dilation
|
|
613
|
-
ceil_mode
|
|
614
|
-
return_indices
|
|
615
|
-
)
|
|
616
|
-
if isinstance(kernel_size, int):
|
|
617
|
-
kernel_size = (kernel_size, kernel_size)
|
|
618
|
-
if stride is None:
|
|
619
|
-
stride = kernel_size
|
|
620
|
-
if isinstance(stride, int):
|
|
621
|
-
stride = (stride, stride)
|
|
622
|
-
|
|
623
|
-
return jaten.max_pool(
|
|
624
|
-
input,
|
|
625
|
-
kernel_size,
|
|
626
|
-
stride,
|
|
627
|
-
padding,
|
|
628
|
-
dilation,
|
|
629
|
-
ceil_mode,
|
|
630
|
-
with_index=return_indices,
|
|
631
|
-
)
|
|
562
|
+
stride,
|
|
563
|
+
padding,
|
|
564
|
+
dilation,
|
|
565
|
+
ceil_mode,
|
|
566
|
+
with_index=return_indices,
|
|
567
|
+
)
|