contexttrace 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- contexttrace/__init__.py +36 -0
- contexttrace/_version.py +1 -0
- contexttrace/cli.py +474 -0
- contexttrace/client.py +1074 -0
- contexttrace/config.py +246 -0
- contexttrace/demo.py +311 -0
- contexttrace/demo_data.py +257 -0
- contexttrace/endpoint_eval.py +314 -0
- contexttrace/errors.py +14 -0
- contexttrace/evaluator.py +448 -0
- contexttrace/integrations/__init__.py +14 -0
- contexttrace/integrations/fastapi.py +311 -0
- contexttrace/integrations/langchain.py +440 -0
- contexttrace/integrations/langgraph.py +197 -0
- contexttrace/integrations/llamaindex.py +422 -0
- contexttrace/integrations/opentelemetry.py +111 -0
- contexttrace/local.py +325 -0
- contexttrace/py.typed +1 -0
- contexttrace/regression.py +123 -0
- contexttrace/reliability.py +284 -0
- contexttrace/report.py +550 -0
- contexttrace/storage/__init__.py +3 -0
- contexttrace/storage/sqlite_store.py +604 -0
- contexttrace/thresholds.py +50 -0
- contexttrace/transport.py +183 -0
- contexttrace/viewer.py +148 -0
- contexttrace-0.1.0.dist-info/METADATA +154 -0
- contexttrace-0.1.0.dist-info/RECORD +31 -0
- contexttrace-0.1.0.dist-info/WHEEL +5 -0
- contexttrace-0.1.0.dist-info/entry_points.txt +2 -0
- contexttrace-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,311 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
import json
|
|
5
|
+
import time
|
|
6
|
+
from typing import Any, Awaitable, Callable, Dict, Optional
|
|
7
|
+
|
|
8
|
+
from contexttrace.client import ContextTrace
|
|
9
|
+
|
|
10
|
+
ASGIApp = Callable[..., Awaitable[None]]
|
|
11
|
+
Extractor = Callable[..., Dict[str, Any]]
|
|
12
|
+
ShouldTrace = Callable[[Dict[str, Any]], bool]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ContextTraceFastAPIMiddleware:
|
|
16
|
+
"""ASGI middleware for tracing RAG-style FastAPI endpoints.
|
|
17
|
+
|
|
18
|
+
The middleware buffers JSON request and response bodies, extracts RAG fields, and logs a
|
|
19
|
+
ContextTrace trace after the endpoint completes. Custom extractors can return:
|
|
20
|
+
query, metadata, retrieved_chunks, selected_context, answer, citations, model, and usage.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
app: ASGIApp,
|
|
26
|
+
*,
|
|
27
|
+
client: Optional[ContextTrace] = None,
|
|
28
|
+
api_key: Optional[str] = None,
|
|
29
|
+
project: str = "default",
|
|
30
|
+
base_url: str = "http://localhost:8000",
|
|
31
|
+
mode: Optional[str] = None,
|
|
32
|
+
request_extractor: Optional[Extractor] = None,
|
|
33
|
+
response_extractor: Optional[Extractor] = None,
|
|
34
|
+
should_trace: Optional[ShouldTrace] = None,
|
|
35
|
+
trace_metadata: Optional[dict[str, Any]] = None,
|
|
36
|
+
raise_logging_errors: bool = False,
|
|
37
|
+
) -> None:
|
|
38
|
+
self.app = app
|
|
39
|
+
self.client = client or ContextTrace(
|
|
40
|
+
api_key=api_key,
|
|
41
|
+
project=project,
|
|
42
|
+
base_url=base_url,
|
|
43
|
+
mode=mode,
|
|
44
|
+
)
|
|
45
|
+
self.request_extractor = request_extractor or default_request_extractor
|
|
46
|
+
self.response_extractor = response_extractor or default_response_extractor
|
|
47
|
+
self.should_trace = should_trace
|
|
48
|
+
self.trace_metadata = trace_metadata or {}
|
|
49
|
+
self.raise_logging_errors = raise_logging_errors
|
|
50
|
+
|
|
51
|
+
async def __call__(
|
|
52
|
+
self,
|
|
53
|
+
scope: dict[str, Any],
|
|
54
|
+
receive: Callable[[], Awaitable[dict[str, Any]]],
|
|
55
|
+
send: Callable[[dict[str, Any]], Awaitable[None]],
|
|
56
|
+
) -> None:
|
|
57
|
+
if scope.get("type") != "http":
|
|
58
|
+
await self.app(scope, receive, send)
|
|
59
|
+
return
|
|
60
|
+
|
|
61
|
+
request_body, request_messages = await _read_request_body(receive)
|
|
62
|
+
request_info = _request_info(scope, request_body)
|
|
63
|
+
if self.should_trace and not self.should_trace(request_info):
|
|
64
|
+
await self.app(scope, _replay_receive(request_messages), send)
|
|
65
|
+
return
|
|
66
|
+
|
|
67
|
+
start_time = time.perf_counter()
|
|
68
|
+
response_messages: list[dict[str, Any]] = []
|
|
69
|
+
|
|
70
|
+
async def capture_send(message: dict[str, Any]) -> None:
|
|
71
|
+
response_messages.append(message)
|
|
72
|
+
|
|
73
|
+
try:
|
|
74
|
+
await self.app(scope, _replay_receive(request_messages), capture_send)
|
|
75
|
+
except BaseException as exc:
|
|
76
|
+
await self._log_exception(request_info, exc, start_time)
|
|
77
|
+
raise
|
|
78
|
+
|
|
79
|
+
response_info = _response_info(response_messages, start_time)
|
|
80
|
+
await self._log_trace(request_info, response_info)
|
|
81
|
+
|
|
82
|
+
for message in response_messages:
|
|
83
|
+
await send(message)
|
|
84
|
+
|
|
85
|
+
async def _log_exception(
|
|
86
|
+
self,
|
|
87
|
+
request_info: dict[str, Any],
|
|
88
|
+
exc: BaseException,
|
|
89
|
+
start_time: float,
|
|
90
|
+
) -> None:
|
|
91
|
+
try:
|
|
92
|
+
extracted = await _call_extractor(self.request_extractor, request_info)
|
|
93
|
+
query = extracted.get("query") or request_info.get("path") or "unknown request"
|
|
94
|
+
with self.client.trace(
|
|
95
|
+
query=str(query),
|
|
96
|
+
metadata={
|
|
97
|
+
**self.trace_metadata,
|
|
98
|
+
**(extracted.get("metadata") or {}),
|
|
99
|
+
"integration": "fastapi",
|
|
100
|
+
"http": _http_metadata(request_info),
|
|
101
|
+
},
|
|
102
|
+
) as trace:
|
|
103
|
+
trace.log_agent_error(
|
|
104
|
+
str(exc),
|
|
105
|
+
name="fastapi_endpoint_error",
|
|
106
|
+
metadata={"error_type": exc.__class__.__name__},
|
|
107
|
+
latency_ms=_elapsed_ms(start_time),
|
|
108
|
+
)
|
|
109
|
+
except Exception:
|
|
110
|
+
if self.raise_logging_errors:
|
|
111
|
+
raise
|
|
112
|
+
|
|
113
|
+
async def _log_trace(
|
|
114
|
+
self,
|
|
115
|
+
request_info: dict[str, Any],
|
|
116
|
+
response_info: dict[str, Any],
|
|
117
|
+
) -> None:
|
|
118
|
+
try:
|
|
119
|
+
request_data = await _call_extractor(self.request_extractor, request_info)
|
|
120
|
+
response_data = await _call_extractor(self.response_extractor, response_info, request_info)
|
|
121
|
+
payload = {**request_data, **response_data}
|
|
122
|
+
query = payload.get("query") or request_info.get("path") or "unknown request"
|
|
123
|
+
metadata = {
|
|
124
|
+
**self.trace_metadata,
|
|
125
|
+
**(request_data.get("metadata") or {}),
|
|
126
|
+
**(response_data.get("metadata") or {}),
|
|
127
|
+
"integration": "fastapi",
|
|
128
|
+
"http": _http_metadata(request_info, response_info),
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
with self.client.trace(query=str(query), metadata=metadata) as trace:
|
|
132
|
+
retrieved = payload.get("retrieved_chunks") or payload.get("chunks") or []
|
|
133
|
+
selected = payload.get("selected_context") or payload.get("context") or []
|
|
134
|
+
citations = payload.get("citations") or []
|
|
135
|
+
answer = payload.get("answer")
|
|
136
|
+
|
|
137
|
+
if retrieved:
|
|
138
|
+
trace.log_retrieval(
|
|
139
|
+
retrieved,
|
|
140
|
+
retriever_name=payload.get("retriever_name") or "fastapi_endpoint",
|
|
141
|
+
metadata={"source": "fastapi_response"},
|
|
142
|
+
)
|
|
143
|
+
if selected:
|
|
144
|
+
trace.log_context(selected, metadata={"source": "fastapi_response"})
|
|
145
|
+
if answer:
|
|
146
|
+
trace.log_answer(
|
|
147
|
+
str(answer),
|
|
148
|
+
model=payload.get("model"),
|
|
149
|
+
usage=payload.get("usage") or {},
|
|
150
|
+
metadata={"latency_ms": response_info.get("latency_ms")},
|
|
151
|
+
)
|
|
152
|
+
if citations:
|
|
153
|
+
trace.log_citations(citations)
|
|
154
|
+
except Exception:
|
|
155
|
+
if self.raise_logging_errors:
|
|
156
|
+
raise
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def default_request_extractor(request: dict[str, Any]) -> dict[str, Any]:
|
|
160
|
+
body = request.get("json")
|
|
161
|
+
if not isinstance(body, dict):
|
|
162
|
+
return {"query": request.get("path"), "metadata": {}}
|
|
163
|
+
query = _first_string(body, "query", "question", "input", "prompt")
|
|
164
|
+
return {
|
|
165
|
+
"query": query or request.get("path"),
|
|
166
|
+
"metadata": body.get("metadata") if isinstance(body.get("metadata"), dict) else {},
|
|
167
|
+
"retrieved_chunks": body.get("retrieved_chunks") or body.get("chunks") or [],
|
|
168
|
+
"selected_context": body.get("selected_context") or body.get("context") or [],
|
|
169
|
+
"citations": _citations(body.get("citations") or []),
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def default_response_extractor(response: dict[str, Any], request: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
174
|
+
body = response.get("json")
|
|
175
|
+
if not isinstance(body, dict):
|
|
176
|
+
return {}
|
|
177
|
+
return {
|
|
178
|
+
"answer": _first_string(body, "answer", "response", "output", "result", "text"),
|
|
179
|
+
"metadata": body.get("metadata") if isinstance(body.get("metadata"), dict) else {},
|
|
180
|
+
"retrieved_chunks": body.get("retrieved_chunks") or body.get("contexts") or body.get("chunks") or [],
|
|
181
|
+
"selected_context": body.get("selected_context") or body.get("context") or [],
|
|
182
|
+
"citations": _citations(body.get("citations") or body.get("sources") or []),
|
|
183
|
+
"model": body.get("model"),
|
|
184
|
+
"usage": body.get("usage") if isinstance(body.get("usage"), dict) else {},
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
async def _read_request_body(
|
|
189
|
+
receive: Callable[[], Awaitable[dict[str, Any]]],
|
|
190
|
+
) -> tuple[bytes, list[dict[str, Any]]]:
|
|
191
|
+
body_parts: list[bytes] = []
|
|
192
|
+
messages: list[dict[str, Any]] = []
|
|
193
|
+
while True:
|
|
194
|
+
message = await receive()
|
|
195
|
+
messages.append(message)
|
|
196
|
+
if message.get("type") != "http.request":
|
|
197
|
+
break
|
|
198
|
+
body_parts.append(message.get("body", b""))
|
|
199
|
+
if not message.get("more_body", False):
|
|
200
|
+
break
|
|
201
|
+
return b"".join(body_parts), messages
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def _replay_receive(messages: list[dict[str, Any]]) -> Callable[[], Awaitable[dict[str, Any]]]:
|
|
205
|
+
pending = list(messages)
|
|
206
|
+
|
|
207
|
+
async def receive() -> dict[str, Any]:
|
|
208
|
+
if pending:
|
|
209
|
+
return pending.pop(0)
|
|
210
|
+
return {"type": "http.request", "body": b"", "more_body": False}
|
|
211
|
+
|
|
212
|
+
return receive
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def _request_info(scope: dict[str, Any], body: bytes) -> dict[str, Any]:
|
|
216
|
+
headers = {
|
|
217
|
+
key.decode("latin1").lower(): value.decode("latin1")
|
|
218
|
+
for key, value in scope.get("headers", [])
|
|
219
|
+
}
|
|
220
|
+
return {
|
|
221
|
+
"method": scope.get("method"),
|
|
222
|
+
"path": scope.get("path"),
|
|
223
|
+
"headers": headers,
|
|
224
|
+
"body": body,
|
|
225
|
+
"json": _decode_json(body),
|
|
226
|
+
"scope": scope,
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def _response_info(messages: list[dict[str, Any]], start_time: float) -> dict[str, Any]:
|
|
231
|
+
status_code = None
|
|
232
|
+
headers: dict[str, str] = {}
|
|
233
|
+
body_parts: list[bytes] = []
|
|
234
|
+
for message in messages:
|
|
235
|
+
if message.get("type") == "http.response.start":
|
|
236
|
+
status_code = message.get("status")
|
|
237
|
+
headers = {
|
|
238
|
+
key.decode("latin1").lower(): value.decode("latin1")
|
|
239
|
+
for key, value in message.get("headers", [])
|
|
240
|
+
}
|
|
241
|
+
if message.get("type") == "http.response.body":
|
|
242
|
+
body_parts.append(message.get("body", b""))
|
|
243
|
+
body = b"".join(body_parts)
|
|
244
|
+
return {
|
|
245
|
+
"status_code": status_code,
|
|
246
|
+
"headers": headers,
|
|
247
|
+
"body": body,
|
|
248
|
+
"json": _decode_json(body),
|
|
249
|
+
"latency_ms": _elapsed_ms(start_time),
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
async def _call_extractor(extractor: Extractor, *args: Any) -> dict[str, Any]:
|
|
254
|
+
try:
|
|
255
|
+
value = extractor(*args)
|
|
256
|
+
except TypeError:
|
|
257
|
+
value = extractor(args[0])
|
|
258
|
+
if inspect.isawaitable(value):
|
|
259
|
+
value = await value
|
|
260
|
+
if not isinstance(value, dict):
|
|
261
|
+
raise ValueError("ContextTrace extractor must return a dictionary.")
|
|
262
|
+
return value
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def _decode_json(body: bytes) -> Any:
|
|
266
|
+
if not body:
|
|
267
|
+
return None
|
|
268
|
+
try:
|
|
269
|
+
return json.loads(body.decode("utf-8"))
|
|
270
|
+
except (UnicodeDecodeError, json.JSONDecodeError):
|
|
271
|
+
return None
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def _first_string(data: dict[str, Any], *keys: str) -> Optional[str]:
|
|
275
|
+
for key in keys:
|
|
276
|
+
value = data.get(key)
|
|
277
|
+
if isinstance(value, str) and value.strip():
|
|
278
|
+
return value
|
|
279
|
+
return None
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def _citations(value: Any) -> list[dict[str, Any]]:
|
|
283
|
+
if not isinstance(value, list):
|
|
284
|
+
return []
|
|
285
|
+
citations = []
|
|
286
|
+
for citation in value:
|
|
287
|
+
if not isinstance(citation, dict):
|
|
288
|
+
continue
|
|
289
|
+
claim = citation.get("claim")
|
|
290
|
+
source_chunk_id = citation.get("source_chunk_id") or citation.get("chunk_id") or citation.get("source")
|
|
291
|
+
if claim and source_chunk_id:
|
|
292
|
+
citations.append({"claim": str(claim), "source_chunk_id": str(source_chunk_id)})
|
|
293
|
+
return citations
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def _http_metadata(
|
|
297
|
+
request: dict[str, Any],
|
|
298
|
+
response: Optional[dict[str, Any]] = None,
|
|
299
|
+
) -> dict[str, Any]:
|
|
300
|
+
metadata = {
|
|
301
|
+
"method": request.get("method"),
|
|
302
|
+
"path": request.get("path"),
|
|
303
|
+
}
|
|
304
|
+
if response is not None:
|
|
305
|
+
metadata["status_code"] = response.get("status_code")
|
|
306
|
+
metadata["latency_ms"] = response.get("latency_ms")
|
|
307
|
+
return metadata
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def _elapsed_ms(start_time: float) -> int:
|
|
311
|
+
return int((time.perf_counter() - start_time) * 1000)
|