sqlspec 0.14.1__py3-none-any.whl → 0.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +50 -25
- sqlspec/__main__.py +1 -1
- sqlspec/__metadata__.py +1 -3
- sqlspec/_serialization.py +1 -2
- sqlspec/_sql.py +256 -120
- sqlspec/_typing.py +278 -142
- sqlspec/adapters/adbc/__init__.py +4 -3
- sqlspec/adapters/adbc/_types.py +12 -0
- sqlspec/adapters/adbc/config.py +115 -260
- sqlspec/adapters/adbc/driver.py +462 -367
- sqlspec/adapters/aiosqlite/__init__.py +18 -3
- sqlspec/adapters/aiosqlite/_types.py +13 -0
- sqlspec/adapters/aiosqlite/config.py +199 -129
- sqlspec/adapters/aiosqlite/driver.py +230 -269
- sqlspec/adapters/asyncmy/__init__.py +18 -3
- sqlspec/adapters/asyncmy/_types.py +12 -0
- sqlspec/adapters/asyncmy/config.py +80 -168
- sqlspec/adapters/asyncmy/driver.py +260 -225
- sqlspec/adapters/asyncpg/__init__.py +19 -4
- sqlspec/adapters/asyncpg/_types.py +17 -0
- sqlspec/adapters/asyncpg/config.py +82 -181
- sqlspec/adapters/asyncpg/driver.py +285 -383
- sqlspec/adapters/bigquery/__init__.py +17 -3
- sqlspec/adapters/bigquery/_types.py +12 -0
- sqlspec/adapters/bigquery/config.py +191 -258
- sqlspec/adapters/bigquery/driver.py +474 -646
- sqlspec/adapters/duckdb/__init__.py +14 -3
- sqlspec/adapters/duckdb/_types.py +12 -0
- sqlspec/adapters/duckdb/config.py +415 -351
- sqlspec/adapters/duckdb/driver.py +343 -413
- sqlspec/adapters/oracledb/__init__.py +19 -5
- sqlspec/adapters/oracledb/_types.py +14 -0
- sqlspec/adapters/oracledb/config.py +123 -379
- sqlspec/adapters/oracledb/driver.py +507 -560
- sqlspec/adapters/psqlpy/__init__.py +13 -3
- sqlspec/adapters/psqlpy/_types.py +11 -0
- sqlspec/adapters/psqlpy/config.py +93 -254
- sqlspec/adapters/psqlpy/driver.py +505 -234
- sqlspec/adapters/psycopg/__init__.py +19 -5
- sqlspec/adapters/psycopg/_types.py +17 -0
- sqlspec/adapters/psycopg/config.py +143 -403
- sqlspec/adapters/psycopg/driver.py +706 -872
- sqlspec/adapters/sqlite/__init__.py +14 -3
- sqlspec/adapters/sqlite/_types.py +11 -0
- sqlspec/adapters/sqlite/config.py +202 -118
- sqlspec/adapters/sqlite/driver.py +264 -303
- sqlspec/base.py +105 -9
- sqlspec/{statement/builder → builder}/__init__.py +12 -14
- sqlspec/{statement/builder → builder}/_base.py +120 -55
- sqlspec/{statement/builder → builder}/_column.py +17 -6
- sqlspec/{statement/builder → builder}/_ddl.py +46 -79
- sqlspec/{statement/builder → builder}/_ddl_utils.py +5 -10
- sqlspec/{statement/builder → builder}/_delete.py +6 -25
- sqlspec/{statement/builder → builder}/_insert.py +6 -64
- sqlspec/builder/_merge.py +56 -0
- sqlspec/{statement/builder → builder}/_parsing_utils.py +3 -10
- sqlspec/{statement/builder → builder}/_select.py +11 -56
- sqlspec/{statement/builder → builder}/_update.py +12 -18
- sqlspec/{statement/builder → builder}/mixins/__init__.py +10 -14
- sqlspec/{statement/builder → builder}/mixins/_cte_and_set_ops.py +48 -59
- sqlspec/{statement/builder → builder}/mixins/_insert_operations.py +22 -16
- sqlspec/{statement/builder → builder}/mixins/_join_operations.py +1 -3
- sqlspec/{statement/builder → builder}/mixins/_merge_operations.py +3 -5
- sqlspec/{statement/builder → builder}/mixins/_order_limit_operations.py +3 -3
- sqlspec/{statement/builder → builder}/mixins/_pivot_operations.py +4 -8
- sqlspec/{statement/builder → builder}/mixins/_select_operations.py +21 -36
- sqlspec/{statement/builder → builder}/mixins/_update_operations.py +3 -14
- sqlspec/{statement/builder → builder}/mixins/_where_clause.py +52 -79
- sqlspec/cli.py +4 -5
- sqlspec/config.py +180 -133
- sqlspec/core/__init__.py +63 -0
- sqlspec/core/cache.py +873 -0
- sqlspec/core/compiler.py +396 -0
- sqlspec/core/filters.py +828 -0
- sqlspec/core/hashing.py +310 -0
- sqlspec/core/parameters.py +1209 -0
- sqlspec/core/result.py +664 -0
- sqlspec/{statement → core}/splitter.py +321 -191
- sqlspec/core/statement.py +651 -0
- sqlspec/driver/__init__.py +7 -10
- sqlspec/driver/_async.py +387 -176
- sqlspec/driver/_common.py +527 -289
- sqlspec/driver/_sync.py +390 -172
- sqlspec/driver/mixins/__init__.py +2 -19
- sqlspec/driver/mixins/_result_tools.py +168 -0
- sqlspec/driver/mixins/_sql_translator.py +6 -3
- sqlspec/exceptions.py +5 -252
- sqlspec/extensions/aiosql/adapter.py +93 -96
- sqlspec/extensions/litestar/config.py +0 -1
- sqlspec/extensions/litestar/handlers.py +15 -26
- sqlspec/extensions/litestar/plugin.py +16 -14
- sqlspec/extensions/litestar/providers.py +17 -52
- sqlspec/loader.py +424 -105
- sqlspec/migrations/__init__.py +12 -0
- sqlspec/migrations/base.py +92 -68
- sqlspec/migrations/commands.py +24 -106
- sqlspec/migrations/loaders.py +402 -0
- sqlspec/migrations/runner.py +49 -51
- sqlspec/migrations/tracker.py +31 -44
- sqlspec/migrations/utils.py +64 -24
- sqlspec/protocols.py +7 -183
- sqlspec/storage/__init__.py +1 -1
- sqlspec/storage/backends/base.py +37 -40
- sqlspec/storage/backends/fsspec.py +136 -112
- sqlspec/storage/backends/obstore.py +138 -160
- sqlspec/storage/capabilities.py +5 -4
- sqlspec/storage/registry.py +57 -106
- sqlspec/typing.py +136 -115
- sqlspec/utils/__init__.py +2 -3
- sqlspec/utils/correlation.py +0 -3
- sqlspec/utils/deprecation.py +6 -6
- sqlspec/utils/fixtures.py +6 -6
- sqlspec/utils/logging.py +0 -2
- sqlspec/utils/module_loader.py +7 -12
- sqlspec/utils/singleton.py +0 -1
- sqlspec/utils/sync_tools.py +16 -37
- sqlspec/utils/text.py +12 -51
- sqlspec/utils/type_guards.py +443 -232
- {sqlspec-0.14.1.dist-info → sqlspec-0.15.0.dist-info}/METADATA +7 -2
- sqlspec-0.15.0.dist-info/RECORD +134 -0
- sqlspec/adapters/adbc/transformers.py +0 -108
- sqlspec/driver/connection.py +0 -207
- sqlspec/driver/mixins/_cache.py +0 -114
- sqlspec/driver/mixins/_csv_writer.py +0 -91
- sqlspec/driver/mixins/_pipeline.py +0 -508
- sqlspec/driver/mixins/_query_tools.py +0 -796
- sqlspec/driver/mixins/_result_utils.py +0 -138
- sqlspec/driver/mixins/_storage.py +0 -912
- sqlspec/driver/mixins/_type_coercion.py +0 -128
- sqlspec/driver/parameters.py +0 -138
- sqlspec/statement/__init__.py +0 -21
- sqlspec/statement/builder/_merge.py +0 -95
- sqlspec/statement/cache.py +0 -50
- sqlspec/statement/filters.py +0 -625
- sqlspec/statement/parameters.py +0 -956
- sqlspec/statement/pipelines/__init__.py +0 -210
- sqlspec/statement/pipelines/analyzers/__init__.py +0 -9
- sqlspec/statement/pipelines/analyzers/_analyzer.py +0 -646
- sqlspec/statement/pipelines/context.py +0 -109
- sqlspec/statement/pipelines/transformers/__init__.py +0 -7
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +0 -88
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +0 -1247
- sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +0 -76
- sqlspec/statement/pipelines/validators/__init__.py +0 -23
- sqlspec/statement/pipelines/validators/_dml_safety.py +0 -290
- sqlspec/statement/pipelines/validators/_parameter_style.py +0 -370
- sqlspec/statement/pipelines/validators/_performance.py +0 -714
- sqlspec/statement/pipelines/validators/_security.py +0 -967
- sqlspec/statement/result.py +0 -435
- sqlspec/statement/sql.py +0 -1774
- sqlspec/utils/cached_property.py +0 -25
- sqlspec/utils/statement_hashing.py +0 -203
- sqlspec-0.14.1.dist-info/RECORD +0 -145
- /sqlspec/{statement/builder → builder}/mixins/_delete_operations.py +0 -0
- {sqlspec-0.14.1.dist-info → sqlspec-0.15.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.14.1.dist-info → sqlspec-0.15.0.dist-info}/entry_points.txt +0 -0
- {sqlspec-0.14.1.dist-info → sqlspec-0.15.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.14.1.dist-info → sqlspec-0.15.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -1,32 +1,34 @@
|
|
|
1
1
|
"""AsyncPG database configuration with direct field-based configuration."""
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
|
-
from collections.abc import
|
|
4
|
+
from collections.abc import Callable
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
|
-
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict
|
|
6
|
+
from typing import TYPE_CHECKING, Any, ClassVar, Optional, TypedDict, Union
|
|
7
7
|
|
|
8
8
|
from asyncpg import Connection, Record
|
|
9
9
|
from asyncpg import create_pool as asyncpg_create_pool
|
|
10
10
|
from asyncpg.connection import ConnectionMeta
|
|
11
11
|
from asyncpg.pool import Pool, PoolConnectionProxy, PoolConnectionProxyMeta
|
|
12
|
-
from typing_extensions import NotRequired
|
|
12
|
+
from typing_extensions import NotRequired
|
|
13
13
|
|
|
14
|
-
from sqlspec.adapters.asyncpg.
|
|
14
|
+
from sqlspec.adapters.asyncpg._types import AsyncpgConnection
|
|
15
|
+
from sqlspec.adapters.asyncpg.driver import AsyncpgCursor, AsyncpgDriver, asyncpg_statement_config
|
|
15
16
|
from sqlspec.config import AsyncDatabaseConfig
|
|
16
|
-
from sqlspec.statement.sql import SQLConfig
|
|
17
|
-
from sqlspec.typing import DictRow, Empty
|
|
18
17
|
from sqlspec.utils.serializers import from_json, to_json
|
|
19
18
|
|
|
20
19
|
if TYPE_CHECKING:
|
|
21
20
|
from asyncio.events import AbstractEventLoop
|
|
21
|
+
from collections.abc import AsyncGenerator, Awaitable
|
|
22
22
|
|
|
23
|
+
from sqlspec.core.statement import StatementConfig
|
|
23
24
|
|
|
24
|
-
|
|
25
|
+
|
|
26
|
+
__all__ = ("AsyncpgConfig", "AsyncpgConnectionConfig", "AsyncpgDriverFeatures", "AsyncpgPoolConfig")
|
|
25
27
|
|
|
26
28
|
logger = logging.getLogger("sqlspec")
|
|
27
29
|
|
|
28
30
|
|
|
29
|
-
class
|
|
31
|
+
class AsyncpgConnectionConfig(TypedDict, total=False):
|
|
30
32
|
"""TypedDict for AsyncPG connection parameters."""
|
|
31
33
|
|
|
32
34
|
dsn: NotRequired[str]
|
|
@@ -35,7 +37,7 @@ class AsyncpgConnectionParams(TypedDict, total=False):
|
|
|
35
37
|
user: NotRequired[str]
|
|
36
38
|
password: NotRequired[str]
|
|
37
39
|
database: NotRequired[str]
|
|
38
|
-
ssl: NotRequired[Any]
|
|
40
|
+
ssl: NotRequired[Any]
|
|
39
41
|
passfile: NotRequired[str]
|
|
40
42
|
direct_tls: NotRequired[bool]
|
|
41
43
|
connect_timeout: NotRequired[float]
|
|
@@ -46,7 +48,7 @@ class AsyncpgConnectionParams(TypedDict, total=False):
|
|
|
46
48
|
server_settings: NotRequired[dict[str, str]]
|
|
47
49
|
|
|
48
50
|
|
|
49
|
-
class
|
|
51
|
+
class AsyncpgPoolConfig(AsyncpgConnectionConfig, total=False):
|
|
50
52
|
"""TypedDict for AsyncPG pool parameters, inheriting connection parameters."""
|
|
51
53
|
|
|
52
54
|
min_size: NotRequired[int]
|
|
@@ -58,189 +60,92 @@ class AsyncpgPoolParams(AsyncpgConnectionParams, total=False):
|
|
|
58
60
|
loop: NotRequired["AbstractEventLoop"]
|
|
59
61
|
connection_class: NotRequired[type["AsyncpgConnection"]]
|
|
60
62
|
record_class: NotRequired[type[Record]]
|
|
63
|
+
extra: NotRequired[dict[str, Any]]
|
|
61
64
|
|
|
62
65
|
|
|
63
|
-
class
|
|
64
|
-
"""TypedDict for
|
|
66
|
+
class AsyncpgDriverFeatures(TypedDict, total=False):
|
|
67
|
+
"""TypedDict for AsyncPG driver features configuration."""
|
|
65
68
|
|
|
66
|
-
statement_config: NotRequired[SQLConfig]
|
|
67
|
-
default_row_type: NotRequired[type[DictRow]]
|
|
68
69
|
json_serializer: NotRequired[Callable[[Any], str]]
|
|
69
70
|
json_deserializer: NotRequired[Callable[[str], Any]]
|
|
70
|
-
pool_instance: NotRequired["Pool[Record]"]
|
|
71
|
-
extras: NotRequired[dict[str, Any]]
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
CONNECTION_FIELDS = {
|
|
75
|
-
"dsn",
|
|
76
|
-
"host",
|
|
77
|
-
"port",
|
|
78
|
-
"user",
|
|
79
|
-
"password",
|
|
80
|
-
"database",
|
|
81
|
-
"ssl",
|
|
82
|
-
"passfile",
|
|
83
|
-
"direct_tls",
|
|
84
|
-
"connect_timeout",
|
|
85
|
-
"command_timeout",
|
|
86
|
-
"statement_cache_size",
|
|
87
|
-
"max_cached_statement_lifetime",
|
|
88
|
-
"max_cacheable_statement_size",
|
|
89
|
-
"server_settings",
|
|
90
|
-
}
|
|
91
|
-
POOL_FIELDS = CONNECTION_FIELDS.union(
|
|
92
|
-
{
|
|
93
|
-
"min_size",
|
|
94
|
-
"max_size",
|
|
95
|
-
"max_queries",
|
|
96
|
-
"max_inactive_connection_lifetime",
|
|
97
|
-
"setup",
|
|
98
|
-
"init",
|
|
99
|
-
"loop",
|
|
100
|
-
"connection_class",
|
|
101
|
-
"record_class",
|
|
102
|
-
}
|
|
103
|
-
)
|
|
104
71
|
|
|
105
72
|
|
|
106
73
|
class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", AsyncpgDriver]):
|
|
107
74
|
"""Configuration for AsyncPG database connections using TypedDict."""
|
|
108
75
|
|
|
109
|
-
driver_type: type[AsyncpgDriver] = AsyncpgDriver
|
|
110
|
-
connection_type: type[AsyncpgConnection] = type(AsyncpgConnection) # type: ignore[assignment]
|
|
111
|
-
supported_parameter_styles: ClassVar[tuple[str, ...]] = ("numeric",)
|
|
112
|
-
default_parameter_style: ClassVar[str] = "numeric"
|
|
113
|
-
|
|
114
|
-
def __init__(self, **kwargs: "Unpack[DriverParameters]") -> None:
|
|
115
|
-
"""Initialize AsyncPG configuration."""
|
|
116
|
-
# Known fields that are part of the config
|
|
117
|
-
known_fields = {
|
|
118
|
-
"dsn",
|
|
119
|
-
"host",
|
|
120
|
-
"port",
|
|
121
|
-
"user",
|
|
122
|
-
"password",
|
|
123
|
-
"database",
|
|
124
|
-
"ssl",
|
|
125
|
-
"passfile",
|
|
126
|
-
"direct_tls",
|
|
127
|
-
"connect_timeout",
|
|
128
|
-
"command_timeout",
|
|
129
|
-
"statement_cache_size",
|
|
130
|
-
"max_cached_statement_lifetime",
|
|
131
|
-
"max_cacheable_statement_size",
|
|
132
|
-
"server_settings",
|
|
133
|
-
"min_size",
|
|
134
|
-
"max_size",
|
|
135
|
-
"max_queries",
|
|
136
|
-
"max_inactive_connection_lifetime",
|
|
137
|
-
"setup",
|
|
138
|
-
"init",
|
|
139
|
-
"loop",
|
|
140
|
-
"connection_class",
|
|
141
|
-
"record_class",
|
|
142
|
-
"extras",
|
|
143
|
-
"statement_config",
|
|
144
|
-
"default_row_type",
|
|
145
|
-
"json_serializer",
|
|
146
|
-
"json_deserializer",
|
|
147
|
-
"pool_instance",
|
|
148
|
-
}
|
|
149
|
-
|
|
150
|
-
self.dsn = kwargs.get("dsn")
|
|
151
|
-
self.host = kwargs.get("host")
|
|
152
|
-
self.port = kwargs.get("port")
|
|
153
|
-
self.user = kwargs.get("user")
|
|
154
|
-
self.password = kwargs.get("password")
|
|
155
|
-
self.database = kwargs.get("database")
|
|
156
|
-
self.ssl = kwargs.get("ssl")
|
|
157
|
-
self.passfile = kwargs.get("passfile")
|
|
158
|
-
self.direct_tls = kwargs.get("direct_tls")
|
|
159
|
-
self.connect_timeout = kwargs.get("connect_timeout")
|
|
160
|
-
self.command_timeout = kwargs.get("command_timeout")
|
|
161
|
-
self.statement_cache_size = kwargs.get("statement_cache_size")
|
|
162
|
-
self.max_cached_statement_lifetime = kwargs.get("max_cached_statement_lifetime")
|
|
163
|
-
self.max_cacheable_statement_size = kwargs.get("max_cacheable_statement_size")
|
|
164
|
-
self.server_settings = kwargs.get("server_settings")
|
|
165
|
-
self.min_size = kwargs.get("min_size")
|
|
166
|
-
self.max_size = kwargs.get("max_size")
|
|
167
|
-
self.max_queries = kwargs.get("max_queries")
|
|
168
|
-
self.max_inactive_connection_lifetime = kwargs.get("max_inactive_connection_lifetime")
|
|
169
|
-
self.setup = kwargs.get("setup")
|
|
170
|
-
self.init = kwargs.get("init")
|
|
171
|
-
self.loop = kwargs.get("loop")
|
|
172
|
-
self.connection_class = kwargs.get("connection_class")
|
|
173
|
-
self.record_class = kwargs.get("record_class")
|
|
174
|
-
|
|
175
|
-
# Collect unknown parameters into extras
|
|
176
|
-
provided_extras = kwargs.get("extras", {})
|
|
177
|
-
unknown_params = {k: v for k, v in kwargs.items() if k not in known_fields}
|
|
178
|
-
self.extras = {**provided_extras, **unknown_params}
|
|
179
|
-
|
|
180
|
-
self.statement_config = (
|
|
181
|
-
SQLConfig() if kwargs.get("statement_config") is None else kwargs.get("statement_config")
|
|
182
|
-
)
|
|
183
|
-
self.default_row_type = kwargs.get("default_row_type", dict[str, Any])
|
|
184
|
-
self.json_serializer = kwargs.get("json_serializer", to_json)
|
|
185
|
-
self.json_deserializer = kwargs.get("json_deserializer", from_json)
|
|
186
|
-
pool_instance_from_kwargs = kwargs.get("pool_instance")
|
|
187
|
-
|
|
188
|
-
super().__init__()
|
|
189
|
-
|
|
190
|
-
# Override prepared statements to True for PostgreSQL since it supports them well
|
|
191
|
-
self.enable_prepared_statements = kwargs.get("enable_prepared_statements", True) # type: ignore[assignment]
|
|
192
|
-
|
|
193
|
-
if pool_instance_from_kwargs is not None:
|
|
194
|
-
self.pool_instance = pool_instance_from_kwargs
|
|
76
|
+
driver_type: "ClassVar[type[AsyncpgDriver]]" = AsyncpgDriver
|
|
77
|
+
connection_type: "ClassVar[type[AsyncpgConnection]]" = type(AsyncpgConnection) # type: ignore[assignment]
|
|
195
78
|
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
*,
|
|
82
|
+
pool_config: "Optional[Union[AsyncpgPoolConfig, dict[str, Any]]]" = None,
|
|
83
|
+
pool_instance: "Optional[Pool[Record]]" = None,
|
|
84
|
+
migration_config: "Optional[dict[str, Any]]" = None,
|
|
85
|
+
statement_config: "Optional[StatementConfig]" = None,
|
|
86
|
+
driver_features: "Optional[Union[AsyncpgDriverFeatures, dict[str, Any]]]" = None,
|
|
87
|
+
) -> None:
|
|
88
|
+
"""Initialize AsyncPG configuration.
|
|
199
89
|
|
|
200
|
-
|
|
90
|
+
Args:
|
|
91
|
+
pool_config: Pool configuration parameters (TypedDict or dict)
|
|
92
|
+
pool_instance: Existing pool instance to use
|
|
93
|
+
migration_config: Migration configuration
|
|
94
|
+
statement_config: Statement configuration override
|
|
95
|
+
driver_features: Driver features configuration (TypedDict or dict)
|
|
201
96
|
"""
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
97
|
+
features_dict: dict[str, Any] = dict(driver_features) if driver_features else {}
|
|
98
|
+
|
|
99
|
+
if "json_serializer" not in features_dict:
|
|
100
|
+
features_dict["json_serializer"] = to_json
|
|
101
|
+
if "json_deserializer" not in features_dict:
|
|
102
|
+
features_dict["json_deserializer"] = from_json
|
|
103
|
+
super().__init__(
|
|
104
|
+
pool_config=dict(pool_config) if pool_config else {},
|
|
105
|
+
pool_instance=pool_instance,
|
|
106
|
+
migration_config=migration_config,
|
|
107
|
+
statement_config=statement_config or asyncpg_statement_config,
|
|
108
|
+
driver_features=features_dict,
|
|
109
|
+
)
|
|
212
110
|
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
"""Return the full pool configuration as a dict for asyncpg.create_pool().
|
|
111
|
+
def _get_pool_config_dict(self) -> "dict[str, Any]":
|
|
112
|
+
"""Get pool configuration as plain dict for external library.
|
|
216
113
|
|
|
217
114
|
Returns:
|
|
218
|
-
|
|
115
|
+
Dictionary with pool parameters, filtering out None values.
|
|
219
116
|
"""
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
if getattr(self, field, None) is not None and getattr(self, field) is not Empty
|
|
225
|
-
}
|
|
226
|
-
|
|
227
|
-
# Merge extras parameters
|
|
228
|
-
config.update(self.extras)
|
|
229
|
-
|
|
230
|
-
return config
|
|
117
|
+
config: dict[str, Any] = dict(self.pool_config)
|
|
118
|
+
extras = config.pop("extra", {})
|
|
119
|
+
config.update(extras)
|
|
120
|
+
return {k: v for k, v in config.items() if v is not None}
|
|
231
121
|
|
|
232
122
|
async def _create_pool(self) -> "Pool[Record]":
|
|
233
123
|
"""Create the actual async connection pool."""
|
|
234
|
-
|
|
235
|
-
|
|
124
|
+
config = self._get_pool_config_dict()
|
|
125
|
+
|
|
126
|
+
if "init" not in config:
|
|
127
|
+
config["init"] = self._init_pgvector_connection
|
|
128
|
+
|
|
129
|
+
return await asyncpg_create_pool(**config)
|
|
130
|
+
|
|
131
|
+
async def _init_pgvector_connection(self, connection: "AsyncpgConnection") -> None:
|
|
132
|
+
"""Initialize pgvector support for asyncpg connections."""
|
|
133
|
+
try:
|
|
134
|
+
import pgvector.asyncpg
|
|
135
|
+
|
|
136
|
+
await pgvector.asyncpg.register_vector(connection)
|
|
137
|
+
except ImportError:
|
|
138
|
+
pass
|
|
139
|
+
except Exception as e:
|
|
140
|
+
logger.debug("Failed to register pgvector for asyncpg: %s", e)
|
|
236
141
|
|
|
237
142
|
async def _close_pool(self) -> None:
|
|
238
143
|
"""Close the actual async connection pool."""
|
|
239
144
|
if self.pool_instance:
|
|
240
145
|
await self.pool_instance.close()
|
|
241
146
|
|
|
242
|
-
async def create_connection(self) -> AsyncpgConnection:
|
|
243
|
-
"""Create a single async connection
|
|
147
|
+
async def create_connection(self) -> "AsyncpgConnection":
|
|
148
|
+
"""Create a single async connection from the pool.
|
|
244
149
|
|
|
245
150
|
Returns:
|
|
246
151
|
An AsyncPG connection instance.
|
|
@@ -250,7 +155,7 @@ class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", Async
|
|
|
250
155
|
return await self.pool_instance.acquire()
|
|
251
156
|
|
|
252
157
|
@asynccontextmanager
|
|
253
|
-
async def provide_connection(self, *args: Any, **kwargs: Any) -> AsyncGenerator[AsyncpgConnection, None]:
|
|
158
|
+
async def provide_connection(self, *args: Any, **kwargs: Any) -> "AsyncGenerator[AsyncpgConnection, None]":
|
|
254
159
|
"""Provide an async connection context manager.
|
|
255
160
|
|
|
256
161
|
Args:
|
|
@@ -271,28 +176,22 @@ class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", Async
|
|
|
271
176
|
await self.pool_instance.release(connection)
|
|
272
177
|
|
|
273
178
|
@asynccontextmanager
|
|
274
|
-
async def provide_session(
|
|
179
|
+
async def provide_session(
|
|
180
|
+
self, *args: Any, statement_config: "Optional[StatementConfig]" = None, **kwargs: Any
|
|
181
|
+
) -> "AsyncGenerator[AsyncpgDriver, None]":
|
|
275
182
|
"""Provide an async driver session context manager.
|
|
276
183
|
|
|
277
184
|
Args:
|
|
278
185
|
*args: Additional arguments.
|
|
186
|
+
statement_config: Optional statement configuration override.
|
|
279
187
|
**kwargs: Additional keyword arguments.
|
|
280
188
|
|
|
281
189
|
Yields:
|
|
282
190
|
An AsyncpgDriver instance.
|
|
283
191
|
"""
|
|
284
192
|
async with self.provide_connection(*args, **kwargs) as connection:
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
if statement_config is not None and statement_config.allowed_parameter_styles is None:
|
|
288
|
-
from dataclasses import replace
|
|
289
|
-
|
|
290
|
-
statement_config = replace(
|
|
291
|
-
statement_config,
|
|
292
|
-
allowed_parameter_styles=self.supported_parameter_styles,
|
|
293
|
-
default_parameter_style=self.default_parameter_style,
|
|
294
|
-
)
|
|
295
|
-
yield self.driver_type(connection=connection, config=statement_config)
|
|
193
|
+
final_statement_config = statement_config or self.statement_config or asyncpg_statement_config
|
|
194
|
+
yield self.driver_type(connection=connection, statement_config=final_statement_config)
|
|
296
195
|
|
|
297
196
|
async def provide_pool(self, *args: Any, **kwargs: Any) -> "Pool[Record]":
|
|
298
197
|
"""Provide async pool instance.
|
|
@@ -313,6 +212,7 @@ class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", Async
|
|
|
313
212
|
Returns:
|
|
314
213
|
Dictionary mapping type names to types.
|
|
315
214
|
"""
|
|
215
|
+
|
|
316
216
|
namespace = super().get_signature_namespace()
|
|
317
217
|
namespace.update(
|
|
318
218
|
{
|
|
@@ -322,7 +222,8 @@ class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", Async
|
|
|
322
222
|
"PoolConnectionProxyMeta": PoolConnectionProxyMeta,
|
|
323
223
|
"ConnectionMeta": ConnectionMeta,
|
|
324
224
|
"Record": Record,
|
|
325
|
-
"AsyncpgConnection":
|
|
225
|
+
"AsyncpgConnection": AsyncpgConnection, # type: ignore[dict-item]
|
|
226
|
+
"AsyncpgCursor": AsyncpgCursor,
|
|
326
227
|
}
|
|
327
228
|
)
|
|
328
229
|
return namespace
|