asyncpg-typed 0.1.0__tar.gz → 0.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (21) hide show
  1. asyncpg_typed-0.1.1/MANIFEST.in +1 -0
  2. {asyncpg_typed-0.1.0 → asyncpg_typed-0.1.1}/PKG-INFO +13 -10
  3. {asyncpg_typed-0.1.0 → asyncpg_typed-0.1.1}/README.md +11 -8
  4. {asyncpg_typed-0.1.0 → asyncpg_typed-0.1.1}/asyncpg_typed/__init__.py +126 -76
  5. {asyncpg_typed-0.1.0 → asyncpg_typed-0.1.1}/asyncpg_typed.egg-info/PKG-INFO +13 -10
  6. {asyncpg_typed-0.1.0 → asyncpg_typed-0.1.1}/asyncpg_typed.egg-info/SOURCES.txt +8 -5
  7. {asyncpg_typed-0.1.0 → asyncpg_typed-0.1.1}/pyproject.toml +1 -1
  8. asyncpg_typed-0.1.1/tests/__init__.py +0 -0
  9. asyncpg_typed-0.1.1/tests/connection.py +19 -0
  10. {asyncpg_typed-0.1.0/test → asyncpg_typed-0.1.1/tests}/test_code.py +5 -6
  11. {asyncpg_typed-0.1.0/test → asyncpg_typed-0.1.1/tests}/test_data.py +93 -22
  12. {asyncpg_typed-0.1.0/test → asyncpg_typed-0.1.1/tests}/test_template.py +2 -13
  13. {asyncpg_typed-0.1.0/test → asyncpg_typed-0.1.1/tests}/test_vector.py +2 -12
  14. {asyncpg_typed-0.1.0 → asyncpg_typed-0.1.1}/LICENSE +0 -0
  15. {asyncpg_typed-0.1.0 → asyncpg_typed-0.1.1}/asyncpg_typed/py.typed +0 -0
  16. {asyncpg_typed-0.1.0 → asyncpg_typed-0.1.1}/asyncpg_typed.egg-info/dependency_links.txt +0 -0
  17. {asyncpg_typed-0.1.0 → asyncpg_typed-0.1.1}/asyncpg_typed.egg-info/requires.txt +0 -0
  18. {asyncpg_typed-0.1.0 → asyncpg_typed-0.1.1}/asyncpg_typed.egg-info/top_level.txt +0 -0
  19. {asyncpg_typed-0.1.0 → asyncpg_typed-0.1.1}/asyncpg_typed.egg-info/zip-safe +0 -0
  20. {asyncpg_typed-0.1.0 → asyncpg_typed-0.1.1}/setup.cfg +0 -0
  21. {asyncpg_typed-0.1.0/test → asyncpg_typed-0.1.1/tests}/test_type.py +0 -0
@@ -0,0 +1 @@
1
+ recursive-include tests *.py
@@ -1,13 +1,13 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: asyncpg_typed
3
- Version: 0.1.0
3
+ Version: 0.1.1
4
4
  Summary: Type-safe queries for asyncpg
5
5
  Author-email: Levente Hunyadi <hunyadi@gmail.com>
6
6
  Maintainer-email: Levente Hunyadi <hunyadi@gmail.com>
7
7
  License-Expression: MIT
8
8
  Project-URL: Homepage, https://github.com/hunyadi/asyncpg_typed
9
9
  Project-URL: Source, https://github.com/hunyadi/asyncpg_typed
10
- Keywords: asyncpg,typed,database-client
10
+ Keywords: asyncpg,typed,database-client,postgres
11
11
  Classifier: Development Status :: 5 - Production/Stable
12
12
  Classifier: Intended Audience :: Developers
13
13
  Classifier: Operating System :: OS Independent
@@ -85,19 +85,21 @@ Instantiate a SQL object with the `sql` function:
85
85
 
86
86
  ```python
87
87
  def sql(
88
- stmt: str | string.templatelib.Template,
88
+ stmt: LiteralString | string.templatelib.Template,
89
89
  *,
90
- args: None | type[P1] | type[tuple[P1, P2]] | type[tuple[P1, P2, P3]] | ... = None,
91
- resultset: None | type[R1] | type[tuple[R1, R2]] | type[tuple[R1, R2, R3]] | ... = None
90
+ args: None | type[tuple[P1, P2]] | type[tuple[P1, P2, P3]] | ... = None,
91
+ resultset: None | type[tuple[R1, R2]] | type[tuple[R1, R2, R3]] | ... = None,
92
+ arg: None | type[P] = None,
93
+ result: None | type[R] = None,
92
94
  ) -> _SQL: ...
93
95
  ```
94
96
 
95
- The parameter `stmt` represents a SQL expression, either as a string (including an *f-string*) or a template (i.e. a *t-string*).
97
+ The parameter `stmt` represents a SQL expression, either as a literal string or a template (i.e. a *t-string*).
96
98
 
97
99
  If the expression is a string, it can have PostgreSQL parameter placeholders such as `$1`, `$2` or `$3`:
98
100
 
99
101
  ```python
100
- f"INSERT INTO table_name (col_1, col_2, col_3) VALUES ($1, $2, $3);"
102
+ "INSERT INTO table_name (col_1, col_2, col_3) VALUES ($1, $2, $3);"
101
103
  ```
102
104
 
103
105
  If the expression is a *t-string*, it can have replacement fields that evaluate to integers:
@@ -106,11 +108,10 @@ If the expression is a *t-string*, it can have replacement fields that evaluate
106
108
  t"INSERT INTO table_name (col_1, col_2, col_3) VALUES ({1}, {2}, {3});"
107
109
  ```
108
110
 
109
- The parameters `args` and `resultset` take a series type `P` or `R`, which may be any of the following:
111
+ The parameters `args` and `resultset` take a `tuple` of several types `Px` or `Rx`, each of which may be any of the following:
110
112
 
111
113
  * (required) simple type
112
114
  * optional simple type (`T | None`)
113
- * `tuple` of several (required or optional) simple types.
114
115
 
115
116
  Simple types include:
116
117
 
@@ -124,6 +125,7 @@ Simple types include:
124
125
  * `str`
125
126
  * `bytes`
126
127
  * `uuid.UUID`
128
+ * a user-defined class that derives from `StrEnum`
127
129
 
128
130
  Types are grouped together with `tuple`:
129
131
 
@@ -131,7 +133,7 @@ Types are grouped together with `tuple`:
131
133
  tuple[bool, int, str | None]
132
134
  ```
133
135
 
134
- Passing a simple type directly (e.g. `type[T]`) is for convenience, and is equivalent to passing a one-element tuple of the same simple type (i.e. `type[tuple[T]]`).
136
+ The parameters `arg` and `result` take a single type `P` or `R`. Passing a simple type (e.g. `type[T]`) directly via `arg` and `result` is for convenience, and is equivalent to passing a one-element tuple of the same simple type (i.e. `type[tuple[T]]`) via `args` and `resultset`.
135
137
 
