repomap-cli 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
repomap/ranking.py ADDED
@@ -0,0 +1,639 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Repo Map Ranking — PageRank and Analysis Layer
4
+ ================================================
5
+ 负责符号排名、调用链查询、文件分析、AI 摘要生成。
6
+
7
+ 提供:
8
+ - PageRank 重要性计算
9
+ - 调用链追踪(callers/callees)
10
+ - 热点文件识别
11
+ - 模块摘要和推荐阅读顺序
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import logging
17
+ from collections import defaultdict
18
+ from collections import deque
19
+ from pathlib import PurePosixPath
20
+ from typing import Any, TYPE_CHECKING
21
+
22
+ from . import Edge, RepoGraph, Symbol, call_reference_parts
23
+ from .topic import is_test_like_file
24
+
25
+ if TYPE_CHECKING:
26
+ from .core import RepoMapEngine
27
+
28
+ logger = logging.getLogger("repomap")
29
+
30
+
31
+ class GraphAnalyzer:
32
+ """
33
+ 图分析器:执行 PageRank 和各种图查询。
34
+ """
35
+
36
+ LOW_SIGNAL_KINDS = {"element", "selector", "class_selector", "id_selector", "json_key"}
37
+ BOILERPLATE_NAMES = {"__init__", "__main__"}
38
+
39
+ def __init__(self, graph: RepoGraph) -> None:
40
+ self.graph = graph
41
+
42
+ def calculate_pagerank(self, damping: float = 0.85, max_iter: int = 50,
43
+ tol: float = 1e-6) -> None:
44
+ """带收敛检测的 PageRank。"""
45
+ syms = list(self.graph.symbols)
46
+ n = len(syms)
47
+ if n == 0:
48
+ return
49
+
50
+ pr = {s: 1.0 / n for s in syms}
51
+ # 预算 outgoing 权重和,过滤出权重>0的节点
52
+ out_w: dict[str, float] = {
53
+ s: sum(e.weight for e in self.graph.outgoing.get(s, []))
54
+ for s in syms
55
+ }
56
+ # 只保留有出边的节点,避免除零
57
+ active_srcs = {s for s, w in out_w.items() if w > 0}
58
+ # incoming: tgt -> [(src, weight)],只包含有出边的源节点
59
+ inc: dict[str, list[tuple[str, float]]] = defaultdict(list)
60
+ for src, edges in self.graph.outgoing.items():
61
+ if src in active_srcs:
62
+ for e in edges:
63
+ inc[e.target].append((src, e.weight))
64
+
65
+ base = (1 - damping) / n
66
+ for _ in range(max_iter):
67
+ new_pr: dict[str, float] = {}
68
+ for s in syms:
69
+ score = base + sum(
70
+ damping * pr[src] * w / out_w[src]
71
+ for src, w in inc[s]
72
+ )
73
+ new_pr[s] = score
74
+ total = sum(new_pr.values()) or 1.0
75
+ for s in syms:
76
+ new_pr[s] /= total
77
+ # 收敛检测
78
+ delta = max(abs(new_pr[s] - pr[s]) for s in syms)
79
+ pr = new_pr
80
+ if delta < tol:
81
+ break
82
+
83
+ for s, score in pr.items():
84
+ self.graph.symbols[s].pagerank = score
85
+
86
+ def query_symbol(self, name: str, query_context: str | None = None) -> list[Symbol]:
87
+ """按名称模糊查找符号,按 PageRank 降序返回。
88
+
89
+ query_context: 可选的关键词上下文,用于为匹配到的符号做局部 boost。
90
+ 当提供时,符号所在文件名命中关键词的会获得额外排序加分。
91
+ """
92
+ nl = name.lower()
93
+ candidates = [s for s in self.graph.symbols.values() if nl in s.name.lower()]
94
+ if not query_context or len(candidates) <= 1:
95
+ return sorted(candidates, key=lambda s: s.pagerank, reverse=True)
96
+
97
+ # 局部 boost:文件名或路径命中 query_context 关键词的符号获得加分
98
+ context_keywords = query_context.lower().split()
99
+ return sorted(
100
+ candidates,
101
+ key=lambda s: (
102
+ s.pagerank
103
+ + sum(
104
+ 0.05
105
+ for kw in context_keywords
106
+ if kw in s.file.lower()
107
+ )
108
+ ),
109
+ reverse=True,
110
+ )
111
+
112
+ def call_chain(self, symbol_id: str, direction: str = "both",
113
+ max_depth: int = 3) -> dict[str, list[Symbol]]:
114
+ """
115
+ 返回指定符号的调用链。
116
+ direction: "callers" | "callees" | "both"
117
+ """
118
+ result: dict[str, list[Symbol]] = {"callers": [], "callees": []}
119
+ if direction in ("callers", "both"):
120
+ result["callers"] = self._bfs(symbol_id, "incoming", max_depth, {"call"})
121
+ if direction in ("callees", "both"):
122
+ result["callees"] = self._bfs(symbol_id, "outgoing", max_depth, {"call"})
123
+ return result
124
+
125
+ def _bfs(
126
+ self,
127
+ start: str,
128
+ direction: str,
129
+ max_depth: int,
130
+ allowed_kinds: set[str] | None = None,
131
+ ) -> list[Symbol]:
132
+ """用 deque 实现 BFS,O(n) 复杂度。"""
133
+ visited = {start}
134
+ queue: deque[tuple[str, int]] = deque([(start, 0)])
135
+ result: list[Symbol] = []
136
+ edges_map = self.graph.incoming if direction == "incoming" else self.graph.outgoing
137
+
138
+ # 防止内存溢出:队列大小限制
139
+ MAX_QUEUE_SIZE = 10000
140
+ MAX_RESULTS = 1000
141
+
142
+ while queue:
143
+ # 队列大小安全检查
144
+ if len(queue) > MAX_QUEUE_SIZE:
145
+ logger.warning(f"BFS queue size exceeded limit ({MAX_QUEUE_SIZE}), truncating search")
146
+ break
147
+
148
+ cur, depth = queue.popleft()
149
+ if cur != start:
150
+ sym = self.graph.symbols.get(cur)
151
+ if sym:
152
+ result.append(sym)
153
+ # 结果数量限制,防止内存溢出
154
+ if len(result) >= MAX_RESULTS:
155
+ logger.debug(f"BFS reached max results ({MAX_RESULTS}), truncating")
156
+ break
157
+ if depth < max_depth:
158
+ for e in edges_map.get(cur, []):
159
+ if allowed_kinds is not None and e.kind not in allowed_kinds:
160
+ continue
161
+ nxt = e.source if direction == "incoming" else e.target
162
+ if nxt not in visited:
163
+ visited.add(nxt)
164
+ queue.append((nxt, depth + 1))
165
+ return result
166
+
167
+ def _edge_count(
168
+ self,
169
+ symbol_id: str,
170
+ direction: str,
171
+ allowed_kinds: set[str] | None = None,
172
+ ) -> int:
173
+ edges_map = self.graph.incoming if direction == "incoming" else self.graph.outgoing
174
+ return sum(
175
+ 1
176
+ for edge in edges_map.get(symbol_id, [])
177
+ if allowed_kinds is None or edge.kind in allowed_kinds
178
+ )
179
+
180
+ def _signal_weight(self, symbol: Symbol) -> float:
181
+ weight = 1.0
182
+ if symbol.kind in self.LOW_SIGNAL_KINDS:
183
+ return 0.002
184
+ if symbol.name in self.BOILERPLATE_NAMES:
185
+ weight *= 0.35
186
+ elif symbol.name.startswith("_") and symbol.visibility == "private":
187
+ weight *= 0.85
188
+ return weight
189
+
190
+ def _summary_symbol_score(self, symbol: Symbol) -> float:
191
+ incoming_calls = self._edge_count(symbol.id, "incoming", {"call"})
192
+ outgoing_calls = self._edge_count(symbol.id, "outgoing", {"call"})
193
+ incoming_imports = self._edge_count(symbol.id, "incoming", {"import"})
194
+ visibility_bonus = 1.2 if symbol.visibility == "exported" else 0.45 if symbol.visibility == "public" else 0.0
195
+ kind_bonus = 1.0 if symbol.kind == "class" else 0.55 if symbol.kind in {"function", "method"} else 0.2
196
+ import_bonus = min(incoming_imports, 4) * 0.15
197
+ centrality_bonus = symbol.pagerank * 40
198
+ return (
199
+ incoming_calls * 4.0
200
+ + outgoing_calls * 1.5
201
+ + import_bonus
202
+ + visibility_bonus
203
+ + kind_bonus
204
+ + centrality_bonus
205
+ ) * self._signal_weight(symbol)
206
+
207
+ def hotspots(self, limit: int = 15) -> list[dict]:
208
+ """识别高密度文件,优先看高语义密度而不是标签/配置噪音。"""
209
+ analysis = self.file_analysis()
210
+ counts = sorted(
211
+ analysis.values(),
212
+ key=lambda row: (row["is_test_file"], -row["semantic_symbol_count"], -row["score"], row["file"]),
213
+ )
214
+ return [
215
+ {
216
+ "file": row["file"],
217
+ "symbol_count": row["symbol_count"],
218
+ "semantic_symbol_count": round(row["semantic_symbol_count"], 1),
219
+ "risk": (
220
+ "high"
221
+ if row["semantic_symbol_count"] >= 12
222
+ else "medium"
223
+ if row["semantic_symbol_count"] >= 4
224
+ else "low"
225
+ ),
226
+ }
227
+ for row in counts[:limit]
228
+ ]
229
+
230
+ def entry_points(self) -> list[str]:
231
+ """识别常见的入口文件。支持子目录路径匹配。"""
232
+ candidates = [
233
+ "main.py", "app.py", "manage.py", "run.py", "server.py",
234
+ "main.go", "cmd/main.go",
235
+ "src/main.rs", "src/lib.rs",
236
+ "src/main.ts", "src/index.ts", "src/main.tsx", "src/index.tsx",
237
+ "src/main.js", "src/index.js",
238
+ "index.ts", "index.js",
239
+ # 支持 monorepo 子目录结构
240
+ "*/src/main.tsx", "*/src/main.ts", "*/src/index.tsx", "*/src/index.ts",
241
+ "*/src/main.js", "*/src/index.js",
242
+ "*/main.rs", "*/lib.rs",
243
+ ]
244
+ result = []
245
+ for c in candidates:
246
+ if "*" in c:
247
+ # 通配符匹配
248
+ pattern = c.replace("*/", "")
249
+ for f in self.graph.file_symbols:
250
+ if f.endswith(pattern):
251
+ result.append(f)
252
+ elif c in self.graph.file_symbols:
253
+ result.append(c)
254
+ return sorted(set(result))
255
+
256
+ def file_analysis(self) -> dict[str, dict[str, Any]]:
257
+ """分析每个文件的复杂度和连接性。"""
258
+ analysis: dict[str, dict[str, Any]] = {}
259
+
260
+ # 初始化文件分析数据
261
+ for file_path, symbol_ids in self.graph.file_symbols.items():
262
+ symbols = [
263
+ self.graph.symbols[symbol_id]
264
+ for symbol_id in symbol_ids
265
+ if symbol_id in self.graph.symbols
266
+ ]
267
+ ranked_symbols = sorted(
268
+ symbols,
269
+ key=lambda item: (-self._summary_symbol_score(item), item.line, item.name),
270
+ )
271
+ semantic_symbol_count = sum(self._signal_weight(symbol) for symbol in symbols)
272
+ semantic_pagerank_sum = sum(symbol.pagerank * self._signal_weight(symbol) for symbol in symbols)
273
+ weighted_exported_count = sum(
274
+ self._signal_weight(symbol)
275
+ for symbol in symbols
276
+ if symbol.visibility == "exported"
277
+ )
278
+ weighted_public_count = sum(
279
+ self._signal_weight(symbol)
280
+ for symbol in symbols
281
+ if symbol.visibility == "public"
282
+ )
283
+ analysis[file_path] = {
284
+ "file": file_path,
285
+ "symbol_count": len(symbols),
286
+ "semantic_symbol_count": semantic_symbol_count,
287
+ "pagerank_sum": sum(symbol.pagerank for symbol in symbols),
288
+ "semantic_pagerank_sum": semantic_pagerank_sum,
289
+ "implementation_score": sum(self._summary_symbol_score(symbol) for symbol in ranked_symbols[:5]),
290
+ "exported_count": weighted_exported_count,
291
+ "public_count": weighted_public_count,
292
+ "is_test_file": self._is_test_like_file(file_path),
293
+ "call_edges": 0,
294
+ "cross_file_call_edges": 0,
295
+ "import_edges": 0,
296
+ "neighbor_files": set(),
297
+ "top_symbols": [
298
+ symbol.name
299
+ for symbol in ranked_symbols[:3]
300
+ ],
301
+ }
302
+
303
+ # 统计边关系
304
+ for source_id, edge_list in self.graph.outgoing.items():
305
+ source_symbol = self.graph.symbols.get(source_id)
306
+ if not source_symbol:
307
+ continue
308
+ source_file = source_symbol.file
309
+ source_entry = analysis.setdefault(
310
+ source_file,
311
+ {
312
+ "file": source_file,
313
+ "symbol_count": 0,
314
+ "semantic_symbol_count": 0.0,
315
+ "pagerank_sum": 0.0,
316
+ "semantic_pagerank_sum": 0.0,
317
+ "implementation_score": 0.0,
318
+ "exported_count": 0,
319
+ "public_count": 0,
320
+ "is_test_file": self._is_test_like_file(source_file),
321
+ "call_edges": 0,
322
+ "cross_file_call_edges": 0,
323
+ "import_edges": 0,
324
+ "neighbor_files": set(),
325
+ "top_symbols": [],
326
+ },
327
+ )
328
+ for edge in edge_list:
329
+ target_symbol = self.graph.symbols.get(edge.target)
330
+ if not target_symbol:
331
+ continue
332
+ target_file = target_symbol.file
333
+ if edge.kind == "call":
334
+ source_entry["call_edges"] += 1
335
+ if source_file != target_file:
336
+ source_entry["neighbor_files"].add(target_file)
337
+ analysis.setdefault(
338
+ target_file,
339
+ {
340
+ "file": target_file,
341
+ "symbol_count": 0,
342
+ "semantic_symbol_count": 0.0,
343
+ "pagerank_sum": 0.0,
344
+ "semantic_pagerank_sum": 0.0,
345
+ "implementation_score": 0.0,
346
+ "exported_count": 0,
347
+ "public_count": 0,
348
+ "is_test_file": self._is_test_like_file(target_file),
349
+ "call_edges": 0,
350
+ "cross_file_call_edges": 0,
351
+ "import_edges": 0,
352
+ "neighbor_files": set(),
353
+ "top_symbols": [],
354
+ },
355
+ )["neighbor_files"].add(source_file)
356
+ if edge.kind == "call":
357
+ source_entry["cross_file_call_edges"] += 1
358
+ if edge.kind == "import":
359
+ source_entry["import_edges"] += 1
360
+
361
+ # 计算综合得分
362
+ for data in analysis.values():
363
+ neighbor_count = len(data["neighbor_files"])
364
+ data["neighbor_count"] = neighbor_count
365
+ data["score"] = (
366
+ data["implementation_score"]
367
+ + data["exported_count"] * 0.8
368
+ + data["public_count"] * 0.25
369
+ + data["semantic_symbol_count"] * 0.6
370
+ + neighbor_count * 0.45
371
+ + data["cross_file_call_edges"] * 0.25
372
+ + data["call_edges"] * 0.05
373
+ )
374
+ if data["is_test_file"]:
375
+ data["score"] *= 0.55
376
+
377
+ return analysis
378
+
379
+ def module_summary(self, limit: int = 8) -> list[dict[str, Any]]:
380
+ """生成模块级别的摘要。"""
381
+ modules: dict[str, list[dict[str, Any]]] = defaultdict(list)
382
+ analysis = self.file_analysis()
383
+ for file_path, file_data in analysis.items():
384
+ modules[self._module_bucket_for_file(file_path)].append(file_data)
385
+
386
+ rows: list[dict[str, Any]] = []
387
+ for module_name, file_rows in modules.items():
388
+ ordered_files = sorted(file_rows, key=lambda row: (-row["score"], row["file"]))
389
+ representative = ordered_files[0] if ordered_files else None
390
+ rows.append(
391
+ {
392
+ "module": module_name,
393
+ "file_count": len(file_rows),
394
+ "symbol_count": sum(row["symbol_count"] for row in file_rows),
395
+ "semantic_symbol_count": round(sum(row["semantic_symbol_count"] for row in file_rows), 1),
396
+ "pagerank_sum": sum(row["pagerank_sum"] for row in file_rows),
397
+ "semantic_pagerank_sum": sum(row["semantic_pagerank_sum"] for row in file_rows),
398
+ "representative_file": representative["file"] if representative else "",
399
+ "highlights": representative["top_symbols"][:3] if representative else [],
400
+ }
401
+ )
402
+ rows.sort(key=lambda row: (-row["semantic_pagerank_sum"], -row["semantic_symbol_count"], row["module"]))
403
+ return rows[:limit]
404
+
405
+ def suggested_reading_order(self, limit: int = 8) -> list[dict[str, Any]]:
406
+ """为 AI 生成推荐阅读顺序。"""
407
+ analysis = self.file_analysis()
408
+ suggestions: list[dict[str, Any]] = []
409
+ seen_files: set[str] = set()
410
+
411
+ # 首先推荐入口点
412
+ for entry in self.entry_points():
413
+ if entry not in analysis or entry in seen_files:
414
+ continue
415
+ file_data = analysis[entry]
416
+ suggestions.append(
417
+ {
418
+ "file": entry,
419
+ "reason": "入口点,适合先建立运行路径",
420
+ "top_symbols": file_data["top_symbols"][:3],
421
+ "symbol_count": file_data["symbol_count"],
422
+ "semantic_symbol_count": round(file_data["semantic_symbol_count"], 1),
423
+ }
424
+ )
425
+ seen_files.add(entry)
426
+ if len(suggestions) >= limit:
427
+ return suggestions
428
+
429
+ # 然后按重要性排序推荐其他文件
430
+ ordered_files = sorted(
431
+ analysis.values(),
432
+ key=lambda row: (row["is_test_file"], -row["score"], row["file"]),
433
+ )
434
+ for file_data in ordered_files:
435
+ file_path = file_data["file"]
436
+ if file_path in seen_files:
437
+ continue
438
+ if file_data["symbol_count"] <= 0:
439
+ continue
440
+ reason_parts: list[str] = []
441
+ if file_data["neighbor_count"] >= 3:
442
+ reason_parts.append("跨模块枢纽")
443
+ if file_data["exported_count"] >= 2:
444
+ reason_parts.append("导出面大")
445
+ if file_data["semantic_symbol_count"] >= 5:
446
+ reason_parts.append("逻辑密集")
447
+ if file_data["is_test_file"]:
448
+ reason_parts.append("测试验证入口")
449
+ if not reason_parts:
450
+ reason_parts.append("重要符号集中")
451
+ suggestions.append(
452
+ {
453
+ "file": file_path,
454
+ "reason": ",".join(reason_parts),
455
+ "top_symbols": file_data["top_symbols"][:3],
456
+ "symbol_count": file_data["symbol_count"],
457
+ "semantic_symbol_count": round(file_data["semantic_symbol_count"], 1),
458
+ }
459
+ )
460
+ seen_files.add(file_path)
461
+ if len(suggestions) >= limit:
462
+ break
463
+ return suggestions
464
+
465
+ def summary_symbols(
466
+ self,
467
+ limit_files: int = 6,
468
+ per_file: int = 4,
469
+ include_tests: bool = False,
470
+ ) -> list[dict[str, Any]]:
471
+ """给 overview 提供更适合阅读的关键实现符号。"""
472
+ analysis = self.file_analysis()
473
+ suggestion_rows = self.suggested_reading_order(max(limit_files * 2, limit_files))
474
+ reasons = {row["file"]: row["reason"] for row in suggestion_rows}
475
+ ordered_files = [row["file"] for row in suggestion_rows]
476
+ ordered_files.extend(
477
+ row["file"]
478
+ for row in sorted(
479
+ analysis.values(),
480
+ key=lambda row: (row["is_test_file"], -row["score"], row["file"]),
481
+ )
482
+ if row["file"] not in reasons
483
+ )
484
+
485
+ sections: list[dict[str, Any]] = []
486
+ for file_path in ordered_files:
487
+ file_data = analysis.get(file_path)
488
+ if not file_data:
489
+ continue
490
+ if file_data["is_test_file"] and not include_tests:
491
+ continue
492
+ symbols = [
493
+ self.graph.symbols[symbol_id]
494
+ for symbol_id in self.graph.file_symbols.get(file_path, [])
495
+ if symbol_id in self.graph.symbols
496
+ ]
497
+ ranked_symbols = sorted(
498
+ symbols,
499
+ key=lambda item: (-self._summary_symbol_score(item), item.line, item.name),
500
+ )
501
+ if not ranked_symbols:
502
+ continue
503
+ sections.append(
504
+ {
505
+ "file": file_path,
506
+ "reason": reasons.get(file_path, ""),
507
+ "symbol_count": file_data["symbol_count"],
508
+ "semantic_symbol_count": round(file_data["semantic_symbol_count"], 1),
509
+ "symbols": [
510
+ {
511
+ "name": symbol.name,
512
+ "kind": symbol.kind,
513
+ "line": symbol.line,
514
+ "visibility": symbol.visibility,
515
+ "signature": symbol.signature,
516
+ "pagerank": symbol.pagerank,
517
+ "summary_score": round(self._summary_symbol_score(symbol), 2),
518
+ "incoming_calls": self._edge_count(symbol.id, "incoming", {"call"}),
519
+ "outgoing_calls": self._edge_count(symbol.id, "outgoing", {"call"}),
520
+ }
521
+ for symbol in ranked_symbols[:per_file]
522
+ ],
523
+ }
524
+ )
525
+ if len(sections) >= limit_files:
526
+ break
527
+ return sections
528
+
529
+ @staticmethod
530
+ def _module_bucket_for_file(file_path: str) -> str:
531
+ """将文件路径归类到模块。"""
532
+ parts = [part for part in PurePosixPath(file_path).parts if part not in ("", ".")]
533
+ if not parts:
534
+ return "(root)"
535
+ if len(parts) == 1:
536
+ return "(root)"
537
+ if parts[0] in {"src", "app", "apps", "packages", "services", "modules", "libs", "lib"}:
538
+ return "/".join(parts[:2]) if len(parts) > 1 else parts[0]
539
+ return parts[0]
540
+
541
+ @staticmethod
542
+ def _is_test_like_file(file_path: str) -> bool:
543
+ """判断是否为测试文件。"""
544
+ return is_test_like_file(file_path)
545
+
546
+
547
+ class EdgeBuilder:
548
+ """
549
+ 边构建器:负责从符号和调用信息构建依赖图边。
550
+ """
551
+
552
+ IMPORT_WEIGHT = 0.35
553
+ CALL_WEIGHT = 0.50
554
+
555
+ # 符号可见性排序权重(用于选每个文件最具代表性的符号建边)
556
+ _VISIBILITY_RANK = {"exported": 3, "public": 2, "private": 1}
557
+ _KIND_RANK = {"class": 4, "function": 3, "method": 3, "anonymous_function": 2, "struct": 4, "interface": 4, "trait": 4, "enum": 4, "module": 3}
558
+
559
+ def __init__(self, graph: RepoGraph, resolver: Any) -> None:
560
+ self.graph = graph
561
+ self.resolver = resolver
562
+ self._edge_set: set[tuple[str, str, str]] = set()
563
+
564
+ def _top_symbol_ids(self, file: str, max_count: int = 3) -> list[str]:
565
+ """按语义重要性选文件中最具代表性的符号 ID。"""
566
+ ids = self.graph.file_symbols.get(file, [])
567
+ if len(ids) <= max_count:
568
+ return list(ids)
569
+ scored = []
570
+ for sid in ids:
571
+ sym = self.graph.symbols.get(sid)
572
+ if sym is None:
573
+ scored.append((0, sid))
574
+ continue
575
+ vis = self._VISIBILITY_RANK.get(sym.visibility, 0)
576
+ kind = self._KIND_RANK.get(sym.kind, 1)
577
+ scored.append((vis + kind, sid))
578
+ scored.sort(key=lambda x: -x[0])
579
+ return [sid for _, sid in scored[:max(max_count, int(len(ids) * 0.3))]]
580
+
581
+ def build_edges(self) -> None:
582
+ """构建 import 边和 call 边。"""
583
+ self.resolver.build_indices()
584
+
585
+ import_targets_by_file: dict[str, set[str]] = defaultdict(set)
586
+ import_symbol_targets_by_file: dict[str, dict[str, set[str]]] = defaultdict(lambda: defaultdict(set))
587
+
588
+ # import 边
589
+ for file, imports in sorted(self.graph.file_imports.items()):
590
+ src_ids = self._top_symbol_ids(file)
591
+ for imp in imports:
592
+ target_files = self.resolver.resolve_import_targets(file, imp)
593
+ if not target_files:
594
+ continue
595
+ import_targets_by_file[file].update(target_files)
596
+ for target_file in target_files:
597
+ tgt_ids = self._top_symbol_ids(target_file)
598
+ for s in src_ids:
599
+ for t in tgt_ids:
600
+ self._add_edge(s, t, self.IMPORT_WEIGHT, "import")
601
+
602
+ for binding in self.graph.file_import_bindings.get(file, []):
603
+ target_ids = self.resolver.resolve_import_binding_targets(file, binding)
604
+ if not target_ids:
605
+ continue
606
+ import_symbol_targets_by_file[file][binding.local_name].update(target_ids)
607
+ for target_id in sorted(target_ids):
608
+ import_targets_by_file[file].add(self.graph.symbols[target_id].file)
609
+ for source_id in src_ids:
610
+ self._add_edge(source_id, target_id, self.IMPORT_WEIGHT, "import")
611
+
612
+ # call 边
613
+ for file, calls in sorted(self.graph.file_calls.items()):
614
+ for call_ref in calls:
615
+ call_name, call_line, call_kind = call_reference_parts(call_ref)
616
+ caller_id = self.resolver.resolve_calling_symbol(file, call_line)
617
+ if not caller_id:
618
+ continue
619
+ target_id = self.resolver.resolve_call_target(
620
+ file=file,
621
+ call_name=call_name,
622
+ call_line=call_line,
623
+ call_kind=call_kind,
624
+ import_targets_by_file=import_targets_by_file,
625
+ import_symbol_targets_by_file=import_symbol_targets_by_file,
626
+ )
627
+ if target_id:
628
+ self._add_edge(caller_id, target_id, self.CALL_WEIGHT, "call")
629
+
630
+ def _add_edge(self, src: str, tgt: str, weight: float, kind: str) -> None:
631
+ if src == tgt:
632
+ return
633
+ key = (src, tgt, kind)
634
+ if key in self._edge_set:
635
+ return
636
+ self._edge_set.add(key)
637
+ e = Edge(src, tgt, weight, kind)
638
+ self.graph.outgoing[src].append(e)
639
+ self.graph.incoming[tgt].append(e)