onnx2tf 1.29.8__py3-none-any.whl → 1.29.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
onnx2tf/ops/LpPool.py ADDED
@@ -0,0 +1,296 @@
1
+ import random
2
+ random.seed(0)
3
+ import numpy as np
4
+ np.random.seed(0)
5
+ import tensorflow as tf
6
+ import onnx_graphsurgeon as gs
7
+ from onnx2tf.utils.common_functions import (
8
+ get_constant_or_variable,
9
+ remove_dilations,
10
+ print_node_info,
11
+ inverted_operation_enable_disable,
12
+ make_tf_node_info,
13
+ get_replacement_parameter,
14
+ pre_process_transpose,
15
+ post_process_transpose,
16
+ calc_tf_pooling_pads,
17
+ calc_extra_padding_with_ceil,
18
+ transpose_with_flexing_deterrence,
19
+ )
20
+ from onnx2tf.utils.logging import *
21
+
22
+ INF_INDEX_VALUE: int = 4294967296
23
+
24
+
25
+ def _kernel_size_const(kernel_shape, dtype):
26
+ if isinstance(kernel_shape, (list, tuple, np.ndarray)):
27
+ size = 1
28
+ for k in kernel_shape:
29
+ size *= int(k)
30
+ return tf.cast(size, dtype)
31
+ if tf.is_tensor(kernel_shape):
32
+ return tf.cast(tf.reduce_prod(kernel_shape), dtype)
33
+ return tf.cast(int(kernel_shape), dtype)
34
+
35
+
36
+ @print_node_info
37
+ @inverted_operation_enable_disable
38
+ @get_replacement_parameter
39
+ def make_node(
40
+ *,
41
+ graph_node: gs.Node,
42
+ tf_layers_dict: dict,
43
+ **kwargs: dict,
44
+ ):
45
+ """LpPool
46
+
47
+ Parameters
48
+ ----------
49
+ graph_node: gs.Node
50
+ graph_surgeon Node
51
+
52
+ tf_layers_dict: dict
53
+ optype, shape, dtype, tensorflow graph
54
+ """
55
+ before_op_output_shape_trans_1 = \
56
+ tf_layers_dict.get(graph_node.inputs[0].name, {}).get('before_op_output_shape_trans', True)
57
+ before_op_output_shape_trans = \
58
+ before_op_output_shape_trans_1
59
+
60
+ graph_node_input = get_constant_or_variable(
61
+ graph_node.inputs[0],
62
+ before_op_output_shape_trans,
63
+ )
64
+ graph_node_output: gs.Variable = graph_node.outputs[0]
65
+ shape = graph_node_output.shape
66
+ dtype = graph_node_output.dtype
67
+
68
+ input_tensor = tf_layers_dict[graph_node_input.name]['tf_node'] \
69
+ if isinstance(graph_node_input, gs.Variable) else graph_node_input
70
+ input_tensor_shape = input_tensor.shape
71
+ input_tensor_rank = len(input_tensor_shape)
72
+
73
+ # Pre-process transpose
74
+ input_tensor = pre_process_transpose(
75
+ value_before_transpose=input_tensor,
76
+ param_target='inputs',
77
+ param_name=graph_node.inputs[0].name,
78
+ **kwargs,
79
+ )
80
+
81
+ # Workaround to avoid as many conversion failures as possible
82
+ # for models with useless Transpose immediately before them.
83
+ # If the input geometry of the ONNX and the input geometry of the TF model match,
84
+ # the input geometry on the TF model side is forcibly transposed to the NWC or NHWC or NDHWC format.
85
+ # However, if all dimensions of CW or CHW or CDHW have the same value,
86
+ # the forced transposition process is skipped because it may destroy the structure of the model.
87
+ onnx_input_shape = [
88
+ dim if isinstance(dim, int) else None for dim in graph_node.inputs[0].shape
89
+ ] if graph_node.inputs[0].shape is not None else None
90
+ tf_input_shape = [
91
+ dim if isinstance(dim, int) else None for dim in input_tensor_shape
92
+ ]
93
+ if onnx_input_shape is not None \
94
+ and len(onnx_input_shape) > 1 and len(tf_input_shape) > 1 \
95
+ and onnx_input_shape == tf_input_shape:
96
+
97
+ shape_for_judging_skip = [
98
+ dim if dim is not None else INF_INDEX_VALUE for dim in onnx_input_shape[1:]
99
+ ]
100
+ if shape_for_judging_skip.count(shape_for_judging_skip[0]) != len(shape_for_judging_skip):
101
+ if len(onnx_input_shape) == 3:
102
+ # 1D
103
+ input_tensor = transpose_with_flexing_deterrence(
104
+ input_tensor=input_tensor,
105
+ perm=[0,2,1],
106
+ **kwargs,
107
+ )
108
+ elif len(onnx_input_shape) == 4:
109
+ # 2D
110
+ input_tensor = transpose_with_flexing_deterrence(
111
+ input_tensor=input_tensor,
112
+ perm=[0,2,3,1],
113
+ **kwargs,
114
+ )
115
+ elif len(onnx_input_shape) == 5:
116
+ # 3D
117
+ input_tensor = transpose_with_flexing_deterrence(
118
+ input_tensor=input_tensor,
119
+ perm=[0,2,3,4,1],
120
+ **kwargs,
121
+ )
122
+
123
+ auto_pad = graph_node.attrs.get('auto_pad', 'NOTSET')
124
+ ceil_mode = bool(graph_node.attrs.get('ceil_mode', 0))
125
+ kernel_shape = graph_node.attrs['kernel_shape']
126
+ spatial_size = len(kernel_shape)
127
+ dilations = graph_node.attrs.get('dilations', [1] * spatial_size)
128
+ pads = graph_node.attrs.get('pads', [0] * spatial_size * 2)
129
+ strides = graph_node.attrs.get('strides', [1] * spatial_size)
130
+ p = float(graph_node.attrs.get('p', 2))
131
+
132
+ input_tensor_shape = input_tensor.shape.as_list()
133
+ is_known_shape = None not in input_tensor_shape[1:]
134
+
135
+ # default tensorflow action is 'SAME_UPPER' mode (extra padding in the end for odd numbers)
136
+ # explicit pad layer is added for tensorflow incompatible cases
137
+ tf_pad_mode = 'VALID'
138
+ is_explicit_padding = False
139
+ dilated_kernel_shape = kernel_shape
140
+ if dilations != [1] * spatial_size:
141
+ dilated_kernel_shape = [(k - 1) * d for k, d in zip(kernel_shape, dilations)]
142
+
143
+ tf_pads = calc_tf_pooling_pads(
144
+ input_shape=input_tensor_shape,
145
+ kernel=dilated_kernel_shape,
146
+ strides=strides,
147
+ input_tensor=input_tensor,
148
+ )
149
+
150
+ # onnx padding value is ignored if auto_pad is not 'NOTSET'
151
+ if auto_pad == 'NOTSET':
152
+
153
+ # check if onnx padding is same with tensorflow padding mode 'SAME'
154
+ # this is to avoid flex operations since tflite has no builtin pooling with manual padding value
155
+ if is_known_shape and pads != [0] * spatial_size * 2 and tf_pads == pads:
156
+ auto_pad = 'SAME_UPPER'
157
+ tf_pad_mode = 'SAME'
158
+
159
+ else:
160
+ auto_pad = 'VALID'
161
+ is_explicit_padding = True
162
+
163
+ # extra padding may be needed for ceiling
164
+ # this padding is added to end side (right, bottom) only
165
+ if ceil_mode:
166
+ extra_pads = \
167
+ calc_extra_padding_with_ceil(
168
+ input_shape=input_tensor_shape[1:-1],
169
+ kernel=kernel_shape,
170
+ pads=pads,
171
+ dilations=dilations,
172
+ strides=strides,
173
+ )
174
+ pads = pads[:len(pads) // 2] + [p + e for p, e in zip(pads[len(pads) // 2:], extra_pads)]
175
+
176
+ tf_pads = pads
177
+
178
+ elif auto_pad == 'SAME_UPPER':
179
+ tf_pad_mode = 'SAME'
180
+
181
+ elif auto_pad == 'SAME_LOWER':
182
+ is_explicit_padding = True
183
+
184
+ elif auto_pad == 'VALID':
185
+ tf_pads = [0] * spatial_size * 2
186
+
187
+ else:
188
+ error_msg = Color.RED(f'ERROR:') + ' ' + \
189
+ f'Wrong auto_pad parameter in LpPool: {auto_pad}.'
190
+ raise ValueError(error_msg)
191
+
192
+ # add extra pad layer if needed
193
+ if is_explicit_padding and tf_pads != [0] * spatial_size * 2:
194
+ warn(
195
+ f'Tensorflow incompatible padding detected. ' \
196
+ f'Extra pad layer is inserted automatically. '
197
+ )
198
+
199
+ if auto_pad == 'SAME_LOWER':
200
+ # switch the order of pads
201
+ tf_pads = [i for tup in zip(tf_pads[len(tf_pads) // 2:], tf_pads[:len(tf_pads) // 2]) for i in tup]
202
+
203
+ # convert to tensorflow padding format
204
+ tf_pads = \
205
+ [[0, 0]] + \
206
+ [list(i) for i in zip(tf_pads[:len(tf_pads) // 2], tf_pads[len(tf_pads) // 2:])] + \
207
+ [[0, 0]]
208
+
209
+ padded_tensor = tf.pad(
210
+ tensor=input_tensor,
211
+ paddings=tf_pads,
212
+ mode='CONSTANT',
213
+ constant_values=0.0,
214
+ )
215
+
216
+ else:
217
+ padded_tensor = input_tensor
218
+
219
+ # Preserving Graph Structure (Dict)
220
+ tf_layers_dict[graph_node_output.name] = {
221
+ 'optype': graph_node.op,
222
+ 'shape': shape,
223
+ 'dtype': dtype,
224
+ 'nhwc': True,
225
+ }
226
+
227
+ # Generation of TF OP
228
+ tf_op_type = None
229
+ abs_p_tensor = tf.pow(tf.abs(padded_tensor), p) if p != 1.0 else tf.abs(padded_tensor)
230
+ kernel_size_const = _kernel_size_const(kernel_shape, abs_p_tensor.dtype)
231
+
232
+ if spatial_size < 4 and (strides == [1] * spatial_size or dilations == [1] * spatial_size):
233
+ pooled_tensor = tf.nn.pool(
234
+ input=abs_p_tensor,
235
+ window_shape=kernel_shape,
236
+ dilations=dilations,
237
+ strides=strides,
238
+ padding=tf_pad_mode.upper(),
239
+ pooling_type='AVG',
240
+ )
241
+ tf_op_type = tf.nn.pool
242
+ else:
243
+ # TODO: Dilated LpPool with strides is broken for 3D and above, need to be fixed
244
+ if spatial_size >= 3:
245
+ error_msg = Color.RED(f'ERROR:') + ' ' \
246
+ f'Dilated LpPool with strides is not supported for 3D and above for now. '
247
+ print(error_msg)
248
+ raise NotImplementedError(error_msg)
249
+
250
+ abs_p_tensor = remove_dilations(
251
+ input_tensor=abs_p_tensor,
252
+ kernel_shape=kernel_shape,
253
+ spatial_size=spatial_size,
254
+ strides=strides,
255
+ dilations=dilations,
256
+ )
257
+ pooled_tensor = tf.nn.pool(
258
+ input=abs_p_tensor,
259
+ window_shape=kernel_shape,
260
+ strides=kernel_shape,
261
+ padding='VALID',
262
+ pooling_type='AVG',
263
+ )
264
+ tf_op_type = tf.nn.pool
265
+
266
+ pooled_tensor = pooled_tensor * kernel_size_const
267
+ tf_layers_dict[graph_node_output.name]['tf_node'] = \
268
+ tf.pow(pooled_tensor, 1.0 / p) if p != 1.0 else pooled_tensor
269
+
270
+ # Post-process transpose
271
+ tf_layers_dict[graph_node_output.name]['tf_node'] = post_process_transpose(
272
+ value_before_transpose=tf_layers_dict[graph_node_output.name]['tf_node'],
273
+ param_target='outputs',
274
+ param_name=graph_node.outputs[0].name,
275
+ **kwargs,
276
+ )
277
+
278
+ # Generation of Debug Info
279
+ tf_layers_dict[graph_node_output.name]['tf_node_info'] = \
280
+ make_tf_node_info(
281
+ node_info={
282
+ 'tf_op_type': tf_op_type,
283
+ 'tf_inputs': {
284
+ 'input': input_tensor,
285
+ 'kernel_shape': kernel_shape,
286
+ 'strides': strides,
287
+ 'dilations': dilations,
288
+ 'padding': tf_pads if tf_pad_mode != 'same' else tf_pad_mode,
289
+ 'ceil_mode': ceil_mode,
290
+ 'p': p,
291
+ },
292
+ 'tf_outputs': {
293
+ 'output': tf_layers_dict[graph_node_output.name]['tf_node'],
294
+ },
295
+ }
296
+ )
@@ -0,0 +1,236 @@
1
+ import random
2
+ random.seed(0)
3
+ import numpy as np
4
+ np.random.seed(0)
5
+ import tensorflow as tf
6
+ import onnx_graphsurgeon as gs
7
+ from onnx2tf.utils.common_functions import (
8
+ get_constant_or_variable,
9
+ print_node_info,
10
+ inverted_operation_enable_disable,
11
+ make_tf_node_info,
12
+ get_replacement_parameter,
13
+ pre_process_transpose,
14
+ post_process_transpose,
15
+ transpose_with_flexing_deterrence,
16
+ )
17
+ from onnx2tf.utils.logging import *
18
+
19
+ INF_INDEX_VALUE: int = 4294967296
20
+
21
+
22
+ @print_node_info
23
+ @inverted_operation_enable_disable
24
+ @get_replacement_parameter
25
+ def make_node(
26
+ *,
27
+ graph_node: gs.Node,
28
+ tf_layers_dict: dict,
29
+ **kwargs: dict,
30
+ ):
31
+ """MaxRoiPool
32
+
33
+ Parameters
34
+ ----------
35
+ graph_node: gs.Node
36
+ graph_surgeon Node
37
+
38
+ tf_layers_dict: dict
39
+ optype, shape, dtype, tensorflow graph
40
+ """
41
+ before_op_output_shape_trans_1 = \
42
+ tf_layers_dict.get(graph_node.inputs[0].name, {}).get('before_op_output_shape_trans', True)
43
+ before_op_output_shape_trans_2 = \
44
+ tf_layers_dict.get(graph_node.inputs[1].name, {}).get('before_op_output_shape_trans', True)
45
+ before_op_output_shape_trans = \
46
+ before_op_output_shape_trans_1 and before_op_output_shape_trans_2
47
+
48
+ graph_node_input_1 = get_constant_or_variable(
49
+ graph_node.inputs[0],
50
+ before_op_output_shape_trans,
51
+ )
52
+ graph_node_input_2 = get_constant_or_variable(
53
+ graph_node.inputs[1],
54
+ before_op_output_shape_trans,
55
+ )
56
+ graph_node_output: gs.Variable = graph_node.outputs[0]
57
+ shape = graph_node_output.shape
58
+ dtype = graph_node_output.dtype
59
+
60
+ input_tensor = tf_layers_dict[graph_node_input_1.name]['tf_node'] \
61
+ if isinstance(graph_node_input_1, gs.Variable) else graph_node_input_1
62
+ rois = tf_layers_dict[graph_node_input_2.name]['tf_node'] \
63
+ if isinstance(graph_node_input_2, gs.Variable) else graph_node_input_2
64
+
65
+ # Pre-process transpose
66
+ input_tensor = pre_process_transpose(
67
+ value_before_transpose=input_tensor,
68
+ param_target='inputs',
69
+ param_name=graph_node.inputs[0].name,
70
+ **kwargs,
71
+ )
72
+
73
+ # Workaround to avoid as many conversion failures as possible
74
+ # for models with useless Transpose immediately before them.
75
+ # If the input geometry of the ONNX and the input geometry of the TF model match,
76
+ # the input geometry on the TF model side is forcibly transposed to the NHWC format.
77
+ # However, if all dimensions of CHW have the same value,
78
+ # the forced transposition process is skipped because it may destroy the structure of the model.
79
+ onnx_input_shape = [
80
+ dim if isinstance(dim, int) else None for dim in graph_node.inputs[0].shape
81
+ ] if graph_node.inputs[0].shape is not None else None
82
+ tf_input_shape = [
83
+ dim if isinstance(dim, int) else None for dim in input_tensor.shape
84
+ ]
85
+ if onnx_input_shape is not None \
86
+ and len(onnx_input_shape) > 1 and len(tf_input_shape) > 1 \
87
+ and onnx_input_shape == tf_input_shape:
88
+
89
+ shape_for_judging_skip = [
90
+ dim if dim is not None else INF_INDEX_VALUE for dim in onnx_input_shape[1:]
91
+ ]
92
+ if shape_for_judging_skip.count(shape_for_judging_skip[0]) != len(shape_for_judging_skip):
93
+ if len(onnx_input_shape) == 4:
94
+ # 2D
95
+ input_tensor = transpose_with_flexing_deterrence(
96
+ input_tensor=input_tensor,
97
+ perm=[0,2,3,1],
98
+ **kwargs,
99
+ )
100
+
101
+ pooled_shape = graph_node.attrs.get('pooled_shape', None)
102
+ if pooled_shape is None or len(pooled_shape) != 2:
103
+ error_msg = \
104
+ Color.RED(f'ERROR:') + ' ' + \
105
+ f'pooled_shape is required for MaxRoiPool. ' \
106
+ f'graph_node.name: {graph_node.name}, pooled_shape: {pooled_shape}'
107
+ print(error_msg)
108
+ raise ValueError(error_msg)
109
+
110
+ pooled_h = int(pooled_shape[0])
111
+ pooled_w = int(pooled_shape[1])
112
+ spatial_scale = float(graph_node.attrs.get('spatial_scale', 1.0))
113
+
114
+ # Preserving Graph Structure (Dict)
115
+ tf_layers_dict[graph_node_output.name] = {
116
+ 'optype': graph_node.op,
117
+ 'shape': shape,
118
+ 'dtype': dtype,
119
+ 'nhwc': True,
120
+ }
121
+
122
+ # Generation of TF OP
123
+ rois = tf.cast(rois, tf.float32)
124
+ if rois.shape.rank == 1:
125
+ rois = tf.expand_dims(rois, axis=0)
126
+
127
+ channels_static = input_tensor.shape[-1]
128
+ channel_spec = tf.TensorSpec(
129
+ shape=(channels_static,) if channels_static is not None else (None,),
130
+ dtype=input_tensor.dtype,
131
+ )
132
+ row_spec = tf.TensorSpec(
133
+ shape=(pooled_w, channels_static) if channels_static is not None else (pooled_w, None),
134
+ dtype=input_tensor.dtype,
135
+ )
136
+ roi_spec = tf.TensorSpec(
137
+ shape=(pooled_h, pooled_w, channels_static) if channels_static is not None else (pooled_h, pooled_w, None),
138
+ dtype=input_tensor.dtype,
139
+ )
140
+
141
+ def roi_pool_single(roi):
142
+ batch_idx = tf.cast(roi[0], tf.int32)
143
+ x1, y1, x2, y2 = tf.unstack(roi[1:5])
144
+ x1 = x1 * spatial_scale
145
+ y1 = y1 * spatial_scale
146
+ x2 = x2 * spatial_scale
147
+ y2 = y2 * spatial_scale
148
+
149
+ roi_start_w = tf.cast(tf.round(x1), tf.int32)
150
+ roi_start_h = tf.cast(tf.round(y1), tf.int32)
151
+ roi_end_w = tf.cast(tf.round(x2), tf.int32)
152
+ roi_end_h = tf.cast(tf.round(y2), tf.int32)
153
+
154
+ height = tf.shape(input_tensor)[1]
155
+ width = tf.shape(input_tensor)[2]
156
+
157
+ roi_start_w = tf.clip_by_value(roi_start_w, 0, width)
158
+ roi_start_h = tf.clip_by_value(roi_start_h, 0, height)
159
+ roi_end_w = tf.clip_by_value(roi_end_w, 0, width)
160
+ roi_end_h = tf.clip_by_value(roi_end_h, 0, height)
161
+
162
+ roi_width = tf.maximum(roi_end_w - roi_start_w + 1, 1)
163
+ roi_height = tf.maximum(roi_end_h - roi_start_h + 1, 1)
164
+
165
+ bin_size_h = tf.cast(roi_height, tf.float32) / tf.cast(pooled_h, tf.float32)
166
+ bin_size_w = tf.cast(roi_width, tf.float32) / tf.cast(pooled_w, tf.float32)
167
+
168
+ channels_dynamic = tf.shape(input_tensor)[-1]
169
+ zero = tf.zeros([channels_dynamic], dtype=input_tensor.dtype)
170
+
171
+ def pool_bin(ph, pw):
172
+ ph_f = tf.cast(ph, tf.float32)
173
+ pw_f = tf.cast(pw, tf.float32)
174
+ hstart = tf.cast(tf.floor(ph_f * bin_size_h), tf.int32) + roi_start_h
175
+ hend = tf.cast(tf.ceil((ph_f + 1.0) * bin_size_h), tf.int32) + roi_start_h
176
+ wstart = tf.cast(tf.floor(pw_f * bin_size_w), tf.int32) + roi_start_w
177
+ wend = tf.cast(tf.ceil((pw_f + 1.0) * bin_size_w), tf.int32) + roi_start_w
178
+
179
+ hstart = tf.clip_by_value(hstart, 0, height)
180
+ hend = tf.clip_by_value(hend, 0, height)
181
+ wstart = tf.clip_by_value(wstart, 0, width)
182
+ wend = tf.clip_by_value(wend, 0, width)
183
+
184
+ is_empty = tf.logical_or(hend <= hstart, wend <= wstart)
185
+
186
+ def do_max():
187
+ region = input_tensor[batch_idx, hstart:hend, wstart:wend, :]
188
+ return tf.reduce_max(region, axis=[0,1])
189
+
190
+ return tf.cond(is_empty, lambda: zero, do_max)
191
+
192
+ def pool_row(ph):
193
+ return tf.map_fn(
194
+ lambda pw: pool_bin(ph, pw),
195
+ tf.range(pooled_w),
196
+ fn_output_signature=channel_spec,
197
+ )
198
+
199
+ return tf.map_fn(
200
+ pool_row,
201
+ tf.range(pooled_h),
202
+ fn_output_signature=row_spec,
203
+ )
204
+
205
+ pooled_tensor = tf.map_fn(
206
+ roi_pool_single,
207
+ rois,
208
+ fn_output_signature=roi_spec,
209
+ )
210
+
211
+ tf_layers_dict[graph_node_output.name]['tf_node'] = pooled_tensor
212
+
213
+ # Post-process transpose
214
+ tf_layers_dict[graph_node_output.name]['tf_node'] = post_process_transpose(
215
+ value_before_transpose=tf_layers_dict[graph_node_output.name]['tf_node'],
216
+ param_target='outputs',
217
+ param_name=graph_node.outputs[0].name,
218
+ **kwargs,
219
+ )
220
+
221
+ # Generation of Debug Info
222
+ tf_layers_dict[graph_node_output.name]['tf_node_info'] = \
223
+ make_tf_node_info(
224
+ node_info={
225
+ 'tf_op_type': 'MaxRoiPool',
226
+ 'tf_inputs': {
227
+ 'input': input_tensor,
228
+ 'rois': rois,
229
+ 'pooled_shape': pooled_shape,
230
+ 'spatial_scale': spatial_scale,
231
+ },
232
+ 'tf_outputs': {
233
+ 'output': tf_layers_dict[graph_node_output.name]['tf_node'],
234
+ },
235
+ }
236
+ )
onnx2tf/ops/Unsqueeze.py CHANGED
@@ -69,52 +69,58 @@ def make_node(
69
69
  if isinstance(graph_node_input_1, gs.Variable) \
70
70
  and 'nhwc' in tf_layers_dict[graph_node_input_1.name].keys() else False
71
71
 
72
+ input_tensor_shape = None
73
+ tensor_rank = None
72
74
  if input_tensor.shape != tf.TensorShape(None):
73
75
  input_tensor_shape = list(input_tensor.shape)
74
- tensor_rank = len(input_tensor_shape)
75
- elif graph_node_output.shape is not None:
76
+ elif graph_node_output.shape is not None and axes is not None:
76
77
  input_tensor_shape = [
77
78
  dim for idx, dim in enumerate(graph_node_output.shape) if idx not in axes
78
79
  ]
79
80
  input_tensor_shape = [
80
81
  dim if not isinstance(dim, str) else None for dim in input_tensor_shape
81
82
  ]
83
+ if input_tensor_shape is not None:
82
84
  tensor_rank = len(input_tensor_shape)
83
85
 
84
86
  if isinstance(axes, list) or (isinstance(axes, np.ndarray) and len(axes.shape) > 0):
85
- if nhwc:
86
- axes = [
87
- convert_axis(
88
- axis=idx,
89
- tensor_rank=tensor_rank+len(axes),
90
- before_op_output_shape_trans=True,
91
- ) for idx in axes
92
- ]
93
- elif not nhwc and (isinstance(axes, list) and len(axes) == 1 or isinstance(axes, np.ndarray) and len(axes.shape) == 1) and axes[0] == -1:
94
- axes = [
95
- convert_axis(
96
- axis=idx,
97
- tensor_rank=tensor_rank+len(axes),
98
- before_op_output_shape_trans=before_op_output_shape_trans,
99
- ) for idx in axes
100
- ]
87
+ if tensor_rank is not None:
88
+ if nhwc:
89
+ axes = [
90
+ convert_axis(
91
+ axis=idx,
92
+ tensor_rank=tensor_rank+len(axes),
93
+ before_op_output_shape_trans=True,
94
+ ) for idx in axes
95
+ ]
96
+ elif not nhwc and (isinstance(axes, list) and len(axes) == 1 or isinstance(axes, np.ndarray) and len(axes.shape) == 1) and axes[0] == -1:
97
+ axes = [
98
+ convert_axis(
99
+ axis=idx,
100
+ tensor_rank=tensor_rank+len(axes),
101
+ before_op_output_shape_trans=before_op_output_shape_trans,
102
+ ) for idx in axes
103
+ ]
104
+ else:
105
+ axes = [idx for idx in axes]
101
106
  else:
102
- axes = [idx for idx in axes]
107
+ axes = [int(idx) for idx in axes]
103
108
  elif axes is not None and isinstance(axes, np.ndarray) and len(axes.shape) == 0:
104
- if nhwc:
105
- axes = convert_axis(
106
- axis=axes,
107
- tensor_rank=tensor_rank+1,
108
- before_op_output_shape_trans=True,
109
- )
110
- elif not nhwc and (isinstance(axes, list) and len(axes) == 1 or isinstance(axes, np.ndarray) and len(axes.shape) == 1) and axes[0] == -1:
111
- axes = [
112
- convert_axis(
113
- axis=idx,
114
- tensor_rank=tensor_rank+len(axes),
115
- before_op_output_shape_trans=before_op_output_shape_trans,
116
- ) for idx in axes
117
- ]
109
+ if tensor_rank is not None:
110
+ if nhwc:
111
+ axes = convert_axis(
112
+ axis=axes,
113
+ tensor_rank=tensor_rank+1,
114
+ before_op_output_shape_trans=True,
115
+ )
116
+ elif not nhwc and (isinstance(axes, list) and len(axes) == 1 or isinstance(axes, np.ndarray) and len(axes.shape) == 1) and axes[0] == -1:
117
+ axes = [
118
+ convert_axis(
119
+ axis=idx,
120
+ tensor_rank=tensor_rank+len(axes),
121
+ before_op_output_shape_trans=before_op_output_shape_trans,
122
+ ) for idx in axes
123
+ ]
118
124
  axes = list(axes[np.newaxis])
119
125
 
120
126
  if axes is not None and isinstance(axes, list) and len(axes) > 0:
@@ -128,11 +134,13 @@ def make_node(
128
134
  **kwargs,
129
135
  )
130
136
 
131
- new_shape = copy.deepcopy(input_tensor_shape)
132
- for idx in axes:
133
- new_shape.insert(idx, 1)
137
+ new_shape = None
138
+ if input_tensor_shape is not None and axes is not None:
139
+ new_shape = copy.deepcopy(input_tensor_shape)
140
+ for idx in axes:
141
+ new_shape.insert(idx, 1)
134
142
 
135
- new_shape = [dim if dim is not None else -1 for dim in new_shape]
143
+ new_shape = [dim if dim is not None else -1 for dim in new_shape]
136
144
 
137
145
  # Preserving Graph Structure (Dict)
138
146
  tf_layers_dict[graph_node_output.name] = {
@@ -234,6 +242,29 @@ def make_node(
234
242
  tf.identity(input=input_tensor)
235
243
  tf_type = tf.identity
236
244
 
245
+ elif not shape_replaced \
246
+ and new_shape is None:
247
+ axes_list = axes
248
+ if axes_list is None:
249
+ axes_list = []
250
+ elif isinstance(axes_list, np.ndarray):
251
+ axes_list = axes_list.tolist() if axes_list.shape != () else [int(axes_list)]
252
+ elif not isinstance(axes_list, list):
253
+ axes_list = [int(axes_list)]
254
+
255
+ unsqueeze_tensor = input_tensor
256
+ for axis_idx, axis in enumerate(axes_list):
257
+ axis_val = int(axis) if isinstance(axis, (np.integer, np.int64, np.int32)) else axis
258
+ if isinstance(axis_val, int) and axis_val < 0:
259
+ axis_val = tf.rank(unsqueeze_tensor) + axis_val + 1
260
+ unsqueeze_tensor = tf.expand_dims(
261
+ input=unsqueeze_tensor,
262
+ axis=axis_val,
263
+ name=graph_node.name if axis_idx == len(axes_list) - 1 else None,
264
+ )
265
+ tf_layers_dict[graph_node_output.name]['tf_node'] = unsqueeze_tensor
266
+ tf_type = tf.expand_dims
267
+
237
268
  elif not shape_replaced \
238
269
  and nhwc \
239
270
  and len(axes) == 1 \
@@ -247,6 +278,7 @@ def make_node(
247
278
  tf_type = tf.expand_dims
248
279
 
249
280
  elif not shape_replaced \
281
+ and new_shape is not None \
250
282
  and len(new_shape) >= 2 \
251
283
  and len([dim for dim in new_shape if dim is None or dim == -1]) >= 2 \
252
284
  and not isinstance(axes, int) \
@@ -276,6 +276,8 @@ def make_tf_node_info(**kwargs):
276
276
  def print_node_info(func):
277
277
  @wraps(func)
278
278
  def print_wrapper_func(*args, **kwargs):
279
+ if kwargs.get('suppress_log', False):
280
+ return func(*args, **kwargs)
279
281
  input_onnx_file_path: str = kwargs.get('input_onnx_file_path', None)
280
282
  graph_input: gs.Variable = kwargs.get('graph_input', None)
281
283
  graph_node: gs.Variable = kwargs.get('graph_node', None)
@@ -4051,6 +4053,7 @@ def dummy_tf_inference(
4051
4053
  for idx, dim in enumerate(input_size):
4052
4054
  if idx == 0 and input_sizes[0][0] is not None \
4053
4055
  and len(input_sizes[0]) == len(input_size) \
4056
+ and len(input_size) >= 2 \
4054
4057
  and dim is None:
4055
4058
  # Batch size assignment for input OPs
4056
4059
  new_input_size.append(input_sizes[0][0])