onnx2tf 1.29.1__py3-none-any.whl → 1.29.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -20,11 +20,11 @@ import shutil
20
20
 
21
21
  class OperationFixer:
22
22
  """Base class for operation-specific fixers"""
23
-
23
+
24
24
  def __init__(self, node: gs.Node, error_info: Dict[str, Any]):
25
25
  self.node = node
26
26
  self.error_info = error_info
27
-
27
+
28
28
  def generate_fixes(self) -> List[Dict[str, Any]]:
29
29
  """Generate possible fixes for this operation"""
30
30
  raise NotImplementedError
@@ -32,15 +32,15 @@ class OperationFixer:
32
32
 
33
33
  class TransposeFixer(OperationFixer):
34
34
  """Fixer for Transpose operations"""
35
-
35
+
36
36
  def generate_fixes(self) -> List[Dict[str, Any]]:
37
37
  fixes = []
38
38
  perm = self.node.attrs.get("perm", [])
39
39
  if not perm:
40
40
  return fixes
41
-
41
+
42
42
  ndim = len(perm)
43
-
43
+
44
44
  # Generate all permutations based on dimension
45
45
  if ndim <= 3:
46
46
  # For 3D or less, try all permutations (max 6)
@@ -73,11 +73,11 @@ class TransposeFixer(OperationFixer):
73
73
  candidates.append(list(p))
74
74
  if len(candidates) >= 50:
75
75
  break
76
-
76
+
77
77
  # Add current perm if not in candidates
78
78
  if perm not in candidates:
79
79
  candidates.insert(1, perm)
80
-
80
+
81
81
  # Generate fixes
82
82
  for candidate in candidates:
83
83
  if candidate != perm:
@@ -88,17 +88,17 @@ class TransposeFixer(OperationFixer):
88
88
  "values": candidate,
89
89
  "confidence": 0.8 if candidate == list(range(ndim)) else 0.5
90
90
  })
91
-
91
+
92
92
  return fixes
93
93
 
94
94
 
95
95
  class ConcatFixer(OperationFixer):
96
96
  """Fixer for Concat operations"""
97
-
97
+
98
98
  def generate_fixes(self) -> List[Dict[str, Any]]:
99
99
  fixes = []
100
100
  axis = self.node.attrs.get("axis", 1)
101
-
101
+
102
102
  # Common axis adjustments for NCHW to NHWC conversion
103
103
  axis_mappings = {
104
104
  1: [3, -1], # Channel dimension NCHW -> NHWC
@@ -106,9 +106,9 @@ class ConcatFixer(OperationFixer):
106
106
  3: [2], # Width dimension
107
107
  -1: [3, 1], # Last dimension
108
108
  }
109
-
109
+
110
110
  candidates = axis_mappings.get(axis, [])
111
-
111
+
112
112
  # Always try the complement dimension
113
113
  ndim = 4 # Assume 4D by default, will be refined later
114
114
  if hasattr(self.node, 'inputs') and self.node.inputs:
@@ -117,7 +117,7 @@ class ConcatFixer(OperationFixer):
117
117
  if hasattr(inp, 'shape') and inp.shape:
118
118
  ndim = len(inp.shape)
119
119
  break
120
-
120
+
121
121
  # Add dimension-specific candidates
122
122
  if ndim == 4:
123
123
  candidates.extend([3, 1, 2, -1])
@@ -125,11 +125,11 @@ class ConcatFixer(OperationFixer):
125
125
  candidates.extend([2, 1, -1])
126
126
  elif ndim == 5:
127
127
  candidates.extend([4, 1, 2, 3, -1])
128
-
128
+
129
129
  # Remove duplicates and current axis
130
130
  candidates = list(dict.fromkeys(candidates))
131
131
  candidates = [c for c in candidates if c != axis]
132
-
132
+
133
133
  for candidate in candidates:
134
134
  fixes.append({
135
135
  "op_name": self.node.name,
@@ -138,13 +138,13 @@ class ConcatFixer(OperationFixer):
138
138
  "values": candidate,
139
139
  "confidence": 0.7
140
140
  })
141
-
141
+
142
142
  return fixes
143
143
 
144
144
 
145
145
  class SplitFixer(OperationFixer):
146
146
  """Fixer for Split operations"""
147
-
147
+
148
148
  def generate_fixes(self) -> List[Dict[str, Any]]:
149
149
  # Similar to ConcatFixer but for Split
150
150
  return ConcatFixer(self.node, self.error_info).generate_fixes()
@@ -152,10 +152,10 @@ class SplitFixer(OperationFixer):
152
152
 
153
153
  class ReshapeFixer(OperationFixer):
154
154
  """Fixer for Reshape operations"""
155
-
155
+
156
156
  def generate_fixes(self) -> List[Dict[str, Any]]:
157
157
  fixes = []
158
-
158
+
159
159
  # Generate all permutations for different dimensions
160
160
  # For performance reasons, limit permutations for higher dimensions
161
161
  def get_all_permutations(ndim: int) -> List[List[int]]:
@@ -186,7 +186,7 @@ class ReshapeFixer(OperationFixer):
186
186
  if len(perms) >= 30: # Limit to 30 permutations
187
187
  break
188
188
  return perms
189
-
189
+
190
190
  # Get input shape for pre-process transpose
191
191
  input_ndim = 4 # Default
192
192
  if hasattr(self.node, 'inputs') and self.node.inputs:
@@ -194,7 +194,7 @@ class ReshapeFixer(OperationFixer):
194
194
  if hasattr(input_tensor, 'shape') and input_tensor.shape:
195
195
  input_ndim = len(input_tensor.shape)
196
196
  info(f"ReshapeFixer: Input tensor shape: {input_tensor.shape} (ndim={input_ndim})")
197
-
197
+
198
198
  # Get output shape for post-process transpose
199
199
  output_ndim = 4 # Default
200
200
  if hasattr(self.node, 'outputs') and self.node.outputs:
@@ -202,7 +202,7 @@ class ReshapeFixer(OperationFixer):
202
202
  if hasattr(output_tensor, 'shape') and output_tensor.shape:
203
203
  output_ndim = len(output_tensor.shape)
204
204
  info(f"ReshapeFixer: Output tensor shape: {output_tensor.shape} (ndim={output_ndim})")
205
-
205
+
206
206
  # Generate pre-process transpose based on input dimensions
207
207
  input_perms = get_all_permutations(input_ndim)
208
208
  for perm in input_perms:
@@ -213,7 +213,7 @@ class ReshapeFixer(OperationFixer):
213
213
  "pre_process_transpose_perm": perm,
214
214
  "confidence": 0.6
