haoline 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. haoline/.streamlit/config.toml +10 -0
  2. haoline/__init__.py +248 -0
  3. haoline/analyzer.py +935 -0
  4. haoline/cli.py +2712 -0
  5. haoline/compare.py +811 -0
  6. haoline/compare_visualizations.py +1564 -0
  7. haoline/edge_analysis.py +525 -0
  8. haoline/eval/__init__.py +131 -0
  9. haoline/eval/adapters.py +844 -0
  10. haoline/eval/cli.py +390 -0
  11. haoline/eval/comparison.py +542 -0
  12. haoline/eval/deployment.py +633 -0
  13. haoline/eval/schemas.py +833 -0
  14. haoline/examples/__init__.py +15 -0
  15. haoline/examples/basic_inspection.py +74 -0
  16. haoline/examples/compare_models.py +117 -0
  17. haoline/examples/hardware_estimation.py +78 -0
  18. haoline/format_adapters.py +1001 -0
  19. haoline/formats/__init__.py +123 -0
  20. haoline/formats/coreml.py +250 -0
  21. haoline/formats/gguf.py +483 -0
  22. haoline/formats/openvino.py +255 -0
  23. haoline/formats/safetensors.py +273 -0
  24. haoline/formats/tflite.py +369 -0
  25. haoline/hardware.py +2307 -0
  26. haoline/hierarchical_graph.py +462 -0
  27. haoline/html_export.py +1573 -0
  28. haoline/layer_summary.py +769 -0
  29. haoline/llm_summarizer.py +465 -0
  30. haoline/op_icons.py +618 -0
  31. haoline/operational_profiling.py +1492 -0
  32. haoline/patterns.py +1116 -0
  33. haoline/pdf_generator.py +265 -0
  34. haoline/privacy.py +250 -0
  35. haoline/pydantic_models.py +241 -0
  36. haoline/report.py +1923 -0
  37. haoline/report_sections.py +539 -0
  38. haoline/risks.py +521 -0
  39. haoline/schema.py +523 -0
  40. haoline/streamlit_app.py +2024 -0
  41. haoline/tests/__init__.py +4 -0
  42. haoline/tests/conftest.py +123 -0
  43. haoline/tests/test_analyzer.py +868 -0
  44. haoline/tests/test_compare_visualizations.py +293 -0
  45. haoline/tests/test_edge_analysis.py +243 -0
  46. haoline/tests/test_eval.py +604 -0
  47. haoline/tests/test_format_adapters.py +460 -0
  48. haoline/tests/test_hardware.py +237 -0
  49. haoline/tests/test_hardware_recommender.py +90 -0
  50. haoline/tests/test_hierarchical_graph.py +326 -0
  51. haoline/tests/test_html_export.py +180 -0
  52. haoline/tests/test_layer_summary.py +428 -0
  53. haoline/tests/test_llm_patterns.py +540 -0
  54. haoline/tests/test_llm_summarizer.py +339 -0
  55. haoline/tests/test_patterns.py +774 -0
  56. haoline/tests/test_pytorch.py +327 -0
  57. haoline/tests/test_report.py +383 -0
  58. haoline/tests/test_risks.py +398 -0
  59. haoline/tests/test_schema.py +417 -0
  60. haoline/tests/test_tensorflow.py +380 -0
  61. haoline/tests/test_visualizations.py +316 -0
  62. haoline/universal_ir.py +856 -0
  63. haoline/visualizations.py +1086 -0
  64. haoline/visualize_yolo.py +44 -0
  65. haoline/web.py +110 -0
  66. haoline-0.3.0.dist-info/METADATA +471 -0
  67. haoline-0.3.0.dist-info/RECORD +70 -0
  68. haoline-0.3.0.dist-info/WHEEL +4 -0
  69. haoline-0.3.0.dist-info/entry_points.txt +5 -0
  70. haoline-0.3.0.dist-info/licenses/LICENSE +22 -0
