onnx2tf 1.23.3__py3-none-any.whl → 1.25.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2tf/__init__.py +1 -1
- onnx2tf/onnx2tf.py +181 -30
- onnx2tf/ops/Add.py +29 -0
- onnx2tf/ops/AveragePool.py +20 -10
- onnx2tf/ops/BatchNormalization.py +270 -24
- onnx2tf/ops/Concat.py +4 -4
- onnx2tf/ops/DepthToSpace.py +8 -0
- onnx2tf/ops/Div.py +30 -0
- onnx2tf/ops/Expand.py +207 -0
- onnx2tf/ops/Gather.py +67 -18
- onnx2tf/ops/Mod.py +29 -0
- onnx2tf/ops/Mul.py +30 -0
- onnx2tf/ops/ReduceL1.py +3 -0
- onnx2tf/ops/ReduceL2.py +3 -0
- onnx2tf/ops/ReduceLogSum.py +3 -0
- onnx2tf/ops/ReduceLogSumExp.py +3 -0
- onnx2tf/ops/ReduceMax.py +3 -0
- onnx2tf/ops/ReduceMean.py +3 -0
- onnx2tf/ops/ReduceMin.py +3 -0
- onnx2tf/ops/ReduceProd.py +3 -0
- onnx2tf/ops/ReduceSum.py +3 -0
- onnx2tf/ops/ReduceSumSquare.py +3 -0
- onnx2tf/ops/Shape.py +2 -0
- onnx2tf/ops/Sub.py +29 -0
- onnx2tf/ops/Transpose.py +14 -0
- onnx2tf/utils/common_functions.py +2 -2
- {onnx2tf-1.23.3.dist-info → onnx2tf-1.25.8.dist-info}/METADATA +269 -28
- {onnx2tf-1.23.3.dist-info → onnx2tf-1.25.8.dist-info}/RECORD +33 -33
- {onnx2tf-1.23.3.dist-info → onnx2tf-1.25.8.dist-info}/WHEEL +1 -1
- {onnx2tf-1.23.3.dist-info → onnx2tf-1.25.8.dist-info}/LICENSE +0 -0
- {onnx2tf-1.23.3.dist-info → onnx2tf-1.25.8.dist-info}/LICENSE_onnx-tensorflow +0 -0
- {onnx2tf-1.23.3.dist-info → onnx2tf-1.25.8.dist-info}/entry_points.txt +0 -0
- {onnx2tf-1.23.3.dist-info → onnx2tf-1.25.8.dist-info}/top_level.txt +0 -0
onnx2tf/ops/Gather.py
CHANGED
|
@@ -64,13 +64,79 @@ def make_node(
|
|
|
64
64
|
indices = tf_layers_dict[graph_node_input_2.name]['tf_node'] \
|
|
65
65
|
if isinstance(graph_node_input_2, gs.Variable) else graph_node_input_2
|
|
66
66
|
|
|
67
|
+
axis = graph_node.attrs.get("axis", 0)
|
|
68
|
+
axis = convert_axis(
|
|
69
|
+
axis=axis,
|
|
70
|
+
tensor_rank=len(input_tensor.shape),
|
|
71
|
+
before_op_output_shape_trans=before_op_output_shape_trans,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# Param replacement - axis
|
|
75
|
+
axis = replace_parameter(
|
|
76
|
+
value_before_replacement=axis,
|
|
77
|
+
param_target='attributes',
|
|
78
|
+
param_name='axis',
|
|
79
|
+
**kwargs,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
nhwc = tf_layers_dict[graph_node_input_1.name]['nhwc'] \
|
|
83
|
+
if isinstance(graph_node_input_1, gs.Variable) \
|
|
84
|
+
and 'nhwc' in tf_layers_dict[graph_node_input_1.name].keys() else False
|
|
85
|
+
|
|
67
86
|
before_cast_indices = None
|
|
68
87
|
if isinstance(indices, np.ndarray) and indices.ndim == 1 and len(indices) == 1 and indices[0] is not None:
|
|
69
88
|
if indices[0] >= 0:
|
|
70
89
|
before_cast_indices = indices[0]
|
|
90
|
+
# 直前が Shape だった場合のみの特別なワークアラウンドで、入力がNHWCで確定しているときはindicesを変換する
|
|
91
|
+
# 1. ind=0 のときはそのまま
|
|
92
|
+
# 2. ind=1 のときは末尾
|
|
93
|
+
# 3. ind=2 のときは1
|
|
94
|
+
# 4. ind=3 のときは2
|
|
95
|
+
# 5. ind=4 のときは3
|
|
96
|
+
#
|
|
97
|
+
# ONNX: ind=2
|
|
98
|
+
# 0,1,2 -> 0,2,1
|
|
99
|
+
# 0,1,2,3 -> 0,2,3,1
|
|
100
|
+
# 0,1,2,3,4 -> 0,2,3,4,1
|
|
101
|
+
# 0,1,2,3,4,5 -> 0,2,3,4,5,1
|
|
102
|
+
if nhwc and graph_node.i().op == 'Shape':
|
|
103
|
+
input_tensor_rank = input_tensor.shape[0]
|
|
104
|
+
if before_cast_indices == 0:
|
|
105
|
+
# batch
|
|
106
|
+
pass
|
|
107
|
+
elif before_cast_indices == 1:
|
|
108
|
+
# channel
|
|
109
|
+
before_cast_indices = input_tensor_rank - 1
|
|
110
|
+
else:
|
|
111
|
+
# spartial dim
|
|
112
|
+
before_cast_indices = before_cast_indices - 1
|
|
113
|
+
|
|
71
114
|
elif isinstance(indices, np.ndarray) and indices.ndim == 0 and indices is not None:
|
|
72
115
|
if indices >= 0:
|
|
73
116
|
before_cast_indices = int(indices)
|
|
117
|
+
# 直前が Shape だった場合のみの特別なワークアラウンドで、入力がNHWCで確定しているときはindicesを変換する
|
|
118
|
+
# 1. ind=0 のときはそのまま
|
|
119
|
+
# 2. ind=1 のときは末尾
|
|
120
|
+
# 3. ind=2 のときは1
|
|
121
|
+
# 4. ind=3 のときは2
|
|
122
|
+
# 5. ind=4 のときは3
|
|
123
|
+
#
|
|
124
|
+
# ONNX: ind=2
|
|
125
|
+
# 0,1,2 -> 0,2,1
|
|
126
|
+
# 0,1,2,3 -> 0,2,3,1
|
|
127
|
+
# 0,1,2,3,4 -> 0,2,3,4,1
|
|
128
|
+
# 0,1,2,3,4,5 -> 0,2,3,4,5,1
|
|
129
|
+
if nhwc and graph_node.i().op == 'Shape':
|
|
130
|
+
input_tensor_rank = input_tensor.shape[0]
|
|
131
|
+
if before_cast_indices == 0:
|
|
132
|
+
# batch
|
|
133
|
+
pass
|
|
134
|
+
elif before_cast_indices == 1:
|
|
135
|
+
# channel
|
|
136
|
+
before_cast_indices = input_tensor_rank - 1
|
|
137
|
+
else:
|
|
138
|
+
# spartial dim
|
|
139
|
+
before_cast_indices = before_cast_indices - 1
|
|
74
140
|
|
|
75
141
|
simple_indices = None
|
|
76
142
|
if isinstance(indices, np.ndarray) and indices.ndim == 1 and None not in indices:
|
|
@@ -81,21 +147,6 @@ def make_node(
|
|
|
81
147
|
shape = graph_node_output.shape
|
|
82
148
|
dtype = graph_node_output.dtype
|
|
83
149
|
|
|
84
|
-
axis = graph_node.attrs.get("axis", 0)
|
|
85
|
-
axis = convert_axis(
|
|
86
|
-
axis=axis,
|
|
87
|
-
tensor_rank=len(input_tensor.shape),
|
|
88
|
-
before_op_output_shape_trans=before_op_output_shape_trans,
|
|
89
|
-
)
|
|
90
|
-
|
|
91
|
-
# Param replacement - axis
|
|
92
|
-
axis = replace_parameter(
|
|
93
|
-
value_before_replacement=axis,
|
|
94
|
-
param_target='attributes',
|
|
95
|
-
param_name='axis',
|
|
96
|
-
**kwargs,
|
|
97
|
-
)
|
|
98
|
-
|
|
99
150
|
optimization_for_gpu_delegate: bool = \
|
|
100
151
|
kwargs['optimization_for_gpu_delegate']
|
|
101
152
|
|
|
@@ -141,9 +192,7 @@ def make_node(
|
|
|
141
192
|
'optype': graph_node.op,
|
|
142
193
|
'shape': shape,
|
|
143
194
|
'dtype': dtype,
|
|
144
|
-
'nhwc':
|
|
145
|
-
if isinstance(graph_node_input_1, gs.Variable) \
|
|
146
|
-
and 'nhwc' in tf_layers_dict[graph_node_input_1.name].keys() else False
|
|
195
|
+
'nhwc': nhwc,
|
|
147
196
|
}
|
|
148
197
|
|
|
149
198
|
# Param replacement
|
onnx2tf/ops/Mod.py
CHANGED
|
@@ -157,8 +157,37 @@ def make_node(
|
|
|
157
157
|
is_scalar_2_rank = tf.rank(input_tensor_2) == 0
|
|
158
158
|
if hasattr(is_scalar_2_rank, 'numpy'):
|
|
159
159
|
is_scalar_2 = is_scalar_2_rank.numpy()
|
|
160
|
+
|
|
160
161
|
if (is_scalar_1 or is_scalar_2) and graph_node.i().op == 'Gemm':
|
|
161
162
|
pass
|
|
163
|
+
elif (is_scalar_1 or is_scalar_2) and graph_node.i().op != 'Gemm':
|
|
164
|
+
first_tensor = None
|
|
165
|
+
second_tensor = None
|
|
166
|
+
if is_scalar_1:
|
|
167
|
+
first_tensor = input_tensor_2
|
|
168
|
+
second_tensor = input_tensor_1
|
|
169
|
+
elif is_scalar_2:
|
|
170
|
+
first_tensor = input_tensor_1
|
|
171
|
+
second_tensor = input_tensor_2
|
|
172
|
+
tmp_result = tf.math.mod(first_tensor, second_tensor)
|
|
173
|
+
tmp_result_shape = tmp_result.shape
|
|
174
|
+
if first_tensor.shape == tmp_result_shape:
|
|
175
|
+
pass
|
|
176
|
+
else:
|
|
177
|
+
input_tensor_1, input_tensor_2 = \
|
|
178
|
+
pre_explicit_broadcast(
|
|
179
|
+
input_tensor_1=input_tensor_1,
|
|
180
|
+
input_tensor_2=input_tensor_2,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
input_tensor_1, input_tensor_2 = \
|
|
184
|
+
explicit_broadcast(
|
|
185
|
+
const_or_var_1=input_tensor_1,
|
|
186
|
+
const_or_var_2=input_tensor_2,
|
|
187
|
+
graph_node=graph_node,
|
|
188
|
+
tf_layers_dict= tf_layers_dict,
|
|
189
|
+
)
|
|
190
|
+
|
|
162
191
|
else:
|
|
163
192
|
input_tensor_1, input_tensor_2 = \
|
|
164
193
|
pre_explicit_broadcast(
|
onnx2tf/ops/Mul.py
CHANGED
|
@@ -152,14 +152,44 @@ def make_node(
|
|
|
152
152
|
try:
|
|
153
153
|
is_scalar_1 = False
|
|
154
154
|
is_scalar_2 = False
|
|
155
|
+
is_partial_scalar = False
|
|
155
156
|
is_scalar_1_rank = tf.rank(input_tensor_1) == 0
|
|
156
157
|
if hasattr(is_scalar_1_rank, 'numpy'):
|
|
157
158
|
is_scalar_1 = is_scalar_1_rank.numpy()
|
|
158
159
|
is_scalar_2_rank = tf.rank(input_tensor_2) == 0
|
|
159
160
|
if hasattr(is_scalar_2_rank, 'numpy'):
|
|
160
161
|
is_scalar_2 = is_scalar_2_rank.numpy()
|
|
162
|
+
|
|
161
163
|
if (is_scalar_1 or is_scalar_2) and graph_node.i().op == 'Gemm':
|
|
162
164
|
pass
|
|
165
|
+
elif (is_scalar_1 or is_scalar_2) and graph_node.i().op != 'Gemm':
|
|
166
|
+
first_tensor = None
|
|
167
|
+
second_tensor = None
|
|
168
|
+
if is_scalar_1:
|
|
169
|
+
first_tensor = input_tensor_2
|
|
170
|
+
second_tensor = input_tensor_1
|
|
171
|
+
elif is_scalar_2:
|
|
172
|
+
first_tensor = input_tensor_1
|
|
173
|
+
second_tensor = input_tensor_2
|
|
174
|
+
tmp_result = tf.math.multiply(first_tensor, second_tensor)
|
|
175
|
+
tmp_result_shape = tmp_result.shape
|
|
176
|
+
if first_tensor.shape == tmp_result_shape:
|
|
177
|
+
pass
|
|
178
|
+
else:
|
|
179
|
+
input_tensor_1, input_tensor_2 = \
|
|
180
|
+
pre_explicit_broadcast(
|
|
181
|
+
input_tensor_1=input_tensor_1,
|
|
182
|
+
input_tensor_2=input_tensor_2,
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
input_tensor_1, input_tensor_2 = \
|
|
186
|
+
explicit_broadcast(
|
|
187
|
+
const_or_var_1=input_tensor_1,
|
|
188
|
+
const_or_var_2=input_tensor_2,
|
|
189
|
+
graph_node=graph_node,
|
|
190
|
+
tf_layers_dict= tf_layers_dict,
|
|
191
|
+
)
|
|
192
|
+
|
|
163
193
|
else:
|
|
164
194
|
input_tensor_1, input_tensor_2 = \
|
|
165
195
|
pre_explicit_broadcast(
|
onnx2tf/ops/ReduceL1.py
CHANGED
|
@@ -106,6 +106,9 @@ def make_node(
|
|
|
106
106
|
'optype': graph_node.op,
|
|
107
107
|
'shape': onnx_output_shape,
|
|
108
108
|
'dtype': dtype,
|
|
109
|
+
'nhwc': tf_layers_dict[graph_node_input_1.name]['nhwc'] \
|
|
110
|
+
if isinstance(graph_node_input_1, gs.Variable) \
|
|
111
|
+
and 'nhwc' in tf_layers_dict[graph_node_input_1.name].keys() else False
|
|
109
112
|
}
|
|
110
113
|
|
|
111
114
|
onnx_tensor_infos_for_validation: Dict[str:np.ndarray] = kwargs['onnx_tensor_infos_for_validation']
|
onnx2tf/ops/ReduceL2.py
CHANGED
|
@@ -106,6 +106,9 @@ def make_node(
|
|
|
106
106
|
'optype': graph_node.op,
|
|
107
107
|
'shape': onnx_output_shape,
|
|
108
108
|
'dtype': dtype,
|
|
109
|
+
'nhwc': tf_layers_dict[graph_node_input_1.name]['nhwc'] \
|
|
110
|
+
if isinstance(graph_node_input_1, gs.Variable) \
|
|
111
|
+
and 'nhwc' in tf_layers_dict[graph_node_input_1.name].keys() else False
|
|
109
112
|
}
|
|
110
113
|
|
|
111
114
|
onnx_tensor_infos_for_validation: Dict[str:np.ndarray] = kwargs['onnx_tensor_infos_for_validation']
|
onnx2tf/ops/ReduceLogSum.py
CHANGED
|
@@ -106,6 +106,9 @@ def make_node(
|
|
|
106
106
|
'optype': graph_node.op,
|
|
107
107
|
'shape': onnx_output_shape,
|
|
108
108
|
'dtype': dtype,
|
|
109
|
+
'nhwc': tf_layers_dict[graph_node_input_1.name]['nhwc'] \
|
|
110
|
+
if isinstance(graph_node_input_1, gs.Variable) \
|
|
111
|
+
and 'nhwc' in tf_layers_dict[graph_node_input_1.name].keys() else False
|
|
109
112
|
}
|
|
110
113
|
|
|
111
114
|
onnx_tensor_infos_for_validation: Dict[str:np.ndarray] = kwargs['onnx_tensor_infos_for_validation']
|
onnx2tf/ops/ReduceLogSumExp.py
CHANGED
|
@@ -106,6 +106,9 @@ def make_node(
|
|
|
106
106
|
'optype': graph_node.op,
|
|
107
107
|
'shape': onnx_output_shape,
|
|
108
108
|
'dtype': dtype,
|
|
109
|
+
'nhwc': tf_layers_dict[graph_node_input_1.name]['nhwc'] \
|
|
110
|
+
if isinstance(graph_node_input_1, gs.Variable) \
|
|
111
|
+
and 'nhwc' in tf_layers_dict[graph_node_input_1.name].keys() else False
|
|
109
112
|
}
|
|
110
113
|
|
|
111
114
|
onnx_tensor_infos_for_validation: Dict[str:np.ndarray] = kwargs['onnx_tensor_infos_for_validation']
|
onnx2tf/ops/ReduceMax.py
CHANGED
|
@@ -113,6 +113,9 @@ def make_node(
|
|
|
113
113
|
'optype': graph_node.op,
|
|
114
114
|
'shape': onnx_output_shape,
|
|
115
115
|
'dtype': dtype,
|
|
116
|
+
'nhwc': tf_layers_dict[graph_node_input_1.name]['nhwc'] \
|
|
117
|
+
if isinstance(graph_node_input_1, gs.Variable) \
|
|
118
|
+
and 'nhwc' in tf_layers_dict[graph_node_input_1.name].keys() else False
|
|
116
119
|
}
|
|
117
120
|
|
|
118
121
|
onnx_tensor_infos_for_validation: Dict[str:np.ndarray] = kwargs['onnx_tensor_infos_for_validation']
|
onnx2tf/ops/ReduceMean.py
CHANGED
|
@@ -107,6 +107,9 @@ def make_node(
|
|
|
107
107
|
'optype': graph_node.op,
|
|
108
108
|
'shape': onnx_output_shape,
|
|
109
109
|
'dtype': dtype,
|
|
110
|
+
'nhwc': tf_layers_dict[graph_node_input_1.name]['nhwc'] \
|
|
111
|
+
if isinstance(graph_node_input_1, gs.Variable) \
|
|
112
|
+
and 'nhwc' in tf_layers_dict[graph_node_input_1.name].keys() else False
|
|
110
113
|
}
|
|
111
114
|
|
|
112
115
|
onnx_tensor_infos_for_validation: Dict[str:np.ndarray] = kwargs['onnx_tensor_infos_for_validation']
|
onnx2tf/ops/ReduceMin.py
CHANGED
|
@@ -107,6 +107,9 @@ def make_node(
|
|
|
107
107
|
'optype': graph_node.op,
|
|
108
108
|
'shape': onnx_output_shape,
|
|
109
109
|
'dtype': dtype,
|
|
110
|
+
'nhwc': tf_layers_dict[graph_node_input_1.name]['nhwc'] \
|
|
111
|
+
if isinstance(graph_node_input_1, gs.Variable) \
|
|
112
|
+
and 'nhwc' in tf_layers_dict[graph_node_input_1.name].keys() else False
|
|
110
113
|
}
|
|
111
114
|
|
|
112
115
|
onnx_tensor_infos_for_validation: Dict[str:np.ndarray] = kwargs['onnx_tensor_infos_for_validation']
|
onnx2tf/ops/ReduceProd.py
CHANGED
|
@@ -107,6 +107,9 @@ def make_node(
|
|
|
107
107
|
'optype': graph_node.op,
|
|
108
108
|
'shape': onnx_output_shape,
|
|
109
109
|
'dtype': dtype,
|
|
110
|
+
'nhwc': tf_layers_dict[graph_node_input_1.name]['nhwc'] \
|
|
111
|
+
if isinstance(graph_node_input_1, gs.Variable) \
|
|
112
|
+
and 'nhwc' in tf_layers_dict[graph_node_input_1.name].keys() else False
|
|
110
113
|
}
|
|
111
114
|
|
|
112
115
|
onnx_tensor_infos_for_validation: Dict[str:np.ndarray] = kwargs['onnx_tensor_infos_for_validation']
|
onnx2tf/ops/ReduceSum.py
CHANGED
|
@@ -106,6 +106,9 @@ def make_node(
|
|
|
106
106
|
'optype': graph_node.op,
|
|
107
107
|
'shape': onnx_output_shape,
|
|
108
108
|
'dtype': dtype,
|
|
109
|
+
'nhwc': tf_layers_dict[graph_node_input_1.name]['nhwc'] \
|
|
110
|
+
if isinstance(graph_node_input_1, gs.Variable) \
|
|
111
|
+
and 'nhwc' in tf_layers_dict[graph_node_input_1.name].keys() else False
|
|
109
112
|
}
|
|
110
113
|
|
|
111
114
|
onnx_tensor_infos_for_validation: Dict[str:np.ndarray] = kwargs['onnx_tensor_infos_for_validation']
|
onnx2tf/ops/ReduceSumSquare.py
CHANGED
|
@@ -107,6 +107,9 @@ def make_node(
|
|
|
107
107
|
'optype': graph_node.op,
|
|
108
108
|
'shape': onnx_output_shape,
|
|
109
109
|
'dtype': dtype,
|
|
110
|
+
'nhwc': tf_layers_dict[graph_node_input_1.name]['nhwc'] \
|
|
111
|
+
if isinstance(graph_node_input_1, gs.Variable) \
|
|
112
|
+
and 'nhwc' in tf_layers_dict[graph_node_input_1.name].keys() else False
|
|
110
113
|
}
|
|
111
114
|
|
|
112
115
|
onnx_tensor_infos_for_validation: Dict[str:np.ndarray] = kwargs['onnx_tensor_infos_for_validation']
|
onnx2tf/ops/Shape.py
CHANGED
onnx2tf/ops/Sub.py
CHANGED
|
@@ -156,8 +156,37 @@ def make_node(
|
|
|
156
156
|
is_scalar_2_rank = tf.rank(input_tensor_2) == 0
|
|
157
157
|
if hasattr(is_scalar_2_rank, 'numpy'):
|
|
158
158
|
is_scalar_2 = is_scalar_2_rank.numpy()
|
|
159
|
+
|
|
159
160
|
if (is_scalar_1 or is_scalar_2) and graph_node.i().op == 'Gemm':
|
|
160
161
|
pass
|
|
162
|
+
elif (is_scalar_1 or is_scalar_2) and graph_node.i().op != 'Gemm':
|
|
163
|
+
first_tensor = None
|
|
164
|
+
second_tensor = None
|
|
165
|
+
if is_scalar_1:
|
|
166
|
+
first_tensor = input_tensor_2
|
|
167
|
+
second_tensor = input_tensor_1
|
|
168
|
+
elif is_scalar_2:
|
|
169
|
+
first_tensor = input_tensor_1
|
|
170
|
+
second_tensor = input_tensor_2
|
|
171
|
+
tmp_result = tf.math.subtract(first_tensor, second_tensor)
|
|
172
|
+
tmp_result_shape = tmp_result.shape
|
|
173
|
+
if first_tensor.shape == tmp_result_shape:
|
|
174
|
+
pass
|
|
175
|
+
else:
|
|
176
|
+
input_tensor_1, input_tensor_2 = \
|
|
177
|
+
pre_explicit_broadcast(
|
|
178
|
+
input_tensor_1=input_tensor_1,
|
|
179
|
+
input_tensor_2=input_tensor_2,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
input_tensor_1, input_tensor_2 = \
|
|
183
|
+
explicit_broadcast(
|
|
184
|
+
const_or_var_1=input_tensor_1,
|
|
185
|
+
const_or_var_2=input_tensor_2,
|
|
186
|
+
graph_node=graph_node,
|
|
187
|
+
tf_layers_dict= tf_layers_dict,
|
|
188
|
+
)
|
|
189
|
+
|
|
161
190
|
else:
|
|
162
191
|
input_tensor_1, input_tensor_2 = \
|
|
163
192
|
pre_explicit_broadcast(
|
onnx2tf/ops/Transpose.py
CHANGED
|
@@ -160,10 +160,24 @@ def make_node(
|
|
|
160
160
|
perm = new_perm
|
|
161
161
|
|
|
162
162
|
# Preserving Graph Structure (Dict)
|
|
163
|
+
nwhc = False
|
|
164
|
+
if nwc_nhwc_ndhwc_keep:
|
|
165
|
+
nhwc = True
|
|
166
|
+
elif isinstance(graph_node_input, gs.Variable) \
|
|
167
|
+
and 'nhwc' in tf_layers_dict[graph_node_input.name].keys():
|
|
168
|
+
nhwc = tf_layers_dict[graph_node_input.name]['nhwc']
|
|
169
|
+
if nhwc and not before_op_output_shape_trans and perm == [i for i in range(len(input_tensor_shape))]:
|
|
170
|
+
nhwc = True
|
|
171
|
+
else:
|
|
172
|
+
nhwc = False
|
|
173
|
+
else:
|
|
174
|
+
nhwc = False
|
|
175
|
+
|
|
163
176
|
tf_layers_dict[graph_node_output.name] = {
|
|
164
177
|
'optype': graph_node.op,
|
|
165
178
|
'shape': output_shape,
|
|
166
179
|
'dtype': dtype,
|
|
180
|
+
'nhwc': nhwc,
|
|
167
181
|
'nwc_nhwc_ndhwc_keep': nwc_nhwc_ndhwc_keep,
|
|
168
182
|
}
|
|
169
183
|
|
|
@@ -2185,7 +2185,7 @@ def process_neg_idx(
|
|
|
2185
2185
|
and not isinstance(indices_shape[-1], np.ndarray) \
|
|
2186
2186
|
and not isinstance(indices_shape[-1], tf.Tensor) \
|
|
2187
2187
|
and tf_keras.backend.is_keras_tensor(indices_shape[-1]):
|
|
2188
|
-
if
|
|
2188
|
+
if tf.TensorShape(None) not in data_shape :
|
|
2189
2189
|
max_i = tf.cast(
|
|
2190
2190
|
tf.strided_slice(
|
|
2191
2191
|
input_=data_shape,
|
|
@@ -4584,7 +4584,7 @@ def rewrite_tflite_inout_opname(
|
|
|
4584
4584
|
result = subprocess.check_output(
|
|
4585
4585
|
[
|
|
4586
4586
|
'curl',
|
|
4587
|
-
'https://raw.githubusercontent.com/tensorflow/tensorflow/v2.
|
|
4587
|
+
'https://raw.githubusercontent.com/tensorflow/tensorflow/v2.17.0-rc1/tensorflow/compiler/mlir/lite/schema/schema.fbs',
|
|
4588
4588
|
'-o',
|
|
4589
4589
|
f'{output_folder_path}/schema.fbs'
|
|
4590
4590
|
],
|