iron-sql 0.4.2__tar.gz → 0.4.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: iron-sql
3
- Version: 0.4.2
3
+ Version: 0.4.4
4
4
  Summary: iron_sql generates typed async PostgreSQL clients and runtime helpers from schemas and SQL queries
5
5
  Keywords: postgresql,sql,sqlc,psycopg,codegen,async
6
6
  Author: Ilia Ablamonov
@@ -46,6 +46,7 @@ The `sqlc` binary is bundled automatically via the `sqlc` Python package.
46
46
  - **Query discovery.** `generate_sql_package` scans your codebase for calls like `<package>_sql("SELECT ...")`, runs `sqlc` for type analysis, and emits a typed module.
47
47
  - **Strong typing.** Generated dataclasses and method signatures flow through your IDE and type checker.
48
48
  - **Async runtime.** Built on `psycopg` v3 with pooled connections, context-based connection reuse, and transaction helpers.
49
+ - **Streaming.** `query_stream()` uses server-side cursors for memory-efficient iteration over large result sets.
49
50
  - **Safe by default.** Helper methods enforce expected row counts instead of returning silent `None`.
50
51
 
51
52
  ## Package Layout
@@ -98,7 +99,8 @@ The `sqlc` binary is bundled automatically via the `sqlc` Python package.
98
99
  - `ConnectionPool` opens lazily and reopens after `close()`, with `ContextVar`-based connection reuse for nested contexts.
99
100
  - `*_listen_session()` uses a dedicated pooled connection and doesn't reuse `ContextVar` transaction connections.
100
101
  - `query_single_row()` raises `NoRowsError`; `query_optional_row()` returns `None`. Both raise `TooManyRowsError` on 2+ rows.
101
- - JSONB params are sent with `pgjson.Jsonb`; JSON with `pgjson.Json`. Scalar row factories validate types at runtime.
102
+ - `query_stream()` returns an async context manager yielding an `AsyncIterator`; uses server-side cursors with automatic transaction management.
103
+ - JSONB params are sent with `psycopg.types.json.Jsonb`; JSON with `psycopg.types.json.Json`. Scalar row factories validate types at runtime.
102
104
  - `json_validated` decorator applies Pydantic model validation to dataclass fields on construction.
103
105
 
104
106
  ## Example
@@ -20,6 +20,7 @@ The `sqlc` binary is bundled automatically via the `sqlc` Python package.
20
20
  - **Query discovery.** `generate_sql_package` scans your codebase for calls like `<package>_sql("SELECT ...")`, runs `sqlc` for type analysis, and emits a typed module.
21
21
  - **Strong typing.** Generated dataclasses and method signatures flow through your IDE and type checker.
22
22
  - **Async runtime.** Built on `psycopg` v3 with pooled connections, context-based connection reuse, and transaction helpers.
23
+ - **Streaming.** `query_stream()` uses server-side cursors for memory-efficient iteration over large result sets.
23
24
  - **Safe by default.** Helper methods enforce expected row counts instead of returning silent `None`.
24
25
 
25
26
  ## Package Layout
@@ -72,7 +73,8 @@ The `sqlc` binary is bundled automatically via the `sqlc` Python package.
72
73
  - `ConnectionPool` opens lazily and reopens after `close()`, with `ContextVar`-based connection reuse for nested contexts.
73
74
  - `*_listen_session()` uses a dedicated pooled connection and doesn't reuse `ContextVar` transaction connections.
74
75
  - `query_single_row()` raises `NoRowsError`; `query_optional_row()` returns `None`. Both raise `TooManyRowsError` on 2+ rows.
75
- - JSONB params are sent with `pgjson.Jsonb`; JSON with `pgjson.Json`. Scalar row factories validate types at runtime.
76
+ - `query_stream()` returns an async context manager yielding an `AsyncIterator`; uses server-side cursors with automatic transaction management.
77
+ - JSONB params are sent with `psycopg.types.json.Jsonb`; JSON with `psycopg.types.json.Json`. Scalar row factories validate types at runtime.
76
78
  - `json_validated` decorator applies Pydantic model validation to dataclass fields on construction.
