sql-testing-library 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sql_testing_library/__init__.py +42 -0
- sql_testing_library/_adapters/__init__.py +15 -0
- sql_testing_library/_adapters/athena.py +309 -0
- sql_testing_library/_adapters/base.py +49 -0
- sql_testing_library/_adapters/bigquery.py +139 -0
- sql_testing_library/_adapters/redshift.py +219 -0
- sql_testing_library/_adapters/snowflake.py +270 -0
- sql_testing_library/_adapters/trino.py +263 -0
- sql_testing_library/_core.py +502 -0
- sql_testing_library/_exceptions.py +55 -0
- sql_testing_library/_mock_table.py +200 -0
- sql_testing_library/_pytest_plugin.py +451 -0
- sql_testing_library/_sql_utils.py +225 -0
- sql_testing_library/_types.py +142 -0
- sql_testing_library/py.typed +0 -0
- sql_testing_library-0.4.0.dist-info/LICENSE +21 -0
- sql_testing_library-0.4.0.dist-info/METADATA +956 -0
- sql_testing_library-0.4.0.dist-info/RECORD +20 -0
- sql_testing_library-0.4.0.dist-info/WHEEL +4 -0
- sql_testing_library-0.4.0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,263 @@
|
|
|
1
|
+
"""Trino adapter implementation."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import time
|
|
5
|
+
from datetime import date, datetime
|
|
6
|
+
from decimal import Decimal
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union, get_args
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
import pandas as pd
|
|
12
|
+
|
|
13
|
+
# Heavy import moved to function level for better performance
|
|
14
|
+
from .._mock_table import BaseMockTable
|
|
15
|
+
from .._types import BaseTypeConverter
|
|
16
|
+
from .base import DatabaseAdapter
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
HAS_TRINO = True
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
# This is a separate import to keep the module type
|
|
23
|
+
# for type checking, even if the module fails to import
|
|
24
|
+
import trino as _trino_module # noqa: F401
|
|
25
|
+
except ImportError:
|
|
26
|
+
HAS_TRINO = False
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class TrinoTypeConverter(BaseTypeConverter):
|
|
30
|
+
"""Trino-specific type converter."""
|
|
31
|
+
|
|
32
|
+
def convert(self, value: Any, target_type: Type) -> Any:
|
|
33
|
+
"""Convert Trino result value to target type."""
|
|
34
|
+
# Trino returns proper Python types in most cases, so use base converter
|
|
35
|
+
return super().convert(value, target_type)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class TrinoAdapter(DatabaseAdapter):
|
|
39
|
+
"""Trino adapter for SQL testing."""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
host: str,
|
|
44
|
+
port: int = 8080,
|
|
45
|
+
user: Optional[str] = None,
|
|
46
|
+
catalog: str = "memory",
|
|
47
|
+
schema: str = "default",
|
|
48
|
+
http_scheme: str = "http",
|
|
49
|
+
auth: Optional[Dict[str, Any]] = None,
|
|
50
|
+
) -> None:
|
|
51
|
+
if not HAS_TRINO:
|
|
52
|
+
raise ImportError(
|
|
53
|
+
"Trino adapter requires trino client. "
|
|
54
|
+
"Install with: pip install sql-testing-library[trino]"
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
self.host = host
|
|
58
|
+
self.port = port
|
|
59
|
+
self.user = user
|
|
60
|
+
self.catalog = catalog
|
|
61
|
+
self.schema = schema
|
|
62
|
+
self.http_scheme = http_scheme
|
|
63
|
+
self.auth = auth
|
|
64
|
+
self.conn = None
|
|
65
|
+
|
|
66
|
+
# Create a connection - will validate the connection parameters
|
|
67
|
+
self._get_connection()
|
|
68
|
+
|
|
69
|
+
def _get_connection(self) -> Any:
|
|
70
|
+
"""Get or create a connection to Trino."""
|
|
71
|
+
import trino
|
|
72
|
+
|
|
73
|
+
# Create a new connection if needed
|
|
74
|
+
if self.conn is None:
|
|
75
|
+
self.conn = trino.dbapi.connect(
|
|
76
|
+
host=self.host,
|
|
77
|
+
port=self.port,
|
|
78
|
+
user=self.user,
|
|
79
|
+
catalog=self.catalog,
|
|
80
|
+
schema=self.schema,
|
|
81
|
+
http_scheme=self.http_scheme,
|
|
82
|
+
auth=self.auth,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
return self.conn
|
|
86
|
+
|
|
87
|
+
def get_sqlglot_dialect(self) -> str:
|
|
88
|
+
"""Return Trino dialect for sqlglot."""
|
|
89
|
+
return "trino"
|
|
90
|
+
|
|
91
|
+
def execute_query(self, query: str) -> "pd.DataFrame":
|
|
92
|
+
"""Execute query and return results as DataFrame."""
|
|
93
|
+
import pandas as pd
|
|
94
|
+
|
|
95
|
+
conn = self._get_connection()
|
|
96
|
+
|
|
97
|
+
# Execute query
|
|
98
|
+
cursor = conn.cursor()
|
|
99
|
+
cursor.execute(query)
|
|
100
|
+
|
|
101
|
+
# If this is a SELECT query, return results
|
|
102
|
+
if cursor.description:
|
|
103
|
+
# Get column names from cursor description
|
|
104
|
+
columns = [col[0] for col in cursor.description]
|
|
105
|
+
|
|
106
|
+
# Fetch all rows
|
|
107
|
+
rows = cursor.fetchall()
|
|
108
|
+
|
|
109
|
+
# Create DataFrame from rows
|
|
110
|
+
return pd.DataFrame(rows, columns=columns)
|
|
111
|
+
|
|
112
|
+
# For non-SELECT queries
|
|
113
|
+
return pd.DataFrame()
|
|
114
|
+
|
|
115
|
+
def create_temp_table(self, mock_table: BaseMockTable) -> str:
|
|
116
|
+
"""Create a temporary table in Trino using CREATE TABLE AS SELECT."""
|
|
117
|
+
timestamp = int(time.time() * 1000)
|
|
118
|
+
temp_table_name = f"temp_{mock_table.get_table_name()}_{timestamp}"
|
|
119
|
+
|
|
120
|
+
# In Trino, tables are qualified with catalog and schema
|
|
121
|
+
qualified_table_name = f"{self.catalog}.{self.schema}.{temp_table_name}"
|
|
122
|
+
|
|
123
|
+
# Generate CTAS statement (CREATE TABLE AS SELECT)
|
|
124
|
+
ctas_sql = self._generate_ctas_sql(temp_table_name, mock_table)
|
|
125
|
+
|
|
126
|
+
# Execute CTAS query
|
|
127
|
+
self.execute_query(ctas_sql)
|
|
128
|
+
|
|
129
|
+
return qualified_table_name
|
|
130
|
+
|
|
131
|
+
def cleanup_temp_tables(self, table_names: List[str]) -> None:
|
|
132
|
+
"""Clean up temporary tables."""
|
|
133
|
+
for full_table_name in table_names:
|
|
134
|
+
try:
|
|
135
|
+
# Extract just the table name from the fully qualified name
|
|
136
|
+
# Table names can be catalog.schema.table or schema.table
|
|
137
|
+
table_parts = full_table_name.split(".")
|
|
138
|
+
if len(table_parts) == 3:
|
|
139
|
+
# catalog.schema.table format
|
|
140
|
+
catalog, schema, table = table_parts
|
|
141
|
+
drop_query = f'DROP TABLE IF EXISTS {catalog}.{schema}."{table}"'
|
|
142
|
+
elif len(table_parts) == 2:
|
|
143
|
+
# schema.table format, use default catalog
|
|
144
|
+
schema, table = table_parts
|
|
145
|
+
drop_query = f'DROP TABLE IF EXISTS {self.catalog}.{schema}."{table}"'
|
|
146
|
+
else:
|
|
147
|
+
# Just table name, use default catalog and schema
|
|
148
|
+
table = full_table_name
|
|
149
|
+
drop_query = f'DROP TABLE IF EXISTS {self.catalog}.{self.schema}."{table}"'
|
|
150
|
+
|
|
151
|
+
self.execute_query(drop_query)
|
|
152
|
+
except Exception as e:
|
|
153
|
+
logging.warning(f"Warning: Failed to drop table {full_table_name}: {e}")
|
|
154
|
+
|
|
155
|
+
def format_value_for_cte(self, value: Any, column_type: type) -> str:
|
|
156
|
+
"""Format value for Trino CTE VALUES clause."""
|
|
157
|
+
from .._sql_utils import format_sql_value
|
|
158
|
+
|
|
159
|
+
return format_sql_value(value, column_type, dialect="trino")
|
|
160
|
+
|
|
161
|
+
def get_type_converter(self) -> BaseTypeConverter:
|
|
162
|
+
"""Get Trino-specific type converter."""
|
|
163
|
+
return TrinoTypeConverter()
|
|
164
|
+
|
|
165
|
+
def get_query_size_limit(self) -> Optional[int]:
|
|
166
|
+
"""Return query size limit in bytes for Trino."""
|
|
167
|
+
# Trino doesn't have a documented size limit, but we'll use a reasonable default
|
|
168
|
+
return 16 * 1024 * 1024 # 16MB
|
|
169
|
+
|
|
170
|
+
def _generate_ctas_sql(self, table_name: str, mock_table: BaseMockTable) -> str:
|
|
171
|
+
"""Generate CREATE TABLE AS SELECT (CTAS) statement for Trino."""
|
|
172
|
+
df = mock_table.to_dataframe()
|
|
173
|
+
column_types = mock_table.get_column_types()
|
|
174
|
+
columns = list(df.columns)
|
|
175
|
+
|
|
176
|
+
# Qualify table name with schema but not catalog
|
|
177
|
+
# Catalog is specified in the current session context
|
|
178
|
+
qualified_table = f"{self.schema}.{table_name}"
|
|
179
|
+
|
|
180
|
+
if df.empty:
|
|
181
|
+
# For empty tables, create an empty table with correct schema
|
|
182
|
+
# Type mapping from Python types to Trino types
|
|
183
|
+
type_mapping = {
|
|
184
|
+
str: "VARCHAR",
|
|
185
|
+
int: "BIGINT",
|
|
186
|
+
float: "DOUBLE",
|
|
187
|
+
bool: "BOOLEAN",
|
|
188
|
+
date: "DATE",
|
|
189
|
+
datetime: "TIMESTAMP",
|
|
190
|
+
Decimal: "DECIMAL(38,9)",
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
# Generate column definitions
|
|
194
|
+
column_defs = []
|
|
195
|
+
for col_name, col_type in column_types.items():
|
|
196
|
+
# Handle Optional types
|
|
197
|
+
if hasattr(col_type, "__origin__") and col_type.__origin__ is Union:
|
|
198
|
+
# Extract the non-None type from Optional[T]
|
|
199
|
+
non_none_types = [arg for arg in get_args(col_type) if arg is not type(None)]
|
|
200
|
+
if non_none_types:
|
|
201
|
+
col_type = non_none_types[0]
|
|
202
|
+
|
|
203
|
+
trino_type = type_mapping.get(col_type, "VARCHAR")
|
|
204
|
+
column_defs.append(f'"{col_name}" {trino_type}')
|
|
205
|
+
|
|
206
|
+
columns_sql = ",\n ".join(column_defs)
|
|
207
|
+
|
|
208
|
+
# Create an empty table with the correct schema
|
|
209
|
+
# Memory catalog doesn't support table properties like format
|
|
210
|
+
if self.catalog == "memory":
|
|
211
|
+
return f"""
|
|
212
|
+
CREATE TABLE {qualified_table} (
|
|
213
|
+
{columns_sql}
|
|
214
|
+
)
|
|
215
|
+
"""
|
|
216
|
+
else:
|
|
217
|
+
return f"""
|
|
218
|
+
CREATE TABLE {qualified_table} (
|
|
219
|
+
{columns_sql}
|
|
220
|
+
)
|
|
221
|
+
WITH (format = 'ORC')
|
|
222
|
+
"""
|
|
223
|
+
else:
|
|
224
|
+
# For tables with data, use CTAS with a VALUES clause
|
|
225
|
+
# Build a SELECT statement with literal values for the first row
|
|
226
|
+
select_expressions = []
|
|
227
|
+
|
|
228
|
+
# Generate column expressions for the first row
|
|
229
|
+
first_row = df.iloc[0]
|
|
230
|
+
for col_name in columns:
|
|
231
|
+
col_type = column_types.get(col_name, str)
|
|
232
|
+
value = first_row[col_name]
|
|
233
|
+
formatted_value = self.format_value_for_cte(value, col_type)
|
|
234
|
+
select_expressions.append(f'{formatted_value} AS "{col_name}"')
|
|
235
|
+
|
|
236
|
+
# Start with the first row in the SELECT
|
|
237
|
+
select_sql = f"SELECT {', '.join(select_expressions)}"
|
|
238
|
+
|
|
239
|
+
# Add UNION ALL for each additional row
|
|
240
|
+
for i in range(1, len(df)):
|
|
241
|
+
row = df.iloc[i]
|
|
242
|
+
row_values = []
|
|
243
|
+
for col_name in columns:
|
|
244
|
+
col_type = column_types.get(col_name, str)
|
|
245
|
+
value = row[col_name]
|
|
246
|
+
formatted_value = self.format_value_for_cte(value, col_type)
|
|
247
|
+
row_values.append(formatted_value)
|
|
248
|
+
|
|
249
|
+
select_sql += f"\nUNION ALL SELECT {', '.join(row_values)}"
|
|
250
|
+
|
|
251
|
+
# Create the CTAS statement
|
|
252
|
+
# Memory catalog doesn't support table properties like format
|
|
253
|
+
if self.catalog == "memory":
|
|
254
|
+
return f"""
|
|
255
|
+
CREATE TABLE {qualified_table}
|
|
256
|
+
AS {select_sql}
|
|
257
|
+
"""
|
|
258
|
+
else:
|
|
259
|
+
return f"""
|
|
260
|
+
CREATE TABLE {qualified_table}
|
|
261
|
+
WITH (format = 'ORC')
|
|
262
|
+
AS {select_sql}
|
|
263
|
+
"""
|