haoline 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- haoline/.streamlit/config.toml +10 -0
- haoline/__init__.py +248 -0
- haoline/analyzer.py +935 -0
- haoline/cli.py +2712 -0
- haoline/compare.py +811 -0
- haoline/compare_visualizations.py +1564 -0
- haoline/edge_analysis.py +525 -0
- haoline/eval/__init__.py +131 -0
- haoline/eval/adapters.py +844 -0
- haoline/eval/cli.py +390 -0
- haoline/eval/comparison.py +542 -0
- haoline/eval/deployment.py +633 -0
- haoline/eval/schemas.py +833 -0
- haoline/examples/__init__.py +15 -0
- haoline/examples/basic_inspection.py +74 -0
- haoline/examples/compare_models.py +117 -0
- haoline/examples/hardware_estimation.py +78 -0
- haoline/format_adapters.py +1001 -0
- haoline/formats/__init__.py +123 -0
- haoline/formats/coreml.py +250 -0
- haoline/formats/gguf.py +483 -0
- haoline/formats/openvino.py +255 -0
- haoline/formats/safetensors.py +273 -0
- haoline/formats/tflite.py +369 -0
- haoline/hardware.py +2307 -0
- haoline/hierarchical_graph.py +462 -0
- haoline/html_export.py +1573 -0
- haoline/layer_summary.py +769 -0
- haoline/llm_summarizer.py +465 -0
- haoline/op_icons.py +618 -0
- haoline/operational_profiling.py +1492 -0
- haoline/patterns.py +1116 -0
- haoline/pdf_generator.py +265 -0
- haoline/privacy.py +250 -0
- haoline/pydantic_models.py +241 -0
- haoline/report.py +1923 -0
- haoline/report_sections.py +539 -0
- haoline/risks.py +521 -0
- haoline/schema.py +523 -0
- haoline/streamlit_app.py +2024 -0
- haoline/tests/__init__.py +4 -0
- haoline/tests/conftest.py +123 -0
- haoline/tests/test_analyzer.py +868 -0
- haoline/tests/test_compare_visualizations.py +293 -0
- haoline/tests/test_edge_analysis.py +243 -0
- haoline/tests/test_eval.py +604 -0
- haoline/tests/test_format_adapters.py +460 -0
- haoline/tests/test_hardware.py +237 -0
- haoline/tests/test_hardware_recommender.py +90 -0
- haoline/tests/test_hierarchical_graph.py +326 -0
- haoline/tests/test_html_export.py +180 -0
- haoline/tests/test_layer_summary.py +428 -0
- haoline/tests/test_llm_patterns.py +540 -0
- haoline/tests/test_llm_summarizer.py +339 -0
- haoline/tests/test_patterns.py +774 -0
- haoline/tests/test_pytorch.py +327 -0
- haoline/tests/test_report.py +383 -0
- haoline/tests/test_risks.py +398 -0
- haoline/tests/test_schema.py +417 -0
- haoline/tests/test_tensorflow.py +380 -0
- haoline/tests/test_visualizations.py +316 -0
- haoline/universal_ir.py +856 -0
- haoline/visualizations.py +1086 -0
- haoline/visualize_yolo.py +44 -0
- haoline/web.py +110 -0
- haoline-0.3.0.dist-info/METADATA +471 -0
- haoline-0.3.0.dist-info/RECORD +70 -0
- haoline-0.3.0.dist-info/WHEEL +4 -0
- haoline-0.3.0.dist-info/entry_points.txt +5 -0
- haoline-0.3.0.dist-info/licenses/LICENSE +22 -0
haoline/risks.py
ADDED
|
@@ -0,0 +1,521 @@
|
|
|
1
|
+
# Copyright (c) 2025 HaoLine Contributors
|
|
2
|
+
# SPDX-License-Identifier: MIT
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
Risk analysis for HaoLine.
|
|
6
|
+
|
|
7
|
+
Applies heuristics to detect potentially problematic patterns:
|
|
8
|
+
- Deep networks without skip connections
|
|
9
|
+
- Oversized dense layers
|
|
10
|
+
- Dynamic shapes that may cause issues
|
|
11
|
+
- Missing normalization
|
|
12
|
+
- Unusual activation patterns
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from dataclasses import dataclass, field
|
|
19
|
+
from typing import TYPE_CHECKING
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from .analyzer import GraphInfo
|
|
23
|
+
from .patterns import Block
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class RiskSignal:
|
|
28
|
+
"""
|
|
29
|
+
A detected risk or concern about the model architecture.
|
|
30
|
+
|
|
31
|
+
Risk signals are informational - they highlight patterns that
|
|
32
|
+
may cause issues but don't necessarily indicate problems.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
id: str # e.g., "no_skip_connections", "oversized_dense"
|
|
36
|
+
severity: str # "info" | "warning" | "high"
|
|
37
|
+
description: str
|
|
38
|
+
nodes: list[str] = field(default_factory=list)
|
|
39
|
+
recommendation: str = ""
|
|
40
|
+
|
|
41
|
+
def to_dict(self) -> dict:
|
|
42
|
+
return {
|
|
43
|
+
"id": self.id,
|
|
44
|
+
"severity": self.severity,
|
|
45
|
+
"description": self.description,
|
|
46
|
+
"nodes": self.nodes,
|
|
47
|
+
"recommendation": self.recommendation,
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass
|
|
52
|
+
class RiskThresholds:
|
|
53
|
+
"""
|
|
54
|
+
Configurable thresholds for risk detection.
|
|
55
|
+
|
|
56
|
+
Allows tuning sensitivity based on model type and use case.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
# Minimum thresholds - don't bother analyzing tiny models
|
|
60
|
+
min_params_for_analysis: int = 100_000 # 100K params minimum
|
|
61
|
+
min_flops_for_bottleneck: int = 1_000_000_000 # 1B FLOPs before flagging bottlenecks
|
|
62
|
+
min_nodes_for_depth_check: int = 20 # At least 20 nodes before checking depth
|
|
63
|
+
|
|
64
|
+
# Thresholds for risk detection
|
|
65
|
+
deep_network_threshold: int = 50 # nodes before considering "deep"
|
|
66
|
+
oversized_dense_threshold: int = 100_000_000 # 100M params in single layer
|
|
67
|
+
large_embedding_threshold: int = 500_000_000 # 500M params for embedding
|
|
68
|
+
high_flop_ratio_threshold: float = 0.5 # Single op using >50% of FLOPs
|
|
69
|
+
|
|
70
|
+
# Minimum trainable layers before flagging missing normalization/activations
|
|
71
|
+
min_trainable_for_norm_check: int = 10
|
|
72
|
+
min_trainable_for_activation_check: int = 5
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class RiskAnalyzer:
|
|
76
|
+
"""
|
|
77
|
+
Detect architectural risk signals in ONNX graphs.
|
|
78
|
+
|
|
79
|
+
Applies various heuristics to identify patterns that may
|
|
80
|
+
cause training, inference, or deployment issues.
|
|
81
|
+
|
|
82
|
+
Note: Risk signals are only generated for models above minimum
|
|
83
|
+
complexity thresholds to avoid flagging trivial test models.
|
|
84
|
+
|
|
85
|
+
Thresholds can be configured via the `thresholds` parameter.
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
# Default thresholds (class-level for backward compatibility)
|
|
89
|
+
MIN_PARAMS_FOR_ANALYSIS = 100_000
|
|
90
|
+
MIN_FLOPS_FOR_BOTTLENECK = 1_000_000_000
|
|
91
|
+
MIN_NODES_FOR_DEPTH_CHECK = 20
|
|
92
|
+
DEEP_NETWORK_THRESHOLD = 50
|
|
93
|
+
OVERSIZED_DENSE_THRESHOLD = 100_000_000
|
|
94
|
+
LARGE_EMBEDDING_THRESHOLD = 500_000_000
|
|
95
|
+
HIGH_FLOP_RATIO_THRESHOLD = 0.5
|
|
96
|
+
|
|
97
|
+
def __init__(
|
|
98
|
+
self,
|
|
99
|
+
logger: logging.Logger | None = None,
|
|
100
|
+
thresholds: RiskThresholds | None = None,
|
|
101
|
+
):
|
|
102
|
+
self.logger = logger or logging.getLogger("haoline.risks")
|
|
103
|
+
self.thresholds = thresholds or RiskThresholds()
|
|
104
|
+
|
|
105
|
+
# Also update class-level constants for backward compatibility
|
|
106
|
+
if thresholds:
|
|
107
|
+
self.MIN_PARAMS_FOR_ANALYSIS = thresholds.min_params_for_analysis
|
|
108
|
+
self.MIN_FLOPS_FOR_BOTTLENECK = thresholds.min_flops_for_bottleneck
|
|
109
|
+
self.MIN_NODES_FOR_DEPTH_CHECK = thresholds.min_nodes_for_depth_check
|
|
110
|
+
self.DEEP_NETWORK_THRESHOLD = thresholds.deep_network_threshold
|
|
111
|
+
self.OVERSIZED_DENSE_THRESHOLD = thresholds.oversized_dense_threshold
|
|
112
|
+
self.LARGE_EMBEDDING_THRESHOLD = thresholds.large_embedding_threshold
|
|
113
|
+
self.HIGH_FLOP_RATIO_THRESHOLD = thresholds.high_flop_ratio_threshold
|
|
114
|
+
|
|
115
|
+
def analyze(self, graph_info: GraphInfo, blocks: list[Block]) -> list[RiskSignal]:
|
|
116
|
+
"""
|
|
117
|
+
Run all risk heuristics and return detected signals.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
graph_info: Parsed graph information.
|
|
121
|
+
blocks: Detected architectural blocks.
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
List of RiskSignal instances.
|
|
125
|
+
"""
|
|
126
|
+
signals = []
|
|
127
|
+
|
|
128
|
+
# Run all checks
|
|
129
|
+
signal = self.check_deep_without_skips(graph_info, blocks)
|
|
130
|
+
if signal:
|
|
131
|
+
signals.append(signal)
|
|
132
|
+
|
|
133
|
+
signal = self.check_oversized_dense(graph_info)
|
|
134
|
+
if signal:
|
|
135
|
+
signals.append(signal)
|
|
136
|
+
|
|
137
|
+
signal = self.check_dynamic_shapes(graph_info)
|
|
138
|
+
if signal:
|
|
139
|
+
signals.append(signal)
|
|
140
|
+
|
|
141
|
+
signal = self.check_missing_normalization(graph_info, blocks)
|
|
142
|
+
if signal:
|
|
143
|
+
signals.append(signal)
|
|
144
|
+
|
|
145
|
+
signal = self.check_compute_bottleneck(graph_info)
|
|
146
|
+
if signal:
|
|
147
|
+
signals.append(signal)
|
|
148
|
+
|
|
149
|
+
signal = self.check_large_embedding(graph_info, blocks)
|
|
150
|
+
if signal:
|
|
151
|
+
signals.append(signal)
|
|
152
|
+
|
|
153
|
+
signal = self.check_unusual_activations(graph_info)
|
|
154
|
+
if signal:
|
|
155
|
+
signals.append(signal)
|
|
156
|
+
|
|
157
|
+
signal = self.check_nonstandard_residuals(graph_info, blocks)
|
|
158
|
+
if signal:
|
|
159
|
+
signals.append(signal)
|
|
160
|
+
|
|
161
|
+
self.logger.debug(f"Detected {len(signals)} risk signals")
|
|
162
|
+
return signals
|
|
163
|
+
|
|
164
|
+
def check_deep_without_skips(
|
|
165
|
+
self, graph_info: GraphInfo, blocks: list[Block]
|
|
166
|
+
) -> RiskSignal | None:
|
|
167
|
+
"""
|
|
168
|
+
Flag deep networks that lack skip connections.
|
|
169
|
+
|
|
170
|
+
Deep networks without residual connections may suffer from
|
|
171
|
+
vanishing gradients during training.
|
|
172
|
+
"""
|
|
173
|
+
# Skip very small models - they don't need skip connections
|
|
174
|
+
if graph_info.num_nodes < self.MIN_NODES_FOR_DEPTH_CHECK:
|
|
175
|
+
return None
|
|
176
|
+
|
|
177
|
+
if graph_info.num_nodes < self.DEEP_NETWORK_THRESHOLD:
|
|
178
|
+
return None
|
|
179
|
+
|
|
180
|
+
# Count residual blocks
|
|
181
|
+
residual_count = sum(1 for b in blocks if "Residual" in b.block_type)
|
|
182
|
+
|
|
183
|
+
if residual_count == 0:
|
|
184
|
+
return RiskSignal(
|
|
185
|
+
id="no_skip_connections",
|
|
186
|
+
severity="warning",
|
|
187
|
+
description=(
|
|
188
|
+
f"Model has {graph_info.num_nodes} nodes but no detected skip connections. "
|
|
189
|
+
"Deep networks without residual connections may have training difficulties."
|
|
190
|
+
),
|
|
191
|
+
nodes=[],
|
|
192
|
+
recommendation=(
|
|
193
|
+
"Consider adding skip/residual connections if this model will be trained. "
|
|
194
|
+
"If this is a pre-trained inference model, this may not be a concern."
|
|
195
|
+
),
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
return None
|
|
199
|
+
|
|
200
|
+
def check_oversized_dense(self, graph_info: GraphInfo) -> RiskSignal | None:
|
|
201
|
+
"""
|
|
202
|
+
Flag excessively large fully-connected layers.
|
|
203
|
+
|
|
204
|
+
Very large MatMul/Gemm operations can dominate compute and memory.
|
|
205
|
+
"""
|
|
206
|
+
large_ops = []
|
|
207
|
+
|
|
208
|
+
for node in graph_info.nodes:
|
|
209
|
+
if node.op_type in ("MatMul", "Gemm"):
|
|
210
|
+
# Check weight size
|
|
211
|
+
for inp in node.inputs:
|
|
212
|
+
if inp in graph_info.initializers:
|
|
213
|
+
weight = graph_info.initializers[inp]
|
|
214
|
+
param_count = int(weight.size) if hasattr(weight, "size") else 0
|
|
215
|
+
if param_count > self.OVERSIZED_DENSE_THRESHOLD:
|
|
216
|
+
large_ops.append((node.name, param_count))
|
|
217
|
+
break
|
|
218
|
+
|
|
219
|
+
if large_ops:
|
|
220
|
+
total_large = sum(p for _, p in large_ops)
|
|
221
|
+
return RiskSignal(
|
|
222
|
+
id="oversized_dense",
|
|
223
|
+
severity="info",
|
|
224
|
+
description=(
|
|
225
|
+
f"Found {len(large_ops)} dense layer(s) with >100M parameters "
|
|
226
|
+
f"(total: {total_large:,} params). These may dominate compute and memory."
|
|
227
|
+
),
|
|
228
|
+
nodes=[name for name, _ in large_ops],
|
|
229
|
+
recommendation=(
|
|
230
|
+
"Consider whether these large layers are necessary. "
|
|
231
|
+
"Techniques like low-rank factorization or pruning may help reduce size."
|
|
232
|
+
),
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
return None
|
|
236
|
+
|
|
237
|
+
def check_dynamic_shapes(self, graph_info: GraphInfo) -> RiskSignal | None:
|
|
238
|
+
"""
|
|
239
|
+
Flag inputs with dynamic shapes.
|
|
240
|
+
|
|
241
|
+
Dynamic shapes can cause issues with some inference backends
|
|
242
|
+
and prevent certain optimizations.
|
|
243
|
+
"""
|
|
244
|
+
dynamic_inputs = []
|
|
245
|
+
|
|
246
|
+
for name, shape in graph_info.input_shapes.items():
|
|
247
|
+
has_dynamic = any(not isinstance(d, int) for d in shape)
|
|
248
|
+
if has_dynamic:
|
|
249
|
+
dynamic_inputs.append(name)
|
|
250
|
+
|
|
251
|
+
if dynamic_inputs:
|
|
252
|
+
return RiskSignal(
|
|
253
|
+
id="dynamic_input_shapes",
|
|
254
|
+
severity="info",
|
|
255
|
+
description=(
|
|
256
|
+
f"Model has {len(dynamic_inputs)} input(s) with dynamic/symbolic dimensions: "
|
|
257
|
+
f"{', '.join(dynamic_inputs)}. "
|
|
258
|
+
"This is normal for variable-length sequences but may affect optimization."
|
|
259
|
+
),
|
|
260
|
+
nodes=[],
|
|
261
|
+
recommendation=(
|
|
262
|
+
"For best performance with hardware accelerators, consider providing "
|
|
263
|
+
"fixed shapes or using onnxruntime.tools.make_dynamic_shape_fixed."
|
|
264
|
+
),
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
return None
|
|
268
|
+
|
|
269
|
+
def check_missing_normalization(
|
|
270
|
+
self, graph_info: GraphInfo, blocks: list[Block]
|
|
271
|
+
) -> RiskSignal | None:
|
|
272
|
+
"""
|
|
273
|
+
Flag deep networks without normalization layers.
|
|
274
|
+
|
|
275
|
+
Networks without BatchNorm/LayerNorm may have training instabilities.
|
|
276
|
+
"""
|
|
277
|
+
# Skip small models
|
|
278
|
+
if graph_info.num_nodes < self.MIN_NODES_FOR_DEPTH_CHECK:
|
|
279
|
+
return None
|
|
280
|
+
|
|
281
|
+
norm_ops = {
|
|
282
|
+
"BatchNormalization",
|
|
283
|
+
"LayerNormalization",
|
|
284
|
+
"InstanceNormalization",
|
|
285
|
+
"GroupNormalization",
|
|
286
|
+
}
|
|
287
|
+
has_norm = any(op in graph_info.op_type_counts for op in norm_ops)
|
|
288
|
+
|
|
289
|
+
# Count trainable layers (Conv, MatMul, Gemm)
|
|
290
|
+
trainable_count = (
|
|
291
|
+
graph_info.op_type_counts.get("Conv", 0)
|
|
292
|
+
+ graph_info.op_type_counts.get("MatMul", 0)
|
|
293
|
+
+ graph_info.op_type_counts.get("Gemm", 0)
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
# Need at least N trainable layers to care about normalization
|
|
297
|
+
min_trainable = self.thresholds.min_trainable_for_norm_check
|
|
298
|
+
if not has_norm and trainable_count >= min_trainable:
|
|
299
|
+
return RiskSignal(
|
|
300
|
+
id="missing_normalization",
|
|
301
|
+
severity="info",
|
|
302
|
+
description=(
|
|
303
|
+
f"Model has {trainable_count} trainable layers but no normalization layers detected. "
|
|
304
|
+
"This may affect training stability."
|
|
305
|
+
),
|
|
306
|
+
nodes=[],
|
|
307
|
+
recommendation=(
|
|
308
|
+
"If this model will be fine-tuned, consider adding normalization layers. "
|
|
309
|
+
"For inference-only, this is typically not a concern."
|
|
310
|
+
),
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
return None
|
|
314
|
+
|
|
315
|
+
def check_compute_bottleneck(self, graph_info: GraphInfo) -> RiskSignal | None:
|
|
316
|
+
"""
|
|
317
|
+
Flag single operations that dominate compute.
|
|
318
|
+
|
|
319
|
+
If one layer uses >50% of FLOPs, it's a potential bottleneck.
|
|
320
|
+
Only flags models with significant compute (>1B FLOPs) to avoid
|
|
321
|
+
noise on trivial models.
|
|
322
|
+
"""
|
|
323
|
+
# Need to compute per-node FLOPs
|
|
324
|
+
total_flops = sum(node.flops for node in graph_info.nodes)
|
|
325
|
+
|
|
326
|
+
# Skip tiny models - no point optimizing a model with < 1B FLOPs
|
|
327
|
+
if total_flops < self.MIN_FLOPS_FOR_BOTTLENECK:
|
|
328
|
+
return None
|
|
329
|
+
|
|
330
|
+
bottlenecks = []
|
|
331
|
+
for node in graph_info.nodes:
|
|
332
|
+
if node.flops > 0:
|
|
333
|
+
ratio = node.flops / total_flops
|
|
334
|
+
if ratio > self.HIGH_FLOP_RATIO_THRESHOLD:
|
|
335
|
+
bottlenecks.append((node.name, node.op_type, ratio))
|
|
336
|
+
|
|
337
|
+
if bottlenecks:
|
|
338
|
+
desc_parts = [f"{name} ({op}: {ratio:.1%})" for name, op, ratio in bottlenecks]
|
|
339
|
+
total_gflops = total_flops / 1e9
|
|
340
|
+
return RiskSignal(
|
|
341
|
+
id="compute_bottleneck",
|
|
342
|
+
severity="info",
|
|
343
|
+
description=(
|
|
344
|
+
f"The following operations dominate compute ({total_gflops:.1f} GFLOPs total): "
|
|
345
|
+
f"{', '.join(desc_parts)}. Optimizing these would have the greatest impact."
|
|
346
|
+
),
|
|
347
|
+
nodes=[name for name, _, _ in bottlenecks],
|
|
348
|
+
recommendation="Focus optimization efforts (quantization, pruning) on these layers.",
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
return None
|
|
352
|
+
|
|
353
|
+
def check_large_embedding(
|
|
354
|
+
self, graph_info: GraphInfo, blocks: list[Block]
|
|
355
|
+
) -> RiskSignal | None:
|
|
356
|
+
"""
|
|
357
|
+
Flag very large embedding tables.
|
|
358
|
+
|
|
359
|
+
Large vocabulary embeddings can dominate model size.
|
|
360
|
+
"""
|
|
361
|
+
embedding_blocks = [b for b in blocks if b.block_type == "Embedding"]
|
|
362
|
+
|
|
363
|
+
large_embeddings = []
|
|
364
|
+
for block in embedding_blocks:
|
|
365
|
+
vocab_size = block.attributes.get("vocab_size", 0)
|
|
366
|
+
embed_dim = block.attributes.get("embed_dim", 0)
|
|
367
|
+
param_count = vocab_size * embed_dim
|
|
368
|
+
|
|
369
|
+
if param_count > self.LARGE_EMBEDDING_THRESHOLD:
|
|
370
|
+
large_embeddings.append((block.name, vocab_size, embed_dim, param_count))
|
|
371
|
+
|
|
372
|
+
if large_embeddings:
|
|
373
|
+
details = [
|
|
374
|
+
f"{name}: vocab={v}, dim={d}, params={p:,}" for name, v, d, p in large_embeddings
|
|
375
|
+
]
|
|
376
|
+
return RiskSignal(
|
|
377
|
+
id="large_embedding",
|
|
378
|
+
severity="info",
|
|
379
|
+
description=(
|
|
380
|
+
f"Found {len(large_embeddings)} large embedding table(s): {'; '.join(details)}. "
|
|
381
|
+
"These dominate model size."
|
|
382
|
+
),
|
|
383
|
+
nodes=[name for name, _, _, _ in large_embeddings],
|
|
384
|
+
recommendation=(
|
|
385
|
+
"Consider vocabulary pruning, dimensionality reduction, or "
|
|
386
|
+
"hash embeddings to reduce size."
|
|
387
|
+
),
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
return None
|
|
391
|
+
|
|
392
|
+
def check_unusual_activations(self, graph_info: GraphInfo) -> RiskSignal | None:
|
|
393
|
+
"""
|
|
394
|
+
Flag unusual activation function patterns.
|
|
395
|
+
|
|
396
|
+
Some activation combinations may indicate issues.
|
|
397
|
+
Only checks models with sufficient complexity.
|
|
398
|
+
"""
|
|
399
|
+
# Skip small models
|
|
400
|
+
if graph_info.num_nodes < self.MIN_NODES_FOR_DEPTH_CHECK:
|
|
401
|
+
return None
|
|
402
|
+
|
|
403
|
+
# Check for deprecated or unusual activations
|
|
404
|
+
unusual_ops = {"Elu", "Selu", "ThresholdedRelu", "Softsign", "Softplus"}
|
|
405
|
+
found_unusual = []
|
|
406
|
+
|
|
407
|
+
for op in unusual_ops:
|
|
408
|
+
if op in graph_info.op_type_counts:
|
|
409
|
+
found_unusual.append(f"{op} (x{graph_info.op_type_counts[op]})")
|
|
410
|
+
|
|
411
|
+
# Check for missing activations in deep networks
|
|
412
|
+
standard_activations = {
|
|
413
|
+
"Relu",
|
|
414
|
+
"LeakyRelu",
|
|
415
|
+
"Gelu",
|
|
416
|
+
"Silu",
|
|
417
|
+
"Sigmoid",
|
|
418
|
+
"Tanh",
|
|
419
|
+
"Softmax",
|
|
420
|
+
}
|
|
421
|
+
has_standard = any(op in graph_info.op_type_counts for op in standard_activations)
|
|
422
|
+
|
|
423
|
+
trainable_count = (
|
|
424
|
+
graph_info.op_type_counts.get("Conv", 0)
|
|
425
|
+
+ graph_info.op_type_counts.get("MatMul", 0)
|
|
426
|
+
+ graph_info.op_type_counts.get("Gemm", 0)
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
# Need at least N trainable layers to care about missing activations
|
|
430
|
+
min_trainable = self.thresholds.min_trainable_for_activation_check
|
|
431
|
+
if not has_standard and trainable_count >= min_trainable:
|
|
432
|
+
return RiskSignal(
|
|
433
|
+
id="no_activations",
|
|
434
|
+
severity="warning",
|
|
435
|
+
description=(
|
|
436
|
+
f"Model has {trainable_count} linear layers but no standard activation functions. "
|
|
437
|
+
"This makes the model effectively linear, limiting expressiveness."
|
|
438
|
+
),
|
|
439
|
+
nodes=[],
|
|
440
|
+
recommendation="Add activation functions between linear layers.",
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
if found_unusual:
|
|
444
|
+
return RiskSignal(
|
|
445
|
+
id="unusual_activations",
|
|
446
|
+
severity="info",
|
|
447
|
+
description=(
|
|
448
|
+
f"Model uses less common activation functions: {', '.join(found_unusual)}. "
|
|
449
|
+
"These may have limited hardware acceleration support."
|
|
450
|
+
),
|
|
451
|
+
nodes=[],
|
|
452
|
+
recommendation=(
|
|
453
|
+
"Consider using more common activations (ReLU, GELU, SiLU) for better "
|
|
454
|
+
"hardware support, unless these specific activations are required."
|
|
455
|
+
),
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
return None
|
|
459
|
+
|
|
460
|
+
def check_nonstandard_residuals(
|
|
461
|
+
self, graph_info: GraphInfo, blocks: list[Block]
|
|
462
|
+
) -> RiskSignal | None:
|
|
463
|
+
"""
|
|
464
|
+
Flag non-standard residual/skip connection patterns.
|
|
465
|
+
|
|
466
|
+
Non-standard patterns include:
|
|
467
|
+
- Concat-based skip connections (DenseNet-style)
|
|
468
|
+
- Gated skip connections (Highway networks)
|
|
469
|
+
- Subtraction-based residuals
|
|
470
|
+
|
|
471
|
+
These may require special handling for optimization or deployment.
|
|
472
|
+
"""
|
|
473
|
+
# Identify non-standard residual blocks
|
|
474
|
+
nonstandard_types = {
|
|
475
|
+
"ResidualConcat": "concat-based (DenseNet-style)",
|
|
476
|
+
"ResidualGate": "gated (Highway/attention gate)",
|
|
477
|
+
"ResidualSub": "subtraction-based",
|
|
478
|
+
}
|
|
479
|
+
|
|
480
|
+
found_nonstandard: dict[str, list[str]] = {}
|
|
481
|
+
for block in blocks:
|
|
482
|
+
if block.block_type in nonstandard_types:
|
|
483
|
+
variant = nonstandard_types[block.block_type]
|
|
484
|
+
if variant not in found_nonstandard:
|
|
485
|
+
found_nonstandard[variant] = []
|
|
486
|
+
found_nonstandard[variant].append(block.name)
|
|
487
|
+
|
|
488
|
+
if not found_nonstandard:
|
|
489
|
+
return None
|
|
490
|
+
|
|
491
|
+
# Build description
|
|
492
|
+
details = []
|
|
493
|
+
all_nodes = []
|
|
494
|
+
for variant, block_names in found_nonstandard.items():
|
|
495
|
+
details.append(f"{len(block_names)} {variant}")
|
|
496
|
+
all_nodes.extend(block_names)
|
|
497
|
+
|
|
498
|
+
total_count = sum(len(names) for names in found_nonstandard.values())
|
|
499
|
+
|
|
500
|
+
# Check if model also has standard residuals
|
|
501
|
+
standard_count = sum(1 for b in blocks if b.block_type == "ResidualAdd")
|
|
502
|
+
mixed_msg = ""
|
|
503
|
+
if standard_count > 0:
|
|
504
|
+
mixed_msg = f" Model also has {standard_count} standard Add-based residuals."
|
|
505
|
+
|
|
506
|
+
return RiskSignal(
|
|
507
|
+
id="nonstandard_residuals",
|
|
508
|
+
severity="info",
|
|
509
|
+
description=(
|
|
510
|
+
f"Model uses {total_count} non-standard skip connection(s): "
|
|
511
|
+
f"{', '.join(details)}.{mixed_msg} "
|
|
512
|
+
"These patterns may indicate custom architectures requiring special attention."
|
|
513
|
+
),
|
|
514
|
+
nodes=all_nodes,
|
|
515
|
+
recommendation=(
|
|
516
|
+
"Non-standard skip connections are valid but may need special handling: "
|
|
517
|
+
"Concat-based patterns increase tensor sizes through the network. "
|
|
518
|
+
"Gated patterns add compute overhead but enable selective information flow. "
|
|
519
|
+
"Ensure your deployment target and optimization tools support these patterns."
|
|
520
|
+
),
|
|
521
|
+
)
|