contextforge-eval 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- context_forge/__init__.py +95 -0
- context_forge/core/__init__.py +55 -0
- context_forge/core/trace.py +369 -0
- context_forge/core/types.py +121 -0
- context_forge/evaluation.py +267 -0
- context_forge/exceptions.py +56 -0
- context_forge/graders/__init__.py +44 -0
- context_forge/graders/base.py +264 -0
- context_forge/graders/deterministic/__init__.py +11 -0
- context_forge/graders/deterministic/memory_corruption.py +130 -0
- context_forge/graders/hybrid.py +190 -0
- context_forge/graders/judges/__init__.py +11 -0
- context_forge/graders/judges/backends/__init__.py +9 -0
- context_forge/graders/judges/backends/ollama.py +173 -0
- context_forge/graders/judges/base.py +158 -0
- context_forge/graders/judges/memory_hygiene_judge.py +332 -0
- context_forge/graders/judges/models.py +113 -0
- context_forge/harness/__init__.py +43 -0
- context_forge/harness/user_simulator/__init__.py +70 -0
- context_forge/harness/user_simulator/adapters/__init__.py +13 -0
- context_forge/harness/user_simulator/adapters/base.py +67 -0
- context_forge/harness/user_simulator/adapters/crewai.py +100 -0
- context_forge/harness/user_simulator/adapters/langgraph.py +157 -0
- context_forge/harness/user_simulator/adapters/pydanticai.py +105 -0
- context_forge/harness/user_simulator/llm/__init__.py +5 -0
- context_forge/harness/user_simulator/llm/ollama.py +119 -0
- context_forge/harness/user_simulator/models.py +103 -0
- context_forge/harness/user_simulator/persona.py +154 -0
- context_forge/harness/user_simulator/runner.py +342 -0
- context_forge/harness/user_simulator/scenario.py +95 -0
- context_forge/harness/user_simulator/simulator.py +307 -0
- context_forge/instrumentation/__init__.py +23 -0
- context_forge/instrumentation/base.py +307 -0
- context_forge/instrumentation/instrumentors/__init__.py +17 -0
- context_forge/instrumentation/instrumentors/langchain.py +671 -0
- context_forge/instrumentation/instrumentors/langgraph.py +534 -0
- context_forge/instrumentation/tracer.py +588 -0
- context_forge/py.typed +0 -0
- contextforge_eval-0.1.0.dist-info/METADATA +420 -0
- contextforge_eval-0.1.0.dist-info/RECORD +43 -0
- contextforge_eval-0.1.0.dist-info/WHEEL +5 -0
- contextforge_eval-0.1.0.dist-info/licenses/LICENSE +201 -0
- contextforge_eval-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,671 @@
|
|
|
1
|
+
"""LangChain instrumentor for ContextForge.
|
|
2
|
+
|
|
3
|
+
This module implements:
|
|
4
|
+
- T039: LangChainInstrumentor
|
|
5
|
+
- T040: LangChain callback hooks for LLM, Tool, Retriever
|
|
6
|
+
- T041: Token usage capture from LangChain callbacks
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import logging
|
|
10
|
+
import time
|
|
11
|
+
import uuid
|
|
12
|
+
from datetime import datetime, timezone
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import TYPE_CHECKING, Any, Optional, Sequence
|
|
15
|
+
from uuid import UUID
|
|
16
|
+
|
|
17
|
+
from context_forge.core.trace import (
|
|
18
|
+
FinalOutputStep,
|
|
19
|
+
LLMCallStep,
|
|
20
|
+
RetrievalStep,
|
|
21
|
+
ToolCallStep,
|
|
22
|
+
UserInputStep,
|
|
23
|
+
)
|
|
24
|
+
from context_forge.core.types import RetrievalResult
|
|
25
|
+
from context_forge.instrumentation.base import BaseInstrumentor, RedactionConfig
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
# Try to import LangChain's BaseCallbackHandler for proper inheritance
|
|
30
|
+
try:
|
|
31
|
+
from langchain_core.callbacks import BaseCallbackHandler
|
|
32
|
+
|
|
33
|
+
_LANGCHAIN_AVAILABLE = True
|
|
34
|
+
except ImportError:
|
|
35
|
+
BaseCallbackHandler = object # type: ignore
|
|
36
|
+
_LANGCHAIN_AVAILABLE = False
|
|
37
|
+
|
|
38
|
+
if TYPE_CHECKING:
|
|
39
|
+
from langchain_core.callbacks import BaseCallbackHandler
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class LangChainInstrumentor(BaseInstrumentor):
|
|
43
|
+
"""Auto-instrumentor for LangChain/LangGraph agents.
|
|
44
|
+
|
|
45
|
+
Provides one-line instrumentation for LangChain-based agents
|
|
46
|
+
by installing a global callback handler.
|
|
47
|
+
|
|
48
|
+
Usage:
|
|
49
|
+
LangChainInstrumentor().instrument()
|
|
50
|
+
# ... your LangChain code runs normally ...
|
|
51
|
+
# All LLM calls, tool calls, and retrievals are captured
|
|
52
|
+
|
|
53
|
+
Or with context manager:
|
|
54
|
+
with LangChainInstrumentor(output_path="./traces") as instrumentor:
|
|
55
|
+
chain.invoke({"input": "hello"})
|
|
56
|
+
# Traces saved automatically
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
agent_name: str = "langchain_agent",
|
|
62
|
+
agent_version: Optional[str] = None,
|
|
63
|
+
output_path: Optional[str | Path] = None,
|
|
64
|
+
redaction_config: Optional[RedactionConfig] = None,
|
|
65
|
+
):
|
|
66
|
+
super().__init__(
|
|
67
|
+
agent_name=agent_name,
|
|
68
|
+
agent_version=agent_version,
|
|
69
|
+
output_path=output_path,
|
|
70
|
+
redaction_config=redaction_config,
|
|
71
|
+
)
|
|
72
|
+
self._handler: Optional["ContextForgeCallbackHandler"] = None
|
|
73
|
+
self._original_handlers: list[Any] = []
|
|
74
|
+
self._framework_version: Optional[str] = None
|
|
75
|
+
|
|
76
|
+
@property
|
|
77
|
+
def framework(self) -> str:
|
|
78
|
+
return "langchain"
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def framework_version(self) -> Optional[str]:
|
|
82
|
+
if self._framework_version is None:
|
|
83
|
+
try:
|
|
84
|
+
import langchain_core
|
|
85
|
+
|
|
86
|
+
self._framework_version = getattr(langchain_core, "__version__", None)
|
|
87
|
+
except ImportError:
|
|
88
|
+
pass
|
|
89
|
+
return self._framework_version
|
|
90
|
+
|
|
91
|
+
def _install_hooks(self) -> None:
|
|
92
|
+
"""Install LangChain callback handler globally."""
|
|
93
|
+
if not _LANGCHAIN_AVAILABLE:
|
|
94
|
+
raise ImportError(
|
|
95
|
+
"LangChain is required for LangChainInstrumentor. "
|
|
96
|
+
"Install it with: pip install langchain-core"
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# Create our callback handler
|
|
100
|
+
self._handler = ContextForgeCallbackHandler(instrumentor=self)
|
|
101
|
+
logger.debug("LangChain instrumentor handler ready")
|
|
102
|
+
|
|
103
|
+
def _remove_hooks(self) -> None:
|
|
104
|
+
"""Remove LangChain callback handler."""
|
|
105
|
+
self._handler = None
|
|
106
|
+
|
|
107
|
+
def get_callback_handler(self) -> "ContextForgeCallbackHandler":
|
|
108
|
+
"""Get the callback handler for explicit use.
|
|
109
|
+
|
|
110
|
+
Useful when you want to pass the handler explicitly:
|
|
111
|
+
handler = instrumentor.get_callback_handler()
|
|
112
|
+
chain.invoke(input, config={"callbacks": [handler]})
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
The ContextForgeCallbackHandler instance
|
|
116
|
+
"""
|
|
117
|
+
if self._handler is None:
|
|
118
|
+
if not _LANGCHAIN_AVAILABLE:
|
|
119
|
+
raise ImportError(
|
|
120
|
+
"LangChain is required for LangChainInstrumentor. "
|
|
121
|
+
"Install it with: pip install langchain-core"
|
|
122
|
+
)
|
|
123
|
+
self._handler = ContextForgeCallbackHandler(instrumentor=self)
|
|
124
|
+
return self._handler
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class ContextForgeCallbackHandler(BaseCallbackHandler): # type: ignore[misc]
|
|
128
|
+
"""LangChain callback handler that captures trace events.
|
|
129
|
+
|
|
130
|
+
Implements LangChain's callback interface to capture LLM calls,
|
|
131
|
+
tool executions, and retrieval operations.
|
|
132
|
+
|
|
133
|
+
Inherits from langchain_core.callbacks.BaseCallbackHandler when
|
|
134
|
+
LangChain is installed, ensuring compatibility with LangChain/LangGraph.
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
def __init__(self, instrumentor: LangChainInstrumentor):
|
|
138
|
+
# Call parent __init__ if we're inheriting from BaseCallbackHandler
|
|
139
|
+
if _LANGCHAIN_AVAILABLE and hasattr(super(), "__init__"):
|
|
140
|
+
super().__init__()
|
|
141
|
+
self._instrumentor = instrumentor
|
|
142
|
+
self._run_id_to_step_id: dict[str, str] = {}
|
|
143
|
+
self._run_id_to_start_time: dict[str, float] = {}
|
|
144
|
+
self._run_id_to_input: dict[str, Any] = {}
|
|
145
|
+
self._run_id_to_model: dict[str, str] = {}
|
|
146
|
+
self._run_id_to_tool_name: dict[str, str] = {}
|
|
147
|
+
self._run_id_to_node: dict[str, str | None] = {}
|
|
148
|
+
self._parent_run_id_map: dict[str, str] = {}
|
|
149
|
+
|
|
150
|
+
def _extract_node_name(self, tags: list[str] | None, metadata: dict[str, Any] | None) -> str | None:
|
|
151
|
+
"""Extract LangGraph node name from tags or metadata.
|
|
152
|
+
|
|
153
|
+
LangGraph tags often include patterns like:
|
|
154
|
+
- 'graph:step:recommend'
|
|
155
|
+
- 'langgraph:step:2'
|
|
156
|
+
- 'seq:step:1'
|
|
157
|
+
|
|
158
|
+
Metadata may include 'langgraph_node' or 'node' keys.
|
|
159
|
+
"""
|
|
160
|
+
# Check metadata first
|
|
161
|
+
if metadata:
|
|
162
|
+
if "langgraph_node" in metadata:
|
|
163
|
+
return metadata["langgraph_node"]
|
|
164
|
+
if "node" in metadata:
|
|
165
|
+
return metadata["node"]
|
|
166
|
+
|
|
167
|
+
# Parse tags for node name
|
|
168
|
+
if tags:
|
|
169
|
+
for tag in tags:
|
|
170
|
+
# Look for LangGraph step patterns
|
|
171
|
+
if tag.startswith("graph:step:"):
|
|
172
|
+
return tag.split(":")[-1]
|
|
173
|
+
if tag.startswith("langgraph_step:"):
|
|
174
|
+
return tag.split(":")[-1]
|
|
175
|
+
# Also check for plain node names in tags
|
|
176
|
+
# LangGraph sometimes includes node names directly
|
|
177
|
+
for tag in tags:
|
|
178
|
+
if not tag.startswith(("seq:", "graph:", "langgraph")):
|
|
179
|
+
# Could be a node name
|
|
180
|
+
return tag
|
|
181
|
+
|
|
182
|
+
return None
|
|
183
|
+
|
|
184
|
+
def _get_step_id(self, run_id: UUID) -> str:
|
|
185
|
+
"""Get or create step ID for a run ID."""
|
|
186
|
+
run_id_str = str(run_id)
|
|
187
|
+
if run_id_str not in self._run_id_to_step_id:
|
|
188
|
+
self._run_id_to_step_id[run_id_str] = str(uuid.uuid4())
|
|
189
|
+
return self._run_id_to_step_id[run_id_str]
|
|
190
|
+
|
|
191
|
+
def _get_parent_step_id(self, parent_run_id: Optional[UUID]) -> Optional[str]:
|
|
192
|
+
"""Get parent step ID if parent run exists."""
|
|
193
|
+
if parent_run_id is None:
|
|
194
|
+
return None
|
|
195
|
+
return self._run_id_to_step_id.get(str(parent_run_id))
|
|
196
|
+
|
|
197
|
+
# LLM Callbacks
|
|
198
|
+
def _extract_model_name(self, serialized: dict[str, Any]) -> str:
|
|
199
|
+
"""Extract model name from serialized LLM config."""
|
|
200
|
+
# Try kwargs first (most common)
|
|
201
|
+
kwargs = serialized.get("kwargs", {})
|
|
202
|
+
if "model_name" in kwargs:
|
|
203
|
+
return kwargs["model_name"]
|
|
204
|
+
if "model" in kwargs:
|
|
205
|
+
return kwargs["model"]
|
|
206
|
+
|
|
207
|
+
# Try id field (e.g., ["langchain", "chat_models", "openai", "ChatOpenAI"])
|
|
208
|
+
id_list = serialized.get("id", [])
|
|
209
|
+
if id_list and len(id_list) > 0:
|
|
210
|
+
# Use the last non-class identifier
|
|
211
|
+
return id_list[-1]
|
|
212
|
+
|
|
213
|
+
# Try name field
|
|
214
|
+
if "name" in serialized:
|
|
215
|
+
return serialized["name"]
|
|
216
|
+
|
|
217
|
+
return "unknown"
|
|
218
|
+
|
|
219
|
+
def on_llm_start(
|
|
220
|
+
self,
|
|
221
|
+
serialized: dict[str, Any],
|
|
222
|
+
prompts: list[str],
|
|
223
|
+
*,
|
|
224
|
+
run_id: UUID,
|
|
225
|
+
parent_run_id: Optional[UUID] = None,
|
|
226
|
+
tags: Optional[list[str]] = None,
|
|
227
|
+
metadata: Optional[dict[str, Any]] = None,
|
|
228
|
+
**kwargs: Any,
|
|
229
|
+
) -> None:
|
|
230
|
+
"""Called when LLM starts."""
|
|
231
|
+
run_id_str = str(run_id)
|
|
232
|
+
self._run_id_to_start_time[run_id_str] = time.perf_counter()
|
|
233
|
+
self._run_id_to_input[run_id_str] = prompts[0] if len(prompts) == 1 else prompts
|
|
234
|
+
self._run_id_to_model[run_id_str] = self._extract_model_name(serialized)
|
|
235
|
+
self._run_id_to_node[run_id_str] = self._extract_node_name(tags, metadata)
|
|
236
|
+
|
|
237
|
+
if parent_run_id:
|
|
238
|
+
self._parent_run_id_map[run_id_str] = str(parent_run_id)
|
|
239
|
+
|
|
240
|
+
def on_chat_model_start(
|
|
241
|
+
self,
|
|
242
|
+
serialized: dict[str, Any],
|
|
243
|
+
messages: list[list[Any]],
|
|
244
|
+
*,
|
|
245
|
+
run_id: UUID,
|
|
246
|
+
parent_run_id: Optional[UUID] = None,
|
|
247
|
+
tags: Optional[list[str]] = None,
|
|
248
|
+
metadata: Optional[dict[str, Any]] = None,
|
|
249
|
+
**kwargs: Any,
|
|
250
|
+
) -> None:
|
|
251
|
+
"""Called when chat model starts."""
|
|
252
|
+
run_id_str = str(run_id)
|
|
253
|
+
self._run_id_to_start_time[run_id_str] = time.perf_counter()
|
|
254
|
+
self._run_id_to_model[run_id_str] = self._extract_model_name(serialized)
|
|
255
|
+
self._run_id_to_node[run_id_str] = self._extract_node_name(tags, metadata)
|
|
256
|
+
|
|
257
|
+
# Convert messages to serializable format
|
|
258
|
+
formatted_messages = []
|
|
259
|
+
for msg_list in messages:
|
|
260
|
+
for msg in msg_list:
|
|
261
|
+
if hasattr(msg, "type") and hasattr(msg, "content"):
|
|
262
|
+
formatted_messages.append({"role": msg.type, "content": msg.content})
|
|
263
|
+
elif isinstance(msg, dict):
|
|
264
|
+
formatted_messages.append(msg)
|
|
265
|
+
|
|
266
|
+
self._run_id_to_input[run_id_str] = formatted_messages
|
|
267
|
+
|
|
268
|
+
if parent_run_id:
|
|
269
|
+
self._parent_run_id_map[run_id_str] = str(parent_run_id)
|
|
270
|
+
|
|
271
|
+
def on_llm_end(
|
|
272
|
+
self,
|
|
273
|
+
response: Any,
|
|
274
|
+
*,
|
|
275
|
+
run_id: UUID,
|
|
276
|
+
parent_run_id: Optional[UUID] = None,
|
|
277
|
+
**kwargs: Any,
|
|
278
|
+
) -> None:
|
|
279
|
+
"""Called when LLM ends."""
|
|
280
|
+
run_id_str = str(run_id)
|
|
281
|
+
step_id = self._get_step_id(run_id)
|
|
282
|
+
|
|
283
|
+
# Calculate latency
|
|
284
|
+
start_time = self._run_id_to_start_time.pop(run_id_str, None)
|
|
285
|
+
latency_ms = None
|
|
286
|
+
if start_time:
|
|
287
|
+
latency_ms = int((time.perf_counter() - start_time) * 1000)
|
|
288
|
+
|
|
289
|
+
# Get input, model, and node
|
|
290
|
+
input_data = self._run_id_to_input.pop(run_id_str, "")
|
|
291
|
+
model = self._run_id_to_model.pop(run_id_str, "unknown")
|
|
292
|
+
node_name = self._run_id_to_node.pop(run_id_str, None)
|
|
293
|
+
|
|
294
|
+
# Try to get model from response if not captured from start
|
|
295
|
+
if model == "unknown" and hasattr(response, "llm_output") and response.llm_output:
|
|
296
|
+
model = response.llm_output.get("model_name", model)
|
|
297
|
+
|
|
298
|
+
# Extract output and token usage
|
|
299
|
+
output = ""
|
|
300
|
+
tokens_in = None
|
|
301
|
+
tokens_out = None
|
|
302
|
+
tokens_total = None
|
|
303
|
+
|
|
304
|
+
if hasattr(response, "generations") and response.generations:
|
|
305
|
+
for gen_list in response.generations:
|
|
306
|
+
for gen in gen_list:
|
|
307
|
+
# Try to extract text output
|
|
308
|
+
if hasattr(gen, "text") and gen.text is not None:
|
|
309
|
+
output = gen.text
|
|
310
|
+
elif hasattr(gen, "message") and gen.message is not None:
|
|
311
|
+
if hasattr(gen.message, "content") and gen.message.content is not None:
|
|
312
|
+
output = gen.message.content
|
|
313
|
+
else:
|
|
314
|
+
output = str(gen.message)
|
|
315
|
+
|
|
316
|
+
# Extract token usage from message
|
|
317
|
+
if hasattr(gen, "message") and gen.message is not None:
|
|
318
|
+
if hasattr(gen.message, "usage_metadata") and gen.message.usage_metadata is not None:
|
|
319
|
+
usage = gen.message.usage_metadata
|
|
320
|
+
if hasattr(usage, "input_tokens"):
|
|
321
|
+
tokens_in = usage.input_tokens
|
|
322
|
+
if hasattr(usage, "output_tokens"):
|
|
323
|
+
tokens_out = usage.output_tokens
|
|
324
|
+
if hasattr(usage, "total_tokens"):
|
|
325
|
+
tokens_total = usage.total_tokens
|
|
326
|
+
|
|
327
|
+
# Also check llm_output for token usage
|
|
328
|
+
if hasattr(response, "llm_output") and response.llm_output:
|
|
329
|
+
token_usage = response.llm_output.get("token_usage", {})
|
|
330
|
+
if token_usage:
|
|
331
|
+
tokens_in = tokens_in or token_usage.get("prompt_tokens")
|
|
332
|
+
tokens_out = tokens_out or token_usage.get("completion_tokens")
|
|
333
|
+
tokens_total = tokens_total or token_usage.get("total_tokens")
|
|
334
|
+
|
|
335
|
+
# Build metadata with node info
|
|
336
|
+
step_metadata = None
|
|
337
|
+
if node_name:
|
|
338
|
+
step_metadata = {"node": node_name}
|
|
339
|
+
|
|
340
|
+
# Create step
|
|
341
|
+
step = LLMCallStep(
|
|
342
|
+
step_id=step_id,
|
|
343
|
+
timestamp=datetime.now(timezone.utc),
|
|
344
|
+
parent_step_id=self._get_parent_step_id(parent_run_id),
|
|
345
|
+
metadata=step_metadata,
|
|
346
|
+
model=model,
|
|
347
|
+
input=input_data,
|
|
348
|
+
output=output,
|
|
349
|
+
tokens_in=tokens_in,
|
|
350
|
+
tokens_out=tokens_out,
|
|
351
|
+
tokens_total=tokens_total,
|
|
352
|
+
latency_ms=latency_ms,
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
trace = self._instrumentor._get_current_trace()
|
|
356
|
+
trace.add_step(step)
|
|
357
|
+
|
|
358
|
+
def on_llm_error(
|
|
359
|
+
self,
|
|
360
|
+
error: BaseException,
|
|
361
|
+
*,
|
|
362
|
+
run_id: UUID,
|
|
363
|
+
parent_run_id: Optional[UUID] = None,
|
|
364
|
+
**kwargs: Any,
|
|
365
|
+
) -> None:
|
|
366
|
+
"""Called when LLM errors."""
|
|
367
|
+
run_id_str = str(run_id)
|
|
368
|
+
self._run_id_to_start_time.pop(run_id_str, None)
|
|
369
|
+
self._run_id_to_input.pop(run_id_str, None)
|
|
370
|
+
self._run_id_to_model.pop(run_id_str, None)
|
|
371
|
+
self._run_id_to_node.pop(run_id_str, None)
|
|
372
|
+
logger.debug(f"LLM error: {error}")
|
|
373
|
+
|
|
374
|
+
# Tool Callbacks
|
|
375
|
+
def on_tool_start(
|
|
376
|
+
self,
|
|
377
|
+
serialized: dict[str, Any],
|
|
378
|
+
input_str: str,
|
|
379
|
+
*,
|
|
380
|
+
run_id: UUID,
|
|
381
|
+
parent_run_id: Optional[UUID] = None,
|
|
382
|
+
tags: Optional[list[str]] = None,
|
|
383
|
+
metadata: Optional[dict[str, Any]] = None,
|
|
384
|
+
inputs: Optional[dict[str, Any]] = None,
|
|
385
|
+
**kwargs: Any,
|
|
386
|
+
) -> None:
|
|
387
|
+
"""Called when tool starts."""
|
|
388
|
+
run_id_str = str(run_id)
|
|
389
|
+
self._run_id_to_start_time[run_id_str] = time.perf_counter()
|
|
390
|
+
self._run_id_to_input[run_id_str] = inputs or {"input": input_str}
|
|
391
|
+
self._run_id_to_node[run_id_str] = self._extract_node_name(tags, metadata)
|
|
392
|
+
|
|
393
|
+
# Extract tool name from serialized
|
|
394
|
+
tool_name = serialized.get("name", "unknown_tool")
|
|
395
|
+
if tool_name == "unknown_tool":
|
|
396
|
+
# Try id field (e.g., ["langchain", "tools", "MyTool"])
|
|
397
|
+
id_list = serialized.get("id", [])
|
|
398
|
+
if id_list:
|
|
399
|
+
tool_name = id_list[-1]
|
|
400
|
+
self._run_id_to_tool_name[run_id_str] = tool_name
|
|
401
|
+
|
|
402
|
+
if parent_run_id:
|
|
403
|
+
self._parent_run_id_map[run_id_str] = str(parent_run_id)
|
|
404
|
+
|
|
405
|
+
def on_tool_end(
|
|
406
|
+
self,
|
|
407
|
+
output: Any,
|
|
408
|
+
*,
|
|
409
|
+
run_id: UUID,
|
|
410
|
+
parent_run_id: Optional[UUID] = None,
|
|
411
|
+
**kwargs: Any,
|
|
412
|
+
) -> None:
|
|
413
|
+
"""Called when tool ends."""
|
|
414
|
+
run_id_str = str(run_id)
|
|
415
|
+
step_id = self._get_step_id(run_id)
|
|
416
|
+
|
|
417
|
+
# Calculate latency
|
|
418
|
+
start_time = self._run_id_to_start_time.pop(run_id_str, None)
|
|
419
|
+
latency_ms = None
|
|
420
|
+
if start_time:
|
|
421
|
+
latency_ms = int((time.perf_counter() - start_time) * 1000)
|
|
422
|
+
|
|
423
|
+
# Get input/arguments, tool name, and node
|
|
424
|
+
arguments = self._run_id_to_input.pop(run_id_str, {})
|
|
425
|
+
tool_name = self._run_id_to_tool_name.pop(run_id_str, None)
|
|
426
|
+
node_name = self._run_id_to_node.pop(run_id_str, None)
|
|
427
|
+
if not tool_name:
|
|
428
|
+
tool_name = kwargs.get("name", "unknown_tool")
|
|
429
|
+
|
|
430
|
+
# Convert output to serializable format
|
|
431
|
+
result = output
|
|
432
|
+
if hasattr(output, "content"):
|
|
433
|
+
result = output.content
|
|
434
|
+
elif not isinstance(output, (str, int, float, bool, list, dict, type(None))):
|
|
435
|
+
result = str(output)
|
|
436
|
+
|
|
437
|
+
# Build metadata with node info
|
|
438
|
+
step_metadata = None
|
|
439
|
+
if node_name:
|
|
440
|
+
step_metadata = {"node": node_name}
|
|
441
|
+
|
|
442
|
+
step = ToolCallStep(
|
|
443
|
+
step_id=step_id,
|
|
444
|
+
timestamp=datetime.now(timezone.utc),
|
|
445
|
+
parent_step_id=self._get_parent_step_id(parent_run_id),
|
|
446
|
+
metadata=step_metadata,
|
|
447
|
+
tool_name=tool_name,
|
|
448
|
+
arguments=arguments if isinstance(arguments, dict) else {"input": arguments},
|
|
449
|
+
result=result,
|
|
450
|
+
latency_ms=latency_ms,
|
|
451
|
+
success=True,
|
|
452
|
+
)
|
|
453
|
+
|
|
454
|
+
trace = self._instrumentor._get_current_trace()
|
|
455
|
+
trace.add_step(step)
|
|
456
|
+
|
|
457
|
+
def on_tool_error(
|
|
458
|
+
self,
|
|
459
|
+
error: BaseException,
|
|
460
|
+
*,
|
|
461
|
+
run_id: UUID,
|
|
462
|
+
parent_run_id: Optional[UUID] = None,
|
|
463
|
+
**kwargs: Any,
|
|
464
|
+
) -> None:
|
|
465
|
+
"""Called when tool errors."""
|
|
466
|
+
run_id_str = str(run_id)
|
|
467
|
+
step_id = self._get_step_id(run_id)
|
|
468
|
+
|
|
469
|
+
start_time = self._run_id_to_start_time.pop(run_id_str, None)
|
|
470
|
+
latency_ms = None
|
|
471
|
+
if start_time:
|
|
472
|
+
latency_ms = int((time.perf_counter() - start_time) * 1000)
|
|
473
|
+
|
|
474
|
+
arguments = self._run_id_to_input.pop(run_id_str, {})
|
|
475
|
+
tool_name = self._run_id_to_tool_name.pop(run_id_str, None)
|
|
476
|
+
node_name = self._run_id_to_node.pop(run_id_str, None)
|
|
477
|
+
if not tool_name:
|
|
478
|
+
tool_name = kwargs.get("name", "unknown_tool")
|
|
479
|
+
|
|
480
|
+
# Build metadata with node info
|
|
481
|
+
step_metadata = None
|
|
482
|
+
if node_name:
|
|
483
|
+
step_metadata = {"node": node_name}
|
|
484
|
+
|
|
485
|
+
step = ToolCallStep(
|
|
486
|
+
step_id=step_id,
|
|
487
|
+
timestamp=datetime.now(timezone.utc),
|
|
488
|
+
parent_step_id=self._get_parent_step_id(parent_run_id),
|
|
489
|
+
metadata=step_metadata,
|
|
490
|
+
tool_name=tool_name,
|
|
491
|
+
arguments=arguments if isinstance(arguments, dict) else {"input": arguments},
|
|
492
|
+
result=None,
|
|
493
|
+
latency_ms=latency_ms,
|
|
494
|
+
success=False,
|
|
495
|
+
error=str(error),
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
trace = self._instrumentor._get_current_trace()
|
|
499
|
+
trace.add_step(step)
|
|
500
|
+
|
|
501
|
+
# Retriever Callbacks
|
|
502
|
+
def on_retriever_start(
|
|
503
|
+
self,
|
|
504
|
+
serialized: dict[str, Any],
|
|
505
|
+
query: str,
|
|
506
|
+
*,
|
|
507
|
+
run_id: UUID,
|
|
508
|
+
parent_run_id: Optional[UUID] = None,
|
|
509
|
+
tags: Optional[list[str]] = None,
|
|
510
|
+
metadata: Optional[dict[str, Any]] = None,
|
|
511
|
+
**kwargs: Any,
|
|
512
|
+
) -> None:
|
|
513
|
+
"""Called when retriever starts."""
|
|
514
|
+
run_id_str = str(run_id)
|
|
515
|
+
self._run_id_to_start_time[run_id_str] = time.perf_counter()
|
|
516
|
+
self._run_id_to_input[run_id_str] = query
|
|
517
|
+
|
|
518
|
+
if parent_run_id:
|
|
519
|
+
self._parent_run_id_map[run_id_str] = str(parent_run_id)
|
|
520
|
+
|
|
521
|
+
def on_retriever_end(
|
|
522
|
+
self,
|
|
523
|
+
documents: Sequence[Any],
|
|
524
|
+
*,
|
|
525
|
+
run_id: UUID,
|
|
526
|
+
parent_run_id: Optional[UUID] = None,
|
|
527
|
+
**kwargs: Any,
|
|
528
|
+
) -> None:
|
|
529
|
+
"""Called when retriever ends."""
|
|
530
|
+
run_id_str = str(run_id)
|
|
531
|
+
step_id = self._get_step_id(run_id)
|
|
532
|
+
|
|
533
|
+
start_time = self._run_id_to_start_time.pop(run_id_str, None)
|
|
534
|
+
latency_ms = None
|
|
535
|
+
if start_time:
|
|
536
|
+
latency_ms = int((time.perf_counter() - start_time) * 1000)
|
|
537
|
+
|
|
538
|
+
query = self._run_id_to_input.pop(run_id_str, "")
|
|
539
|
+
|
|
540
|
+
# Convert documents to RetrievalResult
|
|
541
|
+
results = []
|
|
542
|
+
for doc in documents:
|
|
543
|
+
content = ""
|
|
544
|
+
doc_metadata: dict[str, Any] = {}
|
|
545
|
+
score = None
|
|
546
|
+
|
|
547
|
+
if hasattr(doc, "page_content"):
|
|
548
|
+
content = doc.page_content
|
|
549
|
+
elif isinstance(doc, str):
|
|
550
|
+
content = doc
|
|
551
|
+
else:
|
|
552
|
+
content = str(doc)
|
|
553
|
+
|
|
554
|
+
if hasattr(doc, "metadata"):
|
|
555
|
+
doc_metadata = doc.metadata
|
|
556
|
+
score = doc_metadata.pop("score", None) if isinstance(doc_metadata, dict) else None
|
|
557
|
+
|
|
558
|
+
results.append(RetrievalResult(content=content, score=score, metadata=doc_metadata or None))
|
|
559
|
+
|
|
560
|
+
step = RetrievalStep(
|
|
561
|
+
step_id=step_id,
|
|
562
|
+
timestamp=datetime.now(timezone.utc),
|
|
563
|
+
parent_step_id=self._get_parent_step_id(parent_run_id),
|
|
564
|
+
query=query,
|
|
565
|
+
results=results,
|
|
566
|
+
match_count=len(results),
|
|
567
|
+
latency_ms=latency_ms,
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
trace = self._instrumentor._get_current_trace()
|
|
571
|
+
trace.add_step(step)
|
|
572
|
+
|
|
573
|
+
def on_retriever_error(
|
|
574
|
+
self,
|
|
575
|
+
error: BaseException,
|
|
576
|
+
*,
|
|
577
|
+
run_id: UUID,
|
|
578
|
+
parent_run_id: Optional[UUID] = None,
|
|
579
|
+
**kwargs: Any,
|
|
580
|
+
) -> None:
|
|
581
|
+
"""Called when retriever errors."""
|
|
582
|
+
run_id_str = str(run_id)
|
|
583
|
+
self._run_id_to_start_time.pop(run_id_str, None)
|
|
584
|
+
self._run_id_to_input.pop(run_id_str, None)
|
|
585
|
+
logger.debug(f"Retriever error: {error}")
|
|
586
|
+
|
|
587
|
+
# Chain callbacks (for user input / final output tracking)
|
|
588
|
+
def on_chain_start(
|
|
589
|
+
self,
|
|
590
|
+
serialized: dict[str, Any],
|
|
591
|
+
inputs: dict[str, Any],
|
|
592
|
+
*,
|
|
593
|
+
run_id: UUID,
|
|
594
|
+
parent_run_id: Optional[UUID] = None,
|
|
595
|
+
tags: Optional[list[str]] = None,
|
|
596
|
+
metadata: Optional[dict[str, Any]] = None,
|
|
597
|
+
**kwargs: Any,
|
|
598
|
+
) -> None:
|
|
599
|
+
"""Called when chain starts."""
|
|
600
|
+
# Only record as user input if this is a top-level chain
|
|
601
|
+
if parent_run_id is None:
|
|
602
|
+
step_id = self._get_step_id(run_id)
|
|
603
|
+
|
|
604
|
+
# Extract user input from inputs
|
|
605
|
+
content = ""
|
|
606
|
+
if "input" in inputs:
|
|
607
|
+
content = str(inputs["input"])
|
|
608
|
+
elif "question" in inputs:
|
|
609
|
+
content = str(inputs["question"])
|
|
610
|
+
elif len(inputs) == 1:
|
|
611
|
+
content = str(list(inputs.values())[0])
|
|
612
|
+
else:
|
|
613
|
+
content = str(inputs)
|
|
614
|
+
|
|
615
|
+
step = UserInputStep(
|
|
616
|
+
step_id=step_id,
|
|
617
|
+
timestamp=datetime.now(timezone.utc),
|
|
618
|
+
content=content,
|
|
619
|
+
)
|
|
620
|
+
|
|
621
|
+
trace = self._instrumentor._get_current_trace()
|
|
622
|
+
trace.add_step(step)
|
|
623
|
+
|
|
624
|
+
def on_chain_end(
|
|
625
|
+
self,
|
|
626
|
+
outputs: dict[str, Any],
|
|
627
|
+
*,
|
|
628
|
+
run_id: UUID,
|
|
629
|
+
parent_run_id: Optional[UUID] = None,
|
|
630
|
+
**kwargs: Any,
|
|
631
|
+
) -> None:
|
|
632
|
+
"""Called when chain ends."""
|
|
633
|
+
# Only record as final output if this is a top-level chain
|
|
634
|
+
if parent_run_id is None:
|
|
635
|
+
step_id = str(uuid.uuid4())
|
|
636
|
+
|
|
637
|
+
# Extract output
|
|
638
|
+
content: Any = ""
|
|
639
|
+
if "output" in outputs:
|
|
640
|
+
content = outputs["output"]
|
|
641
|
+
elif "result" in outputs:
|
|
642
|
+
content = outputs["result"]
|
|
643
|
+
elif "answer" in outputs:
|
|
644
|
+
content = outputs["answer"]
|
|
645
|
+
elif len(outputs) == 1:
|
|
646
|
+
content = list(outputs.values())[0]
|
|
647
|
+
else:
|
|
648
|
+
content = outputs
|
|
649
|
+
|
|
650
|
+
step = FinalOutputStep(
|
|
651
|
+
step_id=step_id,
|
|
652
|
+
timestamp=datetime.now(timezone.utc),
|
|
653
|
+
content=content,
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
trace = self._instrumentor._get_current_trace()
|
|
657
|
+
trace.add_step(step)
|
|
658
|
+
|
|
659
|
+
# Finalize the trace
|
|
660
|
+
self._instrumentor._finalize_current_trace()
|
|
661
|
+
|
|
662
|
+
def on_chain_error(
|
|
663
|
+
self,
|
|
664
|
+
error: BaseException,
|
|
665
|
+
*,
|
|
666
|
+
run_id: UUID,
|
|
667
|
+
parent_run_id: Optional[UUID] = None,
|
|
668
|
+
**kwargs: Any,
|
|
669
|
+
) -> None:
|
|
670
|
+
"""Called when chain errors."""
|
|
671
|
+
logger.debug(f"Chain error: {error}")
|