onnx2tf 1.29.14__py3-none-any.whl → 1.29.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
onnx2tf/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
1
  from onnx2tf.onnx2tf import convert, main
2
2
 
3
- __version__ = '1.29.14'
3
+ __version__ = '1.29.16'
onnx2tf/onnx2tf.py CHANGED
@@ -62,6 +62,146 @@ from onnx2tf.utils.enums import (
62
62
  from onnx2tf.utils.logging import *
63
63
  from sng4onnx import generate as op_name_auto_generate
64
64
 
65
+ def fuse_expanded_qdq_to_qdq(
66
+ *,
67
+ graph: gs.Graph,
68
+ ):
69
+ def _get_const_value(tensor):
70
+ if isinstance(tensor, gs.Constant):
71
+ return tensor.values
72
+ if isinstance(tensor, gs.Variable) and len(tensor.inputs) == 1:
73
+ producer = tensor.inputs[0]
74
+ if producer.op == 'Constant' and 'value' in producer.attrs:
75
+ return producer.attrs['value'].values
76
+ return None
77
+
78
+ def _split_const_and_var(inputs):
79
+ if len(inputs) != 2:
80
+ return None, None
81
+ const_val = _get_const_value(inputs[0])
82
+ if const_val is not None:
83
+ return const_val, inputs[1]
84
+ const_val = _get_const_value(inputs[1])
85
+ if const_val is not None:
86
+ return const_val, inputs[0]
87
+ return None, None
88
+
89
+ nodes_to_remove = []
90
+ nodes_to_add = []
91
+
92
+ for round_node in list(graph.nodes):
93
+ if round_node.op != 'Round' or len(round_node.inputs) < 1:
94
+ continue
95
+
96
+ round_in = round_node.inputs[0]
97
+ if len(round_in.inputs) != 1:
98
+ continue
99
+ mul1_node = round_in.inputs[0]
100
+ if mul1_node.op != 'Mul':
101
+ continue
102
+ if len(mul1_node.outputs) != 1 or len(mul1_node.outputs[0].outputs) != 1:
103
+ continue
104
+
105
+ inv_scale, x = _split_const_and_var(mul1_node.inputs)
106
+ if inv_scale is None or x is None:
107
+ continue
108
+
109
+ relu_node = round_node.outputs[0].outputs[0] if round_node.outputs else None
110
+ if relu_node is None:
111
+ continue
112
+ if relu_node.op == 'Relu':
113
+ relu_out = relu_node.outputs[0]
114
+ elif relu_node.op in ['Max', 'Maximum']:
115
+ max_const, max_var = _split_const_and_var(relu_node.inputs)
116
+ if max_const is None or max_var != round_node.outputs[0]:
117
+ continue
118
+ if np.asarray(max_const).size != 1 or float(np.asarray(max_const).item()) != 0.0:
119
+ continue
120
+ relu_out = relu_node.outputs[0]
121
+ else:
122
+ continue
123
+
124
+ if len(relu_out.outputs) != 1:
125
+ continue
126
+ min_node = relu_out.outputs[0]
127
+ if min_node.op not in ['Min', 'Minimum']:
128
+ continue
129
+
130
+ qmax, min_var = _split_const_and_var(min_node.inputs)
131
+ if qmax is None or min_var != relu_out:
132
+ continue
133
+ if np.asarray(qmax).size != 1:
134
+ continue
135
+
136
+ if len(min_node.outputs) != 1 or len(min_node.outputs[0].outputs) != 1:
137
+ continue
138
+ mul2_node = min_node.outputs[0].outputs[0]
139
+ if mul2_node.op != 'Mul':
140
+ continue
141
+
142
+ scale, min_out = _split_const_and_var(mul2_node.inputs)
143
+ if scale is None or min_out != min_node.outputs[0]:
144
+ continue
145
+ if np.asarray(scale).size != 1:
146
+ continue
147
+
148
+ scale_val = float(np.asarray(scale).item())
149
+ inv_scale_val = float(np.asarray(inv_scale).item())
150
+ if scale_val == 0.0 or not np.isfinite(scale_val) or not np.isfinite(inv_scale_val):
151
+ continue
152
+ if not np.isclose(scale_val * inv_scale_val, 1.0, rtol=1e-3, atol=1e-6):
153
+ continue
154
+
155
+ if len(mul2_node.outputs) != 1:
156
+ continue
157
+ output_var = mul2_node.outputs[0]
158
+
159
+ # Require linear chain
160
+ chain_nodes = [mul1_node, round_node, relu_node, min_node, mul2_node]
161
+ if any(len(n.outputs) == 0 for n in chain_nodes):
162
+ continue
163
+ if len(round_node.outputs[0].outputs) != 1 or len(relu_out.outputs) != 1 or len(min_node.outputs[0].outputs) != 1:
164
+ continue
165
+
166
+ # Build QDQ
167
+ scale_const = gs.Constant(
168
+ name=f"{mul2_node.name}_scale",
169
+ values=np.asarray(scale_val, dtype=np.float32),
170
+ )
171
+ zero_const = gs.Constant(
172
+ name=f"{mul2_node.name}_zero_point",
173
+ values=np.asarray(0, dtype=np.uint8),
174
+ )
175
+ quant_out = gs.Variable(
176
+ name=f"{output_var.name}_quant",
177
+ dtype=np.uint8,
178
+ shape=output_var.shape,
179
+ )
180
+ q_node = gs.Node(
181
+ op="QuantizeLinear",
182
+ name=f"{mul2_node.name}_QuantizeLinear",
183
+ inputs=[x, scale_const, zero_const],
184
+ outputs=[quant_out],
185
+ )
186
+ dq_node = gs.Node(
187
+ op="DequantizeLinear",
188
+ name=f"{mul2_node.name}_DequantizeLinear",
189
+ inputs=[quant_out, scale_const, zero_const],
190
+ outputs=[output_var],
191
+ )
192
+ output_var.inputs = [dq_node]
193
+
194
+ nodes_to_add.extend([q_node, dq_node])
195
+ nodes_to_remove.extend(chain_nodes)
196
+
197
+ if nodes_to_add:
198
+ graph.nodes.extend(nodes_to_add)
199
+ if nodes_to_remove:
200
+ for n in nodes_to_remove:
201
+ if n in graph.nodes:
202
+ graph.nodes.remove(n)
203
+ graph.cleanup().toposort()
204
+
65
205
  def apply_nonzero_passthrough(
66
206
  *,
67
207
  graph: gs.Graph,
@@ -848,6 +988,7 @@ def convert(
848
988
  if hasattr(onnx_graph, 'metadata_props'):
849
989
  metadata_props = onnx_graph.metadata_props
850
990
  graph = gs.import_onnx(onnx_graph)
991
+ fuse_expanded_qdq_to_qdq(graph=graph)
851
992
 
852
993
  # Cut the ONNX graph when an input name is specified that interrupts the conversion
853
994
  if not input_names_to_interrupt_model_conversion:
onnx2tf/ops/Concat.py CHANGED
@@ -234,6 +234,31 @@ def make_node(
234
234
  and len(value.shape) > 0 else tf.reshape(value, [1]) for value in values
235
235
  ]
236
236
 
237
+ def _infer_concat_axis_runtime(values, fallback_axis):
238
+ if not values:
239
+ return fallback_axis
240
+ shapes = [tf.shape(v) for v in values]
241
+ shapes = tf.stack(shapes)
242
+ equal_mask = tf.reduce_all(tf.equal(shapes, shapes[0]), axis=0)
243
+ diff_mask = tf.cast(tf.logical_not(equal_mask), tf.int32)
244
+ candidate_count = tf.reduce_sum(diff_mask)
245
+ axis_from_diff = tf.argmax(diff_mask, axis=0, output_type=tf.int32)
246
+ fallback_axis_tensor = tf.cast(fallback_axis, tf.int32)
247
+ is_single = tf.cast(tf.equal(candidate_count, 1), tf.int32)
248
+ return axis_from_diff * is_single + fallback_axis_tensor * (1 - is_single)
249
+
250
+ axis_is_dynamic = False
251
+ if len(values) > 0:
252
+ all_none = True
253
+ for value in values:
254
+ if value.shape is not None and value.shape != tf.TensorShape(None):
255
+ if not all([s is None for s in value.shape]):
256
+ all_none = False
257
+ break
258
+ if all_none:
259
+ axis_is_dynamic = True
260
+ axis_for_concat = _infer_concat_axis_runtime(values, axis) if axis_is_dynamic else axis
261
+
237
262
  # Generation of TF OP
238
263
  tf_type = None
239
264
  if simple_resize:
@@ -271,7 +296,7 @@ def make_node(
271
296
  tf_layers_dict[graph_node_output.name]['tf_node'] = \
272
297
  tf.concat(
273
298
  values=values,
274
- axis=axis,
299
+ axis=axis_for_concat,
275
300
  name=graph_node.name,
276
301
  )
277
302
  except:
@@ -311,51 +336,52 @@ def make_node(
311
336
  # This workaround is useful when automatic axis correction is practically difficult,
312
337
  # such as when all tensors to be combined originate from Transpose or Reshape.
313
338
  # https://github.com/PINTO0309/onnx2tf/issues/473
314
- output_tensor_shape = tf_layers_dict[graph_node_output.name]['tf_node'].shape
315
- if output_tensor_shape != tf.TensorShape(None):
316
- output_tensor_rank = len(output_tensor_shape)
317
- if graph_node.outputs[0].shape is not None \
318
- and axis != 0 \
319
- and output_tensor_rank >= 2 \
320
- and before_axis == axis:
321
-
322
- # Search for valid Concat patterns
323
- if not shape_is_equal_ignore_order(list(graph_node.outputs[0].shape), list(output_tensor_shape)):
324
- matched_axes = []
325
- for dummy_axis in range(1, output_tensor_rank):
326
- try:
327
- dummy_concat_tensor = \
328
- tf.concat(
329
- values=values,
330
- axis=dummy_axis,
331
- name=graph_node.name,
332
- )
333
- dummy_output_shape = dummy_concat_tensor.shape
334
- if shape_is_equal_ignore_order(list(graph_node.outputs[0].shape), list(dummy_output_shape)):
335
- matched_axes.append(dummy_axis)
336
- except:
337
- pass
338
- # Review Concat axes only if there is one valid join pattern
339
- if len(matched_axes) == 1:
340
- tf_layers_dict[graph_node_output.name]['tf_node'] = \
341
- tf.concat(
342
- values=values,
343
- axis=matched_axes[0],
344
- name=graph_node.name,
345
- )
346
- axis = matched_axes[0]
347
- elif not nhwc_judge:
348
- onnx_axis = int(graph_node.attrs.get('axis', 0))
349
- onnx_axis = output_tensor_rank - 1 if onnx_axis == -1 else onnx_axis
350
- if onnx_axis == output_tensor_rank - 1 \
351
- and onnx_axis in matched_axes:
339
+ if not axis_is_dynamic:
340
+ output_tensor_shape = tf_layers_dict[graph_node_output.name]['tf_node'].shape
341
+ if output_tensor_shape != tf.TensorShape(None):
342
+ output_tensor_rank = len(output_tensor_shape)
343
+ if graph_node.outputs[0].shape is not None \
344
+ and axis != 0 \
345
+ and output_tensor_rank >= 2 \
346
+ and before_axis == axis:
347
+
348
+ # Search for valid Concat patterns
349
+ if not shape_is_equal_ignore_order(list(graph_node.outputs[0].shape), list(output_tensor_shape)):
350
+ matched_axes = []
351
+ for dummy_axis in range(1, output_tensor_rank):
352
+ try:
353
+ dummy_concat_tensor = \
354
+ tf.concat(
355
+ values=values,
356
+ axis=dummy_axis,
357
+ name=graph_node.name,
358
+ )
359
+ dummy_output_shape = dummy_concat_tensor.shape
360
+ if shape_is_equal_ignore_order(list(graph_node.outputs[0].shape), list(dummy_output_shape)):
361
+ matched_axes.append(dummy_axis)
362
+ except:
363
+ pass
364
+ # Review Concat axes only if there is one valid join pattern
365
+ if len(matched_axes) == 1:
352
366
  tf_layers_dict[graph_node_output.name]['tf_node'] = \
353
367
  tf.concat(
354
368
  values=values,
355
- axis=onnx_axis,
369
+ axis=matched_axes[0],
356
370
  name=graph_node.name,
357
371
  )
358
- axis = onnx_axis
372
+ axis = matched_axes[0]
373
+ elif not nhwc_judge:
374
+ onnx_axis = int(graph_node.attrs.get('axis', 0))
375
+ onnx_axis = output_tensor_rank - 1 if onnx_axis == -1 else onnx_axis
376
+ if onnx_axis == output_tensor_rank - 1 \
377
+ and onnx_axis in matched_axes:
378
+ tf_layers_dict[graph_node_output.name]['tf_node'] = \
379
+ tf.concat(
380
+ values=values,
381
+ axis=onnx_axis,
382
+ name=graph_node.name,
383
+ )
384
+ axis = onnx_axis
359
385
 
360
386
  # Workaround for post-concat accuracy degradation issues
361
387
  # Process only in the case of a Concat of two tensors because the process is too redundant.
@@ -15,6 +15,43 @@ from onnx2tf.utils.common_functions import (
15
15
  post_process_transpose,
16
16
  )
17
17
 
18
+ def _expand_scale_or_zero_point(
19
+ *,
20
+ value,
21
+ input_tensor,
22
+ axis: int,
23
+ block_size: int,
24
+ ):
25
+ value_rank = len(value.shape)
26
+ input_rank = len(input_tensor.shape)
27
+
28
+ if value_rank == 0:
29
+ return value
30
+
31
+ if input_rank <= 0:
32
+ return value
33
+
34
+ if axis < 0 or axis >= input_rank:
35
+ axis = 0
36
+
37
+ # Blocked quantization: expand along axis then slice to input shape
38
+ if block_size > 0 and value_rank == input_rank:
39
+ if value.shape[axis] is None \
40
+ or input_tensor.shape[axis] is None \
41
+ or value.shape[axis] != input_tensor.shape[axis]:
42
+ expanded = tf.repeat(value, repeats=block_size, axis=axis)
43
+ expanded = tf.slice(expanded, [0] * input_rank, tf.shape(input_tensor))
44
+ return expanded
45
+ return value
46
+
47
+ # Per-axis quantization: reshape 1-D to broadcast
48
+ if value_rank == 1 and input_rank is not None:
49
+ shape = [1] * input_rank
50
+ shape[axis] = -1
51
+ return tf.reshape(value, shape)
52
+
53
+ return value
54
+
18
55
 
19
56
  @print_node_info
20
57
  @inverted_operation_enable_disable
@@ -63,6 +100,11 @@ def make_node(
63
100
 
64
101
  input_tensor = tf_layers_dict[graph_node_input_1.name]['tf_node'] \
65
102
  if isinstance(graph_node_input_1, gs.Variable) else graph_node_input_1
103
+ input_is_dequantized = False
104
+ input_nhwc = False
105
+ if isinstance(graph_node_input_1, gs.Variable):
106
+ input_is_dequantized = tf_layers_dict.get(graph_node_input_1.name, {}).get('is_dequantized', False)
107
+ input_nhwc = tf_layers_dict.get(graph_node_input_1.name, {}).get('nhwc', False)
66
108
 
67
109
  # Pre-process transpose
68
110
  input_tensor = pre_process_transpose(
@@ -72,12 +114,10 @@ def make_node(
72
114
  **kwargs,
73
115
  )
74
116
 
75
- input_tensor_shape = input_tensor.shape
76
- input_tensor_rank = len(input_tensor_shape)
117
+ input_tensor_rank = len(input_tensor.shape)
118
+ input_tensor_dtype = input_tensor.dtype
77
119
  x_scale = tf_layers_dict[graph_node_input_2.name]['tf_node'] \
78
120
  if isinstance(graph_node_input_2, gs.Variable) else graph_node_input_2
79
- x_scale_shape = x_scale.shape
80
- x_scale_rank = len(x_scale_shape)
81
121
  x_zero_point = tf_layers_dict[graph_node_input_3.name]['tf_node'] \
82
122
  if isinstance(graph_node_input_3, gs.Variable) else graph_node_input_3
83
123
 
@@ -87,48 +127,50 @@ def make_node(
87
127
  tensor_rank=input_tensor_rank,
88
128
  before_op_output_shape_trans=before_op_output_shape_trans,
89
129
  )
130
+ if input_tensor_rank == 1:
131
+ axis = 0
90
132
 
91
133
  # Preserving Graph Structure (Dict)
92
134
  tf_layers_dict[graph_node_output.name] = {
93
135
  'optype': graph_node.op,
94
136
  'shape': shape,
95
137
  'dtype': dtype,
138
+ 'is_dequantized': True,
139
+ 'nhwc': input_nhwc,
96
140
  }
97
141
 
98
142
  # Generation of TF OP
99
143
 
100
144
  input_tensor = tf.cast(input_tensor, tf.float32)
145
+ x_scale = tf.cast(x_scale, tf.float32)
101
146
 
102
- # Reshape process is needed for per-axis dequantization
103
- # when scale is a 1-D tensor
104
- if x_scale_rank == 1 and x_scale_shape[0] != 1:
105
- shape_broadcast = list([1 for _ in range(axis)] + [input_tensor_shape[axis]] + [1 for _ in range(axis + 1, input_tensor_rank)])
106
- x_scale = tf.reshape(
107
- tensor=x_scale,
108
- shape=shape_broadcast,
109
- )
110
- elif x_scale_rank == 1 and x_scale_shape[0] == 1:
111
- shape_broadcast = [1 for i in range(input_tensor_rank)]
112
-
113
- subed_tensor = input_tensor
114
- if len(graph_node.inputs) >= 3 and input_tensor.dtype != tf.int32:
115
- x_zero_point = tf.cast(
116
- x=x_zero_point,
117
- dtype=tf.float32,
118
- )
119
- x_zero_point = tf.reshape(
120
- tensor=x_zero_point,
121
- shape=shape_broadcast,
122
- ) if x_scale_rank == 1 else x_zero_point
123
- subed_tensor = tf.subtract(
124
- x=input_tensor,
125
- y=x_zero_point,
126
- )
127
- tf_layers_dict[graph_node_output.name]['tf_node'] = \
128
- tf.multiply(
129
- x=subed_tensor,
130
- y=x_scale,
131
- )
147
+ block_size = int(graph_node.attrs.get('block_size', 0))
148
+ x_scale = _expand_scale_or_zero_point(
149
+ value=x_scale,
150
+ input_tensor=input_tensor,
151
+ axis=axis,
152
+ block_size=block_size,
153
+ )
154
+
155
+ if input_is_dequantized:
156
+ tf_layers_dict[graph_node_output.name]['tf_node'] = input_tensor
157
+ else:
158
+ if x_zero_point is None or input_tensor_dtype == tf.int32:
159
+ x_zero_point = tf.zeros_like(x_scale)
160
+ else:
161
+ x_zero_point = tf.cast(x_zero_point, tf.float32)
162
+ x_zero_point = _expand_scale_or_zero_point(
163
+ value=x_zero_point,
164
+ input_tensor=input_tensor,
165
+ axis=axis,
166
+ block_size=block_size,
167
+ )
168
+
169
+ tf_layers_dict[graph_node_output.name]['tf_node'] = \
170
+ tf.multiply(
171
+ x=tf.subtract(input_tensor, x_zero_point),
172
+ y=x_scale,
173
+ )
132
174
 
133
175
  if hasattr(tf_layers_dict[graph_node_output.name]['tf_node'], 'numpy'):
134
176
  tf_layers_dict[graph_node_output.name]['tf_node'] = \
@@ -43,6 +43,9 @@ def make_node(
43
43
  graph_node.inputs[0],
44
44
  before_op_output_shape_trans,
45
45
  )
46
+ input_nhwc = False
47
+ if isinstance(graph_node_input_1, gs.Variable):
48
+ input_nhwc = tf_layers_dict.get(graph_node_input_1.name, {}).get('nhwc', False)
46
49
  graph_node_output_1: gs.Variable = graph_node.outputs[0]
47
50
  o1_shape = graph_node_output_1.shape
48
51
  o1_dtype = graph_node_output_1.dtype
@@ -58,6 +61,8 @@ def make_node(
58
61
  'optype': graph_node.op,
59
62
  'shape': o1_shape,
60
63
  'dtype': o1_dtype,
64
+ 'is_dequantized': True,
65
+ 'nhwc': input_nhwc,
61
66
  }
62
67
  tf_layers_dict[graph_node_output_2.name] = {
63
68
  'optype': graph_node.op,
@@ -82,35 +87,31 @@ def make_node(
82
87
  )
83
88
 
84
89
  # Generation of TF OP
85
- dtype = tf.uint8
86
- qmin = dtype.min
87
- qmax = dtype.max
88
- min_x = tf.math.minimum(0., tf.math.reduce_min(input_tensor_1))
89
- max_x = tf.math.maximum(0., tf.math.reduce_max(input_tensor_1))
90
+ qmin = 0.0
91
+ qmax = 255.0
92
+ min_x = tf.math.minimum(0.0, tf.math.reduce_min(input_tensor_1))
93
+ max_x = tf.math.maximum(0.0, tf.math.reduce_max(input_tensor_1))
90
94
  y_scale = (max_x - min_x) / (qmax - qmin)
91
95
  intermediate_zero_point = qmin - (min_x / y_scale)
92
- y_zero_point = tf.clip_by_value(
93
- tf.round(
94
- x=intermediate_zero_point
95
- ),
96
+ clipped_zero_point = tf.clip_by_value(
97
+ intermediate_zero_point,
96
98
  clip_value_min=qmin,
97
99
  clip_value_max=qmax,
98
100
  )
99
- y = tf.cast(
100
- tf.clip_by_value(
101
- (tf.round(input_tensor_1 / y_scale) + y_zero_point),
102
- clip_value_min=qmin,
103
- clip_value_max=qmax,
104
- ),
105
- dtype=dtype,
101
+ y_zero_point = tf.round(clipped_zero_point)
102
+ y_quant = tf.clip_by_value(
103
+ tf.round(input_tensor_1 / y_scale) + y_zero_point,
104
+ clip_value_min=qmin,
105
+ clip_value_max=qmax,
106
106
  )
107
+ y = (y_quant - y_zero_point) * y_scale
107
108
 
108
109
  tf_layers_dict[graph_node_output_1.name]['tf_node'] = y
109
110
  tf_layers_dict[graph_node_output_2.name]['tf_node'] = y_scale
110
111
  tf_layers_dict[graph_node_output_3.name]['tf_node'] = \
111
112
  tf.cast(
112
113
  x=y_zero_point,
113
- dtype=dtype,
114
+ dtype=tf.uint8,
114
115
  )
115
116
 
116
117
  # Post-process transpose