215
215
  })
216
-
216
+
217
217
  # Generate post-process transpose based on output dimensions
218
218
  output_perms = get_all_permutations(output_ndim)
219
219
  for perm in output_perms:
@@ -224,11 +224,11 @@ class ReshapeFixer(OperationFixer):
224
224
  "post_process_transpose_perm": perm,
225
225
  "confidence": 0.6
226
226
  })
227
-
227
+
228
228
  # Also try modifying the shape parameter directly
229
229
  if len(self.node.inputs) >= 2:
230
230
  shape_input = self.node.inputs[1]
231
-
231
+
232
232
  # Common shape modifications
233
233
  # For example, if reshaping from [N,C,H,W] to [N,C*H*W]
234
234
  # We might need to transpose to [N,H,W,C] first
@@ -239,20 +239,20 @@ class ReshapeFixer(OperationFixer):
239
239
  "values": [-1, -1], # Let Reshape infer dimensions
240
240
  "confidence": 0.4
241
241
  })
242
-
242
+
243
243
  return fixes
244
244
 
245
245
 
246
246
  class ResizeFixer(OperationFixer):
247
247
  """Fixer for Resize operations"""
248
-
248
+
249
249
  def generate_fixes(self) -> List[Dict[str, Any]]:
250
250
  fixes = []
251
-
251
+
252
252
  # Try different coordinate transformation modes
253
253
  modes = ["asymmetric", "pytorch_half_pixel", "tf_half_pixel_for_nn", "align_corners"]
254
254
  current_mode = self.node.attrs.get("coordinate_transformation_mode", "half_pixel")
255
-
255
+
256
256
  for mode in modes:
257
257
  if mode != current_mode:
258
258
  fixes.append({
@@ -262,11 +262,11 @@ class ResizeFixer(OperationFixer):
262
262
  "values": mode,
263
263
  "confidence": 0.5
264
264
  })
265
-
265
+
266
266
  # Try different interpolation modes
267
267
  interp_modes = ["nearest", "linear", "cubic"]
268
268
  current_interp = self.node.attrs.get("mode", "nearest")
269
-
269
+
270
270
  for mode in interp_modes:
271
271
  if mode != current_interp:
272
272
  fixes.append({
@@ -276,18 +276,18 @@ class ResizeFixer(OperationFixer):
276
276
  "values": mode,
277
277
  "confidence": 0.5
278
278
  })
279
-
279
+
280
280
  return fixes
281
281
 
282
282
 
283
283
  class ReduceFixer(OperationFixer):
284
284
  """Fixer for Reduce operations (ReduceMax, ReduceMean, etc.)"""
285
-
285
+
286
286
  def generate_fixes(self) -> List[Dict[str, Any]]:
287
287
  fixes = []
288
288
  axes = self.node.attrs.get("axes", [])
289
289
  keepdims = self.node.attrs.get("keepdims", 1)
290
-
290
+
291
291
  # Try different axes combinations
292
292
  if axes:
293
293
  # Common axis mappings for dimension conversion
@@ -296,12 +296,12 @@ class ReduceFixer(OperationFixer):
296
296
  2: [1, -2], # Height dimension
297
297
  3: [2, -1], # Width dimension
298
298
  }
299
-
299
+
300
300
  new_axes = []
301
301
  for axis in axes:
302
302
  if axis in axis_mappings:
303
303
  new_axes.extend(axis_mappings[axis])
304
-
304
+
305
305
  if new_axes and new_axes != axes:
306
306
  fixes.append({
307
307
  "op_name": self.node.name,
@@ -310,7 +310,7 @@ class ReduceFixer(OperationFixer):
310
310
  "values": list(dict.fromkeys(new_axes))[:len(axes)],
311
311
  "confidence": 0.6
312
312
  })
313
-
313
+
314
314
  # Try toggling keepdims
315
315
  fixes.append({
316
316
  "op_name": self.node.name,
@@ -319,20 +319,20 @@ class ReduceFixer(OperationFixer):
319
319
  "values": 1 - keepdims,
320
320
  "confidence": 0.4
321
321
  })
322
-
322
+
323
323
  return fixes
324
324
 
325
325
 
326
326
  class SoftmaxFixer(OperationFixer):
327
327
  """Fixer for Softmax operations"""
328
-
328
+
329
329
  def generate_fixes(self) -> List[Dict[str, Any]]:
330
330
  fixes = []
331
331
  axis = self.node.attrs.get("axis", -1)
332
-
332
+
333
333
  # Common axis adjustments
334
334
  candidates = [-1, 1, 2, 3]
335
-
335
+
336
336
  for candidate in candidates:
337
337
  if candidate != axis:
338
338
  fixes.append({
@@ -342,16 +342,16 @@ class SoftmaxFixer(OperationFixer):
342
342
  "values": candidate,
343
343
  "confidence": 0.6
344
344
  })
345
-
345
+
346
346
  return fixes
347
347
 
348
348
 
349
349
  class AddMulDivSubFixer(OperationFixer):
350
350
  """Fixer for Add, Mul, Div, Sub operations"""
351
-
351
+
352
352
  def generate_fixes(self) -> List[Dict[str, Any]]:
353
353
  fixes = []
354
-
354
+
355
355
  # Generate all permutations for pre/post transpose
356
356
  def get_perms_for_ndim(ndim: int) -> List[List[int]]:
357
357
  if ndim <= 3:
@@ -386,9 +386,9 @@ class AddMulDivSubFixer(OperationFixer):
386
386
  if len(perms) >= 30:
387
387
  break
388
388
  return perms
389
-
389
+
390
390
  common_perms = {}
391
-
391
+
392
392
  # Try to determine the dimension
393
393
  ndim = 4 # Default
394
394
  if hasattr(self.node, 'inputs') and self.node.inputs:
@@ -396,7 +396,7 @@ class AddMulDivSubFixer(OperationFixer):
396
396
  if hasattr(inp, 'shape') and inp.shape:
397
397
  ndim = len(inp.shape)
398
398
  break
399
-
399
+
400
400
  # Check if dimension mismatch is mentioned in error
401
401
  if 'error_msg' in self.error_info:
402
402
  error_msg = self.error_info['error_msg']
@@ -407,7 +407,7 @@ class AddMulDivSubFixer(OperationFixer):
407
407
  # Parse shapes
408
408
  shape1 = [int(x.strip()) for x in shapes[0].split(',')]
409
409
  shape2 = [int(x.strip()) for x in shapes[1].split(',')]
410
-
410
+
411
411
  # Generate transpose fixes for broadcasting issues
