sandboxy 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -16,11 +16,37 @@ MAX_RETRIES = 3
16
16
  RETRY_DELAY_BASE = 1.0 # seconds
17
17
 
18
18
 
19
+ def _is_local_provider_model(model_id: str) -> bool:
20
+ """Check if a model ID refers to a local provider.
21
+
22
+ Args:
23
+ model_id: Model identifier
24
+
25
+ Returns:
26
+ True if the model is from a configured local provider
27
+ """
28
+ if "/" not in model_id:
29
+ return False
30
+
31
+ provider_name = model_id.split("/")[0]
32
+
33
+ # Check if this provider name matches a configured local provider
34
+ try:
35
+ from sandboxy.providers.config import load_providers_config
36
+
37
+ config = load_providers_config()
38
+ return any(p.name == provider_name and p.enabled for p in config.providers)
39
+ except Exception:
40
+ return False
41
+
42
+
19
43
  class LlmPromptAgent(BaseAgent):
20
44
  """Agent that uses an LLM via OpenAI-compatible API.
21
45
 
22
- Supports both direct OpenAI and OpenRouter (for 400+ models).
23
- Uses OpenRouter when model contains "/" (e.g., "openai/gpt-4o").
46
+ Supports:
47
+ - Local providers (Ollama, LM Studio, vLLM) when model matches configured provider
48
+ - OpenRouter (for 400+ cloud models)
49
+ - Direct OpenAI when model has no prefix
24
50
  """
25
51
 
26
52
  def __init__(self, config: AgentConfig) -> None:
