sqlspec 0.13.1__py3-none-any.whl → 0.16.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (185) hide show
  1. sqlspec/__init__.py +71 -8
  2. sqlspec/__main__.py +12 -0
  3. sqlspec/__metadata__.py +1 -3
  4. sqlspec/_serialization.py +1 -2
  5. sqlspec/_sql.py +930 -136
  6. sqlspec/_typing.py +278 -142
  7. sqlspec/adapters/adbc/__init__.py +4 -3
  8. sqlspec/adapters/adbc/_types.py +12 -0
  9. sqlspec/adapters/adbc/config.py +116 -285
  10. sqlspec/adapters/adbc/driver.py +462 -340
  11. sqlspec/adapters/aiosqlite/__init__.py +18 -3
  12. sqlspec/adapters/aiosqlite/_types.py +13 -0
  13. sqlspec/adapters/aiosqlite/config.py +202 -150
  14. sqlspec/adapters/aiosqlite/driver.py +226 -247
  15. sqlspec/adapters/asyncmy/__init__.py +18 -3
  16. sqlspec/adapters/asyncmy/_types.py +12 -0
  17. sqlspec/adapters/asyncmy/config.py +80 -199
  18. sqlspec/adapters/asyncmy/driver.py +257 -215
  19. sqlspec/adapters/asyncpg/__init__.py +19 -4
  20. sqlspec/adapters/asyncpg/_types.py +17 -0
  21. sqlspec/adapters/asyncpg/config.py +81 -214
  22. sqlspec/adapters/asyncpg/driver.py +284 -359
  23. sqlspec/adapters/bigquery/__init__.py +17 -3
  24. sqlspec/adapters/bigquery/_types.py +12 -0
  25. sqlspec/adapters/bigquery/config.py +191 -299
  26. sqlspec/adapters/bigquery/driver.py +474 -634
  27. sqlspec/adapters/duckdb/__init__.py +14 -3
  28. sqlspec/adapters/duckdb/_types.py +12 -0
  29. sqlspec/adapters/duckdb/config.py +414 -397
  30. sqlspec/adapters/duckdb/driver.py +342 -393
  31. sqlspec/adapters/oracledb/__init__.py +19 -5
  32. sqlspec/adapters/oracledb/_types.py +14 -0
  33. sqlspec/adapters/oracledb/config.py +123 -458
  34. sqlspec/adapters/oracledb/driver.py +505 -531
  35. sqlspec/adapters/psqlpy/__init__.py +13 -3
  36. sqlspec/adapters/psqlpy/_types.py +11 -0
  37. sqlspec/adapters/psqlpy/config.py +93 -307
  38. sqlspec/adapters/psqlpy/driver.py +504 -213
  39. sqlspec/adapters/psycopg/__init__.py +19 -5
  40. sqlspec/adapters/psycopg/_types.py +17 -0
  41. sqlspec/adapters/psycopg/config.py +143 -472
  42. sqlspec/adapters/psycopg/driver.py +704 -825
  43. sqlspec/adapters/sqlite/__init__.py +14 -3
  44. sqlspec/adapters/sqlite/_types.py +11 -0
  45. sqlspec/adapters/sqlite/config.py +208 -142
  46. sqlspec/adapters/sqlite/driver.py +263 -278
  47. sqlspec/base.py +105 -9
  48. sqlspec/{statement/builder → builder}/__init__.py +12 -14
  49. sqlspec/{statement/builder/base.py → builder/_base.py} +184 -86
  50. sqlspec/{statement/builder/column.py → builder/_column.py} +97 -60
  51. sqlspec/{statement/builder/ddl.py → builder/_ddl.py} +61 -131
  52. sqlspec/{statement/builder → builder}/_ddl_utils.py +4 -10
  53. sqlspec/{statement/builder/delete.py → builder/_delete.py} +10 -30
  54. sqlspec/builder/_insert.py +421 -0
  55. sqlspec/builder/_merge.py +71 -0
  56. sqlspec/{statement/builder → builder}/_parsing_utils.py +49 -26
  57. sqlspec/builder/_select.py +170 -0
  58. sqlspec/{statement/builder/update.py → builder/_update.py} +16 -20
  59. sqlspec/builder/mixins/__init__.py +55 -0
  60. sqlspec/builder/mixins/_cte_and_set_ops.py +222 -0
  61. sqlspec/{statement/builder/mixins/_delete_from.py → builder/mixins/_delete_operations.py} +8 -1
  62. sqlspec/builder/mixins/_insert_operations.py +244 -0
  63. sqlspec/{statement/builder/mixins/_join.py → builder/mixins/_join_operations.py} +45 -13
  64. sqlspec/{statement/builder/mixins/_merge_clauses.py → builder/mixins/_merge_operations.py} +188 -30
  65. sqlspec/builder/mixins/_order_limit_operations.py +135 -0
  66. sqlspec/builder/mixins/_pivot_operations.py +153 -0
  67. sqlspec/builder/mixins/_select_operations.py +604 -0
  68. sqlspec/builder/mixins/_update_operations.py +202 -0
  69. sqlspec/builder/mixins/_where_clause.py +644 -0
  70. sqlspec/cli.py +247 -0
  71. sqlspec/config.py +183 -138
  72. sqlspec/core/__init__.py +63 -0
  73. sqlspec/core/cache.py +871 -0
  74. sqlspec/core/compiler.py +417 -0
  75. sqlspec/core/filters.py +830 -0
  76. sqlspec/core/hashing.py +310 -0
  77. sqlspec/core/parameters.py +1237 -0
  78. sqlspec/core/result.py +677 -0
  79. sqlspec/{statement → core}/splitter.py +321 -191
  80. sqlspec/core/statement.py +676 -0
  81. sqlspec/driver/__init__.py +7 -10
  82. sqlspec/driver/_async.py +422 -163
  83. sqlspec/driver/_common.py +545 -287
  84. sqlspec/driver/_sync.py +426 -160
  85. sqlspec/driver/mixins/__init__.py +2 -13
  86. sqlspec/driver/mixins/_result_tools.py +193 -0
  87. sqlspec/driver/mixins/_sql_translator.py +65 -14
  88. sqlspec/exceptions.py +5 -252
  89. sqlspec/extensions/aiosql/adapter.py +93 -96
  90. sqlspec/extensions/litestar/__init__.py +2 -1
  91. sqlspec/extensions/litestar/cli.py +48 -0
  92. sqlspec/extensions/litestar/config.py +0 -1
  93. sqlspec/extensions/litestar/handlers.py +15 -26
  94. sqlspec/extensions/litestar/plugin.py +21 -16
  95. sqlspec/extensions/litestar/providers.py +17 -52
  96. sqlspec/loader.py +423 -104
  97. sqlspec/migrations/__init__.py +35 -0
  98. sqlspec/migrations/base.py +414 -0
  99. sqlspec/migrations/commands.py +443 -0
  100. sqlspec/migrations/loaders.py +402 -0
  101. sqlspec/migrations/runner.py +213 -0
  102. sqlspec/migrations/tracker.py +140 -0
  103. sqlspec/migrations/utils.py +129 -0
  104. sqlspec/protocols.py +51 -186
  105. sqlspec/storage/__init__.py +1 -1
  106. sqlspec/storage/backends/base.py +37 -40
  107. sqlspec/storage/backends/fsspec.py +136 -112
  108. sqlspec/storage/backends/obstore.py +138 -160
  109. sqlspec/storage/capabilities.py +5 -4
  110. sqlspec/storage/registry.py +57 -106
  111. sqlspec/typing.py +136 -115
  112. sqlspec/utils/__init__.py +2 -2
  113. sqlspec/utils/correlation.py +0 -3
  114. sqlspec/utils/deprecation.py +6 -6
  115. sqlspec/utils/fixtures.py +6 -6
  116. sqlspec/utils/logging.py +0 -2
  117. sqlspec/utils/module_loader.py +7 -12
  118. sqlspec/utils/singleton.py +0 -1
  119. sqlspec/utils/sync_tools.py +17 -38
  120. sqlspec/utils/text.py +12 -51
  121. sqlspec/utils/type_guards.py +482 -235
  122. {sqlspec-0.13.1.dist-info → sqlspec-0.16.2.dist-info}/METADATA +7 -2
  123. sqlspec-0.16.2.dist-info/RECORD +134 -0
  124. sqlspec-0.16.2.dist-info/entry_points.txt +2 -0
  125. sqlspec/driver/connection.py +0 -207
  126. sqlspec/driver/mixins/_csv_writer.py +0 -91
  127. sqlspec/driver/mixins/_pipeline.py +0 -512
  128. sqlspec/driver/mixins/_result_utils.py +0 -140
  129. sqlspec/driver/mixins/_storage.py +0 -926
  130. sqlspec/driver/mixins/_type_coercion.py +0 -130
  131. sqlspec/driver/parameters.py +0 -138
  132. sqlspec/service/__init__.py +0 -4
  133. sqlspec/service/_util.py +0 -147
  134. sqlspec/service/base.py +0 -1131
  135. sqlspec/service/pagination.py +0 -26
  136. sqlspec/statement/__init__.py +0 -21
  137. sqlspec/statement/builder/insert.py +0 -288
  138. sqlspec/statement/builder/merge.py +0 -95
  139. sqlspec/statement/builder/mixins/__init__.py +0 -65
  140. sqlspec/statement/builder/mixins/_aggregate_functions.py +0 -250
  141. sqlspec/statement/builder/mixins/_case_builder.py +0 -91
  142. sqlspec/statement/builder/mixins/_common_table_expr.py +0 -90
  143. sqlspec/statement/builder/mixins/_from.py +0 -63
  144. sqlspec/statement/builder/mixins/_group_by.py +0 -118
  145. sqlspec/statement/builder/mixins/_having.py +0 -35
  146. sqlspec/statement/builder/mixins/_insert_from_select.py +0 -47
  147. sqlspec/statement/builder/mixins/_insert_into.py +0 -36
  148. sqlspec/statement/builder/mixins/_insert_values.py +0 -67
  149. sqlspec/statement/builder/mixins/_limit_offset.py +0 -53
  150. sqlspec/statement/builder/mixins/_order_by.py +0 -46
  151. sqlspec/statement/builder/mixins/_pivot.py +0 -79
  152. sqlspec/statement/builder/mixins/_returning.py +0 -37
  153. sqlspec/statement/builder/mixins/_select_columns.py +0 -61
  154. sqlspec/statement/builder/mixins/_set_ops.py +0 -122
  155. sqlspec/statement/builder/mixins/_unpivot.py +0 -77
  156. sqlspec/statement/builder/mixins/_update_from.py +0 -55
  157. sqlspec/statement/builder/mixins/_update_set.py +0 -94
  158. sqlspec/statement/builder/mixins/_update_table.py +0 -29
  159. sqlspec/statement/builder/mixins/_where.py +0 -401
  160. sqlspec/statement/builder/mixins/_window_functions.py +0 -86
  161. sqlspec/statement/builder/select.py +0 -221
  162. sqlspec/statement/filters.py +0 -596
  163. sqlspec/statement/parameter_manager.py +0 -220
  164. sqlspec/statement/parameters.py +0 -867
  165. sqlspec/statement/pipelines/__init__.py +0 -210
  166. sqlspec/statement/pipelines/analyzers/__init__.py +0 -9
  167. sqlspec/statement/pipelines/analyzers/_analyzer.py +0 -646
  168. sqlspec/statement/pipelines/context.py +0 -115
  169. sqlspec/statement/pipelines/transformers/__init__.py +0 -7
  170. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +0 -88
  171. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +0 -1247
  172. sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +0 -76
  173. sqlspec/statement/pipelines/validators/__init__.py +0 -23
  174. sqlspec/statement/pipelines/validators/_dml_safety.py +0 -290
  175. sqlspec/statement/pipelines/validators/_parameter_style.py +0 -370
  176. sqlspec/statement/pipelines/validators/_performance.py +0 -718
  177. sqlspec/statement/pipelines/validators/_security.py +0 -967
  178. sqlspec/statement/result.py +0 -435
  179. sqlspec/statement/sql.py +0 -1704
  180. sqlspec/statement/sql_compiler.py +0 -140
  181. sqlspec/utils/cached_property.py +0 -25
  182. sqlspec-0.13.1.dist-info/RECORD +0 -150
  183. {sqlspec-0.13.1.dist-info → sqlspec-0.16.2.dist-info}/WHEEL +0 -0
  184. {sqlspec-0.13.1.dist-info → sqlspec-0.16.2.dist-info}/licenses/LICENSE +0 -0
  185. {sqlspec-0.13.1.dist-info → sqlspec-0.16.2.dist-info}/licenses/NOTICE +0 -0
