asyncpg-typed 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
asyncpg_typed/__init__.py CHANGED
@@ -4,9 +4,9 @@ Type-safe queries for asyncpg.
4
4
  :see: https://github.com/hunyadi/asyncpg_typed
5
5
  """
6
6
 
7
- __version__ = "0.1.2"
7
+ __version__ = "0.1.4"
8
8
  __author__ = "Levente Hunyadi"
9
- __copyright__ = "Copyright 2025, Levente Hunyadi"
9
+ __copyright__ = "Copyright 2025-2026, Levente Hunyadi"
10
10
  __license__ = "MIT"
11
11
  __maintainer__ = "Levente Hunyadi"
12
12
  __status__ = "Production"
@@ -16,12 +16,14 @@ import sys
16
16
  import typing
17
17
  from abc import abstractmethod
18
18
  from collections.abc import Callable, Iterable, Sequence
19
+ from dataclasses import dataclass
19
20
  from datetime import date, datetime, time, timedelta
20
21
  from decimal import Decimal
21
22
  from functools import reduce
22
23
  from io import StringIO
24
+ from ipaddress import IPv4Address, IPv4Network, IPv6Address, IPv6Network
23
25
  from types import UnionType
24
- from typing import Any, Protocol, TypeAlias, TypeVar, Union, get_args, get_origin, overload
26
+ from typing import Any, Protocol, TypeAlias, TypeGuard, TypeVar, Union, get_args, get_origin, overload
25
27
  from uuid import UUID
26
28
 
27
29
  import asyncpg
@@ -38,9 +40,32 @@ RequiredJsonType = bool | int | float | str | dict[str, "JsonType"] | list["Json
38
40
 
39
41
  TargetType: TypeAlias = type[Any] | UnionType
40
42
 
43
+ Connection: TypeAlias = asyncpg.Connection | asyncpg.pool.PoolConnectionProxy
44
+
45
+
46
+ class CountMismatchError(TypeError):
47
+ "Raised when a prepared statement takes or returns a different number of parameters or columns than declared in Python."
48
+
49
+
50
+ class TypeMismatchError(TypeError):
51
+ "Raised when a prepared statement takes or returns a PostgreSQL type incompatible with the declared Python type."
52
+
53
+
54
+ class NameMismatchError(TypeError):
55
+ "Raised when the name of a result-set column differs from what is declared in Python."
56
+
57
+
58
+ class EnumMismatchError(TypeError):
59
+ "Raised when a prepared statement takes or returns a PostgreSQL enum type whose permitted set of values differs from what is declared in Python."
60
+
61
+
62
+ class NoneTypeError(TypeError):
63
+ "Raised when a column marked as required contains a `NULL` value."
64
+
65
+
41
66
  if sys.version_info >= (3, 11):
42
67
 
43
- def is_enum_type(typ: object) -> bool:
68
+ def is_enum_type(typ: Any) -> TypeGuard[type[enum.Enum]]:
44
69
  """
45
70
  `True` if the specified type is an enumeration type.
46
71
  """
@@ -49,7 +74,7 @@ if sys.version_info >= (3, 11):
49
74
 
50
75
  else:
51
76
 
52
- def is_enum_type(typ: object) -> bool:
77
+ def is_enum_type(typ: Any) -> TypeGuard[type[enum.Enum]]:
53
78
  """
54
79
  `True` if the specified type is an enumeration type.
55
80
  """
@@ -91,6 +116,14 @@ def is_json_type(tp: Any) -> bool:
91
116
  return tp in [JsonType, RequiredJsonType]
92
117
 
93
118
 
119
+ def is_inet_type(tp: Any) -> bool:
120
+ """
121
+ `True` if the type represents an IP address or network.
122
+ """
123
+
124
+ return tp in [IPv4Address, IPv6Address, IPv4Network, IPv6Network]
125
+
126
+
94
127
  def make_union_type(tpl: list[Any]) -> UnionType:
95
128
  """
96
129
  Creates a `UnionType` (a.k.a. `A | B | C`) dynamically at run time.
@@ -119,25 +152,54 @@ def get_required_type(tp: Any) -> Any:
119
152
  return type(None)
120
153
 
121
154
 
122
- _json_converter: Callable[[str], JsonType]
123
- if typing.TYPE_CHECKING:
155
+ def _standard_json_decoder() -> Callable[[str], JsonType]:
124
156
  import json
125
157
 
126
158
  _json_decoder = json.JSONDecoder()
127
- _json_converter = _json_decoder.decode
128
- else:
129
- try:
130
- import orjson
159
+ return _json_decoder.decode
160
+
161
+
162
+ def _json_decoder() -> Callable[[str], JsonType]:
163
+ if typing.TYPE_CHECKING:
164
+ return _standard_json_decoder()
165
+ else:
166
+ try:
167
+ import orjson
168
+
169
+ return orjson.loads
170
+ except ModuleNotFoundError:
171
+ return _standard_json_decoder()
131
172
 
132
- _json_converter = orjson.loads
133
- except ModuleNotFoundError:
134
- import json
135
173
 
136
- _json_decoder = json.JSONDecoder()
137
- _json_converter = _json_decoder.decode
174
+ JSON_DECODER = _json_decoder()
138
175
 
139
176
 