136
138
  The number of types in `args` must correspond to the number of query parameters. (This is validated on calling `sql(...)` for the *t-string* syntax.) The number of types in `resultset` must correspond to the number of columns returned by the query.
137
139
 
@@ -159,6 +161,7 @@ Both `args` and `resultset` types must be compatible with their corresponding Po
159
161
  | `json` | `str` |
160
162
  | `jsonb` | `str` |
161
163
  | `uuid` | `UUID` |
164
+ | enumeration | `E: StrEnum` |
162
165
 
163
166
  ### Using a SQL object
164
167
 
@@ -47,19 +47,21 @@ Instantiate a SQL object with the `sql` function:
47
47
 
48
48
  ```python
49
49
  def sql(
50
- stmt: str | string.templatelib.Template,
50
+ stmt: LiteralString | string.templatelib.Template,
51
51
  *,
52
- args: None | type[P1] | type[tuple[P1, P2]] | type[tuple[P1, P2, P3]] | ... = None,
53
- resultset: None | type[R1] | type[tuple[R1, R2]] | type[tuple[R1, R2, R3]] | ... = None
52
+ args: None | type[tuple[P1, P2]] | type[tuple[P1, P2, P3]] | ... = None,
53
+ resultset: None | type[tuple[R1, R2]] | type[tuple[R1, R2, R3]] | ... = None,
54
+ arg: None | type[P] = None,
55
+ result: None | type[R] = None,
54
56
  ) -> _SQL: ...
55
57
  ```
56
58
 
57
- The parameter `stmt` represents a SQL expression, either as a string (including an *f-string*) or a template (i.e. a *t-string*).
59
+ The parameter `stmt` represents a SQL expression, either as a literal string or a template (i.e. a *t-string*).
58
60
 
59
61
  If the expression is a string, it can have PostgreSQL parameter placeholders such as `$1`, `$2` or `$3`:
60
62
 
61
63
  ```python
62
- f"INSERT INTO table_name (col_1, col_2, col_3) VALUES ($1, $2, $3);"
64
+ "INSERT INTO table_name (col_1, col_2, col_3) VALUES ($1, $2, $3);"
63
65
  ```
64
66
 
65
67
  If the expression is a *t-string*, it can have replacement fields that evaluate to integers:
@@ -68,11 +70,10 @@ If the expression is a *t-string*, it can have replacement fields that evaluate
68
70
  t"INSERT INTO table_name (col_1, col_2, col_3) VALUES ({1}, {2}, {3});"
69
71
  ```
70
72
 
71
- The parameters `args` and `resultset` take a series type `P` or `R`, which may be any of the following:
73
+ The parameters `args` and `resultset` take a `tuple` of several types `Px` or `Rx`, each of which may be any of the following:
72
74
 
73
75
  * (required) simple type
74
76
  * optional simple type (`T | None`)
75
- * `tuple` of several (required or optional) simple types.
76
77
 
77
78
  Simple types include:
78
79
 
@@ -86,6 +87,7 @@ Simple types include:
86
87
  * `str`
87
88
  * `bytes`
88
89
  * `uuid.UUID`
90
+ * a user-defined class that derives from `StrEnum`
89
91
 
90
92
  Types are grouped together with `tuple`:
91
93
 
@@ -93,7 +95,7 @@ Types are grouped together with `tuple`:
93
95
  tuple[bool, int, str | None]
94
96
  ```
95
97
 
96
- Passing a simple type directly (e.g. `type[T]`) is for convenience, and is equivalent to passing a one-element tuple of the same simple type (i.e. `type[tuple[T]]`).
98
+ The parameters `arg` and `result` take a single type `P` or `R`. Passing a simple type (e.g. `type[T]`) directly via `arg` and `result` is for convenience, and is equivalent to passing a one-element tuple of the same simple type (i.e. `type[tuple[T]]`) via `args` and `resultset`.
97
99
 
98
100
  The number of types in `args` must correspond to the number of query parameters. (This is validated on calling `sql(...)` for the *t-string* syntax.) The number of types in `resultset` must correspond to the number of columns returned by the query.
99
101
 
@@ -121,6 +123,7 @@ Both `args` and `resultset` types must be compatible with their corresponding Po
121
123
  | `json` | `str` |
122
124
  | `jsonb` | `str` |
123
125
  | `uuid` | `UUID` |
126
+ | enumeration | `E: StrEnum` |
124
127
 
125
128
  ### Using a SQL object
126
129
 
@@ -4,15 +4,15 @@ Type-safe queries for asyncpg.
4
4
  :see: https://github.com/hunyadi/asyncpg_typed
5
5
  """
6
6
 
7
- __version__ = "0.1.0"
7
+ __version__ = "0.1.1"
8
8
  __author__ = "Levente Hunyadi"
9
9
  __copyright__ = "Copyright 2025, Levente Hunyadi"
10
10
  __license__ = "MIT"
11
11
  __maintainer__ = "Levente Hunyadi"
12
12
  __status__ = "Production"
13
13
 
14
+ import enum
14
15
  import sys
15
- import typing
16
16
  from abc import abstractmethod
17
17
  from collections.abc import Iterable, Sequence
18
18
  from datetime import date, datetime, time
@@ -27,9 +27,9 @@ import asyncpg
27
27
  from asyncpg.prepared_stmt import PreparedStatement
28
28
 
29
29
  if sys.version_info < (3, 11):
30
- from typing_extensions import TypeVarTuple, Unpack
30
+ from typing_extensions import LiteralString, TypeVarTuple, Unpack
31
31
  else:
32
- from typing import TypeVarTuple, Unpack
32
+ from typing import LiteralString, TypeVarTuple, Unpack
33
33
 
34
34
  # list of supported data types
35
35
  DATA_TYPES: list[type[Any]] = [bool, int, float, Decimal, date, time, datetime, str, bytes, UUID]
@@ -38,9 +38,29 @@ DATA_TYPES: list[type[Any]] = [bool, int, float, Decimal, date, time, datetime,
38
38
  NUM_ARGS = 8
39
39
 
40
40
 
41
+ if sys.version_info >= (3, 11):
42
+
43
+ def is_enum_type(typ: object) -> bool:
44
+ """
45
+ `True` if the specified type is an enumeration type.
46
+ """
47
+
48
+ return isinstance(typ, enum.EnumType)
49
+
50
+ else:
51
+
52
+ def is_enum_type(typ: object) -> bool:
53
+ """
54
+ `True` if the specified type is an enumeration type.
55
+ """
56
+
57
+ # use an explicit isinstance(..., type) check to filter out special forms like generics
58
+ return isinstance(typ, type) and issubclass(typ, enum.Enum)
59
+
60
+
41
61
  def is_union_type(tp: Any) -> bool:
42
62
  """
