asyncpg-typed 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
asyncpg_typed/__init__.py CHANGED
@@ -4,15 +4,15 @@ Type-safe queries for asyncpg.
4
4
  :see: https://github.com/hunyadi/asyncpg_typed
5
5
  """
6
6
 
7
- __version__ = "0.1.0"
7
+ __version__ = "0.1.1"
8
8
  __author__ = "Levente Hunyadi"
9
9
  __copyright__ = "Copyright 2025, Levente Hunyadi"
10
10
  __license__ = "MIT"
11
11
  __maintainer__ = "Levente Hunyadi"
12
12
  __status__ = "Production"
13
13
 
14
+ import enum
14
15
  import sys
15
- import typing
16
16
  from abc import abstractmethod
17
17
  from collections.abc import Iterable, Sequence
18
18
  from datetime import date, datetime, time
@@ -27,9 +27,9 @@ import asyncpg
27
27
  from asyncpg.prepared_stmt import PreparedStatement
28
28
 
29
29
  if sys.version_info < (3, 11):
30
- from typing_extensions import TypeVarTuple, Unpack
30
+ from typing_extensions import LiteralString, TypeVarTuple, Unpack
31
31
  else:
32
- from typing import TypeVarTuple, Unpack
32
+ from typing import LiteralString, TypeVarTuple, Unpack
33
33
 
34
34
  # list of supported data types
35
35
  DATA_TYPES: list[type[Any]] = [bool, int, float, Decimal, date, time, datetime, str, bytes, UUID]
@@ -38,9 +38,29 @@ DATA_TYPES: list[type[Any]] = [bool, int, float, Decimal, date, time, datetime,
38
38
  NUM_ARGS = 8
39
39
 
40
40
 
41
+ if sys.version_info >= (3, 11):
42
+
43
+ def is_enum_type(typ: object) -> bool:
44
+ """
45
+ `True` if the specified type is an enumeration type.
46
+ """
47
+
48
+ return isinstance(typ, enum.EnumType)
49
+
50
+ else:
51
+
52
+ def is_enum_type(typ: object) -> bool:
53
+ """
54
+ `True` if the specified type is an enumeration type.
55
+ """
56
+
57
+ # use an explicit isinstance(..., type) check to filter out special forms like generics
58
+ return isinstance(typ, type) and issubclass(typ, enum.Enum)
59
+
60
+
41
61
  def is_union_type(tp: Any) -> bool:
42
62
  """
43
- Returns `True` if `tp` is a union type such as `A | B` or `Union[A, B]`.
63
+ `True` if `tp` is a union type such as `A | B` or `Union[A, B]`.
44
64
  """
45
65
 
46
66
  origin = get_origin(tp)
@@ -49,7 +69,7 @@ def is_union_type(tp: Any) -> bool:
49
69
 
50
70
  def is_optional_type(tp: Any) -> bool:
51
71
  """
52
- Returns `True` if `tp` is an optional type such as `T | None`, `Optional[T]` or `Union[T, None]`.
72
+ `True` if `tp` is an optional type such as `T | None`, `Optional[T]` or `Union[T, None]`.
53
73
  """
54
74
 
55
75
  return is_union_type(tp) and any(a is type(None) for a in get_args(tp))
@@ -57,7 +77,7 @@ def is_optional_type(tp: Any) -> bool:
57
77
 
58
78
  def is_standard_type(tp: Any) -> bool:
59
79
  """
60
- Returns `True` if the type represents a built-in or a well-known standard type.
80
+ `True` if the type represents a built-in or a well-known standard type.
61
81
  """
62
82
 
63
83
  return tp.__module__ == "builtins" or tp.__module__ == UnionType.__module__
@@ -115,21 +135,20 @@ _name_to_type: dict[str, Any] = {
115
135
  }
116
136
 
117
137
 
118
- def check_data_type(name: str, data_type: type[Any]) -> bool:
138
+ def check_data_type(schema: str, name: str, data_type: type[Any]) -> bool:
119
139
  """
120
140
  Verifies if the Python target type can represent the PostgreSQL source type.
