sqlspec 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (50) hide show
  1. sqlspec/__init__.py +15 -0
  2. sqlspec/_serialization.py +16 -2
  3. sqlspec/_typing.py +1 -1
  4. sqlspec/adapters/adbc/__init__.py +7 -0
  5. sqlspec/adapters/adbc/config.py +160 -17
  6. sqlspec/adapters/adbc/driver.py +333 -0
  7. sqlspec/adapters/aiosqlite/__init__.py +6 -2
  8. sqlspec/adapters/aiosqlite/config.py +25 -7
  9. sqlspec/adapters/aiosqlite/driver.py +275 -0
  10. sqlspec/adapters/asyncmy/__init__.py +7 -2
  11. sqlspec/adapters/asyncmy/config.py +75 -14
  12. sqlspec/adapters/asyncmy/driver.py +255 -0
  13. sqlspec/adapters/asyncpg/__init__.py +9 -0
  14. sqlspec/adapters/asyncpg/config.py +99 -20
  15. sqlspec/adapters/asyncpg/driver.py +288 -0
  16. sqlspec/adapters/duckdb/__init__.py +6 -2
  17. sqlspec/adapters/duckdb/config.py +197 -15
  18. sqlspec/adapters/duckdb/driver.py +225 -0
  19. sqlspec/adapters/oracledb/__init__.py +11 -8
  20. sqlspec/adapters/oracledb/config/__init__.py +6 -6
  21. sqlspec/adapters/oracledb/config/_asyncio.py +98 -13
  22. sqlspec/adapters/oracledb/config/_common.py +1 -1
  23. sqlspec/adapters/oracledb/config/_sync.py +99 -14
  24. sqlspec/adapters/oracledb/driver.py +498 -0
  25. sqlspec/adapters/psycopg/__init__.py +11 -0
  26. sqlspec/adapters/psycopg/config/__init__.py +6 -6
  27. sqlspec/adapters/psycopg/config/_async.py +105 -13
  28. sqlspec/adapters/psycopg/config/_common.py +2 -2
  29. sqlspec/adapters/psycopg/config/_sync.py +105 -13
  30. sqlspec/adapters/psycopg/driver.py +616 -0
  31. sqlspec/adapters/sqlite/__init__.py +7 -0
  32. sqlspec/adapters/sqlite/config.py +25 -7
  33. sqlspec/adapters/sqlite/driver.py +303 -0
  34. sqlspec/base.py +416 -36
  35. sqlspec/extensions/litestar/__init__.py +19 -0
  36. sqlspec/extensions/litestar/_utils.py +56 -0
  37. sqlspec/extensions/litestar/config.py +81 -0
  38. sqlspec/extensions/litestar/handlers.py +188 -0
  39. sqlspec/extensions/litestar/plugin.py +103 -11
  40. sqlspec/typing.py +72 -17
  41. sqlspec/utils/__init__.py +3 -0
  42. sqlspec/utils/deprecation.py +1 -1
  43. sqlspec/utils/fixtures.py +4 -5
  44. sqlspec/utils/sync_tools.py +335 -0
  45. {sqlspec-0.7.0.dist-info → sqlspec-0.8.0.dist-info}/METADATA +1 -1
  46. sqlspec-0.8.0.dist-info/RECORD +57 -0
  47. sqlspec-0.7.0.dist-info/RECORD +0 -46
  48. {sqlspec-0.7.0.dist-info → sqlspec-0.8.0.dist-info}/WHEEL +0 -0
  49. {sqlspec-0.7.0.dist-info → sqlspec-0.8.0.dist-info}/licenses/LICENSE +0 -0
  50. {sqlspec-0.7.0.dist-info → sqlspec-0.8.0.dist-info}/licenses/NOTICE +0 -0
sqlspec/base.py CHANGED
@@ -1,8 +1,23 @@
1
+ # ruff: noqa: PLR6301
2
+ import re
1
3
  from abc import ABC, abstractmethod
2
4
  from collections.abc import AsyncGenerator, Awaitable, Generator
3
5
  from contextlib import AbstractAsyncContextManager, AbstractContextManager