@@ -1,34 +1,34 @@
1
1
  """AsyncPG database configuration with direct field-based configuration."""
2
2
 
3
3
  import logging
4
- from collections.abc import AsyncGenerator, Awaitable, Callable
4
+ from collections.abc import Callable
5
5
  from contextlib import asynccontextmanager
6
- from typing import TYPE_CHECKING, Any, ClassVar, TypedDict
6
+ from typing import TYPE_CHECKING, Any, ClassVar, Optional, TypedDict, Union
7
7
 
8
8
  from asyncpg import Connection, Record
9
9
  from asyncpg import create_pool as asyncpg_create_pool
10
10
  from asyncpg.connection import ConnectionMeta
11
11
  from asyncpg.pool import Pool, PoolConnectionProxy, PoolConnectionProxyMeta
12
- from typing_extensions import NotRequired, Unpack
12
+ from typing_extensions import NotRequired
13
13
 
14
- from sqlspec.adapters.asyncpg.driver import AsyncpgConnection, AsyncpgDriver
14
+ from sqlspec.adapters.asyncpg._types import AsyncpgConnection
15
+ from sqlspec.adapters.asyncpg.driver import AsyncpgCursor, AsyncpgDriver, asyncpg_statement_config
15
16
  from sqlspec.config import AsyncDatabaseConfig
