haoline 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. haoline/.streamlit/config.toml +10 -0
  2. haoline/__init__.py +248 -0
  3. haoline/analyzer.py +935 -0
  4. haoline/cli.py +2712 -0
  5. haoline/compare.py +811 -0
  6. haoline/compare_visualizations.py +1564 -0
  7. haoline/edge_analysis.py +525 -0
  8. haoline/eval/__init__.py +131 -0
  9. haoline/eval/adapters.py +844 -0
  10. haoline/eval/cli.py +390 -0
  11. haoline/eval/comparison.py +542 -0
  12. haoline/eval/deployment.py +633 -0
  13. haoline/eval/schemas.py +833 -0
  14. haoline/examples/__init__.py +15 -0
  15. haoline/examples/basic_inspection.py +74 -0
  16. haoline/examples/compare_models.py +117 -0
  17. haoline/examples/hardware_estimation.py +78 -0
  18. haoline/format_adapters.py +1001 -0
  19. haoline/formats/__init__.py +123 -0
  20. haoline/formats/coreml.py +250 -0
  21. haoline/formats/gguf.py +483 -0
  22. haoline/formats/openvino.py +255 -0
  23. haoline/formats/safetensors.py +273 -0
  24. haoline/formats/tflite.py +369 -0
  25. haoline/hardware.py +2307 -0
  26. haoline/hierarchical_graph.py +462 -0
  27. haoline/html_export.py +1573 -0
  28. haoline/layer_summary.py +769 -0
  29. haoline/llm_summarizer.py +465 -0
  30. haoline/op_icons.py +618 -0
  31. haoline/operational_profiling.py +1492 -0
  32. haoline/patterns.py +1116 -0
  33. haoline/pdf_generator.py +265 -0
  34. haoline/privacy.py +250 -0
  35. haoline/pydantic_models.py +241 -0
  36. haoline/report.py +1923 -0
  37. haoline/report_sections.py +539 -0
  38. haoline/risks.py +521 -0
  39. haoline/schema.py +523 -0
  40. haoline/streamlit_app.py +2024 -0
  41. haoline/tests/__init__.py +4 -0
  42. haoline/tests/conftest.py +123 -0
  43. haoline/tests/test_analyzer.py +868 -0
  44. haoline/tests/test_compare_visualizations.py +293 -0
  45. haoline/tests/test_edge_analysis.py +243 -0
  46. haoline/tests/test_eval.py +604 -0
  47. haoline/tests/test_format_adapters.py +460 -0
  48. haoline/tests/test_hardware.py +237 -0
  49. haoline/tests/test_hardware_recommender.py +90 -0
  50. haoline/tests/test_hierarchical_graph.py +326 -0
  51. haoline/tests/test_html_export.py +180 -0
  52. haoline/tests/test_layer_summary.py +428 -0
  53. haoline/tests/test_llm_patterns.py +540 -0
  54. haoline/tests/test_llm_summarizer.py +339 -0
  55. haoline/tests/test_patterns.py +774 -0
  56. haoline/tests/test_pytorch.py +327 -0
  57. haoline/tests/test_report.py +383 -0
  58. haoline/tests/test_risks.py +398 -0
  59. haoline/tests/test_schema.py +417 -0
  60. haoline/tests/test_tensorflow.py +380 -0
  61. haoline/tests/test_visualizations.py +316 -0
  62. haoline/universal_ir.py +856 -0
  63. haoline/visualizations.py +1086 -0
  64. haoline/visualize_yolo.py +44 -0
  65. haoline/web.py +110 -0
  66. haoline-0.3.0.dist-info/METADATA +471 -0
  67. haoline-0.3.0.dist-info/RECORD +70 -0
  68. haoline-0.3.0.dist-info/WHEEL +4 -0
  69. haoline-0.3.0.dist-info/entry_points.txt +5 -0
  70. haoline-0.3.0.dist-info/licenses/LICENSE +22 -0
