sql-testing-library 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,263 @@
1
+ """Trino adapter implementation."""
2
+
3
+ import logging
4
+ import time
5
+ from datetime import date, datetime
6
+ from decimal import Decimal
7
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union, get_args
8
+
9
+
10
+ if TYPE_CHECKING:
11
+ import pandas as pd
12
+
13
+ # Heavy import moved to function level for better performance
14
+ from .._mock_table import BaseMockTable
15
+ from .._types import BaseTypeConverter
16
+ from .base import DatabaseAdapter
17
+
18
+
19
+ HAS_TRINO = True
20
+
21
+ try:
22
+ # This is a separate import to keep the module type
23
+ # for type checking, even if the module fails to import
24
+ import trino as _trino_module # noqa: F401
25
+ except ImportError:
26
+ HAS_TRINO = False
27
+
28
+
29
+ class TrinoTypeConverter(BaseTypeConverter):
30
+ """Trino-specific type converter."""
31
+
32
+ def convert(self, value: Any, target_type: Type) -> Any:
33
+ """Convert Trino result value to target type."""
34
+ # Trino returns proper Python types in most cases, so use base converter
35
+ return super().convert(value, target_type)
36
+
37
+
38
+ class TrinoAdapter(DatabaseAdapter):
39
+ """Trino adapter for SQL testing."""
40
+
41
+ def __init__(
42
+ self,
43
+ host: str,
44
+ port: int = 8080,
45
+ user: Optional[str] = None,
46
+ catalog: str = "memory",
47
+ schema: str = "default",
48
+ http_scheme: str = "http",
49
+ auth: Optional[Dict[str, Any]] = None,
50
+ ) -> None:
51
+ if not HAS_TRINO:
52
+ raise ImportError(
53
+ "Trino adapter requires trino client. "
54
+ "Install with: pip install sql-testing-library[trino]"
55
+ )
56
+
57
+ self.host = host
58
+ self.port = port
59
+ self.user = user
60
+ self.catalog = catalog
61
+ self.schema = schema
62
+ self.http_scheme = http_scheme
63
+ self.auth = auth
64
+ self.conn = None
65
+
66
+ # Create a connection - will validate the connection parameters
67
+ self._get_connection()
68
+
69
+ def _get_connection(self) -> Any:
70
+ """Get or create a connection to Trino."""
71
+ import trino
72
+
73
+ # Create a new connection if needed
74
+ if self.conn is None:
75
+ self.conn = trino.dbapi.connect(
76
+ host=self.host,
77
+ port=self.port,
78
+ user=self.user,
79
+ catalog=self.catalog,
80
+ schema=self.schema,
81
+ http_scheme=self.http_scheme,
82
+ auth=self.auth,
83
+ )
84
+
85
+ return self.conn
86
+
87
+ def get_sqlglot_dialect(self) -> str:
88
+ """Return Trino dialect for sqlglot."""
89
+ return "trino"
90
+
91
+ def execute_query(self, query: str) -> "pd.DataFrame":
92
+ """Execute query and return results as DataFrame."""
93
+ import pandas as pd
94
+
95
+ conn = self._get_connection()
96
+
97
+ # Execute query
98
+ cursor = conn.cursor()
99
+ cursor.execute(query)
100
+
101
+ # If this is a SELECT query, return results
102
+ if cursor.description:
103
+ # Get column names from cursor description
104
+ columns = [col[0] for col in cursor.description]
105
+
106
+ # Fetch all rows
107
+ rows = cursor.fetchall()
108
+
109
+ # Create DataFrame from rows
110
+ return pd.DataFrame(rows, columns=columns)
111
+
112
+ # For non-SELECT queries
113
+ return pd.DataFrame()
114
+
115
+ def create_temp_table(self, mock_table: BaseMockTable) -> str:
116
+ """Create a temporary table in Trino using CREATE TABLE AS SELECT."""
117
+ timestamp = int(time.time() * 1000)
118
+ temp_table_name = f"temp_{mock_table.get_table_name()}_{timestamp}"
119
+
120
+ # In Trino, tables are qualified with catalog and schema
121
+ qualified_table_name = f"{self.catalog}.{self.schema}.{temp_table_name}"
122
+
123
+ # Generate CTAS statement (CREATE TABLE AS SELECT)
124
+ ctas_sql = self._generate_ctas_sql(temp_table_name, mock_table)
125
+
126
+ # Execute CTAS query
127
+ self.execute_query(ctas_sql)
128
+
129
+ return qualified_table_name
130
+
131
+ def cleanup_temp_tables(self, table_names: List[str]) -> None:
132
+ """Clean up temporary tables."""
133
+ for full_table_name in table_names:
134
+ try:
135
+ # Extract just the table name from the fully qualified name
136
+ # Table names can be catalog.schema.table or schema.table
137
+ table_parts = full_table_name.split(".")
138
+ if len(table_parts) == 3:
139
+ # catalog.schema.table format
140
+ catalog, schema, table = table_parts
141
+ drop_query = f'DROP TABLE IF EXISTS {catalog}.{schema}."{table}"'
142
+ elif len(table_parts) == 2:
143
+ # schema.table format, use default catalog
144
+ schema, table = table_parts
145
+ drop_query = f'DROP TABLE IF EXISTS {self.catalog}.{schema}."{table}"'
146
+ else:
147
+ # Just table name, use default catalog and schema
148
+ table = full_table_name
149
+ drop_query = f'DROP TABLE IF EXISTS {self.catalog}.{self.schema}."{table}"'
150
+
151
+ self.execute_query(drop_query)
152
+ except Exception as e:
153
+ logging.warning(f"Warning: Failed to drop table {full_table_name}: {e}")
154
+
155
+ def format_value_for_cte(self, value: Any, column_type: type) -> str:
156
+ """Format value for Trino CTE VALUES clause."""
157
+ from .._sql_utils import format_sql_value
158
+
159
+ return format_sql_value(value, column_type, dialect="trino")
160
+
161
+ def get_type_converter(self) -> BaseTypeConverter:
162
+ """Get Trino-specific type converter."""
163
+ return TrinoTypeConverter()
164
+
165
+ def get_query_size_limit(self) -> Optional[int]:
166
+ """Return query size limit in bytes for Trino."""
167
+ # Trino doesn't have a documented size limit, but we'll use a reasonable default
168
+ return 16 * 1024 * 1024 # 16MB
169
+
170
+ def _generate_ctas_sql(self, table_name: str, mock_table: BaseMockTable) -> str:
171
+ """Generate CREATE TABLE AS SELECT (CTAS) statement for Trino."""
172
+ df = mock_table.to_dataframe()
173
+ column_types = mock_table.get_column_types()
174
+ columns = list(df.columns)
175
+
176
+ # Qualify table name with schema but not catalog
177
+ # Catalog is specified in the current session context
178
+ qualified_table = f"{self.schema}.{table_name}"
179
+
180
+ if df.empty:
181
+ # For empty tables, create an empty table with correct schema
182
+ # Type mapping from Python types to Trino types
183
+ type_mapping = {
184
+ str: "VARCHAR",
185
+ int: "BIGINT",
186
+ float: "DOUBLE",
187
+ bool: "BOOLEAN",
188
+ date: "DATE",
189
+ datetime: "TIMESTAMP",
190
+ Decimal: "DECIMAL(38,9)",
191
+ }
192
+
193
+ # Generate column definitions
194
+ column_defs = []
195
+ for col_name, col_type in column_types.items():
196
+ # Handle Optional types
197
+ if hasattr(col_type, "__origin__") and col_type.__origin__ is Union:
198
+ # Extract the non-None type from Optional[T]
199
+ non_none_types = [arg for arg in get_args(col_type) if arg is not type(None)]
200
+ if non_none_types:
201
+ col_type = non_none_types[0]
202
+
203
+ trino_type = type_mapping.get(col_type, "VARCHAR")
204
+ column_defs.append(f'"{col_name}" {trino_type}')
205
+
206
+ columns_sql = ",\n ".join(column_defs)
207
+
208
+ # Create an empty table with the correct schema
209
+ # Memory catalog doesn't support table properties like format
210
+ if self.catalog == "memory":
211
+ return f"""
212
+ CREATE TABLE {qualified_table} (
213
+ {columns_sql}
214
+ )
215
+ """
216
+ else:
217
+ return f"""
218
+ CREATE TABLE {qualified_table} (
219
+ {columns_sql}
220
+ )
221
+ WITH (format = 'ORC')
222
+ """
223
+ else:
224
+ # For tables with data, use CTAS with a VALUES clause
225
+ # Build a SELECT statement with literal values for the first row
226
+ select_expressions = []
227
+
228
+ # Generate column expressions for the first row
229
+ first_row = df.iloc[0]
230
+ for col_name in columns:
231
+ col_type = column_types.get(col_name, str)
232
+ value = first_row[col_name]
233
+ formatted_value = self.format_value_for_cte(value, col_type)
234
+ select_expressions.append(f'{formatted_value} AS "{col_name}"')
235
+
236
+ # Start with the first row in the SELECT
237
+ select_sql = f"SELECT {', '.join(select_expressions)}"
238
+
239
+ # Add UNION ALL for each additional row
240
+ for i in range(1, len(df)):
241
+ row = df.iloc[i]
242
+ row_values = []
243
+ for col_name in columns:
244
+ col_type = column_types.get(col_name, str)
245
+ value = row[col_name]
246
+ formatted_value = self.format_value_for_cte(value, col_type)
247
+ row_values.append(formatted_value)
248
+
249
+ select_sql += f"\nUNION ALL SELECT {', '.join(row_values)}"
250
+
251
+ # Create the CTAS statement
252
+ # Memory catalog doesn't support table properties like format
253
+ if self.catalog == "memory":
254
+ return f"""
255
+ CREATE TABLE {qualified_table}
256
+ AS {select_sql}
257
+ """
258
+ else:
259
+ return f"""
260
+ CREATE TABLE {qualified_table}
261
+ WITH (format = 'ORC')
262
+ AS {select_sql}
263
+ """