torchax 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

@@ -1,4 +1,3 @@
1
-
2
1
  from collections.abc import Sequence
3
2
  from jax._src.numpy.util import promote_dtypes_inexact
4
3
  import numpy as np
@@ -15,12 +14,9 @@ from typing import Callable
15
14
 
16
15
  # JAX Link: https://github.com/jax-ml/jax/blob/18f48bd52abe907ff9818da52f3d195d32910c1b/jax/_src/image/scale.py#L52
17
16
 
18
- def compute_weight_mat(input_size: core.DimSize,
19
- output_size: core.DimSize,
20
- scale,
21
- translation,
22
- kernel: Callable,
23
- antialias: bool):
17
+
18
+ def compute_weight_mat(input_size: core.DimSize, output_size: core.DimSize,
19
+ scale, translation, kernel: Callable, antialias: bool):
24
20
  dtype = jnp.result_type(scale, translation)
25
21
  inv_scale = 1. / scale
26
22
  # When downsampling the kernel should be scaled since we want to low pass
@@ -38,8 +34,8 @@ def compute_weight_mat(input_size: core.DimSize,
38
34
  total_weight_sum = jnp.sum(weights, axis=0, keepdims=True)
39
35
  weights = jnp.where(
40
36
  jnp.abs(total_weight_sum) > 1000. * float(np.finfo(np.float32).eps),
41
- jnp.divide(weights, jnp.where(total_weight_sum != 0, total_weight_sum, 1)),
42
- 0)
37
+ jnp.divide(weights, jnp.where(total_weight_sum != 0, total_weight_sum,
38
+ 1)), 0)
43
39
  # Zero out weights where the sample location is completely outside the input
44
40
  # range.
45
41
  # Note sample_f has already had the 0.5 removed, hence the weird range below.
@@ -48,12 +44,14 @@ def compute_weight_mat(input_size: core.DimSize,
48
44
  return weights
49
45
  input_size_minus_0_5 = core.dimension_as_value(input_size) - 0.5
50
46
  return jnp.where(
51
- jnp.logical_and(sample_f >= -0.5,
52
- sample_f <= input_size_minus_0_5)[jnp.newaxis, :], weights, 0)
47
+ jnp.logical_and(sample_f >= -0.5, sample_f
48
+ <= input_size_minus_0_5)[jnp.newaxis, :], weights, 0)
53
49
  # (barney-s) -------------- END returning weights without zeroing ---------------------
54
50
 
51
+
55
52
  # JAX Link: https://github.com/jax-ml/jax/blob/18f48bd52abe907ff9818da52f3d195d32910c1b/jax/_src/image/scale.py#L86
56
53
 
54
+
57
55
  def _scale_and_translate(x, output_shape: core.Shape,
58
56
  spatial_dims: Sequence[int], scale, translation,
59
57
  kernel, antialias: bool, precision):
