sqlspec 0.14.1__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (159) hide show
  1. sqlspec/__init__.py +50 -25
  2. sqlspec/__main__.py +1 -1
  3. sqlspec/__metadata__.py +1 -3
  4. sqlspec/_serialization.py +1 -2
  5. sqlspec/_sql.py +480 -121
  6. sqlspec/_typing.py +278 -142
  7. sqlspec/adapters/adbc/__init__.py +4 -3
  8. sqlspec/adapters/adbc/_types.py +12 -0
  9. sqlspec/adapters/adbc/config.py +115 -260
  10. sqlspec/adapters/adbc/driver.py +462 -367
  11. sqlspec/adapters/aiosqlite/__init__.py +18 -3
  12. sqlspec/adapters/aiosqlite/_types.py +13 -0
  13. sqlspec/adapters/aiosqlite/config.py +199 -129
  14. sqlspec/adapters/aiosqlite/driver.py +230 -269
  15. sqlspec/adapters/asyncmy/__init__.py +18 -3
  16. sqlspec/adapters/asyncmy/_types.py +12 -0
  17. sqlspec/adapters/asyncmy/config.py +80 -168
  18. sqlspec/adapters/asyncmy/driver.py +260 -225
  19. sqlspec/adapters/asyncpg/__init__.py +19 -4
  20. sqlspec/adapters/asyncpg/_types.py +17 -0
  21. sqlspec/adapters/asyncpg/config.py +82 -181
  22. sqlspec/adapters/asyncpg/driver.py +285 -383
  23. sqlspec/adapters/bigquery/__init__.py +17 -3
  24. sqlspec/adapters/bigquery/_types.py +12 -0
  25. sqlspec/adapters/bigquery/config.py +191 -258
  26. sqlspec/adapters/bigquery/driver.py +474 -646
  27. sqlspec/adapters/duckdb/__init__.py +14 -3
  28. sqlspec/adapters/duckdb/_types.py +12 -0
  29. sqlspec/adapters/duckdb/config.py +415 -351
  30. sqlspec/adapters/duckdb/driver.py +343 -413
  31. sqlspec/adapters/oracledb/__init__.py +19 -5
  32. sqlspec/adapters/oracledb/_types.py +14 -0
  33. sqlspec/adapters/oracledb/config.py +123 -379
  34. sqlspec/adapters/oracledb/driver.py +507 -560
  35. sqlspec/adapters/psqlpy/__init__.py +13 -3
  36. sqlspec/adapters/psqlpy/_types.py +11 -0
  37. sqlspec/adapters/psqlpy/config.py +93 -254
  38. sqlspec/adapters/psqlpy/driver.py +505 -234
  39. sqlspec/adapters/psycopg/__init__.py +19 -5
  40. sqlspec/adapters/psycopg/_types.py +17 -0
  41. sqlspec/adapters/psycopg/config.py +143 -403
  42. sqlspec/adapters/psycopg/driver.py +706 -872
  43. sqlspec/adapters/sqlite/__init__.py +14 -3
  44. sqlspec/adapters/sqlite/_types.py +11 -0
  45. sqlspec/adapters/sqlite/config.py +202 -118
  46. sqlspec/adapters/sqlite/driver.py +264 -303
  47. sqlspec/base.py +105 -9
  48. sqlspec/{statement/builder → builder}/__init__.py +12 -14
  49. sqlspec/{statement/builder → builder}/_base.py +120 -55
  50. sqlspec/{statement/builder → builder}/_column.py +17 -6
  51. sqlspec/{statement/builder → builder}/_ddl.py +46 -79
  52. sqlspec/{statement/builder → builder}/_ddl_utils.py +5 -10
  53. sqlspec/{statement/builder → builder}/_delete.py +6 -25
  54. sqlspec/{statement/builder → builder}/_insert.py +18 -65
  55. sqlspec/builder/_merge.py +56 -0
  56. sqlspec/{statement/builder → builder}/_parsing_utils.py +8 -11
  57. sqlspec/{statement/builder → builder}/_select.py +11 -56
  58. sqlspec/{statement/builder → builder}/_update.py +12 -18
  59. sqlspec/{statement/builder → builder}/mixins/__init__.py +10 -14
  60. sqlspec/{statement/builder → builder}/mixins/_cte_and_set_ops.py +48 -59
  61. sqlspec/{statement/builder → builder}/mixins/_insert_operations.py +34 -18
  62. sqlspec/{statement/builder → builder}/mixins/_join_operations.py +1 -3
  63. sqlspec/{statement/builder → builder}/mixins/_merge_operations.py +19 -9
  64. sqlspec/{statement/builder → builder}/mixins/_order_limit_operations.py +3 -3
  65. sqlspec/{statement/builder → builder}/mixins/_pivot_operations.py +4 -8
  66. sqlspec/{statement/builder → builder}/mixins/_select_operations.py +25 -38
  67. sqlspec/{statement/builder → builder}/mixins/_update_operations.py +15 -16
  68. sqlspec/{statement/builder → builder}/mixins/_where_clause.py +210 -137
  69. sqlspec/cli.py +4 -5
  70. sqlspec/config.py +180 -133
  71. sqlspec/core/__init__.py +63 -0
  72. sqlspec/core/cache.py +873 -0
  73. sqlspec/core/compiler.py +396 -0
  74. sqlspec/core/filters.py +830 -0
  75. sqlspec/core/hashing.py +310 -0
  76. sqlspec/core/parameters.py +1209 -0
  77. sqlspec/core/result.py +664 -0
  78. sqlspec/{statement → core}/splitter.py +321 -191
  79. sqlspec/core/statement.py +666 -0
  80. sqlspec/driver/__init__.py +7 -10
  81. sqlspec/driver/_async.py +387 -176
  82. sqlspec/driver/_common.py +527 -289
  83. sqlspec/driver/_sync.py +390 -172
  84. sqlspec/driver/mixins/__init__.py +2 -19
  85. sqlspec/driver/mixins/_result_tools.py +164 -0
  86. sqlspec/driver/mixins/_sql_translator.py +6 -3
  87. sqlspec/exceptions.py +5 -252
  88. sqlspec/extensions/aiosql/adapter.py +93 -96
  89. sqlspec/extensions/litestar/cli.py +1 -1
  90. sqlspec/extensions/litestar/config.py +0 -1
  91. sqlspec/extensions/litestar/handlers.py +15 -26
  92. sqlspec/extensions/litestar/plugin.py +18 -16
  93. sqlspec/extensions/litestar/providers.py +17 -52
  94. sqlspec/loader.py +424 -105
  95. sqlspec/migrations/__init__.py +12 -0
  96. sqlspec/migrations/base.py +92 -68
  97. sqlspec/migrations/commands.py +24 -106
  98. sqlspec/migrations/loaders.py +402 -0
  99. sqlspec/migrations/runner.py +49 -51
  100. sqlspec/migrations/tracker.py +31 -44
  101. sqlspec/migrations/utils.py +64 -24
  102. sqlspec/protocols.py +7 -183
  103. sqlspec/storage/__init__.py +1 -1
  104. sqlspec/storage/backends/base.py +37 -40
  105. sqlspec/storage/backends/fsspec.py +136 -112
  106. sqlspec/storage/backends/obstore.py +138 -160
  107. sqlspec/storage/capabilities.py +5 -4
  108. sqlspec/storage/registry.py +57 -106
  109. sqlspec/typing.py +136 -115
  110. sqlspec/utils/__init__.py +2 -3
  111. sqlspec/utils/correlation.py +0 -3
  112. sqlspec/utils/deprecation.py +6 -6
  113. sqlspec/utils/fixtures.py +6 -6
  114. sqlspec/utils/logging.py +0 -2
  115. sqlspec/utils/module_loader.py +7 -12
  116. sqlspec/utils/singleton.py +0 -1
  117. sqlspec/utils/sync_tools.py +17 -38
  118. sqlspec/utils/text.py +12 -51
  119. sqlspec/utils/type_guards.py +443 -232
  120. {sqlspec-0.14.1.dist-info → sqlspec-0.16.0.dist-info}/METADATA +7 -2
  121. sqlspec-0.16.0.dist-info/RECORD +134 -0
  122. sqlspec/adapters/adbc/transformers.py +0 -108
  123. sqlspec/driver/connection.py +0 -207
  124. sqlspec/driver/mixins/_cache.py +0 -114
  125. sqlspec/driver/mixins/_csv_writer.py +0 -91
  126. sqlspec/driver/mixins/_pipeline.py +0 -508
  127. sqlspec/driver/mixins/_query_tools.py +0 -796
  128. sqlspec/driver/mixins/_result_utils.py +0 -138
  129. sqlspec/driver/mixins/_storage.py +0 -912
  130. sqlspec/driver/mixins/_type_coercion.py +0 -128
  131. sqlspec/driver/parameters.py +0 -138
  132. sqlspec/statement/__init__.py +0 -21
  133. sqlspec/statement/builder/_merge.py +0 -95
  134. sqlspec/statement/cache.py +0 -50
  135. sqlspec/statement/filters.py +0 -625
  136. sqlspec/statement/parameters.py +0 -956
  137. sqlspec/statement/pipelines/__init__.py +0 -210
  138. sqlspec/statement/pipelines/analyzers/__init__.py +0 -9
  139. sqlspec/statement/pipelines/analyzers/_analyzer.py +0 -646
  140. sqlspec/statement/pipelines/context.py +0 -109
  141. sqlspec/statement/pipelines/transformers/__init__.py +0 -7
  142. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +0 -88
  143. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +0 -1247
  144. sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +0 -76
  145. sqlspec/statement/pipelines/validators/__init__.py +0 -23
  146. sqlspec/statement/pipelines/validators/_dml_safety.py +0 -290
  147. sqlspec/statement/pipelines/validators/_parameter_style.py +0 -370
  148. sqlspec/statement/pipelines/validators/_performance.py +0 -714
  149. sqlspec/statement/pipelines/validators/_security.py +0 -967
  150. sqlspec/statement/result.py +0 -435
  151. sqlspec/statement/sql.py +0 -1774
  152. sqlspec/utils/cached_property.py +0 -25
  153. sqlspec/utils/statement_hashing.py +0 -203
  154. sqlspec-0.14.1.dist-info/RECORD +0 -145
  155. /sqlspec/{statement/builder → builder}/mixins/_delete_operations.py +0 -0
  156. {sqlspec-0.14.1.dist-info → sqlspec-0.16.0.dist-info}/WHEEL +0 -0
  157. {sqlspec-0.14.1.dist-info → sqlspec-0.16.0.dist-info}/entry_points.txt +0 -0
  158. {sqlspec-0.14.1.dist-info → sqlspec-0.16.0.dist-info}/licenses/LICENSE +0 -0
  159. {sqlspec-0.14.1.dist-info → sqlspec-0.16.0.dist-info}/licenses/NOTICE +0 -0
