torchax 0.0.10.dev20251117__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,185 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from collections.abc import Sequence
16
+ from jax._src.numpy.util import promote_dtypes_inexact
17
+ import numpy as np
18
+ import jax
19
+ from jax import numpy as jnp
20
+ from jax._src.util import canonicalize_axis
21
+ from jax._src import core
22
+ from jax._src.image.scale import _kernels, ResizeMethod
23
+ from jax import lax
24
+ from typing import Callable
25
+
26
+ # TODO: This block of code needs to be revisited based on https://github.com/jax-ml/jax/issues/24106
27
+ # START ----------------- JAX code copied for fixing scale_and_translate -----------------------------
28
+
29
+ # JAX Link: https://github.com/jax-ml/jax/blob/18f48bd52abe907ff9818da52f3d195d32910c1b/jax/_src/image/scale.py#L52
30
+
31
+
32
+ def compute_weight_mat(input_size: core.DimSize, output_size: core.DimSize,
33
+ scale, translation, kernel: Callable, antialias: bool):
34
+ dtype = jnp.result_type(scale, translation)
35
+ inv_scale = 1. / scale
36
+ # When downsampling the kernel should be scaled since we want to low pass
37
+ # filter and interpolate, but when upsampling it should not be since we only
38
+ # want to interpolate.
39
+ kernel_scale = jnp.maximum(inv_scale, 1.) if antialias else 1.
40
+ sample_f = ((jnp.arange(output_size, dtype=dtype) + 0.5) * inv_scale -
41
+ translation * inv_scale - 0.5)
42
+ x = (
43
+ jnp.abs(sample_f[jnp.newaxis, :] -
44
+ jnp.arange(input_size, dtype=dtype)[:, jnp.newaxis]) /
45
+ kernel_scale)
46
+ weights = kernel(x)
47
+
48
+ total_weight_sum = jnp.sum(weights, axis=0, keepdims=True)
49
+ weights = jnp.where(
50
+ jnp.abs(total_weight_sum) > 1000. * float(np.finfo(np.float32).eps),
51
+ jnp.divide(weights, jnp.where(total_weight_sum != 0, total_weight_sum,
52
+ 1)), 0)
53
+ # Zero out weights where the sample location is completely outside the input
54
+ # range.
55
+ # Note sample_f has already had the 0.5 removed, hence the weird range below.
56
+
57
+ # (barney-s) -------------- returning weights without zeroing ---------------------
58
+ return weights
59
+ input_size_minus_0_5 = core.dimension_as_value(input_size) - 0.5
60
+ return jnp.where(
61
+ jnp.logical_and(sample_f >= -0.5, sample_f
62
+ <= input_size_minus_0_5)[jnp.newaxis, :], weights, 0)
63
+ # (barney-s) -------------- END returning weights without zeroing ---------------------
64
+
65
+
66
+ # JAX Link: https://github.com/jax-ml/jax/blob/18f48bd52abe907ff9818da52f3d195d32910c1b/jax/_src/image/scale.py#L86
67
+
68
+
69
+ def _scale_and_translate(x, output_shape: core.Shape,
70
+ spatial_dims: Sequence[int], scale, translation,
71
+ kernel, antialias: bool, precision):
72
+ input_shape = x.shape
73
+ assert len(input_shape) == len(output_shape)
74
+ assert len(spatial_dims) == len(scale)
75
+ assert len(spatial_dims) == len(translation)
76
+ if len(spatial_dims) == 0:
77
+ return x
78
+ contractions = []
79
+ in_indices = list(range(len(output_shape)))
80
+ out_indices = list(range(len(output_shape)))
81
+ for i, d in enumerate(spatial_dims):
82
+ d = canonicalize_axis(d, x.ndim)
83
+ m = input_shape[d]
84
+ n = output_shape[d]
85
+ w = compute_weight_mat(m, n, scale[i], translation[i], kernel,
86
+ antialias).astype(x.dtype)
87
+ contractions.append(w)
88
+ contractions.append([d, len(output_shape) + i])
89
+ out_indices[d] = len(output_shape) + i
90
+ contractions.append(out_indices)
91
+ return jnp.einsum(x, in_indices, *contractions, precision=precision)
92
+
93
+
94
+ # JAX Link: https://github.com/jax-ml/jax/blob/18f48bd52abe907ff9818da52f3d195d32910c1b/jax/_src/image/scale.py#L172
95
+
96
+
97
+ # scale and translation here are scalar elements of an np.array, what is the
98
+ # correct type annotation?
99
+ def scale_and_translate(
100
+ image,
101
+ shape: core.Shape,
102
+ spatial_dims: Sequence[int],
103
+ scale,
104
+ translation,
105
+ # (barney-s) use string
106
+ method: str, #(barney-s) | ResizeMethod,
107
+ antialias: bool = True,
108
+ precision=lax.Precision.HIGHEST):
109
+ """Apply a scale and translation to an image.
110
+
111
+ Generates a new image of shape 'shape' by resampling from the input image
112
+ using the sampling method corresponding to method. For 2D images, this
113
+ operation transforms a location in the input images, (x, y), to a location
114
+ in the output image according to::
115
+
116
+ (x * scale[1] + translation[1], y * scale[0] + translation[0])
117
+
118
+ (Note the *inverse* warp is used to generate the sample locations.)
119
+ Assumes half-centered pixels, i.e the pixel at integer location ``row, col``
120
+ has coordinates ``y, x = row + 0.5, col + 0.5``, and similarly for other input
121
+ image dimensions.
122
+
123
+ If an output location(pixel) maps to an input sample location that is outside
124
+ the input boundaries then the value for the output location will be set to
125
+ zero.
126
+
127
+ The ``method`` argument expects one of the following resize methods:
128
+
129
+ ``ResizeMethod.LINEAR``, ``"linear"``, ``"bilinear"``, ``"trilinear"``,
130
+ ``"triangle"`` `Linear interpolation`_. If ``antialias`` is ``True``, uses a
131
+ triangular filter when downsampling.
132
+
133
+ ``ResizeMethod.CUBIC``, ``"cubic"``, ``"bicubic"``, ``"tricubic"``
134
+ `Cubic interpolation`_, using the Keys cubic kernel.
135
+
136
+ ``ResizeMethod.LANCZOS3``, ``"lanczos3"``
137
+ `Lanczos resampling`_, using a kernel of radius 3.
138
+
139
+ ``ResizeMethod.LANCZOS5``, ``"lanczos5"``
140
+ `Lanczos resampling`_, using a kernel of radius 5.
141
+
142
+ .. _Linear interpolation: https://en.wikipedia.org/wiki/Bilinear_interpolation
143
+ .. _Cubic interpolation: https://en.wikipedia.org/wiki/Bicubic_interpolation
144
+ .. _Lanczos resampling: https://en.wikipedia.org/wiki/Lanczos_resampling
145
+
146
+ Args:
147
+ image: a JAX array.
148
+ shape: the output shape, as a sequence of integers with length equal to the
149
+ number of dimensions of `image`.
150
+ spatial_dims: A length K tuple specifying the spatial dimensions that the
151
+ passed scale and translation should be applied to.
152
+ scale: A [K] array with the same number of dimensions as image, containing
153
+ the scale to apply in each dimension.
154
+ translation: A [K] array with the same number of dimensions as image,
155
+ containing the translation to apply in each dimension.
156
+ method: the resizing method to use; either a ``ResizeMethod`` instance or a
157
+ string. Available methods are: LINEAR, LANCZOS3, LANCZOS5, CUBIC.
158
+ antialias: Should an antialiasing filter be used when downsampling? Defaults
159
+ to ``True``. Has no effect when upsampling.
160
+
161
+ Returns:
162
+ The scale and translated image.
163
+ """
164
+ shape = core.canonicalize_shape(shape)
165
+ if len(shape) != image.ndim:
166
+ msg = ('shape must have length equal to the number of dimensions of x; '
167
+ f' {shape} vs {image.shape}')
168
+ raise ValueError(msg)
169
+ if isinstance(method, str):
170
+ method = ResizeMethod.from_string(method)
171
+ if method == ResizeMethod.NEAREST:
172
+ # Nearest neighbor is currently special-cased for straight resize, so skip
173
+ # for now.
174
+ raise ValueError('Nearest neighbor resampling is not currently supported '
175
+ 'for scale_and_translate.')
176
+ assert isinstance(method, ResizeMethod)
177
+
178
+ kernel = _kernels[method]
179
+ image, = promote_dtypes_inexact(image)
180
+ scale, translation = promote_dtypes_inexact(scale, translation)
181
+ return _scale_and_translate(image, shape, spatial_dims, scale, translation,
182
+ kernel, antialias, precision)
183
+
184
+
185
+ # END ----------------- END JAX code copied for testing -----------------------------
torchax/ops/jc10d.py ADDED
@@ -0,0 +1,66 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import jax
17
+ import jax.numpy as jnp
18
+
19
+ from torchax.ops import ops_registry
20
+
21
+
22
+ def op(*aten, **kwargs):
23
+
24
+ def inner(func):
25
+ for a in aten:
26
+ ops_registry.register_torch_dispatch_op(a, func, **kwargs)
27
+ return func
28
+
29
+ return inner
30
+
31
+
32
+ @op(torch.ops._c10d_functional.all_gather_into_tensor)
33
+ def _c10d_all_gather(input, group_size: int, group_name: str):
34
+ return jax.lax.all_gather(input, "torch_dist")
35
+
36
+
37
+ @op(torch.ops._c10d_functional.all_reduce)
38
+ def _c10d_all_reduce(self, reduceOp: str, group_name: str):
39
+
40
+ if reduceOp == "sum":
41
+ res = jax.lax.psum(self, axis_name="torch_dist")
42
+ elif reduceOp == "avg":
43
+ res = jax.lax.pmean(self, axis_name="torch_dist")
44
+ elif reduceOp == "min":
45
+ res = jax.lax.pmin(self, axis_name="torch_dist")
46
+ elif reduceOp == "max":
47
+ res = jax.lax.pmax(self, axis_name="torch_dist")
48
+ else:
49
+ raise RuntimeError(f"Reduce op {reduceOp} not implemented")
50
+ return res
51
+
52
+
53
+ @op(torch.ops._c10d_functional.broadcast)
54
+ def _c10d_broadcast(self, src: int, group_name: str):
55
+ masked = jnp.where(
56
+ jax.lax.axis_index("torch_dist") == src,
57
+ self,
58
+ jnp.zeros_like(self),
59
+ )
60
+ return jax.lax.psum(masked, "torch_dist")
61
+
62
+
63
+ @op(torch.ops._c10d_functional.wait_tensor)
64
+ def _c10d_wait_tensor(tensor):
65
+ # Async tensor is aleady `wait`ed by dispatcher
66
+ return tensor
torchax/ops/jimage.py ADDED
@@ -0,0 +1,127 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import jax
16
+ import jax.numpy as jnp
17
+
18
+
19
+ def cubic_kernel(x, a=-0.75):
20
+ """Cubic kernel with a = -0.75 (PyTorch-like Keys kernel)"""
21
+ absx = jnp.abs(x)
22
+ x2 = absx * absx
23
+ x3 = x2 * absx
24
+ cond1 = (absx <= 1)
25
+ cond2 = (absx > 1) & (absx < 2)
26
+ f1 = (a + 2) * x3 - (a + 3) * x2 + 1
27
+ f2 = a * x3 - 5 * a * x2 + 8 * a * absx - 4 * a
28
+ return jnp.where(cond1, f1, jnp.where(cond2, f2, 0.0))
29
+
30
+
31
+ def compute_contribs(in_size,
32
+ out_size,
33
+ scale,
34
+ support=2.0,
35
+ align_corners=False,
36
+ dtype=None):
37
+ if align_corners:
38
+ if out_size == 1:
39
+ in_coords = jnp.zeros((1,), dtype=dtype)
40
+ else:
41
+ in_coords = jnp.linspace(0, in_size - 1, out_size, dtype=dtype)
42
+ else:
43
+ out_coords = jnp.arange(out_size, dtype=dtype) + 0.5
44
+ in_coords = out_coords / scale - 0.5
45
+
46
+ left_idx = jnp.floor(in_coords).astype(jnp.int32) - 1
47
+ idxs = left_idx[:, None] + jnp.arange(4)
48
+
49
+ dx = in_coords[:, None] - idxs
50
+
51
+ weights = cubic_kernel(dx)
52
+
53
+ weights = weights / jnp.sum(weights, axis=1, keepdims=True)
54
+ return idxs, weights
55
+
56
+
57
+ def gather_weights(img, idxs, axis):
58
+ """Safely gather with boundary handling"""
59
+ idxs = jnp.clip(idxs, 0, img.shape[axis] - 1)
60
+ return jnp.take(img, idxs, axis=axis)
61
+
62
+
63
+ def interpolate_along_axis_bchw(img, idxs, weights, axis):
64
+ """
65
+ Interpolate along H (axis=2) or W (axis=3) for tensor (B, C, H, W).
66
+ idxs: (out_size, 4) int32 indices
67
+ weights: (out_size, 4) float32 weights
68
+ """
69
+ assert axis in (2, 3), "Axis must be 2 (H) or 3 (W)"
70
+ out_size = idxs.shape[0]
71
+ k = idxs.shape[1] # Typically 4 for cubic
72
+
73
+ # Clip to input bounds
74
+ idxs = jnp.clip(idxs, 0, img.shape[axis] - 1) # (out_size, 4)
75
+
76
+ def gather_and_weight(i):
77
+ idx = idxs[i] # (4,)
78
+ w = weights[i] # (4,)
79
+
80
+ def gather_one(offset):
81
+ return jnp.take(img, idx[offset], axis=axis) # shape (B, C, H, W)
82
+
83
+ gathered = jnp.stack([gather_one(o) for o in range(k)],
84
+ axis=0) # (4, B, C, H, W)
85
+ weighted = jnp.tensordot(w, gathered, axes=(0, 0)) # (B, C, H, W)
86
+ return weighted
87
+
88
+ out = jax.vmap(gather_and_weight)(
89
+ jnp.arange(out_size)) # (out_size, B, C, H, W)
90
+
91
+ # Move the interpolated axis back into place
92
+ if axis == 2: # interpolated over H
93
+ return jnp.moveaxis(out, 0, 2) # (B, C, out_H, W)
94
+ else: # axis == 3, interpolated over W
95
+ return jnp.moveaxis(out, 0, 3) # (B, C, H, out_W)
96
+
97
+
98
+ def interpolate_bicubic_no_aa(img, out_h, out_w, align_corners=False):
99
+ h, w = img.shape[-2:]
100
+ if align_corners and out_h > 1:
101
+ scale_y = (h - 1) / (out_h - 1)
102
+ else:
103
+ scale_y = out_h / h
104
+
105
+ if align_corners and out_w > 1:
106
+ scale_x = (w - 1) / (out_w - 1)
107
+ else:
108
+ scale_x = out_w / w
109
+
110
+ idxs_y, weights_y = compute_contribs(
111
+ h,
112
+ out_h,
113
+ scale_y,
114
+ align_corners=align_corners,
115
+ dtype=img.dtype,
116
+ )
117
+ tmp = interpolate_along_axis_bchw(img, idxs_y, weights_y, axis=2)
118
+
119
+ idxs_x, weights_x = compute_contribs(
120
+ w,
121
+ out_w,
122
+ scale_x,
123
+ align_corners=align_corners,
124
+ dtype=img.dtype,
125
+ )
126
+ out = interpolate_along_axis_bchw(tmp, idxs_x, weights_x, axis=3)
127
+ return out
@@ -0,0 +1,94 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """The `jlibrary` module has functions which help to preserve torch.library ops
16
+ during export. This includes aten ops, and custom operations.
17
+ """
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torchax
22
+ from torchax.ops import jaten
23
+ import jax
24
+ import functools
25
+
26
+
27
+ def _jit_composite_impl(composite_name, jaxpr_impl, **jit_args):
28
+ """Wrap a jaxpr in a jitted function with the proper composite name
29
+ TODO: Wrap JIT in a `stablehlo.composite` op, instead of generating a call op.
30
+ """
31
+
32
+ def composite_impl(*args):
33
+ return jaxpr_impl(*args)
34
+
35
+ composite_impl.__name__ = composite_name
36
+ composite_impl.__qualname__ = composite_name
37
+ return jax.jit(composite_impl, **jit_args)
38
+
39
+
40
+ def register_jax_composite(composite_name, impl, *ops, **jit_args):
41
+ """Register a composite using a JAX implementation.
42
+ composite_name - The name of the library op to use in the exported composite
43
+ impl - A JAX lowering for the library operation
44
+ *ops - Variadic torch.ops to lower using `impl`.
45
+ **jit_args - Additional parameters to forward to JAX jit.
46
+
47
+ This is used to register custom lowerings with an explicit jaxpr
48
+ implementation, such as preserving a specific aten op using a jaten impl.
49
+
50
+ For custom torch op registration with a decomposition written in torch,
51
+ use `register_torch_composite`.
52
+
53
+ For jit params and troubleshooting see:
54
+ https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html
55
+ """
56
+
57
+ @jaten.op(*ops)
58
+ def _composite_impl(*args):
59
+ return _jit_composite_impl(composite_name, impl, **jit_args)(*args)
60
+
61
+
62
+ def register_torch_composite(composite_name, impl, *ops, **jit_args):
63
+ """Register a torch decomposition as a composite.
64
+ This is useful for registerring custom torch op libraries as composite ops.
65
+
66
+ The `impl` can be the `@impl` used to define the torch custom library op.
67
+ This must be a function or module impl that provides the decompositions, and
68
+ not an instance of the custom op.
69
+
70
+ TODO: Better error handling, or can we make this an instance of the op as a param?
71
+
72
+ For jit params and troubleshooting see:
73
+ https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html
74
+ """
75
+
76
+ @jaten.op(*ops)
77
+ def _composite_impl(*args):
78
+
79
+ class ImplWrapper(torch.nn.Module):
80
+
81
+ def __init__(self):
82
+ super().__init__()
83
+
84
+ def forward(self, *args):
85
+ return impl(*args)
86
+
87
+ # Note: avoid refactoring to share code with register_jaxpr_composite.
88
+ # The `extract_jax` call must live in the `@jaten.op` handler. If called
89
+ # outside of the handler, we would build the jaxpr representation of the
90
+ # module once during registration, potentially missing op registrations that
91
+ # come after. I.e. may miss nested abstractions if we build jaxpr AoT.
92
+ state, jfn = torchax.extract_jax(ImplWrapper())
93
+ jaxpr_impl = lambda *args: jfn(state, tuple([*args]))
94
+ return _jit_composite_impl(composite_name, jaxpr_impl, **jit_args)(*args)