onnx2tf 1.29.14__py3-none-any.whl → 1.29.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2tf/__init__.py +1 -1
- onnx2tf/onnx2tf.py +141 -0
- onnx2tf/ops/Concat.py +67 -41
- onnx2tf/ops/DequantizeLinear.py +76 -34
- onnx2tf/ops/DynamicQuantizeLinear.py +18 -17
- onnx2tf/ops/QLinearConcat.py +245 -26
- onnx2tf/ops/QLinearConv.py +70 -75
- onnx2tf/ops/QLinearMatMul.py +77 -20
- onnx2tf/ops/QuantizeLinear.py +117 -44
- onnx2tf/ops/Shape.py +2 -0
- onnx2tf/ops/Split.py +33 -8
- {onnx2tf-1.29.14.dist-info → onnx2tf-1.29.16.dist-info}/METADATA +3 -3
- {onnx2tf-1.29.14.dist-info → onnx2tf-1.29.16.dist-info}/RECORD +15 -15
- {onnx2tf-1.29.14.dist-info → onnx2tf-1.29.16.dist-info}/WHEEL +0 -0
- {onnx2tf-1.29.14.dist-info → onnx2tf-1.29.16.dist-info}/entry_points.txt +0 -0
onnx2tf/__init__.py
CHANGED
onnx2tf/onnx2tf.py
CHANGED
|
@@ -62,6 +62,146 @@ from onnx2tf.utils.enums import (
|
|
|
62
62
|
from onnx2tf.utils.logging import *
|
|
63
63
|
from sng4onnx import generate as op_name_auto_generate
|
|
64
64
|
|
|
65
|
+
def fuse_expanded_qdq_to_qdq(
|
|
66
|
+
*,
|
|
67
|
+
graph: gs.Graph,
|
|
68
|
+
):
|
|
69
|
+
def _get_const_value(tensor):
|
|
70
|
+
if isinstance(tensor, gs.Constant):
|
|
71
|
+
return tensor.values
|
|
72
|
+
if isinstance(tensor, gs.Variable) and len(tensor.inputs) == 1:
|
|
73
|
+
producer = tensor.inputs[0]
|
|
74
|
+
if producer.op == 'Constant' and 'value' in producer.attrs:
|
|
75
|
+
return producer.attrs['value'].values
|
|
76
|
+
return None
|
|
77
|
+
|
|
78
|
+
def _split_const_and_var(inputs):
|
|
79
|
+
if len(inputs) != 2:
|
|
80
|
+
return None, None
|
|
81
|
+
const_val = _get_const_value(inputs[0])
|
|
82
|
+
if const_val is not None:
|
|
83
|
+
return const_val, inputs[1]
|
|
84
|
+
const_val = _get_const_value(inputs[1])
|
|
85
|
+
if const_val is not None:
|
|
86
|
+
return const_val, inputs[0]
|
|
87
|
+
return None, None
|
|
88
|
+
|
|
89
|
+
nodes_to_remove = []
|
|
90
|
+
nodes_to_add = []
|
|
91
|
+
|
|
92
|
+
for round_node in list(graph.nodes):
|
|
93
|
+
if round_node.op != 'Round' or len(round_node.inputs) < 1:
|
|
94
|
+
continue
|
|
95
|
+
|
|
96
|
+
round_in = round_node.inputs[0]
|
|
97
|
+
if len(round_in.inputs) != 1:
|
|
98
|
+
continue
|
|
99
|
+
mul1_node = round_in.inputs[0]
|
|
100
|
+
if mul1_node.op != 'Mul':
|
|
101
|
+
continue
|
|
102
|
+
if len(mul1_node.outputs) != 1 or len(mul1_node.outputs[0].outputs) != 1:
|
|
103
|
+
continue
|
|
104
|
+
|
|
105
|
+
inv_scale, x = _split_const_and_var(mul1_node.inputs)
|
|
106
|
+
if inv_scale is None or x is None:
|
|
107
|
+
continue
|
|
108
|
+
|
|
109
|
+
relu_node = round_node.outputs[0].outputs[0] if round_node.outputs else None
|
|
110
|
+
if relu_node is None:
|
|
111
|
+
continue
|
|
112
|
+
if relu_node.op == 'Relu':
|
|
113
|
+
relu_out = relu_node.outputs[0]
|
|
114
|
+
elif relu_node.op in ['Max', 'Maximum']:
|
|
115
|
+
max_const, max_var = _split_const_and_var(relu_node.inputs)
|
|
116
|
+
if max_const is None or max_var != round_node.outputs[0]:
|
|
117
|
+
continue
|
|
118
|
+
if np.asarray(max_const).size != 1 or float(np.asarray(max_const).item()) != 0.0:
|
|
119
|
+
continue
|
|
120
|
+
relu_out = relu_node.outputs[0]
|
|
121
|
+
else:
|
|
122
|
+
continue
|
|
123
|
+
|
|
124
|
+
if len(relu_out.outputs) != 1:
|
|
125
|
+
continue
|
|
126
|
+
min_node = relu_out.outputs[0]
|
|
127
|
+
if min_node.op not in ['Min', 'Minimum']:
|
|
128
|
+
continue
|
|
129
|
+
|
|
130
|
+
qmax, min_var = _split_const_and_var(min_node.inputs)
|
|
131
|
+
if qmax is None or min_var != relu_out:
|
|
132
|
+
continue
|
|
133
|
+
if np.asarray(qmax).size != 1:
|
|
134
|
+
continue
|
|
135
|
+
|
|
136
|
+
if len(min_node.outputs) != 1 or len(min_node.outputs[0].outputs) != 1:
|
|
137
|
+
continue
|
|
138
|
+
mul2_node = min_node.outputs[0].outputs[0]
|
|
139
|
+
if mul2_node.op != 'Mul':
|
|
140
|
+
continue
|
|
141
|
+
|
|
142
|
+
scale, min_out = _split_const_and_var(mul2_node.inputs)
|
|
143
|
+
if scale is None or min_out != min_node.outputs[0]:
|
|
144
|
+
continue
|
|
145
|
+
if np.asarray(scale).size != 1:
|
|
146
|
+
continue
|
|
147
|
+
|
|
148
|
+
scale_val = float(np.asarray(scale).item())
|
|
149
|
+
inv_scale_val = float(np.asarray(inv_scale).item())
|
|
150
|
+
if scale_val == 0.0 or not np.isfinite(scale_val) or not np.isfinite(inv_scale_val):
|
|
151
|
+
continue
|
|
152
|
+
if not np.isclose(scale_val * inv_scale_val, 1.0, rtol=1e-3, atol=1e-6):
|
|
153
|
+
continue
|
|
154
|
+
|
|
155
|
+
if len(mul2_node.outputs) != 1:
|
|
156
|
+
continue
|
|
157
|
+
output_var = mul2_node.outputs[0]
|
|
158
|
+
|
|
159
|
+
# Require linear chain
|
|
160
|
+
chain_nodes = [mul1_node, round_node, relu_node, min_node, mul2_node]
|
|
161
|
+
if any(len(n.outputs) == 0 for n in chain_nodes):
|
|
162
|
+
continue
|
|
163
|
+
if len(round_node.outputs[0].outputs) != 1 or len(relu_out.outputs) != 1 or len(min_node.outputs[0].outputs) != 1:
|
|
164
|
+
continue
|
|
165
|
+
|
|
166
|
+
# Build QDQ
|
|
167
|
+
scale_const = gs.Constant(
|
|
168
|
+
name=f"{mul2_node.name}_scale",
|
|
169
|
+
values=np.asarray(scale_val, dtype=np.float32),
|
|
170
|
+
)
|
|
171
|
+
zero_const = gs.Constant(
|
|
172
|
+
name=f"{mul2_node.name}_zero_point",
|
|
173
|
+
values=np.asarray(0, dtype=np.uint8),
|
|
174
|
+
)
|
|
175
|
+
quant_out = gs.Variable(
|
|
176
|
+
name=f"{output_var.name}_quant",
|
|
177
|
+
dtype=np.uint8,
|
|
178
|
+
shape=output_var.shape,
|
|
179
|
+
)
|
|
180
|
+
q_node = gs.Node(
|
|
181
|
+
op="QuantizeLinear",
|
|
182
|
+
name=f"{mul2_node.name}_QuantizeLinear",
|
|
183
|
+
inputs=[x, scale_const, zero_const],
|
|
184
|
+
outputs=[quant_out],
|
|
185
|
+
)
|
|
186
|
+
dq_node = gs.Node(
|
|
187
|
+
op="DequantizeLinear",
|
|
188
|
+
name=f"{mul2_node.name}_DequantizeLinear",
|
|
189
|
+
inputs=[quant_out, scale_const, zero_const],
|
|
190
|
+
outputs=[output_var],
|
|
191
|
+
)
|
|
192
|
+
output_var.inputs = [dq_node]
|
|
193
|
+
|
|
194
|
+
nodes_to_add.extend([q_node, dq_node])
|
|
195
|
+
nodes_to_remove.extend(chain_nodes)
|
|
196
|
+
|
|
197
|
+
if nodes_to_add:
|
|
198
|
+
graph.nodes.extend(nodes_to_add)
|
|
199
|
+
if nodes_to_remove:
|
|
200
|
+
for n in nodes_to_remove:
|
|
201
|
+
if n in graph.nodes:
|
|
202
|
+
graph.nodes.remove(n)
|
|
203
|
+
graph.cleanup().toposort()
|
|
204
|
+
|
|
65
205
|
def apply_nonzero_passthrough(
|
|
66
206
|
*,
|
|
67
207
|
graph: gs.Graph,
|
|
@@ -848,6 +988,7 @@ def convert(
|
|
|
848
988
|
if hasattr(onnx_graph, 'metadata_props'):
|
|
849
989
|
metadata_props = onnx_graph.metadata_props
|
|
850
990
|
graph = gs.import_onnx(onnx_graph)
|
|
991
|
+
fuse_expanded_qdq_to_qdq(graph=graph)
|
|
851
992
|
|
|
852
993
|
# Cut the ONNX graph when an input name is specified that interrupts the conversion
|
|
853
994
|
if not input_names_to_interrupt_model_conversion:
|
onnx2tf/ops/Concat.py
CHANGED
|
@@ -234,6 +234,31 @@ def make_node(
|
|
|
234
234
|
and len(value.shape) > 0 else tf.reshape(value, [1]) for value in values
|
|
235
235
|
]
|
|
236
236
|
|
|
237
|
+
def _infer_concat_axis_runtime(values, fallback_axis):
|
|
238
|
+
if not values:
|
|
239
|
+
return fallback_axis
|
|
240
|
+
shapes = [tf.shape(v) for v in values]
|
|
241
|
+
shapes = tf.stack(shapes)
|
|
242
|
+
equal_mask = tf.reduce_all(tf.equal(shapes, shapes[0]), axis=0)
|
|
243
|
+
diff_mask = tf.cast(tf.logical_not(equal_mask), tf.int32)
|
|
244
|
+
candidate_count = tf.reduce_sum(diff_mask)
|
|
245
|
+
axis_from_diff = tf.argmax(diff_mask, axis=0, output_type=tf.int32)
|
|
246
|
+
fallback_axis_tensor = tf.cast(fallback_axis, tf.int32)
|
|
247
|
+
is_single = tf.cast(tf.equal(candidate_count, 1), tf.int32)
|
|
248
|
+
return axis_from_diff * is_single + fallback_axis_tensor * (1 - is_single)
|
|
249
|
+
|
|
250
|
+
axis_is_dynamic = False
|
|
251
|
+
if len(values) > 0:
|
|
252
|
+
all_none = True
|
|
253
|
+
for value in values:
|
|
254
|
+
if value.shape is not None and value.shape != tf.TensorShape(None):
|
|
255
|
+
if not all([s is None for s in value.shape]):
|
|
256
|
+
all_none = False
|
|
257
|
+
break
|
|
258
|
+
if all_none:
|
|
259
|
+
axis_is_dynamic = True
|
|
260
|
+
axis_for_concat = _infer_concat_axis_runtime(values, axis) if axis_is_dynamic else axis
|
|
261
|
+
|
|
237
262
|
# Generation of TF OP
|
|
238
263
|
tf_type = None
|
|
239
264
|
if simple_resize:
|
|
@@ -271,7 +296,7 @@ def make_node(
|
|
|
271
296
|
tf_layers_dict[graph_node_output.name]['tf_node'] = \
|
|
272
297
|
tf.concat(
|
|
273
298
|
values=values,
|
|
274
|
-
axis=
|
|
299
|
+
axis=axis_for_concat,
|
|
275
300
|
name=graph_node.name,
|
|
276
301
|
)
|
|
277
302
|
except:
|
|
@@ -311,51 +336,52 @@ def make_node(
|
|
|
311
336
|
# This workaround is useful when automatic axis correction is practically difficult,
|
|
312
337
|
# such as when all tensors to be combined originate from Transpose or Reshape.
|
|
313
338
|
# https://github.com/PINTO0309/onnx2tf/issues/473
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
tf.concat(
|
|
342
|
-
values=values,
|
|
343
|
-
axis=matched_axes[0],
|
|
344
|
-
name=graph_node.name,
|
|
345
|
-
)
|
|
346
|
-
axis = matched_axes[0]
|
|
347
|
-
elif not nhwc_judge:
|
|
348
|
-
onnx_axis = int(graph_node.attrs.get('axis', 0))
|
|
349
|
-
onnx_axis = output_tensor_rank - 1 if onnx_axis == -1 else onnx_axis
|
|
350
|
-
if onnx_axis == output_tensor_rank - 1 \
|
|
351
|
-
and onnx_axis in matched_axes:
|
|
339
|
+
if not axis_is_dynamic:
|
|
340
|
+
output_tensor_shape = tf_layers_dict[graph_node_output.name]['tf_node'].shape
|
|
341
|
+
if output_tensor_shape != tf.TensorShape(None):
|
|
342
|
+
output_tensor_rank = len(output_tensor_shape)
|
|
343
|
+
if graph_node.outputs[0].shape is not None \
|
|
344
|
+
and axis != 0 \
|
|
345
|
+
and output_tensor_rank >= 2 \
|
|
346
|
+
and before_axis == axis:
|
|
347
|
+
|
|
348
|
+
# Search for valid Concat patterns
|
|
349
|
+
if not shape_is_equal_ignore_order(list(graph_node.outputs[0].shape), list(output_tensor_shape)):
|
|
350
|
+
matched_axes = []
|
|
351
|
+
for dummy_axis in range(1, output_tensor_rank):
|
|
352
|
+
try:
|
|
353
|
+
dummy_concat_tensor = \
|
|
354
|
+
tf.concat(
|
|
355
|
+
values=values,
|
|
356
|
+
axis=dummy_axis,
|
|
357
|
+
name=graph_node.name,
|
|
358
|
+
)
|
|
359
|
+
dummy_output_shape = dummy_concat_tensor.shape
|
|
360
|
+
if shape_is_equal_ignore_order(list(graph_node.outputs[0].shape), list(dummy_output_shape)):
|
|
361
|
+
matched_axes.append(dummy_axis)
|
|
362
|
+
except:
|
|
363
|
+
pass
|
|
364
|
+
# Review Concat axes only if there is one valid join pattern
|
|
365
|
+
if len(matched_axes) == 1:
|
|
352
366
|
tf_layers_dict[graph_node_output.name]['tf_node'] = \
|
|
353
367
|
tf.concat(
|
|
354
368
|
values=values,
|
|
355
|
-
axis=
|
|
369
|
+
axis=matched_axes[0],
|
|
356
370
|
name=graph_node.name,
|
|
357
371
|
)
|
|
358
|
-
axis =
|
|
372
|
+
axis = matched_axes[0]
|
|
373
|
+
elif not nhwc_judge:
|
|
374
|
+
onnx_axis = int(graph_node.attrs.get('axis', 0))
|
|
375
|
+
onnx_axis = output_tensor_rank - 1 if onnx_axis == -1 else onnx_axis
|
|
376
|
+
if onnx_axis == output_tensor_rank - 1 \
|
|
377
|
+
and onnx_axis in matched_axes:
|
|
378
|
+
tf_layers_dict[graph_node_output.name]['tf_node'] = \
|
|
379
|
+
tf.concat(
|
|
380
|
+
values=values,
|
|
381
|
+
axis=onnx_axis,
|
|
382
|
+
name=graph_node.name,
|
|
383
|
+
)
|
|
384
|
+
axis = onnx_axis
|
|
359
385
|
|
|
360
386
|
# Workaround for post-concat accuracy degradation issues
|
|
361
387
|
# Process only in the case of a Concat of two tensors because the process is too redundant.
|
onnx2tf/ops/DequantizeLinear.py
CHANGED
|
@@ -15,6 +15,43 @@ from onnx2tf.utils.common_functions import (
|
|
|
15
15
|
post_process_transpose,
|
|
16
16
|
)
|
|
17
17
|
|
|
18
|
+
def _expand_scale_or_zero_point(
|
|
19
|
+
*,
|
|
20
|
+
value,
|
|
21
|
+
input_tensor,
|
|
22
|
+
axis: int,
|
|
23
|
+
block_size: int,
|
|
24
|
+
):
|
|
25
|
+
value_rank = len(value.shape)
|
|
26
|
+
input_rank = len(input_tensor.shape)
|
|
27
|
+
|
|
28
|
+
if value_rank == 0:
|
|
29
|
+
return value
|
|
30
|
+
|
|
31
|
+
if input_rank <= 0:
|
|
32
|
+
return value
|
|
33
|
+
|
|
34
|
+
if axis < 0 or axis >= input_rank:
|
|
35
|
+
axis = 0
|
|
36
|
+
|
|
37
|
+
# Blocked quantization: expand along axis then slice to input shape
|
|
38
|
+
if block_size > 0 and value_rank == input_rank:
|
|
39
|
+
if value.shape[axis] is None \
|
|
40
|
+
or input_tensor.shape[axis] is None \
|
|
41
|
+
or value.shape[axis] != input_tensor.shape[axis]:
|
|
42
|
+
expanded = tf.repeat(value, repeats=block_size, axis=axis)
|
|
43
|
+
expanded = tf.slice(expanded, [0] * input_rank, tf.shape(input_tensor))
|
|
44
|
+
return expanded
|
|
45
|
+
return value
|
|
46
|
+
|
|
47
|
+
# Per-axis quantization: reshape 1-D to broadcast
|
|
48
|
+
if value_rank == 1 and input_rank is not None:
|
|
49
|
+
shape = [1] * input_rank
|
|
50
|
+
shape[axis] = -1
|
|
51
|
+
return tf.reshape(value, shape)
|
|
52
|
+
|
|
53
|
+
return value
|
|
54
|
+
|
|
18
55
|
|
|
19
56
|
@print_node_info
|
|
20
57
|
@inverted_operation_enable_disable
|
|
@@ -63,6 +100,11 @@ def make_node(
|
|
|
63
100
|
|
|
64
101
|
input_tensor = tf_layers_dict[graph_node_input_1.name]['tf_node'] \
|
|
65
102
|
if isinstance(graph_node_input_1, gs.Variable) else graph_node_input_1
|
|
103
|
+
input_is_dequantized = False
|
|
104
|
+
input_nhwc = False
|
|
105
|
+
if isinstance(graph_node_input_1, gs.Variable):
|
|
106
|
+
input_is_dequantized = tf_layers_dict.get(graph_node_input_1.name, {}).get('is_dequantized', False)
|
|
107
|
+
input_nhwc = tf_layers_dict.get(graph_node_input_1.name, {}).get('nhwc', False)
|
|
66
108
|
|
|
67
109
|
# Pre-process transpose
|
|
68
110
|
input_tensor = pre_process_transpose(
|
|
@@ -72,12 +114,10 @@ def make_node(
|
|
|
72
114
|
**kwargs,
|
|
73
115
|
)
|
|
74
116
|
|
|
75
|
-
|
|
76
|
-
|
|
117
|
+
input_tensor_rank = len(input_tensor.shape)
|
|
118
|
+
input_tensor_dtype = input_tensor.dtype
|
|
77
119
|
x_scale = tf_layers_dict[graph_node_input_2.name]['tf_node'] \
|
|
78
120
|
if isinstance(graph_node_input_2, gs.Variable) else graph_node_input_2
|
|
79
|
-
x_scale_shape = x_scale.shape
|
|
80
|
-
x_scale_rank = len(x_scale_shape)
|
|
81
121
|
x_zero_point = tf_layers_dict[graph_node_input_3.name]['tf_node'] \
|
|
82
122
|
if isinstance(graph_node_input_3, gs.Variable) else graph_node_input_3
|
|
83
123
|
|
|
@@ -87,48 +127,50 @@ def make_node(
|
|
|
87
127
|
tensor_rank=input_tensor_rank,
|
|
88
128
|
before_op_output_shape_trans=before_op_output_shape_trans,
|
|
89
129
|
)
|
|
130
|
+
if input_tensor_rank == 1:
|
|
131
|
+
axis = 0
|
|
90
132
|
|
|
91
133
|
# Preserving Graph Structure (Dict)
|
|
92
134
|
tf_layers_dict[graph_node_output.name] = {
|
|
93
135
|
'optype': graph_node.op,
|
|
94
136
|
'shape': shape,
|
|
95
137
|
'dtype': dtype,
|
|
138
|
+
'is_dequantized': True,
|
|
139
|
+
'nhwc': input_nhwc,
|
|
96
140
|
}
|
|
97
141
|
|
|
98
142
|
# Generation of TF OP
|
|
99
143
|
|
|
100
144
|
input_tensor = tf.cast(input_tensor, tf.float32)
|
|
145
|
+
x_scale = tf.cast(x_scale, tf.float32)
|
|
101
146
|
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
x=subed_tensor,
|
|
130
|
-
y=x_scale,
|
|
131
|
-
)
|
|
147
|
+
block_size = int(graph_node.attrs.get('block_size', 0))
|
|
148
|
+
x_scale = _expand_scale_or_zero_point(
|
|
149
|
+
value=x_scale,
|
|
150
|
+
input_tensor=input_tensor,
|
|
151
|
+
axis=axis,
|
|
152
|
+
block_size=block_size,
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
if input_is_dequantized:
|
|
156
|
+
tf_layers_dict[graph_node_output.name]['tf_node'] = input_tensor
|
|
157
|
+
else:
|
|
158
|
+
if x_zero_point is None or input_tensor_dtype == tf.int32:
|
|
159
|
+
x_zero_point = tf.zeros_like(x_scale)
|
|
160
|
+
else:
|
|
161
|
+
x_zero_point = tf.cast(x_zero_point, tf.float32)
|
|
162
|
+
x_zero_point = _expand_scale_or_zero_point(
|
|
163
|
+
value=x_zero_point,
|
|
164
|
+
input_tensor=input_tensor,
|
|
165
|
+
axis=axis,
|
|
166
|
+
block_size=block_size,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
tf_layers_dict[graph_node_output.name]['tf_node'] = \
|
|
170
|
+
tf.multiply(
|
|
171
|
+
x=tf.subtract(input_tensor, x_zero_point),
|
|
172
|
+
y=x_scale,
|
|
173
|
+
)
|
|
132
174
|
|
|
133
175
|
if hasattr(tf_layers_dict[graph_node_output.name]['tf_node'], 'numpy'):
|
|
134
176
|
tf_layers_dict[graph_node_output.name]['tf_node'] = \
|
|
@@ -43,6 +43,9 @@ def make_node(
|
|
|
43
43
|
graph_node.inputs[0],
|
|
44
44
|
before_op_output_shape_trans,
|
|
45
45
|
)
|
|
46
|
+
input_nhwc = False
|
|
47
|
+
if isinstance(graph_node_input_1, gs.Variable):
|
|
48
|
+
input_nhwc = tf_layers_dict.get(graph_node_input_1.name, {}).get('nhwc', False)
|
|
46
49
|
graph_node_output_1: gs.Variable = graph_node.outputs[0]
|
|
47
50
|
o1_shape = graph_node_output_1.shape
|
|
48
51
|
o1_dtype = graph_node_output_1.dtype
|
|
@@ -58,6 +61,8 @@ def make_node(
|
|
|
58
61
|
'optype': graph_node.op,
|
|
59
62
|
'shape': o1_shape,
|
|
60
63
|
'dtype': o1_dtype,
|
|
64
|
+
'is_dequantized': True,
|
|
65
|
+
'nhwc': input_nhwc,
|
|
61
66
|
}
|
|
62
67
|
tf_layers_dict[graph_node_output_2.name] = {
|
|
63
68
|
'optype': graph_node.op,
|
|
@@ -82,35 +87,31 @@ def make_node(
|
|
|
82
87
|
)
|
|
83
88
|
|
|
84
89
|
# Generation of TF OP
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
max_x = tf.math.maximum(0., tf.math.reduce_max(input_tensor_1))
|
|
90
|
+
qmin = 0.0
|
|
91
|
+
qmax = 255.0
|
|
92
|
+
min_x = tf.math.minimum(0.0, tf.math.reduce_min(input_tensor_1))
|
|
93
|
+
max_x = tf.math.maximum(0.0, tf.math.reduce_max(input_tensor_1))
|
|
90
94
|
y_scale = (max_x - min_x) / (qmax - qmin)
|
|
91
95
|
intermediate_zero_point = qmin - (min_x / y_scale)
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
x=intermediate_zero_point
|
|
95
|
-
),
|
|
96
|
+
clipped_zero_point = tf.clip_by_value(
|
|
97
|
+
intermediate_zero_point,
|
|
96
98
|
clip_value_min=qmin,
|
|
97
99
|
clip_value_max=qmax,
|
|
98
100
|
)
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
),
|
|
105
|
-
dtype=dtype,
|
|
101
|
+
y_zero_point = tf.round(clipped_zero_point)
|
|
102
|
+
y_quant = tf.clip_by_value(
|
|
103
|
+
tf.round(input_tensor_1 / y_scale) + y_zero_point,
|
|
104
|
+
clip_value_min=qmin,
|
|
105
|
+
clip_value_max=qmax,
|
|
106
106
|
)
|
|
107
|
+
y = (y_quant - y_zero_point) * y_scale
|
|
107
108
|
|
|
108
109
|
tf_layers_dict[graph_node_output_1.name]['tf_node'] = y
|
|
109
110
|
tf_layers_dict[graph_node_output_2.name]['tf_node'] = y_scale
|
|
110
111
|
tf_layers_dict[graph_node_output_3.name]['tf_node'] = \
|
|
111
112
|
tf.cast(
|
|
112
113
|
x=y_zero_point,
|
|
113
|
-
dtype=
|
|
114
|
+
dtype=tf.uint8,
|
|
114
115
|
)
|
|
115
116
|
|
|
116
117
|
# Post-process transpose
|