agentrec 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentrec/__init__.py +94 -0
- agentrec/__main__.py +3 -0
- agentrec/capture.py +46 -0
- agentrec/cli.py +173 -0
- agentrec/comparators.py +271 -0
- agentrec/keying.py +133 -0
- agentrec/migration.py +506 -0
- agentrec/providers/__init__.py +248 -0
- agentrec/providers/anthropic.py +160 -0
- agentrec/providers/base.py +152 -0
- agentrec/providers/openai.py +159 -0
- agentrec/report.py +424 -0
- agentrec/session.py +160 -0
- agentrec/store.py +338 -0
- agentrec/transport.py +271 -0
- agentrec-0.2.0.dist-info/METADATA +269 -0
- agentrec-0.2.0.dist-info/RECORD +21 -0
- agentrec-0.2.0.dist-info/WHEEL +4 -0
- agentrec-0.2.0.dist-info/entry_points.txt +2 -0
- agentrec-0.2.0.dist-info/licenses/LICENSE +21 -0
- agentrec-0.2.0.dist-info/licenses/NOTICE +24 -0
agentrec/__init__.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
from importlib.metadata import PackageNotFoundError, version as _version
|
|
2
|
+
|
|
3
|
+
from .capture import CapturedChunk, CapturedInteraction, CapturedRequest
|
|
4
|
+
from .comparators import (
|
|
5
|
+
Comparator,
|
|
6
|
+
ComparisonResult,
|
|
7
|
+
EmbeddingComparator,
|
|
8
|
+
ExactMatchComparator,
|
|
9
|
+
FuzzyComparator,
|
|
10
|
+
JudgeComparator,
|
|
11
|
+
build_comparators,
|
|
12
|
+
)
|
|
13
|
+
from .keying import Fingerprint, default_key, fingerprint, fingerprint_of
|
|
14
|
+
from .migration import (
|
|
15
|
+
CategoryBreakdown,
|
|
16
|
+
MigrationReport,
|
|
17
|
+
RowResult,
|
|
18
|
+
TokenTotals,
|
|
19
|
+
annotate_corpus,
|
|
20
|
+
migration_id_for,
|
|
21
|
+
run_migration,
|
|
22
|
+
)
|
|
23
|
+
from .providers import (
|
|
24
|
+
Conversation,
|
|
25
|
+
DecodedResponse,
|
|
26
|
+
DecodeError,
|
|
27
|
+
MissingAPIKeyError,
|
|
28
|
+
ProviderAdapter,
|
|
29
|
+
UnsupportedRequestError,
|
|
30
|
+
conversation_of,
|
|
31
|
+
decode_interaction,
|
|
32
|
+
)
|
|
33
|
+
from .report import render_console, render_html, render_markdown
|
|
34
|
+
from .session import DynamicTransport, async_client, cassette
|
|
35
|
+
from .store import FileStore, InMemoryStore, InteractionStore
|
|
36
|
+
from .transport import AutoTransport, RecordingTransport, ReplayTransport
|
|
37
|
+
|
|
38
|
+
try:
|
|
39
|
+
__version__ = _version("agentrec")
|
|
40
|
+
except PackageNotFoundError: # running from a source tree without an install
|
|
41
|
+
__version__ = "0.0.0.dev0"
|
|
42
|
+
|
|
43
|
+
__all__ = [
|
|
44
|
+
"__version__",
|
|
45
|
+
# Data
|
|
46
|
+
"CapturedChunk",
|
|
47
|
+
"CapturedInteraction",
|
|
48
|
+
"CapturedRequest",
|
|
49
|
+
# Keying
|
|
50
|
+
"Fingerprint",
|
|
51
|
+
"default_key",
|
|
52
|
+
"fingerprint",
|
|
53
|
+
"fingerprint_of",
|
|
54
|
+
# Stores
|
|
55
|
+
"FileStore",
|
|
56
|
+
"InMemoryStore",
|
|
57
|
+
"InteractionStore",
|
|
58
|
+
# Low-level transports
|
|
59
|
+
"AutoTransport",
|
|
60
|
+
"RecordingTransport",
|
|
61
|
+
"ReplayTransport",
|
|
62
|
+
# High-level facade
|
|
63
|
+
"DynamicTransport",
|
|
64
|
+
"async_client",
|
|
65
|
+
"cassette",
|
|
66
|
+
# Providers
|
|
67
|
+
"Conversation",
|
|
68
|
+
"DecodedResponse",
|
|
69
|
+
"DecodeError",
|
|
70
|
+
"MissingAPIKeyError",
|
|
71
|
+
"ProviderAdapter",
|
|
72
|
+
"UnsupportedRequestError",
|
|
73
|
+
"conversation_of",
|
|
74
|
+
"decode_interaction",
|
|
75
|
+
# Comparators
|
|
76
|
+
"Comparator",
|
|
77
|
+
"ComparisonResult",
|
|
78
|
+
"ExactMatchComparator",
|
|
79
|
+
"FuzzyComparator",
|
|
80
|
+
"EmbeddingComparator",
|
|
81
|
+
"JudgeComparator",
|
|
82
|
+
"build_comparators",
|
|
83
|
+
# Migration report
|
|
84
|
+
"CategoryBreakdown",
|
|
85
|
+
"MigrationReport",
|
|
86
|
+
"RowResult",
|
|
87
|
+
"TokenTotals",
|
|
88
|
+
"annotate_corpus",
|
|
89
|
+
"migration_id_for",
|
|
90
|
+
"run_migration",
|
|
91
|
+
"render_console",
|
|
92
|
+
"render_html",
|
|
93
|
+
"render_markdown",
|
|
94
|
+
]
|
agentrec/__main__.py
ADDED
agentrec/capture.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Storage-agnostic data structures for a captured HTTP interaction.
|
|
3
|
+
|
|
4
|
+
Storage backends (the JSON-cassette FileStore today, others later) consume
|
|
5
|
+
CapturedInteraction without the transports knowing about them.
|
|
6
|
+
"""
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from typing import Any, Dict, List, Tuple
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class CapturedChunk:
|
|
15
|
+
"""One raw byte frame from an HTTP response body."""
|
|
16
|
+
|
|
17
|
+
data: bytes
|
|
18
|
+
# Seconds elapsed since the first chunk arrived. Stored so replay can
|
|
19
|
+
# optionally simulate realistic pacing; instant replay ignores it.
|
|
20
|
+
timestamp_offset: float = 0.0
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class CapturedRequest:
|
|
25
|
+
method: str
|
|
26
|
+
url: str
|
|
27
|
+
# Raw (bytes, bytes) header pairs — provider-neutral, no decoding here.
|
|
28
|
+
headers: List[Tuple[bytes, bytes]]
|
|
29
|
+
content: bytes
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class CapturedInteraction:
|
|
34
|
+
"""Complete record of one request/response exchange."""
|
|
35
|
+
|
|
36
|
+
request: CapturedRequest
|
|
37
|
+
response_status: int
|
|
38
|
+
response_headers: List[Tuple[bytes, bytes]]
|
|
39
|
+
# extensions minus transport-specific keys like "network_stream"
|
|
40
|
+
response_extensions: dict
|
|
41
|
+
chunks: List[CapturedChunk] = field(default_factory=list)
|
|
42
|
+
# Provenance for the corpus: provider, model, semantic_key, recorded_at.
|
|
43
|
+
# Free-form so new fields (e.g. a future migration-report needs) drop in
|
|
44
|
+
# without a schema change. Populated by RecordingTransport; empty when an
|
|
45
|
+
# interaction is hand-built or loaded from a pre-metadata cassette.
|
|
46
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
agentrec/cli.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Command-line interface: ``python -m agentrec <command>`` (or ``agentrec ...``).
|
|
3
|
+
|
|
4
|
+
migrate Run the corpus against a target model (records new responses into
|
|
5
|
+
the corpus) and write a Markdown/HTML migration report.
|
|
6
|
+
report Re-render the report fully offline from already-recorded cassettes;
|
|
7
|
+
only the offline comparators (exact, fuzzy) are allowed.
|
|
8
|
+
annotate Backfill human-readable summary blocks and fingerprint metadata
|
|
9
|
+
into existing cassettes.
|
|
10
|
+
"""
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import argparse
|
|
14
|
+
import asyncio
|
|
15
|
+
import sys
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from typing import List, Optional
|
|
18
|
+
|
|
19
|
+
from .comparators import OFFLINE_COMPARATOR_NAMES, build_comparators
|
|
20
|
+
from .migration import annotate_corpus, run_migration
|
|
21
|
+
from .report import default_report_basename, render_console, render_html, render_markdown
|
|
22
|
+
from .store import FileStore
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _add_report_args(parser: argparse.ArgumentParser, *, default_compare: str) -> None:
|
|
26
|
+
parser.add_argument("--corpus", default="corpus", help="corpus directory (default: corpus)")
|
|
27
|
+
parser.add_argument("--target", required=True, help="target model id, e.g. claude-haiku-4-5")
|
|
28
|
+
parser.add_argument(
|
|
29
|
+
"--compare",
|
|
30
|
+
default=default_compare,
|
|
31
|
+
help=f"comma-separated comparators or 'all' (default: {default_compare})",
|
|
32
|
+
)
|
|
33
|
+
parser.add_argument(
|
|
34
|
+
"--target-provider",
|
|
35
|
+
default=None,
|
|
36
|
+
help="override the provider inferred from the target model id",
|
|
37
|
+
)
|
|
38
|
+
parser.add_argument("--judge-model", default="claude-opus-4-8", help="model for the judge comparator")
|
|
39
|
+
parser.add_argument(
|
|
40
|
+
"--embedding-model", default="text-embedding-3-small", help="model for the embedding comparator"
|
|
41
|
+
)
|
|
42
|
+
parser.add_argument("--fuzzy-threshold", type=float, default=0.8)
|
|
43
|
+
parser.add_argument("--embedding-threshold", type=float, default=0.8)
|
|
44
|
+
parser.add_argument(
|
|
45
|
+
"--max-tokens", type=int, default=4096,
|
|
46
|
+
help="max_tokens for target requests when the baseline did not set one",
|
|
47
|
+
)
|
|
48
|
+
parser.add_argument("--filter", default=None, help="only baselines whose id contains this substring")
|
|
49
|
+
parser.add_argument(
|
|
50
|
+
"--concurrency", type=int, default=8,
|
|
51
|
+
help="rows scored in parallel (default: 8)",
|
|
52
|
+
)
|
|
53
|
+
parser.add_argument(
|
|
54
|
+
"--format", choices=("md", "html", "both"), default="both", help="report format(s) to write"
|
|
55
|
+
)
|
|
56
|
+
parser.add_argument(
|
|
57
|
+
"--out", default=None,
|
|
58
|
+
help="output base path (extension added per format; default: migration-report__<target>__<timestamp>)",
|
|
59
|
+
)
|
|
60
|
+
parser.add_argument(
|
|
61
|
+
"--strict", action="store_true",
|
|
62
|
+
help="exit with code 1 if any comparison failed or errored (CI gate)",
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _parse(argv: Optional[List[str]]) -> argparse.Namespace:
|
|
67
|
+
parser = argparse.ArgumentParser(
|
|
68
|
+
prog="agentrec", description="Record/replay LLM corpus tooling and migration reports."
|
|
69
|
+
)
|
|
70
|
+
sub = parser.add_subparsers(dest="command", required=True)
|
|
71
|
+
|
|
72
|
+
migrate = sub.add_parser(
|
|
73
|
+
"migrate", help="run corpus prompts against a target model and write a migration report"
|
|
74
|
+
)
|
|
75
|
+
_add_report_args(migrate, default_compare="exact,fuzzy")
|
|
76
|
+
|
|
77
|
+
report = sub.add_parser(
|
|
78
|
+
"report", help="re-render a migration report offline from recorded cassettes"
|
|
79
|
+
)
|
|
80
|
+
_add_report_args(report, default_compare="exact,fuzzy")
|
|
81
|
+
|
|
82
|
+
annotate = sub.add_parser(
|
|
83
|
+
"annotate", help="backfill summary blocks and metadata into existing cassettes"
|
|
84
|
+
)
|
|
85
|
+
annotate.add_argument("--corpus", default="corpus", help="corpus directory (default: corpus)")
|
|
86
|
+
|
|
87
|
+
return parser.parse_args(argv)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def _write_reports(args: argparse.Namespace, report) -> List[Path]:
|
|
91
|
+
base = args.out or default_report_basename(args.target)
|
|
92
|
+
base_path = Path(base)
|
|
93
|
+
if base_path.suffix.lower() in (".md", ".html"):
|
|
94
|
+
base_path = base_path.with_suffix("")
|
|
95
|
+
written: List[Path] = []
|
|
96
|
+
if args.format in ("md", "both"):
|
|
97
|
+
path = base_path.with_suffix(".md")
|
|
98
|
+
path.write_text(render_markdown(report), encoding="utf-8")
|
|
99
|
+
written.append(path)
|
|
100
|
+
if args.format in ("html", "both"):
|
|
101
|
+
path = base_path.with_suffix(".html")
|
|
102
|
+
path.write_text(render_html(report), encoding="utf-8")
|
|
103
|
+
written.append(path)
|
|
104
|
+
return written
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
async def _run_report_command(args: argparse.Namespace, *, offline: bool) -> int:
|
|
108
|
+
if offline:
|
|
109
|
+
requested = (
|
|
110
|
+
list(OFFLINE_COMPARATOR_NAMES)
|
|
111
|
+
if args.compare.strip().lower() == "all"
|
|
112
|
+
else [name.strip() for name in args.compare.split(",") if name.strip()]
|
|
113
|
+
)
|
|
114
|
+
online = [name for name in requested if name not in OFFLINE_COMPARATOR_NAMES]
|
|
115
|
+
if online:
|
|
116
|
+
print(
|
|
117
|
+
f"report (offline) supports only {', '.join(OFFLINE_COMPARATOR_NAMES)}; "
|
|
118
|
+
f"drop: {', '.join(online)} (use `agentrec migrate` for live comparators)",
|
|
119
|
+
file=sys.stderr,
|
|
120
|
+
)
|
|
121
|
+
return 2
|
|
122
|
+
args.compare = ",".join(requested)
|
|
123
|
+
|
|
124
|
+
comparators = build_comparators(
|
|
125
|
+
args.compare,
|
|
126
|
+
judge_model=args.judge_model,
|
|
127
|
+
embedding_model=args.embedding_model,
|
|
128
|
+
fuzzy_threshold=args.fuzzy_threshold,
|
|
129
|
+
embedding_threshold=args.embedding_threshold,
|
|
130
|
+
)
|
|
131
|
+
store = FileStore(args.corpus)
|
|
132
|
+
report = await run_migration(
|
|
133
|
+
store,
|
|
134
|
+
args.target,
|
|
135
|
+
comparators,
|
|
136
|
+
target_provider=args.target_provider,
|
|
137
|
+
offline=offline,
|
|
138
|
+
max_tokens_default=args.max_tokens,
|
|
139
|
+
filter_substr=args.filter,
|
|
140
|
+
concurrency=args.concurrency,
|
|
141
|
+
)
|
|
142
|
+
written = _write_reports(args, report)
|
|
143
|
+
print(render_console(report))
|
|
144
|
+
for path in written:
|
|
145
|
+
print(f"Report written: {path}")
|
|
146
|
+
if args.strict and not report.all_passed:
|
|
147
|
+
return 1
|
|
148
|
+
return 0
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
async def _run_annotate(args: argparse.Namespace) -> int:
|
|
152
|
+
annotated = await annotate_corpus(FileStore(args.corpus))
|
|
153
|
+
print(f"Annotated {len(annotated)} cassette(s) in {args.corpus}")
|
|
154
|
+
return 0
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def main(argv: Optional[List[str]] = None) -> int:
|
|
158
|
+
args = _parse(argv)
|
|
159
|
+
try:
|
|
160
|
+
if args.command == "migrate":
|
|
161
|
+
return asyncio.run(_run_report_command(args, offline=False))
|
|
162
|
+
if args.command == "report":
|
|
163
|
+
return asyncio.run(_run_report_command(args, offline=True))
|
|
164
|
+
if args.command == "annotate":
|
|
165
|
+
return asyncio.run(_run_annotate(args))
|
|
166
|
+
except (ValueError, LookupError) as exc:
|
|
167
|
+
print(f"error: {exc}", file=sys.stderr)
|
|
168
|
+
return 2
|
|
169
|
+
return 2
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
if __name__ == "__main__":
|
|
173
|
+
raise SystemExit(main())
|
agentrec/comparators.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Comparators score a baseline response against a target-model response.
|
|
3
|
+
|
|
4
|
+
The migration runner evaluates *all* selected comparators per prompt in one
|
|
5
|
+
pass, so a single ``migrate`` run can report exact-match, fuzzy similarity,
|
|
6
|
+
embedding similarity and judge verdicts side by side.
|
|
7
|
+
|
|
8
|
+
* ``exact`` — normalized string equality. The right metric for
|
|
9
|
+
classification-style outputs ("positive" vs "Positive ").
|
|
10
|
+
* ``fuzzy`` — ``difflib.SequenceMatcher`` ratio; offline, dependency-free.
|
|
11
|
+
* ``embedding`` — cosine similarity of OpenAI embeddings (live API call).
|
|
12
|
+
* ``judge`` — an LLM scores semantic equivalence (live API call).
|
|
13
|
+
|
|
14
|
+
A comparator failure (missing API key, malformed judge reply) degrades to an
|
|
15
|
+
errored :class:`ComparisonResult` on that row — it never crashes the run.
|
|
16
|
+
"""
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import difflib
|
|
20
|
+
import json
|
|
21
|
+
import math
|
|
22
|
+
import re
|
|
23
|
+
from abc import ABC, abstractmethod
|
|
24
|
+
from dataclasses import dataclass
|
|
25
|
+
from typing import Dict, List, Optional, Sequence
|
|
26
|
+
|
|
27
|
+
import httpx
|
|
28
|
+
|
|
29
|
+
from .providers import Conversation, DecodedResponse, adapter_for_model, adapter_for_provider
|
|
30
|
+
|
|
31
|
+
OFFLINE_COMPARATOR_NAMES = ("exact", "fuzzy")
|
|
32
|
+
ALL_COMPARATOR_NAMES = ("exact", "fuzzy", "embedding", "judge")
|
|
33
|
+
|
|
34
|
+
_WHITESPACE = re.compile(r"\s+")
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _normalize(text: str) -> str:
|
|
38
|
+
return _WHITESPACE.sub(" ", text).strip().casefold()
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _clip(text: str, limit: int = 6000) -> str:
|
|
42
|
+
if len(text) <= limit:
|
|
43
|
+
return text
|
|
44
|
+
return text[:limit] + " …[truncated]"
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass(frozen=True)
|
|
48
|
+
class ComparisonResult:
|
|
49
|
+
comparator: str
|
|
50
|
+
score: float # 0.0–1.0
|
|
51
|
+
passed: Optional[bool] # None when pass/fail is not meaningful
|
|
52
|
+
detail: str = ""
|
|
53
|
+
error: bool = False # True when the comparator itself failed
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class Comparator(ABC):
|
|
57
|
+
"""Scores how well *target* preserves the behaviour of *baseline*."""
|
|
58
|
+
|
|
59
|
+
name: str
|
|
60
|
+
|
|
61
|
+
@abstractmethod
|
|
62
|
+
async def compare(
|
|
63
|
+
self, prompt: str, baseline: DecodedResponse, target: DecodedResponse
|
|
64
|
+
) -> ComparisonResult:
|
|
65
|
+
...
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class ExactMatchComparator(Comparator):
|
|
69
|
+
name = "exact"
|
|
70
|
+
|
|
71
|
+
async def compare(self, prompt, baseline, target) -> ComparisonResult:
|
|
72
|
+
matched = _normalize(baseline.text) == _normalize(target.text)
|
|
73
|
+
return ComparisonResult(
|
|
74
|
+
comparator=self.name,
|
|
75
|
+
score=1.0 if matched else 0.0,
|
|
76
|
+
passed=matched,
|
|
77
|
+
detail="normalized texts match" if matched else "normalized texts differ",
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class FuzzyComparator(Comparator):
|
|
82
|
+
name = "fuzzy"
|
|
83
|
+
|
|
84
|
+
def __init__(self, threshold: float = 0.8) -> None:
|
|
85
|
+
self._threshold = threshold
|
|
86
|
+
|
|
87
|
+
async def compare(self, prompt, baseline, target) -> ComparisonResult:
|
|
88
|
+
score = difflib.SequenceMatcher(
|
|
89
|
+
None, _normalize(baseline.text), _normalize(target.text)
|
|
90
|
+
).ratio()
|
|
91
|
+
return ComparisonResult(
|
|
92
|
+
comparator=self.name,
|
|
93
|
+
score=score,
|
|
94
|
+
passed=score >= self._threshold,
|
|
95
|
+
detail=f"sequence similarity {score:.2f} (threshold {self._threshold})",
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def cosine_similarity(a: Sequence[float], b: Sequence[float]) -> float:
|
|
100
|
+
dot = sum(x * y for x, y in zip(a, b))
|
|
101
|
+
norm = math.sqrt(sum(x * x for x in a)) * math.sqrt(sum(y * y for y in b))
|
|
102
|
+
if norm == 0:
|
|
103
|
+
return 0.0
|
|
104
|
+
return dot / norm
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class _HttpComparator(Comparator):
|
|
108
|
+
"""Shared plumbing for comparators that call an API via httpx."""
|
|
109
|
+
|
|
110
|
+
def __init__(self, http: Optional[httpx.AsyncClient] = None) -> None:
|
|
111
|
+
self._http = http
|
|
112
|
+
|
|
113
|
+
async def _post(self, url: str, headers: Dict[str, str], body: dict) -> httpx.Response:
|
|
114
|
+
if self._http is not None:
|
|
115
|
+
return await self._http.post(url, headers=headers, json=body)
|
|
116
|
+
async with httpx.AsyncClient(timeout=60.0) as client:
|
|
117
|
+
return await client.post(url, headers=headers, json=body)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class EmbeddingComparator(_HttpComparator):
|
|
121
|
+
name = "embedding"
|
|
122
|
+
|
|
123
|
+
embeddings_url = "https://api.openai.com/v1/embeddings"
|
|
124
|
+
|
|
125
|
+
def __init__(
|
|
126
|
+
self,
|
|
127
|
+
http: Optional[httpx.AsyncClient] = None,
|
|
128
|
+
*,
|
|
129
|
+
model: str = "text-embedding-3-small",
|
|
130
|
+
threshold: float = 0.8,
|
|
131
|
+
) -> None:
|
|
132
|
+
super().__init__(http)
|
|
133
|
+
self._model = model
|
|
134
|
+
self._threshold = threshold
|
|
135
|
+
|
|
136
|
+
async def compare(self, prompt, baseline, target) -> ComparisonResult:
|
|
137
|
+
headers = {
|
|
138
|
+
"Authorization": f"Bearer {adapter_for_provider('openai').api_key()}",
|
|
139
|
+
"Content-Type": "application/json",
|
|
140
|
+
}
|
|
141
|
+
body = {"model": self._model, "input": [_clip(baseline.text), _clip(target.text)]}
|
|
142
|
+
response = await self._post(self.embeddings_url, headers, body)
|
|
143
|
+
response.raise_for_status()
|
|
144
|
+
data = sorted(response.json()["data"], key=lambda item: item["index"])
|
|
145
|
+
score = max(0.0, min(1.0, cosine_similarity(data[0]["embedding"], data[1]["embedding"])))
|
|
146
|
+
return ComparisonResult(
|
|
147
|
+
comparator=self.name,
|
|
148
|
+
score=score,
|
|
149
|
+
passed=score >= self._threshold,
|
|
150
|
+
detail=f"cosine similarity {score:.2f} via {self._model} (threshold {self._threshold})",
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
_JUDGE_SYSTEM = (
|
|
155
|
+
"You compare two AI assistant responses to the same prompt and judge whether "
|
|
156
|
+
"the candidate response is an acceptable substitute for the baseline response. "
|
|
157
|
+
"Judge semantic equivalence of the substantive content, not style or length. "
|
|
158
|
+
'Reply with ONLY a JSON object: {"equivalent": true|false, "score": 0.0-1.0, '
|
|
159
|
+
'"reason": "<one sentence>"}'
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
_JUDGE_TEMPLATE = """<prompt>
|
|
163
|
+
{prompt}
|
|
164
|
+
</prompt>
|
|
165
|
+
|
|
166
|
+
<baseline_response>
|
|
167
|
+
{baseline}
|
|
168
|
+
</baseline_response>
|
|
169
|
+
|
|
170
|
+
<candidate_response>
|
|
171
|
+
{target}
|
|
172
|
+
</candidate_response>"""
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def _first_json_object(text: str) -> dict:
|
|
176
|
+
decoder = json.JSONDecoder()
|
|
177
|
+
for start in range(len(text)):
|
|
178
|
+
if text[start] != "{":
|
|
179
|
+
continue
|
|
180
|
+
try:
|
|
181
|
+
obj, _ = decoder.raw_decode(text, start)
|
|
182
|
+
except ValueError:
|
|
183
|
+
continue
|
|
184
|
+
if isinstance(obj, dict):
|
|
185
|
+
return obj
|
|
186
|
+
raise ValueError(f"no JSON object found in judge reply: {text[:200]!r}")
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class JudgeComparator(_HttpComparator):
|
|
190
|
+
name = "judge"
|
|
191
|
+
|
|
192
|
+
def __init__(
|
|
193
|
+
self,
|
|
194
|
+
http: Optional[httpx.AsyncClient] = None,
|
|
195
|
+
*,
|
|
196
|
+
judge_model: str = "claude-opus-4-8",
|
|
197
|
+
) -> None:
|
|
198
|
+
super().__init__(http)
|
|
199
|
+
self._judge_model = judge_model
|
|
200
|
+
|
|
201
|
+
async def compare(self, prompt, baseline, target) -> ComparisonResult:
|
|
202
|
+
adapter = adapter_for_model(self._judge_model)
|
|
203
|
+
conversation = Conversation(
|
|
204
|
+
system=_JUDGE_SYSTEM,
|
|
205
|
+
messages=[
|
|
206
|
+
{
|
|
207
|
+
"role": "user",
|
|
208
|
+
"content": _JUDGE_TEMPLATE.format(
|
|
209
|
+
prompt=_clip(prompt),
|
|
210
|
+
baseline=_clip(baseline.text),
|
|
211
|
+
target=_clip(target.text),
|
|
212
|
+
),
|
|
213
|
+
}
|
|
214
|
+
],
|
|
215
|
+
# No sampling params: the newest judge models reject them.
|
|
216
|
+
max_tokens=1024,
|
|
217
|
+
)
|
|
218
|
+
url, headers, body = adapter.build_request(conversation, self._judge_model)
|
|
219
|
+
response = await self._post(url, headers, body)
|
|
220
|
+
response.raise_for_status()
|
|
221
|
+
decoded = adapter.decode_response(await response.aread(), is_sse=False)
|
|
222
|
+
verdict = _first_json_object(decoded.text)
|
|
223
|
+
|
|
224
|
+
equivalent = bool(verdict.get("equivalent"))
|
|
225
|
+
raw_score = verdict.get("score")
|
|
226
|
+
score = float(raw_score) if isinstance(raw_score, (int, float)) else (1.0 if equivalent else 0.0)
|
|
227
|
+
score = max(0.0, min(1.0, score))
|
|
228
|
+
return ComparisonResult(
|
|
229
|
+
comparator=self.name,
|
|
230
|
+
score=score,
|
|
231
|
+
passed=equivalent,
|
|
232
|
+
detail=str(verdict.get("reason", "")) or f"judge {self._judge_model} verdict",
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def build_comparators(
|
|
237
|
+
spec: str,
|
|
238
|
+
*,
|
|
239
|
+
judge_model: str = "claude-opus-4-8",
|
|
240
|
+
embedding_model: str = "text-embedding-3-small",
|
|
241
|
+
fuzzy_threshold: float = 0.8,
|
|
242
|
+
embedding_threshold: float = 0.8,
|
|
243
|
+
http: Optional[httpx.AsyncClient] = None,
|
|
244
|
+
) -> List[Comparator]:
|
|
245
|
+
"""Parse a ``--compare`` spec like ``"exact,judge"`` or ``"all"``."""
|
|
246
|
+
names = (
|
|
247
|
+
list(ALL_COMPARATOR_NAMES)
|
|
248
|
+
if spec.strip().lower() == "all"
|
|
249
|
+
else [name.strip().lower() for name in spec.split(",") if name.strip()]
|
|
250
|
+
)
|
|
251
|
+
seen: List[str] = []
|
|
252
|
+
for name in names:
|
|
253
|
+
if name not in ALL_COMPARATOR_NAMES:
|
|
254
|
+
raise ValueError(
|
|
255
|
+
f"unknown comparator {name!r}; expected any of "
|
|
256
|
+
f"{', '.join(ALL_COMPARATOR_NAMES)} or 'all'"
|
|
257
|
+
)
|
|
258
|
+
if name not in seen:
|
|
259
|
+
seen.append(name)
|
|
260
|
+
if not seen:
|
|
261
|
+
raise ValueError("no comparators selected")
|
|
262
|
+
|
|
263
|
+
factories = {
|
|
264
|
+
"exact": lambda: ExactMatchComparator(),
|
|
265
|
+
"fuzzy": lambda: FuzzyComparator(threshold=fuzzy_threshold),
|
|
266
|
+
"embedding": lambda: EmbeddingComparator(
|
|
267
|
+
http, model=embedding_model, threshold=embedding_threshold
|
|
268
|
+
),
|
|
269
|
+
"judge": lambda: JudgeComparator(http, judge_model=judge_model),
|
|
270
|
+
}
|
|
271
|
+
return [factories[name]() for name in seen]
|