cherry-docs 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. app/__init__.py +0 -0
  2. app/repo_scope.py +24 -0
  3. app/services/__init__.py +0 -0
  4. app/services/agent_protocol.py +59 -0
  5. app/services/auto_promote_sessions.py +245 -0
  6. app/services/capture_adapters.py +89 -0
  7. app/services/capture_core.py +164 -0
  8. app/services/internal_memory_agent.py +214 -0
  9. app/services/memory_evidence.py +89 -0
  10. app/services/memory_extraction_normalize.py +134 -0
  11. app/services/memory_lifecycle.py +258 -0
  12. app/services/memory_profiles.py +88 -0
  13. app/services/memory_providers.py +113 -0
  14. app/services/memory_retrieval.py +327 -0
  15. app/services/memory_retrieval_scoring.py +106 -0
  16. app/services/memory_retrieval_text.py +113 -0
  17. app/services/memory_similarity.py +135 -0
  18. app/services/privacy.py +72 -0
  19. app/services/promoted_memory_answer.py +157 -0
  20. app/services/promoted_memory_pipeline.py +194 -0
  21. app/services/promoted_memory_store.py +57 -0
  22. cherry_docs-0.2.0.dist-info/METADATA +143 -0
  23. cherry_docs-0.2.0.dist-info/RECORD +42 -0
  24. cherry_docs-0.2.0.dist-info/WHEEL +5 -0
  25. cherry_docs-0.2.0.dist-info/entry_points.txt +4 -0
  26. cherry_docs-0.2.0.dist-info/top_level.txt +3 -0
  27. cherrydocs/__init__.py +3 -0
  28. cherrydocs/cli.py +213 -0
  29. cherrydocs/hook.py +27 -0
  30. cherrydocs/mcp.py +22 -0
  31. scripts/__init__.py +0 -0
  32. scripts/auto_promote_capture.py +63 -0
  33. scripts/check_size_limits.py +115 -0
  34. scripts/ci_auto_capture.py +289 -0
  35. scripts/claude_hooks/__init__.py +0 -0
  36. scripts/claude_hooks/state_manager.py +526 -0
  37. scripts/coverage_regression_gate.py +121 -0
  38. scripts/eval_projects.py +247 -0
  39. scripts/install.py +212 -0
  40. scripts/pr_gate_report.py +282 -0
  41. scripts/promptfoo_regression_gate.py +176 -0
  42. scripts/render_agent_prompts.py +57 -0
