onnx2tf 1.29.8__py3-none-any.whl → 1.29.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2tf/__init__.py +1 -1
- onnx2tf/onnx2tf.py +83 -73
- onnx2tf/ops/If.py +4 -2
- onnx2tf/ops/Loop.py +392 -0
- onnx2tf/ops/LpPool.py +296 -0
- onnx2tf/ops/MaxRoiPool.py +236 -0
- onnx2tf/ops/Unsqueeze.py +69 -37
- onnx2tf/utils/common_functions.py +3 -0
- {onnx2tf-1.29.8.dist-info → onnx2tf-1.29.10.dist-info}/METADATA +6 -6
- {onnx2tf-1.29.8.dist-info → onnx2tf-1.29.10.dist-info}/RECORD +14 -13
- onnx2tf/ops/_Loop.py +0 -306
- onnx2tf/ops/__Loop.py +0 -509
- {onnx2tf-1.29.8.dist-info → onnx2tf-1.29.10.dist-info}/WHEEL +0 -0
- {onnx2tf-1.29.8.dist-info → onnx2tf-1.29.10.dist-info}/licenses/LICENSE +0 -0
- {onnx2tf-1.29.8.dist-info → onnx2tf-1.29.10.dist-info}/licenses/LICENSE_onnx-tensorflow +0 -0
- {onnx2tf-1.29.8.dist-info → onnx2tf-1.29.10.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: onnx2tf
|
|
3
|
-
Version: 1.29.
|
|
3
|
+
Version: 1.29.10
|
|
4
4
|
Summary: Self-Created Tools to convert ONNX files (NCHW) to TensorFlow/TFLite/Keras format (NHWC).
|
|
5
5
|
Home-page: https://github.com/PINTO0309/onnx2tf
|
|
6
6
|
Author: Katsuya Hyodo
|
|
@@ -182,16 +182,16 @@ https://github.com/PINTO0309/onnx2tf/wiki/model_status
|
|
|
182
182
|
|Less|:heavy_check_mark:|
|
|
183
183
|
|Log|:heavy_check_mark:|
|
|
184
184
|
|LogSoftmax|:heavy_check_mark:|
|
|
185
|
-
|Loop
|
|
185
|
+
|Loop|:heavy_check_mark:|
|
|
186
186
|
|LpNormalization|:heavy_check_mark:|
|
|
187
|
-
|LpPool
|
|
187
|
+
|LpPool|:heavy_check_mark:|
|
|
188
188
|
|LRN|:heavy_check_mark:|
|
|
189
189
|
|LSTM|:heavy_check_mark:|
|
|
190
190
|
|MatMul|:heavy_check_mark:|
|
|
191
191
|
|MatMulInteger|:heavy_check_mark:|
|
|
192
192
|
|MaxPool|:heavy_check_mark:|
|
|
193
193
|
|Max|:heavy_check_mark:|
|
|
194
|
-
|MaxRoiPool
|
|
194
|
+
|MaxRoiPool|:heavy_check_mark:|
|
|
195
195
|
|MaxUnpool|:heavy_check_mark:|
|
|
196
196
|
|Mean|:heavy_check_mark:|
|
|
197
197
|
|MeanVarianceNormalization|:heavy_check_mark:|
|
|
@@ -359,7 +359,7 @@ Video speed is adjusted approximately 50 times slower than actual speed.
|
|
|
359
359
|
docker run --rm -it \
|
|
360
360
|
-v `pwd`:/workdir \
|
|
361
361
|
-w /workdir \
|
|
362
|
-
ghcr.io/pinto0309/onnx2tf:1.29.
|
|
362
|
+
ghcr.io/pinto0309/onnx2tf:1.29.10
|
|
363
363
|
|
|
364
364
|
or
|
|
365
365
|
|
|
@@ -367,7 +367,7 @@ Video speed is adjusted approximately 50 times slower than actual speed.
|
|
|
367
367
|
docker run --rm -it \
|
|
368
368
|
-v `pwd`:/workdir \
|
|
369
369
|
-w /workdir \
|
|
370
|
-
docker.io/pinto0309/onnx2tf:1.29.
|
|
370
|
+
docker.io/pinto0309/onnx2tf:1.29.10
|
|
371
371
|
|
|
372
372
|
or
|
|
373
373
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
onnx2tf/__init__.py,sha256=
|
|
1
|
+
onnx2tf/__init__.py,sha256=E1LYdyUQ9pnnydTNI6NtTfnt7AZXaoxY42NlhsS5Jr0,67
|
|
2
2
|
onnx2tf/__main__.py,sha256=2RSCQ7d4lc6CwD-rlGn9UicPFg-P5du7ZD_yh-kuBEU,57
|
|
3
|
-
onnx2tf/onnx2tf.py,sha256=
|
|
3
|
+
onnx2tf/onnx2tf.py,sha256=yR8aKaEn01Q8dEeYDqHIsuZuG6l5TGQniHDlPiUROx4,152238
|
|
4
4
|
onnx2tf/ops/Abs.py,sha256=V7btmCG_ZvK_qJovUsguq0ZMJ349mhNQ4FHSgzP_Yuo,4029
|
|
5
5
|
onnx2tf/ops/Acos.py,sha256=Fo8YkFKuWq8Fi2xUrBdKcAH1yJ8r5pjSD0wgLttTNdk,4003
|
|
6
6
|
onnx2tf/ops/Acosh.py,sha256=ATQj2cT5JS_mTfXi0kXqJ1yzSZu5J0zHA5VjV3j7uKY,3588
|
|
@@ -75,7 +75,7 @@ onnx2tf/ops/HardSigmoid.py,sha256=KDP_t-Z70sDsHMOYxyJ7ZNH31zqkrViOKYCcRG5NJHc,36
|
|
|
75
75
|
onnx2tf/ops/HardSwish.py,sha256=nEng3LCDQYMZ4XhFZ7pXKGyRsM2_waowi8PlZt_f6Ck,3994
|
|
76
76
|
onnx2tf/ops/Hardmax.py,sha256=tiMch3Tuc8Rvy52hgGSfqfOVyXaEsnxYplRMy7vtpyA,4398
|
|
77
77
|
onnx2tf/ops/Identity.py,sha256=egudADqdhe4BiunYHUTh-AlDAkPpRESRT2eG0Q4rBts,2425
|
|
78
|
-
onnx2tf/ops/If.py,sha256=
|
|
78
|
+
onnx2tf/ops/If.py,sha256=Z3VEMm1mOKomYl1Mw58shc83kNPZsYs-wvhse7PlfTY,7062
|
|
79
79
|
onnx2tf/ops/Input.py,sha256=aRZQ4uLWmMS3q317wZO68qqks8p3QDOINhTEObAhvvY,16225
|
|
80
80
|
onnx2tf/ops/InstanceNormalization.py,sha256=gUixsJ1105tt8UGwoLLdZ4V95GiZwzHm_jJMugqQ1yQ,11997
|
|
81
81
|
onnx2tf/ops/Inverse.py,sha256=YsRs0mpZg6dXWXnM1-UU5PcaUvrUqLmDDCNFpirXqp4,4595
|
|
@@ -89,11 +89,14 @@ onnx2tf/ops/Less.py,sha256=YZp5u3cUMU9Gcv_JVqPSIeuaIzVlU0hKy0PnvE6BXFo,4576
|
|
|
89
89
|
onnx2tf/ops/LessOrEqual.py,sha256=9Lc8qaYUPVC6yZoQluNqcdHnvpUbfWBOI4Ow38RRAJo,4595
|
|
90
90
|
onnx2tf/ops/Log.py,sha256=UZebF3SGq85BnoPgYyN2j-zzFRp67fJnYPNyu33W55o,3582
|
|
91
91
|
onnx2tf/ops/LogSoftmax.py,sha256=j2nhYY7__8ViLFJVLA5tS98QEvGS1gTIW0QCdnZWUPQ,3923
|
|
92
|
+
onnx2tf/ops/Loop.py,sha256=I32CWoex8FMXm9KE2aomADB4jK5BzaMoAKvtPnBJy6A,14593
|
|
92
93
|
onnx2tf/ops/LpNormalization.py,sha256=Uu15HgxFNXb6gNMgdTJyf0SLPaLbcbkOYqY_4hMBxNA,3153
|
|
94
|
+
onnx2tf/ops/LpPool.py,sha256=96eI1FaDgW0M_USWBCHFedvtojHTLL28_lb3mcEV55A,10470
|
|
93
95
|
onnx2tf/ops/MatMul.py,sha256=KHhRyQCyxe6845f-AOI1UJzA3rGTssG6eyKmDw0oegs,21466
|
|
94
96
|
onnx2tf/ops/MatMulInteger.py,sha256=qHqzdJNI9SeJDbW8pR90baYCdGN6FdOez4hi9EzwXoc,6538
|
|
95
97
|
onnx2tf/ops/Max.py,sha256=w5nMciO_6ApYUobHuwMGuS3xhuza7eSvKDRhvMPgAuo,3256
|
|
96
98
|
onnx2tf/ops/MaxPool.py,sha256=_JC4eqBTh-qLkZCMG8RZhthRZ8D2d821zaFMWeGMEWc,15775
|
|
99
|
+
onnx2tf/ops/MaxRoiPool.py,sha256=RYZyjnINqJd6k7KLFJ-D9iHjA2vR-m7WvhrumD9cmDk,8490
|
|
97
100
|
onnx2tf/ops/MaxUnpool.py,sha256=dGIEvC45rFuWoeG1S9j4sjEdEUqiWs_xdY5DZH6X7b4,5743
|
|
98
101
|
onnx2tf/ops/Mean.py,sha256=xfTjKpQntJB8uXAkgDLS77oLXy2yQh1Hlz0K2GltMh0,3132
|
|
99
102
|
onnx2tf/ops/MeanVarianceNormalization.py,sha256=Ne53jlDgAJZ9yhzKOWR-0LnjDdM-fg7DYmRytoP-4IA,3743
|
|
@@ -185,22 +188,20 @@ onnx2tf/ops/TopK.py,sha256=f6OG-DcMWneXwSjIkmY935SPyOMD5tMteHnlQHoJwQo,6348
|
|
|
185
188
|
onnx2tf/ops/Transpose.py,sha256=GwJFp7zVqodEsv5mGWviuFqeK93uVM7dbRQ1N8Ua1hg,9774
|
|
186
189
|
onnx2tf/ops/Trilu.py,sha256=uz2TgdErpo9GDp9n4PCe0_koIpNLgBoPCjv3A6VBTl8,4789
|
|
187
190
|
onnx2tf/ops/Unique.py,sha256=GUuOeTO9px22dHmlAn2SOmRHvBgSXo-SaPWm5rYUtPc,4084
|
|
188
|
-
onnx2tf/ops/Unsqueeze.py,sha256=
|
|
191
|
+
onnx2tf/ops/Unsqueeze.py,sha256=UJun_DYfg7aQaHoeAvWlB85oRtDWq2lP7kvb0njcaC0,12219
|
|
189
192
|
onnx2tf/ops/Upsample.py,sha256=SX3N_wZHD8G5Z0PLcPgX1ZCzOdct-uTzxKeMhhzeBOw,5304
|
|
190
193
|
onnx2tf/ops/Where.py,sha256=MaCcY9g4mKZQqCgh4xtoylicP-xVu9f4boKiu_q9Ow8,7711
|
|
191
194
|
onnx2tf/ops/Xor.py,sha256=2ceqxHSI1Wtez_CIh8gFfvcu45Xboqfyp1iy3v2vuIs,4590
|
|
192
|
-
onnx2tf/ops/_Loop.py,sha256=eo5sNfrfOnKV6_I737AWsM5LJTY9DVOxQEvhanxtP4g,11322
|
|
193
|
-
onnx2tf/ops/__Loop.py,sha256=ClwMcbNS4hqUtW_pzwjMa9Cqg7ONvz9aplke55A0uJ0,19704
|
|
194
195
|
onnx2tf/ops/__init__.py,sha256=jnmUWWa-3dHzBZV9bmPzXu6eoz2dumJTzO7i8JdcgSM,25
|
|
195
196
|
onnx2tf/utils/__init__.py,sha256=E9FM9He68VIASDnYp-OrxvHFVn55GzWqw2OEkCqn1zg,27
|
|
196
|
-
onnx2tf/utils/common_functions.py,sha256=
|
|
197
|
+
onnx2tf/utils/common_functions.py,sha256=TWb_e6i2MjB7C4eh1FWHTIDVlr6-7NgSNcCKwKGhGg8,249765
|
|
197
198
|
onnx2tf/utils/enums.py,sha256=7c5TqetqB07VjyHoxJHfLgtqBqk9ZRyUF33fPOJR1IM,1649
|
|
198
199
|
onnx2tf/utils/iterative_json_optimizer.py,sha256=qqeIxWGxrhcCYk8-ebWnblnOkzDCwi-nseipHzHR_bk,10436
|
|
199
200
|
onnx2tf/utils/json_auto_generator.py,sha256=OC-SfKtUg7zUxaXTAg6kT0ShzIc3ByjDa3FNp173DtA,60302
|
|
200
201
|
onnx2tf/utils/logging.py,sha256=yUCmPuJ_XiUItM3sZMcaMO24JErkQy7zZwVTYWAuiKg,1982
|
|
201
|
-
onnx2tf-1.29.
|
|
202
|
-
onnx2tf-1.29.
|
|
203
|
-
onnx2tf-1.29.
|
|
204
|
-
onnx2tf-1.29.
|
|
205
|
-
onnx2tf-1.29.
|
|
206
|
-
onnx2tf-1.29.
|
|
202
|
+
onnx2tf-1.29.10.dist-info/licenses/LICENSE,sha256=5v_Kxihy8i6mzHVl349ikSREaIdsl9YeUnX1KBDLD2w,1070
|
|
203
|
+
onnx2tf-1.29.10.dist-info/licenses/LICENSE_onnx-tensorflow,sha256=gK4GtS9S5YcyINu6uuNNWdo-kBClyEM4MFLFGiNTeRM,11231
|
|
204
|
+
onnx2tf-1.29.10.dist-info/METADATA,sha256=LMFqdTpJPlqtKrJ0jmOeQh5PKMyUy_XDbSPDcARueyE,153516
|
|
205
|
+
onnx2tf-1.29.10.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
206
|
+
onnx2tf-1.29.10.dist-info/top_level.txt,sha256=WgfPiEy3f6vZ_FOpAIEA2CF3TCx1eYrhGw93Ih6b9Fw,8
|
|
207
|
+
onnx2tf-1.29.10.dist-info/RECORD,,
|
onnx2tf/ops/_Loop.py
DELETED
|
@@ -1,306 +0,0 @@
|
|
|
1
|
-
import re
|
|
2
|
-
import sys
|
|
3
|
-
import random
|
|
4
|
-
random.seed(0)
|
|
5
|
-
import numpy as np
|
|
6
|
-
np.random.seed(0)
|
|
7
|
-
import tensorflow as tf
|
|
8
|
-
import tf_keras
|
|
9
|
-
import onnx_graphsurgeon as gs
|
|
10
|
-
from onnx2tf.utils.common_functions import (
|
|
11
|
-
get_constant_or_variable,
|
|
12
|
-
print_node_info,
|
|
13
|
-
inverted_operation_enable_disable,
|
|
14
|
-
make_tf_node_info,
|
|
15
|
-
)
|
|
16
|
-
from onnx2tf.utils.enums import NUMPY_DTYPES_TO_TF_DTYPES
|
|
17
|
-
import importlib
|
|
18
|
-
from onnx2tf.utils.logging import *
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class While_Loop_CustomLayer(tf_keras.layers.Layer):
|
|
22
|
-
def __init__(self):
|
|
23
|
-
super(While_Loop_CustomLayer, self).__init__()
|
|
24
|
-
|
|
25
|
-
def call(self, cond, body, loop_vars, shape_invariants, maximum_iterations):
|
|
26
|
-
return tf.while_loop(
|
|
27
|
-
cond=cond,
|
|
28
|
-
body=body,
|
|
29
|
-
loop_vars=loop_vars,
|
|
30
|
-
shape_invariants=shape_invariants,
|
|
31
|
-
maximum_iterations=maximum_iterations,
|
|
32
|
-
)
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
@print_node_info
|
|
36
|
-
@inverted_operation_enable_disable
|
|
37
|
-
def make_node(
|
|
38
|
-
*,
|
|
39
|
-
graph_node: gs.Node,
|
|
40
|
-
tf_layers_dict: dict,
|
|
41
|
-
**kwargs: dict,
|
|
42
|
-
):
|
|
43
|
-
"""Loop
|
|
44
|
-
|
|
45
|
-
Parameters
|
|
46
|
-
----------
|
|
47
|
-
graph_node: gs.Node
|
|
48
|
-
graph_surgeon Node
|
|
49
|
-
|
|
50
|
-
tf_layers_dict: dict
|
|
51
|
-
optype, shape, dtype, tensorflow graph
|
|
52
|
-
"""
|
|
53
|
-
before_op_output_shape_trans_1 = \
|
|
54
|
-
tf_layers_dict.get(graph_node.inputs[0].name, {}).get('before_op_output_shape_trans', True)
|
|
55
|
-
before_op_output_shape_trans_2 = \
|
|
56
|
-
tf_layers_dict.get(graph_node.inputs[1].name, {}).get('before_op_output_shape_trans', True)
|
|
57
|
-
before_op_output_shape_trans = \
|
|
58
|
-
before_op_output_shape_trans_1 \
|
|
59
|
-
and before_op_output_shape_trans_2
|
|
60
|
-
|
|
61
|
-
graph_node_input_1 = get_constant_or_variable(
|
|
62
|
-
graph_node.inputs[0],
|
|
63
|
-
before_op_output_shape_trans,
|
|
64
|
-
)
|
|
65
|
-
graph_node_input_2 = get_constant_or_variable(
|
|
66
|
-
graph_node.inputs[1],
|
|
67
|
-
before_op_output_shape_trans,
|
|
68
|
-
)
|
|
69
|
-
graph_node_input_n_list = []
|
|
70
|
-
for graph_node_input in graph_node.inputs[2:]:
|
|
71
|
-
graph_node_input_n = get_constant_or_variable(
|
|
72
|
-
graph_node_input,
|
|
73
|
-
before_op_output_shape_trans,
|
|
74
|
-
)
|
|
75
|
-
graph_node_input_n_list.append(graph_node_input_n)
|
|
76
|
-
|
|
77
|
-
M = tf_layers_dict[graph_node_input_1.name]['tf_node'] \
|
|
78
|
-
if isinstance(graph_node_input_1, gs.Variable) else graph_node_input_1
|
|
79
|
-
M = None if isinstance(M, str) and M == "" else M
|
|
80
|
-
M = tf.where(
|
|
81
|
-
tf.greater(M, tf.int32.max),
|
|
82
|
-
tf.constant(tf.int32.max, tf.int32),
|
|
83
|
-
tf.cast(M, tf.int32)
|
|
84
|
-
) if M is not None else M
|
|
85
|
-
cond = tf_layers_dict[graph_node_input_2.name]['tf_node'] \
|
|
86
|
-
if isinstance(graph_node_input_2, gs.Variable) else graph_node_input_2
|
|
87
|
-
cond_init = None if isinstance(cond, str) and cond == "" else tf.cast(cond, tf.bool)
|
|
88
|
-
|
|
89
|
-
v_init = [
|
|
90
|
-
tf_layers_dict[graph_node_input_n.name]['tf_node'] \
|
|
91
|
-
if isinstance(graph_node_input_n, gs.Variable) else graph_node_input_n \
|
|
92
|
-
for graph_node_input_n in graph_node_input_n_list
|
|
93
|
-
]
|
|
94
|
-
v_shapes = [
|
|
95
|
-
tf.TensorShape([None for i in range(len(v.shape))]) for v in v_init
|
|
96
|
-
]
|
|
97
|
-
|
|
98
|
-
body: gs.Graph = graph_node.attrs["body"]
|
|
99
|
-
|
|
100
|
-
iter_cnt_init = np.int64(0)
|
|
101
|
-
|
|
102
|
-
scan_outputs_start_index = 1 + len(v_init)
|
|
103
|
-
scan_outputs_init = [
|
|
104
|
-
tf.TensorArray(
|
|
105
|
-
dtype=body.outputs[i].dtype,
|
|
106
|
-
size=0,
|
|
107
|
-
dynamic_size=True
|
|
108
|
-
) for i in range(scan_outputs_start_index, len(body.outputs))
|
|
109
|
-
]
|
|
110
|
-
scan_outputs_shapes = [tf.TensorShape(None) for o in scan_outputs_init]
|
|
111
|
-
|
|
112
|
-
graph_node_output: gs.Variable = graph_node.outputs[0]
|
|
113
|
-
shape = graph_node_output.shape
|
|
114
|
-
dtype = graph_node_output.dtype
|
|
115
|
-
|
|
116
|
-
# Preserving Graph Structure (Dict)
|
|
117
|
-
tf_layers_dict[graph_node_output.name] = {
|
|
118
|
-
'optype': graph_node.op,
|
|
119
|
-
'shape': shape,
|
|
120
|
-
'dtype': dtype,
|
|
121
|
-
}
|
|
122
|
-
|
|
123
|
-
# Generation of TF OP
|
|
124
|
-
def run_subgraph(iter_cnt, cond, v, scan_outputs):
|
|
125
|
-
for body_input in body.inputs:
|
|
126
|
-
try:
|
|
127
|
-
op = importlib.import_module(f'onnx2tf.ops.Input')
|
|
128
|
-
except ModuleNotFoundError as ex:
|
|
129
|
-
error(
|
|
130
|
-
f'{optype} OP is not yet implemented.'
|
|
131
|
-
)
|
|
132
|
-
sys.exit(1)
|
|
133
|
-
# substitution because saved_model does not allow colons
|
|
134
|
-
body_input.name = body_input.name.replace(':','__')
|
|
135
|
-
# Substitution because saved_model does not allow leading slashes in op names
|
|
136
|
-
if kwargs['output_signaturedefs']:
|
|
137
|
-
body_input.name = re.sub('^/', 'wa/', body_input.name)
|
|
138
|
-
op.make_node(
|
|
139
|
-
graph_input=body_input,
|
|
140
|
-
tf_layers_dict=tf_layers_dict,
|
|
141
|
-
keep_ncw_or_nchw_or_ncdhw_input_names=[],
|
|
142
|
-
keep_nwc_or_nhwc_or_ndhwc_input_names=[],
|
|
143
|
-
keep_shape_absolutely_input_names=[],
|
|
144
|
-
**kwargs,
|
|
145
|
-
)
|
|
146
|
-
for body_node in body.nodes:
|
|
147
|
-
optype = body_node.op
|
|
148
|
-
try:
|
|
149
|
-
op = importlib.import_module(f'onnx2tf.ops.{optype}')
|
|
150
|
-
except ModuleNotFoundError as ex:
|
|
151
|
-
error(
|
|
152
|
-
f'{optype} OP is not yet implemented.'
|
|
153
|
-
)
|
|
154
|
-
sys.exit(1)
|
|
155
|
-
# substitution because saved_model does not allow colons
|
|
156
|
-
body_node.name = body_node.name.replace(':','__')
|
|
157
|
-
# Substitution because saved_model does not allow leading slashes in op names
|
|
158
|
-
if kwargs['output_signaturedefs']:
|
|
159
|
-
body_node.name = re.sub('^/', 'wa/', body_node.name)
|
|
160
|
-
op.make_node(
|
|
161
|
-
graph_node=body_node,
|
|
162
|
-
tf_layers_dict=tf_layers_dict,
|
|
163
|
-
**kwargs,
|
|
164
|
-
)
|
|
165
|
-
# Resister constant
|
|
166
|
-
for output in body.outputs:
|
|
167
|
-
if output.name not in tf_layers_dict and isinstance(output, gs.Constant):
|
|
168
|
-
tf_layers_dict[output.name] = {
|
|
169
|
-
'optype': 'Constant',
|
|
170
|
-
'shape': output.values.shape,
|
|
171
|
-
'dtype': output.values.dtype,
|
|
172
|
-
}
|
|
173
|
-
tf_layers_dict[output.name]['tf_node'] = \
|
|
174
|
-
tf.constant(
|
|
175
|
-
output.values,
|
|
176
|
-
dtype=NUMPY_DTYPES_TO_TF_DTYPES[output.values.dtype],
|
|
177
|
-
)
|
|
178
|
-
outputs = [tf_layers_dict[output.name]['tf_node'] for output in body.outputs]
|
|
179
|
-
for i in range(scan_outputs_start_index, len(outputs)):
|
|
180
|
-
s_index = i - scan_outputs_start_index
|
|
181
|
-
insert_index = scan_outputs[s_index].size()
|
|
182
|
-
scan_outputs[s_index] = scan_outputs[s_index].write(insert_index, outputs[i])
|
|
183
|
-
iter_cnt += 1
|
|
184
|
-
return iter_cnt, outputs[0], outputs[1:scan_outputs_start_index], scan_outputs
|
|
185
|
-
|
|
186
|
-
# for loop
|
|
187
|
-
# https://stackoverflow.com/questions/71635459/how-to-use-keras-symbolic-inputs-with-tf-while-loop
|
|
188
|
-
if M is not None and cond_init is None:
|
|
189
|
-
condition = lambda iter_cnt, cond, v, scan_outputs: True
|
|
190
|
-
while_loop_layer = While_Loop_CustomLayer()
|
|
191
|
-
iter_cnt_final, _, v_final, scan_outputs_final = while_loop_layer(
|
|
192
|
-
cond=condition,
|
|
193
|
-
body=run_subgraph,
|
|
194
|
-
loop_vars=[
|
|
195
|
-
iter_cnt_init,
|
|
196
|
-
"",
|
|
197
|
-
v_init,
|
|
198
|
-
scan_outputs_init,
|
|
199
|
-
],
|
|
200
|
-
shape_invariants=[
|
|
201
|
-
tf.TensorShape([]),
|
|
202
|
-
tf.TensorShape(None),
|
|
203
|
-
v_shapes,
|
|
204
|
-
scan_outputs_shapes,
|
|
205
|
-
],
|
|
206
|
-
maximum_iterations=M,
|
|
207
|
-
)
|
|
208
|
-
# while and do-while loop
|
|
209
|
-
# https://stackoverflow.com/questions/71635459/how-to-use-keras-symbolic-inputs-with-tf-while-loop
|
|
210
|
-
elif M is None and cond_init is not None:
|
|
211
|
-
condition = lambda iter_cnt, cond, v, scan_outputs: tf.reduce_all(tf.equal(cond, True))
|
|
212
|
-
while_loop_layer = While_Loop_CustomLayer()
|
|
213
|
-
iter_cnt_final, cond_final, v_final, scan_outputs_final = while_loop_layer(
|
|
214
|
-
cond=condition,
|
|
215
|
-
body=run_subgraph,
|
|
216
|
-
loop_vars=[
|
|
217
|
-
iter_cnt_init,
|
|
218
|
-
cond_init,
|
|
219
|
-
v_init,
|
|
220
|
-
scan_outputs_init,
|
|
221
|
-
],
|
|
222
|
-
shape_invariants=[
|
|
223
|
-
tf.TensorShape([]),
|
|
224
|
-
tf.TensorShape(None),
|
|
225
|
-
v_shapes,
|
|
226
|
-
scan_outputs_shapes,
|
|
227
|
-
],
|
|
228
|
-
)
|
|
229
|
-
# combine for loop and while loop together
|
|
230
|
-
# https://stackoverflow.com/questions/71635459/how-to-use-keras-symbolic-inputs-with-tf-while-loop
|
|
231
|
-
elif M is not None and cond_init is not None:
|
|
232
|
-
condition = lambda iter_cnt, cond, v, scan_outputs: tf.reduce_all(tf.equal(cond, True))
|
|
233
|
-
while_loop_layer = While_Loop_CustomLayer()
|
|
234
|
-
iter_cnt_final, cond_final, v_final, scan_outputs_final = while_loop_layer(
|
|
235
|
-
cond=condition,
|
|
236
|
-
body=run_subgraph,
|
|
237
|
-
loop_vars=[
|
|
238
|
-
tf.constant(iter_cnt_init, dtype=iter_cnt_init.dtype),
|
|
239
|
-
cond_init,
|
|
240
|
-
v_init,
|
|
241
|
-
scan_outputs_init,
|
|
242
|
-
],
|
|
243
|
-
shape_invariants=[
|
|
244
|
-
tf.TensorShape([]),
|
|
245
|
-
tf.TensorShape(None),
|
|
246
|
-
v_shapes,
|
|
247
|
-
scan_outputs_shapes,
|
|
248
|
-
],
|
|
249
|
-
maximum_iterations=M,
|
|
250
|
-
)
|
|
251
|
-
# M is None and cond is None
|
|
252
|
-
else:
|
|
253
|
-
error(
|
|
254
|
-
f'Both M and cond in Loop are not set at the same time ' +
|
|
255
|
-
f'Tensorflow.(PS. if you want to create a do-while loop ' +
|
|
256
|
-
f'then please set cond to True or 1)\n' +
|
|
257
|
-
f'graph_node.name: {graph_node.name}'
|
|
258
|
-
)
|
|
259
|
-
sys.exit(1)
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
if scan_outputs_start_index == len(body.outputs):
|
|
263
|
-
# there is no scan_output in the body graph
|
|
264
|
-
tf_layers_dict[graph_node_output.name]['tf_node'] = v_final
|
|
265
|
-
|
|
266
|
-
else:
|
|
267
|
-
def true_fn():
|
|
268
|
-
return scan_outputs_final
|
|
269
|
-
|
|
270
|
-
def false_fn():
|
|
271
|
-
new_scan_outputs = []
|
|
272
|
-
for i in range(scan_outputs_start_index, len(body.outputs)):
|
|
273
|
-
exp_elem_shape = scan_outputs_init[i-scan_outputs_start_index].element_shape
|
|
274
|
-
elem_shape = []
|
|
275
|
-
for j in range(exp_elem_shape.rank):
|
|
276
|
-
shape_j = 0 if exp_elem_shape[j] is None else exp_elem_shape[j]
|
|
277
|
-
elem_shape.append(shape_j)
|
|
278
|
-
new_scan_outputs.append(
|
|
279
|
-
tf.TensorArray(
|
|
280
|
-
dtype=body.outputs[i].dtype,
|
|
281
|
-
size=0,
|
|
282
|
-
element_shape=tf.TensorShape(elem_shape)
|
|
283
|
-
)
|
|
284
|
-
)
|
|
285
|
-
return new_scan_outputs
|
|
286
|
-
|
|
287
|
-
scan_out_final = tf.cond(tf.greater(iter_cnt_final, 0), true_fn, false_fn)
|
|
288
|
-
scan_outputs_tensors = [o.stack() for o in scan_out_final]
|
|
289
|
-
tf_layers_dict[graph_node_output.name]['tf_node'] = v_final + scan_outputs_tensors
|
|
290
|
-
|
|
291
|
-
# Generation of Debug Info
|
|
292
|
-
tf_layers_dict[graph_node_output.name]['tf_node_info'] = \
|
|
293
|
-
make_tf_node_info(
|
|
294
|
-
node_info={
|
|
295
|
-
'tf_op_type': tf.while_loop,
|
|
296
|
-
'tf_inputs': {
|
|
297
|
-
'condition': condition,
|
|
298
|
-
'M': M,
|
|
299
|
-
'cond': cond_init,
|
|
300
|
-
'v_initial': v_init,
|
|
301
|
-
},
|
|
302
|
-
'tf_outputs': {
|
|
303
|
-
'output': tf_layers_dict[graph_node_output.name]['tf_node'],
|
|
304
|
-
},
|
|
305
|
-
}
|
|
306
|
-
)
|