4
- from dataclasses import dataclass
5
- from typing import Annotated, Any, ClassVar, Generic, TypeVar, Union, cast, overload
6
+ from dataclasses import dataclass, field
7
+ from typing import (
8
+ Annotated,
9
+ Any,
10
+ ClassVar,
11
+ Generic,
12
+ Optional,
13
+ TypeVar,
14
+ Union,
15
+ cast,
16
+ overload,
17
+ )
18
+
19
+ from sqlspec.exceptions import NotFoundError
20
+ from sqlspec.typing import ModelDTOT, StatementParameterType
6
21
 
7
22
  __all__ = (
8
23
  "AsyncDatabaseConfig",
@@ -13,16 +28,34 @@ __all__ = (
13
28
  "SyncDatabaseConfig",
14
29
  )
15
30
 
31
+ T = TypeVar("T")
16
32
  ConnectionT = TypeVar("ConnectionT")
17
33
  PoolT = TypeVar("PoolT")
18
- AsyncConfigT = TypeVar("AsyncConfigT", bound="Union[AsyncDatabaseConfig[Any, Any], NoPoolAsyncConfig[Any]]")
19
- SyncConfigT = TypeVar("SyncConfigT", bound="Union[SyncDatabaseConfig[Any, Any], NoPoolSyncConfig[Any]]")
34
+ PoolT_co = TypeVar("PoolT_co", covariant=True)
35
+ AsyncConfigT = TypeVar("AsyncConfigT", bound="Union[AsyncDatabaseConfig[Any, Any, Any], NoPoolAsyncConfig[Any, Any]]")
36
+ SyncConfigT = TypeVar("SyncConfigT", bound="Union[SyncDatabaseConfig[Any, Any, Any], NoPoolSyncConfig[Any, Any]]")
37
+ ConfigT = TypeVar(
38
+ "ConfigT",
39
+ bound="Union[Union[AsyncDatabaseConfig[Any, Any, Any], NoPoolAsyncConfig[Any, Any]], SyncDatabaseConfig[Any, Any, Any], NoPoolSyncConfig[Any, Any]]",
40
+ )
41
+ DriverT = TypeVar("DriverT", bound="Union[SyncDriverAdapterProtocol[Any], AsyncDriverAdapterProtocol[Any]]")
42
+
43
+ # Regex to find :param style placeholders, avoiding those inside quotes
44
+ # Handles basic cases, might need refinement for complex SQL
45
+ PARAM_REGEX = re.compile(
46
+ r"(?P<dquote>\"(?:[^\"]|\"\")*\")|" # Double-quoted strings
47
+ r"(?P<squote>'(?:[^']|'')*')|" # Single-quoted strings
48
+ r"(?P<lead>[^:]):(?P<var_name>[a-zA-Z_][a-zA-Z0-9_]*)" # :param placeholder
49
+ )
20
50
 
21
51
 
22
52
  @dataclass
23
- class DatabaseConfigProtocol(Generic[ConnectionT, PoolT], ABC):
53
+ class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]):
24
54
  """Protocol defining the interface for database configurations."""
25
55
 
56
+ connection_type: "type[ConnectionT]" = field(init=False)
57
+ driver_type: "type[DriverT]" = field(init=False)
58
+ pool_instance: "Optional[PoolT]" = field(default=None)
26
59
  __is_async__: ClassVar[bool] = False
27
60
  __supports_connection_pooling__: ClassVar[bool] = False
28
61
 
@@ -59,6 +92,11 @@ class DatabaseConfigProtocol(Generic[ConnectionT, PoolT], ABC):
59
92
  """Create and return connection pool."""
60
93
  raise NotImplementedError
61
94
 
95
+ @abstractmethod
96
+ def close_pool(self) -> Optional[Awaitable[None]]:
97
+ """Terminate the connection pool."""
98
+ raise NotImplementedError
99
+
62
100
  @abstractmethod
