codejury 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- codejury/__init__.py +8 -0
- codejury/agents/__init__.py +6 -0
- codejury/agents/base.py +21 -0
- codejury/agents/debate.py +188 -0
- codejury/agents/mock.py +38 -0
- codejury/agents/parsing.py +42 -0
- codejury/agents/verifier.py +106 -0
- codejury/assembly.py +76 -0
- codejury/cli.py +196 -0
- codejury/data/capabilities/authentication.yaml +67 -0
- codejury/data/capabilities/authorization.yaml +55 -0
- codejury/data/capabilities/business_logic.yaml +58 -0
- codejury/data/capabilities/crypto.yaml +78 -0
- codejury/data/capabilities/data_protection.yaml +57 -0
- codejury/data/capabilities/dependency_config.yaml +52 -0
- codejury/data/capabilities/error_logging.yaml +49 -0
- codejury/data/capabilities/input_validation.yaml +92 -0
- codejury/data/capabilities/output_encoding.yaml +56 -0
- codejury/data/capabilities/secrets.yaml +51 -0
- codejury/data/capabilities/session.yaml +60 -0
- codejury/data/golden/authn_bcrypt_password.yaml +5 -0
- codejury/data/golden/authn_sha256_password.yaml +5 -0
- codejury/data/golden/sqli_fstring_query.yaml +5 -0
- codejury/data/golden/sqli_parameterized_query.yaml +5 -0
- codejury/data/tasks/audit_diff_debate.yaml +4 -0
- codejury/data/tasks/quick_scan_single.yaml +4 -0
- codejury/domain/__init__.py +5 -0
- codejury/domain/artifact.py +20 -0
- codejury/domain/capability.py +123 -0
- codejury/domain/context.py +26 -0
- codejury/domain/observation.py +104 -0
- codejury/domain/result.py +19 -0
- codejury/evaluation.py +107 -0
- codejury/infrastructure/__init__.py +4 -0
- codejury/infrastructure/json_parse.py +57 -0
- codejury/orchestrators/__init__.py +6 -0
- codejury/orchestrators/base.py +19 -0
- codejury/orchestrators/debate.py +57 -0
- codejury/orchestrators/pipeline.py +32 -0
- codejury/orchestrators/reflexion.py +58 -0
- codejury/orchestrators/single.py +24 -0
- codejury/providers/__init__.py +5 -0
- codejury/providers/anthropic.py +68 -0
- codejury/providers/base.py +42 -0
- codejury/providers/litellm.py +68 -0
- codejury/providers/mock.py +32 -0
- codejury/providers/openai.py +57 -0
- codejury/providers/openai_format.py +30 -0
- codejury/providers/retry.py +48 -0
- codejury/reporting.py +114 -0
- codejury/resources.py +13 -0
- codejury/sources/__init__.py +6 -0
- codejury/sources/base.py +17 -0
- codejury/sources/chunker.py +33 -0
- codejury/sources/diff.py +69 -0
- codejury/sources/function.py +35 -0
- codejury/sources/mock.py +25 -0
- codejury/sources/repo.py +44 -0
- codejury/tasks/__init__.py +6 -0
- codejury/tasks/base.py +55 -0
- codejury/tasks/registry.py +22 -0
- codejury-0.1.0.dist-info/METADATA +110 -0
- codejury-0.1.0.dist-info/RECORD +67 -0
- codejury-0.1.0.dist-info/WHEEL +5 -0
- codejury-0.1.0.dist-info/entry_points.txt +2 -0
- codejury-0.1.0.dist-info/licenses/LICENSE +21 -0
- codejury-0.1.0.dist-info/top_level.txt +1 -0
codejury/evaluation.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
"""Evaluation harness -- measure detection quality against labelled golden cases.
|
|
2
|
+
|
|
3
|
+
A golden case is a code snippet labelled vulnerable or not for one capability.
|
|
4
|
+
``evaluate`` runs the verifier over each case and scores predictions into a
|
|
5
|
+
confusion matrix with precision / recall / accuracy. The metric math is provider
|
|
6
|
+
-agnostic and unit-tested; real numbers need a real provider.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
import yaml
|
|
16
|
+
|
|
17
|
+
from codejury.agents.verifier import VerifierAgent
|
|
18
|
+
from codejury.domain.artifact import CodeArtifact
|
|
19
|
+
from codejury.domain.capability import Capability
|
|
20
|
+
from codejury.domain.context import AnalysisContext
|
|
21
|
+
from codejury.providers.base import Provider
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass(frozen=True, kw_only=True)
|
|
25
|
+
class GoldenCase:
|
|
26
|
+
name: str
|
|
27
|
+
capability: str # capability id this case exercises
|
|
28
|
+
vulnerable: bool # the ground-truth label
|
|
29
|
+
code: str
|
|
30
|
+
|
|
31
|
+
@classmethod
|
|
32
|
+
def from_dict(cls, name: str, data: dict[str, Any]) -> GoldenCase:
|
|
33
|
+
return cls(
|
|
34
|
+
name=name,
|
|
35
|
+
capability=data["capability"],
|
|
36
|
+
vulnerable=bool(data["vulnerable"]),
|
|
37
|
+
code=data["code"],
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass
|
|
42
|
+
class Metrics:
|
|
43
|
+
tp: int = 0
|
|
44
|
+
fp: int = 0
|
|
45
|
+
tn: int = 0
|
|
46
|
+
fn: int = 0
|
|
47
|
+
|
|
48
|
+
def record(self, *, actual: bool, predicted: bool) -> None:
|
|
49
|
+
if actual and predicted:
|
|
50
|
+
self.tp += 1
|
|
51
|
+
elif actual and not predicted:
|
|
52
|
+
self.fn += 1
|
|
53
|
+
elif not actual and predicted:
|
|
54
|
+
self.fp += 1
|
|
55
|
+
else:
|
|
56
|
+
self.tn += 1
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def total(self) -> int:
|
|
60
|
+
return self.tp + self.fp + self.tn + self.fn
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def precision(self) -> float:
|
|
64
|
+
predicted_positive = self.tp + self.fp
|
|
65
|
+
return self.tp / predicted_positive if predicted_positive else 0.0
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def recall(self) -> float:
|
|
69
|
+
actual_positive = self.tp + self.fn
|
|
70
|
+
return self.tp / actual_positive if actual_positive else 0.0
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def accuracy(self) -> float:
|
|
74
|
+
return (self.tp + self.tn) / self.total if self.total else 0.0
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def load_cases(directory: str | Path) -> list[GoldenCase]:
|
|
78
|
+
cases = []
|
|
79
|
+
for path in sorted(Path(directory).glob("*.yaml")):
|
|
80
|
+
with open(path, encoding="utf-8") as f:
|
|
81
|
+
data = yaml.safe_load(f)
|
|
82
|
+
cases.append(GoldenCase.from_dict(path.stem, data))
|
|
83
|
+
return cases
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def evaluate(
|
|
87
|
+
cases: list[GoldenCase],
|
|
88
|
+
capabilities: list[Capability],
|
|
89
|
+
*,
|
|
90
|
+
provider: Provider,
|
|
91
|
+
model: str,
|
|
92
|
+
max_tokens: int = 2048,
|
|
93
|
+
) -> Metrics:
|
|
94
|
+
by_id = {c.id: c for c in capabilities}
|
|
95
|
+
agent = VerifierAgent(provider=provider, model=model, max_tokens=max_tokens)
|
|
96
|
+
metrics = Metrics()
|
|
97
|
+
for case in cases:
|
|
98
|
+
capability = by_id.get(case.capability)
|
|
99
|
+
if capability is None:
|
|
100
|
+
raise ValueError(f"golden case {case.name!r} references unknown capability {case.capability!r}")
|
|
101
|
+
ctx = AnalysisContext(
|
|
102
|
+
artifact=CodeArtifact(kind="file", path=case.name, content=case.code),
|
|
103
|
+
capabilities=[capability],
|
|
104
|
+
)
|
|
105
|
+
predicted = any(getattr(v, "status", None) == "VULNERABLE" for v in agent.run(ctx))
|
|
106
|
+
metrics.record(actual=case.vulnerable, predicted=predicted)
|
|
107
|
+
return metrics
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""Best-effort extraction of a JSON object from model output.
|
|
2
|
+
|
|
3
|
+
Models often wrap JSON in prose or code fences despite instructions. This
|
|
4
|
+
recovers the object with no third-party dependency: try a direct parse, then a
|
|
5
|
+
fenced ```json block, then the first balanced-brace span in the text.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
import re
|
|
12
|
+
|
|
13
|
+
# Greedy so a fenced block with nested braces is captured whole.
|
|
14
|
+
_FENCE = re.compile(r"```(?:json)?\s*(\{.*\})\s*```", re.DOTALL)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def extract_json_object(text: str) -> dict | None:
|
|
18
|
+
"""Return the first JSON object found in `text`, or None if there is none."""
|
|
19
|
+
text = text.strip()
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
obj = json.loads(text)
|
|
23
|
+
return obj if isinstance(obj, dict) else None
|
|
24
|
+
except json.JSONDecodeError:
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
fenced = _FENCE.search(text)
|
|
28
|
+
if fenced:
|
|
29
|
+
try:
|
|
30
|
+
obj = json.loads(fenced.group(1))
|
|
31
|
+
if isinstance(obj, dict):
|
|
32
|
+
return obj
|
|
33
|
+
except json.JSONDecodeError:
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
return _first_balanced_object(text)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _first_balanced_object(text: str) -> dict | None:
|
|
40
|
+
depth = 0
|
|
41
|
+
start = -1
|
|
42
|
+
for i, ch in enumerate(text):
|
|
43
|
+
if ch == "{":
|
|
44
|
+
if depth == 0:
|
|
45
|
+
start = i
|
|
46
|
+
depth += 1
|
|
47
|
+
elif ch == "}" and depth:
|
|
48
|
+
depth -= 1
|
|
49
|
+
if depth == 0:
|
|
50
|
+
try:
|
|
51
|
+
obj = json.loads(text[start : i + 1])
|
|
52
|
+
except json.JSONDecodeError:
|
|
53
|
+
start = -1
|
|
54
|
+
continue
|
|
55
|
+
if isinstance(obj, dict):
|
|
56
|
+
return obj
|
|
57
|
+
return None
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
"""codejury.orchestrators -- strategies for running agents over a context.
|
|
2
|
+
|
|
3
|
+
single / debate / pipeline / reflexion. The strategy is the "any orchestration"
|
|
4
|
+
axis; a task picks one. Each takes the same agents and context and returns an
|
|
5
|
+
AnalysisResult, so they are interchangeable.
|
|
6
|
+
"""
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""Orchestrator ABC.
|
|
2
|
+
|
|
3
|
+
An orchestrator decides how agents run over a context -- one pass, an
|
|
4
|
+
adversarial debate, capability-by-capability, etc. Capabilities are read from
|
|
5
|
+
``context.capabilities``, so they are not a separate argument.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from abc import ABC, abstractmethod
|
|
11
|
+
|
|
12
|
+
from codejury.agents.base import Agent
|
|
13
|
+
from codejury.domain.context import AnalysisContext
|
|
14
|
+
from codejury.domain.result import AnalysisResult
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Orchestrator(ABC):
|
|
18
|
+
@abstractmethod
|
|
19
|
+
def run(self, agents: dict[str, Agent], context: AnalysisContext) -> AnalysisResult: ...
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""DebateOrchestrator -- adversarial Finder -> Challenger -> Judge rounds.
|
|
2
|
+
|
|
3
|
+
Each round the three agents run in turn, with the accumulated history and round
|
|
4
|
+
number threaded into their context. The round's product is the Judge's ruling
|
|
5
|
+
(surviving Findings + dismissed Concessions).
|
|
6
|
+
|
|
7
|
+
Convergence is decided here, not by the Judge: the debate stops when the set of
|
|
8
|
+
surviving finding titles is unchanged from the previous round, or when
|
|
9
|
+
max_rounds is reached. The final result is the last round's ruling.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import dataclasses
|
|
15
|
+
|
|
16
|
+
from codejury.agents.base import Agent
|
|
17
|
+
from codejury.domain.context import AnalysisContext
|
|
18
|
+
from codejury.domain.observation import Finding, Observation
|
|
19
|
+
from codejury.domain.result import AnalysisResult
|
|
20
|
+
from codejury.orchestrators.base import Orchestrator
|
|
21
|
+
|
|
22
|
+
_REQUIRED_ROLES = ("finder", "challenger", "judge")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class DebateOrchestrator(Orchestrator):
|
|
26
|
+
def __init__(self, *, max_rounds: int = 3) -> None:
|
|
27
|
+
self._max_rounds = max_rounds
|
|
28
|
+
|
|
29
|
+
def run(self, agents: dict[str, Agent], context: AnalysisContext) -> AnalysisResult:
|
|
30
|
+
missing = [role for role in _REQUIRED_ROLES if role not in agents]
|
|
31
|
+
if missing:
|
|
32
|
+
return AnalysisResult(error=f"debate requires agents: {', '.join(missing)}")
|
|
33
|
+
finder, challenger, judge = (agents[role] for role in _REQUIRED_ROLES)
|
|
34
|
+
|
|
35
|
+
history: list[Observation] = []
|
|
36
|
+
ruling: list[Observation] = []
|
|
37
|
+
previous_survivors: frozenset[str] | None = None
|
|
38
|
+
|
|
39
|
+
for round_num in range(1, self._max_rounds + 1):
|
|
40
|
+
try:
|
|
41
|
+
history = history + finder.run(_round_ctx(context, history, round_num))
|
|
42
|
+
history = history + challenger.run(_round_ctx(context, history, round_num))
|
|
43
|
+
ruling = judge.run(_round_ctx(context, history, round_num))
|
|
44
|
+
history = history + ruling
|
|
45
|
+
except Exception as exc:
|
|
46
|
+
return AnalysisResult(observations=ruling, error=f"debate round {round_num} failed: {exc}")
|
|
47
|
+
|
|
48
|
+
survivors = frozenset(o.title for o in ruling if isinstance(o, Finding))
|
|
49
|
+
if survivors == previous_survivors:
|
|
50
|
+
break
|
|
51
|
+
previous_survivors = survivors
|
|
52
|
+
|
|
53
|
+
return AnalysisResult(observations=ruling)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _round_ctx(context: AnalysisContext, history: list[Observation], round_num: int) -> AnalysisContext:
|
|
57
|
+
return dataclasses.replace(context, history=history, round_num=round_num)
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
"""PipelineOrchestrator -- capability-by-capability full sweep.
|
|
2
|
+
|
|
3
|
+
Each capability is checked in its own single-capability context, so a failure or
|
|
4
|
+
bad reply on one capability does not abort the rest; errors are collected and
|
|
5
|
+
reported together. This is the robust choice for auditing a whole repository
|
|
6
|
+
across every dimension, where the single orchestrator would stop at the first
|
|
7
|
+
agent error.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import dataclasses
|
|
13
|
+
|
|
14
|
+
from codejury.agents.base import Agent
|
|
15
|
+
from codejury.domain.context import AnalysisContext
|
|
16
|
+
from codejury.domain.observation import Observation
|
|
17
|
+
from codejury.domain.result import AnalysisResult
|
|
18
|
+
from codejury.orchestrators.base import Orchestrator
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class PipelineOrchestrator(Orchestrator):
|
|
22
|
+
def run(self, agents: dict[str, Agent], context: AnalysisContext) -> AnalysisResult:
|
|
23
|
+
observations: list[Observation] = []
|
|
24
|
+
errors: list[str] = []
|
|
25
|
+
for capability in context.capabilities:
|
|
26
|
+
cap_ctx = dataclasses.replace(context, capabilities=[capability])
|
|
27
|
+
for name, agent in agents.items():
|
|
28
|
+
try:
|
|
29
|
+
observations.extend(agent.run(cap_ctx))
|
|
30
|
+
except Exception as exc:
|
|
31
|
+
errors.append(f"{capability.id}/{name}: {exc}")
|
|
32
|
+
return AnalysisResult(observations=observations, error="; ".join(errors) or None)
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
"""ReflexionOrchestrator -- actor -> critic -> actor self-revision loop.
|
|
2
|
+
|
|
3
|
+
A lighter cousin of debate: an actor produces findings, a critic pushes back,
|
|
4
|
+
and the actor revises with the critique in its history. There is no judge; the
|
|
5
|
+
result is the actor's final output. Iterates until the actor's findings are
|
|
6
|
+
stable or max_iterations is reached.
|
|
7
|
+
|
|
8
|
+
The actor and critic are ordinary agents (e.g. Finder as actor, Challenger as
|
|
9
|
+
critic), so no reflexion-specific agent is needed.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import dataclasses
|
|
15
|
+
|
|
16
|
+
from codejury.agents.base import Agent
|
|
17
|
+
from codejury.domain.context import AnalysisContext
|
|
18
|
+
from codejury.domain.observation import Finding, Observation
|
|
19
|
+
from codejury.domain.result import AnalysisResult
|
|
20
|
+
from codejury.orchestrators.base import Orchestrator
|
|
21
|
+
|
|
22
|
+
_REQUIRED_ROLES = ("actor", "critic")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ReflexionOrchestrator(Orchestrator):
|
|
26
|
+
def __init__(self, *, max_iterations: int = 2) -> None:
|
|
27
|
+
self._max_iterations = max_iterations
|
|
28
|
+
|
|
29
|
+
def run(self, agents: dict[str, Agent], context: AnalysisContext) -> AnalysisResult:
|
|
30
|
+
missing = [role for role in _REQUIRED_ROLES if role not in agents]
|
|
31
|
+
if missing:
|
|
32
|
+
return AnalysisResult(error=f"reflexion requires agents: {', '.join(missing)}")
|
|
33
|
+
actor, critic = agents["actor"], agents["critic"]
|
|
34
|
+
|
|
35
|
+
history: list[Observation] = []
|
|
36
|
+
actor_output: list[Observation] = []
|
|
37
|
+
previous_findings: frozenset[str] | None = None
|
|
38
|
+
|
|
39
|
+
for iteration in range(1, self._max_iterations + 1):
|
|
40
|
+
try:
|
|
41
|
+
actor_output = actor.run(_iter_ctx(context, history, iteration))
|
|
42
|
+
history = history + actor_output
|
|
43
|
+
|
|
44
|
+
findings = frozenset(o.title for o in actor_output if isinstance(o, Finding))
|
|
45
|
+
if findings == previous_findings:
|
|
46
|
+
break
|
|
47
|
+
previous_findings = findings
|
|
48
|
+
|
|
49
|
+
if iteration < self._max_iterations:
|
|
50
|
+
history = history + critic.run(_iter_ctx(context, history, iteration))
|
|
51
|
+
except Exception as exc:
|
|
52
|
+
return AnalysisResult(observations=actor_output, error=f"reflexion iteration {iteration} failed: {exc}")
|
|
53
|
+
|
|
54
|
+
return AnalysisResult(observations=actor_output)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _iter_ctx(context: AnalysisContext, history: list[Observation], iteration: int) -> AnalysisContext:
|
|
58
|
+
return dataclasses.replace(context, history=history, round_num=iteration)
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""SingleOrchestrator -- the baseline: run each agent once and collect verdicts.
|
|
2
|
+
|
|
3
|
+
The cheapest strategy. If an agent raises (e.g. a provider failure), the run
|
|
4
|
+
stops and the partial observations are returned with the error recorded, rather
|
|
5
|
+
than crashing the caller.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from codejury.agents.base import Agent
|
|
11
|
+
from codejury.domain.context import AnalysisContext
|
|
12
|
+
from codejury.domain.result import AnalysisResult
|
|
13
|
+
from codejury.orchestrators.base import Orchestrator
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class SingleOrchestrator(Orchestrator):
|
|
17
|
+
def run(self, agents: dict[str, Agent], context: AnalysisContext) -> AnalysisResult:
|
|
18
|
+
observations = []
|
|
19
|
+
for name, agent in agents.items():
|
|
20
|
+
try:
|
|
21
|
+
observations.extend(agent.run(context))
|
|
22
|
+
except Exception as exc:
|
|
23
|
+
return AnalysisResult(observations=observations, error=f"agent {name!r} failed: {exc}")
|
|
24
|
+
return AnalysisResult(observations=observations)
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
"""AnthropicProvider -- Provider backed by the Anthropic Messages API.
|
|
2
|
+
|
|
3
|
+
When ``cache`` is set, the system prompt is marked with an ephemeral
|
|
4
|
+
cache_control block; the capability checklist is large and reused across
|
|
5
|
+
artifacts, so caching it is the high-value target.
|
|
6
|
+
|
|
7
|
+
The Anthropic client is injectable so the mapping and caching logic can be
|
|
8
|
+
tested without the SDK or an API key. Constructed lazily otherwise, reading
|
|
9
|
+
ANTHROPIC_API_KEY from the environment.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
16
|
+
from codejury.providers.base import CompletionResult, Message, Provider
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class AnthropicProvider(Provider):
|
|
20
|
+
def __init__(self, *, api_key: str | None = None, client: Any | None = None) -> None:
|
|
21
|
+
self._api_key = api_key
|
|
22
|
+
self._client = client
|
|
23
|
+
|
|
24
|
+
def _get_client(self) -> Any:
|
|
25
|
+
if self._client is None:
|
|
26
|
+
try:
|
|
27
|
+
import anthropic
|
|
28
|
+
except ImportError as exc:
|
|
29
|
+
raise RuntimeError(
|
|
30
|
+
"anthropic SDK not installed; run: pip install 'codejury[anthropic]'"
|
|
31
|
+
) from exc
|
|
32
|
+
self._client = anthropic.Anthropic(api_key=self._api_key)
|
|
33
|
+
return self._client
|
|
34
|
+
|
|
35
|
+
def complete(
|
|
36
|
+
self,
|
|
37
|
+
*,
|
|
38
|
+
system: str,
|
|
39
|
+
messages: list[Message],
|
|
40
|
+
model: str,
|
|
41
|
+
max_tokens: int,
|
|
42
|
+
cache: bool = False,
|
|
43
|
+
) -> CompletionResult:
|
|
44
|
+
system_param: Any = system
|
|
45
|
+
if cache and system:
|
|
46
|
+
system_param = [{"type": "text", "text": system, "cache_control": {"type": "ephemeral"}}]
|
|
47
|
+
|
|
48
|
+
response = self._get_client().messages.create(
|
|
49
|
+
model=model,
|
|
50
|
+
max_tokens=max_tokens,
|
|
51
|
+
system=system_param,
|
|
52
|
+
messages=[{"role": m.role, "content": m.content} for m in messages],
|
|
53
|
+
)
|
|
54
|
+
return CompletionResult(text=_extract_text(response))
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _extract_text(response: Any) -> str:
|
|
58
|
+
content = getattr(response, "content", None)
|
|
59
|
+
if not isinstance(content, list):
|
|
60
|
+
return str(content or "")
|
|
61
|
+
parts: list[str] = []
|
|
62
|
+
for block in content:
|
|
63
|
+
text = getattr(block, "text", None)
|
|
64
|
+
if text is None and isinstance(block, dict):
|
|
65
|
+
text = block.get("text")
|
|
66
|
+
if text:
|
|
67
|
+
parts.append(str(text))
|
|
68
|
+
return "".join(parts)
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
"""Provider ABC and its typed input/output.
|
|
2
|
+
|
|
3
|
+
Deliberately minimal: one synchronous, non-streaming ``complete``. Streaming and
|
|
4
|
+
tool-calling are intentionally left out until a concrete need appears, so the
|
|
5
|
+
interface does not over-commit early.
|
|
6
|
+
|
|
7
|
+
``cache`` is a portable hint, not a guarantee: Anthropic supports prompt caching
|
|
8
|
+
natively, OpenAI does not, LiteLLM depends on the backend. Each provider decides
|
|
9
|
+
how to map the hint onto its own implementation.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
from abc import ABC, abstractmethod
|
|
15
|
+
from dataclasses import dataclass
|
|
16
|
+
from typing import Literal
|
|
17
|
+
|
|
18
|
+
Role = Literal["user", "assistant"]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass(frozen=True, kw_only=True)
|
|
22
|
+
class Message:
|
|
23
|
+
role: Role
|
|
24
|
+
content: str
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass(frozen=True, kw_only=True)
|
|
28
|
+
class CompletionResult:
|
|
29
|
+
text: str
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Provider(ABC):
|
|
33
|
+
@abstractmethod
|
|
34
|
+
def complete(
|
|
35
|
+
self,
|
|
36
|
+
*,
|
|
37
|
+
system: str,
|
|
38
|
+
messages: list[Message],
|
|
39
|
+
model: str,
|
|
40
|
+
max_tokens: int,
|
|
41
|
+
cache: bool = False,
|
|
42
|
+
) -> CompletionResult: ...
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
"""LiteLLMProvider -- Provider backed by LiteLLM, reaching many backends.
|
|
2
|
+
|
|
3
|
+
LiteLLM speaks the OpenAI chat shape, so the system prompt is sent as the first
|
|
4
|
+
message. ``cache`` is accepted but not applied here: prompt caching under LiteLLM
|
|
5
|
+
is backend-specific, so it stays a no-op until a backend-aware mapping is needed.
|
|
6
|
+
|
|
7
|
+
The completion callable is injectable so the mapping can be tested without the
|
|
8
|
+
SDK or an API key.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from typing import Any, Callable
|
|
14
|
+
|
|
15
|
+
from codejury.providers.base import CompletionResult, Message, Provider
|
|
16
|
+
from codejury.providers.openai_format import choice_text
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class LiteLLMProvider(Provider):
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
*,
|
|
23
|
+
api_key: str | None = None,
|
|
24
|
+
api_base: str | None = None,
|
|
25
|
+
temperature: float = 0.2,
|
|
26
|
+
completion: Callable[..., Any] | None = None,
|
|
27
|
+
) -> None:
|
|
28
|
+
self._api_key = api_key
|
|
29
|
+
self._api_base = api_base
|
|
30
|
+
self._temperature = temperature
|
|
31
|
+
self._completion = completion
|
|
32
|
+
|
|
33
|
+
def _completion_fn(self) -> Callable[..., Any]:
|
|
34
|
+
if self._completion is None:
|
|
35
|
+
try:
|
|
36
|
+
import litellm
|
|
37
|
+
except ImportError as exc:
|
|
38
|
+
raise RuntimeError("litellm not installed; run: pip install 'codejury[litellm]'") from exc
|
|
39
|
+
self._completion = litellm.completion
|
|
40
|
+
return self._completion
|
|
41
|
+
|
|
42
|
+
def complete(
|
|
43
|
+
self,
|
|
44
|
+
*,
|
|
45
|
+
system: str,
|
|
46
|
+
messages: list[Message],
|
|
47
|
+
model: str,
|
|
48
|
+
max_tokens: int,
|
|
49
|
+
cache: bool = False,
|
|
50
|
+
) -> CompletionResult:
|
|
51
|
+
api_messages: list[dict] = []
|
|
52
|
+
if system:
|
|
53
|
+
api_messages.append({"role": "system", "content": system})
|
|
54
|
+
api_messages += [{"role": m.role, "content": m.content} for m in messages]
|
|
55
|
+
|
|
56
|
+
kwargs: dict[str, Any] = {
|
|
57
|
+
"model": model,
|
|
58
|
+
"messages": api_messages,
|
|
59
|
+
"max_tokens": max_tokens,
|
|
60
|
+
"temperature": self._temperature,
|
|
61
|
+
}
|
|
62
|
+
if self._api_key:
|
|
63
|
+
kwargs["api_key"] = self._api_key
|
|
64
|
+
if self._api_base:
|
|
65
|
+
kwargs["api_base"] = self._api_base
|
|
66
|
+
|
|
67
|
+
response = self._completion_fn()(**kwargs)
|
|
68
|
+
return CompletionResult(text=choice_text(response))
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
"""MockProvider -- a Provider that returns canned text instead of calling a model.
|
|
2
|
+
|
|
3
|
+
Used for the end-to-end dry-run and for tests, so the pipeline can run with no
|
|
4
|
+
API key and deterministic output. It holds no parsing or audit logic: it returns
|
|
5
|
+
whatever text it was configured with and records each call for inspection.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from codejury.providers.base import CompletionResult, Message, Provider
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class MockProvider(Provider):
|
|
14
|
+
def __init__(self, *, responses: list[str] | None = None, default: str = "") -> None:
|
|
15
|
+
# responses are returned in order, one per call; once exhausted, `default`
|
|
16
|
+
# is returned for every further call.
|
|
17
|
+
self._responses = list(responses or [])
|
|
18
|
+
self._default = default
|
|
19
|
+
self.calls: list[dict] = []
|
|
20
|
+
|
|
21
|
+
def complete(
|
|
22
|
+
self,
|
|
23
|
+
*,
|
|
24
|
+
system: str,
|
|
25
|
+
messages: list[Message],
|
|
26
|
+
model: str,
|
|
27
|
+
max_tokens: int,
|
|
28
|
+
cache: bool = False,
|
|
29
|
+
) -> CompletionResult:
|
|
30
|
+
self.calls.append({"system": system, "messages": messages, "model": model})
|
|
31
|
+
text = self._responses.pop(0) if self._responses else self._default
|
|
32
|
+
return CompletionResult(text=text)
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""OpenAIProvider -- Provider backed by the OpenAI Chat Completions API.
|
|
2
|
+
|
|
3
|
+
The system prompt is sent as the first chat message. ``cache`` is accepted but
|
|
4
|
+
not applied: OpenAI caches long prompts automatically server-side, with no
|
|
5
|
+
request parameter to set.
|
|
6
|
+
|
|
7
|
+
The client is injectable so the mapping can be tested without the SDK or a key.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
from codejury.providers.base import CompletionResult, Message, Provider
|
|
15
|
+
from codejury.providers.openai_format import choice_text
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class OpenAIProvider(Provider):
|
|
19
|
+
def __init__(self, *, api_key: str | None = None, base_url: str | None = None, client: Any | None = None) -> None:
|
|
20
|
+
self._api_key = api_key
|
|
21
|
+
self._base_url = base_url
|
|
22
|
+
self._client = client
|
|
23
|
+
|
|
24
|
+
def _get_client(self) -> Any:
|
|
25
|
+
if self._client is None:
|
|
26
|
+
try:
|
|
27
|
+
import openai
|
|
28
|
+
except ImportError as exc:
|
|
29
|
+
raise RuntimeError("openai not installed; run: pip install 'codejury[openai]'") from exc
|
|
30
|
+
kwargs: dict[str, Any] = {}
|
|
31
|
+
if self._api_key:
|
|
32
|
+
kwargs["api_key"] = self._api_key
|
|
33
|
+
if self._base_url:
|
|
34
|
+
kwargs["base_url"] = self._base_url
|
|
35
|
+
self._client = openai.OpenAI(**kwargs)
|
|
36
|
+
return self._client
|
|
37
|
+
|
|
38
|
+
def complete(
|
|
39
|
+
self,
|
|
40
|
+
*,
|
|
41
|
+
system: str,
|
|
42
|
+
messages: list[Message],
|
|
43
|
+
model: str,
|
|
44
|
+
max_tokens: int,
|
|
45
|
+
cache: bool = False,
|
|
46
|
+
) -> CompletionResult:
|
|
47
|
+
api_messages: list[dict] = []
|
|
48
|
+
if system:
|
|
49
|
+
api_messages.append({"role": "system", "content": system})
|
|
50
|
+
api_messages += [{"role": m.role, "content": m.content} for m in messages]
|
|
51
|
+
|
|
52
|
+
response = self._get_client().chat.completions.create(
|
|
53
|
+
model=model,
|
|
54
|
+
messages=api_messages,
|
|
55
|
+
max_tokens=max_tokens,
|
|
56
|
+
)
|
|
57
|
+
return CompletionResult(text=choice_text(response))
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""Text extraction for the OpenAI chat-completions response shape.
|
|
2
|
+
|
|
3
|
+
Shared by OpenAIProvider and LiteLLMProvider, since LiteLLM returns the same
|
|
4
|
+
``choices[0].message.content`` structure (a string, or a list of content blocks).
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def choice_text(response: Any) -> str:
|
|
13
|
+
choices = getattr(response, "choices", None) or []
|
|
14
|
+
if not choices:
|
|
15
|
+
return ""
|
|
16
|
+
message = getattr(choices[0], "message", choices[0])
|
|
17
|
+
content = getattr(message, "content", None)
|
|
18
|
+
if content is None and isinstance(message, dict):
|
|
19
|
+
content = message.get("content")
|
|
20
|
+
if isinstance(content, str):
|
|
21
|
+
return content
|
|
22
|
+
if isinstance(content, list):
|
|
23
|
+
parts = []
|
|
24
|
+
for block in content:
|
|
25
|
+
if isinstance(block, dict) and "text" in block:
|
|
26
|
+
parts.append(str(block["text"]))
|
|
27
|
+
elif isinstance(block, str):
|
|
28
|
+
parts.append(block)
|
|
29
|
+
return "".join(parts)
|
|
30
|
+
return str(content or "")
|