sqlspec 0.26.0__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (212) hide show
  1. sqlspec/__init__.py +7 -15
  2. sqlspec/_serialization.py +55 -25
  3. sqlspec/_typing.py +155 -52
  4. sqlspec/adapters/adbc/_types.py +1 -1
  5. sqlspec/adapters/adbc/adk/__init__.py +5 -0
  6. sqlspec/adapters/adbc/adk/store.py +880 -0
  7. sqlspec/adapters/adbc/config.py +62 -12
  8. sqlspec/adapters/adbc/data_dictionary.py +74 -2
  9. sqlspec/adapters/adbc/driver.py +226 -58
  10. sqlspec/adapters/adbc/litestar/__init__.py +5 -0
  11. sqlspec/adapters/adbc/litestar/store.py +504 -0
  12. sqlspec/adapters/adbc/type_converter.py +44 -50
  13. sqlspec/adapters/aiosqlite/_types.py +1 -1
  14. sqlspec/adapters/aiosqlite/adk/__init__.py +5 -0
  15. sqlspec/adapters/aiosqlite/adk/store.py +536 -0
  16. sqlspec/adapters/aiosqlite/config.py +86 -16
  17. sqlspec/adapters/aiosqlite/data_dictionary.py +34 -2
  18. sqlspec/adapters/aiosqlite/driver.py +127 -38
  19. sqlspec/adapters/aiosqlite/litestar/__init__.py +5 -0
  20. sqlspec/adapters/aiosqlite/litestar/store.py +281 -0
  21. sqlspec/adapters/aiosqlite/pool.py +7 -7
  22. sqlspec/adapters/asyncmy/__init__.py +7 -1
  23. sqlspec/adapters/asyncmy/_types.py +1 -1
  24. sqlspec/adapters/asyncmy/adk/__init__.py +5 -0
  25. sqlspec/adapters/asyncmy/adk/store.py +503 -0
  26. sqlspec/adapters/asyncmy/config.py +59 -17
  27. sqlspec/adapters/asyncmy/data_dictionary.py +41 -2
  28. sqlspec/adapters/asyncmy/driver.py +293 -62
  29. sqlspec/adapters/asyncmy/litestar/__init__.py +5 -0
  30. sqlspec/adapters/asyncmy/litestar/store.py +296 -0
  31. sqlspec/adapters/asyncpg/__init__.py +2 -1
  32. sqlspec/adapters/asyncpg/_type_handlers.py +71 -0
  33. sqlspec/adapters/asyncpg/_types.py +11 -7
  34. sqlspec/adapters/asyncpg/adk/__init__.py +5 -0
  35. sqlspec/adapters/asyncpg/adk/store.py +460 -0
  36. sqlspec/adapters/asyncpg/config.py +57 -36
  37. sqlspec/adapters/asyncpg/data_dictionary.py +48 -2
  38. sqlspec/adapters/asyncpg/driver.py +153 -23
  39. sqlspec/adapters/asyncpg/litestar/__init__.py +5 -0
  40. sqlspec/adapters/asyncpg/litestar/store.py +253 -0
  41. sqlspec/adapters/bigquery/_types.py +1 -1
  42. sqlspec/adapters/bigquery/adk/__init__.py +5 -0
  43. sqlspec/adapters/bigquery/adk/store.py +585 -0
  44. sqlspec/adapters/bigquery/config.py +36 -11
  45. sqlspec/adapters/bigquery/data_dictionary.py +42 -2
  46. sqlspec/adapters/bigquery/driver.py +489 -144
  47. sqlspec/adapters/bigquery/litestar/__init__.py +5 -0
  48. sqlspec/adapters/bigquery/litestar/store.py +327 -0
  49. sqlspec/adapters/bigquery/type_converter.py +55 -23
  50. sqlspec/adapters/duckdb/_types.py +2 -2
  51. sqlspec/adapters/duckdb/adk/__init__.py +14 -0
  52. sqlspec/adapters/duckdb/adk/store.py +563 -0
  53. sqlspec/adapters/duckdb/config.py +79 -21
  54. sqlspec/adapters/duckdb/data_dictionary.py +41 -2
  55. sqlspec/adapters/duckdb/driver.py +225 -44
  56. sqlspec/adapters/duckdb/litestar/__init__.py +5 -0
  57. sqlspec/adapters/duckdb/litestar/store.py +332 -0
  58. sqlspec/adapters/duckdb/pool.py +5 -5
  59. sqlspec/adapters/duckdb/type_converter.py +51 -21
  60. sqlspec/adapters/oracledb/_numpy_handlers.py +133 -0
  61. sqlspec/adapters/oracledb/_types.py +20 -2
  62. sqlspec/adapters/oracledb/adk/__init__.py +5 -0
  63. sqlspec/adapters/oracledb/adk/store.py +1628 -0
  64. sqlspec/adapters/oracledb/config.py +120 -36
  65. sqlspec/adapters/oracledb/data_dictionary.py +87 -20
  66. sqlspec/adapters/oracledb/driver.py +475 -86
  67. sqlspec/adapters/oracledb/litestar/__init__.py +5 -0
  68. sqlspec/adapters/oracledb/litestar/store.py +765 -0
  69. sqlspec/adapters/oracledb/migrations.py +316 -25
  70. sqlspec/adapters/oracledb/type_converter.py +91 -16
  71. sqlspec/adapters/psqlpy/_type_handlers.py +44 -0
  72. sqlspec/adapters/psqlpy/_types.py +2 -1
  73. sqlspec/adapters/psqlpy/adk/__init__.py +5 -0
  74. sqlspec/adapters/psqlpy/adk/store.py +483 -0
  75. sqlspec/adapters/psqlpy/config.py +45 -19
  76. sqlspec/adapters/psqlpy/data_dictionary.py +48 -2
  77. sqlspec/adapters/psqlpy/driver.py +108 -41
  78. sqlspec/adapters/psqlpy/litestar/__init__.py +5 -0
  79. sqlspec/adapters/psqlpy/litestar/store.py +272 -0
  80. sqlspec/adapters/psqlpy/type_converter.py +40 -11
  81. sqlspec/adapters/psycopg/_type_handlers.py +80 -0
  82. sqlspec/adapters/psycopg/_types.py +2 -1
  83. sqlspec/adapters/psycopg/adk/__init__.py +5 -0
  84. sqlspec/adapters/psycopg/adk/store.py +962 -0
  85. sqlspec/adapters/psycopg/config.py +65 -37
  86. sqlspec/adapters/psycopg/data_dictionary.py +91 -3
  87. sqlspec/adapters/psycopg/driver.py +200 -78
  88. sqlspec/adapters/psycopg/litestar/__init__.py +5 -0
  89. sqlspec/adapters/psycopg/litestar/store.py +554 -0
  90. sqlspec/adapters/sqlite/__init__.py +2 -1
  91. sqlspec/adapters/sqlite/_type_handlers.py +86 -0
  92. sqlspec/adapters/sqlite/_types.py +1 -1
  93. sqlspec/adapters/sqlite/adk/__init__.py +5 -0
  94. sqlspec/adapters/sqlite/adk/store.py +582 -0
  95. sqlspec/adapters/sqlite/config.py +85 -16
  96. sqlspec/adapters/sqlite/data_dictionary.py +34 -2
  97. sqlspec/adapters/sqlite/driver.py +120 -52
  98. sqlspec/adapters/sqlite/litestar/__init__.py +5 -0
  99. sqlspec/adapters/sqlite/litestar/store.py +318 -0
  100. sqlspec/adapters/sqlite/pool.py +5 -5
  101. sqlspec/base.py +45 -26
  102. sqlspec/builder/__init__.py +73 -4
  103. sqlspec/builder/_base.py +91 -58
  104. sqlspec/builder/_column.py +5 -5
  105. sqlspec/builder/_ddl.py +98 -89
  106. sqlspec/builder/_delete.py +5 -4
  107. sqlspec/builder/_dml.py +388 -0
  108. sqlspec/{_sql.py → builder/_factory.py} +41 -44
  109. sqlspec/builder/_insert.py +5 -82
  110. sqlspec/builder/{mixins/_join_operations.py → _join.py} +145 -143
  111. sqlspec/builder/_merge.py +446 -11
  112. sqlspec/builder/_parsing_utils.py +9 -11
  113. sqlspec/builder/_select.py +1313 -25
  114. sqlspec/builder/_update.py +11 -42
  115. sqlspec/cli.py +76 -69
  116. sqlspec/config.py +331 -62
  117. sqlspec/core/__init__.py +5 -4
  118. sqlspec/core/cache.py +18 -18
  119. sqlspec/core/compiler.py +6 -8
  120. sqlspec/core/filters.py +55 -47
  121. sqlspec/core/hashing.py +9 -9
  122. sqlspec/core/parameters.py +76 -45
  123. sqlspec/core/result.py +234 -47
  124. sqlspec/core/splitter.py +16 -17
  125. sqlspec/core/statement.py +32 -31
  126. sqlspec/core/type_conversion.py +3 -2
  127. sqlspec/driver/__init__.py +1 -3
  128. sqlspec/driver/_async.py +183 -160
  129. sqlspec/driver/_common.py +197 -109
  130. sqlspec/driver/_sync.py +189 -161
  131. sqlspec/driver/mixins/_result_tools.py +20 -236
  132. sqlspec/driver/mixins/_sql_translator.py +4 -4
  133. sqlspec/exceptions.py +70 -7
  134. sqlspec/extensions/adk/__init__.py +53 -0
  135. sqlspec/extensions/adk/_types.py +51 -0
  136. sqlspec/extensions/adk/converters.py +172 -0
  137. sqlspec/extensions/adk/migrations/0001_create_adk_tables.py +144 -0
  138. sqlspec/extensions/adk/migrations/__init__.py +0 -0
  139. sqlspec/extensions/adk/service.py +181 -0
  140. sqlspec/extensions/adk/store.py +536 -0
  141. sqlspec/extensions/aiosql/adapter.py +69 -61
  142. sqlspec/extensions/fastapi/__init__.py +21 -0
  143. sqlspec/extensions/fastapi/extension.py +331 -0
  144. sqlspec/extensions/fastapi/providers.py +543 -0
  145. sqlspec/extensions/flask/__init__.py +36 -0
  146. sqlspec/extensions/flask/_state.py +71 -0
  147. sqlspec/extensions/flask/_utils.py +40 -0
  148. sqlspec/extensions/flask/extension.py +389 -0
  149. sqlspec/extensions/litestar/__init__.py +21 -4
  150. sqlspec/extensions/litestar/cli.py +54 -10
  151. sqlspec/extensions/litestar/config.py +56 -266
  152. sqlspec/extensions/litestar/handlers.py +46 -17
  153. sqlspec/extensions/litestar/migrations/0001_create_session_table.py +137 -0
  154. sqlspec/extensions/litestar/migrations/__init__.py +3 -0
  155. sqlspec/extensions/litestar/plugin.py +349 -224
  156. sqlspec/extensions/litestar/providers.py +25 -25
  157. sqlspec/extensions/litestar/store.py +265 -0
  158. sqlspec/extensions/starlette/__init__.py +10 -0
  159. sqlspec/extensions/starlette/_state.py +25 -0
  160. sqlspec/extensions/starlette/_utils.py +52 -0
  161. sqlspec/extensions/starlette/extension.py +254 -0
  162. sqlspec/extensions/starlette/middleware.py +154 -0
  163. sqlspec/loader.py +30 -49
  164. sqlspec/migrations/base.py +200 -76
  165. sqlspec/migrations/commands.py +591 -62
  166. sqlspec/migrations/context.py +6 -9
  167. sqlspec/migrations/fix.py +199 -0
  168. sqlspec/migrations/loaders.py +47 -19
  169. sqlspec/migrations/runner.py +241 -75
  170. sqlspec/migrations/tracker.py +237 -21
  171. sqlspec/migrations/utils.py +51 -3
  172. sqlspec/migrations/validation.py +177 -0
  173. sqlspec/protocols.py +106 -36
  174. sqlspec/storage/_utils.py +85 -0
  175. sqlspec/storage/backends/fsspec.py +133 -107
  176. sqlspec/storage/backends/local.py +78 -51
  177. sqlspec/storage/backends/obstore.py +276 -168
  178. sqlspec/storage/registry.py +75 -39
  179. sqlspec/typing.py +30 -84
  180. sqlspec/utils/__init__.py +25 -4
  181. sqlspec/utils/arrow_helpers.py +81 -0
  182. sqlspec/utils/config_resolver.py +6 -6
  183. sqlspec/utils/correlation.py +4 -5
  184. sqlspec/utils/data_transformation.py +3 -2
  185. sqlspec/utils/deprecation.py +9 -8
  186. sqlspec/utils/fixtures.py +4 -4
  187. sqlspec/utils/logging.py +46 -6
  188. sqlspec/utils/module_loader.py +205 -5
  189. sqlspec/utils/portal.py +311 -0
  190. sqlspec/utils/schema.py +288 -0
  191. sqlspec/utils/serializers.py +113 -4
  192. sqlspec/utils/sync_tools.py +36 -22
  193. sqlspec/utils/text.py +1 -2
  194. sqlspec/utils/type_guards.py +136 -20
  195. sqlspec/utils/version.py +433 -0
  196. {sqlspec-0.26.0.dist-info → sqlspec-0.28.0.dist-info}/METADATA +41 -22
  197. sqlspec-0.28.0.dist-info/RECORD +221 -0
  198. sqlspec/builder/mixins/__init__.py +0 -55
  199. sqlspec/builder/mixins/_cte_and_set_ops.py +0 -253
  200. sqlspec/builder/mixins/_delete_operations.py +0 -50
  201. sqlspec/builder/mixins/_insert_operations.py +0 -282
  202. sqlspec/builder/mixins/_merge_operations.py +0 -698
  203. sqlspec/builder/mixins/_order_limit_operations.py +0 -145
  204. sqlspec/builder/mixins/_pivot_operations.py +0 -157
  205. sqlspec/builder/mixins/_select_operations.py +0 -930
  206. sqlspec/builder/mixins/_update_operations.py +0 -199
  207. sqlspec/builder/mixins/_where_clause.py +0 -1298
  208. sqlspec-0.26.0.dist-info/RECORD +0 -157
  209. sqlspec-0.26.0.dist-info/licenses/NOTICE +0 -29
  210. {sqlspec-0.26.0.dist-info → sqlspec-0.28.0.dist-info}/WHEEL +0 -0
  211. {sqlspec-0.26.0.dist-info → sqlspec-0.28.0.dist-info}/entry_points.txt +0 -0
  212. {sqlspec-0.26.0.dist-info → sqlspec-0.28.0.dist-info}/licenses/LICENSE +0 -0
