sqlspec 0.12.1__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_sql.py +21 -180
- sqlspec/adapters/adbc/config.py +10 -12
- sqlspec/adapters/adbc/driver.py +120 -118
- sqlspec/adapters/aiosqlite/config.py +3 -3
- sqlspec/adapters/aiosqlite/driver.py +116 -141
- sqlspec/adapters/asyncmy/config.py +3 -4
- sqlspec/adapters/asyncmy/driver.py +123 -135
- sqlspec/adapters/asyncpg/config.py +3 -7
- sqlspec/adapters/asyncpg/driver.py +98 -140
- sqlspec/adapters/bigquery/config.py +4 -5
- sqlspec/adapters/bigquery/driver.py +231 -181
- sqlspec/adapters/duckdb/config.py +3 -6
- sqlspec/adapters/duckdb/driver.py +132 -124
- sqlspec/adapters/oracledb/config.py +6 -5
- sqlspec/adapters/oracledb/driver.py +242 -259
- sqlspec/adapters/psqlpy/config.py +3 -7
- sqlspec/adapters/psqlpy/driver.py +118 -93
- sqlspec/adapters/psycopg/config.py +34 -30
- sqlspec/adapters/psycopg/driver.py +342 -214
- sqlspec/adapters/sqlite/config.py +3 -3
- sqlspec/adapters/sqlite/driver.py +150 -104
- sqlspec/config.py +0 -4
- sqlspec/driver/_async.py +89 -98
- sqlspec/driver/_common.py +52 -17
- sqlspec/driver/_sync.py +81 -105
- sqlspec/driver/connection.py +207 -0
- sqlspec/driver/mixins/_csv_writer.py +91 -0
- sqlspec/driver/mixins/_pipeline.py +38 -49
- sqlspec/driver/mixins/_result_utils.py +27 -9
- sqlspec/driver/mixins/_storage.py +149 -216
- sqlspec/driver/mixins/_type_coercion.py +3 -4
- sqlspec/driver/parameters.py +138 -0
- sqlspec/exceptions.py +10 -2
- sqlspec/extensions/aiosql/adapter.py +0 -10
- sqlspec/extensions/litestar/handlers.py +0 -1
- sqlspec/extensions/litestar/plugin.py +0 -3
- sqlspec/extensions/litestar/providers.py +0 -14
- sqlspec/loader.py +31 -118
- sqlspec/protocols.py +542 -0
- sqlspec/service/__init__.py +3 -2
- sqlspec/service/_util.py +147 -0
- sqlspec/service/base.py +1116 -9
- sqlspec/statement/builder/__init__.py +42 -32
- sqlspec/statement/builder/_ddl_utils.py +0 -10
- sqlspec/statement/builder/_parsing_utils.py +10 -4
- sqlspec/statement/builder/base.py +70 -23
- sqlspec/statement/builder/column.py +283 -0
- sqlspec/statement/builder/ddl.py +102 -65
- sqlspec/statement/builder/delete.py +23 -7
- sqlspec/statement/builder/insert.py +29 -15
- sqlspec/statement/builder/merge.py +4 -4
- sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
- sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
- sqlspec/statement/builder/mixins/_delete_from.py +1 -1
- sqlspec/statement/builder/mixins/_from.py +10 -8
- sqlspec/statement/builder/mixins/_group_by.py +0 -1
- sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
- sqlspec/statement/builder/mixins/_insert_values.py +0 -2
- sqlspec/statement/builder/mixins/_join.py +20 -13
- sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
- sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
- sqlspec/statement/builder/mixins/_order_by.py +2 -2
- sqlspec/statement/builder/mixins/_pivot.py +4 -7
- sqlspec/statement/builder/mixins/_select_columns.py +6 -5
- sqlspec/statement/builder/mixins/_unpivot.py +6 -9
- sqlspec/statement/builder/mixins/_update_from.py +2 -1
- sqlspec/statement/builder/mixins/_update_set.py +11 -8
- sqlspec/statement/builder/mixins/_where.py +61 -34
- sqlspec/statement/builder/select.py +32 -17
- sqlspec/statement/builder/update.py +25 -11
- sqlspec/statement/filters.py +39 -14
- sqlspec/statement/parameter_manager.py +220 -0
- sqlspec/statement/parameters.py +210 -79
- sqlspec/statement/pipelines/__init__.py +166 -23
- sqlspec/statement/pipelines/analyzers/_analyzer.py +22 -25
- sqlspec/statement/pipelines/context.py +35 -39
- sqlspec/statement/pipelines/transformers/__init__.py +2 -3
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +667 -43
- sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
- sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
- sqlspec/statement/pipelines/validators/_performance.py +38 -23
- sqlspec/statement/pipelines/validators/_security.py +39 -62
- sqlspec/statement/result.py +37 -129
- sqlspec/statement/splitter.py +0 -12
- sqlspec/statement/sql.py +885 -379
- sqlspec/statement/sql_compiler.py +140 -0
- sqlspec/storage/__init__.py +10 -2
- sqlspec/storage/backends/fsspec.py +82 -35
- sqlspec/storage/backends/obstore.py +66 -49
- sqlspec/storage/capabilities.py +101 -0
- sqlspec/storage/registry.py +56 -83
- sqlspec/typing.py +6 -434
- sqlspec/utils/cached_property.py +25 -0
- sqlspec/utils/correlation.py +0 -2
- sqlspec/utils/logging.py +0 -6
- sqlspec/utils/sync_tools.py +0 -4
- sqlspec/utils/text.py +0 -5
- sqlspec/utils/type_guards.py +892 -0
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/METADATA +1 -1
- sqlspec-0.13.0.dist-info/RECORD +150 -0
- sqlspec/statement/builder/protocols.py +0 -20
- sqlspec/statement/pipelines/base.py +0 -315
- sqlspec/statement/pipelines/result_types.py +0 -41
- sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
- sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
- sqlspec/statement/pipelines/validators/base.py +0 -67
- sqlspec/storage/protocol.py +0 -170
- sqlspec-0.12.1.dist-info/RECORD +0 -145
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -1,6 +1,8 @@
|
|
|
1
|
+
import contextlib
|
|
1
2
|
import datetime
|
|
2
3
|
import io
|
|
3
4
|
import logging
|
|
5
|
+
import uuid
|
|
4
6
|
from collections.abc import Iterator
|
|
5
7
|
from decimal import Decimal
|
|
6
8
|
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union, cast
|
|
@@ -8,15 +10,18 @@ from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union, cast
|
|
|
8
10
|
from google.cloud.bigquery import (
|
|
9
11
|
ArrayQueryParameter,
|
|
10
12
|
Client,
|
|
13
|
+
ExtractJobConfig,
|
|
11
14
|
LoadJobConfig,
|
|
12
15
|
QueryJob,
|
|
13
16
|
QueryJobConfig,
|
|
14
17
|
ScalarQueryParameter,
|
|
18
|
+
SourceFormat,
|
|
15
19
|
WriteDisposition,
|
|
16
20
|
)
|
|
17
21
|
from google.cloud.bigquery.table import Row as BigQueryRow
|
|
18
22
|
|
|
19
23
|
from sqlspec.driver import SyncDriverAdapterProtocol
|
|
24
|
+
from sqlspec.driver.connection import managed_transaction_sync
|
|
20
25
|
from sqlspec.driver.mixins import (
|
|
21
26
|
SQLTranslatorMixin,
|
|
22
27
|
SyncPipelinedExecutionMixin,
|
|
@@ -24,14 +29,17 @@ from sqlspec.driver.mixins import (
|
|
|
24
29
|
ToSchemaMixin,
|
|
25
30
|
TypeCoercionMixin,
|
|
26
31
|
)
|
|
32
|
+
from sqlspec.driver.parameters import normalize_parameter_sequence
|
|
27
33
|
from sqlspec.exceptions import SQLSpecError
|
|
28
|
-
from sqlspec.statement.parameters import ParameterStyle
|
|
29
|
-
from sqlspec.statement.result import ArrowResult,
|
|
34
|
+
from sqlspec.statement.parameters import ParameterStyle, ParameterValidator
|
|
35
|
+
from sqlspec.statement.result import ArrowResult, SQLResult
|
|
30
36
|
from sqlspec.statement.sql import SQL, SQLConfig
|
|
31
|
-
from sqlspec.typing import DictRow,
|
|
37
|
+
from sqlspec.typing import DictRow, RowT
|
|
32
38
|
from sqlspec.utils.serializers import to_json
|
|
33
39
|
|
|
34
40
|
if TYPE_CHECKING:
|
|
41
|
+
from pathlib import Path
|
|
42
|
+
|
|
35
43
|
from sqlglot.dialects.dialect import DialectType
|
|
36
44
|
|
|
37
45
|
|
|
@@ -134,6 +142,10 @@ class BigQueryDriver(
|
|
|
134
142
|
Raises:
|
|
135
143
|
SQLSpecError: If value type is not supported.
|
|
136
144
|
"""
|
|
145
|
+
if value is None:
|
|
146
|
+
# BigQuery handles NULL values without explicit type
|
|
147
|
+
return ("STRING", None) # Use STRING type for NULL values
|
|
148
|
+
|
|
137
149
|
value_type = type(value)
|
|
138
150
|
if value_type is datetime.datetime:
|
|
139
151
|
return ("TIMESTAMP" if value.tzinfo else "DATETIME", None)
|
|
@@ -152,7 +164,6 @@ class BigQueryDriver(
|
|
|
152
164
|
if value_type in type_map:
|
|
153
165
|
return type_map[value_type]
|
|
154
166
|
|
|
155
|
-
# Handle lists/tuples for ARRAY type
|
|
156
167
|
if isinstance(value, (list, tuple)):
|
|
157
168
|
if not value:
|
|
158
169
|
msg = "Cannot determine BigQuery ARRAY type for empty sequence. Provide typed empty array or ensure context implies type."
|
|
@@ -186,8 +197,7 @@ class BigQueryDriver(
|
|
|
186
197
|
for name, value in params_dict.items():
|
|
187
198
|
param_name_for_bq = name.lstrip("@")
|
|
188
199
|
|
|
189
|
-
|
|
190
|
-
actual_value = value.value if hasattr(value, "value") else value
|
|
200
|
+
actual_value = getattr(value, "value", value)
|
|
191
201
|
|
|
192
202
|
param_type, array_element_type = self._get_bq_param_type(actual_value)
|
|
193
203
|
|
|
@@ -232,18 +242,14 @@ class BigQueryDriver(
|
|
|
232
242
|
"""
|
|
233
243
|
conn = connection or self.connection
|
|
234
244
|
|
|
235
|
-
# Build final job configuration
|
|
236
245
|
final_job_config = QueryJobConfig()
|
|
237
246
|
|
|
238
|
-
# Apply default configuration if available
|
|
239
247
|
if self._default_query_job_config:
|
|
240
248
|
self._copy_job_config_attrs(self._default_query_job_config, final_job_config)
|
|
241
249
|
|
|
242
|
-
# Apply override configuration if provided
|
|
243
250
|
if job_config:
|
|
244
251
|
self._copy_job_config_attrs(job_config, final_job_config)
|
|
245
252
|
|
|
246
|
-
# Set query parameters
|
|
247
253
|
final_job_config.query_parameters = bq_query_parameters or []
|
|
248
254
|
|
|
249
255
|
# Debug log the actual parameters being sent
|
|
@@ -258,23 +264,14 @@ class BigQueryDriver(
|
|
|
258
264
|
param_value,
|
|
259
265
|
type(param_value),
|
|
260
266
|
)
|
|
261
|
-
# Let BigQuery generate the job ID to avoid collisions
|
|
262
|
-
# This is the recommended approach for production code and works better with emulators
|
|
263
|
-
logger.warning("About to send to BigQuery - SQL: %r", sql_str)
|
|
264
|
-
logger.warning("Query parameters in job config: %r", final_job_config.query_parameters)
|
|
265
267
|
query_job = conn.query(sql_str, job_config=final_job_config)
|
|
266
268
|
|
|
267
|
-
# Get the auto-generated job ID for callbacks
|
|
268
269
|
if self.on_job_start and query_job.job_id:
|
|
269
|
-
|
|
270
|
+
with contextlib.suppress(Exception):
|
|
270
271
|
self.on_job_start(query_job.job_id)
|
|
271
|
-
except Exception as e:
|
|
272
|
-
logger.warning("Job start callback failed: %s", str(e), extra={"adapter": "bigquery"})
|
|
273
272
|
if self.on_job_complete and query_job.job_id:
|
|
274
|
-
|
|
273
|
+
with contextlib.suppress(Exception):
|
|
275
274
|
self.on_job_complete(query_job.job_id, query_job)
|
|
276
|
-
except Exception as e:
|
|
277
|
-
logger.warning("Job complete callback failed: %s", str(e), extra={"adapter": "bigquery"})
|
|
278
275
|
|
|
279
276
|
return query_job
|
|
280
277
|
|
|
@@ -290,15 +287,21 @@ class BigQueryDriver(
|
|
|
290
287
|
"""
|
|
291
288
|
return [dict(row) for row in rows_iterator] # type: ignore[misc]
|
|
292
289
|
|
|
293
|
-
def _handle_select_job(self, query_job: QueryJob) ->
|
|
290
|
+
def _handle_select_job(self, query_job: QueryJob, statement: SQL) -> SQLResult[RowT]:
|
|
294
291
|
"""Handle a query job that is expected to return rows."""
|
|
295
292
|
job_result = query_job.result()
|
|
296
293
|
rows_list = self._rows_to_results(iter(job_result))
|
|
297
294
|
column_names = [field.name for field in query_job.schema] if query_job.schema else []
|
|
298
295
|
|
|
299
|
-
return
|
|
296
|
+
return SQLResult(
|
|
297
|
+
statement=statement,
|
|
298
|
+
data=rows_list,
|
|
299
|
+
column_names=column_names,
|
|
300
|
+
rows_affected=len(rows_list),
|
|
301
|
+
operation_type="SELECT",
|
|
302
|
+
)
|
|
300
303
|
|
|
301
|
-
def _handle_dml_job(self, query_job: QueryJob) ->
|
|
304
|
+
def _handle_dml_job(self, query_job: QueryJob, statement: SQL) -> SQLResult[RowT]:
|
|
302
305
|
"""Handle a DML job.
|
|
303
306
|
|
|
304
307
|
Note: BigQuery emulators (e.g., goccy/bigquery-emulator) may report 0 rows affected
|
|
@@ -324,7 +327,14 @@ class BigQueryDriver(
|
|
|
324
327
|
)
|
|
325
328
|
num_affected = 1 # Assume at least one row was affected
|
|
326
329
|
|
|
327
|
-
|
|
330
|
+
operation_type = self._determine_operation_type(statement)
|
|
331
|
+
return SQLResult(
|
|
332
|
+
statement=statement,
|
|
333
|
+
data=cast("list[RowT]", []),
|
|
334
|
+
rows_affected=num_affected or 0,
|
|
335
|
+
operation_type=operation_type,
|
|
336
|
+
metadata={"status_message": f"OK - job_id: {query_job.job_id}"},
|
|
337
|
+
)
|
|
328
338
|
|
|
329
339
|
def _compile_bigquery_compatible(self, statement: SQL, target_style: ParameterStyle) -> tuple[str, Any]:
|
|
330
340
|
"""Compile SQL statement for BigQuery.
|
|
@@ -336,12 +346,18 @@ class BigQueryDriver(
|
|
|
336
346
|
|
|
337
347
|
def _execute_statement(
|
|
338
348
|
self, statement: SQL, connection: Optional[BigQueryConnection] = None, **kwargs: Any
|
|
339
|
-
) ->
|
|
349
|
+
) -> SQLResult[RowT]:
|
|
340
350
|
if statement.is_script:
|
|
341
351
|
sql, _ = statement.compile(placeholder_style=ParameterStyle.STATIC)
|
|
342
352
|
return self._execute_script(sql, connection=connection, **kwargs)
|
|
343
353
|
|
|
344
|
-
detected_styles =
|
|
354
|
+
detected_styles = set()
|
|
355
|
+
sql_str = statement.to_sql(placeholder_style=None) # Get raw SQL
|
|
356
|
+
validator = self.config.parameter_validator if self.config else ParameterValidator()
|
|
357
|
+
param_infos = validator.extract_parameters(sql_str)
|
|
358
|
+
if param_infos:
|
|
359
|
+
detected_styles = {p.style for p in param_infos}
|
|
360
|
+
|
|
345
361
|
target_style = self.default_parameter_style
|
|
346
362
|
|
|
347
363
|
unsupported_styles = detected_styles - set(self.supported_parameter_styles)
|
|
@@ -359,167 +375,116 @@ class BigQueryDriver(
|
|
|
359
375
|
return self._execute_many(sql, params, connection=connection, **kwargs)
|
|
360
376
|
|
|
361
377
|
sql, params = self._compile_bigquery_compatible(statement, target_style)
|
|
362
|
-
logger.debug("compile() returned - sql: %r, params: %r", sql, params)
|
|
363
378
|
params = self._process_parameters(params)
|
|
364
|
-
logger.debug("after _process_parameters - params: %r", params)
|
|
365
379
|
return self._execute(sql, params, statement, connection=connection, **kwargs)
|
|
366
380
|
|
|
367
381
|
def _execute(
|
|
368
382
|
self, sql: str, parameters: Any, statement: SQL, connection: Optional[BigQueryConnection] = None, **kwargs: Any
|
|
369
|
-
) ->
|
|
370
|
-
#
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
k: v
|
|
385
|
-
for k, v in converted_params.items()
|
|
386
|
-
if k.startswith("param_") or (not k.startswith("_") and k not in {"dialect", "config"})
|
|
387
|
-
}
|
|
388
|
-
elif isinstance(converted_params, (list, tuple)):
|
|
389
|
-
# Convert positional parameters to named parameters for BigQuery
|
|
390
|
-
# Use param_N to match the compiled SQL placeholders
|
|
391
|
-
param_dict = {f"param_{i}": val for i, val in enumerate(converted_params)}
|
|
392
|
-
else:
|
|
393
|
-
# Single scalar parameter
|
|
394
|
-
param_dict = {"param_0": converted_params}
|
|
383
|
+
) -> SQLResult[RowT]:
|
|
384
|
+
# Use provided connection or driver's default connection
|
|
385
|
+
conn = connection if connection is not None else self._connection(None)
|
|
386
|
+
|
|
387
|
+
# BigQuery doesn't have traditional transactions, but we'll use the pattern for consistency
|
|
388
|
+
# The managed_transaction_sync will just pass through for BigQuery Client objects
|
|
389
|
+
with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
|
|
390
|
+
# Normalize parameters using consolidated utility
|
|
391
|
+
normalized_params = normalize_parameter_sequence(parameters)
|
|
392
|
+
param_dict: dict[str, Any] = {}
|
|
393
|
+
if normalized_params:
|
|
394
|
+
if isinstance(normalized_params[0], dict):
|
|
395
|
+
param_dict = normalized_params[0]
|
|
396
|
+
else:
|
|
397
|
+
param_dict = {f"param_{i}": val for i, val in enumerate(normalized_params)}
|
|
395
398
|
|
|
396
|
-
|
|
399
|
+
bq_params = self._prepare_bq_query_parameters(param_dict)
|
|
397
400
|
|
|
398
|
-
|
|
401
|
+
query_job = self._run_query_job(sql, bq_params, connection=txn_conn)
|
|
399
402
|
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
return self.
|
|
404
|
-
return self._handle_dml_job(query_job)
|
|
403
|
+
query_schema = getattr(query_job, "schema", None)
|
|
404
|
+
if query_job.statement_type == "SELECT" or (query_schema is not None and len(query_schema) > 0):
|
|
405
|
+
return self._handle_select_job(query_job, statement)
|
|
406
|
+
return self._handle_dml_job(query_job, statement)
|
|
405
407
|
|
|
406
408
|
def _execute_many(
|
|
407
409
|
self, sql: str, param_list: Any, connection: Optional[BigQueryConnection] = None, **kwargs: Any
|
|
408
|
-
) ->
|
|
409
|
-
# Use
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
410
|
+
) -> SQLResult[RowT]:
|
|
411
|
+
# Use provided connection or driver's default connection
|
|
412
|
+
conn = connection if connection is not None else self._connection(None)
|
|
413
|
+
|
|
414
|
+
with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
|
|
415
|
+
# Normalize parameter list using consolidated utility
|
|
416
|
+
normalized_param_list = normalize_parameter_sequence(param_list)
|
|
417
|
+
|
|
418
|
+
# Use a multi-statement script for batch execution
|
|
419
|
+
script_parts = []
|
|
420
|
+
all_params: dict[str, Any] = {}
|
|
421
|
+
param_counter = 0
|
|
422
|
+
|
|
423
|
+
for params in normalized_param_list or []:
|
|
424
|
+
if isinstance(params, dict):
|
|
425
|
+
param_dict = params
|
|
426
|
+
elif isinstance(params, (list, tuple)):
|
|
427
|
+
param_dict = {f"param_{i}": val for i, val in enumerate(params)}
|
|
428
|
+
else:
|
|
429
|
+
param_dict = {"param_0": params}
|
|
430
|
+
|
|
431
|
+
# Remap parameters to be unique across the entire script
|
|
432
|
+
param_mapping = {}
|
|
433
|
+
current_sql = sql
|
|
434
|
+
for key, value in param_dict.items():
|
|
435
|
+
new_key = f"p_{param_counter}"
|
|
436
|
+
param_counter += 1
|
|
437
|
+
param_mapping[key] = new_key
|
|
438
|
+
all_params[new_key] = value
|
|
439
|
+
|
|
440
|
+
for old_key, new_key in param_mapping.items():
|
|
441
|
+
current_sql = current_sql.replace(f"@{old_key}", f"@{new_key}")
|
|
442
|
+
|
|
443
|
+
script_parts.append(current_sql)
|
|
444
|
+
|
|
445
|
+
# Execute as a single script
|
|
446
|
+
full_script = ";\n".join(script_parts)
|
|
447
|
+
bq_params = self._prepare_bq_query_parameters(all_params)
|
|
448
|
+
# Filter out kwargs that _run_query_job doesn't expect
|
|
449
|
+
query_kwargs = {k: v for k, v in kwargs.items() if k not in {"parameters", "is_many"}}
|
|
450
|
+
query_job = self._run_query_job(full_script, bq_params, connection=txn_conn, **query_kwargs)
|
|
451
|
+
|
|
452
|
+
# Wait for the job to complete
|
|
453
|
+
query_job.result(timeout=kwargs.get("bq_job_timeout"))
|
|
454
|
+
total_rowcount = query_job.num_dml_affected_rows or 0
|
|
455
|
+
|
|
456
|
+
return SQLResult(
|
|
457
|
+
statement=SQL(sql, _dialect=self.dialect),
|
|
458
|
+
data=[],
|
|
459
|
+
rows_affected=total_rowcount,
|
|
460
|
+
operation_type="EXECUTE",
|
|
461
|
+
metadata={"status_message": f"OK - executed batch job {query_job.job_id}"},
|
|
462
|
+
)
|
|
452
463
|
|
|
453
464
|
def _execute_script(
|
|
454
465
|
self, script: str, connection: Optional[BigQueryConnection] = None, **kwargs: Any
|
|
455
|
-
) ->
|
|
456
|
-
#
|
|
457
|
-
|
|
458
|
-
statements = self._split_script_statements(script)
|
|
459
|
-
|
|
460
|
-
for statement in statements:
|
|
461
|
-
if statement:
|
|
462
|
-
query_job = self._run_query_job(statement, [], connection=connection)
|
|
463
|
-
query_job.result(timeout=kwargs.get("bq_job_timeout"))
|
|
464
|
-
|
|
465
|
-
return {"statements_executed": len(statements), "status_message": "SCRIPT EXECUTED"}
|
|
466
|
-
|
|
467
|
-
def _wrap_select_result(
|
|
468
|
-
self, statement: SQL, result: SelectResultDict, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any
|
|
469
|
-
) -> "Union[SQLResult[RowT], SQLResult[ModelDTOT]]":
|
|
470
|
-
if schema_type:
|
|
471
|
-
return cast(
|
|
472
|
-
"SQLResult[ModelDTOT]",
|
|
473
|
-
SQLResult(
|
|
474
|
-
statement=statement,
|
|
475
|
-
data=cast("list[ModelDTOT]", list(self.to_schema(data=result["data"], schema_type=schema_type))),
|
|
476
|
-
column_names=result["column_names"],
|
|
477
|
-
rows_affected=result["rows_affected"],
|
|
478
|
-
operation_type="SELECT",
|
|
479
|
-
),
|
|
480
|
-
)
|
|
466
|
+
) -> SQLResult[RowT]:
|
|
467
|
+
# Use provided connection or driver's default connection
|
|
468
|
+
conn = connection if connection is not None else self._connection(None)
|
|
481
469
|
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
),
|
|
491
|
-
)
|
|
470
|
+
with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
|
|
471
|
+
# BigQuery does not support multi-statement scripts in a single job
|
|
472
|
+
statements = self._split_script_statements(script)
|
|
473
|
+
|
|
474
|
+
for statement in statements:
|
|
475
|
+
if statement:
|
|
476
|
+
query_job = self._run_query_job(statement, [], connection=txn_conn)
|
|
477
|
+
query_job.result(timeout=kwargs.get("bq_job_timeout"))
|
|
492
478
|
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
) -> "SQLResult[RowT]":
|
|
496
|
-
operation_type = "UNKNOWN"
|
|
497
|
-
if statement.expression:
|
|
498
|
-
operation_type = str(statement.expression.key).upper()
|
|
499
|
-
if "statements_executed" in result:
|
|
500
|
-
return SQLResult[RowT](
|
|
501
|
-
statement=statement,
|
|
479
|
+
return SQLResult(
|
|
480
|
+
statement=SQL(script, _dialect=self.dialect).as_script(),
|
|
502
481
|
data=[],
|
|
503
482
|
rows_affected=0,
|
|
504
483
|
operation_type="SCRIPT",
|
|
505
|
-
metadata={
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
},
|
|
509
|
-
)
|
|
510
|
-
if "rows_affected" in result:
|
|
511
|
-
dml_result = cast("DMLResultDict", result)
|
|
512
|
-
rows_affected = dml_result["rows_affected"]
|
|
513
|
-
status_message = dml_result.get("status_message", "")
|
|
514
|
-
return SQLResult[RowT](
|
|
515
|
-
statement=statement,
|
|
516
|
-
data=[],
|
|
517
|
-
rows_affected=rows_affected,
|
|
518
|
-
operation_type=operation_type,
|
|
519
|
-
metadata={"status_message": status_message},
|
|
484
|
+
metadata={"status_message": "SCRIPT EXECUTED"},
|
|
485
|
+
total_statements=len(statements),
|
|
486
|
+
successful_statements=len(statements),
|
|
520
487
|
)
|
|
521
|
-
msg = f"Unexpected result type: {type(result)}"
|
|
522
|
-
raise ValueError(msg)
|
|
523
488
|
|
|
524
489
|
def _connection(self, connection: "Optional[Client]" = None) -> "Client":
|
|
525
490
|
"""Get the connection to use for the operation."""
|
|
@@ -529,28 +494,115 @@ class BigQueryDriver(
|
|
|
529
494
|
# BigQuery Native Export Support
|
|
530
495
|
# ============================================================================
|
|
531
496
|
|
|
532
|
-
def _export_native(self, query: str, destination_uri: str, format: str, **options: Any) -> int:
|
|
533
|
-
"""BigQuery native export implementation.
|
|
497
|
+
def _export_native(self, query: str, destination_uri: "Union[str, Path]", format: str, **options: Any) -> int:
|
|
498
|
+
"""BigQuery native export implementation with automatic GCS staging.
|
|
534
499
|
|
|
535
|
-
For
|
|
536
|
-
|
|
500
|
+
For GCS URIs, uses direct export. For other locations, automatically stages
|
|
501
|
+
through a temporary GCS location and transfers to the final destination.
|
|
537
502
|
|
|
538
503
|
Args:
|
|
539
504
|
query: SQL query to execute
|
|
540
|
-
destination_uri: Destination URI (local file path
|
|
505
|
+
destination_uri: Destination URI (local file path, gs:// URI, or Path object)
|
|
541
506
|
format: Export format (parquet, csv, json, avro)
|
|
542
|
-
**options: Additional export options
|
|
507
|
+
**options: Additional export options including 'gcs_staging_bucket'
|
|
543
508
|
|
|
544
509
|
Returns:
|
|
545
510
|
Number of rows exported
|
|
546
511
|
|
|
547
512
|
Raises:
|
|
548
|
-
NotImplementedError:
|
|
513
|
+
NotImplementedError: If no staging bucket is configured for non-GCS destinations
|
|
549
514
|
"""
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
515
|
+
destination_str = str(destination_uri)
|
|
516
|
+
|
|
517
|
+
# If it's already a GCS URI, use direct export
|
|
518
|
+
if destination_str.startswith("gs://"):
|
|
519
|
+
return self._export_to_gcs_native(query, destination_str, format, **options)
|
|
520
|
+
|
|
521
|
+
staging_bucket = options.get("gcs_staging_bucket") or getattr(self.config, "gcs_staging_bucket", None)
|
|
522
|
+
if not staging_bucket:
|
|
523
|
+
# Fall back to fetch + write for non-GCS destinations without staging
|
|
524
|
+
msg = "BigQuery native export requires GCS staging bucket for non-GCS destinations"
|
|
525
|
+
raise NotImplementedError(msg)
|
|
526
|
+
|
|
527
|
+
# Generate temporary GCS path
|
|
528
|
+
from datetime import timezone
|
|
529
|
+
|
|
530
|
+
timestamp = datetime.datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
|
|
531
|
+
temp_filename = f"bigquery_export_{timestamp}_{uuid.uuid4().hex[:8]}.{format}"
|
|
532
|
+
temp_gcs_uri = f"gs://{staging_bucket}/temp_exports/{temp_filename}"
|
|
533
|
+
|
|
534
|
+
try:
|
|
535
|
+
# Export to temporary GCS location
|
|
536
|
+
rows_exported = self._export_to_gcs_native(query, temp_gcs_uri, format, **options)
|
|
537
|
+
|
|
538
|
+
# Transfer from GCS to final destination using storage backend
|
|
539
|
+
backend, path = self._resolve_backend_and_path(destination_str)
|
|
540
|
+
gcs_backend = self._get_storage_backend(temp_gcs_uri)
|
|
541
|
+
|
|
542
|
+
# Download from GCS and upload to final destination
|
|
543
|
+
data = gcs_backend.read_bytes(temp_gcs_uri)
|
|
544
|
+
backend.write_bytes(path, data)
|
|
545
|
+
|
|
546
|
+
return rows_exported
|
|
547
|
+
finally:
|
|
548
|
+
# Clean up temporary file
|
|
549
|
+
try:
|
|
550
|
+
gcs_backend = self._get_storage_backend(temp_gcs_uri)
|
|
551
|
+
gcs_backend.delete(temp_gcs_uri)
|
|
552
|
+
except Exception as e:
|
|
553
|
+
logger.warning("Failed to clean up temporary GCS file %s: %s", temp_gcs_uri, e)
|
|
554
|
+
|
|
555
|
+
def _export_to_gcs_native(self, query: str, gcs_uri: str, format: str, **options: Any) -> int:
|
|
556
|
+
"""Direct BigQuery export to GCS.
|
|
557
|
+
|
|
558
|
+
Args:
|
|
559
|
+
query: SQL query to execute
|
|
560
|
+
gcs_uri: GCS destination URI (must start with gs://)
|
|
561
|
+
format: Export format (parquet, csv, json, avro)
|
|
562
|
+
**options: Additional export options
|
|
563
|
+
|
|
564
|
+
Returns:
|
|
565
|
+
Number of rows exported
|
|
566
|
+
"""
|
|
567
|
+
# First, run the query and store results in a temporary table
|
|
568
|
+
|
|
569
|
+
temp_table_id = f"temp_export_{uuid.uuid4().hex[:8]}"
|
|
570
|
+
dataset_id = getattr(self.connection, "default_dataset", None) or options.get("dataset", "temp")
|
|
571
|
+
|
|
572
|
+
query_with_table = f"CREATE OR REPLACE TABLE `{dataset_id}.{temp_table_id}` AS {query}"
|
|
573
|
+
create_job = self._run_query_job(query_with_table, [])
|
|
574
|
+
create_job.result()
|
|
575
|
+
|
|
576
|
+
count_query = f"SELECT COUNT(*) as cnt FROM `{dataset_id}.{temp_table_id}`"
|
|
577
|
+
count_job = self._run_query_job(count_query, [])
|
|
578
|
+
count_result = list(count_job.result())
|
|
579
|
+
row_count = count_result[0]["cnt"] if count_result else 0
|
|
580
|
+
|
|
581
|
+
try:
|
|
582
|
+
# Configure extract job
|
|
583
|
+
extract_config = ExtractJobConfig(**options) # type: ignore[no-untyped-call]
|
|
584
|
+
|
|
585
|
+
format_mapping = {
|
|
586
|
+
"parquet": SourceFormat.PARQUET,
|
|
587
|
+
"csv": SourceFormat.CSV,
|
|
588
|
+
"json": SourceFormat.NEWLINE_DELIMITED_JSON,
|
|
589
|
+
"avro": SourceFormat.AVRO,
|
|
590
|
+
}
|
|
591
|
+
extract_config.destination_format = format_mapping.get(format, SourceFormat.PARQUET)
|
|
592
|
+
|
|
593
|
+
table_ref = self.connection.dataset(dataset_id).table(temp_table_id)
|
|
594
|
+
extract_job = self.connection.extract_table(table_ref, gcs_uri, job_config=extract_config)
|
|
595
|
+
extract_job.result()
|
|
596
|
+
|
|
597
|
+
return row_count
|
|
598
|
+
finally:
|
|
599
|
+
# Clean up temporary table
|
|
600
|
+
try:
|
|
601
|
+
delete_query = f"DROP TABLE IF EXISTS `{dataset_id}.{temp_table_id}`"
|
|
602
|
+
delete_job = self._run_query_job(delete_query, [])
|
|
603
|
+
delete_job.result()
|
|
604
|
+
except Exception as e:
|
|
605
|
+
logger.warning("Failed to clean up temporary table %s: %s", temp_table_id, e)
|
|
554
606
|
|
|
555
607
|
# ============================================================================
|
|
556
608
|
# BigQuery Native Arrow Support
|
|
@@ -570,7 +622,6 @@ class BigQueryDriver(
|
|
|
570
622
|
Returns:
|
|
571
623
|
ArrowResult with native Arrow table
|
|
572
624
|
"""
|
|
573
|
-
|
|
574
625
|
# Execute the query directly with BigQuery to get the QueryJob
|
|
575
626
|
params = sql.get_parameters(style=self.default_parameter_style)
|
|
576
627
|
params_dict: dict[str, Any] = {}
|
|
@@ -650,7 +701,6 @@ class BigQueryDriver(
|
|
|
650
701
|
raise ValueError(msg)
|
|
651
702
|
|
|
652
703
|
# Use BigQuery's native Arrow loading
|
|
653
|
-
# Convert Arrow table to bytes for direct loading
|
|
654
704
|
|
|
655
705
|
import pyarrow.parquet as pq
|
|
656
706
|
|
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
4
|
from contextlib import contextmanager
|
|
5
|
-
from dataclasses import replace
|
|
6
5
|
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, TypedDict
|
|
7
6
|
|
|
8
7
|
import duckdb
|
|
@@ -336,11 +335,9 @@ class DuckDBConfig(NoPoolSyncConfig[DuckDBConnection, DuckDBDriver]):
|
|
|
336
335
|
# DuckDB connect() only accepts database, read_only, and config parameters
|
|
337
336
|
connect_params: dict[str, Any] = {}
|
|
338
337
|
|
|
339
|
-
# Set database if provided
|
|
340
338
|
if hasattr(self, "database") and self.database is not None:
|
|
341
339
|
connect_params["database"] = self.database
|
|
342
340
|
|
|
343
|
-
# Set read_only if provided
|
|
344
341
|
if hasattr(self, "read_only") and self.read_only is not None:
|
|
345
342
|
connect_params["read_only"] = self.read_only
|
|
346
343
|
|
|
@@ -352,7 +349,6 @@ class DuckDBConfig(NoPoolSyncConfig[DuckDBConnection, DuckDBDriver]):
|
|
|
352
349
|
if value is not None and value is not Empty:
|
|
353
350
|
config_dict[field] = value
|
|
354
351
|
|
|
355
|
-
# Add extras to config dict
|
|
356
352
|
config_dict.update(self.extras)
|
|
357
353
|
|
|
358
354
|
# If we have config parameters, add them
|
|
@@ -475,15 +471,16 @@ class DuckDBConfig(NoPoolSyncConfig[DuckDBConnection, DuckDBDriver]):
|
|
|
475
471
|
@contextmanager
|
|
476
472
|
def session_manager() -> "Generator[DuckDBDriver, None, None]":
|
|
477
473
|
with self.provide_connection(*args, **kwargs) as connection:
|
|
478
|
-
# Create statement config with parameter style info if not already set
|
|
479
474
|
statement_config = self.statement_config
|
|
475
|
+
# Inject parameter style info if not already set
|
|
480
476
|
if statement_config.allowed_parameter_styles is None:
|
|
477
|
+
from dataclasses import replace
|
|
478
|
+
|
|
481
479
|
statement_config = replace(
|
|
482
480
|
statement_config,
|
|
483
481
|
allowed_parameter_styles=self.supported_parameter_styles,
|
|
484
482
|
target_parameter_style=self.preferred_parameter_style,
|
|
485
483
|
)
|
|
486
|
-
|
|
487
484
|
driver = self.driver_type(connection=connection, config=statement_config)
|
|
488
485
|
yield driver
|
|
489
486
|
|