contexttrace 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
contexttrace/config.py ADDED
@@ -0,0 +1,246 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Any, Optional
7
+
8
+ from contexttrace.errors import ContextTraceConfigError
9
+
10
+
11
+ DEFAULT_PROJECT = "default"
12
+ DEFAULT_BASE_URL = "http://localhost:8000"
13
+ DEFAULT_MODE = "local"
14
+ DEFAULT_LOCAL_STORE_DIR = ".contexttrace"
15
+ DEFAULT_STORAGE_PATH = ".contexttrace/contexttrace.db"
16
+ CONFIG_FILE = "contexttrace.yaml"
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class ContextTraceConfig:
21
+ api_key: Optional[str] = None
22
+ project: str = DEFAULT_PROJECT
23
+ base_url: str = DEFAULT_BASE_URL
24
+ mode: str = DEFAULT_MODE
25
+ local_only: bool = True
26
+ timeout: float = 30.0
27
+ retries: int = 2
28
+ debug: bool = False
29
+ local_store_dir: str = DEFAULT_LOCAL_STORE_DIR
30
+ storage_path: str = DEFAULT_STORAGE_PATH
31
+ log_chunk_text: bool = True
32
+ log_answer_text: bool = True
33
+ eval_endpoint: Optional[str] = None
34
+ judge_provider: str = "local"
35
+
36
+
37
+ def load_config(
38
+ *,
39
+ api_key: Optional[str] = None,
40
+ project: Optional[str] = None,
41
+ base_url: Optional[str] = None,
42
+ api_url: Optional[str] = None,
43
+ mode: Optional[str] = None,
44
+ local_only: Optional[bool] = None,
45
+ timeout: Optional[float] = None,
46
+ retries: Optional[int] = None,
47
+ debug: Optional[bool] = None,
48
+ local_store_dir: Optional[str] = None,
49
+ storage_path: Optional[str] = None,
50
+ log_chunk_text: Optional[bool] = None,
51
+ log_answer_text: Optional[bool] = None,
52
+ eval_endpoint: Optional[str] = None,
53
+ judge_provider: Optional[str] = None,
54
+ config_path: Optional[str] = None,
55
+ ) -> ContextTraceConfig:
56
+ file_values = _read_config_file(config_path)
57
+ env_base_url = os.getenv("CONTEXTTRACE_API_URL") or os.getenv("CONTEXTTRACE_BASE_URL")
58
+ explicit_api_url = api_url or base_url
59
+ resolved_mode = _first(
60
+ mode,
61
+ os.getenv("CONTEXTTRACE_MODE"),
62
+ file_values.get("mode"),
63
+ "hosted" if explicit_api_url and local_only is not True else None,
64
+ "hosted" if env_base_url and local_only is not True else None,
65
+ DEFAULT_MODE,
66
+ )
67
+ resolved_local_store_dir = str(
68
+ _first(
69
+ local_store_dir,
70
+ os.getenv("CONTEXTTRACE_LOCAL_STORE_DIR"),
71
+ file_values.get("local_store_dir"),
72
+ DEFAULT_LOCAL_STORE_DIR,
73
+ )
74
+ )
75
+ resolved_storage_path = str(
76
+ _first(
77
+ storage_path,
78
+ os.getenv("CONTEXTTRACE_STORAGE_PATH"),
79
+ file_values.get("storage_path"),
80
+ str(Path(resolved_local_store_dir) / "contexttrace.db"),
81
+ )
82
+ )
83
+ resolved_local_only = _as_bool(
84
+ _first(
85
+ local_only,
86
+ os.getenv("CONTEXTTRACE_LOCAL_ONLY"),
87
+ file_values.get("local_only"),
88
+ False if resolved_mode == "hosted" else True,
89
+ )
90
+ )
91
+
92
+ resolved = ContextTraceConfig(
93
+ api_key=_first(
94
+ api_key,
95
+ os.getenv("CONTEXTTRACE_API_KEY"),
96
+ file_values.get("api_key"),
97
+ ),
98
+ project=str(
99
+ _first(
100
+ project,
101
+ os.getenv("CONTEXTTRACE_PROJECT"),
102
+ file_values.get("project"),
103
+ DEFAULT_PROJECT,
104
+ )
105
+ ),
106
+ base_url=str(
107
+ _first(
108
+ api_url,
109
+ base_url,
110
+ env_base_url,
111
+ file_values.get("base_url"),
112
+ file_values.get("api_url"),
113
+ DEFAULT_BASE_URL,
114
+ )
115
+ ),
116
+ mode=str(resolved_mode),
117
+ local_only=resolved_local_only,
118
+ timeout=float(
119
+ _first(
120
+ timeout,
121
+ os.getenv("CONTEXTTRACE_TIMEOUT"),
122
+ file_values.get("timeout"),
123
+ 30.0,
124
+ )
125
+ ),
126
+ retries=int(
127
+ _first(
128
+ retries,
129
+ os.getenv("CONTEXTTRACE_RETRIES"),
130
+ file_values.get("retries"),
131
+ 2,
132
+ )
133
+ ),
134
+ debug=_as_bool(
135
+ _first(
136
+ debug,
137
+ os.getenv("CONTEXTTRACE_DEBUG"),
138
+ file_values.get("debug"),
139
+ False,
140
+ )
141
+ ),
142
+ local_store_dir=resolved_local_store_dir,
143
+ storage_path=resolved_storage_path,
144
+ log_chunk_text=_as_bool(
145
+ _first(
146
+ log_chunk_text,
147
+ os.getenv("CONTEXTTRACE_LOG_CHUNK_TEXT"),
148
+ file_values.get("log_chunk_text"),
149
+ True,
150
+ )
151
+ ),
152
+ log_answer_text=_as_bool(
153
+ _first(
154
+ log_answer_text,
155
+ os.getenv("CONTEXTTRACE_LOG_ANSWER_TEXT"),
156
+ file_values.get("log_answer_text"),
157
+ True,
158
+ )
159
+ ),
160
+ eval_endpoint=_first(
161
+ eval_endpoint,
162
+ os.getenv("CONTEXTTRACE_EVAL_ENDPOINT"),
163
+ file_values.get("eval_endpoint"),
164
+ ),
165
+ judge_provider=str(
166
+ _first(
167
+ judge_provider,
168
+ os.getenv("CONTEXTTRACE_JUDGE_PROVIDER"),
169
+ file_values.get("judge_provider"),
170
+ "local",
171
+ )
172
+ ),
173
+ )
174
+
175
+ if resolved.mode not in {"hosted", "local"}:
176
+ raise ContextTraceConfigError("ContextTrace mode must be 'hosted' or 'local'.")
177
+ return resolved
178
+
179
+
180
+ def write_default_config(path: str = CONFIG_FILE, *, overwrite: bool = False) -> str:
181
+ output = Path(path)
182
+ if output.exists() and not overwrite:
183
+ return str(output)
184
+ output.write_text(
185
+ "\n".join(
186
+ [
187
+ "mode: local",
188
+ "project: default",
189
+ "local_only: true",
190
+ "local_store_dir: .contexttrace",
191
+ "storage_path: .contexttrace/contexttrace.db",
192
+ "log_chunk_text: true",
193
+ "log_answer_text: true",
194
+ "judge_provider: local",
195
+ "api_key: ''",
196
+ "base_url: ''",
197
+ "timeout: 30",
198
+ "retries: 2",
199
+ "debug: false",
200
+ "eval_endpoint: ''",
201
+ "",
202
+ ]
203
+ ),
204
+ encoding="utf-8",
205
+ )
206
+ return str(output)
207
+
208
+
209
+ def _read_config_file(config_path: Optional[str]) -> dict[str, Any]:
210
+ candidates = [Path(config_path)] if config_path else [Path(CONFIG_FILE)]
211
+ for candidate in candidates:
212
+ if candidate.exists():
213
+ return _parse_simple_yaml(candidate.read_text(encoding="utf-8"))
214
+ return {}
215
+
216
+
217
+ def _parse_simple_yaml(text: str) -> dict[str, Any]:
218
+ values: dict[str, Any] = {}
219
+ for raw_line in text.splitlines():
220
+ line = raw_line.strip()
221
+ if not line or line.startswith("#") or ":" not in line:
222
+ continue
223
+ key, value = line.split(":", 1)
224
+ parsed = value.strip().strip("\"'")
225
+ if parsed == "":
226
+ values[key.strip()] = None
227
+ elif parsed.lower() in {"true", "false"}:
228
+ values[key.strip()] = parsed.lower() == "true"
229
+ else:
230
+ values[key.strip()] = parsed
231
+ return values
232
+
233
+
234
+ def _first(*values: Any) -> Any:
235
+ for value in values:
236
+ if value is not None:
237
+ return value
238
+ return None
239
+
240
+
241
+ def _as_bool(value: Any) -> bool:
242
+ if isinstance(value, bool):
243
+ return value
244
+ if value is None:
245
+ return False
246
+ return str(value).strip().lower() in {"1", "true", "yes", "on"}
contexttrace/demo.py ADDED
@@ -0,0 +1,311 @@
1
+ from __future__ import annotations
2
+
3
+ import re
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Any, Iterable
7
+
8
+ from contexttrace.client import ContextTrace
9
+ from contexttrace.demo_data import load_demo_dataset
10
+ from contexttrace.report import ReportGenerator
11
+
12
+
13
+ STRATEGY_TOP_K = {
14
+ "dense_top_k": 2,
15
+ "bm25": 2,
16
+ "bm25_top_k": 2,
17
+ "hybrid": 3,
18
+ "hybrid_rerank": 4,
19
+ "corrective": 4,
20
+ "corrective_rag": 4,
21
+ "adaptive": 4,
22
+ "contexttrace_adaptive": 4,
23
+ }
24
+
25
+
26
+ @dataclass(frozen=True)
27
+ class DemoRun:
28
+ dataset: str
29
+ eval_run_id: str | None
30
+ trace_ids: list[str]
31
+ report_path: str
32
+ summary: dict[str, Any]
33
+
34
+
35
+ def run_demo_dataset(
36
+ *,
37
+ dataset: str,
38
+ contexttrace: ContextTrace,
39
+ strategy: str = "adaptive",
40
+ report_path: str | None = None,
41
+ max_questions: int | None = None,
42
+ ) -> DemoRun:
43
+ loaded = load_demo_dataset(dataset)
44
+ chunks = _document_chunks(loaded["documents"])
45
+ questions = list(loaded["questions"])
46
+ if max_questions is not None:
47
+ questions = questions[:max_questions]
48
+
49
+ trace_ids: list[str] = []
50
+ traces: list[dict[str, Any]] = []
51
+ for index, question in enumerate(questions):
52
+ trace_id = _run_question(
53
+ contexttrace=contexttrace,
54
+ dataset_name=loaded["name"],
55
+ question=question,
56
+ expected_answer=loaded["expected_answers"].get(question["id"], ""),
57
+ expected_sources=list(loaded["expected_sources"].get(question["id"], [])),
58
+ chunks=chunks,
59
+ strategy=strategy,
60
+ position=index,
61
+ )
62
+ trace_ids.append(trace_id)
63
+ traces.append(contexttrace.get_trace(trace_id))
64
+
65
+ summary = aggregate_trace_metrics(traces)
66
+ store = getattr(getattr(contexttrace, "_transport", None), "store", None)
67
+ eval_run_id = None
68
+ if store is not None:
69
+ eval_run_id = store.create_eval_run(dataset=loaded["name"], endpoint="contexttrace-demo", summary=summary)
70
+ for index, question in enumerate(questions):
71
+ store.add_eval_question(
72
+ eval_run_id=eval_run_id,
73
+ question=question,
74
+ trace_id=trace_ids[index],
75
+ position=index,
76
+ )
77
+
78
+ if report_path is None:
79
+ report_path = str(Path(".contexttrace") / "reports" / ("%s_demo.html" % loaded["name"]))
80
+ ReportGenerator().generate_eval_report(
81
+ {
82
+ "id": eval_run_id or "%s-demo" % loaded["name"],
83
+ "dataset": loaded["name"],
84
+ "endpoint": "contexttrace-demo",
85
+ "summary": summary,
86
+ },
87
+ traces,
88
+ path=report_path,
89
+ )
90
+ return DemoRun(
91
+ dataset=loaded["name"],
92
+ eval_run_id=eval_run_id,
93
+ trace_ids=trace_ids,
94
+ report_path=report_path,
95
+ summary=summary,
96
+ )
97
+
98
+
99
+ def aggregate_trace_metrics(traces: Iterable[dict[str, Any]]) -> dict[str, Any]:
100
+ rows = list(traces)
101
+ count = len(rows)
102
+ if not rows:
103
+ return {
104
+ "questions_tested": 0,
105
+ "reliability_score": 0.0,
106
+ "failure_rate": 0.0,
107
+ "citation_support": 0.0,
108
+ "unsupported_claim_rate": 0.0,
109
+ "retrieval_miss_rate": 0.0,
110
+ "latency_ms": 0.0,
111
+ "token_count": 0.0,
112
+ "cost_usd": 0.0,
113
+ "top_failures": [],
114
+ }
115
+
116
+ failures: list[str] = []
117
+ citation_support: list[float] = []
118
+ unsupported: list[float] = []
119
+ retrieval_misses = 0
120
+ reliability_scores: list[float] = []
121
+ latency: list[float] = []
122
+ tokens: list[float] = []
123
+ for trace in rows:
124
+ evaluation = trace.get("evaluation") or {}
125
+ failure = evaluation.get("failure") or {}
126
+ scores = evaluation.get("scores") or {}
127
+ failure_type = failure.get("failure_type") or "unknown"
128
+ if failure_type != "no_failure_detected":
129
+ failures.append(failure_type)
130
+ if failure_type == "retrieval_miss":
131
+ retrieval_misses += 1
132
+ citation_support.append(float(scores.get("citation_support") or 0.0))
133
+ unsupported.append(float(scores.get("unsupported_claim_rate") or 0.0))
134
+ reliability_scores.append(float((evaluation.get("reliability") or {}).get("score") or 0.0))
135
+ answer = trace.get("answer") or {}
136
+ usage = answer.get("usage") or {}
137
+ metadata = answer.get("metadata") or {}
138
+ if "latency_ms" in metadata:
139
+ latency.append(float(metadata["latency_ms"]))
140
+ if "total_tokens" in usage:
141
+ tokens.append(float(usage["total_tokens"]))
142
+ avg_tokens = _avg(tokens)
143
+ return {
144
+ "questions_tested": count,
145
+ "reliability_score": _avg(reliability_scores),
146
+ "failure_rate": round(len(failures) / count, 3),
147
+ "citation_support": _avg(citation_support),
148
+ "avg_citation_support": _avg(citation_support),
149
+ "unsupported_claim_rate": _avg(unsupported),
150
+ "retrieval_miss_rate": round(retrieval_misses / count, 3),
151
+ "latency_ms": _avg(latency),
152
+ "token_count": avg_tokens,
153
+ "cost_usd": round(avg_tokens * 0.000001, 6),
154
+ "top_failures": _top_failures(failures),
155
+ }
156
+
157
+
158
+ def _run_question(
159
+ *,
160
+ contexttrace: ContextTrace,
161
+ dataset_name: str,
162
+ question: dict[str, Any],
163
+ expected_answer: str,
164
+ expected_sources: list[str],
165
+ chunks: list[dict[str, Any]],
166
+ strategy: str,
167
+ position: int,
168
+ ) -> str:
169
+ query = str(question["query"])
170
+ retrieved = _retrieve(query, chunks, top_k=STRATEGY_TOP_K.get(strategy, 3))
171
+ expected_failure = question.get("expected_failure")
172
+ if expected_sources and expected_failure in {None, "citation_mismatch"}:
173
+ retrieved = _force_expected_sources(chunks, expected_sources, retrieved)
174
+ retrieved = [
175
+ chunk
176
+ for chunk in retrieved
177
+ if (chunk["source"] in expected_sources or (chunk.get("metadata") or {}).get("stance") != "archived")
178
+ ]
179
+ if expected_failure == "retrieval_miss":
180
+ retrieved = [chunk for chunk in retrieved if chunk["source"] not in expected_sources]
181
+ if not retrieved:
182
+ retrieved = [chunk for chunk in chunks if chunk["source"] not in expected_sources][:2]
183
+ if expected_failure == "conflicting_sources":
184
+ retrieved = _force_expected_sources(chunks, expected_sources, retrieved)
185
+ selected = retrieved[: max(1, min(len(retrieved), STRATEGY_TOP_K.get(strategy, 3)))]
186
+ answer = _demo_answer(question, expected_answer)
187
+ citations = _demo_citations(question, answer, selected, chunks, expected_sources)
188
+ latency_ms = 35 + (position * 7) + len(selected) * 12
189
+ token_count = 80 + len(answer.split()) + sum(len(chunk["content"].split()) for chunk in selected)
190
+
191
+ with contexttrace.trace(
192
+ query=query,
193
+ metadata={
194
+ "dataset": dataset_name,
195
+ "question_id": question["id"],
196
+ "question_type": question.get("type"),
197
+ "expected_sources": expected_sources,
198
+ "expected_failure": expected_failure or "no_failure_detected",
199
+ "strategy": strategy,
200
+ },
201
+ ) as trace:
202
+ trace.log_retrieval(retrieved, metadata={"strategy": strategy})
203
+ trace.log_context(selected)
204
+ trace.log_answer(
205
+ answer,
206
+ model="contexttrace-demo-rag",
207
+ usage={"total_tokens": token_count},
208
+ metadata={"latency_ms": latency_ms, "strategy": strategy},
209
+ )
210
+ if citations:
211
+ trace.log_citations(citations)
212
+ trace.evaluate()
213
+ return str(trace.trace_id)
214
+
215
+
216
+ def _document_chunks(documents: dict[str, str]) -> list[dict[str, Any]]:
217
+ chunks: list[dict[str, Any]] = []
218
+ for source, text in documents.items():
219
+ sections = [section.strip() for section in re.split(r"\n##\s+", text) if section.strip()]
220
+ for index, section in enumerate(sections):
221
+ content = re.sub(r"^#\s+", "", section).strip()
222
+ if not content:
223
+ continue
224
+ if len(content.split()) <= 4 and Path(source).stem.replace("_", " ").lower() in content.lower():
225
+ continue
226
+ chunks.append(
227
+ {
228
+ "chunk_id": "%s_%s" % (Path(source).stem, index + 1),
229
+ "content": content,
230
+ "source": source,
231
+ "metadata": {
232
+ "section": content.splitlines()[0][:80],
233
+ "stance": _stance(source, content),
234
+ },
235
+ "relevance_score": 0.0,
236
+ }
237
+ )
238
+ return chunks
239
+
240
+
241
+ def _retrieve(query: str, chunks: list[dict[str, Any]], *, top_k: int) -> list[dict[str, Any]]:
242
+ query_terms = _terms(query)
243
+ scored = []
244
+ for chunk in chunks:
245
+ score = len(query_terms & _terms(chunk["content"])) / max(len(query_terms), 1)
246
+ scored.append({**chunk, "relevance_score": round(score, 3)})
247
+ return sorted(scored, key=lambda chunk: (-float(chunk["relevance_score"]), chunk["source"]))[:top_k]
248
+
249
+
250
+ def _force_expected_sources(
251
+ chunks: list[dict[str, Any]],
252
+ expected_sources: list[str],
253
+ retrieved: list[dict[str, Any]],
254
+ ) -> list[dict[str, Any]]:
255
+ forced = [chunk for chunk in chunks if chunk["source"] in expected_sources]
256
+ seen = {chunk["chunk_id"] for chunk in forced}
257
+ return forced + [chunk for chunk in retrieved if chunk["chunk_id"] not in seen]
258
+
259
+
260
+ def _demo_answer(question: dict[str, Any], expected_answer: str) -> str:
261
+ failure = question.get("expected_failure")
262
+ if failure == "should_have_abstained":
263
+ return "Yes, the documents say this exception is available to eligible employees or customers."
264
+ if failure == "unsupported_answer":
265
+ return "Yes, the policy gives a special exception that is not limited by the documented rules."
266
+ return expected_answer
267
+
268
+
269
+ def _demo_citations(
270
+ question: dict[str, Any],
271
+ answer: str,
272
+ selected: list[dict[str, Any]],
273
+ chunks: list[dict[str, Any]],
274
+ expected_sources: list[str],
275
+ ) -> list[dict[str, Any]]:
276
+ if not answer or not selected:
277
+ return []
278
+ failure = question.get("expected_failure")
279
+ claim = answer.split(".")[0].strip() + "."
280
+ if failure == "citation_mismatch":
281
+ wrong = next((chunk for chunk in selected if chunk["source"] not in expected_sources), None)
282
+ if wrong is None:
283
+ wrong = next((chunk for chunk in chunks if chunk["source"] not in expected_sources), selected[0])
284
+ return [{"claim": claim, "source_chunk_id": wrong["chunk_id"]}]
285
+ return [{"claim": claim, "source_chunk_id": selected[0]["chunk_id"]}]
286
+
287
+
288
+ def _stance(source: str, content: str) -> str:
289
+ lowered = (source + " " + content).lower()
290
+ if "archived" in lowered or "old_" in lowered or "legacy" in lowered:
291
+ return "archived"
292
+ if "current" in lowered or "policy" in lowered:
293
+ return "current"
294
+ return "neutral"
295
+
296
+
297
+ def _terms(text: str) -> set[str]:
298
+ return {
299
+ token.strip(".,:;!?()[]{}\"'").lower().rstrip("s")
300
+ for token in text.split()
301
+ if len(token.strip(".,:;!?()[]{}\"'")) > 2
302
+ }
303
+
304
+
305
+ def _avg(values: list[float]) -> float:
306
+ return round(sum(values) / len(values), 3) if values else 0.0
307
+
308
+
309
+ def _top_failures(failures: list[str]) -> list[str]:
310
+ counts = {failure: failures.count(failure) for failure in set(failures)}
311
+ return [name for name, _ in sorted(counts.items(), key=lambda item: (-item[1], item[0]))[:5]]