torchax 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

@@ -1,7 +1,7 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torchax
3
- Version: 0.0.4
4
- Summary: torchax is a library for running PyTorch on TPU
3
+ Version: 0.0.6
4
+ Summary: torchax is a library for running Jax and PyTorch together
5
5
  Project-URL: Homepage, https://github.com/pytorch/xla/tree/master/torchax
6
6
  Author-email: Han Qi <qihan.dev@gmail.com>, Pytorch/XLA team <pytorchxla-dev@google.com>
7
7
  License: BSD 3-Clause License
@@ -51,111 +51,75 @@ Classifier: Topic :: Software Development :: Libraries :: Python Modules
51
51
  Requires-Python: >=3.10
52
52
  Provides-Extra: cpu
53
53
  Requires-Dist: jax[cpu]; extra == 'cpu'
54
- Requires-Dist: jax[cpu]>=0.4.30; extra == 'cpu'
54
+ Requires-Dist: jax[cpu]>=0.6.2; extra == 'cpu'
55
55
  Provides-Extra: cuda
56
- Requires-Dist: jax[cpu]>=0.4.30; extra == 'cuda'
56
+ Requires-Dist: jax[cpu]>=0.6.2; extra == 'cuda'
57
57
  Requires-Dist: jax[cuda12]; extra == 'cuda'
58
58
  Provides-Extra: odml
59
59
  Requires-Dist: jax[cpu]; extra == 'odml'
60
- Requires-Dist: jax[cpu]>=0.4.30; extra == 'odml'
60
+ Requires-Dist: jax[cpu]>=0.6.2; extra == 'odml'
61
61
  Provides-Extra: tpu
62
- Requires-Dist: jax[cpu]>=0.4.30; extra == 'tpu'
62
+ Requires-Dist: jax[cpu]>=0.6.2; extra == 'tpu'
63
63
  Requires-Dist: jax[tpu]; extra == 'tpu'
64
64
  Description-Content-Type: text/markdown
65
65
 
66
- # torchax: Running PyTorch on TPU
66
+ # torchax: Running PyTorch on TPU via JAX
67
67
 
68
- **torchax!** is a backend for PyTorch, allowing users to run
69
- PyTorch on Google CloudTPUs. **torchax!** is also a library for providing
70
- graph-level interoperability between PyTorch and Jax.
68
+ **torchax** is a backend for PyTorch, allowing users to run
69
+ PyTorch on Google Cloud TPUs. **torchax** is also a library for providing
70
+ graph-level interoperability between PyTorch and JAX.
71
71
 
72
72
  This means, with **torchax** you can:
73
- * Run PyTorch code on TPU with as little as 2 lines of code change.
74
- * Call a jax function from a pytorch function, passing in `jax.Array`s
75
- * Call a pytorch function from a jax function, passing in a `torch.Tensor` subclass.
76
- * Use jax features such as `jax.grad`, `optax` and `GSMPD` to train a Pytorch model.
77
- * Use a Pytorch model as feature extractor and use it with a Jax model.
73
+ * Run PyTorch code on TPUs with as little as 2 lines of code change.
74
+ * Call a JAX function from a PyTorch function, passing in `jax.Array`s.
75
+ * Call a PyTorch function from a JAX function, passing in a `torch.Tensor`s.
76
+ * Use JAX features such as `jax.grad`, `optax`, and `GSPMD` to train a PyTorch
77
+ model.
78
+ * Use a PyTorch model as feature extractor and use it with a JAX model.
78
79
  etc etc.
79
80
 
80
81
  ## Install
81
82
 
82
-
83
- ### On Google Cloud TPU:
84
83
  First install torch CPU:
85
84
 
86
85
  ```bash
86
+ # On Linux.
87
87
  pip install torch --index-url https://download.pytorch.org/whl/cpu
88
- ```
89
-
90
- Then install jax TPU:
91
-
92
- ```bash
93
- pip install -U jax[tpu]
94
- ```
95
-
96
- Finally install torchax
97
88
 
98
- ```bash
99
- pip install torchax
89
+ # Or on Mac.
90
+ pip install torch
100
91
  ```
101
92
 
102
- ### On GPU machines:
103
- First install torch CPU:
93
+ Then install JAX for the accelerator you want to use:
104
94
 
