router-maestro 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- router_maestro/__init__.py +3 -0
- router_maestro/__main__.py +6 -0
- router_maestro/auth/__init__.py +18 -0
- router_maestro/auth/github_oauth.py +181 -0
- router_maestro/auth/manager.py +136 -0
- router_maestro/auth/storage.py +91 -0
- router_maestro/cli/__init__.py +1 -0
- router_maestro/cli/auth.py +167 -0
- router_maestro/cli/client.py +322 -0
- router_maestro/cli/config.py +132 -0
- router_maestro/cli/context.py +146 -0
- router_maestro/cli/main.py +42 -0
- router_maestro/cli/model.py +288 -0
- router_maestro/cli/server.py +117 -0
- router_maestro/cli/stats.py +76 -0
- router_maestro/config/__init__.py +72 -0
- router_maestro/config/contexts.py +29 -0
- router_maestro/config/paths.py +50 -0
- router_maestro/config/priorities.py +93 -0
- router_maestro/config/providers.py +34 -0
- router_maestro/config/server.py +115 -0
- router_maestro/config/settings.py +76 -0
- router_maestro/providers/__init__.py +31 -0
- router_maestro/providers/anthropic.py +203 -0
- router_maestro/providers/base.py +123 -0
- router_maestro/providers/copilot.py +346 -0
- router_maestro/providers/openai.py +188 -0
- router_maestro/providers/openai_compat.py +175 -0
- router_maestro/routing/__init__.py +5 -0
- router_maestro/routing/router.py +526 -0
- router_maestro/server/__init__.py +5 -0
- router_maestro/server/app.py +87 -0
- router_maestro/server/middleware/__init__.py +11 -0
- router_maestro/server/middleware/auth.py +66 -0
- router_maestro/server/oauth_sessions.py +159 -0
- router_maestro/server/routes/__init__.py +8 -0
- router_maestro/server/routes/admin.py +358 -0
- router_maestro/server/routes/anthropic.py +228 -0
- router_maestro/server/routes/chat.py +142 -0
- router_maestro/server/routes/models.py +34 -0
- router_maestro/server/schemas/__init__.py +57 -0
- router_maestro/server/schemas/admin.py +87 -0
- router_maestro/server/schemas/anthropic.py +246 -0
- router_maestro/server/schemas/openai.py +107 -0
- router_maestro/server/translation.py +636 -0
- router_maestro/stats/__init__.py +14 -0
- router_maestro/stats/heatmap.py +154 -0
- router_maestro/stats/storage.py +228 -0
- router_maestro/stats/tracker.py +73 -0
- router_maestro/utils/__init__.py +16 -0
- router_maestro/utils/logging.py +81 -0
- router_maestro/utils/tokens.py +51 -0
- router_maestro-0.1.2.dist-info/METADATA +383 -0
- router_maestro-0.1.2.dist-info/RECORD +57 -0
- router_maestro-0.1.2.dist-info/WHEEL +4 -0
- router_maestro-0.1.2.dist-info/entry_points.txt +2 -0
- router_maestro-0.1.2.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,346 @@
|
|
|
1
|
+
"""GitHub Copilot provider implementation."""
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
from collections.abc import AsyncIterator
|
|
5
|
+
|
|
6
|
+
import httpx
|
|
7
|
+
|
|
8
|
+
from router_maestro.auth import AuthManager, AuthType
|
|
9
|
+
from router_maestro.auth.github_oauth import get_copilot_token
|
|
10
|
+
from router_maestro.providers.base import (
|
|
11
|
+
BaseProvider,
|
|
12
|
+
ChatRequest,
|
|
13
|
+
ChatResponse,
|
|
14
|
+
ChatStreamChunk,
|
|
15
|
+
ModelInfo,
|
|
16
|
+
ProviderError,
|
|
17
|
+
)
|
|
18
|
+
from router_maestro.utils import get_logger
|
|
19
|
+
|
|
20
|
+
logger = get_logger("providers.copilot")
|
|
21
|
+
|
|
22
|
+
COPILOT_BASE_URL = "https://api.githubcopilot.com"
|
|
23
|
+
COPILOT_CHAT_URL = f"{COPILOT_BASE_URL}/chat/completions"
|
|
24
|
+
COPILOT_MODELS_URL = f"{COPILOT_BASE_URL}/models"
|
|
25
|
+
|
|
26
|
+
# Model cache TTL in seconds (5 minutes)
|
|
27
|
+
MODELS_CACHE_TTL = 300
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class CopilotProvider(BaseProvider):
|
|
31
|
+
"""GitHub Copilot provider."""
|
|
32
|
+
|
|
33
|
+
name = "github-copilot"
|
|
34
|
+
|
|
35
|
+
def __init__(self) -> None:
|
|
36
|
+
self.auth_manager = AuthManager()
|
|
37
|
+
self._cached_token: str | None = None
|
|
38
|
+
self._token_expires: int = 0
|
|
39
|
+
# Model cache
|
|
40
|
+
self._models_cache: list[ModelInfo] | None = None
|
|
41
|
+
self._models_cache_expires: float = 0
|
|
42
|
+
# Reusable HTTP client
|
|
43
|
+
self._client: httpx.AsyncClient | None = None
|
|
44
|
+
|
|
45
|
+
def is_authenticated(self) -> bool:
|
|
46
|
+
"""Check if authenticated with GitHub Copilot."""
|
|
47
|
+
cred = self.auth_manager.get_credential("github-copilot")
|
|
48
|
+
return cred is not None and cred.type == AuthType.OAUTH
|
|
49
|
+
|
|
50
|
+
async def ensure_token(self) -> None:
|
|
51
|
+
"""Ensure we have a valid Copilot token, refreshing if needed."""
|
|
52
|
+
cred = self.auth_manager.get_credential("github-copilot")
|
|
53
|
+
if not cred or cred.type != AuthType.OAUTH:
|
|
54
|
+
logger.error("Not authenticated with GitHub Copilot")
|
|
55
|
+
raise ProviderError("Not authenticated with GitHub Copilot", status_code=401)
|
|
56
|
+
|
|
57
|
+
current_time = int(time.time())
|
|
58
|
+
|
|
59
|
+
# Check if we need to refresh (token expired or will expire soon)
|
|
60
|
+
if self._cached_token and self._token_expires > current_time + 60:
|
|
61
|
+
return # Token still valid
|
|
62
|
+
|
|
63
|
+
logger.debug("Refreshing Copilot token")
|
|
64
|
+
# Refresh the Copilot token using the GitHub token
|
|
65
|
+
client = self._get_client()
|
|
66
|
+
try:
|
|
67
|
+
copilot_token = await get_copilot_token(client, cred.refresh)
|
|
68
|
+
self._cached_token = copilot_token.token
|
|
69
|
+
self._token_expires = copilot_token.expires_at
|
|
70
|
+
|
|
71
|
+
# Update stored credential with new access token
|
|
72
|
+
cred.access = copilot_token.token
|
|
73
|
+
cred.expires = copilot_token.expires_at
|
|
74
|
+
self.auth_manager.save()
|
|
75
|
+
logger.debug("Copilot token refreshed, expires at %d", copilot_token.expires_at)
|
|
76
|
+
except httpx.HTTPError as e:
|
|
77
|
+
logger.error("Failed to refresh Copilot token: %s", e)
|
|
78
|
+
raise ProviderError(f"Failed to refresh Copilot token: {e}", retryable=True)
|
|
79
|
+
|
|
80
|
+
def _get_headers(self, vision_request: bool = False) -> dict[str, str]:
|
|
81
|
+
"""Get headers for Copilot API requests.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
vision_request: Whether this request contains images (vision)
|
|
85
|
+
"""
|
|
86
|
+
if not self._cached_token:
|
|
87
|
+
raise ProviderError("No valid token available", status_code=401)
|
|
88
|
+
|
|
89
|
+
headers = {
|
|
90
|
+
"Authorization": f"Bearer {self._cached_token}",
|
|
91
|
+
"Content-Type": "application/json",
|
|
92
|
+
"Editor-Version": "vscode/1.85.0",
|
|
93
|
+
"Editor-Plugin-Version": "copilot/1.0.0",
|
|
94
|
+
"Copilot-Integration-Id": "vscode-chat",
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
if vision_request:
|
|
98
|
+
headers["Copilot-Vision-Request"] = "true"
|
|
99
|
+
|
|
100
|
+
return headers
|
|
101
|
+
|
|
102
|
+
def _get_client(self) -> httpx.AsyncClient:
|
|
103
|
+
"""Get or create a reusable HTTP client."""
|
|
104
|
+
if self._client is None or self._client.is_closed:
|
|
105
|
+
self._client = httpx.AsyncClient(timeout=120.0)
|
|
106
|
+
return self._client
|
|
107
|
+
|
|
108
|
+
def _build_messages_payload(self, request: ChatRequest) -> tuple[list[dict], bool]:
|
|
109
|
+
"""Build messages payload and detect if images are present.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
request: The chat request
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
Tuple of (messages list, has_images flag)
|
|
116
|
+
"""
|
|
117
|
+
messages = []
|
|
118
|
+
has_images = False
|
|
119
|
+
|
|
120
|
+
for m in request.messages:
|
|
121
|
+
msg: dict = {"role": m.role, "content": m.content}
|
|
122
|
+
if m.tool_call_id:
|
|
123
|
+
msg["tool_call_id"] = m.tool_call_id
|
|
124
|
+
if m.tool_calls:
|
|
125
|
+
msg["tool_calls"] = m.tool_calls
|
|
126
|
+
messages.append(msg)
|
|
127
|
+
|
|
128
|
+
# Check if this message contains images (multimodal content)
|
|
129
|
+
if isinstance(m.content, list):
|
|
130
|
+
for part in m.content:
|
|
131
|
+
if isinstance(part, dict) and part.get("type") == "image_url":
|
|
132
|
+
has_images = True
|
|
133
|
+
break
|
|
134
|
+
|
|
135
|
+
return messages, has_images
|
|
136
|
+
|
|
137
|
+
async def chat_completion(self, request: ChatRequest) -> ChatResponse:
|
|
138
|
+
"""Generate a chat completion via Copilot."""
|
|
139
|
+
await self.ensure_token()
|
|
140
|
+
|
|
141
|
+
messages, has_images = self._build_messages_payload(request)
|
|
142
|
+
|
|
143
|
+
payload: dict = {
|
|
144
|
+
"model": request.model,
|
|
145
|
+
"messages": messages,
|
|
146
|
+
"temperature": request.temperature,
|
|
147
|
+
"stream": False,
|
|
148
|
+
}
|
|
149
|
+
if request.max_tokens:
|
|
150
|
+
payload["max_tokens"] = request.max_tokens
|
|
151
|
+
|
|
152
|
+
logger.debug("Copilot chat completion: model=%s", request.model)
|
|
153
|
+
client = self._get_client()
|
|
154
|
+
try:
|
|
155
|
+
response = await client.post(
|
|
156
|
+
COPILOT_CHAT_URL,
|
|
157
|
+
json=payload,
|
|
158
|
+
headers=self._get_headers(vision_request=has_images),
|
|
159
|
+
)
|
|
160
|
+
response.raise_for_status()
|
|
161
|
+
data = response.json()
|
|
162
|
+
|
|
163
|
+
choices = data.get("choices", [])
|
|
164
|
+
if not choices:
|
|
165
|
+
import json
|
|
166
|
+
|
|
167
|
+
logger.error("Copilot API returned empty choices: %s", json.dumps(data)[:500])
|
|
168
|
+
raise ProviderError(
|
|
169
|
+
f"Copilot API returned empty choices: {json.dumps(data)[:500]}",
|
|
170
|
+
status_code=500,
|
|
171
|
+
retryable=True,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
logger.debug("Copilot chat completion successful")
|
|
175
|
+
return ChatResponse(
|
|
176
|
+
content=choices[0]["message"]["content"],
|
|
177
|
+
model=data.get("model", request.model),
|
|
178
|
+
finish_reason=choices[0].get("finish_reason", "stop"),
|
|
179
|
+
usage=data.get("usage"),
|
|
180
|
+
)
|
|
181
|
+
except httpx.HTTPStatusError as e:
|
|
182
|
+
retryable = e.response.status_code in (429, 500, 502, 503, 504)
|
|
183
|
+
try:
|
|
184
|
+
error_body = e.response.text
|
|
185
|
+
except Exception:
|
|
186
|
+
error_body = ""
|
|
187
|
+
logger.error("Copilot API error: %d - %s", e.response.status_code, error_body[:200])
|
|
188
|
+
raise ProviderError(
|
|
189
|
+
f"Copilot API error: {e.response.status_code} - {error_body}",
|
|
190
|
+
status_code=e.response.status_code,
|
|
191
|
+
retryable=retryable,
|
|
192
|
+
)
|
|
193
|
+
except httpx.HTTPError as e:
|
|
194
|
+
logger.error("Copilot HTTP error: %s", e)
|
|
195
|
+
raise ProviderError(f"HTTP error: {e}", retryable=True)
|
|
196
|
+
|
|
197
|
+
async def chat_completion_stream(self, request: ChatRequest) -> AsyncIterator[ChatStreamChunk]:
|
|
198
|
+
"""Generate a streaming chat completion via Copilot."""
|
|
199
|
+
await self.ensure_token()
|
|
200
|
+
|
|
201
|
+
messages, has_images = self._build_messages_payload(request)
|
|
202
|
+
|
|
203
|
+
payload: dict = {
|
|
204
|
+
"model": request.model,
|
|
205
|
+
"messages": messages,
|
|
206
|
+
"temperature": request.temperature,
|
|
207
|
+
"stream": True,
|
|
208
|
+
}
|
|
209
|
+
if request.max_tokens:
|
|
210
|
+
payload["max_tokens"] = request.max_tokens
|
|
211
|
+
if request.tools:
|
|
212
|
+
payload["tools"] = request.tools
|
|
213
|
+
if request.tool_choice:
|
|
214
|
+
payload["tool_choice"] = request.tool_choice
|
|
215
|
+
|
|
216
|
+
logger.debug("Copilot streaming chat: model=%s", request.model)
|
|
217
|
+
client = self._get_client()
|
|
218
|
+
try:
|
|
219
|
+
async with client.stream(
|
|
220
|
+
"POST",
|
|
221
|
+
COPILOT_CHAT_URL,
|
|
222
|
+
json=payload,
|
|
223
|
+
headers=self._get_headers(vision_request=has_images),
|
|
224
|
+
) as response:
|
|
225
|
+
response.raise_for_status()
|
|
226
|
+
|
|
227
|
+
stream_finished = False
|
|
228
|
+
async for line in response.aiter_lines():
|
|
229
|
+
if stream_finished:
|
|
230
|
+
break
|
|
231
|
+
|
|
232
|
+
if not line or not line.startswith("data: "):
|
|
233
|
+
continue
|
|
234
|
+
|
|
235
|
+
data_str = line[6:] # Remove "data: " prefix
|
|
236
|
+
if data_str == "[DONE]":
|
|
237
|
+
break
|
|
238
|
+
|
|
239
|
+
import json
|
|
240
|
+
|
|
241
|
+
data = json.loads(data_str)
|
|
242
|
+
|
|
243
|
+
# Extract usage if present (may come in separate chunk)
|
|
244
|
+
usage = data.get("usage")
|
|
245
|
+
|
|
246
|
+
if "choices" in data and data["choices"]:
|
|
247
|
+
delta = data["choices"][0].get("delta", {})
|
|
248
|
+
content = delta.get("content", "")
|
|
249
|
+
finish_reason = data["choices"][0].get("finish_reason")
|
|
250
|
+
tool_calls = delta.get("tool_calls")
|
|
251
|
+
|
|
252
|
+
if content or finish_reason or usage or tool_calls:
|
|
253
|
+
yield ChatStreamChunk(
|
|
254
|
+
content=content,
|
|
255
|
+
finish_reason=finish_reason,
|
|
256
|
+
usage=usage,
|
|
257
|
+
tool_calls=tool_calls,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
# Mark stream as finished after receiving finish_reason
|
|
261
|
+
if finish_reason:
|
|
262
|
+
stream_finished = True
|
|
263
|
+
elif usage:
|
|
264
|
+
# Handle usage-only chunks (no choices)
|
|
265
|
+
yield ChatStreamChunk(
|
|
266
|
+
content="",
|
|
267
|
+
finish_reason=None,
|
|
268
|
+
usage=usage,
|
|
269
|
+
)
|
|
270
|
+
except httpx.HTTPStatusError as e:
|
|
271
|
+
retryable = e.response.status_code in (429, 500, 502, 503, 504)
|
|
272
|
+
try:
|
|
273
|
+
error_body = e.response.text
|
|
274
|
+
except Exception:
|
|
275
|
+
error_body = ""
|
|
276
|
+
logger.error(
|
|
277
|
+
"Copilot stream API error: %d - %s",
|
|
278
|
+
e.response.status_code,
|
|
279
|
+
error_body[:200],
|
|
280
|
+
)
|
|
281
|
+
raise ProviderError(
|
|
282
|
+
f"Copilot API error: {e.response.status_code} - {error_body}",
|
|
283
|
+
status_code=e.response.status_code,
|
|
284
|
+
retryable=retryable,
|
|
285
|
+
)
|
|
286
|
+
except httpx.HTTPError as e:
|
|
287
|
+
logger.error("Copilot stream HTTP error: %s", e)
|
|
288
|
+
raise ProviderError(f"HTTP error: {e}", retryable=True)
|
|
289
|
+
|
|
290
|
+
async def list_models(self, force_refresh: bool = False) -> list[ModelInfo]:
|
|
291
|
+
"""List available Copilot models from API with caching.
|
|
292
|
+
|
|
293
|
+
Args:
|
|
294
|
+
force_refresh: Force refresh the cache
|
|
295
|
+
|
|
296
|
+
Returns:
|
|
297
|
+
List of available models
|
|
298
|
+
"""
|
|
299
|
+
current_time = time.time()
|
|
300
|
+
|
|
301
|
+
# Return cached models if valid
|
|
302
|
+
if (
|
|
303
|
+
not force_refresh
|
|
304
|
+
and self._models_cache is not None
|
|
305
|
+
and current_time < self._models_cache_expires
|
|
306
|
+
):
|
|
307
|
+
logger.debug("Using cached Copilot models (%d models)", len(self._models_cache))
|
|
308
|
+
return self._models_cache
|
|
309
|
+
|
|
310
|
+
await self.ensure_token()
|
|
311
|
+
|
|
312
|
+
logger.debug("Fetching Copilot models from API")
|
|
313
|
+
client = self._get_client()
|
|
314
|
+
try:
|
|
315
|
+
response = await client.get(
|
|
316
|
+
COPILOT_MODELS_URL,
|
|
317
|
+
headers=self._get_headers(),
|
|
318
|
+
)
|
|
319
|
+
response.raise_for_status()
|
|
320
|
+
data = response.json()
|
|
321
|
+
|
|
322
|
+
models = []
|
|
323
|
+
for model in data.get("data", []):
|
|
324
|
+
# Only include models that are enabled in model picker
|
|
325
|
+
if model.get("model_picker_enabled", True):
|
|
326
|
+
models.append(
|
|
327
|
+
ModelInfo(
|
|
328
|
+
id=model["id"],
|
|
329
|
+
name=model.get("name", model["id"]),
|
|
330
|
+
provider=self.name,
|
|
331
|
+
)
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
# Update cache
|
|
335
|
+
self._models_cache = models
|
|
336
|
+
self._models_cache_expires = current_time + MODELS_CACHE_TTL
|
|
337
|
+
|
|
338
|
+
logger.info("Fetched %d Copilot models", len(models))
|
|
339
|
+
return models
|
|
340
|
+
except httpx.HTTPError as e:
|
|
341
|
+
# If cache exists, return stale cache on error
|
|
342
|
+
if self._models_cache is not None:
|
|
343
|
+
logger.warning("Failed to refresh Copilot models, using stale cache: %s", e)
|
|
344
|
+
return self._models_cache
|
|
345
|
+
logger.error("Failed to list Copilot models: %s", e)
|
|
346
|
+
raise ProviderError(f"Failed to list models: {e}", retryable=True)
|
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
"""OpenAI provider implementation."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import AsyncIterator
|
|
4
|
+
|
|
5
|
+
import httpx
|
|
6
|
+
|
|
7
|
+
from router_maestro.auth import AuthManager, AuthType
|
|
8
|
+
from router_maestro.providers.base import (
|
|
9
|
+
BaseProvider,
|
|
10
|
+
ChatRequest,
|
|
11
|
+
ChatResponse,
|
|
12
|
+
ChatStreamChunk,
|
|
13
|
+
ModelInfo,
|
|
14
|
+
ProviderError,
|
|
15
|
+
)
|
|
16
|
+
from router_maestro.utils import get_logger
|
|
17
|
+
|
|
18
|
+
logger = get_logger("providers.openai")
|
|
19
|
+
|
|
20
|
+
OPENAI_API_URL = "https://api.openai.com/v1"
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class OpenAIProvider(BaseProvider):
|
|
24
|
+
"""OpenAI official provider."""
|
|
25
|
+
|
|
26
|
+
name = "openai"
|
|
27
|
+
|
|
28
|
+
def __init__(self, base_url: str = OPENAI_API_URL) -> None:
|
|
29
|
+
self.base_url = base_url.rstrip("/")
|
|
30
|
+
self.auth_manager = AuthManager()
|
|
31
|
+
|
|
32
|
+
def is_authenticated(self) -> bool:
|
|
33
|
+
"""Check if authenticated with OpenAI."""
|
|
34
|
+
cred = self.auth_manager.get_credential("openai")
|
|
35
|
+
return cred is not None and cred.type == AuthType.API_KEY
|
|
36
|
+
|
|
37
|
+
def _get_api_key(self) -> str:
|
|
38
|
+
"""Get the API key."""
|
|
39
|
+
cred = self.auth_manager.get_credential("openai")
|
|
40
|
+
if not cred or cred.type != AuthType.API_KEY:
|
|
41
|
+
logger.error("Not authenticated with OpenAI")
|
|
42
|
+
raise ProviderError("Not authenticated with OpenAI", status_code=401)
|
|
43
|
+
return cred.key
|
|
44
|
+
|
|
45
|
+
def _get_headers(self) -> dict[str, str]:
|
|
46
|
+
"""Get headers for OpenAI API requests."""
|
|
47
|
+
return {
|
|
48
|
+
"Authorization": f"Bearer {self._get_api_key()}",
|
|
49
|
+
"Content-Type": "application/json",
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
async def chat_completion(self, request: ChatRequest) -> ChatResponse:
|
|
53
|
+
"""Generate a chat completion via OpenAI."""
|
|
54
|
+
payload = {
|
|
55
|
+
"model": request.model,
|
|
56
|
+
"messages": [{"role": m.role, "content": m.content} for m in request.messages],
|
|
57
|
+
"temperature": request.temperature,
|
|
58
|
+
"stream": False,
|
|
59
|
+
}
|
|
60
|
+
if request.max_tokens:
|
|
61
|
+
payload["max_tokens"] = request.max_tokens
|
|
62
|
+
|
|
63
|
+
logger.debug("OpenAI chat completion: model=%s", request.model)
|
|
64
|
+
async with httpx.AsyncClient() as client:
|
|
65
|
+
try:
|
|
66
|
+
response = await client.post(
|
|
67
|
+
f"{self.base_url}/chat/completions",
|
|
68
|
+
json=payload,
|
|
69
|
+
headers=self._get_headers(),
|
|
70
|
+
timeout=120.0,
|
|
71
|
+
)
|
|
72
|
+
response.raise_for_status()
|
|
73
|
+
data = response.json()
|
|
74
|
+
|
|
75
|
+
logger.debug("OpenAI chat completion successful")
|
|
76
|
+
return ChatResponse(
|
|
77
|
+
content=data["choices"][0]["message"]["content"],
|
|
78
|
+
model=data.get("model", request.model),
|
|
79
|
+
finish_reason=data["choices"][0].get("finish_reason", "stop"),
|
|
80
|
+
usage=data.get("usage"),
|
|
81
|
+
)
|
|
82
|
+
except httpx.HTTPStatusError as e:
|
|
83
|
+
retryable = e.response.status_code in (429, 500, 502, 503, 504)
|
|
84
|
+
logger.error("OpenAI API error: %d", e.response.status_code)
|
|
85
|
+
raise ProviderError(
|
|
86
|
+
f"OpenAI API error: {e.response.status_code}",
|
|
87
|
+
status_code=e.response.status_code,
|
|
88
|
+
retryable=retryable,
|
|
89
|
+
)
|
|
90
|
+
except httpx.HTTPError as e:
|
|
91
|
+
logger.error("OpenAI HTTP error: %s", e)
|
|
92
|
+
raise ProviderError(f"HTTP error: {e}", retryable=True)
|
|
93
|
+
|
|
94
|
+
async def chat_completion_stream(self, request: ChatRequest) -> AsyncIterator[ChatStreamChunk]:
|
|
95
|
+
"""Generate a streaming chat completion via OpenAI."""
|
|
96
|
+
payload = {
|
|
97
|
+
"model": request.model,
|
|
98
|
+
"messages": [{"role": m.role, "content": m.content} for m in request.messages],
|
|
99
|
+
"temperature": request.temperature,
|
|
100
|
+
"stream": True,
|
|
101
|
+
"stream_options": {"include_usage": True}, # Request usage info in stream
|
|
102
|
+
}
|
|
103
|
+
if request.max_tokens:
|
|
104
|
+
payload["max_tokens"] = request.max_tokens
|
|
105
|
+
|
|
106
|
+
logger.debug("OpenAI streaming chat: model=%s", request.model)
|
|
107
|
+
async with httpx.AsyncClient() as client:
|
|
108
|
+
try:
|
|
109
|
+
async with client.stream(
|
|
110
|
+
"POST",
|
|
111
|
+
f"{self.base_url}/chat/completions",
|
|
112
|
+
json=payload,
|
|
113
|
+
headers=self._get_headers(),
|
|
114
|
+
timeout=120.0,
|
|
115
|
+
) as response:
|
|
116
|
+
response.raise_for_status()
|
|
117
|
+
|
|
118
|
+
async for line in response.aiter_lines():
|
|
119
|
+
if not line or not line.startswith("data: "):
|
|
120
|
+
continue
|
|
121
|
+
|
|
122
|
+
data_str = line[6:]
|
|
123
|
+
if data_str == "[DONE]":
|
|
124
|
+
break
|
|
125
|
+
|
|
126
|
+
import json
|
|
127
|
+
|
|
128
|
+
data = json.loads(data_str)
|
|
129
|
+
|
|
130
|
+
if "choices" in data and data["choices"]:
|
|
131
|
+
delta = data["choices"][0].get("delta", {})
|
|
132
|
+
content = delta.get("content", "")
|
|
133
|
+
finish_reason = data["choices"][0].get("finish_reason")
|
|
134
|
+
usage = data.get("usage") # Capture usage info
|
|
135
|
+
|
|
136
|
+
if content or finish_reason:
|
|
137
|
+
yield ChatStreamChunk(
|
|
138
|
+
content=content,
|
|
139
|
+
finish_reason=finish_reason,
|
|
140
|
+
usage=usage,
|
|
141
|
+
)
|
|
142
|
+
except httpx.HTTPStatusError as e:
|
|
143
|
+
retryable = e.response.status_code in (429, 500, 502, 503, 504)
|
|
144
|
+
logger.error("OpenAI stream API error: %d", e.response.status_code)
|
|
145
|
+
raise ProviderError(
|
|
146
|
+
f"OpenAI API error: {e.response.status_code}",
|
|
147
|
+
status_code=e.response.status_code,
|
|
148
|
+
retryable=retryable,
|
|
149
|
+
)
|
|
150
|
+
except httpx.HTTPError as e:
|
|
151
|
+
logger.error("OpenAI stream HTTP error: %s", e)
|
|
152
|
+
raise ProviderError(f"HTTP error: {e}", retryable=True)
|
|
153
|
+
|
|
154
|
+
async def list_models(self) -> list[ModelInfo]:
|
|
155
|
+
"""List available OpenAI models."""
|
|
156
|
+
logger.debug("Fetching OpenAI models")
|
|
157
|
+
async with httpx.AsyncClient() as client:
|
|
158
|
+
try:
|
|
159
|
+
response = await client.get(
|
|
160
|
+
f"{self.base_url}/models",
|
|
161
|
+
headers=self._get_headers(),
|
|
162
|
+
timeout=30.0,
|
|
163
|
+
)
|
|
164
|
+
response.raise_for_status()
|
|
165
|
+
data = response.json()
|
|
166
|
+
|
|
167
|
+
models = []
|
|
168
|
+
for model in data.get("data", []):
|
|
169
|
+
model_id = model["id"]
|
|
170
|
+
# Filter to chat models
|
|
171
|
+
if any(x in model_id for x in ["gpt-", "o1-", "o3-"]):
|
|
172
|
+
models.append(
|
|
173
|
+
ModelInfo(
|
|
174
|
+
id=model_id,
|
|
175
|
+
name=model_id,
|
|
176
|
+
provider=self.name,
|
|
177
|
+
)
|
|
178
|
+
)
|
|
179
|
+
logger.info("Fetched %d OpenAI models", len(models))
|
|
180
|
+
return models
|
|
181
|
+
except httpx.HTTPError as e:
|
|
182
|
+
logger.warning("Failed to list OpenAI models, using defaults: %s", e)
|
|
183
|
+
# Return default models on error
|
|
184
|
+
return [
|
|
185
|
+
ModelInfo(id="gpt-4o", name="GPT-4o", provider=self.name),
|
|
186
|
+
ModelInfo(id="gpt-4o-mini", name="GPT-4o Mini", provider=self.name),
|
|
187
|
+
ModelInfo(id="gpt-4-turbo", name="GPT-4 Turbo", provider=self.name),
|
|
188
|
+
]
|