haoline 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- haoline/.streamlit/config.toml +10 -0
- haoline/__init__.py +248 -0
- haoline/analyzer.py +935 -0
- haoline/cli.py +2712 -0
- haoline/compare.py +811 -0
- haoline/compare_visualizations.py +1564 -0
- haoline/edge_analysis.py +525 -0
- haoline/eval/__init__.py +131 -0
- haoline/eval/adapters.py +844 -0
- haoline/eval/cli.py +390 -0
- haoline/eval/comparison.py +542 -0
- haoline/eval/deployment.py +633 -0
- haoline/eval/schemas.py +833 -0
- haoline/examples/__init__.py +15 -0
- haoline/examples/basic_inspection.py +74 -0
- haoline/examples/compare_models.py +117 -0
- haoline/examples/hardware_estimation.py +78 -0
- haoline/format_adapters.py +1001 -0
- haoline/formats/__init__.py +123 -0
- haoline/formats/coreml.py +250 -0
- haoline/formats/gguf.py +483 -0
- haoline/formats/openvino.py +255 -0
- haoline/formats/safetensors.py +273 -0
- haoline/formats/tflite.py +369 -0
- haoline/hardware.py +2307 -0
- haoline/hierarchical_graph.py +462 -0
- haoline/html_export.py +1573 -0
- haoline/layer_summary.py +769 -0
- haoline/llm_summarizer.py +465 -0
- haoline/op_icons.py +618 -0
- haoline/operational_profiling.py +1492 -0
- haoline/patterns.py +1116 -0
- haoline/pdf_generator.py +265 -0
- haoline/privacy.py +250 -0
- haoline/pydantic_models.py +241 -0
- haoline/report.py +1923 -0
- haoline/report_sections.py +539 -0
- haoline/risks.py +521 -0
- haoline/schema.py +523 -0
- haoline/streamlit_app.py +2024 -0
- haoline/tests/__init__.py +4 -0
- haoline/tests/conftest.py +123 -0
- haoline/tests/test_analyzer.py +868 -0
- haoline/tests/test_compare_visualizations.py +293 -0
- haoline/tests/test_edge_analysis.py +243 -0
- haoline/tests/test_eval.py +604 -0
- haoline/tests/test_format_adapters.py +460 -0
- haoline/tests/test_hardware.py +237 -0
- haoline/tests/test_hardware_recommender.py +90 -0
- haoline/tests/test_hierarchical_graph.py +326 -0
- haoline/tests/test_html_export.py +180 -0
- haoline/tests/test_layer_summary.py +428 -0
- haoline/tests/test_llm_patterns.py +540 -0
- haoline/tests/test_llm_summarizer.py +339 -0
- haoline/tests/test_patterns.py +774 -0
- haoline/tests/test_pytorch.py +327 -0
- haoline/tests/test_report.py +383 -0
- haoline/tests/test_risks.py +398 -0
- haoline/tests/test_schema.py +417 -0
- haoline/tests/test_tensorflow.py +380 -0
- haoline/tests/test_visualizations.py +316 -0
- haoline/universal_ir.py +856 -0
- haoline/visualizations.py +1086 -0
- haoline/visualize_yolo.py +44 -0
- haoline/web.py +110 -0
- haoline-0.3.0.dist-info/METADATA +471 -0
- haoline-0.3.0.dist-info/RECORD +70 -0
- haoline-0.3.0.dist-info/WHEEL +4 -0
- haoline-0.3.0.dist-info/entry_points.txt +5 -0
- haoline-0.3.0.dist-info/licenses/LICENSE +22 -0
haoline/patterns.py
ADDED
|
@@ -0,0 +1,1116 @@
|
|
|
1
|
+
# Copyright (c) 2025 HaoLine Contributors
|
|
2
|
+
# SPDX-License-Identifier: MIT
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
Pattern detection for HaoLine.
|
|
6
|
+
|
|
7
|
+
Detects common architectural patterns in ONNX graphs:
|
|
8
|
+
- Conv-BatchNorm-ReLU blocks
|
|
9
|
+
- Residual/skip connections
|
|
10
|
+
- Transformer blocks (attention + MLP)
|
|
11
|
+
- Embedding layers
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
from dataclasses import dataclass, field
|
|
18
|
+
from typing import TYPE_CHECKING, Any, ClassVar
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from .analyzer import GraphInfo, NodeInfo
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class Block:
|
|
26
|
+
"""
|
|
27
|
+
A detected architectural block (group of related nodes).
|
|
28
|
+
|
|
29
|
+
Blocks represent higher-level patterns like "ResidualBlock" or
|
|
30
|
+
"TransformerLayer" that consist of multiple ONNX nodes.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
block_type: str # e.g., "ConvBNRelu", "ResidualBlock", "TransformerBlock"
|
|
34
|
+
name: str
|
|
35
|
+
nodes: list[str] # Node names in this block
|
|
36
|
+
start_node: str
|
|
37
|
+
end_node: str
|
|
38
|
+
attributes: dict = field(default_factory=dict) # Block-specific metadata
|
|
39
|
+
|
|
40
|
+
def to_dict(self) -> dict:
|
|
41
|
+
return {
|
|
42
|
+
"block_type": self.block_type,
|
|
43
|
+
"name": self.name,
|
|
44
|
+
"nodes": self.nodes,
|
|
45
|
+
"start_node": self.start_node,
|
|
46
|
+
"end_node": self.end_node,
|
|
47
|
+
"attributes": self.attributes,
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class PatternAnalyzer:
|
|
52
|
+
"""
|
|
53
|
+
Detect architectural patterns in ONNX graphs.
|
|
54
|
+
|
|
55
|
+
Identifies common patterns like Conv-BN-ReLU sequences, residual
|
|
56
|
+
blocks, and transformer attention blocks.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
# Operators that commonly appear together
|
|
60
|
+
CONV_ACTIVATIONS: ClassVar[set[str]] = {
|
|
61
|
+
"Relu",
|
|
62
|
+
"LeakyRelu",
|
|
63
|
+
"Sigmoid",
|
|
64
|
+
"Tanh",
|
|
65
|
+
"Clip",
|
|
66
|
+
"HardSwish",
|
|
67
|
+
"Silu",
|
|
68
|
+
}
|
|
69
|
+
NORM_OPS: ClassVar[set[str]] = {
|
|
70
|
+
"BatchNormalization",
|
|
71
|
+
"InstanceNormalization",
|
|
72
|
+
"LayerNormalization",
|
|
73
|
+
"GroupNormalization",
|
|
74
|
+
}
|
|
75
|
+
ATTENTION_OPS: ClassVar[set[str]] = {"MatMul", "Softmax", "Transpose"}
|
|
76
|
+
EMBEDDING_OPS: ClassVar[set[str]] = {"Gather", "Embedding"}
|
|
77
|
+
|
|
78
|
+
# LLM-specific activation functions
|
|
79
|
+
LLM_ACTIVATIONS: ClassVar[set[str]] = {
|
|
80
|
+
"Gelu",
|
|
81
|
+
"FastGelu",
|
|
82
|
+
"QuickGelu", # GPT-style
|
|
83
|
+
"Silu",
|
|
84
|
+
"Swish", # LLaMA/Mistral style
|
|
85
|
+
"Relu",
|
|
86
|
+
"LeakyRelu", # Classic
|
|
87
|
+
"NewGelu", # Some implementations
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
# MoE-related ops
|
|
91
|
+
MOE_OPS: ClassVar[set[str]] = {"TopK", "Scatter", "ScatterND", "GatherND"}
|
|
92
|
+
|
|
93
|
+
def __init__(self, logger: logging.Logger | None = None):
|
|
94
|
+
self.logger = logger or logging.getLogger("haoline.patterns")
|
|
95
|
+
|
|
96
|
+
def group_into_blocks(self, graph_info: GraphInfo) -> list[Block]:
|
|
97
|
+
"""
|
|
98
|
+
Detect all architectural blocks in the graph.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
graph_info: Parsed graph information.
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
List of detected Block instances.
|
|
105
|
+
"""
|
|
106
|
+
blocks: list[Block] = []
|
|
107
|
+
|
|
108
|
+
# Detect various patterns - ordered from specific to general
|
|
109
|
+
blocks.extend(self.detect_conv_bn_relu(graph_info))
|
|
110
|
+
blocks.extend(self.detect_residual_blocks(graph_info))
|
|
111
|
+
blocks.extend(self.detect_nonstandard_residual_blocks(graph_info))
|
|
112
|
+
|
|
113
|
+
# LLM-specific patterns (Task 5.4.1-5.4.7)
|
|
114
|
+
blocks.extend(self.detect_attention_heads(graph_info))
|
|
115
|
+
blocks.extend(self.detect_mlp_blocks(graph_info))
|
|
116
|
+
blocks.extend(self.detect_embedding_layers(graph_info))
|
|
117
|
+
blocks.extend(self.detect_position_encoding(graph_info))
|
|
118
|
+
blocks.extend(self.detect_moe_routing(graph_info))
|
|
119
|
+
|
|
120
|
+
# High-level transformer detection (uses sub-blocks)
|
|
121
|
+
blocks.extend(self.detect_transformer_blocks(graph_info))
|
|
122
|
+
|
|
123
|
+
# Detect repeated blocks
|
|
124
|
+
repeated = self.detect_repeated_blocks(graph_info, blocks)
|
|
125
|
+
if repeated:
|
|
126
|
+
blocks.extend(repeated)
|
|
127
|
+
|
|
128
|
+
self.logger.debug(f"Detected {len(blocks)} blocks")
|
|
129
|
+
return blocks
|
|
130
|
+
|
|
131
|
+
def detect_conv_bn_relu(self, graph_info: GraphInfo) -> list[Block]:
|
|
132
|
+
"""
|
|
133
|
+
Find Conv-BatchNorm-ReLU sequences.
|
|
134
|
+
|
|
135
|
+
Matches patterns like:
|
|
136
|
+
- Conv -> BatchNorm -> ReLU
|
|
137
|
+
- Conv -> ReLU
|
|
138
|
+
- Conv -> BatchNorm
|
|
139
|
+
"""
|
|
140
|
+
blocks: list[Block] = []
|
|
141
|
+
visited: set[str] = set()
|
|
142
|
+
|
|
143
|
+
for node in graph_info.nodes:
|
|
144
|
+
if node.name in visited:
|
|
145
|
+
continue
|
|
146
|
+
|
|
147
|
+
if node.op_type == "Conv":
|
|
148
|
+
block_nodes = [node.name]
|
|
149
|
+
current_output = node.outputs[0] if node.outputs else None
|
|
150
|
+
block_type_parts = ["Conv"]
|
|
151
|
+
|
|
152
|
+
# Look for BatchNorm
|
|
153
|
+
if current_output:
|
|
154
|
+
next_node = self._find_consumer(current_output, graph_info)
|
|
155
|
+
if next_node and next_node.op_type in self.NORM_OPS:
|
|
156
|
+
block_nodes.append(next_node.name)
|
|
157
|
+
block_type_parts.append("BN")
|
|
158
|
+
current_output = next_node.outputs[0] if next_node.outputs else None
|
|
159
|
+
|
|
160
|
+
# Look for activation after BN
|
|
161
|
+
if current_output:
|
|
162
|
+
act_node = self._find_consumer(current_output, graph_info)
|
|
163
|
+
if act_node and act_node.op_type in self.CONV_ACTIVATIONS:
|
|
164
|
+
block_nodes.append(act_node.name)
|
|
165
|
+
block_type_parts.append(act_node.op_type)
|
|
166
|
+
elif next_node and next_node.op_type in self.CONV_ACTIVATIONS:
|
|
167
|
+
# Conv -> ReLU without BN
|
|
168
|
+
block_nodes.append(next_node.name)
|
|
169
|
+
block_type_parts.append(next_node.op_type)
|
|
170
|
+
|
|
171
|
+
if len(block_nodes) > 1:
|
|
172
|
+
visited.update(block_nodes)
|
|
173
|
+
block = Block(
|
|
174
|
+
block_type="".join(block_type_parts),
|
|
175
|
+
name=f"conv_block_{len(blocks)}",
|
|
176
|
+
nodes=block_nodes,
|
|
177
|
+
start_node=block_nodes[0],
|
|
178
|
+
end_node=block_nodes[-1],
|
|
179
|
+
)
|
|
180
|
+
blocks.append(block)
|
|
181
|
+
|
|
182
|
+
return blocks
|
|
183
|
+
|
|
184
|
+
def detect_residual_blocks(self, graph_info: GraphInfo) -> list[Block]:
|
|
185
|
+
"""
|
|
186
|
+
Find residual/skip connection patterns.
|
|
187
|
+
|
|
188
|
+
Looks for Add nodes where one input comes from earlier in the graph
|
|
189
|
+
(skip connection).
|
|
190
|
+
"""
|
|
191
|
+
blocks: list[Block] = []
|
|
192
|
+
|
|
193
|
+
for node in graph_info.nodes:
|
|
194
|
+
if node.op_type == "Add" and len(node.inputs) >= 2:
|
|
195
|
+
# Check if this could be a residual connection
|
|
196
|
+
# by looking for inputs that come from different depths
|
|
197
|
+
input_nodes = []
|
|
198
|
+
for inp in node.inputs:
|
|
199
|
+
if inp in graph_info.node_by_output:
|
|
200
|
+
input_nodes.append(graph_info.node_by_output[inp])
|
|
201
|
+
|
|
202
|
+
if len(input_nodes) >= 2:
|
|
203
|
+
# Heuristic: if one path is longer (more hops), it's likely the residual path
|
|
204
|
+
# For now, just detect the pattern exists
|
|
205
|
+
blocks.append(
|
|
206
|
+
Block(
|
|
207
|
+
block_type="ResidualAdd",
|
|
208
|
+
name=f"residual_{len(blocks)}",
|
|
209
|
+
nodes=[node.name],
|
|
210
|
+
start_node=node.name,
|
|
211
|
+
end_node=node.name,
|
|
212
|
+
attributes={"inputs": node.inputs, "variant": "standard"},
|
|
213
|
+
)
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
return blocks
|
|
217
|
+
|
|
218
|
+
def detect_nonstandard_residual_blocks(self, graph_info: GraphInfo) -> list[Block]:
|
|
219
|
+
"""
|
|
220
|
+
Find non-standard residual/skip connection patterns.
|
|
221
|
+
|
|
222
|
+
Detects alternative skip connection implementations:
|
|
223
|
+
- Concat-based skip connections (DenseNet-style)
|
|
224
|
+
- Mul-based gating mechanisms (Highway networks, attention gates)
|
|
225
|
+
- Sub-based connections (rare but possible)
|
|
226
|
+
|
|
227
|
+
These may indicate custom architectures that need special handling.
|
|
228
|
+
"""
|
|
229
|
+
blocks: list[Block] = []
|
|
230
|
+
|
|
231
|
+
# Concat-based skip connections (DenseNet-style)
|
|
232
|
+
for node in graph_info.nodes:
|
|
233
|
+
if node.op_type == "Concat" and len(node.inputs) >= 2:
|
|
234
|
+
# Check if inputs come from different depths (skip connection indicator)
|
|
235
|
+
input_depths = self._estimate_input_depths(node.inputs, graph_info)
|
|
236
|
+
if input_depths and max(input_depths) - min(input_depths) >= 2:
|
|
237
|
+
blocks.append(
|
|
238
|
+
Block(
|
|
239
|
+
block_type="ResidualConcat",
|
|
240
|
+
name=f"dense_skip_{len(blocks)}",
|
|
241
|
+
nodes=[node.name],
|
|
242
|
+
start_node=node.name,
|
|
243
|
+
end_node=node.name,
|
|
244
|
+
attributes={
|
|
245
|
+
"inputs": node.inputs,
|
|
246
|
+
"variant": "concat",
|
|
247
|
+
"depth_diff": max(input_depths) - min(input_depths),
|
|
248
|
+
},
|
|
249
|
+
)
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
# Mul-based gating (Highway networks, attention gates)
|
|
253
|
+
for node in graph_info.nodes:
|
|
254
|
+
if node.op_type == "Mul" and len(node.inputs) >= 2:
|
|
255
|
+
# Look for Sigmoid before Mul (gating pattern)
|
|
256
|
+
has_sigmoid_input = False
|
|
257
|
+
for inp in node.inputs:
|
|
258
|
+
if inp in graph_info.node_by_output:
|
|
259
|
+
prev_node = graph_info.node_by_output[inp]
|
|
260
|
+
if prev_node.op_type == "Sigmoid":
|
|
261
|
+
has_sigmoid_input = True
|
|
262
|
+
break
|
|
263
|
+
|
|
264
|
+
if has_sigmoid_input:
|
|
265
|
+
blocks.append(
|
|
266
|
+
Block(
|
|
267
|
+
block_type="ResidualGate",
|
|
268
|
+
name=f"gate_{len(blocks)}",
|
|
269
|
+
nodes=[node.name],
|
|
270
|
+
start_node=node.name,
|
|
271
|
+
end_node=node.name,
|
|
272
|
+
attributes={"inputs": node.inputs, "variant": "gated"},
|
|
273
|
+
)
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
# Sub-based connections (rare, but could be learned residual)
|
|
277
|
+
for node in graph_info.nodes:
|
|
278
|
+
if node.op_type == "Sub" and len(node.inputs) >= 2:
|
|
279
|
+
input_nodes = []
|
|
280
|
+
for inp in node.inputs:
|
|
281
|
+
if inp in graph_info.node_by_output:
|
|
282
|
+
input_nodes.append(graph_info.node_by_output[inp])
|
|
283
|
+
|
|
284
|
+
if len(input_nodes) >= 2:
|
|
285
|
+
blocks.append(
|
|
286
|
+
Block(
|
|
287
|
+
block_type="ResidualSub",
|
|
288
|
+
name=f"sub_residual_{len(blocks)}",
|
|
289
|
+
nodes=[node.name],
|
|
290
|
+
start_node=node.name,
|
|
291
|
+
end_node=node.name,
|
|
292
|
+
attributes={"inputs": node.inputs, "variant": "subtract"},
|
|
293
|
+
)
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
return blocks
|
|
297
|
+
|
|
298
|
+
def _estimate_input_depths(
|
|
299
|
+
self, inputs: list[str], graph_info: GraphInfo, max_depth: int = 20
|
|
300
|
+
) -> list[int]:
|
|
301
|
+
"""
|
|
302
|
+
Estimate the graph depth of each input tensor.
|
|
303
|
+
|
|
304
|
+
Returns a list of estimated depths (hops from graph inputs).
|
|
305
|
+
Used to detect skip connections where inputs come from very different depths.
|
|
306
|
+
"""
|
|
307
|
+
depths = []
|
|
308
|
+
for inp in inputs:
|
|
309
|
+
depth = self._trace_depth(inp, graph_info, 0, max_depth)
|
|
310
|
+
depths.append(depth)
|
|
311
|
+
return depths
|
|
312
|
+
|
|
313
|
+
def _trace_depth(
|
|
314
|
+
self,
|
|
315
|
+
tensor_name: str,
|
|
316
|
+
graph_info: GraphInfo,
|
|
317
|
+
current_depth: int,
|
|
318
|
+
max_depth: int,
|
|
319
|
+
) -> int:
|
|
320
|
+
"""Recursively trace back to find the depth of a tensor."""
|
|
321
|
+
if current_depth >= max_depth:
|
|
322
|
+
return current_depth
|
|
323
|
+
|
|
324
|
+
# If it's a graph input or initializer, depth is 0
|
|
325
|
+
if tensor_name in graph_info.input_shapes:
|
|
326
|
+
return 0
|
|
327
|
+
if tensor_name in graph_info.initializers:
|
|
328
|
+
return 0
|
|
329
|
+
|
|
330
|
+
# Find the node that produces this tensor
|
|
331
|
+
if tensor_name in graph_info.node_by_output:
|
|
332
|
+
producer = graph_info.node_by_output[tensor_name]
|
|
333
|
+
if producer.inputs:
|
|
334
|
+
# Trace back through the first input
|
|
335
|
+
return 1 + self._trace_depth(
|
|
336
|
+
producer.inputs[0], graph_info, current_depth + 1, max_depth
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
return current_depth
|
|
340
|
+
|
|
341
|
+
def detect_transformer_blocks(self, graph_info: GraphInfo) -> list[Block]:
|
|
342
|
+
"""
|
|
343
|
+
Find transformer attention patterns.
|
|
344
|
+
|
|
345
|
+
Looks for the characteristic Softmax in attention computation
|
|
346
|
+
and MatMul patterns for Q, K, V projections.
|
|
347
|
+
"""
|
|
348
|
+
blocks: list[Block] = []
|
|
349
|
+
softmax_nodes = [n for n in graph_info.nodes if n.op_type == "Softmax"]
|
|
350
|
+
|
|
351
|
+
for softmax in softmax_nodes:
|
|
352
|
+
# Look for attention pattern: MatMul -> Softmax -> MatMul
|
|
353
|
+
before_nodes: list[str] = []
|
|
354
|
+
after_nodes: list[str] = []
|
|
355
|
+
|
|
356
|
+
# Find MatMul before softmax
|
|
357
|
+
if softmax.inputs:
|
|
358
|
+
inp = softmax.inputs[0]
|
|
359
|
+
if inp in graph_info.node_by_output:
|
|
360
|
+
prev = graph_info.node_by_output[inp]
|
|
361
|
+
if prev.op_type in (
|
|
362
|
+
"MatMul",
|
|
363
|
+
"Gemm",
|
|
364
|
+
"Div",
|
|
365
|
+
"Mul",
|
|
366
|
+
): # Div for scaling
|
|
367
|
+
before_nodes.append(prev.name)
|
|
368
|
+
|
|
369
|
+
# Find MatMul after softmax
|
|
370
|
+
if softmax.outputs:
|
|
371
|
+
consumer = self._find_consumer(softmax.outputs[0], graph_info)
|
|
372
|
+
if consumer and consumer.op_type in ("MatMul", "Gemm"):
|
|
373
|
+
after_nodes.append(consumer.name)
|
|
374
|
+
|
|
375
|
+
if before_nodes and after_nodes:
|
|
376
|
+
all_nodes = [*before_nodes, softmax.name, *after_nodes]
|
|
377
|
+
blocks.append(
|
|
378
|
+
Block(
|
|
379
|
+
block_type="Attention",
|
|
380
|
+
name=f"attention_{len(blocks)}",
|
|
381
|
+
nodes=all_nodes,
|
|
382
|
+
start_node=before_nodes[0],
|
|
383
|
+
end_node=after_nodes[-1],
|
|
384
|
+
)
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
# Also look for LayerNorm which often brackets transformer blocks
|
|
388
|
+
layernorm_count = sum(1 for n in graph_info.nodes if n.op_type == "LayerNormalization")
|
|
389
|
+
if layernorm_count >= 2 and blocks:
|
|
390
|
+
# Likely a transformer architecture
|
|
391
|
+
self.logger.debug(
|
|
392
|
+
f"Found {len(blocks)} attention blocks with {layernorm_count} LayerNorms"
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
return blocks
|
|
396
|
+
|
|
397
|
+
def detect_embedding_layers(self, graph_info: GraphInfo) -> list[Block]:
|
|
398
|
+
"""
|
|
399
|
+
Find embedding lookup patterns.
|
|
400
|
+
|
|
401
|
+
Looks for Gather operations on large weight tensors.
|
|
402
|
+
"""
|
|
403
|
+
blocks: list[Block] = []
|
|
404
|
+
|
|
405
|
+
for node in graph_info.nodes:
|
|
406
|
+
if node.op_type == "Gather":
|
|
407
|
+
# Check if first input is a large initializer (embedding table)
|
|
408
|
+
if node.inputs and node.inputs[0] in graph_info.initializers:
|
|
409
|
+
embed_table = graph_info.initializers[node.inputs[0]]
|
|
410
|
+
if len(embed_table.shape) == 2:
|
|
411
|
+
vocab_size, embed_dim = embed_table.shape
|
|
412
|
+
# Token embedding typically has large vocab (>1000)
|
|
413
|
+
embed_type = "token" if vocab_size > 1000 else "position"
|
|
414
|
+
blocks.append(
|
|
415
|
+
Block(
|
|
416
|
+
block_type="Embedding",
|
|
417
|
+
name=f"embedding_{len(blocks)}",
|
|
418
|
+
nodes=[node.name],
|
|
419
|
+
start_node=node.name,
|
|
420
|
+
end_node=node.name,
|
|
421
|
+
attributes={
|
|
422
|
+
"vocab_size": int(vocab_size),
|
|
423
|
+
"embed_dim": int(embed_dim),
|
|
424
|
+
"embed_type": embed_type,
|
|
425
|
+
},
|
|
426
|
+
)
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
return blocks
|
|
430
|
+
|
|
431
|
+
def detect_attention_heads(self, graph_info: GraphInfo) -> list[Block]:
|
|
432
|
+
"""
|
|
433
|
+
Detect attention head patterns with Q/K/V projections.
|
|
434
|
+
|
|
435
|
+
Task 5.4.1: Enhanced attention detection for LLMs.
|
|
436
|
+
|
|
437
|
+
Patterns detected:
|
|
438
|
+
- Standard MHA: Q, K, V linear -> attention -> output linear
|
|
439
|
+
- MQA: Single K, V shared across Q heads
|
|
440
|
+
- GQA: Grouped K, V (fewer KV heads than Q heads)
|
|
441
|
+
"""
|
|
442
|
+
blocks: list[Block] = []
|
|
443
|
+
visited_softmax: set[str] = set()
|
|
444
|
+
|
|
445
|
+
# Find all Softmax nodes (attention core)
|
|
446
|
+
softmax_nodes = [n for n in graph_info.nodes if n.op_type == "Softmax"]
|
|
447
|
+
|
|
448
|
+
for softmax in softmax_nodes:
|
|
449
|
+
if softmax.name in visited_softmax:
|
|
450
|
+
continue
|
|
451
|
+
|
|
452
|
+
attention_info = self._trace_attention_pattern(softmax, graph_info)
|
|
453
|
+
if attention_info:
|
|
454
|
+
visited_softmax.add(softmax.name)
|
|
455
|
+
|
|
456
|
+
# Determine attention type
|
|
457
|
+
num_q_heads = attention_info.get("num_q_heads", 0)
|
|
458
|
+
num_kv_heads = attention_info.get("num_kv_heads", 0)
|
|
459
|
+
|
|
460
|
+
if num_kv_heads == 1:
|
|
461
|
+
attention_type = "MQA" # Multi-Query Attention
|
|
462
|
+
elif num_kv_heads > 0 and num_kv_heads < num_q_heads:
|
|
463
|
+
attention_type = "GQA" # Grouped-Query Attention
|
|
464
|
+
else:
|
|
465
|
+
attention_type = "MHA" # Standard Multi-Head Attention
|
|
466
|
+
|
|
467
|
+
blocks.append(
|
|
468
|
+
Block(
|
|
469
|
+
block_type="AttentionHead",
|
|
470
|
+
name=f"attention_head_{len(blocks)}",
|
|
471
|
+
nodes=attention_info.get("nodes", [softmax.name]),
|
|
472
|
+
start_node=attention_info.get("q_proj", softmax.name),
|
|
473
|
+
end_node=attention_info.get("o_proj", softmax.name),
|
|
474
|
+
attributes={
|
|
475
|
+
"attention_type": attention_type,
|
|
476
|
+
"num_q_heads": num_q_heads,
|
|
477
|
+
"num_kv_heads": num_kv_heads,
|
|
478
|
+
"has_scaling": attention_info.get("has_scaling", False),
|
|
479
|
+
"has_mask": attention_info.get("has_mask", False),
|
|
480
|
+
"q_proj": attention_info.get("q_proj"),
|
|
481
|
+
"k_proj": attention_info.get("k_proj"),
|
|
482
|
+
"v_proj": attention_info.get("v_proj"),
|
|
483
|
+
"o_proj": attention_info.get("o_proj"),
|
|
484
|
+
},
|
|
485
|
+
)
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
return blocks
|
|
489
|
+
|
|
490
|
+
def _trace_attention_pattern(
|
|
491
|
+
self, softmax: NodeInfo, graph_info: GraphInfo
|
|
492
|
+
) -> dict[str, Any] | None:
|
|
493
|
+
"""Trace back from Softmax to find Q/K/V projections."""
|
|
494
|
+
result: dict[str, Any] = {
|
|
495
|
+
"nodes": [softmax.name],
|
|
496
|
+
"has_scaling": False,
|
|
497
|
+
"has_mask": False,
|
|
498
|
+
"num_q_heads": 0,
|
|
499
|
+
"num_kv_heads": 0,
|
|
500
|
+
}
|
|
501
|
+
|
|
502
|
+
# Trace backward from softmax
|
|
503
|
+
# Pattern: (Q @ K^T) / sqrt(d_k) -> Softmax -> @ V
|
|
504
|
+
|
|
505
|
+
if not softmax.inputs:
|
|
506
|
+
return None
|
|
507
|
+
|
|
508
|
+
# Look for scaling (Div or Mul) and mask (Add) before softmax
|
|
509
|
+
current = softmax.inputs[0]
|
|
510
|
+
if current in graph_info.node_by_output:
|
|
511
|
+
prev = graph_info.node_by_output[current]
|
|
512
|
+
|
|
513
|
+
# Check for mask addition
|
|
514
|
+
if prev.op_type == "Add":
|
|
515
|
+
result["has_mask"] = True
|
|
516
|
+
result["nodes"].append(prev.name)
|
|
517
|
+
if prev.inputs:
|
|
518
|
+
current = prev.inputs[0]
|
|
519
|
+
if current in graph_info.node_by_output:
|
|
520
|
+
prev = graph_info.node_by_output[current]
|
|
521
|
+
|
|
522
|
+
# Check for scaling
|
|
523
|
+
if prev.op_type in ("Div", "Mul"):
|
|
524
|
+
result["has_scaling"] = True
|
|
525
|
+
result["nodes"].append(prev.name)
|
|
526
|
+
if prev.inputs:
|
|
527
|
+
current = prev.inputs[0]
|
|
528
|
+
if current in graph_info.node_by_output:
|
|
529
|
+
prev = graph_info.node_by_output[current]
|
|
530
|
+
|
|
531
|
+
# Look for Q @ K^T (MatMul)
|
|
532
|
+
if prev.op_type == "MatMul":
|
|
533
|
+
result["nodes"].append(prev.name)
|
|
534
|
+
|
|
535
|
+
# Trace Q and K projections
|
|
536
|
+
if len(prev.inputs) >= 2:
|
|
537
|
+
q_input, k_input = prev.inputs[0], prev.inputs[1]
|
|
538
|
+
|
|
539
|
+
# Find Q projection
|
|
540
|
+
q_proj = self._find_linear_proj(q_input, graph_info)
|
|
541
|
+
if q_proj:
|
|
542
|
+
result["q_proj"] = q_proj["name"]
|
|
543
|
+
result["nodes"].append(q_proj["name"])
|
|
544
|
+
result["num_q_heads"] = q_proj.get("num_heads", 0)
|
|
545
|
+
|
|
546
|
+
# Find K projection (may go through Transpose)
|
|
547
|
+
k_proj = self._find_linear_proj(k_input, graph_info, through_transpose=True)
|
|
548
|
+
if k_proj:
|
|
549
|
+
result["k_proj"] = k_proj["name"]
|
|
550
|
+
result["nodes"].append(k_proj["name"])
|
|
551
|
+
result["num_kv_heads"] = k_proj.get("num_heads", 0)
|
|
552
|
+
|
|
553
|
+
# Trace forward from softmax to find V @ attention and output projection
|
|
554
|
+
if softmax.outputs:
|
|
555
|
+
consumer = self._find_consumer(softmax.outputs[0], graph_info)
|
|
556
|
+
if consumer and consumer.op_type == "MatMul":
|
|
557
|
+
result["nodes"].append(consumer.name)
|
|
558
|
+
|
|
559
|
+
# Find V projection
|
|
560
|
+
if len(consumer.inputs) >= 2:
|
|
561
|
+
v_input = consumer.inputs[1] # Second input is V
|
|
562
|
+
v_proj = self._find_linear_proj(v_input, graph_info)
|
|
563
|
+
if v_proj:
|
|
564
|
+
result["v_proj"] = v_proj["name"]
|
|
565
|
+
result["nodes"].append(v_proj["name"])
|
|
566
|
+
|
|
567
|
+
# Find output projection
|
|
568
|
+
if consumer.outputs:
|
|
569
|
+
o_consumer = self._find_consumer(consumer.outputs[0], graph_info)
|
|
570
|
+
if o_consumer and o_consumer.op_type in ("MatMul", "Gemm"):
|
|
571
|
+
result["o_proj"] = o_consumer.name
|
|
572
|
+
result["nodes"].append(o_consumer.name)
|
|
573
|
+
|
|
574
|
+
# Only return if we found at least the core attention pattern
|
|
575
|
+
if len(result["nodes"]) >= 3: # softmax + at least 2 matmuls
|
|
576
|
+
return result
|
|
577
|
+
return None
|
|
578
|
+
|
|
579
|
+
def _find_linear_proj(
|
|
580
|
+
self, tensor_name: str, graph_info: GraphInfo, through_transpose: bool = False
|
|
581
|
+
) -> dict | None:
|
|
582
|
+
"""Find a linear projection (MatMul/Gemm) producing this tensor."""
|
|
583
|
+
current = tensor_name
|
|
584
|
+
|
|
585
|
+
# Optionally look through Transpose (for K^T in attention)
|
|
586
|
+
if through_transpose and current in graph_info.node_by_output:
|
|
587
|
+
node = graph_info.node_by_output[current]
|
|
588
|
+
if node.op_type == "Transpose":
|
|
589
|
+
if node.inputs:
|
|
590
|
+
current = node.inputs[0]
|
|
591
|
+
|
|
592
|
+
# Also look through Reshape (for multi-head splitting)
|
|
593
|
+
if current in graph_info.node_by_output:
|
|
594
|
+
node = graph_info.node_by_output[current]
|
|
595
|
+
if node.op_type == "Reshape":
|
|
596
|
+
if node.inputs:
|
|
597
|
+
current = node.inputs[0]
|
|
598
|
+
|
|
599
|
+
# Find the MatMul/Gemm
|
|
600
|
+
if current in graph_info.node_by_output:
|
|
601
|
+
node = graph_info.node_by_output[current]
|
|
602
|
+
if node.op_type in ("MatMul", "Gemm"):
|
|
603
|
+
# Try to infer number of heads from weight shape
|
|
604
|
+
num_heads = 0
|
|
605
|
+
if len(node.inputs) >= 2:
|
|
606
|
+
weight_name = node.inputs[1]
|
|
607
|
+
if weight_name in graph_info.initializers:
|
|
608
|
+
weight = graph_info.initializers[weight_name]
|
|
609
|
+
if len(weight.shape) == 2:
|
|
610
|
+
# Typical shape: [hidden, num_heads * head_dim]
|
|
611
|
+
out_features = weight.shape[1]
|
|
612
|
+
# Common head_dim values: 64, 128
|
|
613
|
+
for head_dim in [64, 128, 96, 80]:
|
|
614
|
+
if out_features % head_dim == 0:
|
|
615
|
+
num_heads = out_features // head_dim
|
|
616
|
+
break
|
|
617
|
+
|
|
618
|
+
return {"name": node.name, "num_heads": num_heads}
|
|
619
|
+
|
|
620
|
+
return None
|
|
621
|
+
|
|
622
|
+
def detect_mlp_blocks(self, graph_info: GraphInfo) -> list[Block]:
|
|
623
|
+
"""
|
|
624
|
+
Detect MLP/FFN blocks in transformers.
|
|
625
|
+
|
|
626
|
+
Task 5.4.2: Detect MLP/FFN patterns.
|
|
627
|
+
|
|
628
|
+
Patterns detected:
|
|
629
|
+
- Standard FFN: Linear -> Activation -> Linear
|
|
630
|
+
- SwiGLU/GeGLU: Linear -> (Gate * Activation(Linear)) -> Linear
|
|
631
|
+
- Gated MLP: Uses element-wise gating
|
|
632
|
+
"""
|
|
633
|
+
blocks: list[Block] = []
|
|
634
|
+
visited: set[str] = set()
|
|
635
|
+
|
|
636
|
+
# Look for activation functions that typically appear in FFN
|
|
637
|
+
for node in graph_info.nodes:
|
|
638
|
+
if node.name in visited:
|
|
639
|
+
continue
|
|
640
|
+
|
|
641
|
+
if node.op_type in self.LLM_ACTIVATIONS or node.op_type == "Gelu":
|
|
642
|
+
mlp_info = self._trace_mlp_pattern(node, graph_info)
|
|
643
|
+
if mlp_info:
|
|
644
|
+
visited.update(mlp_info["nodes"])
|
|
645
|
+
|
|
646
|
+
blocks.append(
|
|
647
|
+
Block(
|
|
648
|
+
block_type="MLPBlock",
|
|
649
|
+
name=f"mlp_{len(blocks)}",
|
|
650
|
+
nodes=mlp_info["nodes"],
|
|
651
|
+
start_node=mlp_info["up_proj"],
|
|
652
|
+
end_node=mlp_info["down_proj"],
|
|
653
|
+
attributes={
|
|
654
|
+
"mlp_type": mlp_info["mlp_type"],
|
|
655
|
+
"hidden_dim": mlp_info.get("hidden_dim", 0),
|
|
656
|
+
"intermediate_dim": mlp_info.get("intermediate_dim", 0),
|
|
657
|
+
"activation": node.op_type,
|
|
658
|
+
"is_gated": mlp_info.get("is_gated", False),
|
|
659
|
+
},
|
|
660
|
+
)
|
|
661
|
+
)
|
|
662
|
+
|
|
663
|
+
return blocks
|
|
664
|
+
|
|
665
|
+
def _trace_mlp_pattern(
|
|
666
|
+
self, activation: NodeInfo, graph_info: GraphInfo
|
|
667
|
+
) -> dict[str, Any] | None:
|
|
668
|
+
"""Trace MLP pattern from activation function."""
|
|
669
|
+
result: dict[str, Any] = {
|
|
670
|
+
"nodes": [activation.name],
|
|
671
|
+
"mlp_type": "standard",
|
|
672
|
+
"is_gated": False,
|
|
673
|
+
}
|
|
674
|
+
|
|
675
|
+
# Trace backward to find up-projection (first linear)
|
|
676
|
+
if activation.inputs:
|
|
677
|
+
inp = activation.inputs[0]
|
|
678
|
+
if inp in graph_info.node_by_output:
|
|
679
|
+
prev = graph_info.node_by_output[inp]
|
|
680
|
+
if prev.op_type in ("MatMul", "Gemm"):
|
|
681
|
+
result["up_proj"] = prev.name
|
|
682
|
+
result["nodes"].append(prev.name)
|
|
683
|
+
|
|
684
|
+
# Get dimensions from weight
|
|
685
|
+
if len(prev.inputs) >= 2:
|
|
686
|
+
weight_name = prev.inputs[1]
|
|
687
|
+
if weight_name in graph_info.initializers:
|
|
688
|
+
weight = graph_info.initializers[weight_name]
|
|
689
|
+
if len(weight.shape) == 2:
|
|
690
|
+
result["hidden_dim"] = int(weight.shape[0])
|
|
691
|
+
result["intermediate_dim"] = int(weight.shape[1])
|
|
692
|
+
|
|
693
|
+
# Trace forward to find down-projection
|
|
694
|
+
# Handle gated patterns (SwiGLU): activation output may go through Mul
|
|
695
|
+
if activation.outputs:
|
|
696
|
+
consumer = self._find_consumer(activation.outputs[0], graph_info)
|
|
697
|
+
if consumer:
|
|
698
|
+
if consumer.op_type == "Mul":
|
|
699
|
+
# Gated pattern (SwiGLU/GeGLU)
|
|
700
|
+
result["is_gated"] = True
|
|
701
|
+
result["mlp_type"] = "gated"
|
|
702
|
+
result["nodes"].append(consumer.name)
|
|
703
|
+
|
|
704
|
+
if consumer.outputs:
|
|
705
|
+
down_proj = self._find_consumer(consumer.outputs[0], graph_info)
|
|
706
|
+
if down_proj and down_proj.op_type in ("MatMul", "Gemm"):
|
|
707
|
+
result["down_proj"] = down_proj.name
|
|
708
|
+
result["nodes"].append(down_proj.name)
|
|
709
|
+
return result
|
|
710
|
+
|
|
711
|
+
elif consumer.op_type in ("MatMul", "Gemm"):
|
|
712
|
+
# Standard FFN
|
|
713
|
+
result["down_proj"] = consumer.name
|
|
714
|
+
result["nodes"].append(consumer.name)
|
|
715
|
+
return result
|
|
716
|
+
|
|
717
|
+
# Only return if we found both projections
|
|
718
|
+
if "up_proj" in result and "down_proj" in result:
|
|
719
|
+
return result
|
|
720
|
+
return None
|
|
721
|
+
|
|
722
|
+
def detect_position_encoding(self, graph_info: GraphInfo) -> list[Block]:
|
|
723
|
+
"""
|
|
724
|
+
Detect position encoding patterns.
|
|
725
|
+
|
|
726
|
+
Task 5.4.3: Detect embedding patterns (RoPE, ALiBi, learned, sinusoidal).
|
|
727
|
+
|
|
728
|
+
Patterns:
|
|
729
|
+
- RoPE: Complex rotations using Sin/Cos on positions
|
|
730
|
+
- ALiBi: Learned linear biases added to attention
|
|
731
|
+
- Learned: Position embedding table (Gather)
|
|
732
|
+
- Sinusoidal: Sin/Cos computations on positions
|
|
733
|
+
"""
|
|
734
|
+
blocks: list[Block] = []
|
|
735
|
+
|
|
736
|
+
# Check for RoPE pattern (Sin/Cos operations followed by Mul)
|
|
737
|
+
sin_nodes = [n for n in graph_info.nodes if n.op_type == "Sin"]
|
|
738
|
+
cos_nodes = [n for n in graph_info.nodes if n.op_type == "Cos"]
|
|
739
|
+
|
|
740
|
+
if sin_nodes and cos_nodes:
|
|
741
|
+
# Likely RoPE or sinusoidal encoding
|
|
742
|
+
# RoPE typically has paired Sin/Cos that multiply with Q and K
|
|
743
|
+
rope_nodes = []
|
|
744
|
+
for sin in sin_nodes:
|
|
745
|
+
rope_nodes.append(sin.name)
|
|
746
|
+
for cos in cos_nodes:
|
|
747
|
+
rope_nodes.append(cos.name)
|
|
748
|
+
|
|
749
|
+
# Check if these feed into Mul operations (rotation pattern)
|
|
750
|
+
has_rotation = False
|
|
751
|
+
for node in graph_info.nodes:
|
|
752
|
+
if node.op_type == "Mul":
|
|
753
|
+
for inp in node.inputs:
|
|
754
|
+
if inp in graph_info.node_by_output:
|
|
755
|
+
prev = graph_info.node_by_output[inp]
|
|
756
|
+
if prev.op_type in ("Sin", "Cos"):
|
|
757
|
+
has_rotation = True
|
|
758
|
+
break
|
|
759
|
+
|
|
760
|
+
if has_rotation:
|
|
761
|
+
blocks.append(
|
|
762
|
+
Block(
|
|
763
|
+
block_type="PositionEncoding",
|
|
764
|
+
name="rope_encoding",
|
|
765
|
+
nodes=rope_nodes,
|
|
766
|
+
start_node=rope_nodes[0] if rope_nodes else "",
|
|
767
|
+
end_node=rope_nodes[-1] if rope_nodes else "",
|
|
768
|
+
attributes={
|
|
769
|
+
"encoding_type": "RoPE",
|
|
770
|
+
"num_sin": len(sin_nodes),
|
|
771
|
+
"num_cos": len(cos_nodes),
|
|
772
|
+
},
|
|
773
|
+
)
|
|
774
|
+
)
|
|
775
|
+
else:
|
|
776
|
+
blocks.append(
|
|
777
|
+
Block(
|
|
778
|
+
block_type="PositionEncoding",
|
|
779
|
+
name="sinusoidal_encoding",
|
|
780
|
+
nodes=rope_nodes,
|
|
781
|
+
start_node=rope_nodes[0] if rope_nodes else "",
|
|
782
|
+
end_node=rope_nodes[-1] if rope_nodes else "",
|
|
783
|
+
attributes={
|
|
784
|
+
"encoding_type": "sinusoidal",
|
|
785
|
+
},
|
|
786
|
+
)
|
|
787
|
+
)
|
|
788
|
+
|
|
789
|
+
# Check for learned position embeddings (small Gather, separate from token)
|
|
790
|
+
for node in graph_info.nodes:
|
|
791
|
+
if node.op_type == "Gather":
|
|
792
|
+
if node.inputs and node.inputs[0] in graph_info.initializers:
|
|
793
|
+
table = graph_info.initializers[node.inputs[0]]
|
|
794
|
+
if len(table.shape) == 2:
|
|
795
|
+
size, dim = table.shape
|
|
796
|
+
# Position embeddings are typically smaller than vocab
|
|
797
|
+
# (e.g., 512, 1024, 2048, 4096 max positions)
|
|
798
|
+
if 128 <= size <= 8192 and size not in [
|
|
799
|
+
30000,
|
|
800
|
+
32000,
|
|
801
|
+
50257,
|
|
802
|
+
50304,
|
|
803
|
+
65536,
|
|
804
|
+
128000,
|
|
805
|
+
151936,
|
|
806
|
+
]:
|
|
807
|
+
# Likely position embedding, not token embedding
|
|
808
|
+
blocks.append(
|
|
809
|
+
Block(
|
|
810
|
+
block_type="PositionEncoding",
|
|
811
|
+
name=f"learned_position_{len(blocks)}",
|
|
812
|
+
nodes=[node.name],
|
|
813
|
+
start_node=node.name,
|
|
814
|
+
end_node=node.name,
|
|
815
|
+
attributes={
|
|
816
|
+
"encoding_type": "learned",
|
|
817
|
+
"max_positions": int(size),
|
|
818
|
+
"embed_dim": int(dim),
|
|
819
|
+
},
|
|
820
|
+
)
|
|
821
|
+
)
|
|
822
|
+
|
|
823
|
+
return blocks
|
|
824
|
+
|
|
825
|
+
def detect_moe_routing(self, graph_info: GraphInfo) -> list[Block]:
|
|
826
|
+
"""
|
|
827
|
+
Detect Mixture of Experts (MoE) routing patterns.
|
|
828
|
+
|
|
829
|
+
Task 5.4.7: Handle MoE routing patterns.
|
|
830
|
+
|
|
831
|
+
MoE pattern:
|
|
832
|
+
- Router: Linear -> Softmax/TopK -> expert selection
|
|
833
|
+
- Experts: Multiple parallel FFN blocks
|
|
834
|
+
- Combine: Scatter/Gather to route tokens to experts
|
|
835
|
+
"""
|
|
836
|
+
blocks: list[Block] = []
|
|
837
|
+
|
|
838
|
+
# Look for TopK operations (expert selection)
|
|
839
|
+
topk_nodes = [n for n in graph_info.nodes if n.op_type == "TopK"]
|
|
840
|
+
|
|
841
|
+
for topk in topk_nodes:
|
|
842
|
+
# Check if this looks like MoE routing
|
|
843
|
+
# Pattern: Linear -> TopK (for top-k expert selection)
|
|
844
|
+
router_proj = None
|
|
845
|
+
if topk.inputs:
|
|
846
|
+
inp = topk.inputs[0]
|
|
847
|
+
# May go through Softmax first
|
|
848
|
+
if inp in graph_info.node_by_output:
|
|
849
|
+
prev = graph_info.node_by_output[inp]
|
|
850
|
+
if prev.op_type == "Softmax":
|
|
851
|
+
if prev.inputs:
|
|
852
|
+
inp = prev.inputs[0]
|
|
853
|
+
if inp in graph_info.node_by_output:
|
|
854
|
+
router_node = graph_info.node_by_output[inp]
|
|
855
|
+
if router_node.op_type in ("MatMul", "Gemm"):
|
|
856
|
+
router_proj = router_node.name
|
|
857
|
+
|
|
858
|
+
if router_proj:
|
|
859
|
+
# Try to infer number of experts from router output shape
|
|
860
|
+
num_experts = 0
|
|
861
|
+
k_value = 0
|
|
862
|
+
|
|
863
|
+
# Check TopK k attribute
|
|
864
|
+
for attr in (
|
|
865
|
+
getattr(topk, "attributes", {}).items() if hasattr(topk, "attributes") else []
|
|
866
|
+
):
|
|
867
|
+
if attr[0] == "k":
|
|
868
|
+
k_value = attr[1]
|
|
869
|
+
|
|
870
|
+
blocks.append(
|
|
871
|
+
Block(
|
|
872
|
+
block_type="MoERouter",
|
|
873
|
+
name=f"moe_router_{len(blocks)}",
|
|
874
|
+
nodes=[router_proj, topk.name],
|
|
875
|
+
start_node=router_proj,
|
|
876
|
+
end_node=topk.name,
|
|
877
|
+
attributes={
|
|
878
|
+
"num_experts": num_experts,
|
|
879
|
+
"top_k": k_value,
|
|
880
|
+
"router_type": "top_k",
|
|
881
|
+
},
|
|
882
|
+
)
|
|
883
|
+
)
|
|
884
|
+
|
|
885
|
+
return blocks
|
|
886
|
+
|
|
887
|
+
def detect_repeated_blocks(
|
|
888
|
+
self, graph_info: GraphInfo, existing_blocks: list[Block]
|
|
889
|
+
) -> list[Block]:
|
|
890
|
+
"""
|
|
891
|
+
Detect repeated identical blocks (e.g., N transformer layers).
|
|
892
|
+
|
|
893
|
+
Task 5.4.5: Detect repetition - N identical blocks -> collapse with xN.
|
|
894
|
+
"""
|
|
895
|
+
blocks: list[Block] = []
|
|
896
|
+
|
|
897
|
+
# Group existing blocks by type
|
|
898
|
+
blocks_by_type: dict[str, list[Block]] = {}
|
|
899
|
+
for block in existing_blocks:
|
|
900
|
+
if block.block_type not in blocks_by_type:
|
|
901
|
+
blocks_by_type[block.block_type] = []
|
|
902
|
+
blocks_by_type[block.block_type].append(block)
|
|
903
|
+
|
|
904
|
+
# Check for repeated patterns
|
|
905
|
+
for block_type, type_blocks in blocks_by_type.items():
|
|
906
|
+
if len(type_blocks) >= 4: # At least 4 repetitions to be significant
|
|
907
|
+
# Check if blocks have similar structure
|
|
908
|
+
# (simplified: just count them for now)
|
|
909
|
+
blocks.append(
|
|
910
|
+
Block(
|
|
911
|
+
block_type="RepeatedBlock",
|
|
912
|
+
name=f"repeated_{block_type}",
|
|
913
|
+
nodes=[], # Meta-block, doesn't own nodes directly
|
|
914
|
+
start_node=type_blocks[0].start_node,
|
|
915
|
+
end_node=type_blocks[-1].end_node,
|
|
916
|
+
attributes={
|
|
917
|
+
"repeated_type": block_type,
|
|
918
|
+
"num_repetitions": len(type_blocks),
|
|
919
|
+
"block_names": [b.name for b in type_blocks],
|
|
920
|
+
},
|
|
921
|
+
)
|
|
922
|
+
)
|
|
923
|
+
|
|
924
|
+
return blocks
|
|
925
|
+
|
|
926
|
+
def detect_normalization_pattern(self, graph_info: GraphInfo) -> dict:
|
|
927
|
+
"""
|
|
928
|
+
Detect normalization placement pattern (pre-norm vs post-norm).
|
|
929
|
+
|
|
930
|
+
Task 5.4.4: Detect normalization placement.
|
|
931
|
+
|
|
932
|
+
Pre-norm (modern, e.g., LLaMA, GPT-3):
|
|
933
|
+
LayerNorm -> Attention -> Residual Add
|
|
934
|
+
LayerNorm -> FFN -> Residual Add
|
|
935
|
+
|
|
936
|
+
Post-norm (original transformer, e.g., BERT):
|
|
937
|
+
Attention -> Residual Add -> LayerNorm
|
|
938
|
+
FFN -> Residual Add -> LayerNorm
|
|
939
|
+
"""
|
|
940
|
+
result = {
|
|
941
|
+
"pattern": "unknown",
|
|
942
|
+
"num_layernorms": 0,
|
|
943
|
+
"has_rmsnorm": False,
|
|
944
|
+
}
|
|
945
|
+
|
|
946
|
+
# Count normalization ops
|
|
947
|
+
ln_count = graph_info.op_type_counts.get("LayerNormalization", 0)
|
|
948
|
+
rms_count = sum(
|
|
949
|
+
1 for n in graph_info.nodes if "rms" in n.name.lower() or "rmsnorm" in n.name.lower()
|
|
950
|
+
)
|
|
951
|
+
result["num_layernorms"] = ln_count
|
|
952
|
+
result["has_rmsnorm"] = rms_count > 0 or any(
|
|
953
|
+
n.op_type == "SimplifiedLayerNormalization" for n in graph_info.nodes
|
|
954
|
+
)
|
|
955
|
+
|
|
956
|
+
if ln_count == 0 and not result["has_rmsnorm"]:
|
|
957
|
+
result["pattern"] = "none"
|
|
958
|
+
return result
|
|
959
|
+
|
|
960
|
+
# Analyze pattern by checking what comes after Add (residual)
|
|
961
|
+
add_nodes = [n for n in graph_info.nodes if n.op_type == "Add"]
|
|
962
|
+
|
|
963
|
+
ln_after_add = 0
|
|
964
|
+
ln_before_matmul = 0
|
|
965
|
+
|
|
966
|
+
for add in add_nodes:
|
|
967
|
+
if add.outputs:
|
|
968
|
+
consumer = self._find_consumer(add.outputs[0], graph_info)
|
|
969
|
+
if consumer and consumer.op_type in (
|
|
970
|
+
"LayerNormalization",
|
|
971
|
+
"SimplifiedLayerNormalization",
|
|
972
|
+
):
|
|
973
|
+
ln_after_add += 1
|
|
974
|
+
|
|
975
|
+
# Check if LayerNorm feeds into MatMul (pre-norm pattern)
|
|
976
|
+
for node in graph_info.nodes:
|
|
977
|
+
if node.op_type in ("LayerNormalization", "SimplifiedLayerNormalization"):
|
|
978
|
+
if node.outputs:
|
|
979
|
+
consumer = self._find_consumer(node.outputs[0], graph_info)
|
|
980
|
+
if consumer and consumer.op_type in ("MatMul", "Gemm"):
|
|
981
|
+
ln_before_matmul += 1
|
|
982
|
+
|
|
983
|
+
# Classify
|
|
984
|
+
if ln_after_add > ln_before_matmul:
|
|
985
|
+
result["pattern"] = "post_norm"
|
|
986
|
+
elif ln_before_matmul > ln_after_add:
|
|
987
|
+
result["pattern"] = "pre_norm"
|
|
988
|
+
elif ln_count > 0 or result["has_rmsnorm"]:
|
|
989
|
+
result["pattern"] = "mixed"
|
|
990
|
+
|
|
991
|
+
return result
|
|
992
|
+
|
|
993
|
+
def classify_architecture(self, graph_info: GraphInfo, blocks: list[Block]) -> str:
|
|
994
|
+
"""
|
|
995
|
+
Classify the overall architecture type.
|
|
996
|
+
|
|
997
|
+
Args:
|
|
998
|
+
graph_info: Parsed graph information.
|
|
999
|
+
blocks: Detected blocks from group_into_blocks().
|
|
1000
|
+
|
|
1001
|
+
Returns:
|
|
1002
|
+
Architecture type: "transformer", "cnn", "mlp", "hybrid", "unknown"
|
|
1003
|
+
"""
|
|
1004
|
+
op_counts = graph_info.op_type_counts
|
|
1005
|
+
block_types = [b.block_type for b in blocks]
|
|
1006
|
+
|
|
1007
|
+
# Count key indicators
|
|
1008
|
+
has_attention = any("Attention" in bt for bt in block_types)
|
|
1009
|
+
has_mlp_block = any("MLPBlock" in bt for bt in block_types)
|
|
1010
|
+
has_layernorm = op_counts.get("LayerNormalization", 0) > 0
|
|
1011
|
+
has_embedding = any("Embedding" in bt for bt in block_types)
|
|
1012
|
+
has_moe = any("MoE" in bt for bt in block_types)
|
|
1013
|
+
has_rope = any(
|
|
1014
|
+
b.block_type == "PositionEncoding" and b.attributes.get("encoding_type") == "RoPE"
|
|
1015
|
+
for b in blocks
|
|
1016
|
+
)
|
|
1017
|
+
|
|
1018
|
+
# Include quantized variants (ConvInteger, MatMulInteger) for INT8 models
|
|
1019
|
+
conv_count = op_counts.get("Conv", 0) + op_counts.get("ConvInteger", 0)
|
|
1020
|
+
matmul_count = (
|
|
1021
|
+
op_counts.get("MatMul", 0)
|
|
1022
|
+
+ op_counts.get("Gemm", 0)
|
|
1023
|
+
+ op_counts.get("MatMulInteger", 0)
|
|
1024
|
+
)
|
|
1025
|
+
softmax_count = op_counts.get("Softmax", 0)
|
|
1026
|
+
|
|
1027
|
+
# Classification heuristics - more specific
|
|
1028
|
+
if has_moe:
|
|
1029
|
+
return "moe_transformer"
|
|
1030
|
+
elif has_attention or has_mlp_block or (has_layernorm and softmax_count >= 2):
|
|
1031
|
+
if has_rope:
|
|
1032
|
+
return "decoder_transformer" # LLaMA-style
|
|
1033
|
+
elif has_embedding:
|
|
1034
|
+
return "transformer"
|
|
1035
|
+
else:
|
|
1036
|
+
return "transformer"
|
|
1037
|
+
elif conv_count > matmul_count and conv_count >= 5:
|
|
1038
|
+
return "cnn"
|
|
1039
|
+
elif conv_count > 0 and (has_attention or has_layernorm):
|
|
1040
|
+
return "hybrid"
|
|
1041
|
+
elif matmul_count > 0:
|
|
1042
|
+
return "mlp"
|
|
1043
|
+
else:
|
|
1044
|
+
return "unknown"
|
|
1045
|
+
|
|
1046
|
+
def get_architecture_summary(self, graph_info: GraphInfo, blocks: list[Block]) -> dict:
|
|
1047
|
+
"""
|
|
1048
|
+
Get a detailed architecture summary for LLMs.
|
|
1049
|
+
|
|
1050
|
+
Returns comprehensive architecture info for reports.
|
|
1051
|
+
"""
|
|
1052
|
+
arch_type = self.classify_architecture(graph_info, blocks)
|
|
1053
|
+
norm_pattern = self.detect_normalization_pattern(graph_info)
|
|
1054
|
+
|
|
1055
|
+
# Count block types
|
|
1056
|
+
block_counts: dict[str, int] = {}
|
|
1057
|
+
for block in blocks:
|
|
1058
|
+
bt = block.block_type
|
|
1059
|
+
block_counts[bt] = block_counts.get(bt, 0) + 1
|
|
1060
|
+
|
|
1061
|
+
# Get attention info
|
|
1062
|
+
attention_blocks = [b for b in blocks if "Attention" in b.block_type]
|
|
1063
|
+
attention_type = "unknown"
|
|
1064
|
+
num_heads = 0
|
|
1065
|
+
num_kv_heads = 0
|
|
1066
|
+
|
|
1067
|
+
if attention_blocks:
|
|
1068
|
+
# Use first attention block's info
|
|
1069
|
+
first_attn = attention_blocks[0]
|
|
1070
|
+
attention_type = first_attn.attributes.get("attention_type", "unknown")
|
|
1071
|
+
num_heads = first_attn.attributes.get("num_q_heads", 0)
|
|
1072
|
+
num_kv_heads = first_attn.attributes.get("num_kv_heads", 0)
|
|
1073
|
+
|
|
1074
|
+
# Get MLP info
|
|
1075
|
+
mlp_blocks = [b for b in blocks if b.block_type == "MLPBlock"]
|
|
1076
|
+
mlp_type = "unknown"
|
|
1077
|
+
if mlp_blocks:
|
|
1078
|
+
mlp_type = mlp_blocks[0].attributes.get("mlp_type", "unknown")
|
|
1079
|
+
|
|
1080
|
+
# Get position encoding
|
|
1081
|
+
pos_blocks = [b for b in blocks if b.block_type == "PositionEncoding"]
|
|
1082
|
+
pos_encoding = "none"
|
|
1083
|
+
if pos_blocks:
|
|
1084
|
+
pos_encoding = pos_blocks[0].attributes.get("encoding_type", "unknown")
|
|
1085
|
+
|
|
1086
|
+
# Get repetition info
|
|
1087
|
+
repeated_blocks = [b for b in blocks if b.block_type == "RepeatedBlock"]
|
|
1088
|
+
num_layers = 0
|
|
1089
|
+
for rb in repeated_blocks:
|
|
1090
|
+
if rb.attributes.get("repeated_type") in ("AttentionHead", "Attention"):
|
|
1091
|
+
num_layers = rb.attributes.get("num_repetitions", 0)
|
|
1092
|
+
break
|
|
1093
|
+
|
|
1094
|
+
return {
|
|
1095
|
+
"architecture_type": arch_type,
|
|
1096
|
+
"normalization": norm_pattern,
|
|
1097
|
+
"block_counts": block_counts,
|
|
1098
|
+
"attention": {
|
|
1099
|
+
"type": attention_type,
|
|
1100
|
+
"num_q_heads": num_heads,
|
|
1101
|
+
"num_kv_heads": num_kv_heads,
|
|
1102
|
+
},
|
|
1103
|
+
"mlp": {
|
|
1104
|
+
"type": mlp_type,
|
|
1105
|
+
},
|
|
1106
|
+
"position_encoding": pos_encoding,
|
|
1107
|
+
"num_layers": num_layers,
|
|
1108
|
+
"total_blocks": len(blocks),
|
|
1109
|
+
}
|
|
1110
|
+
|
|
1111
|
+
def _find_consumer(self, output_name: str, graph_info: GraphInfo) -> NodeInfo | None:
|
|
1112
|
+
"""Find the first node that consumes a given output."""
|
|
1113
|
+
for node in graph_info.nodes:
|
|
1114
|
+
if output_name in node.inputs:
|
|
1115
|
+
return node
|
|
1116
|
+
return None
|