@@ -1,32 +1,34 @@
1
1
  """AsyncPG database configuration with direct field-based configuration."""
2
2
 
3
3
  import logging
4
- from collections.abc import AsyncGenerator, Awaitable, Callable
4
+ from collections.abc import Callable
5
5
  from contextlib import asynccontextmanager
6
- from typing import TYPE_CHECKING, Any, ClassVar, TypedDict
6
+ from typing import TYPE_CHECKING, Any, ClassVar, Optional, TypedDict, Union
7
7
 
8
8
  from asyncpg import Connection, Record
9
9
  from asyncpg import create_pool as asyncpg_create_pool
10
10
  from asyncpg.connection import ConnectionMeta
11
11
  from asyncpg.pool import Pool, PoolConnectionProxy, PoolConnectionProxyMeta
12
- from typing_extensions import NotRequired, Unpack
12
+ from typing_extensions import NotRequired
13
13
 
14
- from sqlspec.adapters.asyncpg.driver import AsyncpgConnection, AsyncpgDriver
14
+ from sqlspec.adapters.asyncpg._types import AsyncpgConnection
15
+ from sqlspec.adapters.asyncpg.driver import AsyncpgCursor, AsyncpgDriver, asyncpg_statement_config
15
16
  from sqlspec.config import AsyncDatabaseConfig
