torchax 0.0.11.dev202612__py3-none-any.whl → 0.0.11.dev2025128__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

torchax/__init__.py CHANGED
@@ -27,7 +27,7 @@ from torchax import tensor
27
27
 
28
28
  from .checkpoint import load_checkpoint, save_checkpoint
29
29
 
30
- __version__ = "0.0.11.dev202612"
30
+ __version__ = "0.0.11.dev2025128"
31
31
  VERSION = __version__
32
32
 
33
33
  # the "fast path" uses some sparse tensor thingies that currently we
torchax/config.py CHANGED
@@ -35,6 +35,10 @@ class Configuration:
35
35
  # Use DLPack for converting jax.Arrays <-> and torch.Tensor
36
36
  use_dlpack_for_data_conversion: bool = False
37
37
 
38
+ # Flash attention
39
+ use_tpu_flash_attention: bool = False
40
+ shmap_flash_attention: bool = False
41
+
38
42
  # device
39
43
  treat_cuda_as_jax_device: bool = True
40
44
  internal_respect_torch_return_dtypes: bool = False
torchax/interop.py CHANGED
@@ -22,6 +22,7 @@ import jax
22
22
  import jax.numpy as jnp
23
23
  import torch
24
24
  from jax import tree_util as pytree
25
+ from jax.experimental.shard_map import shard_map
25
26
  from torch.nn.utils import stateless as torch_stateless
26
27
 
27
28
  import torchax
@@ -29,11 +30,6 @@ from torchax import tensor, util
29
30
  from torchax.ops import mappings
30
31
  from torchax.types import JaxCallable, JaxValue, TorchCallable, TorchValue
31
32
 
32
- try:
33
- from jax import shard_map as shard_map # for jax since v0.8.0
34
- except ImportError:
35
- from jax.experimental.shard_map import shard_map
36
-
37
33
 
38
34
  def extract_all_buffers(m: torch.nn.Module):
39
35
  buffers = {}
