tinymlc 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- TinyMLC/ANG/__init__.py +0 -0
- TinyMLC/ANG/args.py +86 -0
- TinyMLC/ANG/estimator.py +103 -0
- TinyMLC/ANG/estimator_hal.py +184 -0
- TinyMLC/ANG/estimator_qemu.py +257 -0
- TinyMLC/ANG/estimator_software.py +130 -0
- TinyMLC/ANG/model_builder.py +508 -0
- TinyMLC/ANG/model_generator.py +439 -0
- TinyMLC/ANG/model_info.py +283 -0
- TinyMLC/ANG/utils.py +420 -0
- TinyMLC/__init__.py +0 -0
- TinyMLC/cli.py +126 -0
- TinyMLC/codegen.py +877 -0
- TinyMLC/converter/__init__.py +0 -0
- TinyMLC/converter/export_weights.py +382 -0
- TinyMLC/converter/parser_litert.py +757 -0
- TinyMLC/converter/parser_onnx.py +649 -0
- TinyMLC/generate_lut.py +97 -0
- TinyMLC/handlers.py +325 -0
- TinyMLC/ops.py +76 -0
- TinyMLC/templates/lut.c.tpl +23 -0
- TinyMLC/templates/lut.h.tpl +67 -0
- TinyMLC/templates/model.c.tpl +314 -0
- TinyMLC/templates/model.h.tpl +66 -0
- TinyMLC/transform/__init__.py +0 -0
- TinyMLC/transform/algebraic.py +286 -0
- TinyMLC/transform/base.py +58 -0
- TinyMLC/transform/constant_folding.py +260 -0
- TinyMLC/transform/cse.py +192 -0
- TinyMLC/transform/dce.py +182 -0
- TinyMLC/transform/fusion.py +723 -0
- TinyMLC/transform/memory.py +200 -0
- TinyMLC/transform/pass_manager.py +101 -0
- TinyMLC/transform/simplify.py +515 -0
- tinymlc-0.1.0.dist-info/METADATA +49 -0
- tinymlc-0.1.0.dist-info/RECORD +47 -0
- tinymlc-0.1.0.dist-info/WHEEL +4 -0
- tinymlc-0.1.0.dist-info/entry_points.txt +2 -0
- tinymlc-0.1.0.dist-info/licenses/LICENSE +201 -0
- utils/__init__.py +0 -0
- utils/arm-none-eabi-gcc.cmake +53 -0
- utils/dump.py +86 -0
- utils/generate_onnx_models.py +183 -0
- utils/generate_tflite_models.py +236 -0
- utils/pack_macos.sh +88 -0
- utils/path.py +31 -0
- utils/riscv-none-elf-gcc.cmake +50 -0
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# TinyMLC - Tiny Machine Learning Compiler
|
|
3
|
+
#
|
|
4
|
+
# Copyright (c) 2026 Jia Liu & TinyMLC Contributors
|
|
5
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
6
|
+
#
|
|
7
|
+
# This file is part of TinyMLC.
|
|
8
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
9
|
+
# you may not use this file except in compliance with the License.
|
|
10
|
+
# You may obtain a copy of the License at:
|
|
11
|
+
#
|
|
12
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
13
|
+
#
|
|
14
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
15
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
16
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
17
|
+
# See the License for the specific language governing permissions and
|
|
18
|
+
# limitations under the License.
|
|
19
|
+
|
|
20
|
+
# Base class for all optimization passes.
|
|
21
|
+
|
|
22
|
+
from abc import ABC, abstractmethod
|
|
23
|
+
from typing import Dict, Any
|
|
24
|
+
import copy
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class Pass(ABC):
|
|
28
|
+
"""
|
|
29
|
+
Base class for all optimization passes.
|
|
30
|
+
|
|
31
|
+
Each pass takes a model_info dict, transforms it, and returns
|
|
32
|
+
the transformed model_info.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(self, name: str = None):
|
|
36
|
+
self.name = name or self.__class__.__name__
|
|
37
|
+
self._stats = {
|
|
38
|
+
"before": {},
|
|
39
|
+
"after": {},
|
|
40
|
+
"changes": [],
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
@abstractmethod
|
|
44
|
+
def run(self, model_info: Dict[str, Any]) -> Dict[str, Any]:
|
|
45
|
+
"""Run the pass on model_info and return transformed model_info."""
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
def get_stats(self) -> Dict[str, Any]:
|
|
49
|
+
"""Return statistics about the pass execution."""
|
|
50
|
+
return self._stats
|
|
51
|
+
|
|
52
|
+
def _log_change(self, msg: str) -> None:
|
|
53
|
+
"""Record a change made by this pass."""
|
|
54
|
+
self._stats["changes"].append(msg)
|
|
55
|
+
|
|
56
|
+
def _copy_model(self, model_info: Dict[str, Any]) -> Dict[str, Any]:
|
|
57
|
+
"""Deep copy model_info to avoid mutating the original."""
|
|
58
|
+
return copy.deepcopy(model_info)
|
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# TinyMLC - Tiny Machine Learning Compiler
|
|
3
|
+
#
|
|
4
|
+
# Copyright (c) 2026 Jia Liu & TinyMLC Contributors
|
|
5
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
6
|
+
#
|
|
7
|
+
# This file is part of TinyMLC.
|
|
8
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
9
|
+
# you may not use this file except in compliance with the License.
|
|
10
|
+
# You may obtain a copy of the License at:
|
|
11
|
+
#
|
|
12
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
13
|
+
#
|
|
14
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
15
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
16
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
17
|
+
# See the License for the specific language governing permissions and
|
|
18
|
+
# limitations under the License.
|
|
19
|
+
|
|
20
|
+
# Constant folding optimization pass.
|
|
21
|
+
|
|
22
|
+
from typing import Dict, Any
|
|
23
|
+
import numpy as np
|
|
24
|
+
|
|
25
|
+
from TinyMLC.transform.base import Pass
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ConstantFolding(Pass):
|
|
29
|
+
"""
|
|
30
|
+
Constant folding optimization pass.
|
|
31
|
+
|
|
32
|
+
This pass evaluates operations at compile time when all inputs
|
|
33
|
+
are constants (known at compile time).
|
|
34
|
+
|
|
35
|
+
Currently supports:
|
|
36
|
+
- Reshape with constant shape
|
|
37
|
+
- Transpose with constant permutation
|
|
38
|
+
- Concat with constant axis
|
|
39
|
+
- Add, Multiply, Subtract with constants
|
|
40
|
+
|
|
41
|
+
Future extensions:
|
|
42
|
+
- Softmax with constant input
|
|
43
|
+
- Mean with constant axis
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self, name: str = "ConstantFolding"):
|
|
47
|
+
super().__init__(name)
|
|
48
|
+
self._const_tensors: Dict[int, np.ndarray] = {}
|
|
49
|
+
|
|
50
|
+
def run(self, model_info: Dict[str, Any]) -> Dict[str, Any]:
|
|
51
|
+
"""Run constant folding on model_info."""
|
|
52
|
+
model_info = self._copy_model(model_info)
|
|
53
|
+
|
|
54
|
+
# 1. Find all constant tensors (weights, bias, etc.)
|
|
55
|
+
self._collect_constants(model_info)
|
|
56
|
+
|
|
57
|
+
# 2. Scan ops and fold constants
|
|
58
|
+
self._fold_ops(model_info)
|
|
59
|
+
|
|
60
|
+
# 3. Prune unused tensors
|
|
61
|
+
self._prune_unused_tensors(model_info)
|
|
62
|
+
|
|
63
|
+
return model_info
|
|
64
|
+
|
|
65
|
+
def _collect_constants(self, model_info: Dict[str, Any]) -> None:
|
|
66
|
+
"""Collect all tensors that are constant (weights, biases, etc.)."""
|
|
67
|
+
weights = model_info.get("weights", {})
|
|
68
|
+
self._const_tensors = {}
|
|
69
|
+
|
|
70
|
+
for idx, weight in weights.items():
|
|
71
|
+
if isinstance(weight, np.ndarray):
|
|
72
|
+
self._const_tensors[idx] = weight
|
|
73
|
+
elif isinstance(weight, list):
|
|
74
|
+
self._const_tensors[idx] = np.array(weight)
|
|
75
|
+
else:
|
|
76
|
+
# Scalar or other type
|
|
77
|
+
self._const_tensors[idx] = np.array([weight])
|
|
78
|
+
|
|
79
|
+
self._log_change(f"Found {len(self._const_tensors)} constant tensors")
|
|
80
|
+
|
|
81
|
+
def _fold_ops(self, model_info: Dict[str, Any]) -> None:
|
|
82
|
+
"""Fold operations where all inputs are constants."""
|
|
83
|
+
ops = model_info.get("ops", [])
|
|
84
|
+
new_ops = []
|
|
85
|
+
folded_count = 0
|
|
86
|
+
|
|
87
|
+
for op in ops:
|
|
88
|
+
op_name = op.get("op_name")
|
|
89
|
+
folded = False
|
|
90
|
+
|
|
91
|
+
# Only fold if we can evaluate it now
|
|
92
|
+
if op_name == "RESHAPE":
|
|
93
|
+
folded = self._fold_reshape(model_info, op)
|
|
94
|
+
elif op_name == "TRANSPOSE":
|
|
95
|
+
folded = self._fold_transpose(model_info, op)
|
|
96
|
+
elif op_name in ("ADD", "MULTIPLY", "SUB"):
|
|
97
|
+
folded = self._fold_binary_op(model_info, op)
|
|
98
|
+
elif op_name == "MEAN":
|
|
99
|
+
folded = self._fold_mean(model_info, op)
|
|
100
|
+
|
|
101
|
+
if folded:
|
|
102
|
+
folded_count += 1
|
|
103
|
+
self._log_change(f"Folded {op_name}")
|
|
104
|
+
else:
|
|
105
|
+
new_ops.append(op)
|
|
106
|
+
|
|
107
|
+
if folded_count > 0:
|
|
108
|
+
model_info["ops"] = new_ops
|
|
109
|
+
self._log_change(f"Folded {folded_count} ops")
|
|
110
|
+
|
|
111
|
+
def _fold_reshape(
|
|
112
|
+
self, model_info: Dict[str, Any], op: Dict[str, Any]
|
|
113
|
+
) -> bool:
|
|
114
|
+
"""Fold reshape if input is constant."""
|
|
115
|
+
input_idx = op.get("input_indices", [])[0]
|
|
116
|
+
output_idx = op.get("output_indices", [])[0]
|
|
117
|
+
|
|
118
|
+
if input_idx in self._const_tensors:
|
|
119
|
+
try:
|
|
120
|
+
# Get the target shape from params
|
|
121
|
+
params = op.get("reshape_params", {})
|
|
122
|
+
target_shape = params.get("shape", [])
|
|
123
|
+
if not target_shape:
|
|
124
|
+
target_shape = params.get("target_shape", [])
|
|
125
|
+
|
|
126
|
+
data = self._const_tensors[input_idx]
|
|
127
|
+
folded = data.reshape(target_shape)
|
|
128
|
+
|
|
129
|
+
# Store as constant tensor
|
|
130
|
+
self._const_tensors[output_idx] = folded
|
|
131
|
+
|
|
132
|
+
# Add to weights so it gets written out
|
|
133
|
+
model_info["weights"][output_idx] = folded
|
|
134
|
+
|
|
135
|
+
self._log_change(
|
|
136
|
+
f" Reshape constant: {data.shape} -> {folded.shape}"
|
|
137
|
+
)
|
|
138
|
+
return True
|
|
139
|
+
except Exception as e:
|
|
140
|
+
print(f" Warning: failed to fold reshape: {e}")
|
|
141
|
+
return False
|
|
142
|
+
return False
|
|
143
|
+
|
|
144
|
+
def _fold_transpose(
|
|
145
|
+
self, model_info: Dict[str, Any], op: Dict[str, Any]
|
|
146
|
+
) -> bool:
|
|
147
|
+
"""Fold transpose if input is constant."""
|
|
148
|
+
input_idx = op.get("input_indices", [])[0]
|
|
149
|
+
output_idx = op.get("output_indices", [])[0]
|
|
150
|
+
|
|
151
|
+
if input_idx in self._const_tensors:
|
|
152
|
+
try:
|
|
153
|
+
params = op.get("transpose_params", {})
|
|
154
|
+
perm = params.get("perm", [])
|
|
155
|
+
|
|
156
|
+
data = self._const_tensors[input_idx]
|
|
157
|
+
folded = np.transpose(data, axes=perm or None)
|
|
158
|
+
|
|
159
|
+
self._const_tensors[output_idx] = folded
|
|
160
|
+
model_info["weights"][output_idx] = folded
|
|
161
|
+
|
|
162
|
+
self._log_change(
|
|
163
|
+
f" Transpose constant: {data.shape} -> {folded.shape}"
|
|
164
|
+
)
|
|
165
|
+
return True
|
|
166
|
+
except Exception as e:
|
|
167
|
+
print(f" Warning: failed to fold transpose: {e}")
|
|
168
|
+
return False
|
|
169
|
+
return False
|
|
170
|
+
|
|
171
|
+
def _fold_binary_op(
|
|
172
|
+
self, model_info: Dict[str, Any], op: Dict[str, Any]
|
|
173
|
+
) -> bool:
|
|
174
|
+
"""Fold binary ops (ADD, MULTIPLY, SUB) if all inputs constant."""
|
|
175
|
+
op_name = op.get("op_name")
|
|
176
|
+
input_indices = op.get("input_indices", [])
|
|
177
|
+
output_idx = op.get("output_indices", [0])[0]
|
|
178
|
+
|
|
179
|
+
if all(idx in self._const_tensors for idx in input_indices):
|
|
180
|
+
try:
|
|
181
|
+
a = self._const_tensors[input_indices[0]]
|
|
182
|
+
b = self._const_tensors[input_indices[1]]
|
|
183
|
+
|
|
184
|
+
if op_name == "ADD":
|
|
185
|
+
folded = a + b
|
|
186
|
+
elif op_name == "MULTIPLY":
|
|
187
|
+
folded = a * b
|
|
188
|
+
elif op_name == "SUB":
|
|
189
|
+
folded = a - b
|
|
190
|
+
else:
|
|
191
|
+
return False
|
|
192
|
+
|
|
193
|
+
self._const_tensors[output_idx] = folded
|
|
194
|
+
model_info["weights"][output_idx] = folded
|
|
195
|
+
|
|
196
|
+
self._log_change(
|
|
197
|
+
f" {op_name} constant: {a.shape} + {b.shape} "
|
|
198
|
+
f"-> {folded.shape}"
|
|
199
|
+
)
|
|
200
|
+
return True
|
|
201
|
+
except Exception as e:
|
|
202
|
+
print(f" Warning: failed to fold {op_name}: {e}")
|
|
203
|
+
return False
|
|
204
|
+
return False
|
|
205
|
+
|
|
206
|
+
def _fold_mean(
|
|
207
|
+
self, model_info: Dict[str, Any], op: Dict[str, Any]
|
|
208
|
+
) -> bool:
|
|
209
|
+
"""Fold MEAN if input is constant."""
|
|
210
|
+
input_idx = op.get("input_indices", [0])[0]
|
|
211
|
+
output_idx = op.get("output_indices", [0])[0]
|
|
212
|
+
|
|
213
|
+
if input_idx in self._const_tensors:
|
|
214
|
+
try:
|
|
215
|
+
params = op.get("mean_params", {})
|
|
216
|
+
axis = params.get("axis", None)
|
|
217
|
+
keepdims = params.get("keepdims", False)
|
|
218
|
+
|
|
219
|
+
data = self._const_tensors[input_idx]
|
|
220
|
+
folded = np.mean(data, axis=axis, keepdims=keepdims)
|
|
221
|
+
|
|
222
|
+
self._const_tensors[output_idx] = folded
|
|
223
|
+
model_info["weights"][output_idx] = folded
|
|
224
|
+
|
|
225
|
+
self._log_change(
|
|
226
|
+
f" Mean constant: {data.shape} -> {folded.shape}"
|
|
227
|
+
)
|
|
228
|
+
return True
|
|
229
|
+
except Exception as e:
|
|
230
|
+
print(f" Warning: failed to fold mean: {e}")
|
|
231
|
+
return False
|
|
232
|
+
return False
|
|
233
|
+
|
|
234
|
+
def _prune_unused_tensors(self, model_info: Dict[str, Any]) -> None:
|
|
235
|
+
"""Remove tensors that are no longer used."""
|
|
236
|
+
# Get all used tensor indices from ops
|
|
237
|
+
used_indices = set()
|
|
238
|
+
for op in model_info.get("ops", []):
|
|
239
|
+
for idx in op.get("input_indices", []):
|
|
240
|
+
used_indices.add(idx)
|
|
241
|
+
for idx in op.get("output_indices", []):
|
|
242
|
+
used_indices.add(idx)
|
|
243
|
+
|
|
244
|
+
# Get input/output indices
|
|
245
|
+
for inp in model_info.get("input", []):
|
|
246
|
+
# Inputs don't have indices in this representation
|
|
247
|
+
pass
|
|
248
|
+
|
|
249
|
+
# Remove unused tensors
|
|
250
|
+
all_indices = set(model_info["tensors"].keys())
|
|
251
|
+
unused = all_indices - used_indices
|
|
252
|
+
|
|
253
|
+
for idx in unused:
|
|
254
|
+
if idx in model_info["tensors"]:
|
|
255
|
+
del model_info["tensors"][idx]
|
|
256
|
+
if idx in model_info["weights"]:
|
|
257
|
+
del model_info["weights"][idx]
|
|
258
|
+
|
|
259
|
+
if unused:
|
|
260
|
+
self._log_change(f"Removed {len(unused)} unused tensors")
|
TinyMLC/transform/cse.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# TinyMLC - Tiny Machine Learning Compiler
|
|
3
|
+
#
|
|
4
|
+
# Copyright (c) 2026 Jia Liu & TinyMLC Contributors
|
|
5
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
6
|
+
#
|
|
7
|
+
# This file is part of TinyMLC.
|
|
8
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
9
|
+
# you may not use this file except in compliance with the License.
|
|
10
|
+
# You may obtain a copy of the License at:
|
|
11
|
+
#
|
|
12
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
13
|
+
#
|
|
14
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
15
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
16
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
17
|
+
# See the License for the specific language governing permissions and
|
|
18
|
+
# limitations under the License.
|
|
19
|
+
|
|
20
|
+
# Common Subexpression Elimination.
|
|
21
|
+
|
|
22
|
+
from typing import Dict, Any
|
|
23
|
+
import hashlib
|
|
24
|
+
import json
|
|
25
|
+
|
|
26
|
+
from TinyMLC.transform.base import Pass
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class CommonSubexpressionElimination(Pass):
|
|
30
|
+
"""
|
|
31
|
+
Common Subexpression Elimination.
|
|
32
|
+
|
|
33
|
+
Finds and eliminates duplicate computations:
|
|
34
|
+
- Same op with same inputs and same params
|
|
35
|
+
- Same constant tensor being computed multiple times
|
|
36
|
+
|
|
37
|
+
Strategy:
|
|
38
|
+
1. Compute a signature for each op (op_name + input_indices + params)
|
|
39
|
+
2. If two ops have the same signature, keep the first one
|
|
40
|
+
3. Replace all uses of the later op's outputs with
|
|
41
|
+
the first op's outputs
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(self, name: str = "CommonSubexpressionElimination"):
|
|
45
|
+
super().__init__(name)
|
|
46
|
+
self._signature_map: Dict[str, int] = {} # signature -> op_index
|
|
47
|
+
self._replace_map: Dict[int, int] = {} # old_tensor -> new_tensor
|
|
48
|
+
|
|
49
|
+
def run(self, model_info: Dict[str, Any]) -> Dict[str, Any]:
|
|
50
|
+
"""Run CSE on model_info."""
|
|
51
|
+
model_info = self._copy_model(model_info)
|
|
52
|
+
|
|
53
|
+
self._signature_map.clear()
|
|
54
|
+
self._replace_map.clear()
|
|
55
|
+
|
|
56
|
+
ops = model_info.get("ops", [])
|
|
57
|
+
new_ops = []
|
|
58
|
+
eliminated_count = 0
|
|
59
|
+
|
|
60
|
+
for op in ops:
|
|
61
|
+
signature = self._compute_signature(op)
|
|
62
|
+
|
|
63
|
+
if signature in self._signature_map:
|
|
64
|
+
# Duplicate found: replace outputs with the original
|
|
65
|
+
orig_op_idx = self._signature_map[signature]
|
|
66
|
+
orig_op = new_ops[orig_op_idx]
|
|
67
|
+
self._replace_outputs(op, orig_op)
|
|
68
|
+
eliminated_count += 1
|
|
69
|
+
self._log_change(
|
|
70
|
+
f" Eliminated duplicate {op.get('op_name')} "
|
|
71
|
+
f"(outputs: {op.get('output_indices')} "
|
|
72
|
+
f"-> {orig_op.get('output_indices')})"
|
|
73
|
+
)
|
|
74
|
+
# Don't add this op to new_ops
|
|
75
|
+
else:
|
|
76
|
+
# New op: record its signature and keep it
|
|
77
|
+
self._signature_map[signature] = len(new_ops)
|
|
78
|
+
new_ops.append(op)
|
|
79
|
+
|
|
80
|
+
if eliminated_count > 0:
|
|
81
|
+
model_info["ops"] = new_ops
|
|
82
|
+
# Update all tensor references in remaining ops
|
|
83
|
+
self._update_tensor_refs(model_info)
|
|
84
|
+
self._log_change(
|
|
85
|
+
f"Eliminated {eliminated_count} duplicate expressions"
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
return model_info
|
|
89
|
+
|
|
90
|
+
def _compute_signature(self, op: Dict[str, Any]) -> str:
|
|
91
|
+
"""
|
|
92
|
+
Compute a unique signature for an op.
|
|
93
|
+
|
|
94
|
+
The signature includes:
|
|
95
|
+
- op_name
|
|
96
|
+
- input_indices (sorted for commutativity)
|
|
97
|
+
- output_indices (for ops with multiple outputs)
|
|
98
|
+
- params (sorted, excluding irrelevant fields)
|
|
99
|
+
"""
|
|
100
|
+
op_name = op.get("op_name", "UNKNOWN")
|
|
101
|
+
|
|
102
|
+
# Input indices: sorted for commutative ops? Not always safe.
|
|
103
|
+
# For now, keep the order as-is, as ops like SUB are not commutative.
|
|
104
|
+
input_indices = op.get("input_indices", [])
|
|
105
|
+
|
|
106
|
+
# Params: filter out fields that don't affect computation
|
|
107
|
+
params = op.get("params", {})
|
|
108
|
+
# Remove fields that are just metadata
|
|
109
|
+
skip_keys = {"name", "index", "state", "pass_flags"}
|
|
110
|
+
filtered_params = {
|
|
111
|
+
k: v for k, v in params.items()
|
|
112
|
+
if k not in skip_keys and not k.startswith("_")
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
# Build signature dict
|
|
116
|
+
sig = {
|
|
117
|
+
"op_name": op_name,
|
|
118
|
+
"input_indices": input_indices,
|
|
119
|
+
"params": filtered_params,
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
# Hash to a string
|
|
123
|
+
sig_str = json.dumps(sig, sort_keys=True)
|
|
124
|
+
return hashlib.sha256(sig_str.encode()).hexdigest()[:16]
|
|
125
|
+
|
|
126
|
+
def _replace_outputs(
|
|
127
|
+
self, dup_op: Dict[str, Any], orig_op: Dict[str, Any]
|
|
128
|
+
) -> None:
|
|
129
|
+
"""
|
|
130
|
+
Map outputs of dup_op to outputs of orig_op.
|
|
131
|
+
|
|
132
|
+
Assumes the output indices are in the same order.
|
|
133
|
+
"""
|
|
134
|
+
dup_outputs = dup_op.get("output_indices", [])
|
|
135
|
+
orig_outputs = orig_op.get("output_indices", [])
|
|
136
|
+
|
|
137
|
+
if len(dup_outputs) != len(orig_outputs):
|
|
138
|
+
# Different number of outputs, can't replace
|
|
139
|
+
self._log_change(
|
|
140
|
+
f" Warning: output count mismatch "
|
|
141
|
+
f"({len(dup_outputs)} vs {len(orig_outputs)})"
|
|
142
|
+
)
|
|
143
|
+
return
|
|
144
|
+
|
|
145
|
+
for dup_idx, orig_idx in zip(dup_outputs, orig_outputs):
|
|
146
|
+
self._replace_map[dup_idx] = orig_idx
|
|
147
|
+
|
|
148
|
+
def _update_tensor_refs(self, model_info: Dict[str, Any]) -> None:
|
|
149
|
+
"""
|
|
150
|
+
Update all tensor references in ops:
|
|
151
|
+
- Replace old tensor indices with new ones
|
|
152
|
+
- Remove any ops that now have duplicate inputs/outputs
|
|
153
|
+
"""
|
|
154
|
+
if not self._replace_map:
|
|
155
|
+
return
|
|
156
|
+
|
|
157
|
+
ops = model_info.get("ops", [])
|
|
158
|
+
|
|
159
|
+
for op in ops:
|
|
160
|
+
# Update input_indices
|
|
161
|
+
input_indices = op.get("input_indices", [])
|
|
162
|
+
new_inputs = [
|
|
163
|
+
self._replace_map.get(idx, idx) for idx in input_indices
|
|
164
|
+
]
|
|
165
|
+
op["input_indices"] = new_inputs
|
|
166
|
+
|
|
167
|
+
# Update output_indices
|
|
168
|
+
output_indices = op.get("output_indices", [])
|
|
169
|
+
new_outputs = [
|
|
170
|
+
self._replace_map.get(idx, idx) for idx in output_indices
|
|
171
|
+
]
|
|
172
|
+
op["output_indices"] = new_outputs
|
|
173
|
+
|
|
174
|
+
# Update tensors dict: remove replaced tensors
|
|
175
|
+
tensors = model_info.get("tensors", {})
|
|
176
|
+
for old_idx in self._replace_map.keys():
|
|
177
|
+
if old_idx in tensors:
|
|
178
|
+
del tensors[old_idx]
|
|
179
|
+
|
|
180
|
+
# Update weights dict
|
|
181
|
+
weights = model_info.get("weights", {})
|
|
182
|
+
for old_idx in self._replace_map.keys():
|
|
183
|
+
if old_idx in weights:
|
|
184
|
+
del weights[old_idx]
|
|
185
|
+
|
|
186
|
+
# Update tensor references in input/output specs
|
|
187
|
+
# (tensor_index is metadata, we don't need to update it for CSE)
|
|
188
|
+
# But if we want to keep consistency, we could update it.
|
|
189
|
+
|
|
190
|
+
self._log_change(
|
|
191
|
+
f" Replaced {len(self._replace_map)} tensor references"
|
|
192
|
+
)
|
TinyMLC/transform/dce.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# TinyMLC - Tiny Machine Learning Compiler
|
|
3
|
+
#
|
|
4
|
+
# Copyright (c) 2026 Jia Liu & TinyMLC Contributors
|
|
5
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
6
|
+
#
|
|
7
|
+
# This file is part of TinyMLC.
|
|
8
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
9
|
+
# you may not use this file except in compliance with the License.
|
|
10
|
+
# You may obtain a copy of the License at:
|
|
11
|
+
#
|
|
12
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
13
|
+
#
|
|
14
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
15
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
16
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
17
|
+
# See the License for the specific language governing permissions and
|
|
18
|
+
# limitations under the License.
|
|
19
|
+
|
|
20
|
+
# Dead Code Elimination pass.
|
|
21
|
+
|
|
22
|
+
from typing import Dict, Any, Set, List
|
|
23
|
+
from TinyMLC.transform.base import Pass
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class DeadCodeElimination(Pass):
|
|
27
|
+
"""
|
|
28
|
+
Dead Code Elimination.
|
|
29
|
+
|
|
30
|
+
Removes:
|
|
31
|
+
- Tensors that are never used as inputs to any op
|
|
32
|
+
- Ops whose outputs are never used
|
|
33
|
+
- Unreachable ops (in case of control flow, not implemented yet)
|
|
34
|
+
|
|
35
|
+
This pass should be run after each pass that may create dead code.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(self, name: str = "DeadCodeElimination"):
|
|
39
|
+
super().__init__(name)
|
|
40
|
+
|
|
41
|
+
def run(self, model_info: Dict[str, Any]) -> Dict[str, Any]:
|
|
42
|
+
"""Run dead code elimination on model_info."""
|
|
43
|
+
model_info = self._copy_model(model_info)
|
|
44
|
+
|
|
45
|
+
changed = True
|
|
46
|
+
iteration = 0
|
|
47
|
+
|
|
48
|
+
while changed:
|
|
49
|
+
changed = False
|
|
50
|
+
iteration += 1
|
|
51
|
+
|
|
52
|
+
# 1. Find all used tensor indices
|
|
53
|
+
used_indices = self._collect_used_indices(model_info)
|
|
54
|
+
|
|
55
|
+
# 2. Remove dead tensors
|
|
56
|
+
dead_tensors = self._remove_dead_tensors(model_info, used_indices)
|
|
57
|
+
if dead_tensors:
|
|
58
|
+
changed = True
|
|
59
|
+
self._log_change(
|
|
60
|
+
f"Iteration {iteration}: removed {len(dead_tensors)} "
|
|
61
|
+
f"dead tensors"
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# 3. Remove dead ops
|
|
65
|
+
dead_ops = self._remove_dead_ops(model_info, used_indices)
|
|
66
|
+
if dead_ops:
|
|
67
|
+
changed = True
|
|
68
|
+
self._log_change(
|
|
69
|
+
f"Iteration {iteration}: removed {len(dead_ops)} dead ops"
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
return model_info
|
|
73
|
+
|
|
74
|
+
def _collect_used_indices(self, model_info: Dict[str, Any]) -> Set[int]:
|
|
75
|
+
"""
|
|
76
|
+
Collect all tensor indices that are used as inputs to any op.
|
|
77
|
+
"""
|
|
78
|
+
used = set()
|
|
79
|
+
|
|
80
|
+
# 1. All input tensors are used
|
|
81
|
+
for inp in model_info.get("input", []):
|
|
82
|
+
idx = inp.get("tensor_index")
|
|
83
|
+
if idx is not None:
|
|
84
|
+
used.add(idx)
|
|
85
|
+
|
|
86
|
+
# 2. All output tensors are used
|
|
87
|
+
for out in model_info.get("output", []):
|
|
88
|
+
idx = out.get("tensor_index")
|
|
89
|
+
if idx is not None:
|
|
90
|
+
used.add(idx)
|
|
91
|
+
|
|
92
|
+
# 3. All tensor indices referenced by ops
|
|
93
|
+
for op in model_info.get("ops", []):
|
|
94
|
+
for idx in op.get("input_indices", []):
|
|
95
|
+
used.add(idx)
|
|
96
|
+
for idx in op.get("output_indices", []):
|
|
97
|
+
used.add(idx)
|
|
98
|
+
|
|
99
|
+
return used
|
|
100
|
+
|
|
101
|
+
def _remove_dead_tensors(
|
|
102
|
+
self,
|
|
103
|
+
model_info: Dict[str, Any],
|
|
104
|
+
used_indices: Set[int]
|
|
105
|
+
) -> Set[int]:
|
|
106
|
+
"""
|
|
107
|
+
Remove tensors that are not used as inputs to any op,
|
|
108
|
+
except outputs (they must be preserved).
|
|
109
|
+
"""
|
|
110
|
+
# Output tensors must be preserved (they are the final result)
|
|
111
|
+
output_indices = set()
|
|
112
|
+
for out in model_info.get("output", []):
|
|
113
|
+
# Outputs are identified by name, not index
|
|
114
|
+
# We need to find which tensor index corresponds to each output
|
|
115
|
+
# For now, assume outputs are in tensors dict with some mapping
|
|
116
|
+
pass
|
|
117
|
+
|
|
118
|
+
# For simplicity: find tensors that are never used as inputs
|
|
119
|
+
all_indices = set(model_info.get("tensors", {}).keys())
|
|
120
|
+
dead = all_indices - used_indices
|
|
121
|
+
|
|
122
|
+
# Don't delete tensors that are explicitly marked as outputs
|
|
123
|
+
# This requires knowing which tensors are outputs.
|
|
124
|
+
# In our model_info, outputs are separate from tensors.
|
|
125
|
+
# For now, we keep all tensors that are outputs.
|
|
126
|
+
|
|
127
|
+
# Actually delete them
|
|
128
|
+
for idx in dead:
|
|
129
|
+
if idx in model_info.get("tensors", {}):
|
|
130
|
+
del model_info["tensors"][idx]
|
|
131
|
+
if idx in model_info.get("weights", {}):
|
|
132
|
+
del model_info["weights"][idx]
|
|
133
|
+
|
|
134
|
+
return dead
|
|
135
|
+
|
|
136
|
+
def _remove_dead_ops(
|
|
137
|
+
self,
|
|
138
|
+
model_info: Dict[str, Any],
|
|
139
|
+
used_indices: Set[int]
|
|
140
|
+
) -> List[Dict[str, Any]]:
|
|
141
|
+
"""
|
|
142
|
+
Remove ops whose output indices are never used as inputs.
|
|
143
|
+
"""
|
|
144
|
+
# For each op, check if any of its outputs are used
|
|
145
|
+
ops = model_info.get("ops", [])
|
|
146
|
+
tensors = model_info.get("tensors", {})
|
|
147
|
+
dead_ops = []
|
|
148
|
+
alive_ops = []
|
|
149
|
+
|
|
150
|
+
for op in ops:
|
|
151
|
+
outputs = op.get("output_indices", [])
|
|
152
|
+
# Check if all output indices are in tensors.
|
|
153
|
+
all_outputs_valid = all(idx in tensors for idx in outputs)
|
|
154
|
+
# An op is alive if any of its outputs is used
|
|
155
|
+
is_alive = (
|
|
156
|
+
any(idx in used_indices for idx in outputs)
|
|
157
|
+
and all_outputs_valid
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
# Also: if this op produces an output tensor that is
|
|
161
|
+
# the final output
|
|
162
|
+
# For now, keep it if it's the last op in the graph
|
|
163
|
+
# (we'll use a more sophisticated analysis later)
|
|
164
|
+
|
|
165
|
+
if is_alive:
|
|
166
|
+
alive_ops.append(op)
|
|
167
|
+
else:
|
|
168
|
+
dead_ops.append(op)
|
|
169
|
+
|
|
170
|
+
if dead_ops:
|
|
171
|
+
model_info["ops"] = alive_ops
|
|
172
|
+
|
|
173
|
+
# Remove any tensors that were only produced by dead ops
|
|
174
|
+
# (they'll be caught by the tensor removal in the next iteration)
|
|
175
|
+
for op in dead_ops:
|
|
176
|
+
for idx in op.get("output_indices", []):
|
|
177
|
+
if idx in model_info.get("tensors", {}):
|
|
178
|
+
# Don't delete right away, let the tensor removal
|
|
179
|
+
# handle it
|
|
180
|
+
pass
|
|
181
|
+
|
|
182
|
+
return dead_ops
|