asyncpg-typed 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
asyncpg_typed/__init__.py CHANGED
@@ -1,964 +1,1261 @@
1
- """
2
- Type-safe queries for asyncpg.
3
-
4
- :see: https://github.com/hunyadi/asyncpg_typed
5
- """
6
-
7
- __version__ = "0.1.3"
8
- __author__ = "Levente Hunyadi"
9
- __copyright__ = "Copyright 2025, Levente Hunyadi"
10
- __license__ = "MIT"
11
- __maintainer__ = "Levente Hunyadi"
12
- __status__ = "Production"
13
-
14
- import enum
15
- import sys
16
- import typing
17
- from abc import abstractmethod
18
- from collections.abc import Callable, Iterable, Sequence
19
- from dataclasses import dataclass
20
- from datetime import date, datetime, time, timedelta
21
- from decimal import Decimal
22
- from functools import reduce
23
- from io import StringIO
24
- from ipaddress import IPv4Address, IPv4Network, IPv6Address, IPv6Network
25
- from types import UnionType
26
- from typing import Any, Protocol, TypeAlias, TypeGuard, TypeVar, Union, get_args, get_origin, overload
27
- from uuid import UUID
28
-
29
- import asyncpg
30
- from asyncpg.prepared_stmt import PreparedStatement
31
-
32
- if sys.version_info < (3, 11):
33
- from typing_extensions import LiteralString, TypeVarTuple, Unpack
34
- else:
35
- from typing import LiteralString, TypeVarTuple, Unpack
36
-
37
- JsonType = None | bool | int | float | str | dict[str, "JsonType"] | list["JsonType"]
38
-
39
- RequiredJsonType = bool | int | float | str | dict[str, "JsonType"] | list["JsonType"]
40
-
41
- TargetType: TypeAlias = type[Any] | UnionType
42
-
43
- Connection: TypeAlias = asyncpg.Connection | asyncpg.pool.PoolConnectionProxy
44
-
45
-
46
- class TypeMismatchError(TypeError):
47
- "Raised when a prepared statement takes or returns a PostgreSQL type incompatible with the declared Python type."
48
-
49
-
50
- class EnumMismatchError(TypeError):
51
- "Raised when a prepared statement takes or returns a PostgreSQL enum type whose permitted set of values differs from what is declared in Python."
52
-
53
-
54
- class NoneTypeError(TypeError):
55
- "Raised when a column marked as required contains a `NULL` value."
56
-
57
-
58
- if sys.version_info >= (3, 11):
59
-
60
- def is_enum_type(typ: Any) -> TypeGuard[type[enum.Enum]]:
61
- """
62
- `True` if the specified type is an enumeration type.
63
- """
64
-
65
- return isinstance(typ, enum.EnumType)
66
-
67
- else:
68
-
69
- def is_enum_type(typ: Any) -> TypeGuard[type[enum.Enum]]:
70
- """
71
- `True` if the specified type is an enumeration type.
72
- """
73
-
74
- # use an explicit isinstance(..., type) check to filter out special forms like generics
75
- return isinstance(typ, type) and issubclass(typ, enum.Enum)
76
-
77
-
78
- def is_union_type(tp: Any) -> bool:
79
- """
80
- `True` if `tp` is a union type such as `A | B` or `Union[A, B]`.
81
- """
82
-
83
- origin = get_origin(tp)
84
- return origin is Union or origin is UnionType
85
-
86
-
87
- def is_optional_type(tp: Any) -> bool:
88
- """
89
- `True` if `tp` is an optional type such as `T | None`, `Optional[T]` or `Union[T, None]`.
90
- """
91
-
92
- return is_union_type(tp) and any(a is type(None) for a in get_args(tp))
93
-
94
-
95
- def is_standard_type(tp: Any) -> bool:
96
- """
97
- `True` if the type represents a built-in or a well-known standard type.
98
- """
99
-
100
- return tp.__module__ == "builtins" or tp.__module__ == UnionType.__module__
101
-
102
-
103
- def is_json_type(tp: Any) -> bool:
104
- """
105
- `True` if the type represents an object de-serialized from a JSON string.
106
- """
107
-
108
- return tp in [JsonType, RequiredJsonType]
109
-
110
-
111
- def is_inet_type(tp: Any) -> bool:
112
- """
113
- `True` if the type represents an IP address or network.
114
- """
115
-
116
- return tp in [IPv4Address, IPv6Address, IPv4Network, IPv6Network]
117
-
118
-
119
- def make_union_type(tpl: list[Any]) -> UnionType:
120
- """
121
- Creates a `UnionType` (a.k.a. `A | B | C`) dynamically at run time.
122
- """
123
-
124
- if len(tpl) < 2:
125
- raise ValueError("expected: at least two types to make a `UnionType`")
126
-
127
- return reduce(lambda a, b: a | b, tpl)
128
-
129
-
130
- def get_required_type(tp: Any) -> Any:
131
- """
132
- Removes `None` from an optional type (i.e. a union type that has `None` as a member).
133
- """
134
-
135
- if not is_optional_type(tp):
136
- return tp
137
-
138
- tpl = [a for a in get_args(tp) if a is not type(None)]
139
- if len(tpl) > 1:
140
- return make_union_type(tpl)
141
- elif len(tpl) > 0:
142
- return tpl[0]
143
- else:
144
- return type(None)
145
-
146
-
147
- def _standard_json_decoder() -> Callable[[str], JsonType]:
148
- import json
149
-
150
- _json_decoder = json.JSONDecoder()
151
- return _json_decoder.decode
152
-
153
-
154
- def _json_decoder() -> Callable[[str], JsonType]:
155
- if typing.TYPE_CHECKING:
156
- return _standard_json_decoder()
157
- else:
158
- try:
159
- import orjson
160
-
161
- return orjson.loads
162
- except ModuleNotFoundError:
163
- return _standard_json_decoder()
164
-
165
-
166
- JSON_DECODER = _json_decoder()
167
-
168
-
169
- def _standard_json_encoder() -> Callable[[JsonType], str]:
170
- import json
171
-
172
- _json_encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":"), allow_nan=False)
173
- return _json_encoder.encode
174
-
175
-
176
- def _json_encoder() -> Callable[[JsonType], str]:
177
- if typing.TYPE_CHECKING:
178
- return _standard_json_encoder()
179
- else:
180
- try:
181
- import orjson
182
-
183
- def _wrap(value: JsonType) -> str:
184
- return orjson.dumps(value).decode()
185
-
186
- return _wrap
187
- except ModuleNotFoundError:
188
- return _standard_json_encoder()
189
-
190
-
191
- JSON_ENCODER = _json_encoder()
192
-
193
-
194
- def get_output_converter_for(tp: Any) -> Callable[[Any], Any]:
195
- """
196
- Returns a callable that takes a wire type and returns a target type.
197
-
198
- A wire type is one of the types returned by asyncpg.
199
- A target type is one of the types supported by the library.
200
- """
201
-
202
- if is_json_type(tp):
203
- # asyncpg returns fields of type `json` and `jsonb` as `str`, which must be de-serialized
204
- return JSON_DECODER
205
- else:
206
- # target data types that require conversion must have a single-argument `__init__` that takes an object of the source type
207
- return tp
208
-
209
-
210
- def get_input_converter_for(tp: Any) -> Callable[[Any], Any]:
211
- """
212
- Returns a callable that takes a source type and returns a wire type.
213
-
214
- A source type is one of the types supported by the library.
215
- A wire type is one of the types returned by asyncpg.
216
- """
217
-
218
- if is_json_type(tp):
219
- # asyncpg expects fields of type `json` and `jsonb` as `str`, which must be serialized
220
- return JSON_ENCODER
221
- else:
222
- # source data types that require conversion must have a single-argument `__init__` that takes an object of the source type
223
- return tp
224
-
225
-
226
- # maps PostgreSQL internal type names to compatible Python types
227
- _NAME_TO_TYPE: dict[str, tuple[Any, ...]] = {
228
- # boolean type
229
- "bool": (bool,),
230
- # numeric types
231
- "int2": (int,),
232
- "int4": (int,),
233
- "int8": (int,),
234
- "float4": (float,),
235
- "float8": (float,),
236
- "numeric": (Decimal,),
237
- # date and time types
238
- "date": (date,),
239
- "time": (time,),
240
- "timetz": (time,),
241
- "timestamp": (datetime,),
242
- "timestamptz": (datetime,),
243
- "interval": (timedelta,),
244
- # character sequence types
245
- "bpchar": (str,),
246
- "varchar": (str,),
247
- "text": (str,),
248
- # binary sequence types
249
- "bytea": (bytes,),
250
- # unique identifier type
251
- "uuid": (UUID,),
252
- # address types
253
- "cidr": (IPv4Network, IPv6Network, IPv4Network | IPv6Network),
254
- "inet": (IPv4Network, IPv6Network, IPv4Network | IPv6Network, IPv4Address, IPv6Address, IPv4Address | IPv6Address),
255
- "macaddr": (str,),
256
- "macaddr8": (str,),
257
- # JSON type
258
- "json": (str, RequiredJsonType),
259
- "jsonb": (str, RequiredJsonType),
260
- # XML type
261
- "xml": (str,),
262
- # geometric types
263
- "point": (asyncpg.Point,),
264
- "line": (asyncpg.Line,),
265
- "lseg": (asyncpg.LineSegment,),
266
- "box": (asyncpg.Box,),
267
- "path": (asyncpg.Path,),
268
- "polygon": (asyncpg.Polygon,),
269
- "circle": (asyncpg.Circle,),
270
- # range types
271
- "int4range": (asyncpg.Range[int],),
272
- "int4multirange": (list[asyncpg.Range[int]],),
273
- "int8range": (asyncpg.Range[int],),
274
- "int8multirange": (list[asyncpg.Range[int]],),
275
- "numrange": (asyncpg.Range[Decimal],),
276
- "nummultirange": (list[asyncpg.Range[Decimal]],),
277
- "tsrange": (asyncpg.Range[datetime],),
278
- "tsmultirange": (list[asyncpg.Range[datetime]],),
279
- "tstzrange": (asyncpg.Range[datetime],),
280
- "tstzmultirange": (list[asyncpg.Range[datetime]],),
281
- "daterange": (asyncpg.Range[date],),
282
- "datemultirange": (list[asyncpg.Range[date]],),
283
- }
284
-
285
-
286
- def type_to_str(tp: Any) -> str:
287
- "Emits a friendly name for a type."
288
-
289
- if isinstance(tp, type):
290
- return tp.__name__
291
- else:
292
- return str(tp)
293
-
294
-
295
- class _TypeVerifier:
296
- """
297
- Verifies if the Python target type can represent the PostgreSQL source type.
298
- """
299
-
300
- _connection: Connection
301
-
302
- def __init__(self, connection: Connection) -> None:
303
- self._connection = connection
304
-
305
- async def _check_enum_type(self, pg_name: str, pg_type: asyncpg.Type, data_type: type[enum.Enum]) -> None:
306
- """
307
- Verifies if a Python enumeration type matches a PostgreSQL enumeration type.
308
- """
309
-
310
- for e in data_type:
311
- if not isinstance(e.value, str):
312
- raise TypeMismatchError(f"expected: Python enum type `{type_to_str(data_type)}` with `str` values; got: `{type_to_str(type(e.value))}` for enum field `{e.name}`")
313
-
314
- py_values = set(e.value for e in data_type)
315
-
316
- rows = await self._connection.fetch("SELECT enumlabel FROM pg_enum WHERE enumtypid = $1 ORDER BY enumsortorder;", pg_type.oid)
317
- db_values = set(row[0] for row in rows)
318
-
319
- db_extra = db_values - py_values
320
- if db_extra:
321
- raise EnumMismatchError(f"expected: Python enum type `{type_to_str(data_type)}` to match values of PostgreSQL enum type `{pg_type.name}` for {pg_name}; missing value(s): {', '.join(f'`{val}`' for val in db_extra)})")
322
-
323
- py_extra = py_values - db_values
324
- if py_extra:
325
- raise EnumMismatchError(f"expected: Python enum type `{type_to_str(data_type)}` to match values of PostgreSQL enum type `{pg_type.name}` for {pg_name}; got extra value(s): {', '.join(f'`{val}`' for val in py_extra)})")
326
-
327
- async def check_data_type(self, pg_name: str, pg_type: asyncpg.Type, data_type: TargetType) -> None:
328
- """
329
- Verifies if the Python target type can represent the PostgreSQL source type.
330
- """
331
-
332
- if pg_type.schema == "pg_catalog": # well-known PostgreSQL types
333
- if is_enum_type(data_type):
334
- if pg_type.name not in ["bpchar", "varchar", "text"]:
335
- raise TypeMismatchError(f"expected: Python enum type `{type_to_str(data_type)}` for {pg_name}; got: PostgreSQL type `{pg_type.kind}` of `{pg_type.name}` instead of `char`, `varchar` or `text`")
336
- else:
337
- expected_types = _NAME_TO_TYPE.get(pg_type.name)
338
- if expected_types is None:
339
- raise TypeMismatchError(f"expected: Python type `{type_to_str(data_type)}` for {pg_name}; got: unrecognized PostgreSQL type `{pg_type.kind}` of `{pg_type.name}`")
340
- elif data_type not in expected_types:
341
- raise TypeMismatchError(
342
- f"expected: Python type `{type_to_str(data_type)}` for {pg_name}; "
343
- f"got: incompatible PostgreSQL type `{pg_type.kind}` of `{pg_type.name}`, which converts to one of the Python types {', '.join(f'`{type_to_str(tp)}`' for tp in expected_types)}"
344
- )
345
- elif pg_type.kind == "composite": # PostgreSQL composite types
346
- # user-defined composite types registered with `conn.set_type_codec()` typically using `format="tuple"`
347
- pass
348
- else: # custom PostgreSQL types
349
- if is_enum_type(data_type):
350
- await self._check_enum_type(pg_name, pg_type, data_type)
351
- elif is_standard_type(data_type):
352
- raise TypeMismatchError(f"expected: Python type `{type_to_str(data_type)}` for {pg_name}; got: PostgreSQL type `{pg_type.kind}` of `{pg_type.name}`")
353
- else:
354
- # user-defined types registered with `conn.set_type_codec()`
355
- pass
356
-
357
-
358
- @dataclass(frozen=True)
359
- class _SQLPlaceholder:
360
- ordinal: int
361
- data_type: TargetType
362
-
363
- def __repr__(self) -> str:
364
- return f"{self.__class__.__name__}({self.ordinal}, {self.data_type!r})"
365
-
366
-
367
- class _SQLObject:
368
- """
369
- Associates input and output type information with a SQL statement.
370
- """
371
-
372
- _parameter_data_types: tuple[_SQLPlaceholder, ...]
373
- _resultset_data_types: tuple[TargetType, ...]
374
- _parameter_cast: int
375
- _parameter_converters: tuple[Callable[[Any], Any], ...]
376
- _required: int
377
- _resultset_cast: int
378
- _resultset_converters: tuple[Callable[[Any], Any], ...]
379
-
380
- @property
381
- def parameter_data_types(self) -> tuple[_SQLPlaceholder, ...]:
382
- return self._parameter_data_types
383
-
384
- @property
385
- def resultset_data_types(self) -> tuple[TargetType, ...]:
386
- return self._resultset_data_types
387
-
388
- def __init__(
389
- self,
390
- input_data_types: tuple[TargetType, ...],
391
- output_data_types: tuple[TargetType, ...],
392
- ) -> None:
393
- self._parameter_data_types = tuple(_SQLPlaceholder(ordinal, get_required_type(arg)) for ordinal, arg in enumerate(input_data_types, start=1))
394
- self._resultset_data_types = tuple(get_required_type(data_type) for data_type in output_data_types)
395
-
396
- # create a bit-field of types that require cast or serialization (1: apply conversion; 0: forward value as-is)
397
- parameter_cast = 0
398
- for index, placeholder in enumerate(self._parameter_data_types):
399
- parameter_cast |= is_json_type(placeholder.data_type) << index
400
- self._parameter_cast = parameter_cast
401
-
402
- self._parameter_converters = tuple(get_input_converter_for(placeholder.data_type) for placeholder in self._parameter_data_types)
403
-
404
- # create a bit-field of required types (1: required; 0: optional)
405
- required = 0
406
- for index, data_type in enumerate(output_data_types):
407
- required |= (not is_optional_type(data_type)) << index
408
- self._required = required
409
-
410
- # create a bit-field of types that require cast or de-serialization (1: apply conversion; 0: forward value as-is)
411
- resultset_cast = 0
412
- for index, data_type in enumerate(self._resultset_data_types):
413
- resultset_cast |= (is_enum_type(data_type) or is_json_type(data_type) or is_inet_type(data_type)) << index
414
- self._resultset_cast = resultset_cast
415
-
416
- self._resultset_converters = tuple(get_output_converter_for(data_type) for data_type in self._resultset_data_types)
417
-
418
- def _raise_required_is_none(self, row: tuple[Any, ...], row_index: int | None = None) -> None:
419
- """
420
- Raises an error with the index of the first column value that is of a required type but has been assigned a value of `None`.
421
- """
422
-
423
- for col_index in range(len(row)):
424
- if (self._required >> col_index & 1) and row[col_index] is None:
425
- if row_index is not None:
426
- row_col_spec = f"row #{row_index} and column #{col_index}"
427
- else:
428
- row_col_spec = f"column #{col_index}"
429
- raise NoneTypeError(f"expected: {self._resultset_data_types[col_index]} in {row_col_spec}; got: NULL")
430
-
431
- def check_rows(self, rows: list[tuple[Any, ...]]) -> None:
432
- """
433
- Verifies if declared types match actual value types in a resultset.
434
- """
435
-
436
- if not rows:
437
- return
438
-
439
- required = self._required
440
- if not required:
441
- return
442
-
443
- match len(rows[0]):
444
- case 1:
445
- for r, row in enumerate(rows):
446
- if required & (row[0] is None):
447
- self._raise_required_is_none(row, r)
448
- case 2:
449
- for r, row in enumerate(rows):
450
- a, b = row
451
- if required & ((a is None) | (b is None) << 1):
452
- self._raise_required_is_none(row, r)
453
- case 3:
454
- for r, row in enumerate(rows):
455
- a, b, c = row
456
- if required & ((a is None) | (b is None) << 1 | (c is None) << 2):
457
- self._raise_required_is_none(row, r)
458
- case 4:
459
- for r, row in enumerate(rows):
460
- a, b, c, d = row
461
- if required & ((a is None) | (b is None) << 1 | (c is None) << 2 | (d is None) << 3):
462
- self._raise_required_is_none(row, r)
463
- case 5:
464
- for r, row in enumerate(rows):
465
- a, b, c, d, e = row
466
- if required & ((a is None) | (b is None) << 1 | (c is None) << 2 | (d is None) << 3 | (e is None) << 4):
467
- self._raise_required_is_none(row, r)
468
- case 6:
469
- for r, row in enumerate(rows):
470
- a, b, c, d, e, f = row
471
- if required & ((a is None) | (b is None) << 1 | (c is None) << 2 | (d is None) << 3 | (e is None) << 4 | (f is None) << 5):
472
- self._raise_required_is_none(row, r)
473
- case 7:
474
- for r, row in enumerate(rows):
475
- a, b, c, d, e, f, g = row
476
- if required & ((a is None) | (b is None) << 1 | (c is None) << 2 | (d is None) << 3 | (e is None) << 4 | (f is None) << 5 | (g is None) << 6):
477
- self._raise_required_is_none(row, r)
478
- case 8:
479
- for r, row in enumerate(rows):
480
- a, b, c, d, e, f, g, h = row
481
- if required & ((a is None) | (b is None) << 1 | (c is None) << 2 | (d is None) << 3 | (e is None) << 4 | (f is None) << 5 | (g is None) << 6 | (h is None) << 7):
482
- self._raise_required_is_none(row, r)
483
- case _:
484
- for r, row in enumerate(rows):
485
- self._raise_required_is_none(row, r)
486
-
487
- def check_row(self, row: tuple[Any, ...]) -> None:
488
- """
489
- Verifies if declared types match actual value types in a single row.
490
- """
491
-
492
- required = self._required
493
- if not required:
494
- return
495
-
496
- match len(row):
497
- case 1:
498
- if required & (row[0] is None):
499
- self._raise_required_is_none(row)
500
- case 2:
501
- a, b = row
502
- if required & ((a is None) | (b is None) << 1):
503
- self._raise_required_is_none(row)
504
- case 3:
505
- a, b, c = row
506
- if required & ((a is None) | (b is None) << 1 | (c is None) << 2):
507
- self._raise_required_is_none(row)
508
- case 4:
509
- a, b, c, d = row
510
- if required & ((a is None) | (b is None) << 1 | (c is None) << 2 | (d is None) << 3):
511
- self._raise_required_is_none(row)
512
- case 5:
513
- a, b, c, d, e = row
514
- if required & ((a is None) | (b is None) << 1 | (c is None) << 2 | (d is None) << 3 | (e is None) << 4):
515
- self._raise_required_is_none(row)
516
- case 6:
517
- a, b, c, d, e, f = row
518
- if required & ((a is None) | (b is None) << 1 | (c is None) << 2 | (d is None) << 3 | (e is None) << 4 | (f is None) << 5):
519
- self._raise_required_is_none(row)
520
- case 7:
521
- a, b, c, d, e, f, g = row
522
- if required & ((a is None) | (b is None) << 1 | (c is None) << 2 | (d is None) << 3 | (e is None) << 4 | (f is None) << 5 | (g is None) << 6):
523
- self._raise_required_is_none(row)
524
- case 8:
525
- a, b, c, d, e, f, g, h = row
526
- if required & ((a is None) | (b is None) << 1 | (c is None) << 2 | (d is None) << 3 | (e is None) << 4 | (f is None) << 5 | (g is None) << 6 | (h is None) << 7):
527
- self._raise_required_is_none(row)
528
- case _:
529
- self._raise_required_is_none(row)
530
-
531
- def check_value(self, value: Any) -> None:
532
- """
533
- Verifies if the declared type matches the actual value type.
534
- """
535
-
536
- if self._required and value is None:
537
- raise NoneTypeError(f"expected: {self._resultset_data_types[0]}; got: NULL")
538
-
539
- def convert_arg_lists(self, arg_lists: Iterable[Sequence[Any]]) -> Iterable[Sequence[Any]]:
540
- """
541
- Converts a list of Python query argument tuples to a list of PostgreSQL parameter tuples.
542
- """
543
-
544
- cast = self._parameter_cast
545
- if cast:
546
- converters = self._parameter_converters
547
- yield from (tuple((converters[i](value) if (value := arg[i]) is not None and cast >> i & 1 else value) for i in range(len(arg))) for arg in arg_lists)
548
- else:
549
- yield from arg_lists
550
-
551
- def convert_arg_list(self, arg_list: Sequence[Any]) -> Sequence[Any]:
552
- """
553
- Converts Python query arguments to PostgreSQL parameters.
554
- """
555
-
556
- cast = self._parameter_cast
557
- if cast:
558
- converters = self._parameter_converters
559
- return tuple((converters[i](value) if (value := arg_list[i]) is not None and cast >> i & 1 else value) for i in range(len(arg_list)))
560
- else:
561
- return tuple(value for value in arg_list)
562
-
563
- def convert_rows(self, rows: list[asyncpg.Record]) -> list[tuple[Any, ...]]:
564
- """
565
- Converts columns in the PostgreSQL result-set to their corresponding Python target types.
566
-
567
- :param rows: List of rows returned by PostgreSQL.
568
- :returns: List of tuples with each tuple element having the configured Python target type.
569
- """
570
-
571
- cast = self._resultset_cast
572
- if cast:
573
- converters = self._resultset_converters
574
- return [tuple((converters[i](value) if (value := row[i]) is not None and cast >> i & 1 else value) for i in range(len(row))) for row in rows]
575
- else:
576
- return [tuple(value for value in row) for row in rows]
577
-
578
- def convert_row(self, row: asyncpg.Record) -> tuple[Any, ...]:
579
- """
580
- Converts columns in the PostgreSQL result-set to their corresponding Python target types.
581
-
582
- :param row: A single row returned by PostgreSQL.
583
- :returns: A tuple with each tuple element having the configured Python target type.
584
- """
585
-
586
- cast = self._resultset_cast
587
- if cast:
588
- converters = self._resultset_converters
589
- return tuple((converters[i](value) if (value := row[i]) is not None and cast >> i & 1 else value) for i in range(len(row)))
590
- else:
591
- return tuple(value for value in row)
592
-
593
- def convert_value(self, value: Any) -> Any:
594
- """
595
- Converts a single PostgreSQL value to its corresponding Python target type.
596
-
597
- :param value: A single value returned by PostgreSQL.
598
- :returns: A converted value having the configured Python target type.
599
- """
600
-
601
- return self._resultset_converters[0](value) if value is not None and self._resultset_cast else value
602
-
603
- @abstractmethod
604
- def query(self) -> str:
605
- """
606
- Returns a SQL query string with PostgreSQL ordinal placeholders.
607
- """
608
- ...
609
-
610
- def __repr__(self) -> str:
611
- return f"{self.__class__.__name__}({self.query()!r})"
612
-
613
- def __str__(self) -> str:
614
- return self.query()
615
-
616
-
617
- if sys.version_info >= (3, 14):
618
- from string.templatelib import Interpolation, Template # type: ignore[import-not-found]
619
-
620
- SQLExpression: TypeAlias = Template | LiteralString
621
-
622
- class _SQLTemplate(_SQLObject):
623
- """
624
- A SQL query specified with the Python t-string syntax.
625
- """
626
-
627
- _strings: tuple[str, ...]
628
- _placeholders: tuple[_SQLPlaceholder, ...]
629
-
630
- def __init__(
631
- self,
632
- template: Template,
633
- *,
634
- args: tuple[TargetType, ...],
635
- resultset: tuple[TargetType, ...],
636
- ) -> None:
637
- super().__init__(args, resultset)
638
-
639
- for ip in template.interpolations:
640
- if ip.conversion is not None:
641
- raise TypeError(f"interpolation `{ip.expression}` expected to apply no conversion")
642
- if ip.format_spec:
643
- raise TypeError(f"interpolation `{ip.expression}` expected to apply no format spec")
644
- if not isinstance(ip.value, int):
645
- raise TypeError(f"interpolation `{ip.expression}` expected to evaluate to an integer")
646
-
647
- self._strings = template.strings
648
-
649
- if len(self.parameter_data_types) > 0:
650
-
651
- def _to_placeholder(ip: Interpolation) -> _SQLPlaceholder:
652
- ordinal = int(ip.value)
653
- if not (0 < ordinal <= len(self.parameter_data_types)):
654
- raise IndexError(f"interpolation `{ip.expression}` is an ordinal out of range; expected: 0 < value <= {len(self.parameter_data_types)}")
655
- return self.parameter_data_types[int(ip.value) - 1]
656
-
657
- self._placeholders = tuple(_to_placeholder(ip) for ip in template.interpolations)
658
- else:
659
- self._placeholders = ()
660
-
661
- def query(self) -> str:
662
- buf = StringIO()
663
- for s, p in zip(self._strings[:-1], self._placeholders, strict=True):
664
- buf.write(s)
665
- buf.write(f"${p.ordinal}")
666
- buf.write(self._strings[-1])
667
- return buf.getvalue()
668
-
669
- else:
670
- SQLExpression = LiteralString
671
-
672
-
673
- class _SQLString(_SQLObject):
674
- """
675
- A SQL query specified as a plain string (e.g. f-string).
676
- """
677
-
678
- _sql: str
679
-
680
- def __init__(
681
- self,
682
- sql: str,
683
- *,
684
- args: tuple[TargetType, ...],
685
- resultset: tuple[TargetType, ...],
686
- ) -> None:
687
- super().__init__(args, resultset)
688
- self._sql = sql
689
-
690
- def query(self) -> str:
691
- return self._sql
692
-
693
-
694
- class _SQL(Protocol):
695
- """
696
- Represents a SQL statement with associated type information.
697
- """
698
-
699
-
700
- class _SQLImpl(_SQL):
701
- """
702
- Forwards input data to an `asyncpg.PreparedStatement`, and validates output data (if necessary).
703
- """
704
-
705
- _sql: _SQLObject
706
-
707
- def __init__(self, sql: _SQLObject) -> None:
708
- self._sql = sql
709
-
710
- def __str__(self) -> str:
711
- return str(self._sql)
712
-
713
- def __repr__(self) -> str:
714
- return repr(self._sql)
715
-
716
- async def _prepare(self, connection: Connection) -> PreparedStatement:
717
- stmt = await connection.prepare(self._sql.query())
718
-
719
- verifier = _TypeVerifier(connection)
720
- for param, placeholder in zip(stmt.get_parameters(), self._sql.parameter_data_types, strict=True):
721
- await verifier.check_data_type(f"parameter ${placeholder.ordinal}", param, placeholder.data_type)
722
- for attr, data_type in zip(stmt.get_attributes(), self._sql.resultset_data_types, strict=True):
723
- await verifier.check_data_type(f"column `{attr.name}`", attr.type, data_type)
724
-
725
- return stmt
726
-
727
- async def execute(self, connection: asyncpg.Connection, *args: Any) -> None:
728
- await connection.execute(self._sql.query(), *self._sql.convert_arg_list(args))
729
-
730
- async def executemany(self, connection: asyncpg.Connection, args: Iterable[Sequence[Any]]) -> None:
731
- stmt = await self._prepare(connection)
732
- await stmt.executemany(self._sql.convert_arg_lists(args))
733
-
734
- def _cast_fetch(self, rows: list[asyncpg.Record]) -> list[tuple[Any, ...]]:
735
- resultset = self._sql.convert_rows(rows)
736
- self._sql.check_rows(resultset)
737
- return resultset
738
-
739
- async def fetch(self, connection: asyncpg.Connection, *args: Any) -> list[tuple[Any, ...]]:
740
- stmt = await self._prepare(connection)
741
- rows = await stmt.fetch(*self._sql.convert_arg_list(args))
742
- return self._cast_fetch(rows)
743
-
744
- async def fetchmany(self, connection: asyncpg.Connection, args: Iterable[Sequence[Any]]) -> list[tuple[Any, ...]]:
745
- stmt = await self._prepare(connection)
746
- rows = await stmt.fetchmany(self._sql.convert_arg_lists(args))
747
- return self._cast_fetch(rows)
748
-
749
- async def fetchrow(self, connection: asyncpg.Connection, *args: Any) -> tuple[Any, ...] | None:
750
- stmt = await self._prepare(connection)
751
- row = await stmt.fetchrow(*self._sql.convert_arg_list(args))
752
- if row is None:
753
- return None
754
- resultset = self._sql.convert_row(row)
755
- self._sql.check_row(resultset)
756
- return resultset
757
-
758
- async def fetchval(self, connection: asyncpg.Connection, *args: Any) -> Any:
759
- stmt = await self._prepare(connection)
760
- value = await stmt.fetchval(*self._sql.convert_arg_list(args))
761
- result = self._sql.convert_value(value)
762
- self._sql.check_value(result)
763
- return result
764
-
765
-
766
- P1 = TypeVar("P1")
767
- PX = TypeVarTuple("PX")
768
-
769
- RT = TypeVar("RT")
770
- R1 = TypeVar("R1")
771
- R2 = TypeVar("R2")
772
- RX = TypeVarTuple("RX")
773
-
774
-
775
- ### START OF AUTO-GENERATED BLOCK FOR Protocol ###
776
- class SQL_P0(Protocol):
777
- @abstractmethod
778
- async def execute(self, connection: Connection) -> None: ...
779
-
780
-
781
- class SQL_R1_P0(SQL_P0, Protocol[R1]):
782
- @abstractmethod
783
- async def fetch(self, connection: Connection) -> list[tuple[R1]]: ...
784
- @abstractmethod
785
- async def fetchrow(self, connection: Connection) -> tuple[R1] | None: ...
786
- @abstractmethod
787
- async def fetchval(self, connection: Connection) -> R1: ...
788
-
789
-
790
- class SQL_RX_P0(SQL_P0, Protocol[RT]):
791
- @abstractmethod
792
- async def fetch(self, connection: Connection) -> list[RT]: ...
793
- @abstractmethod
794
- async def fetchrow(self, connection: Connection) -> RT | None: ...
795
-
796
-
797
- class SQL_PX(Protocol[Unpack[PX]]):
798
- @abstractmethod
799
- async def execute(self, connection: Connection, *args: Unpack[PX]) -> None: ...
800
- @abstractmethod
801
- async def executemany(self, connection: Connection, args: Iterable[tuple[Unpack[PX]]]) -> None: ...
802
-
803
-
804
- class SQL_R1_PX(SQL_PX[Unpack[PX]], Protocol[R1, Unpack[PX]]):
805
- @abstractmethod
806
- async def fetch(self, connection: Connection, *args: Unpack[PX]) -> list[tuple[R1]]: ...
807
- @abstractmethod
808
- async def fetchmany(self, connection: Connection, args: Iterable[tuple[Unpack[PX]]]) -> list[tuple[R1]]: ...
809
- @abstractmethod
810
- async def fetchrow(self, connection: Connection, *args: Unpack[PX]) -> tuple[R1] | None: ...
811
- @abstractmethod
812
- async def fetchval(self, connection: Connection, *args: Unpack[PX]) -> R1: ...
813
-
814
-
815
- class SQL_RX_PX(SQL_PX[Unpack[PX]], Protocol[RT, Unpack[PX]]):
816
- @abstractmethod
817
- async def fetch(self, connection: Connection, *args: Unpack[PX]) -> list[RT]: ...
818
- @abstractmethod
819
- async def fetchmany(self, connection: Connection, args: Iterable[tuple[Unpack[PX]]]) -> list[RT]: ...
820
- @abstractmethod
821
- async def fetchrow(self, connection: Connection, *args: Unpack[PX]) -> RT | None: ...
822
-
823
-
824
- ### END OF AUTO-GENERATED BLOCK FOR Protocol ###
825
-
826
-
827
- class SQLFactory:
828
- """
829
- Creates type-safe SQL queries.
830
- """
831
-
832
- ### START OF AUTO-GENERATED BLOCK FOR sql ###
833
- @overload
834
- def sql(self, stmt: SQLExpression) -> SQL_P0: ...
835
- @overload
836
- def sql(self, stmt: SQLExpression, *, result: type[R1]) -> SQL_R1_P0[R1]: ...
837
- @overload
838
- def sql(self, stmt: SQLExpression, *, resultset: type[tuple[R1]]) -> SQL_R1_P0[R1]: ...
839
- @overload
840
- def sql(self, stmt: SQLExpression, *, resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_RX_P0[tuple[R1, R2, Unpack[RX]]]: ...
841
- @overload
842
- def sql(self, stmt: SQLExpression, *, arg: type[P1]) -> SQL_PX[P1]: ...
843
- @overload
844
- def sql(self, stmt: SQLExpression, *, arg: type[P1], result: type[R1]) -> SQL_R1_PX[R1, P1]: ...
845
- @overload
846
- def sql(self, stmt: SQLExpression, *, arg: type[P1], resultset: type[tuple[R1]]) -> SQL_R1_PX[R1, P1]: ...
847
- @overload
848
- def sql(self, stmt: SQLExpression, *, arg: type[P1], resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_RX_PX[tuple[R1, R2, Unpack[RX]], P1]: ...
849
- @overload
850
- def sql(self, stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]]) -> SQL_PX[P1, Unpack[PX]]: ...
851
- @overload
852
- def sql(self, stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]], result: type[R1]) -> SQL_R1_PX[R1, P1, Unpack[PX]]: ...
853
- @overload
854
- def sql(self, stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]], resultset: type[tuple[R1]]) -> SQL_R1_PX[R1, P1, Unpack[PX]]: ...
855
- @overload
856
- def sql(self, stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]], resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_RX_PX[tuple[R1, R2, Unpack[RX]], P1, Unpack[PX]]: ...
857
-
858
- ### END OF AUTO-GENERATED BLOCK FOR sql ###
859
-
860
- def sql(self, stmt: SQLExpression, *, args: type[Any] | None = None, resultset: type[Any] | None = None, arg: type[Any] | None = None, result: type[Any] | None = None) -> _SQL:
861
- """
862
- Creates a SQL statement with associated type information.
863
-
864
- :param stmt: SQL statement as a literal string or template.
865
- :param args: Type signature for multiple input parameters (e.g. `tuple[bool, int, str]`).
866
- :param resultset: Type signature for multiple resultset columns (e.g. `tuple[datetime, Decimal, str]`).
867
- :param arg: Type signature for a single input parameter (e.g. `int`).
868
- :param result: Type signature for a single result column (e.g. `UUID`).
869
- """
870
-
871
- input_data_types, output_data_types = _sql_args_resultset(args=args, resultset=resultset, arg=arg, result=result)
872
-
873
- obj: _SQLObject
874
- if sys.version_info >= (3, 14):
875
- match stmt:
876
- case Template():
877
- obj = _SQLTemplate(stmt, args=input_data_types, resultset=output_data_types)
878
- case str():
879
- obj = _SQLString(stmt, args=input_data_types, resultset=output_data_types)
880
- else:
881
- obj = _SQLString(stmt, args=input_data_types, resultset=output_data_types)
882
-
883
- return _SQLImpl(obj)
884
-
885
- ### START OF AUTO-GENERATED BLOCK FOR unsafe_sql ###
886
- @overload
887
- def unsafe_sql(self, stmt: str) -> SQL_P0: ...
888
- @overload
889
- def unsafe_sql(self, stmt: str, *, result: type[R1]) -> SQL_R1_P0[R1]: ...
890
- @overload
891
- def unsafe_sql(self, stmt: str, *, resultset: type[tuple[R1]]) -> SQL_R1_P0[R1]: ...
892
- @overload
893
- def unsafe_sql(self, stmt: str, *, resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_RX_P0[tuple[R1, R2, Unpack[RX]]]: ...
894
- @overload
895
- def unsafe_sql(self, stmt: str, *, arg: type[P1]) -> SQL_PX[P1]: ...
896
- @overload
897
- def unsafe_sql(self, stmt: str, *, arg: type[P1], result: type[R1]) -> SQL_R1_PX[R1, P1]: ...
898
- @overload
899
- def unsafe_sql(self, stmt: str, *, arg: type[P1], resultset: type[tuple[R1]]) -> SQL_R1_PX[R1, P1]: ...
900
- @overload
901
- def unsafe_sql(self, stmt: str, *, arg: type[P1], resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_RX_PX[tuple[R1, R2, Unpack[RX]], P1]: ...
902
- @overload
903
- def unsafe_sql(self, stmt: str, *, args: type[tuple[P1, Unpack[PX]]]) -> SQL_PX[P1, Unpack[PX]]: ...
904
- @overload
905
- def unsafe_sql(self, stmt: str, *, args: type[tuple[P1, Unpack[PX]]], result: type[R1]) -> SQL_R1_PX[R1, P1, Unpack[PX]]: ...
906
- @overload
907
- def unsafe_sql(self, stmt: str, *, args: type[tuple[P1, Unpack[PX]]], resultset: type[tuple[R1]]) -> SQL_R1_PX[R1, P1, Unpack[PX]]: ...
908
- @overload
909
- def unsafe_sql(self, stmt: str, *, args: type[tuple[P1, Unpack[PX]]], resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_RX_PX[tuple[R1, R2, Unpack[RX]], P1, Unpack[PX]]: ...
910
-
911
- ### END OF AUTO-GENERATED BLOCK FOR unsafe_sql ###
912
-
913
- def unsafe_sql(self, stmt: str, *, args: type[Any] | None = None, resultset: type[Any] | None = None, arg: type[Any] | None = None, result: type[Any] | None = None) -> _SQL:
914
- """
915
- Creates a SQL statement with associated type information from a string.
916
-
917
- This offers an alternative to the function :func:`sql` when we want to prevent the type checker from enforcing
918
- a string literal, e.g. when we want to embed a variable as the table name to dynamically create a SQL statement.
919
-
920
- :param stmt: SQL statement as a string (or f-string).
921
- :param args: Type signature for multiple input parameters (e.g. `tuple[bool, int, str]`).
922
- :param resultset: Type signature for multiple resultset columns (e.g. `tuple[datetime, Decimal, str]`).
923
- :param arg: Type signature for a single input parameter (e.g. `int`).
924
- :param result: Type signature for a single result column (e.g. `UUID`).
925
- """
926
-
927
- input_data_types, output_data_types = _sql_args_resultset(args=args, resultset=resultset, arg=arg, result=result)
928
- obj = _SQLString(stmt, args=input_data_types, resultset=output_data_types)
929
- return _SQLImpl(obj)
930
-
931
-
932
- def _sql_args_resultset(*, args: type[Any] | None = None, resultset: type[Any] | None = None, arg: type[Any] | None = None, result: type[Any] | None = None) -> tuple[tuple[Any, ...], tuple[Any, ...]]:
933
- "Parses an argument/resultset signature into input/output types."
934
-
935
- if args is not None and arg is not None:
936
- raise TypeError("expected: either `args` or `arg`; got: both")
937
- if resultset is not None and result is not None:
938
- raise TypeError("expected: either `resultset` or `result`; got: both")
939
-
940
- if args is not None:
941
- if get_origin(args) is not tuple:
942
- raise TypeError(f"expected: `type[tuple[T, ...]]` for `args`; got: {type(args)}")
943
- input_data_types = get_args(args)
944
- elif arg is not None:
945
- input_data_types = (arg,)
946
- else:
947
- input_data_types = ()
948
-
949
- if resultset is not None:
950
- if get_origin(resultset) is not tuple:
951
- raise TypeError(f"expected: `type[tuple[T, ...]]` for `resultset`; got: {type(resultset)}")
952
- output_data_types = get_args(resultset)
953
- elif result is not None:
954
- output_data_types = (result,)
955
- else:
956
- output_data_types = ()
957
-
958
- return input_data_types, output_data_types
959
-
960
-
961
- FACTORY: SQLFactory = SQLFactory()
962
-
963
- sql = FACTORY.sql
964
- unsafe_sql = FACTORY.unsafe_sql
1
+ """
2
+ Type-safe queries for asyncpg.
3
+
4
+ :see: https://github.com/hunyadi/asyncpg_typed
5
+ """
6
+
7
+ __version__ = "0.1.5"
8
+ __author__ = "Levente Hunyadi"
9
+ __copyright__ = "Copyright 2025-2026, Levente Hunyadi"
10
+ __license__ = "MIT"
11
+ __maintainer__ = "Levente Hunyadi"
12
+ __status__ = "Production"
13
+
14
+ import enum
15
+ import sys
16
+ import typing
17
+ from abc import abstractmethod
18
+ from collections.abc import Callable, Iterable, Sequence
19
+ from dataclasses import dataclass
20
+ from datetime import date, datetime, time, timedelta
21
+ from decimal import Decimal
22
+ from functools import reduce
23
+ from io import StringIO
24
+ from ipaddress import IPv4Address, IPv4Network, IPv6Address, IPv6Network
25
+ from types import UnionType
26
+ from typing import Any, Protocol, TypeAlias, TypeGuard, TypeVar, Union, get_args, get_origin, overload
27
+ from uuid import UUID
28
+
29
+ import asyncpg
30
+ from asyncpg.prepared_stmt import PreparedStatement
31
+
32
+ if sys.version_info < (3, 11):
33
+ from typing_extensions import LiteralString, TypeVarTuple, Unpack
34
+ else:
35
+ from typing import LiteralString, TypeVarTuple, Unpack
36
+
37
+ JsonType = None | bool | int | float | str | dict[str, "JsonType"] | list["JsonType"]
38
+
39
+ RequiredJsonType = bool | int | float | str | dict[str, "JsonType"] | list["JsonType"]
40
+
41
+ TargetType: TypeAlias = type[Any] | UnionType
42
+
43
+ Connection: TypeAlias = asyncpg.Connection | asyncpg.pool.PoolConnectionProxy
44
+
45
+
46
+ class CountMismatchError(TypeError):
47
+ "Raised when a prepared statement takes or returns a different number of parameters or columns than declared in Python."
48
+
49
+
50
+ class TypeMismatchError(TypeError):
51
+ "Raised when a prepared statement takes or returns a PostgreSQL type incompatible with the declared Python type."
52
+
53
+
54
+ class NameMismatchError(TypeError):
55
+ "Raised when the name of a result-set column differs from what is declared in Python."
56
+
57
+
58
+ class EnumMismatchError(TypeError):
59
+ "Raised when a prepared statement takes or returns a PostgreSQL enum type whose permitted set of values differs from what is declared in Python."
60
+
61
+
62
+ class NoneTypeError(TypeError):
63
+ "Raised when a column marked as required contains a `NULL` value."
64
+
65
+
66
+ if sys.version_info < (3, 11):
67
+
68
+ def is_enum_type(typ: Any) -> TypeGuard[type[enum.Enum]]:
69
+ """
70
+ `True` if the specified type is an enumeration type.
71
+ """
72
+
73
+ # use an explicit isinstance(..., type) check to filter out special forms like generics
74
+ return isinstance(typ, type) and issubclass(typ, enum.Enum)
75
+
76
+
77
+ else:
78
+
79
+ def is_enum_type(typ: Any) -> TypeGuard[type[enum.Enum]]:
80
+ """
81
+ `True` if the specified type is an enumeration type.
82
+ """
83
+
84
+ return isinstance(typ, enum.EnumType)
85
+
86
+
87
+ def is_union_type(tp: Any) -> bool:
88
+ """
89
+ `True` if `tp` is a union type such as `A | B` or `Union[A, B]`.
90
+ """
91
+
92
+ origin = get_origin(tp)
93
+ return origin is Union or origin is UnionType
94
+
95
+
96
+ def is_optional_type(tp: Any) -> bool:
97
+ """
98
+ `True` if `tp` is an optional type such as `T | None`, `Optional[T]` or `Union[T, None]`.
99
+ """
100
+
101
+ return is_union_type(tp) and any(a is type(None) for a in get_args(tp))
102
+
103
+
104
+ def is_standard_type(tp: Any) -> bool:
105
+ """
106
+ `True` if the type represents a built-in or a well-known standard type.
107
+ """
108
+
109
+ return tp.__module__ == "builtins" or tp.__module__ == UnionType.__module__
110
+
111
+
112
+ def is_json_type(tp: Any) -> bool:
113
+ """
114
+ `True` if the type represents an object de-serialized from a JSON string.
115
+ """
116
+
117
+ return tp in [JsonType, RequiredJsonType]
118
+
119
+
120
+ def is_inet_type(tp: Any) -> bool:
121
+ """
122
+ `True` if the type represents an IP address or network.
123
+ """
124
+
125
+ return tp in [IPv4Address, IPv6Address, IPv4Network, IPv6Network]
126
+
127
+
128
+ def make_union_type(tpl: list[Any]) -> UnionType:
129
+ """
130
+ Creates a `UnionType` (a.k.a. `A | B | C`) dynamically at run time.
131
+ """
132
+
133
+ if len(tpl) < 2:
134
+ raise ValueError("expected: at least two types to make a `UnionType`")
135
+
136
+ return reduce(lambda a, b: a | b, tpl)
137
+
138
+
139
+ def get_required_type(tp: Any) -> Any:
140
+ """
141
+ Removes `None` from an optional type (i.e. a union type that has `None` as a member).
142
+ """
143
+
144
+ if not is_optional_type(tp):
145
+ return tp
146
+
147
+ tpl = [a for a in get_args(tp) if a is not type(None)]
148
+ if len(tpl) > 1:
149
+ return make_union_type(tpl)
150
+ elif len(tpl) > 0:
151
+ return tpl[0]
152
+ else:
153
+ return type(None)
154
+
155
+
156
+ def _standard_json_decoder() -> Callable[[str], JsonType]:
157
+ import json
158
+
159
+ _json_decoder = json.JSONDecoder()
160
+ return _json_decoder.decode
161
+
162
+
163
+ def _json_decoder() -> Callable[[str], JsonType]:
164
+ if typing.TYPE_CHECKING:
165
+ return _standard_json_decoder()
166
+ else:
167
+ try:
168
+ import orjson
169
+
170
+ return orjson.loads
171
+ except ModuleNotFoundError:
172
+ return _standard_json_decoder()
173
+
174
+
175
+ JSON_DECODER = _json_decoder()
176
+
177
+
178
+ def _standard_json_encoder() -> Callable[[JsonType], str]:
179
+ import json
180
+
181
+ _json_encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":"), allow_nan=False)
182
+ return _json_encoder.encode
183
+
184
+
185
+ def _json_encoder() -> Callable[[JsonType], str]:
186
+ if typing.TYPE_CHECKING:
187
+ return _standard_json_encoder()
188
+ else:
189
+ try:
190
+ import orjson
191
+
192
+ def _wrap(value: JsonType) -> str:
193
+ return orjson.dumps(value).decode()
194
+
195
+ return _wrap
196
+ except ModuleNotFoundError:
197
+ return _standard_json_encoder()
198
+
199
+
200
+ JSON_ENCODER = _json_encoder()
201
+
202
+
203
+ def get_output_converter_for(tp: Any) -> Callable[[Any], Any]:
204
+ """
205
+ Returns a callable that takes a wire type and returns a target type.
206
+
207
+ A wire type is one of the types returned by asyncpg.
208
+ A target type is one of the types supported by the library.
209
+ """
210
+
211
+ if is_json_type(tp):
212
+ # asyncpg returns fields of type `json` and `jsonb` as `str`, which must be de-serialized
213
+ return JSON_DECODER
214
+ else:
215
+ # target data types that require conversion must have a single-argument `__init__` that takes an object of the source type
216
+ return tp
217
+
218
+
219
+ def get_input_converter_for(tp: Any) -> Callable[[Any], Any]:
220
+ """
221
+ Returns a callable that takes a source type and returns a wire type.
222
+
223
+ A source type is one of the types supported by the library.
224
+ A wire type is one of the types returned by asyncpg.
225
+ """
226
+
227
+ if is_json_type(tp):
228
+ # asyncpg expects fields of type `json` and `jsonb` as `str`, which must be serialized
229
+ return JSON_ENCODER
230
+ else:
231
+ # source data types that require conversion must have a single-argument `__init__` that takes an object of the source type
232
+ return tp
233
+
234
+
235
+ # maps PostgreSQL internal type names to compatible Python types
236
+ _NAME_TO_TYPE: dict[str, tuple[Any, ...]] = {
237
+ # boolean type
238
+ "bool": (bool,),
239
+ # numeric types
240
+ "int2": (int,),
241
+ "int4": (int,),
242
+ "int8": (int,),
243
+ "float4": (float,),
244
+ "float8": (float,),
245
+ "numeric": (Decimal,),
246
+ # date and time types
247
+ "date": (date,),
248
+ "time": (time,),
249
+ "timetz": (time,),
250
+ "timestamp": (datetime,),
251
+ "timestamptz": (datetime,),
252
+ "interval": (timedelta,),
253
+ # character sequence types
254
+ "bpchar": (str,),
255
+ "varchar": (str,),
256
+ "text": (str,),
257
+ # binary sequence types
258
+ "bytea": (bytes,),
259
+ # unique identifier type
260
+ "uuid": (UUID,),
261
+ # address types
262
+ "cidr": (IPv4Network, IPv6Network, IPv4Network | IPv6Network),
263
+ "inet": (IPv4Network, IPv6Network, IPv4Network | IPv6Network, IPv4Address, IPv6Address, IPv4Address | IPv6Address),
264
+ "macaddr": (str,),
265
+ "macaddr8": (str,),
266
+ # JSON type
267
+ "json": (str, RequiredJsonType),
268
+ "jsonb": (str, RequiredJsonType),
269
+ # XML type
270
+ "xml": (str,),
271
+ # geometric types
272
+ "point": (asyncpg.Point,),
273
+ "line": (asyncpg.Line,),
274
+ "lseg": (asyncpg.LineSegment,),
275
+ "box": (asyncpg.Box,),
276
+ "path": (asyncpg.Path,),
277
+ "polygon": (asyncpg.Polygon,),
278
+ "circle": (asyncpg.Circle,),
279
+ # range types
280
+ "int4range": (asyncpg.Range[int],),
281
+ "int4multirange": (list[asyncpg.Range[int]],),
282
+ "int8range": (asyncpg.Range[int],),
283
+ "int8multirange": (list[asyncpg.Range[int]],),
284
+ "numrange": (asyncpg.Range[Decimal],),
285
+ "nummultirange": (list[asyncpg.Range[Decimal]],),
286
+ "tsrange": (asyncpg.Range[datetime],),
287
+ "tsmultirange": (list[asyncpg.Range[datetime]],),
288
+ "tstzrange": (asyncpg.Range[datetime],),
289
+ "tstzmultirange": (list[asyncpg.Range[datetime]],),
290
+ "daterange": (asyncpg.Range[date],),
291
+ "datemultirange": (list[asyncpg.Range[date]],),
292
+ }
293
+
294
+
295
+ def type_to_str(tp: Any) -> str:
296
+ "Emits a friendly name for a type."
297
+
298
+ if isinstance(tp, type):
299
+ return tp.__name__
300
+ else:
301
+ return str(tp)
302
+
303
+
304
+ class _TypeVerifier:
305
+ """
306
+ Verifies if the Python target type can represent the PostgreSQL source type.
307
+ """
308
+
309
+ _connection: Connection
310
+
311
+ def __init__(self, connection: Connection) -> None:
312
+ self._connection = connection
313
+
314
+ async def _check_enum_type(self, pg_name: str, pg_type: asyncpg.Type, data_type: type[enum.Enum]) -> None:
315
+ """
316
+ Verifies if a Python enumeration type matches a PostgreSQL enumeration type.
317
+ """
318
+
319
+ for e in data_type:
320
+ if not isinstance(e.value, str):
321
+ raise TypeMismatchError(f"expected: Python enum type `{type_to_str(data_type)}` with `str` values; got: `{type_to_str(type(e.value))}` for enum field `{e.name}`")
322
+
323
+ py_values = set(e.value for e in data_type)
324
+
325
+ rows = await self._connection.fetch("SELECT enumlabel FROM pg_enum WHERE enumtypid = $1 ORDER BY enumsortorder;", pg_type.oid)
326
+ db_values = set(row[0] for row in rows)
327
+
328
+ db_extra = db_values - py_values
329
+ if db_extra:
330
+ raise EnumMismatchError(f"expected: Python enum type `{type_to_str(data_type)}` to match values of PostgreSQL enum type `{pg_type.name}` for {pg_name}; missing value(s): {', '.join(f'`{val}`' for val in db_extra)})")
331
+
332
+ py_extra = py_values - db_values
333
+ if py_extra:
334
+ raise EnumMismatchError(f"expected: Python enum type `{type_to_str(data_type)}` to match values of PostgreSQL enum type `{pg_type.name}` for {pg_name}; got extra value(s): {', '.join(f'`{val}`' for val in py_extra)})")
335
+
336
+ async def check_data_type(self, pg_name: str, pg_type: asyncpg.Type, data_type: TargetType) -> None:
337
+ """
338
+ Verifies if the Python target type can represent the PostgreSQL source type.
339
+ """
340
+
341
+ if pg_type.schema == "pg_catalog": # well-known PostgreSQL types
342
+ if is_enum_type(data_type):
343
+ if pg_type.name not in ["bpchar", "varchar", "text"]:
344
+ raise TypeMismatchError(f"expected: Python enum type `{type_to_str(data_type)}` for {pg_name}; got: PostgreSQL type `{pg_type.kind}` of `{pg_type.name}` instead of `char`, `varchar` or `text`")
345
+ elif pg_type.kind == "array" and get_origin(data_type) is list:
346
+ if not pg_type.name.endswith("[]"):
347
+ raise TypeMismatchError(f"expected: Python list type `{type_to_str(data_type)}` for {pg_name}; got: PostgreSQL type `{pg_type.kind}` of `{pg_type.name}` instead of array")
348
+
349
+ expected_types = _NAME_TO_TYPE.get(pg_type.name[:-2])
350
+ if expected_types is None:
351
+ raise TypeMismatchError(f"expected: Python list type `{type_to_str(data_type)}` for {pg_name}; got: unrecognized PostgreSQL type `{pg_type.kind}` of `{pg_type.name}`")
352
+ elif get_args(data_type)[0] not in expected_types:
353
+ if len(expected_types) == 1:
354
+ target = f"the Python type `{type_to_str(expected_types[0])}`"
355
+ else:
356
+ target = f"one of the Python types {', '.join(f'`{type_to_str(tp)}`' for tp in expected_types)}"
357
+ raise TypeMismatchError(f"expected: Python list type `{type_to_str(data_type)}` for {pg_name}; got: incompatible PostgreSQL type `{pg_type.kind}` of `{pg_type.name}` whose elements convert to {target}")
358
+ else:
359
+ expected_types = _NAME_TO_TYPE.get(pg_type.name)
360
+ if expected_types is None:
361
+ raise TypeMismatchError(f"expected: Python type `{type_to_str(data_type)}` for {pg_name}; got: unrecognized PostgreSQL type `{pg_type.kind}` of `{pg_type.name}`")
362
+ elif data_type not in expected_types:
363
+ if len(expected_types) == 1:
364
+ target = f"the Python type `{type_to_str(expected_types[0])}`"
365
+ else:
366
+ target = f"one of the Python types {', '.join(f'`{type_to_str(tp)}`' for tp in expected_types)}"
367
+ raise TypeMismatchError(f"expected: Python type `{type_to_str(data_type)}` for {pg_name}; got: incompatible PostgreSQL type `{pg_type.kind}` of `{pg_type.name}` which converts to {target}")
368
+ elif pg_type.kind == "composite": # PostgreSQL composite types
369
+ # user-defined composite types registered with `conn.set_type_codec()` typically using `format="tuple"`
370
+ pass
371
+ else: # custom PostgreSQL types
372
+ if is_enum_type(data_type):
373
+ await self._check_enum_type(pg_name, pg_type, data_type)
374
+ elif is_standard_type(data_type):
375
+ raise TypeMismatchError(f"expected: Python type `{type_to_str(data_type)}` for {pg_name}; got: PostgreSQL type `{pg_type.kind}` of `{pg_type.name}`")
376
+ else:
377
+ # user-defined types registered with `conn.set_type_codec()`
378
+ pass
379
+
380
+
381
+ @dataclass(frozen=True)
382
+ class _Placeholder:
383
+ ordinal: int
384
+ data_type: TargetType
385
+
386
+ def __repr__(self) -> str:
387
+ return f"{self.__class__.__name__}({self.ordinal}, {self.data_type!r})"
388
+
389
+
390
+ class _ResultsetWrapper:
391
+ "Wraps result-set rows into a tuple or named tuple."
392
+
393
+ init: Callable[..., tuple[Any, ...]] | None
394
+ iterable: Callable[[Iterable[Any]], tuple[Any, ...]]
395
+
396
+ def __init__(self, init: Callable[..., tuple[Any, ...]] | None, iterable: Callable[[Iterable[Any]], tuple[Any, ...]]) -> None:
397
+ """
398
+ Initializes a result-set wrapper.
399
+
400
+ :param init: Initializer function that takes as many arguments as columns in the result-set.
401
+ :param iterable: Initializer function that takes an iterable over columns of a result-set row.
402
+ """
403
+
404
+ self.init = init
405
+ self.iterable = iterable
406
+
407
+
408
+ class _SQLObject:
409
+ """
410
+ Associates input and output type information with a SQL statement.
411
+ """
412
+
413
+ _parameter_data_types: tuple[_Placeholder, ...]
414
+ _resultset_data_types: tuple[TargetType, ...]
415
+ _resultset_column_names: tuple[str, ...] | None
416
+ _resultset_wrapper: _ResultsetWrapper
417
+ _parameter_cast: int
418
+ _parameter_converters: tuple[Callable[[Any], Any], ...]
419
+ _required: int
420
+ _resultset_cast: int
421
+ _resultset_converters: tuple[Callable[[Any], Any], ...]
422
+
423
+ @property
424
+ def parameter_data_types(self) -> tuple[_Placeholder, ...]:
425
+ "Expected inbound parameter data types."
426
+
427
+ return self._parameter_data_types
428
+
429
+ @property
430
+ def resultset_data_types(self) -> tuple[TargetType, ...]:
431
+ "Expected column data types in the result-set."
432
+
433
+ return self._resultset_data_types
434
+
435
+ @property
436
+ def resultset_column_names(self) -> tuple[str, ...] | None:
437
+ "Expected column names in the result-set."
438
+
439
+ return self._resultset_column_names
440
+
441
+ def __init__(
442
+ self,
443
+ *,
444
+ args: tuple[TargetType, ...],
445
+ resultset: tuple[TargetType, ...],
446
+ names: tuple[str, ...] | None,
447
+ wrapper: _ResultsetWrapper,
448
+ ) -> None:
449
+ self._parameter_data_types = tuple(_Placeholder(ordinal, get_required_type(arg)) for ordinal, arg in enumerate(args, start=1))
450
+ self._resultset_data_types = tuple(get_required_type(data_type) for data_type in resultset)
451
+ self._resultset_column_names = names
452
+ self._resultset_wrapper = wrapper
453
+
454
+ # create a bit-field of types that require cast or serialization (1: apply conversion; 0: forward value as-is)
455
+ parameter_cast = 0
456
+ for index, placeholder in enumerate(self._parameter_data_types):
457
+ parameter_cast |= is_json_type(placeholder.data_type) << index
458
+ self._parameter_cast = parameter_cast
459
+
460
+ self._parameter_converters = tuple(get_input_converter_for(placeholder.data_type) for placeholder in self._parameter_data_types)
461
+
462
+ # create a bit-field of required types (1: required; 0: optional)
463
+ required = 0
464
+ for index, data_type in enumerate(resultset):
465
+ required |= (not is_optional_type(data_type)) << index
466
+ self._required = required
467
+
468
+ # create a bit-field of types that require cast or de-serialization (1: apply conversion; 0: forward value as-is)
469
+ resultset_cast = 0
470
+ for index, data_type in enumerate(self._resultset_data_types):
471
+ resultset_cast |= (is_enum_type(data_type) or is_json_type(data_type) or is_inet_type(data_type)) << index
472
+ self._resultset_cast = resultset_cast
473
+
474
+ self._resultset_converters = tuple(get_output_converter_for(data_type) for data_type in self._resultset_data_types)
475
+
476
+ def _raise_required_is_none(self, row: tuple[Any, ...], row_index: int | None = None) -> None:
477
+ """
478
+ Raises an error with the index of the first column value that is of a required type but has been assigned a value of `None`.
479
+ """
480
+
481
+ for col_index in range(len(row)):
482
+ if (self._required >> col_index & 1) and row[col_index] is None:
483
+ if row_index is not None:
484
+ row_col_spec = f"row #{row_index} and column #{col_index}"
485
+ else:
486
+ row_col_spec = f"column #{col_index}"
487
+ raise NoneTypeError(f"expected: {self._resultset_data_types[col_index]} in {row_col_spec}; got: NULL")
488
+
489
+ def check_rows(self, rows: list[tuple[Any, ...]]) -> None:
490
+ """
491
+ Verifies if declared types match actual value types in a resultset.
492
+ """
493
+
494
+ if not rows:
495
+ return
496
+
497
+ required = self._required
498
+ if not required:
499
+ return
500
+
501
+ match len(rows[0]):
502
+ case 1:
503
+ for r, row in enumerate(rows):
504
+ if required & (row[0] is None):
505
+ self._raise_required_is_none(row, r)
506
+ case 2:
507
+ for r, row in enumerate(rows):
508
+ a, b = row
509
+ if required & ((a is None) | (b is None) << 1):
510
+ self._raise_required_is_none(row, r)
511
+ case 3:
512
+ for r, row in enumerate(rows):
513
+ a, b, c = row
514
+ if required & ((a is None) | (b is None) << 1 | (c is None) << 2):
515
+ self._raise_required_is_none(row, r)
516
+ case 4:
517
+ for r, row in enumerate(rows):
518
+ a, b, c, d = row
519
+ if required & ((a is None) | (b is None) << 1 | (c is None) << 2 | (d is None) << 3):
520
+ self._raise_required_is_none(row, r)
521
+ case 5:
522
+ for r, row in enumerate(rows):
523
+ a, b, c, d, e = row
524
+ if required & ((a is None) | (b is None) << 1 | (c is None) << 2 | (d is None) << 3 | (e is None) << 4):
525
+ self._raise_required_is_none(row, r)
526
+ case 6:
527
+ for r, row in enumerate(rows):
528
+ a, b, c, d, e, f = row
529
+ if required & ((a is None) | (b is None) << 1 | (c is None) << 2 | (d is None) << 3 | (e is None) << 4 | (f is None) << 5):
530
+ self._raise_required_is_none(row, r)
531
+ case 7:
532
+ for r, row in enumerate(rows):
533
+ a, b, c, d, e, f, g = row
534
+ if required & ((a is None) | (b is None) << 1 | (c is None) << 2 | (d is None) << 3 | (e is None) << 4 | (f is None) << 5 | (g is None) << 6):
535
+ self._raise_required_is_none(row, r)
536
+ case 8:
537
+ for r, row in enumerate(rows):
538
+ a, b, c, d, e, f, g, h = row
539
+ if required & ((a is None) | (b is None) << 1 | (c is None) << 2 | (d is None) << 3 | (e is None) << 4 | (f is None) << 5 | (g is None) << 6 | (h is None) << 7):
540
+ self._raise_required_is_none(row, r)
541
+ case _:
542
+ for r, row in enumerate(rows):
543
+ self._raise_required_is_none(row, r)
544
+
545
+ def check_row(self, row: tuple[Any, ...]) -> None:
546
+ """
547
+ Verifies if declared types match actual value types in a single row.
548
+ """
549
+
550
+ if self._required:
551
+ self._raise_required_is_none(row)
552
+
553
+ def check_column(self, column: list[Any]) -> None:
554
+ """
555
+ Verifies if the declared type matches the actual value type of a single-column resultset.
556
+ """
557
+
558
+ if self._required:
559
+ for i, value in enumerate(column):
560
+ if value is None:
561
+ raise NoneTypeError(f"expected: {self._resultset_data_types[0]} in row #{i}; got: NULL")
562
+
563
+ def check_value(self, value: Any) -> None:
564
+ """
565
+ Verifies if the declared type matches the actual value type.
566
+ """
567
+
568
+ if self._required and value is None:
569
+ raise NoneTypeError(f"expected: {self._resultset_data_types[0]}; got: NULL")
570
+
571
+ def convert_arg_lists(self, arg_lists: Iterable[Sequence[Any]]) -> Iterable[Sequence[Any]]:
572
+ """
573
+ Converts a list of Python query argument tuples to a list of PostgreSQL parameter tuples.
574
+ """
575
+
576
+ cast = self._parameter_cast
577
+ if cast:
578
+ converters = self._parameter_converters
579
+ yield from (tuple((converters[i](value) if (value := arg[i]) is not None and cast >> i & 1 else value) for i in range(len(arg))) for arg in arg_lists)
580
+ else:
581
+ yield from arg_lists
582
+
583
+ def convert_arg_list(self, arg_list: Sequence[Any]) -> Sequence[Any]:
584
+ """
585
+ Converts Python query arguments to PostgreSQL parameters.
586
+ """
587
+
588
+ cast = self._parameter_cast
589
+ if cast:
590
+ converters = self._parameter_converters
591
+ return tuple((converters[i](value) if (value := arg_list[i]) is not None and cast >> i & 1 else value) for i in range(len(arg_list)))
592
+ else:
593
+ return tuple(value for value in arg_list)
594
+
595
+ def convert_rows(self, rows: list[asyncpg.Record]) -> list[tuple[Any, ...]]:
596
+ """
597
+ Converts columns in the PostgreSQL result-set to their corresponding Python target types.
598
+
599
+ :param rows: List of rows returned by PostgreSQL.
600
+ :returns: List of tuples with each tuple element having the configured Python target type.
601
+ """
602
+
603
+ if not rows:
604
+ return []
605
+
606
+ if self._resultset_cast:
607
+ return self._convert_rows_cast(rows)
608
+ else:
609
+ return self._convert_rows_pass(rows)
610
+
611
+ def _convert_rows_cast(self, rows: list[asyncpg.Record]) -> list[tuple[Any, ...]]:
612
+ init_wrapper = self._resultset_wrapper.init
613
+ iterable_wrapper = self._resultset_wrapper.iterable
614
+ cast = self._resultset_cast
615
+ cast_a = cast >> 0 & 1
616
+ cast_b = cast >> 1 & 1
617
+ cast_c = cast >> 2 & 1
618
+ cast_d = cast >> 3 & 1
619
+ cast_e = cast >> 4 & 1
620
+ cast_f = cast >> 5 & 1
621
+ cast_g = cast >> 6 & 1
622
+ columns = len(rows[0])
623
+ match columns:
624
+ case 1:
625
+ converter = self._resultset_converters[0]
626
+ if init_wrapper is not None:
627
+ return [init_wrapper(converter(value) if (value := row[0]) is not None else value) for row in rows]
628
+ else:
629
+ return [(converter(value) if (value := row[0]) is not None else value,) for row in rows]
630
+ ### START OF AUTO-GENERATED BLOCK FOR convert_cast ###
631
+ case 2:
632
+ conv_a, conv_b = self._resultset_converters
633
+ if init_wrapper is not None:
634
+ return [
635
+ init_wrapper(
636
+ conv_a(value) if (value := row[0]) is not None and cast_a else value,
637
+ conv_b(value) if (value := row[1]) is not None and cast_b else value,
638
+ )
639
+ for row in rows
640
+ ]
641
+ else:
642
+ return [
643
+ (
644
+ conv_a(value) if (value := row[0]) is not None and cast_a else value,
645
+ conv_b(value) if (value := row[1]) is not None and cast_b else value,
646
+ )
647
+ for row in rows
648
+ ]
649
+ case 3:
650
+ conv_a, conv_b, conv_c = self._resultset_converters
651
+ if init_wrapper is not None:
652
+ return [
653
+ init_wrapper(
654
+ conv_a(value) if (value := row[0]) is not None and cast_a else value,
655
+ conv_b(value) if (value := row[1]) is not None and cast_b else value,
656
+ conv_c(value) if (value := row[2]) is not None and cast_c else value,
657
+ )
658
+ for row in rows
659
+ ]
660
+ else:
661
+ return [
662
+ (
663
+ conv_a(value) if (value := row[0]) is not None and cast_a else value,
664
+ conv_b(value) if (value := row[1]) is not None and cast_b else value,
665
+ conv_c(value) if (value := row[2]) is not None and cast_c else value,
666
+ )
667
+ for row in rows
668
+ ]
669
+ case 4:
670
+ conv_a, conv_b, conv_c, conv_d = self._resultset_converters
671
+ if init_wrapper is not None:
672
+ return [
673
+ init_wrapper(
674
+ conv_a(value) if (value := row[0]) is not None and cast_a else value,
675
+ conv_b(value) if (value := row[1]) is not None and cast_b else value,
676
+ conv_c(value) if (value := row[2]) is not None and cast_c else value,
677
+ conv_d(value) if (value := row[3]) is not None and cast_d else value,
678
+ )
679
+ for row in rows
680
+ ]
681
+ else:
682
+ return [
683
+ (
684
+ conv_a(value) if (value := row[0]) is not None and cast_a else value,
685
+ conv_b(value) if (value := row[1]) is not None and cast_b else value,
686
+ conv_c(value) if (value := row[2]) is not None and cast_c else value,
687
+ conv_d(value) if (value := row[3]) is not None and cast_d else value,
688
+ )
689
+ for row in rows
690
+ ]
691
+ case 5:
692
+ conv_a, conv_b, conv_c, conv_d, conv_e = self._resultset_converters
693
+ if init_wrapper is not None:
694
+ return [
695
+ init_wrapper(
696
+ conv_a(value) if (value := row[0]) is not None and cast_a else value,
697
+ conv_b(value) if (value := row[1]) is not None and cast_b else value,
698
+ conv_c(value) if (value := row[2]) is not None and cast_c else value,
699
+ conv_d(value) if (value := row[3]) is not None and cast_d else value,
700
+ conv_e(value) if (value := row[4]) is not None and cast_e else value,
701
+ )
702
+ for row in rows
703
+ ]
704
+ else:
705
+ return [
706
+ (
707
+ conv_a(value) if (value := row[0]) is not None and cast_a else value,
708
+ conv_b(value) if (value := row[1]) is not None and cast_b else value,
709
+ conv_c(value) if (value := row[2]) is not None and cast_c else value,
710
+ conv_d(value) if (value := row[3]) is not None and cast_d else value,
711
+ conv_e(value) if (value := row[4]) is not None and cast_e else value,
712
+ )
713
+ for row in rows
714
+ ]
715
+ case 6:
716
+ conv_a, conv_b, conv_c, conv_d, conv_e, conv_f = self._resultset_converters
717
+ if init_wrapper is not None:
718
+ return [
719
+ init_wrapper(
720
+ conv_a(value) if (value := row[0]) is not None and cast_a else value,
721
+ conv_b(value) if (value := row[1]) is not None and cast_b else value,
722
+ conv_c(value) if (value := row[2]) is not None and cast_c else value,
723
+ conv_d(value) if (value := row[3]) is not None and cast_d else value,
724
+ conv_e(value) if (value := row[4]) is not None and cast_e else value,
725
+ conv_f(value) if (value := row[5]) is not None and cast_f else value,
726
+ )
727
+ for row in rows
728
+ ]
729
+ else:
730
+ return [
731
+ (
732
+ conv_a(value) if (value := row[0]) is not None and cast_a else value,
733
+ conv_b(value) if (value := row[1]) is not None and cast_b else value,
734
+ conv_c(value) if (value := row[2]) is not None and cast_c else value,
735
+ conv_d(value) if (value := row[3]) is not None and cast_d else value,
736
+ conv_e(value) if (value := row[4]) is not None and cast_e else value,
737
+ conv_f(value) if (value := row[5]) is not None and cast_f else value,
738
+ )
739
+ for row in rows
740
+ ]
741
+ case 7:
742
+ conv_a, conv_b, conv_c, conv_d, conv_e, conv_f, conv_g = self._resultset_converters
743
+ if init_wrapper is not None:
744
+ return [
745
+ init_wrapper(
746
+ conv_a(value) if (value := row[0]) is not None and cast_a else value,
747
+ conv_b(value) if (value := row[1]) is not None and cast_b else value,
748
+ conv_c(value) if (value := row[2]) is not None and cast_c else value,
749
+ conv_d(value) if (value := row[3]) is not None and cast_d else value,
750
+ conv_e(value) if (value := row[4]) is not None and cast_e else value,
751
+ conv_f(value) if (value := row[5]) is not None and cast_f else value,
752
+ conv_g(value) if (value := row[6]) is not None and cast_g else value,
753
+ )
754
+ for row in rows
755
+ ]
756
+ else:
757
+ return [
758
+ (
759
+ conv_a(value) if (value := row[0]) is not None and cast_a else value,
760
+ conv_b(value) if (value := row[1]) is not None and cast_b else value,
761
+ conv_c(value) if (value := row[2]) is not None and cast_c else value,
762
+ conv_d(value) if (value := row[3]) is not None and cast_d else value,
763
+ conv_e(value) if (value := row[4]) is not None and cast_e else value,
764
+ conv_f(value) if (value := row[5]) is not None and cast_f else value,
765
+ conv_g(value) if (value := row[6]) is not None and cast_g else value,
766
+ )
767
+ for row in rows
768
+ ]
769
+
770
+ ### END OF AUTO-GENERATED BLOCK FOR convert_cast ###
771
+ case _:
772
+ converters = self._resultset_converters
773
+ return [iterable_wrapper((converters[i](value) if (value := row[i]) is not None and cast >> i & 1 else value) for i in range(columns)) for row in rows]
774
+
775
+ def _convert_rows_pass(self, rows: list[asyncpg.Record]) -> list[tuple[Any, ...]]:
776
+ init_wrapper = self._resultset_wrapper.init
777
+ iterable_wrapper = self._resultset_wrapper.iterable
778
+ match len(rows[0]):
779
+ case 1:
780
+ if init_wrapper is not None:
781
+ return [init_wrapper(row[0]) for row in rows]
782
+ else:
783
+ return [(row[0],) for row in rows]
784
+ ### START OF AUTO-GENERATED BLOCK FOR convert_pass ###
785
+ case 2:
786
+ if init_wrapper is not None:
787
+ return [init_wrapper(row[0], row[1]) for row in rows]
788
+ else:
789
+ return [(row[0], row[1]) for row in rows]
790
+ case 3:
791
+ if init_wrapper is not None:
792
+ return [init_wrapper(row[0], row[1], row[2]) for row in rows]
793
+ else:
794
+ return [(row[0], row[1], row[2]) for row in rows]
795
+ case 4:
796
+ if init_wrapper is not None:
797
+ return [init_wrapper(row[0], row[1], row[2], row[3]) for row in rows]
798
+ else:
799
+ return [(row[0], row[1], row[2], row[3]) for row in rows]
800
+ case 5:
801
+ if init_wrapper is not None:
802
+ return [init_wrapper(row[0], row[1], row[2], row[3], row[4]) for row in rows]
803
+ else:
804
+ return [(row[0], row[1], row[2], row[3], row[4]) for row in rows]
805
+ case 6:
806
+ if init_wrapper is not None:
807
+ return [init_wrapper(row[0], row[1], row[2], row[3], row[4], row[5]) for row in rows]
808
+ else:
809
+ return [(row[0], row[1], row[2], row[3], row[4], row[5]) for row in rows]
810
+ case 7:
811
+ if init_wrapper is not None:
812
+ return [init_wrapper(row[0], row[1], row[2], row[3], row[4], row[5], row[6]) for row in rows]
813
+ else:
814
+ return [(row[0], row[1], row[2], row[3], row[4], row[5], row[6]) for row in rows]
815
+
816
+ ### END OF AUTO-GENERATED BLOCK FOR convert_pass ###
817
+ case _:
818
+ return [iterable_wrapper(row.values()) for row in rows]
819
+
820
+ def convert_row(self, row: asyncpg.Record) -> tuple[Any, ...]:
821
+ """
822
+ Converts columns in the PostgreSQL result-set to their corresponding Python target types.
823
+
824
+ :param row: A single row returned by PostgreSQL.
825
+ :returns: A tuple with each tuple element having the configured Python target type.
826
+ """
827
+
828
+ wrapper = self._resultset_wrapper.iterable
829
+ cast = self._resultset_cast
830
+ if cast:
831
+ converters = self._resultset_converters
832
+ columns = len(row)
833
+ return wrapper((converters[i](value) if (value := row[i]) is not None and cast >> i & 1 else value) for i in range(columns))
834
+ else:
835
+ return wrapper(row.values())
836
+
837
+ def convert_column(self, rows: list[asyncpg.Record]) -> list[Any]:
838
+ """
839
+ Converts a single column in the PostgreSQL result-set to its corresponding Python target type.
840
+
841
+ :param rows: List of rows returned by PostgreSQL.
842
+ :returns: List of values having the configured Python target type.
843
+ """
844
+
845
+ cast = self._resultset_cast
846
+ if cast:
847
+ converter = self._resultset_converters[0]
848
+ return [(converter(value) if (value := row[0]) is not None else value) for row in rows]
849
+ else:
850
+ return [row[0] for row in rows]
851
+
852
+ def convert_value(self, value: Any) -> Any:
853
+ """
854
+ Converts a single PostgreSQL value to its corresponding Python target type.
855
+
856
+ :param value: A single value returned by PostgreSQL.
857
+ :returns: A converted value having the configured Python target type.
858
+ """
859
+
860
+ return self._resultset_converters[0](value) if value is not None and self._resultset_cast else value
861
+
862
+ @abstractmethod
863
+ def query(self) -> str:
864
+ """
865
+ Returns a SQL query string with PostgreSQL ordinal placeholders.
866
+ """
867
+ ...
868
+
869
+ def __repr__(self) -> str:
870
+ return f"{self.__class__.__name__}({self.query()!r})"
871
+
872
+ def __str__(self) -> str:
873
+ return self.query()
874
+
875
+
876
+ if sys.version_info >= (3, 14):
877
+ from string.templatelib import Interpolation, Template # type: ignore[import-not-found]
878
+
879
+ SQLExpression: TypeAlias = Template | LiteralString
880
+
881
+ class _SQLTemplate(_SQLObject):
882
+ """
883
+ A SQL query specified with the Python t-string syntax.
884
+ """
885
+
886
+ _strings: tuple[str, ...]
887
+ _placeholders: tuple[_Placeholder, ...]
888
+
889
+ def __init__(
890
+ self,
891
+ template: Template,
892
+ *,
893
+ args: tuple[TargetType, ...],
894
+ resultset: tuple[TargetType, ...],
895
+ names: tuple[str, ...] | None,
896
+ wrapper: _ResultsetWrapper,
897
+ ) -> None:
898
+ super().__init__(args=args, resultset=resultset, names=names, wrapper=wrapper)
899
+
900
+ for ip in template.interpolations:
901
+ if ip.conversion is not None:
902
+ raise TypeError(f"interpolation `{ip.expression}` expected to apply no conversion")
903
+ if ip.format_spec:
904
+ raise TypeError(f"interpolation `{ip.expression}` expected to apply no format spec")
905
+ if not isinstance(ip.value, int):
906
+ raise TypeError(f"interpolation `{ip.expression}` expected to evaluate to an integer")
907
+
908
+ self._strings = template.strings
909
+
910
+ if len(self.parameter_data_types) > 0:
911
+
912
+ def _to_placeholder(ip: Interpolation) -> _Placeholder:
913
+ ordinal = int(ip.value)
914
+ if not (0 < ordinal <= len(self.parameter_data_types)):
915
+ raise IndexError(f"interpolation `{ip.expression}` is an ordinal out of range; expected: 0 < value <= {len(self.parameter_data_types)}")
916
+ return self.parameter_data_types[int(ip.value) - 1]
917
+
918
+ self._placeholders = tuple(_to_placeholder(ip) for ip in template.interpolations)
919
+ else:
920
+ self._placeholders = ()
921
+
922
+ def query(self) -> str:
923
+ buf = StringIO()
924
+ for s, p in zip(self._strings[:-1], self._placeholders, strict=True):
925
+ buf.write(s)
926
+ buf.write(f"${p.ordinal}")
927
+ buf.write(self._strings[-1])
928
+ return buf.getvalue()
929
+
930
+ else:
931
+ SQLExpression = LiteralString
932
+
933
+
934
+ class _SQLString(_SQLObject):
935
+ """
936
+ A SQL query specified as a plain string (e.g. f-string).
937
+ """
938
+
939
+ _sql: str
940
+
941
+ def __init__(
942
+ self,
943
+ sql: str,
944
+ *,
945
+ args: tuple[TargetType, ...],
946
+ resultset: tuple[TargetType, ...],
947
+ names: tuple[str, ...] | None,
948
+ wrapper: _ResultsetWrapper,
949
+ ) -> None:
950
+ super().__init__(args=args, resultset=resultset, names=names, wrapper=wrapper)
951
+ self._sql = sql
952
+
953
+ def query(self) -> str:
954
+ return self._sql
955
+
956
+
957
+ class _SQL(Protocol):
958
+ """
959
+ Represents a SQL statement with associated type information.
960
+ """
961
+
962
+
963
+ class _SQLImpl(_SQL):
964
+ """
965
+ Forwards input data to an `asyncpg.PreparedStatement`, and validates output data (if necessary).
966
+ """
967
+
968
+ _sql: _SQLObject
969
+
970
+ def __init__(self, sql: _SQLObject) -> None:
971
+ self._sql = sql
972
+
973
+ def __str__(self) -> str:
974
+ return str(self._sql)
975
+
976
+ def __repr__(self) -> str:
977
+ return repr(self._sql)
978
+
979
+ async def _prepare(self, connection: Connection) -> PreparedStatement:
980
+ stmt = await connection.prepare(self._sql.query())
981
+
982
+ verifier = _TypeVerifier(connection)
983
+
984
+ input_count = len(self._sql.parameter_data_types)
985
+ parameter_count = len(stmt.get_parameters())
986
+ if parameter_count != input_count:
987
+ raise CountMismatchError(f"expected: PostgreSQL query to take {input_count} parameter(s); got: {parameter_count}")
988
+ for param, placeholder in zip(stmt.get_parameters(), self._sql.parameter_data_types, strict=True):
989
+ await verifier.check_data_type(f"parameter ${placeholder.ordinal}", param, placeholder.data_type)
990
+
991
+ output_count = len(self._sql.resultset_data_types)
992
+ column_count = len(stmt.get_attributes())
993
+ if column_count != output_count:
994
+ raise CountMismatchError(f"expected: PostgreSQL query to return {output_count} column(s) in result-set; got: {column_count}")
995
+ if self._sql.resultset_column_names is not None:
996
+ for index, attr, name in zip(range(output_count), stmt.get_attributes(), self._sql.resultset_column_names, strict=True):
997
+ if attr.name != name:
998
+ raise NameMismatchError(f"expected: Python field name `{name}` to match PostgreSQL result-set column name `{attr.name}` for index #{index}")
999
+ for attr, data_type in zip(stmt.get_attributes(), self._sql.resultset_data_types, strict=True):
1000
+ await verifier.check_data_type(f"column `{attr.name}`", attr.type, data_type)
1001
+
1002
+ return stmt
1003
+
1004
+ async def execute(self, connection: asyncpg.Connection, *args: Any) -> None:
1005
+ await connection.execute(self._sql.query(), *self._sql.convert_arg_list(args))
1006
+
1007
+ async def executemany(self, connection: asyncpg.Connection, args: Iterable[Sequence[Any]]) -> None:
1008
+ stmt = await self._prepare(connection)
1009
+ await stmt.executemany(self._sql.convert_arg_lists(args))
1010
+
1011
+ def _cast_fetch(self, rows: list[asyncpg.Record]) -> list[tuple[Any, ...]]:
1012
+ resultset = self._sql.convert_rows(rows)
1013
+ self._sql.check_rows(resultset)
1014
+ return resultset
1015
+
1016
+ async def fetch(self, connection: asyncpg.Connection, *args: Any) -> list[tuple[Any, ...]]:
1017
+ stmt = await self._prepare(connection)
1018
+ rows = await stmt.fetch(*self._sql.convert_arg_list(args))
1019
+ return self._cast_fetch(rows)
1020
+
1021
+ async def fetchmany(self, connection: asyncpg.Connection, args: Iterable[Sequence[Any]]) -> list[tuple[Any, ...]]:
1022
+ stmt = await self._prepare(connection)
1023
+ rows = await stmt.fetchmany(self._sql.convert_arg_lists(args))
1024
+ return self._cast_fetch(rows)
1025
+
1026
+ async def fetchrow(self, connection: asyncpg.Connection, *args: Any) -> tuple[Any, ...] | None:
1027
+ stmt = await self._prepare(connection)
1028
+ row = await stmt.fetchrow(*self._sql.convert_arg_list(args))
1029
+ if row is None:
1030
+ return None
1031
+ resultset = self._sql.convert_row(row)
1032
+ self._sql.check_row(resultset)
1033
+ return resultset
1034
+
1035
+ async def fetchcol(self, connection: asyncpg.Connection, *args: Any) -> list[Any]:
1036
+ stmt = await self._prepare(connection)
1037
+ rows = await stmt.fetch(*self._sql.convert_arg_list(args))
1038
+ column = self._sql.convert_column(rows)
1039
+ self._sql.check_column(column)
1040
+ return column
1041
+
1042
+ async def fetchval(self, connection: asyncpg.Connection, *args: Any) -> Any:
1043
+ stmt = await self._prepare(connection)
1044
+ value = await stmt.fetchval(*self._sql.convert_arg_list(args))
1045
+ result = self._sql.convert_value(value)
1046
+ self._sql.check_value(result)
1047
+ return result
1048
+
1049
+
1050
+ P1 = TypeVar("P1")
1051
+ PX = TypeVarTuple("PX")
1052
+
1053
+ RT = TypeVar("RT")
1054
+ R1 = TypeVar("R1")
1055
+ R2 = TypeVar("R2")
1056
+ RX = TypeVarTuple("RX")
1057
+
1058
+
1059
+ ### START OF AUTO-GENERATED BLOCK FOR Protocol ###
1060
+ class SQL_P0(Protocol):
1061
+ @abstractmethod
1062
+ async def execute(self, connection: Connection) -> None: ...
1063
+
1064
+
1065
+ class SQL_R1_P0(SQL_P0, Protocol[R1]):
1066
+ @abstractmethod
1067
+ async def fetch(self, connection: Connection) -> list[tuple[R1]]: ...
1068
+ @abstractmethod
1069
+ async def fetchrow(self, connection: Connection) -> tuple[R1] | None: ...
1070
+ @abstractmethod
1071
+ async def fetchcol(self, connection: Connection) -> list[R1]: ...
1072
+ @abstractmethod
1073
+ async def fetchval(self, connection: Connection) -> R1: ...
1074
+
1075
+
1076
+ class SQL_RX_P0(SQL_P0, Protocol[RT]):
1077
+ @abstractmethod
1078
+ async def fetch(self, connection: Connection) -> list[RT]: ...
1079
+ @abstractmethod
1080
+ async def fetchrow(self, connection: Connection) -> RT | None: ...
1081
+
1082
+
1083
+ class SQL_PX(Protocol[Unpack[PX]]):
1084
+ @abstractmethod
1085
+ async def execute(self, connection: Connection, *args: Unpack[PX]) -> None: ...
1086
+ @abstractmethod
1087
+ async def executemany(self, connection: Connection, args: Iterable[tuple[Unpack[PX]]]) -> None: ...
1088
+
1089
+
1090
+ class SQL_R1_PX(SQL_PX[Unpack[PX]], Protocol[R1, Unpack[PX]]):
1091
+ @abstractmethod
1092
+ async def fetch(self, connection: Connection, *args: Unpack[PX]) -> list[tuple[R1]]: ...
1093
+ @abstractmethod
1094
+ async def fetchmany(self, connection: Connection, args: Iterable[tuple[Unpack[PX]]]) -> list[tuple[R1]]: ...
1095
+ @abstractmethod
1096
+ async def fetchrow(self, connection: Connection, *args: Unpack[PX]) -> tuple[R1] | None: ...
1097
+ @abstractmethod
1098
+ async def fetchcol(self, connection: Connection, *args: Unpack[PX]) -> list[R1]: ...
1099
+ @abstractmethod
1100
+ async def fetchval(self, connection: Connection, *args: Unpack[PX]) -> R1: ...
1101
+
1102
+
1103
+ class SQL_RX_PX(SQL_PX[Unpack[PX]], Protocol[RT, Unpack[PX]]):
1104
+ @abstractmethod
1105
+ async def fetch(self, connection: Connection, *args: Unpack[PX]) -> list[RT]: ...
1106
+ @abstractmethod
1107
+ async def fetchmany(self, connection: Connection, args: Iterable[tuple[Unpack[PX]]]) -> list[RT]: ...
1108
+ @abstractmethod
1109
+ async def fetchrow(self, connection: Connection, *args: Unpack[PX]) -> RT | None: ...
1110
+
1111
+
1112
+ ### END OF AUTO-GENERATED BLOCK FOR Protocol ###
1113
+
1114
+ RS = TypeVar("RS", bound=tuple[Any, ...])
1115
+
1116
+
1117
+ class SQLFactory:
1118
+ """
1119
+ Creates type-safe SQL queries.
1120
+ """
1121
+
1122
+ ### START OF AUTO-GENERATED BLOCK FOR sql ###
1123
+ @overload
1124
+ def sql(self, stmt: SQLExpression) -> SQL_P0: ...
1125
+ @overload
1126
+ def sql(self, stmt: SQLExpression, *, result: type[R1]) -> SQL_R1_P0[R1]: ...
1127
+ @overload
1128
+ def sql(self, stmt: SQLExpression, *, resultset: type[RS]) -> SQL_RX_P0[RS]: ...
1129
+ @overload
1130
+ def sql(self, stmt: SQLExpression, *, arg: type[P1]) -> SQL_PX[P1]: ...
1131
+ @overload
1132
+ def sql(self, stmt: SQLExpression, *, arg: type[P1], result: type[R1]) -> SQL_R1_PX[R1, P1]: ...
1133
+ @overload
1134
+ def sql(self, stmt: SQLExpression, *, arg: type[P1], resultset: type[RS]) -> SQL_RX_PX[RS, P1]: ...
1135
+ @overload
1136
+ def sql(self, stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]]) -> SQL_PX[P1, Unpack[PX]]: ...
1137
+ @overload
1138
+ def sql(self, stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]], result: type[R1]) -> SQL_R1_PX[R1, P1, Unpack[PX]]: ...
1139
+ @overload
1140
+ def sql(self, stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]], resultset: type[RS]) -> SQL_RX_PX[RS, P1, Unpack[PX]]: ...
1141
+
1142
+ ### END OF AUTO-GENERATED BLOCK FOR sql ###
1143
+
1144
+ def sql(self, stmt: SQLExpression, *, args: type[Any] | None = None, resultset: type[Any] | None = None, arg: type[Any] | None = None, result: type[Any] | None = None) -> _SQL:
1145
+ """
1146
+ Creates a SQL statement with associated type information.
1147
+
1148
+ :param stmt: SQL statement as a literal string or template.
1149
+ :param args: Type signature for multiple input parameters (e.g. `tuple[bool, int, str]`).
1150
+ :param resultset: Type signature for multiple resultset columns (e.g. `tuple[datetime, Decimal, str]`).
1151
+ :param arg: Type signature for a single input parameter (e.g. `int`).
1152
+ :param result: Type signature for a single result column (e.g. `UUID`).
1153
+ """
1154
+
1155
+ input_data_types, output_data_types, names, wrapper = _sql_args_resultset(args=args, resultset=resultset, arg=arg, result=result)
1156
+
1157
+ obj: _SQLObject
1158
+ if sys.version_info >= (3, 14):
1159
+ match stmt:
1160
+ case Template():
1161
+ obj = _SQLTemplate(stmt, args=input_data_types, resultset=output_data_types, names=names, wrapper=wrapper)
1162
+ case str():
1163
+ obj = _SQLString(stmt, args=input_data_types, resultset=output_data_types, names=names, wrapper=wrapper)
1164
+ else:
1165
+ obj = _SQLString(stmt, args=input_data_types, resultset=output_data_types, names=names, wrapper=wrapper)
1166
+
1167
+ return _SQLImpl(obj)
1168
+
1169
+ ### START OF AUTO-GENERATED BLOCK FOR unsafe_sql ###
1170
+ @overload
1171
+ def unsafe_sql(self, stmt: str) -> SQL_P0: ...
1172
+ @overload
1173
+ def unsafe_sql(self, stmt: str, *, result: type[R1]) -> SQL_R1_P0[R1]: ...
1174
+ @overload
1175
+ def unsafe_sql(self, stmt: str, *, resultset: type[RS]) -> SQL_RX_P0[RS]: ...
1176
+ @overload
1177
+ def unsafe_sql(self, stmt: str, *, arg: type[P1]) -> SQL_PX[P1]: ...
1178
+ @overload
1179
+ def unsafe_sql(self, stmt: str, *, arg: type[P1], result: type[R1]) -> SQL_R1_PX[R1, P1]: ...
1180
+ @overload
1181
+ def unsafe_sql(self, stmt: str, *, arg: type[P1], resultset: type[RS]) -> SQL_RX_PX[RS, P1]: ...
1182
+ @overload
1183
+ def unsafe_sql(self, stmt: str, *, args: type[tuple[P1, Unpack[PX]]]) -> SQL_PX[P1, Unpack[PX]]: ...
1184
+ @overload
1185
+ def unsafe_sql(self, stmt: str, *, args: type[tuple[P1, Unpack[PX]]], result: type[R1]) -> SQL_R1_PX[R1, P1, Unpack[PX]]: ...
1186
+ @overload
1187
+ def unsafe_sql(self, stmt: str, *, args: type[tuple[P1, Unpack[PX]]], resultset: type[RS]) -> SQL_RX_PX[RS, P1, Unpack[PX]]: ...
1188
+
1189
+ ### END OF AUTO-GENERATED BLOCK FOR unsafe_sql ###
1190
+
1191
+ def unsafe_sql(self, stmt: str, *, args: type[Any] | None = None, resultset: type[Any] | None = None, arg: type[Any] | None = None, result: type[Any] | None = None) -> _SQL:
1192
+ """
1193
+ Creates a SQL statement with associated type information from a string.
1194
+
1195
+ This offers an alternative to the function :func:`sql` when we want to prevent the type checker from enforcing
1196
+ a string literal, e.g. when we want to embed a variable as the table name to dynamically create a SQL statement.
1197
+
1198
+ :param stmt: SQL statement as a string (or f-string).
1199
+ :param args: Type signature for multiple input parameters (e.g. `tuple[bool, int, str]`).
1200
+ :param resultset: Type signature for multiple resultset columns (e.g. `tuple[datetime, Decimal, str]`).
1201
+ :param arg: Type signature for a single input parameter (e.g. `int`).
1202
+ :param result: Type signature for a single result column (e.g. `UUID`).
1203
+ """
1204
+
1205
+ input_data_types, output_data_types, names, wrapper = _sql_args_resultset(args=args, resultset=resultset, arg=arg, result=result)
1206
+ obj = _SQLString(stmt, args=input_data_types, resultset=output_data_types, names=names, wrapper=wrapper)
1207
+ return _SQLImpl(obj)
1208
+
1209
+
1210
+ def _sql_args_resultset(
1211
+ *, args: type[Any] | None = None, resultset: type[Any] | None = None, arg: type[Any] | None = None, result: type[Any] | None = None
1212
+ ) -> tuple[tuple[Any, ...], tuple[Any, ...], tuple[str, ...] | None, _ResultsetWrapper]:
1213
+ "Parses an argument/resultset signature into input/output types."
1214
+
1215
+ if args is not None and arg is not None:
1216
+ raise TypeError("expected: either `args` or `arg`; got: both")
1217
+ if resultset is not None and result is not None:
1218
+ raise TypeError("expected: either `resultset` or `result`; got: both")
1219
+
1220
+ if args is not None:
1221
+ if hasattr(args, "_asdict") and hasattr(args, "_fields"):
1222
+ # named tuple
1223
+ input_data_types = tuple(tp for tp in args.__annotations__.values())
1224
+ else:
1225
+ # regular tuple
1226
+ if get_origin(args) is not tuple:
1227
+ raise TypeError(f"expected: `type[tuple[T, ...]]` for `args`; got: {args}")
1228
+ input_data_types = get_args(args)
1229
+ elif arg is not None:
1230
+ input_data_types = (arg,)
1231
+ else:
1232
+ input_data_types = ()
1233
+
1234
+ if resultset is not None:
1235
+ if hasattr(resultset, "_asdict") and hasattr(resultset, "_fields") and hasattr(resultset, "_make"):
1236
+ # named tuple
1237
+ output_data_types = tuple(tp for tp in resultset.__annotations__.values())
1238
+ names = tuple(f for f in resultset._fields)
1239
+ wrapper = _ResultsetWrapper(resultset, resultset._make)
1240
+ else:
1241
+ # regular tuple
1242
+ if get_origin(resultset) is not tuple:
1243
+ raise TypeError(f"expected: `type[tuple[T, ...]]` for `resultset`; got: {resultset}")
1244
+ output_data_types = get_args(resultset)
1245
+ names = None
1246
+ wrapper = _ResultsetWrapper(None, tuple)
1247
+ else:
1248
+ if result is not None:
1249
+ output_data_types = (result,)
1250
+ else:
1251
+ output_data_types = ()
1252
+ names = None
1253
+ wrapper = _ResultsetWrapper(None, tuple)
1254
+
1255
+ return input_data_types, output_data_types, names, wrapper
1256
+
1257
+
1258
+ FACTORY: SQLFactory = SQLFactory()
1259
+
1260
+ sql = FACTORY.sql
1261
+ unsafe_sql = FACTORY.unsafe_sql