torchax 0.0.10.dev20251118__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

@@ -0,0 +1,211 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from collections.abc import Callable, Sequence
16
+
17
+ import numpy as np
18
+ from jax import lax
19
+ from jax import numpy as jnp
20
+ from jax._src import core
21
+ from jax._src.image.scale import ResizeMethod, _kernels
22
+ from jax._src.numpy.util import promote_dtypes_inexact
23
+ from jax._src.util import canonicalize_axis
24
+
25
+ # TODO: This block of code needs to be revisited based on https://github.com/jax-ml/jax/issues/24106
26
+ # START ----------------- JAX code copied for fixing scale_and_translate -----------------------------
27
+
28
+ # JAX Link: https://github.com/jax-ml/jax/blob/18f48bd52abe907ff9818da52f3d195d32910c1b/jax/_src/image/scale.py#L52
29
+
30
+
31
+ def compute_weight_mat(
32
+ input_size: core.DimSize,
33
+ output_size: core.DimSize,
34
+ scale,
35
+ translation,
36
+ kernel: Callable,
37
+ antialias: bool,
38
+ ):
39
+ dtype = jnp.result_type(scale, translation)
40
+ inv_scale = 1.0 / scale
41
+ # When downsampling the kernel should be scaled since we want to low pass
42
+ # filter and interpolate, but when upsampling it should not be since we only
43
+ # want to interpolate.
44
+ kernel_scale = jnp.maximum(inv_scale, 1.0) if antialias else 1.0
45
+ sample_f = (
46
+ (jnp.arange(output_size, dtype=dtype) + 0.5) * inv_scale
47
+ - translation * inv_scale
48
+ - 0.5
49
+ )
50
+ x = (
51
+ jnp.abs(
52
+ sample_f[jnp.newaxis, :] - jnp.arange(input_size, dtype=dtype)[:, jnp.newaxis]
53
+ )
54
+ / kernel_scale
55
+ )
56
+ weights = kernel(x)
57
+
58
+ total_weight_sum = jnp.sum(weights, axis=0, keepdims=True)
59
+ weights = jnp.where(
60
+ jnp.abs(total_weight_sum) > 1000.0 * float(np.finfo(np.float32).eps),
61
+ jnp.divide(weights, jnp.where(total_weight_sum != 0, total_weight_sum, 1)),
62
+ 0,
63
+ )
64
+ # Zero out weights where the sample location is completely outside the input
65
+ # range.
66
+ # Note sample_f has already had the 0.5 removed, hence the weird range below.
67
+
68
+ # (barney-s) -------------- returning weights without zeroing ---------------------
69
+ return weights
70
+ input_size_minus_0_5 = core.dimension_as_value(input_size) - 0.5
71
+ return jnp.where(
72
+ jnp.logical_and(sample_f >= -0.5, sample_f <= input_size_minus_0_5)[jnp.newaxis, :],
73
+ weights,
74
+ 0,
75
+ )
76
+ # (barney-s) -------------- END returning weights without zeroing ---------------------
77
+
78
+
79
+ # JAX Link: https://github.com/jax-ml/jax/blob/18f48bd52abe907ff9818da52f3d195d32910c1b/jax/_src/image/scale.py#L86
80
+
81
+
82
+ def _scale_and_translate(
83
+ x,
84
+ output_shape: core.Shape,
85
+ spatial_dims: Sequence[int],
86
+ scale,
87
+ translation,
88
+ kernel,
89
+ antialias: bool,
90
+ precision,
91
+ ):
92
+ input_shape = x.shape
93
+ assert len(input_shape) == len(output_shape)
94
+ assert len(spatial_dims) == len(scale)
95
+ assert len(spatial_dims) == len(translation)
96
+ if len(spatial_dims) == 0:
97
+ return x
98
+ contractions = []
99
+ in_indices = list(range(len(output_shape)))
100
+ out_indices = list(range(len(output_shape)))
101
+ for i, d in enumerate(spatial_dims):
102
+ d = canonicalize_axis(d, x.ndim)
103
+ m = input_shape[d]
104
+ n = output_shape[d]
105
+ w = compute_weight_mat(m, n, scale[i], translation[i], kernel, antialias).astype(
106
+ x.dtype
107
+ )
108
+ contractions.append(w)
109
+ contractions.append([d, len(output_shape) + i])
110
+ out_indices[d] = len(output_shape) + i
111
+ contractions.append(out_indices)
112
+ return jnp.einsum(x, in_indices, *contractions, precision=precision)
113
+
114
+
115
+ # JAX Link: https://github.com/jax-ml/jax/blob/18f48bd52abe907ff9818da52f3d195d32910c1b/jax/_src/image/scale.py#L172
116
+
117
+
118
+ # scale and translation here are scalar elements of an np.array, what is the
119
+ # correct type annotation?
120
+ def scale_and_translate(
121
+ image,
122
+ shape: core.Shape,
123
+ spatial_dims: Sequence[int],
124
+ scale,
125
+ translation,
126
+ # (barney-s) use string
127
+ method: str, # (barney-s) | ResizeMethod,
128
+ antialias: bool = True,
129
+ precision=lax.Precision.HIGHEST,
130
+ ):
131
+ """Apply a scale and translation to an image.
132
+
133
+ Generates a new image of shape 'shape' by resampling from the input image
134
+ using the sampling method corresponding to method. For 2D images, this
135
+ operation transforms a location in the input images, (x, y), to a location
136
+ in the output image according to::
137
+
138
+ (x * scale[1] + translation[1], y * scale[0] + translation[0])
139
+
140
+ (Note the *inverse* warp is used to generate the sample locations.)
141
+ Assumes half-centered pixels, i.e the pixel at integer location ``row, col``
142
+ has coordinates ``y, x = row + 0.5, col + 0.5``, and similarly for other input
143
+ image dimensions.
144
+
145
+ If an output location(pixel) maps to an input sample location that is outside
146
+ the input boundaries then the value for the output location will be set to
147
+ zero.
148
+
149
+ The ``method`` argument expects one of the following resize methods:
150
+
151
+ ``ResizeMethod.LINEAR``, ``"linear"``, ``"bilinear"``, ``"trilinear"``,
152
+ ``"triangle"`` `Linear interpolation`_. If ``antialias`` is ``True``, uses a
153
+ triangular filter when downsampling.
154
+
155
+ ``ResizeMethod.CUBIC``, ``"cubic"``, ``"bicubic"``, ``"tricubic"``
156
+ `Cubic interpolation`_, using the Keys cubic kernel.
157
+
158
+ ``ResizeMethod.LANCZOS3``, ``"lanczos3"``
159
+ `Lanczos resampling`_, using a kernel of radius 3.
160
+
161
+ ``ResizeMethod.LANCZOS5``, ``"lanczos5"``
162
+ `Lanczos resampling`_, using a kernel of radius 5.
163
+
164
+ .. _Linear interpolation: https://en.wikipedia.org/wiki/Bilinear_interpolation
165
+ .. _Cubic interpolation: https://en.wikipedia.org/wiki/Bicubic_interpolation
166
+ .. _Lanczos resampling: https://en.wikipedia.org/wiki/Lanczos_resampling
167
+
168
+ Args:
169
+ image: a JAX array.
170
+ shape: the output shape, as a sequence of integers with length equal to the
171
+ number of dimensions of `image`.
172
+ spatial_dims: A length K tuple specifying the spatial dimensions that the
173
+ passed scale and translation should be applied to.
174
+ scale: A [K] array with the same number of dimensions as image, containing
175
+ the scale to apply in each dimension.
176
+ translation: A [K] array with the same number of dimensions as image,
177
+ containing the translation to apply in each dimension.
178
+ method: the resizing method to use; either a ``ResizeMethod`` instance or a
179
+ string. Available methods are: LINEAR, LANCZOS3, LANCZOS5, CUBIC.
180
+ antialias: Should an antialiasing filter be used when downsampling? Defaults
181
+ to ``True``. Has no effect when upsampling.
182
+
183
+ Returns:
184
+ The scale and translated image.
185
+ """
186
+ shape = core.canonicalize_shape(shape)
187
+ if len(shape) != image.ndim:
188
+ msg = (
189
+ "shape must have length equal to the number of dimensions of x; "
190
+ f" {shape} vs {image.shape}"
191
+ )
192
+ raise ValueError(msg)
193
+ if isinstance(method, str):
194
+ method = ResizeMethod.from_string(method)
195
+ if method == ResizeMethod.NEAREST:
196
+ # Nearest neighbor is currently special-cased for straight resize, so skip
197
+ # for now.
198
+ raise ValueError(
199
+ "Nearest neighbor resampling is not currently supported for scale_and_translate."
200
+ )
201
+ assert isinstance(method, ResizeMethod)
202
+
203
+ kernel = _kernels[method]
204
+ (image,) = promote_dtypes_inexact(image)
205
+ scale, translation = promote_dtypes_inexact(scale, translation)
206
+ return _scale_and_translate(
207
+ image, shape, spatial_dims, scale, translation, kernel, antialias, precision
208
+ )
209
+
210
+
211
+ # END ----------------- END JAX code copied for testing -----------------------------
torchax/ops/jc10d.py ADDED
@@ -0,0 +1,64 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import jax
16
+ import jax.numpy as jnp
17
+ import torch
18
+
19
+ from torchax.ops import ops_registry
20
+
21
+
22
+ def op(*aten, **kwargs):
23
+ def inner(func):
24
+ for a in aten:
25
+ ops_registry.register_torch_dispatch_op(a, func, **kwargs)
26
+ return func
27
+
28
+ return inner
29
+
30
+
31
+ @op(torch.ops._c10d_functional.all_gather_into_tensor)
32
+ def _c10d_all_gather(input, group_size: int, group_name: str):
33
+ return jax.lax.all_gather(input, "torch_dist")
34
+
35
+
36
+ @op(torch.ops._c10d_functional.all_reduce)
37
+ def _c10d_all_reduce(self, reduceOp: str, group_name: str):
38
+ if reduceOp == "sum":
39
+ res = jax.lax.psum(self, axis_name="torch_dist")
40
+ elif reduceOp == "avg":
41
+ res = jax.lax.pmean(self, axis_name="torch_dist")
42
+ elif reduceOp == "min":
43
+ res = jax.lax.pmin(self, axis_name="torch_dist")
44
+ elif reduceOp == "max":
45
+ res = jax.lax.pmax(self, axis_name="torch_dist")
46
+ else:
47
+ raise RuntimeError(f"Reduce op {reduceOp} not implemented")
48
+ return res
49
+
50
+
51
+ @op(torch.ops._c10d_functional.broadcast)
52
+ def _c10d_broadcast(self, src: int, group_name: str):
53
+ masked = jnp.where(
54
+ jax.lax.axis_index("torch_dist") == src,
55
+ self,
56
+ jnp.zeros_like(self),
57
+ )
58
+ return jax.lax.psum(masked, "torch_dist")
59
+
60
+
61
+ @op(torch.ops._c10d_functional.wait_tensor)
62
+ def _c10d_wait_tensor(tensor):
63
+ # Async tensor is aleady `wait`ed by dispatcher
64
+ return tensor
torchax/ops/jimage.py ADDED
@@ -0,0 +1,122 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import jax
16
+ import jax.numpy as jnp
17
+
18
+
19
+ def cubic_kernel(x, a=-0.75):
20
+ """Cubic kernel with a = -0.75 (PyTorch-like Keys kernel)"""
21
+ absx = jnp.abs(x)
22
+ x2 = absx * absx
23
+ x3 = x2 * absx
24
+ cond1 = absx <= 1
25
+ cond2 = (absx > 1) & (absx < 2)
26
+ f1 = (a + 2) * x3 - (a + 3) * x2 + 1
27
+ f2 = a * x3 - 5 * a * x2 + 8 * a * absx - 4 * a
28
+ return jnp.where(cond1, f1, jnp.where(cond2, f2, 0.0))
29
+
30
+
31
+ def compute_contribs(
32
+ in_size, out_size, scale, support=2.0, align_corners=False, dtype=None
33
+ ):
34
+ if align_corners:
35
+ if out_size == 1:
36
+ in_coords = jnp.zeros((1,), dtype=dtype)
37
+ else:
38
+ in_coords = jnp.linspace(0, in_size - 1, out_size, dtype=dtype)
39
+ else:
40
+ out_coords = jnp.arange(out_size, dtype=dtype) + 0.5
41
+ in_coords = out_coords / scale - 0.5
42
+
43
+ left_idx = jnp.floor(in_coords).astype(jnp.int32) - 1
44
+ idxs = left_idx[:, None] + jnp.arange(4)
45
+
46
+ dx = in_coords[:, None] - idxs
47
+
48
+ weights = cubic_kernel(dx)
49
+
50
+ weights = weights / jnp.sum(weights, axis=1, keepdims=True)
51
+ return idxs, weights
52
+
53
+
54
+ def gather_weights(img, idxs, axis):
55
+ """Safely gather with boundary handling"""
56
+ idxs = jnp.clip(idxs, 0, img.shape[axis] - 1)
57
+ return jnp.take(img, idxs, axis=axis)
58
+
59
+
60
+ def interpolate_along_axis_bchw(img, idxs, weights, axis):
61
+ """
62
+ Interpolate along H (axis=2) or W (axis=3) for tensor (B, C, H, W).
63
+ idxs: (out_size, 4) int32 indices
64
+ weights: (out_size, 4) float32 weights
65
+ """
66
+ assert axis in (2, 3), "Axis must be 2 (H) or 3 (W)"
67
+ out_size = idxs.shape[0]
68
+ k = idxs.shape[1] # Typically 4 for cubic
69
+
70
+ # Clip to input bounds
71
+ idxs = jnp.clip(idxs, 0, img.shape[axis] - 1) # (out_size, 4)
72
+
73
+ def gather_and_weight(i):
74
+ idx = idxs[i] # (4,)
75
+ w = weights[i] # (4,)
76
+
77
+ def gather_one(offset):
78
+ return jnp.take(img, idx[offset], axis=axis) # shape (B, C, H, W)
79
+
80
+ gathered = jnp.stack([gather_one(o) for o in range(k)], axis=0) # (4, B, C, H, W)
81
+ weighted = jnp.tensordot(w, gathered, axes=(0, 0)) # (B, C, H, W)
82
+ return weighted
83
+
84
+ out = jax.vmap(gather_and_weight)(jnp.arange(out_size)) # (out_size, B, C, H, W)
85
+
86
+ # Move the interpolated axis back into place
87
+ if axis == 2: # interpolated over H
88
+ return jnp.moveaxis(out, 0, 2) # (B, C, out_H, W)
89
+ else: # axis == 3, interpolated over W
90
+ return jnp.moveaxis(out, 0, 3) # (B, C, H, out_W)
91
+
92
+
93
+ def interpolate_bicubic_no_aa(img, out_h, out_w, align_corners=False):
94
+ h, w = img.shape[-2:]
95
+ if align_corners and out_h > 1:
96
+ scale_y = (h - 1) / (out_h - 1)
97
+ else:
98
+ scale_y = out_h / h
99
+
100
+ if align_corners and out_w > 1:
101
+ scale_x = (w - 1) / (out_w - 1)
102
+ else:
103
+ scale_x = out_w / w
104
+
105
+ idxs_y, weights_y = compute_contribs(
106
+ h,
107
+ out_h,
108
+ scale_y,
109
+ align_corners=align_corners,
110
+ dtype=img.dtype,
111
+ )
112
+ tmp = interpolate_along_axis_bchw(img, idxs_y, weights_y, axis=2)
113
+
114
+ idxs_x, weights_x = compute_contribs(
115
+ w,
116
+ out_w,
117
+ scale_x,
118
+ align_corners=align_corners,
119
+ dtype=img.dtype,
120
+ )
121
+ out = interpolate_along_axis_bchw(tmp, idxs_x, weights_x, axis=3)
122
+ return out
@@ -0,0 +1,94 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """The `jlibrary` module has functions which help to preserve torch.library ops
16
+ during export. This includes aten ops, and custom operations.
17
+ """
18
+
19
+ import jax
20
+ import torch
21
+
22
+ import torchax
23
+ from torchax.ops import jaten
24
+
25
+
26
+ def _jit_composite_impl(composite_name, jaxpr_impl, **jit_args):
27
+ """Wrap a jaxpr in a jitted function with the proper composite name
28
+ TODO: Wrap JIT in a `stablehlo.composite` op, instead of generating a call op.
29
+ """
30
+
31
+ def composite_impl(*args):
32
+ return jaxpr_impl(*args)
33
+
34
+ composite_impl.__name__ = composite_name
35
+ composite_impl.__qualname__ = composite_name
36
+ return jax.jit(composite_impl, **jit_args)
37
+
38
+
39
+ def register_jax_composite(composite_name, impl, *ops, **jit_args):
40
+ """Register a composite using a JAX implementation.
41
+ composite_name - The name of the library op to use in the exported composite
42
+ impl - A JAX lowering for the library operation
43
+ *ops - Variadic torch.ops to lower using `impl`.
44
+ **jit_args - Additional parameters to forward to JAX jit.
45
+
46
+ This is used to register custom lowerings with an explicit jaxpr
47
+ implementation, such as preserving a specific aten op using a jaten impl.
48
+
49
+ For custom torch op registration with a decomposition written in torch,
50
+ use `register_torch_composite`.
51
+
52
+ For jit params and troubleshooting see:
53
+ https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html
54
+ """
55
+
56
+ @jaten.op(*ops)
57
+ def _composite_impl(*args):
58
+ return _jit_composite_impl(composite_name, impl, **jit_args)(*args)
59
+
60
+
61
+ def register_torch_composite(composite_name, impl, *ops, **jit_args):
62
+ """Register a torch decomposition as a composite.
63
+ This is useful for registerring custom torch op libraries as composite ops.
64
+
65
+ The `impl` can be the `@impl` used to define the torch custom library op.
66
+ This must be a function or module impl that provides the decompositions, and
67
+ not an instance of the custom op.
68
+
69
+ TODO: Better error handling, or can we make this an instance of the op as a param?
70
+
71
+ For jit params and troubleshooting see:
72
+ https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html
73
+ """
74
+
75
+ @jaten.op(*ops)
76
+ def _composite_impl(*args):
77
+ class ImplWrapper(torch.nn.Module):
78
+ def __init__(self):
79
+ super().__init__()
80
+
81
+ def forward(self, *args):
82
+ return impl(*args)
83
+
84
+ # Note: avoid refactoring to share code with register_jaxpr_composite.
85
+ # The `extract_jax` call must live in the `@jaten.op` handler. If called
86
+ # outside of the handler, we would build the jaxpr representation of the
87
+ # module once during registration, potentially missing op registrations that
88
+ # come after. I.e. may miss nested abstractions if we build jaxpr AoT.
89
+ state, jfn = torchax.extract_jax(ImplWrapper())
90
+
91
+ def jaxpr_impl(*args):
92
+ return jfn(state, (*args,))
93
+
94
+ return _jit_composite_impl(composite_name, jaxpr_impl, **jit_args)(*args)