sqlspec 0.11.1__py3-none-any.whl → 0.12.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (155) hide show
  1. sqlspec/__init__.py +16 -3
  2. sqlspec/_serialization.py +3 -10
  3. sqlspec/_sql.py +1147 -0
  4. sqlspec/_typing.py +343 -41
  5. sqlspec/adapters/adbc/__init__.py +2 -6
  6. sqlspec/adapters/adbc/config.py +474 -149
  7. sqlspec/adapters/adbc/driver.py +330 -621
  8. sqlspec/adapters/aiosqlite/__init__.py +2 -6
  9. sqlspec/adapters/aiosqlite/config.py +143 -57
  10. sqlspec/adapters/aiosqlite/driver.py +269 -431
  11. sqlspec/adapters/asyncmy/__init__.py +3 -8
  12. sqlspec/adapters/asyncmy/config.py +247 -202
  13. sqlspec/adapters/asyncmy/driver.py +218 -436
  14. sqlspec/adapters/asyncpg/__init__.py +4 -7
  15. sqlspec/adapters/asyncpg/config.py +329 -176
  16. sqlspec/adapters/asyncpg/driver.py +417 -487
  17. sqlspec/adapters/bigquery/__init__.py +2 -2
  18. sqlspec/adapters/bigquery/config.py +407 -0
  19. sqlspec/adapters/bigquery/driver.py +600 -553
  20. sqlspec/adapters/duckdb/__init__.py +4 -1
  21. sqlspec/adapters/duckdb/config.py +432 -321
  22. sqlspec/adapters/duckdb/driver.py +392 -406
  23. sqlspec/adapters/oracledb/__init__.py +3 -8
  24. sqlspec/adapters/oracledb/config.py +625 -0
  25. sqlspec/adapters/oracledb/driver.py +548 -921
  26. sqlspec/adapters/psqlpy/__init__.py +4 -7
  27. sqlspec/adapters/psqlpy/config.py +372 -203
  28. sqlspec/adapters/psqlpy/driver.py +197 -533
  29. sqlspec/adapters/psycopg/__init__.py +3 -8
  30. sqlspec/adapters/psycopg/config.py +725 -0
  31. sqlspec/adapters/psycopg/driver.py +734 -694
  32. sqlspec/adapters/sqlite/__init__.py +2 -6
  33. sqlspec/adapters/sqlite/config.py +146 -81
  34. sqlspec/adapters/sqlite/driver.py +242 -405
  35. sqlspec/base.py +220 -784
  36. sqlspec/config.py +354 -0
  37. sqlspec/driver/__init__.py +22 -0
  38. sqlspec/driver/_async.py +252 -0
  39. sqlspec/driver/_common.py +338 -0
  40. sqlspec/driver/_sync.py +261 -0
  41. sqlspec/driver/mixins/__init__.py +17 -0
  42. sqlspec/driver/mixins/_pipeline.py +523 -0
  43. sqlspec/driver/mixins/_result_utils.py +122 -0
  44. sqlspec/driver/mixins/_sql_translator.py +35 -0
  45. sqlspec/driver/mixins/_storage.py +993 -0
  46. sqlspec/driver/mixins/_type_coercion.py +131 -0
  47. sqlspec/exceptions.py +299 -7
  48. sqlspec/extensions/aiosql/__init__.py +10 -0
  49. sqlspec/extensions/aiosql/adapter.py +474 -0
  50. sqlspec/extensions/litestar/__init__.py +1 -6
  51. sqlspec/extensions/litestar/_utils.py +1 -5
  52. sqlspec/extensions/litestar/config.py +5 -6
  53. sqlspec/extensions/litestar/handlers.py +13 -12
  54. sqlspec/extensions/litestar/plugin.py +22 -24
  55. sqlspec/extensions/litestar/providers.py +37 -55
  56. sqlspec/loader.py +528 -0
  57. sqlspec/service/__init__.py +3 -0
  58. sqlspec/service/base.py +24 -0
  59. sqlspec/service/pagination.py +26 -0
  60. sqlspec/statement/__init__.py +21 -0
  61. sqlspec/statement/builder/__init__.py +54 -0
  62. sqlspec/statement/builder/_ddl_utils.py +119 -0
  63. sqlspec/statement/builder/_parsing_utils.py +135 -0
  64. sqlspec/statement/builder/base.py +328 -0
  65. sqlspec/statement/builder/ddl.py +1379 -0
  66. sqlspec/statement/builder/delete.py +80 -0
  67. sqlspec/statement/builder/insert.py +274 -0
  68. sqlspec/statement/builder/merge.py +95 -0
  69. sqlspec/statement/builder/mixins/__init__.py +65 -0
  70. sqlspec/statement/builder/mixins/_aggregate_functions.py +151 -0
  71. sqlspec/statement/builder/mixins/_case_builder.py +91 -0
  72. sqlspec/statement/builder/mixins/_common_table_expr.py +91 -0
  73. sqlspec/statement/builder/mixins/_delete_from.py +34 -0
  74. sqlspec/statement/builder/mixins/_from.py +61 -0
  75. sqlspec/statement/builder/mixins/_group_by.py +119 -0
  76. sqlspec/statement/builder/mixins/_having.py +35 -0
  77. sqlspec/statement/builder/mixins/_insert_from_select.py +48 -0
  78. sqlspec/statement/builder/mixins/_insert_into.py +36 -0
  79. sqlspec/statement/builder/mixins/_insert_values.py +69 -0
  80. sqlspec/statement/builder/mixins/_join.py +110 -0
  81. sqlspec/statement/builder/mixins/_limit_offset.py +53 -0
  82. sqlspec/statement/builder/mixins/_merge_clauses.py +405 -0
  83. sqlspec/statement/builder/mixins/_order_by.py +46 -0
  84. sqlspec/statement/builder/mixins/_pivot.py +82 -0
  85. sqlspec/statement/builder/mixins/_returning.py +37 -0
  86. sqlspec/statement/builder/mixins/_select_columns.py +60 -0
  87. sqlspec/statement/builder/mixins/_set_ops.py +122 -0
  88. sqlspec/statement/builder/mixins/_unpivot.py +80 -0
  89. sqlspec/statement/builder/mixins/_update_from.py +54 -0
  90. sqlspec/statement/builder/mixins/_update_set.py +91 -0
  91. sqlspec/statement/builder/mixins/_update_table.py +29 -0
  92. sqlspec/statement/builder/mixins/_where.py +374 -0
  93. sqlspec/statement/builder/mixins/_window_functions.py +86 -0
  94. sqlspec/statement/builder/protocols.py +20 -0
  95. sqlspec/statement/builder/select.py +206 -0
  96. sqlspec/statement/builder/update.py +178 -0
  97. sqlspec/statement/filters.py +571 -0
  98. sqlspec/statement/parameters.py +736 -0
  99. sqlspec/statement/pipelines/__init__.py +67 -0
  100. sqlspec/statement/pipelines/analyzers/__init__.py +9 -0
  101. sqlspec/statement/pipelines/analyzers/_analyzer.py +649 -0
  102. sqlspec/statement/pipelines/base.py +315 -0
  103. sqlspec/statement/pipelines/context.py +119 -0
  104. sqlspec/statement/pipelines/result_types.py +41 -0
  105. sqlspec/statement/pipelines/transformers/__init__.py +8 -0
  106. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +256 -0
  107. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +623 -0
  108. sqlspec/statement/pipelines/transformers/_remove_comments.py +66 -0
  109. sqlspec/statement/pipelines/transformers/_remove_hints.py +81 -0
  110. sqlspec/statement/pipelines/validators/__init__.py +23 -0
  111. sqlspec/statement/pipelines/validators/_dml_safety.py +275 -0
  112. sqlspec/statement/pipelines/validators/_parameter_style.py +297 -0
  113. sqlspec/statement/pipelines/validators/_performance.py +703 -0
  114. sqlspec/statement/pipelines/validators/_security.py +990 -0
  115. sqlspec/statement/pipelines/validators/base.py +67 -0
  116. sqlspec/statement/result.py +527 -0
  117. sqlspec/statement/splitter.py +701 -0
  118. sqlspec/statement/sql.py +1198 -0
  119. sqlspec/storage/__init__.py +15 -0
  120. sqlspec/storage/backends/__init__.py +0 -0
  121. sqlspec/storage/backends/base.py +166 -0
  122. sqlspec/storage/backends/fsspec.py +315 -0
  123. sqlspec/storage/backends/obstore.py +464 -0
  124. sqlspec/storage/protocol.py +170 -0
  125. sqlspec/storage/registry.py +315 -0
  126. sqlspec/typing.py +157 -36
  127. sqlspec/utils/correlation.py +155 -0
  128. sqlspec/utils/deprecation.py +3 -6
  129. sqlspec/utils/fixtures.py +6 -11
  130. sqlspec/utils/logging.py +135 -0
  131. sqlspec/utils/module_loader.py +45 -43
  132. sqlspec/utils/serializers.py +4 -0
  133. sqlspec/utils/singleton.py +6 -8
  134. sqlspec/utils/sync_tools.py +15 -27
  135. sqlspec/utils/text.py +58 -26
  136. {sqlspec-0.11.1.dist-info → sqlspec-0.12.1.dist-info}/METADATA +97 -26
  137. sqlspec-0.12.1.dist-info/RECORD +145 -0
  138. sqlspec/adapters/bigquery/config/__init__.py +0 -3
  139. sqlspec/adapters/bigquery/config/_common.py +0 -40
  140. sqlspec/adapters/bigquery/config/_sync.py +0 -87
  141. sqlspec/adapters/oracledb/config/__init__.py +0 -9
  142. sqlspec/adapters/oracledb/config/_asyncio.py +0 -186
  143. sqlspec/adapters/oracledb/config/_common.py +0 -131
  144. sqlspec/adapters/oracledb/config/_sync.py +0 -186
  145. sqlspec/adapters/psycopg/config/__init__.py +0 -19
  146. sqlspec/adapters/psycopg/config/_async.py +0 -169
  147. sqlspec/adapters/psycopg/config/_common.py +0 -56
  148. sqlspec/adapters/psycopg/config/_sync.py +0 -168
  149. sqlspec/filters.py +0 -331
  150. sqlspec/mixins.py +0 -305
  151. sqlspec/statement.py +0 -378
  152. sqlspec-0.11.1.dist-info/RECORD +0 -69
  153. {sqlspec-0.11.1.dist-info → sqlspec-0.12.1.dist-info}/WHEEL +0 -0
  154. {sqlspec-0.11.1.dist-info → sqlspec-0.12.1.dist-info}/licenses/LICENSE +0 -0
  155. {sqlspec-0.11.1.dist-info → sqlspec-0.12.1.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,338 @@
1
+ """Common driver attributes and utilities."""
2
+
3
+ import re
4
+ from abc import ABC
5
+ from collections.abc import Mapping, Sequence
6
+ from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional
7
+
8
+ import sqlglot
9
+ from sqlglot import exp
10
+ from sqlglot.tokens import TokenType
11
+
12
+ from sqlspec.exceptions import NotFoundError
13
+ from sqlspec.statement import SQLConfig
14
+ from sqlspec.statement.parameters import ParameterStyle, ParameterValidator
15
+ from sqlspec.statement.splitter import split_sql_script
16
+ from sqlspec.typing import ConnectionT, DictRow, RowT, T
17
+ from sqlspec.utils.logging import get_logger
18
+
19
+ if TYPE_CHECKING:
20
+ from sqlglot.dialects.dialect import DialectType
21
+
22
+
23
+ __all__ = ("CommonDriverAttributesMixin",)
24
+
25
+
26
+ logger = get_logger("driver")
27
+
28
+
29
+ class CommonDriverAttributesMixin(ABC, Generic[ConnectionT, RowT]):
30
+ """Common attributes and methods for driver adapters."""
31
+
32
+ __slots__ = ("config", "connection", "default_row_type")
33
+
34
+ dialect: "DialectType"
35
+ """The SQL dialect supported by the underlying database driver."""
36
+ supported_parameter_styles: "tuple[ParameterStyle, ...]"
37
+ """The parameter styles supported by this driver."""
38
+ default_parameter_style: "ParameterStyle"
39
+ """The default parameter style to convert to when unsupported style is detected."""
40
+ supports_native_parquet_export: "ClassVar[bool]" = False
41
+ """Indicates if the driver supports native Parquet export operations."""
42
+ supports_native_parquet_import: "ClassVar[bool]" = False
43
+ """Indicates if the driver supports native Parquet import operations."""
44
+ supports_native_arrow_export: "ClassVar[bool]" = False
45
+ """Indicates if the driver supports native Arrow export operations."""
46
+ supports_native_arrow_import: "ClassVar[bool]" = False
47
+ """Indicates if the driver supports native Arrow import operations."""
48
+
49
+ def __init__(
50
+ self,
51
+ connection: "ConnectionT",
52
+ config: "Optional[SQLConfig]" = None,
53
+ default_row_type: "type[DictRow]" = dict[str, Any],
54
+ ) -> None:
55
+ """Initialize with connection, config, and default_row_type.
56
+
57
+ Args:
58
+ connection: The database connection
59
+ config: SQL statement configuration
60
+ default_row_type: Default row type for results (DictRow, TupleRow, etc.)
61
+ """
62
+ self.connection = connection
63
+ self.config = config or SQLConfig()
64
+ self.default_row_type = default_row_type or dict[str, Any]
65
+
66
+ def _connection(self, connection: "Optional[ConnectionT]" = None) -> "ConnectionT":
67
+ return connection or self.connection
68
+
69
+ def returns_rows(self, expression: "Optional[exp.Expression]") -> bool:
70
+ """Check if the SQL expression is expected to return rows.
71
+
72
+ Args:
73
+ expression: The SQL expression.
74
+
75
+ Returns:
76
+ True if the expression is a SELECT, VALUES, or WITH statement
77
+ that is not a CTE definition.
78
+ """
79
+ if expression is None:
80
+ return False
81
+ if isinstance(expression, (exp.Select, exp.Values, exp.Table, exp.Show, exp.Describe, exp.Pragma, exp.Command)):
82
+ return True
83
+ if isinstance(expression, exp.With) and expression.expressions:
84
+ return self.returns_rows(expression.expressions[-1])
85
+ if isinstance(expression, (exp.Insert, exp.Update, exp.Delete)):
86
+ return bool(expression.find(exp.Returning))
87
+ # Handle Anonymous expressions (failed to parse) using a robust approach
88
+ if isinstance(expression, exp.Anonymous):
89
+ return self._check_anonymous_returns_rows(expression)
90
+ return False
91
+
92
+ def _check_anonymous_returns_rows(self, expression: "exp.Anonymous") -> bool:
93
+ """Check if an Anonymous expression returns rows using robust methods.
94
+
95
+ This method handles SQL that failed to parse (often due to database-specific
96
+ placeholders) by:
97
+ 1. Attempting to re-parse with placeholders sanitized
98
+ 2. Using the tokenizer as a fallback for keyword detection
99
+
100
+ Args:
101
+ expression: The Anonymous expression to check
102
+
103
+ Returns:
104
+ True if the expression likely returns rows
105
+ """
106
+
107
+ sql_text = str(expression.this) if expression.this else ""
108
+ if not sql_text.strip():
109
+ return False
110
+
111
+ # Regex to find common SQL placeholders: ?, %s, $1, $2, :name, etc.
112
+ placeholder_regex = re.compile(r"(\?|%s|\$\d+|:\w+|%\(\w+\)s)")
113
+
114
+ # Approach 1: Try to re-parse with placeholders replaced
115
+ try:
116
+ # Replace placeholders with a dummy literal that sqlglot can parse
117
+ sanitized_sql = placeholder_regex.sub("1", sql_text)
118
+
119
+ # If we replaced any placeholders, try parsing again
120
+ if sanitized_sql != sql_text:
121
+ parsed = sqlglot.parse_one(sanitized_sql, read=None)
122
+ # Check if it's a query type that returns rows
123
+ if isinstance(
124
+ parsed, (exp.Select, exp.Values, exp.Table, exp.Show, exp.Describe, exp.Pragma, exp.Command)
125
+ ):
126
+ return True
127
+ if isinstance(parsed, exp.With) and parsed.expressions:
128
+ return self.returns_rows(parsed.expressions[-1])
129
+ if isinstance(parsed, (exp.Insert, exp.Update, exp.Delete)):
130
+ return bool(parsed.find(exp.Returning))
131
+ if not isinstance(parsed, exp.Anonymous):
132
+ return False
133
+ except Exception:
134
+ logger.debug("Could not parse using placeholders. Using tokenizer. %s", sql_text)
135
+
136
+ # Approach 2: Use tokenizer for robust keyword detection
137
+ try:
138
+ tokens = list(sqlglot.tokenize(sql_text, read=None))
139
+ row_returning_tokens = {
140
+ TokenType.SELECT,
141
+ TokenType.WITH,
142
+ TokenType.VALUES,
143
+ TokenType.TABLE,
144
+ TokenType.SHOW,
145
+ TokenType.DESCRIBE,
146
+ TokenType.PRAGMA,
147
+ }
148
+ for token in tokens:
149
+ if token.token_type in {TokenType.COMMENT, TokenType.SEMICOLON}:
150
+ continue
151
+ return token.token_type in row_returning_tokens
152
+
153
+ except Exception:
154
+ return False
155
+
156
+ return False
157
+
158
+ @staticmethod
159
+ def check_not_found(item_or_none: "Optional[T]" = None) -> "T":
160
+ """Raise :exc:`sqlspec.exceptions.NotFoundError` if ``item_or_none`` is ``None``.
161
+
162
+ Args:
163
+ item_or_none: Item to be tested for existence.
164
+
165
+ Raises:
166
+ NotFoundError: If ``item_or_none`` is ``None``
167
+
168
+ Returns:
169
+ The item, if it exists.
170
+ """
171
+ if item_or_none is None:
172
+ msg = "No result found when one was expected"
173
+ raise NotFoundError(msg)
174
+ return item_or_none
175
+
176
+ def _convert_parameters_to_driver_format( # noqa: C901
177
+ self, sql: str, parameters: Any, target_style: "Optional[ParameterStyle]" = None
178
+ ) -> Any:
179
+ """Convert parameters to the format expected by the driver, but only when necessary.
180
+
181
+ This method analyzes the SQL to understand what parameter style is used
182
+ and only converts when there's a mismatch between provided parameters
183
+ and what the driver expects.
184
+
185
+ Args:
186
+ sql: The SQL string with placeholders
187
+ parameters: The parameters in any format (dict, list, tuple, scalar)
188
+ target_style: Optional override for the target parameter style
189
+
190
+ Returns:
191
+ Parameters in the format expected by the database driver
192
+ """
193
+ if parameters is None:
194
+ return None
195
+
196
+ # Extract parameter info from the SQL
197
+ validator = ParameterValidator()
198
+ param_info_list = validator.extract_parameters(sql)
199
+
200
+ if not param_info_list:
201
+ # No parameters in SQL, return None
202
+ return None
203
+
204
+ # Determine the target style from the SQL if not provided
205
+ if target_style is None:
206
+ target_style = self.default_parameter_style
207
+
208
+ actual_styles = {p.style for p in param_info_list if p.style}
209
+ if len(actual_styles) == 1:
210
+ detected_style = actual_styles.pop()
211
+ if detected_style != target_style:
212
+ target_style = detected_style
213
+
214
+ # Analyze what format the driver expects based on the placeholder style
215
+ driver_expects_dict = target_style in {
216
+ ParameterStyle.NAMED_COLON,
217
+ ParameterStyle.POSITIONAL_COLON,
218
+ ParameterStyle.NAMED_AT,
219
+ ParameterStyle.NAMED_DOLLAR,
220
+ ParameterStyle.NAMED_PYFORMAT,
221
+ }
222
+
223
+ # Check if parameters are already in the correct format
224
+ params_are_dict = isinstance(parameters, (dict, Mapping))
225
+ params_are_sequence = isinstance(parameters, (list, tuple, Sequence)) and not isinstance(
226
+ parameters, (str, bytes)
227
+ )
228
+
229
+ # Single scalar parameter
230
+ if len(param_info_list) == 1 and not params_are_dict and not params_are_sequence:
231
+ if driver_expects_dict:
232
+ # Convert scalar to dict
233
+ param_info = param_info_list[0]
234
+ if param_info.name:
235
+ return {param_info.name: parameters}
236
+ return {f"param_{param_info.ordinal}": parameters}
237
+ return [parameters]
238
+
239
+ if driver_expects_dict and params_are_dict:
240
+ if target_style == ParameterStyle.POSITIONAL_COLON and all(
241
+ p.name and p.name.isdigit() for p in param_info_list
242
+ ):
243
+ # If all parameters are numeric but named, convert to dict
244
+ # SQL has numeric placeholders but params might have named keys
245
+ # Only convert if keys don't match
246
+ numeric_keys_expected = {p.name for p in param_info_list if p.name}
247
+ if not numeric_keys_expected.issubset(parameters.keys()):
248
+ # Need to convert named keys to numeric positions
249
+ numeric_result: dict[str, Any] = {}
250
+ param_values = list(parameters.values())
251
+ for param_info in param_info_list:
252
+ if param_info.name and param_info.ordinal < len(param_values):
253
+ numeric_result[param_info.name] = param_values[param_info.ordinal]
254
+ return numeric_result
255
+
256
+ # Special case: Auto-generated param_N style when SQL expects specific names
257
+ if all(key.startswith("param_") and key[6:].isdigit() for key in parameters):
258
+ # Check if SQL has different parameter names
259
+ sql_param_names = {p.name for p in param_info_list if p.name}
260
+ if sql_param_names and not any(name.startswith("param_") for name in sql_param_names):
261
+ # SQL has specific names, not param_N style - don't use these params as-is
262
+ # This likely indicates a mismatch in parameter generation
263
+ # For now, pass through and let validation catch it
264
+ pass
265
+
266
+ # Otherwise, dict format matches - return as-is
267
+ return parameters
268
+
269
+ if not driver_expects_dict and params_are_sequence:
270
+ # Formats match - return as-is
271
+ return parameters
272
+
273
+ # Formats don't match - need conversion
274
+ if driver_expects_dict and params_are_sequence:
275
+ # Convert positional to dict
276
+ dict_result: dict[str, Any] = {}
277
+ for i, (param_info, value) in enumerate(zip(param_info_list, parameters)):
278
+ if param_info.name:
279
+ # Use the name from SQL
280
+ if param_info.style == ParameterStyle.POSITIONAL_COLON and param_info.name.isdigit():
281
+ # Oracle uses string keys even for numeric placeholders
282
+ dict_result[param_info.name] = value
283
+ else:
284
+ dict_result[param_info.name] = value
285
+ else:
286
+ # Use param_N format for unnamed placeholders
287
+ dict_result[f"param_{i}"] = value
288
+ return dict_result
289
+
290
+ if not driver_expects_dict and params_are_dict:
291
+ # Convert dict to positional
292
+ # First check if it's already in param_N format
293
+ if all(key.startswith("param_") and key[6:].isdigit() for key in parameters):
294
+ # Extract values in order
295
+ positional_result: list[Any] = []
296
+ for i in range(len(param_info_list)):
297
+ key = f"param_{i}"
298
+ if key in parameters:
299
+ positional_result.append(parameters[key])
300
+ return positional_result
301
+
302
+ # Convert named dict to positional based on parameter order in SQL
303
+ positional_params: list[Any] = []
304
+ for param_info in param_info_list:
305
+ if param_info.name and param_info.name in parameters:
306
+ positional_params.append(parameters[param_info.name])
307
+ elif f"param_{param_info.ordinal}" in parameters:
308
+ positional_params.append(parameters[f"param_{param_info.ordinal}"])
309
+ else:
310
+ # Try to match by position if we have a simple dict
311
+ param_values = list(parameters.values())
312
+ if param_info.ordinal < len(param_values):
313
+ positional_params.append(param_values[param_info.ordinal])
314
+ return positional_params or list(parameters.values())
315
+
316
+ # This shouldn't happen, but return as-is
317
+ return parameters
318
+
319
+ def _split_script_statements(self, script: str, strip_trailing_semicolon: bool = False) -> list[str]:
320
+ """Split a SQL script into individual statements.
321
+
322
+ This method uses a robust lexer-driven state machine to handle
323
+ multi-statement scripts, including complex constructs like
324
+ PL/SQL blocks, T-SQL batches, and nested blocks.
325
+
326
+ Args:
327
+ script: The SQL script to split
328
+ strip_trailing_semicolon: If True, remove trailing semicolons from statements
329
+
330
+ Returns:
331
+ A list of individual SQL statements
332
+
333
+ Note:
334
+ This is particularly useful for databases that don't natively
335
+ support multi-statement execution (e.g., Oracle, some async drivers).
336
+ """
337
+ # The split_sql_script function already handles dialect mapping and fallback
338
+ return split_sql_script(script, dialect=str(self.dialect), strip_trailing_semicolon=strip_trailing_semicolon)
@@ -0,0 +1,261 @@
1
+ """Synchronous driver protocol implementation."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload
5
+
6
+ from sqlspec.driver._common import CommonDriverAttributesMixin
7
+ from sqlspec.statement.builder import DeleteBuilder, InsertBuilder, QueryBuilder, SelectBuilder, UpdateBuilder
8
+ from sqlspec.statement.filters import StatementFilter
9
+ from sqlspec.statement.result import SQLResult
10
+ from sqlspec.statement.sql import SQL, SQLConfig, Statement
11
+ from sqlspec.typing import ConnectionT, DictRow, ModelDTOT, RowT, StatementParameters
12
+ from sqlspec.utils.logging import get_logger
13
+
14
+ logger = get_logger("sqlspec")
15
+
16
+
17
+ if TYPE_CHECKING:
18
+ from sqlspec.statement.result import DMLResultDict, ScriptResultDict, SelectResultDict
19
+
20
+ __all__ = ("SyncDriverAdapterProtocol",)
21
+
22
+
23
+ EMPTY_FILTERS: "list[StatementFilter]" = []
24
+
25
+
26
+ class SyncDriverAdapterProtocol(CommonDriverAttributesMixin[ConnectionT, RowT], ABC):
27
+ __slots__ = ()
28
+
29
+ def __init__(
30
+ self,
31
+ connection: "ConnectionT",
32
+ config: "Optional[SQLConfig]" = None,
33
+ default_row_type: "type[DictRow]" = DictRow,
34
+ ) -> None:
35
+ """Initialize sync driver adapter.
36
+
37
+ Args:
38
+ connection: The database connection
39
+ config: SQL statement configuration
40
+ default_row_type: Default row type for results (DictRow, TupleRow, etc.)
41
+ """
42
+ # Initialize CommonDriverAttributes part
43
+ super().__init__(connection=connection, config=config, default_row_type=default_row_type)
44
+
45
+ def _build_statement(
46
+ self,
47
+ statement: "Union[Statement, QueryBuilder[Any]]",
48
+ *parameters: "Union[StatementParameters, StatementFilter]",
49
+ _config: "Optional[SQLConfig]" = None,
50
+ **kwargs: Any,
51
+ ) -> "SQL":
52
+ # Use driver's config if none provided
53
+ _config = _config or self.config
54
+
55
+ if isinstance(statement, QueryBuilder):
56
+ return statement.to_statement(config=_config)
57
+ # If statement is already a SQL object, handle additional parameters
58
+ if isinstance(statement, SQL):
59
+ if parameters or kwargs:
60
+ # Create a new SQL object with the same SQL but additional parameters
61
+ return SQL(statement._sql, *parameters, _dialect=self.dialect, _config=_config, **kwargs)
62
+ return statement
63
+ return SQL(statement, *parameters, _dialect=self.dialect, _config=_config, **kwargs)
64
+
65
+ @abstractmethod
66
+ def _execute_statement(
67
+ self, statement: "SQL", connection: "Optional[ConnectionT]" = None, **kwargs: Any
68
+ ) -> "Union[SelectResultDict, DMLResultDict, ScriptResultDict]":
69
+ """Actual execution implementation by concrete drivers, using the raw connection.
70
+
71
+ Returns one of the standardized result dictionaries based on the statement type.
72
+ """
73
+ raise NotImplementedError
74
+
75
+ @abstractmethod
76
+ def _wrap_select_result(
77
+ self,
78
+ statement: "SQL",
79
+ result: "SelectResultDict",
80
+ schema_type: "Optional[type[ModelDTOT]]" = None,
81
+ **kwargs: Any,
82
+ ) -> "Union[SQLResult[ModelDTOT], SQLResult[RowT]]":
83
+ raise NotImplementedError
84
+
85
+ @abstractmethod
86
+ def _wrap_execute_result(
87
+ self, statement: "SQL", result: "Union[DMLResultDict, ScriptResultDict]", **kwargs: Any
88
+ ) -> "SQLResult[RowT]":
89
+ raise NotImplementedError
90
+
91
+ @overload
92
+ def execute(
93
+ self,
94
+ statement: "SelectBuilder",
95
+ /,
96
+ *parameters: "Union[StatementParameters, StatementFilter]",
97
+ schema_type: "type[ModelDTOT]",
98
+ _connection: "Optional[ConnectionT]" = None,
99
+ _config: "Optional[SQLConfig]" = None,
100
+ **kwargs: Any,
101
+ ) -> "SQLResult[ModelDTOT]": ...
102
+
103
+ @overload
104
+ def execute(
105
+ self,
106
+ statement: "SelectBuilder",
107
+ /,
108
+ *parameters: "Union[StatementParameters, StatementFilter]",
109
+ schema_type: None = None,
110
+ _connection: "Optional[ConnectionT]" = None,
111
+ _config: "Optional[SQLConfig]" = None,
112
+ **kwargs: Any,
113
+ ) -> "SQLResult[RowT]": ...
114
+
115
+ @overload
116
+ def execute(
117
+ self,
118
+ statement: "Union[InsertBuilder, UpdateBuilder, DeleteBuilder]",
119
+ /,
120
+ *parameters: "Union[StatementParameters, StatementFilter]",
121
+ _connection: "Optional[ConnectionT]" = None,
122
+ _config: "Optional[SQLConfig]" = None,
123
+ **kwargs: Any,
124
+ ) -> "SQLResult[RowT]": ...
125
+
126
+ @overload
127
+ def execute(
128
+ self,
129
+ statement: "Statement",
130
+ /,
131
+ *parameters: "Union[StatementParameters, StatementFilter]",
132
+ schema_type: "type[ModelDTOT]",
133
+ _connection: "Optional[ConnectionT]" = None,
134
+ _config: "Optional[SQLConfig]" = None,
135
+ **kwargs: Any,
136
+ ) -> "SQLResult[ModelDTOT]": ...
137
+
138
+ @overload
139
+ def execute(
140
+ self,
141
+ statement: "Union[str, SQL]",
142
+ /,
143
+ *parameters: "Union[StatementParameters, StatementFilter]",
144
+ schema_type: None = None,
145
+ _connection: "Optional[ConnectionT]" = None,
146
+ _config: "Optional[SQLConfig]" = None,
147
+ **kwargs: Any,
148
+ ) -> "SQLResult[RowT]": ...
149
+
150
+ def execute(
151
+ self,
152
+ statement: "Union[SQL, Statement, QueryBuilder[Any]]",
153
+ /,
154
+ *parameters: "Union[StatementParameters, StatementFilter]",
155
+ schema_type: "Optional[type[ModelDTOT]]" = None,
156
+ _connection: "Optional[ConnectionT]" = None,
157
+ _config: "Optional[SQLConfig]" = None,
158
+ **kwargs: Any,
159
+ ) -> "Union[SQLResult[ModelDTOT], SQLResult[RowT]]":
160
+ sql_statement = self._build_statement(statement, *parameters, _config=_config or self.config, **kwargs)
161
+ result = self._execute_statement(statement=sql_statement, connection=self._connection(_connection), **kwargs)
162
+
163
+ if self.returns_rows(sql_statement.expression):
164
+ return self._wrap_select_result(
165
+ sql_statement, cast("SelectResultDict", result), schema_type=schema_type, **kwargs
166
+ )
167
+ return self._wrap_execute_result(
168
+ sql_statement, cast("Union[DMLResultDict, ScriptResultDict]", result), **kwargs
169
+ )
170
+
171
+ def execute_many(
172
+ self,
173
+ statement: "Union[SQL, Statement, QueryBuilder[Any]]",
174
+ /,
175
+ *parameters: "Union[StatementParameters, StatementFilter]",
176
+ _connection: "Optional[ConnectionT]" = None,
177
+ _config: "Optional[SQLConfig]" = None,
178
+ **kwargs: Any,
179
+ ) -> "SQLResult[RowT]":
180
+ # Separate parameters from filters
181
+ param_sequences = []
182
+ filters = []
183
+ for param in parameters:
184
+ if isinstance(param, StatementFilter):
185
+ filters.append(param)
186
+ else:
187
+ param_sequences.append(param)
188
+
189
+ # Use first parameter as the sequence for execute_many
190
+ param_sequence = param_sequences[0] if param_sequences else None
191
+ # Convert tuple to list if needed
192
+ if isinstance(param_sequence, tuple):
193
+ param_sequence = list(param_sequence)
194
+ # Ensure param_sequence is a list or None
195
+ if param_sequence is not None and not isinstance(param_sequence, list):
196
+ param_sequence = list(param_sequence) if hasattr(param_sequence, "__iter__") else None
197
+ sql_statement = self._build_statement(statement, _config=_config or self.config, **kwargs).as_many(
198
+ param_sequence
199
+ )
200
+
201
+ result = self._execute_statement(
202
+ statement=sql_statement,
203
+ connection=self._connection(_connection),
204
+ parameters=param_sequence,
205
+ is_many=True,
206
+ **kwargs,
207
+ )
208
+ return self._wrap_execute_result(
209
+ sql_statement, cast("Union[DMLResultDict, ScriptResultDict]", result), **kwargs
210
+ )
211
+
212
+ def execute_script(
213
+ self,
214
+ statement: "Union[str, SQL]",
215
+ /,
216
+ *parameters: "Union[StatementParameters, StatementFilter]",
217
+ _connection: "Optional[ConnectionT]" = None,
218
+ _config: "Optional[SQLConfig]" = None,
219
+ **kwargs: Any,
220
+ ) -> "SQLResult[RowT]":
221
+ # Separate parameters from filters
222
+ param_values = []
223
+ filters = []
224
+ for param in parameters:
225
+ if isinstance(param, StatementFilter):
226
+ filters.append(param)
227
+ else:
228
+ param_values.append(param)
229
+
230
+ # Use first parameter as the primary parameter value, or None if no parameters
231
+ primary_params = param_values[0] if param_values else None
232
+
233
+ script_config = _config or self.config
234
+ if script_config.enable_validation:
235
+ script_config = SQLConfig(
236
+ enable_parsing=script_config.enable_parsing,
237
+ enable_validation=False,
238
+ enable_transformations=script_config.enable_transformations,
239
+ enable_analysis=script_config.enable_analysis,
240
+ strict_mode=False,
241
+ cache_parsed_expression=script_config.cache_parsed_expression,
242
+ parameter_converter=script_config.parameter_converter,
243
+ parameter_validator=script_config.parameter_validator,
244
+ analysis_cache_size=script_config.analysis_cache_size,
245
+ allowed_parameter_styles=script_config.allowed_parameter_styles,
246
+ target_parameter_style=script_config.target_parameter_style,
247
+ allow_mixed_parameter_styles=script_config.allow_mixed_parameter_styles,
248
+ )
249
+
250
+ sql_statement = SQL(statement, primary_params, *filters, _dialect=self.dialect, _config=script_config, **kwargs)
251
+ sql_statement = sql_statement.as_script()
252
+ script_output = self._execute_statement(
253
+ statement=sql_statement, connection=self._connection(_connection), is_script=True, **kwargs
254
+ )
255
+ if isinstance(script_output, str):
256
+ result = SQLResult[RowT](statement=sql_statement, data=[], operation_type="SCRIPT")
257
+ result.total_statements = 1
258
+ result.successful_statements = 1
259
+ return result
260
+ # Wrap the ScriptResultDict using the driver's wrapper
261
+ return self._wrap_execute_result(sql_statement, cast("ScriptResultDict", script_output), **kwargs)
@@ -0,0 +1,17 @@
1
+ """Driver mixins for instrumentation, storage, and utilities."""
2
+
3
+ from sqlspec.driver.mixins._pipeline import AsyncPipelinedExecutionMixin, SyncPipelinedExecutionMixin
4
+ from sqlspec.driver.mixins._result_utils import ToSchemaMixin
5
+ from sqlspec.driver.mixins._sql_translator import SQLTranslatorMixin
6
+ from sqlspec.driver.mixins._storage import AsyncStorageMixin, SyncStorageMixin
7
+ from sqlspec.driver.mixins._type_coercion import TypeCoercionMixin
8
+
9
+ __all__ = (
10
+ "AsyncPipelinedExecutionMixin",
11
+ "AsyncStorageMixin",
12
+ "SQLTranslatorMixin",
13
+ "SyncPipelinedExecutionMixin",
14
+ "SyncStorageMixin",
15
+ "ToSchemaMixin",
16
+ "TypeCoercionMixin",
17
+ )