tinymlc 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. TinyMLC/ANG/__init__.py +0 -0
  2. TinyMLC/ANG/args.py +86 -0
  3. TinyMLC/ANG/estimator.py +103 -0
  4. TinyMLC/ANG/estimator_hal.py +184 -0
  5. TinyMLC/ANG/estimator_qemu.py +257 -0
  6. TinyMLC/ANG/estimator_software.py +130 -0
  7. TinyMLC/ANG/model_builder.py +508 -0
  8. TinyMLC/ANG/model_generator.py +439 -0
  9. TinyMLC/ANG/model_info.py +283 -0
  10. TinyMLC/ANG/utils.py +420 -0
  11. TinyMLC/__init__.py +0 -0
  12. TinyMLC/cli.py +126 -0
  13. TinyMLC/codegen.py +877 -0
  14. TinyMLC/converter/__init__.py +0 -0
  15. TinyMLC/converter/export_weights.py +382 -0
  16. TinyMLC/converter/parser_litert.py +757 -0
  17. TinyMLC/converter/parser_onnx.py +649 -0
  18. TinyMLC/generate_lut.py +97 -0
  19. TinyMLC/handlers.py +325 -0
  20. TinyMLC/ops.py +76 -0
  21. TinyMLC/templates/lut.c.tpl +23 -0
  22. TinyMLC/templates/lut.h.tpl +67 -0
  23. TinyMLC/templates/model.c.tpl +314 -0
  24. TinyMLC/templates/model.h.tpl +66 -0
  25. TinyMLC/transform/__init__.py +0 -0
  26. TinyMLC/transform/algebraic.py +286 -0
  27. TinyMLC/transform/base.py +58 -0
  28. TinyMLC/transform/constant_folding.py +260 -0
  29. TinyMLC/transform/cse.py +192 -0
  30. TinyMLC/transform/dce.py +182 -0
  31. TinyMLC/transform/fusion.py +723 -0
  32. TinyMLC/transform/memory.py +200 -0
  33. TinyMLC/transform/pass_manager.py +101 -0
  34. TinyMLC/transform/simplify.py +515 -0
  35. tinymlc-0.1.0.dist-info/METADATA +49 -0
  36. tinymlc-0.1.0.dist-info/RECORD +47 -0
  37. tinymlc-0.1.0.dist-info/WHEEL +4 -0
  38. tinymlc-0.1.0.dist-info/entry_points.txt +2 -0
  39. tinymlc-0.1.0.dist-info/licenses/LICENSE +201 -0
  40. utils/__init__.py +0 -0
  41. utils/arm-none-eabi-gcc.cmake +53 -0
  42. utils/dump.py +86 -0
  43. utils/generate_onnx_models.py +183 -0
  44. utils/generate_tflite_models.py +236 -0
  45. utils/pack_macos.sh +88 -0
  46. utils/path.py +31 -0
  47. utils/riscv-none-elf-gcc.cmake +50 -0
