superqode 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- superqode/__init__.py +33 -0
- superqode/acp/__init__.py +23 -0
- superqode/acp/client.py +913 -0
- superqode/acp/permission_screen.py +457 -0
- superqode/acp/types.py +480 -0
- superqode/acp_discovery.py +856 -0
- superqode/agent/__init__.py +22 -0
- superqode/agent/edit_strategies.py +334 -0
- superqode/agent/loop.py +892 -0
- superqode/agent/qe_report_templates.py +39 -0
- superqode/agent/system_prompts.py +353 -0
- superqode/agent_output.py +721 -0
- superqode/agent_stream.py +953 -0
- superqode/agents/__init__.py +59 -0
- superqode/agents/acp_registry.py +305 -0
- superqode/agents/client.py +249 -0
- superqode/agents/data/augmentcode.com.toml +51 -0
- superqode/agents/data/cagent.dev.toml +51 -0
- superqode/agents/data/claude.com.toml +60 -0
- superqode/agents/data/codeassistant.dev.toml +51 -0
- superqode/agents/data/codex.openai.com.toml +57 -0
- superqode/agents/data/fastagent.ai.toml +66 -0
- superqode/agents/data/geminicli.com.toml +77 -0
- superqode/agents/data/goose.block.xyz.toml +54 -0
- superqode/agents/data/junie.jetbrains.com.toml +56 -0
- superqode/agents/data/kimi.moonshot.cn.toml +57 -0
- superqode/agents/data/llmlingagent.dev.toml +51 -0
- superqode/agents/data/molt.bot.toml +49 -0
- superqode/agents/data/opencode.ai.toml +60 -0
- superqode/agents/data/stakpak.dev.toml +51 -0
- superqode/agents/data/vtcode.dev.toml +51 -0
- superqode/agents/discovery.py +266 -0
- superqode/agents/messaging.py +160 -0
- superqode/agents/persona.py +166 -0
- superqode/agents/registry.py +421 -0
- superqode/agents/schema.py +72 -0
- superqode/agents/unified.py +367 -0
- superqode/app/__init__.py +111 -0
- superqode/app/constants.py +314 -0
- superqode/app/css.py +366 -0
- superqode/app/models.py +118 -0
- superqode/app/suggester.py +125 -0
- superqode/app/widgets.py +1591 -0
- superqode/app_enhanced.py +399 -0
- superqode/app_main.py +17187 -0
- superqode/approval.py +312 -0
- superqode/atomic.py +296 -0
- superqode/commands/__init__.py +1 -0
- superqode/commands/acp.py +965 -0
- superqode/commands/agents.py +180 -0
- superqode/commands/auth.py +278 -0
- superqode/commands/config.py +374 -0
- superqode/commands/init.py +826 -0
- superqode/commands/providers.py +819 -0
- superqode/commands/qe.py +1145 -0
- superqode/commands/roles.py +380 -0
- superqode/commands/serve.py +172 -0
- superqode/commands/suggestions.py +127 -0
- superqode/commands/superqe.py +460 -0
- superqode/config/__init__.py +51 -0
- superqode/config/loader.py +812 -0
- superqode/config/schema.py +498 -0
- superqode/core/__init__.py +111 -0
- superqode/core/roles.py +281 -0
- superqode/danger.py +386 -0
- superqode/data/superqode-template.yaml +1522 -0
- superqode/design_system.py +1080 -0
- superqode/dialogs/__init__.py +6 -0
- superqode/dialogs/base.py +39 -0
- superqode/dialogs/model.py +130 -0
- superqode/dialogs/provider.py +870 -0
- superqode/diff_view.py +919 -0
- superqode/enterprise.py +21 -0
- superqode/evaluation/__init__.py +25 -0
- superqode/evaluation/adapters.py +93 -0
- superqode/evaluation/behaviors.py +89 -0
- superqode/evaluation/engine.py +209 -0
- superqode/evaluation/scenarios.py +96 -0
- superqode/execution/__init__.py +36 -0
- superqode/execution/linter.py +538 -0
- superqode/execution/modes.py +347 -0
- superqode/execution/resolver.py +283 -0
- superqode/execution/runner.py +642 -0
- superqode/file_explorer.py +811 -0
- superqode/file_viewer.py +471 -0
- superqode/flash.py +183 -0
- superqode/guidance/__init__.py +58 -0
- superqode/guidance/config.py +203 -0
- superqode/guidance/prompts.py +71 -0
- superqode/harness/__init__.py +54 -0
- superqode/harness/accelerator.py +291 -0
- superqode/harness/config.py +319 -0
- superqode/harness/validator.py +147 -0
- superqode/history.py +279 -0
- superqode/integrations/superopt_runner.py +124 -0
- superqode/logging/__init__.py +49 -0
- superqode/logging/adapters.py +219 -0
- superqode/logging/formatter.py +923 -0
- superqode/logging/integration.py +341 -0
- superqode/logging/sinks.py +170 -0
- superqode/logging/unified_log.py +417 -0
- superqode/lsp/__init__.py +26 -0
- superqode/lsp/client.py +544 -0
- superqode/main.py +1069 -0
- superqode/mcp/__init__.py +89 -0
- superqode/mcp/auth_storage.py +380 -0
- superqode/mcp/client.py +1236 -0
- superqode/mcp/config.py +319 -0
- superqode/mcp/integration.py +337 -0
- superqode/mcp/oauth.py +436 -0
- superqode/mcp/oauth_callback.py +385 -0
- superqode/mcp/types.py +290 -0
- superqode/memory/__init__.py +31 -0
- superqode/memory/feedback.py +342 -0
- superqode/memory/store.py +522 -0
- superqode/notifications.py +369 -0
- superqode/optimization/__init__.py +5 -0
- superqode/optimization/config.py +33 -0
- superqode/permissions/__init__.py +25 -0
- superqode/permissions/rules.py +488 -0
- superqode/plan.py +323 -0
- superqode/providers/__init__.py +33 -0
- superqode/providers/gateway/__init__.py +165 -0
- superqode/providers/gateway/base.py +228 -0
- superqode/providers/gateway/litellm_gateway.py +1170 -0
- superqode/providers/gateway/openresponses_gateway.py +436 -0
- superqode/providers/health.py +297 -0
- superqode/providers/huggingface/__init__.py +74 -0
- superqode/providers/huggingface/downloader.py +472 -0
- superqode/providers/huggingface/endpoints.py +442 -0
- superqode/providers/huggingface/hub.py +531 -0
- superqode/providers/huggingface/inference.py +394 -0
- superqode/providers/huggingface/transformers_runner.py +516 -0
- superqode/providers/local/__init__.py +100 -0
- superqode/providers/local/base.py +438 -0
- superqode/providers/local/discovery.py +418 -0
- superqode/providers/local/lmstudio.py +256 -0
- superqode/providers/local/mlx.py +457 -0
- superqode/providers/local/ollama.py +486 -0
- superqode/providers/local/sglang.py +268 -0
- superqode/providers/local/tgi.py +260 -0
- superqode/providers/local/tool_support.py +477 -0
- superqode/providers/local/vllm.py +258 -0
- superqode/providers/manager.py +1338 -0
- superqode/providers/models.py +1016 -0
- superqode/providers/models_dev.py +578 -0
- superqode/providers/openresponses/__init__.py +87 -0
- superqode/providers/openresponses/converters/__init__.py +17 -0
- superqode/providers/openresponses/converters/messages.py +343 -0
- superqode/providers/openresponses/converters/tools.py +268 -0
- superqode/providers/openresponses/schema/__init__.py +56 -0
- superqode/providers/openresponses/schema/models.py +585 -0
- superqode/providers/openresponses/streaming/__init__.py +5 -0
- superqode/providers/openresponses/streaming/parser.py +338 -0
- superqode/providers/openresponses/tools/__init__.py +21 -0
- superqode/providers/openresponses/tools/apply_patch.py +352 -0
- superqode/providers/openresponses/tools/code_interpreter.py +290 -0
- superqode/providers/openresponses/tools/file_search.py +333 -0
- superqode/providers/openresponses/tools/mcp_adapter.py +252 -0
- superqode/providers/registry.py +716 -0
- superqode/providers/usage.py +332 -0
- superqode/pure_mode.py +384 -0
- superqode/qr/__init__.py +23 -0
- superqode/qr/dashboard.py +781 -0
- superqode/qr/generator.py +1018 -0
- superqode/qr/templates.py +135 -0
- superqode/safety/__init__.py +41 -0
- superqode/safety/sandbox.py +413 -0
- superqode/safety/warnings.py +256 -0
- superqode/server/__init__.py +33 -0
- superqode/server/lsp_server.py +775 -0
- superqode/server/web.py +250 -0
- superqode/session/__init__.py +25 -0
- superqode/session/persistence.py +580 -0
- superqode/session/sharing.py +477 -0
- superqode/session.py +475 -0
- superqode/sidebar.py +2991 -0
- superqode/stream_view.py +648 -0
- superqode/styles/__init__.py +3 -0
- superqode/superqe/__init__.py +184 -0
- superqode/superqe/acp_runner.py +1064 -0
- superqode/superqe/constitution/__init__.py +62 -0
- superqode/superqe/constitution/evaluator.py +308 -0
- superqode/superqe/constitution/loader.py +432 -0
- superqode/superqe/constitution/schema.py +250 -0
- superqode/superqe/events.py +591 -0
- superqode/superqe/frameworks/__init__.py +65 -0
- superqode/superqe/frameworks/base.py +234 -0
- superqode/superqe/frameworks/e2e.py +263 -0
- superqode/superqe/frameworks/executor.py +237 -0
- superqode/superqe/frameworks/javascript.py +409 -0
- superqode/superqe/frameworks/python.py +373 -0
- superqode/superqe/frameworks/registry.py +92 -0
- superqode/superqe/mcp_tools/__init__.py +47 -0
- superqode/superqe/mcp_tools/core_tools.py +418 -0
- superqode/superqe/mcp_tools/registry.py +230 -0
- superqode/superqe/mcp_tools/testing_tools.py +167 -0
- superqode/superqe/noise.py +89 -0
- superqode/superqe/orchestrator.py +778 -0
- superqode/superqe/roles.py +609 -0
- superqode/superqe/session.py +713 -0
- superqode/superqe/skills/__init__.py +57 -0
- superqode/superqe/skills/base.py +106 -0
- superqode/superqe/skills/core_skills.py +899 -0
- superqode/superqe/skills/registry.py +90 -0
- superqode/superqe/verifier.py +101 -0
- superqode/superqe_cli.py +76 -0
- superqode/tool_call.py +358 -0
- superqode/tools/__init__.py +93 -0
- superqode/tools/agent_tools.py +496 -0
- superqode/tools/base.py +324 -0
- superqode/tools/batch_tool.py +133 -0
- superqode/tools/diagnostics.py +311 -0
- superqode/tools/edit_tools.py +653 -0
- superqode/tools/enhanced_base.py +515 -0
- superqode/tools/file_tools.py +269 -0
- superqode/tools/file_tracking.py +45 -0
- superqode/tools/lsp_tools.py +610 -0
- superqode/tools/network_tools.py +350 -0
- superqode/tools/permissions.py +400 -0
- superqode/tools/question_tool.py +324 -0
- superqode/tools/search_tools.py +598 -0
- superqode/tools/shell_tools.py +259 -0
- superqode/tools/todo_tools.py +121 -0
- superqode/tools/validation.py +80 -0
- superqode/tools/web_tools.py +639 -0
- superqode/tui.py +1152 -0
- superqode/tui_integration.py +875 -0
- superqode/tui_widgets/__init__.py +27 -0
- superqode/tui_widgets/widgets/__init__.py +18 -0
- superqode/tui_widgets/widgets/progress.py +185 -0
- superqode/tui_widgets/widgets/tool_display.py +188 -0
- superqode/undo_manager.py +574 -0
- superqode/utils/__init__.py +5 -0
- superqode/utils/error_handling.py +323 -0
- superqode/utils/fuzzy.py +257 -0
- superqode/widgets/__init__.py +477 -0
- superqode/widgets/agent_collab.py +390 -0
- superqode/widgets/agent_store.py +936 -0
- superqode/widgets/agent_switcher.py +395 -0
- superqode/widgets/animation_manager.py +284 -0
- superqode/widgets/code_context.py +356 -0
- superqode/widgets/command_palette.py +412 -0
- superqode/widgets/connection_status.py +537 -0
- superqode/widgets/conversation_history.py +470 -0
- superqode/widgets/diff_indicator.py +155 -0
- superqode/widgets/enhanced_status_bar.py +385 -0
- superqode/widgets/enhanced_toast.py +476 -0
- superqode/widgets/file_browser.py +809 -0
- superqode/widgets/file_reference.py +585 -0
- superqode/widgets/issue_timeline.py +340 -0
- superqode/widgets/leader_key.py +264 -0
- superqode/widgets/mode_switcher.py +445 -0
- superqode/widgets/model_picker.py +234 -0
- superqode/widgets/permission_preview.py +1205 -0
- superqode/widgets/prompt.py +358 -0
- superqode/widgets/provider_connect.py +725 -0
- superqode/widgets/pty_shell.py +587 -0
- superqode/widgets/qe_dashboard.py +321 -0
- superqode/widgets/resizable_sidebar.py +377 -0
- superqode/widgets/response_changes.py +218 -0
- superqode/widgets/response_display.py +528 -0
- superqode/widgets/rich_tool_display.py +613 -0
- superqode/widgets/sidebar_panels.py +1180 -0
- superqode/widgets/slash_complete.py +356 -0
- superqode/widgets/split_view.py +612 -0
- superqode/widgets/status_bar.py +273 -0
- superqode/widgets/superqode_display.py +786 -0
- superqode/widgets/thinking_display.py +815 -0
- superqode/widgets/throbber.py +87 -0
- superqode/widgets/toast.py +206 -0
- superqode/widgets/unified_output.py +1073 -0
- superqode/workspace/__init__.py +75 -0
- superqode/workspace/artifacts.py +472 -0
- superqode/workspace/coordinator.py +353 -0
- superqode/workspace/diff_tracker.py +429 -0
- superqode/workspace/git_guard.py +373 -0
- superqode/workspace/git_snapshot.py +526 -0
- superqode/workspace/manager.py +750 -0
- superqode/workspace/snapshot.py +357 -0
- superqode/workspace/watcher.py +535 -0
- superqode/workspace/worktree.py +440 -0
- superqode-0.1.5.dist-info/METADATA +204 -0
- superqode-0.1.5.dist-info/RECORD +288 -0
- superqode-0.1.5.dist-info/WHEEL +5 -0
- superqode-0.1.5.dist-info/entry_points.txt +3 -0
- superqode-0.1.5.dist-info/licenses/LICENSE +648 -0
- superqode-0.1.5.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1170 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LiteLLM Gateway Implementation.
|
|
3
|
+
|
|
4
|
+
Default gateway for BYOK mode using LiteLLM for unified API access
|
|
5
|
+
to 100+ LLM providers.
|
|
6
|
+
|
|
7
|
+
Performance features:
|
|
8
|
+
- Background prewarming to avoid cold-start latency
|
|
9
|
+
- Shared module instance across gateway instances
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import asyncio
|
|
13
|
+
import concurrent.futures
|
|
14
|
+
import json
|
|
15
|
+
import os
|
|
16
|
+
import threading
|
|
17
|
+
import time
|
|
18
|
+
from typing import Any, AsyncIterator, Dict, List, Optional
|
|
19
|
+
|
|
20
|
+
from .base import (
|
|
21
|
+
AuthenticationError,
|
|
22
|
+
Cost,
|
|
23
|
+
GatewayError,
|
|
24
|
+
GatewayInterface,
|
|
25
|
+
GatewayResponse,
|
|
26
|
+
InvalidRequestError,
|
|
27
|
+
Message,
|
|
28
|
+
ModelNotFoundError,
|
|
29
|
+
RateLimitError,
|
|
30
|
+
StreamChunk,
|
|
31
|
+
ToolDefinition,
|
|
32
|
+
Usage,
|
|
33
|
+
)
|
|
34
|
+
from ..registry import PROVIDERS, ProviderDef
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# Module-level shared state for prewarming
|
|
38
|
+
_litellm_module = None
|
|
39
|
+
_litellm_lock = threading.Lock()
|
|
40
|
+
_prewarm_task: Optional[asyncio.Task] = None
|
|
41
|
+
_prewarm_complete = threading.Event()
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _load_litellm():
|
|
45
|
+
"""Load and configure litellm module (thread-safe)."""
|
|
46
|
+
global _litellm_module
|
|
47
|
+
with _litellm_lock:
|
|
48
|
+
if _litellm_module is None:
|
|
49
|
+
import litellm
|
|
50
|
+
|
|
51
|
+
litellm.drop_params = True # Drop unsupported params
|
|
52
|
+
litellm.set_verbose = False
|
|
53
|
+
_litellm_module = litellm
|
|
54
|
+
_prewarm_complete.set()
|
|
55
|
+
return _litellm_module
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class LiteLLMGateway(GatewayInterface):
|
|
59
|
+
"""LiteLLM-based gateway for BYOK mode.
|
|
60
|
+
|
|
61
|
+
Uses LiteLLM to provide unified access to 100+ LLM providers.
|
|
62
|
+
|
|
63
|
+
Performance:
|
|
64
|
+
Call prewarm() during app startup to load litellm in background,
|
|
65
|
+
avoiding ~500-800ms cold-start on first LLM request.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
# Class-level executor for background tasks
|
|
69
|
+
_executor = concurrent.futures.ThreadPoolExecutor(max_workers=2)
|
|
70
|
+
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
track_costs: bool = True,
|
|
74
|
+
timeout: float = 300.0,
|
|
75
|
+
):
|
|
76
|
+
self.track_costs = track_costs
|
|
77
|
+
self.timeout = timeout
|
|
78
|
+
|
|
79
|
+
@classmethod
|
|
80
|
+
def prewarm(cls) -> None:
|
|
81
|
+
"""Start prewarming litellm in background thread.
|
|
82
|
+
|
|
83
|
+
Call this during app startup for faster first LLM request.
|
|
84
|
+
Non-blocking - returns immediately while loading happens in background.
|
|
85
|
+
|
|
86
|
+
Example:
|
|
87
|
+
# In app startup
|
|
88
|
+
LiteLLMGateway.prewarm()
|
|
89
|
+
|
|
90
|
+
# Later, first request will be fast
|
|
91
|
+
gateway = LiteLLMGateway()
|
|
92
|
+
await gateway.chat_completion(...)
|
|
93
|
+
"""
|
|
94
|
+
if _prewarm_complete.is_set():
|
|
95
|
+
return # Already loaded
|
|
96
|
+
|
|
97
|
+
# Submit to thread pool (non-blocking)
|
|
98
|
+
cls._executor.submit(_load_litellm)
|
|
99
|
+
|
|
100
|
+
@classmethod
|
|
101
|
+
async def prewarm_async(cls) -> None:
|
|
102
|
+
"""Async version of prewarm - await to ensure litellm is loaded.
|
|
103
|
+
|
|
104
|
+
Use this if you want to wait for prewarming to complete.
|
|
105
|
+
"""
|
|
106
|
+
if _prewarm_complete.is_set():
|
|
107
|
+
return
|
|
108
|
+
|
|
109
|
+
loop = asyncio.get_event_loop()
|
|
110
|
+
await loop.run_in_executor(cls._executor, _load_litellm)
|
|
111
|
+
|
|
112
|
+
@classmethod
|
|
113
|
+
def is_prewarmed(cls) -> bool:
|
|
114
|
+
"""Check if litellm has been loaded."""
|
|
115
|
+
return _prewarm_complete.is_set()
|
|
116
|
+
|
|
117
|
+
@classmethod
|
|
118
|
+
def wait_for_prewarm(cls, timeout: float = 5.0) -> bool:
|
|
119
|
+
"""Wait for prewarming to complete.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
timeout: Maximum seconds to wait
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
True if prewarmed, False if timeout
|
|
126
|
+
"""
|
|
127
|
+
return _prewarm_complete.wait(timeout=timeout)
|
|
128
|
+
|
|
129
|
+
def _get_litellm(self):
|
|
130
|
+
"""Get litellm module (uses shared prewarmed instance if available)."""
|
|
131
|
+
global _litellm_module
|
|
132
|
+
if _litellm_module is not None:
|
|
133
|
+
return _litellm_module
|
|
134
|
+
|
|
135
|
+
# Not prewarmed - load synchronously (will be cached for next time)
|
|
136
|
+
try:
|
|
137
|
+
return _load_litellm()
|
|
138
|
+
except ImportError as e:
|
|
139
|
+
raise GatewayError("LiteLLM is not installed. Install with: pip install litellm") from e
|
|
140
|
+
|
|
141
|
+
def get_model_string(self, provider: str, model: str) -> str:
|
|
142
|
+
"""Get the full model string for LiteLLM.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
provider: Provider ID (e.g., "anthropic")
|
|
146
|
+
model: Model ID (e.g., "claude-sonnet-4-20250514")
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
Full model string for LiteLLM (e.g., "anthropic/claude-sonnet-4-20250514")
|
|
150
|
+
"""
|
|
151
|
+
provider_def = PROVIDERS.get(provider)
|
|
152
|
+
|
|
153
|
+
if provider_def and provider_def.litellm_prefix:
|
|
154
|
+
# Don't double-prefix
|
|
155
|
+
if model.startswith(provider_def.litellm_prefix):
|
|
156
|
+
return model
|
|
157
|
+
# Empty prefix means no prefix needed (e.g., OpenAI)
|
|
158
|
+
if provider_def.litellm_prefix == "":
|
|
159
|
+
return model
|
|
160
|
+
return f"{provider_def.litellm_prefix}{model}"
|
|
161
|
+
|
|
162
|
+
# Unknown provider - try as-is
|
|
163
|
+
return model
|
|
164
|
+
|
|
165
|
+
def _setup_provider_env(self, provider: str) -> None:
|
|
166
|
+
"""Set up environment for a provider if needed."""
|
|
167
|
+
provider_def = PROVIDERS.get(provider)
|
|
168
|
+
if not provider_def:
|
|
169
|
+
return
|
|
170
|
+
|
|
171
|
+
# Handle base URL for local/custom providers
|
|
172
|
+
if provider_def.base_url_env:
|
|
173
|
+
base_url = os.environ.get(provider_def.base_url_env)
|
|
174
|
+
if not base_url and provider_def.default_base_url:
|
|
175
|
+
# Set default base URL if not configured
|
|
176
|
+
os.environ[provider_def.base_url_env] = provider_def.default_base_url
|
|
177
|
+
base_url = provider_def.default_base_url
|
|
178
|
+
|
|
179
|
+
# For Ollama, configure LiteLLM via OLLAMA_API_BASE environment variable
|
|
180
|
+
# LiteLLM 1.80.11 uses OLLAMA_API_BASE env var (not ollama_base_url attribute)
|
|
181
|
+
if provider == "ollama" and base_url:
|
|
182
|
+
# Set both OLLAMA_HOST (our convention) and OLLAMA_API_BASE (LiteLLM convention)
|
|
183
|
+
os.environ["OLLAMA_HOST"] = base_url
|
|
184
|
+
os.environ["OLLAMA_API_BASE"] = base_url
|
|
185
|
+
|
|
186
|
+
# For LM Studio - configure for local OpenAI-compatible API
|
|
187
|
+
if provider == "lmstudio" and base_url:
|
|
188
|
+
# LM Studio uses OpenAI-compatible API at /v1
|
|
189
|
+
# Set OPENAI_API_BASE to the base URL (already includes /v1)
|
|
190
|
+
clean_url = base_url.rstrip("/")
|
|
191
|
+
os.environ["OPENAI_API_BASE"] = clean_url
|
|
192
|
+
# Also set the provider-specific env var
|
|
193
|
+
os.environ["LMSTUDIO_HOST"] = clean_url
|
|
194
|
+
# For local LM Studio, set a dummy API key to avoid LiteLLM auth errors
|
|
195
|
+
# Local servers typically don't require authentication
|
|
196
|
+
os.environ["OPENAI_API_KEY"] = os.environ.get(
|
|
197
|
+
"OPENAI_API_KEY", "sk-local-lmstudio-dummy"
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
# For vLLM - configure for OpenAI-compatible API
|
|
201
|
+
if provider == "vllm" and base_url:
|
|
202
|
+
# vLLM uses OpenAI-compatible API at /v1
|
|
203
|
+
# Set OPENAI_API_BASE to the base URL (already includes /v1)
|
|
204
|
+
clean_url = base_url.rstrip("/")
|
|
205
|
+
os.environ["OPENAI_API_BASE"] = clean_url
|
|
206
|
+
# Also set the provider-specific env var
|
|
207
|
+
os.environ["VLLM_HOST"] = clean_url
|
|
208
|
+
# For local vLLM, set a dummy API key to avoid LiteLLM auth errors
|
|
209
|
+
# Local servers typically don't require authentication
|
|
210
|
+
os.environ["OPENAI_API_KEY"] = os.environ.get(
|
|
211
|
+
"OPENAI_API_KEY", "sk-local-vllm-dummy"
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
# For SGLang - configure for OpenAI-compatible API
|
|
215
|
+
if provider == "sglang" and base_url:
|
|
216
|
+
# SGLang uses OpenAI-compatible API at /v1
|
|
217
|
+
# Set OPENAI_API_BASE to the base URL (already includes /v1)
|
|
218
|
+
clean_url = base_url.rstrip("/")
|
|
219
|
+
os.environ["OPENAI_API_BASE"] = clean_url
|
|
220
|
+
# Also set the provider-specific env var
|
|
221
|
+
os.environ["SGLANG_HOST"] = clean_url
|
|
222
|
+
# For local SGLang, set a dummy API key to avoid LiteLLM auth errors
|
|
223
|
+
# Local servers typically don't require authentication
|
|
224
|
+
os.environ["OPENAI_API_KEY"] = os.environ.get(
|
|
225
|
+
"OPENAI_API_KEY", "sk-local-sglang-dummy"
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
# MLX is handled directly, not through LiteLLM, so no env setup needed
|
|
229
|
+
|
|
230
|
+
# Ensure API keys are set for cloud providers (LiteLLM reads from environment)
|
|
231
|
+
# Google - supports both GOOGLE_API_KEY and GEMINI_API_KEY
|
|
232
|
+
if provider == "google":
|
|
233
|
+
google_key = os.environ.get("GOOGLE_API_KEY") or os.environ.get("GEMINI_API_KEY")
|
|
234
|
+
if google_key:
|
|
235
|
+
# Ensure both are set for maximum compatibility
|
|
236
|
+
os.environ["GOOGLE_API_KEY"] = google_key
|
|
237
|
+
if not os.environ.get("GEMINI_API_KEY"):
|
|
238
|
+
os.environ["GEMINI_API_KEY"] = google_key
|
|
239
|
+
|
|
240
|
+
def _convert_messages(self, messages: List[Message]) -> List[Dict[str, Any]]:
|
|
241
|
+
"""Convert Message objects to LiteLLM format."""
|
|
242
|
+
result = []
|
|
243
|
+
for msg in messages:
|
|
244
|
+
m = {"role": msg.role, "content": msg.content}
|
|
245
|
+
if msg.name:
|
|
246
|
+
m["name"] = msg.name
|
|
247
|
+
if msg.tool_calls:
|
|
248
|
+
m["tool_calls"] = msg.tool_calls
|
|
249
|
+
if msg.tool_call_id:
|
|
250
|
+
m["tool_call_id"] = msg.tool_call_id
|
|
251
|
+
result.append(m)
|
|
252
|
+
return result
|
|
253
|
+
|
|
254
|
+
def _convert_tools(
|
|
255
|
+
self, tools: Optional[List[ToolDefinition]]
|
|
256
|
+
) -> Optional[List[Dict[str, Any]]]:
|
|
257
|
+
"""Convert ToolDefinition objects to LiteLLM format."""
|
|
258
|
+
if not tools:
|
|
259
|
+
return None
|
|
260
|
+
return [
|
|
261
|
+
{
|
|
262
|
+
"type": "function",
|
|
263
|
+
"function": {
|
|
264
|
+
"name": tool.name,
|
|
265
|
+
"description": tool.description,
|
|
266
|
+
"parameters": tool.parameters,
|
|
267
|
+
},
|
|
268
|
+
}
|
|
269
|
+
for tool in tools
|
|
270
|
+
]
|
|
271
|
+
|
|
272
|
+
def _normalize_tool_calls(self, tool_calls: Any) -> Optional[List[Dict[str, Any]]]:
|
|
273
|
+
"""Normalize tool calls from LiteLLM to dictionaries.
|
|
274
|
+
|
|
275
|
+
Handles both dict format and object format (ChatCompletionDeltaToolCall, etc.).
|
|
276
|
+
This is necessary because different LiteLLM providers return tool calls in different formats.
|
|
277
|
+
"""
|
|
278
|
+
if not tool_calls:
|
|
279
|
+
return None
|
|
280
|
+
|
|
281
|
+
if isinstance(tool_calls, list):
|
|
282
|
+
normalized = []
|
|
283
|
+
for tc in tool_calls:
|
|
284
|
+
if isinstance(tc, dict):
|
|
285
|
+
# Already a dict - use as-is
|
|
286
|
+
normalized.append(tc)
|
|
287
|
+
else:
|
|
288
|
+
# Object format (e.g., ChatCompletionDeltaToolCall) - convert to dict
|
|
289
|
+
tc_dict = {}
|
|
290
|
+
|
|
291
|
+
# Extract id if present
|
|
292
|
+
if hasattr(tc, "id"):
|
|
293
|
+
tc_dict["id"] = getattr(tc, "id", None)
|
|
294
|
+
elif hasattr(tc, "tool_call_id"):
|
|
295
|
+
tc_dict["id"] = getattr(tc, "tool_call_id", None)
|
|
296
|
+
|
|
297
|
+
# Extract function info
|
|
298
|
+
if hasattr(tc, "function"):
|
|
299
|
+
func = getattr(tc, "function")
|
|
300
|
+
if isinstance(func, dict):
|
|
301
|
+
tc_dict["function"] = func
|
|
302
|
+
else:
|
|
303
|
+
# Function object - extract fields
|
|
304
|
+
func_dict = {}
|
|
305
|
+
if hasattr(func, "name"):
|
|
306
|
+
func_dict["name"] = getattr(func, "name", "")
|
|
307
|
+
if hasattr(func, "arguments"):
|
|
308
|
+
func_dict["arguments"] = getattr(func, "arguments", "{}")
|
|
309
|
+
elif hasattr(func, "argument"):
|
|
310
|
+
func_dict["arguments"] = getattr(func, "argument", "{}")
|
|
311
|
+
tc_dict["function"] = func_dict
|
|
312
|
+
elif hasattr(tc, "name") or hasattr(tc, "function_name"):
|
|
313
|
+
# Tool call might have name directly
|
|
314
|
+
func_dict = {
|
|
315
|
+
"name": getattr(tc, "name", None) or getattr(tc, "function_name", ""),
|
|
316
|
+
"arguments": getattr(tc, "arguments", None)
|
|
317
|
+
or getattr(tc, "args", "{}")
|
|
318
|
+
or "{}",
|
|
319
|
+
}
|
|
320
|
+
tc_dict["function"] = func_dict
|
|
321
|
+
|
|
322
|
+
# If we couldn't extract anything useful, skip it
|
|
323
|
+
if not tc_dict or "function" not in tc_dict:
|
|
324
|
+
continue
|
|
325
|
+
|
|
326
|
+
normalized.append(tc_dict)
|
|
327
|
+
return normalized if normalized else None
|
|
328
|
+
|
|
329
|
+
# Single tool call (not a list) - wrap in list and process
|
|
330
|
+
if isinstance(tool_calls, dict):
|
|
331
|
+
return [tool_calls]
|
|
332
|
+
else:
|
|
333
|
+
# Object format - normalize it by wrapping in list
|
|
334
|
+
result = self._normalize_tool_calls([tool_calls])
|
|
335
|
+
return result
|
|
336
|
+
|
|
337
|
+
def _handle_litellm_error(self, e: Exception, provider: str, model: str) -> None:
|
|
338
|
+
"""Convert LiteLLM exceptions to gateway errors."""
|
|
339
|
+
litellm = self._get_litellm()
|
|
340
|
+
error_msg = str(e)
|
|
341
|
+
|
|
342
|
+
# Get provider info for helpful error messages
|
|
343
|
+
provider_def = PROVIDERS.get(provider)
|
|
344
|
+
docs_url = provider_def.docs_url if provider_def else ""
|
|
345
|
+
env_vars = provider_def.env_vars if provider_def else []
|
|
346
|
+
|
|
347
|
+
# Check for specific error types
|
|
348
|
+
if isinstance(e, litellm.AuthenticationError):
|
|
349
|
+
env_hint = f"Set {' or '.join(env_vars)}" if env_vars else ""
|
|
350
|
+
raise AuthenticationError(
|
|
351
|
+
f"Invalid API key for provider '{provider}'. {env_hint}. "
|
|
352
|
+
f"Get your key at: {docs_url}",
|
|
353
|
+
provider=provider,
|
|
354
|
+
model=model,
|
|
355
|
+
error_type="authentication",
|
|
356
|
+
) from e
|
|
357
|
+
|
|
358
|
+
if isinstance(e, litellm.RateLimitError):
|
|
359
|
+
raise RateLimitError(
|
|
360
|
+
f"Rate limit exceeded for provider '{provider}'. "
|
|
361
|
+
"Wait and retry, or upgrade your API plan.",
|
|
362
|
+
provider=provider,
|
|
363
|
+
model=model,
|
|
364
|
+
error_type="rate_limit",
|
|
365
|
+
) from e
|
|
366
|
+
|
|
367
|
+
if isinstance(e, litellm.NotFoundError):
|
|
368
|
+
example_models = provider_def.example_models if provider_def else []
|
|
369
|
+
models_hint = (
|
|
370
|
+
f"Available models: {', '.join(example_models[:5])}" if example_models else ""
|
|
371
|
+
)
|
|
372
|
+
raise ModelNotFoundError(
|
|
373
|
+
f"Model '{model}' not found for provider '{provider}'. {models_hint}",
|
|
374
|
+
provider=provider,
|
|
375
|
+
model=model,
|
|
376
|
+
error_type="model_not_found",
|
|
377
|
+
) from e
|
|
378
|
+
|
|
379
|
+
if isinstance(e, litellm.BadRequestError):
|
|
380
|
+
raise InvalidRequestError(
|
|
381
|
+
f"Invalid request to '{provider}': {error_msg}",
|
|
382
|
+
provider=provider,
|
|
383
|
+
model=model,
|
|
384
|
+
error_type="invalid_request",
|
|
385
|
+
) from e
|
|
386
|
+
|
|
387
|
+
# Generic error
|
|
388
|
+
raise GatewayError(
|
|
389
|
+
f"Error calling '{provider}/{model}': {error_msg}",
|
|
390
|
+
provider=provider,
|
|
391
|
+
model=model,
|
|
392
|
+
) from e
|
|
393
|
+
|
|
394
|
+
async def _mlx_chat_completion(
|
|
395
|
+
self,
|
|
396
|
+
messages: List[Message],
|
|
397
|
+
model: str,
|
|
398
|
+
temperature: Optional[float] = None,
|
|
399
|
+
max_tokens: Optional[int] = None,
|
|
400
|
+
tools: Optional[List[ToolDefinition]] = None,
|
|
401
|
+
tool_choice: Optional[str] = None,
|
|
402
|
+
**kwargs,
|
|
403
|
+
) -> GatewayResponse:
|
|
404
|
+
"""Handle MLX chat completion directly (bypassing LiteLLM auth issues)."""
|
|
405
|
+
from ..local.mlx import MLXClient
|
|
406
|
+
|
|
407
|
+
client = MLXClient()
|
|
408
|
+
|
|
409
|
+
# Convert messages to OpenAI format
|
|
410
|
+
openai_messages = []
|
|
411
|
+
for msg in messages:
|
|
412
|
+
openai_msg = {"role": msg.role, "content": msg.content}
|
|
413
|
+
if msg.name:
|
|
414
|
+
openai_msg["name"] = msg.name
|
|
415
|
+
if msg.tool_calls:
|
|
416
|
+
openai_msg["tool_calls"] = msg.tool_calls
|
|
417
|
+
if msg.tool_call_id:
|
|
418
|
+
openai_msg["tool_call_id"] = msg.tool_call_id
|
|
419
|
+
openai_messages.append(openai_msg)
|
|
420
|
+
|
|
421
|
+
# Build request
|
|
422
|
+
request_data = {
|
|
423
|
+
"model": model,
|
|
424
|
+
"messages": openai_messages,
|
|
425
|
+
}
|
|
426
|
+
|
|
427
|
+
if temperature is not None:
|
|
428
|
+
request_data["temperature"] = temperature
|
|
429
|
+
if max_tokens is not None:
|
|
430
|
+
request_data["max_tokens"] = max_tokens
|
|
431
|
+
if tools:
|
|
432
|
+
request_data["tools"] = self._convert_tools(tools)
|
|
433
|
+
if tool_choice:
|
|
434
|
+
request_data["tool_choice"] = tool_choice
|
|
435
|
+
|
|
436
|
+
try:
|
|
437
|
+
# Make direct request to MLX server (MLX models can be slow)
|
|
438
|
+
response_data = await client._async_request(
|
|
439
|
+
"POST", "/v1/chat/completions", request_data, timeout=120.0
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
# Extract response
|
|
443
|
+
choice = response_data["choices"][0]
|
|
444
|
+
message = choice["message"]
|
|
445
|
+
|
|
446
|
+
# Build usage info
|
|
447
|
+
usage_data = response_data.get("usage", {})
|
|
448
|
+
usage = None
|
|
449
|
+
if usage_data:
|
|
450
|
+
usage = Usage(
|
|
451
|
+
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
|
452
|
+
completion_tokens=usage_data.get("completion_tokens", 0),
|
|
453
|
+
total_tokens=usage_data.get("total_tokens", 0),
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
# Clean up MLX response content - remove special tokens that might confuse users
|
|
457
|
+
content = message.get("content", "")
|
|
458
|
+
if content:
|
|
459
|
+
# Some MLX models return content with special tokens like <|channel|>, <|message|>, etc.
|
|
460
|
+
# Clean these up for better user experience
|
|
461
|
+
content = (
|
|
462
|
+
content.replace("<|channel|>", "")
|
|
463
|
+
.replace("<|message|>", "")
|
|
464
|
+
.replace("<|end|>", "")
|
|
465
|
+
.replace("<|start|>", "")
|
|
466
|
+
)
|
|
467
|
+
content = content.replace(
|
|
468
|
+
"assistant", ""
|
|
469
|
+
).strip() # Remove duplicate assistant markers
|
|
470
|
+
|
|
471
|
+
return GatewayResponse(
|
|
472
|
+
content=content,
|
|
473
|
+
role=message.get("role", "assistant"),
|
|
474
|
+
finish_reason=choice.get("finish_reason"),
|
|
475
|
+
usage=usage,
|
|
476
|
+
model=response_data.get("model", model),
|
|
477
|
+
provider="mlx",
|
|
478
|
+
tool_calls=message.get("tool_calls"),
|
|
479
|
+
raw_response=response_data,
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
except Exception as e:
|
|
483
|
+
# Provide more specific MLX error messages
|
|
484
|
+
error_msg = str(e)
|
|
485
|
+
if "broadcast_shapes" in error_msg or "cannot be broadcast" in error_msg:
|
|
486
|
+
raise GatewayError(
|
|
487
|
+
f"MLX server encountered a KV cache conflict (concurrent request issue).\n\n"
|
|
488
|
+
f"This happens when multiple requests are sent to the MLX server simultaneously.\n"
|
|
489
|
+
f"MLX servers can only handle one request at a time to avoid memory conflicts.\n\n"
|
|
490
|
+
f"To fix:\n"
|
|
491
|
+
f"1. [yellow]Wait for any running requests to complete[/yellow]\n"
|
|
492
|
+
f"2. [cyan]superqode providers mlx list[/cyan] - Check server status\n"
|
|
493
|
+
f"3. If server crashed: [cyan]superqode providers mlx server --model {model}[/cyan] - Restart server\n"
|
|
494
|
+
f"4. Try your request again with only one active session\n\n"
|
|
495
|
+
f"[dim]💡 MLX Tip: Each model needs its own server instance for concurrent use[/dim]",
|
|
496
|
+
provider="mlx",
|
|
497
|
+
model=model,
|
|
498
|
+
) from e
|
|
499
|
+
elif "Expecting value" in error_msg or "Invalid JSON" in error_msg:
|
|
500
|
+
raise GatewayError(
|
|
501
|
+
f"MLX server returned invalid response.\n\n"
|
|
502
|
+
f"This usually means the MLX server crashed or is in an error state.\n\n"
|
|
503
|
+
f"To fix:\n"
|
|
504
|
+
f"1. [cyan]superqode providers mlx list[/cyan] - Check if server is running\n"
|
|
505
|
+
f"2. If not running: [cyan]superqode providers mlx server --model {model}[/cyan] - Start server\n"
|
|
506
|
+
f"3. Wait 1-2 minutes for large models to load\n"
|
|
507
|
+
f"4. Try again",
|
|
508
|
+
provider="mlx",
|
|
509
|
+
model=model,
|
|
510
|
+
) from e
|
|
511
|
+
elif "Connection refused" in error_msg:
|
|
512
|
+
raise GatewayError(
|
|
513
|
+
f"Cannot connect to MLX server at http://localhost:8080.\n\n"
|
|
514
|
+
f"MLX server is not running. To fix:\n\n"
|
|
515
|
+
f"1. [cyan]superqode providers mlx setup[/cyan] - Complete setup guide\n"
|
|
516
|
+
f"2. [cyan]superqode providers mlx server --model {model}[/cyan] - Get server command\n"
|
|
517
|
+
f"3. Run the server command in a separate terminal\n"
|
|
518
|
+
f"4. Try connecting again",
|
|
519
|
+
provider="mlx",
|
|
520
|
+
model=model,
|
|
521
|
+
) from e
|
|
522
|
+
elif "Connection timed out" in error_msg or "timeout" in error_msg.lower():
|
|
523
|
+
raise GatewayError(
|
|
524
|
+
f"MLX server timed out. Large MLX models (like {model}) can take 1-2 minutes for first response.\n\n"
|
|
525
|
+
f"Please wait and try again. If this persists:\n"
|
|
526
|
+
f"1. Check server is still running: [cyan]superqode providers mlx list[/cyan]\n"
|
|
527
|
+
f"2. Try a smaller model for testing\n"
|
|
528
|
+
f"3. Restart the server if needed",
|
|
529
|
+
provider="mlx",
|
|
530
|
+
model=model,
|
|
531
|
+
) from e
|
|
532
|
+
else:
|
|
533
|
+
# Convert to gateway error
|
|
534
|
+
self._handle_litellm_error(e, "mlx", model)
|
|
535
|
+
|
|
536
|
+
async def _lmstudio_chat_completion(
|
|
537
|
+
self,
|
|
538
|
+
messages: List[Message],
|
|
539
|
+
model: str,
|
|
540
|
+
temperature: Optional[float] = None,
|
|
541
|
+
max_tokens: Optional[int] = None,
|
|
542
|
+
tools: Optional[List[ToolDefinition]] = None,
|
|
543
|
+
tool_choice: Optional[str] = None,
|
|
544
|
+
**kwargs,
|
|
545
|
+
) -> GatewayResponse:
|
|
546
|
+
"""Handle LM Studio chat completion directly to control endpoint."""
|
|
547
|
+
import aiohttp
|
|
548
|
+
from ..registry import PROVIDERS
|
|
549
|
+
|
|
550
|
+
# Get LM Studio base URL
|
|
551
|
+
provider_def = PROVIDERS.get("lmstudio")
|
|
552
|
+
base_url = provider_def.default_base_url if provider_def else "http://localhost:1234"
|
|
553
|
+
if provider_def and provider_def.base_url_env:
|
|
554
|
+
base_url = os.environ.get(provider_def.base_url_env, base_url)
|
|
555
|
+
|
|
556
|
+
# LM Studio typically serves at /v1/chat/completions
|
|
557
|
+
url = f"{base_url.rstrip('/')}/v1/chat/completions"
|
|
558
|
+
|
|
559
|
+
# Convert messages to OpenAI format
|
|
560
|
+
openai_messages = []
|
|
561
|
+
for msg in messages:
|
|
562
|
+
openai_msg = {"role": msg.role, "content": msg.content}
|
|
563
|
+
if msg.name:
|
|
564
|
+
openai_msg["name"] = msg.name
|
|
565
|
+
if msg.tool_calls:
|
|
566
|
+
openai_msg["tool_calls"] = msg.tool_calls
|
|
567
|
+
if msg.tool_call_id:
|
|
568
|
+
openai_msg["tool_call_id"] = msg.tool_call_id
|
|
569
|
+
openai_messages.append(openai_msg)
|
|
570
|
+
|
|
571
|
+
# Build request
|
|
572
|
+
request_data = {
|
|
573
|
+
"model": model,
|
|
574
|
+
"messages": openai_messages,
|
|
575
|
+
}
|
|
576
|
+
|
|
577
|
+
if temperature is not None:
|
|
578
|
+
request_data["temperature"] = temperature
|
|
579
|
+
if max_tokens is not None:
|
|
580
|
+
request_data["max_tokens"] = max_tokens
|
|
581
|
+
if tools:
|
|
582
|
+
request_data["tools"] = self._convert_tools(tools)
|
|
583
|
+
if tool_choice:
|
|
584
|
+
request_data["tool_choice"] = tool_choice
|
|
585
|
+
|
|
586
|
+
headers = {
|
|
587
|
+
"Content-Type": "application/json",
|
|
588
|
+
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY', 'sk-local-lmstudio-dummy')}",
|
|
589
|
+
}
|
|
590
|
+
|
|
591
|
+
try:
|
|
592
|
+
async with aiohttp.ClientSession() as session:
|
|
593
|
+
async with session.post(
|
|
594
|
+
url,
|
|
595
|
+
json=request_data,
|
|
596
|
+
headers=headers,
|
|
597
|
+
timeout=aiohttp.ClientTimeout(total=120.0),
|
|
598
|
+
) as response:
|
|
599
|
+
response_data = await response.json()
|
|
600
|
+
|
|
601
|
+
# Extract response
|
|
602
|
+
choice = response_data["choices"][0]
|
|
603
|
+
message = choice["message"]
|
|
604
|
+
|
|
605
|
+
# Build usage info
|
|
606
|
+
usage_data = response_data.get("usage", {})
|
|
607
|
+
usage = None
|
|
608
|
+
if usage_data:
|
|
609
|
+
usage = Usage(
|
|
610
|
+
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
|
611
|
+
completion_tokens=usage_data.get("completion_tokens", 0),
|
|
612
|
+
total_tokens=usage_data.get("total_tokens", 0),
|
|
613
|
+
)
|
|
614
|
+
|
|
615
|
+
return GatewayResponse(
|
|
616
|
+
content=message.get("content", ""),
|
|
617
|
+
role=message.get("role", "assistant"),
|
|
618
|
+
finish_reason=choice.get("finish_reason"),
|
|
619
|
+
usage=usage,
|
|
620
|
+
model=response_data.get("model", model),
|
|
621
|
+
provider="lmstudio",
|
|
622
|
+
tool_calls=message.get("tool_calls"),
|
|
623
|
+
raw_response=response_data,
|
|
624
|
+
)
|
|
625
|
+
|
|
626
|
+
except aiohttp.ClientError as e:
|
|
627
|
+
if "Connection refused" in str(e):
|
|
628
|
+
raise GatewayError(
|
|
629
|
+
f"Cannot connect to LM Studio server at {base_url}.\n\n"
|
|
630
|
+
f"LM Studio server is not running. To fix:\n\n"
|
|
631
|
+
f"1. [cyan]Open LM Studio application[/cyan]\n"
|
|
632
|
+
f"2. [cyan]Load a model (like qwen/qwen3-30b)[/cyan]\n"
|
|
633
|
+
f"3. [cyan]Start the local server[/cyan]\n"
|
|
634
|
+
f"4. Try connecting again",
|
|
635
|
+
provider="lmstudio",
|
|
636
|
+
model=model,
|
|
637
|
+
) from e
|
|
638
|
+
else:
|
|
639
|
+
raise GatewayError(
|
|
640
|
+
f"LM Studio request failed: {str(e)}",
|
|
641
|
+
provider="lmstudio",
|
|
642
|
+
model=model,
|
|
643
|
+
) from e
|
|
644
|
+
except Exception as e:
|
|
645
|
+
raise GatewayError(
|
|
646
|
+
f"LM Studio error: {str(e)}",
|
|
647
|
+
provider="lmstudio",
|
|
648
|
+
model=model,
|
|
649
|
+
) from e
|
|
650
|
+
|
|
651
|
+
async def _mlx_stream_completion(
|
|
652
|
+
self,
|
|
653
|
+
messages: List[Message],
|
|
654
|
+
model: str,
|
|
655
|
+
temperature: Optional[float] = None,
|
|
656
|
+
max_tokens: Optional[int] = None,
|
|
657
|
+
tools: Optional[List[ToolDefinition]] = None,
|
|
658
|
+
tool_choice: Optional[str] = None,
|
|
659
|
+
**kwargs,
|
|
660
|
+
) -> AsyncIterator[StreamChunk]:
|
|
661
|
+
"""Handle MLX streaming completion directly."""
|
|
662
|
+
from ..local.mlx import MLXClient
|
|
663
|
+
|
|
664
|
+
client = MLXClient()
|
|
665
|
+
|
|
666
|
+
# Convert messages to OpenAI format
|
|
667
|
+
openai_messages = []
|
|
668
|
+
for msg in messages:
|
|
669
|
+
openai_msg = {"role": msg.role, "content": msg.content}
|
|
670
|
+
if msg.name:
|
|
671
|
+
openai_msg["name"] = msg.name
|
|
672
|
+
if msg.tool_calls:
|
|
673
|
+
openai_msg["tool_calls"] = msg.tool_calls
|
|
674
|
+
if msg.tool_call_id:
|
|
675
|
+
openai_msg["tool_call_id"] = msg.tool_call_id
|
|
676
|
+
openai_messages.append(openai_msg)
|
|
677
|
+
|
|
678
|
+
# Build request - MLX server may not support streaming properly, so use non-streaming
|
|
679
|
+
request_data = {
|
|
680
|
+
"model": model,
|
|
681
|
+
"messages": openai_messages,
|
|
682
|
+
# Note: Not setting stream=True as MLX server streaming may cause KV cache issues
|
|
683
|
+
}
|
|
684
|
+
|
|
685
|
+
if temperature is not None:
|
|
686
|
+
request_data["temperature"] = temperature
|
|
687
|
+
if max_tokens is not None:
|
|
688
|
+
request_data["max_tokens"] = max_tokens
|
|
689
|
+
if tools:
|
|
690
|
+
request_data["tools"] = self._convert_tools(tools)
|
|
691
|
+
if tool_choice:
|
|
692
|
+
request_data["tool_choice"] = tool_choice
|
|
693
|
+
|
|
694
|
+
try:
|
|
695
|
+
# Make non-streaming request to MLX server (streaming causes KV cache issues)
|
|
696
|
+
response_data = await client._async_request(
|
|
697
|
+
"POST", "/v1/chat/completions", request_data, timeout=120.0
|
|
698
|
+
)
|
|
699
|
+
|
|
700
|
+
# Extract response and yield as single chunk
|
|
701
|
+
choice = response_data["choices"][0]
|
|
702
|
+
message = choice["message"]
|
|
703
|
+
|
|
704
|
+
# Get content and clean it up
|
|
705
|
+
content = message.get("content", "")
|
|
706
|
+
|
|
707
|
+
# Clean up MLX response content - remove special tokens that might confuse users
|
|
708
|
+
if content:
|
|
709
|
+
# Some MLX models return content with special tokens like <|channel|>, <|message|>, etc.
|
|
710
|
+
# Clean these up for better user experience
|
|
711
|
+
content = (
|
|
712
|
+
content.replace("<|channel|>", "")
|
|
713
|
+
.replace("<|message|>", "")
|
|
714
|
+
.replace("<|end|>", "")
|
|
715
|
+
.replace("<|start|>", "")
|
|
716
|
+
)
|
|
717
|
+
content = content.replace(
|
|
718
|
+
"assistant", ""
|
|
719
|
+
).strip() # Remove duplicate assistant markers
|
|
720
|
+
|
|
721
|
+
yield StreamChunk(
|
|
722
|
+
content=content,
|
|
723
|
+
role=message.get("role"),
|
|
724
|
+
finish_reason=choice.get("finish_reason"),
|
|
725
|
+
tool_calls=message.get("tool_calls"),
|
|
726
|
+
)
|
|
727
|
+
|
|
728
|
+
except Exception as e:
|
|
729
|
+
# Provide more specific MLX error messages
|
|
730
|
+
error_msg = str(e)
|
|
731
|
+
if "broadcast_shapes" in error_msg or "cannot be broadcast" in error_msg:
|
|
732
|
+
raise GatewayError(
|
|
733
|
+
f"MLX server encountered a KV cache conflict (concurrent request issue).\n\n"
|
|
734
|
+
f"This happens when multiple requests are sent to the MLX server simultaneously.\n"
|
|
735
|
+
f"MLX servers can only handle one request at a time to avoid memory conflicts.\n\n"
|
|
736
|
+
f"To fix:\n"
|
|
737
|
+
f"1. [yellow]Wait for any running requests to complete[/yellow]\n"
|
|
738
|
+
f"2. [cyan]superqode providers mlx list[/cyan] - Check server status\n"
|
|
739
|
+
f"3. If server crashed: [cyan]superqode providers mlx server --model {model}[/cyan] - Restart server\n"
|
|
740
|
+
f"4. Try your request again with only one active session\n\n"
|
|
741
|
+
f"[dim]💡 MLX Tip: Each model needs its own server instance for concurrent use[/dim]",
|
|
742
|
+
provider="mlx",
|
|
743
|
+
model=model,
|
|
744
|
+
) from e
|
|
745
|
+
elif "Connection refused" in error_msg:
|
|
746
|
+
raise GatewayError(
|
|
747
|
+
f"Cannot connect to MLX server at http://localhost:8080.\n\n"
|
|
748
|
+
f"MLX server is not running. To fix:\n\n"
|
|
749
|
+
f"1. [cyan]superqode providers mlx setup[/cyan] - Complete setup guide\n"
|
|
750
|
+
f"2. [cyan]superqode providers mlx server --model {model}[/cyan] - Get server command\n"
|
|
751
|
+
f"3. Run the server command in a separate terminal\n"
|
|
752
|
+
f"4. Try connecting again",
|
|
753
|
+
provider="mlx",
|
|
754
|
+
model=model,
|
|
755
|
+
) from e
|
|
756
|
+
elif "Connection timed out" in error_msg or "timeout" in error_msg.lower():
|
|
757
|
+
raise GatewayError(
|
|
758
|
+
f"MLX server timed out. Large MLX models (like {model}) can take 1-2 minutes for first response.\n\n"
|
|
759
|
+
f"Please wait and try again. If this persists:\n"
|
|
760
|
+
f"1. Check server is still running: [cyan]superqode providers mlx list[/cyan]\n"
|
|
761
|
+
f"2. Try a smaller model for testing\n"
|
|
762
|
+
f"3. Restart the server if needed",
|
|
763
|
+
provider="mlx",
|
|
764
|
+
model=model,
|
|
765
|
+
) from e
|
|
766
|
+
else:
|
|
767
|
+
# Convert to gateway error
|
|
768
|
+
self._handle_litellm_error(e, "mlx", model)
|
|
769
|
+
|
|
770
|
+
async def chat_completion(
|
|
771
|
+
self,
|
|
772
|
+
messages: List[Message],
|
|
773
|
+
model: str,
|
|
774
|
+
provider: Optional[str] = None,
|
|
775
|
+
temperature: Optional[float] = None,
|
|
776
|
+
max_tokens: Optional[int] = None,
|
|
777
|
+
tools: Optional[List[ToolDefinition]] = None,
|
|
778
|
+
tool_choice: Optional[str] = None,
|
|
779
|
+
**kwargs,
|
|
780
|
+
) -> GatewayResponse:
|
|
781
|
+
"""Make a chat completion request via LiteLLM."""
|
|
782
|
+
|
|
783
|
+
# Determine provider from model string if not specified
|
|
784
|
+
if not provider and "/" in model:
|
|
785
|
+
provider = model.split("/")[0]
|
|
786
|
+
provider = provider or "unknown"
|
|
787
|
+
|
|
788
|
+
# Special handling for MLX - use direct client instead of LiteLLM
|
|
789
|
+
if provider == "mlx":
|
|
790
|
+
return await self._mlx_chat_completion(
|
|
791
|
+
messages, model, temperature, max_tokens, tools, tool_choice, **kwargs
|
|
792
|
+
)
|
|
793
|
+
|
|
794
|
+
# Special handling for LM Studio - use direct client to avoid cloud API
|
|
795
|
+
if provider == "lmstudio":
|
|
796
|
+
return await self._lmstudio_chat_completion(
|
|
797
|
+
messages, model, temperature, max_tokens, tools, tool_choice, **kwargs
|
|
798
|
+
)
|
|
799
|
+
|
|
800
|
+
litellm = self._get_litellm()
|
|
801
|
+
|
|
802
|
+
# Set up provider environment
|
|
803
|
+
self._setup_provider_env(provider)
|
|
804
|
+
|
|
805
|
+
# Build model string
|
|
806
|
+
model_string = self.get_model_string(provider, model) if provider != "unknown" else model
|
|
807
|
+
|
|
808
|
+
# Build request
|
|
809
|
+
request_kwargs = {
|
|
810
|
+
"model": model_string,
|
|
811
|
+
"messages": self._convert_messages(messages),
|
|
812
|
+
"timeout": self.timeout,
|
|
813
|
+
}
|
|
814
|
+
|
|
815
|
+
# Explicitly pass API keys for providers that need them
|
|
816
|
+
# Some LiteLLM versions require explicit api_key parameter
|
|
817
|
+
if provider == "google":
|
|
818
|
+
google_key = os.environ.get("GOOGLE_API_KEY") or os.environ.get("GEMINI_API_KEY")
|
|
819
|
+
if google_key:
|
|
820
|
+
request_kwargs["api_key"] = google_key
|
|
821
|
+
|
|
822
|
+
if temperature is not None:
|
|
823
|
+
request_kwargs["temperature"] = temperature
|
|
824
|
+
if max_tokens is not None:
|
|
825
|
+
request_kwargs["max_tokens"] = max_tokens
|
|
826
|
+
if tools:
|
|
827
|
+
request_kwargs["tools"] = self._convert_tools(tools)
|
|
828
|
+
if tool_choice:
|
|
829
|
+
request_kwargs["tool_choice"] = tool_choice
|
|
830
|
+
|
|
831
|
+
# Add any extra kwargs
|
|
832
|
+
request_kwargs.update(kwargs)
|
|
833
|
+
|
|
834
|
+
try:
|
|
835
|
+
response = await litellm.acompletion(**request_kwargs)
|
|
836
|
+
|
|
837
|
+
# Extract response data
|
|
838
|
+
choice = response.choices[0]
|
|
839
|
+
message = choice.message
|
|
840
|
+
|
|
841
|
+
# Parse content - handle Ollama JSON responses and detect empty responses
|
|
842
|
+
content = message.content or ""
|
|
843
|
+
|
|
844
|
+
# Check if response is completely empty (no content, no tool calls)
|
|
845
|
+
if not content.strip() and not (hasattr(message, "tool_calls") and message.tool_calls):
|
|
846
|
+
# This model returned nothing - provide a helpful error
|
|
847
|
+
content = f"⚠️ Model '{provider}/{model}' returned an empty response.\n\nThis usually means:\n• The model is not properly configured or available\n• The model may be overloaded or rate-limited\n• Check that the model exists and is accessible\n\nTry using a different model or check your provider configuration."
|
|
848
|
+
|
|
849
|
+
elif isinstance(content, str) and content.strip().startswith("{"):
|
|
850
|
+
try:
|
|
851
|
+
parsed = json.loads(content)
|
|
852
|
+
# Extract text from common Ollama JSON formats
|
|
853
|
+
if isinstance(parsed, dict):
|
|
854
|
+
# Try common fields in order of preference
|
|
855
|
+
content = (
|
|
856
|
+
parsed.get("response")
|
|
857
|
+
or parsed.get("message")
|
|
858
|
+
or parsed.get("content")
|
|
859
|
+
or parsed.get("text")
|
|
860
|
+
or parsed.get("answer")
|
|
861
|
+
or parsed.get("output")
|
|
862
|
+
or content
|
|
863
|
+
)
|
|
864
|
+
# If content is still a dict, try to extract from it
|
|
865
|
+
if isinstance(content, dict):
|
|
866
|
+
content = (
|
|
867
|
+
content.get("content")
|
|
868
|
+
or content.get("text")
|
|
869
|
+
or content.get("message")
|
|
870
|
+
or str(content)
|
|
871
|
+
)
|
|
872
|
+
elif not isinstance(content, str):
|
|
873
|
+
content = str(content)
|
|
874
|
+
except (json.JSONDecodeError, AttributeError):
|
|
875
|
+
# Not valid JSON or can't parse, use as-is
|
|
876
|
+
pass
|
|
877
|
+
|
|
878
|
+
# Build usage info
|
|
879
|
+
usage = None
|
|
880
|
+
if response.usage:
|
|
881
|
+
usage = Usage(
|
|
882
|
+
prompt_tokens=response.usage.prompt_tokens or 0,
|
|
883
|
+
completion_tokens=response.usage.completion_tokens or 0,
|
|
884
|
+
total_tokens=response.usage.total_tokens or 0,
|
|
885
|
+
)
|
|
886
|
+
|
|
887
|
+
# Build cost info if tracking enabled
|
|
888
|
+
cost = None
|
|
889
|
+
if self.track_costs and hasattr(response, "_hidden_params"):
|
|
890
|
+
hidden = response._hidden_params or {}
|
|
891
|
+
if "response_cost" in hidden:
|
|
892
|
+
cost = Cost(total_cost=hidden["response_cost"])
|
|
893
|
+
|
|
894
|
+
# Extract thinking/reasoning content from response
|
|
895
|
+
thinking_content = None
|
|
896
|
+
thinking_tokens = None
|
|
897
|
+
|
|
898
|
+
# Check for extended thinking in various formats
|
|
899
|
+
if hasattr(response, "_hidden_params"):
|
|
900
|
+
hidden = response._hidden_params or {}
|
|
901
|
+
# Claude extended thinking
|
|
902
|
+
if "thinking" in hidden:
|
|
903
|
+
thinking_content = hidden["thinking"]
|
|
904
|
+
elif "reasoning" in hidden:
|
|
905
|
+
thinking_content = hidden["reasoning"]
|
|
906
|
+
# o1 reasoning tokens
|
|
907
|
+
elif "reasoning_tokens" in hidden:
|
|
908
|
+
thinking_content = hidden.get("reasoning_content", "")
|
|
909
|
+
thinking_tokens = hidden.get("reasoning_tokens", 0)
|
|
910
|
+
|
|
911
|
+
# Check raw response for thinking fields
|
|
912
|
+
if not thinking_content and hasattr(response, "response_msgs"):
|
|
913
|
+
# Some providers expose thinking in response_msgs
|
|
914
|
+
for msg in response.response_msgs:
|
|
915
|
+
if hasattr(msg, "thinking") and msg.thinking:
|
|
916
|
+
thinking_content = msg.thinking
|
|
917
|
+
break
|
|
918
|
+
|
|
919
|
+
# Check message for thinking fields (Claude format)
|
|
920
|
+
if not thinking_content and hasattr(message, "thinking"):
|
|
921
|
+
thinking_content = message.thinking
|
|
922
|
+
|
|
923
|
+
# Check for stop_reason indicating thinking (Claude extended thinking)
|
|
924
|
+
if not thinking_content and choice.finish_reason == "thinking":
|
|
925
|
+
# Extended thinking mode - content might be in a different field
|
|
926
|
+
if hasattr(choice, "thinking") and choice.thinking:
|
|
927
|
+
thinking_content = choice.thinking
|
|
928
|
+
elif hasattr(message, "thinking") and message.thinking:
|
|
929
|
+
thinking_content = message.thinking
|
|
930
|
+
|
|
931
|
+
# Extract thinking tokens from usage if available
|
|
932
|
+
if thinking_content and usage and not thinking_tokens:
|
|
933
|
+
# Some providers report thinking tokens separately
|
|
934
|
+
if hasattr(response, "_hidden_params"):
|
|
935
|
+
hidden = response._hidden_params or {}
|
|
936
|
+
thinking_tokens = hidden.get("thinking_tokens") or hidden.get(
|
|
937
|
+
"reasoning_tokens"
|
|
938
|
+
)
|
|
939
|
+
|
|
940
|
+
# Normalize tool calls from LiteLLM response (may be objects or dicts)
|
|
941
|
+
tool_calls = None
|
|
942
|
+
if hasattr(message, "tool_calls") and message.tool_calls:
|
|
943
|
+
tool_calls = self._normalize_tool_calls(message.tool_calls)
|
|
944
|
+
|
|
945
|
+
return GatewayResponse(
|
|
946
|
+
content=content,
|
|
947
|
+
role=message.role,
|
|
948
|
+
finish_reason=choice.finish_reason,
|
|
949
|
+
usage=usage,
|
|
950
|
+
cost=cost,
|
|
951
|
+
model=response.model,
|
|
952
|
+
provider=provider,
|
|
953
|
+
tool_calls=tool_calls,
|
|
954
|
+
raw_response=response,
|
|
955
|
+
thinking_content=thinking_content,
|
|
956
|
+
thinking_tokens=thinking_tokens,
|
|
957
|
+
)
|
|
958
|
+
|
|
959
|
+
except Exception as e:
|
|
960
|
+
self._handle_litellm_error(e, provider, model)
|
|
961
|
+
|
|
962
|
+
async def stream_completion(
|
|
963
|
+
self,
|
|
964
|
+
messages: List[Message],
|
|
965
|
+
model: str,
|
|
966
|
+
provider: Optional[str] = None,
|
|
967
|
+
temperature: Optional[float] = None,
|
|
968
|
+
max_tokens: Optional[int] = None,
|
|
969
|
+
tools: Optional[List[ToolDefinition]] = None,
|
|
970
|
+
tool_choice: Optional[str] = None,
|
|
971
|
+
**kwargs,
|
|
972
|
+
) -> AsyncIterator[StreamChunk]:
|
|
973
|
+
"""Make a streaming chat completion request via LiteLLM."""
|
|
974
|
+
|
|
975
|
+
# Determine provider from model string if not specified
|
|
976
|
+
if not provider and "/" in model:
|
|
977
|
+
provider = model.split("/")[0]
|
|
978
|
+
provider = provider or "unknown"
|
|
979
|
+
|
|
980
|
+
# Special handling for MLX - use direct client instead of LiteLLM
|
|
981
|
+
if provider == "mlx":
|
|
982
|
+
async for chunk in self._mlx_stream_completion(
|
|
983
|
+
messages, model, temperature, max_tokens, tools, tool_choice, **kwargs
|
|
984
|
+
):
|
|
985
|
+
yield chunk
|
|
986
|
+
return
|
|
987
|
+
|
|
988
|
+
litellm = self._get_litellm()
|
|
989
|
+
|
|
990
|
+
# Set up provider environment
|
|
991
|
+
self._setup_provider_env(provider)
|
|
992
|
+
|
|
993
|
+
# Build model string
|
|
994
|
+
model_string = self.get_model_string(provider, model) if provider != "unknown" else model
|
|
995
|
+
|
|
996
|
+
# Build request
|
|
997
|
+
request_kwargs = {
|
|
998
|
+
"model": model_string,
|
|
999
|
+
"messages": self._convert_messages(messages),
|
|
1000
|
+
"stream": True,
|
|
1001
|
+
"timeout": self.timeout,
|
|
1002
|
+
}
|
|
1003
|
+
|
|
1004
|
+
if temperature is not None:
|
|
1005
|
+
request_kwargs["temperature"] = temperature
|
|
1006
|
+
if max_tokens is not None:
|
|
1007
|
+
request_kwargs["max_tokens"] = max_tokens
|
|
1008
|
+
if tools:
|
|
1009
|
+
request_kwargs["tools"] = self._convert_tools(tools)
|
|
1010
|
+
if tool_choice:
|
|
1011
|
+
request_kwargs["tool_choice"] = tool_choice
|
|
1012
|
+
|
|
1013
|
+
# Explicitly pass API keys for providers that need them
|
|
1014
|
+
# Some LiteLLM versions require explicit api_key parameter
|
|
1015
|
+
if provider == "google":
|
|
1016
|
+
google_key = os.environ.get("GOOGLE_API_KEY") or os.environ.get("GEMINI_API_KEY")
|
|
1017
|
+
if google_key:
|
|
1018
|
+
request_kwargs["api_key"] = google_key
|
|
1019
|
+
|
|
1020
|
+
request_kwargs.update(kwargs)
|
|
1021
|
+
|
|
1022
|
+
try:
|
|
1023
|
+
response = await litellm.acompletion(**request_kwargs)
|
|
1024
|
+
|
|
1025
|
+
if not response:
|
|
1026
|
+
raise GatewayError(
|
|
1027
|
+
f"No response from {provider}/{model}",
|
|
1028
|
+
provider=provider,
|
|
1029
|
+
model=model,
|
|
1030
|
+
)
|
|
1031
|
+
|
|
1032
|
+
async for chunk in response:
|
|
1033
|
+
if not chunk.choices:
|
|
1034
|
+
continue
|
|
1035
|
+
|
|
1036
|
+
choice = chunk.choices[0]
|
|
1037
|
+
delta = choice.delta
|
|
1038
|
+
|
|
1039
|
+
# Extract thinking content if available (for extended thinking models)
|
|
1040
|
+
thinking_content = None
|
|
1041
|
+
if hasattr(delta, "thinking") and delta.thinking:
|
|
1042
|
+
thinking_content = delta.thinking
|
|
1043
|
+
elif hasattr(choice, "thinking") and choice.thinking:
|
|
1044
|
+
thinking_content = choice.thinking
|
|
1045
|
+
|
|
1046
|
+
# Extract content - handle Ollama JSON responses
|
|
1047
|
+
content = ""
|
|
1048
|
+
if delta and delta.content:
|
|
1049
|
+
content_str = delta.content
|
|
1050
|
+
# Note: In streaming mode, JSON might come in chunks, so we only parse
|
|
1051
|
+
# if we have a complete JSON object (starts with { and ends with })
|
|
1052
|
+
# Otherwise, we pass through the content as-is
|
|
1053
|
+
if (
|
|
1054
|
+
isinstance(content_str, str)
|
|
1055
|
+
and content_str.strip().startswith("{")
|
|
1056
|
+
and content_str.strip().endswith("}")
|
|
1057
|
+
):
|
|
1058
|
+
try:
|
|
1059
|
+
parsed = json.loads(content_str)
|
|
1060
|
+
# Extract text from common Ollama JSON formats
|
|
1061
|
+
if isinstance(parsed, dict):
|
|
1062
|
+
# Try common fields in order of preference
|
|
1063
|
+
content = (
|
|
1064
|
+
parsed.get("response")
|
|
1065
|
+
or parsed.get("message")
|
|
1066
|
+
or parsed.get("content")
|
|
1067
|
+
or parsed.get("text")
|
|
1068
|
+
or parsed.get("answer")
|
|
1069
|
+
or parsed.get("output")
|
|
1070
|
+
or content_str
|
|
1071
|
+
)
|
|
1072
|
+
# If content is still a dict, try to extract from it
|
|
1073
|
+
if isinstance(content, dict):
|
|
1074
|
+
content = (
|
|
1075
|
+
content.get("content")
|
|
1076
|
+
or content.get("text")
|
|
1077
|
+
or content.get("message")
|
|
1078
|
+
or content_str
|
|
1079
|
+
)
|
|
1080
|
+
else:
|
|
1081
|
+
content = content_str
|
|
1082
|
+
except (json.JSONDecodeError, AttributeError):
|
|
1083
|
+
# Not valid JSON or can't parse, use as-is
|
|
1084
|
+
content = content_str
|
|
1085
|
+
else:
|
|
1086
|
+
content = content_str
|
|
1087
|
+
|
|
1088
|
+
stream_chunk = StreamChunk(
|
|
1089
|
+
content=content,
|
|
1090
|
+
role=delta.role if delta and hasattr(delta, "role") else None,
|
|
1091
|
+
finish_reason=choice.finish_reason,
|
|
1092
|
+
thinking_content=thinking_content,
|
|
1093
|
+
)
|
|
1094
|
+
|
|
1095
|
+
# Handle tool calls in stream
|
|
1096
|
+
# Normalize tool calls (may be objects or dicts from LiteLLM)
|
|
1097
|
+
if delta and hasattr(delta, "tool_calls") and delta.tool_calls:
|
|
1098
|
+
stream_chunk.tool_calls = self._normalize_tool_calls(delta.tool_calls)
|
|
1099
|
+
|
|
1100
|
+
yield stream_chunk
|
|
1101
|
+
|
|
1102
|
+
except GatewayError:
|
|
1103
|
+
# Re-raise gateway errors (they're already formatted)
|
|
1104
|
+
raise
|
|
1105
|
+
except Exception as e:
|
|
1106
|
+
# Convert LiteLLM errors to gateway errors
|
|
1107
|
+
self._handle_litellm_error(e, provider, model)
|
|
1108
|
+
|
|
1109
|
+
async def test_connection(
|
|
1110
|
+
self,
|
|
1111
|
+
provider: str,
|
|
1112
|
+
model: Optional[str] = None,
|
|
1113
|
+
) -> Dict[str, Any]:
|
|
1114
|
+
"""Test connection to a provider."""
|
|
1115
|
+
provider_def = PROVIDERS.get(provider)
|
|
1116
|
+
|
|
1117
|
+
if not provider_def:
|
|
1118
|
+
return {
|
|
1119
|
+
"success": False,
|
|
1120
|
+
"provider": provider,
|
|
1121
|
+
"error": f"Provider '{provider}' not found in registry",
|
|
1122
|
+
}
|
|
1123
|
+
|
|
1124
|
+
# Use first example model if not specified
|
|
1125
|
+
test_model = model or (
|
|
1126
|
+
provider_def.example_models[0] if provider_def.example_models else None
|
|
1127
|
+
)
|
|
1128
|
+
|
|
1129
|
+
if not test_model:
|
|
1130
|
+
return {
|
|
1131
|
+
"success": False,
|
|
1132
|
+
"provider": provider,
|
|
1133
|
+
"error": "No model specified and no example models available",
|
|
1134
|
+
}
|
|
1135
|
+
|
|
1136
|
+
try:
|
|
1137
|
+
# Make a minimal test request
|
|
1138
|
+
response = await self.chat_completion(
|
|
1139
|
+
messages=[Message(role="user", content="Hi")],
|
|
1140
|
+
model=test_model,
|
|
1141
|
+
provider=provider,
|
|
1142
|
+
max_tokens=5,
|
|
1143
|
+
)
|
|
1144
|
+
|
|
1145
|
+
return {
|
|
1146
|
+
"success": True,
|
|
1147
|
+
"provider": provider,
|
|
1148
|
+
"model": test_model,
|
|
1149
|
+
"response_model": response.model,
|
|
1150
|
+
"usage": {
|
|
1151
|
+
"prompt_tokens": response.usage.prompt_tokens if response.usage else 0,
|
|
1152
|
+
"completion_tokens": response.usage.completion_tokens if response.usage else 0,
|
|
1153
|
+
},
|
|
1154
|
+
}
|
|
1155
|
+
|
|
1156
|
+
except GatewayError as e:
|
|
1157
|
+
return {
|
|
1158
|
+
"success": False,
|
|
1159
|
+
"provider": provider,
|
|
1160
|
+
"model": test_model,
|
|
1161
|
+
"error": str(e),
|
|
1162
|
+
"error_type": e.error_type,
|
|
1163
|
+
}
|
|
1164
|
+
except Exception as e:
|
|
1165
|
+
return {
|
|
1166
|
+
"success": False,
|
|
1167
|
+
"provider": provider,
|
|
1168
|
+
"model": test_model,
|
|
1169
|
+
"error": str(e),
|
|
1170
|
+
}
|