contexttrace 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,422 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ from typing import Any, Callable, Dict, Iterable, Optional
5
+
6
+ from contexttrace.client import ContextTrace
7
+
8
+ try:
9
+ from llama_index.core.callbacks.base_handler import BaseCallbackHandler
10
+ except Exception: # pragma: no cover - exercised when llama-index-core is not installed
11
+ BaseCallbackHandler = object # type: ignore[assignment]
12
+
13
+
14
+ QueryExtractor = Callable[[Any], Optional[str]]
15
+ ResponseExtractor = Callable[[Any], Optional[str]]
16
+ NodeConverter = Callable[[Any, int], Dict[str, Any]]
17
+
18
+
19
+ class ContextTraceLlamaIndexCallbackHandler(BaseCallbackHandler): # type: ignore[misc]
20
+ def __init__(
21
+ self,
22
+ *,
23
+ project: Optional[str] = None,
24
+ api_key: Optional[str] = None,
25
+ base_url: str = "http://localhost:8000",
26
+ client: Optional[ContextTrace] = None,
27
+ trace_metadata: Optional[dict[str, Any]] = None,
28
+ selected_context_limit: Optional[int] = None,
29
+ query_extractor: Optional[QueryExtractor] = None,
30
+ response_extractor: Optional[ResponseExtractor] = None,
31
+ node_converter: Optional[NodeConverter] = None,
32
+ ) -> None:
33
+ try:
34
+ super().__init__(event_starts_to_ignore=[], event_ends_to_ignore=[])
35
+ except TypeError:
36
+ try:
37
+ super().__init__()
38
+ except TypeError:
39
+ pass
40
+
41
+ self.event_starts_to_ignore = []
42
+ self.event_ends_to_ignore = []
43
+
44
+ if client is None:
45
+ kwargs: dict[str, Any] = {"project": project or "default"}
46
+ if api_key:
47
+ kwargs.update({"api_key": api_key, "base_url": base_url, "mode": "hosted"})
48
+ client = ContextTrace(**kwargs)
49
+
50
+ self.client = client
51
+ self.trace_metadata = trace_metadata or {}
52
+ self.selected_context_limit = selected_context_limit
53
+ self.query_extractor = query_extractor or _extract_query
54
+ self.response_extractor = response_extractor or _extract_response_text
55
+ self.node_converter = node_converter or llamaindex_node_to_chunk
56
+ self.trace = None
57
+ self.query: Optional[str] = None
58
+ self.start_time: Optional[float] = None
59
+ self.retrieve_start_time: Optional[float] = None
60
+ self.retrieved_chunks: list[dict[str, Any]] = []
61
+ self.source_chunks: list[dict[str, Any]] = []
62
+ self.answer_logged = False
63
+
64
+ def start_trace(self, trace_id: Optional[str] = None) -> None:
65
+ return None
66
+
67
+ def end_trace(
68
+ self,
69
+ trace_id: Optional[str] = None,
70
+ trace_map: Optional[dict[str, list[str]]] = None,
71
+ ) -> None:
72
+ return None
73
+
74
+ def on_event_start(
75
+ self,
76
+ event_type: Any,
77
+ payload: Optional[dict[str, Any]] = None,
78
+ event_id: str = "",
79
+ parent_id: str = "",
80
+ **kwargs: Any,
81
+ ) -> str:
82
+ if _event_matches(event_type, "query"):
83
+ query = self.query_extractor(payload)
84
+ if query:
85
+ self._ensure_trace(
86
+ query=query,
87
+ event="query_start",
88
+ metadata=_event_metadata(event_type, payload, event_id, parent_id, kwargs),
89
+ )
90
+ if _event_matches(event_type, "retrieve", "retriever"):
91
+ self.retrieve_start_time = time.perf_counter()
92
+ return event_id
93
+
94
+ def on_event_end(
95
+ self,
96
+ event_type: Any,
97
+ payload: Optional[dict[str, Any]] = None,
98
+ event_id: str = "",
99
+ **kwargs: Any,
100
+ ) -> None:
101
+ if _event_matches(event_type, "retrieve", "retriever"):
102
+ self._handle_retrieval_end(event_type, payload, event_id, kwargs)
103
+ return
104
+
105
+ if _event_matches(event_type, "query", "synthesize", "response"):
106
+ self._handle_response_end(event_type, payload, event_id, kwargs)
107
+
108
+ def trace_query(self, query: str, *, metadata: Optional[dict[str, Any]] = None) -> None:
109
+ self._ensure_trace(query=query, event="manual_query", metadata=metadata or {})
110
+
111
+ def trace_retrieved_nodes(self, nodes: Iterable[Any], *, metadata: Optional[dict[str, Any]] = None) -> None:
112
+ self._log_retrieved_nodes(nodes, metadata=metadata or {})
113
+
114
+ def trace_response(self, response: Any, *, metadata: Optional[dict[str, Any]] = None) -> None:
115
+ self._log_response(response, metadata=metadata or {})
116
+
117
+ def _handle_retrieval_end(
118
+ self,
119
+ event_type: Any,
120
+ payload: Optional[dict[str, Any]],
121
+ event_id: str,
122
+ kwargs: dict[str, Any],
123
+ ) -> None:
124
+ nodes = _extract_nodes(payload, keys=("nodes", "source_nodes", "documents"))
125
+ metadata = _event_metadata(event_type, payload, event_id, None, kwargs)
126
+ self._log_retrieved_nodes(nodes, metadata=metadata)
127
+
128
+ def _handle_response_end(
129
+ self,
130
+ event_type: Any,
131
+ payload: Optional[dict[str, Any]],
132
+ event_id: str,
133
+ kwargs: dict[str, Any],
134
+ ) -> None:
135
+ response = _payload_get(payload, "response", "output", "result")
136
+ if response is None and payload:
137
+ response = payload
138
+ metadata = _event_metadata(event_type, payload, event_id, None, kwargs)
139
+ self._log_response(response, metadata=metadata)
140
+
141
+ def _log_retrieved_nodes(self, nodes: Iterable[Any], *, metadata: dict[str, Any]) -> None:
142
+ chunks = [self.node_converter(node, index) for index, node in enumerate(nodes or [])]
143
+ self.retrieved_chunks = chunks
144
+
145
+ if self.trace is None:
146
+ self._ensure_trace(
147
+ query=self.query or "unknown query",
148
+ event="retrieve_end",
149
+ metadata=metadata,
150
+ )
151
+
152
+ if self.trace is None or not chunks:
153
+ return
154
+
155
+ self.trace.log_retrieval(
156
+ chunks,
157
+ retriever_name="llamaindex_retriever",
158
+ metadata={
159
+ **metadata,
160
+ "latency_ms": _elapsed_ms(self.retrieve_start_time),
161
+ },
162
+ )
163
+
164
+ def _log_response(self, response: Any, *, metadata: dict[str, Any]) -> None:
165
+ answer = self.response_extractor(response)
166
+ source_nodes = _extract_source_nodes(response)
167
+ source_chunks = [
168
+ self.node_converter(node, index) for index, node in enumerate(source_nodes)
169
+ ]
170
+ self.source_chunks = source_chunks
171
+
172
+ if self.trace is None:
173
+ self._ensure_trace(
174
+ query=self.query or self.query_extractor(response) or "unknown query",
175
+ event="response_end",
176
+ metadata=metadata,
177
+ )
178
+
179
+ if self.trace is None:
180
+ return
181
+
182
+ selected_chunks = source_chunks or self.retrieved_chunks
183
+ if self.selected_context_limit:
184
+ selected_chunks = selected_chunks[: self.selected_context_limit]
185
+ if selected_chunks:
186
+ self.trace.log_context(
187
+ selected_chunks,
188
+ metadata={
189
+ "source": "llamaindex_source_nodes" if source_chunks else "llamaindex_retrieved_nodes",
190
+ "source_node_count": len(source_chunks),
191
+ "retrieved_node_count": len(self.retrieved_chunks),
192
+ },
193
+ )
194
+
195
+ if answer and not self.answer_logged:
196
+ answer_metadata = dict(metadata)
197
+ answer_metadata.update(
198
+ {
199
+ "latency_ms": self._latency_ms(),
200
+ "source_node_count": len(source_chunks),
201
+ "retrieved_node_count": len(self.retrieved_chunks),
202
+ }
203
+ )
204
+ self.trace.log_answer(
205
+ answer,
206
+ metadata=answer_metadata,
207
+ )
208
+ self.answer_logged = True
209
+
210
+ def _ensure_trace(
211
+ self,
212
+ *,
213
+ query: str,
214
+ event: str,
215
+ metadata: dict[str, Any],
216
+ ) -> None:
217
+ if self.trace is not None:
218
+ return
219
+
220
+ self.query = query
221
+ self.start_time = time.perf_counter()
222
+ trace_metadata = dict(self.trace_metadata)
223
+ trace_metadata.update(
224
+ {
225
+ "integration": "llamaindex",
226
+ "start_event": event,
227
+ "llamaindex": metadata,
228
+ }
229
+ )
230
+ self.trace = self.client.trace(query=query, metadata=trace_metadata).__enter__()
231
+
232
+ def _latency_ms(self) -> int:
233
+ if self.start_time is None:
234
+ return 0
235
+ return int((time.perf_counter() - self.start_time) * 1000)
236
+
237
+
238
+ def llamaindex_node_to_chunk(node_or_node_with_score: Any, index: int = 0) -> dict[str, Any]:
239
+ node = getattr(node_or_node_with_score, "node", node_or_node_with_score)
240
+ score = getattr(node_or_node_with_score, "score", None)
241
+ metadata = _node_metadata(node)
242
+ content = _node_content(node)
243
+ if content is None:
244
+ raise ValueError("LlamaIndex node must include text, content, or get_content().")
245
+
246
+ chunk_id = (
247
+ metadata.get("chunk_id")
248
+ or metadata.get("id")
249
+ or metadata.get("doc_id")
250
+ or getattr(node, "node_id", None)
251
+ or getattr(node, "id_", None)
252
+ or getattr(node, "id", None)
253
+ or f"llamaindex_node_{index}"
254
+ )
255
+ source = metadata.get("source") or metadata.get("url") or metadata.get("path") or metadata.get("file_name")
256
+ relevance_score = (
257
+ score
258
+ if score is not None
259
+ else metadata.get("relevance_score") or metadata.get("score")
260
+ )
261
+
262
+ return {
263
+ "chunk_id": str(chunk_id),
264
+ "content": str(content),
265
+ "source": source,
266
+ "metadata": metadata,
267
+ "relevance_score": relevance_score,
268
+ }
269
+
270
+
271
+ def _node_metadata(node: Any) -> dict[str, Any]:
272
+ metadata = getattr(node, "metadata", None)
273
+ if metadata is None:
274
+ metadata = getattr(node, "extra_info", None)
275
+ if isinstance(node, dict):
276
+ metadata = node.get("metadata") or node.get("extra_info") or metadata
277
+ if not isinstance(metadata, dict):
278
+ return {"metadata": metadata} if metadata is not None else {}
279
+ return metadata
280
+
281
+
282
+ def _node_content(node: Any) -> Optional[str]:
283
+ get_content = getattr(node, "get_content", None)
284
+ if callable(get_content):
285
+ try:
286
+ content = get_content()
287
+ if content is not None:
288
+ return str(content)
289
+ except TypeError:
290
+ pass
291
+
292
+ for attr in ("text", "content", "page_content"):
293
+ value = getattr(node, attr, None)
294
+ if value is not None:
295
+ return str(value)
296
+
297
+ text_resource = getattr(node, "text_resource", None)
298
+ value = getattr(text_resource, "text", None)
299
+ if value is not None:
300
+ return str(value)
301
+
302
+ if isinstance(node, dict):
303
+ for key in ("text", "content", "page_content"):
304
+ value = node.get(key)
305
+ if value is not None:
306
+ return str(value)
307
+ return None
308
+
309
+
310
+ def _event_matches(event_type: Any, *names: str) -> bool:
311
+ event_name = _event_name(event_type)
312
+ return event_name in names
313
+
314
+
315
+ def _event_name(event_type: Any) -> str:
316
+ for attr in ("value", "name"):
317
+ value = getattr(event_type, attr, None)
318
+ if isinstance(value, str):
319
+ return value.lower()
320
+ text = str(event_type).lower()
321
+ return text.rsplit(".", 1)[-1]
322
+
323
+
324
+ def _payload_get(payload: Any, *keys: str) -> Any:
325
+ if not isinstance(payload, dict):
326
+ return None
327
+
328
+ normalized = {key.lower() for key in keys}
329
+ for key, value in payload.items():
330
+ variants = {str(key).lower()}
331
+ name = getattr(key, "name", None)
332
+ if isinstance(name, str):
333
+ variants.add(name.lower())
334
+ key_value = getattr(key, "value", None)
335
+ if isinstance(key_value, str):
336
+ variants.add(key_value.lower())
337
+ if variants & normalized:
338
+ return value
339
+ return None
340
+
341
+
342
+ def _extract_query(value: Any) -> Optional[str]:
343
+ if isinstance(value, str):
344
+ return value
345
+ if isinstance(value, dict):
346
+ result = _payload_get(value, "query_str", "query", "question", "input", "prompt")
347
+ if isinstance(result, str) and result.strip():
348
+ return result
349
+ response = getattr(value, "query", None)
350
+ if isinstance(response, str) and response.strip():
351
+ return response
352
+ return None
353
+
354
+
355
+ def _extract_nodes(payload: Any, *, keys: tuple[str, ...]) -> Iterable[Any]:
356
+ nodes = _payload_get(payload, *keys)
357
+ if nodes is None:
358
+ return []
359
+ return nodes
360
+
361
+
362
+ def _extract_response_text(response: Any) -> Optional[str]:
363
+ if isinstance(response, str):
364
+ return response
365
+ if isinstance(response, dict):
366
+ for key in ("response", "answer", "output", "result", "text"):
367
+ value = response.get(key)
368
+ if isinstance(value, str) and value.strip():
369
+ return value
370
+ for attr in ("response", "answer", "output", "text"):
371
+ value = getattr(response, attr, None)
372
+ if isinstance(value, str) and value.strip():
373
+ return value
374
+ if response is not None:
375
+ text = str(response)
376
+ return text if text and not text.startswith("<") else None
377
+ return None
378
+
379
+
380
+ def _extract_source_nodes(response: Any) -> list[Any]:
381
+ if isinstance(response, dict):
382
+ nodes = response.get("source_nodes") or response.get("nodes")
383
+ return list(nodes or [])
384
+ nodes = getattr(response, "source_nodes", None)
385
+ if nodes is None:
386
+ nodes = getattr(response, "source_nodes_with_scores", None)
387
+ return list(nodes or [])
388
+
389
+
390
+ def _elapsed_ms(start_time: Optional[float]) -> Optional[int]:
391
+ if start_time is None:
392
+ return None
393
+ return int((time.perf_counter() - start_time) * 1000)
394
+
395
+
396
+ def _event_metadata(
397
+ event_type: Any,
398
+ payload: Optional[dict[str, Any]],
399
+ event_id: Optional[str],
400
+ parent_id: Optional[str],
401
+ kwargs: dict[str, Any],
402
+ ) -> dict[str, Any]:
403
+ metadata: dict[str, Any] = {
404
+ "event_type": _event_name(event_type),
405
+ }
406
+ if event_id:
407
+ metadata["event_id"] = event_id
408
+ if parent_id:
409
+ metadata["parent_id"] = parent_id
410
+
411
+ callback_metadata = kwargs.get("metadata")
412
+ if isinstance(callback_metadata, dict):
413
+ metadata["metadata"] = callback_metadata
414
+
415
+ if payload:
416
+ payload_keys = []
417
+ for key in payload.keys():
418
+ name = getattr(key, "name", None) or getattr(key, "value", None) or str(key)
419
+ payload_keys.append(str(name))
420
+ metadata["payload_keys"] = payload_keys
421
+
422
+ return metadata
@@ -0,0 +1,111 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from typing import Any, Optional
5
+
6
+
7
+ class OpenTelemetryExporter:
8
+ """Optional OpenTelemetry exporter for ContextTrace trace dictionaries."""
9
+
10
+ def __init__(
11
+ self,
12
+ *,
13
+ enabled: Optional[bool] = None,
14
+ tracer: Any = None,
15
+ tracer_name: str = "contexttrace",
16
+ ) -> None:
17
+ self.enabled = _enabled(enabled)
18
+ self.tracer = tracer if tracer is not None else self._load_tracer(tracer_name)
19
+
20
+ def export_trace(self, trace: dict[str, Any]) -> list[str]:
21
+ if not self.enabled or self.tracer is None:
22
+ return []
23
+
24
+ exported = []
25
+ with self.tracer.start_as_current_span("contexttrace.trace") as span:
26
+ _set(span, "contexttrace.trace_id", trace.get("id") or trace.get("trace_id"))
27
+ _set(span, "contexttrace.project", trace.get("project") or trace.get("project_id"))
28
+ _set(span, "contexttrace.query", trace.get("query"))
29
+ failure = (trace.get("evaluation") or {}).get("failure") or trace.get("failure") or {}
30
+ _set(span, "contexttrace.failure_type", failure.get("failure_type") or failure.get("type"))
31
+ exported.append("contexttrace.trace")
32
+
33
+ for chunk in trace.get("chunks") or []:
34
+ _event(
35
+ span,
36
+ "contexttrace.chunk",
37
+ {
38
+ "chunk_id": chunk.get("chunk_id"),
39
+ "source": chunk.get("source"),
40
+ "selected": bool(chunk.get("selected")),
41
+ },
42
+ )
43
+ exported.append("contexttrace.chunk")
44
+
45
+ answer = trace.get("answer") or {}
46
+ if answer:
47
+ _event(
48
+ span,
49
+ "contexttrace.answer",
50
+ {
51
+ "model": answer.get("model"),
52
+ "total_tokens": (answer.get("usage") or {}).get("total_tokens"),
53
+ },
54
+ )
55
+ exported.append("contexttrace.answer")
56
+
57
+ for check in trace.get("citation_checks") or []:
58
+ _event(
59
+ span,
60
+ "contexttrace.citation_check",
61
+ {
62
+ "claim": check.get("claim"),
63
+ "source_chunk_id": check.get("source_chunk_id"),
64
+ "support_status": check.get("support_status") or check.get("verdict"),
65
+ "support_score": check.get("support_score"),
66
+ },
67
+ )
68
+ exported.append("contexttrace.citation_check")
69
+
70
+ for event in trace.get("agent_events") or []:
71
+ _event(
72
+ span,
73
+ "contexttrace.agent_event",
74
+ {
75
+ "event_type": event.get("event_type"),
76
+ "name": event.get("name"),
77
+ "latency_ms": event.get("latency_ms"),
78
+ "error": event.get("error_message"),
79
+ },
80
+ )
81
+ exported.append("contexttrace.agent_event")
82
+ return exported
83
+
84
+ def _load_tracer(self, tracer_name: str) -> Any:
85
+ if not self.enabled:
86
+ return None
87
+ try:
88
+ from opentelemetry import trace as otel_trace
89
+ except Exception:
90
+ return None
91
+ return otel_trace.get_tracer(tracer_name)
92
+
93
+
94
+ def export_contexttrace_trace(trace: dict[str, Any], *, enabled: Optional[bool] = None, tracer: Any = None) -> list[str]:
95
+ return OpenTelemetryExporter(enabled=enabled, tracer=tracer).export_trace(trace)
96
+
97
+
98
+ def _enabled(value: Optional[bool]) -> bool:
99
+ if value is not None:
100
+ return value
101
+ return os.getenv("CONTEXTTRACE_OTEL_ENABLED", "").lower() in {"1", "true", "yes", "on"}
102
+
103
+
104
+ def _set(span: Any, key: str, value: Any) -> None:
105
+ if value is not None:
106
+ span.set_attribute(key, value)
107
+
108
+
109
+ def _event(span: Any, name: str, attributes: dict[str, Any]) -> None:
110
+ span.add_event(name, {key: value for key, value in attributes.items() if value is not None})
111
+