onnx2fx 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2fx/__init__.py +96 -0
- onnx2fx/converter.py +62 -0
- onnx2fx/exceptions.py +155 -0
- onnx2fx/graph_builder.py +634 -0
- onnx2fx/op_registry.py +345 -0
- onnx2fx/ops/__init__.py +74 -0
- onnx2fx/ops/activation.py +282 -0
- onnx2fx/ops/arithmetic.py +281 -0
- onnx2fx/ops/attention.py +1055 -0
- onnx2fx/ops/attention_msft.py +682 -0
- onnx2fx/ops/control_flow.py +947 -0
- onnx2fx/ops/convolution.py +406 -0
- onnx2fx/ops/image.py +748 -0
- onnx2fx/ops/linalg.py +33 -0
- onnx2fx/ops/loss.py +56 -0
- onnx2fx/ops/nn.py +96 -0
- onnx2fx/ops/normalization.py +289 -0
- onnx2fx/ops/pooling.py +897 -0
- onnx2fx/ops/quantization.py +524 -0
- onnx2fx/ops/random.py +102 -0
- onnx2fx/ops/recurrent.py +647 -0
- onnx2fx/ops/reduction.py +534 -0
- onnx2fx/ops/sequence.py +304 -0
- onnx2fx/ops/signal.py +444 -0
- onnx2fx/ops/string.py +126 -0
- onnx2fx/ops/tensor.py +1161 -0
- onnx2fx/ops/training.py +402 -0
- onnx2fx/py.typed +0 -0
- onnx2fx/utils/__init__.py +45 -0
- onnx2fx/utils/analyze.py +139 -0
- onnx2fx/utils/attributes.py +150 -0
- onnx2fx/utils/dtype.py +107 -0
- onnx2fx/utils/external_data.py +233 -0
- onnx2fx/utils/names.py +43 -0
- onnx2fx/utils/op_helpers.py +339 -0
- onnx2fx/utils/training.py +54 -0
- onnx2fx-0.0.0.dist-info/METADATA +395 -0
- onnx2fx-0.0.0.dist-info/RECORD +39 -0
- onnx2fx-0.0.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,524 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""Quantization operators."""
|
|
3
|
+
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
import onnx
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from ..op_registry import register
|
|
10
|
+
from ..utils.attributes import get_attribute
|
|
11
|
+
from ..utils.op_helpers import get_optional_input
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from ..graph_builder import GraphBuilder
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
# =============================================================================
|
|
18
|
+
# Basic quantization operators
|
|
19
|
+
# =============================================================================
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@register("QuantizeLinear")
|
|
23
|
+
def quantize_linear(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
24
|
+
"""Quantize input tensor using scale and zero_point."""
|
|
25
|
+
x = builder.get_value(node.input[0])
|
|
26
|
+
y_scale = builder.get_value(node.input[1])
|
|
27
|
+
y_zero_point = get_optional_input(builder, node, 2)
|
|
28
|
+
|
|
29
|
+
if y_zero_point is not None:
|
|
30
|
+
|
|
31
|
+
def _quantize_uint8(
|
|
32
|
+
inp: torch.Tensor, s: torch.Tensor, zp: torch.Tensor
|
|
33
|
+
) -> torch.Tensor:
|
|
34
|
+
return torch.clamp(torch.round(inp / s) + zp.float(), 0, 255).to(
|
|
35
|
+
torch.uint8
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
return builder.call_function(_quantize_uint8, args=(x, y_scale, y_zero_point))
|
|
39
|
+
else:
|
|
40
|
+
|
|
41
|
+
def _quantize_int8(inp: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
|
|
42
|
+
return torch.clamp(torch.round(inp / s), -128, 127).to(torch.int8)
|
|
43
|
+
|
|
44
|
+
return builder.call_function(_quantize_int8, args=(x, y_scale))
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@register("DequantizeLinear")
|
|
48
|
+
def dequantize_linear(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
49
|
+
"""Dequantize input tensor using scale and zero_point."""
|
|
50
|
+
x = builder.get_value(node.input[0])
|
|
51
|
+
x_scale = builder.get_value(node.input[1])
|
|
52
|
+
x_zero_point = get_optional_input(builder, node, 2)
|
|
53
|
+
|
|
54
|
+
if x_zero_point is not None:
|
|
55
|
+
|
|
56
|
+
def _dequantize(
|
|
57
|
+
inp: torch.Tensor, s: torch.Tensor, zp: torch.Tensor
|
|
58
|
+
) -> torch.Tensor:
|
|
59
|
+
return (inp.float() - zp.float()) * s
|
|
60
|
+
|
|
61
|
+
return builder.call_function(_dequantize, args=(x, x_scale, x_zero_point))
|
|
62
|
+
else:
|
|
63
|
+
|
|
64
|
+
def _dequantize_no_zp(inp: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
|
|
65
|
+
return inp.float() * s
|
|
66
|
+
|
|
67
|
+
return builder.call_function(_dequantize_no_zp, args=(x, x_scale))
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@register("DynamicQuantizeLinear")
|
|
71
|
+
def dynamic_quantize_linear(
|
|
72
|
+
builder: "GraphBuilder", node: onnx.NodeProto
|
|
73
|
+
) -> torch.fx.Node:
|
|
74
|
+
"""Dynamic quantization of input tensor to uint8.
|
|
75
|
+
|
|
76
|
+
Returns tuple of (y, y_scale, y_zero_point).
|
|
77
|
+
"""
|
|
78
|
+
x = builder.get_value(node.input[0])
|
|
79
|
+
|
|
80
|
+
def _dynamic_quantize(
|
|
81
|
+
inp: torch.Tensor,
|
|
82
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
83
|
+
x_min = torch.min(inp)
|
|
84
|
+
x_max = torch.max(inp)
|
|
85
|
+
scale = (x_max - x_min) / 255.0
|
|
86
|
+
zero_point = torch.clamp(torch.round(-x_min / scale), 0, 255).to(torch.uint8)
|
|
87
|
+
y = torch.clamp(torch.round(inp / scale) + zero_point.float(), 0, 255).to(
|
|
88
|
+
torch.uint8
|
|
89
|
+
)
|
|
90
|
+
return y, scale, zero_point
|
|
91
|
+
|
|
92
|
+
return builder.call_function(_dynamic_quantize, args=(x,))
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
# =============================================================================
|
|
96
|
+
# QLinear operators
|
|
97
|
+
# =============================================================================
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@register("QLinearMatMul")
|
|
101
|
+
def qlinear_matmul(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
102
|
+
"""Quantized MatMul with scales and zero points."""
|
|
103
|
+
a = builder.get_value(node.input[0])
|
|
104
|
+
a_scale = builder.get_value(node.input[1])
|
|
105
|
+
a_zero_point = builder.get_value(node.input[2])
|
|
106
|
+
b = builder.get_value(node.input[3])
|
|
107
|
+
b_scale = builder.get_value(node.input[4])
|
|
108
|
+
b_zero_point = builder.get_value(node.input[5])
|
|
109
|
+
y_scale = builder.get_value(node.input[6])
|
|
110
|
+
y_zero_point = builder.get_value(node.input[7])
|
|
111
|
+
|
|
112
|
+
def _qlinear_matmul(
|
|
113
|
+
a: torch.Tensor,
|
|
114
|
+
a_s: torch.Tensor,
|
|
115
|
+
a_zp: torch.Tensor,
|
|
116
|
+
b: torch.Tensor,
|
|
117
|
+
b_s: torch.Tensor,
|
|
118
|
+
b_zp: torch.Tensor,
|
|
119
|
+
y_s: torch.Tensor,
|
|
120
|
+
y_zp: torch.Tensor,
|
|
121
|
+
) -> torch.Tensor:
|
|
122
|
+
# Dequantize
|
|
123
|
+
a_dq = (a.float() - a_zp.float()) * a_s
|
|
124
|
+
b_dq = (b.float() - b_zp.float()) * b_s
|
|
125
|
+
# MatMul
|
|
126
|
+
result = torch.matmul(a_dq, b_dq)
|
|
127
|
+
# Quantize output
|
|
128
|
+
return torch.clamp(torch.round(result / y_s) + y_zp.float(), 0, 255).to(
|
|
129
|
+
torch.uint8
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
return builder.call_function(
|
|
133
|
+
_qlinear_matmul,
|
|
134
|
+
args=(
|
|
135
|
+
a,
|
|
136
|
+
a_scale,
|
|
137
|
+
a_zero_point,
|
|
138
|
+
b,
|
|
139
|
+
b_scale,
|
|
140
|
+
b_zero_point,
|
|
141
|
+
y_scale,
|
|
142
|
+
y_zero_point,
|
|
143
|
+
),
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@register("QLinearConv")
|
|
148
|
+
def qlinear_conv(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
149
|
+
"""Quantized 2D convolution with scales and zero points.
|
|
150
|
+
|
|
151
|
+
Inputs: x, x_scale, x_zero_point, w, w_scale, w_zero_point, y_scale, y_zero_point, [B]
|
|
152
|
+
"""
|
|
153
|
+
x = builder.get_value(node.input[0])
|
|
154
|
+
x_scale = builder.get_value(node.input[1])
|
|
155
|
+
x_zero_point = builder.get_value(node.input[2])
|
|
156
|
+
w = builder.get_value(node.input[3])
|
|
157
|
+
w_scale = builder.get_value(node.input[4])
|
|
158
|
+
w_zero_point = builder.get_value(node.input[5])
|
|
159
|
+
y_scale = builder.get_value(node.input[6])
|
|
160
|
+
y_zero_point = builder.get_value(node.input[7])
|
|
161
|
+
bias = get_optional_input(builder, node, 8)
|
|
162
|
+
|
|
163
|
+
# Get convolution attributes
|
|
164
|
+
# Note: kernel_shape is inferred from weight tensor, not from attribute
|
|
165
|
+
auto_pad = get_attribute(node, "auto_pad", "NOTSET")
|
|
166
|
+
dilations = get_attribute(node, "dilations", [1, 1])
|
|
167
|
+
group = get_attribute(node, "group", 1)
|
|
168
|
+
pads = get_attribute(node, "pads", [0, 0, 0, 0])
|
|
169
|
+
strides = get_attribute(node, "strides", [1, 1])
|
|
170
|
+
|
|
171
|
+
if auto_pad != "NOTSET":
|
|
172
|
+
# Handle auto_pad - for simplicity, assume SAME_UPPER
|
|
173
|
+
pass
|
|
174
|
+
|
|
175
|
+
# Convert pads from ONNX format [H_begin, W_begin, H_end, W_end] to PyTorch format
|
|
176
|
+
if len(pads) == 4:
|
|
177
|
+
padding = (pads[0], pads[1]) # Symmetric padding
|
|
178
|
+
else:
|
|
179
|
+
padding = tuple(pads)
|
|
180
|
+
|
|
181
|
+
def _qlinear_conv(
|
|
182
|
+
x: torch.Tensor,
|
|
183
|
+
x_s: torch.Tensor,
|
|
184
|
+
x_zp: torch.Tensor,
|
|
185
|
+
w: torch.Tensor,
|
|
186
|
+
w_s: torch.Tensor,
|
|
187
|
+
w_zp: torch.Tensor,
|
|
188
|
+
y_s: torch.Tensor,
|
|
189
|
+
y_zp: torch.Tensor,
|
|
190
|
+
bias: torch.Tensor | None,
|
|
191
|
+
stride: tuple,
|
|
192
|
+
padding: tuple,
|
|
193
|
+
dilation: tuple,
|
|
194
|
+
groups: int,
|
|
195
|
+
) -> torch.Tensor:
|
|
196
|
+
# Dequantize input and weight
|
|
197
|
+
x_dq = (x.float() - x_zp.float()) * x_s
|
|
198
|
+
w_dq = (w.float() - w_zp.float()) * w_s
|
|
199
|
+
|
|
200
|
+
# Perform convolution
|
|
201
|
+
result = torch.nn.functional.conv2d(
|
|
202
|
+
x_dq,
|
|
203
|
+
w_dq,
|
|
204
|
+
bias=bias,
|
|
205
|
+
stride=stride,
|
|
206
|
+
padding=padding,
|
|
207
|
+
dilation=dilation,
|
|
208
|
+
groups=groups,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
# Quantize output
|
|
212
|
+
return torch.clamp(torch.round(result / y_s) + y_zp.float(), 0, 255).to(
|
|
213
|
+
torch.uint8
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
return builder.call_function(
|
|
217
|
+
_qlinear_conv,
|
|
218
|
+
args=(
|
|
219
|
+
x,
|
|
220
|
+
x_scale,
|
|
221
|
+
x_zero_point,
|
|
222
|
+
w,
|
|
223
|
+
w_scale,
|
|
224
|
+
w_zero_point,
|
|
225
|
+
y_scale,
|
|
226
|
+
y_zero_point,
|
|
227
|
+
bias,
|
|
228
|
+
tuple(strides),
|
|
229
|
+
padding,
|
|
230
|
+
tuple(dilations),
|
|
231
|
+
group,
|
|
232
|
+
),
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
@register("QLinearAdd")
|
|
237
|
+
def qlinear_add(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
238
|
+
"""Quantized addition with scales and zero points (com.microsoft domain).
|
|
239
|
+
|
|
240
|
+
Inputs: A, A_scale, A_zero_point, B, B_scale, B_zero_point, C_scale, C_zero_point
|
|
241
|
+
"""
|
|
242
|
+
a = builder.get_value(node.input[0])
|
|
243
|
+
a_scale = builder.get_value(node.input[1])
|
|
244
|
+
a_zero_point = builder.get_value(node.input[2])
|
|
245
|
+
b = builder.get_value(node.input[3])
|
|
246
|
+
b_scale = builder.get_value(node.input[4])
|
|
247
|
+
b_zero_point = builder.get_value(node.input[5])
|
|
248
|
+
c_scale = builder.get_value(node.input[6])
|
|
249
|
+
c_zero_point = builder.get_value(node.input[7])
|
|
250
|
+
|
|
251
|
+
def _qlinear_add(
|
|
252
|
+
a: torch.Tensor,
|
|
253
|
+
a_s: torch.Tensor,
|
|
254
|
+
a_zp: torch.Tensor,
|
|
255
|
+
b: torch.Tensor,
|
|
256
|
+
b_s: torch.Tensor,
|
|
257
|
+
b_zp: torch.Tensor,
|
|
258
|
+
c_s: torch.Tensor,
|
|
259
|
+
c_zp: torch.Tensor,
|
|
260
|
+
) -> torch.Tensor:
|
|
261
|
+
# Dequantize
|
|
262
|
+
a_dq = (a.float() - a_zp.float()) * a_s
|
|
263
|
+
b_dq = (b.float() - b_zp.float()) * b_s
|
|
264
|
+
# Add
|
|
265
|
+
result = a_dq + b_dq
|
|
266
|
+
# Quantize output
|
|
267
|
+
return torch.clamp(torch.round(result / c_s) + c_zp.float(), 0, 255).to(
|
|
268
|
+
torch.uint8
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
return builder.call_function(
|
|
272
|
+
_qlinear_add,
|
|
273
|
+
args=(
|
|
274
|
+
a,
|
|
275
|
+
a_scale,
|
|
276
|
+
a_zero_point,
|
|
277
|
+
b,
|
|
278
|
+
b_scale,
|
|
279
|
+
b_zero_point,
|
|
280
|
+
c_scale,
|
|
281
|
+
c_zero_point,
|
|
282
|
+
),
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
@register("QLinearMul")
|
|
287
|
+
def qlinear_mul(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
288
|
+
"""Quantized multiplication with scales and zero points."""
|
|
289
|
+
a = builder.get_value(node.input[0])
|
|
290
|
+
a_scale = builder.get_value(node.input[1])
|
|
291
|
+
a_zero_point = builder.get_value(node.input[2])
|
|
292
|
+
b = builder.get_value(node.input[3])
|
|
293
|
+
b_scale = builder.get_value(node.input[4])
|
|
294
|
+
b_zero_point = builder.get_value(node.input[5])
|
|
295
|
+
c_scale = builder.get_value(node.input[6])
|
|
296
|
+
c_zero_point = builder.get_value(node.input[7])
|
|
297
|
+
|
|
298
|
+
def _qlinear_mul(
|
|
299
|
+
a: torch.Tensor,
|
|
300
|
+
a_s: torch.Tensor,
|
|
301
|
+
a_zp: torch.Tensor,
|
|
302
|
+
b: torch.Tensor,
|
|
303
|
+
b_s: torch.Tensor,
|
|
304
|
+
b_zp: torch.Tensor,
|
|
305
|
+
c_s: torch.Tensor,
|
|
306
|
+
c_zp: torch.Tensor,
|
|
307
|
+
) -> torch.Tensor:
|
|
308
|
+
# Dequantize
|
|
309
|
+
a_dq = (a.float() - a_zp.float()) * a_s
|
|
310
|
+
b_dq = (b.float() - b_zp.float()) * b_s
|
|
311
|
+
# Multiply
|
|
312
|
+
result = a_dq * b_dq
|
|
313
|
+
# Quantize output
|
|
314
|
+
return torch.clamp(torch.round(result / c_s) + c_zp.float(), 0, 255).to(
|
|
315
|
+
torch.uint8
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
return builder.call_function(
|
|
319
|
+
_qlinear_mul,
|
|
320
|
+
args=(
|
|
321
|
+
a,
|
|
322
|
+
a_scale,
|
|
323
|
+
a_zero_point,
|
|
324
|
+
b,
|
|
325
|
+
b_scale,
|
|
326
|
+
b_zero_point,
|
|
327
|
+
c_scale,
|
|
328
|
+
c_zero_point,
|
|
329
|
+
),
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
@register("QLinearSigmoid")
|
|
334
|
+
def qlinear_sigmoid(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
335
|
+
"""Quantized sigmoid."""
|
|
336
|
+
x = builder.get_value(node.input[0])
|
|
337
|
+
x_scale = builder.get_value(node.input[1])
|
|
338
|
+
x_zero_point = builder.get_value(node.input[2])
|
|
339
|
+
y_scale = builder.get_value(node.input[3])
|
|
340
|
+
y_zero_point = builder.get_value(node.input[4])
|
|
341
|
+
|
|
342
|
+
def _qlinear_sigmoid(
|
|
343
|
+
x: torch.Tensor,
|
|
344
|
+
x_s: torch.Tensor,
|
|
345
|
+
x_zp: torch.Tensor,
|
|
346
|
+
y_s: torch.Tensor,
|
|
347
|
+
y_zp: torch.Tensor,
|
|
348
|
+
) -> torch.Tensor:
|
|
349
|
+
# Dequantize
|
|
350
|
+
x_dq = (x.float() - x_zp.float()) * x_s
|
|
351
|
+
# Sigmoid
|
|
352
|
+
result = torch.sigmoid(x_dq)
|
|
353
|
+
# Quantize output
|
|
354
|
+
return torch.clamp(torch.round(result / y_s) + y_zp.float(), 0, 255).to(
|
|
355
|
+
torch.uint8
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
return builder.call_function(
|
|
359
|
+
_qlinear_sigmoid,
|
|
360
|
+
args=(x, x_scale, x_zero_point, y_scale, y_zero_point),
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
@register("QLinearLeakyRelu")
|
|
365
|
+
def qlinear_leaky_relu(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
366
|
+
"""Quantized Leaky ReLU."""
|
|
367
|
+
x = builder.get_value(node.input[0])
|
|
368
|
+
x_scale = builder.get_value(node.input[1])
|
|
369
|
+
x_zero_point = builder.get_value(node.input[2])
|
|
370
|
+
y_scale = builder.get_value(node.input[3])
|
|
371
|
+
y_zero_point = builder.get_value(node.input[4])
|
|
372
|
+
alpha = get_attribute(node, "alpha", 0.01)
|
|
373
|
+
|
|
374
|
+
def _qlinear_leaky_relu(
|
|
375
|
+
x: torch.Tensor,
|
|
376
|
+
x_s: torch.Tensor,
|
|
377
|
+
x_zp: torch.Tensor,
|
|
378
|
+
y_s: torch.Tensor,
|
|
379
|
+
y_zp: torch.Tensor,
|
|
380
|
+
alpha: float,
|
|
381
|
+
) -> torch.Tensor:
|
|
382
|
+
# Dequantize
|
|
383
|
+
x_dq = (x.float() - x_zp.float()) * x_s
|
|
384
|
+
# LeakyReLU
|
|
385
|
+
result = torch.nn.functional.leaky_relu(x_dq, negative_slope=alpha)
|
|
386
|
+
# Quantize output
|
|
387
|
+
return torch.clamp(torch.round(result / y_s) + y_zp.float(), 0, 255).to(
|
|
388
|
+
torch.uint8
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
return builder.call_function(
|
|
392
|
+
_qlinear_leaky_relu,
|
|
393
|
+
args=(x, x_scale, x_zero_point, y_scale, y_zero_point, alpha),
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
@register("QLinearGlobalAveragePool")
|
|
398
|
+
def qlinear_global_avg_pool(
|
|
399
|
+
builder: "GraphBuilder", node: onnx.NodeProto
|
|
400
|
+
) -> torch.fx.Node:
|
|
401
|
+
"""Quantized Global Average Pooling."""
|
|
402
|
+
x = builder.get_value(node.input[0])
|
|
403
|
+
x_scale = builder.get_value(node.input[1])
|
|
404
|
+
x_zero_point = builder.get_value(node.input[2])
|
|
405
|
+
y_scale = builder.get_value(node.input[3])
|
|
406
|
+
y_zero_point = builder.get_value(node.input[4])
|
|
407
|
+
|
|
408
|
+
def _qlinear_global_avg_pool(
|
|
409
|
+
x: torch.Tensor,
|
|
410
|
+
x_s: torch.Tensor,
|
|
411
|
+
x_zp: torch.Tensor,
|
|
412
|
+
y_s: torch.Tensor,
|
|
413
|
+
y_zp: torch.Tensor,
|
|
414
|
+
) -> torch.Tensor:
|
|
415
|
+
# Dequantize
|
|
416
|
+
x_dq = (x.float() - x_zp.float()) * x_s
|
|
417
|
+
# Global Average Pool
|
|
418
|
+
result = torch.nn.functional.adaptive_avg_pool2d(x_dq, (1, 1))
|
|
419
|
+
# Quantize output
|
|
420
|
+
return torch.clamp(torch.round(result / y_s) + y_zp.float(), 0, 255).to(
|
|
421
|
+
torch.uint8
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
return builder.call_function(
|
|
425
|
+
_qlinear_global_avg_pool,
|
|
426
|
+
args=(x, x_scale, x_zero_point, y_scale, y_zero_point),
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
# =============================================================================
|
|
431
|
+
# Integer arithmetic operators
|
|
432
|
+
# =============================================================================
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
@register("ConvInteger")
|
|
436
|
+
def conv_integer(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
437
|
+
"""Integer convolution (returns int32)."""
|
|
438
|
+
x = builder.get_value(node.input[0])
|
|
439
|
+
w = builder.get_value(node.input[1])
|
|
440
|
+
x_zero_point = get_optional_input(builder, node, 2)
|
|
441
|
+
w_zero_point = get_optional_input(builder, node, 3)
|
|
442
|
+
|
|
443
|
+
# Get convolution attributes
|
|
444
|
+
# Note: auto_pad is not implemented; use explicit pads instead
|
|
445
|
+
dilations = get_attribute(node, "dilations", [1, 1])
|
|
446
|
+
group = get_attribute(node, "group", 1)
|
|
447
|
+
pads = get_attribute(node, "pads", [0, 0, 0, 0])
|
|
448
|
+
strides = get_attribute(node, "strides", [1, 1])
|
|
449
|
+
|
|
450
|
+
if len(pads) == 4:
|
|
451
|
+
padding = (pads[0], pads[1])
|
|
452
|
+
else:
|
|
453
|
+
padding = tuple(pads)
|
|
454
|
+
|
|
455
|
+
def _conv_integer(
|
|
456
|
+
x: torch.Tensor,
|
|
457
|
+
w: torch.Tensor,
|
|
458
|
+
x_zp: torch.Tensor | None,
|
|
459
|
+
w_zp: torch.Tensor | None,
|
|
460
|
+
stride: tuple,
|
|
461
|
+
padding: tuple,
|
|
462
|
+
dilation: tuple,
|
|
463
|
+
groups: int,
|
|
464
|
+
) -> torch.Tensor:
|
|
465
|
+
# Subtract zero points
|
|
466
|
+
x_int = x.int()
|
|
467
|
+
w_int = w.int()
|
|
468
|
+
if x_zp is not None:
|
|
469
|
+
x_int = x_int - x_zp.int()
|
|
470
|
+
if w_zp is not None:
|
|
471
|
+
w_int = w_int - w_zp.int()
|
|
472
|
+
|
|
473
|
+
# Perform convolution in float (PyTorch doesn't support int conv)
|
|
474
|
+
result = torch.nn.functional.conv2d(
|
|
475
|
+
x_int.float(),
|
|
476
|
+
w_int.float(),
|
|
477
|
+
stride=stride,
|
|
478
|
+
padding=padding,
|
|
479
|
+
dilation=dilation,
|
|
480
|
+
groups=groups,
|
|
481
|
+
)
|
|
482
|
+
return result.int()
|
|
483
|
+
|
|
484
|
+
return builder.call_function(
|
|
485
|
+
_conv_integer,
|
|
486
|
+
args=(
|
|
487
|
+
x,
|
|
488
|
+
w,
|
|
489
|
+
x_zero_point,
|
|
490
|
+
w_zero_point,
|
|
491
|
+
tuple(strides),
|
|
492
|
+
padding,
|
|
493
|
+
tuple(dilations),
|
|
494
|
+
group,
|
|
495
|
+
),
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
@register("MatMulInteger")
|
|
500
|
+
def matmul_integer(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
501
|
+
"""Integer matrix multiplication (returns int32)."""
|
|
502
|
+
a = builder.get_value(node.input[0])
|
|
503
|
+
b = builder.get_value(node.input[1])
|
|
504
|
+
a_zero_point = get_optional_input(builder, node, 2)
|
|
505
|
+
b_zero_point = get_optional_input(builder, node, 3)
|
|
506
|
+
|
|
507
|
+
def _matmul_integer(
|
|
508
|
+
a: torch.Tensor,
|
|
509
|
+
b: torch.Tensor,
|
|
510
|
+
a_zp: torch.Tensor | None,
|
|
511
|
+
b_zp: torch.Tensor | None,
|
|
512
|
+
) -> torch.Tensor:
|
|
513
|
+
a_int = a.int()
|
|
514
|
+
b_int = b.int()
|
|
515
|
+
if a_zp is not None:
|
|
516
|
+
a_int = a_int - a_zp.int()
|
|
517
|
+
if b_zp is not None:
|
|
518
|
+
b_int = b_int - b_zp.int()
|
|
519
|
+
return torch.matmul(a_int.float(), b_int.float()).int()
|
|
520
|
+
|
|
521
|
+
return builder.call_function(
|
|
522
|
+
_matmul_integer,
|
|
523
|
+
args=(a, b, a_zero_point, b_zero_point),
|
|
524
|
+
)
|
onnx2fx/ops/random.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""Random number generation operators.
|
|
3
|
+
|
|
4
|
+
This module implements ONNX operators for generating random tensors,
|
|
5
|
+
including normal and uniform distributions.
|
|
6
|
+
|
|
7
|
+
Note: Window functions (HannWindow, HammingWindow, BlackmanWindow) have been
|
|
8
|
+
moved to signal.py as they are used for signal processing.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from typing import TYPE_CHECKING
|
|
12
|
+
|
|
13
|
+
import onnx
|
|
14
|
+
import torch
|
|
15
|
+
|
|
16
|
+
from ..op_registry import register
|
|
17
|
+
from ..utils.attributes import get_attribute
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from ..graph_builder import GraphBuilder
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# =============================================================================
|
|
24
|
+
# Random number generation operators
|
|
25
|
+
# =============================================================================
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@register("RandomNormal")
|
|
29
|
+
def random_normal(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
30
|
+
"""Generate random values from normal distribution.
|
|
31
|
+
|
|
32
|
+
Note: The seed attribute is not supported; use torch.manual_seed() instead.
|
|
33
|
+
"""
|
|
34
|
+
mean = get_attribute(node, "mean", 0.0)
|
|
35
|
+
scale = get_attribute(node, "scale", 1.0)
|
|
36
|
+
shape = get_attribute(node, "shape")
|
|
37
|
+
|
|
38
|
+
def _random_normal(m: float, s: float, sh: list) -> torch.Tensor:
|
|
39
|
+
return torch.randn(sh) * s + m
|
|
40
|
+
|
|
41
|
+
return builder.call_function(_random_normal, args=(mean, scale, list(shape)))
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@register("RandomNormalLike")
|
|
45
|
+
def random_normal_like(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
46
|
+
"""Generate random values like input tensor."""
|
|
47
|
+
x = builder.get_value(node.input[0])
|
|
48
|
+
|
|
49
|
+
mean = get_attribute(node, "mean", 0.0)
|
|
50
|
+
scale = get_attribute(node, "scale", 1.0)
|
|
51
|
+
|
|
52
|
+
def _random_normal_like(t: torch.Tensor, m: float, s: float) -> torch.Tensor:
|
|
53
|
+
return torch.randn_like(t) * s + m
|
|
54
|
+
|
|
55
|
+
return builder.call_function(_random_normal_like, args=(x, mean, scale))
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@register("RandomUniform")
|
|
59
|
+
def random_uniform(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
60
|
+
"""Generate random values from uniform distribution."""
|
|
61
|
+
low = get_attribute(node, "low", 0.0)
|
|
62
|
+
high = get_attribute(node, "high", 1.0)
|
|
63
|
+
shape = get_attribute(node, "shape")
|
|
64
|
+
|
|
65
|
+
def _random_uniform(lo: float, hi: float, sh: list) -> torch.Tensor:
|
|
66
|
+
return torch.rand(sh) * (hi - lo) + lo
|
|
67
|
+
|
|
68
|
+
return builder.call_function(_random_uniform, args=(low, high, list(shape)))
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@register("RandomUniformLike")
|
|
72
|
+
def random_uniform_like(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
73
|
+
"""Generate random values like input tensor."""
|
|
74
|
+
x = builder.get_value(node.input[0])
|
|
75
|
+
|
|
76
|
+
low = get_attribute(node, "low", 0.0)
|
|
77
|
+
high = get_attribute(node, "high", 1.0)
|
|
78
|
+
|
|
79
|
+
def _random_uniform_like(t: torch.Tensor, lo: float, hi: float) -> torch.Tensor:
|
|
80
|
+
return torch.rand_like(t) * (hi - lo) + lo
|
|
81
|
+
|
|
82
|
+
return builder.call_function(_random_uniform_like, args=(x, low, high))
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@register("Multinomial")
|
|
86
|
+
def multinomial(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
87
|
+
"""Sample from multinomial distribution."""
|
|
88
|
+
x = builder.get_value(node.input[0])
|
|
89
|
+
|
|
90
|
+
sample_size = get_attribute(node, "sample_size", 1)
|
|
91
|
+
|
|
92
|
+
return builder.call_function(
|
|
93
|
+
torch.multinomial, args=(x, sample_size), kwargs={"replacement": True}
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@register("Bernoulli")
|
|
98
|
+
def bernoulli(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
99
|
+
"""Sample from Bernoulli distribution."""
|
|
100
|
+
x = builder.get_value(node.input[0])
|
|
101
|
+
|
|
102
|
+
return builder.call_function(torch.bernoulli, args=(x,))
|