tinymlc 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. TinyMLC/ANG/__init__.py +0 -0
  2. TinyMLC/ANG/args.py +86 -0
  3. TinyMLC/ANG/estimator.py +103 -0
  4. TinyMLC/ANG/estimator_hal.py +184 -0
  5. TinyMLC/ANG/estimator_qemu.py +257 -0
  6. TinyMLC/ANG/estimator_software.py +130 -0
  7. TinyMLC/ANG/model_builder.py +508 -0
  8. TinyMLC/ANG/model_generator.py +439 -0
  9. TinyMLC/ANG/model_info.py +283 -0
  10. TinyMLC/ANG/utils.py +420 -0
  11. TinyMLC/__init__.py +0 -0
  12. TinyMLC/cli.py +126 -0
  13. TinyMLC/codegen.py +877 -0
  14. TinyMLC/converter/__init__.py +0 -0
  15. TinyMLC/converter/export_weights.py +382 -0
  16. TinyMLC/converter/parser_litert.py +757 -0
  17. TinyMLC/converter/parser_onnx.py +649 -0
  18. TinyMLC/generate_lut.py +97 -0
  19. TinyMLC/handlers.py +325 -0
  20. TinyMLC/ops.py +76 -0
  21. TinyMLC/templates/lut.c.tpl +23 -0
  22. TinyMLC/templates/lut.h.tpl +67 -0
  23. TinyMLC/templates/model.c.tpl +314 -0
  24. TinyMLC/templates/model.h.tpl +66 -0
  25. TinyMLC/transform/__init__.py +0 -0
  26. TinyMLC/transform/algebraic.py +286 -0
  27. TinyMLC/transform/base.py +58 -0
  28. TinyMLC/transform/constant_folding.py +260 -0
  29. TinyMLC/transform/cse.py +192 -0
  30. TinyMLC/transform/dce.py +182 -0
  31. TinyMLC/transform/fusion.py +723 -0
  32. TinyMLC/transform/memory.py +200 -0
  33. TinyMLC/transform/pass_manager.py +101 -0
  34. TinyMLC/transform/simplify.py +515 -0
  35. tinymlc-0.1.0.dist-info/METADATA +49 -0
  36. tinymlc-0.1.0.dist-info/RECORD +47 -0
  37. tinymlc-0.1.0.dist-info/WHEEL +4 -0
  38. tinymlc-0.1.0.dist-info/entry_points.txt +2 -0
  39. tinymlc-0.1.0.dist-info/licenses/LICENSE +201 -0
  40. utils/__init__.py +0 -0
  41. utils/arm-none-eabi-gcc.cmake +53 -0
  42. utils/dump.py +86 -0
  43. utils/generate_onnx_models.py +183 -0
  44. utils/generate_tflite_models.py +236 -0
  45. utils/pack_macos.sh +88 -0
  46. utils/path.py +31 -0
  47. utils/riscv-none-elf-gcc.cmake +50 -0