16
- from sqlspec.statement.sql import SQLConfig
17
- from sqlspec.typing import DictRow, Empty
18
17
  from sqlspec.utils.serializers import from_json, to_json
19
18
 
20
19
  if TYPE_CHECKING:
21
20
  from asyncio.events import AbstractEventLoop
21
+ from collections.abc import AsyncGenerator, Awaitable
22
22
 
23
+ from sqlspec.core.statement import StatementConfig
23
24
 
24
- __all__ = ("CONNECTION_FIELDS", "POOL_FIELDS", "AsyncpgConfig")
25
+
26
+ __all__ = ("AsyncpgConfig", "AsyncpgConnectionConfig", "AsyncpgDriverFeatures", "AsyncpgPoolConfig")
25
27
 
26
28
  logger = logging.getLogger("sqlspec")
27
29
 
28
30
 
29
- class AsyncpgConnectionParams(TypedDict, total=False):
31
+ class AsyncpgConnectionConfig(TypedDict, total=False):
30
32
  """TypedDict for AsyncPG connection parameters."""
31
33
 
32
34
  dsn: NotRequired[str]
@@ -35,7 +37,7 @@ class AsyncpgConnectionParams(TypedDict, total=False):
35
37
  user: NotRequired[str]
36
38
  password: NotRequired[str]
