tinymlc 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. TinyMLC/ANG/__init__.py +0 -0
  2. TinyMLC/ANG/args.py +86 -0
  3. TinyMLC/ANG/estimator.py +103 -0
  4. TinyMLC/ANG/estimator_hal.py +184 -0
  5. TinyMLC/ANG/estimator_qemu.py +257 -0
  6. TinyMLC/ANG/estimator_software.py +130 -0
  7. TinyMLC/ANG/model_builder.py +508 -0
  8. TinyMLC/ANG/model_generator.py +439 -0
  9. TinyMLC/ANG/model_info.py +283 -0
  10. TinyMLC/ANG/utils.py +420 -0
  11. TinyMLC/__init__.py +0 -0
  12. TinyMLC/cli.py +126 -0
  13. TinyMLC/codegen.py +877 -0
  14. TinyMLC/converter/__init__.py +0 -0
  15. TinyMLC/converter/export_weights.py +382 -0
  16. TinyMLC/converter/parser_litert.py +757 -0
  17. TinyMLC/converter/parser_onnx.py +649 -0
  18. TinyMLC/generate_lut.py +97 -0
  19. TinyMLC/handlers.py +325 -0
  20. TinyMLC/ops.py +76 -0
  21. TinyMLC/templates/lut.c.tpl +23 -0
  22. TinyMLC/templates/lut.h.tpl +67 -0
  23. TinyMLC/templates/model.c.tpl +314 -0
  24. TinyMLC/templates/model.h.tpl +66 -0
  25. TinyMLC/transform/__init__.py +0 -0
  26. TinyMLC/transform/algebraic.py +286 -0
  27. TinyMLC/transform/base.py +58 -0
  28. TinyMLC/transform/constant_folding.py +260 -0
  29. TinyMLC/transform/cse.py +192 -0
  30. TinyMLC/transform/dce.py +182 -0
  31. TinyMLC/transform/fusion.py +723 -0
  32. TinyMLC/transform/memory.py +200 -0
  33. TinyMLC/transform/pass_manager.py +101 -0
  34. TinyMLC/transform/simplify.py +515 -0
  35. tinymlc-0.1.0.dist-info/METADATA +49 -0
  36. tinymlc-0.1.0.dist-info/RECORD +47 -0
  37. tinymlc-0.1.0.dist-info/WHEEL +4 -0
  38. tinymlc-0.1.0.dist-info/entry_points.txt +2 -0
  39. tinymlc-0.1.0.dist-info/licenses/LICENSE +201 -0
  40. utils/__init__.py +0 -0
  41. utils/arm-none-eabi-gcc.cmake +53 -0
  42. utils/dump.py +86 -0
  43. utils/generate_onnx_models.py +183 -0
  44. utils/generate_tflite_models.py +236 -0
  45. utils/pack_macos.sh +88 -0
  46. utils/path.py +31 -0
  47. utils/riscv-none-elf-gcc.cmake +50 -0
