torchax 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/CONTRIBUTING.md +2 -2
- torchax/__init__.py +26 -24
- torchax/amp.py +332 -0
- torchax/config.py +25 -14
- torchax/configuration.py +30 -0
- torchax/decompositions.py +663 -195
- torchax/device_module.py +14 -1
- torchax/environment.py +0 -1
- torchax/export.py +26 -17
- torchax/flax.py +39 -0
- torchax/interop.py +288 -141
- torchax/mesh_util.py +220 -0
- torchax/ops/jaten.py +1723 -1297
- torchax/ops/jax_reimplement.py +23 -21
- torchax/ops/jc10d.py +5 -4
- torchax/ops/jimage.py +113 -0
- torchax/ops/jlibrary.py +9 -2
- torchax/ops/jtorch.py +237 -88
- torchax/ops/jtorchvision_nms.py +32 -43
- torchax/ops/mappings.py +77 -35
- torchax/ops/op_base.py +59 -32
- torchax/ops/ops_registry.py +40 -35
- torchax/tensor.py +442 -288
- torchax/train.py +38 -41
- torchax/util.py +88 -0
- torchax/view.py +377 -0
- {torchax-0.0.4.dist-info → torchax-0.0.6.dist-info}/METADATA +111 -145
- torchax-0.0.6.dist-info/RECORD +33 -0
- torchax/distributed.py +0 -246
- torchax-0.0.4.dist-info/RECORD +0 -27
- {torchax-0.0.4.dist-info → torchax-0.0.6.dist-info}/WHEEL +0 -0
- {torchax-0.0.4.dist-info → torchax-0.0.6.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: torchax
|
|
3
|
-
Version: 0.0.
|
|
4
|
-
Summary: torchax is a library for running PyTorch
|
|
3
|
+
Version: 0.0.6
|
|
4
|
+
Summary: torchax is a library for running Jax and PyTorch together
|
|
5
5
|
Project-URL: Homepage, https://github.com/pytorch/xla/tree/master/torchax
|
|
6
6
|
Author-email: Han Qi <qihan.dev@gmail.com>, Pytorch/XLA team <pytorchxla-dev@google.com>
|
|
7
7
|
License: BSD 3-Clause License
|
|
@@ -51,111 +51,75 @@ Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
|
51
51
|
Requires-Python: >=3.10
|
|
52
52
|
Provides-Extra: cpu
|
|
53
53
|
Requires-Dist: jax[cpu]; extra == 'cpu'
|
|
54
|
-
Requires-Dist: jax[cpu]>=0.
|
|
54
|
+
Requires-Dist: jax[cpu]>=0.6.2; extra == 'cpu'
|
|
55
55
|
Provides-Extra: cuda
|
|
56
|
-
Requires-Dist: jax[cpu]>=0.
|
|
56
|
+
Requires-Dist: jax[cpu]>=0.6.2; extra == 'cuda'
|
|
57
57
|
Requires-Dist: jax[cuda12]; extra == 'cuda'
|
|
58
58
|
Provides-Extra: odml
|
|
59
59
|
Requires-Dist: jax[cpu]; extra == 'odml'
|
|
60
|
-
Requires-Dist: jax[cpu]>=0.
|
|
60
|
+
Requires-Dist: jax[cpu]>=0.6.2; extra == 'odml'
|
|
61
61
|
Provides-Extra: tpu
|
|
62
|
-
Requires-Dist: jax[cpu]>=0.
|
|
62
|
+
Requires-Dist: jax[cpu]>=0.6.2; extra == 'tpu'
|
|
63
63
|
Requires-Dist: jax[tpu]; extra == 'tpu'
|
|
64
64
|
Description-Content-Type: text/markdown
|
|
65
65
|
|
|
66
|
-
# torchax: Running PyTorch on TPU
|
|
66
|
+
# torchax: Running PyTorch on TPU via JAX
|
|
67
67
|
|
|
68
|
-
**torchax
|
|
69
|
-
PyTorch on Google
|
|
70
|
-
graph-level interoperability between PyTorch and
|
|
68
|
+
**torchax** is a backend for PyTorch, allowing users to run
|
|
69
|
+
PyTorch on Google Cloud TPUs. **torchax** is also a library for providing
|
|
70
|
+
graph-level interoperability between PyTorch and JAX.
|
|
71
71
|
|
|
72
72
|
This means, with **torchax** you can:
|
|
73
|
-
* Run PyTorch code on
|
|
74
|
-
* Call a
|
|
75
|
-
* Call a
|
|
76
|
-
* Use
|
|
77
|
-
|
|
73
|
+
* Run PyTorch code on TPUs with as little as 2 lines of code change.
|
|
74
|
+
* Call a JAX function from a PyTorch function, passing in `jax.Array`s.
|
|
75
|
+
* Call a PyTorch function from a JAX function, passing in a `torch.Tensor`s.
|
|
76
|
+
* Use JAX features such as `jax.grad`, `optax`, and `GSPMD` to train a PyTorch
|
|
77
|
+
model.
|
|
78
|
+
* Use a PyTorch model as feature extractor and use it with a JAX model.
|
|
78
79
|
etc etc.
|
|
79
80
|
|
|
80
81
|
## Install
|
|
81
82
|
|
|
82
|
-
|
|
83
|
-
### On Google Cloud TPU:
|
|
84
83
|
First install torch CPU:
|
|
85
84
|
|
|
86
85
|
```bash
|
|
86
|
+
# On Linux.
|
|
87
87
|
pip install torch --index-url https://download.pytorch.org/whl/cpu
|
|
88
|
-
```
|
|
89
|
-
|
|
90
|
-
Then install jax TPU:
|
|
91
|
-
|
|
92
|
-
```bash
|
|
93
|
-
pip install -U jax[tpu]
|
|
94
|
-
```
|
|
95
|
-
|
|
96
|
-
Finally install torchax
|
|
97
88
|
|
|
98
|
-
|
|
99
|
-
pip install
|
|
89
|
+
# Or on Mac.
|
|
90
|
+
pip install torch
|
|
100
91
|
```
|
|
101
92
|
|
|
102
|
-
|
|
103
|
-
First install torch CPU:
|
|
93
|
+
Then install JAX for the accelerator you want to use:
|
|
104
94
|
|
|
105
95
|
```bash
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
Then install jax CUDA:
|
|
96
|
+
# On Google Cloud TPU.
|
|
97
|
+
pip install -U jax[tpu]
|
|
110
98
|
|
|
111
|
-
|
|
99
|
+
# Or, on GPU machines.
|
|
112
100
|
pip install -U jax[cuda12]
|
|
113
|
-
```
|
|
114
|
-
|
|
115
|
-
Finally install torchax
|
|
116
|
-
|
|
117
|
-
```bash
|
|
118
|
-
pip install torchax
|
|
119
|
-
```
|
|
120
101
|
|
|
121
|
-
|
|
122
|
-
First install torch CPU:
|
|
123
|
-
|
|
124
|
-
```bash
|
|
125
|
-
# Linux
|
|
126
|
-
pip install torch --index-url https://download.pytorch.org/whl/cpu
|
|
127
|
-
|
|
128
|
-
# OR Mac:
|
|
129
|
-
pip install torch
|
|
130
|
-
```
|
|
131
|
-
|
|
132
|
-
Then install jax CPU:
|
|
133
|
-
|
|
134
|
-
```bash
|
|
102
|
+
# Or, on Linux CPU machines or Macs (see the note below).
|
|
135
103
|
pip install -U jax
|
|
136
104
|
```
|
|
137
105
|
|
|
138
|
-
Finally install torchax
|
|
139
|
-
|
|
140
|
-
```bash
|
|
141
|
-
pip install torchax
|
|
142
|
-
```
|
|
143
|
-
|
|
144
106
|
NOTE: if you like metal support for Apple devices then install the
|
|
145
|
-
metal version of
|
|
107
|
+
metal version of JAX: https://developer.apple.com/metal/jax/
|
|
146
108
|
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
Still need to install `torch` CPU and `Jax` of your accelerator (GPU, TPU or None).
|
|
109
|
+
Finally install torchax:
|
|
150
110
|
|
|
151
111
|
```bash
|
|
112
|
+
# Install pre-built torchax.
|
|
113
|
+
pip install torchax
|
|
114
|
+
|
|
115
|
+
# Or, install torchax from source.
|
|
152
116
|
pip install git+https://github.com/pytorch/xla.git#subdirectory=torchax
|
|
153
117
|
```
|
|
154
118
|
|
|
155
119
|
## Run a model
|
|
156
120
|
|
|
157
|
-
Now let's execute a model under torchax. We'll start with a simple 2-layer model
|
|
158
|
-
|
|
121
|
+
Now let's execute a model under torchax. We'll start with a simple 2-layer model.
|
|
122
|
+
In theory, we can use any instance of `torch.nn.Module`.
|
|
159
123
|
|
|
160
124
|
```python
|
|
161
125
|
import torch
|
|
@@ -179,75 +143,73 @@ class MyModel(nn.Module):
|
|
|
179
143
|
|
|
180
144
|
m = MyModel()
|
|
181
145
|
|
|
182
|
-
# Execute this model using torch
|
|
146
|
+
# Execute this model using torch.
|
|
183
147
|
inputs = torch.randn(3, 3, 28, 28)
|
|
184
148
|
print(m(inputs))
|
|
185
149
|
```
|
|
186
150
|
|
|
187
|
-
|
|
188
|
-
and it's submodules (`nn.Linear`).
|
|
189
|
-
|
|
190
|
-
To execute this model with `torchax`; we need to enable torchax to capture pytorch ops.
|
|
191
|
-
To enable this, use:
|
|
151
|
+
To execute this model with `torchax`, we need to enable torchax to capture PyTorch ops:
|
|
192
152
|
|
|
193
153
|
```python
|
|
194
154
|
import torchax
|
|
195
155
|
torchax.enable_globally()
|
|
196
156
|
```
|
|
197
|
-
|
|
157
|
+
|
|
158
|
+
Then, we can use a `jax` device:
|
|
198
159
|
|
|
199
160
|
```python
|
|
200
161
|
inputs = torch.randn(3, 3, 28, 28, device='jax')
|
|
201
|
-
m = MyModel()
|
|
162
|
+
m = MyModel().to('jax')
|
|
202
163
|
res = m(inputs)
|
|
203
164
|
print(type(res)) # outputs torchax.tensor.Tensor
|
|
204
165
|
```
|
|
205
166
|
|
|
206
167
|
`torchax.tensor.Tensor` is a `torch.Tensor` subclass that holds
|
|
207
|
-
a `jax.Array`. You can inspect that
|
|
208
|
-
|
|
168
|
+
a `jax.Array`. You can inspect that JAX array with `res.jax()`.
|
|
209
169
|
|
|
210
|
-
## What is happening behind the scene
|
|
170
|
+
## What is happening behind the scene
|
|
211
171
|
|
|
212
|
-
We took the approach detailed in
|
|
172
|
+
We took the approach detailed in the
|
|
173
|
+
[new device](https://github.com/albanD/subclass_zoo/blob/main/new_device.py)
|
|
174
|
+
recipe by Alban (@albanD), using `jax.Array` for `raw_data`.
|
|
213
175
|
|
|
214
|
-
In other words,
|
|
215
|
-
|
|
176
|
+
In other words, when a torch op is executed inside an `env` context manager,
|
|
177
|
+
which is enabled by `torchax.enable_globally()`, we will swap out the
|
|
178
|
+
implementation of that op with JAX.
|
|
216
179
|
|
|
217
180
|
When a model's constructor runs, it will call some tensor constructor, such as
|
|
218
|
-
`torch.rand`, `torch.ones
|
|
219
|
-
|
|
181
|
+
`torch.rand`, `torch.ones`, or `torch.zeros` to create its weights. When torchax
|
|
182
|
+
is enabled, these constructors will create a `torchax.tensor.Tensor`, which
|
|
183
|
+
contains a `jax.Array`.
|
|
220
184
|
|
|
221
|
-
Then, each subsequent op
|
|
222
|
-
and
|
|
223
|
-
|
|
224
|
-
See more at [how_it_works](docs/how_it_works.md) and [ops registry](docs/ops_registry.md).
|
|
185
|
+
Then, each subsequent op will extract the `jax.Array`, call the op's JAX
|
|
186
|
+
implementation, and wrap the result back into a `torchax.tensor.Tensor`,
|
|
225
187
|
|
|
188
|
+
See more at [how it works](docs/how_it_works.md) and\
|
|
189
|
+
[ops registry](docs/ops_registry.md).
|
|
226
190
|
|
|
227
191
|
### Executing with jax.jit
|
|
228
192
|
|
|
229
|
-
The above script will execute the model using eager mode
|
|
230
|
-
does allow executing torch models on
|
|
193
|
+
The above script will execute the model using eager mode JAX as the backend. This
|
|
194
|
+
does allow executing torch models on TPUs, but is often slower than what we can
|
|
231
195
|
achieve with `jax.jit`.
|
|
232
196
|
|
|
233
|
-
`jax.jit` is a function that takes a
|
|
234
|
-
and returns
|
|
197
|
+
`jax.jit` is a function that takes a JAX function (i.e. a function that takes JAX arrays
|
|
198
|
+
and returns JAX arrays) into a compiled (thus faster) version of the same function.
|
|
235
199
|
|
|
236
|
-
We have made
|
|
237
|
-
that takes and returns `torch.Tensor
|
|
200
|
+
We have made a `jax_jit` decorator that would accomplish the same with functions
|
|
201
|
+
that takes and returns `torch.Tensor`s. To use this, the first step is to create
|
|
238
202
|
a functional version of this model: this means the parameters should be passed in
|
|
239
|
-
as input instead of being attributes
|
|
240
|
-
|
|
203
|
+
as input instead of being attributes of the class:
|
|
241
204
|
|
|
242
205
|
```python
|
|
243
|
-
|
|
244
206
|
def model_func(param, inputs):
|
|
245
207
|
return torch.func.functional_call(m, param, inputs)
|
|
246
|
-
|
|
247
208
|
```
|
|
209
|
+
|
|
248
210
|
Here we use [torch.func.functional_call](https://pytorch.org/docs/stable/generated/torch.func.functional_call.html)
|
|
249
|
-
from PyTorch to replace the model
|
|
250
|
-
|
|
211
|
+
from PyTorch to replace the model weights with `param` and then call the
|
|
212
|
+
model. This is roughly equivalent to:
|
|
251
213
|
|
|
252
214
|
```python
|
|
253
215
|
def model_func(param, inputs):
|
|
@@ -255,87 +217,91 @@ def model_func(param, inputs):
|
|
|
255
217
|
return m(*inputs)
|
|
256
218
|
```
|
|
257
219
|
|
|
258
|
-
Now, we can apply `jax_jit`
|
|
220
|
+
Now, we can apply `jax_jit` on `module_func`:
|
|
259
221
|
|
|
260
222
|
```python
|
|
261
223
|
from torchax.interop import jax_jit
|
|
224
|
+
|
|
262
225
|
model_func_jitted = jax_jit(model_func)
|
|
263
226
|
print(model_func_jitted(new_state_dict, inputs))
|
|
264
227
|
```
|
|
265
228
|
|
|
266
|
-
See more examples at [eager_mode.py](examples/eager_mode.py) and the
|
|
267
|
-
|
|
268
|
-
However, to ease the idiom of creating functional model and calling it with parameters,
|
|
269
|
-
we also created the `JittableModule` helper class.
|
|
229
|
+
See more examples at [eager_mode.py](examples/eager_mode.py) and the
|
|
230
|
+
[examples folder](examples/).
|
|
270
231
|
|
|
271
|
-
|
|
232
|
+
To ease the idiom of creating functional model and calling it with parameters,
|
|
233
|
+
we also created the `JittableModule` helper class. It lets us rewrite the
|
|
234
|
+
above as:
|
|
272
235
|
|
|
273
236
|
```python
|
|
274
|
-
|
|
275
237
|
from torchax.interop import JittableModule
|
|
276
238
|
|
|
277
239
|
m_jitted = JittableModule(m)
|
|
278
240
|
res = m_jitted(...)
|
|
279
241
|
```
|
|
280
242
|
|
|
281
|
-
The first time
|
|
282
|
-
|
|
283
|
-
|
|
243
|
+
The first time `m_jitted` is called, it will trigger `jax.jit` to compile the
|
|
244
|
+
compile for the given input shapes. Subsequent calls with the same input shapes
|
|
245
|
+
will be fast as the compilation is cached.
|
|
284
246
|
|
|
247
|
+
## Citation
|
|
285
248
|
|
|
286
|
-
|
|
287
|
-
|
|
249
|
+
```
|
|
288
250
|
@software{torchax,
|
|
289
251
|
author = {Han Qi, Chun-nien Chan, Will Cromar, Manfei Bai, Kevin Gleanson},
|
|
290
|
-
title = {torchax: PyTorch on TPU and
|
|
252
|
+
title = {torchax: PyTorch on TPU and JAX interoperability},
|
|
291
253
|
url = {https://github.com/pytorch/xla/tree/master/torchax}
|
|
292
254
|
version = {0.0.4},
|
|
293
255
|
date = {2025-02-24},
|
|
294
256
|
}
|
|
257
|
+
```
|
|
295
258
|
|
|
296
259
|
# Maintainers & Contributors:
|
|
297
260
|
|
|
298
261
|
This library is created and maintained by the PyTorch/XLA team at Google Cloud.
|
|
299
262
|
|
|
300
|
-
|
|
263
|
+
It benefitted from many direct and indirect
|
|
301
264
|
contributions outside of the team. Many of them done by
|
|
302
|
-
fellow Googlers using [Google's 20% project policy](https://ebsedu.org/blog/google-tapping-workplace-actualization-20-time-rule)
|
|
265
|
+
fellow Googlers using [Google's 20% project policy](https://ebsedu.org/blog/google-tapping-workplace-actualization-20-time-rule).
|
|
266
|
+
Others by partner teams at Google and other companies.
|
|
303
267
|
|
|
304
|
-
Here is the
|
|
268
|
+
Here is the list of contributors by 2025-02-25.
|
|
305
269
|
|
|
306
|
-
|
|
307
|
-
|
|
270
|
+
```
|
|
271
|
+
Han Qi (qihqi), PyTorch/XLA
|
|
272
|
+
Manfei Bai (manfeibai), PyTorch/XLA
|
|
308
273
|
Will Cromar (will-cromar), Meta
|
|
309
|
-
Milad Mohammadi (miladm),
|
|
310
|
-
Siyuan Liu (lsy323),
|
|
311
|
-
Bhavya Bahl (bhavya01),
|
|
312
|
-
Pei Zhang (zpcore),
|
|
313
|
-
Yifei Teng (tengyifei),
|
|
274
|
+
Milad Mohammadi (miladm), PyTorch/XLA
|
|
275
|
+
Siyuan Liu (lsy323), PyTorch/XLA
|
|
276
|
+
Bhavya Bahl (bhavya01), PyTorch/XLA
|
|
277
|
+
Pei Zhang (zpcore), PyTorch/XLA
|
|
278
|
+
Yifei Teng (tengyifei), PyTorch/XLA
|
|
314
279
|
Chunnien Chan (chunnienc), Google, ODML
|
|
315
|
-
Alban Desmaison (albanD), Meta,
|
|
316
|
-
Simon Teo (simonteozw), Google(20%)
|
|
317
|
-
David Huang (dvhg), Google(20%)
|
|
318
|
-
Barni Seetharaman (barney-s), Google(20%)
|
|
319
|
-
Anish Karthik (anishfish2)
|
|
320
|
-
Yao Gu (guyao)
|
|
321
|
-
Yenkai Wang (yenkwang)
|
|
322
|
-
Greg Shikhman (commander)
|
|
323
|
-
Matin Akhlaghinia (matinehAkhlaghinia), Google(20%)
|
|
324
|
-
Tracy Chen (tracych477), Google(20%)
|
|
325
|
-
Matthias Guenther (mrguenther)
|
|
326
|
-
WenXin Dong (wenxindongwork), Google(20%)
|
|
327
|
-
Kevin Gleason (GleasonK)
|
|
328
|
-
Nupur Baghel (nupurbaghel), Google(20%)
|
|
329
|
-
Gwen Mittertreiner (gmittert), Google(20%)
|
|
280
|
+
Alban Desmaison (albanD), Meta, PyTorch
|
|
281
|
+
Simon Teo (simonteozw), Google (20%)
|
|
282
|
+
David Huang (dvhg), Google (20%)
|
|
283
|
+
Barni Seetharaman (barney-s), Google (20%)
|
|
284
|
+
Anish Karthik (anishfish2), Google (20%)
|
|
285
|
+
Yao Gu (guyao), Google (20%)
|
|
286
|
+
Yenkai Wang (yenkwang), Google (20%)
|
|
287
|
+
Greg Shikhman (commander), Google (20%)
|
|
288
|
+
Matin Akhlaghinia (matinehAkhlaghinia), Google (20%)
|
|
289
|
+
Tracy Chen (tracych477), Google (20%)
|
|
290
|
+
Matthias Guenther (mrguenther), Google (20%)
|
|
291
|
+
WenXin Dong (wenxindongwork), Google (20%)
|
|
292
|
+
Kevin Gleason (GleasonK), Google, StableHLO
|
|
293
|
+
Nupur Baghel (nupurbaghel), Google (20%)
|
|
294
|
+
Gwen Mittertreiner (gmittert), Google (20%)
|
|
330
295
|
Zeev Melumian (zmelumian), Lightricks
|
|
331
|
-
Vyom Sharma (vyom1611), Google(20%)
|
|
296
|
+
Vyom Sharma (vyom1611), Google (20%)
|
|
332
297
|
Shitong Wang (ShitongWang), Adobe
|
|
333
|
-
Rémi Doreau (ayshiff), Google(20%)
|
|
298
|
+
Rémi Doreau (ayshiff), Google (20%)
|
|
334
299
|
Lance Wang (wang2yn84), Google, CoreML
|
|
335
|
-
Hossein Sarshar (hosseinsarshar)
|
|
336
|
-
Daniel Vega-Myhre (danielvegamyhre)
|
|
337
|
-
Tianqi Fan (tqfan28), Google(20%)
|
|
338
|
-
Jim Lin (jimlinntu), Google(20%)
|
|
300
|
+
Hossein Sarshar (hosseinsarshar), Google (20%)
|
|
301
|
+
Daniel Vega-Myhre (danielvegamyhre), Google (20%)
|
|
302
|
+
Tianqi Fan (tqfan28), Google (20%)
|
|
303
|
+
Jim Lin (jimlinntu), Google (20%)
|
|
339
304
|
Fanhai Lu (FanhaiLu1), Google Cloud
|
|
340
305
|
DeWitt Clinton (dewitt), Google PyTorch
|
|
341
|
-
Aman Gupta (aman2930)
|
|
306
|
+
Aman Gupta (aman2930), Google (20%)
|
|
307
|
+
```
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
torchax/CONTRIBUTING.md,sha256=VOL0us6kS-uc4yE6IlSm6SDHYHnx-gw-0upFnP0VkSQ,1369
|
|
2
|
+
torchax/__init__.py,sha256=c98iIGugRTbEVcsx8eWnbAjsC4mpcDrK23ZQqiMycLg,3157
|
|
3
|
+
torchax/amp.py,sha256=-k8t4lrCsJLKHEhI6J0aHE3MAPEL-4DP6wCKtMwo1AM,11791
|
|
4
|
+
torchax/config.py,sha256=O9yF96AShWb02hcwkT5ToPTt_hpOo3dMJNO30A7dmac,922
|
|
5
|
+
torchax/configuration.py,sha256=O9yF96AShWb02hcwkT5ToPTt_hpOo3dMJNO30A7dmac,922
|
|
6
|
+
torchax/decompositions.py,sha256=1p5TFZfAJ2Bs9BiSO1vXbnWEXnbPfC_gCQ54rDXhd9k,28859
|
|
7
|
+
torchax/device_module.py,sha256=7fkdPwXG0qCBTmvDYHp0fvv4xK0W9avV_Ua3MeMzczE,349
|
|
8
|
+
torchax/environment.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
|
9
|
+
torchax/export.py,sha256=xU-UbrQBvQWUy-GM2FfeIHymlEdmYDYcPymjlcXM23w,8969
|
|
10
|
+
torchax/flax.py,sha256=2Tg8inGskgAfByPxJQh4ItZHHAb-960gYq156bSO8V4,1280
|
|
11
|
+
torchax/interop.py,sha256=7HvJwtxdodcCrMyJzs-Wr47hkHuoh6CWb2-YKoBwqV0,11076
|
|
12
|
+
torchax/mesh_util.py,sha256=Ab4ic2eHWmQ3Mw3jpERvi-TKLIcDvQQoC6tuIZ9ig7Q,9314
|
|
13
|
+
torchax/tensor.py,sha256=XjAp7khpQNhoVsSMzDj-V8l4DFT9jBaL4NVCi88a6K0,20893
|
|
14
|
+
torchax/tf_integration.py,sha256=d_h4vSJm7N9rJXpUPNCDOiUz3J1-UPo3KU8D9Wi4nnc,4074
|
|
15
|
+
torchax/train.py,sha256=rtvj6HkdnG9fc3VWYPNwHuxGlUxFJkUXJWED8azgtok,3855
|
|
16
|
+
torchax/types.py,sha256=j4ERjkgDgwhgi9zrwwbbiv4HMDlrJ1IEMUCmP_BIJ9M,388
|
|
17
|
+
torchax/util.py,sha256=cb-eudDE7AX2s-6zYtXdowgyzyvqPqE9MPP82PfH23g,3069
|
|
18
|
+
torchax/view.py,sha256=1ekqRN04lAPd_icgZMKbSYWhr738DzVloc34ynml4wo,11121
|
|
19
|
+
torchax/ops/__init__.py,sha256=Vr1p8zDHwfXZBUbw70iNiCJLZLNdI6gR_vUlaiA7Usg,270
|
|
20
|
+
torchax/ops/jaten.py,sha256=WxfZU6p7b7OR98B3z0LCXKlV6U5aslXxJMJirBr6lns,165835
|
|
21
|
+
torchax/ops/jax_reimplement.py,sha256=idkmFWNCXBilkmaHBGdivKz0XhsjSpqLNlGXxbBOKWQ,7302
|
|
22
|
+
torchax/ops/jc10d.py,sha256=OzSYYle_5jBmNVP64SuJPz9S-rRGD6H7e1a9HHIKsjU,1322
|
|
23
|
+
torchax/ops/jimage.py,sha256=P0lAauYX_au_xjIHDsG7H6jO7Jf54_VCAjzZuIZdhO0,3182
|
|
24
|
+
torchax/ops/jlibrary.py,sha256=YfYUQbf5dKiMtEHUMfdgHTeLuNvvSTJ-l8s7wQNIvO0,2930
|
|
25
|
+
torchax/ops/jtorch.py,sha256=wR4ZdDscxqG4VpxjcLGzgdUKmipa3fp7S0mK3DcD--A,17161
|
|
26
|
+
torchax/ops/jtorchvision_nms.py,sha256=HSnhwU0gFaHucT7EvrEruJdnWkAWTw4T35GY525ohO8,8903
|
|
27
|
+
torchax/ops/mappings.py,sha256=AESERtXJ6i_Hm0ycwEw7z5OJnHu-7QteWlSs-mlUPE4,3492
|
|
28
|
+
torchax/ops/op_base.py,sha256=MLKFxMojIXgz4lkTE6k-8F-ddve-9vEiXkzj3P-YJPs,3739
|
|
29
|
+
torchax/ops/ops_registry.py,sha256=qADpG1up0JOThoybiOQoRDWtAe5TOkHlqcj1bSHjtGY,1594
|
|
30
|
+
torchax-0.0.6.dist-info/METADATA,sha256=uB9hoyxdfrAD14pHy0U8Gh1uCHbYwok-oEW12pEa6qs,10753
|
|
31
|
+
torchax-0.0.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
32
|
+
torchax-0.0.6.dist-info/licenses/LICENSE,sha256=ZHyir3-ltOerFLt9JH1bjf7lIxIWipFmqeMnB_8z_aU,1498
|
|
33
|
+
torchax-0.0.6.dist-info/RECORD,,
|
torchax/distributed.py
DELETED
|
@@ -1,246 +0,0 @@
|
|
|
1
|
-
"""`torch.distributed` backend implemented with JAX collective ops.
|
|
2
|
-
|
|
3
|
-
EXPERIMENTAL: This module is still highly experimental, and it may be removed
|
|
4
|
-
before any stable release.
|
|
5
|
-
|
|
6
|
-
Note: JAX collective ops require that axis names be defined in `pmap` or
|
|
7
|
-
`shmap`. The distributed backend only supports one axis, named `torch_dist`.
|
|
8
|
-
This name is defined by our mirror implementation of `spawn`.
|
|
9
|
-
"""
|
|
10
|
-
|
|
11
|
-
import datetime
|
|
12
|
-
import functools
|
|
13
|
-
import logging
|
|
14
|
-
import os
|
|
15
|
-
from typing import List, Optional, Union
|
|
16
|
-
|
|
17
|
-
import jax
|
|
18
|
-
import numpy as np
|
|
19
|
-
import torch
|
|
20
|
-
import torch.distributed as dist
|
|
21
|
-
import torch.distributed._functional_collectives
|
|
22
|
-
from torch._C._distributed_c10d import ProcessGroup # type: ignore
|
|
23
|
-
import torch.distributed
|
|
24
|
-
import torchax
|
|
25
|
-
from jax.sharding import NamedSharding
|
|
26
|
-
from jax.sharding import Mesh, PartitionSpec as P
|
|
27
|
-
from jax.experimental import mesh_utils
|
|
28
|
-
import torch.utils._pytree as torch_pytree
|
|
29
|
-
from torchax import interop
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
class ProcessGroupJax(ProcessGroup):
|
|
33
|
-
"""Distributed backend implemented with JAX."""
|
|
34
|
-
|
|
35
|
-
def __init__(self, prefix_store, rank, size, timeout):
|
|
36
|
-
super().__init__(rank, size)
|
|
37
|
-
self._group_name = None
|
|
38
|
-
|
|
39
|
-
def getBackendName(self):
|
|
40
|
-
return "jax"
|
|
41
|
-
|
|
42
|
-
# TODO(wcromar): why doesn't default group name setter work?
|
|
43
|
-
# https://github.com/pytorch/pytorch/blob/7b1988f9222f3dec5cc2012afce84218199748ae/torch/csrc/distributed/c10d/ProcessGroup.cpp#L148-L152
|
|
44
|
-
def _set_group_name(self, name: str) -> None:
|
|
45
|
-
self._group_name = name
|
|
46
|
-
|
|
47
|
-
@property
|
|
48
|
-
def group_name(self):
|
|
49
|
-
assert self._group_name
|
|
50
|
-
return self._group_name
|
|
51
|
-
|
|
52
|
-
@staticmethod
|
|
53
|
-
def _work(
|
|
54
|
-
tensors: Union[torch.Tensor, List[torch.Tensor], List[List[torch.Tensor]]],
|
|
55
|
-
) -> dist.Work:
|
|
56
|
-
fut = torch.futures.Future()
|
|
57
|
-
fut.set_result(tensors)
|
|
58
|
-
return torch._C._distributed_c10d._create_work_from_future(fut)
|
|
59
|
-
|
|
60
|
-
def _allgather_base(
|
|
61
|
-
self,
|
|
62
|
-
output: torch.Tensor,
|
|
63
|
-
input: torch.Tensor,
|
|
64
|
-
opts=...,
|
|
65
|
-
) -> dist.Work:
|
|
66
|
-
assert isinstance(input, torchax.tensor.Tensor)
|
|
67
|
-
assert isinstance(output, torchax.tensor.Tensor)
|
|
68
|
-
torch.distributed._functional_collectives.all_gather_tensor_inplace(
|
|
69
|
-
output, input, group=self
|
|
70
|
-
)
|
|
71
|
-
return self._work(output)
|
|
72
|
-
|
|
73
|
-
def allreduce(
|
|
74
|
-
self,
|
|
75
|
-
tensors: List[torch.Tensor],
|
|
76
|
-
opts: dist.AllreduceOptions = ...,
|
|
77
|
-
) -> dist.Work:
|
|
78
|
-
assert len(tensors) == 1
|
|
79
|
-
assert isinstance(tensors[0], torchax.tensor.Tensor)
|
|
80
|
-
torch.distributed._functional_collectives.all_reduce_inplace(
|
|
81
|
-
tensors[0],
|
|
82
|
-
torch.distributed._functional_collectives.REDUCE_OP_TO_STR[
|
|
83
|
-
opts.reduceOp.op
|
|
84
|
-
],
|
|
85
|
-
self,
|
|
86
|
-
)
|
|
87
|
-
|
|
88
|
-
return self._work(tensors)
|
|
89
|
-
|
|
90
|
-
def broadcast(
|
|
91
|
-
self,
|
|
92
|
-
tensors: List[torch.Tensor],
|
|
93
|
-
opts: dist.BroadcastOptions = ...,
|
|
94
|
-
) -> dist.Work:
|
|
95
|
-
assert len(tensors) == 1
|
|
96
|
-
assert isinstance(tensors[0], torchax.tensor.Tensor)
|
|
97
|
-
tensors[0].copy_(
|
|
98
|
-
torch.distributed._functional_collectives.broadcast(
|
|
99
|
-
tensors[0], opts.rootRank, group=self
|
|
100
|
-
)
|
|
101
|
-
)
|
|
102
|
-
|
|
103
|
-
return self._work(tensors)
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
dist.Backend.register_backend("jax", ProcessGroupJax)
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
def jax_rendezvous_handler(
|
|
110
|
-
url: str, timeout: datetime.timedelta = ..., **kwargs
|
|
111
|
-
):
|
|
112
|
-
"""Initialize distributed store with JAX process IDs.
|
|
113
|
-
|
|
114
|
-
Requires `$MASTER_ADDR` and `$MASTER_PORT`.
|
|
115
|
-
"""
|
|
116
|
-
# TODO(wcromar): jax.distributed.initialize(...) for multiprocess on GPU
|
|
117
|
-
# TODO(wcromar): Can we use the XLA coordinator as a Store? This isn't part
|
|
118
|
-
# of their public Python API
|
|
119
|
-
master_ip = os.environ["MASTER_ADDR"]
|
|
120
|
-
master_port = int(os.environ["MASTER_PORT"])
|
|
121
|
-
# TODO(wcromar): Use `torchrun`'s store if available
|
|
122
|
-
store = dist.TCPStore(
|
|
123
|
-
master_ip,
|
|
124
|
-
master_port,
|
|
125
|
-
jax.process_count(),
|
|
126
|
-
is_master=jax.process_index() == 0,
|
|
127
|
-
)
|
|
128
|
-
|
|
129
|
-
yield (store, jax.process_index(), jax.process_count())
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
dist.register_rendezvous_handler("jax", jax_rendezvous_handler)
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
def spawn(f, args=(), env: Optional[torchax.tensor.Environment] = None):
|
|
136
|
-
"""Wrap `f` in a JAX `pmap` with the axis name `torch_dist` defined.
|
|
137
|
-
`f` is expected to take the replica index as a positional argument, similar
|
|
138
|
-
to `torch.multiprocessing.spawn`.
|
|
139
|
-
Note: `spawn` does not actually create parallel processes.
|
|
140
|
-
"""
|
|
141
|
-
env = env or torchax.default_env()
|
|
142
|
-
|
|
143
|
-
def jax_wrapper(index, jax_args):
|
|
144
|
-
index, args = env.j2t_iso([index, jax_args])
|
|
145
|
-
torch_outputs = f(index, *args)
|
|
146
|
-
return env.t2j_iso(torch_outputs)
|
|
147
|
-
|
|
148
|
-
jax_outputs = jax.pmap(jax_wrapper, axis_name="torch_dist")(
|
|
149
|
-
np.arange(jax.device_count()), env.t2j_iso(args)
|
|
150
|
-
)
|
|
151
|
-
return env.j2t_iso(jax_outputs)
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
class DistributedDataParallel(torch.nn.Module):
|
|
155
|
-
"""Re-implementation of DistributedDataParallel using JAX SPMD.
|
|
156
|
-
|
|
157
|
-
Splits inputs along batch dimension (assumed to be 0) across all devices in
|
|
158
|
-
JAX runtime, including remote devices. Each process should load a distinct
|
|
159
|
-
shard of the input data using e.g. DistributedSampler. Each process' shard
|
|
160
|
-
is then further split among the addressable devices (e.g. local TPU chips)
|
|
161
|
-
by `shard_input`.
|
|
162
|
-
|
|
163
|
-
Note: since parameters are replicated across addressable devices, inputs
|
|
164
|
-
must also be SPMD sharded using `shard_input` or `replicate_input`.
|
|
165
|
-
|
|
166
|
-
Example usage:
|
|
167
|
-
|
|
168
|
-
```
|
|
169
|
-
jax_model = torchax.distributed.DistributedDataParallel(create_model())
|
|
170
|
-
for data, dataloader:
|
|
171
|
-
jax_data = jax_model.shard_input(data)
|
|
172
|
-
jax_output = jax_model(jax_data)
|
|
173
|
-
```
|
|
174
|
-
"""
|
|
175
|
-
def __init__(
|
|
176
|
-
self,
|
|
177
|
-
module: torch.nn.Module,
|
|
178
|
-
env: Optional[torchax.tensor.Environment] = None,
|
|
179
|
-
**kwargs,
|
|
180
|
-
):
|
|
181
|
-
if kwargs:
|
|
182
|
-
logging.warning(f"Unsupported kwargs {kwargs}")
|
|
183
|
-
|
|
184
|
-
super().__init__()
|
|
185
|
-
self._env = env or torchax.default_env()
|
|
186
|
-
self._mesh = Mesh(
|
|
187
|
-
mesh_utils.create_device_mesh((jax.device_count(),)),
|
|
188
|
-
axis_names=("batch",),
|
|
189
|
-
)
|
|
190
|
-
replicated_state = torch_pytree.tree_map_only(
|
|
191
|
-
torch.Tensor,
|
|
192
|
-
lambda t: self._env.j2t_iso(
|
|
193
|
-
jax.device_put(
|
|
194
|
-
self._env.to_xla(t)._elem, NamedSharding(self._mesh, P())
|
|
195
|
-
)
|
|
196
|
-
),
|
|
197
|
-
module.state_dict(),
|
|
198
|
-
)
|
|
199
|
-
# TODO: broadcast
|
|
200
|
-
module.load_state_dict(replicated_state, assign=True)
|
|
201
|
-
self._module = module
|
|
202
|
-
|
|
203
|
-
def shard_input(self, inp):
|
|
204
|
-
per_process_batch_size = inp.shape[0] # assumes batch dim is 0
|
|
205
|
-
per_replica_batch_size = per_process_batch_size // jax.local_device_count()
|
|
206
|
-
per_replica_batches = torch.chunk(inp, jax.local_device_count())
|
|
207
|
-
global_batch_size = per_replica_batch_size * jax.device_count()
|
|
208
|
-
global_batch_shape = (global_batch_size,) + inp.shape[1:]
|
|
209
|
-
|
|
210
|
-
sharding = NamedSharding(self._mesh, P("batch"))
|
|
211
|
-
return self._env.j2t_iso(jax.make_array_from_single_device_arrays(
|
|
212
|
-
global_batch_shape,
|
|
213
|
-
NamedSharding(self._mesh, P("batch")),
|
|
214
|
-
arrays=[
|
|
215
|
-
jax.device_put(self._env.to_xla(batch)._elem, device)
|
|
216
|
-
for batch, device in zip(
|
|
217
|
-
per_replica_batches, sharding.addressable_devices
|
|
218
|
-
)
|
|
219
|
-
],
|
|
220
|
-
))
|
|
221
|
-
|
|
222
|
-
def replicate_input(self, inp):
|
|
223
|
-
return self._env.j2t_iso(
|
|
224
|
-
jax.device_put(inp._elem, NamedSharding(self._mesh, P()))
|
|
225
|
-
)
|
|
226
|
-
|
|
227
|
-
def jit_step(self, func):
|
|
228
|
-
@functools.partial(interop.jax_jit,
|
|
229
|
-
kwargs_for_jax_jit={'donate_argnums': 0})
|
|
230
|
-
def _jit_fn(states, args):
|
|
231
|
-
self.load_state_dict(states)
|
|
232
|
-
outputs = func(*args)
|
|
233
|
-
return self.state_dict(), outputs
|
|
234
|
-
|
|
235
|
-
@functools.wraps(func)
|
|
236
|
-
def inner(*args):
|
|
237
|
-
jax_states = self.state_dict()
|
|
238
|
-
new_states, outputs = _jit_fn(jax_states, args)
|
|
239
|
-
self.load_state_dict(new_states)
|
|
240
|
-
return outputs
|
|
241
|
-
|
|
242
|
-
return inner
|
|
243
|
-
|
|
244
|
-
def forward(self, *args):
|
|
245
|
-
with self._env:
|
|
246
|
-
return self._module(*args)
|