torchax 0.0.10.dev20251116__py3-none-any.whl → 0.0.11.dev202617__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

@@ -12,16 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from collections.abc import Sequence
16
- from jax._src.numpy.util import promote_dtypes_inexact
15
+ from collections.abc import Callable, Sequence
16
+
17
17
  import numpy as np
18
- import jax
18
+ from jax import lax
19
19
  from jax import numpy as jnp
20
- from jax._src.util import canonicalize_axis
21
20
  from jax._src import core
22
- from jax._src.image.scale import _kernels, ResizeMethod
23
- from jax import lax
24
- from typing import Callable
21
+ from jax._src.image.scale import ResizeMethod, _kernels
22
+ from jax._src.numpy.util import promote_dtypes_inexact
23
+ from jax._src.util import canonicalize_axis
25
24
 
26
25
  # TODO: This block of code needs to be revisited based on https://github.com/jax-ml/jax/issues/24106
27
26
  # START ----------------- JAX code copied for fixing scale_and_translate -----------------------------
@@ -29,27 +28,39 @@ from typing import Callable
29
28
  # JAX Link: https://github.com/jax-ml/jax/blob/18f48bd52abe907ff9818da52f3d195d32910c1b/jax/_src/image/scale.py#L52
30
29
 
31
30
 
32
- def compute_weight_mat(input_size: core.DimSize, output_size: core.DimSize,
33
- scale, translation, kernel: Callable, antialias: bool):
31
+ def compute_weight_mat(
32
+ input_size: core.DimSize,
33
+ output_size: core.DimSize,
34
+ scale,
35
+ translation,
36
+ kernel: Callable,
37
+ antialias: bool,
38
+ ):
34
39
  dtype = jnp.result_type(scale, translation)
35
- inv_scale = 1. / scale
40
+ inv_scale = 1.0 / scale
36
41
  # When downsampling the kernel should be scaled since we want to low pass
37
42
  # filter and interpolate, but when upsampling it should not be since we only
38
43
  # want to interpolate.
39
- kernel_scale = jnp.maximum(inv_scale, 1.) if antialias else 1.
40
- sample_f = ((jnp.arange(output_size, dtype=dtype) + 0.5) * inv_scale -
41
- translation * inv_scale - 0.5)
44
+ kernel_scale = jnp.maximum(inv_scale, 1.0) if antialias else 1.0
45
+ sample_f = (
46
+ (jnp.arange(output_size, dtype=dtype) + 0.5) * inv_scale
47
+ - translation * inv_scale
48
+ - 0.5
49
+ )
42
50
  x = (
43
- jnp.abs(sample_f[jnp.newaxis, :] -
44
- jnp.arange(input_size, dtype=dtype)[:, jnp.newaxis]) /
45
- kernel_scale)
51
+ jnp.abs(
52
+ sample_f[jnp.newaxis, :] - jnp.arange(input_size, dtype=dtype)[:, jnp.newaxis]
53
+ )
54
+ / kernel_scale
55
+ )
46
56
  weights = kernel(x)
47
57
 
48
58
  total_weight_sum = jnp.sum(weights, axis=0, keepdims=True)
49
59
  weights = jnp.where(
50
- jnp.abs(total_weight_sum) > 1000. * float(np.finfo(np.float32).eps),
51
- jnp.divide(weights, jnp.where(total_weight_sum != 0, total_weight_sum,
52
- 1)), 0)
60
+ jnp.abs(total_weight_sum) > 1000.0 * float(np.finfo(np.float32).eps),
61
+ jnp.divide(weights, jnp.where(total_weight_sum != 0, total_weight_sum, 1)),
62
+ 0,
63
+ )
53
64
  # Zero out weights where the sample location is completely outside the input
54
65
  # range.
55
66
  # Note sample_f has already had the 0.5 removed, hence the weird range below.