43
- Returns `True` if `tp` is a union type such as `A | B` or `Union[A, B]`.
63
+ `True` if `tp` is a union type such as `A | B` or `Union[A, B]`.
44
64
  """
45
65
 
46
66
  origin = get_origin(tp)
@@ -49,7 +69,7 @@ def is_union_type(tp: Any) -> bool:
49
69
 
50
70
  def is_optional_type(tp: Any) -> bool:
51
71
  """
52
- Returns `True` if `tp` is an optional type such as `T | None`, `Optional[T]` or `Union[T, None]`.
72
+ `True` if `tp` is an optional type such as `T | None`, `Optional[T]` or `Union[T, None]`.
53
73
  """
54
74
 
55
75
  return is_union_type(tp) and any(a is type(None) for a in get_args(tp))
@@ -57,7 +77,7 @@ def is_optional_type(tp: Any) -> bool:
57
77
 
58
78
  def is_standard_type(tp: Any) -> bool:
59
79
  """
60
- Returns `True` if the type represents a built-in or a well-known standard type.
80
+ `True` if the type represents a built-in or a well-known standard type.
61
81
  """
62
82
 
63
83
  return tp.__module__ == "builtins" or tp.__module__ == UnionType.__module__
@@ -115,21 +135,20 @@ _name_to_type: dict[str, Any] = {
115
135
  }
116
136
 
117
137
 
118
- def check_data_type(name: str, data_type: type[Any]) -> bool:
138
+ def check_data_type(schema: str, name: str, data_type: type[Any]) -> bool:
119
139
  """
120
140
  Verifies if the Python target type can represent the PostgreSQL source type.
121
141
  """
122
142
 
123
- expected_type = _name_to_type.get(name)
124
- required_type = get_required_type(data_type)
125
-
126
- if expected_type is not None:
127
- return expected_type == required_type
128
- if is_standard_type(required_type):
129
- return False
143
+ if schema == "pg_catalog":
144
+ expected_type = _name_to_type.get(name)
145
+ return expected_type == data_type
146
+ else:
147
+ if is_standard_type(data_type):
148
+ return False
130
149
 
131
- # user-defined type registered with `conn.set_type_codec()`
132
- return True
150
+ # user-defined type registered with `conn.set_type_codec()`
151
+ return True
133
152
 
134
153
 
135
154
  class _SQLPlaceholder:
@@ -152,35 +171,28 @@ class _SQLObject:
152
171
  parameter_data_types: tuple[_SQLPlaceholder, ...]
153
172
  resultset_data_types: tuple[type[Any], ...]
154
173
  required: int
174
+ cast: int
155
175
 
156
176
  def __init__(
157
177
  self,
158
- *,
159
- args: type[Any] | None = None,
160
- resultset: type[Any] | None = None,
178
+ input_data_types: tuple[type[Any], ...],
179
+ output_data_types: tuple[type[Any], ...],
161
180
  ) -> None:
162
- if args is not None:
163
- if get_origin(args) is tuple:
164
- self.parameter_data_types = tuple(_SQLPlaceholder(ordinal, arg) for ordinal, arg in enumerate(get_args(args), start=1))
165
- else:
166
- self.parameter_data_types = (_SQLPlaceholder(1, args),)
167
- else:
168
- self.parameter_data_types = ()
169
-
170
- if resultset is not None:
171
- if get_origin(resultset) is tuple:
172
- self.resultset_data_types = get_args(resultset)
173
- else:
174
- self.resultset_data_types = (resultset,)
175
- else:
176
- self.resultset_data_types = ()
181
+ self.parameter_data_types = tuple(_SQLPlaceholder(ordinal, get_required_type(arg)) for ordinal, arg in enumerate(input_data_types, start=1))
182
+ self.resultset_data_types = tuple(get_required_type(data_type) for data_type in output_data_types)
177
183
 
178
184
  # create a bit-field of required types (1: required; 0: optional)
179
185
  required = 0
180
- for index, data_type in enumerate(self.resultset_data_types):
186
+ for index, data_type in enumerate(output_data_types):
181
187
  required |= (not is_optional_type(data_type)) << index
182
188
  self.required = required
183
189
 
190
+ # create a bit-field of types that require cast/conversion (1: pass to __init__; 0: skip)
191
+ cast = 0
192
+ for index, data_type in enumerate(self.resultset_data_types):
193
+ cast |= is_enum_type(data_type) << index
194
+ self.cast = cast
195
+
184
196
  def _raise_required_is_none(self, row: tuple[Any, ...], row_index: int | None = None) -> None:
185
197
  """
186
198
  Raises an error with the index of the first column value that is of a required type but has been assigned a value of `None`.
@@ -319,7 +331,7 @@ class _SQLObject:
319
331
  if sys.version_info >= (3, 14):
320
332
  from string.templatelib import Interpolation, Template # type: ignore[import-not-found]
321
333
 
322
- SQLExpression: TypeAlias = Template | str
334
+ SQLExpression: TypeAlias = Template | LiteralString
323
335
 
324
336
  class _SQLTemplate(_SQLObject):
