torchax 0.0.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax-0.0.4/.gitignore +33 -0
- torchax-0.0.4/=2.3.0 +9 -0
- torchax-0.0.4/LICENSE +28 -0
- torchax-0.0.4/PKG-INFO +279 -0
- torchax-0.0.4/README.md +211 -0
- torchax-0.0.4/build_nightly.sh +10 -0
- torchax-0.0.4/dev-requirements.txt +4 -0
- torchax-0.0.4/docs/dispatch.png +0 -0
- torchax-0.0.4/docs/fixing_op_info_test.md +254 -0
- torchax-0.0.4/docs/how_it_works.md +134 -0
- torchax-0.0.4/docs/ops_registry.md +40 -0
- torchax-0.0.4/docs/support_a_new_model.md +137 -0
- torchax-0.0.4/docs/torch_dispatch/README.md +39 -0
- torchax-0.0.4/docs/torch_dispatch/example.py +65 -0
- torchax-0.0.4/docs/torch_dispatch/run_env.py +12 -0
- torchax-0.0.4/docs/torch_xla2_dynamo.md +194 -0
- torchax-0.0.4/docs/understand_jax_jit/jax_grad.py +66 -0
- torchax-0.0.4/docs/understand_jax_jit/jax_jit.py +106 -0
- torchax-0.0.4/docs/understand_jax_jit/torch_module.py +126 -0
- torchax-0.0.4/examples/README.md +115 -0
- torchax-0.0.4/examples/__init__.py +0 -0
- torchax-0.0.4/examples/_diffusion.py +112 -0
- torchax-0.0.4/examples/_grad_of_attention.py +73 -0
- torchax-0.0.4/examples/basic_training.py +187 -0
- torchax-0.0.4/examples/basic_training_jax.py +131 -0
- torchax-0.0.4/examples/eager_mode.py +38 -0
- torchax-0.0.4/examples/lightning_training.py +77 -0
- torchax-0.0.4/examples/mnist_tpu.ipynb +647 -0
- torchax-0.0.4/examples/requirements.txt +3 -0
- torchax-0.0.4/examples/torchbench_models/BERT_pytorch.py +49 -0
- torchax-0.0.4/examples/train_gpt/requirements.txt +4 -0
- torchax-0.0.4/examples/train_gpt/train_ddp.py +147 -0
- torchax-0.0.4/examples/train_llama/README.md +194 -0
- torchax-0.0.4/examples/train_llama/__init__.py +0 -0
- torchax-0.0.4/examples/train_llama/model.py +449 -0
- torchax-0.0.4/examples/train_llama/train_llama_lightning.py +329 -0
- torchax-0.0.4/examples/train_llama/utils.py +314 -0
- torchax-0.0.4/examples/train_llama_torchtitan/Dockerfile +34 -0
- torchax-0.0.4/examples/train_llama_torchtitan/README.md +511 -0
- torchax-0.0.4/examples/train_llama_torchtitan/__init__.py +0 -0
- torchax-0.0.4/examples/train_llama_torchtitan/helper.py +37 -0
- torchax-0.0.4/examples/train_llama_torchtitan/splash_attn.py +97 -0
- torchax-0.0.4/examples/train_llama_torchtitan/train_llama.py +351 -0
- torchax-0.0.4/format.sh +4 -0
- torchax-0.0.4/pyproject.toml +57 -0
- torchax-0.0.4/test/BUILD +31 -0
- torchax-0.0.4/test/__init__.py +0 -0
- torchax-0.0.4/test/gemma/__init__.py +0 -0
- torchax-0.0.4/test/gemma/config.py +86 -0
- torchax-0.0.4/test/gemma/model.py +561 -0
- torchax-0.0.4/test/gemma/test_gemma.py +86 -0
- torchax-0.0.4/test/gemma/tokenizer.py +48 -0
- torchax-0.0.4/test/llama/BUILD +25 -0
- torchax-0.0.4/test/llama/__init__.py +0 -0
- torchax-0.0.4/test/llama/llama_model.py +310 -0
- torchax-0.0.4/test/llama/model_exportable.py +304 -0
- torchax-0.0.4/test/llama/test_llama.py +113 -0
- torchax-0.0.4/test/moe/__init__.py +0 -0
- torchax-0.0.4/test/moe/model.py +260 -0
- torchax-0.0.4/test/moe/moe_test.py +75 -0
- torchax-0.0.4/test/test_base.py +55 -0
- torchax-0.0.4/test/test_context.py +114 -0
- torchax-0.0.4/test/test_conv.py +80 -0
- torchax-0.0.4/test/test_core_aten_ops.py +4506 -0
- torchax-0.0.4/test/test_exports.py +148 -0
- torchax-0.0.4/test/test_functions.py +98 -0
- torchax-0.0.4/test/test_interop.py +47 -0
- torchax-0.0.4/test/test_libraries.py +83 -0
- torchax-0.0.4/test/test_mutations.py +36 -0
- torchax-0.0.4/test/test_ops.py +236 -0
- torchax-0.0.4/test/test_symbolic_shapes.py +95 -0
- torchax-0.0.4/test/test_tf_integration.py +52 -0
- torchax-0.0.4/test/test_train.py +60 -0
- torchax-0.0.4/test/test_unbounded_dynamism.py +670 -0
- torchax-0.0.4/test-requirements.txt +9 -0
- torchax-0.0.4/test_dist/README.md +4 -0
- torchax-0.0.4/test_dist/__init__.py +0 -0
- torchax-0.0.4/test_dist/test_distributed.py +155 -0
- torchax-0.0.4/torchax/CONTRIBUTING.md +38 -0
- torchax-0.0.4/torchax/__init__.py +124 -0
- torchax-0.0.4/torchax/config.py +19 -0
- torchax-0.0.4/torchax/decompositions.py +308 -0
- torchax-0.0.4/torchax/device_module.py +20 -0
- torchax-0.0.4/torchax/distributed.py +246 -0
- torchax-0.0.4/torchax/environment.py +2 -0
- torchax-0.0.4/torchax/export.py +236 -0
- torchax-0.0.4/torchax/interop.py +209 -0
- torchax-0.0.4/torchax/ops/__init__.py +10 -0
- torchax-0.0.4/torchax/ops/jaten.py +5197 -0
- torchax-0.0.4/torchax/ops/jax_reimplement.py +169 -0
- torchax-0.0.4/torchax/ops/jc10d.py +51 -0
- torchax-0.0.4/torchax/ops/jlibrary.py +73 -0
- torchax-0.0.4/torchax/ops/jtorch.py +425 -0
- torchax-0.0.4/torchax/ops/jtorchvision_nms.py +245 -0
- torchax-0.0.4/torchax/ops/mappings.py +97 -0
- torchax-0.0.4/torchax/ops/op_base.py +104 -0
- torchax-0.0.4/torchax/ops/ops_registry.py +50 -0
- torchax-0.0.4/torchax/tensor.py +557 -0
- torchax-0.0.4/torchax/tf_integration.py +119 -0
- torchax-0.0.4/torchax/train.py +120 -0
- torchax-0.0.4/torchax/types.py +12 -0
torchax-0.0.4/.gitignore
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
build/
|
|
2
|
+
dist/
|
|
3
|
+
*.egg-info/
|
|
4
|
+
torch_xla/lib/
|
|
5
|
+
torch_xla/pb/cpp/*
|
|
6
|
+
torch_xla/version.py
|
|
7
|
+
torch_xla/csrc/version.cpp
|
|
8
|
+
*/**/__pycache__
|
|
9
|
+
*.swp
|
|
10
|
+
*.pyc
|
|
11
|
+
*.so
|
|
12
|
+
|
|
13
|
+
# BEGIN NOT-CLEAN-FILES (setup.py handles this marker. Do not change.)
|
|
14
|
+
#
|
|
15
|
+
# Below files are not deleted by "setup.py clean".
|
|
16
|
+
|
|
17
|
+
# Visual Studio Code files
|
|
18
|
+
.vs
|
|
19
|
+
.vscode/
|
|
20
|
+
|
|
21
|
+
# Files autogenerated by docs/docs_build.sh
|
|
22
|
+
/core
|
|
23
|
+
/docs/src/*
|
|
24
|
+
|
|
25
|
+
# Local terraform state
|
|
26
|
+
.terraform
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# Build system temporary files
|
|
30
|
+
bazel-*
|
|
31
|
+
|
|
32
|
+
# Clangd cache directory
|
|
33
|
+
.cache/*
|
torchax-0.0.4/=2.3.0
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
Requirement already satisfied: torch in /Users/hanq/git/qihqi/pytorch (2.1.0a0+gited11333)
|
|
2
|
+
Requirement already satisfied: filelock in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/torch310/lib/python3.10/site-packages (from torch) (3.12.0)
|
|
3
|
+
Requirement already satisfied: typing-extensions in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/torch310/lib/python3.10/site-packages (from torch) (4.5.0)
|
|
4
|
+
Requirement already satisfied: sympy in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/torch310/lib/python3.10/site-packages (from torch) (1.12)
|
|
5
|
+
Requirement already satisfied: networkx in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/torch310/lib/python3.10/site-packages (from torch) (3.1)
|
|
6
|
+
Requirement already satisfied: jinja2 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/torch310/lib/python3.10/site-packages (from torch) (3.1.2)
|
|
7
|
+
Requirement already satisfied: fsspec in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/torch310/lib/python3.10/site-packages (from torch) (2023.5.0)
|
|
8
|
+
Requirement already satisfied: MarkupSafe>=2.0 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/torch310/lib/python3.10/site-packages (from jinja2->torch) (2.1.2)
|
|
9
|
+
Requirement already satisfied: mpmath>=0.19 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/torch310/lib/python3.10/site-packages (from sympy->torch) (1.3.0)
|
torchax-0.0.4/LICENSE
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
BSD 3-Clause License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2023, pytorch-tpu
|
|
4
|
+
|
|
5
|
+
Redistribution and use in source and binary forms, with or without
|
|
6
|
+
modification, are permitted provided that the following conditions are met:
|
|
7
|
+
|
|
8
|
+
1. Redistributions of source code must retain the above copyright notice, this
|
|
9
|
+
list of conditions and the following disclaimer.
|
|
10
|
+
|
|
11
|
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
|
12
|
+
this list of conditions and the following disclaimer in the documentation
|
|
13
|
+
and/or other materials provided with the distribution.
|
|
14
|
+
|
|
15
|
+
3. Neither the name of the copyright holder nor the names of its
|
|
16
|
+
contributors may be used to endorse or promote products derived from
|
|
17
|
+
this software without specific prior written permission.
|
|
18
|
+
|
|
19
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
20
|
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
21
|
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
22
|
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
23
|
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
24
|
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
25
|
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
26
|
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
27
|
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
28
|
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
torchax-0.0.4/PKG-INFO
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: torchax
|
|
3
|
+
Version: 0.0.4
|
|
4
|
+
Summary: torchax is a library for running PyTorch on TPU
|
|
5
|
+
Project-URL: Homepage, https://github.com/pytorch/xla/tree/master/torchax
|
|
6
|
+
Author-email: Han Qi <qihan.dev@gmail.com>, Pytorch/XLA team <pytorchxla-dev@google.com>
|
|
7
|
+
License: BSD 3-Clause License
|
|
8
|
+
|
|
9
|
+
Copyright (c) 2023, pytorch-tpu
|
|
10
|
+
|
|
11
|
+
Redistribution and use in source and binary forms, with or without
|
|
12
|
+
modification, are permitted provided that the following conditions are met:
|
|
13
|
+
|
|
14
|
+
1. Redistributions of source code must retain the above copyright notice, this
|
|
15
|
+
list of conditions and the following disclaimer.
|
|
16
|
+
|
|
17
|
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
|
18
|
+
this list of conditions and the following disclaimer in the documentation
|
|
19
|
+
and/or other materials provided with the distribution.
|
|
20
|
+
|
|
21
|
+
3. Neither the name of the copyright holder nor the names of its
|
|
22
|
+
contributors may be used to endorse or promote products derived from
|
|
23
|
+
this software without specific prior written permission.
|
|
24
|
+
|
|
25
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
26
|
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
27
|
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
28
|
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
29
|
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
30
|
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
31
|
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
32
|
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
33
|
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
34
|
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
35
|
+
License-File: LICENSE
|
|
36
|
+
Classifier: Development Status :: 3 - Alpha
|
|
37
|
+
Classifier: Intended Audience :: Developers
|
|
38
|
+
Classifier: Intended Audience :: Education
|
|
39
|
+
Classifier: Intended Audience :: Science/Research
|
|
40
|
+
Classifier: License :: OSI Approved :: BSD License
|
|
41
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
42
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
43
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
44
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
45
|
+
Classifier: Topic :: Scientific/Engineering
|
|
46
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
47
|
+
Classifier: Topic :: Scientific/Engineering :: Mathematics
|
|
48
|
+
Classifier: Topic :: Software Development
|
|
49
|
+
Classifier: Topic :: Software Development :: Libraries
|
|
50
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
51
|
+
Requires-Python: >=3.10
|
|
52
|
+
Provides-Extra: cpu
|
|
53
|
+
Requires-Dist: jax[cpu]; extra == 'cpu'
|
|
54
|
+
Requires-Dist: jax[cpu]>=0.4.30; extra == 'cpu'
|
|
55
|
+
Requires-Dist: tensorflow-cpu; extra == 'cpu'
|
|
56
|
+
Provides-Extra: cuda
|
|
57
|
+
Requires-Dist: jax[cpu]>=0.4.30; extra == 'cuda'
|
|
58
|
+
Requires-Dist: jax[cuda12]; extra == 'cuda'
|
|
59
|
+
Requires-Dist: tensorflow-cpu; extra == 'cuda'
|
|
60
|
+
Provides-Extra: odml
|
|
61
|
+
Requires-Dist: jax[cpu]; extra == 'odml'
|
|
62
|
+
Requires-Dist: jax[cpu]>=0.4.30; extra == 'odml'
|
|
63
|
+
Provides-Extra: tpu
|
|
64
|
+
Requires-Dist: jax[cpu]>=0.4.30; extra == 'tpu'
|
|
65
|
+
Requires-Dist: jax[tpu]; extra == 'tpu'
|
|
66
|
+
Requires-Dist: tensorflow-cpu; extra == 'tpu'
|
|
67
|
+
Description-Content-Type: text/markdown
|
|
68
|
+
|
|
69
|
+
# torchxla2
|
|
70
|
+
|
|
71
|
+
## Install
|
|
72
|
+
|
|
73
|
+
Currently this is only source-installable. Requires Python version >= 3.10.
|
|
74
|
+
|
|
75
|
+
### NOTE:
|
|
76
|
+
|
|
77
|
+
Please don't install torch-xla from instructions in
|
|
78
|
+
https://github.com/pytorch/xla/blob/master/CONTRIBUTING.md .
|
|
79
|
+
In particular, the following are not needed:
|
|
80
|
+
|
|
81
|
+
* There is no need to build pytorch/pytorch from source.
|
|
82
|
+
* There is no need to clone pytorch/xla project inside of pytorch/pytorch
|
|
83
|
+
git checkout.
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
TorchXLA2 and torch-xla have different installation instructions, please follow
|
|
87
|
+
the instructions below from scratch (fresh venv / conda environment.)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
### 1. Installing `torchax`
|
|
91
|
+
|
|
92
|
+
The following instructions assume you are in the `torchax` directory:
|
|
93
|
+
|
|
94
|
+
```
|
|
95
|
+
Fork the repository
|
|
96
|
+
$ git clone https://github.com/<github_username>/xla.git
|
|
97
|
+
$ cd xla/torchax
|
|
98
|
+
```
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
#### 1.0 (recommended) Make a virtualenv / conda env
|
|
102
|
+
|
|
103
|
+
If you are using VSCode, then [you can create a new environment from
|
|
104
|
+
UI](https://code.visualstudio.com/docs/python/environments). Select the
|
|
105
|
+
`dev-requirements.txt` when asked to install project dependencies.
|
|
106
|
+
|
|
107
|
+
Otherwise create a new environment from the command line.
|
|
108
|
+
|
|
109
|
+
```bash
|
|
110
|
+
# Option 1: venv
|
|
111
|
+
python -m venv my_venv
|
|
112
|
+
source my_venv/bin/activate
|
|
113
|
+
|
|
114
|
+
# Option 2: conda
|
|
115
|
+
conda create --name <your_name> python=3.10
|
|
116
|
+
conda activate <your_name>
|
|
117
|
+
|
|
118
|
+
# Either way, install the dev requirements.
|
|
119
|
+
pip install -r dev-requirements.txt
|
|
120
|
+
```
|
|
121
|
+
|
|
122
|
+
Note: `dev-requirements.txt` will install the CPU-only version of PyTorch.
|
|
123
|
+
|
|
124
|
+
#### 1.1 Install this package
|
|
125
|
+
|
|
126
|
+
Install `torchax` from source for your platform:
|
|
127
|
+
```bash
|
|
128
|
+
pip install -e .[cpu]
|
|
129
|
+
pip install -e .[cuda]
|
|
130
|
+
pip install -e .[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
|
131
|
+
```
|
|
132
|
+
|
|
133
|
+
#### 1.2 (optional) verify installation by running tests
|
|
134
|
+
|
|
135
|
+
```bash
|
|
136
|
+
pip install -r test-requirements.txt
|
|
137
|
+
pytest test
|
|
138
|
+
```
|
|
139
|
+
|
|
140
|
+
## Run a model
|
|
141
|
+
|
|
142
|
+
Now let's execute a model under torchax. We'll start with a simple 2-layer model
|
|
143
|
+
it can be in theory any instance of `torch.nn.Module`.
|
|
144
|
+
|
|
145
|
+
```python
|
|
146
|
+
import torch
|
|
147
|
+
import torch.nn as nn
|
|
148
|
+
import torch.nn.functional as F
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class MyModel(nn.Module):
|
|
152
|
+
def __init__(self):
|
|
153
|
+
super().__init__()
|
|
154
|
+
self.fc1 = nn.Linear(28 * 28, 120)
|
|
155
|
+
self.fc2 = nn.Linear(120, 84)
|
|
156
|
+
self.fc3 = nn.Linear(84, 10)
|
|
157
|
+
|
|
158
|
+
def forward(self, x):
|
|
159
|
+
x = x.view(-1, 28 * 28)
|
|
160
|
+
x = F.relu(self.fc1(x))
|
|
161
|
+
x = F.relu(self.fc2(x))
|
|
162
|
+
x = self.fc3(x)
|
|
163
|
+
return x
|
|
164
|
+
|
|
165
|
+
m = MyModel()
|
|
166
|
+
|
|
167
|
+
# Execute this model using torch
|
|
168
|
+
inputs = torch.randn(3, 3, 28, 28)
|
|
169
|
+
print(m(inputs))
|
|
170
|
+
```
|
|
171
|
+
|
|
172
|
+
This model `m` contains 2 parts: the weights that is stored inside of the model
|
|
173
|
+
and it's submodules (`nn.Linear`).
|
|
174
|
+
|
|
175
|
+
To execute this model with `torchax`; we need construct and run the model
|
|
176
|
+
under an `environment` that captures pytorch ops and swaps them with TPU equivalent.
|
|
177
|
+
|
|
178
|
+
To create this environment: use
|
|
179
|
+
|
|
180
|
+
```python
|
|
181
|
+
import torchax
|
|
182
|
+
|
|
183
|
+
env = torchax.default_env()
|
|
184
|
+
```
|
|
185
|
+
Then, execute the instantiation of the model, as well as evaluation of model,
|
|
186
|
+
using `env` as a context manager:
|
|
187
|
+
|
|
188
|
+
```python
|
|
189
|
+
with env:
|
|
190
|
+
inputs = torch.randn(3, 3, 28, 28)
|
|
191
|
+
m = MyModel()
|
|
192
|
+
res = m(inputs)
|
|
193
|
+
print(type(res)) # outputs Tensor
|
|
194
|
+
```
|
|
195
|
+
|
|
196
|
+
You can also enable the environment globally with
|
|
197
|
+
```python
|
|
198
|
+
import torchax
|
|
199
|
+
|
|
200
|
+
torchax.enable_globally()
|
|
201
|
+
```
|
|
202
|
+
|
|
203
|
+
Then everything afterwards is run with XLA.
|
|
204
|
+
|
|
205
|
+
## What is happening behind the scene:
|
|
206
|
+
|
|
207
|
+
When a torch op is executed inside of `env` context manager, we can swap out the
|
|
208
|
+
implementation of that op with a version that runs on TPU.
|
|
209
|
+
When a model's constructor runs, it will call some tensor constructor, such as
|
|
210
|
+
`torch.rand`, `torch.ones` or `torch.zeros` etc to create its weights. Those
|
|
211
|
+
ops are captured by `env` too and placed directly on TPU.
|
|
212
|
+
|
|
213
|
+
See more at [how_it_works](docs/how_it_works.md) and [ops registry](docs/ops_registry.md).
|
|
214
|
+
|
|
215
|
+
### What if I created model outside of `env`.
|
|
216
|
+
|
|
217
|
+
So if you have
|
|
218
|
+
|
|
219
|
+
```
|
|
220
|
+
m = MyModel()
|
|
221
|
+
```
|
|
222
|
+
outside of env, then regular torch ops will run when creating this model.
|
|
223
|
+
Then presumably the model's weights will be on CPU (as instances of `torch.Tensor`).
|
|
224
|
+
|
|
225
|
+
To move this model into XLA device, one can use `env.to_xla()` function.
|
|
226
|
+
|
|
227
|
+
i.e.
|
|
228
|
+
```
|
|
229
|
+
m2 = env.to_xla(m)
|
|
230
|
+
inputs = env.to_xla(inputs)
|
|
231
|
+
|
|
232
|
+
with env:
|
|
233
|
+
res = m2(inputs)
|
|
234
|
+
```
|
|
235
|
+
|
|
236
|
+
NOTE that we also need to move inputs to xla using `.to_xla`.
|
|
237
|
+
`to_xla` works with all pytrees of `torch.Tensor`.
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
### Executing with jax.jit
|
|
241
|
+
|
|
242
|
+
The above script will execute the model using eager mode Jax as backend. This
|
|
243
|
+
does allow executing torch models on TPU, but is often slower than what we can
|
|
244
|
+
achieve with `jax.jit`.
|
|
245
|
+
|
|
246
|
+
`jax.jit` is a function that takes a Jax function (i.e. a function that takes jax array
|
|
247
|
+
and returns jax array) into the same function, but faster.
|
|
248
|
+
|
|
249
|
+
We have made the `jax_jit` decorator that would accomplish the same with functions
|
|
250
|
+
that takes and returns `torch.Tensor`. To use this, the first step is to create
|
|
251
|
+
a functional version of this model: this means the parameters should be passed in
|
|
252
|
+
as input instead of being attributes on class:
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
```python
|
|
256
|
+
|
|
257
|
+
def model_func(param, inputs):
|
|
258
|
+
return torch.func.functional_call(m, param, inputs)
|
|
259
|
+
|
|
260
|
+
```
|
|
261
|
+
Here we use [torch.func.functional_call](https://pytorch.org/docs/stable/generated/torch.func.functional_call.html)
|
|
262
|
+
from PyTorch to replace the model
|
|
263
|
+
weights with `param`, then call the model. This is equivalent to:
|
|
264
|
+
|
|
265
|
+
```python
|
|
266
|
+
def model_func(param, inputs):
|
|
267
|
+
m.load_state_dict(param)
|
|
268
|
+
return m(*inputs)
|
|
269
|
+
```
|
|
270
|
+
|
|
271
|
+
Now, we can apply `jax_jit`
|
|
272
|
+
|
|
273
|
+
```python
|
|
274
|
+
from torchax.interop import jax_jit
|
|
275
|
+
model_func_jitted = jax_jit(model_func)
|
|
276
|
+
print(model_func_jitted(new_state_dict, inputs))
|
|
277
|
+
```
|
|
278
|
+
|
|
279
|
+
See more examples at [eager_mode.py](examples/eager_mode.py) and the (examples folder)[examples/]
|
torchax-0.0.4/README.md
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
# torchxla2
|
|
2
|
+
|
|
3
|
+
## Install
|
|
4
|
+
|
|
5
|
+
Currently this is only source-installable. Requires Python version >= 3.10.
|
|
6
|
+
|
|
7
|
+
### NOTE:
|
|
8
|
+
|
|
9
|
+
Please don't install torch-xla from instructions in
|
|
10
|
+
https://github.com/pytorch/xla/blob/master/CONTRIBUTING.md .
|
|
11
|
+
In particular, the following are not needed:
|
|
12
|
+
|
|
13
|
+
* There is no need to build pytorch/pytorch from source.
|
|
14
|
+
* There is no need to clone pytorch/xla project inside of pytorch/pytorch
|
|
15
|
+
git checkout.
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
TorchXLA2 and torch-xla have different installation instructions, please follow
|
|
19
|
+
the instructions below from scratch (fresh venv / conda environment.)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
### 1. Installing `torchax`
|
|
23
|
+
|
|
24
|
+
The following instructions assume you are in the `torchax` directory:
|
|
25
|
+
|
|
26
|
+
```
|
|
27
|
+
Fork the repository
|
|
28
|
+
$ git clone https://github.com/<github_username>/xla.git
|
|
29
|
+
$ cd xla/torchax
|
|
30
|
+
```
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
#### 1.0 (recommended) Make a virtualenv / conda env
|
|
34
|
+
|
|
35
|
+
If you are using VSCode, then [you can create a new environment from
|
|
36
|
+
UI](https://code.visualstudio.com/docs/python/environments). Select the
|
|
37
|
+
`dev-requirements.txt` when asked to install project dependencies.
|
|
38
|
+
|
|
39
|
+
Otherwise create a new environment from the command line.
|
|
40
|
+
|
|
41
|
+
```bash
|
|
42
|
+
# Option 1: venv
|
|
43
|
+
python -m venv my_venv
|
|
44
|
+
source my_venv/bin/activate
|
|
45
|
+
|
|
46
|
+
# Option 2: conda
|
|
47
|
+
conda create --name <your_name> python=3.10
|
|
48
|
+
conda activate <your_name>
|
|
49
|
+
|
|
50
|
+
# Either way, install the dev requirements.
|
|
51
|
+
pip install -r dev-requirements.txt
|
|
52
|
+
```
|
|
53
|
+
|
|
54
|
+
Note: `dev-requirements.txt` will install the CPU-only version of PyTorch.
|
|
55
|
+
|
|
56
|
+
#### 1.1 Install this package
|
|
57
|
+
|
|
58
|
+
Install `torchax` from source for your platform:
|
|
59
|
+
```bash
|
|
60
|
+
pip install -e .[cpu]
|
|
61
|
+
pip install -e .[cuda]
|
|
62
|
+
pip install -e .[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
|
63
|
+
```
|
|
64
|
+
|
|
65
|
+
#### 1.2 (optional) verify installation by running tests
|
|
66
|
+
|
|
67
|
+
```bash
|
|
68
|
+
pip install -r test-requirements.txt
|
|
69
|
+
pytest test
|
|
70
|
+
```
|
|
71
|
+
|
|
72
|
+
## Run a model
|
|
73
|
+
|
|
74
|
+
Now let's execute a model under torchax. We'll start with a simple 2-layer model
|
|
75
|
+
it can be in theory any instance of `torch.nn.Module`.
|
|
76
|
+
|
|
77
|
+
```python
|
|
78
|
+
import torch
|
|
79
|
+
import torch.nn as nn
|
|
80
|
+
import torch.nn.functional as F
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class MyModel(nn.Module):
|
|
84
|
+
def __init__(self):
|
|
85
|
+
super().__init__()
|
|
86
|
+
self.fc1 = nn.Linear(28 * 28, 120)
|
|
87
|
+
self.fc2 = nn.Linear(120, 84)
|
|
88
|
+
self.fc3 = nn.Linear(84, 10)
|
|
89
|
+
|
|
90
|
+
def forward(self, x):
|
|
91
|
+
x = x.view(-1, 28 * 28)
|
|
92
|
+
x = F.relu(self.fc1(x))
|
|
93
|
+
x = F.relu(self.fc2(x))
|
|
94
|
+
x = self.fc3(x)
|
|
95
|
+
return x
|
|
96
|
+
|
|
97
|
+
m = MyModel()
|
|
98
|
+
|
|
99
|
+
# Execute this model using torch
|
|
100
|
+
inputs = torch.randn(3, 3, 28, 28)
|
|
101
|
+
print(m(inputs))
|
|
102
|
+
```
|
|
103
|
+
|
|
104
|
+
This model `m` contains 2 parts: the weights that is stored inside of the model
|
|
105
|
+
and it's submodules (`nn.Linear`).
|
|
106
|
+
|
|
107
|
+
To execute this model with `torchax`; we need construct and run the model
|
|
108
|
+
under an `environment` that captures pytorch ops and swaps them with TPU equivalent.
|
|
109
|
+
|
|
110
|
+
To create this environment: use
|
|
111
|
+
|
|
112
|
+
```python
|
|
113
|
+
import torchax
|
|
114
|
+
|
|
115
|
+
env = torchax.default_env()
|
|
116
|
+
```
|
|
117
|
+
Then, execute the instantiation of the model, as well as evaluation of model,
|
|
118
|
+
using `env` as a context manager:
|
|
119
|
+
|
|
120
|
+
```python
|
|
121
|
+
with env:
|
|
122
|
+
inputs = torch.randn(3, 3, 28, 28)
|
|
123
|
+
m = MyModel()
|
|
124
|
+
res = m(inputs)
|
|
125
|
+
print(type(res)) # outputs Tensor
|
|
126
|
+
```
|
|
127
|
+
|
|
128
|
+
You can also enable the environment globally with
|
|
129
|
+
```python
|
|
130
|
+
import torchax
|
|
131
|
+
|
|
132
|
+
torchax.enable_globally()
|
|
133
|
+
```
|
|
134
|
+
|
|
135
|
+
Then everything afterwards is run with XLA.
|
|
136
|
+
|
|
137
|
+
## What is happening behind the scene:
|
|
138
|
+
|
|
139
|
+
When a torch op is executed inside of `env` context manager, we can swap out the
|
|
140
|
+
implementation of that op with a version that runs on TPU.
|
|
141
|
+
When a model's constructor runs, it will call some tensor constructor, such as
|
|
142
|
+
`torch.rand`, `torch.ones` or `torch.zeros` etc to create its weights. Those
|
|
143
|
+
ops are captured by `env` too and placed directly on TPU.
|
|
144
|
+
|
|
145
|
+
See more at [how_it_works](docs/how_it_works.md) and [ops registry](docs/ops_registry.md).
|
|
146
|
+
|
|
147
|
+
### What if I created model outside of `env`.
|
|
148
|
+
|
|
149
|
+
So if you have
|
|
150
|
+
|
|
151
|
+
```
|
|
152
|
+
m = MyModel()
|
|
153
|
+
```
|
|
154
|
+
outside of env, then regular torch ops will run when creating this model.
|
|
155
|
+
Then presumably the model's weights will be on CPU (as instances of `torch.Tensor`).
|
|
156
|
+
|
|
157
|
+
To move this model into XLA device, one can use `env.to_xla()` function.
|
|
158
|
+
|
|
159
|
+
i.e.
|
|
160
|
+
```
|
|
161
|
+
m2 = env.to_xla(m)
|
|
162
|
+
inputs = env.to_xla(inputs)
|
|
163
|
+
|
|
164
|
+
with env:
|
|
165
|
+
res = m2(inputs)
|
|
166
|
+
```
|
|
167
|
+
|
|
168
|
+
NOTE that we also need to move inputs to xla using `.to_xla`.
|
|
169
|
+
`to_xla` works with all pytrees of `torch.Tensor`.
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
### Executing with jax.jit
|
|
173
|
+
|
|
174
|
+
The above script will execute the model using eager mode Jax as backend. This
|
|
175
|
+
does allow executing torch models on TPU, but is often slower than what we can
|
|
176
|
+
achieve with `jax.jit`.
|
|
177
|
+
|
|
178
|
+
`jax.jit` is a function that takes a Jax function (i.e. a function that takes jax array
|
|
179
|
+
and returns jax array) into the same function, but faster.
|
|
180
|
+
|
|
181
|
+
We have made the `jax_jit` decorator that would accomplish the same with functions
|
|
182
|
+
that takes and returns `torch.Tensor`. To use this, the first step is to create
|
|
183
|
+
a functional version of this model: this means the parameters should be passed in
|
|
184
|
+
as input instead of being attributes on class:
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
```python
|
|
188
|
+
|
|
189
|
+
def model_func(param, inputs):
|
|
190
|
+
return torch.func.functional_call(m, param, inputs)
|
|
191
|
+
|
|
192
|
+
```
|
|
193
|
+
Here we use [torch.func.functional_call](https://pytorch.org/docs/stable/generated/torch.func.functional_call.html)
|
|
194
|
+
from PyTorch to replace the model
|
|
195
|
+
weights with `param`, then call the model. This is equivalent to:
|
|
196
|
+
|
|
197
|
+
```python
|
|
198
|
+
def model_func(param, inputs):
|
|
199
|
+
m.load_state_dict(param)
|
|
200
|
+
return m(*inputs)
|
|
201
|
+
```
|
|
202
|
+
|
|
203
|
+
Now, we can apply `jax_jit`
|
|
204
|
+
|
|
205
|
+
```python
|
|
206
|
+
from torchax.interop import jax_jit
|
|
207
|
+
model_func_jitted = jax_jit(model_func)
|
|
208
|
+
print(model_func_jitted(new_state_dict, inputs))
|
|
209
|
+
```
|
|
210
|
+
|
|
211
|
+
See more examples at [eager_mode.py](examples/eager_mode.py) and the (examples folder)[examples/]
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
#!/usr/bin/env bash
|
|
2
|
+
set -ex
|
|
3
|
+
|
|
4
|
+
NIGHTLY_VERSION=$(date '+%Y%m%d%H%M')
|
|
5
|
+
|
|
6
|
+
# Update the version to <version>.devYYYYMMDDHHMM in __init__.py
|
|
7
|
+
VERSION_UPDATE_PATTERN="s/^__version__\s*=\s*\"([^\"]+)\"/__version__ = \"\1.dev$NIGHTLY_VERSION\"/g;"
|
|
8
|
+
sed -r "$VERSION_UPDATE_PATTERN" torchax/__init__.py --in-place
|
|
9
|
+
|
|
10
|
+
hatch build -t wheel
|
|
Binary file
|