torchax 0.0.10.dev20251114__py3-none-any.whl → 0.0.11.dev202612__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

torchax/ops/jtorch.py CHANGED
@@ -14,618 +14,554 @@
14
14
 
15
15
  """Tensor constructor overrides"""
16
16
 
17
- import math
18
17
  import collections.abc
19
18
  import functools
20
- from typing import Optional, Sequence, Tuple
21
- from jax._src.interpreters.mlir import wrap_with_memory_kind
22
- import numpy as np
19
+ import math
20
+ from collections.abc import Sequence
23
21
 
24
22
  import jax
25
23
  import jax.numpy as jnp
26
- from jax.experimental.pallas.ops.tpu import flash_attention
27
- from jax.experimental.shard_map import shard_map
28
-
24
+ import numpy as np
29
25
  import torch
30
- from torchax.ops.ops_registry import register_torch_function_op
31
- from torchax.ops import op_base, mappings, jaten, jimage
32
- import torchax.tensor
33
- from torchax.view import View, NarrowInfo
34
26
  import torch.utils._pytree as pytree
35
27
 
28
+ import torchax.tensor
29
+ from torchax.ops import jaten, jimage, mappings, op_base
30
+ from torchax.ops.ops_registry import register_torch_function_op
31
+ from torchax.view import NarrowInfo, View
32
+
36
33
 
37
34
  def register_function(torch_func, **kwargs):
38
- return functools.partial(register_torch_function_op, torch_func, **kwargs)
35
+ return functools.partial(register_torch_function_op, torch_func, **kwargs)
39
36
 
40
37
 
41
38
  @register_function(torch.as_tensor, is_jax_function=False, needs_env=True)
42
- @op_base.convert_dtype(
43
- use_default_dtype=False
44
- ) # Attempt to infer type from elements
39
+ @op_base.convert_dtype(use_default_dtype=False) # Attempt to infer type from elements
45
40
  def _as_tensor(data, dtype=None, device=None, env=None):
46
- if isinstance(data, torch.Tensor):
47
- return env._to_copy(data, dtype, device)
48
- if isinstance(data, np.ndarray):
49
- jax_res = jnp.asarray(data)
50
- else:
51
- jax_res = _tensor(data, dtype=dtype)
52
- return torchax.tensor.Tensor(jax_res, env)
41
+ if isinstance(data, torch.Tensor):
42
+ return env._to_copy(data, dtype, device)
43
+ if isinstance(data, np.ndarray):
44
+ jax_res = jnp.asarray(data)
45
+ else:
46
+ jax_res = _tensor(data, dtype=dtype)
47
+ return torchax.tensor.Tensor(jax_res, env)
53
48
 
54
49
 
55
50
  @register_function(torch.tensor)
56
- @op_base.convert_dtype(
57
- use_default_dtype=False
58
- ) # Attempt to infer type from elements
51
+ @op_base.convert_dtype(use_default_dtype=False) # Attempt to infer type from elements
59
52
  def _tensor(data, *, dtype=None, **kwargs):
60
- python_types_to_torch_types = {
61
- bool: jnp.bool,
62
- int: jnp.int64,
63
- float: jnp.float32,
64
- complex: jnp.complex64,
65
- }
66
- if not dtype:
67
- leaves = jax.tree_util.tree_leaves(data)
68
- if len(leaves) > 0:
69
- dtype = python_types_to_torch_types.get(type(leaves[0]))
70
-
71
- return jnp.array(
72
- data, dtype=dtype or mappings.t2j_dtype(torch.get_default_dtype())
73
- )
53
+ python_types_to_torch_types = {
54
+ bool: jnp.bool,
55
+ int: jnp.int64,
56
+ float: jnp.float32,
57
+ complex: jnp.complex64,
58
+ }
59
+ if not dtype:
60
+ leaves = jax.tree_util.tree_leaves(data)
61
+ if len(leaves) > 0:
62
+ dtype = python_types_to_torch_types.get(type(leaves[0]))
63
+
64
+ return jnp.array(data, dtype=dtype or mappings.t2j_dtype(torch.get_default_dtype()))
74
65
 
75
66
 
76
67
  @register_function(torch.allclose)
77
68
  def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):
78
- return jnp.allclose(input, other, rtol, atol, equal_nan)
69
+ return jnp.allclose(input, other, rtol, atol, equal_nan)
79
70
 
80
71
 
81
72
  @register_function(torch.angle)
82
73
  def _torch_angle(input):
83
- if input.dtype.name == "int64":
84
- input = input.astype(jnp.dtype("float32"))
85
- return jnp.angle(input)
74
+ if input.dtype.name == "int64":
75
+ input = input.astype(jnp.dtype("float32"))
76
+ return jnp.angle(input)
86
77
 
87
78
 
88
79
  @register_function(torch.argsort)
89
80
  def _torch_argsort(input, dim=-1, descending=False, stable=False):