325
337
  """
@@ -333,10 +345,10 @@ if sys.version_info >= (3, 14):
333
345
  self,
334
346
  template: Template,
335
347
  *,
336
- args: type[Any] | None = None,
337
- resultset: type[Any] | None = None,
348
+ args: tuple[type[Any], ...],
349
+ resultset: tuple[type[Any], ...],
338
350
  ) -> None:
339
- super().__init__(args=args, resultset=resultset)
351
+ super().__init__(args, resultset)
340
352
 
341
353
  for ip in template.interpolations:
342
354
  if ip.conversion is not None:
@@ -348,7 +360,7 @@ if sys.version_info >= (3, 14):
348
360
 
349
361
  self.strings = template.strings
350
362
 
351
- if args is not None:
363
+ if len(self.parameter_data_types) > 0:
352
364
 
353
365
  def _to_placeholder(ip: Interpolation) -> _SQLPlaceholder:
354
366
  ordinal = int(ip.value)
@@ -369,7 +381,7 @@ if sys.version_info >= (3, 14):
369
381
  return buf.getvalue()
370
382
 
371
383
  else:
372
- SQLExpression = str
384
+ SQLExpression = LiteralString
373
385
 
374
386
 
375
387
  class _SQLString(_SQLObject):
@@ -383,10 +395,10 @@ class _SQLString(_SQLObject):
383
395
  self,
384
396
  sql: str,
385
397
  *,
386
- args: type[Any] | None = None,
387
- resultset: type[Any] | None = None,
398
+ args: tuple[type[Any], ...],
399
+ resultset: tuple[type[Any], ...],
388
400
  ) -> None:
389
- super().__init__(args=args, resultset=resultset)
401
+ super().__init__(args, resultset)
390
402
  self.sql = sql
391
403
 
392
404
  def query(self) -> str:
@@ -422,7 +434,7 @@ class _SQLImpl(_SQL):
422
434
  stmt = await connection.prepare(self.sql.query())
423
435
 
424
436
  for attr, data_type in zip(stmt.get_attributes(), self.sql.resultset_data_types, strict=True):
425
- if not check_data_type(attr.type.name, data_type):
437
+ if not check_data_type(attr.type.schema, attr.type.name, data_type):
426
438
  raise TypeError(f"expected: {data_type} in column `{attr.name}`; got: `{attr.type.kind}` of `{attr.type.name}`")
427
439
 
428
440
  return stmt
@@ -434,40 +446,51 @@ class _SQLImpl(_SQL):
434
446
  stmt = await self._prepare(connection)
435
447
  await stmt.executemany(args)
436
448
 
449
+ def _cast_fetch(self, rows: list[asyncpg.Record]) -> list[tuple[Any, ...]]:
450
+ cast = self.sql.cast
451
+ if cast:
452
+ data_types = self.sql.resultset_data_types
453
+ resultset = [tuple((data_types[i](value) if (value := row[i]) is not None and cast >> i & 1 else value) for i in range(len(row))) for row in rows]
454
+ else:
455
+ resultset = [tuple(value for value in row) for row in rows]
456
+ self.sql.check_rows(resultset)
457
+ return resultset
458
+
437
459
  async def fetch(self, connection: asyncpg.Connection, *args: Any) -> list[tuple[Any, ...]]:
438
460
  stmt = await self._prepare(connection)
439
461
  rows = await stmt.fetch(*args)
440
- resultset = [tuple(value for value in row) for row in rows]
441
- self.sql.check_rows(resultset)
442
- return resultset
462
+ return self._cast_fetch(rows)
443
463
 
444
464
  async def fetchmany(self, connection: asyncpg.Connection, args: Iterable[Sequence[Any]]) -> list[tuple[Any, ...]]:
445
465
  stmt = await self._prepare(connection)
446
- rows = await stmt.fetchmany(args) # type: ignore[arg-type, call-arg] # pyright: ignore[reportCallIssue]
447
- rows = typing.cast(list[asyncpg.Record], rows)
448
- resultset = [tuple(value for value in row) for row in rows]
449
- self.sql.check_rows(resultset)
450
- return resultset
466
+ rows = await stmt.fetchmany(args)
467
+ return self._cast_fetch(rows)
451
468
 
452
469
  async def fetchrow(self, connection: asyncpg.Connection, *args: Any) -> tuple[Any, ...] | None:
453
470
  stmt = await self._prepare(connection)
454
471
  row = await stmt.fetchrow(*args)
455
472
  if row is None:
456
473
  return None
457
- resultset = tuple(value for value in row)
474
+ cast = self.sql.cast
475
+ if cast:
476
+ data_types = self.sql.resultset_data_types
477
+ resultset = tuple((data_types[i](value) if (value := row[i]) is not None and cast >> i & 1 else value) for i in range(len(row)))
478
+ else:
479
+ resultset = tuple(value for value in row)
458
480
  self.sql.check_row(resultset)
459
481
  return resultset
460
482
 
461
483
  async def fetchval(self, connection: asyncpg.Connection, *args: Any) -> Any:
462
484
  stmt = await self._prepare(connection)
463
485
  value = await stmt.fetchval(*args)
464
- self.sql.check_value(value)
465
- return value
486
+ result = self.sql.resultset_data_types[0](value) if value is not None and self.sql.cast else value
487
+ self.sql.check_value(result)
488
+ return result
466
489
 
467
490
 
468
491
  ### START OF AUTO-GENERATED BLOCK ###
469
492
 
470
- PS = TypeVar("PS", bool, bool | None, int, int | None, float, float | None, Decimal, Decimal | None, date, date | None, time, time | None, datetime, datetime | None, str, str | None, bytes, bytes | None, UUID, UUID | None)
493
+ PS = TypeVar("PS")
471
494
  P1 = TypeVar("P1")
472
495
  P2 = TypeVar("P2")
473
496
  P3 = TypeVar("P3")
@@ -476,7 +499,7 @@ P5 = TypeVar("P5")
476
499
  P6 = TypeVar("P6")
477
500
  P7 = TypeVar("P7")
478
501
  P8 = TypeVar("P8")
479
- RS = TypeVar("RS", bool, bool | None, int, int | None, float, float | None, Decimal, Decimal | None, date, date | None, time, time | None, datetime, datetime | None, str, str | None, bytes, bytes | None, UUID, UUID | None)
502
+ RS = TypeVar("RS")
480
503
  R1 = TypeVar("R1")
481
504
  R2 = TypeVar("R2")
482
505
  RX = TypeVarTuple("RX")
@@ -722,27 +745,27 @@ class SQL_P8_RX(Generic[P1, P2, P3, P4, P5, P6, P7, P8, R1, R2, Unpack[RX]], SQL
722
745
  @overload
723
746
  def sql(stmt: SQLExpression) -> SQL_P0: ...
724
747
  @overload
725
- def sql(stmt: SQLExpression, *, resultset: type[RS]) -> SQL_P0_RS[RS]: ...
748
+ def sql(stmt: SQLExpression, *, result: type[RS]) -> SQL_P0_RS[RS]: ...
726
749
  @overload
727
750
  def sql(stmt: SQLExpression, *, resultset: type[tuple[R1]]) -> SQL_P0_RS[R1]: ...
728
751
  @overload
729
752
  def sql(stmt: SQLExpression, *, resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_P0_RX[R1, R2, Unpack[RX]]: ...
730
753
  @overload
731
- def sql(stmt: SQLExpression, *, args: type[PS]) -> SQL_P1[PS]: ...
754
+ def sql(stmt: SQLExpression, *, arg: type[PS]) -> SQL_P1[PS]: ...
732
755
  @overload
733
756
  def sql(stmt: SQLExpression, *, args: type[tuple[P1]]) -> SQL_P1[P1]: ...
734
757
  @overload
735
- def sql(stmt: SQLExpression, *, args: type[PS], resultset: type[RS]) -> SQL_P1_RS[PS, RS]: ...
758
+ def sql(stmt: SQLExpression, *, arg: type[PS], result: type[RS]) -> SQL_P1_RS[PS, RS]: ...
736
759
  @overload
737
760
  def sql(stmt: SQLExpression, *, args: type[tuple[P1]], resultset: type[tuple[R1]]) -> SQL_P1_RS[P1, R1]: ...
738
761
  @overload
739
- def sql(stmt: SQLExpression, *, args: type[PS], resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_P1_RX[PS, R1, R2, Unpack[RX]]: ...
762
+ def sql(stmt: SQLExpression, *, arg: type[PS], resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_P1_RX[PS, R1, R2, Unpack[RX]]: ...
740
763
  @overload
741
764
  def sql(stmt: SQLExpression, *, args: type[tuple[P1]], resultset: type[tuple[R1, R2, Unpack[RX]]]) -> SQL_P1_RX[P1, R1, R2, Unpack[RX]]: ...
742
765
  @overload
743
766
  def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2]]) -> SQL_P2[P1, P2]: ...
744
767
  @overload
745
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2]], resultset: type[RS]) -> SQL_P2_RS[P1, P2, RS]: ...
768
+ def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2]], result: type[RS]) -> SQL_P2_RS[P1, P2, RS]: ...
746
769
  @overload
747
770
  def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2]], resultset: type[tuple[R1]]) -> SQL_P2_RS[P1, P2, R1]: ...
748
771
  @overload
@@ -750,7 +773,7 @@ def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2]], resultset: type[tuple
750
773
  @overload
751
774
  def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3]]) -> SQL_P3[P1, P2, P3]: ...
752
775
  @overload
753
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3]], resultset: type[RS]) -> SQL_P3_RS[P1, P2, P3, RS]: ...
776
+ def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3]], result: type[RS]) -> SQL_P3_RS[P1, P2, P3, RS]: ...
754
777
  @overload
755
778
  def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3]], resultset: type[tuple[R1]]) -> SQL_P3_RS[P1, P2, P3, R1]: ...
756
779
  @overload
@@ -758,7 +781,7 @@ def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3]], resultset: type[t
758
781
  @overload
759
782
  def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4]]) -> SQL_P4[P1, P2, P3, P4]: ...
760
783
  @overload
761
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4]], resultset: type[RS]) -> SQL_P4_RS[P1, P2, P3, P4, RS]: ...
784
+ def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4]], result: type[RS]) -> SQL_P4_RS[P1, P2, P3, P4, RS]: ...
762
785
  @overload
763
786
  def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4]], resultset: type[tuple[R1]]) -> SQL_P4_RS[P1, P2, P3, P4, R1]: ...
764
787
  @overload
@@ -766,7 +789,7 @@ def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4]], resultset: ty
766
789
  @overload
767
790
  def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5]]) -> SQL_P5[P1, P2, P3, P4, P5]: ...
768
791
  @overload
769
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5]], resultset: type[RS]) -> SQL_P5_RS[P1, P2, P3, P4, P5, RS]: ...
792
+ def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5]], result: type[RS]) -> SQL_P5_RS[P1, P2, P3, P4, P5, RS]: ...
770
793
  @overload
771
794
  def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5]], resultset: type[tuple[R1]]) -> SQL_P5_RS[P1, P2, P3, P4, P5, R1]: ...
772
795
  @overload
@@ -774,7 +797,7 @@ def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5]], resultset
774
797
  @overload
775
798
  def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6]]) -> SQL_P6[P1, P2, P3, P4, P5, P6]: ...
776
799
  @overload
777
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6]], resultset: type[RS]) -> SQL_P6_RS[P1, P2, P3, P4, P5, P6, RS]: ...
800
+ def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6]], result: type[RS]) -> SQL_P6_RS[P1, P2, P3, P4, P5, P6, RS]: ...
778
801
  @overload
779
802
  def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6]], resultset: type[tuple[R1]]) -> SQL_P6_RS[P1, P2, P3, P4, P5, P6, R1]: ...
780
803
  @overload
@@ -782,7 +805,7 @@ def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6]], resul
782
805
  @overload
783
806
  def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6, P7]]) -> SQL_P7[P1, P2, P3, P4, P5, P6, P7]: ...
784
807
  @overload
785
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6, P7]], resultset: type[RS]) -> SQL_P7_RS[P1, P2, P3, P4, P5, P6, P7, RS]: ...
808
+ def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6, P7]], result: type[RS]) -> SQL_P7_RS[P1, P2, P3, P4, P5, P6, P7, RS]: ...
786
809
  @overload
787
810
  def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6, P7]], resultset: type[tuple[R1]]) -> SQL_P7_RS[P1, P2, P3, P4, P5, P6, P7, R1]: ...
788
811
  @overload
@@ -790,7 +813,7 @@ def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6, P7]], r
790
813
  @overload
791
814
  def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6, P7, P8]]) -> SQL_P8[P1, P2, P3, P4, P5, P6, P7, P8]: ...
792
815
  @overload
793
- def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6, P7, P8]], resultset: type[RS]) -> SQL_P8_RS[P1, P2, P3, P4, P5, P6, P7, P8, RS]: ...
816
+ def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6, P7, P8]], result: type[RS]) -> SQL_P8_RS[P1, P2, P3, P4, P5, P6, P7, P8, RS]: ...
794
817
  @overload
795
818
  def sql(stmt: SQLExpression, *, args: type[tuple[P1, P2, P3, P4, P5, P6, P7, P8]], resultset: type[tuple[R1]]) -> SQL_P8_RS[P1, P2, P3, P4, P5, P6, P7, P8, R1]: ...
796
819
  @overload
@@ -805,23 +828,50 @@ def sql(
805
828
  *,
806
829
  args: type[Any] | None = None,
807
830
  resultset: type[Any] | None = None,
831
+ arg: type[Any] | None = None,
832
+ result: type[Any] | None = None,
808
833
  ) -> _SQL:
809
834
  """