63
101
  def provide_pool(
64
102
  self,
@@ -79,31 +117,39 @@ class DatabaseConfigProtocol(Generic[ConnectionT, PoolT], ABC):
79
117
  return self.__supports_connection_pooling__
80
118
 
81
119
 
82
- class NoPoolSyncConfig(DatabaseConfigProtocol[ConnectionT, None]):
120
+ class NoPoolSyncConfig(DatabaseConfigProtocol[ConnectionT, None, DriverT]):
83
121
  """Base class for a sync database configurations that do not implement a pool."""
84
122
 
85
123
  __is_async__ = False
86
124
  __supports_connection_pooling__ = False
125
+ pool_instance: None = None
87
126
 
88
127
  def create_pool(self) -> None:
89
128
  """This database backend has not implemented the pooling configurations."""
90
129
  return
91
130
 
131
+ def close_pool(self) -> None:
132
+ return
133
+
92
134
  def provide_pool(self, *args: Any, **kwargs: Any) -> None:
93
135
  """This database backend has not implemented the pooling configurations."""
94
136
  return
95
137
 
96
138
 
97
- class NoPoolAsyncConfig(DatabaseConfigProtocol[ConnectionT, None]):
139
+ class NoPoolAsyncConfig(DatabaseConfigProtocol[ConnectionT, None, DriverT]):
98
140
  """Base class for an async database configurations that do not implement a pool."""
99
141
 
100
142
  __is_async__ = True
101
143
  __supports_connection_pooling__ = False
144
+ pool_instance: None = None
102
145
 
103
146
  async def create_pool(self) -> None:
104
147
  """This database backend has not implemented the pooling configurations."""
105
148
  return
106
149
 
150
+ async def close_pool(self) -> None:
151
+ return
152
+
107
153
  def provide_pool(self, *args: Any, **kwargs: Any) -> None:
108
154
  """This database backend has not implemented the pooling configurations."""
109
155
  return
@@ -115,7 +161,7 @@ class GenericPoolConfig:
115
161
 
116
162
 
117
163
  @dataclass
118
- class SyncDatabaseConfig(DatabaseConfigProtocol[ConnectionT, PoolT]):
164
+ class SyncDatabaseConfig(DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]):
119
165
  """Generic Sync Database Configuration."""
120
166
 
121
167
  __is_async__ = False
@@ -123,18 +169,20 @@ class SyncDatabaseConfig(DatabaseConfigProtocol[ConnectionT, PoolT]):
123
169
 
124
170
 
125
171
  @dataclass
126
- class AsyncDatabaseConfig(DatabaseConfigProtocol[ConnectionT, PoolT]):
172
+ class AsyncDatabaseConfig(DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]):
127
173
  """Generic Async Database Configuration."""
128
174
 
129
175
  __is_async__ = True
130
176
  __supports_connection_pooling__ = True
131
177
 
132
178
 
133
- class ConfigManager:
134
- """Type-safe configuration manager with literal inference."""
179
+ class SQLSpec:
180
+ """Type-safe configuration manager and registry for database connections and pools."""
181
+
182
+ __slots__ = ("_configs",)
135
183
 
136
184
  def __init__(self) -> None:
137
- self._configs: dict[Any, DatabaseConfigProtocol[Any, Any]] = {}
185
+ self._configs: dict[Any, DatabaseConfigProtocol[Any, Any, Any]] = {}
138
186
 
139
187
  @overload
140
188
  def add_config(self, config: SyncConfigT) -> type[SyncConfigT]: ...
@@ -149,7 +197,11 @@ class ConfigManager:
149
197
  AsyncConfigT,
150
198
  ],
151
199
  ) -> Union[Annotated[type[SyncConfigT], int], Annotated[type[AsyncConfigT], int]]: # pyright: ignore[reportInvalidTypeVarUse]
152
- """Add a new configuration to the manager."""
200
+ """Add a new configuration to the manager.
201
+
202
+ Returns:
203
+ A unique type key that can be used to retrieve the configuration later.
204
+ """
153
205
  key = Annotated[type(config), id(config)] # type: ignore[valid-type]
154
206
  self._configs[key] = config
155
207
  return key # type: ignore[return-value] # pyright: ignore[reportReturnType]
@@ -162,9 +214,16 @@ class ConfigManager:
162
214
 
