asyncpg-typed 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
asyncpg_typed/__init__.py CHANGED
@@ -4,7 +4,7 @@ Type-safe queries for asyncpg.
4
4
  :see: https://github.com/hunyadi/asyncpg_typed
5
5
  """
6
6
 
7
- __version__ = "0.1.1"
7
+ __version__ = "0.1.3"
8
8
  __author__ = "Levente Hunyadi"
9
9
  __copyright__ = "Copyright 2025, Levente Hunyadi"
10
10
  __license__ = "MIT"
@@ -13,14 +13,17 @@ __status__ = "Production"
13
13
 
14
14
  import enum
15
15
  import sys
16
+ import typing
16
17
  from abc import abstractmethod
17
- from collections.abc import Iterable, Sequence
18
- from datetime import date, datetime, time
18
+ from collections.abc import Callable, Iterable, Sequence
19
+ from dataclasses import dataclass
20
+ from datetime import date, datetime, time, timedelta
19
21
  from decimal import Decimal
20
22
  from functools import reduce
21
23
  from io import StringIO
24
+ from ipaddress import IPv4Address, IPv4Network, IPv6Address, IPv6Network
22
25
  from types import UnionType
23
- from typing import Any, Generic, TypeAlias, TypeVar, Union, get_args, get_origin, overload
26
+ from typing import Any, Protocol, TypeAlias, TypeGuard, TypeVar, Union, get_args, get_origin, overload
24
27
  from uuid import UUID
25
28
 
26
29
  import asyncpg
@@ -31,16 +34,30 @@ if sys.version_info < (3, 11):
31
34
  else:
32
35
  from typing import LiteralString, TypeVarTuple, Unpack
33
36
 
34
- # list of supported data types
35
- DATA_TYPES: list[type[Any]] = [bool, int, float, Decimal, date, time, datetime, str, bytes, UUID]
37
+ JsonType = None | bool | int | float | str | dict[str, "JsonType"] | list["JsonType"]
36
38
 
37
- # maximum number of inbound query parameters
38
- NUM_ARGS = 8
39
+ RequiredJsonType = bool | int | float | str | dict[str, "JsonType"] | list["JsonType"]
40
+
41
+ TargetType: TypeAlias = type[Any] | UnionType
42
+
43
+ Connection: TypeAlias = asyncpg.Connection | asyncpg.pool.PoolConnectionProxy
44
+
45
+
46
+ class TypeMismatchError(TypeError):
47
+ "Raised when a prepared statement takes or returns a PostgreSQL type incompatible with the declared Python type."
48
+
49
+
50
+ class EnumMismatchError(TypeError):
51
+ "Raised when a prepared statement takes or returns a PostgreSQL enum type whose permitted set of values differs from what is declared in Python."
52
+
53
+
54
+ class NoneTypeError(TypeError):
55
+ "Raised when a column marked as required contains a `NULL` value."
39
56
 
40
57
 
41
58
  if sys.version_info >= (3, 11):
42
59
 
43
- def is_enum_type(typ: object) -> bool:
60
+ def is_enum_type(typ: Any) -> TypeGuard[type[enum.Enum]]:
44
61
  """
45
62
  `True` if the specified type is an enumeration type.
46
63
  """
@@ -49,7 +66,7 @@ if sys.version_info >= (3, 11):
49
66
 
50
67
  else:
51
68
 
52
- def is_enum_type(typ: object) -> bool:
69
+ def is_enum_type(typ: Any) -> TypeGuard[type[enum.Enum]]:
53
70
  """
54
71
  `True` if the specified type is an enumeration type.
55
72
  """
@@ -83,6 +100,22 @@ def is_standard_type(tp: Any) -> bool:
83
100
  return tp.__module__ == "builtins" or tp.__module__ == UnionType.__module__
84
101
 
85
102
 
103
+ def is_json_type(tp: Any) -> bool:
104
+ """
105
+ `True` if the type represents an object de-serialized from a JSON string.
106
+ """
107
+
108
+ return tp in [JsonType, RequiredJsonType]
109
+
110
+
111
+ def is_inet_type(tp: Any) -> bool:
112
+ """
113
+ `True` if the type represents an IP address or network.
114
+ """
115
+
116
+ return tp in [IPv4Address, IPv6Address, IPv4Network, IPv6Network]
117
+
118
+
86
119
  def make_union_type(tpl: list[Any]) -> UnionType:
87
120
  """
88
121
  Creates a `UnionType` (a.k.a. `A | B | C`) dynamically at run time.
@@ -111,53 +144,221 @@ def get_required_type(tp: Any) -> Any:
111
144
  return type(None)
112
145
 
113
146
 
114
- # maps PostgreSQL internal type names to Python types
115
- _name_to_type: dict[str, Any] = {
116
- "bool": bool,
117
- "int2": int,
118
- "int4": int,
119
- "int8": int,
120
- "float4": float,
121
- "float8": float,
122
- "numeric": Decimal,
123
- "date": date,
124
- "time": time,
125
- "timetz": time,
126
- "timestamp": datetime,
127
- "timestamptz": datetime,
128
- "bpchar": str,
129
- "varchar": str,
130
- "text": str,
131
- "bytea": bytes,
132
- "json": str,
133
- "jsonb": str,
134
- "uuid": UUID,
147
+ def _standard_json_decoder() -> Callable[[str], JsonType]:
148
+ import json
149
+
150
+ _json_decoder = json.JSONDecoder()
151
+ return _json_decoder.decode
152
+
153
+
154
+ def _json_decoder() -> Callable[[str], JsonType]:
155
+ if typing.TYPE_CHECKING:
156
+ return _standard_json_decoder()
157
+ else:
158
+ try:
159
+ import orjson
160
+
161
+ return orjson.loads
162
+ except ModuleNotFoundError:
163
+ return _standard_json_decoder()
164
+
165
+
166
+ JSON_DECODER = _json_decoder()
167
+
168
+
169
+ def _standard_json_encoder() -> Callable[[JsonType], str]:
170
+ import json
171
+
172
+ _json_encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":"), allow_nan=False)
173
+ return _json_encoder.encode
174
+
175
+
176
+ def _json_encoder() -> Callable[[JsonType], str]:
177
+ if typing.TYPE_CHECKING:
178
+ return _standard_json_encoder()
179
+ else:
180
+ try:
181
+ import orjson
182
+
183
+ def _wrap(value: JsonType) -> str:
184
+ return orjson.dumps(value).decode()
185
+
186
+ return _wrap
187
+ except ModuleNotFoundError:
188
+ return _standard_json_encoder()
189
+
190
+
191
+ JSON_ENCODER = _json_encoder()
192
+
193
+
194
+ def get_output_converter_for(tp: Any) -> Callable[[Any], Any]:
195
+ """
196
+ Returns a callable that takes a wire type and returns a target type.
197
+
198
+ A wire type is one of the types returned by asyncpg.
199
+ A target type is one of the types supported by the library.
200
+ """
201
+
202
+ if is_json_type(tp):
203
+ # asyncpg returns fields of type `json` and `jsonb` as `str`, which must be de-serialized
204
+ return JSON_DECODER
205
+ else:
206
+ # target data types that require conversion must have a single-argument `__init__` that takes an object of the source type
207
+ return tp
208
+
209
+
210
+ def get_input_converter_for(tp: Any) -> Callable[[Any], Any]:
211
+ """
212
+ Returns a callable that takes a source type and returns a wire type.
213
+
214
+ A source type is one of the types supported by the library.
215
+ A wire type is one of the types returned by asyncpg.
216
+ """
217
+
218
+ if is_json_type(tp):
219
+ # asyncpg expects fields of type `json` and `jsonb` as `str`, which must be serialized
220
+ return JSON_ENCODER
221
+ else:
222
+ # source data types that require conversion must have a single-argument `__init__` that takes an object of the source type
223
+ return tp
224
+
225
+
226
+ # maps PostgreSQL internal type names to compatible Python types
227
+ _NAME_TO_TYPE: dict[str, tuple[Any, ...]] = {
228
+ # boolean type
229
+ "bool": (bool,),
230
+ # numeric types
231
+ "int2": (int,),
232
+ "int4": (int,),
233
+ "int8": (int,),
234
+ "float4": (float,),
235
+ "float8": (float,),
236
+ "numeric": (Decimal,),
237
+ # date and time types
238
+ "date": (date,),
239
+ "time": (time,),
240
+ "timetz": (time,),
241
+ "timestamp": (datetime,),
242
+ "timestamptz": (datetime,),
243
+ "interval": (timedelta,),
244
+ # character sequence types
245
+ "bpchar": (str,),
246
+ "varchar": (str,),
247
+ "text": (str,),
248
+ # binary sequence types
249
+ "bytea": (bytes,),
250
+ # unique identifier type
251
+ "uuid": (UUID,),
252
+ # address types
253
+ "cidr": (IPv4Network, IPv6Network, IPv4Network | IPv6Network),
254
+ "inet": (IPv4Network, IPv6Network, IPv4Network | IPv6Network, IPv4Address, IPv6Address, IPv4Address | IPv6Address),
255
+ "macaddr": (str,),
256
+ "macaddr8": (str,),
257
+ # JSON type
258
+ "json": (str, RequiredJsonType),
259
+ "jsonb": (str, RequiredJsonType),
260
+ # XML type
261
+ "xml": (str,),
262
+ # geometric types
263
+ "point": (asyncpg.Point,),
264
+ "line": (asyncpg.Line,),
265
+ "lseg": (asyncpg.LineSegment,),
266
+ "box": (asyncpg.Box,),
267
+ "path": (asyncpg.Path,),
268
+ "polygon": (asyncpg.Polygon,),
269
+ "circle": (asyncpg.Circle,),
270
+ # range types
271
+ "int4range": (asyncpg.Range[int],),
272
+ "int4multirange": (list[asyncpg.Range[int]],),
273
+ "int8range": (asyncpg.Range[int],),
274
+ "int8multirange": (list[asyncpg.Range[int]],),
275
+ "numrange": (asyncpg.Range[Decimal],),
276
+ "nummultirange": (list[asyncpg.Range[Decimal]],),
277
+ "tsrange": (asyncpg.Range[datetime],),
278
+ "tsmultirange": (list[asyncpg.Range[datetime]],),
279
+ "tstzrange": (asyncpg.Range[datetime],),
280
+ "tstzmultirange": (list[asyncpg.Range[datetime]],),
281
+ "daterange": (asyncpg.Range[date],),
282
+ "datemultirange": (list[asyncpg.Range[date]],),
135
283
  }
