torchax 0.0.10.dev20251118__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

torchax/train.py ADDED
@@ -0,0 +1,130 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import collections
16
+
17
+ import jax
18
+ import optax
19
+ import torch
20
+
21
+ import torchax
22
+ from torchax import interop
23
+ from torchax.interop import torch_view
24
+
25
+ remat = torch_view(jax.remat)
26
+ mark_sharding = torch_view(jax.lax.with_sharding_constraint)
27
+
28
+
29
+ def make_train_step(model_fn, loss_fn, optax_optimizer, remat_policy=None):
30
+ """Make a function that do one train step given model and loss.
31
+
32
+ model_fn: a function representing the model's forward:
33
+ i.e. has signature Callable[weights, buffers, args] -> result. Where,
34
+ weights is a pytree of trainable parameters
35
+ buffers is a pytree of non-trainable parameters / constants
36
+ args is the input data loaded from the data set
37
+ result is the return value of the model
38
+ loss_fn: a function to compute loss.
39
+ i.e. it has signature of Callable[result, label] -> loss
40
+ where, result is what model_fn returned
41
+ loss is loaded from the dataloader.
42
+ optax_optimizer: the optimizer from optax library. for example, optax.adam
43
+ remat_policy: One of jax.ad_checkpoint.checkpoint_policies, specifies how
44
+ to do gradient checkpointing. If None, then it means checkpoint everything.
45
+ """
46
+ env = torchax.default_env()
47
+
48
+ def loss(weights, buffers, args, label): # inputs are XLATensor
49
+ with env, jax.named_scope("compute_loss"):
50
+ res = model_fn(weights, buffers, args)
51
+ l = loss_fn(res, label) # noqa: E741
52
+ return l
53
+
54
+ # loss = interop.gradient_checkpoint(loss, kwargs={'policy': remat_policy})
55
+ grad_fn = interop.jax_value_and_grad(loss)
56
+
57
+ def step(weights, buffers, opt_state, args, label): # inputs are array
58
+ with jax.named_scope("compute_gradient"):
59
+ loss, gradient = grad_fn(weights, buffers, args, label)
60
+
61
+ with jax.named_scope("optimizer_updates"):
62
+ updates, opt_state = interop.call_jax(
63
+ optax_optimizer.update, gradient, opt_state, weights
64
+ )
65
+ weights = interop.call_jax(optax.apply_updates, weights, updates)
66
+ return loss, weights, opt_state
67
+
68
+ # TODO: apply jax.jit so the user don't have to.
69
+ return step
70
+
71
+
72
+ class Container:
73
+ pass
74
+
75
+
76
+ class ScannedModule(torch.nn.Module):
77
+ def __init__(self, module_list, checkpoint_policy=None):
78
+ super().__init__()
79
+
80
+ self.c = None
81
+ assert module_list
82
+ self.c = Container()
83
+ self.c.one_mod = module_list[0]
84
+ self.checkpoint_policy = checkpoint_policy
85
+
86
+ weights = self._stack_layer_weights(module_list)
87
+ self.layer_weights_keys = list(self.c.one_mod.state_dict().keys())
88
+ self.params = torch.nn.ParameterDict(
89
+ {self._param_name_new(k): v for k, v in weights.items()}
90
+ )
91
+
92
+ def _stack_layer_weights(self, module_list):
93
+ # Create weights such that, for every [n, m] weights
94
+ # becomes [k, n, m] where k is number of layer
95
+ # i.e. stacking layer weights together
96
+ temp = collections.defaultdict(list)
97
+ for m in module_list:
98
+ for k, v in m.state_dict().items():
99
+ temp[k].append(v)
100
+ res = {k: torch.stack(v) for k, v in temp.items()}
101
+ return res
102
+
103
+ def _param_name_new(self, old):
104
+ return "___".join(old.split("."))
105
+
106
+ def _param_name_old(self, new):
107
+ return ".".join(new.split("___"))
108
+
109
+ def forward(self, *args, **kwargs):
110
+ assert not kwargs
111
+ weights = {k: self.params[self._param_name_new(k)] for k in self.layer_weights_keys}
112
+ scan = interop.torch_view(jax.lax.scan)
113
+
114
+ def eval_one_layer(args, weight):
115
+ # unpack args
116
+ h, *rest = args
117
+ newh = torch.func.functional_call(self.c.one_mod, weight, args)
118
+ # next layer's input; and residual to be added to list
119
+ return (newh, *rest), None
120
+
121
+ _eval_one_layer = interop.gradient_checkpoint(
122
+ eval_one_layer,
123
+ kwargs={"policy": self.checkpoint_policy},
124
+ )
125
+ h, _ = scan(
126
+ _eval_one_layer,
127
+ args,
128
+ weights,
129
+ )
130
+ return h[0]
torchax/types.py ADDED
@@ -0,0 +1,27 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from collections.abc import Callable
16
+ from typing import Any, ParamSpec, TypeAlias, Union
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+ import torch
21
+
22
+ P = ParamSpec("P")
23
+
24
+ TorchValue: TypeAlias = Union[torch.Tensor, torch.dtype, "TorchCallable", Any]
25
+ TorchCallable: TypeAlias = Callable[P, TorchValue]
26
+ JaxValue: TypeAlias = Union[jax.Array, jnp.dtype, "JaxCallable", Any]
27
+ JaxCallable: TypeAlias = Callable[P, JaxValue]
torchax/util.py ADDED
@@ -0,0 +1,104 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from collections.abc import Callable
16
+ from typing import Any
17
+
18
+
19
+ def partition(
20
+ original: list[Any], func: Callable[[Any], bool]
21
+ ) -> tuple[list[Any], list[Any]]:
22
+ """Partitions elements into two parallel lists based on a predicate function.
23
+
24
+ Iterates through the 'original' list, applying 'func' to each element 'a'.
25
+ - If `func(a)` returns True, 'a' is appended to the first list ('truthy')
26
+ and `None` is appended to the second list ('falsy').
27
+ - If `func(a)` returns False, `None` is appended to the first list ('truthy')
28
+ and 'a' is appended to the second list ('falsy').
29
+
30
+ The result is two lists of the same length as the 'original' list, acting
31
+ as parallel representations of the partitioned elements, using `None` as
32
+ placeholders.
33
+
34
+ This is useful when we want to mark a group of elements as static (via passing
35
+ static_argnums) or donated (via donate_argnums) when combining with jax.jit
36
+ and friends.
37
+
38
+ Args:
39
+ original: The list of elements to partition.
40
+ func: A callable (function or lambda) that accepts an element from
41
+ 'original' and returns a boolean value (True or False).
42
+
43
+ Returns:
44
+ A tuple containing two lists (`truthy`, `falsy`), both of the same
45
+ length as `original`:
46
+ - The first list contains elements `x` where `func(x)` was True, and
47
+ `None` otherwise.
48
+ - The second list contains elements `x` where `func(x)` was False, and
49
+ `None` otherwise.
50
+
51
+ Example:
52
+ >>> def is_even(n): return n % 2 == 0
53
+ >>> nums = [1, 2, 3, 4, 5, 6]
54
+ >>> truthy_list, falsy_list = partition(nums, is_even)
55
+ >>> truthy_list
56
+ [None, 2, None, 4, None, 6]
57
+ >>> falsy_list
58
+ [1, None, 3, None, 5, None]
59
+ """
60
+ truthy = []
61
+ falsy = []
62
+ for a in original:
63
+ t, f = (a, None) if func(a) else (None, a)
64
+ truthy.append(t)
65
+ falsy.append(f)
66
+ return truthy, falsy
67
+
68
+
69
+ def merge(list1: list[Any], list2: list[Any]) -> list[Any]:
70
+ """Merges two lists element-wise, prioritizing non-None elements from list1.
71
+
72
+ Creates a new list where each element is taken from the corresponding position
73
+ in 'list1', unless that element is None, in which case the element from the
74
+ corresponding position in 'list2' is used. Assumes both lists have the
75
+ same length.
76
+
77
+ Invariant: merge(*partion(input_list, predicate)) == input_list for any predicate
78
+
79
+ Args:
80
+ list1: The primary list. Its elements are preferred unless they are None.
81
+ list2: The secondary list. Its elements are used as fallbacks when the
82
+ corresponding element in list1 is None.
83
+
84
+ Returns:
85
+ A new list representing the merged result.
86
+
87
+ Raises:
88
+ AssertionError: If 'list1' and 'list2' do not have the same length.
89
+
90
+ Example:
91
+ >>> l1 = [1, None, 3, None]
92
+ >>> l2 = [None, 2, None, 4]
93
+ >>> merge(l1, l2)
94
+ [1, 2, 3, 4]
95
+ >>> l3 = [None, 'b', None]
96
+ >>> l4 = ['a', None, 'c']
97
+ >>> merge(l3, l4)
98
+ ['a', 'b', 'c']
99
+ """
100
+ assert len(list1) == len(list2)
101
+ res = []
102
+ for a, b in zip(list1, list2, strict=False):
103
+ res.append(b if a is None else a)
104
+ return res
torchax/view.py ADDED
@@ -0,0 +1,399 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ from abc import ABC, abstractmethod
18
+ from enum import Enum
19
+ from typing import Any
20
+
21
+ import jax
22
+ import torch
23
+
24
+ # Reference to original PyTorch native functions
25
+ # https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
26
+
27
+
28
+ class ViewInfoType(Enum):
29
+ INVALID = 0
30
+ NARROW = 1
31
+ NO_OP = 2
32
+ PERMUTE = 3
33
+ RESHAPE = 4
34
+ RESIZE = 5
35
+ SELECT = 6
36
+ AS_STRIDED = 7
37
+ DIAGONAL = 8
38
+
39
+
40
+ class ViewInfo(ABC):
41
+ """
42
+ Abstract base class for all view operations.
43
+ Defines the interface for applying and updating view transformations.
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ view_info_type: ViewInfoType = ViewInfoType.INVALID,
49
+ ):
50
+ """
51
+ Initialize a ViewInfo object.
52
+
53
+ Args:
54
+ view_info_type: The type of view operation
55
+ """
56
+ self.view_info_type = view_info_type
57
+
58
+ @abstractmethod
59
+ def update_tensor(self, new_value: jax.Array, jax_array: jax.Array) -> jax.Array:
60
+ """
61
+ Apply this view transformation to a JAX array and update its value.
62
+
63
+ Args:
64
+ new_value: The new values to set in the view
65
+ jax_array: The parent array to update
66
+
67
+ Returns:
68
+ Updated array
69
+ """
70
+ pass
71
+
72
+ @abstractmethod
73
+ def transform_tensor(self, jax_array: jax.Array) -> jax.Array:
74
+ """
75
+ Apply this view transformation to a JAX array.
76
+
77
+ Args:
78
+ jax_array: The array to transform
79
+
80
+ Returns:
81
+ Transformed array
82
+ """
83
+ pass
84
+
85
+ @abstractmethod
86
+ def calculate_output_shape(self, source: jax.Array) -> list[int]:
87
+ """
88
+ Calculate the resulting shape after applying this view.
89
+
90
+ Args:
91
+ source: Original jax array before transformation
92
+
93
+ Returns:
94
+ Resulting shape after transformation
95
+ """
96
+ pass
97
+
98
+
99
+ class NarrowInfo(ViewInfo):
100
+ """
101
+ Represents a slicing operation on a tensor.
102
+ Handles operations like tensor[1:3, :, 2:5:2].
103
+ """
104
+
105
+ def __init__(self, slices: slice | tuple[slice]) -> None:
106
+ """
107
+ Args:
108
+ slices: The slice(s) to apply to the tensor.
109
+ E.g. jax_array.at[slices] will return the transformed tensor.
110
+ """
111
+ super().__init__(ViewInfoType.NARROW)
112
+ self.slices = slices
113
+
114
+ def __eq__(self, other: object) -> bool:
115
+ if not isinstance(other, NarrowInfo):
116
+ return False
117
+ return self.slices == other.slices
118
+
119
+ def transform_tensor(self, jax_array: jax.Array) -> jax.Array:
120
+ try:
121
+ return jax_array[self.slices]
122
+ except IndexError as e:
123
+ raise IndexError("Invalid slice operation") from e
124
+
125
+ def update_tensor(self, new_value: jax.Array, jax_array: jax.Array) -> jax.Array:
126
+ return jax_array.at[self.slices].set(new_value)
127
+
128
+ def calculate_output_shape(self, source: jax.Array) -> list[int]:
129
+ return source[self.slices].shape
130
+
131
+
132
+ class SelectInfo(ViewInfo):
133
+ """
134
+ Represents a selection operation on a tensor.
135
+ Typically used for indexing operations that select specific elements.
136
+ """
137
+
138
+ def __init__(
139
+ self, dim: int = 0, start: int = 0, end: int = 0, stride: int = 0
140
+ ) -> None:
141
+ super().__init__(ViewInfoType.SELECT)
142
+ self.dim: int = dim
143
+ self.start: int = start
144
+ self.end: int = end
145
+ self.stride: int = stride
146
+
147
+ def __eq__(self, other: object) -> bool:
148
+ if not isinstance(other, SelectInfo):
149
+ return False
150
+ return (
151
+ self.dim == other.dim
152
+ and self.start == other.start
153
+ and self.end == other.end
154
+ and self.stride == other.stride
155
+ )
156
+
157
+ def transform_tensor(self, jax_array: jax.Array) -> jax.Array:
158
+ raise NotImplementedError("SelectInfo.apply not implemented")
159
+
160
+ def update_tensor(self, new_value: jax.Array, jax_array: jax.Array) -> jax.Array:
161
+ raise NotImplementedError("SelectInfo.update not implemented")
162
+
163
+ def calculate_output_shape(self, source: jax.Array) -> list[int]:
164
+ raise NotImplementedError("SelectInfo.calculate_output_shape not implemented")
165
+
166
+
167
+ class AsStridedInfo(ViewInfo):
168
+ """
169
+ Information for as_strided operations.
170
+ """
171
+
172
+ def __init__(self, stride: list[int], offset: int = 0) -> None:
173
+ super().__init__(ViewInfoType.AS_STRIDED)
174
+ self.stride: list[int] = stride
175
+ self.offset: int = offset
176
+
177
+ def __eq__(self, other: object) -> bool:
178
+ if not isinstance(other, AsStridedInfo):
179
+ return False
180
+ return self.offset == other.offset and self.stride == other.stride
181
+
182
+ def transform_tensor(self, jax_array: jax.Array) -> jax.Array:
183
+ raise NotImplementedError("AsStridedInfo.apply not implemented")
184
+
185
+ def update_tensor(self, new_value: jax.Array, jax_array: jax.Array) -> jax.Array:
186
+ raise NotImplementedError("AsStridedInfo.update not implemented")
187
+
188
+ def calculate_output_shape(self, source: jax.Array) -> list[int]:
189
+ raise NotImplementedError("AsStridedInfo.calculate_output_shape not implemented")
190
+
191
+
192
+ class DiagonalInfo(ViewInfo):
193
+ """
194
+ Information for diagonal operations.
195
+ Extracts diagonal elements from a tensor.
196
+ """
197
+
198
+ def __init__(self, offset: int = 0, dim1: int = 0, dim2: int = 1) -> None:
199
+ """
200
+ Args:
201
+ offset: Offset from the main diagonal
202
+ dim1: First dimension for diagonal extraction
203
+ dim2: Second dimension for diagonal extraction
204
+ """
205
+ super().__init__(ViewInfoType.DIAGONAL)
206
+ self.offset: int = offset
207
+ self.dim1: int = dim1
208
+ self.dim2: int = dim2
209
+
210
+ def __eq__(self, other: object) -> bool:
211
+ if not isinstance(other, DiagonalInfo):
212
+ return False
213
+ return (
214
+ self.offset == other.offset
215
+ and self.dim1 == other.dim1
216
+ and self.dim2 == other.dim2
217
+ )
218
+
219
+ def transform_tensor(self, jax_array: jax.Array) -> jax.Array:
220
+ raise NotImplementedError("DiagonalInfo.apply not implemented")
221
+
222
+ def update_tensor(self, new_value: jax.Array, jax_array: jax.Array) -> jax.Array:
223
+ raise NotImplementedError("DiagonalInfo.update not implemented")
224
+
225
+ def calculate_output_shape(self, source: jax.Array) -> list[int]:
226
+ raise NotImplementedError("DiagonalInfo.calculate_output_shape not implemented")
227
+
228
+
229
+ class View(torch.Tensor):
230
+ """
231
+ A View is a reference to another Tensor or another View,
232
+ with a transformation applied to it.
233
+ """
234
+
235
+ @staticmethod
236
+ def __new__(
237
+ cls,
238
+ parent: torchax.Tensor | View, # noqa: F821
239
+ view_info: ViewInfo,
240
+ env: Any,
241
+ ) -> View:
242
+ """
243
+ Args:
244
+ parent: Parent tensor or view
245
+ view_info: Information about the view transformation
246
+ env: Environment for tensor operations
247
+ """
248
+ shape = view_info.calculate_output_shape(parent.jax())
249
+ return torch.Tensor._make_wrapper_subclass(
250
+ cls,
251
+ shape,
252
+ device="meta",
253
+ dtype=parent.dtype,
254
+ requires_grad=False,
255
+ )
256
+
257
+ def __init__(
258
+ self,
259
+ parent: torchax.Tensor | View, # noqa: F821
260
+ view_info: ViewInfo,
261
+ env: Any,
262
+ ) -> None:
263
+ super().__init__()
264
+ self.parent = parent
265
+ self.view_info = view_info
266
+ self._env = env
267
+
268
+ def get_transformation_chain(self) -> list[ViewInfo]:
269
+ """
270
+ Get all view transformations from the source tensor to this view.
271
+ """
272
+ if isinstance(self.parent, View):
273
+ transformations = self.parent.get_transformation_chain()
274
+ transformations.append(self.view_info)
275
+ return transformations
276
+ else:
277
+ return [self.view_info]
278
+
279
+ __torch_function__ = torch._C._disabled_torch_function_impl
280
+
281
+ def source_jax(self) -> jax.Array:
282
+ """
283
+ Returns the source tensor.
284
+ """
285
+ if isinstance(self.parent, View):
286
+ return self.parent.source_jax()
287
+ else:
288
+ return self.parent.jax()
289
+
290
+ def replace_source_jax(self, new_value: jax.Array) -> None:
291
+ """
292
+ Update the source tensor with new values.
293
+ """
294
+ if isinstance(self.parent, View):
295
+ self.parent.replace_source_jax(new_value)
296
+ else:
297
+ assert new_value.shape == self.parent._elem.shape
298
+ self.parent._elem = new_value
299
+
300
+ def torch(self) -> torchax.Tensor: # noqa: F821
301
+ """
302
+ Returns a Torchax tensor representing this view after all transformations
303
+ """
304
+ from torchax.tensor import Tensor
305
+
306
+ return Tensor(self.jax(), self._env)
307
+
308
+ def update(
309
+ self,
310
+ new_values: jax.Array | View | torchax.Tensor, # noqa: F821
311
+ view_infos: list[ViewInfo] | None = None,
312
+ ) -> None:
313
+ """
314
+ Update this view with new values, propagating changes back to source.
315
+ If view_infos is None, it will use the transformation chain
316
+ from the source tensor.
317
+ """
318
+ if view_infos is None:
319
+ view_infos = self.get_transformation_chain()
320
+
321
+ # Get the source JAX array
322
+ source_array = self.source_jax()
323
+
324
+ # Get the new value
325
+ from torchax.tensor import Tensor
326
+
327
+ if isinstance(new_values, View) or isinstance(new_values, Tensor):
328
+ new_values = new_values.jax()
329
+
330
+ # Apply all view transformations to the source array
331
+ # And store intermediate values
332
+ intermediate_values = [source_array]
333
+ for view_info in view_infos[:-1]:
334
+ intermediate_values.append(view_info.transform_tensor(intermediate_values[-1]))
335
+
336
+ # TODO: Investigate efficiency of this algorithm
337
+ # Update the source array with the new value by
338
+ # applying inverse transformations in reverse order
339
+ for view_info, parent_array in zip(
340
+ reversed(view_infos), reversed(intermediate_values), strict=False
341
+ ):
342
+ # Apply the inverse transformation to propagate changes back
343
+ new_values = view_info.update_tensor(new_values, parent_array)
344
+
345
+ # Update the source tensor with the new values
346
+ self.replace_source_jax(new_values)
347
+
348
+ @classmethod
349
+ def __torch_dispatch__(
350
+ cls,
351
+ func: Any,
352
+ types: tuple[Any, ...],
353
+ args: tuple[Any, ...] = (),
354
+ kwargs: dict | None = None,
355
+ ) -> Any:
356
+ raise AssertionError(
357
+ "torchax Tensors can only do math within the torchax environment."
358
+ "Please wrap your code with `with torchax.default_env()` or "
359
+ "call torchax.enable_globally() before."
360
+ )
361
+
362
+ def create_sub_view(self, view_info: ViewInfo) -> View:
363
+ """
364
+ Create a new view that is a child of this view.
365
+ """
366
+ return View(self, view_info, self._env)
367
+
368
+ def __str__(self) -> str:
369
+ return f"View({self.torch()})"
370
+
371
+ def jax(self) -> jax.Array:
372
+ """
373
+ Returns a copy of the source tensor after transformations.
374
+ """
375
+ result = self.source_jax()
376
+ for view_info in self.get_transformation_chain():
377
+ result = view_info.transform_tensor(result)
378
+ return result
379
+
380
+ def __setitem__(self, indexes, val):
381
+ view_infos = self.get_transformation_chain() + [NarrowInfo(indexes)]
382
+ self.update(view_infos=view_infos, new_values=val)
383
+
384
+ def dim(self):
385
+ return self.ndim
386
+
387
+ @property
388
+ def device(self):
389
+ return torch.device("jax:0")
390
+
391
+ @property
392
+ def jax_device(self):
393
+ return self.jax().device
394
+
395
+ @property
396
+ def ndim(self):
397
+ return len(self.shape)
398
+
399
+ __repr__ = __str__