aicert 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aicert/__init__.py +3 -0
- aicert/__main__.py +6 -0
- aicert/artifacts.py +104 -0
- aicert/cli.py +1423 -0
- aicert/config.py +193 -0
- aicert/doctor.py +366 -0
- aicert/hashing.py +28 -0
- aicert/metrics.py +305 -0
- aicert/providers/__init__.py +13 -0
- aicert/providers/anthropic.py +182 -0
- aicert/providers/base.py +36 -0
- aicert/providers/openai.py +153 -0
- aicert/providers/openai_compatible.py +152 -0
- aicert/runner.py +620 -0
- aicert/templating.py +83 -0
- aicert/validation.py +322 -0
- aicert-0.1.0.dist-info/METADATA +306 -0
- aicert-0.1.0.dist-info/RECORD +22 -0
- aicert-0.1.0.dist-info/WHEEL +5 -0
- aicert-0.1.0.dist-info/entry_points.txt +2 -0
- aicert-0.1.0.dist-info/licenses/LICENSE +21 -0
- aicert-0.1.0.dist-info/top_level.txt +1 -0
aicert/metrics.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
"""Metrics utilities for aicert."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import re
|
|
5
|
+
import statistics
|
|
6
|
+
from typing import Any, Dict, List, Optional
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def clamp(value: float, min_val: float, max_val: float) -> float:
|
|
10
|
+
"""Clamp value between min and max."""
|
|
11
|
+
return max(min_val, min(value, max_val))
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def canonicalize_json(obj: Any) -> Any:
|
|
15
|
+
"""Recursively canonicalize JSON by sorting keys."""
|
|
16
|
+
if isinstance(obj, dict):
|
|
17
|
+
return sorted((k, canonicalize_json(v)) for k, v in obj.items())
|
|
18
|
+
elif isinstance(obj, list):
|
|
19
|
+
return [canonicalize_json(item) for item in obj]
|
|
20
|
+
return obj
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def stringify_compact(obj: Any) -> str:
|
|
24
|
+
"""Convert canonicalized JSON to compact string."""
|
|
25
|
+
return json.dumps(obj, separators=(',', ':'))
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def tokenize(text: str) -> set:
|
|
29
|
+
"""Tokenize text by regex into tokens."""
|
|
30
|
+
# Tokenize by alphanumeric sequences
|
|
31
|
+
tokens = re.findall(r'\w+', text.lower())
|
|
32
|
+
return set(tokens)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def jaccard_similarity(set1: set, set2: set) -> float:
|
|
36
|
+
"""Compute Jaccard similarity between two sets."""
|
|
37
|
+
if not set1 and not set2:
|
|
38
|
+
return 1.0
|
|
39
|
+
intersection = len(set1 & set2)
|
|
40
|
+
union = len(set1 | set2)
|
|
41
|
+
return intersection / union if union > 0 else 0.0
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def compute_similarity(outputs: List[Optional[Dict]]) -> float:
|
|
45
|
+
"""Compute similarity score based on Jaccard similarity of canonicalized tokens."""
|
|
46
|
+
# Filter valid outputs (non-None)
|
|
47
|
+
valid_outputs = [o for o in outputs if o is not None]
|
|
48
|
+
|
|
49
|
+
if not valid_outputs:
|
|
50
|
+
return 0.0
|
|
51
|
+
|
|
52
|
+
# Canonicalize and tokenize each output
|
|
53
|
+
canonicalized = [stringify_compact(canonicalize_json(o)) for o in valid_outputs]
|
|
54
|
+
token_sets = [tokenize(c) for c in canonicalized]
|
|
55
|
+
|
|
56
|
+
# Choose first valid as baseline
|
|
57
|
+
baseline_tokens = token_sets[0]
|
|
58
|
+
|
|
59
|
+
# Compute Jaccard similarity with baseline for each other
|
|
60
|
+
similarities = [jaccard_similarity(baseline_tokens, ts) for ts in token_sets]
|
|
61
|
+
|
|
62
|
+
# Average * 100
|
|
63
|
+
return sum(similarities) / len(similarities) * 100 if similarities else 0.0
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def compute_structural_consistency(outputs: List[Optional[Dict]], required_keys: List[str]) -> float:
|
|
67
|
+
"""Compute structural consistency based on required key frequency."""
|
|
68
|
+
if not required_keys:
|
|
69
|
+
return 100.0
|
|
70
|
+
|
|
71
|
+
valid_outputs = [o for o in outputs if o is not None]
|
|
72
|
+
|
|
73
|
+
if not valid_outputs:
|
|
74
|
+
return 0.0
|
|
75
|
+
|
|
76
|
+
# For each required key, compute frequency present
|
|
77
|
+
key_frequencies = []
|
|
78
|
+
for key in required_keys:
|
|
79
|
+
present_count = sum(1 for o in valid_outputs if isinstance(o, dict) and key in o)
|
|
80
|
+
freq = present_count / len(valid_outputs)
|
|
81
|
+
key_frequencies.append(freq)
|
|
82
|
+
|
|
83
|
+
# Average across required keys * 100
|
|
84
|
+
return (sum(key_frequencies) / len(key_frequencies)) * 100 if key_frequencies else 0.0
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def compute_latency_stats(latencies: List[float]) -> Dict[str, float]:
|
|
88
|
+
"""Compute latency statistics: mean, p95, std."""
|
|
89
|
+
if not latencies:
|
|
90
|
+
return {"mean": 0.0, "p95": 0.0, "std": 0.0}
|
|
91
|
+
|
|
92
|
+
mean_val = statistics.mean(latencies)
|
|
93
|
+
|
|
94
|
+
# Calculate p95
|
|
95
|
+
sorted_latencies = sorted(latencies)
|
|
96
|
+
p95_idx = int(len(sorted_latencies) * 0.95)
|
|
97
|
+
p95_val = sorted_latencies[min(p95_idx, len(sorted_latencies) - 1)]
|
|
98
|
+
|
|
99
|
+
# Calculate std
|
|
100
|
+
std_val = statistics.stdev(latencies) if len(latencies) > 1 else 0.0
|
|
101
|
+
|
|
102
|
+
return {
|
|
103
|
+
"mean": mean_val,
|
|
104
|
+
"p95": p95_val,
|
|
105
|
+
"std": std_val
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def compute_latency_stability(latency_stats: Dict[str, float]) -> float:
|
|
110
|
+
"""Compute latency stability score."""
|
|
111
|
+
mean = latency_stats.get("mean", 0)
|
|
112
|
+
std = latency_stats.get("std", 0)
|
|
113
|
+
|
|
114
|
+
if mean <= 0:
|
|
115
|
+
return 0.0
|
|
116
|
+
|
|
117
|
+
return clamp(100 * (1 - std / mean), 0, 100)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def compute_stability_score(
|
|
121
|
+
compliance: float,
|
|
122
|
+
structural: float,
|
|
123
|
+
similarity: float,
|
|
124
|
+
latency_stability: float
|
|
125
|
+
) -> float:
|
|
126
|
+
"""Compute final stability score."""
|
|
127
|
+
return compliance * 0.40 + structural * 0.25 + similarity * 0.25 + latency_stability * 0.10
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def compute_summary(
|
|
131
|
+
results: List[Dict[str, Any]],
|
|
132
|
+
schema: Dict[str, Any],
|
|
133
|
+
prompt_hash: Optional[str] = None,
|
|
134
|
+
schema_hash: Optional[str] = None,
|
|
135
|
+
) -> Dict[str, Any]:
|
|
136
|
+
"""
|
|
137
|
+
Compute metrics summary from execution results.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
results: List of execution result dicts with fields:
|
|
141
|
+
provider_id, case_id, ok_json, ok_schema, extra_keys,
|
|
142
|
+
latency_ms, cost_usd, output_json (parsed JSON when ok_json), error
|
|
143
|
+
schema: JSON schema dict
|
|
144
|
+
prompt_hash: Optional SHA-256 hash of the prompt file
|
|
145
|
+
schema_hash: Optional SHA-256 hash of the schema file
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
Dict containing per-provider metrics and overall summary
|
|
149
|
+
"""
|
|
150
|
+
# Group results by provider
|
|
151
|
+
providers: Dict[str, List[Dict[str, Any]]] = {}
|
|
152
|
+
for result in results:
|
|
153
|
+
provider_id = result.get("provider_id", "unknown")
|
|
154
|
+
if provider_id not in providers:
|
|
155
|
+
providers[provider_id] = []
|
|
156
|
+
providers[provider_id].append(result)
|
|
157
|
+
|
|
158
|
+
# Get required keys from schema
|
|
159
|
+
required_keys = schema.get("required", []) if schema else []
|
|
160
|
+
|
|
161
|
+
# Compute per-provider metrics
|
|
162
|
+
per_provider: Dict[str, Dict[str, Any]] = {}
|
|
163
|
+
for provider_id, provider_results in providers.items():
|
|
164
|
+
total_runs = len(provider_results)
|
|
165
|
+
|
|
166
|
+
# Count ok_json and ok_schema
|
|
167
|
+
ok_json_count = sum(1 for r in provider_results if r.get("ok_json", False))
|
|
168
|
+
ok_schema_count = sum(1 for r in provider_results if r.get("ok_schema", False))
|
|
169
|
+
|
|
170
|
+
# Count error types
|
|
171
|
+
json_parse_failures = sum(1 for r in provider_results if not r.get("ok_json", False))
|
|
172
|
+
schema_failures = sum(1 for r in provider_results if r.get("ok_json", False) and not r.get("ok_schema", False))
|
|
173
|
+
provider_errors = sum(1 for r in provider_results if r.get("error") and any(x in r.get("error", "") for x in ["429", "500", "Provider error"]))
|
|
174
|
+
timeouts = sum(1 for r in provider_results if r.get("error") and "Timeout" in r.get("error", ""))
|
|
175
|
+
|
|
176
|
+
# Compute rates
|
|
177
|
+
json_parse_rate = (ok_json_count / total_runs * 100) if total_runs > 0 else 0.0
|
|
178
|
+
schema_compliance = (ok_schema_count / total_runs * 100) if total_runs > 0 else 0.0
|
|
179
|
+
|
|
180
|
+
# Collect outputs for structural consistency and similarity
|
|
181
|
+
outputs = [r.get("output_json") for r in provider_results]
|
|
182
|
+
|
|
183
|
+
# Compute structural consistency
|
|
184
|
+
structural_consistency = compute_structural_consistency(outputs, required_keys)
|
|
185
|
+
|
|
186
|
+
# Compute similarity
|
|
187
|
+
similarity = compute_similarity(outputs)
|
|
188
|
+
|
|
189
|
+
# Collect latencies
|
|
190
|
+
latencies = [r.get("latency_ms", 0) for r in provider_results]
|
|
191
|
+
latency_stats = compute_latency_stats(latencies)
|
|
192
|
+
|
|
193
|
+
# Compute latency stability
|
|
194
|
+
latency_stability = compute_latency_stability(latency_stats)
|
|
195
|
+
|
|
196
|
+
# Compute final stability score
|
|
197
|
+
stability_score = compute_stability_score(
|
|
198
|
+
schema_compliance,
|
|
199
|
+
structural_consistency,
|
|
200
|
+
similarity,
|
|
201
|
+
latency_stability
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
# Collect costs
|
|
205
|
+
costs = [r.get("cost_usd", 0) for r in provider_results]
|
|
206
|
+
total_cost = sum(costs)
|
|
207
|
+
|
|
208
|
+
per_provider[provider_id] = {
|
|
209
|
+
"prompt_hash": prompt_hash,
|
|
210
|
+
"schema_hash": schema_hash,
|
|
211
|
+
"total_runs": total_runs,
|
|
212
|
+
"ok_json_count": ok_json_count,
|
|
213
|
+
"ok_schema_count": ok_schema_count,
|
|
214
|
+
"json_parse_failures": json_parse_failures,
|
|
215
|
+
"schema_failures": schema_failures,
|
|
216
|
+
"provider_errors": provider_errors,
|
|
217
|
+
"timeouts": timeouts,
|
|
218
|
+
"json_parse_rate": json_parse_rate,
|
|
219
|
+
"schema_compliance": schema_compliance,
|
|
220
|
+
"structural_consistency": structural_consistency,
|
|
221
|
+
"similarity": similarity,
|
|
222
|
+
"latency_stats": latency_stats,
|
|
223
|
+
"latency_stability": latency_stability,
|
|
224
|
+
"stability_score": stability_score,
|
|
225
|
+
"total_cost_usd": total_cost,
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
# Compute overall metrics
|
|
229
|
+
all_results = list(results)
|
|
230
|
+
all_outputs = [r.get("output_json") for r in all_results]
|
|
231
|
+
all_latencies = [r.get("latency_ms", 0) for r in all_results]
|
|
232
|
+
all_costs = [r.get("cost_usd", 0) for r in all_results]
|
|
233
|
+
|
|
234
|
+
# Overall error counts
|
|
235
|
+
overall_json_parse_failures = sum(1 for r in all_results if not r.get("ok_json", False))
|
|
236
|
+
overall_schema_failures = sum(1 for r in all_results if r.get("ok_json", False) and not r.get("ok_schema", False))
|
|
237
|
+
overall_provider_errors = sum(1 for r in all_results if r.get("error") and any(x in r.get("error", "") for x in ["429", "500", "Provider error"]))
|
|
238
|
+
overall_timeouts = sum(1 for r in all_results if r.get("error") and "Timeout" in r.get("error", ""))
|
|
239
|
+
|
|
240
|
+
overall = {
|
|
241
|
+
"total_runs": len(all_results),
|
|
242
|
+
"providers_count": len(providers),
|
|
243
|
+
"json_parse_failures": overall_json_parse_failures,
|
|
244
|
+
"schema_failures": overall_schema_failures,
|
|
245
|
+
"provider_errors": overall_provider_errors,
|
|
246
|
+
"timeouts": overall_timeouts,
|
|
247
|
+
"json_parse_rate": (sum(1 for r in all_results if r.get("ok_json", False)) / len(all_results) * 100) if all_results else 0.0,
|
|
248
|
+
"schema_compliance": (sum(1 for r in all_results if r.get("ok_schema", False)) / len(all_results) * 100) if all_results else 0.0,
|
|
249
|
+
"structural_consistency": compute_structural_consistency(all_outputs, required_keys),
|
|
250
|
+
"similarity": compute_similarity(all_outputs),
|
|
251
|
+
"latency_stats": compute_latency_stats(all_latencies),
|
|
252
|
+
"latency_stability": compute_latency_stability(compute_latency_stats(all_latencies)),
|
|
253
|
+
"total_cost_usd": sum(all_costs),
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
overall["stability_score"] = compute_stability_score(
|
|
257
|
+
overall["schema_compliance"],
|
|
258
|
+
overall["structural_consistency"],
|
|
259
|
+
overall["similarity"],
|
|
260
|
+
overall["latency_stability"]
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
return {
|
|
264
|
+
"prompt_hash": prompt_hash,
|
|
265
|
+
"schema_hash": schema_hash,
|
|
266
|
+
"per_provider": per_provider,
|
|
267
|
+
"overall": overall,
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
class Metrics:
|
|
272
|
+
"""Container for validation metrics."""
|
|
273
|
+
|
|
274
|
+
def __init__(self):
|
|
275
|
+
self.total: int = 0
|
|
276
|
+
self.passed: int = 0
|
|
277
|
+
self.failed: int = 0
|
|
278
|
+
self.errors: List[Dict[str, Any]] = []
|
|
279
|
+
|
|
280
|
+
def add_result(self, passed: bool, error: str = None) -> None:
|
|
281
|
+
"""Add a validation result."""
|
|
282
|
+
self.total += 1
|
|
283
|
+
if passed:
|
|
284
|
+
self.passed += 1
|
|
285
|
+
else:
|
|
286
|
+
self.failed += 1
|
|
287
|
+
if error:
|
|
288
|
+
self.errors.append({"error": error})
|
|
289
|
+
|
|
290
|
+
@property
|
|
291
|
+
def success_rate(self) -> float:
|
|
292
|
+
"""Calculate success rate."""
|
|
293
|
+
if self.total == 0:
|
|
294
|
+
return 0.0
|
|
295
|
+
return self.passed / self.total
|
|
296
|
+
|
|
297
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
298
|
+
"""Convert metrics to dictionary."""
|
|
299
|
+
return {
|
|
300
|
+
"total": self.total,
|
|
301
|
+
"passed": self.passed,
|
|
302
|
+
"failed": self.failed,
|
|
303
|
+
"success_rate": self.success_rate,
|
|
304
|
+
"errors": self.errors,
|
|
305
|
+
}
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""LLM providers for aicert."""
|
|
2
|
+
|
|
3
|
+
from aicert.providers.base import BaseProvider
|
|
4
|
+
from aicert.providers.openai import OpenAIProvider
|
|
5
|
+
from aicert.providers.anthropic import AnthropicProvider
|
|
6
|
+
from aicert.providers.openai_compatible import OpenAICompatibleProvider
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"BaseProvider",
|
|
10
|
+
"OpenAIProvider",
|
|
11
|
+
"AnthropicProvider",
|
|
12
|
+
"OpenAICompatibleProvider",
|
|
13
|
+
]
|
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
"""Anthropic provider for LLM API calls."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Any, Dict, Optional
|
|
5
|
+
|
|
6
|
+
import httpx
|
|
7
|
+
|
|
8
|
+
from aicert.providers.base import BaseProvider
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class AnthropicProvider(BaseProvider):
|
|
12
|
+
"""Anthropic provider implementation using the Messages API."""
|
|
13
|
+
|
|
14
|
+
DEFAULT_BASE_URL = "https://api.anthropic.com"
|
|
15
|
+
API_KEY_ENV = "ANTHROPIC_API_KEY"
|
|
16
|
+
API_VERSION = "2023-06-01"
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
model: str,
|
|
21
|
+
api_key: Optional[str] = None,
|
|
22
|
+
base_url: Optional[str] = None,
|
|
23
|
+
temperature: float = 0.7,
|
|
24
|
+
**kwargs,
|
|
25
|
+
):
|
|
26
|
+
super().__init__(model=model, api_key=api_key, base_url=base_url, **kwargs)
|
|
27
|
+
self.temperature = temperature
|
|
28
|
+
self._client: Optional[httpx.AsyncClient] = None
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
def api_key(self) -> str:
|
|
32
|
+
"""Get API key from environment if not set."""
|
|
33
|
+
if self._api_key is None:
|
|
34
|
+
api_key = os.environ.get(self.API_KEY_ENV)
|
|
35
|
+
if not api_key:
|
|
36
|
+
raise ValueError(
|
|
37
|
+
f"API key not found. Set {self.API_KEY_ENV} environment variable "
|
|
38
|
+
"or pass api_key to the provider."
|
|
39
|
+
)
|
|
40
|
+
return api_key
|
|
41
|
+
return self._api_key
|
|
42
|
+
|
|
43
|
+
@api_key.setter
|
|
44
|
+
def api_key(self, value: Optional[str]):
|
|
45
|
+
self._api_key = value
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def base_url(self) -> str:
|
|
49
|
+
"""Get base URL for API calls."""
|
|
50
|
+
if self._base_url is None:
|
|
51
|
+
return self.DEFAULT_BASE_URL
|
|
52
|
+
return self._base_url
|
|
53
|
+
|
|
54
|
+
async def _get_client(self) -> httpx.AsyncClient:
|
|
55
|
+
"""Get or create async HTTP client."""
|
|
56
|
+
if self._client is None:
|
|
57
|
+
self._client = httpx.AsyncClient(
|
|
58
|
+
timeout=httpx.Timeout(60.0),
|
|
59
|
+
headers={
|
|
60
|
+
"x-api-key": self.api_key,
|
|
61
|
+
"anthropic-version": self.API_VERSION,
|
|
62
|
+
"Content-Type": "application/json",
|
|
63
|
+
},
|
|
64
|
+
)
|
|
65
|
+
return self._client
|
|
66
|
+
|
|
67
|
+
async def close(self):
|
|
68
|
+
"""Close the HTTP client."""
|
|
69
|
+
if self._client:
|
|
70
|
+
await self._client.aclose()
|
|
71
|
+
self._client = None
|
|
72
|
+
|
|
73
|
+
def _transform_response(self, response_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
74
|
+
"""Transform Anthropic response to OpenAI-compatible format."""
|
|
75
|
+
choices = []
|
|
76
|
+
for content in response_data.get("content", []):
|
|
77
|
+
if content.get("type") == "text":
|
|
78
|
+
choices.append({
|
|
79
|
+
"message": {
|
|
80
|
+
"content": content.get("text", ""),
|
|
81
|
+
},
|
|
82
|
+
"index": 0,
|
|
83
|
+
"finish_reason": response_data.get("stop_reason", "stop"),
|
|
84
|
+
})
|
|
85
|
+
break
|
|
86
|
+
|
|
87
|
+
usage = response_data.get("usage", {})
|
|
88
|
+
# Anthropic uses different field names
|
|
89
|
+
transformed_usage = {
|
|
90
|
+
"prompt_tokens": usage.get("input_tokens", 0),
|
|
91
|
+
"completion_tokens": usage.get("output_tokens", 0),
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
return {
|
|
95
|
+
"choices": choices,
|
|
96
|
+
"usage": transformed_usage,
|
|
97
|
+
"raw": response_data,
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
async def generate(self, prompt: str, **kwargs) -> Dict[str, Any]:
|
|
101
|
+
"""Generate a response from Anthropic."""
|
|
102
|
+
client = await self._get_client()
|
|
103
|
+
|
|
104
|
+
url = f"{self.base_url}/v1/messages"
|
|
105
|
+
|
|
106
|
+
payload = {
|
|
107
|
+
"model": self.model,
|
|
108
|
+
"messages": [{"role": "user", "content": prompt}],
|
|
109
|
+
"max_tokens": 4096,
|
|
110
|
+
"temperature": self.temperature,
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
try:
|
|
114
|
+
response = await client.post(url, json=payload)
|
|
115
|
+
except httpx.RequestError as e:
|
|
116
|
+
raise ConnectionError(f"Failed to connect to Anthropic API: {e}")
|
|
117
|
+
|
|
118
|
+
if not response.is_success:
|
|
119
|
+
status_code = response.status_code
|
|
120
|
+
try:
|
|
121
|
+
error_data = response.json()
|
|
122
|
+
error_msg = error_data.get("error", {}).get("message", response.text)
|
|
123
|
+
except Exception:
|
|
124
|
+
error_msg = response.text
|
|
125
|
+
|
|
126
|
+
if status_code in (429, 500, 502, 503, 504):
|
|
127
|
+
from aicert.runner import RetriableError
|
|
128
|
+
raise RetriableError(f"Anthropic API error ({status_code}): {error_msg}")
|
|
129
|
+
else:
|
|
130
|
+
raise ValueError(f"Anthropic API error ({status_code}): {error_msg}")
|
|
131
|
+
|
|
132
|
+
result = response.json()
|
|
133
|
+
|
|
134
|
+
return self._transform_response(result)
|
|
135
|
+
|
|
136
|
+
async def generate_stream(self, prompt: str, **kwargs):
|
|
137
|
+
"""Generate a streaming response from Anthropic."""
|
|
138
|
+
client = await self._get_client()
|
|
139
|
+
|
|
140
|
+
url = f"{self.base_url}/v1/messages"
|
|
141
|
+
|
|
142
|
+
payload = {
|
|
143
|
+
"model": self.model,
|
|
144
|
+
"messages": [{"role": "user", "content": prompt}],
|
|
145
|
+
"max_tokens": 4096,
|
|
146
|
+
"temperature": self.temperature,
|
|
147
|
+
"stream": True,
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
try:
|
|
151
|
+
async with client.stream("POST", url, json=payload) as response:
|
|
152
|
+
if not response.is_success:
|
|
153
|
+
status_code = response.status_code
|
|
154
|
+
try:
|
|
155
|
+
error_data = await response.json()
|
|
156
|
+
error_msg = error_data.get("error", {}).get("message", await response.aread())
|
|
157
|
+
except Exception:
|
|
158
|
+
error_msg = await response.aread()
|
|
159
|
+
|
|
160
|
+
if status_code in (429, 500, 502, 503, 504):
|
|
161
|
+
from aicert.runner import RetriableError
|
|
162
|
+
raise RetriableError(f"Anthropic API error ({status_code}): {error_msg}")
|
|
163
|
+
else:
|
|
164
|
+
raise ValueError(f"Anthropic API error ({status_code}): {error_msg}")
|
|
165
|
+
|
|
166
|
+
async for line in response.aiter_lines():
|
|
167
|
+
if line.startswith("data: "):
|
|
168
|
+
data = line[6:]
|
|
169
|
+
if data == "[DONE]":
|
|
170
|
+
break
|
|
171
|
+
try:
|
|
172
|
+
chunk = __import__("json").loads(data)
|
|
173
|
+
yield chunk
|
|
174
|
+
except Exception:
|
|
175
|
+
continue
|
|
176
|
+
except httpx.RequestError as e:
|
|
177
|
+
raise ConnectionError(f"Failed to connect to Anthropic API: {e}")
|
|
178
|
+
|
|
179
|
+
@property
|
|
180
|
+
def provider_type(self) -> str:
|
|
181
|
+
"""Return the provider type identifier."""
|
|
182
|
+
return "anthropic"
|
aicert/providers/base.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
"""Base provider for LLM API calls."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Any, Dict, Optional
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class BaseProvider(ABC):
|
|
8
|
+
"""Base class for LLM providers."""
|
|
9
|
+
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
model: str,
|
|
13
|
+
api_key: Optional[str] = None,
|
|
14
|
+
base_url: Optional[str] = None,
|
|
15
|
+
**kwargs,
|
|
16
|
+
):
|
|
17
|
+
self.model = model
|
|
18
|
+
self.api_key = api_key
|
|
19
|
+
self.base_url = base_url
|
|
20
|
+
self.kwargs = kwargs
|
|
21
|
+
|
|
22
|
+
@abstractmethod
|
|
23
|
+
async def generate(self, prompt: str, **kwargs) -> Dict[str, Any]:
|
|
24
|
+
"""Generate a response from the model."""
|
|
25
|
+
raise NotImplementedError
|
|
26
|
+
|
|
27
|
+
@abstractmethod
|
|
28
|
+
async def generate_stream(self, prompt: str, **kwargs):
|
|
29
|
+
"""Generate a streaming response from the model."""
|
|
30
|
+
raise NotImplementedError
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
@abstractmethod
|
|
34
|
+
def provider_type(self) -> str:
|
|
35
|
+
"""Return the provider type identifier."""
|
|
36
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
"""OpenAI provider for LLM API calls."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Any, Dict, Optional
|
|
5
|
+
|
|
6
|
+
import httpx
|
|
7
|
+
|
|
8
|
+
from aicert.providers.base import BaseProvider
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class OpenAIProvider(BaseProvider):
|
|
12
|
+
"""OpenAI provider implementation."""
|
|
13
|
+
|
|
14
|
+
DEFAULT_BASE_URL = "https://api.openai.com/v1"
|
|
15
|
+
API_KEY_ENV = "OPENAI_API_KEY"
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
model: str,
|
|
20
|
+
api_key: Optional[str] = None,
|
|
21
|
+
base_url: Optional[str] = None,
|
|
22
|
+
temperature: float = 0.7,
|
|
23
|
+
**kwargs,
|
|
24
|
+
):
|
|
25
|
+
super().__init__(model=model, api_key=api_key, base_url=base_url, **kwargs)
|
|
26
|
+
self.temperature = temperature
|
|
27
|
+
self._client: Optional[httpx.AsyncClient] = None
|
|
28
|
+
|
|
29
|
+
@property
|
|
30
|
+
def api_key(self) -> str:
|
|
31
|
+
"""Get API key from environment if not set."""
|
|
32
|
+
if self._api_key is None:
|
|
33
|
+
api_key = os.environ.get(self.API_KEY_ENV)
|
|
34
|
+
if not api_key:
|
|
35
|
+
raise ValueError(
|
|
36
|
+
f"API key not found. Set {self.API_KEY_ENV} environment variable "
|
|
37
|
+
"or pass api_key to the provider."
|
|
38
|
+
)
|
|
39
|
+
return api_key
|
|
40
|
+
return self._api_key
|
|
41
|
+
|
|
42
|
+
@api_key.setter
|
|
43
|
+
def api_key(self, value: Optional[str]):
|
|
44
|
+
self._api_key = value
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def base_url(self) -> str:
|
|
48
|
+
"""Get base URL for API calls."""
|
|
49
|
+
if self._base_url is None:
|
|
50
|
+
return self.DEFAULT_BASE_URL
|
|
51
|
+
return self._base_url
|
|
52
|
+
|
|
53
|
+
async def _get_client(self) -> httpx.AsyncClient:
|
|
54
|
+
"""Get or create async HTTP client."""
|
|
55
|
+
if self._client is None:
|
|
56
|
+
self._client = httpx.AsyncClient(
|
|
57
|
+
timeout=httpx.Timeout(60.0),
|
|
58
|
+
headers={"Authorization": f"Bearer {self.api_key}"},
|
|
59
|
+
)
|
|
60
|
+
return self._client
|
|
61
|
+
|
|
62
|
+
async def close(self):
|
|
63
|
+
"""Close the HTTP client."""
|
|
64
|
+
if self._client:
|
|
65
|
+
await self._client.aclose()
|
|
66
|
+
self._client = None
|
|
67
|
+
|
|
68
|
+
async def generate(self, prompt: str, **kwargs) -> Dict[str, Any]:
|
|
69
|
+
"""Generate a response from OpenAI."""
|
|
70
|
+
client = await self._get_client()
|
|
71
|
+
|
|
72
|
+
url = f"{self.base_url}/chat/completions"
|
|
73
|
+
|
|
74
|
+
payload = {
|
|
75
|
+
"model": self.model,
|
|
76
|
+
"messages": [{"role": "user", "content": prompt}],
|
|
77
|
+
"temperature": self.temperature,
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
try:
|
|
81
|
+
response = await client.post(url, json=payload)
|
|
82
|
+
except httpx.RequestError as e:
|
|
83
|
+
raise ConnectionError(f"Failed to connect to OpenAI API: {e}")
|
|
84
|
+
|
|
85
|
+
if not response.is_success:
|
|
86
|
+
status_code = response.status_code
|
|
87
|
+
try:
|
|
88
|
+
error_data = response.json()
|
|
89
|
+
error_msg = error_data.get("error", {}).get("message", response.text)
|
|
90
|
+
except Exception:
|
|
91
|
+
error_msg = response.text
|
|
92
|
+
|
|
93
|
+
if status_code in (429, 500, 502, 503, 504):
|
|
94
|
+
from aicert.runner import RetriableError
|
|
95
|
+
raise RetriableError(f"OpenAI API error ({status_code}): {error_msg}")
|
|
96
|
+
else:
|
|
97
|
+
raise ValueError(f"OpenAI API error ({status_code}): {error_msg}")
|
|
98
|
+
|
|
99
|
+
result = response.json()
|
|
100
|
+
|
|
101
|
+
# Ensure we have the expected structure
|
|
102
|
+
return {
|
|
103
|
+
"choices": result.get("choices", []),
|
|
104
|
+
"usage": result.get("usage", {}),
|
|
105
|
+
"raw": result,
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
async def generate_stream(self, prompt: str, **kwargs):
|
|
109
|
+
"""Generate a streaming response from OpenAI."""
|
|
110
|
+
client = await self._get_client()
|
|
111
|
+
|
|
112
|
+
url = f"{self.base_url}/chat/completions"
|
|
113
|
+
|
|
114
|
+
payload = {
|
|
115
|
+
"model": self.model,
|
|
116
|
+
"messages": [{"role": "user", "content": prompt}],
|
|
117
|
+
"temperature": self.temperature,
|
|
118
|
+
"stream": True,
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
try:
|
|
122
|
+
async with client.stream("POST", url, json=payload) as response:
|
|
123
|
+
if not response.is_success:
|
|
124
|
+
status_code = response.status_code
|
|
125
|
+
try:
|
|
126
|
+
error_data = await response.json()
|
|
127
|
+
error_msg = error_data.get("error", {}).get("message", await response.aread())
|
|
128
|
+
except Exception:
|
|
129
|
+
error_msg = await response.aread()
|
|
130
|
+
|
|
131
|
+
if status_code in (429, 500, 502, 503, 504):
|
|
132
|
+
from aicert.runner import RetriableError
|
|
133
|
+
raise RetriableError(f"OpenAI API error ({status_code}): {error_msg}")
|
|
134
|
+
else:
|
|
135
|
+
raise ValueError(f"OpenAI API error ({status_code}): {error_msg}")
|
|
136
|
+
|
|
137
|
+
async for line in response.aiter_lines():
|
|
138
|
+
if line.startswith("data: "):
|
|
139
|
+
data = line[6:]
|
|
140
|
+
if data == "[DONE]":
|
|
141
|
+
break
|
|
142
|
+
try:
|
|
143
|
+
chunk = __import__("json").loads(data)
|
|
144
|
+
yield chunk
|
|
145
|
+
except Exception:
|
|
146
|
+
continue
|
|
147
|
+
except httpx.RequestError as e:
|
|
148
|
+
raise ConnectionError(f"Failed to connect to OpenAI API: {e}")
|
|
149
|
+
|
|
150
|
+
@property
|
|
151
|
+
def provider_type(self) -> str:
|
|
152
|
+
"""Return the provider type identifier."""
|
|
153
|
+
return "openai"
|