sqlspec 0.12.2__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_sql.py +21 -180
- sqlspec/adapters/adbc/config.py +10 -12
- sqlspec/adapters/adbc/driver.py +120 -118
- sqlspec/adapters/aiosqlite/config.py +3 -3
- sqlspec/adapters/aiosqlite/driver.py +100 -130
- sqlspec/adapters/asyncmy/config.py +3 -4
- sqlspec/adapters/asyncmy/driver.py +123 -135
- sqlspec/adapters/asyncpg/config.py +3 -7
- sqlspec/adapters/asyncpg/driver.py +98 -140
- sqlspec/adapters/bigquery/config.py +4 -5
- sqlspec/adapters/bigquery/driver.py +125 -167
- sqlspec/adapters/duckdb/config.py +3 -6
- sqlspec/adapters/duckdb/driver.py +114 -111
- sqlspec/adapters/oracledb/config.py +6 -5
- sqlspec/adapters/oracledb/driver.py +242 -259
- sqlspec/adapters/psqlpy/config.py +3 -7
- sqlspec/adapters/psqlpy/driver.py +118 -93
- sqlspec/adapters/psycopg/config.py +18 -31
- sqlspec/adapters/psycopg/driver.py +283 -236
- sqlspec/adapters/sqlite/config.py +3 -3
- sqlspec/adapters/sqlite/driver.py +103 -97
- sqlspec/config.py +0 -4
- sqlspec/driver/_async.py +89 -98
- sqlspec/driver/_common.py +52 -17
- sqlspec/driver/_sync.py +81 -105
- sqlspec/driver/connection.py +207 -0
- sqlspec/driver/mixins/_csv_writer.py +91 -0
- sqlspec/driver/mixins/_pipeline.py +38 -49
- sqlspec/driver/mixins/_result_utils.py +27 -9
- sqlspec/driver/mixins/_storage.py +67 -181
- sqlspec/driver/mixins/_type_coercion.py +3 -4
- sqlspec/driver/parameters.py +138 -0
- sqlspec/exceptions.py +10 -2
- sqlspec/extensions/aiosql/adapter.py +0 -10
- sqlspec/extensions/litestar/handlers.py +0 -1
- sqlspec/extensions/litestar/plugin.py +0 -3
- sqlspec/extensions/litestar/providers.py +0 -14
- sqlspec/loader.py +25 -90
- sqlspec/protocols.py +542 -0
- sqlspec/service/__init__.py +3 -2
- sqlspec/service/_util.py +147 -0
- sqlspec/service/base.py +1116 -9
- sqlspec/statement/builder/__init__.py +42 -32
- sqlspec/statement/builder/_ddl_utils.py +0 -10
- sqlspec/statement/builder/_parsing_utils.py +10 -4
- sqlspec/statement/builder/base.py +67 -22
- sqlspec/statement/builder/column.py +283 -0
- sqlspec/statement/builder/ddl.py +91 -67
- sqlspec/statement/builder/delete.py +23 -7
- sqlspec/statement/builder/insert.py +29 -15
- sqlspec/statement/builder/merge.py +4 -4
- sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
- sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
- sqlspec/statement/builder/mixins/_delete_from.py +1 -1
- sqlspec/statement/builder/mixins/_from.py +10 -8
- sqlspec/statement/builder/mixins/_group_by.py +0 -1
- sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
- sqlspec/statement/builder/mixins/_insert_values.py +0 -2
- sqlspec/statement/builder/mixins/_join.py +20 -13
- sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
- sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
- sqlspec/statement/builder/mixins/_order_by.py +2 -2
- sqlspec/statement/builder/mixins/_pivot.py +4 -7
- sqlspec/statement/builder/mixins/_select_columns.py +6 -5
- sqlspec/statement/builder/mixins/_unpivot.py +6 -9
- sqlspec/statement/builder/mixins/_update_from.py +2 -1
- sqlspec/statement/builder/mixins/_update_set.py +11 -8
- sqlspec/statement/builder/mixins/_where.py +61 -34
- sqlspec/statement/builder/select.py +32 -17
- sqlspec/statement/builder/update.py +25 -11
- sqlspec/statement/filters.py +39 -14
- sqlspec/statement/parameter_manager.py +220 -0
- sqlspec/statement/parameters.py +210 -79
- sqlspec/statement/pipelines/__init__.py +166 -23
- sqlspec/statement/pipelines/analyzers/_analyzer.py +21 -20
- sqlspec/statement/pipelines/context.py +35 -39
- sqlspec/statement/pipelines/transformers/__init__.py +2 -3
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +628 -58
- sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
- sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
- sqlspec/statement/pipelines/validators/_performance.py +38 -23
- sqlspec/statement/pipelines/validators/_security.py +39 -62
- sqlspec/statement/result.py +37 -129
- sqlspec/statement/splitter.py +0 -12
- sqlspec/statement/sql.py +863 -391
- sqlspec/statement/sql_compiler.py +140 -0
- sqlspec/storage/__init__.py +10 -2
- sqlspec/storage/backends/fsspec.py +53 -8
- sqlspec/storage/backends/obstore.py +15 -19
- sqlspec/storage/capabilities.py +101 -0
- sqlspec/storage/registry.py +56 -83
- sqlspec/typing.py +6 -434
- sqlspec/utils/cached_property.py +25 -0
- sqlspec/utils/correlation.py +0 -2
- sqlspec/utils/logging.py +0 -6
- sqlspec/utils/sync_tools.py +0 -4
- sqlspec/utils/text.py +0 -5
- sqlspec/utils/type_guards.py +892 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/METADATA +1 -1
- sqlspec-0.13.0.dist-info/RECORD +150 -0
- sqlspec/statement/builder/protocols.py +0 -20
- sqlspec/statement/pipelines/base.py +0 -315
- sqlspec/statement/pipelines/result_types.py +0 -41
- sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
- sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
- sqlspec/statement/pipelines/validators/base.py +0 -67
- sqlspec/storage/protocol.py +0 -173
- sqlspec-0.12.2.dist-info/RECORD +0 -145
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -21,6 +21,7 @@ from google.cloud.bigquery import (
|
|
|
21
21
|
from google.cloud.bigquery.table import Row as BigQueryRow
|
|
22
22
|
|
|
23
23
|
from sqlspec.driver import SyncDriverAdapterProtocol
|
|
24
|
+
from sqlspec.driver.connection import managed_transaction_sync
|
|
24
25
|
from sqlspec.driver.mixins import (
|
|
25
26
|
SQLTranslatorMixin,
|
|
26
27
|
SyncPipelinedExecutionMixin,
|
|
@@ -28,11 +29,12 @@ from sqlspec.driver.mixins import (
|
|
|
28
29
|
ToSchemaMixin,
|
|
29
30
|
TypeCoercionMixin,
|
|
30
31
|
)
|
|
32
|
+
from sqlspec.driver.parameters import normalize_parameter_sequence
|
|
31
33
|
from sqlspec.exceptions import SQLSpecError
|
|
32
|
-
from sqlspec.statement.parameters import ParameterStyle
|
|
33
|
-
from sqlspec.statement.result import ArrowResult,
|
|
34
|
+
from sqlspec.statement.parameters import ParameterStyle, ParameterValidator
|
|
35
|
+
from sqlspec.statement.result import ArrowResult, SQLResult
|
|
34
36
|
from sqlspec.statement.sql import SQL, SQLConfig
|
|
35
|
-
from sqlspec.typing import DictRow,
|
|
37
|
+
from sqlspec.typing import DictRow, RowT
|
|
36
38
|
from sqlspec.utils.serializers import to_json
|
|
37
39
|
|
|
38
40
|
if TYPE_CHECKING:
|
|
@@ -140,6 +142,10 @@ class BigQueryDriver(
|
|
|
140
142
|
Raises:
|
|
141
143
|
SQLSpecError: If value type is not supported.
|
|
142
144
|
"""
|
|
145
|
+
if value is None:
|
|
146
|
+
# BigQuery handles NULL values without explicit type
|
|
147
|
+
return ("STRING", None) # Use STRING type for NULL values
|
|
148
|
+
|
|
143
149
|
value_type = type(value)
|
|
144
150
|
if value_type is datetime.datetime:
|
|
145
151
|
return ("TIMESTAMP" if value.tzinfo else "DATETIME", None)
|
|
@@ -158,7 +164,6 @@ class BigQueryDriver(
|
|
|
158
164
|
if value_type in type_map:
|
|
159
165
|
return type_map[value_type]
|
|
160
166
|
|
|
161
|
-
# Handle lists/tuples for ARRAY type
|
|
162
167
|
if isinstance(value, (list, tuple)):
|
|
163
168
|
if not value:
|
|
164
169
|
msg = "Cannot determine BigQuery ARRAY type for empty sequence. Provide typed empty array or ensure context implies type."
|
|
@@ -192,8 +197,7 @@ class BigQueryDriver(
|
|
|
192
197
|
for name, value in params_dict.items():
|
|
193
198
|
param_name_for_bq = name.lstrip("@")
|
|
194
199
|
|
|
195
|
-
|
|
196
|
-
actual_value = value.value if hasattr(value, "value") else value
|
|
200
|
+
actual_value = getattr(value, "value", value)
|
|
197
201
|
|
|
198
202
|
param_type, array_element_type = self._get_bq_param_type(actual_value)
|
|
199
203
|
|
|
@@ -238,18 +242,14 @@ class BigQueryDriver(
|
|
|
238
242
|
"""
|
|
239
243
|
conn = connection or self.connection
|
|
240
244
|
|
|
241
|
-
# Build final job configuration
|
|
242
245
|
final_job_config = QueryJobConfig()
|
|
243
246
|
|
|
244
|
-
# Apply default configuration if available
|
|
245
247
|
if self._default_query_job_config:
|
|
246
248
|
self._copy_job_config_attrs(self._default_query_job_config, final_job_config)
|
|
247
249
|
|
|
248
|
-
# Apply override configuration if provided
|
|
249
250
|
if job_config:
|
|
250
251
|
self._copy_job_config_attrs(job_config, final_job_config)
|
|
251
252
|
|
|
252
|
-
# Set query parameters
|
|
253
253
|
final_job_config.query_parameters = bq_query_parameters or []
|
|
254
254
|
|
|
255
255
|
# Debug log the actual parameters being sent
|
|
@@ -266,14 +266,11 @@ class BigQueryDriver(
|
|
|
266
266
|
)
|
|
267
267
|
query_job = conn.query(sql_str, job_config=final_job_config)
|
|
268
268
|
|
|
269
|
-
# Get the auto-generated job ID for callbacks
|
|
270
269
|
if self.on_job_start and query_job.job_id:
|
|
271
270
|
with contextlib.suppress(Exception):
|
|
272
|
-
# Callback errors should not interfere with job execution
|
|
273
271
|
self.on_job_start(query_job.job_id)
|
|
274
272
|
if self.on_job_complete and query_job.job_id:
|
|
275
273
|
with contextlib.suppress(Exception):
|
|
276
|
-
# Callback errors should not interfere with job execution
|
|
277
274
|
self.on_job_complete(query_job.job_id, query_job)
|
|
278
275
|
|
|
279
276
|
return query_job
|
|
@@ -290,15 +287,21 @@ class BigQueryDriver(
|
|
|
290
287
|
"""
|
|
291
288
|
return [dict(row) for row in rows_iterator] # type: ignore[misc]
|
|
292
289
|
|
|
293
|
-
def _handle_select_job(self, query_job: QueryJob) ->
|
|
290
|
+
def _handle_select_job(self, query_job: QueryJob, statement: SQL) -> SQLResult[RowT]:
|
|
294
291
|
"""Handle a query job that is expected to return rows."""
|
|
295
292
|
job_result = query_job.result()
|
|
296
293
|
rows_list = self._rows_to_results(iter(job_result))
|
|
297
294
|
column_names = [field.name for field in query_job.schema] if query_job.schema else []
|
|
298
295
|
|
|
299
|
-
return
|
|
296
|
+
return SQLResult(
|
|
297
|
+
statement=statement,
|
|
298
|
+
data=rows_list,
|
|
299
|
+
column_names=column_names,
|
|
300
|
+
rows_affected=len(rows_list),
|
|
301
|
+
operation_type="SELECT",
|
|
302
|
+
)
|
|
300
303
|
|
|
301
|
-
def _handle_dml_job(self, query_job: QueryJob) ->
|
|
304
|
+
def _handle_dml_job(self, query_job: QueryJob, statement: SQL) -> SQLResult[RowT]:
|
|
302
305
|
"""Handle a DML job.
|
|
303
306
|
|
|
304
307
|
Note: BigQuery emulators (e.g., goccy/bigquery-emulator) may report 0 rows affected
|
|
@@ -324,7 +327,14 @@ class BigQueryDriver(
|
|
|
324
327
|
)
|
|
325
328
|
num_affected = 1 # Assume at least one row was affected
|
|
326
329
|
|
|
327
|
-
|
|
330
|
+
operation_type = self._determine_operation_type(statement)
|
|
331
|
+
return SQLResult(
|
|
332
|
+
statement=statement,
|
|
333
|
+
data=cast("list[RowT]", []),
|
|
334
|
+
rows_affected=num_affected or 0,
|
|
335
|
+
operation_type=operation_type,
|
|
336
|
+
metadata={"status_message": f"OK - job_id: {query_job.job_id}"},
|
|
337
|
+
)
|
|
328
338
|
|
|
329
339
|
def _compile_bigquery_compatible(self, statement: SQL, target_style: ParameterStyle) -> tuple[str, Any]:
|
|
330
340
|
"""Compile SQL statement for BigQuery.
|
|
@@ -336,12 +346,18 @@ class BigQueryDriver(
|
|
|
336
346
|
|
|
337
347
|
def _execute_statement(
|
|
338
348
|
self, statement: SQL, connection: Optional[BigQueryConnection] = None, **kwargs: Any
|
|
339
|
-
) ->
|
|
349
|
+
) -> SQLResult[RowT]:
|
|
340
350
|
if statement.is_script:
|
|
341
351
|
sql, _ = statement.compile(placeholder_style=ParameterStyle.STATIC)
|
|
342
352
|
return self._execute_script(sql, connection=connection, **kwargs)
|
|
343
353
|
|
|
344
|
-
detected_styles =
|
|
354
|
+
detected_styles = set()
|
|
355
|
+
sql_str = statement.to_sql(placeholder_style=None) # Get raw SQL
|
|
356
|
+
validator = self.config.parameter_validator if self.config else ParameterValidator()
|
|
357
|
+
param_infos = validator.extract_parameters(sql_str)
|
|
358
|
+
if param_infos:
|
|
359
|
+
detected_styles = {p.style for p in param_infos}
|
|
360
|
+
|
|
345
361
|
target_style = self.default_parameter_style
|
|
346
362
|
|
|
347
363
|
unsupported_styles = detected_styles - set(self.supported_parameter_styles)
|
|
@@ -359,167 +375,116 @@ class BigQueryDriver(
|
|
|
359
375
|
return self._execute_many(sql, params, connection=connection, **kwargs)
|
|
360
376
|
|
|
361
377
|
sql, params = self._compile_bigquery_compatible(statement, target_style)
|
|
362
|
-
logger.debug("compile() returned - sql: %r, params: %r", sql, params)
|
|
363
378
|
params = self._process_parameters(params)
|
|
364
|
-
logger.debug("after _process_parameters - params: %r", params)
|
|
365
379
|
return self._execute(sql, params, statement, connection=connection, **kwargs)
|
|
366
380
|
|
|
367
381
|
def _execute(
|
|
368
382
|
self, sql: str, parameters: Any, statement: SQL, connection: Optional[BigQueryConnection] = None, **kwargs: Any
|
|
369
|
-
) ->
|
|
370
|
-
#
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
k: v
|
|
385
|
-
for k, v in converted_params.items()
|
|
386
|
-
if k.startswith("param_") or (not k.startswith("_") and k not in {"dialect", "config"})
|
|
387
|
-
}
|
|
388
|
-
elif isinstance(converted_params, (list, tuple)):
|
|
389
|
-
# Convert positional parameters to named parameters for BigQuery
|
|
390
|
-
# Use param_N to match the compiled SQL placeholders
|
|
391
|
-
param_dict = {f"param_{i}": val for i, val in enumerate(converted_params)}
|
|
392
|
-
else:
|
|
393
|
-
# Single scalar parameter
|
|
394
|
-
param_dict = {"param_0": converted_params}
|
|
383
|
+
) -> SQLResult[RowT]:
|
|
384
|
+
# Use provided connection or driver's default connection
|
|
385
|
+
conn = connection if connection is not None else self._connection(None)
|
|
386
|
+
|
|
387
|
+
# BigQuery doesn't have traditional transactions, but we'll use the pattern for consistency
|
|
388
|
+
# The managed_transaction_sync will just pass through for BigQuery Client objects
|
|
389
|
+
with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
|
|
390
|
+
# Normalize parameters using consolidated utility
|
|
391
|
+
normalized_params = normalize_parameter_sequence(parameters)
|
|
392
|
+
param_dict: dict[str, Any] = {}
|
|
393
|
+
if normalized_params:
|
|
394
|
+
if isinstance(normalized_params[0], dict):
|
|
395
|
+
param_dict = normalized_params[0]
|
|
396
|
+
else:
|
|
397
|
+
param_dict = {f"param_{i}": val for i, val in enumerate(normalized_params)}
|
|
395
398
|
|
|
396
|
-
|
|
399
|
+
bq_params = self._prepare_bq_query_parameters(param_dict)
|
|
397
400
|
|
|
398
|
-
|
|
401
|
+
query_job = self._run_query_job(sql, bq_params, connection=txn_conn)
|
|
399
402
|
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
return self.
|
|
404
|
-
return self._handle_dml_job(query_job)
|
|
403
|
+
query_schema = getattr(query_job, "schema", None)
|
|
404
|
+
if query_job.statement_type == "SELECT" or (query_schema is not None and len(query_schema) > 0):
|
|
405
|
+
return self._handle_select_job(query_job, statement)
|
|
406
|
+
return self._handle_dml_job(query_job, statement)
|
|
405
407
|
|
|
406
408
|
def _execute_many(
|
|
407
409
|
self, sql: str, param_list: Any, connection: Optional[BigQueryConnection] = None, **kwargs: Any
|
|
408
|
-
) ->
|
|
409
|
-
# Use
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
410
|
+
) -> SQLResult[RowT]:
|
|
411
|
+
# Use provided connection or driver's default connection
|
|
412
|
+
conn = connection if connection is not None else self._connection(None)
|
|
413
|
+
|
|
414
|
+
with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
|
|
415
|
+
# Normalize parameter list using consolidated utility
|
|
416
|
+
normalized_param_list = normalize_parameter_sequence(param_list)
|
|
417
|
+
|
|
418
|
+
# Use a multi-statement script for batch execution
|
|
419
|
+
script_parts = []
|
|
420
|
+
all_params: dict[str, Any] = {}
|
|
421
|
+
param_counter = 0
|
|
422
|
+
|
|
423
|
+
for params in normalized_param_list or []:
|
|
424
|
+
if isinstance(params, dict):
|
|
425
|
+
param_dict = params
|
|
426
|
+
elif isinstance(params, (list, tuple)):
|
|
427
|
+
param_dict = {f"param_{i}": val for i, val in enumerate(params)}
|
|
428
|
+
else:
|
|
429
|
+
param_dict = {"param_0": params}
|
|
430
|
+
|
|
431
|
+
# Remap parameters to be unique across the entire script
|
|
432
|
+
param_mapping = {}
|
|
433
|
+
current_sql = sql
|
|
434
|
+
for key, value in param_dict.items():
|
|
435
|
+
new_key = f"p_{param_counter}"
|
|
436
|
+
param_counter += 1
|
|
437
|
+
param_mapping[key] = new_key
|
|
438
|
+
all_params[new_key] = value
|
|
439
|
+
|
|
440
|
+
for old_key, new_key in param_mapping.items():
|
|
441
|
+
current_sql = current_sql.replace(f"@{old_key}", f"@{new_key}")
|
|
442
|
+
|
|
443
|
+
script_parts.append(current_sql)
|
|
444
|
+
|
|
445
|
+
# Execute as a single script
|
|
446
|
+
full_script = ";\n".join(script_parts)
|
|
447
|
+
bq_params = self._prepare_bq_query_parameters(all_params)
|
|
448
|
+
# Filter out kwargs that _run_query_job doesn't expect
|
|
449
|
+
query_kwargs = {k: v for k, v in kwargs.items() if k not in {"parameters", "is_many"}}
|
|
450
|
+
query_job = self._run_query_job(full_script, bq_params, connection=txn_conn, **query_kwargs)
|
|
451
|
+
|
|
452
|
+
# Wait for the job to complete
|
|
453
|
+
query_job.result(timeout=kwargs.get("bq_job_timeout"))
|
|
454
|
+
total_rowcount = query_job.num_dml_affected_rows or 0
|
|
455
|
+
|
|
456
|
+
return SQLResult(
|
|
457
|
+
statement=SQL(sql, _dialect=self.dialect),
|
|
458
|
+
data=[],
|
|
459
|
+
rows_affected=total_rowcount,
|
|
460
|
+
operation_type="EXECUTE",
|
|
461
|
+
metadata={"status_message": f"OK - executed batch job {query_job.job_id}"},
|
|
462
|
+
)
|
|
452
463
|
|
|
453
464
|
def _execute_script(
|
|
454
465
|
self, script: str, connection: Optional[BigQueryConnection] = None, **kwargs: Any
|
|
455
|
-
) ->
|
|
456
|
-
#
|
|
457
|
-
|
|
458
|
-
statements = self._split_script_statements(script)
|
|
459
|
-
|
|
460
|
-
for statement in statements:
|
|
461
|
-
if statement:
|
|
462
|
-
query_job = self._run_query_job(statement, [], connection=connection)
|
|
463
|
-
query_job.result(timeout=kwargs.get("bq_job_timeout"))
|
|
464
|
-
|
|
465
|
-
return {"statements_executed": len(statements), "status_message": "SCRIPT EXECUTED"}
|
|
466
|
-
|
|
467
|
-
def _wrap_select_result(
|
|
468
|
-
self, statement: SQL, result: SelectResultDict, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any
|
|
469
|
-
) -> "Union[SQLResult[RowT], SQLResult[ModelDTOT]]":
|
|
470
|
-
if schema_type:
|
|
471
|
-
return cast(
|
|
472
|
-
"SQLResult[ModelDTOT]",
|
|
473
|
-
SQLResult(
|
|
474
|
-
statement=statement,
|
|
475
|
-
data=cast("list[ModelDTOT]", list(self.to_schema(data=result["data"], schema_type=schema_type))),
|
|
476
|
-
column_names=result["column_names"],
|
|
477
|
-
rows_affected=result["rows_affected"],
|
|
478
|
-
operation_type="SELECT",
|
|
479
|
-
),
|
|
480
|
-
)
|
|
466
|
+
) -> SQLResult[RowT]:
|
|
467
|
+
# Use provided connection or driver's default connection
|
|
468
|
+
conn = connection if connection is not None else self._connection(None)
|
|
481
469
|
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
),
|
|
491
|
-
)
|
|
470
|
+
with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
|
|
471
|
+
# BigQuery does not support multi-statement scripts in a single job
|
|
472
|
+
statements = self._split_script_statements(script)
|
|
473
|
+
|
|
474
|
+
for statement in statements:
|
|
475
|
+
if statement:
|
|
476
|
+
query_job = self._run_query_job(statement, [], connection=txn_conn)
|
|
477
|
+
query_job.result(timeout=kwargs.get("bq_job_timeout"))
|
|
492
478
|
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
) -> "SQLResult[RowT]":
|
|
496
|
-
operation_type = "UNKNOWN"
|
|
497
|
-
if statement.expression:
|
|
498
|
-
operation_type = str(statement.expression.key).upper()
|
|
499
|
-
if "statements_executed" in result:
|
|
500
|
-
return SQLResult[RowT](
|
|
501
|
-
statement=statement,
|
|
479
|
+
return SQLResult(
|
|
480
|
+
statement=SQL(script, _dialect=self.dialect).as_script(),
|
|
502
481
|
data=[],
|
|
503
482
|
rows_affected=0,
|
|
504
483
|
operation_type="SCRIPT",
|
|
505
|
-
metadata={
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
},
|
|
484
|
+
metadata={"status_message": "SCRIPT EXECUTED"},
|
|
485
|
+
total_statements=len(statements),
|
|
486
|
+
successful_statements=len(statements),
|
|
509
487
|
)
|
|
510
|
-
if "rows_affected" in result:
|
|
511
|
-
dml_result = cast("DMLResultDict", result)
|
|
512
|
-
rows_affected = dml_result["rows_affected"]
|
|
513
|
-
status_message = dml_result.get("status_message", "")
|
|
514
|
-
return SQLResult[RowT](
|
|
515
|
-
statement=statement,
|
|
516
|
-
data=[],
|
|
517
|
-
rows_affected=rows_affected,
|
|
518
|
-
operation_type=operation_type,
|
|
519
|
-
metadata={"status_message": status_message},
|
|
520
|
-
)
|
|
521
|
-
msg = f"Unexpected result type: {type(result)}"
|
|
522
|
-
raise ValueError(msg)
|
|
523
488
|
|
|
524
489
|
def _connection(self, connection: "Optional[Client]" = None) -> "Client":
|
|
525
490
|
"""Get the connection to use for the operation."""
|
|
@@ -553,7 +518,6 @@ class BigQueryDriver(
|
|
|
553
518
|
if destination_str.startswith("gs://"):
|
|
554
519
|
return self._export_to_gcs_native(query, destination_str, format, **options)
|
|
555
520
|
|
|
556
|
-
# For non-GCS destinations, check if staging is configured
|
|
557
521
|
staging_bucket = options.get("gcs_staging_bucket") or getattr(self.config, "gcs_staging_bucket", None)
|
|
558
522
|
if not staging_bucket:
|
|
559
523
|
# Fall back to fetch + write for non-GCS destinations without staging
|
|
@@ -605,12 +569,10 @@ class BigQueryDriver(
|
|
|
605
569
|
temp_table_id = f"temp_export_{uuid.uuid4().hex[:8]}"
|
|
606
570
|
dataset_id = getattr(self.connection, "default_dataset", None) or options.get("dataset", "temp")
|
|
607
571
|
|
|
608
|
-
# Create a temporary table with query results
|
|
609
572
|
query_with_table = f"CREATE OR REPLACE TABLE `{dataset_id}.{temp_table_id}` AS {query}"
|
|
610
573
|
create_job = self._run_query_job(query_with_table, [])
|
|
611
574
|
create_job.result()
|
|
612
575
|
|
|
613
|
-
# Get row count
|
|
614
576
|
count_query = f"SELECT COUNT(*) as cnt FROM `{dataset_id}.{temp_table_id}`"
|
|
615
577
|
count_job = self._run_query_job(count_query, [])
|
|
616
578
|
count_result = list(count_job.result())
|
|
@@ -620,7 +582,6 @@ class BigQueryDriver(
|
|
|
620
582
|
# Configure extract job
|
|
621
583
|
extract_config = ExtractJobConfig(**options) # type: ignore[no-untyped-call]
|
|
622
584
|
|
|
623
|
-
# Set format
|
|
624
585
|
format_mapping = {
|
|
625
586
|
"parquet": SourceFormat.PARQUET,
|
|
626
587
|
"csv": SourceFormat.CSV,
|
|
@@ -629,7 +590,6 @@ class BigQueryDriver(
|
|
|
629
590
|
}
|
|
630
591
|
extract_config.destination_format = format_mapping.get(format, SourceFormat.PARQUET)
|
|
631
592
|
|
|
632
|
-
# Extract table to GCS
|
|
633
593
|
table_ref = self.connection.dataset(dataset_id).table(temp_table_id)
|
|
634
594
|
extract_job = self.connection.extract_table(table_ref, gcs_uri, job_config=extract_config)
|
|
635
595
|
extract_job.result()
|
|
@@ -662,7 +622,6 @@ class BigQueryDriver(
|
|
|
662
622
|
Returns:
|
|
663
623
|
ArrowResult with native Arrow table
|
|
664
624
|
"""
|
|
665
|
-
|
|
666
625
|
# Execute the query directly with BigQuery to get the QueryJob
|
|
667
626
|
params = sql.get_parameters(style=self.default_parameter_style)
|
|
668
627
|
params_dict: dict[str, Any] = {}
|
|
@@ -742,7 +701,6 @@ class BigQueryDriver(
|
|
|
742
701
|
raise ValueError(msg)
|
|
743
702
|
|
|
744
703
|
# Use BigQuery's native Arrow loading
|
|
745
|
-
# Convert Arrow table to bytes for direct loading
|
|
746
704
|
|
|
747
705
|
import pyarrow.parquet as pq
|
|
748
706
|
|
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
4
|
from contextlib import contextmanager
|
|
5
|
-
from dataclasses import replace
|
|
6
5
|
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, TypedDict
|
|
7
6
|
|
|
8
7
|
import duckdb
|
|
@@ -336,11 +335,9 @@ class DuckDBConfig(NoPoolSyncConfig[DuckDBConnection, DuckDBDriver]):
|
|
|
336
335
|
# DuckDB connect() only accepts database, read_only, and config parameters
|
|
337
336
|
connect_params: dict[str, Any] = {}
|
|
338
337
|
|
|
339
|
-
# Set database if provided
|
|
340
338
|
if hasattr(self, "database") and self.database is not None:
|
|
341
339
|
connect_params["database"] = self.database
|
|
342
340
|
|
|
343
|
-
# Set read_only if provided
|
|
344
341
|
if hasattr(self, "read_only") and self.read_only is not None:
|
|
345
342
|
connect_params["read_only"] = self.read_only
|
|
346
343
|
|
|
@@ -352,7 +349,6 @@ class DuckDBConfig(NoPoolSyncConfig[DuckDBConnection, DuckDBDriver]):
|
|
|
352
349
|
if value is not None and value is not Empty:
|
|
353
350
|
config_dict[field] = value
|
|
354
351
|
|
|
355
|
-
# Add extras to config dict
|
|
356
352
|
config_dict.update(self.extras)
|
|
357
353
|
|
|
358
354
|
# If we have config parameters, add them
|
|
@@ -475,15 +471,16 @@ class DuckDBConfig(NoPoolSyncConfig[DuckDBConnection, DuckDBDriver]):
|
|
|
475
471
|
@contextmanager
|
|
476
472
|
def session_manager() -> "Generator[DuckDBDriver, None, None]":
|
|
477
473
|
with self.provide_connection(*args, **kwargs) as connection:
|
|
478
|
-
# Create statement config with parameter style info if not already set
|
|
479
474
|
statement_config = self.statement_config
|
|
475
|
+
# Inject parameter style info if not already set
|
|
480
476
|
if statement_config.allowed_parameter_styles is None:
|
|
477
|
+
from dataclasses import replace
|
|
478
|
+
|
|
481
479
|
statement_config = replace(
|
|
482
480
|
statement_config,
|
|
483
481
|
allowed_parameter_styles=self.supported_parameter_styles,
|
|
484
482
|
target_parameter_style=self.preferred_parameter_style,
|
|
485
483
|
)
|
|
486
|
-
|
|
487
484
|
driver = self.driver_type(connection=connection, config=statement_config)
|
|
488
485
|
yield driver
|
|
489
486
|
|