tinymlc 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- TinyMLC/ANG/__init__.py +0 -0
- TinyMLC/ANG/args.py +86 -0
- TinyMLC/ANG/estimator.py +103 -0
- TinyMLC/ANG/estimator_hal.py +184 -0
- TinyMLC/ANG/estimator_qemu.py +257 -0
- TinyMLC/ANG/estimator_software.py +130 -0
- TinyMLC/ANG/model_builder.py +508 -0
- TinyMLC/ANG/model_generator.py +439 -0
- TinyMLC/ANG/model_info.py +283 -0
- TinyMLC/ANG/utils.py +420 -0
- TinyMLC/__init__.py +0 -0
- TinyMLC/cli.py +126 -0
- TinyMLC/codegen.py +877 -0
- TinyMLC/converter/__init__.py +0 -0
- TinyMLC/converter/export_weights.py +382 -0
- TinyMLC/converter/parser_litert.py +757 -0
- TinyMLC/converter/parser_onnx.py +649 -0
- TinyMLC/generate_lut.py +97 -0
- TinyMLC/handlers.py +325 -0
- TinyMLC/ops.py +76 -0
- TinyMLC/templates/lut.c.tpl +23 -0
- TinyMLC/templates/lut.h.tpl +67 -0
- TinyMLC/templates/model.c.tpl +314 -0
- TinyMLC/templates/model.h.tpl +66 -0
- TinyMLC/transform/__init__.py +0 -0
- TinyMLC/transform/algebraic.py +286 -0
- TinyMLC/transform/base.py +58 -0
- TinyMLC/transform/constant_folding.py +260 -0
- TinyMLC/transform/cse.py +192 -0
- TinyMLC/transform/dce.py +182 -0
- TinyMLC/transform/fusion.py +723 -0
- TinyMLC/transform/memory.py +200 -0
- TinyMLC/transform/pass_manager.py +101 -0
- TinyMLC/transform/simplify.py +515 -0
- tinymlc-0.1.0.dist-info/METADATA +49 -0
- tinymlc-0.1.0.dist-info/RECORD +47 -0
- tinymlc-0.1.0.dist-info/WHEEL +4 -0
- tinymlc-0.1.0.dist-info/entry_points.txt +2 -0
- tinymlc-0.1.0.dist-info/licenses/LICENSE +201 -0
- utils/__init__.py +0 -0
- utils/arm-none-eabi-gcc.cmake +53 -0
- utils/dump.py +86 -0
- utils/generate_onnx_models.py +183 -0
- utils/generate_tflite_models.py +236 -0
- utils/pack_macos.sh +88 -0
- utils/path.py +31 -0
- utils/riscv-none-elf-gcc.cmake +50 -0
|
@@ -0,0 +1,723 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# TinyMLC - Tiny Machine Learning Compiler
|
|
3
|
+
#
|
|
4
|
+
# Copyright (c) 2026 Jia Liu & TinyMLC Contributors
|
|
5
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
6
|
+
#
|
|
7
|
+
# This file is part of TinyMLC.
|
|
8
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
9
|
+
# you may not use this file except in compliance with the License.
|
|
10
|
+
# You may obtain a copy of the License at:
|
|
11
|
+
#
|
|
12
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
13
|
+
#
|
|
14
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
15
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
16
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
17
|
+
# See the License for the specific language governing permissions and
|
|
18
|
+
# limitations under the License.
|
|
19
|
+
|
|
20
|
+
# Operator Fusion.
|
|
21
|
+
|
|
22
|
+
import numpy as np
|
|
23
|
+
|
|
24
|
+
from typing import Dict, Any
|
|
25
|
+
from TinyMLC.transform.base import Pass
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class OperatorFusion(Pass):
|
|
29
|
+
"""
|
|
30
|
+
Operator Fusion.
|
|
31
|
+
|
|
32
|
+
Fuses adjacent operators into single operators:
|
|
33
|
+
- CONV_2D + RELU -> CONV_2D (with activation fused in params)
|
|
34
|
+
- FC + RELU -> FC (with activation fused in params)
|
|
35
|
+
- CONV_2D + RELU6 -> CONV_2D (with relu6)
|
|
36
|
+
- CONV_2D + HARD_SIGMOID -> CONV_2D (with hard_sigmoid)
|
|
37
|
+
- FC + SOFTMAX -> FC (with softmax fused)
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self, name: str = "OperatorFusion"):
|
|
41
|
+
super().__init__(name)
|
|
42
|
+
self._fused_count = 0
|
|
43
|
+
|
|
44
|
+
def run(self, model_info: Dict[str, Any]) -> Dict[str, Any]:
|
|
45
|
+
"""Run operator fusion on model_info."""
|
|
46
|
+
model_info = self._copy_model(model_info)
|
|
47
|
+
self._fused_count = 0
|
|
48
|
+
|
|
49
|
+
# Iteratively fuse until no more changes
|
|
50
|
+
changed = True
|
|
51
|
+
while changed:
|
|
52
|
+
changed = False
|
|
53
|
+
changed |= self._fuse_conv_activation(model_info)
|
|
54
|
+
changed |= self._fuse_fc_activation(model_info)
|
|
55
|
+
changed |= self._fuse_fc_softmax(model_info)
|
|
56
|
+
changed |= self._fuse_dwconv_activation(model_info)
|
|
57
|
+
changed |= self._fuse_conv_add(model_info)
|
|
58
|
+
changed |= self._fuse_conv_batchnorm(model_info)
|
|
59
|
+
changed |= self._fuse_conv_conv(model_info)
|
|
60
|
+
|
|
61
|
+
if self._fused_count > 0:
|
|
62
|
+
self._log_change(f"Fused {self._fused_count} operator pairs")
|
|
63
|
+
|
|
64
|
+
return model_info
|
|
65
|
+
|
|
66
|
+
def _fuse_conv_activation(self, model_info: Dict[str, Any]) -> bool:
|
|
67
|
+
"""Fuse CONV_2D + activation into a single CONV_2D."""
|
|
68
|
+
ops = model_info.get("ops", [])
|
|
69
|
+
fused = False
|
|
70
|
+
i = 0
|
|
71
|
+
|
|
72
|
+
while i < len(ops) - 1:
|
|
73
|
+
op = ops[i]
|
|
74
|
+
next_op = ops[i + 1]
|
|
75
|
+
|
|
76
|
+
if op.get("op_name") != "CONV_2D":
|
|
77
|
+
i += 1
|
|
78
|
+
continue
|
|
79
|
+
|
|
80
|
+
# Check next op is an activation
|
|
81
|
+
activation = None
|
|
82
|
+
if next_op.get("op_name") == "RELU":
|
|
83
|
+
activation = "relu"
|
|
84
|
+
elif next_op.get("op_name") == "RELU6":
|
|
85
|
+
activation = "relu6"
|
|
86
|
+
elif next_op.get("op_name") == "HARD_SIGMOID":
|
|
87
|
+
activation = "hard_sigmoid"
|
|
88
|
+
else:
|
|
89
|
+
i += 1
|
|
90
|
+
continue
|
|
91
|
+
|
|
92
|
+
# Check that conv's output feeds into activation's input
|
|
93
|
+
conv_output = op.get("output_indices", [])[0]
|
|
94
|
+
act_input = next_op.get("input_indices", [])[0]
|
|
95
|
+
|
|
96
|
+
if conv_output != act_input:
|
|
97
|
+
i += 1
|
|
98
|
+
continue
|
|
99
|
+
|
|
100
|
+
# Also check that activation's output is used by later ops
|
|
101
|
+
# (not just a dead op)
|
|
102
|
+
act_output = next_op.get("output_indices", [])[0]
|
|
103
|
+
|
|
104
|
+
# Fuse: add activation to conv params
|
|
105
|
+
conv_params = op.get("conv_params", {})
|
|
106
|
+
conv_params["activation"] = activation
|
|
107
|
+
op["conv_params"] = conv_params
|
|
108
|
+
|
|
109
|
+
# Update op's output to activation's output
|
|
110
|
+
op["output_indices"] = [act_output]
|
|
111
|
+
|
|
112
|
+
# Ensure act_output exists in tensors
|
|
113
|
+
tensors = model_info.get("tensors", {})
|
|
114
|
+
if act_output not in tensors:
|
|
115
|
+
if conv_output in tensors:
|
|
116
|
+
tensors[act_output] = tensors[conv_output].copy()
|
|
117
|
+
tensors[act_output]["name"] = f"tensor_{act_output}"
|
|
118
|
+
else:
|
|
119
|
+
tensors[act_output] = {
|
|
120
|
+
"name": f"tensor_{act_output}",
|
|
121
|
+
"shape": [1, 8, 28, 28],
|
|
122
|
+
"dtype": "float32",
|
|
123
|
+
"scale": 1.0,
|
|
124
|
+
"zero_point": 0,
|
|
125
|
+
}
|
|
126
|
+
model_info["tensors"] = tensors
|
|
127
|
+
|
|
128
|
+
# Update all references: replace conv_output with act_output,
|
|
129
|
+
# then delete conv_output
|
|
130
|
+
self._update_tensor_refs_after_removal(model_info, conv_output,
|
|
131
|
+
act_output)
|
|
132
|
+
|
|
133
|
+
# Remove the activation op
|
|
134
|
+
del ops[i + 1]
|
|
135
|
+
|
|
136
|
+
fused = True
|
|
137
|
+
self._fused_count += 1
|
|
138
|
+
self._log_change(f" Fused CONV_2D + {activation}")
|
|
139
|
+
|
|
140
|
+
# Don't increment i, check if next op can also be fused
|
|
141
|
+
# (but there won't be another activation right after)
|
|
142
|
+
|
|
143
|
+
if fused:
|
|
144
|
+
model_info["ops"] = ops
|
|
145
|
+
|
|
146
|
+
return fused
|
|
147
|
+
|
|
148
|
+
def _fuse_fc_activation(self, model_info: Dict[str, Any]) -> bool:
|
|
149
|
+
"""Fuse FC + activation into a single FC."""
|
|
150
|
+
ops = model_info.get("ops", [])
|
|
151
|
+
fused = False
|
|
152
|
+
i = 0
|
|
153
|
+
|
|
154
|
+
while i < len(ops) - 1:
|
|
155
|
+
op = ops[i]
|
|
156
|
+
next_op = ops[i + 1]
|
|
157
|
+
|
|
158
|
+
if op.get("op_name") != "FULLY_CONNECTED":
|
|
159
|
+
i += 1
|
|
160
|
+
continue
|
|
161
|
+
|
|
162
|
+
activation = None
|
|
163
|
+
if next_op.get("op_name") == "RELU":
|
|
164
|
+
activation = "relu"
|
|
165
|
+
elif next_op.get("op_name") == "RELU6":
|
|
166
|
+
activation = "relu6"
|
|
167
|
+
else:
|
|
168
|
+
i += 1
|
|
169
|
+
continue
|
|
170
|
+
|
|
171
|
+
fc_output = op.get("output_indices", [])[0]
|
|
172
|
+
act_input = next_op.get("input_indices", [])[0]
|
|
173
|
+
|
|
174
|
+
if fc_output != act_input:
|
|
175
|
+
i += 1
|
|
176
|
+
continue
|
|
177
|
+
|
|
178
|
+
act_output = next_op.get("output_indices", [])[0]
|
|
179
|
+
|
|
180
|
+
fc_params = op.get("fc_params", {})
|
|
181
|
+
fc_params["activation"] = activation
|
|
182
|
+
op["fc_params"] = fc_params
|
|
183
|
+
|
|
184
|
+
op["output_indices"] = [act_output]
|
|
185
|
+
|
|
186
|
+
del ops[i + 1]
|
|
187
|
+
|
|
188
|
+
fused = True
|
|
189
|
+
self._fused_count += 1
|
|
190
|
+
self._log_change(f" Fused FULLY_CONNECTED + {activation}")
|
|
191
|
+
|
|
192
|
+
if fused:
|
|
193
|
+
model_info["ops"] = ops
|
|
194
|
+
|
|
195
|
+
return fused
|
|
196
|
+
|
|
197
|
+
def _fuse_fc_softmax(self, model_info: Dict[str, Any]) -> bool:
|
|
198
|
+
"""Fuse FC + SOFTMAX into FC with softmax flag."""
|
|
199
|
+
ops = model_info.get("ops", [])
|
|
200
|
+
fused = False
|
|
201
|
+
i = 0
|
|
202
|
+
|
|
203
|
+
while i < len(ops) - 1:
|
|
204
|
+
op = ops[i]
|
|
205
|
+
next_op = ops[i + 1]
|
|
206
|
+
|
|
207
|
+
if op.get("op_name") != "FULLY_CONNECTED":
|
|
208
|
+
i += 1
|
|
209
|
+
continue
|
|
210
|
+
|
|
211
|
+
if next_op.get("op_name") != "SOFTMAX":
|
|
212
|
+
i += 1
|
|
213
|
+
continue
|
|
214
|
+
|
|
215
|
+
fc_output = op.get("output_indices", [])[0]
|
|
216
|
+
sm_input = next_op.get("input_indices", [])[0]
|
|
217
|
+
|
|
218
|
+
if fc_output != sm_input:
|
|
219
|
+
i += 1
|
|
220
|
+
continue
|
|
221
|
+
|
|
222
|
+
sm_output = next_op.get("output_indices", [])[0]
|
|
223
|
+
|
|
224
|
+
fc_params = op.get("fc_params", {})
|
|
225
|
+
fc_params["with_softmax"] = True
|
|
226
|
+
op["fc_params"] = fc_params
|
|
227
|
+
|
|
228
|
+
op["output_indices"] = [sm_output]
|
|
229
|
+
|
|
230
|
+
del ops[i + 1]
|
|
231
|
+
|
|
232
|
+
fused = True
|
|
233
|
+
self._fused_count += 1
|
|
234
|
+
self._log_change(" Fused FULLY_CONNECTED + SOFTMAX")
|
|
235
|
+
|
|
236
|
+
if fused:
|
|
237
|
+
model_info["ops"] = ops
|
|
238
|
+
|
|
239
|
+
return fused
|
|
240
|
+
|
|
241
|
+
def _fuse_dwconv_activation(self, model_info: Dict[str, Any]) -> bool:
|
|
242
|
+
"""Fuse DEPTHWISE_CONV_2D + activation into a single
|
|
243
|
+
DEPTHWISE_CONV_2D."""
|
|
244
|
+
ops = model_info.get("ops", [])
|
|
245
|
+
fused = False
|
|
246
|
+
i = 0
|
|
247
|
+
|
|
248
|
+
while i < len(ops) - 1:
|
|
249
|
+
op = ops[i]
|
|
250
|
+
next_op = ops[i + 1]
|
|
251
|
+
|
|
252
|
+
if op.get("op_name") != "DEPTHWISE_CONV_2D":
|
|
253
|
+
i += 1
|
|
254
|
+
continue
|
|
255
|
+
|
|
256
|
+
# Check next op is an activation
|
|
257
|
+
activation = None
|
|
258
|
+
if next_op.get("op_name") == "RELU":
|
|
259
|
+
activation = "relu"
|
|
260
|
+
elif next_op.get("op_name") == "RELU6":
|
|
261
|
+
activation = "relu6"
|
|
262
|
+
else:
|
|
263
|
+
i += 1
|
|
264
|
+
continue
|
|
265
|
+
|
|
266
|
+
# Check that dwconv's output feeds into activation's input
|
|
267
|
+
dwconv_output = op.get("output_indices", [])[0]
|
|
268
|
+
act_input = next_op.get("input_indices", [])[0]
|
|
269
|
+
|
|
270
|
+
if dwconv_output != act_input:
|
|
271
|
+
i += 1
|
|
272
|
+
continue
|
|
273
|
+
|
|
274
|
+
act_output = next_op.get("output_indices", [])[0]
|
|
275
|
+
|
|
276
|
+
# Fuse: add activation to dwconv params
|
|
277
|
+
dwconv_params = op.get("dwconv_params", {})
|
|
278
|
+
dwconv_params["activation"] = activation
|
|
279
|
+
op["dwconv_params"] = dwconv_params
|
|
280
|
+
|
|
281
|
+
op["output_indices"] = [act_output]
|
|
282
|
+
|
|
283
|
+
del ops[i + 1]
|
|
284
|
+
|
|
285
|
+
self._update_tensor_refs_after_removal(model_info, act_output,
|
|
286
|
+
dwconv_output)
|
|
287
|
+
|
|
288
|
+
fused = True
|
|
289
|
+
self._fused_count += 1
|
|
290
|
+
self._log_change(f" Fused DEPTHWISE_CONV_2D + {activation}")
|
|
291
|
+
|
|
292
|
+
if fused:
|
|
293
|
+
model_info["ops"] = ops
|
|
294
|
+
|
|
295
|
+
return fused
|
|
296
|
+
|
|
297
|
+
def _count_uses(self, model_info: Dict[str, Any], tensor_idx: int) -> int:
|
|
298
|
+
"""Count how many ops use tensor_idx as input."""
|
|
299
|
+
count = 0
|
|
300
|
+
for op in model_info.get("ops", []):
|
|
301
|
+
if tensor_idx in op.get("input_indices", []):
|
|
302
|
+
count += 1
|
|
303
|
+
return count
|
|
304
|
+
|
|
305
|
+
def _fuse_conv_add(self, model_info: Dict[str, Any]) -> bool:
|
|
306
|
+
"""
|
|
307
|
+
Fuse CONV_2D + ADD (residual connection) into a single CONV_2D.
|
|
308
|
+
|
|
309
|
+
Pattern:
|
|
310
|
+
input ──┬──> CONV_2D ──> ADD
|
|
311
|
+
│ ↑
|
|
312
|
+
└─────────────────┘
|
|
313
|
+
|
|
314
|
+
Becomes:
|
|
315
|
+
input ──> CONV_2D (with residual=True) ──> output
|
|
316
|
+
"""
|
|
317
|
+
ops = model_info.get("ops", [])
|
|
318
|
+
fused = False
|
|
319
|
+
i = 0
|
|
320
|
+
|
|
321
|
+
while i < len(ops) - 1:
|
|
322
|
+
op = ops[i]
|
|
323
|
+
next_op = ops[i + 1]
|
|
324
|
+
|
|
325
|
+
if op.get("op_name") != "CONV_2D":
|
|
326
|
+
i += 1
|
|
327
|
+
continue
|
|
328
|
+
|
|
329
|
+
if next_op.get("op_name") != "ADD":
|
|
330
|
+
i += 1
|
|
331
|
+
continue
|
|
332
|
+
|
|
333
|
+
# CONV_2D outputs: [conv_out, ...]
|
|
334
|
+
conv_outputs = op.get("output_indices", [])
|
|
335
|
+
if not conv_outputs:
|
|
336
|
+
i += 1
|
|
337
|
+
continue
|
|
338
|
+
conv_out = conv_outputs[0]
|
|
339
|
+
|
|
340
|
+
# ADD inputs: [a, b]
|
|
341
|
+
add_inputs = next_op.get("input_indices", [])
|
|
342
|
+
if len(add_inputs) != 2:
|
|
343
|
+
i += 1
|
|
344
|
+
continue
|
|
345
|
+
|
|
346
|
+
# Check if the ADD is adding conv output with the conv input
|
|
347
|
+
# One input must be conv_out, the other must be conv's input
|
|
348
|
+
conv_inputs = op.get("input_indices", [])
|
|
349
|
+
if not conv_inputs:
|
|
350
|
+
i += 1
|
|
351
|
+
continue
|
|
352
|
+
conv_in = conv_inputs[0]
|
|
353
|
+
|
|
354
|
+
# Check if ADD inputs are exactly {conv_out, conv_in}
|
|
355
|
+
if set(add_inputs) != {conv_out, conv_in}:
|
|
356
|
+
i += 1
|
|
357
|
+
continue
|
|
358
|
+
|
|
359
|
+
# Also check that conv_out is not used by any other op
|
|
360
|
+
# (otherwise we can't safely remove the ADD)
|
|
361
|
+
used_count = self._count_uses(model_info, conv_out)
|
|
362
|
+
if used_count > 1:
|
|
363
|
+
# conv_out is used by other ops, can't fuse
|
|
364
|
+
i += 1
|
|
365
|
+
continue
|
|
366
|
+
|
|
367
|
+
# Get ADD's output
|
|
368
|
+
add_output = next_op.get("output_indices", [])[0]
|
|
369
|
+
|
|
370
|
+
# Mark CONV_2D as having a residual connection
|
|
371
|
+
conv_params = op.get("conv_params", {})
|
|
372
|
+
conv_params["residual"] = True
|
|
373
|
+
op["conv_params"] = conv_params
|
|
374
|
+
|
|
375
|
+
# Update CONV_2D's output to ADD's output
|
|
376
|
+
op["output_indices"] = [add_output]
|
|
377
|
+
|
|
378
|
+
# Remove the ADD op
|
|
379
|
+
del ops[i + 1]
|
|
380
|
+
|
|
381
|
+
# Update tensor references: any op that used add_output
|
|
382
|
+
# now uses conv_out (which is the same tensor, but we keep it)
|
|
383
|
+
# Actually we need to clean up the old conv_out tensor
|
|
384
|
+
# Since we're using add_output as the new output, we need to
|
|
385
|
+
# make sure conv_out is not left dangling.
|
|
386
|
+
|
|
387
|
+
# Remove conv_out from tensors (it's replaced by add_output)
|
|
388
|
+
if conv_out in model_info.get("tensors", {}):
|
|
389
|
+
del model_info["tensors"][conv_out]
|
|
390
|
+
if conv_out in model_info.get("weights", {}):
|
|
391
|
+
del model_info["weights"][conv_out]
|
|
392
|
+
|
|
393
|
+
fused = True
|
|
394
|
+
self._fused_count += 1
|
|
395
|
+
self._log_change(" Fused CONV_2D + ADD (residual)")
|
|
396
|
+
|
|
397
|
+
# Don't increment i, check next op
|
|
398
|
+
|
|
399
|
+
if fused:
|
|
400
|
+
model_info["ops"] = ops
|
|
401
|
+
|
|
402
|
+
return fused
|
|
403
|
+
|
|
404
|
+
def _fuse_conv_batchnorm(self, model_info: Dict[str, Any]) -> bool:
|
|
405
|
+
"""
|
|
406
|
+
Fuse CONV_2D + BATCH_NORM into a single CONV_2D.
|
|
407
|
+
|
|
408
|
+
The BN parameters are folded into the conv weights and bias.
|
|
409
|
+
"""
|
|
410
|
+
ops = model_info.get("ops", [])
|
|
411
|
+
tensors = model_info.get("tensors", {})
|
|
412
|
+
weights = model_info.get("weights", {})
|
|
413
|
+
fused = False
|
|
414
|
+
i = 0
|
|
415
|
+
|
|
416
|
+
while i < len(ops) - 1:
|
|
417
|
+
op = ops[i]
|
|
418
|
+
next_op = ops[i + 1]
|
|
419
|
+
|
|
420
|
+
if op.get("op_name") != "CONV_2D":
|
|
421
|
+
i += 1
|
|
422
|
+
continue
|
|
423
|
+
|
|
424
|
+
if next_op.get("op_name") != "BATCH_NORM":
|
|
425
|
+
i += 1
|
|
426
|
+
continue
|
|
427
|
+
|
|
428
|
+
# Check direct connection
|
|
429
|
+
conv_out = op.get("output_indices", [])[0]
|
|
430
|
+
bn_input = next_op.get("input_indices", [])[0]
|
|
431
|
+
if conv_out != bn_input:
|
|
432
|
+
i += 1
|
|
433
|
+
continue
|
|
434
|
+
|
|
435
|
+
# Get BN parameters
|
|
436
|
+
bn_params = next_op.get("bn_params", {})
|
|
437
|
+
scale_idx = bn_params.get("scale_idx")
|
|
438
|
+
bias_idx = bn_params.get("bias_idx")
|
|
439
|
+
mean_idx = bn_params.get("mean_idx")
|
|
440
|
+
var_idx = bn_params.get("var_idx")
|
|
441
|
+
epsilon = bn_params.get("epsilon", 1e-5)
|
|
442
|
+
|
|
443
|
+
if None in (scale_idx, bias_idx, mean_idx, var_idx):
|
|
444
|
+
i += 1
|
|
445
|
+
continue
|
|
446
|
+
|
|
447
|
+
# Get conv weights and bias
|
|
448
|
+
conv_inputs = op.get("input_indices", [])
|
|
449
|
+
if len(conv_inputs) < 2:
|
|
450
|
+
i += 1
|
|
451
|
+
continue
|
|
452
|
+
weight_idx = conv_inputs[1]
|
|
453
|
+
bias_idx_conv = conv_inputs[2] if len(conv_inputs) > 2 else None
|
|
454
|
+
|
|
455
|
+
if weight_idx not in weights:
|
|
456
|
+
i += 1
|
|
457
|
+
continue
|
|
458
|
+
|
|
459
|
+
# Extract BN parameters as numpy arrays
|
|
460
|
+
scale = weights.get(scale_idx)
|
|
461
|
+
bias = weights.get(bias_idx)
|
|
462
|
+
mean = weights.get(mean_idx)
|
|
463
|
+
var = weights.get(var_idx)
|
|
464
|
+
|
|
465
|
+
if not all(isinstance(x, np.ndarray) for x in
|
|
466
|
+
(scale, bias, mean, var)):
|
|
467
|
+
i += 1
|
|
468
|
+
continue
|
|
469
|
+
|
|
470
|
+
# Fold BN into conv weights:
|
|
471
|
+
# w' = w * scale / sqrt(var + eps)
|
|
472
|
+
# b' = (b - mean) * scale / sqrt(var + eps) + bias
|
|
473
|
+
conv_weight = weights[weight_idx]
|
|
474
|
+
conv_bias = weights.get(bias_idx_conv,
|
|
475
|
+
np.zeros(conv_weight.shape[-1],
|
|
476
|
+
dtype=np.int32))
|
|
477
|
+
|
|
478
|
+
# Compute folding factor
|
|
479
|
+
std = np.sqrt(var + epsilon)
|
|
480
|
+
factor = scale / std
|
|
481
|
+
|
|
482
|
+
# Fold into weights (assuming conv_weight shape:
|
|
483
|
+
# [H, W, C_in, C_out])
|
|
484
|
+
# Convert to float for computation
|
|
485
|
+
w_f = conv_weight.astype(np.float32)
|
|
486
|
+
b_f = conv_bias.astype(np.float32)
|
|
487
|
+
factor_f = factor.astype(np.float32).reshape(1, 1, 1, -1)
|
|
488
|
+
|
|
489
|
+
# Fold: w_new = w * factor
|
|
490
|
+
w_new = (w_f * factor_f).astype(np.int8)
|
|
491
|
+
|
|
492
|
+
# Fold bias: b_new = (b - mean) * factor + bias
|
|
493
|
+
mean_f = mean.astype(np.float32).reshape(1, -1)
|
|
494
|
+
b_new = ((b_f - mean_f) * factor_f.reshape(1, -1) + bias.astype(
|
|
495
|
+
np.float32).reshape(1, -1))
|
|
496
|
+
b_new = b_new.astype(np.int32).flatten()
|
|
497
|
+
|
|
498
|
+
# Update weights
|
|
499
|
+
weights[weight_idx] = w_new
|
|
500
|
+
if bias_idx_conv:
|
|
501
|
+
weights[bias_idx_conv] = b_new
|
|
502
|
+
else:
|
|
503
|
+
# Create bias if it didn't exist
|
|
504
|
+
bias_new_idx = max(weights.keys()) + 1 if weights else 1
|
|
505
|
+
weights[bias_new_idx] = b_new
|
|
506
|
+
op["input_indices"].append(bias_new_idx)
|
|
507
|
+
|
|
508
|
+
# Remove BN op
|
|
509
|
+
bn_output = next_op.get("output_indices", [])[0]
|
|
510
|
+
op["output_indices"] = [bn_output]
|
|
511
|
+
|
|
512
|
+
del ops[i + 1]
|
|
513
|
+
|
|
514
|
+
fused = True
|
|
515
|
+
self._fused_count += 1
|
|
516
|
+
self._log_change(" Fused CONV_2D + BATCH_NORM")
|
|
517
|
+
|
|
518
|
+
if fused:
|
|
519
|
+
model_info["ops"] = ops
|
|
520
|
+
|
|
521
|
+
return fused
|
|
522
|
+
|
|
523
|
+
def _fuse_conv_conv(self, model_info: Dict[str, Any]) -> bool:
|
|
524
|
+
"""
|
|
525
|
+
Fuse CONV_2D + CONV_2D into one.
|
|
526
|
+
|
|
527
|
+
Supports:
|
|
528
|
+
1. 1x1 conv + 1x1 conv -> 1x1 conv
|
|
529
|
+
2. 3x3 conv + 3x3 conv -> 5x5 conv
|
|
530
|
+
"""
|
|
531
|
+
ops = model_info.get("ops", [])
|
|
532
|
+
tensors = model_info.get("tensors", {})
|
|
533
|
+
weights = model_info.get("weights", {})
|
|
534
|
+
fused = False
|
|
535
|
+
i = 0
|
|
536
|
+
|
|
537
|
+
while i < len(ops) - 1:
|
|
538
|
+
op = ops[i]
|
|
539
|
+
next_op = ops[i + 1]
|
|
540
|
+
|
|
541
|
+
if op.get("op_name") != "CONV_2D" or next_op.get(
|
|
542
|
+
"op_name") != "CONV_2D":
|
|
543
|
+
i += 1
|
|
544
|
+
continue
|
|
545
|
+
|
|
546
|
+
# Check direct connection
|
|
547
|
+
conv_out = op.get("output_indices", [])[0]
|
|
548
|
+
next_input = next_op.get("input_indices", [])[0]
|
|
549
|
+
if conv_out != next_input:
|
|
550
|
+
i += 1
|
|
551
|
+
continue
|
|
552
|
+
|
|
553
|
+
# Check no activation on first conv
|
|
554
|
+
conv_params = op.get("conv_params", {})
|
|
555
|
+
if conv_params.get("activation"):
|
|
556
|
+
i += 1
|
|
557
|
+
continue
|
|
558
|
+
|
|
559
|
+
# Check conv_out is only used by the next conv
|
|
560
|
+
if self._count_uses(model_info, conv_out) > 1:
|
|
561
|
+
i += 1
|
|
562
|
+
continue
|
|
563
|
+
|
|
564
|
+
# Get params
|
|
565
|
+
c1 = op.get("conv_params", {})
|
|
566
|
+
c2 = next_op.get("conv_params", {})
|
|
567
|
+
|
|
568
|
+
# Get weight indices
|
|
569
|
+
conv1_inputs = op.get("input_indices", [])
|
|
570
|
+
conv2_inputs = next_op.get("input_indices", [])
|
|
571
|
+
if len(conv1_inputs) < 2 or len(conv2_inputs) < 2:
|
|
572
|
+
i += 1
|
|
573
|
+
continue
|
|
574
|
+
w1_idx = conv1_inputs[1]
|
|
575
|
+
w2_idx = conv2_inputs[1]
|
|
576
|
+
|
|
577
|
+
if w1_idx not in weights or w2_idx not in weights:
|
|
578
|
+
i += 1
|
|
579
|
+
continue
|
|
580
|
+
|
|
581
|
+
w1 = weights[w1_idx]
|
|
582
|
+
w2 = weights[w2_idx]
|
|
583
|
+
|
|
584
|
+
if len(w1.shape) != 4 or len(w2.shape) != 4:
|
|
585
|
+
i += 1
|
|
586
|
+
continue
|
|
587
|
+
|
|
588
|
+
if w1.shape[3] != w2.shape[2]:
|
|
589
|
+
i += 1
|
|
590
|
+
continue
|
|
591
|
+
|
|
592
|
+
C_in = w1.shape[2]
|
|
593
|
+
C_mid = w1.shape[3]
|
|
594
|
+
C_out = w2.shape[3]
|
|
595
|
+
|
|
596
|
+
k1 = c1.get("kernel_size")
|
|
597
|
+
k2 = c2.get("kernel_size")
|
|
598
|
+
s1 = c1.get("stride", 1)
|
|
599
|
+
s2 = c2.get("stride", 1)
|
|
600
|
+
p1 = c1.get("padding", "SAME")
|
|
601
|
+
p2 = c2.get("padding", "SAME")
|
|
602
|
+
|
|
603
|
+
# ----------------------------------------------------------------
|
|
604
|
+
# Pattern 1: 1x1 + 1x1 -> 1x1
|
|
605
|
+
# ----------------------------------------------------------------
|
|
606
|
+
if k1 == 1 and k2 == 1 and s1 == 1 and s2 == 1:
|
|
607
|
+
w1_f = w1.astype(np.float32)
|
|
608
|
+
w2_f = w2.astype(np.float32)
|
|
609
|
+
|
|
610
|
+
# w1: [1, 1, C_in, C_mid], w2: [1, 1, C_mid, C_out]
|
|
611
|
+
# w_fused: [1, 1, C_in, C_out]
|
|
612
|
+
w_fused = np.matmul(
|
|
613
|
+
w1_f.transpose(0, 1, 3, 2), # [1, 1, C_mid, C_in]
|
|
614
|
+
w2_f.transpose(0, 1, 2, 3) # [1, 1, C_mid, C_out]
|
|
615
|
+
)
|
|
616
|
+
w_fused = w_fused.astype(np.int8)
|
|
617
|
+
|
|
618
|
+
# Update weights
|
|
619
|
+
weights[w1_idx] = w_fused
|
|
620
|
+
|
|
621
|
+
# Keep second conv's params (but update kernel size)
|
|
622
|
+
next_op["conv_params"]["kernel_size"] = 1
|
|
623
|
+
|
|
624
|
+
# Remove first conv
|
|
625
|
+
op_inputs = op.get("input_indices", [])
|
|
626
|
+
next_op["input_indices"][0] = op_inputs[0]
|
|
627
|
+
|
|
628
|
+
# Clean up
|
|
629
|
+
if conv_out in tensors:
|
|
630
|
+
del tensors[conv_out]
|
|
631
|
+
if conv_out in weights:
|
|
632
|
+
del weights[conv_out]
|
|
633
|
+
|
|
634
|
+
del ops[i]
|
|
635
|
+
|
|
636
|
+
fused = True
|
|
637
|
+
self._fused_count += 1
|
|
638
|
+
self._log_change(" Fused CONV_2D + CONV_2D (1x1 + 1x1 -> 1x1)")
|
|
639
|
+
continue
|
|
640
|
+
|
|
641
|
+
# ----------------------------------------------------------------
|
|
642
|
+
# Pattern 2: 3x3 + 3x3 -> 5x5
|
|
643
|
+
# ----------------------------------------------------------------
|
|
644
|
+
if (k1 == 3 and k2 == 3 and s1 == 1 and s2 == 1
|
|
645
|
+
and p1 == "SAME" and p2 == "SAME"):
|
|
646
|
+
w1_f = w1.astype(np.float32)
|
|
647
|
+
w2_f = w2.astype(np.float32)
|
|
648
|
+
|
|
649
|
+
# w1: [3, 3, C_in, C_mid], w2: [3, 3, C_mid, C_out]
|
|
650
|
+
# w_fused: [5, 5, C_in, C_out]
|
|
651
|
+
w_fused = np.zeros((5, 5, C_in, C_out), dtype=np.float32)
|
|
652
|
+
|
|
653
|
+
for ic in range(C_in):
|
|
654
|
+
for oc in range(C_out):
|
|
655
|
+
# For each output channel, convolve w1 with w2
|
|
656
|
+
# w1[:, :, ic, :] -> [3, 3, C_mid]
|
|
657
|
+
# w2[:, :, :, oc] -> [3, 3, C_mid]
|
|
658
|
+
# Result is a 5x5 kernel
|
|
659
|
+
kernel = np.zeros((5, 5), dtype=np.float32)
|
|
660
|
+
kernel[1:4, 1:4] = w1_f[
|
|
661
|
+
:, :, ic, :] # Place w1 in center
|
|
662
|
+
for mid in range(C_mid):
|
|
663
|
+
w2_mid = w2_f[:, :, mid, oc]
|
|
664
|
+
kernel_tmp = kernel[:, :, mid] * w2_mid
|
|
665
|
+
w_fused[:, :, ic, oc] += kernel_tmp
|
|
666
|
+
|
|
667
|
+
w_fused = w_fused.astype(np.int8)
|
|
668
|
+
|
|
669
|
+
# Update weights
|
|
670
|
+
weights[w1_idx] = w_fused
|
|
671
|
+
|
|
672
|
+
# Update second conv params
|
|
673
|
+
next_op["conv_params"]["kernel_size"] = 5
|
|
674
|
+
next_op["conv_params"]["padding"] = "VALID"
|
|
675
|
+
|
|
676
|
+
# Remove first conv
|
|
677
|
+
op_inputs = op.get("input_indices", [])
|
|
678
|
+
next_op["input_indices"][0] = op_inputs[0]
|
|
679
|
+
|
|
680
|
+
# Clean up
|
|
681
|
+
if conv_out in tensors:
|
|
682
|
+
del tensors[conv_out]
|
|
683
|
+
if conv_out in weights:
|
|
684
|
+
del weights[conv_out]
|
|
685
|
+
|
|
686
|
+
del ops[i]
|
|
687
|
+
|
|
688
|
+
fused = True
|
|
689
|
+
self._fused_count += 1
|
|
690
|
+
self._log_change(" Fused CONV_2D + CONV_2D (3x3 + 3x3 -> 5x5)")
|
|
691
|
+
continue
|
|
692
|
+
|
|
693
|
+
i += 1
|
|
694
|
+
|
|
695
|
+
if fused:
|
|
696
|
+
model_info["ops"] = ops
|
|
697
|
+
|
|
698
|
+
return fused
|
|
699
|
+
|
|
700
|
+
def _update_tensor_refs_after_removal(
|
|
701
|
+
self,
|
|
702
|
+
model_info: Dict[str, Any],
|
|
703
|
+
removed_output: int,
|
|
704
|
+
replaced_by: int
|
|
705
|
+
) -> None:
|
|
706
|
+
"""Update all tensor references after removing an op."""
|
|
707
|
+
# Update all ops
|
|
708
|
+
for op in model_info.get("ops", []):
|
|
709
|
+
# Update input_indices
|
|
710
|
+
for idx, input_idx in enumerate(op.get("input_indices", [])):
|
|
711
|
+
if input_idx == removed_output:
|
|
712
|
+
op["input_indices"][idx] = replaced_by
|
|
713
|
+
|
|
714
|
+
# Update output_indices
|
|
715
|
+
for idx, output_idx in enumerate(op.get("output_indices", [])):
|
|
716
|
+
if output_idx == removed_output:
|
|
717
|
+
op["output_indices"][idx] = replaced_by
|
|
718
|
+
|
|
719
|
+
# Remove the dead tensor from tensors dict
|
|
720
|
+
if removed_output in model_info.get("tensors", {}):
|
|
721
|
+
del model_info["tensors"][removed_output]
|
|
722
|
+
if removed_output in model_info.get("weights", {}):
|
|
723
|
+
del model_info["weights"][removed_output]
|