sql-testing-library 0.6.0__tar.gz → 0.7.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/CHANGELOG.md +12 -0
- {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/PKG-INFO +2 -1
- {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/README.md +1 -0
- {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/pyproject.toml +1 -1
- {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/src/sql_testing_library/_adapters/athena.py +15 -1
- {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/src/sql_testing_library/_adapters/base.py +10 -1
- {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/src/sql_testing_library/_adapters/bigquery.py +63 -3
- {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/src/sql_testing_library/_adapters/redshift.py +15 -1
- {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/src/sql_testing_library/_adapters/snowflake.py +22 -1
- {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/src/sql_testing_library/_adapters/trino.py +17 -1
- {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/src/sql_testing_library/_core.py +185 -7
- {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/src/sql_testing_library/_pytest_plugin.py +96 -2
- sql_testing_library-0.7.1/src/sql_testing_library/_sql_logger.py +385 -0
- {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/src/sql_testing_library/_sql_utils.py +37 -0
- {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/LICENSE +0 -0
- {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/src/sql_testing_library/__init__.py +0 -0
- {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/src/sql_testing_library/_adapters/__init__.py +0 -0
- {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/src/sql_testing_library/_exceptions.py +0 -0
- {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/src/sql_testing_library/_mock_table.py +0 -0
- {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/src/sql_testing_library/_types.py +0 -0
- {sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/src/sql_testing_library/py.typed +0 -0
|
@@ -5,6 +5,18 @@ All notable changes to this project will be documented in this file.
|
|
|
5
5
|
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
|
6
6
|
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
|
7
7
|
|
|
8
|
+
## 0.7.1 (2025-06-06)
|
|
9
|
+
|
|
10
|
+
### Fix
|
|
11
|
+
|
|
12
|
+
- **array**: array handling logic + sql logging improvement (#95)
|
|
13
|
+
|
|
14
|
+
## 0.7.0 (2025-06-06)
|
|
15
|
+
|
|
16
|
+
### Feat
|
|
17
|
+
|
|
18
|
+
- **sqllogging**: added support for logging sql logs for debugging failed tests (#94)
|
|
19
|
+
|
|
8
20
|
## 0.6.0 (2025-06-05)
|
|
9
21
|
|
|
10
22
|
### Feat
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: sql-testing-library
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.7.1
|
|
4
4
|
Summary: A powerful Python framework for unit testing SQL queries across BigQuery, Snowflake, Redshift, Athena, and Trino with mock data
|
|
5
5
|
License: MIT
|
|
6
6
|
Keywords: sql,testing,unit-testing,mock-data,database-testing,bigquery,snowflake,redshift,athena,trino,data-engineering,etl-testing,sql-validation,query-testing
|
|
@@ -108,6 +108,7 @@ For more details on our journey and the engineering challenges we solved, read t
|
|
|
108
108
|
- **CTE or Physical Tables**: Automatic fallback for query size limits
|
|
109
109
|
- **Type-Safe Results**: Deserialize results to Pydantic models
|
|
110
110
|
- **Pytest Integration**: Seamless testing with `@sql_test` decorator
|
|
111
|
+
- **SQL Logging**: Comprehensive SQL logging with formatted output, error traces, and temp table queries
|
|
111
112
|
|
|
112
113
|
## Data Types Support
|
|
113
114
|
|
|
@@ -51,6 +51,7 @@ For more details on our journey and the engineering challenges we solved, read t
|
|
|
51
51
|
- **CTE or Physical Tables**: Automatic fallback for query size limits
|
|
52
52
|
- **Type-Safe Results**: Deserialize results to Pydantic models
|
|
53
53
|
- **Pytest Integration**: Seamless testing with `@sql_test` decorator
|
|
54
|
+
- **SQL Logging**: Comprehensive SQL logging with formatted output, error traces, and temp table queries
|
|
54
55
|
|
|
55
56
|
## Data Types Support
|
|
56
57
|
|
|
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
|
|
|
4
4
|
|
|
5
5
|
[tool.poetry]
|
|
6
6
|
name = "sql-testing-library"
|
|
7
|
-
version = "0.
|
|
7
|
+
version = "0.7.1"
|
|
8
8
|
description = "A powerful Python framework for unit testing SQL queries across BigQuery, Snowflake, Redshift, Athena, and Trino with mock data"
|
|
9
9
|
authors = ["Gurmeet Saran <gurmeetx@gmail.com>", "Kushal Thakkar <kushal.thakkar@gmail.com>"]
|
|
10
10
|
maintainers = ["Gurmeet Saran <gurmeetx@gmail.com>", "Kushal Thakkar <kushal.thakkar@gmail.com>"]
|
{sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/src/sql_testing_library/_adapters/athena.py
RENAMED
|
@@ -4,7 +4,7 @@ import logging
|
|
|
4
4
|
import time
|
|
5
5
|
from datetime import date, datetime
|
|
6
6
|
from decimal import Decimal
|
|
7
|
-
from typing import TYPE_CHECKING, Any, List, Optional, Type, Union, get_args
|
|
7
|
+
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Type, Union, get_args
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
if TYPE_CHECKING:
|
|
@@ -136,6 +136,20 @@ class AthenaAdapter(DatabaseAdapter):
|
|
|
136
136
|
|
|
137
137
|
return qualified_table_name
|
|
138
138
|
|
|
139
|
+
def create_temp_table_with_sql(self, mock_table: BaseMockTable) -> Tuple[str, str]:
|
|
140
|
+
"""Create a temporary table and return both table name and SQL."""
|
|
141
|
+
timestamp = int(time.time() * 1000)
|
|
142
|
+
temp_table_name = f"temp_{mock_table.get_table_name()}_{timestamp}"
|
|
143
|
+
qualified_table_name = f"{self.database}.{temp_table_name}"
|
|
144
|
+
|
|
145
|
+
# Generate CTAS statement (CREATE TABLE AS SELECT)
|
|
146
|
+
ctas_sql = self._generate_ctas_sql(temp_table_name, mock_table)
|
|
147
|
+
|
|
148
|
+
# Execute CTAS query
|
|
149
|
+
self.execute_query(ctas_sql)
|
|
150
|
+
|
|
151
|
+
return qualified_table_name, ctas_sql
|
|
152
|
+
|
|
139
153
|
def cleanup_temp_tables(self, table_names: List[str]) -> None:
|
|
140
154
|
"""Clean up temporary tables."""
|
|
141
155
|
for full_table_name in table_names:
|
{sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/src/sql_testing_library/_adapters/base.py
RENAMED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Base database adapter interface."""
|
|
2
2
|
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
|
-
from typing import TYPE_CHECKING, Any, List, Optional
|
|
4
|
+
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
if TYPE_CHECKING:
|
|
@@ -30,6 +30,15 @@ class DatabaseAdapter(ABC):
|
|
|
30
30
|
"""Create a temporary table with mock data. Returns temp table name."""
|
|
31
31
|
pass
|
|
32
32
|
|
|
33
|
+
@abstractmethod
|
|
34
|
+
def create_temp_table_with_sql(self, mock_table: BaseMockTable) -> Tuple[str, str]:
|
|
35
|
+
"""Create a temporary table and return both table name and SQL.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
Tuple of (temp_table_name, create_table_sql)
|
|
39
|
+
"""
|
|
40
|
+
pass
|
|
41
|
+
|
|
33
42
|
@abstractmethod
|
|
34
43
|
def cleanup_temp_tables(self, table_names: List[str]) -> None:
|
|
35
44
|
"""Clean up temporary tables."""
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
import logging
|
|
4
4
|
from datetime import date, datetime
|
|
5
5
|
from decimal import Decimal
|
|
6
|
-
from typing import TYPE_CHECKING, Any, List, Optional, Type, Union, get_args
|
|
6
|
+
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Type, Union, get_args
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
if TYPE_CHECKING:
|
|
@@ -88,6 +88,54 @@ class BigQueryAdapter(DatabaseAdapter):
|
|
|
88
88
|
|
|
89
89
|
return table_id
|
|
90
90
|
|
|
91
|
+
def create_temp_table_with_sql(self, mock_table: BaseMockTable) -> Tuple[str, str]:
|
|
92
|
+
"""Create temporary table and return both table name and SQL."""
|
|
93
|
+
import time
|
|
94
|
+
|
|
95
|
+
temp_table_name = f"temp_{mock_table.get_table_name()}_{int(time.time() * 1000)}"
|
|
96
|
+
table_id = f"{self.project_id}.{self.dataset_id}.{temp_table_name}"
|
|
97
|
+
|
|
98
|
+
# Generate CREATE TABLE SQL
|
|
99
|
+
schema = self._get_bigquery_schema(mock_table)
|
|
100
|
+
column_defs = []
|
|
101
|
+
for field in schema:
|
|
102
|
+
column_defs.append(f"`{field.name}` {field.field_type}")
|
|
103
|
+
|
|
104
|
+
columns_sql = ",\n ".join(column_defs)
|
|
105
|
+
create_sql = f"CREATE TABLE `{table_id}` (\n {columns_sql}\n)"
|
|
106
|
+
|
|
107
|
+
# Get insert SQL for the data
|
|
108
|
+
df = mock_table.to_dataframe()
|
|
109
|
+
if not df.empty:
|
|
110
|
+
# Generate INSERT statement
|
|
111
|
+
values_rows = []
|
|
112
|
+
for _, row in df.iterrows():
|
|
113
|
+
values = []
|
|
114
|
+
for col in df.columns:
|
|
115
|
+
value = row[col]
|
|
116
|
+
col_type = mock_table.get_column_types().get(col, str)
|
|
117
|
+
formatted_value = self.format_value_for_cte(value, col_type)
|
|
118
|
+
values.append(formatted_value)
|
|
119
|
+
values_rows.append(f"({', '.join(values)})")
|
|
120
|
+
|
|
121
|
+
values_sql = ",\n".join(values_rows)
|
|
122
|
+
insert_sql = f"INSERT INTO `{table_id}` VALUES\n{values_sql}"
|
|
123
|
+
full_sql = f"{create_sql};\n\n{insert_sql};"
|
|
124
|
+
else:
|
|
125
|
+
full_sql = create_sql + ";"
|
|
126
|
+
|
|
127
|
+
# Actually create the table
|
|
128
|
+
table = bigquery.Table(table_id, schema=schema)
|
|
129
|
+
table = self.client.create_table(table)
|
|
130
|
+
|
|
131
|
+
# Insert data if any
|
|
132
|
+
if not df.empty:
|
|
133
|
+
job_config = bigquery.LoadJobConfig()
|
|
134
|
+
job = self.client.load_table_from_dataframe(df, table, job_config=job_config)
|
|
135
|
+
job.result()
|
|
136
|
+
|
|
137
|
+
return table_id, full_sql
|
|
138
|
+
|
|
91
139
|
def cleanup_temp_tables(self, table_names: List[str]) -> None:
|
|
92
140
|
"""Delete temporary tables."""
|
|
93
141
|
for table_name in table_names:
|
|
@@ -130,7 +178,19 @@ class BigQueryAdapter(DatabaseAdapter):
|
|
|
130
178
|
if non_none_types:
|
|
131
179
|
col_type = non_none_types[0]
|
|
132
180
|
|
|
133
|
-
|
|
134
|
-
|
|
181
|
+
# Handle List/Array types
|
|
182
|
+
if hasattr(col_type, "__origin__") and col_type.__origin__ is list:
|
|
183
|
+
# Get the element type from List[T]
|
|
184
|
+
element_type = get_args(col_type)[0] if get_args(col_type) else str
|
|
185
|
+
|
|
186
|
+
# Map element type to BigQuery type
|
|
187
|
+
element_bq_type = type_mapping.get(element_type, bigquery.enums.SqlTypeNames.STRING)
|
|
188
|
+
|
|
189
|
+
# Create field with mode=REPEATED for arrays
|
|
190
|
+
schema.append(bigquery.SchemaField(col_name, element_bq_type, mode="REPEATED"))
|
|
191
|
+
else:
|
|
192
|
+
# Handle scalar types
|
|
193
|
+
bq_type = type_mapping.get(col_type, bigquery.enums.SqlTypeNames.STRING)
|
|
194
|
+
schema.append(bigquery.SchemaField(col_name, bq_type))
|
|
135
195
|
|
|
136
196
|
return schema
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
import time
|
|
4
4
|
from datetime import date, datetime
|
|
5
5
|
from decimal import Decimal
|
|
6
|
-
from typing import TYPE_CHECKING, Any, List, Optional, Type, Union, get_args
|
|
6
|
+
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Type, Union, get_args
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
if TYPE_CHECKING:
|
|
@@ -122,6 +122,20 @@ class RedshiftAdapter(DatabaseAdapter):
|
|
|
122
122
|
# Return just the table name, no schema prefix needed for temp tables
|
|
123
123
|
return temp_table_name
|
|
124
124
|
|
|
125
|
+
def create_temp_table_with_sql(self, mock_table: BaseMockTable) -> Tuple[str, str]:
|
|
126
|
+
"""Create a temporary table and return both table name and SQL."""
|
|
127
|
+
timestamp = int(time.time() * 1000)
|
|
128
|
+
temp_table_name = f"temp_{mock_table.get_table_name()}_{timestamp}"
|
|
129
|
+
|
|
130
|
+
# Generate CTAS statement (CREATE TABLE AS SELECT)
|
|
131
|
+
ctas_sql = self._generate_ctas_sql(temp_table_name, mock_table)
|
|
132
|
+
|
|
133
|
+
# Execute CTAS query
|
|
134
|
+
self.execute_query(ctas_sql)
|
|
135
|
+
|
|
136
|
+
# Return just the table name and the SQL
|
|
137
|
+
return temp_table_name, ctas_sql
|
|
138
|
+
|
|
125
139
|
def cleanup_temp_tables(self, table_names: List[str]) -> None:
|
|
126
140
|
"""Clean up temporary tables."""
|
|
127
141
|
# Redshift temporary tables are automatically dropped at the end of the session
|
|
@@ -4,7 +4,7 @@ import logging
|
|
|
4
4
|
import time
|
|
5
5
|
from datetime import date, datetime
|
|
6
6
|
from decimal import Decimal
|
|
7
|
-
from typing import TYPE_CHECKING, Any, List, Optional, Type, Union, get_args
|
|
7
|
+
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Type, Union, get_args
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
if TYPE_CHECKING:
|
|
@@ -152,6 +152,27 @@ class SnowflakeAdapter(DatabaseAdapter):
|
|
|
152
152
|
|
|
153
153
|
return qualified_table_name
|
|
154
154
|
|
|
155
|
+
def create_temp_table_with_sql(self, mock_table: BaseMockTable) -> Tuple[str, str]:
|
|
156
|
+
"""Create a temporary table and return both table name and SQL."""
|
|
157
|
+
timestamp = int(time.time() * 1000)
|
|
158
|
+
temp_table_name = f"TEMP_{mock_table.get_table_name()}_{timestamp}"
|
|
159
|
+
|
|
160
|
+
# Use the adapter's configured database and schema for temporary tables
|
|
161
|
+
# This avoids permission issues with creating schemas in other databases
|
|
162
|
+
target_schema = self.schema
|
|
163
|
+
|
|
164
|
+
# For temporary tables, Snowflake doesn't support full database qualification
|
|
165
|
+
# Return schema.table format for temporary tables
|
|
166
|
+
qualified_table_name = f"{target_schema}.{temp_table_name}"
|
|
167
|
+
|
|
168
|
+
# Generate CTAS statement (CREATE TABLE AS SELECT)
|
|
169
|
+
ctas_sql = self._generate_ctas_sql(temp_table_name, mock_table, target_schema)
|
|
170
|
+
|
|
171
|
+
# Execute CTAS query
|
|
172
|
+
self.execute_query(ctas_sql)
|
|
173
|
+
|
|
174
|
+
return qualified_table_name, ctas_sql
|
|
175
|
+
|
|
155
176
|
def cleanup_temp_tables(self, table_names: List[str]) -> None:
|
|
156
177
|
"""Clean up temporary tables."""
|
|
157
178
|
for full_table_name in table_names:
|
{sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/src/sql_testing_library/_adapters/trino.py
RENAMED
|
@@ -4,7 +4,7 @@ import logging
|
|
|
4
4
|
import time
|
|
5
5
|
from datetime import date, datetime
|
|
6
6
|
from decimal import Decimal
|
|
7
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union, get_args
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union, get_args
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
if TYPE_CHECKING:
|
|
@@ -132,6 +132,22 @@ class TrinoAdapter(DatabaseAdapter):
|
|
|
132
132
|
|
|
133
133
|
return qualified_table_name
|
|
134
134
|
|
|
135
|
+
def create_temp_table_with_sql(self, mock_table: BaseMockTable) -> Tuple[str, str]:
|
|
136
|
+
"""Create a temporary table and return both table name and SQL."""
|
|
137
|
+
timestamp = int(time.time() * 1000)
|
|
138
|
+
temp_table_name = f"temp_{mock_table.get_table_name()}_{timestamp}"
|
|
139
|
+
|
|
140
|
+
# In Trino, tables are qualified with catalog and schema
|
|
141
|
+
qualified_table_name = f"{self.catalog}.{self.schema}.{temp_table_name}"
|
|
142
|
+
|
|
143
|
+
# Generate CTAS statement (CREATE TABLE AS SELECT)
|
|
144
|
+
ctas_sql = self._generate_ctas_sql(temp_table_name, mock_table)
|
|
145
|
+
|
|
146
|
+
# Execute CTAS query
|
|
147
|
+
self.execute_query(ctas_sql)
|
|
148
|
+
|
|
149
|
+
return qualified_table_name, ctas_sql
|
|
150
|
+
|
|
135
151
|
def cleanup_temp_tables(self, table_names: List[str]) -> None:
|
|
136
152
|
"""Clean up temporary tables."""
|
|
137
153
|
for full_table_name in table_names:
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Core SQL testing framework."""
|
|
2
2
|
|
|
3
|
+
import os
|
|
3
4
|
from dataclasses import dataclass
|
|
4
5
|
from typing import (
|
|
5
6
|
TYPE_CHECKING,
|
|
@@ -27,6 +28,7 @@ from ._exceptions import (
|
|
|
27
28
|
TypeConversionError,
|
|
28
29
|
)
|
|
29
30
|
from ._mock_table import BaseMockTable
|
|
31
|
+
from ._sql_logger import SQLLogger
|
|
30
32
|
|
|
31
33
|
|
|
32
34
|
# Type for adapter types
|
|
@@ -34,6 +36,9 @@ AdapterType = Literal["bigquery", "athena", "redshift", "trino", "snowflake"]
|
|
|
34
36
|
|
|
35
37
|
T = TypeVar("T")
|
|
36
38
|
|
|
39
|
+
# Global storage for SQL execution data (used by pytest plugin)
|
|
40
|
+
sql_test_execution_data: Dict[str, Dict[str, Any]] = {}
|
|
41
|
+
|
|
37
42
|
|
|
38
43
|
@dataclass
|
|
39
44
|
class SQLTestCase(Generic[T]):
|
|
@@ -48,6 +53,7 @@ class SQLTestCase(Generic[T]):
|
|
|
48
53
|
use_physical_tables: bool = False
|
|
49
54
|
description: Optional[str] = None
|
|
50
55
|
adapter_type: Optional[AdapterType] = None
|
|
56
|
+
log_sql: Optional[bool] = None
|
|
51
57
|
# Backward compatibility
|
|
52
58
|
execution_database: Optional[str] = None
|
|
53
59
|
|
|
@@ -84,21 +90,34 @@ class SQLTestCase(Generic[T]):
|
|
|
84
90
|
class SQLTestFramework:
|
|
85
91
|
"""Main framework for executing SQL tests."""
|
|
86
92
|
|
|
87
|
-
def __init__(self, adapter: DatabaseAdapter) -> None:
|
|
93
|
+
def __init__(self, adapter: DatabaseAdapter, sql_logger: Optional[SQLLogger] = None) -> None:
|
|
88
94
|
self.adapter = adapter
|
|
89
95
|
self.type_converter = self.adapter.get_type_converter()
|
|
90
96
|
self.temp_tables: List[str] = []
|
|
97
|
+
self.sql_logger = sql_logger or SQLLogger()
|
|
91
98
|
|
|
92
|
-
def run_test(
|
|
99
|
+
def run_test(
|
|
100
|
+
self, test_case: SQLTestCase[T], test_context: Optional[Dict[str, Any]] = None
|
|
101
|
+
) -> List[T]:
|
|
93
102
|
"""
|
|
94
103
|
Execute a test case and return deserialized results.
|
|
95
104
|
|
|
96
105
|
Args:
|
|
97
106
|
test_case: The test case to execute
|
|
107
|
+
test_context: Optional context dictionary with test metadata
|
|
98
108
|
|
|
99
109
|
Returns:
|
|
100
110
|
List of result objects of type test_case.result_class
|
|
101
111
|
"""
|
|
112
|
+
import time
|
|
113
|
+
|
|
114
|
+
# Track execution time
|
|
115
|
+
start_time = time.time()
|
|
116
|
+
final_query = ""
|
|
117
|
+
error_message = None
|
|
118
|
+
row_count = None
|
|
119
|
+
temp_table_queries: List[str] = [] # Track temp table creation queries
|
|
120
|
+
|
|
102
121
|
try:
|
|
103
122
|
# Validate required fields
|
|
104
123
|
if test_case.mock_tables is None:
|
|
@@ -130,7 +149,7 @@ class SQLTestFramework:
|
|
|
130
149
|
if test_case.use_physical_tables:
|
|
131
150
|
# Create physical temporary tables
|
|
132
151
|
final_query = self._execute_with_physical_tables(
|
|
133
|
-
test_case.query, table_mapping, test_case.mock_tables
|
|
152
|
+
test_case.query, table_mapping, test_case.mock_tables, temp_table_queries
|
|
134
153
|
)
|
|
135
154
|
else:
|
|
136
155
|
# Generate query with CTEs
|
|
@@ -150,8 +169,134 @@ class SQLTestFramework:
|
|
|
150
169
|
# Execute query
|
|
151
170
|
result_df = self.adapter.execute_query(final_query)
|
|
152
171
|
|
|
172
|
+
# Track row count
|
|
173
|
+
row_count = len(result_df) if result_df is not None else 0
|
|
174
|
+
|
|
153
175
|
# Convert results to typed objects
|
|
154
|
-
|
|
176
|
+
results = self._deserialize_results(result_df, test_case.result_class)
|
|
177
|
+
|
|
178
|
+
# Log SQL if enabled (success case)
|
|
179
|
+
execution_time = time.time() - start_time
|
|
180
|
+
|
|
181
|
+
# Store execution data for potential logging on test failure
|
|
182
|
+
if test_context:
|
|
183
|
+
test_id = test_context.get("test_id")
|
|
184
|
+
if test_id:
|
|
185
|
+
sql_test_execution_data[test_id] = {
|
|
186
|
+
"sql": final_query,
|
|
187
|
+
"test_name": test_context.get("test_name", "unknown_test"),
|
|
188
|
+
"test_class": test_context.get("test_class"),
|
|
189
|
+
"test_file": test_context.get("test_file"),
|
|
190
|
+
"metadata": {
|
|
191
|
+
"query": test_case.query,
|
|
192
|
+
"default_namespace": test_case.default_namespace,
|
|
193
|
+
"mock_tables": test_case.mock_tables,
|
|
194
|
+
"adapter_type": self.adapter.__class__.__name__.replace(
|
|
195
|
+
"Adapter", ""
|
|
196
|
+
).lower(),
|
|
197
|
+
"use_physical_tables": test_case.use_physical_tables,
|
|
198
|
+
"execution_time": execution_time,
|
|
199
|
+
"row_count": row_count,
|
|
200
|
+
"error": None,
|
|
201
|
+
"temp_table_queries": temp_table_queries,
|
|
202
|
+
},
|
|
203
|
+
"sql_logger": self.sql_logger,
|
|
204
|
+
"log_sql": test_case.log_sql,
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
if self.sql_logger.should_log(test_case.log_sql):
|
|
208
|
+
# Get test context info
|
|
209
|
+
test_name = (
|
|
210
|
+
test_context.get("test_name", "unknown_test")
|
|
211
|
+
if test_context
|
|
212
|
+
else "unknown_test"
|
|
213
|
+
)
|
|
214
|
+
test_class = test_context.get("test_class") if test_context else None
|
|
215
|
+
test_file = test_context.get("test_file") if test_context else None
|
|
216
|
+
|
|
217
|
+
metadata = {
|
|
218
|
+
"query": test_case.query,
|
|
219
|
+
"default_namespace": test_case.default_namespace,
|
|
220
|
+
"mock_tables": test_case.mock_tables,
|
|
221
|
+
"adapter_type": self.adapter.get_sqlglot_dialect(),
|
|
222
|
+
"adapter_name": self.adapter.__class__.__name__.replace("Adapter", "").lower(),
|
|
223
|
+
"use_physical_tables": test_case.use_physical_tables,
|
|
224
|
+
"execution_time": execution_time,
|
|
225
|
+
"row_count": row_count,
|
|
226
|
+
"error": None,
|
|
227
|
+
"temp_table_queries": temp_table_queries,
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
# Log SQL immediately
|
|
231
|
+
log_path = self.sql_logger.log_sql(
|
|
232
|
+
sql=final_query,
|
|
233
|
+
test_name=test_name,
|
|
234
|
+
test_class=test_class,
|
|
235
|
+
test_file=test_file,
|
|
236
|
+
failed=False,
|
|
237
|
+
metadata=metadata,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
# Print log location if environment variable is set
|
|
241
|
+
if os.environ.get("SQL_TEST_LOG_ALL", "").lower() in ("true", "1", "yes"):
|
|
242
|
+
import sys
|
|
243
|
+
|
|
244
|
+
print(f"\nSQL logged to: file://{log_path}", file=sys.stderr) # noqa: T201
|
|
245
|
+
sys.stderr.flush()
|
|
246
|
+
|
|
247
|
+
return results
|
|
248
|
+
|
|
249
|
+
except Exception as e:
|
|
250
|
+
# Store exception information for potential logging by pytest hook
|
|
251
|
+
execution_time = time.time() - start_time
|
|
252
|
+
|
|
253
|
+
# Capture full error details including traceback
|
|
254
|
+
import traceback
|
|
255
|
+
|
|
256
|
+
error_message = str(e)
|
|
257
|
+
error_traceback = traceback.format_exc()
|
|
258
|
+
|
|
259
|
+
# Store execution data for pytest hook to potentially log
|
|
260
|
+
if test_context and test_case.log_sql is not False:
|
|
261
|
+
test_id = test_context.get("test_id")
|
|
262
|
+
if test_id:
|
|
263
|
+
# Update the execution data with error information
|
|
264
|
+
if test_id in sql_test_execution_data:
|
|
265
|
+
sql_test_execution_data[test_id]["metadata"]["error"] = error_message
|
|
266
|
+
sql_test_execution_data[test_id]["metadata"]["error_traceback"] = (
|
|
267
|
+
error_traceback
|
|
268
|
+
)
|
|
269
|
+
sql_test_execution_data[test_id]["metadata"]["execution_time"] = (
|
|
270
|
+
execution_time
|
|
271
|
+
)
|
|
272
|
+
sql_test_execution_data[test_id]["metadata"]["row_count"] = row_count
|
|
273
|
+
else:
|
|
274
|
+
# If we haven't stored data yet (error happened early), store it now
|
|
275
|
+
sql_test_execution_data[test_id] = {
|
|
276
|
+
"sql": final_query if "final_query" in locals() else test_case.query,
|
|
277
|
+
"test_name": test_context.get("test_name", "unknown_test"),
|
|
278
|
+
"test_class": test_context.get("test_class"),
|
|
279
|
+
"test_file": test_context.get("test_file"),
|
|
280
|
+
"metadata": {
|
|
281
|
+
"query": test_case.query,
|
|
282
|
+
"default_namespace": test_case.default_namespace,
|
|
283
|
+
"mock_tables": test_case.mock_tables,
|
|
284
|
+
"adapter_type": self.adapter.get_sqlglot_dialect(),
|
|
285
|
+
"adapter_name": self.adapter.__class__.__name__.replace(
|
|
286
|
+
"Adapter", ""
|
|
287
|
+
).lower(),
|
|
288
|
+
"use_physical_tables": test_case.use_physical_tables,
|
|
289
|
+
"execution_time": execution_time,
|
|
290
|
+
"row_count": row_count,
|
|
291
|
+
"error": error_message,
|
|
292
|
+
"error_traceback": error_traceback,
|
|
293
|
+
"temp_table_queries": temp_table_queries,
|
|
294
|
+
},
|
|
295
|
+
"sql_logger": self.sql_logger,
|
|
296
|
+
"log_sql": test_case.log_sql,
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
raise
|
|
155
300
|
|
|
156
301
|
finally:
|
|
157
302
|
# Cleanup any temporary tables
|
|
@@ -439,15 +584,48 @@ class SQLTestFramework:
|
|
|
439
584
|
query: str,
|
|
440
585
|
table_mapping: Dict[str, BaseMockTable],
|
|
441
586
|
mock_tables: List[BaseMockTable],
|
|
587
|
+
temp_table_queries: List[str],
|
|
442
588
|
) -> str:
|
|
443
589
|
"""Execute query using physical temporary tables."""
|
|
444
590
|
# Create physical tables
|
|
445
591
|
replacement_mapping = {}
|
|
446
592
|
|
|
447
593
|
for original_name, mock_table in table_mapping.items():
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
594
|
+
try:
|
|
595
|
+
# Check if adapter has method to get temp table SQL
|
|
596
|
+
if hasattr(self.adapter, "create_temp_table_with_sql"):
|
|
597
|
+
temp_table_name, create_sql = self.adapter.create_temp_table_with_sql(
|
|
598
|
+
mock_table
|
|
599
|
+
)
|
|
600
|
+
temp_table_queries.append(create_sql)
|
|
601
|
+
else:
|
|
602
|
+
temp_table_name = self.adapter.create_temp_table(mock_table)
|
|
603
|
+
# Try to generate approximate SQL for logging
|
|
604
|
+
temp_table_queries.append(
|
|
605
|
+
f"-- CREATE TEMP TABLE {temp_table_name} (SQL not captured)"
|
|
606
|
+
)
|
|
607
|
+
|
|
608
|
+
self.temp_tables.append(temp_table_name)
|
|
609
|
+
replacement_mapping[original_name] = temp_table_name
|
|
610
|
+
except Exception:
|
|
611
|
+
# If table creation fails, still try to capture the SQL for debugging
|
|
612
|
+
if hasattr(self.adapter, "create_temp_table_with_sql") and hasattr(
|
|
613
|
+
mock_table, "get_table_name"
|
|
614
|
+
):
|
|
615
|
+
try:
|
|
616
|
+
temp_table_name, ctas_sql = self.adapter.create_temp_table_with_sql(
|
|
617
|
+
mock_table
|
|
618
|
+
)
|
|
619
|
+
temp_table_queries.append(ctas_sql)
|
|
620
|
+
replacement_mapping[original_name] = temp_table_name
|
|
621
|
+
except Exception:
|
|
622
|
+
# If even SQL generation fails, add a placeholder
|
|
623
|
+
temp_table_queries.append(
|
|
624
|
+
f"-- CREATE TEMP TABLE for {original_name} (SQL generation failed)"
|
|
625
|
+
)
|
|
626
|
+
replacement_mapping[original_name] = f"temp_{original_name}_failed"
|
|
627
|
+
# Re-raise the original exception
|
|
628
|
+
raise
|
|
451
629
|
|
|
452
630
|
# Replace table names and return modified query
|
|
453
631
|
return self._replace_table_names_in_query(query, replacement_mapping)
|
{sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/src/sql_testing_library/_pytest_plugin.py
RENAMED
|
@@ -10,7 +10,7 @@ import pytest
|
|
|
10
10
|
from _pytest.nodes import Item
|
|
11
11
|
|
|
12
12
|
from ._adapters.base import DatabaseAdapter
|
|
13
|
-
from ._core import AdapterType, SQLTestCase, SQLTestFramework
|
|
13
|
+
from ._core import AdapterType, SQLTestCase, SQLTestFramework, sql_test_execution_data
|
|
14
14
|
from ._mock_table import BaseMockTable
|
|
15
15
|
|
|
16
16
|
|
|
@@ -335,12 +335,16 @@ class SQLTestDecorator:
|
|
|
335
335
|
# Global instance
|
|
336
336
|
_sql_test_decorator = SQLTestDecorator()
|
|
337
337
|
|
|
338
|
+
# Global SQL execution context for logging
|
|
339
|
+
_sql_execution_context: Dict[str, Any] = {}
|
|
340
|
+
|
|
338
341
|
|
|
339
342
|
def sql_test(
|
|
340
343
|
mock_tables: Optional[List[BaseMockTable]] = None,
|
|
341
344
|
result_class: Optional[Type[T]] = None,
|
|
342
345
|
use_physical_tables: Optional[bool] = None,
|
|
343
346
|
adapter_type: Optional[AdapterType] = None,
|
|
347
|
+
log_sql: Optional[bool] = None,
|
|
344
348
|
) -> Callable[[Callable[[], SQLTestCase[T]]], Callable[[], List[T]]]:
|
|
345
349
|
"""
|
|
346
350
|
Decorator to mark a function as a SQL test.
|
|
@@ -360,6 +364,8 @@ def sql_test(
|
|
|
360
364
|
(e.g., 'bigquery', 'athena').
|
|
361
365
|
If provided, overrides adapter_type in SQLTestCase and uses config
|
|
362
366
|
from [sql_testing.{adapter_type}] section.
|
|
367
|
+
log_sql: Optional flag to log the generated SQL to a file.
|
|
368
|
+
If provided, overrides log_sql in SQLTestCase.
|
|
363
369
|
"""
|
|
364
370
|
|
|
365
371
|
def decorator(func: Callable[[], SQLTestCase[T]]) -> Callable[[], List[T]]:
|
|
@@ -396,15 +402,54 @@ def sql_test(
|
|
|
396
402
|
if adapter_type is not None:
|
|
397
403
|
test_case.adapter_type = adapter_type
|
|
398
404
|
|
|
405
|
+
if log_sql is not None:
|
|
406
|
+
test_case.log_sql = log_sql
|
|
407
|
+
|
|
399
408
|
# Get framework and execute test
|
|
400
409
|
framework = _sql_test_decorator.get_framework(test_case.adapter_type)
|
|
401
|
-
|
|
410
|
+
|
|
411
|
+
# Create test context for logging
|
|
412
|
+
test_context = {}
|
|
413
|
+
|
|
414
|
+
# Try to get test metadata from the current pytest context
|
|
415
|
+
import inspect
|
|
416
|
+
|
|
417
|
+
frame = inspect.currentframe()
|
|
418
|
+
while frame:
|
|
419
|
+
frame_locals = frame.f_locals
|
|
420
|
+
if "item" in frame_locals and hasattr(frame_locals["item"], "name"):
|
|
421
|
+
item = frame_locals["item"]
|
|
422
|
+
test_context["test_name"] = item.name
|
|
423
|
+
test_context["test_class"] = item.cls.__name__ if item.cls else None
|
|
424
|
+
test_context["test_file"] = (
|
|
425
|
+
str(item.fspath) if hasattr(item, "fspath") else None
|
|
426
|
+
)
|
|
427
|
+
# Create a unique test ID
|
|
428
|
+
test_context["test_id"] = str(id(item))
|
|
429
|
+
break
|
|
430
|
+
frame = frame.f_back
|
|
431
|
+
|
|
432
|
+
# If we couldn't get test context from stack, try to get it from function name
|
|
433
|
+
if not test_context:
|
|
434
|
+
test_context["test_name"] = func.__name__
|
|
435
|
+
test_context["test_file"] = inspect.getfile(func) if func is not None else None
|
|
436
|
+
# Create a unique test ID
|
|
437
|
+
test_context["test_id"] = f"{func.__name__}_{id(func)}"
|
|
438
|
+
|
|
439
|
+
results: List[T] = framework.run_test(test_case, test_context)
|
|
402
440
|
|
|
403
441
|
return results
|
|
404
442
|
|
|
405
443
|
# Mark function as SQL test
|
|
406
444
|
wrapper._sql_test_decorated = True # type: ignore
|
|
407
445
|
wrapper._original_func = func # type: ignore
|
|
446
|
+
wrapper._decorator_params = { # type: ignore
|
|
447
|
+
"mock_tables": mock_tables,
|
|
448
|
+
"result_class": result_class,
|
|
449
|
+
"use_physical_tables": use_physical_tables,
|
|
450
|
+
"adapter_type": adapter_type,
|
|
451
|
+
"log_sql": log_sql,
|
|
452
|
+
}
|
|
408
453
|
|
|
409
454
|
return wrapper
|
|
410
455
|
|
|
@@ -449,3 +494,52 @@ def pytest_runtest_call(item: Item) -> None:
|
|
|
449
494
|
else:
|
|
450
495
|
# Use default pytest execution
|
|
451
496
|
item.runtest()
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
def pytest_runtest_makereport(item: Item, call: Any) -> None:
|
|
500
|
+
"""Hook to log SQL when tests fail (including assertion failures)."""
|
|
501
|
+
# We want to log after the test call phase
|
|
502
|
+
if call.when == "call":
|
|
503
|
+
test_id = str(id(item))
|
|
504
|
+
|
|
505
|
+
if call.excinfo is not None:
|
|
506
|
+
# Test failed - check if we have SQL execution data for this test
|
|
507
|
+
if test_id in sql_test_execution_data:
|
|
508
|
+
data = sql_test_execution_data[test_id]
|
|
509
|
+
sql_logger = data["sql_logger"]
|
|
510
|
+
log_sql = data.get("log_sql")
|
|
511
|
+
|
|
512
|
+
# Only log if log_sql is not False
|
|
513
|
+
if log_sql is not False:
|
|
514
|
+
# Capture the assertion error details
|
|
515
|
+
import traceback
|
|
516
|
+
|
|
517
|
+
metadata = data["metadata"].copy()
|
|
518
|
+
# Update error info with the actual pytest error
|
|
519
|
+
# (might be different from stored error)
|
|
520
|
+
metadata["error"] = str(call.excinfo.value)
|
|
521
|
+
metadata["error_traceback"] = "".join(
|
|
522
|
+
traceback.format_exception(
|
|
523
|
+
call.excinfo.type, call.excinfo.value, call.excinfo.tb
|
|
524
|
+
)
|
|
525
|
+
)
|
|
526
|
+
|
|
527
|
+
# Log the SQL
|
|
528
|
+
log_path = sql_logger.log_sql(
|
|
529
|
+
sql=data["sql"],
|
|
530
|
+
test_name=data["test_name"],
|
|
531
|
+
test_class=data["test_class"],
|
|
532
|
+
test_file=data["test_file"],
|
|
533
|
+
failed=True,
|
|
534
|
+
metadata=metadata,
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
# Print log location
|
|
538
|
+
import sys
|
|
539
|
+
|
|
540
|
+
print(f"\nSQL logged to: file://{log_path}", file=sys.stderr) # noqa: T201
|
|
541
|
+
sys.stderr.flush()
|
|
542
|
+
|
|
543
|
+
# Clean up the stored data after the test (whether it passed or failed)
|
|
544
|
+
if test_id in sql_test_execution_data:
|
|
545
|
+
del sql_test_execution_data[test_id]
|
|
@@ -0,0 +1,385 @@
|
|
|
1
|
+
"""SQL logging functionality for test cases."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import re
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Dict, List, Optional
|
|
8
|
+
|
|
9
|
+
from sqlglot import parse_one
|
|
10
|
+
|
|
11
|
+
from ._mock_table import BaseMockTable
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SQLLogger:
|
|
15
|
+
"""Handles SQL logging for test cases."""
|
|
16
|
+
|
|
17
|
+
# Class variable to store the run directory for the current test session
|
|
18
|
+
_run_directory: Optional[Path] = None
|
|
19
|
+
_run_id: Optional[str] = None
|
|
20
|
+
|
|
21
|
+
def __init__(self, log_dir: Optional[str] = None) -> None:
|
|
22
|
+
"""Initialize SQL logger.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
log_dir: Directory to store SQL log files. If None, uses .sql_logs in project root.
|
|
26
|
+
"""
|
|
27
|
+
if log_dir is None:
|
|
28
|
+
# Check environment variable first
|
|
29
|
+
env_log_dir = os.environ.get("SQL_TEST_LOG_DIR")
|
|
30
|
+
if env_log_dir:
|
|
31
|
+
self.log_dir = Path(env_log_dir)
|
|
32
|
+
else:
|
|
33
|
+
# Try to find the project root by looking for specific project files
|
|
34
|
+
current_path = Path.cwd()
|
|
35
|
+
|
|
36
|
+
# Look for definitive project root markers (in order of preference)
|
|
37
|
+
# These are files that typically only exist at project root
|
|
38
|
+
root_markers = ["pyproject.toml", "setup.py", "setup.cfg", "tox.ini"]
|
|
39
|
+
|
|
40
|
+
# Search up the directory tree for project root
|
|
41
|
+
project_root = None
|
|
42
|
+
search_path = current_path
|
|
43
|
+
|
|
44
|
+
while search_path != search_path.parent:
|
|
45
|
+
# Check for root markers
|
|
46
|
+
if any((search_path / marker).exists() for marker in root_markers):
|
|
47
|
+
project_root = search_path
|
|
48
|
+
break
|
|
49
|
+
|
|
50
|
+
# Also check for .git directory (but not .git file which could be a submodule)
|
|
51
|
+
if (search_path / ".git").is_dir():
|
|
52
|
+
project_root = search_path
|
|
53
|
+
break
|
|
54
|
+
|
|
55
|
+
search_path = search_path.parent
|
|
56
|
+
|
|
57
|
+
# If we found a project root, use it; otherwise fall back to current directory
|
|
58
|
+
if project_root:
|
|
59
|
+
self.log_dir = project_root / ".sql_logs"
|
|
60
|
+
else:
|
|
61
|
+
# Fall back to current directory if project root not found
|
|
62
|
+
self.log_dir = Path(".sql_logs")
|
|
63
|
+
else:
|
|
64
|
+
self.log_dir = Path(log_dir)
|
|
65
|
+
|
|
66
|
+
self.log_dir.mkdir(parents=True, exist_ok=True)
|
|
67
|
+
self._logged_files: List[str] = []
|
|
68
|
+
|
|
69
|
+
def _ensure_run_directory(self) -> Path:
|
|
70
|
+
"""Ensure run directory exists, creating it if necessary.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Path to the run directory
|
|
74
|
+
"""
|
|
75
|
+
# Create run directory if not already created for this session
|
|
76
|
+
if SQLLogger._run_directory is None:
|
|
77
|
+
# Generate run ID with timestamp
|
|
78
|
+
timestamp = datetime.now().strftime("%Y%m%dT%H%M%S")
|
|
79
|
+
SQLLogger._run_id = f"runid_{timestamp}"
|
|
80
|
+
SQLLogger._run_directory = self.log_dir / SQLLogger._run_id
|
|
81
|
+
SQLLogger._run_directory.mkdir(parents=True, exist_ok=True)
|
|
82
|
+
return SQLLogger._run_directory
|
|
83
|
+
|
|
84
|
+
def should_log(self, log_sql: Optional[bool] = None) -> bool:
|
|
85
|
+
"""Determine if SQL should be logged based on environment and parameters.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
log_sql: Explicit parameter from test case
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
True if SQL should be logged
|
|
92
|
+
"""
|
|
93
|
+
# If explicitly set in test case, use that
|
|
94
|
+
if log_sql is not None:
|
|
95
|
+
return log_sql
|
|
96
|
+
|
|
97
|
+
# Check environment variable
|
|
98
|
+
return os.environ.get("SQL_TEST_LOG_ALL", "").lower() in ("true", "1", "yes")
|
|
99
|
+
|
|
100
|
+
def generate_filename(
|
|
101
|
+
self,
|
|
102
|
+
test_name: str,
|
|
103
|
+
test_class: Optional[str] = None,
|
|
104
|
+
test_file: Optional[str] = None,
|
|
105
|
+
failed: bool = False,
|
|
106
|
+
) -> str:
|
|
107
|
+
"""Generate a unique filename for the SQL log.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
test_name: Name of the test function
|
|
111
|
+
test_class: Name of the test class (if any)
|
|
112
|
+
test_file: Path to the test file
|
|
113
|
+
failed: Whether the test failed
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
Generated filename
|
|
117
|
+
"""
|
|
118
|
+
# Clean test name for filesystem (including square brackets)
|
|
119
|
+
clean_name = re.sub(r'[<>:"/\\|?*\[\]]', "_", test_name)
|
|
120
|
+
|
|
121
|
+
# Build filename components
|
|
122
|
+
components = []
|
|
123
|
+
|
|
124
|
+
# Add test file name (without path and extension)
|
|
125
|
+
if test_file:
|
|
126
|
+
file_base = Path(test_file).stem
|
|
127
|
+
components.append(file_base)
|
|
128
|
+
|
|
129
|
+
# Add class name if present
|
|
130
|
+
if test_class:
|
|
131
|
+
clean_class = re.sub(r'[<>:"/\\|?*\[\]]', "_", test_class)
|
|
132
|
+
components.append(clean_class)
|
|
133
|
+
|
|
134
|
+
# Add test name
|
|
135
|
+
components.append(clean_name)
|
|
136
|
+
|
|
137
|
+
# Add status indicator
|
|
138
|
+
if failed:
|
|
139
|
+
components.append("FAILED")
|
|
140
|
+
|
|
141
|
+
# Add timestamp for uniqueness
|
|
142
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] # Milliseconds
|
|
143
|
+
components.append(timestamp)
|
|
144
|
+
|
|
145
|
+
# Join with double underscore for clarity
|
|
146
|
+
filename = "__".join(components) + ".sql"
|
|
147
|
+
|
|
148
|
+
return filename
|
|
149
|
+
|
|
150
|
+
def format_sql(self, sql: str, dialect: Optional[str] = None) -> str:
|
|
151
|
+
"""Format SQL query for better readability.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
sql: SQL query to format
|
|
155
|
+
dialect: SQL dialect (e.g., 'bigquery', 'athena')
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
Formatted SQL
|
|
159
|
+
"""
|
|
160
|
+
try:
|
|
161
|
+
# Parse and format using sqlglot
|
|
162
|
+
parsed = parse_one(sql, dialect=dialect)
|
|
163
|
+
# IMPORTANT: Pass dialect to sql() to preserve dialect-specific syntax
|
|
164
|
+
formatted = parsed.sql(pretty=True, pad=2, dialect=dialect)
|
|
165
|
+
return formatted
|
|
166
|
+
except Exception:
|
|
167
|
+
# If formatting fails, return original
|
|
168
|
+
return sql
|
|
169
|
+
|
|
170
|
+
def create_metadata_header(
|
|
171
|
+
self,
|
|
172
|
+
test_name: str,
|
|
173
|
+
test_class: Optional[str] = None,
|
|
174
|
+
test_file: Optional[str] = None,
|
|
175
|
+
query: str = "",
|
|
176
|
+
default_namespace: Optional[str] = None,
|
|
177
|
+
mock_tables: Optional[List[BaseMockTable]] = None,
|
|
178
|
+
adapter_type: Optional[str] = None,
|
|
179
|
+
use_physical_tables: bool = False,
|
|
180
|
+
execution_time: Optional[float] = None,
|
|
181
|
+
row_count: Optional[int] = None,
|
|
182
|
+
error: Optional[str] = None,
|
|
183
|
+
error_traceback: Optional[str] = None,
|
|
184
|
+
temp_table_queries: Optional[List[str]] = None,
|
|
185
|
+
**kwargs: Any,
|
|
186
|
+
) -> str:
|
|
187
|
+
"""Create a metadata header for the SQL file.
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
Formatted metadata header as SQL comments
|
|
191
|
+
"""
|
|
192
|
+
lines = [
|
|
193
|
+
"-- SQL Test Case Log",
|
|
194
|
+
"-- " + "=" * 78,
|
|
195
|
+
f"-- Generated: {datetime.now().isoformat()}",
|
|
196
|
+
f"-- Run ID: {SQLLogger._run_id}",
|
|
197
|
+
f"-- Test Name: {test_name}",
|
|
198
|
+
]
|
|
199
|
+
|
|
200
|
+
if test_class:
|
|
201
|
+
lines.append(f"-- Test Class: {test_class}")
|
|
202
|
+
|
|
203
|
+
if test_file:
|
|
204
|
+
lines.append(f"-- Test File: {test_file}")
|
|
205
|
+
|
|
206
|
+
if adapter_type:
|
|
207
|
+
lines.append(f"-- Adapter: {adapter_type}")
|
|
208
|
+
|
|
209
|
+
# Show adapter name if different from sqlglot dialect
|
|
210
|
+
adapter_name = kwargs.get("adapter_name")
|
|
211
|
+
if adapter_name and adapter_name != adapter_type:
|
|
212
|
+
lines.append(f"-- Database: {adapter_name}")
|
|
213
|
+
|
|
214
|
+
if default_namespace:
|
|
215
|
+
lines.append(f"-- Default Namespace: {default_namespace}")
|
|
216
|
+
|
|
217
|
+
lines.append(f"-- Use Physical Tables: {use_physical_tables}")
|
|
218
|
+
|
|
219
|
+
if execution_time is not None:
|
|
220
|
+
lines.append(f"-- Execution Time: {execution_time:.3f} seconds")
|
|
221
|
+
|
|
222
|
+
if row_count is not None:
|
|
223
|
+
lines.append(f"-- Result Rows: {row_count}")
|
|
224
|
+
|
|
225
|
+
if error:
|
|
226
|
+
lines.extend(
|
|
227
|
+
[
|
|
228
|
+
"-- Status: FAILED",
|
|
229
|
+
"-- Error:",
|
|
230
|
+
]
|
|
231
|
+
)
|
|
232
|
+
for line in error.strip().split("\n"):
|
|
233
|
+
lines.append(f"-- {line}")
|
|
234
|
+
|
|
235
|
+
# Add full error traceback if available
|
|
236
|
+
if error_traceback:
|
|
237
|
+
lines.extend(
|
|
238
|
+
[
|
|
239
|
+
"",
|
|
240
|
+
"-- Full Error Details:",
|
|
241
|
+
"-- " + "-" * 78,
|
|
242
|
+
]
|
|
243
|
+
)
|
|
244
|
+
# Add each line of the traceback as a SQL comment
|
|
245
|
+
for line in error_traceback.strip().split("\n"):
|
|
246
|
+
lines.append(f"-- {line}")
|
|
247
|
+
else:
|
|
248
|
+
lines.append("-- Status: SUCCESS")
|
|
249
|
+
|
|
250
|
+
# Add mock tables information
|
|
251
|
+
if mock_tables:
|
|
252
|
+
lines.extend(
|
|
253
|
+
[
|
|
254
|
+
"",
|
|
255
|
+
"-- Mock Tables:",
|
|
256
|
+
"-- " + "-" * 78,
|
|
257
|
+
]
|
|
258
|
+
)
|
|
259
|
+
for table in mock_tables:
|
|
260
|
+
lines.append(f"-- Table: {table.get_table_name()}")
|
|
261
|
+
# Get row count from data
|
|
262
|
+
if hasattr(table, "data") and table.data:
|
|
263
|
+
lines.append(f"-- Rows: {len(table.data)}")
|
|
264
|
+
# Get column names from first row or column types
|
|
265
|
+
if hasattr(table, "get_column_types"):
|
|
266
|
+
columns = list(table.get_column_types().keys())
|
|
267
|
+
if columns:
|
|
268
|
+
lines.append(f"-- Columns: {', '.join(columns)}")
|
|
269
|
+
|
|
270
|
+
# Add original query
|
|
271
|
+
lines.extend(
|
|
272
|
+
[
|
|
273
|
+
"",
|
|
274
|
+
"-- Original Query:",
|
|
275
|
+
"-- " + "-" * 78,
|
|
276
|
+
]
|
|
277
|
+
)
|
|
278
|
+
# Comment out each line of the original query
|
|
279
|
+
for line in query.split("\n"):
|
|
280
|
+
lines.append(f"-- {line}")
|
|
281
|
+
|
|
282
|
+
# Add temp table queries if physical tables were used
|
|
283
|
+
if use_physical_tables and temp_table_queries:
|
|
284
|
+
lines.extend(
|
|
285
|
+
[
|
|
286
|
+
"",
|
|
287
|
+
"-- Temporary Table Creation Queries:",
|
|
288
|
+
"-- " + "-" * 78,
|
|
289
|
+
"",
|
|
290
|
+
]
|
|
291
|
+
)
|
|
292
|
+
for i, temp_query in enumerate(temp_table_queries, 1):
|
|
293
|
+
lines.append(f"-- Query {i}:")
|
|
294
|
+
lines.append("")
|
|
295
|
+
# Format the temp table SQL
|
|
296
|
+
formatted_temp_sql = self.format_sql(temp_query, dialect=adapter_type)
|
|
297
|
+
lines.append(formatted_temp_sql)
|
|
298
|
+
lines.append("")
|
|
299
|
+
|
|
300
|
+
lines.extend(
|
|
301
|
+
[
|
|
302
|
+
"",
|
|
303
|
+
"-- Transformed Query:",
|
|
304
|
+
"-- " + "=" * 78,
|
|
305
|
+
"",
|
|
306
|
+
]
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
return "\n".join(lines)
|
|
310
|
+
|
|
311
|
+
def log_sql(
|
|
312
|
+
self,
|
|
313
|
+
sql: str,
|
|
314
|
+
test_name: str,
|
|
315
|
+
test_class: Optional[str] = None,
|
|
316
|
+
test_file: Optional[str] = None,
|
|
317
|
+
failed: bool = False,
|
|
318
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
319
|
+
) -> str:
|
|
320
|
+
"""Log SQL to a file and return the file path.
|
|
321
|
+
|
|
322
|
+
Args:
|
|
323
|
+
sql: The transformed SQL query to log
|
|
324
|
+
test_name: Name of the test
|
|
325
|
+
test_class: Test class name
|
|
326
|
+
test_file: Test file path
|
|
327
|
+
failed: Whether the test failed
|
|
328
|
+
metadata: Additional metadata to include
|
|
329
|
+
|
|
330
|
+
Returns:
|
|
331
|
+
Path to the created SQL file
|
|
332
|
+
"""
|
|
333
|
+
# Generate filename
|
|
334
|
+
filename = self.generate_filename(test_name, test_class, test_file, failed)
|
|
335
|
+
|
|
336
|
+
# Ensure run directory exists (lazy creation)
|
|
337
|
+
run_directory = self._ensure_run_directory()
|
|
338
|
+
filepath = run_directory / filename
|
|
339
|
+
|
|
340
|
+
# Prepare metadata
|
|
341
|
+
if metadata is None:
|
|
342
|
+
metadata = {}
|
|
343
|
+
|
|
344
|
+
# Create header
|
|
345
|
+
header = self.create_metadata_header(
|
|
346
|
+
test_name=test_name, test_class=test_class, test_file=test_file, **metadata
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
# Format SQL
|
|
350
|
+
dialect = metadata.get("adapter_type")
|
|
351
|
+
formatted_sql = self.format_sql(sql, dialect)
|
|
352
|
+
|
|
353
|
+
# Write to file
|
|
354
|
+
content = header + formatted_sql
|
|
355
|
+
filepath.write_text(content, encoding="utf-8")
|
|
356
|
+
|
|
357
|
+
# Track logged file
|
|
358
|
+
self._logged_files.append(str(filepath))
|
|
359
|
+
|
|
360
|
+
# Return absolute path for clickable URLs
|
|
361
|
+
return str(filepath.absolute())
|
|
362
|
+
|
|
363
|
+
def get_logged_files(self) -> List[str]:
|
|
364
|
+
"""Get list of files logged in this session."""
|
|
365
|
+
return self._logged_files.copy()
|
|
366
|
+
|
|
367
|
+
def clear_logged_files(self) -> None:
|
|
368
|
+
"""Clear the list of logged files."""
|
|
369
|
+
self._logged_files = []
|
|
370
|
+
|
|
371
|
+
@classmethod
|
|
372
|
+
def get_run_directory(cls) -> Optional[Path]:
|
|
373
|
+
"""Get the current run directory."""
|
|
374
|
+
return cls._run_directory
|
|
375
|
+
|
|
376
|
+
@classmethod
|
|
377
|
+
def get_run_id(cls) -> Optional[str]:
|
|
378
|
+
"""Get the current run ID."""
|
|
379
|
+
return cls._run_id
|
|
380
|
+
|
|
381
|
+
@classmethod
|
|
382
|
+
def reset_run_directory(cls) -> None:
|
|
383
|
+
"""Reset the run directory (useful for testing)."""
|
|
384
|
+
cls._run_directory = None
|
|
385
|
+
cls._run_id = None
|
{sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/src/sql_testing_library/_sql_utils.py
RENAMED
|
@@ -82,6 +82,7 @@ def format_sql_value(value: Any, column_type: Type, dialect: str = "standard") -
|
|
|
82
82
|
"""
|
|
83
83
|
from datetime import date, datetime
|
|
84
84
|
from decimal import Decimal
|
|
85
|
+
from typing import get_args
|
|
85
86
|
|
|
86
87
|
import pandas as pd
|
|
87
88
|
|
|
@@ -89,6 +90,42 @@ def format_sql_value(value: Any, column_type: Type, dialect: str = "standard") -
|
|
|
89
90
|
# Note: pd.isna() doesn't work on lists/arrays, so check for None first
|
|
90
91
|
# and only use pd.isna() on scalar values
|
|
91
92
|
if value is None or (not isinstance(value, (list, tuple)) and pd.isna(value)):
|
|
93
|
+
# Check if column_type is a List type
|
|
94
|
+
if hasattr(column_type, "__origin__") and column_type.__origin__ is list:
|
|
95
|
+
# Get the element type from List[T]
|
|
96
|
+
element_type = get_args(column_type)[0] if get_args(column_type) else str
|
|
97
|
+
|
|
98
|
+
if dialect in ("athena", "trino"):
|
|
99
|
+
# Map Python types to SQL types for array elements
|
|
100
|
+
if element_type == Decimal:
|
|
101
|
+
sql_element_type = "DECIMAL(38,9)"
|
|
102
|
+
elif element_type is int:
|
|
103
|
+
sql_element_type = "INTEGER" if dialect == "athena" else "BIGINT"
|
|
104
|
+
elif element_type is float:
|
|
105
|
+
sql_element_type = "DOUBLE"
|
|
106
|
+
elif element_type is bool:
|
|
107
|
+
sql_element_type = "BOOLEAN"
|
|
108
|
+
elif element_type is date:
|
|
109
|
+
sql_element_type = "DATE"
|
|
110
|
+
elif element_type == datetime:
|
|
111
|
+
sql_element_type = "TIMESTAMP"
|
|
112
|
+
else:
|
|
113
|
+
sql_element_type = "VARCHAR"
|
|
114
|
+
|
|
115
|
+
return f"CAST(NULL AS ARRAY({sql_element_type}))"
|
|
116
|
+
elif dialect == "bigquery":
|
|
117
|
+
# BigQuery doesn't need explicit NULL array casting
|
|
118
|
+
return "NULL"
|
|
119
|
+
elif dialect == "redshift":
|
|
120
|
+
# Redshift SUPER type handles NULL arrays
|
|
121
|
+
return "NULL::SUPER"
|
|
122
|
+
elif dialect == "snowflake":
|
|
123
|
+
# Snowflake VARIANT type handles NULL arrays
|
|
124
|
+
return "NULL::VARIANT"
|
|
125
|
+
else:
|
|
126
|
+
return "NULL"
|
|
127
|
+
|
|
128
|
+
# Handle non-array NULL values
|
|
92
129
|
if dialect == "redshift":
|
|
93
130
|
# Redshift needs type-specific NULL casting
|
|
94
131
|
if column_type == Decimal:
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/src/sql_testing_library/_exceptions.py
RENAMED
|
File without changes
|
{sql_testing_library-0.6.0 → sql_testing_library-0.7.1}/src/sql_testing_library/_mock_table.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|