37
39
  database: NotRequired[str]
38
- ssl: NotRequired[Any] # Can be bool, SSLContext, or specific string
40
+ ssl: NotRequired[Any]
39
41
  passfile: NotRequired[str]
40
42
  direct_tls: NotRequired[bool]
41
43
  connect_timeout: NotRequired[float]
@@ -46,7 +48,7 @@ class AsyncpgConnectionParams(TypedDict, total=False):
46
48
  server_settings: NotRequired[dict[str, str]]
47
49
 
48
50
 
49
- class AsyncpgPoolParams(AsyncpgConnectionParams, total=False):
51
+ class AsyncpgPoolConfig(AsyncpgConnectionConfig, total=False):
50
52
  """TypedDict for AsyncPG pool parameters, inheriting connection parameters."""
51
53
 
52
54
  min_size: NotRequired[int]
@@ -58,189 +60,92 @@ class AsyncpgPoolParams(AsyncpgConnectionParams, total=False):
58
60
  loop: NotRequired["AbstractEventLoop"]
59
61
  connection_class: NotRequired[type["AsyncpgConnection"]]
60
62
  record_class: NotRequired[type[Record]]
63
+ extra: NotRequired[dict[str, Any]]
61
64
 
62
65
 
63
- class DriverParameters(AsyncpgPoolParams, total=False):
64
- """TypedDict for additional parameters that can be passed to AsyncPG."""
66
+ class AsyncpgDriverFeatures(TypedDict, total=False):
67
+ """TypedDict for AsyncPG driver features configuration."""
65
68
 
66
- statement_config: NotRequired[SQLConfig]
67
- default_row_type: NotRequired[type[DictRow]]
68
69
  json_serializer: NotRequired[Callable[[Any], str]]
69
70
  json_deserializer: NotRequired[Callable[[str], Any]]
70
- pool_instance: NotRequired["Pool[Record]"]
71
- extras: NotRequired[dict[str, Any]]
72
-
73
-
74
- CONNECTION_FIELDS = {
75
- "dsn",
76
- "host",
77
- "port",
78
- "user",
79
- "password",
80
- "database",
81
- "ssl",
82
- "passfile",
83
- "direct_tls",
84
- "connect_timeout",
85
- "command_timeout",
86
- "statement_cache_size",
87
- "max_cached_statement_lifetime",
88
- "max_cacheable_statement_size",
89
- "server_settings",
90
- }
91
- POOL_FIELDS = CONNECTION_FIELDS.union(
92
- {
93
- "min_size",
94
- "max_size",
95
- "max_queries",
96
- "max_inactive_connection_lifetime",
97
- "setup",
98
- "init",
99
- "loop",
100
- "connection_class",
101
- "record_class",
102
- }
103
- )
104
71
 
