contextforge-eval 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. context_forge/__init__.py +95 -0
  2. context_forge/core/__init__.py +55 -0
  3. context_forge/core/trace.py +369 -0
  4. context_forge/core/types.py +121 -0
  5. context_forge/evaluation.py +267 -0
  6. context_forge/exceptions.py +56 -0
  7. context_forge/graders/__init__.py +44 -0
  8. context_forge/graders/base.py +264 -0
  9. context_forge/graders/deterministic/__init__.py +11 -0
  10. context_forge/graders/deterministic/memory_corruption.py +130 -0
  11. context_forge/graders/hybrid.py +190 -0
  12. context_forge/graders/judges/__init__.py +11 -0
  13. context_forge/graders/judges/backends/__init__.py +9 -0
  14. context_forge/graders/judges/backends/ollama.py +173 -0
  15. context_forge/graders/judges/base.py +158 -0
  16. context_forge/graders/judges/memory_hygiene_judge.py +332 -0
  17. context_forge/graders/judges/models.py +113 -0
  18. context_forge/harness/__init__.py +43 -0
  19. context_forge/harness/user_simulator/__init__.py +70 -0
  20. context_forge/harness/user_simulator/adapters/__init__.py +13 -0
  21. context_forge/harness/user_simulator/adapters/base.py +67 -0
  22. context_forge/harness/user_simulator/adapters/crewai.py +100 -0
  23. context_forge/harness/user_simulator/adapters/langgraph.py +157 -0
  24. context_forge/harness/user_simulator/adapters/pydanticai.py +105 -0
  25. context_forge/harness/user_simulator/llm/__init__.py +5 -0
  26. context_forge/harness/user_simulator/llm/ollama.py +119 -0
  27. context_forge/harness/user_simulator/models.py +103 -0
  28. context_forge/harness/user_simulator/persona.py +154 -0
  29. context_forge/harness/user_simulator/runner.py +342 -0
  30. context_forge/harness/user_simulator/scenario.py +95 -0
  31. context_forge/harness/user_simulator/simulator.py +307 -0
  32. context_forge/instrumentation/__init__.py +23 -0
  33. context_forge/instrumentation/base.py +307 -0
  34. context_forge/instrumentation/instrumentors/__init__.py +17 -0
  35. context_forge/instrumentation/instrumentors/langchain.py +671 -0
  36. context_forge/instrumentation/instrumentors/langgraph.py +534 -0
  37. context_forge/instrumentation/tracer.py +588 -0
  38. context_forge/py.typed +0 -0
  39. contextforge_eval-0.1.0.dist-info/METADATA +420 -0
  40. contextforge_eval-0.1.0.dist-info/RECORD +43 -0
  41. contextforge_eval-0.1.0.dist-info/WHEEL +5 -0
  42. contextforge_eval-0.1.0.dist-info/licenses/LICENSE +201 -0
  43. contextforge_eval-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,671 @@