@@ -58,17 +69,26 @@ def compute_weight_mat(input_size: core.DimSize, output_size: core.DimSize,
58
69
  return weights
59
70
  input_size_minus_0_5 = core.dimension_as_value(input_size) - 0.5
60
71
  return jnp.where(
61
- jnp.logical_and(sample_f >= -0.5, sample_f
62
- <= input_size_minus_0_5)[jnp.newaxis, :], weights, 0)
72
+ jnp.logical_and(sample_f >= -0.5, sample_f <= input_size_minus_0_5)[jnp.newaxis, :],
73
+ weights,
74
+ 0,
75
+ )
63
76
  # (barney-s) -------------- END returning weights without zeroing ---------------------
64
77
 
65
78
 
66
79
  # JAX Link: https://github.com/jax-ml/jax/blob/18f48bd52abe907ff9818da52f3d195d32910c1b/jax/_src/image/scale.py#L86
67
80
 
68
81
 
69
- def _scale_and_translate(x, output_shape: core.Shape,
70
- spatial_dims: Sequence[int], scale, translation,
71
- kernel, antialias: bool, precision):
82
+ def _scale_and_translate(
83
+ x,
84
+ output_shape: core.Shape,
85
+ spatial_dims: Sequence[int],
86
+ scale,
87
+ translation,
88
+ kernel,
89
+ antialias: bool,
90
+ precision,
91
+ ):
72
92
  input_shape = x.shape
73
93
  assert len(input_shape) == len(output_shape)
74
94
  assert len(spatial_dims) == len(scale)