@@ -7,7 +7,7 @@ type coercion, error handling, and query job management.
7
7
  import datetime
8
8
  import logging
9
9
  from decimal import Decimal
10
- from typing import TYPE_CHECKING, Any, Optional, Union
10
+ from typing import TYPE_CHECKING, Any
11
11
 
12
12
  import sqlglot
13
13
  import sqlglot.expressions as exp
@@ -16,29 +16,44 @@ from google.cloud.exceptions import GoogleCloudError
16
16
 
17
17
  from sqlspec.adapters.bigquery._types import BigQueryConnection
18
18
  from sqlspec.adapters.bigquery.type_converter import BigQueryTypeConverter
19
- from sqlspec.core.cache import get_cache_config
20
- from sqlspec.core.parameters import ParameterStyle, ParameterStyleConfig
21
- from sqlspec.core.statement import StatementConfig
22
- from sqlspec.driver import SyncDriverAdapterBase
23
- from sqlspec.driver._common import ExecutionResult
24
- from sqlspec.exceptions import SQLParsingError, SQLSpecError
19
+ from sqlspec.core import ParameterStyle, ParameterStyleConfig, StatementConfig, get_cache_config
20
+ from sqlspec.driver import ExecutionResult, SyncDriverAdapterBase
21
+ from sqlspec.exceptions import (
22
+ DatabaseConnectionError,
23
+ DataError,
24
+ NotFoundError,
25
+ OperationalError,
26
+ SQLParsingError,
27
+ SQLSpecError,
28
+ UniqueViolationError,
29
+ )
25
30
  from sqlspec.utils.serializers import to_json