105
72
 
106
73
  class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", AsyncpgDriver]):
107
74
  """Configuration for AsyncPG database connections using TypedDict."""
108
75
 
109
- driver_type: type[AsyncpgDriver] = AsyncpgDriver
110
- connection_type: type[AsyncpgConnection] = type(AsyncpgConnection) # type: ignore[assignment]
111
- supported_parameter_styles: ClassVar[tuple[str, ...]] = ("numeric",)
112
- default_parameter_style: ClassVar[str] = "numeric"
113
-
114
- def __init__(self, **kwargs: "Unpack[DriverParameters]") -> None:
115
- """Initialize AsyncPG configuration."""
116
- # Known fields that are part of the config
117
- known_fields = {
118
- "dsn",
119
- "host",
120
- "port",
121
- "user",
122
- "password",
123
- "database",
124
- "ssl",
125
- "passfile",
126
- "direct_tls",
127
- "connect_timeout",
128
- "command_timeout",
129
- "statement_cache_size",
130
- "max_cached_statement_lifetime",
131
- "max_cacheable_statement_size",
132
- "server_settings",
133
- "min_size",
134
- "max_size",
135
- "max_queries",
136
- "max_inactive_connection_lifetime",
137
- "setup",
138
- "init",
139
- "loop",
140
- "connection_class",
141
- "record_class",
142
- "extras",
143
- "statement_config",
144
- "default_row_type",
145
- "json_serializer",
146
- "json_deserializer",
147
- "pool_instance",
148
- }
149
-
150
- self.dsn = kwargs.get("dsn")
151
- self.host = kwargs.get("host")
152
- self.port = kwargs.get("port")
153
- self.user = kwargs.get("user")
154
- self.password = kwargs.get("password")
155
- self.database = kwargs.get("database")
156
- self.ssl = kwargs.get("ssl")
157
- self.passfile = kwargs.get("passfile")
158
- self.direct_tls = kwargs.get("direct_tls")
159
- self.connect_timeout = kwargs.get("connect_timeout")
160
- self.command_timeout = kwargs.get("command_timeout")
161
- self.statement_cache_size = kwargs.get("statement_cache_size")
162
- self.max_cached_statement_lifetime = kwargs.get("max_cached_statement_lifetime")
163
- self.max_cacheable_statement_size = kwargs.get("max_cacheable_statement_size")
164
- self.server_settings = kwargs.get("server_settings")
165
- self.min_size = kwargs.get("min_size")
166
- self.max_size = kwargs.get("max_size")
167
- self.max_queries = kwargs.get("max_queries")
168
- self.max_inactive_connection_lifetime = kwargs.get("max_inactive_connection_lifetime")
169
- self.setup = kwargs.get("setup")
170
- self.init = kwargs.get("init")
171
- self.loop = kwargs.get("loop")
172
- self.connection_class = kwargs.get("connection_class")
173
- self.record_class = kwargs.get("record_class")
174
-
175
- # Collect unknown parameters into extras
176
- provided_extras = kwargs.get("extras", {})
177
- unknown_params = {k: v for k, v in kwargs.items() if k not in known_fields}
178
- self.extras = {**provided_extras, **unknown_params}
179
-
180
- self.statement_config = (
181
- SQLConfig() if kwargs.get("statement_config") is None else kwargs.get("statement_config")
182
- )
183
- self.default_row_type = kwargs.get("default_row_type", dict[str, Any])
184
- self.json_serializer = kwargs.get("json_serializer", to_json)
185
- self.json_deserializer = kwargs.get("json_deserializer", from_json)
186
- pool_instance_from_kwargs = kwargs.get("pool_instance")
187
-
188
- super().__init__()
189
-
190
- # Override prepared statements to True for PostgreSQL since it supports them well
191
- self.enable_prepared_statements = kwargs.get("enable_prepared_statements", True) # type: ignore[assignment]
192
-
193
- if pool_instance_from_kwargs is not None:
194
- self.pool_instance = pool_instance_from_kwargs
76
+ driver_type: "ClassVar[type[AsyncpgDriver]]" = AsyncpgDriver
77
+ connection_type: "ClassVar[type[AsyncpgConnection]]" = type(AsyncpgConnection) # type: ignore[assignment]
195
78
 
