sqlspec 0.11.1__py3-none-any.whl → 0.12.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (155) hide show
  1. sqlspec/__init__.py +16 -3
  2. sqlspec/_serialization.py +3 -10
  3. sqlspec/_sql.py +1147 -0
  4. sqlspec/_typing.py +343 -41
  5. sqlspec/adapters/adbc/__init__.py +2 -6
  6. sqlspec/adapters/adbc/config.py +474 -149
  7. sqlspec/adapters/adbc/driver.py +330 -621
  8. sqlspec/adapters/aiosqlite/__init__.py +2 -6
  9. sqlspec/adapters/aiosqlite/config.py +143 -57
  10. sqlspec/adapters/aiosqlite/driver.py +269 -431
  11. sqlspec/adapters/asyncmy/__init__.py +3 -8
  12. sqlspec/adapters/asyncmy/config.py +247 -202
  13. sqlspec/adapters/asyncmy/driver.py +218 -436
  14. sqlspec/adapters/asyncpg/__init__.py +4 -7
  15. sqlspec/adapters/asyncpg/config.py +329 -176
  16. sqlspec/adapters/asyncpg/driver.py +417 -487
  17. sqlspec/adapters/bigquery/__init__.py +2 -2
  18. sqlspec/adapters/bigquery/config.py +407 -0
  19. sqlspec/adapters/bigquery/driver.py +600 -553
  20. sqlspec/adapters/duckdb/__init__.py +4 -1
  21. sqlspec/adapters/duckdb/config.py +432 -321
  22. sqlspec/adapters/duckdb/driver.py +392 -406
  23. sqlspec/adapters/oracledb/__init__.py +3 -8
  24. sqlspec/adapters/oracledb/config.py +625 -0
  25. sqlspec/adapters/oracledb/driver.py +548 -921
  26. sqlspec/adapters/psqlpy/__init__.py +4 -7
  27. sqlspec/adapters/psqlpy/config.py +372 -203
  28. sqlspec/adapters/psqlpy/driver.py +197 -533
  29. sqlspec/adapters/psycopg/__init__.py +3 -8
  30. sqlspec/adapters/psycopg/config.py +725 -0
  31. sqlspec/adapters/psycopg/driver.py +734 -694
  32. sqlspec/adapters/sqlite/__init__.py +2 -6
  33. sqlspec/adapters/sqlite/config.py +146 -81
  34. sqlspec/adapters/sqlite/driver.py +242 -405
  35. sqlspec/base.py +220 -784
  36. sqlspec/config.py +354 -0
  37. sqlspec/driver/__init__.py +22 -0
  38. sqlspec/driver/_async.py +252 -0
  39. sqlspec/driver/_common.py +338 -0
  40. sqlspec/driver/_sync.py +261 -0
  41. sqlspec/driver/mixins/__init__.py +17 -0
  42. sqlspec/driver/mixins/_pipeline.py +523 -0
  43. sqlspec/driver/mixins/_result_utils.py +122 -0
  44. sqlspec/driver/mixins/_sql_translator.py +35 -0
  45. sqlspec/driver/mixins/_storage.py +993 -0
  46. sqlspec/driver/mixins/_type_coercion.py +131 -0
  47. sqlspec/exceptions.py +299 -7
  48. sqlspec/extensions/aiosql/__init__.py +10 -0
  49. sqlspec/extensions/aiosql/adapter.py +474 -0
  50. sqlspec/extensions/litestar/__init__.py +1 -6
  51. sqlspec/extensions/litestar/_utils.py +1 -5
  52. sqlspec/extensions/litestar/config.py +5 -6
  53. sqlspec/extensions/litestar/handlers.py +13 -12
  54. sqlspec/extensions/litestar/plugin.py +22 -24
  55. sqlspec/extensions/litestar/providers.py +37 -55
  56. sqlspec/loader.py +528 -0
  57. sqlspec/service/__init__.py +3 -0
  58. sqlspec/service/base.py +24 -0
  59. sqlspec/service/pagination.py +26 -0
  60. sqlspec/statement/__init__.py +21 -0
  61. sqlspec/statement/builder/__init__.py +54 -0
  62. sqlspec/statement/builder/_ddl_utils.py +119 -0
  63. sqlspec/statement/builder/_parsing_utils.py +135 -0
  64. sqlspec/statement/builder/base.py +328 -0
  65. sqlspec/statement/builder/ddl.py +1379 -0
  66. sqlspec/statement/builder/delete.py +80 -0
  67. sqlspec/statement/builder/insert.py +274 -0
  68. sqlspec/statement/builder/merge.py +95 -0
  69. sqlspec/statement/builder/mixins/__init__.py +65 -0
  70. sqlspec/statement/builder/mixins/_aggregate_functions.py +151 -0
  71. sqlspec/statement/builder/mixins/_case_builder.py +91 -0
  72. sqlspec/statement/builder/mixins/_common_table_expr.py +91 -0
  73. sqlspec/statement/builder/mixins/_delete_from.py +34 -0
  74. sqlspec/statement/builder/mixins/_from.py +61 -0
  75. sqlspec/statement/builder/mixins/_group_by.py +119 -0
  76. sqlspec/statement/builder/mixins/_having.py +35 -0
  77. sqlspec/statement/builder/mixins/_insert_from_select.py +48 -0
  78. sqlspec/statement/builder/mixins/_insert_into.py +36 -0
  79. sqlspec/statement/builder/mixins/_insert_values.py +69 -0
  80. sqlspec/statement/builder/mixins/_join.py +110 -0
  81. sqlspec/statement/builder/mixins/_limit_offset.py +53 -0
  82. sqlspec/statement/builder/mixins/_merge_clauses.py +405 -0
  83. sqlspec/statement/builder/mixins/_order_by.py +46 -0
  84. sqlspec/statement/builder/mixins/_pivot.py +82 -0
  85. sqlspec/statement/builder/mixins/_returning.py +37 -0
  86. sqlspec/statement/builder/mixins/_select_columns.py +60 -0
  87. sqlspec/statement/builder/mixins/_set_ops.py +122 -0
  88. sqlspec/statement/builder/mixins/_unpivot.py +80 -0
  89. sqlspec/statement/builder/mixins/_update_from.py +54 -0
  90. sqlspec/statement/builder/mixins/_update_set.py +91 -0
  91. sqlspec/statement/builder/mixins/_update_table.py +29 -0
  92. sqlspec/statement/builder/mixins/_where.py +374 -0
  93. sqlspec/statement/builder/mixins/_window_functions.py +86 -0
  94. sqlspec/statement/builder/protocols.py +20 -0
  95. sqlspec/statement/builder/select.py +206 -0
  96. sqlspec/statement/builder/update.py +178 -0
  97. sqlspec/statement/filters.py +571 -0
  98. sqlspec/statement/parameters.py +736 -0
  99. sqlspec/statement/pipelines/__init__.py +67 -0
  100. sqlspec/statement/pipelines/analyzers/__init__.py +9 -0
  101. sqlspec/statement/pipelines/analyzers/_analyzer.py +649 -0
  102. sqlspec/statement/pipelines/base.py +315 -0
  103. sqlspec/statement/pipelines/context.py +119 -0
  104. sqlspec/statement/pipelines/result_types.py +41 -0
  105. sqlspec/statement/pipelines/transformers/__init__.py +8 -0
  106. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +256 -0
  107. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +623 -0
  108. sqlspec/statement/pipelines/transformers/_remove_comments.py +66 -0
  109. sqlspec/statement/pipelines/transformers/_remove_hints.py +81 -0
  110. sqlspec/statement/pipelines/validators/__init__.py +23 -0
  111. sqlspec/statement/pipelines/validators/_dml_safety.py +275 -0
  112. sqlspec/statement/pipelines/validators/_parameter_style.py +297 -0
  113. sqlspec/statement/pipelines/validators/_performance.py +703 -0
  114. sqlspec/statement/pipelines/validators/_security.py +990 -0
  115. sqlspec/statement/pipelines/validators/base.py +67 -0
  116. sqlspec/statement/result.py +527 -0
  117. sqlspec/statement/splitter.py +701 -0
  118. sqlspec/statement/sql.py +1198 -0
  119. sqlspec/storage/__init__.py +15 -0
  120. sqlspec/storage/backends/__init__.py +0 -0
  121. sqlspec/storage/backends/base.py +166 -0
  122. sqlspec/storage/backends/fsspec.py +315 -0
  123. sqlspec/storage/backends/obstore.py +464 -0
  124. sqlspec/storage/protocol.py +170 -0
  125. sqlspec/storage/registry.py +315 -0
  126. sqlspec/typing.py +157 -36
  127. sqlspec/utils/correlation.py +155 -0
  128. sqlspec/utils/deprecation.py +3 -6
  129. sqlspec/utils/fixtures.py +6 -11
  130. sqlspec/utils/logging.py +135 -0
  131. sqlspec/utils/module_loader.py +45 -43
  132. sqlspec/utils/serializers.py +4 -0
  133. sqlspec/utils/singleton.py +6 -8
  134. sqlspec/utils/sync_tools.py +15 -27
  135. sqlspec/utils/text.py +58 -26
  136. {sqlspec-0.11.1.dist-info → sqlspec-0.12.1.dist-info}/METADATA +97 -26
  137. sqlspec-0.12.1.dist-info/RECORD +145 -0
  138. sqlspec/adapters/bigquery/config/__init__.py +0 -3
  139. sqlspec/adapters/bigquery/config/_common.py +0 -40
  140. sqlspec/adapters/bigquery/config/_sync.py +0 -87
  141. sqlspec/adapters/oracledb/config/__init__.py +0 -9
  142. sqlspec/adapters/oracledb/config/_asyncio.py +0 -186
  143. sqlspec/adapters/oracledb/config/_common.py +0 -131
  144. sqlspec/adapters/oracledb/config/_sync.py +0 -186
  145. sqlspec/adapters/psycopg/config/__init__.py +0 -19
  146. sqlspec/adapters/psycopg/config/_async.py +0 -169
  147. sqlspec/adapters/psycopg/config/_common.py +0 -56
  148. sqlspec/adapters/psycopg/config/_sync.py +0 -168
  149. sqlspec/filters.py +0 -331
  150. sqlspec/mixins.py +0 -305
  151. sqlspec/statement.py +0 -378
  152. sqlspec-0.11.1.dist-info/RECORD +0 -69
  153. {sqlspec-0.11.1.dist-info → sqlspec-0.12.1.dist-info}/WHEEL +0 -0
  154. {sqlspec-0.11.1.dist-info → sqlspec-0.12.1.dist-info}/licenses/LICENSE +0 -0
  155. {sqlspec-0.11.1.dist-info → sqlspec-0.12.1.dist-info}/licenses/NOTICE +0 -0
