torchax 0.0.10.dev20251117__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchax/CONTRIBUTING.md +43 -0
- torchax/__init__.py +153 -0
- torchax/amp.py +346 -0
- torchax/checkpoint.py +79 -0
- torchax/config.py +44 -0
- torchax/decompositions.py +790 -0
- torchax/device_module.py +47 -0
- torchax/export.py +259 -0
- torchax/flax.py +53 -0
- torchax/interop.py +369 -0
- torchax/mesh_util.py +234 -0
- torchax/ops/__init__.py +24 -0
- torchax/ops/jaten.py +5937 -0
- torchax/ops/jax_reimplement.py +185 -0
- torchax/ops/jc10d.py +66 -0
- torchax/ops/jimage.py +127 -0
- torchax/ops/jlibrary.py +94 -0
- torchax/ops/jtorch.py +631 -0
- torchax/ops/jtorchvision_nms.py +248 -0
- torchax/ops/mappings.py +161 -0
- torchax/ops/op_base.py +145 -0
- torchax/ops/ops_registry.py +69 -0
- torchax/tensor.py +736 -0
- torchax/train.py +132 -0
- torchax/types.py +26 -0
- torchax/util.py +102 -0
- torchax/view.py +391 -0
- torchax-0.0.10.dev20251117.dist-info/METADATA +507 -0
- torchax-0.0.10.dev20251117.dist-info/RECORD +31 -0
- torchax-0.0.10.dev20251117.dist-info/WHEEL +4 -0
- torchax-0.0.10.dev20251117.dist-info/licenses/LICENSE +201 -0
torchax/train.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import collections
|
|
16
|
+
import functools
|
|
17
|
+
import torch
|
|
18
|
+
import jax
|
|
19
|
+
import torchax
|
|
20
|
+
from torchax import interop
|
|
21
|
+
from torchax.interop import torch_view, jax_view
|
|
22
|
+
import optax
|
|
23
|
+
|
|
24
|
+
remat = torch_view(jax.remat)
|
|
25
|
+
mark_sharding = torch_view(jax.lax.with_sharding_constraint)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def make_train_step(model_fn, loss_fn, optax_optimizer, remat_policy=None):
|
|
29
|
+
"""Make a function that do one train step given model and loss.
|
|
30
|
+
|
|
31
|
+
model_fn: a function representing the model's forward:
|
|
32
|
+
i.e. has signature Callable[weights, buffers, args] -> result. Where,
|
|
33
|
+
weights is a pytree of trainable parameters
|
|
34
|
+
buffers is a pytree of non-trainable parameters / constants
|
|
35
|
+
args is the input data loaded from the data set
|
|
36
|
+
result is the return value of the model
|
|
37
|
+
loss_fn: a function to compute loss.
|
|
38
|
+
i.e. it has signature of Callable[result, label] -> loss
|
|
39
|
+
where, result is what model_fn returned
|
|
40
|
+
loss is loaded from the dataloader.
|
|
41
|
+
optax_optimizer: the optimizer from optax library. for example, optax.adam
|
|
42
|
+
remat_policy: One of jax.ad_checkpoint.checkpoint_policies, specifies how
|
|
43
|
+
to do gradient checkpointing. If None, then it means checkpoint everything.
|
|
44
|
+
"""
|
|
45
|
+
env = torchax.default_env()
|
|
46
|
+
|
|
47
|
+
def loss(weights, buffers, args, label): # inputs are XLATensor
|
|
48
|
+
with env, jax.named_scope("compute_loss"):
|
|
49
|
+
res = model_fn(weights, buffers, args)
|
|
50
|
+
l = loss_fn(res, label)
|
|
51
|
+
return l
|
|
52
|
+
|
|
53
|
+
# loss = interop.gradient_checkpoint(loss, kwargs={'policy': remat_policy})
|
|
54
|
+
grad_fn = interop.jax_value_and_grad(loss)
|
|
55
|
+
|
|
56
|
+
def step(weights, buffers, opt_state, args, label): # inputs are array
|
|
57
|
+
with jax.named_scope("compute_gradient"):
|
|
58
|
+
loss, gradient = grad_fn(weights, buffers, args, label)
|
|
59
|
+
|
|
60
|
+
with jax.named_scope("optimizer_updates"):
|
|
61
|
+
updates, opt_state = interop.call_jax(
|
|
62
|
+
optax_optimizer.update, gradient, opt_state, weights
|
|
63
|
+
)
|
|
64
|
+
weights = interop.call_jax(optax.apply_updates, weights, updates)
|
|
65
|
+
return loss, weights, opt_state
|
|
66
|
+
|
|
67
|
+
# TODO: apply jax.jit so the user don't have to.
|
|
68
|
+
return step
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class Container:
|
|
72
|
+
pass
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class ScannedModule(torch.nn.Module):
|
|
76
|
+
def __init__(self, module_list, checkpoint_policy=None):
|
|
77
|
+
super().__init__()
|
|
78
|
+
|
|
79
|
+
self.c = None
|
|
80
|
+
assert module_list
|
|
81
|
+
self.c = Container()
|
|
82
|
+
self.c.one_mod = module_list[0]
|
|
83
|
+
self.checkpoint_policy = checkpoint_policy
|
|
84
|
+
|
|
85
|
+
weights = self._stack_layer_weights(module_list)
|
|
86
|
+
self.layer_weights_keys = list(self.c.one_mod.state_dict().keys())
|
|
87
|
+
self.params = torch.nn.ParameterDict(
|
|
88
|
+
{self._param_name_new(k): v for k, v in weights.items()}
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
def _stack_layer_weights(self, module_list):
|
|
92
|
+
# Create weights such that, for every [n, m] weights
|
|
93
|
+
# becomes [k, n, m] where k is number of layer
|
|
94
|
+
# i.e. stacking layer weights together
|
|
95
|
+
temp = collections.defaultdict(list)
|
|
96
|
+
for m in module_list:
|
|
97
|
+
for k, v in m.state_dict().items():
|
|
98
|
+
temp[k].append(v)
|
|
99
|
+
res = {k: torch.stack(v) for k, v in temp.items()}
|
|
100
|
+
return res
|
|
101
|
+
|
|
102
|
+
def _param_name_new(self, old):
|
|
103
|
+
return "___".join(old.split("."))
|
|
104
|
+
|
|
105
|
+
def _param_name_old(self, new):
|
|
106
|
+
return ".".join(new.split("___"))
|
|
107
|
+
|
|
108
|
+
def forward(self, *args, **kwargs):
|
|
109
|
+
assert not kwargs
|
|
110
|
+
weights = {
|
|
111
|
+
k: self.params[self._param_name_new(k)]
|
|
112
|
+
for k in self.layer_weights_keys
|
|
113
|
+
}
|
|
114
|
+
scan = interop.torch_view(jax.lax.scan)
|
|
115
|
+
|
|
116
|
+
def eval_one_layer(args, weight):
|
|
117
|
+
# unpack args
|
|
118
|
+
h, *rest = args
|
|
119
|
+
newh = torch.func.functional_call(self.c.one_mod, weight, args)
|
|
120
|
+
# next layer's input; and residual to be added to list
|
|
121
|
+
return (newh, *rest), None
|
|
122
|
+
|
|
123
|
+
_eval_one_layer = interop.gradient_checkpoint(
|
|
124
|
+
eval_one_layer,
|
|
125
|
+
kwargs={"policy": self.checkpoint_policy},
|
|
126
|
+
)
|
|
127
|
+
h, _ = scan(
|
|
128
|
+
_eval_one_layer,
|
|
129
|
+
args,
|
|
130
|
+
weights,
|
|
131
|
+
)
|
|
132
|
+
return h[0]
|
torchax/types.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Callable, Any, Union, ParamSpec, TypeAlias
|
|
16
|
+
import torch
|
|
17
|
+
import jax
|
|
18
|
+
import jax.numpy as jnp
|
|
19
|
+
import sys
|
|
20
|
+
|
|
21
|
+
P = ParamSpec('P')
|
|
22
|
+
|
|
23
|
+
TorchValue: TypeAlias = Union[torch.Tensor, torch.dtype, 'TorchCallable', Any]
|
|
24
|
+
TorchCallable: TypeAlias = Callable[P, TorchValue]
|
|
25
|
+
JaxValue: TypeAlias = Union[jax.Array, jnp.dtype, 'JaxCallable', Any]
|
|
26
|
+
JaxCallable: TypeAlias = Callable[P, JaxValue]
|
torchax/util.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Any, Callable
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def partition(original: list[Any],
|
|
19
|
+
func: Callable[[Any], bool]) -> tuple[list[Any], list[Any]]:
|
|
20
|
+
"""Partitions elements into two parallel lists based on a predicate function.
|
|
21
|
+
|
|
22
|
+
Iterates through the 'original' list, applying 'func' to each element 'a'.
|
|
23
|
+
- If `func(a)` returns True, 'a' is appended to the first list ('truthy')
|
|
24
|
+
and `None` is appended to the second list ('falsy').
|
|
25
|
+
- If `func(a)` returns False, `None` is appended to the first list ('truthy')
|
|
26
|
+
and 'a' is appended to the second list ('falsy').
|
|
27
|
+
|
|
28
|
+
The result is two lists of the same length as the 'original' list, acting
|
|
29
|
+
as parallel representations of the partitioned elements, using `None` as
|
|
30
|
+
placeholders.
|
|
31
|
+
|
|
32
|
+
This is useful when we want to mark a group of elements as static (via passing
|
|
33
|
+
static_argnums) or donated (via donate_argnums) when combining with jax.jit
|
|
34
|
+
and friends.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
original: The list of elements to partition.
|
|
38
|
+
func: A callable (function or lambda) that accepts an element from
|
|
39
|
+
'original' and returns a boolean value (True or False).
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
A tuple containing two lists (`truthy`, `falsy`), both of the same
|
|
43
|
+
length as `original`:
|
|
44
|
+
- The first list contains elements `x` where `func(x)` was True, and
|
|
45
|
+
`None` otherwise.
|
|
46
|
+
- The second list contains elements `x` where `func(x)` was False, and
|
|
47
|
+
`None` otherwise.
|
|
48
|
+
|
|
49
|
+
Example:
|
|
50
|
+
>>> def is_even(n): return n % 2 == 0
|
|
51
|
+
>>> nums = [1, 2, 3, 4, 5, 6]
|
|
52
|
+
>>> truthy_list, falsy_list = partition(nums, is_even)
|
|
53
|
+
>>> truthy_list
|
|
54
|
+
[None, 2, None, 4, None, 6]
|
|
55
|
+
>>> falsy_list
|
|
56
|
+
[1, None, 3, None, 5, None]
|
|
57
|
+
"""
|
|
58
|
+
truthy = []
|
|
59
|
+
falsy = []
|
|
60
|
+
for a in original:
|
|
61
|
+
t, f = (a, None) if func(a) else (None, a)
|
|
62
|
+
truthy.append(t)
|
|
63
|
+
falsy.append(f)
|
|
64
|
+
return truthy, falsy
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def merge(list1: list[Any], list2: list[Any]) -> list[Any]:
|
|
68
|
+
"""Merges two lists element-wise, prioritizing non-None elements from list1.
|
|
69
|
+
|
|
70
|
+
Creates a new list where each element is taken from the corresponding position
|
|
71
|
+
in 'list1', unless that element is None, in which case the element from the
|
|
72
|
+
corresponding position in 'list2' is used. Assumes both lists have the
|
|
73
|
+
same length.
|
|
74
|
+
|
|
75
|
+
Invariant: merge(*partion(input_list, predicate)) == input_list for any predicate
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
list1: The primary list. Its elements are preferred unless they are None.
|
|
79
|
+
list2: The secondary list. Its elements are used as fallbacks when the
|
|
80
|
+
corresponding element in list1 is None.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
A new list representing the merged result.
|
|
84
|
+
|
|
85
|
+
Raises:
|
|
86
|
+
AssertionError: If 'list1' and 'list2' do not have the same length.
|
|
87
|
+
|
|
88
|
+
Example:
|
|
89
|
+
>>> l1 = [1, None, 3, None]
|
|
90
|
+
>>> l2 = [None, 2, None, 4]
|
|
91
|
+
>>> merge(l1, l2)
|
|
92
|
+
[1, 2, 3, 4]
|
|
93
|
+
>>> l3 = [None, 'b', None]
|
|
94
|
+
>>> l4 = ['a', None, 'c']
|
|
95
|
+
>>> merge(l3, l4)
|
|
96
|
+
['a', 'b', 'c']
|
|
97
|
+
"""
|
|
98
|
+
assert len(list1) == len(list2)
|
|
99
|
+
res = []
|
|
100
|
+
for a, b in zip(list1, list2):
|
|
101
|
+
res.append(b if a is None else a)
|
|
102
|
+
return res
|
torchax/view.py
ADDED
|
@@ -0,0 +1,391 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
import torch.utils._pytree as torch_pytree
|
|
17
|
+
import jax
|
|
18
|
+
from enum import Enum
|
|
19
|
+
from typing import Union, List, Tuple, Optional, Any, cast
|
|
20
|
+
from abc import ABC, abstractmethod
|
|
21
|
+
|
|
22
|
+
# Reference to original PyTorch native functions
|
|
23
|
+
# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ViewInfoType(Enum):
|
|
27
|
+
INVALID = 0
|
|
28
|
+
NARROW = 1
|
|
29
|
+
NO_OP = 2
|
|
30
|
+
PERMUTE = 3
|
|
31
|
+
RESHAPE = 4
|
|
32
|
+
RESIZE = 5
|
|
33
|
+
SELECT = 6
|
|
34
|
+
AS_STRIDED = 7
|
|
35
|
+
DIAGONAL = 8
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class ViewInfo(ABC):
|
|
39
|
+
"""
|
|
40
|
+
Abstract base class for all view operations.
|
|
41
|
+
Defines the interface for applying and updating view transformations.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
view_info_type: ViewInfoType = ViewInfoType.INVALID,
|
|
47
|
+
):
|
|
48
|
+
"""
|
|
49
|
+
Initialize a ViewInfo object.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
view_info_type: The type of view operation
|
|
53
|
+
"""
|
|
54
|
+
self.view_info_type = view_info_type
|
|
55
|
+
|
|
56
|
+
@abstractmethod
|
|
57
|
+
def update_tensor(self, new_value: jax.Array,
|
|
58
|
+
jax_array: jax.Array) -> jax.Array:
|
|
59
|
+
"""
|
|
60
|
+
Apply this view transformation to a JAX array and update its value.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
new_value: The new values to set in the view
|
|
64
|
+
jax_array: The parent array to update
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
Updated array
|
|
68
|
+
"""
|
|
69
|
+
pass
|
|
70
|
+
|
|
71
|
+
@abstractmethod
|
|
72
|
+
def transform_tensor(self, jax_array: jax.Array) -> jax.Array:
|
|
73
|
+
"""
|
|
74
|
+
Apply this view transformation to a JAX array.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
jax_array: The array to transform
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
Transformed array
|
|
81
|
+
"""
|
|
82
|
+
pass
|
|
83
|
+
|
|
84
|
+
@abstractmethod
|
|
85
|
+
def calculate_output_shape(self, source: jax.Array) -> List[int]:
|
|
86
|
+
"""
|
|
87
|
+
Calculate the resulting shape after applying this view.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
source: Original jax array before transformation
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
Resulting shape after transformation
|
|
94
|
+
"""
|
|
95
|
+
pass
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class NarrowInfo(ViewInfo):
|
|
99
|
+
"""
|
|
100
|
+
Represents a slicing operation on a tensor.
|
|
101
|
+
Handles operations like tensor[1:3, :, 2:5:2].
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
def __init__(self, slices: Union[slice, Tuple[slice]]) -> None:
|
|
105
|
+
"""
|
|
106
|
+
Args:
|
|
107
|
+
slices: The slice(s) to apply to the tensor.
|
|
108
|
+
E.g. jax_array.at[slices] will return the transformed tensor.
|
|
109
|
+
"""
|
|
110
|
+
super().__init__(ViewInfoType.NARROW)
|
|
111
|
+
self.slices = slices
|
|
112
|
+
|
|
113
|
+
def __eq__(self, other: object) -> bool:
|
|
114
|
+
if not isinstance(other, NarrowInfo):
|
|
115
|
+
return False
|
|
116
|
+
return self.slices == other.slices
|
|
117
|
+
|
|
118
|
+
def transform_tensor(self, jax_array: jax.Array) -> jax.Array:
|
|
119
|
+
try:
|
|
120
|
+
return jax_array[self.slices]
|
|
121
|
+
except IndexError as e:
|
|
122
|
+
raise IndexError("Invalid slice operation") from e
|
|
123
|
+
|
|
124
|
+
def update_tensor(self, new_value: jax.Array,
|
|
125
|
+
jax_array: jax.Array) -> jax.Array:
|
|
126
|
+
return jax_array.at[self.slices].set(new_value)
|
|
127
|
+
|
|
128
|
+
def calculate_output_shape(self, source: jax.Array) -> List[int]:
|
|
129
|
+
return source[self.slices].shape
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class SelectInfo(ViewInfo):
|
|
133
|
+
"""
|
|
134
|
+
Represents a selection operation on a tensor.
|
|
135
|
+
Typically used for indexing operations that select specific elements.
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
def __init__(self,
|
|
139
|
+
dim: int = 0,
|
|
140
|
+
start: int = 0,
|
|
141
|
+
end: int = 0,
|
|
142
|
+
stride: int = 0) -> None:
|
|
143
|
+
super().__init__(ViewInfoType.SELECT)
|
|
144
|
+
self.dim: int = dim
|
|
145
|
+
self.start: int = start
|
|
146
|
+
self.end: int = end
|
|
147
|
+
self.stride: int = stride
|
|
148
|
+
|
|
149
|
+
def __eq__(self, other: object) -> bool:
|
|
150
|
+
if not isinstance(other, SelectInfo):
|
|
151
|
+
return False
|
|
152
|
+
return (self.dim == other.dim and self.start == other.start and
|
|
153
|
+
self.end == other.end and self.stride == other.stride)
|
|
154
|
+
|
|
155
|
+
def transform_tensor(self, jax_array: jax.Array) -> jax.Array:
|
|
156
|
+
raise NotImplementedError("SelectInfo.apply not implemented")
|
|
157
|
+
|
|
158
|
+
def update_tensor(self, new_value: jax.Array,
|
|
159
|
+
jax_array: jax.Array) -> jax.Array:
|
|
160
|
+
raise NotImplementedError("SelectInfo.update not implemented")
|
|
161
|
+
|
|
162
|
+
def calculate_output_shape(self, source: jax.Array) -> List[int]:
|
|
163
|
+
raise NotImplementedError(
|
|
164
|
+
"SelectInfo.calculate_output_shape not implemented")
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
class AsStridedInfo(ViewInfo):
|
|
168
|
+
"""
|
|
169
|
+
Information for as_strided operations.
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
def __init__(self, stride: List[int], offset: int = 0) -> None:
|
|
173
|
+
super().__init__(ViewInfoType.AS_STRIDED)
|
|
174
|
+
self.stride: List[int] = stride
|
|
175
|
+
self.offset: int = offset
|
|
176
|
+
|
|
177
|
+
def __eq__(self, other: object) -> bool:
|
|
178
|
+
if not isinstance(other, AsStridedInfo):
|
|
179
|
+
return False
|
|
180
|
+
return self.offset == other.offset and self.stride == other.stride
|
|
181
|
+
|
|
182
|
+
def transform_tensor(self, jax_array: jax.Array) -> jax.Array:
|
|
183
|
+
raise NotImplementedError("AsStridedInfo.apply not implemented")
|
|
184
|
+
|
|
185
|
+
def update_tensor(self, new_value: jax.Array,
|
|
186
|
+
jax_array: jax.Array) -> jax.Array:
|
|
187
|
+
raise NotImplementedError("AsStridedInfo.update not implemented")
|
|
188
|
+
|
|
189
|
+
def calculate_output_shape(self, source: jax.Array) -> List[int]:
|
|
190
|
+
raise NotImplementedError(
|
|
191
|
+
"AsStridedInfo.calculate_output_shape not implemented")
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
class DiagonalInfo(ViewInfo):
|
|
195
|
+
"""
|
|
196
|
+
Information for diagonal operations.
|
|
197
|
+
Extracts diagonal elements from a tensor.
|
|
198
|
+
"""
|
|
199
|
+
|
|
200
|
+
def __init__(self, offset: int = 0, dim1: int = 0, dim2: int = 1) -> None:
|
|
201
|
+
"""
|
|
202
|
+
Args:
|
|
203
|
+
offset: Offset from the main diagonal
|
|
204
|
+
dim1: First dimension for diagonal extraction
|
|
205
|
+
dim2: Second dimension for diagonal extraction
|
|
206
|
+
"""
|
|
207
|
+
super().__init__(ViewInfoType.DIAGONAL)
|
|
208
|
+
self.offset: int = offset
|
|
209
|
+
self.dim1: int = dim1
|
|
210
|
+
self.dim2: int = dim2
|
|
211
|
+
|
|
212
|
+
def __eq__(self, other: object) -> bool:
|
|
213
|
+
if not isinstance(other, DiagonalInfo):
|
|
214
|
+
return False
|
|
215
|
+
return (self.offset == other.offset and self.dim1 == other.dim1 and
|
|
216
|
+
self.dim2 == other.dim2)
|
|
217
|
+
|
|
218
|
+
def transform_tensor(self, jax_array: jax.Array) -> jax.Array:
|
|
219
|
+
raise NotImplementedError("DiagonalInfo.apply not implemented")
|
|
220
|
+
|
|
221
|
+
def update_tensor(self, new_value: jax.Array,
|
|
222
|
+
jax_array: jax.Array) -> jax.Array:
|
|
223
|
+
raise NotImplementedError("DiagonalInfo.update not implemented")
|
|
224
|
+
|
|
225
|
+
def calculate_output_shape(self, source: jax.Array) -> List[int]:
|
|
226
|
+
raise NotImplementedError(
|
|
227
|
+
"DiagonalInfo.calculate_output_shape not implemented")
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
class View(torch.Tensor):
|
|
231
|
+
"""
|
|
232
|
+
A View is a reference to another Tensor or another View,
|
|
233
|
+
with a transformation applied to it.
|
|
234
|
+
"""
|
|
235
|
+
|
|
236
|
+
@staticmethod
|
|
237
|
+
def __new__(cls, parent: Union["torchax.Tensor", "View"], view_info: ViewInfo,
|
|
238
|
+
env: Any) -> "View":
|
|
239
|
+
"""
|
|
240
|
+
Args:
|
|
241
|
+
parent: Parent tensor or view
|
|
242
|
+
view_info: Information about the view transformation
|
|
243
|
+
env: Environment for tensor operations
|
|
244
|
+
"""
|
|
245
|
+
shape = view_info.calculate_output_shape(parent.jax())
|
|
246
|
+
return torch.Tensor._make_wrapper_subclass(
|
|
247
|
+
cls,
|
|
248
|
+
shape,
|
|
249
|
+
device="meta",
|
|
250
|
+
dtype=parent.dtype,
|
|
251
|
+
requires_grad=False,
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
def __init__(self, parent: Union["torchax.Tensor", "View"],
|
|
255
|
+
view_info: ViewInfo, env: Any) -> None:
|
|
256
|
+
super().__init__()
|
|
257
|
+
self.parent = parent
|
|
258
|
+
self.view_info = view_info
|
|
259
|
+
self._env = env
|
|
260
|
+
|
|
261
|
+
def get_transformation_chain(self) -> List[ViewInfo]:
|
|
262
|
+
"""
|
|
263
|
+
Get all view transformations from the source tensor to this view.
|
|
264
|
+
"""
|
|
265
|
+
if isinstance(self.parent, View):
|
|
266
|
+
transformations = self.parent.get_transformation_chain()
|
|
267
|
+
transformations.append(self.view_info)
|
|
268
|
+
return transformations
|
|
269
|
+
else:
|
|
270
|
+
return [self.view_info]
|
|
271
|
+
|
|
272
|
+
__torch_function__ = torch._C._disabled_torch_function_impl
|
|
273
|
+
|
|
274
|
+
def source_jax(self) -> jax.Array:
|
|
275
|
+
"""
|
|
276
|
+
Returns the source tensor.
|
|
277
|
+
"""
|
|
278
|
+
if isinstance(self.parent, View):
|
|
279
|
+
return self.parent.source_jax()
|
|
280
|
+
else:
|
|
281
|
+
return self.parent.jax()
|
|
282
|
+
|
|
283
|
+
def replace_source_jax(self, new_value: jax.Array) -> None:
|
|
284
|
+
"""
|
|
285
|
+
Update the source tensor with new values.
|
|
286
|
+
"""
|
|
287
|
+
if isinstance(self.parent, View):
|
|
288
|
+
self.parent.replace_source_jax(new_value)
|
|
289
|
+
else:
|
|
290
|
+
assert new_value.shape == self.parent._elem.shape
|
|
291
|
+
self.parent._elem = new_value
|
|
292
|
+
|
|
293
|
+
def torch(self) -> "torchax.Tensor":
|
|
294
|
+
"""
|
|
295
|
+
Returns a Torchax tensor representing this view after all transformations
|
|
296
|
+
"""
|
|
297
|
+
from torchax.tensor import Tensor
|
|
298
|
+
|
|
299
|
+
return Tensor(self.jax(), self._env)
|
|
300
|
+
|
|
301
|
+
def update(
|
|
302
|
+
self,
|
|
303
|
+
new_values: Union[jax.Array, "View", "torchax.Tensor"],
|
|
304
|
+
view_infos: Optional[List[ViewInfo]] = None,
|
|
305
|
+
) -> None:
|
|
306
|
+
"""
|
|
307
|
+
Update this view with new values, propagating changes back to source.
|
|
308
|
+
If view_infos is None, it will use the transformation chain
|
|
309
|
+
from the source tensor.
|
|
310
|
+
"""
|
|
311
|
+
if view_infos is None:
|
|
312
|
+
view_infos = self.get_transformation_chain()
|
|
313
|
+
|
|
314
|
+
# Get the source JAX array
|
|
315
|
+
source_array = self.source_jax()
|
|
316
|
+
|
|
317
|
+
# Get the new value
|
|
318
|
+
from torchax.tensor import Tensor
|
|
319
|
+
|
|
320
|
+
if isinstance(new_values, View) or isinstance(new_values, Tensor):
|
|
321
|
+
new_values = new_values.jax()
|
|
322
|
+
|
|
323
|
+
# Apply all view transformations to the source array
|
|
324
|
+
# And store intermediate values
|
|
325
|
+
intermediate_values = [source_array]
|
|
326
|
+
for view_info in view_infos[:-1]:
|
|
327
|
+
intermediate_values.append(
|
|
328
|
+
view_info.transform_tensor(intermediate_values[-1]))
|
|
329
|
+
|
|
330
|
+
# TODO: Investigate efficiency of this algorithm
|
|
331
|
+
# Update the source array with the new value by
|
|
332
|
+
# applying inverse transformations in reverse order
|
|
333
|
+
for view_info, parent_array in zip(
|
|
334
|
+
reversed(view_infos), reversed(intermediate_values)):
|
|
335
|
+
# Apply the inverse transformation to propagate changes back
|
|
336
|
+
new_values = view_info.update_tensor(new_values, parent_array)
|
|
337
|
+
|
|
338
|
+
# Update the source tensor with the new values
|
|
339
|
+
self.replace_source_jax(new_values)
|
|
340
|
+
|
|
341
|
+
@classmethod
|
|
342
|
+
def __torch_dispatch__(
|
|
343
|
+
cls,
|
|
344
|
+
func: Any,
|
|
345
|
+
types: Tuple[Any, ...],
|
|
346
|
+
args: Tuple[Any, ...] = (),
|
|
347
|
+
kwargs: Optional[dict] = None,
|
|
348
|
+
) -> Any:
|
|
349
|
+
raise AssertionError(
|
|
350
|
+
'torchax Tensors can only do math within the torchax environment.'
|
|
351
|
+
'Please wrap your code with `with torchax.default_env()` or '
|
|
352
|
+
'call torchax.enable_globally() before.')
|
|
353
|
+
|
|
354
|
+
def create_sub_view(self, view_info: ViewInfo) -> "View":
|
|
355
|
+
"""
|
|
356
|
+
Create a new view that is a child of this view.
|
|
357
|
+
"""
|
|
358
|
+
return View(self, view_info, self._env)
|
|
359
|
+
|
|
360
|
+
def __str__(self) -> str:
|
|
361
|
+
return f"View({self.torch()})"
|
|
362
|
+
|
|
363
|
+
def jax(self) -> jax.Array:
|
|
364
|
+
"""
|
|
365
|
+
Returns a copy of the source tensor after transformations.
|
|
366
|
+
"""
|
|
367
|
+
result = self.source_jax()
|
|
368
|
+
for view_info in self.get_transformation_chain():
|
|
369
|
+
result = view_info.transform_tensor(result)
|
|
370
|
+
return result
|
|
371
|
+
|
|
372
|
+
def __setitem__(self, indexes, val):
|
|
373
|
+
view_infos = self.get_transformation_chain() + [NarrowInfo(indexes)]
|
|
374
|
+
self.update(view_infos=view_infos, new_values=val)
|
|
375
|
+
|
|
376
|
+
def dim(self):
|
|
377
|
+
return self.ndim
|
|
378
|
+
|
|
379
|
+
@property
|
|
380
|
+
def device(self):
|
|
381
|
+
return torch.device("jax:0")
|
|
382
|
+
|
|
383
|
+
@property
|
|
384
|
+
def jax_device(self):
|
|
385
|
+
return self.jax().device
|
|
386
|
+
|
|
387
|
+
@property
|
|
388
|
+
def ndim(self):
|
|
389
|
+
return len(self.shape)
|
|
390
|
+
|
|
391
|
+
__repr__ = __str__
|