136
284
 
137
285
 
138
- def check_data_type(schema: str, name: str, data_type: type[Any]) -> bool:
286
+ def type_to_str(tp: Any) -> str:
287
+ "Emits a friendly name for a type."
288
+
289
+ if isinstance(tp, type):
290
+ return tp.__name__
291
+ else:
292
+ return str(tp)
293
+
294
+
295
+ class _TypeVerifier:
139
296
  """
140
297
  Verifies if the Python target type can represent the PostgreSQL source type.
141
298
  """
142
299
 
143
- if schema == "pg_catalog":
144
- expected_type = _name_to_type.get(name)
145
- return expected_type == data_type
146
- else:
147
- if is_standard_type(data_type):
148
- return False
300
+ _connection: Connection
301
+
302
+ def __init__(self, connection: Connection) -> None:
303
+ self._connection = connection
304
+
305
+ async def _check_enum_type(self, pg_name: str, pg_type: asyncpg.Type, data_type: type[enum.Enum]) -> None:
306
+ """
307
+ Verifies if a Python enumeration type matches a PostgreSQL enumeration type.
308
+ """
309
+
310
+ for e in data_type:
311
+ if not isinstance(e.value, str):
312
+ raise TypeMismatchError(f"expected: Python enum type `{type_to_str(data_type)}` with `str` values; got: `{type_to_str(type(e.value))}` for enum field `{e.name}`")
313
+
314
+ py_values = set(e.value for e in data_type)
149
315
 
150
- # user-defined type registered with `conn.set_type_codec()`
151
- return True
316
+ rows = await self._connection.fetch("SELECT enumlabel FROM pg_enum WHERE enumtypid = $1 ORDER BY enumsortorder;", pg_type.oid)
317
+ db_values = set(row[0] for row in rows)
152
318
 
319
+ db_extra = db_values - py_values
320
+ if db_extra:
321
+ raise EnumMismatchError(f"expected: Python enum type `{type_to_str(data_type)}` to match values of PostgreSQL enum type `{pg_type.name}` for {pg_name}; missing value(s): {', '.join(f'`{val}`' for val in db_extra)})")
322
+
323
+ py_extra = py_values - db_values
324
+ if py_extra:
325
+ raise EnumMismatchError(f"expected: Python enum type `{type_to_str(data_type)}` to match values of PostgreSQL enum type `{pg_type.name}` for {pg_name}; got extra value(s): {', '.join(f'`{val}`' for val in py_extra)})")
326
+
327
+ async def check_data_type(self, pg_name: str, pg_type: asyncpg.Type, data_type: TargetType) -> None:
328
+ """
329
+ Verifies if the Python target type can represent the PostgreSQL source type.
330
+ """
331
+
332
+ if pg_type.schema == "pg_catalog": # well-known PostgreSQL types
333
+ if is_enum_type(data_type):
334
+ if pg_type.name not in ["bpchar", "varchar", "text"]:
335
+ raise TypeMismatchError(f"expected: Python enum type `{type_to_str(data_type)}` for {pg_name}; got: PostgreSQL type `{pg_type.kind}` of `{pg_type.name}` instead of `char`, `varchar` or `text`")
336
+ else:
337
+ expected_types = _NAME_TO_TYPE.get(pg_type.name)
338
+ if expected_types is None:
339
+ raise TypeMismatchError(f"expected: Python type `{type_to_str(data_type)}` for {pg_name}; got: unrecognized PostgreSQL type `{pg_type.kind}` of `{pg_type.name}`")
340
+ elif data_type not in expected_types:
341
+ raise TypeMismatchError(
342
+ f"expected: Python type `{type_to_str(data_type)}` for {pg_name}; "
343
+ f"got: incompatible PostgreSQL type `{pg_type.kind}` of `{pg_type.name}`, which converts to one of the Python types {', '.join(f'`{type_to_str(tp)}`' for tp in expected_types)}"
344
+ )
345
+ elif pg_type.kind == "composite": # PostgreSQL composite types
346
+ # user-defined composite types registered with `conn.set_type_codec()` typically using `format="tuple"`
347
+ pass
348
+ else: # custom PostgreSQL types
349
+ if is_enum_type(data_type):
350
+ await self._check_enum_type(pg_name, pg_type, data_type)
351
+ elif is_standard_type(data_type):
352
+ raise TypeMismatchError(f"expected: Python type `{type_to_str(data_type)}` for {pg_name}; got: PostgreSQL type `{pg_type.kind}` of `{pg_type.name}`")
353
+ else:
354
+ # user-defined types registered with `conn.set_type_codec()`
355
+ pass
153
356
 
357
+
358
+ @dataclass(frozen=True)
154
359
  class _SQLPlaceholder:
155
360
  ordinal: int
156
- data_type: type[Any]
157
-
158
- def __init__(self, ordinal: int, data_type: type[Any]) -> None:
159
- self.ordinal = ordinal
160
- self.data_type = data_type
361
+ data_type: TargetType
161
362
 
162
363
  def __repr__(self) -> str:
163
364
  return f"{self.__class__.__name__}({self.ordinal}, {self.data_type!r})"
@@ -168,30 +369,51 @@ class _SQLObject:
168
369
  Associates input and output type information with a SQL statement.
