torchax 0.0.11.dev202617__py3-none-any.whl → 0.0.11.dev2025128__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/__init__.py +1 -1
- torchax/config.py +4 -0
- torchax/interop.py +1 -5
- torchax/ops/jaten.py +0 -9
- torchax/ops/jtorch.py +41 -0
- {torchax-0.0.11.dev202617.dist-info → torchax-0.0.11.dev2025128.dist-info}/METADATA +1 -1
- {torchax-0.0.11.dev202617.dist-info → torchax-0.0.11.dev2025128.dist-info}/RECORD +9 -9
- {torchax-0.0.11.dev202617.dist-info → torchax-0.0.11.dev2025128.dist-info}/WHEEL +0 -0
- {torchax-0.0.11.dev202617.dist-info → torchax-0.0.11.dev2025128.dist-info}/licenses/LICENSE +0 -0
torchax/__init__.py
CHANGED
|
@@ -27,7 +27,7 @@ from torchax import tensor
|
|
|
27
27
|
|
|
28
28
|
from .checkpoint import load_checkpoint, save_checkpoint
|
|
29
29
|
|
|
30
|
-
__version__ = "0.0.11.
|
|
30
|
+
__version__ = "0.0.11.dev2025128"
|
|
31
31
|
VERSION = __version__
|
|
32
32
|
|
|
33
33
|
# the "fast path" uses some sparse tensor thingies that currently we
|
torchax/config.py
CHANGED
|
@@ -35,6 +35,10 @@ class Configuration:
|
|
|
35
35
|
# Use DLPack for converting jax.Arrays <-> and torch.Tensor
|
|
36
36
|
use_dlpack_for_data_conversion: bool = False
|
|
37
37
|
|
|
38
|
+
# Flash attention
|
|
39
|
+
use_tpu_flash_attention: bool = False
|
|
40
|
+
shmap_flash_attention: bool = False
|
|
41
|
+
|
|
38
42
|
# device
|
|
39
43
|
treat_cuda_as_jax_device: bool = True
|
|
40
44
|
internal_respect_torch_return_dtypes: bool = False
|
torchax/interop.py
CHANGED
|
@@ -22,6 +22,7 @@ import jax
|
|
|
22
22
|
import jax.numpy as jnp
|
|
23
23
|
import torch
|
|
24
24
|
from jax import tree_util as pytree
|
|
25
|
+
from jax.experimental.shard_map import shard_map
|
|
25
26
|
from torch.nn.utils import stateless as torch_stateless
|
|
26
27
|
|
|
27
28
|
import torchax
|
|
@@ -29,11 +30,6 @@ from torchax import tensor, util
|
|
|
29
30
|
from torchax.ops import mappings
|
|
30
31
|
from torchax.types import JaxCallable, JaxValue, TorchCallable, TorchValue
|
|
31
32
|
|
|
32
|
-
try:
|
|
33
|
-
from jax import shard_map as shard_map # for jax since v0.8.0
|
|
34
|
-
except ImportError:
|
|
35
|
-
from jax.experimental.shard_map import shard_map
|
|
36
|
-
|
|
37
33
|
|
|
38
34
|
def extract_all_buffers(m: torch.nn.Module):
|
|
39
35
|
buffers = {}
|
torchax/ops/jaten.py
CHANGED
|
@@ -5776,16 +5776,7 @@ mutation_ops_to_functional = {
|
|
|
5776
5776
|
torch.ops.aten.scatter_add_: op_base.InplaceOp(torch.ops.aten.scatter_add),
|
|
5777
5777
|
torch.ops.aten.scatter_reduce_.two: op_base.InplaceOp(torch.ops.aten.scatter_reduce),
|
|
5778
5778
|
torch.ops.aten.scatter_: op_base.InplaceOp(torch.ops.aten.scatter),
|
|
5779
|
-
torch.ops.aten.bitwise_not_: op_base.InplaceOp(torch.ops.aten.bitwise_not),
|
|
5780
|
-
torch.ops.aten.bitwise_and_: op_base.InplaceOp(torch.ops.aten.bitwise_and),
|
|
5781
5779
|
torch.ops.aten.bitwise_or_: op_base.InplaceOp(torch.ops.aten.bitwise_or),
|
|
5782
|
-
torch.ops.aten.bitwise_xor_: op_base.InplaceOp(torch.ops.aten.bitwise_xor),
|
|
5783
|
-
torch.ops.aten.bitwise_left_shift_: op_base.InplaceOp(
|
|
5784
|
-
torch.ops.aten.bitwise_left_shift
|
|
5785
|
-
),
|
|
5786
|
-
torch.ops.aten.bitwise_right_shift_: op_base.InplaceOp(
|
|
5787
|
-
torch.ops.aten.bitwise_right_shift
|
|
5788
|
-
),
|
|
5789
5780
|
torch.ops.aten.floor_divide_: op_base.InplaceOp(torch.ops.aten.floor_divide),
|
|
5790
5781
|
torch.ops.aten.remainder_: op_base.InplaceOp(torch.ops.aten.remainder),
|
|
5791
5782
|
torch.ops.aten.index_put_: op_base.InplaceOp(torch.ops.aten.index_put),
|
torchax/ops/jtorch.py
CHANGED
|
@@ -24,6 +24,9 @@ import jax.numpy as jnp
|
|
|
24
24
|
import numpy as np
|
|
25
25
|
import torch
|
|
26
26
|
import torch.utils._pytree as pytree
|
|
27
|
+
from jax.experimental.pallas.ops.tpu import flash_attention
|
|
28
|
+
from jax.experimental.shard_map import shard_map
|
|
29
|
+
from jax.sharding import PartitionSpec
|
|
27
30
|
|
|
28
31
|
import torchax.tensor
|
|
29
32
|
from torchax.ops import jaten, jimage, mappings, op_base
|
|
@@ -151,6 +154,39 @@ def _sdpa_reference(
|
|
|
151
154
|
return attn_weight @ value
|
|
152
155
|
|
|
153
156
|
|
|
157
|
+
def _tpu_flash_attention(query, key, value, env):
|
|
158
|
+
fsdp_partition = PartitionSpec("fsdp")
|
|
159
|
+
|
|
160
|
+
def wrap_flash_attention(query, key, value):
|
|
161
|
+
block_sizes = flash_attention.BlockSizes(
|
|
162
|
+
block_b=min(2, query.shape[0]),
|
|
163
|
+
block_q=min(512, query.shape[2]),
|
|
164
|
+
block_k_major=min(512, key.shape[2]),
|
|
165
|
+
block_k=min(512, key.shape[2]),
|
|
166
|
+
block_q_major_dkv=min(512, query.shape[2]),
|
|
167
|
+
block_k_major_dkv=min(512, key.shape[2]),
|
|
168
|
+
block_k_dkv=min(512, key.shape[2]),
|
|
169
|
+
block_q_dkv=min(512, query.shape[2]),
|
|
170
|
+
block_k_major_dq=min(512, key.shape[2]),
|
|
171
|
+
block_k_dq=min(256, key.shape[2]),
|
|
172
|
+
block_q_dq=min(1024, query.shape[2]),
|
|
173
|
+
)
|
|
174
|
+
return flash_attention.flash_attention(
|
|
175
|
+
query, key, value, causal=True, block_sizes=block_sizes
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
if env.config.shmap_flash_attention:
|
|
179
|
+
wrap_flash_attention = shard_map(
|
|
180
|
+
wrap_flash_attention,
|
|
181
|
+
mesh=env._mesh,
|
|
182
|
+
in_specs=(fsdp_partition, fsdp_partition, fsdp_partition),
|
|
183
|
+
out_specs=fsdp_partition,
|
|
184
|
+
check_rep=False,
|
|
185
|
+
)
|
|
186
|
+
# return flash_attn_mapped(query, key, value)
|
|
187
|
+
return wrap_flash_attention(query, key, value)
|
|
188
|
+
|
|
189
|
+
|
|
154
190
|
@register_function(torch.nn.functional.one_hot)
|
|
155
191
|
def one_hot(tensor, num_classes=-1):
|
|
156
192
|
if num_classes == -1:
|
|
@@ -229,6 +265,11 @@ def scaled_dot_product_attention(
|
|
|
229
265
|
enable_gqa=False,
|
|
230
266
|
env=None,
|
|
231
267
|
) -> torch.Tensor:
|
|
268
|
+
if env.config.use_tpu_flash_attention:
|
|
269
|
+
jquery, jkey, jvalue = env.t2j_iso((query, key, value))
|
|
270
|
+
res = _tpu_flash_attention(jquery, jkey, jvalue, env)
|
|
271
|
+
return env.j2t_iso(res)
|
|
272
|
+
|
|
232
273
|
return _sdpa_reference(
|
|
233
274
|
query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa
|
|
234
275
|
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: torchax
|
|
3
|
-
Version: 0.0.11.
|
|
3
|
+
Version: 0.0.11.dev2025128
|
|
4
4
|
Summary: torchax is a library for running Jax and PyTorch together
|
|
5
5
|
Project-URL: Homepage, https://github.com/google/torchax
|
|
6
6
|
Author-email: Han Qi <qihan.dev@gmail.com>, Google Cloud Inference Team <cmcs-inference-eng@google.com>
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
torchax/CONTRIBUTING.md,sha256=gbB2ewxDLC-HHRBC3B8HdppV_d9MbDd-9rvzGQt7vZU,1440
|
|
2
|
-
torchax/__init__.py,sha256=
|
|
2
|
+
torchax/__init__.py,sha256=KUKGHnTP6TQi9GqMlCVo6YhIRxeryRTZA9lKyLr7vb0,3890
|
|
3
3
|
torchax/amp.py,sha256=WTrfah2NYodapDVMsht7H3zDpl-XslujxhcYXr85g1s,10999
|
|
4
4
|
torchax/checkpoint.py,sha256=2eoGeIQtL1Chof0W9qorB2Q0eCyEVJyWKqGoetf32GQ,2439
|
|
5
|
-
torchax/config.py,sha256=
|
|
5
|
+
torchax/config.py,sha256=c2JVtKx-GkkQn9vGxYgweGrm57G60mH1BB-SfE6-d6Q,1497
|
|
6
6
|
torchax/decompositions.py,sha256=8VU0FfKqbP8h3S7JzHi0iWqT5E1OrwYvuhf6cjzDTlI,29303
|
|
7
7
|
torchax/device_module.py,sha256=7WrLUBjMQiAilVfRwEwJrbfkizPZAC3022UO70U5uEQ,924
|
|
8
8
|
torchax/export.py,sha256=CFQESWy9-ENo9Ozf08qwqUtrNpt7Y3ubX_dP99o3FdM,9395
|
|
9
9
|
torchax/flax.py,sha256=Eft46Np3qPvSLmBGOltCx8KbnrGsuBeqx3Zu0tlhMpg,1807
|
|
10
|
-
torchax/interop.py,sha256=
|
|
10
|
+
torchax/interop.py,sha256=_cdG_lrvRD8KFkzJe5bsHqnapW2sxSABL6K2GskV-r8,11722
|
|
11
11
|
torchax/mesh_util.py,sha256=Y3RVKOyLVKpbseyXTYlJlUgyNavmDAsh2MQ6pXYQDUU,9719
|
|
12
12
|
torchax/tensor.py,sha256=IP7JuzUfKZVrhJdA67VdXBEh4bfypwavKRzw6llsD4U,21452
|
|
13
13
|
torchax/train.py,sha256=3sqIYO1Q6GN6gRGkVVoKjjUZ3xYPgh110g5XbkaUxh4,4367
|
|
@@ -15,17 +15,17 @@ torchax/types.py,sha256=NDtW1fARypg0ZHcGRVTBZKQqxwJzWtf5E5bntng91gk,981
|
|
|
15
15
|
torchax/util.py,sha256=Oud7oaAw1SJo_v4fwEZdjuseZ_bvngAsAQ-dOEzy_20,3675
|
|
16
16
|
torchax/view.py,sha256=750VYe6tmwAINNAyjN8GDvPmgaR8luvr74szwuikGts,11256
|
|
17
17
|
torchax/ops/__init__.py,sha256=uc-Rod4Xlk_oJ-rdY9P8-MGTu9dsXPombscqcSys0r8,840
|
|
18
|
-
torchax/ops/jaten.py,sha256=
|
|
18
|
+
torchax/ops/jaten.py,sha256=XjYPZLX9w9fi0TtiHYqvAZQkjnjuVSNcl0ge_IYjhl8,162999
|
|
19
19
|
torchax/ops/jax_reimplement.py,sha256=Te8Je2ea9jX2SFV34PNPSHVP7-z_bmFykVWeqn8Tqwo,7714
|
|
20
20
|
torchax/ops/jc10d.py,sha256=sO99kYDM9WRnSENHmMkH_MXWkx6HZdDvK5Z_M0C571g,1889
|
|
21
21
|
torchax/ops/jimage.py,sha256=uvYW-fMaGU6-QTXAmTZ8rmHEkkwpXh2I-bu54STf-ic,3594
|
|
22
22
|
torchax/ops/jlibrary.py,sha256=ESeFS5xwV1yiCIp_yKzDXihHd1wcz5eXjFkjFKsmw3w,3470
|
|
23
|
-
torchax/ops/jtorch.py,sha256=
|
|
23
|
+
torchax/ops/jtorch.py,sha256=WtBaDRFgw1PcXk8tW5SFOLbYNWLFULhdIQIrJv6Cy7k,17874
|
|
24
24
|
torchax/ops/jtorchvision_nms.py,sha256=VNMshE3LCsIBHVVzkrNEm0kYMF89ZVIKWlQIk6pCZB0,9197
|
|
25
25
|
torchax/ops/mappings.py,sha256=ViEsZaGIi37BhuLw9hx9cA9XXl35OVuaRN6Q9yZygxk,4162
|
|
26
26
|
torchax/ops/op_base.py,sha256=-rQXLpkgNZ1HM3OT1XQkvAV_7Dtq019_rAXNAg97OuE,4135
|
|
27
27
|
torchax/ops/ops_registry.py,sha256=sBT41LRGmUVP4ZJ9YU1DyffatOHxe-x8oXqMhCKh0y8,1836
|
|
28
|
-
torchax-0.0.11.
|
|
29
|
-
torchax-0.0.11.
|
|
30
|
-
torchax-0.0.11.
|
|
31
|
-
torchax-0.0.11.
|
|
28
|
+
torchax-0.0.11.dev2025128.dist-info/METADATA,sha256=TqAfHQIUW8aGaQHO9rwwI7ODnc8d9D0DOL6fYHscUqM,22452
|
|
29
|
+
torchax-0.0.11.dev2025128.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
30
|
+
torchax-0.0.11.dev2025128.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
31
|
+
torchax-0.0.11.dev2025128.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|