torchax 0.0.10.dev20251118__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

@@ -0,0 +1,268 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Forked at: https://raw.githubusercontent.com/mlperf/training_results_v0.7/refs/heads/master/Google/benchmarks/ssd/implementations/ssd-research-JAX-tpu-v3-4096/nms.py
17
+ """
18
+
19
+ import jax.numpy as jnp
20
+ import torch
21
+ from jax import lax
22
+
23
+ from . import ops_registry
24
+
25
+ _NMS_TILE_SIZE = 256
26
+
27
+
28
+ def _bbox_overlap(boxes, gt_boxes):
29
+ """Find Bounding box overlap.
30
+
31
+ Args:
32
+ boxes: first set of bounding boxes
33
+ gt_boxes: second set of boxes to compute IOU
34
+
35
+ Returns:
36
+ iou: Intersection over union matrix of all input bounding boxes
37
+ """
38
+ bb_y_min, bb_x_min, bb_y_max, bb_x_max = jnp.split(
39
+ ary=boxes, indices_or_sections=4, axis=2
40
+ )
41
+ gt_y_min, gt_x_min, gt_y_max, gt_x_max = jnp.split(
42
+ ary=gt_boxes, indices_or_sections=4, axis=2
43
+ )
44
+
45
+ # Calculates the intersection area.
46
+ i_xmin = jnp.maximum(bb_x_min, jnp.transpose(gt_x_min, [0, 2, 1]))
47
+ i_xmax = jnp.minimum(bb_x_max, jnp.transpose(gt_x_max, [0, 2, 1]))
48
+ i_ymin = jnp.maximum(bb_y_min, jnp.transpose(gt_y_min, [0, 2, 1]))
49
+ i_ymax = jnp.minimum(bb_y_max, jnp.transpose(gt_y_max, [0, 2, 1]))
50
+ i_area = jnp.maximum((i_xmax - i_xmin), 0) * jnp.maximum((i_ymax - i_ymin), 0)
51
+
52
+ # Calculates the union area.
53
+ bb_area = (bb_y_max - bb_y_min) * (bb_x_max - bb_x_min)
54
+ gt_area = (gt_y_max - gt_y_min) * (gt_x_max - gt_x_min)
55
+ # Adds a small epsilon to avoid divide-by-zero.
56
+ u_area = bb_area + jnp.transpose(gt_area, [0, 2, 1]) - i_area + 1e-8
57
+
58
+ # Calculates IoU.
59
+ iou = i_area / u_area
60
+
61
+ return iou
62
+
63
+
64
+ def _self_suppression(in_args):
65
+ iou, _, iou_sum = in_args
66
+ batch_size = iou.shape[0]
67
+ can_suppress_others = jnp.reshape(jnp.max(iou, 1) <= 0.5, [batch_size, -1, 1]).astype(
68
+ iou.dtype
69
+ )
70
+ iou_suppressed = (
71
+ jnp.reshape(
72
+ (jnp.max(can_suppress_others * iou, 1) <= 0.5).astype(iou.dtype),
73
+ [batch_size, -1, 1],
74
+ )
75
+ * iou
76
+ )
77
+ iou_sum_new = jnp.sum(iou_suppressed, [1, 2])
78
+ return iou_suppressed, jnp.any(iou_sum - iou_sum_new > 0.5), iou_sum_new
79
+
80
+
81
+ def _cross_suppression(in_args):
82
+ boxes, box_slice, iou_threshold, inner_idx = in_args
83
+ batch_size = boxes.shape[0]
84
+ new_slice = lax.dynamic_slice(
85
+ boxes, [0, inner_idx * _NMS_TILE_SIZE, 0], [batch_size, _NMS_TILE_SIZE, 4]
86
+ )
87
+ iou = _bbox_overlap(new_slice, box_slice)
88
+ ret_slice = (
89
+ jnp.expand_dims((jnp.all(iou < iou_threshold, [1])).astype(box_slice.dtype), 2)
90
+ * box_slice
91
+ )
92
+ return boxes, ret_slice, iou_threshold, inner_idx + 1
93
+
94
+
95
+ def _suppression_loop_body(in_args):
96
+ """Process boxes in the range [idx*_NMS_TILE_SIZE, (idx+1)*_NMS_TILE_SIZE).
97
+
98
+ Args:
99
+ in_args: A tuple of arguments: boxes, iou_threshold, output_size, idx
100
+
101
+ Returns:
102
+ boxes: updated boxes.
103
+ iou_threshold: pass down iou_threshold to the next iteration.
104
+ output_size: the updated output_size.
105
+ idx: the updated induction variable.
106
+ """
107
+ boxes, iou_threshold, output_size, idx = in_args
108
+ num_tiles = boxes.shape[1] // _NMS_TILE_SIZE
109
+ batch_size = boxes.shape[0]
110
+
111
+ # Iterates over tiles that can possibly suppress the current tile.
112
+ box_slice = lax.dynamic_slice(
113
+ boxes, [0, idx * _NMS_TILE_SIZE, 0], [batch_size, _NMS_TILE_SIZE, 4]
114
+ )
115
+
116
+ def _loop_cond(in_args):
117
+ _, _, _, inner_idx = in_args
118
+ return inner_idx < idx
119
+
120
+ _, box_slice, _, _ = lax.while_loop(
121
+ _loop_cond, _cross_suppression, (boxes, box_slice, iou_threshold, 0)
122
+ )
123
+
124
+ # Iterates over the current tile to compute self-suppression.
125
+ iou = _bbox_overlap(box_slice, box_slice)
126
+ mask = jnp.expand_dims(
127
+ jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [1, -1])
128
+ > jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [-1, 1]),
129
+ 0,
130
+ )
131
+ iou *= (jnp.logical_and(mask, iou >= iou_threshold)).astype(iou.dtype)
132
+
133
+ def _loop_cond2(in_args):
134
+ _, loop_condition, _ = in_args
135
+ return loop_condition
136
+
137
+ suppressed_iou, _, _ = lax.while_loop(
138
+ _loop_cond2, _self_suppression, (iou, True, jnp.sum(iou, [1, 2]))
139
+ )
140
+ suppressed_box = jnp.sum(suppressed_iou, 1) > 0
141
+ box_slice *= jnp.expand_dims(1.0 - suppressed_box.astype(box_slice.dtype), 2)
142
+
143
+ # Uses box_slice to update the input boxes.
144
+ mask = jnp.reshape(
145
+ (jnp.equal(jnp.arange(num_tiles), idx)).astype(boxes.dtype), [1, -1, 1, 1]
146
+ )
147
+ boxes = jnp.tile(
148
+ jnp.expand_dims(box_slice, 1), [1, num_tiles, 1, 1]
149
+ ) * mask + jnp.reshape(boxes, [batch_size, num_tiles, _NMS_TILE_SIZE, 4]) * (1 - mask)
150
+ boxes = jnp.reshape(boxes, [batch_size, -1, 4])
151
+
152
+ # Updates output_size.
153
+ output_size += jnp.sum(jnp.any(box_slice > 0, [2]).astype(jnp.int32), [1])
154
+ return boxes, iou_threshold, output_size, idx + 1
155
+
156
+
157
+ def non_max_suppression_padded(scores, boxes, max_output_size, iou_threshold):
158
+ """A wrapper that handles non-maximum suppression.
159
+
160
+ Assumption:
161
+ * The boxes are sorted by scores unless the box is a dot (all coordinates
162
+ are zero).
163
+ * Boxes with higher scores can be used to suppress boxes with lower scores.
164
+
165
+ The overal design of the algorithm is to handle boxes tile-by-tile:
166
+
167
+ boxes = boxes.pad_to_multiply_of(tile_size)
168
+ num_tiles = len(boxes) // tile_size
169
+ output_boxes = []
170
+ for i in range(num_tiles):
171
+ box_tile = boxes[i*tile_size : (i+1)*tile_size]
172
+ for j in range(i - 1):
173
+ suppressing_tile = boxes[j*tile_size : (j+1)*tile_size]
174
+ iou = _bbox_overlap(box_tile, suppressing_tile)
175
+ # if the box is suppressed in iou, clear it to a dot
176
+ box_tile *= _update_boxes(iou)
177
+ # Iteratively handle the diagnal tile.
178
+ iou = _box_overlap(box_tile, box_tile)
179
+ iou_changed = True
180
+ while iou_changed:
181
+ # boxes that are not suppressed by anything else
182
+ suppressing_boxes = _get_suppressing_boxes(iou)
183
+ # boxes that are suppressed by suppressing_boxes
184
+ suppressed_boxes = _get_suppressed_boxes(iou, suppressing_boxes)
185
+ # clear iou to 0 for boxes that are suppressed, as they cannot be used
186
+ # to suppress other boxes any more
187
+ new_iou = _clear_iou(iou, suppressed_boxes)
188
+ iou_changed = (new_iou != iou)
189
+ iou = new_iou
190
+ # remaining boxes that can still suppress others, are selected boxes.
191
+ output_boxes.append(_get_suppressing_boxes(iou))
192
+ if len(output_boxes) >= max_output_size:
193
+ break
194
+
195
+ Args:
196
+ scores: a tensor with a shape of [batch_size, anchors].
197
+ boxes: a tensor with a shape of [batch_size, anchors, 4].
198
+ max_output_size: a scalar integer `Tensor` representing the maximum number
199
+ of boxes to be selected by non max suppression.
200
+ iou_threshold: a float representing the threshold for deciding whether boxes
201
+ overlap too much with respect to IOU.
202
+ Returns:
203
+ nms_scores: a tensor with a shape of [batch_size, anchors]. It has same
204
+ dtype as input scores.
205
+ nms_proposals: a tensor with a shape of [batch_size, anchors, 4]. It has
206
+ same dtype as input boxes.
207
+ """
208
+ batch_size = boxes.shape[0]
209
+ num_boxes = boxes.shape[1]
210
+ pad = int(jnp.ceil(float(num_boxes) / _NMS_TILE_SIZE)) * _NMS_TILE_SIZE - num_boxes
211
+ boxes = jnp.pad(boxes.astype(jnp.float32), [[0, 0], [0, pad], [0, 0]])
212
+ scores = jnp.pad(scores.astype(jnp.float32), [[0, 0], [0, pad]])
213
+ num_boxes += pad
214
+
215
+ def _loop_cond(in_args):
216
+ unused_boxes, unused_threshold, output_size, idx = in_args
217
+ return jnp.logical_and(
218
+ jnp.min(output_size) < max_output_size, idx < num_boxes // _NMS_TILE_SIZE
219
+ )
220
+
221
+ selected_boxes, _, output_size, _ = lax.while_loop(
222
+ _loop_cond,
223
+ _suppression_loop_body,
224
+ (boxes, iou_threshold, jnp.zeros([batch_size], jnp.int32), 0),
225
+ )
226
+ idx = num_boxes - lax.top_k(
227
+ jnp.any(selected_boxes > 0, [2]).astype(jnp.int32)
228
+ * jnp.expand_dims(jnp.arange(num_boxes, 0, -1), 0),
229
+ max_output_size,
230
+ )[0].astype(jnp.int32)
231
+ idx = jnp.minimum(idx, num_boxes - 1)
232
+ idx = jnp.reshape(
233
+ idx + jnp.reshape(jnp.arange(batch_size) * num_boxes, [-1, 1]), [-1]
234
+ )
235
+
236
+ return idx
237
+ boxes = jnp.reshape(
238
+ (jnp.reshape(boxes, [-1, 4]))[idx], [batch_size, max_output_size, 4]
239
+ )
240
+ boxes = boxes * (
241
+ jnp.reshape(jnp.arange(max_output_size), [1, -1, 1])
242
+ < jnp.reshape(output_size, [-1, 1, 1])
243
+ ).astype(boxes.dtype)
244
+ scores = jnp.reshape(jnp.reshape(scores, [-1, 1])[idx], [batch_size, max_output_size])
245
+ scores = scores * (
246
+ jnp.reshape(jnp.arange(max_output_size), [1, -1])
247
+ < jnp.reshape(output_size, [-1, 1])
248
+ ).astype(scores.dtype)
249
+ return scores, boxes
250
+
251
+
252
+ # registry:
253
+
254
+
255
+ def nms(boxes, scores, iou_threshold):
256
+ max_output_size = boxes.shape[0]
257
+ boxes = boxes.reshape((1, *boxes.shape))
258
+ scores = scores.reshape((1, *scores.shape))
259
+ res = non_max_suppression_padded(scores, boxes, max_output_size, iou_threshold)
260
+ return res
261
+
262
+
263
+ try:
264
+ import torch
265
+
266
+ ops_registry.register_torch_dispatch_op(torch.ops.torchvision.nms, nms)
267
+ except Exception:
268
+ pass
@@ -0,0 +1,139 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import jax.numpy as jnp
16
+ import numpy
17
+ import torch
18
+ import torch.func
19
+ import torch.utils._mode_utils as mode_utils
20
+ import torch.utils.dlpack as torchdl
21
+ from jax import dlpack as jaxdl
22
+
23
+ NUMPY_UNSUPPORTED_DTYPES = {
24
+ torch.bfloat16: jnp.bfloat16,
25
+ torch.float8_e4m3fn: jnp.float8_e4m3fn,
26
+ torch.float8_e4m3fnuz: jnp.float8_e4m3fnuz,
27
+ torch.float8_e5m2: jnp.float8_e5m2,
28
+ torch.float8_e5m2fnuz: jnp.float8_e5m2fnuz,
29
+ }
30
+
31
+
32
+ def t2j(t, use_dlpack=True):
33
+ is_bool = False
34
+ if t.dtype == torch.bool:
35
+ is_bool = True
36
+ t = t.to(torch.int8)
37
+
38
+ t = t.to_dense()
39
+
40
+ if not t.is_contiguous():
41
+ t = t.contiguous()
42
+
43
+ res = None
44
+ if use_dlpack:
45
+ try:
46
+ res = jaxdl.from_dlpack(t)
47
+ except Exception:
48
+ pass
49
+
50
+ if res is None:
51
+ # https://github.com/google/jax/issues/7657
52
+ # https://github.com/google/jax/issues/17784
53
+ if t.dtype in NUMPY_UNSUPPORTED_DTYPES:
54
+ nparray = (
55
+ t.cpu().detach().to(torch.float32).numpy()
56
+ ) # handle dtypes not supported by numpy
57
+ else:
58
+ nparray = t.cpu().detach().numpy()
59
+ res = jnp.asarray(nparray)
60
+ if t.dtype in NUMPY_UNSUPPORTED_DTYPES:
61
+ res = res.astype(NUMPY_UNSUPPORTED_DTYPES[t.dtype])
62
+
63
+ if is_bool:
64
+ res = res.astype(jnp.bool_)
65
+ return res
66
+
67
+
68
+ def j2t(x, use_dlpack=True):
69
+ with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
70
+ res = None
71
+ if use_dlpack:
72
+ try:
73
+ dl = jaxdl.to_dlpack(x)
74
+ res = torchdl.from_dlpack(dl)
75
+ except Exception:
76
+ res = None
77
+
78
+ orig_dtype = None
79
+ if res is None:
80
+ orig_dtype = None
81
+ if x.dtype == jnp.bfloat16.dtype:
82
+ orig_dtype = x.dtype
83
+ x = x.astype(jnp.float32.dtype)
84
+ res = torch.from_numpy(numpy.asarray(x))
85
+
86
+ if x.dtype == jnp.bool_:
87
+ res = res.to(torch.bool)
88
+
89
+ if orig_dtype is not None:
90
+ res = res.to(j2t_dtype(orig_dtype))
91
+ return res
92
+
93
+
94
+ TORCH_DTYPE_TO_JAX = {
95
+ # NO_MAPPING : jnp.float0.dtype (signless scalar int),
96
+ torch.bool: jnp.bool_.dtype,
97
+ # NO_MAPPING : jnp.int4.dtype,
98
+ torch.int8: jnp.int8.dtype,
99
+ torch.int16: jnp.int16.dtype,
100
+ torch.int32: jnp.int32.dtype,
101
+ torch.int64: jnp.int64.dtype,
102
+ torch.long: jnp.int64.dtype,
103
+ # NO_MAPPING : jnp.uint4
104
+ torch.uint8: jnp.uint8.dtype,
105
+ torch.uint16: jnp.uint16.dtype,
106
+ torch.uint32: jnp.uint32.dtype,
107
+ torch.uint64: jnp.uint64.dtype,
108
+ # NO_MAPPING : jnp.float8_e4m3b11fnuz.dtype,
109
+ torch.float8_e4m3fn: jnp.float8_e4m3fn.dtype,
110
+ # NO_MAPPING : jnp.float8_e4m3fnuz.dtype,
111
+ torch.float8_e5m2: jnp.float8_e5m2.dtype,
112
+ # NO_MAPPING : jnp.float8_e5m2fnuz.dtype,
113
+ torch.bfloat16: jnp.bfloat16.dtype,
114
+ torch.half: jnp.float16.dtype,
115
+ torch.float16: jnp.float16.dtype,
116
+ torch.float32: jnp.float32.dtype,
117
+ torch.float64: jnp.float64.dtype,
118
+ torch.double: jnp.double.dtype,
119
+ torch.complex64: jnp.complex64.dtype,
120
+ torch.complex128: jnp.complex128.dtype,
121
+ None: None,
122
+ }
123
+
124
+ JAX_DTYPE_TO_TORCH = {value: key for key, value in TORCH_DTYPE_TO_JAX.items()}
125
+ # Add imprecise mappings for some JAX dtypes which don't have torch analogues
126
+ JAX_DTYPE_TO_TORCH[jnp.dtype("int4")] = torch.int8
127
+ JAX_DTYPE_TO_TORCH[jnp.dtype("uint4")] = torch.uint8
128
+
129
+
130
+ def t2j_dtype(dtype):
131
+ if dtype not in TORCH_DTYPE_TO_JAX:
132
+ raise RuntimeError(f"Attempting to convert unknown type: {dtype} to jax type,")
133
+ return TORCH_DTYPE_TO_JAX[dtype]
134
+
135
+
136
+ def j2t_dtype(dtype):
137
+ if dtype not in JAX_DTYPE_TO_TORCH:
138
+ raise RuntimeError(f"Attempting to convert unknown type: {dtype} to torch type,")
139
+ return JAX_DTYPE_TO_TORCH[dtype]
torchax/ops/op_base.py ADDED
@@ -0,0 +1,137 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import functools
16
+ from collections.abc import Callable
17
+ from typing import Concatenate, ParamSpec
18
+
19
+ import jax
20
+ import jax.numpy as jnp
21
+ import numpy as np
22
+ import torch
23
+
24
+ from torchax import types
25
+ from torchax.ops import mappings
26
+ from torchax.view import View
27
+
28
+
29
+ class InplaceOp:
30
+ def __init__(
31
+ self, functional_op, replace=False, position_to_mutate=0, is_jax_func=False
32
+ ):
33
+ self.functional = functional_op
34
+ self.replace = replace
35
+ self.position_to_mutate = position_to_mutate
36
+ self.is_jax_func = is_jax_func
37
+
38
+ def __call__(self, *args, **kwargs):
39
+ to_mutate = args[self.position_to_mutate]
40
+ view_value = to_mutate
41
+ if isinstance(to_mutate, View):
42
+ view_value = to_mutate.torch()
43
+ # Convert the target View to a Tensor, and
44
+ # leave the rest args as is. If other args are
45
+ # also View, they will be converted to tensors
46
+ # in the self.functional dispatch.
47
+ env = view_value._env
48
+ if self.is_jax_func:
49
+ view_value, args, kwargs = env.t2j_iso((view_value, args, kwargs))
50
+ new_value_jax = self.functional(view_value, *args[1:], **kwargs)
51
+ new_value = env.j2t_iso(new_value_jax)
52
+ else:
53
+ new_value = self.functional(view_value, *args[1:], **kwargs)
54
+
55
+ if isinstance(to_mutate, View):
56
+ to_mutate.update(new_value)
57
+ else:
58
+ if self.replace:
59
+ to_mutate._elem = new_value._elem
60
+ else:
61
+ to_mutate.copy_(new_value)
62
+ return to_mutate
63
+
64
+
65
+ class OutVariant:
66
+ def __call__(self, *args, **kwargs):
67
+ to_mutate = kwargs["out"]
68
+ del kwargs["out"]
69
+ to_mutate._elem = self.functional(*args, **kwargs)._elem
70
+ return to_mutate
71
+
72
+
73
+ P = ParamSpec("P")
74
+
75
+
76
+ def convert_dtype(use_default_dtype: bool = True):
77
+ """Converts `dtype` kwarg of function from torch to JAX.
78
+
79
+ Args:
80
+ use_default_dtype: Whether to use torch default dtype if none is provided.
81
+
82
+ Returns:
83
+ A decorator that wraps a JAX implementation of a torch function.
84
+ """
85
+
86
+ def decorator(func: types.TorchCallable):
87
+ @functools.wraps(func)
88
+ def wrapper(*args: P.args, dtype: torch.dtype | None = None, **kwargs: P.kwargs):
89
+ if not dtype and use_default_dtype:
90
+ dtype = torch.get_default_dtype()
91
+ if isinstance(dtype, torch.dtype):
92
+ jax_dtype = mappings.t2j_dtype(dtype)
93
+ else:
94
+ jax_dtype = dtype
95
+
96
+ return func(*args, dtype=jax_dtype, **kwargs)
97
+
98
+ return wrapper
99
+
100
+ return decorator
101
+
102
+
103
+ def maybe_convert_constant_dtype(val: types.JaxValue | None, dtype: jnp.dtype | None):
104
+ """Optionally converts scalar constant's dtype using `numpy`
105
+
106
+ Use in cases where you require a constant and can't handle a traced array.
107
+ """
108
+ if val and dtype:
109
+ if isinstance(val, jax.Array):
110
+ return maybe_convert_constant_dtype(val.item(), dtype)
111
+
112
+ return np.array(val, dtype)
113
+
114
+ return val
115
+
116
+
117
+ def promote_int_input(f: Callable[Concatenate[jax.Array, P], types.JaxValue]):
118
+ """If the first argument is an int array, promote it to float32."""
119
+
120
+ @functools.wraps(f)
121
+ def wrapper(x: jax.Array, *args: P.args, **kwargs: P.kwargs):
122
+ if x.dtype in [jnp.int8, jnp.int16, jnp.int32, jnp.int64]:
123
+ x = x.astype(mappings.t2j_dtype(torch.get_default_dtype()))
124
+
125
+ return f(x, *args, **kwargs)
126
+
127
+ return wrapper
128
+
129
+
130
+ def foreach_loop(
131
+ seq: jax.Array, fn: Callable[[jax.Array, jax.Array], jax.Array], init_val=0.0
132
+ ):
133
+ """Run `fn` for each element of 1D array `seq`.
134
+
135
+ Similar to `functools.reduce`, but implemented with `jax.lax.fori_loop`."""
136
+ assert len(seq.shape) == 1
137
+ return jax.lax.fori_loop(0, len(seq), lambda i, carry: fn(carry, seq[i]), init_val)
@@ -0,0 +1,74 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import dataclasses
16
+ import logging
17
+
18
+ from torchax.types import JaxCallable, TorchCallable
19
+
20
+
21
+ @dataclasses.dataclass
22
+ class Operator:
23
+ torch_op: TorchCallable
24
+ func: TorchCallable | JaxCallable
25
+ is_jax_function: bool
26
+ is_user_defined: bool
27
+ needs_env: bool
28
+ is_view_op: bool
29
+
30
+
31
+ all_aten_ops: dict[TorchCallable, Operator] = {}
32
+ all_torch_functions: dict[TorchCallable, Operator] = {}
33
+
34
+
35
+ def register_torch_dispatch_op(
36
+ aten_op,
37
+ impl_callable,
38
+ is_jax_function=True,
39
+ is_user_defined=False,
40
+ needs_env=False,
41
+ is_view_op=False,
42
+ ):
43
+ op = Operator(
44
+ aten_op,
45
+ impl_callable,
46
+ is_jax_function=is_jax_function,
47
+ is_user_defined=is_user_defined,
48
+ needs_env=needs_env,
49
+ is_view_op=is_view_op,
50
+ )
51
+ if aten_op in all_aten_ops:
52
+ logging.warning(f"Duplicate op registration for {aten_op}")
53
+ all_aten_ops[aten_op] = op
54
+ return impl_callable
55
+
56
+
57
+ def register_torch_function_op(
58
+ torch_func,
59
+ impl_callable,
60
+ is_jax_function=True,
61
+ is_user_defined=False,
62
+ needs_env=False,
63
+ is_view_op=False,
64
+ ):
65
+ op = Operator(
66
+ torch_func,
67
+ impl_callable,
68
+ is_jax_function=is_jax_function,
69
+ is_user_defined=is_user_defined,
70
+ needs_env=needs_env,
71
+ is_view_op=is_view_op,
72
+ )
73
+ all_torch_functions[torch_func] = op
74
+ return impl_callable