90
- expanded = False
91
- if input.ndim == 0:
92
- # for self of rank 0:
93
- # torch.any(x, 0), torch.any(x, -1) works;
94
- # torch.any(x, 1) throws out of bounds, so it's
95
- # behavior is the same as a jnp array of rank 1
96
- expanded = True
97
- input = jnp.expand_dims(input, 0)
98
- res = jnp.argsort(input, axis=dim, descending=descending, stable=stable)
99
- if expanded:
100
- res = res.squeeze()
101
- return res
81
+ expanded = False
82
+ if input.ndim == 0:
83
+ # for self of rank 0:
84
+ # torch.any(x, 0), torch.any(x, -1) works;
85
+ # torch.any(x, 1) throws out of bounds, so it's
86
+ # behavior is the same as a jnp array of rank 1
87
+ expanded = True
88
+ input = jnp.expand_dims(input, 0)
89
+ res = jnp.argsort(input, axis=dim, descending=descending, stable=stable)
90
+ if expanded:
91
+ res = res.squeeze()
92
+ return res
102
93
 
103
94
 
104
95
  @register_function(torch.diag)
105
96
  def _diag(input, diagonal=0):
106
- return jnp.diag(input, k=diagonal)
97
+ return jnp.diag(input, k=diagonal)
107
98
 
108
99
 
109
100
  @register_function(torch.einsum)
110
101
  @register_function(torch.ops.aten.einsum)
111
102
  def _einsum(equation, *operands):
112
- def get_params(*a):
113
- inner_list = a[0]
114
- if not isinstance(inner_list, jax.Array):
115
- if len(inner_list) == 1:
116
- A = inner_list
117
- return A
118
- elif len(inner_list) == 2:
119
- A, B = inner_list
120
- return A, B
121
- return operands
122
-
123
- assert isinstance(equation, str), "Only accept str equation"
124
- filtered_operands = get_params(*operands)
125
- return jnp.einsum(equation, *filtered_operands)
103
+ def get_params(*a):
104
+ inner_list = a[0]
105
+ if not isinstance(inner_list, jax.Array):
106
+ if len(inner_list) == 1:
107
+ A = inner_list
108
+ return A
109
+ elif len(inner_list) == 2:
110
+ A, B = inner_list
111
+ return A, B
112
+ return operands
113
+
114
+ assert isinstance(equation, str), "Only accept str equation"
115
+ filtered_operands = get_params(*operands)
116
+ return jnp.einsum(equation, *filtered_operands)
126
117
 
127
118
 
128
119
  def _sdpa_reference(
129
- query,
130
- key,
131
- value,
132
- attn_mask=None,
133
- dropout_p=0.0,
134
- is_causal=False,
135
- scale=None,
136
- enable_gqa=False,
120
+ query,
121
+ key,
122
+ value,
123
+ attn_mask=None,
124
+ dropout_p=0.0,
125
+ is_causal=False,
126
+ scale=None,
127
+ enable_gqa=False,
137
128
  ) -> torch.Tensor:
138
- L, S = query.size(-2), key.size(-2)
139
- scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
140
- attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
141
- if is_causal:
142
- assert attn_mask is None
143
- temp_mask = torch.ones(
144
- L, S, dtype=torch.bool, device=query.device
145
- ).tril(diagonal=0)
146
- attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
147
- attn_bias.to(query.dtype)
148
- if attn_mask is not None:
149
- if attn_mask.dtype == torch.bool:
150
- attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
151
- else:
152
- attn_bias += attn_mask
153
- if enable_gqa:
154
- key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
155
- value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)
156
-
157
- attn_weight = query @ key.transpose(-2, -1) * scale_factor
158
- attn_weight += attn_bias
159
- attn_weight = torch.softmax(attn_weight, dim=-1)
160
- if dropout_p > 0:
161
- attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
162
- return attn_weight @ value
163
-
164
-
165
- from jax.sharding import PartitionSpec
166
-
167
-
168
- def _tpu_flash_attention(query, key, value, env):
169
- fsdp_partition = PartitionSpec("fsdp")
170
-
171
- def wrap_flash_attention(query, key, value):
172
- block_sizes = flash_attention.BlockSizes(
173
- block_b=min(2, query.shape[0]),
174
- block_q=min(512, query.shape[2]),
175
- block_k_major=min(512, key.shape[2]),
176
- block_k=min(512, key.shape[2]),
177
- block_q_major_dkv=min(512, query.shape[2]),
178
- block_k_major_dkv=min(512, key.shape[2]),
179
- block_k_dkv=min(512, key.shape[2]),
180
- block_q_dkv=min(512, query.shape[2]),
181
- block_k_major_dq=min(512, key.shape[2]),
182
- block_k_dq=min(256, key.shape[2]),
183
- block_q_dq=min(1024, query.shape[2]),
184
- )
185
- return flash_attention.flash_attention(
186
- query, key, value, causal=True, block_sizes=block_sizes
187
- )
188
-
189
- if env.config.shmap_flash_attention:
190
- wrap_flash_attention = shard_map(
191
- wrap_flash_attention,
192
- mesh=env._mesh,
193
- in_specs=(fsdp_partition, fsdp_partition, fsdp_partition),
194
- out_specs=fsdp_partition,
195
- check_rep=False,
196
- )
197
- # return flash_attn_mapped(query, key, value)
198
- return wrap_flash_attention(query, key, value)
129
+ L, S = query.size(-2), key.size(-2)
130
+ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
131
+ attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
132
+ if is_causal:
133
+ assert attn_mask is None
134
+ temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
135
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
136
+ attn_bias.to(query.dtype)
137
+ if attn_mask is not None:
138
+ if attn_mask.dtype == torch.bool:
139
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
140
+ else:
141
+ attn_bias += attn_mask
142
+ if enable_gqa:
143
+ key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
144
+ value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)
145
+
146
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
147
+ attn_weight += attn_bias
148
+ attn_weight = torch.softmax(attn_weight, dim=-1)
149
+ if dropout_p > 0:
150
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
151
+ return attn_weight @ value
199
152
 