16
- from sqlspec.statement.sql import SQLConfig
17
- from sqlspec.typing import DictRow, Empty
18
17
  from sqlspec.utils.serializers import from_json, to_json
19
18
 
20
19
  if TYPE_CHECKING:
21
20
  from asyncio.events import AbstractEventLoop
21
+ from collections.abc import AsyncGenerator, Awaitable
22
22
 
23
- from sqlglot.dialects.dialect import DialectType
23
+ from sqlspec.core.statement import StatementConfig
24
24
 
25
25
 
26
- __all__ = ("CONNECTION_FIELDS", "POOL_FIELDS", "AsyncpgConfig")
26
+ __all__ = ("AsyncpgConfig", "AsyncpgConnectionConfig", "AsyncpgDriverFeatures", "AsyncpgPoolConfig")
27
27
 
28
28
  logger = logging.getLogger("sqlspec")
29
29
 
30
30
 
31
- class AsyncpgConnectionParams(TypedDict, total=False):
31
+ class AsyncpgConnectionConfig(TypedDict, total=False):
32
32
  """TypedDict for AsyncPG connection parameters."""
33
33
 
34
34
  dsn: NotRequired[str]
@@ -37,7 +37,7 @@ class AsyncpgConnectionParams(TypedDict, total=False):
37
37
  user: NotRequired[str]
