rag-benchmarking 1.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- app/__init__.py +5 -0
- app/api/evaluate.py +170 -0
- app/api/query.py +72 -0
- app/api/security.py +23 -0
- app/config/__init__.py +1 -0
- app/config/settings.py +74 -0
- app/engine/__init__.py +0 -0
- app/engine/rag_engine.py +189 -0
- app/eval/agentic_llm_metrics.py +151 -0
- app/eval/agentic_metrics.py +45 -0
- app/eval/faithfulness.py +79 -0
- app/eval/ragas_runner.py +204 -0
- app/eval/reporting.py +30 -0
- app/eval/result_store.py +269 -0
- app/eval/retrieval_metrics.py +42 -0
- app/exceptions.py +22 -0
- app/llm/client.py +175 -0
- app/logging/__init__.py +1 -0
- app/logging/json_logger.py +53 -0
- app/main.py +89 -0
- app/quality/self_check.py +13 -0
- app/retrieval/chunking.py +49 -0
- app/retrieval/embeddings.py +20 -0
- app/retrieval/ingest_cli.py +75 -0
- app/retrieval/qdrant_store.py +155 -0
- app/retrieval/reranker.py +25 -0
- app/retrieval/service.py +55 -0
- app/sdk/__init__.py +0 -0
- app/sdk/client.py +192 -0
- app/utils/timing.py +17 -0
- harness/__init__.py +25 -0
- harness/protocol.py +68 -0
- harness/result_store.py +101 -0
- harness/runner.py +154 -0
- harness/schemas.py +132 -0
- rag_benchmarking-1.0.0rc1.dist-info/METADATA +386 -0
- rag_benchmarking-1.0.0rc1.dist-info/RECORD +40 -0
- rag_benchmarking-1.0.0rc1.dist-info/WHEEL +5 -0
- rag_benchmarking-1.0.0rc1.dist-info/licenses/LICENSE +192 -0
- rag_benchmarking-1.0.0rc1.dist-info/top_level.txt +2 -0
app/__init__.py
ADDED
app/api/evaluate.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import functools
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from fastapi import APIRouter, Depends, HTTPException
|
|
9
|
+
from pydantic import BaseModel, Field
|
|
10
|
+
|
|
11
|
+
from app.api.security import get_api_key
|
|
12
|
+
from app.eval.reporting import write_report_files
|
|
13
|
+
from app.eval.result_store import ResultStore
|
|
14
|
+
from harness.schemas import AgentTrace
|
|
15
|
+
|
|
16
|
+
router = APIRouter(prefix="/v1", tags=["evaluate"])
|
|
17
|
+
|
|
18
|
+
_result_store = ResultStore()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class EvalSample(BaseModel):
|
|
22
|
+
question: str
|
|
23
|
+
contexts: list[str] = Field(default_factory=list)
|
|
24
|
+
answer: str
|
|
25
|
+
ground_truths: list[str] = Field(default_factory=list)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class EvalRequest(BaseModel):
|
|
29
|
+
samples: list[EvalSample]
|
|
30
|
+
metrics: list[str] | None = None
|
|
31
|
+
out_json: str | None = None
|
|
32
|
+
out_md: str | None = None
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@router.post("/evaluate")
|
|
36
|
+
async def post_evaluate(req: EvalRequest, _: str | None = Depends(get_api_key)) -> dict[str, Any]:
|
|
37
|
+
"""Evaluate RAG samples against the requested metrics.
|
|
38
|
+
|
|
39
|
+
Routes through EvaluationRunner which handles:
|
|
40
|
+
- Deterministic metrics (source_attribution_accuracy) — no LLM needed
|
|
41
|
+
- RAGAS metrics (faithfulness, answer_relevancy, context_precision, context_recall) — needs Gemini/OpenAI
|
|
42
|
+
- Retrieval metrics (precision_at_k, recall_at_k, mrr, ndcg_at_k) — no LLM needed
|
|
43
|
+
|
|
44
|
+
RAGAS calls are offloaded to a thread-pool executor to avoid blocking
|
|
45
|
+
the FastAPI event loop.
|
|
46
|
+
"""
|
|
47
|
+
from harness.runner import EvaluationRunner
|
|
48
|
+
from harness.schemas import EvalSample as HarnessEvalSample, RunConfig
|
|
49
|
+
|
|
50
|
+
loop = asyncio.get_running_loop()
|
|
51
|
+
try:
|
|
52
|
+
# Convert API samples to harness EvalSample objects
|
|
53
|
+
harness_samples = [
|
|
54
|
+
HarnessEvalSample(
|
|
55
|
+
question=s.question,
|
|
56
|
+
contexts=s.contexts,
|
|
57
|
+
answer=s.answer,
|
|
58
|
+
ground_truth=s.ground_truths[0] if s.ground_truths else None,
|
|
59
|
+
)
|
|
60
|
+
for s in req.samples
|
|
61
|
+
]
|
|
62
|
+
|
|
63
|
+
config = RunConfig(metrics=req.metrics or ["faithfulness", "answer_relevancy"])
|
|
64
|
+
runner = EvaluationRunner(config)
|
|
65
|
+
|
|
66
|
+
result = await loop.run_in_executor(None, runner.evaluate, harness_samples)
|
|
67
|
+
|
|
68
|
+
# Convert BenchmarkReport to dict for response
|
|
69
|
+
output: dict[str, Any] = {
|
|
70
|
+
"metrics": result.metrics,
|
|
71
|
+
"skipped_metrics": result.skipped_metrics,
|
|
72
|
+
"skip_reasons": result.skip_reasons,
|
|
73
|
+
"run_id": result.run_id,
|
|
74
|
+
"n_samples": result.n_samples,
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
paths = write_report_files(
|
|
78
|
+
{"metrics": result.metrics, "per_sample": {}, "skipped_metrics": result.skipped_metrics},
|
|
79
|
+
out_json=Path(req.out_json) if req.out_json else None,
|
|
80
|
+
out_md=Path(req.out_md) if req.out_md else None,
|
|
81
|
+
)
|
|
82
|
+
output["written"] = paths
|
|
83
|
+
return output
|
|
84
|
+
|
|
85
|
+
except Exception as e:
|
|
86
|
+
raise HTTPException(status_code=500, detail=str(e)) from e
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class AgentEvalRequest(BaseModel):
|
|
90
|
+
trace: AgentTrace
|
|
91
|
+
metrics: list[str] = Field(
|
|
92
|
+
default=["source_attribution_accuracy", "agent_faithfulness", "tool_call_accuracy"]
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@router.post("/evaluate/agent")
|
|
97
|
+
async def post_evaluate_agent(
|
|
98
|
+
request: AgentEvalRequest,
|
|
99
|
+
_: str | None = Depends(get_api_key),
|
|
100
|
+
) -> dict[str, Any]:
|
|
101
|
+
"""Evaluate an agentic RAG trace using agentic-specific metrics."""
|
|
102
|
+
import re
|
|
103
|
+
|
|
104
|
+
from app.eval.agentic_metrics import source_attribution_accuracy
|
|
105
|
+
from app.eval.agentic_llm_metrics import (
|
|
106
|
+
compute_agent_faithfulness,
|
|
107
|
+
compute_retrieval_necessity,
|
|
108
|
+
compute_tool_call_accuracy,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
scores: dict[str, float] = {}
|
|
112
|
+
details: dict[str, Any] = {}
|
|
113
|
+
|
|
114
|
+
for metric in request.metrics:
|
|
115
|
+
if metric == "source_attribution_accuracy":
|
|
116
|
+
cited = re.findall(r'\[source:\s*([^\]]+)\]', request.trace.final_answer)
|
|
117
|
+
retrieved = [c.source_id for c in request.trace.retrieved_chunks]
|
|
118
|
+
r = source_attribution_accuracy(cited, retrieved)
|
|
119
|
+
scores[metric] = r["score"]
|
|
120
|
+
details[metric] = r
|
|
121
|
+
elif metric == "agent_faithfulness":
|
|
122
|
+
loop = asyncio.get_running_loop()
|
|
123
|
+
r = await loop.run_in_executor(None, compute_agent_faithfulness, request.trace)
|
|
124
|
+
scores[metric] = r["score"]
|
|
125
|
+
details[metric] = r
|
|
126
|
+
elif metric == "tool_call_accuracy":
|
|
127
|
+
loop = asyncio.get_running_loop()
|
|
128
|
+
r = await loop.run_in_executor(None, compute_tool_call_accuracy, request.trace)
|
|
129
|
+
scores[metric] = r["score"]
|
|
130
|
+
details[metric] = r
|
|
131
|
+
elif metric == "retrieval_necessity":
|
|
132
|
+
contexts = [c.content for c in request.trace.retrieved_chunks]
|
|
133
|
+
loop = asyncio.get_running_loop()
|
|
134
|
+
r = await loop.run_in_executor(
|
|
135
|
+
None,
|
|
136
|
+
functools.partial(
|
|
137
|
+
compute_retrieval_necessity,
|
|
138
|
+
request.trace.question,
|
|
139
|
+
request.trace.final_answer,
|
|
140
|
+
contexts,
|
|
141
|
+
),
|
|
142
|
+
)
|
|
143
|
+
scores[metric] = r["score"]
|
|
144
|
+
details[metric] = r
|
|
145
|
+
|
|
146
|
+
return {
|
|
147
|
+
"scores": scores,
|
|
148
|
+
"details": details,
|
|
149
|
+
"trace_id": request.trace.trace_id,
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
@router.get("/runs")
|
|
154
|
+
async def list_runs(
|
|
155
|
+
limit: int = 50,
|
|
156
|
+
_: str | None = Depends(get_api_key),
|
|
157
|
+
) -> list[dict[str, Any]]:
|
|
158
|
+
"""List recent evaluation runs from the result store."""
|
|
159
|
+
loop = asyncio.get_running_loop()
|
|
160
|
+
return await loop.run_in_executor(None, functools.partial(_result_store.list_runs, limit))
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
@router.post("/runs/compare")
|
|
164
|
+
async def compare_runs(
|
|
165
|
+
run_ids: list[str],
|
|
166
|
+
_: str | None = Depends(get_api_key),
|
|
167
|
+
) -> dict[str, Any]:
|
|
168
|
+
"""Compare metrics across multiple named runs."""
|
|
169
|
+
loop = asyncio.get_running_loop()
|
|
170
|
+
return await loop.run_in_executor(None, functools.partial(_result_store.compare_runs, run_ids))
|
app/api/query.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import functools
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from fastapi import APIRouter, Depends, HTTPException
|
|
8
|
+
from pydantic import BaseModel, Field
|
|
9
|
+
|
|
10
|
+
from app.engine.rag_engine import RAGEngine, RetrievedChunk
|
|
11
|
+
|
|
12
|
+
router = APIRouter(prefix="/v1", tags=["query"])
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class QueryRequest(BaseModel):
|
|
16
|
+
query: str = Field(..., min_length=1)
|
|
17
|
+
top_k: int = Field(5, ge=1, le=20)
|
|
18
|
+
rerank: bool = Field(default=False)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class QueryResponse(BaseModel):
|
|
22
|
+
answer: str
|
|
23
|
+
citations: list[RetrievedChunk]
|
|
24
|
+
timings_ms: dict[str, float] | None = None
|
|
25
|
+
tokens: dict[str, int] | None = None
|
|
26
|
+
groundedness: float | None = None
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def get_rag_engine() -> RAGEngine:
|
|
30
|
+
return RAGEngine()
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@router.post("/query", response_model=QueryResponse)
|
|
34
|
+
async def post_query(
|
|
35
|
+
req: QueryRequest, engine: RAGEngine = Depends(get_rag_engine)
|
|
36
|
+
) -> QueryResponse:
|
|
37
|
+
"""Async query endpoint.
|
|
38
|
+
|
|
39
|
+
``RAGEngine.query()`` calls sentence-transformers (CPU-bound / sync) and
|
|
40
|
+
blocking HTTP calls to OpenAI / Gemini. We run it in a thread-pool
|
|
41
|
+
executor so the FastAPI event loop is never blocked.
|
|
42
|
+
|
|
43
|
+
Pattern for sync code in async FastAPI:
|
|
44
|
+
loop = asyncio.get_running_loop()
|
|
45
|
+
result = await loop.run_in_executor(None, functools.partial(sync_fn, *args))
|
|
46
|
+
"""
|
|
47
|
+
loop = asyncio.get_running_loop()
|
|
48
|
+
try:
|
|
49
|
+
result = await loop.run_in_executor(
|
|
50
|
+
None,
|
|
51
|
+
functools.partial(engine.query, req.query, req.top_k, req.rerank),
|
|
52
|
+
)
|
|
53
|
+
except Exception as e:
|
|
54
|
+
raise HTTPException(status_code=500, detail=str(e)) from e
|
|
55
|
+
|
|
56
|
+
# Token usage: LLMClient.generate() stores self._last_token_usage after
|
|
57
|
+
# each call. Guard against mock engines (test environments) where the
|
|
58
|
+
# attribute may be a Mock object rather than a dict.
|
|
59
|
+
_raw = getattr(getattr(engine, "llm", None), "_last_token_usage", None)
|
|
60
|
+
last_usage: dict[str, int] = (
|
|
61
|
+
_raw
|
|
62
|
+
if isinstance(_raw, dict)
|
|
63
|
+
else {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
return QueryResponse(
|
|
67
|
+
answer=result.answer,
|
|
68
|
+
citations=result.citations,
|
|
69
|
+
timings_ms=result.timings,
|
|
70
|
+
tokens=last_usage,
|
|
71
|
+
groundedness=result.groundedness,
|
|
72
|
+
)
|
app/api/security.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from fastapi import HTTPException, Security, status
|
|
2
|
+
from fastapi.security import APIKeyHeader
|
|
3
|
+
|
|
4
|
+
from app.config.settings import get_settings
|
|
5
|
+
|
|
6
|
+
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def get_api_key(api_key_header: str = Security(api_key_header)) -> str | None:
|
|
10
|
+
settings = get_settings()
|
|
11
|
+
|
|
12
|
+
# If no API key is configured on the server, allow access (open mode)
|
|
13
|
+
# WARNING: This is for development/POC only.
|
|
14
|
+
if not settings.api_key:
|
|
15
|
+
return None
|
|
16
|
+
|
|
17
|
+
if api_key_header == settings.api_key:
|
|
18
|
+
return api_key_header
|
|
19
|
+
|
|
20
|
+
raise HTTPException(
|
|
21
|
+
status_code=status.HTTP_403_FORBIDDEN,
|
|
22
|
+
detail="Could not validate credentials",
|
|
23
|
+
)
|
app/config/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
app/config/settings.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from functools import lru_cache
|
|
4
|
+
|
|
5
|
+
from pydantic import Field
|
|
6
|
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class AppSettings(BaseSettings):
|
|
10
|
+
"""Central application settings loaded from environment and optional .env file.
|
|
11
|
+
|
|
12
|
+
Attributes
|
|
13
|
+
----------
|
|
14
|
+
app_env: str
|
|
15
|
+
Application environment (e.g., dev, prod).
|
|
16
|
+
log_level: str
|
|
17
|
+
Logging level string (e.g., INFO, DEBUG).
|
|
18
|
+
openai_api_key: Optional[str]
|
|
19
|
+
API key for OpenAI (optional).
|
|
20
|
+
gemini_api_key: Optional[str]
|
|
21
|
+
API key for Google Gemini (optional).
|
|
22
|
+
qdrant_url: Optional[str]
|
|
23
|
+
Qdrant Cloud URL.
|
|
24
|
+
qdrant_api_key: Optional[str]
|
|
25
|
+
Qdrant Cloud API key.
|
|
26
|
+
qdrant_collection: str
|
|
27
|
+
Default Qdrant collection name.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
|
|
31
|
+
|
|
32
|
+
app_env: str = Field(default="dev", alias="APP_ENV")
|
|
33
|
+
log_level: str = Field(default="INFO", alias="LOG_LEVEL")
|
|
34
|
+
|
|
35
|
+
openai_api_key: str | None = Field(default=None, alias="OPENAI_API_KEY")
|
|
36
|
+
api_key: str | None = Field(default=None, alias="API_KEY")
|
|
37
|
+
gemini_api_key: str | None = Field(default=None, alias="GEMINI_API_KEY")
|
|
38
|
+
llm_provider: str | None = Field(default=None, alias="LLM_PROVIDER")
|
|
39
|
+
openai_model: str | None = Field(default=None, alias="OPENAI_MODEL")
|
|
40
|
+
gemini_model: str | None = Field(default=None, alias="GEMINI_MODEL")
|
|
41
|
+
llm_temperature: float = Field(default=0.2, alias="LLM_TEMPERATURE")
|
|
42
|
+
llm_max_tokens: int = Field(default=512, alias="LLM_MAX_TOKENS")
|
|
43
|
+
|
|
44
|
+
qdrant_url: str | None = Field(default=None, alias="QDRANT_URL")
|
|
45
|
+
qdrant_api_key: str | None = Field(default=None, alias="QDRANT_API_KEY")
|
|
46
|
+
qdrant_collection: str = Field(default="agentic_rag_poc", alias="QDRANT_COLLECTION")
|
|
47
|
+
# Self-check configuration
|
|
48
|
+
self_check_min_groundedness: float = Field(default=0.7, alias="SELF_CHECK_MIN_GROUNDEDNESS")
|
|
49
|
+
self_check_retry: bool = Field(default=True, alias="SELF_CHECK_RETRY")
|
|
50
|
+
|
|
51
|
+
# Prompts
|
|
52
|
+
system_prompt: str = Field(
|
|
53
|
+
default=(
|
|
54
|
+
"You are a helpful assistant. Answer based only on the provided context. Cite sources."
|
|
55
|
+
),
|
|
56
|
+
alias="SYSTEM_PROMPT",
|
|
57
|
+
)
|
|
58
|
+
user_prompt_template: str = Field(
|
|
59
|
+
default="Context:\n{context_blocks}\n\nQuestion: {query}\nAnswer:",
|
|
60
|
+
alias="USER_PROMPT_TEMPLATE",
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@lru_cache(maxsize=1)
|
|
65
|
+
def get_settings() -> AppSettings:
|
|
66
|
+
"""Return a cached instance of application settings.
|
|
67
|
+
|
|
68
|
+
Returns
|
|
69
|
+
-------
|
|
70
|
+
AppSettings
|
|
71
|
+
Loaded settings object.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
return AppSettings() # type: ignore[call-arg]
|
app/engine/__init__.py
ADDED
|
File without changes
|
app/engine/rag_engine.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
|
|
8
|
+
import app.llm.client as llm_client
|
|
9
|
+
import app.retrieval.service as retrieval_service
|
|
10
|
+
from app.config.settings import get_settings
|
|
11
|
+
from app.utils.timing import timer
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class RetrievedChunk(BaseModel):
|
|
17
|
+
text: str
|
|
18
|
+
source_id: str
|
|
19
|
+
chunk_index: int
|
|
20
|
+
score: float
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class RAGResult(BaseModel):
|
|
24
|
+
answer: str
|
|
25
|
+
citations: list[RetrievedChunk]
|
|
26
|
+
timings: dict[str, float]
|
|
27
|
+
groundedness: float | None = None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class RAGEngine:
|
|
31
|
+
def __init__(self) -> None:
|
|
32
|
+
self.settings = get_settings()
|
|
33
|
+
self.llm = llm_client.LLMClient()
|
|
34
|
+
|
|
35
|
+
def query(self, query: str, top_k: int, rerank: bool) -> RAGResult:
|
|
36
|
+
"""
|
|
37
|
+
Execute the full RAG pipeline including retrieval, reranking, generation,
|
|
38
|
+
and optional retry.
|
|
39
|
+
"""
|
|
40
|
+
timings: dict[str, float] = {}
|
|
41
|
+
|
|
42
|
+
# 1. Retrieve
|
|
43
|
+
logger.info("Starting RAG query", extra={"query": query, "top_k": top_k, "rerank": rerank})
|
|
44
|
+
with timer() as t_retr:
|
|
45
|
+
chunks = retrieval_service.retrieve_top_chunks(query, top_k=max(top_k, 10))
|
|
46
|
+
timings["retrieve"] = t_retr["elapsed_ms"]
|
|
47
|
+
|
|
48
|
+
if not chunks:
|
|
49
|
+
logger.warning("No chunks retrieved for query", extra={"query": query})
|
|
50
|
+
return RAGResult(answer="", citations=[], timings=timings)
|
|
51
|
+
|
|
52
|
+
logger.info("Retrieved chunks", extra={"count": len(chunks)})
|
|
53
|
+
|
|
54
|
+
# 2. Rerank
|
|
55
|
+
if rerank:
|
|
56
|
+
try:
|
|
57
|
+
from app.retrieval.reranker import CrossEncoderReranker
|
|
58
|
+
|
|
59
|
+
with timer() as t_rr:
|
|
60
|
+
reranker = CrossEncoderReranker()
|
|
61
|
+
chunks = reranker.rerank(query, chunks, top_k=len(chunks))
|
|
62
|
+
timings["rerank"] = t_rr["elapsed_ms"]
|
|
63
|
+
except Exception:
|
|
64
|
+
pass
|
|
65
|
+
|
|
66
|
+
current_chunks = chunks[:top_k]
|
|
67
|
+
|
|
68
|
+
# 3. Generate
|
|
69
|
+
with timer() as t_gen:
|
|
70
|
+
answer = self._call_llm(query, current_chunks)
|
|
71
|
+
timings["generate"] = t_gen["elapsed_ms"]
|
|
72
|
+
|
|
73
|
+
# 4. Self-Check
|
|
74
|
+
groundedness = None
|
|
75
|
+
try:
|
|
76
|
+
from app.quality.self_check import compute_groundedness
|
|
77
|
+
|
|
78
|
+
with timer() as t_sc:
|
|
79
|
+
groundedness = compute_groundedness(
|
|
80
|
+
answer, [c.get("text", "") for c in current_chunks]
|
|
81
|
+
)
|
|
82
|
+
timings["self_check"] = t_sc["elapsed_ms"]
|
|
83
|
+
except Exception:
|
|
84
|
+
pass
|
|
85
|
+
|
|
86
|
+
# 5. Retry if needed
|
|
87
|
+
if (
|
|
88
|
+
groundedness is not None
|
|
89
|
+
and groundedness < self.settings.self_check_min_groundedness
|
|
90
|
+
and self.settings.self_check_retry
|
|
91
|
+
):
|
|
92
|
+
|
|
93
|
+
logger.info(
|
|
94
|
+
"Groundedness below threshold, attempting retry",
|
|
95
|
+
extra={
|
|
96
|
+
"groundedness": groundedness,
|
|
97
|
+
"threshold": self.settings.self_check_min_groundedness,
|
|
98
|
+
},
|
|
99
|
+
)
|
|
100
|
+
try:
|
|
101
|
+
# Retry logic
|
|
102
|
+
retry_result = self._retry_workflow(query, top_k, rerank, groundedness)
|
|
103
|
+
if retry_result:
|
|
104
|
+
logger.info(
|
|
105
|
+
"Retry successful, adopting new answer",
|
|
106
|
+
extra={"new_groundedness": retry_result["groundedness"]},
|
|
107
|
+
)
|
|
108
|
+
answer = retry_result["answer"]
|
|
109
|
+
groundedness = retry_result["groundedness"]
|
|
110
|
+
current_chunks = retry_result["chunks"]
|
|
111
|
+
timings.update(retry_result["timings"])
|
|
112
|
+
else:
|
|
113
|
+
logger.info("Retry did not improve groundedness")
|
|
114
|
+
except Exception as e:
|
|
115
|
+
logger.error("Error during retry workflow", extra={"error": str(e)})
|
|
116
|
+
pass
|
|
117
|
+
|
|
118
|
+
citations = [
|
|
119
|
+
RetrievedChunk(
|
|
120
|
+
text=c.get("text", ""),
|
|
121
|
+
source_id=c.get("source_id", ""),
|
|
122
|
+
chunk_index=int(c.get("chunk_index", 0)),
|
|
123
|
+
score=float(c.get("score", 0.0)),
|
|
124
|
+
)
|
|
125
|
+
for c in current_chunks
|
|
126
|
+
]
|
|
127
|
+
|
|
128
|
+
return RAGResult(
|
|
129
|
+
answer=answer, citations=citations, timings=timings, groundedness=groundedness
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
def _call_llm(self, query: str, chunks: list[dict[str, Any]]) -> str:
|
|
133
|
+
context_blocks = "\n\n".join(
|
|
134
|
+
[f"[source: {c.get('source_id','')}]\n{c.get('text','')}" for c in chunks]
|
|
135
|
+
)
|
|
136
|
+
user_prompt = self.settings.user_prompt_template.format(
|
|
137
|
+
context_blocks=context_blocks, query=query
|
|
138
|
+
)
|
|
139
|
+
return self.llm.generate(self.settings.system_prompt, user_prompt)
|
|
140
|
+
|
|
141
|
+
def _retry_workflow(
|
|
142
|
+
self, query: str, top_k: int, rerank: bool, current_score: float
|
|
143
|
+
) -> dict[str, Any] | None:
|
|
144
|
+
timings = {}
|
|
145
|
+
|
|
146
|
+
# Expand retrieval
|
|
147
|
+
with timer() as t_retr:
|
|
148
|
+
more_chunks = retrieval_service.retrieve_top_chunks(query, top_k=20)
|
|
149
|
+
timings["retrieve_retry"] = t_retr["elapsed_ms"]
|
|
150
|
+
|
|
151
|
+
if rerank:
|
|
152
|
+
try:
|
|
153
|
+
from app.retrieval.reranker import CrossEncoderReranker
|
|
154
|
+
|
|
155
|
+
with timer() as t_rr:
|
|
156
|
+
reranker = CrossEncoderReranker()
|
|
157
|
+
more_chunks = reranker.rerank(query, more_chunks, top_k=len(more_chunks))
|
|
158
|
+
timings["rerank_retry"] = t_rr["elapsed_ms"]
|
|
159
|
+
except Exception:
|
|
160
|
+
pass
|
|
161
|
+
|
|
162
|
+
more_chunks = more_chunks[:top_k]
|
|
163
|
+
|
|
164
|
+
# Generate
|
|
165
|
+
with timer() as t_gen:
|
|
166
|
+
answer = self._call_llm(query, more_chunks)
|
|
167
|
+
timings["generate_retry"] = t_gen["elapsed_ms"]
|
|
168
|
+
|
|
169
|
+
# Check
|
|
170
|
+
groundedness = None
|
|
171
|
+
try:
|
|
172
|
+
from app.quality.self_check import compute_groundedness
|
|
173
|
+
|
|
174
|
+
with timer() as t_sc:
|
|
175
|
+
groundedness = compute_groundedness(
|
|
176
|
+
answer, [c.get("text", "") for c in more_chunks]
|
|
177
|
+
)
|
|
178
|
+
timings["self_check_retry"] = t_sc["elapsed_ms"]
|
|
179
|
+
except Exception:
|
|
180
|
+
return None
|
|
181
|
+
|
|
182
|
+
if groundedness is not None and groundedness >= current_score:
|
|
183
|
+
return {
|
|
184
|
+
"answer": answer,
|
|
185
|
+
"groundedness": groundedness,
|
|
186
|
+
"chunks": more_chunks,
|
|
187
|
+
"timings": timings,
|
|
188
|
+
}
|
|
189
|
+
return None
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from app.llm.client import LLMClient
|
|
8
|
+
from harness.schemas import AgentTrace
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _call_judge(system: str, user: str) -> dict[str, Any]:
|
|
14
|
+
"""Call LLM judge and parse JSON. Returns empty dict on failure."""
|
|
15
|
+
llm = LLMClient()
|
|
16
|
+
raw = llm.generate(system, user).strip()
|
|
17
|
+
try:
|
|
18
|
+
if raw.startswith("```"):
|
|
19
|
+
parts = raw.split("```")
|
|
20
|
+
raw = parts[1] if len(parts) > 1 else raw
|
|
21
|
+
if raw.startswith("json"):
|
|
22
|
+
raw = raw[4:]
|
|
23
|
+
raw = raw.strip()
|
|
24
|
+
return json.loads(raw)
|
|
25
|
+
except Exception as exc:
|
|
26
|
+
logger.warning("Judge parse failure: %s | raw: %.200s", exc, raw)
|
|
27
|
+
return {}
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
_AGENT_FAITHFULNESS_SYSTEM = """You are evaluating factual consistency across an AI agent's full reasoning trace.
|
|
31
|
+
|
|
32
|
+
For each reasoning step, verify every factual claim is supported by the tool output at that step or a previously retrieved source.
|
|
33
|
+
A claim is NOT supported if it appears in the reasoning but not in any tool output, or contradicts the tool output.
|
|
34
|
+
|
|
35
|
+
Return JSON only:
|
|
36
|
+
{
|
|
37
|
+
"step_analysis": [
|
|
38
|
+
{"step_index": int, "claims": ["..."], "supported": [true|false], "faithfulness_score": float}
|
|
39
|
+
],
|
|
40
|
+
"trace_faithfulness_score": float,
|
|
41
|
+
"worst_step": int,
|
|
42
|
+
"critical_hallucinations": ["..."]
|
|
43
|
+
}"""
|
|
44
|
+
|
|
45
|
+
_TOOL_ACCURACY_SYSTEM = """You are evaluating whether an AI agent made appropriate tool calls to answer a question.
|
|
46
|
+
For each tool call assess: was it necessary, was the correct tool chosen, was the input well-formed?
|
|
47
|
+
Score each: 0 (wrong), 0.5 (partial), 1 (correct).
|
|
48
|
+
|
|
49
|
+
Return JSON only:
|
|
50
|
+
{
|
|
51
|
+
"tool_evaluations": [
|
|
52
|
+
{"step_index": int, "tool_name": str, "necessary": bool, "correct_tool": bool,
|
|
53
|
+
"input_quality": float, "score": float, "reason": str}
|
|
54
|
+
],
|
|
55
|
+
"overall_score": float
|
|
56
|
+
}"""
|
|
57
|
+
|
|
58
|
+
_NECESSITY_SYSTEM = """Evaluate whether retrieval was necessary to answer this question.
|
|
59
|
+
Categories: NECESSARY / HELPFUL / UNNECESSARY
|
|
60
|
+
|
|
61
|
+
Return JSON only:
|
|
62
|
+
{
|
|
63
|
+
"necessity": "NECESSARY"|"HELPFUL"|"UNNECESSARY",
|
|
64
|
+
"parametric_answer_possible": bool,
|
|
65
|
+
"retrieval_contribution": "none"|"marginal"|"significant"|"essential",
|
|
66
|
+
"score": float,
|
|
67
|
+
"reasoning": str
|
|
68
|
+
}"""
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def compute_agent_faithfulness(trace: AgentTrace) -> dict:
|
|
72
|
+
"""
|
|
73
|
+
Evaluates faithfulness across the entire agent reasoning trace.
|
|
74
|
+
Returns score 1.0 when no reasoning steps present (neutral).
|
|
75
|
+
"""
|
|
76
|
+
if not trace.reasoning_steps:
|
|
77
|
+
return {
|
|
78
|
+
"score": 1.0,
|
|
79
|
+
"worst_step": -1,
|
|
80
|
+
"critical_hallucinations": [],
|
|
81
|
+
"step_analysis": [],
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
steps_text = "\n\n".join(
|
|
85
|
+
f"Step {s.step_index}: {s.thought}\nObservation: {s.observation}"
|
|
86
|
+
for s in trace.reasoning_steps
|
|
87
|
+
)
|
|
88
|
+
sources_text = "\n\n".join(
|
|
89
|
+
f"[{c.source_id}]: {c.content}" for c in trace.retrieved_chunks
|
|
90
|
+
) or "No explicit chunks provided."
|
|
91
|
+
|
|
92
|
+
user = f"FULL AGENT TRACE:\n{steps_text}\n\nSOURCES:\n{sources_text}"
|
|
93
|
+
data = _call_judge(_AGENT_FAITHFULNESS_SYSTEM, user)
|
|
94
|
+
|
|
95
|
+
score = float(data.get("trace_faithfulness_score", 0.0))
|
|
96
|
+
return {
|
|
97
|
+
"score": max(0.0, min(1.0, score)),
|
|
98
|
+
"worst_step": data.get("worst_step", -1),
|
|
99
|
+
"critical_hallucinations": data.get("critical_hallucinations", []),
|
|
100
|
+
"step_analysis": data.get("step_analysis", []),
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def compute_tool_call_accuracy(trace: AgentTrace) -> dict:
|
|
105
|
+
"""
|
|
106
|
+
Evaluates whether the agent made appropriate tool calls.
|
|
107
|
+
Returns score 1.0 with empty evaluations when no tool calls present.
|
|
108
|
+
"""
|
|
109
|
+
if not trace.tool_calls:
|
|
110
|
+
return {"score": 1.0, "tool_evaluations": []}
|
|
111
|
+
|
|
112
|
+
calls_text = "\n".join(
|
|
113
|
+
f"Step {tc.step_index}: {tc.tool_name}({tc.tool_input}) → {tc.tool_output[:200]}"
|
|
114
|
+
for tc in trace.tool_calls
|
|
115
|
+
)
|
|
116
|
+
user = (
|
|
117
|
+
f"QUESTION: {trace.question}\n\n"
|
|
118
|
+
f"TOOL CALLS:\n{calls_text}\n\n"
|
|
119
|
+
"AVAILABLE TOOLS: retrieve, web_search, code_exec, calculator"
|
|
120
|
+
)
|
|
121
|
+
data = _call_judge(_TOOL_ACCURACY_SYSTEM, user)
|
|
122
|
+
score = float(data.get("overall_score", 0.0))
|
|
123
|
+
return {
|
|
124
|
+
"score": max(0.0, min(1.0, score)),
|
|
125
|
+
"tool_evaluations": data.get("tool_evaluations", []),
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def compute_retrieval_necessity(
|
|
130
|
+
question: str,
|
|
131
|
+
answer: str,
|
|
132
|
+
contexts: list[str],
|
|
133
|
+
) -> dict:
|
|
134
|
+
"""
|
|
135
|
+
Evaluates whether retrieval was actually necessary for this query.
|
|
136
|
+
High score = retrieval was essential; low score = retrieval was unnecessary.
|
|
137
|
+
"""
|
|
138
|
+
context_text = "\n\n".join(contexts)
|
|
139
|
+
user = (
|
|
140
|
+
f"QUESTION: {question}\n\n"
|
|
141
|
+
f"RETRIEVED CONTEXT:\n{context_text}\n\n"
|
|
142
|
+
f"FINAL ANSWER: {answer}"
|
|
143
|
+
)
|
|
144
|
+
data = _call_judge(_NECESSITY_SYSTEM, user)
|
|
145
|
+
score = float(data.get("score", 0.0))
|
|
146
|
+
return {
|
|
147
|
+
"score": max(0.0, min(1.0, score)),
|
|
148
|
+
"necessity": data.get("necessity", "UNKNOWN"),
|
|
149
|
+
"retrieval_contribution": data.get("retrieval_contribution", "UNKNOWN"),
|
|
150
|
+
"reasoning": data.get("reasoning", ""),
|
|
151
|
+
}
|