torchax 0.0.5__tar.gz → 0.0.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

Files changed (121) hide show
  1. {torchax-0.0.5 → torchax-0.0.6}/PKG-INFO +1 -1
  2. torchax-0.0.6/docs/api_iterations.md +22 -0
  3. torchax-0.0.6/test/BUILD +31 -0
  4. torchax-0.0.6/test/test_base.py +55 -0
  5. {torchax-0.0.5 → torchax-0.0.6}/test/test_context.py +16 -20
  6. {torchax-0.0.5 → torchax-0.0.6}/test/test_core_aten_ops.py +0 -5
  7. {torchax-0.0.5 → torchax-0.0.6}/test/test_flax.py +1 -1
  8. {torchax-0.0.5 → torchax-0.0.6}/test/test_functions.py +9 -2
  9. {torchax-0.0.5 → torchax-0.0.6}/test/test_interop.py +14 -18
  10. {torchax-0.0.5 → torchax-0.0.6}/test/test_jittable_module.py +19 -0
  11. {torchax-0.0.5 → torchax-0.0.6}/test/test_libraries.py +0 -1
  12. torchax-0.0.6/test/test_misc.py +22 -0
  13. torchax-0.0.6/test/test_mutations.py +52 -0
  14. {torchax-0.0.5 → torchax-0.0.6}/test/test_ops.py +3 -6
  15. {torchax-0.0.5 → torchax-0.0.6}/test/test_unbounded_dynamism.py +0 -5
  16. {torchax-0.0.5 → torchax-0.0.6}/torchax/__init__.py +5 -41
  17. {torchax-0.0.5 → torchax-0.0.6}/torchax/amp.py +2 -3
  18. {torchax-0.0.5 → torchax-0.0.6}/torchax/config.py +5 -1
  19. torchax-0.0.6/torchax/configuration.py +30 -0
  20. {torchax-0.0.5 → torchax-0.0.6}/torchax/device_module.py +7 -0
  21. torchax-0.0.6/torchax/environment.py +1 -0
  22. {torchax-0.0.5 → torchax-0.0.6}/torchax/interop.py +27 -14
  23. {torchax-0.0.5 → torchax-0.0.6}/torchax/mesh_util.py +10 -1
  24. {torchax-0.0.5 → torchax-0.0.6}/torchax/ops/jaten.py +5 -3
  25. {torchax-0.0.5 → torchax-0.0.6}/torchax/ops/jtorch.py +18 -10
  26. {torchax-0.0.5 → torchax-0.0.6}/torchax/tensor.py +127 -115
  27. torchax-0.0.5/examples/mnist_tpu.ipynb +0 -647
  28. torchax-0.0.5/examples/train_gpt/train_ddp.py +0 -140
  29. torchax-0.0.5/test/test_mutations.py +0 -36
  30. torchax-0.0.5/test_dist/README.md +0 -4
  31. torchax-0.0.5/test_dist/__init__.py +0 -0
  32. torchax-0.0.5/test_dist/test_distributed.py +0 -154
  33. torchax-0.0.5/torchax/distributed.py +0 -241
  34. {torchax-0.0.5 → torchax-0.0.6}/.gitignore +0 -0
  35. {torchax-0.0.5 → torchax-0.0.6}/=2.3.0 +0 -0
  36. {torchax-0.0.5 → torchax-0.0.6}/LICENSE +0 -0
  37. {torchax-0.0.5 → torchax-0.0.6}/README.md +0 -0
  38. {torchax-0.0.5 → torchax-0.0.6}/build_nightly.sh +0 -0
  39. {torchax-0.0.5 → torchax-0.0.6}/dev-requirements.txt +0 -0
  40. {torchax-0.0.5 → torchax-0.0.6}/docs/dispatch.png +0 -0
  41. {torchax-0.0.5 → torchax-0.0.6}/docs/fixing_op_info_test.md +0 -0
  42. {torchax-0.0.5 → torchax-0.0.6}/docs/how_it_works.md +0 -0
  43. {torchax-0.0.5 → torchax-0.0.6}/docs/ops_registry.md +0 -0
  44. {torchax-0.0.5 → torchax-0.0.6}/docs/support_a_new_model.md +0 -0
  45. {torchax-0.0.5 → torchax-0.0.6}/docs/torch_dispatch/README.md +0 -0
  46. {torchax-0.0.5 → torchax-0.0.6}/docs/torch_dispatch/example.py +0 -0
  47. {torchax-0.0.5 → torchax-0.0.6}/docs/torch_dispatch/run_env.py +0 -0
  48. {torchax-0.0.5 → torchax-0.0.6}/docs/torch_xla2_dynamo.md +0 -0
  49. {torchax-0.0.5 → torchax-0.0.6}/docs/understand_jax_jit/jax_grad.py +0 -0
  50. {torchax-0.0.5 → torchax-0.0.6}/docs/understand_jax_jit/jax_jit.py +0 -0
  51. {torchax-0.0.5 → torchax-0.0.6}/docs/understand_jax_jit/torch_module.py +0 -0
  52. {torchax-0.0.5 → torchax-0.0.6}/examples/README.md +0 -0
  53. {torchax-0.0.5 → torchax-0.0.6}/examples/__init__.py +0 -0
  54. {torchax-0.0.5 → torchax-0.0.6}/examples/_diffusion.py +0 -0
  55. {torchax-0.0.5 → torchax-0.0.6}/examples/_grad_of_attention.py +0 -0
  56. {torchax-0.0.5 → torchax-0.0.6}/examples/basic_training.py +0 -0
  57. {torchax-0.0.5 → torchax-0.0.6}/examples/basic_training_jax.py +0 -0
  58. {torchax-0.0.5 → torchax-0.0.6}/examples/eager_mode.py +0 -0
  59. {torchax-0.0.5 → torchax-0.0.6}/examples/lightning_training.py +0 -0
  60. {torchax-0.0.5 → torchax-0.0.6}/examples/requirements.txt +0 -0
  61. {torchax-0.0.5 → torchax-0.0.6}/examples/torchbench_models/BERT_pytorch.py +0 -0
  62. {torchax-0.0.5 → torchax-0.0.6}/examples/train_gpt/requirements.txt +0 -0
  63. {torchax-0.0.5 → torchax-0.0.6}/examples/train_llama/README.md +0 -0
  64. {torchax-0.0.5 → torchax-0.0.6}/examples/train_llama/__init__.py +0 -0
  65. {torchax-0.0.5 → torchax-0.0.6}/examples/train_llama/model.py +0 -0
  66. {torchax-0.0.5 → torchax-0.0.6}/examples/train_llama/train_llama_lightning.py +0 -0
  67. {torchax-0.0.5 → torchax-0.0.6}/examples/train_llama/utils.py +0 -0
  68. {torchax-0.0.5 → torchax-0.0.6}/examples/train_llama_torchtitan/Dockerfile +0 -0
  69. {torchax-0.0.5 → torchax-0.0.6}/examples/train_llama_torchtitan/README.md +0 -0
  70. {torchax-0.0.5 → torchax-0.0.6}/examples/train_llama_torchtitan/__init__.py +0 -0
  71. {torchax-0.0.5 → torchax-0.0.6}/examples/train_llama_torchtitan/helper.py +0 -0
  72. {torchax-0.0.5 → torchax-0.0.6}/examples/train_llama_torchtitan/splash_attn.py +0 -0
  73. {torchax-0.0.5 → torchax-0.0.6}/examples/train_llama_torchtitan/train_llama.py +0 -0
  74. {torchax-0.0.5 → torchax-0.0.6}/format.sh +0 -0
  75. {torchax-0.0.5 → torchax-0.0.6}/pyproject.toml +0 -0
  76. {torchax-0.0.5 → torchax-0.0.6}/repro1.py +0 -0
  77. {torchax-0.0.5 → torchax-0.0.6}/temp +0 -0
  78. {torchax-0.0.5 → torchax-0.0.6}/test/__init__.py +0 -0
  79. {torchax-0.0.5 → torchax-0.0.6}/test/base_test_util.py +0 -0
  80. {torchax-0.0.5 → torchax-0.0.6}/test/gemma/__init__.py +0 -0
  81. {torchax-0.0.5 → torchax-0.0.6}/test/gemma/config.py +0 -0
  82. {torchax-0.0.5 → torchax-0.0.6}/test/gemma/model.py +0 -0
  83. {torchax-0.0.5 → torchax-0.0.6}/test/gemma/test_gemma.py +0 -0
  84. {torchax-0.0.5 → torchax-0.0.6}/test/gemma/tokenizer.py +0 -0
  85. {torchax-0.0.5 → torchax-0.0.6}/test/llama/BUILD +0 -0
  86. {torchax-0.0.5 → torchax-0.0.6}/test/llama/__init__.py +0 -0
  87. {torchax-0.0.5 → torchax-0.0.6}/test/llama/llama_model.py +0 -0
  88. {torchax-0.0.5 → torchax-0.0.6}/test/llama/model_exportable.py +0 -0
  89. {torchax-0.0.5 → torchax-0.0.6}/test/llama/test_llama.py +0 -0
  90. {torchax-0.0.5 → torchax-0.0.6}/test/moe/__init__.py +0 -0
  91. {torchax-0.0.5 → torchax-0.0.6}/test/moe/model.py +0 -0
  92. {torchax-0.0.5 → torchax-0.0.6}/test/moe/moe_test.py +0 -0
  93. {torchax-0.0.5 → torchax-0.0.6}/test/test_amp.py +0 -0
  94. {torchax-0.0.5 → torchax-0.0.6}/test/test_conv.py +0 -0
  95. {torchax-0.0.5 → torchax-0.0.6}/test/test_exports.py +0 -0
  96. {torchax-0.0.5 → torchax-0.0.6}/test/test_image.py +0 -0
  97. {torchax-0.0.5 → torchax-0.0.6}/test/test_symbolic_shapes.py +0 -0
  98. {torchax-0.0.5 → torchax-0.0.6}/test/test_tf_integration.py +0 -0
  99. {torchax-0.0.5 → torchax-0.0.6}/test/test_train.py +0 -0
  100. {torchax-0.0.5 → torchax-0.0.6}/test/test_util.py +0 -0
  101. {torchax-0.0.5 → torchax-0.0.6}/test/test_view.py +0 -0
  102. {torchax-0.0.5 → torchax-0.0.6}/test-requirements.txt +0 -0
  103. {torchax-0.0.5 → torchax-0.0.6}/test_dist/test_mesh_util.py +0 -0
  104. {torchax-0.0.5 → torchax-0.0.6}/torchax/CONTRIBUTING.md +0 -0
  105. {torchax-0.0.5 → torchax-0.0.6}/torchax/decompositions.py +0 -0
  106. {torchax-0.0.5 → torchax-0.0.6}/torchax/export.py +0 -0
  107. {torchax-0.0.5 → torchax-0.0.6}/torchax/flax.py +0 -0
  108. {torchax-0.0.5 → torchax-0.0.6}/torchax/ops/__init__.py +0 -0
  109. {torchax-0.0.5 → torchax-0.0.6}/torchax/ops/jax_reimplement.py +0 -0
  110. {torchax-0.0.5 → torchax-0.0.6}/torchax/ops/jc10d.py +0 -0
  111. {torchax-0.0.5 → torchax-0.0.6}/torchax/ops/jimage.py +0 -0
  112. {torchax-0.0.5 → torchax-0.0.6}/torchax/ops/jlibrary.py +0 -0
  113. {torchax-0.0.5 → torchax-0.0.6}/torchax/ops/jtorchvision_nms.py +0 -0
  114. {torchax-0.0.5 → torchax-0.0.6}/torchax/ops/mappings.py +0 -0
  115. {torchax-0.0.5 → torchax-0.0.6}/torchax/ops/op_base.py +0 -0
  116. {torchax-0.0.5 → torchax-0.0.6}/torchax/ops/ops_registry.py +0 -0
  117. {torchax-0.0.5 → torchax-0.0.6}/torchax/tf_integration.py +0 -0
  118. {torchax-0.0.5 → torchax-0.0.6}/torchax/train.py +0 -0
  119. {torchax-0.0.5 → torchax-0.0.6}/torchax/types.py +0 -0
  120. {torchax-0.0.5 → torchax-0.0.6}/torchax/util.py +0 -0
  121. {torchax-0.0.5 → torchax-0.0.6}/torchax/view.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torchax