38
38
  password: NotRequired[str]
39
39
  database: NotRequired[str]
40
- ssl: NotRequired[Any] # Can be bool, SSLContext, or specific string
40
+ ssl: NotRequired[Any]
41
41
  passfile: NotRequired[str]
42
42
  direct_tls: NotRequired[bool]
43
43
  connect_timeout: NotRequired[float]
@@ -48,7 +48,7 @@ class AsyncpgConnectionParams(TypedDict, total=False):
48
48
  server_settings: NotRequired[dict[str, str]]
49
49
 
50
50
 
51
- class AsyncpgPoolParams(AsyncpgConnectionParams, total=False):
51
+ class AsyncpgPoolConfig(AsyncpgConnectionConfig, total=False):
52
52
  """TypedDict for AsyncPG pool parameters, inheriting connection parameters."""
53
53
 
54
54
  min_size: NotRequired[int]
@@ -60,221 +60,92 @@ class AsyncpgPoolParams(AsyncpgConnectionParams, total=False):
60
60
  loop: NotRequired["AbstractEventLoop"]
61
61
  connection_class: NotRequired[type["AsyncpgConnection"]]
62
62
  record_class: NotRequired[type[Record]]
63
+ extra: NotRequired[dict[str, Any]]
63
64
 
64
65
 
65
- class DriverParameters(AsyncpgPoolParams, total=False):
66
- """TypedDict for additional parameters that can be passed to AsyncPG."""
66
+ class AsyncpgDriverFeatures(TypedDict, total=False):
67
+ """TypedDict for AsyncPG driver features configuration."""
67
68
 
68
- statement_config: NotRequired[SQLConfig]
69
- default_row_type: NotRequired[type[DictRow]]
70
69
  json_serializer: NotRequired[Callable[[Any], str]]
71
70
  json_deserializer: NotRequired[Callable[[str], Any]]
