iron-sql 0.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
iron_sql-0.1.1/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Ilia Ablamonov
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,61 @@
1
+ Metadata-Version: 2.4
2
+ Name: iron-sql
3
+ Version: 0.1.1
4
+ Summary: iron_sql generates typed async PostgreSQL clients and runtime helpers from schemas and SQL queries
5
+ Keywords: postgresql,sql,sqlc,psycopg,codegen,async
6
+ Author: Ilia Ablamonov
7
+ Author-email: Ilia Ablamonov <ilia@flamefork.ru>
8
+ License-Expression: MIT
9
+ License-File: LICENSE
10
+ Classifier: Development Status :: 3 - Alpha
11
+ Classifier: Intended Audience :: Developers
12
+ Classifier: Topic :: Software Development :: Libraries
13
+ Classifier: Programming Language :: Python :: 3
14
+ Classifier: Programming Language :: Python :: 3.13
15
+ Requires-Dist: inflection>=0.5.1
16
+ Requires-Dist: psycopg>=3.2.12
17
+ Requires-Dist: psycopg-pool>=3.2.7
18
+ Requires-Dist: pydantic>=2.12.4
19
+ Requires-Python: >=3.13
20
+ Project-URL: Homepage, https://github.com/Flamefork/iron_sql
21
+ Project-URL: Issues, https://github.com/Flamefork/iron_sql/issues
22
+ Project-URL: Repository, https://github.com/Flamefork/iron_sql.git
23
+ Description-Content-Type: text/markdown
24
+
25
+ # iron_sql
26
+
27
+ iron_sql keeps SQL close to Python call sites while giving you typed, async query helpers. You write SQL once, keep it in version control, and get generated clients that match your schema without hand-written boilerplate.
28
+
29
+ ## Why use it
30
+ - SQL-first workflow: write queries where they are used; no ORM layer to fight.
31
+ - Strong typing: generated dataclasses and method signatures flow through your IDE and type checker.
32
+ - Async-ready: built on `psycopg` with pooled connections and transaction helpers.
33
+ - Safe-by-default: helper methods enforce expected row counts instead of returning silent `None`.
34
+
35
+ ## Quick start
36
+ 1. Install `iron_sql`, `psycopg`, `psycopg-pool`, `orjson`, and `pydantic`.
37
+ 2. Install [`sqlc` v2](https://docs.sqlc.dev/en/latest/overview/install.html) and ensure `/usr/local/bin/sqlc` is in PATH.
38
+ 3. Add a Postgres schema dump, for example `db/adept_schema.sql`.
39
+ 4. Call `generate_sql_package(schema_path=..., package_full_name=..., dsn_import=...)` from a small script or task. The generator scans your code, runs `sqlc`, and writes a module such as `adept/db/adept.py`.
40
+
41
+ ## Authoring queries
42
+ - Use the package helper for your DB, e.g. `adept_sql("select ...")`. The SQL string must be a literal so the generator can find it.
43
+ - Named parameters:
44
+ - Required: `@param`
45
+ - Optional: `@param?` (expands to `sqlc.narg('param')`)
46
+ - Positional placeholders (`$1`) stay as-is.
47
+ - Multi-column results can opt into a custom dataclass with `row_type="MyResult"`. Single-column queries return a scalar type; statements without results expose `execute()`.
48
+
49
+ ## Using generated clients
50
+ - `*_sql("...")` returns a query object with methods derived from the result shape:
51
+ - `execute()` when no rows are returned.
52
+ - `query_all_rows()`, `query_single_row()`, `query_optional_row()` for result sets.
53
+ - `*_connection()` yields a pooled `psycopg.AsyncConnection`; `*_transaction()` wraps it in a transaction context.
54
+ - JSONB params are sent with `pgjson.Jsonb`; scalar row factories validate types and raise when they do not match.
55
+
56
+ ## Adding another database package
57
+ Provide the schema file and DSN import string, then call `generate_sql_package()` with:
58
+ - `schema_path`: path to the schema SQL file.
59
+ - `package_full_name`: target module, e.g. `adept.db.analytics`.
60
+ - `dsn_import`: import path to a DSN string, e.g. `adept.config:CONFIG.analytics_db_url.value`.
61
+ - Optional `application_name`, `debug_path`, and `to_pascal_fn` if you need naming overrides or want to keep `sqlc` inputs for inspection.
@@ -0,0 +1,37 @@
1
+ # iron_sql
2
+
3
+ iron_sql keeps SQL close to Python call sites while giving you typed, async query helpers. You write SQL once, keep it in version control, and get generated clients that match your schema without hand-written boilerplate.
4
+
5
+ ## Why use it
6
+ - SQL-first workflow: write queries where they are used; no ORM layer to fight.
7
+ - Strong typing: generated dataclasses and method signatures flow through your IDE and type checker.
8
+ - Async-ready: built on `psycopg` with pooled connections and transaction helpers.
9
+ - Safe-by-default: helper methods enforce expected row counts instead of returning silent `None`.
10
+
11
+ ## Quick start
12
+ 1. Install `iron_sql`, `psycopg`, `psycopg-pool`, `orjson`, and `pydantic`.
13
+ 2. Install [`sqlc` v2](https://docs.sqlc.dev/en/latest/overview/install.html) and ensure `/usr/local/bin/sqlc` is in PATH.
14
+ 3. Add a Postgres schema dump, for example `db/adept_schema.sql`.
15
+ 4. Call `generate_sql_package(schema_path=..., package_full_name=..., dsn_import=...)` from a small script or task. The generator scans your code, runs `sqlc`, and writes a module such as `adept/db/adept.py`.
16
+
17
+ ## Authoring queries
18
+ - Use the package helper for your DB, e.g. `adept_sql("select ...")`. The SQL string must be a literal so the generator can find it.
19
+ - Named parameters:
20
+ - Required: `@param`
21
+ - Optional: `@param?` (expands to `sqlc.narg('param')`)
22
+ - Positional placeholders (`$1`) stay as-is.
23
+ - Multi-column results can opt into a custom dataclass with `row_type="MyResult"`. Single-column queries return a scalar type; statements without results expose `execute()`.
24
+
25
+ ## Using generated clients
26
+ - `*_sql("...")` returns a query object with methods derived from the result shape:
27
+ - `execute()` when no rows are returned.
28
+ - `query_all_rows()`, `query_single_row()`, `query_optional_row()` for result sets.
29
+ - `*_connection()` yields a pooled `psycopg.AsyncConnection`; `*_transaction()` wraps it in a transaction context.
30
+ - JSONB params are sent with `pgjson.Jsonb`; scalar row factories validate types and raise when they do not match.
31
+
32
+ ## Adding another database package
33
+ Provide the schema file and DSN import string, then call `generate_sql_package()` with:
34
+ - `schema_path`: path to the schema SQL file.
35
+ - `package_full_name`: target module, e.g. `adept.db.analytics`.
36
+ - `dsn_import`: import path to a DSN string, e.g. `adept.config:CONFIG.analytics_db_url.value`.
37
+ - Optional `application_name`, `debug_path`, and `to_pascal_fn` if you need naming overrides or want to keep `sqlc` inputs for inspection.
@@ -0,0 +1,117 @@
1
+ [project]
2
+ name = "iron-sql"
3
+ version = "0.1.1"
4
+
5
+ description = "iron_sql generates typed async PostgreSQL clients and runtime helpers from schemas and SQL queries"
6
+ readme = "README.md"
7
+ authors = [{ name = "Ilia Ablamonov", email = "ilia@flamefork.ru" }]
8
+ license = "MIT"
9
+ license-files = ["LICENSE"]
10
+ keywords = ["postgresql", "sql", "sqlc", "psycopg", "codegen", "async"]
11
+ classifiers = [
12
+ "Development Status :: 3 - Alpha",
13
+ "Intended Audience :: Developers",
14
+ "Topic :: Software Development :: Libraries",
15
+ "Programming Language :: Python :: 3",
16
+ "Programming Language :: Python :: 3.13",
17
+ ]
18
+
19
+ requires-python = ">=3.13"
20
+ dependencies = [
21
+ "inflection>=0.5.1",
22
+ "psycopg>=3.2.12",
23
+ "psycopg-pool>=3.2.7",
24
+ "pydantic>=2.12.4",
25
+ ]
26
+
27
+ [project.urls]
28
+ Homepage = "https://github.com/Flamefork/iron_sql"
29
+ Repository = "https://github.com/Flamefork/iron_sql.git"
30
+ Issues = "https://github.com/Flamefork/iron_sql/issues"
31
+
32
+ [build-system]
33
+ requires = ["uv_build>=0.9.4,<0.10.0"]
34
+ build-backend = "uv_build"
35
+
36
+ [dependency-groups]
37
+ dev = [
38
+ "basedpyright>=1.31.7",
39
+ "pytest>=8.4.2",
40
+ "pytest-asyncio>=1.2.0",
41
+ "pytest-cov>=7.0.0",
42
+ "ruff>=0.14.1",
43
+ ]
44
+
45
+ [tool.pyright]
46
+ typeCheckingMode = "strict"
47
+ reportUnknownArgumentType = "none"
48
+ reportUnknownLambdaType = "none"
49
+ reportUnknownMemberType = "none"
50
+ reportUnknownParameterType = "none"
51
+ reportUnknownVariableType = "none"
52
+ reportMissingParameterType = "none"
53
+ reportMissingTypeArgument = "none"
54
+ reportMissingTypeStubs = "none"
55
+ deprecateTypingAliases = true
56
+ reportImportCycles = true
57
+ reportUnnecessaryTypeIgnoreComment = true
58
+ reportUnreachable = true
59
+ reportIgnoreCommentWithoutRule = true
60
+ reportImplicitRelativeImport = true
61
+
62
+ [tool.ruff]
63
+ target-version = "py313"
64
+
65
+ [tool.ruff.format]
66
+ preview = true
67
+
68
+ [tool.ruff.lint]
69
+ preview = true
70
+ select = ["ALL"]
71
+ ignore = [
72
+ "ANN",
73
+ "COM812",
74
+ "CPY",
75
+ "D",
76
+ "FIX",
77
+ "G004",
78
+ "ISC001",
79
+ "PLC1901",
80
+ "PLR0911",
81
+ "PLR0915",
82
+ "PLR6301",
83
+ "RUF001",
84
+ "RUF002",
85
+ "RUF003",
86
+ "TC006",
87
+ "TD",
88
+ ]
89
+
90
+ [tool.ruff.lint.per-file-ignores]
91
+ "{*_test.py,conftest.py}" = ["A002", "PLR2004", "S", "FBT"]
92
+
93
+ [tool.ruff.lint.isort]
94
+ force-single-line = true
95
+
96
+ [tool.ruff.lint.pylint]
97
+ max-args = 10
98
+
99
+ [tool.ruff.lint.flake8-tidy-imports]
100
+ ban-relative-imports = "all"
101
+
102
+ [tool.pytest.ini_options]
103
+ filterwarnings = ["error"]
104
+ addopts = ["--no-cov-on-fail", "--cov-report=term-missing:skip-covered"]
105
+ asyncio_mode = "auto"
106
+ asyncio_default_fixture_loop_scope = "function"
107
+
108
+ [tool.coverage.run]
109
+ branch = true
110
+ omit = ["*_test.py"]
111
+ data_file = ".coverage/db.sqlite"
112
+
113
+ [tool.coverage.html]
114
+ directory = ".coverage/htmlcov"
115
+
116
+ [tool.coverage.report]
117
+ exclude_also = ["if TYPE_CHECKING:", "@overload"]
@@ -0,0 +1,7 @@
1
+ """iron_gql: Typed GraphQL client generator for Python."""
2
+
3
+ from iron_sql.generator import generate_sql_package
4
+
5
+ __all__ = [
6
+ "generate_sql_package",
7
+ ]
@@ -0,0 +1,600 @@
1
+ import ast
2
+ import dataclasses
3
+ import hashlib
4
+ import importlib
5
+ import logging
6
+ from collections import defaultdict
7
+ from collections.abc import Callable
8
+ from collections.abc import Iterator
9
+ from dataclasses import dataclass
10
+ from operator import attrgetter
11
+ from pathlib import Path
12
+
13
+ import inflection
14
+ from pydantic import alias_generators
15
+
16
+ from iron_sql.sqlc import Catalog
17
+ from iron_sql.sqlc import Column
18
+ from iron_sql.sqlc import Query
19
+ from iron_sql.sqlc import run_sqlc
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ @dataclass(kw_only=True, frozen=True)
25
+ class ColumnPySpec:
26
+ name: str
27
+ table: str
28
+ db_type: str
29
+ not_null: bool
30
+ is_array: bool
31
+ py_type: str
32
+
33
+
34
+ def generate_sql_package( # noqa: PLR0914
35
+ *,
36
+ schema_path: Path,
37
+ package_full_name: str,
38
+ dsn_import: str,
39
+ application_name: str | None = None,
40
+ to_pascal_fn=alias_generators.to_pascal,
41
+ debug_path: Path | None = None,
42
+ src_path: Path,
43
+ ) -> bool:
44
+ dsn_import_package, dsn_import_path = dsn_import.split(":")
45
+
46
+ package_name = package_full_name.split(".")[-1] # noqa: PLC0207
47
+ sql_fn_name = f"{package_name}_sql"
48
+
49
+ target_package_path = src_path / f"{package_full_name.replace('.', '/')}.py"
50
+
51
+ queries = list(find_all_queries(src_path, sql_fn_name))
52
+ queries = list({q.name: q for q in queries}.values())
53
+
54
+ dsn = attrgetter(dsn_import_path)(importlib.import_module(dsn_import_package))
55
+
56
+ sqlc_res = run_sqlc(
57
+ src_path / schema_path,
58
+ [(q.name, q.stmt) for q in queries],
59
+ dsn=dsn,
60
+ debug_path=debug_path,
61
+ )
62
+
63
+ if sqlc_res.error:
64
+ logger.error("Error running SQLC:\n%s", sqlc_res.error)
65
+ return False
66
+
67
+ ordered_entities, result_types = map_entities(
68
+ package_name,
69
+ sqlc_res.queries,
70
+ sqlc_res.catalog,
71
+ sqlc_res.used_schemas(),
72
+ queries,
73
+ to_pascal_fn,
74
+ )
75
+
76
+ entities = [render_entity(e.name, e.column_specs) for e in ordered_entities]
77
+
78
+ query_classes = [
79
+ render_query_class(
80
+ q.name,
81
+ q.text,
82
+ package_name,
83
+ [
84
+ (
85
+ column_py_spec(p.column, sqlc_res.catalog, p.number),
86
+ p.column.is_named_param,
87
+ )
88
+ for p in q.params
89
+ ],
90
+ result_types[q.name],
91
+ len(q.columns),
92
+ )
93
+ for q in sqlc_res.queries
94
+ ]
95
+
96
+ query_overloads = [
97
+ render_query_overload(sql_fn_name, q.name, q.stmt, q.row_type) for q in queries
98
+ ]
99
+
100
+ query_cases = [render_query_case(q.name, q.stmt) for q in queries]
101
+
102
+ new_content = render_package(
103
+ dsn_import_package,
104
+ dsn_import_path,
105
+ package_name,
106
+ sql_fn_name,
107
+ sorted(entities),
108
+ sorted(query_classes),
109
+ sorted(query_overloads),
110
+ sorted(query_cases),
111
+ application_name,
112
+ )
113
+ changed = write_if_changed(target_package_path, new_content + "\n")
114
+ if changed:
115
+ logger.info(f"Generated SQL package {package_full_name}")
116
+ return changed
117
+
118
+
119
+ def render_package(
120
+ dsn_import_package: str,
121
+ dsn_import_path: str,
122
+ package_name: str,
123
+ sql_fn_name: str,
124
+ entities: list[str],
125
+ query_classes: list[str],
126
+ query_overloads: list[str],
127
+ query_cases: list[str],
128
+ application_name: str | None = None,
129
+ ):
130
+ return f"""
131
+
132
+ # Code generated by iron_sql, DO NOT EDIT.
133
+
134
+ # fmt: off
135
+ # pyright: reportUnusedImport=false
136
+ # ruff: noqa: A002
137
+ # ruff: noqa: ARG001
138
+ # ruff: noqa: C901
139
+ # ruff: noqa: E303
140
+ # ruff: noqa: E501
141
+ # ruff: noqa: F401
142
+ # ruff: noqa: FBT001
143
+ # ruff: noqa: I001
144
+ # ruff: noqa: N801
145
+ # ruff: noqa: PLR0912
146
+ # ruff: noqa: PLR0913
147
+ # ruff: noqa: PLR0917
148
+ # ruff: noqa: Q000
149
+ # ruff: noqa: RUF100
150
+
151
+ import datetime
152
+ import decimal
153
+ import uuid
154
+ from collections.abc import AsyncIterator
155
+ from collections.abc import Sequence
156
+ from contextlib import asynccontextmanager
157
+ from contextvars import ContextVar
158
+ from dataclasses import dataclass
159
+ from typing import Literal
160
+ from typing import overload
161
+
162
+ import psycopg
163
+ import psycopg.rows
164
+ from psycopg.types import json as pgjson
165
+
166
+ from iron_sql import runtime
167
+
168
+ from {dsn_import_package} import {dsn_import_path.split(".", maxsplit=1)[0]}
169
+
170
+ {package_name.upper()}_POOL = runtime.ConnectionPool(
171
+ {dsn_import_path},
172
+ name="{package_name}",
173
+ application_name="{application_name}",
174
+ )
175
+
176
+ _{package_name}_connection = ContextVar[psycopg.AsyncConnection | None](
177
+ "_{package_name}_connection",
178
+ default=None,
179
+ )
180
+
181
+
182
+ @asynccontextmanager
183
+ async def {package_name}_connection() -> AsyncIterator[psycopg.AsyncConnection]:
184
+ async with {package_name.upper()}_POOL.connection_in_context(
185
+ _{package_name}_connection
186
+ ) as conn:
187
+ yield conn
188
+
189
+
190
+ @asynccontextmanager
191
+ async def {package_name}_transaction() -> AsyncIterator[None]:
192
+ async with {package_name}_connection() as conn, conn.transaction():
193
+ yield
194
+
195
+
196
+ {"\n\n\n".join(entities)}
197
+
198
+
199
+ class Query:
200
+ pass
201
+
202
+
203
+ {"\n\n\n".join(query_classes)}
204
+
205
+
206
+ {"\n".join(query_overloads)}
207
+ @overload
208
+ def {sql_fn_name}(stmt: str) -> Query: ...
209
+
210
+
211
+ def {sql_fn_name}(stmt: str, row_type: str | None = None) -> Query:
212
+ {indent_block("\n".join(query_cases), " ")}
213
+ return Query()
214
+
215
+ """.strip()
216
+
217
+
218
+ def render_entity(
219
+ name: str,
220
+ columns: tuple[ColumnPySpec, ...],
221
+ ) -> str:
222
+ return f"""
223
+
224
+ @dataclass(kw_only=True)
225
+ class {name}:
226
+ {"\n ".join(f"{c.name}: {c.py_type}" for c in columns)}
227
+
228
+ """.strip()
229
+
230
+
231
+ def deduplicate_params(
232
+ params: list[tuple[ColumnPySpec, bool]],
233
+ ) -> list[tuple[ColumnPySpec, bool]]:
234
+ seen = defaultdict(int)
235
+ result: list[tuple[ColumnPySpec, bool]] = []
236
+ for column, is_named in params:
237
+ seen[column.name] += 1
238
+ new_name = (
239
+ f"{column.name}{seen[column.name]}"
240
+ if seen[column.name] > 1
241
+ else column.name
242
+ )
243
+ new_column = dataclasses.replace(column, name=new_name)
244
+ result.append((new_column, is_named))
245
+ return result
246
+
247
+
248
+ def serialized_arg(column: ColumnPySpec) -> str:
249
+ match column:
250
+ case ColumnPySpec(db_type="json", not_null=True):
251
+ msg = "Unsupported column type: json"
252
+ raise TypeError(msg)
253
+ case ColumnPySpec(db_type="jsonb", is_array=True):
254
+ msg = "Unsupported column type: jsonb[]"
255
+ raise TypeError(msg)
256
+ case ColumnPySpec(db_type="jsonb", not_null=True, name=name):
257
+ return f"pgjson.Jsonb({name})"
258
+ case ColumnPySpec(db_type="jsonb", not_null=False, name=name):
259
+ return f"pgjson.Jsonb({name}) if {name} is not None else None"
260
+ case ColumnPySpec(name=name):
261
+ return name
262
+
263
+
264
+ def render_query_class(
265
+ query_name: str,
266
+ stmt: str,
267
+ package_name: str,
268
+ query_params: list[tuple[ColumnPySpec, bool]],
269
+ result: str,
270
+ columns_num: int,
271
+ ) -> str:
272
+ query_params = deduplicate_params(query_params)
273
+
274
+ match [column for column, _ in query_params]:
275
+ case []:
276
+ params_arg = "None"
277
+ case [column]:
278
+ params_arg = f"({serialized_arg(column)},)"
279
+ case columns:
280
+ params_arg = f"({', '.join(serialized_arg(column) for column in columns)})"
281
+
282
+ query_fn_params = [f"{column.name}: {column.py_type}" for column, _ in query_params]
283
+ first_named_param_idx = next(
284
+ (i for i, (_, is_named_param) in enumerate(query_params) if is_named_param), -1
285
+ )
286
+ if first_named_param_idx >= 0:
287
+ query_fn_params.insert(first_named_param_idx, "*")
288
+ query_fn_params.insert(0, "self")
289
+
290
+ base_result = result.removesuffix(" | None")
291
+
292
+ if columns_num == 0:
293
+ row_factory = "psycopg.rows.scalar_row"
294
+ elif columns_num == 1:
295
+ if result.endswith(" | None"):
296
+ row_factory = f"runtime.typed_scalar_row({base_result}, not_null=False)"
297
+ else:
298
+ row_factory = f"runtime.typed_scalar_row({base_result}, not_null=True)"
299
+ else:
300
+ row_factory = f"psycopg.rows.class_row({result})"
301
+
302
+ if columns_num > 0:
303
+ methods = f"""
304
+
305
+ async def query_all_rows({", ".join(query_fn_params)}) -> list[{result}]:
306
+ async with self._execute({params_arg}) as cur:
307
+ return await cur.fetchall()
308
+
309
+ async def query_single_row({", ".join(query_fn_params)}) -> {result}:
310
+ async with self._execute({params_arg}) as cur:
311
+ return runtime.get_one_row(await cur.fetchall())
312
+
313
+ async def query_optional_row({", ".join(query_fn_params)}) -> {base_result} | None:
314
+ async with self._execute({params_arg}) as cur:
315
+ return runtime.get_one_row_or_none(await cur.fetchall())
316
+
317
+ """.strip()
318
+ else:
319
+ methods = f"""
320
+
321
+ async def execute({", ".join(query_fn_params)}) -> None:
322
+ async with self._execute({params_arg}):
323
+ pass
324
+
325
+ """.strip()
326
+
327
+ return f"""
328
+
329
+ class {query_name}(Query):
330
+ @asynccontextmanager
331
+ async def _execute(self, params) -> AsyncIterator[psycopg.AsyncRawCursor[{result}]]:
332
+ stmt = {stmt!r}
333
+ async with (
334
+ {package_name}_connection() as conn,
335
+ psycopg.AsyncRawCursor(conn, row_factory={row_factory}) as cur,
336
+ ):
337
+ await cur.execute(stmt, params)
338
+ yield cur
339
+
340
+ {indent_block(methods, " ")}
341
+
342
+ """.strip()
343
+
344
+
345
+ def render_query_overload(
346
+ sql_fn_name: str, query_name: str, stmt: str, row_type: str | None
347
+ ) -> str:
348
+ result_arg = ""
349
+ if row_type:
350
+ result_arg = f", row_type: Literal[{row_type!r}]"
351
+
352
+ return f"""
353
+
354
+ @overload
355
+ def {sql_fn_name}(stmt: Literal[{stmt!r}]{result_arg}) -> {query_name}: ...
356
+
357
+ """.strip()
358
+
359
+
360
+ def render_query_case(query_name: str, stmt: str) -> str:
361
+ return f"""
362
+
363
+ if stmt == {stmt!r}:
364
+ return {query_name}()
365
+
366
+ """.strip()
367
+
368
+
369
+ @dataclass(kw_only=True)
370
+ class CodeQuery:
371
+ stmt: str
372
+ row_type: str | None
373
+ file: Path
374
+ lineno: int
375
+
376
+ @property
377
+ def name(self) -> str:
378
+ md5_hash = hashlib.md5(self.stmt.encode(), usedforsecurity=False).hexdigest()
379
+ return f"Query_{md5_hash}{'_' + self.row_type if self.row_type else ''}"
380
+
381
+ @property
382
+ def location(self) -> str:
383
+ return f"{self.file}:{self.lineno}"
384
+
385
+
386
+ @dataclass(kw_only=True)
387
+ class SQLEntity:
388
+ package_name: str
389
+ set_name: str | None
390
+ table_name: str | None
391
+ columns: list[Column]
392
+ catalog: Catalog = dataclasses.field(repr=False)
393
+ to_pascal_fn: Callable[[str], str]
394
+
395
+ @property
396
+ def name(self) -> str:
397
+ if self.set_name:
398
+ return self.set_name
399
+ if self.table_name:
400
+ return self.to_pascal_fn(
401
+ f"{self.package_name}_{inflection.singularize(self.table_name)}"
402
+ )
403
+ hash_base = repr(self.column_specs)
404
+ md5_hash = hashlib.md5(hash_base.encode(), usedforsecurity=False).hexdigest()
405
+ return f"QueryResult_{md5_hash}"
406
+
407
+ @property
408
+ def column_specs(self) -> tuple[ColumnPySpec, ...]:
409
+ return tuple(column_py_spec(c, self.catalog) for c in self.columns)
410
+
411
+
412
+ def map_entities(
413
+ package_name: str,
414
+ queries_from_sqlc: list[Query],
415
+ catalog: Catalog,
416
+ used_schemas: list[str],
417
+ queries_from_code: list[CodeQuery],
418
+ to_pascal_fn: Callable[[str], str],
419
+ ):
420
+ row_types = {q.name: q.row_type for q in queries_from_code}
421
+
422
+ table_entities = [
423
+ SQLEntity(
424
+ package_name=package_name,
425
+ set_name=None,
426
+ table_name=t.rel.name,
427
+ columns=t.columns,
428
+ catalog=catalog,
429
+ to_pascal_fn=to_pascal_fn,
430
+ )
431
+ for sch in used_schemas
432
+ for t in catalog.schema_by_name(sch).tables
433
+ ]
434
+ specs_to_entities = {e.column_specs: e for e in table_entities}
435
+
436
+ for q in queries_from_sqlc:
437
+ if row_types[q.name] and not q.columns:
438
+ msg = f"Query has row_type={row_types[q.name]} but no result"
439
+ raise ValueError(msg)
440
+ if row_types[q.name] and len(q.columns) == 1:
441
+ msg = f"Query has row_type={row_types[q.name]} but only one column"
442
+ raise ValueError(msg)
443
+
444
+ query_result_entities = {
445
+ q.name: SQLEntity(
446
+ package_name=package_name,
447
+ set_name=row_types[q.name],
448
+ table_name=None,
449
+ columns=q.columns,
450
+ catalog=catalog,
451
+ to_pascal_fn=to_pascal_fn,
452
+ )
453
+ for q in queries_from_sqlc
454
+ if len(q.columns) > 1
455
+ }
456
+
457
+ unique_entities = {
458
+ e.column_specs: specs_to_entities.get(e.column_specs, e)
459
+ for e in query_result_entities.values()
460
+ }
461
+ ordered_entities = sorted(
462
+ unique_entities.values(),
463
+ key=lambda e: (e.table_name is None, e.table_name or ""),
464
+ )
465
+
466
+ result_types = {}
467
+ for q in queries_from_sqlc:
468
+ if len(q.columns) == 0:
469
+ result_types[q.name] = "None"
470
+ elif len(q.columns) == 1:
471
+ result_types[q.name] = column_py_spec(q.columns[0], catalog).py_type
472
+ else:
473
+ column_spec = query_result_entities[q.name].column_specs
474
+ result_types[q.name] = unique_entities[column_spec].name
475
+
476
+ return ordered_entities, result_types
477
+
478
+
479
+ def column_py_spec( # noqa: C901, PLR0912
480
+ column: Column, catalog: Catalog, number: int = 0
481
+ ) -> ColumnPySpec:
482
+ db_type = column.type.name.removeprefix("pg_catalog.")
483
+ match db_type:
484
+ case "bool" | "boolean":
485
+ py_type = "bool"
486
+ case "int2" | "int4" | "int8" | "smallint" | "integer" | "bigint":
487
+ py_type = "int"
488
+ case "float4" | "float8":
489
+ py_type = "float"
490
+ case "numeric":
491
+ py_type = "decimal.Decimal"
492
+ case "varchar" | "text":
493
+ py_type = "str"
494
+ case "bytea":
495
+ py_type = "bytes"
496
+ case "json" | "jsonb":
497
+ py_type = "object"
498
+ case "date":
499
+ py_type = "datetime.date"
500
+ case "time" | "timetz":
501
+ py_type = "datetime.time"
502
+ case "timestamp" | "timestamptz":
503
+ py_type = "datetime.datetime"
504
+ case "uuid":
505
+ py_type = "uuid.UUID"
506
+ case "any" | "anyelement":
507
+ py_type = "object"
508
+ case enum if catalog.schema_by_ref(column.table).has_enum(enum):
509
+ py_type = "str"
510
+ case _:
511
+ logger.warning(f"Unknown SQL type: {column.type.name} ({column.name})")
512
+ py_type = "object"
513
+
514
+ if column.is_array:
515
+ py_type = f"Sequence[{py_type}]"
516
+
517
+ if not column.not_null:
518
+ py_type += " | None"
519
+
520
+ return ColumnPySpec(
521
+ name=column.name or f"param_{number}",
522
+ table=column.table.name if column.table else "unknown",
523
+ db_type=db_type,
524
+ not_null=column.not_null,
525
+ is_array=column.is_array,
526
+ py_type=py_type,
527
+ )
528
+
529
+
530
+ def find_fn_calls(
531
+ root_path: Path, fn_name: str
532
+ ) -> Iterator[tuple[Path, int, ast.Call]]:
533
+ for path in root_path.glob("**/*.py"):
534
+ content = path.read_text(encoding="utf-8")
535
+ if fn_name not in content:
536
+ continue
537
+ for node in ast.walk(ast.parse(content, filename=str(path))):
538
+ match node:
539
+ case ast.Call(func=ast.Name(id=id)) if id == fn_name:
540
+ yield path, node.lineno, node
541
+ case _:
542
+ pass
543
+
544
+
545
+ def find_all_queries(src_path: Path, sql_fn_name: str) -> Iterator[CodeQuery]:
546
+ for file, lineno, node in find_fn_calls(src_path, sql_fn_name):
547
+ relative_path = file.relative_to(src_path)
548
+
549
+ stmt_arg = node.args[0]
550
+ if (
551
+ len(node.args) != 1
552
+ or not isinstance(stmt_arg, ast.Constant)
553
+ or not isinstance(stmt_arg.value, str)
554
+ ):
555
+ msg = (
556
+ f"Invalid positional arguments for {sql_fn_name} "
557
+ f"at {relative_path}:{lineno}, "
558
+ "expected a single string literal"
559
+ )
560
+ raise TypeError(msg)
561
+
562
+ stmt = stmt_arg.value
563
+
564
+ row_type = None
565
+ for kw in node.keywords:
566
+ if not isinstance(kw.value, ast.Constant) or not isinstance(
567
+ kw.value.value, str
568
+ ):
569
+ msg = (
570
+ f"Invalid keyword argument {kw.arg} for {sql_fn_name} "
571
+ f"at {relative_path}:{lineno}, expected a string literal"
572
+ )
573
+ raise TypeError(msg)
574
+ if kw.arg == "row_type":
575
+ row_type = kw.value.value
576
+ break
577
+
578
+ yield CodeQuery(
579
+ stmt=stmt,
580
+ row_type=row_type,
581
+ file=relative_path,
582
+ lineno=lineno,
583
+ )
584
+
585
+
586
+ def indent_block(block: str, indent: str) -> str:
587
+ return "\n".join(
588
+ indent + line if i > 0 and line.strip() else line
589
+ for i, line in enumerate(block.split("\n"))
590
+ )
591
+
592
+
593
+ def write_if_changed(path: Path, new_content: str) -> bool:
594
+ path.parent.mkdir(parents=True, exist_ok=True)
595
+ existing_content = path.read_text(encoding="utf-8") if path.exists() else None
596
+ if existing_content == new_content:
597
+ return False
598
+ path.write_text(new_content, encoding="utf-8")
599
+ path.touch()
600
+ return True
@@ -0,0 +1,2 @@
1
+ def test_nothing():
2
+ pass
@@ -0,0 +1,132 @@
1
+ from collections.abc import AsyncIterator
2
+ from collections.abc import Sequence
3
+ from contextlib import asynccontextmanager
4
+ from contextvars import ContextVar
5
+ from typing import Any
6
+ from typing import Literal
7
+ from typing import Self
8
+ from typing import overload
9
+
10
+ import psycopg
11
+ import psycopg.rows
12
+ import psycopg_pool
13
+
14
+
15
+ class NoRowsError(Exception):
16
+ pass
17
+
18
+
19
+ class TooManyRowsError(Exception):
20
+ pass
21
+
22
+
23
+ class ConnectionPool:
24
+ def __init__(
25
+ self,
26
+ conninfo: str,
27
+ *,
28
+ name: str | None = None,
29
+ application_name: str | None = None,
30
+ ) -> None:
31
+ self.conninfo = conninfo
32
+ self.name = name
33
+ self.application_name = application_name
34
+ self._init_psycopg_pool()
35
+
36
+ async def close(self) -> None:
37
+ await self.psycopg_pool.close()
38
+ self._init_psycopg_pool()
39
+
40
+ async def __aenter__(self) -> Self:
41
+ return self
42
+
43
+ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
44
+ await self.close()
45
+
46
+ async def await_connections(self) -> None:
47
+ await self.psycopg_pool.open(wait=True)
48
+
49
+ async def check(self) -> None:
50
+ await self.psycopg_pool.open()
51
+ await self.psycopg_pool.check()
52
+
53
+ @asynccontextmanager
54
+ async def connection(self) -> AsyncIterator[psycopg.AsyncConnection]:
55
+ await self.psycopg_pool.open()
56
+ async with self.psycopg_pool.connection() as conn:
57
+ yield conn
58
+
59
+ def _init_psycopg_pool(self) -> None:
60
+ self.psycopg_pool = psycopg_pool.AsyncConnectionPool(
61
+ self.conninfo,
62
+ open=False,
63
+ name=self.name,
64
+ kwargs={
65
+ "application_name": self.application_name,
66
+ # https://www.psycopg.org/psycopg3/docs/basic/transactions.html#autocommit-transactions
67
+ "autocommit": True,
68
+ },
69
+ )
70
+
71
+ @asynccontextmanager
72
+ async def connection_in_context(
73
+ self, context_var: ContextVar[psycopg.AsyncConnection | None]
74
+ ) -> AsyncIterator[psycopg.AsyncConnection]:
75
+ conn = context_var.get()
76
+ if conn is not None:
77
+ yield conn
78
+ return
79
+ async with self.connection() as conn:
80
+ token = context_var.set(conn)
81
+ try:
82
+ yield conn
83
+ finally:
84
+ context_var.reset(token)
85
+
86
+
87
+ def get_one_row[T](rows: list[T]) -> T:
88
+ if len(rows) == 0:
89
+ raise NoRowsError
90
+ if len(rows) > 1:
91
+ raise TooManyRowsError
92
+ return rows[0]
93
+
94
+
95
+ def get_one_row_or_none[T](rows: list[T]) -> T | None:
96
+ if len(rows) == 0:
97
+ return None
98
+ if len(rows) > 1:
99
+ raise TooManyRowsError
100
+ return rows[0]
101
+
102
+
103
+ @overload
104
+ def typed_scalar_row[T](
105
+ typ: type[T], *, not_null: Literal[True]
106
+ ) -> psycopg.rows.BaseRowFactory[T]: ...
107
+
108
+
109
+ @overload
110
+ def typed_scalar_row[T](
111
+ typ: type[T], *, not_null: Literal[False]
112
+ ) -> psycopg.rows.BaseRowFactory[T | None]: ...
113
+
114
+
115
+ def typed_scalar_row[T](
116
+ typ: type[T], *, not_null: bool
117
+ ) -> psycopg.rows.BaseRowFactory[T | None]:
118
+ def typed_scalar_row_(cursor) -> psycopg.rows.RowMaker[T | None]:
119
+ scalar_row_ = psycopg.rows.scalar_row(cursor)
120
+
121
+ def typed_scalar_row__(values: Sequence[Any]) -> T | None:
122
+ val = scalar_row_(values)
123
+ if not not_null and val is None:
124
+ return None
125
+ if not isinstance(val, typ):
126
+ msg = f"Expected scalar of type {typ}, got {type(val)}"
127
+ raise TypeError(msg)
128
+ return val
129
+
130
+ return typed_scalar_row__
131
+
132
+ return typed_scalar_row_
@@ -0,0 +1,195 @@
1
+ import json
2
+ import re
3
+ import shutil
4
+ import subprocess # noqa: S404
5
+ import tempfile
6
+ import textwrap
7
+ from pathlib import Path
8
+
9
+ import pydantic
10
+
11
+ SQLC_QUERY_TPL = """
12
+ -- name: ${name} :exec
13
+ ${stmt};
14
+ """
15
+
16
+
17
+ class CatalogReference(pydantic.BaseModel):
18
+ catalog: str
19
+ schema_name: str = pydantic.Field(..., alias="schema")
20
+ name: str
21
+
22
+
23
+ class Column(pydantic.BaseModel):
24
+ name: str
25
+ not_null: bool
26
+ is_array: bool
27
+ comment: str
28
+ length: int
29
+ is_named_param: bool
30
+ is_func_call: bool
31
+ scope: str
32
+ table: CatalogReference | None
33
+ table_alias: str
34
+ type: CatalogReference
35
+ is_sqlc_slice: bool
36
+ embed_table: None
37
+ original_name: str
38
+ unsigned: bool
39
+ array_dims: int
40
+
41
+
42
+ class Table(pydantic.BaseModel):
43
+ rel: CatalogReference
44
+ columns: list[Column]
45
+ comment: str
46
+
47
+
48
+ class Enum(pydantic.BaseModel):
49
+ name: str
50
+ vals: list[str]
51
+ comment: str
52
+
53
+
54
+ class CompositeType(pydantic.BaseModel):
55
+ name: str
56
+ comment: str
57
+
58
+
59
+ class Schema(pydantic.BaseModel):
60
+ comment: str
61
+ name: str
62
+ tables: list[Table]
63
+ enums: list[Enum]
64
+ composite_types: list[CompositeType]
65
+
66
+ def has_enum(self, name: str) -> bool:
67
+ return any(e.name == name for e in self.enums)
68
+
69
+
70
+ class Catalog(pydantic.BaseModel):
71
+ default_schema: str
72
+ name: str
73
+ schemas: list[Schema]
74
+
75
+ def schema_by_name(self, name: str) -> Schema:
76
+ for schema in self.schemas:
77
+ if schema.name == name:
78
+ return schema
79
+ msg = f"Schema not found: {name}"
80
+ raise ValueError(msg)
81
+
82
+ def schema_by_ref(self, ref: CatalogReference | None) -> Schema:
83
+ schema = self.default_schema
84
+ if ref and ref.schema_name:
85
+ schema = ref.schema_name
86
+ return self.schema_by_name(schema)
87
+
88
+
89
+ class QueryParameter(pydantic.BaseModel):
90
+ number: int
91
+ column: Column
92
+
93
+
94
+ class Query(pydantic.BaseModel):
95
+ text: str
96
+ name: str
97
+ cmd: str
98
+ columns: list[Column]
99
+ params: list[QueryParameter]
100
+
101
+
102
+ class SQLCResult(pydantic.BaseModel):
103
+ error: str | None = None
104
+ catalog: Catalog
105
+ queries: list[Query]
106
+
107
+ def used_schemas(self) -> list[str]:
108
+ result = {
109
+ c.table.schema_name
110
+ for q in self.queries
111
+ for c in q.columns
112
+ if c.table is not None
113
+ }
114
+ if "" in result:
115
+ result.remove("")
116
+ result.add(self.catalog.default_schema)
117
+ return list(result)
118
+
119
+
120
+ def run_sqlc(
121
+ schema_path: Path,
122
+ queries: list[tuple[str, str]],
123
+ *,
124
+ dsn: str | None,
125
+ debug_path: Path | None = None,
126
+ ) -> SQLCResult:
127
+ if not schema_path.exists():
128
+ msg = f"Schema file not found: {schema_path}"
129
+ raise ValueError(msg)
130
+
131
+ if not queries:
132
+ return SQLCResult(
133
+ catalog=Catalog(default_schema="", name="", schemas=[]),
134
+ queries=[],
135
+ )
136
+
137
+ queries = list({q[0]: q for q in queries}.values())
138
+
139
+ with tempfile.TemporaryDirectory() as tempdir:
140
+ queries_path = Path(tempdir) / "queries.sql"
141
+ queries_path.write_text(
142
+ "\n\n".join(
143
+ f"-- name: {name} :exec\n{preprocess_sql(stmt)};"
144
+ for name, stmt in queries
145
+ ),
146
+ encoding="utf-8",
147
+ )
148
+
149
+ (Path(tempdir) / "schema.sql").symlink_to(schema_path.absolute())
150
+
151
+ config_path = Path(tempdir) / "sqlc.json"
152
+ sqlc_config = {
153
+ "version": "2",
154
+ "sql": [
155
+ {
156
+ "schema": "schema.sql",
157
+ "queries": ["queries.sql"],
158
+ "engine": "postgresql",
159
+ "database": {"uri": dsn} if dsn else None,
160
+ "gen": {"json": {"out": ".", "filename": "out.json"}},
161
+ }
162
+ ],
163
+ }
164
+ config_path.write_text(json.dumps(sqlc_config, indent=2), encoding="utf-8")
165
+
166
+ sqlc_run_result = subprocess.run( # noqa: S603
167
+ ["/usr/local/bin/sqlc", "generate", "--file", str(config_path)],
168
+ capture_output=True,
169
+ check=False,
170
+ )
171
+
172
+ json_out_path = Path(tempdir) / "out.json"
173
+
174
+ if debug_path:
175
+ debug_path.absolute().mkdir(parents=True, exist_ok=True)
176
+ shutil.copy(queries_path, debug_path)
177
+ shutil.copy(schema_path, debug_path / "schema.sql")
178
+ shutil.copy(config_path, debug_path)
179
+ if json_out_path.exists():
180
+ shutil.copy(json_out_path, debug_path)
181
+ elif (debug_path / "out.json").exists():
182
+ (debug_path / "out.json").unlink()
183
+
184
+ if not json_out_path.exists():
185
+ return SQLCResult(
186
+ error=sqlc_run_result.stderr.decode().strip(),
187
+ catalog=Catalog(default_schema="", name="", schemas=[]),
188
+ queries=[],
189
+ )
190
+ return SQLCResult.model_validate_json(json_out_path.read_text(encoding="utf-8"))
191
+
192
+
193
+ def preprocess_sql(stmt: str) -> str:
194
+ stmt = re.sub(r"@(\w+)\?", r"sqlc.narg('\1')", stmt)
195
+ return textwrap.dedent(stmt).strip()