140
- def get_converter_for(tp: Any) -> Callable[[Any], Any]:
177
+ def _standard_json_encoder() -> Callable[[JsonType], str]:
178
+ import json
179
+
180
+ _json_encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":"), allow_nan=False)
181
+ return _json_encoder.encode
182
+
183
+
184
+ def _json_encoder() -> Callable[[JsonType], str]:
185
+ if typing.TYPE_CHECKING:
186
+ return _standard_json_encoder()
187
+ else:
188
+ try:
189
+ import orjson
190
+
191
+ def _wrap(value: JsonType) -> str:
192
+ return orjson.dumps(value).decode()
193
+
194
+ return _wrap
195
+ except ModuleNotFoundError:
196
+ return _standard_json_encoder()
197
+
198
+
199
+ JSON_ENCODER = _json_encoder()
200
+
201
+
202
+ def get_output_converter_for(tp: Any) -> Callable[[Any], Any]:
141
203
  """
142
204
  Returns a callable that takes a wire type and returns a target type.
143
205
 
@@ -147,101 +209,268 @@ def get_converter_for(tp: Any) -> Callable[[Any], Any]:
147
209
 
148
210
  if is_json_type(tp):
149
211
  # asyncpg returns fields of type `json` and `jsonb` as `str`, which must be de-serialized
150
- return _json_converter
212
+ return JSON_DECODER
151
213
  else:
152
214
  # target data types that require conversion must have a single-argument `__init__` that takes an object of the source type
153
215
  return tp
154
216
 
155
217
 
218
+ def get_input_converter_for(tp: Any) -> Callable[[Any], Any]:
219
+ """
220
+ Returns a callable that takes a source type and returns a wire type.
221
+
222
+ A source type is one of the types supported by the library.
223
+ A wire type is one of the types returned by asyncpg.
224
+ """
225
+
226
+ if is_json_type(tp):
227
+ # asyncpg expects fields of type `json` and `jsonb` as `str`, which must be serialized
228
+ return JSON_ENCODER
229
+ else:
230
+ # source data types that require conversion must have a single-argument `__init__` that takes an object of the source type
231
+ return tp
232
+
233
+
156
234
  # maps PostgreSQL internal type names to compatible Python types
157
- _name_to_type: dict[str, tuple[Any, ...]] = {
235
+ _NAME_TO_TYPE: dict[str, tuple[Any, ...]] = {
236
+ # boolean type
158
237
  "bool": (bool,),
238
+ # numeric types
159
239
  "int2": (int,),
160
240
  "int4": (int,),
161
241
  "int8": (int,),
162
242
  "float4": (float,),
163
243
  "float8": (float,),
164
244
  "numeric": (Decimal,),
245
+ # date and time types
165
246
  "date": (date,),
166
247
  "time": (time,),
167
248
  "timetz": (time,),
168
249
  "timestamp": (datetime,),
169
250
  "timestamptz": (datetime,),
170
251
  "interval": (timedelta,),
252
+ # character sequence types
171
253
  "bpchar": (str,),
172
254
  "varchar": (str,),
173
255
  "text": (str,),
256
+ # binary sequence types
174
257
  "bytea": (bytes,),
258
+ # unique identifier type
259
+ "uuid": (UUID,),
260
+ # address types
261
+ "cidr": (IPv4Network, IPv6Network, IPv4Network | IPv6Network),
262
+ "inet": (IPv4Network, IPv6Network, IPv4Network | IPv6Network, IPv4Address, IPv6Address, IPv4Address | IPv6Address),
263
+ "macaddr": (str,),
264
+ "macaddr8": (str,),
265
+ # JSON type
175
266
  "json": (str, RequiredJsonType),
176
267
  "jsonb": (str, RequiredJsonType),
177
- "uuid": (UUID,),
268
+ # XML type
178
269
  "xml": (str,),
270
+ # geometric types
271
+ "point": (asyncpg.Point,),
272
+ "line": (asyncpg.Line,),
273
+ "lseg": (asyncpg.LineSegment,),
274
+ "box": (asyncpg.Box,),
275
+ "path": (asyncpg.Path,),
276
+ "polygon": (asyncpg.Polygon,),
277
+ "circle": (asyncpg.Circle,),
278
+ # range types
279
+ "int4range": (asyncpg.Range[int],),
280
+ "int4multirange": (list[asyncpg.Range[int]],),
281
+ "int8range": (asyncpg.Range[int],),
282
+ "int8multirange": (list[asyncpg.Range[int]],),
283
+ "numrange": (asyncpg.Range[Decimal],),
284
+ "nummultirange": (list[asyncpg.Range[Decimal]],),
285
+ "tsrange": (asyncpg.Range[datetime],),
286
+ "tsmultirange": (list[asyncpg.Range[datetime]],),
287
+ "tstzrange": (asyncpg.Range[datetime],),
288
+ "tstzmultirange": (list[asyncpg.Range[datetime]],),
289
+ "daterange": (asyncpg.Range[date],),
290
+ "datemultirange": (list[asyncpg.Range[date]],),
179
291
  }
180
292
 
181
293
 
182
- def check_data_type(schema: str, name: str, data_type: TargetType) -> bool:
294
+ def type_to_str(tp: Any) -> str:
295
+ "Emits a friendly name for a type."
296
+
297
+ if isinstance(tp, type):
298
+ return tp.__name__
299
+ else:
300
+ return str(tp)
301
+
302
+
303
+ class _TypeVerifier:
183
304
  """
184
305
  Verifies if the Python target type can represent the PostgreSQL source type.
