iron-sql 0.4.3__tar.gz → 0.4.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: iron-sql
3
- Version: 0.4.3
3
+ Version: 0.4.4
4
4
  Summary: iron_sql generates typed async PostgreSQL clients and runtime helpers from schemas and SQL queries
5
5
  Keywords: postgresql,sql,sqlc,psycopg,codegen,async
6
6
  Author: Ilia Ablamonov
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "iron-sql"
3
- version = "0.4.3"
3
+ version = "0.4.4"
4
4
 
5
5
  description = "iron_sql generates typed async PostgreSQL clients and runtime helpers from schemas and SQL queries"
6
6
  readme = "README.md"
@@ -3,6 +3,7 @@ import dataclasses
3
3
  import hashlib
4
4
  import importlib
5
5
  import logging
6
+ import re
6
7
  import warnings
7
8
  from collections import defaultdict
8
9
  from collections.abc import Callable
@@ -186,6 +187,24 @@ def collect_used_enums(sqlc_res: SQLCResult) -> set[tuple[str, str]]:
186
187
  }
187
188
 
188
189
 
190
+ def map_sqlc_error(
191
+ error: str,
192
+ block_starts: list[tuple[int, str]],
193
+ all_locations: dict[str, list[str]],
194
+ ) -> str:
195
+ def replace(m: re.Match[str]) -> str:
196
+ line = int(m.group(1))
197
+ name = next((n for start, n in reversed(block_starts) if start <= line), None)
198
+ if name is None:
199
+ return m.group(0)
200
+ locations = all_locations.get(name)
201
+ if not locations:
202
+ return m.group(0)
203
+ return f"{', '.join(locations)}:"
204
+
205
+ return re.sub(r"queries\.sql:(\d+)(?::\d+)?:", replace, error)
206
+
207
+
189
208
  def generate_sql_package( # noqa: PLR0913, PLR0914
190
209
  *,
191
210
  schema_path: Path,
@@ -200,21 +219,14 @@ def generate_sql_package( # noqa: PLR0913, PLR0914
200
219
  src_path: Path = Path(),
201
220
  tempdir_path: Path | None = None,
202
221
  ) -> bool:
203
- dsn_import_package, dsn_import_path = dsn_import.split(":")
204
-
205
- package_name = package_full_name.split(".")[-1] # noqa: PLC0207
222
+ package_name = package_full_name.rsplit(".", maxsplit=1)[-1]
206
223
  sql_fn_name = f"{package_name}_sql"
207
224
 
208
- target_package_path = src_path / f"{package_full_name.replace('.', '/')}.py"
225
+ queries, all_locations = collect_queries(src_path, sql_fn_name)
209
226
 
210
- queries = list(find_all_queries(src_path, sql_fn_name))
211
- validate_stmt_has_single_row_type(queries)
212
- queries = list({q.name: q for q in queries}.values())
227
+ dsn, dsn_import_package, dsn_import_path = resolve_dsn(dsn_import)
213
228
 
214
- dsn_package = importlib.import_module(dsn_import_package)
215
- dsn = eval(dsn_import_path, vars(dsn_package)) # noqa: S307
216
-
217
- sqlc_res = run_sqlc(
229
+ sqlc_res, block_starts = run_sqlc(
218
230
  src_path / schema_path,
219
231
  [(q.name, q.stmt) for q in queries],
220
232
  dsn=dsn,
@@ -223,63 +235,13 @@ def generate_sql_package( # noqa: PLR0913, PLR0914
223
235
  )
224
236
 
225
237
  if sqlc_res.error:
226
- logger.error("Error running SQLC:\n%s", sqlc_res.error)
238
+ mapped = map_sqlc_error(sqlc_res.error, block_starts, all_locations)
239
+ logger.error(f"Error running SQLC:\n{mapped}")
227
240
  return False
228
241
 
229
- json_import_block = ""
230
- json_col_overrides: dict[tuple[str, str], str] = {}
231
-
232
- if json_model_overrides:
233
- json_compatible_types = {"json", "jsonb", "text", "varchar"}
234
- col_types = {
235
- (table.rel.name, column.name): column.type.name.removeprefix("pg_catalog.")
236
- for schema in sqlc_res.catalog.schemas
237
- for table in schema.tables
238
- for column in table.columns
239
- }
240
- tables = {table for table, _ in col_types}
241
-
242
- parsed: dict[tuple[str, str], tuple[str, str]] = {}
243
- for key, import_path in json_model_overrides.items():
244
- table_name, sep, col_name = key.partition(".")
245
- if not sep:
246
- msg = f"json_model_overrides key must be 'table.column', got: {key!r}"
247
- raise ValueError(msg)
248
- if table_name not in tables:
249
- msg = f"json_model_overrides: table {table_name!r} not found in catalog"
250
- raise ValueError(msg)
251
- if (table_name, col_name) not in col_types:
252
- msg = (
253
- f"json_model_overrides: column {col_name!r} "
254
- f"not found in table {table_name!r}"
255
- )
256
- raise ValueError(msg)
257
-
258
- db_type = col_types[table_name, col_name]
259
- if db_type not in json_compatible_types:
260
- msg = (
261
- f"json_model_overrides: column "
262
- f"{table_name}.{col_name} has type "
263
- f"{db_type!r}, expected one of "
264
- f"{json_compatible_types}"
265
- )
266
- raise ValueError(msg)
267
-
268
- module_path, sep, class_name = import_path.partition(":")
269
- if not sep:
270
- msg = (
271
- "json_model_overrides value must be "
272
- f"'module:Class', got: {import_path!r}"
273
- )
274
- raise ValueError(msg)
275
-
276
- parsed[table_name, col_name] = (module_path, class_name)
277
-
278
- modules = sorted({module for module, _ in parsed.values()})
279
- json_import_block = "\n" + "\n".join(f"import {m}" for m in modules)
280
- json_col_overrides = {
281
- key: f"{module}.{cls}" for key, (module, cls) in parsed.items()
282
- }
242
+ json_import_block, json_col_overrides = resolve_json_model_overrides(
243
+ json_model_overrides or {}, sqlc_res.catalog
244
+ )
283
245
 
284
246
  resolver = TypeResolver(
285
247
  catalog=sqlc_res.catalog,
@@ -297,18 +259,79 @@ def generate_sql_package( # noqa: PLR0913, PLR0914
297
259
  resolver,
298
260
  )
299
261
 
300
- entities = [render_entity(e.name, e.column_specs) for e in ordered_entities]
262
+ entities = sorted(render_entity(e.name, e.column_specs) for e in ordered_entities)
301
263
 
302
264
  used_enums = collect_used_enums(sqlc_res)
303
265
 
304
- enums = [
266
+ enums = sorted(
305
267
  render_enum_class(e, package_name, to_pascal_fn, to_snake_fn)
306
268
  for schema in sqlc_res.catalog.schemas
307
269
  for e in schema.enums
308
270
  if (schema.name, e.name) in used_enums
271
+ )
272
+
273
+ query_classes = render_query_classes(
274
+ sqlc_res.queries, queries, resolver, result_types, all_locations
275
+ )
276
+
277
+ query_overloads = [
278
+ render_query_overload(sql_fn_name, q.name, q.stmt, q.row_type) for q in queries
309
279
  ]
310
280
 
311
- query_classes = [
281
+ query_dict_entries = [render_query_dict_entry(q.name, q.stmt) for q in queries]
282
+
283
+ target_package_path = src_path / f"{package_full_name.replace('.', '/')}.py"
284
+
285
+ new_content = render_package(
286
+ dsn_import_package,
287
+ dsn_import_path,
288
+ package_name,
289
+ sql_fn_name,
290
+ entities,
291
+ enums,
292
+ query_classes,
293
+ query_overloads,
294
+ query_dict_entries,
295
+ application_name,
296
+ json_import_block,
297
+ )
298
+ changed = write_if_changed(target_package_path, new_content + "\n")
299
+ if changed:
300
+ logger.info(f"Generated SQL package {package_full_name}")
301
+ return changed
302
+
303
+
304
+ def collect_queries(
305
+ src_path: Path, sql_fn_name: str
306
+ ) -> tuple[list["CodeQuery"], defaultdict[str, list[str]]]:
307
+ raw = list(find_all_queries(src_path, sql_fn_name))
308
+ validate_stmt_has_single_row_type(raw)
309
+ all_locations: defaultdict[str, list[str]] = defaultdict(list)
310
+ first_occurrence: dict[str, CodeQuery] = {}
311
+ for q in raw:
312
+ all_locations[q.name].append(q.location)
313
+ if q.name not in first_occurrence:
314
+ first_occurrence[q.name] = q
315
+ queries = sorted(first_occurrence.values(), key=lambda q: (q.file, q.lineno))
316
+ return queries, all_locations
317
+
318
+
319
+ def resolve_dsn(dsn_import: str) -> tuple[str, str, str]:
320
+ package_name, attr_path = dsn_import.split(":")
321
+ mod = importlib.import_module(package_name)
322
+ dsn: str = eval(attr_path, vars(mod)) # noqa: S307
323
+ return dsn, package_name, attr_path
324
+
325
+
326
+ def render_query_classes(
327
+ sqlc_queries: tuple[Query, ...],
328
+ queries: list["CodeQuery"],
329
+ resolver: TypeResolver,
330
+ result_types: dict[str, str],
331
+ all_locations: defaultdict[str, list[str]],
332
+ ) -> list[str]:
333
+ query_order = {q.name: i for i, q in enumerate(queries)}
334
+ return [
312
335
  render_query_class(
313
336
  q.name,
314
337
  q.text,
@@ -327,33 +350,67 @@ def generate_sql_package( # noqa: PLR0913, PLR0914
327
350
  if len(q.columns) == 1
328
351
  else None
329
352
  ),
353
+ all_locations[q.name],
330
354
  )
331
- for q in sqlc_res.queries
355
+ for q in sorted(sqlc_queries, key=lambda q: query_order[q.name])
332
356
  ]
333
357
 
334
- query_overloads = [
335
- render_query_overload(sql_fn_name, q.name, q.stmt, q.row_type) for q in queries
336
- ]
337
358
 
338
- query_dict_entries = [render_query_dict_entry(q.name, q.stmt) for q in queries]
359
+ def resolve_json_model_overrides(
360
+ overrides: dict[str, str], catalog: Catalog
361
+ ) -> tuple[str, dict[tuple[str, str], str]]:
362
+ if not overrides:
363
+ return "", {}
339
364
 
340
- new_content = render_package(
341
- dsn_import_package,
342
- dsn_import_path,
343
- package_name,
344
- sql_fn_name,
345
- sorted(entities),
346
- sorted(enums),
347
- sorted(query_classes),
348
- sorted(query_overloads),
349
- sorted(query_dict_entries),
350
- application_name,
351
- json_import_block,
352
- )
353
- changed = write_if_changed(target_package_path, new_content + "\n")
354
- if changed:
355
- logger.info(f"Generated SQL package {package_full_name}")
356
- return changed
365
+ json_compatible_types = {"json", "jsonb", "text", "varchar"}
366
+ col_types = {
367
+ (table.rel.name, column.name): column.type.name.removeprefix("pg_catalog.")
368
+ for schema in catalog.schemas
369
+ for table in schema.tables
370
+ for column in table.columns
371
+ }
372
+ tables = {table for table, _ in col_types}
373
+
374
+ parsed: dict[tuple[str, str], tuple[str, str]] = {}
375
+ for key, import_path in overrides.items():
376
+ table_name, sep, col_name = key.partition(".")
377
+ if not sep:
378
+ msg = f"json_model_overrides key must be 'table.column', got: {key!r}"
379
+ raise ValueError(msg)
380
+ if table_name not in tables:
381
+ msg = f"json_model_overrides: table {table_name!r} not found in catalog"
382
+ raise ValueError(msg)
383
+ if (table_name, col_name) not in col_types:
384
+ msg = (
385
+ f"json_model_overrides: column {col_name!r} "
386
+ f"not found in table {table_name!r}"
387
+ )
388
+ raise ValueError(msg)
389
+
390
+ db_type = col_types[table_name, col_name]
391
+ if db_type not in json_compatible_types:
392
+ msg = (
393
+ f"json_model_overrides: column "
394
+ f"{table_name}.{col_name} has type "
395
+ f"{db_type!r}, expected one of "
396
+ f"{json_compatible_types}"
397
+ )
398
+ raise ValueError(msg)
399
+
400
+ module_path, sep, class_name = import_path.partition(":")
401
+ if not sep:
402
+ msg = (
403
+ "json_model_overrides value must be "
404
+ f"'module:Class', got: {import_path!r}"
405
+ )
406
+ raise ValueError(msg)
407
+
408
+ parsed[table_name, col_name] = (module_path, class_name)
409
+
410
+ modules = sorted({module for module, _ in parsed.values()})
411
+ import_block = "\n" + "\n".join(f"import {m}" for m in modules)
412
+ col_overrides = {key: f"{module}.{cls}" for key, (module, cls) in parsed.items()}
413
+ return import_block, col_overrides
357
414
 
358
415
 
359
416
  def render_package( # noqa: PLR0913, PLR0917
@@ -562,7 +619,8 @@ def render_query_class(
562
619
  query_params: list[ParamSpec],
563
620
  result: str,
564
621
  columns_num: int,
565
- scalar_json_type: str | None = None,
622
+ scalar_json_type: str | None,
623
+ locations: list[str],
566
624
  ) -> str:
567
625
  query_params = deduplicate_params(query_params)
568
626
 
@@ -631,6 +689,7 @@ async def execute({", ".join(query_fn_params)}) -> None:
631
689
  return f"""
632
690
 
633
691
  class {query_name}(Query[{result}]):
692
+ # See: {", ".join(locations)}
634
693
  _stmt = psycopg.sql.SQL({stmt!r})
635
694
  _row_factory = staticmethod({row_factory})
636
695
 
@@ -767,7 +826,12 @@ def find_fn_calls(
767
826
  content = path.read_text(encoding="utf-8")
768
827
  if fn_name not in content:
769
828
  continue
770
- for node in ast.walk(ast.parse(content, filename=str(path))):
829
+ try:
830
+ tree = ast.parse(content, filename=str(path))
831
+ except SyntaxError as exc:
832
+ msg = f"Failed to parse {path}: {exc.msg} (line {exc.lineno})"
833
+ raise SyntaxError(msg) from exc
834
+ for node in ast.walk(tree):
771
835
  match node:
772
836
  case ast.Call(func=ast.Name(id=id)) if id == fn_name:
773
837
  yield path, node.lineno, node
@@ -817,11 +881,15 @@ def find_all_queries(src_path: Path, sql_fn_name: str) -> Iterator[CodeQuery]:
817
881
 
818
882
 
819
883
  def validate_stmt_has_single_row_type(queries: list[CodeQuery]) -> None:
820
- row_type_by_stmt: dict[str, str | None] = {}
884
+ first_by_stmt: dict[str, CodeQuery] = {}
821
885
  for query in queries:
822
- if query.stmt in row_type_by_stmt:
823
- if query.row_type != row_type_by_stmt[query.stmt]:
824
- msg = f"row_type conflict (existing={row_type_by_stmt[query.stmt]!r})"
886
+ if query.stmt in first_by_stmt:
887
+ first = first_by_stmt[query.stmt]
888
+ if query.row_type != first.row_type:
889
+ msg = (
890
+ f"row_type conflict: {first.location} has {first.row_type!r},"
891
+ f" {query.location} has {query.row_type!r}"
892
+ )
825
893
  raise ValueError(msg)
826
894
  else:
827
- row_type_by_stmt[query.stmt] = query.row_type
895
+ first_by_stmt[query.stmt] = query
@@ -139,7 +139,7 @@ def run_sqlc(
139
139
  dsn: str | None,
140
140
  debug_path: Path | None = None,
141
141
  tempdir_path: Path | None = None,
142
- ) -> SQLCResult:
142
+ ) -> tuple[SQLCResult, list[tuple[int, str]]]:
143
143
  if not schema_path.exists():
144
144
  msg = f"Schema file not found: {schema_path}"
145
145
  raise ValueError(msg)
@@ -148,7 +148,7 @@ def run_sqlc(
148
148
  return SQLCResult(
149
149
  catalog=Catalog(default_schema="", name="", schemas=()),
150
150
  queries=(),
151
- )
151
+ ), []
152
152
 
153
153
  queries = list({q[0]: q for q in queries}.values())
154
154
 
@@ -156,13 +156,15 @@ def run_sqlc(
156
156
  dir=str(tempdir_path) if tempdir_path else None
157
157
  ) as tempdir:
158
158
  queries_path = Path(tempdir) / "queries.sql"
159
- queries_path.write_text(
160
- "\n\n".join(
161
- f"-- name: {name} :exec\n{preprocess_sql(stmt)};"
162
- for name, stmt in queries
163
- ),
164
- encoding="utf-8",
165
- )
159
+ block_starts: list[tuple[int, str]] = []
160
+ blocks: list[str] = []
161
+ current_line = 1
162
+ for name, stmt in queries:
163
+ block = f"-- name: {name} :exec\n{preprocess_sql(stmt)};"
164
+ block_starts.append((current_line, name))
165
+ current_line += block.count("\n") + 2
166
+ blocks.append(block)
167
+ queries_path.write_text("\n\n".join(blocks), encoding="utf-8")
166
168
 
167
169
  (Path(tempdir) / "schema.sql").symlink_to(schema_path.absolute())
168
170
 
@@ -206,8 +208,10 @@ def run_sqlc(
206
208
  error=sqlc_run_result.stderr.decode().strip(),
207
209
  catalog=Catalog(default_schema="", name="", schemas=()),
208
210
  queries=(),
209
- )
210
- return SQLCResult.model_validate_json(json_out_path.read_text(encoding="utf-8"))
211
+ ), block_starts
212
+ return SQLCResult.model_validate_json(
213
+ json_out_path.read_text(encoding="utf-8")
214
+ ), block_starts
211
215
 
212
216
 
213
217
  def preprocess_sql(stmt: str) -> str:
File without changes
File without changes
File without changes