asyncpg-typed 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- asyncpg_typed/__init__.py +437 -154
- {asyncpg_typed-0.1.2.dist-info → asyncpg_typed-0.1.3.dist-info}/METADATA +121 -41
- asyncpg_typed-0.1.3.dist-info/RECORD +8 -0
- asyncpg_typed-0.1.2.dist-info/RECORD +0 -8
- {asyncpg_typed-0.1.2.dist-info → asyncpg_typed-0.1.3.dist-info}/WHEEL +0 -0
- {asyncpg_typed-0.1.2.dist-info → asyncpg_typed-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {asyncpg_typed-0.1.2.dist-info → asyncpg_typed-0.1.3.dist-info}/top_level.txt +0 -0
- {asyncpg_typed-0.1.2.dist-info → asyncpg_typed-0.1.3.dist-info}/zip-safe +0 -0
asyncpg_typed/__init__.py
CHANGED
|
@@ -4,7 +4,7 @@ Type-safe queries for asyncpg.
|
|
|
4
4
|
:see: https://github.com/hunyadi/asyncpg_typed
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
__version__ = "0.1.
|
|
7
|
+
__version__ = "0.1.3"
|
|
8
8
|
__author__ = "Levente Hunyadi"
|
|
9
9
|
__copyright__ = "Copyright 2025, Levente Hunyadi"
|
|
10
10
|
__license__ = "MIT"
|
|
@@ -16,12 +16,14 @@ import sys
|
|
|
16
16
|
import typing
|
|
17
17
|
from abc import abstractmethod
|
|
18
18
|
from collections.abc import Callable, Iterable, Sequence
|
|
19
|
+
from dataclasses import dataclass
|
|
19
20
|
from datetime import date, datetime, time, timedelta
|
|
20
21
|
from decimal import Decimal
|
|
21
22
|
from functools import reduce
|
|
22
23
|
from io import StringIO
|
|
24
|
+
from ipaddress import IPv4Address, IPv4Network, IPv6Address, IPv6Network
|
|
23
25
|
from types import UnionType
|
|
24
|
-
from typing import Any, Protocol, TypeAlias, TypeVar, Union, get_args, get_origin, overload
|
|
26
|
+
from typing import Any, Protocol, TypeAlias, TypeGuard, TypeVar, Union, get_args, get_origin, overload
|
|
25
27
|
from uuid import UUID
|
|
26
28
|
|
|
27
29
|
import asyncpg
|
|
@@ -38,9 +40,24 @@ RequiredJsonType = bool | int | float | str | dict[str, "JsonType"] | list["Json
|
|
|
38
40
|
|
|
39
41
|
TargetType: TypeAlias = type[Any] | UnionType
|
|
40
42
|
|
|
43
|
+
Connection: TypeAlias = asyncpg.Connection | asyncpg.pool.PoolConnectionProxy
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class TypeMismatchError(TypeError):
|
|
47
|
+
"Raised when a prepared statement takes or returns a PostgreSQL type incompatible with the declared Python type."
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class EnumMismatchError(TypeError):
|
|
51
|
+
"Raised when a prepared statement takes or returns a PostgreSQL enum type whose permitted set of values differs from what is declared in Python."
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class NoneTypeError(TypeError):
|
|
55
|
+
"Raised when a column marked as required contains a `NULL` value."
|
|
56
|
+
|
|
57
|
+
|
|
41
58
|
if sys.version_info >= (3, 11):
|
|
42
59
|
|
|
43
|
-
def is_enum_type(typ:
|
|
60
|
+
def is_enum_type(typ: Any) -> TypeGuard[type[enum.Enum]]:
|
|
44
61
|
"""
|
|
45
62
|
`True` if the specified type is an enumeration type.
|
|
46
63
|
"""
|
|
@@ -49,7 +66,7 @@ if sys.version_info >= (3, 11):
|
|
|
49
66
|
|
|
50
67
|
else:
|
|
51
68
|
|
|
52
|
-
def is_enum_type(typ:
|
|
69
|
+
def is_enum_type(typ: Any) -> TypeGuard[type[enum.Enum]]:
|
|
53
70
|
"""
|
|
54
71
|
`True` if the specified type is an enumeration type.
|
|
55
72
|
"""
|
|
@@ -91,6 +108,14 @@ def is_json_type(tp: Any) -> bool:
|
|
|
91
108
|
return tp in [JsonType, RequiredJsonType]
|
|
92
109
|
|
|
93
110
|
|
|
111
|
+
def is_inet_type(tp: Any) -> bool:
|
|
112
|
+
"""
|
|
113
|
+
`True` if the type represents an IP address or network.
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
return tp in [IPv4Address, IPv6Address, IPv4Network, IPv6Network]
|
|
117
|
+
|
|
118
|
+
|
|
94
119
|
def make_union_type(tpl: list[Any]) -> UnionType:
|
|
95
120
|
"""
|
|
96
121
|
Creates a `UnionType` (a.k.a. `A | B | C`) dynamically at run time.
|
|
@@ -119,25 +144,54 @@ def get_required_type(tp: Any) -> Any:
|
|
|
119
144
|
return type(None)
|
|
120
145
|
|
|
121
146
|
|
|
122
|
-
|
|
123
|
-
if typing.TYPE_CHECKING:
|
|
147
|
+
def _standard_json_decoder() -> Callable[[str], JsonType]:
|
|
124
148
|
import json
|
|
125
149
|
|
|
126
150
|
_json_decoder = json.JSONDecoder()
|
|
127
|
-
|
|
128
|
-
else:
|
|
129
|
-
try:
|
|
130
|
-
import orjson
|
|
151
|
+
return _json_decoder.decode
|
|
131
152
|
|
|
132
|
-
_json_converter = orjson.loads
|
|
133
|
-
except ModuleNotFoundError:
|
|
134
|
-
import json
|
|
135
153
|
|
|
136
|
-
|
|
137
|
-
|
|
154
|
+
def _json_decoder() -> Callable[[str], JsonType]:
|
|
155
|
+
if typing.TYPE_CHECKING:
|
|
156
|
+
return _standard_json_decoder()
|
|
157
|
+
else:
|
|
158
|
+
try:
|
|
159
|
+
import orjson
|
|
160
|
+
|
|
161
|
+
return orjson.loads
|
|
162
|
+
except ModuleNotFoundError:
|
|
163
|
+
return _standard_json_decoder()
|
|
138
164
|
|
|
139
165
|
|
|
140
|
-
|
|
166
|
+
JSON_DECODER = _json_decoder()
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def _standard_json_encoder() -> Callable[[JsonType], str]:
|
|
170
|
+
import json
|
|
171
|
+
|
|
172
|
+
_json_encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":"), allow_nan=False)
|
|
173
|
+
return _json_encoder.encode
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def _json_encoder() -> Callable[[JsonType], str]:
|
|
177
|
+
if typing.TYPE_CHECKING:
|
|
178
|
+
return _standard_json_encoder()
|
|
179
|
+
else:
|
|
180
|
+
try:
|
|
181
|
+
import orjson
|
|
182
|
+
|
|
183
|
+
def _wrap(value: JsonType) -> str:
|
|
184
|
+
return orjson.dumps(value).decode()
|
|
185
|
+
|
|
186
|
+
return _wrap
|
|
187
|
+
except ModuleNotFoundError:
|
|
188
|
+
return _standard_json_encoder()
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
JSON_ENCODER = _json_encoder()
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def get_output_converter_for(tp: Any) -> Callable[[Any], Any]:
|
|
141
195
|
"""
|
|
142
196
|
Returns a callable that takes a wire type and returns a target type.
|
|
143
197
|
|
|
@@ -147,65 +201,165 @@ def get_converter_for(tp: Any) -> Callable[[Any], Any]:
|
|
|
147
201
|
|
|
148
202
|
if is_json_type(tp):
|
|
149
203
|
# asyncpg returns fields of type `json` and `jsonb` as `str`, which must be de-serialized
|
|
150
|
-
return
|
|
204
|
+
return JSON_DECODER
|
|
151
205
|
else:
|
|
152
206
|
# target data types that require conversion must have a single-argument `__init__` that takes an object of the source type
|
|
153
207
|
return tp
|
|
154
208
|
|
|
155
209
|
|
|
210
|
+
def get_input_converter_for(tp: Any) -> Callable[[Any], Any]:
|
|
211
|
+
"""
|
|
212
|
+
Returns a callable that takes a source type and returns a wire type.
|
|
213
|
+
|
|
214
|
+
A source type is one of the types supported by the library.
|
|
215
|
+
A wire type is one of the types returned by asyncpg.
|
|
216
|
+
"""
|
|
217
|
+
|
|
218
|
+
if is_json_type(tp):
|
|
219
|
+
# asyncpg expects fields of type `json` and `jsonb` as `str`, which must be serialized
|
|
220
|
+
return JSON_ENCODER
|
|
221
|
+
else:
|
|
222
|
+
# source data types that require conversion must have a single-argument `__init__` that takes an object of the source type
|
|
223
|
+
return tp
|
|
224
|
+
|
|
225
|
+
|
|
156
226
|
# maps PostgreSQL internal type names to compatible Python types
|
|
157
|
-
|
|
227
|
+
_NAME_TO_TYPE: dict[str, tuple[Any, ...]] = {
|
|
228
|
+
# boolean type
|
|
158
229
|
"bool": (bool,),
|
|
230
|
+
# numeric types
|
|
159
231
|
"int2": (int,),
|
|
160
232
|
"int4": (int,),
|
|
161
233
|
"int8": (int,),
|
|
162
234
|
"float4": (float,),
|
|
163
235
|
"float8": (float,),
|
|
164
236
|
"numeric": (Decimal,),
|
|
237
|
+
# date and time types
|
|
165
238
|
"date": (date,),
|
|
166
239
|
"time": (time,),
|
|
167
240
|
"timetz": (time,),
|
|
168
241
|
"timestamp": (datetime,),
|
|
169
242
|
"timestamptz": (datetime,),
|
|
170
243
|
"interval": (timedelta,),
|
|
244
|
+
# character sequence types
|
|
171
245
|
"bpchar": (str,),
|
|
172
246
|
"varchar": (str,),
|
|
173
247
|
"text": (str,),
|
|
248
|
+
# binary sequence types
|
|
174
249
|
"bytea": (bytes,),
|
|
250
|
+
# unique identifier type
|
|
251
|
+
"uuid": (UUID,),
|
|
252
|
+
# address types
|
|
253
|
+
"cidr": (IPv4Network, IPv6Network, IPv4Network | IPv6Network),
|
|
254
|
+
"inet": (IPv4Network, IPv6Network, IPv4Network | IPv6Network, IPv4Address, IPv6Address, IPv4Address | IPv6Address),
|
|
255
|
+
"macaddr": (str,),
|
|
256
|
+
"macaddr8": (str,),
|
|
257
|
+
# JSON type
|
|
175
258
|
"json": (str, RequiredJsonType),
|
|
176
259
|
"jsonb": (str, RequiredJsonType),
|
|
177
|
-
|
|
260
|
+
# XML type
|
|
178
261
|
"xml": (str,),
|
|
262
|
+
# geometric types
|
|
263
|
+
"point": (asyncpg.Point,),
|
|
264
|
+
"line": (asyncpg.Line,),
|
|
265
|
+
"lseg": (asyncpg.LineSegment,),
|
|
266
|
+
"box": (asyncpg.Box,),
|
|
267
|
+
"path": (asyncpg.Path,),
|
|
268
|
+
"polygon": (asyncpg.Polygon,),
|
|
269
|
+
"circle": (asyncpg.Circle,),
|
|
270
|
+
# range types
|
|
271
|
+
"int4range": (asyncpg.Range[int],),
|
|
272
|
+
"int4multirange": (list[asyncpg.Range[int]],),
|
|
273
|
+
"int8range": (asyncpg.Range[int],),
|
|
274
|
+
"int8multirange": (list[asyncpg.Range[int]],),
|
|
275
|
+
"numrange": (asyncpg.Range[Decimal],),
|
|
276
|
+
"nummultirange": (list[asyncpg.Range[Decimal]],),
|
|
277
|
+
"tsrange": (asyncpg.Range[datetime],),
|
|
278
|
+
"tsmultirange": (list[asyncpg.Range[datetime]],),
|
|
279
|
+
"tstzrange": (asyncpg.Range[datetime],),
|
|
280
|
+
"tstzmultirange": (list[asyncpg.Range[datetime]],),
|
|
281
|
+
"daterange": (asyncpg.Range[date],),
|
|
282
|
+
"datemultirange": (list[asyncpg.Range[date]],),
|
|
179
283
|
}
|
|
180
284
|
|
|
181
285
|
|
|
182
|
-
def
|
|
286
|
+
def type_to_str(tp: Any) -> str:
|
|
287
|
+
"Emits a friendly name for a type."
|
|
288
|
+
|
|
289
|
+
if isinstance(tp, type):
|
|
290
|
+
return tp.__name__
|
|
291
|
+
else:
|
|
292
|
+
return str(tp)
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
class _TypeVerifier:
|
|
183
296
|
"""
|
|
184
297
|
Verifies if the Python target type can represent the PostgreSQL source type.
|
|
185
298
|
"""
|
|
186
299
|
|
|
187
|
-
|
|
188
|
-
if is_enum_type(data_type):
|
|
189
|
-
return name in ["bpchar", "varchar", "text"]
|
|
300
|
+
_connection: Connection
|
|
190
301
|
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
302
|
+
def __init__(self, connection: Connection) -> None:
|
|
303
|
+
self._connection = connection
|
|
304
|
+
|
|
305
|
+
async def _check_enum_type(self, pg_name: str, pg_type: asyncpg.Type, data_type: type[enum.Enum]) -> None:
|
|
306
|
+
"""
|
|
307
|
+
Verifies if a Python enumeration type matches a PostgreSQL enumeration type.
|
|
308
|
+
"""
|
|
196
309
|
|
|
197
|
-
|
|
198
|
-
|
|
310
|
+
for e in data_type:
|
|
311
|
+
if not isinstance(e.value, str):
|
|
312
|
+
raise TypeMismatchError(f"expected: Python enum type `{type_to_str(data_type)}` with `str` values; got: `{type_to_str(type(e.value))}` for enum field `{e.name}`")
|
|
199
313
|
|
|
314
|
+
py_values = set(e.value for e in data_type)
|
|
200
315
|
|
|
316
|
+
rows = await self._connection.fetch("SELECT enumlabel FROM pg_enum WHERE enumtypid = $1 ORDER BY enumsortorder;", pg_type.oid)
|
|
317
|
+
db_values = set(row[0] for row in rows)
|
|
318
|
+
|
|
319
|
+
db_extra = db_values - py_values
|
|
320
|
+
if db_extra:
|
|
321
|
+
raise EnumMismatchError(f"expected: Python enum type `{type_to_str(data_type)}` to match values of PostgreSQL enum type `{pg_type.name}` for {pg_name}; missing value(s): {', '.join(f'`{val}`' for val in db_extra)})")
|
|
322
|
+
|
|
323
|
+
py_extra = py_values - db_values
|
|
324
|
+
if py_extra:
|
|
325
|
+
raise EnumMismatchError(f"expected: Python enum type `{type_to_str(data_type)}` to match values of PostgreSQL enum type `{pg_type.name}` for {pg_name}; got extra value(s): {', '.join(f'`{val}`' for val in py_extra)})")
|
|
326
|
+
|
|
327
|
+
async def check_data_type(self, pg_name: str, pg_type: asyncpg.Type, data_type: TargetType) -> None:
|
|
328
|
+
"""
|
|
329
|
+
Verifies if the Python target type can represent the PostgreSQL source type.
|
|
330
|
+
"""
|
|
331
|
+
|
|
332
|
+
if pg_type.schema == "pg_catalog": # well-known PostgreSQL types
|
|
333
|
+
if is_enum_type(data_type):
|
|
334
|
+
if pg_type.name not in ["bpchar", "varchar", "text"]:
|
|
335
|
+
raise TypeMismatchError(f"expected: Python enum type `{type_to_str(data_type)}` for {pg_name}; got: PostgreSQL type `{pg_type.kind}` of `{pg_type.name}` instead of `char`, `varchar` or `text`")
|
|
336
|
+
else:
|
|
337
|
+
expected_types = _NAME_TO_TYPE.get(pg_type.name)
|
|
338
|
+
if expected_types is None:
|
|
339
|
+
raise TypeMismatchError(f"expected: Python type `{type_to_str(data_type)}` for {pg_name}; got: unrecognized PostgreSQL type `{pg_type.kind}` of `{pg_type.name}`")
|
|
340
|
+
elif data_type not in expected_types:
|
|
341
|
+
raise TypeMismatchError(
|
|
342
|
+
f"expected: Python type `{type_to_str(data_type)}` for {pg_name}; "
|
|
343
|
+
f"got: incompatible PostgreSQL type `{pg_type.kind}` of `{pg_type.name}`, which converts to one of the Python types {', '.join(f'`{type_to_str(tp)}`' for tp in expected_types)}"
|
|
344
|
+
)
|
|
345
|
+
elif pg_type.kind == "composite": # PostgreSQL composite types
|
|
346
|
+
# user-defined composite types registered with `conn.set_type_codec()` typically using `format="tuple"`
|
|
347
|
+
pass
|
|
348
|
+
else: # custom PostgreSQL types
|
|
349
|
+
if is_enum_type(data_type):
|
|
350
|
+
await self._check_enum_type(pg_name, pg_type, data_type)
|
|
351
|
+
elif is_standard_type(data_type):
|
|
352
|
+
raise TypeMismatchError(f"expected: Python type `{type_to_str(data_type)}` for {pg_name}; got: PostgreSQL type `{pg_type.kind}` of `{pg_type.name}`")
|
|
353
|
+
else:
|
|
354
|
+
# user-defined types registered with `conn.set_type_codec()`
|
|
355
|
+
pass
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
@dataclass(frozen=True)
|
|
201
359
|
class _SQLPlaceholder:
|
|
202
360
|
ordinal: int
|
|
203
361
|
data_type: TargetType
|
|
204
362
|
|
|
205
|
-
def __init__(self, ordinal: int, data_type: TargetType) -> None:
|
|
206
|
-
self.ordinal = ordinal
|
|
207
|
-
self.data_type = data_type
|
|
208
|
-
|
|
209
363
|
def __repr__(self) -> str:
|
|
210
364
|
return f"{self.__class__.__name__}({self.ordinal}, {self.data_type!r})"
|
|
211
365
|
|
|
@@ -215,33 +369,51 @@ class _SQLObject:
|
|
|
215
369
|
Associates input and output type information with a SQL statement.
|
|
216
370
|
"""
|
|
217
371
|
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
372
|
+
_parameter_data_types: tuple[_SQLPlaceholder, ...]
|
|
373
|
+
_resultset_data_types: tuple[TargetType, ...]
|
|
374
|
+
_parameter_cast: int
|
|
375
|
+
_parameter_converters: tuple[Callable[[Any], Any], ...]
|
|
376
|
+
_required: int
|
|
377
|
+
_resultset_cast: int
|
|
378
|
+
_resultset_converters: tuple[Callable[[Any], Any], ...]
|
|
379
|
+
|
|
380
|
+
@property
|
|
381
|
+
def parameter_data_types(self) -> tuple[_SQLPlaceholder, ...]:
|
|
382
|
+
return self._parameter_data_types
|
|
383
|
+
|
|
384
|
+
@property
|
|
385
|
+
def resultset_data_types(self) -> tuple[TargetType, ...]:
|
|
386
|
+
return self._resultset_data_types
|
|
223
387
|
|
|
224
388
|
def __init__(
|
|
225
389
|
self,
|
|
226
390
|
input_data_types: tuple[TargetType, ...],
|
|
227
391
|
output_data_types: tuple[TargetType, ...],
|
|
228
392
|
) -> None:
|
|
229
|
-
self.
|
|
230
|
-
self.
|
|
393
|
+
self._parameter_data_types = tuple(_SQLPlaceholder(ordinal, get_required_type(arg)) for ordinal, arg in enumerate(input_data_types, start=1))
|
|
394
|
+
self._resultset_data_types = tuple(get_required_type(data_type) for data_type in output_data_types)
|
|
395
|
+
|
|
396
|
+
# create a bit-field of types that require cast or serialization (1: apply conversion; 0: forward value as-is)
|
|
397
|
+
parameter_cast = 0
|
|
398
|
+
for index, placeholder in enumerate(self._parameter_data_types):
|
|
399
|
+
parameter_cast |= is_json_type(placeholder.data_type) << index
|
|
400
|
+
self._parameter_cast = parameter_cast
|
|
401
|
+
|
|
402
|
+
self._parameter_converters = tuple(get_input_converter_for(placeholder.data_type) for placeholder in self._parameter_data_types)
|
|
231
403
|
|
|
232
404
|
# create a bit-field of required types (1: required; 0: optional)
|
|
233
405
|
required = 0
|
|
234
406
|
for index, data_type in enumerate(output_data_types):
|
|
235
407
|
required |= (not is_optional_type(data_type)) << index
|
|
236
|
-
self.
|
|
408
|
+
self._required = required
|
|
237
409
|
|
|
238
|
-
# create a bit-field of types that require cast or serialization (1: apply conversion; 0: forward value as-is)
|
|
239
|
-
|
|
240
|
-
for index, data_type in enumerate(self.
|
|
241
|
-
|
|
242
|
-
self.
|
|
410
|
+
# create a bit-field of types that require cast or de-serialization (1: apply conversion; 0: forward value as-is)
|
|
411
|
+
resultset_cast = 0
|
|
412
|
+
for index, data_type in enumerate(self._resultset_data_types):
|
|
413
|
+
resultset_cast |= (is_enum_type(data_type) or is_json_type(data_type) or is_inet_type(data_type)) << index
|
|
414
|
+
self._resultset_cast = resultset_cast
|
|
243
415
|
|
|
244
|
-
self.
|
|
416
|
+
self._resultset_converters = tuple(get_output_converter_for(data_type) for data_type in self._resultset_data_types)
|
|
245
417
|
|
|
246
418
|
def _raise_required_is_none(self, row: tuple[Any, ...], row_index: int | None = None) -> None:
|
|
247
419
|
"""
|
|
@@ -249,12 +421,12 @@ class _SQLObject:
|
|
|
249
421
|
"""
|
|
250
422
|
|
|
251
423
|
for col_index in range(len(row)):
|
|
252
|
-
if (self.
|
|
424
|
+
if (self._required >> col_index & 1) and row[col_index] is None:
|
|
253
425
|
if row_index is not None:
|
|
254
426
|
row_col_spec = f"row #{row_index} and column #{col_index}"
|
|
255
427
|
else:
|
|
256
428
|
row_col_spec = f"column #{col_index}"
|
|
257
|
-
raise
|
|
429
|
+
raise NoneTypeError(f"expected: {self._resultset_data_types[col_index]} in {row_col_spec}; got: NULL")
|
|
258
430
|
|
|
259
431
|
def check_rows(self, rows: list[tuple[Any, ...]]) -> None:
|
|
260
432
|
"""
|
|
@@ -264,7 +436,7 @@ class _SQLObject:
|
|
|
264
436
|
if not rows:
|
|
265
437
|
return
|
|
266
438
|
|
|
267
|
-
required = self.
|
|
439
|
+
required = self._required
|
|
268
440
|
if not required:
|
|
269
441
|
return
|
|
270
442
|
|
|
@@ -317,7 +489,7 @@ class _SQLObject:
|
|
|
317
489
|
Verifies if declared types match actual value types in a single row.
|
|
318
490
|
"""
|
|
319
491
|
|
|
320
|
-
required = self.
|
|
492
|
+
required = self._required
|
|
321
493
|
if not required:
|
|
322
494
|
return
|
|
323
495
|
|
|
@@ -361,8 +533,72 @@ class _SQLObject:
|
|
|
361
533
|
Verifies if the declared type matches the actual value type.
|
|
362
534
|
"""
|
|
363
535
|
|
|
364
|
-
if self.
|
|
365
|
-
raise
|
|
536
|
+
if self._required and value is None:
|
|
537
|
+
raise NoneTypeError(f"expected: {self._resultset_data_types[0]}; got: NULL")
|
|
538
|
+
|
|
539
|
+
def convert_arg_lists(self, arg_lists: Iterable[Sequence[Any]]) -> Iterable[Sequence[Any]]:
|
|
540
|
+
"""
|
|
541
|
+
Converts a list of Python query argument tuples to a list of PostgreSQL parameter tuples.
|
|
542
|
+
"""
|
|
543
|
+
|
|
544
|
+
cast = self._parameter_cast
|
|
545
|
+
if cast:
|
|
546
|
+
converters = self._parameter_converters
|
|
547
|
+
yield from (tuple((converters[i](value) if (value := arg[i]) is not None and cast >> i & 1 else value) for i in range(len(arg))) for arg in arg_lists)
|
|
548
|
+
else:
|
|
549
|
+
yield from arg_lists
|
|
550
|
+
|
|
551
|
+
def convert_arg_list(self, arg_list: Sequence[Any]) -> Sequence[Any]:
|
|
552
|
+
"""
|
|
553
|
+
Converts Python query arguments to PostgreSQL parameters.
|
|
554
|
+
"""
|
|
555
|
+
|
|
556
|
+
cast = self._parameter_cast
|
|
557
|
+
if cast:
|
|
558
|
+
converters = self._parameter_converters
|
|
559
|
+
return tuple((converters[i](value) if (value := arg_list[i]) is not None and cast >> i & 1 else value) for i in range(len(arg_list)))
|
|
560
|
+
else:
|
|
561
|
+
return tuple(value for value in arg_list)
|
|
562
|
+
|
|
563
|
+
def convert_rows(self, rows: list[asyncpg.Record]) -> list[tuple[Any, ...]]:
|
|
564
|
+
"""
|
|
565
|
+
Converts columns in the PostgreSQL result-set to their corresponding Python target types.
|
|
566
|
+
|
|
567
|
+
:param rows: List of rows returned by PostgreSQL.
|
|
568
|
+
:returns: List of tuples with each tuple element having the configured Python target type.
|
|
569
|
+
"""
|
|
570
|
+
|
|
571
|
+
cast = self._resultset_cast
|
|
572
|
+
if cast:
|
|
573
|
+
converters = self._resultset_converters
|
|
574
|
+
return [tuple((converters[i](value) if (value := row[i]) is not None and cast >> i & 1 else value) for i in range(len(row))) for row in rows]
|
|
575
|
+
else:
|
|
576
|
+
return [tuple(value for value in row) for row in rows]
|
|
577
|
+
|
|
578
|
+
def convert_row(self, row: asyncpg.Record) -> tuple[Any, ...]:
|
|
579
|
+
"""
|
|
580
|
+
Converts columns in the PostgreSQL result-set to their corresponding Python target types.
|
|
581
|
+
|
|
582
|
+
:param row: A single row returned by PostgreSQL.
|
|
583
|
+
:returns: A tuple with each tuple element having the configured Python target type.
|
|
584
|
+
"""
|
|
585
|
+
|
|
586
|
+
cast = self._resultset_cast
|
|
587
|
+
if cast:
|
|
588
|
+
converters = self._resultset_converters
|
|
589
|
+
return tuple((converters[i](value) if (value := row[i]) is not None and cast >> i & 1 else value) for i in range(len(row)))
|
|
590
|
+
else:
|
|
591
|
+
return tuple(value for value in row)
|
|
592
|
+
|
|
593
|
+
def convert_value(self, value: Any) -> Any:
|
|
594
|
+
"""
|
|
595
|
+
Converts a single PostgreSQL value to its corresponding Python target type.
|
|
596
|
+
|
|
597
|
+
:param value: A single value returned by PostgreSQL.
|
|
598
|
+
:returns: A converted value having the configured Python target type.
|
|
599
|
+
"""
|
|
600
|
+
|
|
601
|
+
return self._resultset_converters[0](value) if value is not None and self._resultset_cast else value
|
|
366
602
|
|
|
367
603
|
@abstractmethod
|
|
368
604
|
def query(self) -> str:
|
|
@@ -388,8 +624,8 @@ if sys.version_info >= (3, 14):
|
|
|
388
624
|
A SQL query specified with the Python t-string syntax.
|
|
389
625
|
"""
|
|
390
626
|
|
|
391
|
-
|
|
392
|
-
|
|
627
|
+
_strings: tuple[str, ...]
|
|
628
|
+
_placeholders: tuple[_SQLPlaceholder, ...]
|
|
393
629
|
|
|
394
630
|
def __init__(
|
|
395
631
|
self,
|
|
@@ -408,7 +644,7 @@ if sys.version_info >= (3, 14):
|
|
|
408
644
|
if not isinstance(ip.value, int):
|
|
409
645
|
raise TypeError(f"interpolation `{ip.expression}` expected to evaluate to an integer")
|
|
410
646
|
|
|
411
|
-
self.
|
|
647
|
+
self._strings = template.strings
|
|
412
648
|
|
|
413
649
|
if len(self.parameter_data_types) > 0:
|
|
414
650
|
|
|
@@ -418,16 +654,16 @@ if sys.version_info >= (3, 14):
|
|
|
418
654
|
raise IndexError(f"interpolation `{ip.expression}` is an ordinal out of range; expected: 0 < value <= {len(self.parameter_data_types)}")
|
|
419
655
|
return self.parameter_data_types[int(ip.value) - 1]
|
|
420
656
|
|
|
421
|
-
self.
|
|
657
|
+
self._placeholders = tuple(_to_placeholder(ip) for ip in template.interpolations)
|
|
422
658
|
else:
|
|
423
|
-
self.
|
|
659
|
+
self._placeholders = ()
|
|
424
660
|
|
|
425
661
|
def query(self) -> str:
|
|
426
662
|
buf = StringIO()
|
|
427
|
-
for s, p in zip(self.
|
|
663
|
+
for s, p in zip(self._strings[:-1], self._placeholders, strict=True):
|
|
428
664
|
buf.write(s)
|
|
429
665
|
buf.write(f"${p.ordinal}")
|
|
430
|
-
buf.write(self.
|
|
666
|
+
buf.write(self._strings[-1])
|
|
431
667
|
return buf.getvalue()
|
|
432
668
|
|
|
433
669
|
else:
|
|
@@ -439,7 +675,7 @@ class _SQLString(_SQLObject):
|
|
|
439
675
|
A SQL query specified as a plain string (e.g. f-string).
|
|
440
676
|
"""
|
|
441
677
|
|
|
442
|
-
|
|
678
|
+
_sql: str
|
|
443
679
|
|
|
444
680
|
def __init__(
|
|
445
681
|
self,
|
|
@@ -449,10 +685,10 @@ class _SQLString(_SQLObject):
|
|
|
449
685
|
resultset: tuple[TargetType, ...],
|
|
450
686
|
) -> None:
|
|
451
687
|
super().__init__(args, resultset)
|
|
452
|
-
self.
|
|
688
|
+
self._sql = sql
|
|
453
689
|
|
|
454
690
|
def query(self) -> str:
|
|
455
|
-
return self.
|
|
691
|
+
return self._sql
|
|
456
692
|
|
|
457
693
|
|
|
458
694
|
class _SQL(Protocol):
|
|
@@ -461,80 +697,69 @@ class _SQL(Protocol):
|
|
|
461
697
|
"""
|
|
462
698
|
|
|
463
699
|
|
|
464
|
-
Connection: TypeAlias = asyncpg.Connection | asyncpg.pool.PoolConnectionProxy
|
|
465
|
-
|
|
466
|
-
|
|
467
700
|
class _SQLImpl(_SQL):
|
|
468
701
|
"""
|
|
469
702
|
Forwards input data to an `asyncpg.PreparedStatement`, and validates output data (if necessary).
|
|
470
703
|
"""
|
|
471
704
|
|
|
472
|
-
|
|
705
|
+
_sql: _SQLObject
|
|
473
706
|
|
|
474
707
|
def __init__(self, sql: _SQLObject) -> None:
|
|
475
|
-
self.
|
|
708
|
+
self._sql = sql
|
|
476
709
|
|
|
477
710
|
def __str__(self) -> str:
|
|
478
|
-
return str(self.
|
|
711
|
+
return str(self._sql)
|
|
479
712
|
|
|
480
713
|
def __repr__(self) -> str:
|
|
481
|
-
return repr(self.
|
|
714
|
+
return repr(self._sql)
|
|
482
715
|
|
|
483
716
|
async def _prepare(self, connection: Connection) -> PreparedStatement:
|
|
484
|
-
stmt = await connection.prepare(self.
|
|
717
|
+
stmt = await connection.prepare(self._sql.query())
|
|
485
718
|
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
719
|
+
verifier = _TypeVerifier(connection)
|
|
720
|
+
for param, placeholder in zip(stmt.get_parameters(), self._sql.parameter_data_types, strict=True):
|
|
721
|
+
await verifier.check_data_type(f"parameter ${placeholder.ordinal}", param, placeholder.data_type)
|
|
722
|
+
for attr, data_type in zip(stmt.get_attributes(), self._sql.resultset_data_types, strict=True):
|
|
723
|
+
await verifier.check_data_type(f"column `{attr.name}`", attr.type, data_type)
|
|
489
724
|
|
|
490
725
|
return stmt
|
|
491
726
|
|
|
492
727
|
async def execute(self, connection: asyncpg.Connection, *args: Any) -> None:
|
|
493
|
-
await connection.execute(self.
|
|
728
|
+
await connection.execute(self._sql.query(), *self._sql.convert_arg_list(args))
|
|
494
729
|
|
|
495
730
|
async def executemany(self, connection: asyncpg.Connection, args: Iterable[Sequence[Any]]) -> None:
|
|
496
731
|
stmt = await self._prepare(connection)
|
|
497
|
-
await stmt.executemany(args)
|
|
732
|
+
await stmt.executemany(self._sql.convert_arg_lists(args))
|
|
498
733
|
|
|
499
734
|
def _cast_fetch(self, rows: list[asyncpg.Record]) -> list[tuple[Any, ...]]:
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
converters = self.sql.converters
|
|
503
|
-
resultset = [tuple((converters[i](value) if (value := row[i]) is not None and cast >> i & 1 else value) for i in range(len(row))) for row in rows]
|
|
504
|
-
else:
|
|
505
|
-
resultset = [tuple(value for value in row) for row in rows]
|
|
506
|
-
self.sql.check_rows(resultset)
|
|
735
|
+
resultset = self._sql.convert_rows(rows)
|
|
736
|
+
self._sql.check_rows(resultset)
|
|
507
737
|
return resultset
|
|
508
738
|
|
|
509
739
|
async def fetch(self, connection: asyncpg.Connection, *args: Any) -> list[tuple[Any, ...]]:
|
|
510
740
|
stmt = await self._prepare(connection)
|
|
511
|
-
rows = await stmt.fetch(*args)
|
|
741
|
+
rows = await stmt.fetch(*self._sql.convert_arg_list(args))
|
|
512
742
|
return self._cast_fetch(rows)
|
|
513
743
|
|
|
514
744
|
async def fetchmany(self, connection: asyncpg.Connection, args: Iterable[Sequence[Any]]) -> list[tuple[Any, ...]]:
|
|
515
745
|
stmt = await self._prepare(connection)
|
|
516
|
-
rows = await stmt.fetchmany(args)
|
|
746
|
+
rows = await stmt.fetchmany(self._sql.convert_arg_lists(args))
|
|
517
747
|
return self._cast_fetch(rows)
|
|
518
748
|
|
|
519
749
|
async def fetchrow(self, connection: asyncpg.Connection, *args: Any) -> tuple[Any, ...] | None:
|
|
520
750
|
stmt = await self._prepare(connection)
|
|
521
|
-
row = await stmt.fetchrow(*args)
|
|
751
|
+
row = await stmt.fetchrow(*self._sql.convert_arg_list(args))
|
|
522
752
|
if row is None:
|
|
523
753
|
return None
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
converters = self.sql.converters
|
|
527
|
-
resultset = tuple((converters[i](value) if (value := row[i]) is not None and cast >> i & 1 else value) for i in range(len(row)))
|
|
528
|
-
else:
|
|
529
|
-
resultset = tuple(value for value in row)
|
|
530
|
-
self.sql.check_row(resultset)
|
|
754
|
+
resultset = self._sql.convert_row(row)
|
|
755
|
+
self._sql.check_row(resultset)
|
|
531
756
|
return resultset
|
|
532
757
|
|
|
533
758
|
async def fetchval(self, connection: asyncpg.Connection, *args: Any) -> Any:
|
|
534
759
|
stmt = await self._prepare(connection)
|
|
535
|
-
value = await stmt.fetchval(*args)
|
|
536
|
-
result = self.
|
|
537
|
-
self.
|
|
760
|
+
value = await stmt.fetchval(*self._sql.convert_arg_list(args))
|
|
761
|
+
result = self._sql.convert_value(value)
|
|
762
|
+
self._sql.check_value(result)
|
|
538
763
|
return result
|
|
539
764
|
|
|
540
765
|
|
|
@@ -547,9 +772,7 @@ R2 = TypeVar("R2")
|
|
|
547
772
|
RX = TypeVarTuple("RX")
|
|
548
773
|
|
|
549
774
|
|
|
550
|
-
### START OF AUTO-GENERATED BLOCK ###
|
|
551
|
-
|
|
552
|
-
|
|
775
|
+
### START OF AUTO-GENERATED BLOCK FOR Protocol ###
|
|
553
776
|
class SQL_P0(Protocol):
|
|
554
777
|
@abstractmethod
|
|
555
778
|
async def execute(self, connection: Connection) -> None: ...
|
|
@@ -598,52 +821,116 @@ class SQL_RX_PX(SQL_PX[Unpack[PX]], Protocol[RT, Unpack[PX]]):
|
|
|
598
821
|
async def fetchrow(self, connection: Connection, *args: Unpack[PX]) -> RT | None: ...
|
|
599
822
|
|
|
600
823
|
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
@overload
|
|
604
|
-
def sql(stmt: SQLExpression, *, result: type[R1]) -> SQL_R1_P0[R1]: ...
|
|
605
|
-
@overload
|
|
606
|
-
def sql(stmt: SQLExpression, *, resultset: type[tuple[R1]]) -> SQL_R1_P0[R1]: ...
|
|
607
|
-
@overload
|
|
608
|
-
def sql(stmt: SQLExpression, *, resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_RX_P0[tuple[R1, R2, Unpack[RX]]]: ...
|
|
609
|
-
@overload
|
|
610
|
-
def sql(stmt: SQLExpression, *, arg: type[P1]) -> SQL_PX[P1]: ...
|
|
611
|
-
@overload
|
|
612
|
-
def sql(stmt: SQLExpression, *, arg: type[P1], result: type[R1]) -> SQL_R1_PX[R1, P1]: ...
|
|
613
|
-
@overload
|
|
614
|
-
def sql(stmt: SQLExpression, *, arg: type[P1], resultset: type[tuple[R1]]) -> SQL_R1_PX[R1, P1]: ...
|
|
615
|
-
@overload
|
|
616
|
-
def sql(stmt: SQLExpression, *, arg: type[P1], resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_RX_PX[tuple[R1, R2, Unpack[RX]], P1]: ...
|
|
617
|
-
@overload
|
|
618
|
-
def sql(stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]]) -> SQL_PX[P1, Unpack[PX]]: ...
|
|
619
|
-
@overload
|
|
620
|
-
def sql(stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]], result: type[R1]) -> SQL_R1_PX[R1, P1, Unpack[PX]]: ...
|
|
621
|
-
@overload
|
|
622
|
-
def sql(stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]], resultset: type[tuple[R1]]) -> SQL_R1_PX[R1, P1, Unpack[PX]]: ...
|
|
623
|
-
@overload
|
|
624
|
-
def sql(stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]], resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_RX_PX[tuple[R1, R2, Unpack[RX]], P1, Unpack[PX]]: ...
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
### END OF AUTO-GENERATED BLOCK ###
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
def sql(
|
|
631
|
-
stmt: SQLExpression,
|
|
632
|
-
*,
|
|
633
|
-
args: type[Any] | None = None,
|
|
634
|
-
resultset: type[Any] | None = None,
|
|
635
|
-
arg: type[Any] | None = None,
|
|
636
|
-
result: type[Any] | None = None,
|
|
637
|
-
) -> _SQL:
|
|
638
|
-
"""
|
|
639
|
-
Creates a SQL statement with associated type information.
|
|
824
|
+
### END OF AUTO-GENERATED BLOCK FOR Protocol ###
|
|
825
|
+
|
|
640
826
|
|
|
641
|
-
|
|
642
|
-
:param args: Type signature for multiple input parameters (e.g. `tuple[bool, int, str]`).
|
|
643
|
-
:param resultset: Type signature for multiple resultset columns (e.g. `tuple[datetime, Decimal, str]`).
|
|
644
|
-
:param arg: Type signature for a single input parameter (e.g. `int`).
|
|
645
|
-
:param result: Type signature for a single result column (e.g. `UUID`).
|
|
827
|
+
class SQLFactory:
|
|
646
828
|
"""
|
|
829
|
+
Creates type-safe SQL queries.
|
|
830
|
+
"""
|
|
831
|
+
|
|
832
|
+
### START OF AUTO-GENERATED BLOCK FOR sql ###
|
|
833
|
+
@overload
|
|
834
|
+
def sql(self, stmt: SQLExpression) -> SQL_P0: ...
|
|
835
|
+
@overload
|
|
836
|
+
def sql(self, stmt: SQLExpression, *, result: type[R1]) -> SQL_R1_P0[R1]: ...
|
|
837
|
+
@overload
|
|
838
|
+
def sql(self, stmt: SQLExpression, *, resultset: type[tuple[R1]]) -> SQL_R1_P0[R1]: ...
|
|
839
|
+
@overload
|
|
840
|
+
def sql(self, stmt: SQLExpression, *, resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_RX_P0[tuple[R1, R2, Unpack[RX]]]: ...
|
|
841
|
+
@overload
|
|
842
|
+
def sql(self, stmt: SQLExpression, *, arg: type[P1]) -> SQL_PX[P1]: ...
|
|
843
|
+
@overload
|
|
844
|
+
def sql(self, stmt: SQLExpression, *, arg: type[P1], result: type[R1]) -> SQL_R1_PX[R1, P1]: ...
|
|
845
|
+
@overload
|
|
846
|
+
def sql(self, stmt: SQLExpression, *, arg: type[P1], resultset: type[tuple[R1]]) -> SQL_R1_PX[R1, P1]: ...
|
|
847
|
+
@overload
|
|
848
|
+
def sql(self, stmt: SQLExpression, *, arg: type[P1], resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_RX_PX[tuple[R1, R2, Unpack[RX]], P1]: ...
|
|
849
|
+
@overload
|
|
850
|
+
def sql(self, stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]]) -> SQL_PX[P1, Unpack[PX]]: ...
|
|
851
|
+
@overload
|
|
852
|
+
def sql(self, stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]], result: type[R1]) -> SQL_R1_PX[R1, P1, Unpack[PX]]: ...
|
|
853
|
+
@overload
|
|
854
|
+
def sql(self, stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]], resultset: type[tuple[R1]]) -> SQL_R1_PX[R1, P1, Unpack[PX]]: ...
|
|
855
|
+
@overload
|
|
856
|
+
def sql(self, stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]], resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_RX_PX[tuple[R1, R2, Unpack[RX]], P1, Unpack[PX]]: ...
|
|
857
|
+
|
|
858
|
+
### END OF AUTO-GENERATED BLOCK FOR sql ###
|
|
859
|
+
|
|
860
|
+
def sql(self, stmt: SQLExpression, *, args: type[Any] | None = None, resultset: type[Any] | None = None, arg: type[Any] | None = None, result: type[Any] | None = None) -> _SQL:
|
|
861
|
+
"""
|
|
862
|
+
Creates a SQL statement with associated type information.
|
|
863
|
+
|
|
864
|
+
:param stmt: SQL statement as a literal string or template.
|
|
865
|
+
:param args: Type signature for multiple input parameters (e.g. `tuple[bool, int, str]`).
|
|
866
|
+
:param resultset: Type signature for multiple resultset columns (e.g. `tuple[datetime, Decimal, str]`).
|
|
867
|
+
:param arg: Type signature for a single input parameter (e.g. `int`).
|
|
868
|
+
:param result: Type signature for a single result column (e.g. `UUID`).
|
|
869
|
+
"""
|
|
870
|
+
|
|
871
|
+
input_data_types, output_data_types = _sql_args_resultset(args=args, resultset=resultset, arg=arg, result=result)
|
|
872
|
+
|
|
873
|
+
obj: _SQLObject
|
|
874
|
+
if sys.version_info >= (3, 14):
|
|
875
|
+
match stmt:
|
|
876
|
+
case Template():
|
|
877
|
+
obj = _SQLTemplate(stmt, args=input_data_types, resultset=output_data_types)
|
|
878
|
+
case str():
|
|
879
|
+
obj = _SQLString(stmt, args=input_data_types, resultset=output_data_types)
|
|
880
|
+
else:
|
|
881
|
+
obj = _SQLString(stmt, args=input_data_types, resultset=output_data_types)
|
|
882
|
+
|
|
883
|
+
return _SQLImpl(obj)
|
|
884
|
+
|
|
885
|
+
### START OF AUTO-GENERATED BLOCK FOR unsafe_sql ###
|
|
886
|
+
@overload
|
|
887
|
+
def unsafe_sql(self, stmt: str) -> SQL_P0: ...
|
|
888
|
+
@overload
|
|
889
|
+
def unsafe_sql(self, stmt: str, *, result: type[R1]) -> SQL_R1_P0[R1]: ...
|
|
890
|
+
@overload
|
|
891
|
+
def unsafe_sql(self, stmt: str, *, resultset: type[tuple[R1]]) -> SQL_R1_P0[R1]: ...
|
|
892
|
+
@overload
|
|
893
|
+
def unsafe_sql(self, stmt: str, *, resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_RX_P0[tuple[R1, R2, Unpack[RX]]]: ...
|
|
894
|
+
@overload
|
|
895
|
+
def unsafe_sql(self, stmt: str, *, arg: type[P1]) -> SQL_PX[P1]: ...
|
|
896
|
+
@overload
|
|
897
|
+
def unsafe_sql(self, stmt: str, *, arg: type[P1], result: type[R1]) -> SQL_R1_PX[R1, P1]: ...
|
|
898
|
+
@overload
|
|
899
|
+
def unsafe_sql(self, stmt: str, *, arg: type[P1], resultset: type[tuple[R1]]) -> SQL_R1_PX[R1, P1]: ...
|
|
900
|
+
@overload
|
|
901
|
+
def unsafe_sql(self, stmt: str, *, arg: type[P1], resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_RX_PX[tuple[R1, R2, Unpack[RX]], P1]: ...
|
|
902
|
+
@overload
|
|
903
|
+
def unsafe_sql(self, stmt: str, *, args: type[tuple[P1, Unpack[PX]]]) -> SQL_PX[P1, Unpack[PX]]: ...
|
|
904
|
+
@overload
|
|
905
|
+
def unsafe_sql(self, stmt: str, *, args: type[tuple[P1, Unpack[PX]]], result: type[R1]) -> SQL_R1_PX[R1, P1, Unpack[PX]]: ...
|
|
906
|
+
@overload
|
|
907
|
+
def unsafe_sql(self, stmt: str, *, args: type[tuple[P1, Unpack[PX]]], resultset: type[tuple[R1]]) -> SQL_R1_PX[R1, P1, Unpack[PX]]: ...
|
|
908
|
+
@overload
|
|
909
|
+
def unsafe_sql(self, stmt: str, *, args: type[tuple[P1, Unpack[PX]]], resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_RX_PX[tuple[R1, R2, Unpack[RX]], P1, Unpack[PX]]: ...
|
|
910
|
+
|
|
911
|
+
### END OF AUTO-GENERATED BLOCK FOR unsafe_sql ###
|
|
912
|
+
|
|
913
|
+
def unsafe_sql(self, stmt: str, *, args: type[Any] | None = None, resultset: type[Any] | None = None, arg: type[Any] | None = None, result: type[Any] | None = None) -> _SQL:
|
|
914
|
+
"""
|
|
915
|
+
Creates a SQL statement with associated type information from a string.
|
|
916
|
+
|
|
917
|
+
This offers an alternative to the function :func:`sql` when we want to prevent the type checker from enforcing
|
|
918
|
+
a string literal, e.g. when we want to embed a variable as the table name to dynamically create a SQL statement.
|
|
919
|
+
|
|
920
|
+
:param stmt: SQL statement as a string (or f-string).
|
|
921
|
+
:param args: Type signature for multiple input parameters (e.g. `tuple[bool, int, str]`).
|
|
922
|
+
:param resultset: Type signature for multiple resultset columns (e.g. `tuple[datetime, Decimal, str]`).
|
|
923
|
+
:param arg: Type signature for a single input parameter (e.g. `int`).
|
|
924
|
+
:param result: Type signature for a single result column (e.g. `UUID`).
|
|
925
|
+
"""
|
|
926
|
+
|
|
927
|
+
input_data_types, output_data_types = _sql_args_resultset(args=args, resultset=resultset, arg=arg, result=result)
|
|
928
|
+
obj = _SQLString(stmt, args=input_data_types, resultset=output_data_types)
|
|
929
|
+
return _SQLImpl(obj)
|
|
930
|
+
|
|
931
|
+
|
|
932
|
+
def _sql_args_resultset(*, args: type[Any] | None = None, resultset: type[Any] | None = None, arg: type[Any] | None = None, result: type[Any] | None = None) -> tuple[tuple[Any, ...], tuple[Any, ...]]:
|
|
933
|
+
"Parses an argument/resultset signature into input/output types."
|
|
647
934
|
|
|
648
935
|
if args is not None and arg is not None:
|
|
649
936
|
raise TypeError("expected: either `args` or `arg`; got: both")
|
|
@@ -668,14 +955,10 @@ def sql(
|
|
|
668
955
|
else:
|
|
669
956
|
output_data_types = ()
|
|
670
957
|
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
obj = _SQLTemplate(stmt, args=input_data_types, resultset=output_data_types)
|
|
676
|
-
case str():
|
|
677
|
-
obj = _SQLString(stmt, args=input_data_types, resultset=output_data_types)
|
|
678
|
-
else:
|
|
679
|
-
obj = _SQLString(stmt, args=input_data_types, resultset=output_data_types)
|
|
958
|
+
return input_data_types, output_data_types
|
|
959
|
+
|
|
960
|
+
|
|
961
|
+
FACTORY: SQLFactory = SQLFactory()
|
|
680
962
|
|
|
681
|
-
|
|
963
|
+
sql = FACTORY.sql
|
|
964
|
+
unsafe_sql = FACTORY.unsafe_sql
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: asyncpg_typed
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.3
|
|
4
4
|
Summary: Type-safe queries for asyncpg
|
|
5
5
|
Author-email: Levente Hunyadi <hunyadi@gmail.com>
|
|
6
6
|
Maintainer-email: Levente Hunyadi <hunyadi@gmail.com>
|
|
@@ -113,6 +113,8 @@ def sql(
|
|
|
113
113
|
) -> _SQL: ...
|
|
114
114
|
```
|
|
115
115
|
|
|
116
|
+
#### Parameters to factory function
|
|
117
|
+
|
|
116
118
|
The parameter `stmt` represents a SQL expression, either as a literal string or a template (i.e. a *t-string*).
|
|
117
119
|
|
|
118
120
|
If the expression is a string, it can have PostgreSQL parameter placeholders such as `$1`, `$2` or `$3`:
|
|
@@ -127,24 +129,62 @@ If the expression is a *t-string*, it can have replacement fields that evaluate
|
|
|
127
129
|
t"INSERT INTO table_name (col_1, col_2, col_3) VALUES ({1}, {2}, {3});"
|
|
128
130
|
```
|
|
129
131
|
|
|
130
|
-
The parameters `args` and `resultset` take a `tuple` of several types `Px` or `Rx
|
|
132
|
+
The parameters `args` and `resultset` take a `tuple` of several types `Px` or `Rx`.
|
|
133
|
+
|
|
134
|
+
The parameters `arg` and `result` take a single type `P` or `R`. Passing a simple type (e.g. `type[T]`) directly via `arg` and `result` is for convenience, and is equivalent to passing a one-element tuple of the same simple type (i.e. `type[tuple[T]]`) via `args` and `resultset`.
|
|
135
|
+
|
|
136
|
+
The number of types in `args` must correspond to the number of query parameters. (This is validated on calling `sql(...)` for the *t-string* syntax.) The number of types in `resultset` must correspond to the number of columns returned by the query.
|
|
137
|
+
|
|
138
|
+
#### Argument and resultset types
|
|
139
|
+
|
|
140
|
+
When passing Python types via the parameters `args` and `resultset`, each type may be any of the following:
|
|
131
141
|
|
|
132
142
|
* (required) simple type
|
|
133
143
|
* optional simple type (`T | None`)
|
|
144
|
+
* special union type
|
|
134
145
|
|
|
135
146
|
Simple types include:
|
|
136
147
|
|
|
137
148
|
* `bool`
|
|
138
|
-
*
|
|
139
|
-
* `
|
|
140
|
-
* `
|
|
141
|
-
* `
|
|
142
|
-
*
|
|
143
|
-
* `datetime.
|
|
149
|
+
* numeric types:
|
|
150
|
+
* `int`
|
|
151
|
+
* `float`
|
|
152
|
+
* `decimal.Decimal`
|
|
153
|
+
* date and time types:
|
|
154
|
+
* `datetime.date`
|
|
155
|
+
* `datetime.time`
|
|
156
|
+
* `datetime.datetime`
|
|
157
|
+
* `datetime.timedelta`
|
|
144
158
|
* `str`
|
|
145
159
|
* `bytes`
|
|
146
160
|
* `uuid.UUID`
|
|
147
|
-
*
|
|
161
|
+
* types defined in the module [ipaddress](https://docs.python.org/3/library/ipaddress.html):
|
|
162
|
+
* `ipaddress.IPv4Address`
|
|
163
|
+
* `ipaddress.IPv6Address`
|
|
164
|
+
* `ipaddress.IPv4Network`
|
|
165
|
+
* `ipaddress.IPv6Network`
|
|
166
|
+
* [asyncpg representations](https://magicstack.github.io/asyncpg/current/api/index.html#module-asyncpg.types) of PostgreSQL geometric types:
|
|
167
|
+
* `asyncpg.Point`
|
|
168
|
+
* `asyncpg.Line`
|
|
169
|
+
* `asyncpg.LineSegment`
|
|
170
|
+
* `asyncpg.Box`
|
|
171
|
+
* `asyncpg.Path`
|
|
172
|
+
* `asyncpg.Polygon`
|
|
173
|
+
* `asyncpg.Circle`
|
|
174
|
+
* concrete types of [asyncpg.Range](https://magicstack.github.io/asyncpg/current/api/index.html#asyncpg.types.Range):
|
|
175
|
+
* `asyncpg.Range[int]`
|
|
176
|
+
* `asyncpg.Range[Decimal]`
|
|
177
|
+
* `asyncpg.Range[date]`
|
|
178
|
+
* `asyncpg.Range[datetime]`
|
|
179
|
+
* a user-defined enumeration class that derives from `StrEnum`
|
|
180
|
+
|
|
181
|
+
Custom Python types corresponding to PostgreSQL scalar or [composite types](https://www.postgresql.org/docs/current/rowtypes.html) are permitted. These types need to be pre-registered with [set_type_codec](https://magicstack.github.io/asyncpg/current/api/index.html#asyncpg.connection.Connection.set_type_codec) passing an encoder, a decoder and typically `format="tuple"`.
|
|
182
|
+
|
|
183
|
+
In general, union types are not allowed. However, there are notable exceptions. Special union types are as follows:
|
|
184
|
+
|
|
185
|
+
* `JsonType` to represent an object reconstructed from a JSON string
|
|
186
|
+
* `IPv4Address | IPv6Address` to denote either an IPv4 or IPv6 address
|
|
187
|
+
* `IPv4Network | IPv6Network` to denote either an IPv4 or IPv6 network definition
|
|
148
188
|
|
|
149
189
|
Types are grouped together with `tuple`:
|
|
150
190
|
|
|
@@ -152,39 +192,74 @@ Types are grouped together with `tuple`:
|
|
|
152
192
|
tuple[bool, int, str | None]
|
|
153
193
|
```
|
|
154
194
|
|
|
155
|
-
|
|
195
|
+
Both `args` and `resultset` types must be compatible with their corresponding PostgreSQL query parameter types and resultset column types, respectively. The following table shows the mapping between PostgreSQL and Python types. When there are multiple options separated by a slash, either of the types can be specified as a source or target type.
|
|
156
196
|
|
|
157
|
-
|
|
197
|
+
| PostgreSQL type | Python type |
|
|
198
|
+
| ---------------------------- | ---------------------------------- |
|
|
199
|
+
| `bool` | `bool` |
|
|
200
|
+
| `smallint` | `int` |
|
|
201
|
+
| `integer` | `int` |
|
|
202
|
+
| `bigint` | `int` |
|
|
203
|
+
| `real`/`float4` | `float` |
|
|
204
|
+
| `double`/`float8` | `float` |
|
|
205
|
+
| `decimal`/`numeric` | `Decimal` |
|
|
206
|
+
| `date` | `date` |
|
|
207
|
+
| `time` | `time` (naive) |
|
|
208
|
+
| `timetz` | `time` (tz) |
|
|
209
|
+
| `timestamp` | `datetime` (naive) |
|
|
210
|
+
| `timestamptz` | `datetime` (tz) |
|
|
211
|
+
| `interval` | `timedelta` |
|
|
212
|
+
| `char(N)` | `str` |
|
|
213
|
+
| `varchar(N)` | `str` |
|
|
214
|
+
| `text` | `str` |
|
|
215
|
+
| `bytea` | `bytes` |
|
|
216
|
+
| `uuid` | `UUID` |
|
|
217
|
+
| `cidr` | `IPvXNetwork` |
|
|
218
|
+
| `inet` | `IPvXNetwork`/`IPvXAddress` |
|
|
219
|
+
| `macaddr` | `str` |
|
|
220
|
+
| `macaddr8` | `str` |
|
|
221
|
+
| `json` | `str`/`JsonType` |
|
|
222
|
+
| `jsonb` | `str`/`JsonType` |
|
|
223
|
+
| `xml` | `str` |
|
|
224
|
+
| any enumeration type | `E: StrEnum` |
|
|
225
|
+
| `point` | `asyncpg.Point` |
|
|
226
|
+
| `line` | `asyncpg.Line` |
|
|
227
|
+
| `lseg` | `asyncpg.LineSegment` |
|
|
228
|
+
| `box` | `asyncpg.Box` |
|
|
229
|
+
| `path` | `asyncpg.Path` |
|
|
230
|
+
| `polygon` | `asyncpg.Polygon` |
|
|
231
|
+
| `circle` | `asyncpg.Circle` |
|
|
232
|
+
| `int4range` | `asyncpg.Range[int]` |
|
|
233
|
+
| `int8range` | `asyncpg.Range[int]` |
|
|
234
|
+
| `numrange` | `asyncpg.Range[Decimal]` |
|
|
235
|
+
| `tsrange` | `asyncpg.Range[datetime]` (naive) |
|
|
236
|
+
| `tstzrange` | `asyncpg.Range[datetime]` (tz) |
|
|
237
|
+
| `daterange` | `asyncpg.Range[date]` |
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
PostgreSQL types `json` and `jsonb` are [returned by asyncpg](https://magicstack.github.io/asyncpg/current/usage.html#type-conversion) as Python type `str`. However, if we specify the union type `JsonType` in `args` or `resultset`, the JSON string is parsed as if by calling `json.loads()`. If the library `orjson` is present, its faster routines are invoked instead of the slower standard library implementation in the module `json`.
|
|
241
|
+
|
|
242
|
+
`JsonType` is defined in the module `asyncpg_typed` as follows:
|
|
158
243
|
|
|
159
|
-
|
|
244
|
+
```python
|
|
245
|
+
JsonType = None | bool | int | float | str | dict[str, "JsonType"] | list["JsonType"]
|
|
246
|
+
```
|
|
247
|
+
|
|
248
|
+
`IPvXNetwork` is a shorthand for either of the following:
|
|
249
|
+
|
|
250
|
+
* `IPv4Network`
|
|
251
|
+
* `IPv6Network`
|
|
252
|
+
* their union type `IPv4Network | IPv6Network`
|
|
160
253
|
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
| `numeric` | `Decimal` |
|
|
171
|
-
| `date` | `date` |
|
|
172
|
-
| `time` | `time` (naive) |
|
|
173
|
-
| `timetz` | `time` (tz) |
|
|
174
|
-
| `timestamp` | `datetime` (naive) |
|
|
175
|
-
| `timestamptz` | `datetime` (tz) |
|
|
176
|
-
| `interval` | `timedelta` |
|
|
177
|
-
| `char(N)` | `str` |
|
|
178
|
-
| `varchar(N)` | `str` |
|
|
179
|
-
| `text` | `str` |
|
|
180
|
-
| `bytea` | `bytes` |
|
|
181
|
-
| `json` | `str`/`JsonType` |
|
|
182
|
-
| `jsonb` | `str`/`JsonType` |
|
|
183
|
-
| `xml` | `str` |
|
|
184
|
-
| `uuid` | `UUID` |
|
|
185
|
-
| enumeration | `E: StrEnum` |
|
|
186
|
-
|
|
187
|
-
PostgreSQL types `json` and `jsonb` are [returned by asyncpg](https://magicstack.github.io/asyncpg/current/usage.html#type-conversion) as Python type `str`. However, if we specify the union type `JsonType` in `args` or `resultset`, the JSON string is parsed as if by calling `json.loads()`. (`JsonType` is defined in the module `asyncpg_typed`.) If the library `orjson` is present, its faster routines are invoked instead of the slower standard library implementation in the module `json`.
|
|
254
|
+
`IPvXAddress` stands for either of the following:
|
|
255
|
+
|
|
256
|
+
* `IPv4Address`
|
|
257
|
+
* `IPv6Address`
|
|
258
|
+
* their union type `IPv4Address | IPv6Address`
|
|
259
|
+
|
|
260
|
+
#### SQL statement as an f-string
|
|
261
|
+
|
|
262
|
+
In addition to the `sql` function, SQL objects can be created with the functionally identical `unsafe_sql` function. As opposed to its safer alternative, the first parameter of `unsafe_sql` has the plain type `str`, allowing us to pass an f-string. This can prove useful if we want to inject the value of a Python variable at location where binding parameters are not permitted by PostgreSQL syntax, e.g. substitute the name of a database table to dynamically create a SQL statement.
|
|
188
263
|
|
|
189
264
|
### Using a SQL object
|
|
190
265
|
|
|
@@ -207,7 +282,12 @@ async def fetchval(self, connection: Connection, *args: *P) -> R1: ...
|
|
|
207
282
|
|
|
208
283
|
Only those functions are prompted on code completion that make sense in the context of the given number of input and output arguments. Specifically, `fetchval` is available only for a single type passed to `resultset`, and `executemany` and `fetchmany` are available only if the query takes (one or more) parameters.
|
|
209
284
|
|
|
285
|
+
#### Run-time behavior
|
|
286
|
+
|
|
287
|
+
When a call such as `sql.executemany(conn, records)` or `sql.fetch(conn, param1, param2)` is made on a `SQL` object at run time, the library invokes `connection.prepare(sql)` to create a `PreparedStatement` and compares the actual statement signature against the expected Python types. If the expected and actual signatures don't match, an exception `TypeMismatchError` (subclass of `TypeError`) is raised.
|
|
288
|
+
|
|
289
|
+
The set of values for an enumeration type is validated when a prepared statement is created. The string values declared in a Python `StrEnum` are compared against the values listed in PostgreSQL `CREATE TYPE ... AS ENUM` by querying the system table `pg_enum`. If there are missing or extra values on either side, an exception `EnumMismatchError` (subclass of `TypeError`) is raised.
|
|
210
290
|
|
|
211
|
-
|
|
291
|
+
Unfortunately, PostgreSQL doesn't propagate nullability via prepared statements: resultset types that are declared as required (e.g. `T` as opposed to `T | None`) are validated at run time. When a `None` value is encountered for a required type, an exception `NoneTypeError` (subclass of `TypeError`) is raised.
|
|
212
292
|
|
|
213
|
-
|
|
293
|
+
PostgreSQL doesn't differentiate between IPv4 and IPv6 network definitions, or IPv4 and IPv6 addresses in the types `cidr` and `inet`. This means that semantically a union type is returned. If you specify a more restrictive type, the resultset data is validated dynamically at run time.
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
asyncpg_typed/__init__.py,sha256=pDwWTWeNqXtw0Z0YrHRu_kneHu20X2SggFWK6aczbY8,38766
|
|
2
|
+
asyncpg_typed/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
+
asyncpg_typed-0.1.3.dist-info/licenses/LICENSE,sha256=rx4jD36wX8TyLZaR2HEOJ6TphFPjKUqoCSSYWzwWNRk,1093
|
|
4
|
+
asyncpg_typed-0.1.3.dist-info/METADATA,sha256=LTGsagnYy0YHn33DUpIEfkRh63mNyH1rdRxCnpyTNZk,15353
|
|
5
|
+
asyncpg_typed-0.1.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
6
|
+
asyncpg_typed-0.1.3.dist-info/top_level.txt,sha256=T0X1nWnXRTi5a5oTErGy572ORDbM9UV9wfhRXWLsaoY,14
|
|
7
|
+
asyncpg_typed-0.1.3.dist-info/zip-safe,sha256=frcCV1k9oG9oKj3dpUqdJg1PxRT2RSN_XKdLCPjaYaY,2
|
|
8
|
+
asyncpg_typed-0.1.3.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
asyncpg_typed/__init__.py,sha256=Z9UqmIr2QcSpGe7qC-ddMDDkwnJSGg5mm1dqiWPKYQM,24915
|
|
2
|
-
asyncpg_typed/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
-
asyncpg_typed-0.1.2.dist-info/licenses/LICENSE,sha256=rx4jD36wX8TyLZaR2HEOJ6TphFPjKUqoCSSYWzwWNRk,1093
|
|
4
|
-
asyncpg_typed-0.1.2.dist-info/METADATA,sha256=9wNzfDUQWAOhedM3g3cx_TYYlaaDjlqTrNq1qEqcK0k,9932
|
|
5
|
-
asyncpg_typed-0.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
6
|
-
asyncpg_typed-0.1.2.dist-info/top_level.txt,sha256=T0X1nWnXRTi5a5oTErGy572ORDbM9UV9wfhRXWLsaoY,14
|
|
7
|
-
asyncpg_typed-0.1.2.dist-info/zip-safe,sha256=frcCV1k9oG9oKj3dpUqdJg1PxRT2RSN_XKdLCPjaYaY,2
|
|
8
|
-
asyncpg_typed-0.1.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|