tinymlc 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. TinyMLC/ANG/__init__.py +0 -0
  2. TinyMLC/ANG/args.py +86 -0
  3. TinyMLC/ANG/estimator.py +103 -0
  4. TinyMLC/ANG/estimator_hal.py +184 -0
  5. TinyMLC/ANG/estimator_qemu.py +257 -0
  6. TinyMLC/ANG/estimator_software.py +130 -0
  7. TinyMLC/ANG/model_builder.py +508 -0
  8. TinyMLC/ANG/model_generator.py +439 -0
  9. TinyMLC/ANG/model_info.py +283 -0
  10. TinyMLC/ANG/utils.py +420 -0
  11. TinyMLC/__init__.py +0 -0
  12. TinyMLC/cli.py +126 -0
  13. TinyMLC/codegen.py +877 -0
  14. TinyMLC/converter/__init__.py +0 -0
  15. TinyMLC/converter/export_weights.py +382 -0
  16. TinyMLC/converter/parser_litert.py +757 -0
  17. TinyMLC/converter/parser_onnx.py +649 -0
  18. TinyMLC/generate_lut.py +97 -0
  19. TinyMLC/handlers.py +325 -0
  20. TinyMLC/ops.py +76 -0
  21. TinyMLC/templates/lut.c.tpl +23 -0
  22. TinyMLC/templates/lut.h.tpl +67 -0
  23. TinyMLC/templates/model.c.tpl +314 -0
  24. TinyMLC/templates/model.h.tpl +66 -0
  25. TinyMLC/transform/__init__.py +0 -0
  26. TinyMLC/transform/algebraic.py +286 -0
  27. TinyMLC/transform/base.py +58 -0
  28. TinyMLC/transform/constant_folding.py +260 -0
  29. TinyMLC/transform/cse.py +192 -0
  30. TinyMLC/transform/dce.py +182 -0
  31. TinyMLC/transform/fusion.py +723 -0
  32. TinyMLC/transform/memory.py +200 -0
  33. TinyMLC/transform/pass_manager.py +101 -0
  34. TinyMLC/transform/simplify.py +515 -0
  35. tinymlc-0.1.0.dist-info/METADATA +49 -0
  36. tinymlc-0.1.0.dist-info/RECORD +47 -0
  37. tinymlc-0.1.0.dist-info/WHEEL +4 -0
  38. tinymlc-0.1.0.dist-info/entry_points.txt +2 -0
  39. tinymlc-0.1.0.dist-info/licenses/LICENSE +201 -0
  40. utils/__init__.py +0 -0
  41. utils/arm-none-eabi-gcc.cmake +53 -0
  42. utils/dump.py +86 -0
  43. utils/generate_onnx_models.py +183 -0
  44. utils/generate_tflite_models.py +236 -0
  45. utils/pack_macos.sh +88 -0
  46. utils/path.py +31 -0
  47. utils/riscv-none-elf-gcc.cmake +50 -0
