tinymlc 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. TinyMLC/ANG/__init__.py +0 -0
  2. TinyMLC/ANG/args.py +86 -0
  3. TinyMLC/ANG/estimator.py +103 -0
  4. TinyMLC/ANG/estimator_hal.py +184 -0
  5. TinyMLC/ANG/estimator_qemu.py +257 -0
  6. TinyMLC/ANG/estimator_software.py +130 -0
  7. TinyMLC/ANG/model_builder.py +508 -0
  8. TinyMLC/ANG/model_generator.py +439 -0
  9. TinyMLC/ANG/model_info.py +283 -0
  10. TinyMLC/ANG/utils.py +420 -0
  11. TinyMLC/__init__.py +0 -0
  12. TinyMLC/cli.py +126 -0
  13. TinyMLC/codegen.py +877 -0
  14. TinyMLC/converter/__init__.py +0 -0
  15. TinyMLC/converter/export_weights.py +382 -0
  16. TinyMLC/converter/parser_litert.py +757 -0
  17. TinyMLC/converter/parser_onnx.py +649 -0
  18. TinyMLC/generate_lut.py +97 -0
  19. TinyMLC/handlers.py +325 -0
  20. TinyMLC/ops.py +76 -0
  21. TinyMLC/templates/lut.c.tpl +23 -0
  22. TinyMLC/templates/lut.h.tpl +67 -0
  23. TinyMLC/templates/model.c.tpl +314 -0
  24. TinyMLC/templates/model.h.tpl +66 -0
  25. TinyMLC/transform/__init__.py +0 -0
  26. TinyMLC/transform/algebraic.py +286 -0
  27. TinyMLC/transform/base.py +58 -0
  28. TinyMLC/transform/constant_folding.py +260 -0
  29. TinyMLC/transform/cse.py +192 -0
  30. TinyMLC/transform/dce.py +182 -0
  31. TinyMLC/transform/fusion.py +723 -0
  32. TinyMLC/transform/memory.py +200 -0
  33. TinyMLC/transform/pass_manager.py +101 -0
  34. TinyMLC/transform/simplify.py +515 -0
  35. tinymlc-0.1.0.dist-info/METADATA +49 -0
  36. tinymlc-0.1.0.dist-info/RECORD +47 -0
  37. tinymlc-0.1.0.dist-info/WHEEL +4 -0
  38. tinymlc-0.1.0.dist-info/entry_points.txt +2 -0
  39. tinymlc-0.1.0.dist-info/licenses/LICENSE +201 -0
  40. utils/__init__.py +0 -0
  41. utils/arm-none-eabi-gcc.cmake +53 -0
  42. utils/dump.py +86 -0
  43. utils/generate_onnx_models.py +183 -0
  44. utils/generate_tflite_models.py +236 -0
  45. utils/pack_macos.sh +88 -0
  46. utils/path.py +31 -0
  47. utils/riscv-none-elf-gcc.cmake +50 -0