sqlspec/config.py ADDED
@@ -0,0 +1,354 @@
1
+ from abc import ABC, abstractmethod
2
+ from dataclasses import dataclass, field
3
+ from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union
4
+
5
+ from sqlspec.typing import ConnectionT, PoolT # pyright: ignore
6
+ from sqlspec.utils.logging import get_logger
7
+
8
+ if TYPE_CHECKING:
9
+ from collections.abc import Awaitable
10
+ from contextlib import AbstractAsyncContextManager, AbstractContextManager
11
+
12
+ from sqlglot.dialects.dialect import DialectType
13
+
14
+ from sqlspec.driver import AsyncDriverAdapterProtocol, SyncDriverAdapterProtocol
15
+ from sqlspec.statement.result import StatementResult
16
+
17
+
18
+ StatementResultType = Union["StatementResult[dict[str, Any]]", "StatementResult[Any]"]
19
+
20
+
21
+ __all__ = (
22
+ "AsyncConfigT",
23
+ "AsyncDatabaseConfig",
24
+ "ConfigT",
25
+ "DatabaseConfigProtocol",
26
+ "DriverT",
27
+ "GenericPoolConfig",
28
+ "NoPoolAsyncConfig",
29
+ "NoPoolSyncConfig",
30
+ "StatementResultType",
31
+ "SyncConfigT",
32
+ "SyncDatabaseConfig",
33
+ )
34
+
35
+ AsyncConfigT = TypeVar("AsyncConfigT", bound="Union[AsyncDatabaseConfig[Any, Any, Any], NoPoolAsyncConfig[Any, Any]]")
36
+ SyncConfigT = TypeVar("SyncConfigT", bound="Union[SyncDatabaseConfig[Any, Any, Any], NoPoolSyncConfig[Any, Any]]")
37
+ ConfigT = TypeVar(
38
+ "ConfigT",
39
+ bound="Union[Union[AsyncDatabaseConfig[Any, Any, Any], NoPoolAsyncConfig[Any, Any]], SyncDatabaseConfig[Any, Any, Any], NoPoolSyncConfig[Any, Any]]",
40
+ )
41
+ DriverT = TypeVar("DriverT", bound="Union[SyncDriverAdapterProtocol[Any], AsyncDriverAdapterProtocol[Any]]")
42
+
43
+ logger = get_logger("config")
44
+
45
+
46
+ @dataclass
47
+ class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]):
48
+ """Protocol defining the interface for database configurations."""
49
+
50
+ # Note: __slots__ cannot be used with dataclass fields in Python < 3.10
51
+ # Concrete subclasses can still use __slots__ for any additional attributes
52
+ __slots__ = ()
53
+
54
+ is_async: "ClassVar[bool]" = field(init=False, default=False)
55
+ supports_connection_pooling: "ClassVar[bool]" = field(init=False, default=False)
56
+ supports_native_arrow_import: "ClassVar[bool]" = field(init=False, default=False)
57
+ supports_native_arrow_export: "ClassVar[bool]" = field(init=False, default=False)
58
+ supports_native_parquet_import: "ClassVar[bool]" = field(init=False, default=False)
59
+ supports_native_parquet_export: "ClassVar[bool]" = field(init=False, default=False)
60
+ connection_type: "type[ConnectionT]" = field(init=False, repr=False, hash=False, compare=False)
61
+ driver_type: "type[DriverT]" = field(init=False, repr=False, hash=False, compare=False)
62
+ pool_instance: "Optional[PoolT]" = field(default=None)
63
+ default_row_type: "type[Any]" = field(init=False)
64
+ _dialect: "DialectType" = field(default=None, init=False, repr=False, hash=False, compare=False)
65
+
66
+ supported_parameter_styles: "ClassVar[tuple[str, ...]]" = ()
67
+ """Parameter styles supported by this database adapter (e.g., ('qmark', 'named_colon'))."""
68
+
69
+ preferred_parameter_style: "ClassVar[str]" = "none"
70
+ """The preferred/native parameter style for this database."""
71
+
72
+ def __hash__(self) -> int:
73
+ return id(self)
74
+
75
+ @property
76
+ def dialect(self) -> "DialectType":
77
+ """Get the SQL dialect type lazily.
78
+
79
+ This property allows dialect to be set either statically as a class attribute
80
+ or dynamically via the _get_dialect() method. If a specific adapter needs
81
+ dynamic dialect detection (e.g., ADBC which supports multiple databases),
82
+ it can override _get_dialect() to provide custom logic.
83
+
84
+ Returns:
85
+ The SQL dialect type for this database.
86
+ """
87
+ if self._dialect is None:
88
+ self._dialect = self._get_dialect() # type: ignore[misc]
89
+ return self._dialect
90
+
91
+ def _get_dialect(self) -> "DialectType":
92
+ """Get the dialect for this database configuration.
93
+
94
+ This method should be overridden by configs that need dynamic dialect detection.
95
+ By default, it looks for the dialect on the driver class.
96
+
97
+ Returns:
98
+ The SQL dialect type.
99
+ """
100
+ # Get dialect from driver_class (all drivers must have a dialect attribute)
101
+ return self.driver_type.dialect
102
+
103
+ @abstractmethod
104
+ def create_connection(self) -> "Union[ConnectionT, Awaitable[ConnectionT]]":
105
+ """Create and return a new database connection."""
106
+ raise NotImplementedError
107
+
108
+ @abstractmethod
109
+ def provide_connection(
110
+ self, *args: Any, **kwargs: Any
111
+ ) -> "Union[AbstractContextManager[ConnectionT], AbstractAsyncContextManager[ConnectionT]]":
112
+ """Provide a database connection context manager."""
113
+ raise NotImplementedError
114
+
115
+ @abstractmethod
116
+ def provide_session(
117
+ self, *args: Any, **kwargs: Any
118
+ ) -> "Union[AbstractContextManager[DriverT], AbstractAsyncContextManager[DriverT]]":
119
+ """Provide a database session context manager."""
120
+ raise NotImplementedError
121
+
122
+ @property
123
+ @abstractmethod
124
+ def connection_config_dict(self) -> "dict[str, Any]":
125
+ """Return the connection configuration as a dict."""
126
+ raise NotImplementedError
127
+
128
+ @abstractmethod
129
+ def create_pool(self) -> "Union[PoolT, Awaitable[PoolT]]":
130
+ """Create and return connection pool."""
131
+ raise NotImplementedError
132
+
133
+ @abstractmethod
134
+ def close_pool(self) -> "Optional[Awaitable[None]]":
135
+ """Terminate the connection pool."""
136
+ raise NotImplementedError
137
+
138
+ @abstractmethod
139
+ def provide_pool(
140
+ self, *args: Any, **kwargs: Any
141
+ ) -> "Union[PoolT, Awaitable[PoolT], AbstractContextManager[PoolT], AbstractAsyncContextManager[PoolT]]":
142
+ """Provide pool instance."""
143
+ raise NotImplementedError
144
+
145
+ def get_signature_namespace(self) -> "dict[str, type[Any]]":
146
+ """Get the signature namespace for this database configuration.
147
+
148
+ This method returns a dictionary of type names to types that should be
149
+ registered with Litestar's signature namespace to prevent serialization
150
+ attempts on database-specific types.
151
+
152
+ Returns:
153
+ Dictionary mapping type names to types.
154
+ """
155
+ namespace: dict[str, type[Any]] = {}
156
+
157
+ # Add the driver and config types
158
+ if hasattr(self, "driver_type") and self.driver_type:
159
+ namespace[self.driver_type.__name__] = self.driver_type
160
+
161
+ namespace[self.__class__.__name__] = self.__class__
162
+
163
+ # Add connection type(s)
164
+ if hasattr(self, "connection_type") and self.connection_type:
165
+ connection_type = self.connection_type
166
+
167
+ # Handle Union types (like AsyncPG's Union[Connection, PoolConnectionProxy])
168
+ if hasattr(connection_type, "__args__"):
169
+ # It's a generic type, extract the actual types
170
+ for arg_type in connection_type.__args__: # type: ignore[attr-defined]
171
+ if arg_type and hasattr(arg_type, "__name__"):
172
+ namespace[arg_type.__name__] = arg_type
173
+ elif hasattr(connection_type, "__name__"):
174
+ # Regular type
175
+ namespace[connection_type.__name__] = connection_type
176
+
177
+ return namespace
178
+
179
+
180
+ @dataclass
181
+ class NoPoolSyncConfig(DatabaseConfigProtocol[ConnectionT, None, DriverT]):
182
+ """Base class for a sync database configurations that do not implement a pool."""
183
+
184
+ __slots__ = ()
185
+
186
+ is_async: "ClassVar[bool]" = field(init=False, default=False)
187
+ supports_connection_pooling: "ClassVar[bool]" = field(init=False, default=False)
188
+ pool_instance: None = None
189
+
190
+ def create_connection(self) -> ConnectionT:
191
+ """Create connection with instrumentation."""
192
+ raise NotImplementedError
193
+
194
+ def provide_connection(self, *args: Any, **kwargs: Any) -> "AbstractContextManager[ConnectionT]":
195
+ """Provide connection with instrumentation."""
196
+ raise NotImplementedError
197
+
198
+ def provide_session(self, *args: Any, **kwargs: Any) -> "AbstractContextManager[DriverT]":
199
+ """Provide session with instrumentation."""
200
+ raise NotImplementedError
201
+
202
+ def create_pool(self) -> None:
203
+ return None
204
+
205
+ def close_pool(self) -> None:
206
+ return None
207
+
208
+ def provide_pool(self, *args: Any, **kwargs: Any) -> None:
209
+ return None
210
+
211
+
212
+ @dataclass
213
+ class NoPoolAsyncConfig(DatabaseConfigProtocol[ConnectionT, None, DriverT]):
214
+ """Base class for an async database configurations that do not implement a pool."""
215
+
216
+ __slots__ = ()
217
+
218
+ is_async: "ClassVar[bool]" = field(init=False, default=True)
219
+ supports_connection_pooling: "ClassVar[bool]" = field(init=False, default=False)
220
+ pool_instance: None = None
221
+
222
+ async def create_connection(self) -> ConnectionT:
223
+ """Create connection with instrumentation."""
224
+ raise NotImplementedError
225
+
226
+ def provide_connection(self, *args: Any, **kwargs: Any) -> "AbstractAsyncContextManager[ConnectionT]":
227
+ """Provide connection with instrumentation."""
228
+ raise NotImplementedError
229
+
230
+ def provide_session(self, *args: Any, **kwargs: Any) -> "AbstractAsyncContextManager[DriverT]":
231
+ """Provide session with instrumentation."""
232
+ raise NotImplementedError
233
+
234
+ async def create_pool(self) -> None:
235
+ return None
236
+
237
+ async def close_pool(self) -> None:
238
+ return None
239
+
240
+ def provide_pool(self, *args: Any, **kwargs: Any) -> None:
241
+ return None
242
+
243
+
244
+ @dataclass
245
+ class GenericPoolConfig:
246
+ """Generic Database Pool Configuration."""
247
+
248
+ __slots__ = ()
249
+
250
+
251
+ @dataclass
252
+ class SyncDatabaseConfig(DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]):
253
+ """Generic Sync Database Configuration."""
254
+
255
+ __slots__ = ()
256
+
257
+ is_async: "ClassVar[bool]" = field(init=False, default=False)
258
+ supports_connection_pooling: "ClassVar[bool]" = field(init=False, default=True)
259
+
260
+ def create_pool(self) -> PoolT:
261
+ """Create pool with instrumentation.
262
+
263
+ Returns:
264
+ The created pool.
265
+ """
266
+ if self.pool_instance is not None:
267
+ return self.pool_instance
268
+ self.pool_instance = self._create_pool() # type: ignore[misc]
269
+ return self.pool_instance
270
+
271
+ def close_pool(self) -> None:
272
+ """Close pool with instrumentation."""
273
+ self._close_pool()
274
+
275
+ def provide_pool(self, *args: Any, **kwargs: Any) -> PoolT:
276
+ """Provide pool instance."""
277
+ if self.pool_instance is None:
278
+ self.pool_instance = self.create_pool() # type: ignore[misc]
279
+ return self.pool_instance
280
+
281
+ def create_connection(self) -> ConnectionT:
282
+ """Create connection with instrumentation."""
283
+ raise NotImplementedError
284
+
285
+ def provide_connection(self, *args: Any, **kwargs: Any) -> "AbstractContextManager[ConnectionT]":
286
+ """Provide connection with instrumentation."""
287
+ raise NotImplementedError
288
+
289
+ def provide_session(self, *args: Any, **kwargs: Any) -> "AbstractContextManager[DriverT]":
290
+ """Provide session with instrumentation."""
291
+ raise NotImplementedError
292
+
293
+ @abstractmethod
294
+ def _create_pool(self) -> PoolT:
295
+ """Actual pool creation implementation."""
296
+ raise NotImplementedError
297
+
298
+ @abstractmethod
299
+ def _close_pool(self) -> None:
300
+ """Actual pool destruction implementation."""
301
+ raise NotImplementedError
302
+
303
+
304
+ @dataclass
305
+ class AsyncDatabaseConfig(DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]):
306
+ """Generic Async Database Configuration."""
307
+
308
+ __slots__ = ()
309
+
310
+ is_async: "ClassVar[bool]" = field(init=False, default=True)
311
+ supports_connection_pooling: "ClassVar[bool]" = field(init=False, default=True)
312
+
313
+ async def create_pool(self) -> PoolT:
314
+ """Create pool with instrumentation.
315
+
316
+ Returns:
317
+ The created pool.
318
+ """
319
+ if self.pool_instance is not None:
320
+ return self.pool_instance
321
+ self.pool_instance = await self._create_pool() # type: ignore[misc]
322
+ return self.pool_instance
323
+
324
+ async def close_pool(self) -> None:
325
+ """Close pool with instrumentation."""
326
+ await self._close_pool()
327
+
328
+ async def provide_pool(self, *args: Any, **kwargs: Any) -> PoolT:
329
+ """Provide pool instance."""
330
+ if self.pool_instance is None:
331
+ self.pool_instance = await self.create_pool() # type: ignore[misc]
332
+ return self.pool_instance
333
+
334
+ async def create_connection(self) -> ConnectionT:
335
+ """Create connection with instrumentation."""
336
+ raise NotImplementedError
337
+
338
+ def provide_connection(self, *args: Any, **kwargs: Any) -> "AbstractAsyncContextManager[ConnectionT]":
339
+ """Provide connection with instrumentation."""
340
+ raise NotImplementedError
341
+
342
+ def provide_session(self, *args: Any, **kwargs: Any) -> "AbstractAsyncContextManager[DriverT]":
343
+ """Provide session with instrumentation."""
344
+ raise NotImplementedError
345
+
346
+ @abstractmethod
347
+ async def _create_pool(self) -> PoolT:
348
+ """Actual async pool creation implementation."""
349
+ raise NotImplementedError
350
+
351
+ @abstractmethod
352
+ async def _close_pool(self) -> None:
353
+ """Actual async pool destruction implementation."""
354
+ raise NotImplementedError
@@ -0,0 +1,22 @@
1
+ """Driver protocols and base classes for database adapters."""
2
+
3
+ from typing import Union
4
+
5
+ from sqlspec.driver import mixins
6
+ from sqlspec.driver._async import AsyncDriverAdapterProtocol
7
+ from sqlspec.driver._common import CommonDriverAttributesMixin
8
+ from sqlspec.driver._sync import SyncDriverAdapterProtocol
9
+ from sqlspec.typing import ConnectionT, RowT
10
+
11
+ __all__ = (
12
+ "AsyncDriverAdapterProtocol",
13
+ "CommonDriverAttributesMixin",
14
+ "DriverAdapterProtocol",
15
+ "SyncDriverAdapterProtocol",
16
+ "mixins",
17
+ )
18
+
19
+ # Type alias for convenience
20
+ DriverAdapterProtocol = Union[
21
+ SyncDriverAdapterProtocol[ConnectionT, RowT], AsyncDriverAdapterProtocol[ConnectionT, RowT]
22
+ ]
@@ -0,0 +1,252 @@
1
+ """Asynchronous driver protocol implementation."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload
5
+
6
+ from sqlspec.driver._common import CommonDriverAttributesMixin
7
+ from sqlspec.statement.builder import DeleteBuilder, InsertBuilder, QueryBuilder, SelectBuilder, UpdateBuilder
8
+ from sqlspec.statement.filters import StatementFilter
9
+ from sqlspec.statement.result import SQLResult
10
+ from sqlspec.statement.sql import SQL, SQLConfig, Statement
11
+ from sqlspec.typing import ConnectionT, DictRow, ModelDTOT, RowT, StatementParameters
12
+
13
+ if TYPE_CHECKING:
14
+ from sqlspec.statement.result import DMLResultDict, ScriptResultDict, SelectResultDict
15
+
16
+ __all__ = ("AsyncDriverAdapterProtocol",)
17
+
18
+
19
+ EMPTY_FILTERS: "list[StatementFilter]" = []
20
+
21
+
22
+ class AsyncDriverAdapterProtocol(CommonDriverAttributesMixin[ConnectionT, RowT], ABC):
23
+ __slots__ = ()
24
+
25
+ def __init__(
26
+ self,
27
+ connection: "ConnectionT",
28
+ config: "Optional[SQLConfig]" = None,
29
+ default_row_type: "type[DictRow]" = DictRow,
30
+ ) -> None:
31
+ """Initialize async driver adapter.
32
+
33
+ Args:
34
+ connection: The database connection
35
+ config: SQL statement configuration
36
+ default_row_type: Default row type for results (DictRow, TupleRow, etc.)
37
+ """
38
+ super().__init__(connection=connection, config=config, default_row_type=default_row_type)
39
+
40
+ def _build_statement(
41
+ self,
42
+ statement: "Union[Statement, QueryBuilder[Any]]",
43
+ *parameters: "Union[StatementParameters, StatementFilter]",
44
+ _config: "Optional[SQLConfig]" = None,
45
+ **kwargs: Any,
46
+ ) -> "SQL":
47
+ # Use driver's config if none provided
48
+ _config = _config or self.config
49
+
50
+ if isinstance(statement, QueryBuilder):
51
+ return statement.to_statement(config=_config)
52
+ # If statement is already a SQL object, return it as-is
53
+ if isinstance(statement, SQL):
54
+ return statement
55
+ return SQL(statement, *parameters, _dialect=self.dialect, _config=_config, **kwargs)
56
+
57
+ @abstractmethod
58
+ async def _execute_statement(
59
+ self, statement: "SQL", connection: "Optional[ConnectionT]" = None, **kwargs: Any
60
+ ) -> "Union[SelectResultDict, DMLResultDict, ScriptResultDict]":
61
+ """Actual execution implementation by concrete drivers, using the raw connection.
62
+
63
+ Returns one of the standardized result dictionaries based on the statement type.
64
+ """
65
+ raise NotImplementedError
66
+
67
+ @abstractmethod
68
+ async def _wrap_select_result(
69
+ self,
70
+ statement: "SQL",
71
+ result: "SelectResultDict",
72
+ schema_type: "Optional[type[ModelDTOT]]" = None,
73
+ **kwargs: Any,
74
+ ) -> "Union[SQLResult[ModelDTOT], SQLResult[RowT]]":
75
+ raise NotImplementedError
76
+
77
+ @abstractmethod
78
+ async def _wrap_execute_result(
79
+ self, statement: "SQL", result: "Union[DMLResultDict, ScriptResultDict]", **kwargs: Any
80
+ ) -> "SQLResult[RowT]":
81
+ raise NotImplementedError
82
+
83
+ # Type-safe overloads based on the refactor plan pattern
84
+ @overload
85
+ async def execute(
86
+ self,
87
+ statement: "SelectBuilder",
88
+ /,
89
+ *parameters: "Union[StatementParameters, StatementFilter]",
90
+ schema_type: "type[ModelDTOT]",
91
+ _connection: "Optional[ConnectionT]" = None,
92
+ _config: "Optional[SQLConfig]" = None,
93
+ **kwargs: Any,
94
+ ) -> "SQLResult[ModelDTOT]": ...
95
+
96
+ @overload
97
+ async def execute(
98
+ self,
99
+ statement: "SelectBuilder",
100
+ /,
101
+ *parameters: "Union[StatementParameters, StatementFilter]",
102
+ schema_type: None = None,
103
+ _connection: "Optional[ConnectionT]" = None,
104
+ _config: "Optional[SQLConfig]" = None,
105
+ **kwargs: Any,
106
+ ) -> "SQLResult[RowT]": ...
107
+
108
+ @overload
109
+ async def execute(
110
+ self,
111
+ statement: "Union[InsertBuilder, UpdateBuilder, DeleteBuilder]",
112
+ /,
113
+ *parameters: "Union[StatementParameters, StatementFilter]",
114
+ _connection: "Optional[ConnectionT]" = None,
115
+ _config: "Optional[SQLConfig]" = None,
116
+ **kwargs: Any,
117
+ ) -> "SQLResult[RowT]": ...
118
+
119
+ @overload
120
+ async def execute(
121
+ self,
122
+ statement: "Union[str, SQL]", # exp.Expression
123
+ /,
124
+ *parameters: "Union[StatementParameters, StatementFilter]",
125
+ schema_type: "type[ModelDTOT]",
126
+ _connection: "Optional[ConnectionT]" = None,
127
+ _config: "Optional[SQLConfig]" = None,
128
+ **kwargs: Any,
129
+ ) -> "SQLResult[ModelDTOT]": ...
130
+
131
+ @overload
132
+ async def execute(
133
+ self,
134
+ statement: "Union[str, SQL]",
135
+ /,
136
+ *parameters: "Union[StatementParameters, StatementFilter]",
137
+ schema_type: None = None,
138
+ _connection: "Optional[ConnectionT]" = None,
139
+ _config: "Optional[SQLConfig]" = None,
140
+ **kwargs: Any,
141
+ ) -> "SQLResult[RowT]": ...
142
+
143
+ async def execute(
144
+ self,
145
+ statement: "Union[SQL, Statement, QueryBuilder[Any]]",
146
+ /,
147
+ *parameters: "Union[StatementParameters, StatementFilter]",
148
+ schema_type: "Optional[type[ModelDTOT]]" = None,
149
+ _connection: "Optional[ConnectionT]" = None,
150
+ _config: "Optional[SQLConfig]" = None,
151
+ **kwargs: Any,
152
+ ) -> "Union[SQLResult[ModelDTOT], SQLResult[RowT]]":
153
+ sql_statement = self._build_statement(statement, *parameters, _config=_config or self.config, **kwargs)
154
+ result = await self._execute_statement(
155
+ statement=sql_statement, connection=self._connection(_connection), **kwargs
156
+ )
157
+
158
+ if self.returns_rows(sql_statement.expression):
159
+ return await self._wrap_select_result(
160
+ sql_statement, cast("SelectResultDict", result), schema_type=schema_type, **kwargs
161
+ )
162
+ return await self._wrap_execute_result(
163
+ sql_statement, cast("Union[DMLResultDict, ScriptResultDict]", result), **kwargs
164
+ )
165
+
166
+ async def execute_many(
167
+ self,
168
+ statement: "Union[SQL, Statement, QueryBuilder[Any]]", # QueryBuilder for DMLs will likely not return rows.
169
+ /,
170
+ *parameters: "Union[StatementParameters, StatementFilter]",
171
+ _connection: "Optional[ConnectionT]" = None,
172
+ _config: "Optional[SQLConfig]" = None,
173
+ **kwargs: Any,
174
+ ) -> "SQLResult[RowT]":
175
+ # Separate parameters from filters
176
+ param_sequences = []
177
+ filters = []
178
+ for param in parameters:
179
+ if isinstance(param, StatementFilter):
180
+ filters.append(param)
181
+ else:
182
+ param_sequences.append(param)
183
+
184
+ # Use first parameter as the sequence for execute_many
185
+ param_sequence = param_sequences[0] if param_sequences else None
186
+ # Convert tuple to list if needed
187
+ if isinstance(param_sequence, tuple):
188
+ param_sequence = list(param_sequence)
189
+ # Ensure param_sequence is a list or None
190
+ if param_sequence is not None and not isinstance(param_sequence, list):
191
+ param_sequence = list(param_sequence) if hasattr(param_sequence, "__iter__") else None
192
+ sql_statement = self._build_statement(statement, _config=_config or self.config, **kwargs)
193
+ sql_statement = sql_statement.as_many(param_sequence)
194
+ result = await self._execute_statement(
195
+ statement=sql_statement,
196
+ connection=self._connection(_connection),
197
+ parameters=param_sequence,
198
+ is_many=True,
199
+ **kwargs,
200
+ )
201
+ return await self._wrap_execute_result(
202
+ sql_statement, cast("Union[DMLResultDict, ScriptResultDict]", result), **kwargs
203
+ )
204
+
205
+ async def execute_script(
206
+ self,
207
+ statement: "Union[str, SQL]",
208
+ /,
209
+ *parameters: "Union[StatementParameters, StatementFilter]",
210
+ _connection: "Optional[ConnectionT]" = None,
211
+ _config: "Optional[SQLConfig]" = None,
212
+ **kwargs: Any,
213
+ ) -> "SQLResult[RowT]":
214
+ param_values = []
215
+ filters = []
216
+ for param in parameters:
217
+ if isinstance(param, StatementFilter):
218
+ filters.append(param)
219
+ else:
220
+ param_values.append(param)
221
+
222
+ # Use first parameter as the primary parameter value, or None if no parameters
223
+ primary_params = param_values[0] if param_values else None
224
+
225
+ script_config = _config or self.config
226
+ if script_config.enable_validation:
227
+ script_config = SQLConfig(
228
+ enable_parsing=script_config.enable_parsing,
229
+ enable_validation=False,
230
+ enable_transformations=script_config.enable_transformations,
231
+ enable_analysis=script_config.enable_analysis,
232
+ strict_mode=False,
233
+ cache_parsed_expression=script_config.cache_parsed_expression,
234
+ parameter_converter=script_config.parameter_converter,
235
+ parameter_validator=script_config.parameter_validator,
236
+ analysis_cache_size=script_config.analysis_cache_size,
237
+ allowed_parameter_styles=script_config.allowed_parameter_styles,
238
+ target_parameter_style=script_config.target_parameter_style,
239
+ allow_mixed_parameter_styles=script_config.allow_mixed_parameter_styles,
240
+ )
241
+ sql_statement = SQL(statement, primary_params, *filters, _dialect=self.dialect, _config=script_config, **kwargs)
242
+ sql_statement = sql_statement.as_script()
243
+ script_output = await self._execute_statement(
244
+ statement=sql_statement, connection=self._connection(_connection), is_script=True, **kwargs
245
+ )
246
+ if isinstance(script_output, str):
247
+ result = SQLResult[RowT](statement=sql_statement, data=[], operation_type="SCRIPT")
248
+ result.total_statements = 1
249
+ result.successful_statements = 1
250
+ return result
251
+ # Wrap the ScriptResultDict using the driver's wrapper
252
+ return await self._wrap_execute_result(sql_statement, cast("ScriptResultDict", script_output), **kwargs)