169
370
  """
170
371
 
171
- parameter_data_types: tuple[_SQLPlaceholder, ...]
172
- resultset_data_types: tuple[type[Any], ...]
173
- required: int
174
- cast: int
372
+ _parameter_data_types: tuple[_SQLPlaceholder, ...]
373
+ _resultset_data_types: tuple[TargetType, ...]
374
+ _parameter_cast: int
375
+ _parameter_converters: tuple[Callable[[Any], Any], ...]
376
+ _required: int
377
+ _resultset_cast: int
378
+ _resultset_converters: tuple[Callable[[Any], Any], ...]
379
+
380
+ @property
381
+ def parameter_data_types(self) -> tuple[_SQLPlaceholder, ...]:
382
+ return self._parameter_data_types
383
+
384
+ @property
385
+ def resultset_data_types(self) -> tuple[TargetType, ...]:
386
+ return self._resultset_data_types
175
387
 
176
388
  def __init__(
177
389
  self,
178
- input_data_types: tuple[type[Any], ...],
179
- output_data_types: tuple[type[Any], ...],
390
+ input_data_types: tuple[TargetType, ...],
391
+ output_data_types: tuple[TargetType, ...],
180
392
  ) -> None:
181
- self.parameter_data_types = tuple(_SQLPlaceholder(ordinal, get_required_type(arg)) for ordinal, arg in enumerate(input_data_types, start=1))
182
- self.resultset_data_types = tuple(get_required_type(data_type) for data_type in output_data_types)
393
+ self._parameter_data_types = tuple(_SQLPlaceholder(ordinal, get_required_type(arg)) for ordinal, arg in enumerate(input_data_types, start=1))
394
+ self._resultset_data_types = tuple(get_required_type(data_type) for data_type in output_data_types)
395
+
396
+ # create a bit-field of types that require cast or serialization (1: apply conversion; 0: forward value as-is)
397
+ parameter_cast = 0
398
+ for index, placeholder in enumerate(self._parameter_data_types):
399
+ parameter_cast |= is_json_type(placeholder.data_type) << index
400
+ self._parameter_cast = parameter_cast
401
+
402
+ self._parameter_converters = tuple(get_input_converter_for(placeholder.data_type) for placeholder in self._parameter_data_types)
183
403
 
184
404
  # create a bit-field of required types (1: required; 0: optional)
185
405
  required = 0
186
406
  for index, data_type in enumerate(output_data_types):
187
407
  required |= (not is_optional_type(data_type)) << index
188
- self.required = required
408
+ self._required = required
409
+
410
+ # create a bit-field of types that require cast or de-serialization (1: apply conversion; 0: forward value as-is)
411
+ resultset_cast = 0
412
+ for index, data_type in enumerate(self._resultset_data_types):
413
+ resultset_cast |= (is_enum_type(data_type) or is_json_type(data_type) or is_inet_type(data_type)) << index
414
+ self._resultset_cast = resultset_cast
189
415
 
190
- # create a bit-field of types that require cast/conversion (1: pass to __init__; 0: skip)
191
- cast = 0
192
- for index, data_type in enumerate(self.resultset_data_types):
193
- cast |= is_enum_type(data_type) << index
194
- self.cast = cast
416
+ self._resultset_converters = tuple(get_output_converter_for(data_type) for data_type in self._resultset_data_types)
195
417
 
196
418
  def _raise_required_is_none(self, row: tuple[Any, ...], row_index: int | None = None) -> None:
197
419
  """
@@ -199,12 +421,12 @@ class _SQLObject:
199
421
  """
200
422
 
201
423
  for col_index in range(len(row)):
202
- if (self.required >> col_index & 1) and row[col_index] is None:
424
+ if (self._required >> col_index & 1) and row[col_index] is None:
203
425
  if row_index is not None:
204
426
  row_col_spec = f"row #{row_index} and column #{col_index}"
205
427
  else:
206
428
  row_col_spec = f"column #{col_index}"
207
- raise TypeError(f"expected: {self.resultset_data_types[col_index]} in {row_col_spec}; got: NULL")
429
+ raise NoneTypeError(f"expected: {self._resultset_data_types[col_index]} in {row_col_spec}; got: NULL")
208
430
 
209
431
  def check_rows(self, rows: list[tuple[Any, ...]]) -> None:
210
432
  """
@@ -214,7 +436,7 @@ class _SQLObject:
214
436
  if not rows:
215
437
  return
216
438
 
217
- required = self.required
439
+ required = self._required
218
440
  if not required:
219
441
  return
220
442
 
@@ -267,7 +489,7 @@ class _SQLObject:
267
489
  Verifies if declared types match actual value types in a single row.
268
490
  """
269
491
 
270
- required = self.required
492
+ required = self._required
271
493
  if not required:
272
494
  return
273
495
 
@@ -311,8 +533,72 @@ class _SQLObject:
311
533
  Verifies if the declared type matches the actual value type.
312
534
  """
313
535
 
314
- if self.required and value is None:
315
- raise TypeError(f"expected: {self.resultset_data_types[0]}; got: NULL")
536
+ if self._required and value is None:
537
+ raise NoneTypeError(f"expected: {self._resultset_data_types[0]}; got: NULL")
538
+
539
+ def convert_arg_lists(self, arg_lists: Iterable[Sequence[Any]]) -> Iterable[Sequence[Any]]:
540
+ """
541
+ Converts a list of Python query argument tuples to a list of PostgreSQL parameter tuples.
542
+ """
543
+
544
+ cast = self._parameter_cast
545
+ if cast:
546
+ converters = self._parameter_converters
547
+ yield from (tuple((converters[i](value) if (value := arg[i]) is not None and cast >> i & 1 else value) for i in range(len(arg))) for arg in arg_lists)
548
+ else:
549
+ yield from arg_lists
550
+
551
+ def convert_arg_list(self, arg_list: Sequence[Any]) -> Sequence[Any]:
552
+ """
553
+ Converts Python query arguments to PostgreSQL parameters.
554
+ """
555
+
556
+ cast = self._parameter_cast
557
+ if cast:
558
+ converters = self._parameter_converters
559
+ return tuple((converters[i](value) if (value := arg_list[i]) is not None and cast >> i & 1 else value) for i in range(len(arg_list)))
560
+ else:
561
+ return tuple(value for value in arg_list)
562
+
563
+ def convert_rows(self, rows: list[asyncpg.Record]) -> list[tuple[Any, ...]]:
564
+ """
565
+ Converts columns in the PostgreSQL result-set to their corresponding Python target types.
566
+
567
+ :param rows: List of rows returned by PostgreSQL.
568
+ :returns: List of tuples with each tuple element having the configured Python target type.
569
+ """
570
+
571
+ cast = self._resultset_cast
572
+ if cast:
573
+ converters = self._resultset_converters
574
+ return [tuple((converters[i](value) if (value := row[i]) is not None and cast >> i & 1 else value) for i in range(len(row))) for row in rows]
575
+ else:
576
+ return [tuple(value for value in row) for row in rows]
577
+
578
+ def convert_row(self, row: asyncpg.Record) -> tuple[Any, ...]:
579
+ """
580
+ Converts columns in the PostgreSQL result-set to their corresponding Python target types.
581
+
582
+ :param row: A single row returned by PostgreSQL.
583
+ :returns: A tuple with each tuple element having the configured Python target type.
584
+ """
585
+
586
+ cast = self._resultset_cast
587
+ if cast:
588
+ converters = self._resultset_converters
589
+ return tuple((converters[i](value) if (value := row[i]) is not None and cast >> i & 1 else value) for i in range(len(row)))
590
+ else:
591
+ return tuple(value for value in row)
592
+
593
+ def convert_value(self, value: Any) -> Any:
594
+ """
595
+ Converts a single PostgreSQL value to its corresponding Python target type.
596
+
597
+ :param value: A single value returned by PostgreSQL.
598
+ :returns: A converted value having the configured Python target type.
599
+ """
600
+
601
+ return self._resultset_converters[0](value) if value is not None and self._resultset_cast else value
316
602
 
317
603
  @abstractmethod
318
604
  def query(self) -> str:
@@ -338,15 +624,15 @@ if sys.version_info >= (3, 14):
338
624
  A SQL query specified with the Python t-string syntax.
339
625
  """
340
626
 
341
- strings: tuple[str, ...]
342
- placeholders: tuple[_SQLPlaceholder, ...]
627
+ _strings: tuple[str, ...]
628
+ _placeholders: tuple[_SQLPlaceholder, ...]
343
629
 
344
630
  def __init__(
345
631
  self,
346
632
  template: Template,
347
633
  *,
348
- args: tuple[type[Any], ...],
349
- resultset: tuple[type[Any], ...],
634
+ args: tuple[TargetType, ...],
635
+ resultset: tuple[TargetType, ...],
350
636
  ) -> None:
351
637
  super().__init__(args, resultset)
352
638
 
@@ -358,7 +644,7 @@ if sys.version_info >= (3, 14):
358
644
  if not isinstance(ip.value, int):
359
645
  raise TypeError(f"interpolation `{ip.expression}` expected to evaluate to an integer")
360
646
 
361
- self.strings = template.strings
647
+ self._strings = template.strings
362
648
 
363
649
  if len(self.parameter_data_types) > 0:
364
650
 
@@ -368,16 +654,16 @@ if sys.version_info >= (3, 14):
368
654
  raise IndexError(f"interpolation `{ip.expression}` is an ordinal out of range; expected: 0 < value <= {len(self.parameter_data_types)}")
369
655
  return self.parameter_data_types[int(ip.value) - 1]
370
656
 
371
- self.placeholders = tuple(_to_placeholder(ip) for ip in template.interpolations)
657
+ self._placeholders = tuple(_to_placeholder(ip) for ip in template.interpolations)
372
658
  else:
373
- self.placeholders = ()
659
+ self._placeholders = ()
374
660
 
375
661
  def query(self) -> str:
376
662
  buf = StringIO()
377
- for s, p in zip(self.strings[:-1], self.placeholders, strict=True):
663
+ for s, p in zip(self._strings[:-1], self._placeholders, strict=True):
378
664
  buf.write(s)
379
665
  buf.write(f"${p.ordinal}")
380
- buf.write(self.strings[-1])
666
+ buf.write(self._strings[-1])
381
667
  return buf.getvalue()
382
668
 
383
669
  else:
@@ -389,128 +675,110 @@ class _SQLString(_SQLObject):
389
675
  A SQL query specified as a plain string (e.g. f-string).
390
676
  """
