pr-context-engine 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,225 @@
1
+ """Git history context for changed files.
2
+
3
+ Fetches the last N commit messages per file via git log, and optionally
4
+ resolves the most recent merged PRs that touched the same files via the
5
+ GitHub REST API.
6
+
7
+ Degrades gracefully on shallow clones (workflow uses fetch-depth: 50) or
8
+ when git/network is unavailable — callers receive limited_history=True and
9
+ an empty commit list rather than an exception. See docs/design-decisions.md
10
+ for the deliberate shallow-clone tradeoff.
11
+ """
12
+ import logging
13
+ import re
14
+ import subprocess
15
+ from dataclasses import dataclass, field
16
+ from pathlib import Path
17
+
18
+ import requests
19
+
20
+ from src.github_api import GITHUB_API_URL
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ _GIT_LOG_LIMIT = 5
25
+ _PR_MERGE_SCAN = 30 # merge commits to scan when searching for PR numbers
26
+ _MAX_PRS = 3
27
+ _REQUEST_TIMEOUT = 10
28
+
29
+
30
+ @dataclass
31
+ class CommitRecord:
32
+ """A single commit's abbreviated hash and subject line."""
33
+
34
+ sha: str
35
+ message: str
36
+
37
+
38
+ @dataclass
39
+ class FileHistory:
40
+ """Recent commit history for one changed file."""
41
+
42
+ file_path: str
43
+ recent_commits: list[CommitRecord] = field(default_factory=list)
44
+ limited_history: bool = False
45
+
46
+
47
+ @dataclass
48
+ class RecentPR:
49
+ """A recently merged PR that touched one of the changed files."""
50
+
51
+ number: int
52
+ title: str
53
+ body_first_line: str
54
+
55
+
56
+ def get_file_histories(
57
+ file_paths: list[str],
58
+ repo_root: str = ".",
59
+ max_commits: int = _GIT_LOG_LIMIT,
60
+ ) -> dict[str, FileHistory]:
61
+ """Return recent commit history for each file path.
62
+
63
+ Args:
64
+ file_paths: Repo-relative paths of changed files.
65
+ repo_root: Root of the git repository.
66
+ max_commits: Maximum commits to fetch per file.
67
+
68
+ Returns:
69
+ Mapping of file_path -> FileHistory. Files with no discoverable
70
+ history still get an entry (empty commits, limited_history=True).
71
+ """
72
+ root = Path(repo_root).resolve()
73
+ return {path: _fetch_file_history(path, root, max_commits) for path in file_paths}
74
+
75
+
76
+ def get_recent_merged_prs(
77
+ file_paths: list[str],
78
+ repo: str,
79
+ github_token: str,
80
+ repo_root: str = ".",
81
+ max_prs: int = _MAX_PRS,
82
+ ) -> list[RecentPR]:
83
+ """Find the most recent merged PRs that touched any of the given files.
84
+
85
+ Uses git merge-commit messages to locate PR numbers, then fetches details
86
+ via the GitHub API. Returns an empty list when git or the API is
87
+ unavailable rather than raising.
88
+
89
+ Args:
90
+ file_paths: Repo-relative paths of changed files.
91
+ repo: Repository in "owner/name" form.
92
+ github_token: GitHub token used for API requests.
93
+ repo_root: Root of the git repository.
94
+ max_prs: Maximum number of PRs to return.
95
+
96
+ Returns:
97
+ List of RecentPR, most recent first.
98
+ """
99
+ if not file_paths:
100
+ return []
101
+
102
+ root = Path(repo_root).resolve()
103
+ pr_numbers = _find_pr_numbers_from_merges(file_paths, root)
104
+ if not pr_numbers:
105
+ return []
106
+
107
+ headers = {
108
+ "Authorization": f"token {github_token}",
109
+ "Accept": "application/vnd.github.v3+json",
110
+ }
111
+
112
+ prs: list[RecentPR] = []
113
+ for pr_number in pr_numbers:
114
+ if len(prs) >= max_prs:
115
+ break
116
+ pr = _fetch_pr_details(repo, pr_number, headers)
117
+ if pr is not None:
118
+ prs.append(pr)
119
+
120
+ return prs
121
+
122
+
123
+ # ---------------------------------------------------------------------------
124
+ # Private helpers
125
+ # ---------------------------------------------------------------------------
126
+
127
+
128
+ def _fetch_file_history(path: str, root: Path, max_commits: int) -> FileHistory:
129
+ """Run git log for a single file and return its FileHistory."""
130
+ try:
131
+ result = subprocess.run(
132
+ [
133
+ "git",
134
+ "log",
135
+ "--follow",
136
+ f"--max-count={max_commits}",
137
+ "--format=%H %s",
138
+ "--",
139
+ path,
140
+ ],
141
+ cwd=root,
142
+ capture_output=True,
143
+ text=True,
144
+ check=True,
145
+ timeout=10,
146
+ )
147
+ except (subprocess.CalledProcessError, FileNotFoundError, subprocess.TimeoutExpired) as exc:
148
+ logger.warning("git log failed for %s: %s", path, exc)
149
+ return FileHistory(file_path=path, limited_history=True)
150
+
151
+ commits: list[CommitRecord] = []
152
+ for line in result.stdout.strip().splitlines():
153
+ if not line.strip():
154
+ continue
155
+ sha, _, message = line.partition(" ")
156
+ commits.append(CommitRecord(sha=sha[:8], message=message))
157
+
158
+ # git log silently returns fewer commits on shallow clones without any
159
+ # stderr output; hitting the limit is the only reliable signal.
160
+ limited = len(commits) == max_commits
161
+
162
+ return FileHistory(file_path=path, recent_commits=commits, limited_history=limited)
163
+
164
+
165
+ def _find_pr_numbers_from_merges(file_paths: list[str], root: Path) -> list[int]:
166
+ """Extract PR numbers from merge commits that touched any of the given files."""
167
+ try:
168
+ result = subprocess.run(
169
+ [
170
+ "git",
171
+ "log",
172
+ "--merges",
173
+ f"--max-count={_PR_MERGE_SCAN}",
174
+ "--format=%s",
175
+ "--",
176
+ *file_paths,
177
+ ],
178
+ cwd=root,
179
+ capture_output=True,
180
+ text=True,
181
+ check=True,
182
+ timeout=15,
183
+ )
184
+ except (subprocess.CalledProcessError, FileNotFoundError, subprocess.TimeoutExpired) as exc:
185
+ logger.warning("git log --merges failed: %s", exc)
186
+ return []
187
+
188
+ # Matches:
189
+ # "Merge pull request #N from ..." (GitHub standard merge)
190
+ # "Merge PR #N ..." (shorthand)
191
+ # "feat: something (#N)" (GitHub squash merge)
192
+ pattern = re.compile(r"(?:(?:pull\s+request|PR)\s+#(\d+)|\(#(\d+)\))", re.IGNORECASE)
193
+ seen: set[int] = set()
194
+ numbers: list[int] = []
195
+
196
+ for line in result.stdout.strip().splitlines():
197
+ match = pattern.search(line)
198
+ if match:
199
+ num = int(match.group(1) or match.group(2))
200
+ if num not in seen:
201
+ seen.add(num)
202
+ numbers.append(num)
203
+
204
+ return numbers
205
+
206
+
207
+ def _fetch_pr_details(repo: str, pr_number: int, headers: dict[str, str]) -> RecentPR | None:
208
+ """Fetch PR title and body from GitHub API; returns None when unavailable."""
209
+ url = f"{GITHUB_API_URL}/repos/{repo}/pulls/{pr_number}"
210
+ try:
211
+ resp = requests.get(url, headers=headers, timeout=_REQUEST_TIMEOUT)
212
+ resp.raise_for_status()
213
+ except requests.RequestException as exc:
214
+ logger.warning("GitHub API request failed for PR #%d: %s", pr_number, exc)
215
+ return None
216
+
217
+ data = resp.json()
218
+ if not data.get("merged_at"):
219
+ return None # ignore open or closed-unmerged PRs
220
+
221
+ title = (data.get("title") or "").strip()
222
+ body = data.get("body") or ""
223
+ body_first_line = body.split("\n")[0].strip()[:200]
224
+
225
+ return RecentPR(number=pr_number, title=title, body_first_line=body_first_line)
src/fixes/__init__.py ADDED
@@ -0,0 +1 @@
1
+ """Opt-in fix suggestion generation with confidence gating (Milestone 8)."""
@@ -0,0 +1,60 @@
1
+ """Confidence gating: decides which fix suggestions become collapsed code blocks vs prose.
2
+
3
+ High and medium confidence fixes (with a non-None patch) are formatted as collapsed
4
+ <details> blocks with a fenced code snippet. Low confidence (or missing patch) is
5
+ rendered as a prose warning only.
6
+
7
+ Note: GitHub's ```suggestion fence only works in line-level review comments, not in
8
+ general PR body comments. We use language-inferred fences so the patch renders as
9
+ readable code regardless of comment type.
10
+ """
11
+ from src.analyzers.diff_parser import detect_language
12
+ from src.fixes.fix_generator import FixSuggestion
13
+
14
+
15
+ def _lang_from_path(path: str) -> str:
16
+ return detect_language(path)
17
+
18
+ _CONFIDENCE_ICONS = {
19
+ "high": "🔴",
20
+ "medium": "🟡",
21
+ "low": "⚠️",
22
+ }
23
+
24
+
25
+ def is_block_eligible(suggestion: FixSuggestion) -> bool:
26
+ """True when confidence is high or medium AND a patch is present.
27
+
28
+ Low confidence suggestions must never produce a suggestion block even if
29
+ the LLM accidentally emitted patch text — the parser already nulls the patch
30
+ on low confidence, but this gate adds a second layer of enforcement.
31
+ """
32
+ return suggestion.confidence in ("high", "medium") and suggestion.patch is not None
33
+
34
+
35
+ def format_suggestion_block(suggestion: FixSuggestion) -> str:
36
+ """Render a high/medium confidence fix as a collapsed <details> block with a code patch."""
37
+ icon = _CONFIDENCE_ICONS.get(suggestion.confidence, "💡")
38
+ flag_label = suggestion.flag.flag
39
+ location = f"`{suggestion.flag.file}:{suggestion.flag.line}`"
40
+ lang = _lang_from_path(suggestion.flag.file)
41
+
42
+ return (
43
+ f"<details>\n"
44
+ f"<summary>{icon} <strong>{suggestion.confidence} confidence</strong>"
45
+ f" — {flag_label} in {location}</summary>\n\n"
46
+ f"**Rationale:** {suggestion.rationale}\n\n"
47
+ f"```{lang}\n{suggestion.patch}\n```\n\n"
48
+ f"</details>\n"
49
+ )
50
+
51
+
52
+ def format_prose_note(suggestion: FixSuggestion) -> str:
53
+ """Render a low-confidence fix (or missing patch) as a prose-only warning."""
54
+ icon = _CONFIDENCE_ICONS.get(suggestion.confidence, "⚠️")
55
+ flag_label = suggestion.flag.flag
56
+ location = f"`{suggestion.flag.file}:{suggestion.flag.line}`"
57
+ return (
58
+ f"> {icon} **{suggestion.confidence}** — {flag_label} in {location}: "
59
+ f"{suggestion.rationale}\n"
60
+ )
@@ -0,0 +1,152 @@
1
+ """Generate fix suggestions for located risk flags via a separate LLM call.
2
+
3
+ Each call targets a single RiskFlag with a concrete line number and asks the LLM
4
+ for a minimal replacement patch plus a confidence self-assessment. Low-confidence
5
+ responses intentionally produce no patch — a wrong fix is worse than no fix.
6
+ """
7
+ import logging
8
+ from dataclasses import dataclass
9
+
10
+ from src.analyzers.diff_parser import FileChange
11
+ from src.analyzers.risk_scorer import RiskFlag
12
+ from src.briefing.prompt_templates import FIX_SYSTEM_PROMPT
13
+ from src.llm.base import LLMProvider
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ _MAX_CONTEXT_LINES = 40
18
+
19
+
20
+ @dataclass
21
+ class FixSuggestion:
22
+ """A suggested fix for a specific risk flag."""
23
+
24
+ flag: RiskFlag
25
+ patch: str | None # replacement code; None when confidence is low or generation failed
26
+ rationale: str
27
+ confidence: str # "high" | "medium" | "low"
28
+
29
+
30
+ def generate_fixes(
31
+ provider: LLMProvider,
32
+ flags: list[RiskFlag],
33
+ changes: list[FileChange],
34
+ max_fixes: int = 3,
35
+ ) -> tuple[list[FixSuggestion], int]:
36
+ """Generate fix suggestions for up to max_fixes eligible flags.
37
+
38
+ Only flags with a non-null line number are fix-eligible. Returns a tuple of
39
+ (suggestions, extra_count) where extra_count is how many eligible flags were
40
+ skipped due to the cap.
41
+
42
+ Args:
43
+ provider: LLM provider instance for generation calls.
44
+ flags: All risk flags from the current PR.
45
+ changes: Parsed file changes (used to extract surrounding code context).
46
+ max_fixes: Maximum number of fix suggestions to generate (default 3).
47
+
48
+ Returns:
49
+ Tuple of (list of FixSuggestion, number of eligible flags beyond the cap).
50
+ """
51
+ eligible = [f for f in flags if f.line is not None]
52
+ capped = eligible[:max_fixes]
53
+ extra_count = len(eligible) - len(capped)
54
+
55
+ suggestions: list[FixSuggestion] = []
56
+ for flag in capped:
57
+ suggestion = _generate_single_fix(provider, flag, changes)
58
+ suggestions.append(suggestion)
59
+
60
+ return suggestions, extra_count
61
+
62
+
63
+ def _generate_single_fix(
64
+ provider: LLMProvider,
65
+ flag: RiskFlag,
66
+ changes: list[FileChange],
67
+ ) -> FixSuggestion:
68
+ """Make one LLM call to generate a fix for a single located flag."""
69
+ context = _get_flag_context(flag, changes)
70
+ prompt = _build_fix_prompt(flag, context)
71
+
72
+ try:
73
+ response = provider.generate(prompt)
74
+ except Exception as exc:
75
+ logger.warning("Fix generation failed for %s:%s: %s", flag.file, flag.line, exc)
76
+ return FixSuggestion(
77
+ flag=flag,
78
+ patch=None,
79
+ rationale="Fix generation failed — see briefing for details.",
80
+ confidence="low",
81
+ )
82
+
83
+ return _parse_fix_response(flag, response)
84
+
85
+
86
+ def _get_flag_context(flag: RiskFlag, changes: list[FileChange]) -> str:
87
+ """Extract the diff hunk containing the flagged line as context.
88
+
89
+ Checks both old-file and new-file line ranges so both modifies_auth
90
+ (new-file lines) and deletes_public_api (old-file lines) are handled.
91
+ Falls back to the flag snippet if no matching hunk is found.
92
+ """
93
+ for change in changes:
94
+ if change.path != flag.file:
95
+ continue
96
+ for hunk in change.hunks:
97
+ new_end = hunk.new_start + max(hunk.new_count, 1)
98
+ old_end = hunk.old_start + max(hunk.old_count, 1)
99
+ line = flag.line or 0
100
+ if hunk.new_start <= line < new_end or hunk.old_start <= line < old_end:
101
+ return "\n".join(hunk.lines[:_MAX_CONTEXT_LINES])
102
+ return flag.snippet
103
+
104
+
105
+ def _build_fix_prompt(flag: RiskFlag, context: str) -> str:
106
+ """Assemble the full prompt: system instructions + flag context."""
107
+ user_section = (
108
+ f"Flag type: {flag.flag}\n"
109
+ f"File: {flag.file}\n"
110
+ f"Line: {flag.line}\n"
111
+ f"Flagged snippet: {flag.snippet}\n\n"
112
+ f"Surrounding diff context:\n```\n{context}\n```"
113
+ )
114
+ return f"{FIX_SYSTEM_PROMPT}\n\n---\n\n{user_section}"
115
+
116
+
117
+ def _parse_fix_response(flag: RiskFlag, response: str) -> FixSuggestion:
118
+ """Parse structured fix response (CONFIDENCE / RATIONALE / PATCH) from LLM."""
119
+ confidence = "low"
120
+ rationale = ""
121
+ patch_lines: list[str] = []
122
+ in_patch = False
123
+
124
+ for line in response.strip().splitlines():
125
+ stripped = line.strip()
126
+ if stripped.upper().startswith("CONFIDENCE:"):
127
+ raw = stripped.split(":", 1)[1].strip().lower()
128
+ if raw in ("high", "medium", "low"):
129
+ confidence = raw
130
+ elif stripped.upper().startswith("RATIONALE:"):
131
+ rationale = stripped.split(":", 1)[1].strip()
132
+ elif stripped.upper().startswith("PATCH:"):
133
+ in_patch = True
134
+ elif in_patch:
135
+ patch_lines.append(line)
136
+
137
+ patch: str | None = None
138
+ if patch_lines:
139
+ raw_patch = "\n".join(patch_lines).strip()
140
+ if raw_patch and raw_patch.upper() != "NO_PATCH":
141
+ patch = raw_patch
142
+
143
+ # Enforce hard rule: low confidence must never carry a patch block.
144
+ if confidence == "low":
145
+ patch = None
146
+
147
+ return FixSuggestion(
148
+ flag=flag,
149
+ patch=patch,
150
+ rationale=rationale or flag.snippet,
151
+ confidence=confidence,
152
+ )
@@ -0,0 +1,3 @@
1
+ """GitHub API access: fetch pull request diffs and post review comments."""
2
+
3
+ GITHUB_API_URL = "https://api.github.com"
@@ -0,0 +1,95 @@
1
+ """GitHub REST access for a PR: fetch its unified diff and post a comment.
2
+
3
+ format_fix_section() renders Milestone 8 fix suggestions as collapsed <details>
4
+ blocks (high/medium confidence) or prose warnings (low confidence) suitable for
5
+ appending to the main briefing comment body.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ import logging
10
+ from typing import TYPE_CHECKING
11
+
12
+ import requests
13
+ from github import Auth, Github
14
+
15
+ from src.github_api import GITHUB_API_URL
16
+
17
+ if TYPE_CHECKING:
18
+ from src.fixes.fix_generator import FixSuggestion
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ _API_VERSION = "2022-11-28"
23
+
24
+
25
+ def fetch_pr_diff(repo: str, pr_number: int, github_token: str | None) -> str:
26
+ """Fetch the unified diff of a pull request.
27
+
28
+ `repo` is in `owner/name` form. Returns the raw unified-diff text.
29
+ Omits the Authorization header when `github_token` is None (anonymous, public repos only).
30
+ """
31
+ url = f"{GITHUB_API_URL}/repos/{repo}/pulls/{pr_number}"
32
+ headers: dict[str, str] = {
33
+ "Accept": "application/vnd.github.diff",
34
+ "X-GitHub-Api-Version": _API_VERSION,
35
+ }
36
+ if github_token:
37
+ headers["Authorization"] = f"Bearer {github_token}"
38
+ logger.info("Fetching diff for %s PR #%d", repo, pr_number)
39
+ try:
40
+ response = requests.get(url, headers=headers, timeout=30)
41
+ except requests.RequestException as exc:
42
+ raise RuntimeError(f"Network error fetching PR diff: {exc}") from exc
43
+ if not response.ok:
44
+ logger.error("GitHub API error %d: %s", response.status_code, response.text[:300])
45
+ response.raise_for_status()
46
+ return response.text
47
+
48
+
49
+ def post_pr_comment(repo: str, pr_number: int, body: str, github_token: str) -> None:
50
+ """Post `body` as a general (issue-level) comment on the pull request.
51
+
52
+ `repo` is in `owner/name` form.
53
+ """
54
+ logger.info("Posting comment to %s PR #%d", repo, pr_number)
55
+ gh = Github(auth=Auth.Token(github_token))
56
+ pull_request = gh.get_repo(repo).get_pull(pr_number)
57
+ pull_request.create_issue_comment(body)
58
+
59
+
60
+ def format_fix_section(suggestions: list[FixSuggestion], extra_count: int = 0) -> str:
61
+ """Render fix suggestions as a markdown section for appending to the briefing comment.
62
+
63
+ High/medium confidence suggestions with a patch become collapsed <details> blocks.
64
+ Low confidence (or missing patch) suggestions become prose warnings only.
65
+ If extra_count > 0, a trailing note indicates how many eligible flags were skipped.
66
+
67
+ Args:
68
+ suggestions: List of FixSuggestion objects (may mix confidence levels).
69
+ extra_count: Number of fix-eligible flags beyond the 3-suggestion cap.
70
+
71
+ Returns:
72
+ Markdown string ready to append after the briefing's closing `---` line.
73
+ Returns empty string if suggestions is empty.
74
+ """
75
+ # Deferred to avoid pulling src.fixes into comment_poster at module load time;
76
+ # this module is imported by cli.py regardless of whether fixes are enabled.
77
+ from src.fixes.confidence import format_prose_note, format_suggestion_block, is_block_eligible
78
+
79
+ if not suggestions:
80
+ return ""
81
+
82
+ parts: list[str] = ["\n\n### 💡 Fix Suggestions\n\n"]
83
+
84
+ for suggestion in suggestions:
85
+ if is_block_eligible(suggestion):
86
+ parts.append(format_suggestion_block(suggestion))
87
+ parts.append("\n")
88
+ else:
89
+ parts.append(format_prose_note(suggestion))
90
+
91
+ if extra_count > 0:
92
+ noun = "issue" if extra_count == 1 else "issues"
93
+ parts.append(f"\n_{extra_count} more {noun} detected — see Risk Flags above._\n")
94
+
95
+ return "".join(parts)
src/llm/__init__.py ADDED
@@ -0,0 +1,106 @@
1
+ """LLM provider abstraction: a provider-agnostic interface and its implementations.
2
+
3
+ Exports FailoverProvider — wraps an ordered list of providers and tries each in
4
+ sequence, recording which one succeeded and why others were skipped. This is the
5
+ runtime payoff for the ADR-0 provider abstraction built in Milestone 2.
6
+ """
7
+ import logging
8
+ from dataclasses import dataclass
9
+
10
+ from src.llm.base import LLMProvider
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ @dataclass
16
+ class _ProviderAttempt:
17
+ name: str
18
+ was_rate_limited: bool
19
+
20
+
21
+ class FailoverProvider(LLMProvider):
22
+ """Try providers in order; fall back on rate-limit or any error.
23
+
24
+ After a successful generate() call:
25
+ - provider_used: name of the provider that responded
26
+ - skipped: records of providers that were tried and failed
27
+
28
+ Call attribution() after generate() to get a footer-friendly string such as
29
+ "groq" or "gemini (groq rate-limited)".
30
+ """
31
+
32
+ def __init__(self, providers: list[tuple[str, LLMProvider]]) -> None:
33
+ """
34
+ Args:
35
+ providers: Ordered list of (name, provider) pairs. First is tried first.
36
+ """
37
+ if not providers:
38
+ raise ValueError("FailoverProvider requires at least one provider")
39
+ self._providers = providers
40
+ self.provider_used: str | None = None
41
+ self.skipped: list[_ProviderAttempt] = []
42
+
43
+ def generate(self, prompt: str) -> str:
44
+ """Try each provider in order; return the first successful response."""
45
+ self.provider_used = None
46
+ self.skipped = []
47
+ last_exc: Exception | None = None
48
+
49
+ for name, provider in self._providers:
50
+ try:
51
+ result = provider.generate(prompt)
52
+ self.provider_used = name
53
+ logger.info("Provider %s succeeded", name)
54
+ return result
55
+ except (TypeError, AttributeError):
56
+ raise
57
+ except Exception as exc:
58
+ is_rate_limited = _is_rate_limit_error(exc)
59
+ self.skipped.append(_ProviderAttempt(name=name, was_rate_limited=is_rate_limited))
60
+ if is_rate_limited:
61
+ logger.warning("Provider %s rate-limited; trying next", name)
62
+ else:
63
+ logger.warning("Provider %s failed (%s); trying next", name, exc)
64
+ last_exc = exc
65
+
66
+ raise RuntimeError(
67
+ f"All {len(self._providers)} provider(s) failed. Last error: {last_exc}"
68
+ ) from last_exc
69
+
70
+ def attribution(self) -> str:
71
+ """One-line attribution for the PR comment footer.
72
+
73
+ Must be called after generate(). Raises RuntimeError if generate() has not
74
+ been called yet.
75
+
76
+ Examples:
77
+ "groq"
78
+ "gemini (groq rate-limited)"
79
+ "gemini (groq failed)"
80
+ "all providers failed"
81
+ """
82
+ if self.provider_used is None and not self.skipped:
83
+ raise RuntimeError("attribution() called before generate()")
84
+ if self.provider_used is None:
85
+ return "all providers failed"
86
+ if not self.skipped:
87
+ return self.provider_used
88
+
89
+ rate_limited = [s.name for s in self.skipped if s.was_rate_limited]
90
+ errors = [s.name for s in self.skipped if not s.was_rate_limited]
91
+ reasons: list[str] = []
92
+ if rate_limited:
93
+ reasons.append(f"{', '.join(rate_limited)} rate-limited")
94
+ if errors:
95
+ reasons.append(f"{', '.join(errors)} failed")
96
+ return f"{self.provider_used} ({'; '.join(reasons)})"
97
+
98
+
99
+ def _is_rate_limit_error(exc: Exception) -> bool:
100
+ """Return True if exc represents a rate-limit or quota exhaustion error."""
101
+ msg = str(exc).lower()
102
+ keywords = ("rate limit", "rate_limit", "quota", "429", "resource exhausted", "too many requests")
103
+ if any(kw in msg for kw in keywords):
104
+ return True
105
+ status = getattr(exc, "status_code", None) or getattr(exc, "code", None)
106
+ return status == 429
@@ -0,0 +1,32 @@
1
+ """Anthropic Claude-backed LLMProvider implementation (model: claude-sonnet-4-6)."""
2
+ import logging
3
+
4
+ import anthropic
5
+
6
+ from src.llm.base import LLMProvider
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ _MODEL = "claude-sonnet-4-6"
11
+ _MAX_TOKENS = 1024
12
+
13
+
14
+ class AnthropicProvider(LLMProvider):
15
+ """Calls the Anthropic Messages API to satisfy the LLMProvider contract."""
16
+
17
+ def __init__(self, api_key: str) -> None:
18
+ """Build an Anthropic client from an explicit API key."""
19
+ self._client = anthropic.Anthropic(api_key=api_key)
20
+
21
+ def generate(self, prompt: str) -> str:
22
+ """Send `prompt` as a single user message and return the completion text."""
23
+ logger.info("Requesting Anthropic completion (model=%s)", _MODEL)
24
+ message = self._client.messages.create(
25
+ model=_MODEL,
26
+ max_tokens=_MAX_TOKENS,
27
+ messages=[{"role": "user", "content": prompt}],
28
+ )
29
+ if not message.content or not message.content[0].text:
30
+ raise RuntimeError("Anthropic returned an empty completion")
31
+ text = message.content[0].text
32
+ return text
src/llm/base.py ADDED
@@ -0,0 +1,11 @@
1
+ """Abstract LLM provider contract — the single interface all providers implement."""
2
+ from abc import ABC, abstractmethod
3
+
4
+
5
+ class LLMProvider(ABC):
6
+ """Provider-agnostic LLM interface. One method; no provider types leak out."""
7
+
8
+ @abstractmethod
9
+ def generate(self, prompt: str) -> str:
10
+ """Return the model's text completion for the given prompt."""
11
+ ...