asyncpg-typed 0.1.4__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
asyncpg_typed/__init__.py CHANGED
@@ -1,1169 +1,1261 @@
1
- """
2
- Type-safe queries for asyncpg.
3
-
4
- :see: https://github.com/hunyadi/asyncpg_typed
5
- """
6
-
7
- __version__ = "0.1.4"
8
- __author__ = "Levente Hunyadi"
9
- __copyright__ = "Copyright 2025-2026, Levente Hunyadi"
10
- __license__ = "MIT"
11
- __maintainer__ = "Levente Hunyadi"
12
- __status__ = "Production"
13
-
14
- import enum
15
- import sys
16
- import typing
17
- from abc import abstractmethod
18
- from collections.abc import Callable, Iterable, Sequence
19
- from dataclasses import dataclass
20
- from datetime import date, datetime, time, timedelta
21
- from decimal import Decimal
22
- from functools import reduce
23
- from io import StringIO
24
- from ipaddress import IPv4Address, IPv4Network, IPv6Address, IPv6Network
25
- from types import UnionType
26
- from typing import Any, Protocol, TypeAlias, TypeGuard, TypeVar, Union, get_args, get_origin, overload
27
- from uuid import UUID
28
-
29
- import asyncpg
30
- from asyncpg.prepared_stmt import PreparedStatement
31
-
32
- if sys.version_info < (3, 11):
33
- from typing_extensions import LiteralString, TypeVarTuple, Unpack
34
- else:
35
- from typing import LiteralString, TypeVarTuple, Unpack
36
-
37
- JsonType = None | bool | int | float | str | dict[str, "JsonType"] | list["JsonType"]
38
-
39
- RequiredJsonType = bool | int | float | str | dict[str, "JsonType"] | list["JsonType"]
40
-
41
- TargetType: TypeAlias = type[Any] | UnionType
42
-
43
- Connection: TypeAlias = asyncpg.Connection | asyncpg.pool.PoolConnectionProxy
44
-
45
-
46
- class CountMismatchError(TypeError):
47
- "Raised when a prepared statement takes or returns a different number of parameters or columns than declared in Python."
48
-
49
-
50
- class TypeMismatchError(TypeError):
51
- "Raised when a prepared statement takes or returns a PostgreSQL type incompatible with the declared Python type."
52
-
53
-
54
- class NameMismatchError(TypeError):
55
- "Raised when the name of a result-set column differs from what is declared in Python."
56
-
57
-
58
- class EnumMismatchError(TypeError):
59
- "Raised when a prepared statement takes or returns a PostgreSQL enum type whose permitted set of values differs from what is declared in Python."
60
-
61
-
62
- class NoneTypeError(TypeError):
63
- "Raised when a column marked as required contains a `NULL` value."
64
-
65
-
66
- if sys.version_info >= (3, 11):
67
-
68
- def is_enum_type(typ: Any) -> TypeGuard[type[enum.Enum]]:
69
- """
70
- `True` if the specified type is an enumeration type.
71
- """
72
-
73
- return isinstance(typ, enum.EnumType)
74
-
75
- else:
76
-
77
- def is_enum_type(typ: Any) -> TypeGuard[type[enum.Enum]]:
78
- """
79
- `True` if the specified type is an enumeration type.
80
- """
81
-
82
- # use an explicit isinstance(..., type) check to filter out special forms like generics
83
- return isinstance(typ, type) and issubclass(typ, enum.Enum)
84
-
85
-
86
- def is_union_type(tp: Any) -> bool:
87
- """
88
- `True` if `tp` is a union type such as `A | B` or `Union[A, B]`.
89
- """
90
-
91
- origin = get_origin(tp)
92
- return origin is Union or origin is UnionType
93
-
94
-
95
- def is_optional_type(tp: Any) -> bool:
96
- """
97
- `True` if `tp` is an optional type such as `T | None`, `Optional[T]` or `Union[T, None]`.
98
- """
99
-
100
- return is_union_type(tp) and any(a is type(None) for a in get_args(tp))
101
-
102
-
103
- def is_standard_type(tp: Any) -> bool:
104
- """
105
- `True` if the type represents a built-in or a well-known standard type.
106
- """
107
-
108
- return tp.__module__ == "builtins" or tp.__module__ == UnionType.__module__
109
-
110
-
111
- def is_json_type(tp: Any) -> bool:
112
- """
113
- `True` if the type represents an object de-serialized from a JSON string.
114
- """
115
-
116
- return tp in [JsonType, RequiredJsonType]
117
-
118
-
119
- def is_inet_type(tp: Any) -> bool:
120
- """
121
- `True` if the type represents an IP address or network.
122
- """
123
-
124
- return tp in [IPv4Address, IPv6Address, IPv4Network, IPv6Network]
125
-
126
-
127
- def make_union_type(tpl: list[Any]) -> UnionType:
128
- """
129
- Creates a `UnionType` (a.k.a. `A | B | C`) dynamically at run time.
130
- """
131
-
132
- if len(tpl) < 2:
133
- raise ValueError("expected: at least two types to make a `UnionType`")
134
-
135
- return reduce(lambda a, b: a | b, tpl)
136
-
137
-
138
- def get_required_type(tp: Any) -> Any:
139
- """
140
- Removes `None` from an optional type (i.e. a union type that has `None` as a member).
141
- """
142
-
143
- if not is_optional_type(tp):
144
- return tp
145
-
146
- tpl = [a for a in get_args(tp) if a is not type(None)]
147
- if len(tpl) > 1:
148
- return make_union_type(tpl)
149
- elif len(tpl) > 0:
150
- return tpl[0]
151
- else:
152
- return type(None)
153
-
154
-
155
- def _standard_json_decoder() -> Callable[[str], JsonType]:
156
- import json
157
-
158
- _json_decoder = json.JSONDecoder()
159
- return _json_decoder.decode
160
-
161
-
162
- def _json_decoder() -> Callable[[str], JsonType]:
163
- if typing.TYPE_CHECKING:
164
- return _standard_json_decoder()
165
- else:
166
- try:
167
- import orjson
168
-
169
- return orjson.loads
170
- except ModuleNotFoundError:
171
- return _standard_json_decoder()
172
-
173
-
174
- JSON_DECODER = _json_decoder()
175
-
176
-
177
- def _standard_json_encoder() -> Callable[[JsonType], str]:
178
- import json
179
-
180
- _json_encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":"), allow_nan=False)
181
- return _json_encoder.encode
182
-
183
-
184
- def _json_encoder() -> Callable[[JsonType], str]:
185
- if typing.TYPE_CHECKING:
186
- return _standard_json_encoder()
187
- else:
188
- try:
189
- import orjson
190
-
191
- def _wrap(value: JsonType) -> str:
192
- return orjson.dumps(value).decode()
193
-
194
- return _wrap
195
- except ModuleNotFoundError:
196
- return _standard_json_encoder()
197
-
198
-
199
- JSON_ENCODER = _json_encoder()
200
-
201
-
202
- def get_output_converter_for(tp: Any) -> Callable[[Any], Any]:
203
- """
204
- Returns a callable that takes a wire type and returns a target type.
205
-
206
- A wire type is one of the types returned by asyncpg.
207
- A target type is one of the types supported by the library.
208
- """
209
-
210
- if is_json_type(tp):
211
- # asyncpg returns fields of type `json` and `jsonb` as `str`, which must be de-serialized
212
- return JSON_DECODER
213
- else:
214
- # target data types that require conversion must have a single-argument `__init__` that takes an object of the source type
215
- return tp
216
-
217
-
218
- def get_input_converter_for(tp: Any) -> Callable[[Any], Any]:
219
- """
220
- Returns a callable that takes a source type and returns a wire type.
221
-
222
- A source type is one of the types supported by the library.
223
- A wire type is one of the types returned by asyncpg.
224
- """
225
-
226
- if is_json_type(tp):
227
- # asyncpg expects fields of type `json` and `jsonb` as `str`, which must be serialized
228
- return JSON_ENCODER
229
- else:
230
- # source data types that require conversion must have a single-argument `__init__` that takes an object of the source type
231
- return tp
232
-
233
-
234
- # maps PostgreSQL internal type names to compatible Python types
235
- _NAME_TO_TYPE: dict[str, tuple[Any, ...]] = {
236
- # boolean type
237
- "bool": (bool,),
238
- # numeric types
239
- "int2": (int,),
240
- "int4": (int,),
241
- "int8": (int,),
242
- "float4": (float,),
243
- "float8": (float,),
244
- "numeric": (Decimal,),
245
- # date and time types
246
- "date": (date,),
247
- "time": (time,),
248
- "timetz": (time,),
249
- "timestamp": (datetime,),
250
- "timestamptz": (datetime,),
251
- "interval": (timedelta,),
252
- # character sequence types
253
- "bpchar": (str,),
254
- "varchar": (str,),
255
- "text": (str,),
256
- # binary sequence types
257
- "bytea": (bytes,),
258
- # unique identifier type
259
- "uuid": (UUID,),
260
- # address types
261
- "cidr": (IPv4Network, IPv6Network, IPv4Network | IPv6Network),
262
- "inet": (IPv4Network, IPv6Network, IPv4Network | IPv6Network, IPv4Address, IPv6Address, IPv4Address | IPv6Address),
263
- "macaddr": (str,),
264
- "macaddr8": (str,),
265
- # JSON type
266
- "json": (str, RequiredJsonType),
267
- "jsonb": (str, RequiredJsonType),
268
- # XML type
269
- "xml": (str,),
270
- # geometric types
271
- "point": (asyncpg.Point,),
272
- "line": (asyncpg.Line,),
273
- "lseg": (asyncpg.LineSegment,),
274
- "box": (asyncpg.Box,),
275
- "path": (asyncpg.Path,),
276
- "polygon": (asyncpg.Polygon,),
277
- "circle": (asyncpg.Circle,),
278
- # range types
279
- "int4range": (asyncpg.Range[int],),
280
- "int4multirange": (list[asyncpg.Range[int]],),
281
- "int8range": (asyncpg.Range[int],),
282
- "int8multirange": (list[asyncpg.Range[int]],),
283
- "numrange": (asyncpg.Range[Decimal],),
284
- "nummultirange": (list[asyncpg.Range[Decimal]],),
285
- "tsrange": (asyncpg.Range[datetime],),
286
- "tsmultirange": (list[asyncpg.Range[datetime]],),
287
- "tstzrange": (asyncpg.Range[datetime],),
288
- "tstzmultirange": (list[asyncpg.Range[datetime]],),
289
- "daterange": (asyncpg.Range[date],),
290
- "datemultirange": (list[asyncpg.Range[date]],),
291
- }
292
-
293
-
294
- def type_to_str(tp: Any) -> str:
295
- "Emits a friendly name for a type."
296
-
297
- if isinstance(tp, type):
298
- return tp.__name__
299
- else:
300
- return str(tp)
301
-
302
-
303
- class _TypeVerifier:
304
- """
305
- Verifies if the Python target type can represent the PostgreSQL source type.
306
- """
307
-
308
- _connection: Connection
309
-
310
- def __init__(self, connection: Connection) -> None:
311
- self._connection = connection
312
-
313
- async def _check_enum_type(self, pg_name: str, pg_type: asyncpg.Type, data_type: type[enum.Enum]) -> None:
314
- """
315
- Verifies if a Python enumeration type matches a PostgreSQL enumeration type.
316
- """
317
-
318
- for e in data_type:
319
- if not isinstance(e.value, str):
320
- raise TypeMismatchError(f"expected: Python enum type `{type_to_str(data_type)}` with `str` values; got: `{type_to_str(type(e.value))}` for enum field `{e.name}`")
321
-
322
- py_values = set(e.value for e in data_type)
323
-
324
- rows = await self._connection.fetch("SELECT enumlabel FROM pg_enum WHERE enumtypid = $1 ORDER BY enumsortorder;", pg_type.oid)
325
- db_values = set(row[0] for row in rows)
326
-
327
- db_extra = db_values - py_values
328
- if db_extra:
329
- raise EnumMismatchError(f"expected: Python enum type `{type_to_str(data_type)}` to match values of PostgreSQL enum type `{pg_type.name}` for {pg_name}; missing value(s): {', '.join(f'`{val}`' for val in db_extra)})")
330
-
331
- py_extra = py_values - db_values
332
- if py_extra:
333
- raise EnumMismatchError(f"expected: Python enum type `{type_to_str(data_type)}` to match values of PostgreSQL enum type `{pg_type.name}` for {pg_name}; got extra value(s): {', '.join(f'`{val}`' for val in py_extra)})")
334
-
335
- async def check_data_type(self, pg_name: str, pg_type: asyncpg.Type, data_type: TargetType) -> None:
336
- """
337
- Verifies if the Python target type can represent the PostgreSQL source type.
338
- """
339
-
340
- if pg_type.schema == "pg_catalog": # well-known PostgreSQL types
341
- if is_enum_type(data_type):
342
- if pg_type.name not in ["bpchar", "varchar", "text"]:
343
- raise TypeMismatchError(f"expected: Python enum type `{type_to_str(data_type)}` for {pg_name}; got: PostgreSQL type `{pg_type.kind}` of `{pg_type.name}` instead of `char`, `varchar` or `text`")
344
- elif pg_type.kind == "array" and get_origin(data_type) is list:
345
- if not pg_type.name.endswith("[]"):
346
- raise TypeMismatchError(f"expected: Python list type `{type_to_str(data_type)}` for {pg_name}; got: PostgreSQL type `{pg_type.kind}` of `{pg_type.name}` instead of array")
347
-
348
- expected_types = _NAME_TO_TYPE.get(pg_type.name[:-2])
349
- if expected_types is None:
350
- raise TypeMismatchError(f"expected: Python list type `{type_to_str(data_type)}` for {pg_name}; got: unrecognized PostgreSQL type `{pg_type.kind}` of `{pg_type.name}`")
351
- elif get_args(data_type)[0] not in expected_types:
352
- if len(expected_types) == 1:
353
- target = f"the Python type `{type_to_str(expected_types[0])}`"
354
- else:
355
- target = f"one of the Python types {', '.join(f'`{type_to_str(tp)}`' for tp in expected_types)}"
356
- raise TypeMismatchError(f"expected: Python list type `{type_to_str(data_type)}` for {pg_name}; got: incompatible PostgreSQL type `{pg_type.kind}` of `{pg_type.name}` whose elements convert to {target}")
357
- else:
358
- expected_types = _NAME_TO_TYPE.get(pg_type.name)
359
- if expected_types is None:
360
- raise TypeMismatchError(f"expected: Python type `{type_to_str(data_type)}` for {pg_name}; got: unrecognized PostgreSQL type `{pg_type.kind}` of `{pg_type.name}`")
361
- elif data_type not in expected_types:
362
- if len(expected_types) == 1:
363
- target = f"the Python type `{type_to_str(expected_types[0])}`"
364
- else:
365
- target = f"one of the Python types {', '.join(f'`{type_to_str(tp)}`' for tp in expected_types)}"
366
- raise TypeMismatchError(f"expected: Python type `{type_to_str(data_type)}` for {pg_name}; got: incompatible PostgreSQL type `{pg_type.kind}` of `{pg_type.name}` which converts to {target}")
367
- elif pg_type.kind == "composite": # PostgreSQL composite types
368
- # user-defined composite types registered with `conn.set_type_codec()` typically using `format="tuple"`
369
- pass
370
- else: # custom PostgreSQL types
371
- if is_enum_type(data_type):
372
- await self._check_enum_type(pg_name, pg_type, data_type)
373
- elif is_standard_type(data_type):
374
- raise TypeMismatchError(f"expected: Python type `{type_to_str(data_type)}` for {pg_name}; got: PostgreSQL type `{pg_type.kind}` of `{pg_type.name}`")
375
- else:
376
- # user-defined types registered with `conn.set_type_codec()`
377
- pass
378
-
379
-
380
- @dataclass(frozen=True)
381
- class _Placeholder:
382
- ordinal: int
383
- data_type: TargetType
384
-
385
- def __repr__(self) -> str:
386
- return f"{self.__class__.__name__}({self.ordinal}, {self.data_type!r})"
387
-
388
-
389
- class _ResultsetWrapper:
390
- "Wraps result-set rows into a tuple or named tuple."
391
-
392
- init: Callable[..., tuple[Any, ...]] | None
393
- iterable: Callable[[Iterable[Any]], tuple[Any, ...]]
394
-
395
- def __init__(self, init: Callable[..., tuple[Any, ...]] | None, iterable: Callable[[Iterable[Any]], tuple[Any, ...]]) -> None:
396
- """
397
- Initializes a result-set wrapper.
398
-
399
- :param init: Initializer function that takes as many arguments as columns in the result-set.
400
- :param iterable: Initializer function that takes an iterable over columns of a result-set row.
401
- """
402
-
403
- self.init = init
404
- self.iterable = iterable
405
-
406
-
407
- class _SQLObject:
408
- """
409
- Associates input and output type information with a SQL statement.
410
- """
411
-
412
- _parameter_data_types: tuple[_Placeholder, ...]
413
- _resultset_data_types: tuple[TargetType, ...]
414
- _resultset_column_names: tuple[str, ...] | None
415
- _resultset_wrapper: _ResultsetWrapper
416
- _parameter_cast: int
417
- _parameter_converters: tuple[Callable[[Any], Any], ...]
418
- _required: int
419
- _resultset_cast: int
420
- _resultset_converters: tuple[Callable[[Any], Any], ...]
421
-
422
- @property
423
- def parameter_data_types(self) -> tuple[_Placeholder, ...]:
424
- "Expected inbound parameter data types."
425
-
426
- return self._parameter_data_types
427
-
428
- @property
429
- def resultset_data_types(self) -> tuple[TargetType, ...]:
430
- "Expected column data types in the result-set."
431
-
432
- return self._resultset_data_types
433
-
434
- @property
435
- def resultset_column_names(self) -> tuple[str, ...] | None:
436
- "Expected column names in the result-set."
437
-
438
- return self._resultset_column_names
439
-
440
- def __init__(
441
- self,
442
- *,
443
- args: tuple[TargetType, ...],
444
- resultset: tuple[TargetType, ...],
445
- names: tuple[str, ...] | None,
446
- wrapper: _ResultsetWrapper,
447
- ) -> None:
448
- self._parameter_data_types = tuple(_Placeholder(ordinal, get_required_type(arg)) for ordinal, arg in enumerate(args, start=1))
449
- self._resultset_data_types = tuple(get_required_type(data_type) for data_type in resultset)
450
- self._resultset_column_names = names
451
- self._resultset_wrapper = wrapper
452
-
453
- # create a bit-field of types that require cast or serialization (1: apply conversion; 0: forward value as-is)
454
- parameter_cast = 0
455
- for index, placeholder in enumerate(self._parameter_data_types):
456
- parameter_cast |= is_json_type(placeholder.data_type) << index
457
- self._parameter_cast = parameter_cast
458
-
459
- self._parameter_converters = tuple(get_input_converter_for(placeholder.data_type) for placeholder in self._parameter_data_types)
460
-
461
- # create a bit-field of required types (1: required; 0: optional)
462
- required = 0
463
- for index, data_type in enumerate(resultset):
464
- required |= (not is_optional_type(data_type)) << index
465
- self._required = required
466
-
467
- # create a bit-field of types that require cast or de-serialization (1: apply conversion; 0: forward value as-is)
468
- resultset_cast = 0
469
- for index, data_type in enumerate(self._resultset_data_types):
470
- resultset_cast |= (is_enum_type(data_type) or is_json_type(data_type) or is_inet_type(data_type)) << index
471
- self._resultset_cast = resultset_cast
472
-
473
- self._resultset_converters = tuple(get_output_converter_for(data_type) for data_type in self._resultset_data_types)
474
-
475
- def _raise_required_is_none(self, row: tuple[Any, ...], row_index: int | None = None) -> None:
476
- """
477
- Raises an error with the index of the first column value that is of a required type but has been assigned a value of `None`.
478
- """
479
-
480
- for col_index in range(len(row)):
481
- if (self._required >> col_index & 1) and row[col_index] is None:
482
- if row_index is not None:
483
- row_col_spec = f"row #{row_index} and column #{col_index}"
484
- else:
485
- row_col_spec = f"column #{col_index}"
486
- raise NoneTypeError(f"expected: {self._resultset_data_types[col_index]} in {row_col_spec}; got: NULL")
487
-
488
- def check_rows(self, rows: list[tuple[Any, ...]]) -> None:
489
- """
490
- Verifies if declared types match actual value types in a resultset.
491
- """
492
-
493
- if not rows:
494
- return
495
-
496
- required = self._required
497
- if not required:
498
- return
499
-
500
- match len(rows[0]):
501
- case 1:
502
- for r, row in enumerate(rows):
503
- if required & (row[0] is None):
504
- self._raise_required_is_none(row, r)
505
- case 2:
506
- for r, row in enumerate(rows):
507
- a, b = row
508
- if required & ((a is None) | (b is None) << 1):
509
- self._raise_required_is_none(row, r)
510
- case 3:
511
- for r, row in enumerate(rows):
512
- a, b, c = row
513
- if required & ((a is None) | (b is None) << 1 | (c is None) << 2):
514
- self._raise_required_is_none(row, r)
515
- case 4:
516
- for r, row in enumerate(rows):
517
- a, b, c, d = row
518
- if required & ((a is None) | (b is None) << 1 | (c is None) << 2 | (d is None) << 3):
519
- self._raise_required_is_none(row, r)
520
- case 5:
521
- for r, row in enumerate(rows):
522
- a, b, c, d, e = row
523
- if required & ((a is None) | (b is None) << 1 | (c is None) << 2 | (d is None) << 3 | (e is None) << 4):
524
- self._raise_required_is_none(row, r)
525
- case 6:
526
- for r, row in enumerate(rows):
527
- a, b, c, d, e, f = row
528
- if required & ((a is None) | (b is None) << 1 | (c is None) << 2 | (d is None) << 3 | (e is None) << 4 | (f is None) << 5):
529
- self._raise_required_is_none(row, r)
530
- case 7:
531
- for r, row in enumerate(rows):
532
- a, b, c, d, e, f, g = row
533
- if required & ((a is None) | (b is None) << 1 | (c is None) << 2 | (d is None) << 3 | (e is None) << 4 | (f is None) << 5 | (g is None) << 6):
534
- self._raise_required_is_none(row, r)
535
- case 8:
536
- for r, row in enumerate(rows):
537
- a, b, c, d, e, f, g, h = row
538
- if required & ((a is None) | (b is None) << 1 | (c is None) << 2 | (d is None) << 3 | (e is None) << 4 | (f is None) << 5 | (g is None) << 6 | (h is None) << 7):
539
- self._raise_required_is_none(row, r)
540
- case _:
541
- for r, row in enumerate(rows):
542
- self._raise_required_is_none(row, r)
543
-
544
- def check_row(self, row: tuple[Any, ...]) -> None:
545
- """
546
- Verifies if declared types match actual value types in a single row.
547
- """
548
-
549
- required = self._required
550
- if not required:
551
- return
552
-
553
- match len(row):
554
- case 1:
555
- if required & (row[0] is None):
556
- self._raise_required_is_none(row)
557
- case 2:
558
- a, b = row
559
- if required & ((a is None) | (b is None) << 1):
560
- self._raise_required_is_none(row)
561
- case 3:
562
- a, b, c = row
563
- if required & ((a is None) | (b is None) << 1 | (c is None) << 2):
564
- self._raise_required_is_none(row)
565
- case 4:
566
- a, b, c, d = row
567
- if required & ((a is None) | (b is None) << 1 | (c is None) << 2 | (d is None) << 3):
568
- self._raise_required_is_none(row)
569
- case 5:
570
- a, b, c, d, e = row
571
- if required & ((a is None) | (b is None) << 1 | (c is None) << 2 | (d is None) << 3 | (e is None) << 4):
572
- self._raise_required_is_none(row)
573
- case 6:
574
- a, b, c, d, e, f = row
575
- if required & ((a is None) | (b is None) << 1 | (c is None) << 2 | (d is None) << 3 | (e is None) << 4 | (f is None) << 5):
576
- self._raise_required_is_none(row)
577
- case 7:
578
- a, b, c, d, e, f, g = row
579
- if required & ((a is None) | (b is None) << 1 | (c is None) << 2 | (d is None) << 3 | (e is None) << 4 | (f is None) << 5 | (g is None) << 6):
580
- self._raise_required_is_none(row)
581
- case 8:
582
- a, b, c, d, e, f, g, h = row
583
- if required & ((a is None) | (b is None) << 1 | (c is None) << 2 | (d is None) << 3 | (e is None) << 4 | (f is None) << 5 | (g is None) << 6 | (h is None) << 7):
584
- self._raise_required_is_none(row)
585
- case _:
586
- self._raise_required_is_none(row)
587
-
588
- def check_column(self, column: list[Any]) -> None:
589
- """
590
- Verifies if the declared type matches the actual value type of a single-column resultset.
591
- """
592
-
593
- if self._required:
594
- for i, value in enumerate(column):
595
- if value is None:
596
- raise NoneTypeError(f"expected: {self._resultset_data_types[0]} in row #{i}; got: NULL")
597
-
598
- def check_value(self, value: Any) -> None:
599
- """
600
- Verifies if the declared type matches the actual value type.
601
- """
602
-
603
- if self._required and value is None:
604
- raise NoneTypeError(f"expected: {self._resultset_data_types[0]}; got: NULL")
605
-
606
- def convert_arg_lists(self, arg_lists: Iterable[Sequence[Any]]) -> Iterable[Sequence[Any]]:
607
- """
608
- Converts a list of Python query argument tuples to a list of PostgreSQL parameter tuples.
609
- """
610
-
611
- cast = self._parameter_cast
612
- if cast:
613
- converters = self._parameter_converters
614
- yield from (tuple((converters[i](value) if (value := arg[i]) is not None and cast >> i & 1 else value) for i in range(len(arg))) for arg in arg_lists)
615
- else:
616
- yield from arg_lists
617
-
618
- def convert_arg_list(self, arg_list: Sequence[Any]) -> Sequence[Any]:
619
- """
620
- Converts Python query arguments to PostgreSQL parameters.
621
- """
622
-
623
- cast = self._parameter_cast
624
- if cast:
625
- converters = self._parameter_converters
626
- return tuple((converters[i](value) if (value := arg_list[i]) is not None and cast >> i & 1 else value) for i in range(len(arg_list)))
627
- else:
628
- return tuple(value for value in arg_list)
629
-
630
- def convert_rows(self, rows: list[asyncpg.Record]) -> list[tuple[Any, ...]]:
631
- """
632
- Converts columns in the PostgreSQL result-set to their corresponding Python target types.
633
-
634
- :param rows: List of rows returned by PostgreSQL.
635
- :returns: List of tuples with each tuple element having the configured Python target type.
636
- """
637
-
638
- if not rows:
639
- return []
640
-
641
- init_wrapper = self._resultset_wrapper.init
642
- iterable_wrapper = self._resultset_wrapper.iterable
643
- cast = self._resultset_cast
644
- if not cast:
645
- return [iterable_wrapper(row.values()) for row in rows]
646
-
647
- columns = len(rows[0])
648
- match columns:
649
- case 1:
650
- converter = self._resultset_converters[0]
651
- if init_wrapper is not None:
652
- return [init_wrapper(converter(value) if (value := row[0]) is not None else value) for row in rows]
653
- else:
654
- return [(converter(value) if (value := row[0]) is not None else value,) for row in rows]
655
- case 2:
656
- conv_a, conv_b = self._resultset_converters
657
- cast_a = cast >> 0 & 1
658
- cast_b = cast >> 1 & 1
659
- if init_wrapper is not None:
660
- return [
661
- init_wrapper(
662
- conv_a(value) if (value := row[0]) is not None and cast_a else value,
663
- conv_b(value) if (value := row[1]) is not None and cast_b else value,
664
- )
665
- for row in rows
666
- ]
667
- else:
668
- return [
669
- (
670
- conv_a(value) if (value := row[0]) is not None and cast_a else value,
671
- conv_b(value) if (value := row[1]) is not None and cast_b else value,
672
- )
673
- for row in rows
674
- ]
675
- case 3:
676
- conv_a, conv_b, conv_c = self._resultset_converters
677
- cast_a = cast >> 0 & 1
678
- cast_b = cast >> 1 & 1
679
- cast_c = cast >> 2 & 1
680
- if init_wrapper is not None:
681
- return [
682
- init_wrapper(
683
- conv_a(value) if (value := row[0]) is not None and cast_a else value,
684
- conv_b(value) if (value := row[1]) is not None and cast_b else value,
685
- conv_c(value) if (value := row[2]) is not None and cast_c else value,
686
- )
687
- for row in rows
688
- ]
689
- else:
690
- return [
691
- (
692
- conv_a(value) if (value := row[0]) is not None and cast_a else value,
693
- conv_b(value) if (value := row[1]) is not None and cast_b else value,
694
- conv_c(value) if (value := row[2]) is not None and cast_c else value,
695
- )
696
- for row in rows
697
- ]
698
- case 4:
699
- conv_a, conv_b, conv_c, conv_d = self._resultset_converters
700
- cast_a = cast >> 0 & 1
701
- cast_b = cast >> 1 & 1
702
- cast_c = cast >> 2 & 1
703
- cast_d = cast >> 3 & 1
704
- if init_wrapper is not None:
705
- return [
706
- init_wrapper(
707
- conv_a(value) if (value := row[0]) is not None and cast_a else value,
708
- conv_b(value) if (value := row[1]) is not None and cast_b else value,
709
- conv_c(value) if (value := row[2]) is not None and cast_c else value,
710
- conv_d(value) if (value := row[3]) is not None and cast_d else value,
711
- )
712
- for row in rows
713
- ]
714
- else:
715
- return [
716
- (
717
- conv_a(value) if (value := row[0]) is not None and cast_a else value,
718
- conv_b(value) if (value := row[1]) is not None and cast_b else value,
719
- conv_c(value) if (value := row[2]) is not None and cast_c else value,
720
- conv_d(value) if (value := row[3]) is not None and cast_d else value,
721
- )
722
- for row in rows
723
- ]
724
- case _:
725
- converters = self._resultset_converters
726
- return [iterable_wrapper((converters[i](value) if (value := row[i]) is not None and cast >> i & 1 else value) for i in range(columns)) for row in rows]
727
-
728
- def convert_row(self, row: asyncpg.Record) -> tuple[Any, ...]:
729
- """
730
- Converts columns in the PostgreSQL result-set to their corresponding Python target types.
731
-
732
- :param row: A single row returned by PostgreSQL.
733
- :returns: A tuple with each tuple element having the configured Python target type.
734
- """
735
-
736
- wrapper = self._resultset_wrapper.iterable
737
- cast = self._resultset_cast
738
- if cast:
739
- converters = self._resultset_converters
740
- columns = len(row)
741
- return wrapper((converters[i](value) if (value := row[i]) is not None and cast >> i & 1 else value) for i in range(columns))
742
- else:
743
- return wrapper(row.values())
744
-
745
- def convert_column(self, rows: list[asyncpg.Record]) -> list[Any]:
746
- """
747
- Converts a single column in the PostgreSQL result-set to its corresponding Python target type.
748
-
749
- :param rows: List of rows returned by PostgreSQL.
750
- :returns: List of values having the configured Python target type.
751
- """
752
-
753
- cast = self._resultset_cast
754
- if cast:
755
- converter = self._resultset_converters[0]
756
- return [(converter(value) if (value := row[0]) is not None else value) for row in rows]
757
- else:
758
- return [row[0] for row in rows]
759
-
760
- def convert_value(self, value: Any) -> Any:
761
- """
762
- Converts a single PostgreSQL value to its corresponding Python target type.
763
-
764
- :param value: A single value returned by PostgreSQL.
765
- :returns: A converted value having the configured Python target type.
766
- """
767
-
768
- return self._resultset_converters[0](value) if value is not None and self._resultset_cast else value
769
-
770
- @abstractmethod
771
- def query(self) -> str:
772
- """
773
- Returns a SQL query string with PostgreSQL ordinal placeholders.
774
- """
775
- ...
776
-
777
- def __repr__(self) -> str:
778
- return f"{self.__class__.__name__}({self.query()!r})"
779
-
780
- def __str__(self) -> str:
781
- return self.query()
782
-
783
-
784
- if sys.version_info >= (3, 14):
785
- from string.templatelib import Interpolation, Template # type: ignore[import-not-found]
786
-
787
- SQLExpression: TypeAlias = Template | LiteralString
788
-
789
- class _SQLTemplate(_SQLObject):
790
- """
791
- A SQL query specified with the Python t-string syntax.
792
- """
793
-
794
- _strings: tuple[str, ...]
795
- _placeholders: tuple[_Placeholder, ...]
796
-
797
- def __init__(
798
- self,
799
- template: Template,
800
- *,
801
- args: tuple[TargetType, ...],
802
- resultset: tuple[TargetType, ...],
803
- names: tuple[str, ...] | None,
804
- wrapper: _ResultsetWrapper,
805
- ) -> None:
806
- super().__init__(args=args, resultset=resultset, names=names, wrapper=wrapper)
807
-
808
- for ip in template.interpolations:
809
- if ip.conversion is not None:
810
- raise TypeError(f"interpolation `{ip.expression}` expected to apply no conversion")
811
- if ip.format_spec:
812
- raise TypeError(f"interpolation `{ip.expression}` expected to apply no format spec")
813
- if not isinstance(ip.value, int):
814
- raise TypeError(f"interpolation `{ip.expression}` expected to evaluate to an integer")
815
-
816
- self._strings = template.strings
817
-
818
- if len(self.parameter_data_types) > 0:
819
-
820
- def _to_placeholder(ip: Interpolation) -> _Placeholder:
821
- ordinal = int(ip.value)
822
- if not (0 < ordinal <= len(self.parameter_data_types)):
823
- raise IndexError(f"interpolation `{ip.expression}` is an ordinal out of range; expected: 0 < value <= {len(self.parameter_data_types)}")
824
- return self.parameter_data_types[int(ip.value) - 1]
825
-
826
- self._placeholders = tuple(_to_placeholder(ip) for ip in template.interpolations)
827
- else:
828
- self._placeholders = ()
829
-
830
- def query(self) -> str:
831
- buf = StringIO()
832
- for s, p in zip(self._strings[:-1], self._placeholders, strict=True):
833
- buf.write(s)
834
- buf.write(f"${p.ordinal}")
835
- buf.write(self._strings[-1])
836
- return buf.getvalue()
837
-
838
- else:
839
- SQLExpression = LiteralString
840
-
841
-
842
- class _SQLString(_SQLObject):
843
- """
844
- A SQL query specified as a plain string (e.g. f-string).
845
- """
846
-
847
- _sql: str
848
-
849
- def __init__(
850
- self,
851
- sql: str,
852
- *,
853
- args: tuple[TargetType, ...],
854
- resultset: tuple[TargetType, ...],
855
- names: tuple[str, ...] | None,
856
- wrapper: _ResultsetWrapper,
857
- ) -> None:
858
- super().__init__(args=args, resultset=resultset, names=names, wrapper=wrapper)
859
- self._sql = sql
860
-
861
- def query(self) -> str:
862
- return self._sql
863
-
864
-
865
- class _SQL(Protocol):
866
- """
867
- Represents a SQL statement with associated type information.
868
- """
869
-
870
-
871
- class _SQLImpl(_SQL):
872
- """
873
- Forwards input data to an `asyncpg.PreparedStatement`, and validates output data (if necessary).
874
- """
875
-
876
- _sql: _SQLObject
877
-
878
- def __init__(self, sql: _SQLObject) -> None:
879
- self._sql = sql
880
-
881
- def __str__(self) -> str:
882
- return str(self._sql)
883
-
884
- def __repr__(self) -> str:
885
- return repr(self._sql)
886
-
887
- async def _prepare(self, connection: Connection) -> PreparedStatement:
888
- stmt = await connection.prepare(self._sql.query())
889
-
890
- verifier = _TypeVerifier(connection)
891
-
892
- input_count = len(self._sql.parameter_data_types)
893
- parameter_count = len(stmt.get_parameters())
894
- if parameter_count != input_count:
895
- raise CountMismatchError(f"expected: PostgreSQL query to take {input_count} parameter(s); got: {parameter_count}")
896
- for param, placeholder in zip(stmt.get_parameters(), self._sql.parameter_data_types, strict=True):
897
- await verifier.check_data_type(f"parameter ${placeholder.ordinal}", param, placeholder.data_type)
898
-
899
- output_count = len(self._sql.resultset_data_types)
900
- column_count = len(stmt.get_attributes())
901
- if column_count != output_count:
902
- raise CountMismatchError(f"expected: PostgreSQL query to return {output_count} column(s) in result-set; got: {column_count}")
903
- if self._sql.resultset_column_names is not None:
904
- for index, attr, name in zip(range(output_count), stmt.get_attributes(), self._sql.resultset_column_names, strict=True):
905
- if attr.name != name:
906
- raise NameMismatchError(f"expected: Python field name `{name}` to match PostgreSQL result-set column name `{attr.name}` for index #{index}")
907
- for attr, data_type in zip(stmt.get_attributes(), self._sql.resultset_data_types, strict=True):
908
- await verifier.check_data_type(f"column `{attr.name}`", attr.type, data_type)
909
-
910
- return stmt
911
-
912
- async def execute(self, connection: asyncpg.Connection, *args: Any) -> None:
913
- await connection.execute(self._sql.query(), *self._sql.convert_arg_list(args))
914
-
915
- async def executemany(self, connection: asyncpg.Connection, args: Iterable[Sequence[Any]]) -> None:
916
- stmt = await self._prepare(connection)
917
- await stmt.executemany(self._sql.convert_arg_lists(args))
918
-
919
- def _cast_fetch(self, rows: list[asyncpg.Record]) -> list[tuple[Any, ...]]:
920
- resultset = self._sql.convert_rows(rows)
921
- self._sql.check_rows(resultset)
922
- return resultset
923
-
924
- async def fetch(self, connection: asyncpg.Connection, *args: Any) -> list[tuple[Any, ...]]:
925
- stmt = await self._prepare(connection)
926
- rows = await stmt.fetch(*self._sql.convert_arg_list(args))
927
- return self._cast_fetch(rows)
928
-
929
- async def fetchmany(self, connection: asyncpg.Connection, args: Iterable[Sequence[Any]]) -> list[tuple[Any, ...]]:
930
- stmt = await self._prepare(connection)
931
- rows = await stmt.fetchmany(self._sql.convert_arg_lists(args))
932
- return self._cast_fetch(rows)
933
-
934
- async def fetchrow(self, connection: asyncpg.Connection, *args: Any) -> tuple[Any, ...] | None:
935
- stmt = await self._prepare(connection)
936
- row = await stmt.fetchrow(*self._sql.convert_arg_list(args))
937
- if row is None:
938
- return None
939
- resultset = self._sql.convert_row(row)
940
- self._sql.check_row(resultset)
941
- return resultset
942
-
943
- async def fetchcol(self, connection: asyncpg.Connection, *args: Any) -> list[Any]:
944
- stmt = await self._prepare(connection)
945
- rows = await stmt.fetch(*self._sql.convert_arg_list(args))
946
- column = self._sql.convert_column(rows)
947
- self._sql.check_column(column)
948
- return column
949
-
950
- async def fetchval(self, connection: asyncpg.Connection, *args: Any) -> Any:
951
- stmt = await self._prepare(connection)
952
- value = await stmt.fetchval(*self._sql.convert_arg_list(args))
953
- result = self._sql.convert_value(value)
954
- self._sql.check_value(result)
955
- return result
956
-
957
-
958
- P1 = TypeVar("P1")
959
- PX = TypeVarTuple("PX")
960
-
961
- RT = TypeVar("RT")
962
- R1 = TypeVar("R1")
963
- R2 = TypeVar("R2")
964
- RX = TypeVarTuple("RX")
965
-
966
-
967
- ### START OF AUTO-GENERATED BLOCK FOR Protocol ###
968
- class SQL_P0(Protocol):
969
- @abstractmethod
970
- async def execute(self, connection: Connection) -> None: ...
971
-
972
-
973
- class SQL_R1_P0(SQL_P0, Protocol[R1]):
974
- @abstractmethod
975
- async def fetch(self, connection: Connection) -> list[tuple[R1]]: ...
976
- @abstractmethod
977
- async def fetchrow(self, connection: Connection) -> tuple[R1] | None: ...
978
- @abstractmethod
979
- async def fetchcol(self, connection: Connection) -> list[R1]: ...
980
- @abstractmethod
981
- async def fetchval(self, connection: Connection) -> R1: ...
982
-
983
-
984
- class SQL_RX_P0(SQL_P0, Protocol[RT]):
985
- @abstractmethod
986
- async def fetch(self, connection: Connection) -> list[RT]: ...
987
- @abstractmethod
988
- async def fetchrow(self, connection: Connection) -> RT | None: ...
989
-
990
-
991
- class SQL_PX(Protocol[Unpack[PX]]):
992
- @abstractmethod
993
- async def execute(self, connection: Connection, *args: Unpack[PX]) -> None: ...
994
- @abstractmethod
995
- async def executemany(self, connection: Connection, args: Iterable[tuple[Unpack[PX]]]) -> None: ...
996
-
997
-
998
- class SQL_R1_PX(SQL_PX[Unpack[PX]], Protocol[R1, Unpack[PX]]):
999
- @abstractmethod
1000
- async def fetch(self, connection: Connection, *args: Unpack[PX]) -> list[tuple[R1]]: ...
1001
- @abstractmethod
1002
- async def fetchmany(self, connection: Connection, args: Iterable[tuple[Unpack[PX]]]) -> list[tuple[R1]]: ...
1003
- @abstractmethod
1004
- async def fetchrow(self, connection: Connection, *args: Unpack[PX]) -> tuple[R1] | None: ...
1005
- @abstractmethod
1006
- async def fetchcol(self, connection: Connection, *args: Unpack[PX]) -> list[R1]: ...
1007
- @abstractmethod
1008
- async def fetchval(self, connection: Connection, *args: Unpack[PX]) -> R1: ...
1009
-
1010
-
1011
- class SQL_RX_PX(SQL_PX[Unpack[PX]], Protocol[RT, Unpack[PX]]):
1012
- @abstractmethod
1013
- async def fetch(self, connection: Connection, *args: Unpack[PX]) -> list[RT]: ...
1014
- @abstractmethod
1015
- async def fetchmany(self, connection: Connection, args: Iterable[tuple[Unpack[PX]]]) -> list[RT]: ...
1016
- @abstractmethod
1017
- async def fetchrow(self, connection: Connection, *args: Unpack[PX]) -> RT | None: ...
1018
-
1019
-
1020
- ### END OF AUTO-GENERATED BLOCK FOR Protocol ###
1021
-
1022
- RS = TypeVar("RS", bound=tuple[Any, ...])
1023
-
1024
-
1025
- class SQLFactory:
1026
- """
1027
- Creates type-safe SQL queries.
1028
- """
1029
-
1030
- ### START OF AUTO-GENERATED BLOCK FOR sql ###
1031
- @overload
1032
- def sql(self, stmt: SQLExpression) -> SQL_P0: ...
1033
- @overload
1034
- def sql(self, stmt: SQLExpression, *, result: type[R1]) -> SQL_R1_P0[R1]: ...
1035
- @overload
1036
- def sql(self, stmt: SQLExpression, *, resultset: type[RS]) -> SQL_RX_P0[RS]: ...
1037
- @overload
1038
- def sql(self, stmt: SQLExpression, *, arg: type[P1]) -> SQL_PX[P1]: ...
1039
- @overload
1040
- def sql(self, stmt: SQLExpression, *, arg: type[P1], result: type[R1]) -> SQL_R1_PX[R1, P1]: ...
1041
- @overload
1042
- def sql(self, stmt: SQLExpression, *, arg: type[P1], resultset: type[RS]) -> SQL_RX_PX[RS, P1]: ...
1043
- @overload
1044
- def sql(self, stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]]) -> SQL_PX[P1, Unpack[PX]]: ...
1045
- @overload
1046
- def sql(self, stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]], result: type[R1]) -> SQL_R1_PX[R1, P1, Unpack[PX]]: ...
1047
- @overload
1048
- def sql(self, stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]], resultset: type[RS]) -> SQL_RX_PX[RS, P1, Unpack[PX]]: ...
1049
-
1050
- ### END OF AUTO-GENERATED BLOCK FOR sql ###
1051
-
1052
- def sql(self, stmt: SQLExpression, *, args: type[Any] | None = None, resultset: type[Any] | None = None, arg: type[Any] | None = None, result: type[Any] | None = None) -> _SQL:
1053
- """
1054
- Creates a SQL statement with associated type information.
1055
-
1056
- :param stmt: SQL statement as a literal string or template.
1057
- :param args: Type signature for multiple input parameters (e.g. `tuple[bool, int, str]`).
1058
- :param resultset: Type signature for multiple resultset columns (e.g. `tuple[datetime, Decimal, str]`).
1059
- :param arg: Type signature for a single input parameter (e.g. `int`).
1060
- :param result: Type signature for a single result column (e.g. `UUID`).
1061
- """
1062
-
1063
- input_data_types, output_data_types, names, wrapper = _sql_args_resultset(args=args, resultset=resultset, arg=arg, result=result)
1064
-
1065
- obj: _SQLObject
1066
- if sys.version_info >= (3, 14):
1067
- match stmt:
1068
- case Template():
1069
- obj = _SQLTemplate(stmt, args=input_data_types, resultset=output_data_types, names=names, wrapper=wrapper)
1070
- case str():
1071
- obj = _SQLString(stmt, args=input_data_types, resultset=output_data_types, names=names, wrapper=wrapper)
1072
- else:
1073
- obj = _SQLString(stmt, args=input_data_types, resultset=output_data_types, names=names, wrapper=wrapper)
1074
-
1075
- return _SQLImpl(obj)
1076
-
1077
- ### START OF AUTO-GENERATED BLOCK FOR unsafe_sql ###
1078
- @overload
1079
- def unsafe_sql(self, stmt: str) -> SQL_P0: ...
1080
- @overload
1081
- def unsafe_sql(self, stmt: str, *, result: type[R1]) -> SQL_R1_P0[R1]: ...
1082
- @overload
1083
- def unsafe_sql(self, stmt: str, *, resultset: type[RS]) -> SQL_RX_P0[RS]: ...
1084
- @overload
1085
- def unsafe_sql(self, stmt: str, *, arg: type[P1]) -> SQL_PX[P1]: ...
1086
- @overload
1087
- def unsafe_sql(self, stmt: str, *, arg: type[P1], result: type[R1]) -> SQL_R1_PX[R1, P1]: ...
1088
- @overload
1089
- def unsafe_sql(self, stmt: str, *, arg: type[P1], resultset: type[RS]) -> SQL_RX_PX[RS, P1]: ...
1090
- @overload
1091
- def unsafe_sql(self, stmt: str, *, args: type[tuple[P1, Unpack[PX]]]) -> SQL_PX[P1, Unpack[PX]]: ...
1092
- @overload
1093
- def unsafe_sql(self, stmt: str, *, args: type[tuple[P1, Unpack[PX]]], result: type[R1]) -> SQL_R1_PX[R1, P1, Unpack[PX]]: ...
1094
- @overload
1095
- def unsafe_sql(self, stmt: str, *, args: type[tuple[P1, Unpack[PX]]], resultset: type[RS]) -> SQL_RX_PX[RS, P1, Unpack[PX]]: ...
1096
-
1097
- ### END OF AUTO-GENERATED BLOCK FOR unsafe_sql ###
1098
-
1099
- def unsafe_sql(self, stmt: str, *, args: type[Any] | None = None, resultset: type[Any] | None = None, arg: type[Any] | None = None, result: type[Any] | None = None) -> _SQL:
1100
- """
1101
- Creates a SQL statement with associated type information from a string.
1102
-
1103
- This offers an alternative to the function :func:`sql` when we want to prevent the type checker from enforcing
1104
- a string literal, e.g. when we want to embed a variable as the table name to dynamically create a SQL statement.
1105
-
1106
- :param stmt: SQL statement as a string (or f-string).
1107
- :param args: Type signature for multiple input parameters (e.g. `tuple[bool, int, str]`).
1108
- :param resultset: Type signature for multiple resultset columns (e.g. `tuple[datetime, Decimal, str]`).
1109
- :param arg: Type signature for a single input parameter (e.g. `int`).
1110
- :param result: Type signature for a single result column (e.g. `UUID`).
1111
- """
1112
-
1113
- input_data_types, output_data_types, names, wrapper = _sql_args_resultset(args=args, resultset=resultset, arg=arg, result=result)
1114
- obj = _SQLString(stmt, args=input_data_types, resultset=output_data_types, names=names, wrapper=wrapper)
1115
- return _SQLImpl(obj)
1116
-
1117
-
1118
- def _sql_args_resultset(
1119
- *, args: type[Any] | None = None, resultset: type[Any] | None = None, arg: type[Any] | None = None, result: type[Any] | None = None
1120
- ) -> tuple[tuple[Any, ...], tuple[Any, ...], tuple[str, ...] | None, _ResultsetWrapper]:
1121
- "Parses an argument/resultset signature into input/output types."
1122
-
1123
- if args is not None and arg is not None:
1124
- raise TypeError("expected: either `args` or `arg`; got: both")
1125
- if resultset is not None and result is not None:
1126
- raise TypeError("expected: either `resultset` or `result`; got: both")
1127
-
1128
- if args is not None:
1129
- if hasattr(args, "_asdict") and hasattr(args, "_fields"):
1130
- # named tuple
1131
- input_data_types = tuple(tp for tp in args.__annotations__.values())
1132
- else:
1133
- # regular tuple
1134
- if get_origin(args) is not tuple:
1135
- raise TypeError(f"expected: `type[tuple[T, ...]]` for `args`; got: {args}")
1136
- input_data_types = get_args(args)
1137
- elif arg is not None:
1138
- input_data_types = (arg,)
1139
- else:
1140
- input_data_types = ()
1141
-
1142
- if resultset is not None:
1143
- if hasattr(resultset, "_asdict") and hasattr(resultset, "_fields") and hasattr(resultset, "_make"):
1144
- # named tuple
1145
- output_data_types = tuple(tp for tp in resultset.__annotations__.values())
1146
- names = tuple(f for f in resultset._fields)
1147
- wrapper = _ResultsetWrapper(resultset, resultset._make)
1148
- else:
1149
- # regular tuple
1150
- if get_origin(resultset) is not tuple:
1151
- raise TypeError(f"expected: `type[tuple[T, ...]]` for `resultset`; got: {resultset}")
1152
- output_data_types = get_args(resultset)
1153
- names = None
1154
- wrapper = _ResultsetWrapper(None, tuple)
1155
- else:
1156
- if result is not None:
1157
- output_data_types = (result,)
1158
- else:
1159
- output_data_types = ()
1160
- names = None
1161
- wrapper = _ResultsetWrapper(None, tuple)
1162
-
1163
- return input_data_types, output_data_types, names, wrapper
1164
-
1165
-
1166
- FACTORY: SQLFactory = SQLFactory()
1167
-
1168
- sql = FACTORY.sql
1169
- unsafe_sql = FACTORY.unsafe_sql
1
+ """
2
+ Type-safe queries for asyncpg.
3
+
4
+ :see: https://github.com/hunyadi/asyncpg_typed
5
+ """
6
+
7
+ __version__ = "0.1.5"
8
+ __author__ = "Levente Hunyadi"
9
+ __copyright__ = "Copyright 2025-2026, Levente Hunyadi"
10
+ __license__ = "MIT"
11
+ __maintainer__ = "Levente Hunyadi"
12
+ __status__ = "Production"
13
+
14
+ import enum
15
+ import sys
16
+ import typing
17
+ from abc import abstractmethod
18
+ from collections.abc import Callable, Iterable, Sequence
19
+ from dataclasses import dataclass
20
+ from datetime import date, datetime, time, timedelta
21
+ from decimal import Decimal
22
+ from functools import reduce
23
+ from io import StringIO
24
+ from ipaddress import IPv4Address, IPv4Network, IPv6Address, IPv6Network
25
+ from types import UnionType
26
+ from typing import Any, Protocol, TypeAlias, TypeGuard, TypeVar, Union, get_args, get_origin, overload
27
+ from uuid import UUID
28
+
29
+ import asyncpg
30
+ from asyncpg.prepared_stmt import PreparedStatement
31
+
32
+ if sys.version_info < (3, 11):
33
+ from typing_extensions import LiteralString, TypeVarTuple, Unpack
34
+ else:
35
+ from typing import LiteralString, TypeVarTuple, Unpack
36
+
37
+ JsonType = None | bool | int | float | str | dict[str, "JsonType"] | list["JsonType"]
38
+
39
+ RequiredJsonType = bool | int | float | str | dict[str, "JsonType"] | list["JsonType"]
40
+
41
+ TargetType: TypeAlias = type[Any] | UnionType
42
+
43
+ Connection: TypeAlias = asyncpg.Connection | asyncpg.pool.PoolConnectionProxy
44
+
45
+
46
+ class CountMismatchError(TypeError):
47
+ "Raised when a prepared statement takes or returns a different number of parameters or columns than declared in Python."
48
+
49
+
50
+ class TypeMismatchError(TypeError):
51
+ "Raised when a prepared statement takes or returns a PostgreSQL type incompatible with the declared Python type."
52
+
53
+
54
+ class NameMismatchError(TypeError):
55
+ "Raised when the name of a result-set column differs from what is declared in Python."
56
+
57
+
58
+ class EnumMismatchError(TypeError):
59
+ "Raised when a prepared statement takes or returns a PostgreSQL enum type whose permitted set of values differs from what is declared in Python."
60
+
61
+
62
+ class NoneTypeError(TypeError):
63
+ "Raised when a column marked as required contains a `NULL` value."
64
+
65
+
66
+ if sys.version_info < (3, 11):
67
+
68
+ def is_enum_type(typ: Any) -> TypeGuard[type[enum.Enum]]:
69
+ """
70
+ `True` if the specified type is an enumeration type.
71
+ """
72
+
73
+ # use an explicit isinstance(..., type) check to filter out special forms like generics
74
+ return isinstance(typ, type) and issubclass(typ, enum.Enum)
75
+
76
+
77
+ else:
78
+
79
+ def is_enum_type(typ: Any) -> TypeGuard[type[enum.Enum]]:
80
+ """
81
+ `True` if the specified type is an enumeration type.
82
+ """
83
+
84
+ return isinstance(typ, enum.EnumType)
85
+
86
+
87
+ def is_union_type(tp: Any) -> bool:
88
+ """
89
+ `True` if `tp` is a union type such as `A | B` or `Union[A, B]`.
90
+ """
91
+
92
+ origin = get_origin(tp)
93
+ return origin is Union or origin is UnionType
94
+
95
+
96
+ def is_optional_type(tp: Any) -> bool:
97
+ """
98
+ `True` if `tp` is an optional type such as `T | None`, `Optional[T]` or `Union[T, None]`.
99
+ """
100
+
101
+ return is_union_type(tp) and any(a is type(None) for a in get_args(tp))
102
+
103
+
104
+ def is_standard_type(tp: Any) -> bool:
105
+ """
106
+ `True` if the type represents a built-in or a well-known standard type.
107
+ """
108
+
109
+ return tp.__module__ == "builtins" or tp.__module__ == UnionType.__module__
110
+
111
+
112
+ def is_json_type(tp: Any) -> bool:
113
+ """
114
+ `True` if the type represents an object de-serialized from a JSON string.
115
+ """
116
+
117
+ return tp in [JsonType, RequiredJsonType]
118
+
119
+
120
+ def is_inet_type(tp: Any) -> bool:
121
+ """
122
+ `True` if the type represents an IP address or network.
123
+ """
124
+
125
+ return tp in [IPv4Address, IPv6Address, IPv4Network, IPv6Network]
126
+
127
+
128
+ def make_union_type(tpl: list[Any]) -> UnionType:
129
+ """
130
+ Creates a `UnionType` (a.k.a. `A | B | C`) dynamically at run time.
131
+ """
132
+
133
+ if len(tpl) < 2:
134
+ raise ValueError("expected: at least two types to make a `UnionType`")
135
+
136
+ return reduce(lambda a, b: a | b, tpl)
137
+
138
+
139
+ def get_required_type(tp: Any) -> Any:
140
+ """
141
+ Removes `None` from an optional type (i.e. a union type that has `None` as a member).
142
+ """
143
+
144
+ if not is_optional_type(tp):
145
+ return tp
146
+
147
+ tpl = [a for a in get_args(tp) if a is not type(None)]
148
+ if len(tpl) > 1:
149
+ return make_union_type(tpl)
150
+ elif len(tpl) > 0:
151
+ return tpl[0]
152
+ else:
153
+ return type(None)
154
+
155
+
156
+ def _standard_json_decoder() -> Callable[[str], JsonType]:
157
+ import json
158
+
159
+ _json_decoder = json.JSONDecoder()
160
+ return _json_decoder.decode
161
+
162
+
163
+ def _json_decoder() -> Callable[[str], JsonType]:
164
+ if typing.TYPE_CHECKING:
165
+ return _standard_json_decoder()
166
+ else:
167
+ try:
168
+ import orjson
169
+
170
+ return orjson.loads
171
+ except ModuleNotFoundError:
172
+ return _standard_json_decoder()
173
+
174
+
175
+ JSON_DECODER = _json_decoder()
176
+
177
+
178
+ def _standard_json_encoder() -> Callable[[JsonType], str]:
179
+ import json
180
+
181
+ _json_encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":"), allow_nan=False)
182
+ return _json_encoder.encode
183
+
184
+
185
+ def _json_encoder() -> Callable[[JsonType], str]:
186
+ if typing.TYPE_CHECKING:
187
+ return _standard_json_encoder()
188
+ else:
189
+ try:
190
+ import orjson
191
+
192
+ def _wrap(value: JsonType) -> str:
193
+ return orjson.dumps(value).decode()
194
+
195
+ return _wrap
196
+ except ModuleNotFoundError:
197
+ return _standard_json_encoder()
198
+
199
+
200
+ JSON_ENCODER = _json_encoder()
201
+
202
+
203
+ def get_output_converter_for(tp: Any) -> Callable[[Any], Any]:
204
+ """
205
+ Returns a callable that takes a wire type and returns a target type.
206
+
207
+ A wire type is one of the types returned by asyncpg.
208
+ A target type is one of the types supported by the library.
209
+ """
210
+
211
+ if is_json_type(tp):
212
+ # asyncpg returns fields of type `json` and `jsonb` as `str`, which must be de-serialized
213
+ return JSON_DECODER
214
+ else:
215
+ # target data types that require conversion must have a single-argument `__init__` that takes an object of the source type
216
+ return tp
217
+
218
+
219
+ def get_input_converter_for(tp: Any) -> Callable[[Any], Any]:
220
+ """
221
+ Returns a callable that takes a source type and returns a wire type.
222
+
223
+ A source type is one of the types supported by the library.
224
+ A wire type is one of the types returned by asyncpg.
225
+ """
226
+
227
+ if is_json_type(tp):
228
+ # asyncpg expects fields of type `json` and `jsonb` as `str`, which must be serialized
229
+ return JSON_ENCODER
230
+ else:
231
+ # source data types that require conversion must have a single-argument `__init__` that takes an object of the source type
232
+ return tp
233
+
234
+
235
+ # maps PostgreSQL internal type names to compatible Python types
236
+ _NAME_TO_TYPE: dict[str, tuple[Any, ...]] = {
237
+ # boolean type
238
+ "bool": (bool,),
239
+ # numeric types
240
+ "int2": (int,),
241
+ "int4": (int,),
242
+ "int8": (int,),
243
+ "float4": (float,),
244
+ "float8": (float,),
245
+ "numeric": (Decimal,),
246
+ # date and time types
247
+ "date": (date,),
248
+ "time": (time,),
249
+ "timetz": (time,),
250
+ "timestamp": (datetime,),
251
+ "timestamptz": (datetime,),
252
+ "interval": (timedelta,),
253
+ # character sequence types
254
+ "bpchar": (str,),
255
+ "varchar": (str,),
256
+ "text": (str,),
257
+ # binary sequence types
258
+ "bytea": (bytes,),
259
+ # unique identifier type
260
+ "uuid": (UUID,),
261
+ # address types
262
+ "cidr": (IPv4Network, IPv6Network, IPv4Network | IPv6Network),
263
+ "inet": (IPv4Network, IPv6Network, IPv4Network | IPv6Network, IPv4Address, IPv6Address, IPv4Address | IPv6Address),
264
+ "macaddr": (str,),
265
+ "macaddr8": (str,),
266
+ # JSON type
267
+ "json": (str, RequiredJsonType),
268
+ "jsonb": (str, RequiredJsonType),
269
+ # XML type
270
+ "xml": (str,),
271
+ # geometric types
272
+ "point": (asyncpg.Point,),
273
+ "line": (asyncpg.Line,),
274
+ "lseg": (asyncpg.LineSegment,),
275
+ "box": (asyncpg.Box,),
276
+ "path": (asyncpg.Path,),
277
+ "polygon": (asyncpg.Polygon,),
278
+ "circle": (asyncpg.Circle,),
279
+ # range types
280
+ "int4range": (asyncpg.Range[int],),
281
+ "int4multirange": (list[asyncpg.Range[int]],),
282
+ "int8range": (asyncpg.Range[int],),
283
+ "int8multirange": (list[asyncpg.Range[int]],),
284
+ "numrange": (asyncpg.Range[Decimal],),
285
+ "nummultirange": (list[asyncpg.Range[Decimal]],),
286
+ "tsrange": (asyncpg.Range[datetime],),
287
+ "tsmultirange": (list[asyncpg.Range[datetime]],),
288
+ "tstzrange": (asyncpg.Range[datetime],),
289
+ "tstzmultirange": (list[asyncpg.Range[datetime]],),
290
+ "daterange": (asyncpg.Range[date],),
291
+ "datemultirange": (list[asyncpg.Range[date]],),
292
+ }
293
+
294
+
295
+ def type_to_str(tp: Any) -> str:
296
+ "Emits a friendly name for a type."
297
+
298
+ if isinstance(tp, type):
299
+ return tp.__name__
300
+ else:
301
+ return str(tp)
302
+
303
+
304
+ class _TypeVerifier:
305
+ """
306
+ Verifies if the Python target type can represent the PostgreSQL source type.
307
+ """
308
+
309
+ _connection: Connection
310
+
311
+ def __init__(self, connection: Connection) -> None:
312
+ self._connection = connection
313
+
314
+ async def _check_enum_type(self, pg_name: str, pg_type: asyncpg.Type, data_type: type[enum.Enum]) -> None:
315
+ """
316
+ Verifies if a Python enumeration type matches a PostgreSQL enumeration type.
317
+ """
318
+
319
+ for e in data_type:
320
+ if not isinstance(e.value, str):
321
+ raise TypeMismatchError(f"expected: Python enum type `{type_to_str(data_type)}` with `str` values; got: `{type_to_str(type(e.value))}` for enum field `{e.name}`")
322
+
323
+ py_values = set(e.value for e in data_type)
324
+
325
+ rows = await self._connection.fetch("SELECT enumlabel FROM pg_enum WHERE enumtypid = $1 ORDER BY enumsortorder;", pg_type.oid)
326
+ db_values = set(row[0] for row in rows)
327
+
328
+ db_extra = db_values - py_values
329
+ if db_extra:
330
+ raise EnumMismatchError(f"expected: Python enum type `{type_to_str(data_type)}` to match values of PostgreSQL enum type `{pg_type.name}` for {pg_name}; missing value(s): {', '.join(f'`{val}`' for val in db_extra)})")
331
+
332
+ py_extra = py_values - db_values
333
+ if py_extra:
334
+ raise EnumMismatchError(f"expected: Python enum type `{type_to_str(data_type)}` to match values of PostgreSQL enum type `{pg_type.name}` for {pg_name}; got extra value(s): {', '.join(f'`{val}`' for val in py_extra)})")
335
+
336
+ async def check_data_type(self, pg_name: str, pg_type: asyncpg.Type, data_type: TargetType) -> None:
337
+ """
338
+ Verifies if the Python target type can represent the PostgreSQL source type.
339
+ """
340
+
341
+ if pg_type.schema == "pg_catalog": # well-known PostgreSQL types
342
+ if is_enum_type(data_type):
343
+ if pg_type.name not in ["bpchar", "varchar", "text"]:
344
+ raise TypeMismatchError(f"expected: Python enum type `{type_to_str(data_type)}` for {pg_name}; got: PostgreSQL type `{pg_type.kind}` of `{pg_type.name}` instead of `char`, `varchar` or `text`")
345
+ elif pg_type.kind == "array" and get_origin(data_type) is list:
346
+ if not pg_type.name.endswith("[]"):
347
+ raise TypeMismatchError(f"expected: Python list type `{type_to_str(data_type)}` for {pg_name}; got: PostgreSQL type `{pg_type.kind}` of `{pg_type.name}` instead of array")
348
+
349
+ expected_types = _NAME_TO_TYPE.get(pg_type.name[:-2])
350
+ if expected_types is None:
351
+ raise TypeMismatchError(f"expected: Python list type `{type_to_str(data_type)}` for {pg_name}; got: unrecognized PostgreSQL type `{pg_type.kind}` of `{pg_type.name}`")
352
+ elif get_args(data_type)[0] not in expected_types:
353
+ if len(expected_types) == 1:
354
+ target = f"the Python type `{type_to_str(expected_types[0])}`"
355
+ else:
356
+ target = f"one of the Python types {', '.join(f'`{type_to_str(tp)}`' for tp in expected_types)}"
357
+ raise TypeMismatchError(f"expected: Python list type `{type_to_str(data_type)}` for {pg_name}; got: incompatible PostgreSQL type `{pg_type.kind}` of `{pg_type.name}` whose elements convert to {target}")
358
+ else:
359
+ expected_types = _NAME_TO_TYPE.get(pg_type.name)
360
+ if expected_types is None:
361
+ raise TypeMismatchError(f"expected: Python type `{type_to_str(data_type)}` for {pg_name}; got: unrecognized PostgreSQL type `{pg_type.kind}` of `{pg_type.name}`")
362
+ elif data_type not in expected_types:
363
+ if len(expected_types) == 1:
364
+ target = f"the Python type `{type_to_str(expected_types[0])}`"
365
+ else:
366
+ target = f"one of the Python types {', '.join(f'`{type_to_str(tp)}`' for tp in expected_types)}"
367
+ raise TypeMismatchError(f"expected: Python type `{type_to_str(data_type)}` for {pg_name}; got: incompatible PostgreSQL type `{pg_type.kind}` of `{pg_type.name}` which converts to {target}")
368
+ elif pg_type.kind == "composite": # PostgreSQL composite types
369
+ # user-defined composite types registered with `conn.set_type_codec()` typically using `format="tuple"`
370
+ pass
371
+ else: # custom PostgreSQL types
372
+ if is_enum_type(data_type):
373
+ await self._check_enum_type(pg_name, pg_type, data_type)
374
+ elif is_standard_type(data_type):
375
+ raise TypeMismatchError(f"expected: Python type `{type_to_str(data_type)}` for {pg_name}; got: PostgreSQL type `{pg_type.kind}` of `{pg_type.name}`")
376
+ else:
377
+ # user-defined types registered with `conn.set_type_codec()`
378
+ pass
379
+
380
+
381
+ @dataclass(frozen=True)
382
+ class _Placeholder:
383
+ ordinal: int
384
+ data_type: TargetType
385
+
386
+ def __repr__(self) -> str:
387
+ return f"{self.__class__.__name__}({self.ordinal}, {self.data_type!r})"
388
+
389
+
390
+ class _ResultsetWrapper:
391
+ "Wraps result-set rows into a tuple or named tuple."
392
+
393
+ init: Callable[..., tuple[Any, ...]] | None
394
+ iterable: Callable[[Iterable[Any]], tuple[Any, ...]]
395
+
396
+ def __init__(self, init: Callable[..., tuple[Any, ...]] | None, iterable: Callable[[Iterable[Any]], tuple[Any, ...]]) -> None:
397
+ """
398
+ Initializes a result-set wrapper.
399
+
400
+ :param init: Initializer function that takes as many arguments as columns in the result-set.
401
+ :param iterable: Initializer function that takes an iterable over columns of a result-set row.
402
+ """
403
+
404
+ self.init = init
405
+ self.iterable = iterable
406
+
407
+
408
+ class _SQLObject:
409
+ """
410
+ Associates input and output type information with a SQL statement.
411
+ """
412
+
413
+ _parameter_data_types: tuple[_Placeholder, ...]
414
+ _resultset_data_types: tuple[TargetType, ...]
415
+ _resultset_column_names: tuple[str, ...] | None
416
+ _resultset_wrapper: _ResultsetWrapper
417
+ _parameter_cast: int
418
+ _parameter_converters: tuple[Callable[[Any], Any], ...]
419
+ _required: int
420
+ _resultset_cast: int
421
+ _resultset_converters: tuple[Callable[[Any], Any], ...]
422
+
423
+ @property
424
+ def parameter_data_types(self) -> tuple[_Placeholder, ...]:
425
+ "Expected inbound parameter data types."
426
+
427
+ return self._parameter_data_types
428
+
429
+ @property
430
+ def resultset_data_types(self) -> tuple[TargetType, ...]:
431
+ "Expected column data types in the result-set."
432
+
433
+ return self._resultset_data_types
434
+
435
+ @property
436
+ def resultset_column_names(self) -> tuple[str, ...] | None:
437
+ "Expected column names in the result-set."
438
+
439
+ return self._resultset_column_names
440
+
441
+ def __init__(
442
+ self,
443
+ *,
444
+ args: tuple[TargetType, ...],
445
+ resultset: tuple[TargetType, ...],
446
+ names: tuple[str, ...] | None,
447
+ wrapper: _ResultsetWrapper,
448
+ ) -> None:
449
+ self._parameter_data_types = tuple(_Placeholder(ordinal, get_required_type(arg)) for ordinal, arg in enumerate(args, start=1))
450
+ self._resultset_data_types = tuple(get_required_type(data_type) for data_type in resultset)
451
+ self._resultset_column_names = names
452
+ self._resultset_wrapper = wrapper
453
+
454
+ # create a bit-field of types that require cast or serialization (1: apply conversion; 0: forward value as-is)
455
+ parameter_cast = 0
456
+ for index, placeholder in enumerate(self._parameter_data_types):
457
+ parameter_cast |= is_json_type(placeholder.data_type) << index
458
+ self._parameter_cast = parameter_cast
459
+
460
+ self._parameter_converters = tuple(get_input_converter_for(placeholder.data_type) for placeholder in self._parameter_data_types)
461
+
462
+ # create a bit-field of required types (1: required; 0: optional)
463
+ required = 0
464
+ for index, data_type in enumerate(resultset):
465
+ required |= (not is_optional_type(data_type)) << index
466
+ self._required = required
467
+
468
+ # create a bit-field of types that require cast or de-serialization (1: apply conversion; 0: forward value as-is)
469
+ resultset_cast = 0
470
+ for index, data_type in enumerate(self._resultset_data_types):
471
+ resultset_cast |= (is_enum_type(data_type) or is_json_type(data_type) or is_inet_type(data_type)) << index
472
+ self._resultset_cast = resultset_cast
473
+
474
+ self._resultset_converters = tuple(get_output_converter_for(data_type) for data_type in self._resultset_data_types)
475
+
476
+ def _raise_required_is_none(self, row: tuple[Any, ...], row_index: int | None = None) -> None:
477
+ """
478
+ Raises an error with the index of the first column value that is of a required type but has been assigned a value of `None`.
479
+ """
480
+
481
+ for col_index in range(len(row)):
482
+ if (self._required >> col_index & 1) and row[col_index] is None:
483
+ if row_index is not None:
484
+ row_col_spec = f"row #{row_index} and column #{col_index}"
485
+ else:
486
+ row_col_spec = f"column #{col_index}"
487
+ raise NoneTypeError(f"expected: {self._resultset_data_types[col_index]} in {row_col_spec}; got: NULL")
488
+
489
+ def check_rows(self, rows: list[tuple[Any, ...]]) -> None:
490
+ """
491
+ Verifies if declared types match actual value types in a resultset.
492
+ """
493
+
494
+ if not rows:
495
+ return
496
+
497
+ required = self._required
498
+ if not required:
499
+ return
500
+
501
+ match len(rows[0]):
502
+ case 1:
503
+ for r, row in enumerate(rows):
504
+ if required & (row[0] is None):
505
+ self._raise_required_is_none(row, r)
506
+ case 2:
507
+ for r, row in enumerate(rows):
508
+ a, b = row
509
+ if required & ((a is None) | (b is None) << 1):
510
+ self._raise_required_is_none(row, r)
511
+ case 3:
512
+ for r, row in enumerate(rows):
513
+ a, b, c = row
514
+ if required & ((a is None) | (b is None) << 1 | (c is None) << 2):
515
+ self._raise_required_is_none(row, r)
516
+ case 4:
517
+ for r, row in enumerate(rows):
518
+ a, b, c, d = row
519
+ if required & ((a is None) | (b is None) << 1 | (c is None) << 2 | (d is None) << 3):
520
+ self._raise_required_is_none(row, r)
521
+ case 5:
522
+ for r, row in enumerate(rows):
523
+ a, b, c, d, e = row
524
+ if required & ((a is None) | (b is None) << 1 | (c is None) << 2 | (d is None) << 3 | (e is None) << 4):
525
+ self._raise_required_is_none(row, r)
526
+ case 6:
527
+ for r, row in enumerate(rows):
528
+ a, b, c, d, e, f = row
529
+ if required & ((a is None) | (b is None) << 1 | (c is None) << 2 | (d is None) << 3 | (e is None) << 4 | (f is None) << 5):
530
+ self._raise_required_is_none(row, r)
531
+ case 7:
532
+ for r, row in enumerate(rows):
533
+ a, b, c, d, e, f, g = row
534
+ if required & ((a is None) | (b is None) << 1 | (c is None) << 2 | (d is None) << 3 | (e is None) << 4 | (f is None) << 5 | (g is None) << 6):
535
+ self._raise_required_is_none(row, r)
536
+ case 8:
537
+ for r, row in enumerate(rows):
538
+ a, b, c, d, e, f, g, h = row
539
+ if required & ((a is None) | (b is None) << 1 | (c is None) << 2 | (d is None) << 3 | (e is None) << 4 | (f is None) << 5 | (g is None) << 6 | (h is None) << 7):
540
+ self._raise_required_is_none(row, r)
541
+ case _:
542
+ for r, row in enumerate(rows):
543
+ self._raise_required_is_none(row, r)
544
+
545
+ def check_row(self, row: tuple[Any, ...]) -> None:
546
+ """
547
+ Verifies if declared types match actual value types in a single row.
548
+ """
549
+
550
+ if self._required:
551
+ self._raise_required_is_none(row)
552
+
553
+ def check_column(self, column: list[Any]) -> None:
554
+ """
555
+ Verifies if the declared type matches the actual value type of a single-column resultset.
556
+ """
557
+
558
+ if self._required:
559
+ for i, value in enumerate(column):
560
+ if value is None:
561
+ raise NoneTypeError(f"expected: {self._resultset_data_types[0]} in row #{i}; got: NULL")
562
+
563
+ def check_value(self, value: Any) -> None:
564
+ """
565
+ Verifies if the declared type matches the actual value type.
566
+ """
567
+
568
+ if self._required and value is None:
569
+ raise NoneTypeError(f"expected: {self._resultset_data_types[0]}; got: NULL")
570
+
571
+ def convert_arg_lists(self, arg_lists: Iterable[Sequence[Any]]) -> Iterable[Sequence[Any]]:
572
+ """
573
+ Converts a list of Python query argument tuples to a list of PostgreSQL parameter tuples.
574
+ """
575
+
576
+ cast = self._parameter_cast
577
+ if cast:
578
+ converters = self._parameter_converters
579
+ yield from (tuple((converters[i](value) if (value := arg[i]) is not None and cast >> i & 1 else value) for i in range(len(arg))) for arg in arg_lists)
580
+ else:
581
+ yield from arg_lists
582
+
583
+ def convert_arg_list(self, arg_list: Sequence[Any]) -> Sequence[Any]:
584
+ """
585
+ Converts Python query arguments to PostgreSQL parameters.
586
+ """
587
+
588
+ cast = self._parameter_cast
589
+ if cast:
590
+ converters = self._parameter_converters
591
+ return tuple((converters[i](value) if (value := arg_list[i]) is not None and cast >> i & 1 else value) for i in range(len(arg_list)))
592
+ else:
593
+ return tuple(value for value in arg_list)
594
+
595
+ def convert_rows(self, rows: list[asyncpg.Record]) -> list[tuple[Any, ...]]:
596
+ """
597
+ Converts columns in the PostgreSQL result-set to their corresponding Python target types.
598
+
599
+ :param rows: List of rows returned by PostgreSQL.
600
+ :returns: List of tuples with each tuple element having the configured Python target type.
601
+ """
602
+
603
+ if not rows:
604
+ return []
605
+
606
+ if self._resultset_cast:
607
+ return self._convert_rows_cast(rows)
608
+ else:
609
+ return self._convert_rows_pass(rows)
610
+
611
+ def _convert_rows_cast(self, rows: list[asyncpg.Record]) -> list[tuple[Any, ...]]:
612
+ init_wrapper = self._resultset_wrapper.init
613
+ iterable_wrapper = self._resultset_wrapper.iterable
614
+ cast = self._resultset_cast
615
+ cast_a = cast >> 0 & 1
616
+ cast_b = cast >> 1 & 1
617
+ cast_c = cast >> 2 & 1
618
+ cast_d = cast >> 3 & 1
619
+ cast_e = cast >> 4 & 1
620
+ cast_f = cast >> 5 & 1
621
+ cast_g = cast >> 6 & 1
622
+ columns = len(rows[0])
623
+ match columns:
624
+ case 1:
625
+ converter = self._resultset_converters[0]
626
+ if init_wrapper is not None:
627
+ return [init_wrapper(converter(value) if (value := row[0]) is not None else value) for row in rows]
628
+ else:
629
+ return [(converter(value) if (value := row[0]) is not None else value,) for row in rows]
630
+ ### START OF AUTO-GENERATED BLOCK FOR convert_cast ###
631
+ case 2:
632
+ conv_a, conv_b = self._resultset_converters
633
+ if init_wrapper is not None:
634
+ return [
635
+ init_wrapper(
636
+ conv_a(value) if (value := row[0]) is not None and cast_a else value,
637
+ conv_b(value) if (value := row[1]) is not None and cast_b else value,
638
+ )
639
+ for row in rows
640
+ ]
641
+ else:
642
+ return [
643
+ (
644
+ conv_a(value) if (value := row[0]) is not None and cast_a else value,
645
+ conv_b(value) if (value := row[1]) is not None and cast_b else value,
646
+ )
647
+ for row in rows
648
+ ]
649
+ case 3:
650
+ conv_a, conv_b, conv_c = self._resultset_converters
651
+ if init_wrapper is not None:
652
+ return [
653
+ init_wrapper(
654
+ conv_a(value) if (value := row[0]) is not None and cast_a else value,
655
+ conv_b(value) if (value := row[1]) is not None and cast_b else value,
656
+ conv_c(value) if (value := row[2]) is not None and cast_c else value,
657
+ )
658
+ for row in rows
659
+ ]
660
+ else:
661
+ return [
662
+ (
663
+ conv_a(value) if (value := row[0]) is not None and cast_a else value,
664
+ conv_b(value) if (value := row[1]) is not None and cast_b else value,
665
+ conv_c(value) if (value := row[2]) is not None and cast_c else value,
666
+ )
667
+ for row in rows
668
+ ]
669
+ case 4:
670
+ conv_a, conv_b, conv_c, conv_d = self._resultset_converters
671
+ if init_wrapper is not None:
672
+ return [
673
+ init_wrapper(
674
+ conv_a(value) if (value := row[0]) is not None and cast_a else value,
675
+ conv_b(value) if (value := row[1]) is not None and cast_b else value,
676
+ conv_c(value) if (value := row[2]) is not None and cast_c else value,
677
+ conv_d(value) if (value := row[3]) is not None and cast_d else value,
678
+ )
679
+ for row in rows
680
+ ]
681
+ else:
682
+ return [
683
+ (
684
+ conv_a(value) if (value := row[0]) is not None and cast_a else value,
685
+ conv_b(value) if (value := row[1]) is not None and cast_b else value,
686
+ conv_c(value) if (value := row[2]) is not None and cast_c else value,
687
+ conv_d(value) if (value := row[3]) is not None and cast_d else value,
688
+ )
689
+ for row in rows
690
+ ]
691
+ case 5:
692
+ conv_a, conv_b, conv_c, conv_d, conv_e = self._resultset_converters
693
+ if init_wrapper is not None:
694
+ return [
695
+ init_wrapper(
696
+ conv_a(value) if (value := row[0]) is not None and cast_a else value,
697
+ conv_b(value) if (value := row[1]) is not None and cast_b else value,
698
+ conv_c(value) if (value := row[2]) is not None and cast_c else value,
699
+ conv_d(value) if (value := row[3]) is not None and cast_d else value,
700
+ conv_e(value) if (value := row[4]) is not None and cast_e else value,
701
+ )
702
+ for row in rows
703
+ ]
704
+ else:
705
+ return [
706
+ (
707
+ conv_a(value) if (value := row[0]) is not None and cast_a else value,
708
+ conv_b(value) if (value := row[1]) is not None and cast_b else value,
709
+ conv_c(value) if (value := row[2]) is not None and cast_c else value,
710
+ conv_d(value) if (value := row[3]) is not None and cast_d else value,
711
+ conv_e(value) if (value := row[4]) is not None and cast_e else value,
712
+ )
713
+ for row in rows
714
+ ]
715
+ case 6:
716
+ conv_a, conv_b, conv_c, conv_d, conv_e, conv_f = self._resultset_converters
717
+ if init_wrapper is not None:
718
+ return [
719
+ init_wrapper(
720
+ conv_a(value) if (value := row[0]) is not None and cast_a else value,
721
+ conv_b(value) if (value := row[1]) is not None and cast_b else value,
722
+ conv_c(value) if (value := row[2]) is not None and cast_c else value,
723
+ conv_d(value) if (value := row[3]) is not None and cast_d else value,
724
+ conv_e(value) if (value := row[4]) is not None and cast_e else value,
725
+ conv_f(value) if (value := row[5]) is not None and cast_f else value,
726
+ )
727
+ for row in rows
728
+ ]
729
+ else:
730
+ return [
731
+ (
732
+ conv_a(value) if (value := row[0]) is not None and cast_a else value,
733
+ conv_b(value) if (value := row[1]) is not None and cast_b else value,
734
+ conv_c(value) if (value := row[2]) is not None and cast_c else value,
735
+ conv_d(value) if (value := row[3]) is not None and cast_d else value,
736
+ conv_e(value) if (value := row[4]) is not None and cast_e else value,
737
+ conv_f(value) if (value := row[5]) is not None and cast_f else value,
738
+ )
739
+ for row in rows
740
+ ]
741
+ case 7:
742
+ conv_a, conv_b, conv_c, conv_d, conv_e, conv_f, conv_g = self._resultset_converters
743
+ if init_wrapper is not None:
744
+ return [
745
+ init_wrapper(
746
+ conv_a(value) if (value := row[0]) is not None and cast_a else value,
747
+ conv_b(value) if (value := row[1]) is not None and cast_b else value,
748
+ conv_c(value) if (value := row[2]) is not None and cast_c else value,
749
+ conv_d(value) if (value := row[3]) is not None and cast_d else value,
750
+ conv_e(value) if (value := row[4]) is not None and cast_e else value,
751
+ conv_f(value) if (value := row[5]) is not None and cast_f else value,
752
+ conv_g(value) if (value := row[6]) is not None and cast_g else value,
753
+ )
754
+ for row in rows
755
+ ]
756
+ else:
757
+ return [
758
+ (
759
+ conv_a(value) if (value := row[0]) is not None and cast_a else value,
760
+ conv_b(value) if (value := row[1]) is not None and cast_b else value,
761
+ conv_c(value) if (value := row[2]) is not None and cast_c else value,
762
+ conv_d(value) if (value := row[3]) is not None and cast_d else value,
763
+ conv_e(value) if (value := row[4]) is not None and cast_e else value,
764
+ conv_f(value) if (value := row[5]) is not None and cast_f else value,
765
+ conv_g(value) if (value := row[6]) is not None and cast_g else value,
766
+ )
767
+ for row in rows
768
+ ]
769
+
770
+ ### END OF AUTO-GENERATED BLOCK FOR convert_cast ###
771
+ case _:
772
+ converters = self._resultset_converters
773
+ return [iterable_wrapper((converters[i](value) if (value := row[i]) is not None and cast >> i & 1 else value) for i in range(columns)) for row in rows]
774
+
775
+ def _convert_rows_pass(self, rows: list[asyncpg.Record]) -> list[tuple[Any, ...]]:
776
+ init_wrapper = self._resultset_wrapper.init
777
+ iterable_wrapper = self._resultset_wrapper.iterable
778
+ match len(rows[0]):
779
+ case 1:
780
+ if init_wrapper is not None:
781
+ return [init_wrapper(row[0]) for row in rows]
782
+ else:
783
+ return [(row[0],) for row in rows]
784
+ ### START OF AUTO-GENERATED BLOCK FOR convert_pass ###
785
+ case 2:
786
+ if init_wrapper is not None:
787
+ return [init_wrapper(row[0], row[1]) for row in rows]
788
+ else:
789
+ return [(row[0], row[1]) for row in rows]
790
+ case 3:
791
+ if init_wrapper is not None:
792
+ return [init_wrapper(row[0], row[1], row[2]) for row in rows]
793
+ else:
794
+ return [(row[0], row[1], row[2]) for row in rows]
795
+ case 4:
796
+ if init_wrapper is not None:
797
+ return [init_wrapper(row[0], row[1], row[2], row[3]) for row in rows]
798
+ else:
799
+ return [(row[0], row[1], row[2], row[3]) for row in rows]
800
+ case 5:
801
+ if init_wrapper is not None:
802
+ return [init_wrapper(row[0], row[1], row[2], row[3], row[4]) for row in rows]
803
+ else:
804
+ return [(row[0], row[1], row[2], row[3], row[4]) for row in rows]
805
+ case 6:
806
+ if init_wrapper is not None:
807
+ return [init_wrapper(row[0], row[1], row[2], row[3], row[4], row[5]) for row in rows]
808
+ else:
809
+ return [(row[0], row[1], row[2], row[3], row[4], row[5]) for row in rows]
810
+ case 7:
811
+ if init_wrapper is not None:
812
+ return [init_wrapper(row[0], row[1], row[2], row[3], row[4], row[5], row[6]) for row in rows]
813
+ else:
814
+ return [(row[0], row[1], row[2], row[3], row[4], row[5], row[6]) for row in rows]
815
+
816
+ ### END OF AUTO-GENERATED BLOCK FOR convert_pass ###
817
+ case _:
818
+ return [iterable_wrapper(row.values()) for row in rows]
819
+
820
+ def convert_row(self, row: asyncpg.Record) -> tuple[Any, ...]:
821
+ """
822
+ Converts columns in the PostgreSQL result-set to their corresponding Python target types.
823
+
824
+ :param row: A single row returned by PostgreSQL.
825
+ :returns: A tuple with each tuple element having the configured Python target type.
826
+ """
827
+
828
+ wrapper = self._resultset_wrapper.iterable
829
+ cast = self._resultset_cast
830
+ if cast:
831
+ converters = self._resultset_converters
832
+ columns = len(row)
833
+ return wrapper((converters[i](value) if (value := row[i]) is not None and cast >> i & 1 else value) for i in range(columns))
834
+ else:
835
+ return wrapper(row.values())
836
+
837
+ def convert_column(self, rows: list[asyncpg.Record]) -> list[Any]:
838
+ """
839
+ Converts a single column in the PostgreSQL result-set to its corresponding Python target type.
840
+
841
+ :param rows: List of rows returned by PostgreSQL.
842
+ :returns: List of values having the configured Python target type.
843
+ """
844
+
845
+ cast = self._resultset_cast
846
+ if cast:
847
+ converter = self._resultset_converters[0]
848
+ return [(converter(value) if (value := row[0]) is not None else value) for row in rows]
849
+ else:
850
+ return [row[0] for row in rows]
851
+
852
+ def convert_value(self, value: Any) -> Any:
853
+ """
854
+ Converts a single PostgreSQL value to its corresponding Python target type.
855
+
856
+ :param value: A single value returned by PostgreSQL.
857
+ :returns: A converted value having the configured Python target type.
858
+ """
859
+
860
+ return self._resultset_converters[0](value) if value is not None and self._resultset_cast else value
861
+
862
+ @abstractmethod
863
+ def query(self) -> str:
864
+ """
865
+ Returns a SQL query string with PostgreSQL ordinal placeholders.
866
+ """
867
+ ...
868
+
869
+ def __repr__(self) -> str:
870
+ return f"{self.__class__.__name__}({self.query()!r})"
871
+
872
+ def __str__(self) -> str:
873
+ return self.query()
874
+
875
+
876
+ if sys.version_info >= (3, 14):
877
+ from string.templatelib import Interpolation, Template # type: ignore[import-not-found]
878
+
879
+ SQLExpression: TypeAlias = Template | LiteralString
880
+
881
+ class _SQLTemplate(_SQLObject):
882
+ """
883
+ A SQL query specified with the Python t-string syntax.
884
+ """
885
+
886
+ _strings: tuple[str, ...]
887
+ _placeholders: tuple[_Placeholder, ...]
888
+
889
+ def __init__(
890
+ self,
891
+ template: Template,
892
+ *,
893
+ args: tuple[TargetType, ...],
894
+ resultset: tuple[TargetType, ...],
895
+ names: tuple[str, ...] | None,
896
+ wrapper: _ResultsetWrapper,
897
+ ) -> None:
898
+ super().__init__(args=args, resultset=resultset, names=names, wrapper=wrapper)
899
+
900
+ for ip in template.interpolations:
901
+ if ip.conversion is not None:
902
+ raise TypeError(f"interpolation `{ip.expression}` expected to apply no conversion")
903
+ if ip.format_spec:
904
+ raise TypeError(f"interpolation `{ip.expression}` expected to apply no format spec")
905
+ if not isinstance(ip.value, int):
906
+ raise TypeError(f"interpolation `{ip.expression}` expected to evaluate to an integer")
907
+
908
+ self._strings = template.strings
909
+
910
+ if len(self.parameter_data_types) > 0:
911
+
912
+ def _to_placeholder(ip: Interpolation) -> _Placeholder:
913
+ ordinal = int(ip.value)
914
+ if not (0 < ordinal <= len(self.parameter_data_types)):
915
+ raise IndexError(f"interpolation `{ip.expression}` is an ordinal out of range; expected: 0 < value <= {len(self.parameter_data_types)}")
916
+ return self.parameter_data_types[int(ip.value) - 1]
917
+
918
+ self._placeholders = tuple(_to_placeholder(ip) for ip in template.interpolations)
919
+ else:
920
+ self._placeholders = ()
921
+
922
+ def query(self) -> str:
923
+ buf = StringIO()
924
+ for s, p in zip(self._strings[:-1], self._placeholders, strict=True):
925
+ buf.write(s)
926
+ buf.write(f"${p.ordinal}")
927
+ buf.write(self._strings[-1])
928
+ return buf.getvalue()
929
+
930
+ else:
931
+ SQLExpression = LiteralString
932
+
933
+
934
+ class _SQLString(_SQLObject):
935
+ """
936
+ A SQL query specified as a plain string (e.g. f-string).
937
+ """
938
+
939
+ _sql: str
940
+
941
+ def __init__(
942
+ self,
943
+ sql: str,
944
+ *,
945
+ args: tuple[TargetType, ...],
946
+ resultset: tuple[TargetType, ...],
947
+ names: tuple[str, ...] | None,
948
+ wrapper: _ResultsetWrapper,
949
+ ) -> None:
950
+ super().__init__(args=args, resultset=resultset, names=names, wrapper=wrapper)
951
+ self._sql = sql
952
+
953
+ def query(self) -> str:
954
+ return self._sql
955
+
956
+
957
+ class _SQL(Protocol):
958
+ """
959
+ Represents a SQL statement with associated type information.
960
+ """
961
+
962
+
963
+ class _SQLImpl(_SQL):
964
+ """
965
+ Forwards input data to an `asyncpg.PreparedStatement`, and validates output data (if necessary).
966
+ """
967
+
968
+ _sql: _SQLObject
969
+
970
+ def __init__(self, sql: _SQLObject) -> None:
971
+ self._sql = sql
972
+
973
+ def __str__(self) -> str:
974
+ return str(self._sql)
975
+
976
+ def __repr__(self) -> str:
977
+ return repr(self._sql)
978
+
979
+ async def _prepare(self, connection: Connection) -> PreparedStatement:
980
+ stmt = await connection.prepare(self._sql.query())
981
+
982
+ verifier = _TypeVerifier(connection)
983
+
984
+ input_count = len(self._sql.parameter_data_types)
985
+ parameter_count = len(stmt.get_parameters())
986
+ if parameter_count != input_count:
987
+ raise CountMismatchError(f"expected: PostgreSQL query to take {input_count} parameter(s); got: {parameter_count}")
988
+ for param, placeholder in zip(stmt.get_parameters(), self._sql.parameter_data_types, strict=True):
989
+ await verifier.check_data_type(f"parameter ${placeholder.ordinal}", param, placeholder.data_type)
990
+
991
+ output_count = len(self._sql.resultset_data_types)
992
+ column_count = len(stmt.get_attributes())
993
+ if column_count != output_count:
994
+ raise CountMismatchError(f"expected: PostgreSQL query to return {output_count} column(s) in result-set; got: {column_count}")
995
+ if self._sql.resultset_column_names is not None:
996
+ for index, attr, name in zip(range(output_count), stmt.get_attributes(), self._sql.resultset_column_names, strict=True):
997
+ if attr.name != name:
998
+ raise NameMismatchError(f"expected: Python field name `{name}` to match PostgreSQL result-set column name `{attr.name}` for index #{index}")
999
+ for attr, data_type in zip(stmt.get_attributes(), self._sql.resultset_data_types, strict=True):
1000
+ await verifier.check_data_type(f"column `{attr.name}`", attr.type, data_type)
1001
+
1002
+ return stmt
1003
+
1004
+ async def execute(self, connection: asyncpg.Connection, *args: Any) -> None:
1005
+ await connection.execute(self._sql.query(), *self._sql.convert_arg_list(args))
1006
+
1007
+ async def executemany(self, connection: asyncpg.Connection, args: Iterable[Sequence[Any]]) -> None:
1008
+ stmt = await self._prepare(connection)
1009
+ await stmt.executemany(self._sql.convert_arg_lists(args))
1010
+
1011
+ def _cast_fetch(self, rows: list[asyncpg.Record]) -> list[tuple[Any, ...]]:
1012
+ resultset = self._sql.convert_rows(rows)
1013
+ self._sql.check_rows(resultset)
1014
+ return resultset
1015
+
1016
+ async def fetch(self, connection: asyncpg.Connection, *args: Any) -> list[tuple[Any, ...]]:
1017
+ stmt = await self._prepare(connection)
1018
+ rows = await stmt.fetch(*self._sql.convert_arg_list(args))
1019
+ return self._cast_fetch(rows)
1020
+
1021
+ async def fetchmany(self, connection: asyncpg.Connection, args: Iterable[Sequence[Any]]) -> list[tuple[Any, ...]]:
1022
+ stmt = await self._prepare(connection)
1023
+ rows = await stmt.fetchmany(self._sql.convert_arg_lists(args))
1024
+ return self._cast_fetch(rows)
1025
+
1026
+ async def fetchrow(self, connection: asyncpg.Connection, *args: Any) -> tuple[Any, ...] | None:
1027
+ stmt = await self._prepare(connection)
1028
+ row = await stmt.fetchrow(*self._sql.convert_arg_list(args))
1029
+ if row is None:
1030
+ return None
1031
+ resultset = self._sql.convert_row(row)
1032
+ self._sql.check_row(resultset)
1033
+ return resultset
1034
+
1035
+ async def fetchcol(self, connection: asyncpg.Connection, *args: Any) -> list[Any]:
1036
+ stmt = await self._prepare(connection)
1037
+ rows = await stmt.fetch(*self._sql.convert_arg_list(args))
1038
+ column = self._sql.convert_column(rows)
1039
+ self._sql.check_column(column)
1040
+ return column
1041
+
1042
+ async def fetchval(self, connection: asyncpg.Connection, *args: Any) -> Any:
1043
+ stmt = await self._prepare(connection)
1044
+ value = await stmt.fetchval(*self._sql.convert_arg_list(args))
1045
+ result = self._sql.convert_value(value)
1046
+ self._sql.check_value(result)
1047
+ return result
1048
+
1049
+
1050
+ P1 = TypeVar("P1")
1051
+ PX = TypeVarTuple("PX")
1052
+
1053
+ RT = TypeVar("RT")
1054
+ R1 = TypeVar("R1")
1055
+ R2 = TypeVar("R2")
1056
+ RX = TypeVarTuple("RX")
1057
+
1058
+
1059
+ ### START OF AUTO-GENERATED BLOCK FOR Protocol ###
1060
+ class SQL_P0(Protocol):
1061
+ @abstractmethod
1062
+ async def execute(self, connection: Connection) -> None: ...
1063
+
1064
+
1065
+ class SQL_R1_P0(SQL_P0, Protocol[R1]):
1066
+ @abstractmethod
1067
+ async def fetch(self, connection: Connection) -> list[tuple[R1]]: ...
1068
+ @abstractmethod
1069
+ async def fetchrow(self, connection: Connection) -> tuple[R1] | None: ...
1070
+ @abstractmethod
1071
+ async def fetchcol(self, connection: Connection) -> list[R1]: ...
1072
+ @abstractmethod
1073
+ async def fetchval(self, connection: Connection) -> R1: ...
1074
+
1075
+
1076
+ class SQL_RX_P0(SQL_P0, Protocol[RT]):
1077
+ @abstractmethod
1078
+ async def fetch(self, connection: Connection) -> list[RT]: ...
1079
+ @abstractmethod
1080
+ async def fetchrow(self, connection: Connection) -> RT | None: ...
1081
+
1082
+
1083
+ class SQL_PX(Protocol[Unpack[PX]]):
1084
+ @abstractmethod
1085
+ async def execute(self, connection: Connection, *args: Unpack[PX]) -> None: ...
1086
+ @abstractmethod
1087
+ async def executemany(self, connection: Connection, args: Iterable[tuple[Unpack[PX]]]) -> None: ...
1088
+
1089
+
1090
+ class SQL_R1_PX(SQL_PX[Unpack[PX]], Protocol[R1, Unpack[PX]]):
1091
+ @abstractmethod
1092
+ async def fetch(self, connection: Connection, *args: Unpack[PX]) -> list[tuple[R1]]: ...
1093
+ @abstractmethod
1094
+ async def fetchmany(self, connection: Connection, args: Iterable[tuple[Unpack[PX]]]) -> list[tuple[R1]]: ...
1095
+ @abstractmethod
1096
+ async def fetchrow(self, connection: Connection, *args: Unpack[PX]) -> tuple[R1] | None: ...
1097
+ @abstractmethod
1098
+ async def fetchcol(self, connection: Connection, *args: Unpack[PX]) -> list[R1]: ...
1099
+ @abstractmethod
1100
+ async def fetchval(self, connection: Connection, *args: Unpack[PX]) -> R1: ...
1101
+
1102
+
1103
+ class SQL_RX_PX(SQL_PX[Unpack[PX]], Protocol[RT, Unpack[PX]]):
1104
+ @abstractmethod
1105
+ async def fetch(self, connection: Connection, *args: Unpack[PX]) -> list[RT]: ...
1106
+ @abstractmethod
1107
+ async def fetchmany(self, connection: Connection, args: Iterable[tuple[Unpack[PX]]]) -> list[RT]: ...
1108
+ @abstractmethod
1109
+ async def fetchrow(self, connection: Connection, *args: Unpack[PX]) -> RT | None: ...
1110
+
1111
+
1112
+ ### END OF AUTO-GENERATED BLOCK FOR Protocol ###
1113
+
1114
+ RS = TypeVar("RS", bound=tuple[Any, ...])
1115
+
1116
+
1117
+ class SQLFactory:
1118
+ """
1119
+ Creates type-safe SQL queries.
1120
+ """
1121
+
1122
+ ### START OF AUTO-GENERATED BLOCK FOR sql ###
1123
+ @overload
1124
+ def sql(self, stmt: SQLExpression) -> SQL_P0: ...
1125
+ @overload
1126
+ def sql(self, stmt: SQLExpression, *, result: type[R1]) -> SQL_R1_P0[R1]: ...
1127
+ @overload
1128
+ def sql(self, stmt: SQLExpression, *, resultset: type[RS]) -> SQL_RX_P0[RS]: ...
1129
+ @overload
1130
+ def sql(self, stmt: SQLExpression, *, arg: type[P1]) -> SQL_PX[P1]: ...
1131
+ @overload
1132
+ def sql(self, stmt: SQLExpression, *, arg: type[P1], result: type[R1]) -> SQL_R1_PX[R1, P1]: ...
1133
+ @overload
1134
+ def sql(self, stmt: SQLExpression, *, arg: type[P1], resultset: type[RS]) -> SQL_RX_PX[RS, P1]: ...
1135
+ @overload
1136
+ def sql(self, stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]]) -> SQL_PX[P1, Unpack[PX]]: ...
1137
+ @overload
1138
+ def sql(self, stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]], result: type[R1]) -> SQL_R1_PX[R1, P1, Unpack[PX]]: ...
1139
+ @overload
1140
+ def sql(self, stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]], resultset: type[RS]) -> SQL_RX_PX[RS, P1, Unpack[PX]]: ...
1141
+
1142
+ ### END OF AUTO-GENERATED BLOCK FOR sql ###
1143
+
1144
+ def sql(self, stmt: SQLExpression, *, args: type[Any] | None = None, resultset: type[Any] | None = None, arg: type[Any] | None = None, result: type[Any] | None = None) -> _SQL:
1145
+ """
1146
+ Creates a SQL statement with associated type information.
1147
+
1148
+ :param stmt: SQL statement as a literal string or template.
1149
+ :param args: Type signature for multiple input parameters (e.g. `tuple[bool, int, str]`).
1150
+ :param resultset: Type signature for multiple resultset columns (e.g. `tuple[datetime, Decimal, str]`).
1151
+ :param arg: Type signature for a single input parameter (e.g. `int`).
1152
+ :param result: Type signature for a single result column (e.g. `UUID`).
1153
+ """
1154
+
1155
+ input_data_types, output_data_types, names, wrapper = _sql_args_resultset(args=args, resultset=resultset, arg=arg, result=result)
1156
+
1157
+ obj: _SQLObject
1158
+ if sys.version_info >= (3, 14):
1159
+ match stmt:
1160
+ case Template():
1161
+ obj = _SQLTemplate(stmt, args=input_data_types, resultset=output_data_types, names=names, wrapper=wrapper)
1162
+ case str():
1163
+ obj = _SQLString(stmt, args=input_data_types, resultset=output_data_types, names=names, wrapper=wrapper)
1164
+ else:
1165
+ obj = _SQLString(stmt, args=input_data_types, resultset=output_data_types, names=names, wrapper=wrapper)
1166
+
1167
+ return _SQLImpl(obj)
1168
+
1169
+ ### START OF AUTO-GENERATED BLOCK FOR unsafe_sql ###
1170
+ @overload
1171
+ def unsafe_sql(self, stmt: str) -> SQL_P0: ...
1172
+ @overload
1173
+ def unsafe_sql(self, stmt: str, *, result: type[R1]) -> SQL_R1_P0[R1]: ...
1174
+ @overload
1175
+ def unsafe_sql(self, stmt: str, *, resultset: type[RS]) -> SQL_RX_P0[RS]: ...
1176
+ @overload
1177
+ def unsafe_sql(self, stmt: str, *, arg: type[P1]) -> SQL_PX[P1]: ...
1178
+ @overload
1179
+ def unsafe_sql(self, stmt: str, *, arg: type[P1], result: type[R1]) -> SQL_R1_PX[R1, P1]: ...
1180
+ @overload
1181
+ def unsafe_sql(self, stmt: str, *, arg: type[P1], resultset: type[RS]) -> SQL_RX_PX[RS, P1]: ...
1182
+ @overload
1183
+ def unsafe_sql(self, stmt: str, *, args: type[tuple[P1, Unpack[PX]]]) -> SQL_PX[P1, Unpack[PX]]: ...
1184
+ @overload
1185
+ def unsafe_sql(self, stmt: str, *, args: type[tuple[P1, Unpack[PX]]], result: type[R1]) -> SQL_R1_PX[R1, P1, Unpack[PX]]: ...
1186
+ @overload
1187
+ def unsafe_sql(self, stmt: str, *, args: type[tuple[P1, Unpack[PX]]], resultset: type[RS]) -> SQL_RX_PX[RS, P1, Unpack[PX]]: ...
1188
+
1189
+ ### END OF AUTO-GENERATED BLOCK FOR unsafe_sql ###
1190
+
1191
+ def unsafe_sql(self, stmt: str, *, args: type[Any] | None = None, resultset: type[Any] | None = None, arg: type[Any] | None = None, result: type[Any] | None = None) -> _SQL:
1192
+ """
1193
+ Creates a SQL statement with associated type information from a string.
1194
+
1195
+ This offers an alternative to the function :func:`sql` when we want to prevent the type checker from enforcing
1196
+ a string literal, e.g. when we want to embed a variable as the table name to dynamically create a SQL statement.
1197
+
1198
+ :param stmt: SQL statement as a string (or f-string).
1199
+ :param args: Type signature for multiple input parameters (e.g. `tuple[bool, int, str]`).
1200
+ :param resultset: Type signature for multiple resultset columns (e.g. `tuple[datetime, Decimal, str]`).
1201
+ :param arg: Type signature for a single input parameter (e.g. `int`).
1202
+ :param result: Type signature for a single result column (e.g. `UUID`).
1203
+ """
1204
+
1205
+ input_data_types, output_data_types, names, wrapper = _sql_args_resultset(args=args, resultset=resultset, arg=arg, result=result)
1206
+ obj = _SQLString(stmt, args=input_data_types, resultset=output_data_types, names=names, wrapper=wrapper)
1207
+ return _SQLImpl(obj)
1208
+
1209
+
1210
+ def _sql_args_resultset(
1211
+ *, args: type[Any] | None = None, resultset: type[Any] | None = None, arg: type[Any] | None = None, result: type[Any] | None = None
1212
+ ) -> tuple[tuple[Any, ...], tuple[Any, ...], tuple[str, ...] | None, _ResultsetWrapper]:
1213
+ "Parses an argument/resultset signature into input/output types."
1214
+
1215
+ if args is not None and arg is not None:
1216
+ raise TypeError("expected: either `args` or `arg`; got: both")
1217
+ if resultset is not None and result is not None:
1218
+ raise TypeError("expected: either `resultset` or `result`; got: both")
1219
+
1220
+ if args is not None:
1221
+ if hasattr(args, "_asdict") and hasattr(args, "_fields"):
1222
+ # named tuple
1223
+ input_data_types = tuple(tp for tp in args.__annotations__.values())
1224
+ else:
1225
+ # regular tuple
1226
+ if get_origin(args) is not tuple:
1227
+ raise TypeError(f"expected: `type[tuple[T, ...]]` for `args`; got: {args}")
1228
+ input_data_types = get_args(args)
1229
+ elif arg is not None:
1230
+ input_data_types = (arg,)
1231
+ else:
1232
+ input_data_types = ()
1233
+
1234
+ if resultset is not None:
1235
+ if hasattr(resultset, "_asdict") and hasattr(resultset, "_fields") and hasattr(resultset, "_make"):
1236
+ # named tuple
1237
+ output_data_types = tuple(tp for tp in resultset.__annotations__.values())
1238
+ names = tuple(f for f in resultset._fields)
1239
+ wrapper = _ResultsetWrapper(resultset, resultset._make)
1240
+ else:
1241
+ # regular tuple
1242
+ if get_origin(resultset) is not tuple:
1243
+ raise TypeError(f"expected: `type[tuple[T, ...]]` for `resultset`; got: {resultset}")
1244
+ output_data_types = get_args(resultset)
1245
+ names = None
1246
+ wrapper = _ResultsetWrapper(None, tuple)
1247
+ else:
1248
+ if result is not None:
1249
+ output_data_types = (result,)
1250
+ else:
1251
+ output_data_types = ()
1252
+ names = None
1253
+ wrapper = _ResultsetWrapper(None, tuple)
1254
+
1255
+ return input_data_types, output_data_types, names, wrapper
1256
+
1257
+
1258
+ FACTORY: SQLFactory = SQLFactory()
1259
+
1260
+ sql = FACTORY.sql
1261
+ unsafe_sql = FACTORY.unsafe_sql