haoline/analyzer.py ADDED
@@ -0,0 +1,935 @@
1
+ # Copyright (c) 2025 HaoLine Contributors
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ """
5
+ Core analysis engine for HaoLine.
6
+
7
+ This module provides:
8
+ - ONNXGraphLoader: Load ONNX models and extract graph structure
9
+ - MetricsEngine: Compute parameters, FLOPs, and memory estimates
10
+ - GraphInfo: Internal representation of the parsed graph
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import logging
16
+ import pathlib
17
+ from dataclasses import dataclass, field
18
+ from typing import Any, ClassVar
19
+
20
+ import numpy as np
21
+ import onnx
22
+
23
+
24
+ # Standalone implementations that work without onnxruntime
25
+ def get_opsets_imported(model: onnx.ModelProto) -> dict:
26
+ """Get the opsets imported by the model."""
27
+ opsets = {}
28
+ for entry in model.opset_import:
29
+ domain = entry.domain or "ai.onnx"
30
+ opsets[domain] = entry.version
31
+ return opsets
32
+
33
+
34
+ def iterate_graph_per_node_func(graph, per_node_func, **func_args):
35
+ """Iterate the graph including subgraphs calling the per_node_func for each node."""
36
+ for node in graph.node:
37
+ per_node_func(node, **func_args)
38
+ for attr in node.attribute:
39
+ if attr.HasField("g"):
40
+ iterate_graph_per_node_func(attr.g, per_node_func, **func_args)
41
+
42
+
43
+ # ORT utilities not available in standalone package - use onnx fallback
44
+ _HAS_ORT_UTILS = False
45
+ ModelProtoWithShapeInfo = None # type: ignore
46
+
47
+
48
+ @dataclass
49
+ class NodeInfo:
50
+ """Information about a single ONNX node."""
51
+
52
+ name: str
53
+ op_type: str
54
+ domain: str
55
+ inputs: list[str]
56
+ outputs: list[str]
57
+ attributes: dict[str, Any]
58
+ # Computed during analysis
59
+ param_count: float = 0.0 # Float for fractional shared weight attribution
60
+ flops: int = 0
61
+
62
+
63
+ @dataclass
64
+ class GraphInfo:
65
+ """Parsed graph structure with extracted metadata."""
66
+
67
+ name: str
68
+ nodes: list[NodeInfo]
69
+ inputs: list[str]
70
+ outputs: list[str]
71
+ initializers: dict[str, np.ndarray] # name -> tensor
72
+ value_shapes: dict[str, list[int | str]] # name -> shape (may have symbolic dims)
73
+
74
+ # Computed summaries
75
+ num_nodes: int = 0
76
+ input_shapes: dict[str, list[int | str]] = field(default_factory=dict)
77
+ output_shapes: dict[str, list[int | str]] = field(default_factory=dict)
78
+ op_type_counts: dict[str, int] = field(default_factory=dict)
79
+
80
+ # Node lookup
81
+ node_by_name: dict[str, NodeInfo] = field(default_factory=dict)
82
+ node_by_output: dict[str, NodeInfo] = field(default_factory=dict)
83
+
84
+
85
+ @dataclass
86
+ class ParamCounts:
87
+ """Parameter count breakdown."""
88
+
89
+ total: int = 0
90
+ trainable: int = 0 # Assumed: all initializers are trainable unless marked
91
+ non_trainable: int = 0
92
+ by_node: dict[str, float] = field(
93
+ default_factory=dict
94
+ ) # Float for fractional shared attribution
95
+ by_op_type: dict[str, float] = field(
96
+ default_factory=dict
97
+ ) # Float for fractional shared attribution
98
+
99
+ # Shared weight tracking
100
+ shared_weights: dict[str, list[str]] = field(
101
+ default_factory=dict
102
+ ) # initializer -> nodes using it
103
+ num_shared_weights: int = 0 # Count of weights used by 2+ nodes
104
+
105
+ # Quantization info
106
+ precision_breakdown: dict[str, int] = field(default_factory=dict) # dtype -> param count
107
+ is_quantized: bool = False # True if model has quantized weights or ops
108
+ quantized_ops: list[str] = field(default_factory=list) # Quantized op types detected
109
+
110
+ def to_dict(self) -> dict:
111
+ return {
112
+ "total": self.total,
113
+ "trainable": self.trainable,
114
+ "non_trainable": self.non_trainable,
115
+ "by_op_type": {k: round(v, 2) for k, v in self.by_op_type.items()},
116
+ "shared_weights": {
117
+ "count": self.num_shared_weights,
118
+ "details": {k: v for k, v in self.shared_weights.items() if len(v) > 1},
119
+ },
120
+ "precision_breakdown": self.precision_breakdown,
121
+ "is_quantized": self.is_quantized,
122
+ "quantized_ops": self.quantized_ops,
123
+ }
124
+
125
+
126
+ @dataclass
127
+ class FlopCounts:
128
+ """FLOP estimate breakdown."""
129
+
130
+ total: int = 0
131
+ by_node: dict[str, int] = field(default_factory=dict)
132
+ by_op_type: dict[str, int] = field(default_factory=dict)
133
+
134
+ def to_dict(self) -> dict:
135
+ return {
136
+ "total": self.total,
137
+ "by_op_type": self.by_op_type,
138
+ }
139
+
140
+
141
+ @dataclass
142
+ class MemoryBreakdown:
143
+ """Detailed memory breakdown by component type."""
144
+
145
+ # Weights by operation type
146
+ weights_by_op_type: dict[str, int] = field(default_factory=dict) # op -> bytes
147
+ # Top weight tensors
148
+ largest_weights: list[tuple[str, int]] = field(default_factory=list) # (name, bytes)
149
+ # Activation breakdown
150
+ activations_by_op_type: dict[str, int] = field(default_factory=dict) # op -> bytes
151
+ largest_activations: list[tuple[str, int]] = field(default_factory=list)
152
+
153
+ def to_dict(self) -> dict[str, Any]:
154
+ return {
155
+ "weights_by_op_type": self.weights_by_op_type,
156
+ "largest_weights": [
157
+ {"name": name, "bytes": size} for name, size in self.largest_weights[:10]
158
+ ],
159
+ "activations_by_op_type": self.activations_by_op_type,
160
+ "largest_activations": [
161
+ {"name": name, "bytes": size} for name, size in self.largest_activations[:10]
162
+ ],
163
+ }
164
+
165
+
166
+ @dataclass
167
+ class MemoryEstimates:
168
+ """Memory usage estimates."""
169
+
170
+ model_size_bytes: int = 0 # Size of parameters/initializers
171
+ peak_activation_bytes: int = 0 # Estimated peak activation memory (batch=1)
172
+ per_layer_activation_bytes: dict[str, int] = field(default_factory=dict)
173
+ # KV cache estimates for transformer models
174
+ kv_cache_bytes_per_token: int = 0 # KV cache per token (for streaming inference)
175
+ kv_cache_bytes_full_context: int = 0 # Total KV cache at max seq length
176
+ kv_cache_config: dict[str, int] = field(default_factory=dict) # num_layers, hidden_dim, etc.
177
+ # Detailed breakdown
178
+ breakdown: MemoryBreakdown | None = None
179
+
180
+ def to_dict(self) -> dict[str, Any]:
181
+ result: dict[str, Any] = {
182
+ "model_size_bytes": self.model_size_bytes,
183
+ "peak_activation_bytes": self.peak_activation_bytes,
184
+ }
185
+ if self.kv_cache_bytes_per_token > 0:
186
+ result["kv_cache_bytes_per_token"] = self.kv_cache_bytes_per_token
187
+ result["kv_cache_bytes_full_context"] = self.kv_cache_bytes_full_context
188
+ result["kv_cache_config"] = self.kv_cache_config
189
+ if self.breakdown:
190
+ result["breakdown"] = self.breakdown.to_dict()
191
+ return result
192
+
193
+
194
+ class ONNXGraphLoader:
195
+ """
196
+ Load ONNX models and extract graph structure.
197
+
198
+ Handles shape inference and creates a GraphInfo representation
199
+ suitable for analysis.
200
+ """
201
+
202
+ def __init__(self, logger: logging.Logger | None = None):
203
+ self.logger = logger or logging.getLogger("haoline.loader")
204
+
205
+ def load(self, model_path: str | pathlib.Path) -> tuple[onnx.ModelProto, GraphInfo]:
206
+ """
207
+ Load an ONNX model and extract graph information.
208
+
209
+ Args:
210
+ model_path: Path to the ONNX model file.
211
+
212
+ Returns:
213
+ Tuple of (ModelProto, GraphInfo)
214
+ """
215
+ model_path = pathlib.Path(model_path)
216
+ self.logger.debug(f"Loading model from {model_path}")
217
+
218
+ # Use ORT's helper if available, otherwise fall back to onnx
219
+ if _HAS_ORT_UTILS and ModelProtoWithShapeInfo is not None:
220
+ wrapper = ModelProtoWithShapeInfo(model_path)
221
+ model = wrapper.model_with_shape_info
222
+ else:
223
+ # Fallback: load with onnx and run shape inference
224
+ model = onnx.load(str(model_path))
225
+ try:
226
+ model = onnx.shape_inference.infer_shapes(model, strict_mode=True)
227
+ except Exception as e:
228
+ self.logger.warning(f"Shape inference failed: {e}. Proceeding without shape info.")
229
+
230
+ graph_info = self._extract_graph_info(model.graph, model)
231
+
232
+ self.logger.debug(f"Loaded graph with {graph_info.num_nodes} nodes")
233
+ return model, graph_info
234
+
235
+ def _extract_graph_info(self, graph: onnx.GraphProto, model: onnx.ModelProto) -> GraphInfo:
236
+ """Extract GraphInfo from an ONNX GraphProto."""
237
+
238
+ # Extract initializers (weights/biases)
239
+ initializers = {}
240
+ for init in graph.initializer:
241
+ try:
242
+ initializers[init.name] = onnx.numpy_helper.to_array(init)
243
+ except Exception as e:
244
+ self.logger.warning(f"Could not convert initializer {init.name}: {e}")
245
+ # Store shape info at minimum
246
+ initializers[init.name] = np.zeros(init.dims, dtype=np.float32)
247
+
248
+ # Build value shape map from value_info, inputs, and outputs
249
+ value_shapes = {}
250
+
251
+ def _extract_shape(value_info: onnx.ValueInfoProto) -> list[int | str]:
252
+ shape = []
253
+ if value_info.type.HasField("tensor_type"):
254
+ tensor_type = value_info.type.tensor_type
255
+ if tensor_type.HasField("shape"):
256
+ for dim in tensor_type.shape.dim:
257
+ if dim.HasField("dim_value"):
258
+ shape.append(dim.dim_value)
259
+ elif dim.HasField("dim_param"):
260
+ shape.append(dim.dim_param)
261
+ else:
262
+ shape.append("?")
263
+ return shape
264
+
265
+ for vi in graph.input:
266
+ value_shapes[vi.name] = _extract_shape(vi)
267
+ for vi in graph.output:
268
+ value_shapes[vi.name] = _extract_shape(vi)
269
+ for vi in graph.value_info:
270
+ value_shapes[vi.name] = _extract_shape(vi)
271
+
272
+ # For initializers without explicit value_info, use their tensor shapes
273
+ for name, arr in initializers.items():
274
+ if name not in value_shapes:
275
+ value_shapes[name] = list(arr.shape)
276
+
277
+ # Extract nodes
278
+ nodes: list[NodeInfo] = []
279
+ op_type_counts: dict[str, int] = {}
280
+ node_by_name: dict[str, NodeInfo] = {}
281
+ node_by_output: dict[str, NodeInfo] = {}
282
+
283
+ for node in graph.node:
284
+ # Extract attributes
285
+ attributes = {}
286
+ for attr in node.attribute:
287
+ if attr.HasField("i"):
288
+ attributes[attr.name] = attr.i
289
+ elif attr.HasField("f"):
290
+ attributes[attr.name] = attr.f
291
+ elif attr.HasField("s"):
292
+ attributes[attr.name] = (
293
+ attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s
294
+ )
295
+ elif attr.ints:
296
+ attributes[attr.name] = list(attr.ints)
297
+ elif attr.floats:
298
+ attributes[attr.name] = list(attr.floats)
299
+ # Skip subgraphs and other complex types for now
300
+
301
+ node_info = NodeInfo(
302
+ name=node.name or f"unnamed_{len(nodes)}",
303
+ op_type=node.op_type,
304
+ domain=node.domain or "ai.onnx",
305
+ inputs=list(node.input),
306
+ outputs=list(node.output),
307
+ attributes=attributes,
308
+ )
309
+ nodes.append(node_info)
310
+ node_by_name[node_info.name] = node_info
311
+ for output in node_info.outputs:
312
+ node_by_output[output] = node_info
313
+
314
+ # Count op types
315
+ op_type_counts[node.op_type] = op_type_counts.get(node.op_type, 0) + 1
316
+
317
+ # Build input/output shape maps (excluding initializers from inputs)
318
+ input_names = [i.name for i in graph.input if i.name not in initializers]
319
+ output_names = [o.name for o in graph.output]
320
+
321
+ input_shapes = {name: value_shapes.get(name, []) for name in input_names}
322
+ output_shapes = {name: value_shapes.get(name, []) for name in output_names}
323
+
324
+ return GraphInfo(
325
+ name=graph.name or "main",
326
+ nodes=nodes,
327
+ inputs=input_names,
328
+ outputs=output_names,
329
+ initializers=initializers,
330
+ value_shapes=value_shapes,
331
+ num_nodes=len(nodes),
332
+ input_shapes=input_shapes,
333
+ output_shapes=output_shapes,
334
+ op_type_counts=op_type_counts,
335
+ node_by_name=node_by_name,
336
+ node_by_output=node_by_output,
337
+ )
338
+
339
+
340
+ class MetricsEngine:
341
+ """
342
+ Compute model complexity metrics.
343
+
344
+ Provides parameter counts, FLOP estimates, and memory estimates
345
+ for ONNX graphs.
346
+ """
347
+
348
+ # FLOP multipliers per operation type
349
+ # These are rough estimates; actual FLOPs depend on implementation
350
+ FLOP_FORMULAS: ClassVar[dict[str, str]] = {
351
+ # Conv: 2 * K_h * K_w * C_in * C_out * H_out * W_out
352
+ "Conv": "conv",
353
+ # MatMul: 2 * M * N * K
354
+ "MatMul": "matmul",
355
+ "Gemm": "gemm",
356
+ # Element-wise ops: N elements
357
+ "Add": "elementwise",
358
+ "Sub": "elementwise",
359
+ "Mul": "elementwise",
360
+ "Div": "elementwise",
361
+ "Relu": "elementwise",
362
+ "Sigmoid": "elementwise",
363
+ "Tanh": "elementwise",
364
+ "Sqrt": "elementwise",
365
+ "Exp": "elementwise",
366
+ "Log": "elementwise",
367
+ "Gelu": "elementwise",
368
+ "Silu": "elementwise",
369
+ # Softmax: ~5N (exp, sum, div)
370
+ "Softmax": "softmax",
371
+ # Reduction ops: N elements
372
+ "ReduceMean": "elementwise",
373
+ "ReduceSum": "elementwise",
374
+ "ReduceMax": "elementwise",
375
+ # Normalization layers
376
+ "LayerNormalization": "layernorm",
377
+ "BatchNormalization": "batchnorm",
378
+ # Attention ops (ONNX contrib / custom)
379
+ "Attention": "attention",
380
+ "MultiHeadAttention": "attention",
381
+ "com.microsoft.Attention": "attention",
382
+ "com.microsoft.MultiHeadAttention": "attention",
383
+ }
384
+
385
+ # Quantized operation types in ONNX
386
+ QUANTIZED_OPS: ClassVar[set[str]] = {
387
+ "QuantizeLinear",
388
+ "DequantizeLinear",
389
+ "QLinearConv",
390
+ "QLinearMatMul",
391
+ "QLinearAdd",
392
+ "QGemm",
393
+ "ConvInteger",
394
+ "MatMulInteger",
395
+ "DynamicQuantizeLinear",
396
+ "QLinearSigmoid",
397
+ "QLinearLeakyRelu",
398
+ "QLinearAveragePool",
399
+ "QLinearGlobalAveragePool",
400
+ "QLinearConcat",
401
+ }
402
+
403
+ # Quantized dtypes
404
+ QUANTIZED_DTYPES: ClassVar[set[type]] = {np.int8, np.uint8, np.int16, np.uint16}
405
+
406
+ def __init__(self, logger: logging.Logger | None = None):
407
+ self.logger = logger or logging.getLogger("haoline.metrics")
408
+
409
+ def count_parameters(self, graph_info: GraphInfo) -> ParamCounts:
410
+ """
411
+ Count parameters in the model.
412
+
413
+ Parameters are counted from initializers. All initializers are
414
+ assumed trainable unless specifically marked otherwise.
415
+
416
+ Handles edge cases:
417
+ - Shared weights: Uses fractional attribution so by_op_type sums to total
418
+ - Quantized params: Detects INT8/UINT8 weights and quantized ops
419
+
420
+ Args:
421
+ graph_info: Parsed graph information.
422
+
423
+ Returns:
424
+ ParamCounts with total and per-node breakdowns.
425
+ """
426
+ counts = ParamCounts()
427
+
428
+ # First pass: build usage map (which nodes use each initializer)
429
+ usage_map: dict[str, list[str]] = {name: [] for name in graph_info.initializers}
430
+ for node in graph_info.nodes:
431
+ for inp in node.inputs:
432
+ if inp in graph_info.initializers:
433
+ usage_map[inp].append(node.name)
434
+
435
+ # Track shared weights (used by 2+ nodes)
436
+ counts.shared_weights = {k: v for k, v in usage_map.items() if len(v) > 1}
437
+ counts.num_shared_weights = len(counts.shared_weights)
438
+
439
+ # Detect quantized ops in the graph
440
+ quantized_ops_found = set()
441
+ for node in graph_info.nodes:
442
+ if node.op_type in self.QUANTIZED_OPS:
443
+ quantized_ops_found.add(node.op_type)
444
+ counts.quantized_ops = sorted(quantized_ops_found)
445
+
446
+ # Second pass: count parameters with fractional attribution
447
+ for name, tensor in graph_info.initializers.items():
448
+ param_count = int(np.prod(tensor.shape)) if tensor.shape else 1
449
+ counts.total += param_count
450
+ counts.by_node[name] = float(param_count)
451
+
452
+ # Track precision breakdown
453
+ dtype_name = self._get_dtype_name(tensor)
454
+ counts.precision_breakdown[dtype_name] = (
455
+ counts.precision_breakdown.get(dtype_name, 0) + param_count
456
+ )
457
+
458
+ # Check if this is a quantized weight
459
+ if hasattr(tensor, "dtype") and tensor.dtype in self.QUANTIZED_DTYPES:
460
+ counts.is_quantized = True
461
+
462
+ # Fractional attribution to nodes sharing this weight
463
+ using_nodes = usage_map[name]
464
+ num_users = len(using_nodes) if using_nodes else 1
465
+ fractional_count = param_count / num_users
466
+
467
+ for node in graph_info.nodes:
468
+ if node.name in using_nodes:
469
+ counts.by_op_type[node.op_type] = (
470
+ counts.by_op_type.get(node.op_type, 0.0) + fractional_count
471
+ )
472
+ node.param_count += fractional_count
473
+
474
+ # Mark as quantized if quantized ops are present
475
+ if counts.quantized_ops:
476
+ counts.is_quantized = True
477
+
478
+ # For now, assume all are trainable
479
+ # Could be refined with graph analysis (e.g., constants, frozen layers)
480
+ counts.trainable = counts.total
481
+ counts.non_trainable = 0
482
+
483
+ return counts
484
+
485
+ def _get_dtype_name(self, tensor: np.ndarray) -> str:
486
+ """Get a human-readable dtype name for a tensor."""
487
+ if not hasattr(tensor, "dtype"):
488
+ return "unknown"
489
+ dtype = tensor.dtype
490
+ dtype_map = {
491
+ np.float32: "fp32",
492
+ np.float64: "fp64",
493
+ np.float16: "fp16",
494
+ np.int8: "int8",
495
+ np.uint8: "uint8",
496
+ np.int16: "int16",
497
+ np.uint16: "uint16",
498
+ np.int32: "int32",
499
+ np.int64: "int64",
500
+ }
501
+ return dtype_map.get(dtype.type, str(dtype))
502
+
503
+ def estimate_flops(self, graph_info: GraphInfo) -> FlopCounts:
504
+ """
505
+ Estimate FLOPs for each operation in the graph.
506
+
507
+ Uses shape information to compute FLOPs. Falls back to
508
+ rough estimates when shapes are unavailable.
509
+
510
+ Args:
511
+ graph_info: Parsed graph information.
512
+
513
+ Returns:
514
+ FlopCounts with total and per-node breakdowns.
515
+ """
516
+ counts = FlopCounts()
517
+
518
+ for node in graph_info.nodes:
519
+ flops = self._estimate_node_flops(node, graph_info)
520
+ node.flops = flops
521
+ counts.total += flops
522
+ counts.by_node[node.name] = flops
523
+ counts.by_op_type[node.op_type] = counts.by_op_type.get(node.op_type, 0) + flops
524
+
525
+ return counts
526
+
527
+ def _estimate_node_flops(self, node: NodeInfo, graph_info: GraphInfo) -> int:
528
+ """Estimate FLOPs for a single node."""
529
+ formula_type = self.FLOP_FORMULAS.get(node.op_type, "unknown")
530
+
531
+ if formula_type == "conv":
532
+ return self._estimate_conv_flops(node, graph_info)
533
+ elif formula_type == "matmul":
534
+ return self._estimate_matmul_flops(node, graph_info)
535
+ elif formula_type == "gemm":
536
+ return self._estimate_gemm_flops(node, graph_info)
537
+ elif formula_type == "elementwise":
538
+ return self._estimate_elementwise_flops(node, graph_info)
539
+ elif formula_type == "softmax":
540
+ return self._estimate_elementwise_flops(node, graph_info) * 5
541
+ elif formula_type == "layernorm":
542
+ return self._estimate_elementwise_flops(node, graph_info) * 5
543
+ elif formula_type == "batchnorm":
544
+ return self._estimate_elementwise_flops(node, graph_info) * 2
545
+ elif formula_type == "attention":
546
+ return self._estimate_attention_flops(node, graph_info)
547
+ else:
548
+ # Unknown op - estimate based on output size
549
+ return self._estimate_elementwise_flops(node, graph_info)
550
+
551
+ def _estimate_conv_flops(self, node: NodeInfo, graph_info: GraphInfo) -> int:
552
+ """Estimate FLOPs for Conv operation: 2 * K_h * K_w * C_in * C_out * H_out * W_out"""
553
+ if len(node.inputs) < 2:
554
+ return 0
555
+
556
+ # Get weight shape
557
+ weight_name = node.inputs[1]
558
+ if weight_name in graph_info.initializers:
559
+ weight_shape = list(graph_info.initializers[weight_name].shape)
560
+ elif weight_name in graph_info.value_shapes:
561
+ weight_shape = graph_info.value_shapes[weight_name]
562
+ else:
563
+ return 0
564
+
565
+ # Weight shape: [C_out, C_in/groups, K_h, K_w] for 2D conv
566
+ if len(weight_shape) < 4 or not all(isinstance(d, int) for d in weight_shape):
567
+ return 0
568
+
569
+ c_out, c_in_per_group, k_h, k_w = weight_shape[:4]
570
+
571
+ # Get output shape
572
+ if node.outputs and node.outputs[0] in graph_info.value_shapes:
573
+ output_shape = graph_info.value_shapes[node.outputs[0]]
574
+ if len(output_shape) >= 4 and all(isinstance(d, int) for d in output_shape[-2:]):
575
+ h_out, w_out = output_shape[-2], output_shape[-1]
576
+ else:
577
+ h_out, w_out = 1, 1
578
+ else:
579
+ h_out, w_out = 1, 1
580
+
581
+ node.attributes.get("group", 1)
582
+ flops = 2 * k_h * k_w * c_in_per_group * c_out * h_out * w_out
583
+
584
+ # Add bias if present
585
+ if len(node.inputs) > 2:
586
+ flops += c_out * h_out * w_out
587
+
588
+ return int(flops)
589
+
590
+ def _estimate_matmul_flops(self, node: NodeInfo, graph_info: GraphInfo) -> int:
591
+ """Estimate FLOPs for MatMul: 2 * M * N * K"""
592
+ if len(node.inputs) < 2:
593
+ return 0
594
+
595
+ # Get shapes of both inputs
596
+ shapes = []
597
+ for inp in node.inputs[:2]:
598
+ if inp in graph_info.initializers:
599
+ shapes.append(list(graph_info.initializers[inp].shape))
600
+ elif inp in graph_info.value_shapes:
601
+ shapes.append(graph_info.value_shapes[inp])
602
+ else:
603
+ return 0
604
+
605
+ if len(shapes) < 2:
606
+ return 0
607
+
608
+ # MatMul: A[..., M, K] @ B[..., K, N] = C[..., M, N]
609
+ shape_a, shape_b = shapes[0], shapes[1]
610
+
611
+ # Handle broadcasting and get M, K, N
612
+ if len(shape_a) < 2 or len(shape_b) < 2:
613
+ return 0
614
+
615
+ if not all(isinstance(d, int) for d in shape_a[-2:]) or not all(
616
+ isinstance(d, int) for d in shape_b[-2:]
617
+ ):
618
+ return 0
619
+
620
+ m, k = shape_a[-2], shape_a[-1]
621
+ k2, n = shape_b[-2], shape_b[-1]
622
+
623
+ if k != k2:
624
+ self.logger.warning(f"MatMul shape mismatch in node {node.name}: K={k} vs K={k2}")
625
+ return 0
626
+
627
+ # Handle batch dimensions
628
+ batch = 1
629
+ for dim in shape_a[:-2]:
630
+ if isinstance(dim, int):
631
+ batch *= dim
632
+
633
+ return int(2 * batch * m * n * k)
634
+
635
+ def _estimate_gemm_flops(self, node: NodeInfo, graph_info: GraphInfo) -> int:
636
+ """Estimate FLOPs for Gemm: 2 * M * N * K + M * N (bias)"""
637
+ flops = self._estimate_matmul_flops(node, graph_info)
638
+
639
+ # Add bias computation if present
640
+ if len(node.inputs) > 2 and node.outputs and node.outputs[0] in graph_info.value_shapes:
641
+ output_shape = graph_info.value_shapes[node.outputs[0]]
642
+ if output_shape and all(isinstance(d, int) for d in output_shape):
643
+ int_shape: list[int] = [d for d in output_shape if isinstance(d, int)]
644
+ bias_flops = int(np.prod(int_shape))
645
+ flops += bias_flops
646
+
647
+ return flops
648
+
649
+ def _estimate_elementwise_flops(self, node: NodeInfo, graph_info: GraphInfo) -> int:
650
+ """Estimate FLOPs for element-wise operations: N elements"""
651
+ # Use output shape to determine element count
652
+ if node.outputs and node.outputs[0] in graph_info.value_shapes:
653
+ shape = graph_info.value_shapes[node.outputs[0]]
654
+ if shape and all(isinstance(d, int) for d in shape):
655
+ int_shape: list[int] = [d for d in shape if isinstance(d, int)]
656
+ return int(np.prod(int_shape))
657
+
658
+ # Fallback: use first input shape
659
+ if node.inputs and node.inputs[0] in graph_info.value_shapes:
660
+ shape = graph_info.value_shapes[node.inputs[0]]
661
+ if shape and all(isinstance(d, int) for d in shape):
662
+ int_shape2: list[int] = [d for d in shape if isinstance(d, int)]
663
+ return int(np.prod(int_shape2))
664
+
665
+ return 0
666
+
667
+ def _estimate_attention_flops(self, node: NodeInfo, graph_info: GraphInfo) -> int:
668
+ """
669
+ Estimate FLOPs for attention operations.
670
+
671
+ Standard multi-head attention FLOPs:
672
+ - QKV projections: 3 * batch * seq_len * d_model * d_model
673
+ - Attention scores (Q @ K^T): batch * num_heads * seq_len * seq_len * d_head
674
+ - Softmax: batch * num_heads * seq_len * seq_len * 5
675
+ - Attention output (scores @ V): batch * num_heads * seq_len * seq_len * d_head
676
+ - Output projection: batch * seq_len * d_model * d_model
677
+
678
+ Simplified formula: 4 * seq_len * d_model^2 + 2 * num_heads * seq_len^2 * d_head
679
+ """
680
+ # Try to get dimensions from node attributes or input shapes
681
+ num_heads = 1
682
+ seq_len = 512 # Default assumption
683
+ d_model = 768 # Default assumption
684
+
685
+ # Try to extract from node attributes (ONNX Attention op)
686
+ for attr_name, attr_value in node.attributes.items():
687
+ if attr_name == "num_heads" and isinstance(attr_value, int):
688
+ num_heads = attr_value
689
+ elif attr_name == "hidden_size" and isinstance(attr_value, int):
690
+ d_model = attr_value
691
+
692
+ # Try to infer from input shapes
693
+ if node.inputs and node.inputs[0] in graph_info.value_shapes:
694
+ input_shape = graph_info.value_shapes[node.inputs[0]]
695
+ if input_shape and len(input_shape) >= 2:
696
+ # Shape is typically [batch, seq_len, d_model] or [batch, seq_len, ...]
697
+ if len(input_shape) >= 3:
698
+ if isinstance(input_shape[1], int):
699
+ seq_len = input_shape[1]
700
+ if isinstance(input_shape[2], int):
701
+ d_model = input_shape[2]
702
+ elif len(input_shape) == 2:
703
+ if isinstance(input_shape[0], int):
704
+ seq_len = input_shape[0]
705
+ if isinstance(input_shape[1], int):
706
+ d_model = input_shape[1]
707
+
708
+ d_head = d_model // num_heads if num_heads > 0 else d_model
709
+
710
+ # Compute FLOPs using standard attention formula
711
+ # QKV projections: 3 * seq * d_model * d_model
712
+ qkv_flops = 3 * seq_len * d_model * d_model
713
+
714
+ # Attention scores and output: 2 * num_heads * seq^2 * d_head
715
+ attention_flops = 2 * num_heads * seq_len * seq_len * d_head
716
+
717
+ # Output projection: seq * d_model * d_model
718
+ output_flops = seq_len * d_model * d_model
719
+
720
+ # Softmax on attention scores: 5 * num_heads * seq^2
721
+ softmax_flops = 5 * num_heads * seq_len * seq_len
722
+
723
+ total_flops = qkv_flops + attention_flops + output_flops + softmax_flops
724
+
725
+ self.logger.debug(
726
+ f"Attention FLOPs: seq={seq_len}, d_model={d_model}, "
727
+ f"heads={num_heads}, total={total_flops:,}"
728
+ )
729
+
730
+ return total_flops
731
+
732
+ def estimate_memory(self, graph_info: GraphInfo) -> MemoryEstimates:
733
+ """
734
+ Estimate memory usage for the model.
735
+
736
+ Computes model size (parameters), peak activation memory,
737
+ KV cache size for transformer models, and detailed breakdown.
738
+
739
+ Args:
740
+ graph_info: Parsed graph information.
741
+
742
+ Returns:
743
+ MemoryEstimates with size, activation memory, KV cache, and breakdown.
744
+ """
745
+ estimates = MemoryEstimates()
746
+ breakdown = MemoryBreakdown()
747
+
748
+ # Build mapping: initializer name -> op type that uses it
749
+ init_to_op: dict[str, str] = {}
750
+ for node in graph_info.nodes:
751
+ for inp in node.inputs:
752
+ if inp in graph_info.initializers and inp not in init_to_op:
753
+ init_to_op[inp] = node.op_type
754
+
755
+ # Model size: sum of initializer sizes with breakdown by op type
756
+ weight_sizes: list[tuple[str, int]] = []
757
+ for name, tensor in graph_info.initializers.items():
758
+ # Determine bytes per element based on dtype
759
+ bytes_per_elem = 4
760
+ if hasattr(tensor, "dtype"):
761
+ if tensor.dtype == np.float16:
762
+ bytes_per_elem = 2
763
+ elif tensor.dtype == np.float64:
764
+ bytes_per_elem = 8
765
+ elif tensor.dtype in (np.int8, np.uint8):
766
+ bytes_per_elem = 1
767
+ elif tensor.dtype in (np.int16, np.uint16):
768
+ bytes_per_elem = 2
769
+
770
+ tensor_bytes = (
771
+ int(np.prod(tensor.shape)) * bytes_per_elem if tensor.shape else bytes_per_elem
772
+ )
773
+ estimates.model_size_bytes += tensor_bytes
774
+ weight_sizes.append((name, tensor_bytes))
775
+
776
+ # Categorize by op type
777
+ op_type = init_to_op.get(name, "Other")
778
+ breakdown.weights_by_op_type[op_type] = (
779
+ breakdown.weights_by_op_type.get(op_type, 0) + tensor_bytes
780
+ )
781
+
782
+ # Store top 10 largest weights
783
+ breakdown.largest_weights = sorted(weight_sizes, key=lambda x: -x[1])[:10]
784
+
785
+ # Peak activation memory: rough estimate based on intermediate tensor sizes
786
+ # Build mapping: activation name -> op type that produces it
787
+ activation_to_op: dict[str, str] = {}
788
+ for node in graph_info.nodes:
789
+ for out in node.outputs:
790
+ activation_to_op[out] = node.op_type
791
+
792
+ activation_sizes: list[tuple[str, int]] = []
793
+ for name, shape in graph_info.value_shapes.items():
794
+ # Skip initializers (they're counted in model size)
795
+ if name in graph_info.initializers:
796
+ continue
797
+
798
+ if shape:
799
+ # Handle symbolic dimensions (e.g., 'N' for batch) by treating as 1
800
+ int_shape: list[int] = [d if isinstance(d, int) else 1 for d in shape]
801
+ # Skip if all dims are symbolic (no meaningful size)
802
+ if all(d == 1 for d in int_shape) and len(int_shape) > 1:
803
+ continue
804
+ # Assume float32 for activations
805
+ tensor_bytes = int(np.prod(int_shape)) * 4
806
+ activation_sizes.append((name, tensor_bytes))
807
+ estimates.per_layer_activation_bytes[name] = tensor_bytes
808
+
809
+ # Categorize by producing op type
810
+ op_type = activation_to_op.get(name, "Input")
811
+ breakdown.activations_by_op_type[op_type] = (
812
+ breakdown.activations_by_op_type.get(op_type, 0) + tensor_bytes
813
+ )
814
+
815
+ # Peak is approximate: sum of largest activations that might coexist
816
+ sorted_activations = sorted(activation_sizes, key=lambda x: -x[1])
817
+ # Rough heuristic: top 3 largest activations might coexist
818
+ top_n = min(3, len(sorted_activations))
819
+ estimates.peak_activation_bytes = sum(size for _, size in sorted_activations[:top_n])
820
+
821
+ # Store top 10 largest activations
822
+ breakdown.largest_activations = sorted_activations[:10]
823
+
824
+ # Store breakdown
825
+ estimates.breakdown = breakdown
826
+
827
+ # Estimate KV cache for transformer models
828
+ kv_config = self._estimate_kv_cache_config(graph_info)
829
+ if kv_config:
830
+ estimates.kv_cache_config = kv_config
831
+ estimates.kv_cache_bytes_per_token = self._compute_kv_cache_per_token(kv_config)
832
+ estimates.kv_cache_bytes_full_context = (
833
+ estimates.kv_cache_bytes_per_token * kv_config["seq_len"]
834
+ )
835
+
836
+ return estimates
837
+
838
+ def _estimate_kv_cache_config(self, graph_info: GraphInfo) -> dict[str, int]:
839
+ """
840
+ Detect transformer architecture and extract KV cache config.
841
+
842
+ Returns dict with num_layers, hidden_dim, num_heads, seq_len, bytes_per_elem
843
+ or empty dict if not a transformer.
844
+ """
845
+ # Check for attention ops
846
+ attention_ops = {"Attention", "MultiHeadAttention", "Softmax"}
847
+ attention_count = sum(graph_info.op_type_counts.get(op, 0) for op in attention_ops)
848
+
849
+ if attention_count == 0:
850
+ return {}
851
+
852
+ # Try to detect transformer parameters
853
+ num_layers = 0
854
+ hidden_dim = 768 # Default
855
+ num_heads = 12 # Default
856
+ seq_len = 512 # Default
857
+ bytes_per_elem = 4 # FP32 default
858
+
859
+ # Count attention ops to estimate number of layers
860
+ # Each transformer layer typically has one attention block
861
+ mha_count = graph_info.op_type_counts.get("Attention", 0) + graph_info.op_type_counts.get(
862
+ "MultiHeadAttention", 0
863
+ )
864
+ softmax_count = graph_info.op_type_counts.get("Softmax", 0)
865
+
866
+ # Use MHA count if available, otherwise estimate from Softmax
867
+ if mha_count > 0:
868
+ num_layers = mha_count
869
+ elif softmax_count > 0:
870
+ # Softmax in attention: typically one per layer (or two with cross-attention)
871
+ num_layers = max(1, softmax_count // 2)
872
+
873
+ if num_layers == 0:
874
+ return {}
875
+
876
+ # Try to infer hidden_dim from weight shapes
877
+ for node in graph_info.nodes:
878
+ if node.op_type in ("MatMul", "Gemm"):
879
+ for inp in node.inputs:
880
+ if inp in graph_info.initializers:
881
+ weight = graph_info.initializers[inp]
882
+ if len(weight.shape) == 2:
883
+ # Dense layer weights: [in_features, out_features] or vice versa
884
+ dim = max(weight.shape)
885
+ if 256 <= dim <= 16384 and dim % 64 == 0:
886
+ hidden_dim = dim
887
+ break
888
+ break
889
+
890
+ # Try to infer sequence length from input shapes
891
+ for shape in graph_info.input_shapes.values():
892
+ if len(shape) >= 2:
893
+ # Look for typical transformer input shape [batch, seq_len, ...] or [batch, seq_len]
894
+ for dim in shape[1:3]:
895
+ if isinstance(dim, int) and 16 <= dim <= 32768:
896
+ seq_len = dim
897
+ break
898
+
899
+ # Estimate num_heads from hidden_dim (typical: 64-128 per head)
900
+ if hidden_dim >= 256:
901
+ num_heads = max(1, hidden_dim // 64)
902
+
903
+ self.logger.debug(
904
+ f"KV cache config: layers={num_layers}, hidden={hidden_dim}, "
905
+ f"heads={num_heads}, seq={seq_len}"
906
+ )
907
+
908
+ return {
909
+ "num_layers": num_layers,
910
+ "hidden_dim": hidden_dim,
911
+ "num_heads": num_heads,
912
+ "seq_len": seq_len,
913
+ "bytes_per_elem": bytes_per_elem,
914
+ }
915
+
916
+ def _compute_kv_cache_per_token(self, config: dict[str, int]) -> int:
917
+ """
918
+ Compute KV cache memory per token.
919
+
920
+ Formula: 2 * num_layers * hidden_dim * bytes_per_elem
921
+ (2 for K and V, each of size [hidden_dim])
922
+
923
+ For multi-head attention with head_dim = hidden_dim / num_heads:
924
+ KV cache per token per layer = 2 * hidden_dim * bytes_per_elem
925
+
926
+ Total per token = 2 * num_layers * hidden_dim * bytes_per_elem
927
+ """
928
+ num_layers = config.get("num_layers", 0)
929
+ hidden_dim = config.get("hidden_dim", 0)
930
+ bytes_per_elem = config.get("bytes_per_elem", 4)
931
+
932
+ # KV cache: each layer stores K and V for each token
933
+ # K and V each have shape [batch, num_heads, seq_len, head_dim]
934
+ # Per token: 2 * hidden_dim * bytes_per_elem per layer
935
+ return 2 * num_layers * hidden_dim * bytes_per_elem