105
95
  ```bash
106
- pip install torch --index-url https://download.pytorch.org/whl/cpu
107
- ```
108
-
109
- Then install jax CUDA:
96
+ # On Google Cloud TPU.
97
+ pip install -U jax[tpu]
110
98
 
111
- ```bash
99
+ # Or, on GPU machines.
112
100
  pip install -U jax[cuda12]
113
- ```
114
-
115
- Finally install torchax
116
-
117
- ```bash
118
- pip install torchax
119
- ```
120
101
 
121
- ### On CPU machines (mac included)
122
- First install torch CPU:
123
-
124
- ```bash
125
- # Linux
126
- pip install torch --index-url https://download.pytorch.org/whl/cpu
127
-
128
- # OR Mac:
129
- pip install torch
130
- ```
131
-
132
- Then install jax CPU:
133
-
134
- ```bash
102
+ # Or, on Linux CPU machines or Macs (see the note below).
135
103
  pip install -U jax
136
104
  ```
137
105
 
138
- Finally install torchax
139
-
140
- ```bash
141
- pip install torchax
142
- ```
143
-
144
106
  NOTE: if you like metal support for Apple devices then install the
145
- metal version of jax: https://developer.apple.com/metal/jax/
107
+ metal version of JAX: https://developer.apple.com/metal/jax/
146
108
 
147
- ### Installing `torchax` from source
148
-
149
- Still need to install `torch` CPU and `Jax` of your accelerator (GPU, TPU or None).
109
+ Finally install torchax:
150
110
 
151
111
  ```bash
112
+ # Install pre-built torchax.
113
+ pip install torchax
114
+
115
+ # Or, install torchax from source.
152
116
  pip install git+https://github.com/pytorch/xla.git#subdirectory=torchax
153
117
  ```
154
118
 
155
119
  ## Run a model
156
120
 
157
- Now let's execute a model under torchax. We'll start with a simple 2-layer model
158
- it can be in theory any instance of `torch.nn.Module`.
121
+ Now let's execute a model under torchax. We'll start with a simple 2-layer model.
122
+ In theory, we can use any instance of `torch.nn.Module`.
159
123
 
160
124
  ```python
161
125
  import torch
@@ -179,75 +143,73 @@ class MyModel(nn.Module):
179
143
 
180
144
  m = MyModel()
181
145
 
182
- # Execute this model using torch
146
+ # Execute this model using torch.
183
147
  inputs = torch.randn(3, 3, 28, 28)
184
148
  print(m(inputs))
185
149
  ```
186
150
 
187
- This model `m` contains 2 parts: the weights that is stored inside of the model
188
- and it's submodules (`nn.Linear`).
189
-
190
- To execute this model with `torchax`; we need to enable torchax to capture pytorch ops.
191
- To enable this, use:
151
+ To execute this model with `torchax`, we need to enable torchax to capture PyTorch ops:
192
152
 
193
153
  ```python
194
154
  import torchax
195
155
  torchax.enable_globally()
196
156
  ```
197
- Then, a `jax` device will be available to use
157
+
158
+ Then, we can use a `jax` device:
198
159
 
199
160
  ```python
200
161
  inputs = torch.randn(3, 3, 28, 28, device='jax')
201
- m = MyModel()
162
+ m = MyModel().to('jax')
202
163
  res = m(inputs)
203
164
  print(type(res)) # outputs torchax.tensor.Tensor
204
165
  ```
205
166
 
206
167
  `torchax.tensor.Tensor` is a `torch.Tensor` subclass that holds
207
- a `jax.Array`. You can inspect that jax array with `res.jax()`
208
-
168
+ a `jax.Array`. You can inspect that JAX array with `res.jax()`.
209
169
 
210
- ## What is happening behind the scene:
170
+ ## What is happening behind the scene
211
171
 
