tinymlc 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. TinyMLC/ANG/__init__.py +0 -0
  2. TinyMLC/ANG/args.py +86 -0
  3. TinyMLC/ANG/estimator.py +103 -0
  4. TinyMLC/ANG/estimator_hal.py +184 -0
  5. TinyMLC/ANG/estimator_qemu.py +257 -0
  6. TinyMLC/ANG/estimator_software.py +130 -0
  7. TinyMLC/ANG/model_builder.py +508 -0
  8. TinyMLC/ANG/model_generator.py +439 -0
  9. TinyMLC/ANG/model_info.py +283 -0
  10. TinyMLC/ANG/utils.py +420 -0
  11. TinyMLC/__init__.py +0 -0
  12. TinyMLC/cli.py +126 -0
  13. TinyMLC/codegen.py +877 -0
  14. TinyMLC/converter/__init__.py +0 -0
  15. TinyMLC/converter/export_weights.py +382 -0
  16. TinyMLC/converter/parser_litert.py +757 -0
  17. TinyMLC/converter/parser_onnx.py +649 -0
  18. TinyMLC/generate_lut.py +97 -0
  19. TinyMLC/handlers.py +325 -0
  20. TinyMLC/ops.py +76 -0
  21. TinyMLC/templates/lut.c.tpl +23 -0
  22. TinyMLC/templates/lut.h.tpl +67 -0
  23. TinyMLC/templates/model.c.tpl +314 -0
  24. TinyMLC/templates/model.h.tpl +66 -0
  25. TinyMLC/transform/__init__.py +0 -0
  26. TinyMLC/transform/algebraic.py +286 -0
  27. TinyMLC/transform/base.py +58 -0
  28. TinyMLC/transform/constant_folding.py +260 -0
  29. TinyMLC/transform/cse.py +192 -0
  30. TinyMLC/transform/dce.py +182 -0
  31. TinyMLC/transform/fusion.py +723 -0
  32. TinyMLC/transform/memory.py +200 -0
  33. TinyMLC/transform/pass_manager.py +101 -0
  34. TinyMLC/transform/simplify.py +515 -0
  35. tinymlc-0.1.0.dist-info/METADATA +49 -0
  36. tinymlc-0.1.0.dist-info/RECORD +47 -0
  37. tinymlc-0.1.0.dist-info/WHEEL +4 -0
  38. tinymlc-0.1.0.dist-info/entry_points.txt +2 -0
  39. tinymlc-0.1.0.dist-info/licenses/LICENSE +201 -0
  40. utils/__init__.py +0 -0
  41. utils/arm-none-eabi-gcc.cmake +53 -0
  42. utils/dump.py +86 -0
  43. utils/generate_onnx_models.py +183 -0
  44. utils/generate_tflite_models.py +236 -0
  45. utils/pack_macos.sh +88 -0
  46. utils/path.py +31 -0
  47. utils/riscv-none-elf-gcc.cmake +50 -0
