sandboxy 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sandboxy/agents/llm_prompt.py +85 -14
- sandboxy/api/app.py +2 -1
- sandboxy/api/routes/local.py +34 -1
- sandboxy/api/routes/providers.py +369 -0
- sandboxy/cli/main.py +371 -0
- sandboxy/mlflow/exporter.py +7 -1
- sandboxy/providers/__init__.py +37 -3
- sandboxy/providers/config.py +243 -0
- sandboxy/providers/local.py +498 -0
- sandboxy/providers/registry.py +107 -13
- sandboxy/scenarios/unified.py +27 -3
- sandboxy/ui/dist/assets/index-BZFjoK-_.js +377 -0
- sandboxy/ui/dist/assets/index-Qf7gGJk_.css +1 -0
- sandboxy/ui/dist/index.html +2 -2
- {sandboxy-0.0.4.dist-info → sandboxy-0.0.6.dist-info}/METADATA +67 -27
- {sandboxy-0.0.4.dist-info → sandboxy-0.0.6.dist-info}/RECORD +19 -16
- sandboxy/ui/dist/assets/index-CU06wBqc.js +0 -362
- sandboxy/ui/dist/assets/index-Cgg2wY2m.css +0 -1
- {sandboxy-0.0.4.dist-info → sandboxy-0.0.6.dist-info}/WHEEL +0 -0
- {sandboxy-0.0.4.dist-info → sandboxy-0.0.6.dist-info}/entry_points.txt +0 -0
- {sandboxy-0.0.4.dist-info → sandboxy-0.0.6.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,498 @@
|
|
|
1
|
+
"""Local model provider for OpenAI-compatible servers (Ollama, LM Studio, vLLM, etc.)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import time
|
|
7
|
+
from collections.abc import AsyncIterator
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
import httpx
|
|
11
|
+
|
|
12
|
+
from sandboxy.providers.base import BaseProvider, ModelInfo, ModelResponse, ProviderError
|
|
13
|
+
from sandboxy.providers.config import (
|
|
14
|
+
LocalModelInfo,
|
|
15
|
+
LocalProviderConfig,
|
|
16
|
+
ProviderStatus,
|
|
17
|
+
ProviderStatusEnum,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
# Default timeout for requests (60 seconds)
|
|
23
|
+
DEFAULT_TIMEOUT = 60.0
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class LocalProviderConnectionError(ProviderError):
|
|
27
|
+
"""Error when local provider is unreachable."""
|
|
28
|
+
|
|
29
|
+
def __init__(self, provider_name: str, base_url: str, original_error: str):
|
|
30
|
+
self.base_url = base_url
|
|
31
|
+
self.original_error = original_error
|
|
32
|
+
message = (
|
|
33
|
+
f"Cannot connect to {provider_name} at {base_url}. "
|
|
34
|
+
f"Is the server running? Error: {original_error}"
|
|
35
|
+
)
|
|
36
|
+
super().__init__(message, provider=provider_name)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class LocalProvider(BaseProvider):
|
|
40
|
+
"""Provider for local OpenAI-compatible servers.
|
|
41
|
+
|
|
42
|
+
Supports:
|
|
43
|
+
- Ollama (http://localhost:11434/v1)
|
|
44
|
+
- LM Studio (http://localhost:1234/v1)
|
|
45
|
+
- vLLM (http://localhost:8000/v1)
|
|
46
|
+
- Any OpenAI-compatible endpoint
|
|
47
|
+
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
provider_name: str = "local"
|
|
51
|
+
|
|
52
|
+
def __init__(self, config: LocalProviderConfig):
|
|
53
|
+
"""Initialize local provider with configuration.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
config: Provider configuration including base URL and optional API key
|
|
57
|
+
|
|
58
|
+
"""
|
|
59
|
+
self.config = config
|
|
60
|
+
self.provider_name = config.name
|
|
61
|
+
|
|
62
|
+
# Build headers
|
|
63
|
+
headers: dict[str, str] = {"Content-Type": "application/json"}
|
|
64
|
+
if config.api_key:
|
|
65
|
+
headers["Authorization"] = f"Bearer {config.api_key}"
|
|
66
|
+
|
|
67
|
+
self._client = httpx.AsyncClient(
|
|
68
|
+
base_url=config.base_url,
|
|
69
|
+
headers=headers,
|
|
70
|
+
timeout=DEFAULT_TIMEOUT,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# Cache for discovered models
|
|
74
|
+
self._models_cache: list[LocalModelInfo] | None = None
|
|
75
|
+
self._tool_support_cache: dict[str, bool] = {}
|
|
76
|
+
|
|
77
|
+
async def close(self) -> None:
|
|
78
|
+
"""Close the HTTP client."""
|
|
79
|
+
await self._client.aclose()
|
|
80
|
+
|
|
81
|
+
async def complete(
|
|
82
|
+
self,
|
|
83
|
+
model: str,
|
|
84
|
+
messages: list[dict[str, Any]],
|
|
85
|
+
temperature: float = 0.7,
|
|
86
|
+
max_tokens: int = 4096,
|
|
87
|
+
tools: list[dict[str, Any]] | None = None,
|
|
88
|
+
**kwargs: Any,
|
|
89
|
+
) -> ModelResponse:
|
|
90
|
+
"""Send a chat completion request to the local server.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
model: Model identifier (e.g., "llama3:8b", "mistral:latest")
|
|
94
|
+
messages: List of message dicts with 'role' and 'content'
|
|
95
|
+
temperature: Sampling temperature (0-2)
|
|
96
|
+
max_tokens: Maximum tokens in response
|
|
97
|
+
tools: Optional list of tool definitions for function calling
|
|
98
|
+
**kwargs: Additional parameters passed to the API
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
ModelResponse with content and metadata
|
|
102
|
+
|
|
103
|
+
Raises:
|
|
104
|
+
LocalProviderConnectionError: If server is unreachable
|
|
105
|
+
ProviderError: If the request fails
|
|
106
|
+
|
|
107
|
+
"""
|
|
108
|
+
# Strip provider prefix if present (e.g., "ollama/llama3" -> "llama3")
|
|
109
|
+
if "/" in model:
|
|
110
|
+
_, model = model.rsplit("/", 1)
|
|
111
|
+
|
|
112
|
+
start_time = time.perf_counter()
|
|
113
|
+
|
|
114
|
+
# Build request payload
|
|
115
|
+
payload: dict[str, Any] = {
|
|
116
|
+
"model": model,
|
|
117
|
+
"messages": messages,
|
|
118
|
+
"temperature": temperature,
|
|
119
|
+
"max_tokens": max_tokens,
|
|
120
|
+
"stream": False,
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
# Add tools if provided and model might support them
|
|
124
|
+
if tools:
|
|
125
|
+
payload["tools"] = tools
|
|
126
|
+
|
|
127
|
+
# Merge any default params from config
|
|
128
|
+
payload.update(self.config.default_params)
|
|
129
|
+
payload.update(kwargs)
|
|
130
|
+
|
|
131
|
+
try:
|
|
132
|
+
response = await self._client.post("/chat/completions", json=payload)
|
|
133
|
+
response.raise_for_status()
|
|
134
|
+
data = response.json()
|
|
135
|
+
except httpx.ConnectError as e:
|
|
136
|
+
raise LocalProviderConnectionError(
|
|
137
|
+
self.config.name,
|
|
138
|
+
self.config.base_url,
|
|
139
|
+
str(e),
|
|
140
|
+
) from e
|
|
141
|
+
except httpx.HTTPStatusError as e:
|
|
142
|
+
error_detail = self._extract_error_detail(e)
|
|
143
|
+
raise ProviderError(
|
|
144
|
+
f"Request failed: {error_detail}",
|
|
145
|
+
provider=self.config.name,
|
|
146
|
+
model=model,
|
|
147
|
+
) from e
|
|
148
|
+
except httpx.TimeoutException as e:
|
|
149
|
+
raise ProviderError(
|
|
150
|
+
f"Request timed out after {DEFAULT_TIMEOUT}s",
|
|
151
|
+
provider=self.config.name,
|
|
152
|
+
model=model,
|
|
153
|
+
) from e
|
|
154
|
+
|
|
155
|
+
latency_ms = int((time.perf_counter() - start_time) * 1000)
|
|
156
|
+
|
|
157
|
+
# Extract response content
|
|
158
|
+
choice = data.get("choices", [{}])[0]
|
|
159
|
+
message = choice.get("message", {})
|
|
160
|
+
content = message.get("content", "")
|
|
161
|
+
|
|
162
|
+
# Handle tool calls in response
|
|
163
|
+
tool_calls = message.get("tool_calls")
|
|
164
|
+
if tool_calls:
|
|
165
|
+
# Include tool calls in raw response for caller to handle
|
|
166
|
+
pass
|
|
167
|
+
|
|
168
|
+
# Extract token usage
|
|
169
|
+
usage = data.get("usage", {})
|
|
170
|
+
input_tokens = usage.get("prompt_tokens", 0)
|
|
171
|
+
output_tokens = usage.get("completion_tokens", 0)
|
|
172
|
+
|
|
173
|
+
# If no usage provided, estimate with tiktoken
|
|
174
|
+
if input_tokens == 0 and output_tokens == 0:
|
|
175
|
+
input_tokens, output_tokens = self._estimate_tokens(messages, content)
|
|
176
|
+
|
|
177
|
+
return ModelResponse(
|
|
178
|
+
content=content,
|
|
179
|
+
model_id=model,
|
|
180
|
+
latency_ms=latency_ms,
|
|
181
|
+
input_tokens=input_tokens,
|
|
182
|
+
output_tokens=output_tokens,
|
|
183
|
+
cost_usd=0.0, # Local models have no API cost
|
|
184
|
+
finish_reason=choice.get("finish_reason"),
|
|
185
|
+
raw_response=data,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
async def stream(
|
|
189
|
+
self,
|
|
190
|
+
model: str,
|
|
191
|
+
messages: list[dict[str, Any]],
|
|
192
|
+
temperature: float = 0.7,
|
|
193
|
+
max_tokens: int = 4096,
|
|
194
|
+
**kwargs: Any,
|
|
195
|
+
) -> AsyncIterator[str]:
|
|
196
|
+
"""Stream a chat completion response.
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
model: Model identifier
|
|
200
|
+
messages: List of message dicts
|
|
201
|
+
temperature: Sampling temperature
|
|
202
|
+
max_tokens: Maximum tokens
|
|
203
|
+
**kwargs: Additional parameters
|
|
204
|
+
|
|
205
|
+
Yields:
|
|
206
|
+
Content chunks as they arrive
|
|
207
|
+
|
|
208
|
+
"""
|
|
209
|
+
# Strip provider prefix if present
|
|
210
|
+
if "/" in model:
|
|
211
|
+
_, model = model.rsplit("/", 1)
|
|
212
|
+
|
|
213
|
+
payload: dict[str, Any] = {
|
|
214
|
+
"model": model,
|
|
215
|
+
"messages": messages,
|
|
216
|
+
"temperature": temperature,
|
|
217
|
+
"max_tokens": max_tokens,
|
|
218
|
+
"stream": True,
|
|
219
|
+
}
|
|
220
|
+
payload.update(self.config.default_params)
|
|
221
|
+
payload.update(kwargs)
|
|
222
|
+
|
|
223
|
+
try:
|
|
224
|
+
async with self._client.stream("POST", "/chat/completions", json=payload) as response:
|
|
225
|
+
response.raise_for_status()
|
|
226
|
+
|
|
227
|
+
async for line in response.aiter_lines():
|
|
228
|
+
if not line or not line.startswith("data: "):
|
|
229
|
+
continue
|
|
230
|
+
|
|
231
|
+
data_str = line[6:] # Remove "data: " prefix
|
|
232
|
+
if data_str == "[DONE]":
|
|
233
|
+
break
|
|
234
|
+
|
|
235
|
+
try:
|
|
236
|
+
import json
|
|
237
|
+
|
|
238
|
+
data = json.loads(data_str)
|
|
239
|
+
delta = data.get("choices", [{}])[0].get("delta", {})
|
|
240
|
+
content = delta.get("content", "")
|
|
241
|
+
if content:
|
|
242
|
+
yield content
|
|
243
|
+
except Exception:
|
|
244
|
+
continue
|
|
245
|
+
|
|
246
|
+
except httpx.ConnectError as e:
|
|
247
|
+
raise LocalProviderConnectionError(
|
|
248
|
+
self.config.name,
|
|
249
|
+
self.config.base_url,
|
|
250
|
+
str(e),
|
|
251
|
+
) from e
|
|
252
|
+
|
|
253
|
+
def list_models(self) -> list[ModelInfo]:
|
|
254
|
+
"""Return available models from this provider.
|
|
255
|
+
|
|
256
|
+
Returns cached list if available. Call refresh_models() to update.
|
|
257
|
+
|
|
258
|
+
"""
|
|
259
|
+
if self._models_cache is not None:
|
|
260
|
+
return self._models_cache
|
|
261
|
+
|
|
262
|
+
# Return manually configured models if any
|
|
263
|
+
if self.config.models:
|
|
264
|
+
return [
|
|
265
|
+
LocalModelInfo(
|
|
266
|
+
id=model_id,
|
|
267
|
+
name=model_id,
|
|
268
|
+
provider=self.config.name,
|
|
269
|
+
provider_name=self.config.name,
|
|
270
|
+
context_length=8192, # Default, unknown
|
|
271
|
+
input_cost_per_million=None,
|
|
272
|
+
output_cost_per_million=None,
|
|
273
|
+
supports_tools=False, # Unknown until verified
|
|
274
|
+
supports_vision=False,
|
|
275
|
+
supports_streaming=True,
|
|
276
|
+
is_local=True,
|
|
277
|
+
capabilities_verified=False,
|
|
278
|
+
)
|
|
279
|
+
for model_id in self.config.models
|
|
280
|
+
]
|
|
281
|
+
|
|
282
|
+
# Return empty list - caller should use async refresh_models()
|
|
283
|
+
return []
|
|
284
|
+
|
|
285
|
+
async def refresh_models(self) -> list[LocalModelInfo]:
|
|
286
|
+
"""Fetch available models from the provider's /v1/models endpoint.
|
|
287
|
+
|
|
288
|
+
Returns:
|
|
289
|
+
List of discovered models
|
|
290
|
+
|
|
291
|
+
"""
|
|
292
|
+
try:
|
|
293
|
+
response = await self._client.get("/models")
|
|
294
|
+
response.raise_for_status()
|
|
295
|
+
data = response.json()
|
|
296
|
+
|
|
297
|
+
models: list[LocalModelInfo] = []
|
|
298
|
+
|
|
299
|
+
# Handle different response formats:
|
|
300
|
+
# - OpenAI format: {"data": [...]}
|
|
301
|
+
# - Ollama format: {"models": [...]} or direct list
|
|
302
|
+
model_list = data.get("data") or data.get("models") or []
|
|
303
|
+
if model_list is None:
|
|
304
|
+
model_list = []
|
|
305
|
+
|
|
306
|
+
# If data itself is a list (some providers return this)
|
|
307
|
+
if isinstance(data, list):
|
|
308
|
+
model_list = data
|
|
309
|
+
|
|
310
|
+
for model_data in model_list:
|
|
311
|
+
# Handle different model object formats
|
|
312
|
+
if isinstance(model_data, str):
|
|
313
|
+
# Some providers return just model IDs as strings
|
|
314
|
+
model_id = model_data
|
|
315
|
+
model_name = model_data
|
|
316
|
+
context_length = 8192
|
|
317
|
+
else:
|
|
318
|
+
# Object format - try various field names
|
|
319
|
+
model_id = (
|
|
320
|
+
model_data.get("id")
|
|
321
|
+
or model_data.get("model")
|
|
322
|
+
or model_data.get("name")
|
|
323
|
+
or "unknown"
|
|
324
|
+
)
|
|
325
|
+
model_name = model_data.get("name", model_id)
|
|
326
|
+
context_length = model_data.get("context_length", 8192)
|
|
327
|
+
|
|
328
|
+
models.append(
|
|
329
|
+
LocalModelInfo(
|
|
330
|
+
id=model_id,
|
|
331
|
+
name=model_name,
|
|
332
|
+
provider=self.config.name,
|
|
333
|
+
provider_name=self.config.name,
|
|
334
|
+
context_length=context_length,
|
|
335
|
+
input_cost_per_million=None,
|
|
336
|
+
output_cost_per_million=None,
|
|
337
|
+
supports_tools=self._infer_tool_support(model_id),
|
|
338
|
+
supports_vision=False,
|
|
339
|
+
supports_streaming=True,
|
|
340
|
+
is_local=True,
|
|
341
|
+
capabilities_verified=False,
|
|
342
|
+
)
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
self._models_cache = models
|
|
346
|
+
return models
|
|
347
|
+
|
|
348
|
+
except httpx.ConnectError as e:
|
|
349
|
+
raise LocalProviderConnectionError(
|
|
350
|
+
self.config.name,
|
|
351
|
+
self.config.base_url,
|
|
352
|
+
str(e),
|
|
353
|
+
) from e
|
|
354
|
+
except Exception as e:
|
|
355
|
+
logger.warning(f"Failed to fetch models from {self.config.name}: {e}")
|
|
356
|
+
# Return manually configured models as fallback
|
|
357
|
+
return self.list_models()
|
|
358
|
+
|
|
359
|
+
def supports_model(self, model_id: str) -> bool:
|
|
360
|
+
"""Check if this provider supports a given model.
|
|
361
|
+
|
|
362
|
+
Args:
|
|
363
|
+
model_id: Model identifier to check
|
|
364
|
+
|
|
365
|
+
Returns:
|
|
366
|
+
True if the model is supported
|
|
367
|
+
|
|
368
|
+
"""
|
|
369
|
+
# Strip provider prefix if present
|
|
370
|
+
if "/" in model_id:
|
|
371
|
+
prefix, model_name = model_id.rsplit("/", 1)
|
|
372
|
+
if prefix != self.config.name:
|
|
373
|
+
return False
|
|
374
|
+
model_id = model_name
|
|
375
|
+
|
|
376
|
+
# Check manually configured models
|
|
377
|
+
if self.config.models:
|
|
378
|
+
return model_id in self.config.models
|
|
379
|
+
|
|
380
|
+
# Check cached models
|
|
381
|
+
if self._models_cache:
|
|
382
|
+
return any(m.id == model_id for m in self._models_cache)
|
|
383
|
+
|
|
384
|
+
# If no cache, assume we support it (will fail at runtime if not)
|
|
385
|
+
return True
|
|
386
|
+
|
|
387
|
+
async def test_connection(self) -> ProviderStatus:
|
|
388
|
+
"""Test connectivity to the provider and return status.
|
|
389
|
+
|
|
390
|
+
Returns:
|
|
391
|
+
ProviderStatus with connection details
|
|
392
|
+
|
|
393
|
+
"""
|
|
394
|
+
from datetime import datetime
|
|
395
|
+
|
|
396
|
+
start_time = time.perf_counter()
|
|
397
|
+
|
|
398
|
+
try:
|
|
399
|
+
models = await self.refresh_models()
|
|
400
|
+
latency_ms = int((time.perf_counter() - start_time) * 1000)
|
|
401
|
+
|
|
402
|
+
return ProviderStatus(
|
|
403
|
+
name=self.config.name,
|
|
404
|
+
status=ProviderStatusEnum.CONNECTED,
|
|
405
|
+
last_checked=datetime.now(),
|
|
406
|
+
available_models=[m.id for m in models],
|
|
407
|
+
latency_ms=latency_ms,
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
except LocalProviderConnectionError as e:
|
|
411
|
+
return ProviderStatus(
|
|
412
|
+
name=self.config.name,
|
|
413
|
+
status=ProviderStatusEnum.DISCONNECTED,
|
|
414
|
+
last_checked=datetime.now(),
|
|
415
|
+
error_message=str(e),
|
|
416
|
+
)
|
|
417
|
+
except Exception as e:
|
|
418
|
+
return ProviderStatus(
|
|
419
|
+
name=self.config.name,
|
|
420
|
+
status=ProviderStatusEnum.ERROR,
|
|
421
|
+
last_checked=datetime.now(),
|
|
422
|
+
error_message=str(e),
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
def _extract_error_detail(self, error: httpx.HTTPStatusError) -> str:
|
|
426
|
+
"""Extract error detail from HTTP error response."""
|
|
427
|
+
try:
|
|
428
|
+
data = error.response.json()
|
|
429
|
+
if "error" in data:
|
|
430
|
+
err = data["error"]
|
|
431
|
+
if isinstance(err, dict):
|
|
432
|
+
return err.get("message", str(err))
|
|
433
|
+
return str(err)
|
|
434
|
+
except Exception:
|
|
435
|
+
logger.debug("Failed to parse error response JSON")
|
|
436
|
+
return f"HTTP {error.response.status_code}"
|
|
437
|
+
|
|
438
|
+
def _estimate_tokens(
|
|
439
|
+
self, messages: list[dict[str, Any]], response_content: str
|
|
440
|
+
) -> tuple[int, int]:
|
|
441
|
+
"""Estimate token counts using tiktoken when server doesn't provide them.
|
|
442
|
+
|
|
443
|
+
Args:
|
|
444
|
+
messages: Input messages
|
|
445
|
+
response_content: Output content
|
|
446
|
+
|
|
447
|
+
Returns:
|
|
448
|
+
Tuple of (input_tokens, output_tokens)
|
|
449
|
+
|
|
450
|
+
"""
|
|
451
|
+
try:
|
|
452
|
+
import tiktoken
|
|
453
|
+
|
|
454
|
+
enc = tiktoken.get_encoding("cl100k_base")
|
|
455
|
+
|
|
456
|
+
# Estimate input tokens
|
|
457
|
+
input_text = ""
|
|
458
|
+
for msg in messages:
|
|
459
|
+
input_text += msg.get("role", "") + " " + msg.get("content", "") + " "
|
|
460
|
+
input_tokens = len(enc.encode(input_text))
|
|
461
|
+
|
|
462
|
+
# Estimate output tokens
|
|
463
|
+
output_tokens = len(enc.encode(response_content))
|
|
464
|
+
|
|
465
|
+
return input_tokens, output_tokens
|
|
466
|
+
except ImportError:
|
|
467
|
+
# tiktoken not available, return rough estimates
|
|
468
|
+
input_chars = sum(len(str(m.get("content", ""))) for m in messages)
|
|
469
|
+
return input_chars // 4, len(response_content) // 4
|
|
470
|
+
|
|
471
|
+
def _infer_tool_support(self, model_id: str) -> bool:
|
|
472
|
+
"""Infer whether a model likely supports tool calling.
|
|
473
|
+
|
|
474
|
+
Based on known models that support function calling.
|
|
475
|
+
|
|
476
|
+
"""
|
|
477
|
+
# Check cache first
|
|
478
|
+
if model_id in self._tool_support_cache:
|
|
479
|
+
return self._tool_support_cache[model_id]
|
|
480
|
+
|
|
481
|
+
model_lower = model_id.lower()
|
|
482
|
+
|
|
483
|
+
# Models known to support tools
|
|
484
|
+
tool_supporting_patterns = [
|
|
485
|
+
"llama3.1",
|
|
486
|
+
"llama-3.1",
|
|
487
|
+
"llama3.2",
|
|
488
|
+
"llama-3.2",
|
|
489
|
+
"mistral",
|
|
490
|
+
"mixtral",
|
|
491
|
+
"qwen",
|
|
492
|
+
"command-r",
|
|
493
|
+
"gemma2",
|
|
494
|
+
]
|
|
495
|
+
|
|
496
|
+
supports = any(pattern in model_lower for pattern in tool_supporting_patterns)
|
|
497
|
+
self._tool_support_cache[model_id] = supports
|
|
498
|
+
return supports
|
sandboxy/providers/registry.py
CHANGED
|
@@ -1,12 +1,57 @@
|
|
|
1
1
|
"""Provider registry for managing multiple LLM providers."""
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
3
5
|
import logging
|
|
4
6
|
import os
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
5
8
|
|
|
6
9
|
from sandboxy.providers.base import BaseProvider, ModelInfo, ProviderError
|
|
7
10
|
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from sandboxy.providers.local import LocalProvider
|
|
13
|
+
|
|
8
14
|
logger = logging.getLogger(__name__)
|
|
9
15
|
|
|
16
|
+
# Local providers are lazily loaded to avoid circular imports
|
|
17
|
+
_local_providers: dict[str, LocalProvider] | None = None
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _get_local_providers() -> dict[str, BaseProvider]:
|
|
21
|
+
"""Load local providers from config file.
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
Dict mapping provider name to LocalProvider instance
|
|
25
|
+
|
|
26
|
+
"""
|
|
27
|
+
global _local_providers
|
|
28
|
+
if _local_providers is not None:
|
|
29
|
+
return _local_providers
|
|
30
|
+
|
|
31
|
+
_local_providers = {}
|
|
32
|
+
|
|
33
|
+
try:
|
|
34
|
+
from sandboxy.providers.config import get_enabled_providers
|
|
35
|
+
from sandboxy.providers.local import LocalProvider
|
|
36
|
+
|
|
37
|
+
for config in get_enabled_providers():
|
|
38
|
+
try:
|
|
39
|
+
_local_providers[config.name] = LocalProvider(config)
|
|
40
|
+
logger.info(f"Local provider '{config.name}' loaded from config")
|
|
41
|
+
except Exception as e:
|
|
42
|
+
logger.warning(f"Failed to load local provider '{config.name}': {e}")
|
|
43
|
+
except Exception as e:
|
|
44
|
+
logger.debug(f"Could not load local providers: {e}")
|
|
45
|
+
|
|
46
|
+
return _local_providers
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def reload_local_providers() -> None:
|
|
50
|
+
"""Force reload of local providers from config file."""
|
|
51
|
+
global _local_providers
|
|
52
|
+
_local_providers = None
|
|
53
|
+
_get_local_providers()
|
|
54
|
+
|
|
10
55
|
|
|
11
56
|
class ProviderRegistry:
|
|
12
57
|
"""Registry of available LLM providers.
|
|
@@ -25,9 +70,15 @@ class ProviderRegistry:
|
|
|
25
70
|
|
|
26
71
|
"""
|
|
27
72
|
|
|
28
|
-
def __init__(self):
|
|
29
|
-
"""Initialize registry and detect available providers.
|
|
73
|
+
def __init__(self, include_local: bool = True):
|
|
74
|
+
"""Initialize registry and detect available providers.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
include_local: Whether to include local providers from config
|
|
78
|
+
|
|
79
|
+
"""
|
|
30
80
|
self.providers: dict[str, BaseProvider] = {}
|
|
81
|
+
self._include_local = include_local
|
|
31
82
|
self._init_providers()
|
|
32
83
|
|
|
33
84
|
def _init_providers(self) -> None:
|
|
@@ -62,10 +113,18 @@ class ProviderRegistry:
|
|
|
62
113
|
except ProviderError as e:
|
|
63
114
|
logger.warning(f"Failed to init Anthropic: {e}")
|
|
64
115
|
|
|
116
|
+
# Load local providers from config
|
|
117
|
+
if self._include_local:
|
|
118
|
+
local_providers = _get_local_providers()
|
|
119
|
+
for name, provider in local_providers.items():
|
|
120
|
+
self.providers[name] = provider
|
|
121
|
+
logger.info(f"Local provider '{name}' registered")
|
|
122
|
+
|
|
65
123
|
if not self.providers:
|
|
66
124
|
logger.warning(
|
|
67
125
|
"No providers available. Set at least one API key: "
|
|
68
|
-
"OPENROUTER_API_KEY, OPENAI_API_KEY, or ANTHROPIC_API_KEY"
|
|
126
|
+
"OPENROUTER_API_KEY, OPENAI_API_KEY, or ANTHROPIC_API_KEY, "
|
|
127
|
+
"or configure local providers with 'sandboxy providers add'"
|
|
69
128
|
)
|
|
70
129
|
|
|
71
130
|
def get_provider_for_model(self, model_id: str) -> BaseProvider:
|
|
@@ -91,15 +150,23 @@ class ProviderRegistry:
|
|
|
91
150
|
provider="registry",
|
|
92
151
|
)
|
|
93
152
|
|
|
94
|
-
# If model has a prefix (
|
|
95
|
-
# This is OpenRouter's convention - direct APIs don't use prefixes
|
|
153
|
+
# If model has a prefix (provider/model format)
|
|
96
154
|
if "/" in model_id:
|
|
155
|
+
provider_name, model_name = model_id.split("/", 1)
|
|
156
|
+
|
|
157
|
+
# Check for local provider first (e.g., "ollama/llama3")
|
|
158
|
+
if provider_name in self.providers:
|
|
159
|
+
provider = self.providers[provider_name]
|
|
160
|
+
# Verify it's a local provider or supports the model
|
|
161
|
+
if hasattr(provider, "config") or provider.supports_model(model_id):
|
|
162
|
+
return provider
|
|
163
|
+
|
|
164
|
+
# OpenRouter format (e.g., "openai/gpt-4o")
|
|
97
165
|
if "openrouter" in self.providers:
|
|
98
166
|
return self.providers["openrouter"]
|
|
99
|
-
|
|
100
|
-
|
|
167
|
+
|
|
168
|
+
# Fallback to direct provider if prefix matches
|
|
101
169
|
if provider_name == "openai" and "openai" in self.providers:
|
|
102
|
-
# Note: caller should strip prefix when calling direct provider
|
|
103
170
|
return self.providers["openai"]
|
|
104
171
|
if provider_name == "anthropic" and "anthropic" in self.providers:
|
|
105
172
|
return self.providers["anthropic"]
|
|
@@ -131,18 +198,32 @@ class ProviderRegistry:
|
|
|
131
198
|
def list_all_models(self) -> list[ModelInfo]:
|
|
132
199
|
"""List all models from all providers.
|
|
133
200
|
|
|
134
|
-
Returns deduplicated list with
|
|
135
|
-
|
|
201
|
+
Returns deduplicated list with:
|
|
202
|
+
1. Local providers first (highest priority)
|
|
203
|
+
2. Direct cloud providers (OpenAI, Anthropic)
|
|
204
|
+
3. OpenRouter last (fallback)
|
|
136
205
|
"""
|
|
137
206
|
seen_ids: set[str] = set()
|
|
138
207
|
models: list[ModelInfo] = []
|
|
139
208
|
|
|
140
|
-
# Add
|
|
209
|
+
# Add local provider models first (highest priority)
|
|
141
210
|
for name, provider in self.providers.items():
|
|
142
|
-
if name
|
|
143
|
-
continue
|
|
211
|
+
if name in ("openrouter", "openai", "anthropic"):
|
|
212
|
+
continue
|
|
144
213
|
|
|
145
214
|
for model in provider.list_models():
|
|
215
|
+
# Use provider-prefixed ID for local models
|
|
216
|
+
prefixed_id = f"{name}/{model.id}"
|
|
217
|
+
if prefixed_id not in seen_ids:
|
|
218
|
+
seen_ids.add(prefixed_id)
|
|
219
|
+
models.append(model)
|
|
220
|
+
|
|
221
|
+
# Add direct cloud provider models
|
|
222
|
+
for name in ("openai", "anthropic"):
|
|
223
|
+
if name not in self.providers:
|
|
224
|
+
continue
|
|
225
|
+
|
|
226
|
+
for model in self.providers[name].list_models():
|
|
146
227
|
if model.id not in seen_ids:
|
|
147
228
|
seen_ids.add(model.id)
|
|
148
229
|
models.append(model)
|
|
@@ -156,6 +237,19 @@ class ProviderRegistry:
|
|
|
156
237
|
|
|
157
238
|
return models
|
|
158
239
|
|
|
240
|
+
def get_local_providers(self) -> dict[str, BaseProvider]:
|
|
241
|
+
"""Get all local providers.
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
Dict of local provider name to provider instance
|
|
245
|
+
|
|
246
|
+
"""
|
|
247
|
+
return {
|
|
248
|
+
name: provider
|
|
249
|
+
for name, provider in self.providers.items()
|
|
250
|
+
if hasattr(provider, "config") # LocalProvider has config attribute
|
|
251
|
+
}
|
|
252
|
+
|
|
159
253
|
def get_provider(self, provider_name: str) -> BaseProvider | None:
|
|
160
254
|
"""Get a specific provider by name.
|
|
161
255
|
|