iron-sql 0.5.2__tar.gz → 0.5.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: iron-sql
3
- Version: 0.5.2
3
+ Version: 0.5.4
4
4
  Summary: iron_sql generates typed async PostgreSQL clients and runtime helpers from schemas and SQL queries
5
5
  Keywords: postgresql,sql,sqlc,psycopg,codegen,async
6
6
  Author: Ilia Ablamonov
@@ -16,7 +16,7 @@ Requires-Dist: psycopg>=3.3.2
16
16
  Requires-Dist: psycopg-pool>=3.3.0
17
17
  Requires-Dist: pydantic>=2.12.4
18
18
  Requires-Dist: inflection>=0.5.1 ; extra == 'codegen'
19
- Requires-Dist: sqlc>=1.30.0.post18 ; extra == 'codegen'
19
+ Requires-Dist: sqlc>=1.30.0.post19 ; extra == 'codegen'
20
20
  Requires-Python: >=3.13
21
21
  Project-URL: Homepage, https://github.com/Flamefork/iron_sql
22
22
  Project-URL: Repository, https://github.com/Flamefork/iron_sql.git
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "iron-sql"
3
- version = "0.5.2"
3
+ version = "0.5.4"
4
4
 
5
5
  description = "iron_sql generates typed async PostgreSQL clients and runtime helpers from schemas and SQL queries"
6
6
  readme = "README.md"