185
306
  """
186
307
 
187
- if schema == "pg_catalog":
188
- if is_enum_type(data_type):
189
- return name in ["bpchar", "varchar", "text"]
308
+ _connection: Connection
190
309
 
191
- expected_types = _name_to_type.get(name)
192
- return expected_types is not None and data_type in expected_types
193
- else:
194
- if is_standard_type(data_type):
195
- return False
310
+ def __init__(self, connection: Connection) -> None:
311
+ self._connection = connection
312
+
313
+ async def _check_enum_type(self, pg_name: str, pg_type: asyncpg.Type, data_type: type[enum.Enum]) -> None:
314
+ """
315
+ Verifies if a Python enumeration type matches a PostgreSQL enumeration type.
316
+ """
317
+
318
+ for e in data_type:
319
+ if not isinstance(e.value, str):
320
+ raise TypeMismatchError(f"expected: Python enum type `{type_to_str(data_type)}` with `str` values; got: `{type_to_str(type(e.value))}` for enum field `{e.name}`")
321
+
322
+ py_values = set(e.value for e in data_type)
196
323
 
197
- # user-defined type registered with `conn.set_type_codec()`
198
- return True
324
+ rows = await self._connection.fetch("SELECT enumlabel FROM pg_enum WHERE enumtypid = $1 ORDER BY enumsortorder;", pg_type.oid)
325
+ db_values = set(row[0] for row in rows)
199
326
 
327
+ db_extra = db_values - py_values
328
+ if db_extra:
329
+ raise EnumMismatchError(f"expected: Python enum type `{type_to_str(data_type)}` to match values of PostgreSQL enum type `{pg_type.name}` for {pg_name}; missing value(s): {', '.join(f'`{val}`' for val in db_extra)})")
200
330
 
201
- class _SQLPlaceholder:
331
+ py_extra = py_values - db_values
332
+ if py_extra:
333
+ raise EnumMismatchError(f"expected: Python enum type `{type_to_str(data_type)}` to match values of PostgreSQL enum type `{pg_type.name}` for {pg_name}; got extra value(s): {', '.join(f'`{val}`' for val in py_extra)})")
334
+
335
+ async def check_data_type(self, pg_name: str, pg_type: asyncpg.Type, data_type: TargetType) -> None:
336
+ """
337
+ Verifies if the Python target type can represent the PostgreSQL source type.
338
+ """
339
+
340
+ if pg_type.schema == "pg_catalog": # well-known PostgreSQL types
341
+ if is_enum_type(data_type):
342
+ if pg_type.name not in ["bpchar", "varchar", "text"]:
343
+ raise TypeMismatchError(f"expected: Python enum type `{type_to_str(data_type)}` for {pg_name}; got: PostgreSQL type `{pg_type.kind}` of `{pg_type.name}` instead of `char`, `varchar` or `text`")
344
+ elif pg_type.kind == "array" and get_origin(data_type) is list:
345
+ if not pg_type.name.endswith("[]"):
346
+ raise TypeMismatchError(f"expected: Python list type `{type_to_str(data_type)}` for {pg_name}; got: PostgreSQL type `{pg_type.kind}` of `{pg_type.name}` instead of array")
347
+
348
+ expected_types = _NAME_TO_TYPE.get(pg_type.name[:-2])
349
+ if expected_types is None:
350
+ raise TypeMismatchError(f"expected: Python list type `{type_to_str(data_type)}` for {pg_name}; got: unrecognized PostgreSQL type `{pg_type.kind}` of `{pg_type.name}`")
351
+ elif get_args(data_type)[0] not in expected_types:
352
+ if len(expected_types) == 1:
353
+ target = f"the Python type `{type_to_str(expected_types[0])}`"
354
+ else:
355
+ target = f"one of the Python types {', '.join(f'`{type_to_str(tp)}`' for tp in expected_types)}"
356
+ raise TypeMismatchError(f"expected: Python list type `{type_to_str(data_type)}` for {pg_name}; got: incompatible PostgreSQL type `{pg_type.kind}` of `{pg_type.name}` whose elements convert to {target}")
357
+ else:
358
+ expected_types = _NAME_TO_TYPE.get(pg_type.name)
359
+ if expected_types is None:
360
+ raise TypeMismatchError(f"expected: Python type `{type_to_str(data_type)}` for {pg_name}; got: unrecognized PostgreSQL type `{pg_type.kind}` of `{pg_type.name}`")
361
+ elif data_type not in expected_types:
362
+ if len(expected_types) == 1:
363
+ target = f"the Python type `{type_to_str(expected_types[0])}`"
364
+ else:
365
+ target = f"one of the Python types {', '.join(f'`{type_to_str(tp)}`' for tp in expected_types)}"
366
+ raise TypeMismatchError(f"expected: Python type `{type_to_str(data_type)}` for {pg_name}; got: incompatible PostgreSQL type `{pg_type.kind}` of `{pg_type.name}` which converts to {target}")
367
+ elif pg_type.kind == "composite": # PostgreSQL composite types
368
+ # user-defined composite types registered with `conn.set_type_codec()` typically using `format="tuple"`
369
+ pass
370
+ else: # custom PostgreSQL types
371
+ if is_enum_type(data_type):
372
+ await self._check_enum_type(pg_name, pg_type, data_type)
373
+ elif is_standard_type(data_type):
374
+ raise TypeMismatchError(f"expected: Python type `{type_to_str(data_type)}` for {pg_name}; got: PostgreSQL type `{pg_type.kind}` of `{pg_type.name}`")
375
+ else:
376
+ # user-defined types registered with `conn.set_type_codec()`
377
+ pass
378
+
379
+
380
+ @dataclass(frozen=True)
381
+ class _Placeholder:
202
382
  ordinal: int
203
383
  data_type: TargetType
204
384
 
205
- def __init__(self, ordinal: int, data_type: TargetType) -> None:
206
- self.ordinal = ordinal
207
- self.data_type = data_type
208
-
209
385
  def __repr__(self) -> str:
210
386
  return f"{self.__class__.__name__}({self.ordinal}, {self.data_type!r})"
211
387
 
212
388
 
389
+ class _ResultsetWrapper:
390
+ "Wraps result-set rows into a tuple or named tuple."
391
+
392
+ init: Callable[..., tuple[Any, ...]] | None
393
+ iterable: Callable[[Iterable[Any]], tuple[Any, ...]]
394
+
395
+ def __init__(self, init: Callable[..., tuple[Any, ...]] | None, iterable: Callable[[Iterable[Any]], tuple[Any, ...]]) -> None:
396
+ """
397
+ Initializes a result-set wrapper.
398
+
399
+ :param init: Initializer function that takes as many arguments as columns in the result-set.
400
+ :param iterable: Initializer function that takes an iterable over columns of a result-set row.
401
+ """
402
+
403
+ self.init = init
404
+ self.iterable = iterable
405
+
406
+
213
407
  class _SQLObject:
214
408
  """
215
409
  Associates input and output type information with a SQL statement.
216
410
  """
217
411
 
218
- parameter_data_types: tuple[_SQLPlaceholder, ...]
219
- resultset_data_types: tuple[TargetType, ...]
220
- required: int
221
- cast: int
222
- converters: tuple[Callable[[Any], Any], ...]
412
+ _parameter_data_types: tuple[_Placeholder, ...]
413
+ _resultset_data_types: tuple[TargetType, ...]
414
+ _resultset_column_names: tuple[str, ...] | None
415
+ _resultset_wrapper: _ResultsetWrapper
416
+ _parameter_cast: int
417
+ _parameter_converters: tuple[Callable[[Any], Any], ...]
418
+ _required: int
419
+ _resultset_cast: int
420
+ _resultset_converters: tuple[Callable[[Any], Any], ...]
421
+
422
+ @property
423
+ def parameter_data_types(self) -> tuple[_Placeholder, ...]:
424
+ "Expected inbound parameter data types."
425
+
426
+ return self._parameter_data_types
427
+
428
+ @property
429
+ def resultset_data_types(self) -> tuple[TargetType, ...]:
430
+ "Expected column data types in the result-set."
431
+
432
+ return self._resultset_data_types
433
+
434
+ @property
435
+ def resultset_column_names(self) -> tuple[str, ...] | None:
436
+ "Expected column names in the result-set."
437
+
438
+ return self._resultset_column_names
223
439
 
224
440
  def __init__(
225
441
  self,
226
- input_data_types: tuple[TargetType, ...],
227
- output_data_types: tuple[TargetType, ...],
442
+ *,
443
+ args: tuple[TargetType, ...],
444
+ resultset: tuple[TargetType, ...],
445
+ names: tuple[str, ...] | None,
446
+ wrapper: _ResultsetWrapper,
228
447
  ) -> None:
229
- self.parameter_data_types = tuple(_SQLPlaceholder(ordinal, get_required_type(arg)) for ordinal, arg in enumerate(input_data_types, start=1))
230
- self.resultset_data_types = tuple(get_required_type(data_type) for data_type in output_data_types)
448
+ self._parameter_data_types = tuple(_Placeholder(ordinal, get_required_type(arg)) for ordinal, arg in enumerate(args, start=1))
449
+ self._resultset_data_types = tuple(get_required_type(data_type) for data_type in resultset)
450
+ self._resultset_column_names = names
451
+ self._resultset_wrapper = wrapper
452
+
453
+ # create a bit-field of types that require cast or serialization (1: apply conversion; 0: forward value as-is)
454
+ parameter_cast = 0
455
+ for index, placeholder in enumerate(self._parameter_data_types):
456
+ parameter_cast |= is_json_type(placeholder.data_type) << index
457
+ self._parameter_cast = parameter_cast
458
+
459
+ self._parameter_converters = tuple(get_input_converter_for(placeholder.data_type) for placeholder in self._parameter_data_types)
231
460
 
232
461
  # create a bit-field of required types (1: required; 0: optional)
233
462
  required = 0
234
- for index, data_type in enumerate(output_data_types):
463
+ for index, data_type in enumerate(resultset):
235
464
  required |= (not is_optional_type(data_type)) << index
236
- self.required = required
465
+ self._required = required
237
466
 
238
- # create a bit-field of types that require cast or serialization (1: apply conversion; 0: forward value as-is)
239
- cast = 0
240
- for index, data_type in enumerate(self.resultset_data_types):
241
- cast |= (is_enum_type(data_type) or is_json_type(data_type)) << index
242
- self.cast = cast
467
+ # create a bit-field of types that require cast or de-serialization (1: apply conversion; 0: forward value as-is)
468
+ resultset_cast = 0
469
+ for index, data_type in enumerate(self._resultset_data_types):
470
+ resultset_cast |= (is_enum_type(data_type) or is_json_type(data_type) or is_inet_type(data_type)) << index
471
+ self._resultset_cast = resultset_cast
243
472
 
244
- self.converters = tuple(get_converter_for(data_type) for data_type in self.resultset_data_types)
473
+ self._resultset_converters = tuple(get_output_converter_for(data_type) for data_type in self._resultset_data_types)
245
474
 
246
475
  def _raise_required_is_none(self, row: tuple[Any, ...], row_index: int | None = None) -> None:
247
476
  """
