torchax 0.0.6__py3-none-any.whl → 0.0.10.dev20251116__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

torchax/ops/jtorch.py CHANGED
@@ -1,9 +1,24 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  """Tensor constructor overrides"""
2
16
 
3
17
  import math
4
18
  import collections.abc
5
19
  import functools
6
20
  from typing import Optional, Sequence, Tuple
21
+ from jax._src.interpreters.mlir import wrap_with_memory_kind
7
22
  import numpy as np
8
23
 
9
24
  import jax
@@ -20,92 +35,94 @@ import torch.utils._pytree as pytree
20
35
 
21
36
 
22
37
  def register_function(torch_func, **kwargs):
23
- return functools.partial(register_torch_function_op, torch_func, **kwargs)
38
+ return functools.partial(register_torch_function_op, torch_func, **kwargs)
24
39
 
25
40
 
26
41
  @register_function(torch.as_tensor, is_jax_function=False, needs_env=True)
27
42
  @op_base.convert_dtype(
28
- use_default_dtype=False) # Attempt to infer type from elements
43
+ use_default_dtype=False
44
+ ) # Attempt to infer type from elements
29
45
  def _as_tensor(data, dtype=None, device=None, env=None):
30
- if isinstance(data, torch.Tensor):
31
- return env._to_copy(data, dtype, device)
32
- if isinstance(data, np.ndarray):
33
- jax_res = jnp.asarray(data)
34
- else:
35
- jax_res = _tensor(data, dtype=dtype)
36
- return torchax.tensor.Tensor(jax_res, env)
46
+ if isinstance(data, torch.Tensor):
47
+ return env._to_copy(data, dtype, device)
48
+ if isinstance(data, np.ndarray):
49
+ jax_res = jnp.asarray(data)
50
+ else:
51
+ jax_res = _tensor(data, dtype=dtype)
52
+ return torchax.tensor.Tensor(jax_res, env)
37
53
 
38
54
 
39
55
  @register_function(torch.tensor)
40
56
  @op_base.convert_dtype(
41
- use_default_dtype=False) # Attempt to infer type from elements
57
+ use_default_dtype=False
58
+ ) # Attempt to infer type from elements
42
59
  def _tensor(data, *, dtype=None, **kwargs):
43
- python_types_to_torch_types = {
44
- bool: jnp.bool,
45
- int: jnp.int64,
46
- float: jnp.float32,
47
- complex: jnp.complex64,
48
- }
49
- if not dtype:
50
- leaves = jax.tree_util.tree_leaves(data)
51
- if len(leaves) > 0:
52
- dtype = python_types_to_torch_types.get(type(leaves[0]))
53
-
54
- return jnp.array(
55
- data, dtype=dtype or mappings.t2j_dtype(torch.get_default_dtype()))
60
+ python_types_to_torch_types = {
61
+ bool: jnp.bool,
62
+ int: jnp.int64,
63
+ float: jnp.float32,
64
+ complex: jnp.complex64,
65
+ }
66
+ if not dtype:
67
+ leaves = jax.tree_util.tree_leaves(data)
68
+ if len(leaves) > 0:
69
+ dtype = python_types_to_torch_types.get(type(leaves[0]))
70
+
71
+ return jnp.array(
72
+ data, dtype=dtype or mappings.t2j_dtype(torch.get_default_dtype())
73
+ )
56
74
 
57
75
 
58
76
  @register_function(torch.allclose)
59
77
  def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):
60
- return jnp.allclose(input, other, rtol, atol, equal_nan)
78
+ return jnp.allclose(input, other, rtol, atol, equal_nan)
61
79
 
62
80
 
63
81
  @register_function(torch.angle)
64
82
  def _torch_angle(input):
65
- if input.dtype.name == "int64":
66
- input = input.astype(jnp.dtype("float32"))
67
- return jnp.angle(input)
83
+ if input.dtype.name == "int64":
84
+ input = input.astype(jnp.dtype("float32"))
85
+ return jnp.angle(input)
68
86
 
69
87
 
70
88
  @register_function(torch.argsort)
71
89
  def _torch_argsort(input, dim=-1, descending=False, stable=False):
72
- expanded = False
73
- if input.ndim == 0:
74
- # for self of rank 0:
75
- # torch.any(x, 0), torch.any(x, -1) works;
76
- # torch.any(x, 1) throws out of bounds, so it's
77
- # behavior is the same as a jnp array of rank 1
78
- expanded = True
79
- input = jnp.expand_dims(input, 0)
80
- res = jnp.argsort(input, axis=dim, descending=descending, stable=stable)
81
- if expanded:
82
- res = res.squeeze()
83
- return res
90
+ expanded = False
91
+ if input.ndim == 0:
92
+ # for self of rank 0:
93
+ # torch.any(x, 0), torch.any(x, -1) works;
94
+ # torch.any(x, 1) throws out of bounds, so it's
95
+ # behavior is the same as a jnp array of rank 1
96
+ expanded = True
97
+ input = jnp.expand_dims(input, 0)
98
+ res = jnp.argsort(input, axis=dim, descending=descending, stable=stable)
99
+ if expanded:
100
+ res = res.squeeze()
101
+ return res
84
102
 
