java-codebase-rag 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pr_analysis.py ADDED
@@ -0,0 +1,534 @@
1
+ """Unified-diff → symbol mapping and PR-style risk scoring (B4 / PR-B).
2
+
3
+ Uses the `unidiff` library for parsing. Graph-resident symbols only; newly
4
+ added Java members are not modelled — see `notes` on the returned report.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import re
9
+ from dataclasses import asdict, dataclass
10
+ from typing import Any
11
+
12
+ from unidiff import PatchSet
13
+ from unidiff.errors import UnidiffParseError
14
+
15
+ from kuzu_queries import SymbolHit, find_symbols_in_file_range, _row_to_symbol
16
+
17
+
18
+ @dataclass
19
+ class DiffHunk:
20
+ """One unified-diff hunk in the *new* file coordinate system."""
21
+
22
+ target_path: str
23
+ source_path: str
24
+ target_line_start: int # inclusive, 1-based; 0 when the hunk has no new-file lines
25
+ target_line_end: int # inclusive
26
+ source_line_start: int
27
+ source_line_end: int
28
+ source_length: int = 0
29
+ target_length: int = 0
30
+
31
+
32
+ @dataclass
33
+ class ChangedSymbol:
34
+ symbol_id: str
35
+ fqn: str
36
+ kind: str # 'method' | 'type' | 'field'
37
+ change_type: str # 'added' | 'removed' | 'modified'
38
+ file: str
39
+ hunk_lines: list[int]
40
+ cross_service_callers_count: int = 0
41
+
42
+
43
+ @dataclass
44
+ class PrRiskReport:
45
+ changed_symbols: list[ChangedSymbol]
46
+ blast_radius_total: int
47
+ blast_radius_by_symbol: dict[str, int]
48
+ cross_service_callers: int
49
+ routes_touched: list[str]
50
+ risk_score: float
51
+ risk_band: str
52
+ notes: list[str]
53
+
54
+
55
+ _BINARY_DIFF_LINE = re.compile(r"^Binary files .+ differ\s*$")
56
+ # Heuristic: new Java method/ctor-looking line. Covers annotations, method-level
57
+ # generics, `default` interface methods, and return types with spaces (e.g.
58
+ # `Map<String, String> m(`). Misses multi-line signatures, some compact record
59
+ # forms, and unusual annotations; `_notes_for_unindexed_additions` is best-effort.
60
+ _DECL_ADD = re.compile(
61
+ r"^\+\s*"
62
+ r"(?:(?:@[\w.]+\([^)]*\))\s+)*"
63
+ r"(?:<[^>]+>\s+)?"
64
+ r"(?:(?:public|private|protected|default|static|final|synchronized|abstract|native)\s+)*"
65
+ r"(.+?)\s+(\w+)\s*\(",
66
+ )
67
+
68
+
69
+ def _strip_ab_prefix(path: str) -> str:
70
+ p = path.strip()
71
+ if p.startswith(("a/", "b/")):
72
+ return p[2:]
73
+ return p
74
+
75
+
76
+ def _hunk_ranges(h: Any) -> tuple[tuple[int, int], tuple[int, int]]:
77
+ """Return ((src_start, src_end inclusive), (tgt_start, tgt_end inclusive))."""
78
+ src_len = int(getattr(h, "source_length", 0) or 0)
79
+ tgt_len = int(getattr(h, "target_length", 0) or 0)
80
+ src_start = int(getattr(h, "source_start", 0) or 0)
81
+ tgt_start = int(getattr(h, "target_start", 0) or 0)
82
+ if src_len <= 0:
83
+ src_start, src_end = 0, 0
84
+ else:
85
+ src_end = src_start + src_len - 1
86
+ if tgt_len <= 0:
87
+ tgt_start, tgt_end = 0, 0
88
+ else:
89
+ tgt_end = tgt_start + tgt_len - 1
90
+ return (src_start, src_end), (tgt_start, tgt_end)
91
+
92
+
93
+ def parse_unified_diff(diff_text: str) -> list[DiffHunk]:
94
+ """Parse `diff_text` into logical hunks (non-binary, non-rename files only)."""
95
+ if not (diff_text or "").strip():
96
+ return []
97
+ try:
98
+ patches = PatchSet(diff_text.splitlines(keepends=True))
99
+ except UnidiffParseError:
100
+ return []
101
+ out: list[DiffHunk] = []
102
+ for pf in patches:
103
+ if getattr(pf, "is_rename", False):
104
+ continue
105
+ tgt = _strip_ab_prefix(str(pf.path or ""))
106
+ src = _strip_ab_prefix(str(getattr(pf, "source_file", "") or pf.path or ""))
107
+ if not tgt:
108
+ continue
109
+ for h in pf:
110
+ (s0, s1), (t0, t1) = _hunk_ranges(h)
111
+ sl = int(getattr(h, "source_length", 0) or 0)
112
+ tl = int(getattr(h, "target_length", 0) or 0)
113
+ out.append(
114
+ DiffHunk(
115
+ target_path=tgt,
116
+ source_path=src,
117
+ target_line_start=t0,
118
+ target_line_end=t1,
119
+ source_line_start=s0,
120
+ source_line_end=s1,
121
+ source_length=sl,
122
+ target_length=tl,
123
+ )
124
+ )
125
+ return out
126
+
127
+
128
+ def collect_diff_file_notes(diff_text: str) -> list[str]:
129
+ """Collect human-readable notes for binary diffs and renames (no crash)."""
130
+ notes: list[str] = []
131
+ if not (diff_text or "").strip():
132
+ return notes
133
+ for line in diff_text.splitlines():
134
+ if _BINARY_DIFF_LINE.match(line):
135
+ notes.append(f"skipped binary diff: {line.strip()}")
136
+ try:
137
+ patches = PatchSet(diff_text.splitlines(keepends=True))
138
+ except UnidiffParseError:
139
+ notes.append("diff text could not be fully parsed as a unified patch")
140
+ return notes
141
+ for pf in patches:
142
+ if getattr(pf, "is_rename", False):
143
+ a = _strip_ab_prefix(str(getattr(pf, "source_file", "") or ""))
144
+ b = _strip_ab_prefix(str(pf.path or ""))
145
+ notes.append(f"rename (symbols not mapped): {a} -> {b}")
146
+ return notes
147
+
148
+
149
+ def _resolve_graph_filename(
150
+ graph: Any,
151
+ path: str,
152
+ *,
153
+ ambiguity_notes: list[str] | None = None,
154
+ ) -> str | None:
155
+ """Map a diff path to `Symbol.filename` values stored in Kuzu."""
156
+ variants = {_strip_ab_prefix(path)}
157
+ for v in list(variants):
158
+ if v.startswith("./"):
159
+ variants.add(v[2:])
160
+ for candidate in variants:
161
+ if not candidate:
162
+ continue
163
+ rows = graph._rows(
164
+ "MATCH (s:Symbol) WHERE s.filename = $fn RETURN s.filename AS fn LIMIT 1",
165
+ {"fn": candidate},
166
+ )
167
+ if rows and rows[0].get("fn"):
168
+ return str(rows[0]["fn"])
169
+ tail = path.strip().split("/")[-1]
170
+ if tail:
171
+ rows = graph._rows(
172
+ "MATCH (s:Symbol) WHERE s.filename ENDS WITH $tail "
173
+ "RETURN DISTINCT s.filename AS fn LIMIT 8",
174
+ {"tail": "/" + tail},
175
+ )
176
+ n = len(rows)
177
+ if n > 1 and ambiguity_notes is not None:
178
+ fns = [str(r.get("fn") or "") for r in rows if r.get("fn")]
179
+ ambiguity_notes.append(
180
+ f"ambiguous filename tail {tail!r} ({n} graph paths); "
181
+ f"ENDS WITH resolution skipped ({', '.join(fns[:4])}"
182
+ f"{'…' if len(fns) > 4 else ''})",
183
+ )
184
+ if n == 1 and rows[0].get("fn"):
185
+ return str(rows[0]["fn"])
186
+ return None
187
+
188
+
189
+ def _symbol_to_changed(
190
+ sym: SymbolHit,
191
+ *,
192
+ change_type: str,
193
+ lines: list[int],
194
+ ) -> ChangedSymbol:
195
+ kind = sym.kind
196
+ if kind in ("class", "interface", "enum", "record", "annotation"):
197
+ mapped_kind = "type"
198
+ elif kind == "field":
199
+ mapped_kind = "field"
200
+ elif kind == "constructor":
201
+ mapped_kind = "method"
202
+ else:
203
+ mapped_kind = "method"
204
+ uniq = sorted({int(x) for x in lines if int(x) > 0})
205
+ return ChangedSymbol(
206
+ symbol_id=sym.id,
207
+ fqn=sym.fqn,
208
+ kind=mapped_kind,
209
+ change_type=change_type,
210
+ file=sym.filename,
211
+ hunk_lines=uniq,
212
+ )
213
+
214
+
215
+ def _decl_added_lines_for_file(diff_text: str, resolved_filename: str) -> int:
216
+ """Count `+` lines in the diff that look like Java member declarations for one file."""
217
+ lines = diff_text.splitlines()
218
+ in_file = False
219
+ n = 0
220
+ for line in lines:
221
+ if line.startswith("+++ "):
222
+ rest = line[4:].strip()
223
+ if rest.startswith("b/"):
224
+ rest = rest[2:]
225
+ in_file = rest.endswith(resolved_filename) or resolved_filename.endswith(rest)
226
+ continue
227
+ if not in_file:
228
+ continue
229
+ if _DECL_ADD.match(line):
230
+ n += 1
231
+ return n
232
+
233
+
234
+ def _notes_for_unindexed_additions(
235
+ graph: Any,
236
+ diff_text: str,
237
+ changed: list[ChangedSymbol],
238
+ hunks: list[DiffHunk],
239
+ ) -> list[str]:
240
+ """Heuristic: added declaration lines vs indexed methods touched on the same file."""
241
+ notes: list[str] = []
242
+ if not diff_text.strip():
243
+ return notes
244
+ for h in hunks:
245
+ tgt_fn = _resolve_graph_filename(graph, h.target_path)
246
+ if not tgt_fn or h.target_line_start <= 0:
247
+ continue
248
+ decls = _decl_added_lines_for_file(diff_text, tgt_fn)
249
+ if decls <= 0:
250
+ continue
251
+ methods_here = [c for c in changed if c.kind == "method" and c.file == tgt_fn]
252
+ if decls > len(methods_here):
253
+ extra = decls - len(methods_here)
254
+ notes.append(
255
+ f"{extra} new method(s) not yet indexed; risk underestimated",
256
+ )
257
+ return notes
258
+
259
+
260
+ def map_hunks_to_symbols(
261
+ graph: Any,
262
+ hunks: list[DiffHunk],
263
+ *,
264
+ path_ambiguity_notes: list[str] | None = None,
265
+ ) -> list[ChangedSymbol]:
266
+ """Map diff hunks to overlapping `Symbol` rows (graph-resident only)."""
267
+ by_id: dict[str, ChangedSymbol] = {}
268
+
269
+ def merge(sym: ChangedSymbol) -> None:
270
+ existing = by_id.get(sym.symbol_id)
271
+ if existing is None:
272
+ by_id[sym.symbol_id] = sym
273
+ else:
274
+ if existing.change_type == "modified" or sym.change_type == "modified":
275
+ ct = "modified"
276
+ elif existing.change_type == "removed" or sym.change_type == "removed":
277
+ ct = "removed"
278
+ else:
279
+ ct = sym.change_type
280
+ merged_lines = sorted(set(existing.hunk_lines + sym.hunk_lines))
281
+ by_id[sym.symbol_id] = ChangedSymbol(
282
+ symbol_id=existing.symbol_id,
283
+ fqn=existing.fqn,
284
+ kind=existing.kind,
285
+ change_type=ct,
286
+ file=existing.file,
287
+ hunk_lines=merged_lines,
288
+ )
289
+
290
+ for h in hunks:
291
+ tgt_fn = _resolve_graph_filename(
292
+ graph, h.target_path, ambiguity_notes=path_ambiguity_notes,
293
+ )
294
+ src_fn = (
295
+ _resolve_graph_filename(
296
+ graph, h.source_path, ambiguity_notes=path_ambiguity_notes,
297
+ )
298
+ if h.source_path
299
+ else tgt_fn
300
+ )
301
+ if not tgt_fn and not src_fn:
302
+ continue
303
+
304
+ minus_only = h.target_length == 0 and h.source_length > 0
305
+
306
+ # Removed lines on old file (process before modified so mixed hunks prefer modified)
307
+ if h.source_line_start > 0 and h.source_line_end >= h.source_line_start and src_fn:
308
+ rows = find_symbols_in_file_range(
309
+ graph,
310
+ filename=src_fn,
311
+ start_line=h.source_line_start,
312
+ end_line=h.source_line_end,
313
+ )
314
+ for sym in rows:
315
+ if sym.kind == "file":
316
+ continue
317
+ overlap = list(range(
318
+ max(h.source_line_start, sym.start_line),
319
+ min(h.source_line_end, sym.end_line) + 1,
320
+ ))
321
+ if minus_only:
322
+ merge(_symbol_to_changed(sym, change_type="removed", lines=overlap))
323
+
324
+ # Modified / added lines on new file
325
+ if h.target_line_start > 0 and h.target_line_end >= h.target_line_start and tgt_fn:
326
+ rows = find_symbols_in_file_range(
327
+ graph,
328
+ filename=tgt_fn,
329
+ start_line=h.target_line_start,
330
+ end_line=h.target_line_end,
331
+ )
332
+ for sym in rows:
333
+ if sym.kind == "file":
334
+ continue
335
+ merge(_symbol_to_changed(sym, change_type="modified", lines=list(range(
336
+ max(h.target_line_start, sym.start_line),
337
+ min(h.target_line_end, sym.end_line) + 1,
338
+ ))))
339
+
340
+ return list(by_id.values())
341
+
342
+
343
+ def _impact_needle_for_changed(_graph: Any, fqn: str, mapped_kind: str) -> str:
344
+ """Pick the `impact_analysis` needle: type FQN for members, else the symbol FQN."""
345
+ if mapped_kind in ("method", "field", "constructor"):
346
+ if "#" in fqn:
347
+ return fqn.split("#", 1)[0]
348
+ return fqn
349
+
350
+
351
+ def _is_public_interface_method(graph: Any, sym: SymbolHit) -> bool:
352
+ if sym.kind != "method":
353
+ return False
354
+ if "private" in (sym.modifiers or []):
355
+ return False
356
+ type_fqn = sym.fqn.split("#", 1)[0] if "#" in sym.fqn else sym.fqn
357
+ rows = graph._rows(
358
+ "MATCH (t:Symbol) WHERE t.fqn = $f AND t.kind = 'interface' RETURN t.id LIMIT 1",
359
+ {"f": type_fqn},
360
+ )
361
+ return bool(rows)
362
+
363
+
364
+ def _route_ids_for_symbol(graph: Any, symbol_id: str) -> list[str]:
365
+ # Note: Kuzu rejects `ORDER BY r.id` together with `RETURN DISTINCT r.id` (binder loses `r`).
366
+ q = (
367
+ "MATCH (s:Symbol)-[e:EXPOSES]->(r:Route) WHERE s.id = $sid "
368
+ "RETURN r.id AS id ORDER BY id"
369
+ )
370
+ seen: set[str] = set()
371
+ out: list[str] = []
372
+ for row in graph._rows(q, {"sid": symbol_id}):
373
+ rid = str(row.get("id") or "")
374
+ if rid and rid not in seen:
375
+ seen.add(rid)
376
+ out.append(rid)
377
+ return out
378
+
379
+
380
+ def compute_risk(graph: Any, changed: list[ChangedSymbol]) -> PrRiskReport:
381
+ """Aggregate blast radius, routes, cross-service callers, and v1 risk score.
382
+
383
+ Risk score stays in [0, 1]. Cross-service route callers add a bounded
384
+ bump (up to +1.0) after normalization so they influence rank while
385
+ preserving the public scalar contract.
386
+ """
387
+ notes: list[str] = []
388
+ blast_by: dict[str, int] = {}
389
+ blast_total = 0
390
+ routes: list[str] = []
391
+ cross_total = 0
392
+
393
+ sym_cols = (
394
+ "id", "kind", "name", "fqn", "package", "module", "microservice",
395
+ "filename", "start_line", "end_line", "start_byte", "end_byte",
396
+ "modifiers", "annotations", "capabilities", "role", "signature",
397
+ "parent_id", "resolved",
398
+ )
399
+ _sym_return = ", ".join(f"s.{c} AS {c}" for c in sym_cols)
400
+
401
+ iface_hit = 0.0
402
+ enriched_changed: list[ChangedSymbol] = []
403
+ for cs in changed:
404
+ sym_row = graph._rows(
405
+ "MATCH (s:Symbol) WHERE s.id = $id RETURN " + _sym_return,
406
+ {"id": cs.symbol_id},
407
+ )
408
+ if not sym_row:
409
+ continue
410
+ row0 = sym_row[0]
411
+ if iface_hit < 1.0:
412
+ sym = _row_to_symbol(row0)
413
+ if _is_public_interface_method(graph, sym):
414
+ iface_hit = 1.0
415
+ fqn = str(row0.get("fqn") or cs.fqn)
416
+ needle = _impact_needle_for_changed(graph, fqn, cs.kind)
417
+ ia = graph.impact_analysis(needle, depth=2, limit=400)
418
+ n = len(ia)
419
+ blast_by[cs.symbol_id] = n
420
+ blast_total += n
421
+
422
+ for e in graph.find_callers(cs.fqn, depth=2, limit=400):
423
+ if (
424
+ e.src.microservice
425
+ and e.dst.microservice
426
+ and e.src.microservice != e.dst.microservice
427
+ ):
428
+ cross_total += 1
429
+
430
+ cs_cross_service = 0
431
+ route_ids = _route_ids_for_symbol(graph, cs.symbol_id)
432
+ for rid in route_ids:
433
+ if rid not in routes:
434
+ routes.append(rid)
435
+ callers = graph._rows(
436
+ "MATCH (s:Symbol)-[:DECLARES_CLIENT]->(c:Client)-[e:HTTP_CALLS]->(r:Route {id: $rid}) "
437
+ "WHERE e.match = 'cross_service' "
438
+ "RETURN c.id AS id LIMIT 500",
439
+ {"rid": rid},
440
+ )
441
+ callers += graph._rows(
442
+ "MATCH (s:Symbol)-[:DECLARES_PRODUCER]->(p:Producer)-[e:ASYNC_CALLS]->(r:Route {id: $rid}) "
443
+ "WHERE e.match = 'cross_service' "
444
+ "RETURN p.id AS id LIMIT 500",
445
+ {"rid": rid},
446
+ )
447
+ cs_cross_service += len(callers)
448
+ enriched_changed.append(
449
+ ChangedSymbol(
450
+ symbol_id=cs.symbol_id,
451
+ fqn=cs.fqn,
452
+ kind=cs.kind,
453
+ change_type=cs.change_type,
454
+ file=cs.file,
455
+ hunk_lines=list(cs.hunk_lines),
456
+ cross_service_callers_count=cs_cross_service,
457
+ ),
458
+ )
459
+
460
+ def _normalize(x: float, ceiling: float) -> float:
461
+ if ceiling <= 0:
462
+ return 0.0
463
+ return min(float(x), ceiling) / ceiling
464
+
465
+ # v1 risk weights / ceilings (PR-B §1.2): intentionally simple baselines;
466
+ # these constants are expected to be tuned after real-world use — do not treat as stable.
467
+ w_blast, cap_blast = 0.4, 100.0
468
+ w_cross, cap_cross = 0.3, 20.0
469
+ w_iface = 0.2
470
+ w_routes, cap_routes = 0.1, 5.0
471
+
472
+ raw = (
473
+ w_blast * _normalize(float(blast_total), cap_blast)
474
+ + w_cross * _normalize(float(cross_total), cap_cross)
475
+ + w_iface * iface_hit
476
+ + w_routes * _normalize(float(len(routes)), cap_routes)
477
+ )
478
+ cross_service_bonus = min(
479
+ 5.0,
480
+ float(sum(c.cross_service_callers_count for c in enriched_changed)),
481
+ )
482
+ score = max(0.0, min(1.0, raw + (cross_service_bonus / 5.0)))
483
+ if score < 0.3:
484
+ band = "low"
485
+ elif score < 0.7:
486
+ band = "medium"
487
+ else:
488
+ band = "high"
489
+
490
+ return PrRiskReport(
491
+ changed_symbols=list(enriched_changed),
492
+ blast_radius_total=blast_total,
493
+ blast_radius_by_symbol=blast_by,
494
+ cross_service_callers=cross_total,
495
+ routes_touched=routes,
496
+ risk_score=score,
497
+ risk_band=band,
498
+ notes=notes,
499
+ )
500
+
501
+
502
+ def pr_report_to_dict(rep: PrRiskReport) -> dict[str, Any]:
503
+ return {
504
+ "changed_symbols": [asdict(c) for c in rep.changed_symbols],
505
+ "blast_radius_total": rep.blast_radius_total,
506
+ "blast_radius_by_symbol": dict(rep.blast_radius_by_symbol),
507
+ "cross_service_callers": rep.cross_service_callers,
508
+ "routes_touched": list(rep.routes_touched),
509
+ "risk_score": rep.risk_score,
510
+ "risk_band": rep.risk_band,
511
+ "notes": list(rep.notes),
512
+ }
513
+
514
+
515
+ def analyze_pr_pipeline(graph: Any, diff_unified: str) -> PrRiskReport:
516
+ """Full PR-B pipeline: parse → notes → map → risk."""
517
+ notes = collect_diff_file_notes(diff_unified)
518
+ hunks = parse_unified_diff(diff_unified)
519
+ path_amb: list[str] = []
520
+ changed = map_hunks_to_symbols(graph, hunks, path_ambiguity_notes=path_amb)
521
+ notes.extend(path_amb)
522
+ notes.extend(_notes_for_unindexed_additions(graph, diff_unified, changed, hunks))
523
+ rep = compute_risk(graph, changed)
524
+ merged = list(dict.fromkeys([*notes, *rep.notes]))
525
+ return PrRiskReport(
526
+ changed_symbols=rep.changed_symbols,
527
+ blast_radius_total=rep.blast_radius_total,
528
+ blast_radius_by_symbol=rep.blast_radius_by_symbol,
529
+ cross_service_callers=rep.cross_service_callers,
530
+ routes_touched=rep.routes_touched,
531
+ risk_score=rep.risk_score,
532
+ risk_band=rep.risk_band,
533
+ notes=merged,
534
+ )