@@ -0,0 +1,314 @@
1
+ /* TinyMLC - Tiny Machine Learning Compiler
2
+ *
3
+ * Copyright (c) 2026 Jia Liu & TinyMLC Contributors
4
+ * SPDX-License-Identifier: Apache-2.0
5
+ *
6
+ * This file is part of TinyMLC.
7
+ * Licensed under the Apache License, Version 2.0 (the "License");
8
+ * you may not use this file except in compliance with the License.
9
+ * You may obtain a copy of the License at:
10
+ *
11
+ * http://www.apache.org/licenses/LICENSE-2.0
12
+ *
13
+ * Unless required by applicable law or agreed to in writing, software
14
+ * distributed under the License is distributed on an "AS IS" BASIS,
15
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ * See the License for the specific language governing permissions and
17
+ * limitations under the License.
18
+ */
19
+
20
+ // Auto-generated code, do not modify manually
21
+ // Generated by tinymlc
22
+
23
+ #include "tinymlc.h"
24
+ {{ includes }}
25
+ #include "model.h"
26
+ #include "debug_print.h"
27
+ #include <stddef.h>
28
+
29
+ // Input/output sizes
30
+ #define INPUT_SIZE {{ input_size }}
31
+ #define OUTPUT_SIZE {{ output_size }}
32
+
33
+ // Intermediate tensor memory (static allocation, placed outside function)
34
+ {% for tensor in tensors_to_define %}
35
+ {% if target == "host" %}
36
+ {{ tensor.type }} tensor_{{ tensor.index }}[{{ tensor.size }}];
37
+ {% else %}
38
+ {{ tensor.type }} tensor_{{ tensor.index }}[{{ tensor.size }}] \
39
+ __attribute__((section(".bss")));
40
+ {% endif %}
41
+ {% endfor %}
42
+
43
+ {% if has_lstm %}
44
+ // LSTM parameters
45
+ #define LSTM_TIME_STEPS {{ lstm_time_steps }}
46
+ #define LSTM_HIDDEN_SIZE {{ lstm_hidden_size }}
47
+ #define TINYMLC_HAS_LSTM
48
+ {% endif %}
49
+
50
+ // Inference function
51
+ {% if inputs_count == 1 %}
52
+ void {{ inference_func }}(const int8_t* input, int8_t* output) {
53
+ // Input tensor mapping
54
+ int8_t* tensor_{{ input_tensor_indices[0] }} = (int8_t*)input;
55
+ {% elif inputs_count == 2 %}
56
+ void {{ inference_func }}(
57
+ const int8_t* input1,
58
+ const int8_t* input2,
59
+ int8_t* output
60
+ ) {
61
+ // Input tensor mapping
62
+ int8_t* tensor_{{ input_tensor_indices[0] }} = (int8_t*)input1;
63
+ int8_t* tensor_{{ input_tensor_indices[1] }} = (int8_t*)input2;
64
+ {% endif %}
65
+ // Execute operators in order
66
+ {% for op in execution_order %}
67
+ {% if op.op_name == "UNIDIRECTIONAL_SEQUENCE_LSTM" %}
68
+ tmlc_unidirectional_sequence_lstm_s8(
69
+ tensor_{{ op.input_indices[0] }},
70
+ lstm_input_weights,
71
+ lstm_recurrent_weights,
72
+ lstm_bias,
73
+ tensor_{{ op.output_indices[0] }},
74
+ NULL, NULL,
75
+ {{ op.lstm_params.time_steps }},
76
+ {{ op.lstm_params.batch_size }},
77
+ {{ op.lstm_params.input_size }},
78
+ {{ op.lstm_params.hidden_size }}
79
+ );
80
+ {% elif op.op_name == "FULLY_CONNECTED" %}
81
+ tmlc_fully_connected_s8(
82
+ tensor_{{ op.input_indices[0] }},
83
+ fc_weights,
84
+ fc_bias,
85
+ tensor_{{ op.output_indices[0] }},
86
+ {{ tensor_sizes[op.input_indices[0]] }},
87
+ {{ tensor_sizes[op.output_indices[0]] }},
88
+ {{ fc_multiplier }},
89
+ {{ fc_shift }}
90
+ );
91
+ {% elif op.op_name == "SOFTMAX" %}
92
+ tmlc_softmax_s8(
93
+ tensor_{{ op.input_indices[0] }},
94
+ tensor_{{ op.output_indices[0] }},
95
+ {{ tensor_sizes[op.input_indices[0]] }}
96
+ );
97
+ {% elif op.op_name == "RESHAPE" %}
98
+ {
99
+ static const int reshape_target[] = {
100
+ {% for s in op.reshape_target_shape %}
101
+ {{ s }}{% if not loop.last %},{% endif %}
102
+ {% endfor %}
103
+ };
104
+ int reshape_input_size = {{ tensor_sizes[op.input_indices[0]] }};
105
+ tmlc_reshape_s8(
106
+ tensor_{{ op.input_indices[0] }},
107
+ tensor_{{ op.output_indices[0] }},
108
+ reshape_input_size,
109
+ reshape_target,
110
+ {{ op.reshape_target_shape | length }}
111
+ );
112
+ }
113
+ {% elif op.op_name == "ADD" %}
114
+ tmlc_add_s8(
115
+ tensor_{{ op.input_indices[0] }},
116
+ tensor_{{ op.input_indices[1] }},
117
+ tensor_{{ op.output_indices[0] }},
118
+ {{ tensor_sizes[op.output_indices[0]] }}
119
+ );
120
+ {% elif op.op_name == "SVDF" %}
121
+ tmlc_svdf_s8(
122
+ tensor_{{ op.data_input_idx }},
123
+ svdf_weights,
124
+ svdf_bias,
125
+ tensor_{{ op.output_indices[0] }},
126
+ {{ op.svdf_params.time_steps }},
127
+ {{ op.svdf_params.input_size }},
128
+ {{ op.svdf_params.rank }},
129
+ {{ op.svdf_params.units }}
130
+ );
131
+ {% elif op.op_name == "CONV_2D" %}
132
+ tmlc_conv2d_s8(
133
+ tensor_{{ op.data_input_idx }},
134
+ conv_weights,
135
+ conv_bias,
136
+ tensor_{{ op.output_indices[0] }},
137
+ {{ op.conv_params.input_h }},
138
+ {{ op.conv_params.input_w }},
139
+ {{ op.conv_params.input_c }},
140
+ {{ op.conv_params.output_h }},
141
+ {{ op.conv_params.output_w }},
142
+ {{ op.conv_params.output_c }},
143
+ {{ op.conv_params.kernel_h }},
144
+ {{ op.conv_params.kernel_w }},
145
+ {{ op.conv_params.stride_h }},
146
+ {{ op.conv_params.stride_w }},
147
+ {{ op.conv_params.padding_h }},
148
+ {{ op.conv_params.padding_w }},
149
+ {{ conv_multiplier }},
150
+ {{ conv_shift }}
151
+ );
152
+ {% elif op.op_name == "MAX_POOL_2D" %}
153
+ tmlc_max_pool_2d_s8(
154
+ tensor_{{ op.data_input_idx }},
155
+ tensor_{{ op.output_indices[0] }},
156
+ {{ op.pool_params.input_h }},
157
+ {{ op.pool_params.input_w }},
158
+ {{ op.pool_params.input_c }},
159
+ {{ op.pool_params.output_h }},
160
+ {{ op.pool_params.output_w }},
161
+ {{ op.pool_params.output_c }},
162
+ {{ op.pool_params.pool_size_h }},
163
+ {{ op.pool_params.pool_size_w }},
164
+ {{ op.pool_params.stride_h }},
165
+ {{ op.pool_params.stride_w }},
166
+ 0, 0
167
+ );
168
+ {% elif op.op_name == "DEPTHWISE_CONV_2D" %}
169
+ tmlc_depthwise_conv_2d_s8(
170
+ tensor_{{ op.data_input_idx }},
171
+ dw_weights,
172
+ dw_bias,
173
+ tensor_{{ op.output_indices[0] }},
174
+ {{ op.dw_params.input_h }},
175
+ {{ op.dw_params.input_w }},
176
+ {{ op.dw_params.input_c }},
177
+ {{ op.dw_params.output_h }},
178
+ {{ op.dw_params.output_w }},
179
+ {{ op.dw_params.output_c }},
180
+ {{ op.dw_params.kernel_h }},
181
+ {{ op.dw_params.kernel_w }},
182
+ {{ op.dw_params.stride_h }},
183
+ {{ op.dw_params.stride_w }},
184
+ {{ op.dw_params.depth_multiplier }},
185
+ {{ op.dw_params.padding_h }},
186
+ {{ op.dw_params.padding_w }},
187
+ {{ conv_multiplier }},
188
+ {{ conv_shift }}
189
+ );
190
+ {% elif op.op_name == "RELU" %}
191
+ tmlc_relu_s8(
192
+ tensor_{{ op.input_indices[0] }},
193
+ tensor_{{ op.output_indices[0] }},
194
+ {{ tensor_sizes[op.input_indices[0]] }}
195
+ );
196
+ {% elif op.op_name == "AVERAGE_POOL_2D" %}
197
+ tmlc_avg_pool_2d_s8(
198
+ tensor_{{ op.data_input_idx }},
199
+ tensor_{{ op.output_indices[0] }},
200
+ {{ op.pool_params.input_h }},
201
+ {{ op.pool_params.input_w }},
202
+ {{ op.pool_params.input_c }},
203
+ {{ op.pool_params.output_h }},
204
+ {{ op.pool_params.output_w }},
205
+ {{ op.pool_params.output_c }},
206
+ {{ op.pool_params.pool_size_h }},
207
+ {{ op.pool_params.pool_size_w }},
208
+ {{ op.pool_params.stride_h }},
209
+ {{ op.pool_params.stride_w }},
210
+ 0, 0
211
+ );
212
+ {% elif op.op_name == "TRANSPOSE" %}
213
+ tmlc_transpose_s8(
214
+ tensor_{{ op.data_input_idx }},
215
+ NULL,
216
+ tensor_{{ op.output_indices[0] }},
217
+ {{ op.transpose_params.input_dims }},
218
+ (const int[]){
219
+ {% for s in tensor_shapes[op.data_input_idx] %}
220
+ {{ s }}{% if not loop.last %},{% endif %}
221
+ {% endfor %}
222
+ }
223
+ );
224
+ {% elif op.op_name == "PAD" %}
225
+ tmlc_pad_s8(
226
+ tensor_{{ op.data_input_idx }},
227
+ NULL, // paddings, currently NULL
228
+ tensor_{{ op.output_indices[0] }},
229
+ 4,
230
+ (const int[]){
231
+ {% for s in tensor_shapes[op.data_input_idx] %}
232
+ {{ s }}{% if not loop.last %},{% endif %}
233
+ {% endfor %}
234
+ },
235
+ (const int[]){
236
+ {% for s in tensor_shapes[op.output_indices[0]] %}
237
+ {{ s }}{% if not loop.last %},{% endif %}
238
+ {% endfor %}
239
+ }
240
+ );
241
+ {% elif op.op_name == "MEAN" %}
242
+ tmlc_mean_s8(
243
+ tensor_{{ op.data_input_idx }},
244
+ tensor_{{ op.output_indices[0] }},
245
+ {{ op.mean_params.input_dims }},
246
+ (const int[]){
247
+ {% for s in tensor_shapes[op.data_input_idx] %}
248
+ {{ s }}{% if not loop.last %},{% endif %}
249
+ {% endfor %}
250
+ },
251
+ (const int[]){
252
+ {% for s in tensor_shapes[op.output_indices[0]] %}
253
+ {{ s }}{% if not loop.last %},{% endif %}
254
+ {% endfor %}
255
+ },
256
+ NULL,
257
+ 0,
258
+ 0
259
+ );
260
+ {% elif op.op_name == "MULTIPLY" %}
261
+ tmlc_multiply_s8(
262
+ tensor_{{ op.input_indices[0] }},
263
+ tensor_{{ op.input_indices[1] }},
264
+ tensor_{{ op.output_indices[0] }},
265
+ {{ tensor_sizes[op.output_indices[0]] }}
266
+ );
267
+ {% elif op.op_name == "SIGMOID" %}
268
+ tmlc_sigmoid_s8(
269
+ tensor_{{ op.input_indices[0] }},
270
+ tensor_{{ op.output_indices[0] }},
271
+ {{ tensor_sizes[op.input_indices[0]] }}
272
+ );
273
+ {% elif op.op_name == "CONCAT" %}
274
+ {
275
+ static const int8_t* concat_inputs[] = {
276
+ {% for idx in op.input_indices %}
277
+ tensor_{{ idx }}{% if not loop.last %},{% endif %}
278
+ {% endfor %}
279
+ };
280
+ static const int concat_sizes[] = {
281
+ {% for idx in op.input_indices %}
282
+ {{ tensor_sizes[idx] }}{% if not loop.last %},{% endif %}
283
+ {% endfor %}
284
+ };
285
+ tmlc_concat_s8(
286
+ concat_inputs,
287
+ concat_sizes,
288
+ {{ op.input_indices | length }},
289
+ tensor_{{ op.output_indices[0] }}
290
+ );
291
+ }
292
+ {% elif op.op_name == "SUB" %}
293
+ tmlc_sub_s8(
294
+ tensor_{{ op.input_indices[0] }},
295
+ tensor_{{ op.input_indices[1] }},
296
+ tensor_{{ op.output_indices[0] }},
297
+ {{ tensor_sizes[op.output_indices[0]] }}
298
+ );
299
+ {% elif op.op_name == "TANH" %}
300
+ tmlc_tanh_s8(
301
+ tensor_{{ op.input_indices[0] }},
302
+ tensor_{{ op.output_indices[0] }},
303
+ {{ tensor_sizes[op.input_indices[0]] }}
304
+ );
305
+ {% endif %}
306
+ {% endfor %}
307
+
308
+ // Output tensor mapping
309
+ if (output != NULL) {
310
+ for (int i = 0; i < OUTPUT_SIZE; i++) {
311
+ output[i] = tensor_{{ last_output_tensor }}[i];
312
+ }
313
+ }
314
+ }
@@ -0,0 +1,66 @@
1
+ /* TinyMLC - Tiny Machine Learning Compiler
2
+ *
3
+ * Copyright (c) 2026 Jia Liu & TinyMLC Contributors
4
+ * SPDX-License-Identifier: Apache-2.0
5
+ *
6
+ * This file is part of TinyMLC.
7
+ * Licensed under the Apache License, Version 2.0 (the "License");
8
+ * you may not use this file except in compliance with the License.
9
+ * You may obtain a copy of the License at:
10
+ *
11
+ * http://www.apache.org/licenses/LICENSE-2.0
12
+ *
13
+ * Unless required by applicable law or agreed to in writing, software
14
+ * distributed under the License is distributed on an "AS IS" BASIS,
15
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ * See the License for the specific language governing permissions and
17
+ * limitations under the License.
18
+ */
19
+
20
+ // Auto-generated code, do not modify manually
21
+ // Generated by tinymlc
22
+
23
+ #ifndef TINYMLC_MODEL_H
24
+ #define TINYMLC_MODEL_H
25
+
26
+ #include <stdint.h>
27
+
28
+ #define INPUT_SIZE {{ input_size }}
29
+ #define OUTPUT_SIZE {{ output_size }}
30
+
31
+ {% if inputs_count >= 1 %}
32
+ #define INPUT_SIZE_1 {{ INPUT_SIZE_1 }}
33
+ {% endif %}
34
+ {% if inputs_count >= 2 %}
35
+ #define INPUT_SIZE_2 {{ INPUT_SIZE_2 }}
36
+ {% endif %}
37
+
38
+ {% if target == "arm" %}
39
+ // FC quantization parameters
40
+ #define FC_INPUT_OFFSET 0
41
+ #define FC_OUTPUT_OFFSET 0
42
+ #define FC_MULTIPLIER {{ fc_multiplier }}
43
+ #define FC_SHIFT {{ fc_shift }}
44
+
45
+ // Softmax quantization parameters
46
+ #define SOFTMAX_MULTIPLIER {{ softmax_multiplier }}
47
+ #define SOFTMAX_SHIFT {{ softmax_shift }}
48
+ #define SOFTMAX_DIFF_MIN -128
49
+ {% endif %}
50
+
51
+ // LSTM right shift bits (calculated from model quantization parameters)
52
+ #define LSTM_SHIFT_I {{ lstm_shifts[0] }}
53
+ #define LSTM_SHIFT_F {{ lstm_shifts[1] }}
54
+ #define LSTM_SHIFT_G {{ lstm_shifts[2] }}
55
+ #define LSTM_SHIFT_O {{ lstm_shifts[3] }}
56
+
57
+ // Inference function declaration
58
+ {% if inputs_count == 1 %}
59
+ void {{ inference_func }}(const int8_t* input, int8_t* output);
60
+ {% elif inputs_count == 2 %}
61
+ void {{ inference_func }}(const int8_t* input1,
62
+ const int8_t* input2,
63
+ int8_t* output);
64
+ {% endif %}
65
+
66
+ #endif // TINYMLC_MODEL_H
File without changes
@@ -0,0 +1,286 @@
1
+ # -*- coding: utf-8 -*-
2
+ # TinyMLC - Tiny Machine Learning Compiler
3
+ #
4
+ # Copyright (c) 2026 Jia Liu & TinyMLC Contributors
5
+ # SPDX-License-Identifier: Apache-2.0
6
+ #
7
+ # This file is part of TinyMLC.
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at:
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+
20
+ # Algebraic simplification: transforms arithmetic expressions
21
+ # to simpler equivalent forms.
22
+ #
23
+ # Patterns handled:
24
+ # 1. ADD(x, -x) -> 0 (or constant zero)
25
+ # 2. SUB(x, x) -> 0
26
+ # 3. ADD(x, constant) -> x + constant (evaluate if possible)
27
+ # 4. SUB(x, constant) -> x - constant
28
+ # 5. ADD(x, ADD(y, z)) -> ADD(ADD(x, y), z) (associative)
29
+ # 6. ADD(x, y) -> SUB(x, -y) (if y is negative constant)
30
+ import numpy as np
31
+
32
+ from typing import Dict, Any, List, Optional
33
+ from TinyMLC.transform.base import Pass
34
+
35
+
36
+ class AlgebraicSimplify(Pass):
37
+ """
38
+ Algebraic simplification for arithmetic operations.
39
+
40
+ Runs after Simplify to catch patterns that require numeric analysis.
41
+ """
42
+
43
+ def __init__(self, name: str = "AlgebraicSimplify"):
44
+ super().__init__(name)
45
+ self._simplified_count = 0
46
+
47
+ def run(self, model_info: Dict[str, Any]) -> Dict[str, Any]:
48
+ model_info = self._copy_model(model_info)
49
+ self._simplified_count = 0
50
+
51
+ changed = True
52
+ iteration = 0
53
+
54
+ while changed:
55
+ changed = False
56
+ iteration += 1
57
+
58
+ # 1. SUB(x, x) -> 0
59
+ changed |= self._simplify_sub_same(model_info)
60
+
61
+ # 2. ADD(x, -x) -> 0 (if -x is a constant)
62
+ changed |= self._simplify_add_neg(model_info)
63
+
64
+ # 3. ADD(x, constant) with constant that can be folded
65
+ changed |= self._simplify_add_constant(model_info)
66
+
67
+ # 4. SUB(x, constant) with constant that can be folded
68
+ changed |= self._simplify_sub_constant(model_info)
69
+
70
+ if changed:
71
+ self._log_change(f"Iteration {iteration}: applied")
72
+
73
+ if self._simplified_count > 0:
74
+ self._log_change(
75
+ f"Total: {self._simplified_count} algebraic simplifications"
76
+ )
77
+
78
+ return model_info
79
+
80
+ # ================================================================
81
+ # 1. SUB(x, x) -> 0 (forward constant zero)
82
+ # ================================================================
83
+
84
+ def _simplify_sub_same(self, model_info: Dict[str, Any]) -> bool:
85
+ """
86
+ SUB(x, x) -> 0
87
+ """
88
+ ops = model_info.get("ops", [])
89
+ changed = False
90
+ i = 0
91
+
92
+ while i < len(ops):
93
+ op = ops[i]
94
+ if op.get("op_name") == "SUB":
95
+ input_indices = op.get("input_indices", [])
96
+ if (len(input_indices) == 2
97
+ and input_indices[0] == input_indices[1]):
98
+ # Create a constant zero tensor
99
+ zero_idx = self._create_constant_zero(model_info)
100
+ self._remove_op_and_forward(ops, i, zero_idx, model_info)
101
+ changed = True
102
+ self._simplified_count += 1
103
+ self._log_change(" SUB(x, x) -> 0")
104
+ continue
105
+
106
+ i += 1
107
+
108
+ if changed:
109
+ model_info["ops"] = ops
110
+
111
+ return changed
112
+
113
+ # ================================================================
114
+ # 2. ADD(x, -x) -> 0 (if -x is a constant)
115
+ # ================================================================
116
+
117
+ def _simplify_add_neg(self, model_info: Dict[str, Any]) -> bool:
118
+ """
119
+ ADD(x, -x) -> 0 when -x is a constant with opposite sign.
120
+ """
121
+ ops = model_info.get("ops", [])
122
+ changed = False
123
+ i = 0
124
+
125
+ while i < len(ops):
126
+ op = ops[i]
127
+ if op.get("op_name") == "ADD":
128
+ input_indices = op.get("input_indices", [])
129
+ if len(input_indices) == 2:
130
+ a = input_indices[0]
131
+ b = input_indices[1]
132
+ # Check if a is constant and b is the same constant
133
+ # with opposite sign
134
+ const_val = self._get_constant_value(model_info, a)
135
+ if const_val is not None:
136
+ const_b = self._get_constant_value(model_info, b)
137
+ if const_b is not None and const_b == -const_val:
138
+ zero_idx = self._create_constant_zero(model_info)
139
+ self._remove_op_and_forward(
140
+ ops, i, zero_idx, model_info
141
+ )
142
+ changed = True
143
+ self._simplified_count += 1
144
+ self._log_change(
145
+ f" ADD({const_val}, {-const_val}) -> 0"
146
+ )
147
+ continue
148
+
149
+ i += 1
150
+
151
+ if changed:
152
+ model_info["ops"] = ops
153
+
154
+ return changed
155
+
156
+ # ================================================================
157
+ # 3. ADD(x, constant) with constant that can be folded
158
+ # ================================================================
159
+
160
+ def _simplify_add_constant(self, model_info: Dict[str, Any]) -> bool:
161
+ """
162
+ ADD(x, constant) -> x + constant (if constant is a scalar).
163
+ """
164
+ ops = model_info.get("ops", [])
165
+ changed = False
166
+ i = 0
167
+
168
+ while i < len(ops):
169
+ op = ops[i]
170
+ if op.get("op_name") == "ADD":
171
+ input_indices = op.get("input_indices", [])
172
+ if len(input_indices) == 2:
173
+ # Check if one input is a constant
174
+ const_idx = None
175
+ other_idx = None
176
+ for idx in input_indices:
177
+ if self._is_scalar_constant(model_info, idx):
178
+ const_idx = idx
179
+ else:
180
+ other_idx = idx
181
+
182
+ if const_idx is not None and other_idx is not None:
183
+ # Fold constant into the op's params
184
+ const_val = self._get_constant_value(
185
+ model_info, const_idx
186
+ )
187
+ if const_val is not None:
188
+ # Replace with a new op that has the constant
189
+ # baked in
190
+ # For now, we just keep the constant as a tensor and
191
+ # let the constant folding pass handle it.
192
+ # But we can mark it for later folding.
193
+ pass
194
+ i += 1
195
+
196
+ if changed:
197
+ model_info["ops"] = ops
198
+
199
+ return changed
200
+
201
+ # ================================================================
202
+ # 4. SUB(x, constant) -> x - constant (if constant is a scalar)
203
+ # ================================================================
204
+
205
+ def _simplify_sub_constant(self, model_info: Dict[str, Any]) -> bool:
206
+ """
207
+ SUB(x, constant) -> x - constant (if constant is a scalar).
208
+ """
209
+ ops = model_info.get("ops", [])
210
+ changed = False
211
+ i = 0
212
+
213
+ while i < len(ops):
214
+ op = ops[i]
215
+ if op.get("op_name") == "SUB":
216
+ input_indices = op.get("input_indices", [])
217
+ if len(input_indices) == 2:
218
+ # Check if the second input is a constant
219
+ const_idx = input_indices[1]
220
+ if self._is_scalar_constant(model_info, const_idx):
221
+ # This can be folded by constant folding
222
+ pass
223
+ i += 1
224
+
225
+ return changed
226
+
227
+ # ================================================================
228
+ # Helper functions
229
+ # ================================================================
230
+
231
+ def _create_constant_zero(self, model_info: Dict[str, Any]) -> int:
232
+ """Create a constant zero tensor."""
233
+ # Find the max tensor index
234
+ tensors = model_info.get("tensors", {})
235
+ max_idx = max(tensors.keys()) if tensors else 0
236
+ new_idx = max_idx + 1
237
+
238
+ # Create zero tensor (scalar)
239
+ tensors[new_idx] = {
240
+ "name": f"zero_{new_idx}",
241
+ "shape": [1],
242
+ "dtype": "int8",
243
+ "scale": 1.0,
244
+ "zero_point": 0,
245
+ }
246
+ model_info["weights"][new_idx] = np.array([0], dtype=np.int8)
247
+
248
+ return new_idx
249
+
250
+ def _get_constant_value(self, model_info: Dict, idx: int) -> Optional[int]:
251
+ """Get scalar constant value if tensor at idx is a scalar constant."""
252
+ weights = model_info.get("weights", {})
253
+ if idx in weights:
254
+ weight = weights[idx]
255
+ if hasattr(weight, "size") and weight.size == 1:
256
+ return int(weight.item())
257
+ if isinstance(weight, list) and len(weight) == 1:
258
+ return weight[0]
259
+ return None
260
+
261
+ def _is_scalar_constant(self, model_info: Dict, idx: int) -> bool:
262
+ """Check if tensor at idx is a scalar constant."""
263
+ return self._get_constant_value(model_info, idx) is not None
264
+
265
+ def _remove_op_and_forward(
266
+ self,
267
+ ops: List[Dict],
268
+ op_idx: int,
269
+ forward_idx: int,
270
+ model_info: Dict
271
+ ) -> None:
272
+ """Remove op at op_idx and forward its output from forward_idx."""
273
+ op = ops[op_idx]
274
+ output_idx = op.get("output_indices", [])[0]
275
+
276
+ for other_op in ops:
277
+ for i, idx in enumerate(other_op.get("input_indices", [])):
278
+ if idx == output_idx:
279
+ other_op["input_indices"][i] = forward_idx
280
+
281
+ del ops[op_idx]
282
+
283
+ if output_idx in model_info.get("tensors", {}):
284
+ del model_info["tensors"][output_idx]
285
+ if output_idx in model_info.get("weights", {}):
286
+ del model_info["weights"][output_idx]