@@ -249,12 +478,12 @@ class _SQLObject:
249
478
  """
250
479
 
251
480
  for col_index in range(len(row)):
252
- if (self.required >> col_index & 1) and row[col_index] is None:
481
+ if (self._required >> col_index & 1) and row[col_index] is None:
253
482
  if row_index is not None:
254
483
  row_col_spec = f"row #{row_index} and column #{col_index}"
255
484
  else:
256
485
  row_col_spec = f"column #{col_index}"
257
- raise TypeError(f"expected: {self.resultset_data_types[col_index]} in {row_col_spec}; got: NULL")
486
+ raise NoneTypeError(f"expected: {self._resultset_data_types[col_index]} in {row_col_spec}; got: NULL")
258
487
 
259
488
  def check_rows(self, rows: list[tuple[Any, ...]]) -> None:
260
489
  """
@@ -264,7 +493,7 @@ class _SQLObject:
264
493
  if not rows:
265
494
  return
266
495
 
267
- required = self.required
496
+ required = self._required
268
497
  if not required:
269
498
  return
270
499
 
@@ -317,7 +546,7 @@ class _SQLObject:
317
546
  Verifies if declared types match actual value types in a single row.
318
547
  """
319
548
 
320
- required = self.required
549
+ required = self._required
321
550
  if not required:
322
551
  return
323
552
 
@@ -356,13 +585,187 @@ class _SQLObject:
356
585
  case _:
357
586
  self._raise_required_is_none(row)
358
587
 
588
+ def check_column(self, column: list[Any]) -> None:
589
+ """
590
+ Verifies if the declared type matches the actual value type of a single-column resultset.
591
+ """
592
+
593
+ if self._required:
594
+ for i, value in enumerate(column):
595
+ if value is None:
596
+ raise NoneTypeError(f"expected: {self._resultset_data_types[0]} in row #{i}; got: NULL")
597
+
359
598
  def check_value(self, value: Any) -> None:
360
599
  """
361
600
  Verifies if the declared type matches the actual value type.
362
601
  """
363
602
 
364
- if self.required and value is None:
365
- raise TypeError(f"expected: {self.resultset_data_types[0]}; got: NULL")
603
+ if self._required and value is None:
604
+ raise NoneTypeError(f"expected: {self._resultset_data_types[0]}; got: NULL")
605
+
606
+ def convert_arg_lists(self, arg_lists: Iterable[Sequence[Any]]) -> Iterable[Sequence[Any]]:
607
+ """
608
+ Converts a list of Python query argument tuples to a list of PostgreSQL parameter tuples.
609
+ """
610
+
611
+ cast = self._parameter_cast
612
+ if cast:
613
+ converters = self._parameter_converters
614
+ yield from (tuple((converters[i](value) if (value := arg[i]) is not None and cast >> i & 1 else value) for i in range(len(arg))) for arg in arg_lists)
615
+ else:
616
+ yield from arg_lists
617
+
618
+ def convert_arg_list(self, arg_list: Sequence[Any]) -> Sequence[Any]:
619
+ """
620
+ Converts Python query arguments to PostgreSQL parameters.
621
+ """
622
+
623
+ cast = self._parameter_cast
624
+ if cast:
625
+ converters = self._parameter_converters
626
+ return tuple((converters[i](value) if (value := arg_list[i]) is not None and cast >> i & 1 else value) for i in range(len(arg_list)))
627
+ else:
628
+ return tuple(value for value in arg_list)
629
+
630
+ def convert_rows(self, rows: list[asyncpg.Record]) -> list[tuple[Any, ...]]:
631
+ """
632
+ Converts columns in the PostgreSQL result-set to their corresponding Python target types.
633
+
634
+ :param rows: List of rows returned by PostgreSQL.
635
+ :returns: List of tuples with each tuple element having the configured Python target type.
636
+ """
637
+
638
+ if not rows:
639
+ return []
640
+
641
+ init_wrapper = self._resultset_wrapper.init
642
+ iterable_wrapper = self._resultset_wrapper.iterable
643
+ cast = self._resultset_cast
644
+ if not cast:
645
+ return [iterable_wrapper(row.values()) for row in rows]
646
+
647
+ columns = len(rows[0])
648
+ match columns:
649
+ case 1:
650
+ converter = self._resultset_converters[0]
651
+ if init_wrapper is not None:
652
+ return [init_wrapper(converter(value) if (value := row[0]) is not None else value) for row in rows]
653
+ else:
654
+ return [(converter(value) if (value := row[0]) is not None else value,) for row in rows]
655
+ case 2:
656
+ conv_a, conv_b = self._resultset_converters
657
+ cast_a = cast >> 0 & 1
658
+ cast_b = cast >> 1 & 1
659
+ if init_wrapper is not None:
660
+ return [
661
+ init_wrapper(
662
+ conv_a(value) if (value := row[0]) is not None and cast_a else value,
663
+ conv_b(value) if (value := row[1]) is not None and cast_b else value,
664
+ )
665
+ for row in rows
666
+ ]
667
+ else:
668
+ return [
669
+ (
670
+ conv_a(value) if (value := row[0]) is not None and cast_a else value,
671
+ conv_b(value) if (value := row[1]) is not None and cast_b else value,
672
+ )
673
+ for row in rows
674
+ ]
675
+ case 3:
676
+ conv_a, conv_b, conv_c = self._resultset_converters
677
+ cast_a = cast >> 0 & 1
678
+ cast_b = cast >> 1 & 1
679
+ cast_c = cast >> 2 & 1
680
+ if init_wrapper is not None:
681
+ return [
682
+ init_wrapper(
683
+ conv_a(value) if (value := row[0]) is not None and cast_a else value,
684
+ conv_b(value) if (value := row[1]) is not None and cast_b else value,
685
+ conv_c(value) if (value := row[2]) is not None and cast_c else value,
686
+ )
687
+ for row in rows
688
+ ]
689
+ else:
690
+ return [
691
+ (
692
+ conv_a(value) if (value := row[0]) is not None and cast_a else value,
693
+ conv_b(value) if (value := row[1]) is not None and cast_b else value,
694
+ conv_c(value) if (value := row[2]) is not None and cast_c else value,
695
+ )
696
+ for row in rows
697
+ ]
698
+ case 4:
699
+ conv_a, conv_b, conv_c, conv_d = self._resultset_converters
700
+ cast_a = cast >> 0 & 1
701
+ cast_b = cast >> 1 & 1
702
+ cast_c = cast >> 2 & 1
703
+ cast_d = cast >> 3 & 1
704
+ if init_wrapper is not None:
705
+ return [
706
+ init_wrapper(
707
+ conv_a(value) if (value := row[0]) is not None and cast_a else value,
708
+ conv_b(value) if (value := row[1]) is not None and cast_b else value,
709
+ conv_c(value) if (value := row[2]) is not None and cast_c else value,
710
+ conv_d(value) if (value := row[3]) is not None and cast_d else value,
711
+ )
712
+ for row in rows
713
+ ]
714
+ else:
715
+ return [
716
+ (
717
+ conv_a(value) if (value := row[0]) is not None and cast_a else value,
718
+ conv_b(value) if (value := row[1]) is not None and cast_b else value,
719
+ conv_c(value) if (value := row[2]) is not None and cast_c else value,
720
+ conv_d(value) if (value := row[3]) is not None and cast_d else value,
721
+ )
722
+ for row in rows
723
+ ]
724
+ case _:
725
+ converters = self._resultset_converters
726
+ return [iterable_wrapper((converters[i](value) if (value := row[i]) is not None and cast >> i & 1 else value) for i in range(columns)) for row in rows]
727
+
728
+ def convert_row(self, row: asyncpg.Record) -> tuple[Any, ...]:
729
+ """
730
+ Converts columns in the PostgreSQL result-set to their corresponding Python target types.
731
+
732
+ :param row: A single row returned by PostgreSQL.
733
+ :returns: A tuple with each tuple element having the configured Python target type.
734
+ """
735
+
736
+ wrapper = self._resultset_wrapper.iterable
737
+ cast = self._resultset_cast
738
+ if cast:
739
+ converters = self._resultset_converters
740
+ columns = len(row)
741
+ return wrapper((converters[i](value) if (value := row[i]) is not None and cast >> i & 1 else value) for i in range(columns))
742
+ else:
743
+ return wrapper(row.values())
744
+
745
+ def convert_column(self, rows: list[asyncpg.Record]) -> list[Any]:
746
+ """
747
+ Converts a single column in the PostgreSQL result-set to its corresponding Python target type.
748
+
749
+ :param rows: List of rows returned by PostgreSQL.
750
+ :returns: List of values having the configured Python target type.
751
+ """
752
+
753
+ cast = self._resultset_cast
754
+ if cast:
755
+ converter = self._resultset_converters[0]
756
+ return [(converter(value) if (value := row[0]) is not None else value) for row in rows]
757
+ else:
758
+ return [row[0] for row in rows]
759
+
760
+ def convert_value(self, value: Any) -> Any:
761
+ """
762
+ Converts a single PostgreSQL value to its corresponding Python target type.
763
+
764
+ :param value: A single value returned by PostgreSQL.
765
+ :returns: A converted value having the configured Python target type.
766
+ """
767
+
768
+ return self._resultset_converters[0](value) if value is not None and self._resultset_cast else value
366
769
 
367
770
  @abstractmethod
368
771
  def query(self) -> str:
@@ -388,8 +791,8 @@ if sys.version_info >= (3, 14):
388
791
  A SQL query specified with the Python t-string syntax.
389
792
  """
