contexttrace 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,311 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ import json
5
+ import time
6
+ from typing import Any, Awaitable, Callable, Dict, Optional
7
+
8
+ from contexttrace.client import ContextTrace
9
+
10
+ ASGIApp = Callable[..., Awaitable[None]]
11
+ Extractor = Callable[..., Dict[str, Any]]
12
+ ShouldTrace = Callable[[Dict[str, Any]], bool]
13
+
14
+
15
+ class ContextTraceFastAPIMiddleware:
16
+ """ASGI middleware for tracing RAG-style FastAPI endpoints.
17
+
18
+ The middleware buffers JSON request and response bodies, extracts RAG fields, and logs a
19
+ ContextTrace trace after the endpoint completes. Custom extractors can return:
20
+ query, metadata, retrieved_chunks, selected_context, answer, citations, model, and usage.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ app: ASGIApp,
26
+ *,
27
+ client: Optional[ContextTrace] = None,
28
+ api_key: Optional[str] = None,
29
+ project: str = "default",
30
+ base_url: str = "http://localhost:8000",
31
+ mode: Optional[str] = None,
32
+ request_extractor: Optional[Extractor] = None,
33
+ response_extractor: Optional[Extractor] = None,
34
+ should_trace: Optional[ShouldTrace] = None,
35
+ trace_metadata: Optional[dict[str, Any]] = None,
36
+ raise_logging_errors: bool = False,
37
+ ) -> None:
38
+ self.app = app
39
+ self.client = client or ContextTrace(
40
+ api_key=api_key,
41
+ project=project,
42
+ base_url=base_url,
43
+ mode=mode,
44
+ )
45
+ self.request_extractor = request_extractor or default_request_extractor
46
+ self.response_extractor = response_extractor or default_response_extractor
47
+ self.should_trace = should_trace
48
+ self.trace_metadata = trace_metadata or {}
49
+ self.raise_logging_errors = raise_logging_errors
50
+
51
+ async def __call__(
52
+ self,
53
+ scope: dict[str, Any],
54
+ receive: Callable[[], Awaitable[dict[str, Any]]],
55
+ send: Callable[[dict[str, Any]], Awaitable[None]],
56
+ ) -> None:
57
+ if scope.get("type") != "http":
58
+ await self.app(scope, receive, send)
59
+ return
60
+
61
+ request_body, request_messages = await _read_request_body(receive)
62
+ request_info = _request_info(scope, request_body)
63
+ if self.should_trace and not self.should_trace(request_info):
64
+ await self.app(scope, _replay_receive(request_messages), send)
65
+ return
66
+
67
+ start_time = time.perf_counter()
68
+ response_messages: list[dict[str, Any]] = []
69
+
70
+ async def capture_send(message: dict[str, Any]) -> None:
71
+ response_messages.append(message)
72
+
73
+ try:
74
+ await self.app(scope, _replay_receive(request_messages), capture_send)
75
+ except BaseException as exc:
76
+ await self._log_exception(request_info, exc, start_time)
77
+ raise
78
+
79
+ response_info = _response_info(response_messages, start_time)
80
+ await self._log_trace(request_info, response_info)
81
+
82
+ for message in response_messages:
83
+ await send(message)
84
+
85
+ async def _log_exception(
86
+ self,
87
+ request_info: dict[str, Any],
88
+ exc: BaseException,
89
+ start_time: float,
90
+ ) -> None:
91
+ try:
92
+ extracted = await _call_extractor(self.request_extractor, request_info)
93
+ query = extracted.get("query") or request_info.get("path") or "unknown request"
94
+ with self.client.trace(
95
+ query=str(query),
96
+ metadata={
97
+ **self.trace_metadata,
98
+ **(extracted.get("metadata") or {}),
99
+ "integration": "fastapi",
100
+ "http": _http_metadata(request_info),
101
+ },
102
+ ) as trace:
103
+ trace.log_agent_error(
104
+ str(exc),
105
+ name="fastapi_endpoint_error",
106
+ metadata={"error_type": exc.__class__.__name__},
107
+ latency_ms=_elapsed_ms(start_time),
108
+ )
109
+ except Exception:
110
+ if self.raise_logging_errors:
111
+ raise
112
+
113
+ async def _log_trace(
114
+ self,
115
+ request_info: dict[str, Any],
116
+ response_info: dict[str, Any],
117
+ ) -> None:
118
+ try:
119
+ request_data = await _call_extractor(self.request_extractor, request_info)
120
+ response_data = await _call_extractor(self.response_extractor, response_info, request_info)
121
+ payload = {**request_data, **response_data}
122
+ query = payload.get("query") or request_info.get("path") or "unknown request"
123
+ metadata = {
124
+ **self.trace_metadata,
125
+ **(request_data.get("metadata") or {}),
126
+ **(response_data.get("metadata") or {}),
127
+ "integration": "fastapi",
128
+ "http": _http_metadata(request_info, response_info),
129
+ }
130
+
131
+ with self.client.trace(query=str(query), metadata=metadata) as trace:
132
+ retrieved = payload.get("retrieved_chunks") or payload.get("chunks") or []
133
+ selected = payload.get("selected_context") or payload.get("context") or []
134
+ citations = payload.get("citations") or []
135
+ answer = payload.get("answer")
136
+
137
+ if retrieved:
138
+ trace.log_retrieval(
139
+ retrieved,
140
+ retriever_name=payload.get("retriever_name") or "fastapi_endpoint",
141
+ metadata={"source": "fastapi_response"},
142
+ )
143
+ if selected:
144
+ trace.log_context(selected, metadata={"source": "fastapi_response"})
145
+ if answer:
146
+ trace.log_answer(
147
+ str(answer),
148
+ model=payload.get("model"),
149
+ usage=payload.get("usage") or {},
150
+ metadata={"latency_ms": response_info.get("latency_ms")},
151
+ )
152
+ if citations:
153
+ trace.log_citations(citations)
154
+ except Exception:
155
+ if self.raise_logging_errors:
156
+ raise
157
+
158
+
159
+ def default_request_extractor(request: dict[str, Any]) -> dict[str, Any]:
160
+ body = request.get("json")
161
+ if not isinstance(body, dict):
162
+ return {"query": request.get("path"), "metadata": {}}
163
+ query = _first_string(body, "query", "question", "input", "prompt")
164
+ return {
165
+ "query": query or request.get("path"),
166
+ "metadata": body.get("metadata") if isinstance(body.get("metadata"), dict) else {},
167
+ "retrieved_chunks": body.get("retrieved_chunks") or body.get("chunks") or [],
168
+ "selected_context": body.get("selected_context") or body.get("context") or [],
169
+ "citations": _citations(body.get("citations") or []),
170
+ }
171
+
172
+
173
+ def default_response_extractor(response: dict[str, Any], request: Optional[dict[str, Any]] = None) -> dict[str, Any]:
174
+ body = response.get("json")
175
+ if not isinstance(body, dict):
176
+ return {}
177
+ return {
178
+ "answer": _first_string(body, "answer", "response", "output", "result", "text"),
179
+ "metadata": body.get("metadata") if isinstance(body.get("metadata"), dict) else {},
180
+ "retrieved_chunks": body.get("retrieved_chunks") or body.get("contexts") or body.get("chunks") or [],
181
+ "selected_context": body.get("selected_context") or body.get("context") or [],
182
+ "citations": _citations(body.get("citations") or body.get("sources") or []),
183
+ "model": body.get("model"),
184
+ "usage": body.get("usage") if isinstance(body.get("usage"), dict) else {},
185
+ }
186
+
187
+
188
+ async def _read_request_body(
189
+ receive: Callable[[], Awaitable[dict[str, Any]]],
190
+ ) -> tuple[bytes, list[dict[str, Any]]]:
191
+ body_parts: list[bytes] = []
192
+ messages: list[dict[str, Any]] = []
193
+ while True:
194
+ message = await receive()
195
+ messages.append(message)
196
+ if message.get("type") != "http.request":
197
+ break
198
+ body_parts.append(message.get("body", b""))
199
+ if not message.get("more_body", False):
200
+ break
201
+ return b"".join(body_parts), messages
202
+
203
+
204
+ def _replay_receive(messages: list[dict[str, Any]]) -> Callable[[], Awaitable[dict[str, Any]]]:
205
+ pending = list(messages)
206
+
207
+ async def receive() -> dict[str, Any]:
208
+ if pending:
209
+ return pending.pop(0)
210
+ return {"type": "http.request", "body": b"", "more_body": False}
211
+
212
+ return receive
213
+
214
+
215
+ def _request_info(scope: dict[str, Any], body: bytes) -> dict[str, Any]:
216
+ headers = {
217
+ key.decode("latin1").lower(): value.decode("latin1")
218
+ for key, value in scope.get("headers", [])
219
+ }
220
+ return {
221
+ "method": scope.get("method"),
222
+ "path": scope.get("path"),
223
+ "headers": headers,
224
+ "body": body,
225
+ "json": _decode_json(body),
226
+ "scope": scope,
227
+ }
228
+
229
+
230
+ def _response_info(messages: list[dict[str, Any]], start_time: float) -> dict[str, Any]:
231
+ status_code = None
232
+ headers: dict[str, str] = {}
233
+ body_parts: list[bytes] = []
234
+ for message in messages:
235
+ if message.get("type") == "http.response.start":
236
+ status_code = message.get("status")
237
+ headers = {
238
+ key.decode("latin1").lower(): value.decode("latin1")
239
+ for key, value in message.get("headers", [])
240
+ }
241
+ if message.get("type") == "http.response.body":
242
+ body_parts.append(message.get("body", b""))
243
+ body = b"".join(body_parts)
244
+ return {
245
+ "status_code": status_code,
246
+ "headers": headers,
247
+ "body": body,
248
+ "json": _decode_json(body),
249
+ "latency_ms": _elapsed_ms(start_time),
250
+ }
251
+
252
+
253
+ async def _call_extractor(extractor: Extractor, *args: Any) -> dict[str, Any]:
254
+ try:
255
+ value = extractor(*args)
256
+ except TypeError:
257
+ value = extractor(args[0])
258
+ if inspect.isawaitable(value):
259
+ value = await value
260
+ if not isinstance(value, dict):
261
+ raise ValueError("ContextTrace extractor must return a dictionary.")
262
+ return value
263
+
264
+
265
+ def _decode_json(body: bytes) -> Any:
266
+ if not body:
267
+ return None
268
+ try:
269
+ return json.loads(body.decode("utf-8"))
270
+ except (UnicodeDecodeError, json.JSONDecodeError):
271
+ return None
272
+
273
+
274
+ def _first_string(data: dict[str, Any], *keys: str) -> Optional[str]:
275
+ for key in keys:
276
+ value = data.get(key)
277
+ if isinstance(value, str) and value.strip():
278
+ return value
279
+ return None
280
+
281
+
282
+ def _citations(value: Any) -> list[dict[str, Any]]:
283
+ if not isinstance(value, list):
284
+ return []
285
+ citations = []
286
+ for citation in value:
287
+ if not isinstance(citation, dict):
288
+ continue
289
+ claim = citation.get("claim")
290
+ source_chunk_id = citation.get("source_chunk_id") or citation.get("chunk_id") or citation.get("source")
291
+ if claim and source_chunk_id:
292
+ citations.append({"claim": str(claim), "source_chunk_id": str(source_chunk_id)})
293
+ return citations
294
+
295
+
296
+ def _http_metadata(
297
+ request: dict[str, Any],
298
+ response: Optional[dict[str, Any]] = None,
299
+ ) -> dict[str, Any]:
300
+ metadata = {
301
+ "method": request.get("method"),
302
+ "path": request.get("path"),
303
+ }
304
+ if response is not None:
305
+ metadata["status_code"] = response.get("status_code")
306
+ metadata["latency_ms"] = response.get("latency_ms")
307
+ return metadata
308
+
309
+
310
+ def _elapsed_ms(start_time: float) -> int:
311
+ return int((time.perf_counter() - start_time) * 1000)