412
412
  if len(shape1) == len(shape2) and len(shape1) == 6:
413
413
  # For 6D tensor broadcasting issues
@@ -416,7 +416,7 @@ class AddMulDivSubFixer(OperationFixer):
416
416
  for i in range(6):
417
417
  if shape1[i] != shape2[i] and shape2[i] != 1 and shape1[i] != 1:
418
418
  mismatches.append(i)
419
-
419
+
420
420
  special_perms = []
421
421
  # If we have exactly 2 mismatched dimensions, swap them
422
422
  if len(mismatches) == 2:
@@ -424,14 +424,14 @@ class AddMulDivSubFixer(OperationFixer):
424
424
  perm[mismatches[0]], perm[mismatches[1]] = perm[mismatches[1]], perm[mismatches[0]]
425
425
  special_perms.append(perm)
426
426
  info(f"ExpandFixer: Detected dimension mismatch at dims {mismatches}, suggesting permutation {perm}")
427
-
427
+
428
428
  # Also add common patterns
429
429
  special_perms.extend([
430
430
  [0, 4, 2, 3, 1, 5], # Swap dims 1 and 4 (common for this error)
431
431
  [0, 2, 1, 3, 4, 5], # Swap dims 1 and 2
432
432
  [0, 1, 2, 3, 5, 4], # Swap last two
433
433
  ])
434
-
434
+
435
435
  for perm in special_perms:
436
436
  for inp in self.node.inputs[:2]: # Both inputs
437
437
  if hasattr(inp, 'name'):
@@ -442,7 +442,7 @@ class AddMulDivSubFixer(OperationFixer):
442
442
  "pre_process_transpose_perm": perm,
443
443
  "confidence": 0.7
444
444
  })
445
-
445
+
446
446
  # Generate permutations for the detected dimension
447
447
  perms = get_perms_for_ndim(ndim)
448
448
  if perms:
@@ -456,7 +456,7 @@ class AddMulDivSubFixer(OperationFixer):
456
456
  "pre_process_transpose_perm": perm,
457
457
  "confidence": 0.5
458
458
  })
459
-
459
+
460
460
  # Post-process transpose
461
461
  if self.node.outputs:
462
462
  fixes.append({
@@ -466,16 +466,16 @@ class AddMulDivSubFixer(OperationFixer):
466
466
  "post_process_transpose_perm": perm,
467
467
  "confidence": 0.5
468
468
  })
469
-
469
+
470
470
  return fixes
471
471
 
472
472
 
473
473
  class CastFixer(OperationFixer):
474
474
  """Fixer for Cast operations"""
475
-
475
+
476
476
  def generate_fixes(self) -> List[Dict[str, Any]]:
477
477
  fixes = []
478
-
478
+
479
479
  # Type mappings from README
480
480
  type_values = {
481
481
  "float32": 1,
@@ -491,12 +491,12 @@ class CastFixer(OperationFixer):
491
491
  "uint32": 12,
492
492
  "uint64": 13,
493
493
  }
494
-
494
+
495
495
  current_to = self.node.attrs.get("to", 1)
496
-
496
+
497
497
  # Try common type conversions
498
498
  common_types = [1, 6, 7] # float32, int32, int64
499
-
499
+
500
500
  for type_val in common_types:
501
501
  if type_val != current_to:
502
502
  fixes.append({
@@ -506,20 +506,20 @@ class CastFixer(OperationFixer):
506
506
  "values": type_val,
507
507
  "confidence": 0.4
508
508
  })
509
-
509
+
510
510
  return fixes
511
511
 
512
512
 
513
513
  class GatherFixer(OperationFixer):
514
514
  """Fixer for Gather operations"""
515
-
515
+
516
516
  def generate_fixes(self) -> List[Dict[str, Any]]:
517
517
  fixes = []
518
518
  axis = self.node.attrs.get("axis", 0)
519
-
519
+
520
520
  # Try different axis values
521
521
  candidates = [0, 1, 2, 3, -1, -2]
522
-
522
+
523
523
  for candidate in candidates:
524
524
  if candidate != axis:
525
525
  fixes.append({
@@ -529,20 +529,20 @@ class GatherFixer(OperationFixer):
529
529
  "values": candidate,
530
530
  "confidence": 0.5
531
531
  })
532
-
532
+
533
533
  return fixes
534
534
 
535
535
 
536
536
  class FlattenFixer(OperationFixer):
537
537
  """Fixer for Flatten operations"""
538
-
538
+
539
539
  def generate_fixes(self) -> List[Dict[str, Any]]:
540
540
  fixes = []
541
541
  axis = self.node.attrs.get("axis", 1)
542
-
542
+
543
543
  # Try different axis values
544
544
  candidates = [0, 1, 2, -1]
545
-
545
+
546
546
  for candidate in candidates:
547
547
  if candidate != axis:
548
548
  fixes.append({
@@ -552,7 +552,7 @@ class FlattenFixer(OperationFixer):
552
552
  "values": candidate,
553
553
  "confidence": 0.6
554
554
  })
555
-
555
+
556
556
  # Also try pre-process transpose
557
557
  if self.node.inputs and self.node.inputs[0]:
558
558
  input_tensor = self.node.inputs[0]
@@ -573,7 +573,7 @@ class FlattenFixer(OperationFixer):
573
573
  else:
574
574
  # Default to 4D perms if shape unknown
575
575
  perms = list(itertools.permutations(range(4)))
576
-
576
+
577
577
  for perm in perms:
578
578
  fixes.append({
579
579
  "op_name": self.node.name,
@@ -582,40 +582,40 @@ class FlattenFixer(OperationFixer):
582
582
  "pre_process_transpose_perm": perm,
583
583
  "confidence": 0.5
584
584
  })
585
-
585
+
586
586
  return fixes
587
587
 
588
588
 
589
589
  class ExpandFixer(OperationFixer):
590
590
  """Fixer for Expand operations"""
591
-
591
+
592
592
  def generate_fixes(self) -> List[Dict[str, Any]]:
593
593
  fixes = []
594
-
594
+
595
595
  # Check if dimension mismatch is in error
596
596
  if 'error_msg' in self.error_info:
597
597
  error_msg = self.error_info['error_msg']
598
598
  # Extract shape info from error message
599
599
  shape_pattern = r'\[([\d,\s]+)\]'
600
600
  shapes = re.findall(shape_pattern, error_msg)
601
-
601
+
602
602
  if len(shapes) >= 2:
603
603
  # Parse shapes
604
604
  shape1 = [int(x.strip()) for x in shapes[0].split(',')]
605
605
  shape2 = [int(x.strip()) for x in shapes[1].split(',')]
