onnx2tf 1.29.15__py3-none-any.whl → 1.29.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,3 @@
1
- import sys
2
1
  import random
3
2
  random.seed(0)
4
3
  import numpy as np
@@ -25,45 +24,57 @@ def _dequantize_tensor(
25
24
  # Do computation in float32
26
25
  base = tf.cast(base, tf.float32)
27
26
  zero_point = tf.cast(zero_point, tf.float32)
27
+ scale = tf.cast(scale, tf.float32)
28
28
  return (base - zero_point) * scale
29
29
 
30
30
 
31
+ def _reshape_per_output_channel(
32
+ *,
33
+ value,
34
+ weights,
35
+ ):
36
+ value_rank = len(value.shape)
37
+ weights_rank = len(weights.shape)
38
+ if value_rank == 0:
39
+ return value
40
+ if value_rank == 1 and weights_rank is not None:
41
+ shape = [1] * weights_rank
42
+ shape[-1] = -1
43
+ return tf.reshape(value, shape)
44
+ return value
45
+
46
+
31
47
  def _dequantize_weights(
32
48
  *,
33
49
  base,
34
50
  zero_point,
35
51
  scale,
36
- is_bias=False,
37
- scale_is_scalar=False,
38
52
  ):
39
53
  # Do computation in float32
40
54
  casted_base = tf.cast(base, tf.float32)
41
55
  casted_zero_point = tf.cast(zero_point, tf.float32)
42
- spartial_shape_len = len(casted_base.shape) - 2
43
- casted_zero_point_shape = casted_zero_point.shape[0]
44
- if casted_zero_point_shape == base.shape[-2]:
45
- reshaped_zero_point = tf.reshape(
46
- tensor=casted_zero_point,
47
- shape=[1 for _ in range(spartial_shape_len)] + [casted_zero_point_shape, 1],
48
- )
49
- if scale_is_scalar:
50
- reshaped_scale = tf.reshape(
51
- tensor=scale,
52
- shape=[1 for _ in range(spartial_shape_len)] + [casted_zero_point_shape, 1],
53
- )
54
- tensor_list = [
55
- (casted_base[..., i:i+1] - reshaped_zero_point) * reshaped_scale
56
- for i in range(base.shape[-1])
57
- ]
58
- out_tensor = tf.concat(tensor_list, axis=-1)
59
- else:
60
- reshaped_scale = scale
61
- out_tensor = (casted_base - reshaped_zero_point) * reshaped_scale
62
- return tf.reshape(out_tensor, base.shape)
63
- else:
64
- reshaped_zero_point = casted_zero_point
65
- reshaped_scale = scale
66
- return (casted_base - reshaped_zero_point) * reshaped_scale
56
+ casted_scale = tf.cast(scale, tf.float32)
57
+ casted_zero_point = _reshape_per_output_channel(
58
+ value=casted_zero_point,
59
+ weights=casted_base,
60
+ )
61
+ casted_scale = _reshape_per_output_channel(
62
+ value=casted_scale,
63
+ weights=casted_base,
64
+ )
65
+ return (casted_base - casted_zero_point) * casted_scale
66
+
67
+
68
+ def _get_qmin_qmax(dtype: tf.dtypes.DType):
69
+ if dtype == tf.uint8:
70
+ return 0.0, 255.0
71
+ if dtype == tf.int8:
72
+ return -128.0, 127.0
73
+ if dtype == tf.uint16:
74
+ return 0.0, 65535.0
75
+ if dtype == tf.int16:
76
+ return -32768.0, 32767.0
77
+ return None, None
67
78
 
68
79
 
69
80
  @print_node_info
@@ -139,6 +150,11 @@ def make_node(
139
150
 
140
151
  input_tensor = tf_layers_dict[graph_node_input_1.name]['tf_node'] \
141
152
  if isinstance(graph_node_input_1, gs.Variable) else graph_node_input_1
153
+ input_is_dequantized = False
154
+ input_nhwc = False
155
+ if isinstance(graph_node_input_1, gs.Variable):
156
+ input_is_dequantized = tf_layers_dict.get(graph_node_input_1.name, {}).get('is_dequantized', False)
157
+ input_nhwc = tf_layers_dict.get(graph_node_input_1.name, {}).get('nhwc', False)
142
158
  input_tensor_scale = tf_layers_dict[graph_node_input_2.name]['tf_node'] \
143
159
  if isinstance(graph_node_input_2, gs.Variable) else graph_node_input_2
144
160
  input_tensor_zero_point = tf_layers_dict[graph_node_input_3.name]['tf_node'] \
@@ -155,7 +171,7 @@ def make_node(
155
171
  if isinstance(graph_node_input_8, gs.Variable) else graph_node_input_8
156
172
  input_bias = tf_layers_dict[graph_node_input_9.name]['tf_node'] \
157
173
  if isinstance(graph_node_input_9, gs.Variable) else graph_node_input_9
158
- output_dtype = input_tensor.dtype if input_tensor.dtype not in [tf.int8, tf.uint8] else tf.float32
174
+ output_quant_dtype = y_zero_point.dtype
159
175
 
160
176
  input_tensor_shape = input_tensor.shape
161
177
  input_tensor_rank = len(input_tensor_shape)
@@ -172,48 +188,32 @@ def make_node(
172
188
  'optype': graph_node.op,
173
189
  'shape': output_tensor_shape,
174
190
  'dtype': dtype,
191
+ 'is_dequantized': True,
192
+ 'nhwc': input_nhwc,
175
193
  }
176
194
 
177
195
  # Generation of TF OP
178
196
 
179
- # Convert w_zero_point and w_scale to 1-D if scalar
180
- if len(input_weights_zero_point.shape) == 0:
181
- input_weights_zero_point = tf.fill([input_tensor.shape[-1]//group], input_weights_zero_point)
182
- elif len(input_weights_zero_point.shape) > 1:
183
- error(
184
- f'Unsupported zero point: {graph_node.name} {input_weights_zero_point}'
185
- )
186
- sys.exit(1)
187
-
188
- weights_scale_is_scalar = False
189
- if len(input_weights_scale.shape) == 0:
190
- weights_scale_is_scalar = True
191
- input_weights_scale = tf.fill([input_tensor.shape[-1]//group], input_weights_scale)
192
- elif len(input_weights_scale.shape) > 1:
193
- error(
194
- f'Unsupported scalet: {graph_node.name} {input_weights_scale}'
195
- )
196
- sys.exit(1)
197
-
198
197
  # Dequantize variables to float32
199
- input_tensor = _dequantize_tensor(
200
- base=input_tensor,
201
- zero_point=input_tensor_zero_point,
202
- scale=input_tensor_scale,
203
- )
198
+ if input_is_dequantized:
199
+ input_tensor = tf.cast(input_tensor, tf.float32)
200
+ else:
201
+ input_tensor = _dequantize_tensor(
202
+ base=input_tensor,
203
+ zero_point=input_tensor_zero_point,
204
+ scale=input_tensor_scale,
205
+ )
204
206
  input_weights = _dequantize_weights(
205
207
  base=input_weights,
206
208
  zero_point=input_weights_zero_point,
207
209
  scale=input_weights_scale,
208
- scale_is_scalar=weights_scale_is_scalar,
209
210
  )
210
- y_zero_point = tf.cast(y_zero_point, tf.float32)
211
211
 
212
212
  # if bias is defined save it here
213
213
  if input_bias is not None:
214
214
  input_bias = tf.cast(input_bias, tf.float32)
215
- input_bias_scale = input_tensor_scale * input_weights_scale
216
- input_bias = tf.round(input_bias / input_bias_scale)
215
+ input_bias_scale = tf.cast(input_tensor_scale, tf.float32) * tf.cast(input_weights_scale, tf.float32)
216
+ input_bias = input_bias * input_bias_scale
217
217
 
218
218
  """
219
219
  Conv1D
@@ -260,7 +260,7 @@ def make_node(
260
260
  depthwise = bool(group == input_tensor_shape[-1])
261
261
 
262
262
  if depthwise is True:
263
- depthwise_filter_shape = list(input_weights_shape[0:2]) + [-1, input_weights_shape[3] // group]
263
+ depthwise_filter_shape = list(input_weights_shape[0:2]) + [input_weights_shape[2], input_weights_shape[3] // group]
264
264
  input_weights = tf.reshape(input_weights, depthwise_filter_shape)
265
265
 
266
266
  # Conv
@@ -308,27 +308,23 @@ def make_node(
308
308
  )
309
309
  tf_op_type = tf.nn.depthwise_conv2d
310
310
 
311
- # Process output
312
- scaled_conv_node = tf.add(
313
- x=tf.round(
314
- tf.divide(
315
- x=conv_node,
316
- y=y_scale,
317
- ),
318
- ),
319
- y=y_zero_point,
320
- )
321
-
322
- # Add bias to the convolution
311
+ # Add bias to the convolution (float)
323
312
  if input_bias is not None:
324
- scaled_conv_node = tf.add(
325
- x=scaled_conv_node,
313
+ conv_node = tf.add(
314
+ x=conv_node,
326
315
  y=input_bias,
327
316
  )
328
317
 
329
- casted_conv_node = tf.cast(scaled_conv_node, output_dtype)
318
+ # quantize then dequantize to float32
319
+ y_scale = tf.cast(y_scale, tf.float32)
320
+ y_zero_point = tf.cast(y_zero_point, tf.float32)
321
+ quantized = tf.round(tf.divide(conv_node, y_scale)) + y_zero_point
322
+ qmin, qmax = _get_qmin_qmax(output_quant_dtype)
323
+ if qmin is not None and qmax is not None:
324
+ quantized = tf.clip_by_value(quantized, qmin, qmax)
325
+ dequantized = tf.multiply(tf.subtract(quantized, y_zero_point), y_scale)
330
326
 
331
- tf_layers_dict[graph_node_output.name]['tf_node'] = casted_conv_node
327
+ tf_layers_dict[graph_node_output.name]['tf_node'] = dequantized
332
328
 
333
329
  # Generation of Debug Info
334
330
  tf_layers_dict[graph_node_output.name]['tf_node_info'] = \
@@ -349,4 +345,3 @@ def make_node(
349
345
  },
350
346
  }
351
347
  )
352
-
@@ -11,6 +11,47 @@ from onnx2tf.utils.common_functions import (
11
11
  make_tf_node_info,
12
12
  )
13
13
 
14
+ def _get_qmin_qmax(dtype: tf.dtypes.DType):
15
+ if dtype == tf.uint8:
16
+ return 0.0, 255.0
17
+ if dtype == tf.int8:
18
+ return -128.0, 127.0
19
+ if dtype == tf.uint16:
20
+ return 0.0, 65535.0
21
+ if dtype == tf.int16:
22
+ return -32768.0, 32767.0
23
+ return None, None
24
+
25
+
26
+ def _reshape_for_axis(
27
+ *,
28
+ value,
29
+ input_tensor,
30
+ axis: int,
31
+ ):
32
+ value_rank = len(value.shape)
33
+ input_rank = len(input_tensor.shape)
34
+ if value_rank == 1 and input_rank is not None:
35
+ shape = [1] * input_rank
36
+ shape[axis] = -1
37
+ return tf.reshape(value, shape)
38
+ return value
39
+
40
+
41
+ def _reshape_for_output(
42
+ *,
43
+ value,
44
+ output_tensor,
45
+ ):
46
+ value_rank = len(value.shape)
47
+ output_rank = len(output_tensor.shape)
48
+ if value_rank == 1 and output_rank is not None and output_rank >= 2:
49
+ if output_tensor.shape[-2] == value.shape[0]:
50
+ shape = [1] * output_rank
51
+ shape[-2] = -1
52
+ return tf.reshape(value, shape)
53
+ return value
54
+
14
55
 
15
56
  @print_node_info
16
57
  @inverted_operation_enable_disable
@@ -76,12 +117,18 @@ def make_node(
76
117
 
77
118
  a = tf_layers_dict[graph_node_input_1.name]['tf_node'] \
78
119
  if isinstance(graph_node_input_1, gs.Variable) else graph_node_input_1
120
+ a_is_dequantized = False
121
+ if isinstance(graph_node_input_1, gs.Variable):
122
+ a_is_dequantized = tf_layers_dict.get(graph_node_input_1.name, {}).get('is_dequantized', False)
79
123
  a_scale = tf_layers_dict[graph_node_input_2.name]['tf_node'] \
80
124
  if isinstance(graph_node_input_2, gs.Variable) else graph_node_input_2
81
125
  a_zero_point = tf_layers_dict[graph_node_input_3.name]['tf_node'] \
82
126
  if isinstance(graph_node_input_3, gs.Variable) else graph_node_input_3
83
127
  b = tf_layers_dict[graph_node_input_4.name]['tf_node'] \
84
128
  if isinstance(graph_node_input_4, gs.Variable) else graph_node_input_4
129
+ b_is_dequantized = False
130
+ if isinstance(graph_node_input_4, gs.Variable):
131
+ b_is_dequantized = tf_layers_dict.get(graph_node_input_4.name, {}).get('is_dequantized', False)
85
132
  b_scale = tf_layers_dict[graph_node_input_5.name]['tf_node'] \
86
133
  if isinstance(graph_node_input_5, gs.Variable) else graph_node_input_5
87
134
  b_zero_point = tf_layers_dict[graph_node_input_6.name]['tf_node'] \
@@ -90,50 +137,60 @@ def make_node(
90
137
  if isinstance(graph_node_input_7, gs.Variable) else graph_node_input_7
91
138
  y_zero_point = tf_layers_dict[graph_node_input_8.name]['tf_node'] \
92
139
  if isinstance(graph_node_input_8, gs.Variable) else graph_node_input_8
93
- y_dtype = y_zero_point.dtype if y_zero_point.dtype not in [tf.int8, tf.uint8] else tf.float32
140
+ y_dtype = y_zero_point.dtype
94
141
 
95
142
  # Preserving Graph Structure (Dict)
96
143
  tf_layers_dict[graph_node_output.name] = {
97
144
  'optype': graph_node.op,
98
145
  'shape': shape,
99
146
  'dtype': dtype,
147
+ 'is_dequantized': True,
100
148
  }
101
149
 
102
150
  # Generation of TF OP
103
151
 
104
- # reshape 1-D a_scale, a_zero_point, y_scale and
105
- # y_zero_point so it can broadcast in arithmetic
106
- # operations later
107
- a_scale_shape = a_scale.shape
108
- if a_scale_shape and a_scale_shape[0] > 1:
109
- a_scale = tf.reshape(a_scale, [a_scale_shape[0], 1])
110
- a_zero_point = tf.reshape(a_zero_point, [a_scale_shape[0], 1])
111
- y_scale_shape = y_scale.shape
112
- if y_scale_shape and y_scale_shape[0] > 1:
113
- y_scale = tf.reshape(y_scale, [y_scale_shape[0], 1])
114
- y_zero_point = tf.reshape(y_zero_point, [y_scale_shape[0], 1])
152
+ # reshape a_scale and a_zero_point to broadcast on row axis (second last)
153
+ a_scale = _reshape_for_axis(value=a_scale, input_tensor=a, axis=-2)
154
+ a_zero_point = _reshape_for_axis(value=a_zero_point, input_tensor=a, axis=-2)
155
+ # reshape b_scale and b_zero_point to broadcast on column axis (last)
156
+ b_scale = _reshape_for_axis(value=b_scale, input_tensor=b, axis=-1)
157
+ b_zero_point = _reshape_for_axis(value=b_zero_point, input_tensor=b, axis=-1)
115
158
 
116
159
  # cast all inputs to float32
117
160
  a = tf.cast(a, tf.float32)
161
+ a_scale = tf.cast(a_scale, tf.float32)
118
162
  a_zero_point = tf.cast(a_zero_point, tf.float32)
119
163
  b = tf.cast(b, tf.float32)
164
+ b_scale = tf.cast(b_scale, tf.float32)
120
165
  b_zero_point = tf.cast(b_zero_point, tf.float32)
166
+ y_scale = tf.cast(y_scale, tf.float32)
121
167
  y_zero_point = tf.cast(y_zero_point, tf.float32)
122
168
 
123
169
  # dequantize a and b
124
- dequantized_a = tf.subtract(a, a_zero_point)
125
- dequantized_a = tf.multiply(dequantized_a, a_scale)
126
- dequantized_b = tf.subtract(b, b_zero_point)
127
- dequantized_b = tf.multiply(dequantized_b, b_scale)
170
+ if a_is_dequantized:
171
+ dequantized_a = tf.cast(a, tf.float32)
172
+ else:
173
+ dequantized_a = tf.multiply(tf.subtract(a, a_zero_point), a_scale)
174
+
175
+ if b_is_dequantized:
176
+ dequantized_b = tf.cast(b, tf.float32)
177
+ else:
178
+ dequantized_b = tf.multiply(tf.subtract(b, b_zero_point), b_scale)
128
179
 
129
180
  # matmul
130
181
  x = tf.matmul(dequantized_a, dequantized_b)
131
182
 
132
- # quantize x
133
- y = tf.divide(x, y_scale)
134
- y = tf.round(y)
183
+ # broadcast output scale/zero_point if needed
184
+ y_scale = _reshape_for_output(value=y_scale, output_tensor=x)
185
+ y_zero_point = _reshape_for_output(value=y_zero_point, output_tensor=x)
186
+
187
+ # quantize then dequantize to float32
188
+ y = tf.round(tf.divide(x, y_scale))
135
189
  y = tf.add(y, y_zero_point)
136
- y = tf.saturate_cast(y, y_dtype)
190
+ qmin, qmax = _get_qmin_qmax(y_dtype)
191
+ if qmin is not None and qmax is not None:
192
+ y = tf.clip_by_value(y, qmin, qmax)
193
+ y = tf.multiply(tf.subtract(y, y_zero_point), y_scale)
137
194
 
138
195
  tf_layers_dict[graph_node_output.name]['tf_node'] = y
139
196
 
@@ -11,6 +11,49 @@ from onnx2tf.utils.common_functions import (
11
11
  make_tf_node_info,
12
12
  convert_axis,
13
13
  )
14
+ from onnx2tf.utils.enums import ONNX_DTYPES_TO_TF_DTYPES
15
+
16
+
17
+ def _get_qmin_qmax(dtype: tf.dtypes.DType):
18
+ if dtype == tf.uint8:
19
+ return 0.0, 255.0
20
+ if dtype == tf.int8:
21
+ return -128.0, 127.0
22
+ if dtype == tf.uint16:
23
+ return 0.0, 65535.0
24
+ if dtype == tf.int16:
25
+ return -32768.0, 32767.0
26
+ return None, None
27
+
28
+
29
+ def _expand_scale_or_zero_point(
30
+ *,
31
+ value,
32
+ input_tensor,
33
+ axis: int,
34
+ block_size: int,
35
+ ):
36
+ value_rank = len(value.shape)
37
+ input_rank = len(input_tensor.shape)
38
+
39
+ if value_rank == 0:
40
+ return value
41
+
42
+ if block_size > 0 and value_rank == input_rank:
43
+ if value.shape[axis] is None \
44
+ or input_tensor.shape[axis] is None \
45
+ or value.shape[axis] != input_tensor.shape[axis]:
46
+ expanded = tf.repeat(value, repeats=block_size, axis=axis)
47
+ expanded = tf.slice(expanded, [0] * input_rank, tf.shape(input_tensor))
48
+ return expanded
49
+ return value
50
+
51
+ if value_rank == 1 and input_rank is not None:
52
+ shape = [1] * input_rank
53
+ shape[axis] = -1
54
+ return tf.reshape(value, shape)
55
+
56
+ return value
14
57
 
15
58
 
16
59
  @print_node_info
@@ -60,12 +103,12 @@ def make_node(
60
103
 
61
104
  input_tensor = tf_layers_dict[graph_node_input_1.name]['tf_node'] \
62
105
  if isinstance(graph_node_input_1, gs.Variable) else graph_node_input_1
63
- input_tensor_shape = input_tensor.shape
64
- input_tensor_rank = len(input_tensor_shape)
106
+ input_nhwc = False
107
+ if isinstance(graph_node_input_1, gs.Variable):
108
+ input_nhwc = tf_layers_dict.get(graph_node_input_1.name, {}).get('nhwc', False)
109
+ input_tensor_rank = len(input_tensor.shape)
65
110
  y_scale = tf_layers_dict[graph_node_input_2.name]['tf_node'] \
66
111
  if isinstance(graph_node_input_2, gs.Variable) else graph_node_input_2
67
- y_scale_shape = y_scale.shape
68
- y_scale_rank = len(y_scale_shape)
69
112
  y_zero_point = tf_layers_dict[graph_node_input_3.name]['tf_node'] \
70
113
  if isinstance(graph_node_input_3, gs.Variable) else graph_node_input_3
71
114
 
@@ -81,6 +124,8 @@ def make_node(
81
124
  'optype': graph_node.op,
82
125
  'shape': shape,
83
126
  'dtype': dtype,
127
+ 'is_dequantized': True,
128
+ 'nhwc': input_nhwc,
84
129
  }
85
130
 
86
131
  # Generation of TF OP
@@ -88,51 +133,79 @@ def make_node(
88
133
  x=input_tensor,
89
134
  dtype=tf.float32,
90
135
  )
91
- x_shape = input_tensor_shape
92
- x_rank = input_tensor_rank
93
- y_scale_shape = y_scale_shape
94
-
95
- # Reshape process is needed for per-axis quantization
96
- # when scale is a 1-D tensor
97
- if y_scale_rank == 1:
98
- shape_broadcast = list(
99
- [1 for _ in range(axis)] \
100
- + [x_shape[axis]] \
101
- + [1 for _ in range(axis + 1, x_rank)]
102
- )
103
- y_scale = tf.reshape(
104
- tensor=y_scale,
105
- shape=shape_broadcast,
106
- )
107
- y = tf.divide(
108
- x=input_tensor,
109
- y=y_scale,
136
+
137
+ # If QuantizeLinear is immediately followed by Cast -> DequantizeLinear
138
+ # or DequantizeLinear only, bypass fake-quant to avoid generating
139
+ # Mul/Round/Min/Relu/Mul chains in TF/TFLite.
140
+ bypass_fake_quant = False
141
+ if graph_node.outputs and len(graph_node.outputs) > 0:
142
+ consumers = graph_node.outputs[0].outputs
143
+ if consumers:
144
+ bypass_fake_quant = True
145
+ for consumer in consumers:
146
+ if consumer.op == 'DequantizeLinear':
147
+ continue
148
+ if consumer.op == 'Cast':
149
+ cast_outs = consumer.outputs[0].outputs if consumer.outputs else []
150
+ if not cast_outs or any(grand.op != 'DequantizeLinear' for grand in cast_outs):
151
+ bypass_fake_quant = False
152
+ break
153
+ else:
154
+ bypass_fake_quant = False
155
+ break
156
+
157
+ if bypass_fake_quant:
158
+ tf_layers_dict[graph_node_output.name]['tf_node'] = input_tensor
159
+ tf_layers_dict[graph_node_output.name]['tf_node_info'] = \
160
+ make_tf_node_info(
161
+ node_info={
162
+ 'tf_op_type': 'QuantizeLinear',
163
+ 'tf_inputs': {
164
+ 'x': input_tensor,
165
+ },
166
+ 'tf_outputs': {
167
+ 'output': tf_layers_dict[graph_node_output.name]['tf_node'],
168
+ },
169
+ }
170
+ )
171
+ return
172
+ y_scale = tf.cast(y_scale, tf.float32)
173
+
174
+ block_size = int(graph_node.attrs.get('block_size', 0))
175
+ y_scale = _expand_scale_or_zero_point(
176
+ value=y_scale,
177
+ input_tensor=input_tensor,
178
+ axis=axis,
179
+ block_size=block_size,
110
180
  )
111
- y = tf.round(y)
112
181
 
113
- if y_zero_point is not None:
114
- y_dtype = y_zero_point.dtype if y_zero_point.dtype not in [tf.int8, tf.uint8] else tf.float32
115
- y_zero_point = tf.cast(
116
- x=y_zero_point,
117
- dtype=tf.float32,
118
- )
119
- y_zero_point = tf.reshape(
120
- tensor=y_zero_point,
121
- shape=shape_broadcast,
122
- ) if y_scale_rank == 1 else y_zero_point
123
- y = tf.add(
124
- x=y,
125
- y=y_zero_point,
182
+ output_dtype_attr = int(graph_node.attrs.get('output_dtype', 0))
183
+ if y_zero_point is None:
184
+ output_dtype = ONNX_DTYPES_TO_TF_DTYPES.get(output_dtype_attr, tf.uint8) \
185
+ if output_dtype_attr != 0 else tf.uint8
186
+ y_zero_point = tf.zeros_like(y_scale)
187
+ else:
188
+ output_dtype = y_zero_point.dtype
189
+ y_zero_point = tf.cast(y_zero_point, tf.float32)
190
+ y_zero_point = _expand_scale_or_zero_point(
191
+ value=y_zero_point,
192
+ input_tensor=input_tensor,
193
+ axis=axis,
194
+ block_size=block_size,
126
195
  )
127
- else: # y_zero_point default dtype = uint8
128
- y_dtype = tf.uint8
129
196
 
130
- # Generation of TF OP
197
+ y = tf.round(tf.divide(input_tensor, y_scale))
198
+ y = tf.add(y, y_zero_point)
199
+
200
+ qmin, qmax = _get_qmin_qmax(output_dtype)
201
+ if qmin is not None and qmax is not None:
202
+ y = tf.clip_by_value(y, qmin, qmax)
203
+
204
+ # dequantize to float32 output
131
205
  tf_layers_dict[graph_node_output.name]['tf_node'] = \
132
- tf.saturate_cast(
133
- value=y,
134
- dtype=y_dtype,
135
- name=graph_node.name,
206
+ tf.multiply(
207
+ x=tf.subtract(y, y_zero_point),
208
+ y=y_scale,
136
209
  )
137
210
 
138
211
  # Generation of Debug Info
onnx2tf/ops/Split.py CHANGED
@@ -124,6 +124,32 @@ def make_node(
124
124
  **kwargs,
125
125
  )
126
126
 
127
+ def _infer_split_axis_runtime(input_tensor, sum_split, fallback_axis):
128
+ if sum_split is None:
129
+ return tf.cast(fallback_axis, tf.int32)
130
+ shape = tf.shape(input_tensor)
131
+ eq = tf.equal(shape, tf.cast(sum_split, tf.int32))
132
+ mask = tf.cast(eq, tf.int32)
133
+ count = tf.reduce_sum(mask)
134
+ axis_from = tf.argmax(mask, axis=0, output_type=tf.int32)
135
+ fallback_axis_tensor = tf.cast(fallback_axis, tf.int32)
136
+ is_single = tf.cast(tf.equal(count, 1), tf.int32)
137
+ return axis_from * is_single + fallback_axis_tensor * (1 - is_single)
138
+
139
+ axis_for_split = axis
140
+ sum_split = None
141
+ split_list = None
142
+ if isinstance(split, np.ndarray):
143
+ split_list = list(split)
144
+ elif isinstance(split, (list, tuple)):
145
+ split_list = list(split)
146
+ if split_list is not None and len(split_list) > 1:
147
+ if len(split_list) == sum([1 for dim in split_list if isinstance(dim, (np.int64, int))]):
148
+ sum_split = int(np.sum(split_list))
149
+ axis_dim = input_tensor_shape[axis] if axis < len(input_tensor_shape) else None
150
+ if axis_dim is None or (isinstance(axis_dim, int) and axis_dim != sum_split):
151
+ axis_for_split = _infer_split_axis_runtime(input_tensor, sum_split, axis)
152
+
127
153
  # Generation of TF OP
128
154
  splited_tensors = None
129
155
  if (
@@ -225,18 +251,17 @@ def make_node(
225
251
  num=None,
226
252
  name=graph_node.name,
227
253
  )
228
- elif isinstance(split, np.ndarray) \
254
+ elif isinstance(split, (list, tuple, np.ndarray)) \
229
255
  and len(list(split)) > 1 \
230
- and np.prod(split) != 1 \
231
- and isinstance(input_tensor_shape[axis], int) \
232
- and len(split) == sum([1 for dim in split if isinstance(dim, np.int64) or isinstance(dim, int)]) \
233
- and len(split) != sum([1 for dim in split if split[0] == dim]) \
234
- and np.sum(split) == input_tensor_shape[axis]:
256
+ and (np.prod(split) != 1 if isinstance(split, np.ndarray) else True) \
257
+ and len(list(split)) == sum([1 for dim in list(split) if isinstance(dim, (np.int64, int))]) \
258
+ and len(list(split)) != sum([1 for dim in list(split) if list(split)[0] == dim]) \
259
+ and (not isinstance(input_tensor_shape[axis], int) or np.sum(list(split)) == input_tensor_shape[axis]):
235
260
  # Suppression of FlexSplitV generation
236
261
  # SplitV -> Strided_Slice
237
262
  splited_tensors = []
238
263
  begin_stock = []
239
- for split_idx, split_dim in enumerate(split):
264
+ for split_idx, split_dim in enumerate(list(split)):
240
265
  begin_ = []
241
266
  end_ = []
242
267
  begin_mask_ = 0
@@ -269,7 +294,7 @@ def make_node(
269
294
  tf.split(
270
295
  value=input_tensor,
271
296
  num_or_size_splits=split,
272
- axis=axis,
297
+ axis=axis_for_split,
273
298
  num=num_outputs,
274
299
  name=graph_node.name,
275
300
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx2tf
3
- Version: 1.29.15
3
+ Version: 1.29.17
4
4
  Summary: Self-Created Tools to convert ONNX files (NCHW) to TensorFlow/TFLite/Keras format (NHWC). The purpose of this tool is to solve the massive Transpose extrapolation problem in onnx-tensorflow (onnx-tf).
5
5
  Keywords: onnx,tensorflow,tflite,keras,deep-learning,machine-learning
6
6
  Author: Katsuya Hyodo
@@ -364,7 +364,7 @@ Video speed is adjusted approximately 50 times slower than actual speed.
364
364
  docker run --rm -it \
365
365
  -v `pwd`:/workdir \
366
366
  -w /workdir \
367
- ghcr.io/pinto0309/onnx2tf:1.29.15
367
+ ghcr.io/pinto0309/onnx2tf:1.29.17
368
368
 
369
369
  or
370
370
 
@@ -372,7 +372,7 @@ Video speed is adjusted approximately 50 times slower than actual speed.
372
372
  docker run --rm -it \
373
373
  -v `pwd`:/workdir \
374
374
  -w /workdir \
375
- docker.io/pinto0309/onnx2tf:1.29.15
375
+ docker.io/pinto0309/onnx2tf:1.29.17
376
376
 
377
377
  or
378
378