sqlspec 0.8.0__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_typing.py +39 -6
- sqlspec/adapters/adbc/__init__.py +2 -2
- sqlspec/adapters/adbc/config.py +34 -11
- sqlspec/adapters/adbc/driver.py +302 -111
- sqlspec/adapters/aiosqlite/__init__.py +2 -2
- sqlspec/adapters/aiosqlite/config.py +2 -2
- sqlspec/adapters/aiosqlite/driver.py +164 -42
- sqlspec/adapters/asyncmy/__init__.py +3 -3
- sqlspec/adapters/asyncmy/config.py +11 -12
- sqlspec/adapters/asyncmy/driver.py +161 -37
- sqlspec/adapters/asyncpg/__init__.py +5 -5
- sqlspec/adapters/asyncpg/config.py +17 -19
- sqlspec/adapters/asyncpg/driver.py +386 -96
- sqlspec/adapters/duckdb/__init__.py +2 -2
- sqlspec/adapters/duckdb/config.py +2 -2
- sqlspec/adapters/duckdb/driver.py +190 -60
- sqlspec/adapters/oracledb/__init__.py +8 -8
- sqlspec/adapters/oracledb/config/__init__.py +6 -6
- sqlspec/adapters/oracledb/config/_asyncio.py +9 -10
- sqlspec/adapters/oracledb/config/_sync.py +8 -9
- sqlspec/adapters/oracledb/driver.py +384 -45
- sqlspec/adapters/psqlpy/__init__.py +0 -0
- sqlspec/adapters/psqlpy/config.py +250 -0
- sqlspec/adapters/psqlpy/driver.py +481 -0
- sqlspec/adapters/psycopg/__init__.py +10 -5
- sqlspec/adapters/psycopg/config/__init__.py +6 -6
- sqlspec/adapters/psycopg/config/_async.py +12 -12
- sqlspec/adapters/psycopg/config/_sync.py +13 -13
- sqlspec/adapters/psycopg/driver.py +432 -222
- sqlspec/adapters/sqlite/__init__.py +2 -2
- sqlspec/adapters/sqlite/config.py +2 -2
- sqlspec/adapters/sqlite/driver.py +176 -72
- sqlspec/base.py +687 -161
- sqlspec/exceptions.py +30 -0
- sqlspec/extensions/litestar/config.py +6 -0
- sqlspec/extensions/litestar/handlers.py +25 -0
- sqlspec/extensions/litestar/plugin.py +8 -1
- sqlspec/statement.py +373 -0
- sqlspec/typing.py +10 -1
- {sqlspec-0.8.0.dist-info → sqlspec-0.9.1.dist-info}/METADATA +144 -2
- sqlspec-0.9.1.dist-info/RECORD +61 -0
- sqlspec-0.8.0.dist-info/RECORD +0 -57
- {sqlspec-0.8.0.dist-info → sqlspec-0.9.1.dist-info}/WHEEL +0 -0
- {sqlspec-0.8.0.dist-info → sqlspec-0.9.1.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.8.0.dist-info → sqlspec-0.9.1.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,481 @@
|
|
|
1
|
+
# ruff: noqa: PLR0915, PLR0914, PLR0912, C901
|
|
2
|
+
"""Psqlpy Driver Implementation."""
|
|
3
|
+
|
|
4
|
+
import logging
|
|
5
|
+
import re
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload
|
|
7
|
+
|
|
8
|
+
from psqlpy.exceptions import RustPSQLDriverPyBaseError
|
|
9
|
+
|
|
10
|
+
from sqlspec.base import AsyncDriverAdapterProtocol, T
|
|
11
|
+
from sqlspec.exceptions import SQLParsingError
|
|
12
|
+
from sqlspec.statement import PARAM_REGEX, SQLStatement
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from collections.abc import Sequence
|
|
16
|
+
|
|
17
|
+
from psqlpy import Connection, QueryResult
|
|
18
|
+
|
|
19
|
+
from sqlspec.typing import ModelDTOT, StatementParameterType
|
|
20
|
+
|
|
21
|
+
__all__ = ("PsqlpyDriver",)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# Regex to find '?' placeholders, skipping those inside quotes or SQL comments
|
|
25
|
+
QMARK_REGEX = re.compile(
|
|
26
|
+
r"""(?P<dquote>"[^"]*") | # Double-quoted strings
|
|
27
|
+
(?P<squote>\'[^\']*\') | # Single-quoted strings
|
|
28
|
+
(?P<comment>--[^\n]*|/\*.*?\*/) | # SQL comments (single/multi-line)
|
|
29
|
+
(?P<qmark>\?) # The question mark placeholder
|
|
30
|
+
""",
|
|
31
|
+
re.VERBOSE | re.DOTALL,
|
|
32
|
+
)
|
|
33
|
+
logger = logging.getLogger("sqlspec")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class PsqlpyDriver(AsyncDriverAdapterProtocol["Connection"]):
|
|
37
|
+
"""Psqlpy Postgres Driver Adapter."""
|
|
38
|
+
|
|
39
|
+
connection: "Connection"
|
|
40
|
+
dialect: str = "postgres"
|
|
41
|
+
|
|
42
|
+
def __init__(self, connection: "Connection") -> None:
|
|
43
|
+
self.connection = connection
|
|
44
|
+
|
|
45
|
+
def _process_sql_params(
|
|
46
|
+
self,
|
|
47
|
+
sql: str,
|
|
48
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
49
|
+
/,
|
|
50
|
+
**kwargs: Any,
|
|
51
|
+
) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]":
|
|
52
|
+
"""Process SQL and parameters for psqlpy.
|
|
53
|
+
|
|
54
|
+
psqlpy uses $1, $2 style parameters natively.
|
|
55
|
+
This method converts '?' (tuple/list) and ':name' (dict) styles to $n.
|
|
56
|
+
It relies on SQLStatement for initial parameter validation and merging.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
sql: The SQL to process.
|
|
60
|
+
parameters: The parameters to process.
|
|
61
|
+
kwargs: Additional keyword arguments.
|
|
62
|
+
|
|
63
|
+
Raises:
|
|
64
|
+
SQLParsingError: If the SQL is invalid.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
A tuple of the processed SQL and parameters.
|
|
68
|
+
"""
|
|
69
|
+
stmt = SQLStatement(sql=sql, parameters=parameters, dialect=self.dialect, kwargs=kwargs or None)
|
|
70
|
+
sql, parameters = stmt.process()
|
|
71
|
+
|
|
72
|
+
# Case 1: Parameters are a dictionary
|
|
73
|
+
if isinstance(parameters, dict):
|
|
74
|
+
processed_sql_parts: list[str] = []
|
|
75
|
+
ordered_params = []
|
|
76
|
+
last_end = 0
|
|
77
|
+
param_index = 1
|
|
78
|
+
found_params_regex: list[str] = []
|
|
79
|
+
|
|
80
|
+
for match in PARAM_REGEX.finditer(sql):
|
|
81
|
+
if match.group("dquote") or match.group("squote") or match.group("comment"):
|
|
82
|
+
continue
|
|
83
|
+
|
|
84
|
+
if match.group("var_name"): # Finds :var_name
|
|
85
|
+
var_name = match.group("var_name")
|
|
86
|
+
found_params_regex.append(var_name)
|
|
87
|
+
start = match.start("var_name") - 1
|
|
88
|
+
end = match.end("var_name")
|
|
89
|
+
|
|
90
|
+
if var_name not in parameters:
|
|
91
|
+
msg = f"Named parameter ':{var_name}' missing from parameters. SQL: {sql}"
|
|
92
|
+
raise SQLParsingError(msg)
|
|
93
|
+
|
|
94
|
+
processed_sql_parts.extend((sql[last_end:start], f"${param_index}"))
|
|
95
|
+
ordered_params.append(parameters[var_name])
|
|
96
|
+
last_end = end
|
|
97
|
+
param_index += 1
|
|
98
|
+
|
|
99
|
+
processed_sql_parts.append(sql[last_end:])
|
|
100
|
+
final_sql = "".join(processed_sql_parts)
|
|
101
|
+
|
|
102
|
+
if not found_params_regex and parameters:
|
|
103
|
+
logger.warning(
|
|
104
|
+
"Dict params provided (%s), but no :name placeholders found. SQL: %s",
|
|
105
|
+
list(parameters.keys()),
|
|
106
|
+
sql,
|
|
107
|
+
)
|
|
108
|
+
return sql, ()
|
|
109
|
+
|
|
110
|
+
provided_keys = set(parameters.keys())
|
|
111
|
+
found_keys = set(found_params_regex)
|
|
112
|
+
unused_keys = provided_keys - found_keys
|
|
113
|
+
if unused_keys:
|
|
114
|
+
logger.warning("Unused parameters provided: %s. SQL: %s", unused_keys, sql)
|
|
115
|
+
|
|
116
|
+
return final_sql, tuple(ordered_params)
|
|
117
|
+
|
|
118
|
+
# Case 2: Parameters are a sequence/scalar
|
|
119
|
+
if isinstance(parameters, (list, tuple)):
|
|
120
|
+
sequence_processed_parts: list[str] = []
|
|
121
|
+
param_index = 1
|
|
122
|
+
last_end = 0
|
|
123
|
+
qmark_found = False
|
|
124
|
+
|
|
125
|
+
for match in QMARK_REGEX.finditer(sql):
|
|
126
|
+
if match.group("dquote") or match.group("squote") or match.group("comment"):
|
|
127
|
+
continue
|
|
128
|
+
|
|
129
|
+
if match.group("qmark"):
|
|
130
|
+
qmark_found = True
|
|
131
|
+
start = match.start("qmark")
|
|
132
|
+
end = match.end("qmark")
|
|
133
|
+
sequence_processed_parts.extend((sql[last_end:start], f"${param_index}"))
|
|
134
|
+
last_end = end
|
|
135
|
+
param_index += 1
|
|
136
|
+
|
|
137
|
+
sequence_processed_parts.append(sql[last_end:])
|
|
138
|
+
final_sql = "".join(sequence_processed_parts)
|
|
139
|
+
|
|
140
|
+
if parameters and not qmark_found:
|
|
141
|
+
logger.warning("Sequence parameters provided, but no '?' placeholders found. SQL: %s", sql)
|
|
142
|
+
return sql, parameters
|
|
143
|
+
|
|
144
|
+
expected_params = param_index - 1
|
|
145
|
+
actual_params = len(parameters)
|
|
146
|
+
if expected_params != actual_params:
|
|
147
|
+
msg = f"Parameter count mismatch: Expected {expected_params}, got {actual_params}. SQL: {final_sql}"
|
|
148
|
+
raise SQLParsingError(msg)
|
|
149
|
+
|
|
150
|
+
return final_sql, parameters
|
|
151
|
+
|
|
152
|
+
# Case 3: Parameters are None
|
|
153
|
+
if PARAM_REGEX.search(sql) or QMARK_REGEX.search(sql):
|
|
154
|
+
# Perform a simpler check if any placeholders might exist if no params are given
|
|
155
|
+
for match in PARAM_REGEX.finditer(sql):
|
|
156
|
+
if not (match.group("dquote") or match.group("squote") or match.group("comment")) and match.group(
|
|
157
|
+
"var_name"
|
|
158
|
+
):
|
|
159
|
+
msg = f"SQL contains named parameters (:name) but no parameters provided. SQL: {sql}"
|
|
160
|
+
raise SQLParsingError(msg)
|
|
161
|
+
for match in QMARK_REGEX.finditer(sql):
|
|
162
|
+
if not (match.group("dquote") or match.group("squote") or match.group("comment")) and match.group(
|
|
163
|
+
"qmark"
|
|
164
|
+
):
|
|
165
|
+
msg = f"SQL contains positional parameters (?) but no parameters provided. SQL: {sql}"
|
|
166
|
+
raise SQLParsingError(msg)
|
|
167
|
+
|
|
168
|
+
return sql, ()
|
|
169
|
+
|
|
170
|
+
# --- Public API Methods --- #
|
|
171
|
+
@overload
|
|
172
|
+
async def select(
|
|
173
|
+
self,
|
|
174
|
+
sql: str,
|
|
175
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
176
|
+
/,
|
|
177
|
+
*,
|
|
178
|
+
connection: "Optional[Connection]" = None,
|
|
179
|
+
schema_type: None = None,
|
|
180
|
+
**kwargs: Any,
|
|
181
|
+
) -> "Sequence[dict[str, Any]]": ...
|
|
182
|
+
@overload
|
|
183
|
+
async def select(
|
|
184
|
+
self,
|
|
185
|
+
sql: str,
|
|
186
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
187
|
+
/,
|
|
188
|
+
*,
|
|
189
|
+
connection: "Optional[Connection]" = None,
|
|
190
|
+
schema_type: "type[ModelDTOT]",
|
|
191
|
+
**kwargs: Any,
|
|
192
|
+
) -> "Sequence[ModelDTOT]": ...
|
|
193
|
+
async def select(
|
|
194
|
+
self,
|
|
195
|
+
sql: str,
|
|
196
|
+
parameters: Optional["StatementParameterType"] = None,
|
|
197
|
+
/,
|
|
198
|
+
*,
|
|
199
|
+
connection: Optional["Connection"] = None,
|
|
200
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
201
|
+
**kwargs: Any,
|
|
202
|
+
) -> "Sequence[Union[ModelDTOT, dict[str, Any]]]":
|
|
203
|
+
connection = self._connection(connection)
|
|
204
|
+
sql, parameters = self._process_sql_params(sql, parameters, **kwargs)
|
|
205
|
+
parameters = parameters or [] # psqlpy expects a list/tuple
|
|
206
|
+
|
|
207
|
+
results: QueryResult = await connection.fetch(sql, parameters=parameters)
|
|
208
|
+
|
|
209
|
+
if schema_type is None:
|
|
210
|
+
return cast("list[dict[str, Any]]", results.result())
|
|
211
|
+
return results.as_class(as_class=schema_type)
|
|
212
|
+
|
|
213
|
+
@overload
|
|
214
|
+
async def select_one(
|
|
215
|
+
self,
|
|
216
|
+
sql: str,
|
|
217
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
218
|
+
/,
|
|
219
|
+
*,
|
|
220
|
+
connection: "Optional[Connection]" = None,
|
|
221
|
+
schema_type: None = None,
|
|
222
|
+
**kwargs: Any,
|
|
223
|
+
) -> "dict[str, Any]": ...
|
|
224
|
+
@overload
|
|
225
|
+
async def select_one(
|
|
226
|
+
self,
|
|
227
|
+
sql: str,
|
|
228
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
229
|
+
/,
|
|
230
|
+
*,
|
|
231
|
+
connection: "Optional[Connection]" = None,
|
|
232
|
+
schema_type: "type[ModelDTOT]",
|
|
233
|
+
**kwargs: Any,
|
|
234
|
+
) -> "ModelDTOT": ...
|
|
235
|
+
async def select_one(
|
|
236
|
+
self,
|
|
237
|
+
sql: str,
|
|
238
|
+
parameters: Optional["StatementParameterType"] = None,
|
|
239
|
+
/,
|
|
240
|
+
*,
|
|
241
|
+
connection: Optional["Connection"] = None,
|
|
242
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
243
|
+
**kwargs: Any,
|
|
244
|
+
) -> "Union[ModelDTOT, dict[str, Any]]":
|
|
245
|
+
connection = self._connection(connection)
|
|
246
|
+
sql, parameters = self._process_sql_params(sql, parameters, **kwargs)
|
|
247
|
+
parameters = parameters or []
|
|
248
|
+
|
|
249
|
+
result = await connection.fetch(sql, parameters=parameters)
|
|
250
|
+
|
|
251
|
+
if schema_type is None:
|
|
252
|
+
result = cast("list[dict[str, Any]]", result.result()) # type: ignore[assignment]
|
|
253
|
+
return cast("dict[str, Any]", result[0]) # type: ignore[index]
|
|
254
|
+
return result.as_class(as_class=schema_type)[0]
|
|
255
|
+
|
|
256
|
+
@overload
|
|
257
|
+
async def select_one_or_none(
|
|
258
|
+
self,
|
|
259
|
+
sql: str,
|
|
260
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
261
|
+
/,
|
|
262
|
+
*,
|
|
263
|
+
connection: "Optional[Connection]" = None,
|
|
264
|
+
schema_type: None = None,
|
|
265
|
+
**kwargs: Any,
|
|
266
|
+
) -> "Optional[dict[str, Any]]": ...
|
|
267
|
+
@overload
|
|
268
|
+
async def select_one_or_none(
|
|
269
|
+
self,
|
|
270
|
+
sql: str,
|
|
271
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
272
|
+
/,
|
|
273
|
+
*,
|
|
274
|
+
connection: "Optional[Connection]" = None,
|
|
275
|
+
schema_type: "type[ModelDTOT]",
|
|
276
|
+
**kwargs: Any,
|
|
277
|
+
) -> "Optional[ModelDTOT]": ...
|
|
278
|
+
async def select_one_or_none(
|
|
279
|
+
self,
|
|
280
|
+
sql: str,
|
|
281
|
+
parameters: Optional["StatementParameterType"] = None,
|
|
282
|
+
/,
|
|
283
|
+
*,
|
|
284
|
+
connection: Optional["Connection"] = None,
|
|
285
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
286
|
+
**kwargs: Any,
|
|
287
|
+
) -> "Optional[Union[ModelDTOT, dict[str, Any]]]":
|
|
288
|
+
connection = self._connection(connection)
|
|
289
|
+
sql, parameters = self._process_sql_params(sql, parameters, **kwargs)
|
|
290
|
+
parameters = parameters or []
|
|
291
|
+
|
|
292
|
+
result = await connection.fetch(sql, parameters=parameters)
|
|
293
|
+
if schema_type is None:
|
|
294
|
+
result = cast("list[dict[str, Any]]", result.result()) # type: ignore[assignment]
|
|
295
|
+
if len(result) == 0: # type: ignore[arg-type]
|
|
296
|
+
return None
|
|
297
|
+
return cast("dict[str, Any]", result[0]) # type: ignore[index]
|
|
298
|
+
result = cast("list[ModelDTOT]", result.as_class(as_class=schema_type)) # type: ignore[assignment]
|
|
299
|
+
if len(result) == 0: # type: ignore[arg-type]
|
|
300
|
+
return None
|
|
301
|
+
return cast("ModelDTOT", result[0]) # type: ignore[index]
|
|
302
|
+
|
|
303
|
+
@overload
|
|
304
|
+
async def select_value(
|
|
305
|
+
self,
|
|
306
|
+
sql: str,
|
|
307
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
308
|
+
/,
|
|
309
|
+
*,
|
|
310
|
+
connection: "Optional[Connection]" = None,
|
|
311
|
+
schema_type: None = None,
|
|
312
|
+
**kwargs: Any,
|
|
313
|
+
) -> "Any": ...
|
|
314
|
+
@overload
|
|
315
|
+
async def select_value(
|
|
316
|
+
self,
|
|
317
|
+
sql: str,
|
|
318
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
319
|
+
/,
|
|
320
|
+
*,
|
|
321
|
+
connection: "Optional[Connection]" = None,
|
|
322
|
+
schema_type: "type[T]",
|
|
323
|
+
**kwargs: Any,
|
|
324
|
+
) -> "T": ...
|
|
325
|
+
async def select_value(
|
|
326
|
+
self,
|
|
327
|
+
sql: str,
|
|
328
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
329
|
+
/,
|
|
330
|
+
*,
|
|
331
|
+
connection: "Optional[Connection]" = None,
|
|
332
|
+
schema_type: "Optional[type[T]]" = None,
|
|
333
|
+
**kwargs: Any,
|
|
334
|
+
) -> "Union[T, Any]":
|
|
335
|
+
connection = self._connection(connection)
|
|
336
|
+
sql, parameters = self._process_sql_params(sql, parameters, **kwargs)
|
|
337
|
+
parameters = parameters or []
|
|
338
|
+
|
|
339
|
+
value = await connection.fetch_val(sql, parameters=parameters)
|
|
340
|
+
|
|
341
|
+
if schema_type is None:
|
|
342
|
+
return value
|
|
343
|
+
return schema_type(value) # type: ignore[call-arg]
|
|
344
|
+
|
|
345
|
+
@overload
|
|
346
|
+
async def select_value_or_none(
|
|
347
|
+
self,
|
|
348
|
+
sql: str,
|
|
349
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
350
|
+
/,
|
|
351
|
+
*,
|
|
352
|
+
connection: "Optional[Connection]" = None,
|
|
353
|
+
schema_type: None = None,
|
|
354
|
+
**kwargs: Any,
|
|
355
|
+
) -> "Optional[Any]": ...
|
|
356
|
+
@overload
|
|
357
|
+
async def select_value_or_none(
|
|
358
|
+
self,
|
|
359
|
+
sql: str,
|
|
360
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
361
|
+
/,
|
|
362
|
+
*,
|
|
363
|
+
connection: "Optional[Connection]" = None,
|
|
364
|
+
schema_type: "type[T]",
|
|
365
|
+
**kwargs: Any,
|
|
366
|
+
) -> "Optional[T]": ...
|
|
367
|
+
async def select_value_or_none(
|
|
368
|
+
self,
|
|
369
|
+
sql: str,
|
|
370
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
371
|
+
/,
|
|
372
|
+
*,
|
|
373
|
+
connection: "Optional[Connection]" = None,
|
|
374
|
+
schema_type: "Optional[type[T]]" = None,
|
|
375
|
+
**kwargs: Any,
|
|
376
|
+
) -> "Optional[Union[T, Any]]":
|
|
377
|
+
connection = self._connection(connection)
|
|
378
|
+
sql, parameters = self._process_sql_params(sql, parameters, **kwargs)
|
|
379
|
+
parameters = parameters or []
|
|
380
|
+
try:
|
|
381
|
+
value = await connection.fetch_val(sql, parameters=parameters)
|
|
382
|
+
except RustPSQLDriverPyBaseError:
|
|
383
|
+
return None
|
|
384
|
+
|
|
385
|
+
if value is None:
|
|
386
|
+
return None
|
|
387
|
+
if schema_type is None:
|
|
388
|
+
return value
|
|
389
|
+
return schema_type(value) # type: ignore[call-arg]
|
|
390
|
+
|
|
391
|
+
async def insert_update_delete(
|
|
392
|
+
self,
|
|
393
|
+
sql: str,
|
|
394
|
+
parameters: Optional["StatementParameterType"] = None,
|
|
395
|
+
/,
|
|
396
|
+
*,
|
|
397
|
+
connection: Optional["Connection"] = None,
|
|
398
|
+
**kwargs: Any,
|
|
399
|
+
) -> int:
|
|
400
|
+
connection = self._connection(connection)
|
|
401
|
+
sql, parameters = self._process_sql_params(sql, parameters, **kwargs)
|
|
402
|
+
parameters = parameters or []
|
|
403
|
+
|
|
404
|
+
await connection.execute(sql, parameters=parameters)
|
|
405
|
+
# For INSERT/UPDATE/DELETE, psqlpy returns an empty list but the operation succeeded
|
|
406
|
+
# if no error was raised
|
|
407
|
+
return 1
|
|
408
|
+
|
|
409
|
+
@overload
|
|
410
|
+
async def insert_update_delete_returning(
|
|
411
|
+
self,
|
|
412
|
+
sql: str,
|
|
413
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
414
|
+
/,
|
|
415
|
+
*,
|
|
416
|
+
connection: "Optional[Connection]" = None,
|
|
417
|
+
schema_type: None = None,
|
|
418
|
+
**kwargs: Any,
|
|
419
|
+
) -> "dict[str, Any]": ...
|
|
420
|
+
@overload
|
|
421
|
+
async def insert_update_delete_returning(
|
|
422
|
+
self,
|
|
423
|
+
sql: str,
|
|
424
|
+
parameters: "Optional[StatementParameterType]" = None,
|
|
425
|
+
/,
|
|
426
|
+
*,
|
|
427
|
+
connection: "Optional[Connection]" = None,
|
|
428
|
+
schema_type: "type[ModelDTOT]",
|
|
429
|
+
**kwargs: Any,
|
|
430
|
+
) -> "ModelDTOT": ...
|
|
431
|
+
async def insert_update_delete_returning(
|
|
432
|
+
self,
|
|
433
|
+
sql: str,
|
|
434
|
+
parameters: Optional["StatementParameterType"] = None,
|
|
435
|
+
/,
|
|
436
|
+
*,
|
|
437
|
+
connection: Optional["Connection"] = None,
|
|
438
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
439
|
+
**kwargs: Any,
|
|
440
|
+
) -> "Optional[Union[dict[str, Any], ModelDTOT]]":
|
|
441
|
+
connection = self._connection(connection)
|
|
442
|
+
sql, parameters = self._process_sql_params(sql, parameters, **kwargs)
|
|
443
|
+
parameters = parameters or []
|
|
444
|
+
|
|
445
|
+
result = await connection.execute(sql, parameters=parameters)
|
|
446
|
+
if schema_type is None:
|
|
447
|
+
result = result.result() # type: ignore[assignment]
|
|
448
|
+
if len(result) == 0: # type: ignore[arg-type]
|
|
449
|
+
return None
|
|
450
|
+
return cast("dict[str, Any]", result[0]) # type: ignore[index]
|
|
451
|
+
result = result.as_class(as_class=schema_type) # type: ignore[assignment]
|
|
452
|
+
if len(result) == 0: # type: ignore[arg-type]
|
|
453
|
+
return None
|
|
454
|
+
return cast("ModelDTOT", result[0]) # type: ignore[index]
|
|
455
|
+
|
|
456
|
+
async def execute_script(
|
|
457
|
+
self,
|
|
458
|
+
sql: str,
|
|
459
|
+
parameters: Optional["StatementParameterType"] = None,
|
|
460
|
+
/,
|
|
461
|
+
*,
|
|
462
|
+
connection: Optional["Connection"] = None,
|
|
463
|
+
**kwargs: Any,
|
|
464
|
+
) -> str:
|
|
465
|
+
connection = self._connection(connection)
|
|
466
|
+
sql, parameters = self._process_sql_params(sql, parameters, **kwargs)
|
|
467
|
+
parameters = parameters or []
|
|
468
|
+
|
|
469
|
+
await connection.execute(sql, parameters=parameters)
|
|
470
|
+
return sql
|
|
471
|
+
|
|
472
|
+
def _connection(self, connection: Optional["Connection"] = None) -> "Connection":
|
|
473
|
+
"""Get the connection to use.
|
|
474
|
+
|
|
475
|
+
Args:
|
|
476
|
+
connection: Optional connection to use. If not provided, use the default connection.
|
|
477
|
+
|
|
478
|
+
Returns:
|
|
479
|
+
The connection to use.
|
|
480
|
+
"""
|
|
481
|
+
return connection or self.connection
|
|
@@ -1,11 +1,16 @@
|
|
|
1
|
-
from sqlspec.adapters.psycopg.config import
|
|
1
|
+
from sqlspec.adapters.psycopg.config import (
|
|
2
|
+
PsycopgAsyncConfig,
|
|
3
|
+
PsycopgAsyncPoolConfig,
|
|
4
|
+
PsycopgSyncConfig,
|
|
5
|
+
PsycopgSyncPoolConfig,
|
|
6
|
+
)
|
|
2
7
|
from sqlspec.adapters.psycopg.driver import PsycopgAsyncDriver, PsycopgSyncDriver
|
|
3
8
|
|
|
4
9
|
__all__ = (
|
|
5
|
-
"
|
|
10
|
+
"PsycopgAsyncConfig",
|
|
6
11
|
"PsycopgAsyncDriver",
|
|
7
|
-
"
|
|
8
|
-
"
|
|
12
|
+
"PsycopgAsyncPoolConfig",
|
|
13
|
+
"PsycopgSyncConfig",
|
|
9
14
|
"PsycopgSyncDriver",
|
|
10
|
-
"
|
|
15
|
+
"PsycopgSyncPoolConfig",
|
|
11
16
|
)
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
from sqlspec.adapters.psycopg.config._async import
|
|
2
|
-
from sqlspec.adapters.psycopg.config._sync import
|
|
1
|
+
from sqlspec.adapters.psycopg.config._async import PsycopgAsyncConfig, PsycopgAsyncPoolConfig
|
|
2
|
+
from sqlspec.adapters.psycopg.config._sync import PsycopgSyncConfig, PsycopgSyncPoolConfig
|
|
3
3
|
|
|
4
4
|
__all__ = (
|
|
5
|
-
"
|
|
6
|
-
"
|
|
7
|
-
"
|
|
8
|
-
"
|
|
5
|
+
"PsycopgAsyncConfig",
|
|
6
|
+
"PsycopgAsyncPoolConfig",
|
|
7
|
+
"PsycopgSyncConfig",
|
|
8
|
+
"PsycopgSyncPoolConfig",
|
|
9
9
|
)
|
|
@@ -16,18 +16,18 @@ if TYPE_CHECKING:
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
__all__ = (
|
|
19
|
-
"
|
|
20
|
-
"
|
|
19
|
+
"PsycopgAsyncConfig",
|
|
20
|
+
"PsycopgAsyncPoolConfig",
|
|
21
21
|
)
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
@dataclass
|
|
25
|
-
class
|
|
25
|
+
class PsycopgAsyncPoolConfig(PsycopgGenericPoolConfig[AsyncConnection, AsyncConnectionPool]):
|
|
26
26
|
"""Async Psycopg Pool Config"""
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
@dataclass
|
|
30
|
-
class
|
|
30
|
+
class PsycopgAsyncConfig(AsyncDatabaseConfig[AsyncConnection, AsyncConnectionPool, PsycopgAsyncDriver]):
|
|
31
31
|
"""Async Psycopg database Configuration.
|
|
32
32
|
|
|
33
33
|
This class provides the base configuration for Psycopg database connections, extending
|
|
@@ -37,7 +37,7 @@ class PsycopgAsync(AsyncDatabaseConfig[AsyncConnection, AsyncConnectionPool, Psy
|
|
|
37
37
|
with both synchronous and asynchronous connections.([2](https://www.psycopg.org/psycopg3/docs/api/connections.html))
|
|
38
38
|
"""
|
|
39
39
|
|
|
40
|
-
pool_config: "Optional[
|
|
40
|
+
pool_config: "Optional[PsycopgAsyncPoolConfig]" = None
|
|
41
41
|
"""Psycopg Pool configuration"""
|
|
42
42
|
pool_instance: "Optional[AsyncConnectionPool]" = None
|
|
43
43
|
"""Optional pool to use"""
|
|
@@ -71,7 +71,7 @@ class PsycopgAsync(AsyncDatabaseConfig[AsyncConnection, AsyncConnectionPool, Psy
|
|
|
71
71
|
self.pool_config,
|
|
72
72
|
exclude_empty=True,
|
|
73
73
|
convert_nested=False,
|
|
74
|
-
exclude=pool_only_params.union({"pool_instance", "connection_type", "driver_type"}),
|
|
74
|
+
exclude=pool_only_params.union({"pool_instance", "connection_type", "driver_type", "open"}),
|
|
75
75
|
)
|
|
76
76
|
msg = "You must provide a 'pool_config' for this adapter."
|
|
77
77
|
raise ImproperConfigurationError(msg)
|
|
@@ -94,7 +94,7 @@ class PsycopgAsync(AsyncDatabaseConfig[AsyncConnection, AsyncConnectionPool, Psy
|
|
|
94
94
|
raise ImproperConfigurationError(msg)
|
|
95
95
|
|
|
96
96
|
async def create_connection(self) -> "AsyncConnection":
|
|
97
|
-
"""Create and return a new psycopg async connection.
|
|
97
|
+
"""Create and return a new psycopg async connection from the pool.
|
|
98
98
|
|
|
99
99
|
Returns:
|
|
100
100
|
An AsyncConnection instance.
|
|
@@ -103,9 +103,8 @@ class PsycopgAsync(AsyncDatabaseConfig[AsyncConnection, AsyncConnectionPool, Psy
|
|
|
103
103
|
ImproperConfigurationError: If the connection could not be created.
|
|
104
104
|
"""
|
|
105
105
|
try:
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
return await AsyncConnection.connect(**self.connection_config_dict)
|
|
106
|
+
pool = await self.provide_pool()
|
|
107
|
+
return await pool.getconn()
|
|
109
108
|
except Exception as e:
|
|
110
109
|
msg = f"Could not configure the Psycopg connection. Error: {e!s}"
|
|
111
110
|
raise ImproperConfigurationError(msg) from e
|
|
@@ -128,10 +127,11 @@ class PsycopgAsync(AsyncDatabaseConfig[AsyncConnection, AsyncConnectionPool, Psy
|
|
|
128
127
|
raise ImproperConfigurationError(msg)
|
|
129
128
|
|
|
130
129
|
pool_config = self.pool_config_dict
|
|
131
|
-
self.pool_instance = AsyncConnectionPool(**pool_config)
|
|
130
|
+
self.pool_instance = AsyncConnectionPool(open=False, **pool_config)
|
|
132
131
|
if self.pool_instance is None: # pyright: ignore[reportUnnecessaryComparison]
|
|
133
132
|
msg = "Could not configure the 'pool_instance'. Please check your configuration." # type: ignore[unreachable]
|
|
134
133
|
raise ImproperConfigurationError(msg)
|
|
134
|
+
await self.pool_instance.open()
|
|
135
135
|
return self.pool_instance
|
|
136
136
|
|
|
137
137
|
def provide_pool(self, *args: "Any", **kwargs: "Any") -> "Awaitable[AsyncConnectionPool]":
|
|
@@ -150,7 +150,7 @@ class PsycopgAsync(AsyncDatabaseConfig[AsyncConnection, AsyncConnectionPool, Psy
|
|
|
150
150
|
AsyncConnection: A database connection from the pool.
|
|
151
151
|
"""
|
|
152
152
|
pool = await self.provide_pool(*args, **kwargs)
|
|
153
|
-
async with pool.connection() as connection:
|
|
153
|
+
async with pool, pool.connection() as connection:
|
|
154
154
|
yield connection
|
|
155
155
|
|
|
156
156
|
@asynccontextmanager
|
|
@@ -16,18 +16,18 @@ if TYPE_CHECKING:
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
__all__ = (
|
|
19
|
-
"
|
|
20
|
-
"
|
|
19
|
+
"PsycopgSyncConfig",
|
|
20
|
+
"PsycopgSyncPoolConfig",
|
|
21
21
|
)
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
@dataclass
|
|
25
|
-
class
|
|
25
|
+
class PsycopgSyncPoolConfig(PsycopgGenericPoolConfig[Connection, ConnectionPool]):
|
|
26
26
|
"""Sync Psycopg Pool Config"""
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
@dataclass
|
|
30
|
-
class
|
|
30
|
+
class PsycopgSyncConfig(SyncDatabaseConfig[Connection, ConnectionPool, PsycopgSyncDriver]):
|
|
31
31
|
"""Sync Psycopg database Configuration.
|
|
32
32
|
This class provides the base configuration for Psycopg database connections, extending
|
|
33
33
|
the generic database configuration with Psycopg-specific settings.([1](https://www.psycopg.org/psycopg3/docs/api/connections.html))
|
|
@@ -36,7 +36,7 @@ class PsycopgSync(SyncDatabaseConfig[Connection, ConnectionPool, PsycopgSyncDriv
|
|
|
36
36
|
with both synchronous and asynchronous connections.([2](https://www.psycopg.org/psycopg3/docs/api/connections.html))
|
|
37
37
|
"""
|
|
38
38
|
|
|
39
|
-
pool_config: "Optional[
|
|
39
|
+
pool_config: "Optional[PsycopgSyncPoolConfig]" = None
|
|
40
40
|
"""Psycopg Pool configuration"""
|
|
41
41
|
pool_instance: "Optional[ConnectionPool]" = None
|
|
42
42
|
"""Optional pool to use"""
|
|
@@ -70,7 +70,7 @@ class PsycopgSync(SyncDatabaseConfig[Connection, ConnectionPool, PsycopgSyncDriv
|
|
|
70
70
|
self.pool_config,
|
|
71
71
|
exclude_empty=True,
|
|
72
72
|
convert_nested=False,
|
|
73
|
-
exclude=pool_only_params.union({"pool_instance", "connection_type", "driver_type"}),
|
|
73
|
+
exclude=pool_only_params.union({"pool_instance", "connection_type", "driver_type", "open"}),
|
|
74
74
|
)
|
|
75
75
|
msg = "You must provide a 'pool_config' for this adapter."
|
|
76
76
|
raise ImproperConfigurationError(msg)
|
|
@@ -87,13 +87,13 @@ class PsycopgSync(SyncDatabaseConfig[Connection, ConnectionPool, PsycopgSyncDriv
|
|
|
87
87
|
self.pool_config,
|
|
88
88
|
exclude_empty=True,
|
|
89
89
|
convert_nested=False,
|
|
90
|
-
exclude={"pool_instance", "connection_type", "driver_type"},
|
|
90
|
+
exclude={"pool_instance", "connection_type", "driver_type", "open"},
|
|
91
91
|
)
|
|
92
92
|
msg = "'pool_config' methods can not be used when a 'pool_instance' is provided."
|
|
93
93
|
raise ImproperConfigurationError(msg)
|
|
94
94
|
|
|
95
95
|
def create_connection(self) -> "Connection":
|
|
96
|
-
"""Create and return a new psycopg connection.
|
|
96
|
+
"""Create and return a new psycopg connection from the pool.
|
|
97
97
|
|
|
98
98
|
Returns:
|
|
99
99
|
A Connection instance.
|
|
@@ -102,9 +102,8 @@ class PsycopgSync(SyncDatabaseConfig[Connection, ConnectionPool, PsycopgSyncDriv
|
|
|
102
102
|
ImproperConfigurationError: If the connection could not be created.
|
|
103
103
|
"""
|
|
104
104
|
try:
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
return connect(**self.connection_config_dict)
|
|
105
|
+
pool = self.provide_pool()
|
|
106
|
+
return pool.getconn()
|
|
108
107
|
except Exception as e:
|
|
109
108
|
msg = f"Could not configure the Psycopg connection. Error: {e!s}"
|
|
110
109
|
raise ImproperConfigurationError(msg) from e
|
|
@@ -127,10 +126,11 @@ class PsycopgSync(SyncDatabaseConfig[Connection, ConnectionPool, PsycopgSyncDriv
|
|
|
127
126
|
raise ImproperConfigurationError(msg)
|
|
128
127
|
|
|
129
128
|
pool_config = self.pool_config_dict
|
|
130
|
-
self.pool_instance = ConnectionPool(**pool_config)
|
|
129
|
+
self.pool_instance = ConnectionPool(open=False, **pool_config)
|
|
131
130
|
if self.pool_instance is None: # pyright: ignore[reportUnnecessaryComparison]
|
|
132
131
|
msg = "Could not configure the 'pool_instance'. Please check your configuration." # type: ignore[unreachable]
|
|
133
132
|
raise ImproperConfigurationError(msg)
|
|
133
|
+
self.pool_instance.open()
|
|
134
134
|
return self.pool_instance
|
|
135
135
|
|
|
136
136
|
def provide_pool(self, *args: "Any", **kwargs: "Any") -> "ConnectionPool":
|
|
@@ -149,7 +149,7 @@ class PsycopgSync(SyncDatabaseConfig[Connection, ConnectionPool, PsycopgSyncDriv
|
|
|
149
149
|
Connection: A database connection from the pool.
|
|
150
150
|
"""
|
|
151
151
|
pool = self.provide_pool(*args, **kwargs)
|
|
152
|
-
with pool.connection() as connection:
|
|
152
|
+
with pool, pool.connection() as connection:
|
|
153
153
|
yield connection
|
|
154
154
|
|
|
155
155
|
@contextmanager
|