haoline/patterns.py ADDED
@@ -0,0 +1,1116 @@
1
+ # Copyright (c) 2025 HaoLine Contributors
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ """
5
+ Pattern detection for HaoLine.
6
+
7
+ Detects common architectural patterns in ONNX graphs:
8
+ - Conv-BatchNorm-ReLU blocks
9
+ - Residual/skip connections
10
+ - Transformer blocks (attention + MLP)
11
+ - Embedding layers
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import logging
17
+ from dataclasses import dataclass, field
18
+ from typing import TYPE_CHECKING, Any, ClassVar
19
+
20
+ if TYPE_CHECKING:
21
+ from .analyzer import GraphInfo, NodeInfo
22
+
23
+
24
+ @dataclass
25
+ class Block:
26
+ """
27
+ A detected architectural block (group of related nodes).
28
+
29
+ Blocks represent higher-level patterns like "ResidualBlock" or
30
+ "TransformerLayer" that consist of multiple ONNX nodes.
31
+ """
32
+
33
+ block_type: str # e.g., "ConvBNRelu", "ResidualBlock", "TransformerBlock"
34
+ name: str
35
+ nodes: list[str] # Node names in this block
36
+ start_node: str
37
+ end_node: str
38
+ attributes: dict = field(default_factory=dict) # Block-specific metadata
39
+
40
+ def to_dict(self) -> dict:
41
+ return {
42
+ "block_type": self.block_type,
43
+ "name": self.name,
44
+ "nodes": self.nodes,
45
+ "start_node": self.start_node,
46
+ "end_node": self.end_node,
47
+ "attributes": self.attributes,
48
+ }
49
+
50
+
51
+ class PatternAnalyzer:
52
+ """
53
+ Detect architectural patterns in ONNX graphs.
54
+
55
+ Identifies common patterns like Conv-BN-ReLU sequences, residual
56
+ blocks, and transformer attention blocks.
57
+ """
58
+
59
+ # Operators that commonly appear together
60
+ CONV_ACTIVATIONS: ClassVar[set[str]] = {
61
+ "Relu",
62
+ "LeakyRelu",
63
+ "Sigmoid",
64
+ "Tanh",
65
+ "Clip",
66
+ "HardSwish",
67
+ "Silu",
68
+ }
69
+ NORM_OPS: ClassVar[set[str]] = {
70
+ "BatchNormalization",
71
+ "InstanceNormalization",
72
+ "LayerNormalization",
73
+ "GroupNormalization",
74
+ }
75
+ ATTENTION_OPS: ClassVar[set[str]] = {"MatMul", "Softmax", "Transpose"}
76
+ EMBEDDING_OPS: ClassVar[set[str]] = {"Gather", "Embedding"}
77
+
78
+ # LLM-specific activation functions
79
+ LLM_ACTIVATIONS: ClassVar[set[str]] = {
80
+ "Gelu",
81
+ "FastGelu",
82
+ "QuickGelu", # GPT-style
83
+ "Silu",
84
+ "Swish", # LLaMA/Mistral style
85
+ "Relu",
86
+ "LeakyRelu", # Classic
87
+ "NewGelu", # Some implementations
88
+ }
89
+
90
+ # MoE-related ops
91
+ MOE_OPS: ClassVar[set[str]] = {"TopK", "Scatter", "ScatterND", "GatherND"}
92
+
93
+ def __init__(self, logger: logging.Logger | None = None):
94
+ self.logger = logger or logging.getLogger("haoline.patterns")
95
+
96
+ def group_into_blocks(self, graph_info: GraphInfo) -> list[Block]:
97
+ """
98
+ Detect all architectural blocks in the graph.
99
+
100
+ Args:
101
+ graph_info: Parsed graph information.
102
+
103
+ Returns:
104
+ List of detected Block instances.
105
+ """
106
+ blocks: list[Block] = []
107
+
108
+ # Detect various patterns - ordered from specific to general
109
+ blocks.extend(self.detect_conv_bn_relu(graph_info))
110
+ blocks.extend(self.detect_residual_blocks(graph_info))
111
+ blocks.extend(self.detect_nonstandard_residual_blocks(graph_info))
112
+
113
+ # LLM-specific patterns (Task 5.4.1-5.4.7)
114
+ blocks.extend(self.detect_attention_heads(graph_info))
115
+ blocks.extend(self.detect_mlp_blocks(graph_info))
116
+ blocks.extend(self.detect_embedding_layers(graph_info))
117
+ blocks.extend(self.detect_position_encoding(graph_info))
118
+ blocks.extend(self.detect_moe_routing(graph_info))
119
+
120
+ # High-level transformer detection (uses sub-blocks)
121
+ blocks.extend(self.detect_transformer_blocks(graph_info))
122
+
123
+ # Detect repeated blocks
124
+ repeated = self.detect_repeated_blocks(graph_info, blocks)
125
+ if repeated:
126
+ blocks.extend(repeated)
127
+
128
+ self.logger.debug(f"Detected {len(blocks)} blocks")
129
+ return blocks
130
+
131
+ def detect_conv_bn_relu(self, graph_info: GraphInfo) -> list[Block]:
132
+ """
133
+ Find Conv-BatchNorm-ReLU sequences.
134
+
135
+ Matches patterns like:
136
+ - Conv -> BatchNorm -> ReLU
137
+ - Conv -> ReLU
138
+ - Conv -> BatchNorm
139
+ """
140
+ blocks: list[Block] = []
141
+ visited: set[str] = set()
142
+
143
+ for node in graph_info.nodes:
144
+ if node.name in visited:
145
+ continue
146
+
147
+ if node.op_type == "Conv":
148
+ block_nodes = [node.name]
149
+ current_output = node.outputs[0] if node.outputs else None
150
+ block_type_parts = ["Conv"]
151
+
152
+ # Look for BatchNorm
153
+ if current_output:
154
+ next_node = self._find_consumer(current_output, graph_info)
155
+ if next_node and next_node.op_type in self.NORM_OPS:
156
+ block_nodes.append(next_node.name)
157
+ block_type_parts.append("BN")
158
+ current_output = next_node.outputs[0] if next_node.outputs else None
159
+
160
+ # Look for activation after BN
161
+ if current_output:
162
+ act_node = self._find_consumer(current_output, graph_info)
163
+ if act_node and act_node.op_type in self.CONV_ACTIVATIONS:
164
+ block_nodes.append(act_node.name)
165
+ block_type_parts.append(act_node.op_type)
166
+ elif next_node and next_node.op_type in self.CONV_ACTIVATIONS:
167
+ # Conv -> ReLU without BN
168
+ block_nodes.append(next_node.name)
169
+ block_type_parts.append(next_node.op_type)
170
+
171
+ if len(block_nodes) > 1:
172
+ visited.update(block_nodes)
173
+ block = Block(
174
+ block_type="".join(block_type_parts),
175
+ name=f"conv_block_{len(blocks)}",
176
+ nodes=block_nodes,
177
+ start_node=block_nodes[0],
178
+ end_node=block_nodes[-1],
179
+ )
180
+ blocks.append(block)
181
+
182
+ return blocks
183
+
184
+ def detect_residual_blocks(self, graph_info: GraphInfo) -> list[Block]:
185
+ """
186
+ Find residual/skip connection patterns.
187
+
188
+ Looks for Add nodes where one input comes from earlier in the graph
189
+ (skip connection).
190
+ """
191
+ blocks: list[Block] = []
192
+
193
+ for node in graph_info.nodes:
194
+ if node.op_type == "Add" and len(node.inputs) >= 2:
195
+ # Check if this could be a residual connection
196
+ # by looking for inputs that come from different depths
197
+ input_nodes = []
198
+ for inp in node.inputs:
199
+ if inp in graph_info.node_by_output:
200
+ input_nodes.append(graph_info.node_by_output[inp])
201
+
202
+ if len(input_nodes) >= 2:
203
+ # Heuristic: if one path is longer (more hops), it's likely the residual path
204
+ # For now, just detect the pattern exists
205
+ blocks.append(
206
+ Block(
207
+ block_type="ResidualAdd",
208
+ name=f"residual_{len(blocks)}",
209
+ nodes=[node.name],
210
+ start_node=node.name,
211
+ end_node=node.name,
212
+ attributes={"inputs": node.inputs, "variant": "standard"},
213
+ )
214
+ )
215
+
216
+ return blocks
217
+
218
+ def detect_nonstandard_residual_blocks(self, graph_info: GraphInfo) -> list[Block]:
219
+ """
220
+ Find non-standard residual/skip connection patterns.
221
+
222
+ Detects alternative skip connection implementations:
223
+ - Concat-based skip connections (DenseNet-style)
224
+ - Mul-based gating mechanisms (Highway networks, attention gates)
225
+ - Sub-based connections (rare but possible)
226
+
227
+ These may indicate custom architectures that need special handling.
228
+ """
229
+ blocks: list[Block] = []
230
+
231
+ # Concat-based skip connections (DenseNet-style)
232
+ for node in graph_info.nodes:
233
+ if node.op_type == "Concat" and len(node.inputs) >= 2:
234
+ # Check if inputs come from different depths (skip connection indicator)
235
+ input_depths = self._estimate_input_depths(node.inputs, graph_info)
236
+ if input_depths and max(input_depths) - min(input_depths) >= 2:
237
+ blocks.append(
238
+ Block(
239
+ block_type="ResidualConcat",
240
+ name=f"dense_skip_{len(blocks)}",
241
+ nodes=[node.name],
242
+ start_node=node.name,
243
+ end_node=node.name,
244
+ attributes={
245
+ "inputs": node.inputs,
246
+ "variant": "concat",
247
+ "depth_diff": max(input_depths) - min(input_depths),
248
+ },
249
+ )
250
+ )
251
+
252
+ # Mul-based gating (Highway networks, attention gates)
253
+ for node in graph_info.nodes:
254
+ if node.op_type == "Mul" and len(node.inputs) >= 2:
255
+ # Look for Sigmoid before Mul (gating pattern)
256
+ has_sigmoid_input = False
257
+ for inp in node.inputs:
258
+ if inp in graph_info.node_by_output:
259
+ prev_node = graph_info.node_by_output[inp]
260
+ if prev_node.op_type == "Sigmoid":
261
+ has_sigmoid_input = True
262
+ break
263
+
264
+ if has_sigmoid_input:
265
+ blocks.append(
266
+ Block(
267
+ block_type="ResidualGate",
268
+ name=f"gate_{len(blocks)}",
269
+ nodes=[node.name],
270
+ start_node=node.name,
271
+ end_node=node.name,
272
+ attributes={"inputs": node.inputs, "variant": "gated"},
273
+ )
274
+ )
275
+
276
+ # Sub-based connections (rare, but could be learned residual)
277
+ for node in graph_info.nodes:
278
+ if node.op_type == "Sub" and len(node.inputs) >= 2:
279
+ input_nodes = []
280
+ for inp in node.inputs:
281
+ if inp in graph_info.node_by_output:
282
+ input_nodes.append(graph_info.node_by_output[inp])
283
+
284
+ if len(input_nodes) >= 2:
285
+ blocks.append(
286
+ Block(
287
+ block_type="ResidualSub",
288
+ name=f"sub_residual_{len(blocks)}",
289
+ nodes=[node.name],
290
+ start_node=node.name,
291
+ end_node=node.name,
292
+ attributes={"inputs": node.inputs, "variant": "subtract"},
293
+ )
294
+ )
295
+
296
+ return blocks
297
+
298
+ def _estimate_input_depths(
299
+ self, inputs: list[str], graph_info: GraphInfo, max_depth: int = 20
300
+ ) -> list[int]:
301
+ """
302
+ Estimate the graph depth of each input tensor.
303
+
304
+ Returns a list of estimated depths (hops from graph inputs).
305
+ Used to detect skip connections where inputs come from very different depths.
306
+ """
307
+ depths = []
308
+ for inp in inputs:
309
+ depth = self._trace_depth(inp, graph_info, 0, max_depth)
310
+ depths.append(depth)
311
+ return depths
312
+
313
+ def _trace_depth(
314
+ self,
315
+ tensor_name: str,
316
+ graph_info: GraphInfo,
317
+ current_depth: int,
318
+ max_depth: int,
319
+ ) -> int:
320
+ """Recursively trace back to find the depth of a tensor."""
321
+ if current_depth >= max_depth:
322
+ return current_depth
323
+
324
+ # If it's a graph input or initializer, depth is 0
325
+ if tensor_name in graph_info.input_shapes:
326
+ return 0
327
+ if tensor_name in graph_info.initializers:
328
+ return 0
329
+
330
+ # Find the node that produces this tensor
331
+ if tensor_name in graph_info.node_by_output:
332
+ producer = graph_info.node_by_output[tensor_name]
333
+ if producer.inputs:
334
+ # Trace back through the first input
335
+ return 1 + self._trace_depth(
336
+ producer.inputs[0], graph_info, current_depth + 1, max_depth
337
+ )
338
+
339
+ return current_depth
340
+
341
+ def detect_transformer_blocks(self, graph_info: GraphInfo) -> list[Block]:
342
+ """
343
+ Find transformer attention patterns.
344
+
345
+ Looks for the characteristic Softmax in attention computation
346
+ and MatMul patterns for Q, K, V projections.
347
+ """
348
+ blocks: list[Block] = []
349
+ softmax_nodes = [n for n in graph_info.nodes if n.op_type == "Softmax"]
350
+
351
+ for softmax in softmax_nodes:
352
+ # Look for attention pattern: MatMul -> Softmax -> MatMul
353
+ before_nodes: list[str] = []
354
+ after_nodes: list[str] = []
355
+
356
+ # Find MatMul before softmax
357
+ if softmax.inputs:
358
+ inp = softmax.inputs[0]
359
+ if inp in graph_info.node_by_output:
360
+ prev = graph_info.node_by_output[inp]
361
+ if prev.op_type in (
362
+ "MatMul",
363
+ "Gemm",
364
+ "Div",
365
+ "Mul",
366
+ ): # Div for scaling
367
+ before_nodes.append(prev.name)
368
+
369
+ # Find MatMul after softmax
370
+ if softmax.outputs:
371
+ consumer = self._find_consumer(softmax.outputs[0], graph_info)
372
+ if consumer and consumer.op_type in ("MatMul", "Gemm"):
373
+ after_nodes.append(consumer.name)
374
+
375
+ if before_nodes and after_nodes:
376
+ all_nodes = [*before_nodes, softmax.name, *after_nodes]
377
+ blocks.append(
378
+ Block(
379
+ block_type="Attention",
380
+ name=f"attention_{len(blocks)}",
381
+ nodes=all_nodes,
382
+ start_node=before_nodes[0],
383
+ end_node=after_nodes[-1],
384
+ )
385
+ )
386
+
387
+ # Also look for LayerNorm which often brackets transformer blocks
388
+ layernorm_count = sum(1 for n in graph_info.nodes if n.op_type == "LayerNormalization")
389
+ if layernorm_count >= 2 and blocks:
390
+ # Likely a transformer architecture
391
+ self.logger.debug(
392
+ f"Found {len(blocks)} attention blocks with {layernorm_count} LayerNorms"
393
+ )
394
+
395
+ return blocks
396
+
397
+ def detect_embedding_layers(self, graph_info: GraphInfo) -> list[Block]:
398
+ """
399
+ Find embedding lookup patterns.
400
+
401
+ Looks for Gather operations on large weight tensors.
402
+ """
403
+ blocks: list[Block] = []
404
+
405
+ for node in graph_info.nodes:
406
+ if node.op_type == "Gather":
407
+ # Check if first input is a large initializer (embedding table)
408
+ if node.inputs and node.inputs[0] in graph_info.initializers:
409
+ embed_table = graph_info.initializers[node.inputs[0]]
410
+ if len(embed_table.shape) == 2:
411
+ vocab_size, embed_dim = embed_table.shape
412
+ # Token embedding typically has large vocab (>1000)
413
+ embed_type = "token" if vocab_size > 1000 else "position"
414
+ blocks.append(
415
+ Block(
416
+ block_type="Embedding",
417
+ name=f"embedding_{len(blocks)}",
418
+ nodes=[node.name],
419
+ start_node=node.name,
420
+ end_node=node.name,
421
+ attributes={
422
+ "vocab_size": int(vocab_size),
423
+ "embed_dim": int(embed_dim),
424
+ "embed_type": embed_type,
425
+ },
426
+ )
427
+ )
428
+
429
+ return blocks
430
+
431
+ def detect_attention_heads(self, graph_info: GraphInfo) -> list[Block]:
432
+ """
433
+ Detect attention head patterns with Q/K/V projections.
434
+
435
+ Task 5.4.1: Enhanced attention detection for LLMs.
436
+
437
+ Patterns detected:
438
+ - Standard MHA: Q, K, V linear -> attention -> output linear
439
+ - MQA: Single K, V shared across Q heads
440
+ - GQA: Grouped K, V (fewer KV heads than Q heads)
441
+ """
442
+ blocks: list[Block] = []
443
+ visited_softmax: set[str] = set()
444
+
445
+ # Find all Softmax nodes (attention core)
446
+ softmax_nodes = [n for n in graph_info.nodes if n.op_type == "Softmax"]
447
+
448
+ for softmax in softmax_nodes:
449
+ if softmax.name in visited_softmax:
450
+ continue
451
+
452
+ attention_info = self._trace_attention_pattern(softmax, graph_info)
453
+ if attention_info:
454
+ visited_softmax.add(softmax.name)
455
+
456
+ # Determine attention type
457
+ num_q_heads = attention_info.get("num_q_heads", 0)
458
+ num_kv_heads = attention_info.get("num_kv_heads", 0)
459
+
460
+ if num_kv_heads == 1:
461
+ attention_type = "MQA" # Multi-Query Attention
462
+ elif num_kv_heads > 0 and num_kv_heads < num_q_heads:
463
+ attention_type = "GQA" # Grouped-Query Attention
464
+ else:
465
+ attention_type = "MHA" # Standard Multi-Head Attention
466
+
467
+ blocks.append(
468
+ Block(
469
+ block_type="AttentionHead",
470
+ name=f"attention_head_{len(blocks)}",
471
+ nodes=attention_info.get("nodes", [softmax.name]),
472
+ start_node=attention_info.get("q_proj", softmax.name),
473
+ end_node=attention_info.get("o_proj", softmax.name),
474
+ attributes={
475
+ "attention_type": attention_type,
476
+ "num_q_heads": num_q_heads,
477
+ "num_kv_heads": num_kv_heads,
478
+ "has_scaling": attention_info.get("has_scaling", False),
479
+ "has_mask": attention_info.get("has_mask", False),
480
+ "q_proj": attention_info.get("q_proj"),
481
+ "k_proj": attention_info.get("k_proj"),
482
+ "v_proj": attention_info.get("v_proj"),
483
+ "o_proj": attention_info.get("o_proj"),
484
+ },
485
+ )
486
+ )
487
+
488
+ return blocks
489
+
490
+ def _trace_attention_pattern(
491
+ self, softmax: NodeInfo, graph_info: GraphInfo
492
+ ) -> dict[str, Any] | None:
493
+ """Trace back from Softmax to find Q/K/V projections."""
494
+ result: dict[str, Any] = {
495
+ "nodes": [softmax.name],
496
+ "has_scaling": False,
497
+ "has_mask": False,
498
+ "num_q_heads": 0,
499
+ "num_kv_heads": 0,
500
+ }
501
+
502
+ # Trace backward from softmax
503
+ # Pattern: (Q @ K^T) / sqrt(d_k) -> Softmax -> @ V
504
+
505
+ if not softmax.inputs:
506
+ return None
507
+
508
+ # Look for scaling (Div or Mul) and mask (Add) before softmax
509
+ current = softmax.inputs[0]
510
+ if current in graph_info.node_by_output:
511
+ prev = graph_info.node_by_output[current]
512
+
513
+ # Check for mask addition
514
+ if prev.op_type == "Add":
515
+ result["has_mask"] = True
516
+ result["nodes"].append(prev.name)
517
+ if prev.inputs:
518
+ current = prev.inputs[0]
519
+ if current in graph_info.node_by_output:
520
+ prev = graph_info.node_by_output[current]
521
+
522
+ # Check for scaling
523
+ if prev.op_type in ("Div", "Mul"):
524
+ result["has_scaling"] = True
525
+ result["nodes"].append(prev.name)
526
+ if prev.inputs:
527
+ current = prev.inputs[0]
528
+ if current in graph_info.node_by_output:
529
+ prev = graph_info.node_by_output[current]
530
+
531
+ # Look for Q @ K^T (MatMul)
532
+ if prev.op_type == "MatMul":
533
+ result["nodes"].append(prev.name)
534
+
535
+ # Trace Q and K projections
536
+ if len(prev.inputs) >= 2:
537
+ q_input, k_input = prev.inputs[0], prev.inputs[1]
538
+
539
+ # Find Q projection
540
+ q_proj = self._find_linear_proj(q_input, graph_info)
541
+ if q_proj:
542
+ result["q_proj"] = q_proj["name"]
543
+ result["nodes"].append(q_proj["name"])
544
+ result["num_q_heads"] = q_proj.get("num_heads", 0)
545
+
546
+ # Find K projection (may go through Transpose)
547
+ k_proj = self._find_linear_proj(k_input, graph_info, through_transpose=True)
548
+ if k_proj:
549
+ result["k_proj"] = k_proj["name"]
550
+ result["nodes"].append(k_proj["name"])
551
+ result["num_kv_heads"] = k_proj.get("num_heads", 0)
552
+
553
+ # Trace forward from softmax to find V @ attention and output projection
554
+ if softmax.outputs:
555
+ consumer = self._find_consumer(softmax.outputs[0], graph_info)
556
+ if consumer and consumer.op_type == "MatMul":
557
+ result["nodes"].append(consumer.name)
558
+
559
+ # Find V projection
560
+ if len(consumer.inputs) >= 2:
561
+ v_input = consumer.inputs[1] # Second input is V
562
+ v_proj = self._find_linear_proj(v_input, graph_info)
563
+ if v_proj:
564
+ result["v_proj"] = v_proj["name"]
565
+ result["nodes"].append(v_proj["name"])
566
+
567
+ # Find output projection
568
+ if consumer.outputs:
569
+ o_consumer = self._find_consumer(consumer.outputs[0], graph_info)
570
+ if o_consumer and o_consumer.op_type in ("MatMul", "Gemm"):
571
+ result["o_proj"] = o_consumer.name
572
+ result["nodes"].append(o_consumer.name)
573
+
574
+ # Only return if we found at least the core attention pattern
575
+ if len(result["nodes"]) >= 3: # softmax + at least 2 matmuls
576
+ return result
577
+ return None
578
+
579
+ def _find_linear_proj(
580
+ self, tensor_name: str, graph_info: GraphInfo, through_transpose: bool = False
581
+ ) -> dict | None:
582
+ """Find a linear projection (MatMul/Gemm) producing this tensor."""
583
+ current = tensor_name
584
+
585
+ # Optionally look through Transpose (for K^T in attention)
586
+ if through_transpose and current in graph_info.node_by_output:
587
+ node = graph_info.node_by_output[current]
588
+ if node.op_type == "Transpose":
589
+ if node.inputs:
590
+ current = node.inputs[0]
591
+
592
+ # Also look through Reshape (for multi-head splitting)
593
+ if current in graph_info.node_by_output:
594
+ node = graph_info.node_by_output[current]
595
+ if node.op_type == "Reshape":
596
+ if node.inputs:
597
+ current = node.inputs[0]
598
+
599
+ # Find the MatMul/Gemm
600
+ if current in graph_info.node_by_output:
601
+ node = graph_info.node_by_output[current]
602
+ if node.op_type in ("MatMul", "Gemm"):
603
+ # Try to infer number of heads from weight shape
604
+ num_heads = 0
605
+ if len(node.inputs) >= 2:
606
+ weight_name = node.inputs[1]
607
+ if weight_name in graph_info.initializers:
608
+ weight = graph_info.initializers[weight_name]
609
+ if len(weight.shape) == 2:
610
+ # Typical shape: [hidden, num_heads * head_dim]
611
+ out_features = weight.shape[1]
612
+ # Common head_dim values: 64, 128
613
+ for head_dim in [64, 128, 96, 80]:
614
+ if out_features % head_dim == 0:
615
+ num_heads = out_features // head_dim
616
+ break
617
+
618
+ return {"name": node.name, "num_heads": num_heads}
619
+
620
+ return None
621
+
622
+ def detect_mlp_blocks(self, graph_info: GraphInfo) -> list[Block]:
623
+ """
624
+ Detect MLP/FFN blocks in transformers.
625
+
626
+ Task 5.4.2: Detect MLP/FFN patterns.
627
+
628
+ Patterns detected:
629
+ - Standard FFN: Linear -> Activation -> Linear
630
+ - SwiGLU/GeGLU: Linear -> (Gate * Activation(Linear)) -> Linear
631
+ - Gated MLP: Uses element-wise gating
632
+ """
633
+ blocks: list[Block] = []
634
+ visited: set[str] = set()
635
+
636
+ # Look for activation functions that typically appear in FFN
637
+ for node in graph_info.nodes:
638
+ if node.name in visited:
639
+ continue
640
+
641
+ if node.op_type in self.LLM_ACTIVATIONS or node.op_type == "Gelu":
642
+ mlp_info = self._trace_mlp_pattern(node, graph_info)
643
+ if mlp_info:
644
+ visited.update(mlp_info["nodes"])
645
+
646
+ blocks.append(
647
+ Block(
648
+ block_type="MLPBlock",
649
+ name=f"mlp_{len(blocks)}",
650
+ nodes=mlp_info["nodes"],
651
+ start_node=mlp_info["up_proj"],
652
+ end_node=mlp_info["down_proj"],
653
+ attributes={
654
+ "mlp_type": mlp_info["mlp_type"],
655
+ "hidden_dim": mlp_info.get("hidden_dim", 0),
656
+ "intermediate_dim": mlp_info.get("intermediate_dim", 0),
657
+ "activation": node.op_type,
658
+ "is_gated": mlp_info.get("is_gated", False),
659
+ },
660
+ )
661
+ )
662
+
663
+ return blocks
664
+
665
+ def _trace_mlp_pattern(
666
+ self, activation: NodeInfo, graph_info: GraphInfo
667
+ ) -> dict[str, Any] | None:
668
+ """Trace MLP pattern from activation function."""
669
+ result: dict[str, Any] = {
670
+ "nodes": [activation.name],
671
+ "mlp_type": "standard",
672
+ "is_gated": False,
673
+ }
674
+
675
+ # Trace backward to find up-projection (first linear)
676
+ if activation.inputs:
677
+ inp = activation.inputs[0]
678
+ if inp in graph_info.node_by_output:
679
+ prev = graph_info.node_by_output[inp]
680
+ if prev.op_type in ("MatMul", "Gemm"):
681
+ result["up_proj"] = prev.name
682
+ result["nodes"].append(prev.name)
683
+
684
+ # Get dimensions from weight
685
+ if len(prev.inputs) >= 2:
686
+ weight_name = prev.inputs[1]
687
+ if weight_name in graph_info.initializers:
688
+ weight = graph_info.initializers[weight_name]
689
+ if len(weight.shape) == 2:
690
+ result["hidden_dim"] = int(weight.shape[0])
691
+ result["intermediate_dim"] = int(weight.shape[1])
692
+
693
+ # Trace forward to find down-projection
694
+ # Handle gated patterns (SwiGLU): activation output may go through Mul
695
+ if activation.outputs:
696
+ consumer = self._find_consumer(activation.outputs[0], graph_info)
697
+ if consumer:
698
+ if consumer.op_type == "Mul":
699
+ # Gated pattern (SwiGLU/GeGLU)
700
+ result["is_gated"] = True
701
+ result["mlp_type"] = "gated"
702
+ result["nodes"].append(consumer.name)
703
+
704
+ if consumer.outputs:
705
+ down_proj = self._find_consumer(consumer.outputs[0], graph_info)
706
+ if down_proj and down_proj.op_type in ("MatMul", "Gemm"):
707
+ result["down_proj"] = down_proj.name
708
+ result["nodes"].append(down_proj.name)
709
+ return result
710
+
711
+ elif consumer.op_type in ("MatMul", "Gemm"):
712
+ # Standard FFN
713
+ result["down_proj"] = consumer.name
714
+ result["nodes"].append(consumer.name)
715
+ return result
716
+
717
+ # Only return if we found both projections
718
+ if "up_proj" in result and "down_proj" in result:
719
+ return result
720
+ return None
721
+
722
+ def detect_position_encoding(self, graph_info: GraphInfo) -> list[Block]:
723
+ """
724
+ Detect position encoding patterns.
725
+
726
+ Task 5.4.3: Detect embedding patterns (RoPE, ALiBi, learned, sinusoidal).
727
+
728
+ Patterns:
729
+ - RoPE: Complex rotations using Sin/Cos on positions
730
+ - ALiBi: Learned linear biases added to attention
731
+ - Learned: Position embedding table (Gather)
732
+ - Sinusoidal: Sin/Cos computations on positions
733
+ """
734
+ blocks: list[Block] = []
735
+
736
+ # Check for RoPE pattern (Sin/Cos operations followed by Mul)
737
+ sin_nodes = [n for n in graph_info.nodes if n.op_type == "Sin"]
738
+ cos_nodes = [n for n in graph_info.nodes if n.op_type == "Cos"]
739
+
740
+ if sin_nodes and cos_nodes:
741
+ # Likely RoPE or sinusoidal encoding
742
+ # RoPE typically has paired Sin/Cos that multiply with Q and K
743
+ rope_nodes = []
744
+ for sin in sin_nodes:
745
+ rope_nodes.append(sin.name)
746
+ for cos in cos_nodes:
747
+ rope_nodes.append(cos.name)
748
+
749
+ # Check if these feed into Mul operations (rotation pattern)
750
+ has_rotation = False
751
+ for node in graph_info.nodes:
752
+ if node.op_type == "Mul":
753
+ for inp in node.inputs:
754
+ if inp in graph_info.node_by_output:
755
+ prev = graph_info.node_by_output[inp]
756
+ if prev.op_type in ("Sin", "Cos"):
757
+ has_rotation = True
758
+ break
759
+
760
+ if has_rotation:
761
+ blocks.append(
762
+ Block(
763
+ block_type="PositionEncoding",
764
+ name="rope_encoding",
765
+ nodes=rope_nodes,
766
+ start_node=rope_nodes[0] if rope_nodes else "",
767
+ end_node=rope_nodes[-1] if rope_nodes else "",
768
+ attributes={
769
+ "encoding_type": "RoPE",
770
+ "num_sin": len(sin_nodes),
771
+ "num_cos": len(cos_nodes),
772
+ },
773
+ )
774
+ )
775
+ else:
776
+ blocks.append(
777
+ Block(
778
+ block_type="PositionEncoding",
779
+ name="sinusoidal_encoding",
780
+ nodes=rope_nodes,
781
+ start_node=rope_nodes[0] if rope_nodes else "",
782
+ end_node=rope_nodes[-1] if rope_nodes else "",
783
+ attributes={
784
+ "encoding_type": "sinusoidal",
785
+ },
786
+ )
787
+ )
788
+
789
+ # Check for learned position embeddings (small Gather, separate from token)
790
+ for node in graph_info.nodes:
791
+ if node.op_type == "Gather":
792
+ if node.inputs and node.inputs[0] in graph_info.initializers:
793
+ table = graph_info.initializers[node.inputs[0]]
794
+ if len(table.shape) == 2:
795
+ size, dim = table.shape
796
+ # Position embeddings are typically smaller than vocab
797
+ # (e.g., 512, 1024, 2048, 4096 max positions)
798
+ if 128 <= size <= 8192 and size not in [
799
+ 30000,
800
+ 32000,
801
+ 50257,
802
+ 50304,
803
+ 65536,
804
+ 128000,
805
+ 151936,
806
+ ]:
807
+ # Likely position embedding, not token embedding
808
+ blocks.append(
809
+ Block(
810
+ block_type="PositionEncoding",
811
+ name=f"learned_position_{len(blocks)}",
812
+ nodes=[node.name],
813
+ start_node=node.name,
814
+ end_node=node.name,
815
+ attributes={
816
+ "encoding_type": "learned",
817
+ "max_positions": int(size),
818
+ "embed_dim": int(dim),
819
+ },
820
+ )
821
+ )
822
+
823
+ return blocks
824
+
825
+ def detect_moe_routing(self, graph_info: GraphInfo) -> list[Block]:
826
+ """
827
+ Detect Mixture of Experts (MoE) routing patterns.
828
+
829
+ Task 5.4.7: Handle MoE routing patterns.
830
+
831
+ MoE pattern:
832
+ - Router: Linear -> Softmax/TopK -> expert selection
833
+ - Experts: Multiple parallel FFN blocks
834
+ - Combine: Scatter/Gather to route tokens to experts
835
+ """
836
+ blocks: list[Block] = []
837
+
838
+ # Look for TopK operations (expert selection)
839
+ topk_nodes = [n for n in graph_info.nodes if n.op_type == "TopK"]
840
+
841
+ for topk in topk_nodes:
842
+ # Check if this looks like MoE routing
843
+ # Pattern: Linear -> TopK (for top-k expert selection)
844
+ router_proj = None
845
+ if topk.inputs:
846
+ inp = topk.inputs[0]
847
+ # May go through Softmax first
848
+ if inp in graph_info.node_by_output:
849
+ prev = graph_info.node_by_output[inp]
850
+ if prev.op_type == "Softmax":
851
+ if prev.inputs:
852
+ inp = prev.inputs[0]
853
+ if inp in graph_info.node_by_output:
854
+ router_node = graph_info.node_by_output[inp]
855
+ if router_node.op_type in ("MatMul", "Gemm"):
856
+ router_proj = router_node.name
857
+
858
+ if router_proj:
859
+ # Try to infer number of experts from router output shape
860
+ num_experts = 0
861
+ k_value = 0
862
+
863
+ # Check TopK k attribute
864
+ for attr in (
865
+ getattr(topk, "attributes", {}).items() if hasattr(topk, "attributes") else []
866
+ ):
867
+ if attr[0] == "k":
868
+ k_value = attr[1]
869
+
870
+ blocks.append(
871
+ Block(
872
+ block_type="MoERouter",
873
+ name=f"moe_router_{len(blocks)}",
874
+ nodes=[router_proj, topk.name],
875
+ start_node=router_proj,
876
+ end_node=topk.name,
877
+ attributes={
878
+ "num_experts": num_experts,
879
+ "top_k": k_value,
880
+ "router_type": "top_k",
881
+ },
882
+ )
883
+ )
884
+
885
+ return blocks
886
+
887
+ def detect_repeated_blocks(
888
+ self, graph_info: GraphInfo, existing_blocks: list[Block]
889
+ ) -> list[Block]:
890
+ """
891
+ Detect repeated identical blocks (e.g., N transformer layers).
892
+
893
+ Task 5.4.5: Detect repetition - N identical blocks -> collapse with xN.
894
+ """
895
+ blocks: list[Block] = []
896
+
897
+ # Group existing blocks by type
898
+ blocks_by_type: dict[str, list[Block]] = {}
899
+ for block in existing_blocks:
900
+ if block.block_type not in blocks_by_type:
901
+ blocks_by_type[block.block_type] = []
902
+ blocks_by_type[block.block_type].append(block)
903
+
904
+ # Check for repeated patterns
905
+ for block_type, type_blocks in blocks_by_type.items():
906
+ if len(type_blocks) >= 4: # At least 4 repetitions to be significant
907
+ # Check if blocks have similar structure
908
+ # (simplified: just count them for now)
909
+ blocks.append(
910
+ Block(
911
+ block_type="RepeatedBlock",
912
+ name=f"repeated_{block_type}",
913
+ nodes=[], # Meta-block, doesn't own nodes directly
914
+ start_node=type_blocks[0].start_node,
915
+ end_node=type_blocks[-1].end_node,
916
+ attributes={
917
+ "repeated_type": block_type,
918
+ "num_repetitions": len(type_blocks),
919
+ "block_names": [b.name for b in type_blocks],
920
+ },
921
+ )
922
+ )
923
+
924
+ return blocks
925
+
926
+ def detect_normalization_pattern(self, graph_info: GraphInfo) -> dict:
927
+ """
928
+ Detect normalization placement pattern (pre-norm vs post-norm).
929
+
930
+ Task 5.4.4: Detect normalization placement.
931
+
932
+ Pre-norm (modern, e.g., LLaMA, GPT-3):
933
+ LayerNorm -> Attention -> Residual Add
934
+ LayerNorm -> FFN -> Residual Add
935
+
936
+ Post-norm (original transformer, e.g., BERT):
937
+ Attention -> Residual Add -> LayerNorm
938
+ FFN -> Residual Add -> LayerNorm
939
+ """
940
+ result = {
941
+ "pattern": "unknown",
942
+ "num_layernorms": 0,
943
+ "has_rmsnorm": False,
944
+ }
945
+
946
+ # Count normalization ops
947
+ ln_count = graph_info.op_type_counts.get("LayerNormalization", 0)
948
+ rms_count = sum(
949
+ 1 for n in graph_info.nodes if "rms" in n.name.lower() or "rmsnorm" in n.name.lower()
950
+ )
951
+ result["num_layernorms"] = ln_count
952
+ result["has_rmsnorm"] = rms_count > 0 or any(
953
+ n.op_type == "SimplifiedLayerNormalization" for n in graph_info.nodes
954
+ )
955
+
956
+ if ln_count == 0 and not result["has_rmsnorm"]:
957
+ result["pattern"] = "none"
958
+ return result
959
+
960
+ # Analyze pattern by checking what comes after Add (residual)
961
+ add_nodes = [n for n in graph_info.nodes if n.op_type == "Add"]
962
+
963
+ ln_after_add = 0
964
+ ln_before_matmul = 0
965
+
966
+ for add in add_nodes:
967
+ if add.outputs:
968
+ consumer = self._find_consumer(add.outputs[0], graph_info)
969
+ if consumer and consumer.op_type in (
970
+ "LayerNormalization",
971
+ "SimplifiedLayerNormalization",
972
+ ):
973
+ ln_after_add += 1
974
+
975
+ # Check if LayerNorm feeds into MatMul (pre-norm pattern)
976
+ for node in graph_info.nodes:
977
+ if node.op_type in ("LayerNormalization", "SimplifiedLayerNormalization"):
978
+ if node.outputs:
979
+ consumer = self._find_consumer(node.outputs[0], graph_info)
980
+ if consumer and consumer.op_type in ("MatMul", "Gemm"):
981
+ ln_before_matmul += 1
982
+
983
+ # Classify
984
+ if ln_after_add > ln_before_matmul:
985
+ result["pattern"] = "post_norm"
986
+ elif ln_before_matmul > ln_after_add:
987
+ result["pattern"] = "pre_norm"
988
+ elif ln_count > 0 or result["has_rmsnorm"]:
989
+ result["pattern"] = "mixed"
990
+
991
+ return result
992
+
993
+ def classify_architecture(self, graph_info: GraphInfo, blocks: list[Block]) -> str:
994
+ """
995
+ Classify the overall architecture type.
996
+
997
+ Args:
998
+ graph_info: Parsed graph information.
999
+ blocks: Detected blocks from group_into_blocks().
1000
+
1001
+ Returns:
1002
+ Architecture type: "transformer", "cnn", "mlp", "hybrid", "unknown"
1003
+ """
1004
+ op_counts = graph_info.op_type_counts
1005
+ block_types = [b.block_type for b in blocks]
1006
+
1007
+ # Count key indicators
1008
+ has_attention = any("Attention" in bt for bt in block_types)
1009
+ has_mlp_block = any("MLPBlock" in bt for bt in block_types)
1010
+ has_layernorm = op_counts.get("LayerNormalization", 0) > 0
1011
+ has_embedding = any("Embedding" in bt for bt in block_types)
1012
+ has_moe = any("MoE" in bt for bt in block_types)
1013
+ has_rope = any(
1014
+ b.block_type == "PositionEncoding" and b.attributes.get("encoding_type") == "RoPE"
1015
+ for b in blocks
1016
+ )
1017
+
1018
+ # Include quantized variants (ConvInteger, MatMulInteger) for INT8 models
1019
+ conv_count = op_counts.get("Conv", 0) + op_counts.get("ConvInteger", 0)
1020
+ matmul_count = (
1021
+ op_counts.get("MatMul", 0)
1022
+ + op_counts.get("Gemm", 0)
1023
+ + op_counts.get("MatMulInteger", 0)
1024
+ )
1025
+ softmax_count = op_counts.get("Softmax", 0)
1026
+
1027
+ # Classification heuristics - more specific
1028
+ if has_moe:
1029
+ return "moe_transformer"
1030
+ elif has_attention or has_mlp_block or (has_layernorm and softmax_count >= 2):
1031
+ if has_rope:
1032
+ return "decoder_transformer" # LLaMA-style
1033
+ elif has_embedding:
1034
+ return "transformer"
1035
+ else:
1036
+ return "transformer"
1037
+ elif conv_count > matmul_count and conv_count >= 5:
1038
+ return "cnn"
1039
+ elif conv_count > 0 and (has_attention or has_layernorm):
1040
+ return "hybrid"
1041
+ elif matmul_count > 0:
1042
+ return "mlp"
1043
+ else:
1044
+ return "unknown"
1045
+
1046
+ def get_architecture_summary(self, graph_info: GraphInfo, blocks: list[Block]) -> dict:
1047
+ """
1048
+ Get a detailed architecture summary for LLMs.
1049
+
1050
+ Returns comprehensive architecture info for reports.
1051
+ """
1052
+ arch_type = self.classify_architecture(graph_info, blocks)
1053
+ norm_pattern = self.detect_normalization_pattern(graph_info)
1054
+
1055
+ # Count block types
1056
+ block_counts: dict[str, int] = {}
1057
+ for block in blocks:
1058
+ bt = block.block_type
1059
+ block_counts[bt] = block_counts.get(bt, 0) + 1
1060
+
1061
+ # Get attention info
1062
+ attention_blocks = [b for b in blocks if "Attention" in b.block_type]
1063
+ attention_type = "unknown"
1064
+ num_heads = 0
1065
+ num_kv_heads = 0
1066
+
1067
+ if attention_blocks:
1068
+ # Use first attention block's info
1069
+ first_attn = attention_blocks[0]
1070
+ attention_type = first_attn.attributes.get("attention_type", "unknown")
1071
+ num_heads = first_attn.attributes.get("num_q_heads", 0)
1072
+ num_kv_heads = first_attn.attributes.get("num_kv_heads", 0)
1073
+
1074
+ # Get MLP info
1075
+ mlp_blocks = [b for b in blocks if b.block_type == "MLPBlock"]
1076
+ mlp_type = "unknown"
1077
+ if mlp_blocks:
1078
+ mlp_type = mlp_blocks[0].attributes.get("mlp_type", "unknown")
1079
+
1080
+ # Get position encoding
1081
+ pos_blocks = [b for b in blocks if b.block_type == "PositionEncoding"]
1082
+ pos_encoding = "none"
1083
+ if pos_blocks:
1084
+ pos_encoding = pos_blocks[0].attributes.get("encoding_type", "unknown")
1085
+
1086
+ # Get repetition info
1087
+ repeated_blocks = [b for b in blocks if b.block_type == "RepeatedBlock"]
1088
+ num_layers = 0
1089
+ for rb in repeated_blocks:
1090
+ if rb.attributes.get("repeated_type") in ("AttentionHead", "Attention"):
1091
+ num_layers = rb.attributes.get("num_repetitions", 0)
1092
+ break
1093
+
1094
+ return {
1095
+ "architecture_type": arch_type,
1096
+ "normalization": norm_pattern,
1097
+ "block_counts": block_counts,
1098
+ "attention": {
1099
+ "type": attention_type,
1100
+ "num_q_heads": num_heads,
1101
+ "num_kv_heads": num_kv_heads,
1102
+ },
1103
+ "mlp": {
1104
+ "type": mlp_type,
1105
+ },
1106
+ "position_encoding": pos_encoding,
1107
+ "num_layers": num_layers,
1108
+ "total_blocks": len(blocks),
1109
+ }
1110
+
1111
+ def _find_consumer(self, output_name: str, graph_info: GraphInfo) -> NodeInfo | None:
1112
+ """Find the first node that consumes a given output."""
1113
+ for node in graph_info.nodes:
1114
+ if output_name in node.inputs:
1115
+ return node
1116
+ return None