rknncli 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rknncli/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+ """RKNN CLI - A command line tool for parsing and displaying RKNN model information."""
2
+
3
+ __version__ = "0.1.0"
rknncli/cli.py ADDED
@@ -0,0 +1,229 @@
1
+ """RKNN CLI - Command line interface for parsing RKNN models."""
2
+
3
+ import argparse
4
+ import sys
5
+ from pathlib import Path
6
+ from typing import List, Dict, Any
7
+
8
+ from rknncli.parser import RKNNParser
9
+
10
+
11
+ def format_shape(size: List) -> List:
12
+ """Format tensor shape with dimension names.
13
+
14
+ Args:
15
+ size: List of dimension sizes.
16
+
17
+ Returns:
18
+ List of dimension names (strings) or dimension values.
19
+ """
20
+ return size
21
+
22
+
23
+ def get_dtype_str(dtype_info: Dict) -> str:
24
+ """Get data type string from dtype info.
25
+
26
+ Args:
27
+ dtype_info: Dictionary containing dtype information.
28
+
29
+ Returns:
30
+ String representation of the data type.
31
+ """
32
+ if isinstance(dtype_info, dict):
33
+ # Try vx_type first, then qnt_type
34
+ vx_type = dtype_info.get("vx_type", "").strip()
35
+ if vx_type:
36
+ return vx_type.upper()
37
+ qnt_type = dtype_info.get("qnt_type", "").strip()
38
+ if qnt_type:
39
+ return qnt_type.upper()
40
+ return "FLOAT"
41
+
42
+
43
+
44
+
45
+ def print_merged_model_info(parser: RKNNParser) -> None:
46
+ """Print merged model information from both FlatBuffers and JSON.
47
+
48
+ Args:
49
+ parser: RKNNParser instance with FlatBuffers support.
50
+ """
51
+ # Get basic model info from JSON
52
+ print(f"Model: {parser.get_model_name()}")
53
+ platforms = parser.get_target_platform()
54
+ if platforms:
55
+ print(f"Target Platform: {', '.join(platforms)}")
56
+
57
+ # Get FlatBuffers info
58
+ fb_info = parser.get_flatbuffers_info()
59
+ if fb_info:
60
+ print(f"Format: {fb_info.get('format', 'Unknown')}")
61
+ print(f"Source: {fb_info.get('source', 'Unknown')}")
62
+ print(f"Compiler: {fb_info.get('compiler', 'Unknown')}")
63
+ print(f"Runtime: {fb_info.get('runtime', 'Unknown')}")
64
+
65
+ if fb_info.get("num_graphs", 0) > 0:
66
+ print(f"Number of graphs: {fb_info['num_graphs']}")
67
+
68
+ print()
69
+
70
+
71
+ def print_merged_io_info(parser: RKNNParser) -> None:
72
+ """Print merged input/output information from both FlatBuffers and JSON.
73
+
74
+ Args:
75
+ parser: RKNNParser instance with FlatBuffers support.
76
+ """
77
+ # Get merged IO info
78
+ inputs, outputs = parser.get_merged_io_info()
79
+
80
+ # Print input information
81
+ print("Input information")
82
+ print("-" * 80)
83
+
84
+ for tensor in inputs:
85
+ name = tensor.get("url", f"tensor_{tensor.get('tensor_id', 0)}")
86
+ dtype = tensor.get("dtype", {})
87
+
88
+ # Get dtype string
89
+ if isinstance(dtype, dict):
90
+ dtype_str = get_dtype_str(dtype)
91
+ else:
92
+ dtype_str = str(dtype).upper()
93
+
94
+ size = tensor.get("size", [])
95
+ shape = format_shape(size)
96
+ shape_str = "[" + ", ".join(f"'{s}'" if isinstance(s, str) else str(s) for s in shape) + "]"
97
+
98
+ # Build the base output string
99
+ output_parts = [f' ValueInfo "{name}": type {dtype_str}, shape {shape_str}']
100
+
101
+ # Add layout info from FlatBuffers
102
+ if "layout" in tensor and tensor["layout"]:
103
+ layout = tensor["layout"].upper()
104
+ layout_ori = tensor.get("layout_ori", "").upper()
105
+ if layout_ori and layout_ori != layout:
106
+ output_parts.append(f"layout {layout}(ori:{layout_ori})")
107
+ else:
108
+ output_parts.append(f"layout {layout}")
109
+
110
+ # Add quantization info from FlatBuffers
111
+ if "quant_info" in tensor:
112
+ quant = tensor["quant_info"]
113
+ if quant.get("qmethod") or quant.get("qtype"):
114
+ output_parts.append(f"quant {quant['qmethod']} {quant['qtype']}")
115
+
116
+ # Print the merged info on one line
117
+ print(", ".join(output_parts) + ",")
118
+
119
+ print()
120
+
121
+ # Print output information
122
+ print("Output information")
123
+ print("-" * 80)
124
+
125
+ for tensor in outputs:
126
+ name = tensor.get("url", f"tensor_{tensor.get('tensor_id', 0)}")
127
+ dtype = tensor.get("dtype", {})
128
+
129
+ # Get dtype string
130
+ if isinstance(dtype, dict):
131
+ dtype_str = get_dtype_str(dtype)
132
+ else:
133
+ dtype_str = str(dtype).upper()
134
+
135
+ size = tensor.get("size", [])
136
+ shape = format_shape(size)
137
+ shape_str = "[" + ", ".join(f"'{s}'" if isinstance(s, str) else str(s) for s in shape) + "]"
138
+
139
+ # Build the base output string
140
+ output_parts = [f' ValueInfo "{name}": type {dtype_str}, shape {shape_str}']
141
+
142
+ # Add layout info from FlatBuffers
143
+ if "layout" in tensor and tensor["layout"]:
144
+ layout = tensor["layout"].upper()
145
+ layout_ori = tensor.get("layout_ori", "").upper()
146
+ if layout_ori and layout_ori != layout:
147
+ output_parts.append(f"layout {layout}(ori:{layout_ori})")
148
+ else:
149
+ output_parts.append(f"layout {layout}")
150
+
151
+ # Add quantization info from FlatBuffers
152
+ if "quant_info" in tensor:
153
+ quant = tensor["quant_info"]
154
+ if quant.get("qmethod") or quant.get("qtype"):
155
+ output_parts.append(f"quant {quant['qmethod']} {quant['qtype']}")
156
+
157
+ # Print the merged info on one line
158
+ print(", ".join(output_parts) + ",")
159
+
160
+
161
+ def print_model_summary(parser) -> None:
162
+ """Print model summary information.
163
+
164
+ Args:
165
+ parser: RKNNParser instance.
166
+ """
167
+ print(f"Model: {parser.get_model_name()}")
168
+ print(f"Version: {parser.get_version()}")
169
+ platforms = parser.get_target_platform()
170
+ if platforms:
171
+ print(f"Target Platform: {', '.join(platforms)}")
172
+ print()
173
+
174
+
175
+
176
+
177
+ def main() -> int:
178
+ """Main entry point.
179
+
180
+ Returns:
181
+ Exit code (0 for success, non-zero for error).
182
+ """
183
+ parser = argparse.ArgumentParser(
184
+ prog="rknncli",
185
+ description="A command line tool for parsing and displaying RKNN model information.",
186
+ )
187
+ parser.add_argument(
188
+ "model",
189
+ type=str,
190
+ help="Path to the RKNN model file",
191
+ )
192
+ parser.add_argument(
193
+ "-v", "--version",
194
+ action="version",
195
+ version="%(prog)s 0.1.0",
196
+ )
197
+
198
+ args = parser.parse_args()
199
+
200
+ model_path = Path(args.model)
201
+ if not model_path.exists():
202
+ print(f"Error: File not found: {model_path}", file=sys.stderr)
203
+ return 1
204
+
205
+ if not model_path.is_file():
206
+ print(f"Error: Not a file: {model_path}", file=sys.stderr)
207
+ return 1
208
+
209
+ try:
210
+ # Always parse both FlatBuffers and JSON data
211
+ rknn_parser = RKNNParser(model_path, parse_flatbuffers=True)
212
+
213
+ # Print merged model information
214
+ print_merged_model_info(rknn_parser)
215
+
216
+ # Print merged IO information
217
+ print_merged_io_info(rknn_parser)
218
+ except ValueError as e:
219
+ print(f"Error: Failed to parse RKNN file: {e}", file=sys.stderr)
220
+ return 1
221
+ except Exception as e:
222
+ print(f"Error: {e}", file=sys.stderr)
223
+ return 1
224
+
225
+ return 0
226
+
227
+
228
+ if __name__ == "__main__":
229
+ sys.exit(main())
rknncli/parser.py ADDED
@@ -0,0 +1,320 @@
1
+ """RKNN file parser with FlatBuffers support."""
2
+
3
+ import json
4
+ import struct
5
+ from pathlib import Path
6
+ from typing import Any, Optional, Dict, List, Union, Tuple
7
+
8
+ import flatbuffers
9
+ from rknncli.schema.rknn.Model import Model
10
+ from rknncli.schema.rknn.Graph import Graph
11
+ from rknncli.schema.rknn.Tensor import Tensor
12
+ from rknncli.schema.rknn.Node import Node
13
+
14
+
15
+ class RKNNParser:
16
+ """Parser for RKNN model files with optional FlatBuffers support."""
17
+
18
+ HEADER_SIZE = 64
19
+ MAGIC_NUMBER = b"RKNN"
20
+
21
+ def __init__(self, file_path: Union[str, Path], parse_flatbuffers: bool = True):
22
+ """Initialize parser with RKNN file path.
23
+
24
+ Args:
25
+ file_path: Path to the RKNN model file.
26
+ parse_flatbuffers: Whether to parse FlatBuffers data. Defaults to False.
27
+ """
28
+ self.file_path = Path(file_path)
29
+ self.header: Dict[str, Any] = {}
30
+ self.model_info: Dict[str, Any] = {}
31
+ self.fb_model: Optional[Model] = None
32
+ self._parse(parse_flatbuffers)
33
+
34
+ def _parse(self, parse_flatbuffers: bool) -> None:
35
+ """Parse the RKNN file.
36
+
37
+ Args:
38
+ parse_flatbuffers: Whether to parse FlatBuffers data.
39
+ """
40
+ with open(self.file_path, "rb") as f:
41
+ # Read header
42
+ header_data = f.read(self.HEADER_SIZE)
43
+ if len(header_data) < self.HEADER_SIZE:
44
+ raise ValueError(f"File too small: {self.file_path}")
45
+
46
+ # Parse header
47
+ # Bytes 0-3: Magic number "RKNN"
48
+ magic = header_data[0:4]
49
+ if magic != self.MAGIC_NUMBER:
50
+ raise ValueError(f"Invalid magic number: {magic!r}, expected {self.MAGIC_NUMBER!r}")
51
+
52
+ # Bytes 4-7: Padding (4 bytes, zeros)
53
+ padding = header_data[4:8]
54
+
55
+ # Bytes 8-15: File format version (8 bytes, little-endian uint64)
56
+ file_format = struct.unpack("<Q", header_data[8:16])[0]
57
+
58
+ # Bytes 16-23: File length (8 bytes, little-endian uint64)
59
+ file_length = struct.unpack("<Q", header_data[16:24])[0]
60
+
61
+ self.header = {
62
+ "magic": magic,
63
+ "padding": padding,
64
+ "file_format": file_format,
65
+ "file_length": file_length,
66
+ }
67
+
68
+ # Calculate JSON offset and size
69
+ real_header_size = self.HEADER_SIZE
70
+ if file_format <= 1:
71
+ # only 3 uint64 for rknn-v1
72
+ real_header_size = 24
73
+
74
+ # Parse FlatBuffers data if requested
75
+ if parse_flatbuffers and file_length > 0:
76
+ fb_offset = real_header_size
77
+ fb_data = f.read(file_length)
78
+ if len(fb_data) == file_length:
79
+ self.fb_model = Model.GetRootAs(fb_data, 0)
80
+
81
+ # Read JSON model info
82
+ file_size = self.file_path.stat().st_size
83
+ json_offset = real_header_size + file_length
84
+ f.seek(json_offset)
85
+ json_size = struct.unpack("<Q", f.read(8))[0]
86
+ if (
87
+ json_size <= 0 or
88
+ json_size > file_size - real_header_size - file_length - 8
89
+ ):
90
+ raise ValueError(f"Invalid JSON size: {json_size}")
91
+
92
+ json_data = f.read(json_size)
93
+
94
+ try:
95
+ self.model_info = json.loads(json_data.decode("utf-8"))
96
+ except (json.JSONDecodeError, UnicodeDecodeError) as e:
97
+ raise ValueError(f"Failed to parse model info JSON: {e}")
98
+
99
+ def get_flatbuffers_info(self) -> Dict[str, Any]:
100
+ """Get FlatBuffers model information.
101
+
102
+ Returns:
103
+ Dictionary containing FlatBuffers model data.
104
+ """
105
+ if not self.fb_model:
106
+ return {}
107
+
108
+ info = {}
109
+
110
+ # Basic model info
111
+ if self.fb_model.Format():
112
+ info["format"] = self.fb_model.Format().decode('utf-8')
113
+
114
+ if self.fb_model.Generator():
115
+ info["generator"] = self.fb_model.Generator().decode('utf-8')
116
+
117
+ if self.fb_model.Compiler():
118
+ info["compiler"] = self.fb_model.Compiler().decode('utf-8')
119
+
120
+ if self.fb_model.Runtime():
121
+ info["runtime"] = self.fb_model.Runtime().decode('utf-8')
122
+
123
+ if self.fb_model.Source():
124
+ info["source"] = self.fb_model.Source().decode('utf-8')
125
+
126
+ # Graphs info
127
+ info["num_graphs"] = self.fb_model.GraphsLength()
128
+
129
+ # Input/Output JSON strings
130
+ if self.fb_model.InputJson():
131
+ info["input_json"] = self.fb_model.InputJson().decode('utf-8')
132
+
133
+ if self.fb_model.OutputJson():
134
+ info["output_json"] = self.fb_model.OutputJson().decode('utf-8')
135
+
136
+ # Parse graphs
137
+ graphs = []
138
+ for i in range(self.fb_model.GraphsLength()):
139
+ graph = self.fb_model.Graphs(i)
140
+ if graph:
141
+ graph_info = {
142
+ "num_tensors": graph.TensorsLength(),
143
+ "num_nodes": graph.NodesLength(),
144
+ "num_inputs": graph.InputsLength(),
145
+ "num_outputs": graph.OutputsLength(),
146
+ }
147
+
148
+ # Parse tensors
149
+ tensors = []
150
+ for j in range(graph.TensorsLength()):
151
+ tensor = graph.Tensors(j)
152
+ if tensor:
153
+ tensor_info = {
154
+ "data_type": tensor.DataType(),
155
+ "kind": tensor.Kind(),
156
+ "name": tensor.Name().decode('utf-8') if tensor.Name() else "",
157
+ "shape": [tensor.Shape(k) for k in range(tensor.ShapeLength())],
158
+ "size": tensor.Size(),
159
+ "index": tensor.Index(),
160
+ }
161
+ tensors.append(tensor_info)
162
+ graph_info["tensors"] = tensors
163
+
164
+ # Parse nodes
165
+ nodes = []
166
+ for j in range(graph.NodesLength()):
167
+ node = graph.Nodes(j)
168
+ if node:
169
+ node_info = {
170
+ "type": node.Type().decode('utf-8') if node.Type() else "",
171
+ "name": node.Name().decode('utf-8') if node.Name() else "",
172
+ "num_inputs": node.InputsLength(),
173
+ "num_outputs": node.OutputsLength(),
174
+ }
175
+ nodes.append(node_info)
176
+ graph_info["nodes"] = nodes
177
+
178
+ graphs.append(graph_info)
179
+
180
+ info["graphs"] = graphs
181
+
182
+ return info
183
+
184
+ def get_input_info(self) -> List[Dict[str, Any]]:
185
+ """Get model input information.
186
+
187
+ Returns:
188
+ List of input tensor information.
189
+ """
190
+ inputs = []
191
+ norm_tensors = {t["tensor_id"]: t for t in self.model_info.get("norm_tensor", [])}
192
+
193
+ # Find input tensors from graph connections
194
+ for conn in self.model_info.get("graph", []):
195
+ if conn.get("left") == "input":
196
+ tensor_id = conn.get("right_tensor_id")
197
+ if tensor_id in norm_tensors:
198
+ inputs.append(norm_tensors[tensor_id])
199
+
200
+ return inputs
201
+
202
+ def get_output_info(self) -> List[Dict[str, Any]]:
203
+ """Get model output information.
204
+
205
+ Returns:
206
+ List of output tensor information.
207
+ """
208
+ outputs = []
209
+ norm_tensors = {t["tensor_id"]: t for t in self.model_info.get("norm_tensor", [])}
210
+
211
+ # Find output tensors from graph connections
212
+ for conn in self.model_info.get("graph", []):
213
+ if conn.get("left") == "output":
214
+ tensor_id = conn.get("right_tensor_id")
215
+ if tensor_id in norm_tensors:
216
+ outputs.append(norm_tensors[tensor_id])
217
+
218
+ return outputs
219
+
220
+ def get_model_name(self) -> str:
221
+ """Get model name."""
222
+ return self.model_info.get("name", "Unknown")
223
+
224
+ def get_version(self) -> str:
225
+ """Get model version."""
226
+ return self.model_info.get("version", "Unknown")
227
+
228
+ def get_target_platform(self) -> list[str]:
229
+ """Get target platform."""
230
+ return self.model_info.get("target_platform", [])
231
+
232
+ def get_merged_io_info(self) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
233
+ """Get merged input/output information from both FlatBuffers and JSON.
234
+
235
+ Returns:
236
+ Tuple of (inputs, outputs) with merged information.
237
+ """
238
+ # Get basic IO info from JSON
239
+ json_inputs = self.get_input_info()
240
+ json_outputs = self.get_output_info()
241
+
242
+ # If no FlatBuffers data, return JSON data as is
243
+ if not self.fb_model:
244
+ return json_inputs, json_outputs
245
+
246
+ # Get FlatBuffers info
247
+ fb_info = self.get_flatbuffers_info()
248
+
249
+ # Extract generator attributes and quant info
250
+ generator_attrs = {}
251
+ quant_tab = {}
252
+
253
+ # Parse generator JSON string if present
254
+ if fb_info.get("generator"):
255
+ try:
256
+ # Use ast.literal_eval to handle single quotes
257
+ import ast
258
+ gen_data = ast.literal_eval(fb_info["generator"])
259
+ if isinstance(gen_data, dict):
260
+ generator_attrs = gen_data.get("attrs", {})
261
+ quant_tab = gen_data.get("quant_tab", {})
262
+ except (ValueError, SyntaxError):
263
+ # Fallback: try JSON parse after fixing quotes
264
+ try:
265
+ import json
266
+ fixed_str = fb_info["generator"].replace("'", '"')
267
+ gen_data = json.loads(fixed_str)
268
+ if isinstance(gen_data, dict):
269
+ generator_attrs = gen_data.get("attrs", {})
270
+ quant_tab = gen_data.get("quant_tab", {})
271
+ except (json.JSONDecodeError, ValueError):
272
+ pass
273
+
274
+ # Merge input information
275
+ merged_inputs = []
276
+ for inp in json_inputs:
277
+ io_name = inp.get("url", "")
278
+ merged_inp = inp.copy()
279
+
280
+ # Add layout info from generator attrs
281
+ if io_name in generator_attrs:
282
+ merged_inp["layout"] = generator_attrs[io_name].get("layout", "")
283
+ merged_inp["layout_ori"] = generator_attrs[io_name].get("layout_ori", "")
284
+
285
+ # Add quant info from quant_tab
286
+ if io_name in quant_tab:
287
+ merged_inp["dtype"] = quant_tab[io_name].get("dtype", inp.get("dtype", {}))
288
+ merged_inp["quant_info"] = {
289
+ "qmethod": quant_tab[io_name].get("qmethod", ""),
290
+ "qtype": quant_tab[io_name].get("qtype", ""),
291
+ "scale": quant_tab[io_name].get("scale", []),
292
+ "zero_point": quant_tab[io_name].get("zero_point", [])
293
+ }
294
+
295
+ merged_inputs.append(merged_inp)
296
+
297
+ # Merge output information
298
+ merged_outputs = []
299
+ for out in json_outputs:
300
+ io_name = out.get("url", "")
301
+ merged_out = out.copy()
302
+
303
+ # Add layout info from generator attrs
304
+ if io_name in generator_attrs:
305
+ merged_out["layout"] = generator_attrs[io_name].get("layout", "")
306
+ merged_out["layout_ori"] = generator_attrs[io_name].get("layout_ori", "")
307
+
308
+ # Add quant info from quant_tab
309
+ if io_name in quant_tab:
310
+ merged_out["dtype"] = quant_tab[io_name].get("dtype", out.get("dtype", {}))
311
+ merged_out["quant_info"] = {
312
+ "qmethod": quant_tab[io_name].get("qmethod", ""),
313
+ "qtype": quant_tab[io_name].get("qtype", ""),
314
+ "scale": quant_tab[io_name].get("scale", []),
315
+ "zero_point": quant_tab[io_name].get("zero_point", [])
316
+ }
317
+
318
+ merged_outputs.append(merged_out)
319
+
320
+ return merged_inputs, merged_outputs