torchax 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/CONTRIBUTING.md +2 -2
- torchax/__init__.py +26 -24
- torchax/amp.py +332 -0
- torchax/config.py +25 -14
- torchax/configuration.py +30 -0
- torchax/decompositions.py +663 -195
- torchax/device_module.py +14 -1
- torchax/environment.py +0 -1
- torchax/export.py +26 -17
- torchax/flax.py +39 -0
- torchax/interop.py +288 -141
- torchax/mesh_util.py +220 -0
- torchax/ops/jaten.py +1723 -1297
- torchax/ops/jax_reimplement.py +23 -21
- torchax/ops/jc10d.py +5 -4
- torchax/ops/jimage.py +113 -0
- torchax/ops/jlibrary.py +9 -2
- torchax/ops/jtorch.py +237 -88
- torchax/ops/jtorchvision_nms.py +32 -43
- torchax/ops/mappings.py +77 -35
- torchax/ops/op_base.py +59 -32
- torchax/ops/ops_registry.py +40 -35
- torchax/tensor.py +442 -288
- torchax/train.py +38 -41
- torchax/util.py +88 -0
- torchax/view.py +377 -0
- {torchax-0.0.4.dist-info → torchax-0.0.6.dist-info}/METADATA +111 -145
- torchax-0.0.6.dist-info/RECORD +33 -0
- torchax/distributed.py +0 -246
- torchax-0.0.4.dist-info/RECORD +0 -27
- {torchax-0.0.4.dist-info → torchax-0.0.6.dist-info}/WHEEL +0 -0
- {torchax-0.0.4.dist-info → torchax-0.0.6.dist-info}/licenses/LICENSE +0 -0
torchax/train.py
CHANGED
|
@@ -7,14 +7,11 @@ from torchax import interop
|
|
|
7
7
|
from torchax.interop import torch_view, jax_view
|
|
8
8
|
import optax
|
|
9
9
|
|
|
10
|
-
|
|
11
10
|
remat = torch_view(jax.remat)
|
|
12
11
|
mark_sharding = torch_view(jax.lax.with_sharding_constraint)
|
|
13
12
|
|
|
14
13
|
|
|
15
|
-
def make_train_step(model_fn,
|
|
16
|
-
loss_fn, optax_optimizer,
|
|
17
|
-
remat_policy=None):
|
|
14
|
+
def make_train_step(model_fn, loss_fn, optax_optimizer, remat_policy=None):
|
|
18
15
|
"""Make a function that do one train step given model and loss.
|
|
19
16
|
|
|
20
17
|
model_fn: a function representing the model's forward:
|
|
@@ -32,7 +29,8 @@ def make_train_step(model_fn,
|
|
|
32
29
|
to do gradient checkpointing. If None, then it means checkpoint everything.
|
|
33
30
|
"""
|
|
34
31
|
env = torchax.default_env()
|
|
35
|
-
|
|
32
|
+
|
|
33
|
+
def loss(weights, buffers, args, label): # inputs are XLATensor
|
|
36
34
|
with env, jax.named_scope('compute_loss'):
|
|
37
35
|
res = model_fn(weights, buffers, args)
|
|
38
36
|
l = loss_fn(res, label)
|
|
@@ -41,26 +39,24 @@ def make_train_step(model_fn,
|
|
|
41
39
|
loss = interop.gradient_checkpoint(loss, kwargs={'policy': remat_policy})
|
|
42
40
|
grad_fn = interop.jax_value_and_grad(loss)
|
|
43
41
|
|
|
44
|
-
def step(weights, buffers, opt_state, args, label):
|
|
42
|
+
def step(weights, buffers, opt_state, args, label): #inputs are array
|
|
45
43
|
with jax.named_scope('compute_gradient'):
|
|
46
|
-
|
|
44
|
+
loss, gradient = grad_fn(weights, buffers, args, label)
|
|
47
45
|
|
|
48
46
|
with jax.named_scope("optimizer_updates"):
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
weights = interop.call_jax(optax.apply_updates, weights, updates)
|
|
47
|
+
updates, opt_state = interop.call_jax(optax_optimizer.update, gradient,
|
|
48
|
+
opt_state, weights)
|
|
49
|
+
weights = interop.call_jax(optax.apply_updates, weights, updates)
|
|
53
50
|
return loss, weights, opt_state
|
|
54
51
|
|
|
55
52
|
# TODO: apply jax.jit so the user don't have to.
|
|
56
53
|
return step
|
|
57
54
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
55
|
+
|
|
61
56
|
class Container:
|
|
62
57
|
pass
|
|
63
58
|
|
|
59
|
+
|
|
64
60
|
class ScannedModule(torch.nn.Module):
|
|
65
61
|
|
|
66
62
|
def __init__(self, module_list, checkpoint_policy=None):
|
|
@@ -75,9 +71,9 @@ class ScannedModule(torch.nn.Module):
|
|
|
75
71
|
weights = self._stack_layer_weights(module_list)
|
|
76
72
|
self.layer_weights_keys = list(self.c.one_mod.state_dict().keys())
|
|
77
73
|
self.params = torch.nn.ParameterDict({
|
|
78
|
-
|
|
74
|
+
self._param_name_new(k): v for k, v in weights.items()
|
|
79
75
|
})
|
|
80
|
-
|
|
76
|
+
|
|
81
77
|
def _stack_layer_weights(self, module_list):
|
|
82
78
|
# Create weights such that, for every [n, m] weights
|
|
83
79
|
# becomes [k, n, m] where k is number of layer
|
|
@@ -85,36 +81,37 @@ class ScannedModule(torch.nn.Module):
|
|
|
85
81
|
temp = collections.defaultdict(list)
|
|
86
82
|
for m in module_list:
|
|
87
83
|
for k, v in m.state_dict().items():
|
|
88
|
-
|
|
84
|
+
temp[k].append(v)
|
|
89
85
|
res = {k: torch.stack(v) for k, v in temp.items()}
|
|
90
86
|
return res
|
|
91
87
|
|
|
92
|
-
|
|
93
88
|
def _param_name_new(self, old):
|
|
94
|
-
|
|
89
|
+
return '___'.join(old.split('.'))
|
|
95
90
|
|
|
96
91
|
def _param_name_old(self, new):
|
|
97
|
-
|
|
92
|
+
return '.'.join(new.split('___'))
|
|
98
93
|
|
|
99
94
|
def forward(self, *args, **kwargs):
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
95
|
+
assert not kwargs
|
|
96
|
+
weights = {
|
|
97
|
+
k: self.params[self._param_name_new(k)] for k in self.layer_weights_keys
|
|
98
|
+
}
|
|
99
|
+
scan = interop.torch_view(jax.lax.scan)
|
|
100
|
+
|
|
101
|
+
def eval_one_layer(args, weight):
|
|
102
|
+
# unpack args
|
|
103
|
+
h, *rest = args
|
|
104
|
+
newh = torch.func.functional_call(self.c.one_mod, weight, args)
|
|
105
|
+
# next layer's input; and residual to be added to list
|
|
106
|
+
return (newh, *rest), None
|
|
107
|
+
|
|
108
|
+
_eval_one_layer = interop.gradient_checkpoint(
|
|
109
|
+
eval_one_layer,
|
|
110
|
+
kwargs={'policy': self.checkpoint_policy},
|
|
111
|
+
)
|
|
112
|
+
h, _ = scan(
|
|
113
|
+
_eval_one_layer,
|
|
114
|
+
args,
|
|
115
|
+
weights,
|
|
116
|
+
)
|
|
117
|
+
return h[0]
|
torchax/util.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from typing import Any, Callable
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def partition(original: list[Any],
|
|
5
|
+
func: Callable[[Any], bool]) -> tuple[list[Any], list[Any]]:
|
|
6
|
+
"""Partitions elements into two parallel lists based on a predicate function.
|
|
7
|
+
|
|
8
|
+
Iterates through the 'original' list, applying 'func' to each element 'a'.
|
|
9
|
+
- If `func(a)` returns True, 'a' is appended to the first list ('truthy')
|
|
10
|
+
and `None` is appended to the second list ('falsy').
|
|
11
|
+
- If `func(a)` returns False, `None` is appended to the first list ('truthy')
|
|
12
|
+
and 'a' is appended to the second list ('falsy').
|
|
13
|
+
|
|
14
|
+
The result is two lists of the same length as the 'original' list, acting
|
|
15
|
+
as parallel representations of the partitioned elements, using `None` as
|
|
16
|
+
placeholders.
|
|
17
|
+
|
|
18
|
+
This is useful when we want to mark a group of elements as static (via passing
|
|
19
|
+
static_argnums) or donated (via donate_argnums) when combining with jax.jit
|
|
20
|
+
and friends.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
original: The list of elements to partition.
|
|
24
|
+
func: A callable (function or lambda) that accepts an element from
|
|
25
|
+
'original' and returns a boolean value (True or False).
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
A tuple containing two lists (`truthy`, `falsy`), both of the same
|
|
29
|
+
length as `original`:
|
|
30
|
+
- The first list contains elements `x` where `func(x)` was True, and
|
|
31
|
+
`None` otherwise.
|
|
32
|
+
- The second list contains elements `x` where `func(x)` was False, and
|
|
33
|
+
`None` otherwise.
|
|
34
|
+
|
|
35
|
+
Example:
|
|
36
|
+
>>> def is_even(n): return n % 2 == 0
|
|
37
|
+
>>> nums = [1, 2, 3, 4, 5, 6]
|
|
38
|
+
>>> truthy_list, falsy_list = partition(nums, is_even)
|
|
39
|
+
>>> truthy_list
|
|
40
|
+
[None, 2, None, 4, None, 6]
|
|
41
|
+
>>> falsy_list
|
|
42
|
+
[1, None, 3, None, 5, None]
|
|
43
|
+
"""
|
|
44
|
+
truthy = []
|
|
45
|
+
falsy = []
|
|
46
|
+
for a in original:
|
|
47
|
+
t, f = (a, None) if func(a) else (None, a)
|
|
48
|
+
truthy.append(t)
|
|
49
|
+
falsy.append(f)
|
|
50
|
+
return truthy, falsy
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def merge(list1: list[Any], list2: list[Any]) -> list[Any]:
|
|
54
|
+
"""Merges two lists element-wise, prioritizing non-None elements from list1.
|
|
55
|
+
|
|
56
|
+
Creates a new list where each element is taken from the corresponding position
|
|
57
|
+
in 'list1', unless that element is None, in which case the element from the
|
|
58
|
+
corresponding position in 'list2' is used. Assumes both lists have the
|
|
59
|
+
same length.
|
|
60
|
+
|
|
61
|
+
Invariant: merge(*partion(input_list, predicate)) == input_list for any predicate
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
list1: The primary list. Its elements are preferred unless they are None.
|
|
65
|
+
list2: The secondary list. Its elements are used as fallbacks when the
|
|
66
|
+
corresponding element in list1 is None.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
A new list representing the merged result.
|
|
70
|
+
|
|
71
|
+
Raises:
|
|
72
|
+
AssertionError: If 'list1' and 'list2' do not have the same length.
|
|
73
|
+
|
|
74
|
+
Example:
|
|
75
|
+
>>> l1 = [1, None, 3, None]
|
|
76
|
+
>>> l2 = [None, 2, None, 4]
|
|
77
|
+
>>> merge(l1, l2)
|
|
78
|
+
[1, 2, 3, 4]
|
|
79
|
+
>>> l3 = [None, 'b', None]
|
|
80
|
+
>>> l4 = ['a', None, 'c']
|
|
81
|
+
>>> merge(l3, l4)
|
|
82
|
+
['a', 'b', 'c']
|
|
83
|
+
"""
|
|
84
|
+
assert len(list1) == len(list2)
|
|
85
|
+
res = []
|
|
86
|
+
for a, b in zip(list1, list2):
|
|
87
|
+
res.append(b if a is None else a)
|
|
88
|
+
return res
|
torchax/view.py
ADDED
|
@@ -0,0 +1,377 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.utils._pytree as torch_pytree
|
|
3
|
+
import jax
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from typing import Union, List, Tuple, Optional, Any, cast
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
|
|
8
|
+
# Reference to original PyTorch native functions
|
|
9
|
+
# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ViewInfoType(Enum):
|
|
13
|
+
INVALID = 0
|
|
14
|
+
NARROW = 1
|
|
15
|
+
NO_OP = 2
|
|
16
|
+
PERMUTE = 3
|
|
17
|
+
RESHAPE = 4
|
|
18
|
+
RESIZE = 5
|
|
19
|
+
SELECT = 6
|
|
20
|
+
AS_STRIDED = 7
|
|
21
|
+
DIAGONAL = 8
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ViewInfo(ABC):
|
|
25
|
+
"""
|
|
26
|
+
Abstract base class for all view operations.
|
|
27
|
+
Defines the interface for applying and updating view transformations.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
view_info_type: ViewInfoType = ViewInfoType.INVALID,
|
|
33
|
+
):
|
|
34
|
+
"""
|
|
35
|
+
Initialize a ViewInfo object.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
view_info_type: The type of view operation
|
|
39
|
+
"""
|
|
40
|
+
self.view_info_type = view_info_type
|
|
41
|
+
|
|
42
|
+
@abstractmethod
|
|
43
|
+
def update_tensor(self, new_value: jax.Array,
|
|
44
|
+
jax_array: jax.Array) -> jax.Array:
|
|
45
|
+
"""
|
|
46
|
+
Apply this view transformation to a JAX array and update its value.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
new_value: The new values to set in the view
|
|
50
|
+
jax_array: The parent array to update
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
Updated array
|
|
54
|
+
"""
|
|
55
|
+
pass
|
|
56
|
+
|
|
57
|
+
@abstractmethod
|
|
58
|
+
def transform_tensor(self, jax_array: jax.Array) -> jax.Array:
|
|
59
|
+
"""
|
|
60
|
+
Apply this view transformation to a JAX array.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
jax_array: The array to transform
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
Transformed array
|
|
67
|
+
"""
|
|
68
|
+
pass
|
|
69
|
+
|
|
70
|
+
@abstractmethod
|
|
71
|
+
def calculate_output_shape(self, source: jax.Array) -> List[int]:
|
|
72
|
+
"""
|
|
73
|
+
Calculate the resulting shape after applying this view.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
source: Original jax array before transformation
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
Resulting shape after transformation
|
|
80
|
+
"""
|
|
81
|
+
pass
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class NarrowInfo(ViewInfo):
|
|
85
|
+
"""
|
|
86
|
+
Represents a slicing operation on a tensor.
|
|
87
|
+
Handles operations like tensor[1:3, :, 2:5:2].
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
def __init__(self, slices: Union[slice, Tuple[slice]]) -> None:
|
|
91
|
+
"""
|
|
92
|
+
Args:
|
|
93
|
+
slices: The slice(s) to apply to the tensor.
|
|
94
|
+
E.g. jax_array.at[slices] will return the transformed tensor.
|
|
95
|
+
"""
|
|
96
|
+
super().__init__(ViewInfoType.NARROW)
|
|
97
|
+
self.slices = slices
|
|
98
|
+
|
|
99
|
+
def __eq__(self, other: object) -> bool:
|
|
100
|
+
if not isinstance(other, NarrowInfo):
|
|
101
|
+
return False
|
|
102
|
+
return self.slices == other.slices
|
|
103
|
+
|
|
104
|
+
def transform_tensor(self, jax_array: jax.Array) -> jax.Array:
|
|
105
|
+
try:
|
|
106
|
+
return jax_array[self.slices]
|
|
107
|
+
except IndexError as e:
|
|
108
|
+
raise IndexError("Invalid slice operation") from e
|
|
109
|
+
|
|
110
|
+
def update_tensor(self, new_value: jax.Array,
|
|
111
|
+
jax_array: jax.Array) -> jax.Array:
|
|
112
|
+
return jax_array.at[self.slices].set(new_value)
|
|
113
|
+
|
|
114
|
+
def calculate_output_shape(self, source: jax.Array) -> List[int]:
|
|
115
|
+
return source[self.slices].shape
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class SelectInfo(ViewInfo):
|
|
119
|
+
"""
|
|
120
|
+
Represents a selection operation on a tensor.
|
|
121
|
+
Typically used for indexing operations that select specific elements.
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
def __init__(self,
|
|
125
|
+
dim: int = 0,
|
|
126
|
+
start: int = 0,
|
|
127
|
+
end: int = 0,
|
|
128
|
+
stride: int = 0) -> None:
|
|
129
|
+
super().__init__(ViewInfoType.SELECT)
|
|
130
|
+
self.dim: int = dim
|
|
131
|
+
self.start: int = start
|
|
132
|
+
self.end: int = end
|
|
133
|
+
self.stride: int = stride
|
|
134
|
+
|
|
135
|
+
def __eq__(self, other: object) -> bool:
|
|
136
|
+
if not isinstance(other, SelectInfo):
|
|
137
|
+
return False
|
|
138
|
+
return (self.dim == other.dim and self.start == other.start and
|
|
139
|
+
self.end == other.end and self.stride == other.stride)
|
|
140
|
+
|
|
141
|
+
def transform_tensor(self, jax_array: jax.Array) -> jax.Array:
|
|
142
|
+
raise NotImplementedError("SelectInfo.apply not implemented")
|
|
143
|
+
|
|
144
|
+
def update_tensor(self, new_value: jax.Array,
|
|
145
|
+
jax_array: jax.Array) -> jax.Array:
|
|
146
|
+
raise NotImplementedError("SelectInfo.update not implemented")
|
|
147
|
+
|
|
148
|
+
def calculate_output_shape(self, source: jax.Array) -> List[int]:
|
|
149
|
+
raise NotImplementedError(
|
|
150
|
+
"SelectInfo.calculate_output_shape not implemented")
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class AsStridedInfo(ViewInfo):
|
|
154
|
+
"""
|
|
155
|
+
Information for as_strided operations.
|
|
156
|
+
"""
|
|
157
|
+
|
|
158
|
+
def __init__(self, stride: List[int], offset: int = 0) -> None:
|
|
159
|
+
super().__init__(ViewInfoType.AS_STRIDED)
|
|
160
|
+
self.stride: List[int] = stride
|
|
161
|
+
self.offset: int = offset
|
|
162
|
+
|
|
163
|
+
def __eq__(self, other: object) -> bool:
|
|
164
|
+
if not isinstance(other, AsStridedInfo):
|
|
165
|
+
return False
|
|
166
|
+
return self.offset == other.offset and self.stride == other.stride
|
|
167
|
+
|
|
168
|
+
def transform_tensor(self, jax_array: jax.Array) -> jax.Array:
|
|
169
|
+
raise NotImplementedError("AsStridedInfo.apply not implemented")
|
|
170
|
+
|
|
171
|
+
def update_tensor(self, new_value: jax.Array,
|
|
172
|
+
jax_array: jax.Array) -> jax.Array:
|
|
173
|
+
raise NotImplementedError("AsStridedInfo.update not implemented")
|
|
174
|
+
|
|
175
|
+
def calculate_output_shape(self, source: jax.Array) -> List[int]:
|
|
176
|
+
raise NotImplementedError(
|
|
177
|
+
"AsStridedInfo.calculate_output_shape not implemented")
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class DiagonalInfo(ViewInfo):
|
|
181
|
+
"""
|
|
182
|
+
Information for diagonal operations.
|
|
183
|
+
Extracts diagonal elements from a tensor.
|
|
184
|
+
"""
|
|
185
|
+
|
|
186
|
+
def __init__(self, offset: int = 0, dim1: int = 0, dim2: int = 1) -> None:
|
|
187
|
+
"""
|
|
188
|
+
Args:
|
|
189
|
+
offset: Offset from the main diagonal
|
|
190
|
+
dim1: First dimension for diagonal extraction
|
|
191
|
+
dim2: Second dimension for diagonal extraction
|
|
192
|
+
"""
|
|
193
|
+
super().__init__(ViewInfoType.DIAGONAL)
|
|
194
|
+
self.offset: int = offset
|
|
195
|
+
self.dim1: int = dim1
|
|
196
|
+
self.dim2: int = dim2
|
|
197
|
+
|
|
198
|
+
def __eq__(self, other: object) -> bool:
|
|
199
|
+
if not isinstance(other, DiagonalInfo):
|
|
200
|
+
return False
|
|
201
|
+
return (self.offset == other.offset and self.dim1 == other.dim1 and
|
|
202
|
+
self.dim2 == other.dim2)
|
|
203
|
+
|
|
204
|
+
def transform_tensor(self, jax_array: jax.Array) -> jax.Array:
|
|
205
|
+
raise NotImplementedError("DiagonalInfo.apply not implemented")
|
|
206
|
+
|
|
207
|
+
def update_tensor(self, new_value: jax.Array,
|
|
208
|
+
jax_array: jax.Array) -> jax.Array:
|
|
209
|
+
raise NotImplementedError("DiagonalInfo.update not implemented")
|
|
210
|
+
|
|
211
|
+
def calculate_output_shape(self, source: jax.Array) -> List[int]:
|
|
212
|
+
raise NotImplementedError(
|
|
213
|
+
"DiagonalInfo.calculate_output_shape not implemented")
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
class View(torch.Tensor):
|
|
217
|
+
"""
|
|
218
|
+
A View is a reference to another Tensor or another View,
|
|
219
|
+
with a transformation applied to it.
|
|
220
|
+
"""
|
|
221
|
+
|
|
222
|
+
@staticmethod
|
|
223
|
+
def __new__(cls, parent: Union["torchax.Tensor", "View"], view_info: ViewInfo,
|
|
224
|
+
env: Any) -> "View":
|
|
225
|
+
"""
|
|
226
|
+
Args:
|
|
227
|
+
parent: Parent tensor or view
|
|
228
|
+
view_info: Information about the view transformation
|
|
229
|
+
env: Environment for tensor operations
|
|
230
|
+
"""
|
|
231
|
+
shape = view_info.calculate_output_shape(parent.jax())
|
|
232
|
+
return torch.Tensor._make_wrapper_subclass(
|
|
233
|
+
cls,
|
|
234
|
+
shape,
|
|
235
|
+
device="meta",
|
|
236
|
+
dtype=parent.dtype,
|
|
237
|
+
requires_grad=False,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
def __init__(self, parent: Union["torchax.Tensor", "View"],
|
|
241
|
+
view_info: ViewInfo, env: Any) -> None:
|
|
242
|
+
super().__init__()
|
|
243
|
+
self.parent = parent
|
|
244
|
+
self.view_info = view_info
|
|
245
|
+
self._env = env
|
|
246
|
+
|
|
247
|
+
def get_transformation_chain(self) -> List[ViewInfo]:
|
|
248
|
+
"""
|
|
249
|
+
Get all view transformations from the source tensor to this view.
|
|
250
|
+
"""
|
|
251
|
+
if isinstance(self.parent, View):
|
|
252
|
+
transformations = self.parent.get_transformation_chain()
|
|
253
|
+
transformations.append(self.view_info)
|
|
254
|
+
return transformations
|
|
255
|
+
else:
|
|
256
|
+
return [self.view_info]
|
|
257
|
+
|
|
258
|
+
__torch_function__ = torch._C._disabled_torch_function_impl
|
|
259
|
+
|
|
260
|
+
def source_jax(self) -> jax.Array:
|
|
261
|
+
"""
|
|
262
|
+
Returns the source tensor.
|
|
263
|
+
"""
|
|
264
|
+
if isinstance(self.parent, View):
|
|
265
|
+
return self.parent.source_jax()
|
|
266
|
+
else:
|
|
267
|
+
return self.parent.jax()
|
|
268
|
+
|
|
269
|
+
def replace_source_jax(self, new_value: jax.Array) -> None:
|
|
270
|
+
"""
|
|
271
|
+
Update the source tensor with new values.
|
|
272
|
+
"""
|
|
273
|
+
if isinstance(self.parent, View):
|
|
274
|
+
self.parent.replace_source_jax(new_value)
|
|
275
|
+
else:
|
|
276
|
+
assert new_value.shape == self.parent._elem.shape
|
|
277
|
+
self.parent._elem = new_value
|
|
278
|
+
|
|
279
|
+
def torch(self) -> "torchax.Tensor":
|
|
280
|
+
"""
|
|
281
|
+
Returns a Torchax tensor representing this view after all transformations
|
|
282
|
+
"""
|
|
283
|
+
from torchax.tensor import Tensor
|
|
284
|
+
|
|
285
|
+
return Tensor(self.jax(), self._env)
|
|
286
|
+
|
|
287
|
+
def update(
|
|
288
|
+
self,
|
|
289
|
+
new_values: Union[jax.Array, "View", "torchax.Tensor"],
|
|
290
|
+
view_infos: Optional[List[ViewInfo]] = None,
|
|
291
|
+
) -> None:
|
|
292
|
+
"""
|
|
293
|
+
Update this view with new values, propagating changes back to source.
|
|
294
|
+
If view_infos is None, it will use the transformation chain
|
|
295
|
+
from the source tensor.
|
|
296
|
+
"""
|
|
297
|
+
if view_infos is None:
|
|
298
|
+
view_infos = self.get_transformation_chain()
|
|
299
|
+
|
|
300
|
+
# Get the source JAX array
|
|
301
|
+
source_array = self.source_jax()
|
|
302
|
+
|
|
303
|
+
# Get the new value
|
|
304
|
+
from torchax.tensor import Tensor
|
|
305
|
+
|
|
306
|
+
if isinstance(new_values, View) or isinstance(new_values, Tensor):
|
|
307
|
+
new_values = new_values.jax()
|
|
308
|
+
|
|
309
|
+
# Apply all view transformations to the source array
|
|
310
|
+
# And store intermediate values
|
|
311
|
+
intermediate_values = [source_array]
|
|
312
|
+
for view_info in view_infos[:-1]:
|
|
313
|
+
intermediate_values.append(
|
|
314
|
+
view_info.transform_tensor(intermediate_values[-1]))
|
|
315
|
+
|
|
316
|
+
# TODO: Investigate efficiency of this algorithm
|
|
317
|
+
# Update the source array with the new value by
|
|
318
|
+
# applying inverse transformations in reverse order
|
|
319
|
+
for view_info, parent_array in zip(
|
|
320
|
+
reversed(view_infos), reversed(intermediate_values)):
|
|
321
|
+
# Apply the inverse transformation to propagate changes back
|
|
322
|
+
new_values = view_info.update_tensor(new_values, parent_array)
|
|
323
|
+
|
|
324
|
+
# Update the source tensor with the new values
|
|
325
|
+
self.replace_source_jax(new_values)
|
|
326
|
+
|
|
327
|
+
@classmethod
|
|
328
|
+
def __torch_dispatch__(
|
|
329
|
+
cls,
|
|
330
|
+
func: Any,
|
|
331
|
+
types: Tuple[Any, ...],
|
|
332
|
+
args: Tuple[Any, ...] = (),
|
|
333
|
+
kwargs: Optional[dict] = None,
|
|
334
|
+
) -> Any:
|
|
335
|
+
raise AssertionError(
|
|
336
|
+
'torchax Tensors can only do math within the torchax environment.'
|
|
337
|
+
'Please wrap your code with `with torchax.default_env()` or '
|
|
338
|
+
'call torchax.enable_globally() before.')
|
|
339
|
+
|
|
340
|
+
def create_sub_view(self, view_info: ViewInfo) -> "View":
|
|
341
|
+
"""
|
|
342
|
+
Create a new view that is a child of this view.
|
|
343
|
+
"""
|
|
344
|
+
return View(self, view_info, self._env)
|
|
345
|
+
|
|
346
|
+
def __str__(self) -> str:
|
|
347
|
+
return f"View({self.torch()})"
|
|
348
|
+
|
|
349
|
+
def jax(self) -> jax.Array:
|
|
350
|
+
"""
|
|
351
|
+
Returns a copy of the source tensor after transformations.
|
|
352
|
+
"""
|
|
353
|
+
result = self.source_jax()
|
|
354
|
+
for view_info in self.get_transformation_chain():
|
|
355
|
+
result = view_info.transform_tensor(result)
|
|
356
|
+
return result
|
|
357
|
+
|
|
358
|
+
def __setitem__(self, indexes, val):
|
|
359
|
+
view_infos = self.get_transformation_chain() + [NarrowInfo(indexes)]
|
|
360
|
+
self.update(view_infos=view_infos, new_values=val)
|
|
361
|
+
|
|
362
|
+
def dim(self):
|
|
363
|
+
return self.ndim
|
|
364
|
+
|
|
365
|
+
@property
|
|
366
|
+
def device(self):
|
|
367
|
+
return torch.device("jax:0")
|
|
368
|
+
|
|
369
|
+
@property
|
|
370
|
+
def jax_device(self):
|
|
371
|
+
return self.jax().device
|
|
372
|
+
|
|
373
|
+
@property
|
|
374
|
+
def ndim(self):
|
|
375
|
+
return len(self.shape)
|
|
376
|
+
|
|
377
|
+
__repr__ = __str__
|