26
31
 
27
32
  if TYPE_CHECKING:
33
+ from collections.abc import Callable
28
34
  from contextlib import AbstractContextManager
29
35
 
30
- from sqlspec.core.result import SQLResult
31
- from sqlspec.core.statement import SQL
32
- from sqlspec.driver._sync import SyncDataDictionaryBase
36
+ from sqlspec.builder import QueryBuilder
37
+ from sqlspec.core import SQL, SQLResult, Statement, StatementFilter
38
+ from sqlspec.core.result import ArrowResult
39
+ from sqlspec.driver import SyncDataDictionaryBase
40
+ from sqlspec.typing import StatementParameters
33
41
 
34
42
  logger = logging.getLogger(__name__)
35
43
 
36
44
  __all__ = ("BigQueryCursor", "BigQueryDriver", "BigQueryExceptionHandler", "bigquery_statement_config")
37
45
 
38
- _type_converter = BigQueryTypeConverter()
46
+ HTTP_CONFLICT = 409
47
+ HTTP_NOT_FOUND = 404
48
+ HTTP_BAD_REQUEST = 400
49
+ HTTP_FORBIDDEN = 403
50
+ HTTP_SERVER_ERROR = 500
51
+
39
52
 
53
+ _default_type_converter = BigQueryTypeConverter()
40
54
 