810
835
  Creates a SQL statement with associated type information.
811
836
 
812
- :param stmt: SQL statement as a string or template.
813
- :param args: Type signature for input parameters. Use the type for a single parameter (e.g. `int`) or `tuple[...]` for multiple parameters.
814
- :param resultset: Type signature for output data. Use the type for a single parameter (e.g. `int`) or `tuple[...]` for multiple parameters.
837
+ :param stmt: SQL statement as a literal string or template.
838
+ :param args: Type signature for multiple input parameters (e.g. `tuple[bool, int, str]`).
839
+ :param resultset: Type signature for multiple resultset columns (e.g. `tuple[datetime, Decimal, str]`).
840
+ :param arg: Type signature for a single input parameter (e.g. `int`).
841
+ :param result: Type signature for a single result column (e.g. `UUID`).
815
842
  """
816
843
 
844
+ if args is not None and arg is not None:
845
+ raise TypeError("expected: either `args` or `arg`; got: both")
846
+ if resultset is not None and result is not None:
847
+ raise TypeError("expected: either `resultset` or `result`; got: both")
848
+
849
+ if args is not None:
850
+ if get_origin(args) is not tuple:
851
+ raise TypeError(f"expected: `type[tuple[T, ...]]` for `args`; got: {type(args)}")
852
+ input_data_types = get_args(args)
853
+ elif arg is not None:
854
+ input_data_types = (arg,)
855
+ else:
856
+ input_data_types = ()
857
+
858
+ if resultset is not None:
859
+ if get_origin(resultset) is not tuple:
860
+ raise TypeError(f"expected: `type[tuple[T, ...]]` for `resultset`; got: {type(resultset)}")
861
+ output_data_types = get_args(resultset)
862
+ elif result is not None:
863
+ output_data_types = (result,)
864
+ else:
865
+ output_data_types = ()
866
+
817
867
  if sys.version_info >= (3, 14):
818
868
  obj: _SQLObject
819
869
  match stmt:
820
870
  case Template():
821
- obj = _SQLTemplate(stmt, args=args, resultset=resultset)
871
+ obj = _SQLTemplate(stmt, args=input_data_types, resultset=output_data_types)
822
872
  case str():
823
- obj = _SQLString(stmt, args=args, resultset=resultset)
873
+ obj = _SQLString(stmt, args=input_data_types, resultset=output_data_types)
824
874
  else:
825
- obj = _SQLString(stmt, args=args, resultset=resultset)
875
+ obj = _SQLString(stmt, args=input_data_types, resultset=output_data_types)
826
876
 
827
877
  return _SQLImpl(obj)
@@ -1,13 +1,13 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: asyncpg_typed
3
- Version: 0.1.0
3
+ Version: 0.1.1
4
4
  Summary: Type-safe queries for asyncpg
5
5
  Author-email: Levente Hunyadi <hunyadi@gmail.com>
6
6
  Maintainer-email: Levente Hunyadi <hunyadi@gmail.com>
7
7
  License-Expression: MIT
8
8
  Project-URL: Homepage, https://github.com/hunyadi/asyncpg_typed
9
9
  Project-URL: Source, https://github.com/hunyadi/asyncpg_typed
10
- Keywords: asyncpg,typed,database-client
10
+ Keywords: asyncpg,typed,database-client,postgres
11
11
  Classifier: Development Status :: 5 - Production/Stable
12
12
  Classifier: Intended Audience :: Developers
13
13
  Classifier: Operating System :: OS Independent
@@ -85,19 +85,21 @@ Instantiate a SQL object with the `sql` function:
85
85
 
86
86
  ```python
87
87
  def sql(
88
- stmt: str | string.templatelib.Template,
88
+ stmt: LiteralString | string.templatelib.Template,
89
89
  *,
90
- args: None | type[P1] | type[tuple[P1, P2]] | type[tuple[P1, P2, P3]] | ... = None,
91
- resultset: None | type[R1] | type[tuple[R1, R2]] | type[tuple[R1, R2, R3]] | ... = None
90
+ args: None | type[tuple[P1, P2]] | type[tuple[P1, P2, P3]] | ... = None,
91
+ resultset: None | type[tuple[R1, R2]] | type[tuple[R1, R2, R3]] | ... = None,
92
+ arg: None | type[P] = None,
93
+ result: None | type[R] = None,
92
94
  ) -> _SQL: ...
93
95
  ```
94
96
 
95
- The parameter `stmt` represents a SQL expression, either as a string (including an *f-string*) or a template (i.e. a *t-string*).
97
+ The parameter `stmt` represents a SQL expression, either as a literal string or a template (i.e. a *t-string*).
96
98
 
97
99
  If the expression is a string, it can have PostgreSQL parameter placeholders such as `$1`, `$2` or `$3`:
98
100
 
99
101
  ```python
100
- f"INSERT INTO table_name (col_1, col_2, col_3) VALUES ($1, $2, $3);"
102
+ "INSERT INTO table_name (col_1, col_2, col_3) VALUES ($1, $2, $3);"
101
103
  ```
102
104
 
103
105
  If the expression is a *t-string*, it can have replacement fields that evaluate to integers:
@@ -106,11 +108,10 @@ If the expression is a *t-string*, it can have replacement fields that evaluate
106
108
  t"INSERT INTO table_name (col_1, col_2, col_3) VALUES ({1}, {2}, {3});"
107
109
  ```
108
110
 
109
- The parameters `args` and `resultset` take a series type `P` or `R`, which may be any of the following:
111
+ The parameters `args` and `resultset` take a `tuple` of several types `Px` or `Rx`, each of which may be any of the following:
110
112
 
111
113
  * (required) simple type
112
114
  * optional simple type (`T | None`)
113
- * `tuple` of several (required or optional) simple types.
114
115
 
115
116
  Simple types include:
116
117
 
@@ -124,6 +125,7 @@ Simple types include:
124
125
  * `str`
125
126
  * `bytes`
126
127
  * `uuid.UUID`
128
+ * a user-defined class that derives from `StrEnum`
127
129
 
128
130
  Types are grouped together with `tuple`:
129
131
 
@@ -131,7 +133,7 @@ Types are grouped together with `tuple`:
131
133
  tuple[bool, int, str | None]
132
134
  ```
133
135
 
134
- Passing a simple type directly (e.g. `type[T]`) is for convenience, and is equivalent to passing a one-element tuple of the same simple type (i.e. `type[tuple[T]]`).
136
+ The parameters `arg` and `result` take a single type `P` or `R`. Passing a simple type (e.g. `type[T]`) directly via `arg` and `result` is for convenience, and is equivalent to passing a one-element tuple of the same simple type (i.e. `type[tuple[T]]`) via `args` and `resultset`.
135
137
 
136
138
  The number of types in `args` must correspond to the number of query parameters. (This is validated on calling `sql(...)` for the *t-string* syntax.) The number of types in `resultset` must correspond to the number of columns returned by the query.
137
139
 
@@ -159,6 +161,7 @@ Both `args` and `resultset` types must be compatible with their corresponding Po
159
161
  | `json` | `str` |
160
162
  | `jsonb` | `str` |
161
163
  | `uuid` | `UUID` |
164
+ | enumeration | `E: StrEnum` |
162
165
 
163
166
  ### Using a SQL object
164
167
 
@@ -1,4 +1,5 @@
1
1
  LICENSE
2
+ MANIFEST.in
2
3
  README.md
3
4
  pyproject.toml
4
5
  asyncpg_typed/__init__.py
@@ -9,8 +10,10 @@ asyncpg_typed.egg-info/dependency_links.txt
9
10
  asyncpg_typed.egg-info/requires.txt
10
11
  asyncpg_typed.egg-info/top_level.txt
11
12
  asyncpg_typed.egg-info/zip-safe
12
- test/test_code.py
13
- test/test_data.py
14
- test/test_template.py
15
- test/test_type.py
16
- test/test_vector.py
13
+ tests/__init__.py
14
+ tests/connection.py
15
+ tests/test_code.py
16
+ tests/test_data.py
17
+ tests/test_template.py
18
+ tests/test_type.py
19
+ tests/test_vector.py
@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
6
6
  name = "asyncpg_typed"
7
7
  description = "Type-safe queries for asyncpg"
8
8
  readme = { file = "README.md", content-type = "text/markdown" }
9
- keywords = ["asyncpg", "typed", "database-client"]
9
+ keywords = ["asyncpg", "typed", "database-client", "postgres"]
10
10
  license = "MIT"
11
11
  authors = [
12
12
  { name = "Levente Hunyadi", email = "hunyadi@gmail.com" }
File without changes
@@ -0,0 +1,19 @@
1
+ """
2
+ Type-safe queries for asyncpg.
3
+
4
+ :see: https://github.com/hunyadi/asyncpg_typed
5
+ """
6
+
7
+ from collections.abc import AsyncIterator
8
+ from contextlib import asynccontextmanager
9
+
10
+ import asyncpg
11
+
12
+
13
+ @asynccontextmanager
14
+ async def get_connection() -> AsyncIterator[asyncpg.Connection]:
15
+ conn = await asyncpg.connect(host="localhost", port=5432, user="postgres", password="postgres")
16
+ try:
17
+ yield conn
18
+ finally:
19
+ await conn.close()
@@ -10,7 +10,7 @@ from io import StringIO
10
10
  from pathlib import Path
11
11
  from typing import TextIO
12
12
 
13
- from asyncpg_typed import DATA_TYPES, NUM_ARGS
13
+ from asyncpg_typed import NUM_ARGS
14
14
 
15
15
 
16
16
  def _args_and_results(p: int, r: int, s: bool = False) -> str:
@@ -83,11 +83,11 @@ def _param_spec(p: int, r: int, s: bool = False) -> str:
83
83
  if (s and p > 1) or (not s and p > 0):
84
84
  params.append(f"args: type[tuple{_args(p)}]")
85
85
  elif s and p == 1:
86
- params.append("args: type[PS]")
86
+ params.append("arg: type[PS]")
87
87
  if (s and r > 1) or (not s and r > 0):
88
88
  params.append(f"resultset: type[tuple{_results(r)}]")
89
89
  elif s and r == 1:
90
- params.append("resultset: type[RS]")
90
+ params.append("result: type[RS]")
91
91
  if len(params) > 1:
92
92
  return f", {', '.join(params)}" if params else ""
93
93
  else:
@@ -109,13 +109,12 @@ def _write_function(out: TextIO, p: int, r: int, s: bool) -> None:
109
109
 
110
110
 
111
111
  def write_code(out: TextIO) -> None:
112
- data_types_list = ", ".join(f"{data_type.__name__}, {data_type.__name__} | None" for data_type in DATA_TYPES)
113
- print(f'PS = TypeVar("PS", {data_types_list})', file=out)
112
+ print('PS = TypeVar("PS")', file=out)
114
113
 
115
114
  for p in range(1, NUM_ARGS + 1):
116
115
  print(f'P{p} = TypeVar("P{p}")', file=out)
117
116
 
118
- print(f'RS = TypeVar("RS", {data_types_list})', file=out)
117
+ print('RS = TypeVar("RS")', file=out)
119
118
  print('R1 = TypeVar("R1")', file=out)
120
119
  print('R2 = TypeVar("R2")', file=out)
121
120
  print('RX = TypeVarTuple("RX")', file=out)
@@ -4,9 +4,9 @@ Type-safe queries for asyncpg.
4
4
  :see: https://github.com/hunyadi/asyncpg_typed
5
5
  """