390
793
 
391
- strings: tuple[str, ...]
392
- placeholders: tuple[_SQLPlaceholder, ...]
794
+ _strings: tuple[str, ...]
795
+ _placeholders: tuple[_Placeholder, ...]
393
796
 
394
797
  def __init__(
395
798
  self,
@@ -397,8 +800,10 @@ if sys.version_info >= (3, 14):
397
800
  *,
398
801
  args: tuple[TargetType, ...],
399
802
  resultset: tuple[TargetType, ...],
803
+ names: tuple[str, ...] | None,
804
+ wrapper: _ResultsetWrapper,
400
805
  ) -> None:
401
- super().__init__(args, resultset)
806
+ super().__init__(args=args, resultset=resultset, names=names, wrapper=wrapper)
402
807
 
403
808
  for ip in template.interpolations:
404
809
  if ip.conversion is not None:
@@ -408,26 +813,26 @@ if sys.version_info >= (3, 14):
408
813
  if not isinstance(ip.value, int):
409
814
  raise TypeError(f"interpolation `{ip.expression}` expected to evaluate to an integer")
410
815
 
411
- self.strings = template.strings
816
+ self._strings = template.strings
412
817
 
413
818
  if len(self.parameter_data_types) > 0:
414
819
 
415
- def _to_placeholder(ip: Interpolation) -> _SQLPlaceholder:
820
+ def _to_placeholder(ip: Interpolation) -> _Placeholder:
416
821
  ordinal = int(ip.value)
417
822
  if not (0 < ordinal <= len(self.parameter_data_types)):
418
823
  raise IndexError(f"interpolation `{ip.expression}` is an ordinal out of range; expected: 0 < value <= {len(self.parameter_data_types)}")
419
824
  return self.parameter_data_types[int(ip.value) - 1]
420
825
 
421
- self.placeholders = tuple(_to_placeholder(ip) for ip in template.interpolations)
826
+ self._placeholders = tuple(_to_placeholder(ip) for ip in template.interpolations)
422
827
  else:
423
- self.placeholders = ()
828
+ self._placeholders = ()
424
829
 
425
830
  def query(self) -> str:
426
831
  buf = StringIO()
427
- for s, p in zip(self.strings[:-1], self.placeholders, strict=True):
832
+ for s, p in zip(self._strings[:-1], self._placeholders, strict=True):
428
833
  buf.write(s)
429
834
  buf.write(f"${p.ordinal}")
430
- buf.write(self.strings[-1])
835
+ buf.write(self._strings[-1])
431
836
  return buf.getvalue()
432
837
 
433
838
  else:
@@ -439,7 +844,7 @@ class _SQLString(_SQLObject):
439
844
  A SQL query specified as a plain string (e.g. f-string).
440
845
  """
441
846
 
442
- sql: str
847
+ _sql: str
443
848
 
444
849
  def __init__(
445
850
  self,
@@ -447,12 +852,14 @@ class _SQLString(_SQLObject):
447
852
  *,
448
853
  args: tuple[TargetType, ...],
449
854
  resultset: tuple[TargetType, ...],
855
+ names: tuple[str, ...] | None,
856
+ wrapper: _ResultsetWrapper,
450
857
  ) -> None:
451
- super().__init__(args, resultset)
452
- self.sql = sql
858
+ super().__init__(args=args, resultset=resultset, names=names, wrapper=wrapper)
859
+ self._sql = sql
453
860
 
454
861
  def query(self) -> str:
455
- return self.sql
862
+ return self._sql
456
863
 
457
864
 
458
865
  class _SQL(Protocol):
@@ -461,80 +868,90 @@ class _SQL(Protocol):
461
868
  """
