torchax 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/CONTRIBUTING.md +38 -0
- torchax/__init__.py +124 -0
- torchax/config.py +19 -0
- torchax/decompositions.py +308 -0
- torchax/device_module.py +20 -0
- torchax/distributed.py +246 -0
- torchax/environment.py +2 -0
- torchax/export.py +236 -0
- torchax/interop.py +209 -0
- torchax/ops/__init__.py +10 -0
- torchax/ops/jaten.py +5212 -0
- torchax/ops/jax_reimplement.py +169 -0
- torchax/ops/jc10d.py +51 -0
- torchax/ops/jlibrary.py +73 -0
- torchax/ops/jtorch.py +427 -0
- torchax/ops/jtorchvision_nms.py +245 -0
- torchax/ops/mappings.py +97 -0
- torchax/ops/op_base.py +104 -0
- torchax/ops/ops_registry.py +50 -0
- torchax/tensor.py +557 -0
- torchax/tf_integration.py +119 -0
- torchax/train.py +120 -0
- torchax/types.py +12 -0
- torchax-0.0.4.dist-info/METADATA +341 -0
- torchax-0.0.4.dist-info/RECORD +27 -0
- torchax-0.0.4.dist-info/WHEEL +4 -0
- torchax-0.0.4.dist-info/licenses/LICENSE +28 -0
torchax/train.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
import collections
|
|
2
|
+
import functools
|
|
3
|
+
import torch
|
|
4
|
+
import jax
|
|
5
|
+
import torchax
|
|
6
|
+
from torchax import interop
|
|
7
|
+
from torchax.interop import torch_view, jax_view
|
|
8
|
+
import optax
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
remat = torch_view(jax.remat)
|
|
12
|
+
mark_sharding = torch_view(jax.lax.with_sharding_constraint)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def make_train_step(model_fn,
|
|
16
|
+
loss_fn, optax_optimizer,
|
|
17
|
+
remat_policy=None):
|
|
18
|
+
"""Make a function that do one train step given model and loss.
|
|
19
|
+
|
|
20
|
+
model_fn: a function representing the model's forward:
|
|
21
|
+
i.e. has signature Callable[weights, buffers, args] -> result. Where,
|
|
22
|
+
weights is a pytree of trainable parameters
|
|
23
|
+
buffers is a pytree of non-trainable parameters / constants
|
|
24
|
+
args is the input data loaded from the data set
|
|
25
|
+
result is the return value of the model
|
|
26
|
+
loss_fn: a function to compute loss.
|
|
27
|
+
i.e. it has signature of Callable[result, label] -> loss
|
|
28
|
+
where, result is what model_fn returned
|
|
29
|
+
loss is loaded from the dataloader.
|
|
30
|
+
optax_optimizer: the optimizer from optax library. for example, optax.adam
|
|
31
|
+
remat_policy: One of jax.ad_checkpoint.checkpoint_policies, specifies how
|
|
32
|
+
to do gradient checkpointing. If None, then it means checkpoint everything.
|
|
33
|
+
"""
|
|
34
|
+
env = torchax.default_env()
|
|
35
|
+
def loss(weights, buffers, args, label): # inputs are XLATensor
|
|
36
|
+
with env, jax.named_scope('compute_loss'):
|
|
37
|
+
res = model_fn(weights, buffers, args)
|
|
38
|
+
l = loss_fn(res, label)
|
|
39
|
+
return l
|
|
40
|
+
|
|
41
|
+
loss = interop.gradient_checkpoint(loss, kwargs={'policy': remat_policy})
|
|
42
|
+
grad_fn = interop.jax_value_and_grad(loss)
|
|
43
|
+
|
|
44
|
+
def step(weights, buffers, opt_state, args, label): #inputs are array
|
|
45
|
+
with jax.named_scope('compute_gradient'):
|
|
46
|
+
loss, gradient = grad_fn(weights, buffers, args, label)
|
|
47
|
+
|
|
48
|
+
with jax.named_scope("optimizer_updates"):
|
|
49
|
+
updates, opt_state = interop.call_jax(
|
|
50
|
+
optax_optimizer.update,
|
|
51
|
+
gradient, opt_state, weights)
|
|
52
|
+
weights = interop.call_jax(optax.apply_updates, weights, updates)
|
|
53
|
+
return loss, weights, opt_state
|
|
54
|
+
|
|
55
|
+
# TODO: apply jax.jit so the user don't have to.
|
|
56
|
+
return step
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class Container:
|
|
62
|
+
pass
|
|
63
|
+
|
|
64
|
+
class ScannedModule(torch.nn.Module):
|
|
65
|
+
|
|
66
|
+
def __init__(self, module_list, checkpoint_policy=None):
|
|
67
|
+
super().__init__()
|
|
68
|
+
|
|
69
|
+
self.c = None
|
|
70
|
+
assert module_list
|
|
71
|
+
self.c = Container()
|
|
72
|
+
self.c.one_mod = module_list[0]
|
|
73
|
+
self.checkpoint_policy = checkpoint_policy
|
|
74
|
+
|
|
75
|
+
weights = self._stack_layer_weights(module_list)
|
|
76
|
+
self.layer_weights_keys = list(self.c.one_mod.state_dict().keys())
|
|
77
|
+
self.params = torch.nn.ParameterDict({
|
|
78
|
+
self._param_name_new(k): v for k, v in weights.items()
|
|
79
|
+
})
|
|
80
|
+
|
|
81
|
+
def _stack_layer_weights(self, module_list):
|
|
82
|
+
# Create weights such that, for every [n, m] weights
|
|
83
|
+
# becomes [k, n, m] where k is number of layer
|
|
84
|
+
# i.e. stacking layer weights together
|
|
85
|
+
temp = collections.defaultdict(list)
|
|
86
|
+
for m in module_list:
|
|
87
|
+
for k, v in m.state_dict().items():
|
|
88
|
+
temp[k].append(v)
|
|
89
|
+
res = {k: torch.stack(v) for k, v in temp.items()}
|
|
90
|
+
return res
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _param_name_new(self, old):
|
|
94
|
+
return '___'.join(old.split('.'))
|
|
95
|
+
|
|
96
|
+
def _param_name_old(self, new):
|
|
97
|
+
return '.'.join(new.split('___'))
|
|
98
|
+
|
|
99
|
+
def forward(self, *args, **kwargs):
|
|
100
|
+
assert not kwargs
|
|
101
|
+
weights = {k: self.params[self._param_name_new(k)] for k in self.layer_weights_keys}
|
|
102
|
+
scan = interop.torch_view(jax.lax.scan)
|
|
103
|
+
|
|
104
|
+
def eval_one_layer(args, weight):
|
|
105
|
+
# unpack args
|
|
106
|
+
h, *rest = args
|
|
107
|
+
newh = torch.func.functional_call(self.c.one_mod, weight, args)
|
|
108
|
+
# next layer's input; and residual to be added to list
|
|
109
|
+
return (newh, *rest), None
|
|
110
|
+
|
|
111
|
+
_eval_one_layer = interop.gradient_checkpoint(
|
|
112
|
+
eval_one_layer,
|
|
113
|
+
kwargs={'policy': self.checkpoint_policy},
|
|
114
|
+
)
|
|
115
|
+
h, _ = scan(
|
|
116
|
+
_eval_one_layer,
|
|
117
|
+
args,
|
|
118
|
+
weights,
|
|
119
|
+
)
|
|
120
|
+
return h[0]
|
torchax/types.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from typing import Callable, Any, Union, ParamSpec, TypeAlias
|
|
2
|
+
import torch
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
import sys
|
|
6
|
+
|
|
7
|
+
P = ParamSpec('P')
|
|
8
|
+
|
|
9
|
+
TorchValue: TypeAlias = Union[torch.Tensor, torch.dtype, 'TorchCallable', Any]
|
|
10
|
+
TorchCallable: TypeAlias = Callable[P, TorchValue]
|
|
11
|
+
JaxValue: TypeAlias = Union[jax.Array, jnp.dtype, 'JaxCallable', Any]
|
|
12
|
+
JaxCallable: TypeAlias = Callable[P, JaxValue]
|
|
@@ -0,0 +1,341 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: torchax
|
|
3
|
+
Version: 0.0.4
|
|
4
|
+
Summary: torchax is a library for running PyTorch on TPU
|
|
5
|
+
Project-URL: Homepage, https://github.com/pytorch/xla/tree/master/torchax
|
|
6
|
+
Author-email: Han Qi <qihan.dev@gmail.com>, Pytorch/XLA team <pytorchxla-dev@google.com>
|
|
7
|
+
License: BSD 3-Clause License
|
|
8
|
+
|
|
9
|
+
Copyright (c) 2023, pytorch-tpu
|
|
10
|
+
|
|
11
|
+
Redistribution and use in source and binary forms, with or without
|
|
12
|
+
modification, are permitted provided that the following conditions are met:
|
|
13
|
+
|
|
14
|
+
1. Redistributions of source code must retain the above copyright notice, this
|
|
15
|
+
list of conditions and the following disclaimer.
|
|
16
|
+
|
|
17
|
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
|
18
|
+
this list of conditions and the following disclaimer in the documentation
|
|
19
|
+
and/or other materials provided with the distribution.
|
|
20
|
+
|
|
21
|
+
3. Neither the name of the copyright holder nor the names of its
|
|
22
|
+
contributors may be used to endorse or promote products derived from
|
|
23
|
+
this software without specific prior written permission.
|
|
24
|
+
|
|
25
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
26
|
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
27
|
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
28
|
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
29
|
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
30
|
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
31
|
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
32
|
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
33
|
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
34
|
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
35
|
+
License-File: LICENSE
|
|
36
|
+
Classifier: Development Status :: 3 - Alpha
|
|
37
|
+
Classifier: Intended Audience :: Developers
|
|
38
|
+
Classifier: Intended Audience :: Education
|
|
39
|
+
Classifier: Intended Audience :: Science/Research
|
|
40
|
+
Classifier: License :: OSI Approved :: BSD License
|
|
41
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
42
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
43
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
44
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
45
|
+
Classifier: Topic :: Scientific/Engineering
|
|
46
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
47
|
+
Classifier: Topic :: Scientific/Engineering :: Mathematics
|
|
48
|
+
Classifier: Topic :: Software Development
|
|
49
|
+
Classifier: Topic :: Software Development :: Libraries
|
|
50
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
51
|
+
Requires-Python: >=3.10
|
|
52
|
+
Provides-Extra: cpu
|
|
53
|
+
Requires-Dist: jax[cpu]; extra == 'cpu'
|
|
54
|
+
Requires-Dist: jax[cpu]>=0.4.30; extra == 'cpu'
|
|
55
|
+
Provides-Extra: cuda
|
|
56
|
+
Requires-Dist: jax[cpu]>=0.4.30; extra == 'cuda'
|
|
57
|
+
Requires-Dist: jax[cuda12]; extra == 'cuda'
|
|
58
|
+
Provides-Extra: odml
|
|
59
|
+
Requires-Dist: jax[cpu]; extra == 'odml'
|
|
60
|
+
Requires-Dist: jax[cpu]>=0.4.30; extra == 'odml'
|
|
61
|
+
Provides-Extra: tpu
|
|
62
|
+
Requires-Dist: jax[cpu]>=0.4.30; extra == 'tpu'
|
|
63
|
+
Requires-Dist: jax[tpu]; extra == 'tpu'
|
|
64
|
+
Description-Content-Type: text/markdown
|
|
65
|
+
|
|
66
|
+
# torchax: Running PyTorch on TPU
|
|
67
|
+
|
|
68
|
+
**torchax!** is a backend for PyTorch, allowing users to run
|
|
69
|
+
PyTorch on Google CloudTPUs. **torchax!** is also a library for providing
|
|
70
|
+
graph-level interoperability between PyTorch and Jax.
|
|
71
|
+
|
|
72
|
+
This means, with **torchax** you can:
|
|
73
|
+
* Run PyTorch code on TPU with as little as 2 lines of code change.
|
|
74
|
+
* Call a jax function from a pytorch function, passing in `jax.Array`s
|
|
75
|
+
* Call a pytorch function from a jax function, passing in a `torch.Tensor` subclass.
|
|
76
|
+
* Use jax features such as `jax.grad`, `optax` and `GSMPD` to train a Pytorch model.
|
|
77
|
+
* Use a Pytorch model as feature extractor and use it with a Jax model.
|
|
78
|
+
etc etc.
|
|
79
|
+
|
|
80
|
+
## Install
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
### On Google Cloud TPU:
|
|
84
|
+
First install torch CPU:
|
|
85
|
+
|
|
86
|
+
```bash
|
|
87
|
+
pip install torch --index-url https://download.pytorch.org/whl/cpu
|
|
88
|
+
```
|
|
89
|
+
|
|
90
|
+
Then install jax TPU:
|
|
91
|
+
|
|
92
|
+
```bash
|
|
93
|
+
pip install -U jax[tpu]
|
|
94
|
+
```
|
|
95
|
+
|
|
96
|
+
Finally install torchax
|
|
97
|
+
|
|
98
|
+
```bash
|
|
99
|
+
pip install torchax
|
|
100
|
+
```
|
|
101
|
+
|
|
102
|
+
### On GPU machines:
|
|
103
|
+
First install torch CPU:
|
|
104
|
+
|
|
105
|
+
```bash
|
|
106
|
+
pip install torch --index-url https://download.pytorch.org/whl/cpu
|
|
107
|
+
```
|
|
108
|
+
|
|
109
|
+
Then install jax CUDA:
|
|
110
|
+
|
|
111
|
+
```bash
|
|
112
|
+
pip install -U jax[cuda12]
|
|
113
|
+
```
|
|
114
|
+
|
|
115
|
+
Finally install torchax
|
|
116
|
+
|
|
117
|
+
```bash
|
|
118
|
+
pip install torchax
|
|
119
|
+
```
|
|
120
|
+
|
|
121
|
+
### On CPU machines (mac included)
|
|
122
|
+
First install torch CPU:
|
|
123
|
+
|
|
124
|
+
```bash
|
|
125
|
+
# Linux
|
|
126
|
+
pip install torch --index-url https://download.pytorch.org/whl/cpu
|
|
127
|
+
|
|
128
|
+
# OR Mac:
|
|
129
|
+
pip install torch
|
|
130
|
+
```
|
|
131
|
+
|
|
132
|
+
Then install jax CPU:
|
|
133
|
+
|
|
134
|
+
```bash
|
|
135
|
+
pip install -U jax
|
|
136
|
+
```
|
|
137
|
+
|
|
138
|
+
Finally install torchax
|
|
139
|
+
|
|
140
|
+
```bash
|
|
141
|
+
pip install torchax
|
|
142
|
+
```
|
|
143
|
+
|
|
144
|
+
NOTE: if you like metal support for Apple devices then install the
|
|
145
|
+
metal version of jax: https://developer.apple.com/metal/jax/
|
|
146
|
+
|
|
147
|
+
### Installing `torchax` from source
|
|
148
|
+
|
|
149
|
+
Still need to install `torch` CPU and `Jax` of your accelerator (GPU, TPU or None).
|
|
150
|
+
|
|
151
|
+
```bash
|
|
152
|
+
pip install git+https://github.com/pytorch/xla.git#subdirectory=torchax
|
|
153
|
+
```
|
|
154
|
+
|
|
155
|
+
## Run a model
|
|
156
|
+
|
|
157
|
+
Now let's execute a model under torchax. We'll start with a simple 2-layer model
|
|
158
|
+
it can be in theory any instance of `torch.nn.Module`.
|
|
159
|
+
|
|
160
|
+
```python
|
|
161
|
+
import torch
|
|
162
|
+
import torch.nn as nn
|
|
163
|
+
import torch.nn.functional as F
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class MyModel(nn.Module):
|
|
167
|
+
def __init__(self):
|
|
168
|
+
super().__init__()
|
|
169
|
+
self.fc1 = nn.Linear(28 * 28, 120)
|
|
170
|
+
self.fc2 = nn.Linear(120, 84)
|
|
171
|
+
self.fc3 = nn.Linear(84, 10)
|
|
172
|
+
|
|
173
|
+
def forward(self, x):
|
|
174
|
+
x = x.view(-1, 28 * 28)
|
|
175
|
+
x = F.relu(self.fc1(x))
|
|
176
|
+
x = F.relu(self.fc2(x))
|
|
177
|
+
x = self.fc3(x)
|
|
178
|
+
return x
|
|
179
|
+
|
|
180
|
+
m = MyModel()
|
|
181
|
+
|
|
182
|
+
# Execute this model using torch
|
|
183
|
+
inputs = torch.randn(3, 3, 28, 28)
|
|
184
|
+
print(m(inputs))
|
|
185
|
+
```
|
|
186
|
+
|
|
187
|
+
This model `m` contains 2 parts: the weights that is stored inside of the model
|
|
188
|
+
and it's submodules (`nn.Linear`).
|
|
189
|
+
|
|
190
|
+
To execute this model with `torchax`; we need to enable torchax to capture pytorch ops.
|
|
191
|
+
To enable this, use:
|
|
192
|
+
|
|
193
|
+
```python
|
|
194
|
+
import torchax
|
|
195
|
+
torchax.enable_globally()
|
|
196
|
+
```
|
|
197
|
+
Then, a `jax` device will be available to use
|
|
198
|
+
|
|
199
|
+
```python
|
|
200
|
+
inputs = torch.randn(3, 3, 28, 28, device='jax')
|
|
201
|
+
m = MyModel()
|
|
202
|
+
res = m(inputs)
|
|
203
|
+
print(type(res)) # outputs torchax.tensor.Tensor
|
|
204
|
+
```
|
|
205
|
+
|
|
206
|
+
`torchax.tensor.Tensor` is a `torch.Tensor` subclass that holds
|
|
207
|
+
a `jax.Array`. You can inspect that jax array with `res.jax()`
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
## What is happening behind the scene:
|
|
211
|
+
|
|
212
|
+
We took the approach detailed in [new device](https://github.com/albanD/subclass_zoo/blob/main/new_device.py) recipe by Alban (@albanD); using `jax.Array` for the `raw_data`.
|
|
213
|
+
|
|
214
|
+
In other words, When a torch op is executed inside of `env` context manager (which is enabled with `torchax.enable_globally()`), we can swap out the
|
|
215
|
+
implementation of that op written in Jax.
|
|
216
|
+
|
|
217
|
+
When a model's constructor runs, it will call some tensor constructor, such as
|
|
218
|
+
`torch.rand`, `torch.ones` or `torch.zeros` etc to create its weights. The constructor
|
|
219
|
+
will create an `torch.Tensor` subclass that contains a `jax.Array`.
|
|
220
|
+
|
|
221
|
+
Then, each subsequent op can unpack the `jax.Array`, call the op implementation,
|
|
222
|
+
and wraps it back into `torch.Tensor` subclass.
|
|
223
|
+
|
|
224
|
+
See more at [how_it_works](docs/how_it_works.md) and [ops registry](docs/ops_registry.md).
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
### Executing with jax.jit
|
|
228
|
+
|
|
229
|
+
The above script will execute the model using eager mode Jax as backend. This
|
|
230
|
+
does allow executing torch models on TPU, but is often slower than what we can
|
|
231
|
+
achieve with `jax.jit`.
|
|
232
|
+
|
|
233
|
+
`jax.jit` is a function that takes a Jax function (i.e. a function that takes jax array
|
|
234
|
+
and returns jax array) into the same function, but faster.
|
|
235
|
+
|
|
236
|
+
We have made the `jax_jit` decorator that would accomplish the same with functions
|
|
237
|
+
that takes and returns `torch.Tensor`. To use this, the first step is to create
|
|
238
|
+
a functional version of this model: this means the parameters should be passed in
|
|
239
|
+
as input instead of being attributes on class:
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
```python
|
|
243
|
+
|
|
244
|
+
def model_func(param, inputs):
|
|
245
|
+
return torch.func.functional_call(m, param, inputs)
|
|
246
|
+
|
|
247
|
+
```
|
|
248
|
+
Here we use [torch.func.functional_call](https://pytorch.org/docs/stable/generated/torch.func.functional_call.html)
|
|
249
|
+
from PyTorch to replace the model
|
|
250
|
+
weights with `param`, then call the model. This is roughly equivalent to:
|
|
251
|
+
|
|
252
|
+
```python
|
|
253
|
+
def model_func(param, inputs):
|
|
254
|
+
m.load_state_dict(param)
|
|
255
|
+
return m(*inputs)
|
|
256
|
+
```
|
|
257
|
+
|
|
258
|
+
Now, we can apply `jax_jit`
|
|
259
|
+
|
|
260
|
+
```python
|
|
261
|
+
from torchax.interop import jax_jit
|
|
262
|
+
model_func_jitted = jax_jit(model_func)
|
|
263
|
+
print(model_func_jitted(new_state_dict, inputs))
|
|
264
|
+
```
|
|
265
|
+
|
|
266
|
+
See more examples at [eager_mode.py](examples/eager_mode.py) and the (examples folder)[examples/]
|
|
267
|
+
|
|
268
|
+
However, to ease the idiom of creating functional model and calling it with parameters,
|
|
269
|
+
we also created the `JittableModule` helper class.
|
|
270
|
+
|
|
271
|
+
So the above can be written as:
|
|
272
|
+
|
|
273
|
+
```python
|
|
274
|
+
|
|
275
|
+
from torchax.interop import JittableModule
|
|
276
|
+
|
|
277
|
+
m_jitted = JittableModule(m)
|
|
278
|
+
res = m_jitted(...)
|
|
279
|
+
```
|
|
280
|
+
|
|
281
|
+
The first time that `m_jitted` is called , it will trigger `jax.jit`
|
|
282
|
+
then the subsequent computation with inputs of same shape will be fast.
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
# Citation:
|
|
287
|
+
|
|
288
|
+
@software{torchax,
|
|
289
|
+
author = {Han Qi, Chun-nien Chan, Will Cromar, Manfei Bai, Kevin Gleanson},
|
|
290
|
+
title = {torchax: PyTorch on TPU and Jax interoperability},
|
|
291
|
+
url = {https://github.com/pytorch/xla/tree/master/torchax}
|
|
292
|
+
version = {0.0.4},
|
|
293
|
+
date = {2025-02-24},
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
# Maintainers & Contributors:
|
|
297
|
+
|
|
298
|
+
This library is created and maintained by the PyTorch/XLA team at Google Cloud.
|
|
299
|
+
|
|
300
|
+
However, it benefitted from many direct and indirect
|
|
301
|
+
contributions outside of the team. Many of them done by
|
|
302
|
+
fellow Googlers using [Google's 20% project policy](https://ebsedu.org/blog/google-tapping-workplace-actualization-20-time-rule), others by partner teams.
|
|
303
|
+
|
|
304
|
+
Here is the full list of contributors by 2025-02-25.
|
|
305
|
+
|
|
306
|
+
Han Qi (qihqi), Pytorch / XLA
|
|
307
|
+
Manfei Bai (manfeibai), Pytorch / XLA
|
|
308
|
+
Will Cromar (will-cromar), Meta
|
|
309
|
+
Milad Mohammadi (miladm), Pytorch / XLA
|
|
310
|
+
Siyuan Liu (lsy323), Pytorch / XLA
|
|
311
|
+
Bhavya Bahl (bhavya01), Pytorch / XLA
|
|
312
|
+
Pei Zhang (zpcore), Pytorch / XLA
|
|
313
|
+
Yifei Teng (tengyifei), Pytorch / XLA
|
|
314
|
+
Chunnien Chan (chunnienc), Google, ODML
|
|
315
|
+
Alban Desmaison (albanD), Meta, Pytorch
|
|
316
|
+
Simon Teo (simonteozw), Google(20%)
|
|
317
|
+
David Huang (dvhg), Google(20%)
|
|
318
|
+
Barni Seetharaman (barney-s), Google(20%)
|
|
319
|
+
Anish Karthik (anishfish2) , Google(20%)
|
|
320
|
+
Yao Gu (guyao) , Google(20%)
|
|
321
|
+
Yenkai Wang (yenkwang) , Google(20%)
|
|
322
|
+
Greg Shikhman (commander) , Google(20%)
|
|
323
|
+
Matin Akhlaghinia (matinehAkhlaghinia), Google(20%)
|
|
324
|
+
Tracy Chen (tracych477), Google(20%)
|
|
325
|
+
Matthias Guenther (mrguenther) , Google(20%)
|
|
326
|
+
WenXin Dong (wenxindongwork), Google(20%)
|
|
327
|
+
Kevin Gleason (GleasonK) , Google, StableHLO
|
|
328
|
+
Nupur Baghel (nupurbaghel), Google(20%)
|
|
329
|
+
Gwen Mittertreiner (gmittert), Google(20%)
|
|
330
|
+
Zeev Melumian (zmelumian), Lightricks
|
|
331
|
+
Vyom Sharma (vyom1611), Google(20%)
|
|
332
|
+
Shitong Wang (ShitongWang), Adobe
|
|
333
|
+
Rémi Doreau (ayshiff), Google(20%)
|
|
334
|
+
Lance Wang (wang2yn84), Google, CoreML
|
|
335
|
+
Hossein Sarshar (hosseinsarshar) , Google(20%)
|
|
336
|
+
Daniel Vega-Myhre (danielvegamyhre) , Google(20%)
|
|
337
|
+
Tianqi Fan (tqfan28), Google(20%)
|
|
338
|
+
Jim Lin (jimlinntu), Google(20%)
|
|
339
|
+
Fanhai Lu (FanhaiLu1), Google Cloud
|
|
340
|
+
DeWitt Clinton (dewitt), Google PyTorch
|
|
341
|
+
Aman Gupta (aman2930) , Google(20%)
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
torchax/CONTRIBUTING.md,sha256=T6Uy7vZvzXArbPqqUswEtQBgTU_eRwtJnHOiZNth2Pw,1371
|
|
2
|
+
torchax/__init__.py,sha256=0NoxpzaPIMYMJXZz_JiRCTYZCuICnTFrLDm5cZKGIm0,3242
|
|
3
|
+
torchax/config.py,sha256=Zpfcn7Q3QsKnn4cCx1-bYxA8q_lqrWJcUz30B3QcIc4,535
|
|
4
|
+
torchax/decompositions.py,sha256=C067xSd9MpGs1SDqlC5T9RnqNj-i-m31mysfFGDxLEQ,11221
|
|
5
|
+
torchax/device_module.py,sha256=NNEGSfk9ApVVSy5dRjHkvGuNGamopCEjSlItcNlAIbE,253
|
|
6
|
+
torchax/distributed.py,sha256=CXkpV0K2Oas7fcRA-i3s2VXTz8l4CEI7_NnbYOsHtRw,7519
|
|
7
|
+
torchax/environment.py,sha256=daEdpEyAJIa8b2VkCqSKcw8PaExcB6Qro80XNes_sHA,2
|
|
8
|
+
torchax/export.py,sha256=MG19Y0QcRcX3S3gruy9K-wFGdemD4I0ii03HWQM31xk,8853
|
|
9
|
+
torchax/interop.py,sha256=fncREqC1Nib-ud7wyekmtKaif7GSDql7ITmzTt5A9e4,6774
|
|
10
|
+
torchax/tensor.py,sha256=89yXmyO1uT6F1KI-Kn6NCnIieCiiMsDutgS071_9p64,16644
|
|
11
|
+
torchax/tf_integration.py,sha256=d_h4vSJm7N9rJXpUPNCDOiUz3J1-UPo3KU8D9Wi4nnc,4074
|
|
12
|
+
torchax/train.py,sha256=Ym_SC0WOmhQMhSCH0aZ-zgP7Q19WdCDTeoI-JBFrqx8,3930
|
|
13
|
+
torchax/types.py,sha256=j4ERjkgDgwhgi9zrwwbbiv4HMDlrJ1IEMUCmP_BIJ9M,388
|
|
14
|
+
torchax/ops/__init__.py,sha256=Vr1p8zDHwfXZBUbw70iNiCJLZLNdI6gR_vUlaiA7Usg,270
|
|
15
|
+
torchax/ops/jaten.py,sha256=ycTSyCzgJpz5zu936BTZXmRJWuzhxeGOOjoffU00c3g,154999
|
|
16
|
+
torchax/ops/jax_reimplement.py,sha256=29SGPF0bszgCuyVxiZ4A2xTPtAuPXBDvDG79M3C0Vxo,7468
|
|
17
|
+
torchax/ops/jc10d.py,sha256=DiJjiUFHyYmFUHEgqEDbGsLaCsJs94xJlBatRwXiPfg,1317
|
|
18
|
+
torchax/ops/jlibrary.py,sha256=R22GUlK6lUivuDSmvtqwbV9ku_uT0ddLcQcwHI96vqM,2927
|
|
19
|
+
torchax/ops/jtorch.py,sha256=fWU2z5uD3FpVE5ui-O5nfHEule2hod4BqCnsbV5jVfQ,14463
|
|
20
|
+
torchax/ops/jtorchvision_nms.py,sha256=SiMN9bUV7GADD2S2PkH9spn5oQGgr2jhFQ1nyd3eHgI,8873
|
|
21
|
+
torchax/ops/mappings.py,sha256=a8C5CYIEucGIxUPQHCsDZypBC3b_BYtOKQTWf9bdiRc,3060
|
|
22
|
+
torchax/ops/op_base.py,sha256=EdlAx3oqU2brMngjZvTxECin0eNiPdAUfOLRtFL3tgg,2900
|
|
23
|
+
torchax/ops/ops_registry.py,sha256=nXCAXiQNpC0uBUl8xEim7TOUTyROZm_4H1LaOvYwJVQ,1241
|
|
24
|
+
torchax-0.0.4.dist-info/METADATA,sha256=4VQb39quiN9U9j955fs9zY00x4t4Orubo92ONCCLyug,11119
|
|
25
|
+
torchax-0.0.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
26
|
+
torchax-0.0.4.dist-info/licenses/LICENSE,sha256=ZHyir3-ltOerFLt9JH1bjf7lIxIWipFmqeMnB_8z_aU,1498
|
|
27
|
+
torchax-0.0.4.dist-info/RECORD,,
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
BSD 3-Clause License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2023, pytorch-tpu
|
|
4
|
+
|
|
5
|
+
Redistribution and use in source and binary forms, with or without
|
|
6
|
+
modification, are permitted provided that the following conditions are met:
|
|
7
|
+
|
|
8
|
+
1. Redistributions of source code must retain the above copyright notice, this
|
|
9
|
+
list of conditions and the following disclaimer.
|
|
10
|
+
|
|
11
|
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
|
12
|
+
this list of conditions and the following disclaimer in the documentation
|
|
13
|
+
and/or other materials provided with the distribution.
|
|
14
|
+
|
|
15
|
+
3. Neither the name of the copyright holder nor the names of its
|
|
16
|
+
contributors may be used to endorse or promote products derived from
|
|
17
|
+
this software without specific prior written permission.
|
|
18
|
+
|
|
19
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
20
|
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
21
|
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
22
|
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
23
|
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
24
|
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
25
|
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
26
|
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
27
|
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
28
|
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|