200
153
 
201
154
  @register_function(torch.nn.functional.one_hot)
202
155
  def one_hot(tensor, num_classes=-1):
203
- if num_classes == -1:
204
- num_classes = jnp.max(tensor) + 1
205
- return jax.nn.one_hot(tensor, num_classes).astype(jnp.int64)
156
+ if num_classes == -1:
157
+ num_classes = jnp.max(tensor) + 1
158
+ return jax.nn.one_hot(tensor, num_classes).astype(jnp.int64)
206
159
 
207
160
 
208
161
  @register_function(torch.nn.functional.pad)
209
162
  def pad(tensor, pad, mode="constant", value=None):
210
- # For padding modes that have different names between Torch and NumPy, this
211
- # dict provides a Torch-to-NumPy translation. Any string not in this dict will
212
- # be passed through as-is.
213
- MODE_NAME_TRANSLATION = {
214
- "circular": "wrap",
215
- "replicate": "edge",
216
- }
217
-
218
- numpy_mode = MODE_NAME_TRANSLATION.get(mode, mode)
219
-
220
- num_prefix_dims = tensor.ndim - len(pad) // 2
221
-
222
- numpy_pad_width = [(0, 0)] * num_prefix_dims
223
- nd_slice = [slice(None)] * num_prefix_dims
224
-
225
- for i in range(len(pad) - 2, -1, -2):
226
- pad_start, pad_end = pad[i : i + 2]
227
- slice_start, slice_end = None, None
228
-
229
- if pad_start < 0:
230
- slice_start = -pad_start
231
- pad_start = 0
232
-
233
- if pad_end < 0:
234
- slice_end = pad_end
235
- pad_end = 0
236
-
237
- numpy_pad_width.append((pad_start, pad_end))
238
- nd_slice.append(slice(slice_start, slice_end))
239
-
240
- nd_slice = tuple(nd_slice)
241
-
242
- # `jax.numpy.pad` complains if we provide an irrelevant `constant_values` arg,
243
- # even if the value we pass in is `None`. (It treats `None` as `nan`.)
244
- kwargs = dict()
245
- if mode == "constant" and value is not None:
246
- kwargs["constant_values"] = value
247
-
248
- # The "replicate" mode pads first and then slices, whereas the "circular" mode
249
- # slices first and then pads. The latter approach deals with smaller tensors,
250
- # so we default to that option in modes where the order of operations doesn't
251
- # affect the result.
252
- if mode == "replicate":
253
- return jnp.pad(tensor, numpy_pad_width, mode=numpy_mode, **kwargs)[
254
- nd_slice
255
- ]
256
- else:
257
- return jnp.pad(
258
- tensor[nd_slice], numpy_pad_width, mode=numpy_mode, **kwargs
259
- )
163
+ # For padding modes that have different names between Torch and NumPy, this
164
+ # dict provides a Torch-to-NumPy translation. Any string not in this dict will
165
+ # be passed through as-is.
166
+ MODE_NAME_TRANSLATION = {
167
+ "circular": "wrap",
168
+ "replicate": "edge",
169
+ }
170
+
171
+ numpy_mode = MODE_NAME_TRANSLATION.get(mode, mode)
172
+
173
+ num_prefix_dims = tensor.ndim - len(pad) // 2
174
+
175
+ numpy_pad_width = [(0, 0)] * num_prefix_dims
176
+ nd_slice = [slice(None)] * num_prefix_dims
177
+
178
+ for i in range(len(pad) - 2, -1, -2):
179
+ pad_start, pad_end = pad[i : i + 2]
180
+ slice_start, slice_end = None, None
181
+
182
+ if pad_start < 0:
183
+ slice_start = -pad_start
184
+ pad_start = 0
185
+
186
+ if pad_end < 0:
187
+ slice_end = pad_end
188
+ pad_end = 0
189
+
190
+ numpy_pad_width.append((pad_start, pad_end))
191
+ nd_slice.append(slice(slice_start, slice_end))
192
+
193
+ nd_slice = tuple(nd_slice)
194
+
195
+ # `jax.numpy.pad` complains if we provide an irrelevant `constant_values` arg,
196
+ # even if the value we pass in is `None`. (It treats `None` as `nan`.)
197
+ kwargs = {}
198
+ if mode == "constant" and value is not None:
199
+ kwargs["constant_values"] = value
200
+
201
+ # The "replicate" mode pads first and then slices, whereas the "circular" mode
202
+ # slices first and then pads. The latter approach deals with smaller tensors,
203
+ # so we default to that option in modes where the order of operations doesn't
204
+ # affect the result.
205
+ if mode == "replicate":
206
+ return jnp.pad(tensor, numpy_pad_width, mode=numpy_mode, **kwargs)[nd_slice]
207
+ else:
208
+ return jnp.pad(tensor[nd_slice], numpy_pad_width, mode=numpy_mode, **kwargs)
260
209
 