121
141
  """
122
142
 
123
- expected_type = _name_to_type.get(name)
124
- required_type = get_required_type(data_type)
125
-
126
- if expected_type is not None:
127
- return expected_type == required_type
128
- if is_standard_type(required_type):
129
- return False
143
+ if schema == "pg_catalog":
144
+ expected_type = _name_to_type.get(name)
145
+ return expected_type == data_type
146
+ else:
147
+ if is_standard_type(data_type):
148
+ return False
130
149
 
131
- # user-defined type registered with `conn.set_type_codec()`
132
- return True
150
+ # user-defined type registered with `conn.set_type_codec()`
151
+ return True
133
152
 
134
153
 
135
154
  class _SQLPlaceholder:
@@ -152,35 +171,28 @@ class _SQLObject:
152
171
  parameter_data_types: tuple[_SQLPlaceholder, ...]
153
172
  resultset_data_types: tuple[type[Any], ...]
154
173
  required: int
174
+ cast: int
155
175
 
156
176
  def __init__(
157
177
  self,
158
- *,
159
- args: type[Any] | None = None,
160
- resultset: type[Any] | None = None,
178
+ input_data_types: tuple[type[Any], ...],
179
+ output_data_types: tuple[type[Any], ...],
161
180
  ) -> None:
162
- if args is not None:
163
- if get_origin(args) is tuple:
164
- self.parameter_data_types = tuple(_SQLPlaceholder(ordinal, arg) for ordinal, arg in enumerate(get_args(args), start=1))
165
- else:
166
- self.parameter_data_types = (_SQLPlaceholder(1, args),)
167
- else:
168
- self.parameter_data_types = ()
169
-
170
- if resultset is not None:
171
- if get_origin(resultset) is tuple:
172
- self.resultset_data_types = get_args(resultset)
173
- else:
174
- self.resultset_data_types = (resultset,)
175
- else:
176
- self.resultset_data_types = ()
181
+ self.parameter_data_types = tuple(_SQLPlaceholder(ordinal, get_required_type(arg)) for ordinal, arg in enumerate(input_data_types, start=1))
182
+ self.resultset_data_types = tuple(get_required_type(data_type) for data_type in output_data_types)
177
183
 
178
184
  # create a bit-field of required types (1: required; 0: optional)
179
185
  required = 0
180
- for index, data_type in enumerate(self.resultset_data_types):
186
+ for index, data_type in enumerate(output_data_types):
181
187
  required |= (not is_optional_type(data_type)) << index
182
188
  self.required = required
183
189
 
190
+ # create a bit-field of types that require cast/conversion (1: pass to __init__; 0: skip)
191
+ cast = 0
192
+ for index, data_type in enumerate(self.resultset_data_types):
193
+ cast |= is_enum_type(data_type) << index
194
+ self.cast = cast
195
+
184
196
  def _raise_required_is_none(self, row: tuple[Any, ...], row_index: int | None = None) -> None:
185
197
  """
186
198
  Raises an error with the index of the first column value that is of a required type but has been assigned a value of `None`.
@@ -319,7 +331,7 @@ class _SQLObject:
319
331
  if sys.version_info >= (3, 14):
320
332
  from string.templatelib import Interpolation, Template # type: ignore[import-not-found]
321
333
 
322
- SQLExpression: TypeAlias = Template | str
334
+ SQLExpression: TypeAlias = Template | LiteralString
323
335
 
324
336
  class _SQLTemplate(_SQLObject):
325
337
  """
@@ -333,10 +345,10 @@ if sys.version_info >= (3, 14):
333
345
  self,
334
346
  template: Template,
335
347
  *,
336
- args: type[Any] | None = None,
337
- resultset: type[Any] | None = None,
348
+ args: tuple[type[Any], ...],
349
+ resultset: tuple[type[Any], ...],
338
350
  ) -> None:
339
- super().__init__(args=args, resultset=resultset)
351
+ super().__init__(args, resultset)
340
352
 
341
353
  for ip in template.interpolations:
342
354
  if ip.conversion is not None:
@@ -348,7 +360,7 @@ if sys.version_info >= (3, 14):
348
360
 
349
361
  self.strings = template.strings
350
362
 
351
- if args is not None:
363
+ if len(self.parameter_data_types) > 0:
352
364
 
353
365
  def _to_placeholder(ip: Interpolation) -> _SQLPlaceholder:
354
366
  ordinal = int(ip.value)
@@ -369,7 +381,7 @@ if sys.version_info >= (3, 14):
369
381
  return buf.getvalue()
370
382
 
371
383
  else:
372
- SQLExpression = str
384
+ SQLExpression = LiteralString
373
385
 
374
386
 
375
387
  class _SQLString(_SQLObject):
@@ -383,10 +395,10 @@ class _SQLString(_SQLObject):
383
395
  self,
384
396
  sql: str,
385
397
  *,
386
- args: type[Any] | None = None,
387
- resultset: type[Any] | None = None,
398
+ args: tuple[type[Any], ...],
399
+ resultset: tuple[type[Any], ...],
388
400
  ) -> None:
389
- super().__init__(args=args, resultset=resultset)
401
+ super().__init__(args, resultset)
390
402
  self.sql = sql
391
403
 
392
404
  def query(self) -> str:
@@ -422,7 +434,7 @@ class _SQLImpl(_SQL):
422
434
  stmt = await connection.prepare(self.sql.query())
423
435
 
424
436
  for attr, data_type in zip(stmt.get_attributes(), self.sql.resultset_data_types, strict=True):
425
- if not check_data_type(attr.type.name, data_type):
437
+ if not check_data_type(attr.type.schema, attr.type.name, data_type):
426
438
  raise TypeError(f"expected: {data_type} in column `{attr.name}`; got: `{attr.type.kind}` of `{attr.type.name}`")
427
439
 
428
440
  return stmt
@@ -434,40 +446,51 @@ class _SQLImpl(_SQL):
434
446
  stmt = await self._prepare(connection)
435
447
  await stmt.executemany(args)
436
448
 
449
+ def _cast_fetch(self, rows: list[asyncpg.Record]) -> list[tuple[Any, ...]]:
450
+ cast = self.sql.cast
451
+ if cast:
452
+ data_types = self.sql.resultset_data_types
453
+ resultset = [tuple((data_types[i](value) if (value := row[i]) is not None and cast >> i & 1 else value) for i in range(len(row))) for row in rows]
454
+ else:
455
+ resultset = [tuple(value for value in row) for row in rows]
456
+ self.sql.check_rows(resultset)
457
+ return resultset
458
+
437
459
  async def fetch(self, connection: asyncpg.Connection, *args: Any) -> list[tuple[Any, ...]]:
438
460
  stmt = await self._prepare(connection)
439
461
  rows = await stmt.fetch(*args)
440
- resultset = [tuple(value for value in row) for row in rows]
441
- self.sql.check_rows(resultset)
442
- return resultset
462
+ return self._cast_fetch(rows)
443
463
 
444
464
  async def fetchmany(self, connection: asyncpg.Connection, args: Iterable[Sequence[Any]]) -> list[tuple[Any, ...]]:
445
465
  stmt = await self._prepare(connection)
446
- rows = await stmt.fetchmany(args) # type: ignore[arg-type, call-arg] # pyright: ignore[reportCallIssue]
447
- rows = typing.cast(list[asyncpg.Record], rows)
448
- resultset = [tuple(value for value in row) for row in rows]
449
- self.sql.check_rows(resultset)
450
- return resultset
466
+ rows = await stmt.fetchmany(args)
467
+ return self._cast_fetch(rows)
451
468
 
452
469
  async def fetchrow(self, connection: asyncpg.Connection, *args: Any) -> tuple[Any, ...] | None:
453
470
  stmt = await self._prepare(connection)
454
471
  row = await stmt.fetchrow(*args)
455
472
  if row is None:
456
473
  return None
457
- resultset = tuple(value for value in row)
474
+ cast = self.sql.cast
475
+ if cast:
476
+ data_types = self.sql.resultset_data_types
477
+ resultset = tuple((data_types[i](value) if (value := row[i]) is not None and cast >> i & 1 else value) for i in range(len(row)))
478
+ else:
479
+ resultset = tuple(value for value in row)
458
480
  self.sql.check_row(resultset)
459
481
  return resultset
460
482
 
461
483
  async def fetchval(self, connection: asyncpg.Connection, *args: Any) -> Any:
462
484
  stmt = await self._prepare(connection)
463
485
  value = await stmt.fetchval(*args)
464
- self.sql.check_value(value)
465
- return value
486
+ result = self.sql.resultset_data_types[0](value) if value is not None and self.sql.cast else value
487
+ self.sql.check_value(result)
488
+ return result
466
489
 
467
490
 
468
491
  ### START OF AUTO-GENERATED BLOCK ###
469
492
 
470
- PS = TypeVar("PS", bool, bool | None, int, int | None, float, float | None, Decimal, Decimal | None, date, date | None, time, time | None, datetime, datetime | None, str, str | None, bytes, bytes | None, UUID, UUID | None)
493
+ PS = TypeVar("PS")
471
494
  P1 = TypeVar("P1")
472
495
  P2 = TypeVar("P2")
473
496
  P3 = TypeVar("P3")
@@ -476,7 +499,7 @@ P5 = TypeVar("P5")
476
499
  P6 = TypeVar("P6")
477
500
  P7 = TypeVar("P7")
478
501
  P8 = TypeVar("P8")
479
- RS = TypeVar("RS", bool, bool | None, int, int | None, float, float | None, Decimal, Decimal | None, date, date | None, time, time | None, datetime, datetime | None, str, str | None, bytes, bytes | None, UUID, UUID | None)
502
+ RS = TypeVar("RS")
480
503
  R1 = TypeVar("R1")
481
504
  R2 = TypeVar("R2")
482
505
  RX = TypeVarTuple("RX")
@@ -722,27 +745,27 @@ class SQL_P8_RX(Generic[P1, P2, P3, P4, P5, P6, P7, P8, R1, R2, Unpack[RX]], SQL
722
745
  @overload
723
746
  def sql(stmt: SQLExpression) -> SQL_P0: ...
724
747
  @overload
725
- def sql(stmt: SQLExpression, *, resultset: type[RS]) -> SQL_P0_RS[RS]: ...
748
+ def sql(stmt: SQLExpression, *, result: type[RS]) -> SQL_P0_RS[RS]: ...
726
749
  @overload
727
750
  def sql(stmt: SQLExpression, *, resultset: type[tuple[R1]]) -> SQL_P0_RS[R1]: ...
728
751
  @overload
729
752
  def sql(stmt: SQLExpression, *, resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_P0_RX[R1, R2, Unpack[RX]]: ...
730
753
  @overload
731
- def sql(stmt: SQLExpression, *, args: type[PS]) -> SQL_P1[PS]: ...
754
+ def sql(stmt: SQLExpression, *, arg: type[PS]) -> SQL_P1[PS]: ...
732
755
  @overload
733
756
  def sql(stmt: SQLExpression, *, args: type[tuple[P1]]) -> SQL_P1[P1]: ...
734
757
  @overload
735
- def sql(stmt: SQLExpression, *, args: type[PS], resultset: type[RS]) -> SQL_P1_RS[PS, RS]: ...
758
+ def sql(stmt: SQLExpression, *, arg: type[PS], result: type[RS]) -> SQL_P1_RS[PS, RS]: ...
736
759
  @overload
737
760
  def sql(stmt: SQLExpression, *, args: type[tuple[P1]], resultset: type[tuple[R1]]) -> SQL_P1_RS[P1, R1]: ...
738
761
  @overload
739
- def sql(stmt: SQLExpression, *, args: type[PS], resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_P1_RX[PS, R1, R2, Unpack[RX]]: ...
762
+ def sql(stmt: SQLExpression, *, arg: type[PS], resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_P1_RX[PS, R1, R2, Unpack[RX]]: ...
740
763
  @overload
741
764
  def sql(stmt: SQLExpression, *, args: type[tuple[P1]], resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_P1_RX[P1, R1, R2, Unpack[RX]]: ...
742
765
  @overload
743
766
  def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2]]) -> SQL_P2[P1, P2]: ...
744
767
  @overload
745
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2]], resultset: type[RS]) -> SQL_P2_RS[P1, P2, RS]: ...
768
+ def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2]], result: type[RS]) -> SQL_P2_RS[P1, P2, RS]: ...
746
769
  @overload
747
770
  def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2]], resultset: type[tuple[R1]]) -> SQL_P2_RS[P1, P2, R1]: ...
748
771
  @overload
@@ -750,7 +773,7 @@ def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2]], resultset: type[tuple
750
773
  @overload
751
774
  def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3]]) -> SQL_P3[P1, P2, P3]: ...
752
775
  @overload
753
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3]], resultset: type[RS]) -> SQL_P3_RS[P1, P2, P3, RS]: ...
776
+ def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3]], result: type[RS]) -> SQL_P3_RS[P1, P2, P3, RS]: ...
754
777
  @overload
755
778
  def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3]], resultset: type[tuple[R1]]) -> SQL_P3_RS[P1, P2, P3, R1]: ...
756
779
  @overload
@@ -758,7 +781,7 @@ def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3]], resultset: type[t
758
781
  @overload
759
782
  def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4]]) -> SQL_P4[P1, P2, P3, P4]: ...
760
783
  @overload
761
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4]], resultset: type[RS]) -> SQL_P4_RS[P1, P2, P3, P4, RS]: ...
784
+ def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4]], result: type[RS]) -> SQL_P4_RS[P1, P2, P3, P4, RS]: ...
762
785
  @overload
763
786
  def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4]], resultset: type[tuple[R1]]) -> SQL_P4_RS[P1, P2, P3, P4, R1]: ...
764
787
  @overload
@@ -766,7 +789,7 @@ def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4]], resultset: ty
766
789
  @overload
767
790
  def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5]]) -> SQL_P5[P1, P2, P3, P4, P5]: ...
768
791
  @overload
769
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5]], resultset: type[RS]) -> SQL_P5_RS[P1, P2, P3, P4, P5, RS]: ...
792
+ def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5]], result: type[RS]) -> SQL_P5_RS[P1, P2, P3, P4, P5, RS]: ...
770
793
  @overload
771
794
  def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5]], resultset: type[tuple[R1]]) -> SQL_P5_RS[P1, P2, P3, P4, P5, R1]: ...
772
795
  @overload
@@ -774,7 +797,7 @@ def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5]], resultset
774
797
  @overload
775
798
  def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6]]) -> SQL_P6[P1, P2, P3, P4, P5, P6]: ...
776
799
  @overload
777
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6]], resultset: type[RS]) -> SQL_P6_RS[P1, P2, P3, P4, P5, P6, RS]: ...
800
+ def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6]], result: type[RS]) -> SQL_P6_RS[P1, P2, P3, P4, P5, P6, RS]: ...
778
801
  @overload
779
802
  def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6]], resultset: type[tuple[R1]]) -> SQL_P6_RS[P1, P2, P3, P4, P5, P6, R1]: ...
780
803
  @overload
@@ -782,7 +805,7 @@ def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6]], resul
782
805
  @overload
783
806
  def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6, P7]]) -> SQL_P7[P1, P2, P3, P4, P5, P6, P7]: ...
784
807
  @overload
785
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6, P7]], resultset: type[RS]) -> SQL_P7_RS[P1, P2, P3, P4, P5, P6, P7, RS]: ...
808
+ def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6, P7]], result: type[RS]) -> SQL_P7_RS[P1, P2, P3, P4, P5, P6, P7, RS]: ...
786
809
  @overload
787
810
  def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6, P7]], resultset: type[tuple[R1]]) -> SQL_P7_RS[P1, P2, P3, P4, P5, P6, P7, R1]: ...
788
811
  @overload
@@ -790,7 +813,7 @@ def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6, P7]], r
790
813
  @overload
791
814
  def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6, P7, P8]]) -> SQL_P8[P1, P2, P3, P4, P5, P6, P7, P8]: ...
792
815
  @overload
793
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6, P7, P8]], resultset: type[RS]) -> SQL_P8_RS[P1, P2, P3, P4, P5, P6, P7, P8, RS]: ...
816
+ def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6, P7, P8]], result: type[RS]) -> SQL_P8_RS[P1, P2, P3, P4, P5, P6, P7, P8, RS]: ...
794
817
  @overload
795
818
  def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6, P7, P8]], resultset: type[tuple[R1]]) -> SQL_P8_RS[P1, P2, P3, P4, P5, P6, P7, P8, R1]: ...
796
819
  @overload
@@ -805,23 +828,50 @@ def sql(
805
828
  *,
806
829
  args: type[Any] | None = None,
807
830
  resultset: type[Any] | None = None,
831
+ arg: type[Any] | None = None,
832
+ result: type[Any] | None = None,
808
833
  ) -> _SQL:
809
834
  """
