onnx2tf 1.29.18__py3-none-any.whl → 1.29.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,399 @@
1
+ import sys
2
+ import random
3
+ random.seed(0)
4
+ import numpy as np
5
+ np.random.seed(0)
6
+ import tensorflow as tf
7
+ import onnx_graphsurgeon as gs
8
+ from onnx2tf.utils.common_functions import (
9
+ get_constant_or_variable,
10
+ get_weights_constant_or_variable,
11
+ print_node_info,
12
+ inverted_operation_enable_disable,
13
+ make_tf_node_info,
14
+ get_replacement_parameter,
15
+ pre_process_transpose,
16
+ post_process_transpose,
17
+ transpose_with_flexing_deterrence,
18
+ )
19
+ from onnx2tf.utils.logging import *
20
+
21
+ INF_INDEX_VALUE: int = 4294967296
22
+
23
+
24
+ def _to_int_tensor(value, name=None):
25
+ if isinstance(value, tf.Tensor):
26
+ return tf.cast(value, tf.int32)
27
+ return tf.constant(value, dtype=tf.int32, name=name)
28
+
29
+
30
+ def _bilinear_sample_2d(
31
+ image,
32
+ coords,
33
+ ):
34
+ """
35
+ image: [N, H, W, C]
36
+ coords: [N, oH, oW, kH, kW, 2] in absolute coords (y, x)
37
+ """
38
+ coord_dtype = coords.dtype
39
+ h = tf.shape(image)[1]
40
+ w = tf.shape(image)[2]
41
+ h_f = tf.cast(h, coord_dtype)
42
+ w_f = tf.cast(w, coord_dtype)
43
+ max_y = h_f - 1.0
44
+ max_x = w_f - 1.0
45
+
46
+ y, x = tf.split(coords, num_or_size_splits=2, axis=-1)
47
+
48
+ y0 = tf.floor(y)
49
+ x0 = tf.floor(x)
50
+ y1 = y0 + 1.0
51
+ x1 = x0 + 1.0
52
+
53
+ dy = y - y0
54
+ dx = x - x0
55
+
56
+ w00 = (1.0 - dy) * (1.0 - dx)
57
+ w10 = dy * (1.0 - dx)
58
+ w11 = dy * dx
59
+ w01 = (1.0 - dy) * dx
60
+
61
+ def _in_bounds(y_idx, x_idx):
62
+ return tf.logical_and(
63
+ tf.logical_and(y_idx >= 0.0, y_idx <= max_y),
64
+ tf.logical_and(x_idx >= 0.0, x_idx <= max_x),
65
+ )
66
+
67
+ m00 = _in_bounds(y0, x0)
68
+ m10 = _in_bounds(y1, x0)
69
+ m11 = _in_bounds(y1, x1)
70
+ m01 = _in_bounds(y0, x1)
71
+
72
+ y0c = tf.clip_by_value(y0, 0.0, max_y)
73
+ x0c = tf.clip_by_value(x0, 0.0, max_x)
74
+ y1c = tf.clip_by_value(y1, 0.0, max_y)
75
+ x1c = tf.clip_by_value(x1, 0.0, max_x)
76
+
77
+ y0i = tf.cast(y0c, tf.int32)
78
+ x0i = tf.cast(x0c, tf.int32)
79
+ y1i = tf.cast(y1c, tf.int32)
80
+ x1i = tf.cast(x1c, tf.int32)
81
+
82
+ input_flat = tf.reshape(image, tf.stack([tf.shape(image)[0], h * w, tf.shape(image)[3]]))
83
+
84
+ def _gather(y_idx, x_idx):
85
+ linear = y_idx * w + x_idx
86
+ linear = tf.squeeze(linear, axis=-1)
87
+ return tf.gather(input_flat, linear, batch_dims=1)
88
+
89
+ v00 = _gather(y0i, x0i)
90
+ v10 = _gather(y1i, x0i)
91
+ v11 = _gather(y1i, x1i)
92
+ v01 = _gather(y0i, x1i)
93
+
94
+ m00 = tf.cast(m00, image.dtype)
95
+ m10 = tf.cast(m10, image.dtype)
96
+ m11 = tf.cast(m11, image.dtype)
97
+ m01 = tf.cast(m01, image.dtype)
98
+
99
+ output = w00 * m00 * v00 + w10 * m10 * v10 + w11 * m11 * v11 + w01 * m01 * v01
100
+ return output
101
+
102
+
103
+ @print_node_info
104
+ @inverted_operation_enable_disable
105
+ @get_replacement_parameter
106
+ def make_node(
107
+ *,
108
+ graph_node: gs.Node,
109
+ tf_layers_dict: dict,
110
+ **kwargs: dict,
111
+ ):
112
+ """DeformConv
113
+
114
+ Parameters
115
+ ----------
116
+ graph_node: gs.Node
117
+ graph_surgeon Node
118
+
119
+ tf_layers_dict: dict
120
+ optype, shape, dtype, tensorflow graph
121
+ """
122
+ before_op_output_shape_trans_1 = \
123
+ tf_layers_dict.get(graph_node.inputs[0].name, {}).get('before_op_output_shape_trans', True)
124
+ before_op_output_shape_trans_3 = \
125
+ tf_layers_dict.get(graph_node.inputs[2].name, {}).get('before_op_output_shape_trans', True)
126
+ before_op_output_shape_trans_4 = \
127
+ tf_layers_dict.get(graph_node.inputs[3].name, {}).get('before_op_output_shape_trans', True) \
128
+ if len(graph_node.inputs) >= 4 else True
129
+ before_op_output_shape_trans_5 = \
130
+ tf_layers_dict.get(graph_node.inputs[4].name, {}).get('before_op_output_shape_trans', True) \
131
+ if len(graph_node.inputs) >= 5 else True
132
+
133
+ graph_node_input_1 = get_constant_or_variable(
134
+ graph_node.inputs[0],
135
+ before_op_output_shape_trans_1,
136
+ )
137
+
138
+ kernel_shape = graph_node.attrs.get('kernel_shape', [])
139
+ if kernel_shape == [] and graph_node.inputs[1].shape is not None:
140
+ kernel_shape = graph_node.inputs[1].shape[2:]
141
+ kernel_size = len(kernel_shape) if kernel_shape != [] else 2
142
+
143
+ graph_node_input_2 = get_weights_constant_or_variable(
144
+ const_or_var=graph_node.inputs[1],
145
+ kernel_size=kernel_size,
146
+ )
147
+ graph_node_input_3 = get_constant_or_variable(
148
+ graph_node.inputs[2],
149
+ before_op_output_shape_trans_3,
150
+ )
151
+ graph_node_input_4 = get_constant_or_variable(
152
+ graph_node.inputs[3],
153
+ before_op_output_shape_trans_4,
154
+ ) if len(graph_node.inputs) >= 4 else None
155
+ graph_node_input_5 = get_constant_or_variable(
156
+ graph_node.inputs[4],
157
+ before_op_output_shape_trans_5,
158
+ ) if len(graph_node.inputs) >= 5 else None
159
+
160
+ graph_node_output: gs.Variable = graph_node.outputs[0]
161
+ output_tensor_shape = graph_node_output.shape
162
+ dtype = graph_node_output.dtype
163
+
164
+ input_tensor = tf_layers_dict[graph_node_input_1.name]['tf_node'] \
165
+ if isinstance(graph_node_input_1, gs.Variable) else graph_node_input_1
166
+ weights = tf_layers_dict[graph_node_input_2.name]['tf_node'] \
167
+ if isinstance(graph_node_input_2, gs.Variable) else graph_node_input_2
168
+ offset = tf_layers_dict[graph_node_input_3.name]['tf_node'] \
169
+ if isinstance(graph_node_input_3, gs.Variable) else graph_node_input_3
170
+ bias = tf_layers_dict[graph_node_input_4.name]['tf_node'] \
171
+ if isinstance(graph_node_input_4, gs.Variable) else graph_node_input_4
172
+ mask = tf_layers_dict[graph_node_input_5.name]['tf_node'] \
173
+ if isinstance(graph_node_input_5, gs.Variable) else graph_node_input_5
174
+
175
+ input_tensor_shape = input_tensor.shape
176
+
177
+ if input_tensor_shape is not None and len(input_tensor_shape) != 4:
178
+ error('DeformConv currently supports only 2D inputs (N, C, H, W).')
179
+ sys.exit(1)
180
+
181
+ # Preserving Graph Structure (Dict)
182
+ tf_layers_dict[graph_node_output.name] = {
183
+ 'optype': graph_node.op,
184
+ 'shape': output_tensor_shape,
185
+ 'dtype': dtype,
186
+ 'nhwc': True,
187
+ }
188
+
189
+ # Pre-process transpose
190
+ input_tensor = pre_process_transpose(
191
+ value_before_transpose=input_tensor,
192
+ param_target='inputs',
193
+ param_name=graph_node.inputs[0].name,
194
+ **kwargs,
195
+ )
196
+ offset = pre_process_transpose(
197
+ value_before_transpose=offset,
198
+ param_target='inputs',
199
+ param_name=graph_node.inputs[2].name,
200
+ **kwargs,
201
+ )
202
+ if mask is not None:
203
+ mask = pre_process_transpose(
204
+ value_before_transpose=mask,
205
+ param_target='inputs',
206
+ param_name=graph_node.inputs[4].name,
207
+ **kwargs,
208
+ )
209
+
210
+ input_dtype = input_tensor.dtype
211
+ if weights is not None and weights.dtype != input_dtype:
212
+ weights = tf.cast(weights, input_dtype)
213
+ if offset is not None and offset.dtype != input_dtype:
214
+ offset = tf.cast(offset, input_dtype)
215
+ if bias is not None and bias.dtype != input_dtype:
216
+ bias = tf.cast(bias, input_dtype)
217
+ if mask is not None and mask.dtype != input_dtype:
218
+ mask = tf.cast(mask, input_dtype)
219
+
220
+ # Workaround to avoid as many conversion failures as possible
221
+ onnx_input_shape = [
222
+ dim if isinstance(dim, int) else None for dim in graph_node.inputs[0].shape
223
+ ] if graph_node.inputs[0].shape is not None else None
224
+ tf_input_shape = [
225
+ dim if isinstance(dim, int) else None for dim in input_tensor.shape
226
+ ]
227
+ if onnx_input_shape is not None \
228
+ and len(onnx_input_shape) > 1 and len(tf_input_shape) > 1 \
229
+ and onnx_input_shape == tf_input_shape:
230
+
231
+ shape_for_judging_skip = [
232
+ dim if dim is not None else INF_INDEX_VALUE for dim in onnx_input_shape[1:]
233
+ ]
234
+ if shape_for_judging_skip.count(shape_for_judging_skip[0]) != len(shape_for_judging_skip):
235
+ input_tensor = transpose_with_flexing_deterrence(
236
+ input_tensor=input_tensor,
237
+ perm=[0,2,3,1],
238
+ **kwargs,
239
+ )
240
+ offset = transpose_with_flexing_deterrence(
241
+ input_tensor=offset,
242
+ perm=[0,2,3,1],
243
+ **kwargs,
244
+ )
245
+ if mask is not None:
246
+ mask = transpose_with_flexing_deterrence(
247
+ input_tensor=mask,
248
+ perm=[0,2,3,1],
249
+ **kwargs,
250
+ )
251
+
252
+ # Attributes
253
+ dilations = graph_node.attrs.get('dilations', [1, 1])
254
+ group = graph_node.attrs.get('group', 1)
255
+ offset_group = graph_node.attrs.get('offset_group', 1)
256
+ pads = graph_node.attrs.get('pads', [0, 0, 0, 0])
257
+ strides = graph_node.attrs.get('strides', [1, 1])
258
+
259
+ dilation_h, dilation_w = dilations
260
+ stride_h, stride_w = strides
261
+ pad_top, pad_left, pad_bottom, pad_right = pads
262
+
263
+ # Input prep
264
+ if pad_top != 0 or pad_bottom != 0 or pad_left != 0 or pad_right != 0:
265
+ input_tensor = tf.pad(
266
+ input_tensor,
267
+ paddings=[[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]],
268
+ )
269
+
270
+ batch = tf.shape(input_tensor)[0]
271
+ in_h = tf.shape(input_tensor)[1]
272
+ in_w = tf.shape(input_tensor)[2]
273
+ in_c = tf.shape(input_tensor)[3]
274
+
275
+ offset_shape = tf.shape(offset)
276
+ out_h = offset_shape[1]
277
+ out_w = offset_shape[2]
278
+
279
+ # Kernel shape
280
+ if kernel_shape != []:
281
+ kh = _to_int_tensor(kernel_shape[0])
282
+ kw = _to_int_tensor(kernel_shape[1])
283
+ else:
284
+ kh = _to_int_tensor(tf.shape(weights)[0])
285
+ kw = _to_int_tensor(tf.shape(weights)[1])
286
+
287
+ # Base grid: [oH, oW, kH, kW, 2]
288
+ oy = tf.range(out_h, dtype=input_dtype) * tf.cast(stride_h, input_dtype)
289
+ ox = tf.range(out_w, dtype=input_dtype) * tf.cast(stride_w, input_dtype)
290
+ ky = tf.range(kh, dtype=input_dtype) * tf.cast(dilation_h, input_dtype)
291
+ kx = tf.range(kw, dtype=input_dtype) * tf.cast(dilation_w, input_dtype)
292
+
293
+ oy = tf.reshape(oy, tf.stack([out_h, 1, 1, 1]))
294
+ ox = tf.reshape(ox, tf.stack([1, out_w, 1, 1]))
295
+ ky = tf.reshape(ky, tf.stack([1, 1, kh, 1]))
296
+ kx = tf.reshape(kx, tf.stack([1, 1, 1, kw]))
297
+
298
+ y = oy + ky
299
+ x = ox + kx
300
+ target_shape = tf.stack([out_h, out_w, kh, kw])
301
+ y = tf.broadcast_to(y, target_shape)
302
+ x = tf.broadcast_to(x, target_shape)
303
+ base_grid = tf.stack([y, x], axis=-1)
304
+
305
+ # Offset reshape: [N, oH, oW, Goff, kH, kW, 2]
306
+ offset = tf.reshape(
307
+ offset,
308
+ tf.stack([batch, out_h, out_w, offset_group, kh, kw, 2]),
309
+ )
310
+
311
+ coords = base_grid[None, :, :, None, :, :, :] + offset
312
+ coords = tf.transpose(coords, [0, 3, 1, 2, 4, 5, 6])
313
+ coords = tf.reshape(coords, tf.stack([batch * offset_group, out_h, out_w, kh, kw, 2]))
314
+
315
+ # Input grouping for offset_group
316
+ c_per_offset = tf.math.floordiv(in_c, offset_group)
317
+ input_tensor = tf.reshape(
318
+ input_tensor,
319
+ tf.stack([batch, in_h, in_w, offset_group, c_per_offset]),
320
+ )
321
+ input_tensor = tf.transpose(input_tensor, [0, 3, 1, 2, 4])
322
+ input_tensor = tf.reshape(
323
+ input_tensor,
324
+ tf.stack([batch * offset_group, in_h, in_w, c_per_offset]),
325
+ )
326
+
327
+ sampled = _bilinear_sample_2d(input_tensor, coords)
328
+ sampled = tf.reshape(
329
+ sampled,
330
+ tf.stack([batch, offset_group, out_h, out_w, kh, kw, c_per_offset]),
331
+ )
332
+ sampled = tf.transpose(sampled, [0, 2, 3, 1, 4, 5, 6])
333
+
334
+ if mask is not None:
335
+ mask = tf.reshape(
336
+ mask,
337
+ tf.stack([batch, out_h, out_w, offset_group, kh, kw, 1]),
338
+ )
339
+ sampled = sampled * tf.cast(mask, sampled.dtype)
340
+
341
+ # Merge offset_group back to channel dim: [N, oH, oW, kH, kW, C]
342
+ sampled = tf.reshape(
343
+ sampled,
344
+ tf.stack([batch, out_h, out_w, kh, kw, in_c]),
345
+ )
346
+
347
+ # Grouped convolution via batched matmul
348
+ out_c = tf.shape(weights)[3]
349
+ c_per_group = tf.math.floordiv(in_c, group)
350
+ out_c_per_group = tf.math.floordiv(out_c, group)
351
+
352
+ cols = tf.reshape(sampled, tf.stack([batch * out_h * out_w, kh * kw * in_c]))
353
+ cols = tf.reshape(cols, tf.stack([batch * out_h * out_w, group, kh * kw * c_per_group]))
354
+ cols = tf.transpose(cols, [1, 0, 2])
355
+
356
+ weights = tf.reshape(weights, tf.stack([kh, kw, c_per_group, group, out_c_per_group]))
357
+ weights = tf.transpose(weights, [3, 0, 1, 2, 4])
358
+ weights = tf.reshape(weights, tf.stack([group, kh * kw * c_per_group, out_c_per_group]))
359
+
360
+ output = tf.matmul(cols, weights)
361
+ output = tf.transpose(output, [1, 0, 2])
362
+ output = tf.reshape(output, tf.stack([batch, out_h, out_w, out_c]))
363
+
364
+ if bias is not None:
365
+ output += tf.reshape(bias, tf.stack([1, 1, 1, out_c]))
366
+
367
+ if output.dtype != input_dtype:
368
+ output = tf.cast(output, input_dtype)
369
+
370
+ # Post-process transpose
371
+ tf_layers_dict[graph_node_output.name]['tf_node'] = post_process_transpose(
372
+ value_before_transpose=output,
373
+ param_target='outputs',
374
+ param_name=graph_node.outputs[0].name,
375
+ **kwargs,
376
+ )
377
+
378
+ # Generation of Debug Info
379
+ tf_layers_dict[graph_node_output.name]['tf_node_info'] = \
380
+ make_tf_node_info(
381
+ node_info={
382
+ 'tf_op_type': 'DeformConv',
383
+ 'tf_inputs': {
384
+ 'input_tensor': input_tensor,
385
+ 'weights': weights,
386
+ 'offset': offset,
387
+ 'bias': bias,
388
+ 'mask': mask,
389
+ 'strides': strides,
390
+ 'dilations': dilations,
391
+ 'pads': pads,
392
+ 'group': group,
393
+ 'offset_group': offset_group,
394
+ },
395
+ 'tf_outputs': {
396
+ 'output': tf_layers_dict[graph_node_output.name]['tf_node'],
397
+ },
398
+ }
399
+ )
@@ -57,9 +57,10 @@ def make_node(
57
57
  graph_node.inputs[0],
58
58
  before_op_output_shape_trans,
59
59
  )
