torchax 0.0.5__tar.gz → 0.0.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- {torchax-0.0.5 → torchax-0.0.6}/PKG-INFO +1 -1
- torchax-0.0.6/docs/api_iterations.md +22 -0
- torchax-0.0.6/test/BUILD +31 -0
- torchax-0.0.6/test/test_base.py +55 -0
- {torchax-0.0.5 → torchax-0.0.6}/test/test_context.py +16 -20
- {torchax-0.0.5 → torchax-0.0.6}/test/test_core_aten_ops.py +0 -5
- {torchax-0.0.5 → torchax-0.0.6}/test/test_flax.py +1 -1
- {torchax-0.0.5 → torchax-0.0.6}/test/test_functions.py +9 -2
- {torchax-0.0.5 → torchax-0.0.6}/test/test_interop.py +14 -18
- {torchax-0.0.5 → torchax-0.0.6}/test/test_jittable_module.py +19 -0
- {torchax-0.0.5 → torchax-0.0.6}/test/test_libraries.py +0 -1
- torchax-0.0.6/test/test_misc.py +22 -0
- torchax-0.0.6/test/test_mutations.py +52 -0
- {torchax-0.0.5 → torchax-0.0.6}/test/test_ops.py +3 -6
- {torchax-0.0.5 → torchax-0.0.6}/test/test_unbounded_dynamism.py +0 -5
- {torchax-0.0.5 → torchax-0.0.6}/torchax/__init__.py +5 -41
- {torchax-0.0.5 → torchax-0.0.6}/torchax/amp.py +2 -3
- {torchax-0.0.5 → torchax-0.0.6}/torchax/config.py +5 -1
- torchax-0.0.6/torchax/configuration.py +30 -0
- {torchax-0.0.5 → torchax-0.0.6}/torchax/device_module.py +7 -0
- torchax-0.0.6/torchax/environment.py +1 -0
- {torchax-0.0.5 → torchax-0.0.6}/torchax/interop.py +27 -14
- {torchax-0.0.5 → torchax-0.0.6}/torchax/mesh_util.py +10 -1
- {torchax-0.0.5 → torchax-0.0.6}/torchax/ops/jaten.py +5 -3
- {torchax-0.0.5 → torchax-0.0.6}/torchax/ops/jtorch.py +18 -10
- {torchax-0.0.5 → torchax-0.0.6}/torchax/tensor.py +127 -115
- torchax-0.0.5/examples/mnist_tpu.ipynb +0 -647
- torchax-0.0.5/examples/train_gpt/train_ddp.py +0 -140
- torchax-0.0.5/test/test_mutations.py +0 -36
- torchax-0.0.5/test_dist/README.md +0 -4
- torchax-0.0.5/test_dist/__init__.py +0 -0
- torchax-0.0.5/test_dist/test_distributed.py +0 -154
- torchax-0.0.5/torchax/distributed.py +0 -241
- {torchax-0.0.5 → torchax-0.0.6}/.gitignore +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/=2.3.0 +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/LICENSE +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/README.md +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/build_nightly.sh +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/dev-requirements.txt +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/docs/dispatch.png +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/docs/fixing_op_info_test.md +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/docs/how_it_works.md +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/docs/ops_registry.md +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/docs/support_a_new_model.md +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/docs/torch_dispatch/README.md +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/docs/torch_dispatch/example.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/docs/torch_dispatch/run_env.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/docs/torch_xla2_dynamo.md +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/docs/understand_jax_jit/jax_grad.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/docs/understand_jax_jit/jax_jit.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/docs/understand_jax_jit/torch_module.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/examples/README.md +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/examples/__init__.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/examples/_diffusion.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/examples/_grad_of_attention.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/examples/basic_training.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/examples/basic_training_jax.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/examples/eager_mode.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/examples/lightning_training.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/examples/requirements.txt +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/examples/torchbench_models/BERT_pytorch.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/examples/train_gpt/requirements.txt +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/examples/train_llama/README.md +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/examples/train_llama/__init__.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/examples/train_llama/model.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/examples/train_llama/train_llama_lightning.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/examples/train_llama/utils.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/examples/train_llama_torchtitan/Dockerfile +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/examples/train_llama_torchtitan/README.md +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/examples/train_llama_torchtitan/__init__.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/examples/train_llama_torchtitan/helper.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/examples/train_llama_torchtitan/splash_attn.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/examples/train_llama_torchtitan/train_llama.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/format.sh +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/pyproject.toml +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/repro1.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/temp +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/test/__init__.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/test/base_test_util.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/test/gemma/__init__.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/test/gemma/config.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/test/gemma/model.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/test/gemma/test_gemma.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/test/gemma/tokenizer.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/test/llama/BUILD +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/test/llama/__init__.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/test/llama/llama_model.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/test/llama/model_exportable.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/test/llama/test_llama.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/test/moe/__init__.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/test/moe/model.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/test/moe/moe_test.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/test/test_amp.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/test/test_conv.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/test/test_exports.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/test/test_image.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/test/test_symbolic_shapes.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/test/test_tf_integration.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/test/test_train.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/test/test_util.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/test/test_view.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/test-requirements.txt +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/test_dist/test_mesh_util.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/torchax/CONTRIBUTING.md +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/torchax/decompositions.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/torchax/export.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/torchax/flax.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/torchax/ops/__init__.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/torchax/ops/jax_reimplement.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/torchax/ops/jc10d.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/torchax/ops/jimage.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/torchax/ops/jlibrary.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/torchax/ops/jtorchvision_nms.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/torchax/ops/mappings.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/torchax/ops/op_base.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/torchax/ops/ops_registry.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/torchax/tf_integration.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/torchax/train.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/torchax/types.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/torchax/util.py +0 -0
- {torchax-0.0.5 → torchax-0.0.6}/torchax/view.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: torchax
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.6
|
|
4
4
|
Summary: torchax is a library for running Jax and PyTorch together
|
|
5
5
|
Project-URL: Homepage, https://github.com/pytorch/xla/tree/master/torchax
|
|
6
6
|
Author-email: Han Qi <qihan.dev@gmail.com>, Pytorch/XLA team <pytorchxla-dev@google.com>
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
|
|
2
|
+
## always create a new environment, use it, discard it?
|
|
3
|
+
|
|
4
|
+
```
|
|
5
|
+
env = torchax.env()
|
|
6
|
+
with env.with_prng_key(): // or extra inputs
|
|
7
|
+
do stuff
|
|
8
|
+
|
|
9
|
+
# discard env
|
|
10
|
+
|
|
11
|
+
with env.output_shape
|
|
12
|
+
with env.manual axis ...
|
|
13
|
+
|
|
14
|
+
env.call_torch_func(f, args, kwargs)
|
|
15
|
+
|
|
16
|
+
functions should take in env
|
|
17
|
+
functions in torch will get env from threadlocal property
|
|
18
|
+
```
|
|
19
|
+
env.call_stateless_torch_func()?
|
|
20
|
+
|
|
21
|
+
tx = torchax.initialize(...)
|
|
22
|
+
```
|
torchax-0.0.6/test/BUILD
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
# TODO(hanq): describe this package.
|
|
2
|
+
|
|
3
|
+
load(
|
|
4
|
+
"//third_party/py/torch/google/bazel_rules/rules_python/python:defs.bzl",
|
|
5
|
+
"py_library",
|
|
6
|
+
"py_test",
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
package(
|
|
10
|
+
default_applicable_licenses = ["//devtools/compliance/licenses:no_external_contributions"],
|
|
11
|
+
default_visibility = ["//visibility:public"],
|
|
12
|
+
licenses = ["notice"],
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
py_library(
|
|
16
|
+
name = "test_base",
|
|
17
|
+
srcs = ["test_base.py"],
|
|
18
|
+
deps = [
|
|
19
|
+
"//testing/pybase",
|
|
20
|
+
],
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
py_test(
|
|
24
|
+
name = "test_core_aten_ops",
|
|
25
|
+
srcs = ["test_core_aten_ops.py"],
|
|
26
|
+
deps = [
|
|
27
|
+
":test_base",
|
|
28
|
+
"//third_party/py/absl:app",
|
|
29
|
+
"//third_party/py/torch/google/_torx",
|
|
30
|
+
],
|
|
31
|
+
)
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import unittest
|
|
2
|
+
import torch
|
|
3
|
+
from torch.utils import _pytree as pytree
|
|
4
|
+
|
|
5
|
+
from torchax import tensor
|
|
6
|
+
|
|
7
|
+
TestCase = unittest.TestCase
|
|
8
|
+
main = unittest.main
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def diff_output(testcase, output1, output2, rtol, atol, equal_nan=True):
|
|
12
|
+
if isinstance(output1, torch.Tensor):
|
|
13
|
+
testcase.assertIsInstance(output2, torch.Tensor)
|
|
14
|
+
output2_cpu = output2.detach().cpu()
|
|
15
|
+
if output2_cpu.dtype != output1.dtype:
|
|
16
|
+
output2_cpu = output2_cpu.to(output1.dtype)
|
|
17
|
+
testcase.assertTrue(
|
|
18
|
+
torch.allclose(
|
|
19
|
+
output1, output2_cpu, atol=atol, rtol=rtol, equal_nan=equal_nan))
|
|
20
|
+
elif isinstance(output1, (tuple, list)):
|
|
21
|
+
testcase.assertIsInstance(output2, (tuple, list))
|
|
22
|
+
testcase.assertEqual(len(output1), len(output2))
|
|
23
|
+
for o1, o2 in zip(output1, output2):
|
|
24
|
+
diff_output(testcase, o1, o2, rtol, atol)
|
|
25
|
+
else:
|
|
26
|
+
testcase.assertEqual(output1, output2)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def run_function_and_compare(testcase,
|
|
30
|
+
func,
|
|
31
|
+
args,
|
|
32
|
+
kwargs,
|
|
33
|
+
atol=1e-3,
|
|
34
|
+
rtol=1e-5,
|
|
35
|
+
equal_nan=True,
|
|
36
|
+
ignore_indices=False):
|
|
37
|
+
with testcase.subTest("torch_eval"):
|
|
38
|
+
res = func(*args, **kwargs)
|
|
39
|
+
with testcase.subTest("torchax_eval"):
|
|
40
|
+
args2, kwargs2 = pytree.tree_map_only(torch.Tensor, tensor.move_to_device,
|
|
41
|
+
(args, kwargs))
|
|
42
|
+
res2 = func(*args2, **kwargs2)
|
|
43
|
+
res2 = pytree.tree_map_only(tensor.Tensor, lambda t: t.torch(), res2)
|
|
44
|
+
with testcase.subTest("torchax_diff:" + str(atol)):
|
|
45
|
+
if ignore_indices and isinstance(res, tuple) and len(res) == 2:
|
|
46
|
+
diff_output(
|
|
47
|
+
testcase,
|
|
48
|
+
res[0],
|
|
49
|
+
res2[0],
|
|
50
|
+
atol=atol,
|
|
51
|
+
rtol=rtol,
|
|
52
|
+
equal_nan=equal_nan)
|
|
53
|
+
else:
|
|
54
|
+
diff_output(
|
|
55
|
+
testcase, res, res2, atol=atol, rtol=rtol, equal_nan=equal_nan)
|
|
@@ -10,16 +10,9 @@ xla_env = torchax.default_env()
|
|
|
10
10
|
|
|
11
11
|
class TestContext(unittest.TestCase):
|
|
12
12
|
|
|
13
|
-
def setUp(self):
|
|
14
|
-
self.old_var = xla_env.config.use_torch_native_for_cpu_tensor
|
|
15
|
-
xla_env.config.use_torch_native_for_cpu_tensor = False
|
|
16
|
-
|
|
17
|
-
def tearDown(self):
|
|
18
|
-
xla_env.config.use_torch_native_for_cpu_tensor = self.old_var
|
|
19
|
-
|
|
20
13
|
def test_mode_context_manager(self):
|
|
21
14
|
with xla_env:
|
|
22
|
-
x = torch.full((3, 3), -1)
|
|
15
|
+
x = torch.full((3, 3), -1, device='jax')
|
|
23
16
|
self.assertIsInstance(x, tensor.Tensor)
|
|
24
17
|
y = x.abs()
|
|
25
18
|
self.assertIsInstance(y, tensor.Tensor)
|
|
@@ -27,7 +20,7 @@ class TestContext(unittest.TestCase):
|
|
|
27
20
|
@staticmethod
|
|
28
21
|
@xla_env
|
|
29
22
|
def _test_mode_decorator():
|
|
30
|
-
x = torch.full((3, 3), -1)
|
|
23
|
+
x = torch.full((3, 3), -1).to('jax')
|
|
31
24
|
y = x.abs()
|
|
32
25
|
|
|
33
26
|
return x, y
|
|
@@ -40,11 +33,11 @@ class TestContext(unittest.TestCase):
|
|
|
40
33
|
def test_same_manual_seed(self):
|
|
41
34
|
with xla_env:
|
|
42
35
|
xla_env.manual_seed(1234)
|
|
43
|
-
x = torch.randn((3, 3))
|
|
36
|
+
x = torch.randn((3, 3), device='jax')
|
|
44
37
|
self.assertIsInstance(x, tensor.Tensor)
|
|
45
38
|
|
|
46
39
|
xla_env.manual_seed(1234)
|
|
47
|
-
y = torch.randn((3, 3))
|
|
40
|
+
y = torch.randn((3, 3), device='jax')
|
|
48
41
|
self.assertIsInstance(y, tensor.Tensor)
|
|
49
42
|
|
|
50
43
|
self.assertTrue(torch.allclose(x, y))
|
|
@@ -52,11 +45,11 @@ class TestContext(unittest.TestCase):
|
|
|
52
45
|
def test_different_manual_seed(self):
|
|
53
46
|
with xla_env:
|
|
54
47
|
xla_env.manual_seed(1234)
|
|
55
|
-
x = torch.randn((3, 3))
|
|
48
|
+
x = torch.randn((3, 3), device='jax')
|
|
56
49
|
self.assertIsInstance(x, tensor.Tensor)
|
|
57
50
|
|
|
58
51
|
xla_env.manual_seed(12345)
|
|
59
|
-
y = torch.randn((3, 3))
|
|
52
|
+
y = torch.randn((3, 3), device='jax')
|
|
60
53
|
self.assertIsInstance(y, tensor.Tensor)
|
|
61
54
|
|
|
62
55
|
self.assertFalse(torch.allclose(x, y))
|
|
@@ -66,21 +59,24 @@ class TestContext(unittest.TestCase):
|
|
|
66
59
|
with xla_env:
|
|
67
60
|
|
|
68
61
|
def random_op():
|
|
69
|
-
x = torch.randn(3, 3)
|
|
70
|
-
y = torch.randn(3, 3)
|
|
62
|
+
x = torch.randn(3, 3, device='jax')
|
|
63
|
+
y = torch.randn(3, 3, device='jax')
|
|
71
64
|
return x @ y
|
|
72
65
|
|
|
73
66
|
random_jit = torchax.interop.jax_jit(random_op)
|
|
74
67
|
self.assertIsInstance(random_jit(), tensor.Tensor)
|
|
75
68
|
|
|
76
69
|
# If we run the JIT twice, the random values should be different.
|
|
77
|
-
|
|
78
|
-
|
|
70
|
+
# TODO(qihqi): think about API for passing down seed
|
|
71
|
+
# with self.assertRaises(AssertionError):
|
|
72
|
+
# torch.testing.assert_close(random_jit(), random_jit(), atol=0, rtol=0)
|
|
79
73
|
|
|
80
74
|
def test_generator_seed(self):
|
|
81
75
|
with xla_env:
|
|
82
|
-
x = torch.randn(
|
|
83
|
-
|
|
76
|
+
x = torch.randn(
|
|
77
|
+
2, 3, generator=torch.Generator().manual_seed(0), device='jax')
|
|
78
|
+
y = torch.randn(
|
|
79
|
+
2, 3, generator=torch.Generator().manual_seed(0), device='jax')
|
|
84
80
|
|
|
85
81
|
# Values will be the same given the same seed.
|
|
86
82
|
torch.testing.assert_close(x, y)
|
|
@@ -97,7 +93,7 @@ class TestContext(unittest.TestCase):
|
|
|
97
93
|
|
|
98
94
|
# Test context manager.
|
|
99
95
|
with xla_env:
|
|
100
|
-
m = M()
|
|
96
|
+
m = M().to('jax')
|
|
101
97
|
self.assertIsInstance(m.c, tensor.Tensor)
|
|
102
98
|
self.assertIsInstance(m.c2, tensor.Tensor)
|
|
103
99
|
# Test `to_xla`.
|
|
@@ -90,11 +90,6 @@ class TestCoreAtenOps(unittest.TestCase):
|
|
|
90
90
|
super().setUp()
|
|
91
91
|
torch.manual_seed(0)
|
|
92
92
|
self.env = tensor.Environment()
|
|
93
|
-
self.old_var = self.env.config.use_torch_native_for_cpu_tensor
|
|
94
|
-
self.env.config.use_torch_native_for_cpu_tensor = False
|
|
95
|
-
|
|
96
|
-
def tearDown(self):
|
|
97
|
-
self.env.config.use_torch_native_for_cpu_tensor = self.old_var
|
|
98
93
|
|
|
99
94
|
def test_aten_abs_0(self):
|
|
100
95
|
args = (torch.randn((10, 10)).to(torch.float32),)
|
|
@@ -88,8 +88,15 @@ class TestTorchFunctions(parameterized.TestCase):
|
|
|
88
88
|
res2 = model(x)
|
|
89
89
|
self.assertTrue(torch.allclose(res, res2.to('cpu')))
|
|
90
90
|
|
|
91
|
-
|
|
92
|
-
|
|
91
|
+
@parameterized.named_parameters(
|
|
92
|
+
('ones', torch.ones, ((2, 2),)), ('zeros', torch.zeros, ((2, 2),)),
|
|
93
|
+
('empty', torch.empty,
|
|
94
|
+
((2, 2),)), ('empty_strided', torch.empty_strided,
|
|
95
|
+
((2, 2), (2, 1))), ('tensor', torch.tensor, ([2.0, 2.0],)),
|
|
96
|
+
('eye', torch.eye, (2,)), ('randn', torch.randn, ((2, 2),)),
|
|
97
|
+
('rand', torch.rand, ((2, 2),)), ('full', torch.full, ((2, 2), 0)))
|
|
98
|
+
def test_requires_grad(self, func, args):
|
|
99
|
+
x = func(*args, requires_grad=True, device='jax')
|
|
93
100
|
self.assertEqual(x.requires_grad, True)
|
|
94
101
|
|
|
95
102
|
|
|
@@ -2,9 +2,10 @@ import functools
|
|
|
2
2
|
import torch
|
|
3
3
|
import unittest
|
|
4
4
|
import torchax
|
|
5
|
-
from torchax import interop
|
|
5
|
+
from torchax import interop
|
|
6
6
|
import torchax
|
|
7
7
|
import jax
|
|
8
|
+
import jax.numpy as jnp
|
|
8
9
|
|
|
9
10
|
|
|
10
11
|
def is_tpu_available():
|
|
@@ -142,7 +143,7 @@ class InteropTest(unittest.TestCase):
|
|
|
142
143
|
self.assertEqual(e.jax_device.platform, "cpu")
|
|
143
144
|
self.assertEqual(e.device.type, "jax")
|
|
144
145
|
|
|
145
|
-
with
|
|
146
|
+
with jax.default_device(jax.devices("cpu")[0]):
|
|
146
147
|
# move torch.tensor to torchax.tensor CPU
|
|
147
148
|
b = a.to("jax")
|
|
148
149
|
self.assertEqual(b.jax_device.platform, "cpu")
|
|
@@ -150,26 +151,21 @@ class InteropTest(unittest.TestCase):
|
|
|
150
151
|
|
|
151
152
|
if is_tpu_available():
|
|
152
153
|
# move torch.tensor to torchax.tensor TPU
|
|
153
|
-
with
|
|
154
|
+
with jax.default_device(jax.local_devices("tpu")[0]):
|
|
154
155
|
c = a.to("jax")
|
|
155
156
|
self.assertEqual(c.jax_device.platform, "tpu")
|
|
156
157
|
self.assertEqual(c.device.type, "jax")
|
|
157
158
|
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
self.assertEqual(c.jax_device.platform, "tpu")
|
|
169
|
-
self.assertEqual(c.device.type, "jax")
|
|
170
|
-
d = c.to("jax")
|
|
171
|
-
self.assertEqual(d.jax_device.platform, "cpu")
|
|
172
|
-
self.assertEqual(d.device.type, "jax")
|
|
159
|
+
def test_torch_jax_view_dtype(self):
|
|
160
|
+
dtype = torch.float32
|
|
161
|
+
self.assertEqual(interop.jax_view(dtype), jnp.float32.dtype)
|
|
162
|
+
self.assertEqual(interop.torch_view(interop.jax_view(dtype)), dtype)
|
|
163
|
+
dtype = torch.bfloat16
|
|
164
|
+
self.assertEqual(interop.jax_view(dtype), jnp.bfloat16.dtype)
|
|
165
|
+
self.assertEqual(interop.torch_view(interop.jax_view(dtype)), dtype)
|
|
166
|
+
dtype = torch.int32
|
|
167
|
+
self.assertEqual(interop.jax_view(dtype), jnp.int32.dtype)
|
|
168
|
+
self.assertEqual(interop.torch_view(interop.jax_view(dtype)), dtype)
|
|
173
169
|
|
|
174
170
|
|
|
175
171
|
if __name__ == '__main__':
|
|
@@ -34,6 +34,25 @@ class JittableModuleTest(unittest.TestCase):
|
|
|
34
34
|
assert isinstance(JittableMoreAwesomeModel, EvenMoreAwesomeModel)
|
|
35
35
|
assert not isinstance(JittableMoreAwesomeModel, MyAwesomeModel)
|
|
36
36
|
|
|
37
|
+
def test_functional_call_callable(self):
|
|
38
|
+
|
|
39
|
+
def outer_function(model, x):
|
|
40
|
+
return x + 1
|
|
41
|
+
|
|
42
|
+
model = MyAwesomeModel()
|
|
43
|
+
jittable_module = interop.JittableModule(model)
|
|
44
|
+
|
|
45
|
+
# Check if the jittable module can be called like a function
|
|
46
|
+
input_tensor = torch.randn(1, 3, 224, 224)
|
|
47
|
+
expected_output = input_tensor + 1
|
|
48
|
+
|
|
49
|
+
output = jittable_module.functional_call(outer_function,
|
|
50
|
+
jittable_module.params,
|
|
51
|
+
jittable_module.buffers,
|
|
52
|
+
input_tensor)
|
|
53
|
+
|
|
54
|
+
assert torch.equal(output, expected_output)
|
|
55
|
+
|
|
37
56
|
|
|
38
57
|
if __name__ == '__main__':
|
|
39
58
|
unittest.main()
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
import unittest
|
|
2
|
+
import torch
|
|
3
|
+
import torchax
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class MiscTest(unittest.TestCase):
|
|
7
|
+
|
|
8
|
+
@classmethod
|
|
9
|
+
def setUpClass(cls):
|
|
10
|
+
torchax.enable_globally()
|
|
11
|
+
|
|
12
|
+
def test_mixed_tensor_math_with_scalar(self):
|
|
13
|
+
a = torch.tensor(2)
|
|
14
|
+
b = torch.ones((2, 2), device='jax')
|
|
15
|
+
c = a * b
|
|
16
|
+
self.assertTrue(
|
|
17
|
+
torch.allclose(c.cpu(),
|
|
18
|
+
torch.tensor([[2, 2], [2, 2]], dtype=torch.float32)))
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
if __name__ == '__main__':
|
|
22
|
+
unittest.main()
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
import unittest
|
|
2
|
+
import torchax
|
|
3
|
+
import torch
|
|
4
|
+
from torch.testing._internal.common_utils import TestCase
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class TestMutations(TestCase):
|
|
8
|
+
|
|
9
|
+
def setUp(self):
|
|
10
|
+
self.env = torchax.tensor.Environment()
|
|
11
|
+
self.env.config.debug_print_each_op = True
|
|
12
|
+
|
|
13
|
+
def test_add(self):
|
|
14
|
+
with self.env:
|
|
15
|
+
x = torch.tensor([1, 2, 3], device='jax', dtype=torch.int32)
|
|
16
|
+
y = torch.tensor([4, 5, 6], device='jax', dtype=torch.int32)
|
|
17
|
+
x.add_(y)
|
|
18
|
+
torch.testing.assert_close(x.cpu(),
|
|
19
|
+
torch.tensor([5, 7, 9], dtype=torch.int32))
|
|
20
|
+
|
|
21
|
+
def test_sub(self):
|
|
22
|
+
with self.env:
|
|
23
|
+
x = torch.tensor([1, 2, 3], device='jax', dtype=torch.int32)
|
|
24
|
+
y = torch.tensor([4, 5, 6], device='jax', dtype=torch.int32)
|
|
25
|
+
x.sub_(y)
|
|
26
|
+
torch.testing.assert_close(x.cpu(),
|
|
27
|
+
torch.tensor([-3, -3, -3], dtype=torch.int32))
|
|
28
|
+
|
|
29
|
+
def test_mul(self):
|
|
30
|
+
with self.env:
|
|
31
|
+
x = torch.tensor([1, 2, 3], device='jax', dtype=torch.int32)
|
|
32
|
+
y = torch.tensor([4, 5, 6], device='jax', dtype=torch.int32)
|
|
33
|
+
|
|
34
|
+
x.mul_(y)
|
|
35
|
+
torch.testing.assert_close(x.cpu(),
|
|
36
|
+
torch.tensor([4, 10, 18], dtype=torch.int32))
|
|
37
|
+
|
|
38
|
+
def test_index_copy(self):
|
|
39
|
+
with self.env:
|
|
40
|
+
x = torch.zeros(5, 3, device='jax')
|
|
41
|
+
t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
|
|
42
|
+
device='jax',
|
|
43
|
+
dtype=torch.float)
|
|
44
|
+
index = torch.tensor([0, 4, 2], device='jax')
|
|
45
|
+
x.index_copy_(0, index, t)
|
|
46
|
+
expected = torch.tensor([[1., 2., 3.], [0., 0., 0.], [7., 8., 9.],
|
|
47
|
+
[0., 0., 0.], [4., 5., 6.]])
|
|
48
|
+
torch.testing.assert_close(x.cpu(), expected)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
if __name__ == '__main__':
|
|
52
|
+
unittest.main()
|
|
@@ -140,6 +140,8 @@ def run_export_and_compare(testcase,
|
|
|
140
140
|
with testcase.subTest("torchax_eval"):
|
|
141
141
|
input2, args2, kwargs2 = testcase.env.to_xla(
|
|
142
142
|
(sample_input.input, sample_input.args, sample_input.kwargs))
|
|
143
|
+
if 'device' in kwargs2:
|
|
144
|
+
kwargs2['device'] = 'jax'
|
|
143
145
|
with testcase.env:
|
|
144
146
|
res2 = func(input2, *args2, **kwargs2)
|
|
145
147
|
res2 = pytree.tree_map_only(tensor.Tensor, lambda t: t.torch(), res2)
|
|
@@ -186,13 +188,8 @@ class TestOpInfo(TestCase):
|
|
|
186
188
|
self.env = torchax.default_env()
|
|
187
189
|
torchax.enable_accuracy_mode()
|
|
188
190
|
#self.env.config.debug_accuracy_for_each_op = True
|
|
189
|
-
self.env.config.debug_print_each_op =
|
|
191
|
+
self.env.config.debug_print_each_op = False
|
|
190
192
|
torch.manual_seed(0)
|
|
191
|
-
self.old_var = self.env.config.use_torch_native_for_cpu_tensor
|
|
192
|
-
self.env.config.use_torch_native_for_cpu_tensor = False
|
|
193
|
-
|
|
194
|
-
def tearDown(self):
|
|
195
|
-
self.env.config.use_torch_native_for_cpu_tensor = self.old_var
|
|
196
193
|
|
|
197
194
|
# Replaces all values in the input torch_tensor that are less than the given threshold
|
|
198
195
|
# with the threshold value itself.
|
|
@@ -53,13 +53,8 @@ def wrap_func_as_nn_module(f):
|
|
|
53
53
|
class UnboundedDynamismExportTest(unittest.TestCase):
|
|
54
54
|
|
|
55
55
|
def setUp(self):
|
|
56
|
-
self.env = torchax.default_env()
|
|
57
|
-
self.env.config.use_torch_native_for_cpu_tensor = False
|
|
58
56
|
torchax.enable_accuracy_mode()
|
|
59
57
|
|
|
60
|
-
def tearDown(self):
|
|
61
|
-
self.env.config.use_torch_native_for_cpu_tensor = True
|
|
62
|
-
|
|
63
58
|
def test_add(self):
|
|
64
59
|
args = (torch.rand((10, 197, 768)), torch.rand((10, 197, 768)))
|
|
65
60
|
dynamic_shapes = (({0: Dim("dim")}, {0: Dim("dim")}),)
|
|
@@ -6,10 +6,9 @@ import os
|
|
|
6
6
|
import torch
|
|
7
7
|
from torch.utils import _pytree as pytree
|
|
8
8
|
from torchax import tensor
|
|
9
|
-
from torchax import distributed # noqa: F401
|
|
10
9
|
from contextlib import contextmanager
|
|
11
10
|
|
|
12
|
-
__version__ = "0.0.
|
|
11
|
+
__version__ = "0.0.6"
|
|
13
12
|
VERSION = __version__
|
|
14
13
|
|
|
15
14
|
__all__ = [
|
|
@@ -50,10 +49,11 @@ def extract_jax(mod: torch.nn.Module, env=None):
|
|
|
50
49
|
states = env.t2j_copy(states)
|
|
51
50
|
|
|
52
51
|
#@jax.jit
|
|
53
|
-
def jax_func(states,
|
|
54
|
-
(states,
|
|
52
|
+
def jax_func(states, args, kwargs=None):
|
|
53
|
+
(states, args, kwargs) = env.j2t_iso((states, args, kwargs))
|
|
55
54
|
with env:
|
|
56
|
-
res = torch.func.functional_call(
|
|
55
|
+
res = torch.func.functional_call(
|
|
56
|
+
mod, states, args, kwargs, tie_weights=False)
|
|
57
57
|
return env.t2j_iso(res)
|
|
58
58
|
|
|
59
59
|
return states, jax_func
|
|
@@ -81,11 +81,6 @@ def disable_temporarily():
|
|
|
81
81
|
|
|
82
82
|
torch.utils.rename_privateuse1_backend('jax')
|
|
83
83
|
unsupported_dtype = [torch.quint8]
|
|
84
|
-
torch.utils.generate_methods_for_privateuse1_backend(
|
|
85
|
-
for_tensor=True,
|
|
86
|
-
for_module=True,
|
|
87
|
-
for_storage=True,
|
|
88
|
-
unsupported_dtype=unsupported_dtype)
|
|
89
84
|
|
|
90
85
|
import jax
|
|
91
86
|
import torchax.device_module
|
|
@@ -129,34 +124,3 @@ def compile(fn, options: Optional[CompileOptions] = None):
|
|
|
129
124
|
raise RuntimeError('dynamo mode is not supported yet')
|
|
130
125
|
elif options.mode == 'export':
|
|
131
126
|
raise RuntimeError('export mode is not supported yet')
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
@contextmanager
|
|
135
|
-
def jax_device(target_device: str, env: tensor.Environment | None = None):
|
|
136
|
-
"""
|
|
137
|
-
to("jax") cannot differentiate the device/platform (cpu vs tpu).
|
|
138
|
-
Use this context manager to control jax array's storage device
|
|
139
|
-
|
|
140
|
-
Examples:
|
|
141
|
-
|
|
142
|
-
a = torch.ones(3, 3)
|
|
143
|
-
|
|
144
|
-
with jax_device("cpu"):
|
|
145
|
-
b = a.to("jax")
|
|
146
|
-
|
|
147
|
-
with jax_device("tpu"):
|
|
148
|
-
c = a.to("jax")
|
|
149
|
-
|
|
150
|
-
with jax_device("tpu"):
|
|
151
|
-
c = b.to("jax")
|
|
152
|
-
|
|
153
|
-
"""
|
|
154
|
-
if env is None:
|
|
155
|
-
env = default_env()
|
|
156
|
-
|
|
157
|
-
prev_target_device = env.target_device
|
|
158
|
-
try:
|
|
159
|
-
env.target_device = target_device
|
|
160
|
-
yield env
|
|
161
|
-
finally:
|
|
162
|
-
env.target_device = prev_target_device
|
|
@@ -61,9 +61,8 @@ def autocast(device, dtype=torch.bfloat16, env=None):
|
|
|
61
61
|
if env is None:
|
|
62
62
|
import torchax
|
|
63
63
|
env = torchax.default_env()
|
|
64
|
-
env.autocast_dtype
|
|
65
|
-
|
|
66
|
-
env.autocast_dtype = old
|
|
64
|
+
with env.override_property(autocast_dtype=dtype):
|
|
65
|
+
yield
|
|
67
66
|
|
|
68
67
|
|
|
69
68
|
# https://github.com/pytorch/pytorch/blob/05faba40287cf7d8734da96cb2e904f39710bf29/aten/src/ATen/autocast_mode.cpp#L327
|
|
@@ -10,6 +10,11 @@ class Configuration:
|
|
|
10
10
|
|
|
11
11
|
use_int32_for_index: bool = False
|
|
12
12
|
|
|
13
|
+
# normally, math between CPU torch.Tensor with torchax.Tensor is not
|
|
14
|
+
# allowed. However, if that torch.Tensor happens to be scalar, then we
|
|
15
|
+
# can use scalar * tensor math to handle it
|
|
16
|
+
allow_mixed_math_with_scalar_tensor: bool = True
|
|
17
|
+
|
|
13
18
|
# If true, we will convert Views into torchax.Tensors eagerly
|
|
14
19
|
force_materialize_views: bool = False
|
|
15
20
|
|
|
@@ -22,5 +27,4 @@ class Configuration:
|
|
|
22
27
|
|
|
23
28
|
# device
|
|
24
29
|
treat_cuda_as_jax_device: bool = True
|
|
25
|
-
use_torch_native_for_cpu_tensor: bool = True
|
|
26
30
|
internal_respect_torch_return_dtypes: bool = False
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
@dataclasses.dataclass
|
|
5
|
+
class Configuration:
|
|
6
|
+
debug_print_each_op: bool = False
|
|
7
|
+
debug_accuracy_for_each_op: bool = False
|
|
8
|
+
debug_mixed_tensor: bool = False
|
|
9
|
+
debug_print_each_op_operands: bool = False
|
|
10
|
+
|
|
11
|
+
use_int32_for_index: bool = False
|
|
12
|
+
|
|
13
|
+
# normally, math between CPU torch.Tensor with torchax.Tensor is not
|
|
14
|
+
# allowed. However, if that torch.Tensor happens to be scalar, then we
|
|
15
|
+
# can use scalar * tensor math to handle it
|
|
16
|
+
allow_mixed_math_with_scalar_tensor: bool = True
|
|
17
|
+
|
|
18
|
+
# If true, we will convert Views into torchax.Tensors eagerly
|
|
19
|
+
force_materialize_views: bool = False
|
|
20
|
+
|
|
21
|
+
# Use DLPack for converting jax.Arrays <-> and torch.Tensor
|
|
22
|
+
use_dlpack_for_data_conversion: bool = False
|
|
23
|
+
|
|
24
|
+
# Flash attention
|
|
25
|
+
use_tpu_flash_attention: bool = False
|
|
26
|
+
shmap_flash_attention: bool = False
|
|
27
|
+
|
|
28
|
+
# device
|
|
29
|
+
treat_cuda_as_jax_device: bool = True
|
|
30
|
+
internal_respect_torch_return_dtypes: bool = False
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|