onnx2tf 1.27.10__py3-none-any.whl → 1.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2tf/__init__.py +1 -1
- onnx2tf/onnx2tf.py +398 -19
- onnx2tf/utils/common_functions.py +2 -1
- onnx2tf/utils/iterative_json_optimizer.py +258 -0
- onnx2tf/utils/json_auto_generator.py +1505 -0
- {onnx2tf-1.27.10.dist-info → onnx2tf-1.28.0.dist-info}/METADATA +40 -4
- {onnx2tf-1.27.10.dist-info → onnx2tf-1.28.0.dist-info}/RECORD +12 -10
- {onnx2tf-1.27.10.dist-info → onnx2tf-1.28.0.dist-info}/WHEEL +1 -1
- {onnx2tf-1.27.10.dist-info → onnx2tf-1.28.0.dist-info}/entry_points.txt +0 -0
- {onnx2tf-1.27.10.dist-info → onnx2tf-1.28.0.dist-info}/licenses/LICENSE +0 -0
- {onnx2tf-1.27.10.dist-info → onnx2tf-1.28.0.dist-info}/licenses/LICENSE_onnx-tensorflow +0 -0
- {onnx2tf-1.27.10.dist-info → onnx2tf-1.28.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,258 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Iterative JSON optimizer for automatic parameter replacement.
|
|
3
|
+
This module implements an iterative optimization algorithm that repeatedly
|
|
4
|
+
tests different parameter modifications and evaluates their impact.
|
|
5
|
+
"""
|
|
6
|
+
import os
|
|
7
|
+
import json
|
|
8
|
+
import tempfile
|
|
9
|
+
import subprocess
|
|
10
|
+
import sys
|
|
11
|
+
from typing import Dict, List, Any, Tuple, Optional
|
|
12
|
+
import numpy as np
|
|
13
|
+
from onnx2tf.utils.logging import *
|
|
14
|
+
import onnx_graphsurgeon as gs
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class IterativeJSONOptimizer:
|
|
18
|
+
"""
|
|
19
|
+
Iteratively optimize JSON replacements by testing conversions
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, model_path: str, output_dir: str):
|
|
23
|
+
self.model_path = model_path
|
|
24
|
+
self.output_dir = output_dir
|
|
25
|
+
self.best_json = None
|
|
26
|
+
self.best_error = float('inf')
|
|
27
|
+
self.tested_combinations = []
|
|
28
|
+
|
|
29
|
+
def test_conversion(self, json_path: Optional[str] = None) -> Tuple[bool, float, str]:
|
|
30
|
+
"""
|
|
31
|
+
Test conversion with a given JSON file
|
|
32
|
+
Returns: (success, max_error, output_message)
|
|
33
|
+
"""
|
|
34
|
+
# Build command
|
|
35
|
+
cmd = [
|
|
36
|
+
sys.executable,
|
|
37
|
+
"-m", "onnx2tf",
|
|
38
|
+
"-i", self.model_path,
|
|
39
|
+
"-o", self.output_dir,
|
|
40
|
+
"-cotof",
|
|
41
|
+
"-n" # non-verbose
|
|
42
|
+
]
|
|
43
|
+
|
|
44
|
+
if json_path:
|
|
45
|
+
cmd.extend(["-prf", json_path])
|
|
46
|
+
|
|
47
|
+
# Run conversion
|
|
48
|
+
try:
|
|
49
|
+
result = subprocess.run(
|
|
50
|
+
cmd,
|
|
51
|
+
capture_output=True,
|
|
52
|
+
text=True,
|
|
53
|
+
timeout=300 # 5 minute timeout
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
# Parse output for accuracy errors
|
|
57
|
+
output = result.stdout + result.stderr
|
|
58
|
+
max_error = 0.0
|
|
59
|
+
|
|
60
|
+
# Look for max error in output
|
|
61
|
+
for line in output.split('\n'):
|
|
62
|
+
if 'Max Absolute Error:' in line:
|
|
63
|
+
try:
|
|
64
|
+
error_str = line.split('Max Absolute Error:')[1].strip().split()[0]
|
|
65
|
+
error_val = float(error_str)
|
|
66
|
+
max_error = max(max_error, error_val)
|
|
67
|
+
except:
|
|
68
|
+
pass
|
|
69
|
+
|
|
70
|
+
# Check if conversion succeeded
|
|
71
|
+
success = result.returncode == 0
|
|
72
|
+
|
|
73
|
+
return success, max_error, output
|
|
74
|
+
|
|
75
|
+
except subprocess.TimeoutExpired:
|
|
76
|
+
return False, float('inf'), "Conversion timed out"
|
|
77
|
+
except Exception as e:
|
|
78
|
+
return False, float('inf'), str(e)
|
|
79
|
+
|
|
80
|
+
def optimize_iteratively(
|
|
81
|
+
self,
|
|
82
|
+
initial_json: Dict[str, Any],
|
|
83
|
+
max_iterations: int = 10,
|
|
84
|
+
target_error: float = 1e-2
|
|
85
|
+
) -> Dict[str, Any]:
|
|
86
|
+
"""
|
|
87
|
+
Iteratively optimize the JSON by testing different combinations
|
|
88
|
+
"""
|
|
89
|
+
info(f"Starting iterative optimization for {self.model_path}")
|
|
90
|
+
info(f"Target error: {target_error}")
|
|
91
|
+
|
|
92
|
+
# Test baseline (no JSON)
|
|
93
|
+
info("\n=== Testing baseline (no JSON) ===")
|
|
94
|
+
base_success, base_error, base_output = self.test_conversion()
|
|
95
|
+
info(f"Baseline: success={base_success}, error={base_error:.6f}")
|
|
96
|
+
|
|
97
|
+
if base_success and base_error <= target_error:
|
|
98
|
+
info("Model already meets target accuracy!")
|
|
99
|
+
return {}
|
|
100
|
+
|
|
101
|
+
# Initialize with initial JSON
|
|
102
|
+
current_json = initial_json
|
|
103
|
+
self.best_json = initial_json
|
|
104
|
+
self.best_error = base_error
|
|
105
|
+
|
|
106
|
+
# Iterative optimization
|
|
107
|
+
for iteration in range(1, max_iterations + 1):
|
|
108
|
+
info(f"\n=== Iteration {iteration}/{max_iterations} ===")
|
|
109
|
+
|
|
110
|
+
# Save current JSON to temp file
|
|
111
|
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
|
112
|
+
json.dump(current_json, f, indent=2)
|
|
113
|
+
temp_json_path = f.name
|
|
114
|
+
|
|
115
|
+
try:
|
|
116
|
+
# Test current JSON
|
|
117
|
+
success, error, output = self.test_conversion(temp_json_path)
|
|
118
|
+
info(f"Test result: success={success}, error={error:.6f}")
|
|
119
|
+
|
|
120
|
+
# Track this attempt
|
|
121
|
+
self.tested_combinations.append({
|
|
122
|
+
'iteration': iteration,
|
|
123
|
+
'operations': len(current_json.get('operations', [])),
|
|
124
|
+
'success': success,
|
|
125
|
+
'error': error
|
|
126
|
+
})
|
|
127
|
+
|
|
128
|
+
# Update best if improved
|
|
129
|
+
if success and error < self.best_error:
|
|
130
|
+
self.best_error = error
|
|
131
|
+
self.best_json = current_json.copy()
|
|
132
|
+
info(f"✓ New best! Error reduced to {error:.6f}")
|
|
133
|
+
|
|
134
|
+
# Check if target reached
|
|
135
|
+
if error <= target_error:
|
|
136
|
+
info(f"✓ Target accuracy achieved!")
|
|
137
|
+
break
|
|
138
|
+
|
|
139
|
+
# Generate variations for next iteration
|
|
140
|
+
if not success or error > target_error:
|
|
141
|
+
# Analyze the error from output
|
|
142
|
+
error_type = self._analyze_error_output(output)
|
|
143
|
+
|
|
144
|
+
# Generate new variations based on error type
|
|
145
|
+
variations = self._generate_variations(current_json, error_type)
|
|
146
|
+
|
|
147
|
+
# Test variations
|
|
148
|
+
best_variation = None
|
|
149
|
+
best_variation_error = float('inf')
|
|
150
|
+
|
|
151
|
+
for i, variation in enumerate(variations[:3]): # Test up to 3 variations
|
|
152
|
+
info(f" Testing variation {i+1}/{len(variations[:3])}")
|
|
153
|
+
|
|
154
|
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as vf:
|
|
155
|
+
json.dump(variation, vf, indent=2)
|
|
156
|
+
var_json_path = vf.name
|
|
157
|
+
|
|
158
|
+
try:
|
|
159
|
+
var_success, var_error, _ = self.test_conversion(var_json_path)
|
|
160
|
+
if var_success and var_error < best_variation_error:
|
|
161
|
+
best_variation = variation
|
|
162
|
+
best_variation_error = var_error
|
|
163
|
+
info(f" Variation error: {var_error:.6f}")
|
|
164
|
+
finally:
|
|
165
|
+
os.unlink(var_json_path)
|
|
166
|
+
|
|
167
|
+
# Use best variation for next iteration
|
|
168
|
+
if best_variation and best_variation_error < error:
|
|
169
|
+
current_json = best_variation
|
|
170
|
+
info(f" Using variation with error {best_variation_error:.6f}")
|
|
171
|
+
else:
|
|
172
|
+
# No improvement, try different strategy
|
|
173
|
+
current_json = self._modify_strategy(current_json, iteration)
|
|
174
|
+
|
|
175
|
+
finally:
|
|
176
|
+
# Clean up temp file
|
|
177
|
+
if os.path.exists(temp_json_path):
|
|
178
|
+
os.unlink(temp_json_path)
|
|
179
|
+
|
|
180
|
+
# Log summary
|
|
181
|
+
info(f"\n=== Optimization Summary ===")
|
|
182
|
+
info(f"Tested {len(self.tested_combinations)} configurations")
|
|
183
|
+
info(f"Best error achieved: {self.best_error:.6f}")
|
|
184
|
+
|
|
185
|
+
return self.best_json or {}
|
|
186
|
+
|
|
187
|
+
def _analyze_error_output(self, output: str) -> str:
|
|
188
|
+
"""Analyze error output to determine error type"""
|
|
189
|
+
if 'multiply' in output.lower() or 'mul' in output.lower():
|
|
190
|
+
return 'multiply'
|
|
191
|
+
elif 'concat' in output.lower():
|
|
192
|
+
return 'concat'
|
|
193
|
+
elif 'transpose' in output.lower():
|
|
194
|
+
return 'transpose'
|
|
195
|
+
elif 'dimension' in output.lower() or 'shape' in output.lower():
|
|
196
|
+
return 'dimension'
|
|
197
|
+
else:
|
|
198
|
+
return 'unknown'
|
|
199
|
+
|
|
200
|
+
def _generate_variations(self, json_data: Dict[str, Any], error_type: str) -> List[Dict[str, Any]]:
|
|
201
|
+
"""Generate variations based on error type"""
|
|
202
|
+
variations = []
|
|
203
|
+
operations = json_data.get('operations', [])
|
|
204
|
+
|
|
205
|
+
if error_type == 'multiply':
|
|
206
|
+
# Try different transpose permutations for Mul operations
|
|
207
|
+
new_perms = [
|
|
208
|
+
[0, 1, 2, 3, 4, 5], # Identity
|
|
209
|
+
[0, 4, 2, 3, 1, 5], # Swap dims 1 and 4
|
|
210
|
+
[0, 2, 1, 3, 4, 5], # Swap dims 1 and 2
|
|
211
|
+
[0, 1, 4, 3, 2, 5], # Swap dims 2 and 4
|
|
212
|
+
]
|
|
213
|
+
|
|
214
|
+
for perm in new_perms:
|
|
215
|
+
variation = json_data.copy()
|
|
216
|
+
variation['operations'] = []
|
|
217
|
+
|
|
218
|
+
# Modify existing Mul operations
|
|
219
|
+
for op in operations:
|
|
220
|
+
new_op = op.copy()
|
|
221
|
+
if 'Mul' in op.get('op_name', '') and 'pre_process_transpose_perm' in op:
|
|
222
|
+
new_op['pre_process_transpose_perm'] = perm
|
|
223
|
+
variation['operations'].append(new_op)
|
|
224
|
+
|
|
225
|
+
variations.append(variation)
|
|
226
|
+
|
|
227
|
+
elif error_type == 'dimension':
|
|
228
|
+
# Try removing some operations to see if simpler works better
|
|
229
|
+
if len(operations) > 10:
|
|
230
|
+
# Try with half the operations
|
|
231
|
+
variation = json_data.copy()
|
|
232
|
+
variation['operations'] = operations[:len(operations)//2]
|
|
233
|
+
variations.append(variation)
|
|
234
|
+
|
|
235
|
+
# Always include original as fallback
|
|
236
|
+
variations.append(json_data)
|
|
237
|
+
|
|
238
|
+
return variations
|
|
239
|
+
|
|
240
|
+
def _modify_strategy(self, json_data: Dict[str, Any], iteration: int) -> Dict[str, Any]:
|
|
241
|
+
"""Modify strategy when variations don't improve"""
|
|
242
|
+
operations = json_data.get('operations', [])
|
|
243
|
+
|
|
244
|
+
# Strategy 1: Reduce number of operations
|
|
245
|
+
if iteration % 2 == 0 and len(operations) > 5:
|
|
246
|
+
new_json = json_data.copy()
|
|
247
|
+
new_json['operations'] = operations[:max(5, len(operations)//2)]
|
|
248
|
+
return new_json
|
|
249
|
+
|
|
250
|
+
# Strategy 2: Focus on different operation types
|
|
251
|
+
if iteration % 3 == 0:
|
|
252
|
+
new_json = json_data.copy()
|
|
253
|
+
# Keep only Transpose operations
|
|
254
|
+
new_json['operations'] = [op for op in operations if 'Transpose' in op.get('op_name', '')]
|
|
255
|
+
return new_json
|
|
256
|
+
|
|
257
|
+
# Default: return as-is
|
|
258
|
+
return json_data
|