3
- Version: 0.0.5
3
+ Version: 0.0.6
4
4
  Summary: torchax is a library for running Jax and PyTorch together
5
5
  Project-URL: Homepage, https://github.com/pytorch/xla/tree/master/torchax
6
6
  Author-email: Han Qi <qihan.dev@gmail.com>, Pytorch/XLA team <pytorchxla-dev@google.com>
@@ -0,0 +1,22 @@
1
+
2
+ ## always create a new environment, use it, discard it?
3
+
4
+ ```
5
+ env = torchax.env()
6
+ with env.with_prng_key(): // or extra inputs
7
+ do stuff
8
+
9
+ # discard env
10
+
11
+ with env.output_shape
12
+ with env.manual axis ...
13
+
14
+ env.call_torch_func(f, args, kwargs)
15
+
16
+ functions should take in env
17
+ functions in torch will get env from threadlocal property
18
+ ```
19
+ env.call_stateless_torch_func()?
20
+
21
+ tx = torchax.initialize(...)
22
+ ```
@@ -0,0 +1,31 @@
1
+ # TODO(hanq): describe this package.
2
+
3
+ load(
4
+ "//third_party/py/torch/google/bazel_rules/rules_python/python:defs.bzl",
5
+ "py_library",
6
+ "py_test",
7
+ )
8
+
9
+ package(
10
+ default_applicable_licenses = ["//devtools/compliance/licenses:no_external_contributions"],
11
+ default_visibility = ["//visibility:public"],
12
+ licenses = ["notice"],
13
+ )
14
+
15
+ py_library(
16
+ name = "test_base",
17
+ srcs = ["test_base.py"],
18
+ deps = [
19
+ "//testing/pybase",
20
+ ],
21
+ )
22
+
23
+ py_test(
24
+ name = "test_core_aten_ops",
25
+ srcs = ["test_core_aten_ops.py"],
26
+ deps = [
27
+ ":test_base",
28
+ "//third_party/py/absl:app",
29
+ "//third_party/py/torch/google/_torx",
30
+ ],
31
+ )
@@ -0,0 +1,55 @@
1
+ import unittest
2
+ import torch
3
+ from torch.utils import _pytree as pytree
4
+
5
+ from torchax import tensor
6
+
7
+ TestCase = unittest.TestCase
8
+ main = unittest.main
9
+
10
+
11
+ def diff_output(testcase, output1, output2, rtol, atol, equal_nan=True):
12
+ if isinstance(output1, torch.Tensor):
13
+ testcase.assertIsInstance(output2, torch.Tensor)
14
+ output2_cpu = output2.detach().cpu()
15
+ if output2_cpu.dtype != output1.dtype:
16
+ output2_cpu = output2_cpu.to(output1.dtype)
17
+ testcase.assertTrue(
18
+ torch.allclose(
19
+ output1, output2_cpu, atol=atol, rtol=rtol, equal_nan=equal_nan))
20
+ elif isinstance(output1, (tuple, list)):
21
+ testcase.assertIsInstance(output2, (tuple, list))
22
+ testcase.assertEqual(len(output1), len(output2))
23
+ for o1, o2 in zip(output1, output2):
24
+ diff_output(testcase, o1, o2, rtol, atol)
25
+ else:
26
+ testcase.assertEqual(output1, output2)
27
+
28
+
29
+ def run_function_and_compare(testcase,
30
+ func,
31
+ args,
32
+ kwargs,
33
+ atol=1e-3,
34
+ rtol=1e-5,
35
+ equal_nan=True,
36
+ ignore_indices=False):
37
+ with testcase.subTest("torch_eval"):
38
+ res = func(*args, **kwargs)
39
+ with testcase.subTest("torchax_eval"):
40
+ args2, kwargs2 = pytree.tree_map_only(torch.Tensor, tensor.move_to_device,
41
+ (args, kwargs))
42
+ res2 = func(*args2, **kwargs2)
43
+ res2 = pytree.tree_map_only(tensor.Tensor, lambda t: t.torch(), res2)
44
+ with testcase.subTest("torchax_diff:" + str(atol)):
45
+ if ignore_indices and isinstance(res, tuple) and len(res) == 2:
46
+ diff_output(
47
+ testcase,
48
+ res[0],
49
+ res2[0],
50
+ atol=atol,
51
+ rtol=rtol,
52
+ equal_nan=equal_nan)
53
+ else:
54
+ diff_output(
55
+ testcase, res, res2, atol=atol, rtol=rtol, equal_nan=equal_nan)
@@ -10,16 +10,9 @@ xla_env = torchax.default_env()
10
10
 
11
11
  class TestContext(unittest.TestCase):
12
12
 
13
- def setUp(self):
14
- self.old_var = xla_env.config.use_torch_native_for_cpu_tensor
15
- xla_env.config.use_torch_native_for_cpu_tensor = False
16
-
17
- def tearDown(self):
18
- xla_env.config.use_torch_native_for_cpu_tensor = self.old_var
19
-
20
13
  def test_mode_context_manager(self):
21
14
  with xla_env:
22
- x = torch.full((3, 3), -1)
15
+ x = torch.full((3, 3), -1, device='jax')
23
16
  self.assertIsInstance(x, tensor.Tensor)
24
17
  y = x.abs()
25
18
  self.assertIsInstance(y, tensor.Tensor)
@@ -27,7 +20,7 @@ class TestContext(unittest.TestCase):
27
20
  @staticmethod
28
21
  @xla_env
29
22
  def _test_mode_decorator():
30
- x = torch.full((3, 3), -1)
23
+ x = torch.full((3, 3), -1).to('jax')
31
24
  y = x.abs()
32
25
 
33
26
  return x, y
@@ -40,11 +33,11 @@ class TestContext(unittest.TestCase):
40
33
  def test_same_manual_seed(self):
41
34
  with xla_env:
42
35
  xla_env.manual_seed(1234)
43
- x = torch.randn((3, 3))
36
+ x = torch.randn((3, 3), device='jax')
44
37
  self.assertIsInstance(x, tensor.Tensor)
45
38
 
46
39
  xla_env.manual_seed(1234)
47
- y = torch.randn((3, 3))
40
+ y = torch.randn((3, 3), device='jax')
48
41
  self.assertIsInstance(y, tensor.Tensor)
49
42
 
50
43
  self.assertTrue(torch.allclose(x, y))
@@ -52,11 +45,11 @@ class TestContext(unittest.TestCase):
52
45
  def test_different_manual_seed(self):
53
46
  with xla_env:
54
47
  xla_env.manual_seed(1234)
55
- x = torch.randn((3, 3))
48
+ x = torch.randn((3, 3), device='jax')
56
49
  self.assertIsInstance(x, tensor.Tensor)
57
50
 
58
51
  xla_env.manual_seed(12345)
59
- y = torch.randn((3, 3))
52
+ y = torch.randn((3, 3), device='jax')
60
53
  self.assertIsInstance(y, tensor.Tensor)
61
54
 
62
55
  self.assertFalse(torch.allclose(x, y))
@@ -66,21 +59,24 @@ class TestContext(unittest.TestCase):
66
59
  with xla_env:
67
60
 
68
61
  def random_op():
69
- x = torch.randn(3, 3)
70
- y = torch.randn(3, 3)
62
+ x = torch.randn(3, 3, device='jax')
63
+ y = torch.randn(3, 3, device='jax')
71
64
  return x @ y
72
65
 
73
66
  random_jit = torchax.interop.jax_jit(random_op)
74
67
  self.assertIsInstance(random_jit(), tensor.Tensor)
75
68
 
76
69
  # If we run the JIT twice, the random values should be different.
77
- with self.assertRaises(AssertionError):
78
- torch.testing.assert_close(random_jit(), random_jit(), atol=0, rtol=0)
70
+ # TODO(qihqi): think about API for passing down seed
71
+ # with self.assertRaises(AssertionError):
72
+ # torch.testing.assert_close(random_jit(), random_jit(), atol=0, rtol=0)
79
73
 
80
74
  def test_generator_seed(self):
81
75
  with xla_env:
82
- x = torch.randn(2, 3, generator=torch.Generator().manual_seed(0))
83
- y = torch.randn(2, 3, generator=torch.Generator().manual_seed(0))
76
+ x = torch.randn(
77
+ 2, 3, generator=torch.Generator().manual_seed(0), device='jax')
78
+ y = torch.randn(
79
+ 2, 3, generator=torch.Generator().manual_seed(0), device='jax')
84
80
 
85
81
  # Values will be the same given the same seed.
86
82
  torch.testing.assert_close(x, y)
@@ -97,7 +93,7 @@ class TestContext(unittest.TestCase):
97
93
 
98
94
  # Test context manager.
99
95
  with xla_env:
100
- m = M()
96
+ m = M().to('jax')
101
97
  self.assertIsInstance(m.c, tensor.Tensor)
102
98
  self.assertIsInstance(m.c2, tensor.Tensor)
103
99
  # Test `to_xla`.
@@ -90,11 +90,6 @@ class TestCoreAtenOps(unittest.TestCase):
90
90
  super().setUp()
91
91
  torch.manual_seed(0)
92
92
  self.env = tensor.Environment()
93
- self.old_var = self.env.config.use_torch_native_for_cpu_tensor
94
- self.env.config.use_torch_native_for_cpu_tensor = False
95
-
96
- def tearDown(self):
97
- self.env.config.use_torch_native_for_cpu_tensor = self.old_var
98
93
 
99
94
  def test_aten_abs_0(self):
100
95
  args = (torch.randn((10, 10)).to(torch.float32),)
@@ -81,7 +81,7 @@ class FlaxTest(unittest.TestCase):
81
81
  return res
82
82
 
83
83
  with env:
84
- nn_module = Parent()
84
+ nn_module = Parent().to('jax')
85
85
 
86
86
  @jax_jit
87
87
  def jitted(weights, args):
@@ -88,8 +88,15 @@ class TestTorchFunctions(parameterized.TestCase):
88
88
  res2 = model(x)
89
89
  self.assertTrue(torch.allclose(res, res2.to('cpu')))
90
90
 
91
- def test_randn_requires_grad(self):
92
- x = torch.randn((3, 3), requires_grad=True, device='jax')
91
+ @parameterized.named_parameters(
92
+ ('ones', torch.ones, ((2, 2),)), ('zeros', torch.zeros, ((2, 2),)),
93
+ ('empty', torch.empty,
94
+ ((2, 2),)), ('empty_strided', torch.empty_strided,
95
+ ((2, 2), (2, 1))), ('tensor', torch.tensor, ([2.0, 2.0],)),
96
+ ('eye', torch.eye, (2,)), ('randn', torch.randn, ((2, 2),)),
97
+ ('rand', torch.rand, ((2, 2),)), ('full', torch.full, ((2, 2), 0)))
98
+ def test_requires_grad(self, func, args):
99
+ x = func(*args, requires_grad=True, device='jax')
93
100
  self.assertEqual(x.requires_grad, True)
94
101
 
95
102
 
@@ -2,9 +2,10 @@ import functools
2
2
  import torch
3
3
  import unittest
4
4
  import torchax
5
- from torchax import interop, jax_device
5
+ from torchax import interop
6
6
  import torchax
7
7
  import jax
8
+ import jax.numpy as jnp
8
9
 
9
10
 
10
11
  def is_tpu_available():
@@ -142,7 +143,7 @@ class InteropTest(unittest.TestCase):
142
143
  self.assertEqual(e.jax_device.platform, "cpu")
143
144
  self.assertEqual(e.device.type, "jax")
144
145
 
145
- with jax_device("cpu"):
146
+ with jax.default_device(jax.devices("cpu")[0]):
146
147
  # move torch.tensor to torchax.tensor CPU
147
148
  b = a.to("jax")
148
149
  self.assertEqual(b.jax_device.platform, "cpu")
@@ -150,26 +151,21 @@ class InteropTest(unittest.TestCase):
150
151
 
151
152
  if is_tpu_available():
152
153
  # move torch.tensor to torchax.tensor TPU
153
- with jax_device("tpu"):
154
+ with jax.default_device(jax.local_devices("tpu")[0]):
154
155
  c = a.to("jax")
155
156
  self.assertEqual(c.jax_device.platform, "tpu")
156
157
  self.assertEqual(c.device.type, "jax")
157
158
 
158
- # move torchax.tensor on CPU to TPU
159
- with jax_device("tpu"):
160
- self.assertEqual(b.jax_device.platform, "cpu")
161
- self.assertEqual(c.device.type, "jax")
162
- c = b.to("jax")
163
- self.assertEqual(c.jax_device.platform, "tpu")
164
- self.assertEqual(c.device.type, "jax")
165
-
166
- # move torchax.tensor on TPU to CPU
167
- with jax_device("cpu"):
168
- self.assertEqual(c.jax_device.platform, "tpu")
169
- self.assertEqual(c.device.type, "jax")
170
- d = c.to("jax")
171
- self.assertEqual(d.jax_device.platform, "cpu")
172
- self.assertEqual(d.device.type, "jax")
159
+ def test_torch_jax_view_dtype(self):
160
+ dtype = torch.float32
161
+ self.assertEqual(interop.jax_view(dtype), jnp.float32.dtype)
162
+ self.assertEqual(interop.torch_view(interop.jax_view(dtype)), dtype)
163
+ dtype = torch.bfloat16
164
+ self.assertEqual(interop.jax_view(dtype), jnp.bfloat16.dtype)
165
+ self.assertEqual(interop.torch_view(interop.jax_view(dtype)), dtype)
166
+ dtype = torch.int32
167
+ self.assertEqual(interop.jax_view(dtype), jnp.int32.dtype)
168
+ self.assertEqual(interop.torch_view(interop.jax_view(dtype)), dtype)
173
169
 
174
170
 
175
171
  if __name__ == '__main__':
@@ -34,6 +34,25 @@ class JittableModuleTest(unittest.TestCase):
34
34
  assert isinstance(JittableMoreAwesomeModel, EvenMoreAwesomeModel)
35
35
  assert not isinstance(JittableMoreAwesomeModel, MyAwesomeModel)
36
36
 
37
+ def test_functional_call_callable(self):
38
+
39
+ def outer_function(model, x):
40
+ return x + 1
41
+
42
+ model = MyAwesomeModel()
43
+ jittable_module = interop.JittableModule(model)
44
+
45
+ # Check if the jittable module can be called like a function
46
+ input_tensor = torch.randn(1, 3, 224, 224)
47
+ expected_output = input_tensor + 1
48
+
49
+ output = jittable_module.functional_call(outer_function,
50
+ jittable_module.params,
51
+ jittable_module.buffers,
52
+ input_tensor)
53
+
54
+ assert torch.equal(output, expected_output)
55
+
37
56
 
38
57
  if __name__ == '__main__':
39
58
  unittest.main()
@@ -54,7 +54,6 @@ class LibraryTest(unittest.TestCase):
54
54
 
55
55
  def setUp(self):
56
56
  torch.manual_seed(0)
57
- torchax.default_env().config.use_torch_native_for_cpu_tensor = False
58
57
 
59
58
  def test_basic_sdpa_library(self):
60
59
 
@@ -0,0 +1,22 @@
1
+ import unittest
2
+ import torch
3
+ import torchax
4
+
5
+
6
+ class MiscTest(unittest.TestCase):
7
+
8
+ @classmethod
9
+ def setUpClass(cls):
10
+ torchax.enable_globally()
11
+
12
+ def test_mixed_tensor_math_with_scalar(self):
13
+ a = torch.tensor(2)
14
+ b = torch.ones((2, 2), device='jax')
15
+ c = a * b
16
+ self.assertTrue(
17
+ torch.allclose(c.cpu(),
18
+ torch.tensor([[2, 2], [2, 2]], dtype=torch.float32)))
19
+
20
+
21
+ if __name__ == '__main__':
22
+ unittest.main()
@@ -0,0 +1,52 @@
1
+ import unittest
2
+ import torchax
3
+ import torch
4
+ from torch.testing._internal.common_utils import TestCase
5
+
6
+
7
+ class TestMutations(TestCase):
8
+
9
+ def setUp(self):
10
+ self.env = torchax.tensor.Environment()
11
+ self.env.config.debug_print_each_op = True
12
+
13
+ def test_add(self):
14
+ with self.env:
15
+ x = torch.tensor([1, 2, 3], device='jax', dtype=torch.int32)
16
+ y = torch.tensor([4, 5, 6], device='jax', dtype=torch.int32)
17
+ x.add_(y)
18
+ torch.testing.assert_close(x.cpu(),
19
+ torch.tensor([5, 7, 9], dtype=torch.int32))
20
+
21
+ def test_sub(self):
22
+ with self.env:
23
+ x = torch.tensor([1, 2, 3], device='jax', dtype=torch.int32)
24
+ y = torch.tensor([4, 5, 6], device='jax', dtype=torch.int32)
25
+ x.sub_(y)
26
+ torch.testing.assert_close(x.cpu(),
27
+ torch.tensor([-3, -3, -3], dtype=torch.int32))
28
+
29
+ def test_mul(self):
30
+ with self.env:
31
+ x = torch.tensor([1, 2, 3], device='jax', dtype=torch.int32)
32
+ y = torch.tensor([4, 5, 6], device='jax', dtype=torch.int32)
33
+
34
+ x.mul_(y)
35
+ torch.testing.assert_close(x.cpu(),
36
+ torch.tensor([4, 10, 18], dtype=torch.int32))
37
+
38
+ def test_index_copy(self):
39
+ with self.env:
40
+ x = torch.zeros(5, 3, device='jax')
41
+ t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
42
+ device='jax',
43
+ dtype=torch.float)
44
+ index = torch.tensor([0, 4, 2], device='jax')
45
+ x.index_copy_(0, index, t)
46
+ expected = torch.tensor([[1., 2., 3.], [0., 0., 0.], [7., 8., 9.],
47
+ [0., 0., 0.], [4., 5., 6.]])
48
+ torch.testing.assert_close(x.cpu(), expected)
49
+
50
+
51
+ if __name__ == '__main__':
52
+ unittest.main()
@@ -140,6 +140,8 @@ def run_export_and_compare(testcase,
140
140
  with testcase.subTest("torchax_eval"):
141
141
  input2, args2, kwargs2 = testcase.env.to_xla(
142
142
  (sample_input.input, sample_input.args, sample_input.kwargs))
143
+ if 'device' in kwargs2:
144
+ kwargs2['device'] = 'jax'
143
145
  with testcase.env:
144
146
  res2 = func(input2, *args2, **kwargs2)
145
147
  res2 = pytree.tree_map_only(tensor.Tensor, lambda t: t.torch(), res2)
@@ -186,13 +188,8 @@ class TestOpInfo(TestCase):
186
188
  self.env = torchax.default_env()
187
189
  torchax.enable_accuracy_mode()
188
190
  #self.env.config.debug_accuracy_for_each_op = True
189
- self.env.config.debug_print_each_op = True
191
+ self.env.config.debug_print_each_op = False
190
192
  torch.manual_seed(0)
191
- self.old_var = self.env.config.use_torch_native_for_cpu_tensor
192
- self.env.config.use_torch_native_for_cpu_tensor = False
193
-
194
- def tearDown(self):
195
- self.env.config.use_torch_native_for_cpu_tensor = self.old_var
196
193
 
197
194
  # Replaces all values in the input torch_tensor that are less than the given threshold
198
195
  # with the threshold value itself.
@@ -53,13 +53,8 @@ def wrap_func_as_nn_module(f):
53
53
  class UnboundedDynamismExportTest(unittest.TestCase):
54
54
 
55
55
  def setUp(self):
56
- self.env = torchax.default_env()
57
- self.env.config.use_torch_native_for_cpu_tensor = False
58
56
  torchax.enable_accuracy_mode()
59
57
 
60
- def tearDown(self):
61
- self.env.config.use_torch_native_for_cpu_tensor = True
62
-
63
58
  def test_add(self):
64
59
  args = (torch.rand((10, 197, 768)), torch.rand((10, 197, 768)))
65
60
  dynamic_shapes = (({0: Dim("dim")}, {0: Dim("dim")}),)
@@ -6,10 +6,9 @@ import os
6
6
  import torch
7
7
  from torch.utils import _pytree as pytree
8
8
  from torchax import tensor
9
- from torchax import distributed # noqa: F401
10
9
  from contextlib import contextmanager
11
10
 
12
- __version__ = "0.0.5"
11
+ __version__ = "0.0.6"
13
12
  VERSION = __version__
14
13
 
15
14
  __all__ = [
@@ -50,10 +49,11 @@ def extract_jax(mod: torch.nn.Module, env=None):
50
49
  states = env.t2j_copy(states)
51
50
 
52
51
  #@jax.jit
53
- def jax_func(states, inputs):
54
- (states, inputs) = env.j2t_iso((states, inputs))
52
+ def jax_func(states, args, kwargs=None):
53
+ (states, args, kwargs) = env.j2t_iso((states, args, kwargs))
55
54
  with env:
56
- res = torch.func.functional_call(mod, states, inputs, tie_weights=False)
55
+ res = torch.func.functional_call(
56
+ mod, states, args, kwargs, tie_weights=False)
57
57
  return env.t2j_iso(res)
58
58
 
59
59
  return states, jax_func
@@ -81,11 +81,6 @@ def disable_temporarily():
81
81
 
82
82
  torch.utils.rename_privateuse1_backend('jax')
83
83
  unsupported_dtype = [torch.quint8]
84
- torch.utils.generate_methods_for_privateuse1_backend(
85
- for_tensor=True,
86
- for_module=True,
87
- for_storage=True,
88
- unsupported_dtype=unsupported_dtype)
89
84
 
90
85
  import jax
91
86
  import torchax.device_module
@@ -129,34 +124,3 @@ def compile(fn, options: Optional[CompileOptions] = None):
129
124
  raise RuntimeError('dynamo mode is not supported yet')
130
125
  elif options.mode == 'export':
131
126
  raise RuntimeError('export mode is not supported yet')
132
-
133
-
134
- @contextmanager
135
- def jax_device(target_device: str, env: tensor.Environment | None = None):
136
- """
137
- to("jax") cannot differentiate the device/platform (cpu vs tpu).
138
- Use this context manager to control jax array's storage device
139
-
140
- Examples:
141
-
142
- a = torch.ones(3, 3)
143
-
144
- with jax_device("cpu"):
145
- b = a.to("jax")
146
-
147
- with jax_device("tpu"):
148
- c = a.to("jax")
149
-
150
- with jax_device("tpu"):
151
- c = b.to("jax")
152
-
153
- """
154
- if env is None:
155
- env = default_env()
156
-
157
- prev_target_device = env.target_device
158
- try:
159
- env.target_device = target_device
160
- yield env
161
- finally:
162
- env.target_device = prev_target_device
@@ -61,9 +61,8 @@ def autocast(device, dtype=torch.bfloat16, env=None):
61
61
  if env is None:
62
62
  import torchax
63
63
  env = torchax.default_env()
64
- env.autocast_dtype, old = dtype, env.autocast_dtype
65
- yield
66
- env.autocast_dtype = old
64
+ with env.override_property(autocast_dtype=dtype):
65
+ yield
67
66
 
68
67
 
69
68
  # https://github.com/pytorch/pytorch/blob/05faba40287cf7d8734da96cb2e904f39710bf29/aten/src/ATen/autocast_mode.cpp#L327
@@ -10,6 +10,11 @@ class Configuration:
10
10
 
11
11
  use_int32_for_index: bool = False
12
12
 
13
+ # normally, math between CPU torch.Tensor with torchax.Tensor is not
14
+ # allowed. However, if that torch.Tensor happens to be scalar, then we
15
+ # can use scalar * tensor math to handle it
16
+ allow_mixed_math_with_scalar_tensor: bool = True
17
+
13
18
  # If true, we will convert Views into torchax.Tensors eagerly
14
19
  force_materialize_views: bool = False
15
20
 
@@ -22,5 +27,4 @@ class Configuration:
22
27
 
23
28
  # device
24
29
  treat_cuda_as_jax_device: bool = True
25
- use_torch_native_for_cpu_tensor: bool = True
26
30
  internal_respect_torch_return_dtypes: bool = False
@@ -0,0 +1,30 @@
1
+ import dataclasses
2
+
3
+
4
+ @dataclasses.dataclass
5
+ class Configuration:
6
+ debug_print_each_op: bool = False
7
+ debug_accuracy_for_each_op: bool = False
8
+ debug_mixed_tensor: bool = False
9
+ debug_print_each_op_operands: bool = False
10
+
11
+ use_int32_for_index: bool = False
12
+
13
+ # normally, math between CPU torch.Tensor with torchax.Tensor is not
14
+ # allowed. However, if that torch.Tensor happens to be scalar, then we
15
+ # can use scalar * tensor math to handle it
16
+ allow_mixed_math_with_scalar_tensor: bool = True
17
+
18
+ # If true, we will convert Views into torchax.Tensors eagerly
19
+ force_materialize_views: bool = False
20
+
21
+ # Use DLPack for converting jax.Arrays <-> and torch.Tensor
22
+ use_dlpack_for_data_conversion: bool = False
23
+
24
+ # Flash attention
25
+ use_tpu_flash_attention: bool = False
26
+ shmap_flash_attention: bool = False
27
+
28
+ # device
29
+ treat_cuda_as_jax_device: bool = True
30
+ internal_respect_torch_return_dtypes: bool = False
@@ -1,3 +1,6 @@
1
+ import torch
2
+
3
+
1
4
  def _is_in_bad_fork():
2
5
  return False
3
6
 
@@ -24,3 +27,7 @@ def is_available():
24
27
 
25
28
  def current_device():
26
29
  return 0
30
+
31
+
32
+ def get_amp_supported_dtype():
33
+ return [torch.float16, torch.bfloat16]
@@ -0,0 +1 @@
1
+