sqlspec 0.12.1__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (113) hide show
  1. sqlspec/_sql.py +21 -180
  2. sqlspec/adapters/adbc/config.py +10 -12
  3. sqlspec/adapters/adbc/driver.py +120 -118
  4. sqlspec/adapters/aiosqlite/config.py +3 -3
  5. sqlspec/adapters/aiosqlite/driver.py +116 -141
  6. sqlspec/adapters/asyncmy/config.py +3 -4
  7. sqlspec/adapters/asyncmy/driver.py +123 -135
  8. sqlspec/adapters/asyncpg/config.py +3 -7
  9. sqlspec/adapters/asyncpg/driver.py +98 -140
  10. sqlspec/adapters/bigquery/config.py +4 -5
  11. sqlspec/adapters/bigquery/driver.py +231 -181
  12. sqlspec/adapters/duckdb/config.py +3 -6
  13. sqlspec/adapters/duckdb/driver.py +132 -124
  14. sqlspec/adapters/oracledb/config.py +6 -5
  15. sqlspec/adapters/oracledb/driver.py +242 -259
  16. sqlspec/adapters/psqlpy/config.py +3 -7
  17. sqlspec/adapters/psqlpy/driver.py +118 -93
  18. sqlspec/adapters/psycopg/config.py +34 -30
  19. sqlspec/adapters/psycopg/driver.py +342 -214
  20. sqlspec/adapters/sqlite/config.py +3 -3
  21. sqlspec/adapters/sqlite/driver.py +150 -104
  22. sqlspec/config.py +0 -4
  23. sqlspec/driver/_async.py +89 -98
  24. sqlspec/driver/_common.py +52 -17
  25. sqlspec/driver/_sync.py +81 -105
  26. sqlspec/driver/connection.py +207 -0
  27. sqlspec/driver/mixins/_csv_writer.py +91 -0
  28. sqlspec/driver/mixins/_pipeline.py +38 -49
  29. sqlspec/driver/mixins/_result_utils.py +27 -9
  30. sqlspec/driver/mixins/_storage.py +149 -216
  31. sqlspec/driver/mixins/_type_coercion.py +3 -4
  32. sqlspec/driver/parameters.py +138 -0
  33. sqlspec/exceptions.py +10 -2
  34. sqlspec/extensions/aiosql/adapter.py +0 -10
  35. sqlspec/extensions/litestar/handlers.py +0 -1
  36. sqlspec/extensions/litestar/plugin.py +0 -3
  37. sqlspec/extensions/litestar/providers.py +0 -14
  38. sqlspec/loader.py +31 -118
  39. sqlspec/protocols.py +542 -0
  40. sqlspec/service/__init__.py +3 -2
  41. sqlspec/service/_util.py +147 -0
  42. sqlspec/service/base.py +1116 -9
  43. sqlspec/statement/builder/__init__.py +42 -32
  44. sqlspec/statement/builder/_ddl_utils.py +0 -10
  45. sqlspec/statement/builder/_parsing_utils.py +10 -4
  46. sqlspec/statement/builder/base.py +70 -23
  47. sqlspec/statement/builder/column.py +283 -0
  48. sqlspec/statement/builder/ddl.py +102 -65
  49. sqlspec/statement/builder/delete.py +23 -7
  50. sqlspec/statement/builder/insert.py +29 -15
  51. sqlspec/statement/builder/merge.py +4 -4
  52. sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
  53. sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
  54. sqlspec/statement/builder/mixins/_delete_from.py +1 -1
  55. sqlspec/statement/builder/mixins/_from.py +10 -8
  56. sqlspec/statement/builder/mixins/_group_by.py +0 -1
  57. sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
  58. sqlspec/statement/builder/mixins/_insert_values.py +0 -2
  59. sqlspec/statement/builder/mixins/_join.py +20 -13
  60. sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
  61. sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
  62. sqlspec/statement/builder/mixins/_order_by.py +2 -2
  63. sqlspec/statement/builder/mixins/_pivot.py +4 -7
  64. sqlspec/statement/builder/mixins/_select_columns.py +6 -5
  65. sqlspec/statement/builder/mixins/_unpivot.py +6 -9
  66. sqlspec/statement/builder/mixins/_update_from.py +2 -1
  67. sqlspec/statement/builder/mixins/_update_set.py +11 -8
  68. sqlspec/statement/builder/mixins/_where.py +61 -34
  69. sqlspec/statement/builder/select.py +32 -17
  70. sqlspec/statement/builder/update.py +25 -11
  71. sqlspec/statement/filters.py +39 -14
  72. sqlspec/statement/parameter_manager.py +220 -0
  73. sqlspec/statement/parameters.py +210 -79
  74. sqlspec/statement/pipelines/__init__.py +166 -23
  75. sqlspec/statement/pipelines/analyzers/_analyzer.py +22 -25
  76. sqlspec/statement/pipelines/context.py +35 -39
  77. sqlspec/statement/pipelines/transformers/__init__.py +2 -3
  78. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
  79. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +667 -43
  80. sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
  81. sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
  82. sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
  83. sqlspec/statement/pipelines/validators/_performance.py +38 -23
  84. sqlspec/statement/pipelines/validators/_security.py +39 -62
  85. sqlspec/statement/result.py +37 -129
  86. sqlspec/statement/splitter.py +0 -12
  87. sqlspec/statement/sql.py +885 -379
  88. sqlspec/statement/sql_compiler.py +140 -0
  89. sqlspec/storage/__init__.py +10 -2
  90. sqlspec/storage/backends/fsspec.py +82 -35
  91. sqlspec/storage/backends/obstore.py +66 -49
  92. sqlspec/storage/capabilities.py +101 -0
  93. sqlspec/storage/registry.py +56 -83
  94. sqlspec/typing.py +6 -434
  95. sqlspec/utils/cached_property.py +25 -0
  96. sqlspec/utils/correlation.py +0 -2
  97. sqlspec/utils/logging.py +0 -6
  98. sqlspec/utils/sync_tools.py +0 -4
  99. sqlspec/utils/text.py +0 -5
  100. sqlspec/utils/type_guards.py +892 -0
  101. {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/METADATA +1 -1
  102. sqlspec-0.13.0.dist-info/RECORD +150 -0
  103. sqlspec/statement/builder/protocols.py +0 -20
  104. sqlspec/statement/pipelines/base.py +0 -315
  105. sqlspec/statement/pipelines/result_types.py +0 -41
  106. sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
  107. sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
  108. sqlspec/statement/pipelines/validators/base.py +0 -67
  109. sqlspec/storage/protocol.py +0 -170
  110. sqlspec-0.12.1.dist-info/RECORD +0 -145
  111. {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/WHEEL +0 -0
  112. {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/licenses/LICENSE +0 -0
  113. {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/licenses/NOTICE +0 -0
@@ -1,6 +1,8 @@
1
+ import contextlib
1
2
  import datetime
2
3
  import io
3
4
  import logging
5
+ import uuid
4
6
  from collections.abc import Iterator
5
7
  from decimal import Decimal
6
8
  from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union, cast
@@ -8,15 +10,18 @@ from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union, cast
8
10
  from google.cloud.bigquery import (
9
11
  ArrayQueryParameter,
10
12
  Client,
13
+ ExtractJobConfig,
11
14
  LoadJobConfig,
12
15
  QueryJob,
13
16
  QueryJobConfig,
14
17
  ScalarQueryParameter,
18
+ SourceFormat,
15
19
  WriteDisposition,
16
20
  )
17
21
  from google.cloud.bigquery.table import Row as BigQueryRow
18
22
 
19
23
  from sqlspec.driver import SyncDriverAdapterProtocol
24
+ from sqlspec.driver.connection import managed_transaction_sync
20
25
  from sqlspec.driver.mixins import (
21
26
  SQLTranslatorMixin,
22
27
  SyncPipelinedExecutionMixin,
@@ -24,14 +29,17 @@ from sqlspec.driver.mixins import (
24
29
  ToSchemaMixin,
25
30
  TypeCoercionMixin,
26
31
  )
32
+ from sqlspec.driver.parameters import normalize_parameter_sequence
27
33
  from sqlspec.exceptions import SQLSpecError
28
- from sqlspec.statement.parameters import ParameterStyle
29
- from sqlspec.statement.result import ArrowResult, DMLResultDict, ScriptResultDict, SelectResultDict, SQLResult
34
+ from sqlspec.statement.parameters import ParameterStyle, ParameterValidator
35
+ from sqlspec.statement.result import ArrowResult, SQLResult
30
36
  from sqlspec.statement.sql import SQL, SQLConfig
31
- from sqlspec.typing import DictRow, ModelDTOT, RowT
37
+ from sqlspec.typing import DictRow, RowT
32
38
  from sqlspec.utils.serializers import to_json
33
39
 
34
40
  if TYPE_CHECKING:
41
+ from pathlib import Path
42
+
35
43
  from sqlglot.dialects.dialect import DialectType
36
44
 
37
45
 
@@ -134,6 +142,10 @@ class BigQueryDriver(
134
142
  Raises:
135
143
  SQLSpecError: If value type is not supported.
136
144
  """
145
+ if value is None:
146
+ # BigQuery handles NULL values without explicit type
147
+ return ("STRING", None) # Use STRING type for NULL values
148
+
137
149
  value_type = type(value)
138
150
  if value_type is datetime.datetime:
139
151
  return ("TIMESTAMP" if value.tzinfo else "DATETIME", None)
@@ -152,7 +164,6 @@ class BigQueryDriver(
152
164
  if value_type in type_map:
153
165
  return type_map[value_type]
154
166
 
155
- # Handle lists/tuples for ARRAY type
156
167
  if isinstance(value, (list, tuple)):
157
168
  if not value:
158
169
  msg = "Cannot determine BigQuery ARRAY type for empty sequence. Provide typed empty array or ensure context implies type."
@@ -186,8 +197,7 @@ class BigQueryDriver(
186
197
  for name, value in params_dict.items():
187
198
  param_name_for_bq = name.lstrip("@")
188
199
 
189
- # Extract value from TypedParameter if needed
190
- actual_value = value.value if hasattr(value, "value") else value
200
+ actual_value = getattr(value, "value", value)
191
201
 
192
202
  param_type, array_element_type = self._get_bq_param_type(actual_value)
193
203
 
@@ -232,18 +242,14 @@ class BigQueryDriver(
232
242
  """
233
243
  conn = connection or self.connection
234
244
 
235
- # Build final job configuration
236
245
  final_job_config = QueryJobConfig()
237
246
 
238
- # Apply default configuration if available
239
247
  if self._default_query_job_config:
240
248
  self._copy_job_config_attrs(self._default_query_job_config, final_job_config)
241
249
 
242
- # Apply override configuration if provided
243
250
  if job_config:
244
251
  self._copy_job_config_attrs(job_config, final_job_config)
245
252
 
246
- # Set query parameters
247
253
  final_job_config.query_parameters = bq_query_parameters or []
248
254
 
249
255
  # Debug log the actual parameters being sent
@@ -258,23 +264,14 @@ class BigQueryDriver(
258
264
  param_value,
259
265
  type(param_value),
260
266
  )
261
- # Let BigQuery generate the job ID to avoid collisions
262
- # This is the recommended approach for production code and works better with emulators
263
- logger.warning("About to send to BigQuery - SQL: %r", sql_str)
264
- logger.warning("Query parameters in job config: %r", final_job_config.query_parameters)
265
267
  query_job = conn.query(sql_str, job_config=final_job_config)
266
268
 
267
- # Get the auto-generated job ID for callbacks
268
269
  if self.on_job_start and query_job.job_id:
269
- try:
270
+ with contextlib.suppress(Exception):
270
271
  self.on_job_start(query_job.job_id)
271
- except Exception as e:
272
- logger.warning("Job start callback failed: %s", str(e), extra={"adapter": "bigquery"})
273
272
  if self.on_job_complete and query_job.job_id:
274
- try:
273
+ with contextlib.suppress(Exception):
275
274
  self.on_job_complete(query_job.job_id, query_job)
276
- except Exception as e:
277
- logger.warning("Job complete callback failed: %s", str(e), extra={"adapter": "bigquery"})
278
275
 
279
276
  return query_job
280
277
 
@@ -290,15 +287,21 @@ class BigQueryDriver(
290
287
  """
291
288
  return [dict(row) for row in rows_iterator] # type: ignore[misc]
292
289
 
293
- def _handle_select_job(self, query_job: QueryJob) -> SelectResultDict:
290
+ def _handle_select_job(self, query_job: QueryJob, statement: SQL) -> SQLResult[RowT]:
294
291
  """Handle a query job that is expected to return rows."""
295
292
  job_result = query_job.result()
296
293
  rows_list = self._rows_to_results(iter(job_result))
297
294
  column_names = [field.name for field in query_job.schema] if query_job.schema else []
298
295
 
299
- return {"data": rows_list, "column_names": column_names, "rows_affected": len(rows_list)}
296
+ return SQLResult(
297
+ statement=statement,
298
+ data=rows_list,
299
+ column_names=column_names,
300
+ rows_affected=len(rows_list),
301
+ operation_type="SELECT",
302
+ )
300
303
 
301
- def _handle_dml_job(self, query_job: QueryJob) -> DMLResultDict:
304
+ def _handle_dml_job(self, query_job: QueryJob, statement: SQL) -> SQLResult[RowT]:
302
305
  """Handle a DML job.
303
306
 
304
307
  Note: BigQuery emulators (e.g., goccy/bigquery-emulator) may report 0 rows affected
@@ -324,7 +327,14 @@ class BigQueryDriver(
324
327
  )
325
328
  num_affected = 1 # Assume at least one row was affected
326
329
 
327
- return {"rows_affected": num_affected or 0, "status_message": f"OK - job_id: {query_job.job_id}"}
330
+ operation_type = self._determine_operation_type(statement)
331
+ return SQLResult(
332
+ statement=statement,
333
+ data=cast("list[RowT]", []),
334
+ rows_affected=num_affected or 0,
335
+ operation_type=operation_type,
336
+ metadata={"status_message": f"OK - job_id: {query_job.job_id}"},
337
+ )
328
338
 
329
339
  def _compile_bigquery_compatible(self, statement: SQL, target_style: ParameterStyle) -> tuple[str, Any]:
330
340
  """Compile SQL statement for BigQuery.
@@ -336,12 +346,18 @@ class BigQueryDriver(
336
346
 
337
347
  def _execute_statement(
338
348
  self, statement: SQL, connection: Optional[BigQueryConnection] = None, **kwargs: Any
339
- ) -> Union[SelectResultDict, DMLResultDict, ScriptResultDict]:
349
+ ) -> SQLResult[RowT]:
340
350
  if statement.is_script:
341
351
  sql, _ = statement.compile(placeholder_style=ParameterStyle.STATIC)
342
352
  return self._execute_script(sql, connection=connection, **kwargs)
343
353
 
344
- detected_styles = {p.style for p in statement.parameter_info}
354
+ detected_styles = set()
355
+ sql_str = statement.to_sql(placeholder_style=None) # Get raw SQL
356
+ validator = self.config.parameter_validator if self.config else ParameterValidator()
357
+ param_infos = validator.extract_parameters(sql_str)
358
+ if param_infos:
359
+ detected_styles = {p.style for p in param_infos}
360
+
345
361
  target_style = self.default_parameter_style
346
362
 
347
363
  unsupported_styles = detected_styles - set(self.supported_parameter_styles)
@@ -359,167 +375,116 @@ class BigQueryDriver(
359
375
  return self._execute_many(sql, params, connection=connection, **kwargs)
360
376
 
361
377
  sql, params = self._compile_bigquery_compatible(statement, target_style)
362
- logger.debug("compile() returned - sql: %r, params: %r", sql, params)
363
378
  params = self._process_parameters(params)
364
- logger.debug("after _process_parameters - params: %r", params)
365
379
  return self._execute(sql, params, statement, connection=connection, **kwargs)
366
380
 
367
381
  def _execute(
368
382
  self, sql: str, parameters: Any, statement: SQL, connection: Optional[BigQueryConnection] = None, **kwargs: Any
369
- ) -> Union[SelectResultDict, DMLResultDict]:
370
- # SQL should already be in correct format from compile()
371
- converted_sql = sql
372
- # Parameters are already in the correct format from compile()
373
- converted_params = parameters
374
-
375
- # Prepare BigQuery parameters
376
- # Convert various parameter formats to dict format for BigQuery
377
- param_dict: dict[str, Any]
378
- if converted_params is None:
379
- param_dict = {}
380
- elif isinstance(converted_params, dict):
381
- # Filter out non-parameter keys (dialect, config, etc.)
382
- # Real parameters start with 'param_' or are user-provided named parameters
383
- param_dict = {
384
- k: v
385
- for k, v in converted_params.items()
386
- if k.startswith("param_") or (not k.startswith("_") and k not in {"dialect", "config"})
387
- }
388
- elif isinstance(converted_params, (list, tuple)):
389
- # Convert positional parameters to named parameters for BigQuery
390
- # Use param_N to match the compiled SQL placeholders
391
- param_dict = {f"param_{i}": val for i, val in enumerate(converted_params)}
392
- else:
393
- # Single scalar parameter
394
- param_dict = {"param_0": converted_params}
383
+ ) -> SQLResult[RowT]:
384
+ # Use provided connection or driver's default connection
385
+ conn = connection if connection is not None else self._connection(None)
386
+
387
+ # BigQuery doesn't have traditional transactions, but we'll use the pattern for consistency
388
+ # The managed_transaction_sync will just pass through for BigQuery Client objects
389
+ with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
390
+ # Normalize parameters using consolidated utility
391
+ normalized_params = normalize_parameter_sequence(parameters)
392
+ param_dict: dict[str, Any] = {}
393
+ if normalized_params:
394
+ if isinstance(normalized_params[0], dict):
395
+ param_dict = normalized_params[0]
396
+ else:
397
+ param_dict = {f"param_{i}": val for i, val in enumerate(normalized_params)}
395
398
 
396
- bq_params = self._prepare_bq_query_parameters(param_dict)
399
+ bq_params = self._prepare_bq_query_parameters(param_dict)
397
400
 
398
- query_job = self._run_query_job(converted_sql, bq_params, connection=connection)
401
+ query_job = self._run_query_job(sql, bq_params, connection=txn_conn)
399
402
 
400
- if query_job.statement_type == "SELECT" or (
401
- hasattr(query_job, "schema") and query_job.schema and len(query_job.schema) > 0
402
- ):
403
- return self._handle_select_job(query_job)
404
- return self._handle_dml_job(query_job)
403
+ query_schema = getattr(query_job, "schema", None)
404
+ if query_job.statement_type == "SELECT" or (query_schema is not None and len(query_schema) > 0):
405
+ return self._handle_select_job(query_job, statement)
406
+ return self._handle_dml_job(query_job, statement)
405
407
 
406
408
  def _execute_many(
407
409
  self, sql: str, param_list: Any, connection: Optional[BigQueryConnection] = None, **kwargs: Any
408
- ) -> DMLResultDict:
409
- # Use a multi-statement script for batch execution
410
- script_parts = []
411
- all_params: dict[str, Any] = {}
412
- param_counter = 0
413
-
414
- for params in param_list or []:
415
- # Convert various parameter formats to dict format for BigQuery
416
- if isinstance(params, dict):
417
- param_dict = params
418
- elif isinstance(params, (list, tuple)):
419
- # Convert positional parameters to named parameters matching SQL placeholders
420
- param_dict = {f"param_{i}": val for i, val in enumerate(params)}
421
- else:
422
- # Single scalar parameter
423
- param_dict = {"param_0": params}
424
-
425
- # Remap parameters to be unique across the entire script
426
- param_mapping = {}
427
- current_sql = sql
428
- for key, value in param_dict.items():
429
- new_key = f"p_{param_counter}"
430
- param_counter += 1
431
- param_mapping[key] = new_key
432
- all_params[new_key] = value
433
-
434
- # Replace placeholders in the SQL for this statement
435
- for old_key, new_key in param_mapping.items():
436
- current_sql = current_sql.replace(f"@{old_key}", f"@{new_key}")
437
-
438
- script_parts.append(current_sql)
439
-
440
- # Execute as a single script
441
- full_script = ";\n".join(script_parts)
442
- bq_params = self._prepare_bq_query_parameters(all_params)
443
- # Filter out kwargs that _run_query_job doesn't expect
444
- query_kwargs = {k: v for k, v in kwargs.items() if k not in {"parameters", "is_many"}}
445
- query_job = self._run_query_job(full_script, bq_params, connection=connection, **query_kwargs)
446
-
447
- # Wait for the job to complete
448
- query_job.result(timeout=kwargs.get("bq_job_timeout"))
449
- total_rowcount = query_job.num_dml_affected_rows or 0
450
-
451
- return {"rows_affected": total_rowcount, "status_message": f"OK - executed batch job {query_job.job_id}"}
410
+ ) -> SQLResult[RowT]:
411
+ # Use provided connection or driver's default connection
412
+ conn = connection if connection is not None else self._connection(None)
413
+
414
+ with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
415
+ # Normalize parameter list using consolidated utility
416
+ normalized_param_list = normalize_parameter_sequence(param_list)
417
+
418
+ # Use a multi-statement script for batch execution
419
+ script_parts = []
420
+ all_params: dict[str, Any] = {}
421
+ param_counter = 0
422
+
423
+ for params in normalized_param_list or []:
424
+ if isinstance(params, dict):
425
+ param_dict = params
426
+ elif isinstance(params, (list, tuple)):
427
+ param_dict = {f"param_{i}": val for i, val in enumerate(params)}
428
+ else:
429
+ param_dict = {"param_0": params}
430
+
431
+ # Remap parameters to be unique across the entire script
432
+ param_mapping = {}
433
+ current_sql = sql
434
+ for key, value in param_dict.items():
435
+ new_key = f"p_{param_counter}"
436
+ param_counter += 1
437
+ param_mapping[key] = new_key
438
+ all_params[new_key] = value
439
+
440
+ for old_key, new_key in param_mapping.items():
441
+ current_sql = current_sql.replace(f"@{old_key}", f"@{new_key}")
442
+
443
+ script_parts.append(current_sql)
444
+
445
+ # Execute as a single script
446
+ full_script = ";\n".join(script_parts)
447
+ bq_params = self._prepare_bq_query_parameters(all_params)
448
+ # Filter out kwargs that _run_query_job doesn't expect
449
+ query_kwargs = {k: v for k, v in kwargs.items() if k not in {"parameters", "is_many"}}
450
+ query_job = self._run_query_job(full_script, bq_params, connection=txn_conn, **query_kwargs)
451
+
452
+ # Wait for the job to complete
453
+ query_job.result(timeout=kwargs.get("bq_job_timeout"))
454
+ total_rowcount = query_job.num_dml_affected_rows or 0
455
+
456
+ return SQLResult(
457
+ statement=SQL(sql, _dialect=self.dialect),
458
+ data=[],
459
+ rows_affected=total_rowcount,
460
+ operation_type="EXECUTE",
461
+ metadata={"status_message": f"OK - executed batch job {query_job.job_id}"},
462
+ )
452
463
 
453
464
  def _execute_script(
454
465
  self, script: str, connection: Optional[BigQueryConnection] = None, **kwargs: Any
455
- ) -> ScriptResultDict:
456
- # BigQuery does not support multi-statement scripts in a single job
457
- # Use the shared implementation to split and execute statements individually
458
- statements = self._split_script_statements(script)
459
-
460
- for statement in statements:
461
- if statement:
462
- query_job = self._run_query_job(statement, [], connection=connection)
463
- query_job.result(timeout=kwargs.get("bq_job_timeout"))
464
-
465
- return {"statements_executed": len(statements), "status_message": "SCRIPT EXECUTED"}
466
-
467
- def _wrap_select_result(
468
- self, statement: SQL, result: SelectResultDict, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any
469
- ) -> "Union[SQLResult[RowT], SQLResult[ModelDTOT]]":
470
- if schema_type:
471
- return cast(
472
- "SQLResult[ModelDTOT]",
473
- SQLResult(
474
- statement=statement,
475
- data=cast("list[ModelDTOT]", list(self.to_schema(data=result["data"], schema_type=schema_type))),
476
- column_names=result["column_names"],
477
- rows_affected=result["rows_affected"],
478
- operation_type="SELECT",
479
- ),
480
- )
466
+ ) -> SQLResult[RowT]:
467
+ # Use provided connection or driver's default connection
468
+ conn = connection if connection is not None else self._connection(None)
481
469
 
482
- return cast(
483
- "SQLResult[RowT]",
484
- SQLResult(
485
- statement=statement,
486
- data=result["data"],
487
- column_names=result["column_names"],
488
- operation_type="SELECT",
489
- rows_affected=result["rows_affected"],
490
- ),
491
- )
470
+ with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
471
+ # BigQuery does not support multi-statement scripts in a single job
472
+ statements = self._split_script_statements(script)
473
+
474
+ for statement in statements:
475
+ if statement:
476
+ query_job = self._run_query_job(statement, [], connection=txn_conn)
477
+ query_job.result(timeout=kwargs.get("bq_job_timeout"))
492
478
 
493
- def _wrap_execute_result(
494
- self, statement: SQL, result: Union[DMLResultDict, ScriptResultDict], **kwargs: Any
495
- ) -> "SQLResult[RowT]":
496
- operation_type = "UNKNOWN"
497
- if statement.expression:
498
- operation_type = str(statement.expression.key).upper()
499
- if "statements_executed" in result:
500
- return SQLResult[RowT](
501
- statement=statement,
479
+ return SQLResult(
480
+ statement=SQL(script, _dialect=self.dialect).as_script(),
502
481
  data=[],
503
482
  rows_affected=0,
504
483
  operation_type="SCRIPT",
505
- metadata={
506
- "status_message": result.get("status_message", ""),
507
- "statements_executed": result.get("statements_executed", -1),
508
- },
509
- )
510
- if "rows_affected" in result:
511
- dml_result = cast("DMLResultDict", result)
512
- rows_affected = dml_result["rows_affected"]
513
- status_message = dml_result.get("status_message", "")
514
- return SQLResult[RowT](
515
- statement=statement,
516
- data=[],
517
- rows_affected=rows_affected,
518
- operation_type=operation_type,
519
- metadata={"status_message": status_message},
484
+ metadata={"status_message": "SCRIPT EXECUTED"},
485
+ total_statements=len(statements),
486
+ successful_statements=len(statements),
520
487
  )
521
- msg = f"Unexpected result type: {type(result)}"
522
- raise ValueError(msg)
523
488
 
524
489
  def _connection(self, connection: "Optional[Client]" = None) -> "Client":
525
490
  """Get the connection to use for the operation."""
@@ -529,28 +494,115 @@ class BigQueryDriver(
529
494
  # BigQuery Native Export Support
530
495
  # ============================================================================
531
496
 
532
- def _export_native(self, query: str, destination_uri: str, format: str, **options: Any) -> int:
533
- """BigQuery native export implementation.
497
+ def _export_native(self, query: str, destination_uri: "Union[str, Path]", format: str, **options: Any) -> int:
498
+ """BigQuery native export implementation with automatic GCS staging.
534
499
 
535
- For local files, BigQuery doesn't support direct export, so we raise NotImplementedError
536
- to trigger the fallback mechanism that uses fetch + write.
500
+ For GCS URIs, uses direct export. For other locations, automatically stages
501
+ through a temporary GCS location and transfers to the final destination.
537
502
 
538
503
  Args:
539
504
  query: SQL query to execute
540
- destination_uri: Destination URI (local file path or gs:// URI)
505
+ destination_uri: Destination URI (local file path, gs:// URI, or Path object)
541
506
  format: Export format (parquet, csv, json, avro)
542
- **options: Additional export options
507
+ **options: Additional export options including 'gcs_staging_bucket'
543
508
 
544
509
  Returns:
545
510
  Number of rows exported
546
511
 
547
512
  Raises:
548
- NotImplementedError: Always, to trigger fallback to fetch + write
513
+ NotImplementedError: If no staging bucket is configured for non-GCS destinations
549
514
  """
550
- # BigQuery only supports native export to GCS, not local files
551
- # By raising NotImplementedError, the mixin will fall back to fetch + write
552
- msg = "BigQuery native export only supports GCS URIs, using fallback for local files"
553
- raise NotImplementedError(msg)
515
+ destination_str = str(destination_uri)
516
+
517
+ # If it's already a GCS URI, use direct export
518
+ if destination_str.startswith("gs://"):
519
+ return self._export_to_gcs_native(query, destination_str, format, **options)
520
+
521
+ staging_bucket = options.get("gcs_staging_bucket") or getattr(self.config, "gcs_staging_bucket", None)
522
+ if not staging_bucket:
523
+ # Fall back to fetch + write for non-GCS destinations without staging
524
+ msg = "BigQuery native export requires GCS staging bucket for non-GCS destinations"
525
+ raise NotImplementedError(msg)
526
+
527
+ # Generate temporary GCS path
528
+ from datetime import timezone
529
+
530
+ timestamp = datetime.datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
531
+ temp_filename = f"bigquery_export_{timestamp}_{uuid.uuid4().hex[:8]}.{format}"
532
+ temp_gcs_uri = f"gs://{staging_bucket}/temp_exports/{temp_filename}"
533
+
534
+ try:
535
+ # Export to temporary GCS location
536
+ rows_exported = self._export_to_gcs_native(query, temp_gcs_uri, format, **options)
537
+
538
+ # Transfer from GCS to final destination using storage backend
539
+ backend, path = self._resolve_backend_and_path(destination_str)
540
+ gcs_backend = self._get_storage_backend(temp_gcs_uri)
541
+
542
+ # Download from GCS and upload to final destination
543
+ data = gcs_backend.read_bytes(temp_gcs_uri)
544
+ backend.write_bytes(path, data)
545
+
546
+ return rows_exported
547
+ finally:
548
+ # Clean up temporary file
549
+ try:
550
+ gcs_backend = self._get_storage_backend(temp_gcs_uri)
551
+ gcs_backend.delete(temp_gcs_uri)
552
+ except Exception as e:
553
+ logger.warning("Failed to clean up temporary GCS file %s: %s", temp_gcs_uri, e)
554
+
555
+ def _export_to_gcs_native(self, query: str, gcs_uri: str, format: str, **options: Any) -> int:
556
+ """Direct BigQuery export to GCS.
557
+
558
+ Args:
559
+ query: SQL query to execute
560
+ gcs_uri: GCS destination URI (must start with gs://)
561
+ format: Export format (parquet, csv, json, avro)
562
+ **options: Additional export options
563
+
564
+ Returns:
565
+ Number of rows exported
566
+ """
567
+ # First, run the query and store results in a temporary table
568
+
569
+ temp_table_id = f"temp_export_{uuid.uuid4().hex[:8]}"
570
+ dataset_id = getattr(self.connection, "default_dataset", None) or options.get("dataset", "temp")
571
+
572
+ query_with_table = f"CREATE OR REPLACE TABLE `{dataset_id}.{temp_table_id}` AS {query}"
573
+ create_job = self._run_query_job(query_with_table, [])
574
+ create_job.result()
575
+
576
+ count_query = f"SELECT COUNT(*) as cnt FROM `{dataset_id}.{temp_table_id}`"
577
+ count_job = self._run_query_job(count_query, [])
578
+ count_result = list(count_job.result())
579
+ row_count = count_result[0]["cnt"] if count_result else 0
580
+
581
+ try:
582
+ # Configure extract job
583
+ extract_config = ExtractJobConfig(**options) # type: ignore[no-untyped-call]
584
+
585
+ format_mapping = {
586
+ "parquet": SourceFormat.PARQUET,
587
+ "csv": SourceFormat.CSV,
588
+ "json": SourceFormat.NEWLINE_DELIMITED_JSON,
589
+ "avro": SourceFormat.AVRO,
590
+ }
591
+ extract_config.destination_format = format_mapping.get(format, SourceFormat.PARQUET)
592
+
593
+ table_ref = self.connection.dataset(dataset_id).table(temp_table_id)
594
+ extract_job = self.connection.extract_table(table_ref, gcs_uri, job_config=extract_config)
595
+ extract_job.result()
596
+
597
+ return row_count
598
+ finally:
599
+ # Clean up temporary table
600
+ try:
601
+ delete_query = f"DROP TABLE IF EXISTS `{dataset_id}.{temp_table_id}`"
602
+ delete_job = self._run_query_job(delete_query, [])
603
+ delete_job.result()
604
+ except Exception as e:
605
+ logger.warning("Failed to clean up temporary table %s: %s", temp_table_id, e)
554
606
 
555
607
  # ============================================================================
556
608
  # BigQuery Native Arrow Support
@@ -570,7 +622,6 @@ class BigQueryDriver(
570
622
  Returns:
571
623
  ArrowResult with native Arrow table
572
624
  """
573
-
574
625
  # Execute the query directly with BigQuery to get the QueryJob
575
626
  params = sql.get_parameters(style=self.default_parameter_style)
576
627
  params_dict: dict[str, Any] = {}
@@ -650,7 +701,6 @@ class BigQueryDriver(
650
701
  raise ValueError(msg)
651
702
 
652
703
  # Use BigQuery's native Arrow loading
653
- # Convert Arrow table to bytes for direct loading
654
704
 
655
705
  import pyarrow.parquet as pq
656
706
 
@@ -2,7 +2,6 @@
2
2
 
3
3
  import logging
4
4
  from contextlib import contextmanager
5
- from dataclasses import replace
6
5
  from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, TypedDict
7
6
 
8
7
  import duckdb
@@ -336,11 +335,9 @@ class DuckDBConfig(NoPoolSyncConfig[DuckDBConnection, DuckDBDriver]):
336
335
  # DuckDB connect() only accepts database, read_only, and config parameters
337
336
  connect_params: dict[str, Any] = {}
338
337
 
339
- # Set database if provided
340
338
  if hasattr(self, "database") and self.database is not None:
341
339
  connect_params["database"] = self.database
342
340
 
343
- # Set read_only if provided
344
341
  if hasattr(self, "read_only") and self.read_only is not None:
345
342
  connect_params["read_only"] = self.read_only
346
343
 
@@ -352,7 +349,6 @@ class DuckDBConfig(NoPoolSyncConfig[DuckDBConnection, DuckDBDriver]):
352
349
  if value is not None and value is not Empty:
353
350
  config_dict[field] = value
354
351
 
355
- # Add extras to config dict
356
352
  config_dict.update(self.extras)
357
353
 
358
354
  # If we have config parameters, add them
@@ -475,15 +471,16 @@ class DuckDBConfig(NoPoolSyncConfig[DuckDBConnection, DuckDBDriver]):
475
471
  @contextmanager
476
472
  def session_manager() -> "Generator[DuckDBDriver, None, None]":
477
473
  with self.provide_connection(*args, **kwargs) as connection:
478
- # Create statement config with parameter style info if not already set
479
474
  statement_config = self.statement_config
475
+ # Inject parameter style info if not already set
480
476
  if statement_config.allowed_parameter_styles is None:
477
+ from dataclasses import replace
478
+
481
479
  statement_config = replace(
482
480
  statement_config,
483
481
  allowed_parameter_styles=self.supported_parameter_styles,
484
482
  target_parameter_style=self.preferred_parameter_style,
485
483
  )
486
-
487
484
  driver = self.driver_type(connection=connection, config=statement_config)
488
485
  yield driver
489
486