aicert 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aicert/metrics.py ADDED
@@ -0,0 +1,305 @@
1
+ """Metrics utilities for aicert."""
2
+
3
+ import json
4
+ import re
5
+ import statistics
6
+ from typing import Any, Dict, List, Optional
7
+
8
+
9
+ def clamp(value: float, min_val: float, max_val: float) -> float:
10
+ """Clamp value between min and max."""
11
+ return max(min_val, min(value, max_val))
12
+
13
+
14
+ def canonicalize_json(obj: Any) -> Any:
15
+ """Recursively canonicalize JSON by sorting keys."""
16
+ if isinstance(obj, dict):
17
+ return sorted((k, canonicalize_json(v)) for k, v in obj.items())
18
+ elif isinstance(obj, list):
19
+ return [canonicalize_json(item) for item in obj]
20
+ return obj
21
+
22
+
23
+ def stringify_compact(obj: Any) -> str:
24
+ """Convert canonicalized JSON to compact string."""
25
+ return json.dumps(obj, separators=(',', ':'))
26
+
27
+
28
+ def tokenize(text: str) -> set:
29
+ """Tokenize text by regex into tokens."""
30
+ # Tokenize by alphanumeric sequences
31
+ tokens = re.findall(r'\w+', text.lower())
32
+ return set(tokens)
33
+
34
+
35
+ def jaccard_similarity(set1: set, set2: set) -> float:
36
+ """Compute Jaccard similarity between two sets."""
37
+ if not set1 and not set2:
38
+ return 1.0
39
+ intersection = len(set1 & set2)
40
+ union = len(set1 | set2)
41
+ return intersection / union if union > 0 else 0.0
42
+
43
+
44
+ def compute_similarity(outputs: List[Optional[Dict]]) -> float:
45
+ """Compute similarity score based on Jaccard similarity of canonicalized tokens."""
46
+ # Filter valid outputs (non-None)
47
+ valid_outputs = [o for o in outputs if o is not None]
48
+
49
+ if not valid_outputs:
50
+ return 0.0
51
+
52
+ # Canonicalize and tokenize each output
53
+ canonicalized = [stringify_compact(canonicalize_json(o)) for o in valid_outputs]
54
+ token_sets = [tokenize(c) for c in canonicalized]
55
+
56
+ # Choose first valid as baseline
57
+ baseline_tokens = token_sets[0]
58
+
59
+ # Compute Jaccard similarity with baseline for each other
60
+ similarities = [jaccard_similarity(baseline_tokens, ts) for ts in token_sets]
61
+
62
+ # Average * 100
63
+ return sum(similarities) / len(similarities) * 100 if similarities else 0.0
64
+
65
+
66
+ def compute_structural_consistency(outputs: List[Optional[Dict]], required_keys: List[str]) -> float:
67
+ """Compute structural consistency based on required key frequency."""
68
+ if not required_keys:
69
+ return 100.0
70
+
71
+ valid_outputs = [o for o in outputs if o is not None]
72
+
73
+ if not valid_outputs:
74
+ return 0.0
75
+
76
+ # For each required key, compute frequency present
77
+ key_frequencies = []
78
+ for key in required_keys:
79
+ present_count = sum(1 for o in valid_outputs if isinstance(o, dict) and key in o)
80
+ freq = present_count / len(valid_outputs)
81
+ key_frequencies.append(freq)
82
+
83
+ # Average across required keys * 100
84
+ return (sum(key_frequencies) / len(key_frequencies)) * 100 if key_frequencies else 0.0
85
+
86
+
87
+ def compute_latency_stats(latencies: List[float]) -> Dict[str, float]:
88
+ """Compute latency statistics: mean, p95, std."""
89
+ if not latencies:
90
+ return {"mean": 0.0, "p95": 0.0, "std": 0.0}
91
+
92
+ mean_val = statistics.mean(latencies)
93
+
94
+ # Calculate p95
95
+ sorted_latencies = sorted(latencies)
96
+ p95_idx = int(len(sorted_latencies) * 0.95)
97
+ p95_val = sorted_latencies[min(p95_idx, len(sorted_latencies) - 1)]
98
+
99
+ # Calculate std
100
+ std_val = statistics.stdev(latencies) if len(latencies) > 1 else 0.0
101
+
102
+ return {
103
+ "mean": mean_val,
104
+ "p95": p95_val,
105
+ "std": std_val
106
+ }
107
+
108
+
109
+ def compute_latency_stability(latency_stats: Dict[str, float]) -> float:
110
+ """Compute latency stability score."""
111
+ mean = latency_stats.get("mean", 0)
112
+ std = latency_stats.get("std", 0)
113
+
114
+ if mean <= 0:
115
+ return 0.0
116
+
117
+ return clamp(100 * (1 - std / mean), 0, 100)
118
+
119
+
120
+ def compute_stability_score(
121
+ compliance: float,
122
+ structural: float,
123
+ similarity: float,
124
+ latency_stability: float
125
+ ) -> float:
126
+ """Compute final stability score."""
127
+ return compliance * 0.40 + structural * 0.25 + similarity * 0.25 + latency_stability * 0.10
128
+
129
+
130
+ def compute_summary(
131
+ results: List[Dict[str, Any]],
132
+ schema: Dict[str, Any],
133
+ prompt_hash: Optional[str] = None,
134
+ schema_hash: Optional[str] = None,
135
+ ) -> Dict[str, Any]:
136
+ """
137
+ Compute metrics summary from execution results.
138
+
139
+ Args:
140
+ results: List of execution result dicts with fields:
141
+ provider_id, case_id, ok_json, ok_schema, extra_keys,
142
+ latency_ms, cost_usd, output_json (parsed JSON when ok_json), error
143
+ schema: JSON schema dict
144
+ prompt_hash: Optional SHA-256 hash of the prompt file
145
+ schema_hash: Optional SHA-256 hash of the schema file
146
+
147
+ Returns:
148
+ Dict containing per-provider metrics and overall summary
149
+ """
150
+ # Group results by provider
151
+ providers: Dict[str, List[Dict[str, Any]]] = {}
152
+ for result in results:
153
+ provider_id = result.get("provider_id", "unknown")
154
+ if provider_id not in providers:
155
+ providers[provider_id] = []
156
+ providers[provider_id].append(result)
157
+
158
+ # Get required keys from schema
159
+ required_keys = schema.get("required", []) if schema else []
160
+
161
+ # Compute per-provider metrics
162
+ per_provider: Dict[str, Dict[str, Any]] = {}
163
+ for provider_id, provider_results in providers.items():
164
+ total_runs = len(provider_results)
165
+
166
+ # Count ok_json and ok_schema
167
+ ok_json_count = sum(1 for r in provider_results if r.get("ok_json", False))
168
+ ok_schema_count = sum(1 for r in provider_results if r.get("ok_schema", False))
169
+
170
+ # Count error types
171
+ json_parse_failures = sum(1 for r in provider_results if not r.get("ok_json", False))
172
+ schema_failures = sum(1 for r in provider_results if r.get("ok_json", False) and not r.get("ok_schema", False))
173
+ provider_errors = sum(1 for r in provider_results if r.get("error") and any(x in r.get("error", "") for x in ["429", "500", "Provider error"]))
174
+ timeouts = sum(1 for r in provider_results if r.get("error") and "Timeout" in r.get("error", ""))
175
+
176
+ # Compute rates
177
+ json_parse_rate = (ok_json_count / total_runs * 100) if total_runs > 0 else 0.0
178
+ schema_compliance = (ok_schema_count / total_runs * 100) if total_runs > 0 else 0.0
179
+
180
+ # Collect outputs for structural consistency and similarity
181
+ outputs = [r.get("output_json") for r in provider_results]
182
+
183
+ # Compute structural consistency
184
+ structural_consistency = compute_structural_consistency(outputs, required_keys)
185
+
186
+ # Compute similarity
187
+ similarity = compute_similarity(outputs)
188
+
189
+ # Collect latencies
190
+ latencies = [r.get("latency_ms", 0) for r in provider_results]
191
+ latency_stats = compute_latency_stats(latencies)
192
+
193
+ # Compute latency stability
194
+ latency_stability = compute_latency_stability(latency_stats)
195
+
196
+ # Compute final stability score
197
+ stability_score = compute_stability_score(
198
+ schema_compliance,
199
+ structural_consistency,
200
+ similarity,
201
+ latency_stability
202
+ )
203
+
204
+ # Collect costs
205
+ costs = [r.get("cost_usd", 0) for r in provider_results]
206
+ total_cost = sum(costs)
207
+
208
+ per_provider[provider_id] = {
209
+ "prompt_hash": prompt_hash,
210
+ "schema_hash": schema_hash,
211
+ "total_runs": total_runs,
212
+ "ok_json_count": ok_json_count,
213
+ "ok_schema_count": ok_schema_count,
214
+ "json_parse_failures": json_parse_failures,
215
+ "schema_failures": schema_failures,
216
+ "provider_errors": provider_errors,
217
+ "timeouts": timeouts,
218
+ "json_parse_rate": json_parse_rate,
219
+ "schema_compliance": schema_compliance,
220
+ "structural_consistency": structural_consistency,
221
+ "similarity": similarity,
222
+ "latency_stats": latency_stats,
223
+ "latency_stability": latency_stability,
224
+ "stability_score": stability_score,
225
+ "total_cost_usd": total_cost,
226
+ }
227
+
228
+ # Compute overall metrics
229
+ all_results = list(results)
230
+ all_outputs = [r.get("output_json") for r in all_results]
231
+ all_latencies = [r.get("latency_ms", 0) for r in all_results]
232
+ all_costs = [r.get("cost_usd", 0) for r in all_results]
233
+
234
+ # Overall error counts
235
+ overall_json_parse_failures = sum(1 for r in all_results if not r.get("ok_json", False))
236
+ overall_schema_failures = sum(1 for r in all_results if r.get("ok_json", False) and not r.get("ok_schema", False))
237
+ overall_provider_errors = sum(1 for r in all_results if r.get("error") and any(x in r.get("error", "") for x in ["429", "500", "Provider error"]))
238
+ overall_timeouts = sum(1 for r in all_results if r.get("error") and "Timeout" in r.get("error", ""))
239
+
240
+ overall = {
241
+ "total_runs": len(all_results),
242
+ "providers_count": len(providers),
243
+ "json_parse_failures": overall_json_parse_failures,
244
+ "schema_failures": overall_schema_failures,
245
+ "provider_errors": overall_provider_errors,
246
+ "timeouts": overall_timeouts,
247
+ "json_parse_rate": (sum(1 for r in all_results if r.get("ok_json", False)) / len(all_results) * 100) if all_results else 0.0,
248
+ "schema_compliance": (sum(1 for r in all_results if r.get("ok_schema", False)) / len(all_results) * 100) if all_results else 0.0,
249
+ "structural_consistency": compute_structural_consistency(all_outputs, required_keys),
250
+ "similarity": compute_similarity(all_outputs),
251
+ "latency_stats": compute_latency_stats(all_latencies),
252
+ "latency_stability": compute_latency_stability(compute_latency_stats(all_latencies)),
253
+ "total_cost_usd": sum(all_costs),
254
+ }
255
+
256
+ overall["stability_score"] = compute_stability_score(
257
+ overall["schema_compliance"],
258
+ overall["structural_consistency"],
259
+ overall["similarity"],
260
+ overall["latency_stability"]
261
+ )
262
+
263
+ return {
264
+ "prompt_hash": prompt_hash,
265
+ "schema_hash": schema_hash,
266
+ "per_provider": per_provider,
267
+ "overall": overall,
268
+ }
269
+
270
+
271
+ class Metrics:
272
+ """Container for validation metrics."""
273
+
274
+ def __init__(self):
275
+ self.total: int = 0
276
+ self.passed: int = 0
277
+ self.failed: int = 0
278
+ self.errors: List[Dict[str, Any]] = []
279
+
280
+ def add_result(self, passed: bool, error: str = None) -> None:
281
+ """Add a validation result."""
282
+ self.total += 1
283
+ if passed:
284
+ self.passed += 1
285
+ else:
286
+ self.failed += 1
287
+ if error:
288
+ self.errors.append({"error": error})
289
+
290
+ @property
291
+ def success_rate(self) -> float:
292
+ """Calculate success rate."""
293
+ if self.total == 0:
294
+ return 0.0
295
+ return self.passed / self.total
296
+
297
+ def to_dict(self) -> Dict[str, Any]:
298
+ """Convert metrics to dictionary."""
299
+ return {
300
+ "total": self.total,
301
+ "passed": self.passed,
302
+ "failed": self.failed,
303
+ "success_rate": self.success_rate,
304
+ "errors": self.errors,
305
+ }
@@ -0,0 +1,13 @@
1
+ """LLM providers for aicert."""
2
+
3
+ from aicert.providers.base import BaseProvider
4
+ from aicert.providers.openai import OpenAIProvider
5
+ from aicert.providers.anthropic import AnthropicProvider
6
+ from aicert.providers.openai_compatible import OpenAICompatibleProvider
7
+
8
+ __all__ = [
9
+ "BaseProvider",
10
+ "OpenAIProvider",
11
+ "AnthropicProvider",
12
+ "OpenAICompatibleProvider",
13
+ ]
@@ -0,0 +1,182 @@
1
+ """Anthropic provider for LLM API calls."""
2
+
3
+ import os
4
+ from typing import Any, Dict, Optional
5
+
6
+ import httpx
7
+
8
+ from aicert.providers.base import BaseProvider
9
+
10
+
11
+ class AnthropicProvider(BaseProvider):
12
+ """Anthropic provider implementation using the Messages API."""
13
+
14
+ DEFAULT_BASE_URL = "https://api.anthropic.com"
15
+ API_KEY_ENV = "ANTHROPIC_API_KEY"
16
+ API_VERSION = "2023-06-01"
17
+
18
+ def __init__(
19
+ self,
20
+ model: str,
21
+ api_key: Optional[str] = None,
22
+ base_url: Optional[str] = None,
23
+ temperature: float = 0.7,
24
+ **kwargs,
25
+ ):
26
+ super().__init__(model=model, api_key=api_key, base_url=base_url, **kwargs)
27
+ self.temperature = temperature
28
+ self._client: Optional[httpx.AsyncClient] = None
29
+
30
+ @property
31
+ def api_key(self) -> str:
32
+ """Get API key from environment if not set."""
33
+ if self._api_key is None:
34
+ api_key = os.environ.get(self.API_KEY_ENV)
35
+ if not api_key:
36
+ raise ValueError(
37
+ f"API key not found. Set {self.API_KEY_ENV} environment variable "
38
+ "or pass api_key to the provider."
39
+ )
40
+ return api_key
41
+ return self._api_key
42
+
43
+ @api_key.setter
44
+ def api_key(self, value: Optional[str]):
45
+ self._api_key = value
46
+
47
+ @property
48
+ def base_url(self) -> str:
49
+ """Get base URL for API calls."""
50
+ if self._base_url is None:
51
+ return self.DEFAULT_BASE_URL
52
+ return self._base_url
53
+
54
+ async def _get_client(self) -> httpx.AsyncClient:
55
+ """Get or create async HTTP client."""
56
+ if self._client is None:
57
+ self._client = httpx.AsyncClient(
58
+ timeout=httpx.Timeout(60.0),
59
+ headers={
60
+ "x-api-key": self.api_key,
61
+ "anthropic-version": self.API_VERSION,
62
+ "Content-Type": "application/json",
63
+ },
64
+ )
65
+ return self._client
66
+
67
+ async def close(self):
68
+ """Close the HTTP client."""
69
+ if self._client:
70
+ await self._client.aclose()
71
+ self._client = None
72
+
73
+ def _transform_response(self, response_data: Dict[str, Any]) -> Dict[str, Any]:
74
+ """Transform Anthropic response to OpenAI-compatible format."""
75
+ choices = []
76
+ for content in response_data.get("content", []):
77
+ if content.get("type") == "text":
78
+ choices.append({
79
+ "message": {
80
+ "content": content.get("text", ""),
81
+ },
82
+ "index": 0,
83
+ "finish_reason": response_data.get("stop_reason", "stop"),
84
+ })
85
+ break
86
+
87
+ usage = response_data.get("usage", {})
88
+ # Anthropic uses different field names
89
+ transformed_usage = {
90
+ "prompt_tokens": usage.get("input_tokens", 0),
91
+ "completion_tokens": usage.get("output_tokens", 0),
92
+ }
93
+
94
+ return {
95
+ "choices": choices,
96
+ "usage": transformed_usage,
97
+ "raw": response_data,
98
+ }
99
+
100
+ async def generate(self, prompt: str, **kwargs) -> Dict[str, Any]:
101
+ """Generate a response from Anthropic."""
102
+ client = await self._get_client()
103
+
104
+ url = f"{self.base_url}/v1/messages"
105
+
106
+ payload = {
107
+ "model": self.model,
108
+ "messages": [{"role": "user", "content": prompt}],
109
+ "max_tokens": 4096,
110
+ "temperature": self.temperature,
111
+ }
112
+
113
+ try:
114
+ response = await client.post(url, json=payload)
115
+ except httpx.RequestError as e:
116
+ raise ConnectionError(f"Failed to connect to Anthropic API: {e}")
117
+
118
+ if not response.is_success:
119
+ status_code = response.status_code
120
+ try:
121
+ error_data = response.json()
122
+ error_msg = error_data.get("error", {}).get("message", response.text)
123
+ except Exception:
124
+ error_msg = response.text
125
+
126
+ if status_code in (429, 500, 502, 503, 504):
127
+ from aicert.runner import RetriableError
128
+ raise RetriableError(f"Anthropic API error ({status_code}): {error_msg}")
129
+ else:
130
+ raise ValueError(f"Anthropic API error ({status_code}): {error_msg}")
131
+
132
+ result = response.json()
133
+
134
+ return self._transform_response(result)
135
+
136
+ async def generate_stream(self, prompt: str, **kwargs):
137
+ """Generate a streaming response from Anthropic."""
138
+ client = await self._get_client()
139
+
140
+ url = f"{self.base_url}/v1/messages"
141
+
142
+ payload = {
143
+ "model": self.model,
144
+ "messages": [{"role": "user", "content": prompt}],
145
+ "max_tokens": 4096,
146
+ "temperature": self.temperature,
147
+ "stream": True,
148
+ }
149
+
150
+ try:
151
+ async with client.stream("POST", url, json=payload) as response:
152
+ if not response.is_success:
153
+ status_code = response.status_code
154
+ try:
155
+ error_data = await response.json()
156
+ error_msg = error_data.get("error", {}).get("message", await response.aread())
157
+ except Exception:
158
+ error_msg = await response.aread()
159
+
160
+ if status_code in (429, 500, 502, 503, 504):
161
+ from aicert.runner import RetriableError
162
+ raise RetriableError(f"Anthropic API error ({status_code}): {error_msg}")
163
+ else:
164
+ raise ValueError(f"Anthropic API error ({status_code}): {error_msg}")
165
+
166
+ async for line in response.aiter_lines():
167
+ if line.startswith("data: "):
168
+ data = line[6:]
169
+ if data == "[DONE]":
170
+ break
171
+ try:
172
+ chunk = __import__("json").loads(data)
173
+ yield chunk
174
+ except Exception:
175
+ continue
176
+ except httpx.RequestError as e:
177
+ raise ConnectionError(f"Failed to connect to Anthropic API: {e}")
178
+
179
+ @property
180
+ def provider_type(self) -> str:
181
+ """Return the provider type identifier."""
182
+ return "anthropic"
@@ -0,0 +1,36 @@
1
+ """Base provider for LLM API calls."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any, Dict, Optional
5
+
6
+
7
+ class BaseProvider(ABC):
8
+ """Base class for LLM providers."""
9
+
10
+ def __init__(
11
+ self,
12
+ model: str,
13
+ api_key: Optional[str] = None,
14
+ base_url: Optional[str] = None,
15
+ **kwargs,
16
+ ):
17
+ self.model = model
18
+ self.api_key = api_key
19
+ self.base_url = base_url
20
+ self.kwargs = kwargs
21
+
22
+ @abstractmethod
23
+ async def generate(self, prompt: str, **kwargs) -> Dict[str, Any]:
24
+ """Generate a response from the model."""
25
+ raise NotImplementedError
26
+
27
+ @abstractmethod
28
+ async def generate_stream(self, prompt: str, **kwargs):
29
+ """Generate a streaming response from the model."""
30
+ raise NotImplementedError
31
+
32
+ @property
33
+ @abstractmethod
34
+ def provider_type(self) -> str:
35
+ """Return the provider type identifier."""
36
+ raise NotImplementedError
@@ -0,0 +1,153 @@
1
+ """OpenAI provider for LLM API calls."""
2
+
3
+ import os
4
+ from typing import Any, Dict, Optional
5
+
6
+ import httpx
7
+
8
+ from aicert.providers.base import BaseProvider
9
+
10
+
11
+ class OpenAIProvider(BaseProvider):
12
+ """OpenAI provider implementation."""
13
+
14
+ DEFAULT_BASE_URL = "https://api.openai.com/v1"
15
+ API_KEY_ENV = "OPENAI_API_KEY"
16
+
17
+ def __init__(
18
+ self,
19
+ model: str,
20
+ api_key: Optional[str] = None,
21
+ base_url: Optional[str] = None,
22
+ temperature: float = 0.7,
23
+ **kwargs,
24
+ ):
25
+ super().__init__(model=model, api_key=api_key, base_url=base_url, **kwargs)
26
+ self.temperature = temperature
27
+ self._client: Optional[httpx.AsyncClient] = None
28
+
29
+ @property
30
+ def api_key(self) -> str:
31
+ """Get API key from environment if not set."""
32
+ if self._api_key is None:
33
+ api_key = os.environ.get(self.API_KEY_ENV)
34
+ if not api_key:
35
+ raise ValueError(
36
+ f"API key not found. Set {self.API_KEY_ENV} environment variable "
37
+ "or pass api_key to the provider."
38
+ )
39
+ return api_key
40
+ return self._api_key
41
+
42
+ @api_key.setter
43
+ def api_key(self, value: Optional[str]):
44
+ self._api_key = value
45
+
46
+ @property
47
+ def base_url(self) -> str:
48
+ """Get base URL for API calls."""
49
+ if self._base_url is None:
50
+ return self.DEFAULT_BASE_URL
51
+ return self._base_url
52
+
53
+ async def _get_client(self) -> httpx.AsyncClient:
54
+ """Get or create async HTTP client."""
55
+ if self._client is None:
56
+ self._client = httpx.AsyncClient(
57
+ timeout=httpx.Timeout(60.0),
58
+ headers={"Authorization": f"Bearer {self.api_key}"},
59
+ )
60
+ return self._client
61
+
62
+ async def close(self):
63
+ """Close the HTTP client."""
64
+ if self._client:
65
+ await self._client.aclose()
66
+ self._client = None
67
+
68
+ async def generate(self, prompt: str, **kwargs) -> Dict[str, Any]:
69
+ """Generate a response from OpenAI."""
70
+ client = await self._get_client()
71
+
72
+ url = f"{self.base_url}/chat/completions"
73
+
74
+ payload = {
75
+ "model": self.model,
76
+ "messages": [{"role": "user", "content": prompt}],
77
+ "temperature": self.temperature,
78
+ }
79
+
80
+ try:
81
+ response = await client.post(url, json=payload)
82
+ except httpx.RequestError as e:
83
+ raise ConnectionError(f"Failed to connect to OpenAI API: {e}")
84
+
85
+ if not response.is_success:
86
+ status_code = response.status_code
87
+ try:
88
+ error_data = response.json()
89
+ error_msg = error_data.get("error", {}).get("message", response.text)
90
+ except Exception:
91
+ error_msg = response.text
92
+
93
+ if status_code in (429, 500, 502, 503, 504):
94
+ from aicert.runner import RetriableError
95
+ raise RetriableError(f"OpenAI API error ({status_code}): {error_msg}")
96
+ else:
97
+ raise ValueError(f"OpenAI API error ({status_code}): {error_msg}")
98
+
99
+ result = response.json()
100
+
101
+ # Ensure we have the expected structure
102
+ return {
103
+ "choices": result.get("choices", []),
104
+ "usage": result.get("usage", {}),
105
+ "raw": result,
106
+ }
107
+
108
+ async def generate_stream(self, prompt: str, **kwargs):
109
+ """Generate a streaming response from OpenAI."""
110
+ client = await self._get_client()
111
+
112
+ url = f"{self.base_url}/chat/completions"
113
+
114
+ payload = {
115
+ "model": self.model,
116
+ "messages": [{"role": "user", "content": prompt}],
117
+ "temperature": self.temperature,
118
+ "stream": True,
119
+ }
120
+
121
+ try:
122
+ async with client.stream("POST", url, json=payload) as response:
123
+ if not response.is_success:
124
+ status_code = response.status_code
125
+ try:
126
+ error_data = await response.json()
127
+ error_msg = error_data.get("error", {}).get("message", await response.aread())
128
+ except Exception:
129
+ error_msg = await response.aread()
130
+
131
+ if status_code in (429, 500, 502, 503, 504):
132
+ from aicert.runner import RetriableError
133
+ raise RetriableError(f"OpenAI API error ({status_code}): {error_msg}")
134
+ else:
135
+ raise ValueError(f"OpenAI API error ({status_code}): {error_msg}")
136
+
137
+ async for line in response.aiter_lines():
138
+ if line.startswith("data: "):
139
+ data = line[6:]
140
+ if data == "[DONE]":
141
+ break
142
+ try:
143
+ chunk = __import__("json").loads(data)
144
+ yield chunk
145
+ except Exception:
146
+ continue
147
+ except httpx.RequestError as e:
148
+ raise ConnectionError(f"Failed to connect to OpenAI API: {e}")
149
+
150
+ @property
151
+ def provider_type(self) -> str:
152
+ """Return the provider type identifier."""
153
+ return "openai"