261
210
 
262
211
  @register_function(
263
- torch.nn.functional.scaled_dot_product_attention,
264
- is_jax_function=False,
265
- needs_env=True,
212
+ torch.nn.functional.scaled_dot_product_attention,
213
+ is_jax_function=False,
214
+ needs_env=True,
266
215
  )
267
216
  @register_function(
268
- torch.ops.aten.scaled_dot_product_attention,
269
- is_jax_function=False,
270
- needs_env=True,
217
+ torch.ops.aten.scaled_dot_product_attention,
218
+ is_jax_function=False,
219
+ needs_env=True,
271
220
  )
272
221
  def scaled_dot_product_attention(
273
- query,
274
- key,
275
- value,
276
- attn_mask=None,
277
- dropout_p=0.0,
278
- is_causal=False,
279
- scale=None,
280
- enable_gqa=False,
281
- env=None,
222
+ query,
223
+ key,
224
+ value,
225
+ attn_mask=None,
226
+ dropout_p=0.0,
227
+ is_causal=False,
228
+ scale=None,
229
+ enable_gqa=False,
230
+ env=None,
282
231
  ) -> torch.Tensor:
283
- if env.config.use_tpu_flash_attention:
284
- jquery, jkey, jvalue = env.t2j_iso((query, key, value))
285
- res = _tpu_flash_attention(jquery, jkey, jvalue, env)
286
- return env.j2t_iso(res)
287
-
288
- return _sdpa_reference(
289
- query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa
290
- )
232
+ return _sdpa_reference(
233
+ query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa
234
+ )
291
235
 
292
236
 
293
- @register_function(
294
- torch.Tensor.__getitem__, is_jax_function=False, is_view_op=True
295
- )
237
+ @register_function(torch.Tensor.__getitem__, is_jax_function=False, is_view_op=True)
296
238
  def getitem(self, indexes):
297
- if isinstance(indexes, list) and isinstance(indexes[0], int):
298
- # list of int, i.e. x[[1, 2]] NOT x[1, 2] (the second would be tuple of int)
299
- indexes = (indexes,)
300
- elif isinstance(indexes, list):
301
- indexes = tuple(indexes)
302
-
303
- def is_narrow_slicing():
304
- tensor_free = not pytree.tree_any(
305
- lambda x: isinstance(x, torch.Tensor) or isinstance(x, jax.Array),
306
- indexes,
307
- )
308
- list_free = not isinstance(indexes, tuple) or all(
309
- [False if isinstance(x, list) else True for x in indexes]
310
- )
311
- return tensor_free and list_free
239
+ if isinstance(indexes, list) and isinstance(indexes[0], int):
240
+ # list of int, i.e. x[[1, 2]] NOT x[1, 2] (the second would be tuple of int)
241
+ indexes = (indexes,)
242
+ elif isinstance(indexes, list):
243
+ indexes = tuple(indexes)
244
+
245
+ def is_narrow_slicing():
246
+ tensor_free = not pytree.tree_any(
247
+ lambda x: isinstance(x, torch.Tensor) or isinstance(x, jax.Array),
248
+ indexes,
249
+ )
250
+ list_free = not isinstance(indexes, tuple) or all(
251
+ False if isinstance(x, list) else True for x in indexes
252
+ )
253
+ return tensor_free and list_free
312
254
 
313
- if is_narrow_slicing():
314
- return View(self, view_info=NarrowInfo(indexes), env=self._env)
255
+ if is_narrow_slicing():
256
+ return View(self, view_info=NarrowInfo(indexes), env=self._env)
315
257
 
316
- indexes = self._env.t2j_iso(indexes)
317
- return torchax.tensor.Tensor(self._elem[indexes], self._env)
258
+ indexes = self._env.t2j_iso(indexes)
259
+ return torchax.tensor.Tensor(self._elem[indexes], self._env)
318
260
 
319
261
 
320
262
  @register_function(torch.corrcoef)
321
263
  def _corrcoef(x):
