sqlsaber 0.25.0__py3-none-any.whl → 0.26.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlsaber might be problematic. Click here for more details.

@@ -0,0 +1,328 @@
1
+ """PostgreSQL database connection and schema introspection."""
2
+
3
+ import asyncio
4
+ import ssl
5
+ from typing import Any
6
+ from urllib.parse import parse_qs, urlparse
7
+
8
+ import asyncpg
9
+
10
+ from .base import (
11
+ DEFAULT_QUERY_TIMEOUT,
12
+ BaseDatabaseConnection,
13
+ BaseSchemaIntrospector,
14
+ QueryTimeoutError,
15
+ )
16
+
17
+
18
+ class PostgreSQLConnection(BaseDatabaseConnection):
19
+ """PostgreSQL database connection using asyncpg."""
20
+
21
+ def __init__(self, connection_string: str):
22
+ super().__init__(connection_string)
23
+ self._pool: asyncpg.Pool | None = None
24
+ self._ssl_context = self._create_ssl_context()
25
+
26
+ def _create_ssl_context(self) -> ssl.SSLContext | None:
27
+ """Create SSL context from connection string parameters."""
28
+ parsed = urlparse(self.connection_string)
29
+ if not parsed.query:
30
+ return None
31
+
32
+ params = parse_qs(parsed.query)
33
+ ssl_mode = params.get("sslmode", [None])[0]
34
+
35
+ if not ssl_mode or ssl_mode == "disable":
36
+ return None
37
+
38
+ # Create SSL context based on mode
39
+ if ssl_mode in ["require", "verify-ca", "verify-full"]:
40
+ ssl_context = ssl.create_default_context()
41
+
42
+ # Configure certificate verification
43
+ if ssl_mode == "require":
44
+ ssl_context.check_hostname = False
45
+ ssl_context.verify_mode = ssl.CERT_NONE
46
+ elif ssl_mode == "verify-ca":
47
+ ssl_context.check_hostname = False
48
+ ssl_context.verify_mode = ssl.CERT_REQUIRED
49
+ elif ssl_mode == "verify-full":
50
+ ssl_context.check_hostname = True
51
+ ssl_context.verify_mode = ssl.CERT_REQUIRED
52
+
53
+ # Load certificates if provided
54
+ ssl_ca = params.get("sslrootcert", [None])[0]
55
+ ssl_cert = params.get("sslcert", [None])[0]
56
+ ssl_key = params.get("sslkey", [None])[0]
57
+
58
+ if ssl_ca:
59
+ ssl_context.load_verify_locations(ssl_ca)
60
+
61
+ if ssl_cert and ssl_key:
62
+ ssl_context.load_cert_chain(ssl_cert, ssl_key)
63
+
64
+ return ssl_context
65
+
66
+ return None
67
+
68
+ async def get_pool(self) -> asyncpg.Pool:
69
+ """Get or create connection pool."""
70
+ if self._pool is None:
71
+ # Create pool with SSL context if configured
72
+ if self._ssl_context:
73
+ self._pool = await asyncpg.create_pool(
74
+ self.connection_string,
75
+ min_size=1,
76
+ max_size=10,
77
+ ssl=self._ssl_context,
78
+ )
79
+ else:
80
+ self._pool = await asyncpg.create_pool(
81
+ self.connection_string, min_size=1, max_size=10
82
+ )
83
+ return self._pool
84
+
85
+ async def close(self):
86
+ """Close the connection pool."""
87
+ if self._pool:
88
+ await self._pool.close()
89
+ self._pool = None
90
+
91
+ async def execute_query(
92
+ self, query: str, *args, timeout: float | None = None
93
+ ) -> list[dict[str, Any]]:
94
+ """Execute a query and return results as list of dicts.
95
+
96
+ All queries run in a transaction that is rolled back at the end,
97
+ ensuring no changes are persisted to the database.
98
+ """
99
+ effective_timeout = timeout or DEFAULT_QUERY_TIMEOUT
100
+ pool = await self.get_pool()
101
+
102
+ async with pool.acquire() as conn:
103
+ # Start a transaction that we'll always rollback
104
+ transaction = conn.transaction()
105
+ await transaction.start()
106
+
107
+ try:
108
+ # Set server-side timeout if specified
109
+ if effective_timeout:
110
+ # Clamp timeout to sane range (10ms to 5 minutes) and validate
111
+ timeout_ms = max(10, min(int(effective_timeout * 1000), 300000))
112
+ await conn.execute(f"SET LOCAL statement_timeout = {timeout_ms}")
113
+
114
+ # Execute query with client-side timeout
115
+ if effective_timeout:
116
+ rows = await asyncio.wait_for(
117
+ conn.fetch(query, *args), timeout=effective_timeout
118
+ )
119
+ else:
120
+ rows = await conn.fetch(query, *args)
121
+
122
+ return [dict(row) for row in rows]
123
+ except asyncio.TimeoutError as exc:
124
+ raise QueryTimeoutError(effective_timeout or 0) from exc
125
+ finally:
126
+ # Always rollback to ensure no changes are committed
127
+ await transaction.rollback()
128
+
129
+
130
+ class PostgreSQLSchemaIntrospector(BaseSchemaIntrospector):
131
+ """PostgreSQL-specific schema introspection."""
132
+
133
+ def _build_table_filter_clause(self, tables: list) -> tuple[str, list]:
134
+ """Build VALUES clause with bind parameters for table filtering.
135
+
136
+ Args:
137
+ tables: List of table dictionaries with table_schema and table_name keys
138
+
139
+ Returns:
140
+ Tuple of (values_clause, params) for use in SQL queries
141
+ """
142
+ if not tables:
143
+ return "", []
144
+
145
+ table_pairs = [(table["table_schema"], table["table_name"]) for table in tables]
146
+ values_clause = ", ".join(
147
+ [f"(${2 * i + 1}, ${2 * i + 2})" for i in range(len(table_pairs))]
148
+ )
149
+ params = [value for pair in table_pairs for value in pair]
150
+ return values_clause, params
151
+
152
+ async def get_tables_info(
153
+ self, connection, table_pattern: str | None = None
154
+ ) -> dict[str, Any]:
155
+ """Get tables information for PostgreSQL."""
156
+ pool = await connection.get_pool()
157
+ async with pool.acquire() as conn:
158
+ # Build WHERE clause for filtering
159
+ where_conditions = [
160
+ "table_schema NOT IN ('pg_catalog', 'information_schema')"
161
+ ]
162
+ params = []
163
+
164
+ if table_pattern:
165
+ # Support patterns like 'schema.table' or just 'table'
166
+ if "." in table_pattern:
167
+ schema_pattern, table_name_pattern = table_pattern.split(".", 1)
168
+ where_conditions.append(
169
+ "(table_schema LIKE $1 AND table_name LIKE $2)"
170
+ )
171
+ params.extend([schema_pattern, table_name_pattern])
172
+ else:
173
+ where_conditions.append(
174
+ "(table_name LIKE $1 OR table_schema || '.' || table_name LIKE $1)"
175
+ )
176
+ params.append(table_pattern)
177
+
178
+ # Get tables
179
+ tables_query = f"""
180
+ SELECT
181
+ table_schema,
182
+ table_name,
183
+ table_type
184
+ FROM information_schema.tables
185
+ WHERE {" AND ".join(where_conditions)}
186
+ ORDER BY table_schema, table_name;
187
+ """
188
+ return await conn.fetch(tables_query, *params)
189
+
190
+ async def get_columns_info(self, connection, tables: list) -> list:
191
+ """Get columns information for PostgreSQL."""
192
+ if not tables:
193
+ return []
194
+
195
+ pool = await connection.get_pool()
196
+ async with pool.acquire() as conn:
197
+ values_clause, params = self._build_table_filter_clause(tables)
198
+
199
+ columns_query = f"""
200
+ SELECT
201
+ c.table_schema,
202
+ c.table_name,
203
+ c.column_name,
204
+ c.data_type,
205
+ c.is_nullable,
206
+ c.column_default,
207
+ c.character_maximum_length,
208
+ c.numeric_precision,
209
+ c.numeric_scale
210
+ FROM information_schema.columns c
211
+ WHERE (c.table_schema, c.table_name) IN (VALUES {values_clause})
212
+ ORDER BY c.table_schema, c.table_name, c.ordinal_position;
213
+ """
214
+ return await conn.fetch(columns_query, *params)
215
+
216
+ async def get_foreign_keys_info(self, connection, tables: list) -> list:
217
+ """Get foreign keys information for PostgreSQL."""
218
+ if not tables:
219
+ return []
220
+
221
+ pool = await connection.get_pool()
222
+ async with pool.acquire() as conn:
223
+ values_clause, params = self._build_table_filter_clause(tables)
224
+
225
+ fk_query = f"""
226
+ WITH t(schema, name) AS (VALUES {values_clause})
227
+ SELECT
228
+ tc.table_schema,
229
+ tc.table_name,
230
+ kcu.column_name,
231
+ ccu.table_schema AS foreign_table_schema,
232
+ ccu.table_name AS foreign_table_name,
233
+ ccu.column_name AS foreign_column_name
234
+ FROM information_schema.table_constraints AS tc
235
+ JOIN information_schema.key_column_usage AS kcu
236
+ ON tc.constraint_name = kcu.constraint_name
237
+ AND tc.table_schema = kcu.table_schema
238
+ JOIN information_schema.constraint_column_usage AS ccu
239
+ ON ccu.constraint_name = tc.constraint_name
240
+ AND ccu.table_schema = tc.table_schema
241
+ JOIN t ON t.schema = tc.table_schema AND t.name = tc.table_name
242
+ WHERE tc.constraint_type = 'FOREIGN KEY';
243
+ """
244
+ return await conn.fetch(fk_query, *params)
245
+
246
+ async def get_primary_keys_info(self, connection, tables: list) -> list:
247
+ """Get primary keys information for PostgreSQL."""
248
+ if not tables:
249
+ return []
250
+
251
+ pool = await connection.get_pool()
252
+ async with pool.acquire() as conn:
253
+ values_clause, params = self._build_table_filter_clause(tables)
254
+
255
+ pk_query = f"""
256
+ WITH t(schema, name) AS (VALUES {values_clause})
257
+ SELECT
258
+ tc.table_schema,
259
+ tc.table_name,
260
+ kcu.column_name
261
+ FROM information_schema.table_constraints AS tc
262
+ JOIN information_schema.key_column_usage AS kcu
263
+ ON tc.constraint_name = kcu.constraint_name
264
+ AND tc.table_schema = kcu.table_schema
265
+ JOIN t ON t.schema = tc.table_schema AND t.name = tc.table_name
266
+ WHERE tc.constraint_type = 'PRIMARY KEY'
267
+ ORDER BY tc.table_schema, tc.table_name, kcu.ordinal_position;
268
+ """
269
+ return await conn.fetch(pk_query, *params)
270
+
271
+ async def get_indexes_info(self, connection, tables: list) -> list:
272
+ """Get indexes information for PostgreSQL."""
273
+ if not tables:
274
+ return []
275
+
276
+ pool = await connection.get_pool()
277
+ async with pool.acquire() as conn:
278
+ values_clause, params = self._build_table_filter_clause(tables)
279
+
280
+ idx_query = f"""
281
+ WITH t_filter(schema, name) AS (VALUES {values_clause})
282
+ SELECT
283
+ ns.nspname AS table_schema,
284
+ tcls.relname AS table_name,
285
+ icls.relname AS index_name,
286
+ ix.indisunique AS is_unique,
287
+ am.amname AS index_type,
288
+ string_agg(a.attname, ',' ORDER BY att.ordinality) AS column_names
289
+ FROM pg_class tcls
290
+ JOIN pg_namespace ns ON tcls.relnamespace = ns.oid
291
+ JOIN pg_index ix ON tcls.oid = ix.indrelid
292
+ JOIN pg_class icls ON icls.oid = ix.indexrelid
293
+ JOIN pg_am am ON icls.relam = am.oid
294
+ JOIN pg_attribute a ON a.attrelid = tcls.oid
295
+ JOIN unnest(ix.indkey) WITH ORDINALITY AS att(attnum, ordinality) ON a.attnum = att.attnum
296
+ JOIN t_filter ON t_filter.schema = ns.nspname AND t_filter.name = tcls.relname
297
+ WHERE tcls.relkind = 'r'
298
+ AND icls.relname NOT LIKE '%_pkey'
299
+ GROUP BY ns.nspname, tcls.relname, icls.relname, ix.indisunique, am.amname
300
+ ORDER BY ns.nspname, tcls.relname, icls.relname;
301
+ """
302
+ return await conn.fetch(idx_query, *params)
303
+
304
+ async def list_tables_info(self, connection) -> list[dict[str, Any]]:
305
+ """Get list of tables with basic information for PostgreSQL."""
306
+ pool = await connection.get_pool()
307
+ async with pool.acquire() as conn:
308
+ # Get table names and basic info without row counts for better performance
309
+ tables_query = """
310
+ SELECT
311
+ table_schema,
312
+ table_name,
313
+ table_type
314
+ FROM information_schema.tables
315
+ WHERE table_schema NOT IN ('pg_catalog', 'information_schema')
316
+ ORDER BY table_schema, table_name;
317
+ """
318
+ tables = await conn.fetch(tables_query)
319
+
320
+ # Convert to expected format
321
+ return [
322
+ {
323
+ "table_schema": table["table_schema"],
324
+ "table_name": table["table_name"],
325
+ "table_type": table["table_type"],
326
+ }
327
+ for table in tables
328
+ ]