196
- @property
197
- def connection_config_dict(self) -> dict[str, Any]:
198
- """Return the connection configuration as a dict for asyncpg.connect().
79
+ def __init__(
80
+ self,
81
+ *,
82
+ pool_config: "Optional[Union[AsyncpgPoolConfig, dict[str, Any]]]" = None,
83
+ pool_instance: "Optional[Pool[Record]]" = None,
84
+ migration_config: "Optional[dict[str, Any]]" = None,
85
+ statement_config: "Optional[StatementConfig]" = None,
86
+ driver_features: "Optional[Union[AsyncpgDriverFeatures, dict[str, Any]]]" = None,
87
+ ) -> None:
88
+ """Initialize AsyncPG configuration.
199
89
 
200
- This method filters out pool-specific parameters that are not valid for asyncpg.connect().
90
+ Args:
91
+ pool_config: Pool configuration parameters (TypedDict or dict)
92
+ pool_instance: Existing pool instance to use
93
+ migration_config: Migration configuration
94
+ statement_config: Statement configuration override
95
+ driver_features: Driver features configuration (TypedDict or dict)
201
96
  """
202
- # Gather non-None connection parameters
203
- config = {
204
- field: getattr(self, field)
205
- for field in CONNECTION_FIELDS
206
- if getattr(self, field, None) is not None and getattr(self, field) is not Empty
207
- }
208
-
209
- config.update(self.extras)
210
-
211
- return config
97
+ features_dict: dict[str, Any] = dict(driver_features) if driver_features else {}
98
+
99
+ if "json_serializer" not in features_dict:
100
+ features_dict["json_serializer"] = to_json
101
+ if "json_deserializer" not in features_dict:
102
+ features_dict["json_deserializer"] = from_json
103
+ super().__init__(
104
+ pool_config=dict(pool_config) if pool_config else {},
105
+ pool_instance=pool_instance,
106
+ migration_config=migration_config,
107
+ statement_config=statement_config or asyncpg_statement_config,
108
+ driver_features=features_dict,
109
+ )
212
110
 
213
- @property
214
- def pool_config_dict(self) -> dict[str, Any]:
215
- """Return the full pool configuration as a dict for asyncpg.create_pool().
111
+ def _get_pool_config_dict(self) -> "dict[str, Any]":
112
+ """Get pool configuration as plain dict for external library.
216
113
 
217
114
  Returns:
218
- A dictionary containing all pool configuration parameters.
115
+ Dictionary with pool parameters, filtering out None values.
219
116
  """
220
- # All AsyncPG parameter names (connection + pool)
221
- config = {
222
- field: getattr(self, field)
223
- for field in POOL_FIELDS
224
- if getattr(self, field, None) is not None and getattr(self, field) is not Empty
225
- }
226
-
227
- # Merge extras parameters
228
- config.update(self.extras)
229
-
230
- return config
117
+ config: dict[str, Any] = dict(self.pool_config)
118
+ extras = config.pop("extra", {})
119
+ config.update(extras)
120
+ return {k: v for k, v in config.items() if v is not None}
231
121
 
232
122
  async def _create_pool(self) -> "Pool[Record]":
233
123
  """Create the actual async connection pool."""