212
- We took the approach detailed in [new device](https://github.com/albanD/subclass_zoo/blob/main/new_device.py) recipe by Alban (@albanD); using `jax.Array` for the `raw_data`.
172
+ We took the approach detailed in the
173
+ [new device](https://github.com/albanD/subclass_zoo/blob/main/new_device.py)
174
+ recipe by Alban (@albanD), using `jax.Array` for `raw_data`.
213
175
 
214
- In other words, When a torch op is executed inside of `env` context manager (which is enabled with `torchax.enable_globally()`), we can swap out the
215
- implementation of that op written in Jax.
176
+ In other words, when a torch op is executed inside an `env` context manager,
177
+ which is enabled by `torchax.enable_globally()`, we will swap out the
178
+ implementation of that op with JAX.
216
179
 
217
180
  When a model's constructor runs, it will call some tensor constructor, such as
218
- `torch.rand`, `torch.ones` or `torch.zeros` etc to create its weights. The constructor
219
- will create an `torch.Tensor` subclass that contains a `jax.Array`.
181
+ `torch.rand`, `torch.ones`, or `torch.zeros` to create its weights. When torchax
182
+ is enabled, these constructors will create a `torchax.tensor.Tensor`, which
183
+ contains a `jax.Array`.
220
184
 
221
- Then, each subsequent op can unpack the `jax.Array`, call the op implementation,
222
- and wraps it back into `torch.Tensor` subclass.
223
-
224
- See more at [how_it_works](docs/how_it_works.md) and [ops registry](docs/ops_registry.md).
185
+ Then, each subsequent op will extract the `jax.Array`, call the op's JAX
186
+ implementation, and wrap the result back into a `torchax.tensor.Tensor`,
225
187
 
188
+ See more at [how it works](docs/how_it_works.md) and\
189
+ [ops registry](docs/ops_registry.md).
226
190
 
227
191
  ### Executing with jax.jit
228
192
 
229
- The above script will execute the model using eager mode Jax as backend. This
230
- does allow executing torch models on TPU, but is often slower than what we can
193
+ The above script will execute the model using eager mode JAX as the backend. This
194
+ does allow executing torch models on TPUs, but is often slower than what we can
231
195
  achieve with `jax.jit`.
232
196
 
233
- `jax.jit` is a function that takes a Jax function (i.e. a function that takes jax array
234
- and returns jax array) into the same function, but faster.
197
+ `jax.jit` is a function that takes a JAX function (i.e. a function that takes JAX arrays
198
+ and returns JAX arrays) into a compiled (thus faster) version of the same function.
235
199
 
236
- We have made the `jax_jit` decorator that would accomplish the same with functions
237
- that takes and returns `torch.Tensor`. To use this, the first step is to create
200
+ We have made a `jax_jit` decorator that would accomplish the same with functions
201
+ that takes and returns `torch.Tensor`s. To use this, the first step is to create
238
202
  a functional version of this model: this means the parameters should be passed in
239
- as input instead of being attributes on class:
240
-
203
+ as input instead of being attributes of the class:
241
204
 
242
205
  ```python
243
-
244
206
  def model_func(param, inputs):
245
207
  return torch.func.functional_call(m, param, inputs)
246
-
247
208
  ```
209
+
248
210
  Here we use [torch.func.functional_call](https://pytorch.org/docs/stable/generated/torch.func.functional_call.html)
249
- from PyTorch to replace the model
250
- weights with `param`, then call the model. This is roughly equivalent to:
211
+ from PyTorch to replace the model weights with `param` and then call the
212
+ model. This is roughly equivalent to:
251
213
 
252
214
  ```python
253
215
  def model_func(param, inputs):
@@ -255,87 +217,91 @@ def model_func(param, inputs):
255
217
  return m(*inputs)
256
218
  ```
257
219
 
258
- Now, we can apply `jax_jit`
220
+ Now, we can apply `jax_jit` on `module_func`:
259
221
 
260
222
  ```python
261
223
  from torchax.interop import jax_jit
224
+
262
225
  model_func_jitted = jax_jit(model_func)
263
226
  print(model_func_jitted(new_state_dict, inputs))
264
227
  ```
265
228
 
266
- See more examples at [eager_mode.py](examples/eager_mode.py) and the (examples folder)[examples/]
267
-
268
- However, to ease the idiom of creating functional model and calling it with parameters,
269
- we also created the `JittableModule` helper class.
229
+ See more examples at [eager_mode.py](examples/eager_mode.py) and the
230
+ [examples folder](examples/).
270
231
 
271
- So the above can be written as:
232
+ To ease the idiom of creating functional model and calling it with parameters,
233
+ we also created the `JittableModule` helper class. It lets us rewrite the
234
+ above as:
272
235
 
273
236
  ```python
274
-
275
237
  from torchax.interop import JittableModule
276
238
 
277
239
  m_jitted = JittableModule(m)
278
240
  res = m_jitted(...)
279
241
  ```
280
242
 
281
- The first time that `m_jitted` is called , it will trigger `jax.jit`
282
- then the subsequent computation with inputs of same shape will be fast.
283
-
243
+ The first time `m_jitted` is called, it will trigger `jax.jit` to compile the
244
+ compile for the given input shapes. Subsequent calls with the same input shapes
245
+ will be fast as the compilation is cached.
284
246
 
247
+ ## Citation
285
248
 
286
- # Citation:
287
-
249
+ ```
288
250
  @software{torchax,
289
251
  author = {Han Qi, Chun-nien Chan, Will Cromar, Manfei Bai, Kevin Gleanson},
290
- title = {torchax: PyTorch on TPU and Jax interoperability},
252
+ title = {torchax: PyTorch on TPU and JAX interoperability},
291
253
  url = {https://github.com/pytorch/xla/tree/master/torchax}
292
254
  version = {0.0.4},
293
255
  date = {2025-02-24},
294
256
  }
257
+ ```
295
258
 
296
259
  # Maintainers & Contributors:
297
260
 
298
261
  This library is created and maintained by the PyTorch/XLA team at Google Cloud.
299
262
 
300
- However, it benefitted from many direct and indirect
263
+ It benefitted from many direct and indirect
301
264
  contributions outside of the team. Many of them done by
302
- fellow Googlers using [Google's 20% project policy](https://ebsedu.org/blog/google-tapping-workplace-actualization-20-time-rule), others by partner teams.
265
+ fellow Googlers using [Google's 20% project policy](https://ebsedu.org/blog/google-tapping-workplace-actualization-20-time-rule).
266
+ Others by partner teams at Google and other companies.
303
267
 
304
- Here is the full list of contributors by 2025-02-25.
268
+ Here is the list of contributors by 2025-02-25.
305
269
 
306
- Han Qi (qihqi), Pytorch / XLA
307
- Manfei Bai (manfeibai), Pytorch / XLA
270
+ ```
271
+ Han Qi (qihqi), PyTorch/XLA
272
+ Manfei Bai (manfeibai), PyTorch/XLA
308
273
  Will Cromar (will-cromar), Meta
309
- Milad Mohammadi (miladm), Pytorch / XLA
310
- Siyuan Liu (lsy323), Pytorch / XLA
311
- Bhavya Bahl (bhavya01), Pytorch / XLA
312
- Pei Zhang (zpcore), Pytorch / XLA
313
- Yifei Teng (tengyifei), Pytorch / XLA
274
+ Milad Mohammadi (miladm), PyTorch/XLA
275
+ Siyuan Liu (lsy323), PyTorch/XLA
276
+ Bhavya Bahl (bhavya01), PyTorch/XLA
277
+ Pei Zhang (zpcore), PyTorch/XLA
278
+ Yifei Teng (tengyifei), PyTorch/XLA
314
279
  Chunnien Chan (chunnienc), Google, ODML
315
- Alban Desmaison (albanD), Meta, Pytorch
316
- Simon Teo (simonteozw), Google(20%)
317
- David Huang (dvhg), Google(20%)
318
- Barni Seetharaman (barney-s), Google(20%)
319
- Anish Karthik (anishfish2) , Google(20%)
320
- Yao Gu (guyao) , Google(20%)
321
- Yenkai Wang (yenkwang) , Google(20%)
322
- Greg Shikhman (commander) , Google(20%)
323
- Matin Akhlaghinia (matinehAkhlaghinia), Google(20%)
324
- Tracy Chen (tracych477), Google(20%)
325
- Matthias Guenther (mrguenther) , Google(20%)
326
- WenXin Dong (wenxindongwork), Google(20%)
327
- Kevin Gleason (GleasonK) , Google, StableHLO
328
- Nupur Baghel (nupurbaghel), Google(20%)
329
- Gwen Mittertreiner (gmittert), Google(20%)
280
+ Alban Desmaison (albanD), Meta, PyTorch
281
+ Simon Teo (simonteozw), Google (20%)
282
+ David Huang (dvhg), Google (20%)
283
+ Barni Seetharaman (barney-s), Google (20%)
284
+ Anish Karthik (anishfish2), Google (20%)
285
+ Yao Gu (guyao), Google (20%)
286
+ Yenkai Wang (yenkwang), Google (20%)
287
+ Greg Shikhman (commander), Google (20%)
288
+ Matin Akhlaghinia (matinehAkhlaghinia), Google (20%)
289
+ Tracy Chen (tracych477), Google (20%)
290
+ Matthias Guenther (mrguenther), Google (20%)
291
+ WenXin Dong (wenxindongwork), Google (20%)
292
+ Kevin Gleason (GleasonK), Google, StableHLO
293
+ Nupur Baghel (nupurbaghel), Google (20%)
294
+ Gwen Mittertreiner (gmittert), Google (20%)
330
295
  Zeev Melumian (zmelumian), Lightricks
331
- Vyom Sharma (vyom1611), Google(20%)
296
+ Vyom Sharma (vyom1611), Google (20%)
332
297
  Shitong Wang (ShitongWang), Adobe
333
- Rémi Doreau (ayshiff), Google(20%)
298
+ Rémi Doreau (ayshiff), Google (20%)
334
299
  Lance Wang (wang2yn84), Google, CoreML
335
- Hossein Sarshar (hosseinsarshar) , Google(20%)
336
- Daniel Vega-Myhre (danielvegamyhre) , Google(20%)
337
- Tianqi Fan (tqfan28), Google(20%)
338
- Jim Lin (jimlinntu), Google(20%)
300
+ Hossein Sarshar (hosseinsarshar), Google (20%)
301
+ Daniel Vega-Myhre (danielvegamyhre), Google (20%)
302
+ Tianqi Fan (tqfan28), Google (20%)
303
+ Jim Lin (jimlinntu), Google (20%)
339
304
  Fanhai Lu (FanhaiLu1), Google Cloud
340
305
  DeWitt Clinton (dewitt), Google PyTorch
341
- Aman Gupta (aman2930) , Google(20%)
306
+ Aman Gupta (aman2930), Google (20%)
307
+ ```
@@ -0,0 +1,33 @@
1
+ torchax/CONTRIBUTING.md,sha256=VOL0us6kS-uc4yE6IlSm6SDHYHnx-gw-0upFnP0VkSQ,1369
2
+ torchax/__init__.py,sha256=c98iIGugRTbEVcsx8eWnbAjsC4mpcDrK23ZQqiMycLg,3157
3
+ torchax/amp.py,sha256=-k8t4lrCsJLKHEhI6J0aHE3MAPEL-4DP6wCKtMwo1AM,11791
4
+ torchax/config.py,sha256=O9yF96AShWb02hcwkT5ToPTt_hpOo3dMJNO30A7dmac,922
5
+ torchax/configuration.py,sha256=O9yF96AShWb02hcwkT5ToPTt_hpOo3dMJNO30A7dmac,922
6
+ torchax/decompositions.py,sha256=1p5TFZfAJ2Bs9BiSO1vXbnWEXnbPfC_gCQ54rDXhd9k,28859
7
+ torchax/device_module.py,sha256=7fkdPwXG0qCBTmvDYHp0fvv4xK0W9avV_Ua3MeMzczE,349
8
+ torchax/environment.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
9
+ torchax/export.py,sha256=xU-UbrQBvQWUy-GM2FfeIHymlEdmYDYcPymjlcXM23w,8969
10
+ torchax/flax.py,sha256=2Tg8inGskgAfByPxJQh4ItZHHAb-960gYq156bSO8V4,1280
11
+ torchax/interop.py,sha256=7HvJwtxdodcCrMyJzs-Wr47hkHuoh6CWb2-YKoBwqV0,11076
12
+ torchax/mesh_util.py,sha256=Ab4ic2eHWmQ3Mw3jpERvi-TKLIcDvQQoC6tuIZ9ig7Q,9314
13
+ torchax/tensor.py,sha256=XjAp7khpQNhoVsSMzDj-V8l4DFT9jBaL4NVCi88a6K0,20893
14
+ torchax/tf_integration.py,sha256=d_h4vSJm7N9rJXpUPNCDOiUz3J1-UPo3KU8D9Wi4nnc,4074
15
+ torchax/train.py,sha256=rtvj6HkdnG9fc3VWYPNwHuxGlUxFJkUXJWED8azgtok,3855
16
+ torchax/types.py,sha256=j4ERjkgDgwhgi9zrwwbbiv4HMDlrJ1IEMUCmP_BIJ9M,388
17
+ torchax/util.py,sha256=cb-eudDE7AX2s-6zYtXdowgyzyvqPqE9MPP82PfH23g,3069
18
+ torchax/view.py,sha256=1ekqRN04lAPd_icgZMKbSYWhr738DzVloc34ynml4wo,11121
19
+ torchax/ops/__init__.py,sha256=Vr1p8zDHwfXZBUbw70iNiCJLZLNdI6gR_vUlaiA7Usg,270
20
+ torchax/ops/jaten.py,sha256=WxfZU6p7b7OR98B3z0LCXKlV6U5aslXxJMJirBr6lns,165835
21
+ torchax/ops/jax_reimplement.py,sha256=idkmFWNCXBilkmaHBGdivKz0XhsjSpqLNlGXxbBOKWQ,7302
22
+ torchax/ops/jc10d.py,sha256=OzSYYle_5jBmNVP64SuJPz9S-rRGD6H7e1a9HHIKsjU,1322
23
+ torchax/ops/jimage.py,sha256=P0lAauYX_au_xjIHDsG7H6jO7Jf54_VCAjzZuIZdhO0,3182
24
+ torchax/ops/jlibrary.py,sha256=YfYUQbf5dKiMtEHUMfdgHTeLuNvvSTJ-l8s7wQNIvO0,2930
25
+ torchax/ops/jtorch.py,sha256=wR4ZdDscxqG4VpxjcLGzgdUKmipa3fp7S0mK3DcD--A,17161
26
+ torchax/ops/jtorchvision_nms.py,sha256=HSnhwU0gFaHucT7EvrEruJdnWkAWTw4T35GY525ohO8,8903
27
+ torchax/ops/mappings.py,sha256=AESERtXJ6i_Hm0ycwEw7z5OJnHu-7QteWlSs-mlUPE4,3492
28
+ torchax/ops/op_base.py,sha256=MLKFxMojIXgz4lkTE6k-8F-ddve-9vEiXkzj3P-YJPs,3739
29
+ torchax/ops/ops_registry.py,sha256=qADpG1up0JOThoybiOQoRDWtAe5TOkHlqcj1bSHjtGY,1594
30
+ torchax-0.0.6.dist-info/METADATA,sha256=uB9hoyxdfrAD14pHy0U8Gh1uCHbYwok-oEW12pEa6qs,10753
31
+ torchax-0.0.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
32
+ torchax-0.0.6.dist-info/licenses/LICENSE,sha256=ZHyir3-ltOerFLt9JH1bjf7lIxIWipFmqeMnB_8z_aU,1498
33
+ torchax-0.0.6.dist-info/RECORD,,
torchax/distributed.py DELETED
@@ -1,246 +0,0 @@
1
- """`torch.distributed` backend implemented with JAX collective ops.
2
-
3
- EXPERIMENTAL: This module is still highly experimental, and it may be removed
4
- before any stable release.
5
-
6
- Note: JAX collective ops require that axis names be defined in `pmap` or
7
- `shmap`. The distributed backend only supports one axis, named `torch_dist`.
8
- This name is defined by our mirror implementation of `spawn`.
9
- """
10
-
11
- import datetime
12
- import functools
13
- import logging
14
- import os
15
- from typing import List, Optional, Union
16
-
17
- import jax
18
- import numpy as np
19
- import torch
20
- import torch.distributed as dist
21
- import torch.distributed._functional_collectives
22
- from torch._C._distributed_c10d import ProcessGroup # type: ignore
23
- import torch.distributed
24
- import torchax
25
- from jax.sharding import NamedSharding
26
- from jax.sharding import Mesh, PartitionSpec as P
27
- from jax.experimental import mesh_utils
28
- import torch.utils._pytree as torch_pytree
29
- from torchax import interop
30
-
31
-
32
- class ProcessGroupJax(ProcessGroup):
33
- """Distributed backend implemented with JAX."""
34
-
35
- def __init__(self, prefix_store, rank, size, timeout):
36
- super().__init__(rank, size)
37
- self._group_name = None
38
-
39
- def getBackendName(self):
40
- return "jax"
41
-
42
- # TODO(wcromar): why doesn't default group name setter work?
43
- # https://github.com/pytorch/pytorch/blob/7b1988f9222f3dec5cc2012afce84218199748ae/torch/csrc/distributed/c10d/ProcessGroup.cpp#L148-L152
44
- def _set_group_name(self, name: str) -> None:
45
- self._group_name = name
46
-
47
- @property
48
- def group_name(self):
49
- assert self._group_name
50
- return self._group_name
51
-
52
- @staticmethod
53
- def _work(
54
- tensors: Union[torch.Tensor, List[torch.Tensor], List[List[torch.Tensor]]],
55
- ) -> dist.Work:
56
- fut = torch.futures.Future()
57
- fut.set_result(tensors)
58
- return torch._C._distributed_c10d._create_work_from_future(fut)
59
-
60
- def _allgather_base(
61
- self,
62
- output: torch.Tensor,
63
- input: torch.Tensor,
64
- opts=...,
65
- ) -> dist.Work:
66
- assert isinstance(input, torchax.tensor.Tensor)
67
- assert isinstance(output, torchax.tensor.Tensor)
68
- torch.distributed._functional_collectives.all_gather_tensor_inplace(
69
- output, input, group=self
70
- )
71
- return self._work(output)
72
-
73
- def allreduce(
74
- self,
75
- tensors: List[torch.Tensor],
76
- opts: dist.AllreduceOptions = ...,
77
- ) -> dist.Work:
78
- assert len(tensors) == 1
79
- assert isinstance(tensors[0], torchax.tensor.Tensor)
80
- torch.distributed._functional_collectives.all_reduce_inplace(
81
- tensors[0],
82
- torch.distributed._functional_collectives.REDUCE_OP_TO_STR[
83
- opts.reduceOp.op
84
- ],
85
- self,
86
- )
87
-
88
- return self._work(tensors)
89
-
90
- def broadcast(
91
- self,
92
- tensors: List[torch.Tensor],
93
- opts: dist.BroadcastOptions = ...,
94
- ) -> dist.Work:
95
- assert len(tensors) == 1
96
- assert isinstance(tensors[0], torchax.tensor.Tensor)
97
- tensors[0].copy_(
98
- torch.distributed._functional_collectives.broadcast(
99
- tensors[0], opts.rootRank, group=self
100
- )
101
- )
102
-
103
- return self._work(tensors)
104
-
105
-
106
- dist.Backend.register_backend("jax", ProcessGroupJax)
107
-
108
-
109
- def jax_rendezvous_handler(
110
- url: str, timeout: datetime.timedelta = ..., **kwargs
111
- ):
112
- """Initialize distributed store with JAX process IDs.
113
-
114
- Requires `$MASTER_ADDR` and `$MASTER_PORT`.
115
- """
116
- # TODO(wcromar): jax.distributed.initialize(...) for multiprocess on GPU
117
- # TODO(wcromar): Can we use the XLA coordinator as a Store? This isn't part
118
- # of their public Python API
119
- master_ip = os.environ["MASTER_ADDR"]
120
- master_port = int(os.environ["MASTER_PORT"])
121
- # TODO(wcromar): Use `torchrun`'s store if available
122
- store = dist.TCPStore(
123
- master_ip,
124
- master_port,
125
- jax.process_count(),
126
- is_master=jax.process_index() == 0,
127
- )
128
-
129
- yield (store, jax.process_index(), jax.process_count())
130
-
131
-
132
- dist.register_rendezvous_handler("jax", jax_rendezvous_handler)
133
-
134
-
135
- def spawn(f, args=(), env: Optional[torchax.tensor.Environment] = None):
136
- """Wrap `f` in a JAX `pmap` with the axis name `torch_dist` defined.
137
- `f` is expected to take the replica index as a positional argument, similar
138
- to `torch.multiprocessing.spawn`.
139
- Note: `spawn` does not actually create parallel processes.
140
- """
141
- env = env or torchax.default_env()
142
-
143
- def jax_wrapper(index, jax_args):
144
- index, args = env.j2t_iso([index, jax_args])
145
- torch_outputs = f(index, *args)
146
- return env.t2j_iso(torch_outputs)
147
-
148
- jax_outputs = jax.pmap(jax_wrapper, axis_name="torch_dist")(
149
- np.arange(jax.device_count()), env.t2j_iso(args)
150
- )
151
- return env.j2t_iso(jax_outputs)
152
-
153
-
154
- class DistributedDataParallel(torch.nn.Module):
155
- """Re-implementation of DistributedDataParallel using JAX SPMD.
156
-
157
- Splits inputs along batch dimension (assumed to be 0) across all devices in
158
- JAX runtime, including remote devices. Each process should load a distinct
159
- shard of the input data using e.g. DistributedSampler. Each process' shard
160
- is then further split among the addressable devices (e.g. local TPU chips)
161
- by `shard_input`.
162
-
163
- Note: since parameters are replicated across addressable devices, inputs
164
- must also be SPMD sharded using `shard_input` or `replicate_input`.
165
-
166
- Example usage:
167
-
168
- ```
169
- jax_model = torchax.distributed.DistributedDataParallel(create_model())
170
- for data, dataloader:
171
- jax_data = jax_model.shard_input(data)
172
- jax_output = jax_model(jax_data)
173
- ```
174
- """
175
- def __init__(
176
- self,
177
- module: torch.nn.Module,
178
- env: Optional[torchax.tensor.Environment] = None,
179
- **kwargs,
180
- ):
181
- if kwargs:
182
- logging.warning(f"Unsupported kwargs {kwargs}")
183
-
184
- super().__init__()
185
- self._env = env or torchax.default_env()
186
- self._mesh = Mesh(
187
- mesh_utils.create_device_mesh((jax.device_count(),)),
188
- axis_names=("batch",),
189
- )
190
- replicated_state = torch_pytree.tree_map_only(
191
- torch.Tensor,
192
- lambda t: self._env.j2t_iso(
193
- jax.device_put(
194
- self._env.to_xla(t)._elem, NamedSharding(self._mesh, P())
195
- )
196
- ),
197
- module.state_dict(),
198
- )
199
- # TODO: broadcast
200
- module.load_state_dict(replicated_state, assign=True)
201
- self._module = module
202
-
203
- def shard_input(self, inp):
204
- per_process_batch_size = inp.shape[0] # assumes batch dim is 0
205
- per_replica_batch_size = per_process_batch_size // jax.local_device_count()
206
- per_replica_batches = torch.chunk(inp, jax.local_device_count())
207
- global_batch_size = per_replica_batch_size * jax.device_count()
208
- global_batch_shape = (global_batch_size,) + inp.shape[1:]
209
-
210
- sharding = NamedSharding(self._mesh, P("batch"))
211
- return self._env.j2t_iso(jax.make_array_from_single_device_arrays(
212
- global_batch_shape,
213
- NamedSharding(self._mesh, P("batch")),
214
- arrays=[
215
- jax.device_put(self._env.to_xla(batch)._elem, device)
216
- for batch, device in zip(
217
- per_replica_batches, sharding.addressable_devices
218
- )
219
- ],
220
- ))
221
-
222
- def replicate_input(self, inp):
223
- return self._env.j2t_iso(
224
- jax.device_put(inp._elem, NamedSharding(self._mesh, P()))
225
- )
226
-
227
- def jit_step(self, func):
228
- @functools.partial(interop.jax_jit,
229
- kwargs_for_jax_jit={'donate_argnums': 0})
230
- def _jit_fn(states, args):
231
- self.load_state_dict(states)
232
- outputs = func(*args)
233
- return self.state_dict(), outputs
234
-
235
- @functools.wraps(func)
236
- def inner(*args):
237
- jax_states = self.state_dict()
238
- new_states, outputs = _jit_fn(jax_states, args)
239
- self.load_state_dict(new_states)
240
- return outputs
241
-
242
- return inner
243
-
244
- def forward(self, *args):
245
- with self._env:
246
- return self._module(*args)