163
215
  def get_config(
164
216
  self,
165
- name: Union[type[DatabaseConfigProtocol[ConnectionT, PoolT]], Any],
166
- ) -> DatabaseConfigProtocol[ConnectionT, PoolT]:
167
- """Retrieve a configuration by its type."""
217
+ name: Union[type[DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]], Any],
218
+ ) -> DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]:
219
+ """Retrieve a configuration by its type.
220
+
221
+ Returns:
222
+ DatabaseConfigProtocol: The configuration instance for the given type.
223
+
224
+ Raises:
225
+ KeyError: If no configuration is found for the given type.
226
+ """
168
227
  config = self._configs.get(name)
169
228
  if not config:
170
229
  msg = f"No configuration found for {name}"
@@ -175,8 +234,8 @@ class ConfigManager:
175
234
  def get_connection(
176
235
  self,
177
236
  name: Union[
178
- type[NoPoolSyncConfig[ConnectionT]],
179
- type[SyncDatabaseConfig[ConnectionT, PoolT]], # pyright: ignore[reportInvalidTypeVarUse]
237
+ type[NoPoolSyncConfig[ConnectionT, DriverT]],
238
+ type[SyncDatabaseConfig[ConnectionT, PoolT, DriverT]], # pyright: ignore[reportInvalidTypeVarUse]
180
239
  ],
181
240
  ) -> ConnectionT: ...
182
241
 
@@ -184,44 +243,365 @@ class ConfigManager:
184
243
  def get_connection(
185
244
  self,
186
245
  name: Union[
187
- type[NoPoolAsyncConfig[ConnectionT]],
188
- type[AsyncDatabaseConfig[ConnectionT, PoolT]], # pyright: ignore[reportInvalidTypeVarUse]
246
+ type[NoPoolAsyncConfig[ConnectionT, DriverT]],
247
+ type[AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]], # pyright: ignore[reportInvalidTypeVarUse]
189
248
  ],
190
249
  ) -> Awaitable[ConnectionT]: ...
191
250
 
192
251
  def get_connection(
193
252
  self,
194
253
  name: Union[
195
- type[NoPoolSyncConfig[ConnectionT]],
196
- type[NoPoolAsyncConfig[ConnectionT]],
197
- type[SyncDatabaseConfig[ConnectionT, PoolT]],
198
- type[AsyncDatabaseConfig[ConnectionT, PoolT]],
254
+ type[NoPoolSyncConfig[ConnectionT, DriverT]],
255
+ type[NoPoolAsyncConfig[ConnectionT, DriverT]],
256
+ type[SyncDatabaseConfig[ConnectionT, PoolT, DriverT]],
257
+ type[AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]],
199
258
  ],
200
259
  ) -> Union[ConnectionT, Awaitable[ConnectionT]]:
201
- """Create and return a connection from the specified configuration."""
260
+ """Create and return a connection from the specified configuration.
261
+
262
+ Args:
263
+ name: The configuration type to use for creating the connection.
264
+
265
+ Returns:
266
+ Either a connection instance or an awaitable that resolves to a connection,
267
+ depending on whether the configuration is sync or async.
268
+ """
202
269
  config = self.get_config(name)
203
270
  return config.create_connection()
204
271
 
205
272
  @overload
206
- def get_pool(self, name: type[Union[NoPoolSyncConfig[ConnectionT], NoPoolAsyncConfig[ConnectionT]]]) -> None: ... # pyright: ignore[reportInvalidTypeVarUse]
273
+ def get_pool(
274
+ self, name: type[Union[NoPoolSyncConfig[ConnectionT, DriverT], NoPoolAsyncConfig[ConnectionT, DriverT]]]
275
+ ) -> None: ... # pyright: ignore[reportInvalidTypeVarUse]
207
276
 
208
277
  @overload
209
- def get_pool(self, name: type[SyncDatabaseConfig[ConnectionT, PoolT]]) -> type[PoolT]: ... # pyright: ignore[reportInvalidTypeVarUse]
278
+ def get_pool(self, name: type[SyncDatabaseConfig[ConnectionT, PoolT, DriverT]]) -> type[PoolT]: ... # pyright: ignore[reportInvalidTypeVarUse]
210
279
 
211
280
  @overload