@@ -31,7 +57,12 @@ class LlmPromptAgent(BaseAgent):
31
57
  """
32
58
  super().__init__(config)
33
59
  self._client: Any = None
34
- self._is_openrouter = "/" in (config.model or "")
60
+ self._local_provider: Any = None
61
+
62
+ # Check for local provider first
63
+ self._is_local = _is_local_provider_model(config.model or "")
64
+ self._is_openrouter = not self._is_local and "/" in (config.model or "")
65
+
35
66
  # Token usage tracking
36
67
  self._total_input_tokens = 0
37
68
  self._total_output_tokens = 0
@@ -39,6 +70,9 @@ class LlmPromptAgent(BaseAgent):
39
70
  @property
40
71
  def api_key(self) -> str:
41
72
  """Get the appropriate API key based on model type."""
73
+ if self._is_local:
74
+ # Local providers may not need an API key, or it's in the provider config
75
+ return ""
42
76
  if self._is_openrouter:
43
77
  return os.getenv("OPENROUTER_API_KEY", "")
44
78
  return os.getenv("OPENAI_API_KEY", "")
@@ -49,15 +83,46 @@ class LlmPromptAgent(BaseAgent):
49
83
  if self._client is None:
50
84
  from openai import OpenAI
51
85
 
52
- if self._is_openrouter:
53
- logger.debug("Initializing OpenRouter client for model: %s", self.config.model)
54
- self._client = OpenAI(
55
- api_key=self.api_key,
56
- base_url="https://openrouter.ai/api/v1",
57
- )
58
- else:
59
- logger.debug("Initializing OpenAI client for model: %s", self.config.model)
60
- self._client = OpenAI(api_key=self.api_key)
86
+ if self._is_local:
87
+ # Get local provider and create client pointing to it
88
+ provider_name = (self.config.model or "").split("/")[0]
89
+ from sandboxy.providers.config import load_providers_config
90
+
91
+ config = load_providers_config()
92
+ provider_config = config.get_provider(provider_name)
93
+
94
+ if provider_config:
95
+ logger.debug(
96
+ "Initializing local client for %s at %s",
97
+ provider_name,
98
+ provider_config.base_url,
99
+ )
100
+ headers = {}
101
+ if provider_config.api_key:
102
+ headers["Authorization"] = f"Bearer {provider_config.api_key}"
103
+
104
+ self._client = OpenAI(
105
+ api_key=provider_config.api_key or "not-needed",
106
+ base_url=provider_config.base_url,
107
+ default_headers=headers if headers else None,
108
+ )
109
+ else:
110
+ logger.warning(
111
+ "Local provider %s not found, falling back to OpenRouter", provider_name
112
+ )
113
+ self._is_local = False
114
+ self._is_openrouter = True
115
+
116
+ if self._client is None: # Not set by local provider path
117
+ if self._is_openrouter:
118
+ logger.debug("Initializing OpenRouter client for model: %s", self.config.model)
119
+ self._client = OpenAI(
120
+ api_key=self.api_key,
121
+ base_url="https://openrouter.ai/api/v1",
122
+ )
123
+ else:
124
+ logger.debug("Initializing OpenAI client for model: %s", self.config.model)
125
+ self._client = OpenAI(api_key=self.api_key)
61
126
  return self._client
62
127
 
63
128
  def step(
@@ -66,7 +131,8 @@ class LlmPromptAgent(BaseAgent):
66
131
  available_tools: list[dict[str, Any]] | None = None,
67
132
  ) -> AgentAction:
68
133
  """Process conversation and return next action using LLM."""
69
- if not self.api_key:
134
+ # Local providers don't require an API key
135
+ if not self._is_local and not self.api_key:
70
136
  return self._stub_response(history)
71
137
 
72
138
  messages = self._build_messages(history)
@@ -188,8 +254,13 @@ class LlmPromptAgent(BaseAgent):
188
254
  messages: list[dict[str, Any]],
189
255
  tools: list[dict[str, Any]] | None,
190
256
  ) -> Any:
191
- """Make API call to OpenAI/OpenRouter."""
257
+ """Make API call to OpenAI/OpenRouter/Local provider."""
192
258
  model = self.config.model or "gpt-4o-mini"
259
+
260
+ # For local providers, strip the provider prefix (e.g., "ollama/llama3" -> "llama3")
261
+ if self._is_local and "/" in model:
262
+ model = model.split("/", 1)[1]
263
+
193
264
  kwargs: dict[str, Any] = {
194
265
  "model": model,
195
266
  "messages": messages,
sandboxy/api/app.py CHANGED
@@ -58,12 +58,13 @@ def create_local_app(
58
58
  )
59
59
 
60
60
  # Local routes only
61
- from sandboxy.api.routes import agents, tools
61
+ from sandboxy.api.routes import agents, providers, tools
62
62
  from sandboxy.api.routes import local as local_routes
63
63
 
64
64
  app.include_router(local_routes.router, prefix="/api/v1", tags=["local"])
65
65
  app.include_router(agents.router, prefix="/api/v1", tags=["agents"])
66
66
  app.include_router(tools.router, prefix="/api/v1", tags=["tools"])
67
+ app.include_router(providers.router, prefix="/api/v1", tags=["providers"])
67
68
 
68
69
  @app.get("/health")
69
70
  async def health_check():
@@ -636,6 +636,8 @@ async def compare_models(request: CompareModelsRequest) -> CompareModelsResponse
636
636
  exporter.export(
637
637
  result=result.to_dict(),
638
638
  scenario_path=scenario_path,
639
+ scenario_name=spec.name,
640
+ scenario_id=spec.id,
639
641
  agent_name=result.model,
640
642
  )
641
643
  except ImportError:
@@ -692,10 +694,40 @@ def calculate_cost(model_id: str, input_tokens: int, output_tokens: int) -> floa
692
694
 
693
695
  @router.get("/local/models")
694
696
  async def list_available_models() -> list[dict[str, Any]]:
695
- """List available models from OpenRouter."""
697
+ """List available models from OpenRouter and local providers."""
698
+ from sandboxy.providers.config import get_enabled_providers
699
+ from sandboxy.providers.local import LocalProvider
696
700
  from sandboxy.providers.openrouter import OPENROUTER_MODELS
697
701
 
698
702
  models = []
703
+
704
+ # Add models from local providers first
705
+ for provider_config in get_enabled_providers():
706
+ try:
707
+ provider = LocalProvider(provider_config)
708
+ local_models = await provider.refresh_models()
709
+ await provider.close()
710
+
711
+ for model in local_models:
712
+ # Model ID includes provider prefix for routing
713
+ full_model_id = f"{provider_config.name}/{model.id}"
714
+ models.append(
715
+ {
716
+ "id": full_model_id,
717
+ "name": model.name,
718
+ "price": "Local",
719
+ "pricing": {"input": 0, "output": 0},
720
+ "provider": provider_config.name,
721
+ "context_length": model.context_length,
722
+ "supports_vision": model.supports_vision,
723
+ "is_local": True,
724
+ "provider_name": provider_config.name,
725
+ }
726
+ )
727
+ except Exception as e:
728
+ logger.warning(f"Failed to fetch models from {provider_config.name}: {e}")
729
+
730
+ # Add OpenRouter models
699
731
  for model_id, info in OPENROUTER_MODELS.items():
700
732
  # Format price string
701
733
  if info.input_cost_per_million == 0 and info.output_cost_per_million == 0:
@@ -715,6 +747,7 @@ async def list_available_models() -> list[dict[str, Any]]:
715
747
  "provider": info.provider,
716
748
  "context_length": info.context_length,
717
749
  "supports_vision": info.supports_vision,
750
+ "is_local": False,
718
751
  }
719
752
  )
720
753
 
@@ -0,0 +1,369 @@
1
+ """API routes for local provider management."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from typing import Any, Literal
7
+
8
+ from fastapi import APIRouter, HTTPException, status
9
+ from pydantic import BaseModel, Field
10
+
11
+ from sandboxy.providers.config import (
12
+ LocalProviderConfig,
13
+ ProviderStatusEnum,
14
+ load_providers_config,
15
+ save_providers_config,
16
+ )
17
+ from sandboxy.providers.local import LocalProvider
18
+ from sandboxy.providers.registry import reload_local_providers
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ router = APIRouter(prefix="/providers", tags=["providers"])
23
+
24
+
25
+ # --- Response Models ---
26
+
27
+
28
+ class ProviderSummary(BaseModel):
29
+ """Summary of a provider for list view."""
30
+
31
+ name: str
32
+ type: str
33
+ base_url: str
34
+ enabled: bool
35
+ status: ProviderStatusEnum
36
+ model_count: int
37
+ models: list[str] = Field(default_factory=list)
38
+
39
+
40
+ class ProviderListResponse(BaseModel):
41
+ """Response for GET /api/v1/providers."""
42
+
43
+ providers: list[ProviderSummary]
44
+
45
+
46
+ class LocalModelInfoResponse(BaseModel):
47
+ """Model info in API response."""
48
+
49
+ id: str
50
+ name: str
51
+ context_length: int
52
+ supports_tools: bool
53
+ is_local: bool = True
54
+
55
+
56
+ class ProviderStatusResponse(BaseModel):
57
+ """Provider connection status."""
58
+
59
+ status: ProviderStatusEnum
60
+ last_checked: str | None = None
61
+ available_models: list[str] = Field(default_factory=list)
62
+ latency_ms: int | None = None
63
+ error_message: str | None = None
64
+
65
+
66
+ class ProviderDetailResponse(BaseModel):
67
+ """Response for GET /api/v1/providers/{name}."""
68
+
69
+ config: dict[str, Any] # LocalProviderConfig as dict
70
+ status: ProviderStatusResponse
71
+ models: list[LocalModelInfoResponse]
72
+
73
+
74
+ class AddProviderRequest(BaseModel):
75
+ """Request body for POST /api/v1/providers."""
76
+
77
+ name: str
78
+ type: Literal["ollama", "lmstudio", "vllm", "openai-compatible"] = "openai-compatible"
79
+ base_url: str
80
+ api_key: str | None = None
81
+ models: list[str] = Field(default_factory=list)
82
+ default_params: dict[str, Any] = Field(default_factory=dict)
83
+
84
+
85
+ class UpdateProviderRequest(BaseModel):
86
+ """Request body for PATCH /api/v1/providers/{name}."""
87
+
88
+ enabled: bool | None = None
89
+ api_key: str | None = None
90
+ models: list[str] | None = None
91
+ default_params: dict[str, Any] | None = None
92
+
93
+
94
+ class TestConnectionResponse(BaseModel):
95
+ """Response for POST /api/v1/providers/{name}/test."""
96
+
97
+ success: bool
98
+ latency_ms: int | None = None
99
+ models_found: list[str] = Field(default_factory=list)
100
+ error: str | None = None
101
+
102
+
103
+ class RefreshModelsResponse(BaseModel):
104
+ """Response for POST /api/v1/providers/{name}/refresh."""
105
+
106
+ models_found: list[str]
107
+ models_added: list[str]
108
+ models_removed: list[str]
109
+
110
+
111
+ class ErrorDetail(BaseModel):
112
+ """Standard error response."""
113
+
114
+ code: str
115
+ message: str
116
+ details: dict[str, Any] | None = None
117
+
118
+
119
+ # --- Routes ---
120
+
121
+
122
+ @router.get("", response_model=ProviderListResponse)
123
+ async def list_providers() -> ProviderListResponse:
124
+ """List all configured providers with status."""
125
+ config = load_providers_config()
126
+
127
+ summaries: list[ProviderSummary] = []
128
+ for pconfig in config.providers:
129
+ provider = LocalProvider(pconfig)
130
+ try:
131
+ status = await provider.test_connection()
132
+ summaries.append(
133
+ ProviderSummary(
134
+ name=pconfig.name,
135
+ type=pconfig.type,
136
+ base_url=pconfig.base_url,
137
+ enabled=pconfig.enabled,
138
+ status=status.status,
139
+ model_count=len(status.available_models),
140
+ models=status.available_models,
141
+ )
142
+ )
143
+ except Exception:
144
+ summaries.append(
145
+ ProviderSummary(
146
+ name=pconfig.name,
147
+ type=pconfig.type,
148
+ base_url=pconfig.base_url,
149
+ enabled=pconfig.enabled,
150
+ status=ProviderStatusEnum.ERROR,
151
+ model_count=0,
152
+ models=[],
153
+ )
154
+ )
155
+ finally:
156
+ await provider.close()
157
+
158
+ return ProviderListResponse(providers=summaries)
159
+
160
+
161
+ @router.post("", response_model=ProviderSummary, status_code=status.HTTP_201_CREATED)
162
+ async def add_provider(request: AddProviderRequest) -> ProviderSummary:
163
+ """Add a new provider."""
164
+ config = load_providers_config()
165
+
166
+ # Check for duplicate
167
+ if config.get_provider(request.name):
168
+ raise HTTPException(
169
+ status_code=status.HTTP_409_CONFLICT,
170
+ detail=ErrorDetail(
171
+ code="provider_exists",
172
+ message=f"Provider '{request.name}' already exists",
173
+ ).model_dump(),
174
+ )
175
+
176
+ # Validate and create config
177
+ try:
178
+ provider_config = LocalProviderConfig(
179
+ name=request.name,
180
+ type=request.type,
181
+ base_url=request.base_url,
182
+ api_key=request.api_key,
183
+ models=request.models,
184
+ default_params=request.default_params,
185
+ )
186
+ except ValueError as e:
187
+ raise HTTPException(
188
+ status_code=status.HTTP_400_BAD_REQUEST,
189
+ detail=ErrorDetail(
190
+ code="validation_error",
191
+ message=str(e),
192
+ ).model_dump(),
193
+ ) from e
194
+
195
+ # Test connection
196
+ provider = LocalProvider(provider_config)
197
+ try:
198
+ conn_status = await provider.test_connection()
199
+ finally:
200
+ await provider.close()
201
+
202
+ # Save config
203
+ config.add_provider(provider_config)
204
+ save_providers_config(config)
205
+ reload_local_providers()
206
+
207
+ return ProviderSummary(
208
+ name=provider_config.name,
209
+ type=provider_config.type,
210
+ base_url=provider_config.base_url,
211
+ enabled=provider_config.enabled,
212
+ status=conn_status.status,
213
+ model_count=len(conn_status.available_models),
214
+ )
215
+
216
+
217
+ @router.get("/{name}", response_model=ProviderDetailResponse)
218
+ async def get_provider(name: str) -> ProviderDetailResponse:
219
+ """Get detailed provider info including models."""
220
+ config = load_providers_config()
221
+ provider_config = config.get_provider(name)
222
+
223
+ if not provider_config:
224
+ raise HTTPException(
225
+ status_code=status.HTTP_404_NOT_FOUND,
226
+ detail=ErrorDetail(
227
+ code="provider_not_found",
228
+ message=f"Provider '{name}' not found",
229
+ details={"available_providers": [p.name for p in config.providers]},
230
+ ).model_dump(),
231
+ )
232
+
233
+ provider = LocalProvider(provider_config)
234
+ try:
235
+ conn_status = await provider.test_connection()
236
+ models = await provider.refresh_models()
237
+ finally:
238
+ await provider.close()
239
+
240
+ return ProviderDetailResponse(
241
+ config=provider_config.model_dump(),
242
+ status=ProviderStatusResponse(
243
+ status=conn_status.status,
244
+ last_checked=conn_status.last_checked.isoformat() if conn_status.last_checked else None,
245
+ available_models=conn_status.available_models,
246
+ latency_ms=conn_status.latency_ms,
247
+ error_message=conn_status.error_message,
248
+ ),
249
+ models=[
250
+ LocalModelInfoResponse(
251
+ id=m.id,
252
+ name=m.name,
253
+ context_length=m.context_length,
254
+ supports_tools=m.supports_tools,
255
+ )
256
+ for m in models
257
+ ],
258
+ )
259
+
260
+
261
+ @router.delete("/{name}", status_code=status.HTTP_204_NO_CONTENT)
262
+ async def delete_provider(name: str) -> None:
263
+ """Remove a provider."""
264
+ config = load_providers_config()
265
+
266
+ if not config.remove_provider(name):
267
+ raise HTTPException(
268
+ status_code=status.HTTP_404_NOT_FOUND,
269
+ detail=ErrorDetail(
270
+ code="provider_not_found",
271
+ message=f"Provider '{name}' not found",
272
+ ).model_dump(),
273
+ )
274
+
275
+ save_providers_config(config)
276
+ reload_local_providers()
277
+
278
+
279
+ @router.patch("/{name}", response_model=dict)
280
+ async def update_provider(name: str, request: UpdateProviderRequest) -> dict:
281
+ """Update provider configuration."""
282
+ config = load_providers_config()
283
+
284
+ updates = {k: v for k, v in request.model_dump().items() if v is not None}
285
+ if not updates:
286
+ raise HTTPException(
287
+ status_code=status.HTTP_400_BAD_REQUEST,
288
+ detail=ErrorDetail(
289
+ code="validation_error",
290
+ message="No fields to update",
291
+ ).model_dump(),
292
+ )
293
+
294
+ updated = config.update_provider(name, **updates)
295
+ if not updated:
296
+ raise HTTPException(
297
+ status_code=status.HTTP_404_NOT_FOUND,
298
+ detail=ErrorDetail(
299
+ code="provider_not_found",
300
+ message=f"Provider '{name}' not found",
301
+ ).model_dump(),
302
+ )
303
+
304
+ save_providers_config(config)
305
+ reload_local_providers()
306
+
307
+ return updated.model_dump()
308
+
309
+
310
+ @router.post("/{name}/test", response_model=TestConnectionResponse)
311
+ async def test_provider_connection(name: str) -> TestConnectionResponse:
312
+ """Test provider connection."""
313
+ config = load_providers_config()
314
+ provider_config = config.get_provider(name)
315
+
316
+ if not provider_config:
317
+ raise HTTPException(
318
+ status_code=status.HTTP_404_NOT_FOUND,
319
+ detail=ErrorDetail(
320
+ code="provider_not_found",
321
+ message=f"Provider '{name}' not found",
322
+ ).model_dump(),
323
+ )
324
+
325
+ provider = LocalProvider(provider_config)
326
+ try:
327
+ conn_status = await provider.test_connection()
328
+ finally:
329
+ await provider.close()
330
+
331
+ return TestConnectionResponse(
332
+ success=conn_status.status == ProviderStatusEnum.CONNECTED,
333
+ latency_ms=conn_status.latency_ms,
334
+ models_found=conn_status.available_models,
335
+ error=conn_status.error_message,
336
+ )
337
+
338
+
339
+ @router.post("/{name}/refresh", response_model=RefreshModelsResponse)
340
+ async def refresh_provider_models(name: str) -> RefreshModelsResponse:
341
+ """Refresh model list from provider."""
342
+ config = load_providers_config()
343
+ provider_config = config.get_provider(name)
344
+
345
+ if not provider_config:
346
+ raise HTTPException(
347
+ status_code=status.HTTP_404_NOT_FOUND,
348
+ detail=ErrorDetail(
349
+ code="provider_not_found",
350
+ message=f"Provider '{name}' not found",
351
+ ).model_dump(),
352
+ )
353
+
354
+ # Get current models
355
+ old_models = set(provider_config.models)
356
+
357
+ provider = LocalProvider(provider_config)
358
+ try:
359
+ models = await provider.refresh_models()
360
+ finally:
361
+ await provider.close()
362
+
363
+ new_models = {m.id for m in models}
364
+
365
+ return RefreshModelsResponse(
366
+ models_found=list(new_models),
367
+ models_added=list(new_models - old_models),
368
+ models_removed=list(old_models - new_models),
369
+ )