85
103
 
86
104
  @register_function(torch.diag)
87
105
  def _diag(input, diagonal=0):
88
- return jnp.diag(input, k=diagonal)
106
+ return jnp.diag(input, k=diagonal)
89
107
 
90
108
 
91
109
  @register_function(torch.einsum)
92
110
  @register_function(torch.ops.aten.einsum)
93
111
  def _einsum(equation, *operands):
94
-
95
- def get_params(*a):
96
- inner_list = a[0]
97
- if not isinstance(inner_list, jax.Array):
98
- if len(inner_list) == 1:
99
- A = inner_list
100
- return A
101
- elif len(inner_list) == 2:
102
- A, B = inner_list
103
- return A, B
104
- return operands
105
-
106
- assert isinstance(equation, str), "Only accept str equation"
107
- filtered_operands = get_params(*operands)
108
- return jnp.einsum(equation, *filtered_operands)
112
+ def get_params(*a):
113
+ inner_list = a[0]
114
+ if not isinstance(inner_list, jax.Array):
115
+ if len(inner_list) == 1:
116
+ A = inner_list
117
+ return A
118
+ elif len(inner_list) == 2:
119
+ A, B = inner_list
120
+ return A, B
121
+ return operands
122
+
123
+ assert isinstance(equation, str), "Only accept str equation"
124
+ filtered_operands = get_params(*operands)
125
+ return jnp.einsum(equation, *filtered_operands)
109
126
 
110
127
 
111
128
  def _sdpa_reference(
@@ -118,122 +135,128 @@ def _sdpa_reference(
118
135
  scale=None,
119
136
  enable_gqa=False,
120
137
  ) -> torch.Tensor:
121
- L, S = query.size(-2), key.size(-2)
122
- scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
123
- attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
124
- if is_causal:
125
- assert attn_mask is None
126
- temp_mask = torch.ones(
127
- L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
128
- attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
129
- attn_bias.to(query.dtype)
130
- if attn_mask is not None:
131
- if attn_mask.dtype == torch.bool:
132
- attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
133
- else:
134
- attn_bias += attn_mask
135
- if enable_gqa:
136
- key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
137
- value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)
138
-
139
- attn_weight = query @ key.transpose(-2, -1) * scale_factor
140
- attn_weight += attn_bias
141
- attn_weight = torch.softmax(attn_weight, dim=-1)
142
- if dropout_p > 0:
143
- attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
144
- return attn_weight @ value
138
+ L, S = query.size(-2), key.size(-2)
139
+ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
140
+ attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
141
+ if is_causal:
142
+ assert attn_mask is None
143
+ temp_mask = torch.ones(
144
+ L, S, dtype=torch.bool, device=query.device
145
+ ).tril(diagonal=0)
146
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
147
+ attn_bias.to(query.dtype)
148
+ if attn_mask is not None:
149
+ if attn_mask.dtype == torch.bool:
150
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
151
+ else:
152
+ attn_bias += attn_mask
153
+ if enable_gqa:
154
+ key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
155
+ value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)
156
+
157
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
158
+ attn_weight += attn_bias
159
+ attn_weight = torch.softmax(attn_weight, dim=-1)
160
+ if dropout_p > 0:
161
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
162
+ return attn_weight @ value
145
163
 
146
164
 
147
165
  from jax.sharding import PartitionSpec
148
166
 
149
167
 
150
168
  def _tpu_flash_attention(query, key, value, env):
