torchax 0.0.10.dev20251117__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchax/CONTRIBUTING.md +43 -0
- torchax/__init__.py +153 -0
- torchax/amp.py +346 -0
- torchax/checkpoint.py +79 -0
- torchax/config.py +44 -0
- torchax/decompositions.py +790 -0
- torchax/device_module.py +47 -0
- torchax/export.py +259 -0
- torchax/flax.py +53 -0
- torchax/interop.py +369 -0
- torchax/mesh_util.py +234 -0
- torchax/ops/__init__.py +24 -0
- torchax/ops/jaten.py +5937 -0
- torchax/ops/jax_reimplement.py +185 -0
- torchax/ops/jc10d.py +66 -0
- torchax/ops/jimage.py +127 -0
- torchax/ops/jlibrary.py +94 -0
- torchax/ops/jtorch.py +631 -0
- torchax/ops/jtorchvision_nms.py +248 -0
- torchax/ops/mappings.py +161 -0
- torchax/ops/op_base.py +145 -0
- torchax/ops/ops_registry.py +69 -0
- torchax/tensor.py +736 -0
- torchax/train.py +132 -0
- torchax/types.py +26 -0
- torchax/util.py +102 -0
- torchax/view.py +391 -0
- torchax-0.0.10.dev20251117.dist-info/METADATA +507 -0
- torchax-0.0.10.dev20251117.dist-info/RECORD +31 -0
- torchax-0.0.10.dev20251117.dist-info/WHEEL +4 -0
- torchax-0.0.10.dev20251117.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,248 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""
|
|
16
|
+
Forked at: https://raw.githubusercontent.com/mlperf/training_results_v0.7/refs/heads/master/Google/benchmarks/ssd/implementations/ssd-research-JAX-tpu-v3-4096/nms.py
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import functools
|
|
20
|
+
from typing import List, Union, Optional, Tuple
|
|
21
|
+
|
|
22
|
+
import torch
|
|
23
|
+
from jax import lax
|
|
24
|
+
import jax.numpy as jnp
|
|
25
|
+
from . import ops_registry
|
|
26
|
+
|
|
27
|
+
_NMS_TILE_SIZE = 256
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _bbox_overlap(boxes, gt_boxes):
|
|
31
|
+
"""Find Bounding box overlap.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
boxes: first set of bounding boxes
|
|
35
|
+
gt_boxes: second set of boxes to compute IOU
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
iou: Intersection over union matrix of all input bounding boxes
|
|
39
|
+
"""
|
|
40
|
+
bb_y_min, bb_x_min, bb_y_max, bb_x_max = jnp.split(
|
|
41
|
+
ary=boxes, indices_or_sections=4, axis=2)
|
|
42
|
+
gt_y_min, gt_x_min, gt_y_max, gt_x_max = jnp.split(
|
|
43
|
+
ary=gt_boxes, indices_or_sections=4, axis=2)
|
|
44
|
+
|
|
45
|
+
# Calculates the intersection area.
|
|
46
|
+
i_xmin = jnp.maximum(bb_x_min, jnp.transpose(gt_x_min, [0, 2, 1]))
|
|
47
|
+
i_xmax = jnp.minimum(bb_x_max, jnp.transpose(gt_x_max, [0, 2, 1]))
|
|
48
|
+
i_ymin = jnp.maximum(bb_y_min, jnp.transpose(gt_y_min, [0, 2, 1]))
|
|
49
|
+
i_ymax = jnp.minimum(bb_y_max, jnp.transpose(gt_y_max, [0, 2, 1]))
|
|
50
|
+
i_area = jnp.maximum((i_xmax - i_xmin), 0) * jnp.maximum((i_ymax - i_ymin), 0)
|
|
51
|
+
|
|
52
|
+
# Calculates the union area.
|
|
53
|
+
bb_area = (bb_y_max - bb_y_min) * (bb_x_max - bb_x_min)
|
|
54
|
+
gt_area = (gt_y_max - gt_y_min) * (gt_x_max - gt_x_min)
|
|
55
|
+
# Adds a small epsilon to avoid divide-by-zero.
|
|
56
|
+
u_area = bb_area + jnp.transpose(gt_area, [0, 2, 1]) - i_area + 1e-8
|
|
57
|
+
|
|
58
|
+
# Calculates IoU.
|
|
59
|
+
iou = i_area / u_area
|
|
60
|
+
|
|
61
|
+
return iou
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _self_suppression(in_args):
|
|
65
|
+
iou, _, iou_sum = in_args
|
|
66
|
+
batch_size = iou.shape[0]
|
|
67
|
+
can_suppress_others = jnp.reshape(
|
|
68
|
+
jnp.max(iou, 1) <= 0.5, [batch_size, -1, 1]).astype(iou.dtype)
|
|
69
|
+
iou_suppressed = jnp.reshape(
|
|
70
|
+
(jnp.max(can_suppress_others * iou, 1) <= 0.5).astype(
|
|
71
|
+
iou.dtype), [batch_size, -1, 1]) * iou
|
|
72
|
+
iou_sum_new = jnp.sum(iou_suppressed, [1, 2])
|
|
73
|
+
return iou_suppressed, jnp.any(iou_sum - iou_sum_new > 0.5), iou_sum_new
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _cross_suppression(in_args):
|
|
77
|
+
boxes, box_slice, iou_threshold, inner_idx = in_args
|
|
78
|
+
batch_size = boxes.shape[0]
|
|
79
|
+
new_slice = lax.dynamic_slice(boxes, [0, inner_idx * _NMS_TILE_SIZE, 0],
|
|
80
|
+
[batch_size, _NMS_TILE_SIZE, 4])
|
|
81
|
+
iou = _bbox_overlap(new_slice, box_slice)
|
|
82
|
+
ret_slice = jnp.expand_dims((jnp.all(iou < iou_threshold, [1])).astype(
|
|
83
|
+
box_slice.dtype), 2) * box_slice
|
|
84
|
+
return boxes, ret_slice, iou_threshold, inner_idx + 1
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _suppression_loop_body(in_args):
|
|
88
|
+
"""Process boxes in the range [idx*_NMS_TILE_SIZE, (idx+1)*_NMS_TILE_SIZE).
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
in_args: A tuple of arguments: boxes, iou_threshold, output_size, idx
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
boxes: updated boxes.
|
|
95
|
+
iou_threshold: pass down iou_threshold to the next iteration.
|
|
96
|
+
output_size: the updated output_size.
|
|
97
|
+
idx: the updated induction variable.
|
|
98
|
+
"""
|
|
99
|
+
boxes, iou_threshold, output_size, idx = in_args
|
|
100
|
+
num_tiles = boxes.shape[1] // _NMS_TILE_SIZE
|
|
101
|
+
batch_size = boxes.shape[0]
|
|
102
|
+
|
|
103
|
+
# Iterates over tiles that can possibly suppress the current tile.
|
|
104
|
+
box_slice = lax.dynamic_slice(boxes, [0, idx * _NMS_TILE_SIZE, 0],
|
|
105
|
+
[batch_size, _NMS_TILE_SIZE, 4])
|
|
106
|
+
|
|
107
|
+
def _loop_cond(in_args):
|
|
108
|
+
_, _, _, inner_idx = in_args
|
|
109
|
+
return inner_idx < idx
|
|
110
|
+
|
|
111
|
+
_, box_slice, _, _ = lax.while_loop(_loop_cond, _cross_suppression,
|
|
112
|
+
(boxes, box_slice, iou_threshold, 0))
|
|
113
|
+
|
|
114
|
+
# Iterates over the current tile to compute self-suppression.
|
|
115
|
+
iou = _bbox_overlap(box_slice, box_slice)
|
|
116
|
+
mask = jnp.expand_dims(
|
|
117
|
+
jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [1, -1])
|
|
118
|
+
> jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [-1, 1]), 0)
|
|
119
|
+
iou *= (jnp.logical_and(mask, iou >= iou_threshold)).astype(iou.dtype)
|
|
120
|
+
|
|
121
|
+
def _loop_cond2(in_args):
|
|
122
|
+
_, loop_condition, _ = in_args
|
|
123
|
+
return loop_condition
|
|
124
|
+
|
|
125
|
+
suppressed_iou, _, _ = lax.while_loop(_loop_cond2, _self_suppression,
|
|
126
|
+
(iou, True, jnp.sum(iou, [1, 2])))
|
|
127
|
+
suppressed_box = jnp.sum(suppressed_iou, 1) > 0
|
|
128
|
+
box_slice *= jnp.expand_dims(1.0 - suppressed_box.astype(box_slice.dtype), 2)
|
|
129
|
+
|
|
130
|
+
# Uses box_slice to update the input boxes.
|
|
131
|
+
mask = jnp.reshape((jnp.equal(jnp.arange(num_tiles),
|
|
132
|
+
idx)).astype(boxes.dtype), [1, -1, 1, 1])
|
|
133
|
+
boxes = jnp.tile(jnp.expand_dims(
|
|
134
|
+
box_slice, 1), [1, num_tiles, 1, 1]) * mask + jnp.reshape(
|
|
135
|
+
boxes, [batch_size, num_tiles, _NMS_TILE_SIZE, 4]) * (1 - mask)
|
|
136
|
+
boxes = jnp.reshape(boxes, [batch_size, -1, 4])
|
|
137
|
+
|
|
138
|
+
# Updates output_size.
|
|
139
|
+
output_size += jnp.sum(jnp.any(box_slice > 0, [2]).astype(jnp.int32), [1])
|
|
140
|
+
return boxes, iou_threshold, output_size, idx + 1
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def non_max_suppression_padded(scores, boxes, max_output_size, iou_threshold):
|
|
144
|
+
"""A wrapper that handles non-maximum suppression.
|
|
145
|
+
|
|
146
|
+
Assumption:
|
|
147
|
+
* The boxes are sorted by scores unless the box is a dot (all coordinates
|
|
148
|
+
are zero).
|
|
149
|
+
* Boxes with higher scores can be used to suppress boxes with lower scores.
|
|
150
|
+
|
|
151
|
+
The overal design of the algorithm is to handle boxes tile-by-tile:
|
|
152
|
+
|
|
153
|
+
boxes = boxes.pad_to_multiply_of(tile_size)
|
|
154
|
+
num_tiles = len(boxes) // tile_size
|
|
155
|
+
output_boxes = []
|
|
156
|
+
for i in range(num_tiles):
|
|
157
|
+
box_tile = boxes[i*tile_size : (i+1)*tile_size]
|
|
158
|
+
for j in range(i - 1):
|
|
159
|
+
suppressing_tile = boxes[j*tile_size : (j+1)*tile_size]
|
|
160
|
+
iou = _bbox_overlap(box_tile, suppressing_tile)
|
|
161
|
+
# if the box is suppressed in iou, clear it to a dot
|
|
162
|
+
box_tile *= _update_boxes(iou)
|
|
163
|
+
# Iteratively handle the diagnal tile.
|
|
164
|
+
iou = _box_overlap(box_tile, box_tile)
|
|
165
|
+
iou_changed = True
|
|
166
|
+
while iou_changed:
|
|
167
|
+
# boxes that are not suppressed by anything else
|
|
168
|
+
suppressing_boxes = _get_suppressing_boxes(iou)
|
|
169
|
+
# boxes that are suppressed by suppressing_boxes
|
|
170
|
+
suppressed_boxes = _get_suppressed_boxes(iou, suppressing_boxes)
|
|
171
|
+
# clear iou to 0 for boxes that are suppressed, as they cannot be used
|
|
172
|
+
# to suppress other boxes any more
|
|
173
|
+
new_iou = _clear_iou(iou, suppressed_boxes)
|
|
174
|
+
iou_changed = (new_iou != iou)
|
|
175
|
+
iou = new_iou
|
|
176
|
+
# remaining boxes that can still suppress others, are selected boxes.
|
|
177
|
+
output_boxes.append(_get_suppressing_boxes(iou))
|
|
178
|
+
if len(output_boxes) >= max_output_size:
|
|
179
|
+
break
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
scores: a tensor with a shape of [batch_size, anchors].
|
|
183
|
+
boxes: a tensor with a shape of [batch_size, anchors, 4].
|
|
184
|
+
max_output_size: a scalar integer `Tensor` representing the maximum number
|
|
185
|
+
of boxes to be selected by non max suppression.
|
|
186
|
+
iou_threshold: a float representing the threshold for deciding whether boxes
|
|
187
|
+
overlap too much with respect to IOU.
|
|
188
|
+
Returns:
|
|
189
|
+
nms_scores: a tensor with a shape of [batch_size, anchors]. It has same
|
|
190
|
+
dtype as input scores.
|
|
191
|
+
nms_proposals: a tensor with a shape of [batch_size, anchors, 4]. It has
|
|
192
|
+
same dtype as input boxes.
|
|
193
|
+
"""
|
|
194
|
+
batch_size = boxes.shape[0]
|
|
195
|
+
num_boxes = boxes.shape[1]
|
|
196
|
+
pad = int(jnp.ceil(
|
|
197
|
+
float(num_boxes) / _NMS_TILE_SIZE)) * _NMS_TILE_SIZE - num_boxes
|
|
198
|
+
boxes = jnp.pad(boxes.astype(jnp.float32), [[0, 0], [0, pad], [0, 0]])
|
|
199
|
+
scores = jnp.pad(scores.astype(jnp.float32), [[0, 0], [0, pad]])
|
|
200
|
+
num_boxes += pad
|
|
201
|
+
|
|
202
|
+
def _loop_cond(in_args):
|
|
203
|
+
unused_boxes, unused_threshold, output_size, idx = in_args
|
|
204
|
+
return jnp.logical_and(
|
|
205
|
+
jnp.min(output_size) < max_output_size, idx
|
|
206
|
+
< num_boxes // _NMS_TILE_SIZE)
|
|
207
|
+
|
|
208
|
+
selected_boxes, _, output_size, _ = lax.while_loop(
|
|
209
|
+
_loop_cond, _suppression_loop_body,
|
|
210
|
+
(boxes, iou_threshold, jnp.zeros([batch_size], jnp.int32), 0))
|
|
211
|
+
idx = num_boxes - lax.top_k(
|
|
212
|
+
jnp.any(selected_boxes > 0, [2]).astype(jnp.int32) *
|
|
213
|
+
jnp.expand_dims(jnp.arange(num_boxes, 0, -1), 0),
|
|
214
|
+
max_output_size)[0].astype(jnp.int32)
|
|
215
|
+
idx = jnp.minimum(idx, num_boxes - 1)
|
|
216
|
+
idx = jnp.reshape(
|
|
217
|
+
idx + jnp.reshape(jnp.arange(batch_size) * num_boxes, [-1, 1]), [-1])
|
|
218
|
+
|
|
219
|
+
return idx
|
|
220
|
+
boxes = jnp.reshape((jnp.reshape(boxes, [-1, 4]))[idx],
|
|
221
|
+
[batch_size, max_output_size, 4])
|
|
222
|
+
boxes = boxes * (jnp.reshape(jnp.arange(max_output_size), [1, -1, 1])
|
|
223
|
+
< jnp.reshape(output_size, [-1, 1, 1])).astype(boxes.dtype)
|
|
224
|
+
scores = jnp.reshape(
|
|
225
|
+
jnp.reshape(scores, [-1, 1])[idx], [batch_size, max_output_size])
|
|
226
|
+
scores = scores * (jnp.reshape(jnp.arange(max_output_size), [1, -1])
|
|
227
|
+
< jnp.reshape(output_size, [-1, 1])).astype(scores.dtype)
|
|
228
|
+
return scores, boxes
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
# registry:
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def nms(boxes, scores, iou_threshold):
|
|
235
|
+
max_output_size = boxes.shape[0]
|
|
236
|
+
boxes = boxes.reshape((1, *boxes.shape))
|
|
237
|
+
scores = scores.reshape((1, *scores.shape))
|
|
238
|
+
res = non_max_suppression_padded(scores, boxes, max_output_size,
|
|
239
|
+
iou_threshold)
|
|
240
|
+
return res
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
try:
|
|
244
|
+
import torch
|
|
245
|
+
import torchvision
|
|
246
|
+
ops_registry.register_torch_dispatch_op(torch.ops.torchvision.nms, nms)
|
|
247
|
+
except Exception:
|
|
248
|
+
pass
|
torchax/ops/mappings.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from jax import dlpack as jaxdl
|
|
16
|
+
import jax.numpy as jnp
|
|
17
|
+
import numpy
|
|
18
|
+
import torch
|
|
19
|
+
import torch.func
|
|
20
|
+
import torch.utils.dlpack as torchdl
|
|
21
|
+
import torch.utils._mode_utils as mode_utils
|
|
22
|
+
|
|
23
|
+
NUMPY_UNSUPPORTED_DTYPES = {
|
|
24
|
+
torch.bfloat16: jnp.bfloat16,
|
|
25
|
+
torch.float8_e4m3fn: jnp.float8_e4m3fn,
|
|
26
|
+
torch.float8_e4m3fnuz: jnp.float8_e4m3fnuz,
|
|
27
|
+
torch.float8_e5m2: jnp.float8_e5m2,
|
|
28
|
+
torch.float8_e5m2fnuz: jnp.float8_e5m2fnuz,
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def t2j(t, use_dlpack=True):
|
|
33
|
+
is_bool = False
|
|
34
|
+
if t.dtype == torch.bool:
|
|
35
|
+
is_bool = True
|
|
36
|
+
t = t.to(torch.int8)
|
|
37
|
+
|
|
38
|
+
t = t.to_dense()
|
|
39
|
+
|
|
40
|
+
if not t.is_contiguous():
|
|
41
|
+
t = t.contiguous()
|
|
42
|
+
|
|
43
|
+
res = None
|
|
44
|
+
if use_dlpack:
|
|
45
|
+
try:
|
|
46
|
+
res = jaxdl.from_dlpack(t)
|
|
47
|
+
except Exception:
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
if res is None:
|
|
51
|
+
# https://github.com/google/jax/issues/7657
|
|
52
|
+
# https://github.com/google/jax/issues/17784
|
|
53
|
+
if t.dtype in NUMPY_UNSUPPORTED_DTYPES:
|
|
54
|
+
nparray = (t.cpu().detach().to(torch.float32).numpy()
|
|
55
|
+
) # handle dtypes not supported by numpy
|
|
56
|
+
else:
|
|
57
|
+
nparray = t.cpu().detach().numpy()
|
|
58
|
+
res = jnp.asarray(nparray)
|
|
59
|
+
if t.dtype in NUMPY_UNSUPPORTED_DTYPES:
|
|
60
|
+
res = res.astype(NUMPY_UNSUPPORTED_DTYPES[t.dtype])
|
|
61
|
+
|
|
62
|
+
if is_bool:
|
|
63
|
+
res = res.astype(jnp.bool_)
|
|
64
|
+
return res
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def j2t(x, use_dlpack=True):
|
|
68
|
+
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
|
|
69
|
+
res = None
|
|
70
|
+
if use_dlpack:
|
|
71
|
+
try:
|
|
72
|
+
dl = jaxdl.to_dlpack(x)
|
|
73
|
+
res = torchdl.from_dlpack(dl)
|
|
74
|
+
except Exception:
|
|
75
|
+
res = None
|
|
76
|
+
|
|
77
|
+
orig_dtype = None
|
|
78
|
+
if res is None:
|
|
79
|
+
orig_dtype = None
|
|
80
|
+
if x.dtype == jnp.bfloat16.dtype:
|
|
81
|
+
orig_dtype = x.dtype
|
|
82
|
+
x = x.astype(jnp.float32.dtype)
|
|
83
|
+
res = torch.from_numpy(numpy.asarray(x))
|
|
84
|
+
|
|
85
|
+
if x.dtype == jnp.bool_:
|
|
86
|
+
res = res.to(torch.bool)
|
|
87
|
+
|
|
88
|
+
if orig_dtype is not None:
|
|
89
|
+
res = res.to(j2t_dtype(orig_dtype))
|
|
90
|
+
return res
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
TORCH_DTYPE_TO_JAX = {
|
|
94
|
+
# NO_MAPPING : jnp.float0.dtype (signless scalar int),
|
|
95
|
+
torch.bool:
|
|
96
|
+
jnp.bool_.dtype,
|
|
97
|
+
# NO_MAPPING : jnp.int4.dtype,
|
|
98
|
+
torch.int8:
|
|
99
|
+
jnp.int8.dtype,
|
|
100
|
+
torch.int16:
|
|
101
|
+
jnp.int16.dtype,
|
|
102
|
+
torch.int32:
|
|
103
|
+
jnp.int32.dtype,
|
|
104
|
+
torch.int64:
|
|
105
|
+
jnp.int64.dtype,
|
|
106
|
+
torch.long:
|
|
107
|
+
jnp.int64.dtype,
|
|
108
|
+
# NO_MAPPING : jnp.uint4
|
|
109
|
+
torch.uint8:
|
|
110
|
+
jnp.uint8.dtype,
|
|
111
|
+
torch.uint16:
|
|
112
|
+
jnp.uint16.dtype,
|
|
113
|
+
torch.uint32:
|
|
114
|
+
jnp.uint32.dtype,
|
|
115
|
+
torch.uint64:
|
|
116
|
+
jnp.uint64.dtype,
|
|
117
|
+
# NO_MAPPING : jnp.float8_e4m3b11fnuz.dtype,
|
|
118
|
+
torch.float8_e4m3fn:
|
|
119
|
+
jnp.float8_e4m3fn.dtype,
|
|
120
|
+
# NO_MAPPING : jnp.float8_e4m3fnuz.dtype,
|
|
121
|
+
torch.float8_e5m2:
|
|
122
|
+
jnp.float8_e5m2.dtype,
|
|
123
|
+
# NO_MAPPING : jnp.float8_e5m2fnuz.dtype,
|
|
124
|
+
torch.bfloat16:
|
|
125
|
+
jnp.bfloat16.dtype,
|
|
126
|
+
torch.half:
|
|
127
|
+
jnp.float16.dtype,
|
|
128
|
+
torch.float16:
|
|
129
|
+
jnp.float16.dtype,
|
|
130
|
+
torch.float32:
|
|
131
|
+
jnp.float32.dtype,
|
|
132
|
+
torch.float64:
|
|
133
|
+
jnp.float64.dtype,
|
|
134
|
+
torch.double:
|
|
135
|
+
jnp.double.dtype,
|
|
136
|
+
torch.complex64:
|
|
137
|
+
jnp.complex64.dtype,
|
|
138
|
+
torch.complex128:
|
|
139
|
+
jnp.complex128.dtype,
|
|
140
|
+
None:
|
|
141
|
+
None,
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
JAX_DTYPE_TO_TORCH = {value: key for key, value in TORCH_DTYPE_TO_JAX.items()}
|
|
145
|
+
# Add imprecise mappings for some JAX dtypes which don't have torch analogues
|
|
146
|
+
JAX_DTYPE_TO_TORCH[jnp.dtype('int4')] = torch.int8
|
|
147
|
+
JAX_DTYPE_TO_TORCH[jnp.dtype('uint4')] = torch.uint8
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def t2j_dtype(dtype):
|
|
151
|
+
if dtype not in TORCH_DTYPE_TO_JAX:
|
|
152
|
+
raise RuntimeError(
|
|
153
|
+
f'Attempting to convert unknown type: {dtype} to jax type,')
|
|
154
|
+
return TORCH_DTYPE_TO_JAX[dtype]
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def j2t_dtype(dtype):
|
|
158
|
+
if dtype not in JAX_DTYPE_TO_TORCH:
|
|
159
|
+
raise RuntimeError(
|
|
160
|
+
f'Attempting to convert unknown type: {dtype} to torch type,')
|
|
161
|
+
return JAX_DTYPE_TO_TORCH[dtype]
|
torchax/ops/op_base.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import functools
|
|
16
|
+
import jax
|
|
17
|
+
import jax.numpy as jnp
|
|
18
|
+
import numpy as np
|
|
19
|
+
import torch
|
|
20
|
+
from torchax.ops import mappings
|
|
21
|
+
from torchax.view import View
|
|
22
|
+
from torchax import types
|
|
23
|
+
import sys
|
|
24
|
+
|
|
25
|
+
from typing import Callable, Optional, ParamSpec, Concatenate
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class InplaceOp:
|
|
29
|
+
|
|
30
|
+
def __init__(self,
|
|
31
|
+
functional_op,
|
|
32
|
+
replace=False,
|
|
33
|
+
position_to_mutate=0,
|
|
34
|
+
is_jax_func=False):
|
|
35
|
+
self.functional = functional_op
|
|
36
|
+
self.replace = replace
|
|
37
|
+
self.position_to_mutate = position_to_mutate
|
|
38
|
+
self.is_jax_func = is_jax_func
|
|
39
|
+
|
|
40
|
+
def __call__(self, *args, **kwargs):
|
|
41
|
+
to_mutate = args[self.position_to_mutate]
|
|
42
|
+
view_value = to_mutate
|
|
43
|
+
if isinstance(to_mutate, View):
|
|
44
|
+
view_value = to_mutate.torch()
|
|
45
|
+
# Convert the target View to a Tensor, and
|
|
46
|
+
# leave the rest args as is. If other args are
|
|
47
|
+
# also View, they will be converted to tensors
|
|
48
|
+
# in the self.functional dispatch.
|
|
49
|
+
env = view_value._env
|
|
50
|
+
if self.is_jax_func:
|
|
51
|
+
view_value, args, kwargs = env.t2j_iso((view_value, args, kwargs))
|
|
52
|
+
new_value_jax = self.functional(view_value, *args[1:], **kwargs)
|
|
53
|
+
new_value = env.j2t_iso(new_value_jax)
|
|
54
|
+
else:
|
|
55
|
+
new_value = self.functional(view_value, *args[1:], **kwargs)
|
|
56
|
+
|
|
57
|
+
if isinstance(to_mutate, View):
|
|
58
|
+
to_mutate.update(new_value)
|
|
59
|
+
else:
|
|
60
|
+
if self.replace:
|
|
61
|
+
to_mutate._elem = new_value._elem
|
|
62
|
+
else:
|
|
63
|
+
to_mutate.copy_(new_value)
|
|
64
|
+
return to_mutate
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class OutVariant:
|
|
68
|
+
|
|
69
|
+
def __call__(self, *args, **kwargs):
|
|
70
|
+
to_mutate = kwargs['out']
|
|
71
|
+
del kwargs['out']
|
|
72
|
+
to_mutate._elem = self.functional(*args, **kwargs)._elem
|
|
73
|
+
return to_mutate
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
P = ParamSpec('P')
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def convert_dtype(use_default_dtype: bool = True):
|
|
80
|
+
"""Converts `dtype` kwarg of function from torch to JAX.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
use_default_dtype: Whether to use torch default dtype if none is provided.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
A decorator that wraps a JAX implementation of a torch function.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
def decorator(func: types.TorchCallable):
|
|
90
|
+
|
|
91
|
+
@functools.wraps(func)
|
|
92
|
+
def wrapper(*args: P.args,
|
|
93
|
+
dtype: Optional[torch.dtype] = None,
|
|
94
|
+
**kwargs: P.kwargs):
|
|
95
|
+
if not dtype and use_default_dtype:
|
|
96
|
+
dtype = torch.get_default_dtype()
|
|
97
|
+
if isinstance(dtype, torch.dtype):
|
|
98
|
+
jax_dtype = mappings.t2j_dtype(dtype)
|
|
99
|
+
else:
|
|
100
|
+
jax_dtype = dtype
|
|
101
|
+
|
|
102
|
+
return func(*args, dtype=jax_dtype, **kwargs)
|
|
103
|
+
|
|
104
|
+
return wrapper
|
|
105
|
+
|
|
106
|
+
return decorator
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def maybe_convert_constant_dtype(val: Optional[types.JaxValue],
|
|
110
|
+
dtype: Optional[jnp.dtype]):
|
|
111
|
+
"""Optionally converts scalar constant's dtype using `numpy`
|
|
112
|
+
|
|
113
|
+
Use in cases where you require a constant and can't handle a traced array.
|
|
114
|
+
"""
|
|
115
|
+
if val and dtype:
|
|
116
|
+
if isinstance(val, jax.Array):
|
|
117
|
+
return maybe_convert_constant_dtype(val.item(), dtype)
|
|
118
|
+
|
|
119
|
+
return np.array(val, dtype)
|
|
120
|
+
|
|
121
|
+
return val
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def promote_int_input(f: Callable[Concatenate[jax.Array, P], types.JaxValue]):
|
|
125
|
+
"""If the first argument is an int array, promote it to float32."""
|
|
126
|
+
|
|
127
|
+
@functools.wraps(f)
|
|
128
|
+
def wrapper(x: jax.Array, *args: P.args, **kwargs: P.kwargs):
|
|
129
|
+
if x.dtype in [jnp.int8, jnp.int16, jnp.int32, jnp.int64]:
|
|
130
|
+
x = x.astype(mappings.t2j_dtype(torch.get_default_dtype()))
|
|
131
|
+
|
|
132
|
+
return f(x, *args, **kwargs)
|
|
133
|
+
|
|
134
|
+
return wrapper
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def foreach_loop(seq: jax.Array,
|
|
138
|
+
fn: Callable[[jax.Array, jax.Array], jax.Array],
|
|
139
|
+
init_val=0.0):
|
|
140
|
+
"""Run `fn` for each element of 1D array `seq`.
|
|
141
|
+
|
|
142
|
+
Similar to `functools.reduce`, but implemented with `jax.lax.fori_loop`."""
|
|
143
|
+
assert len(seq.shape) == 1
|
|
144
|
+
return jax.lax.fori_loop(0, len(seq), lambda i, carry: fn(carry, seq[i]),
|
|
145
|
+
init_val)
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import dataclasses
|
|
16
|
+
import logging
|
|
17
|
+
from torchax.types import JaxCallable, TorchCallable
|
|
18
|
+
|
|
19
|
+
from typing import Union, Dict
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclasses.dataclass
|
|
23
|
+
class Operator:
|
|
24
|
+
torch_op: TorchCallable
|
|
25
|
+
func: Union[TorchCallable, JaxCallable]
|
|
26
|
+
is_jax_function: bool
|
|
27
|
+
is_user_defined: bool
|
|
28
|
+
needs_env: bool
|
|
29
|
+
is_view_op: bool
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
all_aten_ops: Dict[TorchCallable, Operator] = {}
|
|
33
|
+
all_torch_functions: Dict[TorchCallable, Operator] = {}
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def register_torch_dispatch_op(aten_op,
|
|
37
|
+
impl_callable,
|
|
38
|
+
is_jax_function=True,
|
|
39
|
+
is_user_defined=False,
|
|
40
|
+
needs_env=False,
|
|
41
|
+
is_view_op=False):
|
|
42
|
+
op = Operator(
|
|
43
|
+
aten_op,
|
|
44
|
+
impl_callable,
|
|
45
|
+
is_jax_function=is_jax_function,
|
|
46
|
+
is_user_defined=is_user_defined,
|
|
47
|
+
needs_env=needs_env,
|
|
48
|
+
is_view_op=is_view_op)
|
|
49
|
+
if aten_op in all_aten_ops:
|
|
50
|
+
logging.warning(f'Duplicate op registration for {aten_op}')
|
|
51
|
+
all_aten_ops[aten_op] = op
|
|
52
|
+
return impl_callable
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def register_torch_function_op(torch_func,
|
|
56
|
+
impl_callable,
|
|
57
|
+
is_jax_function=True,
|
|
58
|
+
is_user_defined=False,
|
|
59
|
+
needs_env=False,
|
|
60
|
+
is_view_op=False):
|
|
61
|
+
op = Operator(
|
|
62
|
+
torch_func,
|
|
63
|
+
impl_callable,
|
|
64
|
+
is_jax_function=is_jax_function,
|
|
65
|
+
is_user_defined=is_user_defined,
|
|
66
|
+
needs_env=needs_env,
|
|
67
|
+
is_view_op=is_view_op)
|
|
68
|
+
all_torch_functions[torch_func] = op
|
|
69
|
+
return impl_callable
|