torchax 0.0.4__tar.gz → 0.0.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

Files changed (144) hide show
  1. {torchax-0.0.4 → torchax-0.0.5}/.gitignore +10 -6
  2. torchax-0.0.5/PKG-INFO +307 -0
  3. torchax-0.0.5/README.md +242 -0
  4. torchax-0.0.5/dev-requirements.txt +5 -0
  5. {torchax-0.0.4 → torchax-0.0.5}/docs/fixing_op_info_test.md +10 -10
  6. {torchax-0.0.4 → torchax-0.0.5}/docs/how_it_works.md +7 -7
  7. {torchax-0.0.4 → torchax-0.0.5}/docs/ops_registry.md +3 -2
  8. {torchax-0.0.4 → torchax-0.0.5}/docs/support_a_new_model.md +5 -5
  9. {torchax-0.0.4 → torchax-0.0.5}/docs/torch_dispatch/README.md +1 -1
  10. {torchax-0.0.4 → torchax-0.0.5}/docs/torch_dispatch/example.py +18 -16
  11. {torchax-0.0.4 → torchax-0.0.5}/docs/torch_dispatch/run_env.py +0 -1
  12. {torchax-0.0.4 → torchax-0.0.5}/docs/torch_xla2_dynamo.md +13 -13
  13. {torchax-0.0.4 → torchax-0.0.5}/docs/understand_jax_jit/jax_grad.py +12 -7
  14. {torchax-0.0.4 → torchax-0.0.5}/docs/understand_jax_jit/jax_jit.py +22 -23
  15. {torchax-0.0.4 → torchax-0.0.5}/docs/understand_jax_jit/torch_module.py +12 -17
  16. {torchax-0.0.4 → torchax-0.0.5}/examples/README.md +5 -5
  17. torchax-0.0.5/examples/_diffusion.py +106 -0
  18. {torchax-0.0.4 → torchax-0.0.5}/examples/_grad_of_attention.py +28 -25
  19. torchax-0.0.5/examples/basic_training.py +195 -0
  20. {torchax-0.0.4 → torchax-0.0.5}/examples/basic_training_jax.py +54 -47
  21. {torchax-0.0.4 → torchax-0.0.5}/examples/eager_mode.py +14 -14
  22. torchax-0.0.5/examples/lightning_training.py +82 -0
  23. {torchax-0.0.4 → torchax-0.0.5}/examples/torchbench_models/BERT_pytorch.py +13 -10
  24. {torchax-0.0.4 → torchax-0.0.5}/examples/train_gpt/train_ddp.py +16 -23
  25. {torchax-0.0.4 → torchax-0.0.5}/examples/train_llama/README.md +13 -13
  26. torchax-0.0.5/examples/train_llama/model.py +510 -0
  27. torchax-0.0.5/examples/train_llama/train_llama_lightning.py +307 -0
  28. torchax-0.0.5/examples/train_llama/utils.py +318 -0
  29. {torchax-0.0.4 → torchax-0.0.5}/examples/train_llama_torchtitan/README.md +20 -20
  30. torchax-0.0.5/examples/train_llama_torchtitan/helper.py +37 -0
  31. {torchax-0.0.4 → torchax-0.0.5}/examples/train_llama_torchtitan/splash_attn.py +32 -28
  32. torchax-0.0.5/examples/train_llama_torchtitan/train_llama.py +385 -0
  33. {torchax-0.0.4 → torchax-0.0.5}/pyproject.toml +6 -10
  34. torchax-0.0.5/repro1.py +62 -0
  35. torchax-0.0.5/temp +444 -0
  36. torchax-0.0.5/test/gemma/config.py +83 -0
  37. torchax-0.0.5/test/gemma/model.py +549 -0
  38. torchax-0.0.5/test/gemma/test_gemma.py +82 -0
  39. torchax-0.0.5/test/gemma/tokenizer.py +48 -0
  40. torchax-0.0.5/test/llama/test_llama.py +111 -0
  41. torchax-0.0.5/test/moe/model.py +307 -0
  42. torchax-0.0.5/test/moe/moe_test.py +68 -0
  43. torchax-0.0.5/test/test_amp.py +38 -0
  44. {torchax-0.0.4 → torchax-0.0.5}/test/test_context.py +23 -26
  45. {torchax-0.0.4 → torchax-0.0.5}/test/test_conv.py +9 -7
  46. {torchax-0.0.4 → torchax-0.0.5}/test/test_core_aten_ops.py +49 -16
  47. {torchax-0.0.4 → torchax-0.0.5}/test/test_exports.py +60 -36
  48. torchax-0.0.5/test/test_flax.py +97 -0
  49. torchax-0.0.5/test/test_functions.py +97 -0
  50. torchax-0.0.5/test/test_image.py +68 -0
  51. torchax-0.0.5/test/test_interop.py +176 -0
  52. torchax-0.0.5/test/test_jittable_module.py +39 -0
  53. {torchax-0.0.4 → torchax-0.0.5}/test/test_libraries.py +19 -14
  54. {torchax-0.0.4 → torchax-0.0.5}/test/test_ops.py +87 -85
  55. {torchax-0.0.4 → torchax-0.0.5}/test/test_symbolic_shapes.py +18 -11
  56. {torchax-0.0.4 → torchax-0.0.5}/test/test_tf_integration.py +3 -4
  57. {torchax-0.0.4 → torchax-0.0.5}/test/test_train.py +7 -9
  58. {torchax-0.0.4 → torchax-0.0.5}/test/test_unbounded_dynamism.py +9 -1
  59. torchax-0.0.5/test/test_util.py +118 -0
  60. torchax-0.0.5/test/test_view.py +385 -0
  61. torchax-0.0.5/test-requirements.txt +10 -0
  62. {torchax-0.0.4 → torchax-0.0.5}/test_dist/test_distributed.py +34 -35
  63. torchax-0.0.5/test_dist/test_mesh_util.py +51 -0
  64. {torchax-0.0.4 → torchax-0.0.5}/torchax/CONTRIBUTING.md +2 -2
  65. {torchax-0.0.4 → torchax-0.0.5}/torchax/__init__.py +57 -19
  66. torchax-0.0.5/torchax/amp.py +333 -0
  67. torchax-0.0.5/torchax/config.py +26 -0
  68. torchax-0.0.5/torchax/decompositions.py +776 -0
  69. {torchax-0.0.4 → torchax-0.0.5}/torchax/device_module.py +7 -1
  70. {torchax-0.0.4 → torchax-0.0.5}/torchax/distributed.py +55 -60
  71. {torchax-0.0.4 → torchax-0.0.5}/torchax/export.py +26 -17
  72. torchax-0.0.5/torchax/flax.py +39 -0
  73. torchax-0.0.5/torchax/interop.py +343 -0
  74. torchax-0.0.5/torchax/mesh_util.py +211 -0
  75. {torchax-0.0.4 → torchax-0.0.5}/torchax/ops/jaten.py +1732 -1293
  76. {torchax-0.0.4 → torchax-0.0.5}/torchax/ops/jax_reimplement.py +23 -21
  77. {torchax-0.0.4 → torchax-0.0.5}/torchax/ops/jc10d.py +5 -4
  78. torchax-0.0.5/torchax/ops/jimage.py +113 -0
  79. {torchax-0.0.4 → torchax-0.0.5}/torchax/ops/jlibrary.py +9 -2
  80. {torchax-0.0.4 → torchax-0.0.5}/torchax/ops/jtorch.py +218 -75
  81. {torchax-0.0.4 → torchax-0.0.5}/torchax/ops/jtorchvision_nms.py +32 -43
  82. torchax-0.0.5/torchax/ops/mappings.py +139 -0
  83. torchax-0.0.5/torchax/ops/op_base.py +131 -0
  84. torchax-0.0.5/torchax/ops/ops_registry.py +55 -0
  85. torchax-0.0.5/torchax/tensor.py +699 -0
  86. {torchax-0.0.4 → torchax-0.0.5}/torchax/train.py +38 -41
  87. torchax-0.0.5/torchax/util.py +88 -0
  88. torchax-0.0.5/torchax/view.py +377 -0
  89. torchax-0.0.4/PKG-INFO +0 -279
  90. torchax-0.0.4/README.md +0 -211
  91. torchax-0.0.4/dev-requirements.txt +0 -4
  92. torchax-0.0.4/examples/_diffusion.py +0 -112
  93. torchax-0.0.4/examples/basic_training.py +0 -187
  94. torchax-0.0.4/examples/lightning_training.py +0 -77
  95. torchax-0.0.4/examples/train_llama/model.py +0 -449
  96. torchax-0.0.4/examples/train_llama/train_llama_lightning.py +0 -329
  97. torchax-0.0.4/examples/train_llama/utils.py +0 -314
  98. torchax-0.0.4/examples/train_llama_torchtitan/helper.py +0 -37
  99. torchax-0.0.4/examples/train_llama_torchtitan/train_llama.py +0 -351
  100. torchax-0.0.4/test/BUILD +0 -31
  101. torchax-0.0.4/test/gemma/config.py +0 -86
  102. torchax-0.0.4/test/gemma/model.py +0 -561
  103. torchax-0.0.4/test/gemma/test_gemma.py +0 -86
  104. torchax-0.0.4/test/gemma/tokenizer.py +0 -48
  105. torchax-0.0.4/test/llama/test_llama.py +0 -113
  106. torchax-0.0.4/test/moe/model.py +0 -260
  107. torchax-0.0.4/test/moe/moe_test.py +0 -75
  108. torchax-0.0.4/test/test_functions.py +0 -98
  109. torchax-0.0.4/test/test_interop.py +0 -47
  110. torchax-0.0.4/test-requirements.txt +0 -9
  111. torchax-0.0.4/torchax/config.py +0 -19
  112. torchax-0.0.4/torchax/decompositions.py +0 -308
  113. torchax-0.0.4/torchax/environment.py +0 -2
  114. torchax-0.0.4/torchax/interop.py +0 -209
  115. torchax-0.0.4/torchax/ops/mappings.py +0 -97
  116. torchax-0.0.4/torchax/ops/op_base.py +0 -104
  117. torchax-0.0.4/torchax/ops/ops_registry.py +0 -50
  118. torchax-0.0.4/torchax/tensor.py +0 -557
  119. {torchax-0.0.4 → torchax-0.0.5}/=2.3.0 +0 -0
  120. {torchax-0.0.4 → torchax-0.0.5}/LICENSE +0 -0
  121. {torchax-0.0.4 → torchax-0.0.5}/build_nightly.sh +0 -0
  122. {torchax-0.0.4 → torchax-0.0.5}/docs/dispatch.png +0 -0
  123. {torchax-0.0.4 → torchax-0.0.5}/examples/__init__.py +0 -0
  124. {torchax-0.0.4 → torchax-0.0.5}/examples/mnist_tpu.ipynb +0 -0
  125. {torchax-0.0.4 → torchax-0.0.5}/examples/requirements.txt +0 -0
  126. {torchax-0.0.4 → torchax-0.0.5}/examples/train_gpt/requirements.txt +0 -0
  127. {torchax-0.0.4 → torchax-0.0.5}/examples/train_llama/__init__.py +0 -0
  128. {torchax-0.0.4 → torchax-0.0.5}/examples/train_llama_torchtitan/Dockerfile +0 -0
  129. {torchax-0.0.4 → torchax-0.0.5}/examples/train_llama_torchtitan/__init__.py +0 -0
  130. {torchax-0.0.4 → torchax-0.0.5}/format.sh +0 -0
  131. {torchax-0.0.4 → torchax-0.0.5}/test/__init__.py +0 -0
  132. /torchax-0.0.4/test/test_base.py → /torchax-0.0.5/test/base_test_util.py +0 -0
  133. {torchax-0.0.4 → torchax-0.0.5}/test/gemma/__init__.py +0 -0
  134. {torchax-0.0.4 → torchax-0.0.5}/test/llama/BUILD +0 -0
  135. {torchax-0.0.4 → torchax-0.0.5}/test/llama/__init__.py +0 -0
  136. {torchax-0.0.4 → torchax-0.0.5}/test/llama/llama_model.py +0 -0
  137. {torchax-0.0.4 → torchax-0.0.5}/test/llama/model_exportable.py +0 -0
  138. {torchax-0.0.4 → torchax-0.0.5}/test/moe/__init__.py +0 -0
  139. {torchax-0.0.4 → torchax-0.0.5}/test/test_mutations.py +0 -0
  140. {torchax-0.0.4 → torchax-0.0.5}/test_dist/README.md +0 -0
  141. {torchax-0.0.4 → torchax-0.0.5}/test_dist/__init__.py +0 -0
  142. {torchax-0.0.4 → torchax-0.0.5}/torchax/ops/__init__.py +0 -0
  143. {torchax-0.0.4 → torchax-0.0.5}/torchax/tf_integration.py +0 -0
  144. {torchax-0.0.4 → torchax-0.0.5}/torchax/types.py +0 -0
