torchax 0.0.10.dev20251116__py3-none-any.whl → 0.0.11.dev202612__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

torchax/view.py CHANGED
@@ -12,12 +12,14 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import torch
16
- import torch.utils._pytree as torch_pytree
17
- import jax
18
- from enum import Enum
19
- from typing import Union, List, Tuple, Optional, Any, cast
15
+ from __future__ import annotations
16
+
20
17
  from abc import ABC, abstractmethod
18
+ from enum import Enum
19
+ from typing import Any
20
+
21
+ import jax
22
+ import torch
21
23
 
22
24
  # Reference to original PyTorch native functions
23
25
  # https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
@@ -37,76 +39,75 @@ class ViewInfoType(Enum):
37
39
 
38
40
  class ViewInfo(ABC):
39
41
  """
40
- Abstract base class for all view operations.
41
- Defines the interface for applying and updating view transformations.
42
- """
42
+ Abstract base class for all view operations.
43
+ Defines the interface for applying and updating view transformations.
44
+ """
43
45
 
44
46
  def __init__(
45
- self,
46
- view_info_type: ViewInfoType = ViewInfoType.INVALID,
47
+ self,
48
+ view_info_type: ViewInfoType = ViewInfoType.INVALID,
47
49
  ):
48
50
  """
49
- Initialize a ViewInfo object.
51
+ Initialize a ViewInfo object.
50
52
 
51
- Args:
52
- view_info_type: The type of view operation
53
- """
53
+ Args:
54
+ view_info_type: The type of view operation
55
+ """
54
56
  self.view_info_type = view_info_type
55
57
 
56
58
  @abstractmethod
57
- def update_tensor(self, new_value: jax.Array,
58
- jax_array: jax.Array) -> jax.Array:
59
+ def update_tensor(self, new_value: jax.Array, jax_array: jax.Array) -> jax.Array:
59
60
  """
60
- Apply this view transformation to a JAX array and update its value.
61
+ Apply this view transformation to a JAX array and update its value.
61
62
 
62
- Args:
63
- new_value: The new values to set in the view
64
- jax_array: The parent array to update
63
+ Args:
64
+ new_value: The new values to set in the view
65
+ jax_array: The parent array to update
65
66
 
66
- Returns:
67
- Updated array
68
- """
67
+ Returns:
68
+ Updated array
69
+ """
69
70
  pass
70
71
 
71
72
  @abstractmethod
72
73
  def transform_tensor(self, jax_array: jax.Array) -> jax.Array:
73
74
  """
74
- Apply this view transformation to a JAX array.
75
+ Apply this view transformation to a JAX array.
75
76
 
76
- Args:
77
- jax_array: The array to transform
77
+ Args:
78
+ jax_array: The array to transform
78
79
 
79
- Returns:
80
- Transformed array
81
- """
80
+ Returns:
81
+ Transformed array
82
+ """
82
83
  pass
83
84
 
84
85
  @abstractmethod
85
- def calculate_output_shape(self, source: jax.Array) -> List[int]:
86
+ def calculate_output_shape(self, source: jax.Array) -> list[int]:
86
87
  """
87
- Calculate the resulting shape after applying this view.
88
+ Calculate the resulting shape after applying this view.
88
89
 
89
- Args:
90
- source: Original jax array before transformation
90
+ Args:
91
+ source: Original jax array before transformation
91
92
 
92
- Returns:
93
- Resulting shape after transformation
94
- """
93
+ Returns:
94
+ Resulting shape after transformation
95
+ """
95
96
  pass
96
97
 
97
98
 
98
99
  class NarrowInfo(ViewInfo):
99
100
  """
100
- Represents a slicing operation on a tensor.
101
- Handles operations like tensor[1:3, :, 2:5:2].
102
- """
101
+ Represents a slicing operation on a tensor.
102
+ Handles operations like tensor[1:3, :, 2:5:2].
103
+ """
103
104
 
104
- def __init__(self, slices: Union[slice, Tuple[slice]]) -> None:
105
+ def __init__(self, slices: slice | tuple[slice]) -> None:
106
+ """
107
+ Args:
108
+ slices: The slice(s) to apply to the tensor.
109
+ E.g. jax_array.at[slices] will return the transformed tensor.
105
110
  """
