iron-sql 0.4.2__tar.gz → 0.4.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {iron_sql-0.4.2 → iron_sql-0.4.4}/PKG-INFO +4 -2
- {iron_sql-0.4.2 → iron_sql-0.4.4}/README.md +3 -1
- {iron_sql-0.4.2 → iron_sql-0.4.4}/pyproject.toml +1 -1
- {iron_sql-0.4.2 → iron_sql-0.4.4}/src/iron_sql/codegen/generator.py +207 -121
- {iron_sql-0.4.2 → iron_sql-0.4.4}/src/iron_sql/codegen/sqlc.py +15 -11
- {iron_sql-0.4.2 → iron_sql-0.4.4}/src/iron_sql/runtime.py +34 -9
- {iron_sql-0.4.2 → iron_sql-0.4.4}/LICENSE +0 -0
- {iron_sql-0.4.2 → iron_sql-0.4.4}/src/iron_sql/__init__.py +0 -0
- {iron_sql-0.4.2 → iron_sql-0.4.4}/src/iron_sql/codegen/__init__.py +0 -0
- {iron_sql-0.4.2 → iron_sql-0.4.4}/src/iron_sql/codegen/util.py +0 -0
- {iron_sql-0.4.2 → iron_sql-0.4.4}/src/iron_sql/py.typed +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: iron-sql
|
|
3
|
-
Version: 0.4.
|
|
3
|
+
Version: 0.4.4
|
|
4
4
|
Summary: iron_sql generates typed async PostgreSQL clients and runtime helpers from schemas and SQL queries
|
|
5
5
|
Keywords: postgresql,sql,sqlc,psycopg,codegen,async
|
|
6
6
|
Author: Ilia Ablamonov
|
|
@@ -46,6 +46,7 @@ The `sqlc` binary is bundled automatically via the `sqlc` Python package.
|
|
|
46
46
|
- **Query discovery.** `generate_sql_package` scans your codebase for calls like `<package>_sql("SELECT ...")`, runs `sqlc` for type analysis, and emits a typed module.
|
|
47
47
|
- **Strong typing.** Generated dataclasses and method signatures flow through your IDE and type checker.
|
|
48
48
|
- **Async runtime.** Built on `psycopg` v3 with pooled connections, context-based connection reuse, and transaction helpers.
|
|
49
|
+
- **Streaming.** `query_stream()` uses server-side cursors for memory-efficient iteration over large result sets.
|
|
49
50
|
- **Safe by default.** Helper methods enforce expected row counts instead of returning silent `None`.
|
|
50
51
|
|
|
51
52
|
## Package Layout
|
|
@@ -98,7 +99,8 @@ The `sqlc` binary is bundled automatically via the `sqlc` Python package.
|
|
|
98
99
|
- `ConnectionPool` opens lazily and reopens after `close()`, with `ContextVar`-based connection reuse for nested contexts.
|
|
99
100
|
- `*_listen_session()` uses a dedicated pooled connection and doesn't reuse `ContextVar` transaction connections.
|
|
100
101
|
- `query_single_row()` raises `NoRowsError`; `query_optional_row()` returns `None`. Both raise `TooManyRowsError` on 2+ rows.
|
|
101
|
-
-
|
|
102
|
+
- `query_stream()` returns an async context manager yielding an `AsyncIterator`; uses server-side cursors with automatic transaction management.
|
|
103
|
+
- JSONB params are sent with `psycopg.types.json.Jsonb`; JSON with `psycopg.types.json.Json`. Scalar row factories validate types at runtime.
|
|
102
104
|
- `json_validated` decorator applies Pydantic model validation to dataclass fields on construction.
|
|
103
105
|
|
|
104
106
|
## Example
|
|
@@ -20,6 +20,7 @@ The `sqlc` binary is bundled automatically via the `sqlc` Python package.
|
|
|
20
20
|
- **Query discovery.** `generate_sql_package` scans your codebase for calls like `<package>_sql("SELECT ...")`, runs `sqlc` for type analysis, and emits a typed module.
|
|
21
21
|
- **Strong typing.** Generated dataclasses and method signatures flow through your IDE and type checker.
|
|
22
22
|
- **Async runtime.** Built on `psycopg` v3 with pooled connections, context-based connection reuse, and transaction helpers.
|
|
23
|
+
- **Streaming.** `query_stream()` uses server-side cursors for memory-efficient iteration over large result sets.
|
|
23
24
|
- **Safe by default.** Helper methods enforce expected row counts instead of returning silent `None`.
|
|
24
25
|
|
|
25
26
|
## Package Layout
|
|
@@ -72,7 +73,8 @@ The `sqlc` binary is bundled automatically via the `sqlc` Python package.
|
|
|
72
73
|
- `ConnectionPool` opens lazily and reopens after `close()`, with `ContextVar`-based connection reuse for nested contexts.
|
|
73
74
|
- `*_listen_session()` uses a dedicated pooled connection and doesn't reuse `ContextVar` transaction connections.
|
|
74
75
|
- `query_single_row()` raises `NoRowsError`; `query_optional_row()` returns `None`. Both raise `TooManyRowsError` on 2+ rows.
|
|
75
|
-
-
|
|
76
|
+
- `query_stream()` returns an async context manager yielding an `AsyncIterator`; uses server-side cursors with automatic transaction management.
|
|
77
|
+
- JSONB params are sent with `psycopg.types.json.Jsonb`; JSON with `psycopg.types.json.Json`. Scalar row factories validate types at runtime.
|
|
76
78
|
- `json_validated` decorator applies Pydantic model validation to dataclass fields on construction.
|
|
77
79
|
|
|
78
80
|
## Example
|
|
@@ -3,6 +3,7 @@ import dataclasses
|
|
|
3
3
|
import hashlib
|
|
4
4
|
import importlib
|
|
5
5
|
import logging
|
|
6
|
+
import re
|
|
6
7
|
import warnings
|
|
7
8
|
from collections import defaultdict
|
|
8
9
|
from collections.abc import Callable
|
|
@@ -55,9 +56,9 @@ class ParamSpec:
|
|
|
55
56
|
|
|
56
57
|
match self.db_type:
|
|
57
58
|
case "json":
|
|
58
|
-
expr = f"
|
|
59
|
+
expr = f"psycopg.types.json.Json({self.name})"
|
|
59
60
|
case "jsonb":
|
|
60
|
-
expr = f"
|
|
61
|
+
expr = f"psycopg.types.json.Jsonb({self.name})"
|
|
61
62
|
case _:
|
|
62
63
|
return self.name
|
|
63
64
|
|
|
@@ -186,6 +187,24 @@ def collect_used_enums(sqlc_res: SQLCResult) -> set[tuple[str, str]]:
|
|
|
186
187
|
}
|
|
187
188
|
|
|
188
189
|
|
|
190
|
+
def map_sqlc_error(
|
|
191
|
+
error: str,
|
|
192
|
+
block_starts: list[tuple[int, str]],
|
|
193
|
+
all_locations: dict[str, list[str]],
|
|
194
|
+
) -> str:
|
|
195
|
+
def replace(m: re.Match[str]) -> str:
|
|
196
|
+
line = int(m.group(1))
|
|
197
|
+
name = next((n for start, n in reversed(block_starts) if start <= line), None)
|
|
198
|
+
if name is None:
|
|
199
|
+
return m.group(0)
|
|
200
|
+
locations = all_locations.get(name)
|
|
201
|
+
if not locations:
|
|
202
|
+
return m.group(0)
|
|
203
|
+
return f"{', '.join(locations)}:"
|
|
204
|
+
|
|
205
|
+
return re.sub(r"queries\.sql:(\d+)(?::\d+)?:", replace, error)
|
|
206
|
+
|
|
207
|
+
|
|
189
208
|
def generate_sql_package( # noqa: PLR0913, PLR0914
|
|
190
209
|
*,
|
|
191
210
|
schema_path: Path,
|
|
@@ -200,21 +219,14 @@ def generate_sql_package( # noqa: PLR0913, PLR0914
|
|
|
200
219
|
src_path: Path = Path(),
|
|
201
220
|
tempdir_path: Path | None = None,
|
|
202
221
|
) -> bool:
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
package_name = package_full_name.split(".")[-1] # noqa: PLC0207
|
|
222
|
+
package_name = package_full_name.rsplit(".", maxsplit=1)[-1]
|
|
206
223
|
sql_fn_name = f"{package_name}_sql"
|
|
207
224
|
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
queries = list(find_all_queries(src_path, sql_fn_name))
|
|
211
|
-
validate_stmt_has_single_row_type(queries)
|
|
212
|
-
queries = list({q.name: q for q in queries}.values())
|
|
225
|
+
queries, all_locations = collect_queries(src_path, sql_fn_name)
|
|
213
226
|
|
|
214
|
-
|
|
215
|
-
dsn = eval(dsn_import_path, vars(dsn_package)) # noqa: S307
|
|
227
|
+
dsn, dsn_import_package, dsn_import_path = resolve_dsn(dsn_import)
|
|
216
228
|
|
|
217
|
-
sqlc_res = run_sqlc(
|
|
229
|
+
sqlc_res, block_starts = run_sqlc(
|
|
218
230
|
src_path / schema_path,
|
|
219
231
|
[(q.name, q.stmt) for q in queries],
|
|
220
232
|
dsn=dsn,
|
|
@@ -223,63 +235,13 @@ def generate_sql_package( # noqa: PLR0913, PLR0914
|
|
|
223
235
|
)
|
|
224
236
|
|
|
225
237
|
if sqlc_res.error:
|
|
226
|
-
|
|
238
|
+
mapped = map_sqlc_error(sqlc_res.error, block_starts, all_locations)
|
|
239
|
+
logger.error(f"Error running SQLC:\n{mapped}")
|
|
227
240
|
return False
|
|
228
241
|
|
|
229
|
-
json_import_block =
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
if json_model_overrides:
|
|
233
|
-
json_compatible_types = {"json", "jsonb", "text", "varchar"}
|
|
234
|
-
col_types = {
|
|
235
|
-
(table.rel.name, column.name): column.type.name.removeprefix("pg_catalog.")
|
|
236
|
-
for schema in sqlc_res.catalog.schemas
|
|
237
|
-
for table in schema.tables
|
|
238
|
-
for column in table.columns
|
|
239
|
-
}
|
|
240
|
-
tables = {table for table, _ in col_types}
|
|
241
|
-
|
|
242
|
-
parsed: dict[tuple[str, str], tuple[str, str]] = {}
|
|
243
|
-
for key, import_path in json_model_overrides.items():
|
|
244
|
-
table_name, sep, col_name = key.partition(".")
|
|
245
|
-
if not sep:
|
|
246
|
-
msg = f"json_model_overrides key must be 'table.column', got: {key!r}"
|
|
247
|
-
raise ValueError(msg)
|
|
248
|
-
if table_name not in tables:
|
|
249
|
-
msg = f"json_model_overrides: table {table_name!r} not found in catalog"
|
|
250
|
-
raise ValueError(msg)
|
|
251
|
-
if (table_name, col_name) not in col_types:
|
|
252
|
-
msg = (
|
|
253
|
-
f"json_model_overrides: column {col_name!r} "
|
|
254
|
-
f"not found in table {table_name!r}"
|
|
255
|
-
)
|
|
256
|
-
raise ValueError(msg)
|
|
257
|
-
|
|
258
|
-
db_type = col_types[table_name, col_name]
|
|
259
|
-
if db_type not in json_compatible_types:
|
|
260
|
-
msg = (
|
|
261
|
-
f"json_model_overrides: column "
|
|
262
|
-
f"{table_name}.{col_name} has type "
|
|
263
|
-
f"{db_type!r}, expected one of "
|
|
264
|
-
f"{json_compatible_types}"
|
|
265
|
-
)
|
|
266
|
-
raise ValueError(msg)
|
|
267
|
-
|
|
268
|
-
module_path, sep, class_name = import_path.partition(":")
|
|
269
|
-
if not sep:
|
|
270
|
-
msg = (
|
|
271
|
-
"json_model_overrides value must be "
|
|
272
|
-
f"'module:Class', got: {import_path!r}"
|
|
273
|
-
)
|
|
274
|
-
raise ValueError(msg)
|
|
275
|
-
|
|
276
|
-
parsed[table_name, col_name] = (module_path, class_name)
|
|
277
|
-
|
|
278
|
-
modules = sorted({module for module, _ in parsed.values()})
|
|
279
|
-
json_import_block = "\n" + "\n".join(f"import {m}" for m in modules)
|
|
280
|
-
json_col_overrides = {
|
|
281
|
-
key: f"{module}.{cls}" for key, (module, cls) in parsed.items()
|
|
282
|
-
}
|
|
242
|
+
json_import_block, json_col_overrides = resolve_json_model_overrides(
|
|
243
|
+
json_model_overrides or {}, sqlc_res.catalog
|
|
244
|
+
)
|
|
283
245
|
|
|
284
246
|
resolver = TypeResolver(
|
|
285
247
|
catalog=sqlc_res.catalog,
|
|
@@ -297,22 +259,82 @@ def generate_sql_package( # noqa: PLR0913, PLR0914
|
|
|
297
259
|
resolver,
|
|
298
260
|
)
|
|
299
261
|
|
|
300
|
-
entities =
|
|
262
|
+
entities = sorted(render_entity(e.name, e.column_specs) for e in ordered_entities)
|
|
301
263
|
|
|
302
264
|
used_enums = collect_used_enums(sqlc_res)
|
|
303
265
|
|
|
304
|
-
enums =
|
|
266
|
+
enums = sorted(
|
|
305
267
|
render_enum_class(e, package_name, to_pascal_fn, to_snake_fn)
|
|
306
268
|
for schema in sqlc_res.catalog.schemas
|
|
307
269
|
for e in schema.enums
|
|
308
270
|
if (schema.name, e.name) in used_enums
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
query_classes = render_query_classes(
|
|
274
|
+
sqlc_res.queries, queries, resolver, result_types, all_locations
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
query_overloads = [
|
|
278
|
+
render_query_overload(sql_fn_name, q.name, q.stmt, q.row_type) for q in queries
|
|
309
279
|
]
|
|
310
280
|
|
|
311
|
-
|
|
281
|
+
query_dict_entries = [render_query_dict_entry(q.name, q.stmt) for q in queries]
|
|
282
|
+
|
|
283
|
+
target_package_path = src_path / f"{package_full_name.replace('.', '/')}.py"
|
|
284
|
+
|
|
285
|
+
new_content = render_package(
|
|
286
|
+
dsn_import_package,
|
|
287
|
+
dsn_import_path,
|
|
288
|
+
package_name,
|
|
289
|
+
sql_fn_name,
|
|
290
|
+
entities,
|
|
291
|
+
enums,
|
|
292
|
+
query_classes,
|
|
293
|
+
query_overloads,
|
|
294
|
+
query_dict_entries,
|
|
295
|
+
application_name,
|
|
296
|
+
json_import_block,
|
|
297
|
+
)
|
|
298
|
+
changed = write_if_changed(target_package_path, new_content + "\n")
|
|
299
|
+
if changed:
|
|
300
|
+
logger.info(f"Generated SQL package {package_full_name}")
|
|
301
|
+
return changed
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def collect_queries(
|
|
305
|
+
src_path: Path, sql_fn_name: str
|
|
306
|
+
) -> tuple[list["CodeQuery"], defaultdict[str, list[str]]]:
|
|
307
|
+
raw = list(find_all_queries(src_path, sql_fn_name))
|
|
308
|
+
validate_stmt_has_single_row_type(raw)
|
|
309
|
+
all_locations: defaultdict[str, list[str]] = defaultdict(list)
|
|
310
|
+
first_occurrence: dict[str, CodeQuery] = {}
|
|
311
|
+
for q in raw:
|
|
312
|
+
all_locations[q.name].append(q.location)
|
|
313
|
+
if q.name not in first_occurrence:
|
|
314
|
+
first_occurrence[q.name] = q
|
|
315
|
+
queries = sorted(first_occurrence.values(), key=lambda q: (q.file, q.lineno))
|
|
316
|
+
return queries, all_locations
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def resolve_dsn(dsn_import: str) -> tuple[str, str, str]:
|
|
320
|
+
package_name, attr_path = dsn_import.split(":")
|
|
321
|
+
mod = importlib.import_module(package_name)
|
|
322
|
+
dsn: str = eval(attr_path, vars(mod)) # noqa: S307
|
|
323
|
+
return dsn, package_name, attr_path
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def render_query_classes(
|
|
327
|
+
sqlc_queries: tuple[Query, ...],
|
|
328
|
+
queries: list["CodeQuery"],
|
|
329
|
+
resolver: TypeResolver,
|
|
330
|
+
result_types: dict[str, str],
|
|
331
|
+
all_locations: defaultdict[str, list[str]],
|
|
332
|
+
) -> list[str]:
|
|
333
|
+
query_order = {q.name: i for i, q in enumerate(queries)}
|
|
334
|
+
return [
|
|
312
335
|
render_query_class(
|
|
313
336
|
q.name,
|
|
314
337
|
q.text,
|
|
315
|
-
package_name,
|
|
316
338
|
[
|
|
317
339
|
resolver.param_spec(
|
|
318
340
|
p.column,
|
|
@@ -328,33 +350,67 @@ def generate_sql_package( # noqa: PLR0913, PLR0914
|
|
|
328
350
|
if len(q.columns) == 1
|
|
329
351
|
else None
|
|
330
352
|
),
|
|
353
|
+
all_locations[q.name],
|
|
331
354
|
)
|
|
332
|
-
for q in
|
|
355
|
+
for q in sorted(sqlc_queries, key=lambda q: query_order[q.name])
|
|
333
356
|
]
|
|
334
357
|
|
|
335
|
-
query_overloads = [
|
|
336
|
-
render_query_overload(sql_fn_name, q.name, q.stmt, q.row_type) for q in queries
|
|
337
|
-
]
|
|
338
358
|
|
|
339
|
-
|
|
359
|
+
def resolve_json_model_overrides(
|
|
360
|
+
overrides: dict[str, str], catalog: Catalog
|
|
361
|
+
) -> tuple[str, dict[tuple[str, str], str]]:
|
|
362
|
+
if not overrides:
|
|
363
|
+
return "", {}
|
|
340
364
|
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
365
|
+
json_compatible_types = {"json", "jsonb", "text", "varchar"}
|
|
366
|
+
col_types = {
|
|
367
|
+
(table.rel.name, column.name): column.type.name.removeprefix("pg_catalog.")
|
|
368
|
+
for schema in catalog.schemas
|
|
369
|
+
for table in schema.tables
|
|
370
|
+
for column in table.columns
|
|
371
|
+
}
|
|
372
|
+
tables = {table for table, _ in col_types}
|
|
373
|
+
|
|
374
|
+
parsed: dict[tuple[str, str], tuple[str, str]] = {}
|
|
375
|
+
for key, import_path in overrides.items():
|
|
376
|
+
table_name, sep, col_name = key.partition(".")
|
|
377
|
+
if not sep:
|
|
378
|
+
msg = f"json_model_overrides key must be 'table.column', got: {key!r}"
|
|
379
|
+
raise ValueError(msg)
|
|
380
|
+
if table_name not in tables:
|
|
381
|
+
msg = f"json_model_overrides: table {table_name!r} not found in catalog"
|
|
382
|
+
raise ValueError(msg)
|
|
383
|
+
if (table_name, col_name) not in col_types:
|
|
384
|
+
msg = (
|
|
385
|
+
f"json_model_overrides: column {col_name!r} "
|
|
386
|
+
f"not found in table {table_name!r}"
|
|
387
|
+
)
|
|
388
|
+
raise ValueError(msg)
|
|
389
|
+
|
|
390
|
+
db_type = col_types[table_name, col_name]
|
|
391
|
+
if db_type not in json_compatible_types:
|
|
392
|
+
msg = (
|
|
393
|
+
f"json_model_overrides: column "
|
|
394
|
+
f"{table_name}.{col_name} has type "
|
|
395
|
+
f"{db_type!r}, expected one of "
|
|
396
|
+
f"{json_compatible_types}"
|
|
397
|
+
)
|
|
398
|
+
raise ValueError(msg)
|
|
399
|
+
|
|
400
|
+
module_path, sep, class_name = import_path.partition(":")
|
|
401
|
+
if not sep:
|
|
402
|
+
msg = (
|
|
403
|
+
"json_model_overrides value must be "
|
|
404
|
+
f"'module:Class', got: {import_path!r}"
|
|
405
|
+
)
|
|
406
|
+
raise ValueError(msg)
|
|
407
|
+
|
|
408
|
+
parsed[table_name, col_name] = (module_path, class_name)
|
|
409
|
+
|
|
410
|
+
modules = sorted({module for module, _ in parsed.values()})
|
|
411
|
+
import_block = "\n" + "\n".join(f"import {m}" for m in modules)
|
|
412
|
+
col_overrides = {key: f"{module}.{cls}" for key, (module, cls) in parsed.items()}
|
|
413
|
+
return import_block, col_overrides
|
|
358
414
|
|
|
359
415
|
|
|
360
416
|
def render_package( # noqa: PLR0913, PLR0917
|
|
@@ -385,16 +441,20 @@ import uuid
|
|
|
385
441
|
from collections.abc import AsyncGenerator
|
|
386
442
|
from collections.abc import AsyncIterator
|
|
387
443
|
from collections.abc import Sequence
|
|
444
|
+
from contextlib import AbstractAsyncContextManager
|
|
388
445
|
from contextlib import asynccontextmanager
|
|
389
446
|
from contextvars import ContextVar
|
|
390
447
|
from dataclasses import dataclass
|
|
391
448
|
from enum import StrEnum
|
|
449
|
+
from typing import ClassVar
|
|
392
450
|
from typing import Literal
|
|
393
451
|
from typing import overload
|
|
394
452
|
|
|
395
453
|
import psycopg
|
|
454
|
+
import psycopg.abc
|
|
396
455
|
import psycopg.rows
|
|
397
|
-
|
|
456
|
+
import psycopg.sql
|
|
457
|
+
import psycopg.types.json
|
|
398
458
|
|
|
399
459
|
from iron_sql import runtime
|
|
400
460
|
|
|
@@ -447,8 +507,28 @@ async def {package_name}_notify(channel: str, payload: str = "") -> None:
|
|
|
447
507
|
{"\n\n\n".join(entities)}
|
|
448
508
|
|
|
449
509
|
|
|
450
|
-
class Query:
|
|
451
|
-
|
|
510
|
+
class Query[T]:
|
|
511
|
+
_stmt: ClassVar[psycopg.sql.SQL]
|
|
512
|
+
_row_factory: psycopg.rows.BaseRowFactory[T]
|
|
513
|
+
|
|
514
|
+
@asynccontextmanager
|
|
515
|
+
async def _client_cursor(self, params: psycopg.abc.Params | None):
|
|
516
|
+
async with (
|
|
517
|
+
{package_name}_connection() as conn,
|
|
518
|
+
psycopg.AsyncRawCursor(conn, row_factory=self._row_factory) as cur,
|
|
519
|
+
):
|
|
520
|
+
await cur.execute(self._stmt, params)
|
|
521
|
+
yield cur
|
|
522
|
+
|
|
523
|
+
@asynccontextmanager
|
|
524
|
+
async def _server_cursor(self, params: psycopg.abc.Params | None):
|
|
525
|
+
async with (
|
|
526
|
+
{package_name}_connection() as conn,
|
|
527
|
+
runtime.ensure_transaction(conn),
|
|
528
|
+
psycopg.AsyncRawServerCursor(conn, row_factory=self._row_factory, name=runtime.next_cursor_name()) as cur,
|
|
529
|
+
):
|
|
530
|
+
await cur.execute(self._stmt, params)
|
|
531
|
+
yield cur
|
|
452
532
|
|
|
453
533
|
|
|
454
534
|
{"\n\n\n".join(query_classes)}
|
|
@@ -470,7 +550,7 @@ def {sql_fn_name}(stmt: str, row_type: str | None = None) -> Query:
|
|
|
470
550
|
msg = f"Unknown statement: {{stmt!r}}"
|
|
471
551
|
raise KeyError(msg)
|
|
472
552
|
|
|
473
|
-
""".strip()
|
|
553
|
+
""".strip() # noqa: E501
|
|
474
554
|
|
|
475
555
|
|
|
476
556
|
def render_enum_class(
|
|
@@ -536,11 +616,11 @@ def deduplicate_params(params: list[ParamSpec]) -> list[ParamSpec]:
|
|
|
536
616
|
def render_query_class(
|
|
537
617
|
query_name: str,
|
|
538
618
|
stmt: str,
|
|
539
|
-
package_name: str,
|
|
540
619
|
query_params: list[ParamSpec],
|
|
541
620
|
result: str,
|
|
542
621
|
columns_num: int,
|
|
543
|
-
scalar_json_type: str | None
|
|
622
|
+
scalar_json_type: str | None,
|
|
623
|
+
locations: list[str],
|
|
544
624
|
) -> str:
|
|
545
625
|
query_params = deduplicate_params(query_params)
|
|
546
626
|
|
|
@@ -582,39 +662,36 @@ def render_query_class(
|
|
|
582
662
|
methods = f"""
|
|
583
663
|
|
|
584
664
|
async def query_all_rows({", ".join(query_fn_params)}) -> list[{result}]:
|
|
585
|
-
async with self.
|
|
665
|
+
async with self._client_cursor({params_arg}) as cur:
|
|
586
666
|
return await cur.fetchall()
|
|
587
667
|
|
|
588
668
|
async def query_single_row({", ".join(query_fn_params)}) -> {result}:
|
|
589
|
-
async with self.
|
|
669
|
+
async with self._client_cursor({params_arg}) as cur:
|
|
590
670
|
return runtime.get_one_row(await cur.fetchmany(2))
|
|
591
671
|
|
|
592
672
|
async def query_optional_row({", ".join(query_fn_params)}) -> {base_result} | None:
|
|
593
|
-
async with self.
|
|
673
|
+
async with self._client_cursor({params_arg}) as cur:
|
|
594
674
|
return runtime.get_one_row_or_none(await cur.fetchmany(2))
|
|
595
675
|
|
|
596
|
-
|
|
676
|
+
def query_stream({", ".join(query_fn_params)}) -> AbstractAsyncContextManager[AsyncIterator[{result}]]:
|
|
677
|
+
return self._server_cursor({params_arg})
|
|
678
|
+
|
|
679
|
+
""".strip() # noqa: E501
|
|
597
680
|
else:
|
|
598
681
|
methods = f"""
|
|
599
682
|
|
|
600
683
|
async def execute({", ".join(query_fn_params)}) -> None:
|
|
601
|
-
async with self.
|
|
684
|
+
async with self._client_cursor({params_arg}):
|
|
602
685
|
pass
|
|
603
686
|
|
|
604
687
|
""".strip()
|
|
605
688
|
|
|
606
689
|
return f"""
|
|
607
690
|
|
|
608
|
-
class {query_name}(Query):
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
async with (
|
|
613
|
-
{package_name}_connection() as conn,
|
|
614
|
-
psycopg.AsyncRawCursor(conn, row_factory={row_factory}) as cur,
|
|
615
|
-
):
|
|
616
|
-
await cur.execute(stmt, params)
|
|
617
|
-
yield cur
|
|
691
|
+
class {query_name}(Query[{result}]):
|
|
692
|
+
# See: {", ".join(locations)}
|
|
693
|
+
_stmt = psycopg.sql.SQL({stmt!r})
|
|
694
|
+
_row_factory = staticmethod({row_factory})
|
|
618
695
|
|
|
619
696
|
{indent_block(methods, " ")}
|
|
620
697
|
|
|
@@ -749,7 +826,12 @@ def find_fn_calls(
|
|
|
749
826
|
content = path.read_text(encoding="utf-8")
|
|
750
827
|
if fn_name not in content:
|
|
751
828
|
continue
|
|
752
|
-
|
|
829
|
+
try:
|
|
830
|
+
tree = ast.parse(content, filename=str(path))
|
|
831
|
+
except SyntaxError as exc:
|
|
832
|
+
msg = f"Failed to parse {path}: {exc.msg} (line {exc.lineno})"
|
|
833
|
+
raise SyntaxError(msg) from exc
|
|
834
|
+
for node in ast.walk(tree):
|
|
753
835
|
match node:
|
|
754
836
|
case ast.Call(func=ast.Name(id=id)) if id == fn_name:
|
|
755
837
|
yield path, node.lineno, node
|
|
@@ -799,11 +881,15 @@ def find_all_queries(src_path: Path, sql_fn_name: str) -> Iterator[CodeQuery]:
|
|
|
799
881
|
|
|
800
882
|
|
|
801
883
|
def validate_stmt_has_single_row_type(queries: list[CodeQuery]) -> None:
|
|
802
|
-
|
|
884
|
+
first_by_stmt: dict[str, CodeQuery] = {}
|
|
803
885
|
for query in queries:
|
|
804
|
-
if query.stmt in
|
|
805
|
-
|
|
806
|
-
|
|
886
|
+
if query.stmt in first_by_stmt:
|
|
887
|
+
first = first_by_stmt[query.stmt]
|
|
888
|
+
if query.row_type != first.row_type:
|
|
889
|
+
msg = (
|
|
890
|
+
f"row_type conflict: {first.location} has {first.row_type!r},"
|
|
891
|
+
f" {query.location} has {query.row_type!r}"
|
|
892
|
+
)
|
|
807
893
|
raise ValueError(msg)
|
|
808
894
|
else:
|
|
809
|
-
|
|
895
|
+
first_by_stmt[query.stmt] = query
|
|
@@ -139,7 +139,7 @@ def run_sqlc(
|
|
|
139
139
|
dsn: str | None,
|
|
140
140
|
debug_path: Path | None = None,
|
|
141
141
|
tempdir_path: Path | None = None,
|
|
142
|
-
) -> SQLCResult:
|
|
142
|
+
) -> tuple[SQLCResult, list[tuple[int, str]]]:
|
|
143
143
|
if not schema_path.exists():
|
|
144
144
|
msg = f"Schema file not found: {schema_path}"
|
|
145
145
|
raise ValueError(msg)
|
|
@@ -148,7 +148,7 @@ def run_sqlc(
|
|
|
148
148
|
return SQLCResult(
|
|
149
149
|
catalog=Catalog(default_schema="", name="", schemas=()),
|
|
150
150
|
queries=(),
|
|
151
|
-
)
|
|
151
|
+
), []
|
|
152
152
|
|
|
153
153
|
queries = list({q[0]: q for q in queries}.values())
|
|
154
154
|
|
|
@@ -156,13 +156,15 @@ def run_sqlc(
|
|
|
156
156
|
dir=str(tempdir_path) if tempdir_path else None
|
|
157
157
|
) as tempdir:
|
|
158
158
|
queries_path = Path(tempdir) / "queries.sql"
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
)
|
|
164
|
-
|
|
165
|
-
|
|
159
|
+
block_starts: list[tuple[int, str]] = []
|
|
160
|
+
blocks: list[str] = []
|
|
161
|
+
current_line = 1
|
|
162
|
+
for name, stmt in queries:
|
|
163
|
+
block = f"-- name: {name} :exec\n{preprocess_sql(stmt)};"
|
|
164
|
+
block_starts.append((current_line, name))
|
|
165
|
+
current_line += block.count("\n") + 2
|
|
166
|
+
blocks.append(block)
|
|
167
|
+
queries_path.write_text("\n\n".join(blocks), encoding="utf-8")
|
|
166
168
|
|
|
167
169
|
(Path(tempdir) / "schema.sql").symlink_to(schema_path.absolute())
|
|
168
170
|
|
|
@@ -206,8 +208,10 @@ def run_sqlc(
|
|
|
206
208
|
error=sqlc_run_result.stderr.decode().strip(),
|
|
207
209
|
catalog=Catalog(default_schema="", name="", schemas=()),
|
|
208
210
|
queries=(),
|
|
209
|
-
)
|
|
210
|
-
return SQLCResult.model_validate_json(
|
|
211
|
+
), block_starts
|
|
212
|
+
return SQLCResult.model_validate_json(
|
|
213
|
+
json_out_path.read_text(encoding="utf-8")
|
|
214
|
+
), block_starts
|
|
211
215
|
|
|
212
216
|
|
|
213
217
|
def preprocess_sql(stmt: str) -> str:
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import contextlib
|
|
2
|
+
import itertools
|
|
2
3
|
import types
|
|
3
4
|
from collections.abc import AsyncGenerator
|
|
4
5
|
from collections.abc import AsyncIterator
|
|
@@ -14,9 +15,9 @@ from typing import overload
|
|
|
14
15
|
|
|
15
16
|
import psycopg
|
|
16
17
|
import psycopg.rows
|
|
18
|
+
import psycopg.sql
|
|
19
|
+
import psycopg.types.json
|
|
17
20
|
import psycopg_pool
|
|
18
|
-
from psycopg import sql
|
|
19
|
-
from psycopg.types import json as pgjson
|
|
20
21
|
from pydantic import TypeAdapter
|
|
21
22
|
|
|
22
23
|
_adapter_cache: dict[object, TypeAdapter[object]] = {}
|
|
@@ -63,21 +64,25 @@ async def listen(
|
|
|
63
64
|
async def notify(conn: psycopg.AsyncConnection, channel: str, payload: str) -> None:
|
|
64
65
|
_validate_channel(channel)
|
|
65
66
|
await conn.execute(
|
|
66
|
-
sql.SQL("NOTIFY {}, {}").format(
|
|
67
|
-
sql.Identifier(channel),
|
|
68
|
-
sql.Literal(payload),
|
|
67
|
+
psycopg.sql.SQL("NOTIFY {}, {}").format(
|
|
68
|
+
psycopg.sql.Identifier(channel),
|
|
69
|
+
psycopg.sql.Literal(payload),
|
|
69
70
|
)
|
|
70
71
|
)
|
|
71
72
|
|
|
72
73
|
|
|
73
74
|
async def execute_listen(conn: psycopg.AsyncConnection, channel: str) -> None:
|
|
74
75
|
_validate_channel(channel)
|
|
75
|
-
await conn.execute(
|
|
76
|
+
await conn.execute(
|
|
77
|
+
psycopg.sql.SQL("LISTEN {}").format(psycopg.sql.Identifier(channel))
|
|
78
|
+
)
|
|
76
79
|
|
|
77
80
|
|
|
78
81
|
async def execute_unlisten(conn: psycopg.AsyncConnection, channel: str) -> None:
|
|
79
82
|
_validate_channel(channel)
|
|
80
|
-
await conn.execute(
|
|
83
|
+
await conn.execute(
|
|
84
|
+
psycopg.sql.SQL("UNLISTEN {}").format(psycopg.sql.Identifier(channel))
|
|
85
|
+
)
|
|
81
86
|
|
|
82
87
|
|
|
83
88
|
async def _has_active_listen_subscriptions(conn: psycopg.AsyncConnection) -> bool:
|
|
@@ -96,6 +101,26 @@ def _validate_channel(name: str) -> None:
|
|
|
96
101
|
raise ValueError(msg)
|
|
97
102
|
|
|
98
103
|
|
|
104
|
+
_cursor_seq = itertools.count()
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def next_cursor_name() -> str:
|
|
108
|
+
return f"_c{next(_cursor_seq)}"
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
@asynccontextmanager
|
|
112
|
+
async def ensure_transaction(conn: psycopg.AsyncConnection) -> AsyncIterator[None]:
|
|
113
|
+
match conn.info.transaction_status:
|
|
114
|
+
case psycopg.pq.TransactionStatus.IDLE:
|
|
115
|
+
async with conn.transaction():
|
|
116
|
+
yield
|
|
117
|
+
case psycopg.pq.TransactionStatus.INTRANS:
|
|
118
|
+
yield
|
|
119
|
+
case status:
|
|
120
|
+
msg = f"Cannot use server-side cursor: connection is in {status.name} state"
|
|
121
|
+
raise psycopg.InterfaceError(msg)
|
|
122
|
+
|
|
123
|
+
|
|
99
124
|
class ConnectionPool:
|
|
100
125
|
def __init__(
|
|
101
126
|
self,
|
|
@@ -196,9 +221,9 @@ def serialize_json_param(typ: object, value: object, db_type: str) -> object:
|
|
|
196
221
|
adapter = get_adapter(typ)
|
|
197
222
|
match db_type:
|
|
198
223
|
case "json":
|
|
199
|
-
return
|
|
224
|
+
return psycopg.types.json.Json(adapter.dump_python(value, mode="json"))
|
|
200
225
|
case "jsonb":
|
|
201
|
-
return
|
|
226
|
+
return psycopg.types.json.Jsonb(adapter.dump_python(value, mode="json"))
|
|
202
227
|
case _:
|
|
203
228
|
return adapter.dump_json(value).decode()
|
|
204
229
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|