aicert 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aicert/__init__.py +3 -0
- aicert/__main__.py +6 -0
- aicert/artifacts.py +104 -0
- aicert/cli.py +1423 -0
- aicert/config.py +193 -0
- aicert/doctor.py +366 -0
- aicert/hashing.py +28 -0
- aicert/metrics.py +305 -0
- aicert/providers/__init__.py +13 -0
- aicert/providers/anthropic.py +182 -0
- aicert/providers/base.py +36 -0
- aicert/providers/openai.py +153 -0
- aicert/providers/openai_compatible.py +152 -0
- aicert/runner.py +620 -0
- aicert/templating.py +83 -0
- aicert/validation.py +322 -0
- aicert-0.1.0.dist-info/METADATA +306 -0
- aicert-0.1.0.dist-info/RECORD +22 -0
- aicert-0.1.0.dist-info/WHEEL +5 -0
- aicert-0.1.0.dist-info/entry_points.txt +2 -0
- aicert-0.1.0.dist-info/licenses/LICENSE +21 -0
- aicert-0.1.0.dist-info/top_level.txt +1 -0
aicert/runner.py
ADDED
|
@@ -0,0 +1,620 @@
|
|
|
1
|
+
"""Async runner utilities for aicert."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
import random
|
|
6
|
+
import time
|
|
7
|
+
from typing import Any, Dict, List, Optional
|
|
8
|
+
|
|
9
|
+
from aicert.artifacts import append_result, create_run_dir
|
|
10
|
+
from aicert.config import Config, ProviderConfig, ChaosConfig
|
|
11
|
+
from aicert.providers.base import BaseProvider
|
|
12
|
+
from aicert.templating import build_schema_hint, render_prompt
|
|
13
|
+
from aicert.validation import validate_output
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class RetriableError(Exception):
|
|
17
|
+
"""Exception for retriable errors (429, 5xx, etc.)."""
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ProviderError(Exception):
|
|
22
|
+
"""Exception for provider errors (429, 500, etc.)."""
|
|
23
|
+
|
|
24
|
+
def __init__(self, status_code: int, message: str):
|
|
25
|
+
self.status_code = status_code
|
|
26
|
+
self.message = message
|
|
27
|
+
super().__init__(f"Provider error {status_code}: {message}")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class FakeAdapter(BaseProvider):
|
|
31
|
+
"""Fake adapter for testing the runner end-to-end.
|
|
32
|
+
|
|
33
|
+
Returns deterministic JSON output based on prompt content.
|
|
34
|
+
By default, produces stable output with no errors for testing.
|
|
35
|
+
When chaos mode is enabled (via chaos config or AICERT_FAKE_CHAOS env var),
|
|
36
|
+
simulates various failure modes including invalid JSON, schema drift,
|
|
37
|
+
timeouts, and HTTP errors.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
model: str = "fake-model",
|
|
43
|
+
error_rate: float = 0.0,
|
|
44
|
+
latency_ms: int = 10,
|
|
45
|
+
seed: Optional[int] = None,
|
|
46
|
+
chaos: Optional[ChaosConfig] = None,
|
|
47
|
+
):
|
|
48
|
+
super().__init__(model=model)
|
|
49
|
+
self.error_rate = error_rate
|
|
50
|
+
self.latency_ms = latency_ms
|
|
51
|
+
self.call_count = 0
|
|
52
|
+
|
|
53
|
+
# Chaos mode: enabled via config or environment variable
|
|
54
|
+
self.chaos_enabled = chaos is not None
|
|
55
|
+
self.chaos = chaos or ChaosConfig()
|
|
56
|
+
|
|
57
|
+
# Set random seed for reproducibility if provided
|
|
58
|
+
if seed is not None:
|
|
59
|
+
random.seed(seed)
|
|
60
|
+
elif self.chaos_enabled:
|
|
61
|
+
# Use chaos config seed for reproducible chaos
|
|
62
|
+
random.seed(self.chaos.seed)
|
|
63
|
+
|
|
64
|
+
def _generate_base_response(self, prompt: str) -> Dict[str, Any]:
|
|
65
|
+
"""Generate the base deterministic response based on prompt content."""
|
|
66
|
+
# Generate deterministic output based on prompt
|
|
67
|
+
# These mappings ensure stable output for testing
|
|
68
|
+
if "2 + 2" in prompt or "What is 2 + 2" in prompt:
|
|
69
|
+
answer = "4"
|
|
70
|
+
greeting = "Hello!"
|
|
71
|
+
elif "capital of France" in prompt:
|
|
72
|
+
answer = "Paris"
|
|
73
|
+
greeting = "Hi there!"
|
|
74
|
+
elif "How many planets" in prompt:
|
|
75
|
+
answer = "8"
|
|
76
|
+
greeting = "Greetings!"
|
|
77
|
+
else:
|
|
78
|
+
answer = "test answer"
|
|
79
|
+
greeting = "Hello!"
|
|
80
|
+
|
|
81
|
+
confidence = 0.95
|
|
82
|
+
|
|
83
|
+
return {
|
|
84
|
+
"greeting": greeting,
|
|
85
|
+
"answer": answer,
|
|
86
|
+
"confidence": confidence,
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
def _apply_chaos(self, prompt: str, base_response: Dict[str, Any]) -> tuple[str, Optional[Exception]]:
|
|
90
|
+
"""Apply chaos transformations to the response.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
Tuple of (content, exception) where exception is not None if an error should be raised
|
|
94
|
+
"""
|
|
95
|
+
chaos = self.chaos
|
|
96
|
+
|
|
97
|
+
# Roll for timeout (raises asyncio.TimeoutError)
|
|
98
|
+
if random.random() < chaos.p_timeout:
|
|
99
|
+
return "", asyncio.TimeoutError("Chaos timeout")
|
|
100
|
+
|
|
101
|
+
# Roll for HTTP 429
|
|
102
|
+
if random.random() < chaos.p_http_429:
|
|
103
|
+
return "", ProviderError(429, "Rate limited - chaos mode")
|
|
104
|
+
|
|
105
|
+
# Roll for HTTP 500
|
|
106
|
+
if random.random() < chaos.p_http_500:
|
|
107
|
+
return "", ProviderError(500, "Internal server error - chaos mode")
|
|
108
|
+
|
|
109
|
+
# Roll for non-JSON response
|
|
110
|
+
if random.random() < chaos.p_non_json:
|
|
111
|
+
return "This is just plain text, not JSON at all!", None
|
|
112
|
+
|
|
113
|
+
# Roll for wrapped JSON (markdown fence with json language)
|
|
114
|
+
if random.random() < chaos.p_wrapped_json:
|
|
115
|
+
json_str = json.dumps(base_response)
|
|
116
|
+
return f"```json\n{json_str}\n```\n\nHere is your response!", None
|
|
117
|
+
|
|
118
|
+
# Roll for extra keys (valid JSON but with extra fields)
|
|
119
|
+
if random.random() < chaos.p_extra_keys:
|
|
120
|
+
response = base_response.copy()
|
|
121
|
+
response["extra_field_1"] = "should not be here"
|
|
122
|
+
response["extra_field_2"] = 12345
|
|
123
|
+
return json.dumps(response), None
|
|
124
|
+
|
|
125
|
+
# Roll for wrong schema (missing required fields or wrong types)
|
|
126
|
+
if random.random() < chaos.p_wrong_schema:
|
|
127
|
+
# Create a response missing required fields
|
|
128
|
+
wrong_response = {
|
|
129
|
+
"greeting": "Hi",
|
|
130
|
+
# Missing 'answer' field
|
|
131
|
+
# 'confidence' is string instead of number
|
|
132
|
+
"confidence": "high",
|
|
133
|
+
}
|
|
134
|
+
return json.dumps(wrong_response), None
|
|
135
|
+
|
|
136
|
+
# Roll for invalid JSON
|
|
137
|
+
if random.random() < chaos.p_invalid_json:
|
|
138
|
+
valid_json = json.dumps(base_response)
|
|
139
|
+
# Truncate to make it invalid JSON
|
|
140
|
+
truncated = valid_json[:-10] if len(valid_json) > 10 else "{"
|
|
141
|
+
return truncated, None
|
|
142
|
+
|
|
143
|
+
# Return valid JSON (normal case)
|
|
144
|
+
return json.dumps(base_response), None
|
|
145
|
+
|
|
146
|
+
async def generate(self, prompt: str, **kwargs) -> Dict[str, Any]:
|
|
147
|
+
"""Generate a fake response with optional chaos mode."""
|
|
148
|
+
self.call_count += 1
|
|
149
|
+
|
|
150
|
+
# Calculate latency: simulate realistic distribution (100-2000ms) if chaos enabled
|
|
151
|
+
if self.chaos_enabled:
|
|
152
|
+
# Use uniform distribution in range, or exponential for more realism
|
|
153
|
+
latency = random.uniform(100, 2000) / 1000.0 # Convert to seconds
|
|
154
|
+
else:
|
|
155
|
+
latency = self.latency_ms / 1000.0
|
|
156
|
+
|
|
157
|
+
await asyncio.sleep(latency)
|
|
158
|
+
|
|
159
|
+
# Simulate occasional retriable errors in normal mode (only if error_rate > 0)
|
|
160
|
+
if not self.chaos_enabled and self.error_rate > 0 and random.random() < self.error_rate:
|
|
161
|
+
error_type = random.choice([429, 500, 502, 503])
|
|
162
|
+
raise RetriableError(f"Simulated error {error_type}")
|
|
163
|
+
|
|
164
|
+
# Generate base response
|
|
165
|
+
base_response = self._generate_base_response(prompt)
|
|
166
|
+
|
|
167
|
+
# Apply chaos if enabled
|
|
168
|
+
if self.chaos_enabled:
|
|
169
|
+
content, exception = self._apply_chaos(prompt, base_response)
|
|
170
|
+
|
|
171
|
+
# Raise exception if chaos generated one
|
|
172
|
+
if exception is not None:
|
|
173
|
+
raise exception
|
|
174
|
+
else:
|
|
175
|
+
content = json.dumps(base_response)
|
|
176
|
+
|
|
177
|
+
return {
|
|
178
|
+
"choices": [
|
|
179
|
+
{
|
|
180
|
+
"message": {
|
|
181
|
+
"content": content
|
|
182
|
+
}
|
|
183
|
+
}
|
|
184
|
+
],
|
|
185
|
+
"usage": {
|
|
186
|
+
"prompt_tokens": len(prompt.split()),
|
|
187
|
+
"completion_tokens": 50,
|
|
188
|
+
}
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
async def generate_stream(self, prompt: str, **kwargs):
|
|
192
|
+
"""Streaming not supported for FakeAdapter."""
|
|
193
|
+
raise NotImplementedError("Streaming not supported for FakeAdapter")
|
|
194
|
+
|
|
195
|
+
@property
|
|
196
|
+
def provider_type(self) -> str:
|
|
197
|
+
return "fake"
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def create_provider_adapter(config: ProviderConfig) -> BaseProvider:
|
|
201
|
+
"""Create a provider adapter from config.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
config: Provider configuration.
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
Initialized provider adapter.
|
|
208
|
+
"""
|
|
209
|
+
from aicert.providers import OpenAIProvider, AnthropicProvider, OpenAICompatibleProvider
|
|
210
|
+
|
|
211
|
+
provider_map = {
|
|
212
|
+
"openai": OpenAIProvider,
|
|
213
|
+
"anthropic": AnthropicProvider,
|
|
214
|
+
"openai_compatible": OpenAICompatibleProvider,
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
provider_class = provider_map.get(config.provider)
|
|
218
|
+
if provider_class is None:
|
|
219
|
+
raise ValueError(f"Unknown provider type: {config.provider}")
|
|
220
|
+
|
|
221
|
+
return provider_class(
|
|
222
|
+
model=config.model,
|
|
223
|
+
api_key=None, # Will be loaded from environment
|
|
224
|
+
base_url=config.base_url,
|
|
225
|
+
temperature=config.temperature,
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
async def run_single_with_retry(
|
|
230
|
+
adapter: BaseProvider,
|
|
231
|
+
prompt: str,
|
|
232
|
+
timeout_s: int,
|
|
233
|
+
max_retries: int = 3,
|
|
234
|
+
) -> Dict[str, Any]:
|
|
235
|
+
"""Run a single prompt with timeout and retries.
|
|
236
|
+
|
|
237
|
+
Args:
|
|
238
|
+
adapter: Provider adapter.
|
|
239
|
+
prompt: The prompt to send.
|
|
240
|
+
timeout_s: Timeout in seconds.
|
|
241
|
+
max_retries: Maximum retry attempts.
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
Response dict with timing and result info.
|
|
245
|
+
|
|
246
|
+
Raises:
|
|
247
|
+
RetriableError: After exhausting retries.
|
|
248
|
+
"""
|
|
249
|
+
base_delay = 0.5 # seconds
|
|
250
|
+
|
|
251
|
+
for attempt in range(max_retries + 1):
|
|
252
|
+
try:
|
|
253
|
+
start_time = time.perf_counter()
|
|
254
|
+
response = await asyncio.wait_for(
|
|
255
|
+
adapter.generate(prompt),
|
|
256
|
+
timeout=timeout_s
|
|
257
|
+
)
|
|
258
|
+
latency_ms = (time.perf_counter() - start_time) * 1000
|
|
259
|
+
|
|
260
|
+
# Extract content and token usage
|
|
261
|
+
content = response.get("choices", [{}])[0].get("message", {}).get("content", "")
|
|
262
|
+
usage = response.get("usage", {})
|
|
263
|
+
|
|
264
|
+
return {
|
|
265
|
+
"ok": True,
|
|
266
|
+
"content": content,
|
|
267
|
+
"latency_ms": latency_ms,
|
|
268
|
+
"prompt_tokens": usage.get("prompt_tokens", 0),
|
|
269
|
+
"completion_tokens": usage.get("completion_tokens", 0),
|
|
270
|
+
"error": None,
|
|
271
|
+
"attempt": attempt,
|
|
272
|
+
}
|
|
273
|
+
except asyncio.TimeoutError:
|
|
274
|
+
elapsed = time.perf_counter() - start_time
|
|
275
|
+
if attempt < max_retries:
|
|
276
|
+
# Exponential backoff with jitter
|
|
277
|
+
delay = base_delay * (2 ** attempt) + random.uniform(0, 0.5)
|
|
278
|
+
await asyncio.sleep(delay)
|
|
279
|
+
continue
|
|
280
|
+
else:
|
|
281
|
+
return {
|
|
282
|
+
"ok": False,
|
|
283
|
+
"content": "",
|
|
284
|
+
"latency_ms": elapsed * 1000,
|
|
285
|
+
"prompt_tokens": 0,
|
|
286
|
+
"completion_tokens": 0,
|
|
287
|
+
"error": f"Timeout after {timeout_s}s",
|
|
288
|
+
"attempt": attempt,
|
|
289
|
+
}
|
|
290
|
+
except ProviderError as e:
|
|
291
|
+
# Handle provider errors (429, 500, etc.) - these are retriable
|
|
292
|
+
if attempt < max_retries:
|
|
293
|
+
# Exponential backoff with jitter
|
|
294
|
+
delay = base_delay * (2 ** attempt) + random.uniform(0, 0.5)
|
|
295
|
+
await asyncio.sleep(delay)
|
|
296
|
+
continue
|
|
297
|
+
else:
|
|
298
|
+
return {
|
|
299
|
+
"ok": False,
|
|
300
|
+
"content": "",
|
|
301
|
+
"latency_ms": 0,
|
|
302
|
+
"prompt_tokens": 0,
|
|
303
|
+
"completion_tokens": 0,
|
|
304
|
+
"error": str(e),
|
|
305
|
+
"attempt": attempt,
|
|
306
|
+
}
|
|
307
|
+
except RetriableError as e:
|
|
308
|
+
if attempt < max_retries:
|
|
309
|
+
# Exponential backoff with jitter
|
|
310
|
+
delay = base_delay * (2 ** attempt) + random.uniform(0, 0.5)
|
|
311
|
+
await asyncio.sleep(delay)
|
|
312
|
+
continue
|
|
313
|
+
else:
|
|
314
|
+
return {
|
|
315
|
+
"ok": False,
|
|
316
|
+
"content": "",
|
|
317
|
+
"latency_ms": 0,
|
|
318
|
+
"prompt_tokens": 0,
|
|
319
|
+
"completion_tokens": 0,
|
|
320
|
+
"error": str(e),
|
|
321
|
+
"attempt": attempt,
|
|
322
|
+
}
|
|
323
|
+
except Exception as e:
|
|
324
|
+
# Non-retriable error
|
|
325
|
+
return {
|
|
326
|
+
"ok": False,
|
|
327
|
+
"content": "",
|
|
328
|
+
"latency_ms": 0,
|
|
329
|
+
"prompt_tokens": 0,
|
|
330
|
+
"completion_tokens": 0,
|
|
331
|
+
"error": str(e),
|
|
332
|
+
"attempt": attempt,
|
|
333
|
+
}
|
|
334
|
+
|
|
335
|
+
# Should not reach here
|
|
336
|
+
return {
|
|
337
|
+
"ok": False,
|
|
338
|
+
"content": "",
|
|
339
|
+
"latency_ms": 0,
|
|
340
|
+
"prompt_tokens": 0,
|
|
341
|
+
"completion_tokens": 0,
|
|
342
|
+
"error": "Max retries exceeded",
|
|
343
|
+
"attempt": max_retries,
|
|
344
|
+
}
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
async def execute_case(
|
|
348
|
+
adapter: BaseProvider,
|
|
349
|
+
case: Dict[str, Any],
|
|
350
|
+
case_id: str,
|
|
351
|
+
schema: Dict[str, Any],
|
|
352
|
+
schema_hint: str,
|
|
353
|
+
config: Config,
|
|
354
|
+
run_index: int,
|
|
355
|
+
semaphore: asyncio.Semaphore,
|
|
356
|
+
) -> Dict[str, Any]:
|
|
357
|
+
"""Execute a single test case with concurrency control.
|
|
358
|
+
|
|
359
|
+
Args:
|
|
360
|
+
adapter: Provider adapter.
|
|
361
|
+
case: Test case dict.
|
|
362
|
+
case_id: Unique case identifier.
|
|
363
|
+
schema: JSON schema for validation.
|
|
364
|
+
schema_hint: Compact schema hint string.
|
|
365
|
+
config: Main config.
|
|
366
|
+
run_index: Index of this run (0-based).
|
|
367
|
+
semaphore: Concurrency semaphore.
|
|
368
|
+
|
|
369
|
+
Returns:
|
|
370
|
+
Result dict for this execution.
|
|
371
|
+
"""
|
|
372
|
+
async with semaphore:
|
|
373
|
+
# Render prompt
|
|
374
|
+
prompt_template = case.get("prompt", "")
|
|
375
|
+
variables = case.get("variables", {})
|
|
376
|
+
|
|
377
|
+
try:
|
|
378
|
+
rendered_prompt = render_prompt(
|
|
379
|
+
template=prompt_template,
|
|
380
|
+
case=variables,
|
|
381
|
+
schema_hint=schema_hint,
|
|
382
|
+
case_id=case_id,
|
|
383
|
+
)
|
|
384
|
+
except Exception as e:
|
|
385
|
+
return {
|
|
386
|
+
"provider_id": adapter.model,
|
|
387
|
+
"case_id": case_id,
|
|
388
|
+
"run_index": run_index,
|
|
389
|
+
"ok_json": False,
|
|
390
|
+
"ok_schema": False,
|
|
391
|
+
"extra_keys": [],
|
|
392
|
+
"output_json": None,
|
|
393
|
+
"latency_ms": 0,
|
|
394
|
+
"cost_usd": 0,
|
|
395
|
+
"error": f"Prompt rendering error: {str(e)}",
|
|
396
|
+
}
|
|
397
|
+
|
|
398
|
+
# Call adapter with timeout and retries
|
|
399
|
+
result = await run_single_with_retry(
|
|
400
|
+
adapter=adapter,
|
|
401
|
+
prompt=rendered_prompt,
|
|
402
|
+
timeout_s=config.timeout_s,
|
|
403
|
+
max_retries=3,
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
if not result["ok"]:
|
|
407
|
+
return {
|
|
408
|
+
"provider_id": adapter.model,
|
|
409
|
+
"case_id": case_id,
|
|
410
|
+
"run_index": run_index,
|
|
411
|
+
"ok_json": False,
|
|
412
|
+
"ok_schema": False,
|
|
413
|
+
"extra_keys": [],
|
|
414
|
+
"output_json": None,
|
|
415
|
+
"latency_ms": result.get("latency_ms", 0),
|
|
416
|
+
"cost_usd": 0,
|
|
417
|
+
"error": result.get("error", "Unknown error"),
|
|
418
|
+
}
|
|
419
|
+
|
|
420
|
+
# Validate output
|
|
421
|
+
validation_result = validate_output(
|
|
422
|
+
text=result["content"],
|
|
423
|
+
schema=schema,
|
|
424
|
+
extract_json=config.validation.extract_json,
|
|
425
|
+
allow_extra_keys=config.validation.allow_extra_keys,
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
# Calculate cost (simplified - based on tokens)
|
|
429
|
+
cost_usd = estimate_cost(
|
|
430
|
+
prompt_tokens=result.get("prompt_tokens", 0),
|
|
431
|
+
completion_tokens=result.get("completion_tokens", 0),
|
|
432
|
+
model=adapter.model,
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
return {
|
|
436
|
+
"provider_id": adapter.model,
|
|
437
|
+
"case_id": case_id,
|
|
438
|
+
"run_index": run_index,
|
|
439
|
+
"ok_json": validation_result.ok_json,
|
|
440
|
+
"ok_schema": validation_result.ok_schema,
|
|
441
|
+
"extra_keys": validation_result.extra_keys,
|
|
442
|
+
"output_json": validation_result.parsed,
|
|
443
|
+
"latency_ms": result["latency_ms"],
|
|
444
|
+
"cost_usd": cost_usd,
|
|
445
|
+
"error": validation_result.error,
|
|
446
|
+
}
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
def estimate_cost(prompt_tokens: int, completion_tokens: int, model: str) -> float:
|
|
450
|
+
"""Estimate cost in USD based on token counts.
|
|
451
|
+
|
|
452
|
+
This is a simplified estimator. In production, you'd use
|
|
453
|
+
actual provider pricing.
|
|
454
|
+
"""
|
|
455
|
+
# Simplified pricing (placeholder values)
|
|
456
|
+
pricing = {
|
|
457
|
+
"gpt-4": {"prompt": 0.00003, "completion": 0.00006},
|
|
458
|
+
"gpt-4o": {"prompt": 0.000005, "completion": 0.000015},
|
|
459
|
+
"gpt-3.5-turbo": {"prompt": 0.0000015, "completion": 0.000002},
|
|
460
|
+
"claude-3-opus-20240229": {"prompt": 0.000015, "completion": 0.000075},
|
|
461
|
+
"claude-3-sonnet-20240229": {"prompt": 0.000003, "completion": 0.000015},
|
|
462
|
+
"fake-model": {"prompt": 0.0, "completion": 0.0},
|
|
463
|
+
}
|
|
464
|
+
|
|
465
|
+
# Try exact match first, then fall back to basic estimates
|
|
466
|
+
if model in pricing:
|
|
467
|
+
rates = pricing[model]
|
|
468
|
+
elif "gpt-4" in model:
|
|
469
|
+
rates = pricing["gpt-4"]
|
|
470
|
+
elif "claude" in model:
|
|
471
|
+
rates = pricing["claude-3-sonnet-20240229"]
|
|
472
|
+
else:
|
|
473
|
+
# Default to zero cost for unknown models
|
|
474
|
+
rates = {"prompt": 0.0, "completion": 0.0}
|
|
475
|
+
|
|
476
|
+
return (prompt_tokens * rates["prompt"]) + (completion_tokens * rates["completion"])
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
async def run_provider_cases(
|
|
480
|
+
adapter: BaseProvider,
|
|
481
|
+
cases: List[Dict[str, Any]],
|
|
482
|
+
schema: Dict[str, Any],
|
|
483
|
+
schema_hint: str,
|
|
484
|
+
config: Config,
|
|
485
|
+
) -> List[Dict[str, Any]]:
|
|
486
|
+
"""Run all cases for a single provider.
|
|
487
|
+
|
|
488
|
+
Args:
|
|
489
|
+
adapter: Provider adapter.
|
|
490
|
+
cases: List of test cases.
|
|
491
|
+
schema: JSON schema for validation.
|
|
492
|
+
schema_hint: Compact schema hint string.
|
|
493
|
+
config: Main config.
|
|
494
|
+
|
|
495
|
+
Returns:
|
|
496
|
+
List of result dicts.
|
|
497
|
+
"""
|
|
498
|
+
results = []
|
|
499
|
+
semaphore = asyncio.Semaphore(config.concurrency)
|
|
500
|
+
|
|
501
|
+
# Create tasks for all runs
|
|
502
|
+
tasks = []
|
|
503
|
+
for case in cases:
|
|
504
|
+
case_id = case.get("name", case.get("id", str(cases.index(case))))
|
|
505
|
+
for run_idx in range(config.runs):
|
|
506
|
+
task = execute_case(
|
|
507
|
+
adapter=adapter,
|
|
508
|
+
case=case,
|
|
509
|
+
case_id=case_id,
|
|
510
|
+
schema=schema,
|
|
511
|
+
schema_hint=schema_hint,
|
|
512
|
+
config=config,
|
|
513
|
+
run_index=run_idx,
|
|
514
|
+
semaphore=semaphore,
|
|
515
|
+
)
|
|
516
|
+
tasks.append(task)
|
|
517
|
+
|
|
518
|
+
# Execute all tasks concurrently
|
|
519
|
+
results = await asyncio.gather(*tasks)
|
|
520
|
+
|
|
521
|
+
return results
|
|
522
|
+
|
|
523
|
+
|
|
524
|
+
def load_cases(cases_file: str) -> List[Dict[str, Any]]:
|
|
525
|
+
"""Load test cases from JSONL file.
|
|
526
|
+
|
|
527
|
+
Args:
|
|
528
|
+
cases_file: Path to JSONL file.
|
|
529
|
+
|
|
530
|
+
Returns:
|
|
531
|
+
List of case dicts.
|
|
532
|
+
"""
|
|
533
|
+
cases = []
|
|
534
|
+
with open(cases_file, "r") as f:
|
|
535
|
+
for line in f:
|
|
536
|
+
line = line.strip()
|
|
537
|
+
if line:
|
|
538
|
+
cases.append(json.loads(line))
|
|
539
|
+
return cases
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
def load_schema(schema_file: str) -> Dict[str, Any]:
|
|
543
|
+
"""Load JSON schema from file.
|
|
544
|
+
|
|
545
|
+
Args:
|
|
546
|
+
schema_file: Path to schema file.
|
|
547
|
+
|
|
548
|
+
Returns:
|
|
549
|
+
Schema dict.
|
|
550
|
+
"""
|
|
551
|
+
import yaml
|
|
552
|
+
with open(schema_file, "r") as f:
|
|
553
|
+
return yaml.safe_load(f)
|
|
554
|
+
|
|
555
|
+
|
|
556
|
+
async def run_suite(config: Config, output_dir: Optional[str] = None) -> List[Dict[str, Any]]:
|
|
557
|
+
"""Run the test suite for all providers and cases.
|
|
558
|
+
|
|
559
|
+
Args:
|
|
560
|
+
config: Aicert configuration.
|
|
561
|
+
output_dir: Optional output directory for results.
|
|
562
|
+
|
|
563
|
+
Returns:
|
|
564
|
+
List of all execution result dicts.
|
|
565
|
+
"""
|
|
566
|
+
from pathlib import Path
|
|
567
|
+
|
|
568
|
+
# Create run directory
|
|
569
|
+
run_dir = create_run_dir(output_dir)
|
|
570
|
+
|
|
571
|
+
# Resolve paths relative to config file location if possible
|
|
572
|
+
config_path = getattr(config, '_config_path', None)
|
|
573
|
+
if config_path:
|
|
574
|
+
config_dir = Path(config_path).parent
|
|
575
|
+
else:
|
|
576
|
+
config_dir = Path.cwd()
|
|
577
|
+
|
|
578
|
+
# Load cases and schema (resolve relative to config directory)
|
|
579
|
+
cases_file = config_dir / config.cases_file if not Path(config.cases_file).is_absolute() else config.cases_file
|
|
580
|
+
schema_file = config_dir / config.schema_file if not Path(config.schema_file).is_absolute() else config.schema_file
|
|
581
|
+
|
|
582
|
+
cases = load_cases(str(cases_file))
|
|
583
|
+
schema = load_schema(str(schema_file))
|
|
584
|
+
schema_hint = build_schema_hint(schema)
|
|
585
|
+
|
|
586
|
+
all_results = []
|
|
587
|
+
|
|
588
|
+
# Run each provider
|
|
589
|
+
for provider_config in config.providers:
|
|
590
|
+
# Create adapter (use FakeAdapter if specified, otherwise real adapter)
|
|
591
|
+
if provider_config.provider == "fake":
|
|
592
|
+
# Pass chaos config if available
|
|
593
|
+
adapter = FakeAdapter(
|
|
594
|
+
model=provider_config.id,
|
|
595
|
+
chaos=provider_config.chaos,
|
|
596
|
+
)
|
|
597
|
+
else:
|
|
598
|
+
adapter = create_provider_adapter(provider_config)
|
|
599
|
+
|
|
600
|
+
try:
|
|
601
|
+
# Run all cases for this provider
|
|
602
|
+
provider_results = await run_provider_cases(
|
|
603
|
+
adapter=adapter,
|
|
604
|
+
cases=cases,
|
|
605
|
+
schema=schema,
|
|
606
|
+
schema_hint=schema_hint,
|
|
607
|
+
config=config,
|
|
608
|
+
)
|
|
609
|
+
finally:
|
|
610
|
+
# Close the adapter if it has a close method
|
|
611
|
+
if hasattr(adapter, 'close'):
|
|
612
|
+
await adapter.close()
|
|
613
|
+
|
|
614
|
+
# Append results to artifacts
|
|
615
|
+
for result in provider_results:
|
|
616
|
+
append_result(run_dir, result)
|
|
617
|
+
|
|
618
|
+
all_results.extend(provider_results)
|
|
619
|
+
|
|
620
|
+
return all_results
|
aicert/templating.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""Templating utilities for aicert."""
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from string import Template
|
|
5
|
+
from typing import Any, Dict, Optional
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def render_template(template: str, variables: Dict[str, Any]) -> str:
|
|
9
|
+
"""Render a template string with variables."""
|
|
10
|
+
t = Template(template)
|
|
11
|
+
return t.substitute(**variables)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def render_template_file(path: str, variables: Dict[str, Any]) -> str:
|
|
15
|
+
"""Render a template file with variables."""
|
|
16
|
+
with open(path, "r") as f:
|
|
17
|
+
template = f.read()
|
|
18
|
+
return render_template(template, variables)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def build_schema_hint(schema: dict) -> str:
|
|
22
|
+
"""Build a deterministic, compact schema hint from a JSON schema.
|
|
23
|
+
|
|
24
|
+
Includes required keys and their types (shallow extraction).
|
|
25
|
+
Output is deterministic and compact.
|
|
26
|
+
"""
|
|
27
|
+
if not isinstance(schema, dict):
|
|
28
|
+
return ""
|
|
29
|
+
|
|
30
|
+
required = schema.get("required", [])
|
|
31
|
+
properties = schema.get("properties", {})
|
|
32
|
+
|
|
33
|
+
# Build hint for required fields with their types
|
|
34
|
+
hints = []
|
|
35
|
+
for key in sorted(required): # sorted for determinism
|
|
36
|
+
prop = properties.get(key, {})
|
|
37
|
+
prop_type = prop.get("type", "any")
|
|
38
|
+
hints.append(f"{key}: {prop_type}")
|
|
39
|
+
|
|
40
|
+
return " | ".join(hints) if hints else ""
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def render_prompt(
|
|
44
|
+
template: str,
|
|
45
|
+
case: dict,
|
|
46
|
+
schema_hint: str,
|
|
47
|
+
case_id: Optional[str] = None
|
|
48
|
+
) -> str:
|
|
49
|
+
"""Render a prompt template with placeholders.
|
|
50
|
+
|
|
51
|
+
Supports {{ key }} placeholders with optional whitespace.
|
|
52
|
+
Injects schema_hint automatically.
|
|
53
|
+
Raises an error if any placeholder key is missing.
|
|
54
|
+
"""
|
|
55
|
+
# Build variables dict from case, including schema_hint
|
|
56
|
+
variables = dict(case)
|
|
57
|
+
if schema_hint:
|
|
58
|
+
variables["schema_hint"] = schema_hint
|
|
59
|
+
|
|
60
|
+
# Pattern to match {{ key }} with optional whitespace
|
|
61
|
+
pattern = re.compile(r'\{\{\s*(\w+)\s*\}\}')
|
|
62
|
+
|
|
63
|
+
# Find all placeholders and check for missing keys
|
|
64
|
+
missing_keys = []
|
|
65
|
+
for match in pattern.finditer(template):
|
|
66
|
+
key = match.group(1)
|
|
67
|
+
if key not in variables:
|
|
68
|
+
missing_keys.append(key)
|
|
69
|
+
|
|
70
|
+
if missing_keys:
|
|
71
|
+
case_info = f" (case_id: {case_id})" if case_id else ""
|
|
72
|
+
missing_str = ", ".join(missing_keys)
|
|
73
|
+
raise ValueError(
|
|
74
|
+
f"Missing placeholder value(s) in template: {missing_str}{case_info}"
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# Replace all placeholders
|
|
78
|
+
result = pattern.sub(
|
|
79
|
+
lambda m: str(variables[m.group(1)]),
|
|
80
|
+
template
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
return result
|