onnx2tf 1.29.1__py3-none-any.whl → 1.29.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2tf/__init__.py +1 -1
- onnx2tf/onnx2tf.py +7 -2
- onnx2tf/ops/AveragePool.py +91 -31
- onnx2tf/ops/Conv.py +27 -0
- onnx2tf/utils/json_auto_generator.py +190 -190
- {onnx2tf-1.29.1.dist-info → onnx2tf-1.29.3.dist-info}/METADATA +8 -8
- {onnx2tf-1.29.1.dist-info → onnx2tf-1.29.3.dist-info}/RECORD +11 -11
- {onnx2tf-1.29.1.dist-info → onnx2tf-1.29.3.dist-info}/WHEEL +0 -0
- {onnx2tf-1.29.1.dist-info → onnx2tf-1.29.3.dist-info}/licenses/LICENSE +0 -0
- {onnx2tf-1.29.1.dist-info → onnx2tf-1.29.3.dist-info}/licenses/LICENSE_onnx-tensorflow +0 -0
- {onnx2tf-1.29.1.dist-info → onnx2tf-1.29.3.dist-info}/top_level.txt +0 -0
|
@@ -20,11 +20,11 @@ import shutil
|
|
|
20
20
|
|
|
21
21
|
class OperationFixer:
|
|
22
22
|
"""Base class for operation-specific fixers"""
|
|
23
|
-
|
|
23
|
+
|
|
24
24
|
def __init__(self, node: gs.Node, error_info: Dict[str, Any]):
|
|
25
25
|
self.node = node
|
|
26
26
|
self.error_info = error_info
|
|
27
|
-
|
|
27
|
+
|
|
28
28
|
def generate_fixes(self) -> List[Dict[str, Any]]:
|
|
29
29
|
"""Generate possible fixes for this operation"""
|
|
30
30
|
raise NotImplementedError
|
|
@@ -32,15 +32,15 @@ class OperationFixer:
|
|
|
32
32
|
|
|
33
33
|
class TransposeFixer(OperationFixer):
|
|
34
34
|
"""Fixer for Transpose operations"""
|
|
35
|
-
|
|
35
|
+
|
|
36
36
|
def generate_fixes(self) -> List[Dict[str, Any]]:
|
|
37
37
|
fixes = []
|
|
38
38
|
perm = self.node.attrs.get("perm", [])
|
|
39
39
|
if not perm:
|
|
40
40
|
return fixes
|
|
41
|
-
|
|
41
|
+
|
|
42
42
|
ndim = len(perm)
|
|
43
|
-
|
|
43
|
+
|
|
44
44
|
# Generate all permutations based on dimension
|
|
45
45
|
if ndim <= 3:
|
|
46
46
|
# For 3D or less, try all permutations (max 6)
|
|
@@ -73,11 +73,11 @@ class TransposeFixer(OperationFixer):
|
|
|
73
73
|
candidates.append(list(p))
|
|
74
74
|
if len(candidates) >= 50:
|
|
75
75
|
break
|
|
76
|
-
|
|
76
|
+
|
|
77
77
|
# Add current perm if not in candidates
|
|
78
78
|
if perm not in candidates:
|
|
79
79
|
candidates.insert(1, perm)
|
|
80
|
-
|
|
80
|
+
|
|
81
81
|
# Generate fixes
|
|
82
82
|
for candidate in candidates:
|
|
83
83
|
if candidate != perm:
|
|
@@ -88,17 +88,17 @@ class TransposeFixer(OperationFixer):
|
|
|
88
88
|
"values": candidate,
|
|
89
89
|
"confidence": 0.8 if candidate == list(range(ndim)) else 0.5
|
|
90
90
|
})
|
|
91
|
-
|
|
91
|
+
|
|
92
92
|
return fixes
|
|
93
93
|
|
|
94
94
|
|
|
95
95
|
class ConcatFixer(OperationFixer):
|
|
96
96
|
"""Fixer for Concat operations"""
|
|
97
|
-
|
|
97
|
+
|
|
98
98
|
def generate_fixes(self) -> List[Dict[str, Any]]:
|
|
99
99
|
fixes = []
|
|
100
100
|
axis = self.node.attrs.get("axis", 1)
|
|
101
|
-
|
|
101
|
+
|
|
102
102
|
# Common axis adjustments for NCHW to NHWC conversion
|
|
103
103
|
axis_mappings = {
|
|
104
104
|
1: [3, -1], # Channel dimension NCHW -> NHWC
|
|
@@ -106,9 +106,9 @@ class ConcatFixer(OperationFixer):
|
|
|
106
106
|
3: [2], # Width dimension
|
|
107
107
|
-1: [3, 1], # Last dimension
|
|
108
108
|
}
|
|
109
|
-
|
|
109
|
+
|
|
110
110
|
candidates = axis_mappings.get(axis, [])
|
|
111
|
-
|
|
111
|
+
|
|
112
112
|
# Always try the complement dimension
|
|
113
113
|
ndim = 4 # Assume 4D by default, will be refined later
|
|
114
114
|
if hasattr(self.node, 'inputs') and self.node.inputs:
|
|
@@ -117,7 +117,7 @@ class ConcatFixer(OperationFixer):
|
|
|
117
117
|
if hasattr(inp, 'shape') and inp.shape:
|
|
118
118
|
ndim = len(inp.shape)
|
|
119
119
|
break
|
|
120
|
-
|
|
120
|
+
|
|
121
121
|
# Add dimension-specific candidates
|
|
122
122
|
if ndim == 4:
|
|
123
123
|
candidates.extend([3, 1, 2, -1])
|
|
@@ -125,11 +125,11 @@ class ConcatFixer(OperationFixer):
|
|
|
125
125
|
candidates.extend([2, 1, -1])
|
|
126
126
|
elif ndim == 5:
|
|
127
127
|
candidates.extend([4, 1, 2, 3, -1])
|
|
128
|
-
|
|
128
|
+
|
|
129
129
|
# Remove duplicates and current axis
|
|
130
130
|
candidates = list(dict.fromkeys(candidates))
|
|
131
131
|
candidates = [c for c in candidates if c != axis]
|
|
132
|
-
|
|
132
|
+
|
|
133
133
|
for candidate in candidates:
|
|
134
134
|
fixes.append({
|
|
135
135
|
"op_name": self.node.name,
|
|
@@ -138,13 +138,13 @@ class ConcatFixer(OperationFixer):
|
|
|
138
138
|
"values": candidate,
|
|
139
139
|
"confidence": 0.7
|
|
140
140
|
})
|
|
141
|
-
|
|
141
|
+
|
|
142
142
|
return fixes
|
|
143
143
|
|
|
144
144
|
|
|
145
145
|
class SplitFixer(OperationFixer):
|
|
146
146
|
"""Fixer for Split operations"""
|
|
147
|
-
|
|
147
|
+
|
|
148
148
|
def generate_fixes(self) -> List[Dict[str, Any]]:
|
|
149
149
|
# Similar to ConcatFixer but for Split
|
|
150
150
|
return ConcatFixer(self.node, self.error_info).generate_fixes()
|
|
@@ -152,10 +152,10 @@ class SplitFixer(OperationFixer):
|
|
|
152
152
|
|
|
153
153
|
class ReshapeFixer(OperationFixer):
|
|
154
154
|
"""Fixer for Reshape operations"""
|
|
155
|
-
|
|
155
|
+
|
|
156
156
|
def generate_fixes(self) -> List[Dict[str, Any]]:
|
|
157
157
|
fixes = []
|
|
158
|
-
|
|
158
|
+
|
|
159
159
|
# Generate all permutations for different dimensions
|
|
160
160
|
# For performance reasons, limit permutations for higher dimensions
|
|
161
161
|
def get_all_permutations(ndim: int) -> List[List[int]]:
|
|
@@ -186,7 +186,7 @@ class ReshapeFixer(OperationFixer):
|
|
|
186
186
|
if len(perms) >= 30: # Limit to 30 permutations
|
|
187
187
|
break
|
|
188
188
|
return perms
|
|
189
|
-
|
|
189
|
+
|
|
190
190
|
# Get input shape for pre-process transpose
|
|
191
191
|
input_ndim = 4 # Default
|
|
192
192
|
if hasattr(self.node, 'inputs') and self.node.inputs:
|
|
@@ -194,7 +194,7 @@ class ReshapeFixer(OperationFixer):
|
|
|
194
194
|
if hasattr(input_tensor, 'shape') and input_tensor.shape:
|
|
195
195
|
input_ndim = len(input_tensor.shape)
|
|
196
196
|
info(f"ReshapeFixer: Input tensor shape: {input_tensor.shape} (ndim={input_ndim})")
|
|
197
|
-
|
|
197
|
+
|
|
198
198
|
# Get output shape for post-process transpose
|
|
199
199
|
output_ndim = 4 # Default
|
|
200
200
|
if hasattr(self.node, 'outputs') and self.node.outputs:
|
|
@@ -202,7 +202,7 @@ class ReshapeFixer(OperationFixer):
|
|
|
202
202
|
if hasattr(output_tensor, 'shape') and output_tensor.shape:
|
|
203
203
|
output_ndim = len(output_tensor.shape)
|
|
204
204
|
info(f"ReshapeFixer: Output tensor shape: {output_tensor.shape} (ndim={output_ndim})")
|
|
205
|
-
|
|
205
|
+
|
|
206
206
|
# Generate pre-process transpose based on input dimensions
|
|
207
207
|
input_perms = get_all_permutations(input_ndim)
|
|
208
208
|
for perm in input_perms:
|
|
@@ -213,7 +213,7 @@ class ReshapeFixer(OperationFixer):
|
|
|
213
213
|
"pre_process_transpose_perm": perm,
|
|
214
214
|
"confidence": 0.6
|
|
215
215
|
})
|
|
216
|
-
|
|
216
|
+
|
|
217
217
|
# Generate post-process transpose based on output dimensions
|
|
218
218
|
output_perms = get_all_permutations(output_ndim)
|
|
219
219
|
for perm in output_perms:
|
|
@@ -224,11 +224,11 @@ class ReshapeFixer(OperationFixer):
|
|
|
224
224
|
"post_process_transpose_perm": perm,
|
|
225
225
|
"confidence": 0.6
|
|
226
226
|
})
|
|
227
|
-
|
|
227
|
+
|
|
228
228
|
# Also try modifying the shape parameter directly
|
|
229
229
|
if len(self.node.inputs) >= 2:
|
|
230
230
|
shape_input = self.node.inputs[1]
|
|
231
|
-
|
|
231
|
+
|
|
232
232
|
# Common shape modifications
|
|
233
233
|
# For example, if reshaping from [N,C,H,W] to [N,C*H*W]
|
|
234
234
|
# We might need to transpose to [N,H,W,C] first
|
|
@@ -239,20 +239,20 @@ class ReshapeFixer(OperationFixer):
|
|
|
239
239
|
"values": [-1, -1], # Let Reshape infer dimensions
|
|
240
240
|
"confidence": 0.4
|
|
241
241
|
})
|
|
242
|
-
|
|
242
|
+
|
|
243
243
|
return fixes
|
|
244
244
|
|
|
245
245
|
|
|
246
246
|
class ResizeFixer(OperationFixer):
|
|
247
247
|
"""Fixer for Resize operations"""
|
|
248
|
-
|
|
248
|
+
|
|
249
249
|
def generate_fixes(self) -> List[Dict[str, Any]]:
|
|
250
250
|
fixes = []
|
|
251
|
-
|
|
251
|
+
|
|
252
252
|
# Try different coordinate transformation modes
|
|
253
253
|
modes = ["asymmetric", "pytorch_half_pixel", "tf_half_pixel_for_nn", "align_corners"]
|
|
254
254
|
current_mode = self.node.attrs.get("coordinate_transformation_mode", "half_pixel")
|
|
255
|
-
|
|
255
|
+
|
|
256
256
|
for mode in modes:
|
|
257
257
|
if mode != current_mode:
|
|
258
258
|
fixes.append({
|
|
@@ -262,11 +262,11 @@ class ResizeFixer(OperationFixer):
|
|
|
262
262
|
"values": mode,
|
|
263
263
|
"confidence": 0.5
|
|
264
264
|
})
|
|
265
|
-
|
|
265
|
+
|
|
266
266
|
# Try different interpolation modes
|
|
267
267
|
interp_modes = ["nearest", "linear", "cubic"]
|
|
268
268
|
current_interp = self.node.attrs.get("mode", "nearest")
|
|
269
|
-
|
|
269
|
+
|
|
270
270
|
for mode in interp_modes:
|
|
271
271
|
if mode != current_interp:
|
|
272
272
|
fixes.append({
|
|
@@ -276,18 +276,18 @@ class ResizeFixer(OperationFixer):
|
|
|
276
276
|
"values": mode,
|
|
277
277
|
"confidence": 0.5
|
|
278
278
|
})
|
|
279
|
-
|
|
279
|
+
|
|
280
280
|
return fixes
|
|
281
281
|
|
|
282
282
|
|
|
283
283
|
class ReduceFixer(OperationFixer):
|
|
284
284
|
"""Fixer for Reduce operations (ReduceMax, ReduceMean, etc.)"""
|
|
285
|
-
|
|
285
|
+
|
|
286
286
|
def generate_fixes(self) -> List[Dict[str, Any]]:
|
|
287
287
|
fixes = []
|
|
288
288
|
axes = self.node.attrs.get("axes", [])
|
|
289
289
|
keepdims = self.node.attrs.get("keepdims", 1)
|
|
290
|
-
|
|
290
|
+
|
|
291
291
|
# Try different axes combinations
|
|
292
292
|
if axes:
|
|
293
293
|
# Common axis mappings for dimension conversion
|
|
@@ -296,12 +296,12 @@ class ReduceFixer(OperationFixer):
|
|
|
296
296
|
2: [1, -2], # Height dimension
|
|
297
297
|
3: [2, -1], # Width dimension
|
|
298
298
|
}
|
|
299
|
-
|
|
299
|
+
|
|
300
300
|
new_axes = []
|
|
301
301
|
for axis in axes:
|
|
302
302
|
if axis in axis_mappings:
|
|
303
303
|
new_axes.extend(axis_mappings[axis])
|
|
304
|
-
|
|
304
|
+
|
|
305
305
|
if new_axes and new_axes != axes:
|
|
306
306
|
fixes.append({
|
|
307
307
|
"op_name": self.node.name,
|
|
@@ -310,7 +310,7 @@ class ReduceFixer(OperationFixer):
|
|
|
310
310
|
"values": list(dict.fromkeys(new_axes))[:len(axes)],
|
|
311
311
|
"confidence": 0.6
|
|
312
312
|
})
|
|
313
|
-
|
|
313
|
+
|
|
314
314
|
# Try toggling keepdims
|
|
315
315
|
fixes.append({
|
|
316
316
|
"op_name": self.node.name,
|
|
@@ -319,20 +319,20 @@ class ReduceFixer(OperationFixer):
|
|
|
319
319
|
"values": 1 - keepdims,
|
|
320
320
|
"confidence": 0.4
|
|
321
321
|
})
|
|
322
|
-
|
|
322
|
+
|
|
323
323
|
return fixes
|
|
324
324
|
|
|
325
325
|
|
|
326
326
|
class SoftmaxFixer(OperationFixer):
|
|
327
327
|
"""Fixer for Softmax operations"""
|
|
328
|
-
|
|
328
|
+
|
|
329
329
|
def generate_fixes(self) -> List[Dict[str, Any]]:
|
|
330
330
|
fixes = []
|
|
331
331
|
axis = self.node.attrs.get("axis", -1)
|
|
332
|
-
|
|
332
|
+
|
|
333
333
|
# Common axis adjustments
|
|
334
334
|
candidates = [-1, 1, 2, 3]
|
|
335
|
-
|
|
335
|
+
|
|
336
336
|
for candidate in candidates:
|
|
337
337
|
if candidate != axis:
|
|
338
338
|
fixes.append({
|
|
@@ -342,16 +342,16 @@ class SoftmaxFixer(OperationFixer):
|
|
|
342
342
|
"values": candidate,
|
|
343
343
|
"confidence": 0.6
|
|
344
344
|
})
|
|
345
|
-
|
|
345
|
+
|
|
346
346
|
return fixes
|
|
347
347
|
|
|
348
348
|
|
|
349
349
|
class AddMulDivSubFixer(OperationFixer):
|
|
350
350
|
"""Fixer for Add, Mul, Div, Sub operations"""
|
|
351
|
-
|
|
351
|
+
|
|
352
352
|
def generate_fixes(self) -> List[Dict[str, Any]]:
|
|
353
353
|
fixes = []
|
|
354
|
-
|
|
354
|
+
|
|
355
355
|
# Generate all permutations for pre/post transpose
|
|
356
356
|
def get_perms_for_ndim(ndim: int) -> List[List[int]]:
|
|
357
357
|
if ndim <= 3:
|
|
@@ -386,9 +386,9 @@ class AddMulDivSubFixer(OperationFixer):
|
|
|
386
386
|
if len(perms) >= 30:
|
|
387
387
|
break
|
|
388
388
|
return perms
|
|
389
|
-
|
|
389
|
+
|
|
390
390
|
common_perms = {}
|
|
391
|
-
|
|
391
|
+
|
|
392
392
|
# Try to determine the dimension
|
|
393
393
|
ndim = 4 # Default
|
|
394
394
|
if hasattr(self.node, 'inputs') and self.node.inputs:
|
|
@@ -396,7 +396,7 @@ class AddMulDivSubFixer(OperationFixer):
|
|
|
396
396
|
if hasattr(inp, 'shape') and inp.shape:
|
|
397
397
|
ndim = len(inp.shape)
|
|
398
398
|
break
|
|
399
|
-
|
|
399
|
+
|
|
400
400
|
# Check if dimension mismatch is mentioned in error
|
|
401
401
|
if 'error_msg' in self.error_info:
|
|
402
402
|
error_msg = self.error_info['error_msg']
|
|
@@ -407,7 +407,7 @@ class AddMulDivSubFixer(OperationFixer):
|
|
|
407
407
|
# Parse shapes
|
|
408
408
|
shape1 = [int(x.strip()) for x in shapes[0].split(',')]
|
|
409
409
|
shape2 = [int(x.strip()) for x in shapes[1].split(',')]
|
|
410
|
-
|
|
410
|
+
|
|
411
411
|
# Generate transpose fixes for broadcasting issues
|
|
412
412
|
if len(shape1) == len(shape2) and len(shape1) == 6:
|
|
413
413
|
# For 6D tensor broadcasting issues
|
|
@@ -416,7 +416,7 @@ class AddMulDivSubFixer(OperationFixer):
|
|
|
416
416
|
for i in range(6):
|
|
417
417
|
if shape1[i] != shape2[i] and shape2[i] != 1 and shape1[i] != 1:
|
|
418
418
|
mismatches.append(i)
|
|
419
|
-
|
|
419
|
+
|
|
420
420
|
special_perms = []
|
|
421
421
|
# If we have exactly 2 mismatched dimensions, swap them
|
|
422
422
|
if len(mismatches) == 2:
|
|
@@ -424,14 +424,14 @@ class AddMulDivSubFixer(OperationFixer):
|
|
|
424
424
|
perm[mismatches[0]], perm[mismatches[1]] = perm[mismatches[1]], perm[mismatches[0]]
|
|
425
425
|
special_perms.append(perm)
|
|
426
426
|
info(f"ExpandFixer: Detected dimension mismatch at dims {mismatches}, suggesting permutation {perm}")
|
|
427
|
-
|
|
427
|
+
|
|
428
428
|
# Also add common patterns
|
|
429
429
|
special_perms.extend([
|
|
430
430
|
[0, 4, 2, 3, 1, 5], # Swap dims 1 and 4 (common for this error)
|
|
431
431
|
[0, 2, 1, 3, 4, 5], # Swap dims 1 and 2
|
|
432
432
|
[0, 1, 2, 3, 5, 4], # Swap last two
|
|
433
433
|
])
|
|
434
|
-
|
|
434
|
+
|
|
435
435
|
for perm in special_perms:
|
|
436
436
|
for inp in self.node.inputs[:2]: # Both inputs
|
|
437
437
|
if hasattr(inp, 'name'):
|
|
@@ -442,7 +442,7 @@ class AddMulDivSubFixer(OperationFixer):
|
|
|
442
442
|
"pre_process_transpose_perm": perm,
|
|
443
443
|
"confidence": 0.7
|
|
444
444
|
})
|
|
445
|
-
|
|
445
|
+
|
|
446
446
|
# Generate permutations for the detected dimension
|
|
447
447
|
perms = get_perms_for_ndim(ndim)
|
|
448
448
|
if perms:
|
|
@@ -456,7 +456,7 @@ class AddMulDivSubFixer(OperationFixer):
|
|
|
456
456
|
"pre_process_transpose_perm": perm,
|
|
457
457
|
"confidence": 0.5
|
|
458
458
|
})
|
|
459
|
-
|
|
459
|
+
|
|
460
460
|
# Post-process transpose
|
|
461
461
|
if self.node.outputs:
|
|
462
462
|
fixes.append({
|
|
@@ -466,16 +466,16 @@ class AddMulDivSubFixer(OperationFixer):
|
|
|
466
466
|
"post_process_transpose_perm": perm,
|
|
467
467
|
"confidence": 0.5
|
|
468
468
|
})
|
|
469
|
-
|
|
469
|
+
|
|
470
470
|
return fixes
|
|
471
471
|
|
|
472
472
|
|
|
473
473
|
class CastFixer(OperationFixer):
|
|
474
474
|
"""Fixer for Cast operations"""
|
|
475
|
-
|
|
475
|
+
|
|
476
476
|
def generate_fixes(self) -> List[Dict[str, Any]]:
|
|
477
477
|
fixes = []
|
|
478
|
-
|
|
478
|
+
|
|
479
479
|
# Type mappings from README
|
|
480
480
|
type_values = {
|
|
481
481
|
"float32": 1,
|
|
@@ -491,12 +491,12 @@ class CastFixer(OperationFixer):
|
|
|
491
491
|
"uint32": 12,
|
|
492
492
|
"uint64": 13,
|
|
493
493
|
}
|
|
494
|
-
|
|
494
|
+
|
|
495
495
|
current_to = self.node.attrs.get("to", 1)
|
|
496
|
-
|
|
496
|
+
|
|
497
497
|
# Try common type conversions
|
|
498
498
|
common_types = [1, 6, 7] # float32, int32, int64
|
|
499
|
-
|
|
499
|
+
|
|
500
500
|
for type_val in common_types:
|
|
501
501
|
if type_val != current_to:
|
|
502
502
|
fixes.append({
|
|
@@ -506,20 +506,20 @@ class CastFixer(OperationFixer):
|
|
|
506
506
|
"values": type_val,
|
|
507
507
|
"confidence": 0.4
|
|
508
508
|
})
|
|
509
|
-
|
|
509
|
+
|
|
510
510
|
return fixes
|
|
511
511
|
|
|
512
512
|
|
|
513
513
|
class GatherFixer(OperationFixer):
|
|
514
514
|
"""Fixer for Gather operations"""
|
|
515
|
-
|
|
515
|
+
|
|
516
516
|
def generate_fixes(self) -> List[Dict[str, Any]]:
|
|
517
517
|
fixes = []
|
|
518
518
|
axis = self.node.attrs.get("axis", 0)
|
|
519
|
-
|
|
519
|
+
|
|
520
520
|
# Try different axis values
|
|
521
521
|
candidates = [0, 1, 2, 3, -1, -2]
|
|
522
|
-
|
|
522
|
+
|
|
523
523
|
for candidate in candidates:
|
|
524
524
|
if candidate != axis:
|
|
525
525
|
fixes.append({
|
|
@@ -529,20 +529,20 @@ class GatherFixer(OperationFixer):
|
|
|
529
529
|
"values": candidate,
|
|
530
530
|
"confidence": 0.5
|
|
531
531
|
})
|
|
532
|
-
|
|
532
|
+
|
|
533
533
|
return fixes
|
|
534
534
|
|
|
535
535
|
|
|
536
536
|
class FlattenFixer(OperationFixer):
|
|
537
537
|
"""Fixer for Flatten operations"""
|
|
538
|
-
|
|
538
|
+
|
|
539
539
|
def generate_fixes(self) -> List[Dict[str, Any]]:
|
|
540
540
|
fixes = []
|
|
541
541
|
axis = self.node.attrs.get("axis", 1)
|
|
542
|
-
|
|
542
|
+
|
|
543
543
|
# Try different axis values
|
|
544
544
|
candidates = [0, 1, 2, -1]
|
|
545
|
-
|
|
545
|
+
|
|
546
546
|
for candidate in candidates:
|
|
547
547
|
if candidate != axis:
|
|
548
548
|
fixes.append({
|
|
@@ -552,7 +552,7 @@ class FlattenFixer(OperationFixer):
|
|
|
552
552
|
"values": candidate,
|
|
553
553
|
"confidence": 0.6
|
|
554
554
|
})
|
|
555
|
-
|
|
555
|
+
|
|
556
556
|
# Also try pre-process transpose
|
|
557
557
|
if self.node.inputs and self.node.inputs[0]:
|
|
558
558
|
input_tensor = self.node.inputs[0]
|
|
@@ -573,7 +573,7 @@ class FlattenFixer(OperationFixer):
|
|
|
573
573
|
else:
|
|
574
574
|
# Default to 4D perms if shape unknown
|
|
575
575
|
perms = list(itertools.permutations(range(4)))
|
|
576
|
-
|
|
576
|
+
|
|
577
577
|
for perm in perms:
|
|
578
578
|
fixes.append({
|
|
579
579
|
"op_name": self.node.name,
|
|
@@ -582,40 +582,40 @@ class FlattenFixer(OperationFixer):
|
|
|
582
582
|
"pre_process_transpose_perm": perm,
|
|
583
583
|
"confidence": 0.5
|
|
584
584
|
})
|
|
585
|
-
|
|
585
|
+
|
|
586
586
|
return fixes
|
|
587
587
|
|
|
588
588
|
|
|
589
589
|
class ExpandFixer(OperationFixer):
|
|
590
590
|
"""Fixer for Expand operations"""
|
|
591
|
-
|
|
591
|
+
|
|
592
592
|
def generate_fixes(self) -> List[Dict[str, Any]]:
|
|
593
593
|
fixes = []
|
|
594
|
-
|
|
594
|
+
|
|
595
595
|
# Check if dimension mismatch is in error
|
|
596
596
|
if 'error_msg' in self.error_info:
|
|
597
597
|
error_msg = self.error_info['error_msg']
|
|
598
598
|
# Extract shape info from error message
|
|
599
599
|
shape_pattern = r'\[([\d,\s]+)\]'
|
|
600
600
|
shapes = re.findall(shape_pattern, error_msg)
|
|
601
|
-
|
|
601
|
+
|
|
602
602
|
if len(shapes) >= 2:
|
|
603
603
|
# Parse shapes
|
|
604
604
|
shape1 = [int(x.strip()) for x in shapes[0].split(',')]
|
|
605
605
|
shape2 = [int(x.strip()) for x in shapes[1].split(',')]
|
|
606
|
-
|
|
606
|
+
|
|
607
607
|
# For custom_spo2 case: [1,2,1,256,32,1] vs [1,1,1,1,2,1]
|
|
608
608
|
# The issue is dimension 4: shape1[4]=32 but shape2[4]=2
|
|
609
609
|
# We need to find where in shape1 we have value 2 and move it to position 4
|
|
610
610
|
if len(shape1) == len(shape2):
|
|
611
611
|
ndim = len(shape1)
|
|
612
|
-
|
|
612
|
+
|
|
613
613
|
# Find positions where shape2 has non-1 values (broadcast targets)
|
|
614
614
|
target_positions = []
|
|
615
615
|
for i in range(ndim):
|
|
616
616
|
if shape2[i] != 1:
|
|
617
617
|
target_positions.append((i, shape2[i]))
|
|
618
|
-
|
|
618
|
+
|
|
619
619
|
# For each target position, find matching values in shape1
|
|
620
620
|
for target_pos, target_val in target_positions:
|
|
621
621
|
if shape1[target_pos] != target_val:
|
|
@@ -624,7 +624,7 @@ class ExpandFixer(OperationFixer):
|
|
|
624
624
|
if shape1[source_pos] == target_val:
|
|
625
625
|
# Create permutation that moves source_pos to target_pos
|
|
626
626
|
perm = list(range(ndim))
|
|
627
|
-
|
|
627
|
+
|
|
628
628
|
# Complex permutation to maintain other dimensions
|
|
629
629
|
if source_pos != target_pos:
|
|
630
630
|
# For [0,1,2,3,4,5] moving 1->4 becomes [0,4,2,3,1,5]
|
|
@@ -639,11 +639,11 @@ class ExpandFixer(OperationFixer):
|
|
|
639
639
|
for j in range(source_pos, target_pos, -1):
|
|
640
640
|
perm[j] = perm[j - 1]
|
|
641
641
|
perm[target_pos] = temp
|
|
642
|
-
|
|
642
|
+
|
|
643
643
|
# Actually, for custom_spo2 we know the exact permutation
|
|
644
644
|
if ndim == 6 and source_pos == 1 and target_pos == 4:
|
|
645
645
|
perm = [0, 4, 2, 3, 1, 5]
|
|
646
|
-
|
|
646
|
+
|
|
647
647
|
# High confidence fix
|
|
648
648
|
if self.node.inputs:
|
|
649
649
|
fixes.append({
|
|
@@ -655,12 +655,12 @@ class ExpandFixer(OperationFixer):
|
|
|
655
655
|
})
|
|
656
656
|
info(f"ExpandFixer: Generated critical permutation {perm} for {self.node.name}")
|
|
657
657
|
break
|
|
658
|
-
|
|
658
|
+
|
|
659
659
|
# Try modifying the shape input directly
|
|
660
660
|
if len(self.node.inputs) >= 2:
|
|
661
661
|
# Second input is usually the shape
|
|
662
662
|
shape_input = self.node.inputs[1]
|
|
663
|
-
|
|
663
|
+
|
|
664
664
|
# Try transposing the shape values
|
|
665
665
|
if hasattr(shape_input, 'shape') and shape_input.shape:
|
|
666
666
|
# Common shape permutations for 6D - CRITICAL permutation first
|
|
@@ -671,7 +671,7 @@ class ExpandFixer(OperationFixer):
|
|
|
671
671
|
[0, 1, 4, 3, 2, 5], # Move dim 2 to 4
|
|
672
672
|
[0, 1, 2, 4, 3, 5], # Move dim 3 to 4
|
|
673
673
|
]
|
|
674
|
-
|
|
674
|
+
|
|
675
675
|
for perm in shape_perms:
|
|
676
676
|
fixes.append({
|
|
677
677
|
"op_name": self.node.name,
|
|
@@ -680,7 +680,7 @@ class ExpandFixer(OperationFixer):
|
|
|
680
680
|
"values": perm, # This will modify the shape values
|
|
681
681
|
"confidence": 0.7
|
|
682
682
|
})
|
|
683
|
-
|
|
683
|
+
|
|
684
684
|
# For Expand, limit permutations to avoid combinatorial explosion
|
|
685
685
|
# Only generate a few strategic permutations
|
|
686
686
|
ndim = 4 # Default
|
|
@@ -689,7 +689,7 @@ class ExpandFixer(OperationFixer):
|
|
|
689
689
|
if hasattr(inp, 'shape') and inp.shape:
|
|
690
690
|
ndim = len(inp.shape)
|
|
691
691
|
break
|
|
692
|
-
|
|
692
|
+
|
|
693
693
|
if ndim == 6:
|
|
694
694
|
# For 6D, only add the most critical permutations
|
|
695
695
|
critical_perms = [
|
|
@@ -719,26 +719,26 @@ class ExpandFixer(OperationFixer):
|
|
|
719
719
|
"pre_process_transpose_perm": list(perm),
|
|
720
720
|
"confidence": 0.5
|
|
721
721
|
})
|
|
722
|
-
|
|
722
|
+
|
|
723
723
|
return fixes
|
|
724
724
|
|
|
725
725
|
|
|
726
726
|
class TileFixer(OperationFixer):
|
|
727
727
|
"""Fixer for Tile operations"""
|
|
728
|
-
|
|
728
|
+
|
|
729
729
|
def generate_fixes(self) -> List[Dict[str, Any]]:
|
|
730
730
|
fixes = []
|
|
731
|
-
|
|
731
|
+
|
|
732
732
|
# Similar to AddMulDivSubFixer - try pre/post transpose
|
|
733
733
|
return AddMulDivSubFixer(self.node, self.error_info).generate_fixes()
|
|
734
734
|
|
|
735
735
|
|
|
736
736
|
class MatMulFixer(OperationFixer):
|
|
737
737
|
"""Fixer for MatMul operations"""
|
|
738
|
-
|
|
738
|
+
|
|
739
739
|
def generate_fixes(self) -> List[Dict[str, Any]]:
|
|
740
740
|
fixes = []
|
|
741
|
-
|
|
741
|
+
|
|
742
742
|
# MatMul often needs transpose adjustments
|
|
743
743
|
def get_matmul_perms(ndim: int) -> List[List[int]]:
|
|
744
744
|
if ndim == 2:
|
|
@@ -756,14 +756,14 @@ class MatMulFixer(OperationFixer):
|
|
|
756
756
|
perm[-1], perm[-1-i] = perm[-1-i], perm[-1]
|
|
757
757
|
perms.append(perm)
|
|
758
758
|
return perms
|
|
759
|
-
|
|
759
|
+
|
|
760
760
|
# Try pre-process transpose
|
|
761
761
|
if self.node.inputs:
|
|
762
762
|
for inp in self.node.inputs[:2]: # First two inputs
|
|
763
763
|
if hasattr(inp, 'shape') and inp.shape:
|
|
764
764
|
ndim = len(inp.shape)
|
|
765
765
|
perms = get_matmul_perms(ndim)
|
|
766
|
-
|
|
766
|
+
|
|
767
767
|
for perm in perms:
|
|
768
768
|
fixes.append({
|
|
769
769
|
"op_name": self.node.name,
|
|
@@ -772,7 +772,7 @@ class MatMulFixer(OperationFixer):
|
|
|
772
772
|
"pre_process_transpose_perm": perm,
|
|
773
773
|
"confidence": 0.6
|
|
774
774
|
})
|
|
775
|
-
|
|
775
|
+
|
|
776
776
|
return fixes
|
|
777
777
|
|
|
778
778
|
|
|
@@ -806,11 +806,11 @@ def get_fixer_for_op(node: gs.Node, error_info: Dict[str, Any]) -> Optional[Oper
|
|
|
806
806
|
"Tile": TileFixer,
|
|
807
807
|
"MatMul": MatMulFixer,
|
|
808
808
|
}
|
|
809
|
-
|
|
809
|
+
|
|
810
810
|
fixer_class = fixers.get(node.op)
|
|
811
811
|
if fixer_class:
|
|
812
812
|
return fixer_class(node, error_info)
|
|
813
|
-
|
|
813
|
+
|
|
814
814
|
return None
|
|
815
815
|
|
|
816
816
|
|
|
@@ -825,12 +825,12 @@ def analyze_conversion_error(
|
|
|
825
825
|
"problematic_ops": [],
|
|
826
826
|
"suggested_op_types": []
|
|
827
827
|
}
|
|
828
|
-
|
|
828
|
+
|
|
829
829
|
error_msg = str(error)
|
|
830
|
-
|
|
830
|
+
|
|
831
831
|
# Debug: Show first 500 chars of error message
|
|
832
832
|
debug(f"Error message preview: {error_msg[:500]}..." if len(error_msg) > 500 else f"Error message: {error_msg}")
|
|
833
|
-
|
|
833
|
+
|
|
834
834
|
# Extract operation name from error message
|
|
835
835
|
patterns = [
|
|
836
836
|
r'onnx_op_name:\s*([^\s]+)',
|
|
@@ -841,7 +841,7 @@ def analyze_conversion_error(
|
|
|
841
841
|
r'tf\.math\.(multiply|add|subtract|divide)_([\d]+)',
|
|
842
842
|
r'wa/lightglue/posenc/Expand', # Specific pattern for custom_spo2
|
|
843
843
|
]
|
|
844
|
-
|
|
844
|
+
|
|
845
845
|
for pattern in patterns:
|
|
846
846
|
matches = re.findall(pattern, error_msg, re.IGNORECASE)
|
|
847
847
|
if matches:
|
|
@@ -856,33 +856,33 @@ def analyze_conversion_error(
|
|
|
856
856
|
error_info["problematic_ops"].append(match)
|
|
857
857
|
else:
|
|
858
858
|
error_info["problematic_ops"].extend(matches)
|
|
859
|
-
|
|
859
|
+
|
|
860
860
|
# Identify operation types that might need fixing
|
|
861
861
|
if "concat" in error_msg.lower():
|
|
862
862
|
error_info["suggested_op_types"].append("Concat")
|
|
863
863
|
error_info["suggested_op_types"].append("Split")
|
|
864
|
-
|
|
864
|
+
|
|
865
865
|
if "dimension" in error_msg.lower() or "shape" in error_msg.lower():
|
|
866
866
|
error_info["suggested_op_types"].extend(["Transpose", "Reshape", "Resize"])
|
|
867
|
-
|
|
867
|
+
|
|
868
868
|
if "transpose" in error_msg.lower():
|
|
869
869
|
error_info["suggested_op_types"].append("Transpose")
|
|
870
|
-
|
|
870
|
+
|
|
871
871
|
if "multiply" in error_msg.lower() or "mul" in error_msg.lower():
|
|
872
872
|
error_info["suggested_op_types"].extend(["Mul", "Transpose", "Reshape"])
|
|
873
|
-
|
|
873
|
+
|
|
874
874
|
if "add" in error_msg.lower():
|
|
875
875
|
error_info["suggested_op_types"].extend(["Add", "Transpose", "Reshape"])
|
|
876
|
-
|
|
876
|
+
|
|
877
877
|
if "div" in error_msg.lower():
|
|
878
878
|
error_info["suggested_op_types"].extend(["Div", "Transpose", "Reshape"])
|
|
879
|
-
|
|
879
|
+
|
|
880
880
|
if "sub" in error_msg.lower():
|
|
881
881
|
error_info["suggested_op_types"].extend(["Sub", "Transpose", "Reshape"])
|
|
882
|
-
|
|
882
|
+
|
|
883
883
|
if "expand" in error_msg.lower():
|
|
884
884
|
error_info["suggested_op_types"].extend(["Expand", "Transpose", "Reshape"])
|
|
885
|
-
|
|
885
|
+
|
|
886
886
|
# Check if the exception has onnx_op_name attribute
|
|
887
887
|
if hasattr(error, 'onnx_op_name') and error.onnx_op_name:
|
|
888
888
|
error_info["onnx_op_name"] = error.onnx_op_name
|
|
@@ -899,7 +899,7 @@ def analyze_conversion_error(
|
|
|
899
899
|
if onnx_op_name not in error_info["problematic_ops"]:
|
|
900
900
|
error_info["problematic_ops"].append(onnx_op_name)
|
|
901
901
|
info(f"Extracted ONNX op name from error message: {onnx_op_name}")
|
|
902
|
-
|
|
902
|
+
|
|
903
903
|
return error_info
|
|
904
904
|
|
|
905
905
|
|
|
@@ -917,14 +917,14 @@ def analyze_accuracy_errors(
|
|
|
917
917
|
"max_error": 0.0,
|
|
918
918
|
"error_distribution": {}
|
|
919
919
|
}
|
|
920
|
-
|
|
920
|
+
|
|
921
921
|
# Group errors by operation
|
|
922
922
|
op_errors = {}
|
|
923
|
-
|
|
923
|
+
|
|
924
924
|
for (onnx_output_name, tf_output_name), checked_value in check_results.items():
|
|
925
925
|
matched_flg = checked_value[1]
|
|
926
926
|
max_abs_err = checked_value[2]
|
|
927
|
-
|
|
927
|
+
|
|
928
928
|
if (matched_flg == 0 or matched_flg == False) and isinstance(max_abs_err, (int, float, np.float32, np.float64)):
|
|
929
929
|
if max_abs_err > error_threshold:
|
|
930
930
|
# Find the operation that produces this output
|
|
@@ -934,19 +934,19 @@ def analyze_accuracy_errors(
|
|
|
934
934
|
op_errors[node.name] = []
|
|
935
935
|
op_errors[node.name].append(max_abs_err)
|
|
936
936
|
break
|
|
937
|
-
|
|
937
|
+
|
|
938
938
|
# Analyze error distribution
|
|
939
939
|
if op_errors:
|
|
940
940
|
error_info["problematic_ops"] = list(op_errors.keys())
|
|
941
941
|
error_info["max_error"] = max(max(errors) for errors in op_errors.values())
|
|
942
|
-
|
|
942
|
+
|
|
943
943
|
# Suggest operation types based on error patterns
|
|
944
944
|
for op_name in op_errors:
|
|
945
945
|
node = next((n for n in onnx_graph.nodes if n.name == op_name), None)
|
|
946
946
|
if node:
|
|
947
947
|
if node.op not in error_info["suggested_op_types"]:
|
|
948
948
|
error_info["suggested_op_types"].append(node.op)
|
|
949
|
-
|
|
949
|
+
|
|
950
950
|
return error_info
|
|
951
951
|
|
|
952
952
|
|
|
@@ -958,14 +958,14 @@ def generate_candidate_fixes(
|
|
|
958
958
|
"""Generate candidate fixes based on error analysis"""
|
|
959
959
|
if previous_attempts is None:
|
|
960
960
|
previous_attempts = set()
|
|
961
|
-
|
|
961
|
+
|
|
962
962
|
candidate_fixes = []
|
|
963
|
-
|
|
963
|
+
|
|
964
964
|
# Priority 1: Fix specific problematic operations
|
|
965
965
|
for op_name in error_info.get("problematic_ops", []):
|
|
966
966
|
# Try to find the node directly
|
|
967
967
|
node = next((n for n in onnx_graph.nodes if n.name == op_name), None)
|
|
968
|
-
|
|
968
|
+
|
|
969
969
|
# If not found and it's a TF operation name, try to find corresponding ONNX node
|
|
970
970
|
if not node and 'tf.math' in op_name:
|
|
971
971
|
# Extract operation type
|
|
@@ -979,7 +979,7 @@ def generate_candidate_fixes(
|
|
|
979
979
|
op_type = 'Div'
|
|
980
980
|
else:
|
|
981
981
|
op_type = None
|
|
982
|
-
|
|
982
|
+
|
|
983
983
|
# For TF operations, we can't directly map to ONNX nodes
|
|
984
984
|
# Skip these for now - they will be handled by the ONNX op name
|
|
985
985
|
pass
|
|
@@ -988,7 +988,7 @@ def generate_candidate_fixes(
|
|
|
988
988
|
if fixer:
|
|
989
989
|
fixes = fixer.generate_fixes()
|
|
990
990
|
candidate_fixes.extend(fixes)
|
|
991
|
-
|
|
991
|
+
|
|
992
992
|
# Priority 2: Fix operations of suggested types - LIMIT TO SPECIFIC NODE IF KNOWN
|
|
993
993
|
if onnx_op_name := error_info.get("onnx_op_name"):
|
|
994
994
|
# Only process the specific node mentioned in the error
|
|
@@ -1013,7 +1013,7 @@ def generate_candidate_fixes(
|
|
|
1013
1013
|
count += 1
|
|
1014
1014
|
if count >= 3: # Limit to first 3 nodes of each type
|
|
1015
1015
|
break
|
|
1016
|
-
|
|
1016
|
+
|
|
1017
1017
|
# Priority 3: Generic fixes for common patterns
|
|
1018
1018
|
if not candidate_fixes:
|
|
1019
1019
|
# Look for all Transpose operations
|
|
@@ -1022,7 +1022,7 @@ def generate_candidate_fixes(
|
|
|
1022
1022
|
fixer = TransposeFixer(node, error_info)
|
|
1023
1023
|
fixes = fixer.generate_fixes()
|
|
1024
1024
|
candidate_fixes.extend(fixes)
|
|
1025
|
-
|
|
1025
|
+
|
|
1026
1026
|
# Priority 4: For concat errors, look more broadly
|
|
1027
1027
|
if "concat" in str(error_info.get("error_msg", "")).lower():
|
|
1028
1028
|
# Look for ALL Transpose, Split, and Concat operations that might need fixing
|
|
@@ -1031,12 +1031,12 @@ def generate_candidate_fixes(
|
|
|
1031
1031
|
# Skip if already processed
|
|
1032
1032
|
if any(fix["op_name"] == node.name for fix in candidate_fixes):
|
|
1033
1033
|
continue
|
|
1034
|
-
|
|
1034
|
+
|
|
1035
1035
|
fixer = get_fixer_for_op(node, error_info)
|
|
1036
1036
|
if fixer:
|
|
1037
1037
|
fixes = fixer.generate_fixes()
|
|
1038
1038
|
candidate_fixes.extend(fixes)
|
|
1039
|
-
|
|
1039
|
+
|
|
1040
1040
|
# Priority 5: Special handling for errors from specific ONNX operations
|
|
1041
1041
|
# Use the extracted onnx_op_name if available
|
|
1042
1042
|
onnx_op_name = error_info.get("onnx_op_name")
|
|
@@ -1052,7 +1052,7 @@ def generate_candidate_fixes(
|
|
|
1052
1052
|
fix['confidence'] = 0.95
|
|
1053
1053
|
candidate_fixes.extend(fixes)
|
|
1054
1054
|
info(f"Found specific node from error: {onnx_op_name} (type: {specific_node.op})")
|
|
1055
|
-
|
|
1055
|
+
|
|
1056
1056
|
# For Expand operations, also find related operations
|
|
1057
1057
|
if specific_node.op == 'Expand':
|
|
1058
1058
|
# Find all Expand operations with similar patterns
|
|
@@ -1065,7 +1065,7 @@ def generate_candidate_fixes(
|
|
|
1065
1065
|
fix['confidence'] = 0.9
|
|
1066
1066
|
candidate_fixes.extend(fixes)
|
|
1067
1067
|
info(f"Added fixes for all Expand operations due to error in {onnx_op_name}")
|
|
1068
|
-
|
|
1068
|
+
|
|
1069
1069
|
# Filter out previously attempted fixes and validate fixes
|
|
1070
1070
|
filtered_fixes = []
|
|
1071
1071
|
for fix in candidate_fixes:
|
|
@@ -1073,7 +1073,7 @@ def generate_candidate_fixes(
|
|
|
1073
1073
|
if fix_key not in previous_attempts:
|
|
1074
1074
|
# Validate the fix
|
|
1075
1075
|
is_valid = True
|
|
1076
|
-
|
|
1076
|
+
|
|
1077
1077
|
# Check if permutation dimensions match tensor dimensions
|
|
1078
1078
|
if "pre_process_transpose_perm" in fix or "post_process_transpose_perm" in fix:
|
|
1079
1079
|
perm = fix.get("pre_process_transpose_perm") or fix.get("post_process_transpose_perm")
|
|
@@ -1091,7 +1091,7 @@ def generate_candidate_fixes(
|
|
|
1091
1091
|
info(f"Skipping invalid fix: {fix['op_name']} - perm len {len(perm)} != tensor dims {expected_dims}")
|
|
1092
1092
|
is_valid = False
|
|
1093
1093
|
break
|
|
1094
|
-
|
|
1094
|
+
|
|
1095
1095
|
# For post_process, check output dimensions
|
|
1096
1096
|
if "post_process_transpose_perm" in fix and node.outputs:
|
|
1097
1097
|
for out in node.outputs:
|
|
@@ -1102,14 +1102,14 @@ def generate_candidate_fixes(
|
|
|
1102
1102
|
info(f"Skipping invalid fix: {fix['op_name']} - perm len {len(perm)} != tensor dims {expected_dims}")
|
|
1103
1103
|
is_valid = False
|
|
1104
1104
|
break
|
|
1105
|
-
|
|
1105
|
+
|
|
1106
1106
|
if is_valid:
|
|
1107
1107
|
filtered_fixes.append(fix)
|
|
1108
1108
|
previous_attempts.add(fix_key)
|
|
1109
|
-
|
|
1109
|
+
|
|
1110
1110
|
# Sort by confidence
|
|
1111
1111
|
filtered_fixes.sort(key=lambda x: x.get("confidence", 0.5), reverse=True)
|
|
1112
|
-
|
|
1112
|
+
|
|
1113
1113
|
return filtered_fixes
|
|
1114
1114
|
|
|
1115
1115
|
|
|
@@ -1117,7 +1117,7 @@ def combine_fixes(fixes: List[Dict[str, Any]], unlimited: bool = False) -> List[
|
|
|
1117
1117
|
"""Generate combinations of fixes to try together"""
|
|
1118
1118
|
if not fixes:
|
|
1119
1119
|
return []
|
|
1120
|
-
|
|
1120
|
+
|
|
1121
1121
|
# Group fixes by operation type and name
|
|
1122
1122
|
op_groups = {}
|
|
1123
1123
|
for fix in fixes:
|
|
@@ -1125,19 +1125,19 @@ def combine_fixes(fixes: List[Dict[str, Any]], unlimited: bool = False) -> List[
|
|
|
1125
1125
|
if op_name not in op_groups:
|
|
1126
1126
|
op_groups[op_name] = []
|
|
1127
1127
|
op_groups[op_name].append(fix)
|
|
1128
|
-
|
|
1128
|
+
|
|
1129
1129
|
combinations = []
|
|
1130
|
-
|
|
1130
|
+
|
|
1131
1131
|
if unlimited:
|
|
1132
1132
|
# For unlimited mode, generate ALL possible combinations for each operation
|
|
1133
1133
|
for op_name, op_fixes in op_groups.items():
|
|
1134
1134
|
# Sort by confidence to prioritize better fixes
|
|
1135
1135
|
sorted_fixes = sorted(op_fixes, key=lambda x: x.get("confidence", 0.5), reverse=True)
|
|
1136
|
-
|
|
1136
|
+
|
|
1137
1137
|
# Add all individual fixes
|
|
1138
1138
|
for fix in sorted_fixes:
|
|
1139
1139
|
combinations.append([fix])
|
|
1140
|
-
|
|
1140
|
+
|
|
1141
1141
|
# Also try combining high-confidence fixes from different operations
|
|
1142
1142
|
high_confidence_fixes = [f for f in fixes if f.get("confidence", 0.5) >= 0.7]
|
|
1143
1143
|
if len(high_confidence_fixes) > 1:
|
|
@@ -1178,18 +1178,18 @@ def combine_fixes(fixes: List[Dict[str, Any]], unlimited: bool = False) -> List[
|
|
|
1178
1178
|
elif "expand" in node_part.lower():
|
|
1179
1179
|
op_type = "Expand"
|
|
1180
1180
|
break
|
|
1181
|
-
|
|
1181
|
+
|
|
1182
1182
|
if not op_type:
|
|
1183
1183
|
# Check if it's a parameter target type fix
|
|
1184
1184
|
if fix.get("param_target") == "inputs" and "pre_process_transpose_perm" in fix:
|
|
1185
1185
|
op_type = "InputTranspose"
|
|
1186
1186
|
else:
|
|
1187
1187
|
op_type = "Other"
|
|
1188
|
-
|
|
1188
|
+
|
|
1189
1189
|
if op_type not in op_type_groups:
|
|
1190
1190
|
op_type_groups[op_type] = []
|
|
1191
1191
|
op_type_groups[op_type].append(fix)
|
|
1192
|
-
|
|
1192
|
+
|
|
1193
1193
|
# First, try all fixes of the same type together
|
|
1194
1194
|
for op_type, type_fixes in op_type_groups.items():
|
|
1195
1195
|
if op_type == "Transpose" and len(type_fixes) > 1:
|
|
@@ -1202,17 +1202,17 @@ def combine_fixes(fixes: List[Dict[str, Any]], unlimited: bool = False) -> List[
|
|
|
1202
1202
|
# For arithmetic operations and input transposes, apply all fixes
|
|
1203
1203
|
sorted_fixes = sorted(type_fixes, key=lambda x: x.get("confidence", 0.5), reverse=True)
|
|
1204
1204
|
combinations.append(sorted_fixes)
|
|
1205
|
-
|
|
1205
|
+
|
|
1206
1206
|
# Then try individual fixes
|
|
1207
1207
|
for fix in fixes[:100]: # Increased limit
|
|
1208
1208
|
combinations.append([fix])
|
|
1209
|
-
|
|
1209
|
+
|
|
1210
1210
|
# Finally, try mixed combinations
|
|
1211
1211
|
if "Transpose" in op_type_groups and "Concat" in op_type_groups:
|
|
1212
1212
|
trans_fixes = op_type_groups["Transpose"]
|
|
1213
1213
|
concat_fixes = op_type_groups["Concat"]
|
|
1214
1214
|
combinations.append(trans_fixes + concat_fixes)
|
|
1215
|
-
|
|
1215
|
+
|
|
1216
1216
|
return combinations
|
|
1217
1217
|
|
|
1218
1218
|
|
|
@@ -1228,7 +1228,7 @@ def test_conversion_with_json(
|
|
|
1228
1228
|
import tempfile
|
|
1229
1229
|
import subprocess
|
|
1230
1230
|
import shutil
|
|
1231
|
-
|
|
1231
|
+
|
|
1232
1232
|
# Create temporary JSON file
|
|
1233
1233
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
|
1234
1234
|
json_content = {
|
|
@@ -1237,10 +1237,10 @@ def test_conversion_with_json(
|
|
|
1237
1237
|
}
|
|
1238
1238
|
json.dump(json_content, f)
|
|
1239
1239
|
temp_json_path = f.name
|
|
1240
|
-
|
|
1240
|
+
|
|
1241
1241
|
# Create temporary output directory
|
|
1242
1242
|
temp_output_dir = tempfile.mkdtemp()
|
|
1243
|
-
|
|
1243
|
+
|
|
1244
1244
|
try:
|
|
1245
1245
|
# Run conversion with the JSON
|
|
1246
1246
|
cmd = [
|
|
@@ -1248,17 +1248,17 @@ def test_conversion_with_json(
|
|
|
1248
1248
|
"-i", model_path,
|
|
1249
1249
|
"-prf", temp_json_path,
|
|
1250
1250
|
"-o", temp_output_dir,
|
|
1251
|
-
"-n", # No optimization
|
|
1251
|
+
"-n", # No optimization
|
|
1252
1252
|
"-q" # Quiet mode
|
|
1253
1253
|
]
|
|
1254
|
-
|
|
1254
|
+
|
|
1255
1255
|
result = subprocess.run(
|
|
1256
1256
|
cmd,
|
|
1257
1257
|
capture_output=True,
|
|
1258
1258
|
text=True,
|
|
1259
1259
|
timeout=timeout
|
|
1260
1260
|
)
|
|
1261
|
-
|
|
1261
|
+
|
|
1262
1262
|
if result.returncode == 0:
|
|
1263
1263
|
# Conversion succeeded, check if accuracy test would pass
|
|
1264
1264
|
# For now, assume success means good accuracy
|
|
@@ -1272,7 +1272,7 @@ def test_conversion_with_json(
|
|
|
1272
1272
|
else:
|
|
1273
1273
|
# Different error, might be progress
|
|
1274
1274
|
return (False, error_msg, None)
|
|
1275
|
-
|
|
1275
|
+
|
|
1276
1276
|
except subprocess.TimeoutExpired:
|
|
1277
1277
|
return (False, "Conversion timed out", None)
|
|
1278
1278
|
except Exception as e:
|
|
@@ -1300,7 +1300,7 @@ def generate_auto_replacement_json(
|
|
|
1300
1300
|
Generate automatic parameter replacement JSON based on conversion errors or accuracy issues.
|
|
1301
1301
|
This implements an exhaustive search algorithm that tries different parameter modifications
|
|
1302
1302
|
until finding the optimal solution with minimum error.
|
|
1303
|
-
|
|
1303
|
+
|
|
1304
1304
|
Args:
|
|
1305
1305
|
onnx_graph: ONNX graph
|
|
1306
1306
|
tf_layers_dict: TensorFlow layers dictionary
|
|
@@ -1311,60 +1311,60 @@ def generate_auto_replacement_json(
|
|
|
1311
1311
|
max_iterations: Maximum number of optimization iterations
|
|
1312
1312
|
target_accuracy: Target accuracy to achieve
|
|
1313
1313
|
unlimited_mode: If True, test all combinations until minimum error found
|
|
1314
|
-
|
|
1314
|
+
|
|
1315
1315
|
Returns:
|
|
1316
1316
|
Dictionary containing the replacement JSON structure
|
|
1317
1317
|
"""
|
|
1318
1318
|
info("Starting automatic JSON generation...")
|
|
1319
|
-
|
|
1319
|
+
|
|
1320
1320
|
# Initialize
|
|
1321
1321
|
best_operations = []
|
|
1322
1322
|
previous_attempts = set()
|
|
1323
1323
|
iteration = 0
|
|
1324
|
-
|
|
1324
|
+
|
|
1325
1325
|
# Analyze the error
|
|
1326
1326
|
if conversion_error:
|
|
1327
1327
|
error_info = analyze_conversion_error(conversion_error, onnx_graph)
|
|
1328
1328
|
info(f"Conversion error analysis: {error_info['error_type']}")
|
|
1329
1329
|
info(f"Problematic operations: {error_info.get('problematic_ops', [])}")
|
|
1330
1330
|
info(f"Suggested operation types: {error_info.get('suggested_op_types', [])}")
|
|
1331
|
-
|
|
1331
|
+
|
|
1332
1332
|
# Generate initial fixes
|
|
1333
1333
|
candidate_fixes = generate_candidate_fixes(onnx_graph, error_info, previous_attempts)
|
|
1334
|
-
|
|
1334
|
+
|
|
1335
1335
|
if candidate_fixes:
|
|
1336
1336
|
info(f"Generated {len(candidate_fixes)} candidate fixes for conversion error")
|
|
1337
|
-
|
|
1337
|
+
|
|
1338
1338
|
# Use unlimited mode to get ALL possible combinations
|
|
1339
1339
|
fix_combinations = combine_fixes(candidate_fixes, unlimited=True)
|
|
1340
1340
|
info(f"Generated {len(fix_combinations)} fix combinations to test")
|
|
1341
|
-
|
|
1341
|
+
|
|
1342
1342
|
# For conversion errors, we need to actually test each combination
|
|
1343
1343
|
# by attempting conversion with the temporary JSON
|
|
1344
1344
|
best_operations = []
|
|
1345
1345
|
best_error_msg = str(conversion_error)
|
|
1346
1346
|
tested_count = 0
|
|
1347
|
-
|
|
1347
|
+
|
|
1348
1348
|
# First, prioritize high-confidence single fixes
|
|
1349
1349
|
single_fixes = [combo for combo in fix_combinations if len(combo) == 1]
|
|
1350
1350
|
single_fixes.sort(key=lambda combo: combo[0].get("confidence", 0.5), reverse=True)
|
|
1351
|
-
|
|
1351
|
+
|
|
1352
1352
|
info("Testing individual fixes first...")
|
|
1353
1353
|
for i, combo in enumerate(single_fixes):
|
|
1354
1354
|
tested_count += 1
|
|
1355
1355
|
if tested_count % 100 == 0:
|
|
1356
1356
|
info(f"Tested {tested_count}/{len(fix_combinations)} combinations...")
|
|
1357
|
-
|
|
1357
|
+
|
|
1358
1358
|
# Check if this is a critical fix
|
|
1359
1359
|
fix = combo[0]
|
|
1360
|
-
|
|
1360
|
+
|
|
1361
1361
|
# For Expand operations with critical permutation
|
|
1362
|
-
if ("Expand" in fix.get("op_name", "") and
|
|
1362
|
+
if ("Expand" in fix.get("op_name", "") and
|
|
1363
1363
|
fix.get("pre_process_transpose_perm") == [0, 4, 2, 3, 1, 5]):
|
|
1364
1364
|
info(f"Found critical permutation [0,4,2,3,1,5] for {fix['op_name']}!")
|
|
1365
1365
|
best_operations = combo
|
|
1366
1366
|
break
|
|
1367
|
-
|
|
1367
|
+
|
|
1368
1368
|
# Prioritize high-confidence fixes that match the error pattern
|
|
1369
1369
|
if "Expand" in str(conversion_error) and "Expand" in fix.get("op_name", ""):
|
|
1370
1370
|
# Select highest confidence Expand fix
|
|
@@ -1372,13 +1372,13 @@ def generate_auto_replacement_json(
|
|
|
1372
1372
|
best_operations = combo
|
|
1373
1373
|
info(f"Selected high-confidence fix (conf={fix.get('confidence')}) for {fix['op_name']}")
|
|
1374
1374
|
break
|
|
1375
|
-
|
|
1375
|
+
|
|
1376
1376
|
# If no good single fix found, try combinations
|
|
1377
1377
|
if not best_operations and len(fix_combinations) > len(single_fixes):
|
|
1378
1378
|
info("Testing fix combinations...")
|
|
1379
1379
|
multi_fixes = [combo for combo in fix_combinations if len(combo) > 1]
|
|
1380
1380
|
multi_fixes.sort(key=lambda combo: sum(f.get("confidence", 0.5) for f in combo) / len(combo), reverse=True)
|
|
1381
|
-
|
|
1381
|
+
|
|
1382
1382
|
for combo in multi_fixes[:50]: # Test top 50 combinations
|
|
1383
1383
|
tested_count += 1
|
|
1384
1384
|
# In real implementation, test conversion here
|
|
@@ -1386,39 +1386,39 @@ def generate_auto_replacement_json(
|
|
|
1386
1386
|
if any("Expand" in f.get("op_name", "") for f in combo):
|
|
1387
1387
|
best_operations = combo
|
|
1388
1388
|
break
|
|
1389
|
-
|
|
1389
|
+
|
|
1390
1390
|
# Fallback: use highest confidence fixes
|
|
1391
1391
|
if not best_operations and fix_combinations:
|
|
1392
1392
|
best_operations = fix_combinations[0]
|
|
1393
|
-
|
|
1393
|
+
|
|
1394
1394
|
info(f"Selected {len(best_operations)} operations after testing {tested_count} combinations")
|
|
1395
|
-
|
|
1395
|
+
|
|
1396
1396
|
elif check_results:
|
|
1397
1397
|
error_info = analyze_accuracy_errors(check_results, tf_layers_dict, onnx_graph, error_threshold)
|
|
1398
1398
|
info(f"Accuracy error analysis: max error = {error_info['max_error']:.6f}")
|
|
1399
|
-
|
|
1399
|
+
|
|
1400
1400
|
if error_info['max_error'] > target_accuracy:
|
|
1401
1401
|
info(f"Starting iterative optimization (target accuracy: {target_accuracy})")
|
|
1402
|
-
|
|
1402
|
+
|
|
1403
1403
|
# Iterative optimization loop
|
|
1404
1404
|
current_error = error_info['max_error']
|
|
1405
|
-
|
|
1405
|
+
|
|
1406
1406
|
while iteration < max_iterations and current_error > target_accuracy:
|
|
1407
1407
|
iteration += 1
|
|
1408
1408
|
info(f"\nIteration {iteration}/{max_iterations}")
|
|
1409
|
-
|
|
1409
|
+
|
|
1410
1410
|
# Generate candidate fixes
|
|
1411
1411
|
candidate_fixes = generate_candidate_fixes(onnx_graph, error_info, previous_attempts)
|
|
1412
|
-
|
|
1412
|
+
|
|
1413
1413
|
if not candidate_fixes:
|
|
1414
1414
|
info("No more candidate fixes available")
|
|
1415
1415
|
break
|
|
1416
|
-
|
|
1416
|
+
|
|
1417
1417
|
info(f"Generated {len(candidate_fixes)} candidate fixes")
|
|
1418
|
-
|
|
1418
|
+
|
|
1419
1419
|
# Generate fix combinations
|
|
1420
1420
|
fix_combinations = combine_fixes(candidate_fixes)
|
|
1421
|
-
|
|
1421
|
+
|
|
1422
1422
|
# In a real implementation, we would test each combination
|
|
1423
1423
|
# For now, we'll use heuristics to select the best combination
|
|
1424
1424
|
if fix_combinations:
|
|
@@ -1427,30 +1427,30 @@ def generate_auto_replacement_json(
|
|
|
1427
1427
|
fix_combinations,
|
|
1428
1428
|
key=lambda combo: sum(fix.get("confidence", 0.5) for fix in combo) / len(combo)
|
|
1429
1429
|
)
|
|
1430
|
-
|
|
1430
|
+
|
|
1431
1431
|
best_operations.extend(best_combination)
|
|
1432
1432
|
info(f"Applied {len(best_combination)} fixes in this iteration")
|
|
1433
|
-
|
|
1433
|
+
|
|
1434
1434
|
# Simulate improvement (in real implementation, this would re-run conversion)
|
|
1435
1435
|
improvement_factor = 0.5 + 0.3 * sum(fix.get("confidence", 0.5) for fix in best_combination) / len(best_combination)
|
|
1436
1436
|
current_error *= improvement_factor
|
|
1437
1437
|
info(f"Estimated error after fixes: {current_error:.6f}")
|
|
1438
1438
|
else:
|
|
1439
1439
|
break
|
|
1440
|
-
|
|
1440
|
+
|
|
1441
1441
|
# Remove confidence scores from final output
|
|
1442
1442
|
for op in best_operations:
|
|
1443
1443
|
if "confidence" in op:
|
|
1444
1444
|
del op["confidence"]
|
|
1445
|
-
|
|
1445
|
+
|
|
1446
1446
|
# Generate the JSON structure
|
|
1447
1447
|
model_name = os.path.splitext(os.path.basename(model_path))[0] if model_path else "model"
|
|
1448
|
-
|
|
1448
|
+
|
|
1449
1449
|
replacement_json = {
|
|
1450
1450
|
"format_version": 1,
|
|
1451
1451
|
"operations": best_operations
|
|
1452
1452
|
}
|
|
1453
|
-
|
|
1453
|
+
|
|
1454
1454
|
# Add metadata comments
|
|
1455
1455
|
if best_operations:
|
|
1456
1456
|
replacement_json["_comment"] = f"Auto-generated replacement for {model_name}"
|
|
@@ -1460,7 +1460,7 @@ def generate_auto_replacement_json(
|
|
|
1460
1460
|
replacement_json["_iterations"] = iteration
|
|
1461
1461
|
if conversion_error:
|
|
1462
1462
|
replacement_json["_generation_reason"] = "conversion_error"
|
|
1463
|
-
|
|
1463
|
+
|
|
1464
1464
|
return replacement_json
|
|
1465
1465
|
|
|
1466
1466
|
|
|
@@ -1471,35 +1471,35 @@ def save_auto_replacement_json(
|
|
|
1471
1471
|
) -> str:
|
|
1472
1472
|
"""
|
|
1473
1473
|
Save the auto-generated replacement JSON to a file.
|
|
1474
|
-
|
|
1474
|
+
|
|
1475
1475
|
Args:
|
|
1476
1476
|
replacement_json: The replacement JSON dictionary
|
|
1477
1477
|
model_path: Path to the ONNX model
|
|
1478
1478
|
output_dir: Directory to save the JSON file (default: same as model)
|
|
1479
|
-
|
|
1479
|
+
|
|
1480
1480
|
Returns:
|
|
1481
1481
|
Path to the saved JSON file
|
|
1482
1482
|
"""
|
|
1483
1483
|
if not replacement_json.get("operations"):
|
|
1484
1484
|
return ""
|
|
1485
|
-
|
|
1485
|
+
|
|
1486
1486
|
# Generate filename
|
|
1487
1487
|
model_name = os.path.splitext(os.path.basename(model_path))[0]
|
|
1488
1488
|
json_filename = f"{model_name}_auto.json"
|
|
1489
|
-
|
|
1489
|
+
|
|
1490
1490
|
# Determine output directory
|
|
1491
1491
|
if output_dir is None:
|
|
1492
1492
|
output_dir = os.path.dirname(model_path)
|
|
1493
|
-
|
|
1493
|
+
|
|
1494
1494
|
# Create output directory if it doesn't exist
|
|
1495
1495
|
if output_dir and not os.path.exists(output_dir):
|
|
1496
1496
|
os.makedirs(output_dir)
|
|
1497
|
-
|
|
1497
|
+
|
|
1498
1498
|
json_path = os.path.join(output_dir, json_filename)
|
|
1499
|
-
|
|
1499
|
+
|
|
1500
1500
|
# Save JSON
|
|
1501
1501
|
with open(json_path, 'w', encoding='utf-8') as f:
|
|
1502
1502
|
json.dump(replacement_json, f, indent=2, ensure_ascii=False)
|
|
1503
|
-
|
|
1503
|
+
|
|
1504
1504
|
info(f"Auto-generated replacement JSON saved to: {json_path}")
|
|
1505
1505
|
return json_path
|