sqlsaber 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlsaber might be problematic. Click here for more details.

@@ -0,0 +1,102 @@
1
+ """API Key management for SQLSaber."""
2
+
3
+ import getpass
4
+ import os
5
+ from typing import Optional
6
+
7
+ import keyring
8
+ from rich.console import Console
9
+
10
+ console = Console()
11
+
12
+
13
+ class APIKeyManager:
14
+ """Manages API keys with cascading retrieval: env var -> keyring -> prompt."""
15
+
16
+ def __init__(self):
17
+ self.service_prefix = "sqlsaber"
18
+
19
+ def get_api_key(self, provider: str) -> Optional[str]:
20
+ """Get API key for the specified provider using cascading logic."""
21
+ env_var_name = self._get_env_var_name(provider)
22
+ service_name = self._get_service_name(provider)
23
+
24
+ # 1. Check environment variable first
25
+ api_key = os.getenv(env_var_name)
26
+ if api_key:
27
+ console.print(f"Using {env_var_name} from environment", style="dim")
28
+ return api_key
29
+
30
+ # 2. Check keyring storage
31
+ try:
32
+ api_key = keyring.get_password(service_name, provider)
33
+ if api_key:
34
+ console.print(
35
+ f"Using stored {provider} API key from keyring", style="dim"
36
+ )
37
+ return api_key
38
+ except Exception as e:
39
+ # Keyring access failed, continue to prompt
40
+ console.print(f"Keyring access failed: {e}", style="dim yellow")
41
+
42
+ # 3. Prompt user for API key
43
+ return self._prompt_and_store_key(provider, env_var_name, service_name)
44
+
45
+ def _get_env_var_name(self, provider: str) -> str:
46
+ """Get the expected environment variable name for a provider."""
47
+ if provider == "openai":
48
+ return "OPENAI_API_KEY"
49
+ elif provider == "anthropic":
50
+ return "ANTHROPIC_API_KEY"
51
+ else:
52
+ return "AI_API_KEY"
53
+
54
+ def _get_service_name(self, provider: str) -> str:
55
+ """Get the keyring service name for a provider."""
56
+ return f"{self.service_prefix}-{provider}-api-key"
57
+
58
+ def _prompt_and_store_key(
59
+ self, provider: str, env_var_name: str, service_name: str
60
+ ) -> Optional[str]:
61
+ """Prompt user for API key and store it in keyring."""
62
+ try:
63
+ console.print(
64
+ f"\n{provider.title()} API key not found in environment or keyring."
65
+ )
66
+ console.print("You can either:")
67
+ console.print(f" 1. Set the {env_var_name} environment variable")
68
+ console.print(
69
+ " 2. Enter it now to securely store using your operating system's credentials store"
70
+ )
71
+
72
+ api_key = getpass.getpass(
73
+ f"\nEnter your {provider.title()} API key (or press Enter to skip): "
74
+ )
75
+
76
+ if not api_key.strip():
77
+ console.print(
78
+ "No API key provided. Some functionality may not work.",
79
+ style="yellow",
80
+ )
81
+ return None
82
+
83
+ # Store in keyring for future use
84
+ try:
85
+ keyring.set_password(service_name, provider, api_key.strip())
86
+ console.print("API key stored securely for future use", style="green")
87
+ except Exception as e:
88
+ console.print(
89
+ f"Warning: Could not store API key in keyring: {e}", style="yellow"
90
+ )
91
+ console.print(
92
+ "You may need to enter it again next time", style="yellow"
93
+ )
94
+
95
+ return api_key.strip()
96
+
97
+ except KeyboardInterrupt:
98
+ console.print("\nOperation cancelled", style="yellow")
99
+ return None
100
+ except Exception as e:
101
+ console.print(f"Error prompting for API key: {e}", style="red")
102
+ return None
@@ -0,0 +1,252 @@
1
+ """Database configuration management."""
2
+
3
+ import json
4
+ import os
5
+ import platform
6
+ import stat
7
+ import keyring
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+ from typing import Dict, List, Optional, Any
11
+ from urllib.parse import quote_plus
12
+
13
+ import platformdirs
14
+
15
+
16
+ @dataclass
17
+ class DatabaseConfig:
18
+ """Database connection configuration."""
19
+
20
+ name: str
21
+ type: str # postgresql, mysql, sqlite
22
+ host: Optional[str]
23
+ port: Optional[int]
24
+ database: str
25
+ username: Optional[str]
26
+ password: Optional[str] = None
27
+ ssl_mode: Optional[str] = None
28
+ schema: Optional[str] = None
29
+
30
+ def to_connection_string(self) -> str:
31
+ """Convert config to database connection string."""
32
+ password = self.password or self._get_password_from_keyring()
33
+
34
+ if self.type == "postgresql":
35
+ if not all([self.host, self.port, self.username]):
36
+ raise ValueError("Host, port, and username are required for PostgreSQL")
37
+ if password:
38
+ encoded_password = quote_plus(password)
39
+ return f"postgresql://{self.username}:{encoded_password}@{self.host}:{self.port}/{self.database}"
40
+ else:
41
+ return f"postgresql://{self.username}@{self.host}:{self.port}/{self.database}"
42
+ elif self.type == "mysql":
43
+ if not all([self.host, self.port, self.username]):
44
+ raise ValueError("Host, port, and username are required for MySQL")
45
+ if password:
46
+ encoded_password = quote_plus(password)
47
+ return f"mysql://{self.username}:{encoded_password}@{self.host}:{self.port}/{self.database}"
48
+ else:
49
+ return (
50
+ f"mysql://{self.username}@{self.host}:{self.port}/{self.database}"
51
+ )
52
+ elif self.type == "sqlite":
53
+ return f"sqlite:///{self.database}"
54
+ else:
55
+ raise ValueError(f"Unsupported database type: {self.type}")
56
+
57
+ def _get_password_from_keyring(self) -> Optional[str]:
58
+ """Get password from OS keyring."""
59
+ try:
60
+ return keyring.get_password("sqlsaber", f"{self.name}_{self.username}")
61
+ except Exception:
62
+ return None
63
+
64
+ def store_password_in_keyring(self, password: str) -> None:
65
+ """Store password in OS keyring."""
66
+ keyring.set_password("sqlsaber", f"{self.name}_{self.username}", password)
67
+
68
+ def delete_password_from_keyring(self) -> None:
69
+ """Delete password from OS keyring."""
70
+ try:
71
+ keyring.delete_password("sqlsaber", f"{self.name}_{self.username}")
72
+ except Exception:
73
+ pass
74
+
75
+ def to_dict(self) -> Dict[str, Any]:
76
+ """Convert to dictionary for JSON serialization."""
77
+ return {
78
+ "name": self.name,
79
+ "type": self.type,
80
+ "host": self.host,
81
+ "port": self.port,
82
+ "database": self.database,
83
+ "username": self.username,
84
+ "ssl_mode": self.ssl_mode,
85
+ "schema": self.schema,
86
+ }
87
+
88
+ @classmethod
89
+ def from_dict(cls, data: Dict[str, Any]) -> "DatabaseConfig":
90
+ """Create from dictionary."""
91
+ return cls(
92
+ name=data["name"],
93
+ type=data["type"],
94
+ host=data["host"],
95
+ port=data["port"],
96
+ database=data["database"],
97
+ username=data["username"],
98
+ ssl_mode=data.get("ssl_mode"),
99
+ schema=data.get("schema"),
100
+ )
101
+
102
+
103
+ class DatabaseConfigManager:
104
+ """Manages database configurations."""
105
+
106
+ def __init__(self):
107
+ self.config_dir = Path(platformdirs.user_config_dir("sqlsaber", "sqlsaber"))
108
+ self.config_file = self.config_dir / "database_config.json"
109
+ self._ensure_config_dir()
110
+
111
+ def _ensure_config_dir(self) -> None:
112
+ """Ensure config directory exists with proper permissions."""
113
+ self.config_dir.mkdir(parents=True, exist_ok=True)
114
+ self._set_secure_permissions(self.config_dir, is_directory=True)
115
+
116
+ def _set_secure_permissions(self, path: Path, is_directory: bool = False) -> None:
117
+ """Set secure permissions cross-platform."""
118
+ try:
119
+ if platform.system() == "Windows":
120
+ # On Windows, rely on NTFS permissions and avoid chmod
121
+ # The default permissions are usually sufficient for user-only access
122
+ return
123
+ else:
124
+ # Unix-like systems (Linux, macOS)
125
+ if is_directory:
126
+ os.chmod(
127
+ path, stat.S_IRWXU
128
+ ) # 0o700 - owner read/write/execute only
129
+ else:
130
+ os.chmod(
131
+ path, stat.S_IRUSR | stat.S_IWUSR
132
+ ) # 0o600 - owner read/write only
133
+ except (OSError, PermissionError):
134
+ # If we can't set permissions, continue anyway
135
+ # The directory/file creation should still work
136
+ pass
137
+
138
+ def _load_config(self) -> Dict[str, Any]:
139
+ """Load configuration from file."""
140
+ if not self.config_file.exists():
141
+ return {"default": None, "connections": {}}
142
+
143
+ try:
144
+ with open(self.config_file, "r") as f:
145
+ return json.load(f)
146
+ except (json.JSONDecodeError, IOError):
147
+ return {"default": None, "connections": {}}
148
+
149
+ def _save_config(self, config: Dict[str, Any]) -> None:
150
+ """Save configuration to file."""
151
+ with open(self.config_file, "w") as f:
152
+ json.dump(config, f, indent=2)
153
+
154
+ # Set secure permissions cross-platform
155
+ self._set_secure_permissions(self.config_file, is_directory=False)
156
+
157
+ def add_database(
158
+ self, db_config: DatabaseConfig, password: Optional[str] = None
159
+ ) -> None:
160
+ """Add a database configuration."""
161
+ config = self._load_config()
162
+
163
+ # Check if database with this name already exists
164
+ if db_config.name in config["connections"]:
165
+ raise ValueError(f"Database '{db_config.name}' already exists")
166
+
167
+ # Store password in keyring if provided
168
+ if password:
169
+ db_config.store_password_in_keyring(password)
170
+
171
+ # Add to config
172
+ config["connections"][db_config.name] = db_config.to_dict()
173
+
174
+ # Set as default if it's the first one
175
+ if not config["default"]:
176
+ config["default"] = db_config.name
177
+
178
+ self._save_config(config)
179
+
180
+ def get_database(self, name: str) -> Optional[DatabaseConfig]:
181
+ """Get a database configuration by name."""
182
+ config = self._load_config()
183
+
184
+ if name not in config["connections"]:
185
+ return None
186
+
187
+ return DatabaseConfig.from_dict(config["connections"][name])
188
+
189
+ def get_default_database(self) -> Optional[DatabaseConfig]:
190
+ """Get the default database configuration."""
191
+ config = self._load_config()
192
+
193
+ default_name = config.get("default")
194
+ if not default_name:
195
+ return None
196
+
197
+ return self.get_database(default_name)
198
+
199
+ def list_databases(self) -> List[DatabaseConfig]:
200
+ """List all database configurations."""
201
+ config = self._load_config()
202
+
203
+ databases = []
204
+ for name, db_data in config["connections"].items():
205
+ databases.append(DatabaseConfig.from_dict(db_data))
206
+
207
+ return databases
208
+
209
+ def remove_database(self, name: str) -> bool:
210
+ """Remove a database configuration."""
211
+ config = self._load_config()
212
+
213
+ if name not in config["connections"]:
214
+ return False
215
+
216
+ # Remove password from keyring
217
+ db_config = DatabaseConfig.from_dict(config["connections"][name])
218
+ db_config.delete_password_from_keyring()
219
+
220
+ # Remove from config
221
+ del config["connections"][name]
222
+
223
+ # Update default if this was the default
224
+ if config["default"] == name:
225
+ remaining_connections = list(config["connections"].keys())
226
+ config["default"] = (
227
+ remaining_connections[0] if remaining_connections else None
228
+ )
229
+
230
+ self._save_config(config)
231
+ return True
232
+
233
+ def set_default_database(self, name: str) -> bool:
234
+ """Set the default database."""
235
+ config = self._load_config()
236
+
237
+ if name not in config["connections"]:
238
+ return False
239
+
240
+ config["default"] = name
241
+ self._save_config(config)
242
+ return True
243
+
244
+ def has_databases(self) -> bool:
245
+ """Check if any databases are configured."""
246
+ config = self._load_config()
247
+ return len(config["connections"]) > 0
248
+
249
+ def get_default_name(self) -> Optional[str]:
250
+ """Get the name of the default database."""
251
+ config = self._load_config()
252
+ return config.get("default")
@@ -0,0 +1,115 @@
1
+ """Configuration management for SQLSaber SQL Agent."""
2
+
3
+ import json
4
+ import os
5
+ import platform
6
+ import stat
7
+ from pathlib import Path
8
+ from typing import Any, Dict, Optional
9
+
10
+ import platformdirs
11
+
12
+ from sqlsaber.config.api_keys import APIKeyManager
13
+
14
+
15
+ class ModelConfigManager:
16
+ """Manages model configuration persistence."""
17
+
18
+ DEFAULT_MODEL = "anthropic:claude-sonnet-4-20250514"
19
+
20
+ def __init__(self):
21
+ self.config_dir = Path(platformdirs.user_config_dir("sqlsaber", "sqlsaber"))
22
+ self.config_file = self.config_dir / "model_config.json"
23
+ self._ensure_config_dir()
24
+
25
+ def _ensure_config_dir(self) -> None:
26
+ """Ensure config directory exists with proper permissions."""
27
+ self.config_dir.mkdir(parents=True, exist_ok=True)
28
+ self._set_secure_permissions(self.config_dir, is_directory=True)
29
+
30
+ def _set_secure_permissions(self, path: Path, is_directory: bool = False) -> None:
31
+ """Set secure permissions cross-platform."""
32
+ try:
33
+ if platform.system() == "Windows":
34
+ return
35
+ else:
36
+ if is_directory:
37
+ os.chmod(path, stat.S_IRWXU) # 0o700
38
+ else:
39
+ os.chmod(path, stat.S_IRUSR | stat.S_IWUSR) # 0o600
40
+ except (OSError, PermissionError):
41
+ pass
42
+
43
+ def _load_config(self) -> Dict[str, Any]:
44
+ """Load configuration from file."""
45
+ if not self.config_file.exists():
46
+ return {"model": self.DEFAULT_MODEL}
47
+
48
+ try:
49
+ with open(self.config_file, "r") as f:
50
+ config = json.load(f)
51
+ # Ensure we have a model set
52
+ if "model" not in config:
53
+ config["model"] = self.DEFAULT_MODEL
54
+ return config
55
+ except (json.JSONDecodeError, IOError):
56
+ return {"model": self.DEFAULT_MODEL}
57
+
58
+ def _save_config(self, config: Dict[str, Any]) -> None:
59
+ """Save configuration to file."""
60
+ with open(self.config_file, "w") as f:
61
+ json.dump(config, f, indent=2)
62
+
63
+ self._set_secure_permissions(self.config_file, is_directory=False)
64
+
65
+ def get_model(self) -> str:
66
+ """Get the configured model."""
67
+ config = self._load_config()
68
+ return config.get("model", self.DEFAULT_MODEL)
69
+
70
+ def set_model(self, model: str) -> None:
71
+ """Set the model configuration."""
72
+ config = self._load_config()
73
+ config["model"] = model
74
+ self._save_config(config)
75
+
76
+
77
+ class Config:
78
+ """Configuration class for SQLSaber."""
79
+
80
+ def __init__(self):
81
+ self.model_config_manager = ModelConfigManager()
82
+ self.model_name = self.model_config_manager.get_model()
83
+ self.api_key_manager = APIKeyManager()
84
+ self.api_key = self._get_api_key()
85
+
86
+ def _get_api_key(self) -> Optional[str]:
87
+ """Get API key for the model provider using cascading logic."""
88
+ model = self.model_name
89
+
90
+ if model.startswith("openai:"):
91
+ return self.api_key_manager.get_api_key("openai")
92
+ elif model.startswith("anthropic:"):
93
+ return self.api_key_manager.get_api_key("anthropic")
94
+ else:
95
+ # For other providers, use generic key
96
+ return self.api_key_manager.get_api_key("generic")
97
+
98
+ def set_model(self, model: str) -> None:
99
+ """Set the model and update configuration."""
100
+ self.model_config_manager.set_model(model)
101
+ self.model_name = model
102
+ # Update API key for new model
103
+ self.api_key = self._get_api_key()
104
+
105
+ def validate(self):
106
+ """Validate that necessary configuration is present."""
107
+ if not self.api_key:
108
+ model = self.model_name
109
+ provider = "generic"
110
+ if model.startswith("openai:"):
111
+ provider = "OpenAI"
112
+ elif model.startswith("anthropic:"):
113
+ provider = "Anthropic"
114
+
115
+ raise ValueError(f"{provider} API key not found.")
@@ -0,0 +1,9 @@
1
+ """Database module for SQLSaber."""
2
+
3
+ from .connection import DatabaseConnection
4
+ from .schema import SchemaManager
5
+
6
+ __all__ = [
7
+ "DatabaseConnection",
8
+ "SchemaManager",
9
+ ]
@@ -0,0 +1,187 @@
1
+ """Database connection management."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any, Dict, List, Optional
5
+ from urllib.parse import urlparse
6
+
7
+ import aiomysql
8
+ import aiosqlite
9
+ import asyncpg
10
+
11
+
12
+ class BaseDatabaseConnection(ABC):
13
+ """Abstract base class for database connections."""
14
+
15
+ def __init__(self, connection_string: str):
16
+ self.connection_string = connection_string
17
+ self._pool = None
18
+
19
+ @abstractmethod
20
+ async def get_pool(self):
21
+ """Get or create connection pool."""
22
+ pass
23
+
24
+ @abstractmethod
25
+ async def close(self):
26
+ """Close the connection pool."""
27
+ pass
28
+
29
+ @abstractmethod
30
+ async def execute_query(self, query: str, *args) -> List[Dict[str, Any]]:
31
+ """Execute a query and return results as list of dicts.
32
+
33
+ All queries run in a transaction that is rolled back at the end,
34
+ ensuring no changes are persisted to the database.
35
+ """
36
+ pass
37
+
38
+
39
+ class PostgreSQLConnection(BaseDatabaseConnection):
40
+ """PostgreSQL database connection using asyncpg."""
41
+
42
+ def __init__(self, connection_string: str):
43
+ super().__init__(connection_string)
44
+ self._pool: Optional[asyncpg.Pool] = None
45
+
46
+ async def get_pool(self) -> asyncpg.Pool:
47
+ """Get or create connection pool."""
48
+ if self._pool is None:
49
+ self._pool = await asyncpg.create_pool(
50
+ self.connection_string, min_size=1, max_size=10
51
+ )
52
+ return self._pool
53
+
54
+ async def close(self):
55
+ """Close the connection pool."""
56
+ if self._pool:
57
+ await self._pool.close()
58
+ self._pool = None
59
+
60
+ async def execute_query(self, query: str, *args) -> List[Dict[str, Any]]:
61
+ """Execute a query and return results as list of dicts.
62
+
63
+ All queries run in a transaction that is rolled back at the end,
64
+ ensuring no changes are persisted to the database.
65
+ """
66
+ pool = await self.get_pool()
67
+ async with pool.acquire() as conn:
68
+ # Start a transaction that we'll always rollback
69
+ transaction = conn.transaction()
70
+ await transaction.start()
71
+
72
+ try:
73
+ rows = await conn.fetch(query, *args)
74
+ return [dict(row) for row in rows]
75
+ finally:
76
+ # Always rollback to ensure no changes are committed
77
+ await transaction.rollback()
78
+
79
+
80
+ class MySQLConnection(BaseDatabaseConnection):
81
+ """MySQL database connection using aiomysql."""
82
+
83
+ def __init__(self, connection_string: str):
84
+ super().__init__(connection_string)
85
+ self._pool: Optional[aiomysql.Pool] = None
86
+ self._parse_connection_string()
87
+
88
+ def _parse_connection_string(self):
89
+ """Parse MySQL connection string into components."""
90
+ parsed = urlparse(self.connection_string)
91
+ self.host = parsed.hostname or "localhost"
92
+ self.port = parsed.port or 3306
93
+ self.database = parsed.path.lstrip("/") if parsed.path else ""
94
+ self.user = parsed.username or ""
95
+ self.password = parsed.password or ""
96
+
97
+ async def get_pool(self) -> aiomysql.Pool:
98
+ """Get or create connection pool."""
99
+ if self._pool is None:
100
+ self._pool = await aiomysql.create_pool(
101
+ host=self.host,
102
+ port=self.port,
103
+ user=self.user,
104
+ password=self.password,
105
+ db=self.database,
106
+ minsize=1,
107
+ maxsize=10,
108
+ autocommit=False,
109
+ )
110
+ return self._pool
111
+
112
+ async def close(self):
113
+ """Close the connection pool."""
114
+ if self._pool:
115
+ self._pool.close()
116
+ await self._pool.wait_closed()
117
+ self._pool = None
118
+
119
+ async def execute_query(self, query: str, *args) -> List[Dict[str, Any]]:
120
+ """Execute a query and return results as list of dicts.
121
+
122
+ All queries run in a transaction that is rolled back at the end,
123
+ ensuring no changes are persisted to the database.
124
+ """
125
+ pool = await self.get_pool()
126
+ async with pool.acquire() as conn:
127
+ async with conn.cursor(aiomysql.DictCursor) as cursor:
128
+ # Start transaction
129
+ await conn.begin()
130
+ try:
131
+ await cursor.execute(query, args if args else None)
132
+ rows = await cursor.fetchall()
133
+ return [dict(row) for row in rows]
134
+ finally:
135
+ # Always rollback to ensure no changes are committed
136
+ await conn.rollback()
137
+
138
+
139
+ class SQLiteConnection(BaseDatabaseConnection):
140
+ """SQLite database connection using aiosqlite."""
141
+
142
+ def __init__(self, connection_string: str):
143
+ super().__init__(connection_string)
144
+ # Extract database path from sqlite:///path format
145
+ self.database_path = connection_string.replace("sqlite:///", "")
146
+
147
+ async def get_pool(self):
148
+ """SQLite doesn't use connection pooling, return database path."""
149
+ return self.database_path
150
+
151
+ async def close(self):
152
+ """SQLite connections are created per query, no persistent pool to close."""
153
+ pass
154
+
155
+ async def execute_query(self, query: str, *args) -> List[Dict[str, Any]]:
156
+ """Execute a query and return results as list of dicts.
157
+
158
+ All queries run in a transaction that is rolled back at the end,
159
+ ensuring no changes are persisted to the database.
160
+ """
161
+ async with aiosqlite.connect(self.database_path) as conn:
162
+ # Enable row factory for dict-like access
163
+ conn.row_factory = aiosqlite.Row
164
+
165
+ # Start transaction
166
+ await conn.execute("BEGIN")
167
+ try:
168
+ cursor = await conn.execute(query, args if args else ())
169
+ rows = await cursor.fetchall()
170
+ return [dict(row) for row in rows]
171
+ finally:
172
+ # Always rollback to ensure no changes are committed
173
+ await conn.rollback()
174
+
175
+
176
+ def DatabaseConnection(connection_string: str) -> BaseDatabaseConnection:
177
+ """Factory function to create appropriate database connection based on connection string."""
178
+ if connection_string.startswith("postgresql://"):
179
+ return PostgreSQLConnection(connection_string)
180
+ elif connection_string.startswith("mysql://"):
181
+ return MySQLConnection(connection_string)
182
+ elif connection_string.startswith("sqlite:///"):
183
+ return SQLiteConnection(connection_string)
184
+ else:
185
+ raise ValueError(
186
+ f"Unsupported database type in connection string: {connection_string}"
187
+ )