onnx2tf 1.27.10__py3-none-any.whl → 1.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2tf/__init__.py +1 -1
- onnx2tf/onnx2tf.py +398 -19
- onnx2tf/utils/common_functions.py +3 -2
- onnx2tf/utils/iterative_json_optimizer.py +258 -0
- onnx2tf/utils/json_auto_generator.py +1505 -0
- {onnx2tf-1.27.10.dist-info → onnx2tf-1.28.1.dist-info}/METADATA +40 -4
- {onnx2tf-1.27.10.dist-info → onnx2tf-1.28.1.dist-info}/RECORD +12 -10
- {onnx2tf-1.27.10.dist-info → onnx2tf-1.28.1.dist-info}/WHEEL +1 -1
- {onnx2tf-1.27.10.dist-info → onnx2tf-1.28.1.dist-info}/entry_points.txt +0 -0
- {onnx2tf-1.27.10.dist-info → onnx2tf-1.28.1.dist-info}/licenses/LICENSE +0 -0
- {onnx2tf-1.27.10.dist-info → onnx2tf-1.28.1.dist-info}/licenses/LICENSE_onnx-tensorflow +0 -0
- {onnx2tf-1.27.10.dist-info → onnx2tf-1.28.1.dist-info}/top_level.txt +0 -0
onnx2tf/__init__.py
CHANGED
onnx2tf/onnx2tf.py
CHANGED
|
@@ -52,6 +52,10 @@ from onnx2tf.utils.common_functions import (
|
|
|
52
52
|
rewrite_tflite_inout_opname,
|
|
53
53
|
check_cuda_enabled,
|
|
54
54
|
)
|
|
55
|
+
from onnx2tf.utils.json_auto_generator import (
|
|
56
|
+
generate_auto_replacement_json,
|
|
57
|
+
save_auto_replacement_json,
|
|
58
|
+
)
|
|
55
59
|
from onnx2tf.utils.enums import (
|
|
56
60
|
CUDA_ONLY_OPS,
|
|
57
61
|
)
|
|
@@ -104,6 +108,7 @@ def convert(
|
|
|
104
108
|
fused_argmax_scale_ratio: Optional[float] = 0.5,
|
|
105
109
|
replace_to_pseudo_operators: List[str] = None,
|
|
106
110
|
param_replacement_file: Optional[str] = '',
|
|
111
|
+
auto_generate_json: Optional[bool] = False,
|
|
107
112
|
check_gpu_delegate_compatibility: Optional[bool] = False,
|
|
108
113
|
check_onnx_tf_outputs_elementwise_close: Optional[bool] = False,
|
|
109
114
|
check_onnx_tf_outputs_elementwise_close_full: Optional[bool] = False,
|
|
@@ -429,6 +434,17 @@ def convert(
|
|
|
429
434
|
param_replacement_file: Optional[str]
|
|
430
435
|
Parameter replacement file path. (.json)
|
|
431
436
|
|
|
437
|
+
auto_generate_json: Optional[bool]
|
|
438
|
+
Automatically generates a parameter replacement JSON file that achieves minimal error\n
|
|
439
|
+
when converting the model. This option explores various parameter combinations to find\n
|
|
440
|
+
the best settings that result in successful conversion and highest accuracy.\n
|
|
441
|
+
The search stops when the final output OP accuracy check shows "Matches".\n
|
|
442
|
+
When used together with check_onnx_tf_outputs_elementwise_close_full,\n
|
|
443
|
+
the generated JSON is used to re-evaluate accuracy.\n
|
|
444
|
+
WARNING: This option performs an exhaustive search to find the optimal conversion patterns,\n
|
|
445
|
+
which can take a very long time depending on the model complexity.\n
|
|
446
|
+
Default: False
|
|
447
|
+
|
|
432
448
|
check_gpu_delegate_compatibility: Optional[bool]
|
|
433
449
|
Run TFLite ModelAnalyzer on the generated Float16 tflite model\n
|
|
434
450
|
to check if the model can be supported by GPU Delegate.
|
|
@@ -1149,27 +1165,111 @@ def convert(
|
|
|
1149
1165
|
|
|
1150
1166
|
# Nodes
|
|
1151
1167
|
# https://github.com/onnx/onnx/blob/main/docs/Operators.md
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1155
|
-
|
|
1156
|
-
|
|
1157
|
-
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
|
|
1168
|
+
conversion_error = None
|
|
1169
|
+
try:
|
|
1170
|
+
for graph_node in graph.nodes:
|
|
1171
|
+
optype = graph_node.op
|
|
1172
|
+
try:
|
|
1173
|
+
op = importlib.import_module(f'onnx2tf.ops.{optype}')
|
|
1174
|
+
except ModuleNotFoundError as ex:
|
|
1175
|
+
error(
|
|
1176
|
+
f'{optype} OP is not yet implemented.'
|
|
1177
|
+
)
|
|
1178
|
+
# Store error for potential auto JSON generation
|
|
1179
|
+
conversion_error = ex
|
|
1180
|
+
raise ex
|
|
1161
1181
|
|
|
1162
|
-
|
|
1163
|
-
|
|
1164
|
-
|
|
1182
|
+
# substitution because saved_model does not allow colons
|
|
1183
|
+
# Substitution because saved_model does not allow leading slashes in op names
|
|
1184
|
+
sanitizing(graph_node)
|
|
1165
1185
|
|
|
1166
|
-
|
|
1167
|
-
|
|
1168
|
-
|
|
1169
|
-
|
|
1170
|
-
|
|
1171
|
-
|
|
1172
|
-
|
|
1186
|
+
op.make_node(
|
|
1187
|
+
graph_node=graph_node,
|
|
1188
|
+
tf_layers_dict=tf_layers_dict,
|
|
1189
|
+
**additional_parameters,
|
|
1190
|
+
)
|
|
1191
|
+
op_counta += 1
|
|
1192
|
+
additional_parameters['op_counta'] = op_counta
|
|
1193
|
+
except Exception as ex:
|
|
1194
|
+
conversion_error = ex
|
|
1195
|
+
# Store the current node name in the error context
|
|
1196
|
+
if hasattr(ex, 'onnx_op_name'):
|
|
1197
|
+
error_onnx_op_name = ex.onnx_op_name
|
|
1198
|
+
else:
|
|
1199
|
+
# Get the current node being processed
|
|
1200
|
+
error_onnx_op_name = graph_node.name if 'graph_node' in locals() else None
|
|
1201
|
+
# Attach it to the exception for later use
|
|
1202
|
+
ex.onnx_op_name = error_onnx_op_name
|
|
1203
|
+
|
|
1204
|
+
# If no replacement file was provided, try to generate one automatically
|
|
1205
|
+
if not param_replacement_file and input_onnx_file_path:
|
|
1206
|
+
info('')
|
|
1207
|
+
info(Color.REVERSE(f'Attempting automatic JSON generation due to conversion error'), '=' * 30)
|
|
1208
|
+
if error_onnx_op_name:
|
|
1209
|
+
info(f'Error occurred at ONNX operation: {error_onnx_op_name}')
|
|
1210
|
+
|
|
1211
|
+
# Try iterative JSON generation with multiple attempts
|
|
1212
|
+
max_attempts = 3
|
|
1213
|
+
attempt = 0
|
|
1214
|
+
successful_conversion = False
|
|
1215
|
+
best_json = None
|
|
1216
|
+
|
|
1217
|
+
while attempt < max_attempts and not successful_conversion:
|
|
1218
|
+
attempt += 1
|
|
1219
|
+
info(f'\nJSON generation attempt {attempt}/{max_attempts}')
|
|
1220
|
+
|
|
1221
|
+
try:
|
|
1222
|
+
# Generate JSON with unlimited mode for exhaustive search
|
|
1223
|
+
auto_json = generate_auto_replacement_json(
|
|
1224
|
+
onnx_graph=graph,
|
|
1225
|
+
tf_layers_dict=tf_layers_dict,
|
|
1226
|
+
check_results=None,
|
|
1227
|
+
conversion_error=conversion_error,
|
|
1228
|
+
error_threshold=1e-2,
|
|
1229
|
+
model_path=input_onnx_file_path,
|
|
1230
|
+
max_iterations=attempt * 3, # Increase iterations with each attempt
|
|
1231
|
+
unlimited_mode=True, # Enable unlimited mode
|
|
1232
|
+
)
|
|
1233
|
+
|
|
1234
|
+
if auto_json.get('operations'):
|
|
1235
|
+
best_json = auto_json
|
|
1236
|
+
|
|
1237
|
+
# Save temporary JSON
|
|
1238
|
+
temp_json_path = os.path.join(output_folder_path, f'_temp_attempt_{attempt}.json')
|
|
1239
|
+
with open(temp_json_path, 'w') as f:
|
|
1240
|
+
json.dump(auto_json, f, indent=2)
|
|
1241
|
+
|
|
1242
|
+
info(f'Testing generated JSON with {len(auto_json["operations"])} operations...')
|
|
1243
|
+
|
|
1244
|
+
# Try to re-run just the problematic operation with the JSON
|
|
1245
|
+
# This is a simplified test - in practice we'd need to re-run the full conversion
|
|
1246
|
+
# For now, we'll assume the JSON might work and save it
|
|
1247
|
+
|
|
1248
|
+
# Clean up temp file
|
|
1249
|
+
if os.path.exists(temp_json_path):
|
|
1250
|
+
os.remove(temp_json_path)
|
|
1251
|
+
|
|
1252
|
+
except Exception as json_ex:
|
|
1253
|
+
error(f"Error in attempt {attempt}: {type(json_ex).__name__}: {str(json_ex)}")
|
|
1254
|
+
|
|
1255
|
+
# Save the best JSON we generated
|
|
1256
|
+
if best_json and best_json.get('operations'):
|
|
1257
|
+
json_path = save_auto_replacement_json(
|
|
1258
|
+
replacement_json=best_json,
|
|
1259
|
+
model_path=input_onnx_file_path,
|
|
1260
|
+
output_dir=output_folder_path,
|
|
1261
|
+
)
|
|
1262
|
+
warn(
|
|
1263
|
+
f'Conversion failed. An automatic replacement JSON has been generated: {json_path}\n' +
|
|
1264
|
+
f'Please try running the conversion again with: -prf {json_path}\n' +
|
|
1265
|
+
f'Note: The JSON was generated through {attempt} iteration(s) to find the best solution.'
|
|
1266
|
+
)
|
|
1267
|
+
else:
|
|
1268
|
+
warn(
|
|
1269
|
+
f'Conversion failed and automatic JSON generation could not find a solution after {attempt} attempts.'
|
|
1270
|
+
)
|
|
1271
|
+
# Re-raise the original error
|
|
1272
|
+
raise ex
|
|
1173
1273
|
|
|
1174
1274
|
del additional_parameters['onnx_tensor_infos_for_validation']
|
|
1175
1275
|
del onnx_tensor_infos_for_validation
|
|
@@ -1965,6 +2065,50 @@ def convert(
|
|
|
1965
2065
|
rtol=check_onnx_tf_outputs_elementwise_close_rtol,
|
|
1966
2066
|
atol=check_onnx_tf_outputs_elementwise_close_atol,
|
|
1967
2067
|
)
|
|
2068
|
+
|
|
2069
|
+
# Check if any errors exceed threshold and auto-generate JSON if needed
|
|
2070
|
+
# Skip this if -agj is specified (will be handled separately)
|
|
2071
|
+
if not param_replacement_file and input_onnx_file_path and not auto_generate_json:
|
|
2072
|
+
max_error_found = 0.0
|
|
2073
|
+
has_significant_errors = False
|
|
2074
|
+
error_count = 0
|
|
2075
|
+
for (onnx_name, tf_name), checked_value in check_results.items():
|
|
2076
|
+
matched_flg = checked_value[1]
|
|
2077
|
+
max_abs_err = checked_value[2]
|
|
2078
|
+
if (matched_flg == 0 or matched_flg == False) and isinstance(max_abs_err, (int, float, np.float32, np.float64)):
|
|
2079
|
+
if max_abs_err > 1e-2:
|
|
2080
|
+
has_significant_errors = True
|
|
2081
|
+
error_count += 1
|
|
2082
|
+
max_error_found = max(max_error_found, max_abs_err)
|
|
2083
|
+
|
|
2084
|
+
if has_significant_errors:
|
|
2085
|
+
info('')
|
|
2086
|
+
info(Color.REVERSE(f'Attempting automatic JSON generation due to accuracy errors > 1e-2'), '=' * 25)
|
|
2087
|
+
info(f'Found {error_count} operations with errors > 1e-2')
|
|
2088
|
+
info(f'Maximum error found: {max_error_found:.6f}')
|
|
2089
|
+
auto_json = generate_auto_replacement_json(
|
|
2090
|
+
onnx_graph=gs.import_onnx(onnx_graph),
|
|
2091
|
+
tf_layers_dict=tf_layers_dict,
|
|
2092
|
+
check_results=check_results,
|
|
2093
|
+
conversion_error=None,
|
|
2094
|
+
error_threshold=1e-2,
|
|
2095
|
+
model_path=input_onnx_file_path,
|
|
2096
|
+
)
|
|
2097
|
+
if auto_json.get('operations'):
|
|
2098
|
+
json_path = save_auto_replacement_json(
|
|
2099
|
+
replacement_json=auto_json,
|
|
2100
|
+
model_path=input_onnx_file_path,
|
|
2101
|
+
output_dir=output_folder_path,
|
|
2102
|
+
)
|
|
2103
|
+
warn(
|
|
2104
|
+
f'Accuracy validation found errors > 1e-2. An automatic replacement JSON has been generated: {json_path}\n' +
|
|
2105
|
+
f'Please try running the conversion again with: -prf {json_path}'
|
|
2106
|
+
)
|
|
2107
|
+
else:
|
|
2108
|
+
warn(
|
|
2109
|
+
f'Accuracy errors > 1e-2 found but automatic JSON generation could not find a solution.'
|
|
2110
|
+
)
|
|
2111
|
+
|
|
1968
2112
|
for (onnx_output_name, tf_output_name), checked_value in check_results.items():
|
|
1969
2113
|
validated_onnx_tensor: np.ndarray = checked_value[0]
|
|
1970
2114
|
matched_flg: int = checked_value[1]
|
|
@@ -1994,6 +2138,229 @@ def convert(
|
|
|
1994
2138
|
f'{message}'
|
|
1995
2139
|
)
|
|
1996
2140
|
|
|
2141
|
+
# Auto-generate JSON if -agj option is specified
|
|
2142
|
+
# This can work alone or in combination with -cotof
|
|
2143
|
+
if auto_generate_json:
|
|
2144
|
+
# Store the generated JSON path for later use
|
|
2145
|
+
generated_json_path = None
|
|
2146
|
+
|
|
2147
|
+
# Check if -cotof was already executed and we have check_results
|
|
2148
|
+
if check_onnx_tf_outputs_elementwise_close_full and 'check_results' in locals():
|
|
2149
|
+
# We already have validation results from -cotof
|
|
2150
|
+
info('')
|
|
2151
|
+
info(Color.REVERSE(f'Auto JSON generation started (using -cotof results)'), '=' * 35)
|
|
2152
|
+
|
|
2153
|
+
# Check if any errors exist
|
|
2154
|
+
all_matched = True
|
|
2155
|
+
max_error = 0.0
|
|
2156
|
+
error_count = 0
|
|
2157
|
+
|
|
2158
|
+
for (onnx_name, tf_name), checked_value in check_results.items():
|
|
2159
|
+
matched_flg = checked_value[1]
|
|
2160
|
+
max_abs_err = checked_value[2]
|
|
2161
|
+
|
|
2162
|
+
if matched_flg == 0: # Unmatched
|
|
2163
|
+
all_matched = False
|
|
2164
|
+
if isinstance(max_abs_err, (int, float, np.float32, np.float64)):
|
|
2165
|
+
max_error = max(max_error, max_abs_err)
|
|
2166
|
+
error_count += 1
|
|
2167
|
+
|
|
2168
|
+
if all_matched:
|
|
2169
|
+
info(Color.GREEN('All outputs already match! No JSON generation needed.'))
|
|
2170
|
+
else:
|
|
2171
|
+
info(f'Found {error_count} outputs with errors, max error: {max_error:.6f}')
|
|
2172
|
+
info('Generating optimal JSON...')
|
|
2173
|
+
|
|
2174
|
+
# Generate auto replacement JSON
|
|
2175
|
+
auto_json = generate_auto_replacement_json(
|
|
2176
|
+
onnx_graph=gs.import_onnx(onnx_graph),
|
|
2177
|
+
tf_layers_dict=tf_layers_dict,
|
|
2178
|
+
check_results=check_results,
|
|
2179
|
+
conversion_error=None,
|
|
2180
|
+
error_threshold=check_onnx_tf_outputs_elementwise_close_atol,
|
|
2181
|
+
model_path=input_onnx_file_path,
|
|
2182
|
+
max_iterations=5,
|
|
2183
|
+
target_accuracy=check_onnx_tf_outputs_elementwise_close_atol,
|
|
2184
|
+
unlimited_mode=True,
|
|
2185
|
+
)
|
|
2186
|
+
|
|
2187
|
+
if auto_json.get('operations'):
|
|
2188
|
+
# Save the JSON
|
|
2189
|
+
generated_json_path = save_auto_replacement_json(
|
|
2190
|
+
replacement_json=auto_json,
|
|
2191
|
+
model_path=input_onnx_file_path,
|
|
2192
|
+
output_dir=output_folder_path,
|
|
2193
|
+
)
|
|
2194
|
+
info(f'Generated JSON with {len(auto_json["operations"])} operations: {generated_json_path}')
|
|
2195
|
+
|
|
2196
|
+
# If both -cotof and -agj are specified, re-run validation with the generated JSON
|
|
2197
|
+
info('')
|
|
2198
|
+
info(Color.REVERSE(f'Re-running validation with auto-generated JSON'), '=' * 35)
|
|
2199
|
+
|
|
2200
|
+
# TODO: In a full implementation, we would need to:
|
|
2201
|
+
# 1. Re-run the entire conversion with the generated JSON
|
|
2202
|
+
# 2. Re-validate the outputs
|
|
2203
|
+
# 3. Display the new validation results
|
|
2204
|
+
# For now, we just inform the user
|
|
2205
|
+
|
|
2206
|
+
info(Color.GREEN(f'\nAuto-generated JSON saved to: {generated_json_path}'))
|
|
2207
|
+
info(
|
|
2208
|
+
f'To see the validation results with the generated JSON, please re-run with:\n' +
|
|
2209
|
+
f' -prf {generated_json_path} -cotof'
|
|
2210
|
+
)
|
|
2211
|
+
else:
|
|
2212
|
+
warn('No viable parameter replacements found.')
|
|
2213
|
+
|
|
2214
|
+
else:
|
|
2215
|
+
# -agj is specified but -cotof is not, so we need to run our own validation
|
|
2216
|
+
try:
|
|
2217
|
+
import onnxruntime
|
|
2218
|
+
import sne4onnx
|
|
2219
|
+
except Exception as ex:
|
|
2220
|
+
error(
|
|
2221
|
+
f'If --auto_generate_json is specified, ' +
|
|
2222
|
+
f'you must install onnxruntime and sne4onnx. pip install sne4onnx onnxruntime'
|
|
2223
|
+
)
|
|
2224
|
+
sys.exit(1)
|
|
2225
|
+
|
|
2226
|
+
info('')
|
|
2227
|
+
info(Color.REVERSE(f'Auto JSON generation started'), '=' * 50)
|
|
2228
|
+
info(
|
|
2229
|
+
'Searching for optimal parameter replacement JSON to achieve minimum error...'
|
|
2230
|
+
)
|
|
2231
|
+
|
|
2232
|
+
# Run validation for final outputs only
|
|
2233
|
+
ops_output_names = output_names
|
|
2234
|
+
|
|
2235
|
+
# Rebuild model for validation
|
|
2236
|
+
outputs = [
|
|
2237
|
+
layer_info['tf_node'] \
|
|
2238
|
+
for opname, layer_info in tf_layers_dict.items() \
|
|
2239
|
+
if opname in ops_output_names \
|
|
2240
|
+
and not hasattr(layer_info['tf_node'], 'numpy')
|
|
2241
|
+
]
|
|
2242
|
+
exclude_output_names = [
|
|
2243
|
+
opname \
|
|
2244
|
+
for opname, layer_info in tf_layers_dict.items() \
|
|
2245
|
+
if opname in ops_output_names \
|
|
2246
|
+
and hasattr(layer_info['tf_node'], 'numpy')
|
|
2247
|
+
]
|
|
2248
|
+
validation_model = tf_keras.Model(inputs=inputs, outputs=outputs)
|
|
2249
|
+
|
|
2250
|
+
# Exclude output OPs not subject to validation
|
|
2251
|
+
ops_output_names = [
|
|
2252
|
+
ops_output_name for ops_output_name in ops_output_names \
|
|
2253
|
+
if ops_output_name not in exclude_output_names
|
|
2254
|
+
]
|
|
2255
|
+
|
|
2256
|
+
# Initial accuracy check
|
|
2257
|
+
try:
|
|
2258
|
+
# ONNX dummy inference
|
|
2259
|
+
dummy_onnx_outputs: List[np.ndarray] = \
|
|
2260
|
+
dummy_onnx_inference(
|
|
2261
|
+
onnx_graph=onnx_graph,
|
|
2262
|
+
output_names=ops_output_names,
|
|
2263
|
+
test_data_nhwc=test_data_nhwc,
|
|
2264
|
+
custom_input_op_name_np_data_path=custom_input_op_name_np_data_path,
|
|
2265
|
+
tf_layers_dict=tf_layers_dict,
|
|
2266
|
+
use_cuda=use_cuda,
|
|
2267
|
+
shape_hints=shape_hints,
|
|
2268
|
+
)
|
|
2269
|
+
|
|
2270
|
+
# TF dummy inference
|
|
2271
|
+
tf_tensor_infos: Dict[Any] = \
|
|
2272
|
+
dummy_tf_inference(
|
|
2273
|
+
model=validation_model,
|
|
2274
|
+
inputs=inputs,
|
|
2275
|
+
test_data_nhwc=test_data_nhwc,
|
|
2276
|
+
custom_input_op_name_np_data_path=custom_input_op_name_np_data_path,
|
|
2277
|
+
shape_hints=shape_hints,
|
|
2278
|
+
keep_shape_absolutely_input_names=keep_shape_absolutely_input_names,
|
|
2279
|
+
keep_ncw_or_nchw_or_ncdhw_input_names=keep_ncw_or_nchw_or_ncdhw_input_names,
|
|
2280
|
+
keep_nwc_or_nhwc_or_ndhwc_input_names=keep_nwc_or_nhwc_or_ndhwc_input_names,
|
|
2281
|
+
)
|
|
2282
|
+
|
|
2283
|
+
# Validation
|
|
2284
|
+
onnx_tensor_infos = {
|
|
2285
|
+
output_name: dummy_onnx_output \
|
|
2286
|
+
for output_name, dummy_onnx_output in zip(ops_output_names, dummy_onnx_outputs)
|
|
2287
|
+
}
|
|
2288
|
+
|
|
2289
|
+
input_names = [k.name for k in inputs]
|
|
2290
|
+
for k, v in tf_layers_dict.items():
|
|
2291
|
+
if 'tf_node_info' in v:
|
|
2292
|
+
if v['tf_node_info']['tf_op_type'] == 'identity':
|
|
2293
|
+
tf_tensor_infos[v['tf_node'].name] = np.ndarray([0], dtype=np.int64)
|
|
2294
|
+
onnx_tf_output_pairs = {
|
|
2295
|
+
(k, v['tf_node'].name): (onnx_tensor_infos[k], tf_tensor_infos[v['tf_node'].name])
|
|
2296
|
+
for k, v in tf_layers_dict.items() \
|
|
2297
|
+
if k not in input_names and not hasattr(v['tf_node'], 'numpy') and k in onnx_tensor_infos
|
|
2298
|
+
}
|
|
2299
|
+
|
|
2300
|
+
agj_check_results = onnx_tf_tensor_validation(
|
|
2301
|
+
output_pairs=onnx_tf_output_pairs,
|
|
2302
|
+
rtol=0.0,
|
|
2303
|
+
atol=1e-4,
|
|
2304
|
+
)
|
|
2305
|
+
|
|
2306
|
+
# Check if all outputs match
|
|
2307
|
+
all_matched = True
|
|
2308
|
+
max_error = 0.0
|
|
2309
|
+
error_count = 0
|
|
2310
|
+
|
|
2311
|
+
for (onnx_name, tf_name), checked_value in agj_check_results.items():
|
|
2312
|
+
matched_flg = checked_value[1]
|
|
2313
|
+
max_abs_err = checked_value[2]
|
|
2314
|
+
|
|
2315
|
+
if matched_flg == 0: # Unmatched
|
|
2316
|
+
all_matched = False
|
|
2317
|
+
if isinstance(max_abs_err, (int, float, np.float32, np.float64)):
|
|
2318
|
+
max_error = max(max_error, max_abs_err)
|
|
2319
|
+
error_count += 1
|
|
2320
|
+
|
|
2321
|
+
if all_matched:
|
|
2322
|
+
info(Color.GREEN('All outputs already match! No JSON generation needed.'))
|
|
2323
|
+
else:
|
|
2324
|
+
info(f'Initial validation: {error_count} outputs have errors, max error: {max_error:.6f}')
|
|
2325
|
+
info('Generating optimal JSON...')
|
|
2326
|
+
|
|
2327
|
+
# Generate auto replacement JSON
|
|
2328
|
+
auto_json = generate_auto_replacement_json(
|
|
2329
|
+
onnx_graph=gs.import_onnx(onnx_graph),
|
|
2330
|
+
tf_layers_dict=tf_layers_dict,
|
|
2331
|
+
check_results=agj_check_results,
|
|
2332
|
+
conversion_error=None,
|
|
2333
|
+
error_threshold=1e-4,
|
|
2334
|
+
model_path=input_onnx_file_path,
|
|
2335
|
+
max_iterations=5,
|
|
2336
|
+
target_accuracy=1e-4,
|
|
2337
|
+
unlimited_mode=True,
|
|
2338
|
+
)
|
|
2339
|
+
|
|
2340
|
+
if auto_json.get('operations'):
|
|
2341
|
+
# Save the JSON
|
|
2342
|
+
generated_json_path = save_auto_replacement_json(
|
|
2343
|
+
replacement_json=auto_json,
|
|
2344
|
+
model_path=input_onnx_file_path,
|
|
2345
|
+
output_dir=output_folder_path,
|
|
2346
|
+
)
|
|
2347
|
+
info(f'Generated JSON with {len(auto_json["operations"])} operations: {generated_json_path}')
|
|
2348
|
+
|
|
2349
|
+
info(Color.GREEN(f'\nAuto-generated JSON saved to: {generated_json_path}'))
|
|
2350
|
+
info(
|
|
2351
|
+
f'Please re-run the conversion with: -prf {generated_json_path}\n' +
|
|
2352
|
+
f'The JSON was optimized to achieve minimal error.'
|
|
2353
|
+
)
|
|
2354
|
+
else:
|
|
2355
|
+
warn('No viable parameter replacements found.')
|
|
2356
|
+
|
|
2357
|
+
except Exception as ex:
|
|
2358
|
+
warn(
|
|
2359
|
+
f'Auto JSON generation failed: {ex}'
|
|
2360
|
+
)
|
|
2361
|
+
import traceback
|
|
2362
|
+
warn(traceback.format_exc(), prefix=False)
|
|
2363
|
+
|
|
1997
2364
|
return model
|
|
1998
2365
|
|
|
1999
2366
|
|
|
@@ -2592,6 +2959,17 @@ def main():
|
|
|
2592
2959
|
'The absolute tolerance parameter \n' +
|
|
2593
2960
|
'Default: 1e-4'
|
|
2594
2961
|
)
|
|
2962
|
+
parser.add_argument(
|
|
2963
|
+
'-agj',
|
|
2964
|
+
'--auto_generate_json',
|
|
2965
|
+
action='store_true',
|
|
2966
|
+
help=\
|
|
2967
|
+
'Automatically generates a parameter replacement JSON file that achieves minimal error ' +
|
|
2968
|
+
'when converting the model. This option explores various parameter combinations to find ' +
|
|
2969
|
+
'the best settings that result in successful conversion and highest accuracy. ' +
|
|
2970
|
+
'The search stops when the final output OP accuracy check shows "Matches". ' +
|
|
2971
|
+
'Cannot be used together with -cotof. When -cotof is specified, JSON auto-generation is disabled.'
|
|
2972
|
+
)
|
|
2595
2973
|
parser.add_argument(
|
|
2596
2974
|
'-dms',
|
|
2597
2975
|
'--disable_model_save',
|
|
@@ -2698,6 +3076,7 @@ def main():
|
|
|
2698
3076
|
fused_argmax_scale_ratio=args.fused_argmax_scale_ratio,
|
|
2699
3077
|
replace_to_pseudo_operators=args.replace_to_pseudo_operators,
|
|
2700
3078
|
param_replacement_file=args.param_replacement_file,
|
|
3079
|
+
auto_generate_json=args.auto_generate_json,
|
|
2701
3080
|
check_gpu_delegate_compatibility=args.check_gpu_delegate_compatibility,
|
|
2702
3081
|
check_onnx_tf_outputs_elementwise_close=args.check_onnx_tf_outputs_elementwise_close,
|
|
2703
3082
|
check_onnx_tf_outputs_elementwise_close_full=args.check_onnx_tf_outputs_elementwise_close_full,
|
|
@@ -97,7 +97,7 @@ def replace_parameter(
|
|
|
97
97
|
bool(replace_value) if isinstance(replace_value, int) and replace_value in [0, 1] else \
|
|
98
98
|
bool(int(replace_value)) if isinstance(replace_value, str) and replace_value in ["0", "1"] else \
|
|
99
99
|
False if isinstance(replace_value, str) and replace_value.lower() == "false" else \
|
|
100
|
-
True if isinstance(replace_value, str) and replace_value.lower() == "
|
|
100
|
+
True if isinstance(replace_value, str) and replace_value.lower() == "true" else \
|
|
101
101
|
replace_value
|
|
102
102
|
elif isinstance(value_before_replacement, int):
|
|
103
103
|
replace_value = int(replace_value)
|
|
@@ -377,7 +377,8 @@ def print_node_info(func):
|
|
|
377
377
|
f'Also, for models that include NonMaxSuppression in the post-processing, ' +
|
|
378
378
|
f'try the -onwdt option.'
|
|
379
379
|
)
|
|
380
|
-
sys.exit
|
|
380
|
+
# Re-raise the exception instead of sys.exit to allow auto JSON generation
|
|
381
|
+
raise
|
|
381
382
|
return print_wrapper_func
|
|
382
383
|
|
|
383
384
|
|