106
- Args:
107
- slices: The slice(s) to apply to the tensor.
108
- E.g. jax_array.at[slices] will return the transformed tensor.
109
- """
110
111
  super().__init__(ViewInfoType.NARROW)
111
112
  self.slices = slices
112
113
 
@@ -121,25 +122,22 @@ class NarrowInfo(ViewInfo):
121
122
  except IndexError as e:
122
123
  raise IndexError("Invalid slice operation") from e
123
124
 
124
- def update_tensor(self, new_value: jax.Array,
125
- jax_array: jax.Array) -> jax.Array:
125
+ def update_tensor(self, new_value: jax.Array, jax_array: jax.Array) -> jax.Array:
126
126
  return jax_array.at[self.slices].set(new_value)
127
127
 
128
- def calculate_output_shape(self, source: jax.Array) -> List[int]:
128
+ def calculate_output_shape(self, source: jax.Array) -> list[int]:
129
129
  return source[self.slices].shape
130
130
 
131
131
 
132
132
  class SelectInfo(ViewInfo):
133
133
  """
134
- Represents a selection operation on a tensor.
135
- Typically used for indexing operations that select specific elements.
136
- """
134
+ Represents a selection operation on a tensor.
135
+ Typically used for indexing operations that select specific elements.
136
+ """
137
137
 
138
- def __init__(self,
139
- dim: int = 0,
140
- start: int = 0,
141
- end: int = 0,
142
- stride: int = 0) -> None:
138
+ def __init__(
139
+ self, dim: int = 0, start: int = 0, end: int = 0, stride: int = 0
140
+ ) -> None:
143
141
  super().__init__(ViewInfoType.SELECT)
144
142
  self.dim: int = dim
145
143
  self.start: int = start
@@ -149,29 +147,31 @@ class SelectInfo(ViewInfo):
149
147
  def __eq__(self, other: object) -> bool:
150
148
  if not isinstance(other, SelectInfo):
151
149
  return False
152
- return (self.dim == other.dim and self.start == other.start and
153
- self.end == other.end and self.stride == other.stride)
150
+ return (
151
+ self.dim == other.dim
152
+ and self.start == other.start
153
+ and self.end == other.end
154
+ and self.stride == other.stride
155
+ )
154
156
 
155
157
  def transform_tensor(self, jax_array: jax.Array) -> jax.Array:
156
158
  raise NotImplementedError("SelectInfo.apply not implemented")
157
159
 
158
- def update_tensor(self, new_value: jax.Array,
159
- jax_array: jax.Array) -> jax.Array:
160
+ def update_tensor(self, new_value: jax.Array, jax_array: jax.Array) -> jax.Array:
160
161
  raise NotImplementedError("SelectInfo.update not implemented")
161
162
 
162
- def calculate_output_shape(self, source: jax.Array) -> List[int]:
163
- raise NotImplementedError(
164
- "SelectInfo.calculate_output_shape not implemented")
163
+ def calculate_output_shape(self, source: jax.Array) -> list[int]:
164
+ raise NotImplementedError("SelectInfo.calculate_output_shape not implemented")
165
165
 
166
166
 
167
167
  class AsStridedInfo(ViewInfo):
168
168
  """
169
- Information for as_strided operations.
170
- """
169
+ Information for as_strided operations.
170
+ """
171
171
 
172
- def __init__(self, stride: List[int], offset: int = 0) -> None:
172
+ def __init__(self, stride: list[int], offset: int = 0) -> None:
173
173
  super().__init__(ViewInfoType.AS_STRIDED)
174
- self.stride: List[int] = stride
174
+ self.stride: list[int] = stride
175
175
  self.offset: int = offset
176
176
 
177
177
  def __eq__(self, other: object) -> bool:
@@ -182,28 +182,26 @@ class AsStridedInfo(ViewInfo):
182
182
  def transform_tensor(self, jax_array: jax.Array) -> jax.Array:
183
183
  raise NotImplementedError("AsStridedInfo.apply not implemented")
184
184
 
185
- def update_tensor(self, new_value: jax.Array,
186
- jax_array: jax.Array) -> jax.Array:
185
+ def update_tensor(self, new_value: jax.Array, jax_array: jax.Array) -> jax.Array:
187
186
  raise NotImplementedError("AsStridedInfo.update not implemented")
188
187
 
189
- def calculate_output_shape(self, source: jax.Array) -> List[int]:
190
- raise NotImplementedError(
191
- "AsStridedInfo.calculate_output_shape not implemented")
188
+ def calculate_output_shape(self, source: jax.Array) -> list[int]:
189
+ raise NotImplementedError("AsStridedInfo.calculate_output_shape not implemented")
192
190
 
193
191
 
194
192
  class DiagonalInfo(ViewInfo):
195
193
  """