@@ -0,0 +1,757 @@
1
+ # -*- coding: utf-8 -*-
2
+ # TinyMLC - Tiny Machine Learning Compiler
3
+ #
4
+ # Copyright (c) 2026 Jia Liu & TinyMLC Contributors
5
+ # SPDX-License-Identifier: Apache-2.0
6
+ #
7
+ # This file is part of TinyMLC.
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at:
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+
20
+ #!/usr/bin/env python3
21
+ """LiteRT-based TFLite model parser
22
+
23
+ Note: Uses private API _get_ops_details() as LiteRT has no
24
+ public operator list API.
25
+ LSTM input indices [1:5], [5:9], [9:13] follow TFLite
26
+ UNIDIRECTIONAL_SEQUENCE_LSTM spec.
27
+ """
28
+
29
+ from ai_edge_litert.interpreter import Interpreter
30
+ import numpy as np
31
+
32
+ from utils.dump import fatal_error, info, warning
33
+
34
+
35
+ # Weight extraction functions
36
+ def extract_fc_weights(interpreter, op_info):
37
+ """Extract FULLY_CONNECTED layer weights and bias using LiteRT API"""
38
+ weights_idx = op_info.get("fc_weights_idx")
39
+ bias_idx = op_info.get("fc_bias_idx")
40
+
41
+ if weights_idx is None:
42
+ return None, None
43
+
44
+ try:
45
+ weights = interpreter.get_tensor(weights_idx)
46
+ bias = (interpreter.get_tensor(bias_idx)
47
+ if bias_idx is not None else None)
48
+ except ValueError as e:
49
+ fatal_error(f"Cannot get FC tensor: {e}",
50
+ "Ensure model is loaded correctly")
51
+
52
+ info(f"FC weights: shape={weights.shape}, dtype={weights.dtype}")
53
+ if bias is not None:
54
+ info(f"FC bias: shape={bias.shape}, dtype={bias.dtype}")
55
+ return weights, bias
56
+
57
+
58
+ def extract_lstm_weights(interpreter, op_info):
59
+ """Extract LSTM gate weights and biases using LiteRT API
60
+
61
+ TFLite LSTM input indices follow standard spec:
62
+ - inputs[1:5]: input gate weights (i, f, g, o)
63
+ - inputs[5:9]: recurrent weights (i, f, g, o)
64
+ - inputs[9:13]: biases (i, f, g, o)
65
+ """
66
+ indices = op_info.get("lstm_weight_indices", {})
67
+ input_indices = indices.get("input", [])
68
+ recurrent_indices = indices.get("recurrent", [])
69
+ bias_indices = indices.get("bias", [])
70
+
71
+ if not input_indices or not recurrent_indices:
72
+ return None
73
+
74
+ gate_order = ['i', 'f', 'g', 'o']
75
+ lstm_weights = {'input': {}, 'recurrent': {}, 'bias': {}}
76
+
77
+ for gate, idx in zip(gate_order, input_indices):
78
+ try:
79
+ lstm_weights['input'][gate] = interpreter.get_tensor(idx)
80
+ except ValueError:
81
+ lstm_weights['input'][gate] = None
82
+
83
+ for gate, idx in zip(gate_order, recurrent_indices):
84
+ try:
85
+ lstm_weights['recurrent'][gate] = interpreter.get_tensor(idx)
86
+ except ValueError:
87
+ lstm_weights['recurrent'][gate] = None
88
+
89
+ for gate, idx in zip(gate_order, bias_indices):
90
+ try:
91
+ lstm_weights['bias'][gate] = interpreter.get_tensor(idx)
92
+ except ValueError:
93
+ lstm_weights['bias'][gate] = None
94
+
95
+ return lstm_weights
96
+
97
+
98
+ def extract_conv_weights(interpreter, op_info):
99
+ """Extract CONV_2D layer weights and bias using LiteRT API"""
100
+ weights_idx = op_info.get("conv_weights_idx")
101
+ bias_idx = op_info.get("conv_bias_idx")
102
+
103
+ if weights_idx is None:
104
+ return None, None
105
+
106
+ try:
107
+ weights = interpreter.get_tensor(weights_idx)
108
+ bias = (interpreter.get_tensor(bias_idx)
109
+ if bias_idx is not None else None)
110
+ except ValueError as e:
111
+ fatal_error(f"Cannot get CONV_2D tensor: {e}",
112
+ "Ensure model is loaded correctly")
113
+
114
+ info(f"CONV_2D weights: shape={weights.shape}, dtype={weights.dtype}")
115
+ if bias is not None:
116
+ info(f"CONV_2D bias: shape={bias.shape}, dtype={bias.dtype}")
117
+ return weights, bias
118
+
119
+
120
+ def extract_dw_weights(interpreter, op_info):
121
+ """Extract DEPTHWISE_CONV_2D layer weights and bias using LiteRT API"""
122
+ weights_idx = op_info.get("dw_weights_idx")
123
+ bias_idx = op_info.get("dw_bias_idx")
124
+
125
+ if weights_idx is None:
126
+ return None, None
127
+
128
+ try:
129
+ weights = interpreter.get_tensor(weights_idx)
130
+ bias = (interpreter.get_tensor(bias_idx)
131
+ if bias_idx is not None else None)
132
+ except ValueError as e:
133
+ fatal_error(f"Cannot get DEPTHWISE_CONV_2D tensor: {e}",
134
+ "Ensure model is loaded correctly")
135
+
136
+ info(f"DEPTHWISE_CONV_2D weights: shape={weights.shape}, "
137
+ f"dtype={weights.dtype}")
138
+ if bias is not None:
139
+ info(f"DEPTHWISE_CONV_2D bias: shape={bias.shape}, dtype={bias.dtype}")
140
+ return weights, bias
141
+
142
+
143
+ def extract_all_weights_litert(model_path, model_info):
144
+ """Extract all weights from LiteRT model and store in model_info['weights']
145
+
146
+ Uses source-specific keys: fc_tflite.weight, lstm_tflite.weight_ih, etc.
147
+
148
+ Args:
149
+ model_path: TFLite model file path
150
+ model_info: Model info dict from parse_model_tflite
151
+ """
152
+ # Create interpreter internally
153
+ interpreter = Interpreter(model_path=model_path)
154
+ interpreter.allocate_tensors()
155
+
156
+ fc_op_info = None
157
+ lstm_op_info = None
158
+ conv_op_info = None
159
+ dw_op_info = None
160
+
161
+ for op in model_info["ops"]:
162
+ if op["op_name"] == "FULLY_CONNECTED":
163
+ fc_op_info = op
164
+ elif op["op_name"] == "UNIDIRECTIONAL_SEQUENCE_LSTM":
165
+ lstm_op_info = op
166
+ elif op["op_name"] == "CONV_2D":
167
+ conv_op_info = op
168
+ elif op["op_name"] == "DEPTHWISE_CONV_2D":
169
+ dw_op_info = op
170
+
171
+ fc_weights, fc_bias = (extract_fc_weights(interpreter, fc_op_info)
172
+ if fc_op_info else (None, None))
173
+ lstm_weights = (extract_lstm_weights(interpreter, lstm_op_info)
174
+ if lstm_op_info else None)
175
+ conv_weights, conv_bias = (extract_conv_weights(interpreter, conv_op_info)
176
+ if conv_op_info else (None, None))
177
+ dw_weights, dw_bias = (extract_dw_weights(interpreter, dw_op_info)
178
+ if dw_op_info else (None, None))
179
+
180
+ model_info["weights"] = {}
181
+
182
+ if fc_weights is not None and fc_bias is not None:
183
+ model_info["weights"]["fc_tflite.weight"] = fc_weights
184
+ model_info["weights"]["fc_tflite.bias"] = fc_bias
185
+
186
+ if lstm_weights and lstm_weights['input']:
187
+ gates = ['i', 'f', 'g', 'o']
188
+ if all(lstm_weights['input'].get(g) is not None for g in gates):
189
+ input_concat = np.concatenate(
190
+ [lstm_weights['input'][g].flatten() for g in gates])
191
+ model_info["weights"]["lstm_tflite.weight_ih"] = input_concat
192
+ if all(lstm_weights['recurrent'].get(g) is not None for g in gates):
193
+ recurrent_concat = np.concatenate(
194
+ [lstm_weights['recurrent'][g].flatten() for g in gates])
195
+ model_info["weights"]["lstm_tflite.weight_hh"] = recurrent_concat
196
+ if all(lstm_weights['bias'].get(g) is not None for g in gates):
197
+ bias_concat = np.concatenate(
198
+ [lstm_weights['bias'][g].flatten() for g in gates])
199
+ model_info["weights"]["lstm_tflite.bias"] = bias_concat
200
+
201
+ if conv_weights is not None:
202
+ model_info["weights"]["conv_tflite.weight"] = conv_weights
203
+ if conv_bias is not None:
204
+ model_info["weights"]["conv_tflite.bias"] = conv_bias
205
+
206
+ if dw_weights is not None:
207
+ model_info["weights"]["dw_tflite.weight"] = dw_weights
208
+ if dw_bias is not None:
209
+ model_info["weights"]["dw_tflite.bias"] = dw_bias
210
+
211
+ return fc_weights, fc_bias, lstm_weights, conv_weights, \
212
+ conv_bias, dw_weights, dw_bias
213
+
214
+
215
+ def parse_model_tflite(model_path: str):
216
+ """Parse TFLite model using LiteRT"""
217
+
218
+ # 1. Load model (using LiteRT Interpreter)
219
+ interpreter = Interpreter(model_path=model_path)
220
+ interpreter.allocate_tensors()
221
+
222
+ # 2. Get input/output tensors (same as old API)
223
+ input_details = interpreter.get_input_details()
224
+ output_details = interpreter.get_output_details()
225
+
226
+ # 3. Get all tensor info, try to read constant tensor data
227
+ tensor_details = interpreter.get_tensor_details()
228
+ tensors = {} # Unified key: tensors (was tensors)
229
+
230
+ for tensor in tensor_details:
231
+ shape = tensor["shape"]
232
+ tensor_info = {
233
+ "name": tensor["name"],
234
+ "shape": list(tensor["shape"]),
235
+ "dtype": str(tensor["dtype"]),
236
+ "size": int(np.prod(shape)) if shape is not None and len(
237
+ shape) > 0 else 1,
238
+ "scale": (
239
+ tensor["quantization"][0]
240
+ if tensor["quantization"][0] is not None else 1.0
241
+ ),
242
+ "zero_point": (
243
+ tensor["quantization"][1]
244
+ if tensor["quantization"][1] is not None else 0
245
+ ),
246
+ }
247
+
248
+ # Try to read constant tensor data (for Reshape target shape etc)
249
+ try:
250
+ data = interpreter.get_tensor(tensor["index"])
251
+ tensor_info["data"] = data
252
+ except:
253
+ pass
254
+
255
+ tensors[tensor["index"]] = tensor_info
256
+
257
+ # 4. Get operator list
258
+ ops = []
259
+ for op in interpreter._get_ops_details():
260
+ # Skip DELEGATE operator
261
+ if op["op_name"] == "DELEGATE":
262
+ continue
263
+
264
+ op_info = {
265
+ "index": op["index"],
266
+ "op_name": op["op_name"],
267
+ "inputs": [inp for inp in op["inputs"] if inp != -1],
268
+ "outputs": [out for out in op["outputs"] if out != -1],
269
+ "input_indices": [inp for inp in op["inputs"] if inp != -1],
270
+ "output_indices": [out for out in op["outputs"] if out != -1],
271
+ "state": "created",
272
+ "pass_flags": {},
273
+ "input_details": [],
274
+ "output_details": [],
275
+ }
276
+
277
+ # Set state based on operator type
278
+ if op["op_name"] == "ADD":
279
+ inputs = op_info["input_indices"]
280
+ # Input count based on operator spec:
281
+ # inputs[0] inputs[1] are valid
282
+ if len(inputs) >= 2:
283
+ op_info["add_input1_idx"] = inputs[0]
284
+ op_info["add_input2_idx"] = inputs[1]
285
+ op_info["state"] = "translated"
286
+ op_info["pass_flags"]["add_check"] = "success"
287
+ elif op["op_name"] == "FULLY_CONNECTED":
288
+ inputs = op_info["input_indices"]
289
+ # Input count based on operator spec:
290
+ # inputs[0] inputs[1] inputs[2] are valid
291
+ if len(inputs) >= 3:
292
+ op_info["data_input_idx"] = inputs[0]
293
+ op_info["fc_weights_idx"] = inputs[1]
294
+ op_info["fc_bias_idx"] = inputs[2]
295
+ # Input count based on operator spec:
296
+ # inputs[0] inputs[1] are valid
297
+ elif len(inputs) >= 2:
298
+ op_info["data_input_idx"] = inputs[0]
299
+ op_info["fc_weights_idx"] = inputs[1]
300
+ op_info["fc_bias_idx"] = None
301
+ else:
302
+ # Input count is based on operator spec, inputs[0] is valid
303
+ op_info["data_input_idx"] = inputs[0]
304
+ op_info["fc_weights_idx"] = None
305
+ op_info["fc_bias_idx"] = None
306
+ op_info["state"] = "translated"
307
+ op_info["pass_flags"]["fc_check"] = "success"
308
+ elif op["op_name"] == "SOFTMAX":
309
+ if len(op_info["input_indices"]) < 1:
310
+ fatal_error("SOFTMAX missing input", "Check model format")
311
+ if len(op_info["output_indices"]) < 1:
312
+ fatal_error("SOFTMAX missing output", "Check model format")
313
+ op_info["state"] = "translated"
314
+ op_info["pass_flags"]["softmax_check"] = "success"
315
+ elif op["op_name"] == "RESHAPE":
316
+ inputs = op_info["input_indices"]
317
+ if len(inputs) < 2:
318
+ fatal_error(
319
+ "RESHAPE missing target shape parameter",
320
+ "Check model format"
321
+ )
322
+
323
+ shape_idx = inputs[1]
324
+ shape_tensor = tensors.get(shape_idx, {})
325
+
326
+ # Prefer reading actual value from data
327
+ if "data" in shape_tensor:
328
+ target_shape = [int(s) for s in shape_tensor["data"].flatten()]
329
+ else:
330
+ target_shape = shape_tensor.get("shape", [])
331
+
332
+ if not target_shape:
333
+ fatal_error(
334
+ "RESHAPE cannot extract target shape",
335
+ "Check model format"
336
+ )
337
+
338
+ # Handle dynamic dimension -1
339
+ if -1 in target_shape:
340
+ # Get input tensor size
341
+ input_idx = inputs[0]
342
+ input_tensor = tensors.get(input_idx, {})
343
+ input_size = input_tensor.get("size", 1)
344
+
345
+ # Calculate actual value for -1
346
+ other_dims = 1
347
+ for s in target_shape:
348
+ if s != -1:
349
+ other_dims *= s
350
+ if other_dims > 0:
351
+ dynamic_size = input_size // other_dims
352
+ target_shape = [dynamic_size if s == -1 else s for s in
353
+ target_shape]
354
+
355
+ op_info["reshape_target_shape"] = [int(s) for s in target_shape]
356
+ op_info["state"] = "translated"
357
+ op_info["pass_flags"]["reshape_check"] = "success"
358
+ elif op["op_name"] == "UNIDIRECTIONAL_SEQUENCE_LSTM":
359
+ inputs = op_info["input_indices"]
360
+ if len(inputs) >= 13:
361
+ # TFLite/LiteRT LSTM input order:
362
+ # [0] input data
363
+ # [1-4] input gate weights
364
+ # [5-8] recurrent weights
365
+ # [9-12] biases
366
+ # TFLite/LiteRT UNIDIRECTIONAL_SEQUENCE_LSTM
367
+ # operator input order:
368
+ # [0] input data
369
+ # [1-4] input gate weights (i, f, g, o)
370
+ # [5-8] recurrent weights (i, f, g, o)
371
+ # [9-12] biases (i, f, g, o)
372
+ # [13+] other parameters
373
+ op_info["lstm_weight_indices"] = {
374
+ "input": inputs[1:5],
375
+ "recurrent": inputs[5:9],
376
+ "bias": inputs[9:13],
377
+ }
378
+
379
+ # Extract hidden_size from output shape
380
+ output_shape = tensors.get(op_info["output_indices"][0],
381
+ {}).get("shape", [])
382
+ if len(output_shape) >= 3:
383
+ hidden_size = output_shape[2]
384
+ else:
385
+ fatal_error(
386
+ "Cannot extract hidden_size from LSTM output shape",
387
+ "Check model format"
388
+ )
389
+
390
+ # Extract time_steps, batch_size, input_size from input shape
391
+ input_shape = tensors.get(inputs[0], {}).get("shape", [])
392
+ if len(input_shape) >= 3:
393
+ op_info["lstm_params"] = {
394
+ "time_steps": input_shape[1],
395
+ # [batch, time_steps, input_size]
396
+ "batch_size": input_shape[0],
397
+ "input_size": input_shape[2],
398
+ "hidden_size": hidden_size,
399
+ }
400
+ else:
401
+ fatal_error(
402
+ "Cannot extract parameters from LSTM input shape",
403
+ "Check model format"
404
+ )
405
+
406
+ op_info["state"] = "translated"
407
+ op_info["pass_flags"]["lstm_check"] = "success"
408
+ else:
409
+ fatal_error("LSTM input incomplete", "Check model format")
410
+ elif op["op_name"] == "SVDF":
411
+ # SVDF needs to record weight indices
412
+ inputs = op_info["input_indices"]
413
+ # Input count based on operator spec:
414
+ # inputs[0] inputs[1] inputs[2] are valid
415
+ if len(inputs) >= 3:
416
+ op_info["data_input_idx"] = inputs[0]
417
+ op_info["svdf_weights_idx"] = inputs[1]
418
+ op_info["svdf_bias_idx"] = inputs[2]
419
+ op_info["state"] = "translated"
420
+ op_info["pass_flags"]["svdf_check"] = "success"
421
+ elif op["op_name"] == "CONV_2D":
422
+ inputs = op_info["input_indices"]
423
+ if len(inputs) >= 3:
424
+ op_info["data_input_idx"] = inputs[0]
425
+ op_info["conv_weights_idx"] = inputs[1]
426
+ op_info["conv_bias_idx"] = inputs[2]
427
+ else:
428
+ fatal_error("CONV_2D input incomplete", "Check model format")
429
+
430
+ # Extract convolution parameters
431
+ # Calculate stride, padding, etc. from input/output shapes
432
+ input_idx = inputs[0]
433
+ output_idx = op_info["output_indices"][0]
434
+
435
+ input_tensor = tensors.get(input_idx, {})
436
+ output_tensor = tensors.get(output_idx, {})
437
+
438
+ input_shape = input_tensor.get("shape", [])
439
+ output_shape = output_tensor.get("shape", [])
440
+
441
+ # Weight shape: [out_channels, kernel_h, kernel_w, in_channels]
442
+ weights_tensor = tensors.get(inputs[1], {})
443
+ weights_shape = weights_tensor.get("shape", [])
444
+
445
+ # Default parameters
446
+ stride_h = 1
447
+ stride_w = 1
448
+ padding = "VALID"
449
+
450
+ # Infer stride and padding from input/output shapes
451
+ if len(input_shape) >= 4 and len(output_shape) >= 4 and len(
452
+ weights_shape) >= 4:
453
+ input_h = input_shape[1]
454
+ input_w = input_shape[2]
455
+ output_h = output_shape[1]
456
+ output_w = output_shape[2]
457
+ kernel_h = weights_shape[1]
458
+ kernel_w = weights_shape[2]
459
+
460
+ # Calculate stride (assuming stride_h == stride_w)
461
+ if input_h > output_h:
462
+ stride_h = (input_h - kernel_h) // (
463
+ output_h - 1) if output_h > 1 else 1
464
+ stride_w = (input_w - kernel_w) // (
465
+ output_w - 1) if output_w > 1 else 1
466
+
467
+ # Determine padding
468
+ # If output size = ceil(input / stride), usually SAME padding
469
+ # Otherwise VALID
470
+ expected_h = (input_h + stride_h - 1) // stride_h
471
+ if output_h == expected_h:
472
+ padding = "SAME"
473
+ else:
474
+ padding = "VALID"
475
+
476
+ op_info["conv_params"] = {
477
+ "input_h": input_shape[1] if len(input_shape) >= 4 else 0,
478
+ "input_w": input_shape[2] if len(input_shape) >= 4 else 0,
479
+ "input_c": input_shape[3] if len(input_shape) >= 4 else 0,
480
+ "output_h": output_shape[1] if len(output_shape) >= 4 else 0,
481
+ "output_w": output_shape[2] if len(output_shape) >= 4 else 0,
482
+ "output_c": output_shape[3] if len(output_shape) >= 4 else 0,
483
+ "kernel_h": weights_shape[1] if len(weights_shape) >= 4 else 0,
484
+ "kernel_w": weights_shape[2] if len(weights_shape) >= 4 else 0,
485
+ "stride_h": stride_h,
486
+ "stride_w": stride_w,
487
+ "padding": padding,
488
+ }
489
+
490
+ op_info["state"] = "translated"
491
+ op_info["pass_flags"]["conv_check"] = "success"
492
+ elif op["op_name"] == "MAX_POOL_2D":
493
+ inputs = op_info["input_indices"]
494
+ if len(inputs) < 1:
495
+ fatal_error("MAX_POOL_2D missing input", "Check model format")
496
+
497
+ op_info["data_input_idx"] = inputs[0]
498
+
499
+ # Extract parameters from input/output shapes
500
+ input_idx = inputs[0]
501
+ output_idx = op_info["output_indices"][0]
502
+
503
+ input_tensor = tensors.get(input_idx, {})
504
+ output_tensor = tensors.get(output_idx, {})
505
+
506
+ input_shape = input_tensor.get("shape", [])
507
+ output_shape = output_tensor.get("shape", [])
508
+
509
+ # Default parameters
510
+ pool_size_h = 2
511
+ pool_size_w = 2
512
+ stride_h = 2
513
+ stride_w = 2
514
+ padding = "VALID"
515
+
516
+ if len(input_shape) >= 4 and len(output_shape) >= 4:
517
+ input_h = input_shape[1]
518
+ input_w = input_shape[2]
519
+ output_h = output_shape[1]
520
+ output_w = output_shape[2]
521
+
522
+ # Infer stride from input/output
523
+ if input_h > output_h:
524
+ stride_h = input_h // output_h if output_h > 0 else 1
525
+ stride_w = input_w // output_w if output_w > 0 else 1
526
+
527
+ op_info["pool_params"] = {
528
+ "input_h": input_shape[1] if len(input_shape) >= 4 else 0,
529
+ "input_w": input_shape[2] if len(input_shape) >= 4 else 0,
530
+ "input_c": input_shape[3] if len(input_shape) >= 4 else 0,
531
+ "output_h": output_shape[1] if len(output_shape) >= 4 else 0,
532
+ "output_w": output_shape[2] if len(output_shape) >= 4 else 0,
533
+ "output_c": output_shape[3] if len(output_shape) >= 4 else 0,
534
+ "pool_size_h": pool_size_h,
535
+ "pool_size_w": pool_size_w,
536
+ "stride_h": stride_h,
537
+ "stride_w": stride_w,
538
+ "padding": padding,
539
+ }
540
+
541
+ op_info["state"] = "translated"
542
+ op_info["pass_flags"]["pool_check"] = "success"
543
+ elif op["op_name"] == "DEPTHWISE_CONV_2D":
544
+ inputs = op_info["input_indices"]
545
+ if len(inputs) < 3:
546
+ fatal_error(
547
+ "DEPTHWISE_CONV_2D input incomplete",
548
+ "Check model format"
549
+ )
550
+
551
+ op_info["data_input_idx"] = inputs[0]
552
+ op_info["dw_weights_idx"] = inputs[1]
553
+ op_info["dw_bias_idx"] = inputs[2]
554
+
555
+ # Extract parameters
556
+ input_idx = inputs[0]
557
+ output_idx = op_info["output_indices"][0]
558
+
559
+ input_tensor = tensors.get(input_idx, {})
560
+ output_tensor = tensors.get(output_idx, {})
561
+
562
+ input_shape = input_tensor.get("shape", [])
563
+ output_shape = output_tensor.get("shape", [])
564
+ weights_tensor = tensors.get(inputs[1], {})
565
+ weights_shape = weights_tensor.get("shape", [])
566
+
567
+ op_info["dw_params"] = {
568
+ "input_h": input_shape[1] if len(input_shape) >= 4 else 0,
569
+ "input_w": input_shape[2] if len(input_shape) >= 4 else 0,
570
+ "input_c": input_shape[3] if len(input_shape) >= 4 else 0,
571
+ "output_h": output_shape[1] if len(output_shape) >= 4 else 0,
572
+ "output_w": output_shape[2] if len(output_shape) >= 4 else 0,
573
+ "output_c": output_shape[3] if len(output_shape) >= 4 else 0,
574
+ "kernel_h": weights_shape[1] if len(weights_shape) >= 4 else 0,
575
+ "kernel_w": weights_shape[2] if len(weights_shape) >= 4 else 0,
576
+ "depth_multiplier": weights_shape[3] // input_shape[3] if len(
577
+ input_shape) >= 4 and len(weights_shape) >= 4 else 1,
578
+ "stride_h": 1,
579
+ "stride_w": 1,
580
+ "padding_h": 0,
581
+ "padding_w": 0,
582
+ }
583
+
584
+ op_info["state"] = "translated"
585
+ op_info["pass_flags"]["dw_check"] = "success"
586
+ elif op["op_name"] == "RELU":
587
+ # ReLU only needs input/output, no extra parameters
588
+ if len(op_info["input_indices"]) < 1:
589
+ fatal_error("RELU missing input", "Check model format")
590
+ if len(op_info["output_indices"]) < 1:
591
+ fatal_error("RELU missing output", "Check model format")
592
+ op_info["state"] = "translated"
593
+ op_info["pass_flags"]["relu_check"] = "success"
594
+ elif op["op_name"] == "AVERAGE_POOL_2D":
595
+ inputs = op_info["input_indices"]
596
+ if len(inputs) < 1:
597
+ fatal_error(
598
+ "AVERAGE_POOL_2D missing input",
599
+ "Check model format"
600
+ )
601
+
602
+ op_info["data_input_idx"] = inputs[0]
603
+
604
+ input_idx = inputs[0]
605
+ output_idx = op_info["output_indices"][0]
606
+
607
+ input_tensor = tensors.get(input_idx, {})
608
+ output_tensor = tensors.get(output_idx, {})
609
+
610
+ input_shape = input_tensor.get("shape", [])
611
+ output_shape = output_tensor.get("shape", [])
612
+
613
+ pool_h = 2
614
+ pool_w = 2
615
+ stride_h = 2
616
+ stride_w = 2
617
+
618
+ if len(input_shape) >= 4 and len(output_shape) >= 4:
619
+ input_h = input_shape[1]
620
+ input_w = input_shape[2]
621
+ output_h = output_shape[1]
622
+ output_w = output_shape[2]
623
+ if input_h > output_h and output_h > 0:
624
+ stride_h = input_h // output_h
625
+ stride_w = input_w // output_w
626
+
627
+ op_info["pool_params"] = {
628
+ "input_h": input_shape[1] if len(input_shape) >= 4 else 0,
629
+ "input_w": input_shape[2] if len(input_shape) >= 4 else 0,
630
+ "input_c": input_shape[3] if len(input_shape) >= 4 else 0,
631
+ "output_h": output_shape[1] if len(output_shape) >= 4 else 0,
632
+ "output_w": output_shape[2] if len(output_shape) >= 4 else 0,
633
+ "output_c": output_shape[3] if len(output_shape) >= 4 else 0,
634
+ "pool_h": pool_h,
635
+ "pool_w": pool_w,
636
+ "stride_h": stride_h,
637
+ "stride_w": stride_w,
638
+ "padding": "VALID",
639
+ }
640
+
641
+ op_info["state"] = "translated"
642
+ op_info["pass_flags"]["avg_pool_check"] = "success"
643
+ elif op["op_name"] == "TRANSPOSE":
644
+ inputs = op_info["input_indices"]
645
+ if len(inputs) < 2:
646
+ fatal_error(
647
+ "TRANSPOSE missing perm parameter",
648
+ "Check model format"
649
+ )
650
+
651
+ # inputs[0] = input data, inputs[1] = perm (transpose order)
652
+ op_info["data_input_idx"] = inputs[0]
653
+ op_info["transpose_perm_idx"] = inputs[1]
654
+
655
+ # Infer transpose parameters from input/output shapes
656
+ input_idx = inputs[0]
657
+ output_idx = op_info["output_indices"][0]
658
+ input_tensor = tensors.get(input_idx, {})
659
+ output_tensor = tensors.get(output_idx, {})
660
+ input_shape = input_tensor.get("shape", [])
661
+ output_shape = output_tensor.get("shape", [])
662
+
663
+ op_info["transpose_params"] = {
664
+ "input_dims": len(input_shape),
665
+ "output_dims": len(output_shape),
666
+ }
667
+
668
+ op_info["state"] = "translated"
669
+ op_info["pass_flags"]["transpose_check"] = "success"
670
+ elif op["op_name"] == "QUANTIZE":
671
+ op_info["state"] = "translated"
672
+ op_info["pass_flags"]["quantize_check"] = "success"
673
+ elif op["op_name"] == "PAD":
674
+ inputs = op_info["input_indices"]
675
+ if len(inputs) < 2:
676
+ fatal_error(
677
+ "PAD missing padding parameter",
678
+ "Check model format"
679
+ )
680
+
681
+ op_info["data_input_idx"] = inputs[0]
682
+ op_info["pad_paddings_idx"] = inputs[1]
683
+ op_info["state"] = "translated"
684
+ op_info["pass_flags"]["pad_check"] = "success"
685
+ elif op["op_name"] == "MEAN":
686
+ inputs = op_info["input_indices"]
687
+ if len(inputs) < 1:
688
+ fatal_error("MEAN missing input", "Check model format")
689
+
690
+ op_info["data_input_idx"] = inputs[0]
691
+ # Record axis parameter if present
692
+ if len(inputs) >= 2:
693
+ op_info["mean_axis_idx"] = inputs[1]
694
+
695
+ # Extract parameters from input/output shapes
696
+ input_idx = inputs[0]
697
+ output_idx = op_info["output_indices"][0]
698
+
699
+ input_tensor = tensors.get(input_idx, {})
700
+ output_tensor = tensors.get(output_idx, {})
701
+
702
+ input_shape = input_tensor.get("shape", [])
703
+ output_shape = output_tensor.get("shape", [])
704
+
705
+ op_info["mean_params"] = {
706
+ "input_dims": len(input_shape),
707
+ "output_dims": len(output_shape),
708
+ }
709
+
710
+ op_info["state"] = "translated"
711
+ op_info["pass_flags"]["mean_check"] = "success"
712
+ elif op["op_name"] == "DELEGATE":
713
+ continue # skip
714
+ else:
715
+ # Unknown operator, keep created state
716
+ warning(
717
+ "Unknown operator encountered, "
718
+ "please submit an issue or patch set"
719
+ )
720
+ op_info["state"] = "created"
721
+ op_info["pass_flags"]["unknown"] = "needs_implementation"
722
+
723
+ # Add input/output details
724
+ for inp_idx in op_info["input_indices"]:
725
+ tensor_info = tensors.get(inp_idx, {})
726
+ op_info["input_details"].append({
727
+ "index": inp_idx,
728
+ "name": tensor_info.get("name", "unknown"),
729
+ "shape": tensor_info.get("shape", []),
730
+ "size": tensor_info.get("size", 0),
731
+ "scale": tensor_info.get("scale", 1.0),
732
+ "zero_point": tensor_info.get("zero_point", 0),
733
+ })
734
+
735
+ for out_idx in op_info["output_indices"]:
736
+ tensor_info = tensors.get(out_idx, {})
737
+ op_info["output_details"].append({
738
+ "index": out_idx,
739
+ "name": tensor_info.get("name", "unknown"),
740
+ "shape": tensor_info.get("shape", []),
741
+ "size": tensor_info.get("size", 0),
742
+ "scale": tensor_info.get("scale", 1.0),
743
+ "zero_point": tensor_info.get("zero_point", 0),
744
+ })
745
+
746
+ ops.append(op_info)
747
+
748
+ return {
749
+ "input": input_details,
750
+ "output": output_details,
751
+ "ops": ops,
752
+ "tensors": tensors,
753
+ "weights": {}, # LiteRT uses separate weight extraction,
754
+ # kept for unified interface
755
+ "quant_scales": {}, # Quantization scales
756
+ # (LiteRT stores per-tensor in tensors)
757
+ }