contexttrace 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- contexttrace/__init__.py +36 -0
- contexttrace/_version.py +1 -0
- contexttrace/cli.py +474 -0
- contexttrace/client.py +1074 -0
- contexttrace/config.py +246 -0
- contexttrace/demo.py +311 -0
- contexttrace/demo_data.py +257 -0
- contexttrace/endpoint_eval.py +314 -0
- contexttrace/errors.py +14 -0
- contexttrace/evaluator.py +448 -0
- contexttrace/integrations/__init__.py +14 -0
- contexttrace/integrations/fastapi.py +311 -0
- contexttrace/integrations/langchain.py +440 -0
- contexttrace/integrations/langgraph.py +197 -0
- contexttrace/integrations/llamaindex.py +422 -0
- contexttrace/integrations/opentelemetry.py +111 -0
- contexttrace/local.py +325 -0
- contexttrace/py.typed +1 -0
- contexttrace/regression.py +123 -0
- contexttrace/reliability.py +284 -0
- contexttrace/report.py +550 -0
- contexttrace/storage/__init__.py +3 -0
- contexttrace/storage/sqlite_store.py +604 -0
- contexttrace/thresholds.py +50 -0
- contexttrace/transport.py +183 -0
- contexttrace/viewer.py +148 -0
- contexttrace-0.1.0.dist-info/METADATA +154 -0
- contexttrace-0.1.0.dist-info/RECORD +31 -0
- contexttrace-0.1.0.dist-info/WHEEL +5 -0
- contexttrace-0.1.0.dist-info/entry_points.txt +2 -0
- contexttrace-0.1.0.dist-info/top_level.txt +1 -0
contexttrace/config.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Optional
|
|
7
|
+
|
|
8
|
+
from contexttrace.errors import ContextTraceConfigError
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
DEFAULT_PROJECT = "default"
|
|
12
|
+
DEFAULT_BASE_URL = "http://localhost:8000"
|
|
13
|
+
DEFAULT_MODE = "local"
|
|
14
|
+
DEFAULT_LOCAL_STORE_DIR = ".contexttrace"
|
|
15
|
+
DEFAULT_STORAGE_PATH = ".contexttrace/contexttrace.db"
|
|
16
|
+
CONFIG_FILE = "contexttrace.yaml"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(frozen=True)
|
|
20
|
+
class ContextTraceConfig:
|
|
21
|
+
api_key: Optional[str] = None
|
|
22
|
+
project: str = DEFAULT_PROJECT
|
|
23
|
+
base_url: str = DEFAULT_BASE_URL
|
|
24
|
+
mode: str = DEFAULT_MODE
|
|
25
|
+
local_only: bool = True
|
|
26
|
+
timeout: float = 30.0
|
|
27
|
+
retries: int = 2
|
|
28
|
+
debug: bool = False
|
|
29
|
+
local_store_dir: str = DEFAULT_LOCAL_STORE_DIR
|
|
30
|
+
storage_path: str = DEFAULT_STORAGE_PATH
|
|
31
|
+
log_chunk_text: bool = True
|
|
32
|
+
log_answer_text: bool = True
|
|
33
|
+
eval_endpoint: Optional[str] = None
|
|
34
|
+
judge_provider: str = "local"
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def load_config(
|
|
38
|
+
*,
|
|
39
|
+
api_key: Optional[str] = None,
|
|
40
|
+
project: Optional[str] = None,
|
|
41
|
+
base_url: Optional[str] = None,
|
|
42
|
+
api_url: Optional[str] = None,
|
|
43
|
+
mode: Optional[str] = None,
|
|
44
|
+
local_only: Optional[bool] = None,
|
|
45
|
+
timeout: Optional[float] = None,
|
|
46
|
+
retries: Optional[int] = None,
|
|
47
|
+
debug: Optional[bool] = None,
|
|
48
|
+
local_store_dir: Optional[str] = None,
|
|
49
|
+
storage_path: Optional[str] = None,
|
|
50
|
+
log_chunk_text: Optional[bool] = None,
|
|
51
|
+
log_answer_text: Optional[bool] = None,
|
|
52
|
+
eval_endpoint: Optional[str] = None,
|
|
53
|
+
judge_provider: Optional[str] = None,
|
|
54
|
+
config_path: Optional[str] = None,
|
|
55
|
+
) -> ContextTraceConfig:
|
|
56
|
+
file_values = _read_config_file(config_path)
|
|
57
|
+
env_base_url = os.getenv("CONTEXTTRACE_API_URL") or os.getenv("CONTEXTTRACE_BASE_URL")
|
|
58
|
+
explicit_api_url = api_url or base_url
|
|
59
|
+
resolved_mode = _first(
|
|
60
|
+
mode,
|
|
61
|
+
os.getenv("CONTEXTTRACE_MODE"),
|
|
62
|
+
file_values.get("mode"),
|
|
63
|
+
"hosted" if explicit_api_url and local_only is not True else None,
|
|
64
|
+
"hosted" if env_base_url and local_only is not True else None,
|
|
65
|
+
DEFAULT_MODE,
|
|
66
|
+
)
|
|
67
|
+
resolved_local_store_dir = str(
|
|
68
|
+
_first(
|
|
69
|
+
local_store_dir,
|
|
70
|
+
os.getenv("CONTEXTTRACE_LOCAL_STORE_DIR"),
|
|
71
|
+
file_values.get("local_store_dir"),
|
|
72
|
+
DEFAULT_LOCAL_STORE_DIR,
|
|
73
|
+
)
|
|
74
|
+
)
|
|
75
|
+
resolved_storage_path = str(
|
|
76
|
+
_first(
|
|
77
|
+
storage_path,
|
|
78
|
+
os.getenv("CONTEXTTRACE_STORAGE_PATH"),
|
|
79
|
+
file_values.get("storage_path"),
|
|
80
|
+
str(Path(resolved_local_store_dir) / "contexttrace.db"),
|
|
81
|
+
)
|
|
82
|
+
)
|
|
83
|
+
resolved_local_only = _as_bool(
|
|
84
|
+
_first(
|
|
85
|
+
local_only,
|
|
86
|
+
os.getenv("CONTEXTTRACE_LOCAL_ONLY"),
|
|
87
|
+
file_values.get("local_only"),
|
|
88
|
+
False if resolved_mode == "hosted" else True,
|
|
89
|
+
)
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
resolved = ContextTraceConfig(
|
|
93
|
+
api_key=_first(
|
|
94
|
+
api_key,
|
|
95
|
+
os.getenv("CONTEXTTRACE_API_KEY"),
|
|
96
|
+
file_values.get("api_key"),
|
|
97
|
+
),
|
|
98
|
+
project=str(
|
|
99
|
+
_first(
|
|
100
|
+
project,
|
|
101
|
+
os.getenv("CONTEXTTRACE_PROJECT"),
|
|
102
|
+
file_values.get("project"),
|
|
103
|
+
DEFAULT_PROJECT,
|
|
104
|
+
)
|
|
105
|
+
),
|
|
106
|
+
base_url=str(
|
|
107
|
+
_first(
|
|
108
|
+
api_url,
|
|
109
|
+
base_url,
|
|
110
|
+
env_base_url,
|
|
111
|
+
file_values.get("base_url"),
|
|
112
|
+
file_values.get("api_url"),
|
|
113
|
+
DEFAULT_BASE_URL,
|
|
114
|
+
)
|
|
115
|
+
),
|
|
116
|
+
mode=str(resolved_mode),
|
|
117
|
+
local_only=resolved_local_only,
|
|
118
|
+
timeout=float(
|
|
119
|
+
_first(
|
|
120
|
+
timeout,
|
|
121
|
+
os.getenv("CONTEXTTRACE_TIMEOUT"),
|
|
122
|
+
file_values.get("timeout"),
|
|
123
|
+
30.0,
|
|
124
|
+
)
|
|
125
|
+
),
|
|
126
|
+
retries=int(
|
|
127
|
+
_first(
|
|
128
|
+
retries,
|
|
129
|
+
os.getenv("CONTEXTTRACE_RETRIES"),
|
|
130
|
+
file_values.get("retries"),
|
|
131
|
+
2,
|
|
132
|
+
)
|
|
133
|
+
),
|
|
134
|
+
debug=_as_bool(
|
|
135
|
+
_first(
|
|
136
|
+
debug,
|
|
137
|
+
os.getenv("CONTEXTTRACE_DEBUG"),
|
|
138
|
+
file_values.get("debug"),
|
|
139
|
+
False,
|
|
140
|
+
)
|
|
141
|
+
),
|
|
142
|
+
local_store_dir=resolved_local_store_dir,
|
|
143
|
+
storage_path=resolved_storage_path,
|
|
144
|
+
log_chunk_text=_as_bool(
|
|
145
|
+
_first(
|
|
146
|
+
log_chunk_text,
|
|
147
|
+
os.getenv("CONTEXTTRACE_LOG_CHUNK_TEXT"),
|
|
148
|
+
file_values.get("log_chunk_text"),
|
|
149
|
+
True,
|
|
150
|
+
)
|
|
151
|
+
),
|
|
152
|
+
log_answer_text=_as_bool(
|
|
153
|
+
_first(
|
|
154
|
+
log_answer_text,
|
|
155
|
+
os.getenv("CONTEXTTRACE_LOG_ANSWER_TEXT"),
|
|
156
|
+
file_values.get("log_answer_text"),
|
|
157
|
+
True,
|
|
158
|
+
)
|
|
159
|
+
),
|
|
160
|
+
eval_endpoint=_first(
|
|
161
|
+
eval_endpoint,
|
|
162
|
+
os.getenv("CONTEXTTRACE_EVAL_ENDPOINT"),
|
|
163
|
+
file_values.get("eval_endpoint"),
|
|
164
|
+
),
|
|
165
|
+
judge_provider=str(
|
|
166
|
+
_first(
|
|
167
|
+
judge_provider,
|
|
168
|
+
os.getenv("CONTEXTTRACE_JUDGE_PROVIDER"),
|
|
169
|
+
file_values.get("judge_provider"),
|
|
170
|
+
"local",
|
|
171
|
+
)
|
|
172
|
+
),
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
if resolved.mode not in {"hosted", "local"}:
|
|
176
|
+
raise ContextTraceConfigError("ContextTrace mode must be 'hosted' or 'local'.")
|
|
177
|
+
return resolved
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def write_default_config(path: str = CONFIG_FILE, *, overwrite: bool = False) -> str:
|
|
181
|
+
output = Path(path)
|
|
182
|
+
if output.exists() and not overwrite:
|
|
183
|
+
return str(output)
|
|
184
|
+
output.write_text(
|
|
185
|
+
"\n".join(
|
|
186
|
+
[
|
|
187
|
+
"mode: local",
|
|
188
|
+
"project: default",
|
|
189
|
+
"local_only: true",
|
|
190
|
+
"local_store_dir: .contexttrace",
|
|
191
|
+
"storage_path: .contexttrace/contexttrace.db",
|
|
192
|
+
"log_chunk_text: true",
|
|
193
|
+
"log_answer_text: true",
|
|
194
|
+
"judge_provider: local",
|
|
195
|
+
"api_key: ''",
|
|
196
|
+
"base_url: ''",
|
|
197
|
+
"timeout: 30",
|
|
198
|
+
"retries: 2",
|
|
199
|
+
"debug: false",
|
|
200
|
+
"eval_endpoint: ''",
|
|
201
|
+
"",
|
|
202
|
+
]
|
|
203
|
+
),
|
|
204
|
+
encoding="utf-8",
|
|
205
|
+
)
|
|
206
|
+
return str(output)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def _read_config_file(config_path: Optional[str]) -> dict[str, Any]:
|
|
210
|
+
candidates = [Path(config_path)] if config_path else [Path(CONFIG_FILE)]
|
|
211
|
+
for candidate in candidates:
|
|
212
|
+
if candidate.exists():
|
|
213
|
+
return _parse_simple_yaml(candidate.read_text(encoding="utf-8"))
|
|
214
|
+
return {}
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def _parse_simple_yaml(text: str) -> dict[str, Any]:
|
|
218
|
+
values: dict[str, Any] = {}
|
|
219
|
+
for raw_line in text.splitlines():
|
|
220
|
+
line = raw_line.strip()
|
|
221
|
+
if not line or line.startswith("#") or ":" not in line:
|
|
222
|
+
continue
|
|
223
|
+
key, value = line.split(":", 1)
|
|
224
|
+
parsed = value.strip().strip("\"'")
|
|
225
|
+
if parsed == "":
|
|
226
|
+
values[key.strip()] = None
|
|
227
|
+
elif parsed.lower() in {"true", "false"}:
|
|
228
|
+
values[key.strip()] = parsed.lower() == "true"
|
|
229
|
+
else:
|
|
230
|
+
values[key.strip()] = parsed
|
|
231
|
+
return values
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def _first(*values: Any) -> Any:
|
|
235
|
+
for value in values:
|
|
236
|
+
if value is not None:
|
|
237
|
+
return value
|
|
238
|
+
return None
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def _as_bool(value: Any) -> bool:
|
|
242
|
+
if isinstance(value, bool):
|
|
243
|
+
return value
|
|
244
|
+
if value is None:
|
|
245
|
+
return False
|
|
246
|
+
return str(value).strip().lower() in {"1", "true", "yes", "on"}
|
contexttrace/demo.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Iterable
|
|
7
|
+
|
|
8
|
+
from contexttrace.client import ContextTrace
|
|
9
|
+
from contexttrace.demo_data import load_demo_dataset
|
|
10
|
+
from contexttrace.report import ReportGenerator
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
STRATEGY_TOP_K = {
|
|
14
|
+
"dense_top_k": 2,
|
|
15
|
+
"bm25": 2,
|
|
16
|
+
"bm25_top_k": 2,
|
|
17
|
+
"hybrid": 3,
|
|
18
|
+
"hybrid_rerank": 4,
|
|
19
|
+
"corrective": 4,
|
|
20
|
+
"corrective_rag": 4,
|
|
21
|
+
"adaptive": 4,
|
|
22
|
+
"contexttrace_adaptive": 4,
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass(frozen=True)
|
|
27
|
+
class DemoRun:
|
|
28
|
+
dataset: str
|
|
29
|
+
eval_run_id: str | None
|
|
30
|
+
trace_ids: list[str]
|
|
31
|
+
report_path: str
|
|
32
|
+
summary: dict[str, Any]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def run_demo_dataset(
|
|
36
|
+
*,
|
|
37
|
+
dataset: str,
|
|
38
|
+
contexttrace: ContextTrace,
|
|
39
|
+
strategy: str = "adaptive",
|
|
40
|
+
report_path: str | None = None,
|
|
41
|
+
max_questions: int | None = None,
|
|
42
|
+
) -> DemoRun:
|
|
43
|
+
loaded = load_demo_dataset(dataset)
|
|
44
|
+
chunks = _document_chunks(loaded["documents"])
|
|
45
|
+
questions = list(loaded["questions"])
|
|
46
|
+
if max_questions is not None:
|
|
47
|
+
questions = questions[:max_questions]
|
|
48
|
+
|
|
49
|
+
trace_ids: list[str] = []
|
|
50
|
+
traces: list[dict[str, Any]] = []
|
|
51
|
+
for index, question in enumerate(questions):
|
|
52
|
+
trace_id = _run_question(
|
|
53
|
+
contexttrace=contexttrace,
|
|
54
|
+
dataset_name=loaded["name"],
|
|
55
|
+
question=question,
|
|
56
|
+
expected_answer=loaded["expected_answers"].get(question["id"], ""),
|
|
57
|
+
expected_sources=list(loaded["expected_sources"].get(question["id"], [])),
|
|
58
|
+
chunks=chunks,
|
|
59
|
+
strategy=strategy,
|
|
60
|
+
position=index,
|
|
61
|
+
)
|
|
62
|
+
trace_ids.append(trace_id)
|
|
63
|
+
traces.append(contexttrace.get_trace(trace_id))
|
|
64
|
+
|
|
65
|
+
summary = aggregate_trace_metrics(traces)
|
|
66
|
+
store = getattr(getattr(contexttrace, "_transport", None), "store", None)
|
|
67
|
+
eval_run_id = None
|
|
68
|
+
if store is not None:
|
|
69
|
+
eval_run_id = store.create_eval_run(dataset=loaded["name"], endpoint="contexttrace-demo", summary=summary)
|
|
70
|
+
for index, question in enumerate(questions):
|
|
71
|
+
store.add_eval_question(
|
|
72
|
+
eval_run_id=eval_run_id,
|
|
73
|
+
question=question,
|
|
74
|
+
trace_id=trace_ids[index],
|
|
75
|
+
position=index,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
if report_path is None:
|
|
79
|
+
report_path = str(Path(".contexttrace") / "reports" / ("%s_demo.html" % loaded["name"]))
|
|
80
|
+
ReportGenerator().generate_eval_report(
|
|
81
|
+
{
|
|
82
|
+
"id": eval_run_id or "%s-demo" % loaded["name"],
|
|
83
|
+
"dataset": loaded["name"],
|
|
84
|
+
"endpoint": "contexttrace-demo",
|
|
85
|
+
"summary": summary,
|
|
86
|
+
},
|
|
87
|
+
traces,
|
|
88
|
+
path=report_path,
|
|
89
|
+
)
|
|
90
|
+
return DemoRun(
|
|
91
|
+
dataset=loaded["name"],
|
|
92
|
+
eval_run_id=eval_run_id,
|
|
93
|
+
trace_ids=trace_ids,
|
|
94
|
+
report_path=report_path,
|
|
95
|
+
summary=summary,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def aggregate_trace_metrics(traces: Iterable[dict[str, Any]]) -> dict[str, Any]:
|
|
100
|
+
rows = list(traces)
|
|
101
|
+
count = len(rows)
|
|
102
|
+
if not rows:
|
|
103
|
+
return {
|
|
104
|
+
"questions_tested": 0,
|
|
105
|
+
"reliability_score": 0.0,
|
|
106
|
+
"failure_rate": 0.0,
|
|
107
|
+
"citation_support": 0.0,
|
|
108
|
+
"unsupported_claim_rate": 0.0,
|
|
109
|
+
"retrieval_miss_rate": 0.0,
|
|
110
|
+
"latency_ms": 0.0,
|
|
111
|
+
"token_count": 0.0,
|
|
112
|
+
"cost_usd": 0.0,
|
|
113
|
+
"top_failures": [],
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
failures: list[str] = []
|
|
117
|
+
citation_support: list[float] = []
|
|
118
|
+
unsupported: list[float] = []
|
|
119
|
+
retrieval_misses = 0
|
|
120
|
+
reliability_scores: list[float] = []
|
|
121
|
+
latency: list[float] = []
|
|
122
|
+
tokens: list[float] = []
|
|
123
|
+
for trace in rows:
|
|
124
|
+
evaluation = trace.get("evaluation") or {}
|
|
125
|
+
failure = evaluation.get("failure") or {}
|
|
126
|
+
scores = evaluation.get("scores") or {}
|
|
127
|
+
failure_type = failure.get("failure_type") or "unknown"
|
|
128
|
+
if failure_type != "no_failure_detected":
|
|
129
|
+
failures.append(failure_type)
|
|
130
|
+
if failure_type == "retrieval_miss":
|
|
131
|
+
retrieval_misses += 1
|
|
132
|
+
citation_support.append(float(scores.get("citation_support") or 0.0))
|
|
133
|
+
unsupported.append(float(scores.get("unsupported_claim_rate") or 0.0))
|
|
134
|
+
reliability_scores.append(float((evaluation.get("reliability") or {}).get("score") or 0.0))
|
|
135
|
+
answer = trace.get("answer") or {}
|
|
136
|
+
usage = answer.get("usage") or {}
|
|
137
|
+
metadata = answer.get("metadata") or {}
|
|
138
|
+
if "latency_ms" in metadata:
|
|
139
|
+
latency.append(float(metadata["latency_ms"]))
|
|
140
|
+
if "total_tokens" in usage:
|
|
141
|
+
tokens.append(float(usage["total_tokens"]))
|
|
142
|
+
avg_tokens = _avg(tokens)
|
|
143
|
+
return {
|
|
144
|
+
"questions_tested": count,
|
|
145
|
+
"reliability_score": _avg(reliability_scores),
|
|
146
|
+
"failure_rate": round(len(failures) / count, 3),
|
|
147
|
+
"citation_support": _avg(citation_support),
|
|
148
|
+
"avg_citation_support": _avg(citation_support),
|
|
149
|
+
"unsupported_claim_rate": _avg(unsupported),
|
|
150
|
+
"retrieval_miss_rate": round(retrieval_misses / count, 3),
|
|
151
|
+
"latency_ms": _avg(latency),
|
|
152
|
+
"token_count": avg_tokens,
|
|
153
|
+
"cost_usd": round(avg_tokens * 0.000001, 6),
|
|
154
|
+
"top_failures": _top_failures(failures),
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def _run_question(
|
|
159
|
+
*,
|
|
160
|
+
contexttrace: ContextTrace,
|
|
161
|
+
dataset_name: str,
|
|
162
|
+
question: dict[str, Any],
|
|
163
|
+
expected_answer: str,
|
|
164
|
+
expected_sources: list[str],
|
|
165
|
+
chunks: list[dict[str, Any]],
|
|
166
|
+
strategy: str,
|
|
167
|
+
position: int,
|
|
168
|
+
) -> str:
|
|
169
|
+
query = str(question["query"])
|
|
170
|
+
retrieved = _retrieve(query, chunks, top_k=STRATEGY_TOP_K.get(strategy, 3))
|
|
171
|
+
expected_failure = question.get("expected_failure")
|
|
172
|
+
if expected_sources and expected_failure in {None, "citation_mismatch"}:
|
|
173
|
+
retrieved = _force_expected_sources(chunks, expected_sources, retrieved)
|
|
174
|
+
retrieved = [
|
|
175
|
+
chunk
|
|
176
|
+
for chunk in retrieved
|
|
177
|
+
if (chunk["source"] in expected_sources or (chunk.get("metadata") or {}).get("stance") != "archived")
|
|
178
|
+
]
|
|
179
|
+
if expected_failure == "retrieval_miss":
|
|
180
|
+
retrieved = [chunk for chunk in retrieved if chunk["source"] not in expected_sources]
|
|
181
|
+
if not retrieved:
|
|
182
|
+
retrieved = [chunk for chunk in chunks if chunk["source"] not in expected_sources][:2]
|
|
183
|
+
if expected_failure == "conflicting_sources":
|
|
184
|
+
retrieved = _force_expected_sources(chunks, expected_sources, retrieved)
|
|
185
|
+
selected = retrieved[: max(1, min(len(retrieved), STRATEGY_TOP_K.get(strategy, 3)))]
|
|
186
|
+
answer = _demo_answer(question, expected_answer)
|
|
187
|
+
citations = _demo_citations(question, answer, selected, chunks, expected_sources)
|
|
188
|
+
latency_ms = 35 + (position * 7) + len(selected) * 12
|
|
189
|
+
token_count = 80 + len(answer.split()) + sum(len(chunk["content"].split()) for chunk in selected)
|
|
190
|
+
|
|
191
|
+
with contexttrace.trace(
|
|
192
|
+
query=query,
|
|
193
|
+
metadata={
|
|
194
|
+
"dataset": dataset_name,
|
|
195
|
+
"question_id": question["id"],
|
|
196
|
+
"question_type": question.get("type"),
|
|
197
|
+
"expected_sources": expected_sources,
|
|
198
|
+
"expected_failure": expected_failure or "no_failure_detected",
|
|
199
|
+
"strategy": strategy,
|
|
200
|
+
},
|
|
201
|
+
) as trace:
|
|
202
|
+
trace.log_retrieval(retrieved, metadata={"strategy": strategy})
|
|
203
|
+
trace.log_context(selected)
|
|
204
|
+
trace.log_answer(
|
|
205
|
+
answer,
|
|
206
|
+
model="contexttrace-demo-rag",
|
|
207
|
+
usage={"total_tokens": token_count},
|
|
208
|
+
metadata={"latency_ms": latency_ms, "strategy": strategy},
|
|
209
|
+
)
|
|
210
|
+
if citations:
|
|
211
|
+
trace.log_citations(citations)
|
|
212
|
+
trace.evaluate()
|
|
213
|
+
return str(trace.trace_id)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def _document_chunks(documents: dict[str, str]) -> list[dict[str, Any]]:
|
|
217
|
+
chunks: list[dict[str, Any]] = []
|
|
218
|
+
for source, text in documents.items():
|
|
219
|
+
sections = [section.strip() for section in re.split(r"\n##\s+", text) if section.strip()]
|
|
220
|
+
for index, section in enumerate(sections):
|
|
221
|
+
content = re.sub(r"^#\s+", "", section).strip()
|
|
222
|
+
if not content:
|
|
223
|
+
continue
|
|
224
|
+
if len(content.split()) <= 4 and Path(source).stem.replace("_", " ").lower() in content.lower():
|
|
225
|
+
continue
|
|
226
|
+
chunks.append(
|
|
227
|
+
{
|
|
228
|
+
"chunk_id": "%s_%s" % (Path(source).stem, index + 1),
|
|
229
|
+
"content": content,
|
|
230
|
+
"source": source,
|
|
231
|
+
"metadata": {
|
|
232
|
+
"section": content.splitlines()[0][:80],
|
|
233
|
+
"stance": _stance(source, content),
|
|
234
|
+
},
|
|
235
|
+
"relevance_score": 0.0,
|
|
236
|
+
}
|
|
237
|
+
)
|
|
238
|
+
return chunks
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def _retrieve(query: str, chunks: list[dict[str, Any]], *, top_k: int) -> list[dict[str, Any]]:
|
|
242
|
+
query_terms = _terms(query)
|
|
243
|
+
scored = []
|
|
244
|
+
for chunk in chunks:
|
|
245
|
+
score = len(query_terms & _terms(chunk["content"])) / max(len(query_terms), 1)
|
|
246
|
+
scored.append({**chunk, "relevance_score": round(score, 3)})
|
|
247
|
+
return sorted(scored, key=lambda chunk: (-float(chunk["relevance_score"]), chunk["source"]))[:top_k]
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def _force_expected_sources(
|
|
251
|
+
chunks: list[dict[str, Any]],
|
|
252
|
+
expected_sources: list[str],
|
|
253
|
+
retrieved: list[dict[str, Any]],
|
|
254
|
+
) -> list[dict[str, Any]]:
|
|
255
|
+
forced = [chunk for chunk in chunks if chunk["source"] in expected_sources]
|
|
256
|
+
seen = {chunk["chunk_id"] for chunk in forced}
|
|
257
|
+
return forced + [chunk for chunk in retrieved if chunk["chunk_id"] not in seen]
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def _demo_answer(question: dict[str, Any], expected_answer: str) -> str:
|
|
261
|
+
failure = question.get("expected_failure")
|
|
262
|
+
if failure == "should_have_abstained":
|
|
263
|
+
return "Yes, the documents say this exception is available to eligible employees or customers."
|
|
264
|
+
if failure == "unsupported_answer":
|
|
265
|
+
return "Yes, the policy gives a special exception that is not limited by the documented rules."
|
|
266
|
+
return expected_answer
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def _demo_citations(
|
|
270
|
+
question: dict[str, Any],
|
|
271
|
+
answer: str,
|
|
272
|
+
selected: list[dict[str, Any]],
|
|
273
|
+
chunks: list[dict[str, Any]],
|
|
274
|
+
expected_sources: list[str],
|
|
275
|
+
) -> list[dict[str, Any]]:
|
|
276
|
+
if not answer or not selected:
|
|
277
|
+
return []
|
|
278
|
+
failure = question.get("expected_failure")
|
|
279
|
+
claim = answer.split(".")[0].strip() + "."
|
|
280
|
+
if failure == "citation_mismatch":
|
|
281
|
+
wrong = next((chunk for chunk in selected if chunk["source"] not in expected_sources), None)
|
|
282
|
+
if wrong is None:
|
|
283
|
+
wrong = next((chunk for chunk in chunks if chunk["source"] not in expected_sources), selected[0])
|
|
284
|
+
return [{"claim": claim, "source_chunk_id": wrong["chunk_id"]}]
|
|
285
|
+
return [{"claim": claim, "source_chunk_id": selected[0]["chunk_id"]}]
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def _stance(source: str, content: str) -> str:
|
|
289
|
+
lowered = (source + " " + content).lower()
|
|
290
|
+
if "archived" in lowered or "old_" in lowered or "legacy" in lowered:
|
|
291
|
+
return "archived"
|
|
292
|
+
if "current" in lowered or "policy" in lowered:
|
|
293
|
+
return "current"
|
|
294
|
+
return "neutral"
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def _terms(text: str) -> set[str]:
|
|
298
|
+
return {
|
|
299
|
+
token.strip(".,:;!?()[]{}\"'").lower().rstrip("s")
|
|
300
|
+
for token in text.split()
|
|
301
|
+
if len(token.strip(".,:;!?()[]{}\"'")) > 2
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def _avg(values: list[float]) -> float:
|
|
306
|
+
return round(sum(values) / len(values), 3) if values else 0.0
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def _top_failures(failures: list[str]) -> list[str]:
|
|
310
|
+
counts = {failure: failures.count(failure) for failure in set(failures)}
|
|
311
|
+
return [name for name, _ in sorted(counts.items(), key=lambda item: (-item[1], item[0]))[:5]]
|