judgeval 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. judgeval/__init__.py +83 -0
  2. judgeval/clients.py +19 -0
  3. judgeval/common/__init__.py +8 -0
  4. judgeval/common/exceptions.py +28 -0
  5. judgeval/common/logger.py +189 -0
  6. judgeval/common/tracer.py +587 -0
  7. judgeval/common/utils.py +763 -0
  8. judgeval/constants.py +55 -0
  9. judgeval/data/__init__.py +14 -0
  10. judgeval/data/api_example.py +111 -0
  11. judgeval/data/datasets/__init__.py +4 -0
  12. judgeval/data/datasets/dataset.py +407 -0
  13. judgeval/data/datasets/ground_truth.py +54 -0
  14. judgeval/data/datasets/utils.py +74 -0
  15. judgeval/data/example.py +76 -0
  16. judgeval/data/result.py +83 -0
  17. judgeval/data/scorer_data.py +86 -0
  18. judgeval/evaluation_run.py +130 -0
  19. judgeval/judges/__init__.py +7 -0
  20. judgeval/judges/base_judge.py +44 -0
  21. judgeval/judges/litellm_judge.py +49 -0
  22. judgeval/judges/mixture_of_judges.py +248 -0
  23. judgeval/judges/together_judge.py +55 -0
  24. judgeval/judges/utils.py +45 -0
  25. judgeval/judgment_client.py +244 -0
  26. judgeval/run_evaluation.py +355 -0
  27. judgeval/scorers/__init__.py +30 -0
  28. judgeval/scorers/base_scorer.py +51 -0
  29. judgeval/scorers/custom_scorer.py +134 -0
  30. judgeval/scorers/judgeval_scorers/__init__.py +21 -0
  31. judgeval/scorers/judgeval_scorers/answer_relevancy.py +19 -0
  32. judgeval/scorers/judgeval_scorers/contextual_precision.py +19 -0
  33. judgeval/scorers/judgeval_scorers/contextual_recall.py +19 -0
  34. judgeval/scorers/judgeval_scorers/contextual_relevancy.py +22 -0
  35. judgeval/scorers/judgeval_scorers/faithfulness.py +19 -0
  36. judgeval/scorers/judgeval_scorers/hallucination.py +19 -0
  37. judgeval/scorers/judgeval_scorers/json_correctness.py +32 -0
  38. judgeval/scorers/judgeval_scorers/summarization.py +20 -0
  39. judgeval/scorers/judgeval_scorers/tool_correctness.py +19 -0
  40. judgeval/scorers/prompt_scorer.py +439 -0
  41. judgeval/scorers/score.py +427 -0
  42. judgeval/scorers/utils.py +175 -0
  43. judgeval-0.0.1.dist-info/METADATA +40 -0
  44. judgeval-0.0.1.dist-info/RECORD +46 -0
  45. judgeval-0.0.1.dist-info/WHEEL +4 -0
  46. judgeval-0.0.1.dist-info/licenses/LICENSE.md +202 -0