391
677
 
392
- sql: str
678
+ _sql: str
393
679
 
394
680
  def __init__(
395
681
  self,
396
682
  sql: str,
397
683
  *,
398
- args: tuple[type[Any], ...],
399
- resultset: tuple[type[Any], ...],
684
+ args: tuple[TargetType, ...],
685
+ resultset: tuple[TargetType, ...],
400
686
  ) -> None:
401
687
  super().__init__(args, resultset)
402
- self.sql = sql
688
+ self._sql = sql
403
689
 
404
690
  def query(self) -> str:
405
- return self.sql
691
+ return self._sql
406
692
 
407
693
 
408
- class _SQL:
694
+ class _SQL(Protocol):
409
695
  """
410
696
  Represents a SQL statement with associated type information.
411
697
  """
412
698
 
413
699
 
414
- Connection: TypeAlias = asyncpg.Connection | asyncpg.pool.PoolConnectionProxy
415
-
416
-
417
700
  class _SQLImpl(_SQL):
418
701
  """
419
702
  Forwards input data to an `asyncpg.PreparedStatement`, and validates output data (if necessary).
420
703
  """
421
704
 
422
- sql: _SQLObject
705
+ _sql: _SQLObject
423
706
 
424
707
  def __init__(self, sql: _SQLObject) -> None:
425
- self.sql = sql
708
+ self._sql = sql
426
709
 
427
710
  def __str__(self) -> str:
428
- return str(self.sql)
711
+ return str(self._sql)
429
712
 
430
713
  def __repr__(self) -> str:
431
- return repr(self.sql)
714
+ return repr(self._sql)
432
715
 
433
716
  async def _prepare(self, connection: Connection) -> PreparedStatement:
434
- stmt = await connection.prepare(self.sql.query())
717
+ stmt = await connection.prepare(self._sql.query())
435
718
 
436
- for attr, data_type in zip(stmt.get_attributes(), self.sql.resultset_data_types, strict=True):
437
- if not check_data_type(attr.type.schema, attr.type.name, data_type):
438
- raise TypeError(f"expected: {data_type} in column `{attr.name}`; got: `{attr.type.kind}` of `{attr.type.name}`")
719
+ verifier = _TypeVerifier(connection)
720
+ for param, placeholder in zip(stmt.get_parameters(), self._sql.parameter_data_types, strict=True):
721
+ await verifier.check_data_type(f"parameter ${placeholder.ordinal}", param, placeholder.data_type)
722
+ for attr, data_type in zip(stmt.get_attributes(), self._sql.resultset_data_types, strict=True):
723
+ await verifier.check_data_type(f"column `{attr.name}`", attr.type, data_type)
439
724
 
440
725
  return stmt
441
726
 
442
727
  async def execute(self, connection: asyncpg.Connection, *args: Any) -> None:
443
- await connection.execute(self.sql.query(), *args)
728
+ await connection.execute(self._sql.query(), *self._sql.convert_arg_list(args))
444
729
 
445
730
  async def executemany(self, connection: asyncpg.Connection, args: Iterable[Sequence[Any]]) -> None:
446
731
  stmt = await self._prepare(connection)
447
- await stmt.executemany(args)
732
+ await stmt.executemany(self._sql.convert_arg_lists(args))
448
733
 
449
734
  def _cast_fetch(self, rows: list[asyncpg.Record]) -> list[tuple[Any, ...]]:
450
- cast = self.sql.cast
451
- if cast:
452
- data_types = self.sql.resultset_data_types
453
- resultset = [tuple((data_types[i](value) if (value := row[i]) is not None and cast >> i & 1 else value) for i in range(len(row))) for row in rows]
454
- else:
455
- resultset = [tuple(value for value in row) for row in rows]
456
- self.sql.check_rows(resultset)
735
+ resultset = self._sql.convert_rows(rows)
736
+ self._sql.check_rows(resultset)
457
737
  return resultset
458
738
 
459
739
  async def fetch(self, connection: asyncpg.Connection, *args: Any) -> list[tuple[Any, ...]]:
460
740
  stmt = await self._prepare(connection)
461
- rows = await stmt.fetch(*args)
741
+ rows = await stmt.fetch(*self._sql.convert_arg_list(args))
462
742
  return self._cast_fetch(rows)
463
743
 
464
744
  async def fetchmany(self, connection: asyncpg.Connection, args: Iterable[Sequence[Any]]) -> list[tuple[Any, ...]]:
465
745
  stmt = await self._prepare(connection)
466
- rows = await stmt.fetchmany(args)
746
+ rows = await stmt.fetchmany(self._sql.convert_arg_lists(args))
467
747
  return self._cast_fetch(rows)
468
748
 
469
749
  async def fetchrow(self, connection: asyncpg.Connection, *args: Any) -> tuple[Any, ...] | None:
470
750
  stmt = await self._prepare(connection)
471
- row = await stmt.fetchrow(*args)
751
+ row = await stmt.fetchrow(*self._sql.convert_arg_list(args))
472
752
  if row is None:
473
753
  return None
474
- cast = self.sql.cast
475
- if cast:
476
- data_types = self.sql.resultset_data_types
477
- resultset = tuple((data_types[i](value) if (value := row[i]) is not None and cast >> i & 1 else value) for i in range(len(row)))
478
- else:
479
- resultset = tuple(value for value in row)
480
- self.sql.check_row(resultset)
754
+ resultset = self._sql.convert_row(row)
755
+ self._sql.check_row(resultset)
481
756
  return resultset
482
757
 
483
758
  async def fetchval(self, connection: asyncpg.Connection, *args: Any) -> Any:
484
759
  stmt = await self._prepare(connection)
485
- value = await stmt.fetchval(*args)
486
- result = self.sql.resultset_data_types[0](value) if value is not None and self.sql.cast else value
487
- self.sql.check_value(result)
760
+ value = await stmt.fetchval(*self._sql.convert_arg_list(args))
761
+ result = self._sql.convert_value(value)
762
+ self._sql.check_value(result)
488
763
  return result
489
764
 
490
765
 
491
- ### START OF AUTO-GENERATED BLOCK ###
492
-
493
- PS = TypeVar("PS")
494
766
  P1 = TypeVar("P1")
495
- P2 = TypeVar("P2")
496
- P3 = TypeVar("P3")
497
- P4 = TypeVar("P4")
498
- P5 = TypeVar("P5")
499
- P6 = TypeVar("P6")
500
- P7 = TypeVar("P7")
501
- P8 = TypeVar("P8")
502
- RS = TypeVar("RS")
767
+ PX = TypeVarTuple("PX")
768
+
769
+ RT = TypeVar("RT")
503
770
  R1 = TypeVar("R1")
504
771
  R2 = TypeVar("R2")
505
772
  RX = TypeVarTuple("RX")
506
773
 
507
774
 
508
- class SQL_P0(_SQL):
775
+ ### START OF AUTO-GENERATED BLOCK FOR Protocol ###
776
+ class SQL_P0(Protocol):
509
777
  @abstractmethod
510
778
  async def execute(self, connection: Connection) -> None: ...
511
779
 
512
780
 
513
- class SQL_P0_RS(Generic[R1], SQL_P0):
781
+ class SQL_R1_P0(SQL_P0, Protocol[R1]):
514
782
  @abstractmethod
515
783
  async def fetch(self, connection: Connection) -> list[tuple[R1]]: ...
516
784
  @abstractmethod
@@ -519,327 +787,150 @@ class SQL_P0_RS(Generic[R1], SQL_P0):
519
787
  async def fetchval(self, connection: Connection) -> R1: ...
520
788
 
521
789
 