322
- if x.dtype.name == "int64":
323
- return jnp.corrcoef(x).astype(jnp.float32)
324
- return jnp.corrcoef(x)
264
+ if x.dtype.name == "int64":
265
+ return jnp.corrcoef(x).astype(jnp.float32)
266
+ return jnp.corrcoef(x)
325
267
 
326
268
 
327
269
  @register_function(torch.sparse.mm, is_jax_function=False)
328
270
  def _sparse_mm(mat1, mat2, reduce="sum"):
329
- return torch.mm(mat1, mat2)
271
+ return torch.mm(mat1, mat2)
330
272
 
331
273
 
332
274
  @register_function(torch.isclose)
333
275
  def _aten_isclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):
334
- return jnp.isclose(input, other, rtol, atol, equal_nan)
276
+ return jnp.isclose(input, other, rtol, atol, equal_nan)
335
277
 
336
278
 
337
279
  @register_function(torch.linalg.det)
338
280
  def linalg_det(input):
339
- return jnp.linalg.det(input)
281
+ return jnp.linalg.det(input)
340
282
 
341
283
 
342
284
  @register_function(torch.ones)
343
285
  def _ones(*size: int, dtype=None, **kwargs):
344
- if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
345
- size = size[0]
346
- return jaten._ones(size, dtype=dtype)
286
+ if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
287
+ size = size[0]
288
+ return jaten._ones(size, dtype=dtype)
347
289
 
348
290
 
349
291
  @register_function(torch.zeros, is_jax_function=True)
350
292
  def _zeros(*size: int, dtype=None, **kwargs):
351
- if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
352
- size = size[0]
353
- return jaten._zeros(size, dtype=dtype)
293
+ if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
294
+ size = size[0]
295
+ return jaten._zeros(size, dtype=dtype)
354
296
 
355
297
 
356
298
  @register_function(torch.eye)
357
299
  @op_base.convert_dtype()
358
- def _eye(n: int, m: Optional[int] = None, *, dtype=None, **kwargs):
359
- return jnp.eye(n, m, dtype=dtype)
300
+ def _eye(n: int, m: int | None = None, *, dtype=None, **kwargs):
301
+ return jnp.eye(n, m, dtype=dtype)
360
302
 
361
303
 
362
304
  @register_function(torch.full)
363
305
  @op_base.convert_dtype(use_default_dtype=False)
364
306
  def _full(size: Sequence[int], fill_value, *, dtype=None, **kwargs):
365
- # TODO: handle torch.Size
366
- return jnp.full(size, fill_value, dtype=dtype)
307
+ # TODO: handle torch.Size
308
+ return jnp.full(size, fill_value, dtype=dtype)
367
309
 
368
310
 
369
311
  @register_function(torch.empty)
370
312
  @op_base.convert_dtype()
371
313
  def empty(*size: Sequence[int], dtype=None, **kwargs):
372
- if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
373
- size = size[0]
374
- return jnp.empty(size, dtype=dtype)
314
+ if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
315
+ size = size[0]
316
+ return jnp.empty(size, dtype=dtype)
375
317
 
376
318
 
377
319
  @register_function(torch.arange, is_jax_function=True)
378
320
  def arange(
379
- start,
380
- end=None,
381
- step=None,
382
- out=None,
383
- dtype=None,
384
- layout=torch.strided,
385
- device=None,
386
- requires_grad=False,
387
- pin_memory=None,
321
+ start,
322
+ end=None,
323
+ step=None,
324
+ out=None,
325
+ dtype=None,
326
+ layout=torch.strided,
327
+ device=None,
328
+ requires_grad=False,
329
+ pin_memory=None,
388
330
  ):
389
- if end is None:
390
- end = start
391
- start = 0
392
- if step is None:
393
- step = 1
394
- return jaten._aten_arange(start, end, step, dtype=dtype)
331
+ if end is None:
332
+ end = start
333
+ start = 0
334
+ if step is None:
335
+ step = 1
336
+ return jaten._aten_arange(start, end, step, dtype=dtype)
395
337
 
396
338
 
397
339
  @register_function(torch.empty_strided, is_jax_function=True)
398
340
  def empty_strided(
399
- size,
400
- stride,
401
- *,
402
- dtype=None,
403
- layout=None,
404
- device=None,
405
- requires_grad=False,
406
- pin_memory=False,
341
+ size,
342
+ stride,
343
+ *,
344
+ dtype=None,
345
+ layout=None,
346
+ device=None,
347
+ requires_grad=False,
348
+ pin_memory=False,
407
349
  ):
408
- return empty(size, dtype=dtype, requires_grad=requires_grad)
350
+ return empty(size, dtype=dtype, requires_grad=requires_grad)
409
351
 
410
352
 
411
353
  @register_function(torch.unravel_index)
412
354
  def unravel_index(indices, shape):
413
- return jnp.unravel_index(indices, shape)
355
+ return jnp.unravel_index(indices, shape)
414
356
 
415
357
 