torchax/ops/jaten.py CHANGED
@@ -5776,16 +5776,7 @@ mutation_ops_to_functional = {
5776
5776
  torch.ops.aten.scatter_add_: op_base.InplaceOp(torch.ops.aten.scatter_add),
5777
5777
  torch.ops.aten.scatter_reduce_.two: op_base.InplaceOp(torch.ops.aten.scatter_reduce),
5778
5778
  torch.ops.aten.scatter_: op_base.InplaceOp(torch.ops.aten.scatter),
5779
- torch.ops.aten.bitwise_not_: op_base.InplaceOp(torch.ops.aten.bitwise_not),
5780
- torch.ops.aten.bitwise_and_: op_base.InplaceOp(torch.ops.aten.bitwise_and),
5781
5779
  torch.ops.aten.bitwise_or_: op_base.InplaceOp(torch.ops.aten.bitwise_or),
5782
- torch.ops.aten.bitwise_xor_: op_base.InplaceOp(torch.ops.aten.bitwise_xor),
5783
- torch.ops.aten.bitwise_left_shift_: op_base.InplaceOp(
5784
- torch.ops.aten.bitwise_left_shift
5785
- ),
5786
- torch.ops.aten.bitwise_right_shift_: op_base.InplaceOp(
5787
- torch.ops.aten.bitwise_right_shift
5788
- ),
5789
5780
  torch.ops.aten.floor_divide_: op_base.InplaceOp(torch.ops.aten.floor_divide),
5790
5781
  torch.ops.aten.remainder_: op_base.InplaceOp(torch.ops.aten.remainder),
5791
5782
  torch.ops.aten.index_put_: op_base.InplaceOp(torch.ops.aten.index_put),
torchax/ops/jtorch.py CHANGED
@@ -24,6 +24,9 @@ import jax.numpy as jnp
24
24
  import numpy as np
25
25
  import torch
26
26
  import torch.utils._pytree as pytree
27
+ from jax.experimental.pallas.ops.tpu import flash_attention
28
+ from jax.experimental.shard_map import shard_map
29
+ from jax.sharding import PartitionSpec
27
30
 
28
31
  import torchax.tensor
29
32
  from torchax.ops import jaten, jimage, mappings, op_base
@@ -151,6 +154,39 @@ def _sdpa_reference(
151
154
  return attn_weight @ value
152
155
 
153
156
 
157
+ def _tpu_flash_attention(query, key, value, env):
158
+ fsdp_partition = PartitionSpec("fsdp")
159
+
160
+ def wrap_flash_attention(query, key, value):
161
+ block_sizes = flash_attention.BlockSizes(
162
+ block_b=min(2, query.shape[0]),
163
+ block_q=min(512, query.shape[2]),
164
+ block_k_major=min(512, key.shape[2]),
165
+ block_k=min(512, key.shape[2]),
166
+ block_q_major_dkv=min(512, query.shape[2]),
167
+ block_k_major_dkv=min(512, key.shape[2]),
168
+ block_k_dkv=min(512, key.shape[2]),
169
+ block_q_dkv=min(512, query.shape[2]),
170
+ block_k_major_dq=min(512, key.shape[2]),
171
+ block_k_dq=min(256, key.shape[2]),
172
+ block_q_dq=min(1024, query.shape[2]),
173
+ )
174
+ return flash_attention.flash_attention(
175
+ query, key, value, causal=True, block_sizes=block_sizes
176
+ )
177
+
178
+ if env.config.shmap_flash_attention:
179
+ wrap_flash_attention = shard_map(
180
+ wrap_flash_attention,
181
+ mesh=env._mesh,
182
+ in_specs=(fsdp_partition, fsdp_partition, fsdp_partition),
183
+ out_specs=fsdp_partition,
184
+ check_rep=False,
185
+ )
186
+ # return flash_attn_mapped(query, key, value)
187
+ return wrap_flash_attention(query, key, value)
188
+
189
+
154
190
  @register_function(torch.nn.functional.one_hot)
155
191
  def one_hot(tensor, num_classes=-1):
156
192
  if num_classes == -1:
@@ -229,6 +265,11 @@ def scaled_dot_product_attention(
229
265
  enable_gqa=False,
230
266
  env=None,
231
267
  ) -> torch.Tensor:
268
+ if env.config.use_tpu_flash_attention:
269
+ jquery, jkey, jvalue = env.t2j_iso((query, key, value))
270
+ res = _tpu_flash_attention(jquery, jkey, jvalue, env)
271
+ return env.j2t_iso(res)
272
+
232
273
  return _sdpa_reference(
233
274
  query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa
234
275
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torchax
3
- Version: 0.0.11.dev202612
3
+ Version: 0.0.11.dev2025128
4
4
  Summary: torchax is a library for running Jax and PyTorch together
5
5
  Project-URL: Homepage, https://github.com/google/torchax
6
6
  Author-email: Han Qi <qihan.dev@gmail.com>, Google Cloud Inference Team <cmcs-inference-eng@google.com>
@@ -1,13 +1,13 @@
1
1
  torchax/CONTRIBUTING.md,sha256=gbB2ewxDLC-HHRBC3B8HdppV_d9MbDd-9rvzGQt7vZU,1440
2
- torchax/__init__.py,sha256=m9mLdO6-n-WfpeR-R7jNQxkESUUpXQJ5A-0CNnXuSJs,3889
2
+ torchax/__init__.py,sha256=KUKGHnTP6TQi9GqMlCVo6YhIRxeryRTZA9lKyLr7vb0,3890
3
3
  torchax/amp.py,sha256=WTrfah2NYodapDVMsht7H3zDpl-XslujxhcYXr85g1s,10999
4
4
  torchax/checkpoint.py,sha256=2eoGeIQtL1Chof0W9qorB2Q0eCyEVJyWKqGoetf32GQ,2439
5
- torchax/config.py,sha256=oTwgWDujF9vNSHRNKUvz3ZkocDe0aDaF-yviqAJAewY,1398
5
+ torchax/config.py,sha256=c2JVtKx-GkkQn9vGxYgweGrm57G60mH1BB-SfE6-d6Q,1497
6
6
  torchax/decompositions.py,sha256=8VU0FfKqbP8h3S7JzHi0iWqT5E1OrwYvuhf6cjzDTlI,29303
7
7
  torchax/device_module.py,sha256=7WrLUBjMQiAilVfRwEwJrbfkizPZAC3022UO70U5uEQ,924
8
8
  torchax/export.py,sha256=CFQESWy9-ENo9Ozf08qwqUtrNpt7Y3ubX_dP99o3FdM,9395
9
9
  torchax/flax.py,sha256=Eft46Np3qPvSLmBGOltCx8KbnrGsuBeqx3Zu0tlhMpg,1807
10
- torchax/interop.py,sha256=6R9_pIkd5Kwb5EoACUsq4GEEe8PaSFOdfnD4GIy6Y9U,11815
10
+ torchax/interop.py,sha256=_cdG_lrvRD8KFkzJe5bsHqnapW2sxSABL6K2GskV-r8,11722
11
11
  torchax/mesh_util.py,sha256=Y3RVKOyLVKpbseyXTYlJlUgyNavmDAsh2MQ6pXYQDUU,9719
12
12
  torchax/tensor.py,sha256=IP7JuzUfKZVrhJdA67VdXBEh4bfypwavKRzw6llsD4U,21452
13
13
  torchax/train.py,sha256=3sqIYO1Q6GN6gRGkVVoKjjUZ3xYPgh110g5XbkaUxh4,4367
@@ -15,17 +15,17 @@ torchax/types.py,sha256=NDtW1fARypg0ZHcGRVTBZKQqxwJzWtf5E5bntng91gk,981
15
15
  torchax/util.py,sha256=Oud7oaAw1SJo_v4fwEZdjuseZ_bvngAsAQ-dOEzy_20,3675
16
16
  torchax/view.py,sha256=750VYe6tmwAINNAyjN8GDvPmgaR8luvr74szwuikGts,11256
17
17
  torchax/ops/__init__.py,sha256=uc-Rod4Xlk_oJ-rdY9P8-MGTu9dsXPombscqcSys0r8,840
18
- torchax/ops/jaten.py,sha256=x5vrEiVRFaTigPYvFHtR9ClwW0Z6TEHT66-mEdzfe3k,163435
18
+ torchax/ops/jaten.py,sha256=XjYPZLX9w9fi0TtiHYqvAZQkjnjuVSNcl0ge_IYjhl8,162999
19
19
  torchax/ops/jax_reimplement.py,sha256=Te8Je2ea9jX2SFV34PNPSHVP7-z_bmFykVWeqn8Tqwo,7714
20
20
  torchax/ops/jc10d.py,sha256=sO99kYDM9WRnSENHmMkH_MXWkx6HZdDvK5Z_M0C571g,1889
21
21
  torchax/ops/jimage.py,sha256=uvYW-fMaGU6-QTXAmTZ8rmHEkkwpXh2I-bu54STf-ic,3594
22
22
  torchax/ops/jlibrary.py,sha256=ESeFS5xwV1yiCIp_yKzDXihHd1wcz5eXjFkjFKsmw3w,3470
23
- torchax/ops/jtorch.py,sha256=OuFTed92B3WuWnfnjBBvMgwkl43PLJJiDWSMb0ZS0sw,16406
23
+ torchax/ops/jtorch.py,sha256=WtBaDRFgw1PcXk8tW5SFOLbYNWLFULhdIQIrJv6Cy7k,17874
24
24
  torchax/ops/jtorchvision_nms.py,sha256=VNMshE3LCsIBHVVzkrNEm0kYMF89ZVIKWlQIk6pCZB0,9197
25
25
  torchax/ops/mappings.py,sha256=ViEsZaGIi37BhuLw9hx9cA9XXl35OVuaRN6Q9yZygxk,4162
26
26
  torchax/ops/op_base.py,sha256=-rQXLpkgNZ1HM3OT1XQkvAV_7Dtq019_rAXNAg97OuE,4135
27
27
  torchax/ops/ops_registry.py,sha256=sBT41LRGmUVP4ZJ9YU1DyffatOHxe-x8oXqMhCKh0y8,1836
28
- torchax-0.0.11.dev202612.dist-info/METADATA,sha256=tVMhNPH26W_cwlb9oA8xPNByKjM25aY7OI4SGnGFTg4,22451
29
- torchax-0.0.11.dev202612.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
30
- torchax-0.0.11.dev202612.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
31
- torchax-0.0.11.dev202612.dist-info/RECORD,,
28
+ torchax-0.0.11.dev2025128.dist-info/METADATA,sha256=TqAfHQIUW8aGaQHO9rwwI7ODnc8d9D0DOL6fYHscUqM,22452
29
+ torchax-0.0.11.dev2025128.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
30
+ torchax-0.0.11.dev2025128.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
31
+ torchax-0.0.11.dev2025128.dist-info/RECORD,,