onnx2tf 1.23.3__py3-none-any.whl → 1.25.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
onnx2tf/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
1
  from onnx2tf.onnx2tf import convert, main
2
2
 
3
- __version__ = '1.23.3'
3
+ __version__ = '1.25.8'
onnx2tf/onnx2tf.py CHANGED
@@ -25,6 +25,7 @@ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
25
25
  os.environ["TF_USE_LEGACY_KERAS"] = '1'
26
26
  import tensorflow as tf
27
27
  tf.random.set_seed(0)
28
+ from tensorflow.python.saved_model.load import _WrapperFunction
28
29
  import tf_keras
29
30
  tf_keras.utils.set_random_seed(0)
30
31
  tf.config.experimental.enable_op_determinism()
@@ -81,6 +82,7 @@ def convert(
81
82
  keep_ncw_or_nchw_or_ncdhw_input_names: Optional[List[str]] = None,
82
83
  keep_nwc_or_nhwc_or_ndhwc_input_names: Optional[List[str]] = None,
83
84
  keep_shape_absolutely_input_names: Optional[List[str]] = None,
85
+ input_names_to_interrupt_model_conversion: Optional[List[str]] = None,
84
86
  output_names_to_interrupt_model_conversion: Optional[List[str]] = None,
85
87
  disable_group_convolution: Optional[bool] = False,
86
88
  enable_accumulation_type_float16: Optional[bool] = False,
@@ -285,6 +287,13 @@ def convert(
285
287
  e.g.\n
286
288
  keep_shape_absolutely_input_names=['input0','input1','input2']
287
289
 
290
+ input_names_to_interrupt_model_conversion: Optional[List[str]]
291
+ Input names that interrupt model conversion.\n
292
+ Interrupts model transformation at the specified input name\n
293
+ and inputs the model partitioned into subgraphs.\n\n
294
+ e.g.\n
295
+ input_names_to_interrupt_model_conversion=['input0','input1','input2']
296
+
288
297
  output_names_to_interrupt_model_conversion: Optional[List[str]]
289
298
  Output names that interrupt model conversion.\n
290
299
  Interrupts model transformation at the specified output name\n
@@ -686,7 +695,77 @@ def convert(
686
695
  metadata_props = onnx_graph.metadata_props
687
696
  graph = gs.import_onnx(onnx_graph)
688
697
 
689
- # List Output
698
+ # Cut the ONNX graph when an input name is specified that interrupts the conversion
699
+ if not input_names_to_interrupt_model_conversion:
700
+ input_names = [
701
+ graph_input.name for graph_input in graph.inputs
702
+ ]
703
+ else:
704
+ try:
705
+ from sne4onnx import extraction
706
+ except Exception as ex:
707
+ error(
708
+ f'If --input_names_to_interrupt_model_conversion is specified, ' +\
709
+ f'you must install sne4onnx. pip install sne4onnx'
710
+ )
711
+ sys.exit(1)
712
+ # Cut ONNX graph at specified input position
713
+ input_names = [
714
+ input_op_name \
715
+ for input_op_name in input_names_to_interrupt_model_conversion
716
+ ]
717
+ onnx_graph: onnx.ModelProto = \
718
+ extraction(
719
+ input_op_names=input_names,
720
+ output_op_names=[graph_output.name for graph_output in graph.outputs],
721
+ onnx_graph=onnx_graph,
722
+ )
723
+ # Re-import of onnx_graph
724
+ del graph
725
+ graph = gs.import_onnx(onnx_graph)
726
+
727
+ total_num_nodes = len(graph.nodes)
728
+ check_count = 0
729
+ idx = 0
730
+ while True:
731
+ # Delete unused nodes
732
+ if check_count >= total_num_nodes:
733
+ break
734
+ op_input_names: List[str] = [inp.name for inp in graph.nodes[idx].inputs]
735
+ remove_enable = not any(name in input_names for name in op_input_names)
736
+ if remove_enable:
737
+ try:
738
+ num_input = len(graph.nodes[idx].inputs)
739
+ enable_var_input = False
740
+ for sub_idx, graph_node_input in enumerate(graph.nodes[idx].inputs):
741
+ if isinstance(graph_node_input, gs.Variable):
742
+ enable_var_input = True
743
+ break
744
+ else:
745
+ pass
746
+ if enable_var_input:
747
+ name = graph.nodes[idx].i(sub_idx).name
748
+ if any([graph_node.name == name for graph_node in graph.nodes]):
749
+ idx += 1
750
+ else:
751
+ try:
752
+ del graph.nodes[idx]
753
+ except IndexError:
754
+ break
755
+ else:
756
+ idx += 1
757
+ except:
758
+ try:
759
+ del graph.nodes[idx]
760
+ except IndexError:
761
+ break
762
+ else:
763
+ idx += 1
764
+ check_count += 1
765
+ onnx_graph = gs.export_onnx(graph=graph, do_type_check=False, **meta_data)
766
+ if metadata_props is not None:
767
+ onnx_graph.metadata_props.extend(metadata_props)
768
+
690
769
  # Cut the ONNX graph when an output name is specified that interrupts the conversion
691
770
  if not output_names_to_interrupt_model_conversion:
692
771
  output_names = [
@@ -706,11 +785,12 @@ def convert(
706
785
  output_op_name \
707
786
  for output_op_name in output_names_to_interrupt_model_conversion
708
787
  ]
709
- onnx_graph: onnx.ModelProto = extraction(
710
- input_op_names=[graph_input.name for graph_input in graph.inputs],
711
- output_op_names=output_names,
712
- onnx_graph=onnx_graph,
713
- )
788
+ onnx_graph: onnx.ModelProto = \
789
+ extraction(
790
+ input_op_names=[graph_input.name for graph_input in graph.inputs],
791
+ output_op_names=output_names,
792
+ onnx_graph=onnx_graph,
793
+ )
714
794
  # Re-import of onnx_graph
715
795
  del graph
716
796
  graph = gs.import_onnx(onnx_graph)
@@ -772,6 +852,7 @@ def convert(
772
852
  output_name = re.sub('^/', 'wa/', output_name)
773
853
  new_output_names.append(output_name)
774
854
  output_names = new_output_names
855
+
775
856
  try:
776
857
  onnx_graph = gs.export_onnx(graph=graph, do_type_check=False, **meta_data)
777
858
  if metadata_props is not None:
@@ -1234,19 +1315,56 @@ def convert(
1234
1315
  # Switch to .pb
1235
1316
  info(Color.GREEN(f'Switch to the output of an optimized protocol buffer file (.pb).'))
1236
1317
  except (KeyError, AssertionError) as e:
1318
+ msg_list = [s for s in e.args if isinstance(s, str)]
1319
+ if len(msg_list) > 0:
1320
+ try:
1321
+ for s in msg_list:
1322
+ if 'Failed to add concrete function' in s \
1323
+ or "Tried to export a function which references an 'untracked' resource" in s:
1324
+ export_archive = tf_keras.export.ExportArchive()
1325
+ export_archive.add_endpoint(
1326
+ name=SIGNATURE_KEY,
1327
+ fn=lambda *inputs : model(inputs),
1328
+ input_signature=[tf.TensorSpec(tensor.shape, tensor.dtype, tensor.name) for tensor in model.inputs],
1329
+ )
1330
+ export_archive.write_out(output_folder_path)
1331
+ break
1332
+ except ValueError as e:
1333
+ msg_list = [s for s in e.args if isinstance(s, str)]
1334
+ if len(msg_list) > 0:
1335
+ for s in msg_list:
1336
+ if 'A root scope name has to match the following pattern' in s:
1337
+ error(
1338
+ f'Generation of saved_model failed because the OP name does not match the following pattern. ^[A-Za-z0-9.][A-Za-z0-9_.\\\\/>-]*$'
1339
+ )
1340
+ matches = re.findall(r"'([^']*)'", s)
1341
+ error(f'{matches[0]}')
1342
+ error(
1343
+ f'Please convert again with the `-osd` or `--output_signaturedefs` option.'
1344
+ )
1345
+ sys.exit(1)
1346
+ else:
1347
+ error(e)
1348
+ import traceback
1349
+ error(traceback.format_exc(), prefix=False)
1350
+ else:
1351
+ error(e)
1352
+ import traceback
1353
+ error(traceback.format_exc(), prefix=False)
1354
+ except ValueError as e:
1237
1355
  msg_list = [s for s in e.args if isinstance(s, str)]
1238
1356
  if len(msg_list) > 0:
1239
1357
  for s in msg_list:
1240
- if 'Failed to add concrete function' in s \
1241
- or "Tried to export a function which references an 'untracked' resource" in s:
1242
- export_archive = tf_keras.export.ExportArchive()
1243
- export_archive.add_endpoint(
1244
- name=SIGNATURE_KEY,
1245
- fn=lambda *inputs : model(inputs),
1246
- input_signature=[tf.TensorSpec(tensor.shape, tensor.dtype, tensor.name) for tensor in model.inputs],
1358
+ if 'A root scope name has to match the following pattern' in s:
1359
+ error(
1360
+ f'Generation of saved_model failed because the OP name does not match the following pattern. ^[A-Za-z0-9.][A-Za-z0-9_.\\\\/>-]*$'
1247
1361
  )
1248
- export_archive.write_out(output_folder_path)
1249
- break
1362
+ matches = re.findall(r"'([^']*)'", s)
1363
+ error(f'{matches[0]}')
1364
+ error(
1365
+ f'Please convert again with the `-osd` or `--output_signaturedefs` option.'
1366
+ )
1367
+ sys.exit(1)
1250
1368
  else:
1251
1369
  error(e)
1252
1370
  import traceback
@@ -1408,19 +1526,23 @@ def convert(
1408
1526
  )
1409
1527
 
1410
1528
  # Quantized TFLite
1411
- MEAN = np.asarray([[[[0.485, 0.456, 0.406]]]], dtype=np.float32)
1412
- STD = np.asarray([[[[0.229, 0.224, 0.225]]]], dtype=np.float32)
1413
1529
  if output_integer_quantized_tflite:
1414
1530
  # Get signatures/input keys
1415
- loaded_saved_model = tf.saved_model.load(
1416
- output_folder_path
1417
- ).signatures[SIGNATURE_KEY]
1418
- input_keys = list(loaded_saved_model.structured_input_signature[1].keys())
1419
- input_shapes = [v.shape for v in loaded_saved_model.structured_input_signature[1].values()]
1420
- input_dtypes = [v.dtype for v in loaded_saved_model.structured_input_signature[1].values()]
1421
- output_keys = list(loaded_saved_model.structured_outputs.keys())
1422
- output_shapes = [v.shape for v in loaded_saved_model.structured_outputs.values()]
1423
- output_dtypes = [v.dtype for v in loaded_saved_model.structured_outputs.values()]
1531
+ trackable_obj = \
1532
+ tf.saved_model.load(
1533
+ output_folder_path
1534
+ )
1535
+ loaded_saved_model: _WrapperFunction = trackable_obj.signatures[SIGNATURE_KEY]
1536
+ structured_input_signature: Dict[str, tf.TensorSpec] = loaded_saved_model.structured_input_signature[1]
1537
+ structured_outputs: Dict[str, tf.TensorSpec] = loaded_saved_model.structured_outputs
1538
+
1539
+ input_keys: List[str] = list(structured_input_signature.keys())
1540
+ input_shapes: List[tf.TensorShape] = [v.shape for v in structured_input_signature.values()]
1541
+ input_dtypes: List[tf.dtypes.DType] = [v.dtype for v in structured_input_signature.values()]
1542
+
1543
+ output_keys: List[str] = list(structured_outputs.keys())
1544
+ output_shapes: List[tf.TensorShape] = [v.shape for v in structured_outputs.values()]
1545
+ output_dtypes: List[tf.dtypes.DType] = [v.dtype for v in structured_outputs.values()]
1424
1546
 
1425
1547
  print('')
1426
1548
  info(Color.BLUE(f'Signature information for quantization'))
@@ -1463,15 +1585,30 @@ def convert(
1463
1585
  for model_input in model.inputs:
1464
1586
  if model_input.dtype != tf.float32 \
1465
1587
  or len(model_input.shape) != 4 \
1466
- or model_input.shape[-1] != 3:
1588
+ or model_input.shape[-1] not in [3, 4]:
1467
1589
  error(
1468
1590
  f'For models that have multiple input OPs and need to perform INT8 quantization calibration '+
1469
- f'using non-rgb-image input tensors, specify the calibration data with '+
1591
+ f'using non-rgb-image/non-rgba-image input tensors, specify the calibration data with '+
1470
1592
  f'--quant_calib_input_op_name_np_data_path. '+
1471
1593
  f'model_input[n].shape: {model_input.shape}'
1472
1594
  )
1473
1595
  sys.exit(1)
1474
1596
 
1597
+ if model_input.shape[-1] == 3:
1598
+ # RGB
1599
+ mean = np.asarray([[[[0.485, 0.456, 0.406]]]], dtype=np.float32)
1600
+ std = np.asarray([[[[0.229, 0.224, 0.225]]]], dtype=np.float32)
1601
+ elif model_input.shape[-1] == 4:
1602
+ # RGBA
1603
+ mean = np.asarray([[[[0.485, 0.456, 0.406, 0.000]]]], dtype=np.float32)
1604
+ std = np.asarray([[[[0.229, 0.224, 0.225, 1.000]]]], dtype=np.float32)
1605
+ new_element_array = np.full((*calib_data.shape[:-1], 1), 0.500, dtype=np.float32)
1606
+ calib_data = np.concatenate((calib_data, new_element_array), axis=-1)
1607
+ else:
1608
+ # Others
1609
+ mean = np.asarray([[[[0.485, 0.456, 0.406]]]], dtype=np.float32)
1610
+ std = np.asarray([[[[0.229, 0.224, 0.225]]]], dtype=np.float32)
1611
+
1475
1612
  calib_data_dict[model_input.name] = \
1476
1613
  [
1477
1614
  tf.image.resize(
@@ -1481,8 +1618,8 @@ def convert(
1481
1618
  model_input.shape[2] if model_input.shape[2] is not None else 640,
1482
1619
  )
1483
1620
  ),
1484
- MEAN,
1485
- STD,
1621
+ mean,
1622
+ std,
1486
1623
  ]
1487
1624
  elif custom_input_op_name_np_data_path is not None:
1488
1625
  for param in custom_input_op_name_np_data_path:
@@ -1846,6 +1983,7 @@ def main():
1846
1983
  '-osd',
1847
1984
  '--output_signaturedefs',
1848
1985
  action='store_true',
1986
+ default=True,
1849
1987
  help=\
1850
1988
  'Signature is added to the output for serving or for conversion \n' +
1851
1989
  'to other model formats. However, this can significantly reduce the speed \n' +
@@ -2101,6 +2239,18 @@ def main():
2101
2239
  'e.g. \n' +
2102
2240
  '--keep_shape_absolutely_input_names "input0" "input1" "input2"'
2103
2241
  )
2242
+ parser.add_argument(
2243
+ '-inimc',
2244
+ '--input_names_to_interrupt_model_conversion',
2245
+ type=str,
2246
+ nargs='+',
2247
+ help=\
2248
+ 'Input names that interrupt model conversion. \n' +
2249
+ 'Interrupts model transformation at the specified input name \n' +
2250
+ 'and inputs the model partitioned into subgraphs. \n\n' +
2251
+ 'e.g. \n' +
2252
+ '--input_names_to_interrupt_model_conversion "input0" "input1" "input2"'
2253
+ )
2104
2254
  parser.add_argument(
2105
2255
  '-onimc',
2106
2256
  '--output_names_to_interrupt_model_conversion',
@@ -2444,6 +2594,7 @@ def main():
2444
2594
  keep_ncw_or_nchw_or_ncdhw_input_names=args.keep_ncw_or_nchw_or_ncdhw_input_names,
2445
2595
  keep_nwc_or_nhwc_or_ndhwc_input_names=args.keep_nwc_or_nhwc_or_ndhwc_input_names,
2446
2596
  keep_shape_absolutely_input_names=args.keep_shape_absolutely_input_names,
2597
+ input_names_to_interrupt_model_conversion=args.input_names_to_interrupt_model_conversion,
2447
2598
  output_names_to_interrupt_model_conversion=args.output_names_to_interrupt_model_conversion,
2448
2599
  disable_group_convolution=args.disable_group_convolution,
2449
2600
  enable_accumulation_type_float16=args.enable_accumulation_type_float16,
onnx2tf/ops/Add.py CHANGED
@@ -168,8 +168,37 @@ def make_node(
168
168
  is_scalar_2_rank = tf.rank(input_tensor_2) == 0
169
169
  if hasattr(is_scalar_2_rank, 'numpy'):
170
170
  is_scalar_2 = is_scalar_2_rank.numpy()
171
+
171
172
  if (is_scalar_1 or is_scalar_2) and graph_node.i().op == 'Gemm':
172
173
  pass
174
+ elif (is_scalar_1 or is_scalar_2) and graph_node.i().op != 'Gemm':
175
+ first_tensor = None
176
+ second_tensor = None
177
+ if is_scalar_1:
178
+ first_tensor = input_tensor_2
179
+ second_tensor = input_tensor_1
180
+ elif is_scalar_2:
181
+ first_tensor = input_tensor_1
182
+ second_tensor = input_tensor_2
183
+ tmp_result = tf.math.add(first_tensor, second_tensor)
184
+ tmp_result_shape = tmp_result.shape
185
+ if first_tensor.shape == tmp_result_shape:
186
+ pass
187
+ else:
188
+ input_tensor_1, input_tensor_2 = \
189
+ pre_explicit_broadcast(
190
+ input_tensor_1=input_tensor_1,
191
+ input_tensor_2=input_tensor_2,
192
+ )
193
+
194
+ input_tensor_1, input_tensor_2 = \
195
+ explicit_broadcast(
196
+ const_or_var_1=input_tensor_1,
197
+ const_or_var_2=input_tensor_2,
198
+ graph_node=graph_node,
199
+ tf_layers_dict= tf_layers_dict,
200
+ )
201
+
173
202
  else:
174
203
  input_tensor_1, input_tensor_2 = \
175
204
  pre_explicit_broadcast(
@@ -268,11 +268,14 @@ def make_node(
268
268
  [list(i) for i in zip(tf_pads[:len(tf_pads) // 2], tf_pads[len(tf_pads) // 2:])] + \
269
269
  [[0, 0]]
270
270
 
271
- padded_tensor = tf.pad(
272
- tensor=input_tensor,
273
- paddings=tf_pads,
274
- mode='CONSTANT',
275
- )
271
+ if spatial_size == 1 and kernel_shape[0] > input_tensor_shape[1]:
272
+ padded_tensor = input_tensor
273
+ else:
274
+ padded_tensor = tf.pad(
275
+ tensor=input_tensor,
276
+ paddings=tf_pads,
277
+ mode='CONSTANT',
278
+ )
276
279
 
277
280
  else:
278
281
  padded_tensor = input_tensor
@@ -306,11 +309,18 @@ def make_node(
306
309
  # Generation of TF OP
307
310
  tf_op_type = None
308
311
  if len(kernel_shape) == 1:
309
- pooled_tensor = AveragePooling1D(
310
- pool_size=kernel_shape,
311
- strides=strides,
312
- padding=tf_pad_mode.upper(),
313
- )(padded_tensor)
312
+ if kernel_shape[0] > padded_tensor.shape[1]:
313
+ pooled_tensor = AveragePooling1D(
314
+ pool_size=[padded_tensor.shape[1]],
315
+ strides=[padded_tensor.shape[1]],
316
+ padding=tf_pad_mode.upper(),
317
+ )(padded_tensor)
318
+ else:
319
+ pooled_tensor = AveragePooling1D(
320
+ pool_size=kernel_shape,
321
+ strides=strides,
322
+ padding=tf_pad_mode.upper(),
323
+ )(padded_tensor)
314
324
  tf_op_type = AveragePooling1D
315
325
 
316
326
  elif len(kernel_shape) == 2: