sql-testing-library 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,502 @@
1
+ """Core SQL testing framework."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import (
5
+ TYPE_CHECKING,
6
+ Any,
7
+ Dict,
8
+ Generic,
9
+ List,
10
+ Literal,
11
+ Optional,
12
+ Type,
13
+ TypeVar,
14
+ get_type_hints,
15
+ )
16
+
17
+
18
+ if TYPE_CHECKING:
19
+ import pandas as pd
20
+
21
+ # Heavy imports moved to function level for better performance
22
+ from ._adapters.base import DatabaseAdapter
23
+ from ._exceptions import (
24
+ MockTableNotFoundError,
25
+ QuerySizeLimitExceeded,
26
+ SQLParseError,
27
+ TypeConversionError,
28
+ )
29
+ from ._mock_table import BaseMockTable
30
+
31
+
32
+ # Type for adapter types
33
+ AdapterType = Literal["bigquery", "athena", "redshift", "trino", "snowflake"]
34
+
35
+ T = TypeVar("T")
36
+
37
+
38
+ @dataclass
39
+ class SQLTestCase(Generic[T]):
40
+ """Represents a SQL test case."""
41
+
42
+ __test__ = False # Tell pytest this is not a test class
43
+
44
+ query: str
45
+ default_namespace: Optional[str] = None
46
+ mock_tables: Optional[List[BaseMockTable]] = None
47
+ result_class: Optional[Type[T]] = None
48
+ use_physical_tables: bool = False
49
+ description: Optional[str] = None
50
+ adapter_type: Optional[AdapterType] = None
51
+ # Backward compatibility
52
+ execution_database: Optional[str] = None
53
+
54
+ def __post_init__(self) -> None:
55
+ """Handle backward compatibility for execution_database parameter."""
56
+ if self.execution_database is not None and self.default_namespace is not None:
57
+ # Both provided - warn and prefer default_namespace
58
+ import warnings
59
+
60
+ warnings.warn(
61
+ "Both 'default_namespace' and 'execution_database' provided. "
62
+ "Using 'default_namespace'. Please migrate to 'default_namespace' only.",
63
+ DeprecationWarning,
64
+ stacklevel=2,
65
+ )
66
+ elif self.execution_database is not None and self.default_namespace is None:
67
+ # Only execution_database provided - use it with deprecation warning
68
+ import warnings
69
+
70
+ warnings.warn(
71
+ "'execution_database' parameter is deprecated. Use 'default_namespace' instead.",
72
+ DeprecationWarning,
73
+ stacklevel=2,
74
+ )
75
+ self.default_namespace = self.execution_database
76
+ elif self.default_namespace is None and self.execution_database is None:
77
+ # Neither provided - this is an error
78
+ raise ValueError(
79
+ "Must provide either 'default_namespace' (preferred) or 'execution_database' "
80
+ "(deprecated) parameter"
81
+ )
82
+
83
+
84
+ class SQLTestFramework:
85
+ """Main framework for executing SQL tests."""
86
+
87
+ def __init__(self, adapter: DatabaseAdapter) -> None:
88
+ self.adapter = adapter
89
+ self.type_converter = self.adapter.get_type_converter()
90
+ self.temp_tables: List[str] = []
91
+
92
+ def run_test(self, test_case: SQLTestCase[T]) -> List[T]:
93
+ """
94
+ Execute a test case and return deserialized results.
95
+
96
+ Args:
97
+ test_case: The test case to execute
98
+
99
+ Returns:
100
+ List of result objects of type test_case.result_class
101
+ """
102
+ try:
103
+ # Validate required fields
104
+ if test_case.mock_tables is None:
105
+ raise ValueError(
106
+ "mock_tables must be provided either in SQLTestCase or sql_test decorator"
107
+ )
108
+
109
+ if test_case.result_class is None:
110
+ raise ValueError(
111
+ "result_class must be provided either in SQLTestCase or sql_test decorator"
112
+ )
113
+
114
+ # Parse SQL to find table references
115
+ referenced_tables = self._parse_sql_tables(test_case.query)
116
+
117
+ # Resolve unqualified table names
118
+ # default_namespace is guaranteed to be set by __post_init__
119
+ assert test_case.default_namespace is not None
120
+ resolved_tables = self._resolve_table_names(
121
+ referenced_tables, test_case.default_namespace
122
+ )
123
+
124
+ # Validate all required mock tables are provided
125
+ self._validate_mock_tables(resolved_tables, test_case.mock_tables)
126
+
127
+ # Create table name mapping
128
+ table_mapping = self._create_table_mapping(resolved_tables, test_case.mock_tables)
129
+
130
+ if test_case.use_physical_tables:
131
+ # Create physical temporary tables
132
+ final_query = self._execute_with_physical_tables(
133
+ test_case.query, table_mapping, test_case.mock_tables
134
+ )
135
+ else:
136
+ # Generate query with CTEs
137
+ final_query = self._generate_cte_query(
138
+ test_case.query, table_mapping, test_case.mock_tables
139
+ )
140
+
141
+ # Check size limit for adapters that need it
142
+ size_limit = self.adapter.get_query_size_limit()
143
+ if size_limit and len(final_query.encode("utf-8")) > size_limit:
144
+ raise QuerySizeLimitExceeded(
145
+ len(final_query.encode("utf-8")),
146
+ size_limit,
147
+ self.adapter.__class__.__name__,
148
+ )
149
+
150
+ # Execute query
151
+ result_df = self.adapter.execute_query(final_query)
152
+
153
+ # Convert results to typed objects
154
+ return self._deserialize_results(result_df, test_case.result_class)
155
+
156
+ finally:
157
+ # Cleanup any temporary tables
158
+ if self.temp_tables:
159
+ self.adapter.cleanup_temp_tables(self.temp_tables)
160
+ self.temp_tables = []
161
+
162
+ def _parse_sql_tables(self, query: str) -> List[str]:
163
+ """Parse SQL query to extract table references."""
164
+ try:
165
+ import sqlglot
166
+ from sqlglot import exp
167
+
168
+ dialect = self.adapter.get_sqlglot_dialect()
169
+ parsed = sqlglot.parse_one(query, dialect=dialect)
170
+
171
+ # Get all CTE (WITH clause) aliases to filter them out
172
+ cte_aliases = set()
173
+ for cte in parsed.find_all(exp.CTE):
174
+ if hasattr(cte, "alias"):
175
+ cte_aliases.add(str(cte.alias))
176
+
177
+ # Find all real tables (excluding the CTEs)
178
+ tables = []
179
+ for table in parsed.find_all(exp.Table):
180
+ # Skip tables that are actually CTE references
181
+ if str(table.name) in cte_aliases:
182
+ continue
183
+
184
+ # Get the fully qualified name including catalog/schema if present
185
+ if table.db and table.catalog:
186
+ qualified_name = f"{table.catalog}.{table.db}.{table.name}"
187
+ elif table.db:
188
+ qualified_name = f"{table.db}.{table.name}"
189
+ else:
190
+ qualified_name = str(table.name)
191
+
192
+ tables.append(qualified_name)
193
+
194
+ return list(set(tables)) # Remove duplicates
195
+
196
+ except Exception as e:
197
+ raise SQLParseError(query, str(e)) # noqa: B904
198
+
199
+ def _resolve_table_names(
200
+ self, referenced_tables: List[str], default_namespace: str
201
+ ) -> Dict[str, str]:
202
+ """
203
+ Resolve unqualified table names using default namespace context.
204
+
205
+ Returns:
206
+ Dict mapping original table name to fully qualified name
207
+ """
208
+ resolved = {}
209
+ for table_name in referenced_tables:
210
+ if "." in table_name:
211
+ # Already qualified
212
+ resolved[table_name] = table_name
213
+ else:
214
+ # Add namespace prefix
215
+ qualified_name = f"{default_namespace}.{table_name}"
216
+ resolved[table_name] = qualified_name
217
+
218
+ return resolved
219
+
220
+ def _validate_mock_tables(
221
+ self, resolved_tables: Dict[str, str], mock_tables: List[BaseMockTable]
222
+ ) -> None:
223
+ """Validate that all required mock tables are provided."""
224
+ provided_tables = {mock.get_qualified_name() for mock in mock_tables}
225
+ required_tables = set(resolved_tables.values())
226
+
227
+ # Perform case-insensitive validation for all SQL databases
228
+ provided_tables_upper = {table.upper() for table in provided_tables}
229
+ missing_tables = set()
230
+
231
+ for required_table in required_tables:
232
+ if required_table.upper() not in provided_tables_upper:
233
+ missing_tables.add(required_table)
234
+
235
+ if missing_tables:
236
+ raise MockTableNotFoundError(
237
+ list(missing_tables)[0], # Show first missing table
238
+ list(provided_tables),
239
+ )
240
+
241
+ def _create_table_mapping(
242
+ self, resolved_tables: Dict[str, str], mock_tables: List[BaseMockTable]
243
+ ) -> Dict[str, BaseMockTable]:
244
+ """Create mapping from qualified table names to mock table objects."""
245
+ mock_table_map = {mock.get_qualified_name(): mock for mock in mock_tables}
246
+
247
+ # Map original table references to mock tables using case-insensitive matching
248
+ table_mapping = {}
249
+
250
+ for original_name, qualified_name in resolved_tables.items():
251
+ # Case-insensitive matching for all SQL databases
252
+ matched_mock = None
253
+ for mock_qualified_name, mock_table in mock_table_map.items():
254
+ if qualified_name.upper() == mock_qualified_name.upper():
255
+ matched_mock = mock_table
256
+ break
257
+ if matched_mock:
258
+ table_mapping[original_name] = matched_mock
259
+ else:
260
+ # This shouldn't happen if validation passed, but fallback to exact match
261
+ exact_match = mock_table_map.get(qualified_name)
262
+ if exact_match:
263
+ table_mapping[original_name] = exact_match
264
+
265
+ return table_mapping
266
+
267
+ def _generate_cte_query(
268
+ self,
269
+ query: str,
270
+ table_mapping: Dict[str, BaseMockTable],
271
+ mock_tables: List[BaseMockTable],
272
+ ) -> str:
273
+ """Generate query with CTE injections for mock data."""
274
+ # Generate CTEs for each mock table
275
+ ctes = []
276
+ replacement_mapping = {}
277
+
278
+ for original_name, mock_table in table_mapping.items():
279
+ cte_alias = mock_table.get_cte_alias()
280
+ cte_sql = self._generate_cte(mock_table, cte_alias)
281
+ ctes.append(cte_sql)
282
+ replacement_mapping[original_name] = cte_alias
283
+
284
+ # Replace table names in original query
285
+ modified_query = self._replace_table_names_in_query(query, replacement_mapping)
286
+
287
+ # Combine CTEs with original query
288
+ if ctes:
289
+ # Check if modified query already starts with WITH
290
+ modified_query_stripped = modified_query.strip()
291
+ if modified_query_stripped.upper().startswith("WITH"):
292
+ # Query already has WITH clause, so append our CTEs with comma
293
+ cte_block = ",\n".join(ctes)
294
+ final_query = f"WITH {cte_block},\n{modified_query_stripped[4:].strip()}"
295
+ else:
296
+ # Query doesn't have WITH clause, add it
297
+ cte_block = "WITH " + ",\n".join(ctes)
298
+ final_query = f"{cte_block}\n{modified_query}"
299
+ else:
300
+ final_query = modified_query
301
+
302
+ return final_query
303
+
304
+ def _generate_cte(self, mock_table: BaseMockTable, alias: str) -> str:
305
+ """Generate CTE SQL for a mock table."""
306
+ df = mock_table.to_dataframe()
307
+ column_types = mock_table.get_column_types()
308
+ if df.empty:
309
+ # Generate empty CTE
310
+ columns = list(column_types.keys())
311
+ return f"{alias} AS (SELECT {', '.join(f'NULL as {col}' for col in columns)} WHERE 1=0)" # noqa: E501
312
+
313
+ # Get dialect to determine the correct CTE format
314
+ dialect = self.adapter.get_sqlglot_dialect()
315
+
316
+ if dialect in ["bigquery", "snowflake"]:
317
+ # BigQuery and Snowflake-specific format using UNION ALL
318
+ # (Snowflake VALUES clauses don't support complex expressions like ARRAY_CONSTRUCT)
319
+ columns = list(df.columns)
320
+ select_statements = []
321
+
322
+ for idx, (_, row) in enumerate(df.iterrows()):
323
+ if idx == 0:
324
+ # First SELECT with column aliases
325
+ select_expressions = []
326
+ for col_name, value in row.items():
327
+ col_type = column_types.get(col_name, str)
328
+ formatted_value = self.adapter.format_value_for_cte(value, col_type)
329
+ select_expressions.append(f"{formatted_value} AS {col_name}")
330
+ select_statements.append(f"SELECT {', '.join(select_expressions)}")
331
+ else:
332
+ # Subsequent SELECTs without aliases
333
+ row_values = []
334
+ for col_name, value in row.items():
335
+ col_type = column_types.get(col_name, str)
336
+ formatted_value = self.adapter.format_value_for_cte(value, col_type)
337
+ row_values.append(formatted_value)
338
+ select_statements.append(f"SELECT {', '.join(row_values)}")
339
+
340
+ union_query = "\n UNION ALL\n ".join(select_statements)
341
+ return f"{alias} AS (\n {union_query}\n)"
342
+ elif dialect == "redshift":
343
+ # Redshift-specific format using UNION ALL (VALUES not supported in CTEs)
344
+ columns = list(df.columns)
345
+ select_statements = []
346
+
347
+ for idx, (_, row) in enumerate(df.iterrows()):
348
+ if idx == 0:
349
+ # First SELECT with column aliases
350
+ select_expressions = []
351
+ for col_name, value in row.items():
352
+ col_type = column_types.get(col_name, str)
353
+ formatted_value = self.adapter.format_value_for_cte(value, col_type)
354
+ select_expressions.append(f"{formatted_value} AS {col_name}")
355
+ select_statements.append(f"SELECT {', '.join(select_expressions)}")
356
+ else:
357
+ # Subsequent SELECTs without aliases
358
+ row_values = []
359
+ for col_name, value in row.items():
360
+ col_type = column_types.get(col_name, str)
361
+ formatted_value = self.adapter.format_value_for_cte(value, col_type)
362
+ row_values.append(formatted_value)
363
+ select_statements.append(f"SELECT {', '.join(row_values)}")
364
+
365
+ union_query = "\n UNION ALL\n ".join(select_statements)
366
+ return f"{alias} AS (\n {union_query}\n)"
367
+ else:
368
+ # Standard SQL format using VALUES clause
369
+ values_rows = []
370
+ for _, row in df.iterrows():
371
+ row_values = []
372
+ for col_name, value in row.items():
373
+ col_type = column_types.get(col_name, str)
374
+ formatted_value = self.adapter.format_value_for_cte(value, col_type)
375
+ row_values.append(formatted_value)
376
+ values_rows.append(f"({', '.join(row_values)})")
377
+
378
+ column_list = ", ".join(df.columns)
379
+ values_clause = ", ".join(values_rows)
380
+
381
+ return f"{alias} AS (SELECT * FROM (VALUES {values_clause}) AS t({column_list}))"
382
+
383
+ def _replace_table_names_in_query(self, query: str, replacement_mapping: Dict[str, str]) -> str:
384
+ """Replace table names in query using sqlglot AST transformation."""
385
+ try:
386
+ import sqlglot
387
+ from sqlglot import exp
388
+
389
+ dialect = self.adapter.get_sqlglot_dialect()
390
+
391
+ # Parse the query to an AST
392
+ parsed = sqlglot.parse_one(query, dialect=dialect)
393
+
394
+ # Create a transformer to replace table names
395
+ def transform_tables(node: exp.Expression) -> exp.Expression:
396
+ if isinstance(node, exp.Table):
397
+ # Get the original table name
398
+ if node.db and node.catalog:
399
+ original_name = f"{node.catalog}.{node.db}.{node.name}"
400
+ elif node.db:
401
+ original_name = f"{node.db}.{node.name}"
402
+ else:
403
+ original_name = str(node.name)
404
+
405
+ # Check if this table should be replaced
406
+ # Perform case-insensitive matching for all SQL databases
407
+ replacement_name = None
408
+ for mapping_key, mapping_value in replacement_mapping.items():
409
+ if original_name.upper() == mapping_key.upper():
410
+ replacement_name = mapping_value
411
+ break
412
+
413
+ if replacement_name:
414
+ # Create a new Table node with the replacement name
415
+ new_table = exp.Table(this=exp.Identifier(this=replacement_name))
416
+
417
+ # Preserve the table alias if it exists
418
+ if hasattr(node, "alias") and node.alias:
419
+ new_table.set("alias", node.alias)
420
+
421
+ return new_table
422
+
423
+ return node
424
+
425
+ # Apply the transformation to the AST
426
+ transformed = parsed.transform(transform_tables)
427
+
428
+ # Generate the SQL from the transformed AST
429
+ result_sql: str = transformed.sql(dialect=dialect)
430
+ return result_sql
431
+
432
+ except Exception as e:
433
+ # Re-raise the exception as SQLParseError to maintain compatibility
434
+ # with the existing error handling expectations
435
+ raise SQLParseError(query, str(e)) # noqa: B904
436
+
437
+ def _execute_with_physical_tables(
438
+ self,
439
+ query: str,
440
+ table_mapping: Dict[str, BaseMockTable],
441
+ mock_tables: List[BaseMockTable],
442
+ ) -> str:
443
+ """Execute query using physical temporary tables."""
444
+ # Create physical tables
445
+ replacement_mapping = {}
446
+
447
+ for original_name, mock_table in table_mapping.items():
448
+ temp_table_name = self.adapter.create_temp_table(mock_table)
449
+ self.temp_tables.append(temp_table_name)
450
+ replacement_mapping[original_name] = temp_table_name
451
+
452
+ # Replace table names and return modified query
453
+ return self._replace_table_names_in_query(query, replacement_mapping)
454
+
455
+ def _deserialize_results(self, result_df: "pd.DataFrame", result_class: Type[T]) -> List[T]:
456
+ """Deserialize query results to typed objects."""
457
+ import numpy as np
458
+
459
+ if result_df.empty:
460
+ return []
461
+
462
+ # STEP 1: Convert database-returned NaN values to Python None
463
+ #
464
+ # WHY THIS IS NEEDED:
465
+ # - SQL databases return NULL values which pandas converts to NaN
466
+ # - Different database adapters may return NaN for null numeric/float columns
467
+ # - NaN values break object serialization (dataclass/Pydantic instantiation)
468
+ # - Python None is the correct representation for nullable/optional fields
469
+ #
470
+ # RELATIONSHIP TO mock_table.py NaN HANDLING:
471
+ # - mock_table.py: Handles NaN created during DataFrame dtype conversion (input side)
472
+ # - core.py (here): Handles NaN returned from actual database queries (output side)
473
+ # - Both are needed because NaN can appear at different pipeline stages
474
+ result_df = result_df.replace([np.nan], [None])
475
+ # Get type hints from the result class
476
+ type_hints = get_type_hints(result_class)
477
+
478
+ results: List[T] = []
479
+ for _, row in result_df.iterrows():
480
+ # Convert row to dictionary with proper types
481
+ converted_row: Dict[str, Any] = {}
482
+ for col_name, value in row.items():
483
+ if col_name in type_hints:
484
+ target_type = type_hints[col_name]
485
+ try:
486
+ converted_value = self.type_converter.convert(value, target_type)
487
+ converted_row[col_name] = converted_value
488
+ except Exception:
489
+ raise TypeConversionError(value, target_type, col_name) # noqa: B904
490
+ else:
491
+ converted_row[col_name] = value
492
+
493
+ # Create instance of result class
494
+ try:
495
+ result_obj = result_class(**converted_row)
496
+ results.append(result_obj)
497
+ except Exception as e:
498
+ raise TypeError( # noqa: B904
499
+ f"Failed to create {result_class.__name__} instance: {e}"
500
+ )
501
+
502
+ return results
@@ -0,0 +1,55 @@
1
+ """Custom exceptions for SQL testing library."""
2
+
3
+ from typing import Any, List
4
+
5
+
6
+ class SQLTestingError(Exception):
7
+ """Base exception for SQL testing library."""
8
+
9
+ pass
10
+
11
+
12
+ class MockTableNotFoundError(SQLTestingError):
13
+ """Raised when a required mock table is not provided."""
14
+
15
+ def __init__(self, qualified_table_name: str, available_mocks: List[str]):
16
+ self.qualified_table_name = qualified_table_name
17
+ self.available_mocks = available_mocks
18
+ available_list = ", ".join(available_mocks) if available_mocks else "None"
19
+ super().__init__(
20
+ f"Mock table not found: '{qualified_table_name}'. Available: {available_list}"
21
+ )
22
+
23
+
24
+ class SQLParseError(SQLTestingError):
25
+ """Raised when SQL parsing fails."""
26
+
27
+ def __init__(self, query: str, parse_error: str):
28
+ self.query = query
29
+ self.parse_error = parse_error
30
+ super().__init__(f"Failed to parse SQL: {parse_error}")
31
+
32
+
33
+ class QuerySizeLimitExceeded(SQLTestingError):
34
+ """Raised when query size exceeds database limits."""
35
+
36
+ def __init__(self, actual_size: int, limit: int, adapter_name: str):
37
+ self.actual_size = actual_size
38
+ self.limit = limit
39
+ self.adapter_name = adapter_name
40
+ super().__init__(
41
+ f"Query size ({actual_size} bytes) exceeds {adapter_name} limit "
42
+ f"({limit} bytes). Consider using use_physical_tables=True"
43
+ )
44
+
45
+
46
+ class TypeConversionError(SQLTestingError):
47
+ """Raised when type conversion fails during result deserialization."""
48
+
49
+ def __init__(self, value: Any, target_type: type, column_name: str):
50
+ self.value = value
51
+ self.target_type = target_type
52
+ self.column_name = column_name
53
+ super().__init__(
54
+ f"Cannot convert '{value}' to {target_type.__name__} for column '{column_name}'"
55
+ )