@@ -26,7 +26,7 @@ dependencies = [
26
26
  [project.optional-dependencies]
27
27
  codegen = [
28
28
  "inflection>=0.5.1",
29
- "sqlc>=1.30.0.post18",
29
+ "sqlc>=1.30.0.post19",
30
30
  ]
31
31
 
32
32
  [project.urls]
@@ -53,14 +53,6 @@ dev = [
53
53
 
54
54
  [tool.pyright]
55
55
  typeCheckingMode = "strict"
56
- reportUnknownArgumentType = "none"
57
- reportUnknownLambdaType = "none"
58
- reportUnknownMemberType = "none"
59
- reportUnknownParameterType = "none"
60
- reportUnknownVariableType = "none"
61
- reportMissingParameterType = "none"
62
- reportMissingTypeArgument = "none"
63
- reportMissingTypeStubs = "none"
64
56
  deprecateTypingAliases = true
65
57
  reportImportCycles = true
66
58
  reportUnnecessaryTypeIgnoreComment = true
@@ -31,9 +31,18 @@ class ColumnSpec:
31
31
  name: str
32
32
  table: str
33
33
  py_type: str
34
+ element_py_type: str | None = None
34
35
  json_type: str | None = None
35
36
 
36
37
 
38
+ _JSON_PARAM_DUMPERS = {
39
+ "json": "runtime.dump_json_value",
40
+ "jsonb": "runtime.dump_json_value",
41
+ "text": "runtime.dump_json_text",
42
+ "varchar": "runtime.dump_json_text",
43
+ }
44
+
45
+
37
46
  @dataclass(kw_only=True, frozen=True)
38
47
  class ParamSpec:
39
48
  name: str
@@ -51,19 +60,23 @@ class ParamSpec:
51
60
 
52
61
  @property
53
62
  def serialized_expr(self) -> str:
63
+ expr = self.name
64
+ wraps_value = False
54
65
  if self.json_type:
55
- return f"runtime.serialize_json_param({self.json_type}, {self.name}, {self.db_type!r})" # noqa: E501
56
-
66
+ dump_fn = _JSON_PARAM_DUMPERS[self.db_type]
67
+ expr = f"{dump_fn}({self.json_type}, {self.name})"
68
+ wraps_value = True
57
69
  match self.db_type:
58
70
  case "json":
59
- expr = f"psycopg.types.json.Json({self.name})"
71
+ expr = f"psycopg.types.json.Json({expr})"
72
+ wraps_value = True
60
73
  case "jsonb":
61
- expr = f"psycopg.types.json.Jsonb({self.name})"
74
+ expr = f"psycopg.types.json.Jsonb({expr})"
75
+ wraps_value = True
62
76
  case _:
63
- return self.name
64
-
65
- if not self.not_null:
66
- return f"{expr} if {self.name} is not None else None"
77
+ pass
78
+ if wraps_value and not self.not_null:
79
+ expr = f"{expr} if {self.name} is not None else None"
67
80
  return expr
68
81
 
69
82
 
@@ -159,16 +172,17 @@ class TypeResolver:
159
172
  json_column_type_overrides: dict[tuple[str, str], str]
160
173
 
161
174
  def column_spec(self, column: Column) -> ColumnSpec:
162
- _, py_type, json_type = self._resolve(column)
175
+ _, py_type, element_py_type, json_type = self._resolve(column)
163
176
  return ColumnSpec(
164
177
  name=column.name,
165
178
  table=column.table.name if column.table else "unknown",
166
179
  py_type=py_type,
180
+ element_py_type=element_py_type,
167
181
  json_type=json_type,
168
182
  )
169
183
 
170
184
  def param_spec(self, column: Column, name: str, *, is_named: bool) -> ParamSpec:
171
- db_type, py_type, json_type = self._resolve(column)
185
+ db_type, py_type, _, json_type = self._resolve(column)
172
186
  return ParamSpec(
173
187
  name=name,
174
188
  py_type=py_type,
@@ -179,7 +193,7 @@ class TypeResolver:
179
193
  json_type=json_type,
180
194
  )
181
195
 
182
- def _resolve(self, column: Column) -> tuple[str, str, str | None]:
196
+ def _resolve(self, column: Column) -> tuple[str, str, str | None, str | None]:
183
197
  db_type = column.type.name.removeprefix("pg_catalog.")
184
198
 
185
199
  json_type = None
@@ -210,13 +224,15 @@ class TypeResolver:
210
224
  )
211
225
  py_type = "object"
212
226
 
227
+ element_py_type = None
213
228
  if column.is_array:
229
+ element_py_type = py_type
214
230
  py_type = f"Sequence[{py_type}]"
215
231
 
216
232
  if not column.not_null:
217
233
  py_type += " | None"
218
234
 
219
- return db_type, py_type, json_type
235
+ return db_type, py_type, element_py_type, json_type
220
236
 
221
237
 
222
238
  def collect_used_enums(sqlc_res: SQLCResult) -> set[tuple[str, str]]:
@@ -316,11 +332,24 @@ def generate_sql_module( # noqa: PLR0913, PLR0914
316
332
 
317
333
  used_enums = collect_used_enums(sqlc_res)
318
334
 
319
- enums = sorted(
320
- render_enum_class(e, module_name, to_pascal_fn, to_snake_fn)
335
+ enum_specs = [
336
+ (schema, e)
321
337
  for schema in sqlc_res.catalog.schemas
322
338
  for e in schema.enums
323
339
  if (schema.name, e.name) in used_enums
340
+ ]
341
+
342
+ enums = sorted(
343
+ render_enum_class(e, module_name, to_pascal_fn, to_snake_fn)
344
+ for _, e in enum_specs
345
+ )
346
+
347
+ enum_registry = sorted(
348
+ (
349
+ f"{schema.name}.{e.name}",
350
+ enum_class_name(e.name, module_name, to_pascal_fn, to_snake_fn),
351
+ )
352
+ for schema, e in enum_specs
324
353
  )
325
354
 
326
355
  query_classes = render_query_classes(
@@ -341,6 +370,7 @@ def generate_sql_module( # noqa: PLR0913, PLR0914
341
370
  sql_fn_name,
342
371
  entities,
343
372
  enums,
373
+ enum_registry,
344
374
  query_classes,
345
375
  query_overloads,
346
376
  query_dict_entries,
@@ -390,12 +420,7 @@ def render_query_classes(
390
420
  for p in q.params
391
421
  ],
392
422
  query_result_types[q.name],
393
- len(q.columns),
394
- (
395
- resolver.column_spec(q.columns[0]).json_type
396
- if len(q.columns) == 1
397
- else None
398
- ),
423
+ tuple(resolver.column_spec(column) for column in q.columns),
399
424
  query_locations_by_name[q.name],
400
425
  )
401
426
  for q in sorted(sqlc_queries, key=lambda q: query_order[q.name])
@@ -408,7 +433,7 @@ def resolve_json_model_overrides(
408
433
  if not overrides:
409
434
  return "", {}
410
435
 
411
- json_compatible_types = {"json", "jsonb", "text", "varchar"}
436
+ json_compatible_types = set(_JSON_PARAM_DUMPERS)
412
437
  col_types = {
413
438
  (table.rel.name, column.name): column.type.name.removeprefix("pg_catalog.")
414
439
  for schema in catalog.schemas
@@ -465,6 +490,7 @@ def render_module( # noqa: PLR0913, PLR0917
465
490
  sql_fn_name: str,
466
491
  entities: list[str],
467
492
  enums: list[str],
493
+ enum_registry: list[tuple[str, str]],
468
494
  query_classes: list[str],
469
495
  query_overloads: list[str],
470
496
  query_dict_entries: list[str],
@@ -489,6 +515,20 @@ def render_module( # noqa: PLR0913, PLR0917
489
515
 
490
516
  imports_block = "\n".join(imports)
491
517
 
518
+ pre_pool_blocks: list[str] = []
519
+ if enums:
520
+ pre_pool_blocks.append("\n\n\n".join(enums))
521
+ if enum_registry:
522
+ registry_entries = ",\n ".join(
523
+ f'("{pg_name}", {class_name})' for pg_name, class_name in enum_registry
524
+ )
525
+ registry_type = "list[tuple[str, type[StrEnum]]]"
526
+ pre_pool_blocks.append(
527
+ f"ENUM_TYPES: {registry_type} = [\n {registry_entries},\n]"
528
+ )
529
+ pool_args.append("enum_types=ENUM_TYPES")
530
+ pre_pool_section = "".join(f"{block}\n\n\n" for block in pre_pool_blocks)
531
+
492
532
  pool_args_str = ",\n ".join(pool_args)
493
533
 
494
534
  return f"""
@@ -525,7 +565,7 @@ from iron_sql import runtime
525
565
  {imports_block}
526
566
 
527
567
 
528
- {module_name.upper()}_POOL = runtime.ConnectionPool(
568
+ {pre_pool_section}{module_name.upper()}_POOL = runtime.ConnectionPool(
529
569
  {pool_args_str},
530
570
  )
531
571
 
@@ -561,9 +601,6 @@ async def {module_name}_notify(channel: str, payload: str = "") -> None:
561
601
  await runtime.notify(conn, channel, payload)
562
602
 
563
603
 
564
- {"\n\n\n".join(enums)}
565
-
566
-
567
604
  {"\n\n\n".join(entities)}
568
605
 
569
606
 
@@ -593,14 +630,23 @@ def {sql_fn_name}(sql: str, row_type: str | None = None) -> Query[Any]:
593
630
  """.strip() # noqa: E501
594
631
 
595
632
 
633
+ def enum_class_name(
634
+ enum_name: str,
635
+ module_name: str,
636
+ to_pascal_fn: Callable[[str], str],
637
+ to_snake_fn: Callable[[str], str],
638
+ ) -> str:
639
+ return to_pascal_fn(f"{module_name}_{to_snake_fn(enum_name)}")
640
+
641
+
596
642
  def render_enum_class(
597
643
  enum: Enum,
598
644
  module_name: str,
599
645
  to_pascal_fn: Callable[[str], str],
600
646
  to_snake_fn: Callable[[str], str],
601
647
  ) -> str:
602
- class_name = to_pascal_fn(f"{module_name}_{to_snake_fn(enum.name)}")
603
- members = []
648
+ class_name = enum_class_name(enum.name, module_name, to_pascal_fn, to_snake_fn)
649
+ members: list[str] = []
604
650
  seen_names: dict[str, int] = {}
605
651
 
606
652
  for val in enum.vals:
@@ -642,7 +688,7 @@ class {name}:
642
688
 
643
689
 
644
690
  def deduplicate_params(params: list[ParamSpec]) -> list[ParamSpec]:
645
- seen = defaultdict(int)
691
+ seen: defaultdict[str, int] = defaultdict(int)
646
692
  result: list[ParamSpec] = []
647
693
  for param in params:
648
694
  seen[param.name] += 1
@@ -658,8 +704,7 @@ def render_query_class(
658
704
  sql: str,
659
705
  query_params: list[ParamSpec],
660
706
  result: str,
661
- columns_num: int,
662
- scalar_json_type: str | None,
707
+ result_columns: tuple[ColumnSpec, ...],
663
708
  locations: list[str],
664
709
  ) -> str:
665
710
  query_params = deduplicate_params(query_params)
@@ -681,24 +726,9 @@ def render_query_class(
681
726
  query_fn_params.insert(0, "self")
682
727
 
683
728
  base_result = result.removesuffix(" | None")
729
+ row_factory = render_row_factory(result, result_columns)
684
730
 
685
- if columns_num == 0:
686
- row_factory = "psycopg.rows.scalar_row"
687
- elif columns_num == 1:
688
- not_null_str = "True" if not result.endswith(" | None") else "False"
689
- validate_arg = (
690
- f", validate=lambda _v: runtime.validate_json_field({scalar_json_type}, _v)"
691
- if scalar_json_type
692
- else ""
693
- )
694
- row_factory = (
695
- f"runtime.typed_scalar_row"
696
- f"({base_result}, not_null={not_null_str}{validate_arg})"
697
- )
698
- else:
699
- row_factory = f"psycopg.rows.class_row({result})"
700
-
701
- if columns_num > 0:
731
+ if result_columns:
702
732
  methods = f"""
703
733
 
704
734
  async def query_all_rows({", ".join(query_fn_params)}) -> list[{result}]:
@@ -738,6 +768,34 @@ class {query_name}(Query[{result}]):
738
768
  """.strip()
739
769
 
740
770
 
771
+ def render_row_factory(result: str, columns: tuple[ColumnSpec, ...]) -> str:
772
+ match columns:
773
+ case ():
774
+ return "psycopg.rows.scalar_row"
775
+ case (column,):
776
+ return render_scalar_row_factory(result, column)
777
+ case _:
778
+ return f"psycopg.rows.class_row({result})"
779
+
780
+
781
+ def render_scalar_row_factory(result: str, column: ColumnSpec) -> str:
782
+ base_result = result.removesuffix(" | None")
783
+ not_null = "True" if not result.endswith(" | None") else "False"
784
+
785
+ if column.element_py_type is not None:
786
+ return f"runtime.typed_array_row({column.element_py_type}, not_null={not_null})"
787
+
788
+ validate_arg = (
789
+ f", validate=lambda _v: runtime.validate_json_field({column.json_type}, _v)"
790
+ if column.json_type
791
+ else ""
792
+ )
793
+ if " | " in base_result and not validate_arg:
794
+ return f"runtime.typed_value_row(not_null={not_null})"
795
+
796
+ return f"runtime.typed_scalar_row({base_result}, not_null={not_null}{validate_arg})"
797
+
798
+
741
799
  def render_query_overload(
742
800
  sql_fn_name: str, query_name: str, sql: str, row_type: str | None
743
801
  ) -> str:
@@ -846,7 +904,7 @@ def build_entities(
846
904
  key=lambda e: (e.table_name is None, e.table_name or ""),
847
905
  )
848
906
 
849
- query_result_types = {}
907
+ query_result_types: dict[str, str] = {}
850
908
  for q in queries_from_sqlc:
851
909
  if len(q.columns) == 0:
852
910
  query_result_types[q.name] = "None"
@@ -1,5 +1,6 @@
1
1
  import asyncio
2
2
  import contextlib
3
+ import functools
3
4
  import itertools
4
5
  import types
5
6
  from collections.abc import AsyncGenerator
@@ -8,29 +9,28 @@ from collections.abc import Callable
8
9
  from collections.abc import Sequence
9
10
  from contextlib import asynccontextmanager
10
11
  from contextvars import ContextVar
11
- from enum import Enum
12
+ from enum import StrEnum
12
13
  from typing import Any
13
14
  from typing import ClassVar
14
15
  from typing import Literal
15
16
  from typing import Self
16
17
  from typing import TypedDict
18
+ from typing import TypeGuard
17
19
  from typing import overload
18
20
 
19
21
  import psycopg
20
22
  import psycopg.abc
21
23
  import psycopg.rows
22
24
  import psycopg.sql
23
- import psycopg.types.json
25
+ import psycopg.types.enum
24
26
  import psycopg_pool
27
+ from psycopg._cursor_base import BaseCursor
25
28
  from pydantic import TypeAdapter
26
29
 
27
- _adapter_cache: dict[object, TypeAdapter[object]] = {}
28
30
 
29
-
30
- def get_adapter(typ: object) -> TypeAdapter[object]:
31
- if typ not in _adapter_cache:
32
- _adapter_cache[typ] = TypeAdapter(typ)
33
- return _adapter_cache[typ]
31
+ @functools.cache
32
+ def get_adapter(typ: object) -> TypeAdapter[Any]:
33
+ return TypeAdapter(typ)
34
34
 
35
35
 
36
36
  class NoRowsError(Exception):
@@ -43,7 +43,7 @@ class TooManyRowsError(Exception):
43
43
 
44
44
  @asynccontextmanager
45
45
  async def listen(
46
- conn: psycopg.AsyncConnection, channel: str
46
+ conn: psycopg.AsyncConnection[Any], channel: str
47
47
  ) -> AsyncGenerator[AsyncGenerator[str]]:
48
48
  _validate_channel(channel)
49
49
  if await _has_active_listen_subscriptions(conn):
@@ -65,7 +65,9 @@ async def listen(
65
65
  await execute_unlisten(conn, channel)
66
66
 
67
67
 
68
- async def notify(conn: psycopg.AsyncConnection, channel: str, payload: str) -> None:
68
+ async def notify(
69
+ conn: psycopg.AsyncConnection[Any], channel: str, payload: str
70
+ ) -> None:
69
71
  _validate_channel(channel)
70
72
  await conn.execute(
71
73
  psycopg.sql.SQL("NOTIFY {}, {}").format(
@@ -75,21 +77,21 @@ async def notify(conn: psycopg.AsyncConnection, channel: str, payload: str) -> N
75
77
  )
76
78
 
77
79
 
78
- async def execute_listen(conn: psycopg.AsyncConnection, channel: str) -> None:
80
+ async def execute_listen(conn: psycopg.AsyncConnection[Any], channel: str) -> None:
79
81
  _validate_channel(channel)
80
82
  await conn.execute(
81
83
  psycopg.sql.SQL("LISTEN {}").format(psycopg.sql.Identifier(channel))
82
84
  )
83
85
 
84
86
 
85
- async def execute_unlisten(conn: psycopg.AsyncConnection, channel: str) -> None:
87
+ async def execute_unlisten(conn: psycopg.AsyncConnection[Any], channel: str) -> None:
86
88
  _validate_channel(channel)
87
89
  await conn.execute(
88
90
  psycopg.sql.SQL("UNLISTEN {}").format(psycopg.sql.Identifier(channel))
89
91
  )
90
92
 
91
93
 
92
- async def _has_active_listen_subscriptions(conn: psycopg.AsyncConnection) -> bool:
94
+ async def _has_active_listen_subscriptions(conn: psycopg.AsyncConnection[Any]) -> bool:
93
95
  async with conn.cursor() as cur:
94
96
  await cur.execute("SELECT EXISTS (SELECT FROM pg_listening_channels())")
95
97
  row = await cur.fetchone()
@@ -113,7 +115,9 @@ def _next_cursor_name() -> str:
113
115
 
114
116
 
115
117
  @asynccontextmanager
116
- async def _ensure_transaction(conn: psycopg.AsyncConnection) -> AsyncGenerator[None]:
118
+ async def _ensure_transaction(
119
+ conn: psycopg.AsyncConnection[Any],
120
+ ) -> AsyncGenerator[None]:
117
121
  match conn.info.transaction_status:
118
122
  case psycopg.pq.TransactionStatus.IDLE:
119
123
  async with conn.transaction():
@@ -129,10 +133,10 @@ class Query[T]:
129
133
  _stmt: ClassVar[psycopg.sql.SQL]
130
134
  _row_factory: psycopg.rows.BaseRowFactory[T]
131
135
  _connection_factory: Callable[
132
- [], contextlib.AbstractAsyncContextManager[psycopg.AsyncConnection]
136
+ [], contextlib.AbstractAsyncContextManager[psycopg.AsyncConnection[Any]]
133
137
  ]
134
138
 
135
- def with_connection(self, connection: psycopg.AsyncConnection) -> Self:
139
+ def with_connection(self, connection: psycopg.AsyncConnection[Any]) -> Self:
136
140
  q = self.__class__()
137
141
  q._connection_factory = lambda: contextlib.nullcontext(connection) # noqa: SLF001
138
142
  return q
@@ -179,6 +183,35 @@ class PoolOptions(TypedDict, total=False):
179
183
  reconnect_failed: Callable[[psycopg_pool.AsyncConnectionPool[Any]], Awaitable[None]]
180
184
 
181
185
 
186
+ async def register_enums(
187
+ conn: psycopg.AsyncConnection[Any],
188
+ enum_types: Sequence[tuple[str, type[StrEnum]]],
189
+ ) -> None:
190
+ for pg_name, enum_cls in enum_types:
191
+ info = await psycopg.types.enum.EnumInfo.fetch(conn, pg_name)
192
+ if info is None:
193
+ msg = f"Enum type {pg_name!r} not found in database"
194
+ raise RuntimeError(msg)
195
+ psycopg.types.enum.register_enum(
196
+ info,
197
+ conn,
198
+ enum_cls,
199
+ mapping=[(member, member.value) for member in enum_cls],
200
+ )
201
+
202
+
203
+ def _enum_configure(
204
+ enum_types: Sequence[tuple[str, type[StrEnum]]],
205
+ user_configure: Callable[[psycopg.AsyncConnection[Any]], Awaitable[None]] | None,
206
+ ) -> Callable[[psycopg.AsyncConnection[Any]], Awaitable[None]]:
207
+ async def configure(conn: psycopg.AsyncConnection[Any]) -> None:
208
+ await register_enums(conn, enum_types)
209
+ if user_configure is not None:
210
+ await user_configure(conn)
211
+
212
+ return configure
213
+
214
+
182
215
  class ConnectionPool:
183
216
  def __init__(
184
217
  self,
@@ -187,11 +220,13 @@ class ConnectionPool:
187
220
  name: str | None = None,
188
221
  application_name: str | None = None,
189
222
  pool_options: PoolOptions | None = None,
223
+ enum_types: Sequence[tuple[str, type[StrEnum]]] = (),
190
224
  ) -> None:
191
225
  self.conninfo = conninfo
192
226
  self.name = name
193
227
  self.application_name = application_name
194
228
  self.pool_options = pool_options or {}
229
+ self.enum_types = enum_types
195
230
  self._init_psycopg_pool()
196
231
 
197
232
  async def close(self) -> None:
@@ -217,7 +252,7 @@ class ConnectionPool:
217
252
  await self.psycopg_pool.check()
218
253
 
219
254
  @asynccontextmanager
220
- async def connection(self) -> AsyncGenerator[psycopg.AsyncConnection]:
255
+ async def connection(self) -> AsyncGenerator[psycopg.AsyncConnection[Any]]:
221
256
  task = asyncio.current_task()
222
257
  cancelling_before = 0 if task is None else task.cancelling()
223
258
  await self.psycopg_pool.open()
@@ -232,6 +267,10 @@ class ConnectionPool:
232
267
  forwarded: dict[str, Any] = {
233
268
  k: v for k, v in self.pool_options.items() if k != "kwargs"
234
269
  }
270
+ if self.enum_types:
271
+ forwarded["configure"] = _enum_configure(
272
+ self.enum_types, forwarded.get("configure")
273
+ )
235
274
  conn_kwargs = {
236
275
  **user_kwargs,
237
276
  # https://www.psycopg.org/psycopg3/docs/basic/transactions.html#autocommit-transactions
@@ -249,8 +288,8 @@ class ConnectionPool:
249
288
 
250
289
  @asynccontextmanager
251
290
  async def connection_in_context(
252
- self, context_var: ContextVar[psycopg.AsyncConnection | None]
253
- ) -> AsyncGenerator[psycopg.AsyncConnection]:
291
+ self, context_var: ContextVar[psycopg.AsyncConnection[Any] | None]
292
+ ) -> AsyncGenerator[psycopg.AsyncConnection[Any]]:
254
293
  conn = context_var.get()
255
294
  if conn is not None:
256
295
  yield conn
@@ -263,42 +302,40 @@ class ConnectionPool:
263
302
  context_var.reset(token)
264
303
 
265
304
 
266
- def validate_json_field(typ: object, value: object) -> object:
267
- if value is None:
268
- return None
305
+ def validate_json_field[T](typ: type[T], value: object) -> T:
269
306
  adapter = get_adapter(typ)
270
307
  if isinstance(value, str | bytes):
271
308
  return adapter.validate_json(value)
272
309
  return adapter.validate_python(value)
273
310
 
274
311
 
275
- def json_validated(**json_fields: object):
276
- def decorator[T](cls: type[T]) -> type[T]:
277
- original = getattr(cls, "__post_init__", None)
312
+ def json_validated[T](**json_fields: object) -> Callable[[type[T]], type[T]]:
313
+ def decorator(cls: type[T]) -> type[T]:
314
+ original_post_init = getattr(cls, "__post_init__", None)
278
315
 
279
316
  def __post_init__(self: object) -> None: # noqa: N807
280
- if original is not None:
281
- original(self)
317
+ if original_post_init is not None:
318
+ original_post_init(self)
282
319
  for name, typ in json_fields.items():
283
- setattr(self, name, validate_json_field(typ, getattr(self, name)))
320
+ current = getattr(self, name)
321
+ if current is None:
322
+ continue
323
+ setattr(self, name, validate_json_field(typ, current)) # pyright: ignore[reportArgumentType]
284
324
 
285
- cls.__post_init__ = __post_init__ # type: ignore[attr-defined]
325
+ setattr(cls, "__post_init__", __post_init__) # noqa: B010
286
326
  return cls
287
327
 
288
328
  return decorator
289
329
 
290
330
 
291
- def serialize_json_param(typ: object, value: object, db_type: str) -> object:
292
- if value is None:
293
- return None
331
+ def dump_json_value(typ: object, value: object) -> object:
294
332
  adapter = get_adapter(typ)
295
- match db_type:
296
- case "json":
297
- return psycopg.types.json.Json(adapter.dump_python(value, mode="json"))
298
- case "jsonb":
299
- return psycopg.types.json.Jsonb(adapter.dump_python(value, mode="json"))
300
- case _:
301
- return adapter.dump_json(value).decode()
333
+ return adapter.dump_python(value, mode="json")
334
+
335
+
336
+ def dump_json_text(typ: object, value: object) -> str:
337
+ adapter = get_adapter(typ)
338
+ return adapter.dump_json(value).decode()
302
339
 
303
340
 
304
341
  def get_one_row[T](rows: list[T]) -> T:
@@ -338,7 +375,9 @@ def typed_scalar_row[T](
338
375
  def typed_scalar_row[T](
339
376
  typ: type[T], *, not_null: bool, validate: Callable[[object], T] | None = None
340
377
  ) -> psycopg.rows.BaseRowFactory[T | None]:
341
- def typed_scalar_row_(cursor) -> psycopg.rows.RowMaker[T | None]:
378
+ def typed_scalar_row_(
379
+ cursor: BaseCursor[Any, Any],
380
+ ) -> psycopg.rows.RowMaker[T | None]:
342
381
  scalar_row_ = psycopg.rows.scalar_row(cursor)
343
382
 
344
383
  def typed_scalar_row__(values: Sequence[Any]) -> T | None:
@@ -350,13 +389,98 @@ def typed_scalar_row[T](
350
389
  return None
351
390
  if validate:
352
391
  return validate(val)
353
- if not isinstance(val, typ):
354
- if issubclass(typ, Enum):
355
- return typ(val)
356
- msg = f"Expected scalar of type {typ}, got {type(val)}"
357
- raise TypeError(msg)
358
- return val
392
+ return _check_scalar_type(val, typ)
359
393
 
360
394
  return typed_scalar_row__
361
395
 
362
396
  return typed_scalar_row_
397
+
398
+
399
+ @overload
400
+ def typed_value_row[T](
401
+ *,
402
+ not_null: Literal[True],
403
+ ) -> psycopg.rows.BaseRowFactory[T]: ...
404
+
405
+
406
+ @overload
407
+ def typed_value_row[T](
408
+ *,
409
+ not_null: Literal[False],
410
+ ) -> psycopg.rows.BaseRowFactory[T | None]: ...
411
+
412
+
413
+ def typed_value_row[T](*, not_null: bool) -> psycopg.rows.BaseRowFactory[T | None]:
414
+ def typed_value_row_(
415
+ cursor: BaseCursor[Any, Any],
416
+ ) -> psycopg.rows.RowMaker[T | None]:
417
+ scalar_row_ = psycopg.rows.scalar_row(cursor)
418
+
419
+ def typed_value_row__(values: Sequence[Any]) -> T | None:
420
+ val = scalar_row_(values)
421
+ if val is None:
422
+ if not_null:
423
+ msg = "Expected non-null value, got None"
424
+ raise TypeError(msg)
425
+ return None
426
+ return val
427
+
428
+ return typed_value_row__
429
+
430
+ return typed_value_row_
431
+
432
+
433
+ @overload
434
+ def typed_array_row[T](
435
+ elem_typ: type[T],
436
+ *,
437
+ not_null: Literal[True],
438
+ ) -> psycopg.rows.BaseRowFactory[list[T]]: ...
439
+
440
+
441
+ @overload
442
+ def typed_array_row[T](
443
+ elem_typ: type[T],
444
+ *,
445
+ not_null: Literal[False],
446
+ ) -> psycopg.rows.BaseRowFactory[list[T] | None]: ...
447
+
448
+
449
+ def typed_array_row[T](
450
+ elem_typ: type[T], *, not_null: bool
451
+ ) -> psycopg.rows.BaseRowFactory[list[T] | None]:
452
+ def typed_array_row_(
453
+ cursor: BaseCursor[Any, Any],
454
+ ) -> psycopg.rows.RowMaker[list[T] | None]:
455
+ scalar_row_ = psycopg.rows.scalar_row(cursor)
456
+
457
+ def typed_array_row__(values: Sequence[Any]) -> list[T] | None:
458
+ val = scalar_row_(values)
459
+ if val is None:
460
+ if not_null:
461
+ msg = "Expected non-null value, got None"
462
+ raise TypeError(msg)
463
+ return None
464
+ if not _is_object_list(val):
465
+ msg = f"Expected scalar of type list[{elem_typ}], got {type(val)}"
466
+ raise TypeError(msg)
467
+ return [_check_scalar_type(v, elem_typ) for v in val]
468
+
469
+ return typed_array_row__
470
+
471
+ return typed_array_row_
472
+
473
+
474
+ def _check_scalar_type[T](val: object, typ: type[T]) -> T:
475
+ if _is_instance(val, typ):
476
+ return val
477
+ msg = f"Expected scalar of type {typ}, got {type(val)}"
478
+ raise TypeError(msg)
479
+
480
+
481
+ def _is_instance[T](val: object, typ: type[T]) -> TypeGuard[T]:
482
+ return isinstance(val, typ)
483
+
484
+
485
+ def _is_object_list(val: object) -> TypeGuard[list[object]]:
486
+ return isinstance(val, list)
File without changes
File without changes
File without changes