sqlspec 0.10.1__py3-none-any.whl → 0.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/adapters/adbc/config.py +1 -1
- sqlspec/adapters/adbc/driver.py +340 -192
- sqlspec/adapters/aiosqlite/driver.py +183 -129
- sqlspec/adapters/asyncmy/driver.py +168 -88
- sqlspec/adapters/asyncpg/config.py +3 -1
- sqlspec/adapters/asyncpg/driver.py +208 -259
- sqlspec/adapters/bigquery/driver.py +184 -264
- sqlspec/adapters/duckdb/driver.py +172 -110
- sqlspec/adapters/oracledb/driver.py +274 -160
- sqlspec/adapters/psqlpy/driver.py +274 -211
- sqlspec/adapters/psycopg/driver.py +196 -283
- sqlspec/adapters/sqlite/driver.py +154 -142
- sqlspec/base.py +56 -85
- sqlspec/extensions/litestar/__init__.py +3 -12
- sqlspec/extensions/litestar/config.py +22 -7
- sqlspec/extensions/litestar/handlers.py +142 -85
- sqlspec/extensions/litestar/plugin.py +9 -8
- sqlspec/extensions/litestar/providers.py +521 -0
- sqlspec/filters.py +215 -11
- sqlspec/mixins.py +161 -12
- sqlspec/statement.py +276 -271
- sqlspec/typing.py +18 -1
- sqlspec/utils/__init__.py +2 -2
- sqlspec/utils/singleton.py +35 -0
- sqlspec/utils/sync_tools.py +90 -151
- sqlspec/utils/text.py +68 -5
- {sqlspec-0.10.1.dist-info → sqlspec-0.11.1.dist-info}/METADATA +8 -1
- {sqlspec-0.10.1.dist-info → sqlspec-0.11.1.dist-info}/RECORD +31 -29
- {sqlspec-0.10.1.dist-info → sqlspec-0.11.1.dist-info}/WHEEL +0 -0
- {sqlspec-0.10.1.dist-info → sqlspec-0.11.1.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.10.1.dist-info → sqlspec-0.11.1.dist-info}/licenses/NOTICE +0 -0
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import contextlib
|
|
2
2
|
import datetime
|
|
3
|
-
|
|
3
|
+
import logging
|
|
4
|
+
from collections.abc import Iterator, Mapping, Sequence
|
|
4
5
|
from decimal import Decimal
|
|
5
6
|
from typing import (
|
|
6
7
|
TYPE_CHECKING,
|
|
@@ -12,19 +13,21 @@ from typing import (
|
|
|
12
13
|
overload,
|
|
13
14
|
)
|
|
14
15
|
|
|
15
|
-
import sqlglot
|
|
16
16
|
from google.cloud import bigquery
|
|
17
17
|
from google.cloud.bigquery import Client
|
|
18
18
|
from google.cloud.bigquery.job import QueryJob, QueryJobConfig
|
|
19
19
|
from google.cloud.exceptions import NotFound
|
|
20
20
|
|
|
21
21
|
from sqlspec.base import SyncDriverAdapterProtocol
|
|
22
|
-
from sqlspec.exceptions import NotFoundError, SQLSpecError
|
|
22
|
+
from sqlspec.exceptions import NotFoundError, ParameterStyleMismatchError, SQLSpecError
|
|
23
|
+
from sqlspec.filters import StatementFilter
|
|
23
24
|
from sqlspec.mixins import (
|
|
25
|
+
ResultConverter,
|
|
24
26
|
SQLTranslatorMixin,
|
|
25
27
|
SyncArrowBulkOperationsMixin,
|
|
26
28
|
SyncParquetExportMixin,
|
|
27
29
|
)
|
|
30
|
+
from sqlspec.statement import SQLStatement
|
|
28
31
|
from sqlspec.typing import ArrowTable, ModelDTOT, StatementParameterType, T
|
|
29
32
|
|
|
30
33
|
if TYPE_CHECKING:
|
|
@@ -35,12 +38,15 @@ __all__ = ("BigQueryConnection", "BigQueryDriver")
|
|
|
35
38
|
|
|
36
39
|
BigQueryConnection = Client
|
|
37
40
|
|
|
41
|
+
logger = logging.getLogger("sqlspec")
|
|
42
|
+
|
|
38
43
|
|
|
39
44
|
class BigQueryDriver(
|
|
40
45
|
SyncDriverAdapterProtocol["BigQueryConnection"],
|
|
41
46
|
SyncArrowBulkOperationsMixin["BigQueryConnection"],
|
|
42
47
|
SyncParquetExportMixin["BigQueryConnection"],
|
|
43
48
|
SQLTranslatorMixin["BigQueryConnection"],
|
|
49
|
+
ResultConverter,
|
|
44
50
|
):
|
|
45
51
|
"""Synchronous BigQuery Driver Adapter."""
|
|
46
52
|
|
|
@@ -55,7 +61,7 @@ class BigQueryDriver(
|
|
|
55
61
|
)
|
|
56
62
|
|
|
57
63
|
@staticmethod
|
|
58
|
-
def _get_bq_param_type(value: Any) -> "tuple[Optional[str], Optional[str]]":
|
|
64
|
+
def _get_bq_param_type(value: Any) -> "tuple[Optional[str], Optional[str]]":
|
|
59
65
|
if isinstance(value, bool):
|
|
60
66
|
return "BOOL", None
|
|
61
67
|
if isinstance(value, int):
|
|
@@ -63,8 +69,6 @@ class BigQueryDriver(
|
|
|
63
69
|
if isinstance(value, float):
|
|
64
70
|
return "FLOAT64", None
|
|
65
71
|
if isinstance(value, Decimal):
|
|
66
|
-
# Precision/scale might matter, but BQ client handles conversion.
|
|
67
|
-
# Defaulting to BIGNUMERIC, NUMERIC might be desired in some cases though (User change)
|
|
68
72
|
return "BIGNUMERIC", None
|
|
69
73
|
if isinstance(value, str):
|
|
70
74
|
return "STRING", None
|
|
@@ -72,23 +76,17 @@ class BigQueryDriver(
|
|
|
72
76
|
return "BYTES", None
|
|
73
77
|
if isinstance(value, datetime.date):
|
|
74
78
|
return "DATE", None
|
|
75
|
-
# DATETIME is for timezone-naive values
|
|
76
79
|
if isinstance(value, datetime.datetime) and value.tzinfo is None:
|
|
77
80
|
return "DATETIME", None
|
|
78
|
-
# TIMESTAMP is for timezone-aware values
|
|
79
81
|
if isinstance(value, datetime.datetime) and value.tzinfo is not None:
|
|
80
82
|
return "TIMESTAMP", None
|
|
81
83
|
if isinstance(value, datetime.time):
|
|
82
84
|
return "TIME", None
|
|
83
85
|
|
|
84
|
-
# Handle Arrays - Determine element type
|
|
85
86
|
if isinstance(value, (list, tuple)):
|
|
86
87
|
if not value:
|
|
87
|
-
# Cannot determine type of empty array, BQ requires type.
|
|
88
|
-
# Raise or default? Defaulting is risky. Let's raise.
|
|
89
88
|
msg = "Cannot determine BigQuery ARRAY type for empty sequence."
|
|
90
89
|
raise SQLSpecError(msg)
|
|
91
|
-
# Infer type from first element
|
|
92
90
|
first_element = value[0]
|
|
93
91
|
element_type, _ = BigQueryDriver._get_bq_param_type(first_element)
|
|
94
92
|
if element_type is None:
|
|
@@ -96,19 +94,68 @@ class BigQueryDriver(
|
|
|
96
94
|
raise SQLSpecError(msg)
|
|
97
95
|
return "ARRAY", element_type
|
|
98
96
|
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
97
|
+
return None, None
|
|
98
|
+
|
|
99
|
+
def _process_sql_params(
|
|
100
|
+
self,
|
|
101
|
+
sql: str,
|
|
102
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
103
|
+
*filters: "StatementFilter",
|
|
104
|
+
**kwargs: Any,
|
|
105
|
+
) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]":
|
|
106
|
+
"""Process SQL and parameters using SQLStatement with dialect support.
|
|
107
|
+
|
|
108
|
+
This method also handles the separation of StatementFilter instances that might be
|
|
109
|
+
passed in the 'parameters' argument.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
sql: The SQL statement to process.
|
|
113
|
+
parameters: The parameters to bind to the statement. This can be a
|
|
114
|
+
Mapping (dict), Sequence (list/tuple), a single StatementFilter, or None.
|
|
115
|
+
*filters: Additional statement filters to apply.
|
|
116
|
+
**kwargs: Additional keyword arguments (treated as named parameters for the SQL statement).
|
|
117
|
+
|
|
118
|
+
Raises:
|
|
119
|
+
ParameterStyleMismatchError: If pre-formatted BigQuery parameters are mixed with keyword arguments.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
A tuple of (processed_sql, processed_parameters) ready for execution.
|
|
123
|
+
"""
|
|
124
|
+
passed_parameters: Optional[Union[Mapping[str, Any], Sequence[Any]]] = None
|
|
125
|
+
combined_filters_list: list[StatementFilter] = list(filters)
|
|
126
|
+
|
|
127
|
+
if parameters is not None:
|
|
128
|
+
if isinstance(parameters, StatementFilter):
|
|
129
|
+
combined_filters_list.insert(0, parameters)
|
|
130
|
+
else:
|
|
131
|
+
passed_parameters = parameters
|
|
132
|
+
|
|
133
|
+
if (
|
|
134
|
+
isinstance(passed_parameters, (list, tuple))
|
|
135
|
+
and passed_parameters
|
|
136
|
+
and all(
|
|
137
|
+
isinstance(p, (bigquery.ScalarQueryParameter, bigquery.ArrayQueryParameter)) for p in passed_parameters
|
|
138
|
+
)
|
|
139
|
+
):
|
|
140
|
+
if kwargs:
|
|
141
|
+
msg = "Cannot mix pre-formatted BigQuery parameters with keyword arguments."
|
|
142
|
+
raise ParameterStyleMismatchError(msg)
|
|
143
|
+
return sql, passed_parameters
|
|
144
|
+
|
|
145
|
+
statement = SQLStatement(sql, passed_parameters, kwargs=kwargs, dialect=self.dialect)
|
|
146
|
+
|
|
147
|
+
for filter_obj in combined_filters_list:
|
|
148
|
+
statement = statement.apply_filter(filter_obj)
|
|
105
149
|
|
|
106
|
-
|
|
150
|
+
processed_sql, processed_params, _ = statement.process()
|
|
107
151
|
|
|
108
|
-
|
|
152
|
+
return processed_sql, processed_params
|
|
153
|
+
|
|
154
|
+
def _run_query_job(
|
|
109
155
|
self,
|
|
110
156
|
sql: str,
|
|
111
157
|
parameters: "Optional[StatementParameterType]" = None,
|
|
158
|
+
*filters: "StatementFilter",
|
|
112
159
|
connection: "Optional[BigQueryConnection]" = None,
|
|
113
160
|
job_config: "Optional[QueryJobConfig]" = None,
|
|
114
161
|
is_script: bool = False,
|
|
@@ -116,131 +163,71 @@ class BigQueryDriver(
|
|
|
116
163
|
) -> "QueryJob":
|
|
117
164
|
conn = self._connection(connection)
|
|
118
165
|
|
|
119
|
-
# Determine the final job config, creating a new one if necessary
|
|
120
|
-
# to avoid modifying a shared default config.
|
|
121
166
|
if job_config:
|
|
122
|
-
final_job_config = job_config
|
|
167
|
+
final_job_config = job_config
|
|
123
168
|
elif self._default_query_job_config:
|
|
124
|
-
final_job_config = QueryJobConfig()
|
|
169
|
+
final_job_config = QueryJobConfig.from_api_repr(self._default_query_job_config.to_api_repr()) # type: ignore[assignment]
|
|
125
170
|
else:
|
|
126
|
-
final_job_config = QueryJobConfig()
|
|
171
|
+
final_job_config = QueryJobConfig()
|
|
127
172
|
|
|
128
|
-
|
|
129
|
-
params: Union[dict[str, Any], list[Any], None] = None
|
|
130
|
-
param_style: Optional[str] = None # 'named' (@), 'qmark' (?)
|
|
131
|
-
use_preformatted_params = False
|
|
132
|
-
final_sql = sql # Default to original SQL
|
|
173
|
+
final_sql, processed_params = self._process_sql_params(sql, parameters, *filters, **kwargs)
|
|
133
174
|
|
|
134
|
-
# Check for pre-formatted BQ parameters first
|
|
135
175
|
if (
|
|
136
|
-
isinstance(
|
|
137
|
-
and
|
|
138
|
-
and all(
|
|
176
|
+
isinstance(processed_params, (list, tuple))
|
|
177
|
+
and processed_params
|
|
178
|
+
and all(
|
|
179
|
+
isinstance(p, (bigquery.ScalarQueryParameter, bigquery.ArrayQueryParameter)) for p in processed_params
|
|
180
|
+
)
|
|
139
181
|
):
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
param_style = "named"
|
|
152
|
-
elif isinstance(parameters, (list, tuple)):
|
|
153
|
-
if kwargs:
|
|
154
|
-
msg = "Cannot mix positional parameters with keyword arguments."
|
|
155
|
-
raise SQLSpecError(msg)
|
|
156
|
-
# Check if it's primitives for qmark style
|
|
157
|
-
if all(
|
|
158
|
-
not isinstance(p, (bigquery.ScalarQueryParameter, bigquery.ArrayQueryParameter)) for p in parameters
|
|
159
|
-
):
|
|
160
|
-
params = list(parameters)
|
|
161
|
-
param_style = "qmark"
|
|
162
|
-
else:
|
|
163
|
-
# Mixed list or non-BQ parameter objects
|
|
164
|
-
msg = "Invalid mix of parameter types in list. Use only primitive values or only BigQuery QueryParameter objects."
|
|
165
|
-
raise SQLSpecError(msg)
|
|
166
|
-
|
|
167
|
-
elif kwargs:
|
|
168
|
-
params = kwargs
|
|
169
|
-
param_style = "named"
|
|
170
|
-
elif parameters is not None and not isinstance(
|
|
171
|
-
parameters, (bigquery.ScalarQueryParameter, bigquery.ArrayQueryParameter)
|
|
172
|
-
):
|
|
173
|
-
# Could be a single primitive value for positional
|
|
174
|
-
params = [parameters]
|
|
175
|
-
param_style = "qmark"
|
|
176
|
-
elif parameters is not None: # Single BQ parameter object
|
|
177
|
-
msg = "Single BigQuery QueryParameter objects should be passed within a list."
|
|
178
|
-
raise SQLSpecError(msg)
|
|
179
|
-
|
|
180
|
-
# Use sqlglot to transpile ONLY if not a script and not preformatted
|
|
181
|
-
if not is_script and not use_preformatted_params:
|
|
182
|
-
try:
|
|
183
|
-
# Transpile for syntax normalization/dialect conversion if needed
|
|
184
|
-
# Use BigQuery dialect for both reading and writing
|
|
185
|
-
final_sql = sqlglot.transpile(sql, read=self.dialect, write=self.dialect)[0]
|
|
186
|
-
except Exception as e:
|
|
187
|
-
# Catch potential sqlglot errors
|
|
188
|
-
msg = f"SQL transpilation failed using sqlglot: {e!s}" # Adjusted message
|
|
189
|
-
raise SQLSpecError(msg) from e
|
|
190
|
-
# else: If preformatted_params, final_sql remains the original sql
|
|
191
|
-
|
|
192
|
-
# --- Parameter Handling Logic --- (Moved outside the transpilation try/except)
|
|
193
|
-
# Prepare BQ parameters based on style, ONLY if not preformatted
|
|
194
|
-
if not use_preformatted_params:
|
|
195
|
-
if param_style == "named" and params:
|
|
196
|
-
# Convert dict params to BQ ScalarQueryParameter
|
|
197
|
-
if isinstance(params, dict):
|
|
198
|
-
final_job_config.query_parameters = [
|
|
199
|
-
bigquery.ScalarQueryParameter(name, self._get_bq_param_type(value)[0], value)
|
|
200
|
-
for name, value in params.items()
|
|
201
|
-
]
|
|
202
|
-
else:
|
|
203
|
-
# This path should ideally not be reached if param_style logic is correct
|
|
204
|
-
msg = f"Internal error: Parameter style is 'named' but parameters are not a dict: {type(params)}"
|
|
205
|
-
raise SQLSpecError(msg)
|
|
206
|
-
elif param_style == "qmark" and params:
|
|
207
|
-
# Convert list params to BQ ScalarQueryParameter
|
|
208
|
-
final_job_config.query_parameters = [
|
|
209
|
-
bigquery.ScalarQueryParameter(None, self._get_bq_param_type(value)[0], value) for value in params
|
|
210
|
-
]
|
|
211
|
-
|
|
212
|
-
# --- Parameter Handling Logic --- End
|
|
182
|
+
final_job_config.query_parameters = list(processed_params)
|
|
183
|
+
elif isinstance(processed_params, dict):
|
|
184
|
+
final_job_config.query_parameters = [
|
|
185
|
+
bigquery.ScalarQueryParameter(name, self._get_bq_param_type(value)[0], value)
|
|
186
|
+
for name, value in processed_params.items()
|
|
187
|
+
]
|
|
188
|
+
elif isinstance(processed_params, (list, tuple)):
|
|
189
|
+
final_job_config.query_parameters = [
|
|
190
|
+
bigquery.ScalarQueryParameter(None, self._get_bq_param_type(value)[0], value)
|
|
191
|
+
for value in processed_params
|
|
192
|
+
]
|
|
213
193
|
|
|
214
|
-
# Determine which kwargs to pass to the actual query method.
|
|
215
|
-
# We only want to pass kwargs that were *not* treated as SQL parameters.
|
|
216
194
|
final_query_kwargs = {}
|
|
217
|
-
if parameters is not None and kwargs:
|
|
195
|
+
if parameters is not None and kwargs:
|
|
218
196
|
final_query_kwargs = kwargs
|
|
219
|
-
# Else: If params came via kwargs, they are already handled, so don't pass them again.
|
|
220
197
|
|
|
221
|
-
# Execute query
|
|
222
198
|
return conn.query(
|
|
223
199
|
final_sql,
|
|
224
|
-
job_config=final_job_config,
|
|
225
|
-
**final_query_kwargs,
|
|
200
|
+
job_config=final_job_config, # pyright: ignore
|
|
201
|
+
**final_query_kwargs,
|
|
226
202
|
)
|
|
227
203
|
|
|
228
|
-
@
|
|
204
|
+
@overload
|
|
229
205
|
def _rows_to_results(
|
|
206
|
+
self,
|
|
207
|
+
rows: "Iterator[Row]",
|
|
208
|
+
schema: "Sequence[SchemaField]",
|
|
209
|
+
schema_type: "type[ModelDTOT]",
|
|
210
|
+
) -> Sequence[ModelDTOT]: ...
|
|
211
|
+
@overload
|
|
212
|
+
def _rows_to_results(
|
|
213
|
+
self,
|
|
214
|
+
rows: "Iterator[Row]",
|
|
215
|
+
schema: "Sequence[SchemaField]",
|
|
216
|
+
schema_type: None = None,
|
|
217
|
+
) -> Sequence[dict[str, Any]]: ...
|
|
218
|
+
def _rows_to_results(
|
|
219
|
+
self,
|
|
230
220
|
rows: "Iterator[Row]",
|
|
231
221
|
schema: "Sequence[SchemaField]",
|
|
232
222
|
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
233
223
|
) -> Sequence[Union[ModelDTOT, dict[str, Any]]]:
|
|
234
224
|
processed_results = []
|
|
235
|
-
# Create a quick lookup map for schema fields from the passed schema
|
|
236
225
|
schema_map = {field.name: field for field in schema}
|
|
237
226
|
|
|
238
227
|
for row in rows:
|
|
239
|
-
# row here is now a Row object from the iterator
|
|
240
228
|
row_dict = {}
|
|
241
|
-
for key, value in row.items():
|
|
229
|
+
for key, value in row.items():
|
|
242
230
|
field = schema_map.get(key)
|
|
243
|
-
# Workaround remains the same
|
|
244
231
|
if field and field.field_type == "TIMESTAMP" and isinstance(value, str) and "." in value:
|
|
245
232
|
try:
|
|
246
233
|
parsed_value = datetime.datetime.fromtimestamp(float(value), tz=datetime.timezone.utc)
|
|
@@ -249,22 +236,15 @@ class BigQueryDriver(
|
|
|
249
236
|
row_dict[key] = value # type: ignore[assignment]
|
|
250
237
|
else:
|
|
251
238
|
row_dict[key] = value
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
processed_results.append(schema_type(**row_dict))
|
|
255
|
-
else:
|
|
256
|
-
processed_results.append(row_dict) # type: ignore[arg-type]
|
|
257
|
-
if schema_type:
|
|
258
|
-
return cast("Sequence[ModelDTOT]", processed_results)
|
|
259
|
-
return cast("Sequence[dict[str, Any]]", processed_results)
|
|
239
|
+
processed_results.append(row_dict)
|
|
240
|
+
return self.to_schema(processed_results, schema_type=schema_type)
|
|
260
241
|
|
|
261
242
|
@overload
|
|
262
243
|
def select(
|
|
263
244
|
self,
|
|
264
245
|
sql: str,
|
|
265
246
|
parameters: "Optional[StatementParameterType]" = None,
|
|
266
|
-
|
|
267
|
-
*,
|
|
247
|
+
*filters: "StatementFilter",
|
|
268
248
|
connection: "Optional[BigQueryConnection]" = None,
|
|
269
249
|
schema_type: None = None,
|
|
270
250
|
**kwargs: Any,
|
|
@@ -274,8 +254,7 @@ class BigQueryDriver(
|
|
|
274
254
|
self,
|
|
275
255
|
sql: str,
|
|
276
256
|
parameters: "Optional[StatementParameterType]" = None,
|
|
277
|
-
|
|
278
|
-
*,
|
|
257
|
+
*filters: "StatementFilter",
|
|
279
258
|
connection: "Optional[BigQueryConnection]" = None,
|
|
280
259
|
schema_type: "type[ModelDTOT]",
|
|
281
260
|
**kwargs: Any,
|
|
@@ -284,14 +263,15 @@ class BigQueryDriver(
|
|
|
284
263
|
self,
|
|
285
264
|
sql: str,
|
|
286
265
|
parameters: "Optional[StatementParameterType]" = None,
|
|
287
|
-
|
|
288
|
-
*,
|
|
266
|
+
*filters: "StatementFilter",
|
|
289
267
|
connection: "Optional[BigQueryConnection]" = None,
|
|
290
268
|
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
291
269
|
job_config: "Optional[QueryJobConfig]" = None,
|
|
292
270
|
**kwargs: Any,
|
|
293
271
|
) -> "Sequence[Union[ModelDTOT, dict[str, Any]]]":
|
|
294
|
-
query_job = self._run_query_job(
|
|
272
|
+
query_job = self._run_query_job(
|
|
273
|
+
sql, parameters, *filters, connection=connection, job_config=job_config, **kwargs
|
|
274
|
+
)
|
|
295
275
|
return self._rows_to_results(query_job.result(), query_job.result().schema, schema_type)
|
|
296
276
|
|
|
297
277
|
@overload
|
|
@@ -299,8 +279,7 @@ class BigQueryDriver(
|
|
|
299
279
|
self,
|
|
300
280
|
sql: str,
|
|
301
281
|
parameters: "Optional[StatementParameterType]" = None,
|
|
302
|
-
|
|
303
|
-
*,
|
|
282
|
+
*filters: "StatementFilter",
|
|
304
283
|
connection: "Optional[BigQueryConnection]" = None,
|
|
305
284
|
schema_type: None = None,
|
|
306
285
|
**kwargs: Any,
|
|
@@ -310,8 +289,7 @@ class BigQueryDriver(
|
|
|
310
289
|
self,
|
|
311
290
|
sql: str,
|
|
312
291
|
parameters: "Optional[StatementParameterType]" = None,
|
|
313
|
-
|
|
314
|
-
*,
|
|
292
|
+
*filters: "StatementFilter",
|
|
315
293
|
connection: "Optional[BigQueryConnection]" = None,
|
|
316
294
|
schema_type: "type[ModelDTOT]",
|
|
317
295
|
**kwargs: Any,
|
|
@@ -320,24 +298,19 @@ class BigQueryDriver(
|
|
|
320
298
|
self,
|
|
321
299
|
sql: str,
|
|
322
300
|
parameters: "Optional[StatementParameterType]" = None,
|
|
323
|
-
|
|
324
|
-
*,
|
|
301
|
+
*filters: "StatementFilter",
|
|
325
302
|
connection: "Optional[BigQueryConnection]" = None,
|
|
326
303
|
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
327
304
|
job_config: "Optional[QueryJobConfig]" = None,
|
|
328
305
|
**kwargs: Any,
|
|
329
306
|
) -> "Union[ModelDTOT, dict[str, Any]]":
|
|
330
|
-
query_job = self._run_query_job(
|
|
307
|
+
query_job = self._run_query_job(
|
|
308
|
+
sql, parameters, *filters, connection=connection, job_config=job_config, **kwargs
|
|
309
|
+
)
|
|
331
310
|
rows_iterator = query_job.result()
|
|
332
311
|
try:
|
|
333
|
-
# Pass the iterator containing only the first row to _rows_to_results
|
|
334
|
-
# This ensures the timestamp workaround is applied consistently.
|
|
335
|
-
# We need to pass the original iterator for schema access, but only consume one row.
|
|
336
312
|
first_row = next(rows_iterator)
|
|
337
|
-
# Create a simple iterator yielding only the first row for processing
|
|
338
313
|
single_row_iter = iter([first_row])
|
|
339
|
-
# We need RowIterator type for schema, create mock/proxy if needed, or pass schema
|
|
340
|
-
# Let's try passing schema directly to _rows_to_results (requires modifying it)
|
|
341
314
|
results = self._rows_to_results(single_row_iter, rows_iterator.schema, schema_type)
|
|
342
315
|
return results[0]
|
|
343
316
|
except StopIteration:
|
|
@@ -349,8 +322,7 @@ class BigQueryDriver(
|
|
|
349
322
|
self,
|
|
350
323
|
sql: str,
|
|
351
324
|
parameters: "Optional[StatementParameterType]" = None,
|
|
352
|
-
|
|
353
|
-
*,
|
|
325
|
+
*filters: "StatementFilter",
|
|
354
326
|
connection: "Optional[BigQueryConnection]" = None,
|
|
355
327
|
schema_type: None = None,
|
|
356
328
|
**kwargs: Any,
|
|
@@ -360,8 +332,7 @@ class BigQueryDriver(
|
|
|
360
332
|
self,
|
|
361
333
|
sql: str,
|
|
362
334
|
parameters: "Optional[StatementParameterType]" = None,
|
|
363
|
-
|
|
364
|
-
*,
|
|
335
|
+
*filters: "StatementFilter",
|
|
365
336
|
connection: "Optional[BigQueryConnection]" = None,
|
|
366
337
|
schema_type: "type[ModelDTOT]",
|
|
367
338
|
**kwargs: Any,
|
|
@@ -370,20 +341,19 @@ class BigQueryDriver(
|
|
|
370
341
|
self,
|
|
371
342
|
sql: str,
|
|
372
343
|
parameters: "Optional[StatementParameterType]" = None,
|
|
373
|
-
|
|
374
|
-
*,
|
|
344
|
+
*filters: "StatementFilter",
|
|
375
345
|
connection: "Optional[BigQueryConnection]" = None,
|
|
376
346
|
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
377
347
|
job_config: "Optional[QueryJobConfig]" = None,
|
|
378
348
|
**kwargs: Any,
|
|
379
349
|
) -> "Optional[Union[ModelDTOT, dict[str, Any]]]":
|
|
380
|
-
query_job = self._run_query_job(
|
|
350
|
+
query_job = self._run_query_job(
|
|
351
|
+
sql, parameters, *filters, connection=connection, job_config=job_config, **kwargs
|
|
352
|
+
)
|
|
381
353
|
rows_iterator = query_job.result()
|
|
382
354
|
try:
|
|
383
355
|
first_row = next(rows_iterator)
|
|
384
|
-
# Create a simple iterator yielding only the first row for processing
|
|
385
356
|
single_row_iter = iter([first_row])
|
|
386
|
-
# Pass schema directly
|
|
387
357
|
results = self._rows_to_results(single_row_iter, rows_iterator.schema, schema_type)
|
|
388
358
|
return results[0]
|
|
389
359
|
except StopIteration:
|
|
@@ -394,8 +364,7 @@ class BigQueryDriver(
|
|
|
394
364
|
self,
|
|
395
365
|
sql: str,
|
|
396
366
|
parameters: "Optional[StatementParameterType]" = None,
|
|
397
|
-
|
|
398
|
-
*,
|
|
367
|
+
*filters: "StatementFilter",
|
|
399
368
|
connection: "Optional[BigQueryConnection]" = None,
|
|
400
369
|
schema_type: "Optional[type[T]]" = None,
|
|
401
370
|
job_config: "Optional[QueryJobConfig]" = None,
|
|
@@ -406,8 +375,7 @@ class BigQueryDriver(
|
|
|
406
375
|
self,
|
|
407
376
|
sql: str,
|
|
408
377
|
parameters: "Optional[StatementParameterType]" = None,
|
|
409
|
-
|
|
410
|
-
*,
|
|
378
|
+
*filters: "StatementFilter",
|
|
411
379
|
connection: "Optional[BigQueryConnection]" = None,
|
|
412
380
|
schema_type: "type[T]",
|
|
413
381
|
**kwargs: Any,
|
|
@@ -416,22 +384,20 @@ class BigQueryDriver(
|
|
|
416
384
|
self,
|
|
417
385
|
sql: str,
|
|
418
386
|
parameters: "Optional[StatementParameterType]" = None,
|
|
419
|
-
|
|
420
|
-
*,
|
|
387
|
+
*filters: "StatementFilter",
|
|
421
388
|
connection: "Optional[BigQueryConnection]" = None,
|
|
422
389
|
schema_type: "Optional[type[T]]" = None,
|
|
423
390
|
job_config: "Optional[QueryJobConfig]" = None,
|
|
424
391
|
**kwargs: Any,
|
|
425
392
|
) -> Union[T, Any]:
|
|
426
393
|
query_job = self._run_query_job(
|
|
427
|
-
sql
|
|
394
|
+
sql, parameters, *filters, connection=connection, job_config=job_config, **kwargs
|
|
428
395
|
)
|
|
429
396
|
rows = query_job.result()
|
|
430
397
|
try:
|
|
431
398
|
first_row = next(iter(rows))
|
|
432
399
|
value = first_row[0]
|
|
433
|
-
|
|
434
|
-
field = rows.schema[0] # Get schema for the first column
|
|
400
|
+
field = rows.schema[0]
|
|
435
401
|
if field and field.field_type == "TIMESTAMP" and isinstance(value, str) and "." in value:
|
|
436
402
|
with contextlib.suppress(ValueError):
|
|
437
403
|
value = datetime.datetime.fromtimestamp(float(value), tz=datetime.timezone.utc)
|
|
@@ -446,8 +412,7 @@ class BigQueryDriver(
|
|
|
446
412
|
self,
|
|
447
413
|
sql: str,
|
|
448
414
|
parameters: "Optional[StatementParameterType]" = None,
|
|
449
|
-
|
|
450
|
-
*,
|
|
415
|
+
*filters: "StatementFilter",
|
|
451
416
|
connection: "Optional[BigQueryConnection]" = None,
|
|
452
417
|
schema_type: None = None,
|
|
453
418
|
**kwargs: Any,
|
|
@@ -457,8 +422,7 @@ class BigQueryDriver(
|
|
|
457
422
|
self,
|
|
458
423
|
sql: str,
|
|
459
424
|
parameters: "Optional[StatementParameterType]" = None,
|
|
460
|
-
|
|
461
|
-
*,
|
|
425
|
+
*filters: "StatementFilter",
|
|
462
426
|
connection: "Optional[BigQueryConnection]" = None,
|
|
463
427
|
schema_type: "type[T]",
|
|
464
428
|
**kwargs: Any,
|
|
@@ -467,22 +431,25 @@ class BigQueryDriver(
|
|
|
467
431
|
self,
|
|
468
432
|
sql: str,
|
|
469
433
|
parameters: "Optional[StatementParameterType]" = None,
|
|
470
|
-
|
|
471
|
-
*,
|
|
434
|
+
*filters: "StatementFilter",
|
|
472
435
|
connection: "Optional[BigQueryConnection]" = None,
|
|
473
436
|
schema_type: "Optional[type[T]]" = None,
|
|
474
437
|
job_config: "Optional[QueryJobConfig]" = None,
|
|
475
438
|
**kwargs: Any,
|
|
476
439
|
) -> "Optional[Union[T, Any]]":
|
|
477
440
|
query_job = self._run_query_job(
|
|
478
|
-
sql
|
|
441
|
+
sql,
|
|
442
|
+
parameters,
|
|
443
|
+
*filters,
|
|
444
|
+
connection=connection,
|
|
445
|
+
job_config=job_config,
|
|
446
|
+
**kwargs,
|
|
479
447
|
)
|
|
480
448
|
rows = query_job.result()
|
|
481
449
|
try:
|
|
482
450
|
first_row = next(iter(rows))
|
|
483
451
|
value = first_row[0]
|
|
484
|
-
|
|
485
|
-
field = rows.schema[0] # Get schema for the first column
|
|
452
|
+
field = rows.schema[0]
|
|
486
453
|
if field and field.field_type == "TIMESTAMP" and isinstance(value, str) and "." in value:
|
|
487
454
|
with contextlib.suppress(ValueError):
|
|
488
455
|
value = datetime.datetime.fromtimestamp(float(value), tz=datetime.timezone.utc)
|
|
@@ -495,32 +462,23 @@ class BigQueryDriver(
|
|
|
495
462
|
self,
|
|
496
463
|
sql: str,
|
|
497
464
|
parameters: Optional[StatementParameterType] = None,
|
|
498
|
-
|
|
499
|
-
*,
|
|
465
|
+
*filters: "StatementFilter",
|
|
500
466
|
connection: Optional["BigQueryConnection"] = None,
|
|
501
467
|
job_config: Optional[QueryJobConfig] = None,
|
|
502
468
|
**kwargs: Any,
|
|
503
469
|
) -> int:
|
|
504
|
-
"""Executes INSERT, UPDATE, DELETE and returns affected row count.
|
|
505
|
-
|
|
506
|
-
Returns:
|
|
507
|
-
int: The number of rows affected by the DML statement.
|
|
508
|
-
"""
|
|
509
470
|
query_job = self._run_query_job(
|
|
510
|
-
sql
|
|
471
|
+
sql, parameters, *filters, connection=connection, job_config=job_config, **kwargs
|
|
511
472
|
)
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
query_job.result() # Ensure completion
|
|
515
|
-
return query_job.num_dml_affected_rows or 0 # Return 0 if None
|
|
473
|
+
query_job.result()
|
|
474
|
+
return query_job.num_dml_affected_rows or 0
|
|
516
475
|
|
|
517
476
|
@overload
|
|
518
477
|
def insert_update_delete_returning(
|
|
519
478
|
self,
|
|
520
479
|
sql: str,
|
|
521
480
|
parameters: "Optional[StatementParameterType]" = None,
|
|
522
|
-
|
|
523
|
-
*,
|
|
481
|
+
*filters: "StatementFilter",
|
|
524
482
|
connection: "Optional[BigQueryConnection]" = None,
|
|
525
483
|
schema_type: None = None,
|
|
526
484
|
**kwargs: Any,
|
|
@@ -530,8 +488,7 @@ class BigQueryDriver(
|
|
|
530
488
|
self,
|
|
531
489
|
sql: str,
|
|
532
490
|
parameters: "Optional[StatementParameterType]" = None,
|
|
533
|
-
|
|
534
|
-
*,
|
|
491
|
+
*filters: "StatementFilter",
|
|
535
492
|
connection: "Optional[BigQueryConnection]" = None,
|
|
536
493
|
schema_type: "type[ModelDTOT]",
|
|
537
494
|
**kwargs: Any,
|
|
@@ -540,35 +497,26 @@ class BigQueryDriver(
|
|
|
540
497
|
self,
|
|
541
498
|
sql: str,
|
|
542
499
|
parameters: "Optional[StatementParameterType]" = None,
|
|
543
|
-
|
|
544
|
-
*,
|
|
500
|
+
*filters: "StatementFilter",
|
|
545
501
|
connection: "Optional[BigQueryConnection]" = None,
|
|
546
502
|
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
547
503
|
job_config: "Optional[QueryJobConfig]" = None,
|
|
548
504
|
**kwargs: Any,
|
|
549
505
|
) -> Union[ModelDTOT, dict[str, Any]]:
|
|
550
|
-
"""BigQuery DML RETURNING equivalent is complex, often requires temp tables or scripting."""
|
|
551
506
|
msg = "BigQuery does not support `RETURNING` clauses directly in the same way as some other SQL databases. Consider multi-statement queries or alternative approaches."
|
|
552
507
|
raise NotImplementedError(msg)
|
|
553
508
|
|
|
554
509
|
def execute_script(
|
|
555
510
|
self,
|
|
556
|
-
sql: str,
|
|
557
|
-
parameters: "Optional[StatementParameterType]" = None,
|
|
558
|
-
/,
|
|
559
|
-
*,
|
|
511
|
+
sql: str,
|
|
512
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
560
513
|
connection: "Optional[BigQueryConnection]" = None,
|
|
561
514
|
job_config: "Optional[QueryJobConfig]" = None,
|
|
562
515
|
**kwargs: Any,
|
|
563
516
|
) -> str:
|
|
564
|
-
"""Executes a BigQuery script and returns the job ID.
|
|
565
|
-
|
|
566
|
-
Returns:
|
|
567
|
-
str: The job ID of the executed script.
|
|
568
|
-
"""
|
|
569
517
|
query_job = self._run_query_job(
|
|
570
|
-
sql
|
|
571
|
-
parameters
|
|
518
|
+
sql,
|
|
519
|
+
parameters,
|
|
572
520
|
connection=connection,
|
|
573
521
|
job_config=job_config,
|
|
574
522
|
is_script=True,
|
|
@@ -576,14 +524,11 @@ class BigQueryDriver(
|
|
|
576
524
|
)
|
|
577
525
|
return str(query_job.job_id)
|
|
578
526
|
|
|
579
|
-
#
|
|
580
|
-
|
|
581
|
-
def select_arrow( # pyright: ignore # noqa: PLR0912
|
|
527
|
+
def select_arrow( # pyright: ignore
|
|
582
528
|
self,
|
|
583
529
|
sql: str,
|
|
584
530
|
parameters: "Optional[StatementParameterType]" = None,
|
|
585
|
-
|
|
586
|
-
*,
|
|
531
|
+
*filters: "StatementFilter",
|
|
587
532
|
connection: "Optional[BigQueryConnection]" = None,
|
|
588
533
|
job_config: "Optional[QueryJobConfig]" = None,
|
|
589
534
|
**kwargs: Any,
|
|
@@ -591,41 +536,11 @@ class BigQueryDriver(
|
|
|
591
536
|
conn = self._connection(connection)
|
|
592
537
|
final_job_config = job_config or self._default_query_job_config or QueryJobConfig()
|
|
593
538
|
|
|
594
|
-
|
|
595
|
-
params: Union[dict[str, Any], list[Any], None] = None
|
|
596
|
-
param_style: Optional[str] = None # 'named' (@), 'qmark' (?)
|
|
597
|
-
|
|
598
|
-
if isinstance(parameters, dict):
|
|
599
|
-
params = {**parameters, **kwargs}
|
|
600
|
-
param_style = "named"
|
|
601
|
-
elif isinstance(parameters, (list, tuple)):
|
|
602
|
-
if kwargs:
|
|
603
|
-
msg = "Cannot mix positional parameters with keyword arguments."
|
|
604
|
-
raise SQLSpecError(msg)
|
|
605
|
-
params = list(parameters)
|
|
606
|
-
param_style = "qmark"
|
|
607
|
-
elif kwargs:
|
|
608
|
-
params = kwargs
|
|
609
|
-
param_style = "named"
|
|
610
|
-
elif parameters is not None:
|
|
611
|
-
params = [parameters]
|
|
612
|
-
param_style = "qmark"
|
|
613
|
-
|
|
614
|
-
# Use sqlglot to transpile and bind parameters
|
|
615
|
-
try:
|
|
616
|
-
transpiled_sql = sqlglot.transpile(sql, args=params or {}, read=None, write=self.dialect)[0]
|
|
617
|
-
except Exception as e:
|
|
618
|
-
msg = f"SQL transpilation/binding failed using sqlglot: {e!s}"
|
|
619
|
-
raise SQLSpecError(msg) from e
|
|
539
|
+
processed_sql, processed_params = self._process_sql_params(sql, parameters, *filters, **kwargs)
|
|
620
540
|
|
|
621
|
-
|
|
622
|
-
if param_style == "named" and params:
|
|
623
|
-
if not isinstance(params, dict):
|
|
624
|
-
# This should be logically impossible due to how param_style is set
|
|
625
|
-
msg = "Internal error: named parameter style detected but params is not a dict."
|
|
626
|
-
raise SQLSpecError(msg)
|
|
541
|
+
if isinstance(processed_params, dict):
|
|
627
542
|
query_parameters = []
|
|
628
|
-
for key, value in
|
|
543
|
+
for key, value in processed_params.items():
|
|
629
544
|
param_type, array_element_type = self._get_bq_param_type(value)
|
|
630
545
|
|
|
631
546
|
if param_type == "ARRAY" and array_element_type:
|
|
@@ -636,15 +551,15 @@ class BigQueryDriver(
|
|
|
636
551
|
msg = f"Unsupported parameter type for BigQuery Arrow named parameter '{key}': {type(value)}"
|
|
637
552
|
raise SQLSpecError(msg)
|
|
638
553
|
final_job_config.query_parameters = query_parameters
|
|
639
|
-
elif
|
|
640
|
-
|
|
641
|
-
|
|
554
|
+
elif isinstance(processed_params, (list, tuple)):
|
|
555
|
+
final_job_config.query_parameters = [
|
|
556
|
+
bigquery.ScalarQueryParameter(None, self._get_bq_param_type(value)[0], value)
|
|
557
|
+
for value in processed_params
|
|
558
|
+
]
|
|
642
559
|
|
|
643
|
-
# Execute the query and get Arrow table
|
|
644
560
|
try:
|
|
645
|
-
query_job = conn.query(
|
|
646
|
-
arrow_table = query_job.to_arrow()
|
|
647
|
-
|
|
561
|
+
query_job = conn.query(processed_sql, job_config=final_job_config)
|
|
562
|
+
arrow_table = query_job.to_arrow()
|
|
648
563
|
except Exception as e:
|
|
649
564
|
msg = f"BigQuery Arrow query execution failed: {e!s}"
|
|
650
565
|
raise SQLSpecError(msg) from e
|
|
@@ -652,31 +567,34 @@ class BigQueryDriver(
|
|
|
652
567
|
|
|
653
568
|
def select_to_parquet(
|
|
654
569
|
self,
|
|
655
|
-
sql: str,
|
|
570
|
+
sql: str,
|
|
656
571
|
parameters: "Optional[StatementParameterType]" = None,
|
|
657
|
-
|
|
658
|
-
*,
|
|
572
|
+
*filters: "StatementFilter",
|
|
659
573
|
destination_uri: "Optional[str]" = None,
|
|
660
574
|
connection: "Optional[BigQueryConnection]" = None,
|
|
661
575
|
job_config: "Optional[bigquery.ExtractJobConfig]" = None,
|
|
662
576
|
**kwargs: Any,
|
|
663
577
|
) -> None:
|
|
664
|
-
"""Exports a BigQuery table to Parquet files in Google Cloud Storage.
|
|
665
|
-
|
|
666
|
-
Raises:
|
|
667
|
-
NotImplementedError: If the SQL is not a fully qualified table ID or if parameters are provided.
|
|
668
|
-
NotFoundError: If the source table is not found.
|
|
669
|
-
SQLSpecError: If the Parquet export fails.
|
|
670
|
-
"""
|
|
671
578
|
if destination_uri is None:
|
|
672
579
|
msg = "destination_uri is required"
|
|
673
580
|
raise SQLSpecError(msg)
|
|
674
581
|
conn = self._connection(connection)
|
|
675
|
-
|
|
676
|
-
|
|
582
|
+
|
|
583
|
+
if parameters is not None:
|
|
584
|
+
msg = (
|
|
585
|
+
"select_to_parquet expects a fully qualified table ID (e.g., 'project.dataset.table') "
|
|
586
|
+
"as the `sql` argument and does not support `parameters`."
|
|
587
|
+
)
|
|
677
588
|
raise NotImplementedError(msg)
|
|
678
589
|
|
|
679
|
-
|
|
590
|
+
try:
|
|
591
|
+
source_table_ref = bigquery.TableReference.from_string(sql, default_project=conn.project)
|
|
592
|
+
except ValueError as e:
|
|
593
|
+
msg = (
|
|
594
|
+
"select_to_parquet expects a fully qualified table ID (e.g., 'project.dataset.table') "
|
|
595
|
+
f"as the `sql` argument. Parsing failed for input '{sql}': {e!s}"
|
|
596
|
+
)
|
|
597
|
+
raise NotImplementedError(msg) from e
|
|
680
598
|
|
|
681
599
|
final_extract_config = job_config or bigquery.ExtractJobConfig() # type: ignore[no-untyped-call]
|
|
682
600
|
final_extract_config.destination_format = bigquery.DestinationFormat.PARQUET
|
|
@@ -686,9 +604,8 @@ class BigQueryDriver(
|
|
|
686
604
|
source_table_ref,
|
|
687
605
|
destination_uri,
|
|
688
606
|
job_config=final_extract_config,
|
|
689
|
-
# Location is correctly inferred by the client library
|
|
690
607
|
)
|
|
691
|
-
extract_job.result()
|
|
608
|
+
extract_job.result()
|
|
692
609
|
|
|
693
610
|
except NotFound:
|
|
694
611
|
msg = f"Source table not found for Parquet export: {source_table_ref}"
|
|
@@ -699,3 +616,6 @@ class BigQueryDriver(
|
|
|
699
616
|
if extract_job.errors:
|
|
700
617
|
msg = f"BigQuery Parquet export failed: {extract_job.errors}"
|
|
701
618
|
raise SQLSpecError(msg)
|
|
619
|
+
|
|
620
|
+
def _connection(self, connection: "Optional[BigQueryConnection]" = None) -> "BigQueryConnection":
|
|
621
|
+
return connection or self.connection
|