196
- Information for diagonal operations.
197
- Extracts diagonal elements from a tensor.
198
- """
194
+ Information for diagonal operations.
195
+ Extracts diagonal elements from a tensor.
196
+ """
199
197
 
200
198
  def __init__(self, offset: int = 0, dim1: int = 0, dim2: int = 1) -> None:
201
199
  """
202
- Args:
203
- offset: Offset from the main diagonal
204
- dim1: First dimension for diagonal extraction
205
- dim2: Second dimension for diagonal extraction
206
- """
200
+ Args:
201
+ offset: Offset from the main diagonal
202
+ dim1: First dimension for diagonal extraction
203
+ dim2: Second dimension for diagonal extraction
204
+ """
207
205
  super().__init__(ViewInfoType.DIAGONAL)
208
206
  self.offset: int = offset
209
207
  self.dim1: int = dim1
@@ -212,56 +210,65 @@ class DiagonalInfo(ViewInfo):
212
210
  def __eq__(self, other: object) -> bool:
213
211
  if not isinstance(other, DiagonalInfo):
214
212
  return False
215
- return (self.offset == other.offset and self.dim1 == other.dim1 and
216
- self.dim2 == other.dim2)
213
+ return (
214
+ self.offset == other.offset
215
+ and self.dim1 == other.dim1
216
+ and self.dim2 == other.dim2
217
+ )
217
218
 
218
219
  def transform_tensor(self, jax_array: jax.Array) -> jax.Array:
219
220
  raise NotImplementedError("DiagonalInfo.apply not implemented")
220
221
 
221
- def update_tensor(self, new_value: jax.Array,
222
- jax_array: jax.Array) -> jax.Array:
222
+ def update_tensor(self, new_value: jax.Array, jax_array: jax.Array) -> jax.Array:
223
223
  raise NotImplementedError("DiagonalInfo.update not implemented")
224
224
 
225
- def calculate_output_shape(self, source: jax.Array) -> List[int]:
226
- raise NotImplementedError(
227
- "DiagonalInfo.calculate_output_shape not implemented")
225
+ def calculate_output_shape(self, source: jax.Array) -> list[int]:
226
+ raise NotImplementedError("DiagonalInfo.calculate_output_shape not implemented")
228
227
 
229
228
 
230
229
  class View(torch.Tensor):
231
230
  """
232
- A View is a reference to another Tensor or another View,
233
- with a transformation applied to it.
234
- """
231
+ A View is a reference to another Tensor or another View,
232
+ with a transformation applied to it.
233
+ """
235
234
 
236
235
  @staticmethod
237
- def __new__(cls, parent: Union["torchax.Tensor", "View"], view_info: ViewInfo,
238
- env: Any) -> "View":
236
+ def __new__(
237
+ cls,
238
+ parent: torchax.Tensor | View, # noqa: F821
239
+ view_info: ViewInfo,
240
+ env: Any,
241
+ ) -> View:
242
+ """
243
+ Args:
244
+ parent: Parent tensor or view
245
+ view_info: Information about the view transformation
246
+ env: Environment for tensor operations
239
247
  """
