sqlspec 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +2 -1
- sqlspec/adapters/adbc/__init__.py +2 -1
- sqlspec/adapters/adbc/config.py +7 -13
- sqlspec/adapters/adbc/driver.py +37 -30
- sqlspec/adapters/aiosqlite/__init__.py +2 -1
- sqlspec/adapters/aiosqlite/config.py +10 -12
- sqlspec/adapters/aiosqlite/driver.py +36 -31
- sqlspec/adapters/asyncmy/__init__.py +2 -1
- sqlspec/adapters/asyncmy/driver.py +34 -31
- sqlspec/adapters/asyncpg/config.py +1 -3
- sqlspec/adapters/asyncpg/driver.py +7 -3
- sqlspec/adapters/bigquery/__init__.py +4 -0
- sqlspec/adapters/bigquery/config/__init__.py +3 -0
- sqlspec/adapters/bigquery/config/_common.py +40 -0
- sqlspec/adapters/bigquery/config/_sync.py +87 -0
- sqlspec/adapters/bigquery/driver.py +701 -0
- sqlspec/adapters/duckdb/__init__.py +2 -1
- sqlspec/adapters/duckdb/config.py +17 -18
- sqlspec/adapters/duckdb/driver.py +38 -30
- sqlspec/adapters/oracledb/__init__.py +8 -1
- sqlspec/adapters/oracledb/config/_asyncio.py +7 -8
- sqlspec/adapters/oracledb/config/_sync.py +6 -7
- sqlspec/adapters/oracledb/driver.py +65 -62
- sqlspec/adapters/psqlpy/__init__.py +9 -0
- sqlspec/adapters/psqlpy/config.py +5 -5
- sqlspec/adapters/psqlpy/driver.py +34 -28
- sqlspec/adapters/psycopg/__init__.py +8 -1
- sqlspec/adapters/psycopg/config/__init__.py +10 -0
- sqlspec/adapters/psycopg/config/_async.py +6 -7
- sqlspec/adapters/psycopg/config/_sync.py +7 -8
- sqlspec/adapters/psycopg/driver.py +63 -53
- sqlspec/adapters/sqlite/__init__.py +2 -1
- sqlspec/adapters/sqlite/config.py +12 -11
- sqlspec/adapters/sqlite/driver.py +36 -29
- sqlspec/base.py +1 -66
- sqlspec/exceptions.py +9 -0
- sqlspec/extensions/litestar/config.py +3 -11
- sqlspec/extensions/litestar/handlers.py +2 -1
- sqlspec/extensions/litestar/plugin.py +4 -2
- sqlspec/mixins.py +156 -0
- sqlspec/typing.py +19 -1
- {sqlspec-0.9.1.dist-info → sqlspec-0.10.0.dist-info}/METADATA +8 -3
- sqlspec-0.10.0.dist-info/RECORD +67 -0
- sqlspec-0.9.1.dist-info/RECORD +0 -61
- {sqlspec-0.9.1.dist-info → sqlspec-0.10.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.9.1.dist-info → sqlspec-0.10.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.9.1.dist-info → sqlspec-0.10.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,701 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
import datetime
|
|
3
|
+
from collections.abc import Iterator, Sequence
|
|
4
|
+
from decimal import Decimal
|
|
5
|
+
from typing import (
|
|
6
|
+
TYPE_CHECKING,
|
|
7
|
+
Any,
|
|
8
|
+
ClassVar,
|
|
9
|
+
Optional,
|
|
10
|
+
Union,
|
|
11
|
+
cast,
|
|
12
|
+
overload,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
import sqlglot
|
|
16
|
+
from google.cloud import bigquery
|
|
17
|
+
from google.cloud.bigquery import Client
|
|
18
|
+
from google.cloud.bigquery.job import QueryJob, QueryJobConfig
|
|
19
|
+
from google.cloud.exceptions import NotFound
|
|
20
|
+
|
|
21
|
+
from sqlspec.base import SyncDriverAdapterProtocol
|
|
22
|
+
from sqlspec.exceptions import NotFoundError, SQLSpecError
|
|
23
|
+
from sqlspec.mixins import (
|
|
24
|
+
SQLTranslatorMixin,
|
|
25
|
+
SyncArrowBulkOperationsMixin,
|
|
26
|
+
SyncParquetExportMixin,
|
|
27
|
+
)
|
|
28
|
+
from sqlspec.typing import ArrowTable, ModelDTOT, StatementParameterType, T
|
|
29
|
+
|
|
30
|
+
if TYPE_CHECKING:
|
|
31
|
+
from google.cloud.bigquery import SchemaField
|
|
32
|
+
from google.cloud.bigquery.table import Row
|
|
33
|
+
|
|
34
|
+
__all__ = ("BigQueryConnection", "BigQueryDriver")
|
|
35
|
+
|
|
36
|
+
BigQueryConnection = Client
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class BigQueryDriver(
|
|
40
|
+
SyncDriverAdapterProtocol["BigQueryConnection"],
|
|
41
|
+
SyncArrowBulkOperationsMixin["BigQueryConnection"],
|
|
42
|
+
SyncParquetExportMixin["BigQueryConnection"],
|
|
43
|
+
SQLTranslatorMixin["BigQueryConnection"],
|
|
44
|
+
):
|
|
45
|
+
"""Synchronous BigQuery Driver Adapter."""
|
|
46
|
+
|
|
47
|
+
dialect: str = "bigquery"
|
|
48
|
+
connection: "BigQueryConnection"
|
|
49
|
+
__supports_arrow__: ClassVar[bool] = True
|
|
50
|
+
|
|
51
|
+
def __init__(self, connection: "BigQueryConnection", **kwargs: Any) -> None:
|
|
52
|
+
super().__init__(connection=connection)
|
|
53
|
+
self._default_query_job_config = kwargs.get("default_query_job_config") or getattr(
|
|
54
|
+
connection, "default_query_job_config", None
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
@staticmethod
|
|
58
|
+
def _get_bq_param_type(value: Any) -> "tuple[Optional[str], Optional[str]]": # noqa: PLR0911, PLR0912
|
|
59
|
+
if isinstance(value, bool):
|
|
60
|
+
return "BOOL", None
|
|
61
|
+
if isinstance(value, int):
|
|
62
|
+
return "INT64", None
|
|
63
|
+
if isinstance(value, float):
|
|
64
|
+
return "FLOAT64", None
|
|
65
|
+
if isinstance(value, Decimal):
|
|
66
|
+
# Precision/scale might matter, but BQ client handles conversion.
|
|
67
|
+
# Defaulting to BIGNUMERIC, NUMERIC might be desired in some cases though (User change)
|
|
68
|
+
return "BIGNUMERIC", None
|
|
69
|
+
if isinstance(value, str):
|
|
70
|
+
return "STRING", None
|
|
71
|
+
if isinstance(value, bytes):
|
|
72
|
+
return "BYTES", None
|
|
73
|
+
if isinstance(value, datetime.date):
|
|
74
|
+
return "DATE", None
|
|
75
|
+
# DATETIME is for timezone-naive values
|
|
76
|
+
if isinstance(value, datetime.datetime) and value.tzinfo is None:
|
|
77
|
+
return "DATETIME", None
|
|
78
|
+
# TIMESTAMP is for timezone-aware values
|
|
79
|
+
if isinstance(value, datetime.datetime) and value.tzinfo is not None:
|
|
80
|
+
return "TIMESTAMP", None
|
|
81
|
+
if isinstance(value, datetime.time):
|
|
82
|
+
return "TIME", None
|
|
83
|
+
|
|
84
|
+
# Handle Arrays - Determine element type
|
|
85
|
+
if isinstance(value, (list, tuple)):
|
|
86
|
+
if not value:
|
|
87
|
+
# Cannot determine type of empty array, BQ requires type.
|
|
88
|
+
# Raise or default? Defaulting is risky. Let's raise.
|
|
89
|
+
msg = "Cannot determine BigQuery ARRAY type for empty sequence."
|
|
90
|
+
raise SQLSpecError(msg)
|
|
91
|
+
# Infer type from first element
|
|
92
|
+
first_element = value[0]
|
|
93
|
+
element_type, _ = BigQueryDriver._get_bq_param_type(first_element)
|
|
94
|
+
if element_type is None:
|
|
95
|
+
msg = f"Unsupported element type in ARRAY: {type(first_element)}"
|
|
96
|
+
raise SQLSpecError(msg)
|
|
97
|
+
return "ARRAY", element_type
|
|
98
|
+
|
|
99
|
+
# Handle Structs (basic dict mapping) - Requires careful handling
|
|
100
|
+
# if isinstance(value, dict):
|
|
101
|
+
# # This requires recursive type mapping for sub-fields.
|
|
102
|
+
# # For simplicity, users might need to construct StructQueryParameter manually.
|
|
103
|
+
# # return "STRUCT", None # Placeholder if implementing # noqa: ERA001
|
|
104
|
+
# raise SQLSpecError("Automatic STRUCT mapping not implemented. Please use bigquery.StructQueryParameter.") # noqa: ERA001
|
|
105
|
+
|
|
106
|
+
return None, None # Unsupported type
|
|
107
|
+
|
|
108
|
+
def _run_query_job( # noqa: C901, PLR0912, PLR0915 (User change)
|
|
109
|
+
self,
|
|
110
|
+
sql: str,
|
|
111
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
112
|
+
connection: "Optional[BigQueryConnection]" = None,
|
|
113
|
+
job_config: "Optional[QueryJobConfig]" = None,
|
|
114
|
+
is_script: bool = False,
|
|
115
|
+
**kwargs: Any,
|
|
116
|
+
) -> "QueryJob":
|
|
117
|
+
conn = self._connection(connection)
|
|
118
|
+
|
|
119
|
+
# Determine the final job config, creating a new one if necessary
|
|
120
|
+
# to avoid modifying a shared default config.
|
|
121
|
+
if job_config:
|
|
122
|
+
final_job_config = job_config # Use the provided config directly
|
|
123
|
+
elif self._default_query_job_config:
|
|
124
|
+
final_job_config = QueryJobConfig()
|
|
125
|
+
else:
|
|
126
|
+
final_job_config = QueryJobConfig() # Create a fresh config
|
|
127
|
+
|
|
128
|
+
# --- Parameter Handling Logic --- Start
|
|
129
|
+
params: Union[dict[str, Any], list[Any], None] = None
|
|
130
|
+
param_style: Optional[str] = None # 'named' (@), 'qmark' (?)
|
|
131
|
+
use_preformatted_params = False
|
|
132
|
+
final_sql = sql # Default to original SQL
|
|
133
|
+
|
|
134
|
+
# Check for pre-formatted BQ parameters first
|
|
135
|
+
if (
|
|
136
|
+
isinstance(parameters, (list, tuple))
|
|
137
|
+
and parameters
|
|
138
|
+
and all(isinstance(p, (bigquery.ScalarQueryParameter, bigquery.ArrayQueryParameter)) for p in parameters)
|
|
139
|
+
):
|
|
140
|
+
if kwargs:
|
|
141
|
+
msg = "Cannot mix pre-formatted BigQuery parameters with keyword arguments."
|
|
142
|
+
raise SQLSpecError(msg)
|
|
143
|
+
use_preformatted_params = True
|
|
144
|
+
final_job_config.query_parameters = list(parameters)
|
|
145
|
+
# Keep final_sql = sql, as it should match the pre-formatted named params
|
|
146
|
+
|
|
147
|
+
# Determine parameter style and merge standard parameters ONLY if not preformatted
|
|
148
|
+
if not use_preformatted_params:
|
|
149
|
+
if isinstance(parameters, dict):
|
|
150
|
+
params = {**parameters, **kwargs}
|
|
151
|
+
param_style = "named"
|
|
152
|
+
elif isinstance(parameters, (list, tuple)):
|
|
153
|
+
if kwargs:
|
|
154
|
+
msg = "Cannot mix positional parameters with keyword arguments."
|
|
155
|
+
raise SQLSpecError(msg)
|
|
156
|
+
# Check if it's primitives for qmark style
|
|
157
|
+
if all(
|
|
158
|
+
not isinstance(p, (bigquery.ScalarQueryParameter, bigquery.ArrayQueryParameter)) for p in parameters
|
|
159
|
+
):
|
|
160
|
+
params = list(parameters)
|
|
161
|
+
param_style = "qmark"
|
|
162
|
+
else:
|
|
163
|
+
# Mixed list or non-BQ parameter objects
|
|
164
|
+
msg = "Invalid mix of parameter types in list. Use only primitive values or only BigQuery QueryParameter objects."
|
|
165
|
+
raise SQLSpecError(msg)
|
|
166
|
+
|
|
167
|
+
elif kwargs:
|
|
168
|
+
params = kwargs
|
|
169
|
+
param_style = "named"
|
|
170
|
+
elif parameters is not None and not isinstance(
|
|
171
|
+
parameters, (bigquery.ScalarQueryParameter, bigquery.ArrayQueryParameter)
|
|
172
|
+
):
|
|
173
|
+
# Could be a single primitive value for positional
|
|
174
|
+
params = [parameters]
|
|
175
|
+
param_style = "qmark"
|
|
176
|
+
elif parameters is not None: # Single BQ parameter object
|
|
177
|
+
msg = "Single BigQuery QueryParameter objects should be passed within a list."
|
|
178
|
+
raise SQLSpecError(msg)
|
|
179
|
+
|
|
180
|
+
# Use sqlglot to transpile ONLY if not a script and not preformatted
|
|
181
|
+
if not is_script and not use_preformatted_params:
|
|
182
|
+
try:
|
|
183
|
+
# Transpile for syntax normalization/dialect conversion if needed
|
|
184
|
+
# Use BigQuery dialect for both reading and writing
|
|
185
|
+
final_sql = sqlglot.transpile(sql, read=self.dialect, write=self.dialect)[0]
|
|
186
|
+
except Exception as e:
|
|
187
|
+
# Catch potential sqlglot errors
|
|
188
|
+
msg = f"SQL transpilation failed using sqlglot: {e!s}" # Adjusted message
|
|
189
|
+
raise SQLSpecError(msg) from e
|
|
190
|
+
# else: If preformatted_params, final_sql remains the original sql
|
|
191
|
+
|
|
192
|
+
# --- Parameter Handling Logic --- (Moved outside the transpilation try/except)
|
|
193
|
+
# Prepare BQ parameters based on style, ONLY if not preformatted
|
|
194
|
+
if not use_preformatted_params:
|
|
195
|
+
if param_style == "named" and params:
|
|
196
|
+
# Convert dict params to BQ ScalarQueryParameter
|
|
197
|
+
if isinstance(params, dict):
|
|
198
|
+
final_job_config.query_parameters = [
|
|
199
|
+
bigquery.ScalarQueryParameter(name, self._get_bq_param_type(value)[0], value)
|
|
200
|
+
for name, value in params.items()
|
|
201
|
+
]
|
|
202
|
+
else:
|
|
203
|
+
# This path should ideally not be reached if param_style logic is correct
|
|
204
|
+
msg = f"Internal error: Parameter style is 'named' but parameters are not a dict: {type(params)}"
|
|
205
|
+
raise SQLSpecError(msg)
|
|
206
|
+
elif param_style == "qmark" and params:
|
|
207
|
+
# Convert list params to BQ ScalarQueryParameter
|
|
208
|
+
final_job_config.query_parameters = [
|
|
209
|
+
bigquery.ScalarQueryParameter(None, self._get_bq_param_type(value)[0], value) for value in params
|
|
210
|
+
]
|
|
211
|
+
|
|
212
|
+
# --- Parameter Handling Logic --- End
|
|
213
|
+
|
|
214
|
+
# Determine which kwargs to pass to the actual query method.
|
|
215
|
+
# We only want to pass kwargs that were *not* treated as SQL parameters.
|
|
216
|
+
final_query_kwargs = {}
|
|
217
|
+
if parameters is not None and kwargs: # Params came via arg, kwargs are separate
|
|
218
|
+
final_query_kwargs = kwargs
|
|
219
|
+
# Else: If params came via kwargs, they are already handled, so don't pass them again.
|
|
220
|
+
|
|
221
|
+
# Execute query
|
|
222
|
+
return conn.query(
|
|
223
|
+
final_sql,
|
|
224
|
+
job_config=final_job_config,
|
|
225
|
+
**final_query_kwargs, # Pass only relevant kwargs
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
@staticmethod
|
|
229
|
+
def _rows_to_results(
|
|
230
|
+
rows: "Iterator[Row]",
|
|
231
|
+
schema: "Sequence[SchemaField]",
|
|
232
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
233
|
+
) -> Sequence[Union[ModelDTOT, dict[str, Any]]]:
|
|
234
|
+
processed_results = []
|
|
235
|
+
# Create a quick lookup map for schema fields from the passed schema
|
|
236
|
+
schema_map = {field.name: field for field in schema}
|
|
237
|
+
|
|
238
|
+
for row in rows:
|
|
239
|
+
# row here is now a Row object from the iterator
|
|
240
|
+
row_dict = {}
|
|
241
|
+
for key, value in row.items(): # Use row.items() on the Row object
|
|
242
|
+
field = schema_map.get(key)
|
|
243
|
+
# Workaround remains the same
|
|
244
|
+
if field and field.field_type == "TIMESTAMP" and isinstance(value, str) and "." in value:
|
|
245
|
+
try:
|
|
246
|
+
parsed_value = datetime.datetime.fromtimestamp(float(value), tz=datetime.timezone.utc)
|
|
247
|
+
row_dict[key] = parsed_value
|
|
248
|
+
except ValueError:
|
|
249
|
+
row_dict[key] = value # type: ignore[assignment]
|
|
250
|
+
else:
|
|
251
|
+
row_dict[key] = value
|
|
252
|
+
# Use the processed dictionary for the final result
|
|
253
|
+
if schema_type:
|
|
254
|
+
processed_results.append(schema_type(**row_dict))
|
|
255
|
+
else:
|
|
256
|
+
processed_results.append(row_dict) # type: ignore[arg-type]
|
|
257
|
+
if schema_type:
|
|
258
|
+
return cast("Sequence[ModelDTOT]", processed_results)
|
|
259
|
+
return cast("Sequence[dict[str, Any]]", processed_results)
|
|
260
|
+
|
|
261
|
+
@overload
|
|
262
|
+
def select(
|
|
263
|
+
self,
|
|
264
|
+
sql: str,
|
|
265
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
266
|
+
/,
|
|
267
|
+
*,
|
|
268
|
+
connection: "Optional[BigQueryConnection]" = None,
|
|
269
|
+
schema_type: None = None,
|
|
270
|
+
**kwargs: Any,
|
|
271
|
+
) -> "Sequence[dict[str, Any]]": ...
|
|
272
|
+
@overload
|
|
273
|
+
def select(
|
|
274
|
+
self,
|
|
275
|
+
sql: str,
|
|
276
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
277
|
+
/,
|
|
278
|
+
*,
|
|
279
|
+
connection: "Optional[BigQueryConnection]" = None,
|
|
280
|
+
schema_type: "type[ModelDTOT]",
|
|
281
|
+
**kwargs: Any,
|
|
282
|
+
) -> "Sequence[ModelDTOT]": ...
|
|
283
|
+
def select(
|
|
284
|
+
self,
|
|
285
|
+
sql: str,
|
|
286
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
287
|
+
/,
|
|
288
|
+
*,
|
|
289
|
+
connection: "Optional[BigQueryConnection]" = None,
|
|
290
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
291
|
+
job_config: "Optional[QueryJobConfig]" = None,
|
|
292
|
+
**kwargs: Any,
|
|
293
|
+
) -> "Sequence[Union[ModelDTOT, dict[str, Any]]]":
|
|
294
|
+
query_job = self._run_query_job(sql, parameters, connection, job_config, **kwargs)
|
|
295
|
+
return self._rows_to_results(query_job.result(), query_job.result().schema, schema_type)
|
|
296
|
+
|
|
297
|
+
@overload
|
|
298
|
+
def select_one(
|
|
299
|
+
self,
|
|
300
|
+
sql: str,
|
|
301
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
302
|
+
/,
|
|
303
|
+
*,
|
|
304
|
+
connection: "Optional[BigQueryConnection]" = None,
|
|
305
|
+
schema_type: None = None,
|
|
306
|
+
**kwargs: Any,
|
|
307
|
+
) -> "dict[str, Any]": ...
|
|
308
|
+
@overload
|
|
309
|
+
def select_one(
|
|
310
|
+
self,
|
|
311
|
+
sql: str,
|
|
312
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
313
|
+
/,
|
|
314
|
+
*,
|
|
315
|
+
connection: "Optional[BigQueryConnection]" = None,
|
|
316
|
+
schema_type: "type[ModelDTOT]",
|
|
317
|
+
**kwargs: Any,
|
|
318
|
+
) -> "ModelDTOT": ...
|
|
319
|
+
def select_one(
|
|
320
|
+
self,
|
|
321
|
+
sql: str,
|
|
322
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
323
|
+
/,
|
|
324
|
+
*,
|
|
325
|
+
connection: "Optional[BigQueryConnection]" = None,
|
|
326
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
327
|
+
job_config: "Optional[QueryJobConfig]" = None,
|
|
328
|
+
**kwargs: Any,
|
|
329
|
+
) -> "Union[ModelDTOT, dict[str, Any]]":
|
|
330
|
+
query_job = self._run_query_job(sql, parameters, connection, job_config, **kwargs)
|
|
331
|
+
rows_iterator = query_job.result()
|
|
332
|
+
try:
|
|
333
|
+
# Pass the iterator containing only the first row to _rows_to_results
|
|
334
|
+
# This ensures the timestamp workaround is applied consistently.
|
|
335
|
+
# We need to pass the original iterator for schema access, but only consume one row.
|
|
336
|
+
first_row = next(rows_iterator)
|
|
337
|
+
# Create a simple iterator yielding only the first row for processing
|
|
338
|
+
single_row_iter = iter([first_row])
|
|
339
|
+
# We need RowIterator type for schema, create mock/proxy if needed, or pass schema
|
|
340
|
+
# Let's try passing schema directly to _rows_to_results (requires modifying it)
|
|
341
|
+
results = self._rows_to_results(single_row_iter, rows_iterator.schema, schema_type)
|
|
342
|
+
return results[0]
|
|
343
|
+
except StopIteration:
|
|
344
|
+
msg = "No result found when one was expected"
|
|
345
|
+
raise NotFoundError(msg) from None
|
|
346
|
+
|
|
347
|
+
@overload
|
|
348
|
+
def select_one_or_none(
|
|
349
|
+
self,
|
|
350
|
+
sql: str,
|
|
351
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
352
|
+
/,
|
|
353
|
+
*,
|
|
354
|
+
connection: "Optional[BigQueryConnection]" = None,
|
|
355
|
+
schema_type: None = None,
|
|
356
|
+
**kwargs: Any,
|
|
357
|
+
) -> "Optional[dict[str, Any]]": ...
|
|
358
|
+
@overload
|
|
359
|
+
def select_one_or_none(
|
|
360
|
+
self,
|
|
361
|
+
sql: str,
|
|
362
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
363
|
+
/,
|
|
364
|
+
*,
|
|
365
|
+
connection: "Optional[BigQueryConnection]" = None,
|
|
366
|
+
schema_type: "type[ModelDTOT]",
|
|
367
|
+
**kwargs: Any,
|
|
368
|
+
) -> "Optional[ModelDTOT]": ...
|
|
369
|
+
def select_one_or_none(
|
|
370
|
+
self,
|
|
371
|
+
sql: str,
|
|
372
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
373
|
+
/,
|
|
374
|
+
*,
|
|
375
|
+
connection: "Optional[BigQueryConnection]" = None,
|
|
376
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
377
|
+
job_config: "Optional[QueryJobConfig]" = None,
|
|
378
|
+
**kwargs: Any,
|
|
379
|
+
) -> "Optional[Union[ModelDTOT, dict[str, Any]]]":
|
|
380
|
+
query_job = self._run_query_job(sql, parameters, connection, job_config, **kwargs)
|
|
381
|
+
rows_iterator = query_job.result()
|
|
382
|
+
try:
|
|
383
|
+
first_row = next(rows_iterator)
|
|
384
|
+
# Create a simple iterator yielding only the first row for processing
|
|
385
|
+
single_row_iter = iter([first_row])
|
|
386
|
+
# Pass schema directly
|
|
387
|
+
results = self._rows_to_results(single_row_iter, rows_iterator.schema, schema_type)
|
|
388
|
+
return results[0]
|
|
389
|
+
except StopIteration:
|
|
390
|
+
return None
|
|
391
|
+
|
|
392
|
+
@overload
|
|
393
|
+
def select_value(
|
|
394
|
+
self,
|
|
395
|
+
sql: str,
|
|
396
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
397
|
+
/,
|
|
398
|
+
*,
|
|
399
|
+
connection: "Optional[BigQueryConnection]" = None,
|
|
400
|
+
schema_type: "Optional[type[T]]" = None,
|
|
401
|
+
job_config: "Optional[QueryJobConfig]" = None,
|
|
402
|
+
**kwargs: Any,
|
|
403
|
+
) -> Union[T, Any]: ...
|
|
404
|
+
@overload
|
|
405
|
+
def select_value(
|
|
406
|
+
self,
|
|
407
|
+
sql: str,
|
|
408
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
409
|
+
/,
|
|
410
|
+
*,
|
|
411
|
+
connection: "Optional[BigQueryConnection]" = None,
|
|
412
|
+
schema_type: "type[T]",
|
|
413
|
+
**kwargs: Any,
|
|
414
|
+
) -> "T": ...
|
|
415
|
+
def select_value(
|
|
416
|
+
self,
|
|
417
|
+
sql: str,
|
|
418
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
419
|
+
/,
|
|
420
|
+
*,
|
|
421
|
+
connection: "Optional[BigQueryConnection]" = None,
|
|
422
|
+
schema_type: "Optional[type[T]]" = None,
|
|
423
|
+
job_config: "Optional[QueryJobConfig]" = None,
|
|
424
|
+
**kwargs: Any,
|
|
425
|
+
) -> Union[T, Any]:
|
|
426
|
+
query_job = self._run_query_job(
|
|
427
|
+
sql=sql, parameters=parameters, connection=connection, job_config=job_config, **kwargs
|
|
428
|
+
)
|
|
429
|
+
rows = query_job.result()
|
|
430
|
+
try:
|
|
431
|
+
first_row = next(iter(rows))
|
|
432
|
+
value = first_row[0]
|
|
433
|
+
# Apply timestamp workaround if necessary
|
|
434
|
+
field = rows.schema[0] # Get schema for the first column
|
|
435
|
+
if field and field.field_type == "TIMESTAMP" and isinstance(value, str) and "." in value:
|
|
436
|
+
with contextlib.suppress(ValueError):
|
|
437
|
+
value = datetime.datetime.fromtimestamp(float(value), tz=datetime.timezone.utc)
|
|
438
|
+
|
|
439
|
+
return cast("T", value) if schema_type else value
|
|
440
|
+
except (StopIteration, IndexError):
|
|
441
|
+
msg = "No value found when one was expected"
|
|
442
|
+
raise NotFoundError(msg) from None
|
|
443
|
+
|
|
444
|
+
@overload
|
|
445
|
+
def select_value_or_none(
|
|
446
|
+
self,
|
|
447
|
+
sql: str,
|
|
448
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
449
|
+
/,
|
|
450
|
+
*,
|
|
451
|
+
connection: "Optional[BigQueryConnection]" = None,
|
|
452
|
+
schema_type: None = None,
|
|
453
|
+
**kwargs: Any,
|
|
454
|
+
) -> "Optional[Any]": ...
|
|
455
|
+
@overload
|
|
456
|
+
def select_value_or_none(
|
|
457
|
+
self,
|
|
458
|
+
sql: str,
|
|
459
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
460
|
+
/,
|
|
461
|
+
*,
|
|
462
|
+
connection: "Optional[BigQueryConnection]" = None,
|
|
463
|
+
schema_type: "type[T]",
|
|
464
|
+
**kwargs: Any,
|
|
465
|
+
) -> "Optional[T]": ...
|
|
466
|
+
def select_value_or_none(
|
|
467
|
+
self,
|
|
468
|
+
sql: str,
|
|
469
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
470
|
+
/,
|
|
471
|
+
*,
|
|
472
|
+
connection: "Optional[BigQueryConnection]" = None,
|
|
473
|
+
schema_type: "Optional[type[T]]" = None,
|
|
474
|
+
job_config: "Optional[QueryJobConfig]" = None,
|
|
475
|
+
**kwargs: Any,
|
|
476
|
+
) -> "Optional[Union[T, Any]]":
|
|
477
|
+
query_job = self._run_query_job(
|
|
478
|
+
sql=sql, parameters=parameters, connection=connection, job_config=job_config, **kwargs
|
|
479
|
+
)
|
|
480
|
+
rows = query_job.result()
|
|
481
|
+
try:
|
|
482
|
+
first_row = next(iter(rows))
|
|
483
|
+
value = first_row[0]
|
|
484
|
+
# Apply timestamp workaround if necessary
|
|
485
|
+
field = rows.schema[0] # Get schema for the first column
|
|
486
|
+
if field and field.field_type == "TIMESTAMP" and isinstance(value, str) and "." in value:
|
|
487
|
+
with contextlib.suppress(ValueError):
|
|
488
|
+
value = datetime.datetime.fromtimestamp(float(value), tz=datetime.timezone.utc)
|
|
489
|
+
|
|
490
|
+
return cast("T", value) if schema_type else value
|
|
491
|
+
except (StopIteration, IndexError):
|
|
492
|
+
return None
|
|
493
|
+
|
|
494
|
+
def insert_update_delete(
|
|
495
|
+
self,
|
|
496
|
+
sql: str,
|
|
497
|
+
parameters: Optional[StatementParameterType] = None,
|
|
498
|
+
/,
|
|
499
|
+
*,
|
|
500
|
+
connection: Optional["BigQueryConnection"] = None,
|
|
501
|
+
job_config: Optional[QueryJobConfig] = None,
|
|
502
|
+
**kwargs: Any,
|
|
503
|
+
) -> int:
|
|
504
|
+
"""Executes INSERT, UPDATE, DELETE and returns affected row count.
|
|
505
|
+
|
|
506
|
+
Returns:
|
|
507
|
+
int: The number of rows affected by the DML statement.
|
|
508
|
+
"""
|
|
509
|
+
query_job = self._run_query_job(
|
|
510
|
+
sql=sql, parameters=parameters, connection=connection, job_config=job_config, **kwargs
|
|
511
|
+
)
|
|
512
|
+
# DML statements might not return rows, check job properties
|
|
513
|
+
# num_dml_affected_rows might be None initially, wait might be needed
|
|
514
|
+
query_job.result() # Ensure completion
|
|
515
|
+
return query_job.num_dml_affected_rows or 0 # Return 0 if None
|
|
516
|
+
|
|
517
|
+
@overload
|
|
518
|
+
def insert_update_delete_returning(
|
|
519
|
+
self,
|
|
520
|
+
sql: str,
|
|
521
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
522
|
+
/,
|
|
523
|
+
*,
|
|
524
|
+
connection: "Optional[BigQueryConnection]" = None,
|
|
525
|
+
schema_type: None = None,
|
|
526
|
+
**kwargs: Any,
|
|
527
|
+
) -> "dict[str, Any]": ...
|
|
528
|
+
@overload
|
|
529
|
+
def insert_update_delete_returning(
|
|
530
|
+
self,
|
|
531
|
+
sql: str,
|
|
532
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
533
|
+
/,
|
|
534
|
+
*,
|
|
535
|
+
connection: "Optional[BigQueryConnection]" = None,
|
|
536
|
+
schema_type: "type[ModelDTOT]",
|
|
537
|
+
**kwargs: Any,
|
|
538
|
+
) -> "ModelDTOT": ...
|
|
539
|
+
def insert_update_delete_returning(
|
|
540
|
+
self,
|
|
541
|
+
sql: str,
|
|
542
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
543
|
+
/,
|
|
544
|
+
*,
|
|
545
|
+
connection: "Optional[BigQueryConnection]" = None,
|
|
546
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
547
|
+
job_config: "Optional[QueryJobConfig]" = None,
|
|
548
|
+
**kwargs: Any,
|
|
549
|
+
) -> Union[ModelDTOT, dict[str, Any]]:
|
|
550
|
+
"""BigQuery DML RETURNING equivalent is complex, often requires temp tables or scripting."""
|
|
551
|
+
msg = "BigQuery does not support `RETURNING` clauses directly in the same way as some other SQL databases. Consider multi-statement queries or alternative approaches."
|
|
552
|
+
raise NotImplementedError(msg)
|
|
553
|
+
|
|
554
|
+
def execute_script(
|
|
555
|
+
self,
|
|
556
|
+
sql: str, # Expecting a script here
|
|
557
|
+
parameters: "Optional[StatementParameterType]" = None, # Parameters might be complex in scripts
|
|
558
|
+
/,
|
|
559
|
+
*,
|
|
560
|
+
connection: "Optional[BigQueryConnection]" = None,
|
|
561
|
+
job_config: "Optional[QueryJobConfig]" = None,
|
|
562
|
+
**kwargs: Any,
|
|
563
|
+
) -> str:
|
|
564
|
+
"""Executes a BigQuery script and returns the job ID.
|
|
565
|
+
|
|
566
|
+
Returns:
|
|
567
|
+
str: The job ID of the executed script.
|
|
568
|
+
"""
|
|
569
|
+
query_job = self._run_query_job(
|
|
570
|
+
sql=sql,
|
|
571
|
+
parameters=parameters,
|
|
572
|
+
connection=connection,
|
|
573
|
+
job_config=job_config,
|
|
574
|
+
is_script=True,
|
|
575
|
+
**kwargs,
|
|
576
|
+
)
|
|
577
|
+
return str(query_job.job_id)
|
|
578
|
+
|
|
579
|
+
# --- Mixin Implementations ---
|
|
580
|
+
|
|
581
|
+
def select_arrow( # pyright: ignore # noqa: PLR0912
|
|
582
|
+
self,
|
|
583
|
+
sql: str,
|
|
584
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
585
|
+
/,
|
|
586
|
+
*,
|
|
587
|
+
connection: "Optional[BigQueryConnection]" = None,
|
|
588
|
+
job_config: "Optional[QueryJobConfig]" = None,
|
|
589
|
+
**kwargs: Any,
|
|
590
|
+
) -> "ArrowTable": # pyright: ignore[reportUnknownReturnType]
|
|
591
|
+
conn = self._connection(connection)
|
|
592
|
+
final_job_config = job_config or self._default_query_job_config or QueryJobConfig()
|
|
593
|
+
|
|
594
|
+
# Determine parameter style and merge parameters (Similar to _run_query_job)
|
|
595
|
+
params: Union[dict[str, Any], list[Any], None] = None
|
|
596
|
+
param_style: Optional[str] = None # 'named' (@), 'qmark' (?)
|
|
597
|
+
|
|
598
|
+
if isinstance(parameters, dict):
|
|
599
|
+
params = {**parameters, **kwargs}
|
|
600
|
+
param_style = "named"
|
|
601
|
+
elif isinstance(parameters, (list, tuple)):
|
|
602
|
+
if kwargs:
|
|
603
|
+
msg = "Cannot mix positional parameters with keyword arguments."
|
|
604
|
+
raise SQLSpecError(msg)
|
|
605
|
+
params = list(parameters)
|
|
606
|
+
param_style = "qmark"
|
|
607
|
+
elif kwargs:
|
|
608
|
+
params = kwargs
|
|
609
|
+
param_style = "named"
|
|
610
|
+
elif parameters is not None:
|
|
611
|
+
params = [parameters]
|
|
612
|
+
param_style = "qmark"
|
|
613
|
+
|
|
614
|
+
# Use sqlglot to transpile and bind parameters
|
|
615
|
+
try:
|
|
616
|
+
transpiled_sql = sqlglot.transpile(sql, args=params or {}, read=None, write=self.dialect)[0]
|
|
617
|
+
except Exception as e:
|
|
618
|
+
msg = f"SQL transpilation/binding failed using sqlglot: {e!s}"
|
|
619
|
+
raise SQLSpecError(msg) from e
|
|
620
|
+
|
|
621
|
+
# Prepare BigQuery specific parameters if named style was used
|
|
622
|
+
if param_style == "named" and params:
|
|
623
|
+
if not isinstance(params, dict):
|
|
624
|
+
# This should be logically impossible due to how param_style is set
|
|
625
|
+
msg = "Internal error: named parameter style detected but params is not a dict."
|
|
626
|
+
raise SQLSpecError(msg)
|
|
627
|
+
query_parameters = []
|
|
628
|
+
for key, value in params.items():
|
|
629
|
+
param_type, array_element_type = self._get_bq_param_type(value)
|
|
630
|
+
|
|
631
|
+
if param_type == "ARRAY" and array_element_type:
|
|
632
|
+
query_parameters.append(bigquery.ArrayQueryParameter(key, array_element_type, value))
|
|
633
|
+
elif param_type:
|
|
634
|
+
query_parameters.append(bigquery.ScalarQueryParameter(key, param_type, value)) # type: ignore[arg-type]
|
|
635
|
+
else:
|
|
636
|
+
msg = f"Unsupported parameter type for BigQuery Arrow named parameter '{key}': {type(value)}"
|
|
637
|
+
raise SQLSpecError(msg)
|
|
638
|
+
final_job_config.query_parameters = query_parameters
|
|
639
|
+
elif param_style == "qmark" and params:
|
|
640
|
+
# Positional params handled by client library
|
|
641
|
+
pass
|
|
642
|
+
|
|
643
|
+
# Execute the query and get Arrow table
|
|
644
|
+
try:
|
|
645
|
+
query_job = conn.query(transpiled_sql, job_config=final_job_config)
|
|
646
|
+
arrow_table = query_job.to_arrow() # Waits for job completion
|
|
647
|
+
|
|
648
|
+
except Exception as e:
|
|
649
|
+
msg = f"BigQuery Arrow query execution failed: {e!s}"
|
|
650
|
+
raise SQLSpecError(msg) from e
|
|
651
|
+
return arrow_table
|
|
652
|
+
|
|
653
|
+
def select_to_parquet(
|
|
654
|
+
self,
|
|
655
|
+
sql: str, # Expects table ID: project.dataset.table
|
|
656
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
657
|
+
/,
|
|
658
|
+
*,
|
|
659
|
+
destination_uri: "Optional[str]" = None,
|
|
660
|
+
connection: "Optional[BigQueryConnection]" = None,
|
|
661
|
+
job_config: "Optional[bigquery.ExtractJobConfig]" = None,
|
|
662
|
+
**kwargs: Any,
|
|
663
|
+
) -> None:
|
|
664
|
+
"""Exports a BigQuery table to Parquet files in Google Cloud Storage.
|
|
665
|
+
|
|
666
|
+
Raises:
|
|
667
|
+
NotImplementedError: If the SQL is not a fully qualified table ID or if parameters are provided.
|
|
668
|
+
NotFoundError: If the source table is not found.
|
|
669
|
+
SQLSpecError: If the Parquet export fails.
|
|
670
|
+
"""
|
|
671
|
+
if destination_uri is None:
|
|
672
|
+
msg = "destination_uri is required"
|
|
673
|
+
raise SQLSpecError(msg)
|
|
674
|
+
conn = self._connection(connection)
|
|
675
|
+
if "." not in sql or parameters is not None:
|
|
676
|
+
msg = "select_to_parquet currently expects a fully qualified table ID (project.dataset.table) as the `sql` argument and no `parameters`."
|
|
677
|
+
raise NotImplementedError(msg)
|
|
678
|
+
|
|
679
|
+
source_table_ref = bigquery.TableReference.from_string(sql, default_project=conn.project)
|
|
680
|
+
|
|
681
|
+
final_extract_config = job_config or bigquery.ExtractJobConfig() # type: ignore[no-untyped-call]
|
|
682
|
+
final_extract_config.destination_format = bigquery.DestinationFormat.PARQUET
|
|
683
|
+
|
|
684
|
+
try:
|
|
685
|
+
extract_job = conn.extract_table(
|
|
686
|
+
source_table_ref,
|
|
687
|
+
destination_uri,
|
|
688
|
+
job_config=final_extract_config,
|
|
689
|
+
# Location is correctly inferred by the client library
|
|
690
|
+
)
|
|
691
|
+
extract_job.result() # Wait for completion
|
|
692
|
+
|
|
693
|
+
except NotFound:
|
|
694
|
+
msg = f"Source table not found for Parquet export: {source_table_ref}"
|
|
695
|
+
raise NotFoundError(msg) from None
|
|
696
|
+
except Exception as e:
|
|
697
|
+
msg = f"BigQuery Parquet export failed: {e!s}"
|
|
698
|
+
raise SQLSpecError(msg) from e
|
|
699
|
+
if extract_job.errors:
|
|
700
|
+
msg = f"BigQuery Parquet export failed: {extract_job.errors}"
|
|
701
|
+
raise SQLSpecError(msg)
|