torchax 0.0.10.dev20251117__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,248 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Forked at: https://raw.githubusercontent.com/mlperf/training_results_v0.7/refs/heads/master/Google/benchmarks/ssd/implementations/ssd-research-JAX-tpu-v3-4096/nms.py
17
+ """
18
+
19
+ import functools
20
+ from typing import List, Union, Optional, Tuple
21
+
22
+ import torch
23
+ from jax import lax
24
+ import jax.numpy as jnp
25
+ from . import ops_registry
26
+
27
+ _NMS_TILE_SIZE = 256
28
+
29
+
30
+ def _bbox_overlap(boxes, gt_boxes):
31
+ """Find Bounding box overlap.
32
+
33
+ Args:
34
+ boxes: first set of bounding boxes
35
+ gt_boxes: second set of boxes to compute IOU
36
+
37
+ Returns:
38
+ iou: Intersection over union matrix of all input bounding boxes
39
+ """
40
+ bb_y_min, bb_x_min, bb_y_max, bb_x_max = jnp.split(
41
+ ary=boxes, indices_or_sections=4, axis=2)
42
+ gt_y_min, gt_x_min, gt_y_max, gt_x_max = jnp.split(
43
+ ary=gt_boxes, indices_or_sections=4, axis=2)
44
+
45
+ # Calculates the intersection area.
46
+ i_xmin = jnp.maximum(bb_x_min, jnp.transpose(gt_x_min, [0, 2, 1]))
47
+ i_xmax = jnp.minimum(bb_x_max, jnp.transpose(gt_x_max, [0, 2, 1]))
48
+ i_ymin = jnp.maximum(bb_y_min, jnp.transpose(gt_y_min, [0, 2, 1]))
49
+ i_ymax = jnp.minimum(bb_y_max, jnp.transpose(gt_y_max, [0, 2, 1]))
50
+ i_area = jnp.maximum((i_xmax - i_xmin), 0) * jnp.maximum((i_ymax - i_ymin), 0)
51
+
52
+ # Calculates the union area.
53
+ bb_area = (bb_y_max - bb_y_min) * (bb_x_max - bb_x_min)
54
+ gt_area = (gt_y_max - gt_y_min) * (gt_x_max - gt_x_min)
55
+ # Adds a small epsilon to avoid divide-by-zero.
56
+ u_area = bb_area + jnp.transpose(gt_area, [0, 2, 1]) - i_area + 1e-8
57
+
58
+ # Calculates IoU.
59
+ iou = i_area / u_area
60
+
61
+ return iou
62
+
63
+
64
+ def _self_suppression(in_args):
65
+ iou, _, iou_sum = in_args
66
+ batch_size = iou.shape[0]
67
+ can_suppress_others = jnp.reshape(
68
+ jnp.max(iou, 1) <= 0.5, [batch_size, -1, 1]).astype(iou.dtype)
69
+ iou_suppressed = jnp.reshape(
70
+ (jnp.max(can_suppress_others * iou, 1) <= 0.5).astype(
71
+ iou.dtype), [batch_size, -1, 1]) * iou
72
+ iou_sum_new = jnp.sum(iou_suppressed, [1, 2])
73
+ return iou_suppressed, jnp.any(iou_sum - iou_sum_new > 0.5), iou_sum_new
74
+
75
+
76
+ def _cross_suppression(in_args):
77
+ boxes, box_slice, iou_threshold, inner_idx = in_args
78
+ batch_size = boxes.shape[0]
79
+ new_slice = lax.dynamic_slice(boxes, [0, inner_idx * _NMS_TILE_SIZE, 0],
80
+ [batch_size, _NMS_TILE_SIZE, 4])
81
+ iou = _bbox_overlap(new_slice, box_slice)
82
+ ret_slice = jnp.expand_dims((jnp.all(iou < iou_threshold, [1])).astype(
83
+ box_slice.dtype), 2) * box_slice
84
+ return boxes, ret_slice, iou_threshold, inner_idx + 1
85
+
86
+
87
+ def _suppression_loop_body(in_args):
88
+ """Process boxes in the range [idx*_NMS_TILE_SIZE, (idx+1)*_NMS_TILE_SIZE).
89
+
90
+ Args:
91
+ in_args: A tuple of arguments: boxes, iou_threshold, output_size, idx
92
+
93
+ Returns:
94
+ boxes: updated boxes.
95
+ iou_threshold: pass down iou_threshold to the next iteration.
96
+ output_size: the updated output_size.
97
+ idx: the updated induction variable.
98
+ """
99
+ boxes, iou_threshold, output_size, idx = in_args
100
+ num_tiles = boxes.shape[1] // _NMS_TILE_SIZE
101
+ batch_size = boxes.shape[0]
102
+
103
+ # Iterates over tiles that can possibly suppress the current tile.
104
+ box_slice = lax.dynamic_slice(boxes, [0, idx * _NMS_TILE_SIZE, 0],
105
+ [batch_size, _NMS_TILE_SIZE, 4])
106
+
107
+ def _loop_cond(in_args):
108
+ _, _, _, inner_idx = in_args
109
+ return inner_idx < idx
110
+
111
+ _, box_slice, _, _ = lax.while_loop(_loop_cond, _cross_suppression,
112
+ (boxes, box_slice, iou_threshold, 0))
113
+
114
+ # Iterates over the current tile to compute self-suppression.
115
+ iou = _bbox_overlap(box_slice, box_slice)
116
+ mask = jnp.expand_dims(
117
+ jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [1, -1])
118
+ > jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [-1, 1]), 0)
119
+ iou *= (jnp.logical_and(mask, iou >= iou_threshold)).astype(iou.dtype)
120
+
121
+ def _loop_cond2(in_args):
122
+ _, loop_condition, _ = in_args
123
+ return loop_condition
124
+
125
+ suppressed_iou, _, _ = lax.while_loop(_loop_cond2, _self_suppression,
126
+ (iou, True, jnp.sum(iou, [1, 2])))
127
+ suppressed_box = jnp.sum(suppressed_iou, 1) > 0
128
+ box_slice *= jnp.expand_dims(1.0 - suppressed_box.astype(box_slice.dtype), 2)
129
+
130
+ # Uses box_slice to update the input boxes.
131
+ mask = jnp.reshape((jnp.equal(jnp.arange(num_tiles),
132
+ idx)).astype(boxes.dtype), [1, -1, 1, 1])
133
+ boxes = jnp.tile(jnp.expand_dims(
134
+ box_slice, 1), [1, num_tiles, 1, 1]) * mask + jnp.reshape(
135
+ boxes, [batch_size, num_tiles, _NMS_TILE_SIZE, 4]) * (1 - mask)
136
+ boxes = jnp.reshape(boxes, [batch_size, -1, 4])
137
+
138
+ # Updates output_size.
139
+ output_size += jnp.sum(jnp.any(box_slice > 0, [2]).astype(jnp.int32), [1])
140
+ return boxes, iou_threshold, output_size, idx + 1
141
+
142
+
143
+ def non_max_suppression_padded(scores, boxes, max_output_size, iou_threshold):
144
+ """A wrapper that handles non-maximum suppression.
145
+
146
+ Assumption:
147
+ * The boxes are sorted by scores unless the box is a dot (all coordinates
148
+ are zero).
149
+ * Boxes with higher scores can be used to suppress boxes with lower scores.
150
+
151
+ The overal design of the algorithm is to handle boxes tile-by-tile:
152
+
153
+ boxes = boxes.pad_to_multiply_of(tile_size)
154
+ num_tiles = len(boxes) // tile_size
155
+ output_boxes = []
156
+ for i in range(num_tiles):
157
+ box_tile = boxes[i*tile_size : (i+1)*tile_size]
158
+ for j in range(i - 1):
159
+ suppressing_tile = boxes[j*tile_size : (j+1)*tile_size]
160
+ iou = _bbox_overlap(box_tile, suppressing_tile)
161
+ # if the box is suppressed in iou, clear it to a dot
162
+ box_tile *= _update_boxes(iou)
163
+ # Iteratively handle the diagnal tile.
164
+ iou = _box_overlap(box_tile, box_tile)
165
+ iou_changed = True
166
+ while iou_changed:
167
+ # boxes that are not suppressed by anything else
168
+ suppressing_boxes = _get_suppressing_boxes(iou)
169
+ # boxes that are suppressed by suppressing_boxes
170
+ suppressed_boxes = _get_suppressed_boxes(iou, suppressing_boxes)
171
+ # clear iou to 0 for boxes that are suppressed, as they cannot be used
172
+ # to suppress other boxes any more
173
+ new_iou = _clear_iou(iou, suppressed_boxes)
174
+ iou_changed = (new_iou != iou)
175
+ iou = new_iou
176
+ # remaining boxes that can still suppress others, are selected boxes.
177
+ output_boxes.append(_get_suppressing_boxes(iou))
178
+ if len(output_boxes) >= max_output_size:
179
+ break
180
+
181
+ Args:
182
+ scores: a tensor with a shape of [batch_size, anchors].
183
+ boxes: a tensor with a shape of [batch_size, anchors, 4].
184
+ max_output_size: a scalar integer `Tensor` representing the maximum number
185
+ of boxes to be selected by non max suppression.
186
+ iou_threshold: a float representing the threshold for deciding whether boxes
187
+ overlap too much with respect to IOU.
188
+ Returns:
189
+ nms_scores: a tensor with a shape of [batch_size, anchors]. It has same
190
+ dtype as input scores.
191
+ nms_proposals: a tensor with a shape of [batch_size, anchors, 4]. It has
192
+ same dtype as input boxes.
193
+ """
194
+ batch_size = boxes.shape[0]
195
+ num_boxes = boxes.shape[1]
196
+ pad = int(jnp.ceil(
197
+ float(num_boxes) / _NMS_TILE_SIZE)) * _NMS_TILE_SIZE - num_boxes
198
+ boxes = jnp.pad(boxes.astype(jnp.float32), [[0, 0], [0, pad], [0, 0]])
199
+ scores = jnp.pad(scores.astype(jnp.float32), [[0, 0], [0, pad]])
200
+ num_boxes += pad
201
+
202
+ def _loop_cond(in_args):
203
+ unused_boxes, unused_threshold, output_size, idx = in_args
204
+ return jnp.logical_and(
205
+ jnp.min(output_size) < max_output_size, idx
206
+ < num_boxes // _NMS_TILE_SIZE)
207
+
208
+ selected_boxes, _, output_size, _ = lax.while_loop(
209
+ _loop_cond, _suppression_loop_body,
210
+ (boxes, iou_threshold, jnp.zeros([batch_size], jnp.int32), 0))
211
+ idx = num_boxes - lax.top_k(
212
+ jnp.any(selected_boxes > 0, [2]).astype(jnp.int32) *
213
+ jnp.expand_dims(jnp.arange(num_boxes, 0, -1), 0),
214
+ max_output_size)[0].astype(jnp.int32)
215
+ idx = jnp.minimum(idx, num_boxes - 1)
216
+ idx = jnp.reshape(
217
+ idx + jnp.reshape(jnp.arange(batch_size) * num_boxes, [-1, 1]), [-1])
218
+
219
+ return idx
220
+ boxes = jnp.reshape((jnp.reshape(boxes, [-1, 4]))[idx],
221
+ [batch_size, max_output_size, 4])
222
+ boxes = boxes * (jnp.reshape(jnp.arange(max_output_size), [1, -1, 1])
223
+ < jnp.reshape(output_size, [-1, 1, 1])).astype(boxes.dtype)
224
+ scores = jnp.reshape(
225
+ jnp.reshape(scores, [-1, 1])[idx], [batch_size, max_output_size])
226
+ scores = scores * (jnp.reshape(jnp.arange(max_output_size), [1, -1])
227
+ < jnp.reshape(output_size, [-1, 1])).astype(scores.dtype)
228
+ return scores, boxes
229
+
230
+
231
+ # registry:
232
+
233
+
234
+ def nms(boxes, scores, iou_threshold):
235
+ max_output_size = boxes.shape[0]
236
+ boxes = boxes.reshape((1, *boxes.shape))
237
+ scores = scores.reshape((1, *scores.shape))
238
+ res = non_max_suppression_padded(scores, boxes, max_output_size,
239
+ iou_threshold)
240
+ return res
241
+
242
+
243
+ try:
244
+ import torch
245
+ import torchvision
246
+ ops_registry.register_torch_dispatch_op(torch.ops.torchvision.nms, nms)
247
+ except Exception:
248
+ pass
@@ -0,0 +1,161 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from jax import dlpack as jaxdl
16
+ import jax.numpy as jnp
17
+ import numpy
18
+ import torch
19
+ import torch.func
20
+ import torch.utils.dlpack as torchdl
21
+ import torch.utils._mode_utils as mode_utils
22
+
23
+ NUMPY_UNSUPPORTED_DTYPES = {
24
+ torch.bfloat16: jnp.bfloat16,
25
+ torch.float8_e4m3fn: jnp.float8_e4m3fn,
26
+ torch.float8_e4m3fnuz: jnp.float8_e4m3fnuz,
27
+ torch.float8_e5m2: jnp.float8_e5m2,
28
+ torch.float8_e5m2fnuz: jnp.float8_e5m2fnuz,
29
+ }
30
+
31
+
32
+ def t2j(t, use_dlpack=True):
33
+ is_bool = False
34
+ if t.dtype == torch.bool:
35
+ is_bool = True
36
+ t = t.to(torch.int8)
37
+
38
+ t = t.to_dense()
39
+
40
+ if not t.is_contiguous():
41
+ t = t.contiguous()
42
+
43
+ res = None
44
+ if use_dlpack:
45
+ try:
46
+ res = jaxdl.from_dlpack(t)
47
+ except Exception:
48
+ pass
49
+
50
+ if res is None:
51
+ # https://github.com/google/jax/issues/7657
52
+ # https://github.com/google/jax/issues/17784
53
+ if t.dtype in NUMPY_UNSUPPORTED_DTYPES:
54
+ nparray = (t.cpu().detach().to(torch.float32).numpy()
55
+ ) # handle dtypes not supported by numpy
56
+ else:
57
+ nparray = t.cpu().detach().numpy()
58
+ res = jnp.asarray(nparray)
59
+ if t.dtype in NUMPY_UNSUPPORTED_DTYPES:
60
+ res = res.astype(NUMPY_UNSUPPORTED_DTYPES[t.dtype])
61
+
62
+ if is_bool:
63
+ res = res.astype(jnp.bool_)
64
+ return res
65
+
66
+
67
+ def j2t(x, use_dlpack=True):
68
+ with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
69
+ res = None
70
+ if use_dlpack:
71
+ try:
72
+ dl = jaxdl.to_dlpack(x)
73
+ res = torchdl.from_dlpack(dl)
74
+ except Exception:
75
+ res = None
76
+
77
+ orig_dtype = None
78
+ if res is None:
79
+ orig_dtype = None
80
+ if x.dtype == jnp.bfloat16.dtype:
81
+ orig_dtype = x.dtype
82
+ x = x.astype(jnp.float32.dtype)
83
+ res = torch.from_numpy(numpy.asarray(x))
84
+
85
+ if x.dtype == jnp.bool_:
86
+ res = res.to(torch.bool)
87
+
88
+ if orig_dtype is not None:
89
+ res = res.to(j2t_dtype(orig_dtype))
90
+ return res
91
+
92
+
93
+ TORCH_DTYPE_TO_JAX = {
94
+ # NO_MAPPING : jnp.float0.dtype (signless scalar int),
95
+ torch.bool:
96
+ jnp.bool_.dtype,
97
+ # NO_MAPPING : jnp.int4.dtype,
98
+ torch.int8:
99
+ jnp.int8.dtype,
100
+ torch.int16:
101
+ jnp.int16.dtype,
102
+ torch.int32:
103
+ jnp.int32.dtype,
104
+ torch.int64:
105
+ jnp.int64.dtype,
106
+ torch.long:
107
+ jnp.int64.dtype,
108
+ # NO_MAPPING : jnp.uint4
109
+ torch.uint8:
110
+ jnp.uint8.dtype,
111
+ torch.uint16:
112
+ jnp.uint16.dtype,
113
+ torch.uint32:
114
+ jnp.uint32.dtype,
115
+ torch.uint64:
116
+ jnp.uint64.dtype,
117
+ # NO_MAPPING : jnp.float8_e4m3b11fnuz.dtype,
118
+ torch.float8_e4m3fn:
119
+ jnp.float8_e4m3fn.dtype,
120
+ # NO_MAPPING : jnp.float8_e4m3fnuz.dtype,
121
+ torch.float8_e5m2:
122
+ jnp.float8_e5m2.dtype,
123
+ # NO_MAPPING : jnp.float8_e5m2fnuz.dtype,
124
+ torch.bfloat16:
125
+ jnp.bfloat16.dtype,
126
+ torch.half:
127
+ jnp.float16.dtype,
128
+ torch.float16:
129
+ jnp.float16.dtype,
130
+ torch.float32:
131
+ jnp.float32.dtype,
132
+ torch.float64:
133
+ jnp.float64.dtype,
134
+ torch.double:
135
+ jnp.double.dtype,
136
+ torch.complex64:
137
+ jnp.complex64.dtype,
138
+ torch.complex128:
139
+ jnp.complex128.dtype,
140
+ None:
141
+ None,
142
+ }
143
+
144
+ JAX_DTYPE_TO_TORCH = {value: key for key, value in TORCH_DTYPE_TO_JAX.items()}
145
+ # Add imprecise mappings for some JAX dtypes which don't have torch analogues
146
+ JAX_DTYPE_TO_TORCH[jnp.dtype('int4')] = torch.int8
147
+ JAX_DTYPE_TO_TORCH[jnp.dtype('uint4')] = torch.uint8
148
+
149
+
150
+ def t2j_dtype(dtype):
151
+ if dtype not in TORCH_DTYPE_TO_JAX:
152
+ raise RuntimeError(
153
+ f'Attempting to convert unknown type: {dtype} to jax type,')
154
+ return TORCH_DTYPE_TO_JAX[dtype]
155
+
156
+
157
+ def j2t_dtype(dtype):
158
+ if dtype not in JAX_DTYPE_TO_TORCH:
159
+ raise RuntimeError(
160
+ f'Attempting to convert unknown type: {dtype} to torch type,')
161
+ return JAX_DTYPE_TO_TORCH[dtype]
torchax/ops/op_base.py ADDED
@@ -0,0 +1,145 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import functools
16
+ import jax
17
+ import jax.numpy as jnp
18
+ import numpy as np
19
+ import torch
20
+ from torchax.ops import mappings
21
+ from torchax.view import View
22
+ from torchax import types
23
+ import sys
24
+
25
+ from typing import Callable, Optional, ParamSpec, Concatenate
26
+
27
+
28
+ class InplaceOp:
29
+
30
+ def __init__(self,
31
+ functional_op,
32
+ replace=False,
33
+ position_to_mutate=0,
34
+ is_jax_func=False):
35
+ self.functional = functional_op
36
+ self.replace = replace
37
+ self.position_to_mutate = position_to_mutate
38
+ self.is_jax_func = is_jax_func
39
+
40
+ def __call__(self, *args, **kwargs):
41
+ to_mutate = args[self.position_to_mutate]
42
+ view_value = to_mutate
43
+ if isinstance(to_mutate, View):
44
+ view_value = to_mutate.torch()
45
+ # Convert the target View to a Tensor, and
46
+ # leave the rest args as is. If other args are
47
+ # also View, they will be converted to tensors
48
+ # in the self.functional dispatch.
49
+ env = view_value._env
50
+ if self.is_jax_func:
51
+ view_value, args, kwargs = env.t2j_iso((view_value, args, kwargs))
52
+ new_value_jax = self.functional(view_value, *args[1:], **kwargs)
53
+ new_value = env.j2t_iso(new_value_jax)
54
+ else:
55
+ new_value = self.functional(view_value, *args[1:], **kwargs)
56
+
57
+ if isinstance(to_mutate, View):
58
+ to_mutate.update(new_value)
59
+ else:
60
+ if self.replace:
61
+ to_mutate._elem = new_value._elem
62
+ else:
63
+ to_mutate.copy_(new_value)
64
+ return to_mutate
65
+
66
+
67
+ class OutVariant:
68
+
69
+ def __call__(self, *args, **kwargs):
70
+ to_mutate = kwargs['out']
71
+ del kwargs['out']
72
+ to_mutate._elem = self.functional(*args, **kwargs)._elem
73
+ return to_mutate
74
+
75
+
76
+ P = ParamSpec('P')
77
+
78
+
79
+ def convert_dtype(use_default_dtype: bool = True):
80
+ """Converts `dtype` kwarg of function from torch to JAX.
81
+
82
+ Args:
83
+ use_default_dtype: Whether to use torch default dtype if none is provided.
84
+
85
+ Returns:
86
+ A decorator that wraps a JAX implementation of a torch function.
87
+ """
88
+
89
+ def decorator(func: types.TorchCallable):
90
+
91
+ @functools.wraps(func)
92
+ def wrapper(*args: P.args,
93
+ dtype: Optional[torch.dtype] = None,
94
+ **kwargs: P.kwargs):
95
+ if not dtype and use_default_dtype:
96
+ dtype = torch.get_default_dtype()
97
+ if isinstance(dtype, torch.dtype):
98
+ jax_dtype = mappings.t2j_dtype(dtype)
99
+ else:
100
+ jax_dtype = dtype
101
+
102
+ return func(*args, dtype=jax_dtype, **kwargs)
103
+
104
+ return wrapper
105
+
106
+ return decorator
107
+
108
+
109
+ def maybe_convert_constant_dtype(val: Optional[types.JaxValue],
110
+ dtype: Optional[jnp.dtype]):
111
+ """Optionally converts scalar constant's dtype using `numpy`
112
+
113
+ Use in cases where you require a constant and can't handle a traced array.
114
+ """
115
+ if val and dtype:
116
+ if isinstance(val, jax.Array):
117
+ return maybe_convert_constant_dtype(val.item(), dtype)
118
+
119
+ return np.array(val, dtype)
120
+
121
+ return val
122
+
123
+
124
+ def promote_int_input(f: Callable[Concatenate[jax.Array, P], types.JaxValue]):
125
+ """If the first argument is an int array, promote it to float32."""
126
+
127
+ @functools.wraps(f)
128
+ def wrapper(x: jax.Array, *args: P.args, **kwargs: P.kwargs):
129
+ if x.dtype in [jnp.int8, jnp.int16, jnp.int32, jnp.int64]:
130
+ x = x.astype(mappings.t2j_dtype(torch.get_default_dtype()))
131
+
132
+ return f(x, *args, **kwargs)
133
+
134
+ return wrapper
135
+
136
+
137
+ def foreach_loop(seq: jax.Array,
138
+ fn: Callable[[jax.Array, jax.Array], jax.Array],
139
+ init_val=0.0):
140
+ """Run `fn` for each element of 1D array `seq`.
141
+
142
+ Similar to `functools.reduce`, but implemented with `jax.lax.fori_loop`."""
143
+ assert len(seq.shape) == 1
144
+ return jax.lax.fori_loop(0, len(seq), lambda i, carry: fn(carry, seq[i]),
145
+ init_val)
@@ -0,0 +1,69 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import dataclasses
16
+ import logging
17
+ from torchax.types import JaxCallable, TorchCallable
18
+
19
+ from typing import Union, Dict
20
+
21
+
22
+ @dataclasses.dataclass
23
+ class Operator:
24
+ torch_op: TorchCallable
25
+ func: Union[TorchCallable, JaxCallable]
26
+ is_jax_function: bool
27
+ is_user_defined: bool
28
+ needs_env: bool
29
+ is_view_op: bool
30
+
31
+
32
+ all_aten_ops: Dict[TorchCallable, Operator] = {}
33
+ all_torch_functions: Dict[TorchCallable, Operator] = {}
34
+
35
+
36
+ def register_torch_dispatch_op(aten_op,
37
+ impl_callable,
38
+ is_jax_function=True,
39
+ is_user_defined=False,
40
+ needs_env=False,
41
+ is_view_op=False):
42
+ op = Operator(
43
+ aten_op,
44
+ impl_callable,
45
+ is_jax_function=is_jax_function,
46
+ is_user_defined=is_user_defined,
47
+ needs_env=needs_env,
48
+ is_view_op=is_view_op)
49
+ if aten_op in all_aten_ops:
50
+ logging.warning(f'Duplicate op registration for {aten_op}')
51
+ all_aten_ops[aten_op] = op
52
+ return impl_callable
53
+
54
+
55
+ def register_torch_function_op(torch_func,
56
+ impl_callable,
57
+ is_jax_function=True,
58
+ is_user_defined=False,
59
+ needs_env=False,
60
+ is_view_op=False):
61
+ op = Operator(
62
+ torch_func,
63
+ impl_callable,
64
+ is_jax_function=is_jax_function,
65
+ is_user_defined=is_user_defined,
66
+ needs_env=needs_env,
67
+ is_view_op=is_view_op)
68
+ all_torch_functions[torch_func] = op
69
+ return impl_callable