151
- fsdp_partition = PartitionSpec("fsdp")
152
-
153
- def wrap_flash_attention(query, key, value):
154
- block_sizes = flash_attention.BlockSizes(
155
- block_b=min(2, query.shape[0]),
156
- block_q=min(512, query.shape[2]),
157
- block_k_major=min(512, key.shape[2]),
158
- block_k=min(512, key.shape[2]),
159
- block_q_major_dkv=min(512, query.shape[2]),
160
- block_k_major_dkv=min(512, key.shape[2]),
161
- block_k_dkv=min(512, key.shape[2]),
162
- block_q_dkv=min(512, query.shape[2]),
163
- block_k_major_dq=min(512, key.shape[2]),
164
- block_k_dq=min(256, key.shape[2]),
165
- block_q_dq=min(1024, query.shape[2]),
166
- )
167
- return flash_attention.flash_attention(
168
- query, key, value, causal=True, block_sizes=block_sizes)
169
-
170
- if env.config.shmap_flash_attention:
171
- wrap_flash_attention = shard_map(
172
- wrap_flash_attention,
173
- mesh=env._mesh,
174
- in_specs=(fsdp_partition, fsdp_partition, fsdp_partition),
175
- out_specs=fsdp_partition,
176
- check_rep=False,
177
- )
178
- # return flash_attn_mapped(query, key, value)
179
- return wrap_flash_attention(query, key, value)
169
+ fsdp_partition = PartitionSpec("fsdp")
170
+
171
+ def wrap_flash_attention(query, key, value):
172
+ block_sizes = flash_attention.BlockSizes(
173
+ block_b=min(2, query.shape[0]),
174
+ block_q=min(512, query.shape[2]),
175
+ block_k_major=min(512, key.shape[2]),
176
+ block_k=min(512, key.shape[2]),
177
+ block_q_major_dkv=min(512, query.shape[2]),
178
+ block_k_major_dkv=min(512, key.shape[2]),
179
+ block_k_dkv=min(512, key.shape[2]),
180
+ block_q_dkv=min(512, query.shape[2]),
181
+ block_k_major_dq=min(512, key.shape[2]),
182
+ block_k_dq=min(256, key.shape[2]),
183
+ block_q_dq=min(1024, query.shape[2]),
184
+ )
185
+ return flash_attention.flash_attention(
186
+ query, key, value, causal=True, block_sizes=block_sizes
187
+ )
188
+
189
+ if env.config.shmap_flash_attention:
190
+ wrap_flash_attention = shard_map(
191
+ wrap_flash_attention,
192
+ mesh=env._mesh,
193
+ in_specs=(fsdp_partition, fsdp_partition, fsdp_partition),
194
+ out_specs=fsdp_partition,
195
+ check_rep=False,
196
+ )
197
+ # return flash_attn_mapped(query, key, value)
198
+ return wrap_flash_attention(query, key, value)
180
199
 
181
200
 
182
201
  @register_function(torch.nn.functional.one_hot)
183
202
  def one_hot(tensor, num_classes=-1):
184
- if num_classes == -1:
185
- num_classes = jnp.max(tensor) + 1
186
- return jax.nn.one_hot(tensor, num_classes).astype(jnp.int64)
203
+ if num_classes == -1:
204
+ num_classes = jnp.max(tensor) + 1
205
+ return jax.nn.one_hot(tensor, num_classes).astype(jnp.int64)
187
206
 
188
207
 
189
208
  @register_function(torch.nn.functional.pad)
190
209
  def pad(tensor, pad, mode="constant", value=None):
