iron-sql 0.2.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
iron_sql-0.2.5/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Ilia Ablamonov
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,64 @@
1
+ Metadata-Version: 2.4
2
+ Name: iron-sql
3
+ Version: 0.2.5
4
+ Summary: iron_sql generates typed async PostgreSQL clients and runtime helpers from schemas and SQL queries
5
+ Keywords: postgresql,sql,sqlc,psycopg,codegen,async
6
+ Author: Ilia Ablamonov
7
+ Author-email: Ilia Ablamonov <ilia@flamefork.ru>
8
+ License-Expression: MIT
9
+ License-File: LICENSE
10
+ Classifier: Development Status :: 3 - Alpha
11
+ Classifier: Intended Audience :: Developers
12
+ Classifier: Topic :: Software Development :: Libraries
13
+ Classifier: Programming Language :: Python :: 3
14
+ Classifier: Programming Language :: Python :: 3.13
15
+ Requires-Dist: inflection>=0.5.1
16
+ Requires-Dist: psycopg>=3.3.2
17
+ Requires-Dist: psycopg-pool>=3.3.0
18
+ Requires-Dist: pydantic>=2.12.4
19
+ Requires-Python: >=3.13
20
+ Project-URL: Homepage, https://github.com/Flamefork/iron_sql
21
+ Project-URL: Repository, https://github.com/Flamefork/iron_sql.git
22
+ Project-URL: Issues, https://github.com/Flamefork/iron_sql/issues
23
+ Description-Content-Type: text/markdown
24
+
25
+ # iron_sql
26
+
27
+ iron_sql keeps SQL close to Python call sites while giving you typed, async query helpers. You write SQL once, keep it in version control, and get generated clients that match your schema without hand-written boilerplate.
28
+
29
+ ## Why use it
30
+ - SQL-first workflow: write queries where they are used; no ORM layer to fight.
31
+ - Strong typing: generated dataclasses and method signatures flow through your IDE and type checker.
32
+ - Async-ready: built on `psycopg` with pooled connections and transaction helpers.
33
+ - Safe-by-default: helper methods enforce expected row counts instead of returning silent `None`.
34
+
35
+ ## Quick start
36
+ 1. Install `iron_sql`, `psycopg`, `psycopg-pool`, and `pydantic`.
37
+ 2. Install [`sqlc` v2](https://docs.sqlc.dev/en/latest/overview/install.html) and ensure it is available in your PATH.
38
+ 3. Add a Postgres schema dump, for example `db/mydatabase_schema.sql`.
39
+ 4. Call `generate_sql_package(schema_path=..., package_full_name=..., dsn_import=...)` from a small script or task. The generator scans your code (defaults to current directory), runs `sqlc`, and writes a module such as `myapp/db/mydatabase.py`.
40
+
41
+ ## Authoring queries
42
+ - Use the package helper for your DB, e.g. `mydatabase_sql("select ...")`. The SQL string must be a literal so the generator can find it.
43
+ - Named parameters:
44
+ - Required: `@param`
45
+ - Optional: `@param?` (expands to `sqlc.narg('param')`)
46
+ - Positional placeholders (`$1`) stay as-is.
47
+ - Multi-column results can opt into a custom dataclass with `row_type="MyResult"`. Single-column queries return a scalar type; statements without results expose `execute()`.
48
+
49
+ ## Using generated clients
50
+ - `*_sql("...")` returns a query object with methods derived from the result shape:
51
+ - `execute()` when no rows are returned.
52
+ - `query_all_rows()`, `query_single_row()`, `query_optional_row()` for result sets.
53
+ - `*_connection()` yields a pooled `psycopg.AsyncConnection`; `*_transaction()` wraps it in a transaction context.
54
+ - JSONB params are sent with `pgjson.Jsonb`; scalar row factories validate types and raise when they do not match.
55
+
56
+ ## Adding another database package
57
+ Provide the schema file and DSN import string, then call `generate_sql_package()` with:
58
+ - `schema_path`: path to the schema SQL file (relative to `src_path`).
59
+ - `package_full_name`: target module, e.g. `myapp.db`.
60
+ - `dsn_import`: import path to a DSN string, e.g. `myapp.config:CONFIG.db_url.get_value()`.
61
+ - `src_path`: optional base source path for scanning queries (defaults current directory).
62
+ - `sqlc_path`: optional path to the sqlc binary if not in PATH (e.g., `Path("/custom/bin/sqlc")`).
63
+ - `tempdir_path`: optional path for temporary file generation (useful for Docker mounts).
64
+ - Optional `application_name`, `debug_path`, and `to_pascal_fn` if you need naming overrides or want to keep `sqlc` inputs for inspection.
@@ -0,0 +1,40 @@
1
+ # iron_sql
2
+
3
+ iron_sql keeps SQL close to Python call sites while giving you typed, async query helpers. You write SQL once, keep it in version control, and get generated clients that match your schema without hand-written boilerplate.
4
+
5
+ ## Why use it
6
+ - SQL-first workflow: write queries where they are used; no ORM layer to fight.
7
+ - Strong typing: generated dataclasses and method signatures flow through your IDE and type checker.
8
+ - Async-ready: built on `psycopg` with pooled connections and transaction helpers.
9
+ - Safe-by-default: helper methods enforce expected row counts instead of returning silent `None`.
10
+
11
+ ## Quick start
12
+ 1. Install `iron_sql`, `psycopg`, `psycopg-pool`, and `pydantic`.
13
+ 2. Install [`sqlc` v2](https://docs.sqlc.dev/en/latest/overview/install.html) and ensure it is available in your PATH.
14
+ 3. Add a Postgres schema dump, for example `db/mydatabase_schema.sql`.
15
+ 4. Call `generate_sql_package(schema_path=..., package_full_name=..., dsn_import=...)` from a small script or task. The generator scans your code (defaults to current directory), runs `sqlc`, and writes a module such as `myapp/db/mydatabase.py`.
16
+
17
+ ## Authoring queries
18
+ - Use the package helper for your DB, e.g. `mydatabase_sql("select ...")`. The SQL string must be a literal so the generator can find it.
19
+ - Named parameters:
20
+ - Required: `@param`
21
+ - Optional: `@param?` (expands to `sqlc.narg('param')`)
22
+ - Positional placeholders (`$1`) stay as-is.
23
+ - Multi-column results can opt into a custom dataclass with `row_type="MyResult"`. Single-column queries return a scalar type; statements without results expose `execute()`.
24
+
25
+ ## Using generated clients
26
+ - `*_sql("...")` returns a query object with methods derived from the result shape:
27
+ - `execute()` when no rows are returned.
28
+ - `query_all_rows()`, `query_single_row()`, `query_optional_row()` for result sets.
29
+ - `*_connection()` yields a pooled `psycopg.AsyncConnection`; `*_transaction()` wraps it in a transaction context.
30
+ - JSONB params are sent with `pgjson.Jsonb`; scalar row factories validate types and raise when they do not match.
31
+
32
+ ## Adding another database package
33
+ Provide the schema file and DSN import string, then call `generate_sql_package()` with:
34
+ - `schema_path`: path to the schema SQL file (relative to `src_path`).
35
+ - `package_full_name`: target module, e.g. `myapp.db`.
36
+ - `dsn_import`: import path to a DSN string, e.g. `myapp.config:CONFIG.db_url.get_value()`.
37
+ - `src_path`: optional base source path for scanning queries (defaults current directory).
38
+ - `sqlc_path`: optional path to the sqlc binary if not in PATH (e.g., `Path("/custom/bin/sqlc")`).
39
+ - `tempdir_path`: optional path for temporary file generation (useful for Docker mounts).
40
+ - Optional `application_name`, `debug_path`, and `to_pascal_fn` if you need naming overrides or want to keep `sqlc` inputs for inspection.
@@ -0,0 +1,129 @@
1
+ [project]
2
+ name = "iron-sql"
3
+ version = "0.2.5"
4
+
5
+ description = "iron_sql generates typed async PostgreSQL clients and runtime helpers from schemas and SQL queries"
6
+ readme = "README.md"
7
+ authors = [{ name = "Ilia Ablamonov", email = "ilia@flamefork.ru" }]
8
+ license = "MIT"
9
+ license-files = ["LICENSE"]
10
+ keywords = ["postgresql", "sql", "sqlc", "psycopg", "codegen", "async"]
11
+ classifiers = [
12
+ "Development Status :: 3 - Alpha",
13
+ "Intended Audience :: Developers",
14
+ "Topic :: Software Development :: Libraries",
15
+ "Programming Language :: Python :: 3",
16
+ "Programming Language :: Python :: 3.13",
17
+ ]
18
+
19
+ requires-python = ">=3.13"
20
+ dependencies = [
21
+ "inflection>=0.5.1",
22
+ "psycopg>=3.3.2",
23
+ "psycopg-pool>=3.3.0",
24
+ "pydantic>=2.12.4",
25
+ ]
26
+
27
+ [project.urls]
28
+ Homepage = "https://github.com/Flamefork/iron_sql"
29
+ Repository = "https://github.com/Flamefork/iron_sql.git"
30
+ Issues = "https://github.com/Flamefork/iron_sql/issues"
31
+
32
+ [build-system]
33
+ requires = ["uv_build>=0.9.4,<0.10.0"]
34
+ build-backend = "uv_build"
35
+
36
+ [dependency-groups]
37
+ dev = [
38
+ "basedpyright>=1.31.7",
39
+ "psycopg[binary]>=3.3.2",
40
+ "pytest>=8.4.2",
41
+ "pytest-asyncio>=1.2.0",
42
+ "pytest-cov>=7.0.0",
43
+ "pytest-randomly>=4.0.1",
44
+ "ruff>=0.14.1",
45
+ "testcontainers>=4",
46
+ ]
47
+
48
+ [tool.pyright]
49
+ typeCheckingMode = "strict"
50
+ reportUnknownArgumentType = "none"
51
+ reportUnknownLambdaType = "none"
52
+ reportUnknownMemberType = "none"
53
+ reportUnknownParameterType = "none"
54
+ reportUnknownVariableType = "none"
55
+ reportMissingParameterType = "none"
56
+ reportMissingTypeArgument = "none"
57
+ reportMissingTypeStubs = "none"
58
+ deprecateTypingAliases = true
59
+ reportImportCycles = true
60
+ reportUnnecessaryTypeIgnoreComment = true
61
+ reportUnreachable = true
62
+ reportIgnoreCommentWithoutRule = true
63
+ reportImplicitRelativeImport = true
64
+
65
+ [tool.ruff]
66
+ target-version = "py313"
67
+
68
+ [tool.ruff.format]
69
+ preview = true
70
+
71
+ [tool.ruff.lint]
72
+ preview = true
73
+ select = ["ALL"]
74
+ ignore = [
75
+ "ANN",
76
+ "COM812",
77
+ "CPY",
78
+ "D",
79
+ "FIX",
80
+ "G004",
81
+ "ISC001",
82
+ "PLC1901",
83
+ "PLR0911",
84
+ "PLR0915",
85
+ "PLR6301",
86
+ "RUF001",
87
+ "RUF002",
88
+ "RUF003",
89
+ "TC006",
90
+ "TD",
91
+ ]
92
+
93
+ [tool.ruff.lint.per-file-ignores]
94
+ "test_*.py" = ["A002", "PLR2004", "S", "FBT"]
95
+
96
+ [tool.ruff.lint.isort]
97
+ force-single-line = true
98
+
99
+ [tool.ruff.lint.pylint]
100
+ max-args = 10
101
+
102
+ [tool.ruff.lint.flake8-tidy-imports]
103
+ ban-relative-imports = "all"
104
+
105
+ [tool.pytest.ini_options]
106
+ strict = true
107
+ testpaths = ["tests"]
108
+ filterwarnings = [
109
+ "error",
110
+ "ignore:.*wait_container_is_ready:DeprecationWarning",
111
+ ]
112
+ addopts = [
113
+ "--import-mode=importlib",
114
+ "--no-cov-on-fail",
115
+ "--cov-report=term-missing:skip-covered",
116
+ ]
117
+ asyncio_mode = "auto"
118
+ asyncio_default_fixture_loop_scope = "function"
119
+
120
+ [tool.coverage.run]
121
+ branch = true
122
+ omit = ["tests/*.py", "testdb.py"]
123
+ data_file = ".coverage/db.sqlite"
124
+
125
+ [tool.coverage.html]
126
+ directory = ".coverage/htmlcov"
127
+
128
+ [tool.coverage.report]
129
+ exclude_also = ["@overload"]
@@ -0,0 +1,7 @@
1
+ """iron_sql: Typed SQL client generator for Python."""
2
+
3
+ from iron_sql.generator import generate_sql_package
4
+
5
+ __all__ = [
6
+ "generate_sql_package",
7
+ ]
@@ -0,0 +1,726 @@
1
+ import ast
2
+ import dataclasses
3
+ import hashlib
4
+ import importlib
5
+ import logging
6
+ from collections import defaultdict
7
+ from collections.abc import Callable
8
+ from collections.abc import Iterator
9
+ from dataclasses import dataclass
10
+ from pathlib import Path
11
+
12
+ import inflection
13
+ from pydantic import alias_generators
14
+
15
+ from iron_sql.sqlc import Catalog
16
+ from iron_sql.sqlc import Column
17
+ from iron_sql.sqlc import Enum
18
+ from iron_sql.sqlc import Query
19
+ from iron_sql.sqlc import SQLCResult
20
+ from iron_sql.sqlc import run_sqlc
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ @dataclass(kw_only=True, frozen=True)
26
+ class ColumnPySpec:
27
+ name: str
28
+ table: str
29
+ db_type: str
30
+ not_null: bool
31
+ is_array: bool
32
+ py_type: str
33
+
34
+
35
+ def _collect_used_enums(sqlc_res: SQLCResult) -> set[tuple[str, str]]:
36
+ return {
37
+ (schema.name, col.type.name)
38
+ for col in (
39
+ *(c for q in sqlc_res.queries for c in q.columns),
40
+ *(p.column for q in sqlc_res.queries for p in q.params),
41
+ )
42
+ for schema in (sqlc_res.catalog.schema_by_ref(col.type),)
43
+ if schema.has_enum(col.type.name)
44
+ }
45
+
46
+
47
+ def generate_sql_package( # noqa: PLR0913, PLR0914
48
+ *,
49
+ schema_path: Path,
50
+ package_full_name: str,
51
+ dsn_import: str,
52
+ application_name: str | None = None,
53
+ to_pascal_fn=alias_generators.to_pascal,
54
+ to_snake_fn=alias_generators.to_snake,
55
+ debug_path: Path | None = None,
56
+ src_path: Path = Path(),
57
+ sqlc_path: Path | None = None,
58
+ tempdir_path: Path | None = None,
59
+ sqlc_command: list[str] | None = None,
60
+ ) -> bool:
61
+ """Generate a typed SQL package from schema and queries.
62
+
63
+ Args:
64
+ schema_path: Path to the Postgres schema SQL file (relative to src_path)
65
+ package_full_name: Target module name (e.g., "myapp.mydatabase")
66
+ dsn_import: Import path to DSN string (e.g.,
67
+ "myapp.config:CONFIG.db_url")
68
+ application_name: Optional application name for connection pool
69
+ to_pascal_fn: Function to convert names to PascalCase (default:
70
+ pydantic's to_pascal)
71
+ to_snake_fn: Function to convert names to snake_case (default:
72
+ pydantic's to_snake)
73
+ debug_path: Optional path to save sqlc inputs for inspection
74
+ src_path: Base source path for scanning queries (default: Path())
75
+ sqlc_path: Optional path to sqlc binary if not in PATH
76
+ tempdir_path: Optional path for temporary file generation
77
+ sqlc_command: Optional command prefix to run sqlc
78
+
79
+ Returns:
80
+ True if the package was generated or modified, False otherwise
81
+ """
82
+ dsn_import_package, dsn_import_path = dsn_import.split(":")
83
+
84
+ package_name = package_full_name.split(".")[-1] # noqa: PLC0207
85
+ sql_fn_name = f"{package_name}_sql"
86
+
87
+ target_package_path = src_path / f"{package_full_name.replace('.', '/')}.py"
88
+
89
+ queries = list(find_all_queries(src_path, sql_fn_name))
90
+ queries = list({q.name: q for q in queries}.values())
91
+
92
+ dsn_package = importlib.import_module(dsn_import_package)
93
+ dsn = eval(dsn_import_path, vars(dsn_package)) # noqa: S307
94
+
95
+ sqlc_res = run_sqlc(
96
+ src_path / schema_path,
97
+ [(q.name, q.stmt) for q in queries],
98
+ dsn=dsn,
99
+ debug_path=debug_path,
100
+ sqlc_path=sqlc_path,
101
+ tempdir_path=tempdir_path,
102
+ sqlc_command=sqlc_command,
103
+ )
104
+
105
+ if sqlc_res.error:
106
+ logger.error("Error running SQLC:\n%s", sqlc_res.error)
107
+ return False
108
+
109
+ ordered_entities, result_types = map_entities(
110
+ package_name,
111
+ sqlc_res.queries,
112
+ sqlc_res.catalog,
113
+ sqlc_res.used_schemas(),
114
+ queries,
115
+ to_pascal_fn,
116
+ )
117
+
118
+ entities = [render_entity(e.name, e.column_specs) for e in ordered_entities]
119
+
120
+ used_enums = _collect_used_enums(sqlc_res)
121
+
122
+ enums = [
123
+ render_enum_class(e, package_name, to_pascal_fn, to_snake_fn)
124
+ for schema in sqlc_res.catalog.schemas
125
+ for e in schema.enums
126
+ if (schema.name, e.name) in used_enums
127
+ ]
128
+
129
+ query_classes = [
130
+ render_query_class(
131
+ q.name,
132
+ q.text,
133
+ package_name,
134
+ [
135
+ (
136
+ column_py_spec(
137
+ p.column,
138
+ sqlc_res.catalog,
139
+ package_name,
140
+ to_pascal_fn,
141
+ to_snake_fn,
142
+ p.number,
143
+ ),
144
+ p.column.is_named_param,
145
+ )
146
+ for p in q.params
147
+ ],
148
+ result_types[q.name],
149
+ len(q.columns),
150
+ )
151
+ for q in sqlc_res.queries
152
+ ]
153
+
154
+ query_overloads = [
155
+ render_query_overload(sql_fn_name, q.name, q.stmt, q.row_type) for q in queries
156
+ ]
157
+
158
+ query_dict_entries = [render_query_dict_entry(q.name, q.stmt) for q in queries]
159
+
160
+ new_content = render_package(
161
+ dsn_import_package,
162
+ dsn_import_path,
163
+ package_name,
164
+ sql_fn_name,
165
+ sorted(entities),
166
+ sorted(enums),
167
+ sorted(query_classes),
168
+ sorted(query_overloads),
169
+ sorted(query_dict_entries),
170
+ application_name,
171
+ )
172
+ changed = write_if_changed(target_package_path, new_content + "\n")
173
+ if changed:
174
+ logger.info(f"Generated SQL package {package_full_name}")
175
+ return changed
176
+
177
+
178
+ def render_package(
179
+ dsn_import_package: str,
180
+ dsn_import_path: str,
181
+ package_name: str,
182
+ sql_fn_name: str,
183
+ entities: list[str],
184
+ enums: list[str],
185
+ query_classes: list[str],
186
+ query_overloads: list[str],
187
+ query_dict_entries: list[str],
188
+ application_name: str | None = None,
189
+ ):
190
+ return f"""
191
+
192
+ # Code generated by iron_sql, DO NOT EDIT.
193
+
194
+ # fmt: off
195
+ # pyright: reportUnusedImport=false
196
+ # ruff: noqa: A002
197
+ # ruff: noqa: ARG001
198
+ # ruff: noqa: C901
199
+ # ruff: noqa: E303
200
+ # ruff: noqa: E501
201
+ # ruff: noqa: F401
202
+ # ruff: noqa: FBT001
203
+ # ruff: noqa: I001
204
+ # ruff: noqa: N801
205
+ # ruff: noqa: PLR0912
206
+ # ruff: noqa: PLR0913
207
+ # ruff: noqa: PLR0917
208
+ # ruff: noqa: Q000
209
+ # ruff: noqa: RUF100
210
+
211
+ import datetime
212
+ import decimal
213
+ import uuid
214
+ from collections.abc import AsyncIterator
215
+ from collections.abc import Sequence
216
+ from contextlib import asynccontextmanager
217
+ from contextvars import ContextVar
218
+ from dataclasses import dataclass
219
+ from enum import StrEnum
220
+ from typing import Literal
221
+ from typing import overload
222
+
223
+ import psycopg
224
+ import psycopg.rows
225
+ from psycopg.types import json as pgjson
226
+
227
+ from iron_sql import runtime
228
+
229
+ from {dsn_import_package} import {dsn_import_path.split(".", maxsplit=1)[0]}
230
+
231
+ {package_name.upper()}_POOL = runtime.ConnectionPool(
232
+ {dsn_import_path},
233
+ name="{package_name}",
234
+ application_name={application_name!r},
235
+ )
236
+
237
+ _{package_name}_connection = ContextVar[psycopg.AsyncConnection | None](
238
+ "_{package_name}_connection",
239
+ default=None,
240
+ )
241
+
242
+
243
+ @asynccontextmanager
244
+ async def {package_name}_connection() -> AsyncIterator[psycopg.AsyncConnection]:
245
+ async with {package_name.upper()}_POOL.connection_in_context(
246
+ _{package_name}_connection
247
+ ) as conn:
248
+ yield conn
249
+
250
+
251
+ @asynccontextmanager
252
+ async def {package_name}_transaction() -> AsyncIterator[None]:
253
+ async with {package_name}_connection() as conn, conn.transaction():
254
+ yield
255
+
256
+
257
+ {"\n\n\n".join(enums)}
258
+
259
+
260
+ {"\n\n\n".join(entities)}
261
+
262
+
263
+ class Query:
264
+ pass
265
+
266
+
267
+ {"\n\n\n".join(query_classes)}
268
+
269
+
270
+ _QUERIES: dict[str, type[Query]] = {{
271
+ {(",\n ").join(query_dict_entries)}
272
+ }}
273
+
274
+
275
+ {"\n".join(query_overloads)}
276
+ @overload
277
+ def {sql_fn_name}(stmt: str) -> Query: ...
278
+
279
+
280
+ def {sql_fn_name}(stmt: str, row_type: str | None = None) -> Query:
281
+ if stmt in _QUERIES:
282
+ return _QUERIES[stmt]()
283
+ msg = f"Unknown statement: {{stmt!r}}"
284
+ raise KeyError(msg)
285
+
286
+ """.strip()
287
+
288
+
289
+ def render_enum_class(
290
+ enum: Enum,
291
+ package_name: str,
292
+ to_pascal_fn: Callable[[str], str],
293
+ to_snake_fn: Callable[[str], str],
294
+ ) -> str:
295
+ class_name = to_pascal_fn(f"{package_name}_{to_snake_fn(enum.name)}")
296
+ members = []
297
+ seen_names: dict[str, int] = {}
298
+
299
+ for val in enum.vals:
300
+ name = to_snake_fn(val).upper()
301
+ name = "".join(c if c.isalnum() else "_" for c in name)
302
+ name = name.strip("_") or "EMPTY"
303
+ if name[0].isdigit():
304
+ name = "NUM" + name
305
+ if name in seen_names:
306
+ seen_names[name] += 1
307
+ name = f"{name}_{seen_names[name]}"
308
+ else:
309
+ seen_names[name] = 1
310
+ members.append(f'{name} = "{val}"')
311
+
312
+ return f"""
313
+
314
+ class {class_name}(StrEnum):
315
+ {indent_block("\n".join(members), " ")}
316
+
317
+ """.strip()
318
+
319
+
320
+ def render_entity(
321
+ name: str,
322
+ columns: tuple[ColumnPySpec, ...],
323
+ ) -> str:
324
+ return f"""
325
+
326
+ @dataclass(kw_only=True)
327
+ class {name}:
328
+ {"\n ".join(f"{c.name}: {c.py_type}" for c in columns)}
329
+
330
+ """.strip()
331
+
332
+
333
+ def deduplicate_params(
334
+ params: list[tuple[ColumnPySpec, bool]],
335
+ ) -> list[tuple[ColumnPySpec, bool]]:
336
+ seen = defaultdict(int)
337
+ result: list[tuple[ColumnPySpec, bool]] = []
338
+ for column, is_named in params:
339
+ seen[column.name] += 1
340
+ new_name = (
341
+ f"{column.name}{seen[column.name]}"
342
+ if seen[column.name] > 1
343
+ else column.name
344
+ )
345
+ new_column = dataclasses.replace(column, name=new_name)
346
+ result.append((new_column, is_named))
347
+ return result
348
+
349
+
350
+ def serialized_arg(column: ColumnPySpec) -> str:
351
+ match column:
352
+ case ColumnPySpec(db_type="json", not_null=True):
353
+ msg = "Unsupported column type: json"
354
+ raise TypeError(msg)
355
+ case ColumnPySpec(db_type="jsonb", is_array=True):
356
+ msg = "Unsupported column type: jsonb[]"
357
+ raise TypeError(msg)
358
+ case ColumnPySpec(db_type="jsonb", not_null=True, name=name):
359
+ return f"pgjson.Jsonb({name})"
360
+ case ColumnPySpec(db_type="jsonb", not_null=False, name=name):
361
+ return f"pgjson.Jsonb({name}) if {name} is not None else None"
362
+ case ColumnPySpec(name=name):
363
+ return name
364
+
365
+
366
+ def render_query_class(
367
+ query_name: str,
368
+ stmt: str,
369
+ package_name: str,
370
+ query_params: list[tuple[ColumnPySpec, bool]],
371
+ result: str,
372
+ columns_num: int,
373
+ ) -> str:
374
+ query_params = deduplicate_params(query_params)
375
+
376
+ match [column for column, _ in query_params]:
377
+ case []:
378
+ params_arg = "None"
379
+ case [column]:
380
+ params_arg = f"({serialized_arg(column)},)"
381
+ case columns:
382
+ params_arg = f"({', '.join(serialized_arg(column) for column in columns)})"
383
+
384
+ query_fn_params = [f"{column.name}: {column.py_type}" for column, _ in query_params]
385
+ first_named_param_idx = next(
386
+ (i for i, (_, is_named_param) in enumerate(query_params) if is_named_param), -1
387
+ )
388
+ if first_named_param_idx >= 0:
389
+ query_fn_params.insert(first_named_param_idx, "*")
390
+ query_fn_params.insert(0, "self")
391
+
392
+ base_result = result.removesuffix(" | None")
393
+
394
+ if columns_num == 0:
395
+ row_factory = "psycopg.rows.scalar_row"
396
+ elif columns_num == 1:
397
+ if result.endswith(" | None"):
398
+ row_factory = f"runtime.typed_scalar_row({base_result}, not_null=False)"
399
+ else:
400
+ row_factory = f"runtime.typed_scalar_row({base_result}, not_null=True)"
401
+ else:
402
+ row_factory = f"psycopg.rows.class_row({result})"
403
+
404
+ if columns_num > 0:
405
+ methods = f"""
406
+
407
+ async def query_all_rows({", ".join(query_fn_params)}) -> list[{result}]:
408
+ async with self._execute({params_arg}) as cur:
409
+ return await cur.fetchall()
410
+
411
+ async def query_single_row({", ".join(query_fn_params)}) -> {result}:
412
+ async with self._execute({params_arg}) as cur:
413
+ return runtime.get_one_row(await cur.fetchmany(2))
414
+
415
+ async def query_optional_row({", ".join(query_fn_params)}) -> {base_result} | None:
416
+ async with self._execute({params_arg}) as cur:
417
+ return runtime.get_one_row_or_none(await cur.fetchmany(2))
418
+
419
+ """.strip()
420
+ else:
421
+ methods = f"""
422
+
423
+ async def execute({", ".join(query_fn_params)}) -> None:
424
+ async with self._execute({params_arg}):
425
+ pass
426
+
427
+ """.strip()
428
+
429
+ return f"""
430
+
431
+ class {query_name}(Query):
432
+ @asynccontextmanager
433
+ async def _execute(self, params) -> AsyncIterator[psycopg.AsyncRawCursor[{result}]]:
434
+ stmt = {stmt!r}
435
+ async with (
436
+ {package_name}_connection() as conn,
437
+ psycopg.AsyncRawCursor(conn, row_factory={row_factory}) as cur,
438
+ ):
439
+ await cur.execute(stmt, params)
440
+ yield cur
441
+
442
+ {indent_block(methods, " ")}
443
+
444
+ """.strip()
445
+
446
+
447
+ def render_query_overload(
448
+ sql_fn_name: str, query_name: str, stmt: str, row_type: str | None
449
+ ) -> str:
450
+ result_arg = ""
451
+ if row_type:
452
+ result_arg = f", row_type: Literal[{row_type!r}]"
453
+
454
+ return f"""
455
+
456
+ @overload
457
+ def {sql_fn_name}(stmt: Literal[{stmt!r}]{result_arg}) -> {query_name}: ...
458
+
459
+ """.strip()
460
+
461
+
462
+ def render_query_dict_entry(query_name: str, stmt: str) -> str:
463
+ return f"{stmt!r}: {query_name}"
464
+
465
+
466
+ @dataclass(kw_only=True)
467
+ class CodeQuery:
468
+ stmt: str
469
+ row_type: str | None
470
+ file: Path
471
+ lineno: int
472
+
473
+ @property
474
+ def name(self) -> str:
475
+ md5_hash = hashlib.md5(self.stmt.encode(), usedforsecurity=False).hexdigest()
476
+ return f"Query_{md5_hash}{'_' + self.row_type if self.row_type else ''}"
477
+
478
+ @property
479
+ def location(self) -> str:
480
+ return f"{self.file}:{self.lineno}"
481
+
482
+
483
+ @dataclass(kw_only=True)
484
+ class SQLEntity:
485
+ package_name: str
486
+ set_name: str | None
487
+ table_name: str | None
488
+ columns: list[Column]
489
+ catalog: Catalog = dataclasses.field(repr=False)
490
+ to_pascal_fn: Callable[[str], str]
491
+ to_snake_fn: Callable[[str], str] = inflection.underscore
492
+
493
+ @property
494
+ def name(self) -> str:
495
+ if self.set_name:
496
+ return self.set_name
497
+ if self.table_name:
498
+ return self.to_pascal_fn(
499
+ f"{self.package_name}_{inflection.singularize(self.table_name)}"
500
+ )
501
+ hash_base = repr(self.column_specs)
502
+ md5_hash = hashlib.md5(hash_base.encode(), usedforsecurity=False).hexdigest()
503
+ return f"QueryResult_{md5_hash}"
504
+
505
+ @property
506
+ def column_specs(self) -> tuple[ColumnPySpec, ...]:
507
+ return tuple(
508
+ column_py_spec(
509
+ c, self.catalog, self.package_name, self.to_pascal_fn, self.to_snake_fn
510
+ )
511
+ for c in self.columns
512
+ )
513
+
514
+
515
+ def map_entities(
516
+ package_name: str,
517
+ queries_from_sqlc: list[Query],
518
+ catalog: Catalog,
519
+ used_schemas: list[str],
520
+ queries_from_code: list[CodeQuery],
521
+ to_pascal_fn: Callable[[str], str],
522
+ to_snake_fn: Callable[[str], str] = inflection.underscore,
523
+ ):
524
+ row_types = {q.name: q.row_type for q in queries_from_code}
525
+
526
+ table_entities = [
527
+ SQLEntity(
528
+ package_name=package_name,
529
+ set_name=None,
530
+ table_name=t.rel.name,
531
+ columns=t.columns,
532
+ catalog=catalog,
533
+ to_pascal_fn=to_pascal_fn,
534
+ to_snake_fn=to_snake_fn,
535
+ )
536
+ for sch in used_schemas
537
+ for t in catalog.schema_by_name(sch).tables
538
+ ]
539
+ specs_to_entities = {e.column_specs: e for e in table_entities}
540
+
541
+ for q in queries_from_sqlc:
542
+ if row_types[q.name] and not q.columns:
543
+ msg = f"Query has row_type={row_types[q.name]} but no result"
544
+ raise ValueError(msg)
545
+ if row_types[q.name] and len(q.columns) == 1:
546
+ msg = f"Query has row_type={row_types[q.name]} but only one column"
547
+ raise ValueError(msg)
548
+
549
+ query_result_entities = {
550
+ q.name: SQLEntity(
551
+ package_name=package_name,
552
+ set_name=row_types[q.name],
553
+ table_name=None,
554
+ columns=q.columns,
555
+ catalog=catalog,
556
+ to_pascal_fn=to_pascal_fn,
557
+ to_snake_fn=to_snake_fn,
558
+ )
559
+ for q in queries_from_sqlc
560
+ if len(q.columns) > 1
561
+ }
562
+
563
+ unique_entities = {
564
+ e.column_specs: specs_to_entities.get(e.column_specs, e)
565
+ for e in query_result_entities.values()
566
+ }
567
+ ordered_entities = sorted(
568
+ unique_entities.values(),
569
+ key=lambda e: (e.table_name is None, e.table_name or ""),
570
+ )
571
+
572
+ result_types = {}
573
+ for q in queries_from_sqlc:
574
+ if len(q.columns) == 0:
575
+ result_types[q.name] = "None"
576
+ elif len(q.columns) == 1:
577
+ result_types[q.name] = column_py_spec(
578
+ q.columns[0], catalog, package_name, to_pascal_fn, to_snake_fn
579
+ ).py_type
580
+ else:
581
+ column_spec = query_result_entities[q.name].column_specs
582
+ result_types[q.name] = unique_entities[column_spec].name
583
+
584
+ return ordered_entities, result_types
585
+
586
+
587
+ def column_py_spec( # noqa: C901, PLR0912
588
+ column: Column,
589
+ catalog: Catalog,
590
+ package_name: str,
591
+ to_pascal_fn: Callable[[str], str],
592
+ to_snake_fn: Callable[[str], str] = inflection.underscore,
593
+ number: int = 0,
594
+ ) -> ColumnPySpec:
595
+ db_type = column.type.name.removeprefix("pg_catalog.")
596
+ match db_type:
597
+ case "bool" | "boolean":
598
+ py_type = "bool"
599
+ case (
600
+ "int2"
601
+ | "int4"
602
+ | "int8"
603
+ | "smallint"
604
+ | "integer"
605
+ | "bigint"
606
+ | "serial"
607
+ | "bigserial"
608
+ ):
609
+ py_type = "int"
610
+ case "float4" | "float8":
611
+ py_type = "float"
612
+ case "numeric":
613
+ py_type = "decimal.Decimal"
614
+ case "varchar" | "text":
615
+ py_type = "str"
616
+ case "bytea":
617
+ py_type = "bytes"
618
+ case "json" | "jsonb":
619
+ py_type = "object"
620
+ case "date":
621
+ py_type = "datetime.date"
622
+ case "time" | "timetz":
623
+ py_type = "datetime.time"
624
+ case "timestamp" | "timestamptz":
625
+ py_type = "datetime.datetime"
626
+ case "uuid":
627
+ py_type = "uuid.UUID"
628
+ case "any" | "anyelement":
629
+ py_type = "object"
630
+ case enum if catalog.schema_by_ref(column.type).has_enum(enum):
631
+ py_type = (
632
+ to_pascal_fn(f"{package_name}_{to_snake_fn(enum)}")
633
+ if package_name
634
+ else "str"
635
+ )
636
+ case _:
637
+ logger.warning(f"Unknown SQL type: {column.type.name} ({column.name})")
638
+ py_type = "object"
639
+
640
+ if column.is_array:
641
+ py_type = f"Sequence[{py_type}]"
642
+
643
+ if not column.not_null:
644
+ py_type += " | None"
645
+
646
+ return ColumnPySpec(
647
+ name=column.name or f"param_{number}",
648
+ table=column.table.name if column.table else "unknown",
649
+ db_type=db_type,
650
+ not_null=column.not_null,
651
+ is_array=column.is_array,
652
+ py_type=py_type,
653
+ )
654
+
655
+
656
+ def find_fn_calls(
657
+ root_path: Path, fn_name: str
658
+ ) -> Iterator[tuple[Path, int, ast.Call]]:
659
+ for path in root_path.glob("**/*.py"):
660
+ content = path.read_text(encoding="utf-8")
661
+ if fn_name not in content:
662
+ continue
663
+ for node in ast.walk(ast.parse(content, filename=str(path))):
664
+ match node:
665
+ case ast.Call(func=ast.Name(id=id)) if id == fn_name:
666
+ yield path, node.lineno, node
667
+ case _:
668
+ pass
669
+
670
+
671
+ def find_all_queries(src_path: Path, sql_fn_name: str) -> Iterator[CodeQuery]:
672
+ for file, lineno, node in find_fn_calls(src_path, sql_fn_name):
673
+ relative_path = file.relative_to(src_path)
674
+
675
+ stmt_arg = node.args[0]
676
+ if (
677
+ len(node.args) != 1
678
+ or not isinstance(stmt_arg, ast.Constant)
679
+ or not isinstance(stmt_arg.value, str)
680
+ ):
681
+ msg = (
682
+ f"Invalid positional arguments for {sql_fn_name} "
683
+ f"at {relative_path}:{lineno}, "
684
+ "expected a single string literal"
685
+ )
686
+ raise TypeError(msg)
687
+
688
+ stmt = stmt_arg.value
689
+
690
+ row_type = None
691
+ for kw in node.keywords:
692
+ if not isinstance(kw.value, ast.Constant) or not isinstance(
693
+ kw.value.value, str
694
+ ):
695
+ msg = (
696
+ f"Invalid keyword argument {kw.arg} for {sql_fn_name} "
697
+ f"at {relative_path}:{lineno}, expected a string literal"
698
+ )
699
+ raise TypeError(msg)
700
+ if kw.arg == "row_type":
701
+ row_type = kw.value.value
702
+ break
703
+
704
+ yield CodeQuery(
705
+ stmt=stmt,
706
+ row_type=row_type,
707
+ file=relative_path,
708
+ lineno=lineno,
709
+ )
710
+
711
+
712
+ def indent_block(block: str, indent: str) -> str:
713
+ return "\n".join(
714
+ indent + line if i > 0 and line.strip() else line
715
+ for i, line in enumerate(block.split("\n"))
716
+ )
717
+
718
+
719
+ def write_if_changed(path: Path, new_content: str) -> bool:
720
+ path.parent.mkdir(parents=True, exist_ok=True)
721
+ existing_content = path.read_text(encoding="utf-8") if path.exists() else None
722
+ if existing_content == new_content:
723
+ return False
724
+ path.write_text(new_content, encoding="utf-8")
725
+ path.touch()
726
+ return True
@@ -0,0 +1,135 @@
1
+ from collections.abc import AsyncIterator
2
+ from collections.abc import Sequence
3
+ from contextlib import asynccontextmanager
4
+ from contextvars import ContextVar
5
+ from enum import Enum
6
+ from typing import Any
7
+ from typing import Literal
8
+ from typing import Self
9
+ from typing import overload
10
+
11
+ import psycopg
12
+ import psycopg.rows
13
+ import psycopg_pool
14
+
15
+
16
+ class NoRowsError(Exception):
17
+ pass
18
+
19
+
20
+ class TooManyRowsError(Exception):
21
+ pass
22
+
23
+
24
+ class ConnectionPool:
25
+ def __init__(
26
+ self,
27
+ conninfo: str,
28
+ *,
29
+ name: str | None = None,
30
+ application_name: str | None = None,
31
+ ) -> None:
32
+ self.conninfo = conninfo
33
+ self.name = name
34
+ self.application_name = application_name
35
+ self._init_psycopg_pool()
36
+
37
+ async def close(self) -> None:
38
+ await self.psycopg_pool.close()
39
+ self._init_psycopg_pool()
40
+
41
+ async def __aenter__(self) -> Self:
42
+ return self
43
+
44
+ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
45
+ await self.close()
46
+
47
+ async def await_connections(self) -> None:
48
+ await self.psycopg_pool.open(wait=True)
49
+
50
+ async def check(self) -> None:
51
+ await self.psycopg_pool.open()
52
+ await self.psycopg_pool.check()
53
+
54
+ @asynccontextmanager
55
+ async def connection(self) -> AsyncIterator[psycopg.AsyncConnection]:
56
+ await self.psycopg_pool.open()
57
+ async with self.psycopg_pool.connection() as conn:
58
+ yield conn
59
+
60
+ def _init_psycopg_pool(self) -> None:
61
+ self.psycopg_pool = psycopg_pool.AsyncConnectionPool(
62
+ self.conninfo,
63
+ open=False,
64
+ name=self.name,
65
+ kwargs={
66
+ "application_name": self.application_name,
67
+ # https://www.psycopg.org/psycopg3/docs/basic/transactions.html#autocommit-transactions
68
+ "autocommit": True,
69
+ },
70
+ )
71
+
72
+ @asynccontextmanager
73
+ async def connection_in_context(
74
+ self, context_var: ContextVar[psycopg.AsyncConnection | None]
75
+ ) -> AsyncIterator[psycopg.AsyncConnection]:
76
+ conn = context_var.get()
77
+ if conn is not None:
78
+ yield conn
79
+ return
80
+ async with self.connection() as conn:
81
+ token = context_var.set(conn)
82
+ try:
83
+ yield conn
84
+ finally:
85
+ context_var.reset(token)
86
+
87
+
88
+ def get_one_row[T](rows: list[T]) -> T:
89
+ if len(rows) == 0:
90
+ raise NoRowsError
91
+ if len(rows) > 1:
92
+ raise TooManyRowsError
93
+ return rows[0]
94
+
95
+
96
+ def get_one_row_or_none[T](rows: list[T]) -> T | None:
97
+ if len(rows) == 0:
98
+ return None
99
+ if len(rows) > 1:
100
+ raise TooManyRowsError
101
+ return rows[0]
102
+
103
+
104
+ @overload
105
+ def typed_scalar_row[T](
106
+ typ: type[T], *, not_null: Literal[True]
107
+ ) -> psycopg.rows.BaseRowFactory[T]: ...
108
+
109
+
110
+ @overload
111
+ def typed_scalar_row[T](
112
+ typ: type[T], *, not_null: Literal[False]
113
+ ) -> psycopg.rows.BaseRowFactory[T | None]: ...
114
+
115
+
116
+ def typed_scalar_row[T](
117
+ typ: type[T], *, not_null: bool
118
+ ) -> psycopg.rows.BaseRowFactory[T | None]:
119
+ def typed_scalar_row_(cursor) -> psycopg.rows.RowMaker[T | None]:
120
+ scalar_row_ = psycopg.rows.scalar_row(cursor)
121
+
122
+ def typed_scalar_row__(values: Sequence[Any]) -> T | None:
123
+ val = scalar_row_(values)
124
+ if not not_null and val is None:
125
+ return None
126
+ if not isinstance(val, typ):
127
+ if issubclass(typ, Enum):
128
+ return typ(val)
129
+ msg = f"Expected scalar of type {typ}, got {type(val)}"
130
+ raise TypeError(msg)
131
+ return val
132
+
133
+ return typed_scalar_row__
134
+
135
+ return typed_scalar_row_
@@ -0,0 +1,222 @@
1
+ import json
2
+ import re
3
+ import shutil
4
+ import subprocess # noqa: S404
5
+ import tempfile
6
+ import textwrap
7
+ from pathlib import Path
8
+
9
+ import pydantic
10
+
11
+
12
+ class CatalogReference(pydantic.BaseModel):
13
+ catalog: str
14
+ schema_name: str = pydantic.Field(..., alias="schema")
15
+ name: str
16
+
17
+
18
+ class Column(pydantic.BaseModel):
19
+ name: str
20
+ not_null: bool
21
+ is_array: bool
22
+ comment: str
23
+ length: int
24
+ is_named_param: bool
25
+ is_func_call: bool
26
+ scope: str
27
+ table: CatalogReference | None
28
+ table_alias: str
29
+ type: CatalogReference
30
+ is_sqlc_slice: bool
31
+ embed_table: None
32
+ original_name: str
33
+ unsigned: bool
34
+ array_dims: int
35
+
36
+
37
+ class Table(pydantic.BaseModel):
38
+ rel: CatalogReference
39
+ columns: list[Column]
40
+ comment: str
41
+
42
+
43
+ class Enum(pydantic.BaseModel):
44
+ name: str
45
+ vals: list[str]
46
+ comment: str
47
+
48
+
49
+ class CompositeType(pydantic.BaseModel):
50
+ name: str
51
+ comment: str
52
+
53
+
54
+ class Schema(pydantic.BaseModel):
55
+ comment: str
56
+ name: str
57
+ tables: list[Table]
58
+ enums: list[Enum]
59
+ composite_types: list[CompositeType]
60
+
61
+ def has_enum(self, name: str) -> bool:
62
+ return any(e.name == name for e in self.enums)
63
+
64
+
65
+ class Catalog(pydantic.BaseModel):
66
+ default_schema: str
67
+ name: str
68
+ schemas: list[Schema]
69
+
70
+ def schema_by_name(self, name: str) -> Schema:
71
+ for schema in self.schemas:
72
+ if schema.name == name:
73
+ return schema
74
+ msg = f"Schema not found: {name}"
75
+ raise ValueError(msg)
76
+
77
+ def schema_by_ref(self, ref: CatalogReference) -> Schema:
78
+ return self.schema_by_name(ref.schema_name or self.default_schema)
79
+
80
+
81
+ class QueryParameter(pydantic.BaseModel):
82
+ number: int
83
+ column: Column
84
+
85
+
86
+ class Query(pydantic.BaseModel):
87
+ text: str
88
+ name: str
89
+ cmd: str
90
+ columns: list[Column]
91
+ params: list[QueryParameter]
92
+
93
+
94
+ class SQLCResult(pydantic.BaseModel):
95
+ error: str | None = None
96
+ catalog: Catalog
97
+ queries: list[Query]
98
+
99
+ def used_schemas(self) -> list[str]:
100
+ result = {
101
+ c.table.schema_name
102
+ for q in self.queries
103
+ for c in q.columns
104
+ if c.table is not None
105
+ }
106
+ if "" in result:
107
+ result.remove("")
108
+ result.add(self.catalog.default_schema)
109
+ catalog_schema_names = {s.name for s in self.catalog.schemas}
110
+ return [s for s in result if s in catalog_schema_names]
111
+
112
+
113
+ def _resolve_sqlc_command(
114
+ sqlc_path: Path | None,
115
+ sqlc_command: list[str] | None,
116
+ ) -> list[str]:
117
+ if sqlc_command is not None:
118
+ if sqlc_path is not None:
119
+ msg = "sqlc_command and sqlc_path are mutually exclusive"
120
+ raise ValueError(msg)
121
+ if not sqlc_command:
122
+ msg = "sqlc_command must not be empty"
123
+ raise ValueError(msg)
124
+ return sqlc_command
125
+
126
+ if sqlc_path is None:
127
+ discovered_path = shutil.which("sqlc")
128
+ if discovered_path is None:
129
+ msg = "sqlc not found in PATH"
130
+ raise FileNotFoundError(msg)
131
+ sqlc_path = Path(discovered_path)
132
+ if not sqlc_path.exists():
133
+ msg = f"sqlc not found at {sqlc_path}"
134
+ raise FileNotFoundError(msg)
135
+
136
+ return [str(sqlc_path)]
137
+
138
+
139
+ def run_sqlc(
140
+ schema_path: Path,
141
+ queries: list[tuple[str, str]],
142
+ *,
143
+ dsn: str | None,
144
+ debug_path: Path | None = None,
145
+ sqlc_path: Path | None = None,
146
+ tempdir_path: Path | None = None,
147
+ sqlc_command: list[str] | None = None,
148
+ ) -> SQLCResult:
149
+ if not schema_path.exists():
150
+ msg = f"Schema file not found: {schema_path}"
151
+ raise ValueError(msg)
152
+
153
+ if not queries:
154
+ return SQLCResult(
155
+ catalog=Catalog(default_schema="", name="", schemas=[]),
156
+ queries=[],
157
+ )
158
+
159
+ queries = list({q[0]: q for q in queries}.values())
160
+ cmd_prefix = _resolve_sqlc_command(sqlc_path, sqlc_command)
161
+
162
+ with tempfile.TemporaryDirectory(
163
+ dir=str(tempdir_path) if tempdir_path else None
164
+ ) as tempdir:
165
+ queries_path = Path(tempdir) / "queries.sql"
166
+ queries_path.write_text(
167
+ "\n\n".join(
168
+ f"-- name: {name} :exec\n{preprocess_sql(stmt)};"
169
+ for name, stmt in queries
170
+ ),
171
+ encoding="utf-8",
172
+ )
173
+
174
+ (Path(tempdir) / "schema.sql").symlink_to(schema_path.absolute())
175
+
176
+ config_path = Path(tempdir) / "sqlc.json"
177
+ sqlc_config = {
178
+ "version": "2",
179
+ "sql": [
180
+ {
181
+ "schema": "schema.sql",
182
+ "queries": ["queries.sql"],
183
+ "engine": "postgresql",
184
+ "database": {"uri": dsn} if dsn else None,
185
+ "gen": {"json": {"out": ".", "filename": "out.json"}},
186
+ }
187
+ ],
188
+ }
189
+ config_path.write_text(json.dumps(sqlc_config, indent=2), encoding="utf-8")
190
+
191
+ cmd = [*cmd_prefix, "generate", "--file", str(config_path.resolve())]
192
+
193
+ sqlc_run_result = subprocess.run( # noqa: S603
194
+ cmd,
195
+ capture_output=True,
196
+ check=False,
197
+ )
198
+
199
+ json_out_path = Path(tempdir) / "out.json"
200
+
201
+ if debug_path:
202
+ debug_path.absolute().mkdir(parents=True, exist_ok=True)
203
+ shutil.copy(queries_path, debug_path)
204
+ shutil.copy(schema_path, debug_path / "schema.sql")
205
+ shutil.copy(config_path, debug_path)
206
+ if json_out_path.exists():
207
+ shutil.copy(json_out_path, debug_path)
208
+ elif (debug_path / "out.json").exists():
209
+ (debug_path / "out.json").unlink()
210
+
211
+ if not json_out_path.exists():
212
+ return SQLCResult(
213
+ error=sqlc_run_result.stderr.decode().strip(),
214
+ catalog=Catalog(default_schema="", name="", schemas=[]),
215
+ queries=[],
216
+ )
217
+ return SQLCResult.model_validate_json(json_out_path.read_text(encoding="utf-8"))
218
+
219
+
220
+ def preprocess_sql(stmt: str) -> str:
221
+ stmt = re.sub(r"@(\w+)\?", r"sqlc.narg('\1')", stmt)
222
+ return textwrap.dedent(stmt).strip()