72
- pool_instance: NotRequired["Pool[Record]"]
73
- extras: NotRequired[dict[str, Any]]
74
-
75
-
76
- CONNECTION_FIELDS = {
77
- "dsn",
78
- "host",
79
- "port",
80
- "user",
81
- "password",
82
- "database",
83
- "ssl",
84
- "passfile",
85
- "direct_tls",
86
- "connect_timeout",
87
- "command_timeout",
88
- "statement_cache_size",
89
- "max_cached_statement_lifetime",
90
- "max_cacheable_statement_size",
91
- "server_settings",
92
- }
93
- POOL_FIELDS = CONNECTION_FIELDS.union(
94
- {
95
- "min_size",
96
- "max_size",
97
- "max_queries",
98
- "max_inactive_connection_lifetime",
99
- "setup",
100
- "init",
101
- "loop",
102
- "connection_class",
103
- "record_class",
104
- }
105
- )
106
71
 
107
72
 
108
73
  class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", AsyncpgDriver]):
109
74
  """Configuration for AsyncPG database connections using TypedDict."""
110
75
 
111
- __slots__ = (
112
- "_dialect",
113
- "command_timeout",
114
- "connect_timeout",
115
- "connection_class",
116
- "database",
117
- "default_row_type",
118
- "direct_tls",
119
- "dsn",
120
- "extras",
121
- "host",
122
- "init",
123
- "json_deserializer",
124
- "json_serializer",
125
- "loop",
126
- "max_cacheable_statement_size",
127
- "max_cached_statement_lifetime",
128
- "max_inactive_connection_lifetime",
129
- "max_queries",
130
- "max_size",
131
- "min_size",
132
- "passfile",
133
- "password",
134
- "pool_instance",
135
- "port",
136
- "record_class",
137
- "server_settings",
138
- "setup",
139
- "ssl",
140
- "statement_cache_size",
141
- "statement_config",
142
- "user",
143
- )
144
-
145
- driver_type: type[AsyncpgDriver] = AsyncpgDriver
146
- connection_type: type[AsyncpgConnection] = type(AsyncpgConnection) # type: ignore[assignment]
147
- supported_parameter_styles: ClassVar[tuple[str, ...]] = ("numeric",)
148
- preferred_parameter_style: ClassVar[str] = "numeric"
149
-
150
- def __init__(self, **kwargs: "Unpack[DriverParameters]") -> None:
151
- """Initialize AsyncPG configuration."""
152
- # Known fields that are part of the config
153
- known_fields = {
154
- "dsn",
155
- "host",
156
- "port",
157
- "user",
158
- "password",
159
- "database",
160
- "ssl",
161
- "passfile",
162
- "direct_tls",
163
- "connect_timeout",
164
- "command_timeout",
165
- "statement_cache_size",
166
- "max_cached_statement_lifetime",
167
- "max_cacheable_statement_size",
168
- "server_settings",
169
- "min_size",
170
- "max_size",
171
- "max_queries",
172
- "max_inactive_connection_lifetime",
173
- "setup",
174
- "init",
175
- "loop",
176
- "connection_class",
177
- "record_class",
178
- "extras",
179
- "statement_config",
180
- "default_row_type",
181
- "json_serializer",
182
- "json_deserializer",
183
- "pool_instance",
184
- }
185
-
186
- self.dsn = kwargs.get("dsn")
187
- self.host = kwargs.get("host")
188
- self.port = kwargs.get("port")
189
- self.user = kwargs.get("user")
190
- self.password = kwargs.get("password")
191
- self.database = kwargs.get("database")
192
- self.ssl = kwargs.get("ssl")
193
- self.passfile = kwargs.get("passfile")
194
- self.direct_tls = kwargs.get("direct_tls")
195
- self.connect_timeout = kwargs.get("connect_timeout")
196
- self.command_timeout = kwargs.get("command_timeout")
197
- self.statement_cache_size = kwargs.get("statement_cache_size")
198
- self.max_cached_statement_lifetime = kwargs.get("max_cached_statement_lifetime")
199
- self.max_cacheable_statement_size = kwargs.get("max_cacheable_statement_size")
200
- self.server_settings = kwargs.get("server_settings")
201
- self.min_size = kwargs.get("min_size")
202
- self.max_size = kwargs.get("max_size")
203
- self.max_queries = kwargs.get("max_queries")
204
- self.max_inactive_connection_lifetime = kwargs.get("max_inactive_connection_lifetime")
205
- self.setup = kwargs.get("setup")
206
- self.init = kwargs.get("init")
207
- self.loop = kwargs.get("loop")
208
- self.connection_class = kwargs.get("connection_class")
209
- self.record_class = kwargs.get("record_class")
210
-
211
- # Collect unknown parameters into extras
212
- provided_extras = kwargs.get("extras", {})
213
- unknown_params = {k: v for k, v in kwargs.items() if k not in known_fields}
214
- self.extras = {**provided_extras, **unknown_params}
215
-
216
- self.statement_config = (
217
- SQLConfig() if kwargs.get("statement_config") is None else kwargs.get("statement_config")
218
- )
219
- self.default_row_type = kwargs.get("default_row_type", dict[str, Any])
220
- self.json_serializer = kwargs.get("json_serializer", to_json)
221
- self.json_deserializer = kwargs.get("json_deserializer", from_json)
222
- pool_instance_from_kwargs = kwargs.get("pool_instance")
223
- self._dialect: DialectType = None
224
-
225
- super().__init__()
76
+ driver_type: "ClassVar[type[AsyncpgDriver]]" = AsyncpgDriver
77
+ connection_type: "ClassVar[type[AsyncpgConnection]]" = type(AsyncpgConnection) # type: ignore[assignment]
226
78
 