@@ -0,0 +1,723 @@
1
+ # -*- coding: utf-8 -*-
2
+ # TinyMLC - Tiny Machine Learning Compiler
3
+ #
4
+ # Copyright (c) 2026 Jia Liu & TinyMLC Contributors
5
+ # SPDX-License-Identifier: Apache-2.0
6
+ #
7
+ # This file is part of TinyMLC.
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at:
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+
20
+ # Operator Fusion.
21
+
22
+ import numpy as np
23
+
24
+ from typing import Dict, Any
25
+ from TinyMLC.transform.base import Pass
26
+
27
+
28
+ class OperatorFusion(Pass):
29
+ """
30
+ Operator Fusion.
31
+
32
+ Fuses adjacent operators into single operators:
33
+ - CONV_2D + RELU -> CONV_2D (with activation fused in params)
34
+ - FC + RELU -> FC (with activation fused in params)
35
+ - CONV_2D + RELU6 -> CONV_2D (with relu6)
36
+ - CONV_2D + HARD_SIGMOID -> CONV_2D (with hard_sigmoid)
37
+ - FC + SOFTMAX -> FC (with softmax fused)
38
+ """
39
+
40
+ def __init__(self, name: str = "OperatorFusion"):
41
+ super().__init__(name)
42
+ self._fused_count = 0
43
+
44
+ def run(self, model_info: Dict[str, Any]) -> Dict[str, Any]:
45
+ """Run operator fusion on model_info."""
46
+ model_info = self._copy_model(model_info)
47
+ self._fused_count = 0
48
+
49
+ # Iteratively fuse until no more changes
50
+ changed = True
51
+ while changed:
52
+ changed = False
53
+ changed |= self._fuse_conv_activation(model_info)
54
+ changed |= self._fuse_fc_activation(model_info)
55
+ changed |= self._fuse_fc_softmax(model_info)
56
+ changed |= self._fuse_dwconv_activation(model_info)
57
+ changed |= self._fuse_conv_add(model_info)
58
+ changed |= self._fuse_conv_batchnorm(model_info)
59
+ changed |= self._fuse_conv_conv(model_info)
60
+
61
+ if self._fused_count > 0:
62
+ self._log_change(f"Fused {self._fused_count} operator pairs")
63
+
64
+ return model_info
65
+
66
+ def _fuse_conv_activation(self, model_info: Dict[str, Any]) -> bool:
67
+ """Fuse CONV_2D + activation into a single CONV_2D."""
68
+ ops = model_info.get("ops", [])
69
+ fused = False
70
+ i = 0
71
+
72
+ while i < len(ops) - 1:
73
+ op = ops[i]
74
+ next_op = ops[i + 1]
75
+
76
+ if op.get("op_name") != "CONV_2D":
77
+ i += 1
78
+ continue
79
+
80
+ # Check next op is an activation
81
+ activation = None
82
+ if next_op.get("op_name") == "RELU":
83
+ activation = "relu"
84
+ elif next_op.get("op_name") == "RELU6":
85
+ activation = "relu6"
86
+ elif next_op.get("op_name") == "HARD_SIGMOID":
87
+ activation = "hard_sigmoid"
88
+ else:
89
+ i += 1
90
+ continue
91
+
92
+ # Check that conv's output feeds into activation's input
93
+ conv_output = op.get("output_indices", [])[0]
94
+ act_input = next_op.get("input_indices", [])[0]
95
+
96
+ if conv_output != act_input:
97
+ i += 1
98
+ continue
99
+
100
+ # Also check that activation's output is used by later ops
101
+ # (not just a dead op)
102
+ act_output = next_op.get("output_indices", [])[0]
103
+
104
+ # Fuse: add activation to conv params
105
+ conv_params = op.get("conv_params", {})
106
+ conv_params["activation"] = activation
107
+ op["conv_params"] = conv_params
108
+
109
+ # Update op's output to activation's output
110
+ op["output_indices"] = [act_output]
111
+
112
+ # Ensure act_output exists in tensors
113
+ tensors = model_info.get("tensors", {})
114
+ if act_output not in tensors:
115
+ if conv_output in tensors:
116
+ tensors[act_output] = tensors[conv_output].copy()
117
+ tensors[act_output]["name"] = f"tensor_{act_output}"
118
+ else:
119
+ tensors[act_output] = {
120
+ "name": f"tensor_{act_output}",
121
+ "shape": [1, 8, 28, 28],
122
+ "dtype": "float32",
123
+ "scale": 1.0,
124
+ "zero_point": 0,
125
+ }
126
+ model_info["tensors"] = tensors
127
+
128
+ # Update all references: replace conv_output with act_output,
129
+ # then delete conv_output
130
+ self._update_tensor_refs_after_removal(model_info, conv_output,
131
+ act_output)
132
+
133
+ # Remove the activation op
134
+ del ops[i + 1]
135
+
136
+ fused = True
137
+ self._fused_count += 1
138
+ self._log_change(f" Fused CONV_2D + {activation}")
139
+
140
+ # Don't increment i, check if next op can also be fused
141
+ # (but there won't be another activation right after)
142
+
143
+ if fused:
144
+ model_info["ops"] = ops
145
+
146
+ return fused
147
+
148
+ def _fuse_fc_activation(self, model_info: Dict[str, Any]) -> bool:
149
+ """Fuse FC + activation into a single FC."""
150
+ ops = model_info.get("ops", [])
151
+ fused = False
152
+ i = 0
153
+
154
+ while i < len(ops) - 1:
155
+ op = ops[i]
156
+ next_op = ops[i + 1]
157
+
158
+ if op.get("op_name") != "FULLY_CONNECTED":
159
+ i += 1
160
+ continue
161
+
162
+ activation = None
163
+ if next_op.get("op_name") == "RELU":
164
+ activation = "relu"
165
+ elif next_op.get("op_name") == "RELU6":
166
+ activation = "relu6"
167
+ else:
168
+ i += 1
169
+ continue
170
+
171
+ fc_output = op.get("output_indices", [])[0]
172
+ act_input = next_op.get("input_indices", [])[0]
173
+
174
+ if fc_output != act_input:
175
+ i += 1
176
+ continue
177
+
178
+ act_output = next_op.get("output_indices", [])[0]
179
+
180
+ fc_params = op.get("fc_params", {})
181
+ fc_params["activation"] = activation
182
+ op["fc_params"] = fc_params
183
+
184
+ op["output_indices"] = [act_output]
185
+
186
+ del ops[i + 1]
187
+
188
+ fused = True
189
+ self._fused_count += 1
190
+ self._log_change(f" Fused FULLY_CONNECTED + {activation}")
191
+
192
+ if fused:
193
+ model_info["ops"] = ops
194
+
195
+ return fused
196
+
197
+ def _fuse_fc_softmax(self, model_info: Dict[str, Any]) -> bool:
198
+ """Fuse FC + SOFTMAX into FC with softmax flag."""
199
+ ops = model_info.get("ops", [])
200
+ fused = False
201
+ i = 0
202
+
203
+ while i < len(ops) - 1:
204
+ op = ops[i]
205
+ next_op = ops[i + 1]
206
+
207
+ if op.get("op_name") != "FULLY_CONNECTED":
208
+ i += 1
209
+ continue
210
+
211
+ if next_op.get("op_name") != "SOFTMAX":
212
+ i += 1
213
+ continue
214
+
215
+ fc_output = op.get("output_indices", [])[0]
216
+ sm_input = next_op.get("input_indices", [])[0]
217
+
218
+ if fc_output != sm_input:
219
+ i += 1
220
+ continue
221
+
222
+ sm_output = next_op.get("output_indices", [])[0]
223
+
224
+ fc_params = op.get("fc_params", {})
225
+ fc_params["with_softmax"] = True
226
+ op["fc_params"] = fc_params
227
+
228
+ op["output_indices"] = [sm_output]
229
+
230
+ del ops[i + 1]
231
+
232
+ fused = True
233
+ self._fused_count += 1
234
+ self._log_change(" Fused FULLY_CONNECTED + SOFTMAX")
235
+
236
+ if fused:
237
+ model_info["ops"] = ops
238
+
239
+ return fused
240
+
241
+ def _fuse_dwconv_activation(self, model_info: Dict[str, Any]) -> bool:
242
+ """Fuse DEPTHWISE_CONV_2D + activation into a single
243
+ DEPTHWISE_CONV_2D."""
244
+ ops = model_info.get("ops", [])
245
+ fused = False
246
+ i = 0
247
+
248
+ while i < len(ops) - 1:
249
+ op = ops[i]
250
+ next_op = ops[i + 1]
251
+
252
+ if op.get("op_name") != "DEPTHWISE_CONV_2D":
253
+ i += 1
254
+ continue
255
+
256
+ # Check next op is an activation
257
+ activation = None
258
+ if next_op.get("op_name") == "RELU":
259
+ activation = "relu"
260
+ elif next_op.get("op_name") == "RELU6":
261
+ activation = "relu6"
262
+ else:
263
+ i += 1
264
+ continue
265
+
266
+ # Check that dwconv's output feeds into activation's input
267
+ dwconv_output = op.get("output_indices", [])[0]
268
+ act_input = next_op.get("input_indices", [])[0]
269
+
270
+ if dwconv_output != act_input:
271
+ i += 1
272
+ continue
273
+
274
+ act_output = next_op.get("output_indices", [])[0]
275
+
276
+ # Fuse: add activation to dwconv params
277
+ dwconv_params = op.get("dwconv_params", {})
278
+ dwconv_params["activation"] = activation
279
+ op["dwconv_params"] = dwconv_params
280
+
281
+ op["output_indices"] = [act_output]
282
+
283
+ del ops[i + 1]
284
+
285
+ self._update_tensor_refs_after_removal(model_info, act_output,
286
+ dwconv_output)
287
+
288
+ fused = True
289
+ self._fused_count += 1
290
+ self._log_change(f" Fused DEPTHWISE_CONV_2D + {activation}")
291
+
292
+ if fused:
293
+ model_info["ops"] = ops
294
+
295
+ return fused
296
+
297
+ def _count_uses(self, model_info: Dict[str, Any], tensor_idx: int) -> int:
298
+ """Count how many ops use tensor_idx as input."""
299
+ count = 0
300
+ for op in model_info.get("ops", []):
301
+ if tensor_idx in op.get("input_indices", []):
302
+ count += 1
303
+ return count
304
+
305
+ def _fuse_conv_add(self, model_info: Dict[str, Any]) -> bool:
306
+ """
307
+ Fuse CONV_2D + ADD (residual connection) into a single CONV_2D.
308
+
309
+ Pattern:
310
+ input ──┬──> CONV_2D ──> ADD
311
+ │ ↑
312
+ └─────────────────┘
313
+
314
+ Becomes:
315
+ input ──> CONV_2D (with residual=True) ──> output
316
+ """
317
+ ops = model_info.get("ops", [])
318
+ fused = False
319
+ i = 0
320
+
321
+ while i < len(ops) - 1:
322
+ op = ops[i]
323
+ next_op = ops[i + 1]
324
+
325
+ if op.get("op_name") != "CONV_2D":
326
+ i += 1
327
+ continue
328
+
329
+ if next_op.get("op_name") != "ADD":
330
+ i += 1
331
+ continue
332
+
333
+ # CONV_2D outputs: [conv_out, ...]
334
+ conv_outputs = op.get("output_indices", [])
335
+ if not conv_outputs:
336
+ i += 1
337
+ continue
338
+ conv_out = conv_outputs[0]
339
+
340
+ # ADD inputs: [a, b]
341
+ add_inputs = next_op.get("input_indices", [])
342
+ if len(add_inputs) != 2:
343
+ i += 1
344
+ continue
345
+
346
+ # Check if the ADD is adding conv output with the conv input
347
+ # One input must be conv_out, the other must be conv's input
348
+ conv_inputs = op.get("input_indices", [])
349
+ if not conv_inputs:
350
+ i += 1
351
+ continue
352
+ conv_in = conv_inputs[0]
353
+
354
+ # Check if ADD inputs are exactly {conv_out, conv_in}
355
+ if set(add_inputs) != {conv_out, conv_in}:
356
+ i += 1
357
+ continue
358
+
359
+ # Also check that conv_out is not used by any other op
360
+ # (otherwise we can't safely remove the ADD)
361
+ used_count = self._count_uses(model_info, conv_out)
362
+ if used_count > 1:
363
+ # conv_out is used by other ops, can't fuse
364
+ i += 1
365
+ continue
366
+
367
+ # Get ADD's output
368
+ add_output = next_op.get("output_indices", [])[0]
369
+
370
+ # Mark CONV_2D as having a residual connection
371
+ conv_params = op.get("conv_params", {})
372
+ conv_params["residual"] = True
373
+ op["conv_params"] = conv_params
374
+
375
+ # Update CONV_2D's output to ADD's output
376
+ op["output_indices"] = [add_output]
377
+
378
+ # Remove the ADD op
379
+ del ops[i + 1]
380
+
381
+ # Update tensor references: any op that used add_output
382
+ # now uses conv_out (which is the same tensor, but we keep it)
383
+ # Actually we need to clean up the old conv_out tensor
384
+ # Since we're using add_output as the new output, we need to
385
+ # make sure conv_out is not left dangling.
386
+
387
+ # Remove conv_out from tensors (it's replaced by add_output)
388
+ if conv_out in model_info.get("tensors", {}):
389
+ del model_info["tensors"][conv_out]
390
+ if conv_out in model_info.get("weights", {}):
391
+ del model_info["weights"][conv_out]
392
+
393
+ fused = True
394
+ self._fused_count += 1
395
+ self._log_change(" Fused CONV_2D + ADD (residual)")
396
+
397
+ # Don't increment i, check next op
398
+
399
+ if fused:
400
+ model_info["ops"] = ops
401
+
402
+ return fused
403
+
404
+ def _fuse_conv_batchnorm(self, model_info: Dict[str, Any]) -> bool:
405
+ """
406
+ Fuse CONV_2D + BATCH_NORM into a single CONV_2D.
407
+
408
+ The BN parameters are folded into the conv weights and bias.
409
+ """
410
+ ops = model_info.get("ops", [])
411
+ tensors = model_info.get("tensors", {})
412
+ weights = model_info.get("weights", {})
413
+ fused = False
414
+ i = 0
415
+
416
+ while i < len(ops) - 1:
417
+ op = ops[i]
418
+ next_op = ops[i + 1]
419
+
420
+ if op.get("op_name") != "CONV_2D":
421
+ i += 1
422
+ continue
423
+
424
+ if next_op.get("op_name") != "BATCH_NORM":
425
+ i += 1
426
+ continue
427
+
428
+ # Check direct connection
429
+ conv_out = op.get("output_indices", [])[0]
430
+ bn_input = next_op.get("input_indices", [])[0]
431
+ if conv_out != bn_input:
432
+ i += 1
433
+ continue
434
+
435
+ # Get BN parameters
436
+ bn_params = next_op.get("bn_params", {})
437
+ scale_idx = bn_params.get("scale_idx")
438
+ bias_idx = bn_params.get("bias_idx")
439
+ mean_idx = bn_params.get("mean_idx")
440
+ var_idx = bn_params.get("var_idx")
441
+ epsilon = bn_params.get("epsilon", 1e-5)
442
+
443
+ if None in (scale_idx, bias_idx, mean_idx, var_idx):
444
+ i += 1
445
+ continue
446
+
447
+ # Get conv weights and bias
448
+ conv_inputs = op.get("input_indices", [])
449
+ if len(conv_inputs) < 2:
450
+ i += 1
451
+ continue
452
+ weight_idx = conv_inputs[1]
453
+ bias_idx_conv = conv_inputs[2] if len(conv_inputs) > 2 else None
454
+
455
+ if weight_idx not in weights:
456
+ i += 1
457
+ continue
458
+
459
+ # Extract BN parameters as numpy arrays
460
+ scale = weights.get(scale_idx)
461
+ bias = weights.get(bias_idx)
462
+ mean = weights.get(mean_idx)
463
+ var = weights.get(var_idx)
464
+
465
+ if not all(isinstance(x, np.ndarray) for x in
466
+ (scale, bias, mean, var)):
467
+ i += 1
468
+ continue
469
+
470
+ # Fold BN into conv weights:
471
+ # w' = w * scale / sqrt(var + eps)
472
+ # b' = (b - mean) * scale / sqrt(var + eps) + bias
473
+ conv_weight = weights[weight_idx]
474
+ conv_bias = weights.get(bias_idx_conv,
475
+ np.zeros(conv_weight.shape[-1],
476
+ dtype=np.int32))
477
+
478
+ # Compute folding factor
479
+ std = np.sqrt(var + epsilon)
480
+ factor = scale / std
481
+
482
+ # Fold into weights (assuming conv_weight shape:
483
+ # [H, W, C_in, C_out])
484
+ # Convert to float for computation
485
+ w_f = conv_weight.astype(np.float32)
486
+ b_f = conv_bias.astype(np.float32)
487
+ factor_f = factor.astype(np.float32).reshape(1, 1, 1, -1)
488
+
489
+ # Fold: w_new = w * factor
490
+ w_new = (w_f * factor_f).astype(np.int8)
491
+
492
+ # Fold bias: b_new = (b - mean) * factor + bias
493
+ mean_f = mean.astype(np.float32).reshape(1, -1)
494
+ b_new = ((b_f - mean_f) * factor_f.reshape(1, -1) + bias.astype(
495
+ np.float32).reshape(1, -1))
496
+ b_new = b_new.astype(np.int32).flatten()
497
+
498
+ # Update weights
499
+ weights[weight_idx] = w_new
500
+ if bias_idx_conv:
501
+ weights[bias_idx_conv] = b_new
502
+ else:
503
+ # Create bias if it didn't exist
504
+ bias_new_idx = max(weights.keys()) + 1 if weights else 1
505
+ weights[bias_new_idx] = b_new
506
+ op["input_indices"].append(bias_new_idx)
507
+
508
+ # Remove BN op
509
+ bn_output = next_op.get("output_indices", [])[0]
510
+ op["output_indices"] = [bn_output]
511
+
512
+ del ops[i + 1]
513
+
514
+ fused = True
515
+ self._fused_count += 1
516
+ self._log_change(" Fused CONV_2D + BATCH_NORM")
517
+
518
+ if fused:
519
+ model_info["ops"] = ops
520
+
521
+ return fused
522
+
523
+ def _fuse_conv_conv(self, model_info: Dict[str, Any]) -> bool:
524
+ """
525
+ Fuse CONV_2D + CONV_2D into one.
526
+
527
+ Supports:
528
+ 1. 1x1 conv + 1x1 conv -> 1x1 conv
529
+ 2. 3x3 conv + 3x3 conv -> 5x5 conv
530
+ """
531
+ ops = model_info.get("ops", [])
532
+ tensors = model_info.get("tensors", {})
533
+ weights = model_info.get("weights", {})
534
+ fused = False
535
+ i = 0
536
+
537
+ while i < len(ops) - 1:
538
+ op = ops[i]
539
+ next_op = ops[i + 1]
540
+
541
+ if op.get("op_name") != "CONV_2D" or next_op.get(
542
+ "op_name") != "CONV_2D":
543
+ i += 1
544
+ continue
545
+
546
+ # Check direct connection
547
+ conv_out = op.get("output_indices", [])[0]
548
+ next_input = next_op.get("input_indices", [])[0]
549
+ if conv_out != next_input:
550
+ i += 1
551
+ continue
552
+
553
+ # Check no activation on first conv
554
+ conv_params = op.get("conv_params", {})
555
+ if conv_params.get("activation"):
556
+ i += 1
557
+ continue
558
+
559
+ # Check conv_out is only used by the next conv
560
+ if self._count_uses(model_info, conv_out) > 1:
561
+ i += 1
562
+ continue
563
+
564
+ # Get params
565
+ c1 = op.get("conv_params", {})
566
+ c2 = next_op.get("conv_params", {})
567
+
568
+ # Get weight indices
569
+ conv1_inputs = op.get("input_indices", [])
570
+ conv2_inputs = next_op.get("input_indices", [])
571
+ if len(conv1_inputs) < 2 or len(conv2_inputs) < 2:
572
+ i += 1
573
+ continue
574
+ w1_idx = conv1_inputs[1]
575
+ w2_idx = conv2_inputs[1]
576
+
577
+ if w1_idx not in weights or w2_idx not in weights:
578
+ i += 1
579
+ continue
580
+
581
+ w1 = weights[w1_idx]
582
+ w2 = weights[w2_idx]
583
+
584
+ if len(w1.shape) != 4 or len(w2.shape) != 4:
585
+ i += 1
586
+ continue
587
+
588
+ if w1.shape[3] != w2.shape[2]:
589
+ i += 1
590
+ continue
591
+
592
+ C_in = w1.shape[2]
593
+ C_mid = w1.shape[3]
594
+ C_out = w2.shape[3]
595
+
596
+ k1 = c1.get("kernel_size")
597
+ k2 = c2.get("kernel_size")
598
+ s1 = c1.get("stride", 1)
599
+ s2 = c2.get("stride", 1)
600
+ p1 = c1.get("padding", "SAME")
601
+ p2 = c2.get("padding", "SAME")
602
+
603
+ # ----------------------------------------------------------------
604
+ # Pattern 1: 1x1 + 1x1 -> 1x1
605
+ # ----------------------------------------------------------------
606
+ if k1 == 1 and k2 == 1 and s1 == 1 and s2 == 1:
607
+ w1_f = w1.astype(np.float32)
608
+ w2_f = w2.astype(np.float32)
609
+
610
+ # w1: [1, 1, C_in, C_mid], w2: [1, 1, C_mid, C_out]
611
+ # w_fused: [1, 1, C_in, C_out]
612
+ w_fused = np.matmul(
613
+ w1_f.transpose(0, 1, 3, 2), # [1, 1, C_mid, C_in]
614
+ w2_f.transpose(0, 1, 2, 3) # [1, 1, C_mid, C_out]
615
+ )
616
+ w_fused = w_fused.astype(np.int8)
617
+
618
+ # Update weights
619
+ weights[w1_idx] = w_fused
620
+
621
+ # Keep second conv's params (but update kernel size)
622
+ next_op["conv_params"]["kernel_size"] = 1
623
+
624
+ # Remove first conv
625
+ op_inputs = op.get("input_indices", [])
626
+ next_op["input_indices"][0] = op_inputs[0]
627
+
628
+ # Clean up
629
+ if conv_out in tensors:
630
+ del tensors[conv_out]
631
+ if conv_out in weights:
632
+ del weights[conv_out]
633
+
634
+ del ops[i]
635
+
636
+ fused = True
637
+ self._fused_count += 1
638
+ self._log_change(" Fused CONV_2D + CONV_2D (1x1 + 1x1 -> 1x1)")
639
+ continue
640
+
641
+ # ----------------------------------------------------------------
642
+ # Pattern 2: 3x3 + 3x3 -> 5x5
643
+ # ----------------------------------------------------------------
644
+ if (k1 == 3 and k2 == 3 and s1 == 1 and s2 == 1
645
+ and p1 == "SAME" and p2 == "SAME"):
646
+ w1_f = w1.astype(np.float32)
647
+ w2_f = w2.astype(np.float32)
648
+
649
+ # w1: [3, 3, C_in, C_mid], w2: [3, 3, C_mid, C_out]
650
+ # w_fused: [5, 5, C_in, C_out]
651
+ w_fused = np.zeros((5, 5, C_in, C_out), dtype=np.float32)
652
+
653
+ for ic in range(C_in):
654
+ for oc in range(C_out):
655
+ # For each output channel, convolve w1 with w2
656
+ # w1[:, :, ic, :] -> [3, 3, C_mid]
657
+ # w2[:, :, :, oc] -> [3, 3, C_mid]
658
+ # Result is a 5x5 kernel
659
+ kernel = np.zeros((5, 5), dtype=np.float32)
660
+ kernel[1:4, 1:4] = w1_f[
661
+ :, :, ic, :] # Place w1 in center
662
+ for mid in range(C_mid):
663
+ w2_mid = w2_f[:, :, mid, oc]
664
+ kernel_tmp = kernel[:, :, mid] * w2_mid
665
+ w_fused[:, :, ic, oc] += kernel_tmp
666
+
667
+ w_fused = w_fused.astype(np.int8)
668
+
669
+ # Update weights
670
+ weights[w1_idx] = w_fused
671
+
672
+ # Update second conv params
673
+ next_op["conv_params"]["kernel_size"] = 5
674
+ next_op["conv_params"]["padding"] = "VALID"
675
+
676
+ # Remove first conv
677
+ op_inputs = op.get("input_indices", [])
678
+ next_op["input_indices"][0] = op_inputs[0]
679
+
680
+ # Clean up
681
+ if conv_out in tensors:
682
+ del tensors[conv_out]
683
+ if conv_out in weights:
684
+ del weights[conv_out]
685
+
686
+ del ops[i]
687
+
688
+ fused = True
689
+ self._fused_count += 1
690
+ self._log_change(" Fused CONV_2D + CONV_2D (3x3 + 3x3 -> 5x5)")
691
+ continue
692
+
693
+ i += 1
694
+
695
+ if fused:
696
+ model_info["ops"] = ops
697
+
698
+ return fused
699
+
700
+ def _update_tensor_refs_after_removal(
701
+ self,
702
+ model_info: Dict[str, Any],
703
+ removed_output: int,
704
+ replaced_by: int
705
+ ) -> None:
706
+ """Update all tensor references after removing an op."""
707
+ # Update all ops
708
+ for op in model_info.get("ops", []):
709
+ # Update input_indices
710
+ for idx, input_idx in enumerate(op.get("input_indices", [])):
711
+ if input_idx == removed_output:
712
+ op["input_indices"][idx] = replaced_by
713
+
714
+ # Update output_indices
715
+ for idx, output_idx in enumerate(op.get("output_indices", [])):
716
+ if output_idx == removed_output:
717
+ op["output_indices"][idx] = replaced_by
718
+
719
+ # Remove the dead tensor from tensors dict
720
+ if removed_output in model_info.get("tensors", {}):
721
+ del model_info["tensors"][removed_output]
722
+ if removed_output in model_info.get("weights", {}):
723
+ del model_info["weights"][removed_output]