sql-testing-library 0.6.0__tar.gz → 0.7.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (21) hide show
  1. {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/CHANGELOG.md +12 -0
  2. {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/PKG-INFO +2 -1
  3. {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/README.md +1 -0
  4. {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/pyproject.toml +1 -1
  5. {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/src/sql_testing_library/_adapters/athena.py +15 -1
  6. {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/src/sql_testing_library/_adapters/base.py +10 -1
  7. {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/src/sql_testing_library/_adapters/bigquery.py +63 -3
  8. {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/src/sql_testing_library/_adapters/redshift.py +15 -1
  9. {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/src/sql_testing_library/_adapters/snowflake.py +22 -1
  10. {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/src/sql_testing_library/_adapters/trino.py +17 -1
  11. {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/src/sql_testing_library/_core.py +185 -7
  12. {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/src/sql_testing_library/_pytest_plugin.py +96 -2
  13. sql_testing_library-0.7.1/src/sql_testing_library/_sql_logger.py +385 -0
  14. {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/src/sql_testing_library/_sql_utils.py +37 -0
  15. {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/LICENSE +0 -0
  16. {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/src/sql_testing_library/__init__.py +0 -0
  17. {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/src/sql_testing_library/_adapters/__init__.py +0 -0
  18. {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/src/sql_testing_library/_exceptions.py +0 -0
  19. {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/src/sql_testing_library/_mock_table.py +0 -0
  20. {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/src/sql_testing_library/_types.py +0 -0
  21. {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/src/sql_testing_library/py.typed +0 -0
@@ -5,6 +5,18 @@ All notable changes to this project will be documented in this file.
5
5
  The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
6
6
  and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
7
7
 
8
+ ## 0.7.1 (2025-06-06)
9
+
10
+ ### Fix
11
+
12
+ - **array**: array handling logic + sql logging improvement (#95)
13
+
14
+ ## 0.7.0 (2025-06-06)
15
+
16
+ ### Feat
17
+
18
+ - **sqllogging**: added support for logging sql logs for debugging failed tests (#94)
19
+
8
20
  ## 0.6.0 (2025-06-05)
9
21
 
10
22
  ### Feat
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: sql-testing-library
3
- Version: 0.6.0
3
+ Version: 0.7.1
4
4
  Summary: A powerful Python framework for unit testing SQL queries across BigQuery, Snowflake, Redshift, Athena, and Trino with mock data
5
5
  License: MIT
6
6
  Keywords: sql,testing,unit-testing,mock-data,database-testing,bigquery,snowflake,redshift,athena,trino,data-engineering,etl-testing,sql-validation,query-testing
@@ -108,6 +108,7 @@ For more details on our journey and the engineering challenges we solved, read t
108
108
  - **CTE or Physical Tables**: Automatic fallback for query size limits
109
109
  - **Type-Safe Results**: Deserialize results to Pydantic models
110
110
  - **Pytest Integration**: Seamless testing with `@sql_test` decorator
111
+ - **SQL Logging**: Comprehensive SQL logging with formatted output, error traces, and temp table queries
111
112
 
112
113
  ## Data Types Support
113
114
 
@@ -51,6 +51,7 @@ For more details on our journey and the engineering challenges we solved, read t
51
51
  - **CTE or Physical Tables**: Automatic fallback for query size limits
52
52
  - **Type-Safe Results**: Deserialize results to Pydantic models
53
53
  - **Pytest Integration**: Seamless testing with `@sql_test` decorator
54
+ - **SQL Logging**: Comprehensive SQL logging with formatted output, error traces, and temp table queries
54
55
 
55
56
  ## Data Types Support
56
57
 
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "sql-testing-library"
7
- version = "0.6.0"
7
+ version = "0.7.1"
8
8
  description = "A powerful Python framework for unit testing SQL queries across BigQuery, Snowflake, Redshift, Athena, and Trino with mock data"
9
9
  authors = ["Gurmeet Saran <gurmeetx@gmail.com>", "Kushal Thakkar <kushal.thakkar@gmail.com>"]
10
10
  maintainers = ["Gurmeet Saran <gurmeetx@gmail.com>", "Kushal Thakkar <kushal.thakkar@gmail.com>"]
@@ -4,7 +4,7 @@ import logging
4
4
  import time
5
5
  from datetime import date, datetime
6
6
  from decimal import Decimal
7
- from typing import TYPE_CHECKING, Any, List, Optional, Type, Union, get_args
7
+ from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Type, Union, get_args
8
8
 
9
9
 
10
10
  if TYPE_CHECKING:
@@ -136,6 +136,20 @@ class AthenaAdapter(DatabaseAdapter):
136
136
 
137
137
  return qualified_table_name
138
138
 
139
+ def create_temp_table_with_sql(self, mock_table: BaseMockTable) -> Tuple[str, str]:
140
+ """Create a temporary table and return both table name and SQL."""
141
+ timestamp = int(time.time() * 1000)
142
+ temp_table_name = f"temp_{mock_table.get_table_name()}_{timestamp}"
143
+ qualified_table_name = f"{self.database}.{temp_table_name}"
144
+
145
+ # Generate CTAS statement (CREATE TABLE AS SELECT)
146
+ ctas_sql = self._generate_ctas_sql(temp_table_name, mock_table)
147
+
148
+ # Execute CTAS query
149
+ self.execute_query(ctas_sql)
150
+
151
+ return qualified_table_name, ctas_sql
152
+
139
153
  def cleanup_temp_tables(self, table_names: List[str]) -> None:
140
154
  """Clean up temporary tables."""
141
155
  for full_table_name in table_names:
@@ -1,7 +1,7 @@
1
1
  """Base database adapter interface."""
2
2
 
3
3
  from abc import ABC, abstractmethod
4
- from typing import TYPE_CHECKING, Any, List, Optional
4
+ from typing import TYPE_CHECKING, Any, List, Optional, Tuple
5
5
 
6
6
 
7
7
  if TYPE_CHECKING:
@@ -30,6 +30,15 @@ class DatabaseAdapter(ABC):
30
30
  """Create a temporary table with mock data. Returns temp table name."""
31
31
  pass
32
32
 
33
+ @abstractmethod
34
+ def create_temp_table_with_sql(self, mock_table: BaseMockTable) -> Tuple[str, str]:
35
+ """Create a temporary table and return both table name and SQL.
36
+
37
+ Returns:
38
+ Tuple of (temp_table_name, create_table_sql)
39
+ """
40
+ pass
41
+
33
42
  @abstractmethod
34
43
  def cleanup_temp_tables(self, table_names: List[str]) -> None:
35
44
  """Clean up temporary tables."""
@@ -3,7 +3,7 @@
3
3
  import logging
4
4
  from datetime import date, datetime
5
5
  from decimal import Decimal
6
- from typing import TYPE_CHECKING, Any, List, Optional, Type, Union, get_args
6
+ from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Type, Union, get_args
7
7
 
8
8
 
9
9
  if TYPE_CHECKING:
@@ -88,6 +88,54 @@ class BigQueryAdapter(DatabaseAdapter):
88
88
 
89
89
  return table_id
90
90
 
91
+ def create_temp_table_with_sql(self, mock_table: BaseMockTable) -> Tuple[str, str]:
92
+ """Create temporary table and return both table name and SQL."""
93
+ import time
94
+
95
+ temp_table_name = f"temp_{mock_table.get_table_name()}_{int(time.time() * 1000)}"
96
+ table_id = f"{self.project_id}.{self.dataset_id}.{temp_table_name}"
97
+
98
+ # Generate CREATE TABLE SQL
99
+ schema = self._get_bigquery_schema(mock_table)
100
+ column_defs = []
101
+ for field in schema:
102
+ column_defs.append(f"`{field.name}` {field.field_type}")
103
+
104
+ columns_sql = ",\n ".join(column_defs)
105
+ create_sql = f"CREATE TABLE `{table_id}` (\n {columns_sql}\n)"
106
+
107
+ # Get insert SQL for the data
108
+ df = mock_table.to_dataframe()
109
+ if not df.empty:
110
+ # Generate INSERT statement
111
+ values_rows = []
112
+ for _, row in df.iterrows():
113
+ values = []
114
+ for col in df.columns:
115
+ value = row[col]
116
+ col_type = mock_table.get_column_types().get(col, str)
117
+ formatted_value = self.format_value_for_cte(value, col_type)
118
+ values.append(formatted_value)
119
+ values_rows.append(f"({', '.join(values)})")
120
+
121
+ values_sql = ",\n".join(values_rows)
122
+ insert_sql = f"INSERT INTO `{table_id}` VALUES\n{values_sql}"
123
+ full_sql = f"{create_sql};\n\n{insert_sql};"
124
+ else:
125
+ full_sql = create_sql + ";"
126
+
127
+ # Actually create the table
128
+ table = bigquery.Table(table_id, schema=schema)
129
+ table = self.client.create_table(table)
130
+
131
+ # Insert data if any
132
+ if not df.empty:
133
+ job_config = bigquery.LoadJobConfig()
134
+ job = self.client.load_table_from_dataframe(df, table, job_config=job_config)
135
+ job.result()
136
+
137
+ return table_id, full_sql
138
+
91
139
  def cleanup_temp_tables(self, table_names: List[str]) -> None:
92
140
  """Delete temporary tables."""
93
141
  for table_name in table_names:
@@ -130,7 +178,19 @@ class BigQueryAdapter(DatabaseAdapter):
130
178
  if non_none_types:
131
179
  col_type = non_none_types[0]
132
180
 
133
- bq_type = type_mapping.get(col_type, bigquery.enums.SqlTypeNames.STRING)
134
- schema.append(bigquery.SchemaField(col_name, bq_type))
181
+ # Handle List/Array types
182
+ if hasattr(col_type, "__origin__") and col_type.__origin__ is list:
183
+ # Get the element type from List[T]
184
+ element_type = get_args(col_type)[0] if get_args(col_type) else str
185
+
186
+ # Map element type to BigQuery type
187
+ element_bq_type = type_mapping.get(element_type, bigquery.enums.SqlTypeNames.STRING)
188
+
189
+ # Create field with mode=REPEATED for arrays
190
+ schema.append(bigquery.SchemaField(col_name, element_bq_type, mode="REPEATED"))
191
+ else:
192
+ # Handle scalar types
193
+ bq_type = type_mapping.get(col_type, bigquery.enums.SqlTypeNames.STRING)
194
+ schema.append(bigquery.SchemaField(col_name, bq_type))
135
195
 
136
196
  return schema
@@ -3,7 +3,7 @@
3
3
  import time
4
4
  from datetime import date, datetime
5
5
  from decimal import Decimal
6
- from typing import TYPE_CHECKING, Any, List, Optional, Type, Union, get_args
6
+ from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Type, Union, get_args
7
7
 
8
8
 
9
9
  if TYPE_CHECKING:
@@ -122,6 +122,20 @@ class RedshiftAdapter(DatabaseAdapter):
122
122
  # Return just the table name, no schema prefix needed for temp tables
123
123
  return temp_table_name
124
124
 
125
+ def create_temp_table_with_sql(self, mock_table: BaseMockTable) -> Tuple[str, str]:
126
+ """Create a temporary table and return both table name and SQL."""
127
+ timestamp = int(time.time() * 1000)
128
+ temp_table_name = f"temp_{mock_table.get_table_name()}_{timestamp}"
129
+
130
+ # Generate CTAS statement (CREATE TABLE AS SELECT)
131
+ ctas_sql = self._generate_ctas_sql(temp_table_name, mock_table)
132
+
133
+ # Execute CTAS query
134
+ self.execute_query(ctas_sql)
135
+
136
+ # Return just the table name and the SQL
137
+ return temp_table_name, ctas_sql
138
+
125
139
  def cleanup_temp_tables(self, table_names: List[str]) -> None:
126
140
  """Clean up temporary tables."""
127
141
  # Redshift temporary tables are automatically dropped at the end of the session
@@ -4,7 +4,7 @@ import logging
4
4
  import time
5
5
  from datetime import date, datetime
6
6
  from decimal import Decimal
7
- from typing import TYPE_CHECKING, Any, List, Optional, Type, Union, get_args
7
+ from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Type, Union, get_args
8
8
 
9
9
 
10
10
  if TYPE_CHECKING:
@@ -152,6 +152,27 @@ class SnowflakeAdapter(DatabaseAdapter):
152
152
 
153
153
  return qualified_table_name
154
154
 
155
+ def create_temp_table_with_sql(self, mock_table: BaseMockTable) -> Tuple[str, str]:
156
+ """Create a temporary table and return both table name and SQL."""
157
+ timestamp = int(time.time() * 1000)
158
+ temp_table_name = f"TEMP_{mock_table.get_table_name()}_{timestamp}"
159
+
160
+ # Use the adapter's configured database and schema for temporary tables
161
+ # This avoids permission issues with creating schemas in other databases
162
+ target_schema = self.schema
163
+
164
+ # For temporary tables, Snowflake doesn't support full database qualification
165
+ # Return schema.table format for temporary tables
166
+ qualified_table_name = f"{target_schema}.{temp_table_name}"
167
+
168
+ # Generate CTAS statement (CREATE TABLE AS SELECT)
169
+ ctas_sql = self._generate_ctas_sql(temp_table_name, mock_table, target_schema)
170
+
171
+ # Execute CTAS query
172
+ self.execute_query(ctas_sql)
173
+
174
+ return qualified_table_name, ctas_sql
175
+
155
176
  def cleanup_temp_tables(self, table_names: List[str]) -> None:
156
177
  """Clean up temporary tables."""
157
178
  for full_table_name in table_names:
@@ -4,7 +4,7 @@ import logging
4
4
  import time
5
5
  from datetime import date, datetime
6
6
  from decimal import Decimal
7
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union, get_args
7
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union, get_args
8
8
 
9
9
 
10
10
  if TYPE_CHECKING:
@@ -132,6 +132,22 @@ class TrinoAdapter(DatabaseAdapter):
132
132
 
133
133
  return qualified_table_name
134
134
 
135
+ def create_temp_table_with_sql(self, mock_table: BaseMockTable) -> Tuple[str, str]:
136
+ """Create a temporary table and return both table name and SQL."""
137
+ timestamp = int(time.time() * 1000)
138
+ temp_table_name = f"temp_{mock_table.get_table_name()}_{timestamp}"
139
+
140
+ # In Trino, tables are qualified with catalog and schema
141
+ qualified_table_name = f"{self.catalog}.{self.schema}.{temp_table_name}"
142
+
143
+ # Generate CTAS statement (CREATE TABLE AS SELECT)
144
+ ctas_sql = self._generate_ctas_sql(temp_table_name, mock_table)
145
+
146
+ # Execute CTAS query
147
+ self.execute_query(ctas_sql)
148
+
149
+ return qualified_table_name, ctas_sql
150
+
135
151
  def cleanup_temp_tables(self, table_names: List[str]) -> None:
136
152
  """Clean up temporary tables."""
137
153
  for full_table_name in table_names:
@@ -1,5 +1,6 @@
1
1
  """Core SQL testing framework."""
2
2
 
3
+ import os
3
4
  from dataclasses import dataclass
4
5
  from typing import (
5
6
  TYPE_CHECKING,
@@ -27,6 +28,7 @@ from ._exceptions import (
27
28
  TypeConversionError,
28
29
  )
29
30
  from ._mock_table import BaseMockTable
31
+ from ._sql_logger import SQLLogger
30
32
 
31
33
 
32
34
  # Type for adapter types
@@ -34,6 +36,9 @@ AdapterType = Literal["bigquery", "athena", "redshift", "trino", "snowflake"]
34
36
 
35
37
  T = TypeVar("T")
36
38
 
39
+ # Global storage for SQL execution data (used by pytest plugin)
40
+ sql_test_execution_data: Dict[str, Dict[str, Any]] = {}
41
+
37
42
 
38
43
  @dataclass
39
44
  class SQLTestCase(Generic[T]):
@@ -48,6 +53,7 @@ class SQLTestCase(Generic[T]):
48
53
  use_physical_tables: bool = False
49
54
  description: Optional[str] = None
50
55
  adapter_type: Optional[AdapterType] = None
56
+ log_sql: Optional[bool] = None
51
57
  # Backward compatibility
52
58
  execution_database: Optional[str] = None
53
59
 
@@ -84,21 +90,34 @@ class SQLTestCase(Generic[T]):
84
90
  class SQLTestFramework:
85
91
  """Main framework for executing SQL tests."""
86
92
 
87
- def __init__(self, adapter: DatabaseAdapter) -> None:
93
+ def __init__(self, adapter: DatabaseAdapter, sql_logger: Optional[SQLLogger] = None) -> None:
88
94
  self.adapter = adapter
89
95
  self.type_converter = self.adapter.get_type_converter()
90
96
  self.temp_tables: List[str] = []
97
+ self.sql_logger = sql_logger or SQLLogger()
91
98
 
92
- def run_test(self, test_case: SQLTestCase[T]) -> List[T]:
99
+ def run_test(
100
+ self, test_case: SQLTestCase[T], test_context: Optional[Dict[str, Any]] = None
101
+ ) -> List[T]:
93
102
  """
94
103
  Execute a test case and return deserialized results.
95
104
 
96
105
  Args:
97
106
  test_case: The test case to execute
107
+ test_context: Optional context dictionary with test metadata
98
108
 
99
109
  Returns:
100
110
  List of result objects of type test_case.result_class
101
111
  """
112
+ import time
113
+
114
+ # Track execution time
115
+ start_time = time.time()
116
+ final_query = ""
117
+ error_message = None
118
+ row_count = None
119
+ temp_table_queries: List[str] = [] # Track temp table creation queries
120
+
102
121
  try:
103
122
  # Validate required fields
104
123
  if test_case.mock_tables is None:
@@ -130,7 +149,7 @@ class SQLTestFramework:
130
149
  if test_case.use_physical_tables:
131
150
  # Create physical temporary tables
132
151
  final_query = self._execute_with_physical_tables(
133
- test_case.query, table_mapping, test_case.mock_tables
152
+ test_case.query, table_mapping, test_case.mock_tables, temp_table_queries
134
153
  )
135
154
  else:
136
155
  # Generate query with CTEs
@@ -150,8 +169,134 @@ class SQLTestFramework:
150
169
  # Execute query
151
170
  result_df = self.adapter.execute_query(final_query)
152
171
 
172
+ # Track row count
173
+ row_count = len(result_df) if result_df is not None else 0
174
+
153
175
  # Convert results to typed objects
154
- return self._deserialize_results(result_df, test_case.result_class)
176
+ results = self._deserialize_results(result_df, test_case.result_class)
177
+
178
+ # Log SQL if enabled (success case)
179
+ execution_time = time.time() - start_time
180
+
181
+ # Store execution data for potential logging on test failure
182
+ if test_context:
183
+ test_id = test_context.get("test_id")
184
+ if test_id:
185
+ sql_test_execution_data[test_id] = {
186
+ "sql": final_query,
187
+ "test_name": test_context.get("test_name", "unknown_test"),
188
+ "test_class": test_context.get("test_class"),
189
+ "test_file": test_context.get("test_file"),
190
+ "metadata": {
191
+ "query": test_case.query,
192
+ "default_namespace": test_case.default_namespace,
193
+ "mock_tables": test_case.mock_tables,
194
+ "adapter_type": self.adapter.__class__.__name__.replace(
195
+ "Adapter", ""
196
+ ).lower(),
197
+ "use_physical_tables": test_case.use_physical_tables,
198
+ "execution_time": execution_time,
199
+ "row_count": row_count,
200
+ "error": None,
201
+ "temp_table_queries": temp_table_queries,
202
+ },
203
+ "sql_logger": self.sql_logger,
204
+ "log_sql": test_case.log_sql,
205
+ }
206
+
207
+ if self.sql_logger.should_log(test_case.log_sql):
208
+ # Get test context info
209
+ test_name = (
210
+ test_context.get("test_name", "unknown_test")
211
+ if test_context
212
+ else "unknown_test"
213
+ )
214
+ test_class = test_context.get("test_class") if test_context else None
215
+ test_file = test_context.get("test_file") if test_context else None
216
+
217
+ metadata = {
218
+ "query": test_case.query,
219
+ "default_namespace": test_case.default_namespace,
220
+ "mock_tables": test_case.mock_tables,
221
+ "adapter_type": self.adapter.get_sqlglot_dialect(),
222
+ "adapter_name": self.adapter.__class__.__name__.replace("Adapter", "").lower(),
223
+ "use_physical_tables": test_case.use_physical_tables,
224
+ "execution_time": execution_time,
225
+ "row_count": row_count,
226
+ "error": None,
227
+ "temp_table_queries": temp_table_queries,
228
+ }
229
+
230
+ # Log SQL immediately
231
+ log_path = self.sql_logger.log_sql(
232
+ sql=final_query,
233
+ test_name=test_name,
234
+ test_class=test_class,
235
+ test_file=test_file,
236
+ failed=False,
237
+ metadata=metadata,
238
+ )
239
+
240
+ # Print log location if environment variable is set
241
+ if os.environ.get("SQL_TEST_LOG_ALL", "").lower() in ("true", "1", "yes"):
242
+ import sys
243
+
244
+ print(f"\nSQL logged to: file://{log_path}", file=sys.stderr) # noqa: T201
245
+ sys.stderr.flush()
246
+
247
+ return results
248
+
249
+ except Exception as e:
250
+ # Store exception information for potential logging by pytest hook
251
+ execution_time = time.time() - start_time
252
+
253
+ # Capture full error details including traceback
254
+ import traceback
255
+
256
+ error_message = str(e)
257
+ error_traceback = traceback.format_exc()
258
+
259
+ # Store execution data for pytest hook to potentially log
260
+ if test_context and test_case.log_sql is not False:
261
+ test_id = test_context.get("test_id")
262
+ if test_id:
263
+ # Update the execution data with error information
264
+ if test_id in sql_test_execution_data:
265
+ sql_test_execution_data[test_id]["metadata"]["error"] = error_message
266
+ sql_test_execution_data[test_id]["metadata"]["error_traceback"] = (
267
+ error_traceback
268
+ )
269
+ sql_test_execution_data[test_id]["metadata"]["execution_time"] = (
270
+ execution_time
271
+ )
272
+ sql_test_execution_data[test_id]["metadata"]["row_count"] = row_count
273
+ else:
274
+ # If we haven't stored data yet (error happened early), store it now
275
+ sql_test_execution_data[test_id] = {
276
+ "sql": final_query if "final_query" in locals() else test_case.query,
277
+ "test_name": test_context.get("test_name", "unknown_test"),
278
+ "test_class": test_context.get("test_class"),
279
+ "test_file": test_context.get("test_file"),
280
+ "metadata": {
281
+ "query": test_case.query,
282
+ "default_namespace": test_case.default_namespace,
283
+ "mock_tables": test_case.mock_tables,
284
+ "adapter_type": self.adapter.get_sqlglot_dialect(),
285
+ "adapter_name": self.adapter.__class__.__name__.replace(
286
+ "Adapter", ""
287
+ ).lower(),
288
+ "use_physical_tables": test_case.use_physical_tables,
289
+ "execution_time": execution_time,
290
+ "row_count": row_count,
291
+ "error": error_message,
292
+ "error_traceback": error_traceback,
293
+ "temp_table_queries": temp_table_queries,
294
+ },
295
+ "sql_logger": self.sql_logger,
296
+ "log_sql": test_case.log_sql,
297
+ }
298
+
299
+ raise
155
300
 
156
301
  finally:
157
302
  # Cleanup any temporary tables
@@ -439,15 +584,48 @@ class SQLTestFramework:
439
584
  query: str,
440
585
  table_mapping: Dict[str, BaseMockTable],
441
586
  mock_tables: List[BaseMockTable],
587
+ temp_table_queries: List[str],
442
588
  ) -> str:
443
589
  """Execute query using physical temporary tables."""
444
590
  # Create physical tables
445
591
  replacement_mapping = {}
446
592
 
447
593
  for original_name, mock_table in table_mapping.items():
448
- temp_table_name = self.adapter.create_temp_table(mock_table)
449
- self.temp_tables.append(temp_table_name)
450
- replacement_mapping[original_name] = temp_table_name
594
+ try:
595
+ # Check if adapter has method to get temp table SQL
596
+ if hasattr(self.adapter, "create_temp_table_with_sql"):
597
+ temp_table_name, create_sql = self.adapter.create_temp_table_with_sql(
598
+ mock_table
599
+ )
600
+ temp_table_queries.append(create_sql)
601
+ else:
602
+ temp_table_name = self.adapter.create_temp_table(mock_table)
603
+ # Try to generate approximate SQL for logging
604
+ temp_table_queries.append(
605
+ f"-- CREATE TEMP TABLE {temp_table_name} (SQL not captured)"
606
+ )
607
+
608
+ self.temp_tables.append(temp_table_name)
609
+ replacement_mapping[original_name] = temp_table_name
610
+ except Exception:
611
+ # If table creation fails, still try to capture the SQL for debugging
612
+ if hasattr(self.adapter, "create_temp_table_with_sql") and hasattr(
613
+ mock_table, "get_table_name"
614
+ ):
615
+ try:
616
+ temp_table_name, ctas_sql = self.adapter.create_temp_table_with_sql(
617
+ mock_table
618
+ )
619
+ temp_table_queries.append(ctas_sql)
620
+ replacement_mapping[original_name] = temp_table_name
621
+ except Exception:
622
+ # If even SQL generation fails, add a placeholder
623
+ temp_table_queries.append(
624
+ f"-- CREATE TEMP TABLE for {original_name} (SQL generation failed)"
625
+ )
626
+ replacement_mapping[original_name] = f"temp_{original_name}_failed"
627
+ # Re-raise the original exception
628
+ raise
451
629
 
452
630
  # Replace table names and return modified query
453
631
  return self._replace_table_names_in_query(query, replacement_mapping)
@@ -10,7 +10,7 @@ import pytest
10
10
  from _pytest.nodes import Item
11
11
 
12
12
  from ._adapters.base import DatabaseAdapter
13
- from ._core import AdapterType, SQLTestCase, SQLTestFramework
13
+ from ._core import AdapterType, SQLTestCase, SQLTestFramework, sql_test_execution_data
14
14
  from ._mock_table import BaseMockTable
15
15
 
16
16
 
@@ -335,12 +335,16 @@ class SQLTestDecorator:
335
335
  # Global instance
336
336
  _sql_test_decorator = SQLTestDecorator()
337
337
 
338
+ # Global SQL execution context for logging
339
+ _sql_execution_context: Dict[str, Any] = {}
340
+
338
341
 
339
342
  def sql_test(
340
343
  mock_tables: Optional[List[BaseMockTable]] = None,
341
344
  result_class: Optional[Type[T]] = None,
342
345
  use_physical_tables: Optional[bool] = None,
343
346
  adapter_type: Optional[AdapterType] = None,
347
+ log_sql: Optional[bool] = None,
344
348
  ) -> Callable[[Callable[[], SQLTestCase[T]]], Callable[[], List[T]]]:
345
349
  """
346
350
  Decorator to mark a function as a SQL test.
@@ -360,6 +364,8 @@ def sql_test(
360
364
  (e.g., 'bigquery', 'athena').
361
365
  If provided, overrides adapter_type in SQLTestCase and uses config
362
366
  from [sql_testing.{adapter_type}] section.
367
+ log_sql: Optional flag to log the generated SQL to a file.
368
+ If provided, overrides log_sql in SQLTestCase.
363
369
  """
364
370
 
365
371
  def decorator(func: Callable[[], SQLTestCase[T]]) -> Callable[[], List[T]]:
@@ -396,15 +402,54 @@ def sql_test(
396
402
  if adapter_type is not None:
397
403
  test_case.adapter_type = adapter_type
398
404
 
405
+ if log_sql is not None:
406
+ test_case.log_sql = log_sql
407
+
399
408
  # Get framework and execute test
400
409
  framework = _sql_test_decorator.get_framework(test_case.adapter_type)
401
- results: List[T] = framework.run_test(test_case)
410
+
411
+ # Create test context for logging
412
+ test_context = {}
413
+
414
+ # Try to get test metadata from the current pytest context
415
+ import inspect
416
+
417
+ frame = inspect.currentframe()
418
+ while frame:
419
+ frame_locals = frame.f_locals
420
+ if "item" in frame_locals and hasattr(frame_locals["item"], "name"):
421
+ item = frame_locals["item"]
422
+ test_context["test_name"] = item.name
423
+ test_context["test_class"] = item.cls.__name__ if item.cls else None
424
+ test_context["test_file"] = (
425
+ str(item.fspath) if hasattr(item, "fspath") else None
426
+ )
427
+ # Create a unique test ID
428
+ test_context["test_id"] = str(id(item))
429
+ break
430
+ frame = frame.f_back
431
+
432
+ # If we couldn't get test context from stack, try to get it from function name
433
+ if not test_context:
434
+ test_context["test_name"] = func.__name__
435
+ test_context["test_file"] = inspect.getfile(func) if func is not None else None
436
+ # Create a unique test ID
437
+ test_context["test_id"] = f"{func.__name__}_{id(func)}"
438
+
439
+ results: List[T] = framework.run_test(test_case, test_context)
402
440
 
403
441
  return results
404
442
 
405
443
  # Mark function as SQL test
406
444
  wrapper._sql_test_decorated = True # type: ignore
407
445
  wrapper._original_func = func # type: ignore
446
+ wrapper._decorator_params = { # type: ignore
447
+ "mock_tables": mock_tables,
448
+ "result_class": result_class,
449
+ "use_physical_tables": use_physical_tables,
450
+ "adapter_type": adapter_type,
451
+ "log_sql": log_sql,
452
+ }
408
453
 
409
454
  return wrapper
410
455
 
@@ -449,3 +494,52 @@ def pytest_runtest_call(item: Item) -> None:
449
494
  else:
450
495
  # Use default pytest execution
451
496
  item.runtest()
497
+
498
+
499
+ def pytest_runtest_makereport(item: Item, call: Any) -> None:
500
+ """Hook to log SQL when tests fail (including assertion failures)."""
501
+ # We want to log after the test call phase
502
+ if call.when == "call":
503
+ test_id = str(id(item))
504
+
505
+ if call.excinfo is not None:
506
+ # Test failed - check if we have SQL execution data for this test
507
+ if test_id in sql_test_execution_data:
508
+ data = sql_test_execution_data[test_id]
509
+ sql_logger = data["sql_logger"]
510
+ log_sql = data.get("log_sql")
511
+
512
+ # Only log if log_sql is not False
513
+ if log_sql is not False:
514
+ # Capture the assertion error details
515
+ import traceback
516
+
517
+ metadata = data["metadata"].copy()
518
+ # Update error info with the actual pytest error
519
+ # (might be different from stored error)
520
+ metadata["error"] = str(call.excinfo.value)
521
+ metadata["error_traceback"] = "".join(
522
+ traceback.format_exception(
523
+ call.excinfo.type, call.excinfo.value, call.excinfo.tb
524
+ )
525
+ )
526
+
527
+ # Log the SQL
528
+ log_path = sql_logger.log_sql(
529
+ sql=data["sql"],
530
+ test_name=data["test_name"],
531
+ test_class=data["test_class"],
532
+ test_file=data["test_file"],
533
+ failed=True,
534
+ metadata=metadata,
535
+ )
536
+
537
+ # Print log location
538
+ import sys
539
+
540
+ print(f"\nSQL logged to: file://{log_path}", file=sys.stderr) # noqa: T201
541
+ sys.stderr.flush()
542
+
543
+ # Clean up the stored data after the test (whether it passed or failed)
544
+ if test_id in sql_test_execution_data:
545
+ del sql_test_execution_data[test_id]
@@ -0,0 +1,385 @@
1
+ """SQL logging functionality for test cases."""
2
+
3
+ import os
4
+ import re
5
+ from datetime import datetime
6
+ from pathlib import Path
7
+ from typing import Any, Dict, List, Optional
8
+
9
+ from sqlglot import parse_one
10
+
11
+ from ._mock_table import BaseMockTable
12
+
13
+
14
+ class SQLLogger:
15
+ """Handles SQL logging for test cases."""
16
+
17
+ # Class variable to store the run directory for the current test session
18
+ _run_directory: Optional[Path] = None
19
+ _run_id: Optional[str] = None
20
+
21
+ def __init__(self, log_dir: Optional[str] = None) -> None:
22
+ """Initialize SQL logger.
23
+
24
+ Args:
25
+ log_dir: Directory to store SQL log files. If None, uses .sql_logs in project root.
26
+ """
27
+ if log_dir is None:
28
+ # Check environment variable first
29
+ env_log_dir = os.environ.get("SQL_TEST_LOG_DIR")
30
+ if env_log_dir:
31
+ self.log_dir = Path(env_log_dir)
32
+ else:
33
+ # Try to find the project root by looking for specific project files
34
+ current_path = Path.cwd()
35
+
36
+ # Look for definitive project root markers (in order of preference)
37
+ # These are files that typically only exist at project root
38
+ root_markers = ["pyproject.toml", "setup.py", "setup.cfg", "tox.ini"]
39
+
40
+ # Search up the directory tree for project root
41
+ project_root = None
42
+ search_path = current_path
43
+
44
+ while search_path != search_path.parent:
45
+ # Check for root markers
46
+ if any((search_path / marker).exists() for marker in root_markers):
47
+ project_root = search_path
48
+ break
49
+
50
+ # Also check for .git directory (but not .git file which could be a submodule)
51
+ if (search_path / ".git").is_dir():
52
+ project_root = search_path
53
+ break
54
+
55
+ search_path = search_path.parent
56
+
57
+ # If we found a project root, use it; otherwise fall back to current directory
58
+ if project_root:
59
+ self.log_dir = project_root / ".sql_logs"
60
+ else:
61
+ # Fall back to current directory if project root not found
62
+ self.log_dir = Path(".sql_logs")
63
+ else:
64
+ self.log_dir = Path(log_dir)
65
+
66
+ self.log_dir.mkdir(parents=True, exist_ok=True)
67
+ self._logged_files: List[str] = []
68
+
69
+ def _ensure_run_directory(self) -> Path:
70
+ """Ensure run directory exists, creating it if necessary.
71
+
72
+ Returns:
73
+ Path to the run directory
74
+ """
75
+ # Create run directory if not already created for this session
76
+ if SQLLogger._run_directory is None:
77
+ # Generate run ID with timestamp
78
+ timestamp = datetime.now().strftime("%Y%m%dT%H%M%S")
79
+ SQLLogger._run_id = f"runid_{timestamp}"
80
+ SQLLogger._run_directory = self.log_dir / SQLLogger._run_id
81
+ SQLLogger._run_directory.mkdir(parents=True, exist_ok=True)
82
+ return SQLLogger._run_directory
83
+
84
+ def should_log(self, log_sql: Optional[bool] = None) -> bool:
85
+ """Determine if SQL should be logged based on environment and parameters.
86
+
87
+ Args:
88
+ log_sql: Explicit parameter from test case
89
+
90
+ Returns:
91
+ True if SQL should be logged
92
+ """
93
+ # If explicitly set in test case, use that
94
+ if log_sql is not None:
95
+ return log_sql
96
+
97
+ # Check environment variable
98
+ return os.environ.get("SQL_TEST_LOG_ALL", "").lower() in ("true", "1", "yes")
99
+
100
+ def generate_filename(
101
+ self,
102
+ test_name: str,
103
+ test_class: Optional[str] = None,
104
+ test_file: Optional[str] = None,
105
+ failed: bool = False,
106
+ ) -> str:
107
+ """Generate a unique filename for the SQL log.
108
+
109
+ Args:
110
+ test_name: Name of the test function
111
+ test_class: Name of the test class (if any)
112
+ test_file: Path to the test file
113
+ failed: Whether the test failed
114
+
115
+ Returns:
116
+ Generated filename
117
+ """
118
+ # Clean test name for filesystem (including square brackets)
119
+ clean_name = re.sub(r'[<>:"/\\|?*\[\]]', "_", test_name)
120
+
121
+ # Build filename components
122
+ components = []
123
+
124
+ # Add test file name (without path and extension)
125
+ if test_file:
126
+ file_base = Path(test_file).stem
127
+ components.append(file_base)
128
+
129
+ # Add class name if present
130
+ if test_class:
131
+ clean_class = re.sub(r'[<>:"/\\|?*\[\]]', "_", test_class)
132
+ components.append(clean_class)
133
+
134
+ # Add test name
135
+ components.append(clean_name)
136
+
137
+ # Add status indicator
138
+ if failed:
139
+ components.append("FAILED")
140
+
141
+ # Add timestamp for uniqueness
142
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] # Milliseconds
143
+ components.append(timestamp)
144
+
145
+ # Join with double underscore for clarity
146
+ filename = "__".join(components) + ".sql"
147
+
148
+ return filename
149
+
150
+ def format_sql(self, sql: str, dialect: Optional[str] = None) -> str:
151
+ """Format SQL query for better readability.
152
+
153
+ Args:
154
+ sql: SQL query to format
155
+ dialect: SQL dialect (e.g., 'bigquery', 'athena')
156
+
157
+ Returns:
158
+ Formatted SQL
159
+ """
160
+ try:
161
+ # Parse and format using sqlglot
162
+ parsed = parse_one(sql, dialect=dialect)
163
+ # IMPORTANT: Pass dialect to sql() to preserve dialect-specific syntax
164
+ formatted = parsed.sql(pretty=True, pad=2, dialect=dialect)
165
+ return formatted
166
+ except Exception:
167
+ # If formatting fails, return original
168
+ return sql
169
+
170
+ def create_metadata_header(
171
+ self,
172
+ test_name: str,
173
+ test_class: Optional[str] = None,
174
+ test_file: Optional[str] = None,
175
+ query: str = "",
176
+ default_namespace: Optional[str] = None,
177
+ mock_tables: Optional[List[BaseMockTable]] = None,
178
+ adapter_type: Optional[str] = None,
179
+ use_physical_tables: bool = False,
180
+ execution_time: Optional[float] = None,
181
+ row_count: Optional[int] = None,
182
+ error: Optional[str] = None,
183
+ error_traceback: Optional[str] = None,
184
+ temp_table_queries: Optional[List[str]] = None,
185
+ **kwargs: Any,
186
+ ) -> str:
187
+ """Create a metadata header for the SQL file.
188
+
189
+ Returns:
190
+ Formatted metadata header as SQL comments
191
+ """
192
+ lines = [
193
+ "-- SQL Test Case Log",
194
+ "-- " + "=" * 78,
195
+ f"-- Generated: {datetime.now().isoformat()}",
196
+ f"-- Run ID: {SQLLogger._run_id}",
197
+ f"-- Test Name: {test_name}",
198
+ ]
199
+
200
+ if test_class:
201
+ lines.append(f"-- Test Class: {test_class}")
202
+
203
+ if test_file:
204
+ lines.append(f"-- Test File: {test_file}")
205
+
206
+ if adapter_type:
207
+ lines.append(f"-- Adapter: {adapter_type}")
208
+
209
+ # Show adapter name if different from sqlglot dialect
210
+ adapter_name = kwargs.get("adapter_name")
211
+ if adapter_name and adapter_name != adapter_type:
212
+ lines.append(f"-- Database: {adapter_name}")
213
+
214
+ if default_namespace:
215
+ lines.append(f"-- Default Namespace: {default_namespace}")
216
+
217
+ lines.append(f"-- Use Physical Tables: {use_physical_tables}")
218
+
219
+ if execution_time is not None:
220
+ lines.append(f"-- Execution Time: {execution_time:.3f} seconds")
221
+
222
+ if row_count is not None:
223
+ lines.append(f"-- Result Rows: {row_count}")
224
+
225
+ if error:
226
+ lines.extend(
227
+ [
228
+ "-- Status: FAILED",
229
+ "-- Error:",
230
+ ]
231
+ )
232
+ for line in error.strip().split("\n"):
233
+ lines.append(f"-- {line}")
234
+
235
+ # Add full error traceback if available
236
+ if error_traceback:
237
+ lines.extend(
238
+ [
239
+ "",
240
+ "-- Full Error Details:",
241
+ "-- " + "-" * 78,
242
+ ]
243
+ )
244
+ # Add each line of the traceback as a SQL comment
245
+ for line in error_traceback.strip().split("\n"):
246
+ lines.append(f"-- {line}")
247
+ else:
248
+ lines.append("-- Status: SUCCESS")
249
+
250
+ # Add mock tables information
251
+ if mock_tables:
252
+ lines.extend(
253
+ [
254
+ "",
255
+ "-- Mock Tables:",
256
+ "-- " + "-" * 78,
257
+ ]
258
+ )
259
+ for table in mock_tables:
260
+ lines.append(f"-- Table: {table.get_table_name()}")
261
+ # Get row count from data
262
+ if hasattr(table, "data") and table.data:
263
+ lines.append(f"-- Rows: {len(table.data)}")
264
+ # Get column names from first row or column types
265
+ if hasattr(table, "get_column_types"):
266
+ columns = list(table.get_column_types().keys())
267
+ if columns:
268
+ lines.append(f"-- Columns: {', '.join(columns)}")
269
+
270
+ # Add original query
271
+ lines.extend(
272
+ [
273
+ "",
274
+ "-- Original Query:",
275
+ "-- " + "-" * 78,
276
+ ]
277
+ )
278
+ # Comment out each line of the original query
279
+ for line in query.split("\n"):
280
+ lines.append(f"-- {line}")
281
+
282
+ # Add temp table queries if physical tables were used
283
+ if use_physical_tables and temp_table_queries:
284
+ lines.extend(
285
+ [
286
+ "",
287
+ "-- Temporary Table Creation Queries:",
288
+ "-- " + "-" * 78,
289
+ "",
290
+ ]
291
+ )
292
+ for i, temp_query in enumerate(temp_table_queries, 1):
293
+ lines.append(f"-- Query {i}:")
294
+ lines.append("")
295
+ # Format the temp table SQL
296
+ formatted_temp_sql = self.format_sql(temp_query, dialect=adapter_type)
297
+ lines.append(formatted_temp_sql)
298
+ lines.append("")
299
+
300
+ lines.extend(
301
+ [
302
+ "",
303
+ "-- Transformed Query:",
304
+ "-- " + "=" * 78,
305
+ "",
306
+ ]
307
+ )
308
+
309
+ return "\n".join(lines)
310
+
311
+ def log_sql(
312
+ self,
313
+ sql: str,
314
+ test_name: str,
315
+ test_class: Optional[str] = None,
316
+ test_file: Optional[str] = None,
317
+ failed: bool = False,
318
+ metadata: Optional[Dict[str, Any]] = None,
319
+ ) -> str:
320
+ """Log SQL to a file and return the file path.
321
+
322
+ Args:
323
+ sql: The transformed SQL query to log
324
+ test_name: Name of the test
325
+ test_class: Test class name
326
+ test_file: Test file path
327
+ failed: Whether the test failed
328
+ metadata: Additional metadata to include
329
+
330
+ Returns:
331
+ Path to the created SQL file
332
+ """
333
+ # Generate filename
334
+ filename = self.generate_filename(test_name, test_class, test_file, failed)
335
+
336
+ # Ensure run directory exists (lazy creation)
337
+ run_directory = self._ensure_run_directory()
338
+ filepath = run_directory / filename
339
+
340
+ # Prepare metadata
341
+ if metadata is None:
342
+ metadata = {}
343
+
344
+ # Create header
345
+ header = self.create_metadata_header(
346
+ test_name=test_name, test_class=test_class, test_file=test_file, **metadata
347
+ )
348
+
349
+ # Format SQL
350
+ dialect = metadata.get("adapter_type")
351
+ formatted_sql = self.format_sql(sql, dialect)
352
+
353
+ # Write to file
354
+ content = header + formatted_sql
355
+ filepath.write_text(content, encoding="utf-8")
356
+
357
+ # Track logged file
358
+ self._logged_files.append(str(filepath))
359
+
360
+ # Return absolute path for clickable URLs
361
+ return str(filepath.absolute())
362
+
363
+ def get_logged_files(self) -> List[str]:
364
+ """Get list of files logged in this session."""
365
+ return self._logged_files.copy()
366
+
367
+ def clear_logged_files(self) -> None:
368
+ """Clear the list of logged files."""
369
+ self._logged_files = []
370
+
371
+ @classmethod
372
+ def get_run_directory(cls) -> Optional[Path]:
373
+ """Get the current run directory."""
374
+ return cls._run_directory
375
+
376
+ @classmethod
377
+ def get_run_id(cls) -> Optional[str]:
378
+ """Get the current run ID."""
379
+ return cls._run_id
380
+
381
+ @classmethod
382
+ def reset_run_directory(cls) -> None:
383
+ """Reset the run directory (useful for testing)."""
384
+ cls._run_directory = None
385
+ cls._run_id = None
@@ -82,6 +82,7 @@ def format_sql_value(value: Any, column_type: Type, dialect: str = "standard") -
82
82
  """
83
83
  from datetime import date, datetime
84
84
  from decimal import Decimal
85
+ from typing import get_args
85
86
 
86
87
  import pandas as pd
87
88
 
@@ -89,6 +90,42 @@ def format_sql_value(value: Any, column_type: Type, dialect: str = "standard") -
89
90
  # Note: pd.isna() doesn't work on lists/arrays, so check for None first
90
91
  # and only use pd.isna() on scalar values
91
92
  if value is None or (not isinstance(value, (list, tuple)) and pd.isna(value)):
93
+ # Check if column_type is a List type
94
+ if hasattr(column_type, "__origin__") and column_type.__origin__ is list:
95
+ # Get the element type from List[T]
96
+ element_type = get_args(column_type)[0] if get_args(column_type) else str
97
+
98
+ if dialect in ("athena", "trino"):
99
+ # Map Python types to SQL types for array elements
100
+ if element_type == Decimal:
101
+ sql_element_type = "DECIMAL(38,9)"
102
+ elif element_type is int:
103
+ sql_element_type = "INTEGER" if dialect == "athena" else "BIGINT"
104
+ elif element_type is float:
105
+ sql_element_type = "DOUBLE"
106
+ elif element_type is bool:
107
+ sql_element_type = "BOOLEAN"
108
+ elif element_type is date:
109
+ sql_element_type = "DATE"
110
+ elif element_type == datetime:
111
+ sql_element_type = "TIMESTAMP"
112
+ else:
113
+ sql_element_type = "VARCHAR"
114
+
115
+ return f"CAST(NULL AS ARRAY({sql_element_type}))"
116
+ elif dialect == "bigquery":
117
+ # BigQuery doesn't need explicit NULL array casting
118
+ return "NULL"
119
+ elif dialect == "redshift":
120
+ # Redshift SUPER type handles NULL arrays
121
+ return "NULL::SUPER"
122
+ elif dialect == "snowflake":
123
+ # Snowflake VARIANT type handles NULL arrays
124
+ return "NULL::VARIANT"
125
+ else:
126
+ return "NULL"
127
+
128
+ # Handle non-array NULL values
92
129
  if dialect == "redshift":
93
130
  # Redshift needs type-specific NULL casting
94
131
  if column_type == Decimal: