contexttrace 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- contexttrace/__init__.py +36 -0
- contexttrace/_version.py +1 -0
- contexttrace/cli.py +474 -0
- contexttrace/client.py +1074 -0
- contexttrace/config.py +246 -0
- contexttrace/demo.py +311 -0
- contexttrace/demo_data.py +257 -0
- contexttrace/endpoint_eval.py +314 -0
- contexttrace/errors.py +14 -0
- contexttrace/evaluator.py +448 -0
- contexttrace/integrations/__init__.py +14 -0
- contexttrace/integrations/fastapi.py +311 -0
- contexttrace/integrations/langchain.py +440 -0
- contexttrace/integrations/langgraph.py +197 -0
- contexttrace/integrations/llamaindex.py +422 -0
- contexttrace/integrations/opentelemetry.py +111 -0
- contexttrace/local.py +325 -0
- contexttrace/py.typed +1 -0
- contexttrace/regression.py +123 -0
- contexttrace/reliability.py +284 -0
- contexttrace/report.py +550 -0
- contexttrace/storage/__init__.py +3 -0
- contexttrace/storage/sqlite_store.py +604 -0
- contexttrace/thresholds.py +50 -0
- contexttrace/transport.py +183 -0
- contexttrace/viewer.py +148 -0
- contexttrace-0.1.0.dist-info/METADATA +154 -0
- contexttrace-0.1.0.dist-info/RECORD +31 -0
- contexttrace-0.1.0.dist-info/WHEEL +5 -0
- contexttrace-0.1.0.dist-info/entry_points.txt +2 -0
- contexttrace-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,440 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
from collections.abc import Iterable as RuntimeIterable
|
|
5
|
+
from typing import Any, Callable, Dict, Iterable, Optional
|
|
6
|
+
|
|
7
|
+
from contexttrace.client import ContextTrace
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
from langchain_core.callbacks import BaseCallbackHandler
|
|
11
|
+
except Exception: # pragma: no cover - exercised when langchain is not installed
|
|
12
|
+
BaseCallbackHandler = object # type: ignore[assignment]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
QueryExtractor = Callable[[Any], Optional[str]]
|
|
16
|
+
AnswerExtractor = Callable[[Any], Optional[str]]
|
|
17
|
+
CitationExtractor = Callable[[Any], Iterable[Dict[str, Any]]]
|
|
18
|
+
DocumentConverter = Callable[[Any, int], Dict[str, Any]]
|
|
19
|
+
MetadataExtractor = Callable[[Any, Dict[str, Any]], Dict[str, Any]]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ContextTraceCallbackHandler(BaseCallbackHandler): # type: ignore[misc]
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
*,
|
|
26
|
+
project: Optional[str] = None,
|
|
27
|
+
api_key: Optional[str] = None,
|
|
28
|
+
base_url: str = "http://localhost:8000",
|
|
29
|
+
client: Optional[ContextTrace] = None,
|
|
30
|
+
trace_metadata: Optional[dict[str, Any]] = None,
|
|
31
|
+
selected_context_limit: Optional[int] = None,
|
|
32
|
+
query_extractor: Optional[QueryExtractor] = None,
|
|
33
|
+
answer_extractor: Optional[AnswerExtractor] = None,
|
|
34
|
+
citation_extractor: Optional[CitationExtractor] = None,
|
|
35
|
+
document_converter: Optional[DocumentConverter] = None,
|
|
36
|
+
metadata_extractor: Optional[MetadataExtractor] = None,
|
|
37
|
+
log_agent_events: bool = True,
|
|
38
|
+
) -> None:
|
|
39
|
+
if client is None:
|
|
40
|
+
kwargs: dict[str, Any] = {"project": project or "default"}
|
|
41
|
+
if api_key:
|
|
42
|
+
kwargs.update({"api_key": api_key, "base_url": base_url, "mode": "hosted"})
|
|
43
|
+
client = ContextTrace(**kwargs)
|
|
44
|
+
|
|
45
|
+
self.client = client
|
|
46
|
+
self.trace_metadata = trace_metadata or {}
|
|
47
|
+
self.selected_context_limit = selected_context_limit
|
|
48
|
+
self.query_extractor = query_extractor or _extract_query
|
|
49
|
+
self.answer_extractor = answer_extractor or _extract_answer
|
|
50
|
+
self.citation_extractor = citation_extractor or _extract_citations
|
|
51
|
+
self.document_converter = document_converter or langchain_document_to_chunk
|
|
52
|
+
self.metadata_extractor = metadata_extractor
|
|
53
|
+
self.log_agent_events = log_agent_events
|
|
54
|
+
self.trace = None
|
|
55
|
+
self.query: Optional[str] = None
|
|
56
|
+
self.retrieved_chunks: list[dict[str, Any]] = []
|
|
57
|
+
self.start_time: Optional[float] = None
|
|
58
|
+
self.retriever_start_time: Optional[float] = None
|
|
59
|
+
self.llm_model: Optional[str] = None
|
|
60
|
+
self.llm_usage: dict[str, Any] = {}
|
|
61
|
+
self.answer_logged = False
|
|
62
|
+
self._tool_start_times: dict[str, float] = {}
|
|
63
|
+
self._tool_names: dict[str, str] = {}
|
|
64
|
+
|
|
65
|
+
def on_chain_start(
|
|
66
|
+
self,
|
|
67
|
+
serialized: dict[str, Any],
|
|
68
|
+
inputs: Any,
|
|
69
|
+
**kwargs: Any,
|
|
70
|
+
) -> None:
|
|
71
|
+
query = self.query_extractor(inputs)
|
|
72
|
+
if query:
|
|
73
|
+
self._ensure_trace(
|
|
74
|
+
query=query,
|
|
75
|
+
event="chain_start",
|
|
76
|
+
metadata=_event_metadata(serialized, kwargs),
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
def on_retriever_start(
|
|
80
|
+
self,
|
|
81
|
+
serialized: dict[str, Any],
|
|
82
|
+
query: str,
|
|
83
|
+
**kwargs: Any,
|
|
84
|
+
) -> None:
|
|
85
|
+
self.retriever_start_time = time.perf_counter()
|
|
86
|
+
self._ensure_trace(
|
|
87
|
+
query=query,
|
|
88
|
+
event="retriever_start",
|
|
89
|
+
metadata=_event_metadata(serialized, kwargs),
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
def on_retriever_end(self, documents: Iterable[Any], **kwargs: Any) -> None:
|
|
93
|
+
chunks = [self.document_converter(document, index) for index, document in enumerate(documents)]
|
|
94
|
+
self.retrieved_chunks = chunks
|
|
95
|
+
|
|
96
|
+
if self.trace is None:
|
|
97
|
+
self._ensure_trace(
|
|
98
|
+
query=self.query or "unknown query",
|
|
99
|
+
event="retriever_end",
|
|
100
|
+
metadata=_event_metadata(None, kwargs),
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
if not chunks or self.trace is None:
|
|
104
|
+
return
|
|
105
|
+
|
|
106
|
+
self.trace.log_retrieval(
|
|
107
|
+
chunks,
|
|
108
|
+
retriever_name=_serialized_name(kwargs.get("serialized")) or "langchain_retriever",
|
|
109
|
+
metadata={
|
|
110
|
+
**_event_metadata(None, kwargs),
|
|
111
|
+
"latency_ms": _elapsed_ms(self.retriever_start_time),
|
|
112
|
+
},
|
|
113
|
+
)
|
|
114
|
+
selected = chunks[: self.selected_context_limit] if self.selected_context_limit else chunks
|
|
115
|
+
self.trace.log_context(
|
|
116
|
+
selected,
|
|
117
|
+
metadata={
|
|
118
|
+
"source": "langchain_retriever_end",
|
|
119
|
+
"selected_context_limit": self.selected_context_limit,
|
|
120
|
+
},
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
def on_llm_start(self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any) -> None:
|
|
124
|
+
model = _serialized_name(serialized)
|
|
125
|
+
if model:
|
|
126
|
+
self.llm_model = model
|
|
127
|
+
if self.trace is None:
|
|
128
|
+
query = self.query or (prompts[0] if prompts else "unknown query")
|
|
129
|
+
self._ensure_trace(
|
|
130
|
+
query=query,
|
|
131
|
+
event="llm_start",
|
|
132
|
+
metadata=_event_metadata(serialized, kwargs),
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
def on_llm_end(self, response: Any, **kwargs: Any) -> None:
|
|
136
|
+
self.llm_usage = _extract_token_usage(response)
|
|
137
|
+
self.llm_model = _extract_model(response) or self.llm_model
|
|
138
|
+
|
|
139
|
+
def on_chain_end(self, outputs: Any, **kwargs: Any) -> None:
|
|
140
|
+
answer = self.answer_extractor(outputs)
|
|
141
|
+
if not answer:
|
|
142
|
+
return
|
|
143
|
+
|
|
144
|
+
if self.trace is None:
|
|
145
|
+
self._ensure_trace(
|
|
146
|
+
query=self.query or "unknown query",
|
|
147
|
+
event="chain_end",
|
|
148
|
+
metadata=_event_metadata(None, kwargs),
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
if self.trace is None or self.answer_logged:
|
|
152
|
+
return
|
|
153
|
+
|
|
154
|
+
self.trace.log_answer(
|
|
155
|
+
answer,
|
|
156
|
+
model=self.llm_model,
|
|
157
|
+
usage=self.llm_usage,
|
|
158
|
+
metadata=self._merge_metadata(
|
|
159
|
+
outputs,
|
|
160
|
+
{
|
|
161
|
+
"latency_ms": self._latency_ms(),
|
|
162
|
+
"langchain_output_keys": list(outputs.keys()) if isinstance(outputs, dict) else [],
|
|
163
|
+
},
|
|
164
|
+
),
|
|
165
|
+
)
|
|
166
|
+
citations = list(self.citation_extractor(outputs))
|
|
167
|
+
if citations:
|
|
168
|
+
self.trace.log_citations(citations)
|
|
169
|
+
self.answer_logged = True
|
|
170
|
+
|
|
171
|
+
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
|
172
|
+
if self.trace is not None and self.log_agent_events:
|
|
173
|
+
self.trace.log_agent_error(
|
|
174
|
+
str(error),
|
|
175
|
+
name="langchain_chain_error",
|
|
176
|
+
metadata={
|
|
177
|
+
"error_type": error.__class__.__name__,
|
|
178
|
+
**_event_metadata(None, kwargs),
|
|
179
|
+
},
|
|
180
|
+
latency_ms=self._latency_ms(),
|
|
181
|
+
)
|
|
182
|
+
if self.trace is not None and not self.answer_logged:
|
|
183
|
+
self.trace.log_answer(
|
|
184
|
+
"LangChain run failed before producing an answer.",
|
|
185
|
+
metadata={
|
|
186
|
+
"latency_ms": self._latency_ms(),
|
|
187
|
+
"error": str(error),
|
|
188
|
+
"error_type": error.__class__.__name__,
|
|
189
|
+
},
|
|
190
|
+
)
|
|
191
|
+
self.answer_logged = True
|
|
192
|
+
|
|
193
|
+
def on_tool_start(self, serialized: dict[str, Any], input_str: str, **kwargs: Any) -> None:
|
|
194
|
+
if not self.log_agent_events:
|
|
195
|
+
return
|
|
196
|
+
if self.trace is None:
|
|
197
|
+
self._ensure_trace(
|
|
198
|
+
query=self.query or input_str or "unknown query",
|
|
199
|
+
event="tool_start",
|
|
200
|
+
metadata=_event_metadata(serialized, kwargs),
|
|
201
|
+
)
|
|
202
|
+
if self.trace is None:
|
|
203
|
+
return
|
|
204
|
+
run_id = str(kwargs.get("run_id") or _serialized_name(serialized) or input_str)
|
|
205
|
+
tool_name = _serialized_name(serialized) or kwargs.get("name") or "langchain_tool"
|
|
206
|
+
self._tool_start_times[run_id] = time.perf_counter()
|
|
207
|
+
self._tool_names[run_id] = str(tool_name)
|
|
208
|
+
self.trace.log_tool_call(
|
|
209
|
+
str(tool_name),
|
|
210
|
+
input_json={"input": input_str},
|
|
211
|
+
metadata=_event_metadata(serialized, kwargs),
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
|
|
215
|
+
if not self.log_agent_events or self.trace is None:
|
|
216
|
+
return
|
|
217
|
+
run_id = str(kwargs.get("run_id") or "langchain_tool")
|
|
218
|
+
tool_name = self._tool_names.get(run_id) or kwargs.get("name") or "langchain_tool"
|
|
219
|
+
self.trace.log_tool_result(
|
|
220
|
+
str(tool_name),
|
|
221
|
+
output_json=_json_safe(output),
|
|
222
|
+
metadata=_event_metadata(None, kwargs),
|
|
223
|
+
latency_ms=_elapsed_ms(self._tool_start_times.get(run_id)),
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
|
227
|
+
if not self.log_agent_events or self.trace is None:
|
|
228
|
+
return
|
|
229
|
+
run_id = str(kwargs.get("run_id") or "langchain_tool")
|
|
230
|
+
tool_name = self._tool_names.get(run_id) or kwargs.get("name") or "langchain_tool"
|
|
231
|
+
self.trace.log_agent_error(
|
|
232
|
+
str(error),
|
|
233
|
+
name=str(tool_name),
|
|
234
|
+
metadata={
|
|
235
|
+
"error_type": error.__class__.__name__,
|
|
236
|
+
**_event_metadata(None, kwargs),
|
|
237
|
+
},
|
|
238
|
+
latency_ms=_elapsed_ms(self._tool_start_times.get(run_id)),
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
def _ensure_trace(
|
|
242
|
+
self,
|
|
243
|
+
*,
|
|
244
|
+
query: str,
|
|
245
|
+
event: str,
|
|
246
|
+
metadata: dict[str, Any],
|
|
247
|
+
) -> None:
|
|
248
|
+
if self.trace is not None:
|
|
249
|
+
return
|
|
250
|
+
|
|
251
|
+
self.query = query
|
|
252
|
+
self.start_time = time.perf_counter()
|
|
253
|
+
trace_metadata = dict(self.trace_metadata)
|
|
254
|
+
trace_metadata.update(
|
|
255
|
+
{
|
|
256
|
+
"integration": "langchain",
|
|
257
|
+
"start_event": event,
|
|
258
|
+
"langchain": metadata,
|
|
259
|
+
}
|
|
260
|
+
)
|
|
261
|
+
self.trace = self.client.trace(query=query, metadata=trace_metadata).__enter__()
|
|
262
|
+
|
|
263
|
+
def _latency_ms(self) -> int:
|
|
264
|
+
if self.start_time is None:
|
|
265
|
+
return 0
|
|
266
|
+
return int((time.perf_counter() - self.start_time) * 1000)
|
|
267
|
+
|
|
268
|
+
def _merge_metadata(self, source: Any, base: dict[str, Any]) -> dict[str, Any]:
|
|
269
|
+
if not self.metadata_extractor:
|
|
270
|
+
return base
|
|
271
|
+
extracted = self.metadata_extractor(source, base)
|
|
272
|
+
if not extracted:
|
|
273
|
+
return base
|
|
274
|
+
merged = dict(base)
|
|
275
|
+
merged.update(extracted)
|
|
276
|
+
return merged
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def langchain_document_to_chunk(document: Any, index: int = 0) -> dict[str, Any]:
|
|
280
|
+
metadata = getattr(document, "metadata", None) or {}
|
|
281
|
+
if not isinstance(metadata, dict):
|
|
282
|
+
metadata = {"metadata": metadata}
|
|
283
|
+
|
|
284
|
+
content = (
|
|
285
|
+
getattr(document, "page_content", None)
|
|
286
|
+
or getattr(document, "content", None)
|
|
287
|
+
or getattr(document, "text", None)
|
|
288
|
+
)
|
|
289
|
+
if content is None and isinstance(document, dict):
|
|
290
|
+
content = document.get("page_content") or document.get("content") or document.get("text")
|
|
291
|
+
metadata = document.get("metadata") or metadata
|
|
292
|
+
|
|
293
|
+
if content is None:
|
|
294
|
+
raise ValueError("LangChain document must include page_content, content, or text.")
|
|
295
|
+
|
|
296
|
+
chunk_id = (
|
|
297
|
+
metadata.get("chunk_id")
|
|
298
|
+
or metadata.get("id")
|
|
299
|
+
or metadata.get("doc_id")
|
|
300
|
+
or getattr(document, "id", None)
|
|
301
|
+
or f"langchain_doc_{index}"
|
|
302
|
+
)
|
|
303
|
+
source = metadata.get("source") or metadata.get("url") or metadata.get("path")
|
|
304
|
+
relevance_score = (
|
|
305
|
+
metadata.get("relevance_score")
|
|
306
|
+
or metadata.get("score")
|
|
307
|
+
or getattr(document, "score", None)
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
return {
|
|
311
|
+
"chunk_id": str(chunk_id),
|
|
312
|
+
"content": str(content),
|
|
313
|
+
"source": source,
|
|
314
|
+
"metadata": metadata,
|
|
315
|
+
"relevance_score": relevance_score,
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def _extract_query(inputs: Any) -> Optional[str]:
|
|
320
|
+
if isinstance(inputs, str):
|
|
321
|
+
return inputs
|
|
322
|
+
if not isinstance(inputs, dict):
|
|
323
|
+
return None
|
|
324
|
+
|
|
325
|
+
for key in ("query", "question", "input", "prompt"):
|
|
326
|
+
value = inputs.get(key)
|
|
327
|
+
if isinstance(value, str) and value.strip():
|
|
328
|
+
return value
|
|
329
|
+
|
|
330
|
+
for value in inputs.values():
|
|
331
|
+
if isinstance(value, str) and value.strip():
|
|
332
|
+
return value
|
|
333
|
+
return None
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def _extract_answer(outputs: Any) -> Optional[str]:
|
|
337
|
+
if isinstance(outputs, str):
|
|
338
|
+
return outputs
|
|
339
|
+
if not isinstance(outputs, dict):
|
|
340
|
+
return None
|
|
341
|
+
|
|
342
|
+
for key in ("answer", "output", "result", "text", "response"):
|
|
343
|
+
value = outputs.get(key)
|
|
344
|
+
if isinstance(value, str) and value.strip():
|
|
345
|
+
return value
|
|
346
|
+
|
|
347
|
+
for value in outputs.values():
|
|
348
|
+
if isinstance(value, str) and value.strip():
|
|
349
|
+
return value
|
|
350
|
+
return None
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
def _extract_citations(outputs: Any) -> Iterable[dict[str, Any]]:
|
|
354
|
+
if not isinstance(outputs, dict):
|
|
355
|
+
return []
|
|
356
|
+
raw = outputs.get("citations") or outputs.get("citation_checks") or []
|
|
357
|
+
if not isinstance(raw, RuntimeIterable) or isinstance(raw, (str, bytes)):
|
|
358
|
+
return []
|
|
359
|
+
citations = []
|
|
360
|
+
for citation in raw:
|
|
361
|
+
if not isinstance(citation, dict):
|
|
362
|
+
continue
|
|
363
|
+
claim = citation.get("claim")
|
|
364
|
+
source_chunk_id = citation.get("source_chunk_id") or citation.get("chunk_id") or citation.get("source")
|
|
365
|
+
if claim and source_chunk_id:
|
|
366
|
+
citations.append({"claim": str(claim), "source_chunk_id": str(source_chunk_id)})
|
|
367
|
+
return citations
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def _extract_token_usage(response: Any) -> dict[str, Any]:
|
|
371
|
+
llm_output = getattr(response, "llm_output", None) or {}
|
|
372
|
+
if isinstance(llm_output, dict):
|
|
373
|
+
token_usage = llm_output.get("token_usage") or llm_output.get("usage")
|
|
374
|
+
if isinstance(token_usage, dict):
|
|
375
|
+
return token_usage
|
|
376
|
+
return {}
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
def _extract_model(response: Any) -> Optional[str]:
|
|
380
|
+
llm_output = getattr(response, "llm_output", None) or {}
|
|
381
|
+
if isinstance(llm_output, dict):
|
|
382
|
+
model = llm_output.get("model_name") or llm_output.get("model")
|
|
383
|
+
if isinstance(model, str):
|
|
384
|
+
return model
|
|
385
|
+
return None
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
def _elapsed_ms(start_time: Optional[float]) -> Optional[int]:
|
|
389
|
+
if start_time is None:
|
|
390
|
+
return None
|
|
391
|
+
return int((time.perf_counter() - start_time) * 1000)
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
def _json_safe(value: Any) -> Any:
|
|
395
|
+
if value is None or isinstance(value, (str, int, float, bool)):
|
|
396
|
+
return value
|
|
397
|
+
if isinstance(value, dict):
|
|
398
|
+
return {str(key): _json_safe(item) for key, item in value.items()}
|
|
399
|
+
if isinstance(value, (list, tuple, set)):
|
|
400
|
+
return [_json_safe(item) for item in value]
|
|
401
|
+
return str(value)
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
def _event_metadata(serialized: Any, kwargs: dict[str, Any]) -> dict[str, Any]:
|
|
405
|
+
metadata: dict[str, Any] = {}
|
|
406
|
+
name = _serialized_name(serialized)
|
|
407
|
+
if name:
|
|
408
|
+
metadata["serialized_name"] = name
|
|
409
|
+
|
|
410
|
+
callback_metadata = kwargs.get("metadata")
|
|
411
|
+
if isinstance(callback_metadata, dict):
|
|
412
|
+
metadata["metadata"] = callback_metadata
|
|
413
|
+
|
|
414
|
+
tags = kwargs.get("tags")
|
|
415
|
+
if tags:
|
|
416
|
+
metadata["tags"] = list(tags)
|
|
417
|
+
|
|
418
|
+
run_id = kwargs.get("run_id")
|
|
419
|
+
if run_id is not None:
|
|
420
|
+
metadata["run_id"] = str(run_id)
|
|
421
|
+
|
|
422
|
+
parent_run_id = kwargs.get("parent_run_id")
|
|
423
|
+
if parent_run_id is not None:
|
|
424
|
+
metadata["parent_run_id"] = str(parent_run_id)
|
|
425
|
+
|
|
426
|
+
return metadata
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
def _serialized_name(serialized: Any) -> Optional[str]:
|
|
430
|
+
if not isinstance(serialized, dict):
|
|
431
|
+
return None
|
|
432
|
+
name = serialized.get("name")
|
|
433
|
+
if isinstance(name, str):
|
|
434
|
+
return name
|
|
435
|
+
serialized_id = serialized.get("id")
|
|
436
|
+
if isinstance(serialized_id, list) and serialized_id:
|
|
437
|
+
return str(serialized_id[-1])
|
|
438
|
+
if isinstance(serialized_id, str):
|
|
439
|
+
return serialized_id
|
|
440
|
+
return None
|
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
import time
|
|
5
|
+
from functools import wraps
|
|
6
|
+
from typing import Any, Callable, Optional
|
|
7
|
+
|
|
8
|
+
from contexttrace.client import ContextTrace, TraceSession
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ContextTraceLangGraphTracer:
|
|
12
|
+
"""Beta LangGraph adapter for logging graph nodes, tools, memory, and errors."""
|
|
13
|
+
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
*,
|
|
17
|
+
project: Optional[str] = None,
|
|
18
|
+
api_key: Optional[str] = None,
|
|
19
|
+
base_url: str = "http://localhost:8000",
|
|
20
|
+
client: Optional[ContextTrace] = None,
|
|
21
|
+
trace_metadata: Optional[dict[str, Any]] = None,
|
|
22
|
+
) -> None:
|
|
23
|
+
if client is None:
|
|
24
|
+
kwargs: dict[str, Any] = {"project": project or "default"}
|
|
25
|
+
if api_key:
|
|
26
|
+
kwargs.update({"api_key": api_key, "base_url": base_url, "mode": "hosted"})
|
|
27
|
+
client = ContextTrace(**kwargs)
|
|
28
|
+
self.client = client
|
|
29
|
+
self.trace_metadata = trace_metadata or {}
|
|
30
|
+
self.trace: Optional[TraceSession] = None
|
|
31
|
+
self.query: Optional[str] = None
|
|
32
|
+
self._node_starts: dict[str, float] = {}
|
|
33
|
+
|
|
34
|
+
def start_trace(self, query: str, *, metadata: Optional[dict[str, Any]] = None) -> TraceSession:
|
|
35
|
+
if self.trace is not None:
|
|
36
|
+
return self.trace
|
|
37
|
+
self.query = query
|
|
38
|
+
trace_metadata = {
|
|
39
|
+
**self.trace_metadata,
|
|
40
|
+
**(metadata or {}),
|
|
41
|
+
"integration": "langgraph",
|
|
42
|
+
}
|
|
43
|
+
self.trace = self.client.trace(query=query, metadata=trace_metadata).__enter__()
|
|
44
|
+
return self.trace
|
|
45
|
+
|
|
46
|
+
def end_trace(
|
|
47
|
+
self,
|
|
48
|
+
*,
|
|
49
|
+
answer: Optional[str] = None,
|
|
50
|
+
metadata: Optional[dict[str, Any]] = None,
|
|
51
|
+
) -> Optional[TraceSession]:
|
|
52
|
+
if self.trace is None:
|
|
53
|
+
return None
|
|
54
|
+
if answer:
|
|
55
|
+
self.trace.log_answer(answer, metadata=metadata or {})
|
|
56
|
+
self.trace.log_agent_event(
|
|
57
|
+
event_type="final_answer",
|
|
58
|
+
name="final_answer",
|
|
59
|
+
output_json={"answer": answer},
|
|
60
|
+
metadata=metadata or {},
|
|
61
|
+
)
|
|
62
|
+
trace = self.trace
|
|
63
|
+
self.trace = None
|
|
64
|
+
return trace
|
|
65
|
+
|
|
66
|
+
def on_node_start(
|
|
67
|
+
self,
|
|
68
|
+
name: str,
|
|
69
|
+
input_json: Any = None,
|
|
70
|
+
*,
|
|
71
|
+
event_type: str = "planner_step",
|
|
72
|
+
metadata: Optional[dict[str, Any]] = None,
|
|
73
|
+
) -> None:
|
|
74
|
+
trace = self._ensure_trace(input_json)
|
|
75
|
+
self._node_starts[name] = time.perf_counter()
|
|
76
|
+
trace.log_agent_event(
|
|
77
|
+
event_type=event_type,
|
|
78
|
+
name=name,
|
|
79
|
+
input_json=input_json,
|
|
80
|
+
metadata={"phase": "start", **(metadata or {})},
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
def on_node_end(
|
|
84
|
+
self,
|
|
85
|
+
name: str,
|
|
86
|
+
output_json: Any = None,
|
|
87
|
+
*,
|
|
88
|
+
event_type: str = "planner_step",
|
|
89
|
+
metadata: Optional[dict[str, Any]] = None,
|
|
90
|
+
) -> None:
|
|
91
|
+
trace = self._ensure_trace(output_json)
|
|
92
|
+
trace.log_agent_event(
|
|
93
|
+
event_type=event_type,
|
|
94
|
+
name=name,
|
|
95
|
+
output_json=output_json,
|
|
96
|
+
metadata={"phase": "end", **(metadata or {})},
|
|
97
|
+
latency_ms=_elapsed_ms(self._node_starts.get(name)),
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
def on_tool_start(self, name: str, input_json: Any = None, *, metadata: Optional[dict[str, Any]] = None) -> None:
|
|
101
|
+
self._node_starts[name] = time.perf_counter()
|
|
102
|
+
self._ensure_trace(input_json).log_tool_call(name, input_json=input_json, metadata=metadata)
|
|
103
|
+
|
|
104
|
+
def on_tool_end(
|
|
105
|
+
self,
|
|
106
|
+
name: str,
|
|
107
|
+
output_json: Any = None,
|
|
108
|
+
*,
|
|
109
|
+
input_json: Any = None,
|
|
110
|
+
metadata: Optional[dict[str, Any]] = None,
|
|
111
|
+
) -> None:
|
|
112
|
+
self._ensure_trace(output_json).log_tool_result(
|
|
113
|
+
name,
|
|
114
|
+
input_json=input_json,
|
|
115
|
+
output_json=output_json,
|
|
116
|
+
metadata=metadata,
|
|
117
|
+
latency_ms=_elapsed_ms(self._node_starts.get(name)),
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
def on_error(self, name: str, error: BaseException, *, input_json: Any = None) -> None:
|
|
121
|
+
self._ensure_trace(input_json).log_agent_error(
|
|
122
|
+
str(error),
|
|
123
|
+
name=name,
|
|
124
|
+
input_json=input_json,
|
|
125
|
+
metadata={"error_type": error.__class__.__name__},
|
|
126
|
+
latency_ms=_elapsed_ms(self._node_starts.get(name)),
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
def wrap_node(
|
|
130
|
+
self,
|
|
131
|
+
name: str,
|
|
132
|
+
func: Callable[..., Any],
|
|
133
|
+
*,
|
|
134
|
+
event_type: str = "planner_step",
|
|
135
|
+
) -> Callable[..., Any]:
|
|
136
|
+
if inspect.iscoroutinefunction(func):
|
|
137
|
+
|
|
138
|
+
@wraps(func)
|
|
139
|
+
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
140
|
+
input_json = {"args": _json_safe(args), "kwargs": _json_safe(kwargs)}
|
|
141
|
+
self.on_node_start(name, input_json, event_type=event_type)
|
|
142
|
+
try:
|
|
143
|
+
output = await func(*args, **kwargs)
|
|
144
|
+
self.on_node_end(name, _json_safe(output), event_type=event_type)
|
|
145
|
+
return output
|
|
146
|
+
except BaseException as exc:
|
|
147
|
+
self.on_error(name, exc, input_json=input_json)
|
|
148
|
+
raise
|
|
149
|
+
|
|
150
|
+
return async_wrapper
|
|
151
|
+
|
|
152
|
+
@wraps(func)
|
|
153
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
154
|
+
input_json = {"args": _json_safe(args), "kwargs": _json_safe(kwargs)}
|
|
155
|
+
self.on_node_start(name, input_json, event_type=event_type)
|
|
156
|
+
try:
|
|
157
|
+
output = func(*args, **kwargs)
|
|
158
|
+
self.on_node_end(name, _json_safe(output), event_type=event_type)
|
|
159
|
+
return output
|
|
160
|
+
except BaseException as exc:
|
|
161
|
+
self.on_error(name, exc, input_json=input_json)
|
|
162
|
+
raise
|
|
163
|
+
|
|
164
|
+
return wrapper
|
|
165
|
+
|
|
166
|
+
def _ensure_trace(self, value: Any = None) -> TraceSession:
|
|
167
|
+
if self.trace is not None:
|
|
168
|
+
return self.trace
|
|
169
|
+
query = _query_from_value(value) or self.query or "langgraph run"
|
|
170
|
+
return self.start_trace(query)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def _elapsed_ms(start_time: Optional[float]) -> Optional[int]:
|
|
174
|
+
if start_time is None:
|
|
175
|
+
return None
|
|
176
|
+
return int((time.perf_counter() - start_time) * 1000)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def _query_from_value(value: Any) -> Optional[str]:
|
|
180
|
+
if isinstance(value, str):
|
|
181
|
+
return value
|
|
182
|
+
if isinstance(value, dict):
|
|
183
|
+
for key in ("query", "question", "input", "prompt"):
|
|
184
|
+
candidate = value.get(key)
|
|
185
|
+
if isinstance(candidate, str) and candidate.strip():
|
|
186
|
+
return candidate
|
|
187
|
+
return None
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def _json_safe(value: Any) -> Any:
|
|
191
|
+
if value is None or isinstance(value, (str, int, float, bool)):
|
|
192
|
+
return value
|
|
193
|
+
if isinstance(value, dict):
|
|
194
|
+
return {str(key): _json_safe(item) for key, item in value.items()}
|
|
195
|
+
if isinstance(value, (list, tuple, set)):
|
|
196
|
+
return [_json_safe(item) for item in value]
|
|
197
|
+
return str(value)
|