torchax 0.0.10.dev20251114__py3-none-any.whl → 0.0.11.dev202612__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/__init__.py +73 -77
- torchax/amp.py +143 -271
- torchax/checkpoint.py +15 -9
- torchax/config.py +0 -4
- torchax/decompositions.py +66 -60
- torchax/export.py +53 -54
- torchax/flax.py +7 -5
- torchax/interop.py +66 -62
- torchax/mesh_util.py +20 -18
- torchax/ops/__init__.py +4 -3
- torchax/ops/jaten.py +3841 -3968
- torchax/ops/jax_reimplement.py +68 -42
- torchax/ops/jc10d.py +4 -6
- torchax/ops/jimage.py +20 -25
- torchax/ops/jlibrary.py +6 -6
- torchax/ops/jtorch.py +355 -419
- torchax/ops/jtorchvision_nms.py +69 -49
- torchax/ops/mappings.py +42 -63
- torchax/ops/op_base.py +17 -25
- torchax/ops/ops_registry.py +35 -30
- torchax/tensor.py +124 -128
- torchax/train.py +100 -102
- torchax/types.py +8 -7
- torchax/util.py +6 -4
- torchax/view.py +144 -136
- {torchax-0.0.10.dev20251114.dist-info → torchax-0.0.11.dev202612.dist-info}/METADATA +7 -1
- torchax-0.0.11.dev202612.dist-info/RECORD +31 -0
- {torchax-0.0.10.dev20251114.dist-info → torchax-0.0.11.dev202612.dist-info}/WHEEL +1 -1
- torchax-0.0.10.dev20251114.dist-info/RECORD +0 -31
- {torchax-0.0.10.dev20251114.dist-info → torchax-0.0.11.dev202612.dist-info}/licenses/LICENSE +0 -0
torchax/view.py
CHANGED
|
@@ -12,12 +12,14 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
import
|
|
16
|
-
|
|
17
|
-
import jax
|
|
18
|
-
from enum import Enum
|
|
19
|
-
from typing import Union, List, Tuple, Optional, Any, cast
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
20
17
|
from abc import ABC, abstractmethod
|
|
18
|
+
from enum import Enum
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
import jax
|
|
22
|
+
import torch
|
|
21
23
|
|
|
22
24
|
# Reference to original PyTorch native functions
|
|
23
25
|
# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
|
|
@@ -37,76 +39,75 @@ class ViewInfoType(Enum):
|
|
|
37
39
|
|
|
38
40
|
class ViewInfo(ABC):
|
|
39
41
|
"""
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
42
|
+
Abstract base class for all view operations.
|
|
43
|
+
Defines the interface for applying and updating view transformations.
|
|
44
|
+
"""
|
|
43
45
|
|
|
44
46
|
def __init__(
|
|
45
|
-
|
|
46
|
-
|
|
47
|
+
self,
|
|
48
|
+
view_info_type: ViewInfoType = ViewInfoType.INVALID,
|
|
47
49
|
):
|
|
48
50
|
"""
|
|
49
|
-
|
|
51
|
+
Initialize a ViewInfo object.
|
|
50
52
|
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
53
|
+
Args:
|
|
54
|
+
view_info_type: The type of view operation
|
|
55
|
+
"""
|
|
54
56
|
self.view_info_type = view_info_type
|
|
55
57
|
|
|
56
58
|
@abstractmethod
|
|
57
|
-
def update_tensor(self, new_value: jax.Array,
|
|
58
|
-
jax_array: jax.Array) -> jax.Array:
|
|
59
|
+
def update_tensor(self, new_value: jax.Array, jax_array: jax.Array) -> jax.Array:
|
|
59
60
|
"""
|
|
60
|
-
|
|
61
|
+
Apply this view transformation to a JAX array and update its value.
|
|
61
62
|
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
63
|
+
Args:
|
|
64
|
+
new_value: The new values to set in the view
|
|
65
|
+
jax_array: The parent array to update
|
|
65
66
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
67
|
+
Returns:
|
|
68
|
+
Updated array
|
|
69
|
+
"""
|
|
69
70
|
pass
|
|
70
71
|
|
|
71
72
|
@abstractmethod
|
|
72
73
|
def transform_tensor(self, jax_array: jax.Array) -> jax.Array:
|
|
73
74
|
"""
|
|
74
|
-
|
|
75
|
+
Apply this view transformation to a JAX array.
|
|
75
76
|
|
|
76
|
-
|
|
77
|
-
|
|
77
|
+
Args:
|
|
78
|
+
jax_array: The array to transform
|
|
78
79
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
80
|
+
Returns:
|
|
81
|
+
Transformed array
|
|
82
|
+
"""
|
|
82
83
|
pass
|
|
83
84
|
|
|
84
85
|
@abstractmethod
|
|
85
|
-
def calculate_output_shape(self, source: jax.Array) ->
|
|
86
|
+
def calculate_output_shape(self, source: jax.Array) -> list[int]:
|
|
86
87
|
"""
|
|
87
|
-
|
|
88
|
+
Calculate the resulting shape after applying this view.
|
|
88
89
|
|
|
89
|
-
|
|
90
|
-
|
|
90
|
+
Args:
|
|
91
|
+
source: Original jax array before transformation
|
|
91
92
|
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
93
|
+
Returns:
|
|
94
|
+
Resulting shape after transformation
|
|
95
|
+
"""
|
|
95
96
|
pass
|
|
96
97
|
|
|
97
98
|
|
|
98
99
|
class NarrowInfo(ViewInfo):
|
|
99
100
|
"""
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
101
|
+
Represents a slicing operation on a tensor.
|
|
102
|
+
Handles operations like tensor[1:3, :, 2:5:2].
|
|
103
|
+
"""
|
|
103
104
|
|
|
104
|
-
def __init__(self, slices:
|
|
105
|
+
def __init__(self, slices: slice | tuple[slice]) -> None:
|
|
106
|
+
"""
|
|
107
|
+
Args:
|
|
108
|
+
slices: The slice(s) to apply to the tensor.
|
|
109
|
+
E.g. jax_array.at[slices] will return the transformed tensor.
|
|
105
110
|
"""
|
|
106
|
-
Args:
|
|
107
|
-
slices: The slice(s) to apply to the tensor.
|
|
108
|
-
E.g. jax_array.at[slices] will return the transformed tensor.
|
|
109
|
-
"""
|
|
110
111
|
super().__init__(ViewInfoType.NARROW)
|
|
111
112
|
self.slices = slices
|
|
112
113
|
|
|
@@ -121,25 +122,22 @@ class NarrowInfo(ViewInfo):
|
|
|
121
122
|
except IndexError as e:
|
|
122
123
|
raise IndexError("Invalid slice operation") from e
|
|
123
124
|
|
|
124
|
-
def update_tensor(self, new_value: jax.Array,
|
|
125
|
-
jax_array: jax.Array) -> jax.Array:
|
|
125
|
+
def update_tensor(self, new_value: jax.Array, jax_array: jax.Array) -> jax.Array:
|
|
126
126
|
return jax_array.at[self.slices].set(new_value)
|
|
127
127
|
|
|
128
|
-
def calculate_output_shape(self, source: jax.Array) ->
|
|
128
|
+
def calculate_output_shape(self, source: jax.Array) -> list[int]:
|
|
129
129
|
return source[self.slices].shape
|
|
130
130
|
|
|
131
131
|
|
|
132
132
|
class SelectInfo(ViewInfo):
|
|
133
133
|
"""
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
134
|
+
Represents a selection operation on a tensor.
|
|
135
|
+
Typically used for indexing operations that select specific elements.
|
|
136
|
+
"""
|
|
137
137
|
|
|
138
|
-
def __init__(
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
end: int = 0,
|
|
142
|
-
stride: int = 0) -> None:
|
|
138
|
+
def __init__(
|
|
139
|
+
self, dim: int = 0, start: int = 0, end: int = 0, stride: int = 0
|
|
140
|
+
) -> None:
|
|
143
141
|
super().__init__(ViewInfoType.SELECT)
|
|
144
142
|
self.dim: int = dim
|
|
145
143
|
self.start: int = start
|
|
@@ -149,29 +147,31 @@ class SelectInfo(ViewInfo):
|
|
|
149
147
|
def __eq__(self, other: object) -> bool:
|
|
150
148
|
if not isinstance(other, SelectInfo):
|
|
151
149
|
return False
|
|
152
|
-
return (
|
|
153
|
-
|
|
150
|
+
return (
|
|
151
|
+
self.dim == other.dim
|
|
152
|
+
and self.start == other.start
|
|
153
|
+
and self.end == other.end
|
|
154
|
+
and self.stride == other.stride
|
|
155
|
+
)
|
|
154
156
|
|
|
155
157
|
def transform_tensor(self, jax_array: jax.Array) -> jax.Array:
|
|
156
158
|
raise NotImplementedError("SelectInfo.apply not implemented")
|
|
157
159
|
|
|
158
|
-
def update_tensor(self, new_value: jax.Array,
|
|
159
|
-
jax_array: jax.Array) -> jax.Array:
|
|
160
|
+
def update_tensor(self, new_value: jax.Array, jax_array: jax.Array) -> jax.Array:
|
|
160
161
|
raise NotImplementedError("SelectInfo.update not implemented")
|
|
161
162
|
|
|
162
|
-
def calculate_output_shape(self, source: jax.Array) ->
|
|
163
|
-
raise NotImplementedError(
|
|
164
|
-
"SelectInfo.calculate_output_shape not implemented")
|
|
163
|
+
def calculate_output_shape(self, source: jax.Array) -> list[int]:
|
|
164
|
+
raise NotImplementedError("SelectInfo.calculate_output_shape not implemented")
|
|
165
165
|
|
|
166
166
|
|
|
167
167
|
class AsStridedInfo(ViewInfo):
|
|
168
168
|
"""
|
|
169
|
-
|
|
170
|
-
|
|
169
|
+
Information for as_strided operations.
|
|
170
|
+
"""
|
|
171
171
|
|
|
172
|
-
def __init__(self, stride:
|
|
172
|
+
def __init__(self, stride: list[int], offset: int = 0) -> None:
|
|
173
173
|
super().__init__(ViewInfoType.AS_STRIDED)
|
|
174
|
-
self.stride:
|
|
174
|
+
self.stride: list[int] = stride
|
|
175
175
|
self.offset: int = offset
|
|
176
176
|
|
|
177
177
|
def __eq__(self, other: object) -> bool:
|
|
@@ -182,28 +182,26 @@ class AsStridedInfo(ViewInfo):
|
|
|
182
182
|
def transform_tensor(self, jax_array: jax.Array) -> jax.Array:
|
|
183
183
|
raise NotImplementedError("AsStridedInfo.apply not implemented")
|
|
184
184
|
|
|
185
|
-
def update_tensor(self, new_value: jax.Array,
|
|
186
|
-
jax_array: jax.Array) -> jax.Array:
|
|
185
|
+
def update_tensor(self, new_value: jax.Array, jax_array: jax.Array) -> jax.Array:
|
|
187
186
|
raise NotImplementedError("AsStridedInfo.update not implemented")
|
|
188
187
|
|
|
189
|
-
def calculate_output_shape(self, source: jax.Array) ->
|
|
190
|
-
raise NotImplementedError(
|
|
191
|
-
"AsStridedInfo.calculate_output_shape not implemented")
|
|
188
|
+
def calculate_output_shape(self, source: jax.Array) -> list[int]:
|
|
189
|
+
raise NotImplementedError("AsStridedInfo.calculate_output_shape not implemented")
|
|
192
190
|
|
|
193
191
|
|
|
194
192
|
class DiagonalInfo(ViewInfo):
|
|
195
193
|
"""
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
194
|
+
Information for diagonal operations.
|
|
195
|
+
Extracts diagonal elements from a tensor.
|
|
196
|
+
"""
|
|
199
197
|
|
|
200
198
|
def __init__(self, offset: int = 0, dim1: int = 0, dim2: int = 1) -> None:
|
|
201
199
|
"""
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
200
|
+
Args:
|
|
201
|
+
offset: Offset from the main diagonal
|
|
202
|
+
dim1: First dimension for diagonal extraction
|
|
203
|
+
dim2: Second dimension for diagonal extraction
|
|
204
|
+
"""
|
|
207
205
|
super().__init__(ViewInfoType.DIAGONAL)
|
|
208
206
|
self.offset: int = offset
|
|
209
207
|
self.dim1: int = dim1
|
|
@@ -212,56 +210,65 @@ class DiagonalInfo(ViewInfo):
|
|
|
212
210
|
def __eq__(self, other: object) -> bool:
|
|
213
211
|
if not isinstance(other, DiagonalInfo):
|
|
214
212
|
return False
|
|
215
|
-
return (
|
|
216
|
-
|
|
213
|
+
return (
|
|
214
|
+
self.offset == other.offset
|
|
215
|
+
and self.dim1 == other.dim1
|
|
216
|
+
and self.dim2 == other.dim2
|
|
217
|
+
)
|
|
217
218
|
|
|
218
219
|
def transform_tensor(self, jax_array: jax.Array) -> jax.Array:
|
|
219
220
|
raise NotImplementedError("DiagonalInfo.apply not implemented")
|
|
220
221
|
|
|
221
|
-
def update_tensor(self, new_value: jax.Array,
|
|
222
|
-
jax_array: jax.Array) -> jax.Array:
|
|
222
|
+
def update_tensor(self, new_value: jax.Array, jax_array: jax.Array) -> jax.Array:
|
|
223
223
|
raise NotImplementedError("DiagonalInfo.update not implemented")
|
|
224
224
|
|
|
225
|
-
def calculate_output_shape(self, source: jax.Array) ->
|
|
226
|
-
raise NotImplementedError(
|
|
227
|
-
"DiagonalInfo.calculate_output_shape not implemented")
|
|
225
|
+
def calculate_output_shape(self, source: jax.Array) -> list[int]:
|
|
226
|
+
raise NotImplementedError("DiagonalInfo.calculate_output_shape not implemented")
|
|
228
227
|
|
|
229
228
|
|
|
230
229
|
class View(torch.Tensor):
|
|
231
230
|
"""
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
231
|
+
A View is a reference to another Tensor or another View,
|
|
232
|
+
with a transformation applied to it.
|
|
233
|
+
"""
|
|
235
234
|
|
|
236
235
|
@staticmethod
|
|
237
|
-
def __new__(
|
|
238
|
-
|
|
236
|
+
def __new__(
|
|
237
|
+
cls,
|
|
238
|
+
parent: torchax.Tensor | View, # noqa: F821
|
|
239
|
+
view_info: ViewInfo,
|
|
240
|
+
env: Any,
|
|
241
|
+
) -> View:
|
|
242
|
+
"""
|
|
243
|
+
Args:
|
|
244
|
+
parent: Parent tensor or view
|
|
245
|
+
view_info: Information about the view transformation
|
|
246
|
+
env: Environment for tensor operations
|
|
239
247
|
"""
|
|
240
|
-
Args:
|
|
241
|
-
parent: Parent tensor or view
|
|
242
|
-
view_info: Information about the view transformation
|
|
243
|
-
env: Environment for tensor operations
|
|
244
|
-
"""
|
|
245
248
|
shape = view_info.calculate_output_shape(parent.jax())
|
|
246
249
|
return torch.Tensor._make_wrapper_subclass(
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
250
|
+
cls,
|
|
251
|
+
shape,
|
|
252
|
+
device="meta",
|
|
253
|
+
dtype=parent.dtype,
|
|
254
|
+
requires_grad=False,
|
|
252
255
|
)
|
|
253
256
|
|
|
254
|
-
def __init__(
|
|
255
|
-
|
|
257
|
+
def __init__(
|
|
258
|
+
self,
|
|
259
|
+
parent: torchax.Tensor | View, # noqa: F821
|
|
260
|
+
view_info: ViewInfo,
|
|
261
|
+
env: Any,
|
|
262
|
+
) -> None:
|
|
256
263
|
super().__init__()
|
|
257
264
|
self.parent = parent
|
|
258
265
|
self.view_info = view_info
|
|
259
266
|
self._env = env
|
|
260
267
|
|
|
261
|
-
def get_transformation_chain(self) ->
|
|
268
|
+
def get_transformation_chain(self) -> list[ViewInfo]:
|
|
269
|
+
"""
|
|
270
|
+
Get all view transformations from the source tensor to this view.
|
|
262
271
|
"""
|
|
263
|
-
Get all view transformations from the source tensor to this view.
|
|
264
|
-
"""
|
|
265
272
|
if isinstance(self.parent, View):
|
|
266
273
|
transformations = self.parent.get_transformation_chain()
|
|
267
274
|
transformations.append(self.view_info)
|
|
@@ -273,8 +280,8 @@ class View(torch.Tensor):
|
|
|
273
280
|
|
|
274
281
|
def source_jax(self) -> jax.Array:
|
|
275
282
|
"""
|
|
276
|
-
|
|
277
|
-
|
|
283
|
+
Returns the source tensor.
|
|
284
|
+
"""
|
|
278
285
|
if isinstance(self.parent, View):
|
|
279
286
|
return self.parent.source_jax()
|
|
280
287
|
else:
|
|
@@ -282,32 +289,32 @@ class View(torch.Tensor):
|
|
|
282
289
|
|
|
283
290
|
def replace_source_jax(self, new_value: jax.Array) -> None:
|
|
284
291
|
"""
|
|
285
|
-
|
|
286
|
-
|
|
292
|
+
Update the source tensor with new values.
|
|
293
|
+
"""
|
|
287
294
|
if isinstance(self.parent, View):
|
|
288
295
|
self.parent.replace_source_jax(new_value)
|
|
289
296
|
else:
|
|
290
297
|
assert new_value.shape == self.parent._elem.shape
|
|
291
298
|
self.parent._elem = new_value
|
|
292
299
|
|
|
293
|
-
def torch(self) ->
|
|
300
|
+
def torch(self) -> torchax.Tensor: # noqa: F821
|
|
301
|
+
"""
|
|
302
|
+
Returns a Torchax tensor representing this view after all transformations
|
|
294
303
|
"""
|
|
295
|
-
Returns a Torchax tensor representing this view after all transformations
|
|
296
|
-
"""
|
|
297
304
|
from torchax.tensor import Tensor
|
|
298
305
|
|
|
299
306
|
return Tensor(self.jax(), self._env)
|
|
300
307
|
|
|
301
308
|
def update(
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
309
|
+
self,
|
|
310
|
+
new_values: jax.Array | View | torchax.Tensor, # noqa: F821
|
|
311
|
+
view_infos: list[ViewInfo] | None = None,
|
|
305
312
|
) -> None:
|
|
306
313
|
"""
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
314
|
+
Update this view with new values, propagating changes back to source.
|
|
315
|
+
If view_infos is None, it will use the transformation chain
|
|
316
|
+
from the source tensor.
|
|
317
|
+
"""
|
|
311
318
|
if view_infos is None:
|
|
312
319
|
view_infos = self.get_transformation_chain()
|
|
313
320
|
|
|
@@ -324,14 +331,14 @@ class View(torch.Tensor):
|
|
|
324
331
|
# And store intermediate values
|
|
325
332
|
intermediate_values = [source_array]
|
|
326
333
|
for view_info in view_infos[:-1]:
|
|
327
|
-
intermediate_values.append(
|
|
328
|
-
view_info.transform_tensor(intermediate_values[-1]))
|
|
334
|
+
intermediate_values.append(view_info.transform_tensor(intermediate_values[-1]))
|
|
329
335
|
|
|
330
336
|
# TODO: Investigate efficiency of this algorithm
|
|
331
337
|
# Update the source array with the new value by
|
|
332
338
|
# applying inverse transformations in reverse order
|
|
333
339
|
for view_info, parent_array in zip(
|
|
334
|
-
|
|
340
|
+
reversed(view_infos), reversed(intermediate_values), strict=False
|
|
341
|
+
):
|
|
335
342
|
# Apply the inverse transformation to propagate changes back
|
|
336
343
|
new_values = view_info.update_tensor(new_values, parent_array)
|
|
337
344
|
|
|
@@ -340,21 +347,22 @@ class View(torch.Tensor):
|
|
|
340
347
|
|
|
341
348
|
@classmethod
|
|
342
349
|
def __torch_dispatch__(
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
350
|
+
cls,
|
|
351
|
+
func: Any,
|
|
352
|
+
types: tuple[Any, ...],
|
|
353
|
+
args: tuple[Any, ...] = (),
|
|
354
|
+
kwargs: dict | None = None,
|
|
348
355
|
) -> Any:
|
|
349
356
|
raise AssertionError(
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
357
|
+
"torchax Tensors can only do math within the torchax environment."
|
|
358
|
+
"Please wrap your code with `with torchax.default_env()` or "
|
|
359
|
+
"call torchax.enable_globally() before."
|
|
360
|
+
)
|
|
353
361
|
|
|
354
|
-
def create_sub_view(self, view_info: ViewInfo) ->
|
|
362
|
+
def create_sub_view(self, view_info: ViewInfo) -> View:
|
|
363
|
+
"""
|
|
364
|
+
Create a new view that is a child of this view.
|
|
355
365
|
"""
|
|
356
|
-
Create a new view that is a child of this view.
|
|
357
|
-
"""
|
|
358
366
|
return View(self, view_info, self._env)
|
|
359
367
|
|
|
360
368
|
def __str__(self) -> str:
|
|
@@ -362,8 +370,8 @@ class View(torch.Tensor):
|
|
|
362
370
|
|
|
363
371
|
def jax(self) -> jax.Array:
|
|
364
372
|
"""
|
|
365
|
-
|
|
366
|
-
|
|
373
|
+
Returns a copy of the source tensor after transformations.
|
|
374
|
+
"""
|
|
367
375
|
result = self.source_jax()
|
|
368
376
|
for view_info in self.get_transformation_chain():
|
|
369
377
|
result = view_info.transform_tensor(result)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: torchax
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.11.dev202612
|
|
4
4
|
Summary: torchax is a library for running Jax and PyTorch together
|
|
5
5
|
Project-URL: Homepage, https://github.com/google/torchax
|
|
6
6
|
Author-email: Han Qi <qihan.dev@gmail.com>, Google Cloud Inference Team <cmcs-inference-eng@google.com>
|
|
@@ -237,6 +237,12 @@ Description-Content-Type: text/markdown
|
|
|
237
237
|
|
|
238
238
|
# torchax: Running PyTorch on TPU via JAX
|
|
239
239
|
|
|
240
|
+
Docs page: https://google.github.io/torchax/
|
|
241
|
+
Discord Discussion Channel: https://discord.gg/JqeJqGPyzC
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+

|
|
245
|
+
|
|
240
246
|
**torchax** is a backend for PyTorch that allows users to run
|
|
241
247
|
PyTorch programs on Google Cloud TPUs. It also provides graph-level
|
|
242
248
|
interoperability between PyTorch and JAX.
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
torchax/CONTRIBUTING.md,sha256=gbB2ewxDLC-HHRBC3B8HdppV_d9MbDd-9rvzGQt7vZU,1440
|
|
2
|
+
torchax/__init__.py,sha256=m9mLdO6-n-WfpeR-R7jNQxkESUUpXQJ5A-0CNnXuSJs,3889
|
|
3
|
+
torchax/amp.py,sha256=WTrfah2NYodapDVMsht7H3zDpl-XslujxhcYXr85g1s,10999
|
|
4
|
+
torchax/checkpoint.py,sha256=2eoGeIQtL1Chof0W9qorB2Q0eCyEVJyWKqGoetf32GQ,2439
|
|
5
|
+
torchax/config.py,sha256=oTwgWDujF9vNSHRNKUvz3ZkocDe0aDaF-yviqAJAewY,1398
|
|
6
|
+
torchax/decompositions.py,sha256=8VU0FfKqbP8h3S7JzHi0iWqT5E1OrwYvuhf6cjzDTlI,29303
|
|
7
|
+
torchax/device_module.py,sha256=7WrLUBjMQiAilVfRwEwJrbfkizPZAC3022UO70U5uEQ,924
|
|
8
|
+
torchax/export.py,sha256=CFQESWy9-ENo9Ozf08qwqUtrNpt7Y3ubX_dP99o3FdM,9395
|
|
9
|
+
torchax/flax.py,sha256=Eft46Np3qPvSLmBGOltCx8KbnrGsuBeqx3Zu0tlhMpg,1807
|
|
10
|
+
torchax/interop.py,sha256=6R9_pIkd5Kwb5EoACUsq4GEEe8PaSFOdfnD4GIy6Y9U,11815
|
|
11
|
+
torchax/mesh_util.py,sha256=Y3RVKOyLVKpbseyXTYlJlUgyNavmDAsh2MQ6pXYQDUU,9719
|
|
12
|
+
torchax/tensor.py,sha256=IP7JuzUfKZVrhJdA67VdXBEh4bfypwavKRzw6llsD4U,21452
|
|
13
|
+
torchax/train.py,sha256=3sqIYO1Q6GN6gRGkVVoKjjUZ3xYPgh110g5XbkaUxh4,4367
|
|
14
|
+
torchax/types.py,sha256=NDtW1fARypg0ZHcGRVTBZKQqxwJzWtf5E5bntng91gk,981
|
|
15
|
+
torchax/util.py,sha256=Oud7oaAw1SJo_v4fwEZdjuseZ_bvngAsAQ-dOEzy_20,3675
|
|
16
|
+
torchax/view.py,sha256=750VYe6tmwAINNAyjN8GDvPmgaR8luvr74szwuikGts,11256
|
|
17
|
+
torchax/ops/__init__.py,sha256=uc-Rod4Xlk_oJ-rdY9P8-MGTu9dsXPombscqcSys0r8,840
|
|
18
|
+
torchax/ops/jaten.py,sha256=x5vrEiVRFaTigPYvFHtR9ClwW0Z6TEHT66-mEdzfe3k,163435
|
|
19
|
+
torchax/ops/jax_reimplement.py,sha256=Te8Je2ea9jX2SFV34PNPSHVP7-z_bmFykVWeqn8Tqwo,7714
|
|
20
|
+
torchax/ops/jc10d.py,sha256=sO99kYDM9WRnSENHmMkH_MXWkx6HZdDvK5Z_M0C571g,1889
|
|
21
|
+
torchax/ops/jimage.py,sha256=uvYW-fMaGU6-QTXAmTZ8rmHEkkwpXh2I-bu54STf-ic,3594
|
|
22
|
+
torchax/ops/jlibrary.py,sha256=ESeFS5xwV1yiCIp_yKzDXihHd1wcz5eXjFkjFKsmw3w,3470
|
|
23
|
+
torchax/ops/jtorch.py,sha256=OuFTed92B3WuWnfnjBBvMgwkl43PLJJiDWSMb0ZS0sw,16406
|
|
24
|
+
torchax/ops/jtorchvision_nms.py,sha256=VNMshE3LCsIBHVVzkrNEm0kYMF89ZVIKWlQIk6pCZB0,9197
|
|
25
|
+
torchax/ops/mappings.py,sha256=ViEsZaGIi37BhuLw9hx9cA9XXl35OVuaRN6Q9yZygxk,4162
|
|
26
|
+
torchax/ops/op_base.py,sha256=-rQXLpkgNZ1HM3OT1XQkvAV_7Dtq019_rAXNAg97OuE,4135
|
|
27
|
+
torchax/ops/ops_registry.py,sha256=sBT41LRGmUVP4ZJ9YU1DyffatOHxe-x8oXqMhCKh0y8,1836
|
|
28
|
+
torchax-0.0.11.dev202612.dist-info/METADATA,sha256=tVMhNPH26W_cwlb9oA8xPNByKjM25aY7OI4SGnGFTg4,22451
|
|
29
|
+
torchax-0.0.11.dev202612.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
30
|
+
torchax-0.0.11.dev202612.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
31
|
+
torchax-0.0.11.dev202612.dist-info/RECORD,,
|
|
@@ -1,31 +0,0 @@
|
|
|
1
|
-
torchax/CONTRIBUTING.md,sha256=gbB2ewxDLC-HHRBC3B8HdppV_d9MbDd-9rvzGQt7vZU,1440
|
|
2
|
-
torchax/__init__.py,sha256=t7SmghL_AbacVWVL3rtr-q8Fw-oZ1M8cKRNbdnrXAxA,4172
|
|
3
|
-
torchax/amp.py,sha256=h5sYm7jhbQhn8kmtdd8ahzrmH06r2VKGYigHQzRzmZc,12366
|
|
4
|
-
torchax/checkpoint.py,sha256=3nDcGOVbeeBzoaI1dfQxpUPdNiRuGGkVj65DlXkEpF4,2437
|
|
5
|
-
torchax/config.py,sha256=c2JVtKx-GkkQn9vGxYgweGrm57G60mH1BB-SfE6-d6Q,1497
|
|
6
|
-
torchax/decompositions.py,sha256=yR2QXWbwFS3Gb9Wbc318fJDX04wLYc_TCmclafOmqjw,29434
|
|
7
|
-
torchax/device_module.py,sha256=7WrLUBjMQiAilVfRwEwJrbfkizPZAC3022UO70U5uEQ,924
|
|
8
|
-
torchax/export.py,sha256=ztEECMqORuKX75VvQfRs4qwso1NEhFwylM9j6hVShJM,9544
|
|
9
|
-
torchax/flax.py,sha256=LxXS0ZKoEU_hNmweyuNSeErX9mThbkI4El1nAUBuGic,1855
|
|
10
|
-
torchax/interop.py,sha256=MS5BGmXhiucKA47EMOwJxpHhuwsQu19EyZsGz4zSw9g,11845
|
|
11
|
-
torchax/mesh_util.py,sha256=9QPU7fWxVnb8v2A4S1Accu6PCAeEsNgaXVz1fIfhXeI,9889
|
|
12
|
-
torchax/tensor.py,sha256=lEhIPhpJutnj8d7urFfAZdjihdh4iJdozI0bkJBIU-E,21900
|
|
13
|
-
torchax/train.py,sha256=3Uc5zMxHAp-qjwdIm-ebjsYrEC7KF2y-2JKFsNkHhPY,4726
|
|
14
|
-
torchax/types.py,sha256=IiHQHVNWQvo_6sT2Q3j6YBbc5p0DMrewrXNQB0fZb84,963
|
|
15
|
-
torchax/util.py,sha256=sLGXkmYE5dOBC-3Co2HxPTsjod36l9bPvzhufEw9kEE,3644
|
|
16
|
-
torchax/view.py,sha256=PF9ti7eLKBb5-kfx96lijlgSlSBXPCPPtGt_12iMiQI,11696
|
|
17
|
-
torchax/ops/__init__.py,sha256=VY8LjEJnLhMdSbPaTOzrWFn_nFRseWOSEqLFijyX_6Y,845
|
|
18
|
-
torchax/ops/jaten.py,sha256=GEc8s0GjCF6N9A8UnIyLmtch5TwCgnqWbTRr5iMerkk,177058
|
|
19
|
-
torchax/ops/jax_reimplement.py,sha256=kjU0Fopq_zM-8GJDSE4nPASAG-8es0I108ran45d0tM,7877
|
|
20
|
-
torchax/ops/jc10d.py,sha256=l7-bi-uFLbTDjHKsE6H-7eO_PG3j0j1S1vIFEMAyvII,1897
|
|
21
|
-
torchax/ops/jimage.py,sha256=6YpFBaYSvsXuYgsZY2Y-c5yi8bAROxlRCcXSFQMau0M,3757
|
|
22
|
-
torchax/ops/jlibrary.py,sha256=ol12qM2qP1qfqXnmWcVCpPOHmLODZkhiFZmxZGJAD6o,3505
|
|
23
|
-
torchax/ops/jtorch.py,sha256=SiXz_OQyQ7XW8uJ1i9Fx1n-G6k2Ws23fbrFsRr3fzUk,19191
|
|
24
|
-
torchax/ops/jtorchvision_nms.py,sha256=OZfYntyIatgXtMBU2hywlWvchMKaK028nNsFE2fCy7E,9478
|
|
25
|
-
torchax/ops/mappings.py,sha256=RP75HTKJYfZ08dTByfUnCEJRP42jDHcoQXQrBafWI8M,4362
|
|
26
|
-
torchax/ops/op_base.py,sha256=W_hiXEW-NxZruQfpMg2W_Aru47HiNZjR0kiEo749XAg,4314
|
|
27
|
-
torchax/ops/ops_registry.py,sha256=5wV9BjZqTAUr41tr-A7pchwkQ877YDu_m-RH1f4GWsU,2169
|
|
28
|
-
torchax-0.0.10.dev20251114.dist-info/METADATA,sha256=vVEnW18Y1vesO0NypWIvTtPp8mWNykNl3TG2Vlq6-wI,22315
|
|
29
|
-
torchax-0.0.10.dev20251114.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
30
|
-
torchax-0.0.10.dev20251114.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
31
|
-
torchax-0.0.10.dev20251114.dist-info/RECORD,,
|
{torchax-0.0.10.dev20251114.dist-info → torchax-0.0.11.dev202612.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|