prela 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prela/__init__.py +394 -0
- prela/_version.py +3 -0
- prela/contrib/CLI.md +431 -0
- prela/contrib/README.md +118 -0
- prela/contrib/__init__.py +5 -0
- prela/contrib/cli.py +1063 -0
- prela/contrib/explorer.py +571 -0
- prela/core/__init__.py +64 -0
- prela/core/clock.py +98 -0
- prela/core/context.py +228 -0
- prela/core/replay.py +403 -0
- prela/core/sampler.py +178 -0
- prela/core/span.py +295 -0
- prela/core/tracer.py +498 -0
- prela/evals/__init__.py +94 -0
- prela/evals/assertions/README.md +484 -0
- prela/evals/assertions/__init__.py +78 -0
- prela/evals/assertions/base.py +90 -0
- prela/evals/assertions/multi_agent.py +625 -0
- prela/evals/assertions/semantic.py +223 -0
- prela/evals/assertions/structural.py +443 -0
- prela/evals/assertions/tool.py +380 -0
- prela/evals/case.py +370 -0
- prela/evals/n8n/__init__.py +69 -0
- prela/evals/n8n/assertions.py +450 -0
- prela/evals/n8n/runner.py +497 -0
- prela/evals/reporters/README.md +184 -0
- prela/evals/reporters/__init__.py +32 -0
- prela/evals/reporters/console.py +251 -0
- prela/evals/reporters/json.py +176 -0
- prela/evals/reporters/junit.py +278 -0
- prela/evals/runner.py +525 -0
- prela/evals/suite.py +316 -0
- prela/exporters/__init__.py +27 -0
- prela/exporters/base.py +189 -0
- prela/exporters/console.py +443 -0
- prela/exporters/file.py +322 -0
- prela/exporters/http.py +394 -0
- prela/exporters/multi.py +154 -0
- prela/exporters/otlp.py +388 -0
- prela/instrumentation/ANTHROPIC.md +297 -0
- prela/instrumentation/LANGCHAIN.md +480 -0
- prela/instrumentation/OPENAI.md +59 -0
- prela/instrumentation/__init__.py +49 -0
- prela/instrumentation/anthropic.py +1436 -0
- prela/instrumentation/auto.py +129 -0
- prela/instrumentation/base.py +436 -0
- prela/instrumentation/langchain.py +959 -0
- prela/instrumentation/llamaindex.py +719 -0
- prela/instrumentation/multi_agent/__init__.py +48 -0
- prela/instrumentation/multi_agent/autogen.py +357 -0
- prela/instrumentation/multi_agent/crewai.py +404 -0
- prela/instrumentation/multi_agent/langgraph.py +299 -0
- prela/instrumentation/multi_agent/models.py +203 -0
- prela/instrumentation/multi_agent/swarm.py +231 -0
- prela/instrumentation/n8n/__init__.py +68 -0
- prela/instrumentation/n8n/code_node.py +534 -0
- prela/instrumentation/n8n/models.py +336 -0
- prela/instrumentation/n8n/webhook.py +489 -0
- prela/instrumentation/openai.py +1198 -0
- prela/license.py +245 -0
- prela/replay/__init__.py +31 -0
- prela/replay/comparison.py +390 -0
- prela/replay/engine.py +1227 -0
- prela/replay/loader.py +231 -0
- prela/replay/result.py +196 -0
- prela-0.1.0.dist-info/METADATA +399 -0
- prela-0.1.0.dist-info/RECORD +71 -0
- prela-0.1.0.dist-info/WHEEL +4 -0
- prela-0.1.0.dist-info/entry_points.txt +2 -0
- prela-0.1.0.dist-info/licenses/LICENSE +190 -0
prela/evals/runner.py
ADDED
|
@@ -0,0 +1,525 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Evaluation runner for executing test cases against AI agents.
|
|
3
|
+
|
|
4
|
+
This module provides the core infrastructure for running evaluation suites,
|
|
5
|
+
executing test cases, running assertions, and aggregating results.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import time
|
|
11
|
+
import traceback
|
|
12
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
13
|
+
from dataclasses import dataclass, field
|
|
14
|
+
from datetime import datetime, timezone
|
|
15
|
+
from typing import Any, Callable
|
|
16
|
+
|
|
17
|
+
from prela.core.clock import now
|
|
18
|
+
from prela.core.context import get_current_trace_id
|
|
19
|
+
from prela.core.tracer import Tracer
|
|
20
|
+
|
|
21
|
+
from .assertions.base import AssertionResult, BaseAssertion
|
|
22
|
+
from .case import EvalCase, EvalInput
|
|
23
|
+
from .suite import EvalSuite
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class CaseResult:
|
|
28
|
+
"""Result of running a single eval case."""
|
|
29
|
+
|
|
30
|
+
case_id: str
|
|
31
|
+
case_name: str
|
|
32
|
+
passed: bool
|
|
33
|
+
duration_ms: float
|
|
34
|
+
assertion_results: list[AssertionResult]
|
|
35
|
+
output: Any = None
|
|
36
|
+
error: str | None = None
|
|
37
|
+
trace_id: str | None = None
|
|
38
|
+
|
|
39
|
+
def __post_init__(self) -> None:
|
|
40
|
+
"""Validate fields."""
|
|
41
|
+
if self.duration_ms < 0:
|
|
42
|
+
raise ValueError("duration_ms must be non-negative")
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class EvalRunResult:
|
|
47
|
+
"""Result of running an evaluation suite."""
|
|
48
|
+
|
|
49
|
+
suite_name: str
|
|
50
|
+
started_at: datetime
|
|
51
|
+
completed_at: datetime
|
|
52
|
+
total_cases: int
|
|
53
|
+
passed_cases: int
|
|
54
|
+
failed_cases: int
|
|
55
|
+
pass_rate: float
|
|
56
|
+
case_results: list[CaseResult] = field(default_factory=list)
|
|
57
|
+
|
|
58
|
+
def __post_init__(self) -> None:
|
|
59
|
+
"""Validate fields."""
|
|
60
|
+
if self.total_cases < 0:
|
|
61
|
+
raise ValueError("total_cases must be non-negative")
|
|
62
|
+
if self.passed_cases < 0:
|
|
63
|
+
raise ValueError("passed_cases must be non-negative")
|
|
64
|
+
if self.failed_cases < 0:
|
|
65
|
+
raise ValueError("failed_cases must be non-negative")
|
|
66
|
+
if not 0.0 <= self.pass_rate <= 1.0:
|
|
67
|
+
raise ValueError("pass_rate must be between 0.0 and 1.0")
|
|
68
|
+
if self.passed_cases + self.failed_cases != self.total_cases:
|
|
69
|
+
raise ValueError("passed_cases + failed_cases must equal total_cases")
|
|
70
|
+
|
|
71
|
+
def summary(self) -> str:
|
|
72
|
+
"""Return human-readable summary of the evaluation run.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
Multi-line string with summary statistics and case results.
|
|
76
|
+
"""
|
|
77
|
+
lines = [
|
|
78
|
+
f"Evaluation Suite: {self.suite_name}",
|
|
79
|
+
f"Started: {self.started_at.isoformat()}",
|
|
80
|
+
f"Completed: {self.completed_at.isoformat()}",
|
|
81
|
+
f"Duration: {(self.completed_at - self.started_at).total_seconds():.2f}s",
|
|
82
|
+
"",
|
|
83
|
+
f"Total Cases: {self.total_cases}",
|
|
84
|
+
f"Passed: {self.passed_cases} ({self.pass_rate * 100:.1f}%)",
|
|
85
|
+
f"Failed: {self.failed_cases}",
|
|
86
|
+
"",
|
|
87
|
+
"Case Results:",
|
|
88
|
+
]
|
|
89
|
+
|
|
90
|
+
for result in self.case_results:
|
|
91
|
+
status = "✓" if result.passed else "✗"
|
|
92
|
+
lines.append(
|
|
93
|
+
f" {status} {result.case_name} ({result.duration_ms:.1f}ms)"
|
|
94
|
+
)
|
|
95
|
+
if not result.passed:
|
|
96
|
+
# Show failed assertions
|
|
97
|
+
for assertion in result.assertion_results:
|
|
98
|
+
if not assertion.passed:
|
|
99
|
+
lines.append(f" - {assertion.message}")
|
|
100
|
+
|
|
101
|
+
return "\n".join(lines)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class EvalRunner:
|
|
105
|
+
"""Runner for executing evaluation suites against AI agents.
|
|
106
|
+
|
|
107
|
+
The runner executes test cases, runs assertions, captures traces,
|
|
108
|
+
and aggregates results. Supports parallel execution with thread pools.
|
|
109
|
+
|
|
110
|
+
Example:
|
|
111
|
+
>>> from prela.evals import EvalSuite, EvalRunner
|
|
112
|
+
>>> from prela import get_tracer
|
|
113
|
+
>>>
|
|
114
|
+
>>> suite = EvalSuite.from_yaml("tests.yaml")
|
|
115
|
+
>>> tracer = get_tracer()
|
|
116
|
+
>>>
|
|
117
|
+
>>> def my_agent(input_data):
|
|
118
|
+
... # Your agent logic here
|
|
119
|
+
... return "agent output"
|
|
120
|
+
>>>
|
|
121
|
+
>>> runner = EvalRunner(suite, my_agent, tracer=tracer)
|
|
122
|
+
>>> result = runner.run()
|
|
123
|
+
>>> print(result.summary())
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
def __init__(
|
|
127
|
+
self,
|
|
128
|
+
suite: EvalSuite,
|
|
129
|
+
agent: Callable[[EvalInput], Any],
|
|
130
|
+
tracer: Tracer | None = None,
|
|
131
|
+
parallel: bool = False,
|
|
132
|
+
max_workers: int = 4,
|
|
133
|
+
on_case_complete: Callable[[CaseResult], None] | None = None,
|
|
134
|
+
):
|
|
135
|
+
"""Initialize the evaluation runner.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
suite: The evaluation suite to run.
|
|
139
|
+
agent: Callable that takes an EvalInput and returns agent output.
|
|
140
|
+
tracer: Optional tracer for capturing execution traces.
|
|
141
|
+
parallel: Whether to run cases in parallel using a thread pool.
|
|
142
|
+
max_workers: Maximum number of worker threads if parallel=True.
|
|
143
|
+
on_case_complete: Optional callback invoked after each case completes.
|
|
144
|
+
"""
|
|
145
|
+
self.suite = suite
|
|
146
|
+
self.agent = agent
|
|
147
|
+
self.tracer = tracer
|
|
148
|
+
self.parallel = parallel
|
|
149
|
+
self.max_workers = max_workers
|
|
150
|
+
self.on_case_complete = on_case_complete
|
|
151
|
+
|
|
152
|
+
def run(self) -> EvalRunResult:
|
|
153
|
+
"""Run all test cases in the evaluation suite.
|
|
154
|
+
|
|
155
|
+
Executes setup/teardown hooks, runs all cases (sequentially or in parallel),
|
|
156
|
+
executes assertions, and aggregates results.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
EvalRunResult with aggregated statistics and individual case results.
|
|
160
|
+
"""
|
|
161
|
+
started_at = now()
|
|
162
|
+
|
|
163
|
+
# Run setup hook if provided
|
|
164
|
+
if self.suite.setup:
|
|
165
|
+
try:
|
|
166
|
+
self.suite.setup()
|
|
167
|
+
except Exception as e:
|
|
168
|
+
# If setup fails, fail the entire run
|
|
169
|
+
return EvalRunResult(
|
|
170
|
+
suite_name=self.suite.name,
|
|
171
|
+
started_at=started_at,
|
|
172
|
+
completed_at=now(),
|
|
173
|
+
total_cases=len(self.suite.cases),
|
|
174
|
+
passed_cases=0,
|
|
175
|
+
failed_cases=len(self.suite.cases),
|
|
176
|
+
pass_rate=0.0,
|
|
177
|
+
case_results=[
|
|
178
|
+
CaseResult(
|
|
179
|
+
case_id=case.id,
|
|
180
|
+
case_name=case.name,
|
|
181
|
+
passed=False,
|
|
182
|
+
duration_ms=0.0,
|
|
183
|
+
assertion_results=[],
|
|
184
|
+
error=f"Setup failed: {str(e)}",
|
|
185
|
+
)
|
|
186
|
+
for case in self.suite.cases
|
|
187
|
+
],
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# Run all cases
|
|
191
|
+
if self.parallel:
|
|
192
|
+
case_results = self._run_parallel()
|
|
193
|
+
else:
|
|
194
|
+
case_results = self._run_sequential()
|
|
195
|
+
|
|
196
|
+
# Run teardown hook if provided
|
|
197
|
+
if self.suite.teardown:
|
|
198
|
+
try:
|
|
199
|
+
self.suite.teardown()
|
|
200
|
+
except Exception as e:
|
|
201
|
+
# Log teardown errors but don't fail the run
|
|
202
|
+
# (results are already collected)
|
|
203
|
+
pass
|
|
204
|
+
|
|
205
|
+
completed_at = now()
|
|
206
|
+
|
|
207
|
+
# Aggregate results
|
|
208
|
+
passed_cases = sum(1 for r in case_results if r.passed)
|
|
209
|
+
failed_cases = len(case_results) - passed_cases
|
|
210
|
+
pass_rate = passed_cases / len(case_results) if case_results else 0.0
|
|
211
|
+
|
|
212
|
+
return EvalRunResult(
|
|
213
|
+
suite_name=self.suite.name,
|
|
214
|
+
started_at=started_at,
|
|
215
|
+
completed_at=completed_at,
|
|
216
|
+
total_cases=len(case_results),
|
|
217
|
+
passed_cases=passed_cases,
|
|
218
|
+
failed_cases=failed_cases,
|
|
219
|
+
pass_rate=pass_rate,
|
|
220
|
+
case_results=case_results,
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
def _run_sequential(self) -> list[CaseResult]:
|
|
224
|
+
"""Run cases sequentially in the current thread."""
|
|
225
|
+
results = []
|
|
226
|
+
for case in self.suite.cases:
|
|
227
|
+
result = self.run_case(case)
|
|
228
|
+
results.append(result)
|
|
229
|
+
if self.on_case_complete:
|
|
230
|
+
try:
|
|
231
|
+
self.on_case_complete(result)
|
|
232
|
+
except Exception:
|
|
233
|
+
# Don't let callback errors affect execution
|
|
234
|
+
pass
|
|
235
|
+
return results
|
|
236
|
+
|
|
237
|
+
def _run_parallel(self) -> list[CaseResult]:
|
|
238
|
+
"""Run cases in parallel using a thread pool.
|
|
239
|
+
|
|
240
|
+
Note: Each case creates its own context via the tracer, so we don't
|
|
241
|
+
need to propagate the parent context to worker threads.
|
|
242
|
+
"""
|
|
243
|
+
results = []
|
|
244
|
+
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
|
245
|
+
# Submit all cases directly (no context wrapping needed)
|
|
246
|
+
# Each case will create its own trace context if tracer is configured
|
|
247
|
+
future_to_case = {
|
|
248
|
+
executor.submit(self.run_case, case): case
|
|
249
|
+
for case in self.suite.cases
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
# Collect results as they complete
|
|
253
|
+
for future in as_completed(future_to_case):
|
|
254
|
+
try:
|
|
255
|
+
result = future.result()
|
|
256
|
+
results.append(result)
|
|
257
|
+
if self.on_case_complete:
|
|
258
|
+
try:
|
|
259
|
+
self.on_case_complete(result)
|
|
260
|
+
except Exception:
|
|
261
|
+
pass
|
|
262
|
+
except Exception as e:
|
|
263
|
+
# If case execution fails catastrophically, create error result
|
|
264
|
+
case = future_to_case[future]
|
|
265
|
+
results.append(
|
|
266
|
+
CaseResult(
|
|
267
|
+
case_id=case.id,
|
|
268
|
+
case_name=case.name,
|
|
269
|
+
passed=False,
|
|
270
|
+
duration_ms=0.0,
|
|
271
|
+
assertion_results=[],
|
|
272
|
+
error=f"Execution failed: {str(e)}",
|
|
273
|
+
)
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
return results
|
|
277
|
+
|
|
278
|
+
def run_case(self, case: EvalCase) -> CaseResult:
|
|
279
|
+
"""Run a single test case.
|
|
280
|
+
|
|
281
|
+
Executes the agent with the case input, runs all assertions,
|
|
282
|
+
captures the trace ID if a tracer is configured, and returns
|
|
283
|
+
aggregated results.
|
|
284
|
+
|
|
285
|
+
Args:
|
|
286
|
+
case: The test case to run.
|
|
287
|
+
|
|
288
|
+
Returns:
|
|
289
|
+
CaseResult with pass/fail status and assertion results.
|
|
290
|
+
"""
|
|
291
|
+
start_time = time.perf_counter_ns()
|
|
292
|
+
output = None
|
|
293
|
+
error = None
|
|
294
|
+
trace_id = None
|
|
295
|
+
|
|
296
|
+
# Execute agent
|
|
297
|
+
try:
|
|
298
|
+
# Get agent input from case
|
|
299
|
+
agent_input = case.input
|
|
300
|
+
|
|
301
|
+
# If tracer is configured, wrap execution in a span
|
|
302
|
+
if self.tracer:
|
|
303
|
+
from prela.core.span import SpanType
|
|
304
|
+
|
|
305
|
+
with self.tracer.span(
|
|
306
|
+
name=f"eval.case.{case.id}",
|
|
307
|
+
span_type=SpanType.AGENT,
|
|
308
|
+
attributes={
|
|
309
|
+
"eval.case_id": case.id,
|
|
310
|
+
"eval.case_name": case.name,
|
|
311
|
+
"eval.tags": ",".join(case.tags) if case.tags else "",
|
|
312
|
+
},
|
|
313
|
+
):
|
|
314
|
+
output = self.agent(agent_input)
|
|
315
|
+
# Capture trace_id after span is created
|
|
316
|
+
trace_id = get_current_trace_id()
|
|
317
|
+
else:
|
|
318
|
+
output = self.agent(agent_input)
|
|
319
|
+
|
|
320
|
+
except Exception as e:
|
|
321
|
+
error = f"{type(e).__name__}: {str(e)}\n{traceback.format_exc()}"
|
|
322
|
+
|
|
323
|
+
duration_ms = (time.perf_counter_ns() - start_time) / 1_000_000
|
|
324
|
+
|
|
325
|
+
# Run assertions
|
|
326
|
+
assertion_results = []
|
|
327
|
+
|
|
328
|
+
if error:
|
|
329
|
+
# If agent execution failed, mark all assertions as failed
|
|
330
|
+
# (can't run assertions without output)
|
|
331
|
+
if case.expected:
|
|
332
|
+
assertion_results.append(
|
|
333
|
+
AssertionResult(
|
|
334
|
+
assertion_type="execution",
|
|
335
|
+
passed=False,
|
|
336
|
+
message=f"Agent execution failed: {error}",
|
|
337
|
+
)
|
|
338
|
+
)
|
|
339
|
+
else:
|
|
340
|
+
# Run assertions from case.expected (if provided)
|
|
341
|
+
if case.expected:
|
|
342
|
+
assertions = self._create_assertions_from_expected(case)
|
|
343
|
+
for assertion in assertions:
|
|
344
|
+
try:
|
|
345
|
+
result = assertion.evaluate(output, case.expected, None)
|
|
346
|
+
assertion_results.append(result)
|
|
347
|
+
except Exception as e:
|
|
348
|
+
# If assertion itself crashes, mark as failed
|
|
349
|
+
assertion_results.append(
|
|
350
|
+
AssertionResult(
|
|
351
|
+
assertion_type=type(assertion).__name__,
|
|
352
|
+
passed=False,
|
|
353
|
+
message=f"Assertion error: {str(e)}",
|
|
354
|
+
)
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
# Run assertions from case.assertions (if provided)
|
|
358
|
+
if case.assertions:
|
|
359
|
+
for assertion_config in case.assertions:
|
|
360
|
+
try:
|
|
361
|
+
assertion = create_assertion(assertion_config)
|
|
362
|
+
result = assertion.evaluate(output, case.expected, None)
|
|
363
|
+
assertion_results.append(result)
|
|
364
|
+
except Exception as e:
|
|
365
|
+
assertion_results.append(
|
|
366
|
+
AssertionResult(
|
|
367
|
+
assertion_type=assertion_config.get(
|
|
368
|
+
"type", "unknown"
|
|
369
|
+
),
|
|
370
|
+
passed=False,
|
|
371
|
+
message=f"Assertion error: {str(e)}",
|
|
372
|
+
)
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
# Run default assertions from suite (if any)
|
|
376
|
+
if self.suite.default_assertions:
|
|
377
|
+
for assertion_config in self.suite.default_assertions:
|
|
378
|
+
try:
|
|
379
|
+
assertion = create_assertion(assertion_config)
|
|
380
|
+
result = assertion.evaluate(output, case.expected, None)
|
|
381
|
+
assertion_results.append(result)
|
|
382
|
+
except Exception as e:
|
|
383
|
+
assertion_results.append(
|
|
384
|
+
AssertionResult(
|
|
385
|
+
assertion_type=assertion_config.get(
|
|
386
|
+
"type", "unknown"
|
|
387
|
+
),
|
|
388
|
+
passed=False,
|
|
389
|
+
message=f"Default assertion error: {str(e)}",
|
|
390
|
+
)
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
# Determine overall pass/fail
|
|
394
|
+
passed = (not error) and all(r.passed for r in assertion_results)
|
|
395
|
+
|
|
396
|
+
return CaseResult(
|
|
397
|
+
case_id=case.id,
|
|
398
|
+
case_name=case.name,
|
|
399
|
+
passed=passed,
|
|
400
|
+
duration_ms=duration_ms,
|
|
401
|
+
assertion_results=assertion_results,
|
|
402
|
+
output=output,
|
|
403
|
+
error=error,
|
|
404
|
+
trace_id=trace_id,
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
def _create_assertions_from_expected(
|
|
408
|
+
self, case: EvalCase
|
|
409
|
+
) -> list[BaseAssertion]:
|
|
410
|
+
"""Create assertion objects from case.expected fields.
|
|
411
|
+
|
|
412
|
+
This converts the EvalExpected fields (output, contains, not_contains, etc.)
|
|
413
|
+
into actual assertion objects that can be run.
|
|
414
|
+
|
|
415
|
+
Args:
|
|
416
|
+
case: The eval case with expected output.
|
|
417
|
+
|
|
418
|
+
Returns:
|
|
419
|
+
List of assertion objects.
|
|
420
|
+
"""
|
|
421
|
+
from .assertions.structural import ContainsAssertion, NotContainsAssertion
|
|
422
|
+
|
|
423
|
+
assertions: list[BaseAssertion] = []
|
|
424
|
+
|
|
425
|
+
if case.expected is None:
|
|
426
|
+
return assertions
|
|
427
|
+
|
|
428
|
+
# Exact output match
|
|
429
|
+
if case.expected.output is not None:
|
|
430
|
+
# For now, use contains with the exact string
|
|
431
|
+
# TODO: Create a dedicated ExactMatchAssertion
|
|
432
|
+
assertions.append(
|
|
433
|
+
ContainsAssertion(text=case.expected.output, case_sensitive=True)
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
# Contains (all must be present)
|
|
437
|
+
if case.expected.contains:
|
|
438
|
+
for text in case.expected.contains:
|
|
439
|
+
assertions.append(
|
|
440
|
+
ContainsAssertion(text=text, case_sensitive=True)
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
# Not contains (none must be present)
|
|
444
|
+
if case.expected.not_contains:
|
|
445
|
+
for text in case.expected.not_contains:
|
|
446
|
+
assertions.append(
|
|
447
|
+
NotContainsAssertion(text=text, case_sensitive=True)
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
# TODO: Add tool_calls and metadata assertions when implemented
|
|
451
|
+
|
|
452
|
+
return assertions
|
|
453
|
+
|
|
454
|
+
|
|
455
|
+
def create_assertion(config: dict) -> BaseAssertion:
|
|
456
|
+
"""Factory function to create assertion instances from configuration.
|
|
457
|
+
|
|
458
|
+
This maps assertion type strings to concrete assertion classes and
|
|
459
|
+
instantiates them with the provided configuration.
|
|
460
|
+
|
|
461
|
+
Args:
|
|
462
|
+
config: Dictionary with "type" key and type-specific parameters.
|
|
463
|
+
|
|
464
|
+
Returns:
|
|
465
|
+
Instantiated assertion object.
|
|
466
|
+
|
|
467
|
+
Raises:
|
|
468
|
+
ValueError: If assertion type is unknown or configuration is invalid.
|
|
469
|
+
|
|
470
|
+
Example:
|
|
471
|
+
>>> assertion = create_assertion({
|
|
472
|
+
... "type": "contains",
|
|
473
|
+
... "text": "hello",
|
|
474
|
+
... "case_sensitive": False
|
|
475
|
+
... })
|
|
476
|
+
>>> result = assertion.evaluate("Hello world", None, None)
|
|
477
|
+
>>> assert result.passed
|
|
478
|
+
"""
|
|
479
|
+
from .assertions.semantic import SemanticSimilarityAssertion
|
|
480
|
+
from .assertions.structural import (
|
|
481
|
+
ContainsAssertion,
|
|
482
|
+
JSONValidAssertion,
|
|
483
|
+
LengthAssertion,
|
|
484
|
+
NotContainsAssertion,
|
|
485
|
+
RegexAssertion,
|
|
486
|
+
)
|
|
487
|
+
from .assertions.tool import (
|
|
488
|
+
ToolArgsAssertion,
|
|
489
|
+
ToolCalledAssertion,
|
|
490
|
+
ToolSequenceAssertion,
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
assertion_type = config.get("type")
|
|
494
|
+
if not assertion_type:
|
|
495
|
+
raise ValueError("Assertion config must have 'type' field")
|
|
496
|
+
|
|
497
|
+
# Map type strings to classes
|
|
498
|
+
registry: dict[str, type[BaseAssertion]] = {
|
|
499
|
+
"contains": ContainsAssertion,
|
|
500
|
+
"not_contains": NotContainsAssertion,
|
|
501
|
+
"regex": RegexAssertion,
|
|
502
|
+
"length": LengthAssertion,
|
|
503
|
+
"json_valid": JSONValidAssertion,
|
|
504
|
+
"semantic_similarity": SemanticSimilarityAssertion,
|
|
505
|
+
"tool_called": ToolCalledAssertion,
|
|
506
|
+
"tool_args": ToolArgsAssertion,
|
|
507
|
+
"tool_sequence": ToolSequenceAssertion,
|
|
508
|
+
}
|
|
509
|
+
|
|
510
|
+
assertion_class = registry.get(assertion_type)
|
|
511
|
+
if not assertion_class:
|
|
512
|
+
raise ValueError(
|
|
513
|
+
f"Unknown assertion type: {assertion_type}. "
|
|
514
|
+
f"Available types: {', '.join(registry.keys())}"
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
# Extract parameters (everything except 'type')
|
|
518
|
+
params = {k: v for k, v in config.items() if k != "type"}
|
|
519
|
+
|
|
520
|
+
try:
|
|
521
|
+
return assertion_class(**params)
|
|
522
|
+
except TypeError as e:
|
|
523
|
+
raise ValueError(
|
|
524
|
+
f"Invalid parameters for {assertion_type} assertion: {str(e)}"
|
|
525
|
+
) from e
|