234
- pool_args = self.pool_config_dict
235
- return await asyncpg_create_pool(**pool_args)
124
+ config = self._get_pool_config_dict()
125
+
126
+ if "init" not in config:
127
+ config["init"] = self._init_pgvector_connection
128
+
129
+ return await asyncpg_create_pool(**config)
130
+
131
+ async def _init_pgvector_connection(self, connection: "AsyncpgConnection") -> None:
132
+ """Initialize pgvector support for asyncpg connections."""
133
+ try:
134
+ import pgvector.asyncpg
135
+
136
+ await pgvector.asyncpg.register_vector(connection)
137
+ except ImportError:
138
+ pass
139
+ except Exception as e:
140
+ logger.debug("Failed to register pgvector for asyncpg: %s", e)
236
141
 
237
142
  async def _close_pool(self) -> None:
238
143
  """Close the actual async connection pool."""
239
144
  if self.pool_instance:
240
145
  await self.pool_instance.close()
241
146
 
242
- async def create_connection(self) -> AsyncpgConnection:
243
- """Create a single async connection (not from pool).
147
+ async def create_connection(self) -> "AsyncpgConnection":
148
+ """Create a single async connection from the pool.
244
149
 
245
150
  Returns:
246
151
  An AsyncPG connection instance.
@@ -250,7 +155,7 @@ class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", Async
250
155
  return await self.pool_instance.acquire()
251
156
 
252
157
  @asynccontextmanager
253
- async def provide_connection(self, *args: Any, **kwargs: Any) -> AsyncGenerator[AsyncpgConnection, None]:
158
+ async def provide_connection(self, *args: Any, **kwargs: Any) -> "AsyncGenerator[AsyncpgConnection, None]":
254
159
  """Provide an async connection context manager.
255
160
 
256
161
  Args:
@@ -271,28 +176,22 @@ class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", Async
271
176
  await self.pool_instance.release(connection)
272
177
 
273
178
  @asynccontextmanager
274
- async def provide_session(self, *args: Any, **kwargs: Any) -> AsyncGenerator[AsyncpgDriver, None]:
179
+ async def provide_session(
180
+ self, *args: Any, statement_config: "Optional[StatementConfig]" = None, **kwargs: Any
181
+ ) -> "AsyncGenerator[AsyncpgDriver, None]":
275
182
  """Provide an async driver session context manager.
276
183
 
277
184
  Args:
278
185
  *args: Additional arguments.
186
+ statement_config: Optional statement configuration override.
279
187
  **kwargs: Additional keyword arguments.
280
188
 
281
189
  Yields:
282
190
  An AsyncpgDriver instance.
283
191
  """
284
192
  async with self.provide_connection(*args, **kwargs) as connection:
285
- statement_config = self.statement_config
286
- # Inject parameter style info if not already set
287
- if statement_config is not None and statement_config.allowed_parameter_styles is None:
288
- from dataclasses import replace
289
-
290
- statement_config = replace(
291
- statement_config,
292
- allowed_parameter_styles=self.supported_parameter_styles,
293
- default_parameter_style=self.default_parameter_style,
294
- )
295
- yield self.driver_type(connection=connection, config=statement_config)
193
+ final_statement_config = statement_config or self.statement_config or asyncpg_statement_config
194
+ yield self.driver_type(connection=connection, statement_config=final_statement_config)
296
195
 
297
196
  async def provide_pool(self, *args: Any, **kwargs: Any) -> "Pool[Record]":
298
197
  """Provide async pool instance.
@@ -313,6 +212,7 @@ class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", Async
313
212
  Returns:
314
213
  Dictionary mapping type names to types.
315
214
  """
215
+
316
216
  namespace = super().get_signature_namespace()
317
217
  namespace.update(
318
218
  {
@@ -322,7 +222,8 @@ class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", Async
322
222
  "PoolConnectionProxyMeta": PoolConnectionProxyMeta,
323
223
  "ConnectionMeta": ConnectionMeta,
324
224
  "Record": Record,
325
- "AsyncpgConnection": type(AsyncpgConnection),
225
+ "AsyncpgConnection": AsyncpgConnection, # type: ignore[dict-item]
226
+ "AsyncpgCursor": AsyncpgCursor,
326
227
  }
327
228
  )
328
229
  return namespace