462
869
 
463
870
 
464
- Connection: TypeAlias = asyncpg.Connection | asyncpg.pool.PoolConnectionProxy
465
-
466
-
467
871
  class _SQLImpl(_SQL):
468
872
  """
469
873
  Forwards input data to an `asyncpg.PreparedStatement`, and validates output data (if necessary).
470
874
  """
471
875
 
472
- sql: _SQLObject
876
+ _sql: _SQLObject
473
877
 
474
878
  def __init__(self, sql: _SQLObject) -> None:
475
- self.sql = sql
879
+ self._sql = sql
476
880
 
477
881
  def __str__(self) -> str:
478
- return str(self.sql)
882
+ return str(self._sql)
479
883
 
480
884
  def __repr__(self) -> str:
481
- return repr(self.sql)
885
+ return repr(self._sql)
482
886
 
483
887
  async def _prepare(self, connection: Connection) -> PreparedStatement:
484
- stmt = await connection.prepare(self.sql.query())
485
-
486
- for attr, data_type in zip(stmt.get_attributes(), self.sql.resultset_data_types, strict=True):
487
- if not check_data_type(attr.type.schema, attr.type.name, data_type):
488
- raise TypeError(f"expected: {data_type} in column `{attr.name}`; got: `{attr.type.kind}` of `{attr.type.name}`")
888
+ stmt = await connection.prepare(self._sql.query())
889
+
890
+ verifier = _TypeVerifier(connection)
891
+
892
+ input_count = len(self._sql.parameter_data_types)
893
+ parameter_count = len(stmt.get_parameters())
894
+ if parameter_count != input_count:
895
+ raise CountMismatchError(f"expected: PostgreSQL query to take {input_count} parameter(s); got: {parameter_count}")
896
+ for param, placeholder in zip(stmt.get_parameters(), self._sql.parameter_data_types, strict=True):
897
+ await verifier.check_data_type(f"parameter ${placeholder.ordinal}", param, placeholder.data_type)
898
+
899
+ output_count = len(self._sql.resultset_data_types)
900
+ column_count = len(stmt.get_attributes())
901
+ if column_count != output_count:
902
+ raise CountMismatchError(f"expected: PostgreSQL query to return {output_count} column(s) in result-set; got: {column_count}")
903
+ if self._sql.resultset_column_names is not None:
904
+ for index, attr, name in zip(range(output_count), stmt.get_attributes(), self._sql.resultset_column_names, strict=True):
905
+ if attr.name != name:
906
+ raise NameMismatchError(f"expected: Python field name `{name}` to match PostgreSQL result-set column name `{attr.name}` for index #{index}")
907
+ for attr, data_type in zip(stmt.get_attributes(), self._sql.resultset_data_types, strict=True):
908
+ await verifier.check_data_type(f"column `{attr.name}`", attr.type, data_type)
489
909
 
490
910
  return stmt
491
911
 
492
912
  async def execute(self, connection: asyncpg.Connection, *args: Any) -> None:
493
- await connection.execute(self.sql.query(), *args)
913
+ await connection.execute(self._sql.query(), *self._sql.convert_arg_list(args))
494
914
 
495
915
  async def executemany(self, connection: asyncpg.Connection, args: Iterable[Sequence[Any]]) -> None:
496
916
  stmt = await self._prepare(connection)
497
- await stmt.executemany(args)
917
+ await stmt.executemany(self._sql.convert_arg_lists(args))
498
918
 
499
919
  def _cast_fetch(self, rows: list[asyncpg.Record]) -> list[tuple[Any, ...]]:
500
- cast = self.sql.cast
501
- if cast:
502
- converters = self.sql.converters
503
- resultset = [tuple((converters[i](value) if (value := row[i]) is not None and cast >> i & 1 else value) for i in range(len(row))) for row in rows]
504
- else:
505
- resultset = [tuple(value for value in row) for row in rows]
506
- self.sql.check_rows(resultset)
920
+ resultset = self._sql.convert_rows(rows)
921
+ self._sql.check_rows(resultset)
507
922
  return resultset
508
923
 
509
924
  async def fetch(self, connection: asyncpg.Connection, *args: Any) -> list[tuple[Any, ...]]:
510
925
  stmt = await self._prepare(connection)
511
- rows = await stmt.fetch(*args)
926
+ rows = await stmt.fetch(*self._sql.convert_arg_list(args))
512
927
  return self._cast_fetch(rows)
513
928
 
514
929
  async def fetchmany(self, connection: asyncpg.Connection, args: Iterable[Sequence[Any]]) -> list[tuple[Any, ...]]:
515
930
  stmt = await self._prepare(connection)
516
- rows = await stmt.fetchmany(args)
931
+ rows = await stmt.fetchmany(self._sql.convert_arg_lists(args))
517
932
  return self._cast_fetch(rows)
518
933
 
519
934
  async def fetchrow(self, connection: asyncpg.Connection, *args: Any) -> tuple[Any, ...] | None:
520
935
  stmt = await self._prepare(connection)
521
- row = await stmt.fetchrow(*args)
936
+ row = await stmt.fetchrow(*self._sql.convert_arg_list(args))
522
937
  if row is None:
523
938
  return None
524
- cast = self.sql.cast
525
- if cast:
526
- converters = self.sql.converters
527
- resultset = tuple((converters[i](value) if (value := row[i]) is not None and cast >> i & 1 else value) for i in range(len(row)))
528
- else:
529
- resultset = tuple(value for value in row)
530
- self.sql.check_row(resultset)
939
+ resultset = self._sql.convert_row(row)
940
+ self._sql.check_row(resultset)
531
941
  return resultset
532
942
 
943
+ async def fetchcol(self, connection: asyncpg.Connection, *args: Any) -> list[Any]:
944
+ stmt = await self._prepare(connection)
945
+ rows = await stmt.fetch(*self._sql.convert_arg_list(args))
946
+ column = self._sql.convert_column(rows)
947
+ self._sql.check_column(column)
948
+ return column
949
+
533
950
  async def fetchval(self, connection: asyncpg.Connection, *args: Any) -> Any:
534
951
  stmt = await self._prepare(connection)
535
- value = await stmt.fetchval(*args)
536
- result = self.sql.converters[0](value) if value is not None and self.sql.cast else value
537
- self.sql.check_value(result)
952
+ value = await stmt.fetchval(*self._sql.convert_arg_list(args))
953
+ result = self._sql.convert_value(value)
954
+ self._sql.check_value(result)
538
955
  return result
539
956
 
540
957
 
@@ -547,9 +964,7 @@ R2 = TypeVar("R2")
547
964
  RX = TypeVarTuple("RX")
548
965
 
549
966
 
550
- ### START OF AUTO-GENERATED BLOCK ###
551
-
552
-
967
+ ### START OF AUTO-GENERATED BLOCK FOR Protocol ###
553
968
  class SQL_P0(Protocol):
554
969
  @abstractmethod
555
970
  async def execute(self, connection: Connection) -> None: ...