@@ -0,0 +1,649 @@
1
+ # -*- coding: utf-8 -*-
2
+ # TinyMLC - Tiny Machine Learning Compiler
3
+ #
4
+ # Copyright (c) 2026 Jia Liu & TinyMLC Contributors
5
+ # SPDX-License-Identifier: Apache-2.0
6
+ #
7
+ # This file is part of TinyMLC.
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at:
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+
20
+ #!/usr/bin/env python3
21
+ """ONNX model parser"""
22
+
23
+ import onnx
24
+ from onnx import numpy_helper
25
+ from utils.dump import warning
26
+
27
+ # ONNX operator to TinyMLC IR mapping
28
+ OP_MAP = {
29
+ "Conv": "CONV_2D",
30
+ "Gemm": "FULLY_CONNECTED",
31
+ "MatMul": "FULLY_CONNECTED",
32
+ "Relu": "RELU",
33
+ "Softmax": "SOFTMAX",
34
+ "Add": "ADD",
35
+ "Sub": "SUB",
36
+ "Mul": "MULTIPLY",
37
+ "MaxPool": "MAX_POOL_2D",
38
+ "AveragePool": "AVERAGE_POOL_2D",
39
+ "Reshape": "RESHAPE",
40
+ "Transpose": "TRANSPOSE",
41
+ "Pad": "PAD",
42
+ "Mean": "MEAN",
43
+ "ReduceMean": "MEAN",
44
+ "LSTM": "UNIDIRECTIONAL_SEQUENCE_LSTM",
45
+ "SVDF": "SVDF",
46
+ "Concat": "CONCAT",
47
+ "Sigmoid": "SIGMOID",
48
+ "Tanh": "TANH",
49
+ }
50
+
51
+
52
+ def get_tensor_shape(graph, name, initializer_map=None):
53
+ """Get tensor shape from graph"""
54
+ # Check inputs
55
+ for inp in graph.input:
56
+ if inp.name == name:
57
+ return [dim.dim_value for dim in inp.type.tensor_type.shape.dim]
58
+ # Check outputs
59
+ for out in graph.output:
60
+ if out.name == name:
61
+ return [dim.dim_value for dim in out.type.tensor_type.shape.dim]
62
+ # Check intermediate tensors
63
+ for val in graph.value_info:
64
+ if val.name == name:
65
+ return [dim.dim_value for dim in val.type.tensor_type.shape.dim]
66
+ # Check initializer (quantized weights in QDQ models)
67
+ if initializer_map is not None and name in initializer_map:
68
+ return list(initializer_map[name].shape)
69
+ return []
70
+
71
+
72
+ def parse_model_onnx(model_path: str):
73
+ """Parse ONNX model, return model_info"""
74
+ # 0. Load model
75
+ model = onnx.load(model_path)
76
+ graph = model.graph
77
+
78
+ # 1. Build global tensor index mapping
79
+ tensor_index_map = {}
80
+ next_idx = 0
81
+
82
+ # Collect all tensor info
83
+ tensors = {}
84
+
85
+ # Assign indices to all weights
86
+ for init in graph.initializer:
87
+ tensor = numpy_helper.to_array(init)
88
+ tensor_index_map[init.name] = next_idx
89
+ tensors[next_idx] = {
90
+ "name": init.name,
91
+ "shape": list(tensor.shape),
92
+ "dtype": str(tensor.dtype),
93
+ "size": tensor.size,
94
+ "scale": 1.0,
95
+ "zero_point": 0,
96
+ }
97
+ next_idx += 1
98
+
99
+ # Assign indices to all inputs
100
+ for inp in graph.input:
101
+ if inp.name not in tensor_index_map:
102
+ tensor_index_map[inp.name] = next_idx
103
+ shape = [dim.dim_value for dim in inp.type.tensor_type.shape.dim]
104
+ tensors[next_idx] = {
105
+ "name": inp.name,
106
+ "shape": shape,
107
+ "dtype": "float32",
108
+ "size": 1 if not shape else 1,
109
+ "scale": 1.0,
110
+ "zero_point": 0,
111
+ }
112
+ # Calculate size
113
+ size = 1
114
+ for dim in shape:
115
+ size *= dim
116
+ tensors[next_idx]["size"] = size
117
+ next_idx += 1
118
+
119
+ # Assign indices to all outputs
120
+ for out in graph.output:
121
+ if out.name not in tensor_index_map:
122
+ tensor_index_map[out.name] = next_idx
123
+ shape = [dim.dim_value for dim in out.type.tensor_type.shape.dim]
124
+ tensors[next_idx] = {
125
+ "name": out.name,
126
+ "shape": shape,
127
+ "dtype": "float32",
128
+ "size": 1,
129
+ "scale": 1.0,
130
+ "zero_point": 0,
131
+ }
132
+ size = 1
133
+ for dim in shape:
134
+ size *= dim
135
+ tensors[next_idx]["size"] = size
136
+ next_idx += 1
137
+
138
+ # Assign indices to all intermediate tensors
139
+ for val in graph.value_info:
140
+ if val.name not in tensor_index_map:
141
+ tensor_index_map[val.name] = next_idx
142
+ shape = [dim.dim_value for dim in val.type.tensor_type.shape.dim]
143
+ tensors[next_idx] = {
144
+ "name": val.name,
145
+ "shape": shape,
146
+ "dtype": "float32",
147
+ "size": 1,
148
+ "scale": 1.0,
149
+ "zero_point": 0,
150
+ }
151
+ size = 1
152
+ for dim in shape:
153
+ size *= dim
154
+ tensors[next_idx]["size"] = size
155
+ next_idx += 1
156
+
157
+ # Build initializer name to array mapping (for QDQ nodes)
158
+ initializer_map = {}
159
+ for init in graph.initializer:
160
+ initializer_map[init.name] = numpy_helper.to_array(init)
161
+
162
+ # Assign indices to QDQ node outputs (these tensors are not in value_info)
163
+ for node in graph.node:
164
+ for out_name in node.output:
165
+ if out_name not in tensor_index_map:
166
+ # Get shape: if quantized weight, read from
167
+ # initializer; otherwise use input shape
168
+ shape = []
169
+ if out_name in initializer_map:
170
+ shape = list(initializer_map[out_name].shape)
171
+ elif node.input:
172
+ inp_name = node.input[0]
173
+ if inp_name in tensor_index_map:
174
+ shape = tensors[tensor_index_map[inp_name]].get(
175
+ "shape", []
176
+ )
177
+ else:
178
+ shape = get_tensor_shape(
179
+ graph, inp_name, initializer_map
180
+ )
181
+
182
+ size = 1
183
+ for dim in shape:
184
+ size *= dim
185
+
186
+ tensor_index_map[out_name] = next_idx
187
+ tensors[next_idx] = {
188
+ "name": out_name,
189
+ "shape": shape,
190
+ "dtype": (
191
+ "int8"
192
+ if node.op_type == "QuantizeLinear"
193
+ else "float32"
194
+ ),
195
+ "size": size,
196
+ "scale": 1.0,
197
+ "zero_point": 0,
198
+ }
199
+ next_idx += 1
200
+
201
+ # Parse QDQ quantization parameters (QuantizeLinear/DequantizeLinear)
202
+
203
+ # QDQ mapping: quantized node output -> original input
204
+ # Used to replace computation operator inputs from
205
+ # QuantizeLinear/DequantizeLinear outputs to original tensors
206
+ qdq_map = {}
207
+
208
+ # Traverse all nodes, extract quantization params and build mapping
209
+ for node in graph.node:
210
+ if node.op_type in ["QuantizeLinear", "DequantizeLinear"]:
211
+ input_name = node.input[0]
212
+ output_name = node.output[0]
213
+ scale_name = node.input[1]
214
+
215
+ # Build mapping: quantized node output -> original input
216
+ qdq_map[output_name] = input_name
217
+ # Recursive mapping: if input is also
218
+ # quantized node output, continue mapping
219
+ while input_name in qdq_map:
220
+ input_name = qdq_map[input_name]
221
+ qdq_map[output_name] = input_name
222
+
223
+ if scale_name in initializer_map:
224
+ scale_arr = initializer_map[scale_name]
225
+ scale_val = float(scale_arr.flat[0])
226
+ if input_name in tensor_index_map:
227
+ idx = tensor_index_map[input_name]
228
+ tensors[idx]["scale"] = scale_val
229
+ if output_name in tensor_index_map:
230
+ idx = tensor_index_map[output_name]
231
+ tensors[idx]["scale"] = scale_val
232
+
233
+ if len(node.input) >= 3 and node.input[2] in initializer_map:
234
+ zp_arr = initializer_map[node.input[2]]
235
+ zp_val = int(zp_arr.flat[0])
236
+ if input_name in tensor_index_map:
237
+ idx = tensor_index_map[input_name]
238
+ tensors[idx]["zero_point"] = zp_val
239
+ if output_name in tensor_index_map:
240
+ idx = tensor_index_map[output_name]
241
+ tensors[idx]["zero_point"] = zp_val
242
+
243
+ # 2. Get inputs and outputs
244
+ input_details = []
245
+ for inp in graph.input:
246
+ # Skip weights (initializers)
247
+ if inp.name in [init.name for init in graph.initializer]:
248
+ continue
249
+ shape = [dim.dim_value for dim in inp.type.tensor_type.shape.dim]
250
+ input_details.append({
251
+ "name": inp.name,
252
+ "shape": shape,
253
+ "dtype": "float32", # ONNX default float
254
+ })
255
+
256
+ output_details = []
257
+ for out in graph.output:
258
+ shape = [dim.dim_value for dim in out.type.tensor_type.shape.dim]
259
+ output_details.append({
260
+ "name": out.name,
261
+ "shape": shape,
262
+ "dtype": "float32",
263
+ })
264
+
265
+ # 3. Parse operators
266
+ ops = []
267
+ # Pseudo operators in QDQ models, no code generation needed
268
+ skip_ops = {"QuantizeLinear", "DequantizeLinear", "Constant"}
269
+
270
+ for node in graph.node:
271
+ # Skip pseudo operators
272
+ if node.op_type in skip_ops:
273
+ continue
274
+
275
+ # Use QDQ mapping to replace inputs: replace
276
+ # DequantizeLinear output with original input
277
+ mapped_inputs = []
278
+ for inp_name in node.input:
279
+ # If input is DequantizeLinear output, replace with original input
280
+ if inp_name in qdq_map:
281
+ mapped_inputs.append(qdq_map[inp_name])
282
+ else:
283
+ mapped_inputs.append(inp_name)
284
+
285
+ op_info = {
286
+ "index": len(ops),
287
+ "op_name": node.op_type,
288
+ "inputs": mapped_inputs,
289
+ "outputs": list(node.output),
290
+ "input_indices": [
291
+ tensor_index_map.get(name, -1)
292
+ for name in mapped_inputs
293
+ ],
294
+ "output_indices": [
295
+ tensor_index_map.get(name, -1)
296
+ for name in node.output
297
+ ],
298
+ "state": "created",
299
+ "pass_flags": {},
300
+ "input_details": [],
301
+ "output_details": [],
302
+ }
303
+
304
+ # Fill input_details
305
+ for inp_name in mapped_inputs:
306
+ shape = []
307
+ size = 1
308
+ if inp_name in tensor_index_map:
309
+ idx = tensor_index_map[inp_name]
310
+ shape = tensors[idx].get("shape", [])
311
+ size = tensors[idx].get("size", 1)
312
+ else:
313
+ shape = get_tensor_shape(graph, inp_name, initializer_map)
314
+ for dim in shape:
315
+ size *= dim
316
+ scale = 1.0
317
+ zero_point = 0
318
+ if inp_name in tensor_index_map:
319
+ idx = tensor_index_map[inp_name]
320
+ scale = tensors[idx].get("scale", 1.0)
321
+ zero_point = tensors[idx].get("zero_point", 0)
322
+ op_info["input_details"].append({
323
+ "index": len(op_info["input_details"]),
324
+ "name": inp_name,
325
+ "shape": shape,
326
+ "size": size,
327
+ "scale": scale,
328
+ "zero_point": zero_point,
329
+ })
330
+ # Fill output_details
331
+ for out_name in node.output:
332
+ shape = []
333
+ size = 1
334
+ if out_name in tensor_index_map:
335
+ idx = tensor_index_map[out_name]
336
+ shape = tensors[idx].get("shape", [])
337
+ size = tensors[idx].get("size", 1)
338
+ else:
339
+ shape = get_tensor_shape(graph, out_name, initializer_map)
340
+ for dim in shape:
341
+ size *= dim
342
+ scale = 1.0
343
+ zero_point = 0
344
+ if out_name in tensor_index_map:
345
+ idx = tensor_index_map[out_name]
346
+ scale = tensors[idx].get("scale", 1.0)
347
+ zero_point = tensors[idx].get("zero_point", 0)
348
+ op_info["output_details"].append({
349
+ "index": len(op_info["output_details"]),
350
+ "name": out_name,
351
+ "shape": shape,
352
+ "size": size,
353
+ "scale": scale,
354
+ "zero_point": zero_point,
355
+ })
356
+
357
+ # Map to TinyMLC operator name
358
+ if node.op_type in OP_MAP:
359
+ op_info["op_name"] = OP_MAP[node.op_type]
360
+ op_info["state"] = "translated"
361
+ op_info["pass_flags"]["onnx_parse"] = "success"
362
+ else:
363
+ warning(f"Unknown ONNX operator: {node.op_type}")
364
+ op_info["state"] = "created"
365
+ op_info["pass_flags"]["unknown"] = "needs_implementation"
366
+
367
+ # Handle special operators
368
+ # Gemm is FULLY_CONNECTED
369
+ if node.op_type == "Gemm":
370
+ op_info["data_input_idx"] = op_info["input_indices"][0]
371
+ op_info["fc_weights_idx"] = op_info["input_indices"][1]
372
+ op_info["fc_bias_idx"] = (
373
+ op_info["input_indices"][2]
374
+ if len(node.input) >= 3 else None
375
+ )
376
+ op_info["weights_name"] = node.input[1]
377
+ op_info["bias_name"] = (
378
+ node.input[2]
379
+ if len(node.input) >= 3 else None
380
+ )
381
+
382
+ # fc_scale will be calculated during weight extraction
383
+ op_info["fc_scale"] = 0.01 # default, will be updated
384
+ op_info["fc_output_scale"] = 1.0 # default
385
+
386
+ if node.op_type == "Conv":
387
+ op_info["data_input_idx"] = op_info["input_indices"][0]
388
+ op_info["conv_weights_idx"] = op_info["input_indices"][1]
389
+ op_info["conv_bias_idx"] = (
390
+ op_info["input_indices"][2]
391
+ if len(node.input) >= 3 else None
392
+ )
393
+
394
+ weights_name = node.input[1]
395
+
396
+ # Infer kernel_shape from weight shape
397
+ kernel_h, kernel_w = 1, 1
398
+ if weights_name in initializer_map:
399
+ weight_tensor = initializer_map[weights_name]
400
+ if len(weight_tensor.shape) >= 4:
401
+ kernel_h = weight_tensor.shape[2]
402
+ kernel_w = weight_tensor.shape[3]
403
+
404
+ # Get input shape
405
+ input_shape = get_tensor_shape(graph, node.input[0])
406
+ output_shape = get_tensor_shape(graph, node.output[0])
407
+
408
+ input_h = input_shape[2] if len(input_shape) >= 4 else 1
409
+ input_w = input_shape[3] if len(input_shape) >= 4 else 1
410
+ input_c = input_shape[1] if len(input_shape) >= 4 else 1
411
+
412
+ output_h = output_shape[2] if len(output_shape) >= 4 else 1
413
+ output_w = output_shape[3] if len(output_shape) >= 4 else 1
414
+ output_c = output_shape[1] if len(output_shape) >= 4 else 1
415
+
416
+ # Extract stride and padding from attributes
417
+ strides = [1, 1]
418
+ pads = [0, 0, 0, 0]
419
+ for attr in node.attribute:
420
+ if attr.name == "strides":
421
+ strides = list(attr.ints)
422
+ elif attr.name == "pads":
423
+ pads = list(attr.ints)
424
+
425
+ op_info["conv_params"] = {
426
+ "input_h": input_h,
427
+ "input_w": input_w,
428
+ "input_c": input_c,
429
+ "output_h": output_h,
430
+ "output_w": output_w,
431
+ "output_c": output_c,
432
+ "kernel_h": kernel_h,
433
+ "kernel_w": kernel_w,
434
+ "stride_h": strides[0] if len(strides) >= 2 else strides[0],
435
+ "stride_w": strides[1] if len(strides) >= 2 else strides[0],
436
+ "padding_h": pads[0] if len(pads) >= 2 else 0,
437
+ "padding_w": pads[1] if len(pads) >= 2 else 0,
438
+ }
439
+
440
+ if node.op_type == "Softmax":
441
+ # Get softmax axis size from output shape
442
+ output_shape = get_tensor_shape(
443
+ graph, node.output[0], initializer_map)
444
+ if output_shape:
445
+ # Default axis is -1 (last dimension)
446
+ axis = -1
447
+ for attr in node.attribute:
448
+ if attr.name == "axis":
449
+ axis = attr.i
450
+ # Convert negative axis to positive
451
+ if axis < 0:
452
+ axis = len(output_shape) + axis
453
+ op_info["softmax_size"] = (output_shape[axis]
454
+ if axis < len(output_shape)
455
+ else output_shape[-1])
456
+ else:
457
+ # Fallback: get from output_details
458
+ if op_info["output_details"]:
459
+ op_info["softmax_size"] = (
460
+ op_info["output_details"][0].get("size", 10))
461
+
462
+ if node.op_type == "Reshape":
463
+ # Target shape in inputs[1]
464
+ target_shape_name = node.input[1]
465
+ target_shape = initializer_map.get(target_shape_name)
466
+ if target_shape is not None:
467
+ op_info["reshape_target_shape"] = target_shape.tolist()
468
+
469
+ # Extract conv params (from attributes)
470
+ if node.op_type in ["MaxPool", "AveragePool"]:
471
+ strides = [1, 1]
472
+ pads = [0, 0, 0, 0]
473
+ kernel_shape = []
474
+ for attr in node.attribute:
475
+ if attr.name == "strides":
476
+ strides = list(attr.ints)
477
+ elif attr.name == "pads":
478
+ pads = list(attr.ints)
479
+ elif attr.name == "kernel_shape":
480
+ kernel_shape = list(attr.ints)
481
+
482
+ input_shape = get_tensor_shape(
483
+ graph, node.input[0], initializer_map
484
+ )
485
+ output_shape = get_tensor_shape(
486
+ graph, node.output[0], initializer_map
487
+ )
488
+
489
+ op_info["pool_params"] = {
490
+ "input_h": input_shape[2] if len(input_shape) >= 4 else 1,
491
+ "input_w": input_shape[3] if len(input_shape) >= 4 else 1,
492
+ "input_c": input_shape[1] if len(input_shape) >= 4 else 1,
493
+ "output_h": output_shape[2] if len(output_shape) >= 4 else 1,
494
+ "output_w": output_shape[3] if len(output_shape) >= 4 else 1,
495
+ "output_c": output_shape[1] if len(output_shape) >= 4 else 1,
496
+ "pool_size_h": (
497
+ kernel_shape[0] if len(kernel_shape) >= 2
498
+ else kernel_shape[0]
499
+ ),
500
+ "pool_size_w": (
501
+ kernel_shape[1] if len(kernel_shape) >= 2
502
+ else kernel_shape[0]
503
+ ),
504
+ "stride_h": strides[0] if len(strides) >= 2 else strides[0],
505
+ "stride_w": strides[1] if len(strides) >= 2 else strides[0],
506
+ }
507
+
508
+ op_info["data_input_idx"] = op_info["input_indices"][0]
509
+
510
+ if node.op_type == "SVDF":
511
+ inputs = op_info["input_indices"]
512
+ if len(inputs) >= 3:
513
+ op_info["data_input_idx"] = inputs[0]
514
+ op_info["svdf_weights_idx"] = inputs[1]
515
+ op_info["svdf_bias_idx"] = inputs[2]
516
+
517
+ rank = 1
518
+ activation_function = "Tanh"
519
+ for attr in node.attribute:
520
+ if attr.name == "rank":
521
+ rank = attr.i
522
+ elif attr.name == "activation_function":
523
+ activation_function = attr.s.decode('utf-8')
524
+
525
+ input_shape = get_tensor_shape(
526
+ graph, node.input[0], initializer_map
527
+ )
528
+ output_shape = get_tensor_shape(
529
+ graph, node.output[0], initializer_map
530
+ )
531
+
532
+ time_steps = input_shape[1] if len(input_shape) >= 3 else 1
533
+ input_size = input_shape[2] if len(input_shape) >= 3 else 1
534
+ units = output_shape[2] // rank if len(output_shape) >= 3 else 1
535
+
536
+ op_info["svdf_params"] = {
537
+ "rank": rank,
538
+ "activation_function": activation_function,
539
+ "time_steps": time_steps,
540
+ "input_size": input_size,
541
+ "units": units,
542
+ }
543
+
544
+ if node.op_type in ["Mean", "ReduceMean"]:
545
+ input_shape = get_tensor_shape(
546
+ graph, node.input[0], initializer_map
547
+ )
548
+ output_shape = get_tensor_shape(
549
+ graph, node.output[0], initializer_map
550
+ )
551
+
552
+ op_info["data_input_idx"] = op_info["input_indices"][0]
553
+ op_info["mean_params"] = {
554
+ "input_dims": len(input_shape),
555
+ "input_shape": input_shape,
556
+ "output_shape": output_shape,
557
+ }
558
+
559
+ ops.append(op_info)
560
+
561
+ # Store raw weights (without standardized keys)
562
+ raw_weights = {}
563
+ for init in graph.initializer:
564
+ tensor = numpy_helper.to_array(init)
565
+ raw_weights[init.name] = tensor
566
+
567
+ return {
568
+ "input": input_details,
569
+ "output": output_details,
570
+ "ops": ops,
571
+ "weights": raw_weights, # Raw weights, will be processed
572
+ # by extract_all_weights_onnx
573
+ "tensors": tensors,
574
+ "initializer_map": initializer_map, # Keep for weight extraction
575
+ }
576
+
577
+
578
+ def extract_all_weights_onnx(model_path, model_info):
579
+ """Extract all weights from ONNX model and store in model_info['weights']
580
+
581
+ Uses source-specific keys: fc_onnx.weight, conv_onnx.weight, etc.
582
+
583
+ Args:
584
+ model_path: ONNX model file path (for consistency with LiteRT interface)
585
+ model_info: Model info dict from parse_model_onnx
586
+ """
587
+ weights = model_info.get("weights", {})
588
+ initializer_map = model_info.get("initializer_map", {})
589
+ ops = model_info.get("ops", [])
590
+
591
+ # Extract weights with standardized keys based on operator type
592
+ for op in ops:
593
+ op_name = op.get("op_name")
594
+
595
+ if op_name == "FULLY_CONNECTED":
596
+ # Gemm/MatMul operator
597
+ input_indices = op.get("input_indices", [])
598
+ if len(input_indices) >= 2:
599
+ weights_idx = input_indices[1]
600
+ weights_name = op.get("weights_name")
601
+ if weights_name and weights_name in weights:
602
+ weights["fc_onnx.weight"] = weights[weights_name]
603
+ if len(input_indices) >= 3:
604
+ bias_name = op.get("bias_name")
605
+ if bias_name and bias_name in weights:
606
+ weights["fc_onnx.bias"] = weights[bias_name]
607
+
608
+ elif op_name == "CONV_2D":
609
+ # Conv operator
610
+ input_indices = op.get("input_indices", [])
611
+ if len(input_indices) >= 2:
612
+ weights_name = None
613
+ # Find weights name from original node input
614
+ tensors = model_info.get("tensors", {})
615
+ for idx, tensor_info in tensors.items():
616
+ if idx == input_indices[1]:
617
+ weights_name = tensor_info.get("name")
618
+ break
619
+ if weights_name and weights_name in weights:
620
+ weights["conv_onnx.weight"] = weights[weights_name]
621
+ if len(input_indices) >= 3:
622
+ bias_idx = input_indices[2]
623
+ bias_name = None
624
+ for idx, tensor_info in tensors.items():
625
+ if idx == bias_idx:
626
+ bias_name = tensor_info.get("name")
627
+ break
628
+ if bias_name and bias_name in weights:
629
+ weights["conv_onnx.bias"] = weights[bias_name]
630
+
631
+ elif op_name == "SVDF":
632
+ # SVDF operator
633
+ input_indices = op.get("input_indices", [])
634
+ if len(input_indices) >= 3:
635
+ weights_idx = input_indices[1]
636
+ bias_idx = input_indices[2]
637
+ # Find weights name
638
+ tensors = model_info.get("tensors", {})
639
+ for idx, tensor_info in tensors.items():
640
+ if idx == weights_idx:
641
+ weights["svdf_onnx.weight"] = weights.get(
642
+ tensor_info.get("name"))
643
+ if idx == bias_idx:
644
+ weights["svdf_onnx.bias"] = weights.get(
645
+ tensor_info.get("name"))
646
+
647
+ # Update model_info weights
648
+ model_info["weights"] = weights
649
+ return weights