tinymlc 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- TinyMLC/ANG/__init__.py +0 -0
- TinyMLC/ANG/args.py +86 -0
- TinyMLC/ANG/estimator.py +103 -0
- TinyMLC/ANG/estimator_hal.py +184 -0
- TinyMLC/ANG/estimator_qemu.py +257 -0
- TinyMLC/ANG/estimator_software.py +130 -0
- TinyMLC/ANG/model_builder.py +508 -0
- TinyMLC/ANG/model_generator.py +439 -0
- TinyMLC/ANG/model_info.py +283 -0
- TinyMLC/ANG/utils.py +420 -0
- TinyMLC/__init__.py +0 -0
- TinyMLC/cli.py +126 -0
- TinyMLC/codegen.py +877 -0
- TinyMLC/converter/__init__.py +0 -0
- TinyMLC/converter/export_weights.py +382 -0
- TinyMLC/converter/parser_litert.py +757 -0
- TinyMLC/converter/parser_onnx.py +649 -0
- TinyMLC/generate_lut.py +97 -0
- TinyMLC/handlers.py +325 -0
- TinyMLC/ops.py +76 -0
- TinyMLC/templates/lut.c.tpl +23 -0
- TinyMLC/templates/lut.h.tpl +67 -0
- TinyMLC/templates/model.c.tpl +314 -0
- TinyMLC/templates/model.h.tpl +66 -0
- TinyMLC/transform/__init__.py +0 -0
- TinyMLC/transform/algebraic.py +286 -0
- TinyMLC/transform/base.py +58 -0
- TinyMLC/transform/constant_folding.py +260 -0
- TinyMLC/transform/cse.py +192 -0
- TinyMLC/transform/dce.py +182 -0
- TinyMLC/transform/fusion.py +723 -0
- TinyMLC/transform/memory.py +200 -0
- TinyMLC/transform/pass_manager.py +101 -0
- TinyMLC/transform/simplify.py +515 -0
- tinymlc-0.1.0.dist-info/METADATA +49 -0
- tinymlc-0.1.0.dist-info/RECORD +47 -0
- tinymlc-0.1.0.dist-info/WHEEL +4 -0
- tinymlc-0.1.0.dist-info/entry_points.txt +2 -0
- tinymlc-0.1.0.dist-info/licenses/LICENSE +201 -0
- utils/__init__.py +0 -0
- utils/arm-none-eabi-gcc.cmake +53 -0
- utils/dump.py +86 -0
- utils/generate_onnx_models.py +183 -0
- utils/generate_tflite_models.py +236 -0
- utils/pack_macos.sh +88 -0
- utils/path.py +31 -0
- utils/riscv-none-elf-gcc.cmake +50 -0
|
@@ -0,0 +1,757 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# TinyMLC - Tiny Machine Learning Compiler
|
|
3
|
+
#
|
|
4
|
+
# Copyright (c) 2026 Jia Liu & TinyMLC Contributors
|
|
5
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
6
|
+
#
|
|
7
|
+
# This file is part of TinyMLC.
|
|
8
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
9
|
+
# you may not use this file except in compliance with the License.
|
|
10
|
+
# You may obtain a copy of the License at:
|
|
11
|
+
#
|
|
12
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
13
|
+
#
|
|
14
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
15
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
16
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
17
|
+
# See the License for the specific language governing permissions and
|
|
18
|
+
# limitations under the License.
|
|
19
|
+
|
|
20
|
+
#!/usr/bin/env python3
|
|
21
|
+
"""LiteRT-based TFLite model parser
|
|
22
|
+
|
|
23
|
+
Note: Uses private API _get_ops_details() as LiteRT has no
|
|
24
|
+
public operator list API.
|
|
25
|
+
LSTM input indices [1:5], [5:9], [9:13] follow TFLite
|
|
26
|
+
UNIDIRECTIONAL_SEQUENCE_LSTM spec.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
from ai_edge_litert.interpreter import Interpreter
|
|
30
|
+
import numpy as np
|
|
31
|
+
|
|
32
|
+
from utils.dump import fatal_error, info, warning
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# Weight extraction functions
|
|
36
|
+
def extract_fc_weights(interpreter, op_info):
|
|
37
|
+
"""Extract FULLY_CONNECTED layer weights and bias using LiteRT API"""
|
|
38
|
+
weights_idx = op_info.get("fc_weights_idx")
|
|
39
|
+
bias_idx = op_info.get("fc_bias_idx")
|
|
40
|
+
|
|
41
|
+
if weights_idx is None:
|
|
42
|
+
return None, None
|
|
43
|
+
|
|
44
|
+
try:
|
|
45
|
+
weights = interpreter.get_tensor(weights_idx)
|
|
46
|
+
bias = (interpreter.get_tensor(bias_idx)
|
|
47
|
+
if bias_idx is not None else None)
|
|
48
|
+
except ValueError as e:
|
|
49
|
+
fatal_error(f"Cannot get FC tensor: {e}",
|
|
50
|
+
"Ensure model is loaded correctly")
|
|
51
|
+
|
|
52
|
+
info(f"FC weights: shape={weights.shape}, dtype={weights.dtype}")
|
|
53
|
+
if bias is not None:
|
|
54
|
+
info(f"FC bias: shape={bias.shape}, dtype={bias.dtype}")
|
|
55
|
+
return weights, bias
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def extract_lstm_weights(interpreter, op_info):
|
|
59
|
+
"""Extract LSTM gate weights and biases using LiteRT API
|
|
60
|
+
|
|
61
|
+
TFLite LSTM input indices follow standard spec:
|
|
62
|
+
- inputs[1:5]: input gate weights (i, f, g, o)
|
|
63
|
+
- inputs[5:9]: recurrent weights (i, f, g, o)
|
|
64
|
+
- inputs[9:13]: biases (i, f, g, o)
|
|
65
|
+
"""
|
|
66
|
+
indices = op_info.get("lstm_weight_indices", {})
|
|
67
|
+
input_indices = indices.get("input", [])
|
|
68
|
+
recurrent_indices = indices.get("recurrent", [])
|
|
69
|
+
bias_indices = indices.get("bias", [])
|
|
70
|
+
|
|
71
|
+
if not input_indices or not recurrent_indices:
|
|
72
|
+
return None
|
|
73
|
+
|
|
74
|
+
gate_order = ['i', 'f', 'g', 'o']
|
|
75
|
+
lstm_weights = {'input': {}, 'recurrent': {}, 'bias': {}}
|
|
76
|
+
|
|
77
|
+
for gate, idx in zip(gate_order, input_indices):
|
|
78
|
+
try:
|
|
79
|
+
lstm_weights['input'][gate] = interpreter.get_tensor(idx)
|
|
80
|
+
except ValueError:
|
|
81
|
+
lstm_weights['input'][gate] = None
|
|
82
|
+
|
|
83
|
+
for gate, idx in zip(gate_order, recurrent_indices):
|
|
84
|
+
try:
|
|
85
|
+
lstm_weights['recurrent'][gate] = interpreter.get_tensor(idx)
|
|
86
|
+
except ValueError:
|
|
87
|
+
lstm_weights['recurrent'][gate] = None
|
|
88
|
+
|
|
89
|
+
for gate, idx in zip(gate_order, bias_indices):
|
|
90
|
+
try:
|
|
91
|
+
lstm_weights['bias'][gate] = interpreter.get_tensor(idx)
|
|
92
|
+
except ValueError:
|
|
93
|
+
lstm_weights['bias'][gate] = None
|
|
94
|
+
|
|
95
|
+
return lstm_weights
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def extract_conv_weights(interpreter, op_info):
|
|
99
|
+
"""Extract CONV_2D layer weights and bias using LiteRT API"""
|
|
100
|
+
weights_idx = op_info.get("conv_weights_idx")
|
|
101
|
+
bias_idx = op_info.get("conv_bias_idx")
|
|
102
|
+
|
|
103
|
+
if weights_idx is None:
|
|
104
|
+
return None, None
|
|
105
|
+
|
|
106
|
+
try:
|
|
107
|
+
weights = interpreter.get_tensor(weights_idx)
|
|
108
|
+
bias = (interpreter.get_tensor(bias_idx)
|
|
109
|
+
if bias_idx is not None else None)
|
|
110
|
+
except ValueError as e:
|
|
111
|
+
fatal_error(f"Cannot get CONV_2D tensor: {e}",
|
|
112
|
+
"Ensure model is loaded correctly")
|
|
113
|
+
|
|
114
|
+
info(f"CONV_2D weights: shape={weights.shape}, dtype={weights.dtype}")
|
|
115
|
+
if bias is not None:
|
|
116
|
+
info(f"CONV_2D bias: shape={bias.shape}, dtype={bias.dtype}")
|
|
117
|
+
return weights, bias
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def extract_dw_weights(interpreter, op_info):
|
|
121
|
+
"""Extract DEPTHWISE_CONV_2D layer weights and bias using LiteRT API"""
|
|
122
|
+
weights_idx = op_info.get("dw_weights_idx")
|
|
123
|
+
bias_idx = op_info.get("dw_bias_idx")
|
|
124
|
+
|
|
125
|
+
if weights_idx is None:
|
|
126
|
+
return None, None
|
|
127
|
+
|
|
128
|
+
try:
|
|
129
|
+
weights = interpreter.get_tensor(weights_idx)
|
|
130
|
+
bias = (interpreter.get_tensor(bias_idx)
|
|
131
|
+
if bias_idx is not None else None)
|
|
132
|
+
except ValueError as e:
|
|
133
|
+
fatal_error(f"Cannot get DEPTHWISE_CONV_2D tensor: {e}",
|
|
134
|
+
"Ensure model is loaded correctly")
|
|
135
|
+
|
|
136
|
+
info(f"DEPTHWISE_CONV_2D weights: shape={weights.shape}, "
|
|
137
|
+
f"dtype={weights.dtype}")
|
|
138
|
+
if bias is not None:
|
|
139
|
+
info(f"DEPTHWISE_CONV_2D bias: shape={bias.shape}, dtype={bias.dtype}")
|
|
140
|
+
return weights, bias
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def extract_all_weights_litert(model_path, model_info):
|
|
144
|
+
"""Extract all weights from LiteRT model and store in model_info['weights']
|
|
145
|
+
|
|
146
|
+
Uses source-specific keys: fc_tflite.weight, lstm_tflite.weight_ih, etc.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
model_path: TFLite model file path
|
|
150
|
+
model_info: Model info dict from parse_model_tflite
|
|
151
|
+
"""
|
|
152
|
+
# Create interpreter internally
|
|
153
|
+
interpreter = Interpreter(model_path=model_path)
|
|
154
|
+
interpreter.allocate_tensors()
|
|
155
|
+
|
|
156
|
+
fc_op_info = None
|
|
157
|
+
lstm_op_info = None
|
|
158
|
+
conv_op_info = None
|
|
159
|
+
dw_op_info = None
|
|
160
|
+
|
|
161
|
+
for op in model_info["ops"]:
|
|
162
|
+
if op["op_name"] == "FULLY_CONNECTED":
|
|
163
|
+
fc_op_info = op
|
|
164
|
+
elif op["op_name"] == "UNIDIRECTIONAL_SEQUENCE_LSTM":
|
|
165
|
+
lstm_op_info = op
|
|
166
|
+
elif op["op_name"] == "CONV_2D":
|
|
167
|
+
conv_op_info = op
|
|
168
|
+
elif op["op_name"] == "DEPTHWISE_CONV_2D":
|
|
169
|
+
dw_op_info = op
|
|
170
|
+
|
|
171
|
+
fc_weights, fc_bias = (extract_fc_weights(interpreter, fc_op_info)
|
|
172
|
+
if fc_op_info else (None, None))
|
|
173
|
+
lstm_weights = (extract_lstm_weights(interpreter, lstm_op_info)
|
|
174
|
+
if lstm_op_info else None)
|
|
175
|
+
conv_weights, conv_bias = (extract_conv_weights(interpreter, conv_op_info)
|
|
176
|
+
if conv_op_info else (None, None))
|
|
177
|
+
dw_weights, dw_bias = (extract_dw_weights(interpreter, dw_op_info)
|
|
178
|
+
if dw_op_info else (None, None))
|
|
179
|
+
|
|
180
|
+
model_info["weights"] = {}
|
|
181
|
+
|
|
182
|
+
if fc_weights is not None and fc_bias is not None:
|
|
183
|
+
model_info["weights"]["fc_tflite.weight"] = fc_weights
|
|
184
|
+
model_info["weights"]["fc_tflite.bias"] = fc_bias
|
|
185
|
+
|
|
186
|
+
if lstm_weights and lstm_weights['input']:
|
|
187
|
+
gates = ['i', 'f', 'g', 'o']
|
|
188
|
+
if all(lstm_weights['input'].get(g) is not None for g in gates):
|
|
189
|
+
input_concat = np.concatenate(
|
|
190
|
+
[lstm_weights['input'][g].flatten() for g in gates])
|
|
191
|
+
model_info["weights"]["lstm_tflite.weight_ih"] = input_concat
|
|
192
|
+
if all(lstm_weights['recurrent'].get(g) is not None for g in gates):
|
|
193
|
+
recurrent_concat = np.concatenate(
|
|
194
|
+
[lstm_weights['recurrent'][g].flatten() for g in gates])
|
|
195
|
+
model_info["weights"]["lstm_tflite.weight_hh"] = recurrent_concat
|
|
196
|
+
if all(lstm_weights['bias'].get(g) is not None for g in gates):
|
|
197
|
+
bias_concat = np.concatenate(
|
|
198
|
+
[lstm_weights['bias'][g].flatten() for g in gates])
|
|
199
|
+
model_info["weights"]["lstm_tflite.bias"] = bias_concat
|
|
200
|
+
|
|
201
|
+
if conv_weights is not None:
|
|
202
|
+
model_info["weights"]["conv_tflite.weight"] = conv_weights
|
|
203
|
+
if conv_bias is not None:
|
|
204
|
+
model_info["weights"]["conv_tflite.bias"] = conv_bias
|
|
205
|
+
|
|
206
|
+
if dw_weights is not None:
|
|
207
|
+
model_info["weights"]["dw_tflite.weight"] = dw_weights
|
|
208
|
+
if dw_bias is not None:
|
|
209
|
+
model_info["weights"]["dw_tflite.bias"] = dw_bias
|
|
210
|
+
|
|
211
|
+
return fc_weights, fc_bias, lstm_weights, conv_weights, \
|
|
212
|
+
conv_bias, dw_weights, dw_bias
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def parse_model_tflite(model_path: str):
|
|
216
|
+
"""Parse TFLite model using LiteRT"""
|
|
217
|
+
|
|
218
|
+
# 1. Load model (using LiteRT Interpreter)
|
|
219
|
+
interpreter = Interpreter(model_path=model_path)
|
|
220
|
+
interpreter.allocate_tensors()
|
|
221
|
+
|
|
222
|
+
# 2. Get input/output tensors (same as old API)
|
|
223
|
+
input_details = interpreter.get_input_details()
|
|
224
|
+
output_details = interpreter.get_output_details()
|
|
225
|
+
|
|
226
|
+
# 3. Get all tensor info, try to read constant tensor data
|
|
227
|
+
tensor_details = interpreter.get_tensor_details()
|
|
228
|
+
tensors = {} # Unified key: tensors (was tensors)
|
|
229
|
+
|
|
230
|
+
for tensor in tensor_details:
|
|
231
|
+
shape = tensor["shape"]
|
|
232
|
+
tensor_info = {
|
|
233
|
+
"name": tensor["name"],
|
|
234
|
+
"shape": list(tensor["shape"]),
|
|
235
|
+
"dtype": str(tensor["dtype"]),
|
|
236
|
+
"size": int(np.prod(shape)) if shape is not None and len(
|
|
237
|
+
shape) > 0 else 1,
|
|
238
|
+
"scale": (
|
|
239
|
+
tensor["quantization"][0]
|
|
240
|
+
if tensor["quantization"][0] is not None else 1.0
|
|
241
|
+
),
|
|
242
|
+
"zero_point": (
|
|
243
|
+
tensor["quantization"][1]
|
|
244
|
+
if tensor["quantization"][1] is not None else 0
|
|
245
|
+
),
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
# Try to read constant tensor data (for Reshape target shape etc)
|
|
249
|
+
try:
|
|
250
|
+
data = interpreter.get_tensor(tensor["index"])
|
|
251
|
+
tensor_info["data"] = data
|
|
252
|
+
except:
|
|
253
|
+
pass
|
|
254
|
+
|
|
255
|
+
tensors[tensor["index"]] = tensor_info
|
|
256
|
+
|
|
257
|
+
# 4. Get operator list
|
|
258
|
+
ops = []
|
|
259
|
+
for op in interpreter._get_ops_details():
|
|
260
|
+
# Skip DELEGATE operator
|
|
261
|
+
if op["op_name"] == "DELEGATE":
|
|
262
|
+
continue
|
|
263
|
+
|
|
264
|
+
op_info = {
|
|
265
|
+
"index": op["index"],
|
|
266
|
+
"op_name": op["op_name"],
|
|
267
|
+
"inputs": [inp for inp in op["inputs"] if inp != -1],
|
|
268
|
+
"outputs": [out for out in op["outputs"] if out != -1],
|
|
269
|
+
"input_indices": [inp for inp in op["inputs"] if inp != -1],
|
|
270
|
+
"output_indices": [out for out in op["outputs"] if out != -1],
|
|
271
|
+
"state": "created",
|
|
272
|
+
"pass_flags": {},
|
|
273
|
+
"input_details": [],
|
|
274
|
+
"output_details": [],
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
# Set state based on operator type
|
|
278
|
+
if op["op_name"] == "ADD":
|
|
279
|
+
inputs = op_info["input_indices"]
|
|
280
|
+
# Input count based on operator spec:
|
|
281
|
+
# inputs[0] inputs[1] are valid
|
|
282
|
+
if len(inputs) >= 2:
|
|
283
|
+
op_info["add_input1_idx"] = inputs[0]
|
|
284
|
+
op_info["add_input2_idx"] = inputs[1]
|
|
285
|
+
op_info["state"] = "translated"
|
|
286
|
+
op_info["pass_flags"]["add_check"] = "success"
|
|
287
|
+
elif op["op_name"] == "FULLY_CONNECTED":
|
|
288
|
+
inputs = op_info["input_indices"]
|
|
289
|
+
# Input count based on operator spec:
|
|
290
|
+
# inputs[0] inputs[1] inputs[2] are valid
|
|
291
|
+
if len(inputs) >= 3:
|
|
292
|
+
op_info["data_input_idx"] = inputs[0]
|
|
293
|
+
op_info["fc_weights_idx"] = inputs[1]
|
|
294
|
+
op_info["fc_bias_idx"] = inputs[2]
|
|
295
|
+
# Input count based on operator spec:
|
|
296
|
+
# inputs[0] inputs[1] are valid
|
|
297
|
+
elif len(inputs) >= 2:
|
|
298
|
+
op_info["data_input_idx"] = inputs[0]
|
|
299
|
+
op_info["fc_weights_idx"] = inputs[1]
|
|
300
|
+
op_info["fc_bias_idx"] = None
|
|
301
|
+
else:
|
|
302
|
+
# Input count is based on operator spec, inputs[0] is valid
|
|
303
|
+
op_info["data_input_idx"] = inputs[0]
|
|
304
|
+
op_info["fc_weights_idx"] = None
|
|
305
|
+
op_info["fc_bias_idx"] = None
|
|
306
|
+
op_info["state"] = "translated"
|
|
307
|
+
op_info["pass_flags"]["fc_check"] = "success"
|
|
308
|
+
elif op["op_name"] == "SOFTMAX":
|
|
309
|
+
if len(op_info["input_indices"]) < 1:
|
|
310
|
+
fatal_error("SOFTMAX missing input", "Check model format")
|
|
311
|
+
if len(op_info["output_indices"]) < 1:
|
|
312
|
+
fatal_error("SOFTMAX missing output", "Check model format")
|
|
313
|
+
op_info["state"] = "translated"
|
|
314
|
+
op_info["pass_flags"]["softmax_check"] = "success"
|
|
315
|
+
elif op["op_name"] == "RESHAPE":
|
|
316
|
+
inputs = op_info["input_indices"]
|
|
317
|
+
if len(inputs) < 2:
|
|
318
|
+
fatal_error(
|
|
319
|
+
"RESHAPE missing target shape parameter",
|
|
320
|
+
"Check model format"
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
shape_idx = inputs[1]
|
|
324
|
+
shape_tensor = tensors.get(shape_idx, {})
|
|
325
|
+
|
|
326
|
+
# Prefer reading actual value from data
|
|
327
|
+
if "data" in shape_tensor:
|
|
328
|
+
target_shape = [int(s) for s in shape_tensor["data"].flatten()]
|
|
329
|
+
else:
|
|
330
|
+
target_shape = shape_tensor.get("shape", [])
|
|
331
|
+
|
|
332
|
+
if not target_shape:
|
|
333
|
+
fatal_error(
|
|
334
|
+
"RESHAPE cannot extract target shape",
|
|
335
|
+
"Check model format"
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
# Handle dynamic dimension -1
|
|
339
|
+
if -1 in target_shape:
|
|
340
|
+
# Get input tensor size
|
|
341
|
+
input_idx = inputs[0]
|
|
342
|
+
input_tensor = tensors.get(input_idx, {})
|
|
343
|
+
input_size = input_tensor.get("size", 1)
|
|
344
|
+
|
|
345
|
+
# Calculate actual value for -1
|
|
346
|
+
other_dims = 1
|
|
347
|
+
for s in target_shape:
|
|
348
|
+
if s != -1:
|
|
349
|
+
other_dims *= s
|
|
350
|
+
if other_dims > 0:
|
|
351
|
+
dynamic_size = input_size // other_dims
|
|
352
|
+
target_shape = [dynamic_size if s == -1 else s for s in
|
|
353
|
+
target_shape]
|
|
354
|
+
|
|
355
|
+
op_info["reshape_target_shape"] = [int(s) for s in target_shape]
|
|
356
|
+
op_info["state"] = "translated"
|
|
357
|
+
op_info["pass_flags"]["reshape_check"] = "success"
|
|
358
|
+
elif op["op_name"] == "UNIDIRECTIONAL_SEQUENCE_LSTM":
|
|
359
|
+
inputs = op_info["input_indices"]
|
|
360
|
+
if len(inputs) >= 13:
|
|
361
|
+
# TFLite/LiteRT LSTM input order:
|
|
362
|
+
# [0] input data
|
|
363
|
+
# [1-4] input gate weights
|
|
364
|
+
# [5-8] recurrent weights
|
|
365
|
+
# [9-12] biases
|
|
366
|
+
# TFLite/LiteRT UNIDIRECTIONAL_SEQUENCE_LSTM
|
|
367
|
+
# operator input order:
|
|
368
|
+
# [0] input data
|
|
369
|
+
# [1-4] input gate weights (i, f, g, o)
|
|
370
|
+
# [5-8] recurrent weights (i, f, g, o)
|
|
371
|
+
# [9-12] biases (i, f, g, o)
|
|
372
|
+
# [13+] other parameters
|
|
373
|
+
op_info["lstm_weight_indices"] = {
|
|
374
|
+
"input": inputs[1:5],
|
|
375
|
+
"recurrent": inputs[5:9],
|
|
376
|
+
"bias": inputs[9:13],
|
|
377
|
+
}
|
|
378
|
+
|
|
379
|
+
# Extract hidden_size from output shape
|
|
380
|
+
output_shape = tensors.get(op_info["output_indices"][0],
|
|
381
|
+
{}).get("shape", [])
|
|
382
|
+
if len(output_shape) >= 3:
|
|
383
|
+
hidden_size = output_shape[2]
|
|
384
|
+
else:
|
|
385
|
+
fatal_error(
|
|
386
|
+
"Cannot extract hidden_size from LSTM output shape",
|
|
387
|
+
"Check model format"
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
# Extract time_steps, batch_size, input_size from input shape
|
|
391
|
+
input_shape = tensors.get(inputs[0], {}).get("shape", [])
|
|
392
|
+
if len(input_shape) >= 3:
|
|
393
|
+
op_info["lstm_params"] = {
|
|
394
|
+
"time_steps": input_shape[1],
|
|
395
|
+
# [batch, time_steps, input_size]
|
|
396
|
+
"batch_size": input_shape[0],
|
|
397
|
+
"input_size": input_shape[2],
|
|
398
|
+
"hidden_size": hidden_size,
|
|
399
|
+
}
|
|
400
|
+
else:
|
|
401
|
+
fatal_error(
|
|
402
|
+
"Cannot extract parameters from LSTM input shape",
|
|
403
|
+
"Check model format"
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
op_info["state"] = "translated"
|
|
407
|
+
op_info["pass_flags"]["lstm_check"] = "success"
|
|
408
|
+
else:
|
|
409
|
+
fatal_error("LSTM input incomplete", "Check model format")
|
|
410
|
+
elif op["op_name"] == "SVDF":
|
|
411
|
+
# SVDF needs to record weight indices
|
|
412
|
+
inputs = op_info["input_indices"]
|
|
413
|
+
# Input count based on operator spec:
|
|
414
|
+
# inputs[0] inputs[1] inputs[2] are valid
|
|
415
|
+
if len(inputs) >= 3:
|
|
416
|
+
op_info["data_input_idx"] = inputs[0]
|
|
417
|
+
op_info["svdf_weights_idx"] = inputs[1]
|
|
418
|
+
op_info["svdf_bias_idx"] = inputs[2]
|
|
419
|
+
op_info["state"] = "translated"
|
|
420
|
+
op_info["pass_flags"]["svdf_check"] = "success"
|
|
421
|
+
elif op["op_name"] == "CONV_2D":
|
|
422
|
+
inputs = op_info["input_indices"]
|
|
423
|
+
if len(inputs) >= 3:
|
|
424
|
+
op_info["data_input_idx"] = inputs[0]
|
|
425
|
+
op_info["conv_weights_idx"] = inputs[1]
|
|
426
|
+
op_info["conv_bias_idx"] = inputs[2]
|
|
427
|
+
else:
|
|
428
|
+
fatal_error("CONV_2D input incomplete", "Check model format")
|
|
429
|
+
|
|
430
|
+
# Extract convolution parameters
|
|
431
|
+
# Calculate stride, padding, etc. from input/output shapes
|
|
432
|
+
input_idx = inputs[0]
|
|
433
|
+
output_idx = op_info["output_indices"][0]
|
|
434
|
+
|
|
435
|
+
input_tensor = tensors.get(input_idx, {})
|
|
436
|
+
output_tensor = tensors.get(output_idx, {})
|
|
437
|
+
|
|
438
|
+
input_shape = input_tensor.get("shape", [])
|
|
439
|
+
output_shape = output_tensor.get("shape", [])
|
|
440
|
+
|
|
441
|
+
# Weight shape: [out_channels, kernel_h, kernel_w, in_channels]
|
|
442
|
+
weights_tensor = tensors.get(inputs[1], {})
|
|
443
|
+
weights_shape = weights_tensor.get("shape", [])
|
|
444
|
+
|
|
445
|
+
# Default parameters
|
|
446
|
+
stride_h = 1
|
|
447
|
+
stride_w = 1
|
|
448
|
+
padding = "VALID"
|
|
449
|
+
|
|
450
|
+
# Infer stride and padding from input/output shapes
|
|
451
|
+
if len(input_shape) >= 4 and len(output_shape) >= 4 and len(
|
|
452
|
+
weights_shape) >= 4:
|
|
453
|
+
input_h = input_shape[1]
|
|
454
|
+
input_w = input_shape[2]
|
|
455
|
+
output_h = output_shape[1]
|
|
456
|
+
output_w = output_shape[2]
|
|
457
|
+
kernel_h = weights_shape[1]
|
|
458
|
+
kernel_w = weights_shape[2]
|
|
459
|
+
|
|
460
|
+
# Calculate stride (assuming stride_h == stride_w)
|
|
461
|
+
if input_h > output_h:
|
|
462
|
+
stride_h = (input_h - kernel_h) // (
|
|
463
|
+
output_h - 1) if output_h > 1 else 1
|
|
464
|
+
stride_w = (input_w - kernel_w) // (
|
|
465
|
+
output_w - 1) if output_w > 1 else 1
|
|
466
|
+
|
|
467
|
+
# Determine padding
|
|
468
|
+
# If output size = ceil(input / stride), usually SAME padding
|
|
469
|
+
# Otherwise VALID
|
|
470
|
+
expected_h = (input_h + stride_h - 1) // stride_h
|
|
471
|
+
if output_h == expected_h:
|
|
472
|
+
padding = "SAME"
|
|
473
|
+
else:
|
|
474
|
+
padding = "VALID"
|
|
475
|
+
|
|
476
|
+
op_info["conv_params"] = {
|
|
477
|
+
"input_h": input_shape[1] if len(input_shape) >= 4 else 0,
|
|
478
|
+
"input_w": input_shape[2] if len(input_shape) >= 4 else 0,
|
|
479
|
+
"input_c": input_shape[3] if len(input_shape) >= 4 else 0,
|
|
480
|
+
"output_h": output_shape[1] if len(output_shape) >= 4 else 0,
|
|
481
|
+
"output_w": output_shape[2] if len(output_shape) >= 4 else 0,
|
|
482
|
+
"output_c": output_shape[3] if len(output_shape) >= 4 else 0,
|
|
483
|
+
"kernel_h": weights_shape[1] if len(weights_shape) >= 4 else 0,
|
|
484
|
+
"kernel_w": weights_shape[2] if len(weights_shape) >= 4 else 0,
|
|
485
|
+
"stride_h": stride_h,
|
|
486
|
+
"stride_w": stride_w,
|
|
487
|
+
"padding": padding,
|
|
488
|
+
}
|
|
489
|
+
|
|
490
|
+
op_info["state"] = "translated"
|
|
491
|
+
op_info["pass_flags"]["conv_check"] = "success"
|
|
492
|
+
elif op["op_name"] == "MAX_POOL_2D":
|
|
493
|
+
inputs = op_info["input_indices"]
|
|
494
|
+
if len(inputs) < 1:
|
|
495
|
+
fatal_error("MAX_POOL_2D missing input", "Check model format")
|
|
496
|
+
|
|
497
|
+
op_info["data_input_idx"] = inputs[0]
|
|
498
|
+
|
|
499
|
+
# Extract parameters from input/output shapes
|
|
500
|
+
input_idx = inputs[0]
|
|
501
|
+
output_idx = op_info["output_indices"][0]
|
|
502
|
+
|
|
503
|
+
input_tensor = tensors.get(input_idx, {})
|
|
504
|
+
output_tensor = tensors.get(output_idx, {})
|
|
505
|
+
|
|
506
|
+
input_shape = input_tensor.get("shape", [])
|
|
507
|
+
output_shape = output_tensor.get("shape", [])
|
|
508
|
+
|
|
509
|
+
# Default parameters
|
|
510
|
+
pool_size_h = 2
|
|
511
|
+
pool_size_w = 2
|
|
512
|
+
stride_h = 2
|
|
513
|
+
stride_w = 2
|
|
514
|
+
padding = "VALID"
|
|
515
|
+
|
|
516
|
+
if len(input_shape) >= 4 and len(output_shape) >= 4:
|
|
517
|
+
input_h = input_shape[1]
|
|
518
|
+
input_w = input_shape[2]
|
|
519
|
+
output_h = output_shape[1]
|
|
520
|
+
output_w = output_shape[2]
|
|
521
|
+
|
|
522
|
+
# Infer stride from input/output
|
|
523
|
+
if input_h > output_h:
|
|
524
|
+
stride_h = input_h // output_h if output_h > 0 else 1
|
|
525
|
+
stride_w = input_w // output_w if output_w > 0 else 1
|
|
526
|
+
|
|
527
|
+
op_info["pool_params"] = {
|
|
528
|
+
"input_h": input_shape[1] if len(input_shape) >= 4 else 0,
|
|
529
|
+
"input_w": input_shape[2] if len(input_shape) >= 4 else 0,
|
|
530
|
+
"input_c": input_shape[3] if len(input_shape) >= 4 else 0,
|
|
531
|
+
"output_h": output_shape[1] if len(output_shape) >= 4 else 0,
|
|
532
|
+
"output_w": output_shape[2] if len(output_shape) >= 4 else 0,
|
|
533
|
+
"output_c": output_shape[3] if len(output_shape) >= 4 else 0,
|
|
534
|
+
"pool_size_h": pool_size_h,
|
|
535
|
+
"pool_size_w": pool_size_w,
|
|
536
|
+
"stride_h": stride_h,
|
|
537
|
+
"stride_w": stride_w,
|
|
538
|
+
"padding": padding,
|
|
539
|
+
}
|
|
540
|
+
|
|
541
|
+
op_info["state"] = "translated"
|
|
542
|
+
op_info["pass_flags"]["pool_check"] = "success"
|
|
543
|
+
elif op["op_name"] == "DEPTHWISE_CONV_2D":
|
|
544
|
+
inputs = op_info["input_indices"]
|
|
545
|
+
if len(inputs) < 3:
|
|
546
|
+
fatal_error(
|
|
547
|
+
"DEPTHWISE_CONV_2D input incomplete",
|
|
548
|
+
"Check model format"
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
op_info["data_input_idx"] = inputs[0]
|
|
552
|
+
op_info["dw_weights_idx"] = inputs[1]
|
|
553
|
+
op_info["dw_bias_idx"] = inputs[2]
|
|
554
|
+
|
|
555
|
+
# Extract parameters
|
|
556
|
+
input_idx = inputs[0]
|
|
557
|
+
output_idx = op_info["output_indices"][0]
|
|
558
|
+
|
|
559
|
+
input_tensor = tensors.get(input_idx, {})
|
|
560
|
+
output_tensor = tensors.get(output_idx, {})
|
|
561
|
+
|
|
562
|
+
input_shape = input_tensor.get("shape", [])
|
|
563
|
+
output_shape = output_tensor.get("shape", [])
|
|
564
|
+
weights_tensor = tensors.get(inputs[1], {})
|
|
565
|
+
weights_shape = weights_tensor.get("shape", [])
|
|
566
|
+
|
|
567
|
+
op_info["dw_params"] = {
|
|
568
|
+
"input_h": input_shape[1] if len(input_shape) >= 4 else 0,
|
|
569
|
+
"input_w": input_shape[2] if len(input_shape) >= 4 else 0,
|
|
570
|
+
"input_c": input_shape[3] if len(input_shape) >= 4 else 0,
|
|
571
|
+
"output_h": output_shape[1] if len(output_shape) >= 4 else 0,
|
|
572
|
+
"output_w": output_shape[2] if len(output_shape) >= 4 else 0,
|
|
573
|
+
"output_c": output_shape[3] if len(output_shape) >= 4 else 0,
|
|
574
|
+
"kernel_h": weights_shape[1] if len(weights_shape) >= 4 else 0,
|
|
575
|
+
"kernel_w": weights_shape[2] if len(weights_shape) >= 4 else 0,
|
|
576
|
+
"depth_multiplier": weights_shape[3] // input_shape[3] if len(
|
|
577
|
+
input_shape) >= 4 and len(weights_shape) >= 4 else 1,
|
|
578
|
+
"stride_h": 1,
|
|
579
|
+
"stride_w": 1,
|
|
580
|
+
"padding_h": 0,
|
|
581
|
+
"padding_w": 0,
|
|
582
|
+
}
|
|
583
|
+
|
|
584
|
+
op_info["state"] = "translated"
|
|
585
|
+
op_info["pass_flags"]["dw_check"] = "success"
|
|
586
|
+
elif op["op_name"] == "RELU":
|
|
587
|
+
# ReLU only needs input/output, no extra parameters
|
|
588
|
+
if len(op_info["input_indices"]) < 1:
|
|
589
|
+
fatal_error("RELU missing input", "Check model format")
|
|
590
|
+
if len(op_info["output_indices"]) < 1:
|
|
591
|
+
fatal_error("RELU missing output", "Check model format")
|
|
592
|
+
op_info["state"] = "translated"
|
|
593
|
+
op_info["pass_flags"]["relu_check"] = "success"
|
|
594
|
+
elif op["op_name"] == "AVERAGE_POOL_2D":
|
|
595
|
+
inputs = op_info["input_indices"]
|
|
596
|
+
if len(inputs) < 1:
|
|
597
|
+
fatal_error(
|
|
598
|
+
"AVERAGE_POOL_2D missing input",
|
|
599
|
+
"Check model format"
|
|
600
|
+
)
|
|
601
|
+
|
|
602
|
+
op_info["data_input_idx"] = inputs[0]
|
|
603
|
+
|
|
604
|
+
input_idx = inputs[0]
|
|
605
|
+
output_idx = op_info["output_indices"][0]
|
|
606
|
+
|
|
607
|
+
input_tensor = tensors.get(input_idx, {})
|
|
608
|
+
output_tensor = tensors.get(output_idx, {})
|
|
609
|
+
|
|
610
|
+
input_shape = input_tensor.get("shape", [])
|
|
611
|
+
output_shape = output_tensor.get("shape", [])
|
|
612
|
+
|
|
613
|
+
pool_h = 2
|
|
614
|
+
pool_w = 2
|
|
615
|
+
stride_h = 2
|
|
616
|
+
stride_w = 2
|
|
617
|
+
|
|
618
|
+
if len(input_shape) >= 4 and len(output_shape) >= 4:
|
|
619
|
+
input_h = input_shape[1]
|
|
620
|
+
input_w = input_shape[2]
|
|
621
|
+
output_h = output_shape[1]
|
|
622
|
+
output_w = output_shape[2]
|
|
623
|
+
if input_h > output_h and output_h > 0:
|
|
624
|
+
stride_h = input_h // output_h
|
|
625
|
+
stride_w = input_w // output_w
|
|
626
|
+
|
|
627
|
+
op_info["pool_params"] = {
|
|
628
|
+
"input_h": input_shape[1] if len(input_shape) >= 4 else 0,
|
|
629
|
+
"input_w": input_shape[2] if len(input_shape) >= 4 else 0,
|
|
630
|
+
"input_c": input_shape[3] if len(input_shape) >= 4 else 0,
|
|
631
|
+
"output_h": output_shape[1] if len(output_shape) >= 4 else 0,
|
|
632
|
+
"output_w": output_shape[2] if len(output_shape) >= 4 else 0,
|
|
633
|
+
"output_c": output_shape[3] if len(output_shape) >= 4 else 0,
|
|
634
|
+
"pool_h": pool_h,
|
|
635
|
+
"pool_w": pool_w,
|
|
636
|
+
"stride_h": stride_h,
|
|
637
|
+
"stride_w": stride_w,
|
|
638
|
+
"padding": "VALID",
|
|
639
|
+
}
|
|
640
|
+
|
|
641
|
+
op_info["state"] = "translated"
|
|
642
|
+
op_info["pass_flags"]["avg_pool_check"] = "success"
|
|
643
|
+
elif op["op_name"] == "TRANSPOSE":
|
|
644
|
+
inputs = op_info["input_indices"]
|
|
645
|
+
if len(inputs) < 2:
|
|
646
|
+
fatal_error(
|
|
647
|
+
"TRANSPOSE missing perm parameter",
|
|
648
|
+
"Check model format"
|
|
649
|
+
)
|
|
650
|
+
|
|
651
|
+
# inputs[0] = input data, inputs[1] = perm (transpose order)
|
|
652
|
+
op_info["data_input_idx"] = inputs[0]
|
|
653
|
+
op_info["transpose_perm_idx"] = inputs[1]
|
|
654
|
+
|
|
655
|
+
# Infer transpose parameters from input/output shapes
|
|
656
|
+
input_idx = inputs[0]
|
|
657
|
+
output_idx = op_info["output_indices"][0]
|
|
658
|
+
input_tensor = tensors.get(input_idx, {})
|
|
659
|
+
output_tensor = tensors.get(output_idx, {})
|
|
660
|
+
input_shape = input_tensor.get("shape", [])
|
|
661
|
+
output_shape = output_tensor.get("shape", [])
|
|
662
|
+
|
|
663
|
+
op_info["transpose_params"] = {
|
|
664
|
+
"input_dims": len(input_shape),
|
|
665
|
+
"output_dims": len(output_shape),
|
|
666
|
+
}
|
|
667
|
+
|
|
668
|
+
op_info["state"] = "translated"
|
|
669
|
+
op_info["pass_flags"]["transpose_check"] = "success"
|
|
670
|
+
elif op["op_name"] == "QUANTIZE":
|
|
671
|
+
op_info["state"] = "translated"
|
|
672
|
+
op_info["pass_flags"]["quantize_check"] = "success"
|
|
673
|
+
elif op["op_name"] == "PAD":
|
|
674
|
+
inputs = op_info["input_indices"]
|
|
675
|
+
if len(inputs) < 2:
|
|
676
|
+
fatal_error(
|
|
677
|
+
"PAD missing padding parameter",
|
|
678
|
+
"Check model format"
|
|
679
|
+
)
|
|
680
|
+
|
|
681
|
+
op_info["data_input_idx"] = inputs[0]
|
|
682
|
+
op_info["pad_paddings_idx"] = inputs[1]
|
|
683
|
+
op_info["state"] = "translated"
|
|
684
|
+
op_info["pass_flags"]["pad_check"] = "success"
|
|
685
|
+
elif op["op_name"] == "MEAN":
|
|
686
|
+
inputs = op_info["input_indices"]
|
|
687
|
+
if len(inputs) < 1:
|
|
688
|
+
fatal_error("MEAN missing input", "Check model format")
|
|
689
|
+
|
|
690
|
+
op_info["data_input_idx"] = inputs[0]
|
|
691
|
+
# Record axis parameter if present
|
|
692
|
+
if len(inputs) >= 2:
|
|
693
|
+
op_info["mean_axis_idx"] = inputs[1]
|
|
694
|
+
|
|
695
|
+
# Extract parameters from input/output shapes
|
|
696
|
+
input_idx = inputs[0]
|
|
697
|
+
output_idx = op_info["output_indices"][0]
|
|
698
|
+
|
|
699
|
+
input_tensor = tensors.get(input_idx, {})
|
|
700
|
+
output_tensor = tensors.get(output_idx, {})
|
|
701
|
+
|
|
702
|
+
input_shape = input_tensor.get("shape", [])
|
|
703
|
+
output_shape = output_tensor.get("shape", [])
|
|
704
|
+
|
|
705
|
+
op_info["mean_params"] = {
|
|
706
|
+
"input_dims": len(input_shape),
|
|
707
|
+
"output_dims": len(output_shape),
|
|
708
|
+
}
|
|
709
|
+
|
|
710
|
+
op_info["state"] = "translated"
|
|
711
|
+
op_info["pass_flags"]["mean_check"] = "success"
|
|
712
|
+
elif op["op_name"] == "DELEGATE":
|
|
713
|
+
continue # skip
|
|
714
|
+
else:
|
|
715
|
+
# Unknown operator, keep created state
|
|
716
|
+
warning(
|
|
717
|
+
"Unknown operator encountered, "
|
|
718
|
+
"please submit an issue or patch set"
|
|
719
|
+
)
|
|
720
|
+
op_info["state"] = "created"
|
|
721
|
+
op_info["pass_flags"]["unknown"] = "needs_implementation"
|
|
722
|
+
|
|
723
|
+
# Add input/output details
|
|
724
|
+
for inp_idx in op_info["input_indices"]:
|
|
725
|
+
tensor_info = tensors.get(inp_idx, {})
|
|
726
|
+
op_info["input_details"].append({
|
|
727
|
+
"index": inp_idx,
|
|
728
|
+
"name": tensor_info.get("name", "unknown"),
|
|
729
|
+
"shape": tensor_info.get("shape", []),
|
|
730
|
+
"size": tensor_info.get("size", 0),
|
|
731
|
+
"scale": tensor_info.get("scale", 1.0),
|
|
732
|
+
"zero_point": tensor_info.get("zero_point", 0),
|
|
733
|
+
})
|
|
734
|
+
|
|
735
|
+
for out_idx in op_info["output_indices"]:
|
|
736
|
+
tensor_info = tensors.get(out_idx, {})
|
|
737
|
+
op_info["output_details"].append({
|
|
738
|
+
"index": out_idx,
|
|
739
|
+
"name": tensor_info.get("name", "unknown"),
|
|
740
|
+
"shape": tensor_info.get("shape", []),
|
|
741
|
+
"size": tensor_info.get("size", 0),
|
|
742
|
+
"scale": tensor_info.get("scale", 1.0),
|
|
743
|
+
"zero_point": tensor_info.get("zero_point", 0),
|
|
744
|
+
})
|
|
745
|
+
|
|
746
|
+
ops.append(op_info)
|
|
747
|
+
|
|
748
|
+
return {
|
|
749
|
+
"input": input_details,
|
|
750
|
+
"output": output_details,
|
|
751
|
+
"ops": ops,
|
|
752
|
+
"tensors": tensors,
|
|
753
|
+
"weights": {}, # LiteRT uses separate weight extraction,
|
|
754
|
+
# kept for unified interface
|
|
755
|
+
"quant_scales": {}, # Quantization scales
|
|
756
|
+
# (LiteRT stores per-tensor in tensors)
|
|
757
|
+
}
|