torchax 0.0.10.dev20251117__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchax/CONTRIBUTING.md +43 -0
- torchax/__init__.py +153 -0
- torchax/amp.py +346 -0
- torchax/checkpoint.py +79 -0
- torchax/config.py +44 -0
- torchax/decompositions.py +790 -0
- torchax/device_module.py +47 -0
- torchax/export.py +259 -0
- torchax/flax.py +53 -0
- torchax/interop.py +369 -0
- torchax/mesh_util.py +234 -0
- torchax/ops/__init__.py +24 -0
- torchax/ops/jaten.py +5937 -0
- torchax/ops/jax_reimplement.py +185 -0
- torchax/ops/jc10d.py +66 -0
- torchax/ops/jimage.py +127 -0
- torchax/ops/jlibrary.py +94 -0
- torchax/ops/jtorch.py +631 -0
- torchax/ops/jtorchvision_nms.py +248 -0
- torchax/ops/mappings.py +161 -0
- torchax/ops/op_base.py +145 -0
- torchax/ops/ops_registry.py +69 -0
- torchax/tensor.py +736 -0
- torchax/train.py +132 -0
- torchax/types.py +26 -0
- torchax/util.py +102 -0
- torchax/view.py +391 -0
- torchax-0.0.10.dev20251117.dist-info/METADATA +507 -0
- torchax-0.0.10.dev20251117.dist-info/RECORD +31 -0
- torchax-0.0.10.dev20251117.dist-info/WHEEL +4 -0
- torchax-0.0.10.dev20251117.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from collections.abc import Sequence
|
|
16
|
+
from jax._src.numpy.util import promote_dtypes_inexact
|
|
17
|
+
import numpy as np
|
|
18
|
+
import jax
|
|
19
|
+
from jax import numpy as jnp
|
|
20
|
+
from jax._src.util import canonicalize_axis
|
|
21
|
+
from jax._src import core
|
|
22
|
+
from jax._src.image.scale import _kernels, ResizeMethod
|
|
23
|
+
from jax import lax
|
|
24
|
+
from typing import Callable
|
|
25
|
+
|
|
26
|
+
# TODO: This block of code needs to be revisited based on https://github.com/jax-ml/jax/issues/24106
|
|
27
|
+
# START ----------------- JAX code copied for fixing scale_and_translate -----------------------------
|
|
28
|
+
|
|
29
|
+
# JAX Link: https://github.com/jax-ml/jax/blob/18f48bd52abe907ff9818da52f3d195d32910c1b/jax/_src/image/scale.py#L52
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def compute_weight_mat(input_size: core.DimSize, output_size: core.DimSize,
|
|
33
|
+
scale, translation, kernel: Callable, antialias: bool):
|
|
34
|
+
dtype = jnp.result_type(scale, translation)
|
|
35
|
+
inv_scale = 1. / scale
|
|
36
|
+
# When downsampling the kernel should be scaled since we want to low pass
|
|
37
|
+
# filter and interpolate, but when upsampling it should not be since we only
|
|
38
|
+
# want to interpolate.
|
|
39
|
+
kernel_scale = jnp.maximum(inv_scale, 1.) if antialias else 1.
|
|
40
|
+
sample_f = ((jnp.arange(output_size, dtype=dtype) + 0.5) * inv_scale -
|
|
41
|
+
translation * inv_scale - 0.5)
|
|
42
|
+
x = (
|
|
43
|
+
jnp.abs(sample_f[jnp.newaxis, :] -
|
|
44
|
+
jnp.arange(input_size, dtype=dtype)[:, jnp.newaxis]) /
|
|
45
|
+
kernel_scale)
|
|
46
|
+
weights = kernel(x)
|
|
47
|
+
|
|
48
|
+
total_weight_sum = jnp.sum(weights, axis=0, keepdims=True)
|
|
49
|
+
weights = jnp.where(
|
|
50
|
+
jnp.abs(total_weight_sum) > 1000. * float(np.finfo(np.float32).eps),
|
|
51
|
+
jnp.divide(weights, jnp.where(total_weight_sum != 0, total_weight_sum,
|
|
52
|
+
1)), 0)
|
|
53
|
+
# Zero out weights where the sample location is completely outside the input
|
|
54
|
+
# range.
|
|
55
|
+
# Note sample_f has already had the 0.5 removed, hence the weird range below.
|
|
56
|
+
|
|
57
|
+
# (barney-s) -------------- returning weights without zeroing ---------------------
|
|
58
|
+
return weights
|
|
59
|
+
input_size_minus_0_5 = core.dimension_as_value(input_size) - 0.5
|
|
60
|
+
return jnp.where(
|
|
61
|
+
jnp.logical_and(sample_f >= -0.5, sample_f
|
|
62
|
+
<= input_size_minus_0_5)[jnp.newaxis, :], weights, 0)
|
|
63
|
+
# (barney-s) -------------- END returning weights without zeroing ---------------------
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
# JAX Link: https://github.com/jax-ml/jax/blob/18f48bd52abe907ff9818da52f3d195d32910c1b/jax/_src/image/scale.py#L86
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _scale_and_translate(x, output_shape: core.Shape,
|
|
70
|
+
spatial_dims: Sequence[int], scale, translation,
|
|
71
|
+
kernel, antialias: bool, precision):
|
|
72
|
+
input_shape = x.shape
|
|
73
|
+
assert len(input_shape) == len(output_shape)
|
|
74
|
+
assert len(spatial_dims) == len(scale)
|
|
75
|
+
assert len(spatial_dims) == len(translation)
|
|
76
|
+
if len(spatial_dims) == 0:
|
|
77
|
+
return x
|
|
78
|
+
contractions = []
|
|
79
|
+
in_indices = list(range(len(output_shape)))
|
|
80
|
+
out_indices = list(range(len(output_shape)))
|
|
81
|
+
for i, d in enumerate(spatial_dims):
|
|
82
|
+
d = canonicalize_axis(d, x.ndim)
|
|
83
|
+
m = input_shape[d]
|
|
84
|
+
n = output_shape[d]
|
|
85
|
+
w = compute_weight_mat(m, n, scale[i], translation[i], kernel,
|
|
86
|
+
antialias).astype(x.dtype)
|
|
87
|
+
contractions.append(w)
|
|
88
|
+
contractions.append([d, len(output_shape) + i])
|
|
89
|
+
out_indices[d] = len(output_shape) + i
|
|
90
|
+
contractions.append(out_indices)
|
|
91
|
+
return jnp.einsum(x, in_indices, *contractions, precision=precision)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
# JAX Link: https://github.com/jax-ml/jax/blob/18f48bd52abe907ff9818da52f3d195d32910c1b/jax/_src/image/scale.py#L172
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
# scale and translation here are scalar elements of an np.array, what is the
|
|
98
|
+
# correct type annotation?
|
|
99
|
+
def scale_and_translate(
|
|
100
|
+
image,
|
|
101
|
+
shape: core.Shape,
|
|
102
|
+
spatial_dims: Sequence[int],
|
|
103
|
+
scale,
|
|
104
|
+
translation,
|
|
105
|
+
# (barney-s) use string
|
|
106
|
+
method: str, #(barney-s) | ResizeMethod,
|
|
107
|
+
antialias: bool = True,
|
|
108
|
+
precision=lax.Precision.HIGHEST):
|
|
109
|
+
"""Apply a scale and translation to an image.
|
|
110
|
+
|
|
111
|
+
Generates a new image of shape 'shape' by resampling from the input image
|
|
112
|
+
using the sampling method corresponding to method. For 2D images, this
|
|
113
|
+
operation transforms a location in the input images, (x, y), to a location
|
|
114
|
+
in the output image according to::
|
|
115
|
+
|
|
116
|
+
(x * scale[1] + translation[1], y * scale[0] + translation[0])
|
|
117
|
+
|
|
118
|
+
(Note the *inverse* warp is used to generate the sample locations.)
|
|
119
|
+
Assumes half-centered pixels, i.e the pixel at integer location ``row, col``
|
|
120
|
+
has coordinates ``y, x = row + 0.5, col + 0.5``, and similarly for other input
|
|
121
|
+
image dimensions.
|
|
122
|
+
|
|
123
|
+
If an output location(pixel) maps to an input sample location that is outside
|
|
124
|
+
the input boundaries then the value for the output location will be set to
|
|
125
|
+
zero.
|
|
126
|
+
|
|
127
|
+
The ``method`` argument expects one of the following resize methods:
|
|
128
|
+
|
|
129
|
+
``ResizeMethod.LINEAR``, ``"linear"``, ``"bilinear"``, ``"trilinear"``,
|
|
130
|
+
``"triangle"`` `Linear interpolation`_. If ``antialias`` is ``True``, uses a
|
|
131
|
+
triangular filter when downsampling.
|
|
132
|
+
|
|
133
|
+
``ResizeMethod.CUBIC``, ``"cubic"``, ``"bicubic"``, ``"tricubic"``
|
|
134
|
+
`Cubic interpolation`_, using the Keys cubic kernel.
|
|
135
|
+
|
|
136
|
+
``ResizeMethod.LANCZOS3``, ``"lanczos3"``
|
|
137
|
+
`Lanczos resampling`_, using a kernel of radius 3.
|
|
138
|
+
|
|
139
|
+
``ResizeMethod.LANCZOS5``, ``"lanczos5"``
|
|
140
|
+
`Lanczos resampling`_, using a kernel of radius 5.
|
|
141
|
+
|
|
142
|
+
.. _Linear interpolation: https://en.wikipedia.org/wiki/Bilinear_interpolation
|
|
143
|
+
.. _Cubic interpolation: https://en.wikipedia.org/wiki/Bicubic_interpolation
|
|
144
|
+
.. _Lanczos resampling: https://en.wikipedia.org/wiki/Lanczos_resampling
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
image: a JAX array.
|
|
148
|
+
shape: the output shape, as a sequence of integers with length equal to the
|
|
149
|
+
number of dimensions of `image`.
|
|
150
|
+
spatial_dims: A length K tuple specifying the spatial dimensions that the
|
|
151
|
+
passed scale and translation should be applied to.
|
|
152
|
+
scale: A [K] array with the same number of dimensions as image, containing
|
|
153
|
+
the scale to apply in each dimension.
|
|
154
|
+
translation: A [K] array with the same number of dimensions as image,
|
|
155
|
+
containing the translation to apply in each dimension.
|
|
156
|
+
method: the resizing method to use; either a ``ResizeMethod`` instance or a
|
|
157
|
+
string. Available methods are: LINEAR, LANCZOS3, LANCZOS5, CUBIC.
|
|
158
|
+
antialias: Should an antialiasing filter be used when downsampling? Defaults
|
|
159
|
+
to ``True``. Has no effect when upsampling.
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
The scale and translated image.
|
|
163
|
+
"""
|
|
164
|
+
shape = core.canonicalize_shape(shape)
|
|
165
|
+
if len(shape) != image.ndim:
|
|
166
|
+
msg = ('shape must have length equal to the number of dimensions of x; '
|
|
167
|
+
f' {shape} vs {image.shape}')
|
|
168
|
+
raise ValueError(msg)
|
|
169
|
+
if isinstance(method, str):
|
|
170
|
+
method = ResizeMethod.from_string(method)
|
|
171
|
+
if method == ResizeMethod.NEAREST:
|
|
172
|
+
# Nearest neighbor is currently special-cased for straight resize, so skip
|
|
173
|
+
# for now.
|
|
174
|
+
raise ValueError('Nearest neighbor resampling is not currently supported '
|
|
175
|
+
'for scale_and_translate.')
|
|
176
|
+
assert isinstance(method, ResizeMethod)
|
|
177
|
+
|
|
178
|
+
kernel = _kernels[method]
|
|
179
|
+
image, = promote_dtypes_inexact(image)
|
|
180
|
+
scale, translation = promote_dtypes_inexact(scale, translation)
|
|
181
|
+
return _scale_and_translate(image, shape, spatial_dims, scale, translation,
|
|
182
|
+
kernel, antialias, precision)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
# END ----------------- END JAX code copied for testing -----------------------------
|
torchax/ops/jc10d.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
import jax
|
|
17
|
+
import jax.numpy as jnp
|
|
18
|
+
|
|
19
|
+
from torchax.ops import ops_registry
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def op(*aten, **kwargs):
|
|
23
|
+
|
|
24
|
+
def inner(func):
|
|
25
|
+
for a in aten:
|
|
26
|
+
ops_registry.register_torch_dispatch_op(a, func, **kwargs)
|
|
27
|
+
return func
|
|
28
|
+
|
|
29
|
+
return inner
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@op(torch.ops._c10d_functional.all_gather_into_tensor)
|
|
33
|
+
def _c10d_all_gather(input, group_size: int, group_name: str):
|
|
34
|
+
return jax.lax.all_gather(input, "torch_dist")
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@op(torch.ops._c10d_functional.all_reduce)
|
|
38
|
+
def _c10d_all_reduce(self, reduceOp: str, group_name: str):
|
|
39
|
+
|
|
40
|
+
if reduceOp == "sum":
|
|
41
|
+
res = jax.lax.psum(self, axis_name="torch_dist")
|
|
42
|
+
elif reduceOp == "avg":
|
|
43
|
+
res = jax.lax.pmean(self, axis_name="torch_dist")
|
|
44
|
+
elif reduceOp == "min":
|
|
45
|
+
res = jax.lax.pmin(self, axis_name="torch_dist")
|
|
46
|
+
elif reduceOp == "max":
|
|
47
|
+
res = jax.lax.pmax(self, axis_name="torch_dist")
|
|
48
|
+
else:
|
|
49
|
+
raise RuntimeError(f"Reduce op {reduceOp} not implemented")
|
|
50
|
+
return res
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@op(torch.ops._c10d_functional.broadcast)
|
|
54
|
+
def _c10d_broadcast(self, src: int, group_name: str):
|
|
55
|
+
masked = jnp.where(
|
|
56
|
+
jax.lax.axis_index("torch_dist") == src,
|
|
57
|
+
self,
|
|
58
|
+
jnp.zeros_like(self),
|
|
59
|
+
)
|
|
60
|
+
return jax.lax.psum(masked, "torch_dist")
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@op(torch.ops._c10d_functional.wait_tensor)
|
|
64
|
+
def _c10d_wait_tensor(tensor):
|
|
65
|
+
# Async tensor is aleady `wait`ed by dispatcher
|
|
66
|
+
return tensor
|
torchax/ops/jimage.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import jax
|
|
16
|
+
import jax.numpy as jnp
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def cubic_kernel(x, a=-0.75):
|
|
20
|
+
"""Cubic kernel with a = -0.75 (PyTorch-like Keys kernel)"""
|
|
21
|
+
absx = jnp.abs(x)
|
|
22
|
+
x2 = absx * absx
|
|
23
|
+
x3 = x2 * absx
|
|
24
|
+
cond1 = (absx <= 1)
|
|
25
|
+
cond2 = (absx > 1) & (absx < 2)
|
|
26
|
+
f1 = (a + 2) * x3 - (a + 3) * x2 + 1
|
|
27
|
+
f2 = a * x3 - 5 * a * x2 + 8 * a * absx - 4 * a
|
|
28
|
+
return jnp.where(cond1, f1, jnp.where(cond2, f2, 0.0))
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def compute_contribs(in_size,
|
|
32
|
+
out_size,
|
|
33
|
+
scale,
|
|
34
|
+
support=2.0,
|
|
35
|
+
align_corners=False,
|
|
36
|
+
dtype=None):
|
|
37
|
+
if align_corners:
|
|
38
|
+
if out_size == 1:
|
|
39
|
+
in_coords = jnp.zeros((1,), dtype=dtype)
|
|
40
|
+
else:
|
|
41
|
+
in_coords = jnp.linspace(0, in_size - 1, out_size, dtype=dtype)
|
|
42
|
+
else:
|
|
43
|
+
out_coords = jnp.arange(out_size, dtype=dtype) + 0.5
|
|
44
|
+
in_coords = out_coords / scale - 0.5
|
|
45
|
+
|
|
46
|
+
left_idx = jnp.floor(in_coords).astype(jnp.int32) - 1
|
|
47
|
+
idxs = left_idx[:, None] + jnp.arange(4)
|
|
48
|
+
|
|
49
|
+
dx = in_coords[:, None] - idxs
|
|
50
|
+
|
|
51
|
+
weights = cubic_kernel(dx)
|
|
52
|
+
|
|
53
|
+
weights = weights / jnp.sum(weights, axis=1, keepdims=True)
|
|
54
|
+
return idxs, weights
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def gather_weights(img, idxs, axis):
|
|
58
|
+
"""Safely gather with boundary handling"""
|
|
59
|
+
idxs = jnp.clip(idxs, 0, img.shape[axis] - 1)
|
|
60
|
+
return jnp.take(img, idxs, axis=axis)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def interpolate_along_axis_bchw(img, idxs, weights, axis):
|
|
64
|
+
"""
|
|
65
|
+
Interpolate along H (axis=2) or W (axis=3) for tensor (B, C, H, W).
|
|
66
|
+
idxs: (out_size, 4) int32 indices
|
|
67
|
+
weights: (out_size, 4) float32 weights
|
|
68
|
+
"""
|
|
69
|
+
assert axis in (2, 3), "Axis must be 2 (H) or 3 (W)"
|
|
70
|
+
out_size = idxs.shape[0]
|
|
71
|
+
k = idxs.shape[1] # Typically 4 for cubic
|
|
72
|
+
|
|
73
|
+
# Clip to input bounds
|
|
74
|
+
idxs = jnp.clip(idxs, 0, img.shape[axis] - 1) # (out_size, 4)
|
|
75
|
+
|
|
76
|
+
def gather_and_weight(i):
|
|
77
|
+
idx = idxs[i] # (4,)
|
|
78
|
+
w = weights[i] # (4,)
|
|
79
|
+
|
|
80
|
+
def gather_one(offset):
|
|
81
|
+
return jnp.take(img, idx[offset], axis=axis) # shape (B, C, H, W)
|
|
82
|
+
|
|
83
|
+
gathered = jnp.stack([gather_one(o) for o in range(k)],
|
|
84
|
+
axis=0) # (4, B, C, H, W)
|
|
85
|
+
weighted = jnp.tensordot(w, gathered, axes=(0, 0)) # (B, C, H, W)
|
|
86
|
+
return weighted
|
|
87
|
+
|
|
88
|
+
out = jax.vmap(gather_and_weight)(
|
|
89
|
+
jnp.arange(out_size)) # (out_size, B, C, H, W)
|
|
90
|
+
|
|
91
|
+
# Move the interpolated axis back into place
|
|
92
|
+
if axis == 2: # interpolated over H
|
|
93
|
+
return jnp.moveaxis(out, 0, 2) # (B, C, out_H, W)
|
|
94
|
+
else: # axis == 3, interpolated over W
|
|
95
|
+
return jnp.moveaxis(out, 0, 3) # (B, C, H, out_W)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def interpolate_bicubic_no_aa(img, out_h, out_w, align_corners=False):
|
|
99
|
+
h, w = img.shape[-2:]
|
|
100
|
+
if align_corners and out_h > 1:
|
|
101
|
+
scale_y = (h - 1) / (out_h - 1)
|
|
102
|
+
else:
|
|
103
|
+
scale_y = out_h / h
|
|
104
|
+
|
|
105
|
+
if align_corners and out_w > 1:
|
|
106
|
+
scale_x = (w - 1) / (out_w - 1)
|
|
107
|
+
else:
|
|
108
|
+
scale_x = out_w / w
|
|
109
|
+
|
|
110
|
+
idxs_y, weights_y = compute_contribs(
|
|
111
|
+
h,
|
|
112
|
+
out_h,
|
|
113
|
+
scale_y,
|
|
114
|
+
align_corners=align_corners,
|
|
115
|
+
dtype=img.dtype,
|
|
116
|
+
)
|
|
117
|
+
tmp = interpolate_along_axis_bchw(img, idxs_y, weights_y, axis=2)
|
|
118
|
+
|
|
119
|
+
idxs_x, weights_x = compute_contribs(
|
|
120
|
+
w,
|
|
121
|
+
out_w,
|
|
122
|
+
scale_x,
|
|
123
|
+
align_corners=align_corners,
|
|
124
|
+
dtype=img.dtype,
|
|
125
|
+
)
|
|
126
|
+
out = interpolate_along_axis_bchw(tmp, idxs_x, weights_x, axis=3)
|
|
127
|
+
return out
|
torchax/ops/jlibrary.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""The `jlibrary` module has functions which help to preserve torch.library ops
|
|
16
|
+
during export. This includes aten ops, and custom operations.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
import torch.nn as nn
|
|
21
|
+
import torchax
|
|
22
|
+
from torchax.ops import jaten
|
|
23
|
+
import jax
|
|
24
|
+
import functools
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _jit_composite_impl(composite_name, jaxpr_impl, **jit_args):
|
|
28
|
+
"""Wrap a jaxpr in a jitted function with the proper composite name
|
|
29
|
+
TODO: Wrap JIT in a `stablehlo.composite` op, instead of generating a call op.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def composite_impl(*args):
|
|
33
|
+
return jaxpr_impl(*args)
|
|
34
|
+
|
|
35
|
+
composite_impl.__name__ = composite_name
|
|
36
|
+
composite_impl.__qualname__ = composite_name
|
|
37
|
+
return jax.jit(composite_impl, **jit_args)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def register_jax_composite(composite_name, impl, *ops, **jit_args):
|
|
41
|
+
"""Register a composite using a JAX implementation.
|
|
42
|
+
composite_name - The name of the library op to use in the exported composite
|
|
43
|
+
impl - A JAX lowering for the library operation
|
|
44
|
+
*ops - Variadic torch.ops to lower using `impl`.
|
|
45
|
+
**jit_args - Additional parameters to forward to JAX jit.
|
|
46
|
+
|
|
47
|
+
This is used to register custom lowerings with an explicit jaxpr
|
|
48
|
+
implementation, such as preserving a specific aten op using a jaten impl.
|
|
49
|
+
|
|
50
|
+
For custom torch op registration with a decomposition written in torch,
|
|
51
|
+
use `register_torch_composite`.
|
|
52
|
+
|
|
53
|
+
For jit params and troubleshooting see:
|
|
54
|
+
https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
@jaten.op(*ops)
|
|
58
|
+
def _composite_impl(*args):
|
|
59
|
+
return _jit_composite_impl(composite_name, impl, **jit_args)(*args)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def register_torch_composite(composite_name, impl, *ops, **jit_args):
|
|
63
|
+
"""Register a torch decomposition as a composite.
|
|
64
|
+
This is useful for registerring custom torch op libraries as composite ops.
|
|
65
|
+
|
|
66
|
+
The `impl` can be the `@impl` used to define the torch custom library op.
|
|
67
|
+
This must be a function or module impl that provides the decompositions, and
|
|
68
|
+
not an instance of the custom op.
|
|
69
|
+
|
|
70
|
+
TODO: Better error handling, or can we make this an instance of the op as a param?
|
|
71
|
+
|
|
72
|
+
For jit params and troubleshooting see:
|
|
73
|
+
https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
@jaten.op(*ops)
|
|
77
|
+
def _composite_impl(*args):
|
|
78
|
+
|
|
79
|
+
class ImplWrapper(torch.nn.Module):
|
|
80
|
+
|
|
81
|
+
def __init__(self):
|
|
82
|
+
super().__init__()
|
|
83
|
+
|
|
84
|
+
def forward(self, *args):
|
|
85
|
+
return impl(*args)
|
|
86
|
+
|
|
87
|
+
# Note: avoid refactoring to share code with register_jaxpr_composite.
|
|
88
|
+
# The `extract_jax` call must live in the `@jaten.op` handler. If called
|
|
89
|
+
# outside of the handler, we would build the jaxpr representation of the
|
|
90
|
+
# module once during registration, potentially missing op registrations that
|
|
91
|
+
# come after. I.e. may miss nested abstractions if we build jaxpr AoT.
|
|
92
|
+
state, jfn = torchax.extract_jax(ImplWrapper())
|
|
93
|
+
jaxpr_impl = lambda *args: jfn(state, tuple([*args]))
|
|
94
|
+
return _jit_composite_impl(composite_name, jaxpr_impl, **jit_args)(*args)
|