522
- class SQL_P0_RX(Generic[R1, R2, Unpack[RX]], SQL_P0):
523
- @abstractmethod
524
- async def fetch(self, connection: Connection) -> list[tuple[R1, R2, Unpack[RX]]]: ...
525
- @abstractmethod
526
- async def fetchrow(self, connection: Connection) -> tuple[R1, R2, Unpack[RX]] | None: ...
527
-
528
-
529
- class SQL_P1(Generic[P1], _SQL):
530
- @abstractmethod
531
- async def execute(self, connection: Connection, arg1: P1) -> None: ...
532
- @abstractmethod
533
- async def executemany(self, connection: Connection, args: Iterable[tuple[P1]]) -> None: ...
534
-
535
-
536
- class SQL_P1_RS(Generic[P1, R1], SQL_P1[P1]):
537
- @abstractmethod
538
- async def fetch(self, connection: Connection, arg1: P1) -> list[tuple[R1]]: ...
539
- @abstractmethod
540
- async def fetchmany(self, connection: Connection, args: Iterable[tuple[P1]]) -> list[tuple[R1]]: ...
541
- @abstractmethod
542
- async def fetchrow(self, connection: Connection, arg1: P1) -> tuple[R1] | None: ...
543
- @abstractmethod
544
- async def fetchval(self, connection: Connection, arg1: P1) -> R1: ...
545
-
546
-
547
- class SQL_P1_RX(Generic[P1, R1, R2, Unpack[RX]], SQL_P1[P1]):
548
- @abstractmethod
549
- async def fetch(self, connection: Connection, arg1: P1) -> list[tuple[R1, R2, Unpack[RX]]]: ...
550
- @abstractmethod
551
- async def fetchmany(self, connection: Connection, args: Iterable[tuple[P1]]) -> list[tuple[R1, R2, Unpack[RX]]]: ...
552
- @abstractmethod
553
- async def fetchrow(self, connection: Connection, arg1: P1) -> tuple[R1, R2, Unpack[RX]] | None: ...
554
-
555
-
556
- class SQL_P2(Generic[P1, P2], _SQL):
557
- @abstractmethod
558
- async def execute(self, connection: Connection, arg1: P1, arg2: P2) -> None: ...
559
- @abstractmethod
560
- async def executemany(self, connection: Connection, args: Iterable[tuple[P1, P2]]) -> None: ...
561
-
562
-
563
- class SQL_P2_RS(Generic[P1, P2, R1], SQL_P2[P1, P2]):
564
- @abstractmethod
565
- async def fetch(self, connection: Connection, arg1: P1, arg2: P2) -> list[tuple[R1]]: ...
566
- @abstractmethod
567
- async def fetchmany(self, connection: Connection, args: Iterable[tuple[P1, P2]]) -> list[tuple[R1]]: ...
568
- @abstractmethod
569
- async def fetchrow(self, connection: Connection, arg1: P1, arg2: P2) -> tuple[R1] | None: ...
570
- @abstractmethod
571
- async def fetchval(self, connection: Connection, arg1: P1, arg2: P2) -> R1: ...
572
-
573
-
574
- class SQL_P2_RX(Generic[P1, P2, R1, R2, Unpack[RX]], SQL_P2[P1, P2]):
575
- @abstractmethod
576
- async def fetch(self, connection: Connection, arg1: P1, arg2: P2) -> list[tuple[R1, R2, Unpack[RX]]]: ...
577
- @abstractmethod
578
- async def fetchmany(self, connection: Connection, args: Iterable[tuple[P1, P2]]) -> list[tuple[R1, R2, Unpack[RX]]]: ...
579
- @abstractmethod
580
- async def fetchrow(self, connection: Connection, arg1: P1, arg2: P2) -> tuple[R1, R2, Unpack[RX]] | None: ...
581
-
582
-
583
- class SQL_P3(Generic[P1, P2, P3], _SQL):
584
- @abstractmethod
585
- async def execute(self, connection: Connection, arg1: P1, arg2: P2, arg3: P3) -> None: ...
586
- @abstractmethod
587
- async def executemany(self, connection: Connection, args: Iterable[tuple[P1, P2, P3]]) -> None: ...
588
-
589
-
590
- class SQL_P3_RS(Generic[P1, P2, P3, R1], SQL_P3[P1, P2, P3]):
591
- @abstractmethod
592
- async def fetch(self, connection: Connection, arg1: P1, arg2: P2, arg3: P3) -> list[tuple[R1]]: ...
593
- @abstractmethod
594
- async def fetchmany(self, connection: Connection, args: Iterable[tuple[P1, P2, P3]]) -> list[tuple[R1]]: ...
595
- @abstractmethod
596
- async def fetchrow(self, connection: Connection, arg1: P1, arg2: P2, arg3: P3) -> tuple[R1] | None: ...
597
- @abstractmethod
598
- async def fetchval(self, connection: Connection, arg1: P1, arg2: P2, arg3: P3) -> R1: ...
599
-
600
-
601
- class SQL_P3_RX(Generic[P1, P2, P3, R1, R2, Unpack[RX]], SQL_P3[P1, P2, P3]):
602
- @abstractmethod
603
- async def fetch(self, connection: Connection, arg1: P1, arg2: P2, arg3: P3) -> list[tuple[R1, R2, Unpack[RX]]]: ...
604
- @abstractmethod
605
- async def fetchmany(self, connection: Connection, args: Iterable[tuple[P1, P2, P3]]) -> list[tuple[R1, R2, Unpack[RX]]]: ...
606
- @abstractmethod
607
- async def fetchrow(self, connection: Connection, arg1: P1, arg2: P2, arg3: P3) -> tuple[R1, R2, Unpack[RX]] | None: ...
608
-
609
-
610
- class SQL_P4(Generic[P1, P2, P3, P4], _SQL):
611
- @abstractmethod
612
- async def execute(self, connection: Connection, arg1: P1, arg2: P2, arg3: P3, arg4: P4) -> None: ...
613
- @abstractmethod
614
- async def executemany(self, connection: Connection, args: Iterable[tuple[P1, P2, P3, P4]]) -> None: ...
615
-
616
-
617
- class SQL_P4_RS(Generic[P1, P2, P3, P4, R1], SQL_P4[P1, P2, P3, P4]):
618
- @abstractmethod
619
- async def fetch(self, connection: Connection, arg1: P1, arg2: P2, arg3: P3, arg4: P4) -> list[tuple[R1]]: ...
620
- @abstractmethod
621
- async def fetchmany(self, connection: Connection, args: Iterable[tuple[P1, P2, P3, P4]]) -> list[tuple[R1]]: ...
622
- @abstractmethod
623
- async def fetchrow(self, connection: Connection, arg1: P1, arg2: P2, arg3: P3, arg4: P4) -> tuple[R1] | None: ...
624
- @abstractmethod
625
- async def fetchval(self, connection: Connection, arg1: P1, arg2: P2, arg3: P3, arg4: P4) -> R1: ...
626
-
627
-
628
- class SQL_P4_RX(Generic[P1, P2, P3, P4, R1, R2, Unpack[RX]], SQL_P4[P1, P2, P3, P4]):
629
- @abstractmethod
630
- async def fetch(self, connection: Connection, arg1: P1, arg2: P2, arg3: P3, arg4: P4) -> list[tuple[R1, R2, Unpack[RX]]]: ...
790
+ class SQL_RX_P0(SQL_P0, Protocol[RT]):
631
791
  @abstractmethod
632
- async def fetchmany(self, connection: Connection, args: Iterable[tuple[P1, P2, P3, P4]]) -> list[tuple[R1, R2, Unpack[RX]]]: ...
792
+ async def fetch(self, connection: Connection) -> list[RT]: ...
633
793
  @abstractmethod
634
- async def fetchrow(self, connection: Connection, arg1: P1, arg2: P2, arg3: P3, arg4: P4) -> tuple[R1, R2, Unpack[RX]] | None: ...
794
+ async def fetchrow(self, connection: Connection) -> RT | None: ...
635
795
 
636
796
 
637
- class SQL_P5(Generic[P1, P2, P3, P4, P5], _SQL):
797
+ class SQL_PX(Protocol[Unpack[PX]]):
638
798
  @abstractmethod
639
- async def execute(self, connection: Connection, arg1: P1, arg2: P2, arg3: P3, arg4: P4, arg5: P5) -> None: ...
799
+ async def execute(self, connection: Connection, *args: Unpack[PX]) -> None: ...
640
800
  @abstractmethod
