iron-sql 0.4.3__tar.gz → 0.4.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {iron_sql-0.4.3 → iron_sql-0.4.4}/PKG-INFO +1 -1
- {iron_sql-0.4.3 → iron_sql-0.4.4}/pyproject.toml +1 -1
- {iron_sql-0.4.3 → iron_sql-0.4.4}/src/iron_sql/codegen/generator.py +166 -98
- {iron_sql-0.4.3 → iron_sql-0.4.4}/src/iron_sql/codegen/sqlc.py +15 -11
- {iron_sql-0.4.3 → iron_sql-0.4.4}/LICENSE +0 -0
- {iron_sql-0.4.3 → iron_sql-0.4.4}/README.md +0 -0
- {iron_sql-0.4.3 → iron_sql-0.4.4}/src/iron_sql/__init__.py +0 -0
- {iron_sql-0.4.3 → iron_sql-0.4.4}/src/iron_sql/codegen/__init__.py +0 -0
- {iron_sql-0.4.3 → iron_sql-0.4.4}/src/iron_sql/codegen/util.py +0 -0
- {iron_sql-0.4.3 → iron_sql-0.4.4}/src/iron_sql/py.typed +0 -0
- {iron_sql-0.4.3 → iron_sql-0.4.4}/src/iron_sql/runtime.py +0 -0
|
@@ -3,6 +3,7 @@ import dataclasses
|
|
|
3
3
|
import hashlib
|
|
4
4
|
import importlib
|
|
5
5
|
import logging
|
|
6
|
+
import re
|
|
6
7
|
import warnings
|
|
7
8
|
from collections import defaultdict
|
|
8
9
|
from collections.abc import Callable
|
|
@@ -186,6 +187,24 @@ def collect_used_enums(sqlc_res: SQLCResult) -> set[tuple[str, str]]:
|
|
|
186
187
|
}
|
|
187
188
|
|
|
188
189
|
|
|
190
|
+
def map_sqlc_error(
|
|
191
|
+
error: str,
|
|
192
|
+
block_starts: list[tuple[int, str]],
|
|
193
|
+
all_locations: dict[str, list[str]],
|
|
194
|
+
) -> str:
|
|
195
|
+
def replace(m: re.Match[str]) -> str:
|
|
196
|
+
line = int(m.group(1))
|
|
197
|
+
name = next((n for start, n in reversed(block_starts) if start <= line), None)
|
|
198
|
+
if name is None:
|
|
199
|
+
return m.group(0)
|
|
200
|
+
locations = all_locations.get(name)
|
|
201
|
+
if not locations:
|
|
202
|
+
return m.group(0)
|
|
203
|
+
return f"{', '.join(locations)}:"
|
|
204
|
+
|
|
205
|
+
return re.sub(r"queries\.sql:(\d+)(?::\d+)?:", replace, error)
|
|
206
|
+
|
|
207
|
+
|
|
189
208
|
def generate_sql_package( # noqa: PLR0913, PLR0914
|
|
190
209
|
*,
|
|
191
210
|
schema_path: Path,
|
|
@@ -200,21 +219,14 @@ def generate_sql_package( # noqa: PLR0913, PLR0914
|
|
|
200
219
|
src_path: Path = Path(),
|
|
201
220
|
tempdir_path: Path | None = None,
|
|
202
221
|
) -> bool:
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
package_name = package_full_name.split(".")[-1] # noqa: PLC0207
|
|
222
|
+
package_name = package_full_name.rsplit(".", maxsplit=1)[-1]
|
|
206
223
|
sql_fn_name = f"{package_name}_sql"
|
|
207
224
|
|
|
208
|
-
|
|
225
|
+
queries, all_locations = collect_queries(src_path, sql_fn_name)
|
|
209
226
|
|
|
210
|
-
|
|
211
|
-
validate_stmt_has_single_row_type(queries)
|
|
212
|
-
queries = list({q.name: q for q in queries}.values())
|
|
227
|
+
dsn, dsn_import_package, dsn_import_path = resolve_dsn(dsn_import)
|
|
213
228
|
|
|
214
|
-
|
|
215
|
-
dsn = eval(dsn_import_path, vars(dsn_package)) # noqa: S307
|
|
216
|
-
|
|
217
|
-
sqlc_res = run_sqlc(
|
|
229
|
+
sqlc_res, block_starts = run_sqlc(
|
|
218
230
|
src_path / schema_path,
|
|
219
231
|
[(q.name, q.stmt) for q in queries],
|
|
220
232
|
dsn=dsn,
|
|
@@ -223,63 +235,13 @@ def generate_sql_package( # noqa: PLR0913, PLR0914
|
|
|
223
235
|
)
|
|
224
236
|
|
|
225
237
|
if sqlc_res.error:
|
|
226
|
-
|
|
238
|
+
mapped = map_sqlc_error(sqlc_res.error, block_starts, all_locations)
|
|
239
|
+
logger.error(f"Error running SQLC:\n{mapped}")
|
|
227
240
|
return False
|
|
228
241
|
|
|
229
|
-
json_import_block =
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
if json_model_overrides:
|
|
233
|
-
json_compatible_types = {"json", "jsonb", "text", "varchar"}
|
|
234
|
-
col_types = {
|
|
235
|
-
(table.rel.name, column.name): column.type.name.removeprefix("pg_catalog.")
|
|
236
|
-
for schema in sqlc_res.catalog.schemas
|
|
237
|
-
for table in schema.tables
|
|
238
|
-
for column in table.columns
|
|
239
|
-
}
|
|
240
|
-
tables = {table for table, _ in col_types}
|
|
241
|
-
|
|
242
|
-
parsed: dict[tuple[str, str], tuple[str, str]] = {}
|
|
243
|
-
for key, import_path in json_model_overrides.items():
|
|
244
|
-
table_name, sep, col_name = key.partition(".")
|
|
245
|
-
if not sep:
|
|
246
|
-
msg = f"json_model_overrides key must be 'table.column', got: {key!r}"
|
|
247
|
-
raise ValueError(msg)
|
|
248
|
-
if table_name not in tables:
|
|
249
|
-
msg = f"json_model_overrides: table {table_name!r} not found in catalog"
|
|
250
|
-
raise ValueError(msg)
|
|
251
|
-
if (table_name, col_name) not in col_types:
|
|
252
|
-
msg = (
|
|
253
|
-
f"json_model_overrides: column {col_name!r} "
|
|
254
|
-
f"not found in table {table_name!r}"
|
|
255
|
-
)
|
|
256
|
-
raise ValueError(msg)
|
|
257
|
-
|
|
258
|
-
db_type = col_types[table_name, col_name]
|
|
259
|
-
if db_type not in json_compatible_types:
|
|
260
|
-
msg = (
|
|
261
|
-
f"json_model_overrides: column "
|
|
262
|
-
f"{table_name}.{col_name} has type "
|
|
263
|
-
f"{db_type!r}, expected one of "
|
|
264
|
-
f"{json_compatible_types}"
|
|
265
|
-
)
|
|
266
|
-
raise ValueError(msg)
|
|
267
|
-
|
|
268
|
-
module_path, sep, class_name = import_path.partition(":")
|
|
269
|
-
if not sep:
|
|
270
|
-
msg = (
|
|
271
|
-
"json_model_overrides value must be "
|
|
272
|
-
f"'module:Class', got: {import_path!r}"
|
|
273
|
-
)
|
|
274
|
-
raise ValueError(msg)
|
|
275
|
-
|
|
276
|
-
parsed[table_name, col_name] = (module_path, class_name)
|
|
277
|
-
|
|
278
|
-
modules = sorted({module for module, _ in parsed.values()})
|
|
279
|
-
json_import_block = "\n" + "\n".join(f"import {m}" for m in modules)
|
|
280
|
-
json_col_overrides = {
|
|
281
|
-
key: f"{module}.{cls}" for key, (module, cls) in parsed.items()
|
|
282
|
-
}
|
|
242
|
+
json_import_block, json_col_overrides = resolve_json_model_overrides(
|
|
243
|
+
json_model_overrides or {}, sqlc_res.catalog
|
|
244
|
+
)
|
|
283
245
|
|
|
284
246
|
resolver = TypeResolver(
|
|
285
247
|
catalog=sqlc_res.catalog,
|
|
@@ -297,18 +259,79 @@ def generate_sql_package( # noqa: PLR0913, PLR0914
|
|
|
297
259
|
resolver,
|
|
298
260
|
)
|
|
299
261
|
|
|
300
|
-
entities =
|
|
262
|
+
entities = sorted(render_entity(e.name, e.column_specs) for e in ordered_entities)
|
|
301
263
|
|
|
302
264
|
used_enums = collect_used_enums(sqlc_res)
|
|
303
265
|
|
|
304
|
-
enums =
|
|
266
|
+
enums = sorted(
|
|
305
267
|
render_enum_class(e, package_name, to_pascal_fn, to_snake_fn)
|
|
306
268
|
for schema in sqlc_res.catalog.schemas
|
|
307
269
|
for e in schema.enums
|
|
308
270
|
if (schema.name, e.name) in used_enums
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
query_classes = render_query_classes(
|
|
274
|
+
sqlc_res.queries, queries, resolver, result_types, all_locations
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
query_overloads = [
|
|
278
|
+
render_query_overload(sql_fn_name, q.name, q.stmt, q.row_type) for q in queries
|
|
309
279
|
]
|
|
310
280
|
|
|
311
|
-
|
|
281
|
+
query_dict_entries = [render_query_dict_entry(q.name, q.stmt) for q in queries]
|
|
282
|
+
|
|
283
|
+
target_package_path = src_path / f"{package_full_name.replace('.', '/')}.py"
|
|
284
|
+
|
|
285
|
+
new_content = render_package(
|
|
286
|
+
dsn_import_package,
|
|
287
|
+
dsn_import_path,
|
|
288
|
+
package_name,
|
|
289
|
+
sql_fn_name,
|
|
290
|
+
entities,
|
|
291
|
+
enums,
|
|
292
|
+
query_classes,
|
|
293
|
+
query_overloads,
|
|
294
|
+
query_dict_entries,
|
|
295
|
+
application_name,
|
|
296
|
+
json_import_block,
|
|
297
|
+
)
|
|
298
|
+
changed = write_if_changed(target_package_path, new_content + "\n")
|
|
299
|
+
if changed:
|
|
300
|
+
logger.info(f"Generated SQL package {package_full_name}")
|
|
301
|
+
return changed
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def collect_queries(
|
|
305
|
+
src_path: Path, sql_fn_name: str
|
|
306
|
+
) -> tuple[list["CodeQuery"], defaultdict[str, list[str]]]:
|
|
307
|
+
raw = list(find_all_queries(src_path, sql_fn_name))
|
|
308
|
+
validate_stmt_has_single_row_type(raw)
|
|
309
|
+
all_locations: defaultdict[str, list[str]] = defaultdict(list)
|
|
310
|
+
first_occurrence: dict[str, CodeQuery] = {}
|
|
311
|
+
for q in raw:
|
|
312
|
+
all_locations[q.name].append(q.location)
|
|
313
|
+
if q.name not in first_occurrence:
|
|
314
|
+
first_occurrence[q.name] = q
|
|
315
|
+
queries = sorted(first_occurrence.values(), key=lambda q: (q.file, q.lineno))
|
|
316
|
+
return queries, all_locations
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def resolve_dsn(dsn_import: str) -> tuple[str, str, str]:
|
|
320
|
+
package_name, attr_path = dsn_import.split(":")
|
|
321
|
+
mod = importlib.import_module(package_name)
|
|
322
|
+
dsn: str = eval(attr_path, vars(mod)) # noqa: S307
|
|
323
|
+
return dsn, package_name, attr_path
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def render_query_classes(
|
|
327
|
+
sqlc_queries: tuple[Query, ...],
|
|
328
|
+
queries: list["CodeQuery"],
|
|
329
|
+
resolver: TypeResolver,
|
|
330
|
+
result_types: dict[str, str],
|
|
331
|
+
all_locations: defaultdict[str, list[str]],
|
|
332
|
+
) -> list[str]:
|
|
333
|
+
query_order = {q.name: i for i, q in enumerate(queries)}
|
|
334
|
+
return [
|
|
312
335
|
render_query_class(
|
|
313
336
|
q.name,
|
|
314
337
|
q.text,
|
|
@@ -327,33 +350,67 @@ def generate_sql_package( # noqa: PLR0913, PLR0914
|
|
|
327
350
|
if len(q.columns) == 1
|
|
328
351
|
else None
|
|
329
352
|
),
|
|
353
|
+
all_locations[q.name],
|
|
330
354
|
)
|
|
331
|
-
for q in
|
|
355
|
+
for q in sorted(sqlc_queries, key=lambda q: query_order[q.name])
|
|
332
356
|
]
|
|
333
357
|
|
|
334
|
-
query_overloads = [
|
|
335
|
-
render_query_overload(sql_fn_name, q.name, q.stmt, q.row_type) for q in queries
|
|
336
|
-
]
|
|
337
358
|
|
|
338
|
-
|
|
359
|
+
def resolve_json_model_overrides(
|
|
360
|
+
overrides: dict[str, str], catalog: Catalog
|
|
361
|
+
) -> tuple[str, dict[tuple[str, str], str]]:
|
|
362
|
+
if not overrides:
|
|
363
|
+
return "", {}
|
|
339
364
|
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
365
|
+
json_compatible_types = {"json", "jsonb", "text", "varchar"}
|
|
366
|
+
col_types = {
|
|
367
|
+
(table.rel.name, column.name): column.type.name.removeprefix("pg_catalog.")
|
|
368
|
+
for schema in catalog.schemas
|
|
369
|
+
for table in schema.tables
|
|
370
|
+
for column in table.columns
|
|
371
|
+
}
|
|
372
|
+
tables = {table for table, _ in col_types}
|
|
373
|
+
|
|
374
|
+
parsed: dict[tuple[str, str], tuple[str, str]] = {}
|
|
375
|
+
for key, import_path in overrides.items():
|
|
376
|
+
table_name, sep, col_name = key.partition(".")
|
|
377
|
+
if not sep:
|
|
378
|
+
msg = f"json_model_overrides key must be 'table.column', got: {key!r}"
|
|
379
|
+
raise ValueError(msg)
|
|
380
|
+
if table_name not in tables:
|
|
381
|
+
msg = f"json_model_overrides: table {table_name!r} not found in catalog"
|
|
382
|
+
raise ValueError(msg)
|
|
383
|
+
if (table_name, col_name) not in col_types:
|
|
384
|
+
msg = (
|
|
385
|
+
f"json_model_overrides: column {col_name!r} "
|
|
386
|
+
f"not found in table {table_name!r}"
|
|
387
|
+
)
|
|
388
|
+
raise ValueError(msg)
|
|
389
|
+
|
|
390
|
+
db_type = col_types[table_name, col_name]
|
|
391
|
+
if db_type not in json_compatible_types:
|
|
392
|
+
msg = (
|
|
393
|
+
f"json_model_overrides: column "
|
|
394
|
+
f"{table_name}.{col_name} has type "
|
|
395
|
+
f"{db_type!r}, expected one of "
|
|
396
|
+
f"{json_compatible_types}"
|
|
397
|
+
)
|
|
398
|
+
raise ValueError(msg)
|
|
399
|
+
|
|
400
|
+
module_path, sep, class_name = import_path.partition(":")
|
|
401
|
+
if not sep:
|
|
402
|
+
msg = (
|
|
403
|
+
"json_model_overrides value must be "
|
|
404
|
+
f"'module:Class', got: {import_path!r}"
|
|
405
|
+
)
|
|
406
|
+
raise ValueError(msg)
|
|
407
|
+
|
|
408
|
+
parsed[table_name, col_name] = (module_path, class_name)
|
|
409
|
+
|
|
410
|
+
modules = sorted({module for module, _ in parsed.values()})
|
|
411
|
+
import_block = "\n" + "\n".join(f"import {m}" for m in modules)
|
|
412
|
+
col_overrides = {key: f"{module}.{cls}" for key, (module, cls) in parsed.items()}
|
|
413
|
+
return import_block, col_overrides
|
|
357
414
|
|
|
358
415
|
|
|
359
416
|
def render_package( # noqa: PLR0913, PLR0917
|
|
@@ -562,7 +619,8 @@ def render_query_class(
|
|
|
562
619
|
query_params: list[ParamSpec],
|
|
563
620
|
result: str,
|
|
564
621
|
columns_num: int,
|
|
565
|
-
scalar_json_type: str | None
|
|
622
|
+
scalar_json_type: str | None,
|
|
623
|
+
locations: list[str],
|
|
566
624
|
) -> str:
|
|
567
625
|
query_params = deduplicate_params(query_params)
|
|
568
626
|
|
|
@@ -631,6 +689,7 @@ async def execute({", ".join(query_fn_params)}) -> None:
|
|
|
631
689
|
return f"""
|
|
632
690
|
|
|
633
691
|
class {query_name}(Query[{result}]):
|
|
692
|
+
# See: {", ".join(locations)}
|
|
634
693
|
_stmt = psycopg.sql.SQL({stmt!r})
|
|
635
694
|
_row_factory = staticmethod({row_factory})
|
|
636
695
|
|
|
@@ -767,7 +826,12 @@ def find_fn_calls(
|
|
|
767
826
|
content = path.read_text(encoding="utf-8")
|
|
768
827
|
if fn_name not in content:
|
|
769
828
|
continue
|
|
770
|
-
|
|
829
|
+
try:
|
|
830
|
+
tree = ast.parse(content, filename=str(path))
|
|
831
|
+
except SyntaxError as exc:
|
|
832
|
+
msg = f"Failed to parse {path}: {exc.msg} (line {exc.lineno})"
|
|
833
|
+
raise SyntaxError(msg) from exc
|
|
834
|
+
for node in ast.walk(tree):
|
|
771
835
|
match node:
|
|
772
836
|
case ast.Call(func=ast.Name(id=id)) if id == fn_name:
|
|
773
837
|
yield path, node.lineno, node
|
|
@@ -817,11 +881,15 @@ def find_all_queries(src_path: Path, sql_fn_name: str) -> Iterator[CodeQuery]:
|
|
|
817
881
|
|
|
818
882
|
|
|
819
883
|
def validate_stmt_has_single_row_type(queries: list[CodeQuery]) -> None:
|
|
820
|
-
|
|
884
|
+
first_by_stmt: dict[str, CodeQuery] = {}
|
|
821
885
|
for query in queries:
|
|
822
|
-
if query.stmt in
|
|
823
|
-
|
|
824
|
-
|
|
886
|
+
if query.stmt in first_by_stmt:
|
|
887
|
+
first = first_by_stmt[query.stmt]
|
|
888
|
+
if query.row_type != first.row_type:
|
|
889
|
+
msg = (
|
|
890
|
+
f"row_type conflict: {first.location} has {first.row_type!r},"
|
|
891
|
+
f" {query.location} has {query.row_type!r}"
|
|
892
|
+
)
|
|
825
893
|
raise ValueError(msg)
|
|
826
894
|
else:
|
|
827
|
-
|
|
895
|
+
first_by_stmt[query.stmt] = query
|
|
@@ -139,7 +139,7 @@ def run_sqlc(
|
|
|
139
139
|
dsn: str | None,
|
|
140
140
|
debug_path: Path | None = None,
|
|
141
141
|
tempdir_path: Path | None = None,
|
|
142
|
-
) -> SQLCResult:
|
|
142
|
+
) -> tuple[SQLCResult, list[tuple[int, str]]]:
|
|
143
143
|
if not schema_path.exists():
|
|
144
144
|
msg = f"Schema file not found: {schema_path}"
|
|
145
145
|
raise ValueError(msg)
|
|
@@ -148,7 +148,7 @@ def run_sqlc(
|
|
|
148
148
|
return SQLCResult(
|
|
149
149
|
catalog=Catalog(default_schema="", name="", schemas=()),
|
|
150
150
|
queries=(),
|
|
151
|
-
)
|
|
151
|
+
), []
|
|
152
152
|
|
|
153
153
|
queries = list({q[0]: q for q in queries}.values())
|
|
154
154
|
|
|
@@ -156,13 +156,15 @@ def run_sqlc(
|
|
|
156
156
|
dir=str(tempdir_path) if tempdir_path else None
|
|
157
157
|
) as tempdir:
|
|
158
158
|
queries_path = Path(tempdir) / "queries.sql"
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
)
|
|
164
|
-
|
|
165
|
-
|
|
159
|
+
block_starts: list[tuple[int, str]] = []
|
|
160
|
+
blocks: list[str] = []
|
|
161
|
+
current_line = 1
|
|
162
|
+
for name, stmt in queries:
|
|
163
|
+
block = f"-- name: {name} :exec\n{preprocess_sql(stmt)};"
|
|
164
|
+
block_starts.append((current_line, name))
|
|
165
|
+
current_line += block.count("\n") + 2
|
|
166
|
+
blocks.append(block)
|
|
167
|
+
queries_path.write_text("\n\n".join(blocks), encoding="utf-8")
|
|
166
168
|
|
|
167
169
|
(Path(tempdir) / "schema.sql").symlink_to(schema_path.absolute())
|
|
168
170
|
|
|
@@ -206,8 +208,10 @@ def run_sqlc(
|
|
|
206
208
|
error=sqlc_run_result.stderr.decode().strip(),
|
|
207
209
|
catalog=Catalog(default_schema="", name="", schemas=()),
|
|
208
210
|
queries=(),
|
|
209
|
-
)
|
|
210
|
-
return SQLCResult.model_validate_json(
|
|
211
|
+
), block_starts
|
|
212
|
+
return SQLCResult.model_validate_json(
|
|
213
|
+
json_out_path.read_text(encoding="utf-8")
|
|
214
|
+
), block_starts
|
|
211
215
|
|
|
212
216
|
|
|
213
217
|
def preprocess_sql(stmt: str) -> str:
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|