haoline 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. haoline/.streamlit/config.toml +10 -0
  2. haoline/__init__.py +248 -0
  3. haoline/analyzer.py +935 -0
  4. haoline/cli.py +2712 -0
  5. haoline/compare.py +811 -0
  6. haoline/compare_visualizations.py +1564 -0
  7. haoline/edge_analysis.py +525 -0
  8. haoline/eval/__init__.py +131 -0
  9. haoline/eval/adapters.py +844 -0
  10. haoline/eval/cli.py +390 -0
  11. haoline/eval/comparison.py +542 -0
  12. haoline/eval/deployment.py +633 -0
  13. haoline/eval/schemas.py +833 -0
  14. haoline/examples/__init__.py +15 -0
  15. haoline/examples/basic_inspection.py +74 -0
  16. haoline/examples/compare_models.py +117 -0
  17. haoline/examples/hardware_estimation.py +78 -0
  18. haoline/format_adapters.py +1001 -0
  19. haoline/formats/__init__.py +123 -0
  20. haoline/formats/coreml.py +250 -0
  21. haoline/formats/gguf.py +483 -0
  22. haoline/formats/openvino.py +255 -0
  23. haoline/formats/safetensors.py +273 -0
  24. haoline/formats/tflite.py +369 -0
  25. haoline/hardware.py +2307 -0
  26. haoline/hierarchical_graph.py +462 -0
  27. haoline/html_export.py +1573 -0
  28. haoline/layer_summary.py +769 -0
  29. haoline/llm_summarizer.py +465 -0
  30. haoline/op_icons.py +618 -0
  31. haoline/operational_profiling.py +1492 -0
  32. haoline/patterns.py +1116 -0
  33. haoline/pdf_generator.py +265 -0
  34. haoline/privacy.py +250 -0
  35. haoline/pydantic_models.py +241 -0
  36. haoline/report.py +1923 -0
  37. haoline/report_sections.py +539 -0
  38. haoline/risks.py +521 -0
  39. haoline/schema.py +523 -0
  40. haoline/streamlit_app.py +2024 -0
  41. haoline/tests/__init__.py +4 -0
  42. haoline/tests/conftest.py +123 -0
  43. haoline/tests/test_analyzer.py +868 -0
  44. haoline/tests/test_compare_visualizations.py +293 -0
  45. haoline/tests/test_edge_analysis.py +243 -0
  46. haoline/tests/test_eval.py +604 -0
  47. haoline/tests/test_format_adapters.py +460 -0
  48. haoline/tests/test_hardware.py +237 -0
  49. haoline/tests/test_hardware_recommender.py +90 -0
  50. haoline/tests/test_hierarchical_graph.py +326 -0
  51. haoline/tests/test_html_export.py +180 -0
  52. haoline/tests/test_layer_summary.py +428 -0
  53. haoline/tests/test_llm_patterns.py +540 -0
  54. haoline/tests/test_llm_summarizer.py +339 -0
  55. haoline/tests/test_patterns.py +774 -0
  56. haoline/tests/test_pytorch.py +327 -0
  57. haoline/tests/test_report.py +383 -0
  58. haoline/tests/test_risks.py +398 -0
  59. haoline/tests/test_schema.py +417 -0
  60. haoline/tests/test_tensorflow.py +380 -0
  61. haoline/tests/test_visualizations.py +316 -0
  62. haoline/universal_ir.py +856 -0
  63. haoline/visualizations.py +1086 -0
  64. haoline/visualize_yolo.py +44 -0
  65. haoline/web.py +110 -0
  66. haoline-0.3.0.dist-info/METADATA +471 -0
  67. haoline-0.3.0.dist-info/RECORD +70 -0
  68. haoline-0.3.0.dist-info/WHEEL +4 -0
  69. haoline-0.3.0.dist-info/entry_points.txt +5 -0
  70. haoline-0.3.0.dist-info/licenses/LICENSE +22 -0