641
- async def executemany(self, connection: Connection, args: Iterable[tuple[P1, P2, P3, P4, P5]]) -> None: ...
801
+ async def executemany(self, connection: Connection, args: Iterable[tuple[Unpack[PX]]]) -> None: ...
642
802
 
643
803
 
644
- class SQL_P5_RS(Generic[P1, P2, P3, P4, P5, R1], SQL_P5[P1, P2, P3, P4, P5]):
804
+ class SQL_R1_PX(SQL_PX[Unpack[PX]], Protocol[R1, Unpack[PX]]):
645
805
  @abstractmethod
646
- async def fetch(self, connection: Connection, arg1: P1, arg2: P2, arg3: P3, arg4: P4, arg5: P5) -> list[tuple[R1]]: ...
806
+ async def fetch(self, connection: Connection, *args: Unpack[PX]) -> list[tuple[R1]]: ...
647
807
  @abstractmethod
648
- async def fetchmany(self, connection: Connection, args: Iterable[tuple[P1, P2, P3, P4, P5]]) -> list[tuple[R1]]: ...
808
+ async def fetchmany(self, connection: Connection, args: Iterable[tuple[Unpack[PX]]]) -> list[tuple[R1]]: ...
649
809
  @abstractmethod
650
- async def fetchrow(self, connection: Connection, arg1: P1, arg2: P2, arg3: P3, arg4: P4, arg5: P5) -> tuple[R1] | None: ...
810
+ async def fetchrow(self, connection: Connection, *args: Unpack[PX]) -> tuple[R1] | None: ...
651
811
  @abstractmethod
652
- async def fetchval(self, connection: Connection, arg1: P1, arg2: P2, arg3: P3, arg4: P4, arg5: P5) -> R1: ...
812
+ async def fetchval(self, connection: Connection, *args: Unpack[PX]) -> R1: ...
653
813
 
654
814
 
655
- class SQL_P5_RX(Generic[P1, P2, P3, P4, P5, R1, R2, Unpack[RX]], SQL_P5[P1, P2, P3, P4, P5]):
815
+ class SQL_RX_PX(SQL_PX[Unpack[PX]], Protocol[RT, Unpack[PX]]):
656
816
  @abstractmethod
657
- async def fetch(self, connection: Connection, arg1: P1, arg2: P2, arg3: P3, arg4: P4, arg5: P5) -> list[tuple[R1, R2, Unpack[RX]]]: ...
817
+ async def fetch(self, connection: Connection, *args: Unpack[PX]) -> list[RT]: ...
658
818
  @abstractmethod
659
- async def fetchmany(self, connection: Connection, args: Iterable[tuple[P1, P2, P3, P4, P5]]) -> list[tuple[R1, R2, Unpack[RX]]]: ...
819
+ async def fetchmany(self, connection: Connection, args: Iterable[tuple[Unpack[PX]]]) -> list[RT]: ...
660
820
  @abstractmethod
661
- async def fetchrow(self, connection: Connection, arg1: P1, arg2: P2, arg3: P3, arg4: P4, arg5: P5) -> tuple[R1, R2, Unpack[RX]] | None: ...
821
+ async def fetchrow(self, connection: Connection, *args: Unpack[PX]) -> RT | None: ...
662
822
 
663
823
 
664
- class SQL_P6(Generic[P1, P2, P3, P4, P5, P6], _SQL):
665
- @abstractmethod
666
- async def execute(self, connection: Connection, arg1: P1, arg2: P2, arg3: P3, arg4: P4, arg5: P5, arg6: P6) -> None: ...
667
- @abstractmethod
668
- async def executemany(self, connection: Connection, args: Iterable[tuple[P1, P2, P3, P4, P5, P6]]) -> None: ...
824
+ ### END OF AUTO-GENERATED BLOCK FOR Protocol ###
669
825
 
670
826
 
671
- class SQL_P6_RS(Generic[P1, P2, P3, P4, P5, P6, R1], SQL_P6[P1, P2, P3, P4, P5, P6]):
672
- @abstractmethod
673
- async def fetch(self, connection: Connection, arg1: P1, arg2: P2, arg3: P3, arg4: P4, arg5: P5, arg6: P6) -> list[tuple[R1]]: ...
674
- @abstractmethod
675
- async def fetchmany(self, connection: Connection, args: Iterable[tuple[P1, P2, P3, P4, P5, P6]]) -> list[tuple[R1]]: ...
676
- @abstractmethod
677
- async def fetchrow(self, connection: Connection, arg1: P1, arg2: P2, arg3: P3, arg4: P4, arg5: P5, arg6: P6) -> tuple[R1] | None: ...
678
- @abstractmethod
679
- async def fetchval(self, connection: Connection, arg1: P1, arg2: P2, arg3: P3, arg4: P4, arg5: P5, arg6: P6) -> R1: ...
680
-
681
-
682
- class SQL_P6_RX(Generic[P1, P2, P3, P4, P5, P6, R1, R2, Unpack[RX]], SQL_P6[P1, P2, P3, P4, P5, P6]):
683
- @abstractmethod
684
- async def fetch(self, connection: Connection, arg1: P1, arg2: P2, arg3: P3, arg4: P4, arg5: P5, arg6: P6) -> list[tuple[R1, R2, Unpack[RX]]]: ...
685
- @abstractmethod
686
- async def fetchmany(self, connection: Connection, args: Iterable[tuple[P1, P2, P3, P4, P5, P6]]) -> list[tuple[R1, R2, Unpack[RX]]]: ...
687
- @abstractmethod
688
- async def fetchrow(self, connection: Connection, arg1: P1, arg2: P2, arg3: P3, arg4: P4, arg5: P5, arg6: P6) -> tuple[R1, R2, Unpack[RX]] | None: ...
689
-
690
-
691
- class SQL_P7(Generic[P1, P2, P3, P4, P5, P6, P7], _SQL):
692
- @abstractmethod
693
- async def execute(self, connection: Connection, arg1: P1, arg2: P2, arg3: P3, arg4: P4, arg5: P5, arg6: P6, arg7: P7) -> None: ...
694
- @abstractmethod
695
- async def executemany(self, connection: Connection, args: Iterable[tuple[P1, P2, P3, P4, P5, P6, P7]]) -> None: ...
696
-
697
-
698
- class SQL_P7_RS(Generic[P1, P2, P3, P4, P5, P6, P7, R1], SQL_P7[P1, P2, P3, P4, P5, P6, P7]):
699
- @abstractmethod
700
- async def fetch(self, connection: Connection, arg1: P1, arg2: P2, arg3: P3, arg4: P4, arg5: P5, arg6: P6, arg7: P7) -> list[tuple[R1]]: ...
701
- @abstractmethod
702
- async def fetchmany(self, connection: Connection, args: Iterable[tuple[P1, P2, P3, P4, P5, P6, P7]]) -> list[tuple[R1]]: ...
703
- @abstractmethod
704
- async def fetchrow(self, connection: Connection, arg1: P1, arg2: P2, arg3: P3, arg4: P4, arg5: P5, arg6: P6, arg7: P7) -> tuple[R1] | None: ...
705
- @abstractmethod
706
- async def fetchval(self, connection: Connection, arg1: P1, arg2: P2, arg3: P3, arg4: P4, arg5: P5, arg6: P6, arg7: P7) -> R1: ...
827
+ class SQLFactory:
828
+ """
829
+ Creates type-safe SQL queries.
830
+ """
707
831
 
