emx-onnx-cgen 0.3.8__py3-none-any.whl → 0.4.2.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of emx-onnx-cgen might be problematic. Click here for more details.
- emx_onnx_cgen/_build_info.py +1 -1
- emx_onnx_cgen/_version.py +2 -2
- emx_onnx_cgen/cli.py +1025 -162
- emx_onnx_cgen/codegen/__init__.py +2 -0
- emx_onnx_cgen/codegen/c_emitter.py +2081 -458
- emx_onnx_cgen/compiler.py +157 -75
- emx_onnx_cgen/determinism.py +39 -0
- emx_onnx_cgen/ir/context.py +25 -15
- emx_onnx_cgen/ir/model.py +1 -0
- emx_onnx_cgen/ir/op_base.py +32 -7
- emx_onnx_cgen/ir/ops/__init__.py +20 -0
- emx_onnx_cgen/ir/ops/elementwise.py +138 -22
- emx_onnx_cgen/ir/ops/misc.py +95 -0
- emx_onnx_cgen/ir/ops/nn.py +361 -38
- emx_onnx_cgen/ir/ops/reduce.py +1 -16
- emx_onnx_cgen/lowering/__init__.py +9 -0
- emx_onnx_cgen/lowering/arg_reduce.py +0 -4
- emx_onnx_cgen/lowering/average_pool.py +157 -27
- emx_onnx_cgen/lowering/bernoulli.py +73 -0
- emx_onnx_cgen/lowering/common.py +48 -0
- emx_onnx_cgen/lowering/concat.py +41 -7
- emx_onnx_cgen/lowering/conv.py +19 -8
- emx_onnx_cgen/lowering/conv_integer.py +103 -0
- emx_onnx_cgen/lowering/dequantize_linear.py +128 -0
- emx_onnx_cgen/lowering/elementwise.py +140 -43
- emx_onnx_cgen/lowering/gather.py +11 -2
- emx_onnx_cgen/lowering/gemm.py +7 -124
- emx_onnx_cgen/lowering/global_max_pool.py +0 -5
- emx_onnx_cgen/lowering/gru.py +323 -0
- emx_onnx_cgen/lowering/hamming_window.py +104 -0
- emx_onnx_cgen/lowering/hardmax.py +1 -37
- emx_onnx_cgen/lowering/identity.py +7 -6
- emx_onnx_cgen/lowering/logsoftmax.py +1 -35
- emx_onnx_cgen/lowering/lp_pool.py +15 -4
- emx_onnx_cgen/lowering/matmul.py +3 -105
- emx_onnx_cgen/lowering/optional_has_element.py +28 -0
- emx_onnx_cgen/lowering/qlinear_mul.py +116 -0
- emx_onnx_cgen/lowering/reduce.py +0 -5
- emx_onnx_cgen/lowering/reshape.py +7 -16
- emx_onnx_cgen/lowering/shape.py +14 -8
- emx_onnx_cgen/lowering/slice.py +14 -4
- emx_onnx_cgen/lowering/softmax.py +1 -35
- emx_onnx_cgen/lowering/split.py +37 -3
- emx_onnx_cgen/lowering/tfidf_vectorizer.py +199 -0
- emx_onnx_cgen/lowering/tile.py +38 -1
- emx_onnx_cgen/lowering/topk.py +1 -5
- emx_onnx_cgen/lowering/transpose.py +9 -3
- emx_onnx_cgen/lowering/unsqueeze.py +11 -16
- emx_onnx_cgen/lowering/upsample.py +151 -0
- emx_onnx_cgen/lowering/variadic.py +1 -1
- emx_onnx_cgen/lowering/where.py +0 -5
- emx_onnx_cgen/onnx_import.py +578 -14
- emx_onnx_cgen/ops.py +3 -0
- emx_onnx_cgen/templates/adagrad_op.c.j2 +16 -0
- emx_onnx_cgen/templates/arg_reduce_op.c.j2 +18 -0
- emx_onnx_cgen/templates/attention_op.c.j2 +189 -0
- emx_onnx_cgen/templates/average_pool_op.c.j2 +126 -0
- emx_onnx_cgen/templates/batch_norm_op.c.j2 +11 -0
- emx_onnx_cgen/templates/bernoulli_op.c.j2 +34 -0
- emx_onnx_cgen/templates/binary_op.c.j2 +9 -0
- emx_onnx_cgen/templates/cast_op.c.j2 +9 -0
- emx_onnx_cgen/templates/clip_op.c.j2 +14 -0
- emx_onnx_cgen/templates/concat_op.c.j2 +28 -0
- emx_onnx_cgen/templates/constant_of_shape_op.c.j2 +10 -0
- emx_onnx_cgen/templates/conv_integer_op.c.j2 +34 -0
- emx_onnx_cgen/templates/conv_op.c.j2 +32 -0
- emx_onnx_cgen/templates/conv_transpose_op.c.j2 +43 -0
- emx_onnx_cgen/templates/cumsum_op.c.j2 +51 -0
- emx_onnx_cgen/templates/depth_to_space_op.c.j2 +26 -0
- emx_onnx_cgen/templates/dequantize_linear_op.c.j2 +10 -0
- emx_onnx_cgen/templates/einsum_op.c.j2 +55 -0
- emx_onnx_cgen/templates/expand_op.c.j2 +14 -0
- emx_onnx_cgen/templates/eye_like_op.c.j2 +27 -0
- emx_onnx_cgen/templates/gather_elements_op.c.j2 +13 -0
- emx_onnx_cgen/templates/gather_nd_op.c.j2 +29 -0
- emx_onnx_cgen/templates/gather_op.c.j2 +13 -0
- emx_onnx_cgen/templates/gemm_op.c.j2 +35 -0
- emx_onnx_cgen/templates/grid_sample_op.c.j2 +184 -0
- emx_onnx_cgen/templates/group_normalization_op.c.j2 +46 -0
- emx_onnx_cgen/templates/gru_op.c.j2 +152 -0
- emx_onnx_cgen/templates/hamming_window_op.c.j2 +12 -0
- emx_onnx_cgen/templates/hardmax_op.c.j2 +24 -0
- emx_onnx_cgen/templates/identity_op.c.j2 +9 -0
- emx_onnx_cgen/templates/instance_normalization_op.c.j2 +35 -0
- emx_onnx_cgen/templates/layer_normalization_op.c.j2 +65 -0
- emx_onnx_cgen/templates/logsoftmax_op.c.j2 +27 -0
- emx_onnx_cgen/templates/lp_normalization_op.c.j2 +27 -0
- emx_onnx_cgen/templates/lp_pool_op.c.j2 +24 -0
- emx_onnx_cgen/templates/lrn_op.c.j2 +20 -0
- emx_onnx_cgen/templates/lstm_op.c.j2 +175 -0
- emx_onnx_cgen/templates/matmul_op.c.j2 +13 -0
- emx_onnx_cgen/templates/maxpool_op.c.j2 +118 -0
- emx_onnx_cgen/templates/mean_variance_normalization_op.c.j2 +34 -0
- emx_onnx_cgen/templates/multi_input_op.c.j2 +15 -0
- emx_onnx_cgen/templates/negative_log_likelihood_loss_op.c.j2 +54 -0
- emx_onnx_cgen/templates/nonmax_suppression_op.c.j2 +179 -0
- emx_onnx_cgen/templates/nonzero_op.c.j2 +15 -0
- emx_onnx_cgen/templates/one_hot_op.c.j2 +25 -0
- emx_onnx_cgen/templates/optional_has_element_op.c.j2 +4 -0
- emx_onnx_cgen/templates/pad_op.c.j2 +80 -0
- emx_onnx_cgen/templates/qlinear_matmul_op.c.j2 +33 -0
- emx_onnx_cgen/templates/qlinear_mul_op.c.j2 +18 -0
- emx_onnx_cgen/templates/quantize_linear_op.c.j2 +13 -0
- emx_onnx_cgen/templates/range_op.c.j2 +8 -0
- emx_onnx_cgen/templates/reduce_op.c.j2 +28 -0
- emx_onnx_cgen/templates/reduce_op_dynamic.c.j2 +77 -0
- emx_onnx_cgen/templates/reshape_op.c.j2 +18 -0
- emx_onnx_cgen/templates/resize_op.c.j2 +277 -0
- emx_onnx_cgen/templates/rms_normalization_op.c.j2 +28 -0
- emx_onnx_cgen/templates/rotary_embedding_op.c.j2 +66 -0
- emx_onnx_cgen/templates/scatter_nd_op.c.j2 +52 -0
- emx_onnx_cgen/templates/shape_op.c.j2 +6 -0
- emx_onnx_cgen/templates/size_op.c.j2 +4 -0
- emx_onnx_cgen/templates/slice_op.c.j2 +9 -0
- emx_onnx_cgen/templates/slice_op_dynamic.c.j2 +70 -0
- emx_onnx_cgen/templates/softmax_cross_entropy_loss_op.c.j2 +105 -0
- emx_onnx_cgen/templates/softmax_op.c.j2 +26 -0
- emx_onnx_cgen/templates/space_to_depth_op.c.j2 +22 -0
- emx_onnx_cgen/templates/split_op.c.j2 +18 -0
- emx_onnx_cgen/templates/tensor_scatter_op.c.j2 +44 -0
- emx_onnx_cgen/templates/testbench.c.j2 +161 -0
- emx_onnx_cgen/templates/tfidf_vectorizer_op.c.j2 +144 -0
- emx_onnx_cgen/templates/tile_op.c.j2 +14 -0
- emx_onnx_cgen/templates/topk_op.c.j2 +50 -0
- emx_onnx_cgen/templates/transpose_op.c.j2 +9 -0
- emx_onnx_cgen/templates/trilu_op.c.j2 +33 -0
- emx_onnx_cgen/templates/unary_op.c.j2 +23 -0
- emx_onnx_cgen/templates/where_op.c.j2 +9 -0
- emx_onnx_cgen/verification.py +45 -5
- {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/METADATA +33 -15
- emx_onnx_cgen-0.4.2.dev0.dist-info/RECORD +190 -0
- {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/WHEEL +1 -1
- emx_onnx_cgen/runtime/__init__.py +0 -1
- emx_onnx_cgen/runtime/evaluator.py +0 -2955
- emx_onnx_cgen-0.3.8.dist-info/RECORD +0 -107
- {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/entry_points.txt +0 -0
- {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import math
|
|
3
4
|
from dataclasses import dataclass
|
|
4
5
|
|
|
5
6
|
from ..ir.ops import AveragePoolOp
|
|
@@ -12,16 +13,26 @@ from .registry import register_lowering
|
|
|
12
13
|
class _AveragePoolSpec:
|
|
13
14
|
batch: int
|
|
14
15
|
channels: int
|
|
16
|
+
spatial_rank: int
|
|
17
|
+
in_d: int
|
|
15
18
|
in_h: int
|
|
16
19
|
in_w: int
|
|
20
|
+
out_d: int
|
|
17
21
|
out_h: int
|
|
18
22
|
out_w: int
|
|
23
|
+
kernel_d: int
|
|
19
24
|
kernel_h: int
|
|
20
25
|
kernel_w: int
|
|
26
|
+
dilation_d: int
|
|
27
|
+
dilation_h: int
|
|
28
|
+
dilation_w: int
|
|
29
|
+
stride_d: int
|
|
21
30
|
stride_h: int
|
|
22
31
|
stride_w: int
|
|
32
|
+
pad_front: int
|
|
23
33
|
pad_top: int
|
|
24
34
|
pad_left: int
|
|
35
|
+
pad_back: int
|
|
25
36
|
pad_bottom: int
|
|
26
37
|
pad_right: int
|
|
27
38
|
count_include_pad: bool
|
|
@@ -54,6 +65,7 @@ def _resolve_average_pool_spec(graph: Graph, node: Node) -> _AveragePoolSpec:
|
|
|
54
65
|
"auto_pad",
|
|
55
66
|
"ceil_mode",
|
|
56
67
|
"count_include_pad",
|
|
68
|
+
"dilations",
|
|
57
69
|
"kernel_shape",
|
|
58
70
|
"pads",
|
|
59
71
|
"strides",
|
|
@@ -63,11 +75,9 @@ def _resolve_average_pool_spec(graph: Graph, node: Node) -> _AveragePoolSpec:
|
|
|
63
75
|
auto_pad = node.attrs.get("auto_pad", b"NOTSET")
|
|
64
76
|
if isinstance(auto_pad, bytes):
|
|
65
77
|
auto_pad = auto_pad.decode("utf-8", errors="ignore")
|
|
66
|
-
if auto_pad not in ("", "NOTSET"):
|
|
67
|
-
raise UnsupportedOpError("AveragePool supports auto_pad=NOTSET only")
|
|
68
78
|
ceil_mode = int(node.attrs.get("ceil_mode", 0))
|
|
69
|
-
if ceil_mode
|
|
70
|
-
raise UnsupportedOpError("AveragePool supports ceil_mode=0 only")
|
|
79
|
+
if ceil_mode not in (0, 1):
|
|
80
|
+
raise UnsupportedOpError("AveragePool supports ceil_mode=0 or 1 only")
|
|
71
81
|
count_include_pad = int(node.attrs.get("count_include_pad", 0))
|
|
72
82
|
if count_include_pad not in (0, 1):
|
|
73
83
|
raise UnsupportedOpError("AveragePool supports count_include_pad 0 or 1")
|
|
@@ -75,47 +85,128 @@ def _resolve_average_pool_spec(graph: Graph, node: Node) -> _AveragePoolSpec:
|
|
|
75
85
|
if kernel_shape is None:
|
|
76
86
|
raise UnsupportedOpError("AveragePool requires kernel_shape")
|
|
77
87
|
kernel_shape = tuple(int(value) for value in kernel_shape)
|
|
78
|
-
if len(kernel_shape) != 2:
|
|
79
|
-
raise UnsupportedOpError("AveragePool expects 2D kernel_shape")
|
|
80
|
-
kernel_h, kernel_w = kernel_shape
|
|
81
|
-
strides = tuple(int(value) for value in node.attrs.get("strides", (1, 1)))
|
|
82
|
-
if len(strides) != 2:
|
|
83
|
-
raise UnsupportedOpError("AveragePool expects 2D strides")
|
|
84
|
-
pads = tuple(int(value) for value in node.attrs.get("pads", (0, 0, 0, 0)))
|
|
85
|
-
if len(pads) != 4:
|
|
86
|
-
raise UnsupportedOpError("AveragePool expects 4D pads")
|
|
87
|
-
pad_top, pad_left, pad_bottom, pad_right = pads
|
|
88
88
|
input_shape = _value_shape(graph, node.inputs[0], node)
|
|
89
|
-
if len(input_shape)
|
|
90
|
-
raise UnsupportedOpError("AveragePool
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
if out_h < 0 or out_w < 0:
|
|
89
|
+
if len(input_shape) < 3:
|
|
90
|
+
raise UnsupportedOpError("AveragePool expects NCHW inputs with spatial dims")
|
|
91
|
+
spatial_rank = len(input_shape) - 2
|
|
92
|
+
if spatial_rank not in {1, 2, 3}:
|
|
93
|
+
raise UnsupportedOpError("AveragePool supports 1D/2D/3D inputs only")
|
|
94
|
+
if len(kernel_shape) != spatial_rank:
|
|
96
95
|
raise ShapeInferenceError(
|
|
97
|
-
"AveragePool
|
|
96
|
+
"AveragePool kernel_shape must have "
|
|
97
|
+
f"{spatial_rank} dims, got {kernel_shape}"
|
|
98
98
|
)
|
|
99
|
+
strides = tuple(
|
|
100
|
+
int(value) for value in node.attrs.get("strides", (1,) * spatial_rank)
|
|
101
|
+
)
|
|
102
|
+
if len(strides) != spatial_rank:
|
|
103
|
+
raise UnsupportedOpError("AveragePool stride rank mismatch")
|
|
104
|
+
dilations = tuple(
|
|
105
|
+
int(value)
|
|
106
|
+
for value in node.attrs.get("dilations", (1,) * spatial_rank)
|
|
107
|
+
)
|
|
108
|
+
if len(dilations) != spatial_rank:
|
|
109
|
+
raise UnsupportedOpError("AveragePool dilation rank mismatch")
|
|
110
|
+
pads = tuple(
|
|
111
|
+
int(value) for value in node.attrs.get("pads", (0,) * (2 * spatial_rank))
|
|
112
|
+
)
|
|
113
|
+
if len(pads) != 2 * spatial_rank:
|
|
114
|
+
raise UnsupportedOpError("AveragePool pads rank mismatch")
|
|
115
|
+
if auto_pad in ("", "NOTSET"):
|
|
116
|
+
pad_begin = pads[:spatial_rank]
|
|
117
|
+
pad_end = pads[spatial_rank:]
|
|
118
|
+
elif auto_pad == "VALID":
|
|
119
|
+
pad_begin = (0,) * spatial_rank
|
|
120
|
+
pad_end = (0,) * spatial_rank
|
|
121
|
+
elif auto_pad in {"SAME_UPPER", "SAME_LOWER"}:
|
|
122
|
+
pad_begin = []
|
|
123
|
+
pad_end = []
|
|
124
|
+
for dim, stride, dilation, kernel in zip(
|
|
125
|
+
input_shape[2:], strides, dilations, kernel_shape
|
|
126
|
+
):
|
|
127
|
+
effective_kernel = dilation * (kernel - 1) + 1
|
|
128
|
+
out_dim = math.ceil(dim / stride)
|
|
129
|
+
pad_needed = max(0, (out_dim - 1) * stride + effective_kernel - dim)
|
|
130
|
+
if auto_pad == "SAME_UPPER":
|
|
131
|
+
pad_start = pad_needed // 2
|
|
132
|
+
else:
|
|
133
|
+
pad_start = (pad_needed + 1) // 2
|
|
134
|
+
pad_begin.append(pad_start)
|
|
135
|
+
pad_end.append(pad_needed - pad_start)
|
|
136
|
+
pad_begin = tuple(pad_begin)
|
|
137
|
+
pad_end = tuple(pad_end)
|
|
138
|
+
else:
|
|
139
|
+
raise UnsupportedOpError("AveragePool has unsupported auto_pad mode")
|
|
140
|
+
batch, channels = input_shape[:2]
|
|
141
|
+
in_spatial = input_shape[2:]
|
|
142
|
+
out_spatial = []
|
|
143
|
+
for dim, stride, dilation, kernel, pad_start, pad_finish in zip(
|
|
144
|
+
in_spatial, strides, dilations, kernel_shape, pad_begin, pad_end
|
|
145
|
+
):
|
|
146
|
+
effective_kernel = dilation * (kernel - 1) + 1
|
|
147
|
+
numerator = dim + pad_start + pad_finish - effective_kernel
|
|
148
|
+
if ceil_mode:
|
|
149
|
+
out_dim = (numerator + stride - 1) // stride + 1
|
|
150
|
+
if (out_dim - 1) * stride >= dim + pad_start:
|
|
151
|
+
out_dim -= 1
|
|
152
|
+
else:
|
|
153
|
+
out_dim = numerator // stride + 1
|
|
154
|
+
if out_dim < 0:
|
|
155
|
+
raise ShapeInferenceError(
|
|
156
|
+
"AveragePool output shape must be non-negative"
|
|
157
|
+
)
|
|
158
|
+
out_spatial.append(out_dim)
|
|
99
159
|
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
100
|
-
expected_output_shape = (batch, channels,
|
|
160
|
+
expected_output_shape = (batch, channels, *out_spatial)
|
|
101
161
|
if output_shape != expected_output_shape:
|
|
102
162
|
raise ShapeInferenceError(
|
|
103
163
|
"AveragePool output shape must be "
|
|
104
164
|
f"{expected_output_shape}, got {output_shape}"
|
|
105
165
|
)
|
|
166
|
+
in_d = in_spatial[0] if spatial_rank == 3 else 1
|
|
167
|
+
in_h = in_spatial[-2] if spatial_rank >= 2 else 1
|
|
168
|
+
in_w = in_spatial[-1]
|
|
169
|
+
out_d = out_spatial[0] if spatial_rank == 3 else 1
|
|
170
|
+
out_h = out_spatial[-2] if spatial_rank >= 2 else 1
|
|
171
|
+
out_w = out_spatial[-1]
|
|
172
|
+
kernel_d = kernel_shape[0] if spatial_rank == 3 else 1
|
|
173
|
+
kernel_h = kernel_shape[-2] if spatial_rank >= 2 else 1
|
|
174
|
+
kernel_w = kernel_shape[-1]
|
|
175
|
+
dilation_d = dilations[0] if spatial_rank == 3 else 1
|
|
176
|
+
dilation_h = dilations[-2] if spatial_rank >= 2 else 1
|
|
177
|
+
dilation_w = dilations[-1]
|
|
178
|
+
stride_d = strides[0] if spatial_rank == 3 else 1
|
|
179
|
+
stride_h = strides[-2] if spatial_rank >= 2 else 1
|
|
180
|
+
stride_w = strides[-1]
|
|
181
|
+
pad_front = pad_begin[0] if spatial_rank == 3 else 0
|
|
182
|
+
pad_top = pad_begin[-2] if spatial_rank >= 2 else 0
|
|
183
|
+
pad_left = pad_begin[-1]
|
|
184
|
+
pad_back = pad_end[0] if spatial_rank == 3 else 0
|
|
185
|
+
pad_bottom = pad_end[-2] if spatial_rank >= 2 else 0
|
|
186
|
+
pad_right = pad_end[-1]
|
|
106
187
|
return _AveragePoolSpec(
|
|
107
188
|
batch=batch,
|
|
108
189
|
channels=channels,
|
|
190
|
+
spatial_rank=spatial_rank,
|
|
191
|
+
in_d=in_d,
|
|
109
192
|
in_h=in_h,
|
|
110
193
|
in_w=in_w,
|
|
194
|
+
out_d=out_d,
|
|
111
195
|
out_h=out_h,
|
|
112
196
|
out_w=out_w,
|
|
197
|
+
kernel_d=kernel_d,
|
|
113
198
|
kernel_h=kernel_h,
|
|
114
199
|
kernel_w=kernel_w,
|
|
200
|
+
dilation_d=dilation_d,
|
|
201
|
+
dilation_h=dilation_h,
|
|
202
|
+
dilation_w=dilation_w,
|
|
203
|
+
stride_d=stride_d,
|
|
115
204
|
stride_h=stride_h,
|
|
116
205
|
stride_w=stride_w,
|
|
206
|
+
pad_front=pad_front,
|
|
117
207
|
pad_top=pad_top,
|
|
118
208
|
pad_left=pad_left,
|
|
209
|
+
pad_back=pad_back,
|
|
119
210
|
pad_bottom=pad_bottom,
|
|
120
211
|
pad_right=pad_right,
|
|
121
212
|
count_include_pad=bool(count_include_pad),
|
|
@@ -128,29 +219,48 @@ def _resolve_global_average_pool_spec(graph: Graph, node: Node) -> _AveragePoolS
|
|
|
128
219
|
if node.attrs:
|
|
129
220
|
raise UnsupportedOpError("GlobalAveragePool has unsupported attributes")
|
|
130
221
|
input_shape = _value_shape(graph, node.inputs[0], node)
|
|
131
|
-
if len(input_shape)
|
|
132
|
-
raise UnsupportedOpError(
|
|
133
|
-
|
|
222
|
+
if len(input_shape) < 3:
|
|
223
|
+
raise UnsupportedOpError(
|
|
224
|
+
"GlobalAveragePool expects NCHW inputs with spatial dims"
|
|
225
|
+
)
|
|
226
|
+
spatial_rank = len(input_shape) - 2
|
|
227
|
+
if spatial_rank not in {1, 2, 3}:
|
|
228
|
+
raise UnsupportedOpError("GlobalAveragePool supports 1D/2D/3D inputs only")
|
|
229
|
+
batch, channels = input_shape[:2]
|
|
230
|
+
in_spatial = input_shape[2:]
|
|
134
231
|
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
135
|
-
expected_output_shape = (batch, channels, 1
|
|
232
|
+
expected_output_shape = (batch, channels, *([1] * spatial_rank))
|
|
136
233
|
if output_shape != expected_output_shape:
|
|
137
234
|
raise ShapeInferenceError(
|
|
138
235
|
"GlobalAveragePool output shape must be "
|
|
139
236
|
f"{expected_output_shape}, got {output_shape}"
|
|
140
237
|
)
|
|
238
|
+
in_d = in_spatial[0] if spatial_rank == 3 else 1
|
|
239
|
+
in_h = in_spatial[-2] if spatial_rank >= 2 else 1
|
|
240
|
+
in_w = in_spatial[-1]
|
|
141
241
|
return _AveragePoolSpec(
|
|
142
242
|
batch=batch,
|
|
143
243
|
channels=channels,
|
|
244
|
+
spatial_rank=spatial_rank,
|
|
245
|
+
in_d=in_d,
|
|
144
246
|
in_h=in_h,
|
|
145
247
|
in_w=in_w,
|
|
248
|
+
out_d=1,
|
|
146
249
|
out_h=1,
|
|
147
250
|
out_w=1,
|
|
251
|
+
kernel_d=in_d,
|
|
148
252
|
kernel_h=in_h,
|
|
149
253
|
kernel_w=in_w,
|
|
254
|
+
dilation_d=1,
|
|
255
|
+
dilation_h=1,
|
|
256
|
+
dilation_w=1,
|
|
257
|
+
stride_d=1,
|
|
150
258
|
stride_h=1,
|
|
151
259
|
stride_w=1,
|
|
260
|
+
pad_front=0,
|
|
152
261
|
pad_top=0,
|
|
153
262
|
pad_left=0,
|
|
263
|
+
pad_back=0,
|
|
154
264
|
pad_bottom=0,
|
|
155
265
|
pad_right=0,
|
|
156
266
|
count_include_pad=False,
|
|
@@ -176,16 +286,26 @@ def lower_average_pool(graph: Graph, node: Node) -> AveragePoolOp:
|
|
|
176
286
|
output=node.outputs[0],
|
|
177
287
|
batch=spec.batch,
|
|
178
288
|
channels=spec.channels,
|
|
289
|
+
spatial_rank=spec.spatial_rank,
|
|
290
|
+
in_d=spec.in_d,
|
|
179
291
|
in_h=spec.in_h,
|
|
180
292
|
in_w=spec.in_w,
|
|
293
|
+
out_d=spec.out_d,
|
|
181
294
|
out_h=spec.out_h,
|
|
182
295
|
out_w=spec.out_w,
|
|
296
|
+
kernel_d=spec.kernel_d,
|
|
183
297
|
kernel_h=spec.kernel_h,
|
|
184
298
|
kernel_w=spec.kernel_w,
|
|
299
|
+
dilation_d=spec.dilation_d,
|
|
300
|
+
dilation_h=spec.dilation_h,
|
|
301
|
+
dilation_w=spec.dilation_w,
|
|
302
|
+
stride_d=spec.stride_d,
|
|
185
303
|
stride_h=spec.stride_h,
|
|
186
304
|
stride_w=spec.stride_w,
|
|
305
|
+
pad_front=spec.pad_front,
|
|
187
306
|
pad_top=spec.pad_top,
|
|
188
307
|
pad_left=spec.pad_left,
|
|
308
|
+
pad_back=spec.pad_back,
|
|
189
309
|
pad_bottom=spec.pad_bottom,
|
|
190
310
|
pad_right=spec.pad_right,
|
|
191
311
|
count_include_pad=spec.count_include_pad,
|
|
@@ -212,16 +332,26 @@ def lower_global_average_pool(graph: Graph, node: Node) -> AveragePoolOp:
|
|
|
212
332
|
output=node.outputs[0],
|
|
213
333
|
batch=spec.batch,
|
|
214
334
|
channels=spec.channels,
|
|
335
|
+
spatial_rank=spec.spatial_rank,
|
|
336
|
+
in_d=spec.in_d,
|
|
215
337
|
in_h=spec.in_h,
|
|
216
338
|
in_w=spec.in_w,
|
|
339
|
+
out_d=spec.out_d,
|
|
217
340
|
out_h=spec.out_h,
|
|
218
341
|
out_w=spec.out_w,
|
|
342
|
+
kernel_d=spec.kernel_d,
|
|
219
343
|
kernel_h=spec.kernel_h,
|
|
220
344
|
kernel_w=spec.kernel_w,
|
|
345
|
+
dilation_d=spec.dilation_d,
|
|
346
|
+
dilation_h=spec.dilation_h,
|
|
347
|
+
dilation_w=spec.dilation_w,
|
|
348
|
+
stride_d=spec.stride_d,
|
|
221
349
|
stride_h=spec.stride_h,
|
|
222
350
|
stride_w=spec.stride_w,
|
|
351
|
+
pad_front=spec.pad_front,
|
|
223
352
|
pad_top=spec.pad_top,
|
|
224
353
|
pad_left=spec.pad_left,
|
|
354
|
+
pad_back=spec.pad_back,
|
|
225
355
|
pad_bottom=spec.pad_bottom,
|
|
226
356
|
pad_right=spec.pad_right,
|
|
227
357
|
count_include_pad=spec.count_include_pad,
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from shared.scalar_types import ScalarType
|
|
4
|
+
|
|
5
|
+
from ..dtypes import dtype_info
|
|
6
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
|
+
from ..ir.model import Graph, Node
|
|
8
|
+
from ..ir.ops import BernoulliOp
|
|
9
|
+
from .common import value_dtype as _value_dtype
|
|
10
|
+
from .common import value_shape as _value_shape
|
|
11
|
+
from .registry import register_lowering
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
_SUPPORTED_INPUT_DTYPES = {ScalarType.F16, ScalarType.F32, ScalarType.F64}
|
|
15
|
+
_SUPPORTED_OUTPUT_DTYPES = {
|
|
16
|
+
ScalarType.U8,
|
|
17
|
+
ScalarType.U16,
|
|
18
|
+
ScalarType.U32,
|
|
19
|
+
ScalarType.U64,
|
|
20
|
+
ScalarType.I8,
|
|
21
|
+
ScalarType.I16,
|
|
22
|
+
ScalarType.I32,
|
|
23
|
+
ScalarType.I64,
|
|
24
|
+
ScalarType.F16,
|
|
25
|
+
ScalarType.F32,
|
|
26
|
+
ScalarType.F64,
|
|
27
|
+
ScalarType.BOOL,
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@register_lowering("Bernoulli")
|
|
32
|
+
def lower_bernoulli(graph: Graph, node: Node) -> BernoulliOp:
|
|
33
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
34
|
+
raise UnsupportedOpError("Bernoulli must have 1 input and 1 output")
|
|
35
|
+
input_shape = _value_shape(graph, node.inputs[0], node)
|
|
36
|
+
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
37
|
+
if input_shape != output_shape:
|
|
38
|
+
raise ShapeInferenceError(
|
|
39
|
+
"Bernoulli output shape must match input shape, "
|
|
40
|
+
f"got {output_shape} for input {input_shape}"
|
|
41
|
+
)
|
|
42
|
+
input_dtype = _value_dtype(graph, node.inputs[0], node)
|
|
43
|
+
if input_dtype not in _SUPPORTED_INPUT_DTYPES:
|
|
44
|
+
raise UnsupportedOpError(
|
|
45
|
+
"Bernoulli input dtype must be float, "
|
|
46
|
+
f"got {input_dtype.onnx_name}"
|
|
47
|
+
)
|
|
48
|
+
output_dtype = _value_dtype(graph, node.outputs[0], node)
|
|
49
|
+
dtype_attr = node.attrs.get("dtype")
|
|
50
|
+
if dtype_attr is not None:
|
|
51
|
+
attr_dtype = dtype_info(int(dtype_attr))
|
|
52
|
+
if attr_dtype != output_dtype:
|
|
53
|
+
raise UnsupportedOpError(
|
|
54
|
+
"Bernoulli dtype attribute does not match output dtype"
|
|
55
|
+
)
|
|
56
|
+
if output_dtype not in _SUPPORTED_OUTPUT_DTYPES:
|
|
57
|
+
raise UnsupportedOpError(
|
|
58
|
+
"Bernoulli output dtype must be numeric or bool, "
|
|
59
|
+
f"got {output_dtype.onnx_name}"
|
|
60
|
+
)
|
|
61
|
+
seed_value = node.attrs.get("seed")
|
|
62
|
+
seed = None
|
|
63
|
+
if seed_value is not None:
|
|
64
|
+
seed = int(seed_value)
|
|
65
|
+
return BernoulliOp(
|
|
66
|
+
input0=node.inputs[0],
|
|
67
|
+
output=node.outputs[0],
|
|
68
|
+
input_shape=input_shape,
|
|
69
|
+
output_shape=output_shape,
|
|
70
|
+
input_dtype=input_dtype,
|
|
71
|
+
dtype=output_dtype,
|
|
72
|
+
seed=seed,
|
|
73
|
+
)
|
emx_onnx_cgen/lowering/common.py
CHANGED
|
@@ -50,6 +50,8 @@ def value_shape(
|
|
|
50
50
|
if isinstance(graph, GraphContext):
|
|
51
51
|
shape = graph.shape(name, node)
|
|
52
52
|
value = graph.find_value(name)
|
|
53
|
+
if graph.has_shape(name):
|
|
54
|
+
return shape
|
|
53
55
|
else:
|
|
54
56
|
try:
|
|
55
57
|
value = graph.find_value(name)
|
|
@@ -219,6 +221,37 @@ def _shape_values_from_input(
|
|
|
219
221
|
return [int(l / r) if r != 0 else 0 for l, r in zip(left, right)]
|
|
220
222
|
if source_node.op_type == "Mod":
|
|
221
223
|
return [l % r if r != 0 else 0 for l, r in zip(left, right)]
|
|
224
|
+
if source_node.op_type in {"Add", "Sub", "Mul"}:
|
|
225
|
+
if len(source_node.inputs) != 2 or len(source_node.outputs) != 1:
|
|
226
|
+
raise UnsupportedOpError(
|
|
227
|
+
f"{source_node.op_type} must have 2 inputs and 1 output"
|
|
228
|
+
)
|
|
229
|
+
left = _shape_values_from_input(
|
|
230
|
+
graph,
|
|
231
|
+
source_node.inputs[0],
|
|
232
|
+
node,
|
|
233
|
+
_visited=_visited,
|
|
234
|
+
)
|
|
235
|
+
right = _shape_values_from_input(
|
|
236
|
+
graph,
|
|
237
|
+
source_node.inputs[1],
|
|
238
|
+
node,
|
|
239
|
+
_visited=_visited,
|
|
240
|
+
)
|
|
241
|
+
if left is None or right is None:
|
|
242
|
+
return None
|
|
243
|
+
if len(left) == 1 and len(right) != 1:
|
|
244
|
+
left = left * len(right)
|
|
245
|
+
if len(right) == 1 and len(left) != 1:
|
|
246
|
+
right = right * len(left)
|
|
247
|
+
if len(left) != len(right):
|
|
248
|
+
return None
|
|
249
|
+
if source_node.op_type == "Add":
|
|
250
|
+
return [l + r for l, r in zip(left, right)]
|
|
251
|
+
if source_node.op_type == "Sub":
|
|
252
|
+
return [l - r for l, r in zip(left, right)]
|
|
253
|
+
if source_node.op_type == "Mul":
|
|
254
|
+
return [l * r for l, r in zip(left, right)]
|
|
222
255
|
if source_node.op_type == "Not":
|
|
223
256
|
if len(source_node.inputs) != 1 or len(source_node.outputs) != 1:
|
|
224
257
|
raise UnsupportedOpError("Not must have 1 input and 1 output")
|
|
@@ -465,3 +498,18 @@ def optional_name(names: Sequence[str], index: int) -> str | None:
|
|
|
465
498
|
return None
|
|
466
499
|
name = names[index]
|
|
467
500
|
return name or None
|
|
501
|
+
|
|
502
|
+
|
|
503
|
+
def resolve_int_list_from_value(
|
|
504
|
+
graph: Graph | GraphContext,
|
|
505
|
+
name: str,
|
|
506
|
+
node: Node | None = None,
|
|
507
|
+
) -> list[int] | None:
|
|
508
|
+
return _shape_values_from_input(graph, name, node)
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
def value_has_dim_params(
|
|
512
|
+
graph: Graph | GraphContext,
|
|
513
|
+
name: str,
|
|
514
|
+
) -> bool:
|
|
515
|
+
return any(graph.find_value(name).type.dim_params)
|
emx_onnx_cgen/lowering/concat.py
CHANGED
|
@@ -1,12 +1,14 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from ..ir.ops import ConcatOp
|
|
4
3
|
from ..errors import UnsupportedOpError
|
|
4
|
+
from ..ir.context import GraphContext
|
|
5
5
|
from ..ir.model import Graph, Node
|
|
6
|
+
from ..ir.ops import ConcatOp
|
|
6
7
|
from .common import node_dtype as _node_dtype
|
|
8
|
+
from .common import value_has_dim_params as _value_has_dim_params
|
|
7
9
|
from .common import value_shape as _value_shape
|
|
8
10
|
from .registry import register_lowering
|
|
9
|
-
from ..validation import validate_concat_shapes
|
|
11
|
+
from ..validation import normalize_concat_axis, validate_concat_shapes
|
|
10
12
|
|
|
11
13
|
|
|
12
14
|
@register_lowering("Concat")
|
|
@@ -15,12 +17,44 @@ def lower_concat(graph: Graph, node: Node) -> ConcatOp:
|
|
|
15
17
|
raise UnsupportedOpError("Concat must have at least 1 input and 1 output")
|
|
16
18
|
op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
|
|
17
19
|
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
20
|
+
if _value_has_dim_params(graph, node.outputs[0]):
|
|
21
|
+
output_shape = ()
|
|
18
22
|
input_shapes = tuple(_value_shape(graph, name, node) for name in node.inputs)
|
|
19
|
-
axis =
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
23
|
+
axis = int(node.attrs.get("axis", 0))
|
|
24
|
+
if output_shape:
|
|
25
|
+
axis = validate_concat_shapes(
|
|
26
|
+
input_shapes,
|
|
27
|
+
output_shape,
|
|
28
|
+
axis,
|
|
29
|
+
)
|
|
30
|
+
else:
|
|
31
|
+
ranks = {len(shape) for shape in input_shapes}
|
|
32
|
+
if len(ranks) != 1:
|
|
33
|
+
raise UnsupportedOpError(
|
|
34
|
+
f"Concat inputs must have matching ranks, got {input_shapes}"
|
|
35
|
+
)
|
|
36
|
+
rank = ranks.pop()
|
|
37
|
+
axis = normalize_concat_axis(axis, rank)
|
|
38
|
+
base_shape = list(input_shapes[0])
|
|
39
|
+
axis_dim = 0
|
|
40
|
+
for shape in input_shapes:
|
|
41
|
+
if len(shape) != rank:
|
|
42
|
+
raise UnsupportedOpError(
|
|
43
|
+
f"Concat inputs must have matching ranks, got {input_shapes}"
|
|
44
|
+
)
|
|
45
|
+
for dim_index, dim in enumerate(shape):
|
|
46
|
+
if dim_index == axis:
|
|
47
|
+
continue
|
|
48
|
+
if dim != base_shape[dim_index]:
|
|
49
|
+
raise UnsupportedOpError(
|
|
50
|
+
"Concat inputs must match on non-axis dimensions, "
|
|
51
|
+
f"got {input_shapes}"
|
|
52
|
+
)
|
|
53
|
+
axis_dim += shape[axis]
|
|
54
|
+
base_shape[axis] = axis_dim
|
|
55
|
+
output_shape = tuple(base_shape)
|
|
56
|
+
if isinstance(graph, GraphContext):
|
|
57
|
+
graph.set_shape(node.outputs[0], output_shape)
|
|
24
58
|
return ConcatOp(
|
|
25
59
|
inputs=node.inputs,
|
|
26
60
|
output=node.outputs[0],
|
emx_onnx_cgen/lowering/conv.py
CHANGED
|
@@ -26,9 +26,14 @@ class ConvSpec:
|
|
|
26
26
|
group: int
|
|
27
27
|
|
|
28
28
|
|
|
29
|
-
def resolve_conv_spec(
|
|
30
|
-
|
|
31
|
-
|
|
29
|
+
def resolve_conv_spec(
|
|
30
|
+
graph: Graph,
|
|
31
|
+
node: Node,
|
|
32
|
+
*,
|
|
33
|
+
input_name: str,
|
|
34
|
+
weight_name: str,
|
|
35
|
+
bias_name: str | None,
|
|
36
|
+
) -> ConvSpec:
|
|
32
37
|
supported_attrs = {
|
|
33
38
|
"auto_pad",
|
|
34
39
|
"dilations",
|
|
@@ -39,8 +44,8 @@ def resolve_conv_spec(graph: Graph, node: Node) -> ConvSpec:
|
|
|
39
44
|
}
|
|
40
45
|
if set(node.attrs) - supported_attrs:
|
|
41
46
|
raise UnsupportedOpError("Conv has unsupported attributes")
|
|
42
|
-
input_shape = _value_shape(graph,
|
|
43
|
-
weight_shape = _value_shape(graph,
|
|
47
|
+
input_shape = _value_shape(graph, input_name, node)
|
|
48
|
+
weight_shape = _value_shape(graph, weight_name, node)
|
|
44
49
|
if len(input_shape) < 3:
|
|
45
50
|
raise UnsupportedOpError("Conv expects NCHW inputs with spatial dims")
|
|
46
51
|
spatial_rank = len(input_shape) - 2
|
|
@@ -79,8 +84,8 @@ def resolve_conv_spec(graph: Graph, node: Node) -> ConvSpec:
|
|
|
79
84
|
"Conv input channels must match weight channels, "
|
|
80
85
|
f"got {in_channels} and {weight_in_channels * group}"
|
|
81
86
|
)
|
|
82
|
-
if
|
|
83
|
-
bias_shape = _value_shape(graph,
|
|
87
|
+
if bias_name is not None:
|
|
88
|
+
bias_shape = _value_shape(graph, bias_name, node)
|
|
84
89
|
if bias_shape != (out_channels,):
|
|
85
90
|
raise ShapeInferenceError(
|
|
86
91
|
f"Conv bias shape must be {(out_channels,)}, got {bias_shape}"
|
|
@@ -171,7 +176,13 @@ def lower_conv(graph: Graph, node: Node) -> ConvOp:
|
|
|
171
176
|
raise UnsupportedOpError(
|
|
172
177
|
"Conv supports float16, float, and double inputs only"
|
|
173
178
|
)
|
|
174
|
-
spec = resolve_conv_spec(
|
|
179
|
+
spec = resolve_conv_spec(
|
|
180
|
+
graph,
|
|
181
|
+
node,
|
|
182
|
+
input_name=node.inputs[0],
|
|
183
|
+
weight_name=node.inputs[1],
|
|
184
|
+
bias_name=node.inputs[2] if len(node.inputs) == 3 else None,
|
|
185
|
+
)
|
|
175
186
|
return ConvOp(
|
|
176
187
|
input0=node.inputs[0],
|
|
177
188
|
weights=node.inputs[1],
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from shared.scalar_types import ScalarType
|
|
4
|
+
|
|
5
|
+
from ..errors import UnsupportedOpError
|
|
6
|
+
from ..ir.model import Graph, Node
|
|
7
|
+
from ..ir.ops import ConvIntegerOp
|
|
8
|
+
from .common import optional_name, value_dtype as _value_dtype
|
|
9
|
+
from .common import value_shape as _value_shape
|
|
10
|
+
from .conv import resolve_conv_spec
|
|
11
|
+
from .registry import register_lowering
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _ensure_scalar_shape(shape: tuple[int, ...], label: str) -> None:
|
|
15
|
+
if shape not in {(), (1,)}:
|
|
16
|
+
raise UnsupportedOpError(
|
|
17
|
+
f"ConvInteger {label} must be a scalar, got shape {shape}"
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _resolve_w_zero_point_shape(
|
|
22
|
+
shape: tuple[int, ...], out_channels: int
|
|
23
|
+
) -> bool:
|
|
24
|
+
if shape in {(), (1,)}:
|
|
25
|
+
return False
|
|
26
|
+
if shape == (out_channels,):
|
|
27
|
+
return True
|
|
28
|
+
raise UnsupportedOpError(
|
|
29
|
+
"ConvInteger w_zero_point must be scalar or 1D per output channel, "
|
|
30
|
+
f"got shape {shape}"
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@register_lowering("ConvInteger")
|
|
35
|
+
def lower_conv_integer(graph: Graph, node: Node) -> ConvIntegerOp:
|
|
36
|
+
if len(node.inputs) not in {2, 3, 4} or len(node.outputs) != 1:
|
|
37
|
+
raise UnsupportedOpError(
|
|
38
|
+
"ConvInteger must have 2 to 4 inputs and 1 output"
|
|
39
|
+
)
|
|
40
|
+
input_name = node.inputs[0]
|
|
41
|
+
weight_name = node.inputs[1]
|
|
42
|
+
x_zero_point_name = optional_name(node.inputs, 2)
|
|
43
|
+
w_zero_point_name = optional_name(node.inputs, 3)
|
|
44
|
+
input_dtype = _value_dtype(graph, input_name, node)
|
|
45
|
+
weight_dtype = _value_dtype(graph, weight_name, node)
|
|
46
|
+
output_dtype = _value_dtype(graph, node.outputs[0], node)
|
|
47
|
+
if input_dtype not in {ScalarType.U8, ScalarType.I8}:
|
|
48
|
+
raise UnsupportedOpError("ConvInteger supports uint8/int8 inputs only")
|
|
49
|
+
if weight_dtype not in {ScalarType.U8, ScalarType.I8}:
|
|
50
|
+
raise UnsupportedOpError("ConvInteger supports uint8/int8 weights only")
|
|
51
|
+
if output_dtype != ScalarType.I32:
|
|
52
|
+
raise UnsupportedOpError("ConvInteger expects int32 outputs only")
|
|
53
|
+
x_zero_shape = None
|
|
54
|
+
if x_zero_point_name is not None:
|
|
55
|
+
x_zero_shape = _value_shape(graph, x_zero_point_name, node)
|
|
56
|
+
_ensure_scalar_shape(x_zero_shape, "x_zero_point")
|
|
57
|
+
if _value_dtype(graph, x_zero_point_name, node) != input_dtype:
|
|
58
|
+
raise UnsupportedOpError(
|
|
59
|
+
"ConvInteger x_zero_point dtype must match input dtype"
|
|
60
|
+
)
|
|
61
|
+
w_zero_shape = None
|
|
62
|
+
w_zero_point_per_channel = False
|
|
63
|
+
if w_zero_point_name is not None:
|
|
64
|
+
w_zero_shape = _value_shape(graph, w_zero_point_name, node)
|
|
65
|
+
if _value_dtype(graph, w_zero_point_name, node) != weight_dtype:
|
|
66
|
+
raise UnsupportedOpError(
|
|
67
|
+
"ConvInteger w_zero_point dtype must match weight dtype"
|
|
68
|
+
)
|
|
69
|
+
spec = resolve_conv_spec(
|
|
70
|
+
graph,
|
|
71
|
+
node,
|
|
72
|
+
input_name=input_name,
|
|
73
|
+
weight_name=weight_name,
|
|
74
|
+
bias_name=None,
|
|
75
|
+
)
|
|
76
|
+
if w_zero_shape is not None:
|
|
77
|
+
w_zero_point_per_channel = _resolve_w_zero_point_shape(
|
|
78
|
+
w_zero_shape, spec.out_channels
|
|
79
|
+
)
|
|
80
|
+
return ConvIntegerOp(
|
|
81
|
+
input0=input_name,
|
|
82
|
+
weights=weight_name,
|
|
83
|
+
x_zero_point=x_zero_point_name,
|
|
84
|
+
w_zero_point=w_zero_point_name,
|
|
85
|
+
output=node.outputs[0],
|
|
86
|
+
batch=spec.batch,
|
|
87
|
+
in_channels=spec.in_channels,
|
|
88
|
+
out_channels=spec.out_channels,
|
|
89
|
+
spatial_rank=spec.spatial_rank,
|
|
90
|
+
in_spatial=spec.in_spatial,
|
|
91
|
+
out_spatial=spec.out_spatial,
|
|
92
|
+
kernel_shape=spec.kernel_shape,
|
|
93
|
+
strides=spec.strides,
|
|
94
|
+
pads=spec.pads,
|
|
95
|
+
dilations=spec.dilations,
|
|
96
|
+
group=spec.group,
|
|
97
|
+
input_dtype=input_dtype,
|
|
98
|
+
weight_dtype=weight_dtype,
|
|
99
|
+
dtype=output_dtype,
|
|
100
|
+
x_zero_point_shape=x_zero_shape,
|
|
101
|
+
w_zero_point_shape=w_zero_shape,
|
|
102
|
+
w_zero_point_per_channel=w_zero_point_per_channel,
|
|
103
|
+
)
|