@@ -0,0 +1,526 @@
1
+ #!/usr/bin/env python3
2
+ """Local session state manager for Claude Code hooks.
3
+
4
+ Commands: session-start | post-tool-use | stop | reset
5
+ State isolated per session_id; override path via CHERRY_HOOKS_STATE_FILE.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import json
11
+ import os
12
+ import re
13
+ import subprocess
14
+ import sys
15
+ from datetime import datetime, timezone
16
+ from pathlib import Path
17
+
18
+ ROOT = Path(__file__).resolve().parents[2]
19
+ if str(ROOT) not in sys.path:
20
+ sys.path.insert(0, str(ROOT))
21
+
22
+ from app.services.capture_adapters import append_capture_event, infer_capture_event_type
23
+ from app.services.capture_core import CaptureEventType
24
+
25
+ _STATE_FILE_DEFAULT = Path(".claude/hooks/session_state.json")
26
+ _HOME_CHERRY = Path.home() / ".cherrydocs"
27
+ _CAPTURE_BUFFER_ROOT = Path(os.environ.get("CHERRY_CAPTURE_BUFFER_DIR", str(_HOME_CHERRY / "capture")))
28
+ _PROMOTED_ROOT = Path(os.environ.get("CHERRY_PROMOTED_ROOT", str(_HOME_CHERRY / "promoted")))
29
+ _SESSION_ID_FILE = _CAPTURE_BUFFER_ROOT / ".current-session-id"
30
+
31
+ _DEFAULTS: dict = {
32
+ "session_initialized": False,
33
+ "needs_log": False,
34
+ "needs_checkpoint": False,
35
+ "last_log_at": None,
36
+ "last_checkpoint_at": None,
37
+ "meaningful_work_done": False,
38
+ "edits_since_last_log": 0,
39
+ "files_touched": [],
40
+ }
41
+
42
+ # Edit/write actions that produce durable code changes.
43
+ _CODE_TOOLS = frozenset({"Edit", "Write", "NotebookEdit"})
44
+ # Running commands — may produce findings worth logging (test results, errors).
45
+ _RUN_TOOLS = frozenset({"Bash"})
46
+ # CherryDocs tools that satisfy requirements.
47
+ _LOG_TOOLS = frozenset({"mcp__cherry-docs__log_activity"})
48
+ _CHECKPOINT_TOOLS = frozenset({"mcp__cherry-docs__save_checkpoint"})
49
+
50
+ # After this many edits without a log, remind at every prompt.
51
+ _LOG_NUDGE_AFTER_EDITS = 3
52
+
53
+ _CLAUDE_PROMPT_PATH = Path(".claude/CLAUDE.md")
54
+ _GENERATED_PROMPT_RE = re.compile(
55
+ r"Generated from docs/agent_protocol\.toml version=([^\s]+) hash=([0-9a-f]+);"
56
+ )
57
+ _GIT_COMMIT_OUTPUT_RE = re.compile(r"\[(?:[^\]]+)\s+([0-9a-f]{5,40})\]")
58
+
59
+ _STARTUP_REMINDER = """\
60
+ [CHERRYDOCS STARTUP]
61
+ Protocol reminder for this session:
62
+ 1. Call onboard to load the default project memory view.
63
+ 2. Work normally and let passive capture record raw traces when available.
64
+ 3. Use remember/log_activity only when something important would otherwise be lost.
65
+ """
66
+
67
+
68
+ def _append_capture_event(
69
+ event_type: CaptureEventType,
70
+ *,
71
+ session_id: str | None,
72
+ cwd: str | None = None,
73
+ text: str | None = None,
74
+ command: str | None = None,
75
+ exit_code: int | None = None,
76
+ metadata: dict | None = None,
77
+ ) -> None:
78
+ try:
79
+ append_capture_event(
80
+ buffer_dir=_CAPTURE_BUFFER_ROOT,
81
+ source="claude-code",
82
+ event_type=event_type,
83
+ session_id=session_id,
84
+ cwd=cwd,
85
+ text=text,
86
+ command=command,
87
+ exit_code=exit_code,
88
+ metadata=metadata,
89
+ )
90
+ except Exception:
91
+ return
92
+
93
+
94
+ def _parse_edited_file(event: dict) -> str | None:
95
+ """Extract the file path from an Edit/Write tool input."""
96
+ tool_input = event.get("tool_input") or {}
97
+ if isinstance(tool_input, str):
98
+ try:
99
+ tool_input = json.loads(tool_input)
100
+ except json.JSONDecodeError:
101
+ return None
102
+ return tool_input.get("file_path") or tool_input.get("path") or None
103
+
104
+
105
+ def _parse_generated_prompt_metadata(text: str) -> dict | None:
106
+ m = _GENERATED_PROMPT_RE.search(text)
107
+ if not m:
108
+ return None
109
+ return {"version": m.group(1), "hash": m.group(2)}
110
+
111
+
112
+ def _read_local_prompt_metadata() -> dict | None:
113
+ try:
114
+ return _parse_generated_prompt_metadata(_CLAUDE_PROMPT_PATH.read_text(encoding="utf-8"))
115
+ except OSError:
116
+ return None
117
+
118
+
119
+ def _read_ref_prompt_metadata(ref: str = "origin/main") -> dict | None:
120
+ try:
121
+ proc = subprocess.run(
122
+ ["git", "show", f"{ref}:.claude/CLAUDE.md"],
123
+ capture_output=True, text=True, check=True, timeout=2,
124
+ )
125
+ except (OSError, subprocess.SubprocessError):
126
+ return None
127
+ return _parse_generated_prompt_metadata(proc.stdout)
128
+
129
+
130
+ def _protocol_warning_text() -> str | None:
131
+ local_meta = _read_local_prompt_metadata()
132
+ ref_meta = _read_ref_prompt_metadata()
133
+ if not local_meta or not ref_meta:
134
+ return None
135
+ if local_meta["hash"] == ref_meta["hash"]:
136
+ return None
137
+ return (
138
+ "[CHERRYDOCS PROTOCOL WARNING] Local .claude/CLAUDE.md "
139
+ f"(version={local_meta['version']} hash={local_meta['hash']}) differs from "
140
+ f"origin/main (version={ref_meta['version']} hash={ref_meta['hash']}). "
141
+ "Sync from latest main if this branch is not intentionally changing prompt rules."
142
+ )
143
+
144
+
145
+ # State helpers
146
+
147
+ def _state_file(session_id: str | None = None) -> Path:
148
+ override = os.environ.get("CHERRY_HOOKS_STATE_FILE")
149
+ if override:
150
+ return Path(override)
151
+ if session_id:
152
+ safe = session_id[:12].replace("/", "_").replace(".", "_")
153
+ return Path(f".claude/hooks/session_{safe}.json")
154
+ return _STATE_FILE_DEFAULT
155
+
156
+
157
+ def load_state(session_id: str | None = None) -> dict:
158
+ sf = _state_file(session_id)
159
+ if sf.exists():
160
+ try:
161
+ data = json.loads(sf.read_text(encoding="utf-8"))
162
+ # Forward-fill any keys added after this state file was created.
163
+ for k, v in _DEFAULTS.items():
164
+ data.setdefault(k, v)
165
+ return data
166
+ except (json.JSONDecodeError, OSError):
167
+ pass
168
+ return dict(_DEFAULTS)
169
+
170
+
171
+ def save_state(state: dict, session_id: str | None = None) -> None:
172
+ sf = _state_file(session_id)
173
+ sf.parent.mkdir(parents=True, exist_ok=True)
174
+ sf.write_text(json.dumps(state, indent=2), encoding="utf-8")
175
+
176
+
177
+ def _now() -> str:
178
+ return datetime.now(timezone.utc).isoformat()
179
+
180
+
181
+ def _minutes_since(iso: str | None) -> float | None:
182
+ if not iso:
183
+ return None
184
+ try:
185
+ dt = datetime.fromisoformat(iso)
186
+ return (datetime.now(timezone.utc) - dt).total_seconds() / 60
187
+ except ValueError:
188
+ return None
189
+
190
+
191
+ def _load_transcript_messages(transcript_path: str | None) -> list[dict]:
192
+ if not transcript_path:
193
+ return []
194
+ try:
195
+ raw = Path(transcript_path).read_text(encoding="utf-8").strip()
196
+ try:
197
+ data = json.loads(raw)
198
+ messages = data if isinstance(data, list) else [data]
199
+ except json.JSONDecodeError:
200
+ messages = []
201
+ for line in raw.splitlines():
202
+ line = line.strip()
203
+ if line:
204
+ try:
205
+ messages.append(json.loads(line))
206
+ except json.JSONDecodeError:
207
+ pass
208
+ except (OSError, ValueError, TypeError):
209
+ return []
210
+ return [msg for msg in messages if isinstance(msg, dict)]
211
+
212
+
213
+ def _message_text(msg: dict) -> str:
214
+ content = msg.get("content", "")
215
+ if isinstance(content, list):
216
+ return " ".join(
217
+ block.get("text", "") if isinstance(block, dict) else str(block)
218
+ for block in content
219
+ ).strip()
220
+ return str(content).strip()
221
+
222
+
223
+ def _latest_assistant_message(transcript_path: str | None) -> str | None:
224
+ for msg in reversed(_load_transcript_messages(transcript_path)):
225
+ role = msg.get("role") or msg.get("type", "")
226
+ if role != "assistant":
227
+ continue
228
+ text = _message_text(msg)
229
+ if text:
230
+ return text
231
+ return None
232
+
233
+
234
+ def _assistant_message_from_event(event: dict | None) -> str | None:
235
+ if not isinstance(event, dict):
236
+ return None
237
+ direct = str(event.get("last_assistant_message") or "").strip()
238
+ if direct:
239
+ return direct
240
+ return _latest_assistant_message(event.get("transcript_path"))
241
+
242
+
243
+ # ---------------------------------------------------------------------------
244
+ # Command implementations
245
+ # ---------------------------------------------------------------------------
246
+
247
+
248
+ def cmd_session_start(session_id: str | None = None, event: dict | None = None) -> int:
249
+ """Print startup reminder once per fresh session; badge on every prompt."""
250
+ state = load_state(session_id)
251
+ payload = event if isinstance(event, dict) else {}
252
+ cwd = payload.get("cwd")
253
+ if not state.get("session_initialized"):
254
+ print(_STARTUP_REMINDER, end="")
255
+ state["session_initialized"] = True
256
+ # Protocol drift check: only run git subprocess once per session.
257
+ protocol_warning = _protocol_warning_text()
258
+ if protocol_warning:
259
+ print(protocol_warning)
260
+ save_state(state, session_id)
261
+ _append_capture_event(
262
+ CaptureEventType.SESSION_START,
263
+ session_id=session_id,
264
+ cwd=cwd,
265
+ metadata={"source": "claude-hooks"},
266
+ )
267
+ prompt = str(payload.get("prompt") or payload.get("text") or "").strip()
268
+ if prompt:
269
+ _append_capture_event(
270
+ CaptureEventType.USER_PROMPT,
271
+ session_id=session_id,
272
+ cwd=cwd,
273
+ text=prompt,
274
+ metadata={"source": "claude-hooks"},
275
+ )
276
+ # Nudge if edits were done without a log in this session.
277
+ edits = state.get("edits_since_last_log", 0)
278
+ if edits >= _LOG_NUDGE_AFTER_EDITS:
279
+ files = state.get("files_touched") or []
280
+ files_hint = f" (touched: {', '.join(files[-3:])})" if files else ""
281
+ print(
282
+ f"[CHERRYDOCS] {edits} edits since last log{files_hint}. "
283
+ "Log what you found before continuing."
284
+ )
285
+ return 0
286
+
287
+
288
+ def cmd_post_tool_use(tool_name: str, session_id: str | None = None, event: dict | None = None) -> int:
289
+ """Update enforcement state based on the tool that was used."""
290
+ state = load_state(session_id)
291
+ event = event or {}
292
+
293
+ # Extract command + response before the dispatch block so all branches can use them.
294
+ raw_input = event.get("tool_input")
295
+ command = None
296
+ if isinstance(raw_input, dict):
297
+ command = str(raw_input.get("command") or raw_input.get("cmd") or "").strip() or None
298
+ exit_code = event.get("exit_code")
299
+ if not isinstance(exit_code, int):
300
+ exit_code = None
301
+ raw_response = event.get("tool_response") or event.get("tool_result")
302
+ response_text = raw_response.strip() if isinstance(raw_response, str) else None
303
+ if raw_response and not response_text and not isinstance(raw_response, str):
304
+ response_text = json.dumps(raw_response, ensure_ascii=False)
305
+
306
+ if tool_name in _CODE_TOOLS:
307
+ state["needs_log"] = True
308
+ state["needs_checkpoint"] = True
309
+ state["meaningful_work_done"] = True
310
+ state["edits_since_last_log"] = state.get("edits_since_last_log", 0) + 1
311
+ # Track which files were touched for specific stop messages.
312
+ file_path = _parse_edited_file(event)
313
+ if file_path:
314
+ touched = state.get("files_touched") or []
315
+ if file_path not in touched:
316
+ touched.append(file_path)
317
+ state["files_touched"] = touched[-10:] # keep last 10
318
+
319
+ elif tool_name in _RUN_TOOLS:
320
+ # Detect a successful git commit and fire commit-anchored distillation immediately.
321
+ if command and re.search(r"\bgit\s+commit\b", command):
322
+ commit_hash = _parse_commit_hash(response_text)
323
+ if commit_hash:
324
+ state["last_commit_hash"] = commit_hash
325
+ _trigger_auto_distill(event.get("cwd"), commit_hash=commit_hash)
326
+
327
+ elif tool_name in _LOG_TOOLS:
328
+ state["needs_log"] = False
329
+ state["last_log_at"] = _now()
330
+ state["edits_since_last_log"] = 0
331
+ state["files_touched"] = []
332
+
333
+ elif tool_name in _CHECKPOINT_TOOLS:
334
+ state["needs_checkpoint"] = False
335
+ state["last_checkpoint_at"] = _now()
336
+
337
+ event_type = infer_capture_event_type(tool_name=tool_name, command=command)
338
+ _append_capture_event(
339
+ event_type,
340
+ session_id=session_id,
341
+ cwd=event.get("cwd"),
342
+ text=response_text,
343
+ command=command,
344
+ exit_code=exit_code,
345
+ metadata={"tool_name": tool_name},
346
+ )
347
+ save_state(state, session_id)
348
+ return 0
349
+
350
+
351
+ def _check_no_log_in_transcript(transcript_path: str | None) -> bool:
352
+ """Return True if the last assistant message in the transcript contains [NO LOG:."""
353
+ latest = _latest_assistant_message(transcript_path)
354
+ return "[NO LOG:" in latest if latest else False
355
+
356
+
357
+ _SUCCESS_OUTPUT = json.dumps({"continue": True, "suppressOutput": True})
358
+
359
+
360
+ def _persist_session_id(session_id: str) -> None:
361
+ """Write session_id to disk so later hook calls can fall back to it."""
362
+ try:
363
+ _SESSION_ID_FILE.parent.mkdir(parents=True, exist_ok=True)
364
+ _SESSION_ID_FILE.write_text(session_id, encoding="utf-8")
365
+ except Exception:
366
+ pass
367
+
368
+
369
+ def _read_persisted_session_id() -> str | None:
370
+ """Return the last session_id written by session-start, or None."""
371
+ try:
372
+ return _SESSION_ID_FILE.read_text(encoding="utf-8").strip() or None
373
+ except Exception:
374
+ return None
375
+
376
+
377
+ def _parse_commit_hash(output: str | None) -> str | None:
378
+ """Extract abbreviated commit hash from `git commit` output like [main abc1234]."""
379
+ if not output:
380
+ return None
381
+ m = _GIT_COMMIT_OUTPUT_RE.search(output)
382
+ return m.group(1) if m else None
383
+
384
+
385
+ def _trigger_auto_distill(cwd: str | None, commit_hash: str | None = None) -> None:
386
+ """Fire auto-promotion in background — does not block Stop hook."""
387
+ try:
388
+ from app.services.capture_core import capture_repo_context
389
+ from app.repo_scope import normalize_project_id
390
+ ctx = capture_repo_context(cwd)
391
+ project_id = normalize_project_id(ctx.get("repo") or Path(cwd or ".").name)
392
+ script = ROOT / "scripts" / "auto_promote_capture.py"
393
+ cmd = [
394
+ sys.executable, str(script),
395
+ "--project-id", project_id,
396
+ "--buffer-dir", str(_CAPTURE_BUFFER_ROOT.resolve()),
397
+ "--promoted-root", str(_PROMOTED_ROOT.resolve()),
398
+ ]
399
+ if commit_hash:
400
+ cmd += ["--commit-hash", commit_hash]
401
+ subprocess.Popen(
402
+ cmd,
403
+ stdout=subprocess.DEVNULL,
404
+ stderr=subprocess.DEVNULL,
405
+ start_new_session=True, # detach fully — survives parent exit
406
+ cwd=str(ROOT),
407
+ )
408
+ except Exception:
409
+ pass # never block the Stop hook
410
+
411
+
412
+ def cmd_stop(
413
+ stop_hook_active: bool,
414
+ session_id: str | None = None,
415
+ transcript_path: str | None = None,
416
+ cwd: str | None = None,
417
+ event: dict | None = None,
418
+ ) -> int:
419
+ """Return 0 (allow stop) or 2 (block stop with specific message)."""
420
+ if stop_hook_active:
421
+ print(_SUCCESS_OUTPUT)
422
+ print("CherryDocs ✓ session ended", file=sys.stderr)
423
+ return 0
424
+
425
+ state = load_state(session_id)
426
+ issues: list[str] = []
427
+ latest_assistant = _assistant_message_from_event(event) or _latest_assistant_message(transcript_path)
428
+ no_log_declared = "[NO LOG:" in latest_assistant if latest_assistant else _check_no_log_in_transcript(transcript_path)
429
+ if latest_assistant:
430
+ _append_capture_event(
431
+ CaptureEventType.ASSISTANT_OUTPUT,
432
+ session_id=session_id,
433
+ cwd=cwd,
434
+ text=latest_assistant,
435
+ metadata={"source": "claude-hooks"},
436
+ )
437
+
438
+ if state.get("needs_log") and not no_log_declared:
439
+ files = state.get("files_touched") or []
440
+ edits = state.get("edits_since_last_log", 0)
441
+ files_hint = f": {', '.join(files[-3:])}" if files else ""
442
+ issues.append(
443
+ f"[CHERRYDOCS] Log required — {edits} edit(s) since last log{files_hint}.\n"
444
+ " → mcp__cherry-docs__log_activity(type=..., summary=..., files=[...], reasoning=...)\n"
445
+ " → Or write [NO LOG: <reason>] if nothing worth recording happened."
446
+ )
447
+
448
+ if state.get("needs_checkpoint") and state.get("meaningful_work_done"):
449
+ minutes = _minutes_since(state.get("last_checkpoint_at"))
450
+ age_hint = f" (last checkpoint {minutes:.0f}m ago)" if minutes else ""
451
+ issues.append(
452
+ f"[CHERRYDOCS] Checkpoint required before ending{age_hint}.\n"
453
+ " → mcp__cherry-docs__save_checkpoint(summary=..., attempts=..., decisions=..., next_steps=[...])"
454
+ )
455
+
456
+ if issues:
457
+ print("\n".join(issues))
458
+ labels = []
459
+ if state.get("needs_log"):
460
+ labels.append("log")
461
+ if state.get("needs_checkpoint"):
462
+ labels.append("checkpoint")
463
+ print(f"CherryDocs ⚠ blocked: {' + '.join(labels)} required", file=sys.stderr)
464
+ return 2
465
+ _trigger_auto_distill(cwd)
466
+ print(_SUCCESS_OUTPUT)
467
+ edits = state.get("edits_since_last_log", 0)
468
+ status = "no edits" if not edits else f"{edits} edit(s) logged"
469
+ print(f"CherryDocs ✓ {status}", file=sys.stderr)
470
+ return 0
471
+
472
+
473
+ def cmd_reset(session_id: str | None = None) -> int:
474
+ """Clear session state."""
475
+ sf = _state_file(session_id)
476
+ if sf.exists():
477
+ sf.unlink()
478
+ print("Session state cleared.")
479
+ return 0
480
+
481
+
482
+ def main() -> int:
483
+ if len(sys.argv) < 2:
484
+ print("Usage: state_manager.py <session-start|post-tool-use|stop|reset>", file=sys.stderr)
485
+ return 1
486
+
487
+ cmd = sys.argv[1]
488
+ raw_stdin = sys.stdin.read()
489
+
490
+ try:
491
+ event: dict = json.loads(raw_stdin) if raw_stdin.strip() else {}
492
+ except json.JSONDecodeError:
493
+ event = {}
494
+
495
+ session_id: str | None = event.get("session_id") or _read_persisted_session_id()
496
+
497
+ # Persist on every hook that carries a session_id — not just session-start.
498
+ # This ensures Stop can always find the right state file even if it doesn't
499
+ # receive session_id in its own payload.
500
+ if event.get("session_id"):
501
+ _persist_session_id(event["session_id"])
502
+
503
+ if cmd == "session-start":
504
+ return cmd_session_start(session_id, event)
505
+
506
+ if cmd == "post-tool-use":
507
+ return cmd_post_tool_use(event.get("tool_name", ""), session_id, event)
508
+
509
+ if cmd == "stop":
510
+ return cmd_stop(
511
+ bool(event.get("stop_hook_active", False)),
512
+ session_id,
513
+ event.get("transcript_path"),
514
+ event.get("cwd"),
515
+ event,
516
+ )
517
+
518
+ if cmd == "reset":
519
+ return cmd_reset(session_id)
520
+
521
+ print(f"Unknown command: {cmd}", file=sys.stderr)
522
+ return 1
523
+
524
+
525
+ if __name__ == "__main__":
526
+ sys.exit(main())
@@ -0,0 +1,121 @@
1
+ #!/usr/bin/env python3
2
+ """Summarize coverage XML and enforce a baseline regression gate."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import json
8
+ import xml.etree.ElementTree as ET
9
+ from pathlib import Path
10
+ from typing import Any
11
+
12
+
13
+ def _load_json(path: str | Path) -> dict[str, Any]:
14
+ return json.loads(Path(path).read_text())
15
+
16
+
17
+ def _load_coverage_xml(path: str | Path) -> dict[str, Any]:
18
+ root = ET.fromstring(Path(path).read_text())
19
+ return {
20
+ "line_rate": round(float(root.attrib.get("line-rate", 0.0)) * 100, 2),
21
+ "branch_rate": round(float(root.attrib.get("branch-rate", 0.0)) * 100, 2),
22
+ "lines_covered": int(root.attrib.get("lines-covered", 0)),
23
+ "lines_valid": int(root.attrib.get("lines-valid", 0)),
24
+ "branches_covered": int(root.attrib.get("branches-covered", 0)),
25
+ "branches_valid": int(root.attrib.get("branches-valid", 0)),
26
+ }
27
+
28
+
29
+ def _write_json(path: str | None, payload: dict[str, Any]) -> None:
30
+ if not path:
31
+ return
32
+ target = Path(path)
33
+ target.parent.mkdir(parents=True, exist_ok=True)
34
+ target.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n")
35
+
36
+
37
+ def _compare(
38
+ baseline: dict[str, Any],
39
+ candidate: dict[str, Any],
40
+ *,
41
+ min_line_rate: float,
42
+ max_line_drop: float,
43
+ ) -> tuple[dict[str, Any], list[str]]:
44
+ errors: list[str] = []
45
+ line_drop = round(float(baseline.get("line_rate", 0.0)) - float(candidate.get("line_rate", 0.0)), 2)
46
+ branch_drop = round(float(baseline.get("branch_rate", 0.0)) - float(candidate.get("branch_rate", 0.0)), 2)
47
+
48
+ if float(candidate.get("line_rate", 0.0)) < min_line_rate:
49
+ errors.append(
50
+ f"line coverage {candidate['line_rate']:.2f}% is below minimum {min_line_rate:.2f}%"
51
+ )
52
+ if line_drop > max_line_drop:
53
+ errors.append(
54
+ f"line coverage dropped by {line_drop:.2f} points, exceeding max drop {max_line_drop:.2f}"
55
+ )
56
+
57
+ report = {
58
+ "baseline_line_rate": baseline.get("line_rate"),
59
+ "candidate_line_rate": candidate.get("line_rate"),
60
+ "line_rate_drop": line_drop,
61
+ "baseline_branch_rate": baseline.get("branch_rate"),
62
+ "candidate_branch_rate": candidate.get("branch_rate"),
63
+ "branch_rate_drop": branch_drop,
64
+ "min_line_rate": min_line_rate,
65
+ "max_line_drop": max_line_drop,
66
+ "baseline_lines_covered": baseline.get("lines_covered"),
67
+ "candidate_lines_covered": candidate.get("lines_covered"),
68
+ "baseline_lines_valid": baseline.get("lines_valid"),
69
+ "candidate_lines_valid": candidate.get("lines_valid"),
70
+ "errors": errors,
71
+ }
72
+ return report, errors
73
+
74
+
75
+ def main() -> int:
76
+ parser = argparse.ArgumentParser(description=__doc__)
77
+ parser.add_argument("--candidate", required=True, help="Coverage XML file to summarize.")
78
+ parser.add_argument("--baseline", help="Baseline coverage summary JSON to compare against.")
79
+ parser.add_argument("--summary-output", help="Where to write the normalized coverage summary JSON.")
80
+ parser.add_argument("--report-output", help="Where to write the comparison report JSON.")
81
+ parser.add_argument("--min-line-rate", type=float, default=70.0)
82
+ parser.add_argument("--max-line-drop", type=float, default=3.0)
83
+ args = parser.parse_args()
84
+
85
+ candidate = _load_coverage_xml(args.candidate)
86
+ _write_json(args.summary_output, candidate)
87
+
88
+ print(
89
+ "Coverage summary:",
90
+ f"line={candidate['line_rate']:.2f}%",
91
+ f"branch={candidate['branch_rate']:.2f}%",
92
+ f"covered={candidate['lines_covered']}/{candidate['lines_valid']}",
93
+ )
94
+
95
+ if not args.baseline:
96
+ return 0
97
+
98
+ baseline = _load_json(args.baseline)
99
+ report, errors = _compare(
100
+ baseline,
101
+ candidate,
102
+ min_line_rate=args.min_line_rate,
103
+ max_line_drop=args.max_line_drop,
104
+ )
105
+ _write_json(args.report_output, report)
106
+
107
+ print(
108
+ "Coverage regression:",
109
+ f"baseline={report['baseline_line_rate']:.2f}%",
110
+ f"candidate={report['candidate_line_rate']:.2f}%",
111
+ f"drop={report['line_rate_drop']:.2f}",
112
+ )
113
+ if errors:
114
+ for error in errors:
115
+ print(f"ERROR: {error}")
116
+ return 1
117
+ return 0
118
+
119
+
120
+ if __name__ == "__main__":
121
+ raise SystemExit(main())