prismiq 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
prismiq/executor.py ADDED
@@ -0,0 +1,345 @@
1
+ """Query executor for running validated queries against PostgreSQL.
2
+
3
+ This module provides the QueryExecutor class that executes queries with
4
+ timeout handling, row limits, and proper result formatting.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import asyncio
10
+ import time
11
+ from datetime import date, datetime, timedelta
12
+ from datetime import time as time_type
13
+ from decimal import Decimal
14
+ from typing import TYPE_CHECKING, Any
15
+ from uuid import UUID
16
+
17
+ from prismiq.query import QueryBuilder
18
+ from prismiq.sql_validator import SQLValidationError, SQLValidator
19
+ from prismiq.types import (
20
+ DatabaseSchema,
21
+ QueryDefinition,
22
+ QueryExecutionError,
23
+ QueryResult,
24
+ QueryTimeoutError,
25
+ QueryValidationError,
26
+ )
27
+
28
+ if TYPE_CHECKING:
29
+ from asyncpg import Pool # type: ignore[import-not-found]
30
+
31
+
32
+ def serialize_value(value: Any) -> Any:
33
+ """Convert database values to JSON-serializable Python types."""
34
+ if value is None:
35
+ return None
36
+ if isinstance(value, Decimal):
37
+ # Convert Decimal to float for JSON serialization
38
+ return float(value)
39
+ if isinstance(value, datetime | date):
40
+ return value.isoformat()
41
+ if isinstance(value, time_type):
42
+ return value.isoformat()
43
+ if isinstance(value, timedelta):
44
+ return value.total_seconds()
45
+ if isinstance(value, UUID):
46
+ return str(value)
47
+ if isinstance(value, bytes):
48
+ return value.hex()
49
+ if isinstance(value, list | tuple):
50
+ return [serialize_value(v) for v in value]
51
+ if isinstance(value, dict):
52
+ return {k: serialize_value(v) for k, v in value.items()}
53
+ return value
54
+
55
+
56
+ class QueryExecutor:
57
+ """Executes validated queries against a PostgreSQL database.
58
+
59
+ Handles query validation, timeout enforcement, row limits,
60
+ and result formatting.
61
+
62
+ Example:
63
+ >>> executor = QueryExecutor(pool, schema, query_timeout=30.0)
64
+ >>> result = await executor.execute(query_definition)
65
+ >>> print(result.row_count)
66
+ 100
67
+ """
68
+
69
+ def __init__(
70
+ self,
71
+ pool: Pool,
72
+ schema: DatabaseSchema,
73
+ query_timeout: float = 30.0,
74
+ max_rows: int = 10000,
75
+ schema_name: str | None = None,
76
+ ) -> None:
77
+ """Initialize the query executor.
78
+
79
+ Args:
80
+ pool: asyncpg connection pool.
81
+ schema: Database schema for validation.
82
+ query_timeout: Maximum query execution time in seconds.
83
+ max_rows: Maximum number of rows to return.
84
+ schema_name: PostgreSQL schema name for schema-qualified SQL generation.
85
+ If None, tables are referenced without schema prefix.
86
+ """
87
+ self._pool = pool
88
+ self._schema = schema
89
+ self._query_timeout = query_timeout
90
+ self._max_rows = max_rows
91
+ self._schema_name = schema_name
92
+ self._builder = QueryBuilder(schema, schema_name=schema_name)
93
+ self._sql_validator = SQLValidator(schema)
94
+
95
+ async def execute(self, query: QueryDefinition) -> QueryResult:
96
+ """Execute a query and return results.
97
+
98
+ Args:
99
+ query: Query definition to execute.
100
+
101
+ Returns:
102
+ QueryResult with columns, rows, and execution metadata.
103
+
104
+ Raises:
105
+ QueryValidationError: If the query fails validation.
106
+ QueryTimeoutError: If the query exceeds the timeout.
107
+ QueryExecutionError: If the query execution fails.
108
+ """
109
+ # Sanitize filters - remove filters referencing non-existent columns
110
+ query = self._builder.sanitize_filters(query)
111
+
112
+ # Validate
113
+ errors = self._builder.validate(query)
114
+ if errors:
115
+ raise QueryValidationError("Query validation failed", errors=errors)
116
+
117
+ # Build SQL
118
+ sql, params = self._builder.build(query)
119
+
120
+ # Apply row limit if not already specified
121
+ effective_limit = query.limit
122
+ truncated = False
123
+ if effective_limit is None or effective_limit > self._max_rows:
124
+ # Add limit to params and update SQL
125
+ # We need to rebuild with the limit
126
+ limited_query = query.model_copy(update={"limit": self._max_rows + 1})
127
+ sql, params = self._builder.build(limited_query)
128
+ truncated = True
129
+
130
+ # Execute with timeout
131
+ start_time = time.perf_counter()
132
+ try:
133
+ rows = await self._execute_with_timeout(sql, params)
134
+ except asyncio.TimeoutError as e:
135
+ raise QueryTimeoutError(
136
+ f"Query exceeded timeout of {self._query_timeout} seconds",
137
+ timeout_seconds=self._query_timeout,
138
+ ) from e
139
+ except Exception as e:
140
+ raise QueryExecutionError(str(e), sql=sql) from e
141
+
142
+ execution_time_ms = (time.perf_counter() - start_time) * 1000
143
+
144
+ # Check if result was truncated
145
+ if truncated and len(rows) > self._max_rows:
146
+ rows = rows[: self._max_rows]
147
+ truncated = True
148
+ else:
149
+ truncated = False
150
+
151
+ # Convert to QueryResult
152
+ return self._format_result(rows, execution_time_ms, truncated)
153
+
154
+ async def preview(self, query: QueryDefinition, limit: int = 100) -> QueryResult:
155
+ """Execute a query with a smaller limit for quick preview.
156
+
157
+ Args:
158
+ query: Query definition to execute.
159
+ limit: Maximum rows to return (default: 100).
160
+
161
+ Returns:
162
+ QueryResult with limited rows.
163
+ """
164
+ # Create a copy with the specified limit
165
+ preview_query = query.model_copy(update={"limit": min(limit, self._max_rows)})
166
+ return await self.execute(preview_query)
167
+
168
+ async def explain(self, query: QueryDefinition) -> dict[str, Any]:
169
+ """Run EXPLAIN ANALYZE on a query.
170
+
171
+ Args:
172
+ query: Query definition to analyze.
173
+
174
+ Returns:
175
+ Query plan as a dictionary.
176
+
177
+ Raises:
178
+ QueryValidationError: If the query fails validation.
179
+ QueryExecutionError: If the explain fails.
180
+ """
181
+ # Validate first
182
+ errors = self._builder.validate(query)
183
+ if errors:
184
+ raise QueryValidationError("Query validation failed", errors=errors)
185
+
186
+ # Build SQL
187
+ sql, params = self._builder.build(query)
188
+
189
+ # Wrap with EXPLAIN ANALYZE
190
+ explain_sql = f"EXPLAIN (ANALYZE, FORMAT JSON) {sql}"
191
+
192
+ try:
193
+ async with self._pool.acquire() as conn:
194
+ result = await conn.fetchval(explain_sql, *params)
195
+ if result and isinstance(result, list) and len(result) > 0:
196
+ return result[0]
197
+ return {"plan": result}
198
+ except Exception as e:
199
+ raise QueryExecutionError(str(e), sql=explain_sql) from e
200
+
201
+ async def execute_raw_sql(
202
+ self,
203
+ sql: str,
204
+ params: dict[str, Any] | None = None,
205
+ tenant_id: str | None = None,
206
+ tenant_column: str = "tenant_id",
207
+ ) -> QueryResult:
208
+ """Execute a raw SQL query with validation.
209
+
210
+ Args:
211
+ sql: Raw SQL query (must be SELECT only).
212
+ params: Optional named parameters for the query.
213
+ tenant_id: Optional tenant ID for row-level filtering.
214
+ tenant_column: Column name for tenant filtering (default: 'tenant_id').
215
+
216
+ Returns:
217
+ QueryResult with columns, rows, and execution metadata.
218
+
219
+ Raises:
220
+ SQLValidationError: If the SQL fails validation.
221
+ QueryTimeoutError: If the query exceeds the timeout.
222
+ QueryExecutionError: If the query execution fails.
223
+ """
224
+ # Validate the SQL
225
+ validation = self._sql_validator.validate(sql)
226
+ if not validation.valid:
227
+ raise SQLValidationError(
228
+ "SQL validation failed: " + "; ".join(validation.errors),
229
+ errors=validation.errors,
230
+ )
231
+
232
+ # Use the sanitized SQL from validation
233
+ safe_sql = validation.sanitized_sql
234
+ assert safe_sql is not None # Guaranteed by valid=True
235
+
236
+ # Apply row limit using a CTE wrapper
237
+ limited_sql = f"WITH _cte AS ({safe_sql}) SELECT * FROM _cte LIMIT {self._max_rows + 1}"
238
+
239
+ # Convert named params to positional for asyncpg
240
+ # asyncpg uses $1, $2, etc. for positional params
241
+ param_list: list[Any] = []
242
+ if params:
243
+ # Replace :name with $n
244
+ import re
245
+
246
+ param_index = 1
247
+ param_mapping: dict[str, int] = {}
248
+
249
+ def replace_param(match: re.Match[str]) -> str:
250
+ nonlocal param_index
251
+ name = match.group(1)
252
+ if name not in param_mapping:
253
+ param_mapping[name] = param_index
254
+ param_index += 1
255
+ return f"${param_mapping[name]}"
256
+
257
+ limited_sql = re.sub(r":([a-zA-Z_][a-zA-Z0-9_]*)", replace_param, limited_sql)
258
+
259
+ # Build param list in order
260
+ param_list = [None] * len(param_mapping)
261
+ for name, idx in param_mapping.items():
262
+ if name in params:
263
+ param_list[idx - 1] = params[name]
264
+ else:
265
+ raise SQLValidationError(
266
+ f"Missing parameter: {name}",
267
+ errors=[f"Parameter '{name}' referenced in SQL but not provided"],
268
+ )
269
+
270
+ # Execute with timeout
271
+ start_time = time.perf_counter()
272
+ try:
273
+ rows = await self._execute_with_timeout(limited_sql, param_list)
274
+ except asyncio.TimeoutError as e:
275
+ raise QueryTimeoutError(
276
+ f"Query exceeded timeout of {self._query_timeout} seconds",
277
+ timeout_seconds=self._query_timeout,
278
+ ) from e
279
+ except Exception as e:
280
+ raise QueryExecutionError(str(e), sql=sql) from e
281
+
282
+ execution_time_ms = (time.perf_counter() - start_time) * 1000
283
+
284
+ # Check if result was truncated
285
+ truncated = len(rows) > self._max_rows
286
+ if truncated:
287
+ rows = rows[: self._max_rows]
288
+
289
+ return self._format_result(rows, execution_time_ms, truncated)
290
+
291
+ async def _execute_with_timeout(self, sql: str, params: list[Any]) -> list[Any]:
292
+ """Execute SQL with timeout."""
293
+ async with self._pool.acquire() as conn:
294
+ # Set statement timeout on the connection
295
+ timeout_ms = int(self._query_timeout * 1000)
296
+ await conn.execute(f"SET statement_timeout = {timeout_ms}")
297
+
298
+ try:
299
+ return await asyncio.wait_for(
300
+ conn.fetch(sql, *params),
301
+ timeout=self._query_timeout,
302
+ )
303
+ finally:
304
+ # Reset statement timeout
305
+ await conn.execute("SET statement_timeout = 0")
306
+
307
+ def _format_result(
308
+ self, rows: list[Any], execution_time_ms: float, truncated: bool
309
+ ) -> QueryResult:
310
+ """Format raw database rows into QueryResult."""
311
+ if not rows:
312
+ return QueryResult(
313
+ columns=[],
314
+ column_types=[],
315
+ rows=[],
316
+ row_count=0,
317
+ truncated=False,
318
+ execution_time_ms=execution_time_ms,
319
+ )
320
+
321
+ # Extract column names and types from first row
322
+ first_row = rows[0]
323
+ columns = list(first_row.keys())
324
+
325
+ # Get column types (using Python type names)
326
+ # Note: Must iterate over keys(), not the record itself (which yields values)
327
+ column_types: list[str] = []
328
+ for key in columns:
329
+ value = first_row[key]
330
+ if value is None:
331
+ column_types.append("unknown")
332
+ else:
333
+ column_types.append(type(value).__name__)
334
+
335
+ # Convert rows to lists with JSON-serializable values
336
+ result_rows = [[serialize_value(v) for v in row.values()] for row in rows]
337
+
338
+ return QueryResult(
339
+ columns=columns,
340
+ column_types=column_types,
341
+ rows=result_rows,
342
+ row_count=len(result_rows),
343
+ truncated=truncated,
344
+ execution_time_ms=execution_time_ms,
345
+ )