6
6
 
7
+ import enum
8
+ import sys
7
9
  import unittest
8
- from collections.abc import AsyncIterator
9
- from contextlib import asynccontextmanager
10
10
  from datetime import date, datetime, time, timedelta, timezone
11
11
  from decimal import Decimal
12
12
  from random import randint, sample
@@ -14,29 +14,31 @@ from types import UnionType
14
14
  from typing import Any
15
15
  from uuid import UUID, uuid4
16
16
 
17
- import asyncpg
18
-
19
17
  from asyncpg_typed import sql
18
+ from tests.connection import get_connection
20
19
 
21
20
 
22
21
  class RollbackException(RuntimeError):
23
22
  pass
24
23
 
25
24
 
26
- @asynccontextmanager
27
- async def get_connection() -> AsyncIterator[asyncpg.Connection]:
28
- conn = await asyncpg.connect(host="localhost", port=5432, user="postgres", password="postgres")
29
- try:
30
- yield conn
31
- finally:
32
- await conn.close()
25
+ if sys.version_info < (3, 11):
26
+
27
+ class State(str, enum.Enum):
28
+ ACTIVE = "active"
29
+ INACTIVE = "inactive"
30
+ else:
31
+
32
+ class State(enum.StrEnum):
33
+ ACTIVE = "active"
34
+ INACTIVE = "inactive"
33
35
 
