asyncpg-typed 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- asyncpg_typed/__init__.py +660 -172
- asyncpg_typed-0.1.4.dist-info/METADATA +322 -0
- asyncpg_typed-0.1.4.dist-info/RECORD +8 -0
- {asyncpg_typed-0.1.2.dist-info → asyncpg_typed-0.1.4.dist-info}/licenses/LICENSE +1 -1
- asyncpg_typed-0.1.2.dist-info/METADATA +0 -213
- asyncpg_typed-0.1.2.dist-info/RECORD +0 -8
- {asyncpg_typed-0.1.2.dist-info → asyncpg_typed-0.1.4.dist-info}/WHEEL +0 -0
- {asyncpg_typed-0.1.2.dist-info → asyncpg_typed-0.1.4.dist-info}/top_level.txt +0 -0
- {asyncpg_typed-0.1.2.dist-info → asyncpg_typed-0.1.4.dist-info}/zip-safe +0 -0
asyncpg_typed/__init__.py
CHANGED
|
@@ -4,9 +4,9 @@ Type-safe queries for asyncpg.
|
|
|
4
4
|
:see: https://github.com/hunyadi/asyncpg_typed
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
__version__ = "0.1.
|
|
7
|
+
__version__ = "0.1.4"
|
|
8
8
|
__author__ = "Levente Hunyadi"
|
|
9
|
-
__copyright__ = "Copyright 2025, Levente Hunyadi"
|
|
9
|
+
__copyright__ = "Copyright 2025-2026, Levente Hunyadi"
|
|
10
10
|
__license__ = "MIT"
|
|
11
11
|
__maintainer__ = "Levente Hunyadi"
|
|
12
12
|
__status__ = "Production"
|
|
@@ -16,12 +16,14 @@ import sys
|
|
|
16
16
|
import typing
|
|
17
17
|
from abc import abstractmethod
|
|
18
18
|
from collections.abc import Callable, Iterable, Sequence
|
|
19
|
+
from dataclasses import dataclass
|
|
19
20
|
from datetime import date, datetime, time, timedelta
|
|
20
21
|
from decimal import Decimal
|
|
21
22
|
from functools import reduce
|
|
22
23
|
from io import StringIO
|
|
24
|
+
from ipaddress import IPv4Address, IPv4Network, IPv6Address, IPv6Network
|
|
23
25
|
from types import UnionType
|
|
24
|
-
from typing import Any, Protocol, TypeAlias, TypeVar, Union, get_args, get_origin, overload
|
|
26
|
+
from typing import Any, Protocol, TypeAlias, TypeGuard, TypeVar, Union, get_args, get_origin, overload
|
|
25
27
|
from uuid import UUID
|
|
26
28
|
|
|
27
29
|
import asyncpg
|
|
@@ -38,9 +40,32 @@ RequiredJsonType = bool | int | float | str | dict[str, "JsonType"] | list["Json
|
|
|
38
40
|
|
|
39
41
|
TargetType: TypeAlias = type[Any] | UnionType
|
|
40
42
|
|
|
43
|
+
Connection: TypeAlias = asyncpg.Connection | asyncpg.pool.PoolConnectionProxy
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class CountMismatchError(TypeError):
|
|
47
|
+
"Raised when a prepared statement takes or returns a different number of parameters or columns than declared in Python."
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class TypeMismatchError(TypeError):
|
|
51
|
+
"Raised when a prepared statement takes or returns a PostgreSQL type incompatible with the declared Python type."
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class NameMismatchError(TypeError):
|
|
55
|
+
"Raised when the name of a result-set column differs from what is declared in Python."
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class EnumMismatchError(TypeError):
|
|
59
|
+
"Raised when a prepared statement takes or returns a PostgreSQL enum type whose permitted set of values differs from what is declared in Python."
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class NoneTypeError(TypeError):
|
|
63
|
+
"Raised when a column marked as required contains a `NULL` value."
|
|
64
|
+
|
|
65
|
+
|
|
41
66
|
if sys.version_info >= (3, 11):
|
|
42
67
|
|
|
43
|
-
def is_enum_type(typ:
|
|
68
|
+
def is_enum_type(typ: Any) -> TypeGuard[type[enum.Enum]]:
|
|
44
69
|
"""
|
|
45
70
|
`True` if the specified type is an enumeration type.
|
|
46
71
|
"""
|
|
@@ -49,7 +74,7 @@ if sys.version_info >= (3, 11):
|
|
|
49
74
|
|
|
50
75
|
else:
|
|
51
76
|
|
|
52
|
-
def is_enum_type(typ:
|
|
77
|
+
def is_enum_type(typ: Any) -> TypeGuard[type[enum.Enum]]:
|
|
53
78
|
"""
|
|
54
79
|
`True` if the specified type is an enumeration type.
|
|
55
80
|
"""
|
|
@@ -91,6 +116,14 @@ def is_json_type(tp: Any) -> bool:
|
|
|
91
116
|
return tp in [JsonType, RequiredJsonType]
|
|
92
117
|
|
|
93
118
|
|
|
119
|
+
def is_inet_type(tp: Any) -> bool:
|
|
120
|
+
"""
|
|
121
|
+
`True` if the type represents an IP address or network.
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
return tp in [IPv4Address, IPv6Address, IPv4Network, IPv6Network]
|
|
125
|
+
|
|
126
|
+
|
|
94
127
|
def make_union_type(tpl: list[Any]) -> UnionType:
|
|
95
128
|
"""
|
|
96
129
|
Creates a `UnionType` (a.k.a. `A | B | C`) dynamically at run time.
|
|
@@ -119,25 +152,54 @@ def get_required_type(tp: Any) -> Any:
|
|
|
119
152
|
return type(None)
|
|
120
153
|
|
|
121
154
|
|
|
122
|
-
|
|
123
|
-
if typing.TYPE_CHECKING:
|
|
155
|
+
def _standard_json_decoder() -> Callable[[str], JsonType]:
|
|
124
156
|
import json
|
|
125
157
|
|
|
126
158
|
_json_decoder = json.JSONDecoder()
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
159
|
+
return _json_decoder.decode
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _json_decoder() -> Callable[[str], JsonType]:
|
|
163
|
+
if typing.TYPE_CHECKING:
|
|
164
|
+
return _standard_json_decoder()
|
|
165
|
+
else:
|
|
166
|
+
try:
|
|
167
|
+
import orjson
|
|
168
|
+
|
|
169
|
+
return orjson.loads
|
|
170
|
+
except ModuleNotFoundError:
|
|
171
|
+
return _standard_json_decoder()
|
|
131
172
|
|
|
132
|
-
_json_converter = orjson.loads
|
|
133
|
-
except ModuleNotFoundError:
|
|
134
|
-
import json
|
|
135
173
|
|
|
136
|
-
|
|
137
|
-
_json_converter = _json_decoder.decode
|
|
174
|
+
JSON_DECODER = _json_decoder()
|
|
138
175
|
|
|
139
176
|
|
|
140
|
-
def
|
|
177
|
+
def _standard_json_encoder() -> Callable[[JsonType], str]:
|
|
178
|
+
import json
|
|
179
|
+
|
|
180
|
+
_json_encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":"), allow_nan=False)
|
|
181
|
+
return _json_encoder.encode
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def _json_encoder() -> Callable[[JsonType], str]:
|
|
185
|
+
if typing.TYPE_CHECKING:
|
|
186
|
+
return _standard_json_encoder()
|
|
187
|
+
else:
|
|
188
|
+
try:
|
|
189
|
+
import orjson
|
|
190
|
+
|
|
191
|
+
def _wrap(value: JsonType) -> str:
|
|
192
|
+
return orjson.dumps(value).decode()
|
|
193
|
+
|
|
194
|
+
return _wrap
|
|
195
|
+
except ModuleNotFoundError:
|
|
196
|
+
return _standard_json_encoder()
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
JSON_ENCODER = _json_encoder()
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def get_output_converter_for(tp: Any) -> Callable[[Any], Any]:
|
|
141
203
|
"""
|
|
142
204
|
Returns a callable that takes a wire type and returns a target type.
|
|
143
205
|
|
|
@@ -147,101 +209,268 @@ def get_converter_for(tp: Any) -> Callable[[Any], Any]:
|
|
|
147
209
|
|
|
148
210
|
if is_json_type(tp):
|
|
149
211
|
# asyncpg returns fields of type `json` and `jsonb` as `str`, which must be de-serialized
|
|
150
|
-
return
|
|
212
|
+
return JSON_DECODER
|
|
151
213
|
else:
|
|
152
214
|
# target data types that require conversion must have a single-argument `__init__` that takes an object of the source type
|
|
153
215
|
return tp
|
|
154
216
|
|
|
155
217
|
|
|
218
|
+
def get_input_converter_for(tp: Any) -> Callable[[Any], Any]:
|
|
219
|
+
"""
|
|
220
|
+
Returns a callable that takes a source type and returns a wire type.
|
|
221
|
+
|
|
222
|
+
A source type is one of the types supported by the library.
|
|
223
|
+
A wire type is one of the types returned by asyncpg.
|
|
224
|
+
"""
|
|
225
|
+
|
|
226
|
+
if is_json_type(tp):
|
|
227
|
+
# asyncpg expects fields of type `json` and `jsonb` as `str`, which must be serialized
|
|
228
|
+
return JSON_ENCODER
|
|
229
|
+
else:
|
|
230
|
+
# source data types that require conversion must have a single-argument `__init__` that takes an object of the source type
|
|
231
|
+
return tp
|
|
232
|
+
|
|
233
|
+
|
|
156
234
|
# maps PostgreSQL internal type names to compatible Python types
|
|
157
|
-
|
|
235
|
+
_NAME_TO_TYPE: dict[str, tuple[Any, ...]] = {
|
|
236
|
+
# boolean type
|
|
158
237
|
"bool": (bool,),
|
|
238
|
+
# numeric types
|
|
159
239
|
"int2": (int,),
|
|
160
240
|
"int4": (int,),
|
|
161
241
|
"int8": (int,),
|
|
162
242
|
"float4": (float,),
|
|
163
243
|
"float8": (float,),
|
|
164
244
|
"numeric": (Decimal,),
|
|
245
|
+
# date and time types
|
|
165
246
|
"date": (date,),
|
|
166
247
|
"time": (time,),
|
|
167
248
|
"timetz": (time,),
|
|
168
249
|
"timestamp": (datetime,),
|
|
169
250
|
"timestamptz": (datetime,),
|
|
170
251
|
"interval": (timedelta,),
|
|
252
|
+
# character sequence types
|
|
171
253
|
"bpchar": (str,),
|
|
172
254
|
"varchar": (str,),
|
|
173
255
|
"text": (str,),
|
|
256
|
+
# binary sequence types
|
|
174
257
|
"bytea": (bytes,),
|
|
258
|
+
# unique identifier type
|
|
259
|
+
"uuid": (UUID,),
|
|
260
|
+
# address types
|
|
261
|
+
"cidr": (IPv4Network, IPv6Network, IPv4Network | IPv6Network),
|
|
262
|
+
"inet": (IPv4Network, IPv6Network, IPv4Network | IPv6Network, IPv4Address, IPv6Address, IPv4Address | IPv6Address),
|
|
263
|
+
"macaddr": (str,),
|
|
264
|
+
"macaddr8": (str,),
|
|
265
|
+
# JSON type
|
|
175
266
|
"json": (str, RequiredJsonType),
|
|
176
267
|
"jsonb": (str, RequiredJsonType),
|
|
177
|
-
|
|
268
|
+
# XML type
|
|
178
269
|
"xml": (str,),
|
|
270
|
+
# geometric types
|
|
271
|
+
"point": (asyncpg.Point,),
|
|
272
|
+
"line": (asyncpg.Line,),
|
|
273
|
+
"lseg": (asyncpg.LineSegment,),
|
|
274
|
+
"box": (asyncpg.Box,),
|
|
275
|
+
"path": (asyncpg.Path,),
|
|
276
|
+
"polygon": (asyncpg.Polygon,),
|
|
277
|
+
"circle": (asyncpg.Circle,),
|
|
278
|
+
# range types
|
|
279
|
+
"int4range": (asyncpg.Range[int],),
|
|
280
|
+
"int4multirange": (list[asyncpg.Range[int]],),
|
|
281
|
+
"int8range": (asyncpg.Range[int],),
|
|
282
|
+
"int8multirange": (list[asyncpg.Range[int]],),
|
|
283
|
+
"numrange": (asyncpg.Range[Decimal],),
|
|
284
|
+
"nummultirange": (list[asyncpg.Range[Decimal]],),
|
|
285
|
+
"tsrange": (asyncpg.Range[datetime],),
|
|
286
|
+
"tsmultirange": (list[asyncpg.Range[datetime]],),
|
|
287
|
+
"tstzrange": (asyncpg.Range[datetime],),
|
|
288
|
+
"tstzmultirange": (list[asyncpg.Range[datetime]],),
|
|
289
|
+
"daterange": (asyncpg.Range[date],),
|
|
290
|
+
"datemultirange": (list[asyncpg.Range[date]],),
|
|
179
291
|
}
|
|
180
292
|
|
|
181
293
|
|
|
182
|
-
def
|
|
294
|
+
def type_to_str(tp: Any) -> str:
|
|
295
|
+
"Emits a friendly name for a type."
|
|
296
|
+
|
|
297
|
+
if isinstance(tp, type):
|
|
298
|
+
return tp.__name__
|
|
299
|
+
else:
|
|
300
|
+
return str(tp)
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
class _TypeVerifier:
|
|
183
304
|
"""
|
|
184
305
|
Verifies if the Python target type can represent the PostgreSQL source type.
|
|
185
306
|
"""
|
|
186
307
|
|
|
187
|
-
|
|
188
|
-
if is_enum_type(data_type):
|
|
189
|
-
return name in ["bpchar", "varchar", "text"]
|
|
308
|
+
_connection: Connection
|
|
190
309
|
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
310
|
+
def __init__(self, connection: Connection) -> None:
|
|
311
|
+
self._connection = connection
|
|
312
|
+
|
|
313
|
+
async def _check_enum_type(self, pg_name: str, pg_type: asyncpg.Type, data_type: type[enum.Enum]) -> None:
|
|
314
|
+
"""
|
|
315
|
+
Verifies if a Python enumeration type matches a PostgreSQL enumeration type.
|
|
316
|
+
"""
|
|
317
|
+
|
|
318
|
+
for e in data_type:
|
|
319
|
+
if not isinstance(e.value, str):
|
|
320
|
+
raise TypeMismatchError(f"expected: Python enum type `{type_to_str(data_type)}` with `str` values; got: `{type_to_str(type(e.value))}` for enum field `{e.name}`")
|
|
321
|
+
|
|
322
|
+
py_values = set(e.value for e in data_type)
|
|
196
323
|
|
|
197
|
-
|
|
198
|
-
|
|
324
|
+
rows = await self._connection.fetch("SELECT enumlabel FROM pg_enum WHERE enumtypid = $1 ORDER BY enumsortorder;", pg_type.oid)
|
|
325
|
+
db_values = set(row[0] for row in rows)
|
|
199
326
|
|
|
327
|
+
db_extra = db_values - py_values
|
|
328
|
+
if db_extra:
|
|
329
|
+
raise EnumMismatchError(f"expected: Python enum type `{type_to_str(data_type)}` to match values of PostgreSQL enum type `{pg_type.name}` for {pg_name}; missing value(s): {', '.join(f'`{val}`' for val in db_extra)})")
|
|
200
330
|
|
|
201
|
-
|
|
331
|
+
py_extra = py_values - db_values
|
|
332
|
+
if py_extra:
|
|
333
|
+
raise EnumMismatchError(f"expected: Python enum type `{type_to_str(data_type)}` to match values of PostgreSQL enum type `{pg_type.name}` for {pg_name}; got extra value(s): {', '.join(f'`{val}`' for val in py_extra)})")
|
|
334
|
+
|
|
335
|
+
async def check_data_type(self, pg_name: str, pg_type: asyncpg.Type, data_type: TargetType) -> None:
|
|
336
|
+
"""
|
|
337
|
+
Verifies if the Python target type can represent the PostgreSQL source type.
|
|
338
|
+
"""
|
|
339
|
+
|
|
340
|
+
if pg_type.schema == "pg_catalog": # well-known PostgreSQL types
|
|
341
|
+
if is_enum_type(data_type):
|
|
342
|
+
if pg_type.name not in ["bpchar", "varchar", "text"]:
|
|
343
|
+
raise TypeMismatchError(f"expected: Python enum type `{type_to_str(data_type)}` for {pg_name}; got: PostgreSQL type `{pg_type.kind}` of `{pg_type.name}` instead of `char`, `varchar` or `text`")
|
|
344
|
+
elif pg_type.kind == "array" and get_origin(data_type) is list:
|
|
345
|
+
if not pg_type.name.endswith("[]"):
|
|
346
|
+
raise TypeMismatchError(f"expected: Python list type `{type_to_str(data_type)}` for {pg_name}; got: PostgreSQL type `{pg_type.kind}` of `{pg_type.name}` instead of array")
|
|
347
|
+
|
|
348
|
+
expected_types = _NAME_TO_TYPE.get(pg_type.name[:-2])
|
|
349
|
+
if expected_types is None:
|
|
350
|
+
raise TypeMismatchError(f"expected: Python list type `{type_to_str(data_type)}` for {pg_name}; got: unrecognized PostgreSQL type `{pg_type.kind}` of `{pg_type.name}`")
|
|
351
|
+
elif get_args(data_type)[0] not in expected_types:
|
|
352
|
+
if len(expected_types) == 1:
|
|
353
|
+
target = f"the Python type `{type_to_str(expected_types[0])}`"
|
|
354
|
+
else:
|
|
355
|
+
target = f"one of the Python types {', '.join(f'`{type_to_str(tp)}`' for tp in expected_types)}"
|
|
356
|
+
raise TypeMismatchError(f"expected: Python list type `{type_to_str(data_type)}` for {pg_name}; got: incompatible PostgreSQL type `{pg_type.kind}` of `{pg_type.name}` whose elements convert to {target}")
|
|
357
|
+
else:
|
|
358
|
+
expected_types = _NAME_TO_TYPE.get(pg_type.name)
|
|
359
|
+
if expected_types is None:
|
|
360
|
+
raise TypeMismatchError(f"expected: Python type `{type_to_str(data_type)}` for {pg_name}; got: unrecognized PostgreSQL type `{pg_type.kind}` of `{pg_type.name}`")
|
|
361
|
+
elif data_type not in expected_types:
|
|
362
|
+
if len(expected_types) == 1:
|
|
363
|
+
target = f"the Python type `{type_to_str(expected_types[0])}`"
|
|
364
|
+
else:
|
|
365
|
+
target = f"one of the Python types {', '.join(f'`{type_to_str(tp)}`' for tp in expected_types)}"
|
|
366
|
+
raise TypeMismatchError(f"expected: Python type `{type_to_str(data_type)}` for {pg_name}; got: incompatible PostgreSQL type `{pg_type.kind}` of `{pg_type.name}` which converts to {target}")
|
|
367
|
+
elif pg_type.kind == "composite": # PostgreSQL composite types
|
|
368
|
+
# user-defined composite types registered with `conn.set_type_codec()` typically using `format="tuple"`
|
|
369
|
+
pass
|
|
370
|
+
else: # custom PostgreSQL types
|
|
371
|
+
if is_enum_type(data_type):
|
|
372
|
+
await self._check_enum_type(pg_name, pg_type, data_type)
|
|
373
|
+
elif is_standard_type(data_type):
|
|
374
|
+
raise TypeMismatchError(f"expected: Python type `{type_to_str(data_type)}` for {pg_name}; got: PostgreSQL type `{pg_type.kind}` of `{pg_type.name}`")
|
|
375
|
+
else:
|
|
376
|
+
# user-defined types registered with `conn.set_type_codec()`
|
|
377
|
+
pass
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
@dataclass(frozen=True)
|
|
381
|
+
class _Placeholder:
|
|
202
382
|
ordinal: int
|
|
203
383
|
data_type: TargetType
|
|
204
384
|
|
|
205
|
-
def __init__(self, ordinal: int, data_type: TargetType) -> None:
|
|
206
|
-
self.ordinal = ordinal
|
|
207
|
-
self.data_type = data_type
|
|
208
|
-
|
|
209
385
|
def __repr__(self) -> str:
|
|
210
386
|
return f"{self.__class__.__name__}({self.ordinal}, {self.data_type!r})"
|
|
211
387
|
|
|
212
388
|
|
|
389
|
+
class _ResultsetWrapper:
|
|
390
|
+
"Wraps result-set rows into a tuple or named tuple."
|
|
391
|
+
|
|
392
|
+
init: Callable[..., tuple[Any, ...]] | None
|
|
393
|
+
iterable: Callable[[Iterable[Any]], tuple[Any, ...]]
|
|
394
|
+
|
|
395
|
+
def __init__(self, init: Callable[..., tuple[Any, ...]] | None, iterable: Callable[[Iterable[Any]], tuple[Any, ...]]) -> None:
|
|
396
|
+
"""
|
|
397
|
+
Initializes a result-set wrapper.
|
|
398
|
+
|
|
399
|
+
:param init: Initializer function that takes as many arguments as columns in the result-set.
|
|
400
|
+
:param iterable: Initializer function that takes an iterable over columns of a result-set row.
|
|
401
|
+
"""
|
|
402
|
+
|
|
403
|
+
self.init = init
|
|
404
|
+
self.iterable = iterable
|
|
405
|
+
|
|
406
|
+
|
|
213
407
|
class _SQLObject:
|
|
214
408
|
"""
|
|
215
409
|
Associates input and output type information with a SQL statement.
|
|
216
410
|
"""
|
|
217
411
|
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
412
|
+
_parameter_data_types: tuple[_Placeholder, ...]
|
|
413
|
+
_resultset_data_types: tuple[TargetType, ...]
|
|
414
|
+
_resultset_column_names: tuple[str, ...] | None
|
|
415
|
+
_resultset_wrapper: _ResultsetWrapper
|
|
416
|
+
_parameter_cast: int
|
|
417
|
+
_parameter_converters: tuple[Callable[[Any], Any], ...]
|
|
418
|
+
_required: int
|
|
419
|
+
_resultset_cast: int
|
|
420
|
+
_resultset_converters: tuple[Callable[[Any], Any], ...]
|
|
421
|
+
|
|
422
|
+
@property
|
|
423
|
+
def parameter_data_types(self) -> tuple[_Placeholder, ...]:
|
|
424
|
+
"Expected inbound parameter data types."
|
|
425
|
+
|
|
426
|
+
return self._parameter_data_types
|
|
427
|
+
|
|
428
|
+
@property
|
|
429
|
+
def resultset_data_types(self) -> tuple[TargetType, ...]:
|
|
430
|
+
"Expected column data types in the result-set."
|
|
431
|
+
|
|
432
|
+
return self._resultset_data_types
|
|
433
|
+
|
|
434
|
+
@property
|
|
435
|
+
def resultset_column_names(self) -> tuple[str, ...] | None:
|
|
436
|
+
"Expected column names in the result-set."
|
|
437
|
+
|
|
438
|
+
return self._resultset_column_names
|
|
223
439
|
|
|
224
440
|
def __init__(
|
|
225
441
|
self,
|
|
226
|
-
|
|
227
|
-
|
|
442
|
+
*,
|
|
443
|
+
args: tuple[TargetType, ...],
|
|
444
|
+
resultset: tuple[TargetType, ...],
|
|
445
|
+
names: tuple[str, ...] | None,
|
|
446
|
+
wrapper: _ResultsetWrapper,
|
|
228
447
|
) -> None:
|
|
229
|
-
self.
|
|
230
|
-
self.
|
|
448
|
+
self._parameter_data_types = tuple(_Placeholder(ordinal, get_required_type(arg)) for ordinal, arg in enumerate(args, start=1))
|
|
449
|
+
self._resultset_data_types = tuple(get_required_type(data_type) for data_type in resultset)
|
|
450
|
+
self._resultset_column_names = names
|
|
451
|
+
self._resultset_wrapper = wrapper
|
|
452
|
+
|
|
453
|
+
# create a bit-field of types that require cast or serialization (1: apply conversion; 0: forward value as-is)
|
|
454
|
+
parameter_cast = 0
|
|
455
|
+
for index, placeholder in enumerate(self._parameter_data_types):
|
|
456
|
+
parameter_cast |= is_json_type(placeholder.data_type) << index
|
|
457
|
+
self._parameter_cast = parameter_cast
|
|
458
|
+
|
|
459
|
+
self._parameter_converters = tuple(get_input_converter_for(placeholder.data_type) for placeholder in self._parameter_data_types)
|
|
231
460
|
|
|
232
461
|
# create a bit-field of required types (1: required; 0: optional)
|
|
233
462
|
required = 0
|
|
234
|
-
for index, data_type in enumerate(
|
|
463
|
+
for index, data_type in enumerate(resultset):
|
|
235
464
|
required |= (not is_optional_type(data_type)) << index
|
|
236
|
-
self.
|
|
465
|
+
self._required = required
|
|
237
466
|
|
|
238
|
-
# create a bit-field of types that require cast or serialization (1: apply conversion; 0: forward value as-is)
|
|
239
|
-
|
|
240
|
-
for index, data_type in enumerate(self.
|
|
241
|
-
|
|
242
|
-
self.
|
|
467
|
+
# create a bit-field of types that require cast or de-serialization (1: apply conversion; 0: forward value as-is)
|
|
468
|
+
resultset_cast = 0
|
|
469
|
+
for index, data_type in enumerate(self._resultset_data_types):
|
|
470
|
+
resultset_cast |= (is_enum_type(data_type) or is_json_type(data_type) or is_inet_type(data_type)) << index
|
|
471
|
+
self._resultset_cast = resultset_cast
|
|
243
472
|
|
|
244
|
-
self.
|
|
473
|
+
self._resultset_converters = tuple(get_output_converter_for(data_type) for data_type in self._resultset_data_types)
|
|
245
474
|
|
|
246
475
|
def _raise_required_is_none(self, row: tuple[Any, ...], row_index: int | None = None) -> None:
|
|
247
476
|
"""
|
|
@@ -249,12 +478,12 @@ class _SQLObject:
|
|
|
249
478
|
"""
|
|
250
479
|
|
|
251
480
|
for col_index in range(len(row)):
|
|
252
|
-
if (self.
|
|
481
|
+
if (self._required >> col_index & 1) and row[col_index] is None:
|
|
253
482
|
if row_index is not None:
|
|
254
483
|
row_col_spec = f"row #{row_index} and column #{col_index}"
|
|
255
484
|
else:
|
|
256
485
|
row_col_spec = f"column #{col_index}"
|
|
257
|
-
raise
|
|
486
|
+
raise NoneTypeError(f"expected: {self._resultset_data_types[col_index]} in {row_col_spec}; got: NULL")
|
|
258
487
|
|
|
259
488
|
def check_rows(self, rows: list[tuple[Any, ...]]) -> None:
|
|
260
489
|
"""
|
|
@@ -264,7 +493,7 @@ class _SQLObject:
|
|
|
264
493
|
if not rows:
|
|
265
494
|
return
|
|
266
495
|
|
|
267
|
-
required = self.
|
|
496
|
+
required = self._required
|
|
268
497
|
if not required:
|
|
269
498
|
return
|
|
270
499
|
|
|
@@ -317,7 +546,7 @@ class _SQLObject:
|
|
|
317
546
|
Verifies if declared types match actual value types in a single row.
|
|
318
547
|
"""
|
|
319
548
|
|
|
320
|
-
required = self.
|
|
549
|
+
required = self._required
|
|
321
550
|
if not required:
|
|
322
551
|
return
|
|
323
552
|
|
|
@@ -356,13 +585,187 @@ class _SQLObject:
|
|
|
356
585
|
case _:
|
|
357
586
|
self._raise_required_is_none(row)
|
|
358
587
|
|
|
588
|
+
def check_column(self, column: list[Any]) -> None:
|
|
589
|
+
"""
|
|
590
|
+
Verifies if the declared type matches the actual value type of a single-column resultset.
|
|
591
|
+
"""
|
|
592
|
+
|
|
593
|
+
if self._required:
|
|
594
|
+
for i, value in enumerate(column):
|
|
595
|
+
if value is None:
|
|
596
|
+
raise NoneTypeError(f"expected: {self._resultset_data_types[0]} in row #{i}; got: NULL")
|
|
597
|
+
|
|
359
598
|
def check_value(self, value: Any) -> None:
|
|
360
599
|
"""
|
|
361
600
|
Verifies if the declared type matches the actual value type.
|
|
362
601
|
"""
|
|
363
602
|
|
|
364
|
-
if self.
|
|
365
|
-
raise
|
|
603
|
+
if self._required and value is None:
|
|
604
|
+
raise NoneTypeError(f"expected: {self._resultset_data_types[0]}; got: NULL")
|
|
605
|
+
|
|
606
|
+
def convert_arg_lists(self, arg_lists: Iterable[Sequence[Any]]) -> Iterable[Sequence[Any]]:
|
|
607
|
+
"""
|
|
608
|
+
Converts a list of Python query argument tuples to a list of PostgreSQL parameter tuples.
|
|
609
|
+
"""
|
|
610
|
+
|
|
611
|
+
cast = self._parameter_cast
|
|
612
|
+
if cast:
|
|
613
|
+
converters = self._parameter_converters
|
|
614
|
+
yield from (tuple((converters[i](value) if (value := arg[i]) is not None and cast >> i & 1 else value) for i in range(len(arg))) for arg in arg_lists)
|
|
615
|
+
else:
|
|
616
|
+
yield from arg_lists
|
|
617
|
+
|
|
618
|
+
def convert_arg_list(self, arg_list: Sequence[Any]) -> Sequence[Any]:
|
|
619
|
+
"""
|
|
620
|
+
Converts Python query arguments to PostgreSQL parameters.
|
|
621
|
+
"""
|
|
622
|
+
|
|
623
|
+
cast = self._parameter_cast
|
|
624
|
+
if cast:
|
|
625
|
+
converters = self._parameter_converters
|
|
626
|
+
return tuple((converters[i](value) if (value := arg_list[i]) is not None and cast >> i & 1 else value) for i in range(len(arg_list)))
|
|
627
|
+
else:
|
|
628
|
+
return tuple(value for value in arg_list)
|
|
629
|
+
|
|
630
|
+
def convert_rows(self, rows: list[asyncpg.Record]) -> list[tuple[Any, ...]]:
|
|
631
|
+
"""
|
|
632
|
+
Converts columns in the PostgreSQL result-set to their corresponding Python target types.
|
|
633
|
+
|
|
634
|
+
:param rows: List of rows returned by PostgreSQL.
|
|
635
|
+
:returns: List of tuples with each tuple element having the configured Python target type.
|
|
636
|
+
"""
|
|
637
|
+
|
|
638
|
+
if not rows:
|
|
639
|
+
return []
|
|
640
|
+
|
|
641
|
+
init_wrapper = self._resultset_wrapper.init
|
|
642
|
+
iterable_wrapper = self._resultset_wrapper.iterable
|
|
643
|
+
cast = self._resultset_cast
|
|
644
|
+
if not cast:
|
|
645
|
+
return [iterable_wrapper(row.values()) for row in rows]
|
|
646
|
+
|
|
647
|
+
columns = len(rows[0])
|
|
648
|
+
match columns:
|
|
649
|
+
case 1:
|
|
650
|
+
converter = self._resultset_converters[0]
|
|
651
|
+
if init_wrapper is not None:
|
|
652
|
+
return [init_wrapper(converter(value) if (value := row[0]) is not None else value) for row in rows]
|
|
653
|
+
else:
|
|
654
|
+
return [(converter(value) if (value := row[0]) is not None else value,) for row in rows]
|
|
655
|
+
case 2:
|
|
656
|
+
conv_a, conv_b = self._resultset_converters
|
|
657
|
+
cast_a = cast >> 0 & 1
|
|
658
|
+
cast_b = cast >> 1 & 1
|
|
659
|
+
if init_wrapper is not None:
|
|
660
|
+
return [
|
|
661
|
+
init_wrapper(
|
|
662
|
+
conv_a(value) if (value := row[0]) is not None and cast_a else value,
|
|
663
|
+
conv_b(value) if (value := row[1]) is not None and cast_b else value,
|
|
664
|
+
)
|
|
665
|
+
for row in rows
|
|
666
|
+
]
|
|
667
|
+
else:
|
|
668
|
+
return [
|
|
669
|
+
(
|
|
670
|
+
conv_a(value) if (value := row[0]) is not None and cast_a else value,
|
|
671
|
+
conv_b(value) if (value := row[1]) is not None and cast_b else value,
|
|
672
|
+
)
|
|
673
|
+
for row in rows
|
|
674
|
+
]
|
|
675
|
+
case 3:
|
|
676
|
+
conv_a, conv_b, conv_c = self._resultset_converters
|
|
677
|
+
cast_a = cast >> 0 & 1
|
|
678
|
+
cast_b = cast >> 1 & 1
|
|
679
|
+
cast_c = cast >> 2 & 1
|
|
680
|
+
if init_wrapper is not None:
|
|
681
|
+
return [
|
|
682
|
+
init_wrapper(
|
|
683
|
+
conv_a(value) if (value := row[0]) is not None and cast_a else value,
|
|
684
|
+
conv_b(value) if (value := row[1]) is not None and cast_b else value,
|
|
685
|
+
conv_c(value) if (value := row[2]) is not None and cast_c else value,
|
|
686
|
+
)
|
|
687
|
+
for row in rows
|
|
688
|
+
]
|
|
689
|
+
else:
|
|
690
|
+
return [
|
|
691
|
+
(
|
|
692
|
+
conv_a(value) if (value := row[0]) is not None and cast_a else value,
|
|
693
|
+
conv_b(value) if (value := row[1]) is not None and cast_b else value,
|
|
694
|
+
conv_c(value) if (value := row[2]) is not None and cast_c else value,
|
|
695
|
+
)
|
|
696
|
+
for row in rows
|
|
697
|
+
]
|
|
698
|
+
case 4:
|
|
699
|
+
conv_a, conv_b, conv_c, conv_d = self._resultset_converters
|
|
700
|
+
cast_a = cast >> 0 & 1
|
|
701
|
+
cast_b = cast >> 1 & 1
|
|
702
|
+
cast_c = cast >> 2 & 1
|
|
703
|
+
cast_d = cast >> 3 & 1
|
|
704
|
+
if init_wrapper is not None:
|
|
705
|
+
return [
|
|
706
|
+
init_wrapper(
|
|
707
|
+
conv_a(value) if (value := row[0]) is not None and cast_a else value,
|
|
708
|
+
conv_b(value) if (value := row[1]) is not None and cast_b else value,
|
|
709
|
+
conv_c(value) if (value := row[2]) is not None and cast_c else value,
|
|
710
|
+
conv_d(value) if (value := row[3]) is not None and cast_d else value,
|
|
711
|
+
)
|
|
712
|
+
for row in rows
|
|
713
|
+
]
|
|
714
|
+
else:
|
|
715
|
+
return [
|
|
716
|
+
(
|
|
717
|
+
conv_a(value) if (value := row[0]) is not None and cast_a else value,
|
|
718
|
+
conv_b(value) if (value := row[1]) is not None and cast_b else value,
|
|
719
|
+
conv_c(value) if (value := row[2]) is not None and cast_c else value,
|
|
720
|
+
conv_d(value) if (value := row[3]) is not None and cast_d else value,
|
|
721
|
+
)
|
|
722
|
+
for row in rows
|
|
723
|
+
]
|
|
724
|
+
case _:
|
|
725
|
+
converters = self._resultset_converters
|
|
726
|
+
return [iterable_wrapper((converters[i](value) if (value := row[i]) is not None and cast >> i & 1 else value) for i in range(columns)) for row in rows]
|
|
727
|
+
|
|
728
|
+
def convert_row(self, row: asyncpg.Record) -> tuple[Any, ...]:
|
|
729
|
+
"""
|
|
730
|
+
Converts columns in the PostgreSQL result-set to their corresponding Python target types.
|
|
731
|
+
|
|
732
|
+
:param row: A single row returned by PostgreSQL.
|
|
733
|
+
:returns: A tuple with each tuple element having the configured Python target type.
|
|
734
|
+
"""
|
|
735
|
+
|
|
736
|
+
wrapper = self._resultset_wrapper.iterable
|
|
737
|
+
cast = self._resultset_cast
|
|
738
|
+
if cast:
|
|
739
|
+
converters = self._resultset_converters
|
|
740
|
+
columns = len(row)
|
|
741
|
+
return wrapper((converters[i](value) if (value := row[i]) is not None and cast >> i & 1 else value) for i in range(columns))
|
|
742
|
+
else:
|
|
743
|
+
return wrapper(row.values())
|
|
744
|
+
|
|
745
|
+
def convert_column(self, rows: list[asyncpg.Record]) -> list[Any]:
|
|
746
|
+
"""
|
|
747
|
+
Converts a single column in the PostgreSQL result-set to its corresponding Python target type.
|
|
748
|
+
|
|
749
|
+
:param rows: List of rows returned by PostgreSQL.
|
|
750
|
+
:returns: List of values having the configured Python target type.
|
|
751
|
+
"""
|
|
752
|
+
|
|
753
|
+
cast = self._resultset_cast
|
|
754
|
+
if cast:
|
|
755
|
+
converter = self._resultset_converters[0]
|
|
756
|
+
return [(converter(value) if (value := row[0]) is not None else value) for row in rows]
|
|
757
|
+
else:
|
|
758
|
+
return [row[0] for row in rows]
|
|
759
|
+
|
|
760
|
+
def convert_value(self, value: Any) -> Any:
|
|
761
|
+
"""
|
|
762
|
+
Converts a single PostgreSQL value to its corresponding Python target type.
|
|
763
|
+
|
|
764
|
+
:param value: A single value returned by PostgreSQL.
|
|
765
|
+
:returns: A converted value having the configured Python target type.
|
|
766
|
+
"""
|
|
767
|
+
|
|
768
|
+
return self._resultset_converters[0](value) if value is not None and self._resultset_cast else value
|
|
366
769
|
|
|
367
770
|
@abstractmethod
|
|
368
771
|
def query(self) -> str:
|
|
@@ -388,8 +791,8 @@ if sys.version_info >= (3, 14):
|
|
|
388
791
|
A SQL query specified with the Python t-string syntax.
|
|
389
792
|
"""
|
|
390
793
|
|
|
391
|
-
|
|
392
|
-
|
|
794
|
+
_strings: tuple[str, ...]
|
|
795
|
+
_placeholders: tuple[_Placeholder, ...]
|
|
393
796
|
|
|
394
797
|
def __init__(
|
|
395
798
|
self,
|
|
@@ -397,8 +800,10 @@ if sys.version_info >= (3, 14):
|
|
|
397
800
|
*,
|
|
398
801
|
args: tuple[TargetType, ...],
|
|
399
802
|
resultset: tuple[TargetType, ...],
|
|
803
|
+
names: tuple[str, ...] | None,
|
|
804
|
+
wrapper: _ResultsetWrapper,
|
|
400
805
|
) -> None:
|
|
401
|
-
super().__init__(args, resultset)
|
|
806
|
+
super().__init__(args=args, resultset=resultset, names=names, wrapper=wrapper)
|
|
402
807
|
|
|
403
808
|
for ip in template.interpolations:
|
|
404
809
|
if ip.conversion is not None:
|
|
@@ -408,26 +813,26 @@ if sys.version_info >= (3, 14):
|
|
|
408
813
|
if not isinstance(ip.value, int):
|
|
409
814
|
raise TypeError(f"interpolation `{ip.expression}` expected to evaluate to an integer")
|
|
410
815
|
|
|
411
|
-
self.
|
|
816
|
+
self._strings = template.strings
|
|
412
817
|
|
|
413
818
|
if len(self.parameter_data_types) > 0:
|
|
414
819
|
|
|
415
|
-
def _to_placeholder(ip: Interpolation) ->
|
|
820
|
+
def _to_placeholder(ip: Interpolation) -> _Placeholder:
|
|
416
821
|
ordinal = int(ip.value)
|
|
417
822
|
if not (0 < ordinal <= len(self.parameter_data_types)):
|
|
418
823
|
raise IndexError(f"interpolation `{ip.expression}` is an ordinal out of range; expected: 0 < value <= {len(self.parameter_data_types)}")
|
|
419
824
|
return self.parameter_data_types[int(ip.value) - 1]
|
|
420
825
|
|
|
421
|
-
self.
|
|
826
|
+
self._placeholders = tuple(_to_placeholder(ip) for ip in template.interpolations)
|
|
422
827
|
else:
|
|
423
|
-
self.
|
|
828
|
+
self._placeholders = ()
|
|
424
829
|
|
|
425
830
|
def query(self) -> str:
|
|
426
831
|
buf = StringIO()
|
|
427
|
-
for s, p in zip(self.
|
|
832
|
+
for s, p in zip(self._strings[:-1], self._placeholders, strict=True):
|
|
428
833
|
buf.write(s)
|
|
429
834
|
buf.write(f"${p.ordinal}")
|
|
430
|
-
buf.write(self.
|
|
835
|
+
buf.write(self._strings[-1])
|
|
431
836
|
return buf.getvalue()
|
|
432
837
|
|
|
433
838
|
else:
|
|
@@ -439,7 +844,7 @@ class _SQLString(_SQLObject):
|
|
|
439
844
|
A SQL query specified as a plain string (e.g. f-string).
|
|
440
845
|
"""
|
|
441
846
|
|
|
442
|
-
|
|
847
|
+
_sql: str
|
|
443
848
|
|
|
444
849
|
def __init__(
|
|
445
850
|
self,
|
|
@@ -447,12 +852,14 @@ class _SQLString(_SQLObject):
|
|
|
447
852
|
*,
|
|
448
853
|
args: tuple[TargetType, ...],
|
|
449
854
|
resultset: tuple[TargetType, ...],
|
|
855
|
+
names: tuple[str, ...] | None,
|
|
856
|
+
wrapper: _ResultsetWrapper,
|
|
450
857
|
) -> None:
|
|
451
|
-
super().__init__(args, resultset)
|
|
452
|
-
self.
|
|
858
|
+
super().__init__(args=args, resultset=resultset, names=names, wrapper=wrapper)
|
|
859
|
+
self._sql = sql
|
|
453
860
|
|
|
454
861
|
def query(self) -> str:
|
|
455
|
-
return self.
|
|
862
|
+
return self._sql
|
|
456
863
|
|
|
457
864
|
|
|
458
865
|
class _SQL(Protocol):
|
|
@@ -461,80 +868,90 @@ class _SQL(Protocol):
|
|
|
461
868
|
"""
|
|
462
869
|
|
|
463
870
|
|
|
464
|
-
Connection: TypeAlias = asyncpg.Connection | asyncpg.pool.PoolConnectionProxy
|
|
465
|
-
|
|
466
|
-
|
|
467
871
|
class _SQLImpl(_SQL):
|
|
468
872
|
"""
|
|
469
873
|
Forwards input data to an `asyncpg.PreparedStatement`, and validates output data (if necessary).
|
|
470
874
|
"""
|
|
471
875
|
|
|
472
|
-
|
|
876
|
+
_sql: _SQLObject
|
|
473
877
|
|
|
474
878
|
def __init__(self, sql: _SQLObject) -> None:
|
|
475
|
-
self.
|
|
879
|
+
self._sql = sql
|
|
476
880
|
|
|
477
881
|
def __str__(self) -> str:
|
|
478
|
-
return str(self.
|
|
882
|
+
return str(self._sql)
|
|
479
883
|
|
|
480
884
|
def __repr__(self) -> str:
|
|
481
|
-
return repr(self.
|
|
885
|
+
return repr(self._sql)
|
|
482
886
|
|
|
483
887
|
async def _prepare(self, connection: Connection) -> PreparedStatement:
|
|
484
|
-
stmt = await connection.prepare(self.
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
888
|
+
stmt = await connection.prepare(self._sql.query())
|
|
889
|
+
|
|
890
|
+
verifier = _TypeVerifier(connection)
|
|
891
|
+
|
|
892
|
+
input_count = len(self._sql.parameter_data_types)
|
|
893
|
+
parameter_count = len(stmt.get_parameters())
|
|
894
|
+
if parameter_count != input_count:
|
|
895
|
+
raise CountMismatchError(f"expected: PostgreSQL query to take {input_count} parameter(s); got: {parameter_count}")
|
|
896
|
+
for param, placeholder in zip(stmt.get_parameters(), self._sql.parameter_data_types, strict=True):
|
|
897
|
+
await verifier.check_data_type(f"parameter ${placeholder.ordinal}", param, placeholder.data_type)
|
|
898
|
+
|
|
899
|
+
output_count = len(self._sql.resultset_data_types)
|
|
900
|
+
column_count = len(stmt.get_attributes())
|
|
901
|
+
if column_count != output_count:
|
|
902
|
+
raise CountMismatchError(f"expected: PostgreSQL query to return {output_count} column(s) in result-set; got: {column_count}")
|
|
903
|
+
if self._sql.resultset_column_names is not None:
|
|
904
|
+
for index, attr, name in zip(range(output_count), stmt.get_attributes(), self._sql.resultset_column_names, strict=True):
|
|
905
|
+
if attr.name != name:
|
|
906
|
+
raise NameMismatchError(f"expected: Python field name `{name}` to match PostgreSQL result-set column name `{attr.name}` for index #{index}")
|
|
907
|
+
for attr, data_type in zip(stmt.get_attributes(), self._sql.resultset_data_types, strict=True):
|
|
908
|
+
await verifier.check_data_type(f"column `{attr.name}`", attr.type, data_type)
|
|
489
909
|
|
|
490
910
|
return stmt
|
|
491
911
|
|
|
492
912
|
async def execute(self, connection: asyncpg.Connection, *args: Any) -> None:
|
|
493
|
-
await connection.execute(self.
|
|
913
|
+
await connection.execute(self._sql.query(), *self._sql.convert_arg_list(args))
|
|
494
914
|
|
|
495
915
|
async def executemany(self, connection: asyncpg.Connection, args: Iterable[Sequence[Any]]) -> None:
|
|
496
916
|
stmt = await self._prepare(connection)
|
|
497
|
-
await stmt.executemany(args)
|
|
917
|
+
await stmt.executemany(self._sql.convert_arg_lists(args))
|
|
498
918
|
|
|
499
919
|
def _cast_fetch(self, rows: list[asyncpg.Record]) -> list[tuple[Any, ...]]:
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
converters = self.sql.converters
|
|
503
|
-
resultset = [tuple((converters[i](value) if (value := row[i]) is not None and cast >> i & 1 else value) for i in range(len(row))) for row in rows]
|
|
504
|
-
else:
|
|
505
|
-
resultset = [tuple(value for value in row) for row in rows]
|
|
506
|
-
self.sql.check_rows(resultset)
|
|
920
|
+
resultset = self._sql.convert_rows(rows)
|
|
921
|
+
self._sql.check_rows(resultset)
|
|
507
922
|
return resultset
|
|
508
923
|
|
|
509
924
|
async def fetch(self, connection: asyncpg.Connection, *args: Any) -> list[tuple[Any, ...]]:
|
|
510
925
|
stmt = await self._prepare(connection)
|
|
511
|
-
rows = await stmt.fetch(*args)
|
|
926
|
+
rows = await stmt.fetch(*self._sql.convert_arg_list(args))
|
|
512
927
|
return self._cast_fetch(rows)
|
|
513
928
|
|
|
514
929
|
async def fetchmany(self, connection: asyncpg.Connection, args: Iterable[Sequence[Any]]) -> list[tuple[Any, ...]]:
|
|
515
930
|
stmt = await self._prepare(connection)
|
|
516
|
-
rows = await stmt.fetchmany(args)
|
|
931
|
+
rows = await stmt.fetchmany(self._sql.convert_arg_lists(args))
|
|
517
932
|
return self._cast_fetch(rows)
|
|
518
933
|
|
|
519
934
|
async def fetchrow(self, connection: asyncpg.Connection, *args: Any) -> tuple[Any, ...] | None:
|
|
520
935
|
stmt = await self._prepare(connection)
|
|
521
|
-
row = await stmt.fetchrow(*args)
|
|
936
|
+
row = await stmt.fetchrow(*self._sql.convert_arg_list(args))
|
|
522
937
|
if row is None:
|
|
523
938
|
return None
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
converters = self.sql.converters
|
|
527
|
-
resultset = tuple((converters[i](value) if (value := row[i]) is not None and cast >> i & 1 else value) for i in range(len(row)))
|
|
528
|
-
else:
|
|
529
|
-
resultset = tuple(value for value in row)
|
|
530
|
-
self.sql.check_row(resultset)
|
|
939
|
+
resultset = self._sql.convert_row(row)
|
|
940
|
+
self._sql.check_row(resultset)
|
|
531
941
|
return resultset
|
|
532
942
|
|
|
943
|
+
async def fetchcol(self, connection: asyncpg.Connection, *args: Any) -> list[Any]:
|
|
944
|
+
stmt = await self._prepare(connection)
|
|
945
|
+
rows = await stmt.fetch(*self._sql.convert_arg_list(args))
|
|
946
|
+
column = self._sql.convert_column(rows)
|
|
947
|
+
self._sql.check_column(column)
|
|
948
|
+
return column
|
|
949
|
+
|
|
533
950
|
async def fetchval(self, connection: asyncpg.Connection, *args: Any) -> Any:
|
|
534
951
|
stmt = await self._prepare(connection)
|
|
535
|
-
value = await stmt.fetchval(*args)
|
|
536
|
-
result = self.
|
|
537
|
-
self.
|
|
952
|
+
value = await stmt.fetchval(*self._sql.convert_arg_list(args))
|
|
953
|
+
result = self._sql.convert_value(value)
|
|
954
|
+
self._sql.check_value(result)
|
|
538
955
|
return result
|
|
539
956
|
|
|
540
957
|
|
|
@@ -547,9 +964,7 @@ R2 = TypeVar("R2")
|
|
|
547
964
|
RX = TypeVarTuple("RX")
|
|
548
965
|
|
|
549
966
|
|
|
550
|
-
### START OF AUTO-GENERATED BLOCK ###
|
|
551
|
-
|
|
552
|
-
|
|
967
|
+
### START OF AUTO-GENERATED BLOCK FOR Protocol ###
|
|
553
968
|
class SQL_P0(Protocol):
|
|
554
969
|
@abstractmethod
|
|
555
970
|
async def execute(self, connection: Connection) -> None: ...
|
|
@@ -561,6 +976,8 @@ class SQL_R1_P0(SQL_P0, Protocol[R1]):
|
|
|
561
976
|
@abstractmethod
|
|
562
977
|
async def fetchrow(self, connection: Connection) -> tuple[R1] | None: ...
|
|
563
978
|
@abstractmethod
|
|
979
|
+
async def fetchcol(self, connection: Connection) -> list[R1]: ...
|
|
980
|
+
@abstractmethod
|
|
564
981
|
async def fetchval(self, connection: Connection) -> R1: ...
|
|
565
982
|
|
|
566
983
|
|
|
@@ -586,6 +1003,8 @@ class SQL_R1_PX(SQL_PX[Unpack[PX]], Protocol[R1, Unpack[PX]]):
|
|
|
586
1003
|
@abstractmethod
|
|
587
1004
|
async def fetchrow(self, connection: Connection, *args: Unpack[PX]) -> tuple[R1] | None: ...
|
|
588
1005
|
@abstractmethod
|
|
1006
|
+
async def fetchcol(self, connection: Connection, *args: Unpack[PX]) -> list[R1]: ...
|
|
1007
|
+
@abstractmethod
|
|
589
1008
|
async def fetchval(self, connection: Connection, *args: Unpack[PX]) -> R1: ...
|
|
590
1009
|
|
|
591
1010
|
|
|
@@ -598,84 +1017,153 @@ class SQL_RX_PX(SQL_PX[Unpack[PX]], Protocol[RT, Unpack[PX]]):
|
|
|
598
1017
|
async def fetchrow(self, connection: Connection, *args: Unpack[PX]) -> RT | None: ...
|
|
599
1018
|
|
|
600
1019
|
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
def sql(stmt: SQLExpression, *, result: type[R1]) -> SQL_R1_P0[R1]: ...
|
|
605
|
-
@overload
|
|
606
|
-
def sql(stmt: SQLExpression, *, resultset: type[tuple[R1]]) -> SQL_R1_P0[R1]: ...
|
|
607
|
-
@overload
|
|
608
|
-
def sql(stmt: SQLExpression, *, resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_RX_P0[tuple[R1, R2, Unpack[RX]]]: ...
|
|
609
|
-
@overload
|
|
610
|
-
def sql(stmt: SQLExpression, *, arg: type[P1]) -> SQL_PX[P1]: ...
|
|
611
|
-
@overload
|
|
612
|
-
def sql(stmt: SQLExpression, *, arg: type[P1], result: type[R1]) -> SQL_R1_PX[R1, P1]: ...
|
|
613
|
-
@overload
|
|
614
|
-
def sql(stmt: SQLExpression, *, arg: type[P1], resultset: type[tuple[R1]]) -> SQL_R1_PX[R1, P1]: ...
|
|
615
|
-
@overload
|
|
616
|
-
def sql(stmt: SQLExpression, *, arg: type[P1], resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_RX_PX[tuple[R1, R2, Unpack[RX]], P1]: ...
|
|
617
|
-
@overload
|
|
618
|
-
def sql(stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]]) -> SQL_PX[P1, Unpack[PX]]: ...
|
|
619
|
-
@overload
|
|
620
|
-
def sql(stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]], result: type[R1]) -> SQL_R1_PX[R1, P1, Unpack[PX]]: ...
|
|
621
|
-
@overload
|
|
622
|
-
def sql(stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]], resultset: type[tuple[R1]]) -> SQL_R1_PX[R1, P1, Unpack[PX]]: ...
|
|
623
|
-
@overload
|
|
624
|
-
def sql(stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]], resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_RX_PX[tuple[R1, R2, Unpack[RX]], P1, Unpack[PX]]: ...
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
### END OF AUTO-GENERATED BLOCK ###
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
def sql(
|
|
631
|
-
stmt: SQLExpression,
|
|
632
|
-
*,
|
|
633
|
-
args: type[Any] | None = None,
|
|
634
|
-
resultset: type[Any] | None = None,
|
|
635
|
-
arg: type[Any] | None = None,
|
|
636
|
-
result: type[Any] | None = None,
|
|
637
|
-
) -> _SQL:
|
|
638
|
-
"""
|
|
639
|
-
Creates a SQL statement with associated type information.
|
|
1020
|
+
### END OF AUTO-GENERATED BLOCK FOR Protocol ###
|
|
1021
|
+
|
|
1022
|
+
RS = TypeVar("RS", bound=tuple[Any, ...])
|
|
640
1023
|
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
:param result: Type signature for a single result column (e.g. `UUID`).
|
|
1024
|
+
|
|
1025
|
+
class SQLFactory:
|
|
1026
|
+
"""
|
|
1027
|
+
Creates type-safe SQL queries.
|
|
646
1028
|
"""
|
|
647
1029
|
|
|
1030
|
+
### START OF AUTO-GENERATED BLOCK FOR sql ###
|
|
1031
|
+
@overload
|
|
1032
|
+
def sql(self, stmt: SQLExpression) -> SQL_P0: ...
|
|
1033
|
+
@overload
|
|
1034
|
+
def sql(self, stmt: SQLExpression, *, result: type[R1]) -> SQL_R1_P0[R1]: ...
|
|
1035
|
+
@overload
|
|
1036
|
+
def sql(self, stmt: SQLExpression, *, resultset: type[RS]) -> SQL_RX_P0[RS]: ...
|
|
1037
|
+
@overload
|
|
1038
|
+
def sql(self, stmt: SQLExpression, *, arg: type[P1]) -> SQL_PX[P1]: ...
|
|
1039
|
+
@overload
|
|
1040
|
+
def sql(self, stmt: SQLExpression, *, arg: type[P1], result: type[R1]) -> SQL_R1_PX[R1, P1]: ...
|
|
1041
|
+
@overload
|
|
1042
|
+
def sql(self, stmt: SQLExpression, *, arg: type[P1], resultset: type[RS]) -> SQL_RX_PX[RS, P1]: ...
|
|
1043
|
+
@overload
|
|
1044
|
+
def sql(self, stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]]) -> SQL_PX[P1, Unpack[PX]]: ...
|
|
1045
|
+
@overload
|
|
1046
|
+
def sql(self, stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]], result: type[R1]) -> SQL_R1_PX[R1, P1, Unpack[PX]]: ...
|
|
1047
|
+
@overload
|
|
1048
|
+
def sql(self, stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]], resultset: type[RS]) -> SQL_RX_PX[RS, P1, Unpack[PX]]: ...
|
|
1049
|
+
|
|
1050
|
+
### END OF AUTO-GENERATED BLOCK FOR sql ###
|
|
1051
|
+
|
|
1052
|
+
def sql(self, stmt: SQLExpression, *, args: type[Any] | None = None, resultset: type[Any] | None = None, arg: type[Any] | None = None, result: type[Any] | None = None) -> _SQL:
|
|
1053
|
+
"""
|
|
1054
|
+
Creates a SQL statement with associated type information.
|
|
1055
|
+
|
|
1056
|
+
:param stmt: SQL statement as a literal string or template.
|
|
1057
|
+
:param args: Type signature for multiple input parameters (e.g. `tuple[bool, int, str]`).
|
|
1058
|
+
:param resultset: Type signature for multiple resultset columns (e.g. `tuple[datetime, Decimal, str]`).
|
|
1059
|
+
:param arg: Type signature for a single input parameter (e.g. `int`).
|
|
1060
|
+
:param result: Type signature for a single result column (e.g. `UUID`).
|
|
1061
|
+
"""
|
|
1062
|
+
|
|
1063
|
+
input_data_types, output_data_types, names, wrapper = _sql_args_resultset(args=args, resultset=resultset, arg=arg, result=result)
|
|
1064
|
+
|
|
1065
|
+
obj: _SQLObject
|
|
1066
|
+
if sys.version_info >= (3, 14):
|
|
1067
|
+
match stmt:
|
|
1068
|
+
case Template():
|
|
1069
|
+
obj = _SQLTemplate(stmt, args=input_data_types, resultset=output_data_types, names=names, wrapper=wrapper)
|
|
1070
|
+
case str():
|
|
1071
|
+
obj = _SQLString(stmt, args=input_data_types, resultset=output_data_types, names=names, wrapper=wrapper)
|
|
1072
|
+
else:
|
|
1073
|
+
obj = _SQLString(stmt, args=input_data_types, resultset=output_data_types, names=names, wrapper=wrapper)
|
|
1074
|
+
|
|
1075
|
+
return _SQLImpl(obj)
|
|
1076
|
+
|
|
1077
|
+
### START OF AUTO-GENERATED BLOCK FOR unsafe_sql ###
|
|
1078
|
+
@overload
|
|
1079
|
+
def unsafe_sql(self, stmt: str) -> SQL_P0: ...
|
|
1080
|
+
@overload
|
|
1081
|
+
def unsafe_sql(self, stmt: str, *, result: type[R1]) -> SQL_R1_P0[R1]: ...
|
|
1082
|
+
@overload
|
|
1083
|
+
def unsafe_sql(self, stmt: str, *, resultset: type[RS]) -> SQL_RX_P0[RS]: ...
|
|
1084
|
+
@overload
|
|
1085
|
+
def unsafe_sql(self, stmt: str, *, arg: type[P1]) -> SQL_PX[P1]: ...
|
|
1086
|
+
@overload
|
|
1087
|
+
def unsafe_sql(self, stmt: str, *, arg: type[P1], result: type[R1]) -> SQL_R1_PX[R1, P1]: ...
|
|
1088
|
+
@overload
|
|
1089
|
+
def unsafe_sql(self, stmt: str, *, arg: type[P1], resultset: type[RS]) -> SQL_RX_PX[RS, P1]: ...
|
|
1090
|
+
@overload
|
|
1091
|
+
def unsafe_sql(self, stmt: str, *, args: type[tuple[P1, Unpack[PX]]]) -> SQL_PX[P1, Unpack[PX]]: ...
|
|
1092
|
+
@overload
|
|
1093
|
+
def unsafe_sql(self, stmt: str, *, args: type[tuple[P1, Unpack[PX]]], result: type[R1]) -> SQL_R1_PX[R1, P1, Unpack[PX]]: ...
|
|
1094
|
+
@overload
|
|
1095
|
+
def unsafe_sql(self, stmt: str, *, args: type[tuple[P1, Unpack[PX]]], resultset: type[RS]) -> SQL_RX_PX[RS, P1, Unpack[PX]]: ...
|
|
1096
|
+
|
|
1097
|
+
### END OF AUTO-GENERATED BLOCK FOR unsafe_sql ###
|
|
1098
|
+
|
|
1099
|
+
def unsafe_sql(self, stmt: str, *, args: type[Any] | None = None, resultset: type[Any] | None = None, arg: type[Any] | None = None, result: type[Any] | None = None) -> _SQL:
|
|
1100
|
+
"""
|
|
1101
|
+
Creates a SQL statement with associated type information from a string.
|
|
1102
|
+
|
|
1103
|
+
This offers an alternative to the function :func:`sql` when we want to prevent the type checker from enforcing
|
|
1104
|
+
a string literal, e.g. when we want to embed a variable as the table name to dynamically create a SQL statement.
|
|
1105
|
+
|
|
1106
|
+
:param stmt: SQL statement as a string (or f-string).
|
|
1107
|
+
:param args: Type signature for multiple input parameters (e.g. `tuple[bool, int, str]`).
|
|
1108
|
+
:param resultset: Type signature for multiple resultset columns (e.g. `tuple[datetime, Decimal, str]`).
|
|
1109
|
+
:param arg: Type signature for a single input parameter (e.g. `int`).
|
|
1110
|
+
:param result: Type signature for a single result column (e.g. `UUID`).
|
|
1111
|
+
"""
|
|
1112
|
+
|
|
1113
|
+
input_data_types, output_data_types, names, wrapper = _sql_args_resultset(args=args, resultset=resultset, arg=arg, result=result)
|
|
1114
|
+
obj = _SQLString(stmt, args=input_data_types, resultset=output_data_types, names=names, wrapper=wrapper)
|
|
1115
|
+
return _SQLImpl(obj)
|
|
1116
|
+
|
|
1117
|
+
|
|
1118
|
+
def _sql_args_resultset(
|
|
1119
|
+
*, args: type[Any] | None = None, resultset: type[Any] | None = None, arg: type[Any] | None = None, result: type[Any] | None = None
|
|
1120
|
+
) -> tuple[tuple[Any, ...], tuple[Any, ...], tuple[str, ...] | None, _ResultsetWrapper]:
|
|
1121
|
+
"Parses an argument/resultset signature into input/output types."
|
|
1122
|
+
|
|
648
1123
|
if args is not None and arg is not None:
|
|
649
1124
|
raise TypeError("expected: either `args` or `arg`; got: both")
|
|
650
1125
|
if resultset is not None and result is not None:
|
|
651
1126
|
raise TypeError("expected: either `resultset` or `result`; got: both")
|
|
652
1127
|
|
|
653
1128
|
if args is not None:
|
|
654
|
-
if
|
|
655
|
-
|
|
656
|
-
|
|
1129
|
+
if hasattr(args, "_asdict") and hasattr(args, "_fields"):
|
|
1130
|
+
# named tuple
|
|
1131
|
+
input_data_types = tuple(tp for tp in args.__annotations__.values())
|
|
1132
|
+
else:
|
|
1133
|
+
# regular tuple
|
|
1134
|
+
if get_origin(args) is not tuple:
|
|
1135
|
+
raise TypeError(f"expected: `type[tuple[T, ...]]` for `args`; got: {args}")
|
|
1136
|
+
input_data_types = get_args(args)
|
|
657
1137
|
elif arg is not None:
|
|
658
1138
|
input_data_types = (arg,)
|
|
659
1139
|
else:
|
|
660
1140
|
input_data_types = ()
|
|
661
1141
|
|
|
662
1142
|
if resultset is not None:
|
|
663
|
-
if
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
1143
|
+
if hasattr(resultset, "_asdict") and hasattr(resultset, "_fields") and hasattr(resultset, "_make"):
|
|
1144
|
+
# named tuple
|
|
1145
|
+
output_data_types = tuple(tp for tp in resultset.__annotations__.values())
|
|
1146
|
+
names = tuple(f for f in resultset._fields)
|
|
1147
|
+
wrapper = _ResultsetWrapper(resultset, resultset._make)
|
|
1148
|
+
else:
|
|
1149
|
+
# regular tuple
|
|
1150
|
+
if get_origin(resultset) is not tuple:
|
|
1151
|
+
raise TypeError(f"expected: `type[tuple[T, ...]]` for `resultset`; got: {resultset}")
|
|
1152
|
+
output_data_types = get_args(resultset)
|
|
1153
|
+
names = None
|
|
1154
|
+
wrapper = _ResultsetWrapper(None, tuple)
|
|
668
1155
|
else:
|
|
669
|
-
|
|
1156
|
+
if result is not None:
|
|
1157
|
+
output_data_types = (result,)
|
|
1158
|
+
else:
|
|
1159
|
+
output_data_types = ()
|
|
1160
|
+
names = None
|
|
1161
|
+
wrapper = _ResultsetWrapper(None, tuple)
|
|
670
1162
|
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
obj = _SQLTemplate(stmt, args=input_data_types, resultset=output_data_types)
|
|
676
|
-
case str():
|
|
677
|
-
obj = _SQLString(stmt, args=input_data_types, resultset=output_data_types)
|
|
678
|
-
else:
|
|
679
|
-
obj = _SQLString(stmt, args=input_data_types, resultset=output_data_types)
|
|
1163
|
+
return input_data_types, output_data_types, names, wrapper
|
|
1164
|
+
|
|
1165
|
+
|
|
1166
|
+
FACTORY: SQLFactory = SQLFactory()
|
|
680
1167
|
|
|
681
|
-
|
|
1168
|
+
sql = FACTORY.sql
|
|
1169
|
+
unsafe_sql = FACTORY.unsafe_sql
|