torchax 0.0.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

Files changed (101) hide show
  1. torchax-0.0.4/.gitignore +33 -0
  2. torchax-0.0.4/=2.3.0 +9 -0
  3. torchax-0.0.4/LICENSE +28 -0
  4. torchax-0.0.4/PKG-INFO +279 -0
  5. torchax-0.0.4/README.md +211 -0
  6. torchax-0.0.4/build_nightly.sh +10 -0
  7. torchax-0.0.4/dev-requirements.txt +4 -0
  8. torchax-0.0.4/docs/dispatch.png +0 -0
  9. torchax-0.0.4/docs/fixing_op_info_test.md +254 -0
  10. torchax-0.0.4/docs/how_it_works.md +134 -0
  11. torchax-0.0.4/docs/ops_registry.md +40 -0
  12. torchax-0.0.4/docs/support_a_new_model.md +137 -0
  13. torchax-0.0.4/docs/torch_dispatch/README.md +39 -0
  14. torchax-0.0.4/docs/torch_dispatch/example.py +65 -0
  15. torchax-0.0.4/docs/torch_dispatch/run_env.py +12 -0
  16. torchax-0.0.4/docs/torch_xla2_dynamo.md +194 -0
  17. torchax-0.0.4/docs/understand_jax_jit/jax_grad.py +66 -0
  18. torchax-0.0.4/docs/understand_jax_jit/jax_jit.py +106 -0
  19. torchax-0.0.4/docs/understand_jax_jit/torch_module.py +126 -0
  20. torchax-0.0.4/examples/README.md +115 -0
  21. torchax-0.0.4/examples/__init__.py +0 -0
  22. torchax-0.0.4/examples/_diffusion.py +112 -0
  23. torchax-0.0.4/examples/_grad_of_attention.py +73 -0
  24. torchax-0.0.4/examples/basic_training.py +187 -0
  25. torchax-0.0.4/examples/basic_training_jax.py +131 -0
  26. torchax-0.0.4/examples/eager_mode.py +38 -0
  27. torchax-0.0.4/examples/lightning_training.py +77 -0
  28. torchax-0.0.4/examples/mnist_tpu.ipynb +647 -0
  29. torchax-0.0.4/examples/requirements.txt +3 -0
  30. torchax-0.0.4/examples/torchbench_models/BERT_pytorch.py +49 -0
  31. torchax-0.0.4/examples/train_gpt/requirements.txt +4 -0
  32. torchax-0.0.4/examples/train_gpt/train_ddp.py +147 -0
  33. torchax-0.0.4/examples/train_llama/README.md +194 -0
  34. torchax-0.0.4/examples/train_llama/__init__.py +0 -0
  35. torchax-0.0.4/examples/train_llama/model.py +449 -0
  36. torchax-0.0.4/examples/train_llama/train_llama_lightning.py +329 -0
  37. torchax-0.0.4/examples/train_llama/utils.py +314 -0
  38. torchax-0.0.4/examples/train_llama_torchtitan/Dockerfile +34 -0
  39. torchax-0.0.4/examples/train_llama_torchtitan/README.md +511 -0
  40. torchax-0.0.4/examples/train_llama_torchtitan/__init__.py +0 -0
  41. torchax-0.0.4/examples/train_llama_torchtitan/helper.py +37 -0
  42. torchax-0.0.4/examples/train_llama_torchtitan/splash_attn.py +97 -0
  43. torchax-0.0.4/examples/train_llama_torchtitan/train_llama.py +351 -0
  44. torchax-0.0.4/format.sh +4 -0
  45. torchax-0.0.4/pyproject.toml +57 -0
  46. torchax-0.0.4/test/BUILD +31 -0
  47. torchax-0.0.4/test/__init__.py +0 -0
  48. torchax-0.0.4/test/gemma/__init__.py +0 -0
  49. torchax-0.0.4/test/gemma/config.py +86 -0
  50. torchax-0.0.4/test/gemma/model.py +561 -0
  51. torchax-0.0.4/test/gemma/test_gemma.py +86 -0
  52. torchax-0.0.4/test/gemma/tokenizer.py +48 -0
  53. torchax-0.0.4/test/llama/BUILD +25 -0
  54. torchax-0.0.4/test/llama/__init__.py +0 -0
  55. torchax-0.0.4/test/llama/llama_model.py +310 -0
  56. torchax-0.0.4/test/llama/model_exportable.py +304 -0
  57. torchax-0.0.4/test/llama/test_llama.py +113 -0
  58. torchax-0.0.4/test/moe/__init__.py +0 -0
  59. torchax-0.0.4/test/moe/model.py +260 -0
  60. torchax-0.0.4/test/moe/moe_test.py +75 -0
  61. torchax-0.0.4/test/test_base.py +55 -0
  62. torchax-0.0.4/test/test_context.py +114 -0
  63. torchax-0.0.4/test/test_conv.py +80 -0
  64. torchax-0.0.4/test/test_core_aten_ops.py +4506 -0
  65. torchax-0.0.4/test/test_exports.py +148 -0
  66. torchax-0.0.4/test/test_functions.py +98 -0
  67. torchax-0.0.4/test/test_interop.py +47 -0
  68. torchax-0.0.4/test/test_libraries.py +83 -0
  69. torchax-0.0.4/test/test_mutations.py +36 -0
  70. torchax-0.0.4/test/test_ops.py +236 -0
  71. torchax-0.0.4/test/test_symbolic_shapes.py +95 -0
  72. torchax-0.0.4/test/test_tf_integration.py +52 -0
  73. torchax-0.0.4/test/test_train.py +60 -0
  74. torchax-0.0.4/test/test_unbounded_dynamism.py +670 -0
  75. torchax-0.0.4/test-requirements.txt +9 -0
  76. torchax-0.0.4/test_dist/README.md +4 -0
  77. torchax-0.0.4/test_dist/__init__.py +0 -0
  78. torchax-0.0.4/test_dist/test_distributed.py +155 -0
  79. torchax-0.0.4/torchax/CONTRIBUTING.md +38 -0
  80. torchax-0.0.4/torchax/__init__.py +124 -0
  81. torchax-0.0.4/torchax/config.py +19 -0
  82. torchax-0.0.4/torchax/decompositions.py +308 -0
  83. torchax-0.0.4/torchax/device_module.py +20 -0
  84. torchax-0.0.4/torchax/distributed.py +246 -0
  85. torchax-0.0.4/torchax/environment.py +2 -0
  86. torchax-0.0.4/torchax/export.py +236 -0
  87. torchax-0.0.4/torchax/interop.py +209 -0
  88. torchax-0.0.4/torchax/ops/__init__.py +10 -0
  89. torchax-0.0.4/torchax/ops/jaten.py +5197 -0
  90. torchax-0.0.4/torchax/ops/jax_reimplement.py +169 -0
  91. torchax-0.0.4/torchax/ops/jc10d.py +51 -0
  92. torchax-0.0.4/torchax/ops/jlibrary.py +73 -0
  93. torchax-0.0.4/torchax/ops/jtorch.py +425 -0
  94. torchax-0.0.4/torchax/ops/jtorchvision_nms.py +245 -0
  95. torchax-0.0.4/torchax/ops/mappings.py +97 -0
  96. torchax-0.0.4/torchax/ops/op_base.py +104 -0
  97. torchax-0.0.4/torchax/ops/ops_registry.py +50 -0
  98. torchax-0.0.4/torchax/tensor.py +557 -0
  99. torchax-0.0.4/torchax/tf_integration.py +119 -0
  100. torchax-0.0.4/torchax/train.py +120 -0
  101. torchax-0.0.4/torchax/types.py +12 -0