191
- # For padding modes that have different names between Torch and NumPy, this
192
- # dict provides a Torch-to-NumPy translation. Any string not in this dict will
193
- # be passed through as-is.
194
- MODE_NAME_TRANSLATION = {
195
- "circular": "wrap",
196
- "replicate": "edge",
197
- }
198
-
199
- numpy_mode = MODE_NAME_TRANSLATION.get(mode, mode)
200
-
201
- num_prefix_dims = tensor.ndim - len(pad) // 2
202
-
203
- numpy_pad_width = [(0, 0)] * num_prefix_dims
204
- nd_slice = [slice(None)] * num_prefix_dims
205
-
206
- for i in range(len(pad) - 2, -1, -2):
207
- pad_start, pad_end = pad[i:i + 2]
208
- slice_start, slice_end = None, None
209
-
210
- if pad_start < 0:
211
- slice_start = -pad_start
212
- pad_start = 0
213
-
214
- if pad_end < 0:
215
- slice_end = pad_end
216
- pad_end = 0
217
-
218
- numpy_pad_width.append((pad_start, pad_end))
219
- nd_slice.append(slice(slice_start, slice_end))
220
-
221
- nd_slice = tuple(nd_slice)
222
-
223
- # `jax.numpy.pad` complains if we provide an irrelevant `constant_values` arg,
224
- # even if the value we pass in is `None`. (It treats `None` as `nan`.)
225
- kwargs = dict()
226
- if mode == "constant" and value is not None:
227
- kwargs["constant_values"] = value
228
-
229
- # The "replicate" mode pads first and then slices, whereas the "circular" mode
230
- # slices first and then pads. The latter approach deals with smaller tensors,
231
- # so we default to that option in modes where the order of operations doesn't
232
- # affect the result.
233
- if mode == "replicate":
234
- return jnp.pad(tensor, numpy_pad_width, mode=numpy_mode, **kwargs)[nd_slice]
235
- else:
236
- return jnp.pad(tensor[nd_slice], numpy_pad_width, mode=numpy_mode, **kwargs)
210
+ # For padding modes that have different names between Torch and NumPy, this
211
+ # dict provides a Torch-to-NumPy translation. Any string not in this dict will
212
+ # be passed through as-is.
213
+ MODE_NAME_TRANSLATION = {
214
+ "circular": "wrap",
215
+ "replicate": "edge",
216
+ }
217
+
218
+ numpy_mode = MODE_NAME_TRANSLATION.get(mode, mode)
219
+
220
+ num_prefix_dims = tensor.ndim - len(pad) // 2
221
+
222
+ numpy_pad_width = [(0, 0)] * num_prefix_dims
223
+ nd_slice = [slice(None)] * num_prefix_dims
224
+
225
+ for i in range(len(pad) - 2, -1, -2):
226
+ pad_start, pad_end = pad[i : i + 2]
227
+ slice_start, slice_end = None, None
228
+
229
+ if pad_start < 0:
230
+ slice_start = -pad_start
231
+ pad_start = 0
232
+
233
+ if pad_end < 0:
234
+ slice_end = pad_end
235
+ pad_end = 0
236
+
237
+ numpy_pad_width.append((pad_start, pad_end))
238
+ nd_slice.append(slice(slice_start, slice_end))
239
+
240
+ nd_slice = tuple(nd_slice)
241
+
242
+ # `jax.numpy.pad` complains if we provide an irrelevant `constant_values` arg,
243
+ # even if the value we pass in is `None`. (It treats `None` as `nan`.)
244
+ kwargs = dict()
245
+ if mode == "constant" and value is not None:
246
+ kwargs["constant_values"] = value
247
+
248
+ # The "replicate" mode pads first and then slices, whereas the "circular" mode
249
+ # slices first and then pads. The latter approach deals with smaller tensors,
250
+ # so we default to that option in modes where the order of operations doesn't
251
+ # affect the result.
252
+ if mode == "replicate":
253
+ return jnp.pad(tensor, numpy_pad_width, mode=numpy_mode, **kwargs)[
254
+ nd_slice
255
+ ]
256
+ else:
257
+ return jnp.pad(
258
+ tensor[nd_slice], numpy_pad_width, mode=numpy_mode, **kwargs
259
+ )
237
260
 
238
261
 
239
262
  @register_function(
@@ -244,7 +267,8 @@ def pad(tensor, pad, mode="constant", value=None):
244
267
  @register_function(
245
268
  torch.ops.aten.scaled_dot_product_attention,
246
269
  is_jax_function=False,
247
- needs_env=True)
270
+ needs_env=True,
271
+ )
248
272
  def scaled_dot_product_attention(
249
273
  query,
250
274
  key,
@@ -256,96 +280,98 @@ def scaled_dot_product_attention(
256
280
  enable_gqa=False,
257
281
  env=None,
258
282
  ) -> torch.Tensor:
283
+ if env.config.use_tpu_flash_attention:
284
+ jquery, jkey, jvalue = env.t2j_iso((query, key, value))
285
+ res = _tpu_flash_attention(jquery, jkey, jvalue, env)
286
+ return env.j2t_iso(res)
259
287
 
260
- if env.config.use_tpu_flash_attention:
261
- jquery, jkey, jvalue = env.t2j_iso((query, key, value))
262
- res = _tpu_flash_attention(jquery, jkey, jvalue, env)
263
- return env.j2t_iso(res)
264
-
265
- return _sdpa_reference(query, key, value, attn_mask, dropout_p, is_causal,
266
- scale, enable_gqa)
288
+ return _sdpa_reference(
289
+ query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa
290
+ )
267
291
 
268
292
 
269
293
  @register_function(
270
- torch.Tensor.__getitem__, is_jax_function=False, is_view_op=True)
294
+ torch.Tensor.__getitem__, is_jax_function=False, is_view_op=True
295
+ )
271
296
  def getitem(self, indexes):
297
+ if isinstance(indexes, list) and isinstance(indexes[0], int):
298
+ # list of int, i.e. x[[1, 2]] NOT x[1, 2] (the second would be tuple of int)
299
+ indexes = (indexes,)
300
+ elif isinstance(indexes, list):
301
+ indexes = tuple(indexes)
272
302
 
273
- if isinstance(indexes, list) and isinstance(indexes[0], int):
274
- # list of int, i.e. x[[1, 2]] NOT x[1, 2] (the second would be tuple of int)
275
- indexes = (indexes,)
276
- elif isinstance(indexes, list):
277
- indexes = tuple(indexes)
278
-
279
- def is_narrow_slicing():
280
- tensor_free = not pytree.tree_any(
281
- lambda x: isinstance(x, torch.Tensor) or isinstance(x, jax.Array),
282
- indexes)
283
- list_free = not isinstance(indexes, tuple) or all(
284
- [False if isinstance(x, list) else True for x in indexes])
285
- return tensor_free and list_free
303
+ def is_narrow_slicing():
304
+ tensor_free = not pytree.tree_any(
305
+ lambda x: isinstance(x, torch.Tensor) or isinstance(x, jax.Array),
306
+ indexes,
307
+ )
308
+ list_free = not isinstance(indexes, tuple) or all(
309
+ [False if isinstance(x, list) else True for x in indexes]
310
+ )
311
+ return tensor_free and list_free
286
312
 
287
- if is_narrow_slicing():
288
- return View(self, view_info=NarrowInfo(indexes), env=self._env)
313
+ if is_narrow_slicing():
314
+ return View(self, view_info=NarrowInfo(indexes), env=self._env)
289
315
 
290
- indexes = self._env.t2j_iso(indexes)
291
- return torchax.tensor.Tensor(self._elem[indexes], self._env)
316
+ indexes = self._env.t2j_iso(indexes)
317
+ return torchax.tensor.Tensor(self._elem[indexes], self._env)
292
318
 
293
319
 
294
320
  @register_function(torch.corrcoef)
295
321
  def _corrcoef(x):
296
- if x.dtype.name == "int64":
297
- return jnp.corrcoef(x).astype(jnp.float32)
298
- return jnp.corrcoef(x)
322
+ if x.dtype.name == "int64":
323
+ return jnp.corrcoef(x).astype(jnp.float32)
324
+ return jnp.corrcoef(x)
299
325
 
300
326
 
301
327
  @register_function(torch.sparse.mm, is_jax_function=False)
302
328
  def _sparse_mm(mat1, mat2, reduce="sum"):
303
- return torch.mm(mat1, mat2)
329
+ return torch.mm(mat1, mat2)
304
330
 
305
331
 
306
332
  @register_function(torch.isclose)
307
333
  def _aten_isclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):
308
- return jnp.isclose(input, other, rtol, atol, equal_nan)
334
+ return jnp.isclose(input, other, rtol, atol, equal_nan)
309
335
 
310
336
 
311
337
  @register_function(torch.linalg.det)
312
338
  def linalg_det(input):
313
- return jnp.linalg.det(input)
339
+ return jnp.linalg.det(input)
314
340
 
315
341
 
316
342
  @register_function(torch.ones)
317
343
  def _ones(*size: int, dtype=None, **kwargs):
318
- if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
319
- size = size[0]
320
- return jaten._ones(size, dtype=dtype)
344
+ if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
345
+ size = size[0]
346
+ return jaten._ones(size, dtype=dtype)
321
347
 
322
348
 
323
349
  @register_function(torch.zeros, is_jax_function=True)
324
350
  def _zeros(*size: int, dtype=None, **kwargs):
325
- if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
326
- size = size[0]
327
- return jaten._zeros(size, dtype=dtype)
351
+ if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
352
+ size = size[0]
353
+ return jaten._zeros(size, dtype=dtype)
328
354
 
329
355
 
330
356
  @register_function(torch.eye)
331
357
  @op_base.convert_dtype()
332
358
  def _eye(n: int, m: Optional[int] = None, *, dtype=None, **kwargs):
333
- return jnp.eye(n, m, dtype=dtype)
359
+ return jnp.eye(n, m, dtype=dtype)
334
360
 
335
361
 
336
362
  @register_function(torch.full)
337
- @op_base.convert_dtype()
363
+ @op_base.convert_dtype(use_default_dtype=False)
338
364
  def _full(size: Sequence[int], fill_value, *, dtype=None, **kwargs):
339
- # TODO: handle torch.Size
340
- return jnp.full(size, fill_value, dtype=dtype)
365
+ # TODO: handle torch.Size
366
+ return jnp.full(size, fill_value, dtype=dtype)
341
367
 
342
368
 
343
369
  @register_function(torch.empty)
344
370
  @op_base.convert_dtype()
345
371
  def empty(*size: Sequence[int], dtype=None, **kwargs):
346
- if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
347
- size = size[0]
348
- return jnp.empty(size, dtype=dtype)
372
+ if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
373
+ size = size[0]
374
+ return jnp.empty(size, dtype=dtype)
349
375
 
350
376
 
351
377
  @register_function(torch.arange, is_jax_function=True)
@@ -360,12 +386,12 @@ def arange(
360
386
  requires_grad=False,
361
387
  pin_memory=None,
362
388
  ):
363
- if end is None:
364
- end = start
365
- start = 0
366
- if step is None:
367
- step = 1
368
- return jaten._aten_arange(start, end, step, dtype=dtype)
389
+ if end is None:
390
+ end = start
391
+ start = 0
392
+ if step is None:
393
+ step = 1
394
+ return jaten._aten_arange(start, end, step, dtype=dtype)
369
395
 
370
396
 
