sql-testing-library 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sql_testing_library/__init__.py +42 -0
- sql_testing_library/_adapters/__init__.py +15 -0
- sql_testing_library/_adapters/athena.py +309 -0
- sql_testing_library/_adapters/base.py +49 -0
- sql_testing_library/_adapters/bigquery.py +139 -0
- sql_testing_library/_adapters/redshift.py +219 -0
- sql_testing_library/_adapters/snowflake.py +270 -0
- sql_testing_library/_adapters/trino.py +263 -0
- sql_testing_library/_core.py +502 -0
- sql_testing_library/_exceptions.py +55 -0
- sql_testing_library/_mock_table.py +200 -0
- sql_testing_library/_pytest_plugin.py +451 -0
- sql_testing_library/_sql_utils.py +225 -0
- sql_testing_library/_types.py +142 -0
- sql_testing_library/py.typed +0 -0
- sql_testing_library-0.4.0.dist-info/LICENSE +21 -0
- sql_testing_library-0.4.0.dist-info/METADATA +956 -0
- sql_testing_library-0.4.0.dist-info/RECORD +20 -0
- sql_testing_library-0.4.0.dist-info/WHEEL +4 -0
- sql_testing_library-0.4.0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
"""SQL Testing Library - Test SQL queries with mock data injection."""
|
|
2
|
+
|
|
3
|
+
# Import from private modules (leading underscore indicates internal use)
|
|
4
|
+
from ._adapters.base import DatabaseAdapter # noqa: F401
|
|
5
|
+
from ._core import SQLTestCase, SQLTestFramework # noqa: F401
|
|
6
|
+
from ._exceptions import (
|
|
7
|
+
MockTableNotFoundError, # noqa: F401
|
|
8
|
+
QuerySizeLimitExceeded, # noqa: F401
|
|
9
|
+
SQLParseError, # noqa: F401
|
|
10
|
+
SQLTestingError, # noqa: F401
|
|
11
|
+
TypeConversionError, # noqa: F401
|
|
12
|
+
)
|
|
13
|
+
from ._mock_table import BaseMockTable # noqa: F401
|
|
14
|
+
from ._pytest_plugin import sql_test # noqa: F401
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
# Backward compatibility alias
|
|
18
|
+
TestCase = SQLTestCase
|
|
19
|
+
|
|
20
|
+
# Import adapters if their dependencies are available
|
|
21
|
+
try:
|
|
22
|
+
from ._adapters.bigquery import BigQueryAdapter
|
|
23
|
+
|
|
24
|
+
__all__ = ["BigQueryAdapter"]
|
|
25
|
+
except ImportError:
|
|
26
|
+
__all__ = []
|
|
27
|
+
|
|
28
|
+
__version__ = "0.3.0"
|
|
29
|
+
__all__.extend(
|
|
30
|
+
[
|
|
31
|
+
"SQLTestFramework",
|
|
32
|
+
"TestCase",
|
|
33
|
+
"BaseMockTable",
|
|
34
|
+
"DatabaseAdapter",
|
|
35
|
+
"sql_test",
|
|
36
|
+
"SQLTestingError",
|
|
37
|
+
"MockTableNotFoundError",
|
|
38
|
+
"SQLParseError",
|
|
39
|
+
"QuerySizeLimitExceeded",
|
|
40
|
+
"TypeConversionError",
|
|
41
|
+
]
|
|
42
|
+
)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Database adapters for SQL testing library."""
|
|
2
|
+
|
|
3
|
+
from typing import List
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
# Lazy import adapters - only import when explicitly requested
|
|
7
|
+
# This prevents loading all heavy database SDKs when just importing the base adapter
|
|
8
|
+
__all__: List[str] = []
|
|
9
|
+
|
|
10
|
+
# Individual adapters can be imported directly:
|
|
11
|
+
# from sql_testing_library._adapters.bigquery import BigQueryAdapter
|
|
12
|
+
# from sql_testing_library._adapters.athena import AthenaAdapter
|
|
13
|
+
# from sql_testing_library._adapters.redshift import RedshiftAdapter
|
|
14
|
+
# from sql_testing_library._adapters.trino import TrinoAdapter
|
|
15
|
+
# from sql_testing_library._adapters.snowflake import SnowflakeAdapter
|
|
@@ -0,0 +1,309 @@
|
|
|
1
|
+
"""Amazon Athena adapter implementation."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import time
|
|
5
|
+
from datetime import date, datetime
|
|
6
|
+
from decimal import Decimal
|
|
7
|
+
from typing import TYPE_CHECKING, Any, List, Optional, Type, Union, get_args
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
import pandas as pd
|
|
12
|
+
|
|
13
|
+
import boto3
|
|
14
|
+
|
|
15
|
+
# Heavy import moved to function level for better performance
|
|
16
|
+
from .._mock_table import BaseMockTable
|
|
17
|
+
from .._types import BaseTypeConverter
|
|
18
|
+
from .base import DatabaseAdapter
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
HAS_BOTO3 = True
|
|
22
|
+
|
|
23
|
+
try:
|
|
24
|
+
# This is a separate import to keep the module type
|
|
25
|
+
# for type checking, even if the module fails to import
|
|
26
|
+
import boto3 as _boto3_module # noqa: F401
|
|
27
|
+
except ImportError:
|
|
28
|
+
HAS_BOTO3 = False
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class AthenaTypeConverter(BaseTypeConverter):
|
|
32
|
+
"""Athena-specific type converter."""
|
|
33
|
+
|
|
34
|
+
def convert(self, value: Any, target_type: Type) -> Any:
|
|
35
|
+
"""Convert Athena result value to target type."""
|
|
36
|
+
# Handle Athena NULL values (returned as string "NULL")
|
|
37
|
+
if value == "NULL":
|
|
38
|
+
return None
|
|
39
|
+
|
|
40
|
+
# Athena returns proper Python types in most cases, so use base converter
|
|
41
|
+
return super().convert(value, target_type)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class AthenaAdapter(DatabaseAdapter):
|
|
45
|
+
"""Amazon Athena adapter for SQL testing."""
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
database: str,
|
|
50
|
+
s3_output_location: str,
|
|
51
|
+
region: str = "us-west-2",
|
|
52
|
+
aws_access_key_id: Optional[str] = None,
|
|
53
|
+
aws_secret_access_key: Optional[str] = None,
|
|
54
|
+
) -> None:
|
|
55
|
+
if not HAS_BOTO3:
|
|
56
|
+
raise ImportError(
|
|
57
|
+
"Athena adapter requires boto3. "
|
|
58
|
+
"Install with: pip install sql-testing-library[athena]"
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
self.database = database
|
|
62
|
+
self.s3_output_location = s3_output_location
|
|
63
|
+
self.region = region
|
|
64
|
+
|
|
65
|
+
# Initialize Athena client
|
|
66
|
+
if aws_access_key_id and aws_secret_access_key:
|
|
67
|
+
self.client = boto3.client(
|
|
68
|
+
"athena",
|
|
69
|
+
region_name=region,
|
|
70
|
+
aws_access_key_id=aws_access_key_id,
|
|
71
|
+
aws_secret_access_key=aws_secret_access_key,
|
|
72
|
+
)
|
|
73
|
+
else:
|
|
74
|
+
# Use default credentials from ~/.aws/credentials or environment variables
|
|
75
|
+
self.client = boto3.client("athena", region_name=region)
|
|
76
|
+
|
|
77
|
+
def get_sqlglot_dialect(self) -> str:
|
|
78
|
+
"""Return Presto dialect for sqlglot (Athena uses Presto SQL)."""
|
|
79
|
+
return "presto"
|
|
80
|
+
|
|
81
|
+
def execute_query(self, query: str) -> "pd.DataFrame":
|
|
82
|
+
"""Execute query and return results as DataFrame."""
|
|
83
|
+
import pandas as pd
|
|
84
|
+
|
|
85
|
+
# Start query execution
|
|
86
|
+
response = self.client.start_query_execution(
|
|
87
|
+
QueryString=query,
|
|
88
|
+
QueryExecutionContext={"Database": self.database},
|
|
89
|
+
ResultConfiguration={"OutputLocation": self.s3_output_location},
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
query_execution_id = response["QueryExecutionId"]
|
|
93
|
+
|
|
94
|
+
# Wait for query to complete
|
|
95
|
+
query_status, error_info = self._wait_for_query_with_error(query_execution_id)
|
|
96
|
+
if query_status != "SUCCEEDED":
|
|
97
|
+
error_message = f"Athena query failed with status: {query_status}"
|
|
98
|
+
if error_info:
|
|
99
|
+
error_message += f";Error details: {error_info}"
|
|
100
|
+
raise Exception(error_message)
|
|
101
|
+
|
|
102
|
+
# Get query results
|
|
103
|
+
results = self.client.get_query_results(QueryExecutionId=query_execution_id)
|
|
104
|
+
|
|
105
|
+
# Convert to DataFrame
|
|
106
|
+
if "ResultSet" in results and "Rows" in results["ResultSet"]:
|
|
107
|
+
rows = results["ResultSet"]["Rows"]
|
|
108
|
+
if not rows:
|
|
109
|
+
return pd.DataFrame()
|
|
110
|
+
|
|
111
|
+
# First row is header
|
|
112
|
+
header = [col["VarCharValue"] for col in rows[0]["Data"]]
|
|
113
|
+
|
|
114
|
+
# Rest are data
|
|
115
|
+
data = []
|
|
116
|
+
for row in rows[1:]:
|
|
117
|
+
data.append([col.get("VarCharValue") for col in row["Data"]])
|
|
118
|
+
|
|
119
|
+
return pd.DataFrame(data, columns=header)
|
|
120
|
+
else:
|
|
121
|
+
return pd.DataFrame()
|
|
122
|
+
|
|
123
|
+
def create_temp_table(self, mock_table: BaseMockTable) -> str:
|
|
124
|
+
"""Create a temporary table in Athena using CTAS."""
|
|
125
|
+
timestamp = int(time.time() * 1000)
|
|
126
|
+
temp_table_name = f"temp_{mock_table.get_table_name()}_{timestamp}"
|
|
127
|
+
qualified_table_name = f"{self.database}.{temp_table_name}"
|
|
128
|
+
|
|
129
|
+
# Generate CTAS statement (CREATE TABLE AS SELECT)
|
|
130
|
+
ctas_sql = self._generate_ctas_sql(temp_table_name, mock_table)
|
|
131
|
+
|
|
132
|
+
# Execute CTAS query
|
|
133
|
+
self.execute_query(ctas_sql)
|
|
134
|
+
|
|
135
|
+
return qualified_table_name
|
|
136
|
+
|
|
137
|
+
def cleanup_temp_tables(self, table_names: List[str]) -> None:
|
|
138
|
+
"""Clean up temporary tables."""
|
|
139
|
+
for full_table_name in table_names:
|
|
140
|
+
try:
|
|
141
|
+
# Extract just the table name, not the database.table format
|
|
142
|
+
if "." in full_table_name:
|
|
143
|
+
table_name = full_table_name.split(".")[-1]
|
|
144
|
+
else:
|
|
145
|
+
table_name = full_table_name
|
|
146
|
+
|
|
147
|
+
drop_query = f"DROP TABLE IF EXISTS {table_name}"
|
|
148
|
+
self.execute_query(drop_query)
|
|
149
|
+
except Exception as e:
|
|
150
|
+
logging.warning(f"Warning: Failed to drop table {full_table_name}: {e}")
|
|
151
|
+
|
|
152
|
+
def format_value_for_cte(self, value: Any, column_type: type) -> str:
|
|
153
|
+
"""Format value for Athena CTE VALUES clause."""
|
|
154
|
+
from .._sql_utils import format_sql_value
|
|
155
|
+
|
|
156
|
+
return format_sql_value(value, column_type, dialect="athena")
|
|
157
|
+
|
|
158
|
+
def get_type_converter(self) -> BaseTypeConverter:
|
|
159
|
+
"""Get Athena-specific type converter."""
|
|
160
|
+
return AthenaTypeConverter()
|
|
161
|
+
|
|
162
|
+
def get_query_size_limit(self) -> Optional[int]:
|
|
163
|
+
"""Return query size limit in bytes for Athena."""
|
|
164
|
+
# Athena has a 256KB limit for query strings
|
|
165
|
+
return 256 * 1024 # 256KB
|
|
166
|
+
|
|
167
|
+
def _wait_for_query(self, query_execution_id: str, max_retries: int = 60) -> str:
|
|
168
|
+
"""Wait for query to complete, returns final status."""
|
|
169
|
+
status, _ = self._wait_for_query_with_error(query_execution_id, max_retries)
|
|
170
|
+
return status
|
|
171
|
+
|
|
172
|
+
def _wait_for_query_with_error(
|
|
173
|
+
self, query_execution_id: str, max_retries: int = 60
|
|
174
|
+
) -> tuple[str, Optional[str]]:
|
|
175
|
+
"""Wait for query to complete, returns final status and error info if failed."""
|
|
176
|
+
for _ in range(max_retries):
|
|
177
|
+
response = self.client.get_query_execution(QueryExecutionId=query_execution_id)
|
|
178
|
+
query_execution = response["QueryExecution"]
|
|
179
|
+
status = query_execution["Status"]["State"]
|
|
180
|
+
|
|
181
|
+
# Explicitly cast to string to satisfy type checker
|
|
182
|
+
query_status: str = str(status)
|
|
183
|
+
|
|
184
|
+
if query_status in ("SUCCEEDED", "FAILED", "CANCELLED"):
|
|
185
|
+
error_info = None
|
|
186
|
+
if query_status in ("FAILED", "CANCELLED"):
|
|
187
|
+
# Extract error information
|
|
188
|
+
status_info = query_execution["Status"]
|
|
189
|
+
if "StateChangeReason" in status_info:
|
|
190
|
+
error_info = status_info["StateChangeReason"]
|
|
191
|
+
elif "AthenaError" in status_info:
|
|
192
|
+
athena_error = status_info["AthenaError"]
|
|
193
|
+
error_type = athena_error.get("ErrorType", "Unknown")
|
|
194
|
+
error_message = athena_error.get("ErrorMessage", "No details available")
|
|
195
|
+
error_info = f"{error_type}: {error_message}"
|
|
196
|
+
|
|
197
|
+
return query_status, error_info
|
|
198
|
+
|
|
199
|
+
# Wait before checking again
|
|
200
|
+
time.sleep(1)
|
|
201
|
+
|
|
202
|
+
# If we reached here, we timed out
|
|
203
|
+
return "TIMEOUT", "Query execution timed out after waiting for completion"
|
|
204
|
+
|
|
205
|
+
def _build_s3_location(self, table_name: str) -> str:
|
|
206
|
+
"""Build proper S3 location path avoiding double slashes."""
|
|
207
|
+
# Remove trailing slash from s3_output_location if present
|
|
208
|
+
base_location = self.s3_output_location.rstrip("/")
|
|
209
|
+
return f"{base_location}/{table_name}/"
|
|
210
|
+
|
|
211
|
+
def _generate_ctas_sql(self, table_name: str, mock_table: BaseMockTable) -> str:
|
|
212
|
+
"""Generate CREATE TABLE AS SELECT (CTAS) statement for Athena."""
|
|
213
|
+
df = mock_table.to_dataframe()
|
|
214
|
+
column_types = mock_table.get_column_types()
|
|
215
|
+
columns = list(df.columns)
|
|
216
|
+
|
|
217
|
+
if df.empty:
|
|
218
|
+
# For empty tables, create an empty table with correct schema
|
|
219
|
+
# Type mapping from Python types to Athena types
|
|
220
|
+
type_mapping = {
|
|
221
|
+
str: "VARCHAR",
|
|
222
|
+
int: "INTEGER",
|
|
223
|
+
float: "DOUBLE",
|
|
224
|
+
bool: "BOOLEAN",
|
|
225
|
+
date: "DATE",
|
|
226
|
+
datetime: "TIMESTAMP",
|
|
227
|
+
Decimal: "DECIMAL(38,9)",
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
# Generate column definitions
|
|
231
|
+
column_defs = []
|
|
232
|
+
for col_name, col_type in column_types.items():
|
|
233
|
+
# Handle Optional types
|
|
234
|
+
if hasattr(col_type, "__origin__") and col_type.__origin__ is Union:
|
|
235
|
+
# Extract the non-None type from Optional[T]
|
|
236
|
+
non_none_types = [arg for arg in get_args(col_type) if arg is not type(None)]
|
|
237
|
+
if non_none_types:
|
|
238
|
+
col_type = non_none_types[0]
|
|
239
|
+
|
|
240
|
+
athena_type = type_mapping.get(col_type, "VARCHAR")
|
|
241
|
+
column_defs.append(f'"{col_name}" {athena_type}')
|
|
242
|
+
|
|
243
|
+
columns_sql = ",\n ".join(column_defs)
|
|
244
|
+
|
|
245
|
+
# Create an empty external table with the correct schema
|
|
246
|
+
return f"""
|
|
247
|
+
CREATE EXTERNAL TABLE {table_name} (
|
|
248
|
+
{columns_sql}
|
|
249
|
+
)
|
|
250
|
+
STORED AS PARQUET
|
|
251
|
+
LOCATION '{self._build_s3_location(table_name)}'
|
|
252
|
+
"""
|
|
253
|
+
else:
|
|
254
|
+
# For tables with data, use CTAS with a VALUES clause
|
|
255
|
+
# Build a SELECT statement with literal values
|
|
256
|
+
select_expressions = []
|
|
257
|
+
|
|
258
|
+
# Generate column expressions for the first row
|
|
259
|
+
first_row = df.iloc[0]
|
|
260
|
+
for col_name in columns:
|
|
261
|
+
col_type = column_types.get(col_name, str)
|
|
262
|
+
value = first_row[col_name]
|
|
263
|
+
|
|
264
|
+
# Handle Optional types by extracting the non-None type for proper formatting
|
|
265
|
+
actual_type = col_type
|
|
266
|
+
if hasattr(col_type, "__origin__") and col_type.__origin__ is Union:
|
|
267
|
+
# Extract the non-None type from Optional[T]
|
|
268
|
+
non_none_types = [arg for arg in get_args(col_type) if arg is not type(None)]
|
|
269
|
+
if non_none_types:
|
|
270
|
+
actual_type = non_none_types[0]
|
|
271
|
+
|
|
272
|
+
formatted_value = self.format_value_for_cte(value, actual_type)
|
|
273
|
+
select_expressions.append(f'{formatted_value} AS "{col_name}"')
|
|
274
|
+
|
|
275
|
+
# Start with the first row in the SELECT
|
|
276
|
+
select_sql = f"SELECT {', '.join(select_expressions)}"
|
|
277
|
+
|
|
278
|
+
# Add UNION ALL for each additional row
|
|
279
|
+
for i in range(1, len(df)):
|
|
280
|
+
row = df.iloc[i]
|
|
281
|
+
row_values = []
|
|
282
|
+
for col_name in columns:
|
|
283
|
+
col_type = column_types.get(col_name, str)
|
|
284
|
+
value = row[col_name]
|
|
285
|
+
|
|
286
|
+
# Handle Optional types by extracting the non-None type for proper formatting
|
|
287
|
+
actual_type = col_type
|
|
288
|
+
if hasattr(col_type, "__origin__") and col_type.__origin__ is Union:
|
|
289
|
+
# Extract the non-None type from Optional[T]
|
|
290
|
+
non_none_types = [
|
|
291
|
+
arg for arg in get_args(col_type) if arg is not type(None)
|
|
292
|
+
]
|
|
293
|
+
if non_none_types:
|
|
294
|
+
actual_type = non_none_types[0]
|
|
295
|
+
|
|
296
|
+
formatted_value = self.format_value_for_cte(value, actual_type)
|
|
297
|
+
row_values.append(formatted_value)
|
|
298
|
+
|
|
299
|
+
select_sql += f"\nUNION ALL SELECT {', '.join(row_values)}"
|
|
300
|
+
|
|
301
|
+
# Create the CTAS statement for external table
|
|
302
|
+
return f"""
|
|
303
|
+
CREATE TABLE {table_name}
|
|
304
|
+
WITH (
|
|
305
|
+
format = 'PARQUET',
|
|
306
|
+
external_location = '{self._build_s3_location(table_name)}'
|
|
307
|
+
)
|
|
308
|
+
AS {select_sql}
|
|
309
|
+
"""
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
"""Base database adapter interface."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import TYPE_CHECKING, Any, List, Optional
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
import pandas as pd
|
|
9
|
+
|
|
10
|
+
# Heavy import moved to function level for better performance
|
|
11
|
+
from .._mock_table import BaseMockTable
|
|
12
|
+
from .._types import BaseTypeConverter
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class DatabaseAdapter(ABC):
|
|
16
|
+
"""Abstract base class for database adapters."""
|
|
17
|
+
|
|
18
|
+
@abstractmethod
|
|
19
|
+
def get_sqlglot_dialect(self) -> str:
|
|
20
|
+
"""Return the sqlglot dialect string for this database."""
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
@abstractmethod
|
|
24
|
+
def execute_query(self, query: str) -> "pd.DataFrame":
|
|
25
|
+
"""Execute query and return results as DataFrame."""
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
@abstractmethod
|
|
29
|
+
def create_temp_table(self, mock_table: BaseMockTable) -> str:
|
|
30
|
+
"""Create a temporary table with mock data. Returns temp table name."""
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
@abstractmethod
|
|
34
|
+
def cleanup_temp_tables(self, table_names: List[str]) -> None:
|
|
35
|
+
"""Clean up temporary tables."""
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
@abstractmethod
|
|
39
|
+
def format_value_for_cte(self, value: Any, column_type: type) -> str:
|
|
40
|
+
"""Format a value for inclusion in a CTE VALUES clause."""
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
def get_type_converter(self) -> BaseTypeConverter:
|
|
44
|
+
"""Get the type converter for this adapter. Override for custom conversion."""
|
|
45
|
+
return BaseTypeConverter()
|
|
46
|
+
|
|
47
|
+
def get_query_size_limit(self) -> Optional[int]:
|
|
48
|
+
"""Return query size limit in bytes, or None if no limit."""
|
|
49
|
+
return None
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
"""BigQuery adapter implementation."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from datetime import date, datetime
|
|
5
|
+
from decimal import Decimal
|
|
6
|
+
from typing import TYPE_CHECKING, Any, List, Optional, Type, Union, get_args
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
import pandas as pd
|
|
11
|
+
|
|
12
|
+
# Heavy imports moved to function level for better performance
|
|
13
|
+
from google.cloud import bigquery
|
|
14
|
+
|
|
15
|
+
from .._mock_table import BaseMockTable
|
|
16
|
+
from .._types import BaseTypeConverter
|
|
17
|
+
from .base import DatabaseAdapter
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
HAS_BIGQUERY = True
|
|
21
|
+
|
|
22
|
+
# The duplicate import is intentional
|
|
23
|
+
# First import is to get the types, second is to actually import the module
|
|
24
|
+
# If the second fails, we set HAS_BIGQUERY to False to handle it gracefully
|
|
25
|
+
try:
|
|
26
|
+
# This is a separate import to keep the module type
|
|
27
|
+
# for type checking, even if the module fails to import
|
|
28
|
+
import google.cloud.bigquery as _bigquery_module # noqa: F401
|
|
29
|
+
except ImportError:
|
|
30
|
+
HAS_BIGQUERY = False
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class BigQueryTypeConverter(BaseTypeConverter):
|
|
34
|
+
"""BigQuery-specific type converter."""
|
|
35
|
+
|
|
36
|
+
def convert(self, value: Any, target_type: Type) -> Any:
|
|
37
|
+
"""Convert BigQuery result value to target type."""
|
|
38
|
+
# BigQuery typically returns proper Python types, so use base converter
|
|
39
|
+
return super().convert(value, target_type)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class BigQueryAdapter(DatabaseAdapter):
|
|
43
|
+
"""Google BigQuery adapter for SQL testing."""
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self, project_id: str, dataset_id: str, credentials_path: Optional[str] = None
|
|
47
|
+
) -> None:
|
|
48
|
+
if not HAS_BIGQUERY:
|
|
49
|
+
raise ImportError(
|
|
50
|
+
"BigQuery adapter requires google-cloud-bigquery. "
|
|
51
|
+
"Install with: pip install sql-testing-library[bigquery]"
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
self.project_id = project_id
|
|
55
|
+
self.dataset_id = dataset_id
|
|
56
|
+
|
|
57
|
+
if credentials_path:
|
|
58
|
+
self.client = bigquery.Client.from_service_account_json(credentials_path)
|
|
59
|
+
else:
|
|
60
|
+
self.client = bigquery.Client(project=project_id)
|
|
61
|
+
|
|
62
|
+
def get_sqlglot_dialect(self) -> str:
|
|
63
|
+
"""Return BigQuery dialect for sqlglot."""
|
|
64
|
+
return "bigquery"
|
|
65
|
+
|
|
66
|
+
def execute_query(self, query: str) -> "pd.DataFrame":
|
|
67
|
+
"""Execute query and return results as DataFrame."""
|
|
68
|
+
job = self.client.query(query)
|
|
69
|
+
return job.to_dataframe()
|
|
70
|
+
|
|
71
|
+
def create_temp_table(self, mock_table: BaseMockTable) -> str:
|
|
72
|
+
"""Create temporary table in BigQuery."""
|
|
73
|
+
import time
|
|
74
|
+
|
|
75
|
+
temp_table_name = f"temp_{mock_table.get_table_name()}_{int(time.time() * 1000)}"
|
|
76
|
+
table_id = f"{self.project_id}.{self.dataset_id}.{temp_table_name}"
|
|
77
|
+
|
|
78
|
+
# Create table schema from mock table
|
|
79
|
+
schema = self._get_bigquery_schema(mock_table)
|
|
80
|
+
|
|
81
|
+
# Create table
|
|
82
|
+
table = bigquery.Table(table_id, schema=schema)
|
|
83
|
+
table = self.client.create_table(table)
|
|
84
|
+
|
|
85
|
+
# Insert data
|
|
86
|
+
df = mock_table.to_dataframe()
|
|
87
|
+
if not df.empty:
|
|
88
|
+
job_config = bigquery.LoadJobConfig()
|
|
89
|
+
job = self.client.load_table_from_dataframe(df, table, job_config=job_config)
|
|
90
|
+
job.result() # Wait for job to complete
|
|
91
|
+
|
|
92
|
+
return table_id
|
|
93
|
+
|
|
94
|
+
def cleanup_temp_tables(self, table_names: List[str]) -> None:
|
|
95
|
+
"""Delete temporary tables."""
|
|
96
|
+
for table_name in table_names:
|
|
97
|
+
try:
|
|
98
|
+
self.client.delete_table(table_name)
|
|
99
|
+
except Exception as e:
|
|
100
|
+
logging.warning(f"Warning: Failed to delete table {table_name}: {e}")
|
|
101
|
+
|
|
102
|
+
def format_value_for_cte(self, value: Any, column_type: type) -> str:
|
|
103
|
+
"""Format value for BigQuery CTE VALUES clause."""
|
|
104
|
+
from .._sql_utils import format_sql_value
|
|
105
|
+
|
|
106
|
+
return format_sql_value(value, column_type, dialect="bigquery")
|
|
107
|
+
|
|
108
|
+
def get_type_converter(self) -> BaseTypeConverter:
|
|
109
|
+
"""Get BigQuery-specific type converter."""
|
|
110
|
+
return BigQueryTypeConverter()
|
|
111
|
+
|
|
112
|
+
def _get_bigquery_schema(self, mock_table: BaseMockTable) -> List[bigquery.SchemaField]:
|
|
113
|
+
"""Convert mock table schema to BigQuery schema."""
|
|
114
|
+
column_types = mock_table.get_column_types()
|
|
115
|
+
|
|
116
|
+
# Type mapping from Python types to BigQuery types
|
|
117
|
+
type_mapping = {
|
|
118
|
+
str: bigquery.enums.SqlTypeNames.STRING,
|
|
119
|
+
int: bigquery.enums.SqlTypeNames.INT64,
|
|
120
|
+
float: bigquery.enums.SqlTypeNames.FLOAT64,
|
|
121
|
+
bool: bigquery.enums.SqlTypeNames.BOOL,
|
|
122
|
+
date: bigquery.enums.SqlTypeNames.DATE,
|
|
123
|
+
datetime: bigquery.enums.SqlTypeNames.DATETIME,
|
|
124
|
+
Decimal: bigquery.enums.SqlTypeNames.NUMERIC,
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
schema = []
|
|
128
|
+
for col_name, col_type in column_types.items():
|
|
129
|
+
# Handle Optional types
|
|
130
|
+
if hasattr(col_type, "__origin__") and col_type.__origin__ is Union:
|
|
131
|
+
# Extract the non-None type from Optional[T]
|
|
132
|
+
non_none_types = [arg for arg in get_args(col_type) if arg is not type(None)]
|
|
133
|
+
if non_none_types:
|
|
134
|
+
col_type = non_none_types[0]
|
|
135
|
+
|
|
136
|
+
bq_type = type_mapping.get(col_type, bigquery.enums.SqlTypeNames.STRING)
|
|
137
|
+
schema.append(bigquery.SchemaField(col_name, bq_type))
|
|
138
|
+
|
|
139
|
+
return schema
|