onnx2tf 1.27.10__py3-none-any.whl → 1.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,258 @@
1
+ """
2
+ Iterative JSON optimizer for automatic parameter replacement.
3
+ This module implements an iterative optimization algorithm that repeatedly
4
+ tests different parameter modifications and evaluates their impact.
5
+ """
6
+ import os
7
+ import json
8
+ import tempfile
9
+ import subprocess
10
+ import sys
11
+ from typing import Dict, List, Any, Tuple, Optional
12
+ import numpy as np
13
+ from onnx2tf.utils.logging import *
14
+ import onnx_graphsurgeon as gs
15
+
16
+
17
+ class IterativeJSONOptimizer:
18
+ """
19
+ Iteratively optimize JSON replacements by testing conversions
20
+ """
21
+
22
+ def __init__(self, model_path: str, output_dir: str):
23
+ self.model_path = model_path
24
+ self.output_dir = output_dir
25
+ self.best_json = None
26
+ self.best_error = float('inf')
27
+ self.tested_combinations = []
28
+
29
+ def test_conversion(self, json_path: Optional[str] = None) -> Tuple[bool, float, str]:
30
+ """
31
+ Test conversion with a given JSON file
32
+ Returns: (success, max_error, output_message)
33
+ """
34
+ # Build command
35
+ cmd = [
36
+ sys.executable,
37
+ "-m", "onnx2tf",
38
+ "-i", self.model_path,
39
+ "-o", self.output_dir,
40
+ "-cotof",
41
+ "-n" # non-verbose
42
+ ]
43
+
44
+ if json_path:
45
+ cmd.extend(["-prf", json_path])
46
+
47
+ # Run conversion
48
+ try:
49
+ result = subprocess.run(
50
+ cmd,
51
+ capture_output=True,
52
+ text=True,
53
+ timeout=300 # 5 minute timeout
54
+ )
55
+
56
+ # Parse output for accuracy errors
57
+ output = result.stdout + result.stderr
58
+ max_error = 0.0
59
+
60
+ # Look for max error in output
61
+ for line in output.split('\n'):
62
+ if 'Max Absolute Error:' in line:
63
+ try:
64
+ error_str = line.split('Max Absolute Error:')[1].strip().split()[0]
65
+ error_val = float(error_str)
66
+ max_error = max(max_error, error_val)
67
+ except:
68
+ pass
69
+
70
+ # Check if conversion succeeded
71
+ success = result.returncode == 0
72
+
73
+ return success, max_error, output
74
+
75
+ except subprocess.TimeoutExpired:
76
+ return False, float('inf'), "Conversion timed out"
77
+ except Exception as e:
78
+ return False, float('inf'), str(e)
79
+
80
+ def optimize_iteratively(
81
+ self,
82
+ initial_json: Dict[str, Any],
83
+ max_iterations: int = 10,
84
+ target_error: float = 1e-2
85
+ ) -> Dict[str, Any]:
86
+ """
87
+ Iteratively optimize the JSON by testing different combinations
88
+ """
89
+ info(f"Starting iterative optimization for {self.model_path}")
90
+ info(f"Target error: {target_error}")
91
+
92
+ # Test baseline (no JSON)
93
+ info("\n=== Testing baseline (no JSON) ===")
94
+ base_success, base_error, base_output = self.test_conversion()
95
+ info(f"Baseline: success={base_success}, error={base_error:.6f}")
96
+
97
+ if base_success and base_error <= target_error:
98
+ info("Model already meets target accuracy!")
99
+ return {}
100
+
101
+ # Initialize with initial JSON
102
+ current_json = initial_json
103
+ self.best_json = initial_json
104
+ self.best_error = base_error
105
+
106
+ # Iterative optimization
107
+ for iteration in range(1, max_iterations + 1):
108
+ info(f"\n=== Iteration {iteration}/{max_iterations} ===")
109
+
110
+ # Save current JSON to temp file
111
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
112
+ json.dump(current_json, f, indent=2)
113
+ temp_json_path = f.name
114
+
115
+ try:
116
+ # Test current JSON
117
+ success, error, output = self.test_conversion(temp_json_path)
118
+ info(f"Test result: success={success}, error={error:.6f}")
119
+
120
+ # Track this attempt
121
+ self.tested_combinations.append({
122
+ 'iteration': iteration,
123
+ 'operations': len(current_json.get('operations', [])),
124
+ 'success': success,
125
+ 'error': error
126
+ })
127
+
128
+ # Update best if improved
129
+ if success and error < self.best_error:
130
+ self.best_error = error
131
+ self.best_json = current_json.copy()
132
+ info(f"✓ New best! Error reduced to {error:.6f}")
133
+
134
+ # Check if target reached
135
+ if error <= target_error:
136
+ info(f"✓ Target accuracy achieved!")
137
+ break
138
+
139
+ # Generate variations for next iteration
140
+ if not success or error > target_error:
141
+ # Analyze the error from output
142
+ error_type = self._analyze_error_output(output)
143
+
144
+ # Generate new variations based on error type
145
+ variations = self._generate_variations(current_json, error_type)
146
+
147
+ # Test variations
148
+ best_variation = None
149
+ best_variation_error = float('inf')
150
+
151
+ for i, variation in enumerate(variations[:3]): # Test up to 3 variations
152
+ info(f" Testing variation {i+1}/{len(variations[:3])}")
153
+
154
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as vf:
155
+ json.dump(variation, vf, indent=2)
156
+ var_json_path = vf.name
157
+
158
+ try:
159
+ var_success, var_error, _ = self.test_conversion(var_json_path)
160
+ if var_success and var_error < best_variation_error:
161
+ best_variation = variation
162
+ best_variation_error = var_error
163
+ info(f" Variation error: {var_error:.6f}")
164
+ finally:
165
+ os.unlink(var_json_path)
166
+
167
+ # Use best variation for next iteration
168
+ if best_variation and best_variation_error < error:
169
+ current_json = best_variation
170
+ info(f" Using variation with error {best_variation_error:.6f}")
171
+ else:
172
+ # No improvement, try different strategy
173
+ current_json = self._modify_strategy(current_json, iteration)
174
+
175
+ finally:
176
+ # Clean up temp file
177
+ if os.path.exists(temp_json_path):
178
+ os.unlink(temp_json_path)
179
+
180
+ # Log summary
181
+ info(f"\n=== Optimization Summary ===")
182
+ info(f"Tested {len(self.tested_combinations)} configurations")
183
+ info(f"Best error achieved: {self.best_error:.6f}")
184
+
185
+ return self.best_json or {}
186
+
187
+ def _analyze_error_output(self, output: str) -> str:
188
+ """Analyze error output to determine error type"""
189
+ if 'multiply' in output.lower() or 'mul' in output.lower():
190
+ return 'multiply'
191
+ elif 'concat' in output.lower():
192
+ return 'concat'
193
+ elif 'transpose' in output.lower():
194
+ return 'transpose'
195
+ elif 'dimension' in output.lower() or 'shape' in output.lower():
196
+ return 'dimension'
197
+ else:
198
+ return 'unknown'
199
+
200
+ def _generate_variations(self, json_data: Dict[str, Any], error_type: str) -> List[Dict[str, Any]]:
201
+ """Generate variations based on error type"""
202
+ variations = []
203
+ operations = json_data.get('operations', [])
204
+
205
+ if error_type == 'multiply':
206
+ # Try different transpose permutations for Mul operations
207
+ new_perms = [
208
+ [0, 1, 2, 3, 4, 5], # Identity
209
+ [0, 4, 2, 3, 1, 5], # Swap dims 1 and 4
210
+ [0, 2, 1, 3, 4, 5], # Swap dims 1 and 2
211
+ [0, 1, 4, 3, 2, 5], # Swap dims 2 and 4
212
+ ]
213
+
214
+ for perm in new_perms:
215
+ variation = json_data.copy()
216
+ variation['operations'] = []
217
+
218
+ # Modify existing Mul operations
219
+ for op in operations:
220
+ new_op = op.copy()
221
+ if 'Mul' in op.get('op_name', '') and 'pre_process_transpose_perm' in op:
222
+ new_op['pre_process_transpose_perm'] = perm
223
+ variation['operations'].append(new_op)
224
+
225
+ variations.append(variation)
226
+
227
+ elif error_type == 'dimension':
228
+ # Try removing some operations to see if simpler works better
229
+ if len(operations) > 10:
230
+ # Try with half the operations
231
+ variation = json_data.copy()
232
+ variation['operations'] = operations[:len(operations)//2]
233
+ variations.append(variation)
234
+
235
+ # Always include original as fallback
236
+ variations.append(json_data)
237
+
238
+ return variations
239
+
240
+ def _modify_strategy(self, json_data: Dict[str, Any], iteration: int) -> Dict[str, Any]:
241
+ """Modify strategy when variations don't improve"""
242
+ operations = json_data.get('operations', [])
243
+
244
+ # Strategy 1: Reduce number of operations
245
+ if iteration % 2 == 0 and len(operations) > 5:
246
+ new_json = json_data.copy()
247
+ new_json['operations'] = operations[:max(5, len(operations)//2)]
248
+ return new_json
249
+
250
+ # Strategy 2: Focus on different operation types
251
+ if iteration % 3 == 0:
252
+ new_json = json_data.copy()
253
+ # Keep only Transpose operations
254
+ new_json['operations'] = [op for op in operations if 'Transpose' in op.get('op_name', '')]
255
+ return new_json
256
+
257
+ # Default: return as-is
258
+ return json_data