torchax 0.0.10.dev20251116__py3-none-any.whl → 0.0.11.dev202612__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/__init__.py +73 -77
- torchax/amp.py +143 -271
- torchax/checkpoint.py +15 -9
- torchax/config.py +0 -4
- torchax/decompositions.py +66 -60
- torchax/export.py +53 -54
- torchax/flax.py +7 -5
- torchax/interop.py +66 -62
- torchax/mesh_util.py +20 -18
- torchax/ops/__init__.py +4 -3
- torchax/ops/jaten.py +3841 -3968
- torchax/ops/jax_reimplement.py +68 -42
- torchax/ops/jc10d.py +4 -6
- torchax/ops/jimage.py +20 -25
- torchax/ops/jlibrary.py +6 -6
- torchax/ops/jtorch.py +355 -419
- torchax/ops/jtorchvision_nms.py +69 -49
- torchax/ops/mappings.py +42 -63
- torchax/ops/op_base.py +17 -25
- torchax/ops/ops_registry.py +35 -30
- torchax/tensor.py +124 -128
- torchax/train.py +100 -102
- torchax/types.py +8 -7
- torchax/util.py +6 -4
- torchax/view.py +144 -136
- {torchax-0.0.10.dev20251116.dist-info → torchax-0.0.11.dev202612.dist-info}/METADATA +7 -1
- torchax-0.0.11.dev202612.dist-info/RECORD +31 -0
- {torchax-0.0.10.dev20251116.dist-info → torchax-0.0.11.dev202612.dist-info}/WHEEL +1 -1
- torchax-0.0.10.dev20251116.dist-info/RECORD +0 -31
- {torchax-0.0.10.dev20251116.dist-info → torchax-0.0.11.dev202612.dist-info}/licenses/LICENSE +0 -0
torchax/checkpoint.py
CHANGED
|
@@ -12,32 +12,39 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
import torch
|
|
16
15
|
import os
|
|
17
|
-
from typing import Any
|
|
18
|
-
|
|
16
|
+
from typing import Any
|
|
17
|
+
|
|
19
18
|
import jax
|
|
20
19
|
import jax.numpy as jnp
|
|
21
20
|
import numpy as np
|
|
21
|
+
import torch
|
|
22
|
+
from flax.training import checkpoints
|
|
23
|
+
|
|
22
24
|
from . import tensor
|
|
23
25
|
|
|
26
|
+
|
|
24
27
|
def _to_jax(pytree):
|
|
25
28
|
def to_jax_array(x):
|
|
26
29
|
if isinstance(x, tensor.Tensor):
|
|
27
|
-
|
|
30
|
+
return x.jax()
|
|
28
31
|
elif isinstance(x, torch.Tensor):
|
|
29
|
-
|
|
32
|
+
return jnp.asarray(x.cpu().numpy())
|
|
30
33
|
return x
|
|
34
|
+
|
|
31
35
|
return jax.tree_util.tree_map(to_jax_array, pytree)
|
|
32
36
|
|
|
33
37
|
|
|
34
38
|
def _to_torch(pytree):
|
|
35
39
|
return jax.tree_util.tree_map(
|
|
36
40
|
lambda x: torch.from_numpy(np.asarray(x))
|
|
37
|
-
if isinstance(x, (jnp.ndarray, jax.Array))
|
|
41
|
+
if isinstance(x, (jnp.ndarray, jax.Array))
|
|
42
|
+
else x,
|
|
43
|
+
pytree,
|
|
44
|
+
)
|
|
38
45
|
|
|
39
46
|
|
|
40
|
-
def save_checkpoint(state:
|
|
47
|
+
def save_checkpoint(state: dict[str, Any], path: str, step: int):
|
|
41
48
|
"""Saves a checkpoint to a file in JAX style.
|
|
42
49
|
|
|
43
50
|
Args:
|
|
@@ -50,7 +57,7 @@ def save_checkpoint(state: Dict[str, Any], path: str, step: int):
|
|
|
50
57
|
checkpoints.save_checkpoint(path, state, step=step, overwrite=True)
|
|
51
58
|
|
|
52
59
|
|
|
53
|
-
def load_checkpoint(path: str) ->
|
|
60
|
+
def load_checkpoint(path: str) -> dict[str, Any]:
|
|
54
61
|
"""Loads a checkpoint and returns it in JAX format.
|
|
55
62
|
|
|
56
63
|
This function can load both PyTorch-style (single file) and JAX-style
|
|
@@ -76,4 +83,3 @@ def load_checkpoint(path: str) -> Dict[str, Any]:
|
|
|
76
83
|
return _to_jax(state)
|
|
77
84
|
else:
|
|
78
85
|
raise FileNotFoundError(f"No such file or directory: {path}")
|
|
79
|
-
|
torchax/config.py
CHANGED
|
@@ -35,10 +35,6 @@ class Configuration:
|
|
|
35
35
|
# Use DLPack for converting jax.Arrays <-> and torch.Tensor
|
|
36
36
|
use_dlpack_for_data_conversion: bool = False
|
|
37
37
|
|
|
38
|
-
# Flash attention
|
|
39
|
-
use_tpu_flash_attention: bool = False
|
|
40
|
-
shmap_flash_attention: bool = False
|
|
41
|
-
|
|
42
38
|
# device
|
|
43
39
|
treat_cuda_as_jax_device: bool = True
|
|
44
40
|
internal_respect_torch_return_dtypes: bool = False
|
torchax/decompositions.py
CHANGED
|
@@ -22,21 +22,21 @@ Can also contain decompositions of a torch op in terms of other torch ops.
|
|
|
22
22
|
"""
|
|
23
23
|
|
|
24
24
|
import functools
|
|
25
|
-
from
|
|
25
|
+
from collections.abc import Callable
|
|
26
|
+
from typing import Any
|
|
26
27
|
|
|
27
28
|
import torch
|
|
28
|
-
from torch import Tensor
|
|
29
29
|
import torch._decomp as decomp
|
|
30
|
-
from torch._decomp import decompositions_for_rng
|
|
31
|
-
from torch._decomp import register_decomposition
|
|
32
30
|
import torch._prims_common as utils
|
|
31
|
+
from torch import Tensor
|
|
32
|
+
from torch._decomp import decompositions_for_rng, register_decomposition
|
|
33
33
|
from torch._prims_common.wrappers import out_wrapper
|
|
34
34
|
|
|
35
35
|
DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
|
|
36
36
|
|
|
37
37
|
# None of these functions are publicly accessible; get at them
|
|
38
38
|
# from torch._decomps
|
|
39
|
-
__all__:
|
|
39
|
+
__all__: list[str] = []
|
|
40
40
|
|
|
41
41
|
aten = torch._ops.ops.aten
|
|
42
42
|
|
|
@@ -44,21 +44,21 @@ aten = torch._ops.ops.aten
|
|
|
44
44
|
def _try_register(op, impl):
|
|
45
45
|
try:
|
|
46
46
|
register_decomposition(op)(impl)
|
|
47
|
-
|
|
47
|
+
|
|
48
|
+
except Exception:
|
|
48
49
|
pass
|
|
49
50
|
|
|
50
51
|
|
|
51
52
|
@out_wrapper()
|
|
52
|
-
def _reflection_pad(a: Tensor, padding:
|
|
53
|
-
|
|
53
|
+
def _reflection_pad(a: Tensor, padding: tuple[int, ...]) -> Tensor:
|
|
54
54
|
def idx(left, middle, right):
|
|
55
55
|
dim_idx = torch.arange(-left, middle + right, device=a.device)
|
|
56
56
|
return middle - 1 - (middle - 1 - dim_idx.abs()).abs()
|
|
57
57
|
|
|
58
58
|
return _reflection_or_replication_pad(
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
59
|
+
a,
|
|
60
|
+
padding,
|
|
61
|
+
idx,
|
|
62
62
|
)
|
|
63
63
|
|
|
64
64
|
|
|
@@ -68,32 +68,32 @@ _try_register(aten.reflection_pad3d, _reflection_pad)
|
|
|
68
68
|
|
|
69
69
|
|
|
70
70
|
@out_wrapper()
|
|
71
|
-
def _replication_pad(a: Tensor, padding:
|
|
72
|
-
|
|
71
|
+
def _replication_pad(a: Tensor, padding: tuple[int, ...]) -> Tensor:
|
|
73
72
|
def idx(left, middle, right):
|
|
74
73
|
dim_idx = torch.arange(-left, middle + right, device=a.device)
|
|
75
74
|
return torch.clamp(dim_idx, 0, middle - 1)
|
|
76
75
|
|
|
77
76
|
return _reflection_or_replication_pad(
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
77
|
+
a,
|
|
78
|
+
padding,
|
|
79
|
+
idx,
|
|
81
80
|
)
|
|
82
81
|
|
|
83
82
|
|
|
84
|
-
decomp.global_decomposition_table["post_autograd"][
|
|
85
|
-
|
|
83
|
+
decomp.global_decomposition_table["post_autograd"][aten.replication_pad2d.default] = (
|
|
84
|
+
_replication_pad
|
|
85
|
+
)
|
|
86
86
|
|
|
87
87
|
|
|
88
88
|
def _reflection_or_replication_pad(
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
89
|
+
a: Tensor,
|
|
90
|
+
padding: tuple[int, ...],
|
|
91
|
+
idx_fn: Callable[[int, int, int], Tensor],
|
|
92
92
|
) -> Tensor:
|
|
93
93
|
dim = len(padding) // 2
|
|
94
94
|
torch._check(
|
|
95
|
-
|
|
96
|
-
|
|
95
|
+
a.dim() in (dim + 1, dim + 2),
|
|
96
|
+
lambda: f"reflection_pad{dim}d requires {dim + 1}D or {dim + 2}D input",
|
|
97
97
|
)
|
|
98
98
|
inp_shape = a.shape[-dim:]
|
|
99
99
|
nc_dim = a.dim() - dim
|
|
@@ -103,7 +103,7 @@ def _reflection_or_replication_pad(
|
|
|
103
103
|
|
|
104
104
|
result = a
|
|
105
105
|
for i in range(dim):
|
|
106
|
-
idx:
|
|
106
|
+
idx: list[Any] = [None] * result.dim()
|
|
107
107
|
idx[i + nc_dim] = idx_fn(padding_left[i], inp_shape[i], padding_right[i])
|
|
108
108
|
result = aten._unsafe_index(result, idx)
|
|
109
109
|
|
|
@@ -158,11 +158,11 @@ def _sum_tensors(ts) -> Tensor:
|
|
|
158
158
|
|
|
159
159
|
@register_decomposition(aten.grid_sampler_3d)
|
|
160
160
|
def _grid_sampler_3d(
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
161
|
+
a: torch.Tensor,
|
|
162
|
+
grid: torch.Tensor,
|
|
163
|
+
interpolation_mode: int = 0,
|
|
164
|
+
padding_mode: int = 0,
|
|
165
|
+
align_corners: bool = False,
|
|
166
166
|
) -> Tensor:
|
|
167
167
|
"""References: https://github.com/pytorch/pytorch/blob/06a7dc21c1005750598c37f3adbc031183c74de6/torch/_decomp/decompositions.py#L4075
|
|
168
168
|
|
|
@@ -170,11 +170,12 @@ def _grid_sampler_3d(
|
|
|
170
170
|
"""
|
|
171
171
|
_expand_grid = False
|
|
172
172
|
torch._check(
|
|
173
|
-
|
|
174
|
-
|
|
173
|
+
interpolation_mode in (0, 1),
|
|
174
|
+
lambda: f"Invalid interpolation mode {interpolation_mode}",
|
|
175
175
|
)
|
|
176
176
|
torch._check(
|
|
177
|
-
|
|
177
|
+
padding_mode in (0, 1, 2), lambda: f"Invalid padding mode {padding_mode}"
|
|
178
|
+
)
|
|
178
179
|
|
|
179
180
|
# a is 5D: [B, C, D, H, W]
|
|
180
181
|
|
|
@@ -189,8 +190,7 @@ def _grid_sampler_3d(
|
|
|
189
190
|
# Reflects coordinates until they fall between low and high (inclusive).
|
|
190
191
|
# The bounds are passed as twice their value so that half-integer values
|
|
191
192
|
# can be represented as ints.
|
|
192
|
-
def reflect_coordinates(coords: Tensor, twice_low: int,
|
|
193
|
-
twice_high: int) -> Tensor:
|
|
193
|
+
def reflect_coordinates(coords: Tensor, twice_low: int, twice_high: int) -> Tensor:
|
|
194
194
|
if twice_low == twice_high:
|
|
195
195
|
return torch.zeros_like(coords)
|
|
196
196
|
coords_min = twice_low / 2
|
|
@@ -198,8 +198,9 @@ def _grid_sampler_3d(
|
|
|
198
198
|
coords2 = (coords - coords_min).abs()
|
|
199
199
|
extra = torch.fmod(coords2, coords_span)
|
|
200
200
|
flips = (coords2 / coords_span).floor().to(dtype=torch.int8)
|
|
201
|
-
return torch.where(
|
|
202
|
-
|
|
201
|
+
return torch.where(
|
|
202
|
+
flips & 1 == 0, extra + coords_min, coords_span + coords_min - extra
|
|
203
|
+
)
|
|
203
204
|
|
|
204
205
|
def compute_coordinates(coords: Tensor, size: int) -> Tensor:
|
|
205
206
|
if padding_mode == 0: # Zero
|
|
@@ -219,7 +220,7 @@ def _grid_sampler_3d(
|
|
|
219
220
|
|
|
220
221
|
N, C, iD, iH, iW = a.shape
|
|
221
222
|
_, oD, oH, oW, three = grid.shape
|
|
222
|
-
assert three == 3, "Last dim of grid must be 3. got {}"
|
|
223
|
+
assert three == 3, f"Last dim of grid must be 3. got {three}"
|
|
223
224
|
|
|
224
225
|
def in_bounds_cond(xs: Tensor, ys: Tensor, zs) -> Tensor:
|
|
225
226
|
xcheck = torch.logical_and(0 <= xs, xs < iW)
|
|
@@ -238,15 +239,16 @@ def _grid_sampler_3d(
|
|
|
238
239
|
# broadcasting with N_idx, C_idx for the purposes of advanced indexing
|
|
239
240
|
c = C if _expand_grid else 1
|
|
240
241
|
return tuple(
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
242
|
+
torch.where(cond, t, 0).view(N, c, oD, oH, oW)
|
|
243
|
+
for t in (
|
|
244
|
+
xs.to(dtype=torch.int64),
|
|
245
|
+
ys.to(dtype=torch.int64),
|
|
246
|
+
zs.to(dtype=torch.int64),
|
|
247
|
+
ws,
|
|
248
|
+
)
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
def get_summand(ix: torch.Tensor, iy: torch.Tensor, iz: torch.Tensor, w) -> Tensor:
|
|
250
252
|
# Perform clipping, index into input tensor and multiply by weight
|
|
251
253
|
idx_x, idx_y, idx_z, w_ = clip(ix, iy, iz, w)
|
|
252
254
|
return a[N_idx, C_idx, idx_z, idx_y, idx_x] * w_
|
|
@@ -279,16 +281,18 @@ def _grid_sampler_3d(
|
|
|
279
281
|
w_seb = (ix - ix_nwf) * (iy - iy_nwf) * (id_ - id_nwf)
|
|
280
282
|
|
|
281
283
|
return _sum_tensors(
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
)
|
|
284
|
+
get_summand(ix, iy, id_, w)
|
|
285
|
+
for (ix, iy, id_, w) in (
|
|
286
|
+
(ix_nwf, iy_nwf, id_nwf, w_nwf),
|
|
287
|
+
(ix_nef, iy_nef, id_nef, w_nef),
|
|
288
|
+
(ix_swf, iy_swf, id_swf, w_swf),
|
|
289
|
+
(ix_sef, iy_sef, id_sef, w_sef),
|
|
290
|
+
(ix_nwb, iy_nwb, id_nwb, w_nwb),
|
|
291
|
+
(ix_neb, iy_neb, id_neb, w_neb),
|
|
292
|
+
(ix_swb, iy_swb, id_swb, w_swb),
|
|
293
|
+
(ix_seb, iy_seb, id_seb, w_seb),
|
|
294
|
+
)
|
|
295
|
+
)
|
|
292
296
|
else: # interpolation_mode == 1: # Nearest
|
|
293
297
|
ix = compute_source_index(x, iW)
|
|
294
298
|
iy = compute_source_index(y, iH)
|
|
@@ -301,7 +305,8 @@ def _grid_sampler_3d(
|
|
|
301
305
|
return get_summand(ix_nearest, iy_nearest, iz_nearest, 1)
|
|
302
306
|
|
|
303
307
|
|
|
304
|
-
DECOMPOSITIONS = decomp.get_decompositions(
|
|
308
|
+
DECOMPOSITIONS = decomp.get_decompositions(
|
|
309
|
+
[
|
|
305
310
|
torch.ops.aten.upsample_bicubic2d,
|
|
306
311
|
torch.ops.aten.upsample_nearest1d,
|
|
307
312
|
torch.ops.aten.upsample_nearest2d,
|
|
@@ -782,9 +787,10 @@ DECOMPOSITIONS = decomp.get_decompositions([
|
|
|
782
787
|
torch.ops.aten.__irshift__.Tensor,
|
|
783
788
|
torch.ops.aten.__irshift__.Scalar,
|
|
784
789
|
torch.ops.aten.__ior__.Tensor,
|
|
785
|
-
]
|
|
790
|
+
]
|
|
791
|
+
)
|
|
786
792
|
|
|
787
793
|
MUTABLE_DECOMPOSITION = [
|
|
788
|
-
|
|
789
|
-
|
|
794
|
+
torch.ops.aten.bernoulli_.Tensor,
|
|
795
|
+
torch.ops.aten.bernoulli_.float,
|
|
790
796
|
]
|
torchax/export.py
CHANGED
|
@@ -14,17 +14,21 @@
|
|
|
14
14
|
|
|
15
15
|
# pylint: disable
|
|
16
16
|
"""Utilities for exporting a torch program to jax/stablehlo."""
|
|
17
|
+
|
|
17
18
|
import copy
|
|
18
|
-
from typing import Any
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
import jax
|
|
22
|
+
import jax.export
|
|
23
|
+
import sympy
|
|
19
24
|
import torch
|
|
25
|
+
import torch._refs
|
|
26
|
+
from torch._decomp import get_decompositions
|
|
20
27
|
from torch.utils import _pytree as pytree
|
|
28
|
+
|
|
21
29
|
import torchax
|
|
22
|
-
from torchax import tensor
|
|
23
|
-
from torchax.ops import ops_registry, mappings
|
|
24
30
|
from torchax import decompositions
|
|
25
|
-
import
|
|
26
|
-
import jax.export
|
|
27
|
-
import sympy
|
|
31
|
+
from torchax.ops import mappings, ops_registry
|
|
28
32
|
|
|
29
33
|
DEBUG = False
|
|
30
34
|
|
|
@@ -34,16 +38,13 @@ class JaxInterpreter(torch.fx.Interpreter):
|
|
|
34
38
|
|
|
35
39
|
def __init__(self, graph_module):
|
|
36
40
|
super().__init__(graph_module)
|
|
37
|
-
import torchax.ops.jaten
|
|
38
|
-
import torchax.ops.jtorch
|
|
39
41
|
|
|
40
|
-
def call_function(self, target, args:
|
|
41
|
-
if not isinstance(target,
|
|
42
|
-
(torch._ops.OpOverloadPacket, torch._ops.OpOverload)):
|
|
42
|
+
def call_function(self, target, args: tuple, kwargs: dict) -> Any:
|
|
43
|
+
if not isinstance(target, (torch._ops.OpOverloadPacket, torch._ops.OpOverload)):
|
|
43
44
|
return super().call_function(target, args, kwargs)
|
|
44
45
|
|
|
45
46
|
if DEBUG:
|
|
46
|
-
print(
|
|
47
|
+
print("Running ", target.name(), "--------")
|
|
47
48
|
|
|
48
49
|
op = ops_registry.all_aten_ops.get(target)
|
|
49
50
|
if op is None:
|
|
@@ -54,31 +55,30 @@ class JaxInterpreter(torch.fx.Interpreter):
|
|
|
54
55
|
op = ops_registry.all_aten_ops.get(target.overloadpacket)
|
|
55
56
|
if op is None:
|
|
56
57
|
print(target.name(), target.tags)
|
|
57
|
-
raise RuntimeError(
|
|
58
|
+
raise RuntimeError("No lowering found for", target.name())
|
|
58
59
|
return op.func(*args, **kwargs)
|
|
59
60
|
|
|
60
61
|
def run_node(self, n) -> Any:
|
|
61
62
|
res = super().run_node(n)
|
|
62
63
|
if DEBUG:
|
|
63
|
-
if n.op ==
|
|
64
|
-
if hasattr(res,
|
|
65
|
-
print(
|
|
64
|
+
if n.op == "call_function":
|
|
65
|
+
if hasattr(res, "shape"):
|
|
66
|
+
print("Meta:", n.meta.get("val").shape, "REAL: ", res.shape)
|
|
66
67
|
return res
|
|
67
68
|
|
|
68
69
|
|
|
69
|
-
from torch._decomp import get_decompositions
|
|
70
|
-
import torch._refs
|
|
71
|
-
|
|
72
70
|
_extra_decomp = get_decompositions([torch.ops.aten.unfold])
|
|
73
71
|
|
|
74
72
|
|
|
75
73
|
def _extract_states_from_exported_program(exported_model):
|
|
76
74
|
# NOTE call convention: (parameters, buffers, user_inputs)
|
|
77
|
-
param_and_buffer_keys =
|
|
75
|
+
param_and_buffer_keys = (
|
|
76
|
+
exported_model.graph_signature.parameters + exported_model.graph_signature.buffers
|
|
77
|
+
)
|
|
78
78
|
state_dict = copy.copy(exported_model.state_dict)
|
|
79
|
-
if (constants := getattr(exported_model,
|
|
79
|
+
if (constants := getattr(exported_model, "constants", None)) is not None:
|
|
80
80
|
state_dict.update(constants)
|
|
81
|
-
param_buffer_values =
|
|
81
|
+
param_buffer_values = [state_dict[key] for key in param_and_buffer_keys]
|
|
82
82
|
|
|
83
83
|
if hasattr(exported_model.graph_signature, "lifted_tensor_constants"):
|
|
84
84
|
for name in exported_model.graph_signature.lifted_tensor_constants:
|
|
@@ -94,19 +94,19 @@ def exported_program_to_jax(exported_program, export_raw: bool = False):
|
|
|
94
94
|
|
|
95
95
|
func(state, input) would be how you call it.
|
|
96
96
|
"""
|
|
97
|
-
if torch.__version__ >=
|
|
97
|
+
if torch.__version__ >= "2.2":
|
|
98
98
|
# torch version 2.1 didn't expose this yet
|
|
99
99
|
exported_program = exported_program.run_decompositions()
|
|
100
100
|
exported_program = exported_program.run_decompositions(
|
|
101
|
-
|
|
101
|
+
decompositions.DECOMPOSITIONS
|
|
102
|
+
)
|
|
102
103
|
if DEBUG:
|
|
103
104
|
print(exported_program.graph_module.code)
|
|
104
105
|
|
|
105
106
|
names, states = _extract_states_from_exported_program(exported_program)
|
|
106
107
|
|
|
107
108
|
def _extract_args(args, kwargs):
|
|
108
|
-
flat_args, received_spec = pytree.tree_flatten(
|
|
109
|
-
(args, kwargs)) # type: ignore[possibly-undefined]
|
|
109
|
+
flat_args, received_spec = pytree.tree_flatten((args, kwargs)) # type: ignore[possibly-undefined]
|
|
110
110
|
return flat_args
|
|
111
111
|
|
|
112
112
|
num_mutations = len(exported_program.graph_signature.buffers_to_mutate)
|
|
@@ -114,9 +114,9 @@ def exported_program_to_jax(exported_program, export_raw: bool = False):
|
|
|
114
114
|
def func(states, inputs):
|
|
115
115
|
args = _extract_args(inputs, {})
|
|
116
116
|
res = JaxInterpreter(exported_program.graph_module).run(
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
117
|
+
*states,
|
|
118
|
+
*args,
|
|
119
|
+
enable_io_processing=False,
|
|
120
120
|
)
|
|
121
121
|
res = res[num_mutations:]
|
|
122
122
|
return res
|
|
@@ -134,21 +134,19 @@ def extract_avals(exported):
|
|
|
134
134
|
"""
|
|
135
135
|
|
|
136
136
|
def _to_aval(arg_meta, symbolic_shapes):
|
|
137
|
-
"""Convet from torch type to jax abstract value for export tracing
|
|
138
|
-
"""
|
|
137
|
+
"""Convet from torch type to jax abstract value for export tracing"""
|
|
139
138
|
|
|
140
139
|
def _get_dim(d):
|
|
141
140
|
if isinstance(d, torch.SymInt):
|
|
142
141
|
return symbolic_shapes[str(d)]
|
|
143
142
|
return d
|
|
144
143
|
|
|
145
|
-
val = arg_meta[
|
|
146
|
-
is_scalar = isinstance(val, float) or isinstance(val, int) or isinstance(
|
|
147
|
-
val, bool)
|
|
144
|
+
val = arg_meta["val"]
|
|
145
|
+
is_scalar = isinstance(val, float) or isinstance(val, int) or isinstance(val, bool)
|
|
148
146
|
if is_scalar:
|
|
149
|
-
return jax.ShapeDtypeStruct([], type(arg_meta[
|
|
147
|
+
return jax.ShapeDtypeStruct([], type(arg_meta["val"]))
|
|
150
148
|
|
|
151
|
-
tensor_meta = arg_meta[
|
|
149
|
+
tensor_meta = arg_meta["tensor_meta"]
|
|
152
150
|
shape = [_get_dim(d) for d in tensor_meta.shape]
|
|
153
151
|
return jax.ShapeDtypeStruct(shape, mappings.t2j_dtype(tensor_meta.dtype))
|
|
154
152
|
|
|
@@ -156,8 +154,9 @@ def extract_avals(exported):
|
|
|
156
154
|
"""Return placeholders with input metadata"""
|
|
157
155
|
placeholders = [p for p in exported.graph.nodes if p.op == "placeholder"]
|
|
158
156
|
input_placeholders = [
|
|
159
|
-
|
|
160
|
-
|
|
157
|
+
p
|
|
158
|
+
for p, s in zip(placeholders, exported.graph_signature.input_specs, strict=False)
|
|
159
|
+
if s.kind == torch.export.graph_signature.InputKind.USER_INPUT
|
|
161
160
|
]
|
|
162
161
|
return input_placeholders
|
|
163
162
|
|
|
@@ -179,17 +178,22 @@ def extract_avals(exported):
|
|
|
179
178
|
torch.export.Dim("a", min=5, max=10)
|
|
180
179
|
==> ("a >= 5", "a <= 10",)
|
|
181
180
|
"""
|
|
182
|
-
if
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
181
|
+
if (
|
|
182
|
+
not isinstance(torch_constraint, torch.utils._sympy.value_ranges.ValueRanges)
|
|
183
|
+
or torch_constraint.is_bool
|
|
184
|
+
):
|
|
185
|
+
raise TypeError(f"No symbolic constraint handler for: {torch_constraint}")
|
|
186
186
|
|
|
187
187
|
constraints = []
|
|
188
188
|
symbol = sympy.Symbol(symbol_name)
|
|
189
189
|
if torch_constraint.lower != 2:
|
|
190
190
|
constraints.append(symbol >= torch_constraint.lower)
|
|
191
191
|
from sympy.core.singleton import S
|
|
192
|
-
|
|
192
|
+
|
|
193
|
+
if (
|
|
194
|
+
not torch_constraint.upper.is_infinite
|
|
195
|
+
and torch_constraint.upper is not S.IntInfinity
|
|
196
|
+
):
|
|
193
197
|
constraints.append(symbol <= torch_constraint.upper)
|
|
194
198
|
|
|
195
199
|
return tuple(sympy.pretty(c, use_unicode=False) for c in constraints)
|
|
@@ -208,8 +212,7 @@ def extract_avals(exported):
|
|
|
208
212
|
symbol_name = str(sym)
|
|
209
213
|
constraints = _build_symbolic_constraints(symbol_name, constraint)
|
|
210
214
|
if sym.is_symbol:
|
|
211
|
-
symbolic_shape = jax.export.symbolic_shape(
|
|
212
|
-
symbol_name, constraints=constraints)
|
|
215
|
+
symbolic_shape = jax.export.symbolic_shape(symbol_name, constraints=constraints)
|
|
213
216
|
else:
|
|
214
217
|
assert len(sym.free_symbols) > 0
|
|
215
218
|
scope = free_symbols[str(list(sym.free_symbols)[0])].scope
|
|
@@ -222,12 +225,8 @@ def extract_avals(exported):
|
|
|
222
225
|
# integer compuations on symbol variables, so each symbol variable is OK to
|
|
223
226
|
# have its own scope.
|
|
224
227
|
symbolic_shapes = {}
|
|
225
|
-
symbol_variables = [
|
|
226
|
-
|
|
227
|
-
]
|
|
228
|
-
symbol_exprs = [
|
|
229
|
-
(s, v) for s, v in range_constraints.items() if not s.is_symbol
|
|
230
|
-
]
|
|
228
|
+
symbol_variables = [(s, v) for s, v in range_constraints.items() if s.is_symbol]
|
|
229
|
+
symbol_exprs = [(s, v) for s, v in range_constraints.items() if not s.is_symbol]
|
|
231
230
|
for sym, constraint in symbol_variables + symbol_exprs:
|
|
232
231
|
symbolic_shape = _build_symbolic_shape(sym, constraint, symbolic_shapes)
|
|
233
232
|
symbolic_shapes[str(sym)] = symbolic_shape
|
|
@@ -237,10 +236,10 @@ def extract_avals(exported):
|
|
|
237
236
|
args = _get_inputs(exported)
|
|
238
237
|
|
|
239
238
|
if DEBUG:
|
|
240
|
-
print(
|
|
241
|
-
print(
|
|
239
|
+
print("Inputs to aval:", args, "--------")
|
|
240
|
+
print("Symbolic shapes:", symbolic_shapes)
|
|
242
241
|
for arg in args:
|
|
243
|
-
print(
|
|
242
|
+
print("Meta2Aval", arg.meta, "--> ", _to_aval(arg.meta, symbolic_shapes))
|
|
244
243
|
|
|
245
244
|
return [_to_aval(arg.meta, symbolic_shapes) for arg in args]
|
|
246
245
|
|
torchax/flax.py
CHANGED
|
@@ -15,18 +15,19 @@
|
|
|
15
15
|
"""Flax interop."""
|
|
16
16
|
|
|
17
17
|
import torch
|
|
18
|
+
|
|
18
19
|
import torchax as tx
|
|
19
20
|
import torchax.interop
|
|
20
21
|
|
|
21
22
|
|
|
22
23
|
class FlaxNNModule(torch.nn.Module):
|
|
23
|
-
|
|
24
24
|
def __init__(self, env, flax_module, sample_args, sample_kwargs=None):
|
|
25
25
|
super().__init__()
|
|
26
26
|
prng = env.prng_key
|
|
27
27
|
sample_kwargs = sample_kwargs or {}
|
|
28
|
-
parameter_dict = tx.interop.call_jax(
|
|
29
|
-
|
|
28
|
+
parameter_dict = tx.interop.call_jax(
|
|
29
|
+
flax_module.init, prng, *sample_args, **sample_kwargs
|
|
30
|
+
)
|
|
30
31
|
|
|
31
32
|
self._params = self._encode_nested_dict(parameter_dict)
|
|
32
33
|
|
|
@@ -49,5 +50,6 @@ class FlaxNNModule(torch.nn.Module):
|
|
|
49
50
|
|
|
50
51
|
def forward(self, *args, **kwargs):
|
|
51
52
|
nested_dict_params = self._decode_nested_dict(self._params)
|
|
52
|
-
return tx.interop.call_jax(
|
|
53
|
-
|
|
53
|
+
return tx.interop.call_jax(
|
|
54
|
+
self._flax_module.apply, nested_dict_params, *args, **kwargs
|
|
55
|
+
)
|