227
- if pool_instance_from_kwargs is not None:
228
- self.pool_instance = pool_instance_from_kwargs
79
+ def __init__(
80
+ self,
81
+ *,
82
+ pool_config: "Optional[Union[AsyncpgPoolConfig, dict[str, Any]]]" = None,
83
+ pool_instance: "Optional[Pool[Record]]" = None,
84
+ migration_config: "Optional[dict[str, Any]]" = None,
85
+ statement_config: "Optional[StatementConfig]" = None,
86
+ driver_features: "Optional[Union[AsyncpgDriverFeatures, dict[str, Any]]]" = None,
87
+ ) -> None:
88
+ """Initialize AsyncPG configuration.
229
89
 
230
- @property
231
- def connection_config_dict(self) -> dict[str, Any]:
232
- """Return the connection configuration as a dict for asyncpg.connect().
233
-
234
- This method filters out pool-specific parameters that are not valid for asyncpg.connect().
90
+ Args:
91
+ pool_config: Pool configuration parameters (TypedDict or dict)
92
+ pool_instance: Existing pool instance to use
93
+ migration_config: Migration configuration
94
+ statement_config: Statement configuration override
95
+ driver_features: Driver features configuration (TypedDict or dict)
235
96
  """
236
- # Gather non-None connection parameters
237
- config = {
238
- field: getattr(self, field)
239
- for field in CONNECTION_FIELDS
240
- if getattr(self, field, None) is not None and getattr(self, field) is not Empty
241
- }
242
-
243
- config.update(self.extras)
244
-
245
- return config
97
+ features_dict: dict[str, Any] = dict(driver_features) if driver_features else {}
98
+
99
+ if "json_serializer" not in features_dict:
100
+ features_dict["json_serializer"] = to_json
101
+ if "json_deserializer" not in features_dict:
102
+ features_dict["json_deserializer"] = from_json
103
+ super().__init__(
104
+ pool_config=dict(pool_config) if pool_config else {},
105
+ pool_instance=pool_instance,
106
+ migration_config=migration_config,
107
+ statement_config=statement_config or asyncpg_statement_config,
108
+ driver_features=features_dict,
109
+ )
246
110
 
247
- @property
248
- def pool_config_dict(self) -> dict[str, Any]:
249
- """Return the full pool configuration as a dict for asyncpg.create_pool().
111
+ def _get_pool_config_dict(self) -> "dict[str, Any]":
112
+ """Get pool configuration as plain dict for external library.
250
113
 
251
114
  Returns:
252
- A dictionary containing all pool configuration parameters.
115
+ Dictionary with pool parameters, filtering out None values.
253
116
  """
254
- # All AsyncPG parameter names (connection + pool)
255
- config = {
256
- field: getattr(self, field)
257
- for field in POOL_FIELDS
258
- if getattr(self, field, None) is not None and getattr(self, field) is not Empty
259
- }
260
-
261
- # Merge extras parameters
262
- config.update(self.extras)
263
-
264
- return config
117
+ config: dict[str, Any] = dict(self.pool_config)
118
+ extras = config.pop("extra", {})
119
+ config.update(extras)
120
+ return {k: v for k, v in config.items() if v is not None}
265
121
 
