judgeval 0.7.1__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- judgeval/__init__.py +139 -12
- judgeval/api/__init__.py +501 -0
- judgeval/api/api_types.py +344 -0
- judgeval/cli.py +2 -4
- judgeval/constants.py +10 -26
- judgeval/data/evaluation_run.py +49 -26
- judgeval/data/example.py +2 -2
- judgeval/data/judgment_types.py +266 -82
- judgeval/data/result.py +4 -5
- judgeval/data/scorer_data.py +4 -2
- judgeval/data/tool.py +2 -2
- judgeval/data/trace.py +7 -50
- judgeval/data/trace_run.py +7 -4
- judgeval/{dataset.py → dataset/__init__.py} +43 -28
- judgeval/env.py +67 -0
- judgeval/{run_evaluation.py → evaluation/__init__.py} +29 -95
- judgeval/exceptions.py +27 -0
- judgeval/integrations/langgraph/__init__.py +788 -0
- judgeval/judges/__init__.py +2 -2
- judgeval/judges/litellm_judge.py +75 -15
- judgeval/judges/together_judge.py +86 -18
- judgeval/judges/utils.py +7 -21
- judgeval/{common/logger.py → logger.py} +8 -6
- judgeval/scorers/__init__.py +0 -4
- judgeval/scorers/agent_scorer.py +3 -7
- judgeval/scorers/api_scorer.py +8 -13
- judgeval/scorers/base_scorer.py +52 -32
- judgeval/scorers/example_scorer.py +1 -3
- judgeval/scorers/judgeval_scorers/api_scorers/__init__.py +0 -14
- judgeval/scorers/judgeval_scorers/api_scorers/prompt_scorer.py +45 -20
- judgeval/scorers/judgeval_scorers/api_scorers/tool_dependency.py +2 -2
- judgeval/scorers/judgeval_scorers/api_scorers/tool_order.py +3 -3
- judgeval/scorers/score.py +21 -31
- judgeval/scorers/trace_api_scorer.py +5 -0
- judgeval/scorers/utils.py +1 -103
- judgeval/tracer/__init__.py +1075 -2
- judgeval/tracer/constants.py +1 -0
- judgeval/tracer/exporters/__init__.py +37 -0
- judgeval/tracer/exporters/s3.py +119 -0
- judgeval/tracer/exporters/store.py +43 -0
- judgeval/tracer/exporters/utils.py +32 -0
- judgeval/tracer/keys.py +67 -0
- judgeval/tracer/llm/__init__.py +1233 -0
- judgeval/{common/tracer → tracer/llm}/providers.py +5 -10
- judgeval/{local_eval_queue.py → tracer/local_eval_queue.py} +15 -10
- judgeval/tracer/managers.py +188 -0
- judgeval/tracer/processors/__init__.py +181 -0
- judgeval/tracer/utils.py +20 -0
- judgeval/trainer/__init__.py +5 -0
- judgeval/{common/trainer → trainer}/config.py +12 -9
- judgeval/{common/trainer → trainer}/console.py +2 -9
- judgeval/{common/trainer → trainer}/trainable_model.py +12 -7
- judgeval/{common/trainer → trainer}/trainer.py +119 -17
- judgeval/utils/async_utils.py +2 -3
- judgeval/utils/decorators.py +24 -0
- judgeval/utils/file_utils.py +37 -4
- judgeval/utils/guards.py +32 -0
- judgeval/utils/meta.py +14 -0
- judgeval/{common/api/json_encoder.py → utils/serialize.py} +7 -1
- judgeval/utils/testing.py +88 -0
- judgeval/utils/url.py +10 -0
- judgeval/{version_check.py → utils/version_check.py} +3 -3
- judgeval/version.py +5 -0
- judgeval/warnings.py +4 -0
- {judgeval-0.7.1.dist-info → judgeval-0.9.0.dist-info}/METADATA +12 -14
- judgeval-0.9.0.dist-info/RECORD +80 -0
- judgeval/clients.py +0 -35
- judgeval/common/__init__.py +0 -13
- judgeval/common/api/__init__.py +0 -3
- judgeval/common/api/api.py +0 -375
- judgeval/common/api/constants.py +0 -186
- judgeval/common/exceptions.py +0 -27
- judgeval/common/storage/__init__.py +0 -6
- judgeval/common/storage/s3_storage.py +0 -97
- judgeval/common/tracer/__init__.py +0 -31
- judgeval/common/tracer/constants.py +0 -22
- judgeval/common/tracer/core.py +0 -2427
- judgeval/common/tracer/otel_exporter.py +0 -108
- judgeval/common/tracer/otel_span_processor.py +0 -188
- judgeval/common/tracer/span_processor.py +0 -37
- judgeval/common/tracer/span_transformer.py +0 -207
- judgeval/common/tracer/trace_manager.py +0 -101
- judgeval/common/trainer/__init__.py +0 -5
- judgeval/common/utils.py +0 -948
- judgeval/integrations/langgraph.py +0 -844
- judgeval/judges/mixture_of_judges.py +0 -287
- judgeval/judgment_client.py +0 -267
- judgeval/rules.py +0 -521
- judgeval/scorers/judgeval_scorers/api_scorers/execution_order.py +0 -52
- judgeval/scorers/judgeval_scorers/api_scorers/hallucination.py +0 -28
- judgeval/utils/alerts.py +0 -93
- judgeval/utils/requests.py +0 -50
- judgeval-0.7.1.dist-info/RECORD +0 -82
- {judgeval-0.7.1.dist-info → judgeval-0.9.0.dist-info}/WHEEL +0 -0
- {judgeval-0.7.1.dist-info → judgeval-0.9.0.dist-info}/entry_points.txt +0 -0
- {judgeval-0.7.1.dist-info → judgeval-0.9.0.dist-info}/licenses/LICENSE.md +0 -0
@@ -0,0 +1,788 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import time
|
4
|
+
import uuid
|
5
|
+
from typing import Any, Dict, List, Optional, Sequence, Set, Type
|
6
|
+
from uuid import UUID
|
7
|
+
|
8
|
+
try:
|
9
|
+
from langchain_core.callbacks import BaseCallbackHandler
|
10
|
+
from langchain_core.agents import AgentAction, AgentFinish
|
11
|
+
from langchain_core.outputs import LLMResult, ChatGeneration
|
12
|
+
from langchain_core.messages import (
|
13
|
+
AIMessage,
|
14
|
+
BaseMessage,
|
15
|
+
ChatMessage,
|
16
|
+
FunctionMessage,
|
17
|
+
HumanMessage,
|
18
|
+
SystemMessage,
|
19
|
+
ToolMessage,
|
20
|
+
)
|
21
|
+
from langchain_core.documents import Document
|
22
|
+
except ImportError as e:
|
23
|
+
raise ImportError(
|
24
|
+
"Judgeval's langgraph integration requires langchain to be installed. Please install it with `pip install judgeval[langchain]`"
|
25
|
+
) from e
|
26
|
+
|
27
|
+
from judgeval.tracer import Tracer
|
28
|
+
from judgeval.tracer.keys import AttributeKeys
|
29
|
+
from judgeval.tracer.managers import sync_span_context
|
30
|
+
from judgeval.utils.serialize import safe_serialize
|
31
|
+
from judgeval.logger import judgeval_logger
|
32
|
+
from opentelemetry.trace import Status, StatusCode, Span
|
33
|
+
|
34
|
+
# Control flow exception types that should not be treated as errors
|
35
|
+
CONTROL_FLOW_EXCEPTION_TYPES: Set[Type[BaseException]] = set()
|
36
|
+
|
37
|
+
try:
|
38
|
+
from langgraph.errors import GraphBubbleUp
|
39
|
+
|
40
|
+
CONTROL_FLOW_EXCEPTION_TYPES.add(GraphBubbleUp)
|
41
|
+
except ImportError:
|
42
|
+
pass
|
43
|
+
|
44
|
+
LANGSMITH_TAG_HIDDEN: str = "langsmith:hidden"
|
45
|
+
|
46
|
+
|
47
|
+
class JudgevalCallbackHandler(BaseCallbackHandler):
|
48
|
+
"""
|
49
|
+
LangGraph/LangChain Callback Handler that creates OpenTelemetry spans
|
50
|
+
using the Judgeval tracer framework.
|
51
|
+
|
52
|
+
This handler tracks the execution of chains, tools, LLMs, and other components
|
53
|
+
in a LangGraph/LangChain application, creating proper span hierarchies for monitoring.
|
54
|
+
"""
|
55
|
+
|
56
|
+
# Prevent LangChain serialization issues
|
57
|
+
lc_serializable = False
|
58
|
+
lc_kwargs: dict = {}
|
59
|
+
|
60
|
+
def __init__(self, tracer: Optional[Tracer] = None):
|
61
|
+
"""
|
62
|
+
Initialize the callback handler.
|
63
|
+
|
64
|
+
Args:
|
65
|
+
tracer: Optional Tracer instance. If not provided, will try to use an active tracer.
|
66
|
+
"""
|
67
|
+
self.tracer = tracer
|
68
|
+
if self.tracer is None:
|
69
|
+
# Try to get an active tracer
|
70
|
+
if Tracer._active_tracers:
|
71
|
+
self.tracer = next(iter(Tracer._active_tracers))
|
72
|
+
else:
|
73
|
+
judgeval_logger.warning(
|
74
|
+
"No tracer provided and no active tracers found. "
|
75
|
+
"Callback handler will not create spans."
|
76
|
+
)
|
77
|
+
return
|
78
|
+
|
79
|
+
# Track spans by run_id for proper hierarchy
|
80
|
+
self.spans: Dict[UUID, Span] = {}
|
81
|
+
self.span_start_times: Dict[UUID, float] = {}
|
82
|
+
self.run_id_to_span_id: Dict[UUID, str] = {}
|
83
|
+
self.span_id_to_depth: Dict[str, int] = {}
|
84
|
+
self.root_run_id: Optional[UUID] = None
|
85
|
+
|
86
|
+
# Track execution for debugging
|
87
|
+
self.executed_nodes: List[str] = []
|
88
|
+
self.executed_tools: List[str] = []
|
89
|
+
self.executed_node_tools: List[Dict[str, Any]] = []
|
90
|
+
|
91
|
+
def reset(self):
|
92
|
+
"""Reset handler state for reuse across multiple executions."""
|
93
|
+
self.spans.clear()
|
94
|
+
self.span_start_times.clear()
|
95
|
+
self.executed_nodes.clear()
|
96
|
+
self.executed_tools.clear()
|
97
|
+
self.executed_node_tools.clear()
|
98
|
+
|
99
|
+
def _get_run_name(self, serialized: Optional[Dict[str, Any]], **kwargs: Any) -> str:
|
100
|
+
"""Extract the name of the operation from serialized data or kwargs."""
|
101
|
+
if "name" in kwargs and kwargs["name"] is not None:
|
102
|
+
return str(kwargs["name"])
|
103
|
+
|
104
|
+
if serialized is None:
|
105
|
+
return "<unknown>"
|
106
|
+
|
107
|
+
try:
|
108
|
+
return str(serialized["name"])
|
109
|
+
except (KeyError, TypeError):
|
110
|
+
pass
|
111
|
+
|
112
|
+
try:
|
113
|
+
return str(serialized["id"][-1])
|
114
|
+
except (KeyError, TypeError):
|
115
|
+
pass
|
116
|
+
|
117
|
+
return "<unknown>"
|
118
|
+
|
119
|
+
def _convert_message_to_dict(self, message: BaseMessage) -> Dict[str, Any]:
|
120
|
+
"""Convert a LangChain message to a dictionary for storage."""
|
121
|
+
if isinstance(message, HumanMessage):
|
122
|
+
message_dict = {"role": "user", "content": message.content}
|
123
|
+
elif isinstance(message, AIMessage):
|
124
|
+
message_dict = {"role": "assistant", "content": message.content}
|
125
|
+
elif isinstance(message, SystemMessage):
|
126
|
+
message_dict = {"role": "system", "content": message.content}
|
127
|
+
elif isinstance(message, ToolMessage):
|
128
|
+
message_dict = {
|
129
|
+
"role": "tool",
|
130
|
+
"content": message.content,
|
131
|
+
"tool_call_id": message.tool_call_id,
|
132
|
+
}
|
133
|
+
elif isinstance(message, FunctionMessage):
|
134
|
+
message_dict = {"role": "function", "content": message.content}
|
135
|
+
elif isinstance(message, ChatMessage):
|
136
|
+
message_dict = {"role": message.role, "content": message.content}
|
137
|
+
else:
|
138
|
+
message_dict = {"role": "unknown", "content": str(message.content)}
|
139
|
+
|
140
|
+
if hasattr(message, "additional_kwargs") and message.additional_kwargs:
|
141
|
+
message_dict["additional_kwargs"] = str(message.additional_kwargs)
|
142
|
+
|
143
|
+
return message_dict
|
144
|
+
|
145
|
+
def _create_message_dicts(
|
146
|
+
self, messages: List[BaseMessage]
|
147
|
+
) -> List[Dict[str, Any]]:
|
148
|
+
"""Convert a list of LangChain messages to dictionaries."""
|
149
|
+
return [self._convert_message_to_dict(m) for m in messages]
|
150
|
+
|
151
|
+
def _join_tags_and_metadata(
|
152
|
+
self,
|
153
|
+
tags: Optional[List[str]] = None,
|
154
|
+
metadata: Optional[Dict[str, Any]] = None,
|
155
|
+
) -> Optional[Dict[str, Any]]:
|
156
|
+
"""Join tags and metadata into a single dictionary."""
|
157
|
+
final_dict = {}
|
158
|
+
if tags is not None and len(tags) > 0:
|
159
|
+
final_dict["tags"] = tags
|
160
|
+
if metadata is not None:
|
161
|
+
final_dict.update(metadata)
|
162
|
+
return final_dict if final_dict else None
|
163
|
+
|
164
|
+
def _start_span(
|
165
|
+
self,
|
166
|
+
run_id: UUID,
|
167
|
+
parent_run_id: Optional[UUID],
|
168
|
+
name: str,
|
169
|
+
span_type: str,
|
170
|
+
inputs: Any = None,
|
171
|
+
tags: Optional[List[str]] = None,
|
172
|
+
metadata: Optional[Dict[str, Any]] = None,
|
173
|
+
**extra_attributes: Any,
|
174
|
+
) -> None:
|
175
|
+
"""Start a new span for the given run."""
|
176
|
+
if not self.tracer:
|
177
|
+
return
|
178
|
+
|
179
|
+
# Skip internal spans
|
180
|
+
if name.startswith("__") and name.endswith("__"):
|
181
|
+
return
|
182
|
+
|
183
|
+
try:
|
184
|
+
# Determine if this is a root span
|
185
|
+
is_root = parent_run_id is None
|
186
|
+
if is_root:
|
187
|
+
self.root_run_id = run_id
|
188
|
+
|
189
|
+
# Calculate depth for proper hierarchy
|
190
|
+
current_depth = 0
|
191
|
+
if parent_run_id and parent_run_id in self.run_id_to_span_id:
|
192
|
+
parent_span_id = self.run_id_to_span_id[parent_run_id]
|
193
|
+
current_depth = self.span_id_to_depth.get(parent_span_id, 0) + 1
|
194
|
+
|
195
|
+
# Create span attributes
|
196
|
+
attributes = {
|
197
|
+
AttributeKeys.JUDGMENT_SPAN_KIND.value: span_type,
|
198
|
+
}
|
199
|
+
|
200
|
+
# Add metadata and tags
|
201
|
+
combined_metadata = self._join_tags_and_metadata(tags, metadata)
|
202
|
+
if combined_metadata:
|
203
|
+
metadata_str = safe_serialize(combined_metadata)
|
204
|
+
attributes["metadata"] = metadata_str
|
205
|
+
|
206
|
+
# Add extra attributes
|
207
|
+
for key, value in extra_attributes.items():
|
208
|
+
if value is not None:
|
209
|
+
attributes[str(key)] = str(value)
|
210
|
+
|
211
|
+
# Create span using the tracer's context manager for proper hierarchy
|
212
|
+
with sync_span_context(self.tracer, name, attributes) as span:
|
213
|
+
# Set input data if provided
|
214
|
+
if inputs is not None:
|
215
|
+
span.set_attribute(
|
216
|
+
AttributeKeys.JUDGMENT_INPUT.value, safe_serialize(inputs)
|
217
|
+
)
|
218
|
+
|
219
|
+
# Store span information for tracking
|
220
|
+
span_id = (
|
221
|
+
str(span.get_span_context().span_id)
|
222
|
+
if span.get_span_context()
|
223
|
+
else str(uuid.uuid4())
|
224
|
+
)
|
225
|
+
self.spans[run_id] = span
|
226
|
+
self.span_start_times[run_id] = time.time()
|
227
|
+
self.run_id_to_span_id[run_id] = span_id
|
228
|
+
self.span_id_to_depth[span_id] = current_depth
|
229
|
+
|
230
|
+
except Exception as e:
|
231
|
+
judgeval_logger.exception(f"Error starting span for {name}: {e}")
|
232
|
+
|
233
|
+
def _end_span(
|
234
|
+
self,
|
235
|
+
run_id: UUID,
|
236
|
+
outputs: Any = None,
|
237
|
+
error: Optional[BaseException] = None,
|
238
|
+
**extra_attributes: Any,
|
239
|
+
) -> None:
|
240
|
+
"""End the span for the given run."""
|
241
|
+
if run_id not in self.spans:
|
242
|
+
return
|
243
|
+
|
244
|
+
try:
|
245
|
+
span = self.spans[run_id]
|
246
|
+
|
247
|
+
# Set output data if provided
|
248
|
+
if outputs is not None:
|
249
|
+
span.set_attribute(
|
250
|
+
AttributeKeys.JUDGMENT_OUTPUT.value, safe_serialize(outputs)
|
251
|
+
)
|
252
|
+
|
253
|
+
# Set additional attributes
|
254
|
+
for key, value in extra_attributes.items():
|
255
|
+
if value is not None:
|
256
|
+
span.set_attribute(str(key), str(value))
|
257
|
+
|
258
|
+
# Handle errors
|
259
|
+
if error is not None:
|
260
|
+
# Check if this is a control flow exception
|
261
|
+
is_control_flow = any(
|
262
|
+
isinstance(error, t) for t in CONTROL_FLOW_EXCEPTION_TYPES
|
263
|
+
)
|
264
|
+
if not is_control_flow:
|
265
|
+
span.record_exception(error)
|
266
|
+
span.set_status(Status(StatusCode.ERROR, str(error)))
|
267
|
+
# Control flow exceptions don't set error status
|
268
|
+
else:
|
269
|
+
span.set_status(Status(StatusCode.OK))
|
270
|
+
|
271
|
+
# Note: The span will be ended automatically by the context manager
|
272
|
+
|
273
|
+
except Exception as e:
|
274
|
+
judgeval_logger.exception(f"Error ending span for run_id {run_id}: {e}")
|
275
|
+
finally:
|
276
|
+
# Cleanup tracking data
|
277
|
+
if run_id in self.spans:
|
278
|
+
del self.spans[run_id]
|
279
|
+
if run_id in self.span_start_times:
|
280
|
+
del self.span_start_times[run_id]
|
281
|
+
if run_id in self.run_id_to_span_id:
|
282
|
+
span_id = self.run_id_to_span_id[run_id]
|
283
|
+
del self.run_id_to_span_id[run_id]
|
284
|
+
if span_id in self.span_id_to_depth:
|
285
|
+
del self.span_id_to_depth[span_id]
|
286
|
+
|
287
|
+
# Check if this is the root run ending
|
288
|
+
if run_id == self.root_run_id:
|
289
|
+
self.root_run_id = None
|
290
|
+
|
291
|
+
def _log_debug_event(
|
292
|
+
self,
|
293
|
+
event_name: str,
|
294
|
+
run_id: UUID,
|
295
|
+
parent_run_id: Optional[UUID] = None,
|
296
|
+
**kwargs: Any,
|
297
|
+
) -> None:
|
298
|
+
"""Log debug information about callback events."""
|
299
|
+
judgeval_logger.debug(
|
300
|
+
f"Event: {event_name}, run_id: {str(run_id)[:8]}, "
|
301
|
+
f"parent_run_id: {str(parent_run_id)[:8] if parent_run_id else None}"
|
302
|
+
)
|
303
|
+
|
304
|
+
# Chain callbacks
|
305
|
+
def on_chain_start(
|
306
|
+
self,
|
307
|
+
serialized: Optional[Dict[str, Any]],
|
308
|
+
inputs: Dict[str, Any],
|
309
|
+
*,
|
310
|
+
run_id: UUID,
|
311
|
+
parent_run_id: Optional[UUID] = None,
|
312
|
+
tags: Optional[List[str]] = None,
|
313
|
+
metadata: Optional[Dict[str, Any]] = None,
|
314
|
+
**kwargs: Any,
|
315
|
+
) -> Any:
|
316
|
+
"""Called when a chain starts running."""
|
317
|
+
try:
|
318
|
+
self._log_debug_event(
|
319
|
+
"on_chain_start", run_id, parent_run_id, inputs=inputs
|
320
|
+
)
|
321
|
+
|
322
|
+
name = self._get_run_name(serialized, **kwargs)
|
323
|
+
|
324
|
+
# Check for LangGraph node
|
325
|
+
node_name = metadata.get("langgraph_node") if metadata else None
|
326
|
+
if node_name:
|
327
|
+
name = node_name
|
328
|
+
if name not in self.executed_nodes:
|
329
|
+
self.executed_nodes.append(name)
|
330
|
+
|
331
|
+
# Determine if this is a root LangGraph execution
|
332
|
+
is_langgraph_root = (
|
333
|
+
kwargs.get("name") == "LangGraph" and parent_run_id is None
|
334
|
+
)
|
335
|
+
if is_langgraph_root:
|
336
|
+
name = "LangGraph"
|
337
|
+
|
338
|
+
span_level = "DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None
|
339
|
+
|
340
|
+
self._start_span(
|
341
|
+
run_id=run_id,
|
342
|
+
parent_run_id=parent_run_id,
|
343
|
+
name=name,
|
344
|
+
span_type="chain",
|
345
|
+
inputs=inputs,
|
346
|
+
tags=tags,
|
347
|
+
metadata=metadata,
|
348
|
+
level=span_level,
|
349
|
+
serialized=safe_serialize(serialized) if serialized else None,
|
350
|
+
)
|
351
|
+
except Exception as e:
|
352
|
+
judgeval_logger.exception(f"Error in on_chain_start: {e}")
|
353
|
+
|
354
|
+
def on_chain_end(
|
355
|
+
self,
|
356
|
+
outputs: Dict[str, Any],
|
357
|
+
*,
|
358
|
+
run_id: UUID,
|
359
|
+
parent_run_id: Optional[UUID] = None,
|
360
|
+
**kwargs: Any,
|
361
|
+
) -> Any:
|
362
|
+
"""Called when a chain ends successfully."""
|
363
|
+
try:
|
364
|
+
self._log_debug_event(
|
365
|
+
"on_chain_end", run_id, parent_run_id, outputs=outputs
|
366
|
+
)
|
367
|
+
self._end_span(run_id=run_id, outputs=outputs)
|
368
|
+
except Exception as e:
|
369
|
+
judgeval_logger.exception(f"Error in on_chain_end: {e}")
|
370
|
+
|
371
|
+
def on_chain_error(
|
372
|
+
self,
|
373
|
+
error: BaseException,
|
374
|
+
*,
|
375
|
+
run_id: UUID,
|
376
|
+
parent_run_id: Optional[UUID] = None,
|
377
|
+
**kwargs: Any,
|
378
|
+
) -> None:
|
379
|
+
"""Called when a chain encounters an error."""
|
380
|
+
try:
|
381
|
+
self._log_debug_event("on_chain_error", run_id, parent_run_id, error=error)
|
382
|
+
self._end_span(run_id=run_id, error=error)
|
383
|
+
except Exception as e:
|
384
|
+
judgeval_logger.exception(f"Error in on_chain_error: {e}")
|
385
|
+
|
386
|
+
# LLM callbacks
|
387
|
+
def on_llm_start(
|
388
|
+
self,
|
389
|
+
serialized: Optional[Dict[str, Any]],
|
390
|
+
prompts: List[str],
|
391
|
+
*,
|
392
|
+
run_id: UUID,
|
393
|
+
parent_run_id: Optional[UUID] = None,
|
394
|
+
tags: Optional[List[str]] = None,
|
395
|
+
metadata: Optional[Dict[str, Any]] = None,
|
396
|
+
**kwargs: Any,
|
397
|
+
) -> Any:
|
398
|
+
"""Called when an LLM starts generating."""
|
399
|
+
try:
|
400
|
+
self._log_debug_event(
|
401
|
+
"on_llm_start", run_id, parent_run_id, prompts=prompts
|
402
|
+
)
|
403
|
+
|
404
|
+
name = self._get_run_name(serialized, **kwargs)
|
405
|
+
model_name = self._extract_model_name(serialized, kwargs)
|
406
|
+
|
407
|
+
prompt_data = prompts[0] if len(prompts) == 1 else prompts
|
408
|
+
|
409
|
+
self._start_span(
|
410
|
+
run_id=run_id,
|
411
|
+
parent_run_id=parent_run_id,
|
412
|
+
name=name,
|
413
|
+
span_type="llm",
|
414
|
+
inputs=prompt_data,
|
415
|
+
tags=tags,
|
416
|
+
metadata=metadata,
|
417
|
+
model=model_name,
|
418
|
+
serialized=safe_serialize(serialized) if serialized else None,
|
419
|
+
)
|
420
|
+
|
421
|
+
# Set GenAI specific attributes
|
422
|
+
if run_id in self.spans:
|
423
|
+
span = self.spans[run_id]
|
424
|
+
if model_name:
|
425
|
+
span.set_attribute(AttributeKeys.GEN_AI_REQUEST_MODEL, model_name)
|
426
|
+
span.set_attribute(
|
427
|
+
AttributeKeys.GEN_AI_PROMPT, safe_serialize(prompt_data)
|
428
|
+
)
|
429
|
+
|
430
|
+
# Set model parameters if available
|
431
|
+
invocation_params = kwargs.get("invocation_params", {})
|
432
|
+
if "temperature" in invocation_params:
|
433
|
+
span.set_attribute(
|
434
|
+
AttributeKeys.GEN_AI_REQUEST_TEMPERATURE,
|
435
|
+
float(invocation_params["temperature"]),
|
436
|
+
)
|
437
|
+
if "max_tokens" in invocation_params:
|
438
|
+
span.set_attribute(
|
439
|
+
AttributeKeys.GEN_AI_REQUEST_MAX_TOKENS,
|
440
|
+
int(invocation_params["max_tokens"]),
|
441
|
+
)
|
442
|
+
|
443
|
+
except Exception as e:
|
444
|
+
judgeval_logger.exception(f"Error in on_llm_start: {e}")
|
445
|
+
|
446
|
+
def on_chat_model_start(
|
447
|
+
self,
|
448
|
+
serialized: Optional[Dict[str, Any]],
|
449
|
+
messages: List[List[BaseMessage]],
|
450
|
+
*,
|
451
|
+
run_id: UUID,
|
452
|
+
parent_run_id: Optional[UUID] = None,
|
453
|
+
tags: Optional[List[str]] = None,
|
454
|
+
metadata: Optional[Dict[str, Any]] = None,
|
455
|
+
**kwargs: Any,
|
456
|
+
) -> Any:
|
457
|
+
"""Called when a chat model starts generating."""
|
458
|
+
try:
|
459
|
+
self._log_debug_event(
|
460
|
+
"on_chat_model_start", run_id, parent_run_id, messages=messages
|
461
|
+
)
|
462
|
+
|
463
|
+
name = self._get_run_name(serialized, **kwargs)
|
464
|
+
model_name = self._extract_model_name(serialized, kwargs)
|
465
|
+
|
466
|
+
# Flatten messages
|
467
|
+
flattened_messages = []
|
468
|
+
for message_list in messages:
|
469
|
+
flattened_messages.extend(self._create_message_dicts(message_list))
|
470
|
+
|
471
|
+
self._start_span(
|
472
|
+
run_id=run_id,
|
473
|
+
parent_run_id=parent_run_id,
|
474
|
+
name=name,
|
475
|
+
span_type="llm",
|
476
|
+
inputs=flattened_messages,
|
477
|
+
tags=tags,
|
478
|
+
metadata=metadata,
|
479
|
+
model=model_name,
|
480
|
+
serialized=safe_serialize(serialized) if serialized else None,
|
481
|
+
)
|
482
|
+
|
483
|
+
# Set GenAI specific attributes
|
484
|
+
if run_id in self.spans:
|
485
|
+
span = self.spans[run_id]
|
486
|
+
if model_name:
|
487
|
+
span.set_attribute(AttributeKeys.GEN_AI_REQUEST_MODEL, model_name)
|
488
|
+
span.set_attribute(
|
489
|
+
AttributeKeys.GEN_AI_PROMPT, safe_serialize(flattened_messages)
|
490
|
+
)
|
491
|
+
|
492
|
+
except Exception as e:
|
493
|
+
judgeval_logger.exception(f"Error in on_chat_model_start: {e}")
|
494
|
+
|
495
|
+
def on_llm_end(
|
496
|
+
self,
|
497
|
+
response: LLMResult,
|
498
|
+
*,
|
499
|
+
run_id: UUID,
|
500
|
+
parent_run_id: Optional[UUID] = None,
|
501
|
+
**kwargs: Any,
|
502
|
+
) -> Any:
|
503
|
+
"""Called when an LLM finishes generating."""
|
504
|
+
try:
|
505
|
+
self._log_debug_event(
|
506
|
+
"on_llm_end", run_id, parent_run_id, response=response
|
507
|
+
)
|
508
|
+
|
509
|
+
# Extract response content
|
510
|
+
if response.generations:
|
511
|
+
last_generation = response.generations[-1][-1]
|
512
|
+
if (
|
513
|
+
isinstance(last_generation, ChatGeneration)
|
514
|
+
and last_generation.message
|
515
|
+
):
|
516
|
+
output = self._convert_message_to_dict(last_generation.message)
|
517
|
+
else:
|
518
|
+
output = (
|
519
|
+
last_generation.text
|
520
|
+
if hasattr(last_generation, "text")
|
521
|
+
else str(last_generation)
|
522
|
+
)
|
523
|
+
else:
|
524
|
+
output = ""
|
525
|
+
|
526
|
+
# Extract usage information
|
527
|
+
usage_attrs = {}
|
528
|
+
if response.llm_output and "token_usage" in response.llm_output:
|
529
|
+
token_usage = response.llm_output["token_usage"]
|
530
|
+
if hasattr(token_usage, "prompt_tokens"):
|
531
|
+
usage_attrs[AttributeKeys.GEN_AI_USAGE_INPUT_TOKENS] = (
|
532
|
+
token_usage.prompt_tokens
|
533
|
+
)
|
534
|
+
if hasattr(token_usage, "completion_tokens"):
|
535
|
+
usage_attrs[AttributeKeys.GEN_AI_USAGE_OUTPUT_TOKENS] = (
|
536
|
+
token_usage.completion_tokens
|
537
|
+
)
|
538
|
+
|
539
|
+
# Set completion attribute
|
540
|
+
if run_id in self.spans:
|
541
|
+
span = self.spans[run_id]
|
542
|
+
span.set_attribute(
|
543
|
+
AttributeKeys.GEN_AI_COMPLETION, safe_serialize(output)
|
544
|
+
)
|
545
|
+
|
546
|
+
# Set usage attributes
|
547
|
+
for key, value in usage_attrs.items():
|
548
|
+
span.set_attribute(key, value)
|
549
|
+
|
550
|
+
self._end_span(run_id=run_id, outputs=output, **usage_attrs)
|
551
|
+
|
552
|
+
except Exception as e:
|
553
|
+
judgeval_logger.exception(f"Error in on_llm_end: {e}")
|
554
|
+
|
555
|
+
def on_llm_error(
|
556
|
+
self,
|
557
|
+
error: BaseException,
|
558
|
+
*,
|
559
|
+
run_id: UUID,
|
560
|
+
parent_run_id: Optional[UUID] = None,
|
561
|
+
**kwargs: Any,
|
562
|
+
) -> Any:
|
563
|
+
"""Called when an LLM encounters an error."""
|
564
|
+
try:
|
565
|
+
self._log_debug_event("on_llm_error", run_id, parent_run_id, error=error)
|
566
|
+
self._end_span(run_id=run_id, error=error)
|
567
|
+
except Exception as e:
|
568
|
+
judgeval_logger.exception(f"Error in on_llm_error: {e}")
|
569
|
+
|
570
|
+
# Tool callbacks
|
571
|
+
def on_tool_start(
|
572
|
+
self,
|
573
|
+
serialized: Optional[Dict[str, Any]],
|
574
|
+
input_str: str,
|
575
|
+
*,
|
576
|
+
run_id: UUID,
|
577
|
+
parent_run_id: Optional[UUID] = None,
|
578
|
+
tags: Optional[List[str]] = None,
|
579
|
+
metadata: Optional[Dict[str, Any]] = None,
|
580
|
+
**kwargs: Any,
|
581
|
+
) -> Any:
|
582
|
+
"""Called when a tool starts executing."""
|
583
|
+
try:
|
584
|
+
self._log_debug_event(
|
585
|
+
"on_tool_start", run_id, parent_run_id, input_str=input_str
|
586
|
+
)
|
587
|
+
|
588
|
+
name = self._get_run_name(serialized, **kwargs)
|
589
|
+
if name not in self.executed_tools:
|
590
|
+
self.executed_tools.append(name)
|
591
|
+
|
592
|
+
self._start_span(
|
593
|
+
run_id=run_id,
|
594
|
+
parent_run_id=parent_run_id,
|
595
|
+
name=name,
|
596
|
+
span_type="tool",
|
597
|
+
inputs=input_str,
|
598
|
+
tags=tags,
|
599
|
+
metadata=metadata,
|
600
|
+
serialized=safe_serialize(serialized) if serialized else None,
|
601
|
+
)
|
602
|
+
except Exception as e:
|
603
|
+
judgeval_logger.exception(f"Error in on_tool_start: {e}")
|
604
|
+
|
605
|
+
def on_tool_end(
|
606
|
+
self,
|
607
|
+
output: str,
|
608
|
+
*,
|
609
|
+
run_id: UUID,
|
610
|
+
parent_run_id: Optional[UUID] = None,
|
611
|
+
**kwargs: Any,
|
612
|
+
) -> Any:
|
613
|
+
"""Called when a tool finishes executing."""
|
614
|
+
try:
|
615
|
+
self._log_debug_event("on_tool_end", run_id, parent_run_id, output=output)
|
616
|
+
self._end_span(run_id=run_id, outputs=output)
|
617
|
+
except Exception as e:
|
618
|
+
judgeval_logger.exception(f"Error in on_tool_end: {e}")
|
619
|
+
|
620
|
+
def on_tool_error(
|
621
|
+
self,
|
622
|
+
error: BaseException,
|
623
|
+
*,
|
624
|
+
run_id: UUID,
|
625
|
+
parent_run_id: Optional[UUID] = None,
|
626
|
+
**kwargs: Any,
|
627
|
+
) -> Any:
|
628
|
+
"""Called when a tool encounters an error."""
|
629
|
+
try:
|
630
|
+
self._log_debug_event("on_tool_error", run_id, parent_run_id, error=error)
|
631
|
+
self._end_span(run_id=run_id, error=error)
|
632
|
+
except Exception as e:
|
633
|
+
judgeval_logger.exception(f"Error in on_tool_error: {e}")
|
634
|
+
|
635
|
+
# Agent callbacks
|
636
|
+
def on_agent_action(
|
637
|
+
self,
|
638
|
+
action: AgentAction,
|
639
|
+
*,
|
640
|
+
run_id: UUID,
|
641
|
+
parent_run_id: Optional[UUID] = None,
|
642
|
+
**kwargs: Any,
|
643
|
+
) -> Any:
|
644
|
+
"""Called when an agent takes an action."""
|
645
|
+
try:
|
646
|
+
self._log_debug_event(
|
647
|
+
"on_agent_action", run_id, parent_run_id, action=action
|
648
|
+
)
|
649
|
+
|
650
|
+
if run_id in self.spans:
|
651
|
+
span = self.spans[run_id]
|
652
|
+
span.set_attribute("agent.action.tool", action.tool)
|
653
|
+
span.set_attribute(
|
654
|
+
"agent.action.tool_input", safe_serialize(action.tool_input)
|
655
|
+
)
|
656
|
+
span.set_attribute("agent.action.log", action.log)
|
657
|
+
|
658
|
+
self._end_span(
|
659
|
+
run_id=run_id,
|
660
|
+
outputs={"action": action.tool, "input": action.tool_input},
|
661
|
+
)
|
662
|
+
except Exception as e:
|
663
|
+
judgeval_logger.exception(f"Error in on_agent_action: {e}")
|
664
|
+
|
665
|
+
def on_agent_finish(
|
666
|
+
self,
|
667
|
+
finish: AgentFinish,
|
668
|
+
*,
|
669
|
+
run_id: UUID,
|
670
|
+
parent_run_id: Optional[UUID] = None,
|
671
|
+
**kwargs: Any,
|
672
|
+
) -> Any:
|
673
|
+
"""Called when an agent finishes."""
|
674
|
+
try:
|
675
|
+
self._log_debug_event(
|
676
|
+
"on_agent_finish", run_id, parent_run_id, finish=finish
|
677
|
+
)
|
678
|
+
|
679
|
+
if run_id in self.spans:
|
680
|
+
span = self.spans[run_id]
|
681
|
+
span.set_attribute("agent.finish.log", finish.log)
|
682
|
+
|
683
|
+
self._end_span(run_id=run_id, outputs=finish.return_values)
|
684
|
+
except Exception as e:
|
685
|
+
judgeval_logger.exception(f"Error in on_agent_finish: {e}")
|
686
|
+
|
687
|
+
# Retriever callbacks
|
688
|
+
def on_retriever_start(
|
689
|
+
self,
|
690
|
+
serialized: Optional[Dict[str, Any]],
|
691
|
+
query: str,
|
692
|
+
*,
|
693
|
+
run_id: UUID,
|
694
|
+
parent_run_id: Optional[UUID] = None,
|
695
|
+
tags: Optional[List[str]] = None,
|
696
|
+
metadata: Optional[Dict[str, Any]] = None,
|
697
|
+
**kwargs: Any,
|
698
|
+
) -> Any:
|
699
|
+
"""Called when a retriever starts."""
|
700
|
+
try:
|
701
|
+
self._log_debug_event(
|
702
|
+
"on_retriever_start", run_id, parent_run_id, query=query
|
703
|
+
)
|
704
|
+
|
705
|
+
name = self._get_run_name(serialized, **kwargs)
|
706
|
+
|
707
|
+
self._start_span(
|
708
|
+
run_id=run_id,
|
709
|
+
parent_run_id=parent_run_id,
|
710
|
+
name=name,
|
711
|
+
span_type="retriever",
|
712
|
+
inputs=query,
|
713
|
+
tags=tags,
|
714
|
+
metadata=metadata,
|
715
|
+
serialized=safe_serialize(serialized) if serialized else None,
|
716
|
+
)
|
717
|
+
except Exception as e:
|
718
|
+
judgeval_logger.exception(f"Error in on_retriever_start: {e}")
|
719
|
+
|
720
|
+
def on_retriever_end(
|
721
|
+
self,
|
722
|
+
documents: Sequence[Document],
|
723
|
+
*,
|
724
|
+
run_id: UUID,
|
725
|
+
parent_run_id: Optional[UUID] = None,
|
726
|
+
**kwargs: Any,
|
727
|
+
) -> Any:
|
728
|
+
"""Called when a retriever finishes."""
|
729
|
+
try:
|
730
|
+
self._log_debug_event(
|
731
|
+
"on_retriever_end", run_id, parent_run_id, documents=documents
|
732
|
+
)
|
733
|
+
|
734
|
+
# Convert documents to serializable format
|
735
|
+
doc_data = [
|
736
|
+
{"page_content": doc.page_content, "metadata": doc.metadata}
|
737
|
+
for doc in documents
|
738
|
+
]
|
739
|
+
|
740
|
+
if run_id in self.spans:
|
741
|
+
span = self.spans[run_id]
|
742
|
+
span.set_attribute("retriever.document_count", len(documents))
|
743
|
+
|
744
|
+
self._end_span(
|
745
|
+
run_id=run_id, outputs=doc_data, document_count=len(documents)
|
746
|
+
)
|
747
|
+
except Exception as e:
|
748
|
+
judgeval_logger.exception(f"Error in on_retriever_end: {e}")
|
749
|
+
|
750
|
+
def on_retriever_error(
|
751
|
+
self,
|
752
|
+
error: BaseException,
|
753
|
+
*,
|
754
|
+
run_id: UUID,
|
755
|
+
parent_run_id: Optional[UUID] = None,
|
756
|
+
**kwargs: Any,
|
757
|
+
) -> Any:
|
758
|
+
"""Called when a retriever encounters an error."""
|
759
|
+
try:
|
760
|
+
self._log_debug_event(
|
761
|
+
"on_retriever_error", run_id, parent_run_id, error=error
|
762
|
+
)
|
763
|
+
self._end_span(run_id=run_id, error=error)
|
764
|
+
except Exception as e:
|
765
|
+
judgeval_logger.exception(f"Error in on_retriever_error: {e}")
|
766
|
+
|
767
|
+
def _extract_model_name(
|
768
|
+
self, serialized: Optional[Dict[str, Any]], kwargs: Dict[str, Any]
|
769
|
+
) -> Optional[str]:
|
770
|
+
"""Extract model name from serialized data or kwargs."""
|
771
|
+
# Try to get from invocation params
|
772
|
+
invocation_params = kwargs.get("invocation_params", {})
|
773
|
+
if "model_name" in invocation_params:
|
774
|
+
return invocation_params["model_name"]
|
775
|
+
if "model" in invocation_params:
|
776
|
+
return invocation_params["model"]
|
777
|
+
|
778
|
+
# Try to get from serialized data
|
779
|
+
if serialized:
|
780
|
+
if "model_name" in serialized:
|
781
|
+
return serialized["model_name"]
|
782
|
+
if "model" in serialized:
|
783
|
+
return serialized["model"]
|
784
|
+
|
785
|
+
return None
|
786
|
+
|
787
|
+
|
788
|
+
__all__ = ["JudgevalCallbackHandler"]
|