sql-testing-library 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,42 @@
1
+ """SQL Testing Library - Test SQL queries with mock data injection."""
2
+
3
+ # Import from private modules (leading underscore indicates internal use)
4
+ from ._adapters.base import DatabaseAdapter # noqa: F401
5
+ from ._core import SQLTestCase, SQLTestFramework # noqa: F401
6
+ from ._exceptions import (
7
+ MockTableNotFoundError, # noqa: F401
8
+ QuerySizeLimitExceeded, # noqa: F401
9
+ SQLParseError, # noqa: F401
10
+ SQLTestingError, # noqa: F401
11
+ TypeConversionError, # noqa: F401
12
+ )
13
+ from ._mock_table import BaseMockTable # noqa: F401
14
+ from ._pytest_plugin import sql_test # noqa: F401
15
+
16
+
17
+ # Backward compatibility alias
18
+ TestCase = SQLTestCase
19
+
20
+ # Import adapters if their dependencies are available
21
+ try:
22
+ from ._adapters.bigquery import BigQueryAdapter
23
+
24
+ __all__ = ["BigQueryAdapter"]
25
+ except ImportError:
26
+ __all__ = []
27
+
28
+ __version__ = "0.3.0"
29
+ __all__.extend(
30
+ [
31
+ "SQLTestFramework",
32
+ "TestCase",
33
+ "BaseMockTable",
34
+ "DatabaseAdapter",
35
+ "sql_test",
36
+ "SQLTestingError",
37
+ "MockTableNotFoundError",
38
+ "SQLParseError",
39
+ "QuerySizeLimitExceeded",
40
+ "TypeConversionError",
41
+ ]
42
+ )
@@ -0,0 +1,15 @@
1
+ """Database adapters for SQL testing library."""
2
+
3
+ from typing import List
4
+
5
+
6
+ # Lazy import adapters - only import when explicitly requested
7
+ # This prevents loading all heavy database SDKs when just importing the base adapter
8
+ __all__: List[str] = []
9
+
10
+ # Individual adapters can be imported directly:
11
+ # from sql_testing_library._adapters.bigquery import BigQueryAdapter
12
+ # from sql_testing_library._adapters.athena import AthenaAdapter
13
+ # from sql_testing_library._adapters.redshift import RedshiftAdapter
14
+ # from sql_testing_library._adapters.trino import TrinoAdapter
15
+ # from sql_testing_library._adapters.snowflake import SnowflakeAdapter
@@ -0,0 +1,309 @@
1
+ """Amazon Athena adapter implementation."""
2
+
3
+ import logging
4
+ import time
5
+ from datetime import date, datetime
6
+ from decimal import Decimal
7
+ from typing import TYPE_CHECKING, Any, List, Optional, Type, Union, get_args
8
+
9
+
10
+ if TYPE_CHECKING:
11
+ import pandas as pd
12
+
13
+ import boto3
14
+
15
+ # Heavy import moved to function level for better performance
16
+ from .._mock_table import BaseMockTable
17
+ from .._types import BaseTypeConverter
18
+ from .base import DatabaseAdapter
19
+
20
+
21
+ HAS_BOTO3 = True
22
+
23
+ try:
24
+ # This is a separate import to keep the module type
25
+ # for type checking, even if the module fails to import
26
+ import boto3 as _boto3_module # noqa: F401
27
+ except ImportError:
28
+ HAS_BOTO3 = False
29
+
30
+
31
+ class AthenaTypeConverter(BaseTypeConverter):
32
+ """Athena-specific type converter."""
33
+
34
+ def convert(self, value: Any, target_type: Type) -> Any:
35
+ """Convert Athena result value to target type."""
36
+ # Handle Athena NULL values (returned as string "NULL")
37
+ if value == "NULL":
38
+ return None
39
+
40
+ # Athena returns proper Python types in most cases, so use base converter
41
+ return super().convert(value, target_type)
42
+
43
+
44
+ class AthenaAdapter(DatabaseAdapter):
45
+ """Amazon Athena adapter for SQL testing."""
46
+
47
+ def __init__(
48
+ self,
49
+ database: str,
50
+ s3_output_location: str,
51
+ region: str = "us-west-2",
52
+ aws_access_key_id: Optional[str] = None,
53
+ aws_secret_access_key: Optional[str] = None,
54
+ ) -> None:
55
+ if not HAS_BOTO3:
56
+ raise ImportError(
57
+ "Athena adapter requires boto3. "
58
+ "Install with: pip install sql-testing-library[athena]"
59
+ )
60
+
61
+ self.database = database
62
+ self.s3_output_location = s3_output_location
63
+ self.region = region
64
+
65
+ # Initialize Athena client
66
+ if aws_access_key_id and aws_secret_access_key:
67
+ self.client = boto3.client(
68
+ "athena",
69
+ region_name=region,
70
+ aws_access_key_id=aws_access_key_id,
71
+ aws_secret_access_key=aws_secret_access_key,
72
+ )
73
+ else:
74
+ # Use default credentials from ~/.aws/credentials or environment variables
75
+ self.client = boto3.client("athena", region_name=region)
76
+
77
+ def get_sqlglot_dialect(self) -> str:
78
+ """Return Presto dialect for sqlglot (Athena uses Presto SQL)."""
79
+ return "presto"
80
+
81
+ def execute_query(self, query: str) -> "pd.DataFrame":
82
+ """Execute query and return results as DataFrame."""
83
+ import pandas as pd
84
+
85
+ # Start query execution
86
+ response = self.client.start_query_execution(
87
+ QueryString=query,
88
+ QueryExecutionContext={"Database": self.database},
89
+ ResultConfiguration={"OutputLocation": self.s3_output_location},
90
+ )
91
+
92
+ query_execution_id = response["QueryExecutionId"]
93
+
94
+ # Wait for query to complete
95
+ query_status, error_info = self._wait_for_query_with_error(query_execution_id)
96
+ if query_status != "SUCCEEDED":
97
+ error_message = f"Athena query failed with status: {query_status}"
98
+ if error_info:
99
+ error_message += f";Error details: {error_info}"
100
+ raise Exception(error_message)
101
+
102
+ # Get query results
103
+ results = self.client.get_query_results(QueryExecutionId=query_execution_id)
104
+
105
+ # Convert to DataFrame
106
+ if "ResultSet" in results and "Rows" in results["ResultSet"]:
107
+ rows = results["ResultSet"]["Rows"]
108
+ if not rows:
109
+ return pd.DataFrame()
110
+
111
+ # First row is header
112
+ header = [col["VarCharValue"] for col in rows[0]["Data"]]
113
+
114
+ # Rest are data
115
+ data = []
116
+ for row in rows[1:]:
117
+ data.append([col.get("VarCharValue") for col in row["Data"]])
118
+
119
+ return pd.DataFrame(data, columns=header)
120
+ else:
121
+ return pd.DataFrame()
122
+
123
+ def create_temp_table(self, mock_table: BaseMockTable) -> str:
124
+ """Create a temporary table in Athena using CTAS."""
125
+ timestamp = int(time.time() * 1000)
126
+ temp_table_name = f"temp_{mock_table.get_table_name()}_{timestamp}"
127
+ qualified_table_name = f"{self.database}.{temp_table_name}"
128
+
129
+ # Generate CTAS statement (CREATE TABLE AS SELECT)
130
+ ctas_sql = self._generate_ctas_sql(temp_table_name, mock_table)
131
+
132
+ # Execute CTAS query
133
+ self.execute_query(ctas_sql)
134
+
135
+ return qualified_table_name
136
+
137
+ def cleanup_temp_tables(self, table_names: List[str]) -> None:
138
+ """Clean up temporary tables."""
139
+ for full_table_name in table_names:
140
+ try:
141
+ # Extract just the table name, not the database.table format
142
+ if "." in full_table_name:
143
+ table_name = full_table_name.split(".")[-1]
144
+ else:
145
+ table_name = full_table_name
146
+
147
+ drop_query = f"DROP TABLE IF EXISTS {table_name}"
148
+ self.execute_query(drop_query)
149
+ except Exception as e:
150
+ logging.warning(f"Warning: Failed to drop table {full_table_name}: {e}")
151
+
152
+ def format_value_for_cte(self, value: Any, column_type: type) -> str:
153
+ """Format value for Athena CTE VALUES clause."""
154
+ from .._sql_utils import format_sql_value
155
+
156
+ return format_sql_value(value, column_type, dialect="athena")
157
+
158
+ def get_type_converter(self) -> BaseTypeConverter:
159
+ """Get Athena-specific type converter."""
160
+ return AthenaTypeConverter()
161
+
162
+ def get_query_size_limit(self) -> Optional[int]:
163
+ """Return query size limit in bytes for Athena."""
164
+ # Athena has a 256KB limit for query strings
165
+ return 256 * 1024 # 256KB
166
+
167
+ def _wait_for_query(self, query_execution_id: str, max_retries: int = 60) -> str:
168
+ """Wait for query to complete, returns final status."""
169
+ status, _ = self._wait_for_query_with_error(query_execution_id, max_retries)
170
+ return status
171
+
172
+ def _wait_for_query_with_error(
173
+ self, query_execution_id: str, max_retries: int = 60
174
+ ) -> tuple[str, Optional[str]]:
175
+ """Wait for query to complete, returns final status and error info if failed."""
176
+ for _ in range(max_retries):
177
+ response = self.client.get_query_execution(QueryExecutionId=query_execution_id)
178
+ query_execution = response["QueryExecution"]
179
+ status = query_execution["Status"]["State"]
180
+
181
+ # Explicitly cast to string to satisfy type checker
182
+ query_status: str = str(status)
183
+
184
+ if query_status in ("SUCCEEDED", "FAILED", "CANCELLED"):
185
+ error_info = None
186
+ if query_status in ("FAILED", "CANCELLED"):
187
+ # Extract error information
188
+ status_info = query_execution["Status"]
189
+ if "StateChangeReason" in status_info:
190
+ error_info = status_info["StateChangeReason"]
191
+ elif "AthenaError" in status_info:
192
+ athena_error = status_info["AthenaError"]
193
+ error_type = athena_error.get("ErrorType", "Unknown")
194
+ error_message = athena_error.get("ErrorMessage", "No details available")
195
+ error_info = f"{error_type}: {error_message}"
196
+
197
+ return query_status, error_info
198
+
199
+ # Wait before checking again
200
+ time.sleep(1)
201
+
202
+ # If we reached here, we timed out
203
+ return "TIMEOUT", "Query execution timed out after waiting for completion"
204
+
205
+ def _build_s3_location(self, table_name: str) -> str:
206
+ """Build proper S3 location path avoiding double slashes."""
207
+ # Remove trailing slash from s3_output_location if present
208
+ base_location = self.s3_output_location.rstrip("/")
209
+ return f"{base_location}/{table_name}/"
210
+
211
+ def _generate_ctas_sql(self, table_name: str, mock_table: BaseMockTable) -> str:
212
+ """Generate CREATE TABLE AS SELECT (CTAS) statement for Athena."""
213
+ df = mock_table.to_dataframe()
214
+ column_types = mock_table.get_column_types()
215
+ columns = list(df.columns)
216
+
217
+ if df.empty:
218
+ # For empty tables, create an empty table with correct schema
219
+ # Type mapping from Python types to Athena types
220
+ type_mapping = {
221
+ str: "VARCHAR",
222
+ int: "INTEGER",
223
+ float: "DOUBLE",
224
+ bool: "BOOLEAN",
225
+ date: "DATE",
226
+ datetime: "TIMESTAMP",
227
+ Decimal: "DECIMAL(38,9)",
228
+ }
229
+
230
+ # Generate column definitions
231
+ column_defs = []
232
+ for col_name, col_type in column_types.items():
233
+ # Handle Optional types
234
+ if hasattr(col_type, "__origin__") and col_type.__origin__ is Union:
235
+ # Extract the non-None type from Optional[T]
236
+ non_none_types = [arg for arg in get_args(col_type) if arg is not type(None)]
237
+ if non_none_types:
238
+ col_type = non_none_types[0]
239
+
240
+ athena_type = type_mapping.get(col_type, "VARCHAR")
241
+ column_defs.append(f'"{col_name}" {athena_type}')
242
+
243
+ columns_sql = ",\n ".join(column_defs)
244
+
245
+ # Create an empty external table with the correct schema
246
+ return f"""
247
+ CREATE EXTERNAL TABLE {table_name} (
248
+ {columns_sql}
249
+ )
250
+ STORED AS PARQUET
251
+ LOCATION '{self._build_s3_location(table_name)}'
252
+ """
253
+ else:
254
+ # For tables with data, use CTAS with a VALUES clause
255
+ # Build a SELECT statement with literal values
256
+ select_expressions = []
257
+
258
+ # Generate column expressions for the first row
259
+ first_row = df.iloc[0]
260
+ for col_name in columns:
261
+ col_type = column_types.get(col_name, str)
262
+ value = first_row[col_name]
263
+
264
+ # Handle Optional types by extracting the non-None type for proper formatting
265
+ actual_type = col_type
266
+ if hasattr(col_type, "__origin__") and col_type.__origin__ is Union:
267
+ # Extract the non-None type from Optional[T]
268
+ non_none_types = [arg for arg in get_args(col_type) if arg is not type(None)]
269
+ if non_none_types:
270
+ actual_type = non_none_types[0]
271
+
272
+ formatted_value = self.format_value_for_cte(value, actual_type)
273
+ select_expressions.append(f'{formatted_value} AS "{col_name}"')
274
+
275
+ # Start with the first row in the SELECT
276
+ select_sql = f"SELECT {', '.join(select_expressions)}"
277
+
278
+ # Add UNION ALL for each additional row
279
+ for i in range(1, len(df)):
280
+ row = df.iloc[i]
281
+ row_values = []
282
+ for col_name in columns:
283
+ col_type = column_types.get(col_name, str)
284
+ value = row[col_name]
285
+
286
+ # Handle Optional types by extracting the non-None type for proper formatting
287
+ actual_type = col_type
288
+ if hasattr(col_type, "__origin__") and col_type.__origin__ is Union:
289
+ # Extract the non-None type from Optional[T]
290
+ non_none_types = [
291
+ arg for arg in get_args(col_type) if arg is not type(None)
292
+ ]
293
+ if non_none_types:
294
+ actual_type = non_none_types[0]
295
+
296
+ formatted_value = self.format_value_for_cte(value, actual_type)
297
+ row_values.append(formatted_value)
298
+
299
+ select_sql += f"\nUNION ALL SELECT {', '.join(row_values)}"
300
+
301
+ # Create the CTAS statement for external table
302
+ return f"""
303
+ CREATE TABLE {table_name}
304
+ WITH (
305
+ format = 'PARQUET',
306
+ external_location = '{self._build_s3_location(table_name)}'
307
+ )
308
+ AS {select_sql}
309
+ """
@@ -0,0 +1,49 @@
1
+ """Base database adapter interface."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import TYPE_CHECKING, Any, List, Optional
5
+
6
+
7
+ if TYPE_CHECKING:
8
+ import pandas as pd
9
+
10
+ # Heavy import moved to function level for better performance
11
+ from .._mock_table import BaseMockTable
12
+ from .._types import BaseTypeConverter
13
+
14
+
15
+ class DatabaseAdapter(ABC):
16
+ """Abstract base class for database adapters."""
17
+
18
+ @abstractmethod
19
+ def get_sqlglot_dialect(self) -> str:
20
+ """Return the sqlglot dialect string for this database."""
21
+ pass
22
+
23
+ @abstractmethod
24
+ def execute_query(self, query: str) -> "pd.DataFrame":
25
+ """Execute query and return results as DataFrame."""
26
+ pass
27
+
28
+ @abstractmethod
29
+ def create_temp_table(self, mock_table: BaseMockTable) -> str:
30
+ """Create a temporary table with mock data. Returns temp table name."""
31
+ pass
32
+
33
+ @abstractmethod
34
+ def cleanup_temp_tables(self, table_names: List[str]) -> None:
35
+ """Clean up temporary tables."""
36
+ pass
37
+
38
+ @abstractmethod
39
+ def format_value_for_cte(self, value: Any, column_type: type) -> str:
40
+ """Format a value for inclusion in a CTE VALUES clause."""
41
+ pass
42
+
43
+ def get_type_converter(self) -> BaseTypeConverter:
44
+ """Get the type converter for this adapter. Override for custom conversion."""
45
+ return BaseTypeConverter()
46
+
47
+ def get_query_size_limit(self) -> Optional[int]:
48
+ """Return query size limit in bytes, or None if no limit."""
49
+ return None
@@ -0,0 +1,139 @@
1
+ """BigQuery adapter implementation."""
2
+
3
+ import logging
4
+ from datetime import date, datetime
5
+ from decimal import Decimal
6
+ from typing import TYPE_CHECKING, Any, List, Optional, Type, Union, get_args
7
+
8
+
9
+ if TYPE_CHECKING:
10
+ import pandas as pd
11
+
12
+ # Heavy imports moved to function level for better performance
13
+ from google.cloud import bigquery
14
+
15
+ from .._mock_table import BaseMockTable
16
+ from .._types import BaseTypeConverter
17
+ from .base import DatabaseAdapter
18
+
19
+
20
+ HAS_BIGQUERY = True
21
+
22
+ # The duplicate import is intentional
23
+ # First import is to get the types, second is to actually import the module
24
+ # If the second fails, we set HAS_BIGQUERY to False to handle it gracefully
25
+ try:
26
+ # This is a separate import to keep the module type
27
+ # for type checking, even if the module fails to import
28
+ import google.cloud.bigquery as _bigquery_module # noqa: F401
29
+ except ImportError:
30
+ HAS_BIGQUERY = False
31
+
32
+
33
+ class BigQueryTypeConverter(BaseTypeConverter):
34
+ """BigQuery-specific type converter."""
35
+
36
+ def convert(self, value: Any, target_type: Type) -> Any:
37
+ """Convert BigQuery result value to target type."""
38
+ # BigQuery typically returns proper Python types, so use base converter
39
+ return super().convert(value, target_type)
40
+
41
+
42
+ class BigQueryAdapter(DatabaseAdapter):
43
+ """Google BigQuery adapter for SQL testing."""
44
+
45
+ def __init__(
46
+ self, project_id: str, dataset_id: str, credentials_path: Optional[str] = None
47
+ ) -> None:
48
+ if not HAS_BIGQUERY:
49
+ raise ImportError(
50
+ "BigQuery adapter requires google-cloud-bigquery. "
51
+ "Install with: pip install sql-testing-library[bigquery]"
52
+ )
53
+
54
+ self.project_id = project_id
55
+ self.dataset_id = dataset_id
56
+
57
+ if credentials_path:
58
+ self.client = bigquery.Client.from_service_account_json(credentials_path)
59
+ else:
60
+ self.client = bigquery.Client(project=project_id)
61
+
62
+ def get_sqlglot_dialect(self) -> str:
63
+ """Return BigQuery dialect for sqlglot."""
64
+ return "bigquery"
65
+
66
+ def execute_query(self, query: str) -> "pd.DataFrame":
67
+ """Execute query and return results as DataFrame."""
68
+ job = self.client.query(query)
69
+ return job.to_dataframe()
70
+
71
+ def create_temp_table(self, mock_table: BaseMockTable) -> str:
72
+ """Create temporary table in BigQuery."""
73
+ import time
74
+
75
+ temp_table_name = f"temp_{mock_table.get_table_name()}_{int(time.time() * 1000)}"
76
+ table_id = f"{self.project_id}.{self.dataset_id}.{temp_table_name}"
77
+
78
+ # Create table schema from mock table
79
+ schema = self._get_bigquery_schema(mock_table)
80
+
81
+ # Create table
82
+ table = bigquery.Table(table_id, schema=schema)
83
+ table = self.client.create_table(table)
84
+
85
+ # Insert data
86
+ df = mock_table.to_dataframe()
87
+ if not df.empty:
88
+ job_config = bigquery.LoadJobConfig()
89
+ job = self.client.load_table_from_dataframe(df, table, job_config=job_config)
90
+ job.result() # Wait for job to complete
91
+
92
+ return table_id
93
+
94
+ def cleanup_temp_tables(self, table_names: List[str]) -> None:
95
+ """Delete temporary tables."""
96
+ for table_name in table_names:
97
+ try:
98
+ self.client.delete_table(table_name)
99
+ except Exception as e:
100
+ logging.warning(f"Warning: Failed to delete table {table_name}: {e}")
101
+
102
+ def format_value_for_cte(self, value: Any, column_type: type) -> str:
103
+ """Format value for BigQuery CTE VALUES clause."""
104
+ from .._sql_utils import format_sql_value
105
+
106
+ return format_sql_value(value, column_type, dialect="bigquery")
107
+
108
+ def get_type_converter(self) -> BaseTypeConverter:
109
+ """Get BigQuery-specific type converter."""
110
+ return BigQueryTypeConverter()
111
+
112
+ def _get_bigquery_schema(self, mock_table: BaseMockTable) -> List[bigquery.SchemaField]:
113
+ """Convert mock table schema to BigQuery schema."""
114
+ column_types = mock_table.get_column_types()
115
+
116
+ # Type mapping from Python types to BigQuery types
117
+ type_mapping = {
118
+ str: bigquery.enums.SqlTypeNames.STRING,
119
+ int: bigquery.enums.SqlTypeNames.INT64,
120
+ float: bigquery.enums.SqlTypeNames.FLOAT64,
121
+ bool: bigquery.enums.SqlTypeNames.BOOL,
122
+ date: bigquery.enums.SqlTypeNames.DATE,
123
+ datetime: bigquery.enums.SqlTypeNames.DATETIME,
124
+ Decimal: bigquery.enums.SqlTypeNames.NUMERIC,
125
+ }
126
+
127
+ schema = []
128
+ for col_name, col_type in column_types.items():
129
+ # Handle Optional types
130
+ if hasattr(col_type, "__origin__") and col_type.__origin__ is Union:
131
+ # Extract the non-None type from Optional[T]
132
+ non_none_types = [arg for arg in get_args(col_type) if arg is not type(None)]
133
+ if non_none_types:
134
+ col_type = non_none_types[0]
135
+
136
+ bq_type = type_mapping.get(col_type, bigquery.enums.SqlTypeNames.STRING)
137
+ schema.append(bigquery.SchemaField(col_name, bq_type))
138
+
139
+ return schema