240
- Args:
241
- parent: Parent tensor or view
242
- view_info: Information about the view transformation
243
- env: Environment for tensor operations
244
- """
245
248
  shape = view_info.calculate_output_shape(parent.jax())
246
249
  return torch.Tensor._make_wrapper_subclass(
247
- cls,
248
- shape,
249
- device="meta",
250
- dtype=parent.dtype,
251
- requires_grad=False,
250
+ cls,
251
+ shape,
252
+ device="meta",
253
+ dtype=parent.dtype,
254
+ requires_grad=False,
252
255
  )
253
256
 
254
- def __init__(self, parent: Union["torchax.Tensor", "View"],
255
- view_info: ViewInfo, env: Any) -> None:
257
+ def __init__(
258
+ self,
259
+ parent: torchax.Tensor | View, # noqa: F821
260
+ view_info: ViewInfo,
261
+ env: Any,
262
+ ) -> None:
256
263
  super().__init__()
257
264
  self.parent = parent
258
265
  self.view_info = view_info
259
266
  self._env = env
260
267
 
261
- def get_transformation_chain(self) -> List[ViewInfo]:
268
+ def get_transformation_chain(self) -> list[ViewInfo]:
269
+ """
270
+ Get all view transformations from the source tensor to this view.
262
271
  """
263
- Get all view transformations from the source tensor to this view.
264
- """
265
272
  if isinstance(self.parent, View):
266
273
  transformations = self.parent.get_transformation_chain()
267
274
  transformations.append(self.view_info)
@@ -273,8 +280,8 @@ class View(torch.Tensor):
273
280
 
274
281
  def source_jax(self) -> jax.Array:
275
282
  """
276
- Returns the source tensor.
277
- """
283
+ Returns the source tensor.
284
+ """
278
285
  if isinstance(self.parent, View):
279
286
  return self.parent.source_jax()
280
287
  else:
@@ -282,32 +289,32 @@ class View(torch.Tensor):
282
289
 
283
290
  def replace_source_jax(self, new_value: jax.Array) -> None:
284
291
  """
285
- Update the source tensor with new values.
286
- """
292
+ Update the source tensor with new values.
293
+ """
287
294
  if isinstance(self.parent, View):
288
295
  self.parent.replace_source_jax(new_value)
289
296
  else:
290
297
  assert new_value.shape == self.parent._elem.shape
291
298
  self.parent._elem = new_value
292
299
 
293
- def torch(self) -> "torchax.Tensor":
300
+ def torch(self) -> torchax.Tensor: # noqa: F821
301
+ """
302
+ Returns a Torchax tensor representing this view after all transformations
294
303
  """
295
- Returns a Torchax tensor representing this view after all transformations
296
- """
297
304
  from torchax.tensor import Tensor
298
305
 
299
306
  return Tensor(self.jax(), self._env)
300
307
 
301
308
  def update(
302
- self,
303
- new_values: Union[jax.Array, "View", "torchax.Tensor"],
304
- view_infos: Optional[List[ViewInfo]] = None,
309
+ self,
310
+ new_values: jax.Array | View | torchax.Tensor, # noqa: F821
311
+ view_infos: list[ViewInfo] | None = None,
305
312
  ) -> None:
306
313
  """
307
- Update this view with new values, propagating changes back to source.
308
- If view_infos is None, it will use the transformation chain
309
- from the source tensor.
310
- """
314
+ Update this view with new values, propagating changes back to source.
315
+ If view_infos is None, it will use the transformation chain
316
+ from the source tensor.
317
+ """
311
318
  if view_infos is None:
312
319
  view_infos = self.get_transformation_chain()
313
320
 
@@ -324,14 +331,14 @@ class View(torch.Tensor):
324
331
  # And store intermediate values
325
332
  intermediate_values = [source_array]
326
333
  for view_info in view_infos[:-1]:
327
- intermediate_values.append(
328
- view_info.transform_tensor(intermediate_values[-1]))
334
+ intermediate_values.append(view_info.transform_tensor(intermediate_values[-1]))
329
335
 
330
336
  # TODO: Investigate efficiency of this algorithm
331
337
  # Update the source array with the new value by
332
338
  # applying inverse transformations in reverse order
333
339
  for view_info, parent_array in zip(
334
- reversed(view_infos), reversed(intermediate_values)):
340
+ reversed(view_infos), reversed(intermediate_values), strict=False
341
+ ):
335
342
  # Apply the inverse transformation to propagate changes back
