aicert 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,152 @@
1
+ """OpenAI-compatible provider for LLM API calls."""
2
+
3
+ import os
4
+ from typing import Any, Dict, Optional
5
+
6
+ import httpx
7
+
8
+ from aicert.providers.base import BaseProvider
9
+
10
+
11
+ class OpenAICompatibleProvider(BaseProvider):
12
+ """OpenAI-compatible provider implementation (e.g., local LLMs, proxy servers)."""
13
+
14
+ DEFAULT_BASE_URL = "http://localhost:8000/v1"
15
+ API_KEY_ENV = "OPENAI_COMPAT_API_KEY"
16
+
17
+ def __init__(
18
+ self,
19
+ model: str,
20
+ api_key: Optional[str] = None,
21
+ base_url: Optional[str] = None,
22
+ temperature: float = 0.7,
23
+ **kwargs,
24
+ ):
25
+ super().__init__(model=model, api_key=api_key, base_url=base_url, **kwargs)
26
+ self.temperature = temperature
27
+ self._client: Optional[httpx.AsyncClient] = None
28
+
29
+ @property
30
+ def api_key(self) -> Optional[str]:
31
+ """Get API key from environment if not set."""
32
+ if self._api_key is None:
33
+ return os.environ.get(self.API_KEY_ENV)
34
+ return self._api_key
35
+
36
+ @api_key.setter
37
+ def api_key(self, value: Optional[str]):
38
+ self._api_key = value
39
+
40
+ @property
41
+ def base_url(self) -> str:
42
+ """Get base URL for API calls."""
43
+ if self._base_url is None:
44
+ return self.DEFAULT_BASE_URL
45
+ return self._base_url
46
+
47
+ async def _get_client(self) -> httpx.AsyncClient:
48
+ """Get or create async HTTP client."""
49
+ if self._client is None:
50
+ headers = {}
51
+ api_key = self.api_key
52
+ if api_key:
53
+ headers["Authorization"] = f"Bearer {api_key}"
54
+
55
+ self._client = httpx.AsyncClient(
56
+ timeout=httpx.Timeout(120.0), # Longer timeout for local models
57
+ headers=headers,
58
+ )
59
+ return self._client
60
+
61
+ async def close(self):
62
+ """Close the HTTP client."""
63
+ if self._client:
64
+ await self._client.aclose()
65
+ self._client = None
66
+
67
+ async def generate(self, prompt: str, **kwargs) -> Dict[str, Any]:
68
+ """Generate a response from an OpenAI-compatible endpoint."""
69
+ client = await self._get_client()
70
+
71
+ url = f"{self.base_url}/chat/completions"
72
+
73
+ payload = {
74
+ "model": self.model,
75
+ "messages": [{"role": "user", "content": prompt}],
76
+ "temperature": self.temperature,
77
+ }
78
+
79
+ try:
80
+ response = await client.post(url, json=payload)
81
+ except httpx.RequestError as e:
82
+ raise ConnectionError(f"Failed to connect to OpenAI-compatible API: {e}")
83
+
84
+ if not response.is_success:
85
+ status_code = response.status_code
86
+ try:
87
+ error_data = response.json()
88
+ error_msg = error_data.get("error", {}).get("message", response.text)
89
+ except Exception:
90
+ error_msg = response.text
91
+
92
+ if status_code in (429, 500, 502, 503, 504):
93
+ from aicert.runner import RetriableError
94
+ raise RetriableError(f"OpenAI-compatible API error ({status_code}): {error_msg}")
95
+ else:
96
+ raise ValueError(f"OpenAI-compatible API error ({status_code}): {error_msg}")
97
+
98
+ result = response.json()
99
+
100
+ # Ensure we have the expected structure
101
+ return {
102
+ "choices": result.get("choices", []),
103
+ "usage": result.get("usage", {}),
104
+ "raw": result,
105
+ }
106
+
107
+ async def generate_stream(self, prompt: str, **kwargs):
108
+ """Generate a streaming response from an OpenAI-compatible endpoint."""
109
+ client = await self._get_client()
110
+
111
+ url = f"{self.base_url}/chat/completions"
112
+
113
+ payload = {
114
+ "model": self.model,
115
+ "messages": [{"role": "user", "content": prompt}],
116
+ "temperature": self.temperature,
117
+ "stream": True,
118
+ }
119
+
120
+ try:
121
+ async with client.stream("POST", url, json=payload) as response:
122
+ if not response.is_success:
123
+ status_code = response.status_code
124
+ try:
125
+ error_data = await response.json()
126
+ error_msg = error_data.get("error", {}).get("message", await response.aread())
127
+ except Exception:
128
+ error_msg = await response.aread()
129
+
130
+ if status_code in (429, 500, 502, 503, 504):
131
+ from aicert.runner import RetriableError
132
+ raise RetriableError(f"OpenAI-compatible API error ({status_code}): {error_msg}")
133
+ else:
134
+ raise ValueError(f"OpenAI-compatible API error ({status_code}): {error_msg}")
135
+
136
+ async for line in response.aiter_lines():
137
+ if line.startswith("data: "):
138
+ data = line[6:]
139
+ if data == "[DONE]":
140
+ break
141
+ try:
142
+ chunk = __import__("json").loads(data)
143
+ yield chunk
144
+ except Exception:
145
+ continue
146
+ except httpx.RequestError as e:
147
+ raise ConnectionError(f"Failed to connect to OpenAI-compatible API: {e}")
148
+
149
+ @property
150
+ def provider_type(self) -> str:
151
+ """Return the provider type identifier."""
152
+ return "openai-compatible"