832
+ ### START OF AUTO-GENERATED BLOCK FOR sql ###
833
+ @overload
834
+ def sql(self, stmt: SQLExpression) -> SQL_P0: ...
835
+ @overload
836
+ def sql(self, stmt: SQLExpression, *, result: type[R1]) -> SQL_R1_P0[R1]: ...
837
+ @overload
838
+ def sql(self, stmt: SQLExpression, *, resultset: type[tuple[R1]]) -> SQL_R1_P0[R1]: ...
839
+ @overload
840
+ def sql(self, stmt: SQLExpression, *, resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_RX_P0[tuple[R1, R2, Unpack[RX]]]: ...
841
+ @overload
842
+ def sql(self, stmt: SQLExpression, *, arg: type[P1]) -> SQL_PX[P1]: ...
843
+ @overload
844
+ def sql(self, stmt: SQLExpression, *, arg: type[P1], result: type[R1]) -> SQL_R1_PX[R1, P1]: ...
845
+ @overload
846
+ def sql(self, stmt: SQLExpression, *, arg: type[P1], resultset: type[tuple[R1]]) -> SQL_R1_PX[R1, P1]: ...
847
+ @overload
848
+ def sql(self, stmt: SQLExpression, *, arg: type[P1], resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_RX_PX[tuple[R1, R2, Unpack[RX]], P1]: ...
849
+ @overload
850
+ def sql(self, stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]]) -> SQL_PX[P1, Unpack[PX]]: ...
851
+ @overload
852
+ def sql(self, stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]], result: type[R1]) -> SQL_R1_PX[R1, P1, Unpack[PX]]: ...
853
+ @overload
854
+ def sql(self, stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]], resultset: type[tuple[R1]]) -> SQL_R1_PX[R1, P1, Unpack[PX]]: ...
855
+ @overload
856
+ def sql(self, stmt: SQLExpression, *, args: type[tuple[P1, Unpack[PX]]], resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_RX_PX[tuple[R1, R2, Unpack[RX]], P1, Unpack[PX]]: ...
857
+
858
+ ### END OF AUTO-GENERATED BLOCK FOR sql ###
859
+
860
+ def sql(self, stmt: SQLExpression, *, args: type[Any] | None = None, resultset: type[Any] | None = None, arg: type[Any] | None = None, result: type[Any] | None = None) -> _SQL:
861
+ """
862
+ Creates a SQL statement with associated type information.
708
863
 
709
- class SQL_P7_RX(Generic[P1, P2, P3, P4, P5, P6, P7, R1, R2, Unpack[RX]], SQL_P7[P1, P2, P3, P4, P5, P6, P7]):
710
- @abstractmethod
711
- async def fetch(self, connection: Connection, arg1: P1, arg2: P2, arg3: P3, arg4: P4, arg5: P5, arg6: P6, arg7: P7) -> list[tuple[R1, R2, Unpack[RX]]]: ...
712
- @abstractmethod
713
- async def fetchmany(self, connection: Connection, args: Iterable[tuple[P1, P2, P3, P4, P5, P6, P7]]) -> list[tuple[R1, R2, Unpack[RX]]]: ...
714
- @abstractmethod
715
- async def fetchrow(self, connection: Connection, arg1: P1, arg2: P2, arg3: P3, arg4: P4, arg5: P5, arg6: P6, arg7: P7) -> tuple[R1, R2, Unpack[RX]] | None: ...
864
+ :param stmt: SQL statement as a literal string or template.
865
+ :param args: Type signature for multiple input parameters (e.g. `tuple[bool, int, str]`).
866
+ :param resultset: Type signature for multiple resultset columns (e.g. `tuple[datetime, Decimal, str]`).
867
+ :param arg: Type signature for a single input parameter (e.g. `int`).
868
+ :param result: Type signature for a single result column (e.g. `UUID`).
869
+ """
716
870
 
871
+ input_data_types, output_data_types = _sql_args_resultset(args=args, resultset=resultset, arg=arg, result=result)
717
872
 
718
- class SQL_P8(Generic[P1, P2, P3, P4, P5, P6, P7, P8], _SQL):
719
- @abstractmethod
720
- async def execute(self, connection: Connection, arg1: P1, arg2: P2, arg3: P3, arg4: P4, arg5: P5, arg6: P6, arg7: P7, arg8: P8) -> None: ...
721
- @abstractmethod
722
- async def executemany(self, connection: Connection, args: Iterable[tuple[P1, P2, P3, P4, P5, P6, P7, P8]]) -> None: ...
873
+ obj: _SQLObject
874
+ if sys.version_info >= (3, 14):
875
+ match stmt:
876
+ case Template():
877
+ obj = _SQLTemplate(stmt, args=input_data_types, resultset=output_data_types)
878
+ case str():
879
+ obj = _SQLString(stmt, args=input_data_types, resultset=output_data_types)
880
+ else:
881
+ obj = _SQLString(stmt, args=input_data_types, resultset=output_data_types)
882
+
883
+ return _SQLImpl(obj)
884
+
885
+ ### START OF AUTO-GENERATED BLOCK FOR unsafe_sql ###
886
+ @overload
887
+ def unsafe_sql(self, stmt: str) -> SQL_P0: ...
888
+ @overload
889
+ def unsafe_sql(self, stmt: str, *, result: type[R1]) -> SQL_R1_P0[R1]: ...
890
+ @overload
891
+ def unsafe_sql(self, stmt: str, *, resultset: type[tuple[R1]]) -> SQL_R1_P0[R1]: ...
892
+ @overload
893
+ def unsafe_sql(self, stmt: str, *, resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_RX_P0[tuple[R1, R2, Unpack[RX]]]: ...
894
+ @overload
895
+ def unsafe_sql(self, stmt: str, *, arg: type[P1]) -> SQL_PX[P1]: ...
896
+ @overload
897
+ def unsafe_sql(self, stmt: str, *, arg: type[P1], result: type[R1]) -> SQL_R1_PX[R1, P1]: ...
898
+ @overload
899
+ def unsafe_sql(self, stmt: str, *, arg: type[P1], resultset: type[tuple[R1]]) -> SQL_R1_PX[R1, P1]: ...
900
+ @overload
901
+ def unsafe_sql(self, stmt: str, *, arg: type[P1], resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_RX_PX[tuple[R1, R2, Unpack[RX]], P1]: ...
902
+ @overload
903
+ def unsafe_sql(self, stmt: str, *, args: type[tuple[P1, Unpack[PX]]]) -> SQL_PX[P1, Unpack[PX]]: ...
904
+ @overload
905
+ def unsafe_sql(self, stmt: str, *, args: type[tuple[P1, Unpack[PX]]], result: type[R1]) -> SQL_R1_PX[R1, P1, Unpack[PX]]: ...
906
+ @overload
907
+ def unsafe_sql(self, stmt: str, *, args: type[tuple[P1, Unpack[PX]]], resultset: type[tuple[R1]]) -> SQL_R1_PX[R1, P1, Unpack[PX]]: ...
908
+ @overload
909
+ def unsafe_sql(self, stmt: str, *, args: type[tuple[P1, Unpack[PX]]], resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_RX_PX[tuple[R1, R2, Unpack[RX]], P1, Unpack[PX]]: ...
910
+
911
+ ### END OF AUTO-GENERATED BLOCK FOR unsafe_sql ###
912
+
913
+ def unsafe_sql(self, stmt: str, *, args: type[Any] | None = None, resultset: type[Any] | None = None, arg: type[Any] | None = None, result: type[Any] | None = None) -> _SQL:
914
+ """
915
+ Creates a SQL statement with associated type information from a string.
723
916
 
917
+ This offers an alternative to the function :func:`sql` when we want to prevent the type checker from enforcing
918
+ a string literal, e.g. when we want to embed a variable as the table name to dynamically create a SQL statement.
724
919
 
725
- class SQL_P8_RS(Generic[P1, P2, P3, P4, P5, P6, P7, P8, R1], SQL_P8[P1, P2, P3, P4, P5, P6, P7, P8]):
726
- @abstractmethod
727
- async def fetch(self, connection: Connection, arg1: P1, arg2: P2, arg3: P3, arg4: P4, arg5: P5, arg6: P6, arg7: P7, arg8: P8) -> list[tuple[R1]]: ...
728
- @abstractmethod
729
- async def fetchmany(self, connection: Connection, args: Iterable[tuple[P1, P2, P3, P4, P5, P6, P7, P8]]) -> list[tuple[R1]]: ...
730
- @abstractmethod
731
- async def fetchrow(self, connection: Connection, arg1: P1, arg2: P2, arg3: P3, arg4: P4, arg5: P5, arg6: P6, arg7: P7, arg8: P8) -> tuple[R1] | None: ...
732
- @abstractmethod
733
- async def fetchval(self, connection: Connection, arg1: P1, arg2: P2, arg3: P3, arg4: P4, arg5: P5, arg6: P6, arg7: P7, arg8: P8) -> R1: ...
920
+ :param stmt: SQL statement as a string (or f-string).
921
+ :param args: Type signature for multiple input parameters (e.g. `tuple[bool, int, str]`).
922
+ :param resultset: Type signature for multiple resultset columns (e.g. `tuple[datetime, Decimal, str]`).
923
+ :param arg: Type signature for a single input parameter (e.g. `int`).
924
+ :param result: Type signature for a single result column (e.g. `UUID`).
925
+ """
734
926
 
927
+ input_data_types, output_data_types = _sql_args_resultset(args=args, resultset=resultset, arg=arg, result=result)
928
+ obj = _SQLString(stmt, args=input_data_types, resultset=output_data_types)
929
+ return _SQLImpl(obj)
735
930
 
736
- class SQL_P8_RX(Generic[P1, P2, P3, P4, P5, P6, P7, P8, R1, R2, Unpack[RX]], SQL_P8[P1, P2, P3, P4, P5, P6, P7, P8]):
737
- @abstractmethod
738
- async def fetch(self, connection: Connection, arg1: P1, arg2: P2, arg3: P3, arg4: P4, arg5: P5, arg6: P6, arg7: P7, arg8: P8) -> list[tuple[R1, R2, Unpack[RX]]]: ...
739
- @abstractmethod
740
- async def fetchmany(self, connection: Connection, args: Iterable[tuple[P1, P2, P3, P4, P5, P6, P7, P8]]) -> list[tuple[R1, R2, Unpack[RX]]]: ...
741
- @abstractmethod
742
- async def fetchrow(self, connection: Connection, arg1: P1, arg2: P2, arg3: P3, arg4: P4, arg5: P5, arg6: P6, arg7: P7, arg8: P8) -> tuple[R1, R2, Unpack[RX]] | None: ...
743
-
744
-
745
- @overload
746
- def sql(stmt: SQLExpression) -> SQL_P0: ...
747
- @overload
748
- def sql(stmt: SQLExpression, *, result: type[RS]) -> SQL_P0_RS[RS]: ...
749
- @overload
750
- def sql(stmt: SQLExpression, *, resultset: type[tuple[R1]]) -> SQL_P0_RS[R1]: ...
751
- @overload
752
- def sql(stmt: SQLExpression, *, resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_P0_RX[R1, R2, Unpack[RX]]: ...
753
- @overload
754
- def sql(stmt: SQLExpression, *, arg: type[PS]) -> SQL_P1[PS]: ...
755
- @overload
756
- def sql(stmt: SQLExpression, *, args: type[tuple[P1]]) -> SQL_P1[P1]: ...
757
- @overload
758
- def sql(stmt: SQLExpression, *, arg: type[PS], result: type[RS]) -> SQL_P1_RS[PS, RS]: ...
759
- @overload
760
- def sql(stmt: SQLExpression, *, args: type[tuple[P1]], resultset: type[tuple[R1]]) -> SQL_P1_RS[P1, R1]: ...
761
- @overload
762
- def sql(stmt: SQLExpression, *, arg: type[PS], resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_P1_RX[PS, R1, R2, Unpack[RX]]: ...
763
- @overload
764
- def sql(stmt: SQLExpression, *, args: type[tuple[P1]], resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_P1_RX[P1, R1, R2, Unpack[RX]]: ...
765
- @overload
766
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2]]) -> SQL_P2[P1, P2]: ...
767
- @overload
768
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2]], result: type[RS]) -> SQL_P2_RS[P1, P2, RS]: ...
769
- @overload
770
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2]], resultset: type[tuple[R1]]) -> SQL_P2_RS[P1, P2, R1]: ...
771
- @overload
772
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2]], resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_P2_RX[P1, P2, R1, R2, Unpack[RX]]: ...
773
- @overload
774
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3]]) -> SQL_P3[P1, P2, P3]: ...
775
- @overload
776
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3]], result: type[RS]) -> SQL_P3_RS[P1, P2, P3, RS]: ...
777
- @overload
778
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3]], resultset: type[tuple[R1]]) -> SQL_P3_RS[P1, P2, P3, R1]: ...
779
- @overload
780
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3]], resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_P3_RX[P1, P2, P3, R1, R2, Unpack[RX]]: ...
781
- @overload
782
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4]]) -> SQL_P4[P1, P2, P3, P4]: ...
783
- @overload
784
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4]], result: type[RS]) -> SQL_P4_RS[P1, P2, P3, P4, RS]: ...
785
- @overload
786
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4]], resultset: type[tuple[R1]]) -> SQL_P4_RS[P1, P2, P3, P4, R1]: ...
787
- @overload
788
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4]], resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_P4_RX[P1, P2, P3, P4, R1, R2, Unpack[RX]]: ...
789
- @overload
790
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5]]) -> SQL_P5[P1, P2, P3, P4, P5]: ...
791
- @overload
792
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5]], result: type[RS]) -> SQL_P5_RS[P1, P2, P3, P4, P5, RS]: ...
793
- @overload
794
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5]], resultset: type[tuple[R1]]) -> SQL_P5_RS[P1, P2, P3, P4, P5, R1]: ...
795
- @overload
796
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5]], resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_P5_RX[P1, P2, P3, P4, P5, R1, R2, Unpack[RX]]: ...
797
- @overload
798
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6]]) -> SQL_P6[P1, P2, P3, P4, P5, P6]: ...
799
- @overload
800
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6]], result: type[RS]) -> SQL_P6_RS[P1, P2, P3, P4, P5, P6, RS]: ...
801
- @overload
802
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6]], resultset: type[tuple[R1]]) -> SQL_P6_RS[P1, P2, P3, P4, P5, P6, R1]: ...
803
- @overload
804
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6]], resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_P6_RX[P1, P2, P3, P4, P5, P6, R1, R2, Unpack[RX]]: ...
805
- @overload
806
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6, P7]]) -> SQL_P7[P1, P2, P3, P4, P5, P6, P7]: ...
807
- @overload
808
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6, P7]], result: type[RS]) -> SQL_P7_RS[P1, P2, P3, P4, P5, P6, P7, RS]: ...
809
- @overload
810
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6, P7]], resultset: type[tuple[R1]]) -> SQL_P7_RS[P1, P2, P3, P4, P5, P6, P7, R1]: ...
811
- @overload
812
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6, P7]], resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_P7_RX[P1, P2, P3, P4, P5, P6, P7, R1, R2, Unpack[RX]]: ...
813
- @overload
814
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6, P7, P8]]) -> SQL_P8[P1, P2, P3, P4, P5, P6, P7, P8]: ...
815
- @overload
816
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6, P7, P8]], result: type[RS]) -> SQL_P8_RS[P1, P2, P3, P4, P5, P6, P7, P8, RS]: ...
817
- @overload
818
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6, P7, P8]], resultset: type[tuple[R1]]) -> SQL_P8_RS[P1, P2, P3, P4, P5, P6, P7, P8, R1]: ...
819
- @overload
820
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6, P7, P8]], resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_P8_RX[P1, P2, P3, P4, P5, P6, P7, P8, R1, R2, Unpack[RX]]: ...
821
-
822
-
823
- ### END OF AUTO-GENERATED BLOCK ###
824
-
825
-
826
- def sql(
827
- stmt: SQLExpression,
828
- *,
829
- args: type[Any] | None = None,
830
- resultset: type[Any] | None = None,
831
- arg: type[Any] | None = None,
832
- result: type[Any] | None = None,
833
- ) -> _SQL:
834
- """
835
- Creates a SQL statement with associated type information.
836
931
 
837
- :param stmt: SQL statement as a literal string or template.
838
- :param args: Type signature for multiple input parameters (e.g. `tuple[bool, int, str]`).
839
- :param resultset: Type signature for multiple resultset columns (e.g. `tuple[datetime, Decimal, str]`).
840
- :param arg: Type signature for a single input parameter (e.g. `int`).
841
- :param result: Type signature for a single result column (e.g. `UUID`).
842
- """
932
+ def _sql_args_resultset(*, args: type[Any] | None = None, resultset: type[Any] | None = None, arg: type[Any] | None = None, result: type[Any] | None = None) -> tuple[tuple[Any, ...], tuple[Any, ...]]:
933
+ "Parses an argument/resultset signature into input/output types."
843
934
 
844
935
  if args is not None and arg is not None:
845
936
  raise TypeError("expected: either `args` or `arg`; got: both")
@@ -864,14 +955,10 @@ def sql(
864
955
  else:
865
956
  output_data_types = ()
866
957
 
867
- if sys.version_info >= (3, 14):
868
- obj: _SQLObject
869
- match stmt:
870
- case Template():
871
- obj = _SQLTemplate(stmt, args=input_data_types, resultset=output_data_types)
872
- case str():
873
- obj = _SQLString(stmt, args=input_data_types, resultset=output_data_types)
874
- else:
875
- obj = _SQLString(stmt, args=input_data_types, resultset=output_data_types)
958
+ return input_data_types, output_data_types
959
+
960
+
961
+ FACTORY: SQLFactory = SQLFactory()
876
962
 
877
- return _SQLImpl(obj)
963
+ sql = FACTORY.sql
964
+ unsafe_sql = FACTORY.unsafe_sql