torchax 0.0.10.dev20251117__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
torchax/train.py ADDED
@@ -0,0 +1,132 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import collections
16
+ import functools
17
+ import torch
18
+ import jax
19
+ import torchax
20
+ from torchax import interop
21
+ from torchax.interop import torch_view, jax_view
22
+ import optax
23
+
24
+ remat = torch_view(jax.remat)
25
+ mark_sharding = torch_view(jax.lax.with_sharding_constraint)
26
+
27
+
28
+ def make_train_step(model_fn, loss_fn, optax_optimizer, remat_policy=None):
29
+ """Make a function that do one train step given model and loss.
30
+
31
+ model_fn: a function representing the model's forward:
32
+ i.e. has signature Callable[weights, buffers, args] -> result. Where,
33
+ weights is a pytree of trainable parameters
34
+ buffers is a pytree of non-trainable parameters / constants
35
+ args is the input data loaded from the data set
36
+ result is the return value of the model
37
+ loss_fn: a function to compute loss.
38
+ i.e. it has signature of Callable[result, label] -> loss
39
+ where, result is what model_fn returned
40
+ loss is loaded from the dataloader.
41
+ optax_optimizer: the optimizer from optax library. for example, optax.adam
42
+ remat_policy: One of jax.ad_checkpoint.checkpoint_policies, specifies how
43
+ to do gradient checkpointing. If None, then it means checkpoint everything.
44
+ """
45
+ env = torchax.default_env()
46
+
47
+ def loss(weights, buffers, args, label): # inputs are XLATensor
48
+ with env, jax.named_scope("compute_loss"):
49
+ res = model_fn(weights, buffers, args)
50
+ l = loss_fn(res, label)
51
+ return l
52
+
53
+ # loss = interop.gradient_checkpoint(loss, kwargs={'policy': remat_policy})
54
+ grad_fn = interop.jax_value_and_grad(loss)
55
+
56
+ def step(weights, buffers, opt_state, args, label): # inputs are array
57
+ with jax.named_scope("compute_gradient"):
58
+ loss, gradient = grad_fn(weights, buffers, args, label)
59
+
60
+ with jax.named_scope("optimizer_updates"):
61
+ updates, opt_state = interop.call_jax(
62
+ optax_optimizer.update, gradient, opt_state, weights
63
+ )
64
+ weights = interop.call_jax(optax.apply_updates, weights, updates)
65
+ return loss, weights, opt_state
66
+
67
+ # TODO: apply jax.jit so the user don't have to.
68
+ return step
69
+
70
+
71
+ class Container:
72
+ pass
73
+
74
+
75
+ class ScannedModule(torch.nn.Module):
76
+ def __init__(self, module_list, checkpoint_policy=None):
77
+ super().__init__()
78
+
79
+ self.c = None
80
+ assert module_list
81
+ self.c = Container()
82
+ self.c.one_mod = module_list[0]
83
+ self.checkpoint_policy = checkpoint_policy
84
+
85
+ weights = self._stack_layer_weights(module_list)
86
+ self.layer_weights_keys = list(self.c.one_mod.state_dict().keys())
87
+ self.params = torch.nn.ParameterDict(
88
+ {self._param_name_new(k): v for k, v in weights.items()}
89
+ )
90
+
91
+ def _stack_layer_weights(self, module_list):
92
+ # Create weights such that, for every [n, m] weights
93
+ # becomes [k, n, m] where k is number of layer
94
+ # i.e. stacking layer weights together
95
+ temp = collections.defaultdict(list)
96
+ for m in module_list:
97
+ for k, v in m.state_dict().items():
98
+ temp[k].append(v)
99
+ res = {k: torch.stack(v) for k, v in temp.items()}
100
+ return res
101
+
102
+ def _param_name_new(self, old):
103
+ return "___".join(old.split("."))
104
+
105
+ def _param_name_old(self, new):
106
+ return ".".join(new.split("___"))
107
+
108
+ def forward(self, *args, **kwargs):
109
+ assert not kwargs
110
+ weights = {
111
+ k: self.params[self._param_name_new(k)]
112
+ for k in self.layer_weights_keys
113
+ }
114
+ scan = interop.torch_view(jax.lax.scan)
115
+
116
+ def eval_one_layer(args, weight):
117
+ # unpack args
118
+ h, *rest = args
119
+ newh = torch.func.functional_call(self.c.one_mod, weight, args)
120
+ # next layer's input; and residual to be added to list
121
+ return (newh, *rest), None
122
+
123
+ _eval_one_layer = interop.gradient_checkpoint(
124
+ eval_one_layer,
125
+ kwargs={"policy": self.checkpoint_policy},
126
+ )
127
+ h, _ = scan(
128
+ _eval_one_layer,
129
+ args,
130
+ weights,
131
+ )
132
+ return h[0]
torchax/types.py ADDED
@@ -0,0 +1,26 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Callable, Any, Union, ParamSpec, TypeAlias
16
+ import torch
17
+ import jax
18
+ import jax.numpy as jnp
19
+ import sys
20
+
21
+ P = ParamSpec('P')
22
+
23
+ TorchValue: TypeAlias = Union[torch.Tensor, torch.dtype, 'TorchCallable', Any]
24
+ TorchCallable: TypeAlias = Callable[P, TorchValue]
25
+ JaxValue: TypeAlias = Union[jax.Array, jnp.dtype, 'JaxCallable', Any]
26
+ JaxCallable: TypeAlias = Callable[P, JaxValue]
torchax/util.py ADDED
@@ -0,0 +1,102 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Callable
16
+
17
+
18
+ def partition(original: list[Any],
19
+ func: Callable[[Any], bool]) -> tuple[list[Any], list[Any]]:
20
+ """Partitions elements into two parallel lists based on a predicate function.
21
+
22
+ Iterates through the 'original' list, applying 'func' to each element 'a'.
23
+ - If `func(a)` returns True, 'a' is appended to the first list ('truthy')
24
+ and `None` is appended to the second list ('falsy').
25
+ - If `func(a)` returns False, `None` is appended to the first list ('truthy')
26
+ and 'a' is appended to the second list ('falsy').
27
+
28
+ The result is two lists of the same length as the 'original' list, acting
29
+ as parallel representations of the partitioned elements, using `None` as
30
+ placeholders.
31
+
32
+ This is useful when we want to mark a group of elements as static (via passing
33
+ static_argnums) or donated (via donate_argnums) when combining with jax.jit
34
+ and friends.
35
+
36
+ Args:
37
+ original: The list of elements to partition.
38
+ func: A callable (function or lambda) that accepts an element from
39
+ 'original' and returns a boolean value (True or False).
40
+
41
+ Returns:
42
+ A tuple containing two lists (`truthy`, `falsy`), both of the same
43
+ length as `original`:
44
+ - The first list contains elements `x` where `func(x)` was True, and
45
+ `None` otherwise.
46
+ - The second list contains elements `x` where `func(x)` was False, and
47
+ `None` otherwise.
48
+
49
+ Example:
50
+ >>> def is_even(n): return n % 2 == 0
51
+ >>> nums = [1, 2, 3, 4, 5, 6]
52
+ >>> truthy_list, falsy_list = partition(nums, is_even)
53
+ >>> truthy_list
54
+ [None, 2, None, 4, None, 6]
55
+ >>> falsy_list
56
+ [1, None, 3, None, 5, None]
57
+ """
58
+ truthy = []
59
+ falsy = []
60
+ for a in original:
61
+ t, f = (a, None) if func(a) else (None, a)
62
+ truthy.append(t)
63
+ falsy.append(f)
64
+ return truthy, falsy
65
+
66
+
67
+ def merge(list1: list[Any], list2: list[Any]) -> list[Any]:
68
+ """Merges two lists element-wise, prioritizing non-None elements from list1.
69
+
70
+ Creates a new list where each element is taken from the corresponding position
71
+ in 'list1', unless that element is None, in which case the element from the
72
+ corresponding position in 'list2' is used. Assumes both lists have the
73
+ same length.
74
+
75
+ Invariant: merge(*partion(input_list, predicate)) == input_list for any predicate
76
+
77
+ Args:
78
+ list1: The primary list. Its elements are preferred unless they are None.
79
+ list2: The secondary list. Its elements are used as fallbacks when the
80
+ corresponding element in list1 is None.
81
+
82
+ Returns:
83
+ A new list representing the merged result.
84
+
85
+ Raises:
86
+ AssertionError: If 'list1' and 'list2' do not have the same length.
87
+
88
+ Example:
89
+ >>> l1 = [1, None, 3, None]
90
+ >>> l2 = [None, 2, None, 4]
91
+ >>> merge(l1, l2)
92
+ [1, 2, 3, 4]
93
+ >>> l3 = [None, 'b', None]
94
+ >>> l4 = ['a', None, 'c']
95
+ >>> merge(l3, l4)
96
+ ['a', 'b', 'c']
97
+ """
98
+ assert len(list1) == len(list2)
99
+ res = []
100
+ for a, b in zip(list1, list2):
101
+ res.append(b if a is None else a)
102
+ return res
torchax/view.py ADDED
@@ -0,0 +1,391 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import torch.utils._pytree as torch_pytree
17
+ import jax
18
+ from enum import Enum
19
+ from typing import Union, List, Tuple, Optional, Any, cast
20
+ from abc import ABC, abstractmethod
21
+
22
+ # Reference to original PyTorch native functions
23
+ # https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
24
+
25
+
26
+ class ViewInfoType(Enum):
27
+ INVALID = 0
28
+ NARROW = 1
29
+ NO_OP = 2
30
+ PERMUTE = 3
31
+ RESHAPE = 4
32
+ RESIZE = 5
33
+ SELECT = 6
34
+ AS_STRIDED = 7
35
+ DIAGONAL = 8
36
+
37
+
38
+ class ViewInfo(ABC):
39
+ """
40
+ Abstract base class for all view operations.
41
+ Defines the interface for applying and updating view transformations.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ view_info_type: ViewInfoType = ViewInfoType.INVALID,
47
+ ):
48
+ """
49
+ Initialize a ViewInfo object.
50
+
51
+ Args:
52
+ view_info_type: The type of view operation
53
+ """
54
+ self.view_info_type = view_info_type
55
+
56
+ @abstractmethod
57
+ def update_tensor(self, new_value: jax.Array,
58
+ jax_array: jax.Array) -> jax.Array:
59
+ """
60
+ Apply this view transformation to a JAX array and update its value.
61
+
62
+ Args:
63
+ new_value: The new values to set in the view
64
+ jax_array: The parent array to update
65
+
66
+ Returns:
67
+ Updated array
68
+ """
69
+ pass
70
+
71
+ @abstractmethod
72
+ def transform_tensor(self, jax_array: jax.Array) -> jax.Array:
73
+ """
74
+ Apply this view transformation to a JAX array.
75
+
76
+ Args:
77
+ jax_array: The array to transform
78
+
79
+ Returns:
80
+ Transformed array
81
+ """
82
+ pass
83
+
84
+ @abstractmethod
85
+ def calculate_output_shape(self, source: jax.Array) -> List[int]:
86
+ """
87
+ Calculate the resulting shape after applying this view.
88
+
89
+ Args:
90
+ source: Original jax array before transformation
91
+
92
+ Returns:
93
+ Resulting shape after transformation
94
+ """
95
+ pass
96
+
97
+
98
+ class NarrowInfo(ViewInfo):
99
+ """
100
+ Represents a slicing operation on a tensor.
101
+ Handles operations like tensor[1:3, :, 2:5:2].
102
+ """
103
+
104
+ def __init__(self, slices: Union[slice, Tuple[slice]]) -> None:
105
+ """
106
+ Args:
107
+ slices: The slice(s) to apply to the tensor.
108
+ E.g. jax_array.at[slices] will return the transformed tensor.
109
+ """
110
+ super().__init__(ViewInfoType.NARROW)
111
+ self.slices = slices
112
+
113
+ def __eq__(self, other: object) -> bool:
114
+ if not isinstance(other, NarrowInfo):
115
+ return False
116
+ return self.slices == other.slices
117
+
118
+ def transform_tensor(self, jax_array: jax.Array) -> jax.Array:
119
+ try:
120
+ return jax_array[self.slices]
121
+ except IndexError as e:
122
+ raise IndexError("Invalid slice operation") from e
123
+
124
+ def update_tensor(self, new_value: jax.Array,
125
+ jax_array: jax.Array) -> jax.Array:
126
+ return jax_array.at[self.slices].set(new_value)
127
+
128
+ def calculate_output_shape(self, source: jax.Array) -> List[int]:
129
+ return source[self.slices].shape
130
+
131
+
132
+ class SelectInfo(ViewInfo):
133
+ """
134
+ Represents a selection operation on a tensor.
135
+ Typically used for indexing operations that select specific elements.
136
+ """
137
+
138
+ def __init__(self,
139
+ dim: int = 0,
140
+ start: int = 0,
141
+ end: int = 0,
142
+ stride: int = 0) -> None:
143
+ super().__init__(ViewInfoType.SELECT)
144
+ self.dim: int = dim
145
+ self.start: int = start
146
+ self.end: int = end
147
+ self.stride: int = stride
148
+
149
+ def __eq__(self, other: object) -> bool:
150
+ if not isinstance(other, SelectInfo):
151
+ return False
152
+ return (self.dim == other.dim and self.start == other.start and
153
+ self.end == other.end and self.stride == other.stride)
154
+
155
+ def transform_tensor(self, jax_array: jax.Array) -> jax.Array:
156
+ raise NotImplementedError("SelectInfo.apply not implemented")
157
+
158
+ def update_tensor(self, new_value: jax.Array,
159
+ jax_array: jax.Array) -> jax.Array:
160
+ raise NotImplementedError("SelectInfo.update not implemented")
161
+
162
+ def calculate_output_shape(self, source: jax.Array) -> List[int]:
163
+ raise NotImplementedError(
164
+ "SelectInfo.calculate_output_shape not implemented")
165
+
166
+
167
+ class AsStridedInfo(ViewInfo):
168
+ """
169
+ Information for as_strided operations.
170
+ """
171
+
172
+ def __init__(self, stride: List[int], offset: int = 0) -> None:
173
+ super().__init__(ViewInfoType.AS_STRIDED)
174
+ self.stride: List[int] = stride
175
+ self.offset: int = offset
176
+
177
+ def __eq__(self, other: object) -> bool:
178
+ if not isinstance(other, AsStridedInfo):
179
+ return False
180
+ return self.offset == other.offset and self.stride == other.stride
181
+
182
+ def transform_tensor(self, jax_array: jax.Array) -> jax.Array:
183
+ raise NotImplementedError("AsStridedInfo.apply not implemented")
184
+
185
+ def update_tensor(self, new_value: jax.Array,
186
+ jax_array: jax.Array) -> jax.Array:
187
+ raise NotImplementedError("AsStridedInfo.update not implemented")
188
+
189
+ def calculate_output_shape(self, source: jax.Array) -> List[int]:
190
+ raise NotImplementedError(
191
+ "AsStridedInfo.calculate_output_shape not implemented")
192
+
193
+
194
+ class DiagonalInfo(ViewInfo):
195
+ """
196
+ Information for diagonal operations.
197
+ Extracts diagonal elements from a tensor.
198
+ """
199
+
200
+ def __init__(self, offset: int = 0, dim1: int = 0, dim2: int = 1) -> None:
201
+ """
202
+ Args:
203
+ offset: Offset from the main diagonal
204
+ dim1: First dimension for diagonal extraction
205
+ dim2: Second dimension for diagonal extraction
206
+ """
207
+ super().__init__(ViewInfoType.DIAGONAL)
208
+ self.offset: int = offset
209
+ self.dim1: int = dim1
210
+ self.dim2: int = dim2
211
+
212
+ def __eq__(self, other: object) -> bool:
213
+ if not isinstance(other, DiagonalInfo):
214
+ return False
215
+ return (self.offset == other.offset and self.dim1 == other.dim1 and
216
+ self.dim2 == other.dim2)
217
+
218
+ def transform_tensor(self, jax_array: jax.Array) -> jax.Array:
219
+ raise NotImplementedError("DiagonalInfo.apply not implemented")
220
+
221
+ def update_tensor(self, new_value: jax.Array,
222
+ jax_array: jax.Array) -> jax.Array:
223
+ raise NotImplementedError("DiagonalInfo.update not implemented")
224
+
225
+ def calculate_output_shape(self, source: jax.Array) -> List[int]:
226
+ raise NotImplementedError(
227
+ "DiagonalInfo.calculate_output_shape not implemented")
228
+
229
+
230
+ class View(torch.Tensor):
231
+ """
232
+ A View is a reference to another Tensor or another View,
233
+ with a transformation applied to it.
234
+ """
235
+
236
+ @staticmethod
237
+ def __new__(cls, parent: Union["torchax.Tensor", "View"], view_info: ViewInfo,
238
+ env: Any) -> "View":
239
+ """
240
+ Args:
241
+ parent: Parent tensor or view
242
+ view_info: Information about the view transformation
243
+ env: Environment for tensor operations
244
+ """
245
+ shape = view_info.calculate_output_shape(parent.jax())
246
+ return torch.Tensor._make_wrapper_subclass(
247
+ cls,
248
+ shape,
249
+ device="meta",
250
+ dtype=parent.dtype,
251
+ requires_grad=False,
252
+ )
253
+
254
+ def __init__(self, parent: Union["torchax.Tensor", "View"],
255
+ view_info: ViewInfo, env: Any) -> None:
256
+ super().__init__()
257
+ self.parent = parent
258
+ self.view_info = view_info
259
+ self._env = env
260
+
261
+ def get_transformation_chain(self) -> List[ViewInfo]:
262
+ """
263
+ Get all view transformations from the source tensor to this view.
264
+ """
265
+ if isinstance(self.parent, View):
266
+ transformations = self.parent.get_transformation_chain()
267
+ transformations.append(self.view_info)
268
+ return transformations
269
+ else:
270
+ return [self.view_info]
271
+
272
+ __torch_function__ = torch._C._disabled_torch_function_impl
273
+
274
+ def source_jax(self) -> jax.Array:
275
+ """
276
+ Returns the source tensor.
277
+ """
278
+ if isinstance(self.parent, View):
279
+ return self.parent.source_jax()
280
+ else:
281
+ return self.parent.jax()
282
+
283
+ def replace_source_jax(self, new_value: jax.Array) -> None:
284
+ """
285
+ Update the source tensor with new values.
286
+ """
287
+ if isinstance(self.parent, View):
288
+ self.parent.replace_source_jax(new_value)
289
+ else:
290
+ assert new_value.shape == self.parent._elem.shape
291
+ self.parent._elem = new_value
292
+
293
+ def torch(self) -> "torchax.Tensor":
294
+ """
295
+ Returns a Torchax tensor representing this view after all transformations
296
+ """
297
+ from torchax.tensor import Tensor
298
+
299
+ return Tensor(self.jax(), self._env)
300
+
301
+ def update(
302
+ self,
303
+ new_values: Union[jax.Array, "View", "torchax.Tensor"],
304
+ view_infos: Optional[List[ViewInfo]] = None,
305
+ ) -> None:
306
+ """
307
+ Update this view with new values, propagating changes back to source.
308
+ If view_infos is None, it will use the transformation chain
309
+ from the source tensor.
310
+ """
311
+ if view_infos is None:
312
+ view_infos = self.get_transformation_chain()
313
+
314
+ # Get the source JAX array
315
+ source_array = self.source_jax()
316
+
317
+ # Get the new value
318
+ from torchax.tensor import Tensor
319
+
320
+ if isinstance(new_values, View) or isinstance(new_values, Tensor):
321
+ new_values = new_values.jax()
322
+
323
+ # Apply all view transformations to the source array
324
+ # And store intermediate values
325
+ intermediate_values = [source_array]
326
+ for view_info in view_infos[:-1]:
327
+ intermediate_values.append(
328
+ view_info.transform_tensor(intermediate_values[-1]))
329
+
330
+ # TODO: Investigate efficiency of this algorithm
331
+ # Update the source array with the new value by
332
+ # applying inverse transformations in reverse order
333
+ for view_info, parent_array in zip(
334
+ reversed(view_infos), reversed(intermediate_values)):
335
+ # Apply the inverse transformation to propagate changes back
336
+ new_values = view_info.update_tensor(new_values, parent_array)
337
+
338
+ # Update the source tensor with the new values
339
+ self.replace_source_jax(new_values)
340
+
341
+ @classmethod
342
+ def __torch_dispatch__(
343
+ cls,
344
+ func: Any,
345
+ types: Tuple[Any, ...],
346
+ args: Tuple[Any, ...] = (),
347
+ kwargs: Optional[dict] = None,
348
+ ) -> Any:
349
+ raise AssertionError(
350
+ 'torchax Tensors can only do math within the torchax environment.'
351
+ 'Please wrap your code with `with torchax.default_env()` or '
352
+ 'call torchax.enable_globally() before.')
353
+
354
+ def create_sub_view(self, view_info: ViewInfo) -> "View":
355
+ """
356
+ Create a new view that is a child of this view.
357
+ """
358
+ return View(self, view_info, self._env)
359
+
360
+ def __str__(self) -> str:
361
+ return f"View({self.torch()})"
362
+
363
+ def jax(self) -> jax.Array:
364
+ """
365
+ Returns a copy of the source tensor after transformations.
366
+ """
367
+ result = self.source_jax()
368
+ for view_info in self.get_transformation_chain():
369
+ result = view_info.transform_tensor(result)
370
+ return result
371
+
372
+ def __setitem__(self, indexes, val):
373
+ view_infos = self.get_transformation_chain() + [NarrowInfo(indexes)]
374
+ self.update(view_infos=view_infos, new_values=val)
375
+
376
+ def dim(self):
377
+ return self.ndim
378
+
379
+ @property
380
+ def device(self):
381
+ return torch.device("jax:0")
382
+
383
+ @property
384
+ def jax_device(self):
385
+ return self.jax().device
386
+
387
+ @property
388
+ def ndim(self):
389
+ return len(self.shape)
390
+
391
+ __repr__ = __str__