sandboxy 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,498 @@
1
+ """Local model provider for OpenAI-compatible servers (Ollama, LM Studio, vLLM, etc.)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import time
7
+ from collections.abc import AsyncIterator
8
+ from typing import Any
9
+
10
+ import httpx
11
+
12
+ from sandboxy.providers.base import BaseProvider, ModelInfo, ModelResponse, ProviderError
13
+ from sandboxy.providers.config import (
14
+ LocalModelInfo,
15
+ LocalProviderConfig,
16
+ ProviderStatus,
17
+ ProviderStatusEnum,
18
+ )
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # Default timeout for requests (60 seconds)
23
+ DEFAULT_TIMEOUT = 60.0
24
+
25
+
26
+ class LocalProviderConnectionError(ProviderError):
27
+ """Error when local provider is unreachable."""
28
+
29
+ def __init__(self, provider_name: str, base_url: str, original_error: str):
30
+ self.base_url = base_url
31
+ self.original_error = original_error
32
+ message = (
33
+ f"Cannot connect to {provider_name} at {base_url}. "
34
+ f"Is the server running? Error: {original_error}"
35
+ )
36
+ super().__init__(message, provider=provider_name)
37
+
38
+
39
+ class LocalProvider(BaseProvider):
40
+ """Provider for local OpenAI-compatible servers.
41
+
42
+ Supports:
43
+ - Ollama (http://localhost:11434/v1)
44
+ - LM Studio (http://localhost:1234/v1)
45
+ - vLLM (http://localhost:8000/v1)
46
+ - Any OpenAI-compatible endpoint
47
+
48
+ """
49
+
50
+ provider_name: str = "local"
51
+
52
+ def __init__(self, config: LocalProviderConfig):
53
+ """Initialize local provider with configuration.
54
+
55
+ Args:
56
+ config: Provider configuration including base URL and optional API key
57
+
58
+ """
59
+ self.config = config
60
+ self.provider_name = config.name
61
+
62
+ # Build headers
63
+ headers: dict[str, str] = {"Content-Type": "application/json"}
64
+ if config.api_key:
65
+ headers["Authorization"] = f"Bearer {config.api_key}"
66
+
67
+ self._client = httpx.AsyncClient(
68
+ base_url=config.base_url,
69
+ headers=headers,
70
+ timeout=DEFAULT_TIMEOUT,
71
+ )
72
+
73
+ # Cache for discovered models
74
+ self._models_cache: list[LocalModelInfo] | None = None
75
+ self._tool_support_cache: dict[str, bool] = {}
76
+
77
+ async def close(self) -> None:
78
+ """Close the HTTP client."""
79
+ await self._client.aclose()
80
+
81
+ async def complete(
82
+ self,
83
+ model: str,
84
+ messages: list[dict[str, Any]],
85
+ temperature: float = 0.7,
86
+ max_tokens: int = 4096,
87
+ tools: list[dict[str, Any]] | None = None,
88
+ **kwargs: Any,
89
+ ) -> ModelResponse:
90
+ """Send a chat completion request to the local server.
91
+
92
+ Args:
93
+ model: Model identifier (e.g., "llama3:8b", "mistral:latest")
94
+ messages: List of message dicts with 'role' and 'content'
95
+ temperature: Sampling temperature (0-2)
96
+ max_tokens: Maximum tokens in response
97
+ tools: Optional list of tool definitions for function calling
98
+ **kwargs: Additional parameters passed to the API
99
+
100
+ Returns:
101
+ ModelResponse with content and metadata
102
+
103
+ Raises:
104
+ LocalProviderConnectionError: If server is unreachable
105
+ ProviderError: If the request fails
106
+
107
+ """
108
+ # Strip provider prefix if present (e.g., "ollama/llama3" -> "llama3")
109
+ if "/" in model:
110
+ _, model = model.rsplit("/", 1)
111
+
112
+ start_time = time.perf_counter()
113
+
114
+ # Build request payload
115
+ payload: dict[str, Any] = {
116
+ "model": model,
117
+ "messages": messages,
118
+ "temperature": temperature,
119
+ "max_tokens": max_tokens,
120
+ "stream": False,
121
+ }
122
+
123
+ # Add tools if provided and model might support them
124
+ if tools:
125
+ payload["tools"] = tools
126
+
127
+ # Merge any default params from config
128
+ payload.update(self.config.default_params)
129
+ payload.update(kwargs)
130
+
131
+ try:
132
+ response = await self._client.post("/chat/completions", json=payload)
133
+ response.raise_for_status()
134
+ data = response.json()
135
+ except httpx.ConnectError as e:
136
+ raise LocalProviderConnectionError(
137
+ self.config.name,
138
+ self.config.base_url,
139
+ str(e),
140
+ ) from e
141
+ except httpx.HTTPStatusError as e:
142
+ error_detail = self._extract_error_detail(e)
143
+ raise ProviderError(
144
+ f"Request failed: {error_detail}",
145
+ provider=self.config.name,
146
+ model=model,
147
+ ) from e
148
+ except httpx.TimeoutException as e:
149
+ raise ProviderError(
150
+ f"Request timed out after {DEFAULT_TIMEOUT}s",
151
+ provider=self.config.name,
152
+ model=model,
153
+ ) from e
154
+
155
+ latency_ms = int((time.perf_counter() - start_time) * 1000)
156
+
157
+ # Extract response content
158
+ choice = data.get("choices", [{}])[0]
159
+ message = choice.get("message", {})
160
+ content = message.get("content", "")
161
+
162
+ # Handle tool calls in response
163
+ tool_calls = message.get("tool_calls")
164
+ if tool_calls:
165
+ # Include tool calls in raw response for caller to handle
166
+ pass
167
+
168
+ # Extract token usage
169
+ usage = data.get("usage", {})
170
+ input_tokens = usage.get("prompt_tokens", 0)
171
+ output_tokens = usage.get("completion_tokens", 0)
172
+
173
+ # If no usage provided, estimate with tiktoken
174
+ if input_tokens == 0 and output_tokens == 0:
175
+ input_tokens, output_tokens = self._estimate_tokens(messages, content)
176
+
177
+ return ModelResponse(
178
+ content=content,
179
+ model_id=model,
180
+ latency_ms=latency_ms,
181
+ input_tokens=input_tokens,
182
+ output_tokens=output_tokens,
183
+ cost_usd=0.0, # Local models have no API cost
184
+ finish_reason=choice.get("finish_reason"),
185
+ raw_response=data,
186
+ )
187
+
188
+ async def stream(
189
+ self,
190
+ model: str,
191
+ messages: list[dict[str, Any]],
192
+ temperature: float = 0.7,
193
+ max_tokens: int = 4096,
194
+ **kwargs: Any,
195
+ ) -> AsyncIterator[str]:
196
+ """Stream a chat completion response.
197
+
198
+ Args:
199
+ model: Model identifier
200
+ messages: List of message dicts
201
+ temperature: Sampling temperature
202
+ max_tokens: Maximum tokens
203
+ **kwargs: Additional parameters
204
+
205
+ Yields:
206
+ Content chunks as they arrive
207
+
208
+ """
209
+ # Strip provider prefix if present
210
+ if "/" in model:
211
+ _, model = model.rsplit("/", 1)
212
+
213
+ payload: dict[str, Any] = {
214
+ "model": model,
215
+ "messages": messages,
216
+ "temperature": temperature,
217
+ "max_tokens": max_tokens,
218
+ "stream": True,
219
+ }
220
+ payload.update(self.config.default_params)
221
+ payload.update(kwargs)
222
+
223
+ try:
224
+ async with self._client.stream("POST", "/chat/completions", json=payload) as response:
225
+ response.raise_for_status()
226
+
227
+ async for line in response.aiter_lines():
228
+ if not line or not line.startswith("data: "):
229
+ continue
230
+
231
+ data_str = line[6:] # Remove "data: " prefix
232
+ if data_str == "[DONE]":
233
+ break
234
+
235
+ try:
236
+ import json
237
+
238
+ data = json.loads(data_str)
239
+ delta = data.get("choices", [{}])[0].get("delta", {})
240
+ content = delta.get("content", "")
241
+ if content:
242
+ yield content
243
+ except Exception:
244
+ continue
245
+
246
+ except httpx.ConnectError as e:
247
+ raise LocalProviderConnectionError(
248
+ self.config.name,
249
+ self.config.base_url,
250
+ str(e),
251
+ ) from e
252
+
253
+ def list_models(self) -> list[ModelInfo]:
254
+ """Return available models from this provider.
255
+
256
+ Returns cached list if available. Call refresh_models() to update.
257
+
258
+ """
259
+ if self._models_cache is not None:
260
+ return self._models_cache
261
+
262
+ # Return manually configured models if any
263
+ if self.config.models:
264
+ return [
265
+ LocalModelInfo(
266
+ id=model_id,
267
+ name=model_id,
268
+ provider=self.config.name,
269
+ provider_name=self.config.name,
270
+ context_length=8192, # Default, unknown
271
+ input_cost_per_million=None,
272
+ output_cost_per_million=None,
273
+ supports_tools=False, # Unknown until verified
274
+ supports_vision=False,
275
+ supports_streaming=True,
276
+ is_local=True,
277
+ capabilities_verified=False,
278
+ )
279
+ for model_id in self.config.models
280
+ ]
281
+
282
+ # Return empty list - caller should use async refresh_models()
283
+ return []
284
+
285
+ async def refresh_models(self) -> list[LocalModelInfo]:
286
+ """Fetch available models from the provider's /v1/models endpoint.
287
+
288
+ Returns:
289
+ List of discovered models
290
+
291
+ """
292
+ try:
293
+ response = await self._client.get("/models")
294
+ response.raise_for_status()
295
+ data = response.json()
296
+
297
+ models: list[LocalModelInfo] = []
298
+
299
+ # Handle different response formats:
300
+ # - OpenAI format: {"data": [...]}
301
+ # - Ollama format: {"models": [...]} or direct list
302
+ model_list = data.get("data") or data.get("models") or []
303
+ if model_list is None:
304
+ model_list = []
305
+
306
+ # If data itself is a list (some providers return this)
307
+ if isinstance(data, list):
308
+ model_list = data
309
+
310
+ for model_data in model_list:
311
+ # Handle different model object formats
312
+ if isinstance(model_data, str):
313
+ # Some providers return just model IDs as strings
314
+ model_id = model_data
315
+ model_name = model_data
316
+ context_length = 8192
317
+ else:
318
+ # Object format - try various field names
319
+ model_id = (
320
+ model_data.get("id")
321
+ or model_data.get("model")
322
+ or model_data.get("name")
323
+ or "unknown"
324
+ )
325
+ model_name = model_data.get("name", model_id)
326
+ context_length = model_data.get("context_length", 8192)
327
+
328
+ models.append(
329
+ LocalModelInfo(
330
+ id=model_id,
331
+ name=model_name,
332
+ provider=self.config.name,
333
+ provider_name=self.config.name,
334
+ context_length=context_length,
335
+ input_cost_per_million=None,
336
+ output_cost_per_million=None,
337
+ supports_tools=self._infer_tool_support(model_id),
338
+ supports_vision=False,
339
+ supports_streaming=True,
340
+ is_local=True,
341
+ capabilities_verified=False,
342
+ )
343
+ )
344
+
345
+ self._models_cache = models
346
+ return models
347
+
348
+ except httpx.ConnectError as e:
349
+ raise LocalProviderConnectionError(
350
+ self.config.name,
351
+ self.config.base_url,
352
+ str(e),
353
+ ) from e
354
+ except Exception as e:
355
+ logger.warning(f"Failed to fetch models from {self.config.name}: {e}")
356
+ # Return manually configured models as fallback
357
+ return self.list_models()
358
+
359
+ def supports_model(self, model_id: str) -> bool:
360
+ """Check if this provider supports a given model.
361
+
362
+ Args:
363
+ model_id: Model identifier to check
364
+
365
+ Returns:
366
+ True if the model is supported
367
+
368
+ """
369
+ # Strip provider prefix if present
370
+ if "/" in model_id:
371
+ prefix, model_name = model_id.rsplit("/", 1)
372
+ if prefix != self.config.name:
373
+ return False
374
+ model_id = model_name
375
+
376
+ # Check manually configured models
377
+ if self.config.models:
378
+ return model_id in self.config.models
379
+
380
+ # Check cached models
381
+ if self._models_cache:
382
+ return any(m.id == model_id for m in self._models_cache)
383
+
384
+ # If no cache, assume we support it (will fail at runtime if not)
385
+ return True
386
+
387
+ async def test_connection(self) -> ProviderStatus:
388
+ """Test connectivity to the provider and return status.
389
+
390
+ Returns:
391
+ ProviderStatus with connection details
392
+
393
+ """
394
+ from datetime import datetime
395
+
396
+ start_time = time.perf_counter()
397
+
398
+ try:
399
+ models = await self.refresh_models()
400
+ latency_ms = int((time.perf_counter() - start_time) * 1000)
401
+
402
+ return ProviderStatus(
403
+ name=self.config.name,
404
+ status=ProviderStatusEnum.CONNECTED,
405
+ last_checked=datetime.now(),
406
+ available_models=[m.id for m in models],
407
+ latency_ms=latency_ms,
408
+ )
409
+
410
+ except LocalProviderConnectionError as e:
411
+ return ProviderStatus(
412
+ name=self.config.name,
413
+ status=ProviderStatusEnum.DISCONNECTED,
414
+ last_checked=datetime.now(),
415
+ error_message=str(e),
416
+ )
417
+ except Exception as e:
418
+ return ProviderStatus(
419
+ name=self.config.name,
420
+ status=ProviderStatusEnum.ERROR,
421
+ last_checked=datetime.now(),
422
+ error_message=str(e),
423
+ )
424
+
425
+ def _extract_error_detail(self, error: httpx.HTTPStatusError) -> str:
426
+ """Extract error detail from HTTP error response."""
427
+ try:
428
+ data = error.response.json()
429
+ if "error" in data:
430
+ err = data["error"]
431
+ if isinstance(err, dict):
432
+ return err.get("message", str(err))
433
+ return str(err)
434
+ except Exception:
435
+ logger.debug("Failed to parse error response JSON")
436
+ return f"HTTP {error.response.status_code}"
437
+
438
+ def _estimate_tokens(
439
+ self, messages: list[dict[str, Any]], response_content: str
440
+ ) -> tuple[int, int]:
441
+ """Estimate token counts using tiktoken when server doesn't provide them.
442
+
443
+ Args:
444
+ messages: Input messages
445
+ response_content: Output content
446
+
447
+ Returns:
448
+ Tuple of (input_tokens, output_tokens)
449
+
450
+ """
451
+ try:
452
+ import tiktoken
453
+
454
+ enc = tiktoken.get_encoding("cl100k_base")
455
+
456
+ # Estimate input tokens
457
+ input_text = ""
458
+ for msg in messages:
459
+ input_text += msg.get("role", "") + " " + msg.get("content", "") + " "
460
+ input_tokens = len(enc.encode(input_text))
461
+
462
+ # Estimate output tokens
463
+ output_tokens = len(enc.encode(response_content))
464
+
465
+ return input_tokens, output_tokens
466
+ except ImportError:
467
+ # tiktoken not available, return rough estimates
468
+ input_chars = sum(len(str(m.get("content", ""))) for m in messages)
469
+ return input_chars // 4, len(response_content) // 4
470
+
471
+ def _infer_tool_support(self, model_id: str) -> bool:
472
+ """Infer whether a model likely supports tool calling.
473
+
474
+ Based on known models that support function calling.
475
+
476
+ """
477
+ # Check cache first
478
+ if model_id in self._tool_support_cache:
479
+ return self._tool_support_cache[model_id]
480
+
481
+ model_lower = model_id.lower()
482
+
483
+ # Models known to support tools
484
+ tool_supporting_patterns = [
485
+ "llama3.1",
486
+ "llama-3.1",
487
+ "llama3.2",
488
+ "llama-3.2",
489
+ "mistral",
490
+ "mixtral",
491
+ "qwen",
492
+ "command-r",
493
+ "gemma2",
494
+ ]
495
+
496
+ supports = any(pattern in model_lower for pattern in tool_supporting_patterns)
497
+ self._tool_support_cache[model_id] = supports
498
+ return supports
@@ -1,12 +1,57 @@
1
1
  """Provider registry for managing multiple LLM providers."""
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  import logging
4
6
  import os
7
+ from typing import TYPE_CHECKING
5
8
 
6
9
  from sandboxy.providers.base import BaseProvider, ModelInfo, ProviderError
7
10
 
11
+ if TYPE_CHECKING:
12
+ from sandboxy.providers.local import LocalProvider
13
+
8
14
  logger = logging.getLogger(__name__)
9
15
 
16
+ # Local providers are lazily loaded to avoid circular imports
17
+ _local_providers: dict[str, LocalProvider] | None = None
18
+
19
+
20
+ def _get_local_providers() -> dict[str, BaseProvider]:
21
+ """Load local providers from config file.
22
+
23
+ Returns:
24
+ Dict mapping provider name to LocalProvider instance
25
+
26
+ """
27
+ global _local_providers
28
+ if _local_providers is not None:
29
+ return _local_providers
30
+
31
+ _local_providers = {}
32
+
33
+ try:
34
+ from sandboxy.providers.config import get_enabled_providers
35
+ from sandboxy.providers.local import LocalProvider
36
+
37
+ for config in get_enabled_providers():
38
+ try:
39
+ _local_providers[config.name] = LocalProvider(config)
40
+ logger.info(f"Local provider '{config.name}' loaded from config")
41
+ except Exception as e:
42
+ logger.warning(f"Failed to load local provider '{config.name}': {e}")
43
+ except Exception as e:
44
+ logger.debug(f"Could not load local providers: {e}")
45
+
46
+ return _local_providers
47
+
48
+
49
+ def reload_local_providers() -> None:
50
+ """Force reload of local providers from config file."""
51
+ global _local_providers
52
+ _local_providers = None
53
+ _get_local_providers()
54
+
10
55
 
11
56
  class ProviderRegistry:
12
57
  """Registry of available LLM providers.
@@ -25,9 +70,15 @@ class ProviderRegistry:
25
70
 
26
71
  """
27
72
 
28
- def __init__(self):
29
- """Initialize registry and detect available providers."""
73
+ def __init__(self, include_local: bool = True):
74
+ """Initialize registry and detect available providers.
75
+
76
+ Args:
77
+ include_local: Whether to include local providers from config
78
+
79
+ """
30
80
  self.providers: dict[str, BaseProvider] = {}
81
+ self._include_local = include_local
31
82
  self._init_providers()
32
83
 
33
84
  def _init_providers(self) -> None:
@@ -62,10 +113,18 @@ class ProviderRegistry:
62
113
  except ProviderError as e:
63
114
  logger.warning(f"Failed to init Anthropic: {e}")
64
115
 
116
+ # Load local providers from config
117
+ if self._include_local:
118
+ local_providers = _get_local_providers()
119
+ for name, provider in local_providers.items():
120
+ self.providers[name] = provider
121
+ logger.info(f"Local provider '{name}' registered")
122
+
65
123
  if not self.providers:
66
124
  logger.warning(
67
125
  "No providers available. Set at least one API key: "
68
- "OPENROUTER_API_KEY, OPENAI_API_KEY, or ANTHROPIC_API_KEY"
126
+ "OPENROUTER_API_KEY, OPENAI_API_KEY, or ANTHROPIC_API_KEY, "
127
+ "or configure local providers with 'sandboxy providers add'"
69
128
  )
70
129
 
71
130
  def get_provider_for_model(self, model_id: str) -> BaseProvider:
@@ -91,15 +150,23 @@ class ProviderRegistry:
91
150
  provider="registry",
92
151
  )
93
152
 
94
- # If model has a prefix (openai/gpt-4o format), use OpenRouter
95
- # This is OpenRouter's convention - direct APIs don't use prefixes
153
+ # If model has a prefix (provider/model format)
96
154
  if "/" in model_id:
155
+ provider_name, model_name = model_id.split("/", 1)
156
+
157
+ # Check for local provider first (e.g., "ollama/llama3")
158
+ if provider_name in self.providers:
159
+ provider = self.providers[provider_name]
160
+ # Verify it's a local provider or supports the model
161
+ if hasattr(provider, "config") or provider.supports_model(model_id):
162
+ return provider
163
+
164
+ # OpenRouter format (e.g., "openai/gpt-4o")
97
165
  if "openrouter" in self.providers:
98
166
  return self.providers["openrouter"]
99
- # If no OpenRouter, try to extract and use direct provider
100
- provider_name, model_name = model_id.split("/", 1)
167
+
168
+ # Fallback to direct provider if prefix matches
101
169
  if provider_name == "openai" and "openai" in self.providers:
102
- # Note: caller should strip prefix when calling direct provider
103
170
  return self.providers["openai"]
104
171
  if provider_name == "anthropic" and "anthropic" in self.providers:
105
172
  return self.providers["anthropic"]
@@ -131,18 +198,32 @@ class ProviderRegistry:
131
198
  def list_all_models(self) -> list[ModelInfo]:
132
199
  """List all models from all providers.
133
200
 
134
- Returns deduplicated list with direct providers preferred
135
- over OpenRouter for overlapping models.
201
+ Returns deduplicated list with:
202
+ 1. Local providers first (highest priority)
203
+ 2. Direct cloud providers (OpenAI, Anthropic)
204
+ 3. OpenRouter last (fallback)
136
205
  """
137
206
  seen_ids: set[str] = set()
138
207
  models: list[ModelInfo] = []
139
208
 
140
- # Add direct provider models first (preferred)
209
+ # Add local provider models first (highest priority)
141
210
  for name, provider in self.providers.items():
142
- if name == "openrouter":
143
- continue # Add last
211
+ if name in ("openrouter", "openai", "anthropic"):
212
+ continue
144
213
 
145
214
  for model in provider.list_models():
215
+ # Use provider-prefixed ID for local models
216
+ prefixed_id = f"{name}/{model.id}"
217
+ if prefixed_id not in seen_ids:
218
+ seen_ids.add(prefixed_id)
219
+ models.append(model)
220
+
221
+ # Add direct cloud provider models
222
+ for name in ("openai", "anthropic"):
223
+ if name not in self.providers:
224
+ continue
225
+
226
+ for model in self.providers[name].list_models():
146
227
  if model.id not in seen_ids:
147
228
  seen_ids.add(model.id)
148
229
  models.append(model)
@@ -156,6 +237,19 @@ class ProviderRegistry:
156
237
 
157
238
  return models
158
239
 
240
+ def get_local_providers(self) -> dict[str, BaseProvider]:
241
+ """Get all local providers.
242
+
243
+ Returns:
244
+ Dict of local provider name to provider instance
245
+
246
+ """
247
+ return {
248
+ name: provider
249
+ for name, provider in self.providers.items()
250
+ if hasattr(provider, "config") # LocalProvider has config attribute
251
+ }
252
+
159
253
  def get_provider(self, provider_name: str) -> BaseProvider | None:
160
254
  """Get a specific provider by name.
161
255