606
-
606
+
607
607
  # For custom_spo2 case: [1,2,1,256,32,1] vs [1,1,1,1,2,1]
608
608
  # The issue is dimension 4: shape1[4]=32 but shape2[4]=2
609
609
  # We need to find where in shape1 we have value 2 and move it to position 4
610
610
  if len(shape1) == len(shape2):
611
611
  ndim = len(shape1)
612
-
612
+
613
613
  # Find positions where shape2 has non-1 values (broadcast targets)
614
614
  target_positions = []
615
615
  for i in range(ndim):
616
616
  if shape2[i] != 1:
617
617
  target_positions.append((i, shape2[i]))
618
-
618
+
619
619
  # For each target position, find matching values in shape1
620
620
  for target_pos, target_val in target_positions:
621
621
  if shape1[target_pos] != target_val:
@@ -624,7 +624,7 @@ class ExpandFixer(OperationFixer):
624
624
  if shape1[source_pos] == target_val:
625
625
  # Create permutation that moves source_pos to target_pos
626
626
  perm = list(range(ndim))
627
-
627
+
628
628
  # Complex permutation to maintain other dimensions
629
629
  if source_pos != target_pos:
630
630
  # For [0,1,2,3,4,5] moving 1->4 becomes [0,4,2,3,1,5]
@@ -639,11 +639,11 @@ class ExpandFixer(OperationFixer):
639
639
  for j in range(source_pos, target_pos, -1):
640
640
  perm[j] = perm[j - 1]
641
641
  perm[target_pos] = temp
642
-
642
+
643
643
  # Actually, for custom_spo2 we know the exact permutation
644
644
  if ndim == 6 and source_pos == 1 and target_pos == 4:
645
645
  perm = [0, 4, 2, 3, 1, 5]
646
-
646
+
647
647
  # High confidence fix
648
648
  if self.node.inputs:
649
649
  fixes.append({
@@ -655,12 +655,12 @@ class ExpandFixer(OperationFixer):
655
655
  })
656
656
  info(f"ExpandFixer: Generated critical permutation {perm} for {self.node.name}")
657
657
  break
658
-
658
+
659
659
  # Try modifying the shape input directly
660
660
  if len(self.node.inputs) >= 2:
661
661
  # Second input is usually the shape
662
662
  shape_input = self.node.inputs[1]
663
-
663
+
664
664
  # Try transposing the shape values
665
665
  if hasattr(shape_input, 'shape') and shape_input.shape:
666
666
  # Common shape permutations for 6D - CRITICAL permutation first
@@ -671,7 +671,7 @@ class ExpandFixer(OperationFixer):
671
671
  [0, 1, 4, 3, 2, 5], # Move dim 2 to 4
672
672
  [0, 1, 2, 4, 3, 5], # Move dim 3 to 4
673
673
  ]
674
-
674
+
675
675
  for perm in shape_perms:
676
676
  fixes.append({
677
677
  "op_name": self.node.name,
@@ -680,7 +680,7 @@ class ExpandFixer(OperationFixer):
680
680
  "values": perm, # This will modify the shape values
681
681
  "confidence": 0.7
682
682
  })
683
-
683
+
684
684
  # For Expand, limit permutations to avoid combinatorial explosion
685
685
  # Only generate a few strategic permutations
686
686
  ndim = 4 # Default
@@ -689,7 +689,7 @@ class ExpandFixer(OperationFixer):
689
689
  if hasattr(inp, 'shape') and inp.shape:
690
690
  ndim = len(inp.shape)
691
691
  break
692
-
692
+
693
693
  if ndim == 6:
694
694
  # For 6D, only add the most critical permutations
695
695
  critical_perms = [
@@ -719,26 +719,26 @@ class ExpandFixer(OperationFixer):
719
719
  "pre_process_transpose_perm": list(perm),
720
720
  "confidence": 0.5
721
721
  })
722
-
722
+
723
723
  return fixes
724
724
 
725
725
 
726
726
  class TileFixer(OperationFixer):
727
727
  """Fixer for Tile operations"""
728
-
728
+
729
729
  def generate_fixes(self) -> List[Dict[str, Any]]:
730
730
  fixes = []
731
-
731
+
732
732
  # Similar to AddMulDivSubFixer - try pre/post transpose
733
733
  return AddMulDivSubFixer(self.node, self.error_info).generate_fixes()
734
734
 
735
735
 
736
736
  class MatMulFixer(OperationFixer):
737
737
  """Fixer for MatMul operations"""
738
-
738
+
739
739
  def generate_fixes(self) -> List[Dict[str, Any]]:
740
740
  fixes = []
741
-
741
+
742
742
  # MatMul often needs transpose adjustments
743
743
  def get_matmul_perms(ndim: int) -> List[List[int]]:
744
744
  if ndim == 2:
@@ -756,14 +756,14 @@ class MatMulFixer(OperationFixer):
756
756
  perm[-1], perm[-1-i] = perm[-1-i], perm[-1]
757
757
  perms.append(perm)
758
758
  return perms
759
-
759
+
760
760
  # Try pre-process transpose
761
761
  if self.node.inputs:
762
762
  for inp in self.node.inputs[:2]: # First two inputs
763
763
  if hasattr(inp, 'shape') and inp.shape:
764
764
  ndim = len(inp.shape)
765
765
  perms = get_matmul_perms(ndim)
766
-
766
+
767
767
  for perm in perms:
768
768
  fixes.append({
769
769
  "op_name": self.node.name,
@@ -772,7 +772,7 @@ class MatMulFixer(OperationFixer):
772
772
  "pre_process_transpose_perm": perm,
773
773
  "confidence": 0.6
774
774
  })
775
-
775
+
776
776
  return fixes
777
777
 
778
778
 
@@ -806,11 +806,11 @@ def get_fixer_for_op(node: gs.Node, error_info: Dict[str, Any]) -> Optional[Oper
806
806
  "Tile": TileFixer,
807
807
  "MatMul": MatMulFixer,
808
808
  }
809
-
809
+
810
810
  fixer_class = fixers.get(node.op)
811
811
  if fixer_class:
812
812
  return fixer_class(node, error_info)
813
-
813
+
814
814
  return None
815
815
 
816
816
 
@@ -825,12 +825,12 @@ def analyze_conversion_error(
825
825
  "problematic_ops": [],
826
826
  "suggested_op_types": []
827
827
  }
828
-
828
+
829
829
  error_msg = str(error)
830
-
830
+
831
831
  # Debug: Show first 500 chars of error message
832
832
  debug(f"Error message preview: {error_msg[:500]}..." if len(error_msg) > 500 else f"Error message: {error_msg}")