416
358
  @register_function(torch.rand, is_jax_function=True, needs_env=True)
417
359
  def rand(*size, **kwargs):
418
- if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
419
- size = size[0]
420
- return jaten._rand(size, **kwargs)
360
+ if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
361
+ size = size[0]
362
+ return jaten._rand(size, **kwargs)
421
363
 
422
364
 
423
365
  @register_function(torch.randn, is_jax_function=True, needs_env=True)
424
366
  def randn(
425
- *size,
426
- generator=None,
427
- out=None,
428
- dtype=None,
429
- layout=torch.strided,
430
- device=None,
431
- requires_grad=False,
432
- pin_memory=False,
433
- env=None,
367
+ *size,
368
+ generator=None,
369
+ out=None,
370
+ dtype=None,
371
+ layout=torch.strided,
372
+ device=None,
373
+ requires_grad=False,
374
+ pin_memory=False,
375
+ env=None,
434
376
  ):
435
- if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
436
- size = size[0]
437
- return jaten._aten_randn(size, generator=generator, dtype=dtype, env=env)
377
+ if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
378
+ size = size[0]
379
+ return jaten._aten_randn(size, generator=generator, dtype=dtype, env=env)
438
380
 
439
381
 
440
382
  @register_function(torch.randint, is_jax_function=False, needs_env=True)
441
383
  def randint(*args, **kwargs):
442
- return jaten._aten_randint(*args, **kwargs)
384
+ return jaten._aten_randint(*args, **kwargs)
443
385
 
444
386
 
445
387
  @register_function(torch.logdet)
446
388
  def logdet(input):
447
- _, logabsdet = jaten._aten__linalg_slogdet(input)
448
- return logabsdet
389
+ _, logabsdet = jaten._aten__linalg_slogdet(input)
390
+ return logabsdet
449
391
 
450
392
 
451
393
  @register_function(torch.linalg.slogdet)
452
394
  def linalg_slogdet(input):
453
- sign, logabsdet = jaten._aten__linalg_slogdet(input)
454
- return torch.return_types.slogdet((sign, logabsdet))
395
+ sign, logabsdet = jaten._aten__linalg_slogdet(input)
396
+ return torch.return_types.slogdet((sign, logabsdet))
455
397
 
456
398
 
457
399
  @register_function(torch.tensor_split)
458
400
  def tensor_split(input, indices_or_sections, dim=0):
459
- return jnp.array_split(input, indices_or_sections, axis=dim)
401
+ return jnp.array_split(input, indices_or_sections, axis=dim)
460
402
 
461
403
 
462
404
  @register_function(torch.linalg.solve)
463
405
  def linalg_solve(a, b):
464
- res, _ = jaten._aten__linalg_solve_ex(a, b)
465
- return res
406
+ res, _ = jaten._aten__linalg_solve_ex(a, b)
407
+ return res
466
408
 
467
409
 
468
410
  @register_function(torch.linalg.solve_ex)
469
411
  def linalg_solve_ex(a, b):
470
- res, info = jaten._aten__linalg_solve_ex(a, b)
471
- return res, info
412
+ res, info = jaten._aten__linalg_solve_ex(a, b)
413
+ return res, info
472
414
 
473
415
 
474
416
  @register_function(torch.linalg.svd)
475
417
  def linalg_svd(a, full_matrices=True):
476
- return jaten._aten__linalg_svd(a, full_matrices=full_matrices)
418
+ return jaten._aten__linalg_svd(a, full_matrices=full_matrices)
477
419
 
478
420
 
479
421
  @register_function(torch.linalg.matrix_power)
480
422
  def matrix_power(A, n, *, out=None):
481
- return jnp.linalg.matrix_power(A, n)
423
+ return jnp.linalg.matrix_power(A, n)
482
424
 
483
425
 
484
426
  @register_function(torch.svd)
485
427
  def svd(a, some=True, compute_uv=True):
486
- if not compute_uv:
487
- S = jaten._aten__linalg_svd(a, full_matrices=False)[1]
488
- U = jnp.zeros((a.shape[-2], a.shape[-2]), dtype=a.dtype)
489
- V = jnp.zeros((a.shape[-1], a.shape[-1]), dtype=a.dtype)
490
- return U, S, V
491
- U, S, V = jaten._aten__linalg_svd(a, full_matrices=not some)
492
- return U, S, jnp.matrix_transpose(V)
428
+ if not compute_uv:
429
+ S = jaten._aten__linalg_svd(a, full_matrices=False)[1]
430
+ U = jnp.zeros((a.shape[-2], a.shape[-2]), dtype=a.dtype)
431
+ V = jnp.zeros((a.shape[-1], a.shape[-1]), dtype=a.dtype)
432
+ return U, S, V
433
+ U, S, V = jaten._aten__linalg_svd(a, full_matrices=not some)
434
+ return U, S, jnp.matrix_transpose(V)
493
435
 
494
436
 
495
437
  @register_function(torch.cdist)
