rknncli 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rknncli/__init__.py +3 -0
- rknncli/cli.py +229 -0
- rknncli/parser.py +320 -0
- rknncli/schema/rknn/Graph.py +371 -0
- rknncli/schema/rknn/Model.py +422 -0
- rknncli/schema/rknn/Node.py +380 -0
- rknncli/schema/rknn/Tensor.py +697 -0
- rknncli/schema/rknn/Type1.py +94 -0
- rknncli/schema/rknn/Type2.py +255 -0
- rknncli/schema/rknn/Type3.py +94 -0
- rknncli/schema/rknn/__init__.py +0 -0
- rknncli-0.2.0.dist-info/METADATA +69 -0
- rknncli-0.2.0.dist-info/RECORD +16 -0
- rknncli-0.2.0.dist-info/WHEEL +5 -0
- rknncli-0.2.0.dist-info/entry_points.txt +2 -0
- rknncli-0.2.0.dist-info/top_level.txt +1 -0
rknncli/__init__.py
ADDED
rknncli/cli.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
"""RKNN CLI - Command line interface for parsing RKNN models."""
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
import sys
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import List, Dict, Any
|
|
7
|
+
|
|
8
|
+
from rknncli.parser import RKNNParser
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def format_shape(size: List) -> List:
|
|
12
|
+
"""Format tensor shape with dimension names.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
size: List of dimension sizes.
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
List of dimension names (strings) or dimension values.
|
|
19
|
+
"""
|
|
20
|
+
return size
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def get_dtype_str(dtype_info: Dict) -> str:
|
|
24
|
+
"""Get data type string from dtype info.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
dtype_info: Dictionary containing dtype information.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
String representation of the data type.
|
|
31
|
+
"""
|
|
32
|
+
if isinstance(dtype_info, dict):
|
|
33
|
+
# Try vx_type first, then qnt_type
|
|
34
|
+
vx_type = dtype_info.get("vx_type", "").strip()
|
|
35
|
+
if vx_type:
|
|
36
|
+
return vx_type.upper()
|
|
37
|
+
qnt_type = dtype_info.get("qnt_type", "").strip()
|
|
38
|
+
if qnt_type:
|
|
39
|
+
return qnt_type.upper()
|
|
40
|
+
return "FLOAT"
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def print_merged_model_info(parser: RKNNParser) -> None:
|
|
46
|
+
"""Print merged model information from both FlatBuffers and JSON.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
parser: RKNNParser instance with FlatBuffers support.
|
|
50
|
+
"""
|
|
51
|
+
# Get basic model info from JSON
|
|
52
|
+
print(f"Model: {parser.get_model_name()}")
|
|
53
|
+
platforms = parser.get_target_platform()
|
|
54
|
+
if platforms:
|
|
55
|
+
print(f"Target Platform: {', '.join(platforms)}")
|
|
56
|
+
|
|
57
|
+
# Get FlatBuffers info
|
|
58
|
+
fb_info = parser.get_flatbuffers_info()
|
|
59
|
+
if fb_info:
|
|
60
|
+
print(f"Format: {fb_info.get('format', 'Unknown')}")
|
|
61
|
+
print(f"Source: {fb_info.get('source', 'Unknown')}")
|
|
62
|
+
print(f"Compiler: {fb_info.get('compiler', 'Unknown')}")
|
|
63
|
+
print(f"Runtime: {fb_info.get('runtime', 'Unknown')}")
|
|
64
|
+
|
|
65
|
+
if fb_info.get("num_graphs", 0) > 0:
|
|
66
|
+
print(f"Number of graphs: {fb_info['num_graphs']}")
|
|
67
|
+
|
|
68
|
+
print()
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def print_merged_io_info(parser: RKNNParser) -> None:
|
|
72
|
+
"""Print merged input/output information from both FlatBuffers and JSON.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
parser: RKNNParser instance with FlatBuffers support.
|
|
76
|
+
"""
|
|
77
|
+
# Get merged IO info
|
|
78
|
+
inputs, outputs = parser.get_merged_io_info()
|
|
79
|
+
|
|
80
|
+
# Print input information
|
|
81
|
+
print("Input information")
|
|
82
|
+
print("-" * 80)
|
|
83
|
+
|
|
84
|
+
for tensor in inputs:
|
|
85
|
+
name = tensor.get("url", f"tensor_{tensor.get('tensor_id', 0)}")
|
|
86
|
+
dtype = tensor.get("dtype", {})
|
|
87
|
+
|
|
88
|
+
# Get dtype string
|
|
89
|
+
if isinstance(dtype, dict):
|
|
90
|
+
dtype_str = get_dtype_str(dtype)
|
|
91
|
+
else:
|
|
92
|
+
dtype_str = str(dtype).upper()
|
|
93
|
+
|
|
94
|
+
size = tensor.get("size", [])
|
|
95
|
+
shape = format_shape(size)
|
|
96
|
+
shape_str = "[" + ", ".join(f"'{s}'" if isinstance(s, str) else str(s) for s in shape) + "]"
|
|
97
|
+
|
|
98
|
+
# Build the base output string
|
|
99
|
+
output_parts = [f' ValueInfo "{name}": type {dtype_str}, shape {shape_str}']
|
|
100
|
+
|
|
101
|
+
# Add layout info from FlatBuffers
|
|
102
|
+
if "layout" in tensor and tensor["layout"]:
|
|
103
|
+
layout = tensor["layout"].upper()
|
|
104
|
+
layout_ori = tensor.get("layout_ori", "").upper()
|
|
105
|
+
if layout_ori and layout_ori != layout:
|
|
106
|
+
output_parts.append(f"layout {layout}(ori:{layout_ori})")
|
|
107
|
+
else:
|
|
108
|
+
output_parts.append(f"layout {layout}")
|
|
109
|
+
|
|
110
|
+
# Add quantization info from FlatBuffers
|
|
111
|
+
if "quant_info" in tensor:
|
|
112
|
+
quant = tensor["quant_info"]
|
|
113
|
+
if quant.get("qmethod") or quant.get("qtype"):
|
|
114
|
+
output_parts.append(f"quant {quant['qmethod']} {quant['qtype']}")
|
|
115
|
+
|
|
116
|
+
# Print the merged info on one line
|
|
117
|
+
print(", ".join(output_parts) + ",")
|
|
118
|
+
|
|
119
|
+
print()
|
|
120
|
+
|
|
121
|
+
# Print output information
|
|
122
|
+
print("Output information")
|
|
123
|
+
print("-" * 80)
|
|
124
|
+
|
|
125
|
+
for tensor in outputs:
|
|
126
|
+
name = tensor.get("url", f"tensor_{tensor.get('tensor_id', 0)}")
|
|
127
|
+
dtype = tensor.get("dtype", {})
|
|
128
|
+
|
|
129
|
+
# Get dtype string
|
|
130
|
+
if isinstance(dtype, dict):
|
|
131
|
+
dtype_str = get_dtype_str(dtype)
|
|
132
|
+
else:
|
|
133
|
+
dtype_str = str(dtype).upper()
|
|
134
|
+
|
|
135
|
+
size = tensor.get("size", [])
|
|
136
|
+
shape = format_shape(size)
|
|
137
|
+
shape_str = "[" + ", ".join(f"'{s}'" if isinstance(s, str) else str(s) for s in shape) + "]"
|
|
138
|
+
|
|
139
|
+
# Build the base output string
|
|
140
|
+
output_parts = [f' ValueInfo "{name}": type {dtype_str}, shape {shape_str}']
|
|
141
|
+
|
|
142
|
+
# Add layout info from FlatBuffers
|
|
143
|
+
if "layout" in tensor and tensor["layout"]:
|
|
144
|
+
layout = tensor["layout"].upper()
|
|
145
|
+
layout_ori = tensor.get("layout_ori", "").upper()
|
|
146
|
+
if layout_ori and layout_ori != layout:
|
|
147
|
+
output_parts.append(f"layout {layout}(ori:{layout_ori})")
|
|
148
|
+
else:
|
|
149
|
+
output_parts.append(f"layout {layout}")
|
|
150
|
+
|
|
151
|
+
# Add quantization info from FlatBuffers
|
|
152
|
+
if "quant_info" in tensor:
|
|
153
|
+
quant = tensor["quant_info"]
|
|
154
|
+
if quant.get("qmethod") or quant.get("qtype"):
|
|
155
|
+
output_parts.append(f"quant {quant['qmethod']} {quant['qtype']}")
|
|
156
|
+
|
|
157
|
+
# Print the merged info on one line
|
|
158
|
+
print(", ".join(output_parts) + ",")
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def print_model_summary(parser) -> None:
|
|
162
|
+
"""Print model summary information.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
parser: RKNNParser instance.
|
|
166
|
+
"""
|
|
167
|
+
print(f"Model: {parser.get_model_name()}")
|
|
168
|
+
print(f"Version: {parser.get_version()}")
|
|
169
|
+
platforms = parser.get_target_platform()
|
|
170
|
+
if platforms:
|
|
171
|
+
print(f"Target Platform: {', '.join(platforms)}")
|
|
172
|
+
print()
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def main() -> int:
|
|
178
|
+
"""Main entry point.
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
Exit code (0 for success, non-zero for error).
|
|
182
|
+
"""
|
|
183
|
+
parser = argparse.ArgumentParser(
|
|
184
|
+
prog="rknncli",
|
|
185
|
+
description="A command line tool for parsing and displaying RKNN model information.",
|
|
186
|
+
)
|
|
187
|
+
parser.add_argument(
|
|
188
|
+
"model",
|
|
189
|
+
type=str,
|
|
190
|
+
help="Path to the RKNN model file",
|
|
191
|
+
)
|
|
192
|
+
parser.add_argument(
|
|
193
|
+
"-v", "--version",
|
|
194
|
+
action="version",
|
|
195
|
+
version="%(prog)s 0.1.0",
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
args = parser.parse_args()
|
|
199
|
+
|
|
200
|
+
model_path = Path(args.model)
|
|
201
|
+
if not model_path.exists():
|
|
202
|
+
print(f"Error: File not found: {model_path}", file=sys.stderr)
|
|
203
|
+
return 1
|
|
204
|
+
|
|
205
|
+
if not model_path.is_file():
|
|
206
|
+
print(f"Error: Not a file: {model_path}", file=sys.stderr)
|
|
207
|
+
return 1
|
|
208
|
+
|
|
209
|
+
try:
|
|
210
|
+
# Always parse both FlatBuffers and JSON data
|
|
211
|
+
rknn_parser = RKNNParser(model_path, parse_flatbuffers=True)
|
|
212
|
+
|
|
213
|
+
# Print merged model information
|
|
214
|
+
print_merged_model_info(rknn_parser)
|
|
215
|
+
|
|
216
|
+
# Print merged IO information
|
|
217
|
+
print_merged_io_info(rknn_parser)
|
|
218
|
+
except ValueError as e:
|
|
219
|
+
print(f"Error: Failed to parse RKNN file: {e}", file=sys.stderr)
|
|
220
|
+
return 1
|
|
221
|
+
except Exception as e:
|
|
222
|
+
print(f"Error: {e}", file=sys.stderr)
|
|
223
|
+
return 1
|
|
224
|
+
|
|
225
|
+
return 0
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
if __name__ == "__main__":
|
|
229
|
+
sys.exit(main())
|
rknncli/parser.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
"""RKNN file parser with FlatBuffers support."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import struct
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Optional, Dict, List, Union, Tuple
|
|
7
|
+
|
|
8
|
+
import flatbuffers
|
|
9
|
+
from rknncli.schema.rknn.Model import Model
|
|
10
|
+
from rknncli.schema.rknn.Graph import Graph
|
|
11
|
+
from rknncli.schema.rknn.Tensor import Tensor
|
|
12
|
+
from rknncli.schema.rknn.Node import Node
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class RKNNParser:
|
|
16
|
+
"""Parser for RKNN model files with optional FlatBuffers support."""
|
|
17
|
+
|
|
18
|
+
HEADER_SIZE = 64
|
|
19
|
+
MAGIC_NUMBER = b"RKNN"
|
|
20
|
+
|
|
21
|
+
def __init__(self, file_path: Union[str, Path], parse_flatbuffers: bool = True):
|
|
22
|
+
"""Initialize parser with RKNN file path.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
file_path: Path to the RKNN model file.
|
|
26
|
+
parse_flatbuffers: Whether to parse FlatBuffers data. Defaults to False.
|
|
27
|
+
"""
|
|
28
|
+
self.file_path = Path(file_path)
|
|
29
|
+
self.header: Dict[str, Any] = {}
|
|
30
|
+
self.model_info: Dict[str, Any] = {}
|
|
31
|
+
self.fb_model: Optional[Model] = None
|
|
32
|
+
self._parse(parse_flatbuffers)
|
|
33
|
+
|
|
34
|
+
def _parse(self, parse_flatbuffers: bool) -> None:
|
|
35
|
+
"""Parse the RKNN file.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
parse_flatbuffers: Whether to parse FlatBuffers data.
|
|
39
|
+
"""
|
|
40
|
+
with open(self.file_path, "rb") as f:
|
|
41
|
+
# Read header
|
|
42
|
+
header_data = f.read(self.HEADER_SIZE)
|
|
43
|
+
if len(header_data) < self.HEADER_SIZE:
|
|
44
|
+
raise ValueError(f"File too small: {self.file_path}")
|
|
45
|
+
|
|
46
|
+
# Parse header
|
|
47
|
+
# Bytes 0-3: Magic number "RKNN"
|
|
48
|
+
magic = header_data[0:4]
|
|
49
|
+
if magic != self.MAGIC_NUMBER:
|
|
50
|
+
raise ValueError(f"Invalid magic number: {magic!r}, expected {self.MAGIC_NUMBER!r}")
|
|
51
|
+
|
|
52
|
+
# Bytes 4-7: Padding (4 bytes, zeros)
|
|
53
|
+
padding = header_data[4:8]
|
|
54
|
+
|
|
55
|
+
# Bytes 8-15: File format version (8 bytes, little-endian uint64)
|
|
56
|
+
file_format = struct.unpack("<Q", header_data[8:16])[0]
|
|
57
|
+
|
|
58
|
+
# Bytes 16-23: File length (8 bytes, little-endian uint64)
|
|
59
|
+
file_length = struct.unpack("<Q", header_data[16:24])[0]
|
|
60
|
+
|
|
61
|
+
self.header = {
|
|
62
|
+
"magic": magic,
|
|
63
|
+
"padding": padding,
|
|
64
|
+
"file_format": file_format,
|
|
65
|
+
"file_length": file_length,
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
# Calculate JSON offset and size
|
|
69
|
+
real_header_size = self.HEADER_SIZE
|
|
70
|
+
if file_format <= 1:
|
|
71
|
+
# only 3 uint64 for rknn-v1
|
|
72
|
+
real_header_size = 24
|
|
73
|
+
|
|
74
|
+
# Parse FlatBuffers data if requested
|
|
75
|
+
if parse_flatbuffers and file_length > 0:
|
|
76
|
+
fb_offset = real_header_size
|
|
77
|
+
fb_data = f.read(file_length)
|
|
78
|
+
if len(fb_data) == file_length:
|
|
79
|
+
self.fb_model = Model.GetRootAs(fb_data, 0)
|
|
80
|
+
|
|
81
|
+
# Read JSON model info
|
|
82
|
+
file_size = self.file_path.stat().st_size
|
|
83
|
+
json_offset = real_header_size + file_length
|
|
84
|
+
f.seek(json_offset)
|
|
85
|
+
json_size = struct.unpack("<Q", f.read(8))[0]
|
|
86
|
+
if (
|
|
87
|
+
json_size <= 0 or
|
|
88
|
+
json_size > file_size - real_header_size - file_length - 8
|
|
89
|
+
):
|
|
90
|
+
raise ValueError(f"Invalid JSON size: {json_size}")
|
|
91
|
+
|
|
92
|
+
json_data = f.read(json_size)
|
|
93
|
+
|
|
94
|
+
try:
|
|
95
|
+
self.model_info = json.loads(json_data.decode("utf-8"))
|
|
96
|
+
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
|
97
|
+
raise ValueError(f"Failed to parse model info JSON: {e}")
|
|
98
|
+
|
|
99
|
+
def get_flatbuffers_info(self) -> Dict[str, Any]:
|
|
100
|
+
"""Get FlatBuffers model information.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
Dictionary containing FlatBuffers model data.
|
|
104
|
+
"""
|
|
105
|
+
if not self.fb_model:
|
|
106
|
+
return {}
|
|
107
|
+
|
|
108
|
+
info = {}
|
|
109
|
+
|
|
110
|
+
# Basic model info
|
|
111
|
+
if self.fb_model.Format():
|
|
112
|
+
info["format"] = self.fb_model.Format().decode('utf-8')
|
|
113
|
+
|
|
114
|
+
if self.fb_model.Generator():
|
|
115
|
+
info["generator"] = self.fb_model.Generator().decode('utf-8')
|
|
116
|
+
|
|
117
|
+
if self.fb_model.Compiler():
|
|
118
|
+
info["compiler"] = self.fb_model.Compiler().decode('utf-8')
|
|
119
|
+
|
|
120
|
+
if self.fb_model.Runtime():
|
|
121
|
+
info["runtime"] = self.fb_model.Runtime().decode('utf-8')
|
|
122
|
+
|
|
123
|
+
if self.fb_model.Source():
|
|
124
|
+
info["source"] = self.fb_model.Source().decode('utf-8')
|
|
125
|
+
|
|
126
|
+
# Graphs info
|
|
127
|
+
info["num_graphs"] = self.fb_model.GraphsLength()
|
|
128
|
+
|
|
129
|
+
# Input/Output JSON strings
|
|
130
|
+
if self.fb_model.InputJson():
|
|
131
|
+
info["input_json"] = self.fb_model.InputJson().decode('utf-8')
|
|
132
|
+
|
|
133
|
+
if self.fb_model.OutputJson():
|
|
134
|
+
info["output_json"] = self.fb_model.OutputJson().decode('utf-8')
|
|
135
|
+
|
|
136
|
+
# Parse graphs
|
|
137
|
+
graphs = []
|
|
138
|
+
for i in range(self.fb_model.GraphsLength()):
|
|
139
|
+
graph = self.fb_model.Graphs(i)
|
|
140
|
+
if graph:
|
|
141
|
+
graph_info = {
|
|
142
|
+
"num_tensors": graph.TensorsLength(),
|
|
143
|
+
"num_nodes": graph.NodesLength(),
|
|
144
|
+
"num_inputs": graph.InputsLength(),
|
|
145
|
+
"num_outputs": graph.OutputsLength(),
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
# Parse tensors
|
|
149
|
+
tensors = []
|
|
150
|
+
for j in range(graph.TensorsLength()):
|
|
151
|
+
tensor = graph.Tensors(j)
|
|
152
|
+
if tensor:
|
|
153
|
+
tensor_info = {
|
|
154
|
+
"data_type": tensor.DataType(),
|
|
155
|
+
"kind": tensor.Kind(),
|
|
156
|
+
"name": tensor.Name().decode('utf-8') if tensor.Name() else "",
|
|
157
|
+
"shape": [tensor.Shape(k) for k in range(tensor.ShapeLength())],
|
|
158
|
+
"size": tensor.Size(),
|
|
159
|
+
"index": tensor.Index(),
|
|
160
|
+
}
|
|
161
|
+
tensors.append(tensor_info)
|
|
162
|
+
graph_info["tensors"] = tensors
|
|
163
|
+
|
|
164
|
+
# Parse nodes
|
|
165
|
+
nodes = []
|
|
166
|
+
for j in range(graph.NodesLength()):
|
|
167
|
+
node = graph.Nodes(j)
|
|
168
|
+
if node:
|
|
169
|
+
node_info = {
|
|
170
|
+
"type": node.Type().decode('utf-8') if node.Type() else "",
|
|
171
|
+
"name": node.Name().decode('utf-8') if node.Name() else "",
|
|
172
|
+
"num_inputs": node.InputsLength(),
|
|
173
|
+
"num_outputs": node.OutputsLength(),
|
|
174
|
+
}
|
|
175
|
+
nodes.append(node_info)
|
|
176
|
+
graph_info["nodes"] = nodes
|
|
177
|
+
|
|
178
|
+
graphs.append(graph_info)
|
|
179
|
+
|
|
180
|
+
info["graphs"] = graphs
|
|
181
|
+
|
|
182
|
+
return info
|
|
183
|
+
|
|
184
|
+
def get_input_info(self) -> List[Dict[str, Any]]:
|
|
185
|
+
"""Get model input information.
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
List of input tensor information.
|
|
189
|
+
"""
|
|
190
|
+
inputs = []
|
|
191
|
+
norm_tensors = {t["tensor_id"]: t for t in self.model_info.get("norm_tensor", [])}
|
|
192
|
+
|
|
193
|
+
# Find input tensors from graph connections
|
|
194
|
+
for conn in self.model_info.get("graph", []):
|
|
195
|
+
if conn.get("left") == "input":
|
|
196
|
+
tensor_id = conn.get("right_tensor_id")
|
|
197
|
+
if tensor_id in norm_tensors:
|
|
198
|
+
inputs.append(norm_tensors[tensor_id])
|
|
199
|
+
|
|
200
|
+
return inputs
|
|
201
|
+
|
|
202
|
+
def get_output_info(self) -> List[Dict[str, Any]]:
|
|
203
|
+
"""Get model output information.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
List of output tensor information.
|
|
207
|
+
"""
|
|
208
|
+
outputs = []
|
|
209
|
+
norm_tensors = {t["tensor_id"]: t for t in self.model_info.get("norm_tensor", [])}
|
|
210
|
+
|
|
211
|
+
# Find output tensors from graph connections
|
|
212
|
+
for conn in self.model_info.get("graph", []):
|
|
213
|
+
if conn.get("left") == "output":
|
|
214
|
+
tensor_id = conn.get("right_tensor_id")
|
|
215
|
+
if tensor_id in norm_tensors:
|
|
216
|
+
outputs.append(norm_tensors[tensor_id])
|
|
217
|
+
|
|
218
|
+
return outputs
|
|
219
|
+
|
|
220
|
+
def get_model_name(self) -> str:
|
|
221
|
+
"""Get model name."""
|
|
222
|
+
return self.model_info.get("name", "Unknown")
|
|
223
|
+
|
|
224
|
+
def get_version(self) -> str:
|
|
225
|
+
"""Get model version."""
|
|
226
|
+
return self.model_info.get("version", "Unknown")
|
|
227
|
+
|
|
228
|
+
def get_target_platform(self) -> list[str]:
|
|
229
|
+
"""Get target platform."""
|
|
230
|
+
return self.model_info.get("target_platform", [])
|
|
231
|
+
|
|
232
|
+
def get_merged_io_info(self) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
|
233
|
+
"""Get merged input/output information from both FlatBuffers and JSON.
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
Tuple of (inputs, outputs) with merged information.
|
|
237
|
+
"""
|
|
238
|
+
# Get basic IO info from JSON
|
|
239
|
+
json_inputs = self.get_input_info()
|
|
240
|
+
json_outputs = self.get_output_info()
|
|
241
|
+
|
|
242
|
+
# If no FlatBuffers data, return JSON data as is
|
|
243
|
+
if not self.fb_model:
|
|
244
|
+
return json_inputs, json_outputs
|
|
245
|
+
|
|
246
|
+
# Get FlatBuffers info
|
|
247
|
+
fb_info = self.get_flatbuffers_info()
|
|
248
|
+
|
|
249
|
+
# Extract generator attributes and quant info
|
|
250
|
+
generator_attrs = {}
|
|
251
|
+
quant_tab = {}
|
|
252
|
+
|
|
253
|
+
# Parse generator JSON string if present
|
|
254
|
+
if fb_info.get("generator"):
|
|
255
|
+
try:
|
|
256
|
+
# Use ast.literal_eval to handle single quotes
|
|
257
|
+
import ast
|
|
258
|
+
gen_data = ast.literal_eval(fb_info["generator"])
|
|
259
|
+
if isinstance(gen_data, dict):
|
|
260
|
+
generator_attrs = gen_data.get("attrs", {})
|
|
261
|
+
quant_tab = gen_data.get("quant_tab", {})
|
|
262
|
+
except (ValueError, SyntaxError):
|
|
263
|
+
# Fallback: try JSON parse after fixing quotes
|
|
264
|
+
try:
|
|
265
|
+
import json
|
|
266
|
+
fixed_str = fb_info["generator"].replace("'", '"')
|
|
267
|
+
gen_data = json.loads(fixed_str)
|
|
268
|
+
if isinstance(gen_data, dict):
|
|
269
|
+
generator_attrs = gen_data.get("attrs", {})
|
|
270
|
+
quant_tab = gen_data.get("quant_tab", {})
|
|
271
|
+
except (json.JSONDecodeError, ValueError):
|
|
272
|
+
pass
|
|
273
|
+
|
|
274
|
+
# Merge input information
|
|
275
|
+
merged_inputs = []
|
|
276
|
+
for inp in json_inputs:
|
|
277
|
+
io_name = inp.get("url", "")
|
|
278
|
+
merged_inp = inp.copy()
|
|
279
|
+
|
|
280
|
+
# Add layout info from generator attrs
|
|
281
|
+
if io_name in generator_attrs:
|
|
282
|
+
merged_inp["layout"] = generator_attrs[io_name].get("layout", "")
|
|
283
|
+
merged_inp["layout_ori"] = generator_attrs[io_name].get("layout_ori", "")
|
|
284
|
+
|
|
285
|
+
# Add quant info from quant_tab
|
|
286
|
+
if io_name in quant_tab:
|
|
287
|
+
merged_inp["dtype"] = quant_tab[io_name].get("dtype", inp.get("dtype", {}))
|
|
288
|
+
merged_inp["quant_info"] = {
|
|
289
|
+
"qmethod": quant_tab[io_name].get("qmethod", ""),
|
|
290
|
+
"qtype": quant_tab[io_name].get("qtype", ""),
|
|
291
|
+
"scale": quant_tab[io_name].get("scale", []),
|
|
292
|
+
"zero_point": quant_tab[io_name].get("zero_point", [])
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
merged_inputs.append(merged_inp)
|
|
296
|
+
|
|
297
|
+
# Merge output information
|
|
298
|
+
merged_outputs = []
|
|
299
|
+
for out in json_outputs:
|
|
300
|
+
io_name = out.get("url", "")
|
|
301
|
+
merged_out = out.copy()
|
|
302
|
+
|
|
303
|
+
# Add layout info from generator attrs
|
|
304
|
+
if io_name in generator_attrs:
|
|
305
|
+
merged_out["layout"] = generator_attrs[io_name].get("layout", "")
|
|
306
|
+
merged_out["layout_ori"] = generator_attrs[io_name].get("layout_ori", "")
|
|
307
|
+
|
|
308
|
+
# Add quant info from quant_tab
|
|
309
|
+
if io_name in quant_tab:
|
|
310
|
+
merged_out["dtype"] = quant_tab[io_name].get("dtype", out.get("dtype", {}))
|
|
311
|
+
merged_out["quant_info"] = {
|
|
312
|
+
"qmethod": quant_tab[io_name].get("qmethod", ""),
|
|
313
|
+
"qtype": quant_tab[io_name].get("qtype", ""),
|
|
314
|
+
"scale": quant_tab[io_name].get("scale", []),
|
|
315
|
+
"zero_point": quant_tab[io_name].get("zero_point", [])
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
merged_outputs.append(merged_out)
|
|
319
|
+
|
|
320
|
+
return merged_inputs, merged_outputs
|