onnx2tf 1.29.18__py3-none-any.whl → 1.29.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2tf/__init__.py +1 -1
- onnx2tf/onnx2tf.py +967 -27
- onnx2tf/ops/Col2Im.py +108 -64
- onnx2tf/ops/DFT.py +245 -0
- onnx2tf/ops/DeformConv.py +399 -0
- onnx2tf/ops/GatherElements.py +25 -7
- onnx2tf/ops/GatherND.py +28 -1
- onnx2tf/ops/ScatterElements.py +25 -7
- onnx2tf/ops/ScatterND.py +45 -6
- onnx2tf/ops/TensorScatter.py +20 -6
- onnx2tf/utils/common_functions.py +99 -2
- {onnx2tf-1.29.18.dist-info → onnx2tf-1.29.20.dist-info}/METADATA +27 -5
- {onnx2tf-1.29.18.dist-info → onnx2tf-1.29.20.dist-info}/RECORD +15 -13
- {onnx2tf-1.29.18.dist-info → onnx2tf-1.29.20.dist-info}/WHEEL +0 -0
- {onnx2tf-1.29.18.dist-info → onnx2tf-1.29.20.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,399 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import random
|
|
3
|
+
random.seed(0)
|
|
4
|
+
import numpy as np
|
|
5
|
+
np.random.seed(0)
|
|
6
|
+
import tensorflow as tf
|
|
7
|
+
import onnx_graphsurgeon as gs
|
|
8
|
+
from onnx2tf.utils.common_functions import (
|
|
9
|
+
get_constant_or_variable,
|
|
10
|
+
get_weights_constant_or_variable,
|
|
11
|
+
print_node_info,
|
|
12
|
+
inverted_operation_enable_disable,
|
|
13
|
+
make_tf_node_info,
|
|
14
|
+
get_replacement_parameter,
|
|
15
|
+
pre_process_transpose,
|
|
16
|
+
post_process_transpose,
|
|
17
|
+
transpose_with_flexing_deterrence,
|
|
18
|
+
)
|
|
19
|
+
from onnx2tf.utils.logging import *
|
|
20
|
+
|
|
21
|
+
INF_INDEX_VALUE: int = 4294967296
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _to_int_tensor(value, name=None):
|
|
25
|
+
if isinstance(value, tf.Tensor):
|
|
26
|
+
return tf.cast(value, tf.int32)
|
|
27
|
+
return tf.constant(value, dtype=tf.int32, name=name)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _bilinear_sample_2d(
|
|
31
|
+
image,
|
|
32
|
+
coords,
|
|
33
|
+
):
|
|
34
|
+
"""
|
|
35
|
+
image: [N, H, W, C]
|
|
36
|
+
coords: [N, oH, oW, kH, kW, 2] in absolute coords (y, x)
|
|
37
|
+
"""
|
|
38
|
+
coord_dtype = coords.dtype
|
|
39
|
+
h = tf.shape(image)[1]
|
|
40
|
+
w = tf.shape(image)[2]
|
|
41
|
+
h_f = tf.cast(h, coord_dtype)
|
|
42
|
+
w_f = tf.cast(w, coord_dtype)
|
|
43
|
+
max_y = h_f - 1.0
|
|
44
|
+
max_x = w_f - 1.0
|
|
45
|
+
|
|
46
|
+
y, x = tf.split(coords, num_or_size_splits=2, axis=-1)
|
|
47
|
+
|
|
48
|
+
y0 = tf.floor(y)
|
|
49
|
+
x0 = tf.floor(x)
|
|
50
|
+
y1 = y0 + 1.0
|
|
51
|
+
x1 = x0 + 1.0
|
|
52
|
+
|
|
53
|
+
dy = y - y0
|
|
54
|
+
dx = x - x0
|
|
55
|
+
|
|
56
|
+
w00 = (1.0 - dy) * (1.0 - dx)
|
|
57
|
+
w10 = dy * (1.0 - dx)
|
|
58
|
+
w11 = dy * dx
|
|
59
|
+
w01 = (1.0 - dy) * dx
|
|
60
|
+
|
|
61
|
+
def _in_bounds(y_idx, x_idx):
|
|
62
|
+
return tf.logical_and(
|
|
63
|
+
tf.logical_and(y_idx >= 0.0, y_idx <= max_y),
|
|
64
|
+
tf.logical_and(x_idx >= 0.0, x_idx <= max_x),
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
m00 = _in_bounds(y0, x0)
|
|
68
|
+
m10 = _in_bounds(y1, x0)
|
|
69
|
+
m11 = _in_bounds(y1, x1)
|
|
70
|
+
m01 = _in_bounds(y0, x1)
|
|
71
|
+
|
|
72
|
+
y0c = tf.clip_by_value(y0, 0.0, max_y)
|
|
73
|
+
x0c = tf.clip_by_value(x0, 0.0, max_x)
|
|
74
|
+
y1c = tf.clip_by_value(y1, 0.0, max_y)
|
|
75
|
+
x1c = tf.clip_by_value(x1, 0.0, max_x)
|
|
76
|
+
|
|
77
|
+
y0i = tf.cast(y0c, tf.int32)
|
|
78
|
+
x0i = tf.cast(x0c, tf.int32)
|
|
79
|
+
y1i = tf.cast(y1c, tf.int32)
|
|
80
|
+
x1i = tf.cast(x1c, tf.int32)
|
|
81
|
+
|
|
82
|
+
input_flat = tf.reshape(image, tf.stack([tf.shape(image)[0], h * w, tf.shape(image)[3]]))
|
|
83
|
+
|
|
84
|
+
def _gather(y_idx, x_idx):
|
|
85
|
+
linear = y_idx * w + x_idx
|
|
86
|
+
linear = tf.squeeze(linear, axis=-1)
|
|
87
|
+
return tf.gather(input_flat, linear, batch_dims=1)
|
|
88
|
+
|
|
89
|
+
v00 = _gather(y0i, x0i)
|
|
90
|
+
v10 = _gather(y1i, x0i)
|
|
91
|
+
v11 = _gather(y1i, x1i)
|
|
92
|
+
v01 = _gather(y0i, x1i)
|
|
93
|
+
|
|
94
|
+
m00 = tf.cast(m00, image.dtype)
|
|
95
|
+
m10 = tf.cast(m10, image.dtype)
|
|
96
|
+
m11 = tf.cast(m11, image.dtype)
|
|
97
|
+
m01 = tf.cast(m01, image.dtype)
|
|
98
|
+
|
|
99
|
+
output = w00 * m00 * v00 + w10 * m10 * v10 + w11 * m11 * v11 + w01 * m01 * v01
|
|
100
|
+
return output
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@print_node_info
|
|
104
|
+
@inverted_operation_enable_disable
|
|
105
|
+
@get_replacement_parameter
|
|
106
|
+
def make_node(
|
|
107
|
+
*,
|
|
108
|
+
graph_node: gs.Node,
|
|
109
|
+
tf_layers_dict: dict,
|
|
110
|
+
**kwargs: dict,
|
|
111
|
+
):
|
|
112
|
+
"""DeformConv
|
|
113
|
+
|
|
114
|
+
Parameters
|
|
115
|
+
----------
|
|
116
|
+
graph_node: gs.Node
|
|
117
|
+
graph_surgeon Node
|
|
118
|
+
|
|
119
|
+
tf_layers_dict: dict
|
|
120
|
+
optype, shape, dtype, tensorflow graph
|
|
121
|
+
"""
|
|
122
|
+
before_op_output_shape_trans_1 = \
|
|
123
|
+
tf_layers_dict.get(graph_node.inputs[0].name, {}).get('before_op_output_shape_trans', True)
|
|
124
|
+
before_op_output_shape_trans_3 = \
|
|
125
|
+
tf_layers_dict.get(graph_node.inputs[2].name, {}).get('before_op_output_shape_trans', True)
|
|
126
|
+
before_op_output_shape_trans_4 = \
|
|
127
|
+
tf_layers_dict.get(graph_node.inputs[3].name, {}).get('before_op_output_shape_trans', True) \
|
|
128
|
+
if len(graph_node.inputs) >= 4 else True
|
|
129
|
+
before_op_output_shape_trans_5 = \
|
|
130
|
+
tf_layers_dict.get(graph_node.inputs[4].name, {}).get('before_op_output_shape_trans', True) \
|
|
131
|
+
if len(graph_node.inputs) >= 5 else True
|
|
132
|
+
|
|
133
|
+
graph_node_input_1 = get_constant_or_variable(
|
|
134
|
+
graph_node.inputs[0],
|
|
135
|
+
before_op_output_shape_trans_1,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
kernel_shape = graph_node.attrs.get('kernel_shape', [])
|
|
139
|
+
if kernel_shape == [] and graph_node.inputs[1].shape is not None:
|
|
140
|
+
kernel_shape = graph_node.inputs[1].shape[2:]
|
|
141
|
+
kernel_size = len(kernel_shape) if kernel_shape != [] else 2
|
|
142
|
+
|
|
143
|
+
graph_node_input_2 = get_weights_constant_or_variable(
|
|
144
|
+
const_or_var=graph_node.inputs[1],
|
|
145
|
+
kernel_size=kernel_size,
|
|
146
|
+
)
|
|
147
|
+
graph_node_input_3 = get_constant_or_variable(
|
|
148
|
+
graph_node.inputs[2],
|
|
149
|
+
before_op_output_shape_trans_3,
|
|
150
|
+
)
|
|
151
|
+
graph_node_input_4 = get_constant_or_variable(
|
|
152
|
+
graph_node.inputs[3],
|
|
153
|
+
before_op_output_shape_trans_4,
|
|
154
|
+
) if len(graph_node.inputs) >= 4 else None
|
|
155
|
+
graph_node_input_5 = get_constant_or_variable(
|
|
156
|
+
graph_node.inputs[4],
|
|
157
|
+
before_op_output_shape_trans_5,
|
|
158
|
+
) if len(graph_node.inputs) >= 5 else None
|
|
159
|
+
|
|
160
|
+
graph_node_output: gs.Variable = graph_node.outputs[0]
|
|
161
|
+
output_tensor_shape = graph_node_output.shape
|
|
162
|
+
dtype = graph_node_output.dtype
|
|
163
|
+
|
|
164
|
+
input_tensor = tf_layers_dict[graph_node_input_1.name]['tf_node'] \
|
|
165
|
+
if isinstance(graph_node_input_1, gs.Variable) else graph_node_input_1
|
|
166
|
+
weights = tf_layers_dict[graph_node_input_2.name]['tf_node'] \
|
|
167
|
+
if isinstance(graph_node_input_2, gs.Variable) else graph_node_input_2
|
|
168
|
+
offset = tf_layers_dict[graph_node_input_3.name]['tf_node'] \
|
|
169
|
+
if isinstance(graph_node_input_3, gs.Variable) else graph_node_input_3
|
|
170
|
+
bias = tf_layers_dict[graph_node_input_4.name]['tf_node'] \
|
|
171
|
+
if isinstance(graph_node_input_4, gs.Variable) else graph_node_input_4
|
|
172
|
+
mask = tf_layers_dict[graph_node_input_5.name]['tf_node'] \
|
|
173
|
+
if isinstance(graph_node_input_5, gs.Variable) else graph_node_input_5
|
|
174
|
+
|
|
175
|
+
input_tensor_shape = input_tensor.shape
|
|
176
|
+
|
|
177
|
+
if input_tensor_shape is not None and len(input_tensor_shape) != 4:
|
|
178
|
+
error('DeformConv currently supports only 2D inputs (N, C, H, W).')
|
|
179
|
+
sys.exit(1)
|
|
180
|
+
|
|
181
|
+
# Preserving Graph Structure (Dict)
|
|
182
|
+
tf_layers_dict[graph_node_output.name] = {
|
|
183
|
+
'optype': graph_node.op,
|
|
184
|
+
'shape': output_tensor_shape,
|
|
185
|
+
'dtype': dtype,
|
|
186
|
+
'nhwc': True,
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
# Pre-process transpose
|
|
190
|
+
input_tensor = pre_process_transpose(
|
|
191
|
+
value_before_transpose=input_tensor,
|
|
192
|
+
param_target='inputs',
|
|
193
|
+
param_name=graph_node.inputs[0].name,
|
|
194
|
+
**kwargs,
|
|
195
|
+
)
|
|
196
|
+
offset = pre_process_transpose(
|
|
197
|
+
value_before_transpose=offset,
|
|
198
|
+
param_target='inputs',
|
|
199
|
+
param_name=graph_node.inputs[2].name,
|
|
200
|
+
**kwargs,
|
|
201
|
+
)
|
|
202
|
+
if mask is not None:
|
|
203
|
+
mask = pre_process_transpose(
|
|
204
|
+
value_before_transpose=mask,
|
|
205
|
+
param_target='inputs',
|
|
206
|
+
param_name=graph_node.inputs[4].name,
|
|
207
|
+
**kwargs,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
input_dtype = input_tensor.dtype
|
|
211
|
+
if weights is not None and weights.dtype != input_dtype:
|
|
212
|
+
weights = tf.cast(weights, input_dtype)
|
|
213
|
+
if offset is not None and offset.dtype != input_dtype:
|
|
214
|
+
offset = tf.cast(offset, input_dtype)
|
|
215
|
+
if bias is not None and bias.dtype != input_dtype:
|
|
216
|
+
bias = tf.cast(bias, input_dtype)
|
|
217
|
+
if mask is not None and mask.dtype != input_dtype:
|
|
218
|
+
mask = tf.cast(mask, input_dtype)
|
|
219
|
+
|
|
220
|
+
# Workaround to avoid as many conversion failures as possible
|
|
221
|
+
onnx_input_shape = [
|
|
222
|
+
dim if isinstance(dim, int) else None for dim in graph_node.inputs[0].shape
|
|
223
|
+
] if graph_node.inputs[0].shape is not None else None
|
|
224
|
+
tf_input_shape = [
|
|
225
|
+
dim if isinstance(dim, int) else None for dim in input_tensor.shape
|
|
226
|
+
]
|
|
227
|
+
if onnx_input_shape is not None \
|
|
228
|
+
and len(onnx_input_shape) > 1 and len(tf_input_shape) > 1 \
|
|
229
|
+
and onnx_input_shape == tf_input_shape:
|
|
230
|
+
|
|
231
|
+
shape_for_judging_skip = [
|
|
232
|
+
dim if dim is not None else INF_INDEX_VALUE for dim in onnx_input_shape[1:]
|
|
233
|
+
]
|
|
234
|
+
if shape_for_judging_skip.count(shape_for_judging_skip[0]) != len(shape_for_judging_skip):
|
|
235
|
+
input_tensor = transpose_with_flexing_deterrence(
|
|
236
|
+
input_tensor=input_tensor,
|
|
237
|
+
perm=[0,2,3,1],
|
|
238
|
+
**kwargs,
|
|
239
|
+
)
|
|
240
|
+
offset = transpose_with_flexing_deterrence(
|
|
241
|
+
input_tensor=offset,
|
|
242
|
+
perm=[0,2,3,1],
|
|
243
|
+
**kwargs,
|
|
244
|
+
)
|
|
245
|
+
if mask is not None:
|
|
246
|
+
mask = transpose_with_flexing_deterrence(
|
|
247
|
+
input_tensor=mask,
|
|
248
|
+
perm=[0,2,3,1],
|
|
249
|
+
**kwargs,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
# Attributes
|
|
253
|
+
dilations = graph_node.attrs.get('dilations', [1, 1])
|
|
254
|
+
group = graph_node.attrs.get('group', 1)
|
|
255
|
+
offset_group = graph_node.attrs.get('offset_group', 1)
|
|
256
|
+
pads = graph_node.attrs.get('pads', [0, 0, 0, 0])
|
|
257
|
+
strides = graph_node.attrs.get('strides', [1, 1])
|
|
258
|
+
|
|
259
|
+
dilation_h, dilation_w = dilations
|
|
260
|
+
stride_h, stride_w = strides
|
|
261
|
+
pad_top, pad_left, pad_bottom, pad_right = pads
|
|
262
|
+
|
|
263
|
+
# Input prep
|
|
264
|
+
if pad_top != 0 or pad_bottom != 0 or pad_left != 0 or pad_right != 0:
|
|
265
|
+
input_tensor = tf.pad(
|
|
266
|
+
input_tensor,
|
|
267
|
+
paddings=[[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]],
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
batch = tf.shape(input_tensor)[0]
|
|
271
|
+
in_h = tf.shape(input_tensor)[1]
|
|
272
|
+
in_w = tf.shape(input_tensor)[2]
|
|
273
|
+
in_c = tf.shape(input_tensor)[3]
|
|
274
|
+
|
|
275
|
+
offset_shape = tf.shape(offset)
|
|
276
|
+
out_h = offset_shape[1]
|
|
277
|
+
out_w = offset_shape[2]
|
|
278
|
+
|
|
279
|
+
# Kernel shape
|
|
280
|
+
if kernel_shape != []:
|
|
281
|
+
kh = _to_int_tensor(kernel_shape[0])
|
|
282
|
+
kw = _to_int_tensor(kernel_shape[1])
|
|
283
|
+
else:
|
|
284
|
+
kh = _to_int_tensor(tf.shape(weights)[0])
|
|
285
|
+
kw = _to_int_tensor(tf.shape(weights)[1])
|
|
286
|
+
|
|
287
|
+
# Base grid: [oH, oW, kH, kW, 2]
|
|
288
|
+
oy = tf.range(out_h, dtype=input_dtype) * tf.cast(stride_h, input_dtype)
|
|
289
|
+
ox = tf.range(out_w, dtype=input_dtype) * tf.cast(stride_w, input_dtype)
|
|
290
|
+
ky = tf.range(kh, dtype=input_dtype) * tf.cast(dilation_h, input_dtype)
|
|
291
|
+
kx = tf.range(kw, dtype=input_dtype) * tf.cast(dilation_w, input_dtype)
|
|
292
|
+
|
|
293
|
+
oy = tf.reshape(oy, tf.stack([out_h, 1, 1, 1]))
|
|
294
|
+
ox = tf.reshape(ox, tf.stack([1, out_w, 1, 1]))
|
|
295
|
+
ky = tf.reshape(ky, tf.stack([1, 1, kh, 1]))
|
|
296
|
+
kx = tf.reshape(kx, tf.stack([1, 1, 1, kw]))
|
|
297
|
+
|
|
298
|
+
y = oy + ky
|
|
299
|
+
x = ox + kx
|
|
300
|
+
target_shape = tf.stack([out_h, out_w, kh, kw])
|
|
301
|
+
y = tf.broadcast_to(y, target_shape)
|
|
302
|
+
x = tf.broadcast_to(x, target_shape)
|
|
303
|
+
base_grid = tf.stack([y, x], axis=-1)
|
|
304
|
+
|
|
305
|
+
# Offset reshape: [N, oH, oW, Goff, kH, kW, 2]
|
|
306
|
+
offset = tf.reshape(
|
|
307
|
+
offset,
|
|
308
|
+
tf.stack([batch, out_h, out_w, offset_group, kh, kw, 2]),
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
coords = base_grid[None, :, :, None, :, :, :] + offset
|
|
312
|
+
coords = tf.transpose(coords, [0, 3, 1, 2, 4, 5, 6])
|
|
313
|
+
coords = tf.reshape(coords, tf.stack([batch * offset_group, out_h, out_w, kh, kw, 2]))
|
|
314
|
+
|
|
315
|
+
# Input grouping for offset_group
|
|
316
|
+
c_per_offset = tf.math.floordiv(in_c, offset_group)
|
|
317
|
+
input_tensor = tf.reshape(
|
|
318
|
+
input_tensor,
|
|
319
|
+
tf.stack([batch, in_h, in_w, offset_group, c_per_offset]),
|
|
320
|
+
)
|
|
321
|
+
input_tensor = tf.transpose(input_tensor, [0, 3, 1, 2, 4])
|
|
322
|
+
input_tensor = tf.reshape(
|
|
323
|
+
input_tensor,
|
|
324
|
+
tf.stack([batch * offset_group, in_h, in_w, c_per_offset]),
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
sampled = _bilinear_sample_2d(input_tensor, coords)
|
|
328
|
+
sampled = tf.reshape(
|
|
329
|
+
sampled,
|
|
330
|
+
tf.stack([batch, offset_group, out_h, out_w, kh, kw, c_per_offset]),
|
|
331
|
+
)
|
|
332
|
+
sampled = tf.transpose(sampled, [0, 2, 3, 1, 4, 5, 6])
|
|
333
|
+
|
|
334
|
+
if mask is not None:
|
|
335
|
+
mask = tf.reshape(
|
|
336
|
+
mask,
|
|
337
|
+
tf.stack([batch, out_h, out_w, offset_group, kh, kw, 1]),
|
|
338
|
+
)
|
|
339
|
+
sampled = sampled * tf.cast(mask, sampled.dtype)
|
|
340
|
+
|
|
341
|
+
# Merge offset_group back to channel dim: [N, oH, oW, kH, kW, C]
|
|
342
|
+
sampled = tf.reshape(
|
|
343
|
+
sampled,
|
|
344
|
+
tf.stack([batch, out_h, out_w, kh, kw, in_c]),
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
# Grouped convolution via batched matmul
|
|
348
|
+
out_c = tf.shape(weights)[3]
|
|
349
|
+
c_per_group = tf.math.floordiv(in_c, group)
|
|
350
|
+
out_c_per_group = tf.math.floordiv(out_c, group)
|
|
351
|
+
|
|
352
|
+
cols = tf.reshape(sampled, tf.stack([batch * out_h * out_w, kh * kw * in_c]))
|
|
353
|
+
cols = tf.reshape(cols, tf.stack([batch * out_h * out_w, group, kh * kw * c_per_group]))
|
|
354
|
+
cols = tf.transpose(cols, [1, 0, 2])
|
|
355
|
+
|
|
356
|
+
weights = tf.reshape(weights, tf.stack([kh, kw, c_per_group, group, out_c_per_group]))
|
|
357
|
+
weights = tf.transpose(weights, [3, 0, 1, 2, 4])
|
|
358
|
+
weights = tf.reshape(weights, tf.stack([group, kh * kw * c_per_group, out_c_per_group]))
|
|
359
|
+
|
|
360
|
+
output = tf.matmul(cols, weights)
|
|
361
|
+
output = tf.transpose(output, [1, 0, 2])
|
|
362
|
+
output = tf.reshape(output, tf.stack([batch, out_h, out_w, out_c]))
|
|
363
|
+
|
|
364
|
+
if bias is not None:
|
|
365
|
+
output += tf.reshape(bias, tf.stack([1, 1, 1, out_c]))
|
|
366
|
+
|
|
367
|
+
if output.dtype != input_dtype:
|
|
368
|
+
output = tf.cast(output, input_dtype)
|
|
369
|
+
|
|
370
|
+
# Post-process transpose
|
|
371
|
+
tf_layers_dict[graph_node_output.name]['tf_node'] = post_process_transpose(
|
|
372
|
+
value_before_transpose=output,
|
|
373
|
+
param_target='outputs',
|
|
374
|
+
param_name=graph_node.outputs[0].name,
|
|
375
|
+
**kwargs,
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
# Generation of Debug Info
|
|
379
|
+
tf_layers_dict[graph_node_output.name]['tf_node_info'] = \
|
|
380
|
+
make_tf_node_info(
|
|
381
|
+
node_info={
|
|
382
|
+
'tf_op_type': 'DeformConv',
|
|
383
|
+
'tf_inputs': {
|
|
384
|
+
'input_tensor': input_tensor,
|
|
385
|
+
'weights': weights,
|
|
386
|
+
'offset': offset,
|
|
387
|
+
'bias': bias,
|
|
388
|
+
'mask': mask,
|
|
389
|
+
'strides': strides,
|
|
390
|
+
'dilations': dilations,
|
|
391
|
+
'pads': pads,
|
|
392
|
+
'group': group,
|
|
393
|
+
'offset_group': offset_group,
|
|
394
|
+
},
|
|
395
|
+
'tf_outputs': {
|
|
396
|
+
'output': tf_layers_dict[graph_node_output.name]['tf_node'],
|
|
397
|
+
},
|
|
398
|
+
}
|
|
399
|
+
)
|
onnx2tf/ops/GatherElements.py
CHANGED
|
@@ -57,9 +57,10 @@ def make_node(
|
|
|
57
57
|
graph_node.inputs[0],
|
|
58
58
|
before_op_output_shape_trans,
|
|
59
59
|
)
|
|
60
|
+
# Indices must not be layout-transposed.
|
|
60
61
|
graph_node_input_2 = get_constant_or_variable(
|
|
61
62
|
graph_node.inputs[1],
|
|
62
|
-
|
|
63
|
+
False,
|
|
63
64
|
)
|
|
64
65
|
graph_node_output: gs.Variable = graph_node.outputs[0]
|
|
65
66
|
shape = graph_node_output.shape
|
|
@@ -77,12 +78,29 @@ def make_node(
|
|
|
77
78
|
param_name=graph_node.inputs[0].name,
|
|
78
79
|
**kwargs,
|
|
79
80
|
)
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
81
|
+
# If input is transposed by replacement params, align indices tensor shape.
|
|
82
|
+
op_rep_params = kwargs.get('op_rep_params', [])
|
|
83
|
+
params_perm = None
|
|
84
|
+
indices_perm = None
|
|
85
|
+
for op_rep_param in op_rep_params:
|
|
86
|
+
if op_rep_param['param_target'] == 'inputs' \
|
|
87
|
+
and op_rep_param['param_name'] == graph_node.inputs[0].name:
|
|
88
|
+
params_perm = op_rep_param.get('pre_process_transpose_perm', None)
|
|
89
|
+
if op_rep_param['param_target'] == 'inputs' \
|
|
90
|
+
and op_rep_param['param_name'] == graph_node.inputs[1].name:
|
|
91
|
+
indices_perm = op_rep_param.get('pre_process_transpose_perm', None)
|
|
92
|
+
target_perm = indices_perm if indices_perm is not None else params_perm
|
|
93
|
+
if target_perm is not None:
|
|
94
|
+
try:
|
|
95
|
+
rank = len(indices_tensor.shape) if hasattr(indices_tensor, "shape") else None
|
|
96
|
+
if rank is None or rank == len(target_perm):
|
|
97
|
+
indices_tensor = transpose_with_flexing_deterrence(
|
|
98
|
+
input_tensor=indices_tensor,
|
|
99
|
+
perm=target_perm,
|
|
100
|
+
**kwargs,
|
|
101
|
+
)
|
|
102
|
+
except Exception:
|
|
103
|
+
pass
|
|
86
104
|
|
|
87
105
|
tensor_rank = len(input_tensor.shape)
|
|
88
106
|
|
onnx2tf/ops/GatherND.py
CHANGED
|
@@ -51,9 +51,10 @@ def make_node(
|
|
|
51
51
|
graph_node.inputs[0],
|
|
52
52
|
before_op_output_shape_trans,
|
|
53
53
|
)
|
|
54
|
+
# Indices must not be layout-transposed.
|
|
54
55
|
graph_node_input_2 = get_constant_or_variable(
|
|
55
56
|
graph_node.inputs[1],
|
|
56
|
-
|
|
57
|
+
False,
|
|
57
58
|
)
|
|
58
59
|
graph_node_output: gs.Variable = graph_node.outputs[0]
|
|
59
60
|
shape = graph_node_output.shape
|
|
@@ -89,6 +90,32 @@ def make_node(
|
|
|
89
90
|
|
|
90
91
|
replace_gathernd_to_pseudo_gathernd = "gathernd" in kwargs['replace_to_pseudo_operators']
|
|
91
92
|
|
|
93
|
+
# If params is transposed, adjust indices to match the transposed layout.
|
|
94
|
+
op_rep_params = kwargs.get('op_rep_params', [])
|
|
95
|
+
params_perm = None
|
|
96
|
+
indices_perm_specified = False
|
|
97
|
+
for op_rep_param in op_rep_params:
|
|
98
|
+
if op_rep_param['param_target'] == 'inputs' and op_rep_param['param_name'] == graph_node.inputs[0].name:
|
|
99
|
+
params_perm = op_rep_param.get('pre_process_transpose_perm', None)
|
|
100
|
+
if op_rep_param['param_target'] == 'inputs' and op_rep_param['param_name'] == graph_node.inputs[1].name:
|
|
101
|
+
if op_rep_param.get('pre_process_transpose_perm', None) is not None:
|
|
102
|
+
indices_perm_specified = True
|
|
103
|
+
if params_perm is not None and not indices_perm_specified:
|
|
104
|
+
# Only handle standard layout swaps that keep batch dims at the front.
|
|
105
|
+
if batch_dims <= len(params_perm) \
|
|
106
|
+
and list(params_perm[:batch_dims]) == list(range(batch_dims)):
|
|
107
|
+
perm_tail = [p - batch_dims for p in params_perm if p >= batch_dims]
|
|
108
|
+
try:
|
|
109
|
+
if isinstance(indices_tensor, np.ndarray):
|
|
110
|
+
if indices_tensor.shape and indices_tensor.shape[-1] == len(perm_tail):
|
|
111
|
+
indices_tensor = indices_tensor[..., perm_tail]
|
|
112
|
+
else:
|
|
113
|
+
idx_last = indices_tensor.shape[-1] if indices_tensor.shape is not None else None
|
|
114
|
+
if idx_last is None or idx_last == len(perm_tail):
|
|
115
|
+
indices_tensor = tf.gather(indices_tensor, perm_tail, axis=-1)
|
|
116
|
+
except Exception:
|
|
117
|
+
pass
|
|
118
|
+
|
|
92
119
|
# Preserving Graph Structure (Dict)
|
|
93
120
|
tf_layers_dict[graph_node_output.name] = {
|
|
94
121
|
'optype': graph_node.op,
|
onnx2tf/ops/ScatterElements.py
CHANGED
|
@@ -55,9 +55,10 @@ def make_node(
|
|
|
55
55
|
graph_node.inputs[0],
|
|
56
56
|
before_op_output_shape_trans,
|
|
57
57
|
)
|
|
58
|
+
# Indices must not be layout-transposed.
|
|
58
59
|
graph_node_input_2 = get_constant_or_variable(
|
|
59
60
|
graph_node.inputs[1],
|
|
60
|
-
|
|
61
|
+
False,
|
|
61
62
|
)
|
|
62
63
|
graph_node_input_3 = get_constant_or_variable(
|
|
63
64
|
graph_node.inputs[2],
|
|
@@ -81,12 +82,29 @@ def make_node(
|
|
|
81
82
|
indices_tensor = tf_layers_dict[graph_node_input_2.name]['tf_node'] \
|
|
82
83
|
if isinstance(graph_node_input_2, gs.Variable) else graph_node_input_2
|
|
83
84
|
# Pre-process transpose
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
85
|
+
# If input is transposed by replacement params, align indices tensor shape.
|
|
86
|
+
op_rep_params = kwargs.get('op_rep_params', [])
|
|
87
|
+
params_perm = None
|
|
88
|
+
indices_perm = None
|
|
89
|
+
for op_rep_param in op_rep_params:
|
|
90
|
+
if op_rep_param['param_target'] == 'inputs' \
|
|
91
|
+
and op_rep_param['param_name'] == graph_node.inputs[0].name:
|
|
92
|
+
params_perm = op_rep_param.get('pre_process_transpose_perm', None)
|
|
93
|
+
if op_rep_param['param_target'] == 'inputs' \
|
|
94
|
+
and op_rep_param['param_name'] == graph_node.inputs[1].name:
|
|
95
|
+
indices_perm = op_rep_param.get('pre_process_transpose_perm', None)
|
|
96
|
+
target_perm = indices_perm if indices_perm is not None else params_perm
|
|
97
|
+
if target_perm is not None:
|
|
98
|
+
try:
|
|
99
|
+
rank = len(indices_tensor.shape) if hasattr(indices_tensor, "shape") else None
|
|
100
|
+
if rank is None or rank == len(target_perm):
|
|
101
|
+
indices_tensor = transpose_with_flexing_deterrence(
|
|
102
|
+
input_tensor=indices_tensor,
|
|
103
|
+
perm=target_perm,
|
|
104
|
+
**kwargs,
|
|
105
|
+
)
|
|
106
|
+
except Exception:
|
|
107
|
+
pass
|
|
90
108
|
updates_tensor = tf_layers_dict[graph_node_input_3.name]['tf_node'] \
|
|
91
109
|
if isinstance(graph_node_input_3, gs.Variable) else graph_node_input_3
|
|
92
110
|
# Pre-process transpose
|
onnx2tf/ops/ScatterND.py
CHANGED
|
@@ -13,6 +13,7 @@ from onnx2tf.utils.common_functions import (
|
|
|
13
13
|
get_replacement_parameter,
|
|
14
14
|
pre_process_transpose,
|
|
15
15
|
post_process_transpose,
|
|
16
|
+
transpose_with_flexing_deterrence,
|
|
16
17
|
)
|
|
17
18
|
|
|
18
19
|
|
|
@@ -79,6 +80,32 @@ def make_node(
|
|
|
79
80
|
and 'nhwc' in tf_layers_dict[graph_node_input_1.name].keys() else False
|
|
80
81
|
}
|
|
81
82
|
|
|
83
|
+
op_rep_params = kwargs.get('op_rep_params', [])
|
|
84
|
+
params_perm = None
|
|
85
|
+
indices_perm = None
|
|
86
|
+
for op_rep_param in op_rep_params:
|
|
87
|
+
if op_rep_param['param_target'] == 'inputs' \
|
|
88
|
+
and op_rep_param['param_name'] == graph_node.inputs[0].name:
|
|
89
|
+
params_perm = op_rep_param.get('pre_process_transpose_perm', None)
|
|
90
|
+
if op_rep_param['param_target'] == 'inputs' \
|
|
91
|
+
and op_rep_param['param_name'] == graph_node.inputs[1].name:
|
|
92
|
+
indices_perm = op_rep_param.get('pre_process_transpose_perm', None)
|
|
93
|
+
|
|
94
|
+
def reorder_indices_last_dim(target_indices, perm):
|
|
95
|
+
if perm is None:
|
|
96
|
+
return target_indices
|
|
97
|
+
try:
|
|
98
|
+
if isinstance(target_indices, np.ndarray):
|
|
99
|
+
if target_indices.shape and target_indices.shape[-1] == len(perm):
|
|
100
|
+
return target_indices[..., perm]
|
|
101
|
+
else:
|
|
102
|
+
idx_last = target_indices.shape[-1] if target_indices.shape is not None else None
|
|
103
|
+
if idx_last is None or idx_last == len(perm):
|
|
104
|
+
return tf.gather(target_indices, perm, axis=-1)
|
|
105
|
+
except Exception:
|
|
106
|
+
pass
|
|
107
|
+
return target_indices
|
|
108
|
+
|
|
82
109
|
# Pre-process transpose
|
|
83
110
|
input_tensor = pre_process_transpose(
|
|
84
111
|
value_before_transpose=input_tensor,
|
|
@@ -86,18 +113,26 @@ def make_node(
|
|
|
86
113
|
param_name=graph_node.inputs[0].name,
|
|
87
114
|
**kwargs,
|
|
88
115
|
)
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
116
|
+
# Indices must not be layout-transposed; apply explicit perm only if specified.
|
|
117
|
+
if indices_perm is not None:
|
|
118
|
+
try:
|
|
119
|
+
rank = len(indices_tensor.shape) if hasattr(indices_tensor, "shape") else None
|
|
120
|
+
if rank is None or rank == len(indices_perm):
|
|
121
|
+
indices_tensor = transpose_with_flexing_deterrence(
|
|
122
|
+
input_tensor=indices_tensor,
|
|
123
|
+
perm=indices_perm,
|
|
124
|
+
**kwargs,
|
|
125
|
+
)
|
|
126
|
+
except Exception:
|
|
127
|
+
pass
|
|
95
128
|
updates_tensor = pre_process_transpose(
|
|
96
129
|
value_before_transpose=updates_tensor,
|
|
97
130
|
param_target='inputs',
|
|
98
131
|
param_name=graph_node.inputs[2].name,
|
|
99
132
|
**kwargs,
|
|
100
133
|
)
|
|
134
|
+
if params_perm is not None and indices_perm is None:
|
|
135
|
+
indices_tensor = reorder_indices_last_dim(indices_tensor, params_perm)
|
|
101
136
|
|
|
102
137
|
# When NHWC is fixed, return to NCHW format before processing.
|
|
103
138
|
data_nhwc = tf_layers_dict[graph_node_input_1.name]['nhwc'] \
|
|
@@ -119,6 +154,8 @@ def make_node(
|
|
|
119
154
|
and len(input_tensor.shape) >= 3:
|
|
120
155
|
perm = [0, len(input_tensor.shape)-1] + [i for i in range(1, len(input_tensor.shape)-1)]
|
|
121
156
|
input_tensor = tf.transpose(a=input_tensor, perm=perm)
|
|
157
|
+
if indices_perm is None:
|
|
158
|
+
indices_tensor = reorder_indices_last_dim(indices_tensor, perm)
|
|
122
159
|
nchw = True
|
|
123
160
|
elif not data_nhwc \
|
|
124
161
|
and len(input_tensor.shape) >= 3 \
|
|
@@ -126,6 +163,8 @@ def make_node(
|
|
|
126
163
|
and input_tensor.shape != graph_node.inputs[0].shape:
|
|
127
164
|
perm = [0, len(input_tensor.shape)-1] + [i for i in range(1, len(input_tensor.shape)-1)]
|
|
128
165
|
input_tensor = tf.transpose(a=input_tensor, perm=perm)
|
|
166
|
+
if indices_perm is None:
|
|
167
|
+
indices_tensor = reorder_indices_last_dim(indices_tensor, perm)
|
|
129
168
|
nchw = True
|
|
130
169
|
## indices
|
|
131
170
|
if indices_nhwc \
|
onnx2tf/ops/TensorScatter.py
CHANGED
|
@@ -14,6 +14,7 @@ from onnx2tf.utils.common_functions import (
|
|
|
14
14
|
get_replacement_parameter,
|
|
15
15
|
pre_process_transpose,
|
|
16
16
|
post_process_transpose,
|
|
17
|
+
transpose_with_flexing_deterrence,
|
|
17
18
|
)
|
|
18
19
|
from onnx2tf.utils.enums import NUMPY_DTYPES_TO_TF_DTYPES
|
|
19
20
|
from onnx2tf.utils.logging import *
|
|
@@ -112,12 +113,25 @@ def make_node(
|
|
|
112
113
|
**kwargs,
|
|
113
114
|
)
|
|
114
115
|
if write_indices is not None:
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
116
|
+
# Indices must not be layout-transposed; apply explicit perm only if specified.
|
|
117
|
+
op_rep_params = kwargs.get('op_rep_params', [])
|
|
118
|
+
indices_perm = None
|
|
119
|
+
for op_rep_param in op_rep_params:
|
|
120
|
+
if op_rep_param['param_target'] == 'inputs' \
|
|
121
|
+
and op_rep_param['param_name'] == graph_node.inputs[2].name:
|
|
122
|
+
indices_perm = op_rep_param.get('pre_process_transpose_perm', None)
|
|
123
|
+
break
|
|
124
|
+
if indices_perm is not None:
|
|
125
|
+
try:
|
|
126
|
+
rank = len(write_indices.shape) if hasattr(write_indices, "shape") else None
|
|
127
|
+
if rank is None or rank == len(indices_perm):
|
|
128
|
+
write_indices = transpose_with_flexing_deterrence(
|
|
129
|
+
input_tensor=write_indices,
|
|
130
|
+
perm=indices_perm,
|
|
131
|
+
**kwargs,
|
|
132
|
+
)
|
|
133
|
+
except Exception:
|
|
134
|
+
pass
|
|
121
135
|
|
|
122
136
|
# Generation of TF OP
|
|
123
137
|
past_cache = _as_tensor(past_cache)
|