371
397
  @register_function(torch.empty_strided, is_jax_function=True)
@@ -379,19 +405,19 @@ def empty_strided(
379
405
  requires_grad=False,
380
406
  pin_memory=False,
381
407
  ):
382
- return empty(size, dtype=dtype, requires_grad=requires_grad)
408
+ return empty(size, dtype=dtype, requires_grad=requires_grad)
383
409
 
384
410
 
385
411
  @register_function(torch.unravel_index)
386
412
  def unravel_index(indices, shape):
387
- return jnp.unravel_index(indices, shape)
413
+ return jnp.unravel_index(indices, shape)
388
414
 
389
415
 
390
416
  @register_function(torch.rand, is_jax_function=True, needs_env=True)
391
417
  def rand(*size, **kwargs):
392
- if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
393
- size = size[0]
394
- return jaten._rand(size, **kwargs)
418
+ if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
419
+ size = size[0]
420
+ return jaten._rand(size, **kwargs)
395
421
 
396
422
 
397
423
  @register_function(torch.randn, is_jax_function=True, needs_env=True)
@@ -406,120 +432,120 @@ def randn(
406
432
  pin_memory=False,
407
433
  env=None,
408
434
  ):
409
- if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
410
- size = size[0]
411
- return jaten._aten_randn(size, generator=generator, dtype=dtype, env=env)
435
+ if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
436
+ size = size[0]
437
+ return jaten._aten_randn(size, generator=generator, dtype=dtype, env=env)
412
438
 
413
439
 
414
440
  @register_function(torch.randint, is_jax_function=False, needs_env=True)
415
441
  def randint(*args, **kwargs):
416
- return jaten._aten_randint(*args, **kwargs)
442
+ return jaten._aten_randint(*args, **kwargs)
417
443
 
418
444
 
419
445
  @register_function(torch.logdet)
420
446
  def logdet(input):
421
- _, logabsdet = jaten._aten__linalg_slogdet(input)
422
- return logabsdet
447
+ _, logabsdet = jaten._aten__linalg_slogdet(input)
448
+ return logabsdet
423
449
 
424
450
 
425
451
  @register_function(torch.linalg.slogdet)
426
452
  def linalg_slogdet(input):
427
- sign, logabsdet = jaten._aten__linalg_slogdet(input)
428
- return torch.return_types.slogdet((sign, logabsdet))
453
+ sign, logabsdet = jaten._aten__linalg_slogdet(input)
454
+ return torch.return_types.slogdet((sign, logabsdet))
429
455
 
430
456
 
431
457
  @register_function(torch.tensor_split)
432
458
  def tensor_split(input, indices_or_sections, dim=0):
433
- return jnp.array_split(input, indices_or_sections, axis=dim)
459
+ return jnp.array_split(input, indices_or_sections, axis=dim)
434
460
 
435
461
 
436
462
  @register_function(torch.linalg.solve)
437
463
  def linalg_solve(a, b):
438
- res, _ = jaten._aten__linalg_solve_ex(a, b)
439
- return res
464
+ res, _ = jaten._aten__linalg_solve_ex(a, b)
465
+ return res
440
466
 
441
467
 
442
468
  @register_function(torch.linalg.solve_ex)
443
469
  def linalg_solve_ex(a, b):
444
- res, info = jaten._aten__linalg_solve_ex(a, b)
445
- return res, info
470
+ res, info = jaten._aten__linalg_solve_ex(a, b)
471
+ return res, info
446
472
 
447
473
 
448
474
  @register_function(torch.linalg.svd)
449
475
  def linalg_svd(a, full_matrices=True):
450
- return jaten._aten__linalg_svd(a, full_matrices=full_matrices)
476
+ return jaten._aten__linalg_svd(a, full_matrices=full_matrices)
451
477
 
452
478
 
453
479
  @register_function(torch.linalg.matrix_power)
454
480
  def matrix_power(A, n, *, out=None):
455
- return jnp.linalg.matrix_power(A, n)
481
+ return jnp.linalg.matrix_power(A, n)
456
482
 
457
483
 
458
484
  @register_function(torch.svd)
459
485
  def svd(a, some=True, compute_uv=True):
460
- if not compute_uv:
461
- S = jaten._aten__linalg_svd(a, full_matrices=False)[1]
462
- U = jnp.zeros((a.shape[-2], a.shape[-2]), dtype=a.dtype)
463
- V = jnp.zeros((a.shape[-1], a.shape[-1]), dtype=a.dtype)
464
- return U, S, V
465
- U, S, V = jaten._aten__linalg_svd(a, full_matrices=not some)
466
- return U, S, jnp.matrix_transpose(V)
486
+ if not compute_uv:
487
+ S = jaten._aten__linalg_svd(a, full_matrices=False)[1]
488
+ U = jnp.zeros((a.shape[-2], a.shape[-2]), dtype=a.dtype)
489
+ V = jnp.zeros((a.shape[-1], a.shape[-1]), dtype=a.dtype)
490
+ return U, S, V
491
+ U, S, V = jaten._aten__linalg_svd(a, full_matrices=not some)
492
+ return U, S, jnp.matrix_transpose(V)
467
493
 