@@ -561,6 +976,8 @@ class SQL_R1_P0(SQL_P0, Protocol[R1]):
561
976
  @abstractmethod
562
977
  async def fetchrow(self, connection: Connection) -> tuple[R1] | None: ...
563
978
  @abstractmethod
979
+ async def fetchcol(self, connection: Connection) -> list[R1]: ...
980
+ @abstractmethod
564
981
  async def fetchval(self, connection: Connection) -> R1: ...
565
982
 
566
983
 
@@ -586,6 +1003,8 @@ class SQL_R1_PX(SQL_PX[Unpack[PX]], Protocol[R1, Unpack[PX]]):
586
1003
  @abstractmethod
587
1004
  async def fetchrow(self, connection: Connection, *args: Unpack[PX]) -> tuple[R1] | None: ...
588
1005
  @abstractmethod
1006
+ async def fetchcol(self, connection: Connection, *args: Unpack[PX]) -> list[R1]: ...
1007
+ @abstractmethod
589
1008
  async def fetchval(self, connection: Connection, *args: Unpack[PX]) -> R1: ...
590
1009
 
591
1010
 
@@ -598,84 +1017,153 @@ class SQL_RX_PX(SQL_PX[Unpack[PX]], Protocol[RT, Unpack[PX]]):
598
1017
  async def fetchrow(self, connection: Connection, *args: Unpack[PX]) -> RT | None: ...
599
1018
 
600
1019
 