@@ -0,0 +1,58 @@
1
+ # -*- coding: utf-8 -*-
2
+ # TinyMLC - Tiny Machine Learning Compiler
3
+ #
4
+ # Copyright (c) 2026 Jia Liu & TinyMLC Contributors
5
+ # SPDX-License-Identifier: Apache-2.0
6
+ #
7
+ # This file is part of TinyMLC.
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at:
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+
20
+ # Base class for all optimization passes.
21
+
22
+ from abc import ABC, abstractmethod
23
+ from typing import Dict, Any
24
+ import copy
25
+
26
+
27
+ class Pass(ABC):
28
+ """
29
+ Base class for all optimization passes.
30
+
31
+ Each pass takes a model_info dict, transforms it, and returns
32
+ the transformed model_info.
33
+ """
34
+
35
+ def __init__(self, name: str = None):
36
+ self.name = name or self.__class__.__name__
37
+ self._stats = {
38
+ "before": {},
39
+ "after": {},
40
+ "changes": [],
41
+ }
42
+
43
+ @abstractmethod
44
+ def run(self, model_info: Dict[str, Any]) -> Dict[str, Any]:
45
+ """Run the pass on model_info and return transformed model_info."""
46
+ pass
47
+
48
+ def get_stats(self) -> Dict[str, Any]:
49
+ """Return statistics about the pass execution."""
50
+ return self._stats
51
+
52
+ def _log_change(self, msg: str) -> None:
53
+ """Record a change made by this pass."""
54
+ self._stats["changes"].append(msg)
55
+
56
+ def _copy_model(self, model_info: Dict[str, Any]) -> Dict[str, Any]:
57
+ """Deep copy model_info to avoid mutating the original."""
58
+ return copy.deepcopy(model_info)
@@ -0,0 +1,260 @@
1
+ # -*- coding: utf-8 -*-
2
+ # TinyMLC - Tiny Machine Learning Compiler
3
+ #
4
+ # Copyright (c) 2026 Jia Liu & TinyMLC Contributors
5
+ # SPDX-License-Identifier: Apache-2.0
6
+ #
7
+ # This file is part of TinyMLC.
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at:
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+
20
+ # Constant folding optimization pass.
21
+
22
+ from typing import Dict, Any
23
+ import numpy as np
24
+
25
+ from TinyMLC.transform.base import Pass
26
+
27
+
28
+ class ConstantFolding(Pass):
29
+ """
30
+ Constant folding optimization pass.
31
+
32
+ This pass evaluates operations at compile time when all inputs
33
+ are constants (known at compile time).
34
+
35
+ Currently supports:
36
+ - Reshape with constant shape
37
+ - Transpose with constant permutation
38
+ - Concat with constant axis
39
+ - Add, Multiply, Subtract with constants
40
+
41
+ Future extensions:
42
+ - Softmax with constant input
43
+ - Mean with constant axis
44
+ """
45
+
46
+ def __init__(self, name: str = "ConstantFolding"):
47
+ super().__init__(name)
48
+ self._const_tensors: Dict[int, np.ndarray] = {}
49
+
50
+ def run(self, model_info: Dict[str, Any]) -> Dict[str, Any]:
51
+ """Run constant folding on model_info."""
52
+ model_info = self._copy_model(model_info)
53
+
54
+ # 1. Find all constant tensors (weights, bias, etc.)
55
+ self._collect_constants(model_info)
56
+
57
+ # 2. Scan ops and fold constants
58
+ self._fold_ops(model_info)
59
+
60
+ # 3. Prune unused tensors
61
+ self._prune_unused_tensors(model_info)
62
+
63
+ return model_info
64
+
65
+ def _collect_constants(self, model_info: Dict[str, Any]) -> None:
66
+ """Collect all tensors that are constant (weights, biases, etc.)."""
67
+ weights = model_info.get("weights", {})
68
+ self._const_tensors = {}
69
+
70
+ for idx, weight in weights.items():
71
+ if isinstance(weight, np.ndarray):
72
+ self._const_tensors[idx] = weight
73
+ elif isinstance(weight, list):
74
+ self._const_tensors[idx] = np.array(weight)
75
+ else:
76
+ # Scalar or other type
77
+ self._const_tensors[idx] = np.array([weight])
78
+
79
+ self._log_change(f"Found {len(self._const_tensors)} constant tensors")
80
+
81
+ def _fold_ops(self, model_info: Dict[str, Any]) -> None:
82
+ """Fold operations where all inputs are constants."""
83
+ ops = model_info.get("ops", [])
84
+ new_ops = []
85
+ folded_count = 0
86
+
87
+ for op in ops:
88
+ op_name = op.get("op_name")
89
+ folded = False
90
+
91
+ # Only fold if we can evaluate it now
92
+ if op_name == "RESHAPE":
93
+ folded = self._fold_reshape(model_info, op)
94
+ elif op_name == "TRANSPOSE":
95
+ folded = self._fold_transpose(model_info, op)
96
+ elif op_name in ("ADD", "MULTIPLY", "SUB"):
97
+ folded = self._fold_binary_op(model_info, op)
98
+ elif op_name == "MEAN":
99
+ folded = self._fold_mean(model_info, op)
100
+
101
+ if folded:
102
+ folded_count += 1
103
+ self._log_change(f"Folded {op_name}")
104
+ else:
105
+ new_ops.append(op)
106
+
107
+ if folded_count > 0:
108
+ model_info["ops"] = new_ops
109
+ self._log_change(f"Folded {folded_count} ops")
110
+
111
+ def _fold_reshape(
112
+ self, model_info: Dict[str, Any], op: Dict[str, Any]
113
+ ) -> bool:
114
+ """Fold reshape if input is constant."""
115
+ input_idx = op.get("input_indices", [])[0]
116
+ output_idx = op.get("output_indices", [])[0]
117
+
118
+ if input_idx in self._const_tensors:
119
+ try:
120
+ # Get the target shape from params
121
+ params = op.get("reshape_params", {})
122
+ target_shape = params.get("shape", [])
123
+ if not target_shape:
124
+ target_shape = params.get("target_shape", [])
125
+
126
+ data = self._const_tensors[input_idx]
127
+ folded = data.reshape(target_shape)
128
+
129
+ # Store as constant tensor
130
+ self._const_tensors[output_idx] = folded
131
+
132
+ # Add to weights so it gets written out
133
+ model_info["weights"][output_idx] = folded
134
+
135
+ self._log_change(
136
+ f" Reshape constant: {data.shape} -> {folded.shape}"
137
+ )
138
+ return True
139
+ except Exception as e:
140
+ print(f" Warning: failed to fold reshape: {e}")
141
+ return False
142
+ return False
143
+
144
+ def _fold_transpose(
145
+ self, model_info: Dict[str, Any], op: Dict[str, Any]
146
+ ) -> bool:
147
+ """Fold transpose if input is constant."""
148
+ input_idx = op.get("input_indices", [])[0]
149
+ output_idx = op.get("output_indices", [])[0]
150
+
151
+ if input_idx in self._const_tensors:
152
+ try:
153
+ params = op.get("transpose_params", {})
154
+ perm = params.get("perm", [])
155
+
156
+ data = self._const_tensors[input_idx]
157
+ folded = np.transpose(data, axes=perm or None)
158
+
159
+ self._const_tensors[output_idx] = folded
160
+ model_info["weights"][output_idx] = folded
161
+
162
+ self._log_change(
163
+ f" Transpose constant: {data.shape} -> {folded.shape}"
164
+ )
165
+ return True
166
+ except Exception as e:
167
+ print(f" Warning: failed to fold transpose: {e}")
168
+ return False
169
+ return False
170
+
171
+ def _fold_binary_op(
172
+ self, model_info: Dict[str, Any], op: Dict[str, Any]
173
+ ) -> bool:
174
+ """Fold binary ops (ADD, MULTIPLY, SUB) if all inputs constant."""
175
+ op_name = op.get("op_name")
176
+ input_indices = op.get("input_indices", [])
177
+ output_idx = op.get("output_indices", [0])[0]
178
+
179
+ if all(idx in self._const_tensors for idx in input_indices):
180
+ try:
181
+ a = self._const_tensors[input_indices[0]]
182
+ b = self._const_tensors[input_indices[1]]
183
+
184
+ if op_name == "ADD":
185
+ folded = a + b
186
+ elif op_name == "MULTIPLY":
187
+ folded = a * b
188
+ elif op_name == "SUB":
189
+ folded = a - b
190
+ else:
191
+ return False
192
+
193
+ self._const_tensors[output_idx] = folded
194
+ model_info["weights"][output_idx] = folded
195
+
196
+ self._log_change(
197
+ f" {op_name} constant: {a.shape} + {b.shape} "
198
+ f"-> {folded.shape}"
199
+ )
200
+ return True
201
+ except Exception as e:
202
+ print(f" Warning: failed to fold {op_name}: {e}")
203
+ return False
204
+ return False
205
+
206
+ def _fold_mean(
207
+ self, model_info: Dict[str, Any], op: Dict[str, Any]
208
+ ) -> bool:
209
+ """Fold MEAN if input is constant."""
210
+ input_idx = op.get("input_indices", [0])[0]
211
+ output_idx = op.get("output_indices", [0])[0]
212
+
213
+ if input_idx in self._const_tensors:
214
+ try:
215
+ params = op.get("mean_params", {})
216
+ axis = params.get("axis", None)
217
+ keepdims = params.get("keepdims", False)
218
+
219
+ data = self._const_tensors[input_idx]
220
+ folded = np.mean(data, axis=axis, keepdims=keepdims)
221
+
222
+ self._const_tensors[output_idx] = folded
223
+ model_info["weights"][output_idx] = folded
224
+
225
+ self._log_change(
226
+ f" Mean constant: {data.shape} -> {folded.shape}"
227
+ )
228
+ return True
229
+ except Exception as e:
230
+ print(f" Warning: failed to fold mean: {e}")
231
+ return False
232
+ return False
233
+
234
+ def _prune_unused_tensors(self, model_info: Dict[str, Any]) -> None:
235
+ """Remove tensors that are no longer used."""
236
+ # Get all used tensor indices from ops
237
+ used_indices = set()
238
+ for op in model_info.get("ops", []):
239
+ for idx in op.get("input_indices", []):
240
+ used_indices.add(idx)
241
+ for idx in op.get("output_indices", []):
242
+ used_indices.add(idx)
243
+
244
+ # Get input/output indices
245
+ for inp in model_info.get("input", []):
246
+ # Inputs don't have indices in this representation
247
+ pass
248
+
249
+ # Remove unused tensors
250
+ all_indices = set(model_info["tensors"].keys())
251
+ unused = all_indices - used_indices
252
+
253
+ for idx in unused:
254
+ if idx in model_info["tensors"]:
255
+ del model_info["tensors"][idx]
256
+ if idx in model_info["weights"]:
257
+ del model_info["weights"][idx]
258
+
259
+ if unused:
260
+ self._log_change(f"Removed {len(unused)} unused tensors")
@@ -0,0 +1,192 @@
1
+ # -*- coding: utf-8 -*-
2
+ # TinyMLC - Tiny Machine Learning Compiler
3
+ #
4
+ # Copyright (c) 2026 Jia Liu & TinyMLC Contributors
5
+ # SPDX-License-Identifier: Apache-2.0
6
+ #
7
+ # This file is part of TinyMLC.
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at:
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+
20
+ # Common Subexpression Elimination.
21
+
22
+ from typing import Dict, Any
23
+ import hashlib
24
+ import json
25
+
26
+ from TinyMLC.transform.base import Pass
27
+
28
+
29
+ class CommonSubexpressionElimination(Pass):
30
+ """
31
+ Common Subexpression Elimination.
32
+
33
+ Finds and eliminates duplicate computations:
34
+ - Same op with same inputs and same params
35
+ - Same constant tensor being computed multiple times
36
+
37
+ Strategy:
38
+ 1. Compute a signature for each op (op_name + input_indices + params)
39
+ 2. If two ops have the same signature, keep the first one
40
+ 3. Replace all uses of the later op's outputs with
41
+ the first op's outputs
42
+ """
43
+
44
+ def __init__(self, name: str = "CommonSubexpressionElimination"):
45
+ super().__init__(name)
46
+ self._signature_map: Dict[str, int] = {} # signature -> op_index
47
+ self._replace_map: Dict[int, int] = {} # old_tensor -> new_tensor
48
+
49
+ def run(self, model_info: Dict[str, Any]) -> Dict[str, Any]:
50
+ """Run CSE on model_info."""
51
+ model_info = self._copy_model(model_info)
52
+
53
+ self._signature_map.clear()
54
+ self._replace_map.clear()
55
+
56
+ ops = model_info.get("ops", [])
57
+ new_ops = []
58
+ eliminated_count = 0
59
+
60
+ for op in ops:
61
+ signature = self._compute_signature(op)
62
+
63
+ if signature in self._signature_map:
64
+ # Duplicate found: replace outputs with the original
65
+ orig_op_idx = self._signature_map[signature]
66
+ orig_op = new_ops[orig_op_idx]
67
+ self._replace_outputs(op, orig_op)
68
+ eliminated_count += 1
69
+ self._log_change(
70
+ f" Eliminated duplicate {op.get('op_name')} "
71
+ f"(outputs: {op.get('output_indices')} "
72
+ f"-> {orig_op.get('output_indices')})"
73
+ )
74
+ # Don't add this op to new_ops
75
+ else:
76
+ # New op: record its signature and keep it
77
+ self._signature_map[signature] = len(new_ops)
78
+ new_ops.append(op)
79
+
80
+ if eliminated_count > 0:
81
+ model_info["ops"] = new_ops
82
+ # Update all tensor references in remaining ops
83
+ self._update_tensor_refs(model_info)
84
+ self._log_change(
85
+ f"Eliminated {eliminated_count} duplicate expressions"
86
+ )
87
+
88
+ return model_info
89
+
90
+ def _compute_signature(self, op: Dict[str, Any]) -> str:
91
+ """
92
+ Compute a unique signature for an op.
93
+
94
+ The signature includes:
95
+ - op_name
96
+ - input_indices (sorted for commutativity)
97
+ - output_indices (for ops with multiple outputs)
98
+ - params (sorted, excluding irrelevant fields)
99
+ """
100
+ op_name = op.get("op_name", "UNKNOWN")
101
+
102
+ # Input indices: sorted for commutative ops? Not always safe.
103
+ # For now, keep the order as-is, as ops like SUB are not commutative.
104
+ input_indices = op.get("input_indices", [])
105
+
106
+ # Params: filter out fields that don't affect computation
107
+ params = op.get("params", {})
108
+ # Remove fields that are just metadata
109
+ skip_keys = {"name", "index", "state", "pass_flags"}
110
+ filtered_params = {
111
+ k: v for k, v in params.items()
112
+ if k not in skip_keys and not k.startswith("_")
113
+ }
114
+
115
+ # Build signature dict
116
+ sig = {
117
+ "op_name": op_name,
118
+ "input_indices": input_indices,
119
+ "params": filtered_params,
120
+ }
121
+
122
+ # Hash to a string
123
+ sig_str = json.dumps(sig, sort_keys=True)
124
+ return hashlib.sha256(sig_str.encode()).hexdigest()[:16]
125
+
126
+ def _replace_outputs(
127
+ self, dup_op: Dict[str, Any], orig_op: Dict[str, Any]
128
+ ) -> None:
129
+ """
130
+ Map outputs of dup_op to outputs of orig_op.
131
+
132
+ Assumes the output indices are in the same order.
133
+ """
134
+ dup_outputs = dup_op.get("output_indices", [])
135
+ orig_outputs = orig_op.get("output_indices", [])
136
+
137
+ if len(dup_outputs) != len(orig_outputs):
138
+ # Different number of outputs, can't replace
139
+ self._log_change(
140
+ f" Warning: output count mismatch "
141
+ f"({len(dup_outputs)} vs {len(orig_outputs)})"
142
+ )
143
+ return
144
+
145
+ for dup_idx, orig_idx in zip(dup_outputs, orig_outputs):
146
+ self._replace_map[dup_idx] = orig_idx
147
+
148
+ def _update_tensor_refs(self, model_info: Dict[str, Any]) -> None:
149
+ """
150
+ Update all tensor references in ops:
151
+ - Replace old tensor indices with new ones
152
+ - Remove any ops that now have duplicate inputs/outputs
153
+ """
154
+ if not self._replace_map:
155
+ return
156
+
157
+ ops = model_info.get("ops", [])
158
+
159
+ for op in ops:
160
+ # Update input_indices
161
+ input_indices = op.get("input_indices", [])
162
+ new_inputs = [
163
+ self._replace_map.get(idx, idx) for idx in input_indices
164
+ ]
165
+ op["input_indices"] = new_inputs
166
+
167
+ # Update output_indices
168
+ output_indices = op.get("output_indices", [])
169
+ new_outputs = [
170
+ self._replace_map.get(idx, idx) for idx in output_indices
171
+ ]
172
+ op["output_indices"] = new_outputs
173
+
174
+ # Update tensors dict: remove replaced tensors
175
+ tensors = model_info.get("tensors", {})
176
+ for old_idx in self._replace_map.keys():
177
+ if old_idx in tensors:
178
+ del tensors[old_idx]
179
+
180
+ # Update weights dict
181
+ weights = model_info.get("weights", {})
182
+ for old_idx in self._replace_map.keys():
183
+ if old_idx in weights:
184
+ del weights[old_idx]
185
+
186
+ # Update tensor references in input/output specs
187
+ # (tensor_index is metadata, we don't need to update it for CSE)
188
+ # But if we want to keep consistency, we could update it.
189
+
190
+ self._log_change(
191
+ f" Replaced {len(self._replace_map)} tensor references"
192
+ )
@@ -0,0 +1,182 @@
1
+ # -*- coding: utf-8 -*-
2
+ # TinyMLC - Tiny Machine Learning Compiler
3
+ #
4
+ # Copyright (c) 2026 Jia Liu & TinyMLC Contributors
5
+ # SPDX-License-Identifier: Apache-2.0
6
+ #
7
+ # This file is part of TinyMLC.
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at:
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+
20
+ # Dead Code Elimination pass.
21
+
22
+ from typing import Dict, Any, Set, List
23
+ from TinyMLC.transform.base import Pass
24
+
25
+
26
+ class DeadCodeElimination(Pass):
27
+ """
28
+ Dead Code Elimination.
29
+
30
+ Removes:
31
+ - Tensors that are never used as inputs to any op
32
+ - Ops whose outputs are never used
33
+ - Unreachable ops (in case of control flow, not implemented yet)
34
+
35
+ This pass should be run after each pass that may create dead code.
36
+ """
37
+
38
+ def __init__(self, name: str = "DeadCodeElimination"):
39
+ super().__init__(name)
40
+
41
+ def run(self, model_info: Dict[str, Any]) -> Dict[str, Any]:
42
+ """Run dead code elimination on model_info."""
43
+ model_info = self._copy_model(model_info)
44
+
45
+ changed = True
46
+ iteration = 0
47
+
48
+ while changed:
49
+ changed = False
50
+ iteration += 1
51
+
52
+ # 1. Find all used tensor indices
53
+ used_indices = self._collect_used_indices(model_info)
54
+
55
+ # 2. Remove dead tensors
56
+ dead_tensors = self._remove_dead_tensors(model_info, used_indices)
57
+ if dead_tensors:
58
+ changed = True
59
+ self._log_change(
60
+ f"Iteration {iteration}: removed {len(dead_tensors)} "
61
+ f"dead tensors"
62
+ )
63
+
64
+ # 3. Remove dead ops
65
+ dead_ops = self._remove_dead_ops(model_info, used_indices)
66
+ if dead_ops:
67
+ changed = True
68
+ self._log_change(
69
+ f"Iteration {iteration}: removed {len(dead_ops)} dead ops"
70
+ )
71
+
72
+ return model_info
73
+
74
+ def _collect_used_indices(self, model_info: Dict[str, Any]) -> Set[int]:
75
+ """
76
+ Collect all tensor indices that are used as inputs to any op.
77
+ """
78
+ used = set()
79
+
80
+ # 1. All input tensors are used
81
+ for inp in model_info.get("input", []):
82
+ idx = inp.get("tensor_index")
83
+ if idx is not None:
84
+ used.add(idx)
85
+
86
+ # 2. All output tensors are used
87
+ for out in model_info.get("output", []):
88
+ idx = out.get("tensor_index")
89
+ if idx is not None:
90
+ used.add(idx)
91
+
92
+ # 3. All tensor indices referenced by ops
93
+ for op in model_info.get("ops", []):
94
+ for idx in op.get("input_indices", []):
95
+ used.add(idx)
96
+ for idx in op.get("output_indices", []):
97
+ used.add(idx)
98
+
99
+ return used
100
+
101
+ def _remove_dead_tensors(
102
+ self,
103
+ model_info: Dict[str, Any],
104
+ used_indices: Set[int]
105
+ ) -> Set[int]:
106
+ """
107
+ Remove tensors that are not used as inputs to any op,
108
+ except outputs (they must be preserved).
109
+ """
110
+ # Output tensors must be preserved (they are the final result)
111
+ output_indices = set()
112
+ for out in model_info.get("output", []):
113
+ # Outputs are identified by name, not index
114
+ # We need to find which tensor index corresponds to each output
115
+ # For now, assume outputs are in tensors dict with some mapping
116
+ pass
117
+
118
+ # For simplicity: find tensors that are never used as inputs
119
+ all_indices = set(model_info.get("tensors", {}).keys())
120
+ dead = all_indices - used_indices
121
+
122
+ # Don't delete tensors that are explicitly marked as outputs
123
+ # This requires knowing which tensors are outputs.
124
+ # In our model_info, outputs are separate from tensors.
125
+ # For now, we keep all tensors that are outputs.
126
+
127
+ # Actually delete them
128
+ for idx in dead:
129
+ if idx in model_info.get("tensors", {}):
130
+ del model_info["tensors"][idx]
131
+ if idx in model_info.get("weights", {}):
132
+ del model_info["weights"][idx]
133
+
134
+ return dead
135
+
136
+ def _remove_dead_ops(
137
+ self,
138
+ model_info: Dict[str, Any],
139
+ used_indices: Set[int]
140
+ ) -> List[Dict[str, Any]]:
141
+ """
142
+ Remove ops whose output indices are never used as inputs.
143
+ """
144
+ # For each op, check if any of its outputs are used
145
+ ops = model_info.get("ops", [])
146
+ tensors = model_info.get("tensors", {})
147
+ dead_ops = []
148
+ alive_ops = []
149
+
150
+ for op in ops:
151
+ outputs = op.get("output_indices", [])
152
+ # Check if all output indices are in tensors.
153
+ all_outputs_valid = all(idx in tensors for idx in outputs)
154
+ # An op is alive if any of its outputs is used
155
+ is_alive = (
156
+ any(idx in used_indices for idx in outputs)
157
+ and all_outputs_valid
158
+ )
159
+
160
+ # Also: if this op produces an output tensor that is
161
+ # the final output
162
+ # For now, keep it if it's the last op in the graph
163
+ # (we'll use a more sophisticated analysis later)
164
+
165
+ if is_alive:
166
+ alive_ops.append(op)
167
+ else:
168
+ dead_ops.append(op)
169
+
170
+ if dead_ops:
171
+ model_info["ops"] = alive_ops
172
+
173
+ # Remove any tensors that were only produced by dead ops
174
+ # (they'll be caught by the tensor removal in the next iteration)
175
+ for op in dead_ops:
176
+ for idx in op.get("output_indices", []):
177
+ if idx in model_info.get("tensors", {}):
178
+ # Don't delete right away, let the tensor removal
179
+ # handle it
180
+ pass
181
+
182
+ return dead_ops