@@ -70,8 +68,8 @@ def _scale_and_translate(x, output_shape: core.Shape,
70
68
  d = canonicalize_axis(d, x.ndim)
71
69
  m = input_shape[d]
72
70
  n = output_shape[d]
73
- w = compute_weight_mat(m, n, scale[i], translation[i],
74
- kernel, antialias).astype(x.dtype)
71
+ w = compute_weight_mat(m, n, scale[i], translation[i], kernel,
72
+ antialias).astype(x.dtype)
75
73
  contractions.append(w)
76
74
  contractions.append([d, len(output_shape) + i])
77
75
  out_indices[d] = len(output_shape) + i
@@ -81,15 +79,19 @@ def _scale_and_translate(x, output_shape: core.Shape,
81
79
 
82
80
  # JAX Link: https://github.com/jax-ml/jax/blob/18f48bd52abe907ff9818da52f3d195d32910c1b/jax/_src/image/scale.py#L172
83
81
 
82
+
84
83
  # scale and translation here are scalar elements of an np.array, what is the
85
84
  # correct type annotation?
86
- def scale_and_translate(image, shape: core.Shape,
87
- spatial_dims: Sequence[int],
88
- scale, translation,
89
- # (barney-s) use string
90
- method: str, #(barney-s) | ResizeMethod,
91
- antialias: bool = True,
92
- precision=lax.Precision.HIGHEST):
85
+ def scale_and_translate(
86
+ image,
87
+ shape: core.Shape,
88
+ spatial_dims: Sequence[int],
89
+ scale,
90
+ translation,
91
+ # (barney-s) use string
92
+ method: str, #(barney-s) | ResizeMethod,
93
+ antialias: bool = True,
94
+ precision=lax.Precision.HIGHEST):
93
95
  """Apply a scale and translation to an image.
94
96
 
95
97
  Generates a new image of shape 'shape' by resampling from the input image
@@ -165,5 +167,5 @@ def scale_and_translate(image, shape: core.Shape,
165
167
  return _scale_and_translate(image, shape, spatial_dims, scale, translation,
166
168
  kernel, antialias, precision)
167
169
 
168
- # END ----------------- END JAX code copied for testing -----------------------------
169
170
 
171
+ # END ----------------- END JAX code copied for testing -----------------------------
torchax/ops/jc10d.py CHANGED
@@ -6,6 +6,7 @@ from torchax.ops import ops_registry
6
6
 
7
7
 
8
8
  def op(*aten, **kwargs):
9
+
9
10
  def inner(func):
10
11
  for a in aten:
11
12
  ops_registry.register_torch_dispatch_op(a, func, **kwargs)
@@ -21,7 +22,7 @@ def _c10d_all_gather(input, group_size: int, group_name: str):
21
22
 
22
23
  @op(torch.ops._c10d_functional.all_reduce)
23
24
  def _c10d_all_reduce(self, reduceOp: str, group_name: str):
24
-
25
+
25
26
  if reduceOp == "sum":
26
27
  res = jax.lax.psum(self, axis_name="torch_dist")
27
28
  elif reduceOp == "avg":
@@ -38,9 +39,9 @@ def _c10d_all_reduce(self, reduceOp: str, group_name: str):
38
39
  @op(torch.ops._c10d_functional.broadcast)
39
40
  def _c10d_broadcast(self, src: int, group_name: str):
40
41
  masked = jnp.where(
41
- jax.lax.axis_index("torch_dist") == src,
42
- self,
43
- jnp.zeros_like(self),
42
+ jax.lax.axis_index("torch_dist") == src,
43
+ self,
44
+ jnp.zeros_like(self),
44
45
  )
45
46
  return jax.lax.psum(masked, "torch_dist")
46
47
 
torchax/ops/jimage.py ADDED
@@ -0,0 +1,113 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+
4
+
5
+ def cubic_kernel(x, a=-0.75):
6
+ """Cubic kernel with a = -0.75 (PyTorch-like Keys kernel)"""
7
+ absx = jnp.abs(x)
8
+ x2 = absx * absx
9
+ x3 = x2 * absx
10
+ cond1 = (absx <= 1)
11
+ cond2 = (absx > 1) & (absx < 2)
12
+ f1 = (a + 2) * x3 - (a + 3) * x2 + 1
13
+ f2 = a * x3 - 5 * a * x2 + 8 * a * absx - 4 * a
14
+ return jnp.where(cond1, f1, jnp.where(cond2, f2, 0.0))
15
+
16
+
17
+ def compute_contribs(in_size,
18
+ out_size,
19
+ scale,
20
+ support=2.0,
21
+ align_corners=False,
22
+ dtype=None):
23
+ if align_corners:
24
+ if out_size == 1:
25
+ in_coords = jnp.zeros((1,), dtype=dtype)
26
+ else:
27
+ in_coords = jnp.linspace(0, in_size - 1, out_size, dtype=dtype)
28
+ else:
29
+ out_coords = jnp.arange(out_size, dtype=dtype) + 0.5
30
+ in_coords = out_coords / scale - 0.5
31
+
32
+ left_idx = jnp.floor(in_coords).astype(jnp.int32) - 1
33
+ idxs = left_idx[:, None] + jnp.arange(4)
34
+
35
+ dx = in_coords[:, None] - idxs
36
+
37
+ weights = cubic_kernel(dx)
38
+
39
+ weights = weights / jnp.sum(weights, axis=1, keepdims=True)
40
+ return idxs, weights
41
+
42
+
43
+ def gather_weights(img, idxs, axis):
44
+ """Safely gather with boundary handling"""
45
+ idxs = jnp.clip(idxs, 0, img.shape[axis] - 1)
46
+ return jnp.take(img, idxs, axis=axis)
47
+
48
+
49
+ def interpolate_along_axis_bchw(img, idxs, weights, axis):
50
+ """
51
+ Interpolate along H (axis=2) or W (axis=3) for tensor (B, C, H, W).
52
+ idxs: (out_size, 4) int32 indices
53
+ weights: (out_size, 4) float32 weights
54
+ """
55
+ assert axis in (2, 3), "Axis must be 2 (H) or 3 (W)"
56
+ out_size = idxs.shape[0]
57
+ k = idxs.shape[1] # Typically 4 for cubic
58
+
59
+ # Clip to input bounds
60
+ idxs = jnp.clip(idxs, 0, img.shape[axis] - 1) # (out_size, 4)
61
+
62
+ def gather_and_weight(i):
63
+ idx = idxs[i] # (4,)
64
+ w = weights[i] # (4,)
65
+
66
+ def gather_one(offset):
67
+ return jnp.take(img, idx[offset], axis=axis) # shape (B, C, H, W)
68
+
69
+ gathered = jnp.stack([gather_one(o) for o in range(k)],
70
+ axis=0) # (4, B, C, H, W)
71
+ weighted = jnp.tensordot(w, gathered, axes=(0, 0)) # (B, C, H, W)
72
+ return weighted
73
+
74
+ out = jax.vmap(gather_and_weight)(
75
+ jnp.arange(out_size)) # (out_size, B, C, H, W)
76
+
77
+ # Move the interpolated axis back into place
78
+ if axis == 2: # interpolated over H
79
+ return jnp.moveaxis(out, 0, 2) # (B, C, out_H, W)
80
+ else: # axis == 3, interpolated over W
81
+ return jnp.moveaxis(out, 0, 3) # (B, C, H, out_W)
82
+
83
+
84
+ def interpolate_bicubic_no_aa(img, out_h, out_w, align_corners=False):
85
+ h, w = img.shape[-2:]
86
+ if align_corners and out_h > 1:
87
+ scale_y = (h - 1) / (out_h - 1)
88
+ else:
89
+ scale_y = out_h / h
90
+
91
+ if align_corners and out_w > 1:
92
+ scale_x = (w - 1) / (out_w - 1)
93
+ else:
94
+ scale_x = out_w / w
95
+
96
+ idxs_y, weights_y = compute_contribs(
97
+ h,
98
+ out_h,
99
+ scale_y,
100
+ align_corners=align_corners,
101
+ dtype=img.dtype,
102
+ )
103
+ tmp = interpolate_along_axis_bchw(img, idxs_y, weights_y, axis=2)
104
+
105
+ idxs_x, weights_x = compute_contribs(
106
+ w,
107
+ out_w,
108
+ scale_x,
109
+ align_corners=align_corners,
110
+ dtype=img.dtype,
111
+ )
112
+ out = interpolate_along_axis_bchw(tmp, idxs_x, weights_x, axis=3)
113
+ return out
torchax/ops/jlibrary.py CHANGED
@@ -14,19 +14,22 @@ def _jit_composite_impl(composite_name, jaxpr_impl, **jit_args):
14
14
  """Wrap a jaxpr in a jitted function with the proper composite name
15
15
  TODO: Wrap JIT in a `stablehlo.composite` op, instead of generating a call op.
16
16
  """
17
+
17
18
  def composite_impl(*args):
18
19
  return jaxpr_impl(*args)
20
+
19
21
  composite_impl.__name__ = composite_name
20
22
  composite_impl.__qualname__ = composite_name
21
23
  return jax.jit(composite_impl, **jit_args)
22
24
 
25
+
23
26
  def register_jax_composite(composite_name, impl, *ops, **jit_args):
24
27
  """Register a composite using a JAX implementation.
25
28
  composite_name - The name of the library op to use in the exported composite
26
29
  impl - A JAX lowering for the library operation
27
30
  *ops - Variadic torch.ops to lower using `impl`.
28
31
  **jit_args - Additional parameters to forward to JAX jit.
29
-
32
+
30
33
  This is used to register custom lowerings with an explicit jaxpr
31
34
  implementation, such as preserving a specific aten op using a jaten impl.
32
35
 
@@ -36,10 +39,12 @@ def register_jax_composite(composite_name, impl, *ops, **jit_args):
36
39
  For jit params and troubleshooting see:
37
40
  https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html
38
41
  """
42
+
39
43
  @jaten.op(*ops)
40
44
  def _composite_impl(*args):
41
45
  return _jit_composite_impl(composite_name, impl, **jit_args)(*args)
42
46
 
47
+
43
48
  def register_torch_composite(composite_name, impl, *ops, **jit_args):
44
49
  """Register a torch decomposition as a composite.
45
50
  This is useful for registerring custom torch op libraries as composite ops.
@@ -53,10 +58,12 @@ def register_torch_composite(composite_name, impl, *ops, **jit_args):
53
58
  For jit params and troubleshooting see:
54
59
  https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html
55
60
  """
56
-
61
+
57
62
  @jaten.op(*ops)
58
63
  def _composite_impl(*args):
64
+
59
65
  class ImplWrapper(torch.nn.Module):
66
+
60
67
  def __init__(self):
61
68
  super().__init__()
62
69