41
- _BQ_TYPE_MAP: dict[type, tuple[str, Optional[str]]] = {
55
+
56
+ _BQ_TYPE_MAP: dict[type, tuple[str, str | None]] = {
42
57
  bool: ("BOOL", None),
43
58
  int: ("INT64", None),
44
59
  float: ("FLOAT64", None),
@@ -51,7 +66,134 @@ _BQ_TYPE_MAP: dict[type, tuple[str, Optional[str]]] = {
51
66
  }
52
67
 
53
68
 
54
- def _get_bq_param_type(value: Any) -> tuple[Optional[str], Optional[str]]:
69
+ def _create_array_parameter(name: str, value: Any, array_type: str) -> ArrayQueryParameter:
70
+ """Create BigQuery ARRAY parameter.
71
+
72
+ Args:
73
+ name: Parameter name.
74
+ value: Array value (converted to list, empty list if None).
75
+ array_type: BigQuery array element type.
76
+
77
+ Returns:
78
+ ArrayQueryParameter instance.
79
+ """
80
+ return ArrayQueryParameter(name, array_type, [] if value is None else list(value))
81
+
82
+
83
+ def _create_json_parameter(name: str, value: Any, json_serializer: "Callable[[Any], str]") -> ScalarQueryParameter:
84
+ """Create BigQuery JSON parameter as STRING type.
85
+
86
+ Args:
87
+ name: Parameter name.
88
+ value: JSON-serializable value.
89
+ json_serializer: Function to serialize to JSON string.
90
+
91
+ Returns:
92
+ ScalarQueryParameter with STRING type.
93
+ """
94
+ return ScalarQueryParameter(name, "STRING", json_serializer(value))
95
+
96
+
97
+ def _create_scalar_parameter(name: str, value: Any, param_type: str) -> ScalarQueryParameter:
98
+ """Create BigQuery scalar parameter.
99
+
100
+ Args:
101
+ name: Parameter name.
102
+ value: Scalar value.
103
+ param_type: BigQuery parameter type (INT64, FLOAT64, etc.).
104
+
105
+ Returns:
106
+ ScalarQueryParameter instance.
107
+ """
108
+ return ScalarQueryParameter(name, param_type, value)
109
+
110
+
111
+ def _create_literal_node(value: Any, json_serializer: "Callable[[Any], str]") -> "exp.Expression":
112
+ """Create a SQLGlot literal expression from a Python value.
113
+
114
+ Args:
115
+ value: Python value to convert to SQLGlot literal.
116
+ json_serializer: Function to serialize dict/list to JSON string.
117
+
118
+ Returns:
119
+ SQLGlot expression representing the literal value.
120
+ """
121
+ if value is None:
122
+ return exp.Null()
123
+ if isinstance(value, bool):
124
+ return exp.Boolean(this=value)
125
+ if isinstance(value, (int, float)):
126
+ return exp.Literal.number(str(value))
127
+ if isinstance(value, str):
128
+ return exp.Literal.string(value)
129
+ if isinstance(value, (list, tuple)):
130
+ items = [_create_literal_node(item, json_serializer) for item in value]
131
+ return exp.Array(expressions=items)
132
+ if isinstance(value, dict):
133
+ json_str = json_serializer(value)
134
+ return exp.Literal.string(json_str)
135
+
136
+ return exp.Literal.string(str(value))
137
+
138
+
139
+ def _replace_placeholder_node(
140
+ node: "exp.Expression",
141
+ parameters: Any,
142
+ placeholder_counter: dict[str, int],
143
+ json_serializer: "Callable[[Any], str]",
144
+ ) -> "exp.Expression":
145
+ """Replace placeholder or parameter nodes with literal values.
146
+
147
+ Handles both positional placeholders (?) and named parameters (@name, :name).
148
+ Converts values to SQLGlot literal expressions for safe embedding in SQL.
149
+
150
+ Args:
151
+ node: SQLGlot expression node to check and potentially replace.
152
+ parameters: Parameter values (dict, list, or tuple).
153
+ placeholder_counter: Mutable counter dict for positional placeholders.
154
+ json_serializer: Function to serialize dict/list to JSON string.
155
+
156
+ Returns:
157
+ Literal expression if replacement made, otherwise original node.
158
+ """
159
+ if isinstance(node, exp.Placeholder):
160
+ if isinstance(parameters, (list, tuple)):
161
+ current_index = placeholder_counter["index"]
162
+ placeholder_counter["index"] += 1
163
+ if current_index < len(parameters):
164
+ return _create_literal_node(parameters[current_index], json_serializer)
165
+ return node
166
+
167
+ if isinstance(node, exp.Parameter):
168
+ param_name = str(node.this) if hasattr(node.this, "__str__") else node.this
169
+
170
+ if isinstance(parameters, dict):
171
+ possible_names = [param_name, f"@{param_name}", f":{param_name}", f"param_{param_name}"]
172
+ for name in possible_names:
173
+ if name in parameters:
174
+ actual_value = getattr(parameters[name], "value", parameters[name])
175
+ return _create_literal_node(actual_value, json_serializer)
176
+ return node
177
+
178
+ if isinstance(parameters, (list, tuple)):
179
+ try:
180
+ if param_name.startswith("param_"):
181
+ param_index = int(param_name[6:])
182
+ if param_index < len(parameters):
183
+ return _create_literal_node(parameters[param_index], json_serializer)
184
+
185
+ if param_name.isdigit():
186
+ param_index = int(param_name)
187
+ if param_index < len(parameters):
188
+ return _create_literal_node(parameters[param_index], json_serializer)
189
+ except (ValueError, IndexError, AttributeError):
190
+ pass
191
+ return node
192
+
193
+ return node
194
+
195
+
196
+ def _get_bq_param_type(value: Any) -> tuple[str | None, str | None]:
55
197
  """Determine BigQuery parameter type from Python value.
56
198
 
57
199
  Args:
@@ -84,20 +226,30 @@ def _get_bq_param_type(value: Any) -> tuple[Optional[str], Optional[str]]:
84
226
  return None, None
85
227
 
86
228
 
87
- _BQ_PARAM_CREATOR_MAP: dict[str, Any] = {
88
- "ARRAY": lambda name, value, array_type: ArrayQueryParameter(
89
- name, array_type, [] if value is None else list(value)
90
- ),
91
- "JSON": lambda name, value, _: ScalarQueryParameter(name, "STRING", to_json(value)),
92
- "SCALAR": lambda name, value, param_type: ScalarQueryParameter(name, param_type, value),
93
- }
229
+ def _get_bq_param_creator_map(json_serializer: "Callable[[Any], str]") -> dict[str, Any]:
230
+ """Get BigQuery parameter creator map with configurable JSON serializer.
231
+
232
+ Args:
233
+ json_serializer: Function to serialize dict/list to JSON string.
234
+
235
+ Returns:
236
+ Dictionary mapping parameter types to creator functions.
237
+ """
238
+ return {
239
+ "ARRAY": _create_array_parameter,
240
+ "JSON": lambda name, value, _: _create_json_parameter(name, value, json_serializer),
241
+ "SCALAR": _create_scalar_parameter,
242
+ }
94
243
 
95
244
 
96
- def _create_bq_parameters(parameters: Any) -> "list[Union[ArrayQueryParameter, ScalarQueryParameter]]":
245
+ def _create_bq_parameters(
246
+ parameters: Any, json_serializer: "Callable[[Any], str]"
247
+ ) -> "list[ArrayQueryParameter | ScalarQueryParameter]":
97
248
  """Create BigQuery QueryParameter objects from parameters.
98
249
 
99
250
  Args:
100
251
  parameters: Dict of named parameters or list of positional parameters
252
+ json_serializer: Function to serialize dict/list to JSON string
101
253
 
102
254
  Returns:
103
255
  List of BigQuery QueryParameter objects
@@ -105,7 +257,8 @@ def _create_bq_parameters(parameters: Any) -> "list[Union[ArrayQueryParameter, S
105
257
  if not parameters:
106
258
  return []
107
259
 
108
- bq_parameters: list[Union[ArrayQueryParameter, ScalarQueryParameter]] = []
260
+ bq_parameters: list[ArrayQueryParameter | ScalarQueryParameter] = []
261
+ param_creator_map = _get_bq_param_creator_map(json_serializer)
109
262
 
110
263
  if isinstance(parameters, dict):
111
264
  for name, value in parameters.items():
@@ -114,13 +267,13 @@ def _create_bq_parameters(parameters: Any) -> "list[Union[ArrayQueryParameter, S
114
267
  param_type, array_element_type = _get_bq_param_type(actual_value)
115
268
 
116
269
  if param_type == "ARRAY" and array_element_type:
117
- creator = _BQ_PARAM_CREATOR_MAP["ARRAY"]
270
+ creator = param_creator_map["ARRAY"]
118
271
  bq_parameters.append(creator(param_name_for_bq, actual_value, array_element_type))
119
272
  elif param_type == "JSON":
120
- creator = _BQ_PARAM_CREATOR_MAP["JSON"]
273
+ creator = param_creator_map["JSON"]
121
274
  bq_parameters.append(creator(param_name_for_bq, actual_value, None))
122
275
  elif param_type:
123
- creator = _BQ_PARAM_CREATOR_MAP["SCALAR"]
276
+ creator = param_creator_map["SCALAR"]
124
277
  bq_parameters.append(creator(param_name_for_bq, actual_value, param_type))
125
278
  else:
126
279
  msg = f"Unsupported BigQuery parameter type for value of param '{name}': {type(actual_value)}"
@@ -133,21 +286,33 @@ def _create_bq_parameters(parameters: Any) -> "list[Union[ArrayQueryParameter, S
133
286
  return bq_parameters
134
287
 
135
288
 
136
- bigquery_type_coercion_map = {
137
- tuple: list,
138
- bool: lambda x: x,
139
- int: lambda x: x,
140
- float: lambda x: x,
141
- str: _type_converter.convert_if_detected,
142
- bytes: lambda x: x,
143
- datetime.datetime: lambda x: x,
144
- datetime.date: lambda x: x,
145
- datetime.time: lambda x: x,
146
- Decimal: lambda x: x,
147
- dict: lambda x: x,
148
- list: lambda x: x,
149
- type(None): lambda _: None,
150
- }
289
+ def _get_bigquery_type_coercion_map(type_converter: BigQueryTypeConverter) -> dict[type, Any]:
290
+ """Get BigQuery type coercion map with configurable type converter.
291
+
292
+ Args:
293
+ type_converter: BigQuery type converter instance
294
+
295
+ Returns:
296
+ Type coercion map for BigQuery
297
+ """
298
+ return {
299
+ tuple: list,
300
+ bool: lambda x: x,
301
+ int: lambda x: x,
302
+ float: lambda x: x,
303
+ str: type_converter.convert_if_detected,
304
+ bytes: lambda x: x,
305
+ datetime.datetime: lambda x: x,
306
+ datetime.date: lambda x: x,
307
+ datetime.time: lambda x: x,
308
+ Decimal: lambda x: x,
309
+ dict: lambda x: x,
310
+ list: lambda x: x,
311
+ type(None): lambda _: None,
312
+ }
313
+
314
+
315
+ bigquery_type_coercion_map = _get_bigquery_type_coercion_map(_default_type_converter)
151
316
 
152
317
 
153
318
  bigquery_statement_config = StatementConfig(
@@ -176,7 +341,7 @@ class BigQueryCursor:
176
341
 
177
342
  def __init__(self, connection: "BigQueryConnection") -> None:
178
343
  self.connection = connection
179
- self.job: Optional[QueryJob] = None
344
+ self.job: QueryJob | None = None
180
345
 
181
346
  def __enter__(self) -> "BigQueryConnection":
182
347
  return self.connection
@@ -195,7 +360,11 @@ class BigQueryCursor:
195
360
 
196
361
 
197
362
  class BigQueryExceptionHandler:
198
- """Custom sync context manager for handling BigQuery database exceptions."""
363
+ """Context manager for handling BigQuery API exceptions.
364
+
365
+ Maps HTTP status codes and error reasons to specific SQLSpec exceptions
366
+ for better error handling in application code.
367
+ """
199
368
 
200
369
  __slots__ = ()
201
370
 
@@ -203,28 +372,82 @@ class BigQueryExceptionHandler:
203
372
  return None
204
373
 
205
374
  def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
375
+ _ = exc_tb
206
376
  if exc_type is None:
207
377
  return
208
-
209
378
  if issubclass(exc_type, GoogleCloudError):
210
- e = exc_val
211
- error_msg = str(e).lower()
212
- if "syntax" in error_msg or "invalid" in error_msg:
213
- msg = f"BigQuery SQL syntax error: {e}"
214
- raise SQLParsingError(msg) from e
215
- if "permission" in error_msg or "access" in error_msg:
216
- msg = f"BigQuery access error: {e}"
217
- raise SQLSpecError(msg) from e
218
- msg = f"BigQuery cloud error: {e}"
219
- raise SQLSpecError(msg) from e
220
- if issubclass(exc_type, Exception):
221
- e = exc_val
222
- error_msg = str(e).lower()
223
- if "parse" in error_msg or "syntax" in error_msg:
224
- msg = f"SQL parsing failed: {e}"
225
- raise SQLParsingError(msg) from e
226
- msg = f"Unexpected BigQuery operation error: {e}"
227
- raise SQLSpecError(msg) from e
379
+ self._map_bigquery_exception(exc_val)
380
+
381
+ def _map_bigquery_exception(self, e: Any) -> None:
382
+ """Map BigQuery exception to SQLSpec exception.
383
+
384
+ Args:
385
+ e: Google API exception instance
386
+ """
387
+ status_code = getattr(e, "code", None)
388
+ error_msg = str(e).lower()
389
+
390
+ if status_code == HTTP_CONFLICT or "already exists" in error_msg:
391
+ self._raise_unique_violation(e, status_code)
392
+ elif status_code == HTTP_NOT_FOUND or "not found" in error_msg:
393
+ self._raise_not_found_error(e, status_code)
394
+ elif status_code == HTTP_BAD_REQUEST:
395
+ self._handle_bad_request(e, status_code, error_msg)
396
+ elif status_code == HTTP_FORBIDDEN:
397
+ self._raise_connection_error(e, status_code)
398
+ elif status_code and status_code >= HTTP_SERVER_ERROR:
399
+ self._raise_operational_error(e, status_code)
400
+ else:
401
+ self._raise_generic_error(e, status_code)
402
+
403
+ def _handle_bad_request(self, e: Any, code: "int | None", error_msg: str) -> None:
404
+ """Handle 400 Bad Request errors.
405
+
406
+ Args:
407
+ e: Exception instance
408
+ code: HTTP status code
409
+ error_msg: Lowercase error message
410
+ """
411
+ if "syntax" in error_msg or "invalid query" in error_msg:
412
+ self._raise_parsing_error(e, code)
413
+ elif "type" in error_msg or "format" in error_msg:
414
+ self._raise_data_error(e, code)
415
+ else:
416
+ self._raise_generic_error(e, code)
417
+
418
+ def _raise_unique_violation(self, e: Any, code: "int | None") -> None:
419
+ code_str = f"[HTTP {code}]" if code else ""
420
+ msg = f"BigQuery resource already exists {code_str}: {e}"
421
+ raise UniqueViolationError(msg) from e
422
+
423
+ def _raise_not_found_error(self, e: Any, code: "int | None") -> None:
424
+ code_str = f"[HTTP {code}]" if code else ""
425
+ msg = f"BigQuery resource not found {code_str}: {e}"
426
+ raise NotFoundError(msg) from e
427
+
428
+ def _raise_parsing_error(self, e: Any, code: "int | None") -> None:
429
+ code_str = f"[HTTP {code}]" if code else ""
430
+ msg = f"BigQuery query syntax error {code_str}: {e}"
431
+ raise SQLParsingError(msg) from e
432
+
433
+ def _raise_data_error(self, e: Any, code: "int | None") -> None:
434
+ code_str = f"[HTTP {code}]" if code else ""
435
+ msg = f"BigQuery data error {code_str}: {e}"
436
+ raise DataError(msg) from e
437
+
438
+ def _raise_connection_error(self, e: Any, code: "int | None") -> None:
439
+ code_str = f"[HTTP {code}]" if code else ""
440
+ msg = f"BigQuery permission denied {code_str}: {e}"
441
+ raise DatabaseConnectionError(msg) from e
442
+
443
+ def _raise_operational_error(self, e: Any, code: "int | None") -> None:
444
+ code_str = f"[HTTP {code}]" if code else ""
445
+ msg = f"BigQuery operational error {code_str}: {e}"
446
+ raise OperationalError(msg) from e
447
+
448
+ def _raise_generic_error(self, e: Any, code: "int | None") -> None:
449
+ msg = f"BigQuery error [HTTP {code}]: {e}" if code else f"BigQuery error: {e}"
450
+ raise SQLSpecError(msg) from e
228
451
 
229
452
 
230
453
  class BigQueryDriver(SyncDriverAdapterBase):
@@ -234,29 +457,53 @@ class BigQueryDriver(SyncDriverAdapterBase):
234
457
  type coercion, error handling, and query job management.
235
458
  """
236
459
 
237
- __slots__ = ("_data_dictionary", "_default_query_job_config")
460
+ __slots__ = ("_data_dictionary", "_default_query_job_config", "_json_serializer", "_type_converter")
238
461
  dialect = "bigquery"
239
462
 
240
463
  def __init__(
241
464
  self,
242
465
  connection: BigQueryConnection,
243
- statement_config: "Optional[StatementConfig]" = None,
244
- driver_features: "Optional[dict[str, Any]]" = None,
466
+ statement_config: "StatementConfig | None" = None,
467
+ driver_features: "dict[str, Any] | None" = None,
245
468
  ) -> None:
469
+ features = driver_features or {}
470
+
471
+ json_serializer = features.get("json_serializer")
472
+ if json_serializer is None:
473
+ json_serializer = to_json
474
+
475
+ self._json_serializer: Callable[[Any], str] = json_serializer
476
+
477
+ enable_uuid_conversion = features.get("enable_uuid_conversion", True)
478
+ self._type_converter = BigQueryTypeConverter(enable_uuid_conversion=enable_uuid_conversion)
479
+
246
480
  if statement_config is None:
247
481
  cache_config = get_cache_config()
248
- statement_config = bigquery_statement_config.replace(
249
- enable_caching=cache_config.compiled_cache_enabled,
482
+ type_coercion_map = _get_bigquery_type_coercion_map(self._type_converter)
483
+
484
+ param_config = ParameterStyleConfig(
485
+ default_parameter_style=ParameterStyle.NAMED_AT,
486
+ supported_parameter_styles={ParameterStyle.NAMED_AT, ParameterStyle.QMARK},
487
+ default_execution_parameter_style=ParameterStyle.NAMED_AT,
488
+ supported_execution_parameter_styles={ParameterStyle.NAMED_AT},
489
+ type_coercion_map=type_coercion_map,
490
+ has_native_list_expansion=True,
491
+ needs_static_script_compilation=False,
492
+ preserve_original_params_for_many=True,
493
+ )
494
+
495
+ statement_config = StatementConfig(
496
+ dialect="bigquery",
497
+ parameter_config=param_config,
250
498
  enable_parsing=True,
251
499
  enable_validation=True,
252
- dialect="bigquery",
500
+ enable_caching=cache_config.compiled_cache_enabled,
501
+ enable_parameter_type_wrapping=True,
253
502
  )
254
503
 
255
504
  super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features)
256
- self._default_query_job_config: Optional[QueryJobConfig] = (driver_features or {}).get(
257
- "default_query_job_config"
258
- )
259
- self._data_dictionary: Optional[SyncDataDictionaryBase] = None
505
+ self._default_query_job_config: QueryJobConfig | None = (driver_features or {}).get("default_query_job_config")
506
+ self._data_dictionary: SyncDataDictionaryBase | None = None
260
507
 
261
508
  def with_cursor(self, connection: "BigQueryConnection") -> "BigQueryCursor":
262
509
  """Create context manager for cursor management.
@@ -279,20 +526,39 @@ class BigQueryDriver(SyncDriverAdapterBase):
279
526
  """Handle database-specific exceptions and wrap them appropriately."""
280
527
  return BigQueryExceptionHandler()
281
528
 
529
+ def _should_copy_attribute(self, attr: str, source_config: QueryJobConfig) -> bool:
530
+ """Check if attribute should be copied between job configs.
531
+
532
+ Args:
533
+ attr: Attribute name to check.
534
+ source_config: Source configuration object.
535
+
536
+ Returns:
537
+ True if attribute should be copied, False otherwise.
538
+ """
539
+ if attr.startswith("_"):
540
+ return False
541
+
542
+ try:
543
+ value = getattr(source_config, attr)
544
+ return value is not None and not callable(value)
545
+ except (AttributeError, TypeError):
546
+ return False
547
+
282
548
  def _copy_job_config_attrs(self, source_config: QueryJobConfig, target_config: QueryJobConfig) -> None:
283
549
  """Copy non-private attributes from source config to target config.
284
550
 
285
551
  Args:
286
- source_config: Configuration to copy attributes from
287
- target_config: Configuration to copy attributes to
552
+ source_config: Configuration to copy attributes from.
553
+ target_config: Configuration to copy attributes to.
288
554
  """
289
555
  for attr in dir(source_config):
290
- if attr.startswith("_"):
556
+ if not self._should_copy_attribute(attr, source_config):
291
557
  continue
558
+
292
559
  try:
293
560
  value = getattr(source_config, attr)
294
- if value is not None and not callable(value):
295
- setattr(target_config, attr, value)
561
+ setattr(target_config, attr, value)
296
562
  except (AttributeError, TypeError):
297
563
  continue
298
564
 
@@ -300,8 +566,8 @@ class BigQueryDriver(SyncDriverAdapterBase):
300
566
  self,
301
567
  sql_str: str,
302
568
  parameters: Any,
303
- connection: Optional[BigQueryConnection] = None,
304
- job_config: Optional[QueryJobConfig] = None,
569
+ connection: BigQueryConnection | None = None,
570
+ job_config: QueryJobConfig | None = None,
305
571
  ) -> QueryJob:
306
572
  """Execute a BigQuery job with configuration support.
307
573
 
@@ -324,7 +590,7 @@ class BigQueryDriver(SyncDriverAdapterBase):
324
590
  if job_config:
325
591
  self._copy_job_config_attrs(job_config, final_job_config)
326
592
 
327
- bq_parameters = _create_bq_parameters(parameters)
593
+ bq_parameters = _create_bq_parameters(parameters, self._json_serializer)
328
594
  final_job_config.query_parameters = bq_parameters
329
595
 
330
596
  return conn.query(sql_str, job_config=final_job_config)
@@ -341,7 +607,7 @@ class BigQueryDriver(SyncDriverAdapterBase):
341
607
  """
342
608
  return [dict(row) for row in rows_iterator]
343
609
 
344
- def _try_special_handling(self, cursor: "Any", statement: "SQL") -> "Optional[SQLResult]":
610
+ def _try_special_handling(self, cursor: "Any", statement: "SQL") -> "SQLResult | None":
345
611
  """Hook for BigQuery-specific special operations.
346
612
 
347
613
  BigQuery doesn't have complex special operations like PostgreSQL COPY,
@@ -360,12 +626,15 @@ class BigQueryDriver(SyncDriverAdapterBase):
360
626
  def _transform_ast_with_literals(self, sql: str, parameters: Any) -> str:
361
627
  """Transform SQL AST by replacing placeholders with literal values.
362
628
 
629
+ Used for BigQuery script execution and execute_many operations where
630
+ parameter binding is not supported. Safely embeds values as SQL literals.
631
+
363
632
  Args:
364
- sql: SQL string to transform
365
- parameters: Parameters to embed as literals
633
+ sql: SQL string to transform.
634
+ parameters: Parameters to embed as literals.
366
635
 
367
636
  Returns:
368
- Transformed SQL string with literals embedded
637
+ Transformed SQL string with literals embedded.
369
638
  """
370
639
  if not parameters:
371
640
  return sql
@@ -377,70 +646,12 @@ class BigQueryDriver(SyncDriverAdapterBase):
377
646
 
378
647
  placeholder_counter = {"index": 0}
379
648
 
380
- def replace_placeholder(node: exp.Expression) -> exp.Expression:
381
- """Replace placeholder nodes with literal values."""
382
- if isinstance(node, exp.Placeholder):
383
- if isinstance(parameters, (list, tuple)):
384
- current_index = placeholder_counter["index"]
385
- placeholder_counter["index"] += 1
386
- if current_index < len(parameters):
387
- return self._create_literal_node(parameters[current_index])
388
- return node
389
- if isinstance(node, exp.Parameter):
390
- param_name = str(node.this) if hasattr(node.this, "__str__") else node.this
391
- if isinstance(parameters, dict):
392
- possible_names = [param_name, f"@{param_name}", f":{param_name}", f"param_{param_name}"]
393
- for name in possible_names:
394
- if name in parameters:
395
- actual_value = getattr(parameters[name], "value", parameters[name])
396
- return self._create_literal_node(actual_value)
397
- return node
398
- if isinstance(parameters, (list, tuple)):
399
- try:
400
- if param_name.startswith("param_"):
401
- param_index = int(param_name[6:])
402
- if param_index < len(parameters):
403
- return self._create_literal_node(parameters[param_index])
404
-
405
- if param_name.isdigit():
406
- param_index = int(param_name)
407
- if param_index < len(parameters):
408
- return self._create_literal_node(parameters[param_index])
409
- except (ValueError, IndexError, AttributeError):
410
- pass
411
- return node
412
- return node
413
-
414
- transformed_ast = ast.transform(replace_placeholder)
649
+ transformed_ast = ast.transform(
650
+ lambda node: _replace_placeholder_node(node, parameters, placeholder_counter, self._json_serializer)
651
+ )
415
652
 
416
653
  return transformed_ast.sql(dialect="bigquery")
417
654
 
418
- def _create_literal_node(self, value: Any) -> "exp.Expression":
419
- """Create a SQLGlot literal expression from a Python value.
420
-
421
- Args:
422
- value: Python value to convert to SQLGlot literal
423
-
424
- Returns:
425
- SQLGlot expression representing the literal value
426
- """
427
- if value is None:
428
- return exp.Null()
429
- if isinstance(value, bool):
430
- return exp.Boolean(this=value)
431
- if isinstance(value, (int, float)):
432
- return exp.Literal.number(str(value))
433
- if isinstance(value, str):
434
- return exp.Literal.string(value)
435
- if isinstance(value, (list, tuple)):
436
- items = [self._create_literal_node(item) for item in value]
437
- return exp.Array(expressions=items)
438
- if isinstance(value, dict):
439
- json_str = to_json(value)
440
- return exp.Literal.string(json_str)
441
-
442
- return exp.Literal.string(str(value))
443
-
444
655
  def _execute_script(self, cursor: Any, statement: "SQL") -> ExecutionResult:
445
656
  """Execute SQL script with statement splitting and parameter handling.
446
657
 
@@ -550,3 +761,137 @@ class BigQueryDriver(SyncDriverAdapterBase):
550
761
 
551
762
  self._data_dictionary = BigQuerySyncDataDictionary()
552
763
  return self._data_dictionary
764
+
765
+ def _storage_api_available(self) -> bool:
766
+ """Check if BigQuery Storage API is available.
767
+
768
+ Returns:
769
+ True if Storage API is available and working, False otherwise
770
+ """
771
+ try:
772
+ from google.cloud import bigquery_storage_v1 # type: ignore[attr-defined]
773
+
774
+ # Try to create client (will fail if API not enabled or credentials missing)
775
+ _ = bigquery_storage_v1.BigQueryReadClient()
776
+ except ImportError:
777
+ # Package not installed
778
+ return False
779
+ except Exception:
780
+ # API not enabled or permissions issue
781
+ return False
782
+ else:
783
+ return True
784
+
785
+ def select_to_arrow(
786
+ self,
787
+ statement: "Statement | QueryBuilder",
788
+ /,
789
+ *parameters: "StatementParameters | StatementFilter",
790
+ statement_config: "StatementConfig | None" = None,
791
+ return_format: str = "table",
792
+ native_only: bool = False,
793
+ batch_size: int | None = None,
794
+ arrow_schema: Any = None,
795
+ **kwargs: Any,
796
+ ) -> "ArrowResult":
797
+ """Execute query and return results as Apache Arrow (BigQuery native with Storage API).
798
+
799
+ BigQuery provides native Arrow via Storage API (query_job.to_arrow()).
800
+ Requires google-cloud-bigquery-storage package and API enabled.
801
+ Falls back to dict conversion if Storage API not available.
802
+
803
+ Args:
804
+ statement: SQL statement, string, or QueryBuilder
805
+ *parameters: Query parameters or filters
806
+ statement_config: Optional statement configuration override
807
+ return_format: "table" for pyarrow.Table (default), "batch" for RecordBatch
808
+ native_only: If True, raise error if Storage API unavailable (default: False)
809
+ batch_size: Batch size hint (for future streaming implementation)
810
+ arrow_schema: Optional pyarrow.Schema for type casting
811
+ **kwargs: Additional keyword arguments
812
+
813
+ Returns:
814
+ ArrowResult with native Arrow data (if Storage API available) or converted data
815
+
816
+ Raises:
817
+ MissingDependencyError: If pyarrow not installed, or if Storage API not available and native_only=True
818
+ SQLExecutionError: If query execution fails
819
+
820
+ Example:
821
+ >>> # Will use native Arrow if Storage API available, otherwise converts
822
+ >>> result = driver.select_to_arrow(
823
+ ... "SELECT * FROM dataset.users WHERE age > @age",
824
+ ... {"age": 18},
825
+ ... )
826
+ >>> df = result.to_pandas()
827
+
828
+ >>> # Force native Arrow (raises if Storage API unavailable)
829
+ >>> result = driver.select_to_arrow(
830
+ ... "SELECT * FROM dataset.users", native_only=True
831
+ ... )
832
+ """
833
+ from sqlspec.utils.module_loader import ensure_pyarrow
834
+
835
+ ensure_pyarrow()
836
+
837
+ # Check Storage API availability
838
+ if not self._storage_api_available():
839
+ if native_only:
840
+ from sqlspec.exceptions import MissingDependencyError
841
+
842
+ msg = (
843
+ "BigQuery native Arrow requires Storage API.\n"
844
+ "1. Install: pip install google-cloud-bigquery-storage\n"
845
+ "2. Enable API: https://console.cloud.google.com/apis/library/bigquerystorage.googleapis.com\n"
846
+ "3. Grant permissions: roles/bigquery.dataViewer"
847
+ )
848
+ raise MissingDependencyError(
849
+ package="google-cloud-bigquery-storage", install_package="google-cloud-bigquery-storage"
850
+ ) from RuntimeError(msg)
851
+
852
+ # Fallback to conversion path
853
+ result: ArrowResult = super().select_to_arrow(
854
+ statement,
855
+ *parameters,
856
+ statement_config=statement_config,
857
+ return_format=return_format,
858
+ native_only=native_only,
859
+ batch_size=batch_size,
860
+ arrow_schema=arrow_schema,
861
+ **kwargs,
862
+ )
863
+ return result
864
+
865
+ # Use native path with Storage API
866
+ import pyarrow as pa
867
+
868
+ from sqlspec.core.result import create_arrow_result
869
+
870
+ # Prepare statement
871
+ config = statement_config or self.statement_config
872
+ prepared_statement = self.prepare_statement(statement, parameters, statement_config=config, kwargs=kwargs)
873
+
874
+ # Get compiled SQL and parameters
875
+ sql, driver_params = self._get_compiled_sql(prepared_statement, config)
876
+
877
+ # Execute query using existing _run_query_job method
878
+ with self.handle_database_exceptions():
879
+ query_job = self._run_query_job(sql, driver_params)
880
+ query_job.result() # Wait for completion
881
+
882
+ # Native Arrow via Storage API
883
+ arrow_table = query_job.to_arrow()
884
+
885
+ # Apply schema casting if requested
886
+ if arrow_schema is not None:
887
+ arrow_table = arrow_table.cast(arrow_schema)
888
+
889
+ # Convert to batch if requested
890
+ if return_format == "batch":
891
+ batches = arrow_table.to_batches()
892
+ arrow_data: Any = batches[0] if batches else pa.RecordBatch.from_pydict({})
893
+ else:
894
+ arrow_data = arrow_table
895
+
896
+ # Create ArrowResult
897
+ return create_arrow_result(statement=prepared_statement, data=arrow_data, rows_affected=arrow_data.num_rows)