haoline/op_icons.py ADDED
@@ -0,0 +1,618 @@
1
+ # Copyright (c) 2025 HaoLine Contributors
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ """
5
+ Op Type Icon System for graph visualization.
6
+
7
+ Task 5.5: Maps 180+ ONNX operators to visual categories with icons,
8
+ colors, and size scaling based on computational intensity.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import math
14
+ from dataclasses import dataclass
15
+ from enum import Enum
16
+ from typing import ClassVar
17
+
18
+
19
+ class OpCategory(Enum):
20
+ """Visual categories for ONNX operators."""
21
+
22
+ # Neural Network Operations
23
+ CONV = "conv" # Convolution operations
24
+ LINEAR = "linear" # MatMul, Gemm, fully connected
25
+ ATTENTION = "attention" # Attention patterns (detected, not single op)
26
+ NORM = "norm" # Normalization layers
27
+ ACTIVATION = "activation" # Activation functions
28
+ POOL = "pool" # Pooling operations
29
+ DROPOUT = "dropout" # Regularization
30
+
31
+ # Tensor Operations
32
+ RESHAPE = "reshape" # Shape manipulation
33
+ TRANSPOSE = "transpose" # Dimension reordering
34
+ SLICE = "slice" # Indexing and slicing
35
+ CONCAT = "concat" # Tensor joining
36
+ SPLIT = "split" # Tensor splitting
37
+ PAD = "pad" # Padding operations
38
+
39
+ # Math Operations
40
+ ELEMENTWISE = "elementwise" # Element-wise math
41
+ REDUCE = "reduce" # Reduction operations
42
+ COMPARE = "compare" # Comparison and logic
43
+
44
+ # Special Operations
45
+ EMBED = "embed" # Embedding lookups
46
+ RECURRENT = "recurrent" # RNN, LSTM, GRU
47
+ QUANTIZE = "quantize" # Quantization ops
48
+ CAST = "cast" # Type conversion
49
+ CONTROL = "control" # Control flow (If, Loop)
50
+
51
+ # Misc
52
+ CONSTANT = "constant" # Constants and identity
53
+ UNKNOWN = "unknown" # Unrecognized ops
54
+
55
+
56
+ @dataclass
57
+ class OpIcon:
58
+ """Visual representation for an operator category."""
59
+
60
+ category: OpCategory
61
+ shape: str # SVG shape: "rect", "circle", "diamond", "hexagon", etc.
62
+ symbol: str # Unicode symbol for text rendering
63
+ color: str # Default color (hex)
64
+ description: str
65
+
66
+ def to_dict(self) -> dict:
67
+ return {
68
+ "category": self.category.value,
69
+ "shape": self.shape,
70
+ "symbol": self.symbol,
71
+ "color": self.color,
72
+ "description": self.description,
73
+ }
74
+
75
+
76
+ # Icon definitions for each category
77
+ CATEGORY_ICONS: dict[OpCategory, OpIcon] = {
78
+ OpCategory.CONV: OpIcon(
79
+ OpCategory.CONV,
80
+ shape="rect",
81
+ symbol="▦",
82
+ color="#4A90D9", # Blue
83
+ description="Convolution",
84
+ ),
85
+ OpCategory.LINEAR: OpIcon(
86
+ OpCategory.LINEAR,
87
+ shape="diamond",
88
+ symbol="◆",
89
+ color="#9B59B6", # Purple
90
+ description="Linear/MatMul",
91
+ ),
92
+ OpCategory.ATTENTION: OpIcon(
93
+ OpCategory.ATTENTION,
94
+ shape="hexagon",
95
+ symbol="◎",
96
+ color="#E67E22", # Orange
97
+ description="Attention",
98
+ ),
99
+ OpCategory.NORM: OpIcon(
100
+ OpCategory.NORM,
101
+ shape="rect",
102
+ symbol="▬",
103
+ color="#7F8C8D", # Gray
104
+ description="Normalization",
105
+ ),
106
+ OpCategory.ACTIVATION: OpIcon(
107
+ OpCategory.ACTIVATION,
108
+ shape="circle",
109
+ symbol="⚡",
110
+ color="#F1C40F", # Yellow
111
+ description="Activation",
112
+ ),
113
+ OpCategory.POOL: OpIcon(
114
+ OpCategory.POOL,
115
+ shape="trapezoid",
116
+ symbol="▼",
117
+ color="#1ABC9C", # Teal
118
+ description="Pooling",
119
+ ),
120
+ OpCategory.DROPOUT: OpIcon(
121
+ OpCategory.DROPOUT,
122
+ shape="circle",
123
+ symbol="◌",
124
+ color="#95A5A6", # Light gray
125
+ description="Dropout",
126
+ ),
127
+ OpCategory.RESHAPE: OpIcon(
128
+ OpCategory.RESHAPE,
129
+ shape="parallelogram",
130
+ symbol="⬔",
131
+ color="#3498DB", # Light blue
132
+ description="Reshape",
133
+ ),
134
+ OpCategory.TRANSPOSE: OpIcon(
135
+ OpCategory.TRANSPOSE,
136
+ shape="parallelogram",
137
+ symbol="↔",
138
+ color="#3498DB",
139
+ description="Transpose",
140
+ ),
141
+ OpCategory.SLICE: OpIcon(
142
+ OpCategory.SLICE,
143
+ shape="rect",
144
+ symbol="✂",
145
+ color="#E74C3C", # Red
146
+ description="Slice/Index",
147
+ ),
148
+ OpCategory.CONCAT: OpIcon(
149
+ OpCategory.CONCAT,
150
+ shape="rect",
151
+ symbol="⊕",
152
+ color="#2ECC71", # Green
153
+ description="Concat",
154
+ ),
155
+ OpCategory.SPLIT: OpIcon(
156
+ OpCategory.SPLIT,
157
+ shape="rect",
158
+ symbol="⊖",
159
+ color="#E74C3C",
160
+ description="Split",
161
+ ),
162
+ OpCategory.PAD: OpIcon(
163
+ OpCategory.PAD,
164
+ shape="rect",
165
+ symbol="▭",
166
+ color="#BDC3C7", # Silver
167
+ description="Padding",
168
+ ),
169
+ OpCategory.ELEMENTWISE: OpIcon(
170
+ OpCategory.ELEMENTWISE,
171
+ shape="circle",
172
+ symbol="±",
173
+ color="#9B59B6", # Purple
174
+ description="Elementwise",
175
+ ),
176
+ OpCategory.REDUCE: OpIcon(
177
+ OpCategory.REDUCE,
178
+ shape="triangle",
179
+ symbol="Σ",
180
+ color="#E74C3C", # Red
181
+ description="Reduce",
182
+ ),
183
+ OpCategory.COMPARE: OpIcon(
184
+ OpCategory.COMPARE,
185
+ shape="diamond",
186
+ symbol="?",
187
+ color="#F39C12", # Dark yellow
188
+ description="Compare",
189
+ ),
190
+ OpCategory.EMBED: OpIcon(
191
+ OpCategory.EMBED,
192
+ shape="rect",
193
+ symbol="📖",
194
+ color="#8E44AD", # Dark purple
195
+ description="Embedding",
196
+ ),
197
+ OpCategory.RECURRENT: OpIcon(
198
+ OpCategory.RECURRENT,
199
+ shape="rect",
200
+ symbol="↻",
201
+ color="#16A085", # Dark teal
202
+ description="Recurrent",
203
+ ),
204
+ OpCategory.QUANTIZE: OpIcon(
205
+ OpCategory.QUANTIZE,
206
+ shape="octagon",
207
+ symbol="Q",
208
+ color="#27AE60", # Dark green
209
+ description="Quantization",
210
+ ),
211
+ OpCategory.CAST: OpIcon(
212
+ OpCategory.CAST,
213
+ shape="circle",
214
+ symbol="⇄",
215
+ color="#95A5A6",
216
+ description="Cast",
217
+ ),
218
+ OpCategory.CONTROL: OpIcon(
219
+ OpCategory.CONTROL,
220
+ shape="diamond",
221
+ symbol="◇",
222
+ color="#E74C3C",
223
+ description="Control Flow",
224
+ ),
225
+ OpCategory.CONSTANT: OpIcon(
226
+ OpCategory.CONSTANT,
227
+ shape="circle",
228
+ symbol="•",
229
+ color="#BDC3C7",
230
+ description="Constant",
231
+ ),
232
+ OpCategory.UNKNOWN: OpIcon(
233
+ OpCategory.UNKNOWN,
234
+ shape="rect",
235
+ symbol="?",
236
+ color="#7F8C8D",
237
+ description="Unknown",
238
+ ),
239
+ }
240
+
241
+
242
+ # Mapping of ONNX op types to categories
243
+ # This covers all standard ONNX operators as of opset 21
244
+ OP_TO_CATEGORY: dict[str, OpCategory] = {
245
+ # Convolution
246
+ "Conv": OpCategory.CONV,
247
+ "ConvInteger": OpCategory.CONV,
248
+ "ConvTranspose": OpCategory.CONV,
249
+ "DeformConv": OpCategory.CONV,
250
+ # Linear/Matrix
251
+ "MatMul": OpCategory.LINEAR,
252
+ "MatMulInteger": OpCategory.LINEAR,
253
+ "Gemm": OpCategory.LINEAR,
254
+ "QLinearMatMul": OpCategory.LINEAR,
255
+ # Normalization
256
+ "BatchNormalization": OpCategory.NORM,
257
+ "InstanceNormalization": OpCategory.NORM,
258
+ "LayerNormalization": OpCategory.NORM,
259
+ "GroupNormalization": OpCategory.NORM,
260
+ "LpNormalization": OpCategory.NORM,
261
+ "MeanVarianceNormalization": OpCategory.NORM,
262
+ "SimplifiedLayerNormalization": OpCategory.NORM,
263
+ # Activation functions
264
+ "Relu": OpCategory.ACTIVATION,
265
+ "LeakyRelu": OpCategory.ACTIVATION,
266
+ "PRelu": OpCategory.ACTIVATION,
267
+ "Selu": OpCategory.ACTIVATION,
268
+ "Elu": OpCategory.ACTIVATION,
269
+ "Celu": OpCategory.ACTIVATION,
270
+ "Sigmoid": OpCategory.ACTIVATION,
271
+ "HardSigmoid": OpCategory.ACTIVATION,
272
+ "Tanh": OpCategory.ACTIVATION,
273
+ "Softmax": OpCategory.ACTIVATION,
274
+ "LogSoftmax": OpCategory.ACTIVATION,
275
+ "Softplus": OpCategory.ACTIVATION,
276
+ "Softsign": OpCategory.ACTIVATION,
277
+ "HardSwish": OpCategory.ACTIVATION,
278
+ "Mish": OpCategory.ACTIVATION,
279
+ "Gelu": OpCategory.ACTIVATION,
280
+ "FastGelu": OpCategory.ACTIVATION,
281
+ "QuickGelu": OpCategory.ACTIVATION,
282
+ "Silu": OpCategory.ACTIVATION,
283
+ "Swish": OpCategory.ACTIVATION,
284
+ "ThresholdedRelu": OpCategory.ACTIVATION,
285
+ "Shrink": OpCategory.ACTIVATION,
286
+ # Pooling
287
+ "MaxPool": OpCategory.POOL,
288
+ "AveragePool": OpCategory.POOL,
289
+ "GlobalMaxPool": OpCategory.POOL,
290
+ "GlobalAveragePool": OpCategory.POOL,
291
+ "LpPool": OpCategory.POOL,
292
+ "MaxRoiPool": OpCategory.POOL,
293
+ "RoiAlign": OpCategory.POOL,
294
+ "MaxUnpool": OpCategory.POOL,
295
+ # Dropout/Regularization
296
+ "Dropout": OpCategory.DROPOUT,
297
+ # Reshape operations
298
+ "Reshape": OpCategory.RESHAPE,
299
+ "Flatten": OpCategory.RESHAPE,
300
+ "Squeeze": OpCategory.RESHAPE,
301
+ "Unsqueeze": OpCategory.RESHAPE,
302
+ "Expand": OpCategory.RESHAPE,
303
+ "Tile": OpCategory.RESHAPE,
304
+ "SpaceToDepth": OpCategory.RESHAPE,
305
+ "DepthToSpace": OpCategory.RESHAPE,
306
+ # Transpose operations
307
+ "Transpose": OpCategory.TRANSPOSE,
308
+ "Einsum": OpCategory.TRANSPOSE,
309
+ # Slice/Index operations
310
+ "Slice": OpCategory.SLICE,
311
+ "Gather": OpCategory.SLICE,
312
+ "GatherElements": OpCategory.SLICE,
313
+ "GatherND": OpCategory.SLICE,
314
+ "ScatterElements": OpCategory.SLICE,
315
+ "ScatterND": OpCategory.SLICE,
316
+ "Compress": OpCategory.SLICE,
317
+ "TopK": OpCategory.SLICE,
318
+ "NonZero": OpCategory.SLICE,
319
+ "NonMaxSuppression": OpCategory.SLICE,
320
+ # Concat/Join operations
321
+ "Concat": OpCategory.CONCAT,
322
+ "ConcatFromSequence": OpCategory.CONCAT,
323
+ # Split operations
324
+ "Split": OpCategory.SPLIT,
325
+ "SplitToSequence": OpCategory.SPLIT,
326
+ "Chunk": OpCategory.SPLIT,
327
+ # Padding
328
+ "Pad": OpCategory.PAD,
329
+ "ConstantOfShape": OpCategory.PAD,
330
+ # Elementwise math
331
+ "Add": OpCategory.ELEMENTWISE,
332
+ "Sub": OpCategory.ELEMENTWISE,
333
+ "Mul": OpCategory.ELEMENTWISE,
334
+ "Div": OpCategory.ELEMENTWISE,
335
+ "Pow": OpCategory.ELEMENTWISE,
336
+ "Sqrt": OpCategory.ELEMENTWISE,
337
+ "Reciprocal": OpCategory.ELEMENTWISE,
338
+ "Exp": OpCategory.ELEMENTWISE,
339
+ "Log": OpCategory.ELEMENTWISE,
340
+ "Abs": OpCategory.ELEMENTWISE,
341
+ "Neg": OpCategory.ELEMENTWISE,
342
+ "Sign": OpCategory.ELEMENTWISE,
343
+ "Ceil": OpCategory.ELEMENTWISE,
344
+ "Floor": OpCategory.ELEMENTWISE,
345
+ "Round": OpCategory.ELEMENTWISE,
346
+ "Clip": OpCategory.ELEMENTWISE,
347
+ "Min": OpCategory.ELEMENTWISE,
348
+ "Max": OpCategory.ELEMENTWISE,
349
+ "Mean": OpCategory.ELEMENTWISE,
350
+ "Sum": OpCategory.ELEMENTWISE,
351
+ "Mod": OpCategory.ELEMENTWISE,
352
+ "BitShift": OpCategory.ELEMENTWISE,
353
+ "BitwiseAnd": OpCategory.ELEMENTWISE,
354
+ "BitwiseNot": OpCategory.ELEMENTWISE,
355
+ "BitwiseOr": OpCategory.ELEMENTWISE,
356
+ "BitwiseXor": OpCategory.ELEMENTWISE,
357
+ # Trigonometric
358
+ "Sin": OpCategory.ELEMENTWISE,
359
+ "Cos": OpCategory.ELEMENTWISE,
360
+ "Tan": OpCategory.ELEMENTWISE,
361
+ "Asin": OpCategory.ELEMENTWISE,
362
+ "Acos": OpCategory.ELEMENTWISE,
363
+ "Atan": OpCategory.ELEMENTWISE,
364
+ "Sinh": OpCategory.ELEMENTWISE,
365
+ "Cosh": OpCategory.ELEMENTWISE, # Also activation
366
+ "Asinh": OpCategory.ELEMENTWISE,
367
+ "Acosh": OpCategory.ELEMENTWISE,
368
+ "Atanh": OpCategory.ELEMENTWISE,
369
+ # Reduction operations
370
+ "ReduceSum": OpCategory.REDUCE,
371
+ "ReduceMean": OpCategory.REDUCE,
372
+ "ReduceMax": OpCategory.REDUCE,
373
+ "ReduceMin": OpCategory.REDUCE,
374
+ "ReduceProd": OpCategory.REDUCE,
375
+ "ReduceL1": OpCategory.REDUCE,
376
+ "ReduceL2": OpCategory.REDUCE,
377
+ "ReduceLogSum": OpCategory.REDUCE,
378
+ "ReduceLogSumExp": OpCategory.REDUCE,
379
+ "ReduceSumSquare": OpCategory.REDUCE,
380
+ "ArgMax": OpCategory.REDUCE,
381
+ "ArgMin": OpCategory.REDUCE,
382
+ # Comparison/Logic
383
+ "Equal": OpCategory.COMPARE,
384
+ "Greater": OpCategory.COMPARE,
385
+ "GreaterOrEqual": OpCategory.COMPARE,
386
+ "Less": OpCategory.COMPARE,
387
+ "LessOrEqual": OpCategory.COMPARE,
388
+ "And": OpCategory.COMPARE,
389
+ "Or": OpCategory.COMPARE,
390
+ "Xor": OpCategory.COMPARE,
391
+ "Not": OpCategory.COMPARE,
392
+ "Where": OpCategory.COMPARE,
393
+ "IsNaN": OpCategory.COMPARE,
394
+ "IsInf": OpCategory.COMPARE,
395
+ # Embedding
396
+ "Embedding": OpCategory.EMBED,
397
+ # Note: Gather on embedding tables detected separately
398
+ # Recurrent
399
+ "RNN": OpCategory.RECURRENT,
400
+ "LSTM": OpCategory.RECURRENT,
401
+ "GRU": OpCategory.RECURRENT,
402
+ # Quantization
403
+ "QuantizeLinear": OpCategory.QUANTIZE,
404
+ "DequantizeLinear": OpCategory.QUANTIZE,
405
+ "DynamicQuantizeLinear": OpCategory.QUANTIZE,
406
+ "QLinearConv": OpCategory.QUANTIZE,
407
+ # Cast/Type conversion
408
+ "Cast": OpCategory.CAST,
409
+ "CastLike": OpCategory.CAST,
410
+ # Control flow
411
+ "If": OpCategory.CONTROL,
412
+ "Loop": OpCategory.CONTROL,
413
+ "Scan": OpCategory.CONTROL,
414
+ "SequenceAt": OpCategory.CONTROL,
415
+ "SequenceConstruct": OpCategory.CONTROL,
416
+ "SequenceEmpty": OpCategory.CONTROL,
417
+ "SequenceErase": OpCategory.CONTROL,
418
+ "SequenceInsert": OpCategory.CONTROL,
419
+ "SequenceLength": OpCategory.CONTROL,
420
+ # Constants
421
+ "Constant": OpCategory.CONSTANT,
422
+ "Identity": OpCategory.CONSTANT,
423
+ "Shape": OpCategory.CONSTANT,
424
+ "Size": OpCategory.CONSTANT,
425
+ "Range": OpCategory.CONSTANT,
426
+ "EyeLike": OpCategory.CONSTANT,
427
+ "RandomNormal": OpCategory.CONSTANT,
428
+ "RandomNormalLike": OpCategory.CONSTANT,
429
+ "RandomUniform": OpCategory.CONSTANT,
430
+ "RandomUniformLike": OpCategory.CONSTANT,
431
+ "Multinomial": OpCategory.CONSTANT,
432
+ "OneHot": OpCategory.CONSTANT,
433
+ }
434
+
435
+
436
+ def get_op_category(op_type: str) -> OpCategory:
437
+ """Get the visual category for an ONNX operator."""
438
+ return OP_TO_CATEGORY.get(op_type, OpCategory.UNKNOWN)
439
+
440
+
441
+ def get_op_icon(op_type: str) -> OpIcon:
442
+ """Get the icon definition for an ONNX operator."""
443
+ category = get_op_category(op_type)
444
+ return CATEGORY_ICONS[category]
445
+
446
+
447
+ def get_all_categories() -> list[OpIcon]:
448
+ """Get all category icon definitions."""
449
+ return list(CATEGORY_ICONS.values())
450
+
451
+
452
+ # Size scaling based on FLOPs (log scale)
453
+ def compute_node_size(flops: int, min_size: float = 20, max_size: float = 80) -> float:
454
+ """
455
+ Compute visual node size based on FLOPs.
456
+
457
+ Task 5.5.3: Size scaling function.
458
+
459
+ Uses log scale to handle the huge range of FLOPs (1 to 1T+).
460
+ """
461
+ if flops <= 0:
462
+ return min_size
463
+
464
+ # Log scale: 1 FLOP = min_size, 1T FLOPs = max_size
465
+ log_flops = math.log10(max(flops, 1))
466
+ log_max = 12 # 10^12 = 1 trillion FLOPs
467
+
468
+ # Linear interpolation in log space
469
+ t = min(log_flops / log_max, 1.0)
470
+ return min_size + t * (max_size - min_size)
471
+
472
+
473
+ # Color intensity based on compute/memory
474
+ @dataclass
475
+ class ColorMapping:
476
+ """Color mapping configuration for nodes."""
477
+
478
+ # Precision-based colors
479
+ PRECISION_COLORS: ClassVar[dict[str, str]] = {
480
+ "fp32": "#4A90D9", # Blue
481
+ "fp16": "#2ECC71", # Green
482
+ "bf16": "#9B59B6", # Purple
483
+ "int8": "#F1C40F", # Yellow
484
+ "int4": "#E67E22", # Orange
485
+ "uint8": "#F39C12", # Dark yellow
486
+ }
487
+
488
+ # Memory intensity gradient (low to high)
489
+ MEMORY_GRADIENT: ClassVar[list[str]] = [
490
+ "#2ECC71", # Green (low)
491
+ "#F1C40F", # Yellow (medium)
492
+ "#E67E22", # Orange (high)
493
+ "#E74C3C", # Red (very high)
494
+ ]
495
+
496
+ @staticmethod
497
+ def get_precision_color(precision: str) -> str:
498
+ """Get color for precision type."""
499
+ return ColorMapping.PRECISION_COLORS.get(precision.lower(), "#7F8C8D")
500
+
501
+ @staticmethod
502
+ def get_memory_color(memory_bytes: int, max_bytes: int) -> str:
503
+ """Get color based on memory usage intensity."""
504
+ if max_bytes <= 0:
505
+ return ColorMapping.MEMORY_GRADIENT[0]
506
+
507
+ ratio = min(memory_bytes / max_bytes, 1.0)
508
+ idx = int(ratio * (len(ColorMapping.MEMORY_GRADIENT) - 1))
509
+ return ColorMapping.MEMORY_GRADIENT[idx]
510
+
511
+
512
+ # SVG icon templates
513
+ SVG_ICONS: dict[str, str] = {
514
+ "rect": '<rect x="{x}" y="{y}" width="{w}" height="{h}" rx="4" fill="{color}" stroke="{stroke}" stroke-width="1"/>',
515
+ "circle": '<circle cx="{cx}" cy="{cy}" r="{r}" fill="{color}" stroke="{stroke}" stroke-width="1"/>',
516
+ "diamond": '<polygon points="{cx},{y} {x2},{cy} {cx},{y2} {x},{cy}" fill="{color}" stroke="{stroke}" stroke-width="1"/>',
517
+ "hexagon": '<polygon points="{p1} {p2} {p3} {p4} {p5} {p6}" fill="{color}" stroke="{stroke}" stroke-width="1"/>',
518
+ "triangle": '<polygon points="{cx},{y} {x2},{y2} {x},{y2}" fill="{color}" stroke="{stroke}" stroke-width="1"/>',
519
+ }
520
+
521
+
522
+ def generate_svg_node(
523
+ op_type: str,
524
+ x: float,
525
+ y: float,
526
+ size: float,
527
+ label: str | None = None,
528
+ flops: int = 0,
529
+ ) -> str:
530
+ """
531
+ Generate SVG markup for a node.
532
+
533
+ Task 5.5.5: Create SVG icon for HTML embedding.
534
+ """
535
+ icon = get_op_icon(op_type)
536
+ node_size = compute_node_size(flops) if flops > 0 else size
537
+
538
+ half = node_size / 2
539
+ cx, cy = x + half, y + half
540
+
541
+ # Generate shape
542
+ if icon.shape == "rect":
543
+ shape_svg = SVG_ICONS["rect"].format(
544
+ x=x, y=y, w=node_size, h=node_size, color=icon.color, stroke="#333"
545
+ )
546
+ elif icon.shape == "circle":
547
+ shape_svg = SVG_ICONS["circle"].format(
548
+ cx=cx, cy=cy, r=half, color=icon.color, stroke="#333"
549
+ )
550
+ elif icon.shape == "diamond":
551
+ shape_svg = SVG_ICONS["diamond"].format(
552
+ cx=cx,
553
+ cy=cy,
554
+ x=x,
555
+ y=y,
556
+ x2=x + node_size,
557
+ y2=y + node_size,
558
+ color=icon.color,
559
+ stroke="#333",
560
+ )
561
+ elif icon.shape == "triangle":
562
+ shape_svg = SVG_ICONS["triangle"].format(
563
+ cx=cx,
564
+ x=x,
565
+ y=y,
566
+ x2=x + node_size,
567
+ y2=y + node_size,
568
+ color=icon.color,
569
+ stroke="#333",
570
+ )
571
+ else:
572
+ # Default to rect
573
+ shape_svg = SVG_ICONS["rect"].format(
574
+ x=x, y=y, w=node_size, h=node_size, color=icon.color, stroke="#333"
575
+ )
576
+
577
+ # Add label if provided
578
+ if label:
579
+ label_svg = f'<text x="{cx}" y="{cy + 4}" text-anchor="middle" font-size="10" fill="white">{icon.symbol}</text>'
580
+ shape_svg += label_svg
581
+
582
+ return f'<g class="node node-{icon.category.value}" data-op="{op_type}">{shape_svg}</g>'
583
+
584
+
585
+ def generate_legend_svg(width: int = 400, height: int = 300) -> str:
586
+ """
587
+ Generate SVG legend showing all categories.
588
+
589
+ Task 5.5.6: Add legend/key to visualization.
590
+ """
591
+ lines = [f'<svg viewBox="0 0 {width} {height}" xmlns="http://www.w3.org/2000/svg">']
592
+ lines.append('<rect width="100%" height="100%" fill="#1a1a2e"/>')
593
+ lines.append(
594
+ '<text x="10" y="25" font-size="14" fill="white" font-weight="bold">Op Type Legend</text>'
595
+ )
596
+
597
+ y = 45
598
+ col_width = width // 2
599
+ col = 0
600
+
601
+ for _i, (category, icon) in enumerate(CATEGORY_ICONS.items()):
602
+ if category == OpCategory.UNKNOWN:
603
+ continue
604
+
605
+ x = 10 + col * col_width
606
+ node_svg = generate_svg_node(category.value, x, y, 20)
607
+ lines.append(node_svg)
608
+ lines.append(
609
+ f'<text x="{x + 30}" y="{y + 15}" font-size="11" fill="#ccc">{icon.description}</text>'
610
+ )
611
+
612
+ col += 1
613
+ if col >= 2:
614
+ col = 0
615
+ y += 30
616
+
617
+ lines.append("</svg>")
618
+ return "\n".join(lines)