sqlspec 0.7.1__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (49) hide show
  1. sqlspec/__init__.py +15 -0
  2. sqlspec/_serialization.py +16 -2
  3. sqlspec/_typing.py +1 -1
  4. sqlspec/adapters/adbc/__init__.py +7 -0
  5. sqlspec/adapters/adbc/config.py +160 -17
  6. sqlspec/adapters/adbc/driver.py +333 -0
  7. sqlspec/adapters/aiosqlite/__init__.py +6 -2
  8. sqlspec/adapters/aiosqlite/config.py +25 -7
  9. sqlspec/adapters/aiosqlite/driver.py +275 -0
  10. sqlspec/adapters/asyncmy/__init__.py +7 -2
  11. sqlspec/adapters/asyncmy/config.py +75 -14
  12. sqlspec/adapters/asyncmy/driver.py +255 -0
  13. sqlspec/adapters/asyncpg/__init__.py +9 -0
  14. sqlspec/adapters/asyncpg/config.py +99 -20
  15. sqlspec/adapters/asyncpg/driver.py +288 -0
  16. sqlspec/adapters/duckdb/__init__.py +6 -2
  17. sqlspec/adapters/duckdb/config.py +195 -13
  18. sqlspec/adapters/duckdb/driver.py +225 -0
  19. sqlspec/adapters/oracledb/__init__.py +11 -8
  20. sqlspec/adapters/oracledb/config/__init__.py +6 -6
  21. sqlspec/adapters/oracledb/config/_asyncio.py +98 -13
  22. sqlspec/adapters/oracledb/config/_common.py +1 -1
  23. sqlspec/adapters/oracledb/config/_sync.py +99 -14
  24. sqlspec/adapters/oracledb/driver.py +498 -0
  25. sqlspec/adapters/psycopg/__init__.py +11 -0
  26. sqlspec/adapters/psycopg/config/__init__.py +6 -6
  27. sqlspec/adapters/psycopg/config/_async.py +105 -13
  28. sqlspec/adapters/psycopg/config/_common.py +2 -2
  29. sqlspec/adapters/psycopg/config/_sync.py +105 -13
  30. sqlspec/adapters/psycopg/driver.py +616 -0
  31. sqlspec/adapters/sqlite/__init__.py +7 -0
  32. sqlspec/adapters/sqlite/config.py +25 -7
  33. sqlspec/adapters/sqlite/driver.py +303 -0
  34. sqlspec/base.py +416 -36
  35. sqlspec/extensions/litestar/__init__.py +19 -0
  36. sqlspec/extensions/litestar/_utils.py +56 -0
  37. sqlspec/extensions/litestar/config.py +81 -0
  38. sqlspec/extensions/litestar/handlers.py +188 -0
  39. sqlspec/extensions/litestar/plugin.py +100 -11
  40. sqlspec/typing.py +72 -17
  41. sqlspec/utils/__init__.py +3 -0
  42. sqlspec/utils/fixtures.py +4 -5
  43. sqlspec/utils/sync_tools.py +335 -0
  44. {sqlspec-0.7.1.dist-info → sqlspec-0.8.0.dist-info}/METADATA +1 -1
  45. sqlspec-0.8.0.dist-info/RECORD +57 -0
  46. sqlspec-0.7.1.dist-info/RECORD +0 -46
  47. {sqlspec-0.7.1.dist-info → sqlspec-0.8.0.dist-info}/WHEEL +0 -0
  48. {sqlspec-0.7.1.dist-info → sqlspec-0.8.0.dist-info}/licenses/LICENSE +0 -0
  49. {sqlspec-0.7.1.dist-info → sqlspec-0.8.0.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,225 @@
1
+ from contextlib import contextmanager
2
+ from typing import TYPE_CHECKING, Any, Optional, Union, cast
3
+
4
+ from sqlspec.base import SyncDriverAdapterProtocol, T
5
+
6
+ if TYPE_CHECKING:
7
+ from collections.abc import Generator
8
+
9
+ from duckdb import DuckDBPyConnection
10
+
11
+ from sqlspec.typing import ModelDTOT, StatementParameterType
12
+
13
+ __all__ = ("DuckDBDriver",)
14
+
15
+
16
+ class DuckDBDriver(SyncDriverAdapterProtocol["DuckDBPyConnection"]):
17
+ """DuckDB Sync Driver Adapter."""
18
+
19
+ connection: "DuckDBPyConnection"
20
+ use_cursor: bool = True
21
+ # param_style is inherited from CommonDriverAttributes
22
+
23
+ def __init__(self, connection: "DuckDBPyConnection", use_cursor: bool = True) -> None:
24
+ self.connection = connection
25
+ self.use_cursor = use_cursor
26
+
27
+ # --- Helper Methods --- #
28
+ def _cursor(self, connection: "DuckDBPyConnection") -> "DuckDBPyConnection":
29
+ if self.use_cursor:
30
+ # Ignore lack of type hint on cursor()
31
+ return connection.cursor()
32
+ return connection
33
+
34
+ @contextmanager
35
+ def _with_cursor(self, connection: "DuckDBPyConnection") -> "Generator[DuckDBPyConnection, None, None]":
36
+ if self.use_cursor:
37
+ cursor = self._cursor(connection)
38
+ try:
39
+ yield cursor
40
+ finally:
41
+ cursor.close()
42
+ else:
43
+ yield connection # Yield the connection directly
44
+
45
+ # --- Public API Methods (Original Implementation + _process_sql_params) --- #
46
+
47
+ def select(
48
+ self,
49
+ sql: str,
50
+ parameters: Optional["StatementParameterType"] = None,
51
+ /,
52
+ connection: Optional["DuckDBPyConnection"] = None,
53
+ schema_type: "Optional[type[ModelDTOT]]" = None,
54
+ ) -> "list[Union[ModelDTOT, dict[str, Any]]]":
55
+ connection = self._connection(connection)
56
+ sql, parameters = self._process_sql_params(sql, parameters)
57
+ with self._with_cursor(connection) as cursor:
58
+ cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
59
+ results = cursor.fetchall() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
60
+ if not results:
61
+ return []
62
+
63
+ column_names = [col[0] for col in cursor.description or []] # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
64
+
65
+ if schema_type is not None:
66
+ return [cast("ModelDTOT", schema_type(**dict(zip(column_names, row)))) for row in results] # pyright: ignore[reportUnknownArgumentType]
67
+ return [dict(zip(column_names, row)) for row in results] # pyright: ignore[reportUnknownArgumentType]
68
+
69
+ def select_one(
70
+ self,
71
+ sql: str,
72
+ parameters: Optional["StatementParameterType"] = None,
73
+ /,
74
+ connection: Optional["DuckDBPyConnection"] = None,
75
+ schema_type: "Optional[type[ModelDTOT]]" = None,
76
+ ) -> "Union[ModelDTOT, dict[str, Any]]":
77
+ connection = self._connection(connection)
78
+ sql, parameters = self._process_sql_params(sql, parameters)
79
+
80
+ with self._with_cursor(connection) as cursor:
81
+ cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
82
+ result = cursor.fetchone() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
83
+ result = self.check_not_found(result) # pyright: ignore
84
+
85
+ column_names = [col[0] for col in cursor.description or []] # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
86
+ if schema_type is not None:
87
+ return cast("ModelDTOT", schema_type(**dict(zip(column_names, result)))) # pyright: ignore[reportUnknownArgumentType]
88
+ # Always return dictionaries
89
+ return dict(zip(column_names, result)) # pyright: ignore[reportUnknownArgumentType,reportUnknownVariableType]
90
+
91
+ def select_one_or_none(
92
+ self,
93
+ sql: str,
94
+ parameters: Optional["StatementParameterType"] = None,
95
+ /,
96
+ connection: Optional["DuckDBPyConnection"] = None,
97
+ schema_type: "Optional[type[ModelDTOT]]" = None,
98
+ ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]":
99
+ connection = self._connection(connection)
100
+ sql, parameters = self._process_sql_params(sql, parameters)
101
+
102
+ with self._with_cursor(connection) as cursor:
103
+ cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
104
+ result = cursor.fetchone() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
105
+ if result is None:
106
+ return None
107
+
108
+ column_names = [col[0] for col in cursor.description or []] # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
109
+ if schema_type is not None:
110
+ return cast("ModelDTOT", schema_type(**dict(zip(column_names, result)))) # pyright: ignore[reportUnknownArgumentType]
111
+ return dict(zip(column_names, result)) # pyright: ignore[reportUnknownArgumentType,reportUnknownVariableType]
112
+
113
+ def select_value(
114
+ self,
115
+ sql: str,
116
+ parameters: "Optional[StatementParameterType]" = None,
117
+ /,
118
+ connection: "Optional[DuckDBPyConnection]" = None,
119
+ schema_type: "Optional[type[T]]" = None,
120
+ ) -> "Union[T, Any]":
121
+ connection = self._connection(connection)
122
+ sql, parameters = self._process_sql_params(sql, parameters)
123
+
124
+ with self._with_cursor(connection) as cursor:
125
+ cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
126
+ result = cursor.fetchone() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
127
+ result = self.check_not_found(result) # pyright: ignore
128
+ if schema_type is None:
129
+ return result[0] # pyright: ignore
130
+ return schema_type(result[0]) # type: ignore[call-arg]
131
+
132
+ def select_value_or_none(
133
+ self,
134
+ sql: str,
135
+ parameters: "Optional[StatementParameterType]" = None,
136
+ /,
137
+ connection: "Optional[DuckDBPyConnection]" = None,
138
+ schema_type: "Optional[type[T]]" = None,
139
+ ) -> "Optional[Union[T, Any]]":
140
+ connection = self._connection(connection)
141
+ sql, parameters = self._process_sql_params(sql, parameters)
142
+ with self._with_cursor(connection) as cursor:
143
+ cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
144
+ result = cursor.fetchone() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
145
+ if result is None:
146
+ return None
147
+ if schema_type is None:
148
+ return result[0] # pyright: ignore
149
+ return schema_type(result[0]) # type: ignore[call-arg]
150
+
151
+ def insert_update_delete(
152
+ self,
153
+ sql: str,
154
+ parameters: Optional["StatementParameterType"] = None,
155
+ /,
156
+ connection: Optional["DuckDBPyConnection"] = None,
157
+ ) -> int:
158
+ connection = self._connection(connection)
159
+ sql, parameters = self._process_sql_params(sql, parameters)
160
+ with self._with_cursor(connection) as cursor:
161
+ cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
162
+ return getattr(cursor, "rowcount", -1) # pyright: ignore[reportUnknownMemberType]
163
+
164
+ def insert_update_delete_returning(
165
+ self,
166
+ sql: str,
167
+ parameters: Optional["StatementParameterType"] = None,
168
+ /,
169
+ connection: Optional["DuckDBPyConnection"] = None,
170
+ schema_type: "Optional[type[ModelDTOT]]" = None,
171
+ ) -> "Optional[Union[dict[str, Any], ModelDTOT]]":
172
+ connection = self._connection(connection)
173
+ sql, parameters = self._process_sql_params(sql, parameters)
174
+ with self._with_cursor(connection) as cursor:
175
+ cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
176
+ result = cursor.fetchall() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
177
+ if not result:
178
+ return None # pyright: ignore[reportUnknownArgumentType]
179
+ column_names = [col[0] for col in cursor.description or []] # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
180
+ if schema_type is not None:
181
+ return cast("ModelDTOT", schema_type(**dict(zip(column_names, result[0])))) # pyright: ignore[reportUnknownArgumentType]
182
+ # Always return dictionaries
183
+ return dict(zip(column_names, result[0])) # pyright: ignore[reportUnknownArgumentType,reportUnknownVariableType]
184
+
185
+ def _process_sql_params(
186
+ self, sql: str, parameters: "Optional[StatementParameterType]" = None
187
+ ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]":
188
+ """Process SQL query and parameters for DB-API execution.
189
+
190
+ Converts named parameters (:name) to positional parameters (?) for DuckDB.
191
+
192
+ Args:
193
+ sql: The SQL query string.
194
+ parameters: The parameters for the query (dict, tuple, list, or None).
195
+
196
+ Returns:
197
+ A tuple containing the processed SQL string and the processed parameters.
198
+ """
199
+ if not isinstance(parameters, dict) or not parameters:
200
+ # If parameters are not a dict, or empty dict, assume positional/no params
201
+ # Let the underlying driver handle tuples/lists directly
202
+ return sql, parameters
203
+
204
+ # Convert named parameters to positional parameters
205
+ processed_sql = sql
206
+ processed_params: list[Any] = []
207
+ for key, value in parameters.items():
208
+ # Replace :key with ? in the SQL
209
+ processed_sql = processed_sql.replace(f":{key}", "?")
210
+ processed_params.append(value)
211
+
212
+ return processed_sql, tuple(processed_params)
213
+
214
+ def execute_script(
215
+ self,
216
+ sql: str,
217
+ parameters: Optional["StatementParameterType"] = None,
218
+ /,
219
+ connection: Optional["DuckDBPyConnection"] = None,
220
+ ) -> str:
221
+ connection = self._connection(connection)
222
+ sql, parameters = self._process_sql_params(sql, parameters)
223
+ with self._with_cursor(connection) as cursor:
224
+ cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
225
+ return cast("str", getattr(cursor, "statusmessage", "DONE")) # pyright: ignore[reportUnknownMemberType]
@@ -1,13 +1,16 @@
1
1
  from sqlspec.adapters.oracledb.config import (
2
- OracleAsyncDatabaseConfig,
3
- OracleAsyncPoolConfig,
4
- OracleSyncDatabaseConfig,
5
- OracleSyncPoolConfig,
2
+ OracleAsync,
3
+ OracleAsyncPool,
4
+ OracleSync,
5
+ OracleSyncPool,
6
6
  )
7
+ from sqlspec.adapters.oracledb.driver import OracleAsyncDriver, OracleSyncDriver
7
8
 
8
9
  __all__ = (
9
- "OracleAsyncDatabaseConfig",
10
- "OracleAsyncPoolConfig",
11
- "OracleSyncDatabaseConfig",
12
- "OracleSyncPoolConfig",
10
+ "OracleAsync",
11
+ "OracleAsyncDriver",
12
+ "OracleAsyncPool",
13
+ "OracleSync",
14
+ "OracleSyncDriver",
15
+ "OracleSyncPool",
13
16
  )
@@ -1,9 +1,9 @@
1
- from sqlspec.adapters.oracledb.config._asyncio import OracleAsyncDatabaseConfig, OracleAsyncPoolConfig
2
- from sqlspec.adapters.oracledb.config._sync import OracleSyncDatabaseConfig, OracleSyncPoolConfig
1
+ from sqlspec.adapters.oracledb.config._asyncio import OracleAsync, OracleAsyncPool
2
+ from sqlspec.adapters.oracledb.config._sync import OracleSync, OracleSyncPool
3
3
 
4
4
  __all__ = (
5
- "OracleAsyncDatabaseConfig",
6
- "OracleAsyncPoolConfig",
7
- "OracleSyncDatabaseConfig",
8
- "OracleSyncPoolConfig",
5
+ "OracleAsync",
6
+ "OracleAsyncPool",
7
+ "OracleSync",
8
+ "OracleSyncPool",
9
9
  )
@@ -1,14 +1,12 @@
1
1
  from contextlib import asynccontextmanager
2
- from dataclasses import dataclass
2
+ from dataclasses import dataclass, field
3
3
  from typing import TYPE_CHECKING, Any, Optional
4
4
 
5
5
  from oracledb import create_pool_async as oracledb_create_pool # pyright: ignore[reportUnknownVariableType]
6
6
  from oracledb.connection import AsyncConnection
7
- from oracledb.pool import AsyncConnectionPool
8
7
 
9
- from sqlspec.adapters.oracledb.config._common import (
10
- OracleGenericPoolConfig,
11
- )
8
+ from sqlspec.adapters.oracledb.config._common import OracleGenericPoolConfig
9
+ from sqlspec.adapters.oracledb.driver import OracleAsyncDriver
12
10
  from sqlspec.base import AsyncDatabaseConfig
13
11
  from sqlspec.exceptions import ImproperConfigurationError
14
12
  from sqlspec.typing import dataclass_to_dict
@@ -16,20 +14,22 @@ from sqlspec.typing import dataclass_to_dict
16
14
  if TYPE_CHECKING:
17
15
  from collections.abc import AsyncGenerator, Awaitable
18
16
 
17
+ from oracledb.pool import AsyncConnectionPool
18
+
19
19
 
20
20
  __all__ = (
21
- "OracleAsyncDatabaseConfig",
22
- "OracleAsyncPoolConfig",
21
+ "OracleAsync",
22
+ "OracleAsyncPool",
23
23
  )
24
24
 
25
25
 
26
26
  @dataclass
27
- class OracleAsyncPoolConfig(OracleGenericPoolConfig[AsyncConnection, AsyncConnectionPool]):
27
+ class OracleAsyncPool(OracleGenericPoolConfig["AsyncConnection", "AsyncConnectionPool"]):
28
28
  """Async Oracle Pool Config"""
29
29
 
30
30
 
31
31
  @dataclass
32
- class OracleAsyncDatabaseConfig(AsyncDatabaseConfig[AsyncConnection, AsyncConnectionPool]):
32
+ class OracleAsync(AsyncDatabaseConfig["AsyncConnection", "AsyncConnectionPool", "OracleAsyncDriver"]):
33
33
  """Oracle Async database Configuration.
34
34
 
35
35
  This class provides the base configuration for Oracle database connections, extending
@@ -42,30 +42,99 @@ class OracleAsyncDatabaseConfig(AsyncDatabaseConfig[AsyncConnection, AsyncConnec
42
42
  options.([2](https://python-oracledb.readthedocs.io/en/latest/user_guide/tuning.html))
43
43
  """
44
44
 
45
- pool_config: "Optional[OracleAsyncPoolConfig]" = None
45
+ pool_config: "Optional[OracleAsyncPool]" = None
46
46
  """Oracle Pool configuration"""
47
47
  pool_instance: "Optional[AsyncConnectionPool]" = None
48
48
  """Optional pool to use.
49
49
 
50
50
  If set, the plugin will use the provided pool rather than instantiate one.
51
51
  """
52
+ connection_type: "type[AsyncConnection]" = field(init=False, default_factory=lambda: AsyncConnection)
53
+ """Connection class to use.
54
+
55
+ Defaults to :class:`AsyncConnection`.
56
+ """
57
+ driver_type: "type[OracleAsyncDriver]" = field(init=False, default_factory=lambda: OracleAsyncDriver) # type: ignore[type-abstract,unused-ignore]
58
+ """Driver class to use.
59
+
60
+ Defaults to :class:`OracleAsyncDriver`.
61
+ """
62
+
63
+ @property
64
+ def connection_config_dict(self) -> "dict[str, Any]":
65
+ """Return the connection configuration as a dict.
66
+
67
+ Returns:
68
+ A string keyed dict of config kwargs for the oracledb.connect function.
69
+
70
+ Raises:
71
+ ImproperConfigurationError: If the connection configuration is not provided.
72
+ """
73
+ if self.pool_config:
74
+ # Filter out pool-specific parameters
75
+ pool_only_params = {
76
+ "min",
77
+ "max",
78
+ "increment",
79
+ "timeout",
80
+ "wait_timeout",
81
+ "max_lifetime_session",
82
+ "session_callback",
83
+ }
84
+ return dataclass_to_dict(
85
+ self.pool_config,
86
+ exclude_empty=True,
87
+ convert_nested=False,
88
+ exclude=pool_only_params.union({"pool_instance", "connection_type", "driver_type"}),
89
+ )
90
+ msg = "You must provide a 'pool_config' for this adapter."
91
+ raise ImproperConfigurationError(msg)
52
92
 
53
93
  @property
54
94
  def pool_config_dict(self) -> "dict[str, Any]":
55
95
  """Return the pool configuration as a dict.
56
96
 
97
+ Raises:
98
+ ImproperConfigurationError: If no pool_config is provided but a pool_instance
99
+
57
100
  Returns:
58
101
  A string keyed dict of config kwargs for the Asyncpg :func:`create_pool <oracledb.pool.create_pool>`
59
102
  function.
60
103
  """
61
104
  if self.pool_config is not None:
62
- return dataclass_to_dict(self.pool_config, exclude_empty=True, convert_nested=False)
105
+ return dataclass_to_dict(
106
+ self.pool_config,
107
+ exclude_empty=True,
108
+ convert_nested=False,
109
+ exclude={"pool_instance", "connection_type", "driver_type"},
110
+ )
63
111
  msg = "'pool_config' methods can not be used when a 'pool_instance' is provided."
64
112
  raise ImproperConfigurationError(msg)
65
113
 
114
+ async def create_connection(self) -> "AsyncConnection":
115
+ """Create and return a new oracledb async connection.
116
+
117
+ Returns:
118
+ An AsyncConnection instance.
119
+
120
+ Raises:
121
+ ImproperConfigurationError: If the connection could not be created.
122
+ """
123
+ try:
124
+ import oracledb
125
+
126
+ return await oracledb.connect_async(**self.connection_config_dict) # type: ignore[no-any-return]
127
+ except Exception as e:
128
+ msg = f"Could not configure the Oracle async connection. Error: {e!s}"
129
+ raise ImproperConfigurationError(msg) from e
130
+
66
131
  async def create_pool(self) -> "AsyncConnectionPool":
67
132
  """Return a pool. If none exists yet, create one.
68
133
 
134
+ Raises:
135
+ ImproperConfigurationError: If neither pool_config nor pool_instance are provided,
136
+ or if the pool could not be configured.
137
+
69
138
  Returns:
70
139
  Getter that returns the pool instance used by the plugin.
71
140
  """
@@ -95,9 +164,25 @@ class OracleAsyncDatabaseConfig(AsyncDatabaseConfig[AsyncConnection, AsyncConnec
95
164
  async def provide_connection(self, *args: "Any", **kwargs: "Any") -> "AsyncGenerator[AsyncConnection, None]":
96
165
  """Create a connection instance.
97
166
 
98
- Returns:
99
- A connection instance.
167
+ Yields:
168
+ AsyncConnection: A connection instance.
100
169
  """
101
170
  db_pool = await self.provide_pool(*args, **kwargs)
102
171
  async with db_pool.acquire() as connection: # pyright: ignore[reportUnknownMemberType]
103
172
  yield connection
173
+
174
+ @asynccontextmanager
175
+ async def provide_session(self, *args: "Any", **kwargs: "Any") -> "AsyncGenerator[OracleAsyncDriver, None]":
176
+ """Create and provide a database session.
177
+
178
+ Yields:
179
+ OracleAsyncDriver: A driver instance with an active connection.
180
+ """
181
+ async with self.provide_connection(*args, **kwargs) as connection:
182
+ yield self.driver_type(connection)
183
+
184
+ async def close_pool(self) -> None:
185
+ """Close the connection pool."""
186
+ if self.pool_instance is not None:
187
+ await self.pool_instance.close()
188
+ self.pool_instance = None
@@ -27,7 +27,7 @@ PoolT = TypeVar("PoolT", bound="Union[ConnectionPool, AsyncConnectionPool]")
27
27
 
28
28
 
29
29
  @dataclass
30
- class OracleGenericPoolConfig(Generic[ConnectionT, PoolT], GenericPoolConfig):
30
+ class OracleGenericPoolConfig(GenericPoolConfig, Generic[ConnectionT, PoolT]):
31
31
  """Configuration for Oracle database connection pools.
32
32
 
33
33
  This class provides configuration options for both synchronous and asynchronous Oracle
@@ -1,35 +1,35 @@
1
1
  from contextlib import contextmanager
2
- from dataclasses import dataclass
2
+ from dataclasses import dataclass, field
3
3
  from typing import TYPE_CHECKING, Any, Optional
4
4
 
5
5
  from oracledb import create_pool as oracledb_create_pool # pyright: ignore[reportUnknownVariableType]
6
6
  from oracledb.connection import Connection
7
- from oracledb.pool import ConnectionPool
8
7
 
9
- from sqlspec.adapters.oracledb.config._common import (
10
- OracleGenericPoolConfig,
11
- )
8
+ from sqlspec.adapters.oracledb.config._common import OracleGenericPoolConfig
9
+ from sqlspec.adapters.oracledb.driver import OracleSyncDriver
12
10
  from sqlspec.base import SyncDatabaseConfig
13
11
  from sqlspec.exceptions import ImproperConfigurationError
14
12
  from sqlspec.typing import dataclass_to_dict
15
13
 
16
14
  if TYPE_CHECKING:
17
15
  from collections.abc import Generator
18
- from typing import Any
16
+
17
+ from oracledb.pool import ConnectionPool
18
+
19
19
 
20
20
  __all__ = (
21
- "OracleSyncDatabaseConfig",
22
- "OracleSyncPoolConfig",
21
+ "OracleSync",
22
+ "OracleSyncPool",
23
23
  )
24
24
 
25
25
 
26
26
  @dataclass
27
- class OracleSyncPoolConfig(OracleGenericPoolConfig[Connection, ConnectionPool]):
27
+ class OracleSyncPool(OracleGenericPoolConfig["Connection", "ConnectionPool"]):
28
28
  """Sync Oracle Pool Config"""
29
29
 
30
30
 
31
31
  @dataclass
32
- class OracleSyncDatabaseConfig(SyncDatabaseConfig[Connection, ConnectionPool]):
32
+ class OracleSync(SyncDatabaseConfig["Connection", "ConnectionPool", "OracleSyncDriver"]):
33
33
  """Oracle Sync database Configuration.
34
34
 
35
35
  This class provides the base configuration for Oracle database connections, extending
@@ -42,30 +42,99 @@ class OracleSyncDatabaseConfig(SyncDatabaseConfig[Connection, ConnectionPool]):
42
42
  options.([2](https://python-oracledb.readthedocs.io/en/latest/user_guide/tuning.html))
43
43
  """
44
44
 
45
- pool_config: "Optional[OracleSyncPoolConfig]" = None
45
+ pool_config: "Optional[OracleSyncPool]" = None
46
46
  """Oracle Pool configuration"""
47
47
  pool_instance: "Optional[ConnectionPool]" = None
48
48
  """Optional pool to use.
49
49
 
50
50
  If set, the plugin will use the provided pool rather than instantiate one.
51
51
  """
52
+ connection_type: "type[Connection]" = field(init=False, default_factory=lambda: Connection) # pyright: ignore
53
+ """Connection class to use.
54
+
55
+ Defaults to :class:`Connection`.
56
+ """
57
+ driver_type: "type[OracleSyncDriver]" = field(init=False, default_factory=lambda: OracleSyncDriver) # type: ignore[type-abstract,unused-ignore]
58
+ """Driver class to use.
59
+
60
+ Defaults to :class:`OracleSyncDriver`.
61
+ """
62
+
63
+ @property
64
+ def connection_config_dict(self) -> "dict[str, Any]":
65
+ """Return the connection configuration as a dict.
66
+
67
+ Returns:
68
+ A string keyed dict of config kwargs for the oracledb.connect function.
69
+
70
+ Raises:
71
+ ImproperConfigurationError: If the connection configuration is not provided.
72
+ """
73
+ if self.pool_config:
74
+ # Filter out pool-specific parameters
75
+ pool_only_params = {
76
+ "min",
77
+ "max",
78
+ "increment",
79
+ "timeout",
80
+ "wait_timeout",
81
+ "max_lifetime_session",
82
+ "session_callback",
83
+ }
84
+ return dataclass_to_dict(
85
+ self.pool_config,
86
+ exclude_empty=True,
87
+ convert_nested=False,
88
+ exclude=pool_only_params.union({"pool_instance", "connection_type", "driver_type"}),
89
+ )
90
+ msg = "You must provide a 'pool_config' for this adapter."
91
+ raise ImproperConfigurationError(msg)
52
92
 
53
93
  @property
54
94
  def pool_config_dict(self) -> "dict[str, Any]":
55
95
  """Return the pool configuration as a dict.
56
96
 
97
+ Raises:
98
+ ImproperConfigurationError: If no pool_config is provided but a pool_instance
99
+
57
100
  Returns:
58
101
  A string keyed dict of config kwargs for the Asyncpg :func:`create_pool <oracledb.pool.create_pool>`
59
102
  function.
60
103
  """
61
104
  if self.pool_config:
62
- return dataclass_to_dict(self.pool_config, exclude_empty=True, convert_nested=False)
105
+ return dataclass_to_dict(
106
+ self.pool_config,
107
+ exclude_empty=True,
108
+ convert_nested=False,
109
+ exclude={"pool_instance", "connection_type", "driver_type"},
110
+ )
63
111
  msg = "'pool_config' methods can not be used when a 'pool_instance' is provided."
64
112
  raise ImproperConfigurationError(msg)
65
113
 
114
+ def create_connection(self) -> "Connection":
115
+ """Create and return a new oracledb connection.
116
+
117
+ Returns:
118
+ A Connection instance.
119
+
120
+ Raises:
121
+ ImproperConfigurationError: If the connection could not be created.
122
+ """
123
+ try:
124
+ import oracledb
125
+
126
+ return oracledb.connect(**self.connection_config_dict)
127
+ except Exception as e:
128
+ msg = f"Could not configure the Oracle connection. Error: {e!s}"
129
+ raise ImproperConfigurationError(msg) from e
130
+
66
131
  def create_pool(self) -> "ConnectionPool":
67
132
  """Return a pool. If none exists yet, create one.
68
133
 
134
+ Raises:
135
+ ImproperConfigurationError: If neither pool_config nor pool_instance is provided,
136
+ or if the pool could not be configured.
137
+
69
138
  Returns:
70
139
  Getter that returns the pool instance used by the plugin.
71
140
  """
@@ -95,9 +164,25 @@ class OracleSyncDatabaseConfig(SyncDatabaseConfig[Connection, ConnectionPool]):
95
164
  def provide_connection(self, *args: "Any", **kwargs: "Any") -> "Generator[Connection, None, None]":
96
165
  """Create a connection instance.
97
166
 
98
- Returns:
99
- A connection instance.
167
+ Yields:
168
+ Connection: A connection instance from the pool.
100
169
  """
101
170
  db_pool = self.provide_pool(*args, **kwargs)
102
171
  with db_pool.acquire() as connection: # pyright: ignore[reportUnknownMemberType]
103
172
  yield connection
173
+
174
+ @contextmanager
175
+ def provide_session(self, *args: "Any", **kwargs: "Any") -> "Generator[OracleSyncDriver, None, None]":
176
+ """Create and provide a database session.
177
+
178
+ Yields:
179
+ OracleSyncDriver: A driver instance with an active connection.
180
+ """
181
+ with self.provide_connection(*args, **kwargs) as connection:
182
+ yield self.driver_type(connection)
183
+
184
+ def close_pool(self) -> None:
185
+ """Close the connection pool."""
186
+ if self.pool_instance is not None:
187
+ self.pool_instance.close()
188
+ self.pool_instance = None