601
- @overload
602
- def sql(stmt: SQLExpression) -> SQL_P0: ...
603
- @overload
604
- def sql(stmt: SQLExpression, *, result: type[R1]) -> SQL_R1_P0[R1]: ...
605
- @overload
606
- def sql(stmt: SQLExpression, *, resultset: type[tuple[R1]]) -> SQL_R1_P0[R1]: ...
607
- @overload
608
- def sql(stmt: SQLExpression, *, resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_RX_P0[tuple[R1, R2, Unpack[RX]]]: ...
609
- @overload
610
- def sql(stmt: SQLExpression, *, arg: type[P1]) -> SQL_PX[P1]: ...
611
- @overload
612
- def sql(stmt: SQLExpression, *, arg: type[P1], result: type[R1]) -> SQL_R1_PX[R1, P1]: ...
613
- @overload
614
- def sql(stmt: SQLExpression, *, arg: type[P1], resultset: type[tuple[R1]]) -> SQL_R1_PX[R1, P1]: ...
615
- @overload
616
- def sql(stmt: SQLExpression, *, arg: type[P1], resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_RX_PX[tuple[R1, R2, Unpack[RX]], P1]: ...
617
- @overload
618
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]]) -> SQL_PX[P1, Unpack[PX]]: ...
619
- @overload
620
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]], result: type[R1]) -> SQL_R1_PX[R1, P1, Unpack[PX]]: ...
621
- @overload
622
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]], resultset: type[tuple[R1]]) -> SQL_R1_PX[R1, P1, Unpack[PX]]: ...
623
- @overload
624
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]], resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_RX_PX[tuple[R1, R2, Unpack[RX]], P1, Unpack[PX]]: ...
625
-
626
-
627
- ### END OF AUTO-GENERATED BLOCK ###
628
-
629
-
630
- def sql(
631
- stmt: SQLExpression,
632
- *,
633
- args: type[Any] | None = None,
634
- resultset: type[Any] | None = None,
635
- arg: type[Any] | None = None,
636
- result: type[Any] | None = None,
637
- ) -> _SQL:
638
- """
639
- Creates a SQL statement with associated type information.
1020
+ ### END OF AUTO-GENERATED BLOCK FOR Protocol ###
1021
+
1022
+ RS = TypeVar("RS", bound=tuple[Any, ...])
640
1023
 
641
- :param stmt: SQL statement as a literal string or template.
642
- :param args: Type signature for multiple input parameters (e.g. `tuple[bool, int, str]`).
643
- :param resultset: Type signature for multiple resultset columns (e.g. `tuple[datetime, Decimal, str]`).
644
- :param arg: Type signature for a single input parameter (e.g. `int`).
645
- :param result: Type signature for a single result column (e.g. `UUID`).
1024
+
1025
+ class SQLFactory:
1026
+ """
1027
+ Creates type-safe SQL queries.
646
1028
  """
647
1029
 
1030
+ ### START OF AUTO-GENERATED BLOCK FOR sql ###
1031
+ @overload
1032
+ def sql(self, stmt: SQLExpression) -> SQL_P0: ...
1033
+ @overload
1034
+ def sql(self, stmt: SQLExpression, *, result: type[R1]) -> SQL_R1_P0[R1]: ...
1035
+ @overload
1036
+ def sql(self, stmt: SQLExpression, *, resultset: type[RS]) -> SQL_RX_P0[RS]: ...
1037
+ @overload
1038
+ def sql(self, stmt: SQLExpression, *, arg: type[P1]) -> SQL_PX[P1]: ...
1039
+ @overload
1040
+ def sql(self, stmt: SQLExpression, *, arg: type[P1], result: type[R1]) -> SQL_R1_PX[R1, P1]: ...
1041
+ @overload
1042
+ def sql(self, stmt: SQLExpression, *, arg: type[P1], resultset: type[RS]) -> SQL_RX_PX[RS, P1]: ...
1043
+ @overload
1044
+ def sql(self, stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]]) -> SQL_PX[P1, Unpack[PX]]: ...
1045
+ @overload
1046
+ def sql(self, stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]], result: type[R1]) -> SQL_R1_PX[R1, P1, Unpack[PX]]: ...
1047
+ @overload
1048
+ def sql(self, stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]], resultset: type[RS]) -> SQL_RX_PX[RS, P1, Unpack[PX]]: ...
1049
+
1050
+ ### END OF AUTO-GENERATED BLOCK FOR sql ###
1051
+
1052
+ def sql(self, stmt: SQLExpression, *, args: type[Any] | None = None, resultset: type[Any] | None = None, arg: type[Any] | None = None, result: type[Any] | None = None) -> _SQL:
1053
+ """
1054
+ Creates a SQL statement with associated type information.
1055
+
1056
+ :param stmt: SQL statement as a literal string or template.
1057
+ :param args: Type signature for multiple input parameters (e.g. `tuple[bool, int, str]`).
1058
+ :param resultset: Type signature for multiple resultset columns (e.g. `tuple[datetime, Decimal, str]`).
1059
+ :param arg: Type signature for a single input parameter (e.g. `int`).
1060
+ :param result: Type signature for a single result column (e.g. `UUID`).
1061
+ """
1062
+
1063
+ input_data_types, output_data_types, names, wrapper = _sql_args_resultset(args=args, resultset=resultset, arg=arg, result=result)
1064
+
1065
+ obj: _SQLObject
1066
+ if sys.version_info >= (3, 14):
1067
+ match stmt:
1068
+ case Template():
1069
+ obj = _SQLTemplate(stmt, args=input_data_types, resultset=output_data_types, names=names, wrapper=wrapper)
1070
+ case str():
1071
+ obj = _SQLString(stmt, args=input_data_types, resultset=output_data_types, names=names, wrapper=wrapper)
1072
+ else:
1073
+ obj = _SQLString(stmt, args=input_data_types, resultset=output_data_types, names=names, wrapper=wrapper)
1074
+
1075
+ return _SQLImpl(obj)
1076
+
1077
+ ### START OF AUTO-GENERATED BLOCK FOR unsafe_sql ###
1078
+ @overload
1079
+ def unsafe_sql(self, stmt: str) -> SQL_P0: ...
1080
+ @overload
1081
+ def unsafe_sql(self, stmt: str, *, result: type[R1]) -> SQL_R1_P0[R1]: ...
1082
+ @overload
1083
+ def unsafe_sql(self, stmt: str, *, resultset: type[RS]) -> SQL_RX_P0[RS]: ...
1084
+ @overload
1085
+ def unsafe_sql(self, stmt: str, *, arg: type[P1]) -> SQL_PX[P1]: ...
1086
+ @overload
1087
+ def unsafe_sql(self, stmt: str, *, arg: type[P1], result: type[R1]) -> SQL_R1_PX[R1, P1]: ...
1088
+ @overload
1089
+ def unsafe_sql(self, stmt: str, *, arg: type[P1], resultset: type[RS]) -> SQL_RX_PX[RS, P1]: ...
1090
+ @overload
1091
+ def unsafe_sql(self, stmt: str, *, args: type[tuple[P1, Unpack[PX]]]) -> SQL_PX[P1, Unpack[PX]]: ...
1092
+ @overload
1093
+ def unsafe_sql(self, stmt: str, *, args: type[tuple[P1, Unpack[PX]]], result: type[R1]) -> SQL_R1_PX[R1, P1, Unpack[PX]]: ...
1094
+ @overload
1095
+ def unsafe_sql(self, stmt: str, *, args: type[tuple[P1, Unpack[PX]]], resultset: type[RS]) -> SQL_RX_PX[RS, P1, Unpack[PX]]: ...
1096
+
1097
+ ### END OF AUTO-GENERATED BLOCK FOR unsafe_sql ###
1098
+
1099
+ def unsafe_sql(self, stmt: str, *, args: type[Any] | None = None, resultset: type[Any] | None = None, arg: type[Any] | None = None, result: type[Any] | None = None) -> _SQL:
1100
+ """
1101
+ Creates a SQL statement with associated type information from a string.
1102
+
1103
+ This offers an alternative to the function :func:`sql` when we want to prevent the type checker from enforcing
1104
+ a string literal, e.g. when we want to embed a variable as the table name to dynamically create a SQL statement.
1105
+
1106
+ :param stmt: SQL statement as a string (or f-string).
1107
+ :param args: Type signature for multiple input parameters (e.g. `tuple[bool, int, str]`).
1108
+ :param resultset: Type signature for multiple resultset columns (e.g. `tuple[datetime, Decimal, str]`).
1109
+ :param arg: Type signature for a single input parameter (e.g. `int`).
1110
+ :param result: Type signature for a single result column (e.g. `UUID`).
1111
+ """
1112
+
1113
+ input_data_types, output_data_types, names, wrapper = _sql_args_resultset(args=args, resultset=resultset, arg=arg, result=result)
1114
+ obj = _SQLString(stmt, args=input_data_types, resultset=output_data_types, names=names, wrapper=wrapper)
1115
+ return _SQLImpl(obj)
1116
+
1117
+
1118
+ def _sql_args_resultset(
1119
+ *, args: type[Any] | None = None, resultset: type[Any] | None = None, arg: type[Any] | None = None, result: type[Any] | None = None
1120
+ ) -> tuple[tuple[Any, ...], tuple[Any, ...], tuple[str, ...] | None, _ResultsetWrapper]:
1121
+ "Parses an argument/resultset signature into input/output types."
1122
+
648
1123
  if args is not None and arg is not None:
649
1124
  raise TypeError("expected: either `args` or `arg`; got: both")
650
1125
  if resultset is not None and result is not None:
651
1126
  raise TypeError("expected: either `resultset` or `result`; got: both")
652
1127
 
653
1128
  if args is not None:
654
- if get_origin(args) is not tuple:
655
- raise TypeError(f"expected: `type[tuple[T, ...]]` for `args`; got: {type(args)}")
656
- input_data_types = get_args(args)
1129
+ if hasattr(args, "_asdict") and hasattr(args, "_fields"):
1130
+ # named tuple
1131
+ input_data_types = tuple(tp for tp in args.__annotations__.values())
1132
+ else:
1133
+ # regular tuple
1134
+ if get_origin(args) is not tuple:
1135
+ raise TypeError(f"expected: `type[tuple[T, ...]]` for `args`; got: {args}")
1136
+ input_data_types = get_args(args)
657
1137
  elif arg is not None:
658
1138
  input_data_types = (arg,)
659
1139
  else:
660
1140
  input_data_types = ()
661
1141
 
662
1142
  if resultset is not None:
663
- if get_origin(resultset) is not tuple:
664
- raise TypeError(f"expected: `type[tuple[T, ...]]` for `resultset`; got: {type(resultset)}")
665
- output_data_types = get_args(resultset)
666
- elif result is not None:
667
- output_data_types = (result,)
1143
+ if hasattr(resultset, "_asdict") and hasattr(resultset, "_fields") and hasattr(resultset, "_make"):
1144
+ # named tuple
1145
+ output_data_types = tuple(tp for tp in resultset.__annotations__.values())
1146
+ names = tuple(f for f in resultset._fields)
1147
+ wrapper = _ResultsetWrapper(resultset, resultset._make)
1148
+ else:
1149
+ # regular tuple
1150
+ if get_origin(resultset) is not tuple:
1151
+ raise TypeError(f"expected: `type[tuple[T, ...]]` for `resultset`; got: {resultset}")
1152
+ output_data_types = get_args(resultset)
1153
+ names = None
1154
+ wrapper = _ResultsetWrapper(None, tuple)
668
1155
  else:
669
- output_data_types = ()
1156
+ if result is not None:
1157
+ output_data_types = (result,)
1158
+ else:
1159
+ output_data_types = ()
1160
+ names = None
1161
+ wrapper = _ResultsetWrapper(None, tuple)
670
1162
 
671
- if sys.version_info >= (3, 14):
672
- obj: _SQLObject
673
- match stmt:
674
- case Template():
675
- obj = _SQLTemplate(stmt, args=input_data_types, resultset=output_data_types)
676
- case str():
677
- obj = _SQLString(stmt, args=input_data_types, resultset=output_data_types)
678
- else:
679
- obj = _SQLString(stmt, args=input_data_types, resultset=output_data_types)
1163
+ return input_data_types, output_data_types, names, wrapper
1164
+
1165
+
1166
+ FACTORY: SQLFactory = SQLFactory()
680
1167
 
681
- return _SQLImpl(obj)
1168
+ sql = FACTORY.sql
1169
+ unsafe_sql = FACTORY.unsafe_sql