@@ -0,0 +1,587 @@
1
+ """
2
+ Tracing system for judgeval that allows for function tracing using decorators.
3
+ """
4
+
5
+ import time
6
+ import functools
7
+ import requests
8
+ import uuid
9
+ from contextlib import contextmanager
10
+ from typing import (
11
+ Optional,
12
+ Any,
13
+ List,
14
+ Literal,
15
+ Tuple,
16
+ Generator,
17
+ TypeAlias,
18
+ Union
19
+ )
20
+ from dataclasses import dataclass, field
21
+ from datetime import datetime
22
+ from openai import OpenAI
23
+ from together import Together
24
+ from anthropic import Anthropic
25
+ from typing import Dict
26
+ import inspect
27
+ import asyncio
28
+ import json
29
+ import warnings
30
+ from pydantic import BaseModel
31
+ from http import HTTPStatus
32
+
33
+ from judgeval.constants import JUDGMENT_TRACES_SAVE_API_URL
34
+ from judgeval.judgment_client import JudgmentClient
35
+ from judgeval.data import Example
36
+ from judgeval.scorers import JudgmentScorer, CustomScorer
37
+ from judgeval.data.result import ScoringResult
38
+
39
+ # Define type aliases for better code readability and maintainability
40
+ ApiClient: TypeAlias = Union[OpenAI, Together, Anthropic] # Supported API clients
41
+ TraceEntryType = Literal['enter', 'exit', 'output', 'input', 'evaluation'] # Valid trace entry types
42
+ SpanType = Literal['span', 'tool', 'llm', 'evaluation']
43
+ @dataclass
44
+ class TraceEntry:
45
+ """Represents a single trace entry with its visual representation.
46
+
47
+ Visual representations:
48
+ - enter: → (function entry)
49
+ - exit: ← (function exit)
50
+ - output: Output: (function return value)
51
+ - input: Input: (function parameters)
52
+ - evaluation: Evaluation: (evaluation results)
53
+ """
54
+ type: TraceEntryType
55
+ function: str # Name of the function being traced
56
+ depth: int # Indentation level for nested calls
57
+ message: str # Human-readable description
58
+ timestamp: float # Unix timestamp when entry was created
59
+ duration: Optional[float] = None # Time taken (for exit/evaluation entries)
60
+ output: Any = None # Function output value
61
+ # Use field() for mutable defaults to avoid shared state issues
62
+ inputs: dict = field(default_factory=dict)
63
+ span_type: SpanType = "span"
64
+ evaluation_result: Optional[List[ScoringResult]] = field(default=None)
65
+
66
+ def print_entry(self):
67
+ indent = " " * self.depth
68
+ if self.type == "enter":
69
+ print(f"{indent}→ {self.function} (trace: {self.message})")
70
+ elif self.type == "exit":
71
+ print(f"{indent}← {self.function} ({self.duration:.3f}s)")
72
+ elif self.type == "output":
73
+ print(f"{indent}Output: {self.output}")
74
+ elif self.type == "input":
75
+ print(f"{indent}Input: {self.inputs}")
76
+ elif self.type == "evaluation":
77
+ print(f"{indent}Evaluation: {self.evaluation_result} ({self.duration:.3f}s)")
78
+
79
+ def to_dict(self) -> dict:
80
+ """Convert the trace entry to a dictionary format for storage/transmission."""
81
+ try:
82
+ output = self._serialize_output()
83
+ except (TypeError, OverflowError, ValueError):
84
+ # Handle cases where output cannot be serialized
85
+ warnings.warn(f"Output for function {self.function} is not JSON serializable. Setting to None.")
86
+ output = None
87
+
88
+ # Build a complete dictionary representation of the trace entry
89
+ return {
90
+ "type": self.type,
91
+ "function": self.function,
92
+ "depth": self.depth,
93
+ "message": self.message,
94
+ "timestamp": self.timestamp,
95
+ "duration": self.duration,
96
+ "output": output,
97
+ "inputs": self.inputs or None, # Convert empty dict to None
98
+ "evaluation_result": [result.to_dict() for result in self.evaluation_result] if self.evaluation_result else None,
99
+ "span_type": self.span_type
100
+ }
101
+
102
+ def _serialize_output(self) -> Any:
103
+ """Helper method to serialize output data safely.
104
+
105
+ Handles special cases:
106
+ - Pydantic models are converted using model_dump()
107
+ - Other objects must be JSON serializable
108
+ """
109
+ if isinstance(self.output, BaseModel):
110
+ return self.output.model_dump()
111
+
112
+ # Verify JSON serialization is possible
113
+ json.dumps(self.output)
114
+ return self.output
115
+
116
+ class TraceClient:
117
+ """Client for managing a single trace context"""
118
+ def __init__(self, tracer, trace_id: str, name: str, project_name: str = "default_project"):
119
+ self.tracer = tracer
120
+ self.trace_id = trace_id
121
+ self.name = name
122
+ self.project_name = project_name
123
+ self.client: JudgmentClient = tracer.client
124
+ self.entries: List[TraceEntry] = []
125
+ self.start_time = time.time()
126
+ self.span_type = None
127
+ self._current_span: Optional[TraceEntry] = None
128
+
129
+ @contextmanager
130
+ def span(self, name: str, span_type: SpanType = "span"):
131
+ """Context manager for creating a trace span"""
132
+ start_time = time.time()
133
+
134
+ # Record span entry
135
+ self.add_entry(TraceEntry(
136
+ type="enter",
137
+ function=name,
138
+ depth=self.tracer.depth,
139
+ message=name,
140
+ timestamp=start_time,
141
+ span_type=span_type
142
+ ))
143
+
144
+ self.tracer.depth += 1
145
+ prev_span = self._current_span
146
+ self._current_span = name
147
+
148
+ try:
149
+ yield self
150
+ finally:
151
+ self.tracer.depth -= 1
152
+ duration = time.time() - start_time
153
+
154
+ # Record span exit
155
+ self.add_entry(TraceEntry(
156
+ type="exit",
157
+ function=name,
158
+ depth=self.tracer.depth,
159
+ message=f"← {name}",
160
+ timestamp=time.time(),
161
+ duration=duration,
162
+ span_type=span_type
163
+ ))
164
+ self._current_span = prev_span
165
+
166
+ async def async_evaluate(
167
+ self,
168
+ scorers: List[Union[JudgmentScorer, CustomScorer]],
169
+ input: Optional[str] = None,
170
+ actual_output: Optional[str] = None,
171
+ expected_output: Optional[str] = None,
172
+ context: Optional[List[str]] = None,
173
+ retrieval_context: Optional[List[str]] = None,
174
+ tools_called: Optional[List[str]] = None,
175
+ expected_tools: Optional[List[str]] = None,
176
+ additional_metadata: Optional[Dict[str, Any]] = None,
177
+ model: Optional[str] = None,
178
+ log_results: Optional[bool] = False,
179
+ ):
180
+ start_time = time.time() # Record start time
181
+ example = Example(
182
+ input=input,
183
+ actual_output=actual_output,
184
+ expected_output=expected_output,
185
+ context=context,
186
+ retrieval_context=retrieval_context,
187
+ tools_called=tools_called,
188
+ expected_tools=expected_tools,
189
+ additional_metadata=additional_metadata,
190
+ trace_id=self.trace_id
191
+ )
192
+ scoring_results = self.client.run_evaluation(
193
+ examples=[example],
194
+ scorers=scorers,
195
+ model=model,
196
+ metadata={},
197
+ log_results=log_results,
198
+ project_name="TestSpanLevel1", # TODO this should be dynamic
199
+ eval_run_name="TestSpanLevel1",
200
+ override=True,
201
+ )
202
+
203
+ self.record_evaluation(scoring_results, start_time) # Pass start_time to record_evaluation
204
+
205
+ def record_evaluation(self, results: List[ScoringResult], start_time: float):
206
+ """Record evaluation results for the current span"""
207
+ if self._current_span:
208
+ duration = time.time() - start_time # Calculate duration from start_time
209
+
210
+ self.add_entry(TraceEntry(
211
+ type="evaluation",
212
+ function=self._current_span,
213
+ depth=self.tracer.depth,
214
+ message=f"Evaluation results for {self._current_span}",
215
+ timestamp=time.time(),
216
+ evaluation_result=results,
217
+ duration=duration,
218
+ span_type="evaluation"
219
+ ))
220
+
221
+ def record_input(self, inputs: dict):
222
+ """Record input parameters for the current span"""
223
+ if self._current_span:
224
+ self.add_entry(TraceEntry(
225
+ type="input",
226
+ function=self._current_span,
227
+ depth=self.tracer.depth,
228
+ message=f"Inputs to {self._current_span}",
229
+ timestamp=time.time(),
230
+ inputs=inputs,
231
+ span_type=self.span_type
232
+ ))
233
+
234
+ async def _update_coroutine_output(self, entry: TraceEntry, coroutine: Any):
235
+ """Helper method to update the output of a trace entry once the coroutine completes"""
236
+ try:
237
+ result = await coroutine
238
+ entry.output = result
239
+ return result
240
+ except Exception as e:
241
+ entry.output = f"Error: {str(e)}"
242
+ raise
243
+
244
+ def record_output(self, output: Any):
245
+ """Record output for the current span"""
246
+ if self._current_span:
247
+ entry = TraceEntry(
248
+ type="output",
249
+ function=self._current_span,
250
+ depth=self.tracer.depth,
251
+ message=f"Output from {self._current_span}",
252
+ timestamp=time.time(),
253
+ output="<pending>" if inspect.iscoroutine(output) else output,
254
+ span_type=self.span_type
255
+ )
256
+ self.add_entry(entry)
257
+
258
+ if inspect.iscoroutine(output):
259
+ # Create a task to update the output once the coroutine completes
260
+ asyncio.create_task(self._update_coroutine_output(entry, output))
261
+
262
+ def add_entry(self, entry: TraceEntry):
263
+ """Add a trace entry to this trace context"""
264
+ self.entries.append(entry)
265
+ return self
266
+
267
+ def print(self):
268
+ """Print the complete trace with proper visual structure"""
269
+ for entry in self.entries:
270
+ entry.print_entry()
271
+
272
+ def get_duration(self) -> float:
273
+ """
274
+ Get the total duration of this trace
275
+ """
276
+ return time.time() - self.start_time
277
+
278
+ def condense_trace(self, entries: List[dict]) -> List[dict]:
279
+ """
280
+ Condenses trace entries into a single entry for each function call.
281
+ """
282
+ condensed = []
283
+ active_functions = [] # Stack to track nested function calls
284
+ function_entries = {} # Store entries for each function
285
+
286
+ for entry in entries:
287
+ function = entry["function"]
288
+
289
+ if entry["type"] == "enter":
290
+ # Initialize new function entry
291
+ function_entries[function] = {
292
+ "depth": entry["depth"],
293
+ "function": function,
294
+ "timestamp": entry["timestamp"],
295
+ "inputs": None,
296
+ "output": None,
297
+ "evaluation_result": None,
298
+ "span_type": entry.get("span_type", "span")
299
+ }
300
+ active_functions.append(function)
301
+
302
+ elif entry["type"] == "exit" and function in active_functions:
303
+ # Complete function entry
304
+ current_entry = function_entries[function]
305
+ current_entry["duration"] = entry["timestamp"] - current_entry["timestamp"]
306
+ condensed.append(current_entry)
307
+ active_functions.remove(function)
308
+ del function_entries[function]
309
+
310
+ elif function in active_functions:
311
+ # Update existing function entry with additional data
312
+ current_entry = function_entries[function]
313
+
314
+ if entry["type"] == "input" and entry["inputs"]:
315
+ current_entry["inputs"] = entry["inputs"]
316
+
317
+ if entry["type"] == "output" and entry["output"]:
318
+ current_entry["output"] = entry["output"]
319
+
320
+ if entry["type"] == "evaluation" and entry["evaluation_result"]:
321
+ current_entry["evaluation_result"] = entry["evaluation_result"]
322
+
323
+ # Sort by timestamp
324
+ condensed.sort(key=lambda x: x["timestamp"])
325
+ return condensed
326
+
327
+ def save(self, empty_save: bool = False, overwrite: bool = False) -> Tuple[str, dict]:
328
+ """
329
+ Save the current trace to the database.
330
+ Returns a tuple of (trace_id, trace_data) where trace_data is the trace data that was saved.
331
+ """
332
+ # Calculate total elapsed time
333
+ total_duration = self.get_duration()
334
+
335
+ raw_entries = [entry.to_dict() for entry in self.entries]
336
+ condensed_entries = self.condense_trace(raw_entries)
337
+
338
+ # Create trace document
339
+ trace_data = {
340
+ "trace_id": self.trace_id,
341
+ "api_key": self.tracer.api_key,
342
+ "name": self.name,
343
+ "project_name": self.project_name,
344
+ "created_at": datetime.fromtimestamp(self.start_time).isoformat(),
345
+ "duration": total_duration,
346
+ "token_counts": {
347
+ "prompt_tokens": 0, # Dummy value
348
+ "completion_tokens": 0, # Dummy value
349
+ "total_tokens": 0, # Dummy value
350
+ }, # TODO: Add token counts
351
+ "entries": condensed_entries,
352
+ "empty_save": empty_save,
353
+ "overwrite": overwrite
354
+ }
355
+
356
+ # Save trace data by making POST request to API
357
+ response = requests.post(
358
+ JUDGMENT_TRACES_SAVE_API_URL,
359
+ json=trace_data,
360
+ headers={
361
+ "Content-Type": "application/json",
362
+ }
363
+ )
364
+
365
+ if response.status_code == HTTPStatus.BAD_REQUEST:
366
+ raise ValueError(f"Failed to save trace data: Check your Trace name for conflicts, set overwrite=True to overwrite existing traces: {response.text}")
367
+ elif response.status_code != HTTPStatus.OK:
368
+ raise ValueError(f"Failed to save trace data: {response.text}")
369
+
370
+ return self.trace_id, trace_data
371
+
372
+ class Tracer:
373
+ _instance = None
374
+
375
+ def __new__(cls, *args, **kwargs):
376
+ if cls._instance is None:
377
+ cls._instance = super(Tracer, cls).__new__(cls)
378
+ return cls._instance
379
+
380
+ def __init__(self, api_key: str):
381
+ if not hasattr(self, 'initialized'):
382
+
383
+ if not api_key:
384
+ raise ValueError("Tracer must be configured with a Judgment API key")
385
+
386
+ self.api_key: str = api_key
387
+ self.client: JudgmentClient = JudgmentClient(judgment_api_key=api_key)
388
+ self.depth: int = 0
389
+ self._current_trace: Optional[str] = None
390
+ self.initialized: bool = True
391
+
392
+ @contextmanager
393
+ def trace(self, name: str, project_name: str = "default_project", overwrite: bool = False) -> Generator[TraceClient, None, None]:
394
+ """Start a new trace context using a context manager"""
395
+ trace_id = str(uuid.uuid4())
396
+ trace = TraceClient(self, trace_id, name, project_name=project_name)
397
+ prev_trace = self._current_trace
398
+ self._current_trace = trace
399
+
400
+ # Automatically create top-level span
401
+ with trace.span(name or "unnamed_trace") as span:
402
+ try:
403
+ # Save the trace to the database to handle Evaluations' trace_id referential integrity
404
+ trace.save(empty_save=True, overwrite=overwrite)
405
+ yield trace
406
+ finally:
407
+ self._current_trace = prev_trace
408
+
409
+ def get_current_trace(self) -> Optional[TraceClient]:
410
+ """
411
+ Get the current trace context
412
+ """
413
+ return self._current_trace
414
+
415
+ def observe(self, func=None, *, name=None, span_type: SpanType = "span"):
416
+ """
417
+ Decorator to trace function execution with detailed entry/exit information.
418
+
419
+ Args:
420
+ func: The function to trace
421
+ name: Optional custom name for the function
422
+ span_type: The type of span to use for this observation (default: "span")
423
+ """
424
+ if func is None:
425
+ return lambda f: self.observe(f, name=name, span_type=span_type)
426
+
427
+ if asyncio.iscoroutinefunction(func):
428
+ @functools.wraps(func)
429
+ async def async_wrapper(*args, **kwargs):
430
+ if self._current_trace:
431
+ span_name = name or func.__name__
432
+
433
+ with self._current_trace.span(span_name, span_type=span_type) as span:
434
+ # Set the span type
435
+ span.span_type = span_type
436
+
437
+ # Record inputs
438
+ span.record_input({
439
+ 'args': list(args),
440
+ 'kwargs': kwargs
441
+ })
442
+
443
+ # Execute function
444
+ result = await func(*args, **kwargs)
445
+
446
+ # Record output
447
+ span.record_output(result)
448
+
449
+ return result
450
+
451
+ return await func(*args, **kwargs)
452
+ return async_wrapper
453
+ else:
454
+ @functools.wraps(func)
455
+ def wrapper(*args, **kwargs):
456
+ if self._current_trace:
457
+ span_name = name or func.__name__
458
+
459
+ with self._current_trace.span(span_name, span_type=span_type) as span:
460
+ # Set the span type
461
+ span.span_type = span_type
462
+
463
+ # Record inputs
464
+ span.record_input({
465
+ 'args': list(args),
466
+ 'kwargs': kwargs
467
+ })
468
+
469
+ # Execute function
470
+ result = func(*args, **kwargs)
471
+
472
+ # Record output
473
+ span.record_output(result)
474
+
475
+ return result
476
+
477
+ return func(*args, **kwargs)
478
+ return wrapper
479
+
480
+ def wrap(client: Any) -> Any:
481
+ """
482
+ Wraps an API client to add tracing capabilities.
483
+ Supports OpenAI, Together, and Anthropic clients.
484
+ """
485
+ tracer = Tracer._instance # Get the global tracer instance
486
+
487
+ # Get the appropriate configuration for this client type
488
+ span_name, original_create = _get_client_config(client)
489
+
490
+ def traced_create(*args, **kwargs):
491
+ # Skip tracing if no active trace
492
+ if not (tracer and tracer._current_trace):
493
+ return original_create(*args, **kwargs)
494
+
495
+ with tracer._current_trace.span(span_name, span_type="llm") as span:
496
+ # Format and record the input parameters
497
+ input_data = _format_input_data(client, **kwargs)
498
+ span.record_input(input_data)
499
+
500
+ # Make the actual API call
501
+ response = original_create(*args, **kwargs)
502
+
503
+ # Format and record the output
504
+ output_data = _format_output_data(client, response)
505
+ span.record_output(output_data)
506
+
507
+ return response
508
+
509
+ # Replace the original method with our traced version
510
+ if isinstance(client, (OpenAI, Together)):
511
+ client.chat.completions.create = traced_create
512
+ elif isinstance(client, Anthropic):
513
+ client.messages.create = traced_create
514
+
515
+ return client
516
+
517
+ # Helper functions for client-specific operations
518
+
519
+ def _get_client_config(client: ApiClient) -> tuple[str, callable]:
520
+ """Returns configuration tuple for the given API client.
521
+
522
+ Args:
523
+ client: An instance of OpenAI, Together, or Anthropic client
524
+
525
+ Returns:
526
+ tuple: (span_name, create_method)
527
+ - span_name: String identifier for tracing
528
+ - create_method: Reference to the client's creation method
529
+
530
+ Raises:
531
+ ValueError: If client type is not supported
532
+ """
533
+ if isinstance(client, OpenAI):
534
+ return "OPENAI_API_CALL", client.chat.completions.create
535
+ elif isinstance(client, Together):
536
+ return "TOGETHER_API_CALL", client.chat.completions.create
537
+ elif isinstance(client, Anthropic):
538
+ return "ANTHROPIC_API_CALL", client.messages.create
539
+ raise ValueError(f"Unsupported client type: {type(client)}")
540
+
541
+ def _format_input_data(client: ApiClient, **kwargs) -> dict:
542
+ """Format input parameters based on client type.
543
+
544
+ Extracts relevant parameters from kwargs based on the client type
545
+ to ensure consistent tracing across different APIs.
546
+ """
547
+ if isinstance(client, (OpenAI, Together)):
548
+ return {
549
+ "model": kwargs.get("model"),
550
+ "messages": kwargs.get("messages"),
551
+ }
552
+ # Anthropic requires additional max_tokens parameter
553
+ return {
554
+ "model": kwargs.get("model"),
555
+ "messages": kwargs.get("messages"),
556
+ "max_tokens": kwargs.get("max_tokens")
557
+ }
558
+
559
+ def _format_output_data(client: ApiClient, response: Any) -> dict:
560
+ """Format API response data based on client type.
561
+
562
+ Normalizes different response formats into a consistent structure
563
+ for tracing purposes.
564
+
565
+ Returns:
566
+ dict containing:
567
+ - content: The generated text
568
+ - usage: Token usage statistics
569
+ """
570
+ if isinstance(client, (OpenAI, Together)):
571
+ return {
572
+ "content": response.choices[0].message.content,
573
+ "usage": {
574
+ "prompt_tokens": response.usage.prompt_tokens,
575
+ "completion_tokens": response.usage.completion_tokens,
576
+ "total_tokens": response.usage.total_tokens
577
+ }
578
+ }
579
+ # Anthropic has a different response structure
580
+ return {
581
+ "content": response.content[0].text,
582
+ "usage": {
583
+ "input_tokens": response.usage.input_tokens,
584
+ "output_tokens": response.usage.output_tokens,
585
+ "total_tokens": response.usage.input_tokens + response.usage.output_tokens
586
+ }
587
+ }