77
79
 
78
80
  ## Example
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "iron-sql"
3
- version = "0.4.2"
3
+ version = "0.4.4"
4
4
 
5
5
  description = "iron_sql generates typed async PostgreSQL clients and runtime helpers from schemas and SQL queries"
6
6
  readme = "README.md"
@@ -3,6 +3,7 @@ import dataclasses
3
3
  import hashlib
4
4
  import importlib
5
5
  import logging
6
+ import re
6
7
  import warnings
7
8
  from collections import defaultdict
8
9
  from collections.abc import Callable
@@ -55,9 +56,9 @@ class ParamSpec:
55
56
 
56
57
  match self.db_type:
57
58
  case "json":
58
- expr = f"pgjson.Json({self.name})"
59
+ expr = f"psycopg.types.json.Json({self.name})"
59
60
  case "jsonb":
60
- expr = f"pgjson.Jsonb({self.name})"
61
+ expr = f"psycopg.types.json.Jsonb({self.name})"
61
62
  case _:
62
63
  return self.name
63
64
 
@@ -186,6 +187,24 @@ def collect_used_enums(sqlc_res: SQLCResult) -> set[tuple[str, str]]:
186
187
  }
187
188
 
188
189
 
190
+ def map_sqlc_error(
191
+ error: str,
192
+ block_starts: list[tuple[int, str]],
193
+ all_locations: dict[str, list[str]],
194
+ ) -> str:
195
+ def replace(m: re.Match[str]) -> str:
196
+ line = int(m.group(1))
197
+ name = next((n for start, n in reversed(block_starts) if start <= line), None)
198
+ if name is None:
199
+ return m.group(0)
200
+ locations = all_locations.get(name)
201
+ if not locations:
202
+ return m.group(0)
203
+ return f"{', '.join(locations)}:"
204
+
205
+ return re.sub(r"queries\.sql:(\d+)(?::\d+)?:", replace, error)
206
+
207
+
189
208
  def generate_sql_package( # noqa: PLR0913, PLR0914
190
209
  *,
191
210
  schema_path: Path,
@@ -200,21 +219,14 @@ def generate_sql_package( # noqa: PLR0913, PLR0914
200
219
  src_path: Path = Path(),
201
220
  tempdir_path: Path | None = None,
202
221
  ) -> bool:
203
- dsn_import_package, dsn_import_path = dsn_import.split(":")
204
-
205
- package_name = package_full_name.split(".")[-1] # noqa: PLC0207
222
+ package_name = package_full_name.rsplit(".", maxsplit=1)[-1]
206
223
  sql_fn_name = f"{package_name}_sql"
207
224
 
208
- target_package_path = src_path / f"{package_full_name.replace('.', '/')}.py"
209
-
210
- queries = list(find_all_queries(src_path, sql_fn_name))
211
- validate_stmt_has_single_row_type(queries)
212
- queries = list({q.name: q for q in queries}.values())
225
+ queries, all_locations = collect_queries(src_path, sql_fn_name)
213
226
 
214
- dsn_package = importlib.import_module(dsn_import_package)
215
- dsn = eval(dsn_import_path, vars(dsn_package)) # noqa: S307
227
+ dsn, dsn_import_package, dsn_import_path = resolve_dsn(dsn_import)
216
228
 
217
- sqlc_res = run_sqlc(
229
+ sqlc_res, block_starts = run_sqlc(
218
230
  src_path / schema_path,
219
231
  [(q.name, q.stmt) for q in queries],
220
232
  dsn=dsn,
@@ -223,63 +235,13 @@ def generate_sql_package( # noqa: PLR0913, PLR0914
223
235
  )
224
236
 
225
237
  if sqlc_res.error:
226
- logger.error("Error running SQLC:\n%s", sqlc_res.error)
238
+ mapped = map_sqlc_error(sqlc_res.error, block_starts, all_locations)
239
+ logger.error(f"Error running SQLC:\n{mapped}")
227
240
  return False
228
241
 
229
- json_import_block = ""
230
- json_col_overrides: dict[tuple[str, str], str] = {}
231
-
232
- if json_model_overrides:
233
- json_compatible_types = {"json", "jsonb", "text", "varchar"}
234
- col_types = {
235
- (table.rel.name, column.name): column.type.name.removeprefix("pg_catalog.")
236
- for schema in sqlc_res.catalog.schemas
237
- for table in schema.tables
238
- for column in table.columns
239
- }
240
- tables = {table for table, _ in col_types}
241
-
242
- parsed: dict[tuple[str, str], tuple[str, str]] = {}
243
- for key, import_path in json_model_overrides.items():
244
- table_name, sep, col_name = key.partition(".")
245
- if not sep:
246
- msg = f"json_model_overrides key must be 'table.column', got: {key!r}"
247
- raise ValueError(msg)
248
- if table_name not in tables:
249
- msg = f"json_model_overrides: table {table_name!r} not found in catalog"
250
- raise ValueError(msg)
251
- if (table_name, col_name) not in col_types:
252
- msg = (
253
- f"json_model_overrides: column {col_name!r} "
254
- f"not found in table {table_name!r}"
255
- )
256
- raise ValueError(msg)
257
-
258
- db_type = col_types[table_name, col_name]
259
- if db_type not in json_compatible_types:
260
- msg = (
261
- f"json_model_overrides: column "
262
- f"{table_name}.{col_name} has type "
263
- f"{db_type!r}, expected one of "
264
- f"{json_compatible_types}"
265
- )
266
- raise ValueError(msg)
267
-
268
- module_path, sep, class_name = import_path.partition(":")
269
- if not sep:
270
- msg = (
271
- "json_model_overrides value must be "
272
- f"'module:Class', got: {import_path!r}"
273
- )
274
- raise ValueError(msg)
275
-
276
- parsed[table_name, col_name] = (module_path, class_name)
277
-
278
- modules = sorted({module for module, _ in parsed.values()})
279
- json_import_block = "\n" + "\n".join(f"import {m}" for m in modules)
280
- json_col_overrides = {
281
- key: f"{module}.{cls}" for key, (module, cls) in parsed.items()
282
- }
242
+ json_import_block, json_col_overrides = resolve_json_model_overrides(
243
+ json_model_overrides or {}, sqlc_res.catalog
244
+ )
283
245
 
284
246
  resolver = TypeResolver(
285
247
  catalog=sqlc_res.catalog,
@@ -297,22 +259,82 @@ def generate_sql_package( # noqa: PLR0913, PLR0914
297
259
  resolver,
298
260
  )
299
261
 
300
- entities = [render_entity(e.name, e.column_specs) for e in ordered_entities]
262
+ entities = sorted(render_entity(e.name, e.column_specs) for e in ordered_entities)
301
263
 
302
264
  used_enums = collect_used_enums(sqlc_res)
303
265
 
304
- enums = [
266
+ enums = sorted(
305
267
  render_enum_class(e, package_name, to_pascal_fn, to_snake_fn)
306
268
  for schema in sqlc_res.catalog.schemas
307
269
  for e in schema.enums
308
270
  if (schema.name, e.name) in used_enums
271
+ )
272
+
273
+ query_classes = render_query_classes(
274
+ sqlc_res.queries, queries, resolver, result_types, all_locations
275
+ )
276
+
277
+ query_overloads = [
278
+ render_query_overload(sql_fn_name, q.name, q.stmt, q.row_type) for q in queries
309
279
  ]
310
280
 
311
- query_classes = [
281
+ query_dict_entries = [render_query_dict_entry(q.name, q.stmt) for q in queries]
282
+
283
+ target_package_path = src_path / f"{package_full_name.replace('.', '/')}.py"
284
+
285
+ new_content = render_package(
286
+ dsn_import_package,
287
+ dsn_import_path,
288
+ package_name,
289
+ sql_fn_name,
290
+ entities,
291
+ enums,
292
+ query_classes,
293
+ query_overloads,
294
+ query_dict_entries,
295
+ application_name,
296
+ json_import_block,
297
+ )
298
+ changed = write_if_changed(target_package_path, new_content + "\n")
299
+ if changed:
300
+ logger.info(f"Generated SQL package {package_full_name}")
301
+ return changed
302
+
303
+
304
+ def collect_queries(
305
+ src_path: Path, sql_fn_name: str
306
+ ) -> tuple[list["CodeQuery"], defaultdict[str, list[str]]]:
307
+ raw = list(find_all_queries(src_path, sql_fn_name))
308
+ validate_stmt_has_single_row_type(raw)
309
+ all_locations: defaultdict[str, list[str]] = defaultdict(list)
310
+ first_occurrence: dict[str, CodeQuery] = {}
311
+ for q in raw:
312
+ all_locations[q.name].append(q.location)
313
+ if q.name not in first_occurrence:
314
+ first_occurrence[q.name] = q
315
+ queries = sorted(first_occurrence.values(), key=lambda q: (q.file, q.lineno))
316
+ return queries, all_locations
317
+
318
+
319
+ def resolve_dsn(dsn_import: str) -> tuple[str, str, str]:
320
+ package_name, attr_path = dsn_import.split(":")
321
+ mod = importlib.import_module(package_name)
322
+ dsn: str = eval(attr_path, vars(mod)) # noqa: S307
323
+ return dsn, package_name, attr_path
324
+
325
+
326
+ def render_query_classes(
327
+ sqlc_queries: tuple[Query, ...],
328
+ queries: list["CodeQuery"],
329
+ resolver: TypeResolver,
330
+ result_types: dict[str, str],
331
+ all_locations: defaultdict[str, list[str]],
332
+ ) -> list[str]:
333
+ query_order = {q.name: i for i, q in enumerate(queries)}
334
+ return [
312
335
  render_query_class(
313
336
  q.name,
314
337
  q.text,
315
- package_name,
316
338
  [
317
339
  resolver.param_spec(
318
340
  p.column,
@@ -328,33 +350,67 @@ def generate_sql_package( # noqa: PLR0913, PLR0914
328
350
  if len(q.columns) == 1
329
351
  else None
330
352
  ),
353
+ all_locations[q.name],
331
354
  )
332
- for q in sqlc_res.queries
355
+ for q in sorted(sqlc_queries, key=lambda q: query_order[q.name])
333
356
  ]
334
357
 
335
- query_overloads = [
336
- render_query_overload(sql_fn_name, q.name, q.stmt, q.row_type) for q in queries
337
- ]
338
358
 
339
- query_dict_entries = [render_query_dict_entry(q.name, q.stmt) for q in queries]
359
+ def resolve_json_model_overrides(
360
+ overrides: dict[str, str], catalog: Catalog
361
+ ) -> tuple[str, dict[tuple[str, str], str]]:
362
+ if not overrides:
363
+ return "", {}
340
364
 
341
- new_content = render_package(
342
- dsn_import_package,
343
- dsn_import_path,
344
- package_name,
345
- sql_fn_name,
346
- sorted(entities),
347
- sorted(enums),
348
- sorted(query_classes),
349
- sorted(query_overloads),
350
- sorted(query_dict_entries),
351
- application_name,
352
- json_import_block,
353
- )
354
- changed = write_if_changed(target_package_path, new_content + "\n")
355
- if changed:
356
- logger.info(f"Generated SQL package {package_full_name}")
357
- return changed
365
+ json_compatible_types = {"json", "jsonb", "text", "varchar"}
366
+ col_types = {
367
+ (table.rel.name, column.name): column.type.name.removeprefix("pg_catalog.")
368
+ for schema in catalog.schemas
369
+ for table in schema.tables
370
+ for column in table.columns
371
+ }
372
+ tables = {table for table, _ in col_types}
373
+
374
+ parsed: dict[tuple[str, str], tuple[str, str]] = {}
375
+ for key, import_path in overrides.items():
376
+ table_name, sep, col_name = key.partition(".")
377
+ if not sep:
378
+ msg = f"json_model_overrides key must be 'table.column', got: {key!r}"
379
+ raise ValueError(msg)
380
+ if table_name not in tables:
381
+ msg = f"json_model_overrides: table {table_name!r} not found in catalog"
382
+ raise ValueError(msg)
383
+ if (table_name, col_name) not in col_types:
384
+ msg = (
385
+ f"json_model_overrides: column {col_name!r} "
386
+ f"not found in table {table_name!r}"
387
+ )
388
+ raise ValueError(msg)
389
+
390
+ db_type = col_types[table_name, col_name]
391
+ if db_type not in json_compatible_types:
392
+ msg = (
393
+ f"json_model_overrides: column "
394
+ f"{table_name}.{col_name} has type "
395
+ f"{db_type!r}, expected one of "
396
+ f"{json_compatible_types}"
397
+ )
398
+ raise ValueError(msg)
399
+
400
+ module_path, sep, class_name = import_path.partition(":")
401
+ if not sep:
402
+ msg = (
403
+ "json_model_overrides value must be "
404
+ f"'module:Class', got: {import_path!r}"
405
+ )
406
+ raise ValueError(msg)
407
+
408
+ parsed[table_name, col_name] = (module_path, class_name)
409
+
410
+ modules = sorted({module for module, _ in parsed.values()})
411
+ import_block = "\n" + "\n".join(f"import {m}" for m in modules)
412
+ col_overrides = {key: f"{module}.{cls}" for key, (module, cls) in parsed.items()}
413
+ return import_block, col_overrides
358
414
 
359
415
 
360
416
  def render_package( # noqa: PLR0913, PLR0917
@@ -385,16 +441,20 @@ import uuid
385
441
  from collections.abc import AsyncGenerator
386
442
  from collections.abc import AsyncIterator
387
443
  from collections.abc import Sequence
444
+ from contextlib import AbstractAsyncContextManager
388
445
  from contextlib import asynccontextmanager
389
446
  from contextvars import ContextVar
390
447
  from dataclasses import dataclass
391
448
  from enum import StrEnum
449
+ from typing import ClassVar
392
450
  from typing import Literal
393
451
  from typing import overload
394
452
 
395
453
  import psycopg
454
+ import psycopg.abc
396
455
  import psycopg.rows
397
- from psycopg.types import json as pgjson
456
+ import psycopg.sql
457
+ import psycopg.types.json
398
458
 
399
459
  from iron_sql import runtime
400
460
 
@@ -447,8 +507,28 @@ async def {package_name}_notify(channel: str, payload: str = "") -> None:
447
507
  {"\n\n\n".join(entities)}
448
508
 
449
509
 
450
- class Query:
451
- pass
510
+ class Query[T]:
511
+ _stmt: ClassVar[psycopg.sql.SQL]
512
+ _row_factory: psycopg.rows.BaseRowFactory[T]
513
+
514
+ @asynccontextmanager
515
+ async def _client_cursor(self, params: psycopg.abc.Params | None):
516
+ async with (
517
+ {package_name}_connection() as conn,
518
+ psycopg.AsyncRawCursor(conn, row_factory=self._row_factory) as cur,
519
+ ):
520
+ await cur.execute(self._stmt, params)
521
+ yield cur
522
+
523
+ @asynccontextmanager
524
+ async def _server_cursor(self, params: psycopg.abc.Params | None):
525
+ async with (
526
+ {package_name}_connection() as conn,
527
+ runtime.ensure_transaction(conn),
528
+ psycopg.AsyncRawServerCursor(conn, row_factory=self._row_factory, name=runtime.next_cursor_name()) as cur,
529
+ ):
530
+ await cur.execute(self._stmt, params)
531
+ yield cur
452
532
 
453
533
 
454
534
  {"\n\n\n".join(query_classes)}
@@ -470,7 +550,7 @@ def {sql_fn_name}(stmt: str, row_type: str | None = None) -> Query:
470
550
  msg = f"Unknown statement: {{stmt!r}}"
471
551
  raise KeyError(msg)
472
552
 
473
- """.strip()
553
+ """.strip() # noqa: E501
474
554
 
475
555
 
476
556
  def render_enum_class(
@@ -536,11 +616,11 @@ def deduplicate_params(params: list[ParamSpec]) -> list[ParamSpec]:
536
616
  def render_query_class(
537
617
  query_name: str,
538
618
  stmt: str,
539
- package_name: str,
540
619
  query_params: list[ParamSpec],
541
620
  result: str,
542
621
  columns_num: int,
543
- scalar_json_type: str | None = None,
622
+ scalar_json_type: str | None,
623
+ locations: list[str],
544
624
  ) -> str:
545
625
  query_params = deduplicate_params(query_params)
546
626
 
@@ -582,39 +662,36 @@ def render_query_class(
582
662
  methods = f"""
583
663
 
584
664
  async def query_all_rows({", ".join(query_fn_params)}) -> list[{result}]:
585
- async with self._execute({params_arg}) as cur:
665
+ async with self._client_cursor({params_arg}) as cur:
586
666
  return await cur.fetchall()
587
667
 
588
668
  async def query_single_row({", ".join(query_fn_params)}) -> {result}:
589
- async with self._execute({params_arg}) as cur:
669
+ async with self._client_cursor({params_arg}) as cur:
590
670
  return runtime.get_one_row(await cur.fetchmany(2))
591
671
 
592
672
  async def query_optional_row({", ".join(query_fn_params)}) -> {base_result} | None:
593
- async with self._execute({params_arg}) as cur:
673
+ async with self._client_cursor({params_arg}) as cur:
594
674
  return runtime.get_one_row_or_none(await cur.fetchmany(2))
595
675
 
596
- """.strip()
676
+ def query_stream({", ".join(query_fn_params)}) -> AbstractAsyncContextManager[AsyncIterator[{result}]]:
677
+ return self._server_cursor({params_arg})
678
+
679
+ """.strip() # noqa: E501
597
680
  else:
598
681
  methods = f"""
599
682
 
600
683
  async def execute({", ".join(query_fn_params)}) -> None:
601
- async with self._execute({params_arg}):
684
+ async with self._client_cursor({params_arg}):
602
685
  pass
603
686
 
604
687
  """.strip()
605
688
 
606
689
  return f"""
607
690
 
608
- class {query_name}(Query):
609
- @asynccontextmanager
610
- async def _execute(self, params) -> AsyncIterator[psycopg.AsyncRawCursor[{result}]]:
611
- stmt = {stmt!r}
612
- async with (
613
- {package_name}_connection() as conn,
614
- psycopg.AsyncRawCursor(conn, row_factory={row_factory}) as cur,
615
- ):
616
- await cur.execute(stmt, params)
617
- yield cur
691
+ class {query_name}(Query[{result}]):
692
+ # See: {", ".join(locations)}
693
+ _stmt = psycopg.sql.SQL({stmt!r})
694
+ _row_factory = staticmethod({row_factory})
618
695
 
619
696
  {indent_block(methods, " ")}
620
697
 
@@ -749,7 +826,12 @@ def find_fn_calls(
749
826
  content = path.read_text(encoding="utf-8")
750
827
  if fn_name not in content:
751
828
  continue
752
- for node in ast.walk(ast.parse(content, filename=str(path))):
829
+ try:
830
+ tree = ast.parse(content, filename=str(path))
831
+ except SyntaxError as exc:
832
+ msg = f"Failed to parse {path}: {exc.msg} (line {exc.lineno})"
833
+ raise SyntaxError(msg) from exc
834
+ for node in ast.walk(tree):
753
835
  match node:
754
836
  case ast.Call(func=ast.Name(id=id)) if id == fn_name:
755
837
  yield path, node.lineno, node
@@ -799,11 +881,15 @@ def find_all_queries(src_path: Path, sql_fn_name: str) -> Iterator[CodeQuery]:
799
881
 
800
882
 
801
883
  def validate_stmt_has_single_row_type(queries: list[CodeQuery]) -> None:
802
- row_type_by_stmt: dict[str, str | None] = {}
884
+ first_by_stmt: dict[str, CodeQuery] = {}
803
885
  for query in queries:
804
- if query.stmt in row_type_by_stmt:
805
- if query.row_type != row_type_by_stmt[query.stmt]:
806
- msg = f"row_type conflict (existing={row_type_by_stmt[query.stmt]!r})"
886
+ if query.stmt in first_by_stmt:
887
+ first = first_by_stmt[query.stmt]
888
+ if query.row_type != first.row_type:
889
+ msg = (
890
+ f"row_type conflict: {first.location} has {first.row_type!r},"
891
+ f" {query.location} has {query.row_type!r}"
892
+ )
807
893
  raise ValueError(msg)
808
894
  else:
809
- row_type_by_stmt[query.stmt] = query.row_type
895
+ first_by_stmt[query.stmt] = query
@@ -139,7 +139,7 @@ def run_sqlc(
139
139
  dsn: str | None,
140
140
  debug_path: Path | None = None,
141
141
  tempdir_path: Path | None = None,
142
- ) -> SQLCResult:
142
+ ) -> tuple[SQLCResult, list[tuple[int, str]]]:
143
143
  if not schema_path.exists():
144
144
  msg = f"Schema file not found: {schema_path}"
145
145
  raise ValueError(msg)
@@ -148,7 +148,7 @@ def run_sqlc(
148
148
  return SQLCResult(
149
149
  catalog=Catalog(default_schema="", name="", schemas=()),
150
150
  queries=(),
151
- )
151
+ ), []
152
152
 
153
153
  queries = list({q[0]: q for q in queries}.values())
154
154
 
@@ -156,13 +156,15 @@ def run_sqlc(
156
156
  dir=str(tempdir_path) if tempdir_path else None
157
157
  ) as tempdir:
158
158
  queries_path = Path(tempdir) / "queries.sql"
159
- queries_path.write_text(
160
- "\n\n".join(
161
- f"-- name: {name} :exec\n{preprocess_sql(stmt)};"
162
- for name, stmt in queries
163
- ),
164
- encoding="utf-8",
165
- )
159
+ block_starts: list[tuple[int, str]] = []
160
+ blocks: list[str] = []
161
+ current_line = 1
162
+ for name, stmt in queries:
163
+ block = f"-- name: {name} :exec\n{preprocess_sql(stmt)};"
164
+ block_starts.append((current_line, name))
165
+ current_line += block.count("\n") + 2
166
+ blocks.append(block)
167
+ queries_path.write_text("\n\n".join(blocks), encoding="utf-8")
166
168
 
167
169
  (Path(tempdir) / "schema.sql").symlink_to(schema_path.absolute())
168
170
 
@@ -206,8 +208,10 @@ def run_sqlc(
206
208
  error=sqlc_run_result.stderr.decode().strip(),
207
209
  catalog=Catalog(default_schema="", name="", schemas=()),
208
210
  queries=(),
209
- )
210
- return SQLCResult.model_validate_json(json_out_path.read_text(encoding="utf-8"))
211
+ ), block_starts
212
+ return SQLCResult.model_validate_json(
213
+ json_out_path.read_text(encoding="utf-8")
214
+ ), block_starts
211
215
 
212
216
 
213
217
  def preprocess_sql(stmt: str) -> str:
@@ -1,4 +1,5 @@
1
1
  import contextlib
2
+ import itertools
2
3
  import types
3
4
  from collections.abc import AsyncGenerator
4
5
  from collections.abc import AsyncIterator
@@ -14,9 +15,9 @@ from typing import overload
14
15
 
15
16
  import psycopg
16
17
  import psycopg.rows
18
+ import psycopg.sql
19
+ import psycopg.types.json
17
20
  import psycopg_pool
18
- from psycopg import sql
19
- from psycopg.types import json as pgjson
20
21
  from pydantic import TypeAdapter
21
22
 
22
23
  _adapter_cache: dict[object, TypeAdapter[object]] = {}
@@ -63,21 +64,25 @@ async def listen(
63
64
  async def notify(conn: psycopg.AsyncConnection, channel: str, payload: str) -> None:
64
65
  _validate_channel(channel)
65
66
  await conn.execute(
66
- sql.SQL("NOTIFY {}, {}").format(
67
- sql.Identifier(channel),
68
- sql.Literal(payload),
67
+ psycopg.sql.SQL("NOTIFY {}, {}").format(
68
+ psycopg.sql.Identifier(channel),
69
+ psycopg.sql.Literal(payload),
69
70
  )
70
71
  )
71
72
 
72
73
 
73
74
  async def execute_listen(conn: psycopg.AsyncConnection, channel: str) -> None:
74
75
  _validate_channel(channel)
75
- await conn.execute(sql.SQL("LISTEN {}").format(sql.Identifier(channel)))
76
+ await conn.execute(
77
+ psycopg.sql.SQL("LISTEN {}").format(psycopg.sql.Identifier(channel))
78
+ )
76
79
 
77
80
 
78
81
  async def execute_unlisten(conn: psycopg.AsyncConnection, channel: str) -> None:
79
82
  _validate_channel(channel)
80
- await conn.execute(sql.SQL("UNLISTEN {}").format(sql.Identifier(channel)))
83
+ await conn.execute(
84
+ psycopg.sql.SQL("UNLISTEN {}").format(psycopg.sql.Identifier(channel))
85
+ )
81
86
 
82
87
 
83
88
  async def _has_active_listen_subscriptions(conn: psycopg.AsyncConnection) -> bool:
@@ -96,6 +101,26 @@ def _validate_channel(name: str) -> None:
96
101
  raise ValueError(msg)
97
102
 
98
103
 
104
+ _cursor_seq = itertools.count()
105
+
106
+
107
+ def next_cursor_name() -> str:
108
+ return f"_c{next(_cursor_seq)}"
109
+
110
+
111
+ @asynccontextmanager
112
+ async def ensure_transaction(conn: psycopg.AsyncConnection) -> AsyncIterator[None]:
113
+ match conn.info.transaction_status:
114
+ case psycopg.pq.TransactionStatus.IDLE:
115
+ async with conn.transaction():
116
+ yield
117
+ case psycopg.pq.TransactionStatus.INTRANS:
118
+ yield
119
+ case status:
120
+ msg = f"Cannot use server-side cursor: connection is in {status.name} state"
121
+ raise psycopg.InterfaceError(msg)
122
+
123
+
99
124
  class ConnectionPool:
100
125
  def __init__(
101
126
  self,
@@ -196,9 +221,9 @@ def serialize_json_param(typ: object, value: object, db_type: str) -> object:
196
221
  adapter = get_adapter(typ)
197
222
  match db_type:
198
223
  case "json":
199
- return pgjson.Json(adapter.dump_python(value, mode="json"))
224
+ return psycopg.types.json.Json(adapter.dump_python(value, mode="json"))
200
225
  case "jsonb":
201
- return pgjson.Jsonb(adapter.dump_python(value, mode="json"))
226
+ return psycopg.types.json.Jsonb(adapter.dump_python(value, mode="json"))
202
227
  case _:
203
228
  return adapter.dump_json(value).decode()
204
229
 
File without changes
File without changes