TinyMLC/codegen.py ADDED
@@ -0,0 +1,877 @@
1
+ # -*- coding: utf-8 -*-
2
+ # TinyMLC - Tiny Machine Learning Compiler
3
+ #
4
+ # Copyright (c) 2026 Jia Liu & TinyMLC Contributors
5
+ # SPDX-License-Identifier: Apache-2.0
6
+ #
7
+ # This file is part of TinyMLC.
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at:
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+
20
+
21
+ import stat
22
+ import numpy as np
23
+ import shutil
24
+
25
+ from pathlib import Path
26
+ from jinja2 import Template
27
+ from typing import Dict, Any, List, Optional
28
+
29
+ from TinyMLC.ops import SUPPORTED_OPS
30
+ from utils.dump import fatal_error, info
31
+
32
+
33
+ # Fallback values, used when valid scale cannot be read from model
34
+ DEFAULT_SCALE = 0.01 # empirical value
35
+ DEFAULT_SHIFT = 8 # empirical value
36
+
37
+
38
+ def build_execution_order(ops, tensors):
39
+ """Determine operator execution order based on tensor dependencies"""
40
+
41
+ # Convert all indices to Python int
42
+ for op in ops:
43
+ # Skip if index is not yet assigned
44
+ if "index" not in op or op["index"] is None:
45
+ continue
46
+ op["index"] = int(op["index"])
47
+ if "input_indices" in op:
48
+ op["input_indices"] = [int(i) for i in op["input_indices"]]
49
+ if "output_indices" in op:
50
+ op["output_indices"] = [int(i) for i in op["output_indices"]]
51
+
52
+ # 1. Build tensor -> producer operator mapping
53
+ tensor_producer = {}
54
+ for op in ops:
55
+ for out_idx in op.get("output_indices", []):
56
+ tensor_producer[int(out_idx)] = op
57
+
58
+ # 2. Build operator dependency relationships
59
+ op_deps = {}
60
+ for op in ops:
61
+ deps = set()
62
+ op_idx = int(op["index"])
63
+ for inp_idx in op.get("input_indices", []):
64
+ inp_idx = int(inp_idx)
65
+ if inp_idx in tensor_producer:
66
+ producer = tensor_producer[inp_idx]
67
+ prod_idx = int(producer["index"])
68
+ if prod_idx != op_idx:
69
+ deps.add(prod_idx)
70
+ op_deps[op_idx] = list(deps)
71
+
72
+ # 3. Calculate in-degree (how many operators current op depends on)
73
+ in_degree = {}
74
+ for op in ops:
75
+ op_idx = int(op["index"])
76
+ in_degree[op_idx] = len(op_deps.get(op_idx, []))
77
+
78
+ # 4. Topological sort (Kahn's algorithm)
79
+ from collections import deque
80
+ queue = deque([op_idx for op_idx, deg in in_degree.items() if deg == 0])
81
+
82
+ order = []
83
+ while queue:
84
+ op_idx = queue.popleft()
85
+ op = next(o for o in ops if int(o["index"]) == op_idx)
86
+ order.append(op)
87
+
88
+ for next_op in ops:
89
+ next_idx = int(next_op["index"])
90
+ if op_idx in op_deps.get(next_idx, []):
91
+ in_degree[next_idx] -= 1
92
+ if in_degree[next_idx] == 0:
93
+ queue.append(next_idx)
94
+
95
+ if len(order) != len(ops):
96
+ fatal_error(
97
+ "Model has cyclic dependencies, cannot determine execution order",
98
+ "Please check if model structure is valid")
99
+
100
+ return order
101
+
102
+
103
+ def calculate_multiplier_shift(input_scale, weight_scale, output_scale):
104
+ """
105
+ Calculate multiplier and shift for int8 quantization
106
+
107
+ Quantization formula: output = round((acc * multiplier) >> (31 + shift))
108
+ where acc = sum(input * weight) + bias
109
+
110
+ Q31 fixed-point format:
111
+ - 1 << 31 = 2147483648, represents max value in Q31 format
112
+ - multiplier stored in 32-bit signed int, range -2147483648 ~ 2147483647
113
+ - shift adjusts effective_scale * 2^31 to valid range
114
+
115
+ Args:
116
+ input_scale: input tensor quantization scale
117
+ weight_scale: weight quantization scale
118
+ output_scale: output tensor quantization scale
119
+
120
+ Returns:
121
+ multiplier: Q31 fixed-point scale factor
122
+ shift: right shift adjustment bits
123
+ """
124
+ effective_scale = (input_scale * weight_scale) / output_scale
125
+
126
+ if effective_scale == 0:
127
+ return 0, 0
128
+
129
+ # Q31 format: multiplier = effective_scale * 2^31
130
+ mult = effective_scale * (1 << 31)
131
+ shift = 0
132
+
133
+ # multiplier exceeds int32 range: decrease shift (increase actual scale)
134
+ while mult > 2147483647:
135
+ shift -= 1
136
+ mult /= 2
137
+
138
+ # multiplier too small for precision: increase shift (decrease actual scale)
139
+ while mult < 0.5:
140
+ shift += 1
141
+ mult *= 2
142
+
143
+ multiplier = int(round(mult))
144
+ multiplier = max(0, min(multiplier, 2147483647))
145
+
146
+ return multiplier, shift
147
+
148
+
149
+ def calculate_multiplier_shift_from_scale(input_scale, weight_scale,
150
+ output_scale):
151
+ """Calculate multiplier and shift from scales"""
152
+ return calculate_multiplier_shift(input_scale, weight_scale, output_scale)
153
+
154
+
155
+ def validate_ops(model_info: Dict[str, Any]) -> None:
156
+ """Validate all operators and check for supported operators."""
157
+ ops = model_info.get("ops", [])
158
+
159
+ for op in ops:
160
+ state = op.get("state")
161
+ if state not in ("translated", "generated"):
162
+ fatal_error(
163
+ f"Operator {op['op_name']} state is {state}, "
164
+ "cannot generate code",
165
+ f"Pass flags: {op.get('pass_flags', {})}")
166
+
167
+ has_supported = any(op.get("op_name") in SUPPORTED_OPS for op in ops)
168
+ if not has_supported:
169
+ fatal_error(
170
+ "Model does not contain any supported operators",
171
+ f"Supported operators: {', '.join(SUPPORTED_OPS)}")
172
+
173
+
174
+ def analyze_ops(
175
+ model_info: Dict[str, Any],
176
+ execution_order: List[Dict[str, Any]]
177
+ ) -> Dict[str, Any]:
178
+ """
179
+ Analyze operator types, LSTM params, FC/Conv quantization params.
180
+
181
+ Returns:
182
+ {
183
+ "has_fc": bool,
184
+ "has_conv": bool,
185
+ "has_dw": bool,
186
+ "has_svdf": bool,
187
+ "lstm_params": dict,
188
+ "fc_scale": float,
189
+ "fc_output_scale": float,
190
+ "fc_multiplier": int,
191
+ "fc_shift": int,
192
+ "conv_multiplier": int,
193
+ "conv_shift": int,
194
+ }
195
+ """
196
+ tensors = model_info.get("tensors", {})
197
+
198
+ # ---- Detect operator types ----
199
+ has_fc = False
200
+ has_conv = False
201
+ has_dw = False
202
+ has_svdf = False
203
+ lstm_params = None
204
+
205
+ for op in model_info.get("ops", []):
206
+ op_name = op.get("op_name")
207
+ if op_name == "FULLY_CONNECTED":
208
+ has_fc = True
209
+ elif op_name == "UNIDIRECTIONAL_SEQUENCE_LSTM":
210
+ lstm_params = op.get("lstm_params")
211
+ elif op_name == "SVDF":
212
+ has_svdf = True
213
+ elif op_name == "CONV_2D":
214
+ has_conv = True
215
+ elif op_name == "DEPTHWISE_CONV_2D":
216
+ has_dw = True
217
+
218
+ # ---- LSTM params ----
219
+ if lstm_params is None:
220
+ lstm_params = {
221
+ "time_steps": 0,
222
+ "batch_size": 0,
223
+ "input_size": 0,
224
+ "hidden_size": 0,
225
+ "shifts": [8, 8, 8, 8],
226
+ "input_scale": 0.00390625,
227
+ "input_zp": 0,
228
+ }
229
+ else:
230
+ input_scales = lstm_params.get(
231
+ "input_scales",
232
+ [DEFAULT_SCALE, DEFAULT_SCALE, DEFAULT_SCALE, DEFAULT_SCALE])
233
+ recurrent_scales = lstm_params.get(
234
+ "recurrent_scales",
235
+ [DEFAULT_SCALE, DEFAULT_SCALE, DEFAULT_SCALE, DEFAULT_SCALE])
236
+
237
+ shifts = []
238
+ for in_s, rec_s in zip(input_scales, recurrent_scales):
239
+ gate_scale = in_s * rec_s
240
+ if gate_scale > 0:
241
+ shift = int(np.log2(1.0 / gate_scale))
242
+ else:
243
+ shift = DEFAULT_SHIFT
244
+ shift = max(4, min(shift, 12))
245
+ shifts.append(shift)
246
+
247
+ lstm_params["shifts"] = shifts
248
+ info(
249
+ f"LSTM right shifts: i={shifts[0]}, f={shifts[1]}, "
250
+ f"g={shifts[2]}, o={shifts[3]}")
251
+
252
+ # ---- FC quantization ----
253
+ fc_scale = None
254
+ fc_output_scale = None
255
+ fc_multiplier = None
256
+ fc_shift = None
257
+
258
+ for op in model_info.get("ops", []):
259
+ if op.get("op_name") != "FULLY_CONNECTED":
260
+ continue
261
+
262
+ fc_scale = op.get("fc_scale")
263
+ fc_output_scale = op.get("fc_output_scale")
264
+
265
+ if fc_scale is None:
266
+ quant_scales = model_info.get("quant_scales", {})
267
+ fc_scale = quant_scales.get("fc_scale")
268
+
269
+ if fc_scale is None:
270
+ input_indices = op.get("input_indices", [])
271
+ if len(input_indices) > 1:
272
+ weight_idx = input_indices[1]
273
+ if weight_idx in tensors:
274
+ fc_scale = tensors[weight_idx].get("scale")
275
+
276
+ if fc_output_scale is None:
277
+ output_indices = op.get("output_indices", [])
278
+ if output_indices:
279
+ output_idx = output_indices[0]
280
+ if output_idx in tensors:
281
+ fc_output_scale = tensors[output_idx].get("scale")
282
+
283
+ fc_input_scale = 0.00390625
284
+ input_indices = op.get("input_indices", [])
285
+ if input_indices:
286
+ data_idx = input_indices[0]
287
+ if data_idx in tensors:
288
+ fc_input_scale = tensors[data_idx].get("scale", 0.00390625)
289
+
290
+ if fc_scale is None:
291
+ fc_scale = 0.01
292
+ info(f"FC using default weight scale: {fc_scale}")
293
+
294
+ if fc_output_scale is None:
295
+ fc_output_scale = 0.00390625
296
+ info(f"FC using default output scale: {fc_output_scale}")
297
+
298
+ fc_multiplier, fc_shift = calculate_multiplier_shift_from_scale(
299
+ fc_input_scale, fc_scale, fc_output_scale)
300
+
301
+ info(
302
+ f"FC quantization params: scale={fc_scale}, "
303
+ f"output_scale={fc_output_scale}, "
304
+ f"multiplier={fc_multiplier}, shift={fc_shift}")
305
+ break
306
+
307
+ if fc_multiplier is None:
308
+ fc_multiplier, fc_shift = 213512, -30
309
+ info("Using fallback FC quantization params")
310
+
311
+ # ---- CONV quantization ----
312
+ conv_multiplier = None
313
+ conv_shift = None
314
+
315
+ for op in model_info.get("ops", []):
316
+ if op.get("op_name") == "CONV_2D":
317
+ conv_scale = op.get("conv_scale")
318
+ conv_output_scale = op.get("conv_output_scale")
319
+
320
+ if conv_scale is None:
321
+ quant_scales = model_info.get("quant_scales", {})
322
+ conv_scale = quant_scales.get("conv_scale")
323
+
324
+ if conv_scale is None:
325
+ input_indices = op.get("input_indices", [])
326
+ if len(input_indices) > 1:
327
+ weight_idx = input_indices[1]
328
+ if weight_idx in tensors:
329
+ conv_scale = tensors[weight_idx].get("scale")
330
+
331
+ if conv_output_scale is None:
332
+ output_indices = op.get("output_indices", [])
333
+ if output_indices:
334
+ output_idx = output_indices[0]
335
+ if output_idx in tensors:
336
+ conv_output_scale = tensors[output_idx].get("scale")
337
+
338
+ conv_input_scale = 0.00390625
339
+ input_indices = op.get("input_indices", [])
340
+ if input_indices:
341
+ data_idx = input_indices[0]
342
+ if data_idx in tensors:
343
+ conv_input_scale = tensors[data_idx].get("scale", 0.00390625)
344
+
345
+ if conv_scale is None:
346
+ conv_scale = 0.01
347
+ info(f"CONV_2D using default weight scale: {conv_scale}")
348
+
349
+ if conv_output_scale is None:
350
+ conv_output_scale = 0.00390625
351
+ info(f"CONV_2D using default output scale: {conv_output_scale}")
352
+
353
+ conv_multiplier, conv_shift = calculate_multiplier_shift_from_scale(
354
+ conv_input_scale, conv_scale, conv_output_scale)
355
+
356
+ info(
357
+ f"CONV_2D quantization params: scale={conv_scale}, "
358
+ f"output_scale={conv_output_scale}, "
359
+ f"multiplier={conv_multiplier}, shift={conv_shift}")
360
+ break
361
+
362
+ if conv_multiplier is None and has_dw:
363
+ for op in model_info.get("ops", []):
364
+ if op.get("op_name") == "DEPTHWISE_CONV_2D":
365
+ dw_scale = op.get("dw_scale", 0.01)
366
+ dw_output_scale = op.get("dw_output_scale", 0.00390625)
367
+ conv_input_scale = 0.00390625
368
+ conv_multiplier, conv_shift = (
369
+ calculate_multiplier_shift_from_scale(
370
+ conv_input_scale, dw_scale, dw_output_scale))
371
+ info(
372
+ f"DEPTHWISE_CONV_2D quantization params: "
373
+ f"scale={dw_scale}, output_scale={dw_output_scale}, "
374
+ f"multiplier={conv_multiplier}, shift={conv_shift}")
375
+ break
376
+
377
+ if conv_multiplier is None:
378
+ conv_multiplier, conv_shift = 0, 0
379
+
380
+ return {
381
+ "has_fc": has_fc,
382
+ "has_conv": has_conv,
383
+ "has_dw": has_dw,
384
+ "has_svdf": has_svdf,
385
+ "lstm_params": lstm_params,
386
+ "fc_scale": fc_scale,
387
+ "fc_output_scale": fc_output_scale,
388
+ "fc_multiplier": fc_multiplier,
389
+ "fc_shift": fc_shift,
390
+ "conv_multiplier": conv_multiplier,
391
+ "conv_shift": conv_shift,
392
+ }
393
+
394
+
395
+ def build_context(
396
+ model_info: Dict[str, Any],
397
+ execution_order: List[Dict[str, Any]],
398
+ op_analysis: Dict[str, Any],
399
+ # stats: Dict[str, Any], # Keep this code for convert might need it
400
+ ) -> Dict[str, Any]:
401
+ """
402
+ Build the template context from model_info.
403
+
404
+ This collects all data needed for the Jinja2 templates:
405
+ - tensor sizes/shapes
406
+ - reshape targets
407
+ - FC parameters
408
+ - tensor definitions
409
+ - etc.
410
+ """
411
+ tensors = model_info.get("tensors", {})
412
+ target = model_info.get("target", "riscv")
413
+ inference_func = model_info.get("inference_func", "tinymlc_inference")
414
+
415
+ has_fc = op_analysis["has_fc"]
416
+ has_conv = op_analysis["has_conv"]
417
+ has_dw = op_analysis["has_dw"]
418
+ has_svdf = op_analysis["has_svdf"]
419
+ lstm_params = op_analysis["lstm_params"]
420
+ fc_multiplier = op_analysis["fc_multiplier"]
421
+ fc_shift = op_analysis["fc_shift"]
422
+ fc_scale = op_analysis["fc_scale"]
423
+ fc_output_scale = op_analysis["fc_output_scale"]
424
+ conv_multiplier = op_analysis["conv_multiplier"]
425
+ conv_shift = op_analysis["conv_shift"]
426
+
427
+ # ---- Tensor sizes and shapes ----
428
+ tensor_sizes = {}
429
+ tensor_shapes = {}
430
+ for idx, spec in tensors.items():
431
+ shape = spec.get("shape", [])
432
+ size = 1
433
+ for d in shape:
434
+ size *= int(d)
435
+ tensor_sizes[int(idx)] = size
436
+ tensor_shapes[int(idx)] = [int(d) for d in shape]
437
+
438
+ # ---- Build includes ----
439
+ includes = []
440
+ if has_fc:
441
+ includes.append('#include "fc_weights.h"')
442
+ if lstm_params["time_steps"] > 0:
443
+ includes.append('#include "lstm_weights.h"')
444
+ if has_conv:
445
+ includes.append('#include "conv_weights.h"')
446
+ if has_dw:
447
+ includes.append('#include "dw_weights.h"')
448
+ if has_svdf:
449
+ includes.append('#include "svdf_weights.h"')
450
+
451
+ # ---- Reshape targets ----
452
+ reshape_targets = []
453
+ for op in execution_order:
454
+ if op.get("op_name") == "RESHAPE":
455
+ target_shape = op.get("reshape_target_shape", [])
456
+ if target_shape:
457
+ reshape_targets.append(
458
+ "{" + ", ".join(str(int(s)) for s in target_shape) + "}")
459
+ else:
460
+ reshape_targets.append("{0}")
461
+
462
+ # ---- FC params ----
463
+ fc_params = {}
464
+ for op in execution_order:
465
+ if op.get("op_name") == "FULLY_CONNECTED":
466
+ input_idx = op["input_indices"][0]
467
+ output_idx = op["output_indices"][0]
468
+ fc_params[op["index"]] = {
469
+ "input_size": tensor_sizes.get(input_idx, 0),
470
+ "output_size": tensor_sizes.get(output_idx, 0),
471
+ "multiplier": fc_multiplier,
472
+ "shift": fc_shift,
473
+ "scale": fc_scale,
474
+ "output_scale": fc_output_scale,
475
+ }
476
+
477
+ # ---- Copy conv_params from original ops ----
478
+ for op in execution_order:
479
+ if op.get("op_name") == "CONV_2D":
480
+ for orig_op in model_info.get("ops", []):
481
+ if orig_op.get("index") == op["index"]:
482
+ op["conv_params"] = orig_op.get("conv_params", {})
483
+ break
484
+ elif op.get("op_name") == "SVDF":
485
+ for orig_op in model_info.get("ops", []):
486
+ if orig_op.get("index") == op["index"]:
487
+ op["svdf_params"] = orig_op.get("svdf_params", {})
488
+ break
489
+
490
+ # ---- Input sizes ----
491
+ input_size_1 = 1
492
+ input_size_2 = 1
493
+ if len(model_info.get("input", [])) >= 1:
494
+ for dim in model_info["input"][0]["shape"]:
495
+ input_size_1 *= int(dim)
496
+ if len(model_info.get("input", [])) >= 2:
497
+ for dim in model_info["input"][1]["shape"]:
498
+ input_size_2 *= int(dim)
499
+
500
+ # ---- Input/output sizes ----
501
+ input_size = 1
502
+ if model_info.get("input"):
503
+ for dim in model_info["input"][0]["shape"]:
504
+ input_size *= int(dim)
505
+
506
+ output_size = 1
507
+ if model_info.get("output"):
508
+ for dim in model_info["output"][0]["shape"]:
509
+ output_size *= int(dim)
510
+
511
+ # ---- Input tensor indices ----
512
+ input_tensor_indices = []
513
+ for inp in model_info.get("input", []):
514
+ found = False
515
+ for idx, spec in tensors.items():
516
+ if spec.get("name") == inp.get("name"):
517
+ input_tensor_indices.append(int(idx))
518
+ found = True
519
+ break
520
+ if not found:
521
+ input_tensor_indices.append(0)
522
+
523
+ # ---- Tensors to define ----
524
+ tensors_to_define = []
525
+ defined_set = set(input_tensor_indices)
526
+
527
+ for op in execution_order:
528
+ for out_idx in op.get("output_indices", []):
529
+ out_idx = int(out_idx)
530
+ if out_idx in tensor_sizes and out_idx not in defined_set:
531
+ tensors_to_define.append({
532
+ "index": out_idx,
533
+ "size": tensor_sizes[out_idx],
534
+ "type": "int8_t"
535
+ })
536
+ defined_set.add(out_idx)
537
+
538
+ data_idx = op.get("data_input_idx")
539
+ if data_idx is not None:
540
+ data_idx = int(data_idx)
541
+ if data_idx not in op.get("output_indices", []):
542
+ if (data_idx in tensor_sizes and
543
+ data_idx not in defined_set and
544
+ data_idx not in input_tensor_indices):
545
+ tensors_to_define.append({
546
+ "index": data_idx,
547
+ "size": tensor_sizes[data_idx],
548
+ "type": "int8_t"
549
+ })
550
+ defined_set.add(data_idx)
551
+
552
+ if op.get("op_name") == "SVDF":
553
+ for key in ["svdf_weights_idx", "svdf_bias_idx"]:
554
+ idx = op.get(key)
555
+ if idx is not None:
556
+ idx = int(idx)
557
+ if idx not in defined_set and idx in tensor_sizes:
558
+ dtype = "int32_t" if key == "svdf_bias_idx" else "int8_t"
559
+ tensors_to_define.append({
560
+ "index": idx,
561
+ "size": tensor_sizes[idx],
562
+ "type": dtype
563
+ })
564
+ defined_set.add(idx)
565
+
566
+ elif op.get("op_name") == "ADD":
567
+ for key in ["add_input1_idx", "add_input2_idx"]:
568
+ idx = op.get(key)
569
+ if idx is not None:
570
+ idx = int(idx)
571
+ if idx not in defined_set and idx in tensor_sizes:
572
+ tensors_to_define.append({
573
+ "index": idx,
574
+ "size": tensor_sizes[idx],
575
+ "type": "int8_t"
576
+ })
577
+ defined_set.add(idx)
578
+
579
+ # ---- Pool params defaults ----
580
+ for op in execution_order:
581
+ if op.get("op_name") in ("AVERAGE_POOL_2D", "MAX_POOL_2D"):
582
+ pool_params = op.get("pool_params", {})
583
+ for key, default in [("pool_size_h", 2), ("pool_size_w", 2),
584
+ ("stride_h", 2), ("stride_w", 2)]:
585
+ if key not in pool_params or pool_params[key] is None:
586
+ pool_params[key] = default
587
+ op["pool_params"] = pool_params
588
+
589
+ # "stats": stats, # Keep this code for convert might need it
590
+ return {
591
+ "input_size": input_size,
592
+ "output_size": output_size,
593
+ "inference_func": inference_func,
594
+ "includes": "\n".join(includes),
595
+ "has_fc": has_fc,
596
+ "has_lstm": lstm_params["time_steps"] > 0,
597
+ "has_conv": has_conv,
598
+ "has_dw": has_dw,
599
+ "has_svdf": has_svdf,
600
+ "target": target,
601
+ "model_header": "model.h",
602
+ "lstm_time_steps": lstm_params["time_steps"],
603
+ "lstm_batch_size": lstm_params["batch_size"],
604
+ "lstm_input_size": lstm_params["input_size"],
605
+ "lstm_hidden_size": lstm_params["hidden_size"],
606
+ "lstm_input_scale": lstm_params.get("input_scale", 0.00390625),
607
+ "lstm_input_zp": lstm_params.get("input_zp", 0),
608
+ "lstm_shifts": lstm_params.get("shifts", [8, 8, 8, 8]),
609
+ "tensor_sizes": tensor_sizes,
610
+ "tensor_shapes": tensor_shapes,
611
+ "execution_order": execution_order,
612
+ "last_output_tensor": execution_order[-1]["output_indices"][0],
613
+ "reshape_targets": reshape_targets,
614
+ "fc_params": fc_params,
615
+ "inputs_count": len(model_info.get("input", [])),
616
+ "INPUT_SIZE_1": input_size_1,
617
+ "INPUT_SIZE_2": input_size_2,
618
+ "fc_multiplier": fc_multiplier,
619
+ "fc_shift": fc_shift,
620
+ "conv_multiplier": conv_multiplier,
621
+ "conv_shift": conv_shift,
622
+ "input_tensor_indices": input_tensor_indices,
623
+ "tensors_to_define": tensors_to_define,
624
+ }
625
+
626
+
627
+ def render_code(
628
+ context: Dict[str, Any],
629
+ output_dir: Path,
630
+ target: str,
631
+ inference_func: str,
632
+ with_test_main: bool,
633
+ accel_lib_inc: Optional[str] = None,
634
+ accel_lib_lib: Optional[str] = None,
635
+ ) -> Dict[str, str]:
636
+ """Render all templates and write files."""
637
+ template_dir = Path(__file__).parent / "templates"
638
+ output_dir = Path(output_dir)
639
+ output_dir.mkdir(parents=True, exist_ok=True)
640
+
641
+ # ---- Write feature flags ----
642
+ with open(output_dir / "model_features.txt", "w") as f:
643
+ if context.get("has_lstm"):
644
+ f.write("HAS_LSTM\n")
645
+ if context.get("has_fc"):
646
+ f.write("HAS_FC\n")
647
+
648
+ # ---- Render model.c ----
649
+ with open(template_dir / "model.c.tpl", "r") as f:
650
+ tmpl = Template(f.read())
651
+ model_c = tmpl.render(**context)
652
+
653
+ # ---- Render model.h ----
654
+ with open(template_dir / "model.h.tpl", "r") as f:
655
+ tmpl = Template(f.read())
656
+ model_h = tmpl.render(**context)
657
+
658
+ result = {
659
+ "model.c": model_c,
660
+ "model.h": model_h,
661
+ }
662
+
663
+ # ---- main_test.c ----
664
+ if with_test_main:
665
+ main_test_tpl = template_dir / "main_test.c.tpl"
666
+ if main_test_tpl.exists():
667
+ with open(main_test_tpl, "r") as f:
668
+ tmpl = Template(f.read())
669
+ result["main_test.c"] = tmpl.render(**context)
670
+ else:
671
+ # Fallback: copy from architecture-specific directory
672
+ src_dir = Path(__file__).parent.parent / "ops" / target
673
+ main_test_src = src_dir / "main_test.c"
674
+ if main_test_src.exists():
675
+ with open(main_test_src, "r") as f:
676
+ result["main_test.c"] = f.read()
677
+ else:
678
+ fatal_error(
679
+ f"main_test.c template not found: {main_test_tpl}",
680
+ f"Supported architectures: riscv, arm")
681
+
682
+ # ---- Write stats to model_info for UI ----
683
+ # Keep this code for convert might need it
684
+ # stats = context.get("stats", {})
685
+ # if stats:
686
+ # # The stats are already in model_info["quant_scales"] via compute_stats
687
+ # pass
688
+
689
+ return result
690
+
691
+
692
+ def generate_c_code(
693
+ model_info: Dict[str, Any],
694
+ output_dir: str,
695
+ target: str,
696
+ inference_func: str = "tinymlc_inference",
697
+ with_test_main: bool = False,
698
+ accel_lib_inc: Optional[str] = None,
699
+ accel_lib_lib: Optional[str] = None,
700
+ ) -> Dict[str, str]:
701
+ """
702
+ Generate C code from model_info.
703
+
704
+ This is the main entry point for code generation.
705
+ """
706
+ output_dir = Path(output_dir)
707
+
708
+ # ---- Validate ----
709
+ validate_ops(model_info)
710
+
711
+ # ---- Build execution order ----
712
+ execution_order = build_execution_order(
713
+ model_info.get("ops", []),
714
+ model_info.get("tensors", {})
715
+ )
716
+
717
+ # Log execution order
718
+ info("Operator execution order:")
719
+ for op in execution_order:
720
+ info(f" {op['index']}: {op['op_name']}")
721
+
722
+ # ---- Analyze ops ----
723
+ op_analysis = analyze_ops(model_info, execution_order)
724
+
725
+ # ---- Compute stats ----
726
+ # Keep this code for convert might need it
727
+ # stats = {
728
+ # "macs": calculate_macs(model_info),
729
+ # "params": calculate_params(model_info),
730
+ # "peak_ram": calculate_peak_ram(model_info),
731
+ # "flash": calculate_flash(model_info),
732
+ # }
733
+
734
+ # ---- Store stats in model_info for UI ----
735
+ # Keep this code for convert might need it
736
+ # if "quant_scales" not in model_info:
737
+ # model_info["quant_scales"] = {}
738
+ # model_info["quant_scales"]["macs"] = stats["macs"]
739
+ # model_info["quant_scales"]["params"] = stats["params"]
740
+ # model_info["quant_scales"]["peak_ram"] = stats["peak_ram"]
741
+ # model_info["quant_scales"]["flash"] = stats["flash"]
742
+
743
+ # ---- Build context ----
744
+ context = build_context(
745
+ model_info,
746
+ execution_order,
747
+ op_analysis,
748
+ # stats, # Keep this code for convert might need it
749
+ )
750
+
751
+ # ---- Render ----
752
+ result = render_code(
753
+ context,
754
+ output_dir,
755
+ target,
756
+ inference_func,
757
+ with_test_main,
758
+ accel_lib_inc,
759
+ accel_lib_lib,
760
+ )
761
+
762
+ # ---- Update state ----
763
+ for op in model_info.get("ops", []):
764
+ if op.get("state") == "translated":
765
+ op["state"] = "generated"
766
+ op["pass_flags"]["codegen"] = "success"
767
+
768
+ return result
769
+
770
+
771
+ def copy_files_to_build(output_dir: Path, target: str, mode: str, accel: str,
772
+ accel_lib_inc=None, accel_lib_lib=None):
773
+ """
774
+ Copy all files needed for build to tinymlc_generated/
775
+
776
+ Args:
777
+ output_dir: output directory (tinymlc_generated)
778
+ target: target architecture (riscv / arm / host)
779
+ mode: build mode (debug / release)
780
+ accel: acceleration library
781
+ """
782
+ # Determine source directory
783
+ ops_root = Path(__file__).parent.parent / "ops"
784
+ src_dir = ops_root / target
785
+
786
+ if not src_dir.exists():
787
+ fatal_error(
788
+ f"Architecture directory not found: {src_dir}",
789
+ f"Supported architectures: riscv, arm, host")
790
+
791
+ # 1. Copy common header files
792
+ include_src = ops_root / "include"
793
+ if include_src.exists():
794
+ shutil.copytree(include_src, output_dir / "include", dirs_exist_ok=True)
795
+
796
+ # 2. Copy C operators (ops/c/*.c) to output_dir/c/
797
+ c_src = ops_root / "c"
798
+ if c_src.exists():
799
+ shutil.copytree(c_src, output_dir / "c", dirs_exist_ok=True)
800
+
801
+ # 3. Copy accelerator-specific operators (override ops/c/*.c)
802
+ if accel == "cmsis-nn":
803
+ accel_src = ops_root / target / "cmsis_nn"
804
+ if accel_src.exists():
805
+ for file in accel_src.glob("*.c"):
806
+ shutil.copy2(file, output_dir / "c" / file.name)
807
+ elif accel == "nmsis-nn":
808
+ accel_src = ops_root / target / "nmsis_nn"
809
+ if accel_src.exists():
810
+ for file in accel_src.glob("*.c"):
811
+ shutil.copy2(file, output_dir / "c" / file.name)
812
+
813
+ # 4. Copy target architecture files
814
+ # Host only needs .c files (no .S, .ld)
815
+ if target == "host":
816
+ # Create host directory in output
817
+ host_src = ops_root / "host"
818
+ if host_src.exists():
819
+ shutil.copytree(host_src, output_dir / "host", dirs_exist_ok=True)
820
+ else:
821
+ # ARM/RISC-V need .c, .S, .ld files
822
+ for file in src_dir.glob("*.c"):
823
+ shutil.copy2(file, output_dir / file.name)
824
+ for file in src_dir.glob("*.S"):
825
+ shutil.copy2(file, output_dir / file.name)
826
+ for file in src_dir.glob("*.ld"):
827
+ shutil.copy2(file, output_dir / file.name)
828
+
829
+ # 5. Copy corresponding build script
830
+ if target == "host":
831
+ # Host only has debug build script
832
+ build_script = src_dir / "build_host_debug.sh"
833
+ elif accel != 'none':
834
+ accel_underscore = accel.replace("-", "_")
835
+ build_script = src_dir / f"build_{target}_{accel_underscore}_{mode}.sh"
836
+ else:
837
+ build_script = src_dir / f"build_{target}_{mode}.sh"
838
+
839
+ dest_build_script = output_dir / build_script.name
840
+
841
+ # Check if .sh or .tpl exists
842
+ tpl_script = src_dir / f"{build_script.name}.tpl"
843
+ if build_script.exists():
844
+ # Use .sh as source
845
+ source_script = build_script
846
+ use_template = tpl_script.exists() and accel_lib_inc and accel_lib_lib
847
+ elif tpl_script.exists():
848
+ # Use .tpl as source (no .sh file)
849
+ source_script = tpl_script
850
+ use_template = True
851
+ dest_build_script = output_dir / build_script.name # Output still .sh
852
+ else:
853
+ fatal_error(
854
+ f"Build script not found: {build_script}",
855
+ suggestion=f"Please check if accelerator type {accel} "
856
+ "is supported")
857
+
858
+ if use_template:
859
+ # Render template with accel library paths
860
+ with open(source_script, 'r') as f:
861
+ tmpl = Template(f.read())
862
+ rendered = tmpl.render(
863
+ accel_lib_inc=accel_lib_inc,
864
+ accel_lib_lib=accel_lib_lib
865
+ )
866
+ with open(dest_build_script, 'w') as f:
867
+ f.write(rendered)
868
+ else:
869
+ # Just copy the script
870
+ shutil.copy2(source_script, dest_build_script)
871
+
872
+ try:
873
+ current_mode = dest_build_script.stat().st_mode
874
+ dest_build_script.chmod(
875
+ current_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)
876
+ except OSError:
877
+ pass