336
343
  new_values = view_info.update_tensor(new_values, parent_array)
337
344
 
@@ -340,21 +347,22 @@ class View(torch.Tensor):
340
347
 
341
348
  @classmethod
342
349
  def __torch_dispatch__(
343
- cls,
344
- func: Any,
345
- types: Tuple[Any, ...],
346
- args: Tuple[Any, ...] = (),
347
- kwargs: Optional[dict] = None,
350
+ cls,
351
+ func: Any,
352
+ types: tuple[Any, ...],
353
+ args: tuple[Any, ...] = (),
354
+ kwargs: dict | None = None,
348
355
  ) -> Any:
349
356
  raise AssertionError(
350
- 'torchax Tensors can only do math within the torchax environment.'
351
- 'Please wrap your code with `with torchax.default_env()` or '
352
- 'call torchax.enable_globally() before.')
357
+ "torchax Tensors can only do math within the torchax environment."
358
+ "Please wrap your code with `with torchax.default_env()` or "
359
+ "call torchax.enable_globally() before."
360
+ )
353
361
 
354
- def create_sub_view(self, view_info: ViewInfo) -> "View":
362
+ def create_sub_view(self, view_info: ViewInfo) -> View:
363
+ """
364
+ Create a new view that is a child of this view.
355
365
  """
356
- Create a new view that is a child of this view.
357
- """
358
366
  return View(self, view_info, self._env)
359
367
 
360
368
  def __str__(self) -> str:
@@ -362,8 +370,8 @@ class View(torch.Tensor):
362
370
 
363
371
  def jax(self) -> jax.Array:
364
372
  """
365
- Returns a copy of the source tensor after transformations.
366
- """
373
+ Returns a copy of the source tensor after transformations.
374
+ """
367
375
  result = self.source_jax()
368
376
  for view_info in self.get_transformation_chain():
369
377
  result = view_info.transform_tensor(result)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torchax
3
- Version: 0.0.10.dev20251116
3
+ Version: 0.0.11.dev202612
4
4
  Summary: torchax is a library for running Jax and PyTorch together
5
5
  Project-URL: Homepage, https://github.com/google/torchax
6
6
  Author-email: Han Qi <qihan.dev@gmail.com>, Google Cloud Inference Team <cmcs-inference-eng@google.com>
@@ -237,6 +237,12 @@ Description-Content-Type: text/markdown
237
237
 
238
238
  # torchax: Running PyTorch on TPU via JAX
239
239
 
240
+ Docs page: https://google.github.io/torchax/
241
+ Discord Discussion Channel: https://discord.gg/JqeJqGPyzC
242
+
243
+
244
+ ![](docs/docs/assets/logo.jpeg)
245
+
240
246
  **torchax** is a backend for PyTorch that allows users to run
241
247
  PyTorch programs on Google Cloud TPUs. It also provides graph-level
242
248
  interoperability between PyTorch and JAX.