@@ -0,0 +1,33 @@
1
+ build/
2
+ dist/
3
+ *.egg-info/
4
+ torch_xla/lib/
5
+ torch_xla/pb/cpp/*
6
+ torch_xla/version.py
7
+ torch_xla/csrc/version.cpp
8
+ */**/__pycache__
9
+ *.swp
10
+ *.pyc
11
+ *.so
12
+
13
+ # BEGIN NOT-CLEAN-FILES (setup.py handles this marker. Do not change.)
14
+ #
15
+ # Below files are not deleted by "setup.py clean".
16
+
17
+ # Visual Studio Code files
18
+ .vs
19
+ .vscode/
20
+
21
+ # Files autogenerated by docs/docs_build.sh
22
+ /core
23
+ /docs/src/*
24
+
25
+ # Local terraform state
26
+ .terraform
27
+
28
+
29
+ # Build system temporary files
30
+ bazel-*
31
+
32
+ # Clangd cache directory
33
+ .cache/*
torchax-0.0.4/=2.3.0 ADDED
@@ -0,0 +1,9 @@
1
+ Requirement already satisfied: torch in /Users/hanq/git/qihqi/pytorch (2.1.0a0+gited11333)
2
+ Requirement already satisfied: filelock in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/torch310/lib/python3.10/site-packages (from torch) (3.12.0)
3
+ Requirement already satisfied: typing-extensions in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/torch310/lib/python3.10/site-packages (from torch) (4.5.0)
4
+ Requirement already satisfied: sympy in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/torch310/lib/python3.10/site-packages (from torch) (1.12)
5
+ Requirement already satisfied: networkx in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/torch310/lib/python3.10/site-packages (from torch) (3.1)
6
+ Requirement already satisfied: jinja2 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/torch310/lib/python3.10/site-packages (from torch) (3.1.2)
7
+ Requirement already satisfied: fsspec in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/torch310/lib/python3.10/site-packages (from torch) (2023.5.0)
8
+ Requirement already satisfied: MarkupSafe>=2.0 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/torch310/lib/python3.10/site-packages (from jinja2->torch) (2.1.2)
9
+ Requirement already satisfied: mpmath>=0.19 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/torch310/lib/python3.10/site-packages (from sympy->torch) (1.3.0)
torchax-0.0.4/LICENSE ADDED
@@ -0,0 +1,28 @@
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2023, pytorch-tpu
4
+
5
+ Redistribution and use in source and binary forms, with or without
6
+ modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation
13
+ and/or other materials provided with the distribution.
14
+
15
+ 3. Neither the name of the copyright holder nor the names of its
16
+ contributors may be used to endorse or promote products derived from
17
+ this software without specific prior written permission.
18
+
19
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
torchax-0.0.4/PKG-INFO ADDED
@@ -0,0 +1,279 @@
1
+ Metadata-Version: 2.4
2
+ Name: torchax
3
+ Version: 0.0.4
4
+ Summary: torchax is a library for running PyTorch on TPU
5
+ Project-URL: Homepage, https://github.com/pytorch/xla/tree/master/torchax
6
+ Author-email: Han Qi <qihan.dev@gmail.com>, Pytorch/XLA team <pytorchxla-dev@google.com>
7
+ License: BSD 3-Clause License
8
+
9
+ Copyright (c) 2023, pytorch-tpu
10
+
11
+ Redistribution and use in source and binary forms, with or without
12
+ modification, are permitted provided that the following conditions are met:
13
+
14
+ 1. Redistributions of source code must retain the above copyright notice, this
15
+ list of conditions and the following disclaimer.
16
+
17
+ 2. Redistributions in binary form must reproduce the above copyright notice,
18
+ this list of conditions and the following disclaimer in the documentation
19
+ and/or other materials provided with the distribution.
20
+
21
+ 3. Neither the name of the copyright holder nor the names of its
22
+ contributors may be used to endorse or promote products derived from
23
+ this software without specific prior written permission.
24
+
25
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
28
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
29
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
30
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
31
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
32
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
33
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
34
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
35
+ License-File: LICENSE
36
+ Classifier: Development Status :: 3 - Alpha
37
+ Classifier: Intended Audience :: Developers
38
+ Classifier: Intended Audience :: Education
39
+ Classifier: Intended Audience :: Science/Research
40
+ Classifier: License :: OSI Approved :: BSD License
41
+ Classifier: Programming Language :: Python :: 3.10
42
+ Classifier: Programming Language :: Python :: 3.11
43
+ Classifier: Programming Language :: Python :: 3.12
44
+ Classifier: Programming Language :: Python :: 3.13
45
+ Classifier: Topic :: Scientific/Engineering
46
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
47
+ Classifier: Topic :: Scientific/Engineering :: Mathematics
48
+ Classifier: Topic :: Software Development
49
+ Classifier: Topic :: Software Development :: Libraries
50
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
51
+ Requires-Python: >=3.10
52
+ Provides-Extra: cpu
53
+ Requires-Dist: jax[cpu]; extra == 'cpu'
54
+ Requires-Dist: jax[cpu]>=0.4.30; extra == 'cpu'
55
+ Requires-Dist: tensorflow-cpu; extra == 'cpu'
56
+ Provides-Extra: cuda
57
+ Requires-Dist: jax[cpu]>=0.4.30; extra == 'cuda'
58
+ Requires-Dist: jax[cuda12]; extra == 'cuda'
59
+ Requires-Dist: tensorflow-cpu; extra == 'cuda'
60
+ Provides-Extra: odml
61
+ Requires-Dist: jax[cpu]; extra == 'odml'
62
+ Requires-Dist: jax[cpu]>=0.4.30; extra == 'odml'
63
+ Provides-Extra: tpu
64
+ Requires-Dist: jax[cpu]>=0.4.30; extra == 'tpu'
65
+ Requires-Dist: jax[tpu]; extra == 'tpu'
66
+ Requires-Dist: tensorflow-cpu; extra == 'tpu'
67
+ Description-Content-Type: text/markdown
68
+
69
+ # torchxla2
70
+
71
+ ## Install
72
+
73
+ Currently this is only source-installable. Requires Python version >= 3.10.
74
+
75
+ ### NOTE:
76
+
77
+ Please don't install torch-xla from instructions in
78
+ https://github.com/pytorch/xla/blob/master/CONTRIBUTING.md .
79
+ In particular, the following are not needed:
80
+
81
+ * There is no need to build pytorch/pytorch from source.
82
+ * There is no need to clone pytorch/xla project inside of pytorch/pytorch
83
+ git checkout.
84
+
85
+
86
+ TorchXLA2 and torch-xla have different installation instructions, please follow
87
+ the instructions below from scratch (fresh venv / conda environment.)
88
+
89
+
90
+ ### 1. Installing `torchax`
91
+
92
+ The following instructions assume you are in the `torchax` directory:
93
+
94
+ ```
95
+ Fork the repository
96
+ $ git clone https://github.com/<github_username>/xla.git
97
+ $ cd xla/torchax
98
+ ```
99
+
100
+
101
+ #### 1.0 (recommended) Make a virtualenv / conda env
102
+
103
+ If you are using VSCode, then [you can create a new environment from
104
+ UI](https://code.visualstudio.com/docs/python/environments). Select the
105
+ `dev-requirements.txt` when asked to install project dependencies.
106
+
107
+ Otherwise create a new environment from the command line.
108
+
109
+ ```bash
110
+ # Option 1: venv
111
+ python -m venv my_venv
112
+ source my_venv/bin/activate
113
+
114
+ # Option 2: conda
115
+ conda create --name <your_name> python=3.10
116
+ conda activate <your_name>
117
+
118
+ # Either way, install the dev requirements.
119
+ pip install -r dev-requirements.txt
120
+ ```
121
+
122
+ Note: `dev-requirements.txt` will install the CPU-only version of PyTorch.
123
+
124
+ #### 1.1 Install this package
125
+
126
+ Install `torchax` from source for your platform:
127
+ ```bash
128
+ pip install -e .[cpu]
129
+ pip install -e .[cuda]
130
+ pip install -e .[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
131
+ ```
132
+
133
+ #### 1.2 (optional) verify installation by running tests
134
+
135
+ ```bash
136
+ pip install -r test-requirements.txt
137
+ pytest test
138
+ ```
139
+
140
+ ## Run a model
141
+
142
+ Now let's execute a model under torchax. We'll start with a simple 2-layer model
143
+ it can be in theory any instance of `torch.nn.Module`.
144
+
145
+ ```python
146
+ import torch
147
+ import torch.nn as nn
148
+ import torch.nn.functional as F
149
+
150
+
151
+ class MyModel(nn.Module):
152
+ def __init__(self):
153
+ super().__init__()
154
+ self.fc1 = nn.Linear(28 * 28, 120)
155
+ self.fc2 = nn.Linear(120, 84)
156
+ self.fc3 = nn.Linear(84, 10)
157
+
158
+ def forward(self, x):
159
+ x = x.view(-1, 28 * 28)
160
+ x = F.relu(self.fc1(x))
161
+ x = F.relu(self.fc2(x))
162
+ x = self.fc3(x)
163
+ return x
164
+
165
+ m = MyModel()
166
+
167
+ # Execute this model using torch
168
+ inputs = torch.randn(3, 3, 28, 28)
169
+ print(m(inputs))
170
+ ```
171
+
172
+ This model `m` contains 2 parts: the weights that is stored inside of the model
173
+ and it's submodules (`nn.Linear`).
174
+
175
+ To execute this model with `torchax`; we need construct and run the model
176
+ under an `environment` that captures pytorch ops and swaps them with TPU equivalent.
177
+
178
+ To create this environment: use
179
+
180
+ ```python
181
+ import torchax
182
+
183
+ env = torchax.default_env()
184
+ ```
185
+ Then, execute the instantiation of the model, as well as evaluation of model,
186
+ using `env` as a context manager:
187
+
188
+ ```python
189
+ with env:
190
+ inputs = torch.randn(3, 3, 28, 28)
191
+ m = MyModel()
192
+ res = m(inputs)
193
+ print(type(res)) # outputs Tensor
194
+ ```
195
+
196
+ You can also enable the environment globally with
197
+ ```python
198
+ import torchax
199
+
200
+ torchax.enable_globally()
201
+ ```
202
+
203
+ Then everything afterwards is run with XLA.
204
+
205
+ ## What is happening behind the scene:
206
+
207
+ When a torch op is executed inside of `env` context manager, we can swap out the
208
+ implementation of that op with a version that runs on TPU.
209
+ When a model's constructor runs, it will call some tensor constructor, such as
210
+ `torch.rand`, `torch.ones` or `torch.zeros` etc to create its weights. Those
211
+ ops are captured by `env` too and placed directly on TPU.
212
+
213
+ See more at [how_it_works](docs/how_it_works.md) and [ops registry](docs/ops_registry.md).
214
+
215
+ ### What if I created model outside of `env`.
216
+
217
+ So if you have
218
+
219
+ ```
220
+ m = MyModel()
221
+ ```
222
+ outside of env, then regular torch ops will run when creating this model.
223
+ Then presumably the model's weights will be on CPU (as instances of `torch.Tensor`).
224
+
225
+ To move this model into XLA device, one can use `env.to_xla()` function.
226
+
227
+ i.e.
228
+ ```
229
+ m2 = env.to_xla(m)
230
+ inputs = env.to_xla(inputs)
231
+
232
+ with env:
233
+ res = m2(inputs)
234
+ ```
235
+
236
+ NOTE that we also need to move inputs to xla using `.to_xla`.
237
+ `to_xla` works with all pytrees of `torch.Tensor`.
238
+
239
+
240
+ ### Executing with jax.jit
241
+
242
+ The above script will execute the model using eager mode Jax as backend. This
243
+ does allow executing torch models on TPU, but is often slower than what we can
244
+ achieve with `jax.jit`.
245
+
246
+ `jax.jit` is a function that takes a Jax function (i.e. a function that takes jax array
247
+ and returns jax array) into the same function, but faster.
248
+
249
+ We have made the `jax_jit` decorator that would accomplish the same with functions
250
+ that takes and returns `torch.Tensor`. To use this, the first step is to create
251
+ a functional version of this model: this means the parameters should be passed in
252
+ as input instead of being attributes on class:
253
+
254
+
255
+ ```python
256
+
257
+ def model_func(param, inputs):
258
+ return torch.func.functional_call(m, param, inputs)
259
+
260
+ ```
261
+ Here we use [torch.func.functional_call](https://pytorch.org/docs/stable/generated/torch.func.functional_call.html)
262
+ from PyTorch to replace the model
263
+ weights with `param`, then call the model. This is equivalent to:
264
+
265
+ ```python
266
+ def model_func(param, inputs):
267
+ m.load_state_dict(param)
268
+ return m(*inputs)
269
+ ```
270
+
271
+ Now, we can apply `jax_jit`
272
+
273
+ ```python
274
+ from torchax.interop import jax_jit
275
+ model_func_jitted = jax_jit(model_func)
276
+ print(model_func_jitted(new_state_dict, inputs))
277
+ ```
278
+
279
+ See more examples at [eager_mode.py](examples/eager_mode.py) and the (examples folder)[examples/]
@@ -0,0 +1,211 @@
1
+ # torchxla2
2
+
3
+ ## Install
4
+
5
+ Currently this is only source-installable. Requires Python version >= 3.10.
6
+
7
+ ### NOTE:
8
+
9
+ Please don't install torch-xla from instructions in
10
+ https://github.com/pytorch/xla/blob/master/CONTRIBUTING.md .
11
+ In particular, the following are not needed:
12
+
13
+ * There is no need to build pytorch/pytorch from source.
14
+ * There is no need to clone pytorch/xla project inside of pytorch/pytorch
15
+ git checkout.
16
+
17
+
18
+ TorchXLA2 and torch-xla have different installation instructions, please follow
19
+ the instructions below from scratch (fresh venv / conda environment.)
20
+
21
+
22
+ ### 1. Installing `torchax`
23
+
24
+ The following instructions assume you are in the `torchax` directory:
25
+
26
+ ```
27
+ Fork the repository
28
+ $ git clone https://github.com/<github_username>/xla.git
29
+ $ cd xla/torchax
30
+ ```
31
+
32
+
33
+ #### 1.0 (recommended) Make a virtualenv / conda env
34
+
35
+ If you are using VSCode, then [you can create a new environment from
36
+ UI](https://code.visualstudio.com/docs/python/environments). Select the
37
+ `dev-requirements.txt` when asked to install project dependencies.
38
+
39
+ Otherwise create a new environment from the command line.
40
+
41
+ ```bash
42
+ # Option 1: venv
43
+ python -m venv my_venv
44
+ source my_venv/bin/activate
45
+
46
+ # Option 2: conda
47
+ conda create --name <your_name> python=3.10
48
+ conda activate <your_name>
49
+
50
+ # Either way, install the dev requirements.
51
+ pip install -r dev-requirements.txt
52
+ ```
53
+
54
+ Note: `dev-requirements.txt` will install the CPU-only version of PyTorch.
55
+
56
+ #### 1.1 Install this package
57
+
58
+ Install `torchax` from source for your platform:
59
+ ```bash
60
+ pip install -e .[cpu]
61
+ pip install -e .[cuda]
62
+ pip install -e .[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
63
+ ```
64
+
65
+ #### 1.2 (optional) verify installation by running tests
66
+
67
+ ```bash
68
+ pip install -r test-requirements.txt
69
+ pytest test
70
+ ```
71
+
72
+ ## Run a model
73
+
74
+ Now let's execute a model under torchax. We'll start with a simple 2-layer model
75
+ it can be in theory any instance of `torch.nn.Module`.
76
+
77
+ ```python
78
+ import torch
79
+ import torch.nn as nn
80
+ import torch.nn.functional as F
81
+
82
+
83
+ class MyModel(nn.Module):
84
+ def __init__(self):
85
+ super().__init__()
86
+ self.fc1 = nn.Linear(28 * 28, 120)
87
+ self.fc2 = nn.Linear(120, 84)
88
+ self.fc3 = nn.Linear(84, 10)
89
+
90
+ def forward(self, x):
91
+ x = x.view(-1, 28 * 28)
92
+ x = F.relu(self.fc1(x))
93
+ x = F.relu(self.fc2(x))
94
+ x = self.fc3(x)
95
+ return x
96
+
97
+ m = MyModel()
98
+
99
+ # Execute this model using torch
100
+ inputs = torch.randn(3, 3, 28, 28)
101
+ print(m(inputs))
102
+ ```
103
+
104
+ This model `m` contains 2 parts: the weights that is stored inside of the model
105
+ and it's submodules (`nn.Linear`).
106
+
107
+ To execute this model with `torchax`; we need construct and run the model
108
+ under an `environment` that captures pytorch ops and swaps them with TPU equivalent.
109
+
110
+ To create this environment: use
111
+
112
+ ```python
113
+ import torchax
114
+
115
+ env = torchax.default_env()
116
+ ```
117
+ Then, execute the instantiation of the model, as well as evaluation of model,
118
+ using `env` as a context manager:
119
+
120
+ ```python
121
+ with env:
122
+ inputs = torch.randn(3, 3, 28, 28)
123
+ m = MyModel()
124
+ res = m(inputs)
125
+ print(type(res)) # outputs Tensor
126
+ ```
127
+
128
+ You can also enable the environment globally with
129
+ ```python
130
+ import torchax
131
+
132
+ torchax.enable_globally()
133
+ ```
134
+
135
+ Then everything afterwards is run with XLA.
136
+
137
+ ## What is happening behind the scene:
138
+
139
+ When a torch op is executed inside of `env` context manager, we can swap out the
140
+ implementation of that op with a version that runs on TPU.
141
+ When a model's constructor runs, it will call some tensor constructor, such as
142
+ `torch.rand`, `torch.ones` or `torch.zeros` etc to create its weights. Those
143
+ ops are captured by `env` too and placed directly on TPU.
144
+
145
+ See more at [how_it_works](docs/how_it_works.md) and [ops registry](docs/ops_registry.md).
146
+
147
+ ### What if I created model outside of `env`.
148
+
149
+ So if you have
150
+
151
+ ```
152
+ m = MyModel()
153
+ ```
154
+ outside of env, then regular torch ops will run when creating this model.
155
+ Then presumably the model's weights will be on CPU (as instances of `torch.Tensor`).
156
+
157
+ To move this model into XLA device, one can use `env.to_xla()` function.
158
+
159
+ i.e.
160
+ ```
161
+ m2 = env.to_xla(m)
162
+ inputs = env.to_xla(inputs)
163
+
164
+ with env:
165
+ res = m2(inputs)
166
+ ```
167
+
168
+ NOTE that we also need to move inputs to xla using `.to_xla`.
169
+ `to_xla` works with all pytrees of `torch.Tensor`.
170
+
171
+
172
+ ### Executing with jax.jit
173
+
174
+ The above script will execute the model using eager mode Jax as backend. This
175
+ does allow executing torch models on TPU, but is often slower than what we can
176
+ achieve with `jax.jit`.
177
+
178
+ `jax.jit` is a function that takes a Jax function (i.e. a function that takes jax array
179
+ and returns jax array) into the same function, but faster.
180
+
181
+ We have made the `jax_jit` decorator that would accomplish the same with functions
182
+ that takes and returns `torch.Tensor`. To use this, the first step is to create
183
+ a functional version of this model: this means the parameters should be passed in
184
+ as input instead of being attributes on class:
185
+
186
+
187
+ ```python
188
+
189
+ def model_func(param, inputs):
190
+ return torch.func.functional_call(m, param, inputs)
191
+
192
+ ```
193
+ Here we use [torch.func.functional_call](https://pytorch.org/docs/stable/generated/torch.func.functional_call.html)
194
+ from PyTorch to replace the model
195
+ weights with `param`, then call the model. This is equivalent to:
196
+
197
+ ```python
198
+ def model_func(param, inputs):
199
+ m.load_state_dict(param)
200
+ return m(*inputs)
201
+ ```
202
+
203
+ Now, we can apply `jax_jit`
204
+
205
+ ```python
206
+ from torchax.interop import jax_jit
207
+ model_func_jitted = jax_jit(model_func)
208
+ print(model_func_jitted(new_state_dict, inputs))
209
+ ```
210
+
211
+ See more examples at [eager_mode.py](examples/eager_mode.py) and the (examples folder)[examples/]
@@ -0,0 +1,10 @@
1
+ #!/usr/bin/env bash
2
+ set -ex
3
+
4
+ NIGHTLY_VERSION=$(date '+%Y%m%d%H%M')
5
+
6
+ # Update the version to <version>.devYYYYMMDDHHMM in __init__.py
7
+ VERSION_UPDATE_PATTERN="s/^__version__\s*=\s*\"([^\"]+)\"/__version__ = \"\1.dev$NIGHTLY_VERSION\"/g;"
8
+ sed -r "$VERSION_UPDATE_PATTERN" torchax/__init__.py --in-place
9
+
10
+ hatch build -t wheel
@@ -0,0 +1,4 @@
1
+ -f https://download.pytorch.org/whl/torch
2
+ torch==2.5.1; sys_platform == 'darwin' # macOS
3
+ torch==2.5.1+cpu; sys_platform != 'darwin' # Non-macOS (CPU-only), like on TPU
4
+ ruff~=0.3.5
Binary file