contexttrace 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,440 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ from collections.abc import Iterable as RuntimeIterable
5
+ from typing import Any, Callable, Dict, Iterable, Optional
6
+
7
+ from contexttrace.client import ContextTrace
8
+
9
+ try:
10
+ from langchain_core.callbacks import BaseCallbackHandler
11
+ except Exception: # pragma: no cover - exercised when langchain is not installed
12
+ BaseCallbackHandler = object # type: ignore[assignment]
13
+
14
+
15
+ QueryExtractor = Callable[[Any], Optional[str]]
16
+ AnswerExtractor = Callable[[Any], Optional[str]]
17
+ CitationExtractor = Callable[[Any], Iterable[Dict[str, Any]]]
18
+ DocumentConverter = Callable[[Any, int], Dict[str, Any]]
19
+ MetadataExtractor = Callable[[Any, Dict[str, Any]], Dict[str, Any]]
20
+
21
+
22
+ class ContextTraceCallbackHandler(BaseCallbackHandler): # type: ignore[misc]
23
+ def __init__(
24
+ self,
25
+ *,
26
+ project: Optional[str] = None,
27
+ api_key: Optional[str] = None,
28
+ base_url: str = "http://localhost:8000",
29
+ client: Optional[ContextTrace] = None,
30
+ trace_metadata: Optional[dict[str, Any]] = None,
31
+ selected_context_limit: Optional[int] = None,
32
+ query_extractor: Optional[QueryExtractor] = None,
33
+ answer_extractor: Optional[AnswerExtractor] = None,
34
+ citation_extractor: Optional[CitationExtractor] = None,
35
+ document_converter: Optional[DocumentConverter] = None,
36
+ metadata_extractor: Optional[MetadataExtractor] = None,
37
+ log_agent_events: bool = True,
38
+ ) -> None:
39
+ if client is None:
40
+ kwargs: dict[str, Any] = {"project": project or "default"}
41
+ if api_key:
42
+ kwargs.update({"api_key": api_key, "base_url": base_url, "mode": "hosted"})
43
+ client = ContextTrace(**kwargs)
44
+
45
+ self.client = client
46
+ self.trace_metadata = trace_metadata or {}
47
+ self.selected_context_limit = selected_context_limit
48
+ self.query_extractor = query_extractor or _extract_query
49
+ self.answer_extractor = answer_extractor or _extract_answer
50
+ self.citation_extractor = citation_extractor or _extract_citations
51
+ self.document_converter = document_converter or langchain_document_to_chunk
52
+ self.metadata_extractor = metadata_extractor
53
+ self.log_agent_events = log_agent_events
54
+ self.trace = None
55
+ self.query: Optional[str] = None
56
+ self.retrieved_chunks: list[dict[str, Any]] = []
57
+ self.start_time: Optional[float] = None
58
+ self.retriever_start_time: Optional[float] = None
59
+ self.llm_model: Optional[str] = None
60
+ self.llm_usage: dict[str, Any] = {}
61
+ self.answer_logged = False
62
+ self._tool_start_times: dict[str, float] = {}
63
+ self._tool_names: dict[str, str] = {}
64
+
65
+ def on_chain_start(
66
+ self,
67
+ serialized: dict[str, Any],
68
+ inputs: Any,
69
+ **kwargs: Any,
70
+ ) -> None:
71
+ query = self.query_extractor(inputs)
72
+ if query:
73
+ self._ensure_trace(
74
+ query=query,
75
+ event="chain_start",
76
+ metadata=_event_metadata(serialized, kwargs),
77
+ )
78
+
79
+ def on_retriever_start(
80
+ self,
81
+ serialized: dict[str, Any],
82
+ query: str,
83
+ **kwargs: Any,
84
+ ) -> None:
85
+ self.retriever_start_time = time.perf_counter()
86
+ self._ensure_trace(
87
+ query=query,
88
+ event="retriever_start",
89
+ metadata=_event_metadata(serialized, kwargs),
90
+ )
91
+
92
+ def on_retriever_end(self, documents: Iterable[Any], **kwargs: Any) -> None:
93
+ chunks = [self.document_converter(document, index) for index, document in enumerate(documents)]
94
+ self.retrieved_chunks = chunks
95
+
96
+ if self.trace is None:
97
+ self._ensure_trace(
98
+ query=self.query or "unknown query",
99
+ event="retriever_end",
100
+ metadata=_event_metadata(None, kwargs),
101
+ )
102
+
103
+ if not chunks or self.trace is None:
104
+ return
105
+
106
+ self.trace.log_retrieval(
107
+ chunks,
108
+ retriever_name=_serialized_name(kwargs.get("serialized")) or "langchain_retriever",
109
+ metadata={
110
+ **_event_metadata(None, kwargs),
111
+ "latency_ms": _elapsed_ms(self.retriever_start_time),
112
+ },
113
+ )
114
+ selected = chunks[: self.selected_context_limit] if self.selected_context_limit else chunks
115
+ self.trace.log_context(
116
+ selected,
117
+ metadata={
118
+ "source": "langchain_retriever_end",
119
+ "selected_context_limit": self.selected_context_limit,
120
+ },
121
+ )
122
+
123
+ def on_llm_start(self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any) -> None:
124
+ model = _serialized_name(serialized)
125
+ if model:
126
+ self.llm_model = model
127
+ if self.trace is None:
128
+ query = self.query or (prompts[0] if prompts else "unknown query")
129
+ self._ensure_trace(
130
+ query=query,
131
+ event="llm_start",
132
+ metadata=_event_metadata(serialized, kwargs),
133
+ )
134
+
135
+ def on_llm_end(self, response: Any, **kwargs: Any) -> None:
136
+ self.llm_usage = _extract_token_usage(response)
137
+ self.llm_model = _extract_model(response) or self.llm_model
138
+
139
+ def on_chain_end(self, outputs: Any, **kwargs: Any) -> None:
140
+ answer = self.answer_extractor(outputs)
141
+ if not answer:
142
+ return
143
+
144
+ if self.trace is None:
145
+ self._ensure_trace(
146
+ query=self.query or "unknown query",
147
+ event="chain_end",
148
+ metadata=_event_metadata(None, kwargs),
149
+ )
150
+
151
+ if self.trace is None or self.answer_logged:
152
+ return
153
+
154
+ self.trace.log_answer(
155
+ answer,
156
+ model=self.llm_model,
157
+ usage=self.llm_usage,
158
+ metadata=self._merge_metadata(
159
+ outputs,
160
+ {
161
+ "latency_ms": self._latency_ms(),
162
+ "langchain_output_keys": list(outputs.keys()) if isinstance(outputs, dict) else [],
163
+ },
164
+ ),
165
+ )
166
+ citations = list(self.citation_extractor(outputs))
167
+ if citations:
168
+ self.trace.log_citations(citations)
169
+ self.answer_logged = True
170
+
171
+ def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
172
+ if self.trace is not None and self.log_agent_events:
173
+ self.trace.log_agent_error(
174
+ str(error),
175
+ name="langchain_chain_error",
176
+ metadata={
177
+ "error_type": error.__class__.__name__,
178
+ **_event_metadata(None, kwargs),
179
+ },
180
+ latency_ms=self._latency_ms(),
181
+ )
182
+ if self.trace is not None and not self.answer_logged:
183
+ self.trace.log_answer(
184
+ "LangChain run failed before producing an answer.",
185
+ metadata={
186
+ "latency_ms": self._latency_ms(),
187
+ "error": str(error),
188
+ "error_type": error.__class__.__name__,
189
+ },
190
+ )
191
+ self.answer_logged = True
192
+
193
+ def on_tool_start(self, serialized: dict[str, Any], input_str: str, **kwargs: Any) -> None:
194
+ if not self.log_agent_events:
195
+ return
196
+ if self.trace is None:
197
+ self._ensure_trace(
198
+ query=self.query or input_str or "unknown query",
199
+ event="tool_start",
200
+ metadata=_event_metadata(serialized, kwargs),
201
+ )
202
+ if self.trace is None:
203
+ return
204
+ run_id = str(kwargs.get("run_id") or _serialized_name(serialized) or input_str)
205
+ tool_name = _serialized_name(serialized) or kwargs.get("name") or "langchain_tool"
206
+ self._tool_start_times[run_id] = time.perf_counter()
207
+ self._tool_names[run_id] = str(tool_name)
208
+ self.trace.log_tool_call(
209
+ str(tool_name),
210
+ input_json={"input": input_str},
211
+ metadata=_event_metadata(serialized, kwargs),
212
+ )
213
+
214
+ def on_tool_end(self, output: Any, **kwargs: Any) -> None:
215
+ if not self.log_agent_events or self.trace is None:
216
+ return
217
+ run_id = str(kwargs.get("run_id") or "langchain_tool")
218
+ tool_name = self._tool_names.get(run_id) or kwargs.get("name") or "langchain_tool"
219
+ self.trace.log_tool_result(
220
+ str(tool_name),
221
+ output_json=_json_safe(output),
222
+ metadata=_event_metadata(None, kwargs),
223
+ latency_ms=_elapsed_ms(self._tool_start_times.get(run_id)),
224
+ )
225
+
226
+ def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
227
+ if not self.log_agent_events or self.trace is None:
228
+ return
229
+ run_id = str(kwargs.get("run_id") or "langchain_tool")
230
+ tool_name = self._tool_names.get(run_id) or kwargs.get("name") or "langchain_tool"
231
+ self.trace.log_agent_error(
232
+ str(error),
233
+ name=str(tool_name),
234
+ metadata={
235
+ "error_type": error.__class__.__name__,
236
+ **_event_metadata(None, kwargs),
237
+ },
238
+ latency_ms=_elapsed_ms(self._tool_start_times.get(run_id)),
239
+ )
240
+
241
+ def _ensure_trace(
242
+ self,
243
+ *,
244
+ query: str,
245
+ event: str,
246
+ metadata: dict[str, Any],
247
+ ) -> None:
248
+ if self.trace is not None:
249
+ return
250
+
251
+ self.query = query
252
+ self.start_time = time.perf_counter()
253
+ trace_metadata = dict(self.trace_metadata)
254
+ trace_metadata.update(
255
+ {
256
+ "integration": "langchain",
257
+ "start_event": event,
258
+ "langchain": metadata,
259
+ }
260
+ )
261
+ self.trace = self.client.trace(query=query, metadata=trace_metadata).__enter__()
262
+
263
+ def _latency_ms(self) -> int:
264
+ if self.start_time is None:
265
+ return 0
266
+ return int((time.perf_counter() - self.start_time) * 1000)
267
+
268
+ def _merge_metadata(self, source: Any, base: dict[str, Any]) -> dict[str, Any]:
269
+ if not self.metadata_extractor:
270
+ return base
271
+ extracted = self.metadata_extractor(source, base)
272
+ if not extracted:
273
+ return base
274
+ merged = dict(base)
275
+ merged.update(extracted)
276
+ return merged
277
+
278
+
279
+ def langchain_document_to_chunk(document: Any, index: int = 0) -> dict[str, Any]:
280
+ metadata = getattr(document, "metadata", None) or {}
281
+ if not isinstance(metadata, dict):
282
+ metadata = {"metadata": metadata}
283
+
284
+ content = (
285
+ getattr(document, "page_content", None)
286
+ or getattr(document, "content", None)
287
+ or getattr(document, "text", None)
288
+ )
289
+ if content is None and isinstance(document, dict):
290
+ content = document.get("page_content") or document.get("content") or document.get("text")
291
+ metadata = document.get("metadata") or metadata
292
+
293
+ if content is None:
294
+ raise ValueError("LangChain document must include page_content, content, or text.")
295
+
296
+ chunk_id = (
297
+ metadata.get("chunk_id")
298
+ or metadata.get("id")
299
+ or metadata.get("doc_id")
300
+ or getattr(document, "id", None)
301
+ or f"langchain_doc_{index}"
302
+ )
303
+ source = metadata.get("source") or metadata.get("url") or metadata.get("path")
304
+ relevance_score = (
305
+ metadata.get("relevance_score")
306
+ or metadata.get("score")
307
+ or getattr(document, "score", None)
308
+ )
309
+
310
+ return {
311
+ "chunk_id": str(chunk_id),
312
+ "content": str(content),
313
+ "source": source,
314
+ "metadata": metadata,
315
+ "relevance_score": relevance_score,
316
+ }
317
+
318
+
319
+ def _extract_query(inputs: Any) -> Optional[str]:
320
+ if isinstance(inputs, str):
321
+ return inputs
322
+ if not isinstance(inputs, dict):
323
+ return None
324
+
325
+ for key in ("query", "question", "input", "prompt"):
326
+ value = inputs.get(key)
327
+ if isinstance(value, str) and value.strip():
328
+ return value
329
+
330
+ for value in inputs.values():
331
+ if isinstance(value, str) and value.strip():
332
+ return value
333
+ return None
334
+
335
+
336
+ def _extract_answer(outputs: Any) -> Optional[str]:
337
+ if isinstance(outputs, str):
338
+ return outputs
339
+ if not isinstance(outputs, dict):
340
+ return None
341
+
342
+ for key in ("answer", "output", "result", "text", "response"):
343
+ value = outputs.get(key)
344
+ if isinstance(value, str) and value.strip():
345
+ return value
346
+
347
+ for value in outputs.values():
348
+ if isinstance(value, str) and value.strip():
349
+ return value
350
+ return None
351
+
352
+
353
+ def _extract_citations(outputs: Any) -> Iterable[dict[str, Any]]:
354
+ if not isinstance(outputs, dict):
355
+ return []
356
+ raw = outputs.get("citations") or outputs.get("citation_checks") or []
357
+ if not isinstance(raw, RuntimeIterable) or isinstance(raw, (str, bytes)):
358
+ return []
359
+ citations = []
360
+ for citation in raw:
361
+ if not isinstance(citation, dict):
362
+ continue
363
+ claim = citation.get("claim")
364
+ source_chunk_id = citation.get("source_chunk_id") or citation.get("chunk_id") or citation.get("source")
365
+ if claim and source_chunk_id:
366
+ citations.append({"claim": str(claim), "source_chunk_id": str(source_chunk_id)})
367
+ return citations
368
+
369
+
370
+ def _extract_token_usage(response: Any) -> dict[str, Any]:
371
+ llm_output = getattr(response, "llm_output", None) or {}
372
+ if isinstance(llm_output, dict):
373
+ token_usage = llm_output.get("token_usage") or llm_output.get("usage")
374
+ if isinstance(token_usage, dict):
375
+ return token_usage
376
+ return {}
377
+
378
+
379
+ def _extract_model(response: Any) -> Optional[str]:
380
+ llm_output = getattr(response, "llm_output", None) or {}
381
+ if isinstance(llm_output, dict):
382
+ model = llm_output.get("model_name") or llm_output.get("model")
383
+ if isinstance(model, str):
384
+ return model
385
+ return None
386
+
387
+
388
+ def _elapsed_ms(start_time: Optional[float]) -> Optional[int]:
389
+ if start_time is None:
390
+ return None
391
+ return int((time.perf_counter() - start_time) * 1000)
392
+
393
+
394
+ def _json_safe(value: Any) -> Any:
395
+ if value is None or isinstance(value, (str, int, float, bool)):
396
+ return value
397
+ if isinstance(value, dict):
398
+ return {str(key): _json_safe(item) for key, item in value.items()}
399
+ if isinstance(value, (list, tuple, set)):
400
+ return [_json_safe(item) for item in value]
401
+ return str(value)
402
+
403
+
404
+ def _event_metadata(serialized: Any, kwargs: dict[str, Any]) -> dict[str, Any]:
405
+ metadata: dict[str, Any] = {}
406
+ name = _serialized_name(serialized)
407
+ if name:
408
+ metadata["serialized_name"] = name
409
+
410
+ callback_metadata = kwargs.get("metadata")
411
+ if isinstance(callback_metadata, dict):
412
+ metadata["metadata"] = callback_metadata
413
+
414
+ tags = kwargs.get("tags")
415
+ if tags:
416
+ metadata["tags"] = list(tags)
417
+
418
+ run_id = kwargs.get("run_id")
419
+ if run_id is not None:
420
+ metadata["run_id"] = str(run_id)
421
+
422
+ parent_run_id = kwargs.get("parent_run_id")
423
+ if parent_run_id is not None:
424
+ metadata["parent_run_id"] = str(parent_run_id)
425
+
426
+ return metadata
427
+
428
+
429
+ def _serialized_name(serialized: Any) -> Optional[str]:
430
+ if not isinstance(serialized, dict):
431
+ return None
432
+ name = serialized.get("name")
433
+ if isinstance(name, str):
434
+ return name
435
+ serialized_id = serialized.get("id")
436
+ if isinstance(serialized_id, list) and serialized_id:
437
+ return str(serialized_id[-1])
438
+ if isinstance(serialized_id, str):
439
+ return serialized_id
440
+ return None
@@ -0,0 +1,197 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ import time
5
+ from functools import wraps
6
+ from typing import Any, Callable, Optional
7
+
8
+ from contexttrace.client import ContextTrace, TraceSession
9
+
10
+
11
+ class ContextTraceLangGraphTracer:
12
+ """Beta LangGraph adapter for logging graph nodes, tools, memory, and errors."""
13
+
14
+ def __init__(
15
+ self,
16
+ *,
17
+ project: Optional[str] = None,
18
+ api_key: Optional[str] = None,
19
+ base_url: str = "http://localhost:8000",
20
+ client: Optional[ContextTrace] = None,
21
+ trace_metadata: Optional[dict[str, Any]] = None,
22
+ ) -> None:
23
+ if client is None:
24
+ kwargs: dict[str, Any] = {"project": project or "default"}
25
+ if api_key:
26
+ kwargs.update({"api_key": api_key, "base_url": base_url, "mode": "hosted"})
27
+ client = ContextTrace(**kwargs)
28
+ self.client = client
29
+ self.trace_metadata = trace_metadata or {}
30
+ self.trace: Optional[TraceSession] = None
31
+ self.query: Optional[str] = None
32
+ self._node_starts: dict[str, float] = {}
33
+
34
+ def start_trace(self, query: str, *, metadata: Optional[dict[str, Any]] = None) -> TraceSession:
35
+ if self.trace is not None:
36
+ return self.trace
37
+ self.query = query
38
+ trace_metadata = {
39
+ **self.trace_metadata,
40
+ **(metadata or {}),
41
+ "integration": "langgraph",
42
+ }
43
+ self.trace = self.client.trace(query=query, metadata=trace_metadata).__enter__()
44
+ return self.trace
45
+
46
+ def end_trace(
47
+ self,
48
+ *,
49
+ answer: Optional[str] = None,
50
+ metadata: Optional[dict[str, Any]] = None,
51
+ ) -> Optional[TraceSession]:
52
+ if self.trace is None:
53
+ return None
54
+ if answer:
55
+ self.trace.log_answer(answer, metadata=metadata or {})
56
+ self.trace.log_agent_event(
57
+ event_type="final_answer",
58
+ name="final_answer",
59
+ output_json={"answer": answer},
60
+ metadata=metadata or {},
61
+ )
62
+ trace = self.trace
63
+ self.trace = None
64
+ return trace
65
+
66
+ def on_node_start(
67
+ self,
68
+ name: str,
69
+ input_json: Any = None,
70
+ *,
71
+ event_type: str = "planner_step",
72
+ metadata: Optional[dict[str, Any]] = None,
73
+ ) -> None:
74
+ trace = self._ensure_trace(input_json)
75
+ self._node_starts[name] = time.perf_counter()
76
+ trace.log_agent_event(
77
+ event_type=event_type,
78
+ name=name,
79
+ input_json=input_json,
80
+ metadata={"phase": "start", **(metadata or {})},
81
+ )
82
+
83
+ def on_node_end(
84
+ self,
85
+ name: str,
86
+ output_json: Any = None,
87
+ *,
88
+ event_type: str = "planner_step",
89
+ metadata: Optional[dict[str, Any]] = None,
90
+ ) -> None:
91
+ trace = self._ensure_trace(output_json)
92
+ trace.log_agent_event(
93
+ event_type=event_type,
94
+ name=name,
95
+ output_json=output_json,
96
+ metadata={"phase": "end", **(metadata or {})},
97
+ latency_ms=_elapsed_ms(self._node_starts.get(name)),
98
+ )
99
+
100
+ def on_tool_start(self, name: str, input_json: Any = None, *, metadata: Optional[dict[str, Any]] = None) -> None:
101
+ self._node_starts[name] = time.perf_counter()
102
+ self._ensure_trace(input_json).log_tool_call(name, input_json=input_json, metadata=metadata)
103
+
104
+ def on_tool_end(
105
+ self,
106
+ name: str,
107
+ output_json: Any = None,
108
+ *,
109
+ input_json: Any = None,
110
+ metadata: Optional[dict[str, Any]] = None,
111
+ ) -> None:
112
+ self._ensure_trace(output_json).log_tool_result(
113
+ name,
114
+ input_json=input_json,
115
+ output_json=output_json,
116
+ metadata=metadata,
117
+ latency_ms=_elapsed_ms(self._node_starts.get(name)),
118
+ )
119
+
120
+ def on_error(self, name: str, error: BaseException, *, input_json: Any = None) -> None:
121
+ self._ensure_trace(input_json).log_agent_error(
122
+ str(error),
123
+ name=name,
124
+ input_json=input_json,
125
+ metadata={"error_type": error.__class__.__name__},
126
+ latency_ms=_elapsed_ms(self._node_starts.get(name)),
127
+ )
128
+
129
+ def wrap_node(
130
+ self,
131
+ name: str,
132
+ func: Callable[..., Any],
133
+ *,
134
+ event_type: str = "planner_step",
135
+ ) -> Callable[..., Any]:
136
+ if inspect.iscoroutinefunction(func):
137
+
138
+ @wraps(func)
139
+ async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
140
+ input_json = {"args": _json_safe(args), "kwargs": _json_safe(kwargs)}
141
+ self.on_node_start(name, input_json, event_type=event_type)
142
+ try:
143
+ output = await func(*args, **kwargs)
144
+ self.on_node_end(name, _json_safe(output), event_type=event_type)
145
+ return output
146
+ except BaseException as exc:
147
+ self.on_error(name, exc, input_json=input_json)
148
+ raise
149
+
150
+ return async_wrapper
151
+
152
+ @wraps(func)
153
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
154
+ input_json = {"args": _json_safe(args), "kwargs": _json_safe(kwargs)}
155
+ self.on_node_start(name, input_json, event_type=event_type)
156
+ try:
157
+ output = func(*args, **kwargs)
158
+ self.on_node_end(name, _json_safe(output), event_type=event_type)
159
+ return output
160
+ except BaseException as exc:
161
+ self.on_error(name, exc, input_json=input_json)
162
+ raise
163
+
164
+ return wrapper
165
+
166
+ def _ensure_trace(self, value: Any = None) -> TraceSession:
167
+ if self.trace is not None:
168
+ return self.trace
169
+ query = _query_from_value(value) or self.query or "langgraph run"
170
+ return self.start_trace(query)
171
+
172
+
173
+ def _elapsed_ms(start_time: Optional[float]) -> Optional[int]:
174
+ if start_time is None:
175
+ return None
176
+ return int((time.perf_counter() - start_time) * 1000)
177
+
178
+
179
+ def _query_from_value(value: Any) -> Optional[str]:
180
+ if isinstance(value, str):
181
+ return value
182
+ if isinstance(value, dict):
183
+ for key in ("query", "question", "input", "prompt"):
184
+ candidate = value.get(key)
185
+ if isinstance(candidate, str) and candidate.strip():
186
+ return candidate
187
+ return None
188
+
189
+
190
+ def _json_safe(value: Any) -> Any:
191
+ if value is None or isinstance(value, (str, int, float, bool)):
192
+ return value
193
+ if isinstance(value, dict):
194
+ return {str(key): _json_safe(item) for key, item in value.items()}
195
+ if isinstance(value, (list, tuple, set)):
196
+ return [_json_safe(item) for item in value]
197
+ return str(value)