iron-sql 0.5.2__tar.gz → 0.5.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {iron_sql-0.5.2 → iron_sql-0.5.4}/PKG-INFO +2 -2
- {iron_sql-0.5.2 → iron_sql-0.5.4}/pyproject.toml +2 -10
- {iron_sql-0.5.2 → iron_sql-0.5.4}/src/iron_sql/codegen/generator.py +106 -48
- {iron_sql-0.5.2 → iron_sql-0.5.4}/src/iron_sql/runtime.py +170 -46
- {iron_sql-0.5.2 → iron_sql-0.5.4}/LICENSE +0 -0
- {iron_sql-0.5.2 → iron_sql-0.5.4}/README.md +0 -0
- {iron_sql-0.5.2 → iron_sql-0.5.4}/src/iron_sql/__init__.py +0 -0
- {iron_sql-0.5.2 → iron_sql-0.5.4}/src/iron_sql/codegen/__init__.py +0 -0
- {iron_sql-0.5.2 → iron_sql-0.5.4}/src/iron_sql/codegen/sqlc.py +0 -0
- {iron_sql-0.5.2 → iron_sql-0.5.4}/src/iron_sql/codegen/util.py +0 -0
- {iron_sql-0.5.2 → iron_sql-0.5.4}/src/iron_sql/py.typed +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: iron-sql
|
|
3
|
-
Version: 0.5.
|
|
3
|
+
Version: 0.5.4
|
|
4
4
|
Summary: iron_sql generates typed async PostgreSQL clients and runtime helpers from schemas and SQL queries
|
|
5
5
|
Keywords: postgresql,sql,sqlc,psycopg,codegen,async
|
|
6
6
|
Author: Ilia Ablamonov
|
|
@@ -16,7 +16,7 @@ Requires-Dist: psycopg>=3.3.2
|
|
|
16
16
|
Requires-Dist: psycopg-pool>=3.3.0
|
|
17
17
|
Requires-Dist: pydantic>=2.12.4
|
|
18
18
|
Requires-Dist: inflection>=0.5.1 ; extra == 'codegen'
|
|
19
|
-
Requires-Dist: sqlc>=1.30.0.
|
|
19
|
+
Requires-Dist: sqlc>=1.30.0.post19 ; extra == 'codegen'
|
|
20
20
|
Requires-Python: >=3.13
|
|
21
21
|
Project-URL: Homepage, https://github.com/Flamefork/iron_sql
|
|
22
22
|
Project-URL: Repository, https://github.com/Flamefork/iron_sql.git
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "iron-sql"
|
|
3
|
-
version = "0.5.
|
|
3
|
+
version = "0.5.4"
|
|
4
4
|
|
|
5
5
|
description = "iron_sql generates typed async PostgreSQL clients and runtime helpers from schemas and SQL queries"
|
|
6
6
|
readme = "README.md"
|
|
@@ -26,7 +26,7 @@ dependencies = [
|
|
|
26
26
|
[project.optional-dependencies]
|
|
27
27
|
codegen = [
|
|
28
28
|
"inflection>=0.5.1",
|
|
29
|
-
"sqlc>=1.30.0.
|
|
29
|
+
"sqlc>=1.30.0.post19",
|
|
30
30
|
]
|
|
31
31
|
|
|
32
32
|
[project.urls]
|
|
@@ -53,14 +53,6 @@ dev = [
|
|
|
53
53
|
|
|
54
54
|
[tool.pyright]
|
|
55
55
|
typeCheckingMode = "strict"
|
|
56
|
-
reportUnknownArgumentType = "none"
|
|
57
|
-
reportUnknownLambdaType = "none"
|
|
58
|
-
reportUnknownMemberType = "none"
|
|
59
|
-
reportUnknownParameterType = "none"
|
|
60
|
-
reportUnknownVariableType = "none"
|
|
61
|
-
reportMissingParameterType = "none"
|
|
62
|
-
reportMissingTypeArgument = "none"
|
|
63
|
-
reportMissingTypeStubs = "none"
|
|
64
56
|
deprecateTypingAliases = true
|
|
65
57
|
reportImportCycles = true
|
|
66
58
|
reportUnnecessaryTypeIgnoreComment = true
|
|
@@ -31,9 +31,18 @@ class ColumnSpec:
|
|
|
31
31
|
name: str
|
|
32
32
|
table: str
|
|
33
33
|
py_type: str
|
|
34
|
+
element_py_type: str | None = None
|
|
34
35
|
json_type: str | None = None
|
|
35
36
|
|
|
36
37
|
|
|
38
|
+
_JSON_PARAM_DUMPERS = {
|
|
39
|
+
"json": "runtime.dump_json_value",
|
|
40
|
+
"jsonb": "runtime.dump_json_value",
|
|
41
|
+
"text": "runtime.dump_json_text",
|
|
42
|
+
"varchar": "runtime.dump_json_text",
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
|
|
37
46
|
@dataclass(kw_only=True, frozen=True)
|
|
38
47
|
class ParamSpec:
|
|
39
48
|
name: str
|
|
@@ -51,19 +60,23 @@ class ParamSpec:
|
|
|
51
60
|
|
|
52
61
|
@property
|
|
53
62
|
def serialized_expr(self) -> str:
|
|
63
|
+
expr = self.name
|
|
64
|
+
wraps_value = False
|
|
54
65
|
if self.json_type:
|
|
55
|
-
|
|
56
|
-
|
|
66
|
+
dump_fn = _JSON_PARAM_DUMPERS[self.db_type]
|
|
67
|
+
expr = f"{dump_fn}({self.json_type}, {self.name})"
|
|
68
|
+
wraps_value = True
|
|
57
69
|
match self.db_type:
|
|
58
70
|
case "json":
|
|
59
|
-
expr = f"psycopg.types.json.Json({
|
|
71
|
+
expr = f"psycopg.types.json.Json({expr})"
|
|
72
|
+
wraps_value = True
|
|
60
73
|
case "jsonb":
|
|
61
|
-
expr = f"psycopg.types.json.Jsonb({
|
|
74
|
+
expr = f"psycopg.types.json.Jsonb({expr})"
|
|
75
|
+
wraps_value = True
|
|
62
76
|
case _:
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
return f"{expr} if {self.name} is not None else None"
|
|
77
|
+
pass
|
|
78
|
+
if wraps_value and not self.not_null:
|
|
79
|
+
expr = f"{expr} if {self.name} is not None else None"
|
|
67
80
|
return expr
|
|
68
81
|
|
|
69
82
|
|
|
@@ -159,16 +172,17 @@ class TypeResolver:
|
|
|
159
172
|
json_column_type_overrides: dict[tuple[str, str], str]
|
|
160
173
|
|
|
161
174
|
def column_spec(self, column: Column) -> ColumnSpec:
|
|
162
|
-
_, py_type, json_type = self._resolve(column)
|
|
175
|
+
_, py_type, element_py_type, json_type = self._resolve(column)
|
|
163
176
|
return ColumnSpec(
|
|
164
177
|
name=column.name,
|
|
165
178
|
table=column.table.name if column.table else "unknown",
|
|
166
179
|
py_type=py_type,
|
|
180
|
+
element_py_type=element_py_type,
|
|
167
181
|
json_type=json_type,
|
|
168
182
|
)
|
|
169
183
|
|
|
170
184
|
def param_spec(self, column: Column, name: str, *, is_named: bool) -> ParamSpec:
|
|
171
|
-
db_type, py_type, json_type = self._resolve(column)
|
|
185
|
+
db_type, py_type, _, json_type = self._resolve(column)
|
|
172
186
|
return ParamSpec(
|
|
173
187
|
name=name,
|
|
174
188
|
py_type=py_type,
|
|
@@ -179,7 +193,7 @@ class TypeResolver:
|
|
|
179
193
|
json_type=json_type,
|
|
180
194
|
)
|
|
181
195
|
|
|
182
|
-
def _resolve(self, column: Column) -> tuple[str, str, str | None]:
|
|
196
|
+
def _resolve(self, column: Column) -> tuple[str, str, str | None, str | None]:
|
|
183
197
|
db_type = column.type.name.removeprefix("pg_catalog.")
|
|
184
198
|
|
|
185
199
|
json_type = None
|
|
@@ -210,13 +224,15 @@ class TypeResolver:
|
|
|
210
224
|
)
|
|
211
225
|
py_type = "object"
|
|
212
226
|
|
|
227
|
+
element_py_type = None
|
|
213
228
|
if column.is_array:
|
|
229
|
+
element_py_type = py_type
|
|
214
230
|
py_type = f"Sequence[{py_type}]"
|
|
215
231
|
|
|
216
232
|
if not column.not_null:
|
|
217
233
|
py_type += " | None"
|
|
218
234
|
|
|
219
|
-
return db_type, py_type, json_type
|
|
235
|
+
return db_type, py_type, element_py_type, json_type
|
|
220
236
|
|
|
221
237
|
|
|
222
238
|
def collect_used_enums(sqlc_res: SQLCResult) -> set[tuple[str, str]]:
|
|
@@ -316,11 +332,24 @@ def generate_sql_module( # noqa: PLR0913, PLR0914
|
|
|
316
332
|
|
|
317
333
|
used_enums = collect_used_enums(sqlc_res)
|
|
318
334
|
|
|
319
|
-
|
|
320
|
-
|
|
335
|
+
enum_specs = [
|
|
336
|
+
(schema, e)
|
|
321
337
|
for schema in sqlc_res.catalog.schemas
|
|
322
338
|
for e in schema.enums
|
|
323
339
|
if (schema.name, e.name) in used_enums
|
|
340
|
+
]
|
|
341
|
+
|
|
342
|
+
enums = sorted(
|
|
343
|
+
render_enum_class(e, module_name, to_pascal_fn, to_snake_fn)
|
|
344
|
+
for _, e in enum_specs
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
enum_registry = sorted(
|
|
348
|
+
(
|
|
349
|
+
f"{schema.name}.{e.name}",
|
|
350
|
+
enum_class_name(e.name, module_name, to_pascal_fn, to_snake_fn),
|
|
351
|
+
)
|
|
352
|
+
for schema, e in enum_specs
|
|
324
353
|
)
|
|
325
354
|
|
|
326
355
|
query_classes = render_query_classes(
|
|
@@ -341,6 +370,7 @@ def generate_sql_module( # noqa: PLR0913, PLR0914
|
|
|
341
370
|
sql_fn_name,
|
|
342
371
|
entities,
|
|
343
372
|
enums,
|
|
373
|
+
enum_registry,
|
|
344
374
|
query_classes,
|
|
345
375
|
query_overloads,
|
|
346
376
|
query_dict_entries,
|
|
@@ -390,12 +420,7 @@ def render_query_classes(
|
|
|
390
420
|
for p in q.params
|
|
391
421
|
],
|
|
392
422
|
query_result_types[q.name],
|
|
393
|
-
|
|
394
|
-
(
|
|
395
|
-
resolver.column_spec(q.columns[0]).json_type
|
|
396
|
-
if len(q.columns) == 1
|
|
397
|
-
else None
|
|
398
|
-
),
|
|
423
|
+
tuple(resolver.column_spec(column) for column in q.columns),
|
|
399
424
|
query_locations_by_name[q.name],
|
|
400
425
|
)
|
|
401
426
|
for q in sorted(sqlc_queries, key=lambda q: query_order[q.name])
|
|
@@ -408,7 +433,7 @@ def resolve_json_model_overrides(
|
|
|
408
433
|
if not overrides:
|
|
409
434
|
return "", {}
|
|
410
435
|
|
|
411
|
-
json_compatible_types =
|
|
436
|
+
json_compatible_types = set(_JSON_PARAM_DUMPERS)
|
|
412
437
|
col_types = {
|
|
413
438
|
(table.rel.name, column.name): column.type.name.removeprefix("pg_catalog.")
|
|
414
439
|
for schema in catalog.schemas
|
|
@@ -465,6 +490,7 @@ def render_module( # noqa: PLR0913, PLR0917
|
|
|
465
490
|
sql_fn_name: str,
|
|
466
491
|
entities: list[str],
|
|
467
492
|
enums: list[str],
|
|
493
|
+
enum_registry: list[tuple[str, str]],
|
|
468
494
|
query_classes: list[str],
|
|
469
495
|
query_overloads: list[str],
|
|
470
496
|
query_dict_entries: list[str],
|
|
@@ -489,6 +515,20 @@ def render_module( # noqa: PLR0913, PLR0917
|
|
|
489
515
|
|
|
490
516
|
imports_block = "\n".join(imports)
|
|
491
517
|
|
|
518
|
+
pre_pool_blocks: list[str] = []
|
|
519
|
+
if enums:
|
|
520
|
+
pre_pool_blocks.append("\n\n\n".join(enums))
|
|
521
|
+
if enum_registry:
|
|
522
|
+
registry_entries = ",\n ".join(
|
|
523
|
+
f'("{pg_name}", {class_name})' for pg_name, class_name in enum_registry
|
|
524
|
+
)
|
|
525
|
+
registry_type = "list[tuple[str, type[StrEnum]]]"
|
|
526
|
+
pre_pool_blocks.append(
|
|
527
|
+
f"ENUM_TYPES: {registry_type} = [\n {registry_entries},\n]"
|
|
528
|
+
)
|
|
529
|
+
pool_args.append("enum_types=ENUM_TYPES")
|
|
530
|
+
pre_pool_section = "".join(f"{block}\n\n\n" for block in pre_pool_blocks)
|
|
531
|
+
|
|
492
532
|
pool_args_str = ",\n ".join(pool_args)
|
|
493
533
|
|
|
494
534
|
return f"""
|
|
@@ -525,7 +565,7 @@ from iron_sql import runtime
|
|
|
525
565
|
{imports_block}
|
|
526
566
|
|
|
527
567
|
|
|
528
|
-
{module_name.upper()}_POOL = runtime.ConnectionPool(
|
|
568
|
+
{pre_pool_section}{module_name.upper()}_POOL = runtime.ConnectionPool(
|
|
529
569
|
{pool_args_str},
|
|
530
570
|
)
|
|
531
571
|
|
|
@@ -561,9 +601,6 @@ async def {module_name}_notify(channel: str, payload: str = "") -> None:
|
|
|
561
601
|
await runtime.notify(conn, channel, payload)
|
|
562
602
|
|
|
563
603
|
|
|
564
|
-
{"\n\n\n".join(enums)}
|
|
565
|
-
|
|
566
|
-
|
|
567
604
|
{"\n\n\n".join(entities)}
|
|
568
605
|
|
|
569
606
|
|
|
@@ -593,14 +630,23 @@ def {sql_fn_name}(sql: str, row_type: str | None = None) -> Query[Any]:
|
|
|
593
630
|
""".strip() # noqa: E501
|
|
594
631
|
|
|
595
632
|
|
|
633
|
+
def enum_class_name(
|
|
634
|
+
enum_name: str,
|
|
635
|
+
module_name: str,
|
|
636
|
+
to_pascal_fn: Callable[[str], str],
|
|
637
|
+
to_snake_fn: Callable[[str], str],
|
|
638
|
+
) -> str:
|
|
639
|
+
return to_pascal_fn(f"{module_name}_{to_snake_fn(enum_name)}")
|
|
640
|
+
|
|
641
|
+
|
|
596
642
|
def render_enum_class(
|
|
597
643
|
enum: Enum,
|
|
598
644
|
module_name: str,
|
|
599
645
|
to_pascal_fn: Callable[[str], str],
|
|
600
646
|
to_snake_fn: Callable[[str], str],
|
|
601
647
|
) -> str:
|
|
602
|
-
class_name =
|
|
603
|
-
members = []
|
|
648
|
+
class_name = enum_class_name(enum.name, module_name, to_pascal_fn, to_snake_fn)
|
|
649
|
+
members: list[str] = []
|
|
604
650
|
seen_names: dict[str, int] = {}
|
|
605
651
|
|
|
606
652
|
for val in enum.vals:
|
|
@@ -642,7 +688,7 @@ class {name}:
|
|
|
642
688
|
|
|
643
689
|
|
|
644
690
|
def deduplicate_params(params: list[ParamSpec]) -> list[ParamSpec]:
|
|
645
|
-
seen = defaultdict(int)
|
|
691
|
+
seen: defaultdict[str, int] = defaultdict(int)
|
|
646
692
|
result: list[ParamSpec] = []
|
|
647
693
|
for param in params:
|
|
648
694
|
seen[param.name] += 1
|
|
@@ -658,8 +704,7 @@ def render_query_class(
|
|
|
658
704
|
sql: str,
|
|
659
705
|
query_params: list[ParamSpec],
|
|
660
706
|
result: str,
|
|
661
|
-
|
|
662
|
-
scalar_json_type: str | None,
|
|
707
|
+
result_columns: tuple[ColumnSpec, ...],
|
|
663
708
|
locations: list[str],
|
|
664
709
|
) -> str:
|
|
665
710
|
query_params = deduplicate_params(query_params)
|
|
@@ -681,24 +726,9 @@ def render_query_class(
|
|
|
681
726
|
query_fn_params.insert(0, "self")
|
|
682
727
|
|
|
683
728
|
base_result = result.removesuffix(" | None")
|
|
729
|
+
row_factory = render_row_factory(result, result_columns)
|
|
684
730
|
|
|
685
|
-
if
|
|
686
|
-
row_factory = "psycopg.rows.scalar_row"
|
|
687
|
-
elif columns_num == 1:
|
|
688
|
-
not_null_str = "True" if not result.endswith(" | None") else "False"
|
|
689
|
-
validate_arg = (
|
|
690
|
-
f", validate=lambda _v: runtime.validate_json_field({scalar_json_type}, _v)"
|
|
691
|
-
if scalar_json_type
|
|
692
|
-
else ""
|
|
693
|
-
)
|
|
694
|
-
row_factory = (
|
|
695
|
-
f"runtime.typed_scalar_row"
|
|
696
|
-
f"({base_result}, not_null={not_null_str}{validate_arg})"
|
|
697
|
-
)
|
|
698
|
-
else:
|
|
699
|
-
row_factory = f"psycopg.rows.class_row({result})"
|
|
700
|
-
|
|
701
|
-
if columns_num > 0:
|
|
731
|
+
if result_columns:
|
|
702
732
|
methods = f"""
|
|
703
733
|
|
|
704
734
|
async def query_all_rows({", ".join(query_fn_params)}) -> list[{result}]:
|
|
@@ -738,6 +768,34 @@ class {query_name}(Query[{result}]):
|
|
|
738
768
|
""".strip()
|
|
739
769
|
|
|
740
770
|
|
|
771
|
+
def render_row_factory(result: str, columns: tuple[ColumnSpec, ...]) -> str:
|
|
772
|
+
match columns:
|
|
773
|
+
case ():
|
|
774
|
+
return "psycopg.rows.scalar_row"
|
|
775
|
+
case (column,):
|
|
776
|
+
return render_scalar_row_factory(result, column)
|
|
777
|
+
case _:
|
|
778
|
+
return f"psycopg.rows.class_row({result})"
|
|
779
|
+
|
|
780
|
+
|
|
781
|
+
def render_scalar_row_factory(result: str, column: ColumnSpec) -> str:
|
|
782
|
+
base_result = result.removesuffix(" | None")
|
|
783
|
+
not_null = "True" if not result.endswith(" | None") else "False"
|
|
784
|
+
|
|
785
|
+
if column.element_py_type is not None:
|
|
786
|
+
return f"runtime.typed_array_row({column.element_py_type}, not_null={not_null})"
|
|
787
|
+
|
|
788
|
+
validate_arg = (
|
|
789
|
+
f", validate=lambda _v: runtime.validate_json_field({column.json_type}, _v)"
|
|
790
|
+
if column.json_type
|
|
791
|
+
else ""
|
|
792
|
+
)
|
|
793
|
+
if " | " in base_result and not validate_arg:
|
|
794
|
+
return f"runtime.typed_value_row(not_null={not_null})"
|
|
795
|
+
|
|
796
|
+
return f"runtime.typed_scalar_row({base_result}, not_null={not_null}{validate_arg})"
|
|
797
|
+
|
|
798
|
+
|
|
741
799
|
def render_query_overload(
|
|
742
800
|
sql_fn_name: str, query_name: str, sql: str, row_type: str | None
|
|
743
801
|
) -> str:
|
|
@@ -846,7 +904,7 @@ def build_entities(
|
|
|
846
904
|
key=lambda e: (e.table_name is None, e.table_name or ""),
|
|
847
905
|
)
|
|
848
906
|
|
|
849
|
-
query_result_types = {}
|
|
907
|
+
query_result_types: dict[str, str] = {}
|
|
850
908
|
for q in queries_from_sqlc:
|
|
851
909
|
if len(q.columns) == 0:
|
|
852
910
|
query_result_types[q.name] = "None"
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import contextlib
|
|
3
|
+
import functools
|
|
3
4
|
import itertools
|
|
4
5
|
import types
|
|
5
6
|
from collections.abc import AsyncGenerator
|
|
@@ -8,29 +9,28 @@ from collections.abc import Callable
|
|
|
8
9
|
from collections.abc import Sequence
|
|
9
10
|
from contextlib import asynccontextmanager
|
|
10
11
|
from contextvars import ContextVar
|
|
11
|
-
from enum import
|
|
12
|
+
from enum import StrEnum
|
|
12
13
|
from typing import Any
|
|
13
14
|
from typing import ClassVar
|
|
14
15
|
from typing import Literal
|
|
15
16
|
from typing import Self
|
|
16
17
|
from typing import TypedDict
|
|
18
|
+
from typing import TypeGuard
|
|
17
19
|
from typing import overload
|
|
18
20
|
|
|
19
21
|
import psycopg
|
|
20
22
|
import psycopg.abc
|
|
21
23
|
import psycopg.rows
|
|
22
24
|
import psycopg.sql
|
|
23
|
-
import psycopg.types.
|
|
25
|
+
import psycopg.types.enum
|
|
24
26
|
import psycopg_pool
|
|
27
|
+
from psycopg._cursor_base import BaseCursor
|
|
25
28
|
from pydantic import TypeAdapter
|
|
26
29
|
|
|
27
|
-
_adapter_cache: dict[object, TypeAdapter[object]] = {}
|
|
28
30
|
|
|
29
|
-
|
|
30
|
-
def get_adapter(typ: object) -> TypeAdapter[
|
|
31
|
-
|
|
32
|
-
_adapter_cache[typ] = TypeAdapter(typ)
|
|
33
|
-
return _adapter_cache[typ]
|
|
31
|
+
@functools.cache
|
|
32
|
+
def get_adapter(typ: object) -> TypeAdapter[Any]:
|
|
33
|
+
return TypeAdapter(typ)
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
class NoRowsError(Exception):
|
|
@@ -43,7 +43,7 @@ class TooManyRowsError(Exception):
|
|
|
43
43
|
|
|
44
44
|
@asynccontextmanager
|
|
45
45
|
async def listen(
|
|
46
|
-
conn: psycopg.AsyncConnection, channel: str
|
|
46
|
+
conn: psycopg.AsyncConnection[Any], channel: str
|
|
47
47
|
) -> AsyncGenerator[AsyncGenerator[str]]:
|
|
48
48
|
_validate_channel(channel)
|
|
49
49
|
if await _has_active_listen_subscriptions(conn):
|
|
@@ -65,7 +65,9 @@ async def listen(
|
|
|
65
65
|
await execute_unlisten(conn, channel)
|
|
66
66
|
|
|
67
67
|
|
|
68
|
-
async def notify(
|
|
68
|
+
async def notify(
|
|
69
|
+
conn: psycopg.AsyncConnection[Any], channel: str, payload: str
|
|
70
|
+
) -> None:
|
|
69
71
|
_validate_channel(channel)
|
|
70
72
|
await conn.execute(
|
|
71
73
|
psycopg.sql.SQL("NOTIFY {}, {}").format(
|
|
@@ -75,21 +77,21 @@ async def notify(conn: psycopg.AsyncConnection, channel: str, payload: str) -> N
|
|
|
75
77
|
)
|
|
76
78
|
|
|
77
79
|
|
|
78
|
-
async def execute_listen(conn: psycopg.AsyncConnection, channel: str) -> None:
|
|
80
|
+
async def execute_listen(conn: psycopg.AsyncConnection[Any], channel: str) -> None:
|
|
79
81
|
_validate_channel(channel)
|
|
80
82
|
await conn.execute(
|
|
81
83
|
psycopg.sql.SQL("LISTEN {}").format(psycopg.sql.Identifier(channel))
|
|
82
84
|
)
|
|
83
85
|
|
|
84
86
|
|
|
85
|
-
async def execute_unlisten(conn: psycopg.AsyncConnection, channel: str) -> None:
|
|
87
|
+
async def execute_unlisten(conn: psycopg.AsyncConnection[Any], channel: str) -> None:
|
|
86
88
|
_validate_channel(channel)
|
|
87
89
|
await conn.execute(
|
|
88
90
|
psycopg.sql.SQL("UNLISTEN {}").format(psycopg.sql.Identifier(channel))
|
|
89
91
|
)
|
|
90
92
|
|
|
91
93
|
|
|
92
|
-
async def _has_active_listen_subscriptions(conn: psycopg.AsyncConnection) -> bool:
|
|
94
|
+
async def _has_active_listen_subscriptions(conn: psycopg.AsyncConnection[Any]) -> bool:
|
|
93
95
|
async with conn.cursor() as cur:
|
|
94
96
|
await cur.execute("SELECT EXISTS (SELECT FROM pg_listening_channels())")
|
|
95
97
|
row = await cur.fetchone()
|
|
@@ -113,7 +115,9 @@ def _next_cursor_name() -> str:
|
|
|
113
115
|
|
|
114
116
|
|
|
115
117
|
@asynccontextmanager
|
|
116
|
-
async def _ensure_transaction(
|
|
118
|
+
async def _ensure_transaction(
|
|
119
|
+
conn: psycopg.AsyncConnection[Any],
|
|
120
|
+
) -> AsyncGenerator[None]:
|
|
117
121
|
match conn.info.transaction_status:
|
|
118
122
|
case psycopg.pq.TransactionStatus.IDLE:
|
|
119
123
|
async with conn.transaction():
|
|
@@ -129,10 +133,10 @@ class Query[T]:
|
|
|
129
133
|
_stmt: ClassVar[psycopg.sql.SQL]
|
|
130
134
|
_row_factory: psycopg.rows.BaseRowFactory[T]
|
|
131
135
|
_connection_factory: Callable[
|
|
132
|
-
[], contextlib.AbstractAsyncContextManager[psycopg.AsyncConnection]
|
|
136
|
+
[], contextlib.AbstractAsyncContextManager[psycopg.AsyncConnection[Any]]
|
|
133
137
|
]
|
|
134
138
|
|
|
135
|
-
def with_connection(self, connection: psycopg.AsyncConnection) -> Self:
|
|
139
|
+
def with_connection(self, connection: psycopg.AsyncConnection[Any]) -> Self:
|
|
136
140
|
q = self.__class__()
|
|
137
141
|
q._connection_factory = lambda: contextlib.nullcontext(connection) # noqa: SLF001
|
|
138
142
|
return q
|
|
@@ -179,6 +183,35 @@ class PoolOptions(TypedDict, total=False):
|
|
|
179
183
|
reconnect_failed: Callable[[psycopg_pool.AsyncConnectionPool[Any]], Awaitable[None]]
|
|
180
184
|
|
|
181
185
|
|
|
186
|
+
async def register_enums(
|
|
187
|
+
conn: psycopg.AsyncConnection[Any],
|
|
188
|
+
enum_types: Sequence[tuple[str, type[StrEnum]]],
|
|
189
|
+
) -> None:
|
|
190
|
+
for pg_name, enum_cls in enum_types:
|
|
191
|
+
info = await psycopg.types.enum.EnumInfo.fetch(conn, pg_name)
|
|
192
|
+
if info is None:
|
|
193
|
+
msg = f"Enum type {pg_name!r} not found in database"
|
|
194
|
+
raise RuntimeError(msg)
|
|
195
|
+
psycopg.types.enum.register_enum(
|
|
196
|
+
info,
|
|
197
|
+
conn,
|
|
198
|
+
enum_cls,
|
|
199
|
+
mapping=[(member, member.value) for member in enum_cls],
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def _enum_configure(
|
|
204
|
+
enum_types: Sequence[tuple[str, type[StrEnum]]],
|
|
205
|
+
user_configure: Callable[[psycopg.AsyncConnection[Any]], Awaitable[None]] | None,
|
|
206
|
+
) -> Callable[[psycopg.AsyncConnection[Any]], Awaitable[None]]:
|
|
207
|
+
async def configure(conn: psycopg.AsyncConnection[Any]) -> None:
|
|
208
|
+
await register_enums(conn, enum_types)
|
|
209
|
+
if user_configure is not None:
|
|
210
|
+
await user_configure(conn)
|
|
211
|
+
|
|
212
|
+
return configure
|
|
213
|
+
|
|
214
|
+
|
|
182
215
|
class ConnectionPool:
|
|
183
216
|
def __init__(
|
|
184
217
|
self,
|
|
@@ -187,11 +220,13 @@ class ConnectionPool:
|
|
|
187
220
|
name: str | None = None,
|
|
188
221
|
application_name: str | None = None,
|
|
189
222
|
pool_options: PoolOptions | None = None,
|
|
223
|
+
enum_types: Sequence[tuple[str, type[StrEnum]]] = (),
|
|
190
224
|
) -> None:
|
|
191
225
|
self.conninfo = conninfo
|
|
192
226
|
self.name = name
|
|
193
227
|
self.application_name = application_name
|
|
194
228
|
self.pool_options = pool_options or {}
|
|
229
|
+
self.enum_types = enum_types
|
|
195
230
|
self._init_psycopg_pool()
|
|
196
231
|
|
|
197
232
|
async def close(self) -> None:
|
|
@@ -217,7 +252,7 @@ class ConnectionPool:
|
|
|
217
252
|
await self.psycopg_pool.check()
|
|
218
253
|
|
|
219
254
|
@asynccontextmanager
|
|
220
|
-
async def connection(self) -> AsyncGenerator[psycopg.AsyncConnection]:
|
|
255
|
+
async def connection(self) -> AsyncGenerator[psycopg.AsyncConnection[Any]]:
|
|
221
256
|
task = asyncio.current_task()
|
|
222
257
|
cancelling_before = 0 if task is None else task.cancelling()
|
|
223
258
|
await self.psycopg_pool.open()
|
|
@@ -232,6 +267,10 @@ class ConnectionPool:
|
|
|
232
267
|
forwarded: dict[str, Any] = {
|
|
233
268
|
k: v for k, v in self.pool_options.items() if k != "kwargs"
|
|
234
269
|
}
|
|
270
|
+
if self.enum_types:
|
|
271
|
+
forwarded["configure"] = _enum_configure(
|
|
272
|
+
self.enum_types, forwarded.get("configure")
|
|
273
|
+
)
|
|
235
274
|
conn_kwargs = {
|
|
236
275
|
**user_kwargs,
|
|
237
276
|
# https://www.psycopg.org/psycopg3/docs/basic/transactions.html#autocommit-transactions
|
|
@@ -249,8 +288,8 @@ class ConnectionPool:
|
|
|
249
288
|
|
|
250
289
|
@asynccontextmanager
|
|
251
290
|
async def connection_in_context(
|
|
252
|
-
self, context_var: ContextVar[psycopg.AsyncConnection | None]
|
|
253
|
-
) -> AsyncGenerator[psycopg.AsyncConnection]:
|
|
291
|
+
self, context_var: ContextVar[psycopg.AsyncConnection[Any] | None]
|
|
292
|
+
) -> AsyncGenerator[psycopg.AsyncConnection[Any]]:
|
|
254
293
|
conn = context_var.get()
|
|
255
294
|
if conn is not None:
|
|
256
295
|
yield conn
|
|
@@ -263,42 +302,40 @@ class ConnectionPool:
|
|
|
263
302
|
context_var.reset(token)
|
|
264
303
|
|
|
265
304
|
|
|
266
|
-
def validate_json_field(typ:
|
|
267
|
-
if value is None:
|
|
268
|
-
return None
|
|
305
|
+
def validate_json_field[T](typ: type[T], value: object) -> T:
|
|
269
306
|
adapter = get_adapter(typ)
|
|
270
307
|
if isinstance(value, str | bytes):
|
|
271
308
|
return adapter.validate_json(value)
|
|
272
309
|
return adapter.validate_python(value)
|
|
273
310
|
|
|
274
311
|
|
|
275
|
-
def json_validated(**json_fields: object):
|
|
276
|
-
def decorator
|
|
277
|
-
|
|
312
|
+
def json_validated[T](**json_fields: object) -> Callable[[type[T]], type[T]]:
|
|
313
|
+
def decorator(cls: type[T]) -> type[T]:
|
|
314
|
+
original_post_init = getattr(cls, "__post_init__", None)
|
|
278
315
|
|
|
279
316
|
def __post_init__(self: object) -> None: # noqa: N807
|
|
280
|
-
if
|
|
281
|
-
|
|
317
|
+
if original_post_init is not None:
|
|
318
|
+
original_post_init(self)
|
|
282
319
|
for name, typ in json_fields.items():
|
|
283
|
-
|
|
320
|
+
current = getattr(self, name)
|
|
321
|
+
if current is None:
|
|
322
|
+
continue
|
|
323
|
+
setattr(self, name, validate_json_field(typ, current)) # pyright: ignore[reportArgumentType]
|
|
284
324
|
|
|
285
|
-
cls
|
|
325
|
+
setattr(cls, "__post_init__", __post_init__) # noqa: B010
|
|
286
326
|
return cls
|
|
287
327
|
|
|
288
328
|
return decorator
|
|
289
329
|
|
|
290
330
|
|
|
291
|
-
def
|
|
292
|
-
if value is None:
|
|
293
|
-
return None
|
|
331
|
+
def dump_json_value(typ: object, value: object) -> object:
|
|
294
332
|
adapter = get_adapter(typ)
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
return adapter.dump_json(value).decode()
|
|
333
|
+
return adapter.dump_python(value, mode="json")
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def dump_json_text(typ: object, value: object) -> str:
|
|
337
|
+
adapter = get_adapter(typ)
|
|
338
|
+
return adapter.dump_json(value).decode()
|
|
302
339
|
|
|
303
340
|
|
|
304
341
|
def get_one_row[T](rows: list[T]) -> T:
|
|
@@ -338,7 +375,9 @@ def typed_scalar_row[T](
|
|
|
338
375
|
def typed_scalar_row[T](
|
|
339
376
|
typ: type[T], *, not_null: bool, validate: Callable[[object], T] | None = None
|
|
340
377
|
) -> psycopg.rows.BaseRowFactory[T | None]:
|
|
341
|
-
def typed_scalar_row_(
|
|
378
|
+
def typed_scalar_row_(
|
|
379
|
+
cursor: BaseCursor[Any, Any],
|
|
380
|
+
) -> psycopg.rows.RowMaker[T | None]:
|
|
342
381
|
scalar_row_ = psycopg.rows.scalar_row(cursor)
|
|
343
382
|
|
|
344
383
|
def typed_scalar_row__(values: Sequence[Any]) -> T | None:
|
|
@@ -350,13 +389,98 @@ def typed_scalar_row[T](
|
|
|
350
389
|
return None
|
|
351
390
|
if validate:
|
|
352
391
|
return validate(val)
|
|
353
|
-
|
|
354
|
-
if issubclass(typ, Enum):
|
|
355
|
-
return typ(val)
|
|
356
|
-
msg = f"Expected scalar of type {typ}, got {type(val)}"
|
|
357
|
-
raise TypeError(msg)
|
|
358
|
-
return val
|
|
392
|
+
return _check_scalar_type(val, typ)
|
|
359
393
|
|
|
360
394
|
return typed_scalar_row__
|
|
361
395
|
|
|
362
396
|
return typed_scalar_row_
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
@overload
|
|
400
|
+
def typed_value_row[T](
|
|
401
|
+
*,
|
|
402
|
+
not_null: Literal[True],
|
|
403
|
+
) -> psycopg.rows.BaseRowFactory[T]: ...
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
@overload
|
|
407
|
+
def typed_value_row[T](
|
|
408
|
+
*,
|
|
409
|
+
not_null: Literal[False],
|
|
410
|
+
) -> psycopg.rows.BaseRowFactory[T | None]: ...
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
def typed_value_row[T](*, not_null: bool) -> psycopg.rows.BaseRowFactory[T | None]:
|
|
414
|
+
def typed_value_row_(
|
|
415
|
+
cursor: BaseCursor[Any, Any],
|
|
416
|
+
) -> psycopg.rows.RowMaker[T | None]:
|
|
417
|
+
scalar_row_ = psycopg.rows.scalar_row(cursor)
|
|
418
|
+
|
|
419
|
+
def typed_value_row__(values: Sequence[Any]) -> T | None:
|
|
420
|
+
val = scalar_row_(values)
|
|
421
|
+
if val is None:
|
|
422
|
+
if not_null:
|
|
423
|
+
msg = "Expected non-null value, got None"
|
|
424
|
+
raise TypeError(msg)
|
|
425
|
+
return None
|
|
426
|
+
return val
|
|
427
|
+
|
|
428
|
+
return typed_value_row__
|
|
429
|
+
|
|
430
|
+
return typed_value_row_
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
@overload
|
|
434
|
+
def typed_array_row[T](
|
|
435
|
+
elem_typ: type[T],
|
|
436
|
+
*,
|
|
437
|
+
not_null: Literal[True],
|
|
438
|
+
) -> psycopg.rows.BaseRowFactory[list[T]]: ...
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
@overload
|
|
442
|
+
def typed_array_row[T](
|
|
443
|
+
elem_typ: type[T],
|
|
444
|
+
*,
|
|
445
|
+
not_null: Literal[False],
|
|
446
|
+
) -> psycopg.rows.BaseRowFactory[list[T] | None]: ...
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
def typed_array_row[T](
|
|
450
|
+
elem_typ: type[T], *, not_null: bool
|
|
451
|
+
) -> psycopg.rows.BaseRowFactory[list[T] | None]:
|
|
452
|
+
def typed_array_row_(
|
|
453
|
+
cursor: BaseCursor[Any, Any],
|
|
454
|
+
) -> psycopg.rows.RowMaker[list[T] | None]:
|
|
455
|
+
scalar_row_ = psycopg.rows.scalar_row(cursor)
|
|
456
|
+
|
|
457
|
+
def typed_array_row__(values: Sequence[Any]) -> list[T] | None:
|
|
458
|
+
val = scalar_row_(values)
|
|
459
|
+
if val is None:
|
|
460
|
+
if not_null:
|
|
461
|
+
msg = "Expected non-null value, got None"
|
|
462
|
+
raise TypeError(msg)
|
|
463
|
+
return None
|
|
464
|
+
if not _is_object_list(val):
|
|
465
|
+
msg = f"Expected scalar of type list[{elem_typ}], got {type(val)}"
|
|
466
|
+
raise TypeError(msg)
|
|
467
|
+
return [_check_scalar_type(v, elem_typ) for v in val]
|
|
468
|
+
|
|
469
|
+
return typed_array_row__
|
|
470
|
+
|
|
471
|
+
return typed_array_row_
|
|
472
|
+
|
|
473
|
+
|
|
474
|
+
def _check_scalar_type[T](val: object, typ: type[T]) -> T:
|
|
475
|
+
if _is_instance(val, typ):
|
|
476
|
+
return val
|
|
477
|
+
msg = f"Expected scalar of type {typ}, got {type(val)}"
|
|
478
|
+
raise TypeError(msg)
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
def _is_instance[T](val: object, typ: type[T]) -> TypeGuard[T]:
|
|
482
|
+
return isinstance(val, typ)
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
def _is_object_list(val: object) -> TypeGuard[list[object]]:
|
|
486
|
+
return isinstance(val, list)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|