@@ -6,28 +6,32 @@ torch_xla/pb/cpp/*
6
6
  torch_xla/version.py
7
7
  torch_xla/csrc/version.cpp
8
8
  */**/__pycache__
9
+ */**/*MNIST/
10
+ torchax/**/runs/
9
11
  *.swp
10
12
  *.pyc
11
13
  *.so
12
14
 
13
15
  # BEGIN NOT-CLEAN-FILES (setup.py handles this marker. Do not change.)
14
16
  #
15
- # Below files are not deleted by "setup.py clean".
17
+ # Files below are not deleted by "setup.py clean".
16
18
 
17
- # Visual Studio Code files
19
+ # Visual Studio Code files.
18
20
  .vs
19
21
  .vscode/
20
22
 
21
- # Files autogenerated by docs/docs_build.sh
23
+ # Files autogenerated by docs/docs_build.sh.
22
24
  /core
23
25
  /docs/src/*
24
26
 
25
- # Local terraform state
27
+ # Local terraform state.
26
28
  .terraform
27
29
 
28
30
 
29
- # Build system temporary files
31
+ # Bazel temporary files.
30
32
  bazel-*
33
+ MODULE.bazel
34
+ MODULE.bazel.lock
31
35
 
32
- # Clangd cache directory
36
+ # Clangd cache directory.
33
37
  .cache/*
torchax-0.0.5/PKG-INFO ADDED
@@ -0,0 +1,307 @@
1
+ Metadata-Version: 2.4
2
+ Name: torchax
3
+ Version: 0.0.5
4
+ Summary: torchax is a library for running Jax and PyTorch together
5
+ Project-URL: Homepage, https://github.com/pytorch/xla/tree/master/torchax
6
+ Author-email: Han Qi <qihan.dev@gmail.com>, Pytorch/XLA team <pytorchxla-dev@google.com>
7
+ License: BSD 3-Clause License
8
+
9
+ Copyright (c) 2023, pytorch-tpu
10
+
11
+ Redistribution and use in source and binary forms, with or without
12
+ modification, are permitted provided that the following conditions are met:
13
+
14
+ 1. Redistributions of source code must retain the above copyright notice, this
15
+ list of conditions and the following disclaimer.
16
+
17
+ 2. Redistributions in binary form must reproduce the above copyright notice,
18
+ this list of conditions and the following disclaimer in the documentation
19
+ and/or other materials provided with the distribution.
20
+
21
+ 3. Neither the name of the copyright holder nor the names of its
22
+ contributors may be used to endorse or promote products derived from
23
+ this software without specific prior written permission.
24
+
25
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
28
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
29
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
30
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
31
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
32
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
33
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
34
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
35
+ License-File: LICENSE
36
+ Classifier: Development Status :: 3 - Alpha
37
+ Classifier: Intended Audience :: Developers
38
+ Classifier: Intended Audience :: Education
39
+ Classifier: Intended Audience :: Science/Research
40
+ Classifier: License :: OSI Approved :: BSD License
41
+ Classifier: Programming Language :: Python :: 3.10
42
+ Classifier: Programming Language :: Python :: 3.11
43
+ Classifier: Programming Language :: Python :: 3.12
44
+ Classifier: Programming Language :: Python :: 3.13
45
+ Classifier: Topic :: Scientific/Engineering
46
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
47
+ Classifier: Topic :: Scientific/Engineering :: Mathematics
48
+ Classifier: Topic :: Software Development
49
+ Classifier: Topic :: Software Development :: Libraries
50
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
51
+ Requires-Python: >=3.10
52
+ Provides-Extra: cpu
53
+ Requires-Dist: jax[cpu]; extra == 'cpu'
54
+ Requires-Dist: jax[cpu]>=0.6.2; extra == 'cpu'
55
+ Provides-Extra: cuda
56
+ Requires-Dist: jax[cpu]>=0.6.2; extra == 'cuda'
57
+ Requires-Dist: jax[cuda12]; extra == 'cuda'
58
+ Provides-Extra: odml
59
+ Requires-Dist: jax[cpu]; extra == 'odml'
60
+ Requires-Dist: jax[cpu]>=0.6.2; extra == 'odml'
61
+ Provides-Extra: tpu
62
+ Requires-Dist: jax[cpu]>=0.6.2; extra == 'tpu'
63
+ Requires-Dist: jax[tpu]; extra == 'tpu'
64
+ Description-Content-Type: text/markdown
65
+
66
+ # torchax: Running PyTorch on TPU via JAX
67
+
68
+ **torchax** is a backend for PyTorch, allowing users to run
69
+ PyTorch on Google Cloud TPUs. **torchax** is also a library for providing
70
+ graph-level interoperability between PyTorch and JAX.
71
+
72
+ This means, with **torchax** you can:
73
+ * Run PyTorch code on TPUs with as little as 2 lines of code change.
74
+ * Call a JAX function from a PyTorch function, passing in `jax.Array`s.
75
+ * Call a PyTorch function from a JAX function, passing in a `torch.Tensor`s.
76
+ * Use JAX features such as `jax.grad`, `optax`, and `GSPMD` to train a PyTorch
77
+ model.
78
+ * Use a PyTorch model as feature extractor and use it with a JAX model.
79
+ etc etc.
80
+
81
+ ## Install
82
+
83
+ First install torch CPU:
84
+
85
+ ```bash
86
+ # On Linux.
87
+ pip install torch --index-url https://download.pytorch.org/whl/cpu
88
+
89
+ # Or on Mac.
90
+ pip install torch
91
+ ```
92
+
93
+ Then install JAX for the accelerator you want to use:
94
+
95
+ ```bash
96
+ # On Google Cloud TPU.
97
+ pip install -U jax[tpu]
98
+
99
+ # Or, on GPU machines.
100
+ pip install -U jax[cuda12]
101
+
102
+ # Or, on Linux CPU machines or Macs (see the note below).
103
+ pip install -U jax
104
+ ```
105
+
106
+ NOTE: if you like metal support for Apple devices then install the
107
+ metal version of JAX: https://developer.apple.com/metal/jax/
108
+
109
+ Finally install torchax:
110
+
111
+ ```bash
112
+ # Install pre-built torchax.
113
+ pip install torchax
114
+
115
+ # Or, install torchax from source.
116
+ pip install git+https://github.com/pytorch/xla.git#subdirectory=torchax
117
+ ```
118
+
119
+ ## Run a model
120
+
121
+ Now let's execute a model under torchax. We'll start with a simple 2-layer model.
122
+ In theory, we can use any instance of `torch.nn.Module`.
123
+
124
+ ```python
125
+ import torch
126
+ import torch.nn as nn
127
+ import torch.nn.functional as F
128
+
129
+
130
+ class MyModel(nn.Module):
131
+ def __init__(self):
132
+ super().__init__()
133
+ self.fc1 = nn.Linear(28 * 28, 120)
134
+ self.fc2 = nn.Linear(120, 84)
135
+ self.fc3 = nn.Linear(84, 10)
136
+
137
+ def forward(self, x):
138
+ x = x.view(-1, 28 * 28)
139
+ x = F.relu(self.fc1(x))
140
+ x = F.relu(self.fc2(x))
141
+ x = self.fc3(x)
142
+ return x
143
+
144
+ m = MyModel()
145
+
146
+ # Execute this model using torch.
147
+ inputs = torch.randn(3, 3, 28, 28)
148
+ print(m(inputs))
149
+ ```
150
+
151
+ To execute this model with `torchax`, we need to enable torchax to capture PyTorch ops:
152
+
153
+ ```python
154
+ import torchax
155
+ torchax.enable_globally()
156
+ ```
157
+
158
+ Then, we can use a `jax` device:
159
+
160
+ ```python
161
+ inputs = torch.randn(3, 3, 28, 28, device='jax')
162
+ m = MyModel().to('jax')
163
+ res = m(inputs)
164
+ print(type(res)) # outputs torchax.tensor.Tensor
165
+ ```
166
+
167
+ `torchax.tensor.Tensor` is a `torch.Tensor` subclass that holds
168
+ a `jax.Array`. You can inspect that JAX array with `res.jax()`.
169
+
170
+ ## What is happening behind the scene
171
+
172
+ We took the approach detailed in the
173
+ [new device](https://github.com/albanD/subclass_zoo/blob/main/new_device.py)
174
+ recipe by Alban (@albanD), using `jax.Array` for `raw_data`.
175
+
176
+ In other words, when a torch op is executed inside an `env` context manager,
177
+ which is enabled by `torchax.enable_globally()`, we will swap out the
178
+ implementation of that op with JAX.
179
+
180
+ When a model's constructor runs, it will call some tensor constructor, such as
181
+ `torch.rand`, `torch.ones`, or `torch.zeros` to create its weights. When torchax
182
+ is enabled, these constructors will create a `torchax.tensor.Tensor`, which
183
+ contains a `jax.Array`.
184
+
185
+ Then, each subsequent op will extract the `jax.Array`, call the op's JAX
186
+ implementation, and wrap the result back into a `torchax.tensor.Tensor`,
187
+
188
+ See more at [how it works](docs/how_it_works.md) and\
189
+ [ops registry](docs/ops_registry.md).
190
+
191
+ ### Executing with jax.jit
192
+
193
+ The above script will execute the model using eager mode JAX as the backend. This
194
+ does allow executing torch models on TPUs, but is often slower than what we can
195
+ achieve with `jax.jit`.
196
+
197
+ `jax.jit` is a function that takes a JAX function (i.e. a function that takes JAX arrays
198
+ and returns JAX arrays) into a compiled (thus faster) version of the same function.
199
+
200
+ We have made a `jax_jit` decorator that would accomplish the same with functions
201
+ that takes and returns `torch.Tensor`s. To use this, the first step is to create
202
+ a functional version of this model: this means the parameters should be passed in
203
+ as input instead of being attributes of the class:
204
+
205
+ ```python
206
+ def model_func(param, inputs):
207
+ return torch.func.functional_call(m, param, inputs)
208
+ ```
209
+
210
+ Here we use [torch.func.functional_call](https://pytorch.org/docs/stable/generated/torch.func.functional_call.html)
211
+ from PyTorch to replace the model weights with `param` and then call the
212
+ model. This is roughly equivalent to:
213
+
214
+ ```python
215
+ def model_func(param, inputs):
216
+ m.load_state_dict(param)
217
+ return m(*inputs)
218
+ ```
219
+
220
+ Now, we can apply `jax_jit` on `module_func`:
221
+
222
+ ```python
223
+ from torchax.interop import jax_jit
224
+
225
+ model_func_jitted = jax_jit(model_func)
226
+ print(model_func_jitted(new_state_dict, inputs))
227
+ ```
228
+
229
+ See more examples at [eager_mode.py](examples/eager_mode.py) and the
230
+ [examples folder](examples/).
231
+
232
+ To ease the idiom of creating functional model and calling it with parameters,
233
+ we also created the `JittableModule` helper class. It lets us rewrite the
234
+ above as:
235
+
236
+ ```python
237
+ from torchax.interop import JittableModule
238
+
239
+ m_jitted = JittableModule(m)
240
+ res = m_jitted(...)
241
+ ```
242
+
243
+ The first time `m_jitted` is called, it will trigger `jax.jit` to compile the
244
+ compile for the given input shapes. Subsequent calls with the same input shapes
245
+ will be fast as the compilation is cached.
246
+
247
+ ## Citation
248
+
249
+ ```
250
+ @software{torchax,
251
+ author = {Han Qi, Chun-nien Chan, Will Cromar, Manfei Bai, Kevin Gleanson},
252
+ title = {torchax: PyTorch on TPU and JAX interoperability},
253
+ url = {https://github.com/pytorch/xla/tree/master/torchax}
254
+ version = {0.0.4},
255
+ date = {2025-02-24},
256
+ }
257
+ ```
258
+
259
+ # Maintainers & Contributors:
260
+
261
+ This library is created and maintained by the PyTorch/XLA team at Google Cloud.
262
+
263
+ It benefitted from many direct and indirect
264
+ contributions outside of the team. Many of them done by
265
+ fellow Googlers using [Google's 20% project policy](https://ebsedu.org/blog/google-tapping-workplace-actualization-20-time-rule).
266
+ Others by partner teams at Google and other companies.
267
+
268
+ Here is the list of contributors by 2025-02-25.
269
+
270
+ ```
271
+ Han Qi (qihqi), PyTorch/XLA
272
+ Manfei Bai (manfeibai), PyTorch/XLA
273
+ Will Cromar (will-cromar), Meta
274
+ Milad Mohammadi (miladm), PyTorch/XLA
275
+ Siyuan Liu (lsy323), PyTorch/XLA
276
+ Bhavya Bahl (bhavya01), PyTorch/XLA
277
+ Pei Zhang (zpcore), PyTorch/XLA
278
+ Yifei Teng (tengyifei), PyTorch/XLA
279
+ Chunnien Chan (chunnienc), Google, ODML
280
+ Alban Desmaison (albanD), Meta, PyTorch
281
+ Simon Teo (simonteozw), Google (20%)
282
+ David Huang (dvhg), Google (20%)
283
+ Barni Seetharaman (barney-s), Google (20%)
284
+ Anish Karthik (anishfish2), Google (20%)
285
+ Yao Gu (guyao), Google (20%)
286
+ Yenkai Wang (yenkwang), Google (20%)
287
+ Greg Shikhman (commander), Google (20%)
288
+ Matin Akhlaghinia (matinehAkhlaghinia), Google (20%)
289
+ Tracy Chen (tracych477), Google (20%)
290
+ Matthias Guenther (mrguenther), Google (20%)
291
+ WenXin Dong (wenxindongwork), Google (20%)
292
+ Kevin Gleason (GleasonK), Google, StableHLO
293
+ Nupur Baghel (nupurbaghel), Google (20%)
294
+ Gwen Mittertreiner (gmittert), Google (20%)
295
+ Zeev Melumian (zmelumian), Lightricks
296
+ Vyom Sharma (vyom1611), Google (20%)
297
+ Shitong Wang (ShitongWang), Adobe
298
+ Rémi Doreau (ayshiff), Google (20%)
299
+ Lance Wang (wang2yn84), Google, CoreML
300
+ Hossein Sarshar (hosseinsarshar), Google (20%)
301
+ Daniel Vega-Myhre (danielvegamyhre), Google (20%)
302
+ Tianqi Fan (tqfan28), Google (20%)
303
+ Jim Lin (jimlinntu), Google (20%)
304
+ Fanhai Lu (FanhaiLu1), Google Cloud
305
+ DeWitt Clinton (dewitt), Google PyTorch
306
+ Aman Gupta (aman2930), Google (20%)
307
+ ```
@@ -0,0 +1,242 @@
1
+ # torchax: Running PyTorch on TPU via JAX
2
+
3
+ **torchax** is a backend for PyTorch, allowing users to run
4
+ PyTorch on Google Cloud TPUs. **torchax** is also a library for providing
5
+ graph-level interoperability between PyTorch and JAX.
6
+
7
+ This means, with **torchax** you can:
8
+ * Run PyTorch code on TPUs with as little as 2 lines of code change.
9
+ * Call a JAX function from a PyTorch function, passing in `jax.Array`s.
10
+ * Call a PyTorch function from a JAX function, passing in a `torch.Tensor`s.
11
+ * Use JAX features such as `jax.grad`, `optax`, and `GSPMD` to train a PyTorch
12
+ model.
13
+ * Use a PyTorch model as feature extractor and use it with a JAX model.
14
+ etc etc.
15
+
16
+ ## Install
17
+
18
+ First install torch CPU:
19
+
20
+ ```bash
21
+ # On Linux.
22
+ pip install torch --index-url https://download.pytorch.org/whl/cpu
23
+
24
+ # Or on Mac.
25
+ pip install torch
26
+ ```
27
+
28
+ Then install JAX for the accelerator you want to use:
29
+
30
+ ```bash
31
+ # On Google Cloud TPU.
32
+ pip install -U jax[tpu]
33
+
34
+ # Or, on GPU machines.
35
+ pip install -U jax[cuda12]
36
+
37
+ # Or, on Linux CPU machines or Macs (see the note below).
38
+ pip install -U jax
39
+ ```
40
+
41
+ NOTE: if you like metal support for Apple devices then install the
42
+ metal version of JAX: https://developer.apple.com/metal/jax/
43
+
44
+ Finally install torchax:
45
+
46
+ ```bash
47
+ # Install pre-built torchax.
48
+ pip install torchax
49
+
50
+ # Or, install torchax from source.
51
+ pip install git+https://github.com/pytorch/xla.git#subdirectory=torchax
52
+ ```
53
+
54
+ ## Run a model
55
+
56
+ Now let's execute a model under torchax. We'll start with a simple 2-layer model.
57
+ In theory, we can use any instance of `torch.nn.Module`.
58
+
59
+ ```python
60
+ import torch
61
+ import torch.nn as nn
62
+ import torch.nn.functional as F
63
+
64
+
65
+ class MyModel(nn.Module):
66
+ def __init__(self):
67
+ super().__init__()
68
+ self.fc1 = nn.Linear(28 * 28, 120)
69
+ self.fc2 = nn.Linear(120, 84)
70
+ self.fc3 = nn.Linear(84, 10)
71
+
72
+ def forward(self, x):
73
+ x = x.view(-1, 28 * 28)
74
+ x = F.relu(self.fc1(x))
75
+ x = F.relu(self.fc2(x))
76
+ x = self.fc3(x)
77
+ return x
78
+
79
+ m = MyModel()
80
+
81
+ # Execute this model using torch.
82
+ inputs = torch.randn(3, 3, 28, 28)
83
+ print(m(inputs))
84
+ ```
85
+
86
+ To execute this model with `torchax`, we need to enable torchax to capture PyTorch ops:
87
+
88
+ ```python
89
+ import torchax
90
+ torchax.enable_globally()
91
+ ```
92
+
93
+ Then, we can use a `jax` device:
94
+
95
+ ```python
96
+ inputs = torch.randn(3, 3, 28, 28, device='jax')
97
+ m = MyModel().to('jax')
98
+ res = m(inputs)
99
+ print(type(res)) # outputs torchax.tensor.Tensor
100
+ ```
101
+
102
+ `torchax.tensor.Tensor` is a `torch.Tensor` subclass that holds
103
+ a `jax.Array`. You can inspect that JAX array with `res.jax()`.
104
+
105
+ ## What is happening behind the scene
106
+
107
+ We took the approach detailed in the
108
+ [new device](https://github.com/albanD/subclass_zoo/blob/main/new_device.py)
109
+ recipe by Alban (@albanD), using `jax.Array` for `raw_data`.
110
+
111
+ In other words, when a torch op is executed inside an `env` context manager,
112
+ which is enabled by `torchax.enable_globally()`, we will swap out the
113
+ implementation of that op with JAX.
114
+
115
+ When a model's constructor runs, it will call some tensor constructor, such as
116
+ `torch.rand`, `torch.ones`, or `torch.zeros` to create its weights. When torchax
117
+ is enabled, these constructors will create a `torchax.tensor.Tensor`, which
118
+ contains a `jax.Array`.
119
+
120
+ Then, each subsequent op will extract the `jax.Array`, call the op's JAX
121
+ implementation, and wrap the result back into a `torchax.tensor.Tensor`,
122
+
123
+ See more at [how it works](docs/how_it_works.md) and\
124
+ [ops registry](docs/ops_registry.md).
125
+
126
+ ### Executing with jax.jit
127
+
128
+ The above script will execute the model using eager mode JAX as the backend. This
129
+ does allow executing torch models on TPUs, but is often slower than what we can
130
+ achieve with `jax.jit`.
131
+
132
+ `jax.jit` is a function that takes a JAX function (i.e. a function that takes JAX arrays
133
+ and returns JAX arrays) into a compiled (thus faster) version of the same function.
134
+
135
+ We have made a `jax_jit` decorator that would accomplish the same with functions
136
+ that takes and returns `torch.Tensor`s. To use this, the first step is to create
137
+ a functional version of this model: this means the parameters should be passed in
138
+ as input instead of being attributes of the class:
139
+
140
+ ```python
141
+ def model_func(param, inputs):
142
+ return torch.func.functional_call(m, param, inputs)
143
+ ```
144
+
145
+ Here we use [torch.func.functional_call](https://pytorch.org/docs/stable/generated/torch.func.functional_call.html)
146
+ from PyTorch to replace the model weights with `param` and then call the
147
+ model. This is roughly equivalent to:
148
+
149
+ ```python
150
+ def model_func(param, inputs):
151
+ m.load_state_dict(param)
152
+ return m(*inputs)
153
+ ```
154
+
155
+ Now, we can apply `jax_jit` on `module_func`:
156
+
157
+ ```python
158
+ from torchax.interop import jax_jit
159
+
160
+ model_func_jitted = jax_jit(model_func)
161
+ print(model_func_jitted(new_state_dict, inputs))
162
+ ```
163
+
164
+ See more examples at [eager_mode.py](examples/eager_mode.py) and the
165
+ [examples folder](examples/).
166
+
167
+ To ease the idiom of creating functional model and calling it with parameters,
168
+ we also created the `JittableModule` helper class. It lets us rewrite the
169
+ above as:
170
+
171
+ ```python
172
+ from torchax.interop import JittableModule
173
+
174
+ m_jitted = JittableModule(m)
175
+ res = m_jitted(...)
176
+ ```
177
+
178
+ The first time `m_jitted` is called, it will trigger `jax.jit` to compile the
179
+ compile for the given input shapes. Subsequent calls with the same input shapes
180
+ will be fast as the compilation is cached.
181
+
182
+ ## Citation
183
+
184
+ ```
185
+ @software{torchax,
186
+ author = {Han Qi, Chun-nien Chan, Will Cromar, Manfei Bai, Kevin Gleanson},
187
+ title = {torchax: PyTorch on TPU and JAX interoperability},
188
+ url = {https://github.com/pytorch/xla/tree/master/torchax}
189
+ version = {0.0.4},
190
+ date = {2025-02-24},
191
+ }
192
+ ```
193
+
194
+ # Maintainers & Contributors:
195
+
196
+ This library is created and maintained by the PyTorch/XLA team at Google Cloud.
197
+
198
+ It benefitted from many direct and indirect
199
+ contributions outside of the team. Many of them done by
200
+ fellow Googlers using [Google's 20% project policy](https://ebsedu.org/blog/google-tapping-workplace-actualization-20-time-rule).
201
+ Others by partner teams at Google and other companies.
202
+
203
+ Here is the list of contributors by 2025-02-25.
204
+
205
+ ```
206
+ Han Qi (qihqi), PyTorch/XLA
207
+ Manfei Bai (manfeibai), PyTorch/XLA
208
+ Will Cromar (will-cromar), Meta
209
+ Milad Mohammadi (miladm), PyTorch/XLA
210
+ Siyuan Liu (lsy323), PyTorch/XLA
211
+ Bhavya Bahl (bhavya01), PyTorch/XLA
212
+ Pei Zhang (zpcore), PyTorch/XLA
213
+ Yifei Teng (tengyifei), PyTorch/XLA
214
+ Chunnien Chan (chunnienc), Google, ODML
215
+ Alban Desmaison (albanD), Meta, PyTorch
216
+ Simon Teo (simonteozw), Google (20%)
217
+ David Huang (dvhg), Google (20%)
218
+ Barni Seetharaman (barney-s), Google (20%)
219
+ Anish Karthik (anishfish2), Google (20%)
220
+ Yao Gu (guyao), Google (20%)
221
+ Yenkai Wang (yenkwang), Google (20%)
222
+ Greg Shikhman (commander), Google (20%)
223
+ Matin Akhlaghinia (matinehAkhlaghinia), Google (20%)
224
+ Tracy Chen (tracych477), Google (20%)
225
+ Matthias Guenther (mrguenther), Google (20%)
226
+ WenXin Dong (wenxindongwork), Google (20%)
227
+ Kevin Gleason (GleasonK), Google, StableHLO
228
+ Nupur Baghel (nupurbaghel), Google (20%)
229
+ Gwen Mittertreiner (gmittert), Google (20%)
230
+ Zeev Melumian (zmelumian), Lightricks
231
+ Vyom Sharma (vyom1611), Google (20%)
232
+ Shitong Wang (ShitongWang), Adobe
233
+ Rémi Doreau (ayshiff), Google (20%)
234
+ Lance Wang (wang2yn84), Google, CoreML
235
+ Hossein Sarshar (hosseinsarshar), Google (20%)
236
+ Daniel Vega-Myhre (danielvegamyhre), Google (20%)
237
+ Tianqi Fan (tqfan28), Google (20%)
238
+ Jim Lin (jimlinntu), Google (20%)
239
+ Fanhai Lu (FanhaiLu1), Google Cloud
240
+ DeWitt Clinton (dewitt), Google PyTorch
241
+ Aman Gupta (aman2930), Google (20%)
242
+ ```
@@ -0,0 +1,5 @@
1
+ -f https://download.pytorch.org/whl/torch
2
+ torch==2.7.1 ; sys_platform == 'darwin' # macOS
3
+ torch==2.7.1+cpu; sys_platform != 'darwin' # Non-macOS (CPU-only), like on TPU
4
+ yapf==0.40.2 # N.B.: keep in sync with `infra/ansible/config/pip.yaml`, `.github/workflows/lintercheck.yml`
5
+ flax==0.10.6
@@ -17,7 +17,7 @@ Context:
17
17
  ### Remove one op from skiplist
18
18
 
19
19
  Open [test/test_ops.py](../test/test_ops.py) with your
20
- favorite text editor.
20
+ favorite text editor.
21
21
  Remove one line from the `skiplist` set.
22
22
 
23
23
  i.e.
@@ -49,7 +49,7 @@ For errors you might get after running test, there are two kind:
49
49
  Error gotten:
50
50
 
51
51
  ```
52
- (base) hanq-macbookpro:torchax hanq$ python test/test_ops.py
52
+ (base) hanq-macbookpro:torchax hanq$ python test/test_ops.py
53
53
  ...
54
54
  E RuntimeError: ('No lowering found for\n\nTo execute this test, run the following from the base repo dir:\n python test/test_ops.py -k test_reference_eager_addbmm_cpu_int64\n\nThis message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0', 'aten::addbmm')
55
55
  ```
@@ -60,10 +60,10 @@ From here we have 2 strategies for fixing this test:
60
60
  2. Add an implementation `aten::addbmm` operator using torch ops (this commonly known as "decompositions").
61
61
 
62
62
  Either way works for torchax. For ops that are not "Core Aten" sometimes we implement in torch ops with the goal of
63
- upstreaming this decomposition to [pytorch decompositon](https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py)
63
+ upstreaming this decomposition to [pytorch decompositon](https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py)
64
64
  so other projects can benefit from it.
65
65
 
66
- For illustration purposes, let's implement this op in Jax.
66
+ For illustration purposes, let's implement this op in Jax.
67
67
 
68
68
  (NOTE: this doesn't stop us from upstreaming a decomposition later if we want)
69
69
 
@@ -104,7 +104,7 @@ Please try to fix it by following these steps:
104
104
 
105
105
  ### First Impl
106
106
 
107
- To implement this op using jax ops, we first find what
107
+ To implement this op using jax ops, we first find what
108
108
  is the exact semantics in this page:
109
109
  https://pytorch.org/docs/stable/generated/torch.addbmm.html
110
110
 
@@ -124,7 +124,7 @@ Now running test again:
124
124
  python test/test_ops.py -k test_reference_eager_addbmm_cpu_int64
125
125
  ```
126
126
 
127
- (NOTE: the exact test command is printed out when we run
127
+ (NOTE: the exact test command is printed out when we run
128
128
  `pytest test/test_ops.py` so we can only run the failed test instead of running all tests.)
129
129
 
130
130
  We now see this error:
@@ -140,14 +140,14 @@ Traceback (most recent call last):
140
140
  AssertionError: False is not true
141
141
  ```
142
142
 
143
- This is telling me that our implementation did not produce
143
+ This is telling me that our implementation did not produce
144
144
  the same result as the ops in PyTorch.
145
145
 
146
146
  To debug this, let's figure out what exact input caused this.
147
- We can achieve this by setting a break point [here](https://github.com/pytorch/xla/blob/master/experimental/torchax/test/test_ops.py#L644), right before the diff. Here we can
147
+ We can achieve this by setting a break point [here](https://github.com/pytorch/xla/blob/master/experimental/torchax/test/test_ops.py#L644), right before the diff. Here we can
148
148
  inspect values of `res` and `res2`, as well as the `sample_input`.
149
149
 
150
- The sample input we get is
150
+ The sample input we get is
151
151
  ```
152
152
  SampleInput(input=tensor([[-3, -3, 9, 8, -8, -3, -4, 2, 2, 2],
153
153
  [-5, 1, -9, 9, 1, -5, 6, 1, -4, -5],
@@ -212,7 +212,7 @@ SampleInput(input=tensor([[-3, -3, 9, 8, -8, -3, -4, 2, 2, 2],
212
212
  [ 7, 3, 0, 1, 1, -9, 5, -8, -1, -7]]])), kwargs={'beta': 0.6, 'alpha': 0.2}, broadcasts_input=False, name='')
213
213
  ```
214
214
 
215
- And the `res` from torch is
215
+ And the `res` from torch is
216
216
 
217
217
  ```
218
218
  tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],