onnx2tf 1.24.0__py3-none-any.whl → 1.25.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2tf/__init__.py +1 -1
- onnx2tf/onnx2tf.py +181 -30
- onnx2tf/ops/Add.py +29 -0
- onnx2tf/ops/AveragePool.py +20 -10
- onnx2tf/ops/BatchNormalization.py +270 -24
- onnx2tf/ops/Concat.py +4 -4
- onnx2tf/ops/DepthToSpace.py +8 -0
- onnx2tf/ops/Div.py +30 -0
- onnx2tf/ops/Expand.py +207 -0
- onnx2tf/ops/Gather.py +67 -18
- onnx2tf/ops/Mod.py +29 -0
- onnx2tf/ops/Mul.py +30 -0
- onnx2tf/ops/ReduceL1.py +3 -0
- onnx2tf/ops/ReduceL2.py +3 -0
- onnx2tf/ops/ReduceLogSum.py +3 -0
- onnx2tf/ops/ReduceLogSumExp.py +3 -0
- onnx2tf/ops/ReduceMax.py +3 -0
- onnx2tf/ops/ReduceMean.py +3 -0
- onnx2tf/ops/ReduceMin.py +3 -0
- onnx2tf/ops/ReduceProd.py +3 -0
- onnx2tf/ops/ReduceSum.py +3 -0
- onnx2tf/ops/ReduceSumSquare.py +3 -0
- onnx2tf/ops/Shape.py +2 -0
- onnx2tf/ops/Sub.py +29 -0
- onnx2tf/ops/Transpose.py +14 -0
- onnx2tf/utils/common_functions.py +15 -8
- {onnx2tf-1.24.0.dist-info → onnx2tf-1.25.9.dist-info}/METADATA +269 -28
- {onnx2tf-1.24.0.dist-info → onnx2tf-1.25.9.dist-info}/RECORD +33 -33
- {onnx2tf-1.24.0.dist-info → onnx2tf-1.25.9.dist-info}/WHEEL +1 -1
- {onnx2tf-1.24.0.dist-info → onnx2tf-1.25.9.dist-info}/LICENSE +0 -0
- {onnx2tf-1.24.0.dist-info → onnx2tf-1.25.9.dist-info}/LICENSE_onnx-tensorflow +0 -0
- {onnx2tf-1.24.0.dist-info → onnx2tf-1.25.9.dist-info}/entry_points.txt +0 -0
- {onnx2tf-1.24.0.dist-info → onnx2tf-1.25.9.dist-info}/top_level.txt +0 -0
onnx2tf/__init__.py
CHANGED
onnx2tf/onnx2tf.py
CHANGED
|
@@ -25,6 +25,7 @@ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
|
|
25
25
|
os.environ["TF_USE_LEGACY_KERAS"] = '1'
|
|
26
26
|
import tensorflow as tf
|
|
27
27
|
tf.random.set_seed(0)
|
|
28
|
+
from tensorflow.python.saved_model.load import _WrapperFunction
|
|
28
29
|
import tf_keras
|
|
29
30
|
tf_keras.utils.set_random_seed(0)
|
|
30
31
|
tf.config.experimental.enable_op_determinism()
|
|
@@ -81,6 +82,7 @@ def convert(
|
|
|
81
82
|
keep_ncw_or_nchw_or_ncdhw_input_names: Optional[List[str]] = None,
|
|
82
83
|
keep_nwc_or_nhwc_or_ndhwc_input_names: Optional[List[str]] = None,
|
|
83
84
|
keep_shape_absolutely_input_names: Optional[List[str]] = None,
|
|
85
|
+
input_names_to_interrupt_model_conversion: Optional[List[str]] = None,
|
|
84
86
|
output_names_to_interrupt_model_conversion: Optional[List[str]] = None,
|
|
85
87
|
disable_group_convolution: Optional[bool] = False,
|
|
86
88
|
enable_accumulation_type_float16: Optional[bool] = False,
|
|
@@ -285,6 +287,13 @@ def convert(
|
|
|
285
287
|
e.g.\n
|
|
286
288
|
keep_shape_absolutely_input_names=['input0','input1','input2']
|
|
287
289
|
|
|
290
|
+
input_names_to_interrupt_model_conversion: Optional[List[str]]
|
|
291
|
+
Input names that interrupt model conversion.\n
|
|
292
|
+
Interrupts model transformation at the specified input name\n
|
|
293
|
+
and inputs the model partitioned into subgraphs.\n\n
|
|
294
|
+
e.g.\n
|
|
295
|
+
input_names_to_interrupt_model_conversion=['input0','input1','input2']
|
|
296
|
+
|
|
288
297
|
output_names_to_interrupt_model_conversion: Optional[List[str]]
|
|
289
298
|
Output names that interrupt model conversion.\n
|
|
290
299
|
Interrupts model transformation at the specified output name\n
|
|
@@ -686,7 +695,77 @@ def convert(
|
|
|
686
695
|
metadata_props = onnx_graph.metadata_props
|
|
687
696
|
graph = gs.import_onnx(onnx_graph)
|
|
688
697
|
|
|
689
|
-
#
|
|
698
|
+
# Cut the ONNX graph when an input name is specified that interrupts the conversion
|
|
699
|
+
if not input_names_to_interrupt_model_conversion:
|
|
700
|
+
input_names = [
|
|
701
|
+
graph_input.name for graph_input in graph.inputs
|
|
702
|
+
]
|
|
703
|
+
else:
|
|
704
|
+
try:
|
|
705
|
+
from sne4onnx import extraction
|
|
706
|
+
except Exception as ex:
|
|
707
|
+
error(
|
|
708
|
+
f'If --input_names_to_interrupt_model_conversion is specified, ' +\
|
|
709
|
+
f'you must install sne4onnx. pip install sne4onnx'
|
|
710
|
+
)
|
|
711
|
+
sys.exit(1)
|
|
712
|
+
# Cut ONNX graph at specified input position
|
|
713
|
+
input_names = [
|
|
714
|
+
input_op_name \
|
|
715
|
+
for input_op_name in input_names_to_interrupt_model_conversion
|
|
716
|
+
]
|
|
717
|
+
onnx_graph: onnx.ModelProto = \
|
|
718
|
+
extraction(
|
|
719
|
+
input_op_names=input_names,
|
|
720
|
+
output_op_names=[graph_output.name for graph_output in graph.outputs],
|
|
721
|
+
onnx_graph=onnx_graph,
|
|
722
|
+
)
|
|
723
|
+
# Re-import of onnx_graph
|
|
724
|
+
del graph
|
|
725
|
+
graph = gs.import_onnx(onnx_graph)
|
|
726
|
+
|
|
727
|
+
total_num_nodes = len(graph.nodes)
|
|
728
|
+
check_count = 0
|
|
729
|
+
idx = 0
|
|
730
|
+
while True:
|
|
731
|
+
# Delete unused nodes
|
|
732
|
+
if check_count >= total_num_nodes:
|
|
733
|
+
break
|
|
734
|
+
op_input_names: List[str] = [inp.name for inp in graph.nodes[idx].inputs]
|
|
735
|
+
remove_enable = not any(name in input_names for name in op_input_names)
|
|
736
|
+
if remove_enable:
|
|
737
|
+
try:
|
|
738
|
+
num_input = len(graph.nodes[idx].inputs)
|
|
739
|
+
enable_var_input = False
|
|
740
|
+
for sub_idx, graph_node_input in enumerate(graph.nodes[idx].inputs):
|
|
741
|
+
if isinstance(graph_node_input, gs.Variable):
|
|
742
|
+
enable_var_input = True
|
|
743
|
+
break
|
|
744
|
+
else:
|
|
745
|
+
pass
|
|
746
|
+
if enable_var_input:
|
|
747
|
+
name = graph.nodes[idx].i(sub_idx).name
|
|
748
|
+
if any([graph_node.name == name for graph_node in graph.nodes]):
|
|
749
|
+
idx += 1
|
|
750
|
+
else:
|
|
751
|
+
try:
|
|
752
|
+
del graph.nodes[idx]
|
|
753
|
+
except IndexError:
|
|
754
|
+
break
|
|
755
|
+
else:
|
|
756
|
+
idx += 1
|
|
757
|
+
except:
|
|
758
|
+
try:
|
|
759
|
+
del graph.nodes[idx]
|
|
760
|
+
except IndexError:
|
|
761
|
+
break
|
|
762
|
+
else:
|
|
763
|
+
idx += 1
|
|
764
|
+
check_count += 1
|
|
765
|
+
onnx_graph = gs.export_onnx(graph=graph, do_type_check=False, **meta_data)
|
|
766
|
+
if metadata_props is not None:
|
|
767
|
+
onnx_graph.metadata_props.extend(metadata_props)
|
|
768
|
+
|
|
690
769
|
# Cut the ONNX graph when an output name is specified that interrupts the conversion
|
|
691
770
|
if not output_names_to_interrupt_model_conversion:
|
|
692
771
|
output_names = [
|
|
@@ -706,11 +785,12 @@ def convert(
|
|
|
706
785
|
output_op_name \
|
|
707
786
|
for output_op_name in output_names_to_interrupt_model_conversion
|
|
708
787
|
]
|
|
709
|
-
onnx_graph: onnx.ModelProto =
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
788
|
+
onnx_graph: onnx.ModelProto = \
|
|
789
|
+
extraction(
|
|
790
|
+
input_op_names=[graph_input.name for graph_input in graph.inputs],
|
|
791
|
+
output_op_names=output_names,
|
|
792
|
+
onnx_graph=onnx_graph,
|
|
793
|
+
)
|
|
714
794
|
# Re-import of onnx_graph
|
|
715
795
|
del graph
|
|
716
796
|
graph = gs.import_onnx(onnx_graph)
|
|
@@ -772,6 +852,7 @@ def convert(
|
|
|
772
852
|
output_name = re.sub('^/', 'wa/', output_name)
|
|
773
853
|
new_output_names.append(output_name)
|
|
774
854
|
output_names = new_output_names
|
|
855
|
+
|
|
775
856
|
try:
|
|
776
857
|
onnx_graph = gs.export_onnx(graph=graph, do_type_check=False, **meta_data)
|
|
777
858
|
if metadata_props is not None:
|
|
@@ -1234,19 +1315,56 @@ def convert(
|
|
|
1234
1315
|
# Switch to .pb
|
|
1235
1316
|
info(Color.GREEN(f'Switch to the output of an optimized protocol buffer file (.pb).'))
|
|
1236
1317
|
except (KeyError, AssertionError) as e:
|
|
1318
|
+
msg_list = [s for s in e.args if isinstance(s, str)]
|
|
1319
|
+
if len(msg_list) > 0:
|
|
1320
|
+
try:
|
|
1321
|
+
for s in msg_list:
|
|
1322
|
+
if 'Failed to add concrete function' in s \
|
|
1323
|
+
or "Tried to export a function which references an 'untracked' resource" in s:
|
|
1324
|
+
export_archive = tf_keras.export.ExportArchive()
|
|
1325
|
+
export_archive.add_endpoint(
|
|
1326
|
+
name=SIGNATURE_KEY,
|
|
1327
|
+
fn=lambda *inputs : model(inputs),
|
|
1328
|
+
input_signature=[tf.TensorSpec(tensor.shape, tensor.dtype, tensor.name) for tensor in model.inputs],
|
|
1329
|
+
)
|
|
1330
|
+
export_archive.write_out(output_folder_path)
|
|
1331
|
+
break
|
|
1332
|
+
except ValueError as e:
|
|
1333
|
+
msg_list = [s for s in e.args if isinstance(s, str)]
|
|
1334
|
+
if len(msg_list) > 0:
|
|
1335
|
+
for s in msg_list:
|
|
1336
|
+
if 'A root scope name has to match the following pattern' in s:
|
|
1337
|
+
error(
|
|
1338
|
+
f'Generation of saved_model failed because the OP name does not match the following pattern. ^[A-Za-z0-9.][A-Za-z0-9_.\\\\/>-]*$'
|
|
1339
|
+
)
|
|
1340
|
+
matches = re.findall(r"'([^']*)'", s)
|
|
1341
|
+
error(f'{matches[0]}')
|
|
1342
|
+
error(
|
|
1343
|
+
f'Please convert again with the `-osd` or `--output_signaturedefs` option.'
|
|
1344
|
+
)
|
|
1345
|
+
sys.exit(1)
|
|
1346
|
+
else:
|
|
1347
|
+
error(e)
|
|
1348
|
+
import traceback
|
|
1349
|
+
error(traceback.format_exc(), prefix=False)
|
|
1350
|
+
else:
|
|
1351
|
+
error(e)
|
|
1352
|
+
import traceback
|
|
1353
|
+
error(traceback.format_exc(), prefix=False)
|
|
1354
|
+
except ValueError as e:
|
|
1237
1355
|
msg_list = [s for s in e.args if isinstance(s, str)]
|
|
1238
1356
|
if len(msg_list) > 0:
|
|
1239
1357
|
for s in msg_list:
|
|
1240
|
-
if '
|
|
1241
|
-
|
|
1242
|
-
|
|
1243
|
-
export_archive.add_endpoint(
|
|
1244
|
-
name=SIGNATURE_KEY,
|
|
1245
|
-
fn=lambda *inputs : model(inputs),
|
|
1246
|
-
input_signature=[tf.TensorSpec(tensor.shape, tensor.dtype, tensor.name) for tensor in model.inputs],
|
|
1358
|
+
if 'A root scope name has to match the following pattern' in s:
|
|
1359
|
+
error(
|
|
1360
|
+
f'Generation of saved_model failed because the OP name does not match the following pattern. ^[A-Za-z0-9.][A-Za-z0-9_.\\\\/>-]*$'
|
|
1247
1361
|
)
|
|
1248
|
-
|
|
1249
|
-
|
|
1362
|
+
matches = re.findall(r"'([^']*)'", s)
|
|
1363
|
+
error(f'{matches[0]}')
|
|
1364
|
+
error(
|
|
1365
|
+
f'Please convert again with the `-osd` or `--output_signaturedefs` option.'
|
|
1366
|
+
)
|
|
1367
|
+
sys.exit(1)
|
|
1250
1368
|
else:
|
|
1251
1369
|
error(e)
|
|
1252
1370
|
import traceback
|
|
@@ -1408,19 +1526,23 @@ def convert(
|
|
|
1408
1526
|
)
|
|
1409
1527
|
|
|
1410
1528
|
# Quantized TFLite
|
|
1411
|
-
MEAN = np.asarray([[[[0.485, 0.456, 0.406]]]], dtype=np.float32)
|
|
1412
|
-
STD = np.asarray([[[[0.229, 0.224, 0.225]]]], dtype=np.float32)
|
|
1413
1529
|
if output_integer_quantized_tflite:
|
|
1414
1530
|
# Get signatures/input keys
|
|
1415
|
-
|
|
1416
|
-
|
|
1417
|
-
|
|
1418
|
-
|
|
1419
|
-
|
|
1420
|
-
|
|
1421
|
-
|
|
1422
|
-
|
|
1423
|
-
|
|
1531
|
+
trackable_obj = \
|
|
1532
|
+
tf.saved_model.load(
|
|
1533
|
+
output_folder_path
|
|
1534
|
+
)
|
|
1535
|
+
loaded_saved_model: _WrapperFunction = trackable_obj.signatures[SIGNATURE_KEY]
|
|
1536
|
+
structured_input_signature: Dict[str, tf.TensorSpec] = loaded_saved_model.structured_input_signature[1]
|
|
1537
|
+
structured_outputs: Dict[str, tf.TensorSpec] = loaded_saved_model.structured_outputs
|
|
1538
|
+
|
|
1539
|
+
input_keys: List[str] = list(structured_input_signature.keys())
|
|
1540
|
+
input_shapes: List[tf.TensorShape] = [v.shape for v in structured_input_signature.values()]
|
|
1541
|
+
input_dtypes: List[tf.dtypes.DType] = [v.dtype for v in structured_input_signature.values()]
|
|
1542
|
+
|
|
1543
|
+
output_keys: List[str] = list(structured_outputs.keys())
|
|
1544
|
+
output_shapes: List[tf.TensorShape] = [v.shape for v in structured_outputs.values()]
|
|
1545
|
+
output_dtypes: List[tf.dtypes.DType] = [v.dtype for v in structured_outputs.values()]
|
|
1424
1546
|
|
|
1425
1547
|
print('')
|
|
1426
1548
|
info(Color.BLUE(f'Signature information for quantization'))
|
|
@@ -1463,15 +1585,30 @@ def convert(
|
|
|
1463
1585
|
for model_input in model.inputs:
|
|
1464
1586
|
if model_input.dtype != tf.float32 \
|
|
1465
1587
|
or len(model_input.shape) != 4 \
|
|
1466
|
-
or model_input.shape[-1]
|
|
1588
|
+
or model_input.shape[-1] not in [3, 4]:
|
|
1467
1589
|
error(
|
|
1468
1590
|
f'For models that have multiple input OPs and need to perform INT8 quantization calibration '+
|
|
1469
|
-
f'using non-rgb-image input tensors, specify the calibration data with '+
|
|
1591
|
+
f'using non-rgb-image/non-rgba-image input tensors, specify the calibration data with '+
|
|
1470
1592
|
f'--quant_calib_input_op_name_np_data_path. '+
|
|
1471
1593
|
f'model_input[n].shape: {model_input.shape}'
|
|
1472
1594
|
)
|
|
1473
1595
|
sys.exit(1)
|
|
1474
1596
|
|
|
1597
|
+
if model_input.shape[-1] == 3:
|
|
1598
|
+
# RGB
|
|
1599
|
+
mean = np.asarray([[[[0.485, 0.456, 0.406]]]], dtype=np.float32)
|
|
1600
|
+
std = np.asarray([[[[0.229, 0.224, 0.225]]]], dtype=np.float32)
|
|
1601
|
+
elif model_input.shape[-1] == 4:
|
|
1602
|
+
# RGBA
|
|
1603
|
+
mean = np.asarray([[[[0.485, 0.456, 0.406, 0.000]]]], dtype=np.float32)
|
|
1604
|
+
std = np.asarray([[[[0.229, 0.224, 0.225, 1.000]]]], dtype=np.float32)
|
|
1605
|
+
new_element_array = np.full((*calib_data.shape[:-1], 1), 0.500, dtype=np.float32)
|
|
1606
|
+
calib_data = np.concatenate((calib_data, new_element_array), axis=-1)
|
|
1607
|
+
else:
|
|
1608
|
+
# Others
|
|
1609
|
+
mean = np.asarray([[[[0.485, 0.456, 0.406]]]], dtype=np.float32)
|
|
1610
|
+
std = np.asarray([[[[0.229, 0.224, 0.225]]]], dtype=np.float32)
|
|
1611
|
+
|
|
1475
1612
|
calib_data_dict[model_input.name] = \
|
|
1476
1613
|
[
|
|
1477
1614
|
tf.image.resize(
|
|
@@ -1481,8 +1618,8 @@ def convert(
|
|
|
1481
1618
|
model_input.shape[2] if model_input.shape[2] is not None else 640,
|
|
1482
1619
|
)
|
|
1483
1620
|
),
|
|
1484
|
-
|
|
1485
|
-
|
|
1621
|
+
mean,
|
|
1622
|
+
std,
|
|
1486
1623
|
]
|
|
1487
1624
|
elif custom_input_op_name_np_data_path is not None:
|
|
1488
1625
|
for param in custom_input_op_name_np_data_path:
|
|
@@ -1846,6 +1983,7 @@ def main():
|
|
|
1846
1983
|
'-osd',
|
|
1847
1984
|
'--output_signaturedefs',
|
|
1848
1985
|
action='store_true',
|
|
1986
|
+
default=True,
|
|
1849
1987
|
help=\
|
|
1850
1988
|
'Signature is added to the output for serving or for conversion \n' +
|
|
1851
1989
|
'to other model formats. However, this can significantly reduce the speed \n' +
|
|
@@ -2101,6 +2239,18 @@ def main():
|
|
|
2101
2239
|
'e.g. \n' +
|
|
2102
2240
|
'--keep_shape_absolutely_input_names "input0" "input1" "input2"'
|
|
2103
2241
|
)
|
|
2242
|
+
parser.add_argument(
|
|
2243
|
+
'-inimc',
|
|
2244
|
+
'--input_names_to_interrupt_model_conversion',
|
|
2245
|
+
type=str,
|
|
2246
|
+
nargs='+',
|
|
2247
|
+
help=\
|
|
2248
|
+
'Input names that interrupt model conversion. \n' +
|
|
2249
|
+
'Interrupts model transformation at the specified input name \n' +
|
|
2250
|
+
'and inputs the model partitioned into subgraphs. \n\n' +
|
|
2251
|
+
'e.g. \n' +
|
|
2252
|
+
'--input_names_to_interrupt_model_conversion "input0" "input1" "input2"'
|
|
2253
|
+
)
|
|
2104
2254
|
parser.add_argument(
|
|
2105
2255
|
'-onimc',
|
|
2106
2256
|
'--output_names_to_interrupt_model_conversion',
|
|
@@ -2444,6 +2594,7 @@ def main():
|
|
|
2444
2594
|
keep_ncw_or_nchw_or_ncdhw_input_names=args.keep_ncw_or_nchw_or_ncdhw_input_names,
|
|
2445
2595
|
keep_nwc_or_nhwc_or_ndhwc_input_names=args.keep_nwc_or_nhwc_or_ndhwc_input_names,
|
|
2446
2596
|
keep_shape_absolutely_input_names=args.keep_shape_absolutely_input_names,
|
|
2597
|
+
input_names_to_interrupt_model_conversion=args.input_names_to_interrupt_model_conversion,
|
|
2447
2598
|
output_names_to_interrupt_model_conversion=args.output_names_to_interrupt_model_conversion,
|
|
2448
2599
|
disable_group_convolution=args.disable_group_convolution,
|
|
2449
2600
|
enable_accumulation_type_float16=args.enable_accumulation_type_float16,
|
onnx2tf/ops/Add.py
CHANGED
|
@@ -168,8 +168,37 @@ def make_node(
|
|
|
168
168
|
is_scalar_2_rank = tf.rank(input_tensor_2) == 0
|
|
169
169
|
if hasattr(is_scalar_2_rank, 'numpy'):
|
|
170
170
|
is_scalar_2 = is_scalar_2_rank.numpy()
|
|
171
|
+
|
|
171
172
|
if (is_scalar_1 or is_scalar_2) and graph_node.i().op == 'Gemm':
|
|
172
173
|
pass
|
|
174
|
+
elif (is_scalar_1 or is_scalar_2) and graph_node.i().op != 'Gemm':
|
|
175
|
+
first_tensor = None
|
|
176
|
+
second_tensor = None
|
|
177
|
+
if is_scalar_1:
|
|
178
|
+
first_tensor = input_tensor_2
|
|
179
|
+
second_tensor = input_tensor_1
|
|
180
|
+
elif is_scalar_2:
|
|
181
|
+
first_tensor = input_tensor_1
|
|
182
|
+
second_tensor = input_tensor_2
|
|
183
|
+
tmp_result = tf.math.add(first_tensor, second_tensor)
|
|
184
|
+
tmp_result_shape = tmp_result.shape
|
|
185
|
+
if first_tensor.shape == tmp_result_shape:
|
|
186
|
+
pass
|
|
187
|
+
else:
|
|
188
|
+
input_tensor_1, input_tensor_2 = \
|
|
189
|
+
pre_explicit_broadcast(
|
|
190
|
+
input_tensor_1=input_tensor_1,
|
|
191
|
+
input_tensor_2=input_tensor_2,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
input_tensor_1, input_tensor_2 = \
|
|
195
|
+
explicit_broadcast(
|
|
196
|
+
const_or_var_1=input_tensor_1,
|
|
197
|
+
const_or_var_2=input_tensor_2,
|
|
198
|
+
graph_node=graph_node,
|
|
199
|
+
tf_layers_dict= tf_layers_dict,
|
|
200
|
+
)
|
|
201
|
+
|
|
173
202
|
else:
|
|
174
203
|
input_tensor_1, input_tensor_2 = \
|
|
175
204
|
pre_explicit_broadcast(
|
onnx2tf/ops/AveragePool.py
CHANGED
|
@@ -268,11 +268,14 @@ def make_node(
|
|
|
268
268
|
[list(i) for i in zip(tf_pads[:len(tf_pads) // 2], tf_pads[len(tf_pads) // 2:])] + \
|
|
269
269
|
[[0, 0]]
|
|
270
270
|
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
271
|
+
if spatial_size == 1 and kernel_shape[0] > input_tensor_shape[1]:
|
|
272
|
+
padded_tensor = input_tensor
|
|
273
|
+
else:
|
|
274
|
+
padded_tensor = tf.pad(
|
|
275
|
+
tensor=input_tensor,
|
|
276
|
+
paddings=tf_pads,
|
|
277
|
+
mode='CONSTANT',
|
|
278
|
+
)
|
|
276
279
|
|
|
277
280
|
else:
|
|
278
281
|
padded_tensor = input_tensor
|
|
@@ -306,11 +309,18 @@ def make_node(
|
|
|
306
309
|
# Generation of TF OP
|
|
307
310
|
tf_op_type = None
|
|
308
311
|
if len(kernel_shape) == 1:
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
312
|
+
if kernel_shape[0] > padded_tensor.shape[1]:
|
|
313
|
+
pooled_tensor = AveragePooling1D(
|
|
314
|
+
pool_size=[padded_tensor.shape[1]],
|
|
315
|
+
strides=[padded_tensor.shape[1]],
|
|
316
|
+
padding=tf_pad_mode.upper(),
|
|
317
|
+
)(padded_tensor)
|
|
318
|
+
else:
|
|
319
|
+
pooled_tensor = AveragePooling1D(
|
|
320
|
+
pool_size=kernel_shape,
|
|
321
|
+
strides=strides,
|
|
322
|
+
padding=tf_pad_mode.upper(),
|
|
323
|
+
)(padded_tensor)
|
|
314
324
|
tf_op_type = AveragePooling1D
|
|
315
325
|
|
|
316
326
|
elif len(kernel_shape) == 2:
|