sqlspec 0.9.0__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (47) hide show
  1. sqlspec/__init__.py +2 -1
  2. sqlspec/adapters/adbc/__init__.py +2 -1
  3. sqlspec/adapters/adbc/config.py +7 -13
  4. sqlspec/adapters/adbc/driver.py +160 -21
  5. sqlspec/adapters/aiosqlite/__init__.py +2 -1
  6. sqlspec/adapters/aiosqlite/config.py +10 -12
  7. sqlspec/adapters/aiosqlite/driver.py +160 -22
  8. sqlspec/adapters/asyncmy/__init__.py +2 -1
  9. sqlspec/adapters/asyncmy/driver.py +158 -22
  10. sqlspec/adapters/asyncpg/config.py +1 -3
  11. sqlspec/adapters/asyncpg/driver.py +143 -5
  12. sqlspec/adapters/bigquery/__init__.py +4 -0
  13. sqlspec/adapters/bigquery/config/__init__.py +3 -0
  14. sqlspec/adapters/bigquery/config/_common.py +40 -0
  15. sqlspec/adapters/bigquery/config/_sync.py +87 -0
  16. sqlspec/adapters/bigquery/driver.py +701 -0
  17. sqlspec/adapters/duckdb/__init__.py +2 -1
  18. sqlspec/adapters/duckdb/config.py +17 -18
  19. sqlspec/adapters/duckdb/driver.py +165 -27
  20. sqlspec/adapters/oracledb/__init__.py +8 -1
  21. sqlspec/adapters/oracledb/config/_asyncio.py +7 -8
  22. sqlspec/adapters/oracledb/config/_sync.py +6 -7
  23. sqlspec/adapters/oracledb/driver.py +311 -42
  24. sqlspec/adapters/psqlpy/__init__.py +9 -0
  25. sqlspec/adapters/psqlpy/config.py +11 -19
  26. sqlspec/adapters/psqlpy/driver.py +171 -19
  27. sqlspec/adapters/psycopg/__init__.py +8 -1
  28. sqlspec/adapters/psycopg/config/__init__.py +10 -0
  29. sqlspec/adapters/psycopg/config/_async.py +6 -7
  30. sqlspec/adapters/psycopg/config/_sync.py +7 -8
  31. sqlspec/adapters/psycopg/driver.py +344 -86
  32. sqlspec/adapters/sqlite/__init__.py +2 -1
  33. sqlspec/adapters/sqlite/config.py +12 -11
  34. sqlspec/adapters/sqlite/driver.py +160 -51
  35. sqlspec/base.py +402 -63
  36. sqlspec/exceptions.py +9 -0
  37. sqlspec/extensions/litestar/config.py +3 -11
  38. sqlspec/extensions/litestar/handlers.py +2 -1
  39. sqlspec/extensions/litestar/plugin.py +6 -2
  40. sqlspec/mixins.py +156 -0
  41. sqlspec/typing.py +19 -1
  42. {sqlspec-0.9.0.dist-info → sqlspec-0.10.0.dist-info}/METADATA +147 -3
  43. sqlspec-0.10.0.dist-info/RECORD +67 -0
  44. sqlspec-0.9.0.dist-info/RECORD +0 -61
  45. {sqlspec-0.9.0.dist-info → sqlspec-0.10.0.dist-info}/WHEEL +0 -0
  46. {sqlspec-0.9.0.dist-info → sqlspec-0.10.0.dist-info}/licenses/LICENSE +0 -0
  47. {sqlspec-0.9.0.dist-info → sqlspec-0.10.0.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,701 @@
1
+ import contextlib
2
+ import datetime
3
+ from collections.abc import Iterator, Sequence
4
+ from decimal import Decimal
5
+ from typing import (
6
+ TYPE_CHECKING,
7
+ Any,
8
+ ClassVar,
9
+ Optional,
10
+ Union,
11
+ cast,
12
+ overload,
13
+ )
14
+
15
+ import sqlglot
16
+ from google.cloud import bigquery
17
+ from google.cloud.bigquery import Client
18
+ from google.cloud.bigquery.job import QueryJob, QueryJobConfig
19
+ from google.cloud.exceptions import NotFound
20
+
21
+ from sqlspec.base import SyncDriverAdapterProtocol
22
+ from sqlspec.exceptions import NotFoundError, SQLSpecError
23
+ from sqlspec.mixins import (
24
+ SQLTranslatorMixin,
25
+ SyncArrowBulkOperationsMixin,
26
+ SyncParquetExportMixin,
27
+ )
28
+ from sqlspec.typing import ArrowTable, ModelDTOT, StatementParameterType, T
29
+
30
+ if TYPE_CHECKING:
31
+ from google.cloud.bigquery import SchemaField
32
+ from google.cloud.bigquery.table import Row
33
+
34
+ __all__ = ("BigQueryConnection", "BigQueryDriver")
35
+
36
+ BigQueryConnection = Client
37
+
38
+
39
+ class BigQueryDriver(
40
+ SyncDriverAdapterProtocol["BigQueryConnection"],
41
+ SyncArrowBulkOperationsMixin["BigQueryConnection"],
42
+ SyncParquetExportMixin["BigQueryConnection"],
43
+ SQLTranslatorMixin["BigQueryConnection"],
44
+ ):
45
+ """Synchronous BigQuery Driver Adapter."""
46
+
47
+ dialect: str = "bigquery"
48
+ connection: "BigQueryConnection"
49
+ __supports_arrow__: ClassVar[bool] = True
50
+
51
+ def __init__(self, connection: "BigQueryConnection", **kwargs: Any) -> None:
52
+ super().__init__(connection=connection)
53
+ self._default_query_job_config = kwargs.get("default_query_job_config") or getattr(
54
+ connection, "default_query_job_config", None
55
+ )
56
+
57
+ @staticmethod
58
+ def _get_bq_param_type(value: Any) -> "tuple[Optional[str], Optional[str]]": # noqa: PLR0911, PLR0912
59
+ if isinstance(value, bool):
60
+ return "BOOL", None
61
+ if isinstance(value, int):
62
+ return "INT64", None
63
+ if isinstance(value, float):
64
+ return "FLOAT64", None
65
+ if isinstance(value, Decimal):
66
+ # Precision/scale might matter, but BQ client handles conversion.
67
+ # Defaulting to BIGNUMERIC, NUMERIC might be desired in some cases though (User change)
68
+ return "BIGNUMERIC", None
69
+ if isinstance(value, str):
70
+ return "STRING", None
71
+ if isinstance(value, bytes):
72
+ return "BYTES", None
73
+ if isinstance(value, datetime.date):
74
+ return "DATE", None
75
+ # DATETIME is for timezone-naive values
76
+ if isinstance(value, datetime.datetime) and value.tzinfo is None:
77
+ return "DATETIME", None
78
+ # TIMESTAMP is for timezone-aware values
79
+ if isinstance(value, datetime.datetime) and value.tzinfo is not None:
80
+ return "TIMESTAMP", None
81
+ if isinstance(value, datetime.time):
82
+ return "TIME", None
83
+
84
+ # Handle Arrays - Determine element type
85
+ if isinstance(value, (list, tuple)):
86
+ if not value:
87
+ # Cannot determine type of empty array, BQ requires type.
88
+ # Raise or default? Defaulting is risky. Let's raise.
89
+ msg = "Cannot determine BigQuery ARRAY type for empty sequence."
90
+ raise SQLSpecError(msg)
91
+ # Infer type from first element
92
+ first_element = value[0]
93
+ element_type, _ = BigQueryDriver._get_bq_param_type(first_element)
94
+ if element_type is None:
95
+ msg = f"Unsupported element type in ARRAY: {type(first_element)}"
96
+ raise SQLSpecError(msg)
97
+ return "ARRAY", element_type
98
+
99
+ # Handle Structs (basic dict mapping) - Requires careful handling
100
+ # if isinstance(value, dict):
101
+ # # This requires recursive type mapping for sub-fields.
102
+ # # For simplicity, users might need to construct StructQueryParameter manually.
103
+ # # return "STRUCT", None # Placeholder if implementing # noqa: ERA001
104
+ # raise SQLSpecError("Automatic STRUCT mapping not implemented. Please use bigquery.StructQueryParameter.") # noqa: ERA001
105
+
106
+ return None, None # Unsupported type
107
+
108
+ def _run_query_job( # noqa: C901, PLR0912, PLR0915 (User change)
109
+ self,
110
+ sql: str,
111
+ parameters: "Optional[StatementParameterType]" = None,
112
+ connection: "Optional[BigQueryConnection]" = None,
113
+ job_config: "Optional[QueryJobConfig]" = None,
114
+ is_script: bool = False,
115
+ **kwargs: Any,
116
+ ) -> "QueryJob":
117
+ conn = self._connection(connection)
118
+
119
+ # Determine the final job config, creating a new one if necessary
120
+ # to avoid modifying a shared default config.
121
+ if job_config:
122
+ final_job_config = job_config # Use the provided config directly
123
+ elif self._default_query_job_config:
124
+ final_job_config = QueryJobConfig()
125
+ else:
126
+ final_job_config = QueryJobConfig() # Create a fresh config
127
+
128
+ # --- Parameter Handling Logic --- Start
129
+ params: Union[dict[str, Any], list[Any], None] = None
130
+ param_style: Optional[str] = None # 'named' (@), 'qmark' (?)
131
+ use_preformatted_params = False
132
+ final_sql = sql # Default to original SQL
133
+
134
+ # Check for pre-formatted BQ parameters first
135
+ if (
136
+ isinstance(parameters, (list, tuple))
137
+ and parameters
138
+ and all(isinstance(p, (bigquery.ScalarQueryParameter, bigquery.ArrayQueryParameter)) for p in parameters)
139
+ ):
140
+ if kwargs:
141
+ msg = "Cannot mix pre-formatted BigQuery parameters with keyword arguments."
142
+ raise SQLSpecError(msg)
143
+ use_preformatted_params = True
144
+ final_job_config.query_parameters = list(parameters)
145
+ # Keep final_sql = sql, as it should match the pre-formatted named params
146
+
147
+ # Determine parameter style and merge standard parameters ONLY if not preformatted
148
+ if not use_preformatted_params:
149
+ if isinstance(parameters, dict):
150
+ params = {**parameters, **kwargs}
151
+ param_style = "named"
152
+ elif isinstance(parameters, (list, tuple)):
153
+ if kwargs:
154
+ msg = "Cannot mix positional parameters with keyword arguments."
155
+ raise SQLSpecError(msg)
156
+ # Check if it's primitives for qmark style
157
+ if all(
158
+ not isinstance(p, (bigquery.ScalarQueryParameter, bigquery.ArrayQueryParameter)) for p in parameters
159
+ ):
160
+ params = list(parameters)
161
+ param_style = "qmark"
162
+ else:
163
+ # Mixed list or non-BQ parameter objects
164
+ msg = "Invalid mix of parameter types in list. Use only primitive values or only BigQuery QueryParameter objects."
165
+ raise SQLSpecError(msg)
166
+
167
+ elif kwargs:
168
+ params = kwargs
169
+ param_style = "named"
170
+ elif parameters is not None and not isinstance(
171
+ parameters, (bigquery.ScalarQueryParameter, bigquery.ArrayQueryParameter)
172
+ ):
173
+ # Could be a single primitive value for positional
174
+ params = [parameters]
175
+ param_style = "qmark"
176
+ elif parameters is not None: # Single BQ parameter object
177
+ msg = "Single BigQuery QueryParameter objects should be passed within a list."
178
+ raise SQLSpecError(msg)
179
+
180
+ # Use sqlglot to transpile ONLY if not a script and not preformatted
181
+ if not is_script and not use_preformatted_params:
182
+ try:
183
+ # Transpile for syntax normalization/dialect conversion if needed
184
+ # Use BigQuery dialect for both reading and writing
185
+ final_sql = sqlglot.transpile(sql, read=self.dialect, write=self.dialect)[0]
186
+ except Exception as e:
187
+ # Catch potential sqlglot errors
188
+ msg = f"SQL transpilation failed using sqlglot: {e!s}" # Adjusted message
189
+ raise SQLSpecError(msg) from e
190
+ # else: If preformatted_params, final_sql remains the original sql
191
+
192
+ # --- Parameter Handling Logic --- (Moved outside the transpilation try/except)
193
+ # Prepare BQ parameters based on style, ONLY if not preformatted
194
+ if not use_preformatted_params:
195
+ if param_style == "named" and params:
196
+ # Convert dict params to BQ ScalarQueryParameter
197
+ if isinstance(params, dict):
198
+ final_job_config.query_parameters = [
199
+ bigquery.ScalarQueryParameter(name, self._get_bq_param_type(value)[0], value)
200
+ for name, value in params.items()
201
+ ]
202
+ else:
203
+ # This path should ideally not be reached if param_style logic is correct
204
+ msg = f"Internal error: Parameter style is 'named' but parameters are not a dict: {type(params)}"
205
+ raise SQLSpecError(msg)
206
+ elif param_style == "qmark" and params:
207
+ # Convert list params to BQ ScalarQueryParameter
208
+ final_job_config.query_parameters = [
209
+ bigquery.ScalarQueryParameter(None, self._get_bq_param_type(value)[0], value) for value in params
210
+ ]
211
+
212
+ # --- Parameter Handling Logic --- End
213
+
214
+ # Determine which kwargs to pass to the actual query method.
215
+ # We only want to pass kwargs that were *not* treated as SQL parameters.
216
+ final_query_kwargs = {}
217
+ if parameters is not None and kwargs: # Params came via arg, kwargs are separate
218
+ final_query_kwargs = kwargs
219
+ # Else: If params came via kwargs, they are already handled, so don't pass them again.
220
+
221
+ # Execute query
222
+ return conn.query(
223
+ final_sql,
224
+ job_config=final_job_config,
225
+ **final_query_kwargs, # Pass only relevant kwargs
226
+ )
227
+
228
+ @staticmethod
229
+ def _rows_to_results(
230
+ rows: "Iterator[Row]",
231
+ schema: "Sequence[SchemaField]",
232
+ schema_type: "Optional[type[ModelDTOT]]" = None,
233
+ ) -> Sequence[Union[ModelDTOT, dict[str, Any]]]:
234
+ processed_results = []
235
+ # Create a quick lookup map for schema fields from the passed schema
236
+ schema_map = {field.name: field for field in schema}
237
+
238
+ for row in rows:
239
+ # row here is now a Row object from the iterator
240
+ row_dict = {}
241
+ for key, value in row.items(): # Use row.items() on the Row object
242
+ field = schema_map.get(key)
243
+ # Workaround remains the same
244
+ if field and field.field_type == "TIMESTAMP" and isinstance(value, str) and "." in value:
245
+ try:
246
+ parsed_value = datetime.datetime.fromtimestamp(float(value), tz=datetime.timezone.utc)
247
+ row_dict[key] = parsed_value
248
+ except ValueError:
249
+ row_dict[key] = value # type: ignore[assignment]
250
+ else:
251
+ row_dict[key] = value
252
+ # Use the processed dictionary for the final result
253
+ if schema_type:
254
+ processed_results.append(schema_type(**row_dict))
255
+ else:
256
+ processed_results.append(row_dict) # type: ignore[arg-type]
257
+ if schema_type:
258
+ return cast("Sequence[ModelDTOT]", processed_results)
259
+ return cast("Sequence[dict[str, Any]]", processed_results)
260
+
261
+ @overload
262
+ def select(
263
+ self,
264
+ sql: str,
265
+ parameters: "Optional[StatementParameterType]" = None,
266
+ /,
267
+ *,
268
+ connection: "Optional[BigQueryConnection]" = None,
269
+ schema_type: None = None,
270
+ **kwargs: Any,
271
+ ) -> "Sequence[dict[str, Any]]": ...
272
+ @overload
273
+ def select(
274
+ self,
275
+ sql: str,
276
+ parameters: "Optional[StatementParameterType]" = None,
277
+ /,
278
+ *,
279
+ connection: "Optional[BigQueryConnection]" = None,
280
+ schema_type: "type[ModelDTOT]",
281
+ **kwargs: Any,
282
+ ) -> "Sequence[ModelDTOT]": ...
283
+ def select(
284
+ self,
285
+ sql: str,
286
+ parameters: "Optional[StatementParameterType]" = None,
287
+ /,
288
+ *,
289
+ connection: "Optional[BigQueryConnection]" = None,
290
+ schema_type: "Optional[type[ModelDTOT]]" = None,
291
+ job_config: "Optional[QueryJobConfig]" = None,
292
+ **kwargs: Any,
293
+ ) -> "Sequence[Union[ModelDTOT, dict[str, Any]]]":
294
+ query_job = self._run_query_job(sql, parameters, connection, job_config, **kwargs)
295
+ return self._rows_to_results(query_job.result(), query_job.result().schema, schema_type)
296
+
297
+ @overload
298
+ def select_one(
299
+ self,
300
+ sql: str,
301
+ parameters: "Optional[StatementParameterType]" = None,
302
+ /,
303
+ *,
304
+ connection: "Optional[BigQueryConnection]" = None,
305
+ schema_type: None = None,
306
+ **kwargs: Any,
307
+ ) -> "dict[str, Any]": ...
308
+ @overload
309
+ def select_one(
310
+ self,
311
+ sql: str,
312
+ parameters: "Optional[StatementParameterType]" = None,
313
+ /,
314
+ *,
315
+ connection: "Optional[BigQueryConnection]" = None,
316
+ schema_type: "type[ModelDTOT]",
317
+ **kwargs: Any,
318
+ ) -> "ModelDTOT": ...
319
+ def select_one(
320
+ self,
321
+ sql: str,
322
+ parameters: "Optional[StatementParameterType]" = None,
323
+ /,
324
+ *,
325
+ connection: "Optional[BigQueryConnection]" = None,
326
+ schema_type: "Optional[type[ModelDTOT]]" = None,
327
+ job_config: "Optional[QueryJobConfig]" = None,
328
+ **kwargs: Any,
329
+ ) -> "Union[ModelDTOT, dict[str, Any]]":
330
+ query_job = self._run_query_job(sql, parameters, connection, job_config, **kwargs)
331
+ rows_iterator = query_job.result()
332
+ try:
333
+ # Pass the iterator containing only the first row to _rows_to_results
334
+ # This ensures the timestamp workaround is applied consistently.
335
+ # We need to pass the original iterator for schema access, but only consume one row.
336
+ first_row = next(rows_iterator)
337
+ # Create a simple iterator yielding only the first row for processing
338
+ single_row_iter = iter([first_row])
339
+ # We need RowIterator type for schema, create mock/proxy if needed, or pass schema
340
+ # Let's try passing schema directly to _rows_to_results (requires modifying it)
341
+ results = self._rows_to_results(single_row_iter, rows_iterator.schema, schema_type)
342
+ return results[0]
343
+ except StopIteration:
344
+ msg = "No result found when one was expected"
345
+ raise NotFoundError(msg) from None
346
+
347
+ @overload
348
+ def select_one_or_none(
349
+ self,
350
+ sql: str,
351
+ parameters: "Optional[StatementParameterType]" = None,
352
+ /,
353
+ *,
354
+ connection: "Optional[BigQueryConnection]" = None,
355
+ schema_type: None = None,
356
+ **kwargs: Any,
357
+ ) -> "Optional[dict[str, Any]]": ...
358
+ @overload
359
+ def select_one_or_none(
360
+ self,
361
+ sql: str,
362
+ parameters: "Optional[StatementParameterType]" = None,
363
+ /,
364
+ *,
365
+ connection: "Optional[BigQueryConnection]" = None,
366
+ schema_type: "type[ModelDTOT]",
367
+ **kwargs: Any,
368
+ ) -> "Optional[ModelDTOT]": ...
369
+ def select_one_or_none(
370
+ self,
371
+ sql: str,
372
+ parameters: "Optional[StatementParameterType]" = None,
373
+ /,
374
+ *,
375
+ connection: "Optional[BigQueryConnection]" = None,
376
+ schema_type: "Optional[type[ModelDTOT]]" = None,
377
+ job_config: "Optional[QueryJobConfig]" = None,
378
+ **kwargs: Any,
379
+ ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]":
380
+ query_job = self._run_query_job(sql, parameters, connection, job_config, **kwargs)
381
+ rows_iterator = query_job.result()
382
+ try:
383
+ first_row = next(rows_iterator)
384
+ # Create a simple iterator yielding only the first row for processing
385
+ single_row_iter = iter([first_row])
386
+ # Pass schema directly
387
+ results = self._rows_to_results(single_row_iter, rows_iterator.schema, schema_type)
388
+ return results[0]
389
+ except StopIteration:
390
+ return None
391
+
392
+ @overload
393
+ def select_value(
394
+ self,
395
+ sql: str,
396
+ parameters: "Optional[StatementParameterType]" = None,
397
+ /,
398
+ *,
399
+ connection: "Optional[BigQueryConnection]" = None,
400
+ schema_type: "Optional[type[T]]" = None,
401
+ job_config: "Optional[QueryJobConfig]" = None,
402
+ **kwargs: Any,
403
+ ) -> Union[T, Any]: ...
404
+ @overload
405
+ def select_value(
406
+ self,
407
+ sql: str,
408
+ parameters: "Optional[StatementParameterType]" = None,
409
+ /,
410
+ *,
411
+ connection: "Optional[BigQueryConnection]" = None,
412
+ schema_type: "type[T]",
413
+ **kwargs: Any,
414
+ ) -> "T": ...
415
+ def select_value(
416
+ self,
417
+ sql: str,
418
+ parameters: "Optional[StatementParameterType]" = None,
419
+ /,
420
+ *,
421
+ connection: "Optional[BigQueryConnection]" = None,
422
+ schema_type: "Optional[type[T]]" = None,
423
+ job_config: "Optional[QueryJobConfig]" = None,
424
+ **kwargs: Any,
425
+ ) -> Union[T, Any]:
426
+ query_job = self._run_query_job(
427
+ sql=sql, parameters=parameters, connection=connection, job_config=job_config, **kwargs
428
+ )
429
+ rows = query_job.result()
430
+ try:
431
+ first_row = next(iter(rows))
432
+ value = first_row[0]
433
+ # Apply timestamp workaround if necessary
434
+ field = rows.schema[0] # Get schema for the first column
435
+ if field and field.field_type == "TIMESTAMP" and isinstance(value, str) and "." in value:
436
+ with contextlib.suppress(ValueError):
437
+ value = datetime.datetime.fromtimestamp(float(value), tz=datetime.timezone.utc)
438
+
439
+ return cast("T", value) if schema_type else value
440
+ except (StopIteration, IndexError):
441
+ msg = "No value found when one was expected"
442
+ raise NotFoundError(msg) from None
443
+
444
+ @overload
445
+ def select_value_or_none(
446
+ self,
447
+ sql: str,
448
+ parameters: "Optional[StatementParameterType]" = None,
449
+ /,
450
+ *,
451
+ connection: "Optional[BigQueryConnection]" = None,
452
+ schema_type: None = None,
453
+ **kwargs: Any,
454
+ ) -> "Optional[Any]": ...
455
+ @overload
456
+ def select_value_or_none(
457
+ self,
458
+ sql: str,
459
+ parameters: "Optional[StatementParameterType]" = None,
460
+ /,
461
+ *,
462
+ connection: "Optional[BigQueryConnection]" = None,
463
+ schema_type: "type[T]",
464
+ **kwargs: Any,
465
+ ) -> "Optional[T]": ...
466
+ def select_value_or_none(
467
+ self,
468
+ sql: str,
469
+ parameters: "Optional[StatementParameterType]" = None,
470
+ /,
471
+ *,
472
+ connection: "Optional[BigQueryConnection]" = None,
473
+ schema_type: "Optional[type[T]]" = None,
474
+ job_config: "Optional[QueryJobConfig]" = None,
475
+ **kwargs: Any,
476
+ ) -> "Optional[Union[T, Any]]":
477
+ query_job = self._run_query_job(
478
+ sql=sql, parameters=parameters, connection=connection, job_config=job_config, **kwargs
479
+ )
480
+ rows = query_job.result()
481
+ try:
482
+ first_row = next(iter(rows))
483
+ value = first_row[0]
484
+ # Apply timestamp workaround if necessary
485
+ field = rows.schema[0] # Get schema for the first column
486
+ if field and field.field_type == "TIMESTAMP" and isinstance(value, str) and "." in value:
487
+ with contextlib.suppress(ValueError):
488
+ value = datetime.datetime.fromtimestamp(float(value), tz=datetime.timezone.utc)
489
+
490
+ return cast("T", value) if schema_type else value
491
+ except (StopIteration, IndexError):
492
+ return None
493
+
494
+ def insert_update_delete(
495
+ self,
496
+ sql: str,
497
+ parameters: Optional[StatementParameterType] = None,
498
+ /,
499
+ *,
500
+ connection: Optional["BigQueryConnection"] = None,
501
+ job_config: Optional[QueryJobConfig] = None,
502
+ **kwargs: Any,
503
+ ) -> int:
504
+ """Executes INSERT, UPDATE, DELETE and returns affected row count.
505
+
506
+ Returns:
507
+ int: The number of rows affected by the DML statement.
508
+ """
509
+ query_job = self._run_query_job(
510
+ sql=sql, parameters=parameters, connection=connection, job_config=job_config, **kwargs
511
+ )
512
+ # DML statements might not return rows, check job properties
513
+ # num_dml_affected_rows might be None initially, wait might be needed
514
+ query_job.result() # Ensure completion
515
+ return query_job.num_dml_affected_rows or 0 # Return 0 if None
516
+
517
+ @overload
518
+ def insert_update_delete_returning(
519
+ self,
520
+ sql: str,
521
+ parameters: "Optional[StatementParameterType]" = None,
522
+ /,
523
+ *,
524
+ connection: "Optional[BigQueryConnection]" = None,
525
+ schema_type: None = None,
526
+ **kwargs: Any,
527
+ ) -> "dict[str, Any]": ...
528
+ @overload
529
+ def insert_update_delete_returning(
530
+ self,
531
+ sql: str,
532
+ parameters: "Optional[StatementParameterType]" = None,
533
+ /,
534
+ *,
535
+ connection: "Optional[BigQueryConnection]" = None,
536
+ schema_type: "type[ModelDTOT]",
537
+ **kwargs: Any,
538
+ ) -> "ModelDTOT": ...
539
+ def insert_update_delete_returning(
540
+ self,
541
+ sql: str,
542
+ parameters: "Optional[StatementParameterType]" = None,
543
+ /,
544
+ *,
545
+ connection: "Optional[BigQueryConnection]" = None,
546
+ schema_type: "Optional[type[ModelDTOT]]" = None,
547
+ job_config: "Optional[QueryJobConfig]" = None,
548
+ **kwargs: Any,
549
+ ) -> Union[ModelDTOT, dict[str, Any]]:
550
+ """BigQuery DML RETURNING equivalent is complex, often requires temp tables or scripting."""
551
+ msg = "BigQuery does not support `RETURNING` clauses directly in the same way as some other SQL databases. Consider multi-statement queries or alternative approaches."
552
+ raise NotImplementedError(msg)
553
+
554
+ def execute_script(
555
+ self,
556
+ sql: str, # Expecting a script here
557
+ parameters: "Optional[StatementParameterType]" = None, # Parameters might be complex in scripts
558
+ /,
559
+ *,
560
+ connection: "Optional[BigQueryConnection]" = None,
561
+ job_config: "Optional[QueryJobConfig]" = None,
562
+ **kwargs: Any,
563
+ ) -> str:
564
+ """Executes a BigQuery script and returns the job ID.
565
+
566
+ Returns:
567
+ str: The job ID of the executed script.
568
+ """
569
+ query_job = self._run_query_job(
570
+ sql=sql,
571
+ parameters=parameters,
572
+ connection=connection,
573
+ job_config=job_config,
574
+ is_script=True,
575
+ **kwargs,
576
+ )
577
+ return str(query_job.job_id)
578
+
579
+ # --- Mixin Implementations ---
580
+
581
+ def select_arrow( # pyright: ignore # noqa: PLR0912
582
+ self,
583
+ sql: str,
584
+ parameters: "Optional[StatementParameterType]" = None,
585
+ /,
586
+ *,
587
+ connection: "Optional[BigQueryConnection]" = None,
588
+ job_config: "Optional[QueryJobConfig]" = None,
589
+ **kwargs: Any,
590
+ ) -> "ArrowTable": # pyright: ignore[reportUnknownReturnType]
591
+ conn = self._connection(connection)
592
+ final_job_config = job_config or self._default_query_job_config or QueryJobConfig()
593
+
594
+ # Determine parameter style and merge parameters (Similar to _run_query_job)
595
+ params: Union[dict[str, Any], list[Any], None] = None
596
+ param_style: Optional[str] = None # 'named' (@), 'qmark' (?)
597
+
598
+ if isinstance(parameters, dict):
599
+ params = {**parameters, **kwargs}
600
+ param_style = "named"
601
+ elif isinstance(parameters, (list, tuple)):
602
+ if kwargs:
603
+ msg = "Cannot mix positional parameters with keyword arguments."
604
+ raise SQLSpecError(msg)
605
+ params = list(parameters)
606
+ param_style = "qmark"
607
+ elif kwargs:
608
+ params = kwargs
609
+ param_style = "named"
610
+ elif parameters is not None:
611
+ params = [parameters]
612
+ param_style = "qmark"
613
+
614
+ # Use sqlglot to transpile and bind parameters
615
+ try:
616
+ transpiled_sql = sqlglot.transpile(sql, args=params or {}, read=None, write=self.dialect)[0]
617
+ except Exception as e:
618
+ msg = f"SQL transpilation/binding failed using sqlglot: {e!s}"
619
+ raise SQLSpecError(msg) from e
620
+
621
+ # Prepare BigQuery specific parameters if named style was used
622
+ if param_style == "named" and params:
623
+ if not isinstance(params, dict):
624
+ # This should be logically impossible due to how param_style is set
625
+ msg = "Internal error: named parameter style detected but params is not a dict."
626
+ raise SQLSpecError(msg)
627
+ query_parameters = []
628
+ for key, value in params.items():
629
+ param_type, array_element_type = self._get_bq_param_type(value)
630
+
631
+ if param_type == "ARRAY" and array_element_type:
632
+ query_parameters.append(bigquery.ArrayQueryParameter(key, array_element_type, value))
633
+ elif param_type:
634
+ query_parameters.append(bigquery.ScalarQueryParameter(key, param_type, value)) # type: ignore[arg-type]
635
+ else:
636
+ msg = f"Unsupported parameter type for BigQuery Arrow named parameter '{key}': {type(value)}"
637
+ raise SQLSpecError(msg)
638
+ final_job_config.query_parameters = query_parameters
639
+ elif param_style == "qmark" and params:
640
+ # Positional params handled by client library
641
+ pass
642
+
643
+ # Execute the query and get Arrow table
644
+ try:
645
+ query_job = conn.query(transpiled_sql, job_config=final_job_config)
646
+ arrow_table = query_job.to_arrow() # Waits for job completion
647
+
648
+ except Exception as e:
649
+ msg = f"BigQuery Arrow query execution failed: {e!s}"
650
+ raise SQLSpecError(msg) from e
651
+ return arrow_table
652
+
653
+ def select_to_parquet(
654
+ self,
655
+ sql: str, # Expects table ID: project.dataset.table
656
+ parameters: "Optional[StatementParameterType]" = None,
657
+ /,
658
+ *,
659
+ destination_uri: "Optional[str]" = None,
660
+ connection: "Optional[BigQueryConnection]" = None,
661
+ job_config: "Optional[bigquery.ExtractJobConfig]" = None,
662
+ **kwargs: Any,
663
+ ) -> None:
664
+ """Exports a BigQuery table to Parquet files in Google Cloud Storage.
665
+
666
+ Raises:
667
+ NotImplementedError: If the SQL is not a fully qualified table ID or if parameters are provided.
668
+ NotFoundError: If the source table is not found.
669
+ SQLSpecError: If the Parquet export fails.
670
+ """
671
+ if destination_uri is None:
672
+ msg = "destination_uri is required"
673
+ raise SQLSpecError(msg)
674
+ conn = self._connection(connection)
675
+ if "." not in sql or parameters is not None:
676
+ msg = "select_to_parquet currently expects a fully qualified table ID (project.dataset.table) as the `sql` argument and no `parameters`."
677
+ raise NotImplementedError(msg)
678
+
679
+ source_table_ref = bigquery.TableReference.from_string(sql, default_project=conn.project)
680
+
681
+ final_extract_config = job_config or bigquery.ExtractJobConfig() # type: ignore[no-untyped-call]
682
+ final_extract_config.destination_format = bigquery.DestinationFormat.PARQUET
683
+
684
+ try:
685
+ extract_job = conn.extract_table(
686
+ source_table_ref,
687
+ destination_uri,
688
+ job_config=final_extract_config,
689
+ # Location is correctly inferred by the client library
690
+ )
691
+ extract_job.result() # Wait for completion
692
+
693
+ except NotFound:
694
+ msg = f"Source table not found for Parquet export: {source_table_ref}"
695
+ raise NotFoundError(msg) from None
696
+ except Exception as e:
697
+ msg = f"BigQuery Parquet export failed: {e!s}"
698
+ raise SQLSpecError(msg) from e
699
+ if extract_job.errors:
700
+ msg = f"BigQuery Parquet export failed: {extract_job.errors}"
701
+ raise SQLSpecError(msg)