sandboxy 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,243 @@
1
+ """Configuration models for local model providers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import logging
7
+ import re
8
+ from dataclasses import dataclass
9
+ from datetime import datetime
10
+ from enum import Enum
11
+ from pathlib import Path
12
+ from typing import Any, Literal
13
+
14
+ from pydantic import BaseModel, Field, field_validator
15
+
16
+ from sandboxy.providers.base import ModelInfo
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # Default config file location
21
+ DEFAULT_CONFIG_PATH = Path.home() / ".sandboxy" / "providers.json"
22
+
23
+
24
+ class ProviderStatusEnum(str, Enum):
25
+ """Status of a local provider connection."""
26
+
27
+ CONNECTED = "connected"
28
+ DISCONNECTED = "disconnected"
29
+ ERROR = "error"
30
+ UNKNOWN = "unknown"
31
+
32
+
33
+ class LocalProviderConfig(BaseModel):
34
+ """Configuration for a local model provider."""
35
+
36
+ name: str = Field(
37
+ ...,
38
+ description="User-friendly name for this provider",
39
+ examples=["ollama-local", "my-vllm-server"],
40
+ )
41
+
42
+ type: Literal["ollama", "lmstudio", "vllm", "openai-compatible"] = Field(
43
+ default="openai-compatible",
44
+ description="Provider type for specialized handling",
45
+ )
46
+
47
+ base_url: str = Field(
48
+ ...,
49
+ description="Base URL for the provider API",
50
+ examples=["http://localhost:11434/v1", "http://localhost:1234/v1"],
51
+ )
52
+
53
+ api_key: str | None = Field(
54
+ default=None,
55
+ description="Optional API key for authenticated providers",
56
+ )
57
+
58
+ enabled: bool = Field(
59
+ default=True,
60
+ description="Whether this provider is active",
61
+ )
62
+
63
+ models: list[str] = Field(
64
+ default_factory=list,
65
+ description="Manually configured model IDs (overrides auto-discovery)",
66
+ )
67
+
68
+ default_params: dict[str, Any] = Field(
69
+ default_factory=dict,
70
+ description="Default parameters for completions (temperature, max_tokens, etc.)",
71
+ )
72
+
73
+ @field_validator("name")
74
+ @classmethod
75
+ def validate_name(cls, v: str) -> str:
76
+ """Validate provider name is alphanumeric with hyphens/underscores."""
77
+ if not re.match(r"^[a-zA-Z0-9_-]+$", v):
78
+ msg = "Provider name must be alphanumeric with hyphens/underscores only"
79
+ raise ValueError(msg)
80
+ return v
81
+
82
+ @field_validator("base_url")
83
+ @classmethod
84
+ def validate_base_url(cls, v: str) -> str:
85
+ """Validate base URL format."""
86
+ if not v.startswith(("http://", "https://")):
87
+ msg = "Base URL must start with http:// or https://"
88
+ raise ValueError(msg)
89
+ # Remove trailing slash for consistency
90
+ return v.rstrip("/")
91
+
92
+
93
+ class ProvidersConfigFile(BaseModel):
94
+ """Root structure for ~/.sandboxy/providers.json."""
95
+
96
+ version: int = Field(
97
+ default=1,
98
+ description="Config file schema version for migrations",
99
+ )
100
+
101
+ providers: list[LocalProviderConfig] = Field(
102
+ default_factory=list,
103
+ description="List of configured local providers",
104
+ )
105
+
106
+ def get_provider(self, name: str) -> LocalProviderConfig | None:
107
+ """Get a provider by name."""
108
+ for provider in self.providers:
109
+ if provider.name == name:
110
+ return provider
111
+ return None
112
+
113
+ def add_provider(self, config: LocalProviderConfig) -> None:
114
+ """Add a new provider configuration.
115
+
116
+ Raises:
117
+ ValueError: If provider with same name already exists
118
+
119
+ """
120
+ if self.get_provider(config.name):
121
+ msg = f"Provider '{config.name}' already exists"
122
+ raise ValueError(msg)
123
+ self.providers.append(config)
124
+
125
+ def remove_provider(self, name: str) -> bool:
126
+ """Remove a provider by name.
127
+
128
+ Returns:
129
+ True if removed, False if not found
130
+
131
+ """
132
+ for i, provider in enumerate(self.providers):
133
+ if provider.name == name:
134
+ self.providers.pop(i)
135
+ return True
136
+ return False
137
+
138
+ def update_provider(self, name: str, **updates: Any) -> LocalProviderConfig | None:
139
+ """Update a provider's configuration.
140
+
141
+ Args:
142
+ name: Provider name to update
143
+ **updates: Fields to update
144
+
145
+ Returns:
146
+ Updated config or None if not found
147
+
148
+ """
149
+ provider = self.get_provider(name)
150
+ if not provider:
151
+ return None
152
+
153
+ # Create updated config
154
+ data = provider.model_dump()
155
+ data.update(updates)
156
+ updated = LocalProviderConfig(**data)
157
+
158
+ # Replace in list
159
+ for i, p in enumerate(self.providers):
160
+ if p.name == name:
161
+ self.providers[i] = updated
162
+ return updated
163
+ return None
164
+
165
+
166
+ @dataclass
167
+ class LocalModelInfo(ModelInfo):
168
+ """Model information with local-specific metadata.
169
+
170
+ Extends the base ModelInfo with local-specific fields.
171
+ """
172
+
173
+ # Local-specific fields (added to inherited fields from ModelInfo)
174
+ provider_name: str = ""
175
+ is_local: bool = True
176
+ capabilities_verified: bool = False
177
+
178
+
179
+ class ProviderStatus(BaseModel):
180
+ """Runtime status of a provider connection."""
181
+
182
+ name: str
183
+ status: ProviderStatusEnum
184
+ last_checked: datetime | None = None
185
+ error_message: str | None = None
186
+ available_models: list[str] = Field(default_factory=list)
187
+ latency_ms: int | None = None
188
+
189
+
190
+ # --- Config file load/save functions ---
191
+
192
+
193
+ def load_providers_config(path: Path | None = None) -> ProvidersConfigFile:
194
+ """Load providers configuration from file.
195
+
196
+ Args:
197
+ path: Config file path. Defaults to ~/.sandboxy/providers.json
198
+
199
+ Returns:
200
+ ProvidersConfigFile with loaded or default configuration
201
+
202
+ """
203
+ config_path = path or DEFAULT_CONFIG_PATH
204
+
205
+ if not config_path.exists():
206
+ logger.debug(f"Config file not found at {config_path}, using defaults")
207
+ return ProvidersConfigFile()
208
+
209
+ try:
210
+ with open(config_path) as f:
211
+ data = json.load(f)
212
+ return ProvidersConfigFile.model_validate(data)
213
+ except json.JSONDecodeError as e:
214
+ logger.warning(f"Invalid JSON in config file: {e}")
215
+ return ProvidersConfigFile()
216
+ except Exception as e:
217
+ logger.warning(f"Failed to load config: {e}")
218
+ return ProvidersConfigFile()
219
+
220
+
221
+ def save_providers_config(config: ProvidersConfigFile, path: Path | None = None) -> None:
222
+ """Save providers configuration to file.
223
+
224
+ Args:
225
+ config: Configuration to save
226
+ path: Config file path. Defaults to ~/.sandboxy/providers.json
227
+
228
+ """
229
+ config_path = path or DEFAULT_CONFIG_PATH
230
+
231
+ # Ensure directory exists
232
+ config_path.parent.mkdir(parents=True, exist_ok=True)
233
+
234
+ with open(config_path, "w") as f:
235
+ json.dump(config.model_dump(), f, indent=2, default=str)
236
+
237
+ logger.debug(f"Saved providers config to {config_path}")
238
+
239
+
240
+ def get_enabled_providers() -> list[LocalProviderConfig]:
241
+ """Get list of enabled local providers from config."""
242
+ config = load_providers_config()
243
+ return [p for p in config.providers if p.enabled]