ragbits-evaluate 0.0.30rc1__py3-none-any.whl → 1.4.0.dev202602030301__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,591 +0,0 @@
1
- """Task completion checkers for agent simulation scenarios."""
2
-
3
- from __future__ import annotations
4
-
5
- from abc import ABC, abstractmethod
6
- from collections.abc import Callable
7
- from typing import TYPE_CHECKING, Any, ClassVar, Literal
8
-
9
- from pydantic import BaseModel, Field
10
-
11
- from ragbits.core.prompt import Prompt
12
-
13
- if TYPE_CHECKING:
14
- from ragbits.agents.tool import ToolCallResult
15
- from ragbits.evaluate.agent_simulation.models import Task, Turn
16
-
17
-
18
- class LLMCheckerPromptInput(BaseModel):
19
- """Input for the LLM checker prompt."""
20
-
21
- task: str
22
- expected_result: str
23
- context_block: str
24
- history_block: str
25
-
26
-
27
- class LLMCheckerPromptOutput(BaseModel):
28
- """Output schema for the LLM checker prompt."""
29
-
30
- done: bool = Field(..., description="Whether the task has been completed")
31
- reason: str = Field(..., description="Short explanation of the decision")
32
-
33
-
34
- class LLMCheckerPrompt(Prompt[LLMCheckerPromptInput, LLMCheckerPromptOutput]):
35
- """Prompt for LLM-based task completion checking."""
36
-
37
- system_prompt = """
38
- You are a strict task-completion judge for a user-assistant conversation.
39
- Decide if the assistant has fulfilled the current task.
40
- Current task: {{ task }}
41
- Expected result: {{ expected_result }}
42
- {{ context_block }}
43
- """
44
-
45
- user_prompt = """
46
- [CONVERSATION]
47
- {{ history_block }}
48
-
49
- [TASK]
50
- Evaluate if the task has been completed and provide your decision.
51
- """
52
-
53
-
54
- class CheckerResult(BaseModel):
55
- """Standard response schema for all checkers."""
56
-
57
- completed: bool = Field(..., description="Whether the task/turn check passed")
58
- reason: str = Field(..., description="Human-readable explanation of the decision")
59
- checker_type: str = Field(..., description="Type of checker that produced this result")
60
- details: dict[str, Any] = Field(default_factory=dict, description="Additional checker-specific details")
61
-
62
-
63
- class BaseCheckerConfig(BaseModel, ABC):
64
- """Base configuration for all checkers. Subclass this to create new checker types."""
65
-
66
- type: ClassVar[str] # Each subclass must define this
67
-
68
- @abstractmethod
69
- async def check(
70
- self,
71
- task: Task,
72
- history: list[Turn],
73
- tool_calls: list[ToolCallResult],
74
- state: dict[str, Any],
75
- context: CheckerContext,
76
- ) -> CheckerResult:
77
- """Run the check and return result.
78
-
79
- Args:
80
- task: The current task being checked.
81
- history: List of conversation turns so far.
82
- tool_calls: List of tool calls made in the current turn.
83
- state: Current conversation state.
84
- context: Shared context with LLM and other resources.
85
-
86
- Returns:
87
- CheckerResult with the decision and reasoning.
88
- """
89
-
90
-
91
- class CheckerContext(BaseModel):
92
- """Shared context passed to all checkers during evaluation."""
93
-
94
- model_config = {"arbitrary_types_allowed": True}
95
-
96
- llm: Any = Field(default=None, description="LLM instance for checkers that need it")
97
- domain_context: Any = Field(default=None, description="Optional domain context")
98
-
99
-
100
- # =============================================================================
101
- # Checker Registry
102
- # =============================================================================
103
-
104
- _CHECKER_REGISTRY: dict[str, type[BaseCheckerConfig]] = {}
105
-
106
-
107
- def register_checker(checker_type: str) -> Callable[[type[BaseCheckerConfig]], type[BaseCheckerConfig]]:
108
- """Decorator to register a checker type.
109
-
110
- Usage:
111
- @register_checker("my_checker")
112
- class MyCheckerConfig(BaseCheckerConfig):
113
- type: ClassVar[str] = "my_checker"
114
- ...
115
- """
116
-
117
- def decorator(cls: type[BaseCheckerConfig]) -> type[BaseCheckerConfig]:
118
- cls.type = checker_type
119
- _CHECKER_REGISTRY[checker_type] = cls
120
- return cls
121
-
122
- return decorator
123
-
124
-
125
- def get_checker_class(checker_type: str) -> type[BaseCheckerConfig]:
126
- """Get a checker class by type name."""
127
- if checker_type not in _CHECKER_REGISTRY:
128
- available = ", ".join(_CHECKER_REGISTRY.keys())
129
- raise ValueError(f"Unknown checker type: {checker_type}. Available: {available}")
130
- return _CHECKER_REGISTRY[checker_type]
131
-
132
-
133
- def list_checker_types() -> list[str]:
134
- """List all registered checker types."""
135
- return list(_CHECKER_REGISTRY.keys())
136
-
137
-
138
- def parse_checker_config(data: dict[str, Any]) -> BaseCheckerConfig:
139
- """Parse a checker config dict into the appropriate typed config.
140
-
141
- Args:
142
- data: Dict with "type" key and checker-specific fields.
143
-
144
- Returns:
145
- Typed checker config instance.
146
- """
147
- checker_type = data.get("type")
148
- if not checker_type:
149
- raise ValueError("Checker config must have a 'type' field")
150
-
151
- checker_class = get_checker_class(checker_type)
152
- return checker_class(**{k: v for k, v in data.items() if k != "type"})
153
-
154
-
155
- # =============================================================================
156
- # Built-in Checkers
157
- # =============================================================================
158
-
159
-
160
- @register_checker("llm")
161
- class LLMCheckerConfig(BaseCheckerConfig):
162
- """LLM-based checker that uses a language model to evaluate task completion."""
163
-
164
- type: ClassVar[str] = "llm"
165
-
166
- expected_result: str = Field(..., description="Description of the expected outcome")
167
-
168
- async def check(
169
- self,
170
- task: Task,
171
- history: list[Turn],
172
- tool_calls: list[ToolCallResult],
173
- state: dict[str, Any],
174
- context: CheckerContext,
175
- ) -> CheckerResult:
176
- """Check task completion using LLM evaluation."""
177
- if context.llm is None:
178
- return CheckerResult(
179
- completed=False,
180
- reason="LLM checker requires an LLM instance but none was provided",
181
- checker_type=self.type,
182
- )
183
-
184
- prompt = LLMCheckerPrompt(
185
- LLMCheckerPromptInput(
186
- task=task.task,
187
- expected_result=self.expected_result,
188
- context_block=self._build_context_block(context),
189
- history_block=self._build_history_block(history),
190
- )
191
- )
192
-
193
- response: LLMCheckerPromptOutput = await context.llm.generate(prompt)
194
-
195
- return CheckerResult(
196
- completed=response.done,
197
- reason=response.reason,
198
- checker_type=self.type,
199
- details={"expected_result": self.expected_result},
200
- )
201
-
202
- @staticmethod
203
- def _build_context_block(context: CheckerContext) -> str:
204
- """Build the domain context block string."""
205
- if context.domain_context:
206
- return (
207
- "\n[IMPORTANT CONTEXT]\n"
208
- f"{context.domain_context.format_for_prompt()}\n\n"
209
- "When evaluating task completion, consider the domain context above "
210
- f"and use {context.domain_context.locale} locale conventions.\n"
211
- )
212
- return ""
213
-
214
- @staticmethod
215
- def _build_history_block(history: list[Turn]) -> str:
216
- """Build the conversation history block string."""
217
- if not history:
218
- return "(no prior messages)"
219
- history_text = [f"User: {t.user}\nAssistant: {t.assistant}" for t in history]
220
- return "\n\n".join(history_text)
221
-
222
-
223
- class ToolCallExpectation(BaseModel):
224
- """Expected tool call specification."""
225
-
226
- name: str = Field(..., description="Name of the tool that should be called")
227
- arguments: dict[str, Any] | None = Field(
228
- default=None,
229
- description="Optional arguments that should match (partial match)",
230
- )
231
- result_contains: str | None = Field(
232
- default=None,
233
- description="Optional string that should be present in the tool result",
234
- )
235
-
236
-
237
- @register_checker("tool_call")
238
- class ToolCallCheckerConfig(BaseCheckerConfig):
239
- """Checker that verifies specific tool calls were made."""
240
-
241
- type: ClassVar[str] = "tool_call"
242
-
243
- tools: list[ToolCallExpectation | str] = Field(
244
- ..., description="Expected tools - can be names or detailed expectations"
245
- )
246
- mode: Literal["all", "any"] = Field(
247
- default="all", description="'all' requires all tools, 'any' requires at least one"
248
- )
249
-
250
- async def check( # noqa: PLR0912
251
- self,
252
- task: Task,
253
- history: list[Turn],
254
- tool_calls: list[ToolCallResult],
255
- state: dict[str, Any],
256
- context: CheckerContext,
257
- ) -> CheckerResult:
258
- """Check if expected tool calls were made."""
259
- if not self.tools:
260
- return CheckerResult(
261
- completed=True,
262
- reason="No expected tools specified",
263
- checker_type=self.type,
264
- )
265
-
266
- if not tool_calls:
267
- return CheckerResult(
268
- completed=False,
269
- reason="No tools were called, but tools were expected",
270
- checker_type=self.type,
271
- details={"expected": [t if isinstance(t, str) else t.name for t in self.tools], "called": []},
272
- )
273
-
274
- # Normalize expectations
275
- expectations: list[ToolCallExpectation] = []
276
- for tool in self.tools:
277
- if isinstance(tool, str):
278
- expectations.append(ToolCallExpectation(name=tool))
279
- else:
280
- expectations.append(tool)
281
-
282
- matched_tools: list[str] = []
283
- unmatched_tools: list[str] = []
284
- match_details: dict[str, dict[str, Any]] = {}
285
-
286
- for expected in expectations:
287
- found = False
288
- for call in tool_calls:
289
- if call.name != expected.name:
290
- continue
291
-
292
- # Check arguments if specified
293
- if expected.arguments:
294
- args_match = all(call.arguments.get(k) == v for k, v in expected.arguments.items())
295
- if not args_match:
296
- match_details[expected.name] = {
297
- "status": "args_mismatch",
298
- "expected_args": expected.arguments,
299
- "actual_args": call.arguments,
300
- }
301
- continue
302
-
303
- # Check result contains if specified
304
- if expected.result_contains:
305
- result_str = str(call.result) if call.result else ""
306
- if expected.result_contains not in result_str:
307
- match_details[expected.name] = {
308
- "status": "result_mismatch",
309
- "expected_contains": expected.result_contains,
310
- "actual_result": result_str[:200],
311
- }
312
- continue
313
-
314
- found = True
315
- matched_tools.append(expected.name)
316
- match_details[expected.name] = {"status": "matched"}
317
- break
318
-
319
- if not found and expected.name not in match_details:
320
- unmatched_tools.append(expected.name)
321
- match_details[expected.name] = {"status": "not_called"}
322
-
323
- called_names = [tc.name for tc in tool_calls]
324
-
325
- if self.mode == "all":
326
- completed = len(unmatched_tools) == 0 and all(d.get("status") == "matched" for d in match_details.values())
327
- if completed:
328
- reason = f"All expected tools matched: {', '.join(matched_tools)}"
329
- else:
330
- failed = [k for k, v in match_details.items() if v.get("status") != "matched"]
331
- reason = f"Tool check failed for: {', '.join(failed)}. Called: {', '.join(called_names)}"
332
- else: # mode == "any"
333
- completed = len(matched_tools) > 0
334
- if completed:
335
- reason = f"Found matching tool(s): {', '.join(matched_tools)}"
336
- else:
337
- reason = f"None of expected tools matched. Expected: {', '.join(e.name for e in expectations)}"
338
-
339
- return CheckerResult(
340
- completed=completed,
341
- reason=reason,
342
- checker_type=self.type,
343
- details={
344
- "matched": matched_tools,
345
- "unmatched": unmatched_tools,
346
- "called": called_names,
347
- "match_details": match_details,
348
- },
349
- )
350
-
351
-
352
- class StateExpectation(BaseModel):
353
- """Expected state value specification."""
354
-
355
- key: str = Field(..., description="Key path in state (supports dot notation like 'user.name')")
356
- value: Any | None = Field(default=None, description="Expected exact value")
357
- exists: bool | None = Field(default=None, description="Check existence (True) or non-existence (False)")
358
- contains: str | None = Field(default=None, description="For strings, check if contains this substring")
359
- min_value: float | None = Field(default=None, description="For numbers, minimum allowed value")
360
- max_value: float | None = Field(default=None, description="For numbers, maximum allowed value")
361
-
362
-
363
- @register_checker("state")
364
- class StateCheckerConfig(BaseCheckerConfig):
365
- """Checker that verifies state has specific values."""
366
-
367
- type: ClassVar[str] = "state"
368
-
369
- checks: list[StateExpectation] = Field(..., description="State conditions to verify")
370
- mode: Literal["all", "any"] = Field(
371
- default="all", description="'all' requires all checks, 'any' requires at least one"
372
- )
373
-
374
- def _get_nested_value(self, state: dict[str, Any], key_path: str) -> tuple[bool, Any]: # noqa: PLR6301
375
- """Get a nested value from state using dot notation."""
376
- keys = key_path.split(".")
377
- current = state
378
- for key in keys:
379
- if isinstance(current, dict) and key in current:
380
- current = current[key]
381
- else:
382
- return False, None
383
- return True, current
384
-
385
- async def check( # noqa: PLR0912, PLR0915
386
- self,
387
- task: Task,
388
- history: list[Turn],
389
- tool_calls: list[ToolCallResult],
390
- state: dict[str, Any],
391
- context: CheckerContext,
392
- ) -> CheckerResult:
393
- """Check if state meets the specified criteria."""
394
- if not self.checks:
395
- return CheckerResult(
396
- completed=True,
397
- reason="No state checks specified",
398
- checker_type=self.type,
399
- )
400
-
401
- passed_checks: list[str] = []
402
- failed_checks: list[str] = []
403
- check_details: dict[str, dict[str, Any]] = {}
404
-
405
- for check in self.checks:
406
- exists, value = self._get_nested_value(state, check.key)
407
-
408
- # Check existence
409
- if check.exists is not None:
410
- if check.exists and not exists:
411
- failed_checks.append(check.key)
412
- check_details[check.key] = {"status": "not_exists", "expected": "exists"}
413
- continue
414
- elif not check.exists and exists:
415
- failed_checks.append(check.key)
416
- check_details[check.key] = {"status": "exists", "expected": "not_exists"}
417
- continue
418
-
419
- if not exists and check.exists is None:
420
- failed_checks.append(check.key)
421
- check_details[check.key] = {"status": "not_exists"}
422
- continue
423
-
424
- # Check exact value
425
- if check.value is not None and value != check.value:
426
- failed_checks.append(check.key)
427
- check_details[check.key] = {
428
- "status": "value_mismatch",
429
- "expected": check.value,
430
- "actual": value,
431
- }
432
- continue
433
-
434
- # Check contains (for strings)
435
- if check.contains is not None: # noqa: SIM102
436
- if not isinstance(value, str) or check.contains not in value:
437
- failed_checks.append(check.key)
438
- check_details[check.key] = {
439
- "status": "contains_failed",
440
- "expected_contains": check.contains,
441
- "actual": str(value)[:200],
442
- }
443
- continue
444
-
445
- # Check min/max value (for numbers)
446
- if check.min_value is not None:
447
- try:
448
- if float(value) < check.min_value:
449
- failed_checks.append(check.key)
450
- check_details[check.key] = {
451
- "status": "below_min",
452
- "min": check.min_value,
453
- "actual": value,
454
- }
455
- continue
456
- except (ValueError, TypeError):
457
- failed_checks.append(check.key)
458
- check_details[check.key] = {"status": "not_numeric", "actual": value}
459
- continue
460
-
461
- if check.max_value is not None:
462
- try:
463
- if float(value) > check.max_value:
464
- failed_checks.append(check.key)
465
- check_details[check.key] = {
466
- "status": "above_max",
467
- "max": check.max_value,
468
- "actual": value,
469
- }
470
- continue
471
- except (ValueError, TypeError):
472
- failed_checks.append(check.key)
473
- check_details[check.key] = {"status": "not_numeric", "actual": value}
474
- continue
475
-
476
- passed_checks.append(check.key)
477
- check_details[check.key] = {"status": "passed", "value": value}
478
-
479
- if self.mode == "all":
480
- completed = len(failed_checks) == 0
481
- if completed:
482
- reason = f"All state checks passed: {', '.join(passed_checks)}"
483
- else:
484
- reason = f"State checks failed: {', '.join(failed_checks)}"
485
- else: # mode == "any"
486
- completed = len(passed_checks) > 0
487
- if completed: # noqa: SIM108
488
- reason = f"State check(s) passed: {', '.join(passed_checks)}"
489
- else:
490
- reason = "No state checks passed"
491
-
492
- return CheckerResult(
493
- completed=completed,
494
- reason=reason,
495
- checker_type=self.type,
496
- details={
497
- "passed": passed_checks,
498
- "failed": failed_checks,
499
- "check_details": check_details,
500
- "current_state": state,
501
- },
502
- )
503
-
504
-
505
- # =============================================================================
506
- # Checker Runner
507
- # =============================================================================
508
-
509
-
510
- async def run_checkers(
511
- checkers: list[BaseCheckerConfig],
512
- task: Task,
513
- history: list[Turn],
514
- tool_calls: list[ToolCallResult],
515
- state: dict[str, Any],
516
- context: CheckerContext,
517
- mode: Literal["all", "any"] = "all",
518
- ) -> CheckerResult:
519
- """Run multiple checkers and combine results.
520
-
521
- Args:
522
- checkers: List of checker configs to run.
523
- task: The current task.
524
- history: Conversation history.
525
- tool_calls: Tool calls from current turn.
526
- state: Current state.
527
- context: Shared checker context.
528
- mode: 'all' requires all to pass, 'any' requires at least one.
529
-
530
- Returns:
531
- Combined CheckerResult.
532
- """
533
- if not checkers:
534
- return CheckerResult(
535
- completed=False,
536
- reason="No checkers configured",
537
- checker_type="none",
538
- )
539
-
540
- results: list[CheckerResult] = []
541
- for checker in checkers:
542
- result = await checker.check(task, history, tool_calls, state, context)
543
- results.append(result)
544
-
545
- if mode == "any" and result.completed:
546
- return CheckerResult(
547
- completed=True,
548
- reason=f"[{result.checker_type}] {result.reason}",
549
- checker_type=result.checker_type,
550
- details={
551
- "mode": mode,
552
- "passed_checker": result.checker_type,
553
- "all_results": [r.model_dump() for r in results],
554
- },
555
- )
556
- elif mode == "all" and not result.completed:
557
- return CheckerResult(
558
- completed=False,
559
- reason=f"[{result.checker_type}] {result.reason}",
560
- checker_type=result.checker_type,
561
- details={
562
- "mode": mode,
563
- "failed_checker": result.checker_type,
564
- "all_results": [r.model_dump() for r in results],
565
- },
566
- )
567
-
568
- # All passed (for "all" mode) or none passed (for "any" mode)
569
- checker_types = [r.checker_type for r in results]
570
- if mode == "all":
571
- return CheckerResult(
572
- completed=True,
573
- reason=f"All {len(results)} checkers passed: {', '.join(checker_types)}",
574
- checker_type="combined",
575
- details={
576
- "mode": mode,
577
- "checkers": checker_types,
578
- "all_results": [r.model_dump() for r in results],
579
- },
580
- )
581
- else: # mode == "any" and none passed
582
- return CheckerResult(
583
- completed=False,
584
- reason=f"None of {len(results)} checkers passed: {', '.join(checker_types)}",
585
- checker_type="combined",
586
- details={
587
- "mode": mode,
588
- "checkers": checker_types,
589
- "all_results": [r.model_dump() for r in results],
590
- },
591
- )