haoline 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- haoline/.streamlit/config.toml +10 -0
- haoline/__init__.py +248 -0
- haoline/analyzer.py +935 -0
- haoline/cli.py +2712 -0
- haoline/compare.py +811 -0
- haoline/compare_visualizations.py +1564 -0
- haoline/edge_analysis.py +525 -0
- haoline/eval/__init__.py +131 -0
- haoline/eval/adapters.py +844 -0
- haoline/eval/cli.py +390 -0
- haoline/eval/comparison.py +542 -0
- haoline/eval/deployment.py +633 -0
- haoline/eval/schemas.py +833 -0
- haoline/examples/__init__.py +15 -0
- haoline/examples/basic_inspection.py +74 -0
- haoline/examples/compare_models.py +117 -0
- haoline/examples/hardware_estimation.py +78 -0
- haoline/format_adapters.py +1001 -0
- haoline/formats/__init__.py +123 -0
- haoline/formats/coreml.py +250 -0
- haoline/formats/gguf.py +483 -0
- haoline/formats/openvino.py +255 -0
- haoline/formats/safetensors.py +273 -0
- haoline/formats/tflite.py +369 -0
- haoline/hardware.py +2307 -0
- haoline/hierarchical_graph.py +462 -0
- haoline/html_export.py +1573 -0
- haoline/layer_summary.py +769 -0
- haoline/llm_summarizer.py +465 -0
- haoline/op_icons.py +618 -0
- haoline/operational_profiling.py +1492 -0
- haoline/patterns.py +1116 -0
- haoline/pdf_generator.py +265 -0
- haoline/privacy.py +250 -0
- haoline/pydantic_models.py +241 -0
- haoline/report.py +1923 -0
- haoline/report_sections.py +539 -0
- haoline/risks.py +521 -0
- haoline/schema.py +523 -0
- haoline/streamlit_app.py +2024 -0
- haoline/tests/__init__.py +4 -0
- haoline/tests/conftest.py +123 -0
- haoline/tests/test_analyzer.py +868 -0
- haoline/tests/test_compare_visualizations.py +293 -0
- haoline/tests/test_edge_analysis.py +243 -0
- haoline/tests/test_eval.py +604 -0
- haoline/tests/test_format_adapters.py +460 -0
- haoline/tests/test_hardware.py +237 -0
- haoline/tests/test_hardware_recommender.py +90 -0
- haoline/tests/test_hierarchical_graph.py +326 -0
- haoline/tests/test_html_export.py +180 -0
- haoline/tests/test_layer_summary.py +428 -0
- haoline/tests/test_llm_patterns.py +540 -0
- haoline/tests/test_llm_summarizer.py +339 -0
- haoline/tests/test_patterns.py +774 -0
- haoline/tests/test_pytorch.py +327 -0
- haoline/tests/test_report.py +383 -0
- haoline/tests/test_risks.py +398 -0
- haoline/tests/test_schema.py +417 -0
- haoline/tests/test_tensorflow.py +380 -0
- haoline/tests/test_visualizations.py +316 -0
- haoline/universal_ir.py +856 -0
- haoline/visualizations.py +1086 -0
- haoline/visualize_yolo.py +44 -0
- haoline/web.py +110 -0
- haoline-0.3.0.dist-info/METADATA +471 -0
- haoline-0.3.0.dist-info/RECORD +70 -0
- haoline-0.3.0.dist-info/WHEEL +4 -0
- haoline-0.3.0.dist-info/entry_points.txt +5 -0
- haoline-0.3.0.dist-info/licenses/LICENSE +22 -0
haoline/analyzer.py
ADDED
|
@@ -0,0 +1,935 @@
|
|
|
1
|
+
# Copyright (c) 2025 HaoLine Contributors
|
|
2
|
+
# SPDX-License-Identifier: MIT
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
Core analysis engine for HaoLine.
|
|
6
|
+
|
|
7
|
+
This module provides:
|
|
8
|
+
- ONNXGraphLoader: Load ONNX models and extract graph structure
|
|
9
|
+
- MetricsEngine: Compute parameters, FLOPs, and memory estimates
|
|
10
|
+
- GraphInfo: Internal representation of the parsed graph
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
import pathlib
|
|
17
|
+
from dataclasses import dataclass, field
|
|
18
|
+
from typing import Any, ClassVar
|
|
19
|
+
|
|
20
|
+
import numpy as np
|
|
21
|
+
import onnx
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# Standalone implementations that work without onnxruntime
|
|
25
|
+
def get_opsets_imported(model: onnx.ModelProto) -> dict:
|
|
26
|
+
"""Get the opsets imported by the model."""
|
|
27
|
+
opsets = {}
|
|
28
|
+
for entry in model.opset_import:
|
|
29
|
+
domain = entry.domain or "ai.onnx"
|
|
30
|
+
opsets[domain] = entry.version
|
|
31
|
+
return opsets
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def iterate_graph_per_node_func(graph, per_node_func, **func_args):
|
|
35
|
+
"""Iterate the graph including subgraphs calling the per_node_func for each node."""
|
|
36
|
+
for node in graph.node:
|
|
37
|
+
per_node_func(node, **func_args)
|
|
38
|
+
for attr in node.attribute:
|
|
39
|
+
if attr.HasField("g"):
|
|
40
|
+
iterate_graph_per_node_func(attr.g, per_node_func, **func_args)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
# ORT utilities not available in standalone package - use onnx fallback
|
|
44
|
+
_HAS_ORT_UTILS = False
|
|
45
|
+
ModelProtoWithShapeInfo = None # type: ignore
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
49
|
+
class NodeInfo:
|
|
50
|
+
"""Information about a single ONNX node."""
|
|
51
|
+
|
|
52
|
+
name: str
|
|
53
|
+
op_type: str
|
|
54
|
+
domain: str
|
|
55
|
+
inputs: list[str]
|
|
56
|
+
outputs: list[str]
|
|
57
|
+
attributes: dict[str, Any]
|
|
58
|
+
# Computed during analysis
|
|
59
|
+
param_count: float = 0.0 # Float for fractional shared weight attribution
|
|
60
|
+
flops: int = 0
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@dataclass
|
|
64
|
+
class GraphInfo:
|
|
65
|
+
"""Parsed graph structure with extracted metadata."""
|
|
66
|
+
|
|
67
|
+
name: str
|
|
68
|
+
nodes: list[NodeInfo]
|
|
69
|
+
inputs: list[str]
|
|
70
|
+
outputs: list[str]
|
|
71
|
+
initializers: dict[str, np.ndarray] # name -> tensor
|
|
72
|
+
value_shapes: dict[str, list[int | str]] # name -> shape (may have symbolic dims)
|
|
73
|
+
|
|
74
|
+
# Computed summaries
|
|
75
|
+
num_nodes: int = 0
|
|
76
|
+
input_shapes: dict[str, list[int | str]] = field(default_factory=dict)
|
|
77
|
+
output_shapes: dict[str, list[int | str]] = field(default_factory=dict)
|
|
78
|
+
op_type_counts: dict[str, int] = field(default_factory=dict)
|
|
79
|
+
|
|
80
|
+
# Node lookup
|
|
81
|
+
node_by_name: dict[str, NodeInfo] = field(default_factory=dict)
|
|
82
|
+
node_by_output: dict[str, NodeInfo] = field(default_factory=dict)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@dataclass
|
|
86
|
+
class ParamCounts:
|
|
87
|
+
"""Parameter count breakdown."""
|
|
88
|
+
|
|
89
|
+
total: int = 0
|
|
90
|
+
trainable: int = 0 # Assumed: all initializers are trainable unless marked
|
|
91
|
+
non_trainable: int = 0
|
|
92
|
+
by_node: dict[str, float] = field(
|
|
93
|
+
default_factory=dict
|
|
94
|
+
) # Float for fractional shared attribution
|
|
95
|
+
by_op_type: dict[str, float] = field(
|
|
96
|
+
default_factory=dict
|
|
97
|
+
) # Float for fractional shared attribution
|
|
98
|
+
|
|
99
|
+
# Shared weight tracking
|
|
100
|
+
shared_weights: dict[str, list[str]] = field(
|
|
101
|
+
default_factory=dict
|
|
102
|
+
) # initializer -> nodes using it
|
|
103
|
+
num_shared_weights: int = 0 # Count of weights used by 2+ nodes
|
|
104
|
+
|
|
105
|
+
# Quantization info
|
|
106
|
+
precision_breakdown: dict[str, int] = field(default_factory=dict) # dtype -> param count
|
|
107
|
+
is_quantized: bool = False # True if model has quantized weights or ops
|
|
108
|
+
quantized_ops: list[str] = field(default_factory=list) # Quantized op types detected
|
|
109
|
+
|
|
110
|
+
def to_dict(self) -> dict:
|
|
111
|
+
return {
|
|
112
|
+
"total": self.total,
|
|
113
|
+
"trainable": self.trainable,
|
|
114
|
+
"non_trainable": self.non_trainable,
|
|
115
|
+
"by_op_type": {k: round(v, 2) for k, v in self.by_op_type.items()},
|
|
116
|
+
"shared_weights": {
|
|
117
|
+
"count": self.num_shared_weights,
|
|
118
|
+
"details": {k: v for k, v in self.shared_weights.items() if len(v) > 1},
|
|
119
|
+
},
|
|
120
|
+
"precision_breakdown": self.precision_breakdown,
|
|
121
|
+
"is_quantized": self.is_quantized,
|
|
122
|
+
"quantized_ops": self.quantized_ops,
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
@dataclass
|
|
127
|
+
class FlopCounts:
|
|
128
|
+
"""FLOP estimate breakdown."""
|
|
129
|
+
|
|
130
|
+
total: int = 0
|
|
131
|
+
by_node: dict[str, int] = field(default_factory=dict)
|
|
132
|
+
by_op_type: dict[str, int] = field(default_factory=dict)
|
|
133
|
+
|
|
134
|
+
def to_dict(self) -> dict:
|
|
135
|
+
return {
|
|
136
|
+
"total": self.total,
|
|
137
|
+
"by_op_type": self.by_op_type,
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
@dataclass
|
|
142
|
+
class MemoryBreakdown:
|
|
143
|
+
"""Detailed memory breakdown by component type."""
|
|
144
|
+
|
|
145
|
+
# Weights by operation type
|
|
146
|
+
weights_by_op_type: dict[str, int] = field(default_factory=dict) # op -> bytes
|
|
147
|
+
# Top weight tensors
|
|
148
|
+
largest_weights: list[tuple[str, int]] = field(default_factory=list) # (name, bytes)
|
|
149
|
+
# Activation breakdown
|
|
150
|
+
activations_by_op_type: dict[str, int] = field(default_factory=dict) # op -> bytes
|
|
151
|
+
largest_activations: list[tuple[str, int]] = field(default_factory=list)
|
|
152
|
+
|
|
153
|
+
def to_dict(self) -> dict[str, Any]:
|
|
154
|
+
return {
|
|
155
|
+
"weights_by_op_type": self.weights_by_op_type,
|
|
156
|
+
"largest_weights": [
|
|
157
|
+
{"name": name, "bytes": size} for name, size in self.largest_weights[:10]
|
|
158
|
+
],
|
|
159
|
+
"activations_by_op_type": self.activations_by_op_type,
|
|
160
|
+
"largest_activations": [
|
|
161
|
+
{"name": name, "bytes": size} for name, size in self.largest_activations[:10]
|
|
162
|
+
],
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@dataclass
|
|
167
|
+
class MemoryEstimates:
|
|
168
|
+
"""Memory usage estimates."""
|
|
169
|
+
|
|
170
|
+
model_size_bytes: int = 0 # Size of parameters/initializers
|
|
171
|
+
peak_activation_bytes: int = 0 # Estimated peak activation memory (batch=1)
|
|
172
|
+
per_layer_activation_bytes: dict[str, int] = field(default_factory=dict)
|
|
173
|
+
# KV cache estimates for transformer models
|
|
174
|
+
kv_cache_bytes_per_token: int = 0 # KV cache per token (for streaming inference)
|
|
175
|
+
kv_cache_bytes_full_context: int = 0 # Total KV cache at max seq length
|
|
176
|
+
kv_cache_config: dict[str, int] = field(default_factory=dict) # num_layers, hidden_dim, etc.
|
|
177
|
+
# Detailed breakdown
|
|
178
|
+
breakdown: MemoryBreakdown | None = None
|
|
179
|
+
|
|
180
|
+
def to_dict(self) -> dict[str, Any]:
|
|
181
|
+
result: dict[str, Any] = {
|
|
182
|
+
"model_size_bytes": self.model_size_bytes,
|
|
183
|
+
"peak_activation_bytes": self.peak_activation_bytes,
|
|
184
|
+
}
|
|
185
|
+
if self.kv_cache_bytes_per_token > 0:
|
|
186
|
+
result["kv_cache_bytes_per_token"] = self.kv_cache_bytes_per_token
|
|
187
|
+
result["kv_cache_bytes_full_context"] = self.kv_cache_bytes_full_context
|
|
188
|
+
result["kv_cache_config"] = self.kv_cache_config
|
|
189
|
+
if self.breakdown:
|
|
190
|
+
result["breakdown"] = self.breakdown.to_dict()
|
|
191
|
+
return result
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
class ONNXGraphLoader:
|
|
195
|
+
"""
|
|
196
|
+
Load ONNX models and extract graph structure.
|
|
197
|
+
|
|
198
|
+
Handles shape inference and creates a GraphInfo representation
|
|
199
|
+
suitable for analysis.
|
|
200
|
+
"""
|
|
201
|
+
|
|
202
|
+
def __init__(self, logger: logging.Logger | None = None):
|
|
203
|
+
self.logger = logger or logging.getLogger("haoline.loader")
|
|
204
|
+
|
|
205
|
+
def load(self, model_path: str | pathlib.Path) -> tuple[onnx.ModelProto, GraphInfo]:
|
|
206
|
+
"""
|
|
207
|
+
Load an ONNX model and extract graph information.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
model_path: Path to the ONNX model file.
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
Tuple of (ModelProto, GraphInfo)
|
|
214
|
+
"""
|
|
215
|
+
model_path = pathlib.Path(model_path)
|
|
216
|
+
self.logger.debug(f"Loading model from {model_path}")
|
|
217
|
+
|
|
218
|
+
# Use ORT's helper if available, otherwise fall back to onnx
|
|
219
|
+
if _HAS_ORT_UTILS and ModelProtoWithShapeInfo is not None:
|
|
220
|
+
wrapper = ModelProtoWithShapeInfo(model_path)
|
|
221
|
+
model = wrapper.model_with_shape_info
|
|
222
|
+
else:
|
|
223
|
+
# Fallback: load with onnx and run shape inference
|
|
224
|
+
model = onnx.load(str(model_path))
|
|
225
|
+
try:
|
|
226
|
+
model = onnx.shape_inference.infer_shapes(model, strict_mode=True)
|
|
227
|
+
except Exception as e:
|
|
228
|
+
self.logger.warning(f"Shape inference failed: {e}. Proceeding without shape info.")
|
|
229
|
+
|
|
230
|
+
graph_info = self._extract_graph_info(model.graph, model)
|
|
231
|
+
|
|
232
|
+
self.logger.debug(f"Loaded graph with {graph_info.num_nodes} nodes")
|
|
233
|
+
return model, graph_info
|
|
234
|
+
|
|
235
|
+
def _extract_graph_info(self, graph: onnx.GraphProto, model: onnx.ModelProto) -> GraphInfo:
|
|
236
|
+
"""Extract GraphInfo from an ONNX GraphProto."""
|
|
237
|
+
|
|
238
|
+
# Extract initializers (weights/biases)
|
|
239
|
+
initializers = {}
|
|
240
|
+
for init in graph.initializer:
|
|
241
|
+
try:
|
|
242
|
+
initializers[init.name] = onnx.numpy_helper.to_array(init)
|
|
243
|
+
except Exception as e:
|
|
244
|
+
self.logger.warning(f"Could not convert initializer {init.name}: {e}")
|
|
245
|
+
# Store shape info at minimum
|
|
246
|
+
initializers[init.name] = np.zeros(init.dims, dtype=np.float32)
|
|
247
|
+
|
|
248
|
+
# Build value shape map from value_info, inputs, and outputs
|
|
249
|
+
value_shapes = {}
|
|
250
|
+
|
|
251
|
+
def _extract_shape(value_info: onnx.ValueInfoProto) -> list[int | str]:
|
|
252
|
+
shape = []
|
|
253
|
+
if value_info.type.HasField("tensor_type"):
|
|
254
|
+
tensor_type = value_info.type.tensor_type
|
|
255
|
+
if tensor_type.HasField("shape"):
|
|
256
|
+
for dim in tensor_type.shape.dim:
|
|
257
|
+
if dim.HasField("dim_value"):
|
|
258
|
+
shape.append(dim.dim_value)
|
|
259
|
+
elif dim.HasField("dim_param"):
|
|
260
|
+
shape.append(dim.dim_param)
|
|
261
|
+
else:
|
|
262
|
+
shape.append("?")
|
|
263
|
+
return shape
|
|
264
|
+
|
|
265
|
+
for vi in graph.input:
|
|
266
|
+
value_shapes[vi.name] = _extract_shape(vi)
|
|
267
|
+
for vi in graph.output:
|
|
268
|
+
value_shapes[vi.name] = _extract_shape(vi)
|
|
269
|
+
for vi in graph.value_info:
|
|
270
|
+
value_shapes[vi.name] = _extract_shape(vi)
|
|
271
|
+
|
|
272
|
+
# For initializers without explicit value_info, use their tensor shapes
|
|
273
|
+
for name, arr in initializers.items():
|
|
274
|
+
if name not in value_shapes:
|
|
275
|
+
value_shapes[name] = list(arr.shape)
|
|
276
|
+
|
|
277
|
+
# Extract nodes
|
|
278
|
+
nodes: list[NodeInfo] = []
|
|
279
|
+
op_type_counts: dict[str, int] = {}
|
|
280
|
+
node_by_name: dict[str, NodeInfo] = {}
|
|
281
|
+
node_by_output: dict[str, NodeInfo] = {}
|
|
282
|
+
|
|
283
|
+
for node in graph.node:
|
|
284
|
+
# Extract attributes
|
|
285
|
+
attributes = {}
|
|
286
|
+
for attr in node.attribute:
|
|
287
|
+
if attr.HasField("i"):
|
|
288
|
+
attributes[attr.name] = attr.i
|
|
289
|
+
elif attr.HasField("f"):
|
|
290
|
+
attributes[attr.name] = attr.f
|
|
291
|
+
elif attr.HasField("s"):
|
|
292
|
+
attributes[attr.name] = (
|
|
293
|
+
attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s
|
|
294
|
+
)
|
|
295
|
+
elif attr.ints:
|
|
296
|
+
attributes[attr.name] = list(attr.ints)
|
|
297
|
+
elif attr.floats:
|
|
298
|
+
attributes[attr.name] = list(attr.floats)
|
|
299
|
+
# Skip subgraphs and other complex types for now
|
|
300
|
+
|
|
301
|
+
node_info = NodeInfo(
|
|
302
|
+
name=node.name or f"unnamed_{len(nodes)}",
|
|
303
|
+
op_type=node.op_type,
|
|
304
|
+
domain=node.domain or "ai.onnx",
|
|
305
|
+
inputs=list(node.input),
|
|
306
|
+
outputs=list(node.output),
|
|
307
|
+
attributes=attributes,
|
|
308
|
+
)
|
|
309
|
+
nodes.append(node_info)
|
|
310
|
+
node_by_name[node_info.name] = node_info
|
|
311
|
+
for output in node_info.outputs:
|
|
312
|
+
node_by_output[output] = node_info
|
|
313
|
+
|
|
314
|
+
# Count op types
|
|
315
|
+
op_type_counts[node.op_type] = op_type_counts.get(node.op_type, 0) + 1
|
|
316
|
+
|
|
317
|
+
# Build input/output shape maps (excluding initializers from inputs)
|
|
318
|
+
input_names = [i.name for i in graph.input if i.name not in initializers]
|
|
319
|
+
output_names = [o.name for o in graph.output]
|
|
320
|
+
|
|
321
|
+
input_shapes = {name: value_shapes.get(name, []) for name in input_names}
|
|
322
|
+
output_shapes = {name: value_shapes.get(name, []) for name in output_names}
|
|
323
|
+
|
|
324
|
+
return GraphInfo(
|
|
325
|
+
name=graph.name or "main",
|
|
326
|
+
nodes=nodes,
|
|
327
|
+
inputs=input_names,
|
|
328
|
+
outputs=output_names,
|
|
329
|
+
initializers=initializers,
|
|
330
|
+
value_shapes=value_shapes,
|
|
331
|
+
num_nodes=len(nodes),
|
|
332
|
+
input_shapes=input_shapes,
|
|
333
|
+
output_shapes=output_shapes,
|
|
334
|
+
op_type_counts=op_type_counts,
|
|
335
|
+
node_by_name=node_by_name,
|
|
336
|
+
node_by_output=node_by_output,
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
class MetricsEngine:
|
|
341
|
+
"""
|
|
342
|
+
Compute model complexity metrics.
|
|
343
|
+
|
|
344
|
+
Provides parameter counts, FLOP estimates, and memory estimates
|
|
345
|
+
for ONNX graphs.
|
|
346
|
+
"""
|
|
347
|
+
|
|
348
|
+
# FLOP multipliers per operation type
|
|
349
|
+
# These are rough estimates; actual FLOPs depend on implementation
|
|
350
|
+
FLOP_FORMULAS: ClassVar[dict[str, str]] = {
|
|
351
|
+
# Conv: 2 * K_h * K_w * C_in * C_out * H_out * W_out
|
|
352
|
+
"Conv": "conv",
|
|
353
|
+
# MatMul: 2 * M * N * K
|
|
354
|
+
"MatMul": "matmul",
|
|
355
|
+
"Gemm": "gemm",
|
|
356
|
+
# Element-wise ops: N elements
|
|
357
|
+
"Add": "elementwise",
|
|
358
|
+
"Sub": "elementwise",
|
|
359
|
+
"Mul": "elementwise",
|
|
360
|
+
"Div": "elementwise",
|
|
361
|
+
"Relu": "elementwise",
|
|
362
|
+
"Sigmoid": "elementwise",
|
|
363
|
+
"Tanh": "elementwise",
|
|
364
|
+
"Sqrt": "elementwise",
|
|
365
|
+
"Exp": "elementwise",
|
|
366
|
+
"Log": "elementwise",
|
|
367
|
+
"Gelu": "elementwise",
|
|
368
|
+
"Silu": "elementwise",
|
|
369
|
+
# Softmax: ~5N (exp, sum, div)
|
|
370
|
+
"Softmax": "softmax",
|
|
371
|
+
# Reduction ops: N elements
|
|
372
|
+
"ReduceMean": "elementwise",
|
|
373
|
+
"ReduceSum": "elementwise",
|
|
374
|
+
"ReduceMax": "elementwise",
|
|
375
|
+
# Normalization layers
|
|
376
|
+
"LayerNormalization": "layernorm",
|
|
377
|
+
"BatchNormalization": "batchnorm",
|
|
378
|
+
# Attention ops (ONNX contrib / custom)
|
|
379
|
+
"Attention": "attention",
|
|
380
|
+
"MultiHeadAttention": "attention",
|
|
381
|
+
"com.microsoft.Attention": "attention",
|
|
382
|
+
"com.microsoft.MultiHeadAttention": "attention",
|
|
383
|
+
}
|
|
384
|
+
|
|
385
|
+
# Quantized operation types in ONNX
|
|
386
|
+
QUANTIZED_OPS: ClassVar[set[str]] = {
|
|
387
|
+
"QuantizeLinear",
|
|
388
|
+
"DequantizeLinear",
|
|
389
|
+
"QLinearConv",
|
|
390
|
+
"QLinearMatMul",
|
|
391
|
+
"QLinearAdd",
|
|
392
|
+
"QGemm",
|
|
393
|
+
"ConvInteger",
|
|
394
|
+
"MatMulInteger",
|
|
395
|
+
"DynamicQuantizeLinear",
|
|
396
|
+
"QLinearSigmoid",
|
|
397
|
+
"QLinearLeakyRelu",
|
|
398
|
+
"QLinearAveragePool",
|
|
399
|
+
"QLinearGlobalAveragePool",
|
|
400
|
+
"QLinearConcat",
|
|
401
|
+
}
|
|
402
|
+
|
|
403
|
+
# Quantized dtypes
|
|
404
|
+
QUANTIZED_DTYPES: ClassVar[set[type]] = {np.int8, np.uint8, np.int16, np.uint16}
|
|
405
|
+
|
|
406
|
+
def __init__(self, logger: logging.Logger | None = None):
|
|
407
|
+
self.logger = logger or logging.getLogger("haoline.metrics")
|
|
408
|
+
|
|
409
|
+
def count_parameters(self, graph_info: GraphInfo) -> ParamCounts:
|
|
410
|
+
"""
|
|
411
|
+
Count parameters in the model.
|
|
412
|
+
|
|
413
|
+
Parameters are counted from initializers. All initializers are
|
|
414
|
+
assumed trainable unless specifically marked otherwise.
|
|
415
|
+
|
|
416
|
+
Handles edge cases:
|
|
417
|
+
- Shared weights: Uses fractional attribution so by_op_type sums to total
|
|
418
|
+
- Quantized params: Detects INT8/UINT8 weights and quantized ops
|
|
419
|
+
|
|
420
|
+
Args:
|
|
421
|
+
graph_info: Parsed graph information.
|
|
422
|
+
|
|
423
|
+
Returns:
|
|
424
|
+
ParamCounts with total and per-node breakdowns.
|
|
425
|
+
"""
|
|
426
|
+
counts = ParamCounts()
|
|
427
|
+
|
|
428
|
+
# First pass: build usage map (which nodes use each initializer)
|
|
429
|
+
usage_map: dict[str, list[str]] = {name: [] for name in graph_info.initializers}
|
|
430
|
+
for node in graph_info.nodes:
|
|
431
|
+
for inp in node.inputs:
|
|
432
|
+
if inp in graph_info.initializers:
|
|
433
|
+
usage_map[inp].append(node.name)
|
|
434
|
+
|
|
435
|
+
# Track shared weights (used by 2+ nodes)
|
|
436
|
+
counts.shared_weights = {k: v for k, v in usage_map.items() if len(v) > 1}
|
|
437
|
+
counts.num_shared_weights = len(counts.shared_weights)
|
|
438
|
+
|
|
439
|
+
# Detect quantized ops in the graph
|
|
440
|
+
quantized_ops_found = set()
|
|
441
|
+
for node in graph_info.nodes:
|
|
442
|
+
if node.op_type in self.QUANTIZED_OPS:
|
|
443
|
+
quantized_ops_found.add(node.op_type)
|
|
444
|
+
counts.quantized_ops = sorted(quantized_ops_found)
|
|
445
|
+
|
|
446
|
+
# Second pass: count parameters with fractional attribution
|
|
447
|
+
for name, tensor in graph_info.initializers.items():
|
|
448
|
+
param_count = int(np.prod(tensor.shape)) if tensor.shape else 1
|
|
449
|
+
counts.total += param_count
|
|
450
|
+
counts.by_node[name] = float(param_count)
|
|
451
|
+
|
|
452
|
+
# Track precision breakdown
|
|
453
|
+
dtype_name = self._get_dtype_name(tensor)
|
|
454
|
+
counts.precision_breakdown[dtype_name] = (
|
|
455
|
+
counts.precision_breakdown.get(dtype_name, 0) + param_count
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
# Check if this is a quantized weight
|
|
459
|
+
if hasattr(tensor, "dtype") and tensor.dtype in self.QUANTIZED_DTYPES:
|
|
460
|
+
counts.is_quantized = True
|
|
461
|
+
|
|
462
|
+
# Fractional attribution to nodes sharing this weight
|
|
463
|
+
using_nodes = usage_map[name]
|
|
464
|
+
num_users = len(using_nodes) if using_nodes else 1
|
|
465
|
+
fractional_count = param_count / num_users
|
|
466
|
+
|
|
467
|
+
for node in graph_info.nodes:
|
|
468
|
+
if node.name in using_nodes:
|
|
469
|
+
counts.by_op_type[node.op_type] = (
|
|
470
|
+
counts.by_op_type.get(node.op_type, 0.0) + fractional_count
|
|
471
|
+
)
|
|
472
|
+
node.param_count += fractional_count
|
|
473
|
+
|
|
474
|
+
# Mark as quantized if quantized ops are present
|
|
475
|
+
if counts.quantized_ops:
|
|
476
|
+
counts.is_quantized = True
|
|
477
|
+
|
|
478
|
+
# For now, assume all are trainable
|
|
479
|
+
# Could be refined with graph analysis (e.g., constants, frozen layers)
|
|
480
|
+
counts.trainable = counts.total
|
|
481
|
+
counts.non_trainable = 0
|
|
482
|
+
|
|
483
|
+
return counts
|
|
484
|
+
|
|
485
|
+
def _get_dtype_name(self, tensor: np.ndarray) -> str:
|
|
486
|
+
"""Get a human-readable dtype name for a tensor."""
|
|
487
|
+
if not hasattr(tensor, "dtype"):
|
|
488
|
+
return "unknown"
|
|
489
|
+
dtype = tensor.dtype
|
|
490
|
+
dtype_map = {
|
|
491
|
+
np.float32: "fp32",
|
|
492
|
+
np.float64: "fp64",
|
|
493
|
+
np.float16: "fp16",
|
|
494
|
+
np.int8: "int8",
|
|
495
|
+
np.uint8: "uint8",
|
|
496
|
+
np.int16: "int16",
|
|
497
|
+
np.uint16: "uint16",
|
|
498
|
+
np.int32: "int32",
|
|
499
|
+
np.int64: "int64",
|
|
500
|
+
}
|
|
501
|
+
return dtype_map.get(dtype.type, str(dtype))
|
|
502
|
+
|
|
503
|
+
def estimate_flops(self, graph_info: GraphInfo) -> FlopCounts:
|
|
504
|
+
"""
|
|
505
|
+
Estimate FLOPs for each operation in the graph.
|
|
506
|
+
|
|
507
|
+
Uses shape information to compute FLOPs. Falls back to
|
|
508
|
+
rough estimates when shapes are unavailable.
|
|
509
|
+
|
|
510
|
+
Args:
|
|
511
|
+
graph_info: Parsed graph information.
|
|
512
|
+
|
|
513
|
+
Returns:
|
|
514
|
+
FlopCounts with total and per-node breakdowns.
|
|
515
|
+
"""
|
|
516
|
+
counts = FlopCounts()
|
|
517
|
+
|
|
518
|
+
for node in graph_info.nodes:
|
|
519
|
+
flops = self._estimate_node_flops(node, graph_info)
|
|
520
|
+
node.flops = flops
|
|
521
|
+
counts.total += flops
|
|
522
|
+
counts.by_node[node.name] = flops
|
|
523
|
+
counts.by_op_type[node.op_type] = counts.by_op_type.get(node.op_type, 0) + flops
|
|
524
|
+
|
|
525
|
+
return counts
|
|
526
|
+
|
|
527
|
+
def _estimate_node_flops(self, node: NodeInfo, graph_info: GraphInfo) -> int:
|
|
528
|
+
"""Estimate FLOPs for a single node."""
|
|
529
|
+
formula_type = self.FLOP_FORMULAS.get(node.op_type, "unknown")
|
|
530
|
+
|
|
531
|
+
if formula_type == "conv":
|
|
532
|
+
return self._estimate_conv_flops(node, graph_info)
|
|
533
|
+
elif formula_type == "matmul":
|
|
534
|
+
return self._estimate_matmul_flops(node, graph_info)
|
|
535
|
+
elif formula_type == "gemm":
|
|
536
|
+
return self._estimate_gemm_flops(node, graph_info)
|
|
537
|
+
elif formula_type == "elementwise":
|
|
538
|
+
return self._estimate_elementwise_flops(node, graph_info)
|
|
539
|
+
elif formula_type == "softmax":
|
|
540
|
+
return self._estimate_elementwise_flops(node, graph_info) * 5
|
|
541
|
+
elif formula_type == "layernorm":
|
|
542
|
+
return self._estimate_elementwise_flops(node, graph_info) * 5
|
|
543
|
+
elif formula_type == "batchnorm":
|
|
544
|
+
return self._estimate_elementwise_flops(node, graph_info) * 2
|
|
545
|
+
elif formula_type == "attention":
|
|
546
|
+
return self._estimate_attention_flops(node, graph_info)
|
|
547
|
+
else:
|
|
548
|
+
# Unknown op - estimate based on output size
|
|
549
|
+
return self._estimate_elementwise_flops(node, graph_info)
|
|
550
|
+
|
|
551
|
+
def _estimate_conv_flops(self, node: NodeInfo, graph_info: GraphInfo) -> int:
|
|
552
|
+
"""Estimate FLOPs for Conv operation: 2 * K_h * K_w * C_in * C_out * H_out * W_out"""
|
|
553
|
+
if len(node.inputs) < 2:
|
|
554
|
+
return 0
|
|
555
|
+
|
|
556
|
+
# Get weight shape
|
|
557
|
+
weight_name = node.inputs[1]
|
|
558
|
+
if weight_name in graph_info.initializers:
|
|
559
|
+
weight_shape = list(graph_info.initializers[weight_name].shape)
|
|
560
|
+
elif weight_name in graph_info.value_shapes:
|
|
561
|
+
weight_shape = graph_info.value_shapes[weight_name]
|
|
562
|
+
else:
|
|
563
|
+
return 0
|
|
564
|
+
|
|
565
|
+
# Weight shape: [C_out, C_in/groups, K_h, K_w] for 2D conv
|
|
566
|
+
if len(weight_shape) < 4 or not all(isinstance(d, int) for d in weight_shape):
|
|
567
|
+
return 0
|
|
568
|
+
|
|
569
|
+
c_out, c_in_per_group, k_h, k_w = weight_shape[:4]
|
|
570
|
+
|
|
571
|
+
# Get output shape
|
|
572
|
+
if node.outputs and node.outputs[0] in graph_info.value_shapes:
|
|
573
|
+
output_shape = graph_info.value_shapes[node.outputs[0]]
|
|
574
|
+
if len(output_shape) >= 4 and all(isinstance(d, int) for d in output_shape[-2:]):
|
|
575
|
+
h_out, w_out = output_shape[-2], output_shape[-1]
|
|
576
|
+
else:
|
|
577
|
+
h_out, w_out = 1, 1
|
|
578
|
+
else:
|
|
579
|
+
h_out, w_out = 1, 1
|
|
580
|
+
|
|
581
|
+
node.attributes.get("group", 1)
|
|
582
|
+
flops = 2 * k_h * k_w * c_in_per_group * c_out * h_out * w_out
|
|
583
|
+
|
|
584
|
+
# Add bias if present
|
|
585
|
+
if len(node.inputs) > 2:
|
|
586
|
+
flops += c_out * h_out * w_out
|
|
587
|
+
|
|
588
|
+
return int(flops)
|
|
589
|
+
|
|
590
|
+
def _estimate_matmul_flops(self, node: NodeInfo, graph_info: GraphInfo) -> int:
|
|
591
|
+
"""Estimate FLOPs for MatMul: 2 * M * N * K"""
|
|
592
|
+
if len(node.inputs) < 2:
|
|
593
|
+
return 0
|
|
594
|
+
|
|
595
|
+
# Get shapes of both inputs
|
|
596
|
+
shapes = []
|
|
597
|
+
for inp in node.inputs[:2]:
|
|
598
|
+
if inp in graph_info.initializers:
|
|
599
|
+
shapes.append(list(graph_info.initializers[inp].shape))
|
|
600
|
+
elif inp in graph_info.value_shapes:
|
|
601
|
+
shapes.append(graph_info.value_shapes[inp])
|
|
602
|
+
else:
|
|
603
|
+
return 0
|
|
604
|
+
|
|
605
|
+
if len(shapes) < 2:
|
|
606
|
+
return 0
|
|
607
|
+
|
|
608
|
+
# MatMul: A[..., M, K] @ B[..., K, N] = C[..., M, N]
|
|
609
|
+
shape_a, shape_b = shapes[0], shapes[1]
|
|
610
|
+
|
|
611
|
+
# Handle broadcasting and get M, K, N
|
|
612
|
+
if len(shape_a) < 2 or len(shape_b) < 2:
|
|
613
|
+
return 0
|
|
614
|
+
|
|
615
|
+
if not all(isinstance(d, int) for d in shape_a[-2:]) or not all(
|
|
616
|
+
isinstance(d, int) for d in shape_b[-2:]
|
|
617
|
+
):
|
|
618
|
+
return 0
|
|
619
|
+
|
|
620
|
+
m, k = shape_a[-2], shape_a[-1]
|
|
621
|
+
k2, n = shape_b[-2], shape_b[-1]
|
|
622
|
+
|
|
623
|
+
if k != k2:
|
|
624
|
+
self.logger.warning(f"MatMul shape mismatch in node {node.name}: K={k} vs K={k2}")
|
|
625
|
+
return 0
|
|
626
|
+
|
|
627
|
+
# Handle batch dimensions
|
|
628
|
+
batch = 1
|
|
629
|
+
for dim in shape_a[:-2]:
|
|
630
|
+
if isinstance(dim, int):
|
|
631
|
+
batch *= dim
|
|
632
|
+
|
|
633
|
+
return int(2 * batch * m * n * k)
|
|
634
|
+
|
|
635
|
+
def _estimate_gemm_flops(self, node: NodeInfo, graph_info: GraphInfo) -> int:
|
|
636
|
+
"""Estimate FLOPs for Gemm: 2 * M * N * K + M * N (bias)"""
|
|
637
|
+
flops = self._estimate_matmul_flops(node, graph_info)
|
|
638
|
+
|
|
639
|
+
# Add bias computation if present
|
|
640
|
+
if len(node.inputs) > 2 and node.outputs and node.outputs[0] in graph_info.value_shapes:
|
|
641
|
+
output_shape = graph_info.value_shapes[node.outputs[0]]
|
|
642
|
+
if output_shape and all(isinstance(d, int) for d in output_shape):
|
|
643
|
+
int_shape: list[int] = [d for d in output_shape if isinstance(d, int)]
|
|
644
|
+
bias_flops = int(np.prod(int_shape))
|
|
645
|
+
flops += bias_flops
|
|
646
|
+
|
|
647
|
+
return flops
|
|
648
|
+
|
|
649
|
+
def _estimate_elementwise_flops(self, node: NodeInfo, graph_info: GraphInfo) -> int:
|
|
650
|
+
"""Estimate FLOPs for element-wise operations: N elements"""
|
|
651
|
+
# Use output shape to determine element count
|
|
652
|
+
if node.outputs and node.outputs[0] in graph_info.value_shapes:
|
|
653
|
+
shape = graph_info.value_shapes[node.outputs[0]]
|
|
654
|
+
if shape and all(isinstance(d, int) for d in shape):
|
|
655
|
+
int_shape: list[int] = [d for d in shape if isinstance(d, int)]
|
|
656
|
+
return int(np.prod(int_shape))
|
|
657
|
+
|
|
658
|
+
# Fallback: use first input shape
|
|
659
|
+
if node.inputs and node.inputs[0] in graph_info.value_shapes:
|
|
660
|
+
shape = graph_info.value_shapes[node.inputs[0]]
|
|
661
|
+
if shape and all(isinstance(d, int) for d in shape):
|
|
662
|
+
int_shape2: list[int] = [d for d in shape if isinstance(d, int)]
|
|
663
|
+
return int(np.prod(int_shape2))
|
|
664
|
+
|
|
665
|
+
return 0
|
|
666
|
+
|
|
667
|
+
def _estimate_attention_flops(self, node: NodeInfo, graph_info: GraphInfo) -> int:
|
|
668
|
+
"""
|
|
669
|
+
Estimate FLOPs for attention operations.
|
|
670
|
+
|
|
671
|
+
Standard multi-head attention FLOPs:
|
|
672
|
+
- QKV projections: 3 * batch * seq_len * d_model * d_model
|
|
673
|
+
- Attention scores (Q @ K^T): batch * num_heads * seq_len * seq_len * d_head
|
|
674
|
+
- Softmax: batch * num_heads * seq_len * seq_len * 5
|
|
675
|
+
- Attention output (scores @ V): batch * num_heads * seq_len * seq_len * d_head
|
|
676
|
+
- Output projection: batch * seq_len * d_model * d_model
|
|
677
|
+
|
|
678
|
+
Simplified formula: 4 * seq_len * d_model^2 + 2 * num_heads * seq_len^2 * d_head
|
|
679
|
+
"""
|
|
680
|
+
# Try to get dimensions from node attributes or input shapes
|
|
681
|
+
num_heads = 1
|
|
682
|
+
seq_len = 512 # Default assumption
|
|
683
|
+
d_model = 768 # Default assumption
|
|
684
|
+
|
|
685
|
+
# Try to extract from node attributes (ONNX Attention op)
|
|
686
|
+
for attr_name, attr_value in node.attributes.items():
|
|
687
|
+
if attr_name == "num_heads" and isinstance(attr_value, int):
|
|
688
|
+
num_heads = attr_value
|
|
689
|
+
elif attr_name == "hidden_size" and isinstance(attr_value, int):
|
|
690
|
+
d_model = attr_value
|
|
691
|
+
|
|
692
|
+
# Try to infer from input shapes
|
|
693
|
+
if node.inputs and node.inputs[0] in graph_info.value_shapes:
|
|
694
|
+
input_shape = graph_info.value_shapes[node.inputs[0]]
|
|
695
|
+
if input_shape and len(input_shape) >= 2:
|
|
696
|
+
# Shape is typically [batch, seq_len, d_model] or [batch, seq_len, ...]
|
|
697
|
+
if len(input_shape) >= 3:
|
|
698
|
+
if isinstance(input_shape[1], int):
|
|
699
|
+
seq_len = input_shape[1]
|
|
700
|
+
if isinstance(input_shape[2], int):
|
|
701
|
+
d_model = input_shape[2]
|
|
702
|
+
elif len(input_shape) == 2:
|
|
703
|
+
if isinstance(input_shape[0], int):
|
|
704
|
+
seq_len = input_shape[0]
|
|
705
|
+
if isinstance(input_shape[1], int):
|
|
706
|
+
d_model = input_shape[1]
|
|
707
|
+
|
|
708
|
+
d_head = d_model // num_heads if num_heads > 0 else d_model
|
|
709
|
+
|
|
710
|
+
# Compute FLOPs using standard attention formula
|
|
711
|
+
# QKV projections: 3 * seq * d_model * d_model
|
|
712
|
+
qkv_flops = 3 * seq_len * d_model * d_model
|
|
713
|
+
|
|
714
|
+
# Attention scores and output: 2 * num_heads * seq^2 * d_head
|
|
715
|
+
attention_flops = 2 * num_heads * seq_len * seq_len * d_head
|
|
716
|
+
|
|
717
|
+
# Output projection: seq * d_model * d_model
|
|
718
|
+
output_flops = seq_len * d_model * d_model
|
|
719
|
+
|
|
720
|
+
# Softmax on attention scores: 5 * num_heads * seq^2
|
|
721
|
+
softmax_flops = 5 * num_heads * seq_len * seq_len
|
|
722
|
+
|
|
723
|
+
total_flops = qkv_flops + attention_flops + output_flops + softmax_flops
|
|
724
|
+
|
|
725
|
+
self.logger.debug(
|
|
726
|
+
f"Attention FLOPs: seq={seq_len}, d_model={d_model}, "
|
|
727
|
+
f"heads={num_heads}, total={total_flops:,}"
|
|
728
|
+
)
|
|
729
|
+
|
|
730
|
+
return total_flops
|
|
731
|
+
|
|
732
|
+
def estimate_memory(self, graph_info: GraphInfo) -> MemoryEstimates:
|
|
733
|
+
"""
|
|
734
|
+
Estimate memory usage for the model.
|
|
735
|
+
|
|
736
|
+
Computes model size (parameters), peak activation memory,
|
|
737
|
+
KV cache size for transformer models, and detailed breakdown.
|
|
738
|
+
|
|
739
|
+
Args:
|
|
740
|
+
graph_info: Parsed graph information.
|
|
741
|
+
|
|
742
|
+
Returns:
|
|
743
|
+
MemoryEstimates with size, activation memory, KV cache, and breakdown.
|
|
744
|
+
"""
|
|
745
|
+
estimates = MemoryEstimates()
|
|
746
|
+
breakdown = MemoryBreakdown()
|
|
747
|
+
|
|
748
|
+
# Build mapping: initializer name -> op type that uses it
|
|
749
|
+
init_to_op: dict[str, str] = {}
|
|
750
|
+
for node in graph_info.nodes:
|
|
751
|
+
for inp in node.inputs:
|
|
752
|
+
if inp in graph_info.initializers and inp not in init_to_op:
|
|
753
|
+
init_to_op[inp] = node.op_type
|
|
754
|
+
|
|
755
|
+
# Model size: sum of initializer sizes with breakdown by op type
|
|
756
|
+
weight_sizes: list[tuple[str, int]] = []
|
|
757
|
+
for name, tensor in graph_info.initializers.items():
|
|
758
|
+
# Determine bytes per element based on dtype
|
|
759
|
+
bytes_per_elem = 4
|
|
760
|
+
if hasattr(tensor, "dtype"):
|
|
761
|
+
if tensor.dtype == np.float16:
|
|
762
|
+
bytes_per_elem = 2
|
|
763
|
+
elif tensor.dtype == np.float64:
|
|
764
|
+
bytes_per_elem = 8
|
|
765
|
+
elif tensor.dtype in (np.int8, np.uint8):
|
|
766
|
+
bytes_per_elem = 1
|
|
767
|
+
elif tensor.dtype in (np.int16, np.uint16):
|
|
768
|
+
bytes_per_elem = 2
|
|
769
|
+
|
|
770
|
+
tensor_bytes = (
|
|
771
|
+
int(np.prod(tensor.shape)) * bytes_per_elem if tensor.shape else bytes_per_elem
|
|
772
|
+
)
|
|
773
|
+
estimates.model_size_bytes += tensor_bytes
|
|
774
|
+
weight_sizes.append((name, tensor_bytes))
|
|
775
|
+
|
|
776
|
+
# Categorize by op type
|
|
777
|
+
op_type = init_to_op.get(name, "Other")
|
|
778
|
+
breakdown.weights_by_op_type[op_type] = (
|
|
779
|
+
breakdown.weights_by_op_type.get(op_type, 0) + tensor_bytes
|
|
780
|
+
)
|
|
781
|
+
|
|
782
|
+
# Store top 10 largest weights
|
|
783
|
+
breakdown.largest_weights = sorted(weight_sizes, key=lambda x: -x[1])[:10]
|
|
784
|
+
|
|
785
|
+
# Peak activation memory: rough estimate based on intermediate tensor sizes
|
|
786
|
+
# Build mapping: activation name -> op type that produces it
|
|
787
|
+
activation_to_op: dict[str, str] = {}
|
|
788
|
+
for node in graph_info.nodes:
|
|
789
|
+
for out in node.outputs:
|
|
790
|
+
activation_to_op[out] = node.op_type
|
|
791
|
+
|
|
792
|
+
activation_sizes: list[tuple[str, int]] = []
|
|
793
|
+
for name, shape in graph_info.value_shapes.items():
|
|
794
|
+
# Skip initializers (they're counted in model size)
|
|
795
|
+
if name in graph_info.initializers:
|
|
796
|
+
continue
|
|
797
|
+
|
|
798
|
+
if shape:
|
|
799
|
+
# Handle symbolic dimensions (e.g., 'N' for batch) by treating as 1
|
|
800
|
+
int_shape: list[int] = [d if isinstance(d, int) else 1 for d in shape]
|
|
801
|
+
# Skip if all dims are symbolic (no meaningful size)
|
|
802
|
+
if all(d == 1 for d in int_shape) and len(int_shape) > 1:
|
|
803
|
+
continue
|
|
804
|
+
# Assume float32 for activations
|
|
805
|
+
tensor_bytes = int(np.prod(int_shape)) * 4
|
|
806
|
+
activation_sizes.append((name, tensor_bytes))
|
|
807
|
+
estimates.per_layer_activation_bytes[name] = tensor_bytes
|
|
808
|
+
|
|
809
|
+
# Categorize by producing op type
|
|
810
|
+
op_type = activation_to_op.get(name, "Input")
|
|
811
|
+
breakdown.activations_by_op_type[op_type] = (
|
|
812
|
+
breakdown.activations_by_op_type.get(op_type, 0) + tensor_bytes
|
|
813
|
+
)
|
|
814
|
+
|
|
815
|
+
# Peak is approximate: sum of largest activations that might coexist
|
|
816
|
+
sorted_activations = sorted(activation_sizes, key=lambda x: -x[1])
|
|
817
|
+
# Rough heuristic: top 3 largest activations might coexist
|
|
818
|
+
top_n = min(3, len(sorted_activations))
|
|
819
|
+
estimates.peak_activation_bytes = sum(size for _, size in sorted_activations[:top_n])
|
|
820
|
+
|
|
821
|
+
# Store top 10 largest activations
|
|
822
|
+
breakdown.largest_activations = sorted_activations[:10]
|
|
823
|
+
|
|
824
|
+
# Store breakdown
|
|
825
|
+
estimates.breakdown = breakdown
|
|
826
|
+
|
|
827
|
+
# Estimate KV cache for transformer models
|
|
828
|
+
kv_config = self._estimate_kv_cache_config(graph_info)
|
|
829
|
+
if kv_config:
|
|
830
|
+
estimates.kv_cache_config = kv_config
|
|
831
|
+
estimates.kv_cache_bytes_per_token = self._compute_kv_cache_per_token(kv_config)
|
|
832
|
+
estimates.kv_cache_bytes_full_context = (
|
|
833
|
+
estimates.kv_cache_bytes_per_token * kv_config["seq_len"]
|
|
834
|
+
)
|
|
835
|
+
|
|
836
|
+
return estimates
|
|
837
|
+
|
|
838
|
+
def _estimate_kv_cache_config(self, graph_info: GraphInfo) -> dict[str, int]:
|
|
839
|
+
"""
|
|
840
|
+
Detect transformer architecture and extract KV cache config.
|
|
841
|
+
|
|
842
|
+
Returns dict with num_layers, hidden_dim, num_heads, seq_len, bytes_per_elem
|
|
843
|
+
or empty dict if not a transformer.
|
|
844
|
+
"""
|
|
845
|
+
# Check for attention ops
|
|
846
|
+
attention_ops = {"Attention", "MultiHeadAttention", "Softmax"}
|
|
847
|
+
attention_count = sum(graph_info.op_type_counts.get(op, 0) for op in attention_ops)
|
|
848
|
+
|
|
849
|
+
if attention_count == 0:
|
|
850
|
+
return {}
|
|
851
|
+
|
|
852
|
+
# Try to detect transformer parameters
|
|
853
|
+
num_layers = 0
|
|
854
|
+
hidden_dim = 768 # Default
|
|
855
|
+
num_heads = 12 # Default
|
|
856
|
+
seq_len = 512 # Default
|
|
857
|
+
bytes_per_elem = 4 # FP32 default
|
|
858
|
+
|
|
859
|
+
# Count attention ops to estimate number of layers
|
|
860
|
+
# Each transformer layer typically has one attention block
|
|
861
|
+
mha_count = graph_info.op_type_counts.get("Attention", 0) + graph_info.op_type_counts.get(
|
|
862
|
+
"MultiHeadAttention", 0
|
|
863
|
+
)
|
|
864
|
+
softmax_count = graph_info.op_type_counts.get("Softmax", 0)
|
|
865
|
+
|
|
866
|
+
# Use MHA count if available, otherwise estimate from Softmax
|
|
867
|
+
if mha_count > 0:
|
|
868
|
+
num_layers = mha_count
|
|
869
|
+
elif softmax_count > 0:
|
|
870
|
+
# Softmax in attention: typically one per layer (or two with cross-attention)
|
|
871
|
+
num_layers = max(1, softmax_count // 2)
|
|
872
|
+
|
|
873
|
+
if num_layers == 0:
|
|
874
|
+
return {}
|
|
875
|
+
|
|
876
|
+
# Try to infer hidden_dim from weight shapes
|
|
877
|
+
for node in graph_info.nodes:
|
|
878
|
+
if node.op_type in ("MatMul", "Gemm"):
|
|
879
|
+
for inp in node.inputs:
|
|
880
|
+
if inp in graph_info.initializers:
|
|
881
|
+
weight = graph_info.initializers[inp]
|
|
882
|
+
if len(weight.shape) == 2:
|
|
883
|
+
# Dense layer weights: [in_features, out_features] or vice versa
|
|
884
|
+
dim = max(weight.shape)
|
|
885
|
+
if 256 <= dim <= 16384 and dim % 64 == 0:
|
|
886
|
+
hidden_dim = dim
|
|
887
|
+
break
|
|
888
|
+
break
|
|
889
|
+
|
|
890
|
+
# Try to infer sequence length from input shapes
|
|
891
|
+
for shape in graph_info.input_shapes.values():
|
|
892
|
+
if len(shape) >= 2:
|
|
893
|
+
# Look for typical transformer input shape [batch, seq_len, ...] or [batch, seq_len]
|
|
894
|
+
for dim in shape[1:3]:
|
|
895
|
+
if isinstance(dim, int) and 16 <= dim <= 32768:
|
|
896
|
+
seq_len = dim
|
|
897
|
+
break
|
|
898
|
+
|
|
899
|
+
# Estimate num_heads from hidden_dim (typical: 64-128 per head)
|
|
900
|
+
if hidden_dim >= 256:
|
|
901
|
+
num_heads = max(1, hidden_dim // 64)
|
|
902
|
+
|
|
903
|
+
self.logger.debug(
|
|
904
|
+
f"KV cache config: layers={num_layers}, hidden={hidden_dim}, "
|
|
905
|
+
f"heads={num_heads}, seq={seq_len}"
|
|
906
|
+
)
|
|
907
|
+
|
|
908
|
+
return {
|
|
909
|
+
"num_layers": num_layers,
|
|
910
|
+
"hidden_dim": hidden_dim,
|
|
911
|
+
"num_heads": num_heads,
|
|
912
|
+
"seq_len": seq_len,
|
|
913
|
+
"bytes_per_elem": bytes_per_elem,
|
|
914
|
+
}
|
|
915
|
+
|
|
916
|
+
def _compute_kv_cache_per_token(self, config: dict[str, int]) -> int:
|
|
917
|
+
"""
|
|
918
|
+
Compute KV cache memory per token.
|
|
919
|
+
|
|
920
|
+
Formula: 2 * num_layers * hidden_dim * bytes_per_elem
|
|
921
|
+
(2 for K and V, each of size [hidden_dim])
|
|
922
|
+
|
|
923
|
+
For multi-head attention with head_dim = hidden_dim / num_heads:
|
|
924
|
+
KV cache per token per layer = 2 * hidden_dim * bytes_per_elem
|
|
925
|
+
|
|
926
|
+
Total per token = 2 * num_layers * hidden_dim * bytes_per_elem
|
|
927
|
+
"""
|
|
928
|
+
num_layers = config.get("num_layers", 0)
|
|
929
|
+
hidden_dim = config.get("hidden_dim", 0)
|
|
930
|
+
bytes_per_elem = config.get("bytes_per_elem", 4)
|
|
931
|
+
|
|
932
|
+
# KV cache: each layer stores K and V for each token
|
|
933
|
+
# K and V each have shape [batch, num_heads, seq_len, head_dim]
|
|
934
|
+
# Per token: 2 * hidden_dim * bytes_per_elem per layer
|
|
935
|
+
return 2 * num_layers * hidden_dim * bytes_per_elem
|