496
438
  def _cdist(x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary"):
497
- return jaten._aten_cdist(x1, x2, p, compute_mode)
439
+ return jaten._aten_cdist(x1, x2, p, compute_mode)
498
440
 
499
441
 
500
442
  @register_function(torch.lu)
501
443
  def lu(A, **kwargs):
502
- lu, pivots, _ = jax.lax.linalg.lu(A)
503
- # JAX pivots are offset by 1 compared to torch
504
- _pivots = pivots + 1
505
- info_shape = pivots.shape[:-1]
506
- info = jnp.zeros(info_shape, dtype=mappings.t2j_dtype(torch.int32))
507
- if kwargs["get_infos"] == True:
508
- return lu, _pivots, info
509
- return lu, _pivots
444
+ lu, pivots, _ = jax.lax.linalg.lu(A)
445
+ # JAX pivots are offset by 1 compared to torch
446
+ _pivots = pivots + 1
447
+ info_shape = pivots.shape[:-1]
448
+ info = jnp.zeros(info_shape, dtype=mappings.t2j_dtype(torch.int32))
449
+ if kwargs["get_infos"]:
450
+ return lu, _pivots, info
451
+ return lu, _pivots
510
452
 
511
453
 
512
454
  @register_function(torch.lu_solve)
513
455
  def lu_solve(b, LU_data, LU_pivots, **kwargs):
514
- # JAX pivots are offset by 1 compared to torch
515
- _pivots = LU_pivots - 1
516
- x = jax.scipy.linalg.lu_solve((LU_data, _pivots), b)
517
- return x
456
+ # JAX pivots are offset by 1 compared to torch
457
+ _pivots = LU_pivots - 1
458
+ x = jax.scipy.linalg.lu_solve((LU_data, _pivots), b)
459
+ return x
518
460
 
519
461
 
520
462
  @register_function(torch.linalg.tensorsolve)
521
463
  def linalg_tensorsolve(A, b, dims=None):
522
- # examples:
523
- # A = torch.randn(2, 3, 6), b = torch.randn(3, 2)
524
- # A = torch.randn(2, 3, 6), b = torch.randn(2, 3) -> torch.Size([3, 6])
525
- # A = torch.randn(9, 2, 6, 3) b = torch.randn(6, 3) -> torch.Size([6, 3])
526
- # A = torch.randn(9, 2, 3, 6) b = torch.randn(6, 3) -> torch.Size([3, 6])
527
- # A = torch.randn(18, 6, 3) b = torch.randn(18) -> torch.Size([6, 3])
528
- # A = torch.randn(3, 8, 4, 6) b = torch.randn(4, 6) -> torch.Size([4,6])
529
- # A = torch.randn(3, 8, 1, 2, 2, 6) b = torch.randn(3, 4, 2) -> torch.Size([2, 2, 6])
530
-
531
- # torch allows b to be shaped differently.
532
- # especially when axes are moved using dims.
533
- # ValueError: After moving axes to end, leading shape of a must match shape of b. got a.shape=(3, 2, 6), b.shape=(2, 3)
534
- # So we are handling the moveaxis and forcing b's shape to match what jax expects
535
- if dims is not None:
536
- A = jnp.moveaxis(A, dims, len(dims) * (A.ndim - 1,))
537
- dims = None
538
- if A.shape[: b.ndim] != b.shape:
539
- b = jnp.reshape(b, A.shape[: b.ndim])
540
- return jnp.linalg.tensorsolve(A, b, axes=dims)
464
+ # examples:
465
+ # A = torch.randn(2, 3, 6), b = torch.randn(3, 2)
466
+ # A = torch.randn(2, 3, 6), b = torch.randn(2, 3) -> torch.Size([3, 6])
467
+ # A = torch.randn(9, 2, 6, 3) b = torch.randn(6, 3) -> torch.Size([6, 3])
468
+ # A = torch.randn(9, 2, 3, 6) b = torch.randn(6, 3) -> torch.Size([3, 6])
469
+ # A = torch.randn(18, 6, 3) b = torch.randn(18) -> torch.Size([6, 3])
470
+ # A = torch.randn(3, 8, 4, 6) b = torch.randn(4, 6) -> torch.Size([4,6])
471
+ # A = torch.randn(3, 8, 1, 2, 2, 6) b = torch.randn(3, 4, 2) -> torch.Size([2, 2, 6])
472
+
473
+ # torch allows b to be shaped differently.
474
+ # especially when axes are moved using dims.
475
+ # ValueError: After moving axes to end, leading shape of a must match shape of b. got a.shape=(3, 2, 6), b.shape=(2, 3)
476
+ # So we are handling the moveaxis and forcing b's shape to match what jax expects
477
+ if dims is not None:
478
+ A = jnp.moveaxis(A, dims, len(dims) * (A.ndim - 1,))
479
+ dims = None
480
+ if A.shape[: b.ndim] != b.shape:
481
+ b = jnp.reshape(b, A.shape[: b.ndim])
482
+ return jnp.linalg.tensorsolve(A, b, axes=dims)
541
483
 
542
484
 
543
485
  @register_function(torch.nn.functional.linear)
544
486
  def functional_linear(self, weights, bias=None):
545
- res = jnp.einsum("...a,ba->...b", self, weights)
546
- if bias is not None:
547
- res += bias
548
- return res
487
+ res = jnp.einsum("...a,ba->...b", self, weights)
488
+ if bias is not None:
489
+ res += bias
490
+ return res
549
491
 
550
492
 
551
493
  @register_function(torch.nn.functional.interpolate)
552
494
  def functional_interpolate(
553
- input,
554
- size: Tuple[int, int],
555
- scale_factor: Optional[float],
556
- mode: str,
557
- align_corners: bool,
558
- recompute_scale_factor: bool,
559
- antialias: bool,
495
+ input,
496
+ size: tuple[int, int],
497
+ scale_factor: float | None,
498
+ mode: str,
499
+ align_corners: bool,
500
+ recompute_scale_factor: bool,
501
+ antialias: bool,
560
502
  ):
561
- supported_methods = (
562
- "nearest",
563
- "linear",
564
- "bilinear",
565
- "trilinear",
566
- "cubic",
567
- "bicubic",
568
- "tricubic",
569
- "lanczos3",
570
- "lanczos5",
503
+ supported_methods = (
504
+ "nearest",
505
+ "linear",
506
+ "bilinear",
507
+ "trilinear",
508
+ "cubic",
509
+ "bicubic",
510
+ "tricubic",
511
+ "lanczos3",
512
+ "lanczos5",
513
+ )
514
+ is_jax_supported = mode in supported_methods
515
+ if not is_jax_supported:
516
+ raise torchax.tensor.OperatorNotFound(
517
+ f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}"
518
+ )
519
+ # None check
520
+ antialias = antialias or False
521
+ align_corners = align_corners or False
522
+
523
+ if mode in ("cubic", "bicubic", "tricubic") and not antialias and size is not None:
524
+ return jimage.interpolate_bicubic_no_aa(
525
+ input,
526
+ size[0],
527
+ size[1],
528
+ align_corners,
529
+ )
530
+ else:
531
+ # fallback
532
+ raise torchax.tensor.OperatorNotFound(
533
+ f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}"
571
534
  )