810
835
  Creates a SQL statement with associated type information.
811
836
 
812
- :param stmt: SQL statement as a string or template.
813
- :param args: Type signature for input parameters. Use the type for a single parameter (e.g. `int`) or `tuple[...]` for multiple parameters.
814
- :param resultset: Type signature for output data. Use the type for a single parameter (e.g. `int`) or `tuple[...]` for multiple parameters.
837
+ :param stmt: SQL statement as a literal string or template.
838
+ :param args: Type signature for multiple input parameters (e.g. `tuple[bool, int, str]`).
839
+ :param resultset: Type signature for multiple resultset columns (e.g. `tuple[datetime, Decimal, str]`).
840
+ :param arg: Type signature for a single input parameter (e.g. `int`).
841
+ :param result: Type signature for a single result column (e.g. `UUID`).
815
842
  """
816
843
 
844
+ if args is not None and arg is not None:
845
+ raise TypeError("expected: either `args` or `arg`; got: both")
846
+ if resultset is not None and result is not None:
847
+ raise TypeError("expected: either `resultset` or `result`; got: both")
848
+
849
+ if args is not None:
850
+ if get_origin(args) is not tuple:
851
+ raise TypeError(f"expected: `type[tuple[T, ...]]` for `args`; got: {type(args)}")
852
+ input_data_types = get_args(args)
853
+ elif arg is not None:
854
+ input_data_types = (arg,)
855
+ else:
856
+ input_data_types = ()
857
+
858
+ if resultset is not None:
859
+ if get_origin(resultset) is not tuple:
860
+ raise TypeError(f"expected: `type[tuple[T, ...]]` for `resultset`; got: {type(resultset)}")
861
+ output_data_types = get_args(resultset)
862
+ elif result is not None:
863
+ output_data_types = (result,)
864
+ else:
865
+ output_data_types = ()
866
+
817
867
  if sys.version_info >= (3, 14):
818
868
  obj: _SQLObject
819
869
  match stmt:
820
870
  case Template():
821
- obj = _SQLTemplate(stmt, args=args, resultset=resultset)
871
+ obj = _SQLTemplate(stmt, args=input_data_types, resultset=output_data_types)
822
872
  case str():
823
- obj = _SQLString(stmt, args=args, resultset=resultset)
873
+ obj = _SQLString(stmt, args=input_data_types, resultset=output_data_types)
824
874
  else:
825
- obj = _SQLString(stmt, args=args, resultset=resultset)
875
+ obj = _SQLString(stmt, args=input_data_types, resultset=output_data_types)
826
876
 
827
877
  return _SQLImpl(obj)
@@ -1,13 +1,13 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: asyncpg_typed
3
- Version: 0.1.0
3
+ Version: 0.1.1
4
4
  Summary: Type-safe queries for asyncpg
5
5
  Author-email: Levente Hunyadi <hunyadi@gmail.com>
6
6
  Maintainer-email: Levente Hunyadi <hunyadi@gmail.com>
7
7
  License-Expression: MIT
8
8
  Project-URL: Homepage, https://github.com/hunyadi/asyncpg_typed
9
9
  Project-URL: Source, https://github.com/hunyadi/asyncpg_typed
10
- Keywords: asyncpg,typed,database-client
10
+ Keywords: asyncpg,typed,database-client,postgres
11
11
  Classifier: Development Status :: 5 - Production/Stable
12
12
  Classifier: Intended Audience :: Developers
13
13
  Classifier: Operating System :: OS Independent
@@ -85,19 +85,21 @@ Instantiate a SQL object with the `sql` function:
85
85
 
86
86
  ```python
87
87
  def sql(
88
- stmt: str | string.templatelib.Template,
88
+ stmt: LiteralString | string.templatelib.Template,
89
89
  *,
90
- args: None | type[P1] | type[tuple[P1, P2]] | type[tuple[P1, P2, P3]] | ... = None,
91
- resultset: None | type[R1] | type[tuple[R1, R2]] | type[tuple[R1, R2, R3]] | ... = None
90
+ args: None | type[tuple[P1, P2]] | type[tuple[P1, P2, P3]] | ... = None,
91
+ resultset: None | type[tuple[R1, R2]] | type[tuple[R1, R2, R3]] | ... = None,
92
+ arg: None | type[P] = None,
93
+ result: None | type[R] = None,
92
94
  ) -> _SQL: ...
93
95
  ```
94
96
 
95
- The parameter `stmt` represents a SQL expression, either as a string (including an *f-string*) or a template (i.e. a *t-string*).
97
+ The parameter `stmt` represents a SQL expression, either as a literal string or a template (i.e. a *t-string*).
96
98
 
97
99
  If the expression is a string, it can have PostgreSQL parameter placeholders such as `$1`, `$2` or `$3`:
98
100
 
99
101
  ```python
100
- f"INSERT INTO table_name (col_1, col_2, col_3) VALUES ($1, $2, $3);"
102
+ "INSERT INTO table_name (col_1, col_2, col_3) VALUES ($1, $2, $3);"
101
103
  ```
102
104
 
103
105
  If the expression is a *t-string*, it can have replacement fields that evaluate to integers:
@@ -106,11 +108,10 @@ If the expression is a *t-string*, it can have replacement fields that evaluate
106
108
  t"INSERT INTO table_name (col_1, col_2, col_3) VALUES ({1}, {2}, {3});"
107
109
  ```
108
110
 
109
- The parameters `args` and `resultset` take a series type `P` or `R`, which may be any of the following:
111
+ The parameters `args` and `resultset` take a `tuple` of several types `Px` or `Rx`, each of which may be any of the following:
110
112
 
111
113
  * (required) simple type
112
114
  * optional simple type (`T | None`)
113
- * `tuple` of several (required or optional) simple types.
114
115
 
115
116
  Simple types include:
116
117
 
@@ -124,6 +125,7 @@ Simple types include:
124
125
  * `str`
125
126
  * `bytes`
126
127
  * `uuid.UUID`
128
+ * a user-defined class that derives from `StrEnum`
127
129
 
128
130
  Types are grouped together with `tuple`:
129
131
 
@@ -131,7 +133,7 @@ Types are grouped together with `tuple`:
131
133
  tuple[bool, int, str | None]
132
134
  ```
133
135
 
134
- Passing a simple type directly (e.g. `type[T]`) is for convenience, and is equivalent to passing a one-element tuple of the same simple type (i.e. `type[tuple[T]]`).
136
+ The parameters `arg` and `result` take a single type `P` or `R`. Passing a simple type (e.g. `type[T]`) directly via `arg` and `result` is for convenience, and is equivalent to passing a one-element tuple of the same simple type (i.e. `type[tuple[T]]`) via `args` and `resultset`.
135
137
 
136
138
  The number of types in `args` must correspond to the number of query parameters. (This is validated on calling `sql(...)` for the *t-string* syntax.) The number of types in `resultset` must correspond to the number of columns returned by the query.
137
139
 
@@ -159,6 +161,7 @@ Both `args` and `resultset` types must be compatible with their corresponding Po
159
161
  | `json` | `str` |
160
162
  | `jsonb` | `str` |
161
163
  | `uuid` | `UUID` |
164
+ | enumeration | `E: StrEnum` |
162
165
 
163
166
  ### Using a SQL object
164
167
 
@@ -0,0 +1,8 @@
1
+ asyncpg_typed/__init__.py,sha256=2c4xyDjjR-yVrIlcgL-rqLJlWB7_JJR76eRxPDTqAPY,38154
2
+ asyncpg_typed/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ asyncpg_typed-0.1.1.dist-info/licenses/LICENSE,sha256=rx4jD36wX8TyLZaR2HEOJ6TphFPjKUqoCSSYWzwWNRk,1093
4
+ asyncpg_typed-0.1.1.dist-info/METADATA,sha256=6RqDzYtI9FnIbKAjHHQdhnOQhBfM3pK1IHUOYXNf9yU,8652
5
+ asyncpg_typed-0.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
6
+ asyncpg_typed-0.1.1.dist-info/top_level.txt,sha256=T0X1nWnXRTi5a5oTErGy572ORDbM9UV9wfhRXWLsaoY,14
7
+ asyncpg_typed-0.1.1.dist-info/zip-safe,sha256=frcCV1k9oG9oKj3dpUqdJg1PxRT2RSN_XKdLCPjaYaY,2
8
+ asyncpg_typed-0.1.1.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- asyncpg_typed/__init__.py,sha256=6F8tV2H1ayXFptYmsXLEi3puqKT-U904qamONeCJXUA,36489
2
- asyncpg_typed/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- asyncpg_typed-0.1.0.dist-info/licenses/LICENSE,sha256=rx4jD36wX8TyLZaR2HEOJ6TphFPjKUqoCSSYWzwWNRk,1093
4
- asyncpg_typed-0.1.0.dist-info/METADATA,sha256=ti6ld6HyUOodNUCmbNru0xiUu5mNHM_Z2TiDvQm4CNA,8429
5
- asyncpg_typed-0.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
6
- asyncpg_typed-0.1.0.dist-info/top_level.txt,sha256=T0X1nWnXRTi5a5oTErGy572ORDbM9UV9wfhRXWLsaoY,14
7
- asyncpg_typed-0.1.0.dist-info/zip-safe,sha256=frcCV1k9oG9oKj3dpUqdJg1PxRT2RSN_XKdLCPjaYaY,2
8
- asyncpg_typed-0.1.0.dist-info/RECORD,,