torchax 0.0.11.dev2025128__py3-none-any.whl → 0.0.11.dev20251229__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

torchax/__init__.py CHANGED
@@ -27,7 +27,7 @@ from torchax import tensor
27
27
 
28
28
  from .checkpoint import load_checkpoint, save_checkpoint
29
29
 
30
- __version__ = "0.0.11.dev2025128"
30
+ __version__ = "0.0.11.dev20251229"
31
31
  VERSION = __version__
32
32
 
33
33
  # the "fast path" uses some sparse tensor thingies that currently we
torchax/config.py CHANGED
@@ -35,10 +35,6 @@ class Configuration:
35
35
  # Use DLPack for converting jax.Arrays <-> and torch.Tensor
36
36
  use_dlpack_for_data_conversion: bool = False
37
37
 
38
- # Flash attention
39
- use_tpu_flash_attention: bool = False
40
- shmap_flash_attention: bool = False
41
-
42
38
  # device
43
39
  treat_cuda_as_jax_device: bool = True
44
40
  internal_respect_torch_return_dtypes: bool = False
torchax/interop.py CHANGED
@@ -22,7 +22,6 @@ import jax
22
22
  import jax.numpy as jnp
23
23
  import torch
24
24
  from jax import tree_util as pytree
25
- from jax.experimental.shard_map import shard_map
26
25
  from torch.nn.utils import stateless as torch_stateless
27
26
 
28
27
  import torchax
@@ -30,6 +29,11 @@ from torchax import tensor, util
30
29
  from torchax.ops import mappings
31
30
  from torchax.types import JaxCallable, JaxValue, TorchCallable, TorchValue
32
31
 
32
+ try:
33
+ from jax import shard_map as shard_map # for jax since v0.8.0
34
+ except ImportError:
35
+ from jax.experimental.shard_map import shard_map
36
+
33
37
 
34
38
  def extract_all_buffers(m: torch.nn.Module):
35
39
  buffers = {}