60
+ # Indices must not be layout-transposed.
60
61
  graph_node_input_2 = get_constant_or_variable(
61
62
  graph_node.inputs[1],
62
- before_op_output_shape_trans,
63
+ False,
63
64
  )
64
65
  graph_node_output: gs.Variable = graph_node.outputs[0]
65
66
  shape = graph_node_output.shape
@@ -77,12 +78,29 @@ def make_node(
77
78
  param_name=graph_node.inputs[0].name,
78
79
  **kwargs,
79
80
  )
80
- indices_tensor = pre_process_transpose(
81
- value_before_transpose=indices_tensor,
82
- param_target='inputs',
83
- param_name=graph_node.inputs[1].name,
84
- **kwargs,
85
- )
81
+ # If input is transposed by replacement params, align indices tensor shape.
82
+ op_rep_params = kwargs.get('op_rep_params', [])
83
+ params_perm = None
84
+ indices_perm = None
85
+ for op_rep_param in op_rep_params:
86
+ if op_rep_param['param_target'] == 'inputs' \
87
+ and op_rep_param['param_name'] == graph_node.inputs[0].name:
88
+ params_perm = op_rep_param.get('pre_process_transpose_perm', None)
89
+ if op_rep_param['param_target'] == 'inputs' \
90
+ and op_rep_param['param_name'] == graph_node.inputs[1].name:
91
+ indices_perm = op_rep_param.get('pre_process_transpose_perm', None)
92
+ target_perm = indices_perm if indices_perm is not None else params_perm
93
+ if target_perm is not None:
94
+ try:
95
+ rank = len(indices_tensor.shape) if hasattr(indices_tensor, "shape") else None
96
+ if rank is None or rank == len(target_perm):
97
+ indices_tensor = transpose_with_flexing_deterrence(
98
+ input_tensor=indices_tensor,
99
+ perm=target_perm,
100
+ **kwargs,
101
+ )
102
+ except Exception:
103
+ pass
86
104
 
