sqlsaber 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- sqlsaber/__init__.py +3 -0
- sqlsaber/__main__.py +4 -0
- sqlsaber/agents/__init__.py +9 -0
- sqlsaber/agents/anthropic.py +451 -0
- sqlsaber/agents/base.py +67 -0
- sqlsaber/agents/streaming.py +26 -0
- sqlsaber/cli/__init__.py +7 -0
- sqlsaber/cli/commands.py +132 -0
- sqlsaber/cli/database.py +275 -0
- sqlsaber/cli/display.py +207 -0
- sqlsaber/cli/interactive.py +93 -0
- sqlsaber/cli/memory.py +239 -0
- sqlsaber/cli/models.py +231 -0
- sqlsaber/cli/streaming.py +94 -0
- sqlsaber/config/__init__.py +7 -0
- sqlsaber/config/api_keys.py +102 -0
- sqlsaber/config/database.py +252 -0
- sqlsaber/config/settings.py +115 -0
- sqlsaber/database/__init__.py +9 -0
- sqlsaber/database/connection.py +187 -0
- sqlsaber/database/schema.py +678 -0
- sqlsaber/memory/__init__.py +1 -0
- sqlsaber/memory/manager.py +77 -0
- sqlsaber/memory/storage.py +176 -0
- sqlsaber/models/__init__.py +13 -0
- sqlsaber/models/events.py +28 -0
- sqlsaber/models/types.py +40 -0
- sqlsaber-0.1.0.dist-info/METADATA +168 -0
- sqlsaber-0.1.0.dist-info/RECORD +32 -0
- sqlsaber-0.1.0.dist-info/WHEEL +4 -0
- sqlsaber-0.1.0.dist-info/entry_points.txt +4 -0
- sqlsaber-0.1.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
"""API Key management for SQLSaber."""
|
|
2
|
+
|
|
3
|
+
import getpass
|
|
4
|
+
import os
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
import keyring
|
|
8
|
+
from rich.console import Console
|
|
9
|
+
|
|
10
|
+
console = Console()
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class APIKeyManager:
|
|
14
|
+
"""Manages API keys with cascading retrieval: env var -> keyring -> prompt."""
|
|
15
|
+
|
|
16
|
+
def __init__(self):
|
|
17
|
+
self.service_prefix = "sqlsaber"
|
|
18
|
+
|
|
19
|
+
def get_api_key(self, provider: str) -> Optional[str]:
|
|
20
|
+
"""Get API key for the specified provider using cascading logic."""
|
|
21
|
+
env_var_name = self._get_env_var_name(provider)
|
|
22
|
+
service_name = self._get_service_name(provider)
|
|
23
|
+
|
|
24
|
+
# 1. Check environment variable first
|
|
25
|
+
api_key = os.getenv(env_var_name)
|
|
26
|
+
if api_key:
|
|
27
|
+
console.print(f"Using {env_var_name} from environment", style="dim")
|
|
28
|
+
return api_key
|
|
29
|
+
|
|
30
|
+
# 2. Check keyring storage
|
|
31
|
+
try:
|
|
32
|
+
api_key = keyring.get_password(service_name, provider)
|
|
33
|
+
if api_key:
|
|
34
|
+
console.print(
|
|
35
|
+
f"Using stored {provider} API key from keyring", style="dim"
|
|
36
|
+
)
|
|
37
|
+
return api_key
|
|
38
|
+
except Exception as e:
|
|
39
|
+
# Keyring access failed, continue to prompt
|
|
40
|
+
console.print(f"Keyring access failed: {e}", style="dim yellow")
|
|
41
|
+
|
|
42
|
+
# 3. Prompt user for API key
|
|
43
|
+
return self._prompt_and_store_key(provider, env_var_name, service_name)
|
|
44
|
+
|
|
45
|
+
def _get_env_var_name(self, provider: str) -> str:
|
|
46
|
+
"""Get the expected environment variable name for a provider."""
|
|
47
|
+
if provider == "openai":
|
|
48
|
+
return "OPENAI_API_KEY"
|
|
49
|
+
elif provider == "anthropic":
|
|
50
|
+
return "ANTHROPIC_API_KEY"
|
|
51
|
+
else:
|
|
52
|
+
return "AI_API_KEY"
|
|
53
|
+
|
|
54
|
+
def _get_service_name(self, provider: str) -> str:
|
|
55
|
+
"""Get the keyring service name for a provider."""
|
|
56
|
+
return f"{self.service_prefix}-{provider}-api-key"
|
|
57
|
+
|
|
58
|
+
def _prompt_and_store_key(
|
|
59
|
+
self, provider: str, env_var_name: str, service_name: str
|
|
60
|
+
) -> Optional[str]:
|
|
61
|
+
"""Prompt user for API key and store it in keyring."""
|
|
62
|
+
try:
|
|
63
|
+
console.print(
|
|
64
|
+
f"\n{provider.title()} API key not found in environment or keyring."
|
|
65
|
+
)
|
|
66
|
+
console.print("You can either:")
|
|
67
|
+
console.print(f" 1. Set the {env_var_name} environment variable")
|
|
68
|
+
console.print(
|
|
69
|
+
" 2. Enter it now to securely store using your operating system's credentials store"
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
api_key = getpass.getpass(
|
|
73
|
+
f"\nEnter your {provider.title()} API key (or press Enter to skip): "
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
if not api_key.strip():
|
|
77
|
+
console.print(
|
|
78
|
+
"No API key provided. Some functionality may not work.",
|
|
79
|
+
style="yellow",
|
|
80
|
+
)
|
|
81
|
+
return None
|
|
82
|
+
|
|
83
|
+
# Store in keyring for future use
|
|
84
|
+
try:
|
|
85
|
+
keyring.set_password(service_name, provider, api_key.strip())
|
|
86
|
+
console.print("API key stored securely for future use", style="green")
|
|
87
|
+
except Exception as e:
|
|
88
|
+
console.print(
|
|
89
|
+
f"Warning: Could not store API key in keyring: {e}", style="yellow"
|
|
90
|
+
)
|
|
91
|
+
console.print(
|
|
92
|
+
"You may need to enter it again next time", style="yellow"
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
return api_key.strip()
|
|
96
|
+
|
|
97
|
+
except KeyboardInterrupt:
|
|
98
|
+
console.print("\nOperation cancelled", style="yellow")
|
|
99
|
+
return None
|
|
100
|
+
except Exception as e:
|
|
101
|
+
console.print(f"Error prompting for API key: {e}", style="red")
|
|
102
|
+
return None
|
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
"""Database configuration management."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
import platform
|
|
6
|
+
import stat
|
|
7
|
+
import keyring
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Dict, List, Optional, Any
|
|
11
|
+
from urllib.parse import quote_plus
|
|
12
|
+
|
|
13
|
+
import platformdirs
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class DatabaseConfig:
|
|
18
|
+
"""Database connection configuration."""
|
|
19
|
+
|
|
20
|
+
name: str
|
|
21
|
+
type: str # postgresql, mysql, sqlite
|
|
22
|
+
host: Optional[str]
|
|
23
|
+
port: Optional[int]
|
|
24
|
+
database: str
|
|
25
|
+
username: Optional[str]
|
|
26
|
+
password: Optional[str] = None
|
|
27
|
+
ssl_mode: Optional[str] = None
|
|
28
|
+
schema: Optional[str] = None
|
|
29
|
+
|
|
30
|
+
def to_connection_string(self) -> str:
|
|
31
|
+
"""Convert config to database connection string."""
|
|
32
|
+
password = self.password or self._get_password_from_keyring()
|
|
33
|
+
|
|
34
|
+
if self.type == "postgresql":
|
|
35
|
+
if not all([self.host, self.port, self.username]):
|
|
36
|
+
raise ValueError("Host, port, and username are required for PostgreSQL")
|
|
37
|
+
if password:
|
|
38
|
+
encoded_password = quote_plus(password)
|
|
39
|
+
return f"postgresql://{self.username}:{encoded_password}@{self.host}:{self.port}/{self.database}"
|
|
40
|
+
else:
|
|
41
|
+
return f"postgresql://{self.username}@{self.host}:{self.port}/{self.database}"
|
|
42
|
+
elif self.type == "mysql":
|
|
43
|
+
if not all([self.host, self.port, self.username]):
|
|
44
|
+
raise ValueError("Host, port, and username are required for MySQL")
|
|
45
|
+
if password:
|
|
46
|
+
encoded_password = quote_plus(password)
|
|
47
|
+
return f"mysql://{self.username}:{encoded_password}@{self.host}:{self.port}/{self.database}"
|
|
48
|
+
else:
|
|
49
|
+
return (
|
|
50
|
+
f"mysql://{self.username}@{self.host}:{self.port}/{self.database}"
|
|
51
|
+
)
|
|
52
|
+
elif self.type == "sqlite":
|
|
53
|
+
return f"sqlite:///{self.database}"
|
|
54
|
+
else:
|
|
55
|
+
raise ValueError(f"Unsupported database type: {self.type}")
|
|
56
|
+
|
|
57
|
+
def _get_password_from_keyring(self) -> Optional[str]:
|
|
58
|
+
"""Get password from OS keyring."""
|
|
59
|
+
try:
|
|
60
|
+
return keyring.get_password("sqlsaber", f"{self.name}_{self.username}")
|
|
61
|
+
except Exception:
|
|
62
|
+
return None
|
|
63
|
+
|
|
64
|
+
def store_password_in_keyring(self, password: str) -> None:
|
|
65
|
+
"""Store password in OS keyring."""
|
|
66
|
+
keyring.set_password("sqlsaber", f"{self.name}_{self.username}", password)
|
|
67
|
+
|
|
68
|
+
def delete_password_from_keyring(self) -> None:
|
|
69
|
+
"""Delete password from OS keyring."""
|
|
70
|
+
try:
|
|
71
|
+
keyring.delete_password("sqlsaber", f"{self.name}_{self.username}")
|
|
72
|
+
except Exception:
|
|
73
|
+
pass
|
|
74
|
+
|
|
75
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
76
|
+
"""Convert to dictionary for JSON serialization."""
|
|
77
|
+
return {
|
|
78
|
+
"name": self.name,
|
|
79
|
+
"type": self.type,
|
|
80
|
+
"host": self.host,
|
|
81
|
+
"port": self.port,
|
|
82
|
+
"database": self.database,
|
|
83
|
+
"username": self.username,
|
|
84
|
+
"ssl_mode": self.ssl_mode,
|
|
85
|
+
"schema": self.schema,
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
@classmethod
|
|
89
|
+
def from_dict(cls, data: Dict[str, Any]) -> "DatabaseConfig":
|
|
90
|
+
"""Create from dictionary."""
|
|
91
|
+
return cls(
|
|
92
|
+
name=data["name"],
|
|
93
|
+
type=data["type"],
|
|
94
|
+
host=data["host"],
|
|
95
|
+
port=data["port"],
|
|
96
|
+
database=data["database"],
|
|
97
|
+
username=data["username"],
|
|
98
|
+
ssl_mode=data.get("ssl_mode"),
|
|
99
|
+
schema=data.get("schema"),
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class DatabaseConfigManager:
|
|
104
|
+
"""Manages database configurations."""
|
|
105
|
+
|
|
106
|
+
def __init__(self):
|
|
107
|
+
self.config_dir = Path(platformdirs.user_config_dir("sqlsaber", "sqlsaber"))
|
|
108
|
+
self.config_file = self.config_dir / "database_config.json"
|
|
109
|
+
self._ensure_config_dir()
|
|
110
|
+
|
|
111
|
+
def _ensure_config_dir(self) -> None:
|
|
112
|
+
"""Ensure config directory exists with proper permissions."""
|
|
113
|
+
self.config_dir.mkdir(parents=True, exist_ok=True)
|
|
114
|
+
self._set_secure_permissions(self.config_dir, is_directory=True)
|
|
115
|
+
|
|
116
|
+
def _set_secure_permissions(self, path: Path, is_directory: bool = False) -> None:
|
|
117
|
+
"""Set secure permissions cross-platform."""
|
|
118
|
+
try:
|
|
119
|
+
if platform.system() == "Windows":
|
|
120
|
+
# On Windows, rely on NTFS permissions and avoid chmod
|
|
121
|
+
# The default permissions are usually sufficient for user-only access
|
|
122
|
+
return
|
|
123
|
+
else:
|
|
124
|
+
# Unix-like systems (Linux, macOS)
|
|
125
|
+
if is_directory:
|
|
126
|
+
os.chmod(
|
|
127
|
+
path, stat.S_IRWXU
|
|
128
|
+
) # 0o700 - owner read/write/execute only
|
|
129
|
+
else:
|
|
130
|
+
os.chmod(
|
|
131
|
+
path, stat.S_IRUSR | stat.S_IWUSR
|
|
132
|
+
) # 0o600 - owner read/write only
|
|
133
|
+
except (OSError, PermissionError):
|
|
134
|
+
# If we can't set permissions, continue anyway
|
|
135
|
+
# The directory/file creation should still work
|
|
136
|
+
pass
|
|
137
|
+
|
|
138
|
+
def _load_config(self) -> Dict[str, Any]:
|
|
139
|
+
"""Load configuration from file."""
|
|
140
|
+
if not self.config_file.exists():
|
|
141
|
+
return {"default": None, "connections": {}}
|
|
142
|
+
|
|
143
|
+
try:
|
|
144
|
+
with open(self.config_file, "r") as f:
|
|
145
|
+
return json.load(f)
|
|
146
|
+
except (json.JSONDecodeError, IOError):
|
|
147
|
+
return {"default": None, "connections": {}}
|
|
148
|
+
|
|
149
|
+
def _save_config(self, config: Dict[str, Any]) -> None:
|
|
150
|
+
"""Save configuration to file."""
|
|
151
|
+
with open(self.config_file, "w") as f:
|
|
152
|
+
json.dump(config, f, indent=2)
|
|
153
|
+
|
|
154
|
+
# Set secure permissions cross-platform
|
|
155
|
+
self._set_secure_permissions(self.config_file, is_directory=False)
|
|
156
|
+
|
|
157
|
+
def add_database(
|
|
158
|
+
self, db_config: DatabaseConfig, password: Optional[str] = None
|
|
159
|
+
) -> None:
|
|
160
|
+
"""Add a database configuration."""
|
|
161
|
+
config = self._load_config()
|
|
162
|
+
|
|
163
|
+
# Check if database with this name already exists
|
|
164
|
+
if db_config.name in config["connections"]:
|
|
165
|
+
raise ValueError(f"Database '{db_config.name}' already exists")
|
|
166
|
+
|
|
167
|
+
# Store password in keyring if provided
|
|
168
|
+
if password:
|
|
169
|
+
db_config.store_password_in_keyring(password)
|
|
170
|
+
|
|
171
|
+
# Add to config
|
|
172
|
+
config["connections"][db_config.name] = db_config.to_dict()
|
|
173
|
+
|
|
174
|
+
# Set as default if it's the first one
|
|
175
|
+
if not config["default"]:
|
|
176
|
+
config["default"] = db_config.name
|
|
177
|
+
|
|
178
|
+
self._save_config(config)
|
|
179
|
+
|
|
180
|
+
def get_database(self, name: str) -> Optional[DatabaseConfig]:
|
|
181
|
+
"""Get a database configuration by name."""
|
|
182
|
+
config = self._load_config()
|
|
183
|
+
|
|
184
|
+
if name not in config["connections"]:
|
|
185
|
+
return None
|
|
186
|
+
|
|
187
|
+
return DatabaseConfig.from_dict(config["connections"][name])
|
|
188
|
+
|
|
189
|
+
def get_default_database(self) -> Optional[DatabaseConfig]:
|
|
190
|
+
"""Get the default database configuration."""
|
|
191
|
+
config = self._load_config()
|
|
192
|
+
|
|
193
|
+
default_name = config.get("default")
|
|
194
|
+
if not default_name:
|
|
195
|
+
return None
|
|
196
|
+
|
|
197
|
+
return self.get_database(default_name)
|
|
198
|
+
|
|
199
|
+
def list_databases(self) -> List[DatabaseConfig]:
|
|
200
|
+
"""List all database configurations."""
|
|
201
|
+
config = self._load_config()
|
|
202
|
+
|
|
203
|
+
databases = []
|
|
204
|
+
for name, db_data in config["connections"].items():
|
|
205
|
+
databases.append(DatabaseConfig.from_dict(db_data))
|
|
206
|
+
|
|
207
|
+
return databases
|
|
208
|
+
|
|
209
|
+
def remove_database(self, name: str) -> bool:
|
|
210
|
+
"""Remove a database configuration."""
|
|
211
|
+
config = self._load_config()
|
|
212
|
+
|
|
213
|
+
if name not in config["connections"]:
|
|
214
|
+
return False
|
|
215
|
+
|
|
216
|
+
# Remove password from keyring
|
|
217
|
+
db_config = DatabaseConfig.from_dict(config["connections"][name])
|
|
218
|
+
db_config.delete_password_from_keyring()
|
|
219
|
+
|
|
220
|
+
# Remove from config
|
|
221
|
+
del config["connections"][name]
|
|
222
|
+
|
|
223
|
+
# Update default if this was the default
|
|
224
|
+
if config["default"] == name:
|
|
225
|
+
remaining_connections = list(config["connections"].keys())
|
|
226
|
+
config["default"] = (
|
|
227
|
+
remaining_connections[0] if remaining_connections else None
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
self._save_config(config)
|
|
231
|
+
return True
|
|
232
|
+
|
|
233
|
+
def set_default_database(self, name: str) -> bool:
|
|
234
|
+
"""Set the default database."""
|
|
235
|
+
config = self._load_config()
|
|
236
|
+
|
|
237
|
+
if name not in config["connections"]:
|
|
238
|
+
return False
|
|
239
|
+
|
|
240
|
+
config["default"] = name
|
|
241
|
+
self._save_config(config)
|
|
242
|
+
return True
|
|
243
|
+
|
|
244
|
+
def has_databases(self) -> bool:
|
|
245
|
+
"""Check if any databases are configured."""
|
|
246
|
+
config = self._load_config()
|
|
247
|
+
return len(config["connections"]) > 0
|
|
248
|
+
|
|
249
|
+
def get_default_name(self) -> Optional[str]:
|
|
250
|
+
"""Get the name of the default database."""
|
|
251
|
+
config = self._load_config()
|
|
252
|
+
return config.get("default")
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
"""Configuration management for SQLSaber SQL Agent."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
import platform
|
|
6
|
+
import stat
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Dict, Optional
|
|
9
|
+
|
|
10
|
+
import platformdirs
|
|
11
|
+
|
|
12
|
+
from sqlsaber.config.api_keys import APIKeyManager
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ModelConfigManager:
|
|
16
|
+
"""Manages model configuration persistence."""
|
|
17
|
+
|
|
18
|
+
DEFAULT_MODEL = "anthropic:claude-sonnet-4-20250514"
|
|
19
|
+
|
|
20
|
+
def __init__(self):
|
|
21
|
+
self.config_dir = Path(platformdirs.user_config_dir("sqlsaber", "sqlsaber"))
|
|
22
|
+
self.config_file = self.config_dir / "model_config.json"
|
|
23
|
+
self._ensure_config_dir()
|
|
24
|
+
|
|
25
|
+
def _ensure_config_dir(self) -> None:
|
|
26
|
+
"""Ensure config directory exists with proper permissions."""
|
|
27
|
+
self.config_dir.mkdir(parents=True, exist_ok=True)
|
|
28
|
+
self._set_secure_permissions(self.config_dir, is_directory=True)
|
|
29
|
+
|
|
30
|
+
def _set_secure_permissions(self, path: Path, is_directory: bool = False) -> None:
|
|
31
|
+
"""Set secure permissions cross-platform."""
|
|
32
|
+
try:
|
|
33
|
+
if platform.system() == "Windows":
|
|
34
|
+
return
|
|
35
|
+
else:
|
|
36
|
+
if is_directory:
|
|
37
|
+
os.chmod(path, stat.S_IRWXU) # 0o700
|
|
38
|
+
else:
|
|
39
|
+
os.chmod(path, stat.S_IRUSR | stat.S_IWUSR) # 0o600
|
|
40
|
+
except (OSError, PermissionError):
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
def _load_config(self) -> Dict[str, Any]:
|
|
44
|
+
"""Load configuration from file."""
|
|
45
|
+
if not self.config_file.exists():
|
|
46
|
+
return {"model": self.DEFAULT_MODEL}
|
|
47
|
+
|
|
48
|
+
try:
|
|
49
|
+
with open(self.config_file, "r") as f:
|
|
50
|
+
config = json.load(f)
|
|
51
|
+
# Ensure we have a model set
|
|
52
|
+
if "model" not in config:
|
|
53
|
+
config["model"] = self.DEFAULT_MODEL
|
|
54
|
+
return config
|
|
55
|
+
except (json.JSONDecodeError, IOError):
|
|
56
|
+
return {"model": self.DEFAULT_MODEL}
|
|
57
|
+
|
|
58
|
+
def _save_config(self, config: Dict[str, Any]) -> None:
|
|
59
|
+
"""Save configuration to file."""
|
|
60
|
+
with open(self.config_file, "w") as f:
|
|
61
|
+
json.dump(config, f, indent=2)
|
|
62
|
+
|
|
63
|
+
self._set_secure_permissions(self.config_file, is_directory=False)
|
|
64
|
+
|
|
65
|
+
def get_model(self) -> str:
|
|
66
|
+
"""Get the configured model."""
|
|
67
|
+
config = self._load_config()
|
|
68
|
+
return config.get("model", self.DEFAULT_MODEL)
|
|
69
|
+
|
|
70
|
+
def set_model(self, model: str) -> None:
|
|
71
|
+
"""Set the model configuration."""
|
|
72
|
+
config = self._load_config()
|
|
73
|
+
config["model"] = model
|
|
74
|
+
self._save_config(config)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class Config:
|
|
78
|
+
"""Configuration class for SQLSaber."""
|
|
79
|
+
|
|
80
|
+
def __init__(self):
|
|
81
|
+
self.model_config_manager = ModelConfigManager()
|
|
82
|
+
self.model_name = self.model_config_manager.get_model()
|
|
83
|
+
self.api_key_manager = APIKeyManager()
|
|
84
|
+
self.api_key = self._get_api_key()
|
|
85
|
+
|
|
86
|
+
def _get_api_key(self) -> Optional[str]:
|
|
87
|
+
"""Get API key for the model provider using cascading logic."""
|
|
88
|
+
model = self.model_name
|
|
89
|
+
|
|
90
|
+
if model.startswith("openai:"):
|
|
91
|
+
return self.api_key_manager.get_api_key("openai")
|
|
92
|
+
elif model.startswith("anthropic:"):
|
|
93
|
+
return self.api_key_manager.get_api_key("anthropic")
|
|
94
|
+
else:
|
|
95
|
+
# For other providers, use generic key
|
|
96
|
+
return self.api_key_manager.get_api_key("generic")
|
|
97
|
+
|
|
98
|
+
def set_model(self, model: str) -> None:
|
|
99
|
+
"""Set the model and update configuration."""
|
|
100
|
+
self.model_config_manager.set_model(model)
|
|
101
|
+
self.model_name = model
|
|
102
|
+
# Update API key for new model
|
|
103
|
+
self.api_key = self._get_api_key()
|
|
104
|
+
|
|
105
|
+
def validate(self):
|
|
106
|
+
"""Validate that necessary configuration is present."""
|
|
107
|
+
if not self.api_key:
|
|
108
|
+
model = self.model_name
|
|
109
|
+
provider = "generic"
|
|
110
|
+
if model.startswith("openai:"):
|
|
111
|
+
provider = "OpenAI"
|
|
112
|
+
elif model.startswith("anthropic:"):
|
|
113
|
+
provider = "Anthropic"
|
|
114
|
+
|
|
115
|
+
raise ValueError(f"{provider} API key not found.")
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
"""Database connection management."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Any, Dict, List, Optional
|
|
5
|
+
from urllib.parse import urlparse
|
|
6
|
+
|
|
7
|
+
import aiomysql
|
|
8
|
+
import aiosqlite
|
|
9
|
+
import asyncpg
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class BaseDatabaseConnection(ABC):
|
|
13
|
+
"""Abstract base class for database connections."""
|
|
14
|
+
|
|
15
|
+
def __init__(self, connection_string: str):
|
|
16
|
+
self.connection_string = connection_string
|
|
17
|
+
self._pool = None
|
|
18
|
+
|
|
19
|
+
@abstractmethod
|
|
20
|
+
async def get_pool(self):
|
|
21
|
+
"""Get or create connection pool."""
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
@abstractmethod
|
|
25
|
+
async def close(self):
|
|
26
|
+
"""Close the connection pool."""
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
@abstractmethod
|
|
30
|
+
async def execute_query(self, query: str, *args) -> List[Dict[str, Any]]:
|
|
31
|
+
"""Execute a query and return results as list of dicts.
|
|
32
|
+
|
|
33
|
+
All queries run in a transaction that is rolled back at the end,
|
|
34
|
+
ensuring no changes are persisted to the database.
|
|
35
|
+
"""
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class PostgreSQLConnection(BaseDatabaseConnection):
|
|
40
|
+
"""PostgreSQL database connection using asyncpg."""
|
|
41
|
+
|
|
42
|
+
def __init__(self, connection_string: str):
|
|
43
|
+
super().__init__(connection_string)
|
|
44
|
+
self._pool: Optional[asyncpg.Pool] = None
|
|
45
|
+
|
|
46
|
+
async def get_pool(self) -> asyncpg.Pool:
|
|
47
|
+
"""Get or create connection pool."""
|
|
48
|
+
if self._pool is None:
|
|
49
|
+
self._pool = await asyncpg.create_pool(
|
|
50
|
+
self.connection_string, min_size=1, max_size=10
|
|
51
|
+
)
|
|
52
|
+
return self._pool
|
|
53
|
+
|
|
54
|
+
async def close(self):
|
|
55
|
+
"""Close the connection pool."""
|
|
56
|
+
if self._pool:
|
|
57
|
+
await self._pool.close()
|
|
58
|
+
self._pool = None
|
|
59
|
+
|
|
60
|
+
async def execute_query(self, query: str, *args) -> List[Dict[str, Any]]:
|
|
61
|
+
"""Execute a query and return results as list of dicts.
|
|
62
|
+
|
|
63
|
+
All queries run in a transaction that is rolled back at the end,
|
|
64
|
+
ensuring no changes are persisted to the database.
|
|
65
|
+
"""
|
|
66
|
+
pool = await self.get_pool()
|
|
67
|
+
async with pool.acquire() as conn:
|
|
68
|
+
# Start a transaction that we'll always rollback
|
|
69
|
+
transaction = conn.transaction()
|
|
70
|
+
await transaction.start()
|
|
71
|
+
|
|
72
|
+
try:
|
|
73
|
+
rows = await conn.fetch(query, *args)
|
|
74
|
+
return [dict(row) for row in rows]
|
|
75
|
+
finally:
|
|
76
|
+
# Always rollback to ensure no changes are committed
|
|
77
|
+
await transaction.rollback()
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class MySQLConnection(BaseDatabaseConnection):
|
|
81
|
+
"""MySQL database connection using aiomysql."""
|
|
82
|
+
|
|
83
|
+
def __init__(self, connection_string: str):
|
|
84
|
+
super().__init__(connection_string)
|
|
85
|
+
self._pool: Optional[aiomysql.Pool] = None
|
|
86
|
+
self._parse_connection_string()
|
|
87
|
+
|
|
88
|
+
def _parse_connection_string(self):
|
|
89
|
+
"""Parse MySQL connection string into components."""
|
|
90
|
+
parsed = urlparse(self.connection_string)
|
|
91
|
+
self.host = parsed.hostname or "localhost"
|
|
92
|
+
self.port = parsed.port or 3306
|
|
93
|
+
self.database = parsed.path.lstrip("/") if parsed.path else ""
|
|
94
|
+
self.user = parsed.username or ""
|
|
95
|
+
self.password = parsed.password or ""
|
|
96
|
+
|
|
97
|
+
async def get_pool(self) -> aiomysql.Pool:
|
|
98
|
+
"""Get or create connection pool."""
|
|
99
|
+
if self._pool is None:
|
|
100
|
+
self._pool = await aiomysql.create_pool(
|
|
101
|
+
host=self.host,
|
|
102
|
+
port=self.port,
|
|
103
|
+
user=self.user,
|
|
104
|
+
password=self.password,
|
|
105
|
+
db=self.database,
|
|
106
|
+
minsize=1,
|
|
107
|
+
maxsize=10,
|
|
108
|
+
autocommit=False,
|
|
109
|
+
)
|
|
110
|
+
return self._pool
|
|
111
|
+
|
|
112
|
+
async def close(self):
|
|
113
|
+
"""Close the connection pool."""
|
|
114
|
+
if self._pool:
|
|
115
|
+
self._pool.close()
|
|
116
|
+
await self._pool.wait_closed()
|
|
117
|
+
self._pool = None
|
|
118
|
+
|
|
119
|
+
async def execute_query(self, query: str, *args) -> List[Dict[str, Any]]:
|
|
120
|
+
"""Execute a query and return results as list of dicts.
|
|
121
|
+
|
|
122
|
+
All queries run in a transaction that is rolled back at the end,
|
|
123
|
+
ensuring no changes are persisted to the database.
|
|
124
|
+
"""
|
|
125
|
+
pool = await self.get_pool()
|
|
126
|
+
async with pool.acquire() as conn:
|
|
127
|
+
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
|
128
|
+
# Start transaction
|
|
129
|
+
await conn.begin()
|
|
130
|
+
try:
|
|
131
|
+
await cursor.execute(query, args if args else None)
|
|
132
|
+
rows = await cursor.fetchall()
|
|
133
|
+
return [dict(row) for row in rows]
|
|
134
|
+
finally:
|
|
135
|
+
# Always rollback to ensure no changes are committed
|
|
136
|
+
await conn.rollback()
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class SQLiteConnection(BaseDatabaseConnection):
|
|
140
|
+
"""SQLite database connection using aiosqlite."""
|
|
141
|
+
|
|
142
|
+
def __init__(self, connection_string: str):
|
|
143
|
+
super().__init__(connection_string)
|
|
144
|
+
# Extract database path from sqlite:///path format
|
|
145
|
+
self.database_path = connection_string.replace("sqlite:///", "")
|
|
146
|
+
|
|
147
|
+
async def get_pool(self):
|
|
148
|
+
"""SQLite doesn't use connection pooling, return database path."""
|
|
149
|
+
return self.database_path
|
|
150
|
+
|
|
151
|
+
async def close(self):
|
|
152
|
+
"""SQLite connections are created per query, no persistent pool to close."""
|
|
153
|
+
pass
|
|
154
|
+
|
|
155
|
+
async def execute_query(self, query: str, *args) -> List[Dict[str, Any]]:
|
|
156
|
+
"""Execute a query and return results as list of dicts.
|
|
157
|
+
|
|
158
|
+
All queries run in a transaction that is rolled back at the end,
|
|
159
|
+
ensuring no changes are persisted to the database.
|
|
160
|
+
"""
|
|
161
|
+
async with aiosqlite.connect(self.database_path) as conn:
|
|
162
|
+
# Enable row factory for dict-like access
|
|
163
|
+
conn.row_factory = aiosqlite.Row
|
|
164
|
+
|
|
165
|
+
# Start transaction
|
|
166
|
+
await conn.execute("BEGIN")
|
|
167
|
+
try:
|
|
168
|
+
cursor = await conn.execute(query, args if args else ())
|
|
169
|
+
rows = await cursor.fetchall()
|
|
170
|
+
return [dict(row) for row in rows]
|
|
171
|
+
finally:
|
|
172
|
+
# Always rollback to ensure no changes are committed
|
|
173
|
+
await conn.rollback()
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def DatabaseConnection(connection_string: str) -> BaseDatabaseConnection:
|
|
177
|
+
"""Factory function to create appropriate database connection based on connection string."""
|
|
178
|
+
if connection_string.startswith("postgresql://"):
|
|
179
|
+
return PostgreSQLConnection(connection_string)
|
|
180
|
+
elif connection_string.startswith("mysql://"):
|
|
181
|
+
return MySQLConnection(connection_string)
|
|
182
|
+
elif connection_string.startswith("sqlite:///"):
|
|
183
|
+
return SQLiteConnection(connection_string)
|
|
184
|
+
else:
|
|
185
|
+
raise ValueError(
|
|
186
|
+
f"Unsupported database type in connection string: {connection_string}"
|
|
187
|
+
)
|