iron-sql 0.2.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- iron_sql-0.2.5/LICENSE +21 -0
- iron_sql-0.2.5/PKG-INFO +64 -0
- iron_sql-0.2.5/README.md +40 -0
- iron_sql-0.2.5/pyproject.toml +129 -0
- iron_sql-0.2.5/src/iron_sql/__init__.py +7 -0
- iron_sql-0.2.5/src/iron_sql/generator.py +726 -0
- iron_sql-0.2.5/src/iron_sql/runtime.py +135 -0
- iron_sql-0.2.5/src/iron_sql/sqlc.py +222 -0
iron_sql-0.2.5/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Ilia Ablamonov
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
iron_sql-0.2.5/PKG-INFO
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: iron-sql
|
|
3
|
+
Version: 0.2.5
|
|
4
|
+
Summary: iron_sql generates typed async PostgreSQL clients and runtime helpers from schemas and SQL queries
|
|
5
|
+
Keywords: postgresql,sql,sqlc,psycopg,codegen,async
|
|
6
|
+
Author: Ilia Ablamonov
|
|
7
|
+
Author-email: Ilia Ablamonov <ilia@flamefork.ru>
|
|
8
|
+
License-Expression: MIT
|
|
9
|
+
License-File: LICENSE
|
|
10
|
+
Classifier: Development Status :: 3 - Alpha
|
|
11
|
+
Classifier: Intended Audience :: Developers
|
|
12
|
+
Classifier: Topic :: Software Development :: Libraries
|
|
13
|
+
Classifier: Programming Language :: Python :: 3
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
15
|
+
Requires-Dist: inflection>=0.5.1
|
|
16
|
+
Requires-Dist: psycopg>=3.3.2
|
|
17
|
+
Requires-Dist: psycopg-pool>=3.3.0
|
|
18
|
+
Requires-Dist: pydantic>=2.12.4
|
|
19
|
+
Requires-Python: >=3.13
|
|
20
|
+
Project-URL: Homepage, https://github.com/Flamefork/iron_sql
|
|
21
|
+
Project-URL: Repository, https://github.com/Flamefork/iron_sql.git
|
|
22
|
+
Project-URL: Issues, https://github.com/Flamefork/iron_sql/issues
|
|
23
|
+
Description-Content-Type: text/markdown
|
|
24
|
+
|
|
25
|
+
# iron_sql
|
|
26
|
+
|
|
27
|
+
iron_sql keeps SQL close to Python call sites while giving you typed, async query helpers. You write SQL once, keep it in version control, and get generated clients that match your schema without hand-written boilerplate.
|
|
28
|
+
|
|
29
|
+
## Why use it
|
|
30
|
+
- SQL-first workflow: write queries where they are used; no ORM layer to fight.
|
|
31
|
+
- Strong typing: generated dataclasses and method signatures flow through your IDE and type checker.
|
|
32
|
+
- Async-ready: built on `psycopg` with pooled connections and transaction helpers.
|
|
33
|
+
- Safe-by-default: helper methods enforce expected row counts instead of returning silent `None`.
|
|
34
|
+
|
|
35
|
+
## Quick start
|
|
36
|
+
1. Install `iron_sql`, `psycopg`, `psycopg-pool`, and `pydantic`.
|
|
37
|
+
2. Install [`sqlc` v2](https://docs.sqlc.dev/en/latest/overview/install.html) and ensure it is available in your PATH.
|
|
38
|
+
3. Add a Postgres schema dump, for example `db/mydatabase_schema.sql`.
|
|
39
|
+
4. Call `generate_sql_package(schema_path=..., package_full_name=..., dsn_import=...)` from a small script or task. The generator scans your code (defaults to current directory), runs `sqlc`, and writes a module such as `myapp/db/mydatabase.py`.
|
|
40
|
+
|
|
41
|
+
## Authoring queries
|
|
42
|
+
- Use the package helper for your DB, e.g. `mydatabase_sql("select ...")`. The SQL string must be a literal so the generator can find it.
|
|
43
|
+
- Named parameters:
|
|
44
|
+
- Required: `@param`
|
|
45
|
+
- Optional: `@param?` (expands to `sqlc.narg('param')`)
|
|
46
|
+
- Positional placeholders (`$1`) stay as-is.
|
|
47
|
+
- Multi-column results can opt into a custom dataclass with `row_type="MyResult"`. Single-column queries return a scalar type; statements without results expose `execute()`.
|
|
48
|
+
|
|
49
|
+
## Using generated clients
|
|
50
|
+
- `*_sql("...")` returns a query object with methods derived from the result shape:
|
|
51
|
+
- `execute()` when no rows are returned.
|
|
52
|
+
- `query_all_rows()`, `query_single_row()`, `query_optional_row()` for result sets.
|
|
53
|
+
- `*_connection()` yields a pooled `psycopg.AsyncConnection`; `*_transaction()` wraps it in a transaction context.
|
|
54
|
+
- JSONB params are sent with `pgjson.Jsonb`; scalar row factories validate types and raise when they do not match.
|
|
55
|
+
|
|
56
|
+
## Adding another database package
|
|
57
|
+
Provide the schema file and DSN import string, then call `generate_sql_package()` with:
|
|
58
|
+
- `schema_path`: path to the schema SQL file (relative to `src_path`).
|
|
59
|
+
- `package_full_name`: target module, e.g. `myapp.db`.
|
|
60
|
+
- `dsn_import`: import path to a DSN string, e.g. `myapp.config:CONFIG.db_url.get_value()`.
|
|
61
|
+
- `src_path`: optional base source path for scanning queries (defaults current directory).
|
|
62
|
+
- `sqlc_path`: optional path to the sqlc binary if not in PATH (e.g., `Path("/custom/bin/sqlc")`).
|
|
63
|
+
- `tempdir_path`: optional path for temporary file generation (useful for Docker mounts).
|
|
64
|
+
- Optional `application_name`, `debug_path`, and `to_pascal_fn` if you need naming overrides or want to keep `sqlc` inputs for inspection.
|
iron_sql-0.2.5/README.md
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
# iron_sql
|
|
2
|
+
|
|
3
|
+
iron_sql keeps SQL close to Python call sites while giving you typed, async query helpers. You write SQL once, keep it in version control, and get generated clients that match your schema without hand-written boilerplate.
|
|
4
|
+
|
|
5
|
+
## Why use it
|
|
6
|
+
- SQL-first workflow: write queries where they are used; no ORM layer to fight.
|
|
7
|
+
- Strong typing: generated dataclasses and method signatures flow through your IDE and type checker.
|
|
8
|
+
- Async-ready: built on `psycopg` with pooled connections and transaction helpers.
|
|
9
|
+
- Safe-by-default: helper methods enforce expected row counts instead of returning silent `None`.
|
|
10
|
+
|
|
11
|
+
## Quick start
|
|
12
|
+
1. Install `iron_sql`, `psycopg`, `psycopg-pool`, and `pydantic`.
|
|
13
|
+
2. Install [`sqlc` v2](https://docs.sqlc.dev/en/latest/overview/install.html) and ensure it is available in your PATH.
|
|
14
|
+
3. Add a Postgres schema dump, for example `db/mydatabase_schema.sql`.
|
|
15
|
+
4. Call `generate_sql_package(schema_path=..., package_full_name=..., dsn_import=...)` from a small script or task. The generator scans your code (defaults to current directory), runs `sqlc`, and writes a module such as `myapp/db/mydatabase.py`.
|
|
16
|
+
|
|
17
|
+
## Authoring queries
|
|
18
|
+
- Use the package helper for your DB, e.g. `mydatabase_sql("select ...")`. The SQL string must be a literal so the generator can find it.
|
|
19
|
+
- Named parameters:
|
|
20
|
+
- Required: `@param`
|
|
21
|
+
- Optional: `@param?` (expands to `sqlc.narg('param')`)
|
|
22
|
+
- Positional placeholders (`$1`) stay as-is.
|
|
23
|
+
- Multi-column results can opt into a custom dataclass with `row_type="MyResult"`. Single-column queries return a scalar type; statements without results expose `execute()`.
|
|
24
|
+
|
|
25
|
+
## Using generated clients
|
|
26
|
+
- `*_sql("...")` returns a query object with methods derived from the result shape:
|
|
27
|
+
- `execute()` when no rows are returned.
|
|
28
|
+
- `query_all_rows()`, `query_single_row()`, `query_optional_row()` for result sets.
|
|
29
|
+
- `*_connection()` yields a pooled `psycopg.AsyncConnection`; `*_transaction()` wraps it in a transaction context.
|
|
30
|
+
- JSONB params are sent with `pgjson.Jsonb`; scalar row factories validate types and raise when they do not match.
|
|
31
|
+
|
|
32
|
+
## Adding another database package
|
|
33
|
+
Provide the schema file and DSN import string, then call `generate_sql_package()` with:
|
|
34
|
+
- `schema_path`: path to the schema SQL file (relative to `src_path`).
|
|
35
|
+
- `package_full_name`: target module, e.g. `myapp.db`.
|
|
36
|
+
- `dsn_import`: import path to a DSN string, e.g. `myapp.config:CONFIG.db_url.get_value()`.
|
|
37
|
+
- `src_path`: optional base source path for scanning queries (defaults current directory).
|
|
38
|
+
- `sqlc_path`: optional path to the sqlc binary if not in PATH (e.g., `Path("/custom/bin/sqlc")`).
|
|
39
|
+
- `tempdir_path`: optional path for temporary file generation (useful for Docker mounts).
|
|
40
|
+
- Optional `application_name`, `debug_path`, and `to_pascal_fn` if you need naming overrides or want to keep `sqlc` inputs for inspection.
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "iron-sql"
|
|
3
|
+
version = "0.2.5"
|
|
4
|
+
|
|
5
|
+
description = "iron_sql generates typed async PostgreSQL clients and runtime helpers from schemas and SQL queries"
|
|
6
|
+
readme = "README.md"
|
|
7
|
+
authors = [{ name = "Ilia Ablamonov", email = "ilia@flamefork.ru" }]
|
|
8
|
+
license = "MIT"
|
|
9
|
+
license-files = ["LICENSE"]
|
|
10
|
+
keywords = ["postgresql", "sql", "sqlc", "psycopg", "codegen", "async"]
|
|
11
|
+
classifiers = [
|
|
12
|
+
"Development Status :: 3 - Alpha",
|
|
13
|
+
"Intended Audience :: Developers",
|
|
14
|
+
"Topic :: Software Development :: Libraries",
|
|
15
|
+
"Programming Language :: Python :: 3",
|
|
16
|
+
"Programming Language :: Python :: 3.13",
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
requires-python = ">=3.13"
|
|
20
|
+
dependencies = [
|
|
21
|
+
"inflection>=0.5.1",
|
|
22
|
+
"psycopg>=3.3.2",
|
|
23
|
+
"psycopg-pool>=3.3.0",
|
|
24
|
+
"pydantic>=2.12.4",
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
[project.urls]
|
|
28
|
+
Homepage = "https://github.com/Flamefork/iron_sql"
|
|
29
|
+
Repository = "https://github.com/Flamefork/iron_sql.git"
|
|
30
|
+
Issues = "https://github.com/Flamefork/iron_sql/issues"
|
|
31
|
+
|
|
32
|
+
[build-system]
|
|
33
|
+
requires = ["uv_build>=0.9.4,<0.10.0"]
|
|
34
|
+
build-backend = "uv_build"
|
|
35
|
+
|
|
36
|
+
[dependency-groups]
|
|
37
|
+
dev = [
|
|
38
|
+
"basedpyright>=1.31.7",
|
|
39
|
+
"psycopg[binary]>=3.3.2",
|
|
40
|
+
"pytest>=8.4.2",
|
|
41
|
+
"pytest-asyncio>=1.2.0",
|
|
42
|
+
"pytest-cov>=7.0.0",
|
|
43
|
+
"pytest-randomly>=4.0.1",
|
|
44
|
+
"ruff>=0.14.1",
|
|
45
|
+
"testcontainers>=4",
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
[tool.pyright]
|
|
49
|
+
typeCheckingMode = "strict"
|
|
50
|
+
reportUnknownArgumentType = "none"
|
|
51
|
+
reportUnknownLambdaType = "none"
|
|
52
|
+
reportUnknownMemberType = "none"
|
|
53
|
+
reportUnknownParameterType = "none"
|
|
54
|
+
reportUnknownVariableType = "none"
|
|
55
|
+
reportMissingParameterType = "none"
|
|
56
|
+
reportMissingTypeArgument = "none"
|
|
57
|
+
reportMissingTypeStubs = "none"
|
|
58
|
+
deprecateTypingAliases = true
|
|
59
|
+
reportImportCycles = true
|
|
60
|
+
reportUnnecessaryTypeIgnoreComment = true
|
|
61
|
+
reportUnreachable = true
|
|
62
|
+
reportIgnoreCommentWithoutRule = true
|
|
63
|
+
reportImplicitRelativeImport = true
|
|
64
|
+
|
|
65
|
+
[tool.ruff]
|
|
66
|
+
target-version = "py313"
|
|
67
|
+
|
|
68
|
+
[tool.ruff.format]
|
|
69
|
+
preview = true
|
|
70
|
+
|
|
71
|
+
[tool.ruff.lint]
|
|
72
|
+
preview = true
|
|
73
|
+
select = ["ALL"]
|
|
74
|
+
ignore = [
|
|
75
|
+
"ANN",
|
|
76
|
+
"COM812",
|
|
77
|
+
"CPY",
|
|
78
|
+
"D",
|
|
79
|
+
"FIX",
|
|
80
|
+
"G004",
|
|
81
|
+
"ISC001",
|
|
82
|
+
"PLC1901",
|
|
83
|
+
"PLR0911",
|
|
84
|
+
"PLR0915",
|
|
85
|
+
"PLR6301",
|
|
86
|
+
"RUF001",
|
|
87
|
+
"RUF002",
|
|
88
|
+
"RUF003",
|
|
89
|
+
"TC006",
|
|
90
|
+
"TD",
|
|
91
|
+
]
|
|
92
|
+
|
|
93
|
+
[tool.ruff.lint.per-file-ignores]
|
|
94
|
+
"test_*.py" = ["A002", "PLR2004", "S", "FBT"]
|
|
95
|
+
|
|
96
|
+
[tool.ruff.lint.isort]
|
|
97
|
+
force-single-line = true
|
|
98
|
+
|
|
99
|
+
[tool.ruff.lint.pylint]
|
|
100
|
+
max-args = 10
|
|
101
|
+
|
|
102
|
+
[tool.ruff.lint.flake8-tidy-imports]
|
|
103
|
+
ban-relative-imports = "all"
|
|
104
|
+
|
|
105
|
+
[tool.pytest.ini_options]
|
|
106
|
+
strict = true
|
|
107
|
+
testpaths = ["tests"]
|
|
108
|
+
filterwarnings = [
|
|
109
|
+
"error",
|
|
110
|
+
"ignore:.*wait_container_is_ready:DeprecationWarning",
|
|
111
|
+
]
|
|
112
|
+
addopts = [
|
|
113
|
+
"--import-mode=importlib",
|
|
114
|
+
"--no-cov-on-fail",
|
|
115
|
+
"--cov-report=term-missing:skip-covered",
|
|
116
|
+
]
|
|
117
|
+
asyncio_mode = "auto"
|
|
118
|
+
asyncio_default_fixture_loop_scope = "function"
|
|
119
|
+
|
|
120
|
+
[tool.coverage.run]
|
|
121
|
+
branch = true
|
|
122
|
+
omit = ["tests/*.py", "testdb.py"]
|
|
123
|
+
data_file = ".coverage/db.sqlite"
|
|
124
|
+
|
|
125
|
+
[tool.coverage.html]
|
|
126
|
+
directory = ".coverage/htmlcov"
|
|
127
|
+
|
|
128
|
+
[tool.coverage.report]
|
|
129
|
+
exclude_also = ["@overload"]
|
|
@@ -0,0 +1,726 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
import dataclasses
|
|
3
|
+
import hashlib
|
|
4
|
+
import importlib
|
|
5
|
+
import logging
|
|
6
|
+
from collections import defaultdict
|
|
7
|
+
from collections.abc import Callable
|
|
8
|
+
from collections.abc import Iterator
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
|
|
12
|
+
import inflection
|
|
13
|
+
from pydantic import alias_generators
|
|
14
|
+
|
|
15
|
+
from iron_sql.sqlc import Catalog
|
|
16
|
+
from iron_sql.sqlc import Column
|
|
17
|
+
from iron_sql.sqlc import Enum
|
|
18
|
+
from iron_sql.sqlc import Query
|
|
19
|
+
from iron_sql.sqlc import SQLCResult
|
|
20
|
+
from iron_sql.sqlc import run_sqlc
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass(kw_only=True, frozen=True)
|
|
26
|
+
class ColumnPySpec:
|
|
27
|
+
name: str
|
|
28
|
+
table: str
|
|
29
|
+
db_type: str
|
|
30
|
+
not_null: bool
|
|
31
|
+
is_array: bool
|
|
32
|
+
py_type: str
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _collect_used_enums(sqlc_res: SQLCResult) -> set[tuple[str, str]]:
|
|
36
|
+
return {
|
|
37
|
+
(schema.name, col.type.name)
|
|
38
|
+
for col in (
|
|
39
|
+
*(c for q in sqlc_res.queries for c in q.columns),
|
|
40
|
+
*(p.column for q in sqlc_res.queries for p in q.params),
|
|
41
|
+
)
|
|
42
|
+
for schema in (sqlc_res.catalog.schema_by_ref(col.type),)
|
|
43
|
+
if schema.has_enum(col.type.name)
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def generate_sql_package( # noqa: PLR0913, PLR0914
|
|
48
|
+
*,
|
|
49
|
+
schema_path: Path,
|
|
50
|
+
package_full_name: str,
|
|
51
|
+
dsn_import: str,
|
|
52
|
+
application_name: str | None = None,
|
|
53
|
+
to_pascal_fn=alias_generators.to_pascal,
|
|
54
|
+
to_snake_fn=alias_generators.to_snake,
|
|
55
|
+
debug_path: Path | None = None,
|
|
56
|
+
src_path: Path = Path(),
|
|
57
|
+
sqlc_path: Path | None = None,
|
|
58
|
+
tempdir_path: Path | None = None,
|
|
59
|
+
sqlc_command: list[str] | None = None,
|
|
60
|
+
) -> bool:
|
|
61
|
+
"""Generate a typed SQL package from schema and queries.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
schema_path: Path to the Postgres schema SQL file (relative to src_path)
|
|
65
|
+
package_full_name: Target module name (e.g., "myapp.mydatabase")
|
|
66
|
+
dsn_import: Import path to DSN string (e.g.,
|
|
67
|
+
"myapp.config:CONFIG.db_url")
|
|
68
|
+
application_name: Optional application name for connection pool
|
|
69
|
+
to_pascal_fn: Function to convert names to PascalCase (default:
|
|
70
|
+
pydantic's to_pascal)
|
|
71
|
+
to_snake_fn: Function to convert names to snake_case (default:
|
|
72
|
+
pydantic's to_snake)
|
|
73
|
+
debug_path: Optional path to save sqlc inputs for inspection
|
|
74
|
+
src_path: Base source path for scanning queries (default: Path())
|
|
75
|
+
sqlc_path: Optional path to sqlc binary if not in PATH
|
|
76
|
+
tempdir_path: Optional path for temporary file generation
|
|
77
|
+
sqlc_command: Optional command prefix to run sqlc
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
True if the package was generated or modified, False otherwise
|
|
81
|
+
"""
|
|
82
|
+
dsn_import_package, dsn_import_path = dsn_import.split(":")
|
|
83
|
+
|
|
84
|
+
package_name = package_full_name.split(".")[-1] # noqa: PLC0207
|
|
85
|
+
sql_fn_name = f"{package_name}_sql"
|
|
86
|
+
|
|
87
|
+
target_package_path = src_path / f"{package_full_name.replace('.', '/')}.py"
|
|
88
|
+
|
|
89
|
+
queries = list(find_all_queries(src_path, sql_fn_name))
|
|
90
|
+
queries = list({q.name: q for q in queries}.values())
|
|
91
|
+
|
|
92
|
+
dsn_package = importlib.import_module(dsn_import_package)
|
|
93
|
+
dsn = eval(dsn_import_path, vars(dsn_package)) # noqa: S307
|
|
94
|
+
|
|
95
|
+
sqlc_res = run_sqlc(
|
|
96
|
+
src_path / schema_path,
|
|
97
|
+
[(q.name, q.stmt) for q in queries],
|
|
98
|
+
dsn=dsn,
|
|
99
|
+
debug_path=debug_path,
|
|
100
|
+
sqlc_path=sqlc_path,
|
|
101
|
+
tempdir_path=tempdir_path,
|
|
102
|
+
sqlc_command=sqlc_command,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
if sqlc_res.error:
|
|
106
|
+
logger.error("Error running SQLC:\n%s", sqlc_res.error)
|
|
107
|
+
return False
|
|
108
|
+
|
|
109
|
+
ordered_entities, result_types = map_entities(
|
|
110
|
+
package_name,
|
|
111
|
+
sqlc_res.queries,
|
|
112
|
+
sqlc_res.catalog,
|
|
113
|
+
sqlc_res.used_schemas(),
|
|
114
|
+
queries,
|
|
115
|
+
to_pascal_fn,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
entities = [render_entity(e.name, e.column_specs) for e in ordered_entities]
|
|
119
|
+
|
|
120
|
+
used_enums = _collect_used_enums(sqlc_res)
|
|
121
|
+
|
|
122
|
+
enums = [
|
|
123
|
+
render_enum_class(e, package_name, to_pascal_fn, to_snake_fn)
|
|
124
|
+
for schema in sqlc_res.catalog.schemas
|
|
125
|
+
for e in schema.enums
|
|
126
|
+
if (schema.name, e.name) in used_enums
|
|
127
|
+
]
|
|
128
|
+
|
|
129
|
+
query_classes = [
|
|
130
|
+
render_query_class(
|
|
131
|
+
q.name,
|
|
132
|
+
q.text,
|
|
133
|
+
package_name,
|
|
134
|
+
[
|
|
135
|
+
(
|
|
136
|
+
column_py_spec(
|
|
137
|
+
p.column,
|
|
138
|
+
sqlc_res.catalog,
|
|
139
|
+
package_name,
|
|
140
|
+
to_pascal_fn,
|
|
141
|
+
to_snake_fn,
|
|
142
|
+
p.number,
|
|
143
|
+
),
|
|
144
|
+
p.column.is_named_param,
|
|
145
|
+
)
|
|
146
|
+
for p in q.params
|
|
147
|
+
],
|
|
148
|
+
result_types[q.name],
|
|
149
|
+
len(q.columns),
|
|
150
|
+
)
|
|
151
|
+
for q in sqlc_res.queries
|
|
152
|
+
]
|
|
153
|
+
|
|
154
|
+
query_overloads = [
|
|
155
|
+
render_query_overload(sql_fn_name, q.name, q.stmt, q.row_type) for q in queries
|
|
156
|
+
]
|
|
157
|
+
|
|
158
|
+
query_dict_entries = [render_query_dict_entry(q.name, q.stmt) for q in queries]
|
|
159
|
+
|
|
160
|
+
new_content = render_package(
|
|
161
|
+
dsn_import_package,
|
|
162
|
+
dsn_import_path,
|
|
163
|
+
package_name,
|
|
164
|
+
sql_fn_name,
|
|
165
|
+
sorted(entities),
|
|
166
|
+
sorted(enums),
|
|
167
|
+
sorted(query_classes),
|
|
168
|
+
sorted(query_overloads),
|
|
169
|
+
sorted(query_dict_entries),
|
|
170
|
+
application_name,
|
|
171
|
+
)
|
|
172
|
+
changed = write_if_changed(target_package_path, new_content + "\n")
|
|
173
|
+
if changed:
|
|
174
|
+
logger.info(f"Generated SQL package {package_full_name}")
|
|
175
|
+
return changed
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def render_package(
|
|
179
|
+
dsn_import_package: str,
|
|
180
|
+
dsn_import_path: str,
|
|
181
|
+
package_name: str,
|
|
182
|
+
sql_fn_name: str,
|
|
183
|
+
entities: list[str],
|
|
184
|
+
enums: list[str],
|
|
185
|
+
query_classes: list[str],
|
|
186
|
+
query_overloads: list[str],
|
|
187
|
+
query_dict_entries: list[str],
|
|
188
|
+
application_name: str | None = None,
|
|
189
|
+
):
|
|
190
|
+
return f"""
|
|
191
|
+
|
|
192
|
+
# Code generated by iron_sql, DO NOT EDIT.
|
|
193
|
+
|
|
194
|
+
# fmt: off
|
|
195
|
+
# pyright: reportUnusedImport=false
|
|
196
|
+
# ruff: noqa: A002
|
|
197
|
+
# ruff: noqa: ARG001
|
|
198
|
+
# ruff: noqa: C901
|
|
199
|
+
# ruff: noqa: E303
|
|
200
|
+
# ruff: noqa: E501
|
|
201
|
+
# ruff: noqa: F401
|
|
202
|
+
# ruff: noqa: FBT001
|
|
203
|
+
# ruff: noqa: I001
|
|
204
|
+
# ruff: noqa: N801
|
|
205
|
+
# ruff: noqa: PLR0912
|
|
206
|
+
# ruff: noqa: PLR0913
|
|
207
|
+
# ruff: noqa: PLR0917
|
|
208
|
+
# ruff: noqa: Q000
|
|
209
|
+
# ruff: noqa: RUF100
|
|
210
|
+
|
|
211
|
+
import datetime
|
|
212
|
+
import decimal
|
|
213
|
+
import uuid
|
|
214
|
+
from collections.abc import AsyncIterator
|
|
215
|
+
from collections.abc import Sequence
|
|
216
|
+
from contextlib import asynccontextmanager
|
|
217
|
+
from contextvars import ContextVar
|
|
218
|
+
from dataclasses import dataclass
|
|
219
|
+
from enum import StrEnum
|
|
220
|
+
from typing import Literal
|
|
221
|
+
from typing import overload
|
|
222
|
+
|
|
223
|
+
import psycopg
|
|
224
|
+
import psycopg.rows
|
|
225
|
+
from psycopg.types import json as pgjson
|
|
226
|
+
|
|
227
|
+
from iron_sql import runtime
|
|
228
|
+
|
|
229
|
+
from {dsn_import_package} import {dsn_import_path.split(".", maxsplit=1)[0]}
|
|
230
|
+
|
|
231
|
+
{package_name.upper()}_POOL = runtime.ConnectionPool(
|
|
232
|
+
{dsn_import_path},
|
|
233
|
+
name="{package_name}",
|
|
234
|
+
application_name={application_name!r},
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
_{package_name}_connection = ContextVar[psycopg.AsyncConnection | None](
|
|
238
|
+
"_{package_name}_connection",
|
|
239
|
+
default=None,
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
@asynccontextmanager
|
|
244
|
+
async def {package_name}_connection() -> AsyncIterator[psycopg.AsyncConnection]:
|
|
245
|
+
async with {package_name.upper()}_POOL.connection_in_context(
|
|
246
|
+
_{package_name}_connection
|
|
247
|
+
) as conn:
|
|
248
|
+
yield conn
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
@asynccontextmanager
|
|
252
|
+
async def {package_name}_transaction() -> AsyncIterator[None]:
|
|
253
|
+
async with {package_name}_connection() as conn, conn.transaction():
|
|
254
|
+
yield
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
{"\n\n\n".join(enums)}
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
{"\n\n\n".join(entities)}
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
class Query:
|
|
264
|
+
pass
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
{"\n\n\n".join(query_classes)}
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
_QUERIES: dict[str, type[Query]] = {{
|
|
271
|
+
{(",\n ").join(query_dict_entries)}
|
|
272
|
+
}}
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
{"\n".join(query_overloads)}
|
|
276
|
+
@overload
|
|
277
|
+
def {sql_fn_name}(stmt: str) -> Query: ...
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def {sql_fn_name}(stmt: str, row_type: str | None = None) -> Query:
|
|
281
|
+
if stmt in _QUERIES:
|
|
282
|
+
return _QUERIES[stmt]()
|
|
283
|
+
msg = f"Unknown statement: {{stmt!r}}"
|
|
284
|
+
raise KeyError(msg)
|
|
285
|
+
|
|
286
|
+
""".strip()
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def render_enum_class(
|
|
290
|
+
enum: Enum,
|
|
291
|
+
package_name: str,
|
|
292
|
+
to_pascal_fn: Callable[[str], str],
|
|
293
|
+
to_snake_fn: Callable[[str], str],
|
|
294
|
+
) -> str:
|
|
295
|
+
class_name = to_pascal_fn(f"{package_name}_{to_snake_fn(enum.name)}")
|
|
296
|
+
members = []
|
|
297
|
+
seen_names: dict[str, int] = {}
|
|
298
|
+
|
|
299
|
+
for val in enum.vals:
|
|
300
|
+
name = to_snake_fn(val).upper()
|
|
301
|
+
name = "".join(c if c.isalnum() else "_" for c in name)
|
|
302
|
+
name = name.strip("_") or "EMPTY"
|
|
303
|
+
if name[0].isdigit():
|
|
304
|
+
name = "NUM" + name
|
|
305
|
+
if name in seen_names:
|
|
306
|
+
seen_names[name] += 1
|
|
307
|
+
name = f"{name}_{seen_names[name]}"
|
|
308
|
+
else:
|
|
309
|
+
seen_names[name] = 1
|
|
310
|
+
members.append(f'{name} = "{val}"')
|
|
311
|
+
|
|
312
|
+
return f"""
|
|
313
|
+
|
|
314
|
+
class {class_name}(StrEnum):
|
|
315
|
+
{indent_block("\n".join(members), " ")}
|
|
316
|
+
|
|
317
|
+
""".strip()
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def render_entity(
|
|
321
|
+
name: str,
|
|
322
|
+
columns: tuple[ColumnPySpec, ...],
|
|
323
|
+
) -> str:
|
|
324
|
+
return f"""
|
|
325
|
+
|
|
326
|
+
@dataclass(kw_only=True)
|
|
327
|
+
class {name}:
|
|
328
|
+
{"\n ".join(f"{c.name}: {c.py_type}" for c in columns)}
|
|
329
|
+
|
|
330
|
+
""".strip()
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def deduplicate_params(
|
|
334
|
+
params: list[tuple[ColumnPySpec, bool]],
|
|
335
|
+
) -> list[tuple[ColumnPySpec, bool]]:
|
|
336
|
+
seen = defaultdict(int)
|
|
337
|
+
result: list[tuple[ColumnPySpec, bool]] = []
|
|
338
|
+
for column, is_named in params:
|
|
339
|
+
seen[column.name] += 1
|
|
340
|
+
new_name = (
|
|
341
|
+
f"{column.name}{seen[column.name]}"
|
|
342
|
+
if seen[column.name] > 1
|
|
343
|
+
else column.name
|
|
344
|
+
)
|
|
345
|
+
new_column = dataclasses.replace(column, name=new_name)
|
|
346
|
+
result.append((new_column, is_named))
|
|
347
|
+
return result
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def serialized_arg(column: ColumnPySpec) -> str:
|
|
351
|
+
match column:
|
|
352
|
+
case ColumnPySpec(db_type="json", not_null=True):
|
|
353
|
+
msg = "Unsupported column type: json"
|
|
354
|
+
raise TypeError(msg)
|
|
355
|
+
case ColumnPySpec(db_type="jsonb", is_array=True):
|
|
356
|
+
msg = "Unsupported column type: jsonb[]"
|
|
357
|
+
raise TypeError(msg)
|
|
358
|
+
case ColumnPySpec(db_type="jsonb", not_null=True, name=name):
|
|
359
|
+
return f"pgjson.Jsonb({name})"
|
|
360
|
+
case ColumnPySpec(db_type="jsonb", not_null=False, name=name):
|
|
361
|
+
return f"pgjson.Jsonb({name}) if {name} is not None else None"
|
|
362
|
+
case ColumnPySpec(name=name):
|
|
363
|
+
return name
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
def render_query_class(
|
|
367
|
+
query_name: str,
|
|
368
|
+
stmt: str,
|
|
369
|
+
package_name: str,
|
|
370
|
+
query_params: list[tuple[ColumnPySpec, bool]],
|
|
371
|
+
result: str,
|
|
372
|
+
columns_num: int,
|
|
373
|
+
) -> str:
|
|
374
|
+
query_params = deduplicate_params(query_params)
|
|
375
|
+
|
|
376
|
+
match [column for column, _ in query_params]:
|
|
377
|
+
case []:
|
|
378
|
+
params_arg = "None"
|
|
379
|
+
case [column]:
|
|
380
|
+
params_arg = f"({serialized_arg(column)},)"
|
|
381
|
+
case columns:
|
|
382
|
+
params_arg = f"({', '.join(serialized_arg(column) for column in columns)})"
|
|
383
|
+
|
|
384
|
+
query_fn_params = [f"{column.name}: {column.py_type}" for column, _ in query_params]
|
|
385
|
+
first_named_param_idx = next(
|
|
386
|
+
(i for i, (_, is_named_param) in enumerate(query_params) if is_named_param), -1
|
|
387
|
+
)
|
|
388
|
+
if first_named_param_idx >= 0:
|
|
389
|
+
query_fn_params.insert(first_named_param_idx, "*")
|
|
390
|
+
query_fn_params.insert(0, "self")
|
|
391
|
+
|
|
392
|
+
base_result = result.removesuffix(" | None")
|
|
393
|
+
|
|
394
|
+
if columns_num == 0:
|
|
395
|
+
row_factory = "psycopg.rows.scalar_row"
|
|
396
|
+
elif columns_num == 1:
|
|
397
|
+
if result.endswith(" | None"):
|
|
398
|
+
row_factory = f"runtime.typed_scalar_row({base_result}, not_null=False)"
|
|
399
|
+
else:
|
|
400
|
+
row_factory = f"runtime.typed_scalar_row({base_result}, not_null=True)"
|
|
401
|
+
else:
|
|
402
|
+
row_factory = f"psycopg.rows.class_row({result})"
|
|
403
|
+
|
|
404
|
+
if columns_num > 0:
|
|
405
|
+
methods = f"""
|
|
406
|
+
|
|
407
|
+
async def query_all_rows({", ".join(query_fn_params)}) -> list[{result}]:
|
|
408
|
+
async with self._execute({params_arg}) as cur:
|
|
409
|
+
return await cur.fetchall()
|
|
410
|
+
|
|
411
|
+
async def query_single_row({", ".join(query_fn_params)}) -> {result}:
|
|
412
|
+
async with self._execute({params_arg}) as cur:
|
|
413
|
+
return runtime.get_one_row(await cur.fetchmany(2))
|
|
414
|
+
|
|
415
|
+
async def query_optional_row({", ".join(query_fn_params)}) -> {base_result} | None:
|
|
416
|
+
async with self._execute({params_arg}) as cur:
|
|
417
|
+
return runtime.get_one_row_or_none(await cur.fetchmany(2))
|
|
418
|
+
|
|
419
|
+
""".strip()
|
|
420
|
+
else:
|
|
421
|
+
methods = f"""
|
|
422
|
+
|
|
423
|
+
async def execute({", ".join(query_fn_params)}) -> None:
|
|
424
|
+
async with self._execute({params_arg}):
|
|
425
|
+
pass
|
|
426
|
+
|
|
427
|
+
""".strip()
|
|
428
|
+
|
|
429
|
+
return f"""
|
|
430
|
+
|
|
431
|
+
class {query_name}(Query):
|
|
432
|
+
@asynccontextmanager
|
|
433
|
+
async def _execute(self, params) -> AsyncIterator[psycopg.AsyncRawCursor[{result}]]:
|
|
434
|
+
stmt = {stmt!r}
|
|
435
|
+
async with (
|
|
436
|
+
{package_name}_connection() as conn,
|
|
437
|
+
psycopg.AsyncRawCursor(conn, row_factory={row_factory}) as cur,
|
|
438
|
+
):
|
|
439
|
+
await cur.execute(stmt, params)
|
|
440
|
+
yield cur
|
|
441
|
+
|
|
442
|
+
{indent_block(methods, " ")}
|
|
443
|
+
|
|
444
|
+
""".strip()
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
def render_query_overload(
|
|
448
|
+
sql_fn_name: str, query_name: str, stmt: str, row_type: str | None
|
|
449
|
+
) -> str:
|
|
450
|
+
result_arg = ""
|
|
451
|
+
if row_type:
|
|
452
|
+
result_arg = f", row_type: Literal[{row_type!r}]"
|
|
453
|
+
|
|
454
|
+
return f"""
|
|
455
|
+
|
|
456
|
+
@overload
|
|
457
|
+
def {sql_fn_name}(stmt: Literal[{stmt!r}]{result_arg}) -> {query_name}: ...
|
|
458
|
+
|
|
459
|
+
""".strip()
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
def render_query_dict_entry(query_name: str, stmt: str) -> str:
|
|
463
|
+
return f"{stmt!r}: {query_name}"
|
|
464
|
+
|
|
465
|
+
|
|
466
|
+
@dataclass(kw_only=True)
|
|
467
|
+
class CodeQuery:
|
|
468
|
+
stmt: str
|
|
469
|
+
row_type: str | None
|
|
470
|
+
file: Path
|
|
471
|
+
lineno: int
|
|
472
|
+
|
|
473
|
+
@property
|
|
474
|
+
def name(self) -> str:
|
|
475
|
+
md5_hash = hashlib.md5(self.stmt.encode(), usedforsecurity=False).hexdigest()
|
|
476
|
+
return f"Query_{md5_hash}{'_' + self.row_type if self.row_type else ''}"
|
|
477
|
+
|
|
478
|
+
@property
|
|
479
|
+
def location(self) -> str:
|
|
480
|
+
return f"{self.file}:{self.lineno}"
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
@dataclass(kw_only=True)
|
|
484
|
+
class SQLEntity:
|
|
485
|
+
package_name: str
|
|
486
|
+
set_name: str | None
|
|
487
|
+
table_name: str | None
|
|
488
|
+
columns: list[Column]
|
|
489
|
+
catalog: Catalog = dataclasses.field(repr=False)
|
|
490
|
+
to_pascal_fn: Callable[[str], str]
|
|
491
|
+
to_snake_fn: Callable[[str], str] = inflection.underscore
|
|
492
|
+
|
|
493
|
+
@property
|
|
494
|
+
def name(self) -> str:
|
|
495
|
+
if self.set_name:
|
|
496
|
+
return self.set_name
|
|
497
|
+
if self.table_name:
|
|
498
|
+
return self.to_pascal_fn(
|
|
499
|
+
f"{self.package_name}_{inflection.singularize(self.table_name)}"
|
|
500
|
+
)
|
|
501
|
+
hash_base = repr(self.column_specs)
|
|
502
|
+
md5_hash = hashlib.md5(hash_base.encode(), usedforsecurity=False).hexdigest()
|
|
503
|
+
return f"QueryResult_{md5_hash}"
|
|
504
|
+
|
|
505
|
+
@property
|
|
506
|
+
def column_specs(self) -> tuple[ColumnPySpec, ...]:
|
|
507
|
+
return tuple(
|
|
508
|
+
column_py_spec(
|
|
509
|
+
c, self.catalog, self.package_name, self.to_pascal_fn, self.to_snake_fn
|
|
510
|
+
)
|
|
511
|
+
for c in self.columns
|
|
512
|
+
)
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
def map_entities(
|
|
516
|
+
package_name: str,
|
|
517
|
+
queries_from_sqlc: list[Query],
|
|
518
|
+
catalog: Catalog,
|
|
519
|
+
used_schemas: list[str],
|
|
520
|
+
queries_from_code: list[CodeQuery],
|
|
521
|
+
to_pascal_fn: Callable[[str], str],
|
|
522
|
+
to_snake_fn: Callable[[str], str] = inflection.underscore,
|
|
523
|
+
):
|
|
524
|
+
row_types = {q.name: q.row_type for q in queries_from_code}
|
|
525
|
+
|
|
526
|
+
table_entities = [
|
|
527
|
+
SQLEntity(
|
|
528
|
+
package_name=package_name,
|
|
529
|
+
set_name=None,
|
|
530
|
+
table_name=t.rel.name,
|
|
531
|
+
columns=t.columns,
|
|
532
|
+
catalog=catalog,
|
|
533
|
+
to_pascal_fn=to_pascal_fn,
|
|
534
|
+
to_snake_fn=to_snake_fn,
|
|
535
|
+
)
|
|
536
|
+
for sch in used_schemas
|
|
537
|
+
for t in catalog.schema_by_name(sch).tables
|
|
538
|
+
]
|
|
539
|
+
specs_to_entities = {e.column_specs: e for e in table_entities}
|
|
540
|
+
|
|
541
|
+
for q in queries_from_sqlc:
|
|
542
|
+
if row_types[q.name] and not q.columns:
|
|
543
|
+
msg = f"Query has row_type={row_types[q.name]} but no result"
|
|
544
|
+
raise ValueError(msg)
|
|
545
|
+
if row_types[q.name] and len(q.columns) == 1:
|
|
546
|
+
msg = f"Query has row_type={row_types[q.name]} but only one column"
|
|
547
|
+
raise ValueError(msg)
|
|
548
|
+
|
|
549
|
+
query_result_entities = {
|
|
550
|
+
q.name: SQLEntity(
|
|
551
|
+
package_name=package_name,
|
|
552
|
+
set_name=row_types[q.name],
|
|
553
|
+
table_name=None,
|
|
554
|
+
columns=q.columns,
|
|
555
|
+
catalog=catalog,
|
|
556
|
+
to_pascal_fn=to_pascal_fn,
|
|
557
|
+
to_snake_fn=to_snake_fn,
|
|
558
|
+
)
|
|
559
|
+
for q in queries_from_sqlc
|
|
560
|
+
if len(q.columns) > 1
|
|
561
|
+
}
|
|
562
|
+
|
|
563
|
+
unique_entities = {
|
|
564
|
+
e.column_specs: specs_to_entities.get(e.column_specs, e)
|
|
565
|
+
for e in query_result_entities.values()
|
|
566
|
+
}
|
|
567
|
+
ordered_entities = sorted(
|
|
568
|
+
unique_entities.values(),
|
|
569
|
+
key=lambda e: (e.table_name is None, e.table_name or ""),
|
|
570
|
+
)
|
|
571
|
+
|
|
572
|
+
result_types = {}
|
|
573
|
+
for q in queries_from_sqlc:
|
|
574
|
+
if len(q.columns) == 0:
|
|
575
|
+
result_types[q.name] = "None"
|
|
576
|
+
elif len(q.columns) == 1:
|
|
577
|
+
result_types[q.name] = column_py_spec(
|
|
578
|
+
q.columns[0], catalog, package_name, to_pascal_fn, to_snake_fn
|
|
579
|
+
).py_type
|
|
580
|
+
else:
|
|
581
|
+
column_spec = query_result_entities[q.name].column_specs
|
|
582
|
+
result_types[q.name] = unique_entities[column_spec].name
|
|
583
|
+
|
|
584
|
+
return ordered_entities, result_types
|
|
585
|
+
|
|
586
|
+
|
|
587
|
+
def column_py_spec( # noqa: C901, PLR0912
|
|
588
|
+
column: Column,
|
|
589
|
+
catalog: Catalog,
|
|
590
|
+
package_name: str,
|
|
591
|
+
to_pascal_fn: Callable[[str], str],
|
|
592
|
+
to_snake_fn: Callable[[str], str] = inflection.underscore,
|
|
593
|
+
number: int = 0,
|
|
594
|
+
) -> ColumnPySpec:
|
|
595
|
+
db_type = column.type.name.removeprefix("pg_catalog.")
|
|
596
|
+
match db_type:
|
|
597
|
+
case "bool" | "boolean":
|
|
598
|
+
py_type = "bool"
|
|
599
|
+
case (
|
|
600
|
+
"int2"
|
|
601
|
+
| "int4"
|
|
602
|
+
| "int8"
|
|
603
|
+
| "smallint"
|
|
604
|
+
| "integer"
|
|
605
|
+
| "bigint"
|
|
606
|
+
| "serial"
|
|
607
|
+
| "bigserial"
|
|
608
|
+
):
|
|
609
|
+
py_type = "int"
|
|
610
|
+
case "float4" | "float8":
|
|
611
|
+
py_type = "float"
|
|
612
|
+
case "numeric":
|
|
613
|
+
py_type = "decimal.Decimal"
|
|
614
|
+
case "varchar" | "text":
|
|
615
|
+
py_type = "str"
|
|
616
|
+
case "bytea":
|
|
617
|
+
py_type = "bytes"
|
|
618
|
+
case "json" | "jsonb":
|
|
619
|
+
py_type = "object"
|
|
620
|
+
case "date":
|
|
621
|
+
py_type = "datetime.date"
|
|
622
|
+
case "time" | "timetz":
|
|
623
|
+
py_type = "datetime.time"
|
|
624
|
+
case "timestamp" | "timestamptz":
|
|
625
|
+
py_type = "datetime.datetime"
|
|
626
|
+
case "uuid":
|
|
627
|
+
py_type = "uuid.UUID"
|
|
628
|
+
case "any" | "anyelement":
|
|
629
|
+
py_type = "object"
|
|
630
|
+
case enum if catalog.schema_by_ref(column.type).has_enum(enum):
|
|
631
|
+
py_type = (
|
|
632
|
+
to_pascal_fn(f"{package_name}_{to_snake_fn(enum)}")
|
|
633
|
+
if package_name
|
|
634
|
+
else "str"
|
|
635
|
+
)
|
|
636
|
+
case _:
|
|
637
|
+
logger.warning(f"Unknown SQL type: {column.type.name} ({column.name})")
|
|
638
|
+
py_type = "object"
|
|
639
|
+
|
|
640
|
+
if column.is_array:
|
|
641
|
+
py_type = f"Sequence[{py_type}]"
|
|
642
|
+
|
|
643
|
+
if not column.not_null:
|
|
644
|
+
py_type += " | None"
|
|
645
|
+
|
|
646
|
+
return ColumnPySpec(
|
|
647
|
+
name=column.name or f"param_{number}",
|
|
648
|
+
table=column.table.name if column.table else "unknown",
|
|
649
|
+
db_type=db_type,
|
|
650
|
+
not_null=column.not_null,
|
|
651
|
+
is_array=column.is_array,
|
|
652
|
+
py_type=py_type,
|
|
653
|
+
)
|
|
654
|
+
|
|
655
|
+
|
|
656
|
+
def find_fn_calls(
|
|
657
|
+
root_path: Path, fn_name: str
|
|
658
|
+
) -> Iterator[tuple[Path, int, ast.Call]]:
|
|
659
|
+
for path in root_path.glob("**/*.py"):
|
|
660
|
+
content = path.read_text(encoding="utf-8")
|
|
661
|
+
if fn_name not in content:
|
|
662
|
+
continue
|
|
663
|
+
for node in ast.walk(ast.parse(content, filename=str(path))):
|
|
664
|
+
match node:
|
|
665
|
+
case ast.Call(func=ast.Name(id=id)) if id == fn_name:
|
|
666
|
+
yield path, node.lineno, node
|
|
667
|
+
case _:
|
|
668
|
+
pass
|
|
669
|
+
|
|
670
|
+
|
|
671
|
+
def find_all_queries(src_path: Path, sql_fn_name: str) -> Iterator[CodeQuery]:
|
|
672
|
+
for file, lineno, node in find_fn_calls(src_path, sql_fn_name):
|
|
673
|
+
relative_path = file.relative_to(src_path)
|
|
674
|
+
|
|
675
|
+
stmt_arg = node.args[0]
|
|
676
|
+
if (
|
|
677
|
+
len(node.args) != 1
|
|
678
|
+
or not isinstance(stmt_arg, ast.Constant)
|
|
679
|
+
or not isinstance(stmt_arg.value, str)
|
|
680
|
+
):
|
|
681
|
+
msg = (
|
|
682
|
+
f"Invalid positional arguments for {sql_fn_name} "
|
|
683
|
+
f"at {relative_path}:{lineno}, "
|
|
684
|
+
"expected a single string literal"
|
|
685
|
+
)
|
|
686
|
+
raise TypeError(msg)
|
|
687
|
+
|
|
688
|
+
stmt = stmt_arg.value
|
|
689
|
+
|
|
690
|
+
row_type = None
|
|
691
|
+
for kw in node.keywords:
|
|
692
|
+
if not isinstance(kw.value, ast.Constant) or not isinstance(
|
|
693
|
+
kw.value.value, str
|
|
694
|
+
):
|
|
695
|
+
msg = (
|
|
696
|
+
f"Invalid keyword argument {kw.arg} for {sql_fn_name} "
|
|
697
|
+
f"at {relative_path}:{lineno}, expected a string literal"
|
|
698
|
+
)
|
|
699
|
+
raise TypeError(msg)
|
|
700
|
+
if kw.arg == "row_type":
|
|
701
|
+
row_type = kw.value.value
|
|
702
|
+
break
|
|
703
|
+
|
|
704
|
+
yield CodeQuery(
|
|
705
|
+
stmt=stmt,
|
|
706
|
+
row_type=row_type,
|
|
707
|
+
file=relative_path,
|
|
708
|
+
lineno=lineno,
|
|
709
|
+
)
|
|
710
|
+
|
|
711
|
+
|
|
712
|
+
def indent_block(block: str, indent: str) -> str:
|
|
713
|
+
return "\n".join(
|
|
714
|
+
indent + line if i > 0 and line.strip() else line
|
|
715
|
+
for i, line in enumerate(block.split("\n"))
|
|
716
|
+
)
|
|
717
|
+
|
|
718
|
+
|
|
719
|
+
def write_if_changed(path: Path, new_content: str) -> bool:
|
|
720
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
721
|
+
existing_content = path.read_text(encoding="utf-8") if path.exists() else None
|
|
722
|
+
if existing_content == new_content:
|
|
723
|
+
return False
|
|
724
|
+
path.write_text(new_content, encoding="utf-8")
|
|
725
|
+
path.touch()
|
|
726
|
+
return True
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
from collections.abc import AsyncIterator
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
from contextlib import asynccontextmanager
|
|
4
|
+
from contextvars import ContextVar
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from typing import Any
|
|
7
|
+
from typing import Literal
|
|
8
|
+
from typing import Self
|
|
9
|
+
from typing import overload
|
|
10
|
+
|
|
11
|
+
import psycopg
|
|
12
|
+
import psycopg.rows
|
|
13
|
+
import psycopg_pool
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class NoRowsError(Exception):
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class TooManyRowsError(Exception):
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ConnectionPool:
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
conninfo: str,
|
|
28
|
+
*,
|
|
29
|
+
name: str | None = None,
|
|
30
|
+
application_name: str | None = None,
|
|
31
|
+
) -> None:
|
|
32
|
+
self.conninfo = conninfo
|
|
33
|
+
self.name = name
|
|
34
|
+
self.application_name = application_name
|
|
35
|
+
self._init_psycopg_pool()
|
|
36
|
+
|
|
37
|
+
async def close(self) -> None:
|
|
38
|
+
await self.psycopg_pool.close()
|
|
39
|
+
self._init_psycopg_pool()
|
|
40
|
+
|
|
41
|
+
async def __aenter__(self) -> Self:
|
|
42
|
+
return self
|
|
43
|
+
|
|
44
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
45
|
+
await self.close()
|
|
46
|
+
|
|
47
|
+
async def await_connections(self) -> None:
|
|
48
|
+
await self.psycopg_pool.open(wait=True)
|
|
49
|
+
|
|
50
|
+
async def check(self) -> None:
|
|
51
|
+
await self.psycopg_pool.open()
|
|
52
|
+
await self.psycopg_pool.check()
|
|
53
|
+
|
|
54
|
+
@asynccontextmanager
|
|
55
|
+
async def connection(self) -> AsyncIterator[psycopg.AsyncConnection]:
|
|
56
|
+
await self.psycopg_pool.open()
|
|
57
|
+
async with self.psycopg_pool.connection() as conn:
|
|
58
|
+
yield conn
|
|
59
|
+
|
|
60
|
+
def _init_psycopg_pool(self) -> None:
|
|
61
|
+
self.psycopg_pool = psycopg_pool.AsyncConnectionPool(
|
|
62
|
+
self.conninfo,
|
|
63
|
+
open=False,
|
|
64
|
+
name=self.name,
|
|
65
|
+
kwargs={
|
|
66
|
+
"application_name": self.application_name,
|
|
67
|
+
# https://www.psycopg.org/psycopg3/docs/basic/transactions.html#autocommit-transactions
|
|
68
|
+
"autocommit": True,
|
|
69
|
+
},
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
@asynccontextmanager
|
|
73
|
+
async def connection_in_context(
|
|
74
|
+
self, context_var: ContextVar[psycopg.AsyncConnection | None]
|
|
75
|
+
) -> AsyncIterator[psycopg.AsyncConnection]:
|
|
76
|
+
conn = context_var.get()
|
|
77
|
+
if conn is not None:
|
|
78
|
+
yield conn
|
|
79
|
+
return
|
|
80
|
+
async with self.connection() as conn:
|
|
81
|
+
token = context_var.set(conn)
|
|
82
|
+
try:
|
|
83
|
+
yield conn
|
|
84
|
+
finally:
|
|
85
|
+
context_var.reset(token)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def get_one_row[T](rows: list[T]) -> T:
|
|
89
|
+
if len(rows) == 0:
|
|
90
|
+
raise NoRowsError
|
|
91
|
+
if len(rows) > 1:
|
|
92
|
+
raise TooManyRowsError
|
|
93
|
+
return rows[0]
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def get_one_row_or_none[T](rows: list[T]) -> T | None:
|
|
97
|
+
if len(rows) == 0:
|
|
98
|
+
return None
|
|
99
|
+
if len(rows) > 1:
|
|
100
|
+
raise TooManyRowsError
|
|
101
|
+
return rows[0]
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@overload
|
|
105
|
+
def typed_scalar_row[T](
|
|
106
|
+
typ: type[T], *, not_null: Literal[True]
|
|
107
|
+
) -> psycopg.rows.BaseRowFactory[T]: ...
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@overload
|
|
111
|
+
def typed_scalar_row[T](
|
|
112
|
+
typ: type[T], *, not_null: Literal[False]
|
|
113
|
+
) -> psycopg.rows.BaseRowFactory[T | None]: ...
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def typed_scalar_row[T](
|
|
117
|
+
typ: type[T], *, not_null: bool
|
|
118
|
+
) -> psycopg.rows.BaseRowFactory[T | None]:
|
|
119
|
+
def typed_scalar_row_(cursor) -> psycopg.rows.RowMaker[T | None]:
|
|
120
|
+
scalar_row_ = psycopg.rows.scalar_row(cursor)
|
|
121
|
+
|
|
122
|
+
def typed_scalar_row__(values: Sequence[Any]) -> T | None:
|
|
123
|
+
val = scalar_row_(values)
|
|
124
|
+
if not not_null and val is None:
|
|
125
|
+
return None
|
|
126
|
+
if not isinstance(val, typ):
|
|
127
|
+
if issubclass(typ, Enum):
|
|
128
|
+
return typ(val)
|
|
129
|
+
msg = f"Expected scalar of type {typ}, got {type(val)}"
|
|
130
|
+
raise TypeError(msg)
|
|
131
|
+
return val
|
|
132
|
+
|
|
133
|
+
return typed_scalar_row__
|
|
134
|
+
|
|
135
|
+
return typed_scalar_row_
|
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import re
|
|
3
|
+
import shutil
|
|
4
|
+
import subprocess # noqa: S404
|
|
5
|
+
import tempfile
|
|
6
|
+
import textwrap
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
|
|
9
|
+
import pydantic
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class CatalogReference(pydantic.BaseModel):
|
|
13
|
+
catalog: str
|
|
14
|
+
schema_name: str = pydantic.Field(..., alias="schema")
|
|
15
|
+
name: str
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Column(pydantic.BaseModel):
|
|
19
|
+
name: str
|
|
20
|
+
not_null: bool
|
|
21
|
+
is_array: bool
|
|
22
|
+
comment: str
|
|
23
|
+
length: int
|
|
24
|
+
is_named_param: bool
|
|
25
|
+
is_func_call: bool
|
|
26
|
+
scope: str
|
|
27
|
+
table: CatalogReference | None
|
|
28
|
+
table_alias: str
|
|
29
|
+
type: CatalogReference
|
|
30
|
+
is_sqlc_slice: bool
|
|
31
|
+
embed_table: None
|
|
32
|
+
original_name: str
|
|
33
|
+
unsigned: bool
|
|
34
|
+
array_dims: int
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Table(pydantic.BaseModel):
|
|
38
|
+
rel: CatalogReference
|
|
39
|
+
columns: list[Column]
|
|
40
|
+
comment: str
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class Enum(pydantic.BaseModel):
|
|
44
|
+
name: str
|
|
45
|
+
vals: list[str]
|
|
46
|
+
comment: str
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class CompositeType(pydantic.BaseModel):
|
|
50
|
+
name: str
|
|
51
|
+
comment: str
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class Schema(pydantic.BaseModel):
|
|
55
|
+
comment: str
|
|
56
|
+
name: str
|
|
57
|
+
tables: list[Table]
|
|
58
|
+
enums: list[Enum]
|
|
59
|
+
composite_types: list[CompositeType]
|
|
60
|
+
|
|
61
|
+
def has_enum(self, name: str) -> bool:
|
|
62
|
+
return any(e.name == name for e in self.enums)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class Catalog(pydantic.BaseModel):
|
|
66
|
+
default_schema: str
|
|
67
|
+
name: str
|
|
68
|
+
schemas: list[Schema]
|
|
69
|
+
|
|
70
|
+
def schema_by_name(self, name: str) -> Schema:
|
|
71
|
+
for schema in self.schemas:
|
|
72
|
+
if schema.name == name:
|
|
73
|
+
return schema
|
|
74
|
+
msg = f"Schema not found: {name}"
|
|
75
|
+
raise ValueError(msg)
|
|
76
|
+
|
|
77
|
+
def schema_by_ref(self, ref: CatalogReference) -> Schema:
|
|
78
|
+
return self.schema_by_name(ref.schema_name or self.default_schema)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class QueryParameter(pydantic.BaseModel):
|
|
82
|
+
number: int
|
|
83
|
+
column: Column
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class Query(pydantic.BaseModel):
|
|
87
|
+
text: str
|
|
88
|
+
name: str
|
|
89
|
+
cmd: str
|
|
90
|
+
columns: list[Column]
|
|
91
|
+
params: list[QueryParameter]
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class SQLCResult(pydantic.BaseModel):
|
|
95
|
+
error: str | None = None
|
|
96
|
+
catalog: Catalog
|
|
97
|
+
queries: list[Query]
|
|
98
|
+
|
|
99
|
+
def used_schemas(self) -> list[str]:
|
|
100
|
+
result = {
|
|
101
|
+
c.table.schema_name
|
|
102
|
+
for q in self.queries
|
|
103
|
+
for c in q.columns
|
|
104
|
+
if c.table is not None
|
|
105
|
+
}
|
|
106
|
+
if "" in result:
|
|
107
|
+
result.remove("")
|
|
108
|
+
result.add(self.catalog.default_schema)
|
|
109
|
+
catalog_schema_names = {s.name for s in self.catalog.schemas}
|
|
110
|
+
return [s for s in result if s in catalog_schema_names]
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _resolve_sqlc_command(
|
|
114
|
+
sqlc_path: Path | None,
|
|
115
|
+
sqlc_command: list[str] | None,
|
|
116
|
+
) -> list[str]:
|
|
117
|
+
if sqlc_command is not None:
|
|
118
|
+
if sqlc_path is not None:
|
|
119
|
+
msg = "sqlc_command and sqlc_path are mutually exclusive"
|
|
120
|
+
raise ValueError(msg)
|
|
121
|
+
if not sqlc_command:
|
|
122
|
+
msg = "sqlc_command must not be empty"
|
|
123
|
+
raise ValueError(msg)
|
|
124
|
+
return sqlc_command
|
|
125
|
+
|
|
126
|
+
if sqlc_path is None:
|
|
127
|
+
discovered_path = shutil.which("sqlc")
|
|
128
|
+
if discovered_path is None:
|
|
129
|
+
msg = "sqlc not found in PATH"
|
|
130
|
+
raise FileNotFoundError(msg)
|
|
131
|
+
sqlc_path = Path(discovered_path)
|
|
132
|
+
if not sqlc_path.exists():
|
|
133
|
+
msg = f"sqlc not found at {sqlc_path}"
|
|
134
|
+
raise FileNotFoundError(msg)
|
|
135
|
+
|
|
136
|
+
return [str(sqlc_path)]
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def run_sqlc(
|
|
140
|
+
schema_path: Path,
|
|
141
|
+
queries: list[tuple[str, str]],
|
|
142
|
+
*,
|
|
143
|
+
dsn: str | None,
|
|
144
|
+
debug_path: Path | None = None,
|
|
145
|
+
sqlc_path: Path | None = None,
|
|
146
|
+
tempdir_path: Path | None = None,
|
|
147
|
+
sqlc_command: list[str] | None = None,
|
|
148
|
+
) -> SQLCResult:
|
|
149
|
+
if not schema_path.exists():
|
|
150
|
+
msg = f"Schema file not found: {schema_path}"
|
|
151
|
+
raise ValueError(msg)
|
|
152
|
+
|
|
153
|
+
if not queries:
|
|
154
|
+
return SQLCResult(
|
|
155
|
+
catalog=Catalog(default_schema="", name="", schemas=[]),
|
|
156
|
+
queries=[],
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
queries = list({q[0]: q for q in queries}.values())
|
|
160
|
+
cmd_prefix = _resolve_sqlc_command(sqlc_path, sqlc_command)
|
|
161
|
+
|
|
162
|
+
with tempfile.TemporaryDirectory(
|
|
163
|
+
dir=str(tempdir_path) if tempdir_path else None
|
|
164
|
+
) as tempdir:
|
|
165
|
+
queries_path = Path(tempdir) / "queries.sql"
|
|
166
|
+
queries_path.write_text(
|
|
167
|
+
"\n\n".join(
|
|
168
|
+
f"-- name: {name} :exec\n{preprocess_sql(stmt)};"
|
|
169
|
+
for name, stmt in queries
|
|
170
|
+
),
|
|
171
|
+
encoding="utf-8",
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
(Path(tempdir) / "schema.sql").symlink_to(schema_path.absolute())
|
|
175
|
+
|
|
176
|
+
config_path = Path(tempdir) / "sqlc.json"
|
|
177
|
+
sqlc_config = {
|
|
178
|
+
"version": "2",
|
|
179
|
+
"sql": [
|
|
180
|
+
{
|
|
181
|
+
"schema": "schema.sql",
|
|
182
|
+
"queries": ["queries.sql"],
|
|
183
|
+
"engine": "postgresql",
|
|
184
|
+
"database": {"uri": dsn} if dsn else None,
|
|
185
|
+
"gen": {"json": {"out": ".", "filename": "out.json"}},
|
|
186
|
+
}
|
|
187
|
+
],
|
|
188
|
+
}
|
|
189
|
+
config_path.write_text(json.dumps(sqlc_config, indent=2), encoding="utf-8")
|
|
190
|
+
|
|
191
|
+
cmd = [*cmd_prefix, "generate", "--file", str(config_path.resolve())]
|
|
192
|
+
|
|
193
|
+
sqlc_run_result = subprocess.run( # noqa: S603
|
|
194
|
+
cmd,
|
|
195
|
+
capture_output=True,
|
|
196
|
+
check=False,
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
json_out_path = Path(tempdir) / "out.json"
|
|
200
|
+
|
|
201
|
+
if debug_path:
|
|
202
|
+
debug_path.absolute().mkdir(parents=True, exist_ok=True)
|
|
203
|
+
shutil.copy(queries_path, debug_path)
|
|
204
|
+
shutil.copy(schema_path, debug_path / "schema.sql")
|
|
205
|
+
shutil.copy(config_path, debug_path)
|
|
206
|
+
if json_out_path.exists():
|
|
207
|
+
shutil.copy(json_out_path, debug_path)
|
|
208
|
+
elif (debug_path / "out.json").exists():
|
|
209
|
+
(debug_path / "out.json").unlink()
|
|
210
|
+
|
|
211
|
+
if not json_out_path.exists():
|
|
212
|
+
return SQLCResult(
|
|
213
|
+
error=sqlc_run_result.stderr.decode().strip(),
|
|
214
|
+
catalog=Catalog(default_schema="", name="", schemas=[]),
|
|
215
|
+
queries=[],
|
|
216
|
+
)
|
|
217
|
+
return SQLCResult.model_validate_json(json_out_path.read_text(encoding="utf-8"))
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def preprocess_sql(stmt: str) -> str:
|
|
221
|
+
stmt = re.sub(r"@(\w+)\?", r"sqlc.narg('\1')", stmt)
|
|
222
|
+
return textwrap.dedent(stmt).strip()
|