833
-
833
+
834
834
  # Extract operation name from error message
835
835
  patterns = [
836
836
  r'onnx_op_name:\s*([^\s]+)',
@@ -841,7 +841,7 @@ def analyze_conversion_error(
841
841
  r'tf\.math\.(multiply|add|subtract|divide)_([\d]+)',
842
842
  r'wa/lightglue/posenc/Expand', # Specific pattern for custom_spo2
843
843
  ]
844
-
844
+
845
845
  for pattern in patterns:
846
846
  matches = re.findall(pattern, error_msg, re.IGNORECASE)
847
847
  if matches:
@@ -856,33 +856,33 @@ def analyze_conversion_error(
856
856
  error_info["problematic_ops"].append(match)
857
857
  else:
858
858
  error_info["problematic_ops"].extend(matches)
859
-
859
+
860
860
  # Identify operation types that might need fixing
861
861
  if "concat" in error_msg.lower():
862
862
  error_info["suggested_op_types"].append("Concat")
863
863
  error_info["suggested_op_types"].append("Split")
864
-
864
+
865
865
  if "dimension" in error_msg.lower() or "shape" in error_msg.lower():
866
866
  error_info["suggested_op_types"].extend(["Transpose", "Reshape", "Resize"])
867
-
867
+
868
868
  if "transpose" in error_msg.lower():
869
869
  error_info["suggested_op_types"].append("Transpose")
870
-
870
+
871
871
  if "multiply" in error_msg.lower() or "mul" in error_msg.lower():
872
872
  error_info["suggested_op_types"].extend(["Mul", "Transpose", "Reshape"])
873
-
873
+
874
874
  if "add" in error_msg.lower():
875
875
  error_info["suggested_op_types"].extend(["Add", "Transpose", "Reshape"])
876
-
876
+
877
877
  if "div" in error_msg.lower():
878
878
  error_info["suggested_op_types"].extend(["Div", "Transpose", "Reshape"])
879
-
879
+
880
880
  if "sub" in error_msg.lower():
881
881
  error_info["suggested_op_types"].extend(["Sub", "Transpose", "Reshape"])
882
-
882
+
883
883
  if "expand" in error_msg.lower():
884
884
  error_info["suggested_op_types"].extend(["Expand", "Transpose", "Reshape"])
885
-
885
+
886
886
  # Check if the exception has onnx_op_name attribute
887
887
  if hasattr(error, 'onnx_op_name') and error.onnx_op_name:
888
888
  error_info["onnx_op_name"] = error.onnx_op_name
@@ -899,7 +899,7 @@ def analyze_conversion_error(
899
899
  if onnx_op_name not in error_info["problematic_ops"]:
900
900
  error_info["problematic_ops"].append(onnx_op_name)
901
901
  info(f"Extracted ONNX op name from error message: {onnx_op_name}")
902
-
902
+
903
903
  return error_info
904
904
 
905
905
 
@@ -917,14 +917,14 @@ def analyze_accuracy_errors(
917
917
  "max_error": 0.0,
918
918
  "error_distribution": {}
919
919
  }
920
-
920
+
921
921
  # Group errors by operation
922
922
  op_errors = {}
923
-
923
+
924
924
  for (onnx_output_name, tf_output_name), checked_value in check_results.items():
925
925
  matched_flg = checked_value[1]
926
926
  max_abs_err = checked_value[2]
927
-
927
+
928
928
  if (matched_flg == 0 or matched_flg == False) and isinstance(max_abs_err, (int, float, np.float32, np.float64)):
929
929
  if max_abs_err > error_threshold:
930
930
  # Find the operation that produces this output
@@ -934,19 +934,19 @@ def analyze_accuracy_errors(
934
934
  op_errors[node.name] = []
935
935
  op_errors[node.name].append(max_abs_err)
936
936
  break
937
-
937
+
938
938
  # Analyze error distribution
939
939
  if op_errors:
940
940
  error_info["problematic_ops"] = list(op_errors.keys())
941
941
  error_info["max_error"] = max(max(errors) for errors in op_errors.values())
942
-
942
+
943
943
  # Suggest operation types based on error patterns
944
944
  for op_name in op_errors:
945
945
  node = next((n for n in onnx_graph.nodes if n.name == op_name), None)
946
946
  if node:
947
947
  if node.op not in error_info["suggested_op_types"]:
948
948
  error_info["suggested_op_types"].append(node.op)
949
-
949
+
950
950
  return error_info
951
951
 
952
952
 
@@ -958,14 +958,14 @@ def generate_candidate_fixes(
958
958
  """Generate candidate fixes based on error analysis"""
959
959
  if previous_attempts is None:
960
960
  previous_attempts = set()
961
-
961
+
962
962
  candidate_fixes = []
963
-
963
+
964
964
  # Priority 1: Fix specific problematic operations
965
965
  for op_name in error_info.get("problematic_ops", []):
966
966
  # Try to find the node directly
967
967
  node = next((n for n in onnx_graph.nodes if n.name == op_name), None)
968
-
968
+
969
969
  # If not found and it's a TF operation name, try to find corresponding ONNX node
970
970
  if not node and 'tf.math' in op_name:
971
971
  # Extract operation type
@@ -979,7 +979,7 @@ def generate_candidate_fixes(
979
979
  op_type = 'Div'
980
980
  else:
981
981
  op_type = None
982
-
982
+
983
983
  # For TF operations, we can't directly map to ONNX nodes
984
984
  # Skip these for now - they will be handled by the ONNX op name
985
985
  pass
@@ -988,7 +988,7 @@ def generate_candidate_fixes(
988
988
  if fixer:
989
989
  fixes = fixer.generate_fixes()
990
990
  candidate_fixes.extend(fixes)
991
-
991
+
992
992
  # Priority 2: Fix operations of suggested types - LIMIT TO SPECIFIC NODE IF KNOWN
993
993
  if onnx_op_name := error_info.get("onnx_op_name"):
994
994
  # Only process the specific node mentioned in the error
@@ -1013,7 +1013,7 @@ def generate_candidate_fixes(
1013
1013
  count += 1
1014
1014
  if count >= 3: # Limit to first 3 nodes of each type
1015
1015
  break
1016
-
1016
+
1017
1017
  # Priority 3: Generic fixes for common patterns
1018
1018
  if not candidate_fixes:
1019
1019
  # Look for all Transpose operations
@@ -1022,7 +1022,7 @@ def generate_candidate_fixes(
1022
1022
  fixer = TransposeFixer(node, error_info)
1023
1023
  fixes = fixer.generate_fixes()
1024
1024
  candidate_fixes.extend(fixes)
1025
-
1025
+
1026
1026
  # Priority 4: For concat errors, look more broadly
1027
1027
  if "concat" in str(error_info.get("error_msg", "")).lower():
1028
1028
  # Look for ALL Transpose, Split, and Concat operations that might need fixing
@@ -1031,12 +1031,12 @@ def generate_candidate_fixes(
1031
1031
  # Skip if already processed
1032
1032
  if any(fix["op_name"] == node.name for fix in candidate_fixes):
1033
1033
  continue
1034
-
1034
+
1035
1035
  fixer = get_fixer_for_op(node, error_info)
1036
1036
  if fixer:
1037
1037
  fixes = fixer.generate_fixes()
1038
1038
  candidate_fixes.extend(fixes)
1039
-
1039
+
1040
1040
  # Priority 5: Special handling for errors from specific ONNX operations
1041
1041
  # Use the extracted onnx_op_name if available
1042
1042
  onnx_op_name = error_info.get("onnx_op_name")
@@ -1052,7 +1052,7 @@ def generate_candidate_fixes(
1052
1052
  fix['confidence'] = 0.95
1053
1053
  candidate_fixes.extend(fixes)
1054
1054
  info(f"Found specific node from error: {onnx_op_name} (type: {specific_node.op})")
1055
-
1055
+
1056
1056
  # For Expand operations, also find related operations
1057
1057
  if specific_node.op == 'Expand':
1058
1058
  # Find all Expand operations with similar patterns
@@ -1065,7 +1065,7 @@ def generate_candidate_fixes(
1065
1065
  fix['confidence'] = 0.9
1066
1066
  candidate_fixes.extend(fixes)
1067
1067
  info(f"Added fixes for all Expand operations due to error in {onnx_op_name}")
1068
-
1068
+
1069
1069
  # Filter out previously attempted fixes and validate fixes
1070
1070
  filtered_fixes = []
1071
1071
  for fix in candidate_fixes:
@@ -1073,7 +1073,7 @@ def generate_candidate_fixes(
1073
1073
  if fix_key not in previous_attempts:
1074
1074
  # Validate the fix
1075
1075
  is_valid = True
1076
-
1076
+
1077
1077
  # Check if permutation dimensions match tensor dimensions
1078
1078
  if "pre_process_transpose_perm" in fix or "post_process_transpose_perm" in fix:
1079
1079
  perm = fix.get("pre_process_transpose_perm") or fix.get("post_process_transpose_perm")
@@ -1091,7 +1091,7 @@ def generate_candidate_fixes(
1091
1091
  info(f"Skipping invalid fix: {fix['op_name']} - perm len {len(perm)} != tensor dims {expected_dims}")
1092
1092
  is_valid = False
1093
1093
  break
1094
-
1094
+
1095
1095
  # For post_process, check output dimensions
1096
1096
  if "post_process_transpose_perm" in fix and node.outputs:
1097
1097
  for out in node.outputs:
@@ -1102,14 +1102,14 @@ def generate_candidate_fixes(
1102
1102
  info(f"Skipping invalid fix: {fix['op_name']} - perm len {len(perm)} != tensor dims {expected_dims}")
1103
1103
  is_valid = False
1104
1104
  break
1105
-
1105
+
1106
1106
  if is_valid:
1107
1107
  filtered_fixes.append(fix)
1108
1108
  previous_attempts.add(fix_key)
1109
-
1109
+
1110
1110
  # Sort by confidence
1111
1111
  filtered_fixes.sort(key=lambda x: x.get("confidence", 0.5), reverse=True)
1112
-
1112
+
1113
1113
  return filtered_fixes
1114
1114
 
1115
1115
 
@@ -1117,7 +1117,7 @@ def combine_fixes(fixes: List[Dict[str, Any]], unlimited: bool = False) -> List[
1117
1117
  """Generate combinations of fixes to try together"""
1118
1118
  if not fixes:
1119
1119
  return []
1120
-
1120
+
1121
1121
  # Group fixes by operation type and name
1122
1122
  op_groups = {}
1123
1123
  for fix in fixes:
@@ -1125,19 +1125,19 @@ def combine_fixes(fixes: List[Dict[str, Any]], unlimited: bool = False) -> List[
1125
1125
  if op_name not in op_groups:
1126
1126
  op_groups[op_name] = []
1127
1127
  op_groups[op_name].append(fix)
1128
-
1128
+
1129
1129
  combinations = []
1130
-
1130
+
1131
1131
  if unlimited:
1132
1132
  # For unlimited mode, generate ALL possible combinations for each operation
1133
1133
  for op_name, op_fixes in op_groups.items():
1134
1134
  # Sort by confidence to prioritize better fixes
1135
1135
  sorted_fixes = sorted(op_fixes, key=lambda x: x.get("confidence", 0.5), reverse=True)
1136
-
1136
+
1137
1137
  # Add all individual fixes
1138
1138
  for fix in sorted_fixes:
1139
1139
  combinations.append([fix])
1140
-
1140
+
1141
1141
  # Also try combining high-confidence fixes from different operations
1142
1142
  high_confidence_fixes = [f for f in fixes if f.get("confidence", 0.5) >= 0.7]
1143
1143
  if len(high_confidence_fixes) > 1:
@@ -1178,18 +1178,18 @@ def combine_fixes(fixes: List[Dict[str, Any]], unlimited: bool = False) -> List[
1178
1178
  elif "expand" in node_part.lower():
1179
1179
  op_type = "Expand"
1180
1180
  break
1181
-
1181
+
1182
1182
  if not op_type:
1183
1183
  # Check if it's a parameter target type fix
1184
1184
  if fix.get("param_target") == "inputs" and "pre_process_transpose_perm" in fix:
1185
1185
  op_type = "InputTranspose"
1186
1186
  else:
1187
1187
  op_type = "Other"
1188
-
1188
+
1189
1189
  if op_type not in op_type_groups:
1190
1190
  op_type_groups[op_type] = []
1191
1191
  op_type_groups[op_type].append(fix)
1192
-
1192
+
1193
1193
  # First, try all fixes of the same type together
1194
1194
  for op_type, type_fixes in op_type_groups.items():
1195
1195
  if op_type == "Transpose" and len(type_fixes) > 1:
@@ -1202,17 +1202,17 @@ def combine_fixes(fixes: List[Dict[str, Any]], unlimited: bool = False) -> List[
1202
1202
  # For arithmetic operations and input transposes, apply all fixes
1203
1203
  sorted_fixes = sorted(type_fixes, key=lambda x: x.get("confidence", 0.5), reverse=True)
1204
1204
  combinations.append(sorted_fixes)
1205
-
1205
+
1206
1206
  # Then try individual fixes
1207
1207
  for fix in fixes[:100]: # Increased limit
1208
1208
  combinations.append([fix])
1209
-
1209
+
1210
1210
  # Finally, try mixed combinations
1211
1211
  if "Transpose" in op_type_groups and "Concat" in op_type_groups:
1212
1212
  trans_fixes = op_type_groups["Transpose"]
1213
1213
  concat_fixes = op_type_groups["Concat"]
1214
1214
  combinations.append(trans_fixes + concat_fixes)
1215
-
1215
+
1216
1216
  return combinations
1217
1217
 
1218
1218
 
@@ -1228,7 +1228,7 @@ def test_conversion_with_json(
1228
1228
  import tempfile
1229
1229
  import subprocess
1230
1230
  import shutil
1231
-
1231
+
1232
1232
  # Create temporary JSON file
1233
1233
  with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
1234
1234
  json_content = {
@@ -1237,10 +1237,10 @@ def test_conversion_with_json(
1237
1237
  }
1238
1238
  json.dump(json_content, f)
1239
1239
  temp_json_path = f.name
1240
-
1240
+
1241
1241
  # Create temporary output directory
1242
1242
  temp_output_dir = tempfile.mkdtemp()
1243
-
1243
+
1244
1244
  try:
1245
1245
  # Run conversion with the JSON
1246
1246
  cmd = [
@@ -1248,17 +1248,17 @@ def test_conversion_with_json(
1248
1248
  "-i", model_path,
1249
1249
  "-prf", temp_json_path,
1250
1250
  "-o", temp_output_dir,
1251
- "-n", # No optimization
1251
+ "-n", # No optimization
1252
1252
  "-q" # Quiet mode
1253
1253
  ]
1254
-
1254
+
1255
1255
  result = subprocess.run(
1256
1256
  cmd,
1257
1257
  capture_output=True,
1258
1258
  text=True,
1259
1259
  timeout=timeout
1260
1260
  )
1261
-
1261
+
1262
1262
  if result.returncode == 0:
1263
1263
  # Conversion succeeded, check if accuracy test would pass
1264
1264
  # For now, assume success means good accuracy
@@ -1272,7 +1272,7 @@ def test_conversion_with_json(
1272
1272
  else:
1273
1273
  # Different error, might be progress
1274
1274
  return (False, error_msg, None)
1275
-
1275
+
1276
1276
  except subprocess.TimeoutExpired:
1277
1277
  return (False, "Conversion timed out", None)
1278
1278
  except Exception as e:
@@ -1300,7 +1300,7 @@ def generate_auto_replacement_json(
1300
1300
  Generate automatic parameter replacement JSON based on conversion errors or accuracy issues.
1301
1301
  This implements an exhaustive search algorithm that tries different parameter modifications
1302
1302
  until finding the optimal solution with minimum error.
1303
-
1303
+
1304
1304
  Args:
1305
1305
  onnx_graph: ONNX graph
1306
1306
  tf_layers_dict: TensorFlow layers dictionary
@@ -1311,60 +1311,60 @@ def generate_auto_replacement_json(
1311
1311
  max_iterations: Maximum number of optimization iterations
1312
1312
  target_accuracy: Target accuracy to achieve
1313
1313
  unlimited_mode: If True, test all combinations until minimum error found
1314
-
1314
+
1315
1315
  Returns:
1316
1316
  Dictionary containing the replacement JSON structure
1317
1317
  """
1318
1318
  info("Starting automatic JSON generation...")
1319
-
1319
+
1320
1320
  # Initialize
1321
1321
  best_operations = []
1322
1322
  previous_attempts = set()
1323
1323
  iteration = 0
1324
-
1324
+
1325
1325
  # Analyze the error
1326
1326
  if conversion_error:
1327
1327
  error_info = analyze_conversion_error(conversion_error, onnx_graph)
1328
1328
  info(f"Conversion error analysis: {error_info['error_type']}")
1329
1329
  info(f"Problematic operations: {error_info.get('problematic_ops', [])}")
1330
1330
  info(f"Suggested operation types: {error_info.get('suggested_op_types', [])}")
1331
-
1331
+
1332
1332
  # Generate initial fixes
1333
1333
  candidate_fixes = generate_candidate_fixes(onnx_graph, error_info, previous_attempts)
1334
-
1334
+
1335
1335
  if candidate_fixes:
1336
1336
  info(f"Generated {len(candidate_fixes)} candidate fixes for conversion error")
1337
-
1337
+
1338
1338
  # Use unlimited mode to get ALL possible combinations
1339
1339
  fix_combinations = combine_fixes(candidate_fixes, unlimited=True)
1340
1340
  info(f"Generated {len(fix_combinations)} fix combinations to test")
1341
-
1341
+
1342
1342
  # For conversion errors, we need to actually test each combination
1343
1343
  # by attempting conversion with the temporary JSON
1344
1344
  best_operations = []
1345
1345
  best_error_msg = str(conversion_error)
1346
1346
  tested_count = 0
1347
-
1347
+
1348
1348
  # First, prioritize high-confidence single fixes
1349
1349
  single_fixes = [combo for combo in fix_combinations if len(combo) == 1]
1350
1350
  single_fixes.sort(key=lambda combo: combo[0].get("confidence", 0.5), reverse=True)
1351
-
1351
+
1352
1352
  info("Testing individual fixes first...")
1353
1353
  for i, combo in enumerate(single_fixes):
1354
1354
  tested_count += 1
1355
1355
  if tested_count % 100 == 0:
1356
1356
  info(f"Tested {tested_count}/{len(fix_combinations)} combinations...")
1357
-
1357
+
1358
1358
  # Check if this is a critical fix
1359
1359
  fix = combo[0]
1360
-
1360
+
1361
1361
  # For Expand operations with critical permutation
1362
- if ("Expand" in fix.get("op_name", "") and
1362
+ if ("Expand" in fix.get("op_name", "") and
1363
1363
  fix.get("pre_process_transpose_perm") == [0, 4, 2, 3, 1, 5]):
1364
1364
  info(f"Found critical permutation [0,4,2,3,1,5] for {fix['op_name']}!")
1365
1365
  best_operations = combo
1366
1366
  break
1367
-
1367
+
1368
1368
  # Prioritize high-confidence fixes that match the error pattern
1369
1369
  if "Expand" in str(conversion_error) and "Expand" in fix.get("op_name", ""):
1370
1370
  # Select highest confidence Expand fix
@@ -1372,13 +1372,13 @@ def generate_auto_replacement_json(
1372
1372
  best_operations = combo
1373
1373
  info(f"Selected high-confidence fix (conf={fix.get('confidence')}) for {fix['op_name']}")
1374
1374
  break
1375
-
1375
+
1376
1376
  # If no good single fix found, try combinations
1377
1377
  if not best_operations and len(fix_combinations) > len(single_fixes):
1378
1378
  info("Testing fix combinations...")
1379
1379
  multi_fixes = [combo for combo in fix_combinations if len(combo) > 1]
1380
1380
  multi_fixes.sort(key=lambda combo: sum(f.get("confidence", 0.5) for f in combo) / len(combo), reverse=True)
1381
-
1381
+
1382
1382
  for combo in multi_fixes[:50]: # Test top 50 combinations
1383
1383
  tested_count += 1
1384
1384
  # In real implementation, test conversion here
@@ -1386,39 +1386,39 @@ def generate_auto_replacement_json(
1386
1386
  if any("Expand" in f.get("op_name", "") for f in combo):
1387
1387
  best_operations = combo
1388
1388
  break
1389
-
1389
+
1390
1390
  # Fallback: use highest confidence fixes
1391
1391
  if not best_operations and fix_combinations:
1392
1392
  best_operations = fix_combinations[0]
1393
-
1393
+
1394
1394
  info(f"Selected {len(best_operations)} operations after testing {tested_count} combinations")
1395
-
1395
+
1396
1396
  elif check_results:
1397
1397
  error_info = analyze_accuracy_errors(check_results, tf_layers_dict, onnx_graph, error_threshold)
1398
1398
  info(f"Accuracy error analysis: max error = {error_info['max_error']:.6f}")
1399
-
1399
+
1400
1400
  if error_info['max_error'] > target_accuracy:
1401
1401
  info(f"Starting iterative optimization (target accuracy: {target_accuracy})")
1402
-
1402
+
1403
1403
  # Iterative optimization loop
1404
1404
  current_error = error_info['max_error']
1405
-
1405
+
1406
1406
  while iteration < max_iterations and current_error > target_accuracy:
1407
1407
  iteration += 1
1408
1408
  info(f"\nIteration {iteration}/{max_iterations}")
1409
-
1409
+
1410
1410
  # Generate candidate fixes
1411
1411
  candidate_fixes = generate_candidate_fixes(onnx_graph, error_info, previous_attempts)
1412
-
1412
+
1413
1413
  if not candidate_fixes:
1414
1414
  info("No more candidate fixes available")
1415
1415
  break
1416
-
1416
+
1417
1417
  info(f"Generated {len(candidate_fixes)} candidate fixes")
1418
-
1418
+
1419
1419
  # Generate fix combinations
1420
1420
  fix_combinations = combine_fixes(candidate_fixes)
1421
-
1421
+
1422
1422
  # In a real implementation, we would test each combination
1423
1423
  # For now, we'll use heuristics to select the best combination
1424
1424
  if fix_combinations:
@@ -1427,30 +1427,30 @@ def generate_auto_replacement_json(
1427
1427
  fix_combinations,
1428
1428
  key=lambda combo: sum(fix.get("confidence", 0.5) for fix in combo) / len(combo)
1429
1429
  )
1430
-
1430
+
1431
1431
  best_operations.extend(best_combination)
1432
1432
  info(f"Applied {len(best_combination)} fixes in this iteration")
1433
-
1433
+
1434
1434
  # Simulate improvement (in real implementation, this would re-run conversion)
1435
1435
  improvement_factor = 0.5 + 0.3 * sum(fix.get("confidence", 0.5) for fix in best_combination) / len(best_combination)
1436
1436
  current_error *= improvement_factor
1437
1437
  info(f"Estimated error after fixes: {current_error:.6f}")
1438
1438
  else:
1439
1439
  break
1440
-
1440
+
1441
1441
  # Remove confidence scores from final output
1442
1442
  for op in best_operations:
1443
1443
  if "confidence" in op:
1444
1444
  del op["confidence"]
1445
-
1445
+
1446
1446
  # Generate the JSON structure
1447
1447
  model_name = os.path.splitext(os.path.basename(model_path))[0] if model_path else "model"
1448
-
1448
+
1449
1449
  replacement_json = {
1450
1450
  "format_version": 1,
1451
1451
  "operations": best_operations
1452
1452
  }
1453
-
1453
+
1454
1454
  # Add metadata comments
1455
1455
  if best_operations:
1456
1456
  replacement_json["_comment"] = f"Auto-generated replacement for {model_name}"
@@ -1460,7 +1460,7 @@ def generate_auto_replacement_json(
1460
1460
  replacement_json["_iterations"] = iteration
1461
1461
  if conversion_error:
1462
1462
  replacement_json["_generation_reason"] = "conversion_error"
1463
-
1463
+
1464
1464
  return replacement_json
1465
1465
 
1466
1466
 
@@ -1471,35 +1471,35 @@ def save_auto_replacement_json(
1471
1471
  ) -> str:
1472
1472
  """
1473
1473
  Save the auto-generated replacement JSON to a file.
1474
-
1474
+
1475
1475
  Args:
1476
1476
  replacement_json: The replacement JSON dictionary
1477
1477
  model_path: Path to the ONNX model
1478
1478
  output_dir: Directory to save the JSON file (default: same as model)
1479
-
1479
+
1480
1480
  Returns:
1481
1481
  Path to the saved JSON file
1482
1482
  """
1483
1483
  if not replacement_json.get("operations"):
1484
1484
  return ""
1485
-
1485
+
1486
1486
  # Generate filename
1487
1487
  model_name = os.path.splitext(os.path.basename(model_path))[0]
1488
1488
  json_filename = f"{model_name}_auto.json"
1489
-
1489
+
1490
1490
  # Determine output directory
1491
1491
  if output_dir is None:
1492
1492
  output_dir = os.path.dirname(model_path)
1493
-
1493
+
1494
1494
  # Create output directory if it doesn't exist
1495
1495
  if output_dir and not os.path.exists(output_dir):
1496
1496
  os.makedirs(output_dir)
1497
-
1497
+
1498
1498
  json_path = os.path.join(output_dir, json_filename)
1499
-
1499
+
1500
1500
  # Save JSON
1501
1501
  with open(json_path, 'w', encoding='utf-8') as f:
1502
1502
  json.dump(replacement_json, f, indent=2, ensure_ascii=False)
1503
-
1503
+
1504
1504
  info(f"Auto-generated replacement JSON saved to: {json_path}")
1505
1505
  return json_path