sql-testing-library 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sql_testing_library/__init__.py +42 -0
- sql_testing_library/_adapters/__init__.py +15 -0
- sql_testing_library/_adapters/athena.py +309 -0
- sql_testing_library/_adapters/base.py +49 -0
- sql_testing_library/_adapters/bigquery.py +139 -0
- sql_testing_library/_adapters/redshift.py +219 -0
- sql_testing_library/_adapters/snowflake.py +270 -0
- sql_testing_library/_adapters/trino.py +263 -0
- sql_testing_library/_core.py +502 -0
- sql_testing_library/_exceptions.py +55 -0
- sql_testing_library/_mock_table.py +200 -0
- sql_testing_library/_pytest_plugin.py +451 -0
- sql_testing_library/_sql_utils.py +225 -0
- sql_testing_library/_types.py +142 -0
- sql_testing_library/py.typed +0 -0
- sql_testing_library-0.4.0.dist-info/LICENSE +21 -0
- sql_testing_library-0.4.0.dist-info/METADATA +956 -0
- sql_testing_library-0.4.0.dist-info/RECORD +20 -0
- sql_testing_library-0.4.0.dist-info/WHEEL +4 -0
- sql_testing_library-0.4.0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,502 @@
|
|
|
1
|
+
"""Core SQL testing framework."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import (
|
|
5
|
+
TYPE_CHECKING,
|
|
6
|
+
Any,
|
|
7
|
+
Dict,
|
|
8
|
+
Generic,
|
|
9
|
+
List,
|
|
10
|
+
Literal,
|
|
11
|
+
Optional,
|
|
12
|
+
Type,
|
|
13
|
+
TypeVar,
|
|
14
|
+
get_type_hints,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
import pandas as pd
|
|
20
|
+
|
|
21
|
+
# Heavy imports moved to function level for better performance
|
|
22
|
+
from ._adapters.base import DatabaseAdapter
|
|
23
|
+
from ._exceptions import (
|
|
24
|
+
MockTableNotFoundError,
|
|
25
|
+
QuerySizeLimitExceeded,
|
|
26
|
+
SQLParseError,
|
|
27
|
+
TypeConversionError,
|
|
28
|
+
)
|
|
29
|
+
from ._mock_table import BaseMockTable
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# Type for adapter types
|
|
33
|
+
AdapterType = Literal["bigquery", "athena", "redshift", "trino", "snowflake"]
|
|
34
|
+
|
|
35
|
+
T = TypeVar("T")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class SQLTestCase(Generic[T]):
|
|
40
|
+
"""Represents a SQL test case."""
|
|
41
|
+
|
|
42
|
+
__test__ = False # Tell pytest this is not a test class
|
|
43
|
+
|
|
44
|
+
query: str
|
|
45
|
+
default_namespace: Optional[str] = None
|
|
46
|
+
mock_tables: Optional[List[BaseMockTable]] = None
|
|
47
|
+
result_class: Optional[Type[T]] = None
|
|
48
|
+
use_physical_tables: bool = False
|
|
49
|
+
description: Optional[str] = None
|
|
50
|
+
adapter_type: Optional[AdapterType] = None
|
|
51
|
+
# Backward compatibility
|
|
52
|
+
execution_database: Optional[str] = None
|
|
53
|
+
|
|
54
|
+
def __post_init__(self) -> None:
|
|
55
|
+
"""Handle backward compatibility for execution_database parameter."""
|
|
56
|
+
if self.execution_database is not None and self.default_namespace is not None:
|
|
57
|
+
# Both provided - warn and prefer default_namespace
|
|
58
|
+
import warnings
|
|
59
|
+
|
|
60
|
+
warnings.warn(
|
|
61
|
+
"Both 'default_namespace' and 'execution_database' provided. "
|
|
62
|
+
"Using 'default_namespace'. Please migrate to 'default_namespace' only.",
|
|
63
|
+
DeprecationWarning,
|
|
64
|
+
stacklevel=2,
|
|
65
|
+
)
|
|
66
|
+
elif self.execution_database is not None and self.default_namespace is None:
|
|
67
|
+
# Only execution_database provided - use it with deprecation warning
|
|
68
|
+
import warnings
|
|
69
|
+
|
|
70
|
+
warnings.warn(
|
|
71
|
+
"'execution_database' parameter is deprecated. Use 'default_namespace' instead.",
|
|
72
|
+
DeprecationWarning,
|
|
73
|
+
stacklevel=2,
|
|
74
|
+
)
|
|
75
|
+
self.default_namespace = self.execution_database
|
|
76
|
+
elif self.default_namespace is None and self.execution_database is None:
|
|
77
|
+
# Neither provided - this is an error
|
|
78
|
+
raise ValueError(
|
|
79
|
+
"Must provide either 'default_namespace' (preferred) or 'execution_database' "
|
|
80
|
+
"(deprecated) parameter"
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class SQLTestFramework:
|
|
85
|
+
"""Main framework for executing SQL tests."""
|
|
86
|
+
|
|
87
|
+
def __init__(self, adapter: DatabaseAdapter) -> None:
|
|
88
|
+
self.adapter = adapter
|
|
89
|
+
self.type_converter = self.adapter.get_type_converter()
|
|
90
|
+
self.temp_tables: List[str] = []
|
|
91
|
+
|
|
92
|
+
def run_test(self, test_case: SQLTestCase[T]) -> List[T]:
|
|
93
|
+
"""
|
|
94
|
+
Execute a test case and return deserialized results.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
test_case: The test case to execute
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
List of result objects of type test_case.result_class
|
|
101
|
+
"""
|
|
102
|
+
try:
|
|
103
|
+
# Validate required fields
|
|
104
|
+
if test_case.mock_tables is None:
|
|
105
|
+
raise ValueError(
|
|
106
|
+
"mock_tables must be provided either in SQLTestCase or sql_test decorator"
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
if test_case.result_class is None:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
"result_class must be provided either in SQLTestCase or sql_test decorator"
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# Parse SQL to find table references
|
|
115
|
+
referenced_tables = self._parse_sql_tables(test_case.query)
|
|
116
|
+
|
|
117
|
+
# Resolve unqualified table names
|
|
118
|
+
# default_namespace is guaranteed to be set by __post_init__
|
|
119
|
+
assert test_case.default_namespace is not None
|
|
120
|
+
resolved_tables = self._resolve_table_names(
|
|
121
|
+
referenced_tables, test_case.default_namespace
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
# Validate all required mock tables are provided
|
|
125
|
+
self._validate_mock_tables(resolved_tables, test_case.mock_tables)
|
|
126
|
+
|
|
127
|
+
# Create table name mapping
|
|
128
|
+
table_mapping = self._create_table_mapping(resolved_tables, test_case.mock_tables)
|
|
129
|
+
|
|
130
|
+
if test_case.use_physical_tables:
|
|
131
|
+
# Create physical temporary tables
|
|
132
|
+
final_query = self._execute_with_physical_tables(
|
|
133
|
+
test_case.query, table_mapping, test_case.mock_tables
|
|
134
|
+
)
|
|
135
|
+
else:
|
|
136
|
+
# Generate query with CTEs
|
|
137
|
+
final_query = self._generate_cte_query(
|
|
138
|
+
test_case.query, table_mapping, test_case.mock_tables
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# Check size limit for adapters that need it
|
|
142
|
+
size_limit = self.adapter.get_query_size_limit()
|
|
143
|
+
if size_limit and len(final_query.encode("utf-8")) > size_limit:
|
|
144
|
+
raise QuerySizeLimitExceeded(
|
|
145
|
+
len(final_query.encode("utf-8")),
|
|
146
|
+
size_limit,
|
|
147
|
+
self.adapter.__class__.__name__,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
# Execute query
|
|
151
|
+
result_df = self.adapter.execute_query(final_query)
|
|
152
|
+
|
|
153
|
+
# Convert results to typed objects
|
|
154
|
+
return self._deserialize_results(result_df, test_case.result_class)
|
|
155
|
+
|
|
156
|
+
finally:
|
|
157
|
+
# Cleanup any temporary tables
|
|
158
|
+
if self.temp_tables:
|
|
159
|
+
self.adapter.cleanup_temp_tables(self.temp_tables)
|
|
160
|
+
self.temp_tables = []
|
|
161
|
+
|
|
162
|
+
def _parse_sql_tables(self, query: str) -> List[str]:
|
|
163
|
+
"""Parse SQL query to extract table references."""
|
|
164
|
+
try:
|
|
165
|
+
import sqlglot
|
|
166
|
+
from sqlglot import exp
|
|
167
|
+
|
|
168
|
+
dialect = self.adapter.get_sqlglot_dialect()
|
|
169
|
+
parsed = sqlglot.parse_one(query, dialect=dialect)
|
|
170
|
+
|
|
171
|
+
# Get all CTE (WITH clause) aliases to filter them out
|
|
172
|
+
cte_aliases = set()
|
|
173
|
+
for cte in parsed.find_all(exp.CTE):
|
|
174
|
+
if hasattr(cte, "alias"):
|
|
175
|
+
cte_aliases.add(str(cte.alias))
|
|
176
|
+
|
|
177
|
+
# Find all real tables (excluding the CTEs)
|
|
178
|
+
tables = []
|
|
179
|
+
for table in parsed.find_all(exp.Table):
|
|
180
|
+
# Skip tables that are actually CTE references
|
|
181
|
+
if str(table.name) in cte_aliases:
|
|
182
|
+
continue
|
|
183
|
+
|
|
184
|
+
# Get the fully qualified name including catalog/schema if present
|
|
185
|
+
if table.db and table.catalog:
|
|
186
|
+
qualified_name = f"{table.catalog}.{table.db}.{table.name}"
|
|
187
|
+
elif table.db:
|
|
188
|
+
qualified_name = f"{table.db}.{table.name}"
|
|
189
|
+
else:
|
|
190
|
+
qualified_name = str(table.name)
|
|
191
|
+
|
|
192
|
+
tables.append(qualified_name)
|
|
193
|
+
|
|
194
|
+
return list(set(tables)) # Remove duplicates
|
|
195
|
+
|
|
196
|
+
except Exception as e:
|
|
197
|
+
raise SQLParseError(query, str(e)) # noqa: B904
|
|
198
|
+
|
|
199
|
+
def _resolve_table_names(
|
|
200
|
+
self, referenced_tables: List[str], default_namespace: str
|
|
201
|
+
) -> Dict[str, str]:
|
|
202
|
+
"""
|
|
203
|
+
Resolve unqualified table names using default namespace context.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
Dict mapping original table name to fully qualified name
|
|
207
|
+
"""
|
|
208
|
+
resolved = {}
|
|
209
|
+
for table_name in referenced_tables:
|
|
210
|
+
if "." in table_name:
|
|
211
|
+
# Already qualified
|
|
212
|
+
resolved[table_name] = table_name
|
|
213
|
+
else:
|
|
214
|
+
# Add namespace prefix
|
|
215
|
+
qualified_name = f"{default_namespace}.{table_name}"
|
|
216
|
+
resolved[table_name] = qualified_name
|
|
217
|
+
|
|
218
|
+
return resolved
|
|
219
|
+
|
|
220
|
+
def _validate_mock_tables(
|
|
221
|
+
self, resolved_tables: Dict[str, str], mock_tables: List[BaseMockTable]
|
|
222
|
+
) -> None:
|
|
223
|
+
"""Validate that all required mock tables are provided."""
|
|
224
|
+
provided_tables = {mock.get_qualified_name() for mock in mock_tables}
|
|
225
|
+
required_tables = set(resolved_tables.values())
|
|
226
|
+
|
|
227
|
+
# Perform case-insensitive validation for all SQL databases
|
|
228
|
+
provided_tables_upper = {table.upper() for table in provided_tables}
|
|
229
|
+
missing_tables = set()
|
|
230
|
+
|
|
231
|
+
for required_table in required_tables:
|
|
232
|
+
if required_table.upper() not in provided_tables_upper:
|
|
233
|
+
missing_tables.add(required_table)
|
|
234
|
+
|
|
235
|
+
if missing_tables:
|
|
236
|
+
raise MockTableNotFoundError(
|
|
237
|
+
list(missing_tables)[0], # Show first missing table
|
|
238
|
+
list(provided_tables),
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
def _create_table_mapping(
|
|
242
|
+
self, resolved_tables: Dict[str, str], mock_tables: List[BaseMockTable]
|
|
243
|
+
) -> Dict[str, BaseMockTable]:
|
|
244
|
+
"""Create mapping from qualified table names to mock table objects."""
|
|
245
|
+
mock_table_map = {mock.get_qualified_name(): mock for mock in mock_tables}
|
|
246
|
+
|
|
247
|
+
# Map original table references to mock tables using case-insensitive matching
|
|
248
|
+
table_mapping = {}
|
|
249
|
+
|
|
250
|
+
for original_name, qualified_name in resolved_tables.items():
|
|
251
|
+
# Case-insensitive matching for all SQL databases
|
|
252
|
+
matched_mock = None
|
|
253
|
+
for mock_qualified_name, mock_table in mock_table_map.items():
|
|
254
|
+
if qualified_name.upper() == mock_qualified_name.upper():
|
|
255
|
+
matched_mock = mock_table
|
|
256
|
+
break
|
|
257
|
+
if matched_mock:
|
|
258
|
+
table_mapping[original_name] = matched_mock
|
|
259
|
+
else:
|
|
260
|
+
# This shouldn't happen if validation passed, but fallback to exact match
|
|
261
|
+
exact_match = mock_table_map.get(qualified_name)
|
|
262
|
+
if exact_match:
|
|
263
|
+
table_mapping[original_name] = exact_match
|
|
264
|
+
|
|
265
|
+
return table_mapping
|
|
266
|
+
|
|
267
|
+
def _generate_cte_query(
|
|
268
|
+
self,
|
|
269
|
+
query: str,
|
|
270
|
+
table_mapping: Dict[str, BaseMockTable],
|
|
271
|
+
mock_tables: List[BaseMockTable],
|
|
272
|
+
) -> str:
|
|
273
|
+
"""Generate query with CTE injections for mock data."""
|
|
274
|
+
# Generate CTEs for each mock table
|
|
275
|
+
ctes = []
|
|
276
|
+
replacement_mapping = {}
|
|
277
|
+
|
|
278
|
+
for original_name, mock_table in table_mapping.items():
|
|
279
|
+
cte_alias = mock_table.get_cte_alias()
|
|
280
|
+
cte_sql = self._generate_cte(mock_table, cte_alias)
|
|
281
|
+
ctes.append(cte_sql)
|
|
282
|
+
replacement_mapping[original_name] = cte_alias
|
|
283
|
+
|
|
284
|
+
# Replace table names in original query
|
|
285
|
+
modified_query = self._replace_table_names_in_query(query, replacement_mapping)
|
|
286
|
+
|
|
287
|
+
# Combine CTEs with original query
|
|
288
|
+
if ctes:
|
|
289
|
+
# Check if modified query already starts with WITH
|
|
290
|
+
modified_query_stripped = modified_query.strip()
|
|
291
|
+
if modified_query_stripped.upper().startswith("WITH"):
|
|
292
|
+
# Query already has WITH clause, so append our CTEs with comma
|
|
293
|
+
cte_block = ",\n".join(ctes)
|
|
294
|
+
final_query = f"WITH {cte_block},\n{modified_query_stripped[4:].strip()}"
|
|
295
|
+
else:
|
|
296
|
+
# Query doesn't have WITH clause, add it
|
|
297
|
+
cte_block = "WITH " + ",\n".join(ctes)
|
|
298
|
+
final_query = f"{cte_block}\n{modified_query}"
|
|
299
|
+
else:
|
|
300
|
+
final_query = modified_query
|
|
301
|
+
|
|
302
|
+
return final_query
|
|
303
|
+
|
|
304
|
+
def _generate_cte(self, mock_table: BaseMockTable, alias: str) -> str:
|
|
305
|
+
"""Generate CTE SQL for a mock table."""
|
|
306
|
+
df = mock_table.to_dataframe()
|
|
307
|
+
column_types = mock_table.get_column_types()
|
|
308
|
+
if df.empty:
|
|
309
|
+
# Generate empty CTE
|
|
310
|
+
columns = list(column_types.keys())
|
|
311
|
+
return f"{alias} AS (SELECT {', '.join(f'NULL as {col}' for col in columns)} WHERE 1=0)" # noqa: E501
|
|
312
|
+
|
|
313
|
+
# Get dialect to determine the correct CTE format
|
|
314
|
+
dialect = self.adapter.get_sqlglot_dialect()
|
|
315
|
+
|
|
316
|
+
if dialect in ["bigquery", "snowflake"]:
|
|
317
|
+
# BigQuery and Snowflake-specific format using UNION ALL
|
|
318
|
+
# (Snowflake VALUES clauses don't support complex expressions like ARRAY_CONSTRUCT)
|
|
319
|
+
columns = list(df.columns)
|
|
320
|
+
select_statements = []
|
|
321
|
+
|
|
322
|
+
for idx, (_, row) in enumerate(df.iterrows()):
|
|
323
|
+
if idx == 0:
|
|
324
|
+
# First SELECT with column aliases
|
|
325
|
+
select_expressions = []
|
|
326
|
+
for col_name, value in row.items():
|
|
327
|
+
col_type = column_types.get(col_name, str)
|
|
328
|
+
formatted_value = self.adapter.format_value_for_cte(value, col_type)
|
|
329
|
+
select_expressions.append(f"{formatted_value} AS {col_name}")
|
|
330
|
+
select_statements.append(f"SELECT {', '.join(select_expressions)}")
|
|
331
|
+
else:
|
|
332
|
+
# Subsequent SELECTs without aliases
|
|
333
|
+
row_values = []
|
|
334
|
+
for col_name, value in row.items():
|
|
335
|
+
col_type = column_types.get(col_name, str)
|
|
336
|
+
formatted_value = self.adapter.format_value_for_cte(value, col_type)
|
|
337
|
+
row_values.append(formatted_value)
|
|
338
|
+
select_statements.append(f"SELECT {', '.join(row_values)}")
|
|
339
|
+
|
|
340
|
+
union_query = "\n UNION ALL\n ".join(select_statements)
|
|
341
|
+
return f"{alias} AS (\n {union_query}\n)"
|
|
342
|
+
elif dialect == "redshift":
|
|
343
|
+
# Redshift-specific format using UNION ALL (VALUES not supported in CTEs)
|
|
344
|
+
columns = list(df.columns)
|
|
345
|
+
select_statements = []
|
|
346
|
+
|
|
347
|
+
for idx, (_, row) in enumerate(df.iterrows()):
|
|
348
|
+
if idx == 0:
|
|
349
|
+
# First SELECT with column aliases
|
|
350
|
+
select_expressions = []
|
|
351
|
+
for col_name, value in row.items():
|
|
352
|
+
col_type = column_types.get(col_name, str)
|
|
353
|
+
formatted_value = self.adapter.format_value_for_cte(value, col_type)
|
|
354
|
+
select_expressions.append(f"{formatted_value} AS {col_name}")
|
|
355
|
+
select_statements.append(f"SELECT {', '.join(select_expressions)}")
|
|
356
|
+
else:
|
|
357
|
+
# Subsequent SELECTs without aliases
|
|
358
|
+
row_values = []
|
|
359
|
+
for col_name, value in row.items():
|
|
360
|
+
col_type = column_types.get(col_name, str)
|
|
361
|
+
formatted_value = self.adapter.format_value_for_cte(value, col_type)
|
|
362
|
+
row_values.append(formatted_value)
|
|
363
|
+
select_statements.append(f"SELECT {', '.join(row_values)}")
|
|
364
|
+
|
|
365
|
+
union_query = "\n UNION ALL\n ".join(select_statements)
|
|
366
|
+
return f"{alias} AS (\n {union_query}\n)"
|
|
367
|
+
else:
|
|
368
|
+
# Standard SQL format using VALUES clause
|
|
369
|
+
values_rows = []
|
|
370
|
+
for _, row in df.iterrows():
|
|
371
|
+
row_values = []
|
|
372
|
+
for col_name, value in row.items():
|
|
373
|
+
col_type = column_types.get(col_name, str)
|
|
374
|
+
formatted_value = self.adapter.format_value_for_cte(value, col_type)
|
|
375
|
+
row_values.append(formatted_value)
|
|
376
|
+
values_rows.append(f"({', '.join(row_values)})")
|
|
377
|
+
|
|
378
|
+
column_list = ", ".join(df.columns)
|
|
379
|
+
values_clause = ", ".join(values_rows)
|
|
380
|
+
|
|
381
|
+
return f"{alias} AS (SELECT * FROM (VALUES {values_clause}) AS t({column_list}))"
|
|
382
|
+
|
|
383
|
+
def _replace_table_names_in_query(self, query: str, replacement_mapping: Dict[str, str]) -> str:
|
|
384
|
+
"""Replace table names in query using sqlglot AST transformation."""
|
|
385
|
+
try:
|
|
386
|
+
import sqlglot
|
|
387
|
+
from sqlglot import exp
|
|
388
|
+
|
|
389
|
+
dialect = self.adapter.get_sqlglot_dialect()
|
|
390
|
+
|
|
391
|
+
# Parse the query to an AST
|
|
392
|
+
parsed = sqlglot.parse_one(query, dialect=dialect)
|
|
393
|
+
|
|
394
|
+
# Create a transformer to replace table names
|
|
395
|
+
def transform_tables(node: exp.Expression) -> exp.Expression:
|
|
396
|
+
if isinstance(node, exp.Table):
|
|
397
|
+
# Get the original table name
|
|
398
|
+
if node.db and node.catalog:
|
|
399
|
+
original_name = f"{node.catalog}.{node.db}.{node.name}"
|
|
400
|
+
elif node.db:
|
|
401
|
+
original_name = f"{node.db}.{node.name}"
|
|
402
|
+
else:
|
|
403
|
+
original_name = str(node.name)
|
|
404
|
+
|
|
405
|
+
# Check if this table should be replaced
|
|
406
|
+
# Perform case-insensitive matching for all SQL databases
|
|
407
|
+
replacement_name = None
|
|
408
|
+
for mapping_key, mapping_value in replacement_mapping.items():
|
|
409
|
+
if original_name.upper() == mapping_key.upper():
|
|
410
|
+
replacement_name = mapping_value
|
|
411
|
+
break
|
|
412
|
+
|
|
413
|
+
if replacement_name:
|
|
414
|
+
# Create a new Table node with the replacement name
|
|
415
|
+
new_table = exp.Table(this=exp.Identifier(this=replacement_name))
|
|
416
|
+
|
|
417
|
+
# Preserve the table alias if it exists
|
|
418
|
+
if hasattr(node, "alias") and node.alias:
|
|
419
|
+
new_table.set("alias", node.alias)
|
|
420
|
+
|
|
421
|
+
return new_table
|
|
422
|
+
|
|
423
|
+
return node
|
|
424
|
+
|
|
425
|
+
# Apply the transformation to the AST
|
|
426
|
+
transformed = parsed.transform(transform_tables)
|
|
427
|
+
|
|
428
|
+
# Generate the SQL from the transformed AST
|
|
429
|
+
result_sql: str = transformed.sql(dialect=dialect)
|
|
430
|
+
return result_sql
|
|
431
|
+
|
|
432
|
+
except Exception as e:
|
|
433
|
+
# Re-raise the exception as SQLParseError to maintain compatibility
|
|
434
|
+
# with the existing error handling expectations
|
|
435
|
+
raise SQLParseError(query, str(e)) # noqa: B904
|
|
436
|
+
|
|
437
|
+
def _execute_with_physical_tables(
|
|
438
|
+
self,
|
|
439
|
+
query: str,
|
|
440
|
+
table_mapping: Dict[str, BaseMockTable],
|
|
441
|
+
mock_tables: List[BaseMockTable],
|
|
442
|
+
) -> str:
|
|
443
|
+
"""Execute query using physical temporary tables."""
|
|
444
|
+
# Create physical tables
|
|
445
|
+
replacement_mapping = {}
|
|
446
|
+
|
|
447
|
+
for original_name, mock_table in table_mapping.items():
|
|
448
|
+
temp_table_name = self.adapter.create_temp_table(mock_table)
|
|
449
|
+
self.temp_tables.append(temp_table_name)
|
|
450
|
+
replacement_mapping[original_name] = temp_table_name
|
|
451
|
+
|
|
452
|
+
# Replace table names and return modified query
|
|
453
|
+
return self._replace_table_names_in_query(query, replacement_mapping)
|
|
454
|
+
|
|
455
|
+
def _deserialize_results(self, result_df: "pd.DataFrame", result_class: Type[T]) -> List[T]:
|
|
456
|
+
"""Deserialize query results to typed objects."""
|
|
457
|
+
import numpy as np
|
|
458
|
+
|
|
459
|
+
if result_df.empty:
|
|
460
|
+
return []
|
|
461
|
+
|
|
462
|
+
# STEP 1: Convert database-returned NaN values to Python None
|
|
463
|
+
#
|
|
464
|
+
# WHY THIS IS NEEDED:
|
|
465
|
+
# - SQL databases return NULL values which pandas converts to NaN
|
|
466
|
+
# - Different database adapters may return NaN for null numeric/float columns
|
|
467
|
+
# - NaN values break object serialization (dataclass/Pydantic instantiation)
|
|
468
|
+
# - Python None is the correct representation for nullable/optional fields
|
|
469
|
+
#
|
|
470
|
+
# RELATIONSHIP TO mock_table.py NaN HANDLING:
|
|
471
|
+
# - mock_table.py: Handles NaN created during DataFrame dtype conversion (input side)
|
|
472
|
+
# - core.py (here): Handles NaN returned from actual database queries (output side)
|
|
473
|
+
# - Both are needed because NaN can appear at different pipeline stages
|
|
474
|
+
result_df = result_df.replace([np.nan], [None])
|
|
475
|
+
# Get type hints from the result class
|
|
476
|
+
type_hints = get_type_hints(result_class)
|
|
477
|
+
|
|
478
|
+
results: List[T] = []
|
|
479
|
+
for _, row in result_df.iterrows():
|
|
480
|
+
# Convert row to dictionary with proper types
|
|
481
|
+
converted_row: Dict[str, Any] = {}
|
|
482
|
+
for col_name, value in row.items():
|
|
483
|
+
if col_name in type_hints:
|
|
484
|
+
target_type = type_hints[col_name]
|
|
485
|
+
try:
|
|
486
|
+
converted_value = self.type_converter.convert(value, target_type)
|
|
487
|
+
converted_row[col_name] = converted_value
|
|
488
|
+
except Exception:
|
|
489
|
+
raise TypeConversionError(value, target_type, col_name) # noqa: B904
|
|
490
|
+
else:
|
|
491
|
+
converted_row[col_name] = value
|
|
492
|
+
|
|
493
|
+
# Create instance of result class
|
|
494
|
+
try:
|
|
495
|
+
result_obj = result_class(**converted_row)
|
|
496
|
+
results.append(result_obj)
|
|
497
|
+
except Exception as e:
|
|
498
|
+
raise TypeError( # noqa: B904
|
|
499
|
+
f"Failed to create {result_class.__name__} instance: {e}"
|
|
500
|
+
)
|
|
501
|
+
|
|
502
|
+
return results
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
"""Custom exceptions for SQL testing library."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, List
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class SQLTestingError(Exception):
|
|
7
|
+
"""Base exception for SQL testing library."""
|
|
8
|
+
|
|
9
|
+
pass
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class MockTableNotFoundError(SQLTestingError):
|
|
13
|
+
"""Raised when a required mock table is not provided."""
|
|
14
|
+
|
|
15
|
+
def __init__(self, qualified_table_name: str, available_mocks: List[str]):
|
|
16
|
+
self.qualified_table_name = qualified_table_name
|
|
17
|
+
self.available_mocks = available_mocks
|
|
18
|
+
available_list = ", ".join(available_mocks) if available_mocks else "None"
|
|
19
|
+
super().__init__(
|
|
20
|
+
f"Mock table not found: '{qualified_table_name}'. Available: {available_list}"
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class SQLParseError(SQLTestingError):
|
|
25
|
+
"""Raised when SQL parsing fails."""
|
|
26
|
+
|
|
27
|
+
def __init__(self, query: str, parse_error: str):
|
|
28
|
+
self.query = query
|
|
29
|
+
self.parse_error = parse_error
|
|
30
|
+
super().__init__(f"Failed to parse SQL: {parse_error}")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class QuerySizeLimitExceeded(SQLTestingError):
|
|
34
|
+
"""Raised when query size exceeds database limits."""
|
|
35
|
+
|
|
36
|
+
def __init__(self, actual_size: int, limit: int, adapter_name: str):
|
|
37
|
+
self.actual_size = actual_size
|
|
38
|
+
self.limit = limit
|
|
39
|
+
self.adapter_name = adapter_name
|
|
40
|
+
super().__init__(
|
|
41
|
+
f"Query size ({actual_size} bytes) exceeds {adapter_name} limit "
|
|
42
|
+
f"({limit} bytes). Consider using use_physical_tables=True"
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class TypeConversionError(SQLTestingError):
|
|
47
|
+
"""Raised when type conversion fails during result deserialization."""
|
|
48
|
+
|
|
49
|
+
def __init__(self, value: Any, target_type: type, column_name: str):
|
|
50
|
+
self.value = value
|
|
51
|
+
self.target_type = target_type
|
|
52
|
+
self.column_name = column_name
|
|
53
|
+
super().__init__(
|
|
54
|
+
f"Cannot convert '{value}' to {target_type.__name__} for column '{column_name}'"
|
|
55
|
+
)
|