212
- def get_pool(self, name: type[AsyncDatabaseConfig[ConnectionT, PoolT]]) -> Awaitable[type[PoolT]]: ... # pyright: ignore[reportInvalidTypeVarUse]
281
+ def get_pool(self, name: type[AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]]) -> Awaitable[type[PoolT]]: ... # pyright: ignore[reportInvalidTypeVarUse]
213
282
 
214
283
  def get_pool(
215
284
  self,
216
285
  name: Union[
217
- type[NoPoolSyncConfig[ConnectionT]],
218
- type[NoPoolAsyncConfig[ConnectionT]],
219
- type[SyncDatabaseConfig[ConnectionT, PoolT]],
220
- type[AsyncDatabaseConfig[ConnectionT, PoolT]],
286
+ type[NoPoolSyncConfig[ConnectionT, DriverT]],
287
+ type[NoPoolAsyncConfig[ConnectionT, DriverT]],
288
+ type[SyncDatabaseConfig[ConnectionT, PoolT, DriverT]],
289
+ type[AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]],
221
290
  ],
222
291
  ) -> Union[type[PoolT], Awaitable[type[PoolT]], None]:
223
- """Create and return a connection pool from the specified configuration."""
292
+ """Create and return a connection pool from the specified configuration.
293
+
294
+ Args:
295
+ name: The configuration type to use for creating the pool.
296
+
297
+ Returns:
298
+ Either a pool instance, an awaitable that resolves to a pool instance, or None
299
+ if the configuration does not support connection pooling.
300
+ """
224
301
  config = self.get_config(name)
