torchax 0.0.10.dev20251114__py3-none-any.whl → 0.0.11.dev202612__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/__init__.py +73 -77
- torchax/amp.py +143 -271
- torchax/checkpoint.py +15 -9
- torchax/config.py +0 -4
- torchax/decompositions.py +66 -60
- torchax/export.py +53 -54
- torchax/flax.py +7 -5
- torchax/interop.py +66 -62
- torchax/mesh_util.py +20 -18
- torchax/ops/__init__.py +4 -3
- torchax/ops/jaten.py +3841 -3968
- torchax/ops/jax_reimplement.py +68 -42
- torchax/ops/jc10d.py +4 -6
- torchax/ops/jimage.py +20 -25
- torchax/ops/jlibrary.py +6 -6
- torchax/ops/jtorch.py +355 -419
- torchax/ops/jtorchvision_nms.py +69 -49
- torchax/ops/mappings.py +42 -63
- torchax/ops/op_base.py +17 -25
- torchax/ops/ops_registry.py +35 -30
- torchax/tensor.py +124 -128
- torchax/train.py +100 -102
- torchax/types.py +8 -7
- torchax/util.py +6 -4
- torchax/view.py +144 -136
- {torchax-0.0.10.dev20251114.dist-info → torchax-0.0.11.dev202612.dist-info}/METADATA +7 -1
- torchax-0.0.11.dev202612.dist-info/RECORD +31 -0
- {torchax-0.0.10.dev20251114.dist-info → torchax-0.0.11.dev202612.dist-info}/WHEEL +1 -1
- torchax-0.0.10.dev20251114.dist-info/RECORD +0 -31
- {torchax-0.0.10.dev20251114.dist-info → torchax-0.0.11.dev202612.dist-info}/licenses/LICENSE +0 -0
torchax/ops/jax_reimplement.py
CHANGED
|
@@ -12,16 +12,15 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from collections.abc import Sequence
|
|
16
|
-
|
|
15
|
+
from collections.abc import Callable, Sequence
|
|
16
|
+
|
|
17
17
|
import numpy as np
|
|
18
|
-
import
|
|
18
|
+
from jax import lax
|
|
19
19
|
from jax import numpy as jnp
|
|
20
|
-
from jax._src.util import canonicalize_axis
|
|
21
20
|
from jax._src import core
|
|
22
|
-
from jax._src.image.scale import
|
|
23
|
-
from jax import
|
|
24
|
-
from
|
|
21
|
+
from jax._src.image.scale import ResizeMethod, _kernels
|
|
22
|
+
from jax._src.numpy.util import promote_dtypes_inexact
|
|
23
|
+
from jax._src.util import canonicalize_axis
|
|
25
24
|
|
|
26
25
|
# TODO: This block of code needs to be revisited based on https://github.com/jax-ml/jax/issues/24106
|
|
27
26
|
# START ----------------- JAX code copied for fixing scale_and_translate -----------------------------
|
|
@@ -29,27 +28,39 @@ from typing import Callable
|
|
|
29
28
|
# JAX Link: https://github.com/jax-ml/jax/blob/18f48bd52abe907ff9818da52f3d195d32910c1b/jax/_src/image/scale.py#L52
|
|
30
29
|
|
|
31
30
|
|
|
32
|
-
def compute_weight_mat(
|
|
33
|
-
|
|
31
|
+
def compute_weight_mat(
|
|
32
|
+
input_size: core.DimSize,
|
|
33
|
+
output_size: core.DimSize,
|
|
34
|
+
scale,
|
|
35
|
+
translation,
|
|
36
|
+
kernel: Callable,
|
|
37
|
+
antialias: bool,
|
|
38
|
+
):
|
|
34
39
|
dtype = jnp.result_type(scale, translation)
|
|
35
|
-
inv_scale = 1. / scale
|
|
40
|
+
inv_scale = 1.0 / scale
|
|
36
41
|
# When downsampling the kernel should be scaled since we want to low pass
|
|
37
42
|
# filter and interpolate, but when upsampling it should not be since we only
|
|
38
43
|
# want to interpolate.
|
|
39
|
-
kernel_scale = jnp.maximum(inv_scale, 1.) if antialias else 1.
|
|
40
|
-
sample_f = (
|
|
41
|
-
|
|
44
|
+
kernel_scale = jnp.maximum(inv_scale, 1.0) if antialias else 1.0
|
|
45
|
+
sample_f = (
|
|
46
|
+
(jnp.arange(output_size, dtype=dtype) + 0.5) * inv_scale
|
|
47
|
+
- translation * inv_scale
|
|
48
|
+
- 0.5
|
|
49
|
+
)
|
|
42
50
|
x = (
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
51
|
+
jnp.abs(
|
|
52
|
+
sample_f[jnp.newaxis, :] - jnp.arange(input_size, dtype=dtype)[:, jnp.newaxis]
|
|
53
|
+
)
|
|
54
|
+
/ kernel_scale
|
|
55
|
+
)
|
|
46
56
|
weights = kernel(x)
|
|
47
57
|
|
|
48
58
|
total_weight_sum = jnp.sum(weights, axis=0, keepdims=True)
|
|
49
59
|
weights = jnp.where(
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
60
|
+
jnp.abs(total_weight_sum) > 1000.0 * float(np.finfo(np.float32).eps),
|
|
61
|
+
jnp.divide(weights, jnp.where(total_weight_sum != 0, total_weight_sum, 1)),
|
|
62
|
+
0,
|
|
63
|
+
)
|
|
53
64
|
# Zero out weights where the sample location is completely outside the input
|
|
54
65
|
# range.
|
|
55
66
|
# Note sample_f has already had the 0.5 removed, hence the weird range below.
|
|
@@ -58,17 +69,26 @@ def compute_weight_mat(input_size: core.DimSize, output_size: core.DimSize,
|
|
|
58
69
|
return weights
|
|
59
70
|
input_size_minus_0_5 = core.dimension_as_value(input_size) - 0.5
|
|
60
71
|
return jnp.where(
|
|
61
|
-
|
|
62
|
-
|
|
72
|
+
jnp.logical_and(sample_f >= -0.5, sample_f <= input_size_minus_0_5)[jnp.newaxis, :],
|
|
73
|
+
weights,
|
|
74
|
+
0,
|
|
75
|
+
)
|
|
63
76
|
# (barney-s) -------------- END returning weights without zeroing ---------------------
|
|
64
77
|
|
|
65
78
|
|
|
66
79
|
# JAX Link: https://github.com/jax-ml/jax/blob/18f48bd52abe907ff9818da52f3d195d32910c1b/jax/_src/image/scale.py#L86
|
|
67
80
|
|
|
68
81
|
|
|
69
|
-
def _scale_and_translate(
|
|
70
|
-
|
|
71
|
-
|
|
82
|
+
def _scale_and_translate(
|
|
83
|
+
x,
|
|
84
|
+
output_shape: core.Shape,
|
|
85
|
+
spatial_dims: Sequence[int],
|
|
86
|
+
scale,
|
|
87
|
+
translation,
|
|
88
|
+
kernel,
|
|
89
|
+
antialias: bool,
|
|
90
|
+
precision,
|
|
91
|
+
):
|
|
72
92
|
input_shape = x.shape
|
|
73
93
|
assert len(input_shape) == len(output_shape)
|
|
74
94
|
assert len(spatial_dims) == len(scale)
|
|
@@ -82,8 +102,9 @@ def _scale_and_translate(x, output_shape: core.Shape,
|
|
|
82
102
|
d = canonicalize_axis(d, x.ndim)
|
|
83
103
|
m = input_shape[d]
|
|
84
104
|
n = output_shape[d]
|
|
85
|
-
w = compute_weight_mat(m, n, scale[i], translation[i], kernel,
|
|
86
|
-
|
|
105
|
+
w = compute_weight_mat(m, n, scale[i], translation[i], kernel, antialias).astype(
|
|
106
|
+
x.dtype
|
|
107
|
+
)
|
|
87
108
|
contractions.append(w)
|
|
88
109
|
contractions.append([d, len(output_shape) + i])
|
|
89
110
|
out_indices[d] = len(output_shape) + i
|
|
@@ -97,15 +118,16 @@ def _scale_and_translate(x, output_shape: core.Shape,
|
|
|
97
118
|
# scale and translation here are scalar elements of an np.array, what is the
|
|
98
119
|
# correct type annotation?
|
|
99
120
|
def scale_and_translate(
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
121
|
+
image,
|
|
122
|
+
shape: core.Shape,
|
|
123
|
+
spatial_dims: Sequence[int],
|
|
124
|
+
scale,
|
|
125
|
+
translation,
|
|
126
|
+
# (barney-s) use string
|
|
127
|
+
method: str, # (barney-s) | ResizeMethod,
|
|
128
|
+
antialias: bool = True,
|
|
129
|
+
precision=lax.Precision.HIGHEST,
|
|
130
|
+
):
|
|
109
131
|
"""Apply a scale and translation to an image.
|
|
110
132
|
|
|
111
133
|
Generates a new image of shape 'shape' by resampling from the input image
|
|
@@ -163,23 +185,27 @@ def scale_and_translate(
|
|
|
163
185
|
"""
|
|
164
186
|
shape = core.canonicalize_shape(shape)
|
|
165
187
|
if len(shape) != image.ndim:
|
|
166
|
-
msg = (
|
|
167
|
-
|
|
188
|
+
msg = (
|
|
189
|
+
"shape must have length equal to the number of dimensions of x; "
|
|
190
|
+
f" {shape} vs {image.shape}"
|
|
191
|
+
)
|
|
168
192
|
raise ValueError(msg)
|
|
169
193
|
if isinstance(method, str):
|
|
170
194
|
method = ResizeMethod.from_string(method)
|
|
171
195
|
if method == ResizeMethod.NEAREST:
|
|
172
196
|
# Nearest neighbor is currently special-cased for straight resize, so skip
|
|
173
197
|
# for now.
|
|
174
|
-
raise ValueError(
|
|
175
|
-
|
|
198
|
+
raise ValueError(
|
|
199
|
+
"Nearest neighbor resampling is not currently supported for scale_and_translate."
|
|
200
|
+
)
|
|
176
201
|
assert isinstance(method, ResizeMethod)
|
|
177
202
|
|
|
178
203
|
kernel = _kernels[method]
|
|
179
|
-
image, = promote_dtypes_inexact(image)
|
|
204
|
+
(image,) = promote_dtypes_inexact(image)
|
|
180
205
|
scale, translation = promote_dtypes_inexact(scale, translation)
|
|
181
|
-
return _scale_and_translate(
|
|
182
|
-
|
|
206
|
+
return _scale_and_translate(
|
|
207
|
+
image, shape, spatial_dims, scale, translation, kernel, antialias, precision
|
|
208
|
+
)
|
|
183
209
|
|
|
184
210
|
|
|
185
211
|
# END ----------------- END JAX code copied for testing -----------------------------
|
torchax/ops/jc10d.py
CHANGED
|
@@ -12,15 +12,14 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
import torch
|
|
16
15
|
import jax
|
|
17
16
|
import jax.numpy as jnp
|
|
17
|
+
import torch
|
|
18
18
|
|
|
19
19
|
from torchax.ops import ops_registry
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
def op(*aten, **kwargs):
|
|
23
|
-
|
|
24
23
|
def inner(func):
|
|
25
24
|
for a in aten:
|
|
26
25
|
ops_registry.register_torch_dispatch_op(a, func, **kwargs)
|
|
@@ -36,7 +35,6 @@ def _c10d_all_gather(input, group_size: int, group_name: str):
|
|
|
36
35
|
|
|
37
36
|
@op(torch.ops._c10d_functional.all_reduce)
|
|
38
37
|
def _c10d_all_reduce(self, reduceOp: str, group_name: str):
|
|
39
|
-
|
|
40
38
|
if reduceOp == "sum":
|
|
41
39
|
res = jax.lax.psum(self, axis_name="torch_dist")
|
|
42
40
|
elif reduceOp == "avg":
|
|
@@ -53,9 +51,9 @@ def _c10d_all_reduce(self, reduceOp: str, group_name: str):
|
|
|
53
51
|
@op(torch.ops._c10d_functional.broadcast)
|
|
54
52
|
def _c10d_broadcast(self, src: int, group_name: str):
|
|
55
53
|
masked = jnp.where(
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
54
|
+
jax.lax.axis_index("torch_dist") == src,
|
|
55
|
+
self,
|
|
56
|
+
jnp.zeros_like(self),
|
|
59
57
|
)
|
|
60
58
|
return jax.lax.psum(masked, "torch_dist")
|
|
61
59
|
|
torchax/ops/jimage.py
CHANGED
|
@@ -21,19 +21,16 @@ def cubic_kernel(x, a=-0.75):
|
|
|
21
21
|
absx = jnp.abs(x)
|
|
22
22
|
x2 = absx * absx
|
|
23
23
|
x3 = x2 * absx
|
|
24
|
-
cond1 =
|
|
24
|
+
cond1 = absx <= 1
|
|
25
25
|
cond2 = (absx > 1) & (absx < 2)
|
|
26
26
|
f1 = (a + 2) * x3 - (a + 3) * x2 + 1
|
|
27
27
|
f2 = a * x3 - 5 * a * x2 + 8 * a * absx - 4 * a
|
|
28
28
|
return jnp.where(cond1, f1, jnp.where(cond2, f2, 0.0))
|
|
29
29
|
|
|
30
30
|
|
|
31
|
-
def compute_contribs(
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
support=2.0,
|
|
35
|
-
align_corners=False,
|
|
36
|
-
dtype=None):
|
|
31
|
+
def compute_contribs(
|
|
32
|
+
in_size, out_size, scale, support=2.0, align_corners=False, dtype=None
|
|
33
|
+
):
|
|
37
34
|
if align_corners:
|
|
38
35
|
if out_size == 1:
|
|
39
36
|
in_coords = jnp.zeros((1,), dtype=dtype)
|
|
@@ -62,10 +59,10 @@ def gather_weights(img, idxs, axis):
|
|
|
62
59
|
|
|
63
60
|
def interpolate_along_axis_bchw(img, idxs, weights, axis):
|
|
64
61
|
"""
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
62
|
+
Interpolate along H (axis=2) or W (axis=3) for tensor (B, C, H, W).
|
|
63
|
+
idxs: (out_size, 4) int32 indices
|
|
64
|
+
weights: (out_size, 4) float32 weights
|
|
65
|
+
"""
|
|
69
66
|
assert axis in (2, 3), "Axis must be 2 (H) or 3 (W)"
|
|
70
67
|
out_size = idxs.shape[0]
|
|
71
68
|
k = idxs.shape[1] # Typically 4 for cubic
|
|
@@ -80,13 +77,11 @@ def interpolate_along_axis_bchw(img, idxs, weights, axis):
|
|
|
80
77
|
def gather_one(offset):
|
|
81
78
|
return jnp.take(img, idx[offset], axis=axis) # shape (B, C, H, W)
|
|
82
79
|
|
|
83
|
-
gathered = jnp.stack([gather_one(o) for o in range(k)],
|
|
84
|
-
axis=0) # (4, B, C, H, W)
|
|
80
|
+
gathered = jnp.stack([gather_one(o) for o in range(k)], axis=0) # (4, B, C, H, W)
|
|
85
81
|
weighted = jnp.tensordot(w, gathered, axes=(0, 0)) # (B, C, H, W)
|
|
86
82
|
return weighted
|
|
87
83
|
|
|
88
|
-
out = jax.vmap(gather_and_weight)(
|
|
89
|
-
jnp.arange(out_size)) # (out_size, B, C, H, W)
|
|
84
|
+
out = jax.vmap(gather_and_weight)(jnp.arange(out_size)) # (out_size, B, C, H, W)
|
|
90
85
|
|
|
91
86
|
# Move the interpolated axis back into place
|
|
92
87
|
if axis == 2: # interpolated over H
|
|
@@ -108,20 +103,20 @@ def interpolate_bicubic_no_aa(img, out_h, out_w, align_corners=False):
|
|
|
108
103
|
scale_x = out_w / w
|
|
109
104
|
|
|
110
105
|
idxs_y, weights_y = compute_contribs(
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
106
|
+
h,
|
|
107
|
+
out_h,
|
|
108
|
+
scale_y,
|
|
109
|
+
align_corners=align_corners,
|
|
110
|
+
dtype=img.dtype,
|
|
116
111
|
)
|
|
117
112
|
tmp = interpolate_along_axis_bchw(img, idxs_y, weights_y, axis=2)
|
|
118
113
|
|
|
119
114
|
idxs_x, weights_x = compute_contribs(
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
115
|
+
w,
|
|
116
|
+
out_w,
|
|
117
|
+
scale_x,
|
|
118
|
+
align_corners=align_corners,
|
|
119
|
+
dtype=img.dtype,
|
|
125
120
|
)
|
|
126
121
|
out = interpolate_along_axis_bchw(tmp, idxs_x, weights_x, axis=3)
|
|
127
122
|
return out
|
torchax/ops/jlibrary.py
CHANGED
|
@@ -16,12 +16,11 @@
|
|
|
16
16
|
during export. This includes aten ops, and custom operations.
|
|
17
17
|
"""
|
|
18
18
|
|
|
19
|
+
import jax
|
|
19
20
|
import torch
|
|
20
|
-
|
|
21
|
+
|
|
21
22
|
import torchax
|
|
22
23
|
from torchax.ops import jaten
|
|
23
|
-
import jax
|
|
24
|
-
import functools
|
|
25
24
|
|
|
26
25
|
|
|
27
26
|
def _jit_composite_impl(composite_name, jaxpr_impl, **jit_args):
|
|
@@ -75,9 +74,7 @@ def register_torch_composite(composite_name, impl, *ops, **jit_args):
|
|
|
75
74
|
|
|
76
75
|
@jaten.op(*ops)
|
|
77
76
|
def _composite_impl(*args):
|
|
78
|
-
|
|
79
77
|
class ImplWrapper(torch.nn.Module):
|
|
80
|
-
|
|
81
78
|
def __init__(self):
|
|
82
79
|
super().__init__()
|
|
83
80
|
|
|
@@ -90,5 +87,8 @@ def register_torch_composite(composite_name, impl, *ops, **jit_args):
|
|
|
90
87
|
# module once during registration, potentially missing op registrations that
|
|
91
88
|
# come after. I.e. may miss nested abstractions if we build jaxpr AoT.
|
|
92
89
|
state, jfn = torchax.extract_jax(ImplWrapper())
|
|
93
|
-
|
|
90
|
+
|
|
91
|
+
def jaxpr_impl(*args):
|
|
92
|
+
return jfn(state, (*args,))
|
|
93
|
+
|
|
94
94
|
return _jit_composite_impl(composite_name, jaxpr_impl, **jit_args)(*args)
|