torchax/ops/jaten.py CHANGED
@@ -5776,7 +5776,16 @@ mutation_ops_to_functional = {
5776
5776
  torch.ops.aten.scatter_add_: op_base.InplaceOp(torch.ops.aten.scatter_add),
5777
5777
  torch.ops.aten.scatter_reduce_.two: op_base.InplaceOp(torch.ops.aten.scatter_reduce),
5778
5778
  torch.ops.aten.scatter_: op_base.InplaceOp(torch.ops.aten.scatter),
5779
+ torch.ops.aten.bitwise_not_: op_base.InplaceOp(torch.ops.aten.bitwise_not),
5780
+ torch.ops.aten.bitwise_and_: op_base.InplaceOp(torch.ops.aten.bitwise_and),
5779
5781
  torch.ops.aten.bitwise_or_: op_base.InplaceOp(torch.ops.aten.bitwise_or),
5782
+ torch.ops.aten.bitwise_xor_: op_base.InplaceOp(torch.ops.aten.bitwise_xor),
5783
+ torch.ops.aten.bitwise_left_shift_: op_base.InplaceOp(
5784
+ torch.ops.aten.bitwise_left_shift
5785
+ ),
5786
+ torch.ops.aten.bitwise_right_shift_: op_base.InplaceOp(
5787
+ torch.ops.aten.bitwise_right_shift
5788
+ ),
5780
5789
  torch.ops.aten.floor_divide_: op_base.InplaceOp(torch.ops.aten.floor_divide),
5781
5790
  torch.ops.aten.remainder_: op_base.InplaceOp(torch.ops.aten.remainder),
5782
5791
  torch.ops.aten.index_put_: op_base.InplaceOp(torch.ops.aten.index_put),
torchax/ops/jtorch.py CHANGED
@@ -24,9 +24,6 @@ import jax.numpy as jnp
24
24
  import numpy as np
25
25
  import torch
26
26
  import torch.utils._pytree as pytree
27
- from jax.experimental.pallas.ops.tpu import flash_attention
28
- from jax.experimental.shard_map import shard_map
29
- from jax.sharding import PartitionSpec
30
27
 
31
28
  import torchax.tensor
32
29
  from torchax.ops import jaten, jimage, mappings, op_base
@@ -154,39 +151,6 @@ def _sdpa_reference(
154
151
  return attn_weight @ value
155
152
 
156
153
 
157
- def _tpu_flash_attention(query, key, value, env):
158
- fsdp_partition = PartitionSpec("fsdp")
159
-
160
- def wrap_flash_attention(query, key, value):
161
- block_sizes = flash_attention.BlockSizes(
162
- block_b=min(2, query.shape[0]),
163
- block_q=min(512, query.shape[2]),
164
- block_k_major=min(512, key.shape[2]),
165
- block_k=min(512, key.shape[2]),
166
- block_q_major_dkv=min(512, query.shape[2]),
167
- block_k_major_dkv=min(512, key.shape[2]),
168
- block_k_dkv=min(512, key.shape[2]),
169
- block_q_dkv=min(512, query.shape[2]),
170
- block_k_major_dq=min(512, key.shape[2]),
171
- block_k_dq=min(256, key.shape[2]),
172
- block_q_dq=min(1024, query.shape[2]),
173
- )
174
- return flash_attention.flash_attention(
175
- query, key, value, causal=True, block_sizes=block_sizes
176
- )
177
-
178
- if env.config.shmap_flash_attention:
179
- wrap_flash_attention = shard_map(
180
- wrap_flash_attention,
181
- mesh=env._mesh,
182
- in_specs=(fsdp_partition, fsdp_partition, fsdp_partition),
183
- out_specs=fsdp_partition,
184
- check_rep=False,
185
- )
186
- # return flash_attn_mapped(query, key, value)
187
- return wrap_flash_attention(query, key, value)
188
-
189
-
190
154
  @register_function(torch.nn.functional.one_hot)
191
155
  def one_hot(tensor, num_classes=-1):
192
156
  if num_classes == -1:
@@ -265,11 +229,6 @@ def scaled_dot_product_attention(
265
229
  enable_gqa=False,
266
230
  env=None,
267
231
  ) -> torch.Tensor:
268
- if env.config.use_tpu_flash_attention:
269
- jquery, jkey, jvalue = env.t2j_iso((query, key, value))
270
- res = _tpu_flash_attention(jquery, jkey, jvalue, env)
271
- return env.j2t_iso(res)
272
-
273
232
  return _sdpa_reference(
274
233
  query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa
275
234
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torchax
3
- Version: 0.0.11.dev2025128
3
+ Version: 0.0.11.dev20251229
4
4
  Summary: torchax is a library for running Jax and PyTorch together
5
5
  Project-URL: Homepage, https://github.com/google/torchax
6
6
  Author-email: Han Qi <qihan.dev@gmail.com>, Google Cloud Inference Team <cmcs-inference-eng@google.com>
@@ -1,13 +1,13 @@
1
1
  torchax/CONTRIBUTING.md,sha256=gbB2ewxDLC-HHRBC3B8HdppV_d9MbDd-9rvzGQt7vZU,1440
2
- torchax/__init__.py,sha256=KUKGHnTP6TQi9GqMlCVo6YhIRxeryRTZA9lKyLr7vb0,3890
2
+ torchax/__init__.py,sha256=GYuxucLFbesjExbHu6xu1wjhh1up9ZdzVY8LSwv-EDU,3891
3
3
  torchax/amp.py,sha256=WTrfah2NYodapDVMsht7H3zDpl-XslujxhcYXr85g1s,10999
4
4
  torchax/checkpoint.py,sha256=2eoGeIQtL1Chof0W9qorB2Q0eCyEVJyWKqGoetf32GQ,2439
5
- torchax/config.py,sha256=c2JVtKx-GkkQn9vGxYgweGrm57G60mH1BB-SfE6-d6Q,1497
5
+ torchax/config.py,sha256=oTwgWDujF9vNSHRNKUvz3ZkocDe0aDaF-yviqAJAewY,1398
6
6
  torchax/decompositions.py,sha256=8VU0FfKqbP8h3S7JzHi0iWqT5E1OrwYvuhf6cjzDTlI,29303
7
7
  torchax/device_module.py,sha256=7WrLUBjMQiAilVfRwEwJrbfkizPZAC3022UO70U5uEQ,924
8
8
  torchax/export.py,sha256=CFQESWy9-ENo9Ozf08qwqUtrNpt7Y3ubX_dP99o3FdM,9395
9
9
  torchax/flax.py,sha256=Eft46Np3qPvSLmBGOltCx8KbnrGsuBeqx3Zu0tlhMpg,1807
10
- torchax/interop.py,sha256=_cdG_lrvRD8KFkzJe5bsHqnapW2sxSABL6K2GskV-r8,11722
10
+ torchax/interop.py,sha256=6R9_pIkd5Kwb5EoACUsq4GEEe8PaSFOdfnD4GIy6Y9U,11815
11
11
  torchax/mesh_util.py,sha256=Y3RVKOyLVKpbseyXTYlJlUgyNavmDAsh2MQ6pXYQDUU,9719
12
12
  torchax/tensor.py,sha256=IP7JuzUfKZVrhJdA67VdXBEh4bfypwavKRzw6llsD4U,21452
13
13
  torchax/train.py,sha256=3sqIYO1Q6GN6gRGkVVoKjjUZ3xYPgh110g5XbkaUxh4,4367
@@ -15,17 +15,17 @@ torchax/types.py,sha256=NDtW1fARypg0ZHcGRVTBZKQqxwJzWtf5E5bntng91gk,981
15
15
  torchax/util.py,sha256=Oud7oaAw1SJo_v4fwEZdjuseZ_bvngAsAQ-dOEzy_20,3675
16
16
  torchax/view.py,sha256=750VYe6tmwAINNAyjN8GDvPmgaR8luvr74szwuikGts,11256
17
17
  torchax/ops/__init__.py,sha256=uc-Rod4Xlk_oJ-rdY9P8-MGTu9dsXPombscqcSys0r8,840
18
- torchax/ops/jaten.py,sha256=XjYPZLX9w9fi0TtiHYqvAZQkjnjuVSNcl0ge_IYjhl8,162999
18
+ torchax/ops/jaten.py,sha256=x5vrEiVRFaTigPYvFHtR9ClwW0Z6TEHT66-mEdzfe3k,163435
19
19
  torchax/ops/jax_reimplement.py,sha256=Te8Je2ea9jX2SFV34PNPSHVP7-z_bmFykVWeqn8Tqwo,7714
20
20
  torchax/ops/jc10d.py,sha256=sO99kYDM9WRnSENHmMkH_MXWkx6HZdDvK5Z_M0C571g,1889
21
21
  torchax/ops/jimage.py,sha256=uvYW-fMaGU6-QTXAmTZ8rmHEkkwpXh2I-bu54STf-ic,3594
22
22
  torchax/ops/jlibrary.py,sha256=ESeFS5xwV1yiCIp_yKzDXihHd1wcz5eXjFkjFKsmw3w,3470
23
- torchax/ops/jtorch.py,sha256=WtBaDRFgw1PcXk8tW5SFOLbYNWLFULhdIQIrJv6Cy7k,17874
23
+ torchax/ops/jtorch.py,sha256=OuFTed92B3WuWnfnjBBvMgwkl43PLJJiDWSMb0ZS0sw,16406
24
24
  torchax/ops/jtorchvision_nms.py,sha256=VNMshE3LCsIBHVVzkrNEm0kYMF89ZVIKWlQIk6pCZB0,9197
25
25
  torchax/ops/mappings.py,sha256=ViEsZaGIi37BhuLw9hx9cA9XXl35OVuaRN6Q9yZygxk,4162
26
26
  torchax/ops/op_base.py,sha256=-rQXLpkgNZ1HM3OT1XQkvAV_7Dtq019_rAXNAg97OuE,4135
27
27
  torchax/ops/ops_registry.py,sha256=sBT41LRGmUVP4ZJ9YU1DyffatOHxe-x8oXqMhCKh0y8,1836
28
- torchax-0.0.11.dev2025128.dist-info/METADATA,sha256=TqAfHQIUW8aGaQHO9rwwI7ODnc8d9D0DOL6fYHscUqM,22452
29
- torchax-0.0.11.dev2025128.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
30
- torchax-0.0.11.dev2025128.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
31
- torchax-0.0.11.dev2025128.dist-info/RECORD,,
28
+ torchax-0.0.11.dev20251229.dist-info/METADATA,sha256=LhVy_90tMdwogkXLFgi2ljDuCJlMYW94whdboa_khVA,22453
29
+ torchax-0.0.11.dev20251229.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
30
+ torchax-0.0.11.dev20251229.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
31
+ torchax-0.0.11.dev20251229.dist-info/RECORD,,