sqlspec 0.12.1__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (113) hide show
  1. sqlspec/_sql.py +21 -180
  2. sqlspec/adapters/adbc/config.py +10 -12
  3. sqlspec/adapters/adbc/driver.py +120 -118
  4. sqlspec/adapters/aiosqlite/config.py +3 -3
  5. sqlspec/adapters/aiosqlite/driver.py +116 -141
  6. sqlspec/adapters/asyncmy/config.py +3 -4
  7. sqlspec/adapters/asyncmy/driver.py +123 -135
  8. sqlspec/adapters/asyncpg/config.py +3 -7
  9. sqlspec/adapters/asyncpg/driver.py +98 -140
  10. sqlspec/adapters/bigquery/config.py +4 -5
  11. sqlspec/adapters/bigquery/driver.py +231 -181
  12. sqlspec/adapters/duckdb/config.py +3 -6
  13. sqlspec/adapters/duckdb/driver.py +132 -124
  14. sqlspec/adapters/oracledb/config.py +6 -5
  15. sqlspec/adapters/oracledb/driver.py +242 -259
  16. sqlspec/adapters/psqlpy/config.py +3 -7
  17. sqlspec/adapters/psqlpy/driver.py +118 -93
  18. sqlspec/adapters/psycopg/config.py +34 -30
  19. sqlspec/adapters/psycopg/driver.py +342 -214
  20. sqlspec/adapters/sqlite/config.py +3 -3
  21. sqlspec/adapters/sqlite/driver.py +150 -104
  22. sqlspec/config.py +0 -4
  23. sqlspec/driver/_async.py +89 -98
  24. sqlspec/driver/_common.py +52 -17
  25. sqlspec/driver/_sync.py +81 -105
  26. sqlspec/driver/connection.py +207 -0
  27. sqlspec/driver/mixins/_csv_writer.py +91 -0
  28. sqlspec/driver/mixins/_pipeline.py +38 -49
  29. sqlspec/driver/mixins/_result_utils.py +27 -9
  30. sqlspec/driver/mixins/_storage.py +149 -216
  31. sqlspec/driver/mixins/_type_coercion.py +3 -4
  32. sqlspec/driver/parameters.py +138 -0
  33. sqlspec/exceptions.py +10 -2
  34. sqlspec/extensions/aiosql/adapter.py +0 -10
  35. sqlspec/extensions/litestar/handlers.py +0 -1
  36. sqlspec/extensions/litestar/plugin.py +0 -3
  37. sqlspec/extensions/litestar/providers.py +0 -14
  38. sqlspec/loader.py +31 -118
  39. sqlspec/protocols.py +542 -0
  40. sqlspec/service/__init__.py +3 -2
  41. sqlspec/service/_util.py +147 -0
  42. sqlspec/service/base.py +1116 -9
  43. sqlspec/statement/builder/__init__.py +42 -32
  44. sqlspec/statement/builder/_ddl_utils.py +0 -10
  45. sqlspec/statement/builder/_parsing_utils.py +10 -4
  46. sqlspec/statement/builder/base.py +70 -23
  47. sqlspec/statement/builder/column.py +283 -0
  48. sqlspec/statement/builder/ddl.py +102 -65
  49. sqlspec/statement/builder/delete.py +23 -7
  50. sqlspec/statement/builder/insert.py +29 -15
  51. sqlspec/statement/builder/merge.py +4 -4
  52. sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
  53. sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
  54. sqlspec/statement/builder/mixins/_delete_from.py +1 -1
  55. sqlspec/statement/builder/mixins/_from.py +10 -8
  56. sqlspec/statement/builder/mixins/_group_by.py +0 -1
  57. sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
  58. sqlspec/statement/builder/mixins/_insert_values.py +0 -2
  59. sqlspec/statement/builder/mixins/_join.py +20 -13
  60. sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
  61. sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
  62. sqlspec/statement/builder/mixins/_order_by.py +2 -2
  63. sqlspec/statement/builder/mixins/_pivot.py +4 -7
  64. sqlspec/statement/builder/mixins/_select_columns.py +6 -5
  65. sqlspec/statement/builder/mixins/_unpivot.py +6 -9
  66. sqlspec/statement/builder/mixins/_update_from.py +2 -1
  67. sqlspec/statement/builder/mixins/_update_set.py +11 -8
  68. sqlspec/statement/builder/mixins/_where.py +61 -34
  69. sqlspec/statement/builder/select.py +32 -17
  70. sqlspec/statement/builder/update.py +25 -11
  71. sqlspec/statement/filters.py +39 -14
  72. sqlspec/statement/parameter_manager.py +220 -0
  73. sqlspec/statement/parameters.py +210 -79
  74. sqlspec/statement/pipelines/__init__.py +166 -23
  75. sqlspec/statement/pipelines/analyzers/_analyzer.py +22 -25
  76. sqlspec/statement/pipelines/context.py +35 -39
  77. sqlspec/statement/pipelines/transformers/__init__.py +2 -3
  78. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
  79. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +667 -43
  80. sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
  81. sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
  82. sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
  83. sqlspec/statement/pipelines/validators/_performance.py +38 -23
  84. sqlspec/statement/pipelines/validators/_security.py +39 -62
  85. sqlspec/statement/result.py +37 -129
  86. sqlspec/statement/splitter.py +0 -12
  87. sqlspec/statement/sql.py +885 -379
  88. sqlspec/statement/sql_compiler.py +140 -0
  89. sqlspec/storage/__init__.py +10 -2
  90. sqlspec/storage/backends/fsspec.py +82 -35
  91. sqlspec/storage/backends/obstore.py +66 -49
  92. sqlspec/storage/capabilities.py +101 -0
  93. sqlspec/storage/registry.py +56 -83
  94. sqlspec/typing.py +6 -434
  95. sqlspec/utils/cached_property.py +25 -0
  96. sqlspec/utils/correlation.py +0 -2
  97. sqlspec/utils/logging.py +0 -6
  98. sqlspec/utils/sync_tools.py +0 -4
  99. sqlspec/utils/text.py +0 -5
  100. sqlspec/utils/type_guards.py +892 -0
  101. {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/METADATA +1 -1
  102. sqlspec-0.13.0.dist-info/RECORD +150 -0
  103. sqlspec/statement/builder/protocols.py +0 -20
  104. sqlspec/statement/pipelines/base.py +0 -315
  105. sqlspec/statement/pipelines/result_types.py +0 -41
  106. sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
  107. sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
  108. sqlspec/statement/pipelines/validators/base.py +0 -67
  109. sqlspec/storage/protocol.py +0 -170
  110. sqlspec-0.12.1.dist-info/RECORD +0 -145
  111. {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/WHEEL +0 -0
  112. {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/licenses/LICENSE +0 -0
  113. {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/licenses/NOTICE +0 -0
@@ -2,12 +2,14 @@ import contextlib
2
2
  import logging
3
3
  from collections.abc import Iterator
4
4
  from contextlib import contextmanager
5
+ from dataclasses import replace
5
6
  from decimal import Decimal
6
- from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
7
+ from typing import TYPE_CHECKING, Any, ClassVar, Optional, cast
7
8
 
8
9
  from adbc_driver_manager.dbapi import Connection, Cursor
9
10
 
10
11
  from sqlspec.driver import SyncDriverAdapterProtocol
12
+ from sqlspec.driver.connection import managed_transaction_sync
11
13
  from sqlspec.driver.mixins import (
12
14
  SQLTranslatorMixin,
13
15
  SyncPipelinedExecutionMixin,
@@ -15,11 +17,12 @@ from sqlspec.driver.mixins import (
15
17
  ToSchemaMixin,
16
18
  TypeCoercionMixin,
17
19
  )
20
+ from sqlspec.driver.parameters import normalize_parameter_sequence
18
21
  from sqlspec.exceptions import wrap_exceptions
19
22
  from sqlspec.statement.parameters import ParameterStyle
20
- from sqlspec.statement.result import ArrowResult, DMLResultDict, ScriptResultDict, SelectResultDict, SQLResult
23
+ from sqlspec.statement.result import ArrowResult, SQLResult
21
24
  from sqlspec.statement.sql import SQL, SQLConfig
22
- from sqlspec.typing import DictRow, ModelDTOT, RowT, is_dict_with_field
25
+ from sqlspec.typing import DictRow, RowT
23
26
  from sqlspec.utils.serializers import to_json
24
27
 
25
28
  if TYPE_CHECKING:
@@ -65,8 +68,15 @@ class AdbcDriver(
65
68
  config: "Optional[SQLConfig]" = None,
66
69
  default_row_type: "type[DictRow]" = DictRow,
67
70
  ) -> None:
71
+ dialect = self._get_dialect(connection)
72
+ if config and not config.dialect:
73
+ config = replace(config, dialect=dialect)
74
+ elif not config:
75
+ # Create config with dialect
76
+ config = SQLConfig(dialect=dialect)
77
+
68
78
  super().__init__(connection=connection, config=config, default_row_type=default_row_type)
69
- self.dialect: DialectType = self._get_dialect(connection)
79
+ self.dialect: DialectType = dialect
70
80
  self.default_parameter_style = self._get_parameter_style_for_dialect(self.dialect)
71
81
  # Override supported parameter styles based on actual dialect capabilities
72
82
  self.supported_parameter_styles = self._get_supported_parameter_styles_for_dialect(self.dialect)
@@ -169,13 +179,13 @@ class AdbcDriver(
169
179
 
170
180
  def _execute_statement(
171
181
  self, statement: SQL, connection: Optional["AdbcConnection"] = None, **kwargs: Any
172
- ) -> Union[SelectResultDict, DMLResultDict, ScriptResultDict]:
182
+ ) -> SQLResult[RowT]:
173
183
  if statement.is_script:
174
184
  sql, _ = statement.compile(placeholder_style=ParameterStyle.STATIC)
175
185
  return self._execute_script(sql, connection=connection, **kwargs)
176
186
 
177
- # Determine if we need to convert parameter style
178
187
  detected_styles = {p.style for p in statement.parameter_info}
188
+
179
189
  target_style = self.default_parameter_style
180
190
  unsupported_styles = detected_styles - set(self.supported_parameter_styles)
181
191
 
@@ -196,69 +206,107 @@ class AdbcDriver(
196
206
 
197
207
  def _execute(
198
208
  self, sql: str, parameters: Any, statement: SQL, connection: Optional["AdbcConnection"] = None, **kwargs: Any
199
- ) -> Union[SelectResultDict, DMLResultDict]:
200
- conn = self._connection(connection)
201
- with self._get_cursor(conn) as cursor:
202
- # ADBC expects parameters as a list for most drivers
203
- if parameters is not None and not isinstance(parameters, (list, tuple)):
204
- cursor_params = [parameters]
209
+ ) -> SQLResult[RowT]:
210
+ # Use provided connection or driver's default connection
211
+ conn = connection if connection is not None else self._connection(None)
212
+
213
+ with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
214
+ normalized_params = normalize_parameter_sequence(parameters)
215
+ if normalized_params is not None and not isinstance(normalized_params, (list, tuple)):
216
+ cursor_params = [normalized_params]
205
217
  else:
206
- cursor_params = parameters # type: ignore[assignment]
207
-
208
- try:
209
- cursor.execute(sql, cursor_params or [])
210
- except Exception as e:
211
- # Rollback transaction on error for PostgreSQL to avoid "current transaction is aborted" errors
212
- if self.dialect == "postgres":
213
- with contextlib.suppress(Exception):
214
- cursor.execute("ROLLBACK")
215
- raise e from e
216
-
217
- if self.returns_rows(statement.expression):
218
- fetched_data = cursor.fetchall()
219
- column_names = [col[0] for col in cursor.description or []]
220
- result: SelectResultDict = {
221
- "data": fetched_data,
222
- "column_names": column_names,
223
- "rows_affected": len(fetched_data),
224
- }
225
- return result
226
-
227
- dml_result: DMLResultDict = {"rows_affected": cursor.rowcount, "status_message": "OK"}
228
- return dml_result
218
+ cursor_params = normalized_params
219
+
220
+ with self._get_cursor(txn_conn) as cursor:
221
+ try:
222
+ cursor.execute(sql, cursor_params or [])
223
+ except Exception as e:
224
+ # Rollback transaction on error for PostgreSQL to avoid "current transaction is aborted" errors
225
+ if self.dialect == "postgres":
226
+ with contextlib.suppress(Exception):
227
+ cursor.execute("ROLLBACK")
228
+ raise e from e
229
+
230
+ if self.returns_rows(statement.expression):
231
+ fetched_data = cursor.fetchall()
232
+ column_names = [col[0] for col in cursor.description or []]
233
+
234
+ if fetched_data and isinstance(fetched_data[0], tuple):
235
+ dict_data: list[dict[Any, Any]] = [dict(zip(column_names, row)) for row in fetched_data]
236
+ else:
237
+ dict_data = fetched_data # type: ignore[assignment]
238
+
239
+ return SQLResult(
240
+ statement=statement,
241
+ data=cast("list[RowT]", dict_data),
242
+ column_names=column_names,
243
+ rows_affected=len(dict_data),
244
+ operation_type="SELECT",
245
+ )
246
+
247
+ operation_type = self._determine_operation_type(statement)
248
+ return SQLResult(
249
+ statement=statement,
250
+ data=cast("list[RowT]", []),
251
+ rows_affected=cursor.rowcount,
252
+ operation_type=operation_type,
253
+ metadata={"status_message": "OK"},
254
+ )
229
255
 
230
256
  def _execute_many(
231
257
  self, sql: str, param_list: Any, connection: Optional["AdbcConnection"] = None, **kwargs: Any
232
- ) -> DMLResultDict:
233
- conn = self._connection(connection)
234
- with self._get_cursor(conn) as cursor:
235
- try:
236
- cursor.executemany(sql, param_list or [])
237
- except Exception as e:
238
- if self.dialect == "postgres":
239
- with contextlib.suppress(Exception):
240
- cursor.execute("ROLLBACK")
241
- # Always re-raise the original exception
242
- raise e from e
243
-
244
- result: DMLResultDict = {"rows_affected": cursor.rowcount, "status_message": "OK"}
245
- return result
258
+ ) -> SQLResult[RowT]:
259
+ # Use provided connection or driver's default connection
260
+ conn = connection if connection is not None else self._connection(None)
261
+
262
+ with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
263
+ # Normalize parameter list using consolidated utility
264
+ normalized_param_list = normalize_parameter_sequence(param_list)
265
+
266
+ with self._get_cursor(txn_conn) as cursor:
267
+ try:
268
+ cursor.executemany(sql, normalized_param_list or [])
269
+ except Exception as e:
270
+ if self.dialect == "postgres":
271
+ with contextlib.suppress(Exception):
272
+ cursor.execute("ROLLBACK")
273
+ # Always re-raise the original exception
274
+ raise e from e
275
+
276
+ return SQLResult(
277
+ statement=SQL(sql, _dialect=self.dialect),
278
+ data=[],
279
+ rows_affected=cursor.rowcount,
280
+ operation_type="EXECUTE",
281
+ metadata={"status_message": "OK"},
282
+ )
246
283
 
247
284
  def _execute_script(
248
285
  self, script: str, connection: Optional["AdbcConnection"] = None, **kwargs: Any
249
- ) -> ScriptResultDict:
250
- conn = self._connection(connection)
251
- # ADBC drivers don't support multiple statements in a single execute
252
- # Use the shared implementation to split the script
253
- statements = self._split_script_statements(script)
254
-
255
- executed_count = 0
256
- with self._get_cursor(conn) as cursor:
257
- for statement in statements:
258
- executed_count += self._execute_single_script_statement(cursor, statement)
259
-
260
- result: ScriptResultDict = {"statements_executed": executed_count, "status_message": "SCRIPT EXECUTED"}
261
- return result
286
+ ) -> SQLResult[RowT]:
287
+ # Use provided connection or driver's default connection
288
+ conn = connection if connection is not None else self._connection(None)
289
+
290
+ with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
291
+ # ADBC drivers don't support multiple statements in a single execute
292
+ statements = self._split_script_statements(script)
293
+
294
+ executed_count = 0
295
+ with self._get_cursor(txn_conn) as cursor:
296
+ for statement in statements:
297
+ if statement.strip():
298
+ self._execute_single_script_statement(cursor, statement)
299
+ executed_count += 1
300
+
301
+ return SQLResult(
302
+ statement=SQL(script, _dialect=self.dialect).as_script(),
303
+ data=[],
304
+ rows_affected=0,
305
+ operation_type="SCRIPT",
306
+ metadata={"status_message": "SCRIPT EXECUTED"},
307
+ total_statements=executed_count,
308
+ successful_statements=executed_count,
309
+ )
262
310
 
263
311
  def _execute_single_script_statement(self, cursor: "Cursor", statement: str) -> int:
264
312
  """Execute a single statement from a script and handle errors.
@@ -273,7 +321,7 @@ class AdbcDriver(
273
321
  try:
274
322
  cursor.execute(statement)
275
323
  except Exception as e:
276
- # Rollback transaction on error for PostgreSQL
324
+ # Rollback transaction on error for PostgreSQL to avoid "current transaction is aborted" errors
277
325
  if self.dialect == "postgres":
278
326
  with contextlib.suppress(Exception):
279
327
  cursor.execute("ROLLBACK")
@@ -281,59 +329,6 @@ class AdbcDriver(
281
329
  else:
282
330
  return 1
283
331
 
284
- def _wrap_select_result(
285
- self, statement: SQL, result: SelectResultDict, schema_type: Optional[type[ModelDTOT]] = None, **kwargs: Any
286
- ) -> Union[SQLResult[ModelDTOT], SQLResult[RowT]]:
287
- # result must be a dict with keys: data, column_names, rows_affected
288
-
289
- rows_as_dicts = [dict(zip(result["column_names"], row)) for row in result["data"]]
290
-
291
- if schema_type:
292
- return SQLResult[ModelDTOT](
293
- statement=statement,
294
- data=list(self.to_schema(data=rows_as_dicts, schema_type=schema_type)),
295
- column_names=result["column_names"],
296
- rows_affected=result["rows_affected"],
297
- operation_type="SELECT",
298
- )
299
- return SQLResult[RowT](
300
- statement=statement,
301
- data=rows_as_dicts,
302
- column_names=result["column_names"],
303
- rows_affected=result["rows_affected"],
304
- operation_type="SELECT",
305
- )
306
-
307
- def _wrap_execute_result(
308
- self, statement: SQL, result: Union[DMLResultDict, ScriptResultDict], **kwargs: Any
309
- ) -> SQLResult[RowT]:
310
- operation_type = (
311
- str(statement.expression.key).upper()
312
- if statement.expression and hasattr(statement.expression, "key")
313
- else "UNKNOWN"
314
- )
315
-
316
- # Handle TypedDict results
317
- if is_dict_with_field(result, "statements_executed"):
318
- return SQLResult[RowT](
319
- statement=statement,
320
- data=[],
321
- rows_affected=0,
322
- total_statements=result["statements_executed"],
323
- operation_type="SCRIPT", # Scripts always have operation_type SCRIPT
324
- metadata={"status_message": result["status_message"]},
325
- )
326
- if is_dict_with_field(result, "rows_affected"):
327
- return SQLResult[RowT](
328
- statement=statement,
329
- data=[],
330
- rows_affected=result["rows_affected"],
331
- operation_type=operation_type,
332
- metadata={"status_message": result["status_message"]},
333
- )
334
- msg = f"Unexpected result type: {type(result)}"
335
- raise ValueError(msg)
336
-
337
332
  def _fetch_arrow_table(self, sql: SQL, connection: "Optional[Any]" = None, **kwargs: Any) -> "ArrowResult":
338
333
  """ADBC native Arrow table fetching.
339
334
 
@@ -379,10 +374,17 @@ class AdbcDriver(
379
374
 
380
375
  conn = self._connection(None)
381
376
  with self._get_cursor(conn) as cursor:
382
- # Handle different modes
383
377
  if mode == "replace":
384
- cursor.execute(SQL(f"TRUNCATE TABLE {table_name}").to_sql(placeholder_style=ParameterStyle.STATIC))
378
+ cursor.execute(
379
+ SQL(f"TRUNCATE TABLE {table_name}", _dialect=self.dialect).to_sql(
380
+ placeholder_style=ParameterStyle.STATIC
381
+ )
382
+ )
385
383
  elif mode == "create":
386
384
  msg = "'create' mode is not supported for ADBC ingestion"
387
385
  raise NotImplementedError(msg)
388
386
  return cursor.adbc_ingest(table_name, table, mode=mode, **options) # type: ignore[arg-type]
387
+
388
+ def _connection(self, connection: Optional["AdbcConnection"] = None) -> "AdbcConnection":
389
+ """Get the connection to use for the operation."""
390
+ return connection or self.connection
@@ -3,7 +3,6 @@
3
3
  import logging
4
4
  from collections.abc import AsyncGenerator
5
5
  from contextlib import asynccontextmanager
6
- from dataclasses import replace
7
6
  from typing import TYPE_CHECKING, Any, ClassVar, Optional
8
7
 
9
8
  import aiosqlite
@@ -172,15 +171,16 @@ class AiosqliteConfig(AsyncDatabaseConfig[AiosqliteConnection, None, AiosqliteDr
172
171
  An AiosqliteDriver instance.
173
172
  """
174
173
  async with self.provide_connection(*args, **kwargs) as connection:
175
- # Create statement config with parameter style info if not already set
176
174
  statement_config = self.statement_config
175
+ # Inject parameter style info if not already set
177
176
  if statement_config.allowed_parameter_styles is None:
177
+ from dataclasses import replace
178
+
178
179
  statement_config = replace(
179
180
  statement_config,
180
181
  allowed_parameter_styles=self.supported_parameter_styles,
181
182
  target_parameter_style=self.preferred_parameter_style,
182
183
  )
183
-
184
184
  yield self.driver_type(connection=connection, config=statement_config)
185
185
 
186
186
  async def provide_pool(self, *args: Any, **kwargs: Any) -> None: