contexttrace 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,448 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ import time
6
+ from dataclasses import dataclass, field
7
+ from pathlib import Path
8
+ from typing import Any, Callable, Dict, Iterable, Optional
9
+
10
+ import httpx
11
+
12
+ from contexttrace.client import ContextTrace
13
+ from contexttrace.reliability import ReliabilityScorer
14
+
15
+ UNSUPPORTED_VERDICTS = {"unsupported", "contradicted", "not_enough_info"}
16
+ NO_FAILURE = "no_failure_detected"
17
+
18
+
19
+ @dataclass
20
+ class EvalQuestion:
21
+ question: str
22
+ id: Optional[str] = None
23
+ expected_answer: Optional[str] = None
24
+ metadata: dict[str, Any] = field(default_factory=dict)
25
+ payload: dict[str, Any] = field(default_factory=dict)
26
+
27
+
28
+ @dataclass
29
+ class EvalThresholds:
30
+ min_citation_support: float = 0.8
31
+ max_unsupported_claim_rate: float = 0.1
32
+ max_failure_rate: float = 0.0
33
+
34
+
35
+ @dataclass
36
+ class EvalResult:
37
+ question: EvalQuestion
38
+ trace_id: Optional[str]
39
+ citation_support: float
40
+ unsupported_claim_rate: float
41
+ failure_type: str
42
+ token_usage: dict[str, Any]
43
+ latency_ms: float
44
+ error: Optional[str] = None
45
+
46
+
47
+ @dataclass
48
+ class EvalRunSummary:
49
+ results: list[EvalResult]
50
+ thresholds: EvalThresholds
51
+ markdown: str
52
+
53
+ @property
54
+ def avg_citation_support(self) -> float:
55
+ if not self.results:
56
+ return 0.0
57
+ return round(
58
+ sum(result.citation_support for result in self.results) / len(self.results),
59
+ 3,
60
+ )
61
+
62
+ @property
63
+ def unsupported_claim_rate(self) -> float:
64
+ if not self.results:
65
+ return 1.0
66
+ return round(
67
+ sum(result.unsupported_claim_rate for result in self.results) / len(self.results),
68
+ 3,
69
+ )
70
+
71
+ @property
72
+ def failure_rate(self) -> float:
73
+ if not self.results:
74
+ return 1.0
75
+ failures = [
76
+ result
77
+ for result in self.results
78
+ if result.error or result.failure_type != NO_FAILURE
79
+ ]
80
+ return round(len(failures) / len(self.results), 3)
81
+
82
+ @property
83
+ def reliability(self) -> dict[str, Any]:
84
+ return ReliabilityScorer().score(
85
+ citation_support=self.avg_citation_support,
86
+ unsupported_claim_rate=self.unsupported_claim_rate,
87
+ failure_rate=self.failure_rate,
88
+ ).to_dict()
89
+
90
+ @property
91
+ def failed(self) -> bool:
92
+ return (
93
+ self.avg_citation_support < self.thresholds.min_citation_support
94
+ or self.unsupported_claim_rate > self.thresholds.max_unsupported_claim_rate
95
+ or self.failure_rate > self.thresholds.max_failure_rate
96
+ )
97
+
98
+
99
+ EndpointCaller = Callable[[str, EvalQuestion, float, Dict[str, str]], Dict[str, Any]]
100
+
101
+
102
+ def load_dataset(path: str) -> list[EvalQuestion]:
103
+ with open(path, "r", encoding="utf-8") as handle:
104
+ data = json.load(handle)
105
+
106
+ if isinstance(data, dict):
107
+ raw_questions = data.get("questions")
108
+ else:
109
+ raw_questions = data
110
+
111
+ if not isinstance(raw_questions, list):
112
+ raise ValueError("Dataset must be a JSON list or an object with a questions list.")
113
+
114
+ questions = []
115
+ for index, item in enumerate(raw_questions):
116
+ questions.append(_normalize_question(item, index=index))
117
+ return questions
118
+
119
+
120
+ def run_evaluation(
121
+ *,
122
+ dataset_path: str,
123
+ endpoint: str,
124
+ api_key: str,
125
+ project: str,
126
+ base_url: str,
127
+ thresholds: EvalThresholds,
128
+ summary_path: Optional[str] = None,
129
+ timeout: float = 30.0,
130
+ endpoint_headers: Optional[dict[str, str]] = None,
131
+ endpoint_caller: Optional[EndpointCaller] = None,
132
+ contexttrace: Optional[ContextTrace] = None,
133
+ ) -> EvalRunSummary:
134
+ questions = load_dataset(dataset_path)
135
+ caller = endpoint_caller or call_rag_endpoint
136
+ ct = contexttrace or ContextTrace(
137
+ api_key=api_key,
138
+ project=project,
139
+ base_url=base_url,
140
+ timeout=timeout,
141
+ )
142
+
143
+ results: list[EvalResult] = []
144
+ try:
145
+ for question in questions:
146
+ results.append(
147
+ _evaluate_question(
148
+ contexttrace=ct,
149
+ question=question,
150
+ endpoint=endpoint,
151
+ timeout=timeout,
152
+ endpoint_headers=endpoint_headers or {},
153
+ endpoint_caller=caller,
154
+ dataset_path=dataset_path,
155
+ )
156
+ )
157
+ finally:
158
+ if contexttrace is None:
159
+ ct.close()
160
+
161
+ summary = EvalRunSummary(
162
+ results=results,
163
+ thresholds=thresholds,
164
+ markdown="",
165
+ )
166
+ summary.markdown = render_markdown_summary(summary)
167
+ write_summary(summary.markdown, summary_path=summary_path)
168
+ return summary
169
+
170
+
171
+ def call_rag_endpoint(
172
+ endpoint: str,
173
+ question: EvalQuestion,
174
+ timeout: float,
175
+ headers: dict[str, str],
176
+ ) -> dict[str, Any]:
177
+ payload = question.payload or {
178
+ "query": question.question,
179
+ "question": question.question,
180
+ "metadata": question.metadata,
181
+ }
182
+ if question.expected_answer is not None:
183
+ payload["expected_answer"] = question.expected_answer
184
+
185
+ with httpx.Client(timeout=timeout) as client:
186
+ response = client.post(endpoint, json=payload, headers=headers)
187
+ response.raise_for_status()
188
+ return response.json()
189
+
190
+
191
+ def render_markdown_summary(summary: EvalRunSummary) -> str:
192
+ status = "failed" if summary.failed else "passed"
193
+ reliability = summary.reliability
194
+ lines = [
195
+ "# ContextTrace RAG Evaluation",
196
+ "",
197
+ f"Status: **{status}**",
198
+ "",
199
+ "## Reliability Score",
200
+ "",
201
+ f"Score: **{reliability['score']} ({reliability['grade']})**",
202
+ "",
203
+ "This is a practical diagnostic score. It summarizes the available reliability metrics, but the raw metrics below remain the source of truth.",
204
+ "",
205
+ "Recommendations:",
206
+ *["- %s" % item for item in reliability["recommendations"]],
207
+ "",
208
+ "| Metric | Value | Threshold |",
209
+ "| --- | ---: | ---: |",
210
+ (
211
+ "| Average citation support | %.3f | >= %.3f |"
212
+ % (summary.avg_citation_support, summary.thresholds.min_citation_support)
213
+ ),
214
+ (
215
+ "| Unsupported claim rate | %.3f | <= %.3f |"
216
+ % (summary.unsupported_claim_rate, summary.thresholds.max_unsupported_claim_rate)
217
+ ),
218
+ (
219
+ "| Failure rate | %.3f | <= %.3f |"
220
+ % (summary.failure_rate, summary.thresholds.max_failure_rate)
221
+ ),
222
+ "",
223
+ "| Question | Trace | Citation support | Unsupported claims | Failure | Tokens | Latency |",
224
+ "| --- | --- | ---: | ---: | --- | ---: | ---: |",
225
+ ]
226
+ for result in summary.results:
227
+ token_usage = result.token_usage.get("total_tokens", "")
228
+ trace = result.trace_id or ""
229
+ if trace:
230
+ trace = "`%s`" % trace
231
+ failure = result.error or result.failure_type
232
+ lines.append(
233
+ "| %s | %s | %.3f | %.3f | %s | %s | %.1f ms |"
234
+ % (
235
+ _escape_table(result.question.question),
236
+ trace,
237
+ result.citation_support,
238
+ result.unsupported_claim_rate,
239
+ _escape_table(failure),
240
+ token_usage,
241
+ result.latency_ms,
242
+ )
243
+ )
244
+ return "\n".join(lines) + "\n"
245
+
246
+
247
+ def write_summary(markdown: str, *, summary_path: Optional[str]) -> None:
248
+ output_path = summary_path or "contexttrace-eval-summary.md"
249
+ Path(output_path).write_text(markdown, encoding="utf-8")
250
+
251
+ github_summary = os.getenv("GITHUB_STEP_SUMMARY")
252
+ if github_summary:
253
+ with open(github_summary, "a", encoding="utf-8") as handle:
254
+ handle.write(markdown)
255
+
256
+
257
+ def _evaluate_question(
258
+ *,
259
+ contexttrace: ContextTrace,
260
+ question: EvalQuestion,
261
+ endpoint: str,
262
+ timeout: float,
263
+ endpoint_headers: dict[str, str],
264
+ endpoint_caller: EndpointCaller,
265
+ dataset_path: str,
266
+ ) -> EvalResult:
267
+ started_at = time.perf_counter()
268
+ error = None
269
+ response: dict[str, Any] = {}
270
+
271
+ try:
272
+ response = endpoint_caller(endpoint, question, timeout, endpoint_headers)
273
+ except Exception as exc: # pragma: no cover - exact httpx exception surface varies
274
+ error = str(exc)
275
+
276
+ latency_ms = round((time.perf_counter() - started_at) * 1000, 2)
277
+ chunks = _normalize_chunks(response)
278
+ selected_context = _normalize_selected_context(response, chunks)
279
+ answer = _extract_answer(response, error=error)
280
+ citations = _normalize_citations(response)
281
+ usage = _extract_usage(response)
282
+ model = _extract_model(response)
283
+
284
+ with contexttrace.trace(
285
+ query=question.question,
286
+ metadata={
287
+ "source": "contexttrace_cli_eval",
288
+ "dataset": dataset_path,
289
+ "question_id": question.id,
290
+ "expected_answer": question.expected_answer,
291
+ **question.metadata,
292
+ },
293
+ ) as trace:
294
+ trace.log_retrieval(chunks, retriever_name="contexttrace-cli-eval")
295
+ trace.log_context(selected_context)
296
+ trace.log_answer(
297
+ answer,
298
+ model=model,
299
+ usage=usage,
300
+ metadata={"endpoint": endpoint, "latency_ms": latency_ms, "error": error},
301
+ )
302
+ trace.log_citations(citations)
303
+ evaluation = trace.evaluate()
304
+
305
+ citation_support, unsupported_claim_rate = _score_evaluation(evaluation)
306
+ failure_type = evaluation.get("failure", {}).get("failure_type", "unknown")
307
+ if error:
308
+ failure_type = "endpoint_error"
309
+
310
+ return EvalResult(
311
+ question=question,
312
+ trace_id=trace.trace_id,
313
+ citation_support=citation_support,
314
+ unsupported_claim_rate=unsupported_claim_rate,
315
+ failure_type=failure_type,
316
+ token_usage=usage,
317
+ latency_ms=latency_ms,
318
+ error=error,
319
+ )
320
+
321
+
322
+ def _normalize_question(item: Any, *, index: int) -> EvalQuestion:
323
+ if isinstance(item, str):
324
+ return EvalQuestion(question=item, id=str(index))
325
+ if not isinstance(item, dict):
326
+ raise ValueError("Each dataset entry must be a string or object.")
327
+
328
+ question = item.get("question") or item.get("query")
329
+ if not question:
330
+ raise ValueError("Each dataset entry must include question or query.")
331
+
332
+ return EvalQuestion(
333
+ question=str(question),
334
+ id=str(item.get("id") or index),
335
+ expected_answer=item.get("expected_answer"),
336
+ metadata=item.get("metadata") or {},
337
+ payload=item.get("payload") or {},
338
+ )
339
+
340
+
341
+ def _normalize_chunks(response: dict[str, Any]) -> list[dict[str, Any]]:
342
+ candidates = (
343
+ response.get("retrieved_chunks")
344
+ or response.get("chunks")
345
+ or response.get("documents")
346
+ or response.get("sources")
347
+ or []
348
+ )
349
+ return [_normalize_chunk(chunk, index=index) for index, chunk in enumerate(candidates)]
350
+
351
+
352
+ def _normalize_selected_context(
353
+ response: dict[str, Any],
354
+ chunks: list[dict[str, Any]],
355
+ ) -> list[dict[str, Any]]:
356
+ candidates = response.get("selected_context") or response.get("context")
357
+ if not candidates:
358
+ return chunks
359
+ if isinstance(candidates, str):
360
+ candidates = [candidates]
361
+ return [_normalize_chunk(chunk, index=index) for index, chunk in enumerate(candidates)]
362
+
363
+
364
+ def _normalize_chunk(chunk: Any, *, index: int) -> dict[str, Any]:
365
+ if isinstance(chunk, str):
366
+ return {
367
+ "chunk_id": "chunk_%s" % index,
368
+ "content": chunk,
369
+ "source": None,
370
+ "metadata": {},
371
+ "relevance_score": None,
372
+ }
373
+ if not isinstance(chunk, dict):
374
+ content = getattr(chunk, "page_content", None) or getattr(chunk, "content", None)
375
+ metadata = getattr(chunk, "metadata", None) or {}
376
+ chunk_id = getattr(chunk, "id", None) or metadata.get("id") or "chunk_%s" % index
377
+ return {
378
+ "chunk_id": str(chunk_id),
379
+ "content": str(content or ""),
380
+ "source": metadata.get("source"),
381
+ "metadata": metadata,
382
+ "relevance_score": getattr(chunk, "score", None),
383
+ }
384
+
385
+ chunk_id = chunk.get("chunk_id") or chunk.get("id") or "chunk_%s" % index
386
+ content = chunk.get("content") or chunk.get("text") or chunk.get("page_content") or ""
387
+ return {
388
+ "chunk_id": str(chunk_id),
389
+ "content": str(content),
390
+ "source": chunk.get("source"),
391
+ "metadata": chunk.get("metadata") or {},
392
+ "relevance_score": chunk.get("relevance_score") or chunk.get("score"),
393
+ }
394
+
395
+
396
+ def _normalize_citations(response: dict[str, Any]) -> list[dict[str, Any]]:
397
+ citations = response.get("citations") or []
398
+ normalized = []
399
+ for item in citations:
400
+ if not isinstance(item, dict):
401
+ continue
402
+ claim = item.get("claim")
403
+ source_chunk_id = item.get("source_chunk_id") or item.get("chunk_id") or item.get("source_id")
404
+ if claim and source_chunk_id:
405
+ normalized.append(
406
+ {
407
+ "claim": str(claim),
408
+ "source_chunk_id": str(source_chunk_id),
409
+ "metadata": item.get("metadata") or {},
410
+ }
411
+ )
412
+ return normalized
413
+
414
+
415
+ def _extract_answer(response: dict[str, Any], *, error: Optional[str]) -> str:
416
+ if error:
417
+ return "Endpoint request failed: %s" % error
418
+ answer = response.get("answer") or response.get("output") or response.get("response")
419
+ return str(answer or "No answer returned by endpoint.")
420
+
421
+
422
+ def _extract_usage(response: dict[str, Any]) -> dict[str, Any]:
423
+ return response.get("usage") or response.get("token_usage") or {}
424
+
425
+
426
+ def _extract_model(response: dict[str, Any]) -> Optional[str]:
427
+ model = response.get("model") or response.get("model_name")
428
+ return str(model) if model else None
429
+
430
+
431
+ def _score_evaluation(evaluation: dict[str, Any]) -> tuple[float, float]:
432
+ checks = evaluation.get("citation_checks") or []
433
+ if not checks:
434
+ return 0.0, 1.0
435
+ support_scores = [float(check.get("support_score") or 0.0) for check in checks]
436
+ unsupported = [
437
+ check
438
+ for check in checks
439
+ if check.get("verdict") in UNSUPPORTED_VERDICTS
440
+ ]
441
+ return (
442
+ round(sum(support_scores) / len(support_scores), 3),
443
+ round(len(unsupported) / len(checks), 3),
444
+ )
445
+
446
+
447
+ def _escape_table(value: Any) -> str:
448
+ return str(value).replace("|", "\\|").replace("\n", " ")
@@ -0,0 +1,14 @@
1
+ from contexttrace.integrations.fastapi import ContextTraceFastAPIMiddleware
2
+ from contexttrace.integrations.langchain import ContextTraceCallbackHandler
3
+ from contexttrace.integrations.langgraph import ContextTraceLangGraphTracer
4
+ from contexttrace.integrations.llamaindex import ContextTraceLlamaIndexCallbackHandler
5
+ from contexttrace.integrations.opentelemetry import OpenTelemetryExporter, export_contexttrace_trace
6
+
7
+ __all__ = [
8
+ "ContextTraceCallbackHandler",
9
+ "ContextTraceFastAPIMiddleware",
10
+ "ContextTraceLangGraphTracer",
11
+ "ContextTraceLlamaIndexCallbackHandler",
12
+ "OpenTelemetryExporter",
13
+ "export_contexttrace_trace",
14
+ ]