onnx2tf 1.29.8__py3-none-any.whl → 1.29.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2tf/__init__.py +1 -1
- onnx2tf/onnx2tf.py +83 -73
- onnx2tf/ops/If.py +4 -2
- onnx2tf/ops/Loop.py +392 -0
- onnx2tf/ops/LpPool.py +296 -0
- onnx2tf/ops/MaxRoiPool.py +236 -0
- onnx2tf/ops/Unsqueeze.py +69 -37
- onnx2tf/utils/common_functions.py +3 -0
- {onnx2tf-1.29.8.dist-info → onnx2tf-1.29.10.dist-info}/METADATA +6 -6
- {onnx2tf-1.29.8.dist-info → onnx2tf-1.29.10.dist-info}/RECORD +14 -13
- onnx2tf/ops/_Loop.py +0 -306
- onnx2tf/ops/__Loop.py +0 -509
- {onnx2tf-1.29.8.dist-info → onnx2tf-1.29.10.dist-info}/WHEEL +0 -0
- {onnx2tf-1.29.8.dist-info → onnx2tf-1.29.10.dist-info}/licenses/LICENSE +0 -0
- {onnx2tf-1.29.8.dist-info → onnx2tf-1.29.10.dist-info}/licenses/LICENSE_onnx-tensorflow +0 -0
- {onnx2tf-1.29.8.dist-info → onnx2tf-1.29.10.dist-info}/top_level.txt +0 -0
onnx2tf/ops/__Loop.py
DELETED
|
@@ -1,509 +0,0 @@
|
|
|
1
|
-
import re
|
|
2
|
-
import sys
|
|
3
|
-
import random
|
|
4
|
-
random.seed(0)
|
|
5
|
-
import numpy as np
|
|
6
|
-
np.random.seed(0)
|
|
7
|
-
import tensorflow as tf
|
|
8
|
-
import tf_keras
|
|
9
|
-
import onnx_graphsurgeon as gs
|
|
10
|
-
from onnx2tf.utils.common_functions import (
|
|
11
|
-
get_constant_or_variable,
|
|
12
|
-
print_node_info,
|
|
13
|
-
inverted_operation_enable_disable,
|
|
14
|
-
make_tf_node_info,
|
|
15
|
-
)
|
|
16
|
-
from onnx2tf.utils.enums import NUMPY_DTYPES_TO_TF_DTYPES
|
|
17
|
-
import importlib
|
|
18
|
-
from onnx2tf.utils.logging import *
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class While_Loop_CustomLayer(tf_keras.layers.Layer):
|
|
22
|
-
def __init__(self):
|
|
23
|
-
super(While_Loop_CustomLayer, self).__init__()
|
|
24
|
-
|
|
25
|
-
def call(self, cond, body, loop_vars, shape_invariants, maximum_iterations):
|
|
26
|
-
return tf.while_loop(
|
|
27
|
-
cond=cond,
|
|
28
|
-
body=body,
|
|
29
|
-
loop_vars=loop_vars,
|
|
30
|
-
shape_invariants=shape_invariants,
|
|
31
|
-
maximum_iterations=maximum_iterations,
|
|
32
|
-
)
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
@print_node_info
|
|
36
|
-
@inverted_operation_enable_disable
|
|
37
|
-
def make_node(
|
|
38
|
-
*,
|
|
39
|
-
graph_node: gs.Node,
|
|
40
|
-
tf_layers_dict: dict,
|
|
41
|
-
**kwargs: dict,
|
|
42
|
-
):
|
|
43
|
-
"""Loop
|
|
44
|
-
|
|
45
|
-
Parameters
|
|
46
|
-
----------
|
|
47
|
-
graph_node: gs.Node
|
|
48
|
-
graph_surgeon Node
|
|
49
|
-
|
|
50
|
-
tf_layers_dict: dict
|
|
51
|
-
optype, shape, dtype, tensorflow graph
|
|
52
|
-
"""
|
|
53
|
-
before_op_output_shape_trans_1 = \
|
|
54
|
-
tf_layers_dict.get(graph_node.inputs[0].name, {}).get('before_op_output_shape_trans', True)
|
|
55
|
-
before_op_output_shape_trans_2 = \
|
|
56
|
-
tf_layers_dict.get(graph_node.inputs[1].name, {}).get('before_op_output_shape_trans', True)
|
|
57
|
-
before_op_output_shape_trans = \
|
|
58
|
-
before_op_output_shape_trans_1 \
|
|
59
|
-
and before_op_output_shape_trans_2
|
|
60
|
-
|
|
61
|
-
graph_node_input_1 = get_constant_or_variable(
|
|
62
|
-
graph_node.inputs[0],
|
|
63
|
-
before_op_output_shape_trans,
|
|
64
|
-
)
|
|
65
|
-
graph_node_input_2 = get_constant_or_variable(
|
|
66
|
-
graph_node.inputs[1],
|
|
67
|
-
before_op_output_shape_trans,
|
|
68
|
-
)
|
|
69
|
-
graph_node_input_n_list = []
|
|
70
|
-
for graph_node_input in graph_node.inputs[2:]:
|
|
71
|
-
graph_node_input_n = get_constant_or_variable(
|
|
72
|
-
graph_node_input,
|
|
73
|
-
before_op_output_shape_trans,
|
|
74
|
-
)
|
|
75
|
-
graph_node_input_n_list.append(graph_node_input_n)
|
|
76
|
-
|
|
77
|
-
# M はループ終了条件 (ループ回数)
|
|
78
|
-
# ONNXは純粋にループカンタとして処理可能だが、TensorFlowは M に類する引数が無い
|
|
79
|
-
# したがって、グラフの一部として事前に引き渡す必要が有る
|
|
80
|
-
# 1. 引き渡されたループ回数が前のOPからの入力変数の場合はすでにグラフの中に埋め込まれているので何もしない
|
|
81
|
-
# 2. 引き渡されたループ回数が定数(np.ndarray)の場合はグラフの中にまだ埋め込まれていないので埋め込む
|
|
82
|
-
M = tf_layers_dict[graph_node_input_1.name]['tf_node'] \
|
|
83
|
-
if isinstance(graph_node_input_1, gs.Variable) else graph_node_input_1
|
|
84
|
-
M = None if isinstance(M, str) and M == "" else M
|
|
85
|
-
M = tf.convert_to_tensor(M) if isinstance(M, np.ndarray) else M
|
|
86
|
-
M = tf.where(
|
|
87
|
-
tf.greater(M, tf.int32.max),
|
|
88
|
-
tf.constant(tf.int32.max, tf.int32),
|
|
89
|
-
tf.cast(M, tf.int32)
|
|
90
|
-
) if M is not None else tf.constant(tf.int32.max, tf.int32)
|
|
91
|
-
M_name = None
|
|
92
|
-
if not isinstance(graph_node_input_1, np.ndarray):
|
|
93
|
-
graph_node_input_1.name = graph_node_input_1.name.replace(':','__')
|
|
94
|
-
M_name = graph_node_input_1.name
|
|
95
|
-
else:
|
|
96
|
-
M_name = graph_node.inputs[0].name.replace(':','__')
|
|
97
|
-
M_name = f'{M_name}_M'
|
|
98
|
-
if kwargs['output_signaturedefs']:
|
|
99
|
-
M_name = re.sub('^/', 'wa/', M_name)
|
|
100
|
-
tf_layers_dict[f'{M_name}'] = {
|
|
101
|
-
'optype': 'Constant' if hasattr(M, 'numpy') else 'Variable',
|
|
102
|
-
'shape': M.shape,
|
|
103
|
-
'dtype': M.dtype,
|
|
104
|
-
}
|
|
105
|
-
tf_layers_dict[f'{M_name}']['tf_node'] = M
|
|
106
|
-
|
|
107
|
-
# ループ1周目に渡すループ継続条件判定値
|
|
108
|
-
# ループ継続判定用bool値
|
|
109
|
-
# デフォルト: True
|
|
110
|
-
cond = tf_layers_dict[graph_node_input_2.name]['tf_node'] \
|
|
111
|
-
if isinstance(graph_node_input_2, gs.Variable) else graph_node_input_2
|
|
112
|
-
cond_init = None if isinstance(cond, str) and cond == "" else tf.cast(cond, tf.bool)
|
|
113
|
-
cond_init = tf.convert_to_tensor(cond_init) if isinstance(cond_init, np.ndarray) else cond_init
|
|
114
|
-
|
|
115
|
-
cond_init_name = None
|
|
116
|
-
if not isinstance(graph_node_input_2, np.ndarray):
|
|
117
|
-
graph_node_input_2.name = graph_node_input_2.name.replace(':','__')
|
|
118
|
-
cond_init_name = graph_node_input_2.name
|
|
119
|
-
else:
|
|
120
|
-
cond_init_name = graph_node.inputs[0].name.replace(':','__')
|
|
121
|
-
cond_init_name = f'{cond_init_name}_cond_init'
|
|
122
|
-
if kwargs['output_signaturedefs']:
|
|
123
|
-
cond_init_name = re.sub('^/', 'wa/', cond_init_name)
|
|
124
|
-
tf_layers_dict[f'{cond_init_name}'] = {
|
|
125
|
-
'optype': 'Constant' if hasattr(cond_init, 'numpy') else 'Variable',
|
|
126
|
-
'shape': cond_init.shape,
|
|
127
|
-
'dtype': cond_init.dtype,
|
|
128
|
-
}
|
|
129
|
-
tf_layers_dict[f'{cond_init_name}']['tf_node'] = cond_init
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
# ボディ部で処理対象とするループカウンタを除いた変数のリスト
|
|
135
|
-
# 前のOPの出力 あるいは 定数
|
|
136
|
-
v_init = [
|
|
137
|
-
tf_layers_dict[graph_node_input_n.name]['tf_node'] \
|
|
138
|
-
if isinstance(graph_node_input_n, gs.Variable) else graph_node_input_n \
|
|
139
|
-
for graph_node_input_n in graph_node_input_n_list
|
|
140
|
-
]
|
|
141
|
-
|
|
142
|
-
# ボディ部で処理対象とするループカウンタを除いた全変数の出力形状
|
|
143
|
-
# shape_invariants (ループのボディ部内部で形状が変化する可能性が有る場合に指定が必要となる出力形状のヒント) に使用する
|
|
144
|
-
v_shapes = [
|
|
145
|
-
tf.TensorShape([None for i in range(len(v.shape))]) for v in v_init
|
|
146
|
-
]
|
|
147
|
-
|
|
148
|
-
# ボディ部のグラフ
|
|
149
|
-
body: gs.Graph = graph_node.attrs["body"]
|
|
150
|
-
|
|
151
|
-
# ループカウンタの初期値 ゼロ固定
|
|
152
|
-
iter_cnt_init = tf.convert_to_tensor(np.int32(0))
|
|
153
|
-
|
|
154
|
-
scan_outputs_start_index = 1 + len(v_init)
|
|
155
|
-
scan_outputs_init = [
|
|
156
|
-
tf.TensorArray(
|
|
157
|
-
dtype=body.outputs[i].dtype,
|
|
158
|
-
size=0,
|
|
159
|
-
dynamic_size=True
|
|
160
|
-
) for i in range(scan_outputs_start_index, len(body.outputs))
|
|
161
|
-
]
|
|
162
|
-
scan_outputs_shapes = [tf.TensorShape(None) for o in scan_outputs_init]
|
|
163
|
-
|
|
164
|
-
graph_node_output: gs.Variable = graph_node.outputs[0]
|
|
165
|
-
shape = graph_node_output.shape
|
|
166
|
-
dtype = graph_node_output.dtype
|
|
167
|
-
|
|
168
|
-
# Preserving Graph Structure (Dict)
|
|
169
|
-
tf_layers_dict[graph_node_output.name] = {
|
|
170
|
-
'optype': graph_node.op,
|
|
171
|
-
'shape': shape,
|
|
172
|
-
'dtype': dtype,
|
|
173
|
-
}
|
|
174
|
-
|
|
175
|
-
# Generation of TF OP
|
|
176
|
-
# def run_subgraph(iter_cnt, ):
|
|
177
|
-
# for body_input in body.inputs:
|
|
178
|
-
# try:
|
|
179
|
-
# op = importlib.import_module(f'onnx2tf.ops.Input')
|
|
180
|
-
# except ModuleNotFoundError as ex:
|
|
181
|
-
# error(
|
|
182
|
-
# f'{optype} OP is not yet implemented.'
|
|
183
|
-
# )
|
|
184
|
-
# sys.exit(1)
|
|
185
|
-
# # substitution because saved_model does not allow colons
|
|
186
|
-
# body_input.name = body_input.name.replace(':','__')
|
|
187
|
-
# # Substitution because saved_model does not allow leading slashes in op names
|
|
188
|
-
# if kwargs['output_signaturedefs']:
|
|
189
|
-
# body_input.name = re.sub('^/', 'wa/', body_input.name)
|
|
190
|
-
# op.make_node(
|
|
191
|
-
# graph_input=body_input,
|
|
192
|
-
# tf_layers_dict=tf_layers_dict,
|
|
193
|
-
# keep_ncw_or_nchw_or_ncdhw_input_names=[],
|
|
194
|
-
# keep_nwc_or_nhwc_or_ndhwc_input_names=[],
|
|
195
|
-
# keep_shape_absolutely_input_names=[],
|
|
196
|
-
# **kwargs,
|
|
197
|
-
# )
|
|
198
|
-
# for body_node in body.nodes:
|
|
199
|
-
# optype = body_node.op
|
|
200
|
-
# try:
|
|
201
|
-
# op = importlib.import_module(f'onnx2tf.ops.{optype}')
|
|
202
|
-
# except ModuleNotFoundError as ex:
|
|
203
|
-
# error(
|
|
204
|
-
# f'{optype} OP is not yet implemented.'
|
|
205
|
-
# )
|
|
206
|
-
# sys.exit(1)
|
|
207
|
-
# # substitution because saved_model does not allow colons
|
|
208
|
-
# body_node.name = body_node.name.replace(':','__')
|
|
209
|
-
# # Substitution because saved_model does not allow leading slashes in op names
|
|
210
|
-
# if kwargs['output_signaturedefs']:
|
|
211
|
-
# body_node.name = re.sub('^/', 'wa/', body_node.name)
|
|
212
|
-
# op.make_node(
|
|
213
|
-
# graph_node=body_node,
|
|
214
|
-
# tf_layers_dict=tf_layers_dict,
|
|
215
|
-
# **kwargs,
|
|
216
|
-
# )
|
|
217
|
-
# # Resister constant
|
|
218
|
-
# for output in body.outputs:
|
|
219
|
-
# if output.name not in tf_layers_dict and isinstance(output, gs.Constant):
|
|
220
|
-
# tf_layers_dict[output.name] = {
|
|
221
|
-
# 'optype': 'Constant',
|
|
222
|
-
# 'shape': output.values.shape,
|
|
223
|
-
# 'dtype': output.values.dtype,
|
|
224
|
-
# }
|
|
225
|
-
# tf_layers_dict[output.name]['tf_node'] = \
|
|
226
|
-
# tf.constant(
|
|
227
|
-
# output.values,
|
|
228
|
-
# dtype=NUMPY_DTYPES_TO_TF_DTYPES[output.values.dtype],
|
|
229
|
-
# )
|
|
230
|
-
# outputs = [tf_layers_dict[output.name]['tf_node'] for output in body.outputs]
|
|
231
|
-
# for i in range(scan_outputs_start_index, len(outputs)):
|
|
232
|
-
# s_index = i - scan_outputs_start_index
|
|
233
|
-
# insert_index = scan_outputs[s_index].size()
|
|
234
|
-
# scan_outputs[s_index] = scan_outputs[s_index].write(insert_index, outputs[i])
|
|
235
|
-
# iter_cnt += 1
|
|
236
|
-
# return iter_cnt, outputs[0], outputs[1:scan_outputs_start_index], scan_outputs
|
|
237
|
-
def run_subgraph(iter_cnt, ):
|
|
238
|
-
for body_input in body.inputs:
|
|
239
|
-
try:
|
|
240
|
-
op = importlib.import_module(f'onnx2tf.ops.Input')
|
|
241
|
-
except ModuleNotFoundError as ex:
|
|
242
|
-
error(
|
|
243
|
-
f'{optype} OP is not yet implemented.'
|
|
244
|
-
)
|
|
245
|
-
sys.exit(1)
|
|
246
|
-
# substitution because saved_model does not allow colons
|
|
247
|
-
body_input.name = body_input.name.replace(':','__')
|
|
248
|
-
# Substitution because saved_model does not allow leading slashes in op names
|
|
249
|
-
if kwargs['output_signaturedefs']:
|
|
250
|
-
body_input.name = re.sub('^/', 'wa/', body_input.name)
|
|
251
|
-
op.make_node(
|
|
252
|
-
graph_input=body_input,
|
|
253
|
-
tf_layers_dict=tf_layers_dict,
|
|
254
|
-
keep_ncw_or_nchw_or_ncdhw_input_names=[],
|
|
255
|
-
keep_nwc_or_nhwc_or_ndhwc_input_names=[],
|
|
256
|
-
keep_shape_absolutely_input_names=[],
|
|
257
|
-
**kwargs,
|
|
258
|
-
)
|
|
259
|
-
for body_node in body.nodes:
|
|
260
|
-
optype = body_node.op
|
|
261
|
-
try:
|
|
262
|
-
op = importlib.import_module(f'onnx2tf.ops.{optype}')
|
|
263
|
-
except ModuleNotFoundError as ex:
|
|
264
|
-
error(
|
|
265
|
-
f'{optype} OP is not yet implemented.'
|
|
266
|
-
)
|
|
267
|
-
sys.exit(1)
|
|
268
|
-
# substitution because saved_model does not allow colons
|
|
269
|
-
body_node.name = body_node.name.replace(':','__')
|
|
270
|
-
# Substitution because saved_model does not allow leading slashes in op names
|
|
271
|
-
if kwargs['output_signaturedefs']:
|
|
272
|
-
body_node.name = re.sub('^/', 'wa/', body_node.name)
|
|
273
|
-
op.make_node(
|
|
274
|
-
graph_node=body_node,
|
|
275
|
-
tf_layers_dict=tf_layers_dict,
|
|
276
|
-
**kwargs,
|
|
277
|
-
)
|
|
278
|
-
# Resister constant
|
|
279
|
-
for output in body.outputs:
|
|
280
|
-
if output.name not in tf_layers_dict and isinstance(output, gs.Constant):
|
|
281
|
-
tf_layers_dict[output.name] = {
|
|
282
|
-
'optype': 'Constant',
|
|
283
|
-
'shape': output.values.shape,
|
|
284
|
-
'dtype': output.values.dtype,
|
|
285
|
-
}
|
|
286
|
-
tf_layers_dict[output.name]['tf_node'] = \
|
|
287
|
-
tf.constant(
|
|
288
|
-
output.values,
|
|
289
|
-
dtype=NUMPY_DTYPES_TO_TF_DTYPES[output.values.dtype],
|
|
290
|
-
)
|
|
291
|
-
outputs = [tf_layers_dict[output.name]['tf_node'] for output in body.outputs]
|
|
292
|
-
iter_cnt += 1
|
|
293
|
-
return iter_cnt, outputs[0], outputs[1:scan_outputs_start_index]
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
# Regiter v_initial
|
|
299
|
-
# 1. Loop OP で処理対象とする変数のリスト
|
|
300
|
-
# 2. 変数のリストは直前のOPそのもの、あるいは、定数のどちらか
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
# Register body - Inputs
|
|
304
|
-
for body_input in body.inputs:
|
|
305
|
-
try:
|
|
306
|
-
op = importlib.import_module(f'onnx2tf.ops.Input')
|
|
307
|
-
except ModuleNotFoundError as ex:
|
|
308
|
-
error(
|
|
309
|
-
f'{optype} OP is not yet implemented.'
|
|
310
|
-
)
|
|
311
|
-
sys.exit(1)
|
|
312
|
-
# substitution because saved_model does not allow colons
|
|
313
|
-
body_input.name = body_input.name.replace(':','__')
|
|
314
|
-
# Substitution because saved_model does not allow leading slashes in op names
|
|
315
|
-
if kwargs['output_signaturedefs']:
|
|
316
|
-
body_input.name = re.sub('^/', 'wa/', body_input.name)
|
|
317
|
-
op.make_node(
|
|
318
|
-
graph_input=body_input,
|
|
319
|
-
tf_layers_dict=tf_layers_dict,
|
|
320
|
-
keep_ncw_or_nchw_or_ncdhw_input_names=[],
|
|
321
|
-
keep_nwc_or_nhwc_or_ndhwc_input_names=[],
|
|
322
|
-
keep_shape_absolutely_input_names=[],
|
|
323
|
-
**kwargs,
|
|
324
|
-
)
|
|
325
|
-
# Register body - Nodes
|
|
326
|
-
for body_node in body.nodes:
|
|
327
|
-
optype = body_node.op
|
|
328
|
-
try:
|
|
329
|
-
op = importlib.import_module(f'onnx2tf.ops.{optype}')
|
|
330
|
-
except ModuleNotFoundError as ex:
|
|
331
|
-
error(
|
|
332
|
-
f'{optype} OP is not yet implemented.'
|
|
333
|
-
)
|
|
334
|
-
sys.exit(1)
|
|
335
|
-
# substitution because saved_model does not allow colons
|
|
336
|
-
body_node.name = body_node.name.replace(':','__')
|
|
337
|
-
# Substitution because saved_model does not allow leading slashes in op names
|
|
338
|
-
if kwargs['output_signaturedefs']:
|
|
339
|
-
body_node.name = re.sub('^/', 'wa/', body_node.name)
|
|
340
|
-
op.make_node(
|
|
341
|
-
graph_node=body_node,
|
|
342
|
-
tf_layers_dict=tf_layers_dict,
|
|
343
|
-
**kwargs,
|
|
344
|
-
)
|
|
345
|
-
# Register body - Constant outputs
|
|
346
|
-
for output in body.outputs:
|
|
347
|
-
if output.name not in tf_layers_dict and isinstance(output, gs.Constant):
|
|
348
|
-
tf_layers_dict[output.name] = {
|
|
349
|
-
'optype': 'Constant',
|
|
350
|
-
'shape': output.values.shape,
|
|
351
|
-
'dtype': output.values.dtype,
|
|
352
|
-
}
|
|
353
|
-
tf_layers_dict[output.name]['tf_node'] = \
|
|
354
|
-
tf.constant(
|
|
355
|
-
output.values,
|
|
356
|
-
dtype=NUMPY_DTYPES_TO_TF_DTYPES[output.values.dtype],
|
|
357
|
-
)
|
|
358
|
-
|
|
359
|
-
def condition(i, x):
|
|
360
|
-
return tf.less(i, tf_layers_dict[f'{M_name}']['tf_node'])
|
|
361
|
-
|
|
362
|
-
def body_graph(i, x):
|
|
363
|
-
return tf.add(i, 1),
|
|
364
|
-
|
|
365
|
-
a = tf.while_loop(
|
|
366
|
-
cond=condition,
|
|
367
|
-
body=body_graph,
|
|
368
|
-
loop_vars=[iter_cnt_init],
|
|
369
|
-
)
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
# for loop
|
|
378
|
-
# https://stackoverflow.com/questions/71635459/how-to-use-keras-symbolic-inputs-with-tf-while-loop
|
|
379
|
-
if M is not None and cond_init is None:
|
|
380
|
-
condition = lambda iter_cnt, cond, v, scan_outputs: True
|
|
381
|
-
while_loop_layer = While_Loop_CustomLayer()
|
|
382
|
-
iter_cnt_final, _, v_final, scan_outputs_final = while_loop_layer(
|
|
383
|
-
cond=condition,
|
|
384
|
-
body=run_subgraph,
|
|
385
|
-
loop_vars=[
|
|
386
|
-
iter_cnt_init,
|
|
387
|
-
"",
|
|
388
|
-
v_init,
|
|
389
|
-
scan_outputs_init,
|
|
390
|
-
],
|
|
391
|
-
shape_invariants=[
|
|
392
|
-
tf.TensorShape([]),
|
|
393
|
-
tf.TensorShape(None),
|
|
394
|
-
v_shapes,
|
|
395
|
-
scan_outputs_shapes,
|
|
396
|
-
],
|
|
397
|
-
maximum_iterations=M,
|
|
398
|
-
)
|
|
399
|
-
# while and do-while loop
|
|
400
|
-
# https://stackoverflow.com/questions/71635459/how-to-use-keras-symbolic-inputs-with-tf-while-loop
|
|
401
|
-
elif M is None and cond_init is not None:
|
|
402
|
-
condition = lambda iter_cnt, cond, v, scan_outputs: tf.reduce_all(tf.equal(cond, True))
|
|
403
|
-
while_loop_layer = While_Loop_CustomLayer()
|
|
404
|
-
iter_cnt_final, cond_final, v_final, scan_outputs_final = while_loop_layer(
|
|
405
|
-
cond=condition,
|
|
406
|
-
body=run_subgraph,
|
|
407
|
-
loop_vars=[
|
|
408
|
-
iter_cnt_init,
|
|
409
|
-
cond_init,
|
|
410
|
-
v_init,
|
|
411
|
-
scan_outputs_init,
|
|
412
|
-
],
|
|
413
|
-
shape_invariants=[
|
|
414
|
-
tf.TensorShape([]),
|
|
415
|
-
tf.TensorShape(None),
|
|
416
|
-
v_shapes,
|
|
417
|
-
scan_outputs_shapes,
|
|
418
|
-
],
|
|
419
|
-
)
|
|
420
|
-
# combine for loop and while loop together
|
|
421
|
-
# https://stackoverflow.com/questions/71635459/how-to-use-keras-symbolic-inputs-with-tf-while-loop
|
|
422
|
-
elif M is not None and cond_init is not None:
|
|
423
|
-
# condition = lambda iter_cnt, cond, v, scan_outputs: tf.reduce_all(tf.equal(cond, True))
|
|
424
|
-
# while_loop_layer = While_Loop_CustomLayer()
|
|
425
|
-
# iter_cnt_final, cond_final, v_final, scan_outputs_final = while_loop_layer(
|
|
426
|
-
# cond=condition,
|
|
427
|
-
# body=run_subgraph,
|
|
428
|
-
# loop_vars=[
|
|
429
|
-
# tf.constant(iter_cnt_init, dtype=iter_cnt_init.dtype),
|
|
430
|
-
# cond_init,
|
|
431
|
-
# v_init,
|
|
432
|
-
# scan_outputs_init,
|
|
433
|
-
# ],
|
|
434
|
-
# shape_invariants=[
|
|
435
|
-
# tf.TensorShape([]),
|
|
436
|
-
# tf.TensorShape(None),
|
|
437
|
-
# v_shapes,
|
|
438
|
-
# scan_outputs_shapes,
|
|
439
|
-
# ],
|
|
440
|
-
# maximum_iterations=M,
|
|
441
|
-
# )
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
test = tf.while_loop(
|
|
445
|
-
cond=cond_init.numpy() \
|
|
446
|
-
if hasattr(cond_init, 'numpy') else cond_init,
|
|
447
|
-
body=run_subgraph,
|
|
448
|
-
loop_vars=[iter_cnt_init, *v_init]
|
|
449
|
-
)
|
|
450
|
-
a = 0
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
# M is None and cond is None
|
|
455
|
-
else:
|
|
456
|
-
error(
|
|
457
|
-
f'Both M and cond in Loop are not set at the same time ' +
|
|
458
|
-
f'Tensorflow.(PS. if you want to create a do-while loop ' +
|
|
459
|
-
f'then please set cond to True or 1)\n' +
|
|
460
|
-
f'graph_node.name: {graph_node.name}'
|
|
461
|
-
)
|
|
462
|
-
sys.exit(1)
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
if scan_outputs_start_index == len(body.outputs):
|
|
466
|
-
# there is no scan_output in the body graph
|
|
467
|
-
tf_layers_dict[graph_node_output.name]['tf_node'] = v_final
|
|
468
|
-
|
|
469
|
-
else:
|
|
470
|
-
def true_fn():
|
|
471
|
-
return scan_outputs_final
|
|
472
|
-
|
|
473
|
-
def false_fn():
|
|
474
|
-
new_scan_outputs = []
|
|
475
|
-
for i in range(scan_outputs_start_index, len(body.outputs)):
|
|
476
|
-
exp_elem_shape = scan_outputs_init[i-scan_outputs_start_index].element_shape
|
|
477
|
-
elem_shape = []
|
|
478
|
-
for j in range(exp_elem_shape.rank):
|
|
479
|
-
shape_j = 0 if exp_elem_shape[j] is None else exp_elem_shape[j]
|
|
480
|
-
elem_shape.append(shape_j)
|
|
481
|
-
new_scan_outputs.append(
|
|
482
|
-
tf.TensorArray(
|
|
483
|
-
dtype=body.outputs[i].dtype,
|
|
484
|
-
size=0,
|
|
485
|
-
element_shape=tf.TensorShape(elem_shape)
|
|
486
|
-
)
|
|
487
|
-
)
|
|
488
|
-
return new_scan_outputs
|
|
489
|
-
|
|
490
|
-
scan_out_final = tf.cond(tf.greater(iter_cnt_final, 0), true_fn, false_fn)
|
|
491
|
-
scan_outputs_tensors = [o.stack() for o in scan_out_final]
|
|
492
|
-
tf_layers_dict[graph_node_output.name]['tf_node'] = v_final + scan_outputs_tensors
|
|
493
|
-
|
|
494
|
-
# Generation of Debug Info
|
|
495
|
-
tf_layers_dict[graph_node_output.name]['tf_node_info'] = \
|
|
496
|
-
make_tf_node_info(
|
|
497
|
-
node_info={
|
|
498
|
-
'tf_op_type': tf.while_loop,
|
|
499
|
-
'tf_inputs': {
|
|
500
|
-
'condition': condition,
|
|
501
|
-
'M': M,
|
|
502
|
-
'cond': cond_init,
|
|
503
|
-
'v_initial': v_init,
|
|
504
|
-
},
|
|
505
|
-
'tf_outputs': {
|
|
506
|
-
'output': tf_layers_dict[graph_node_output.name]['tf_node'],
|
|
507
|
-
},
|
|
508
|
-
}
|
|
509
|
-
)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|