torchax 0.0.11.dev2025128__py3-none-any.whl → 0.0.11.dev20251229__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/__init__.py +1 -1
- torchax/config.py +0 -4
- torchax/interop.py +5 -1
- torchax/ops/jaten.py +9 -0
- torchax/ops/jtorch.py +0 -41
- {torchax-0.0.11.dev2025128.dist-info → torchax-0.0.11.dev20251229.dist-info}/METADATA +1 -1
- {torchax-0.0.11.dev2025128.dist-info → torchax-0.0.11.dev20251229.dist-info}/RECORD +9 -9
- {torchax-0.0.11.dev2025128.dist-info → torchax-0.0.11.dev20251229.dist-info}/WHEEL +0 -0
- {torchax-0.0.11.dev2025128.dist-info → torchax-0.0.11.dev20251229.dist-info}/licenses/LICENSE +0 -0
torchax/__init__.py
CHANGED
|
@@ -27,7 +27,7 @@ from torchax import tensor
|
|
|
27
27
|
|
|
28
28
|
from .checkpoint import load_checkpoint, save_checkpoint
|
|
29
29
|
|
|
30
|
-
__version__ = "0.0.11.
|
|
30
|
+
__version__ = "0.0.11.dev20251229"
|
|
31
31
|
VERSION = __version__
|
|
32
32
|
|
|
33
33
|
# the "fast path" uses some sparse tensor thingies that currently we
|
torchax/config.py
CHANGED
|
@@ -35,10 +35,6 @@ class Configuration:
|
|
|
35
35
|
# Use DLPack for converting jax.Arrays <-> and torch.Tensor
|
|
36
36
|
use_dlpack_for_data_conversion: bool = False
|
|
37
37
|
|
|
38
|
-
# Flash attention
|
|
39
|
-
use_tpu_flash_attention: bool = False
|
|
40
|
-
shmap_flash_attention: bool = False
|
|
41
|
-
|
|
42
38
|
# device
|
|
43
39
|
treat_cuda_as_jax_device: bool = True
|
|
44
40
|
internal_respect_torch_return_dtypes: bool = False
|
torchax/interop.py
CHANGED
|
@@ -22,7 +22,6 @@ import jax
|
|
|
22
22
|
import jax.numpy as jnp
|
|
23
23
|
import torch
|
|
24
24
|
from jax import tree_util as pytree
|
|
25
|
-
from jax.experimental.shard_map import shard_map
|
|
26
25
|
from torch.nn.utils import stateless as torch_stateless
|
|
27
26
|
|
|
28
27
|
import torchax
|
|
@@ -30,6 +29,11 @@ from torchax import tensor, util
|
|
|
30
29
|
from torchax.ops import mappings
|
|
31
30
|
from torchax.types import JaxCallable, JaxValue, TorchCallable, TorchValue
|
|
32
31
|
|
|
32
|
+
try:
|
|
33
|
+
from jax import shard_map as shard_map # for jax since v0.8.0
|
|
34
|
+
except ImportError:
|
|
35
|
+
from jax.experimental.shard_map import shard_map
|
|
36
|
+
|
|
33
37
|
|
|
34
38
|
def extract_all_buffers(m: torch.nn.Module):
|
|
35
39
|
buffers = {}
|
torchax/ops/jaten.py
CHANGED
|
@@ -5776,7 +5776,16 @@ mutation_ops_to_functional = {
|
|
|
5776
5776
|
torch.ops.aten.scatter_add_: op_base.InplaceOp(torch.ops.aten.scatter_add),
|
|
5777
5777
|
torch.ops.aten.scatter_reduce_.two: op_base.InplaceOp(torch.ops.aten.scatter_reduce),
|
|
5778
5778
|
torch.ops.aten.scatter_: op_base.InplaceOp(torch.ops.aten.scatter),
|
|
5779
|
+
torch.ops.aten.bitwise_not_: op_base.InplaceOp(torch.ops.aten.bitwise_not),
|
|
5780
|
+
torch.ops.aten.bitwise_and_: op_base.InplaceOp(torch.ops.aten.bitwise_and),
|
|
5779
5781
|
torch.ops.aten.bitwise_or_: op_base.InplaceOp(torch.ops.aten.bitwise_or),
|
|
5782
|
+
torch.ops.aten.bitwise_xor_: op_base.InplaceOp(torch.ops.aten.bitwise_xor),
|
|
5783
|
+
torch.ops.aten.bitwise_left_shift_: op_base.InplaceOp(
|
|
5784
|
+
torch.ops.aten.bitwise_left_shift
|
|
5785
|
+
),
|
|
5786
|
+
torch.ops.aten.bitwise_right_shift_: op_base.InplaceOp(
|
|
5787
|
+
torch.ops.aten.bitwise_right_shift
|
|
5788
|
+
),
|
|
5780
5789
|
torch.ops.aten.floor_divide_: op_base.InplaceOp(torch.ops.aten.floor_divide),
|
|
5781
5790
|
torch.ops.aten.remainder_: op_base.InplaceOp(torch.ops.aten.remainder),
|
|
5782
5791
|
torch.ops.aten.index_put_: op_base.InplaceOp(torch.ops.aten.index_put),
|
torchax/ops/jtorch.py
CHANGED
|
@@ -24,9 +24,6 @@ import jax.numpy as jnp
|
|
|
24
24
|
import numpy as np
|
|
25
25
|
import torch
|
|
26
26
|
import torch.utils._pytree as pytree
|
|
27
|
-
from jax.experimental.pallas.ops.tpu import flash_attention
|
|
28
|
-
from jax.experimental.shard_map import shard_map
|
|
29
|
-
from jax.sharding import PartitionSpec
|
|
30
27
|
|
|
31
28
|
import torchax.tensor
|
|
32
29
|
from torchax.ops import jaten, jimage, mappings, op_base
|
|
@@ -154,39 +151,6 @@ def _sdpa_reference(
|
|
|
154
151
|
return attn_weight @ value
|
|
155
152
|
|
|
156
153
|
|
|
157
|
-
def _tpu_flash_attention(query, key, value, env):
|
|
158
|
-
fsdp_partition = PartitionSpec("fsdp")
|
|
159
|
-
|
|
160
|
-
def wrap_flash_attention(query, key, value):
|
|
161
|
-
block_sizes = flash_attention.BlockSizes(
|
|
162
|
-
block_b=min(2, query.shape[0]),
|
|
163
|
-
block_q=min(512, query.shape[2]),
|
|
164
|
-
block_k_major=min(512, key.shape[2]),
|
|
165
|
-
block_k=min(512, key.shape[2]),
|
|
166
|
-
block_q_major_dkv=min(512, query.shape[2]),
|
|
167
|
-
block_k_major_dkv=min(512, key.shape[2]),
|
|
168
|
-
block_k_dkv=min(512, key.shape[2]),
|
|
169
|
-
block_q_dkv=min(512, query.shape[2]),
|
|
170
|
-
block_k_major_dq=min(512, key.shape[2]),
|
|
171
|
-
block_k_dq=min(256, key.shape[2]),
|
|
172
|
-
block_q_dq=min(1024, query.shape[2]),
|
|
173
|
-
)
|
|
174
|
-
return flash_attention.flash_attention(
|
|
175
|
-
query, key, value, causal=True, block_sizes=block_sizes
|
|
176
|
-
)
|
|
177
|
-
|
|
178
|
-
if env.config.shmap_flash_attention:
|
|
179
|
-
wrap_flash_attention = shard_map(
|
|
180
|
-
wrap_flash_attention,
|
|
181
|
-
mesh=env._mesh,
|
|
182
|
-
in_specs=(fsdp_partition, fsdp_partition, fsdp_partition),
|
|
183
|
-
out_specs=fsdp_partition,
|
|
184
|
-
check_rep=False,
|
|
185
|
-
)
|
|
186
|
-
# return flash_attn_mapped(query, key, value)
|
|
187
|
-
return wrap_flash_attention(query, key, value)
|
|
188
|
-
|
|
189
|
-
|
|
190
154
|
@register_function(torch.nn.functional.one_hot)
|
|
191
155
|
def one_hot(tensor, num_classes=-1):
|
|
192
156
|
if num_classes == -1:
|
|
@@ -265,11 +229,6 @@ def scaled_dot_product_attention(
|
|
|
265
229
|
enable_gqa=False,
|
|
266
230
|
env=None,
|
|
267
231
|
) -> torch.Tensor:
|
|
268
|
-
if env.config.use_tpu_flash_attention:
|
|
269
|
-
jquery, jkey, jvalue = env.t2j_iso((query, key, value))
|
|
270
|
-
res = _tpu_flash_attention(jquery, jkey, jvalue, env)
|
|
271
|
-
return env.j2t_iso(res)
|
|
272
|
-
|
|
273
232
|
return _sdpa_reference(
|
|
274
233
|
query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa
|
|
275
234
|
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: torchax
|
|
3
|
-
Version: 0.0.11.
|
|
3
|
+
Version: 0.0.11.dev20251229
|
|
4
4
|
Summary: torchax is a library for running Jax and PyTorch together
|
|
5
5
|
Project-URL: Homepage, https://github.com/google/torchax
|
|
6
6
|
Author-email: Han Qi <qihan.dev@gmail.com>, Google Cloud Inference Team <cmcs-inference-eng@google.com>
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
torchax/CONTRIBUTING.md,sha256=gbB2ewxDLC-HHRBC3B8HdppV_d9MbDd-9rvzGQt7vZU,1440
|
|
2
|
-
torchax/__init__.py,sha256=
|
|
2
|
+
torchax/__init__.py,sha256=GYuxucLFbesjExbHu6xu1wjhh1up9ZdzVY8LSwv-EDU,3891
|
|
3
3
|
torchax/amp.py,sha256=WTrfah2NYodapDVMsht7H3zDpl-XslujxhcYXr85g1s,10999
|
|
4
4
|
torchax/checkpoint.py,sha256=2eoGeIQtL1Chof0W9qorB2Q0eCyEVJyWKqGoetf32GQ,2439
|
|
5
|
-
torchax/config.py,sha256=
|
|
5
|
+
torchax/config.py,sha256=oTwgWDujF9vNSHRNKUvz3ZkocDe0aDaF-yviqAJAewY,1398
|
|
6
6
|
torchax/decompositions.py,sha256=8VU0FfKqbP8h3S7JzHi0iWqT5E1OrwYvuhf6cjzDTlI,29303
|
|
7
7
|
torchax/device_module.py,sha256=7WrLUBjMQiAilVfRwEwJrbfkizPZAC3022UO70U5uEQ,924
|
|
8
8
|
torchax/export.py,sha256=CFQESWy9-ENo9Ozf08qwqUtrNpt7Y3ubX_dP99o3FdM,9395
|
|
9
9
|
torchax/flax.py,sha256=Eft46Np3qPvSLmBGOltCx8KbnrGsuBeqx3Zu0tlhMpg,1807
|
|
10
|
-
torchax/interop.py,sha256=
|
|
10
|
+
torchax/interop.py,sha256=6R9_pIkd5Kwb5EoACUsq4GEEe8PaSFOdfnD4GIy6Y9U,11815
|
|
11
11
|
torchax/mesh_util.py,sha256=Y3RVKOyLVKpbseyXTYlJlUgyNavmDAsh2MQ6pXYQDUU,9719
|
|
12
12
|
torchax/tensor.py,sha256=IP7JuzUfKZVrhJdA67VdXBEh4bfypwavKRzw6llsD4U,21452
|
|
13
13
|
torchax/train.py,sha256=3sqIYO1Q6GN6gRGkVVoKjjUZ3xYPgh110g5XbkaUxh4,4367
|
|
@@ -15,17 +15,17 @@ torchax/types.py,sha256=NDtW1fARypg0ZHcGRVTBZKQqxwJzWtf5E5bntng91gk,981
|
|
|
15
15
|
torchax/util.py,sha256=Oud7oaAw1SJo_v4fwEZdjuseZ_bvngAsAQ-dOEzy_20,3675
|
|
16
16
|
torchax/view.py,sha256=750VYe6tmwAINNAyjN8GDvPmgaR8luvr74szwuikGts,11256
|
|
17
17
|
torchax/ops/__init__.py,sha256=uc-Rod4Xlk_oJ-rdY9P8-MGTu9dsXPombscqcSys0r8,840
|
|
18
|
-
torchax/ops/jaten.py,sha256=
|
|
18
|
+
torchax/ops/jaten.py,sha256=x5vrEiVRFaTigPYvFHtR9ClwW0Z6TEHT66-mEdzfe3k,163435
|
|
19
19
|
torchax/ops/jax_reimplement.py,sha256=Te8Je2ea9jX2SFV34PNPSHVP7-z_bmFykVWeqn8Tqwo,7714
|
|
20
20
|
torchax/ops/jc10d.py,sha256=sO99kYDM9WRnSENHmMkH_MXWkx6HZdDvK5Z_M0C571g,1889
|
|
21
21
|
torchax/ops/jimage.py,sha256=uvYW-fMaGU6-QTXAmTZ8rmHEkkwpXh2I-bu54STf-ic,3594
|
|
22
22
|
torchax/ops/jlibrary.py,sha256=ESeFS5xwV1yiCIp_yKzDXihHd1wcz5eXjFkjFKsmw3w,3470
|
|
23
|
-
torchax/ops/jtorch.py,sha256=
|
|
23
|
+
torchax/ops/jtorch.py,sha256=OuFTed92B3WuWnfnjBBvMgwkl43PLJJiDWSMb0ZS0sw,16406
|
|
24
24
|
torchax/ops/jtorchvision_nms.py,sha256=VNMshE3LCsIBHVVzkrNEm0kYMF89ZVIKWlQIk6pCZB0,9197
|
|
25
25
|
torchax/ops/mappings.py,sha256=ViEsZaGIi37BhuLw9hx9cA9XXl35OVuaRN6Q9yZygxk,4162
|
|
26
26
|
torchax/ops/op_base.py,sha256=-rQXLpkgNZ1HM3OT1XQkvAV_7Dtq019_rAXNAg97OuE,4135
|
|
27
27
|
torchax/ops/ops_registry.py,sha256=sBT41LRGmUVP4ZJ9YU1DyffatOHxe-x8oXqMhCKh0y8,1836
|
|
28
|
-
torchax-0.0.11.
|
|
29
|
-
torchax-0.0.11.
|
|
30
|
-
torchax-0.0.11.
|
|
31
|
-
torchax-0.0.11.
|
|
28
|
+
torchax-0.0.11.dev20251229.dist-info/METADATA,sha256=LhVy_90tMdwogkXLFgi2ljDuCJlMYW94whdboa_khVA,22453
|
|
29
|
+
torchax-0.0.11.dev20251229.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
30
|
+
torchax-0.0.11.dev20251229.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
31
|
+
torchax-0.0.11.dev20251229.dist-info/RECORD,,
|
|
File without changes
|
{torchax-0.0.11.dev2025128.dist-info → torchax-0.0.11.dev20251229.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|