87
105
  tensor_rank = len(input_tensor.shape)
88
106
 
onnx2tf/ops/GatherND.py CHANGED
@@ -51,9 +51,10 @@ def make_node(
51
51
  graph_node.inputs[0],
52
52
  before_op_output_shape_trans,
53
53
  )
54
+ # Indices must not be layout-transposed.
54
55
  graph_node_input_2 = get_constant_or_variable(
55
56
  graph_node.inputs[1],
56
- before_op_output_shape_trans,
57
+ False,
57
58
  )
58
59
  graph_node_output: gs.Variable = graph_node.outputs[0]
59
60
  shape = graph_node_output.shape
@@ -89,6 +90,32 @@ def make_node(
89
90
 
90
91
  replace_gathernd_to_pseudo_gathernd = "gathernd" in kwargs['replace_to_pseudo_operators']
91
92
 
93
+ # If params is transposed, adjust indices to match the transposed layout.
94
+ op_rep_params = kwargs.get('op_rep_params', [])
95
+ params_perm = None
96
+ indices_perm_specified = False
97
+ for op_rep_param in op_rep_params:
98
+ if op_rep_param['param_target'] == 'inputs' and op_rep_param['param_name'] == graph_node.inputs[0].name:
99
+ params_perm = op_rep_param.get('pre_process_transpose_perm', None)
100
+ if op_rep_param['param_target'] == 'inputs' and op_rep_param['param_name'] == graph_node.inputs[1].name:
101
+ if op_rep_param.get('pre_process_transpose_perm', None) is not None:
102
+ indices_perm_specified = True
103
+ if params_perm is not None and not indices_perm_specified:
104
+ # Only handle standard layout swaps that keep batch dims at the front.
105
+ if batch_dims <= len(params_perm) \
106
+ and list(params_perm[:batch_dims]) == list(range(batch_dims)):
107
+ perm_tail = [p - batch_dims for p in params_perm if p >= batch_dims]
108
+ try:
109
+ if isinstance(indices_tensor, np.ndarray):
110
+ if indices_tensor.shape and indices_tensor.shape[-1] == len(perm_tail):
111
+ indices_tensor = indices_tensor[..., perm_tail]
112
+ else:
113
+ idx_last = indices_tensor.shape[-1] if indices_tensor.shape is not None else None
114
+ if idx_last is None or idx_last == len(perm_tail):
115
+ indices_tensor = tf.gather(indices_tensor, perm_tail, axis=-1)
116
+ except Exception:
117
+ pass
118
+
92
119
  # Preserving Graph Structure (Dict)
93
120
  tf_layers_dict[graph_node_output.name] = {
94
121
  'optype': graph_node.op,
@@ -55,9 +55,10 @@ def make_node(
55
55
  graph_node.inputs[0],
56
56
  before_op_output_shape_trans,
57
57
  )
58
+ # Indices must not be layout-transposed.
58
59
  graph_node_input_2 = get_constant_or_variable(
59
60
  graph_node.inputs[1],
60
- before_op_output_shape_trans,
61
+ False,
61
62
  )
62
63
  graph_node_input_3 = get_constant_or_variable(
63
64
  graph_node.inputs[2],
@@ -81,12 +82,29 @@ def make_node(
81
82
  indices_tensor = tf_layers_dict[graph_node_input_2.name]['tf_node'] \
82
83
  if isinstance(graph_node_input_2, gs.Variable) else graph_node_input_2
83
84
  # Pre-process transpose
84
- indices_tensor = pre_process_transpose(
85
- value_before_transpose=indices_tensor,
86
- param_target='inputs',
87
- param_name=graph_node.inputs[1].name,
88
- **kwargs,
89
- )
85
+ # If input is transposed by replacement params, align indices tensor shape.
86
+ op_rep_params = kwargs.get('op_rep_params', [])
87
+ params_perm = None
88
+ indices_perm = None
89
+ for op_rep_param in op_rep_params:
90
+ if op_rep_param['param_target'] == 'inputs' \
91
+ and op_rep_param['param_name'] == graph_node.inputs[0].name:
92
+ params_perm = op_rep_param.get('pre_process_transpose_perm', None)
93
+ if op_rep_param['param_target'] == 'inputs' \
94
+ and op_rep_param['param_name'] == graph_node.inputs[1].name:
95
+ indices_perm = op_rep_param.get('pre_process_transpose_perm', None)
96
+ target_perm = indices_perm if indices_perm is not None else params_perm
97
+ if target_perm is not None:
98
+ try:
99
+ rank = len(indices_tensor.shape) if hasattr(indices_tensor, "shape") else None
100
+ if rank is None or rank == len(target_perm):
101
+ indices_tensor = transpose_with_flexing_deterrence(
102
+ input_tensor=indices_tensor,
103
+ perm=target_perm,
104
+ **kwargs,
105
+ )
106
+ except Exception:
107
+ pass
90
108
  updates_tensor = tf_layers_dict[graph_node_input_3.name]['tf_node'] \
91
109
  if isinstance(graph_node_input_3, gs.Variable) else graph_node_input_3
92
110
  # Pre-process transpose
onnx2tf/ops/ScatterND.py CHANGED
@@ -13,6 +13,7 @@ from onnx2tf.utils.common_functions import (
13
13
  get_replacement_parameter,
14
14
  pre_process_transpose,
15
15
  post_process_transpose,
16
+ transpose_with_flexing_deterrence,
16
17
  )
17
18
 
18
19
 
@@ -79,6 +80,32 @@ def make_node(
79
80
  and 'nhwc' in tf_layers_dict[graph_node_input_1.name].keys() else False
80
81
  }
81
82
 
83
+ op_rep_params = kwargs.get('op_rep_params', [])
84
+ params_perm = None
85
+ indices_perm = None
86
+ for op_rep_param in op_rep_params:
87
+ if op_rep_param['param_target'] == 'inputs' \
88
+ and op_rep_param['param_name'] == graph_node.inputs[0].name:
89
+ params_perm = op_rep_param.get('pre_process_transpose_perm', None)
90
+ if op_rep_param['param_target'] == 'inputs' \
91
+ and op_rep_param['param_name'] == graph_node.inputs[1].name:
92
+ indices_perm = op_rep_param.get('pre_process_transpose_perm', None)
93
+
94
+ def reorder_indices_last_dim(target_indices, perm):
95
+ if perm is None:
96
+ return target_indices
97
+ try:
98
+ if isinstance(target_indices, np.ndarray):
99
+ if target_indices.shape and target_indices.shape[-1] == len(perm):
100
+ return target_indices[..., perm]
101
+ else:
102
+ idx_last = target_indices.shape[-1] if target_indices.shape is not None else None
103
+ if idx_last is None or idx_last == len(perm):
104
+ return tf.gather(target_indices, perm, axis=-1)
105
+ except Exception:
106
+ pass
107
+ return target_indices
108
+
82
109
  # Pre-process transpose
83
110
  input_tensor = pre_process_transpose(
84
111
  value_before_transpose=input_tensor,
@@ -86,18 +113,26 @@ def make_node(
86
113
  param_name=graph_node.inputs[0].name,
87
114
  **kwargs,
88
115
  )
89
- indices_tensor = pre_process_transpose(
90
- value_before_transpose=indices_tensor,
91
- param_target='inputs',
92
- param_name=graph_node.inputs[1].name,
93
- **kwargs,
94
- )
116
+ # Indices must not be layout-transposed; apply explicit perm only if specified.
117
+ if indices_perm is not None:
118
+ try:
119
+ rank = len(indices_tensor.shape) if hasattr(indices_tensor, "shape") else None
120
+ if rank is None or rank == len(indices_perm):
121
+ indices_tensor = transpose_with_flexing_deterrence(
122
+ input_tensor=indices_tensor,
123
+ perm=indices_perm,
124
+ **kwargs,
125
+ )
126
+ except Exception:
127
+ pass
95
128
  updates_tensor = pre_process_transpose(
96
129
  value_before_transpose=updates_tensor,
97
130
  param_target='inputs',
98
131
  param_name=graph_node.inputs[2].name,
99
132
  **kwargs,
100
133
  )
134
+ if params_perm is not None and indices_perm is None:
135
+ indices_tensor = reorder_indices_last_dim(indices_tensor, params_perm)
101
136
 
102
137
  # When NHWC is fixed, return to NCHW format before processing.
103
138
  data_nhwc = tf_layers_dict[graph_node_input_1.name]['nhwc'] \
@@ -119,6 +154,8 @@ def make_node(
119
154
  and len(input_tensor.shape) >= 3:
120
155
  perm = [0, len(input_tensor.shape)-1] + [i for i in range(1, len(input_tensor.shape)-1)]
121
156
  input_tensor = tf.transpose(a=input_tensor, perm=perm)
157
+ if indices_perm is None:
158
+ indices_tensor = reorder_indices_last_dim(indices_tensor, perm)
122
159
  nchw = True
123
160
  elif not data_nhwc \
124
161
  and len(input_tensor.shape) >= 3 \
@@ -126,6 +163,8 @@ def make_node(
126
163
  and input_tensor.shape != graph_node.inputs[0].shape:
127
164
  perm = [0, len(input_tensor.shape)-1] + [i for i in range(1, len(input_tensor.shape)-1)]
128
165
  input_tensor = tf.transpose(a=input_tensor, perm=perm)
166
+ if indices_perm is None:
167
+ indices_tensor = reorder_indices_last_dim(indices_tensor, perm)
129
168
  nchw = True
130
169
  ## indices
131
170
  if indices_nhwc \
@@ -14,6 +14,7 @@ from onnx2tf.utils.common_functions import (
14
14
  get_replacement_parameter,
15
15
  pre_process_transpose,
16
16
  post_process_transpose,
17
+ transpose_with_flexing_deterrence,
17
18
  )
18
19
  from onnx2tf.utils.enums import NUMPY_DTYPES_TO_TF_DTYPES
19
20
  from onnx2tf.utils.logging import *
@@ -112,12 +113,25 @@ def make_node(
112
113
  **kwargs,
113
114
  )
114
115
  if write_indices is not None:
115
- write_indices = pre_process_transpose(
116
- value_before_transpose=write_indices,
117
- param_target='inputs',
118
- param_name=graph_node.inputs[2].name,
119
- **kwargs,
120
- )
116
+ # Indices must not be layout-transposed; apply explicit perm only if specified.
117
+ op_rep_params = kwargs.get('op_rep_params', [])
118
+ indices_perm = None
119
+ for op_rep_param in op_rep_params:
120
+ if op_rep_param['param_target'] == 'inputs' \
121
+ and op_rep_param['param_name'] == graph_node.inputs[2].name:
122
+ indices_perm = op_rep_param.get('pre_process_transpose_perm', None)
123
+ break
124
+ if indices_perm is not None:
125
+ try:
126
+ rank = len(write_indices.shape) if hasattr(write_indices, "shape") else None
127
+ if rank is None or rank == len(indices_perm):
128
+ write_indices = transpose_with_flexing_deterrence(
129
+ input_tensor=write_indices,
130
+ perm=indices_perm,
131
+ **kwargs,
132
+ )
133
+ except Exception:
134
+ pass
121
135
 
122
136
  # Generation of TF OP
123
137
  past_cache = _as_tensor(past_cache)