225
- if isinstance(config, (NoPoolSyncConfig, NoPoolAsyncConfig)):
226
- return None
227
- return cast("Union[type[PoolT], Awaitable[type[PoolT]]]", config.create_pool())
302
+ if config.support_connection_pooling:
303
+ return cast("Union[type[PoolT], Awaitable[type[PoolT]]]", config.create_pool())
304
+ return None
305
+
306
+ def close_pool(
307
+ self,
308
+ name: Union[
309
+ type[NoPoolSyncConfig[ConnectionT, DriverT]],
310
+ type[NoPoolAsyncConfig[ConnectionT, DriverT]],
311
+ type[SyncDatabaseConfig[ConnectionT, PoolT, DriverT]],
312
+ type[AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]],
313
+ ],
314
+ ) -> Optional[Awaitable[None]]:
315
+ """Close the connection pool for the specified configuration.
316
+
317
+ Args:
318
+ name: The configuration type whose pool to close.
319
+
320
+ Returns:
321
+ An awaitable if the configuration is async, otherwise None.
322
+ """
323
+ config = self.get_config(name)
324
+ if config.support_connection_pooling:
325
+ return config.close_pool()
326
+ return None
327
+
328
+
329
+ class CommonDriverAttributes(Generic[ConnectionT]):
330
+ """Common attributes and methods for driver adapters."""
331
+
332
+ param_style: str = "?"
333
+ """The parameter style placeholder supported by the underlying database driver (e.g., '?', '%s')."""
334
+ connection: ConnectionT
335
+ """The connection to the underlying database."""
336
+
337
+ def _connection(self, connection: "Optional[ConnectionT]" = None) -> "ConnectionT":
338
+ return connection if connection is not None else self.connection
339
+
340
+ @staticmethod
341
+ def check_not_found(item_or_none: Optional[T] = None) -> T:
342
+ """Raise :exc:`sqlspec.exceptions.NotFoundError` if ``item_or_none`` is ``None``.
343
+
344
+ Args:
345
+ item_or_none: Item to be tested for existence.
346
+
347
+ Raises:
348
+ NotFoundError: If ``item_or_none`` is ``None``
349
+
350
+ Returns:
351
+ The item, if it exists.
352
+ """
353
+ if item_or_none is None:
354
+ msg = "No result found when one was expected"
355
+ raise NotFoundError(msg)
356
+ return item_or_none
357
+
358
+ def _process_sql_statement(self, sql: str) -> str:
359
+ """Perform any preprocessing of the SQL query string if needed.
360
+ Default implementation returns the SQL unchanged.
361
+
362
+ Args:
363
+ sql: The SQL query string.
364
+
365
+ Returns:
366
+ The processed SQL query string.
367
+ """
368
+ return sql
369
+
370
+ def _process_sql_params(
371
+ self, sql: str, parameters: "Optional[StatementParameterType]" = None
372
+ ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]":
373
+ """Process SQL query and parameters for DB-API execution.
374
+
375
+ Converts named parameters (:name) to positional parameters specified by `self.param_style`
376
+ if the input parameters are a dictionary.
377
+
378
+ Args:
379
+ sql: The SQL query string.
380
+ parameters: The parameters for the query (dict, tuple, list, or None).
381
+
382
+ Returns:
383
+ A tuple containing the processed SQL string and the processed parameters
384
+ (always a tuple or None if the input was a dictionary, otherwise the original type).
385
+
386
+ Raises:
387
+ ValueError: If a named parameter in the SQL is not found in the dictionary
388
+ or if a parameter in the dictionary is not used in the SQL.
389
+ """
390
+ if not isinstance(parameters, dict) or not parameters:
391
+ # If parameters are not a dict, or empty dict, assume positional/no params
392
+ # Let the underlying driver handle tuples/lists directly
393
+ return self._process_sql_statement(sql), parameters
394
+
395
+ processed_sql = ""
396
+ processed_params_list: list[Any] = []
397
+ last_end = 0
398
+ found_params: set[str] = set()
399
+
400
+ for match in PARAM_REGEX.finditer(sql):
401
+ if match.group("dquote") is not None or match.group("squote") is not None:
402
+ # Skip placeholders within quotes
403
+ continue
404
+
405
+ var_name = match.group("var_name")
406
+ if var_name is None: # Should not happen with the regex, but safeguard
407
+ continue
408
+
409
+ if var_name not in parameters:
410
+ msg = f"Named parameter ':{var_name}' found in SQL but not provided in parameters dictionary."
411
+ raise ValueError(msg)
412
+
413
+ # Append segment before the placeholder + the leading character + the driver's positional placeholder
414
+ # The match.start("var_name") -1 includes the character before the ':'
415
+ processed_sql += sql[last_end : match.start("var_name")] + self.param_style
416
+ processed_params_list.append(parameters[var_name])
417
+ found_params.add(var_name)
418
+ last_end = match.end("var_name")
419
+
420
+ # Append the rest of the SQL string
421
+ processed_sql += sql[last_end:]
422
+
423
+ # Check if all provided parameters were used
424
+ unused_params = set(parameters.keys()) - found_params
425
+ if unused_params:
426
+ msg = f"Parameters provided but not found in SQL: {unused_params}"
427
+ # Depending on desired strictness, this could be a warning or an error
428
+ # For now, let's raise an error for clarity
429
+ raise ValueError(msg)
430
+
431
+ processed_params = tuple(processed_params_list)
432
+ # Pass the processed SQL through the driver-specific processor if needed
433
+ final_sql = self._process_sql_statement(processed_sql)
434
+ return final_sql, processed_params
435
+
436
+
437
+ class SyncDriverAdapterProtocol(CommonDriverAttributes[ConnectionT], ABC, Generic[ConnectionT]):
438
+ connection: ConnectionT
439
+
440
+ def __init__(self, connection: ConnectionT) -> None:
441
+ self.connection = connection
442
+
443
+ @abstractmethod
444
+ def select(
445
+ self,
446
+ sql: str,
447
+ parameters: Optional[StatementParameterType] = None,
448
+ /,
449
+ connection: Optional[ConnectionT] = None,
450
+ schema_type: Optional[type[ModelDTOT]] = None,
451
+ ) -> "list[Union[ModelDTOT, dict[str, Any]]]": ...
452
+
453
+ @abstractmethod
454
+ def select_one(
455
+ self,
456
+ sql: str,
457
+ parameters: Optional[StatementParameterType] = None,
458
+ /,
459
+ connection: Optional[ConnectionT] = None,
460
+ schema_type: Optional[type[ModelDTOT]] = None,
461
+ ) -> "Union[ModelDTOT, dict[str, Any]]": ...
462
+
463
+ @abstractmethod
464
+ def select_one_or_none(
465
+ self,
466
+ sql: str,
467
+ parameters: Optional[StatementParameterType] = None,
468
+ /,
469
+ connection: Optional[ConnectionT] = None,
470
+ schema_type: Optional[type[ModelDTOT]] = None,
471
+ ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": ...
472
+
473
+ @abstractmethod
474
+ def select_value(
475
+ self,
476
+ sql: str,
477
+ parameters: Optional[StatementParameterType] = None,
478
+ /,
479
+ connection: Optional[ConnectionT] = None,
480
+ schema_type: Optional[type[T]] = None,
481
+ ) -> "Union[Any, T]": ...
482
+
483
+ @abstractmethod
484
+ def select_value_or_none(
485
+ self,
486
+ sql: str,
487
+ parameters: Optional[StatementParameterType] = None,
488
+ /,
489
+ connection: Optional[ConnectionT] = None,
490
+ schema_type: Optional[type[T]] = None,
491
+ ) -> "Optional[Union[Any, T]]": ...
492
+
493
+ @abstractmethod
494
+ def insert_update_delete(
495
+ self,
496
+ sql: str,
497
+ parameters: Optional[StatementParameterType] = None,
498
+ /,
499
+ connection: Optional[ConnectionT] = None,
500
+ ) -> int: ...
501
+
502
+ @abstractmethod
503
+ def insert_update_delete_returning(
504
+ self,
505
+ sql: str,
506
+ parameters: Optional[StatementParameterType] = None,
507
+ /,
508
+ connection: Optional[ConnectionT] = None,
509
+ schema_type: Optional[type[ModelDTOT]] = None,
510
+ ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": ...
511
+
512
+ @abstractmethod
513
+ def execute_script(
514
+ self,
515
+ sql: str,
516
+ parameters: Optional[StatementParameterType] = None,
517
+ /,
518
+ connection: Optional[ConnectionT] = None,
519
+ ) -> str: ...
520
+
521
+
522
+ class AsyncDriverAdapterProtocol(CommonDriverAttributes[ConnectionT], ABC, Generic[ConnectionT]):
523
+ connection: ConnectionT
524
+
525
+ def __init__(self, connection: ConnectionT) -> None:
526
+ self.connection = connection
527
+
528
+ @abstractmethod
529
+ async def select(
530
+ self,
531
+ sql: str,
532
+ parameters: Optional[StatementParameterType] = None,
533
+ /,
534
+ connection: Optional[ConnectionT] = None,
535
+ schema_type: Optional[type[ModelDTOT]] = None,
536
+ ) -> "list[Union[ModelDTOT, dict[str, Any]]]": ...
537
+
538
+ @abstractmethod
539
+ async def select_one(
540
+ self,
541
+ sql: str,
542
+ parameters: Optional[StatementParameterType] = None,
543
+ /,
544
+ connection: Optional[ConnectionT] = None,
545
+ schema_type: Optional[type[ModelDTOT]] = None,
546
+ ) -> "Union[ModelDTOT, dict[str, Any]]": ...
547
+
548
+ @abstractmethod
549
+ async def select_one_or_none(
550
+ self,
551
+ sql: str,
552
+ parameters: Optional[StatementParameterType] = None,
553
+ /,
554
+ connection: Optional[ConnectionT] = None,
555
+ schema_type: Optional[type[ModelDTOT]] = None,
556
+ ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": ...
557
+
558
+ @abstractmethod
559
+ async def select_value(
560
+ self,
561
+ sql: str,
562
+ parameters: Optional[StatementParameterType] = None,
563
+ /,
564
+ connection: Optional[ConnectionT] = None,
565
+ schema_type: Optional[type[T]] = None,
566
+ ) -> "Union[Any, T]": ...
567
+
568
+ @abstractmethod
569
+ async def select_value_or_none(
570
+ self,
571
+ sql: str,
572
+ parameters: Optional[StatementParameterType] = None,
573
+ /,
574
+ connection: Optional[ConnectionT] = None,
575
+ schema_type: Optional[type[T]] = None,
576
+ ) -> "Optional[Union[Any, T]]": ...
577
+
578
+ @abstractmethod
579
+ async def insert_update_delete(
580
+ self,
581
+ sql: str,
582
+ parameters: Optional[StatementParameterType] = None,
583
+ /,
584
+ connection: Optional[ConnectionT] = None,
585
+ ) -> int: ...
586
+
587
+ @abstractmethod
588
+ async def insert_update_delete_returning(
589
+ self,
590
+ sql: str,
591
+ parameters: Optional[StatementParameterType] = None,
592
+ /,
593
+ connection: Optional[ConnectionT] = None,
594
+ schema_type: Optional[type[ModelDTOT]] = None,
595
+ ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": ...
596
+
597
+ @abstractmethod
598
+ async def execute_script(
599
+ self,
600
+ sql: str,
601
+ parameters: Optional[StatementParameterType] = None,
602
+ /,
603
+ connection: Optional[ConnectionT] = None,
604
+ ) -> str: ...
605
+
606
+
607
+ DriverAdapterProtocol = Union[SyncDriverAdapterProtocol[ConnectionT], AsyncDriverAdapterProtocol[ConnectionT]]
@@ -0,0 +1,19 @@
1
+ from sqlspec.extensions.litestar.config import DatabaseConfig
2
+ from sqlspec.extensions.litestar.handlers import (
3
+ autocommit_handler_maker,
4
+ connection_provider_maker,
5
+ lifespan_handler_maker,
6
+ manual_handler_maker,
7
+ pool_provider_maker,
8
+ )
9
+ from sqlspec.extensions.litestar.plugin import SQLSpec
10
+
11
+ __all__ = (
12
+ "DatabaseConfig",
13
+ "SQLSpec",
14
+ "autocommit_handler_maker",
15
+ "connection_provider_maker",
16
+ "lifespan_handler_maker",
17
+ "manual_handler_maker",
18
+ "pool_provider_maker",
19
+ )
@@ -0,0 +1,56 @@
1
+ from typing import TYPE_CHECKING, Any
2
+
3
+ if TYPE_CHECKING:
4
+ from litestar.types import Scope
5
+
6
+ __all__ = (
7
+ "delete_sqlspec_scope_state",
8
+ "get_sqlspec_scope_state",
9
+ "set_sqlspec_scope_state",
10
+ )
11
+
12
+ _SCOPE_NAMESPACE = "_sqlspec"
13
+
14
+
15
+ def get_sqlspec_scope_state(scope: "Scope", key: str, default: Any = None, pop: bool = False) -> Any:
16
+ """Get an internal value from connection scope state.
17
+
18
+ Note:
19
+ If called with a default value, this method behaves like to `dict.set_default()`, both setting the key in the
20
+ namespace to the default value, and returning it.
21
+
22
+ If called without a default value, the method behaves like `dict.get()`, returning ``None`` if the key does not
23
+ exist.
24
+
25
+ Args:
26
+ scope: The connection scope.
27
+ key: Key to get from internal namespace in scope state.
28
+ default: Default value to return.
29
+ pop: Boolean flag dictating whether the value should be deleted from the state.
30
+
31
+ Returns:
32
+ Value mapped to ``key`` in internal connection scope namespace.
33
+ """
34
+ namespace = scope.setdefault(_SCOPE_NAMESPACE, {}) # type: ignore[misc]
35
+ return namespace.pop(key, default) if pop else namespace.get(key, default) # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
36
+
37
+
38
+ def set_sqlspec_scope_state(scope: "Scope", key: str, value: Any) -> None:
39
+ """Set an internal value in connection scope state.
40
+
41
+ Args:
42
+ scope: The connection scope.
43
+ key: Key to set under internal namespace in scope state.
44
+ value: Value for key.
45
+ """
46
+ scope.setdefault(_SCOPE_NAMESPACE, {})[key] = value # type: ignore[misc]
47
+
48
+
49
+ def delete_sqlspec_scope_state(scope: "Scope", key: str) -> None:
50
+ """Remove an internal value from connection scope state.
51
+
52
+ Args:
53
+ scope: The connection scope.
54
+ key: Key to set under internal namespace in scope state.
55
+ """
56
+ del scope.setdefault(_SCOPE_NAMESPACE, {})[key] # type: ignore[misc]