@@ -82,8 +102,9 @@ def _scale_and_translate(x, output_shape: core.Shape,
82
102
  d = canonicalize_axis(d, x.ndim)
83
103
  m = input_shape[d]
84
104
  n = output_shape[d]
85
- w = compute_weight_mat(m, n, scale[i], translation[i], kernel,
86
- antialias).astype(x.dtype)
105
+ w = compute_weight_mat(m, n, scale[i], translation[i], kernel, antialias).astype(
106
+ x.dtype
107
+ )
87
108
  contractions.append(w)
88
109
  contractions.append([d, len(output_shape) + i])
89
110
  out_indices[d] = len(output_shape) + i
@@ -97,15 +118,16 @@ def _scale_and_translate(x, output_shape: core.Shape,
97
118
  # scale and translation here are scalar elements of an np.array, what is the
98
119
  # correct type annotation?
99
120
  def scale_and_translate(
100
- image,
101
- shape: core.Shape,
102
- spatial_dims: Sequence[int],
103
- scale,
104
- translation,
105
- # (barney-s) use string
106
- method: str, #(barney-s) | ResizeMethod,
107
- antialias: bool = True,
108
- precision=lax.Precision.HIGHEST):
121
+ image,
122
+ shape: core.Shape,
123
+ spatial_dims: Sequence[int],
124
+ scale,
125
+ translation,
126
+ # (barney-s) use string
127
+ method: str, # (barney-s) | ResizeMethod,
128
+ antialias: bool = True,
129
+ precision=lax.Precision.HIGHEST,
130
+ ):
109
131
  """Apply a scale and translation to an image.
110
132
 
111
133
  Generates a new image of shape 'shape' by resampling from the input image
@@ -163,23 +185,27 @@ def scale_and_translate(
163
185
  """
164
186
  shape = core.canonicalize_shape(shape)
165
187
  if len(shape) != image.ndim:
166
- msg = ('shape must have length equal to the number of dimensions of x; '
167
- f' {shape} vs {image.shape}')
188
+ msg = (
189
+ "shape must have length equal to the number of dimensions of x; "
190
+ f" {shape} vs {image.shape}"
191
+ )
168
192
  raise ValueError(msg)
169
193
  if isinstance(method, str):
170
194
  method = ResizeMethod.from_string(method)
171
195
  if method == ResizeMethod.NEAREST:
172
196
  # Nearest neighbor is currently special-cased for straight resize, so skip
173
197
  # for now.
174
- raise ValueError('Nearest neighbor resampling is not currently supported '
175
- 'for scale_and_translate.')
198
+ raise ValueError(
199
+ "Nearest neighbor resampling is not currently supported for scale_and_translate."
200
+ )
176
201
  assert isinstance(method, ResizeMethod)
177
202
 
178
203
  kernel = _kernels[method]
179
- image, = promote_dtypes_inexact(image)
204
+ (image,) = promote_dtypes_inexact(image)
180
205
  scale, translation = promote_dtypes_inexact(scale, translation)
181
- return _scale_and_translate(image, shape, spatial_dims, scale, translation,
182
- kernel, antialias, precision)
206
+ return _scale_and_translate(
207
+ image, shape, spatial_dims, scale, translation, kernel, antialias, precision
208
+ )
183
209
 
184
210
 
185
211
  # END ----------------- END JAX code copied for testing -----------------------------
torchax/ops/jc10d.py CHANGED
@@ -12,15 +12,14 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import torch
16
15
  import jax
17
16
  import jax.numpy as jnp
17
+ import torch
18
18
 
19
19
  from torchax.ops import ops_registry
20
20
 
21
21
 
22
22
  def op(*aten, **kwargs):
23
-
24
23
  def inner(func):
25
24
  for a in aten:
26
25
  ops_registry.register_torch_dispatch_op(a, func, **kwargs)
@@ -36,7 +35,6 @@ def _c10d_all_gather(input, group_size: int, group_name: str):
36
35
 
37
36
  @op(torch.ops._c10d_functional.all_reduce)
38
37
  def _c10d_all_reduce(self, reduceOp: str, group_name: str):
39
-
40
38
  if reduceOp == "sum":
41
39
  res = jax.lax.psum(self, axis_name="torch_dist")
42
40
  elif reduceOp == "avg":
@@ -53,9 +51,9 @@ def _c10d_all_reduce(self, reduceOp: str, group_name: str):
53
51
  @op(torch.ops._c10d_functional.broadcast)
54
52
  def _c10d_broadcast(self, src: int, group_name: str):
55
53
  masked = jnp.where(
56
- jax.lax.axis_index("torch_dist") == src,
57
- self,
58
- jnp.zeros_like(self),
54
+ jax.lax.axis_index("torch_dist") == src,
55
+ self,
56
+ jnp.zeros_like(self),
59
57
  )
60
58
  return jax.lax.psum(masked, "torch_dist")
61
59
 
torchax/ops/jimage.py CHANGED
@@ -21,19 +21,16 @@ def cubic_kernel(x, a=-0.75):
21
21
  absx = jnp.abs(x)
22
22
  x2 = absx * absx
23
23
  x3 = x2 * absx
24
- cond1 = (absx <= 1)
24
+ cond1 = absx <= 1
25
25
  cond2 = (absx > 1) & (absx < 2)
26
26
  f1 = (a + 2) * x3 - (a + 3) * x2 + 1
27
27
  f2 = a * x3 - 5 * a * x2 + 8 * a * absx - 4 * a
28
28
  return jnp.where(cond1, f1, jnp.where(cond2, f2, 0.0))
29
29
 
30
30
 
31
- def compute_contribs(in_size,
32
- out_size,
33
- scale,
34
- support=2.0,
35
- align_corners=False,
36
- dtype=None):
31
+ def compute_contribs(
32
+ in_size, out_size, scale, support=2.0, align_corners=False, dtype=None
33
+ ):
37
34
  if align_corners:
38
35
  if out_size == 1:
39
36
  in_coords = jnp.zeros((1,), dtype=dtype)
@@ -62,10 +59,10 @@ def gather_weights(img, idxs, axis):
62
59
 
63
60
  def interpolate_along_axis_bchw(img, idxs, weights, axis):
64
61
  """
65
- Interpolate along H (axis=2) or W (axis=3) for tensor (B, C, H, W).
66
- idxs: (out_size, 4) int32 indices
67
- weights: (out_size, 4) float32 weights
68
- """
62
+ Interpolate along H (axis=2) or W (axis=3) for tensor (B, C, H, W).
63
+ idxs: (out_size, 4) int32 indices
64
+ weights: (out_size, 4) float32 weights
65
+ """
69
66
  assert axis in (2, 3), "Axis must be 2 (H) or 3 (W)"
70
67
  out_size = idxs.shape[0]
71
68
  k = idxs.shape[1] # Typically 4 for cubic
@@ -80,13 +77,11 @@ def interpolate_along_axis_bchw(img, idxs, weights, axis):
80
77
  def gather_one(offset):
81
78
  return jnp.take(img, idx[offset], axis=axis) # shape (B, C, H, W)
82
79
 
83
- gathered = jnp.stack([gather_one(o) for o in range(k)],
84
- axis=0) # (4, B, C, H, W)
80
+ gathered = jnp.stack([gather_one(o) for o in range(k)], axis=0) # (4, B, C, H, W)
85
81
  weighted = jnp.tensordot(w, gathered, axes=(0, 0)) # (B, C, H, W)
86
82
  return weighted
87
83
 
88
- out = jax.vmap(gather_and_weight)(
89
- jnp.arange(out_size)) # (out_size, B, C, H, W)
84
+ out = jax.vmap(gather_and_weight)(jnp.arange(out_size)) # (out_size, B, C, H, W)
90
85
 
91
86
  # Move the interpolated axis back into place
92
87
  if axis == 2: # interpolated over H
@@ -108,20 +103,20 @@ def interpolate_bicubic_no_aa(img, out_h, out_w, align_corners=False):
108
103
  scale_x = out_w / w
109
104
 
110
105
  idxs_y, weights_y = compute_contribs(
111
- h,
112
- out_h,
113
- scale_y,
114
- align_corners=align_corners,
115
- dtype=img.dtype,
106
+ h,
107
+ out_h,
108
+ scale_y,
109
+ align_corners=align_corners,
110
+ dtype=img.dtype,
116
111
  )
117
112
  tmp = interpolate_along_axis_bchw(img, idxs_y, weights_y, axis=2)
118
113
 
119
114
  idxs_x, weights_x = compute_contribs(
120
- w,
121
- out_w,
122
- scale_x,
123
- align_corners=align_corners,
124
- dtype=img.dtype,
115
+ w,
116
+ out_w,
117
+ scale_x,
118
+ align_corners=align_corners,
119
+ dtype=img.dtype,
125
120
  )
126
121
  out = interpolate_along_axis_bchw(tmp, idxs_x, weights_x, axis=3)
127
122
  return out
torchax/ops/jlibrary.py CHANGED
@@ -16,12 +16,11 @@
16
16
  during export. This includes aten ops, and custom operations.
17
17
  """
18
18
 
19
+ import jax
19
20
  import torch
20
- import torch.nn as nn
21
+
21
22
  import torchax
22
23
  from torchax.ops import jaten
23
- import jax
24
- import functools
25
24
 
26
25
 
27
26
  def _jit_composite_impl(composite_name, jaxpr_impl, **jit_args):
@@ -75,9 +74,7 @@ def register_torch_composite(composite_name, impl, *ops, **jit_args):
75
74
 
76
75
  @jaten.op(*ops)
77
76
  def _composite_impl(*args):
78
-
79
77
  class ImplWrapper(torch.nn.Module):
80
-
81
78
  def __init__(self):
82
79
  super().__init__()
83
80
 
@@ -90,5 +87,8 @@ def register_torch_composite(composite_name, impl, *ops, **jit_args):
90
87
  # module once during registration, potentially missing op registrations that
91
88
  # come after. I.e. may miss nested abstractions if we build jaxpr AoT.
92
89
  state, jfn = torchax.extract_jax(ImplWrapper())
93
- jaxpr_impl = lambda *args: jfn(state, tuple([*args]))
90
+
91
+ def jaxpr_impl(*args):
92
+ return jfn(state, (*args,))
93
+
94
94
  return _jit_composite_impl(composite_name, jaxpr_impl, **jit_args)(*args)