572
- is_jax_supported = mode in supported_methods
573
- if not is_jax_supported:
574
- raise torchax.tensor.OperatorNotFound(
575
- f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}"
576
- )
577
- # None check
578
- antialias = antialias or False
579
- align_corners = align_corners or False
580
-
581
- if (
582
- mode in ("cubic", "bicubic", "tricubic")
583
- and not antialias
584
- and size is not None
585
- ):
586
- return jimage.interpolate_bicubic_no_aa(
587
- input,
588
- size[0],
589
- size[1],
590
- align_corners,
591
- )
592
- else:
593
- # fallback
594
- raise torchax.tensor.OperatorNotFound(
595
- f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}"
596
- )
597
535
 
598
536
 
599
537
  @register_function(torch.Tensor.repeat_interleave)
600
- def torch_Tensor_repeat_interleave(
601
- self, repeats, dim=None, *, output_size=None
602
- ):
603
- return jnp.repeat(self, repeats, axis=dim, total_repeat_length=output_size)
538
+ def torch_Tensor_repeat_interleave(self, repeats, dim=None, *, output_size=None):
539
+ return jnp.repeat(self, repeats, axis=dim, total_repeat_length=output_size)
604
540
 
605
541
 
606
542
  @register_function(torch.nn.functional.max_pool2d)
607
543
  def _functional_max_pool2d(
544
+ input,
545
+ kernel_size,
546
+ stride=None,
547
+ padding=0,
548
+ dilation=1,
549
+ ceil_mode=False,
550
+ return_indices=False,
551
+ ):
552
+ if isinstance(kernel_size, int):
553
+ kernel_size = (kernel_size, kernel_size)
554
+ if stride is None:
555
+ stride = kernel_size
556
+ if isinstance(stride, int):
557
+ stride = (stride, stride)
558
+
559
+ return jaten.max_pool(
608
560
  input,
609
561
  kernel_size,
610
- stride=None,
611
- padding=0,
612
- dilation=1,
613
- ceil_mode=False,
614
- return_indices=False,
615
- ):
616
- if isinstance(kernel_size, int):
617
- kernel_size = (kernel_size, kernel_size)
618
- if stride is None:
619
- stride = kernel_size
620
- if isinstance(stride, int):
621
- stride = (stride, stride)
622
-
623
- return jaten.max_pool(
624
- input,
625
- kernel_size,
626
- stride,
627
- padding,
628
- dilation,
629
- ceil_mode,
630
- with_index=return_indices,
631
- )
562
+ stride,
563
+ padding,
564
+ dilation,
565
+ ceil_mode,
566
+ with_index=return_indices,
567
+ )