torchax 0.0.4__tar.gz → 0.0.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- {torchax-0.0.4 → torchax-0.0.5}/.gitignore +10 -6
- torchax-0.0.5/PKG-INFO +307 -0
- torchax-0.0.5/README.md +242 -0
- torchax-0.0.5/dev-requirements.txt +5 -0
- {torchax-0.0.4 → torchax-0.0.5}/docs/fixing_op_info_test.md +10 -10
- {torchax-0.0.4 → torchax-0.0.5}/docs/how_it_works.md +7 -7
- {torchax-0.0.4 → torchax-0.0.5}/docs/ops_registry.md +3 -2
- {torchax-0.0.4 → torchax-0.0.5}/docs/support_a_new_model.md +5 -5
- {torchax-0.0.4 → torchax-0.0.5}/docs/torch_dispatch/README.md +1 -1
- {torchax-0.0.4 → torchax-0.0.5}/docs/torch_dispatch/example.py +18 -16
- {torchax-0.0.4 → torchax-0.0.5}/docs/torch_dispatch/run_env.py +0 -1
- {torchax-0.0.4 → torchax-0.0.5}/docs/torch_xla2_dynamo.md +13 -13
- {torchax-0.0.4 → torchax-0.0.5}/docs/understand_jax_jit/jax_grad.py +12 -7
- {torchax-0.0.4 → torchax-0.0.5}/docs/understand_jax_jit/jax_jit.py +22 -23
- {torchax-0.0.4 → torchax-0.0.5}/docs/understand_jax_jit/torch_module.py +12 -17
- {torchax-0.0.4 → torchax-0.0.5}/examples/README.md +5 -5
- torchax-0.0.5/examples/_diffusion.py +106 -0
- {torchax-0.0.4 → torchax-0.0.5}/examples/_grad_of_attention.py +28 -25
- torchax-0.0.5/examples/basic_training.py +195 -0
- {torchax-0.0.4 → torchax-0.0.5}/examples/basic_training_jax.py +54 -47
- {torchax-0.0.4 → torchax-0.0.5}/examples/eager_mode.py +14 -14
- torchax-0.0.5/examples/lightning_training.py +82 -0
- {torchax-0.0.4 → torchax-0.0.5}/examples/torchbench_models/BERT_pytorch.py +13 -10
- {torchax-0.0.4 → torchax-0.0.5}/examples/train_gpt/train_ddp.py +16 -23
- {torchax-0.0.4 → torchax-0.0.5}/examples/train_llama/README.md +13 -13
- torchax-0.0.5/examples/train_llama/model.py +510 -0
- torchax-0.0.5/examples/train_llama/train_llama_lightning.py +307 -0
- torchax-0.0.5/examples/train_llama/utils.py +318 -0
- {torchax-0.0.4 → torchax-0.0.5}/examples/train_llama_torchtitan/README.md +20 -20
- torchax-0.0.5/examples/train_llama_torchtitan/helper.py +37 -0
- {torchax-0.0.4 → torchax-0.0.5}/examples/train_llama_torchtitan/splash_attn.py +32 -28
- torchax-0.0.5/examples/train_llama_torchtitan/train_llama.py +385 -0
- {torchax-0.0.4 → torchax-0.0.5}/pyproject.toml +6 -10
- torchax-0.0.5/repro1.py +62 -0
- torchax-0.0.5/temp +444 -0
- torchax-0.0.5/test/gemma/config.py +83 -0
- torchax-0.0.5/test/gemma/model.py +549 -0
- torchax-0.0.5/test/gemma/test_gemma.py +82 -0
- torchax-0.0.5/test/gemma/tokenizer.py +48 -0
- torchax-0.0.5/test/llama/test_llama.py +111 -0
- torchax-0.0.5/test/moe/model.py +307 -0
- torchax-0.0.5/test/moe/moe_test.py +68 -0
- torchax-0.0.5/test/test_amp.py +38 -0
- {torchax-0.0.4 → torchax-0.0.5}/test/test_context.py +23 -26
- {torchax-0.0.4 → torchax-0.0.5}/test/test_conv.py +9 -7
- {torchax-0.0.4 → torchax-0.0.5}/test/test_core_aten_ops.py +49 -16
- {torchax-0.0.4 → torchax-0.0.5}/test/test_exports.py +60 -36
- torchax-0.0.5/test/test_flax.py +97 -0
- torchax-0.0.5/test/test_functions.py +97 -0
- torchax-0.0.5/test/test_image.py +68 -0
- torchax-0.0.5/test/test_interop.py +176 -0
- torchax-0.0.5/test/test_jittable_module.py +39 -0
- {torchax-0.0.4 → torchax-0.0.5}/test/test_libraries.py +19 -14
- {torchax-0.0.4 → torchax-0.0.5}/test/test_ops.py +87 -85
- {torchax-0.0.4 → torchax-0.0.5}/test/test_symbolic_shapes.py +18 -11
- {torchax-0.0.4 → torchax-0.0.5}/test/test_tf_integration.py +3 -4
- {torchax-0.0.4 → torchax-0.0.5}/test/test_train.py +7 -9
- {torchax-0.0.4 → torchax-0.0.5}/test/test_unbounded_dynamism.py +9 -1
- torchax-0.0.5/test/test_util.py +118 -0
- torchax-0.0.5/test/test_view.py +385 -0
- torchax-0.0.5/test-requirements.txt +10 -0
- {torchax-0.0.4 → torchax-0.0.5}/test_dist/test_distributed.py +34 -35
- torchax-0.0.5/test_dist/test_mesh_util.py +51 -0
- {torchax-0.0.4 → torchax-0.0.5}/torchax/CONTRIBUTING.md +2 -2
- {torchax-0.0.4 → torchax-0.0.5}/torchax/__init__.py +57 -19
- torchax-0.0.5/torchax/amp.py +333 -0
- torchax-0.0.5/torchax/config.py +26 -0
- torchax-0.0.5/torchax/decompositions.py +776 -0
- {torchax-0.0.4 → torchax-0.0.5}/torchax/device_module.py +7 -1
- {torchax-0.0.4 → torchax-0.0.5}/torchax/distributed.py +55 -60
- {torchax-0.0.4 → torchax-0.0.5}/torchax/export.py +26 -17
- torchax-0.0.5/torchax/flax.py +39 -0
- torchax-0.0.5/torchax/interop.py +343 -0
- torchax-0.0.5/torchax/mesh_util.py +211 -0
- {torchax-0.0.4 → torchax-0.0.5}/torchax/ops/jaten.py +1732 -1293
- {torchax-0.0.4 → torchax-0.0.5}/torchax/ops/jax_reimplement.py +23 -21
- {torchax-0.0.4 → torchax-0.0.5}/torchax/ops/jc10d.py +5 -4
- torchax-0.0.5/torchax/ops/jimage.py +113 -0
- {torchax-0.0.4 → torchax-0.0.5}/torchax/ops/jlibrary.py +9 -2
- {torchax-0.0.4 → torchax-0.0.5}/torchax/ops/jtorch.py +218 -75
- {torchax-0.0.4 → torchax-0.0.5}/torchax/ops/jtorchvision_nms.py +32 -43
- torchax-0.0.5/torchax/ops/mappings.py +139 -0
- torchax-0.0.5/torchax/ops/op_base.py +131 -0
- torchax-0.0.5/torchax/ops/ops_registry.py +55 -0
- torchax-0.0.5/torchax/tensor.py +699 -0
- {torchax-0.0.4 → torchax-0.0.5}/torchax/train.py +38 -41
- torchax-0.0.5/torchax/util.py +88 -0
- torchax-0.0.5/torchax/view.py +377 -0
- torchax-0.0.4/PKG-INFO +0 -279
- torchax-0.0.4/README.md +0 -211
- torchax-0.0.4/dev-requirements.txt +0 -4
- torchax-0.0.4/examples/_diffusion.py +0 -112
- torchax-0.0.4/examples/basic_training.py +0 -187
- torchax-0.0.4/examples/lightning_training.py +0 -77
- torchax-0.0.4/examples/train_llama/model.py +0 -449
- torchax-0.0.4/examples/train_llama/train_llama_lightning.py +0 -329
- torchax-0.0.4/examples/train_llama/utils.py +0 -314
- torchax-0.0.4/examples/train_llama_torchtitan/helper.py +0 -37
- torchax-0.0.4/examples/train_llama_torchtitan/train_llama.py +0 -351
- torchax-0.0.4/test/BUILD +0 -31
- torchax-0.0.4/test/gemma/config.py +0 -86
- torchax-0.0.4/test/gemma/model.py +0 -561
- torchax-0.0.4/test/gemma/test_gemma.py +0 -86
- torchax-0.0.4/test/gemma/tokenizer.py +0 -48
- torchax-0.0.4/test/llama/test_llama.py +0 -113
- torchax-0.0.4/test/moe/model.py +0 -260
- torchax-0.0.4/test/moe/moe_test.py +0 -75
- torchax-0.0.4/test/test_functions.py +0 -98
- torchax-0.0.4/test/test_interop.py +0 -47
- torchax-0.0.4/test-requirements.txt +0 -9
- torchax-0.0.4/torchax/config.py +0 -19
- torchax-0.0.4/torchax/decompositions.py +0 -308
- torchax-0.0.4/torchax/environment.py +0 -2
- torchax-0.0.4/torchax/interop.py +0 -209
- torchax-0.0.4/torchax/ops/mappings.py +0 -97
- torchax-0.0.4/torchax/ops/op_base.py +0 -104
- torchax-0.0.4/torchax/ops/ops_registry.py +0 -50
- torchax-0.0.4/torchax/tensor.py +0 -557
- {torchax-0.0.4 → torchax-0.0.5}/=2.3.0 +0 -0
- {torchax-0.0.4 → torchax-0.0.5}/LICENSE +0 -0
- {torchax-0.0.4 → torchax-0.0.5}/build_nightly.sh +0 -0
- {torchax-0.0.4 → torchax-0.0.5}/docs/dispatch.png +0 -0
- {torchax-0.0.4 → torchax-0.0.5}/examples/__init__.py +0 -0
- {torchax-0.0.4 → torchax-0.0.5}/examples/mnist_tpu.ipynb +0 -0
- {torchax-0.0.4 → torchax-0.0.5}/examples/requirements.txt +0 -0
- {torchax-0.0.4 → torchax-0.0.5}/examples/train_gpt/requirements.txt +0 -0
- {torchax-0.0.4 → torchax-0.0.5}/examples/train_llama/__init__.py +0 -0
- {torchax-0.0.4 → torchax-0.0.5}/examples/train_llama_torchtitan/Dockerfile +0 -0
- {torchax-0.0.4 → torchax-0.0.5}/examples/train_llama_torchtitan/__init__.py +0 -0
- {torchax-0.0.4 → torchax-0.0.5}/format.sh +0 -0
- {torchax-0.0.4 → torchax-0.0.5}/test/__init__.py +0 -0
- /torchax-0.0.4/test/test_base.py → /torchax-0.0.5/test/base_test_util.py +0 -0
- {torchax-0.0.4 → torchax-0.0.5}/test/gemma/__init__.py +0 -0
- {torchax-0.0.4 → torchax-0.0.5}/test/llama/BUILD +0 -0
- {torchax-0.0.4 → torchax-0.0.5}/test/llama/__init__.py +0 -0
- {torchax-0.0.4 → torchax-0.0.5}/test/llama/llama_model.py +0 -0
- {torchax-0.0.4 → torchax-0.0.5}/test/llama/model_exportable.py +0 -0
- {torchax-0.0.4 → torchax-0.0.5}/test/moe/__init__.py +0 -0
- {torchax-0.0.4 → torchax-0.0.5}/test/test_mutations.py +0 -0
- {torchax-0.0.4 → torchax-0.0.5}/test_dist/README.md +0 -0
- {torchax-0.0.4 → torchax-0.0.5}/test_dist/__init__.py +0 -0
- {torchax-0.0.4 → torchax-0.0.5}/torchax/ops/__init__.py +0 -0
- {torchax-0.0.4 → torchax-0.0.5}/torchax/tf_integration.py +0 -0
- {torchax-0.0.4 → torchax-0.0.5}/torchax/types.py +0 -0
|
@@ -6,28 +6,32 @@ torch_xla/pb/cpp/*
|
|
|
6
6
|
torch_xla/version.py
|
|
7
7
|
torch_xla/csrc/version.cpp
|
|
8
8
|
*/**/__pycache__
|
|
9
|
+
*/**/*MNIST/
|
|
10
|
+
torchax/**/runs/
|
|
9
11
|
*.swp
|
|
10
12
|
*.pyc
|
|
11
13
|
*.so
|
|
12
14
|
|
|
13
15
|
# BEGIN NOT-CLEAN-FILES (setup.py handles this marker. Do not change.)
|
|
14
16
|
#
|
|
15
|
-
#
|
|
17
|
+
# Files below are not deleted by "setup.py clean".
|
|
16
18
|
|
|
17
|
-
# Visual Studio Code files
|
|
19
|
+
# Visual Studio Code files.
|
|
18
20
|
.vs
|
|
19
21
|
.vscode/
|
|
20
22
|
|
|
21
|
-
# Files autogenerated by docs/docs_build.sh
|
|
23
|
+
# Files autogenerated by docs/docs_build.sh.
|
|
22
24
|
/core
|
|
23
25
|
/docs/src/*
|
|
24
26
|
|
|
25
|
-
# Local terraform state
|
|
27
|
+
# Local terraform state.
|
|
26
28
|
.terraform
|
|
27
29
|
|
|
28
30
|
|
|
29
|
-
#
|
|
31
|
+
# Bazel temporary files.
|
|
30
32
|
bazel-*
|
|
33
|
+
MODULE.bazel
|
|
34
|
+
MODULE.bazel.lock
|
|
31
35
|
|
|
32
|
-
# Clangd cache directory
|
|
36
|
+
# Clangd cache directory.
|
|
33
37
|
.cache/*
|
torchax-0.0.5/PKG-INFO
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: torchax
|
|
3
|
+
Version: 0.0.5
|
|
4
|
+
Summary: torchax is a library for running Jax and PyTorch together
|
|
5
|
+
Project-URL: Homepage, https://github.com/pytorch/xla/tree/master/torchax
|
|
6
|
+
Author-email: Han Qi <qihan.dev@gmail.com>, Pytorch/XLA team <pytorchxla-dev@google.com>
|
|
7
|
+
License: BSD 3-Clause License
|
|
8
|
+
|
|
9
|
+
Copyright (c) 2023, pytorch-tpu
|
|
10
|
+
|
|
11
|
+
Redistribution and use in source and binary forms, with or without
|
|
12
|
+
modification, are permitted provided that the following conditions are met:
|
|
13
|
+
|
|
14
|
+
1. Redistributions of source code must retain the above copyright notice, this
|
|
15
|
+
list of conditions and the following disclaimer.
|
|
16
|
+
|
|
17
|
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
|
18
|
+
this list of conditions and the following disclaimer in the documentation
|
|
19
|
+
and/or other materials provided with the distribution.
|
|
20
|
+
|
|
21
|
+
3. Neither the name of the copyright holder nor the names of its
|
|
22
|
+
contributors may be used to endorse or promote products derived from
|
|
23
|
+
this software without specific prior written permission.
|
|
24
|
+
|
|
25
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
26
|
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
27
|
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
28
|
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
29
|
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
30
|
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
31
|
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
32
|
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
33
|
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
34
|
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
35
|
+
License-File: LICENSE
|
|
36
|
+
Classifier: Development Status :: 3 - Alpha
|
|
37
|
+
Classifier: Intended Audience :: Developers
|
|
38
|
+
Classifier: Intended Audience :: Education
|
|
39
|
+
Classifier: Intended Audience :: Science/Research
|
|
40
|
+
Classifier: License :: OSI Approved :: BSD License
|
|
41
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
42
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
43
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
44
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
45
|
+
Classifier: Topic :: Scientific/Engineering
|
|
46
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
47
|
+
Classifier: Topic :: Scientific/Engineering :: Mathematics
|
|
48
|
+
Classifier: Topic :: Software Development
|
|
49
|
+
Classifier: Topic :: Software Development :: Libraries
|
|
50
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
51
|
+
Requires-Python: >=3.10
|
|
52
|
+
Provides-Extra: cpu
|
|
53
|
+
Requires-Dist: jax[cpu]; extra == 'cpu'
|
|
54
|
+
Requires-Dist: jax[cpu]>=0.6.2; extra == 'cpu'
|
|
55
|
+
Provides-Extra: cuda
|
|
56
|
+
Requires-Dist: jax[cpu]>=0.6.2; extra == 'cuda'
|
|
57
|
+
Requires-Dist: jax[cuda12]; extra == 'cuda'
|
|
58
|
+
Provides-Extra: odml
|
|
59
|
+
Requires-Dist: jax[cpu]; extra == 'odml'
|
|
60
|
+
Requires-Dist: jax[cpu]>=0.6.2; extra == 'odml'
|
|
61
|
+
Provides-Extra: tpu
|
|
62
|
+
Requires-Dist: jax[cpu]>=0.6.2; extra == 'tpu'
|
|
63
|
+
Requires-Dist: jax[tpu]; extra == 'tpu'
|
|
64
|
+
Description-Content-Type: text/markdown
|
|
65
|
+
|
|
66
|
+
# torchax: Running PyTorch on TPU via JAX
|
|
67
|
+
|
|
68
|
+
**torchax** is a backend for PyTorch, allowing users to run
|
|
69
|
+
PyTorch on Google Cloud TPUs. **torchax** is also a library for providing
|
|
70
|
+
graph-level interoperability between PyTorch and JAX.
|
|
71
|
+
|
|
72
|
+
This means, with **torchax** you can:
|
|
73
|
+
* Run PyTorch code on TPUs with as little as 2 lines of code change.
|
|
74
|
+
* Call a JAX function from a PyTorch function, passing in `jax.Array`s.
|
|
75
|
+
* Call a PyTorch function from a JAX function, passing in a `torch.Tensor`s.
|
|
76
|
+
* Use JAX features such as `jax.grad`, `optax`, and `GSPMD` to train a PyTorch
|
|
77
|
+
model.
|
|
78
|
+
* Use a PyTorch model as feature extractor and use it with a JAX model.
|
|
79
|
+
etc etc.
|
|
80
|
+
|
|
81
|
+
## Install
|
|
82
|
+
|
|
83
|
+
First install torch CPU:
|
|
84
|
+
|
|
85
|
+
```bash
|
|
86
|
+
# On Linux.
|
|
87
|
+
pip install torch --index-url https://download.pytorch.org/whl/cpu
|
|
88
|
+
|
|
89
|
+
# Or on Mac.
|
|
90
|
+
pip install torch
|
|
91
|
+
```
|
|
92
|
+
|
|
93
|
+
Then install JAX for the accelerator you want to use:
|
|
94
|
+
|
|
95
|
+
```bash
|
|
96
|
+
# On Google Cloud TPU.
|
|
97
|
+
pip install -U jax[tpu]
|
|
98
|
+
|
|
99
|
+
# Or, on GPU machines.
|
|
100
|
+
pip install -U jax[cuda12]
|
|
101
|
+
|
|
102
|
+
# Or, on Linux CPU machines or Macs (see the note below).
|
|
103
|
+
pip install -U jax
|
|
104
|
+
```
|
|
105
|
+
|
|
106
|
+
NOTE: if you like metal support for Apple devices then install the
|
|
107
|
+
metal version of JAX: https://developer.apple.com/metal/jax/
|
|
108
|
+
|
|
109
|
+
Finally install torchax:
|
|
110
|
+
|
|
111
|
+
```bash
|
|
112
|
+
# Install pre-built torchax.
|
|
113
|
+
pip install torchax
|
|
114
|
+
|
|
115
|
+
# Or, install torchax from source.
|
|
116
|
+
pip install git+https://github.com/pytorch/xla.git#subdirectory=torchax
|
|
117
|
+
```
|
|
118
|
+
|
|
119
|
+
## Run a model
|
|
120
|
+
|
|
121
|
+
Now let's execute a model under torchax. We'll start with a simple 2-layer model.
|
|
122
|
+
In theory, we can use any instance of `torch.nn.Module`.
|
|
123
|
+
|
|
124
|
+
```python
|
|
125
|
+
import torch
|
|
126
|
+
import torch.nn as nn
|
|
127
|
+
import torch.nn.functional as F
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class MyModel(nn.Module):
|
|
131
|
+
def __init__(self):
|
|
132
|
+
super().__init__()
|
|
133
|
+
self.fc1 = nn.Linear(28 * 28, 120)
|
|
134
|
+
self.fc2 = nn.Linear(120, 84)
|
|
135
|
+
self.fc3 = nn.Linear(84, 10)
|
|
136
|
+
|
|
137
|
+
def forward(self, x):
|
|
138
|
+
x = x.view(-1, 28 * 28)
|
|
139
|
+
x = F.relu(self.fc1(x))
|
|
140
|
+
x = F.relu(self.fc2(x))
|
|
141
|
+
x = self.fc3(x)
|
|
142
|
+
return x
|
|
143
|
+
|
|
144
|
+
m = MyModel()
|
|
145
|
+
|
|
146
|
+
# Execute this model using torch.
|
|
147
|
+
inputs = torch.randn(3, 3, 28, 28)
|
|
148
|
+
print(m(inputs))
|
|
149
|
+
```
|
|
150
|
+
|
|
151
|
+
To execute this model with `torchax`, we need to enable torchax to capture PyTorch ops:
|
|
152
|
+
|
|
153
|
+
```python
|
|
154
|
+
import torchax
|
|
155
|
+
torchax.enable_globally()
|
|
156
|
+
```
|
|
157
|
+
|
|
158
|
+
Then, we can use a `jax` device:
|
|
159
|
+
|
|
160
|
+
```python
|
|
161
|
+
inputs = torch.randn(3, 3, 28, 28, device='jax')
|
|
162
|
+
m = MyModel().to('jax')
|
|
163
|
+
res = m(inputs)
|
|
164
|
+
print(type(res)) # outputs torchax.tensor.Tensor
|
|
165
|
+
```
|
|
166
|
+
|
|
167
|
+
`torchax.tensor.Tensor` is a `torch.Tensor` subclass that holds
|
|
168
|
+
a `jax.Array`. You can inspect that JAX array with `res.jax()`.
|
|
169
|
+
|
|
170
|
+
## What is happening behind the scene
|
|
171
|
+
|
|
172
|
+
We took the approach detailed in the
|
|
173
|
+
[new device](https://github.com/albanD/subclass_zoo/blob/main/new_device.py)
|
|
174
|
+
recipe by Alban (@albanD), using `jax.Array` for `raw_data`.
|
|
175
|
+
|
|
176
|
+
In other words, when a torch op is executed inside an `env` context manager,
|
|
177
|
+
which is enabled by `torchax.enable_globally()`, we will swap out the
|
|
178
|
+
implementation of that op with JAX.
|
|
179
|
+
|
|
180
|
+
When a model's constructor runs, it will call some tensor constructor, such as
|
|
181
|
+
`torch.rand`, `torch.ones`, or `torch.zeros` to create its weights. When torchax
|
|
182
|
+
is enabled, these constructors will create a `torchax.tensor.Tensor`, which
|
|
183
|
+
contains a `jax.Array`.
|
|
184
|
+
|
|
185
|
+
Then, each subsequent op will extract the `jax.Array`, call the op's JAX
|
|
186
|
+
implementation, and wrap the result back into a `torchax.tensor.Tensor`,
|
|
187
|
+
|
|
188
|
+
See more at [how it works](docs/how_it_works.md) and\
|
|
189
|
+
[ops registry](docs/ops_registry.md).
|
|
190
|
+
|
|
191
|
+
### Executing with jax.jit
|
|
192
|
+
|
|
193
|
+
The above script will execute the model using eager mode JAX as the backend. This
|
|
194
|
+
does allow executing torch models on TPUs, but is often slower than what we can
|
|
195
|
+
achieve with `jax.jit`.
|
|
196
|
+
|
|
197
|
+
`jax.jit` is a function that takes a JAX function (i.e. a function that takes JAX arrays
|
|
198
|
+
and returns JAX arrays) into a compiled (thus faster) version of the same function.
|
|
199
|
+
|
|
200
|
+
We have made a `jax_jit` decorator that would accomplish the same with functions
|
|
201
|
+
that takes and returns `torch.Tensor`s. To use this, the first step is to create
|
|
202
|
+
a functional version of this model: this means the parameters should be passed in
|
|
203
|
+
as input instead of being attributes of the class:
|
|
204
|
+
|
|
205
|
+
```python
|
|
206
|
+
def model_func(param, inputs):
|
|
207
|
+
return torch.func.functional_call(m, param, inputs)
|
|
208
|
+
```
|
|
209
|
+
|
|
210
|
+
Here we use [torch.func.functional_call](https://pytorch.org/docs/stable/generated/torch.func.functional_call.html)
|
|
211
|
+
from PyTorch to replace the model weights with `param` and then call the
|
|
212
|
+
model. This is roughly equivalent to:
|
|
213
|
+
|
|
214
|
+
```python
|
|
215
|
+
def model_func(param, inputs):
|
|
216
|
+
m.load_state_dict(param)
|
|
217
|
+
return m(*inputs)
|
|
218
|
+
```
|
|
219
|
+
|
|
220
|
+
Now, we can apply `jax_jit` on `module_func`:
|
|
221
|
+
|
|
222
|
+
```python
|
|
223
|
+
from torchax.interop import jax_jit
|
|
224
|
+
|
|
225
|
+
model_func_jitted = jax_jit(model_func)
|
|
226
|
+
print(model_func_jitted(new_state_dict, inputs))
|
|
227
|
+
```
|
|
228
|
+
|
|
229
|
+
See more examples at [eager_mode.py](examples/eager_mode.py) and the
|
|
230
|
+
[examples folder](examples/).
|
|
231
|
+
|
|
232
|
+
To ease the idiom of creating functional model and calling it with parameters,
|
|
233
|
+
we also created the `JittableModule` helper class. It lets us rewrite the
|
|
234
|
+
above as:
|
|
235
|
+
|
|
236
|
+
```python
|
|
237
|
+
from torchax.interop import JittableModule
|
|
238
|
+
|
|
239
|
+
m_jitted = JittableModule(m)
|
|
240
|
+
res = m_jitted(...)
|
|
241
|
+
```
|
|
242
|
+
|
|
243
|
+
The first time `m_jitted` is called, it will trigger `jax.jit` to compile the
|
|
244
|
+
compile for the given input shapes. Subsequent calls with the same input shapes
|
|
245
|
+
will be fast as the compilation is cached.
|
|
246
|
+
|
|
247
|
+
## Citation
|
|
248
|
+
|
|
249
|
+
```
|
|
250
|
+
@software{torchax,
|
|
251
|
+
author = {Han Qi, Chun-nien Chan, Will Cromar, Manfei Bai, Kevin Gleanson},
|
|
252
|
+
title = {torchax: PyTorch on TPU and JAX interoperability},
|
|
253
|
+
url = {https://github.com/pytorch/xla/tree/master/torchax}
|
|
254
|
+
version = {0.0.4},
|
|
255
|
+
date = {2025-02-24},
|
|
256
|
+
}
|
|
257
|
+
```
|
|
258
|
+
|
|
259
|
+
# Maintainers & Contributors:
|
|
260
|
+
|
|
261
|
+
This library is created and maintained by the PyTorch/XLA team at Google Cloud.
|
|
262
|
+
|
|
263
|
+
It benefitted from many direct and indirect
|
|
264
|
+
contributions outside of the team. Many of them done by
|
|
265
|
+
fellow Googlers using [Google's 20% project policy](https://ebsedu.org/blog/google-tapping-workplace-actualization-20-time-rule).
|
|
266
|
+
Others by partner teams at Google and other companies.
|
|
267
|
+
|
|
268
|
+
Here is the list of contributors by 2025-02-25.
|
|
269
|
+
|
|
270
|
+
```
|
|
271
|
+
Han Qi (qihqi), PyTorch/XLA
|
|
272
|
+
Manfei Bai (manfeibai), PyTorch/XLA
|
|
273
|
+
Will Cromar (will-cromar), Meta
|
|
274
|
+
Milad Mohammadi (miladm), PyTorch/XLA
|
|
275
|
+
Siyuan Liu (lsy323), PyTorch/XLA
|
|
276
|
+
Bhavya Bahl (bhavya01), PyTorch/XLA
|
|
277
|
+
Pei Zhang (zpcore), PyTorch/XLA
|
|
278
|
+
Yifei Teng (tengyifei), PyTorch/XLA
|
|
279
|
+
Chunnien Chan (chunnienc), Google, ODML
|
|
280
|
+
Alban Desmaison (albanD), Meta, PyTorch
|
|
281
|
+
Simon Teo (simonteozw), Google (20%)
|
|
282
|
+
David Huang (dvhg), Google (20%)
|
|
283
|
+
Barni Seetharaman (barney-s), Google (20%)
|
|
284
|
+
Anish Karthik (anishfish2), Google (20%)
|
|
285
|
+
Yao Gu (guyao), Google (20%)
|
|
286
|
+
Yenkai Wang (yenkwang), Google (20%)
|
|
287
|
+
Greg Shikhman (commander), Google (20%)
|
|
288
|
+
Matin Akhlaghinia (matinehAkhlaghinia), Google (20%)
|
|
289
|
+
Tracy Chen (tracych477), Google (20%)
|
|
290
|
+
Matthias Guenther (mrguenther), Google (20%)
|
|
291
|
+
WenXin Dong (wenxindongwork), Google (20%)
|
|
292
|
+
Kevin Gleason (GleasonK), Google, StableHLO
|
|
293
|
+
Nupur Baghel (nupurbaghel), Google (20%)
|
|
294
|
+
Gwen Mittertreiner (gmittert), Google (20%)
|
|
295
|
+
Zeev Melumian (zmelumian), Lightricks
|
|
296
|
+
Vyom Sharma (vyom1611), Google (20%)
|
|
297
|
+
Shitong Wang (ShitongWang), Adobe
|
|
298
|
+
Rémi Doreau (ayshiff), Google (20%)
|
|
299
|
+
Lance Wang (wang2yn84), Google, CoreML
|
|
300
|
+
Hossein Sarshar (hosseinsarshar), Google (20%)
|
|
301
|
+
Daniel Vega-Myhre (danielvegamyhre), Google (20%)
|
|
302
|
+
Tianqi Fan (tqfan28), Google (20%)
|
|
303
|
+
Jim Lin (jimlinntu), Google (20%)
|
|
304
|
+
Fanhai Lu (FanhaiLu1), Google Cloud
|
|
305
|
+
DeWitt Clinton (dewitt), Google PyTorch
|
|
306
|
+
Aman Gupta (aman2930), Google (20%)
|
|
307
|
+
```
|
torchax-0.0.5/README.md
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
1
|
+
# torchax: Running PyTorch on TPU via JAX
|
|
2
|
+
|
|
3
|
+
**torchax** is a backend for PyTorch, allowing users to run
|
|
4
|
+
PyTorch on Google Cloud TPUs. **torchax** is also a library for providing
|
|
5
|
+
graph-level interoperability between PyTorch and JAX.
|
|
6
|
+
|
|
7
|
+
This means, with **torchax** you can:
|
|
8
|
+
* Run PyTorch code on TPUs with as little as 2 lines of code change.
|
|
9
|
+
* Call a JAX function from a PyTorch function, passing in `jax.Array`s.
|
|
10
|
+
* Call a PyTorch function from a JAX function, passing in a `torch.Tensor`s.
|
|
11
|
+
* Use JAX features such as `jax.grad`, `optax`, and `GSPMD` to train a PyTorch
|
|
12
|
+
model.
|
|
13
|
+
* Use a PyTorch model as feature extractor and use it with a JAX model.
|
|
14
|
+
etc etc.
|
|
15
|
+
|
|
16
|
+
## Install
|
|
17
|
+
|
|
18
|
+
First install torch CPU:
|
|
19
|
+
|
|
20
|
+
```bash
|
|
21
|
+
# On Linux.
|
|
22
|
+
pip install torch --index-url https://download.pytorch.org/whl/cpu
|
|
23
|
+
|
|
24
|
+
# Or on Mac.
|
|
25
|
+
pip install torch
|
|
26
|
+
```
|
|
27
|
+
|
|
28
|
+
Then install JAX for the accelerator you want to use:
|
|
29
|
+
|
|
30
|
+
```bash
|
|
31
|
+
# On Google Cloud TPU.
|
|
32
|
+
pip install -U jax[tpu]
|
|
33
|
+
|
|
34
|
+
# Or, on GPU machines.
|
|
35
|
+
pip install -U jax[cuda12]
|
|
36
|
+
|
|
37
|
+
# Or, on Linux CPU machines or Macs (see the note below).
|
|
38
|
+
pip install -U jax
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
NOTE: if you like metal support for Apple devices then install the
|
|
42
|
+
metal version of JAX: https://developer.apple.com/metal/jax/
|
|
43
|
+
|
|
44
|
+
Finally install torchax:
|
|
45
|
+
|
|
46
|
+
```bash
|
|
47
|
+
# Install pre-built torchax.
|
|
48
|
+
pip install torchax
|
|
49
|
+
|
|
50
|
+
# Or, install torchax from source.
|
|
51
|
+
pip install git+https://github.com/pytorch/xla.git#subdirectory=torchax
|
|
52
|
+
```
|
|
53
|
+
|
|
54
|
+
## Run a model
|
|
55
|
+
|
|
56
|
+
Now let's execute a model under torchax. We'll start with a simple 2-layer model.
|
|
57
|
+
In theory, we can use any instance of `torch.nn.Module`.
|
|
58
|
+
|
|
59
|
+
```python
|
|
60
|
+
import torch
|
|
61
|
+
import torch.nn as nn
|
|
62
|
+
import torch.nn.functional as F
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class MyModel(nn.Module):
|
|
66
|
+
def __init__(self):
|
|
67
|
+
super().__init__()
|
|
68
|
+
self.fc1 = nn.Linear(28 * 28, 120)
|
|
69
|
+
self.fc2 = nn.Linear(120, 84)
|
|
70
|
+
self.fc3 = nn.Linear(84, 10)
|
|
71
|
+
|
|
72
|
+
def forward(self, x):
|
|
73
|
+
x = x.view(-1, 28 * 28)
|
|
74
|
+
x = F.relu(self.fc1(x))
|
|
75
|
+
x = F.relu(self.fc2(x))
|
|
76
|
+
x = self.fc3(x)
|
|
77
|
+
return x
|
|
78
|
+
|
|
79
|
+
m = MyModel()
|
|
80
|
+
|
|
81
|
+
# Execute this model using torch.
|
|
82
|
+
inputs = torch.randn(3, 3, 28, 28)
|
|
83
|
+
print(m(inputs))
|
|
84
|
+
```
|
|
85
|
+
|
|
86
|
+
To execute this model with `torchax`, we need to enable torchax to capture PyTorch ops:
|
|
87
|
+
|
|
88
|
+
```python
|
|
89
|
+
import torchax
|
|
90
|
+
torchax.enable_globally()
|
|
91
|
+
```
|
|
92
|
+
|
|
93
|
+
Then, we can use a `jax` device:
|
|
94
|
+
|
|
95
|
+
```python
|
|
96
|
+
inputs = torch.randn(3, 3, 28, 28, device='jax')
|
|
97
|
+
m = MyModel().to('jax')
|
|
98
|
+
res = m(inputs)
|
|
99
|
+
print(type(res)) # outputs torchax.tensor.Tensor
|
|
100
|
+
```
|
|
101
|
+
|
|
102
|
+
`torchax.tensor.Tensor` is a `torch.Tensor` subclass that holds
|
|
103
|
+
a `jax.Array`. You can inspect that JAX array with `res.jax()`.
|
|
104
|
+
|
|
105
|
+
## What is happening behind the scene
|
|
106
|
+
|
|
107
|
+
We took the approach detailed in the
|
|
108
|
+
[new device](https://github.com/albanD/subclass_zoo/blob/main/new_device.py)
|
|
109
|
+
recipe by Alban (@albanD), using `jax.Array` for `raw_data`.
|
|
110
|
+
|
|
111
|
+
In other words, when a torch op is executed inside an `env` context manager,
|
|
112
|
+
which is enabled by `torchax.enable_globally()`, we will swap out the
|
|
113
|
+
implementation of that op with JAX.
|
|
114
|
+
|
|
115
|
+
When a model's constructor runs, it will call some tensor constructor, such as
|
|
116
|
+
`torch.rand`, `torch.ones`, or `torch.zeros` to create its weights. When torchax
|
|
117
|
+
is enabled, these constructors will create a `torchax.tensor.Tensor`, which
|
|
118
|
+
contains a `jax.Array`.
|
|
119
|
+
|
|
120
|
+
Then, each subsequent op will extract the `jax.Array`, call the op's JAX
|
|
121
|
+
implementation, and wrap the result back into a `torchax.tensor.Tensor`,
|
|
122
|
+
|
|
123
|
+
See more at [how it works](docs/how_it_works.md) and\
|
|
124
|
+
[ops registry](docs/ops_registry.md).
|
|
125
|
+
|
|
126
|
+
### Executing with jax.jit
|
|
127
|
+
|
|
128
|
+
The above script will execute the model using eager mode JAX as the backend. This
|
|
129
|
+
does allow executing torch models on TPUs, but is often slower than what we can
|
|
130
|
+
achieve with `jax.jit`.
|
|
131
|
+
|
|
132
|
+
`jax.jit` is a function that takes a JAX function (i.e. a function that takes JAX arrays
|
|
133
|
+
and returns JAX arrays) into a compiled (thus faster) version of the same function.
|
|
134
|
+
|
|
135
|
+
We have made a `jax_jit` decorator that would accomplish the same with functions
|
|
136
|
+
that takes and returns `torch.Tensor`s. To use this, the first step is to create
|
|
137
|
+
a functional version of this model: this means the parameters should be passed in
|
|
138
|
+
as input instead of being attributes of the class:
|
|
139
|
+
|
|
140
|
+
```python
|
|
141
|
+
def model_func(param, inputs):
|
|
142
|
+
return torch.func.functional_call(m, param, inputs)
|
|
143
|
+
```
|
|
144
|
+
|
|
145
|
+
Here we use [torch.func.functional_call](https://pytorch.org/docs/stable/generated/torch.func.functional_call.html)
|
|
146
|
+
from PyTorch to replace the model weights with `param` and then call the
|
|
147
|
+
model. This is roughly equivalent to:
|
|
148
|
+
|
|
149
|
+
```python
|
|
150
|
+
def model_func(param, inputs):
|
|
151
|
+
m.load_state_dict(param)
|
|
152
|
+
return m(*inputs)
|
|
153
|
+
```
|
|
154
|
+
|
|
155
|
+
Now, we can apply `jax_jit` on `module_func`:
|
|
156
|
+
|
|
157
|
+
```python
|
|
158
|
+
from torchax.interop import jax_jit
|
|
159
|
+
|
|
160
|
+
model_func_jitted = jax_jit(model_func)
|
|
161
|
+
print(model_func_jitted(new_state_dict, inputs))
|
|
162
|
+
```
|
|
163
|
+
|
|
164
|
+
See more examples at [eager_mode.py](examples/eager_mode.py) and the
|
|
165
|
+
[examples folder](examples/).
|
|
166
|
+
|
|
167
|
+
To ease the idiom of creating functional model and calling it with parameters,
|
|
168
|
+
we also created the `JittableModule` helper class. It lets us rewrite the
|
|
169
|
+
above as:
|
|
170
|
+
|
|
171
|
+
```python
|
|
172
|
+
from torchax.interop import JittableModule
|
|
173
|
+
|
|
174
|
+
m_jitted = JittableModule(m)
|
|
175
|
+
res = m_jitted(...)
|
|
176
|
+
```
|
|
177
|
+
|
|
178
|
+
The first time `m_jitted` is called, it will trigger `jax.jit` to compile the
|
|
179
|
+
compile for the given input shapes. Subsequent calls with the same input shapes
|
|
180
|
+
will be fast as the compilation is cached.
|
|
181
|
+
|
|
182
|
+
## Citation
|
|
183
|
+
|
|
184
|
+
```
|
|
185
|
+
@software{torchax,
|
|
186
|
+
author = {Han Qi, Chun-nien Chan, Will Cromar, Manfei Bai, Kevin Gleanson},
|
|
187
|
+
title = {torchax: PyTorch on TPU and JAX interoperability},
|
|
188
|
+
url = {https://github.com/pytorch/xla/tree/master/torchax}
|
|
189
|
+
version = {0.0.4},
|
|
190
|
+
date = {2025-02-24},
|
|
191
|
+
}
|
|
192
|
+
```
|
|
193
|
+
|
|
194
|
+
# Maintainers & Contributors:
|
|
195
|
+
|
|
196
|
+
This library is created and maintained by the PyTorch/XLA team at Google Cloud.
|
|
197
|
+
|
|
198
|
+
It benefitted from many direct and indirect
|
|
199
|
+
contributions outside of the team. Many of them done by
|
|
200
|
+
fellow Googlers using [Google's 20% project policy](https://ebsedu.org/blog/google-tapping-workplace-actualization-20-time-rule).
|
|
201
|
+
Others by partner teams at Google and other companies.
|
|
202
|
+
|
|
203
|
+
Here is the list of contributors by 2025-02-25.
|
|
204
|
+
|
|
205
|
+
```
|
|
206
|
+
Han Qi (qihqi), PyTorch/XLA
|
|
207
|
+
Manfei Bai (manfeibai), PyTorch/XLA
|
|
208
|
+
Will Cromar (will-cromar), Meta
|
|
209
|
+
Milad Mohammadi (miladm), PyTorch/XLA
|
|
210
|
+
Siyuan Liu (lsy323), PyTorch/XLA
|
|
211
|
+
Bhavya Bahl (bhavya01), PyTorch/XLA
|
|
212
|
+
Pei Zhang (zpcore), PyTorch/XLA
|
|
213
|
+
Yifei Teng (tengyifei), PyTorch/XLA
|
|
214
|
+
Chunnien Chan (chunnienc), Google, ODML
|
|
215
|
+
Alban Desmaison (albanD), Meta, PyTorch
|
|
216
|
+
Simon Teo (simonteozw), Google (20%)
|
|
217
|
+
David Huang (dvhg), Google (20%)
|
|
218
|
+
Barni Seetharaman (barney-s), Google (20%)
|
|
219
|
+
Anish Karthik (anishfish2), Google (20%)
|
|
220
|
+
Yao Gu (guyao), Google (20%)
|
|
221
|
+
Yenkai Wang (yenkwang), Google (20%)
|
|
222
|
+
Greg Shikhman (commander), Google (20%)
|
|
223
|
+
Matin Akhlaghinia (matinehAkhlaghinia), Google (20%)
|
|
224
|
+
Tracy Chen (tracych477), Google (20%)
|
|
225
|
+
Matthias Guenther (mrguenther), Google (20%)
|
|
226
|
+
WenXin Dong (wenxindongwork), Google (20%)
|
|
227
|
+
Kevin Gleason (GleasonK), Google, StableHLO
|
|
228
|
+
Nupur Baghel (nupurbaghel), Google (20%)
|
|
229
|
+
Gwen Mittertreiner (gmittert), Google (20%)
|
|
230
|
+
Zeev Melumian (zmelumian), Lightricks
|
|
231
|
+
Vyom Sharma (vyom1611), Google (20%)
|
|
232
|
+
Shitong Wang (ShitongWang), Adobe
|
|
233
|
+
Rémi Doreau (ayshiff), Google (20%)
|
|
234
|
+
Lance Wang (wang2yn84), Google, CoreML
|
|
235
|
+
Hossein Sarshar (hosseinsarshar), Google (20%)
|
|
236
|
+
Daniel Vega-Myhre (danielvegamyhre), Google (20%)
|
|
237
|
+
Tianqi Fan (tqfan28), Google (20%)
|
|
238
|
+
Jim Lin (jimlinntu), Google (20%)
|
|
239
|
+
Fanhai Lu (FanhaiLu1), Google Cloud
|
|
240
|
+
DeWitt Clinton (dewitt), Google PyTorch
|
|
241
|
+
Aman Gupta (aman2930), Google (20%)
|
|
242
|
+
```
|
|
@@ -0,0 +1,5 @@
|
|
|
1
|
+
-f https://download.pytorch.org/whl/torch
|
|
2
|
+
torch==2.7.1 ; sys_platform == 'darwin' # macOS
|
|
3
|
+
torch==2.7.1+cpu; sys_platform != 'darwin' # Non-macOS (CPU-only), like on TPU
|
|
4
|
+
yapf==0.40.2 # N.B.: keep in sync with `infra/ansible/config/pip.yaml`, `.github/workflows/lintercheck.yml`
|
|
5
|
+
flax==0.10.6
|
|
@@ -17,7 +17,7 @@ Context:
|
|
|
17
17
|
### Remove one op from skiplist
|
|
18
18
|
|
|
19
19
|
Open [test/test_ops.py](../test/test_ops.py) with your
|
|
20
|
-
favorite text editor.
|
|
20
|
+
favorite text editor.
|
|
21
21
|
Remove one line from the `skiplist` set.
|
|
22
22
|
|
|
23
23
|
i.e.
|
|
@@ -49,7 +49,7 @@ For errors you might get after running test, there are two kind:
|
|
|
49
49
|
Error gotten:
|
|
50
50
|
|
|
51
51
|
```
|
|
52
|
-
(base) hanq-macbookpro:torchax hanq$ python test/test_ops.py
|
|
52
|
+
(base) hanq-macbookpro:torchax hanq$ python test/test_ops.py
|
|
53
53
|
...
|
|
54
54
|
E RuntimeError: ('No lowering found for\n\nTo execute this test, run the following from the base repo dir:\n python test/test_ops.py -k test_reference_eager_addbmm_cpu_int64\n\nThis message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0', 'aten::addbmm')
|
|
55
55
|
```
|
|
@@ -60,10 +60,10 @@ From here we have 2 strategies for fixing this test:
|
|
|
60
60
|
2. Add an implementation `aten::addbmm` operator using torch ops (this commonly known as "decompositions").
|
|
61
61
|
|
|
62
62
|
Either way works for torchax. For ops that are not "Core Aten" sometimes we implement in torch ops with the goal of
|
|
63
|
-
upstreaming this decomposition to [pytorch decompositon](https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py)
|
|
63
|
+
upstreaming this decomposition to [pytorch decompositon](https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py)
|
|
64
64
|
so other projects can benefit from it.
|
|
65
65
|
|
|
66
|
-
For illustration purposes, let's implement this op in Jax.
|
|
66
|
+
For illustration purposes, let's implement this op in Jax.
|
|
67
67
|
|
|
68
68
|
(NOTE: this doesn't stop us from upstreaming a decomposition later if we want)
|
|
69
69
|
|
|
@@ -104,7 +104,7 @@ Please try to fix it by following these steps:
|
|
|
104
104
|
|
|
105
105
|
### First Impl
|
|
106
106
|
|
|
107
|
-
To implement this op using jax ops, we first find what
|
|
107
|
+
To implement this op using jax ops, we first find what
|
|
108
108
|
is the exact semantics in this page:
|
|
109
109
|
https://pytorch.org/docs/stable/generated/torch.addbmm.html
|
|
110
110
|
|
|
@@ -124,7 +124,7 @@ Now running test again:
|
|
|
124
124
|
python test/test_ops.py -k test_reference_eager_addbmm_cpu_int64
|
|
125
125
|
```
|
|
126
126
|
|
|
127
|
-
(NOTE: the exact test command is printed out when we run
|
|
127
|
+
(NOTE: the exact test command is printed out when we run
|
|
128
128
|
`pytest test/test_ops.py` so we can only run the failed test instead of running all tests.)
|
|
129
129
|
|
|
130
130
|
We now see this error:
|
|
@@ -140,14 +140,14 @@ Traceback (most recent call last):
|
|
|
140
140
|
AssertionError: False is not true
|
|
141
141
|
```
|
|
142
142
|
|
|
143
|
-
This is telling me that our implementation did not produce
|
|
143
|
+
This is telling me that our implementation did not produce
|
|
144
144
|
the same result as the ops in PyTorch.
|
|
145
145
|
|
|
146
146
|
To debug this, let's figure out what exact input caused this.
|
|
147
|
-
We can achieve this by setting a break point [here](https://github.com/pytorch/xla/blob/master/experimental/torchax/test/test_ops.py#L644), right before the diff. Here we can
|
|
147
|
+
We can achieve this by setting a break point [here](https://github.com/pytorch/xla/blob/master/experimental/torchax/test/test_ops.py#L644), right before the diff. Here we can
|
|
148
148
|
inspect values of `res` and `res2`, as well as the `sample_input`.
|
|
149
149
|
|
|
150
|
-
The sample input we get is
|
|
150
|
+
The sample input we get is
|
|
151
151
|
```
|
|
152
152
|
SampleInput(input=tensor([[-3, -3, 9, 8, -8, -3, -4, 2, 2, 2],
|
|
153
153
|
[-5, 1, -9, 9, 1, -5, 6, 1, -4, -5],
|
|
@@ -212,7 +212,7 @@ SampleInput(input=tensor([[-3, -3, 9, 8, -8, -3, -4, 2, 2, 2],
|
|
|
212
212
|
[ 7, 3, 0, 1, 1, -9, 5, -8, -1, -7]]])), kwargs={'beta': 0.6, 'alpha': 0.2}, broadcasts_input=False, name='')
|
|
213
213
|
```
|
|
214
214
|
|
|
215
|
-
And the `res` from torch is
|
|
215
|
+
And the `res` from torch is
|
|
216
216
|
|
|
217
217
|
```
|
|
218
218
|
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|