onnx2tf 1.29.15__py3-none-any.whl → 1.29.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -10,6 +10,10 @@ from onnx2tf.utils.common_functions import (
10
10
  print_node_info,
11
11
  inverted_operation_enable_disable,
12
12
  make_tf_node_info,
13
+ pre_process_transpose,
14
+ replace_parameter,
15
+ shape_is_equal_ignore_order,
16
+ transpose_with_flexing_deterrence,
13
17
  )
14
18
 
15
19
 
@@ -35,8 +39,7 @@ def make_node(
35
39
  y_zero_point_list = [i for i in graph_node.inputs[1::3]]
36
40
  input_list = [i for i in graph_node.inputs[2::3]]
37
41
 
38
- input_tensor_shape = input_list[0].shape
39
- input_tensor_rank = len(input_tensor_shape)
42
+ input_tensor_rank = len(input_list[0].shape)
40
43
 
41
44
  before_op_output_shape_trans = True
42
45
  for graph_node_input in input_list:
@@ -46,6 +49,9 @@ def make_node(
46
49
  before_op_output_shape_trans and before_op_output_shape_trans_n
47
50
 
48
51
  got_values = []
52
+ nhwc_flags = []
53
+ same_input_shape_as_onnxs = []
54
+ input_is_dequantized_list = []
49
55
  got_y_scale_list = []
50
56
  got_y_zero_point_list = []
51
57
  for input, y_scale, y_zero_point in zip(input_list, y_scale_list, y_zero_point_list):
@@ -55,8 +61,24 @@ def make_node(
55
61
  )
56
62
  if isinstance(const_or_var, gs.Variable):
57
63
  got_values.append(tf_layers_dict[const_or_var.name]['tf_node'])
64
+ nhwc_flags.append(
65
+ tf_layers_dict[const_or_var.name].get('nhwc', False)
66
+ )
67
+ same_input_shape_as_onnxs.append(
68
+ True if input.shape is not None and len(input.shape) > 0 \
69
+ and input.shape == tf_layers_dict[const_or_var.name]['tf_node'].shape else False
70
+ )
71
+ input_is_dequantized_list.append(
72
+ tf_layers_dict[const_or_var.name].get('is_dequantized', False)
73
+ )
58
74
  else:
59
75
  got_values.append(const_or_var)
76
+ nhwc_flags.append(False)
77
+ same_input_shape_as_onnxs.append(
78
+ True if input.shape is not None and len(input.shape) > 0 \
79
+ and input.shape == const_or_var.shape else False
80
+ )
81
+ input_is_dequantized_list.append(False)
60
82
 
61
83
  const_or_var = get_constant_or_variable(
62
84
  y_scale,
@@ -82,50 +104,247 @@ def make_node(
82
104
  dtype = graph_node_output.dtype
83
105
 
84
106
  axis = graph_node.attrs.get('axis', 0)
85
- # NCHW->NHWC, NCDHW->NDHWC
86
- axis = convert_axis(
87
- axis=axis,
88
- tensor_rank=len(shape) if shape is not None else input_tensor_rank,
89
- before_op_output_shape_trans=before_op_output_shape_trans,
90
- )
107
+
108
+ # Shape Unmatched Special Avoidance Workaround
109
+ if True in same_input_shape_as_onnxs and True in nhwc_flags:
110
+ before_op_output_shape_trans = True
111
+ new_values = []
112
+ for same_input_shape_as_onnx, nhwc_flag, value in zip(same_input_shape_as_onnxs, nhwc_flags, got_values):
113
+ if same_input_shape_as_onnx and not nhwc_flag:
114
+ if len(value.shape) == 3:
115
+ new_values.append(
116
+ transpose_with_flexing_deterrence(
117
+ input_tensor=value,
118
+ perm=[0, 2, 1],
119
+ **kwargs,
120
+ )
121
+ )
122
+ elif len(value.shape) == 4:
123
+ new_values.append(
124
+ transpose_with_flexing_deterrence(
125
+ input_tensor=value,
126
+ perm=[0, 2, 3, 1],
127
+ **kwargs,
128
+ )
129
+ )
130
+ elif len(value.shape) == 5:
131
+ new_values.append(
132
+ transpose_with_flexing_deterrence(
133
+ input_tensor=value,
134
+ perm=[0, 2, 3, 4, 1],
135
+ **kwargs,
136
+ )
137
+ )
138
+ else:
139
+ new_values.append(value)
140
+ else:
141
+ new_values.append(value)
142
+ got_values = new_values
91
143
 
92
144
  # Preserving Graph Structure (Dict)
145
+ nhwc_judge = True
146
+ for graph_node_input in input_list:
147
+ if isinstance(graph_node_input, gs.Variable) \
148
+ and tf_layers_dict.get(graph_node_input.name, {}).get('nhwc', False):
149
+ nhwc_judge = nhwc_judge and True
150
+ elif isinstance(graph_node_input, gs.Constant) \
151
+ and hasattr(graph_node_input, 'values') \
152
+ and isinstance(graph_node_input.values, np.ndarray):
153
+ nhwc_judge = nhwc_judge or False
154
+ else:
155
+ nhwc_judge = nhwc_judge and False
156
+
93
157
  tf_layers_dict[graph_node_output.name] = {
94
158
  'optype': graph_node.op,
95
159
  'shape': shape,
96
160
  'dtype': dtype,
161
+ 'is_dequantized': True,
97
162
  }
163
+ if nhwc_judge:
164
+ tf_layers_dict[graph_node_output.name]['nhwc'] = True
98
165
 
99
166
  # Generation of TF OP
100
167
 
168
+ # NCHW->NHWC, NCDHW->NDHWC
169
+ axis = convert_axis(
170
+ axis=axis,
171
+ tensor_rank=len(shape) if shape is not None else input_tensor_rank,
172
+ before_op_output_shape_trans=before_op_output_shape_trans,
173
+ )
174
+
175
+ # Param replacement
176
+ before_axis = axis
177
+ axis = replace_parameter(
178
+ value_before_replacement=axis,
179
+ param_target='attributes',
180
+ param_name='axis',
181
+ **kwargs,
182
+ )
183
+
101
184
  # TensorFlow does not support Concat for scalar values, so convert to tensor
102
- values = [
103
- value if len(value.shape) > 0 else tf.reshape(value, [1]) for value in got_values
104
- ]
185
+ values = []
186
+ for graph_node_input, value in zip(input_list, got_values):
187
+ value = pre_process_transpose(
188
+ value_before_transpose=value,
189
+ param_target='inputs',
190
+ param_name=graph_node_input.name,
191
+ **kwargs,
192
+ )
193
+ values.append(value if len(value.shape) > 0 else tf.reshape(value, [1]))
194
+
195
+ def _infer_concat_axis(values, output_shape):
196
+ if not values:
197
+ return None
198
+ ranks = []
199
+ shapes = []
200
+ for val in values:
201
+ if val.shape is None or val.shape == tf.TensorShape(None):
202
+ return None
203
+ shape_list = list(val.shape)
204
+ ranks.append(len(shape_list))
205
+ shapes.append(shape_list)
206
+ if len(set(ranks)) != 1:
207
+ return None
208
+ rank = ranks[0]
209
+ candidates = []
210
+ for ax in range(rank):
211
+ ok = True
212
+ for dim in range(rank):
213
+ if dim == ax:
214
+ continue
215
+ base = shapes[0][dim]
216
+ for s in shapes[1:]:
217
+ if base is None or s[dim] is None:
218
+ continue
219
+ if base != s[dim]:
220
+ ok = False
221
+ break
222
+ if not ok:
223
+ break
224
+ if not ok:
225
+ continue
226
+ if output_shape is not None and len(output_shape) == rank:
227
+ out_dim = output_shape[ax]
228
+ if out_dim is not None:
229
+ sum_dim = 0
230
+ for s in shapes:
231
+ if s[ax] is None:
232
+ sum_dim = None
233
+ break
234
+ sum_dim += s[ax]
235
+ if sum_dim is None or sum_dim != out_dim:
236
+ continue
237
+ candidates.append(ax)
238
+ if len(candidates) == 1:
239
+ return candidates[0]
240
+ return None
241
+
242
+ inferred_axis = _infer_concat_axis(values, shape if shape is not None else None)
243
+ if inferred_axis is not None:
244
+ axis = inferred_axis
105
245
  # cast all inputs to float32
106
246
  casted_x_list = []
107
247
  casted_y_zero_point_list = []
108
- for x, y_zero_point in zip(values, got_y_zero_point_list):
248
+ casted_y_scale_list = []
249
+ for x, y_scale, y_zero_point in zip(values, got_y_scale_list, got_y_zero_point_list):
109
250
  casted_x_list.append(tf.cast(x, tf.float32))
251
+ casted_y_scale_list.append(tf.cast(y_scale, tf.float32))
110
252
  casted_y_zero_point_list.append(tf.cast(y_zero_point, tf.float32))
111
253
  # dequantize x with y_scale, y_zero_point
112
254
  dequantized_x_list = []
113
- for x, y_scale, y_zero_point in zip(casted_x_list, got_y_scale_list, casted_y_zero_point_list):
114
- dequantized_value = tf.add(
115
- x=tf.divide(
116
- x=x,
255
+ for x, y_scale, y_zero_point, is_dequantized in zip(
256
+ casted_x_list,
257
+ casted_y_scale_list,
258
+ casted_y_zero_point_list,
259
+ input_is_dequantized_list,
260
+ ):
261
+ if is_dequantized:
262
+ dequantized_x_list.append(x)
263
+ else:
264
+ dequantized_value = tf.multiply(
265
+ x=tf.subtract(x, y_zero_point),
117
266
  y=y_scale,
118
- ),
119
- y=y_zero_point,
120
- )
121
- dequantized_x_list.append(dequantized_value)
267
+ )
268
+ dequantized_x_list.append(dequantized_value)
122
269
 
123
- tf_layers_dict[graph_node_output.name]['tf_node'] = \
124
- tf.concat(
125
- values=dequantized_x_list,
126
- axis=axis,
127
- name=graph_node.name,
128
- )
270
+ try:
271
+ tf_layers_dict[graph_node_output.name]['tf_node'] = \
272
+ tf.concat(
273
+ values=dequantized_x_list,
274
+ axis=axis,
275
+ name=graph_node.name,
276
+ )
277
+ except:
278
+ try:
279
+ onnx_axis = int(graph_node.attrs.get('axis', 0))
280
+ tf_layers_dict[graph_node_output.name]['tf_node'] = \
281
+ tf.concat(
282
+ values=dequantized_x_list,
283
+ axis=onnx_axis,
284
+ name=graph_node.name,
285
+ )
286
+ axis = onnx_axis
287
+ except:
288
+ value_rank = len(dequantized_x_list[0].shape)
289
+ succeed = False
290
+ for idx in reversed(range(value_rank)):
291
+ try:
292
+ tf_layers_dict[graph_node_output.name]['tf_node'] = \
293
+ tf.concat(
294
+ values=dequantized_x_list,
295
+ axis=idx,
296
+ name=graph_node.name,
297
+ )
298
+ axis = idx
299
+ succeed = True
300
+ break
301
+ except:
302
+ pass
303
+ if not succeed:
304
+ raise
305
+
306
+ output_tensor_shape = tf_layers_dict[graph_node_output.name]['tf_node'].shape
307
+ if output_tensor_shape != tf.TensorShape(None):
308
+ output_tensor_rank = len(output_tensor_shape)
309
+ if graph_node.outputs[0].shape is not None \
310
+ and axis != 0 \
311
+ and output_tensor_rank >= 2 \
312
+ and before_axis == axis:
313
+ if not shape_is_equal_ignore_order(list(graph_node.outputs[0].shape), list(output_tensor_shape)):
314
+ matched_axes = []
315
+ for dummy_axis in range(1, output_tensor_rank):
316
+ try:
317
+ dummy_concat_tensor = \
318
+ tf.concat(
319
+ values=dequantized_x_list,
320
+ axis=dummy_axis,
321
+ name=graph_node.name,
322
+ )
323
+ dummy_output_shape = dummy_concat_tensor.shape
324
+ if shape_is_equal_ignore_order(list(graph_node.outputs[0].shape), list(dummy_output_shape)):
325
+ matched_axes.append(dummy_axis)
326
+ except:
327
+ pass
328
+ if len(matched_axes) == 1:
329
+ tf_layers_dict[graph_node_output.name]['tf_node'] = \
330
+ tf.concat(
331
+ values=dequantized_x_list,
332
+ axis=matched_axes[0],
333
+ name=graph_node.name,
334
+ )
335
+ axis = matched_axes[0]
336
+ elif not nhwc_judge:
337
+ onnx_axis = int(graph_node.attrs.get('axis', 0))
338
+ onnx_axis = output_tensor_rank - 1 if onnx_axis == -1 else onnx_axis
339
+ if onnx_axis == output_tensor_rank - 1 \
340
+ and onnx_axis in matched_axes:
341
+ tf_layers_dict[graph_node_output.name]['tf_node'] = \
342
+ tf.concat(
343
+ values=dequantized_x_list,
344
+ axis=onnx_axis,
345
+ name=graph_node.name,
346
+ )
347
+ axis = onnx_axis
129
348
 
130
349
  # Generation of Debug Info
131
350
  tf_inputs = {f"input{idx}": dequantized_x for idx, dequantized_x in enumerate(dequantized_x_list)}
@@ -1,4 +1,3 @@
1
- import sys
2
1
  import random
3
2
  random.seed(0)
4
3
  import numpy as np
@@ -25,45 +24,57 @@ def _dequantize_tensor(
25
24
  # Do computation in float32
26
25
  base = tf.cast(base, tf.float32)
27
26
  zero_point = tf.cast(zero_point, tf.float32)
27
+ scale = tf.cast(scale, tf.float32)
28
28
  return (base - zero_point) * scale
29
29
 
30
30
 
31
+ def _reshape_per_output_channel(
32
+ *,
33
+ value,
34
+ weights,
35
+ ):
36
+ value_rank = len(value.shape)
37
+ weights_rank = len(weights.shape)
38
+ if value_rank == 0:
39
+ return value
40
+ if value_rank == 1 and weights_rank is not None:
41
+ shape = [1] * weights_rank
42
+ shape[-1] = -1
43
+ return tf.reshape(value, shape)
44
+ return value
45
+
46
+
31
47
  def _dequantize_weights(
32
48
  *,
33
49
  base,
34
50
  zero_point,
35
51
  scale,
36
- is_bias=False,
37
- scale_is_scalar=False,
38
52
  ):
39
53
  # Do computation in float32
40
54
  casted_base = tf.cast(base, tf.float32)
41
55
  casted_zero_point = tf.cast(zero_point, tf.float32)
42
- spartial_shape_len = len(casted_base.shape) - 2
43
- casted_zero_point_shape = casted_zero_point.shape[0]
44
- if casted_zero_point_shape == base.shape[-2]:
45
- reshaped_zero_point = tf.reshape(
46
- tensor=casted_zero_point,
47
- shape=[1 for _ in range(spartial_shape_len)] + [casted_zero_point_shape, 1],
48
- )
49
- if scale_is_scalar:
50
- reshaped_scale = tf.reshape(
51
- tensor=scale,
52
- shape=[1 for _ in range(spartial_shape_len)] + [casted_zero_point_shape, 1],
53
- )
54
- tensor_list = [
55
- (casted_base[..., i:i+1] - reshaped_zero_point) * reshaped_scale
56
- for i in range(base.shape[-1])
57
- ]
58
- out_tensor = tf.concat(tensor_list, axis=-1)
59
- else:
60
- reshaped_scale = scale
61
- out_tensor = (casted_base - reshaped_zero_point) * reshaped_scale
62
- return tf.reshape(out_tensor, base.shape)
63
- else:
64
- reshaped_zero_point = casted_zero_point
65
- reshaped_scale = scale
66
- return (casted_base - reshaped_zero_point) * reshaped_scale
56
+ casted_scale = tf.cast(scale, tf.float32)
57
+ casted_zero_point = _reshape_per_output_channel(
58
+ value=casted_zero_point,
59
+ weights=casted_base,
60
+ )
61
+ casted_scale = _reshape_per_output_channel(
62
+ value=casted_scale,
63
+ weights=casted_base,
64
+ )
65
+ return (casted_base - casted_zero_point) * casted_scale
66
+
67
+
68
+ def _get_qmin_qmax(dtype: tf.dtypes.DType):
69
+ if dtype == tf.uint8:
70
+ return 0.0, 255.0
71
+ if dtype == tf.int8:
72
+ return -128.0, 127.0
73
+ if dtype == tf.uint16:
74
+ return 0.0, 65535.0
75
+ if dtype == tf.int16:
76
+ return -32768.0, 32767.0
77
+ return None, None
67
78
 
68
79
 
69
80
  @print_node_info
@@ -139,6 +150,11 @@ def make_node(
139
150
 
140
151
  input_tensor = tf_layers_dict[graph_node_input_1.name]['tf_node'] \
141
152
  if isinstance(graph_node_input_1, gs.Variable) else graph_node_input_1
153
+ input_is_dequantized = False
154
+ input_nhwc = False
155
+ if isinstance(graph_node_input_1, gs.Variable):
156
+ input_is_dequantized = tf_layers_dict.get(graph_node_input_1.name, {}).get('is_dequantized', False)
157
+ input_nhwc = tf_layers_dict.get(graph_node_input_1.name, {}).get('nhwc', False)
142
158
  input_tensor_scale = tf_layers_dict[graph_node_input_2.name]['tf_node'] \
143
159
  if isinstance(graph_node_input_2, gs.Variable) else graph_node_input_2
144
160
  input_tensor_zero_point = tf_layers_dict[graph_node_input_3.name]['tf_node'] \
@@ -155,7 +171,7 @@ def make_node(
155
171
  if isinstance(graph_node_input_8, gs.Variable) else graph_node_input_8
156
172
  input_bias = tf_layers_dict[graph_node_input_9.name]['tf_node'] \
157
173
  if isinstance(graph_node_input_9, gs.Variable) else graph_node_input_9
158
- output_dtype = input_tensor.dtype if input_tensor.dtype not in [tf.int8, tf.uint8] else tf.float32
174
+ output_quant_dtype = y_zero_point.dtype
159
175
 
160
176
  input_tensor_shape = input_tensor.shape
161
177
  input_tensor_rank = len(input_tensor_shape)
@@ -172,48 +188,32 @@ def make_node(
172
188
  'optype': graph_node.op,
173
189
  'shape': output_tensor_shape,
174
190
  'dtype': dtype,
191
+ 'is_dequantized': True,
192
+ 'nhwc': input_nhwc,
175
193
  }
176
194
 
177
195
  # Generation of TF OP
178
196
 
179
- # Convert w_zero_point and w_scale to 1-D if scalar
180
- if len(input_weights_zero_point.shape) == 0:
181
- input_weights_zero_point = tf.fill([input_tensor.shape[-1]//group], input_weights_zero_point)
182
- elif len(input_weights_zero_point.shape) > 1:
183
- error(
184
- f'Unsupported zero point: {graph_node.name} {input_weights_zero_point}'
185
- )
186
- sys.exit(1)
187
-
188
- weights_scale_is_scalar = False
189
- if len(input_weights_scale.shape) == 0:
190
- weights_scale_is_scalar = True
191
- input_weights_scale = tf.fill([input_tensor.shape[-1]//group], input_weights_scale)
192
- elif len(input_weights_scale.shape) > 1:
193
- error(
194
- f'Unsupported scalet: {graph_node.name} {input_weights_scale}'
195
- )
196
- sys.exit(1)
197
-
198
197
  # Dequantize variables to float32
199
- input_tensor = _dequantize_tensor(
200
- base=input_tensor,
201
- zero_point=input_tensor_zero_point,
202
- scale=input_tensor_scale,
203
- )
198
+ if input_is_dequantized:
199
+ input_tensor = tf.cast(input_tensor, tf.float32)
200
+ else:
201
+ input_tensor = _dequantize_tensor(
202
+ base=input_tensor,
203
+ zero_point=input_tensor_zero_point,
204
+ scale=input_tensor_scale,
205
+ )
204
206
  input_weights = _dequantize_weights(
205
207
  base=input_weights,
206
208
  zero_point=input_weights_zero_point,
207
209
  scale=input_weights_scale,
208
- scale_is_scalar=weights_scale_is_scalar,
209
210
  )
210
- y_zero_point = tf.cast(y_zero_point, tf.float32)
211
211
 
212
212
  # if bias is defined save it here
213
213
  if input_bias is not None:
214
214
  input_bias = tf.cast(input_bias, tf.float32)
215
- input_bias_scale = input_tensor_scale * input_weights_scale
216
- input_bias = tf.round(input_bias / input_bias_scale)
215
+ input_bias_scale = tf.cast(input_tensor_scale, tf.float32) * tf.cast(input_weights_scale, tf.float32)
216
+ input_bias = input_bias * input_bias_scale
217
217
 
218
218
  """
219
219
  Conv1D
@@ -260,7 +260,7 @@ def make_node(
260
260
  depthwise = bool(group == input_tensor_shape[-1])
261
261
 
262
262
  if depthwise is True:
263
- depthwise_filter_shape = list(input_weights_shape[0:2]) + [-1, input_weights_shape[3] // group]
263
+ depthwise_filter_shape = list(input_weights_shape[0:2]) + [input_weights_shape[2], input_weights_shape[3] // group]
264
264
  input_weights = tf.reshape(input_weights, depthwise_filter_shape)
265
265
 
266
266
  # Conv
@@ -308,27 +308,23 @@ def make_node(
308
308
  )
309
309
  tf_op_type = tf.nn.depthwise_conv2d
310
310
 
311
- # Process output
312
- scaled_conv_node = tf.add(
313
- x=tf.round(
314
- tf.divide(
315
- x=conv_node,
316
- y=y_scale,
317
- ),
318
- ),
319
- y=y_zero_point,
320
- )
321
-
322
- # Add bias to the convolution
311
+ # Add bias to the convolution (float)
323
312
  if input_bias is not None:
324
- scaled_conv_node = tf.add(
325
- x=scaled_conv_node,
313
+ conv_node = tf.add(
314
+ x=conv_node,
326
315
  y=input_bias,
327
316
  )
328
317
 
329
- casted_conv_node = tf.cast(scaled_conv_node, output_dtype)
318
+ # quantize then dequantize to float32
319
+ y_scale = tf.cast(y_scale, tf.float32)
320
+ y_zero_point = tf.cast(y_zero_point, tf.float32)
321
+ quantized = tf.round(tf.divide(conv_node, y_scale)) + y_zero_point
322
+ qmin, qmax = _get_qmin_qmax(output_quant_dtype)
323
+ if qmin is not None and qmax is not None:
324
+ quantized = tf.clip_by_value(quantized, qmin, qmax)
325
+ dequantized = tf.multiply(tf.subtract(quantized, y_zero_point), y_scale)
330
326
 
331
- tf_layers_dict[graph_node_output.name]['tf_node'] = casted_conv_node
327
+ tf_layers_dict[graph_node_output.name]['tf_node'] = dequantized
332
328
 
333
329
  # Generation of Debug Info
334
330
  tf_layers_dict[graph_node_output.name]['tf_node_info'] = \
@@ -349,4 +345,3 @@ def make_node(
349
345
  },
350
346
  }
351
347
  )
352
-