@@ -0,0 +1,31 @@
1
+ torchax/CONTRIBUTING.md,sha256=gbB2ewxDLC-HHRBC3B8HdppV_d9MbDd-9rvzGQt7vZU,1440
2
+ torchax/__init__.py,sha256=m9mLdO6-n-WfpeR-R7jNQxkESUUpXQJ5A-0CNnXuSJs,3889
3
+ torchax/amp.py,sha256=WTrfah2NYodapDVMsht7H3zDpl-XslujxhcYXr85g1s,10999
4
+ torchax/checkpoint.py,sha256=2eoGeIQtL1Chof0W9qorB2Q0eCyEVJyWKqGoetf32GQ,2439
5
+ torchax/config.py,sha256=oTwgWDujF9vNSHRNKUvz3ZkocDe0aDaF-yviqAJAewY,1398
6
+ torchax/decompositions.py,sha256=8VU0FfKqbP8h3S7JzHi0iWqT5E1OrwYvuhf6cjzDTlI,29303
7
+ torchax/device_module.py,sha256=7WrLUBjMQiAilVfRwEwJrbfkizPZAC3022UO70U5uEQ,924
8
+ torchax/export.py,sha256=CFQESWy9-ENo9Ozf08qwqUtrNpt7Y3ubX_dP99o3FdM,9395
9
+ torchax/flax.py,sha256=Eft46Np3qPvSLmBGOltCx8KbnrGsuBeqx3Zu0tlhMpg,1807
10
+ torchax/interop.py,sha256=6R9_pIkd5Kwb5EoACUsq4GEEe8PaSFOdfnD4GIy6Y9U,11815
11
+ torchax/mesh_util.py,sha256=Y3RVKOyLVKpbseyXTYlJlUgyNavmDAsh2MQ6pXYQDUU,9719
12
+ torchax/tensor.py,sha256=IP7JuzUfKZVrhJdA67VdXBEh4bfypwavKRzw6llsD4U,21452
13
+ torchax/train.py,sha256=3sqIYO1Q6GN6gRGkVVoKjjUZ3xYPgh110g5XbkaUxh4,4367
14
+ torchax/types.py,sha256=NDtW1fARypg0ZHcGRVTBZKQqxwJzWtf5E5bntng91gk,981
15
+ torchax/util.py,sha256=Oud7oaAw1SJo_v4fwEZdjuseZ_bvngAsAQ-dOEzy_20,3675
16
+ torchax/view.py,sha256=750VYe6tmwAINNAyjN8GDvPmgaR8luvr74szwuikGts,11256
17
+ torchax/ops/__init__.py,sha256=uc-Rod4Xlk_oJ-rdY9P8-MGTu9dsXPombscqcSys0r8,840
18
+ torchax/ops/jaten.py,sha256=x5vrEiVRFaTigPYvFHtR9ClwW0Z6TEHT66-mEdzfe3k,163435
19
+ torchax/ops/jax_reimplement.py,sha256=Te8Je2ea9jX2SFV34PNPSHVP7-z_bmFykVWeqn8Tqwo,7714
20
+ torchax/ops/jc10d.py,sha256=sO99kYDM9WRnSENHmMkH_MXWkx6HZdDvK5Z_M0C571g,1889
21
+ torchax/ops/jimage.py,sha256=uvYW-fMaGU6-QTXAmTZ8rmHEkkwpXh2I-bu54STf-ic,3594
22
+ torchax/ops/jlibrary.py,sha256=ESeFS5xwV1yiCIp_yKzDXihHd1wcz5eXjFkjFKsmw3w,3470
23
+ torchax/ops/jtorch.py,sha256=OuFTed92B3WuWnfnjBBvMgwkl43PLJJiDWSMb0ZS0sw,16406
24
+ torchax/ops/jtorchvision_nms.py,sha256=VNMshE3LCsIBHVVzkrNEm0kYMF89ZVIKWlQIk6pCZB0,9197
25
+ torchax/ops/mappings.py,sha256=ViEsZaGIi37BhuLw9hx9cA9XXl35OVuaRN6Q9yZygxk,4162
26
+ torchax/ops/op_base.py,sha256=-rQXLpkgNZ1HM3OT1XQkvAV_7Dtq019_rAXNAg97OuE,4135
27
+ torchax/ops/ops_registry.py,sha256=sBT41LRGmUVP4ZJ9YU1DyffatOHxe-x8oXqMhCKh0y8,1836
28
+ torchax-0.0.11.dev202612.dist-info/METADATA,sha256=tVMhNPH26W_cwlb9oA8xPNByKjM25aY7OI4SGnGFTg4,22451
29
+ torchax-0.0.11.dev202612.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
30
+ torchax-0.0.11.dev202612.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
31
+ torchax-0.0.11.dev202612.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: hatchling 1.27.0
2
+ Generator: hatchling 1.28.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
@@ -1,31 +0,0 @@
1
- torchax/CONTRIBUTING.md,sha256=gbB2ewxDLC-HHRBC3B8HdppV_d9MbDd-9rvzGQt7vZU,1440
2
- torchax/__init__.py,sha256=V6m7jGoh59hP3nVH7-gK-993qLz5JdmgMIWDbVC3rXw,4172
3
- torchax/amp.py,sha256=h5sYm7jhbQhn8kmtdd8ahzrmH06r2VKGYigHQzRzmZc,12366
4
- torchax/checkpoint.py,sha256=3nDcGOVbeeBzoaI1dfQxpUPdNiRuGGkVj65DlXkEpF4,2437
5
- torchax/config.py,sha256=c2JVtKx-GkkQn9vGxYgweGrm57G60mH1BB-SfE6-d6Q,1497
6
- torchax/decompositions.py,sha256=yR2QXWbwFS3Gb9Wbc318fJDX04wLYc_TCmclafOmqjw,29434
7
- torchax/device_module.py,sha256=7WrLUBjMQiAilVfRwEwJrbfkizPZAC3022UO70U5uEQ,924
8
- torchax/export.py,sha256=ztEECMqORuKX75VvQfRs4qwso1NEhFwylM9j6hVShJM,9544
9
- torchax/flax.py,sha256=LxXS0ZKoEU_hNmweyuNSeErX9mThbkI4El1nAUBuGic,1855
10
- torchax/interop.py,sha256=MS5BGmXhiucKA47EMOwJxpHhuwsQu19EyZsGz4zSw9g,11845
11
- torchax/mesh_util.py,sha256=9QPU7fWxVnb8v2A4S1Accu6PCAeEsNgaXVz1fIfhXeI,9889
12
- torchax/tensor.py,sha256=lEhIPhpJutnj8d7urFfAZdjihdh4iJdozI0bkJBIU-E,21900
13
- torchax/train.py,sha256=3Uc5zMxHAp-qjwdIm-ebjsYrEC7KF2y-2JKFsNkHhPY,4726
14
- torchax/types.py,sha256=IiHQHVNWQvo_6sT2Q3j6YBbc5p0DMrewrXNQB0fZb84,963
15
- torchax/util.py,sha256=sLGXkmYE5dOBC-3Co2HxPTsjod36l9bPvzhufEw9kEE,3644
16
- torchax/view.py,sha256=PF9ti7eLKBb5-kfx96lijlgSlSBXPCPPtGt_12iMiQI,11696
17
- torchax/ops/__init__.py,sha256=VY8LjEJnLhMdSbPaTOzrWFn_nFRseWOSEqLFijyX_6Y,845
18
- torchax/ops/jaten.py,sha256=GEc8s0GjCF6N9A8UnIyLmtch5TwCgnqWbTRr5iMerkk,177058
19
- torchax/ops/jax_reimplement.py,sha256=kjU0Fopq_zM-8GJDSE4nPASAG-8es0I108ran45d0tM,7877
20
- torchax/ops/jc10d.py,sha256=l7-bi-uFLbTDjHKsE6H-7eO_PG3j0j1S1vIFEMAyvII,1897
21
- torchax/ops/jimage.py,sha256=6YpFBaYSvsXuYgsZY2Y-c5yi8bAROxlRCcXSFQMau0M,3757
22
- torchax/ops/jlibrary.py,sha256=ol12qM2qP1qfqXnmWcVCpPOHmLODZkhiFZmxZGJAD6o,3505
23
- torchax/ops/jtorch.py,sha256=SiXz_OQyQ7XW8uJ1i9Fx1n-G6k2Ws23fbrFsRr3fzUk,19191
24
- torchax/ops/jtorchvision_nms.py,sha256=OZfYntyIatgXtMBU2hywlWvchMKaK028nNsFE2fCy7E,9478
25
- torchax/ops/mappings.py,sha256=RP75HTKJYfZ08dTByfUnCEJRP42jDHcoQXQrBafWI8M,4362
26
- torchax/ops/op_base.py,sha256=W_hiXEW-NxZruQfpMg2W_Aru47HiNZjR0kiEo749XAg,4314
27
- torchax/ops/ops_registry.py,sha256=5wV9BjZqTAUr41tr-A7pchwkQ877YDu_m-RH1f4GWsU,2169
28
- torchax-0.0.10.dev20251116.dist-info/METADATA,sha256=yj4zBuvXI62XXwjDUkMjJgTvvbhWjkCwPIMe4eKS4TQ,22315
29
- torchax-0.0.10.dev20251116.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
30
- torchax-0.0.10.dev20251116.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
31
- torchax-0.0.10.dev20251116.dist-info/RECORD,,