torchax 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

torchax/train.py ADDED
@@ -0,0 +1,120 @@
1
+ import collections
2
+ import functools
3
+ import torch
4
+ import jax
5
+ import torchax
6
+ from torchax import interop
7
+ from torchax.interop import torch_view, jax_view
8
+ import optax
9
+
10
+
11
+ remat = torch_view(jax.remat)
12
+ mark_sharding = torch_view(jax.lax.with_sharding_constraint)
13
+
14
+
15
+ def make_train_step(model_fn,
16
+ loss_fn, optax_optimizer,
17
+ remat_policy=None):
18
+ """Make a function that do one train step given model and loss.
19
+
20
+ model_fn: a function representing the model's forward:
21
+ i.e. has signature Callable[weights, buffers, args] -> result. Where,
22
+ weights is a pytree of trainable parameters
23
+ buffers is a pytree of non-trainable parameters / constants
24
+ args is the input data loaded from the data set
25
+ result is the return value of the model
26
+ loss_fn: a function to compute loss.
27
+ i.e. it has signature of Callable[result, label] -> loss
28
+ where, result is what model_fn returned
29
+ loss is loaded from the dataloader.
30
+ optax_optimizer: the optimizer from optax library. for example, optax.adam
31
+ remat_policy: One of jax.ad_checkpoint.checkpoint_policies, specifies how
32
+ to do gradient checkpointing. If None, then it means checkpoint everything.
33
+ """
34
+ env = torchax.default_env()
35
+ def loss(weights, buffers, args, label): # inputs are XLATensor
36
+ with env, jax.named_scope('compute_loss'):
37
+ res = model_fn(weights, buffers, args)
38
+ l = loss_fn(res, label)
39
+ return l
40
+
41
+ loss = interop.gradient_checkpoint(loss, kwargs={'policy': remat_policy})
42
+ grad_fn = interop.jax_value_and_grad(loss)
43
+
44
+ def step(weights, buffers, opt_state, args, label): #inputs are array
45
+ with jax.named_scope('compute_gradient'):
46
+ loss, gradient = grad_fn(weights, buffers, args, label)
47
+
48
+ with jax.named_scope("optimizer_updates"):
49
+ updates, opt_state = interop.call_jax(
50
+ optax_optimizer.update,
51
+ gradient, opt_state, weights)
52
+ weights = interop.call_jax(optax.apply_updates, weights, updates)
53
+ return loss, weights, opt_state
54
+
55
+ # TODO: apply jax.jit so the user don't have to.
56
+ return step
57
+
58
+
59
+
60
+
61
+ class Container:
62
+ pass
63
+
64
+ class ScannedModule(torch.nn.Module):
65
+
66
+ def __init__(self, module_list, checkpoint_policy=None):
67
+ super().__init__()
68
+
69
+ self.c = None
70
+ assert module_list
71
+ self.c = Container()
72
+ self.c.one_mod = module_list[0]
73
+ self.checkpoint_policy = checkpoint_policy
74
+
75
+ weights = self._stack_layer_weights(module_list)
76
+ self.layer_weights_keys = list(self.c.one_mod.state_dict().keys())
77
+ self.params = torch.nn.ParameterDict({
78
+ self._param_name_new(k): v for k, v in weights.items()
79
+ })
80
+
81
+ def _stack_layer_weights(self, module_list):
82
+ # Create weights such that, for every [n, m] weights
83
+ # becomes [k, n, m] where k is number of layer
84
+ # i.e. stacking layer weights together
85
+ temp = collections.defaultdict(list)
86
+ for m in module_list:
87
+ for k, v in m.state_dict().items():
88
+ temp[k].append(v)
89
+ res = {k: torch.stack(v) for k, v in temp.items()}
90
+ return res
91
+
92
+
93
+ def _param_name_new(self, old):
94
+ return '___'.join(old.split('.'))
95
+
96
+ def _param_name_old(self, new):
97
+ return '.'.join(new.split('___'))
98
+
99
+ def forward(self, *args, **kwargs):
100
+ assert not kwargs
101
+ weights = {k: self.params[self._param_name_new(k)] for k in self.layer_weights_keys}
102
+ scan = interop.torch_view(jax.lax.scan)
103
+
104
+ def eval_one_layer(args, weight):
105
+ # unpack args
106
+ h, *rest = args
107
+ newh = torch.func.functional_call(self.c.one_mod, weight, args)
108
+ # next layer's input; and residual to be added to list
109
+ return (newh, *rest), None
110
+
111
+ _eval_one_layer = interop.gradient_checkpoint(
112
+ eval_one_layer,
113
+ kwargs={'policy': self.checkpoint_policy},
114
+ )
115
+ h, _ = scan(
116
+ _eval_one_layer,
117
+ args,
118
+ weights,
119
+ )
120
+ return h[0]
torchax/types.py ADDED
@@ -0,0 +1,12 @@
1
+ from typing import Callable, Any, Union, ParamSpec, TypeAlias
2
+ import torch
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import sys
6
+
7
+ P = ParamSpec('P')
8
+
9
+ TorchValue: TypeAlias = Union[torch.Tensor, torch.dtype, 'TorchCallable', Any]
10
+ TorchCallable: TypeAlias = Callable[P, TorchValue]
11
+ JaxValue: TypeAlias = Union[jax.Array, jnp.dtype, 'JaxCallable', Any]
12
+ JaxCallable: TypeAlias = Callable[P, JaxValue]
@@ -0,0 +1,341 @@
1
+ Metadata-Version: 2.4
2
+ Name: torchax
3
+ Version: 0.0.4
4
+ Summary: torchax is a library for running PyTorch on TPU
5
+ Project-URL: Homepage, https://github.com/pytorch/xla/tree/master/torchax
6
+ Author-email: Han Qi <qihan.dev@gmail.com>, Pytorch/XLA team <pytorchxla-dev@google.com>
7
+ License: BSD 3-Clause License
8
+
9
+ Copyright (c) 2023, pytorch-tpu
10
+
11
+ Redistribution and use in source and binary forms, with or without
12
+ modification, are permitted provided that the following conditions are met:
13
+
14
+ 1. Redistributions of source code must retain the above copyright notice, this
15
+ list of conditions and the following disclaimer.
16
+
17
+ 2. Redistributions in binary form must reproduce the above copyright notice,
18
+ this list of conditions and the following disclaimer in the documentation
19
+ and/or other materials provided with the distribution.
20
+
21
+ 3. Neither the name of the copyright holder nor the names of its
22
+ contributors may be used to endorse or promote products derived from
23
+ this software without specific prior written permission.
24
+
25
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
28
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
29
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
30
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
31
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
32
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
33
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
34
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
35
+ License-File: LICENSE
36
+ Classifier: Development Status :: 3 - Alpha
37
+ Classifier: Intended Audience :: Developers
38
+ Classifier: Intended Audience :: Education
39
+ Classifier: Intended Audience :: Science/Research
40
+ Classifier: License :: OSI Approved :: BSD License
41
+ Classifier: Programming Language :: Python :: 3.10
42
+ Classifier: Programming Language :: Python :: 3.11
43
+ Classifier: Programming Language :: Python :: 3.12
44
+ Classifier: Programming Language :: Python :: 3.13
45
+ Classifier: Topic :: Scientific/Engineering
46
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
47
+ Classifier: Topic :: Scientific/Engineering :: Mathematics
48
+ Classifier: Topic :: Software Development
49
+ Classifier: Topic :: Software Development :: Libraries
50
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
51
+ Requires-Python: >=3.10
52
+ Provides-Extra: cpu
53
+ Requires-Dist: jax[cpu]; extra == 'cpu'
54
+ Requires-Dist: jax[cpu]>=0.4.30; extra == 'cpu'
55
+ Provides-Extra: cuda
56
+ Requires-Dist: jax[cpu]>=0.4.30; extra == 'cuda'
57
+ Requires-Dist: jax[cuda12]; extra == 'cuda'
58
+ Provides-Extra: odml
59
+ Requires-Dist: jax[cpu]; extra == 'odml'
60
+ Requires-Dist: jax[cpu]>=0.4.30; extra == 'odml'
61
+ Provides-Extra: tpu
62
+ Requires-Dist: jax[cpu]>=0.4.30; extra == 'tpu'
63
+ Requires-Dist: jax[tpu]; extra == 'tpu'
64
+ Description-Content-Type: text/markdown
65
+
66
+ # torchax: Running PyTorch on TPU
67
+
68
+ **torchax!** is a backend for PyTorch, allowing users to run
69
+ PyTorch on Google CloudTPUs. **torchax!** is also a library for providing
70
+ graph-level interoperability between PyTorch and Jax.
71
+
72
+ This means, with **torchax** you can:
73
+ * Run PyTorch code on TPU with as little as 2 lines of code change.
74
+ * Call a jax function from a pytorch function, passing in `jax.Array`s
75
+ * Call a pytorch function from a jax function, passing in a `torch.Tensor` subclass.
76
+ * Use jax features such as `jax.grad`, `optax` and `GSMPD` to train a Pytorch model.
77
+ * Use a Pytorch model as feature extractor and use it with a Jax model.
78
+ etc etc.
79
+
80
+ ## Install
81
+
82
+
83
+ ### On Google Cloud TPU:
84
+ First install torch CPU:
85
+
86
+ ```bash
87
+ pip install torch --index-url https://download.pytorch.org/whl/cpu
88
+ ```
89
+
90
+ Then install jax TPU:
91
+
92
+ ```bash
93
+ pip install -U jax[tpu]
94
+ ```
95
+
96
+ Finally install torchax
97
+
98
+ ```bash
99
+ pip install torchax
100
+ ```
101
+
102
+ ### On GPU machines:
103
+ First install torch CPU:
104
+
105
+ ```bash
106
+ pip install torch --index-url https://download.pytorch.org/whl/cpu
107
+ ```
108
+
109
+ Then install jax CUDA:
110
+
111
+ ```bash
112
+ pip install -U jax[cuda12]
113
+ ```
114
+
115
+ Finally install torchax
116
+
117
+ ```bash
118
+ pip install torchax
119
+ ```
120
+
121
+ ### On CPU machines (mac included)
122
+ First install torch CPU:
123
+
124
+ ```bash
125
+ # Linux
126
+ pip install torch --index-url https://download.pytorch.org/whl/cpu
127
+
128
+ # OR Mac:
129
+ pip install torch
130
+ ```
131
+
132
+ Then install jax CPU:
133
+
134
+ ```bash
135
+ pip install -U jax
136
+ ```
137
+
138
+ Finally install torchax
139
+
140
+ ```bash
141
+ pip install torchax
142
+ ```
143
+
144
+ NOTE: if you like metal support for Apple devices then install the
145
+ metal version of jax: https://developer.apple.com/metal/jax/
146
+
147
+ ### Installing `torchax` from source
148
+
149
+ Still need to install `torch` CPU and `Jax` of your accelerator (GPU, TPU or None).
150
+
151
+ ```bash
152
+ pip install git+https://github.com/pytorch/xla.git#subdirectory=torchax
153
+ ```
154
+
155
+ ## Run a model
156
+
157
+ Now let's execute a model under torchax. We'll start with a simple 2-layer model
158
+ it can be in theory any instance of `torch.nn.Module`.
159
+
160
+ ```python
161
+ import torch
162
+ import torch.nn as nn
163
+ import torch.nn.functional as F
164
+
165
+
166
+ class MyModel(nn.Module):
167
+ def __init__(self):
168
+ super().__init__()
169
+ self.fc1 = nn.Linear(28 * 28, 120)
170
+ self.fc2 = nn.Linear(120, 84)
171
+ self.fc3 = nn.Linear(84, 10)
172
+
173
+ def forward(self, x):
174
+ x = x.view(-1, 28 * 28)
175
+ x = F.relu(self.fc1(x))
176
+ x = F.relu(self.fc2(x))
177
+ x = self.fc3(x)
178
+ return x
179
+
180
+ m = MyModel()
181
+
182
+ # Execute this model using torch
183
+ inputs = torch.randn(3, 3, 28, 28)
184
+ print(m(inputs))
185
+ ```
186
+
187
+ This model `m` contains 2 parts: the weights that is stored inside of the model
188
+ and it's submodules (`nn.Linear`).
189
+
190
+ To execute this model with `torchax`; we need to enable torchax to capture pytorch ops.
191
+ To enable this, use:
192
+
193
+ ```python
194
+ import torchax
195
+ torchax.enable_globally()
196
+ ```
197
+ Then, a `jax` device will be available to use
198
+
199
+ ```python
200
+ inputs = torch.randn(3, 3, 28, 28, device='jax')
201
+ m = MyModel()
202
+ res = m(inputs)
203
+ print(type(res)) # outputs torchax.tensor.Tensor
204
+ ```
205
+
206
+ `torchax.tensor.Tensor` is a `torch.Tensor` subclass that holds
207
+ a `jax.Array`. You can inspect that jax array with `res.jax()`
208
+
209
+
210
+ ## What is happening behind the scene:
211
+
212
+ We took the approach detailed in [new device](https://github.com/albanD/subclass_zoo/blob/main/new_device.py) recipe by Alban (@albanD); using `jax.Array` for the `raw_data`.
213
+
214
+ In other words, When a torch op is executed inside of `env` context manager (which is enabled with `torchax.enable_globally()`), we can swap out the
215
+ implementation of that op written in Jax.
216
+
217
+ When a model's constructor runs, it will call some tensor constructor, such as
218
+ `torch.rand`, `torch.ones` or `torch.zeros` etc to create its weights. The constructor
219
+ will create an `torch.Tensor` subclass that contains a `jax.Array`.
220
+
221
+ Then, each subsequent op can unpack the `jax.Array`, call the op implementation,
222
+ and wraps it back into `torch.Tensor` subclass.
223
+
224
+ See more at [how_it_works](docs/how_it_works.md) and [ops registry](docs/ops_registry.md).
225
+
226
+
227
+ ### Executing with jax.jit
228
+
229
+ The above script will execute the model using eager mode Jax as backend. This
230
+ does allow executing torch models on TPU, but is often slower than what we can
231
+ achieve with `jax.jit`.
232
+
233
+ `jax.jit` is a function that takes a Jax function (i.e. a function that takes jax array
234
+ and returns jax array) into the same function, but faster.
235
+
236
+ We have made the `jax_jit` decorator that would accomplish the same with functions
237
+ that takes and returns `torch.Tensor`. To use this, the first step is to create
238
+ a functional version of this model: this means the parameters should be passed in
239
+ as input instead of being attributes on class:
240
+
241
+
242
+ ```python
243
+
244
+ def model_func(param, inputs):
245
+ return torch.func.functional_call(m, param, inputs)
246
+
247
+ ```
248
+ Here we use [torch.func.functional_call](https://pytorch.org/docs/stable/generated/torch.func.functional_call.html)
249
+ from PyTorch to replace the model
250
+ weights with `param`, then call the model. This is roughly equivalent to:
251
+
252
+ ```python
253
+ def model_func(param, inputs):
254
+ m.load_state_dict(param)
255
+ return m(*inputs)
256
+ ```
257
+
258
+ Now, we can apply `jax_jit`
259
+
260
+ ```python
261
+ from torchax.interop import jax_jit
262
+ model_func_jitted = jax_jit(model_func)
263
+ print(model_func_jitted(new_state_dict, inputs))
264
+ ```
265
+
266
+ See more examples at [eager_mode.py](examples/eager_mode.py) and the (examples folder)[examples/]
267
+
268
+ However, to ease the idiom of creating functional model and calling it with parameters,
269
+ we also created the `JittableModule` helper class.
270
+
271
+ So the above can be written as:
272
+
273
+ ```python
274
+
275
+ from torchax.interop import JittableModule
276
+
277
+ m_jitted = JittableModule(m)
278
+ res = m_jitted(...)
279
+ ```
280
+
281
+ The first time that `m_jitted` is called , it will trigger `jax.jit`
282
+ then the subsequent computation with inputs of same shape will be fast.
283
+
284
+
285
+
286
+ # Citation:
287
+
288
+ @software{torchax,
289
+ author = {Han Qi, Chun-nien Chan, Will Cromar, Manfei Bai, Kevin Gleanson},
290
+ title = {torchax: PyTorch on TPU and Jax interoperability},
291
+ url = {https://github.com/pytorch/xla/tree/master/torchax}
292
+ version = {0.0.4},
293
+ date = {2025-02-24},
294
+ }
295
+
296
+ # Maintainers & Contributors:
297
+
298
+ This library is created and maintained by the PyTorch/XLA team at Google Cloud.
299
+
300
+ However, it benefitted from many direct and indirect
301
+ contributions outside of the team. Many of them done by
302
+ fellow Googlers using [Google's 20% project policy](https://ebsedu.org/blog/google-tapping-workplace-actualization-20-time-rule), others by partner teams.
303
+
304
+ Here is the full list of contributors by 2025-02-25.
305
+
306
+ Han Qi (qihqi), Pytorch / XLA
307
+ Manfei Bai (manfeibai), Pytorch / XLA
308
+ Will Cromar (will-cromar), Meta
309
+ Milad Mohammadi (miladm), Pytorch / XLA
310
+ Siyuan Liu (lsy323), Pytorch / XLA
311
+ Bhavya Bahl (bhavya01), Pytorch / XLA
312
+ Pei Zhang (zpcore), Pytorch / XLA
313
+ Yifei Teng (tengyifei), Pytorch / XLA
314
+ Chunnien Chan (chunnienc), Google, ODML
315
+ Alban Desmaison (albanD), Meta, Pytorch
316
+ Simon Teo (simonteozw), Google(20%)
317
+ David Huang (dvhg), Google(20%)
318
+ Barni Seetharaman (barney-s), Google(20%)
319
+ Anish Karthik (anishfish2) , Google(20%)
320
+ Yao Gu (guyao) , Google(20%)
321
+ Yenkai Wang (yenkwang) , Google(20%)
322
+ Greg Shikhman (commander) , Google(20%)
323
+ Matin Akhlaghinia (matinehAkhlaghinia), Google(20%)
324
+ Tracy Chen (tracych477), Google(20%)
325
+ Matthias Guenther (mrguenther) , Google(20%)
326
+ WenXin Dong (wenxindongwork), Google(20%)
327
+ Kevin Gleason (GleasonK) , Google, StableHLO
328
+ Nupur Baghel (nupurbaghel), Google(20%)
329
+ Gwen Mittertreiner (gmittert), Google(20%)
330
+ Zeev Melumian (zmelumian), Lightricks
331
+ Vyom Sharma (vyom1611), Google(20%)
332
+ Shitong Wang (ShitongWang), Adobe
333
+ Rémi Doreau (ayshiff), Google(20%)
334
+ Lance Wang (wang2yn84), Google, CoreML
335
+ Hossein Sarshar (hosseinsarshar) , Google(20%)
336
+ Daniel Vega-Myhre (danielvegamyhre) , Google(20%)
337
+ Tianqi Fan (tqfan28), Google(20%)
338
+ Jim Lin (jimlinntu), Google(20%)
339
+ Fanhai Lu (FanhaiLu1), Google Cloud
340
+ DeWitt Clinton (dewitt), Google PyTorch
341
+ Aman Gupta (aman2930) , Google(20%)
@@ -0,0 +1,27 @@
1
+ torchax/CONTRIBUTING.md,sha256=T6Uy7vZvzXArbPqqUswEtQBgTU_eRwtJnHOiZNth2Pw,1371
2
+ torchax/__init__.py,sha256=0NoxpzaPIMYMJXZz_JiRCTYZCuICnTFrLDm5cZKGIm0,3242
3
+ torchax/config.py,sha256=Zpfcn7Q3QsKnn4cCx1-bYxA8q_lqrWJcUz30B3QcIc4,535
4
+ torchax/decompositions.py,sha256=C067xSd9MpGs1SDqlC5T9RnqNj-i-m31mysfFGDxLEQ,11221
5
+ torchax/device_module.py,sha256=NNEGSfk9ApVVSy5dRjHkvGuNGamopCEjSlItcNlAIbE,253
6
+ torchax/distributed.py,sha256=CXkpV0K2Oas7fcRA-i3s2VXTz8l4CEI7_NnbYOsHtRw,7519
7
+ torchax/environment.py,sha256=daEdpEyAJIa8b2VkCqSKcw8PaExcB6Qro80XNes_sHA,2
8
+ torchax/export.py,sha256=MG19Y0QcRcX3S3gruy9K-wFGdemD4I0ii03HWQM31xk,8853
9
+ torchax/interop.py,sha256=fncREqC1Nib-ud7wyekmtKaif7GSDql7ITmzTt5A9e4,6774
10
+ torchax/tensor.py,sha256=89yXmyO1uT6F1KI-Kn6NCnIieCiiMsDutgS071_9p64,16644
11
+ torchax/tf_integration.py,sha256=d_h4vSJm7N9rJXpUPNCDOiUz3J1-UPo3KU8D9Wi4nnc,4074
12
+ torchax/train.py,sha256=Ym_SC0WOmhQMhSCH0aZ-zgP7Q19WdCDTeoI-JBFrqx8,3930
13
+ torchax/types.py,sha256=j4ERjkgDgwhgi9zrwwbbiv4HMDlrJ1IEMUCmP_BIJ9M,388
14
+ torchax/ops/__init__.py,sha256=Vr1p8zDHwfXZBUbw70iNiCJLZLNdI6gR_vUlaiA7Usg,270
15
+ torchax/ops/jaten.py,sha256=ycTSyCzgJpz5zu936BTZXmRJWuzhxeGOOjoffU00c3g,154999
16
+ torchax/ops/jax_reimplement.py,sha256=29SGPF0bszgCuyVxiZ4A2xTPtAuPXBDvDG79M3C0Vxo,7468
17
+ torchax/ops/jc10d.py,sha256=DiJjiUFHyYmFUHEgqEDbGsLaCsJs94xJlBatRwXiPfg,1317
18
+ torchax/ops/jlibrary.py,sha256=R22GUlK6lUivuDSmvtqwbV9ku_uT0ddLcQcwHI96vqM,2927
19
+ torchax/ops/jtorch.py,sha256=fWU2z5uD3FpVE5ui-O5nfHEule2hod4BqCnsbV5jVfQ,14463
20
+ torchax/ops/jtorchvision_nms.py,sha256=SiMN9bUV7GADD2S2PkH9spn5oQGgr2jhFQ1nyd3eHgI,8873
21
+ torchax/ops/mappings.py,sha256=a8C5CYIEucGIxUPQHCsDZypBC3b_BYtOKQTWf9bdiRc,3060
22
+ torchax/ops/op_base.py,sha256=EdlAx3oqU2brMngjZvTxECin0eNiPdAUfOLRtFL3tgg,2900
23
+ torchax/ops/ops_registry.py,sha256=nXCAXiQNpC0uBUl8xEim7TOUTyROZm_4H1LaOvYwJVQ,1241
24
+ torchax-0.0.4.dist-info/METADATA,sha256=4VQb39quiN9U9j955fs9zY00x4t4Orubo92ONCCLyug,11119
25
+ torchax-0.0.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
26
+ torchax-0.0.4.dist-info/licenses/LICENSE,sha256=ZHyir3-ltOerFLt9JH1bjf7lIxIWipFmqeMnB_8z_aU,1498
27
+ torchax-0.0.4.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.27.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
@@ -0,0 +1,28 @@
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2023, pytorch-tpu
4
+
5
+ Redistribution and use in source and binary forms, with or without
6
+ modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation
13
+ and/or other materials provided with the distribution.
14
+
15
+ 3. Neither the name of the copyright holder nor the names of its
16
+ contributors may be used to endorse or promote products derived from
17
+ this software without specific prior written permission.
18
+
19
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.