onnx2tf 1.29.9__py3-none-any.whl → 1.29.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2tf/__init__.py +1 -1
- onnx2tf/onnx2tf.py +83 -73
- onnx2tf/ops/If.py +4 -2
- onnx2tf/ops/Loop.py +392 -0
- onnx2tf/ops/LpPool.py +296 -0
- onnx2tf/ops/MaxRoiPool.py +236 -0
- onnx2tf/utils/common_functions.py +3 -0
- {onnx2tf-1.29.9.dist-info → onnx2tf-1.29.11.dist-info}/METADATA +26 -22
- {onnx2tf-1.29.9.dist-info → onnx2tf-1.29.11.dist-info}/RECORD +11 -12
- {onnx2tf-1.29.9.dist-info → onnx2tf-1.29.11.dist-info}/WHEEL +1 -2
- onnx2tf-1.29.11.dist-info/entry_points.txt +3 -0
- onnx2tf/ops/_Loop.py +0 -306
- onnx2tf/ops/__Loop.py +0 -509
- onnx2tf-1.29.9.dist-info/licenses/LICENSE +0 -21
- onnx2tf-1.29.9.dist-info/licenses/LICENSE_onnx-tensorflow +0 -213
- onnx2tf-1.29.9.dist-info/top_level.txt +0 -1
onnx2tf/__init__.py
CHANGED
onnx2tf/onnx2tf.py
CHANGED
|
@@ -1467,82 +1467,89 @@ def convert(
|
|
|
1467
1467
|
SIGNATURE_KEY = 'serving_default'
|
|
1468
1468
|
|
|
1469
1469
|
# saved_model
|
|
1470
|
+
saved_model_log_level = get_log_level()
|
|
1470
1471
|
try:
|
|
1471
|
-
|
|
1472
|
-
|
|
1473
|
-
|
|
1474
|
-
|
|
1475
|
-
|
|
1476
|
-
|
|
1477
|
-
|
|
1478
|
-
|
|
1479
|
-
|
|
1480
|
-
|
|
1481
|
-
|
|
1482
|
-
|
|
1483
|
-
|
|
1484
|
-
|
|
1485
|
-
|
|
1486
|
-
|
|
1487
|
-
|
|
1488
|
-
|
|
1489
|
-
|
|
1490
|
-
|
|
1491
|
-
|
|
1492
|
-
|
|
1493
|
-
|
|
1494
|
-
export_archive = tf_keras.export.ExportArchive()
|
|
1495
|
-
export_archive.add_endpoint(
|
|
1496
|
-
name=SIGNATURE_KEY,
|
|
1497
|
-
fn=lambda *inputs : model(inputs),
|
|
1498
|
-
input_signature=[tf.TensorSpec(tensor.shape, tensor.dtype, tensor.name) for tensor in model.inputs],
|
|
1499
|
-
)
|
|
1500
|
-
export_archive.write_out(output_folder_path)
|
|
1501
|
-
break
|
|
1502
|
-
except ValueError as e:
|
|
1503
|
-
msg_list = [s for s in e.args if isinstance(s, str)]
|
|
1504
|
-
if len(msg_list) > 0:
|
|
1472
|
+
if saved_model_log_level <= LOG_LEVELS['debug']:
|
|
1473
|
+
set_log_level('info')
|
|
1474
|
+
try:
|
|
1475
|
+
# concrete_func
|
|
1476
|
+
info(Color.REVERSE(f'saved_model output started'), '=' * 58)
|
|
1477
|
+
if not output_signaturedefs and not output_integer_quantized_tflite:
|
|
1478
|
+
tf.saved_model.save(model, output_folder_path)
|
|
1479
|
+
else:
|
|
1480
|
+
export_archive = tf_keras.export.ExportArchive()
|
|
1481
|
+
export_archive.add_endpoint(
|
|
1482
|
+
name=SIGNATURE_KEY,
|
|
1483
|
+
fn=lambda *inputs : model(inputs),
|
|
1484
|
+
input_signature=[tf.TensorSpec(tensor.shape, tensor.dtype, tensor.name) for tensor in model.inputs],
|
|
1485
|
+
)
|
|
1486
|
+
export_archive.write_out(output_folder_path)
|
|
1487
|
+
info(Color.GREEN(f'saved_model output complete!'))
|
|
1488
|
+
except TypeError as e:
|
|
1489
|
+
# Switch to .pb
|
|
1490
|
+
info(Color.GREEN(f'Switch to the output of an optimized protocol buffer file (.pb).'))
|
|
1491
|
+
except (KeyError, AssertionError) as e:
|
|
1492
|
+
msg_list = [s for s in e.args if isinstance(s, str)]
|
|
1493
|
+
if len(msg_list) > 0:
|
|
1494
|
+
try:
|
|
1505
1495
|
for s in msg_list:
|
|
1506
|
-
if '
|
|
1507
|
-
|
|
1508
|
-
|
|
1496
|
+
if 'Failed to add concrete function' in s \
|
|
1497
|
+
or "Tried to export a function which references an 'untracked' resource" in s:
|
|
1498
|
+
export_archive = tf_keras.export.ExportArchive()
|
|
1499
|
+
export_archive.add_endpoint(
|
|
1500
|
+
name=SIGNATURE_KEY,
|
|
1501
|
+
fn=lambda *inputs : model(inputs),
|
|
1502
|
+
input_signature=[tf.TensorSpec(tensor.shape, tensor.dtype, tensor.name) for tensor in model.inputs],
|
|
1509
1503
|
)
|
|
1510
|
-
|
|
1511
|
-
|
|
1512
|
-
|
|
1513
|
-
|
|
1514
|
-
|
|
1515
|
-
|
|
1516
|
-
|
|
1517
|
-
|
|
1518
|
-
|
|
1519
|
-
|
|
1520
|
-
|
|
1521
|
-
|
|
1522
|
-
|
|
1523
|
-
|
|
1524
|
-
|
|
1525
|
-
|
|
1526
|
-
|
|
1527
|
-
|
|
1528
|
-
|
|
1529
|
-
|
|
1530
|
-
|
|
1531
|
-
|
|
1532
|
-
|
|
1533
|
-
|
|
1534
|
-
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
|
|
1538
|
-
|
|
1504
|
+
export_archive.write_out(output_folder_path)
|
|
1505
|
+
break
|
|
1506
|
+
except ValueError as e:
|
|
1507
|
+
msg_list = [s for s in e.args if isinstance(s, str)]
|
|
1508
|
+
if len(msg_list) > 0:
|
|
1509
|
+
for s in msg_list:
|
|
1510
|
+
if 'A root scope name has to match the following pattern' in s:
|
|
1511
|
+
error(
|
|
1512
|
+
f'Generation of saved_model failed because the OP name does not match the following pattern. ^[A-Za-z0-9.][A-Za-z0-9_.\\\\/>-]*$'
|
|
1513
|
+
)
|
|
1514
|
+
matches = re.findall(r"'([^']*)'", s)
|
|
1515
|
+
error(f'{matches[0]}')
|
|
1516
|
+
error(
|
|
1517
|
+
f'Please convert again with the `-osd` or `--output_signaturedefs` option.'
|
|
1518
|
+
)
|
|
1519
|
+
sys.exit(1)
|
|
1520
|
+
else:
|
|
1521
|
+
error(e)
|
|
1522
|
+
import traceback
|
|
1523
|
+
error(traceback.format_exc(), prefix=False)
|
|
1524
|
+
else:
|
|
1525
|
+
error(e)
|
|
1526
|
+
import traceback
|
|
1527
|
+
error(traceback.format_exc(), prefix=False)
|
|
1528
|
+
except ValueError as e:
|
|
1529
|
+
msg_list = [s for s in e.args if isinstance(s, str)]
|
|
1530
|
+
if len(msg_list) > 0:
|
|
1531
|
+
for s in msg_list:
|
|
1532
|
+
if 'A root scope name has to match the following pattern' in s:
|
|
1533
|
+
error(
|
|
1534
|
+
f'Generation of saved_model failed because the OP name does not match the following pattern. ^[A-Za-z0-9.][A-Za-z0-9_.\\\\/>-]*$'
|
|
1535
|
+
)
|
|
1536
|
+
matches = re.findall(r"'([^']*)'", s)
|
|
1537
|
+
error(f'{matches[0]}')
|
|
1538
|
+
error(
|
|
1539
|
+
f'Please convert again with the `-osd` or `--output_signaturedefs` option.'
|
|
1540
|
+
)
|
|
1541
|
+
sys.exit(1)
|
|
1542
|
+
else:
|
|
1543
|
+
error(e)
|
|
1544
|
+
import traceback
|
|
1545
|
+
error(traceback.format_exc(), prefix=False)
|
|
1546
|
+
except Exception as e:
|
|
1539
1547
|
error(e)
|
|
1540
1548
|
import traceback
|
|
1541
1549
|
error(traceback.format_exc(), prefix=False)
|
|
1542
|
-
|
|
1543
|
-
|
|
1544
|
-
|
|
1545
|
-
error(traceback.format_exc(), prefix=False)
|
|
1550
|
+
finally:
|
|
1551
|
+
if get_log_level() != saved_model_log_level:
|
|
1552
|
+
set_log_level(saved_model_log_level)
|
|
1546
1553
|
|
|
1547
1554
|
# TFv1 .pb
|
|
1548
1555
|
if output_tfv1_pb:
|
|
@@ -1581,9 +1588,12 @@ def convert(
|
|
|
1581
1588
|
Name: flatbuffers
|
|
1582
1589
|
Version: 22.10.26
|
|
1583
1590
|
"""
|
|
1584
|
-
|
|
1585
|
-
|
|
1586
|
-
|
|
1591
|
+
try:
|
|
1592
|
+
converter = tf.lite.TFLiteConverter.from_keras_model(model)
|
|
1593
|
+
except Exception as e:
|
|
1594
|
+
converter = tf.lite.TFLiteConverter.from_concrete_functions(
|
|
1595
|
+
[concrete_func]
|
|
1596
|
+
)
|
|
1587
1597
|
converter.target_spec.supported_ops = [
|
|
1588
1598
|
tf.lite.OpsSet.TFLITE_BUILTINS,
|
|
1589
1599
|
tf.lite.OpsSet.SELECT_TF_OPS,
|
onnx2tf/ops/If.py
CHANGED
|
@@ -54,6 +54,8 @@ def make_node(
|
|
|
54
54
|
graph_node_outputs = [] + graph_node.outputs
|
|
55
55
|
|
|
56
56
|
# Then branch
|
|
57
|
+
subgraph_kwargs = dict(kwargs)
|
|
58
|
+
subgraph_kwargs['suppress_log'] = True
|
|
57
59
|
then_branch_graph: gs.Graph = graph_node.attrs['then_branch']
|
|
58
60
|
then_branch_graph_outputs = then_branch_graph.outputs
|
|
59
61
|
for then_branch_graph_node in then_branch_graph.nodes:
|
|
@@ -73,7 +75,7 @@ def make_node(
|
|
|
73
75
|
op.make_node(
|
|
74
76
|
graph_node=then_branch_graph_node,
|
|
75
77
|
tf_layers_dict=tf_layers_dict,
|
|
76
|
-
**
|
|
78
|
+
**subgraph_kwargs,
|
|
77
79
|
)
|
|
78
80
|
# Then branch - Resister constant
|
|
79
81
|
for output in then_branch_graph_outputs:
|
|
@@ -115,7 +117,7 @@ def make_node(
|
|
|
115
117
|
op.make_node(
|
|
116
118
|
graph_node=else_branch_graph_node,
|
|
117
119
|
tf_layers_dict=tf_layers_dict,
|
|
118
|
-
**
|
|
120
|
+
**subgraph_kwargs,
|
|
119
121
|
)
|
|
120
122
|
# Else branch - Resister constant
|
|
121
123
|
for output in else_branch_graph_outputs:
|
onnx2tf/ops/Loop.py
ADDED
|
@@ -0,0 +1,392 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import sys
|
|
3
|
+
import random
|
|
4
|
+
random.seed(0)
|
|
5
|
+
import numpy as np
|
|
6
|
+
np.random.seed(0)
|
|
7
|
+
import tensorflow as tf
|
|
8
|
+
import tf_keras
|
|
9
|
+
import onnx_graphsurgeon as gs
|
|
10
|
+
from onnx2tf.utils.common_functions import (
|
|
11
|
+
get_constant_or_variable,
|
|
12
|
+
print_node_info,
|
|
13
|
+
inverted_operation_enable_disable,
|
|
14
|
+
make_tf_node_info,
|
|
15
|
+
)
|
|
16
|
+
from onnx2tf.utils.enums import NUMPY_DTYPES_TO_TF_DTYPES
|
|
17
|
+
import importlib
|
|
18
|
+
from onnx2tf.utils.logging import *
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class While_Loop_CustomLayer(tf_keras.layers.Layer):
|
|
22
|
+
def __init__(self):
|
|
23
|
+
super(While_Loop_CustomLayer, self).__init__()
|
|
24
|
+
|
|
25
|
+
def call(self, cond, body, loop_vars, shape_invariants, maximum_iterations):
|
|
26
|
+
return tf.while_loop(
|
|
27
|
+
cond=cond,
|
|
28
|
+
body=body,
|
|
29
|
+
loop_vars=loop_vars,
|
|
30
|
+
shape_invariants=shape_invariants,
|
|
31
|
+
maximum_iterations=maximum_iterations,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _to_tf_dtype(dtype):
|
|
36
|
+
return NUMPY_DTYPES_TO_TF_DTYPES[dtype] if isinstance(dtype, np.dtype) else dtype
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _as_tensor(value):
|
|
40
|
+
if isinstance(value, np.ndarray):
|
|
41
|
+
return tf.convert_to_tensor(value)
|
|
42
|
+
if isinstance(value, (np.generic, int, float, bool)):
|
|
43
|
+
return tf.convert_to_tensor(value)
|
|
44
|
+
return value
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _shape_invariant(value):
|
|
48
|
+
try:
|
|
49
|
+
shape = value.shape
|
|
50
|
+
except Exception:
|
|
51
|
+
return tf.TensorShape(None)
|
|
52
|
+
if shape is None:
|
|
53
|
+
return tf.TensorShape(None)
|
|
54
|
+
if isinstance(shape, tf.TensorShape):
|
|
55
|
+
if shape.rank is None:
|
|
56
|
+
return tf.TensorShape(None)
|
|
57
|
+
return tf.TensorShape([None for _ in range(shape.rank)])
|
|
58
|
+
return tf.TensorShape([None for _ in range(len(shape))])
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@print_node_info
|
|
62
|
+
@inverted_operation_enable_disable
|
|
63
|
+
def make_node(
|
|
64
|
+
*,
|
|
65
|
+
graph_node: gs.Node,
|
|
66
|
+
tf_layers_dict: dict,
|
|
67
|
+
**kwargs: dict,
|
|
68
|
+
):
|
|
69
|
+
"""Loop
|
|
70
|
+
|
|
71
|
+
Parameters
|
|
72
|
+
----------
|
|
73
|
+
graph_node: gs.Node
|
|
74
|
+
graph_surgeon Node
|
|
75
|
+
|
|
76
|
+
tf_layers_dict: dict
|
|
77
|
+
optype, shape, dtype, tensorflow graph
|
|
78
|
+
"""
|
|
79
|
+
before_op_output_shape_trans_1 = \
|
|
80
|
+
tf_layers_dict.get(graph_node.inputs[0].name, {}).get('before_op_output_shape_trans', True)
|
|
81
|
+
before_op_output_shape_trans_2 = \
|
|
82
|
+
tf_layers_dict.get(graph_node.inputs[1].name, {}).get('before_op_output_shape_trans', True)
|
|
83
|
+
before_op_output_shape_trans = \
|
|
84
|
+
before_op_output_shape_trans_1 \
|
|
85
|
+
and before_op_output_shape_trans_2
|
|
86
|
+
|
|
87
|
+
graph_node_input_1 = get_constant_or_variable(
|
|
88
|
+
graph_node.inputs[0],
|
|
89
|
+
before_op_output_shape_trans,
|
|
90
|
+
)
|
|
91
|
+
graph_node_input_2 = get_constant_or_variable(
|
|
92
|
+
graph_node.inputs[1],
|
|
93
|
+
before_op_output_shape_trans,
|
|
94
|
+
)
|
|
95
|
+
graph_node_input_n_list = []
|
|
96
|
+
for graph_node_input in graph_node.inputs[2:]:
|
|
97
|
+
graph_node_input_n = get_constant_or_variable(
|
|
98
|
+
graph_node_input,
|
|
99
|
+
before_op_output_shape_trans,
|
|
100
|
+
)
|
|
101
|
+
graph_node_input_n_list.append(graph_node_input_n)
|
|
102
|
+
|
|
103
|
+
def _sanitize(name):
|
|
104
|
+
name = name.replace(':', '__')
|
|
105
|
+
if kwargs.get('output_signaturedefs', False):
|
|
106
|
+
name = re.sub('^/', 'wa/', name)
|
|
107
|
+
return name
|
|
108
|
+
|
|
109
|
+
# M: maximum trip-count
|
|
110
|
+
M = tf_layers_dict[graph_node_input_1.name]['tf_node'] \
|
|
111
|
+
if isinstance(graph_node_input_1, gs.Variable) else graph_node_input_1
|
|
112
|
+
M = None if isinstance(M, str) and M == "" else M
|
|
113
|
+
M = _as_tensor(M) if M is not None else None
|
|
114
|
+
if M is not None:
|
|
115
|
+
M = tf.cast(M, tf.int64)
|
|
116
|
+
max_i32 = tf.constant(tf.int32.max, dtype=tf.int64)
|
|
117
|
+
M = tf.where(tf.greater(M, max_i32), max_i32, M)
|
|
118
|
+
M = tf.cast(M, tf.int32)
|
|
119
|
+
if M.shape is not None and M.shape.rank not in (None, 0):
|
|
120
|
+
M = tf.reshape(M, [])
|
|
121
|
+
|
|
122
|
+
# cond: loop continuation condition (optional)
|
|
123
|
+
cond = tf_layers_dict[graph_node_input_2.name]['tf_node'] \
|
|
124
|
+
if isinstance(graph_node_input_2, gs.Variable) else graph_node_input_2
|
|
125
|
+
cond = None if isinstance(cond, str) and cond == "" else cond
|
|
126
|
+
cond_init = None if cond is None else tf.cast(_as_tensor(cond), tf.bool)
|
|
127
|
+
if cond_init is not None \
|
|
128
|
+
and isinstance(graph_node_input_2, gs.Variable) \
|
|
129
|
+
and graph_node_input_2.shape is not None \
|
|
130
|
+
and len(graph_node_input_2.shape) == 0:
|
|
131
|
+
cond_init = tf.reshape(cond_init, [])
|
|
132
|
+
cond_provided = cond_init is not None
|
|
133
|
+
if not cond_provided:
|
|
134
|
+
cond_init = tf.constant(True, dtype=tf.bool)
|
|
135
|
+
|
|
136
|
+
v_init = []
|
|
137
|
+
v_input_meta = []
|
|
138
|
+
for graph_node_input_n in graph_node_input_n_list:
|
|
139
|
+
v_val = tf_layers_dict[graph_node_input_n.name]['tf_node'] \
|
|
140
|
+
if isinstance(graph_node_input_n, gs.Variable) else graph_node_input_n
|
|
141
|
+
v_val = _as_tensor(v_val)
|
|
142
|
+
if isinstance(graph_node_input_n, gs.Variable) \
|
|
143
|
+
and graph_node_input_n.shape is not None \
|
|
144
|
+
and len(graph_node_input_n.shape) == 0:
|
|
145
|
+
v_val = tf.reshape(v_val, [])
|
|
146
|
+
v_init.append(v_val)
|
|
147
|
+
if isinstance(graph_node_input_n, gs.Variable):
|
|
148
|
+
v_input_meta.append(tf_layers_dict.get(graph_node_input_n.name, {}))
|
|
149
|
+
else:
|
|
150
|
+
v_input_meta.append({})
|
|
151
|
+
|
|
152
|
+
v_shapes = [_shape_invariant(v) for v in v_init]
|
|
153
|
+
|
|
154
|
+
body: gs.Graph = graph_node.attrs["body"]
|
|
155
|
+
|
|
156
|
+
iter_cnt_init = tf.constant(0, dtype=tf.int32)
|
|
157
|
+
|
|
158
|
+
scan_outputs_start_index = 1 + len(v_init)
|
|
159
|
+
scan_outputs_init = []
|
|
160
|
+
for i in range(scan_outputs_start_index, len(body.outputs)):
|
|
161
|
+
elem_shape = body.outputs[i].shape
|
|
162
|
+
if elem_shape is not None:
|
|
163
|
+
elem_shape = [
|
|
164
|
+
dim if isinstance(dim, int) else None for dim in elem_shape
|
|
165
|
+
]
|
|
166
|
+
elem_shape = tf.TensorShape(elem_shape)
|
|
167
|
+
scan_outputs_init.append(
|
|
168
|
+
tf.TensorArray(
|
|
169
|
+
dtype=_to_tf_dtype(body.outputs[i].dtype),
|
|
170
|
+
size=0,
|
|
171
|
+
dynamic_size=True,
|
|
172
|
+
element_shape=elem_shape,
|
|
173
|
+
)
|
|
174
|
+
)
|
|
175
|
+
scan_outputs_shapes = [tf.TensorShape(None) for _ in scan_outputs_init]
|
|
176
|
+
|
|
177
|
+
graph_node_outputs = list(graph_node.outputs)
|
|
178
|
+
for graph_node_output in graph_node_outputs:
|
|
179
|
+
tf_layers_dict[graph_node_output.name] = {
|
|
180
|
+
'optype': graph_node.op,
|
|
181
|
+
'shape': graph_node_output.shape,
|
|
182
|
+
'dtype': graph_node_output.dtype,
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
def _register_graph_output_constants(graph: gs.Graph):
|
|
186
|
+
for output in graph.outputs:
|
|
187
|
+
if output.name not in tf_layers_dict and isinstance(output, gs.Constant):
|
|
188
|
+
tf_layers_dict[output.name] = {
|
|
189
|
+
'optype': 'Constant',
|
|
190
|
+
'shape': output.values.shape,
|
|
191
|
+
'dtype': output.values.dtype,
|
|
192
|
+
}
|
|
193
|
+
tf_layers_dict[output.name]['tf_node'] = \
|
|
194
|
+
tf.constant(
|
|
195
|
+
output.values,
|
|
196
|
+
dtype=_to_tf_dtype(output.values.dtype),
|
|
197
|
+
)
|
|
198
|
+
for node in graph.nodes:
|
|
199
|
+
for attr_val in node.attrs.values():
|
|
200
|
+
if isinstance(attr_val, gs.Graph):
|
|
201
|
+
_register_graph_output_constants(attr_val)
|
|
202
|
+
elif isinstance(attr_val, (list, tuple)):
|
|
203
|
+
for sub_val in attr_val:
|
|
204
|
+
if isinstance(sub_val, gs.Graph):
|
|
205
|
+
_register_graph_output_constants(sub_val)
|
|
206
|
+
|
|
207
|
+
# Register subgraph constants outside the while_loop to avoid scope issues.
|
|
208
|
+
_register_graph_output_constants(body)
|
|
209
|
+
|
|
210
|
+
def run_subgraph(iter_cnt, cond, v, scan_outputs):
|
|
211
|
+
# Bind loop vars to body graph inputs
|
|
212
|
+
loop_inputs = [iter_cnt, cond] + list(v)
|
|
213
|
+
for idx, (body_input, loop_val) in enumerate(zip(body.inputs, loop_inputs)):
|
|
214
|
+
body_input.name = _sanitize(body_input.name)
|
|
215
|
+
target_dtype = _to_tf_dtype(body_input.dtype) if body_input.dtype is not None else None
|
|
216
|
+
loop_val_cast = loop_val
|
|
217
|
+
if target_dtype is not None \
|
|
218
|
+
and isinstance(loop_val, tf.Tensor) \
|
|
219
|
+
and loop_val.dtype != target_dtype:
|
|
220
|
+
loop_val_cast = tf.cast(loop_val, target_dtype)
|
|
221
|
+
if body_input.shape is not None \
|
|
222
|
+
and len(body_input.shape) == 0:
|
|
223
|
+
loop_val_cast = tf.reshape(loop_val_cast, [])
|
|
224
|
+
tf_layers_dict[body_input.name] = {
|
|
225
|
+
'optype': 'Input',
|
|
226
|
+
'shape': body_input.shape,
|
|
227
|
+
'dtype': body_input.dtype,
|
|
228
|
+
'tf_node': loop_val_cast,
|
|
229
|
+
'before_op_output_shape_trans': True,
|
|
230
|
+
}
|
|
231
|
+
if idx >= 2:
|
|
232
|
+
meta = v_input_meta[idx - 2]
|
|
233
|
+
for key in ('before_op_output_shape_trans', 'nhwc'):
|
|
234
|
+
if key in meta:
|
|
235
|
+
tf_layers_dict[body_input.name][key] = meta[key]
|
|
236
|
+
|
|
237
|
+
subgraph_kwargs = dict(kwargs)
|
|
238
|
+
subgraph_kwargs['suppress_log'] = True
|
|
239
|
+
for body_node in body.nodes:
|
|
240
|
+
optype = body_node.op
|
|
241
|
+
try:
|
|
242
|
+
op = importlib.import_module(f'onnx2tf.ops.{optype}')
|
|
243
|
+
except ModuleNotFoundError as ex:
|
|
244
|
+
error(
|
|
245
|
+
f'{optype} OP is not yet implemented.'
|
|
246
|
+
)
|
|
247
|
+
sys.exit(1)
|
|
248
|
+
body_node.name = _sanitize(body_node.name)
|
|
249
|
+
op.make_node(
|
|
250
|
+
graph_node=body_node,
|
|
251
|
+
tf_layers_dict=tf_layers_dict,
|
|
252
|
+
**subgraph_kwargs,
|
|
253
|
+
)
|
|
254
|
+
outputs = [tf_layers_dict[output.name]['tf_node'] for output in body.outputs]
|
|
255
|
+
for i in range(scan_outputs_start_index, len(outputs)):
|
|
256
|
+
s_index = i - scan_outputs_start_index
|
|
257
|
+
scan_outputs[s_index] = scan_outputs[s_index].write(
|
|
258
|
+
scan_outputs[s_index].size(), outputs[i]
|
|
259
|
+
)
|
|
260
|
+
cond_out = outputs[0]
|
|
261
|
+
if isinstance(cond_out, tf.Tensor) and cond_out.dtype != tf.bool:
|
|
262
|
+
cond_out = tf.cast(cond_out, tf.bool)
|
|
263
|
+
iter_cnt = iter_cnt + 1
|
|
264
|
+
return [iter_cnt, cond_out, outputs[1:scan_outputs_start_index], scan_outputs]
|
|
265
|
+
|
|
266
|
+
if M is None and not cond_provided:
|
|
267
|
+
error(
|
|
268
|
+
f'Both M and cond in Loop are not set at the same time ' +
|
|
269
|
+
f'Tensorflow.(PS. if you want to create a do-while loop ' +
|
|
270
|
+
f'then please set cond to True or 1)\n' +
|
|
271
|
+
f'graph_node.name: {graph_node.name}'
|
|
272
|
+
)
|
|
273
|
+
sys.exit(1)
|
|
274
|
+
|
|
275
|
+
cond_true = tf.constant(True, dtype=tf.bool)
|
|
276
|
+
if M is not None and not cond_provided:
|
|
277
|
+
condition = lambda iter_cnt, cond, v, scan_outputs: cond_true
|
|
278
|
+
while_loop_layer = While_Loop_CustomLayer()
|
|
279
|
+
iter_cnt_final, _, v_final, scan_outputs_final = while_loop_layer(
|
|
280
|
+
cond=condition,
|
|
281
|
+
body=run_subgraph,
|
|
282
|
+
loop_vars=[
|
|
283
|
+
iter_cnt_init,
|
|
284
|
+
cond_init,
|
|
285
|
+
v_init,
|
|
286
|
+
scan_outputs_init,
|
|
287
|
+
],
|
|
288
|
+
shape_invariants=[
|
|
289
|
+
tf.TensorShape([]),
|
|
290
|
+
tf.TensorShape(None),
|
|
291
|
+
v_shapes,
|
|
292
|
+
scan_outputs_shapes,
|
|
293
|
+
],
|
|
294
|
+
maximum_iterations=M,
|
|
295
|
+
)
|
|
296
|
+
elif M is None and cond_provided:
|
|
297
|
+
condition = lambda iter_cnt, cond, v, scan_outputs: tf.reduce_all(cond)
|
|
298
|
+
while_loop_layer = While_Loop_CustomLayer()
|
|
299
|
+
iter_cnt_final, cond_final, v_final, scan_outputs_final = while_loop_layer(
|
|
300
|
+
cond=condition,
|
|
301
|
+
body=run_subgraph,
|
|
302
|
+
loop_vars=[
|
|
303
|
+
iter_cnt_init,
|
|
304
|
+
cond_init,
|
|
305
|
+
v_init,
|
|
306
|
+
scan_outputs_init,
|
|
307
|
+
],
|
|
308
|
+
shape_invariants=[
|
|
309
|
+
tf.TensorShape([]),
|
|
310
|
+
tf.TensorShape(None),
|
|
311
|
+
v_shapes,
|
|
312
|
+
scan_outputs_shapes,
|
|
313
|
+
],
|
|
314
|
+
)
|
|
315
|
+
else:
|
|
316
|
+
condition = lambda iter_cnt, cond, v, scan_outputs: tf.reduce_all(cond)
|
|
317
|
+
while_loop_layer = While_Loop_CustomLayer()
|
|
318
|
+
iter_cnt_final, cond_final, v_final, scan_outputs_final = while_loop_layer(
|
|
319
|
+
cond=condition,
|
|
320
|
+
body=run_subgraph,
|
|
321
|
+
loop_vars=[
|
|
322
|
+
iter_cnt_init,
|
|
323
|
+
cond_init,
|
|
324
|
+
v_init,
|
|
325
|
+
scan_outputs_init,
|
|
326
|
+
],
|
|
327
|
+
shape_invariants=[
|
|
328
|
+
tf.TensorShape([]),
|
|
329
|
+
tf.TensorShape(None),
|
|
330
|
+
v_shapes,
|
|
331
|
+
scan_outputs_shapes,
|
|
332
|
+
],
|
|
333
|
+
maximum_iterations=M,
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
if scan_outputs_start_index == len(body.outputs):
|
|
337
|
+
final_outputs = list(v_final)
|
|
338
|
+
else:
|
|
339
|
+
def true_fn():
|
|
340
|
+
return scan_outputs_final
|
|
341
|
+
|
|
342
|
+
def false_fn():
|
|
343
|
+
empty_scan_outputs = []
|
|
344
|
+
for ta in scan_outputs_init:
|
|
345
|
+
empty_scan_outputs.append(
|
|
346
|
+
tf.TensorArray(
|
|
347
|
+
dtype=ta.dtype,
|
|
348
|
+
size=0,
|
|
349
|
+
element_shape=ta.element_shape,
|
|
350
|
+
)
|
|
351
|
+
)
|
|
352
|
+
return empty_scan_outputs
|
|
353
|
+
|
|
354
|
+
scan_out_final = tf.cond(tf.greater(iter_cnt_final, 0), true_fn, false_fn)
|
|
355
|
+
scan_outputs_tensors = [o.stack() for o in scan_out_final]
|
|
356
|
+
final_outputs = list(v_final) + scan_outputs_tensors
|
|
357
|
+
|
|
358
|
+
if len(final_outputs) != len(graph_node_outputs):
|
|
359
|
+
error(
|
|
360
|
+
f'Loop output count mismatch. expected={len(graph_node_outputs)} actual={len(final_outputs)}\n' +
|
|
361
|
+
f'graph_node.name: {graph_node.name}'
|
|
362
|
+
)
|
|
363
|
+
sys.exit(1)
|
|
364
|
+
|
|
365
|
+
for idx, (graph_node_output, output_tensor) in enumerate(zip(graph_node_outputs, final_outputs)):
|
|
366
|
+
tf_layers_dict[graph_node_output.name]['tf_node'] = output_tensor
|
|
367
|
+
if idx < len(v_init):
|
|
368
|
+
body_output = body.outputs[1 + idx]
|
|
369
|
+
else:
|
|
370
|
+
body_output = body.outputs[scan_outputs_start_index + (idx - len(v_init))]
|
|
371
|
+
body_meta = tf_layers_dict.get(body_output.name, {})
|
|
372
|
+
for key in ('before_op_output_shape_trans', 'nhwc'):
|
|
373
|
+
if key in body_meta:
|
|
374
|
+
tf_layers_dict[graph_node_output.name][key] = body_meta[key]
|
|
375
|
+
|
|
376
|
+
tf_outputs = {f"output{idx}": value for idx, value in enumerate(final_outputs)}
|
|
377
|
+
for graph_node_output in graph_node_outputs:
|
|
378
|
+
tf_layers_dict[graph_node_output.name]['tf_node_info'] = \
|
|
379
|
+
make_tf_node_info(
|
|
380
|
+
node_info={
|
|
381
|
+
'tf_op_type': tf.while_loop,
|
|
382
|
+
'tf_inputs': {
|
|
383
|
+
'condition': condition,
|
|
384
|
+
'M': M,
|
|
385
|
+
'cond': cond_init,
|
|
386
|
+
'v_initial': v_init,
|
|
387
|
+
},
|
|
388
|
+
'tf_outputs': {
|
|
389
|
+
'output': tf_outputs,
|
|
390
|
+
},
|
|
391
|
+
}
|
|
392
|
+
)
|