468
494
 
469
495
  @register_function(torch.cdist)
470
496
  def _cdist(x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary"):
471
- return jaten._aten_cdist(x1, x2, p, compute_mode)
497
+ return jaten._aten_cdist(x1, x2, p, compute_mode)
472
498
 
473
499
 
474
500
  @register_function(torch.lu)
475
501
  def lu(A, **kwargs):
476
- lu, pivots, _ = jax.lax.linalg.lu(A)
477
- # JAX pivots are offset by 1 compared to torch
478
- _pivots = pivots + 1
479
- info_shape = pivots.shape[:-1]
480
- info = jnp.zeros(info_shape, dtype=mappings.t2j_dtype(torch.int32))
481
- if kwargs["get_infos"] == True:
482
- return lu, _pivots, info
483
- return lu, _pivots
502
+ lu, pivots, _ = jax.lax.linalg.lu(A)
503
+ # JAX pivots are offset by 1 compared to torch
504
+ _pivots = pivots + 1
505
+ info_shape = pivots.shape[:-1]
506
+ info = jnp.zeros(info_shape, dtype=mappings.t2j_dtype(torch.int32))
507
+ if kwargs["get_infos"] == True:
508
+ return lu, _pivots, info
509
+ return lu, _pivots
484
510
 
485
511
 
486
512
  @register_function(torch.lu_solve)
487
513
  def lu_solve(b, LU_data, LU_pivots, **kwargs):
488
- # JAX pivots are offset by 1 compared to torch
489
- _pivots = LU_pivots - 1
490
- x = jax.scipy.linalg.lu_solve((LU_data, _pivots), b)
491
- return x
514
+ # JAX pivots are offset by 1 compared to torch
515
+ _pivots = LU_pivots - 1
516
+ x = jax.scipy.linalg.lu_solve((LU_data, _pivots), b)
517
+ return x
492
518
 
493
519
 
494
520
  @register_function(torch.linalg.tensorsolve)
495
521
  def linalg_tensorsolve(A, b, dims=None):
496
- # examples:
497
- # A = torch.randn(2, 3, 6), b = torch.randn(3, 2)
498
- # A = torch.randn(2, 3, 6), b = torch.randn(2, 3) -> torch.Size([3, 6])
499
- # A = torch.randn(9, 2, 6, 3) b = torch.randn(6, 3) -> torch.Size([6, 3])
500
- # A = torch.randn(9, 2, 3, 6) b = torch.randn(6, 3) -> torch.Size([3, 6])
501
- # A = torch.randn(18, 6, 3) b = torch.randn(18) -> torch.Size([6, 3])
502
- # A = torch.randn(3, 8, 4, 6) b = torch.randn(4, 6) -> torch.Size([4,6])
503
- # A = torch.randn(3, 8, 1, 2, 2, 6) b = torch.randn(3, 4, 2) -> torch.Size([2, 2, 6])
504
-
505
- # torch allows b to be shaped differently.
506
- # especially when axes are moved using dims.
507
- # ValueError: After moving axes to end, leading shape of a must match shape of b. got a.shape=(3, 2, 6), b.shape=(2, 3)
508
- # So we are handling the moveaxis and forcing b's shape to match what jax expects
509
- if dims is not None:
510
- A = jnp.moveaxis(A, dims, len(dims) * (A.ndim - 1,))
511
- dims = None
512
- if A.shape[:b.ndim] != b.shape:
513
- b = jnp.reshape(b, A.shape[:b.ndim])
514
- return jnp.linalg.tensorsolve(A, b, axes=dims)
522
+ # examples:
523
+ # A = torch.randn(2, 3, 6), b = torch.randn(3, 2)
524
+ # A = torch.randn(2, 3, 6), b = torch.randn(2, 3) -> torch.Size([3, 6])
525
+ # A = torch.randn(9, 2, 6, 3) b = torch.randn(6, 3) -> torch.Size([6, 3])
526
+ # A = torch.randn(9, 2, 3, 6) b = torch.randn(6, 3) -> torch.Size([3, 6])
527
+ # A = torch.randn(18, 6, 3) b = torch.randn(18) -> torch.Size([6, 3])
528
+ # A = torch.randn(3, 8, 4, 6) b = torch.randn(4, 6) -> torch.Size([4,6])
529
+ # A = torch.randn(3, 8, 1, 2, 2, 6) b = torch.randn(3, 4, 2) -> torch.Size([2, 2, 6])
530
+
531
+ # torch allows b to be shaped differently.
532
+ # especially when axes are moved using dims.
533
+ # ValueError: After moving axes to end, leading shape of a must match shape of b. got a.shape=(3, 2, 6), b.shape=(2, 3)
534
+ # So we are handling the moveaxis and forcing b's shape to match what jax expects
535
+ if dims is not None:
536
+ A = jnp.moveaxis(A, dims, len(dims) * (A.ndim - 1,))
537
+ dims = None
538
+ if A.shape[: b.ndim] != b.shape:
539
+ b = jnp.reshape(b, A.shape[: b.ndim])
540
+ return jnp.linalg.tensorsolve(A, b, axes=dims)
515
541
 
516
542
 
517
543
  @register_function(torch.nn.functional.linear)
518
544
  def functional_linear(self, weights, bias=None):
519
- res = jnp.einsum("...a,ba->...b", self, weights)
520
- if bias is not None:
521
- res += bias
522
- return res
545
+ res = jnp.einsum("...a,ba->...b", self, weights)
546
+ if bias is not None:
547
+ res += bias
548
+ return res
523
549
 
524
550
 
525
551
  @register_function(torch.nn.functional.interpolate)
@@ -532,45 +558,74 @@ def functional_interpolate(
532
558
  recompute_scale_factor: bool,
533
559
  antialias: bool,
534
560
  ):
535
- supported_methods = (
536
- "nearest",
537
- "linear",
538
- "bilinear",
539
- "trilinear",
540
- "cubic",
541
- "bicubic",
542
- "tricubic",
543
- "lanczos3",
544
- "lanczos5",
545
- )
546
- is_jax_supported = mode in supported_methods
547
- if not is_jax_supported:
548
- raise torchax.tensor.OperatorNotFound(
549
- f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}"
550
- )
551
- # None check
552
- antialias = antialias or False
553
- align_corners = align_corners or False
554
-
555
- if mode in ('cubic', 'bicubic',
556
- 'tricubic') and not antialias and size is not None:
557
- return jimage.interpolate_bicubic_no_aa(
558
- input,
559
- size[0],
560
- size[1],
561
- align_corners,
562
- )
563
- else:
564
- # fallback
565
- raise torchax.tensor.OperatorNotFound(
566
- f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}"
561
+ supported_methods = (
562
+ "nearest",
563
+ "linear",
564
+ "bilinear",
565
+ "trilinear",
566
+ "cubic",
567
+ "bicubic",
568
+ "tricubic",
569
+ "lanczos3",
570
+ "lanczos5",
567
571
  )
572
+ is_jax_supported = mode in supported_methods
573
+ if not is_jax_supported:
574
+ raise torchax.tensor.OperatorNotFound(
575
+ f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}"
576
+ )
577
+ # None check
578
+ antialias = antialias or False
579
+ align_corners = align_corners or False
580
+
581
+ if (
582
+ mode in ("cubic", "bicubic", "tricubic")
583
+ and not antialias
584
+ and size is not None
585
+ ):
586
+ return jimage.interpolate_bicubic_no_aa(
587
+ input,
588
+ size[0],
589
+ size[1],
590
+ align_corners,
591
+ )
592
+ else:
593
+ # fallback
594
+ raise torchax.tensor.OperatorNotFound(
595
+ f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}"
596
+ )
568
597
 
569
598
 
570
599
  @register_function(torch.Tensor.repeat_interleave)
571
- def torch_Tensor_repeat_interleave(self,
572
- repeats,
573
- dim=None,
574
- *,
575
- output_size=None):
576
- return jnp.repeat(self, repeats, axis=dim, total_repeat_length=output_size)
600
+ def torch_Tensor_repeat_interleave(
601
+ self, repeats, dim=None, *, output_size=None
602
+ ):
603
+ return jnp.repeat(self, repeats, axis=dim, total_repeat_length=output_size)
604
+
605
+
606
+ @register_function(torch.nn.functional.max_pool2d)
607
+ def _functional_max_pool2d(
608
+ input,
609
+ kernel_size,
610
+ stride=None,
611
+ padding=0,
612
+ dilation=1,
613
+ ceil_mode=False,
614
+ return_indices=False,
615
+ ):
616
+ if isinstance(kernel_size, int):
617
+ kernel_size = (kernel_size, kernel_size)
618
+ if stride is None:
619
+ stride = kernel_size
620
+ if isinstance(stride, int):
621
+ stride = (stride, stride)
622
+
623
+ return jaten.max_pool(
624
+ input,
625
+ kernel_size,
626
+ stride,
627
+ padding,
628
+ dilation,
629
+ ceil_mode,
630
+ with_index=return_indices,
631
+ )