266
122
  async def _create_pool(self) -> "Pool[Record]":
267
123
  """Create the actual async connection pool."""
268
- pool_args = self.pool_config_dict
269
- return await asyncpg_create_pool(**pool_args)
124
+ config = self._get_pool_config_dict()
125
+
126
+ if "init" not in config:
127
+ config["init"] = self._init_pgvector_connection
128
+
129
+ return await asyncpg_create_pool(**config)
130
+
131
+ async def _init_pgvector_connection(self, connection: "AsyncpgConnection") -> None:
132
+ """Initialize pgvector support for asyncpg connections."""
133
+ try:
134
+ import pgvector.asyncpg
135
+
136
+ await pgvector.asyncpg.register_vector(connection)
137
+ except ImportError:
138
+ pass
139
+ except Exception as e:
140
+ logger.debug("Failed to register pgvector for asyncpg: %s", e)
270
141
 
271
142
  async def _close_pool(self) -> None:
272
143
  """Close the actual async connection pool."""
273
144
  if self.pool_instance:
274
145
  await self.pool_instance.close()
275
146
 
276
- async def create_connection(self) -> AsyncpgConnection:
277
- """Create a single async connection (not from pool).
147
+ async def create_connection(self) -> "AsyncpgConnection":
148
+ """Create a single async connection from the pool.
278
149
 
279
150
  Returns:
280
151
  An AsyncPG connection instance.
@@ -284,7 +155,7 @@ class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", Async
284
155
  return await self.pool_instance.acquire()
285
156
 
286
157
  @asynccontextmanager
287
- async def provide_connection(self, *args: Any, **kwargs: Any) -> AsyncGenerator[AsyncpgConnection, None]:
158
+ async def provide_connection(self, *args: Any, **kwargs: Any) -> "AsyncGenerator[AsyncpgConnection, None]":
288
159
  """Provide an async connection context manager.
289
160
 
290
161
  Args:
@@ -305,28 +176,22 @@ class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", Async
305
176
  await self.pool_instance.release(connection)
306
177
 
307
178
  @asynccontextmanager
308
- async def provide_session(self, *args: Any, **kwargs: Any) -> AsyncGenerator[AsyncpgDriver, None]:
179
+ async def provide_session(
180
+ self, *args: Any, statement_config: "Optional[StatementConfig]" = None, **kwargs: Any
181
+ ) -> "AsyncGenerator[AsyncpgDriver, None]":
309
182
  """Provide an async driver session context manager.
310
183
 
311
184
  Args:
312
185
  *args: Additional arguments.
186
+ statement_config: Optional statement configuration override.
313
187
  **kwargs: Additional keyword arguments.
314
188
 
315
189
  Yields:
316
190
  An AsyncpgDriver instance.
317
191
  """
318
192
  async with self.provide_connection(*args, **kwargs) as connection:
319
- statement_config = self.statement_config
320
- # Inject parameter style info if not already set
321
- if statement_config is not None and statement_config.allowed_parameter_styles is None:
322
- from dataclasses import replace
323
-
324
- statement_config = replace(
325
- statement_config,
326
- allowed_parameter_styles=self.supported_parameter_styles,
327
- target_parameter_style=self.preferred_parameter_style,
328
- )
329
- yield self.driver_type(connection=connection, config=statement_config)
193
+ final_statement_config = statement_config or self.statement_config or asyncpg_statement_config
194
+ yield self.driver_type(connection=connection, statement_config=final_statement_config)
330
195
 
331
196
  async def provide_pool(self, *args: Any, **kwargs: Any) -> "Pool[Record]":
332
197
  """Provide async pool instance.
@@ -347,6 +212,7 @@ class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", Async
347
212
  Returns:
348
213
  Dictionary mapping type names to types.
349
214
  """
215
+
350
216
  namespace = super().get_signature_namespace()
351
217
  namespace.update(
352
218
  {
@@ -356,7 +222,8 @@ class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", Async
356
222
  "PoolConnectionProxyMeta": PoolConnectionProxyMeta,
357
223
  "ConnectionMeta": ConnectionMeta,
358
224
  "Record": Record,
359
- "AsyncpgConnection": type(AsyncpgConnection),
225
+ "AsyncpgConnection": AsyncpgConnection, # type: ignore[dict-item]
226
+ "AsyncpgCursor": AsyncpgCursor,
360
227
  }
361
228
  )
362
229
  return namespace