1
+ """LangChain instrumentor for ContextForge.
2
+
3
+ This module implements:
4
+ - T039: LangChainInstrumentor
5
+ - T040: LangChain callback hooks for LLM, Tool, Retriever
6
+ - T041: Token usage capture from LangChain callbacks
7
+ """
8
+
9
+ import logging
10
+ import time
11
+ import uuid
12
+ from datetime import datetime, timezone
13
+ from pathlib import Path
14
+ from typing import TYPE_CHECKING, Any, Optional, Sequence
15
+ from uuid import UUID
16
+
17
+ from context_forge.core.trace import (
18
+ FinalOutputStep,
19
+ LLMCallStep,
20
+ RetrievalStep,
21
+ ToolCallStep,
22
+ UserInputStep,
23
+ )
24
+ from context_forge.core.types import RetrievalResult
25
+ from context_forge.instrumentation.base import BaseInstrumentor, RedactionConfig
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ # Try to import LangChain's BaseCallbackHandler for proper inheritance
30
+ try:
31
+ from langchain_core.callbacks import BaseCallbackHandler
32
+
33
+ _LANGCHAIN_AVAILABLE = True
34
+ except ImportError:
35
+ BaseCallbackHandler = object # type: ignore
36
+ _LANGCHAIN_AVAILABLE = False
37
+
38
+ if TYPE_CHECKING:
39
+ from langchain_core.callbacks import BaseCallbackHandler
40
+
41
+
42
+ class LangChainInstrumentor(BaseInstrumentor):
43
+ """Auto-instrumentor for LangChain/LangGraph agents.
44
+
45
+ Provides one-line instrumentation for LangChain-based agents
46
+ by installing a global callback handler.
47
+
48
+ Usage:
49
+ LangChainInstrumentor().instrument()
50
+ # ... your LangChain code runs normally ...
51
+ # All LLM calls, tool calls, and retrievals are captured
52
+
53
+ Or with context manager:
54
+ with LangChainInstrumentor(output_path="./traces") as instrumentor:
55
+ chain.invoke({"input": "hello"})
56
+ # Traces saved automatically
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ agent_name: str = "langchain_agent",
62
+ agent_version: Optional[str] = None,
63
+ output_path: Optional[str | Path] = None,
64
+ redaction_config: Optional[RedactionConfig] = None,
65
+ ):
66
+ super().__init__(
67
+ agent_name=agent_name,
68
+ agent_version=agent_version,
69
+ output_path=output_path,
70
+ redaction_config=redaction_config,
71
+ )
72
+ self._handler: Optional["ContextForgeCallbackHandler"] = None
73
+ self._original_handlers: list[Any] = []
74
+ self._framework_version: Optional[str] = None
75
+
76
+ @property
77
+ def framework(self) -> str:
78
+ return "langchain"
79
+
80
+ @property
81
+ def framework_version(self) -> Optional[str]:
82
+ if self._framework_version is None:
83
+ try:
84
+ import langchain_core
85
+
86
+ self._framework_version = getattr(langchain_core, "__version__", None)
87
+ except ImportError:
88
+ pass
89
+ return self._framework_version
90
+
91
+ def _install_hooks(self) -> None:
92
+ """Install LangChain callback handler globally."""
93
+ if not _LANGCHAIN_AVAILABLE:
94
+ raise ImportError(
95
+ "LangChain is required for LangChainInstrumentor. "
96
+ "Install it with: pip install langchain-core"
97
+ )
98
+
99
+ # Create our callback handler
100
+ self._handler = ContextForgeCallbackHandler(instrumentor=self)
101
+ logger.debug("LangChain instrumentor handler ready")
102
+
103
+ def _remove_hooks(self) -> None:
104
+ """Remove LangChain callback handler."""
105
+ self._handler = None
106
+
107
+ def get_callback_handler(self) -> "ContextForgeCallbackHandler":
108
+ """Get the callback handler for explicit use.
109
+
110
+ Useful when you want to pass the handler explicitly:
111
+ handler = instrumentor.get_callback_handler()
112
+ chain.invoke(input, config={"callbacks": [handler]})
113
+
114
+ Returns:
115
+ The ContextForgeCallbackHandler instance
116
+ """
117
+ if self._handler is None:
118
+ if not _LANGCHAIN_AVAILABLE:
119
+ raise ImportError(
120
+ "LangChain is required for LangChainInstrumentor. "
121
+ "Install it with: pip install langchain-core"
122
+ )
123
+ self._handler = ContextForgeCallbackHandler(instrumentor=self)
124
+ return self._handler
125
+
126
+
127
+ class ContextForgeCallbackHandler(BaseCallbackHandler): # type: ignore[misc]
128
+ """LangChain callback handler that captures trace events.
129
+
130
+ Implements LangChain's callback interface to capture LLM calls,
131
+ tool executions, and retrieval operations.
132
+
133
+ Inherits from langchain_core.callbacks.BaseCallbackHandler when
134
+ LangChain is installed, ensuring compatibility with LangChain/LangGraph.
135
+ """
136
+
137
+ def __init__(self, instrumentor: LangChainInstrumentor):
138
+ # Call parent __init__ if we're inheriting from BaseCallbackHandler
139
+ if _LANGCHAIN_AVAILABLE and hasattr(super(), "__init__"):
140
+ super().__init__()
141
+ self._instrumentor = instrumentor
142
+ self._run_id_to_step_id: dict[str, str] = {}
143
+ self._run_id_to_start_time: dict[str, float] = {}
144
+ self._run_id_to_input: dict[str, Any] = {}
145
+ self._run_id_to_model: dict[str, str] = {}
146
+ self._run_id_to_tool_name: dict[str, str] = {}
147
+ self._run_id_to_node: dict[str, str | None] = {}
148
+ self._parent_run_id_map: dict[str, str] = {}
149
+
150
+ def _extract_node_name(self, tags: list[str] | None, metadata: dict[str, Any] | None) -> str | None:
151
+ """Extract LangGraph node name from tags or metadata.
152
+
153
+ LangGraph tags often include patterns like:
154
+ - 'graph:step:recommend'
155
+ - 'langgraph:step:2'
156
+ - 'seq:step:1'
157
+
158
+ Metadata may include 'langgraph_node' or 'node' keys.
159
+ """
160
+ # Check metadata first
161
+ if metadata:
162
+ if "langgraph_node" in metadata:
163
+ return metadata["langgraph_node"]
164
+ if "node" in metadata:
165
+ return metadata["node"]
166
+
167
+ # Parse tags for node name
168
+ if tags:
169
+ for tag in tags:
170
+ # Look for LangGraph step patterns
171
+ if tag.startswith("graph:step:"):
172
+ return tag.split(":")[-1]
173
+ if tag.startswith("langgraph_step:"):
174
+ return tag.split(":")[-1]
175
+ # Also check for plain node names in tags
176
+ # LangGraph sometimes includes node names directly
177
+ for tag in tags:
178
+ if not tag.startswith(("seq:", "graph:", "langgraph")):
179
+ # Could be a node name
180
+ return tag
181
+
182
+ return None
183
+
184
+ def _get_step_id(self, run_id: UUID) -> str:
185
+ """Get or create step ID for a run ID."""
186
+ run_id_str = str(run_id)
187
+ if run_id_str not in self._run_id_to_step_id:
188
+ self._run_id_to_step_id[run_id_str] = str(uuid.uuid4())
189
+ return self._run_id_to_step_id[run_id_str]
190
+
191
+ def _get_parent_step_id(self, parent_run_id: Optional[UUID]) -> Optional[str]:
192
+ """Get parent step ID if parent run exists."""
193
+ if parent_run_id is None:
194
+ return None
195
+ return self._run_id_to_step_id.get(str(parent_run_id))
196
+
197
+ # LLM Callbacks
198
+ def _extract_model_name(self, serialized: dict[str, Any]) -> str:
199
+ """Extract model name from serialized LLM config."""
200
+ # Try kwargs first (most common)
201
+ kwargs = serialized.get("kwargs", {})
202
+ if "model_name" in kwargs:
203
+ return kwargs["model_name"]
204
+ if "model" in kwargs:
205
+ return kwargs["model"]
206
+
207
+ # Try id field (e.g., ["langchain", "chat_models", "openai", "ChatOpenAI"])
208
+ id_list = serialized.get("id", [])
209
+ if id_list and len(id_list) > 0:
210
+ # Use the last non-class identifier
211
+ return id_list[-1]
212
+
213
+ # Try name field
214
+ if "name" in serialized:
215
+ return serialized["name"]
216
+
217
+ return "unknown"
218
+
219
+ def on_llm_start(
220
+ self,
221
+ serialized: dict[str, Any],
222
+ prompts: list[str],
223
+ *,
224
+ run_id: UUID,
225
+ parent_run_id: Optional[UUID] = None,
226
+ tags: Optional[list[str]] = None,
227
+ metadata: Optional[dict[str, Any]] = None,
228
+ **kwargs: Any,
229
+ ) -> None:
230
+ """Called when LLM starts."""
231
+ run_id_str = str(run_id)
232
+ self._run_id_to_start_time[run_id_str] = time.perf_counter()
233
+ self._run_id_to_input[run_id_str] = prompts[0] if len(prompts) == 1 else prompts
234
+ self._run_id_to_model[run_id_str] = self._extract_model_name(serialized)
235
+ self._run_id_to_node[run_id_str] = self._extract_node_name(tags, metadata)
236
+
237
+ if parent_run_id:
238
+ self._parent_run_id_map[run_id_str] = str(parent_run_id)
239
+
240
+ def on_chat_model_start(
241
+ self,
242
+ serialized: dict[str, Any],
243
+ messages: list[list[Any]],
244
+ *,
245
+ run_id: UUID,
246
+ parent_run_id: Optional[UUID] = None,
247
+ tags: Optional[list[str]] = None,
248
+ metadata: Optional[dict[str, Any]] = None,
249
+ **kwargs: Any,
250
+ ) -> None:
251
+ """Called when chat model starts."""
252
+ run_id_str = str(run_id)
253
+ self._run_id_to_start_time[run_id_str] = time.perf_counter()
254
+ self._run_id_to_model[run_id_str] = self._extract_model_name(serialized)
255
+ self._run_id_to_node[run_id_str] = self._extract_node_name(tags, metadata)
256
+
257
+ # Convert messages to serializable format
258
+ formatted_messages = []
259
+ for msg_list in messages:
260
+ for msg in msg_list:
261
+ if hasattr(msg, "type") and hasattr(msg, "content"):
262
+ formatted_messages.append({"role": msg.type, "content": msg.content})
263
+ elif isinstance(msg, dict):
264
+ formatted_messages.append(msg)
265
+
266
+ self._run_id_to_input[run_id_str] = formatted_messages
267
+
268
+ if parent_run_id:
269
+ self._parent_run_id_map[run_id_str] = str(parent_run_id)
270
+
271
+ def on_llm_end(
272
+ self,
273
+ response: Any,
274
+ *,
275
+ run_id: UUID,
276
+ parent_run_id: Optional[UUID] = None,
277
+ **kwargs: Any,
278
+ ) -> None:
279
+ """Called when LLM ends."""
280
+ run_id_str = str(run_id)
281
+ step_id = self._get_step_id(run_id)
282
+
283
+ # Calculate latency
284
+ start_time = self._run_id_to_start_time.pop(run_id_str, None)
285
+ latency_ms = None
286
+ if start_time:
287
+ latency_ms = int((time.perf_counter() - start_time) * 1000)
288
+
289
+ # Get input, model, and node
290
+ input_data = self._run_id_to_input.pop(run_id_str, "")
291
+ model = self._run_id_to_model.pop(run_id_str, "unknown")
292
+ node_name = self._run_id_to_node.pop(run_id_str, None)
293
+
294
+ # Try to get model from response if not captured from start
295
+ if model == "unknown" and hasattr(response, "llm_output") and response.llm_output:
296
+ model = response.llm_output.get("model_name", model)
297
+
298
+ # Extract output and token usage
299
+ output = ""
300
+ tokens_in = None
301
+ tokens_out = None
302
+ tokens_total = None
303
+
304
+ if hasattr(response, "generations") and response.generations:
305
+ for gen_list in response.generations:
306
+ for gen in gen_list:
307
+ # Try to extract text output
308
+ if hasattr(gen, "text") and gen.text is not None:
309
+ output = gen.text
310
+ elif hasattr(gen, "message") and gen.message is not None:
311
+ if hasattr(gen.message, "content") and gen.message.content is not None:
312
+ output = gen.message.content
313
+ else:
314
+ output = str(gen.message)
315
+
316
+ # Extract token usage from message
317
+ if hasattr(gen, "message") and gen.message is not None:
318
+ if hasattr(gen.message, "usage_metadata") and gen.message.usage_metadata is not None:
319
+ usage = gen.message.usage_metadata
320
+ if hasattr(usage, "input_tokens"):
321
+ tokens_in = usage.input_tokens
322
+ if hasattr(usage, "output_tokens"):
323
+ tokens_out = usage.output_tokens
324
+ if hasattr(usage, "total_tokens"):
325
+ tokens_total = usage.total_tokens
326
+
327
+ # Also check llm_output for token usage
328
+ if hasattr(response, "llm_output") and response.llm_output:
329
+ token_usage = response.llm_output.get("token_usage", {})
330
+ if token_usage:
331
+ tokens_in = tokens_in or token_usage.get("prompt_tokens")
332
+ tokens_out = tokens_out or token_usage.get("completion_tokens")
333
+ tokens_total = tokens_total or token_usage.get("total_tokens")
334
+
335
+ # Build metadata with node info
336
+ step_metadata = None
337
+ if node_name:
338
+ step_metadata = {"node": node_name}
339
+
340
+ # Create step
341
+ step = LLMCallStep(
342
+ step_id=step_id,
343
+ timestamp=datetime.now(timezone.utc),
344
+ parent_step_id=self._get_parent_step_id(parent_run_id),
345
+ metadata=step_metadata,
346
+ model=model,
347
+ input=input_data,
348
+ output=output,
349
+ tokens_in=tokens_in,
350
+ tokens_out=tokens_out,
351
+ tokens_total=tokens_total,
352
+ latency_ms=latency_ms,
353
+ )
354
+
355
+ trace = self._instrumentor._get_current_trace()
356
+ trace.add_step(step)
357
+
358
+ def on_llm_error(
359
+ self,
360
+ error: BaseException,
361
+ *,
362
+ run_id: UUID,
363
+ parent_run_id: Optional[UUID] = None,
364
+ **kwargs: Any,
365
+ ) -> None:
366
+ """Called when LLM errors."""
367
+ run_id_str = str(run_id)
368
+ self._run_id_to_start_time.pop(run_id_str, None)
369
+ self._run_id_to_input.pop(run_id_str, None)
370
+ self._run_id_to_model.pop(run_id_str, None)
371
+ self._run_id_to_node.pop(run_id_str, None)
372
+ logger.debug(f"LLM error: {error}")
373
+
374
+ # Tool Callbacks
375
+ def on_tool_start(
376
+ self,
377
+ serialized: dict[str, Any],
378
+ input_str: str,
379
+ *,
380
+ run_id: UUID,
381
+ parent_run_id: Optional[UUID] = None,
382
+ tags: Optional[list[str]] = None,
383
+ metadata: Optional[dict[str, Any]] = None,
384
+ inputs: Optional[dict[str, Any]] = None,
385
+ **kwargs: Any,
386
+ ) -> None:
387
+ """Called when tool starts."""
388
+ run_id_str = str(run_id)
389
+ self._run_id_to_start_time[run_id_str] = time.perf_counter()
390
+ self._run_id_to_input[run_id_str] = inputs or {"input": input_str}
391
+ self._run_id_to_node[run_id_str] = self._extract_node_name(tags, metadata)
392
+
393
+ # Extract tool name from serialized
394
+ tool_name = serialized.get("name", "unknown_tool")
395
+ if tool_name == "unknown_tool":
396
+ # Try id field (e.g., ["langchain", "tools", "MyTool"])
397
+ id_list = serialized.get("id", [])
398
+ if id_list:
399
+ tool_name = id_list[-1]
400
+ self._run_id_to_tool_name[run_id_str] = tool_name
401
+
402
+ if parent_run_id:
403
+ self._parent_run_id_map[run_id_str] = str(parent_run_id)
404
+
405
+ def on_tool_end(
406
+ self,
407
+ output: Any,
408
+ *,
409
+ run_id: UUID,
410
+ parent_run_id: Optional[UUID] = None,
411
+ **kwargs: Any,
412
+ ) -> None:
413
+ """Called when tool ends."""
414
+ run_id_str = str(run_id)
415
+ step_id = self._get_step_id(run_id)
416
+
417
+ # Calculate latency
418
+ start_time = self._run_id_to_start_time.pop(run_id_str, None)
419
+ latency_ms = None
420
+ if start_time:
421
+ latency_ms = int((time.perf_counter() - start_time) * 1000)
422
+
423
+ # Get input/arguments, tool name, and node
424
+ arguments = self._run_id_to_input.pop(run_id_str, {})
425
+ tool_name = self._run_id_to_tool_name.pop(run_id_str, None)
426
+ node_name = self._run_id_to_node.pop(run_id_str, None)
427
+ if not tool_name:
428
+ tool_name = kwargs.get("name", "unknown_tool")
429
+
430
+ # Convert output to serializable format
431
+ result = output
432
+ if hasattr(output, "content"):
433
+ result = output.content
434
+ elif not isinstance(output, (str, int, float, bool, list, dict, type(None))):
435
+ result = str(output)
436
+
437
+ # Build metadata with node info
438
+ step_metadata = None
439
+ if node_name:
440
+ step_metadata = {"node": node_name}
441
+
442
+ step = ToolCallStep(
443
+ step_id=step_id,
444
+ timestamp=datetime.now(timezone.utc),
445
+ parent_step_id=self._get_parent_step_id(parent_run_id),
446
+ metadata=step_metadata,
447
+ tool_name=tool_name,
448
+ arguments=arguments if isinstance(arguments, dict) else {"input": arguments},
449
+ result=result,
450
+ latency_ms=latency_ms,
451
+ success=True,
452
+ )
453
+
454
+ trace = self._instrumentor._get_current_trace()
455
+ trace.add_step(step)
456
+
457
+ def on_tool_error(
458
+ self,
459
+ error: BaseException,
460
+ *,
461
+ run_id: UUID,
462
+ parent_run_id: Optional[UUID] = None,
463
+ **kwargs: Any,
464
+ ) -> None:
465
+ """Called when tool errors."""
466
+ run_id_str = str(run_id)
467
+ step_id = self._get_step_id(run_id)
468
+
469
+ start_time = self._run_id_to_start_time.pop(run_id_str, None)
470
+ latency_ms = None
471
+ if start_time:
472
+ latency_ms = int((time.perf_counter() - start_time) * 1000)
473
+
474
+ arguments = self._run_id_to_input.pop(run_id_str, {})
475
+ tool_name = self._run_id_to_tool_name.pop(run_id_str, None)
476
+ node_name = self._run_id_to_node.pop(run_id_str, None)
477
+ if not tool_name:
478
+ tool_name = kwargs.get("name", "unknown_tool")
479
+
480
+ # Build metadata with node info
481
+ step_metadata = None
482
+ if node_name:
483
+ step_metadata = {"node": node_name}
484
+
485
+ step = ToolCallStep(
486
+ step_id=step_id,
487
+ timestamp=datetime.now(timezone.utc),
488
+ parent_step_id=self._get_parent_step_id(parent_run_id),
489
+ metadata=step_metadata,
490
+ tool_name=tool_name,
491
+ arguments=arguments if isinstance(arguments, dict) else {"input": arguments},
492
+ result=None,
493
+ latency_ms=latency_ms,
494
+ success=False,
495
+ error=str(error),
496
+ )
497
+
498
+ trace = self._instrumentor._get_current_trace()
499
+ trace.add_step(step)
500
+
501
+ # Retriever Callbacks
502
+ def on_retriever_start(
503
+ self,
504
+ serialized: dict[str, Any],
505
+ query: str,
506
+ *,
507
+ run_id: UUID,
508
+ parent_run_id: Optional[UUID] = None,
509
+ tags: Optional[list[str]] = None,
510
+ metadata: Optional[dict[str, Any]] = None,
511
+ **kwargs: Any,
512
+ ) -> None:
513
+ """Called when retriever starts."""
514
+ run_id_str = str(run_id)
515
+ self._run_id_to_start_time[run_id_str] = time.perf_counter()
516
+ self._run_id_to_input[run_id_str] = query
517
+
518
+ if parent_run_id:
519
+ self._parent_run_id_map[run_id_str] = str(parent_run_id)
520
+
521
+ def on_retriever_end(
522
+ self,
523
+ documents: Sequence[Any],
524
+ *,
525
+ run_id: UUID,
526
+ parent_run_id: Optional[UUID] = None,
527
+ **kwargs: Any,
528
+ ) -> None:
529
+ """Called when retriever ends."""
530
+ run_id_str = str(run_id)
531
+ step_id = self._get_step_id(run_id)
532
+
533
+ start_time = self._run_id_to_start_time.pop(run_id_str, None)
534
+ latency_ms = None
535
+ if start_time:
536
+ latency_ms = int((time.perf_counter() - start_time) * 1000)
537
+
538
+ query = self._run_id_to_input.pop(run_id_str, "")
539
+
540
+ # Convert documents to RetrievalResult
541
+ results = []
542
+ for doc in documents:
543
+ content = ""
544
+ doc_metadata: dict[str, Any] = {}
545
+ score = None
546
+
547
+ if hasattr(doc, "page_content"):
548
+ content = doc.page_content
549
+ elif isinstance(doc, str):
550
+ content = doc
551
+ else:
552
+ content = str(doc)
553
+
554
+ if hasattr(doc, "metadata"):
555
+ doc_metadata = doc.metadata
556
+ score = doc_metadata.pop("score", None) if isinstance(doc_metadata, dict) else None
557
+
558
+ results.append(RetrievalResult(content=content, score=score, metadata=doc_metadata or None))
559
+
560
+ step = RetrievalStep(
561
+ step_id=step_id,
562
+ timestamp=datetime.now(timezone.utc),
563
+ parent_step_id=self._get_parent_step_id(parent_run_id),
564
+ query=query,
565
+ results=results,
566
+ match_count=len(results),
567
+ latency_ms=latency_ms,
568
+ )
569
+
570
+ trace = self._instrumentor._get_current_trace()
571
+ trace.add_step(step)
572
+
573
+ def on_retriever_error(
574
+ self,
575
+ error: BaseException,
576
+ *,
577
+ run_id: UUID,
578
+ parent_run_id: Optional[UUID] = None,
579
+ **kwargs: Any,
580
+ ) -> None:
581
+ """Called when retriever errors."""
582
+ run_id_str = str(run_id)
583
+ self._run_id_to_start_time.pop(run_id_str, None)
584
+ self._run_id_to_input.pop(run_id_str, None)
585
+ logger.debug(f"Retriever error: {error}")
586
+
587
+ # Chain callbacks (for user input / final output tracking)
588
+ def on_chain_start(
589
+ self,
590
+ serialized: dict[str, Any],
591
+ inputs: dict[str, Any],
592
+ *,
593
+ run_id: UUID,
594
+ parent_run_id: Optional[UUID] = None,
595
+ tags: Optional[list[str]] = None,
596
+ metadata: Optional[dict[str, Any]] = None,
597
+ **kwargs: Any,
598
+ ) -> None:
599
+ """Called when chain starts."""
600
+ # Only record as user input if this is a top-level chain
601
+ if parent_run_id is None:
602
+ step_id = self._get_step_id(run_id)
603
+
604
+ # Extract user input from inputs
605
+ content = ""
606
+ if "input" in inputs:
607
+ content = str(inputs["input"])
608
+ elif "question" in inputs:
609
+ content = str(inputs["question"])
610
+ elif len(inputs) == 1:
611
+ content = str(list(inputs.values())[0])
612
+ else:
613
+ content = str(inputs)
614
+
615
+ step = UserInputStep(
616
+ step_id=step_id,
617
+ timestamp=datetime.now(timezone.utc),
618
+ content=content,
619
+ )
620
+
621
+ trace = self._instrumentor._get_current_trace()
622
+ trace.add_step(step)
623
+
624
+ def on_chain_end(
625
+ self,
626
+ outputs: dict[str, Any],
627
+ *,
628
+ run_id: UUID,
629
+ parent_run_id: Optional[UUID] = None,
630
+ **kwargs: Any,
631
+ ) -> None:
632
+ """Called when chain ends."""
633
+ # Only record as final output if this is a top-level chain
634
+ if parent_run_id is None:
635
+ step_id = str(uuid.uuid4())
636
+
637
+ # Extract output
638
+ content: Any = ""
639
+ if "output" in outputs:
640
+ content = outputs["output"]
641
+ elif "result" in outputs:
642
+ content = outputs["result"]
643
+ elif "answer" in outputs:
644
+ content = outputs["answer"]
645
+ elif len(outputs) == 1:
646
+ content = list(outputs.values())[0]
647
+ else:
648
+ content = outputs
649
+
650
+ step = FinalOutputStep(
651
+ step_id=step_id,
652
+ timestamp=datetime.now(timezone.utc),
653
+ content=content,
654
+ )
655
+
656
+ trace = self._instrumentor._get_current_trace()
657
+ trace.add_step(step)
658
+
659
+ # Finalize the trace
660
+ self._instrumentor._finalize_current_trace()
661
+
662
+ def on_chain_error(
663
+ self,
664
+ error: BaseException,
665
+ *,
666
+ run_id: UUID,
667
+ parent_run_id: Optional[UUID] = None,
668
+ **kwargs: Any,
669
+ ) -> None:
670
+ """Called when chain errors."""
671
+ logger.debug(f"Chain error: {error}")