34
36
 
35
37
  class TestDataTypes(unittest.IsolatedAsyncioTestCase):
36
38
  async def test_numeric_types(self) -> None:
37
39
  create_sql = sql(
38
40
  """
39
- ---sql
41
+ --sql
40
42
  CREATE TEMPORARY TABLE numeric_types(
41
43
  id bigint GENERATED ALWAYS AS IDENTITY,
42
44
  boolean_value boolean NOT NULL,
@@ -80,7 +82,7 @@ class TestDataTypes(unittest.IsolatedAsyncioTestCase):
80
82
  async def test_datetime_types(self) -> None:
81
83
  create_sql = sql(
82
84
  """
83
- ---sql
85
+ --sql
84
86
  CREATE TEMPORARY TABLE datetime_types(
85
87
  id bigint GENERATED ALWAYS AS IDENTITY,
86
88
  date_value date NOT NULL,
@@ -127,7 +129,7 @@ class TestDataTypes(unittest.IsolatedAsyncioTestCase):
127
129
  async def test_sequence_types(self) -> None:
128
130
  create_sql = sql(
129
131
  """
130
- ---sql
132
+ --sql
131
133
  CREATE TEMPORARY TABLE sequence_types(
132
134
  id bigint GENERATED ALWAYS AS IDENTITY,
133
135
  bytes_value bytea NOT NULL,
@@ -167,7 +169,7 @@ class TestDataTypes(unittest.IsolatedAsyncioTestCase):
167
169
  async def test_composite_type(self) -> None:
168
170
  create_sql = sql(
169
171
  """
170
- ---sql
172
+ --sql
171
173
  CREATE TEMPORARY TABLE composite_types(
172
174
  id bigint GENERATED ALWAYS AS IDENTITY,
173
175
  uuid_value uuid NOT NULL,
@@ -204,10 +206,79 @@ class TestDataTypes(unittest.IsolatedAsyncioTestCase):
204
206
  await insert_sql.executemany(conn, [record1, record2])
205
207
  self.assertEqual(await select_sql.fetch(conn), [record1, record2])
206
208
 
209
+ async def test_enum_type(self) -> None:
210
+ create_sql = sql(
211
+ """
212
+ --sql
213
+ DO $$ BEGIN
214
+ CREATE TYPE state AS ENUM ('active', 'inactive');
215
+ EXCEPTION
216
+ WHEN duplicate_object THEN null;
217
+ END $$;
218
+
219
+ --sql
220
+ CREATE TEMPORARY TABLE enum_types(
221
+ id bigint GENERATED ALWAYS AS IDENTITY,
222
+ enum_value state NOT NULL,
223
+ CONSTRAINT pk_sample_data PRIMARY KEY (id)
224
+ );
225
+ """
226
+ )
227
+
228
+ insert_sql = sql(
229
+ """
230
+ --sql
231
+ INSERT INTO enum_types (enum_value)
232
+ VALUES ($1);
233
+ """,
234
+ arg=State,
235
+ )
236
+
237
+ select_sql = sql(
238
+ """
239
+ --sql
240
+ SELECT enum_value, enum_value
241
+ FROM enum_types
242
+ ORDER BY id;
243
+ """,
244
+ resultset=tuple[State, State | None],
245
+ )
246
+
247
+ value_sql = sql(
248
+ """
249
+ --sql
250
+ SELECT enum_value
251
+ FROM enum_types
252
+ ORDER BY id;
253
+ """,
254
+ result=State,
255
+ )
256
+
257
+ async with get_connection() as conn:
258
+ await create_sql.execute(conn)
259
+ await insert_sql.executemany(conn, [(State.ACTIVE,), (State.INACTIVE,)])
260
+
261
+ rows = await select_sql.fetch(conn)
262
+ for row in rows:
263
+ for column in row:
264
+ self.assertIsInstance(column, State)
265
+ self.assertEqual(rows, [(State.ACTIVE, State.ACTIVE), (State.INACTIVE, State.INACTIVE)])
266
+
267
+ record = await select_sql.fetchrow(conn)
268
+ self.assertIsNotNone(record)
269
+ if record:
270
+ for column in record:
271
+ self.assertIsInstance(column, State)
272
+ self.assertEqual(record, (State.ACTIVE, State.ACTIVE))
273
+
274
+ value = await value_sql.fetchval(conn)
275
+ self.assertIsInstance(value, State)
276
+ self.assertEqual(value, State.ACTIVE)
277
+
207
278
  async def test_sql(self) -> None:
208
279
  create_sql = sql(
209
280
  """
210
- ---sql
281
+ --sql
211
282
  CREATE TEMPORARY TABLE sample_data(
212
283
  id bigint GENERATED ALWAYS AS IDENTITY,
213
284
  boolean_value bool NOT NULL,
@@ -258,7 +329,7 @@ class TestDataTypes(unittest.IsolatedAsyncioTestCase):
258
329
  RETURNING id;
259
330
  """,
260
331
  args=tuple[bool, int, str | None],
261
- resultset=int,
332
+ result=int,
262
333
  )
263
334
 
264
335
  count_sql = sql(
@@ -266,7 +337,7 @@ class TestDataTypes(unittest.IsolatedAsyncioTestCase):
266
337
  --sql
267
338
  SELECT COUNT(*) FROM sample_data;
268
339
  """,
269
- resultset=int,
340
+ result=int,
270
341
  )
271
342
 
272
343
  count_where_sql = sql(
@@ -274,8 +345,8 @@ class TestDataTypes(unittest.IsolatedAsyncioTestCase):
274
345
  --sql
275
346
  SELECT COUNT(*) FROM sample_data WHERE integer_value > $1;
276
347
  """,
277
- args=int,
278
- resultset=int,
348
+ arg=int,
349
+ result=int,
279
350
  )
280
351
 
281
352
  async with get_connection() as conn:
@@ -341,12 +412,12 @@ class TestDataTypes(unittest.IsolatedAsyncioTestCase):
341
412
  SELECT
342
413
  {nullif(0, index)}, {nullif(1, index)}, {nullif(2, index)}, {nullif(3, index)},
343
414
  {nullif(4, index)}, {nullif(5, index)}, {nullif(6, index)}, {nullif(7, index)};
344
- """,
415
+ """, # pyright: ignore[reportArgumentType]
345
416
  args=tuple[int, int, int, int, int, int, int, int],
346
417
  resultset=tuple[tuple(params)], # type: ignore[misc]
347
418
  ) # type: ignore[call-overload]
348
419
 
349
- rows = await passthrough_sql.fetch(conn, *args)
420
+ rows = await passthrough_sql.fetch(conn, *args) # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType]
350
421
  resultset: list[int | None] = [i for i in args]
351
422
  resultset[index] = None
352
423
  self.assertEqual(rows, [tuple(resultset)])
@@ -8,31 +8,20 @@ Type-safe queries for asyncpg.
8
8
 
9
9
  import sys
10
10
  import unittest
11
- from contextlib import asynccontextmanager
12
-
13
- import asyncpg
14
11
 
15
12
  from asyncpg_typed import sql
13
+ from tests.connection import get_connection
16
14
 
17
15
  if sys.version_info >= (3, 14):
18
16
  from string.templatelib import Interpolation, Template
19
17
 
20
18
 
21
- @asynccontextmanager
22
- async def get_connection():
23
- conn = await asyncpg.connect(host="localhost", port=5432, user="postgres", password="postgres")
24
- try:
25
- yield conn
26
- finally:
27
- await conn.close()
28
-
29
-
30
19
  @unittest.skipUnless(sys.version_info >= (3, 14), "requires Python 3.14 or later")
31
20
  class TestTemplate(unittest.IsolatedAsyncioTestCase):
32
21
  async def test_sql(self) -> None:
33
22
  create_sql = sql(
34
23
  """
35
- ---sql
24
+ --sql
36
25
  CREATE TEMPORARY TABLE sample_data(
37
26
  id bigint GENERATED ALWAYS AS IDENTITY,
38
27
  boolean_value bool NOT NULL,
@@ -5,33 +5,23 @@ Type-safe queries for asyncpg.
5
5
  """
6
6
 
7
7
  import unittest
8
- from contextlib import asynccontextmanager
9
8
  from random import random
10
9
 
11
- import asyncpg
12
10
  from asyncpg_vector import HalfVector, Vector, register_vector
13
11
 
14
12
  from asyncpg_typed import sql
13
+ from tests.connection import get_connection
15
14
 
16
15
 
17
16
  class RollbackException(RuntimeError):
18
17
  pass
19
18
 
20
19
 
21
- @asynccontextmanager
22
- async def get_connection():
23
- conn = await asyncpg.connect(host="localhost", port=5432, user="postgres", password="postgres")
24
- try:
25
- yield conn
26
- finally:
27
- await conn.close()
28
-
29
-
30
20
  class TestConnection(unittest.IsolatedAsyncioTestCase):
31
21
  async def test_vector_type(self) -> None:
32
22
  create_sql = sql(
33
23
  """
34
- ---sql
24
+ --sql
35
25
  CREATE EXTENSION IF NOT EXISTS vector;
36
26
 
37
27
  --sql
File without changes
File without changes