zeroshot-sql-decorators 0.1.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zeroshot_sql_decorators-0.1.5/PKG-INFO +14 -0
- zeroshot_sql_decorators-0.1.5/README.md +3 -0
- zeroshot_sql_decorators-0.1.5/pyproject.toml +19 -0
- zeroshot_sql_decorators-0.1.5/src/zeroshot_sql_decorators/__init__.py +39 -0
- zeroshot_sql_decorators-0.1.5/src/zeroshot_sql_decorators/decorators.py +268 -0
- zeroshot_sql_decorators-0.1.5/src/zeroshot_sql_decorators/param_mapper.py +167 -0
- zeroshot_sql_decorators-0.1.5/src/zeroshot_sql_decorators/py.typed +1 -0
- zeroshot_sql_decorators-0.1.5/src/zeroshot_sql_decorators/query.py +54 -0
- zeroshot_sql_decorators-0.1.5/src/zeroshot_sql_decorators/result_mapper.py +122 -0
- zeroshot_sql_decorators-0.1.5/src/zeroshot_sql_decorators/stream_iterator.py +52 -0
- zeroshot_sql_decorators-0.1.5/src/zeroshot_sql_decorators/types.py +56 -0
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: zeroshot-sql-decorators
|
|
3
|
+
Version: 0.1.5
|
|
4
|
+
Summary: SQL-focused decorators and helpers for Zeroshot Python packages.
|
|
5
|
+
License-Expression: MIT
|
|
6
|
+
Requires-Dist: zeroshot-commons==0.1.5
|
|
7
|
+
Requires-Dist: sqlalchemy[asyncio]>=2.0
|
|
8
|
+
Requires-Dist: asyncpg>=0.30
|
|
9
|
+
Requires-Python: >=3.12
|
|
10
|
+
Description-Content-Type: text/markdown
|
|
11
|
+
|
|
12
|
+
# zeroshot-sql-decorators
|
|
13
|
+
|
|
14
|
+
SQL-focused decorators and helpers for Zeroshot Python packages.
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "zeroshot-sql-decorators"
|
|
3
|
+
version = "0.1.5"
|
|
4
|
+
description = "SQL-focused decorators and helpers for Zeroshot Python packages."
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
requires-python = ">=3.12"
|
|
7
|
+
license = "MIT"
|
|
8
|
+
dependencies = [
|
|
9
|
+
"zeroshot-commons==0.1.5",
|
|
10
|
+
"sqlalchemy[asyncio]>=2.0",
|
|
11
|
+
"asyncpg>=0.30",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
[tool.uv.sources]
|
|
15
|
+
zeroshot-commons = { workspace = true }
|
|
16
|
+
|
|
17
|
+
[build-system]
|
|
18
|
+
requires = ["uv_build>=0.11.8,<0.12"]
|
|
19
|
+
build-backend = "uv_build"
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""SQL-focused decorators and helpers for Zeroshot Python packages."""
|
|
2
|
+
|
|
3
|
+
from .decorators import (
|
|
4
|
+
DaoBase,
|
|
5
|
+
TransactionalityBase,
|
|
6
|
+
dao,
|
|
7
|
+
sql_query,
|
|
8
|
+
sql_transaction,
|
|
9
|
+
stream_select,
|
|
10
|
+
with_transactionality,
|
|
11
|
+
)
|
|
12
|
+
from .stream_iterator import StreamIterator
|
|
13
|
+
from .types import (
|
|
14
|
+
ArrayResult,
|
|
15
|
+
BooleanResult,
|
|
16
|
+
NumberResult,
|
|
17
|
+
QueryOptions,
|
|
18
|
+
QueryType,
|
|
19
|
+
StreamSelectOptions,
|
|
20
|
+
StringResult,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"ArrayResult",
|
|
25
|
+
"BooleanResult",
|
|
26
|
+
"DaoBase",
|
|
27
|
+
"NumberResult",
|
|
28
|
+
"QueryOptions",
|
|
29
|
+
"QueryType",
|
|
30
|
+
"StreamIterator",
|
|
31
|
+
"StreamSelectOptions",
|
|
32
|
+
"StringResult",
|
|
33
|
+
"TransactionalityBase",
|
|
34
|
+
"dao",
|
|
35
|
+
"sql_query",
|
|
36
|
+
"sql_transaction",
|
|
37
|
+
"stream_select",
|
|
38
|
+
"with_transactionality",
|
|
39
|
+
]
|
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import inspect
|
|
5
|
+
import logging
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from sqlalchemy import text
|
|
10
|
+
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
|
11
|
+
|
|
12
|
+
from .param_mapper import build_replacements, expand_in_clauses, extract_param_names
|
|
13
|
+
from .query import load_query
|
|
14
|
+
from .result_mapper import map_result
|
|
15
|
+
from .stream_iterator import StreamIterator
|
|
16
|
+
from .types import QueryOptions, StreamSelectOptions
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# ---------------------------------------------------------------------------
|
|
22
|
+
# Base classes
|
|
23
|
+
# ---------------------------------------------------------------------------
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class DaoBase:
|
|
27
|
+
"""Base class for Data Access Objects.
|
|
28
|
+
|
|
29
|
+
Holds a reference to the SQLAlchemy ``AsyncEngine``.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(self, engine: AsyncEngine) -> None:
|
|
33
|
+
self._engine = engine
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class TransactionalityBase:
|
|
37
|
+
"""Base class for repositories that coordinate transactions across DAOs."""
|
|
38
|
+
|
|
39
|
+
def __init__(self, engine: AsyncEngine) -> None:
|
|
40
|
+
self._engine = engine
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
# ---------------------------------------------------------------------------
|
|
44
|
+
# Class decorators
|
|
45
|
+
# ---------------------------------------------------------------------------
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def dao(*, query_directory: Path | str | None = None) -> Any:
|
|
49
|
+
"""Mark a class as a DAO with an optional query file directory."""
|
|
50
|
+
|
|
51
|
+
def decorator(cls: type) -> type:
|
|
52
|
+
cls._query_directory = Path(query_directory) if query_directory else None # type: ignore[attr-defined]
|
|
53
|
+
return cls
|
|
54
|
+
|
|
55
|
+
return decorator
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def with_transactionality() -> Any:
|
|
59
|
+
"""Mark a class as transaction-aware (mirrors TS ``@WithTransactionality()``)."""
|
|
60
|
+
|
|
61
|
+
def decorator(cls: type) -> type:
|
|
62
|
+
return cls
|
|
63
|
+
|
|
64
|
+
return decorator
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
# ---------------------------------------------------------------------------
|
|
68
|
+
# @sql_query
|
|
69
|
+
# ---------------------------------------------------------------------------
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def sql_query(options: QueryOptions) -> Any:
|
|
73
|
+
"""Method decorator that executes a SQL query and maps the result.
|
|
74
|
+
|
|
75
|
+
The decorated method becomes a stub — its body is never called.
|
|
76
|
+
Parameters are extracted from the function signature and bound to the query.
|
|
77
|
+
An optional ``session: AsyncSession | None = None`` parameter enables
|
|
78
|
+
transaction participation.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
def decorator(fn: Any) -> Any:
|
|
82
|
+
param_names = extract_param_names(fn)
|
|
83
|
+
|
|
84
|
+
@functools.wraps(fn)
|
|
85
|
+
async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
|
86
|
+
engine: AsyncEngine = self._engine
|
|
87
|
+
|
|
88
|
+
# Resolve query directory
|
|
89
|
+
query_dir: Path | None = getattr(self.__class__, "_query_directory", None)
|
|
90
|
+
|
|
91
|
+
sql = load_query(
|
|
92
|
+
inline_query=options.query,
|
|
93
|
+
file_path=options.file,
|
|
94
|
+
query_directory=query_dir,
|
|
95
|
+
class_name=self.__class__.__name__,
|
|
96
|
+
method_name=fn.__name__,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# Build all args (positional + keyword), excluding self
|
|
100
|
+
all_args = _merge_args(fn, args, kwargs)
|
|
101
|
+
replacements = build_replacements(param_names, all_args)
|
|
102
|
+
|
|
103
|
+
# Expand IN clauses for list parameters
|
|
104
|
+
expanded_sql, expanded_params = expand_in_clauses(sql, replacements)
|
|
105
|
+
|
|
106
|
+
# Check for an explicit session argument
|
|
107
|
+
session = _extract_session(fn, args, kwargs)
|
|
108
|
+
|
|
109
|
+
if session is not None:
|
|
110
|
+
result = await session.execute(text(expanded_sql), expanded_params)
|
|
111
|
+
try:
|
|
112
|
+
rows: list[Any] = [dict(r) for r in result.mappings().all()]
|
|
113
|
+
except Exception:
|
|
114
|
+
rows = []
|
|
115
|
+
else:
|
|
116
|
+
async with engine.connect() as conn:
|
|
117
|
+
result = await conn.execute(text(expanded_sql), expanded_params)
|
|
118
|
+
try:
|
|
119
|
+
rows = [dict(r) for r in result.mappings().all()]
|
|
120
|
+
except Exception:
|
|
121
|
+
rows = []
|
|
122
|
+
await conn.commit()
|
|
123
|
+
|
|
124
|
+
return map_result(rows, options.query_type, options.clazz, options.return_list)
|
|
125
|
+
|
|
126
|
+
return wrapper
|
|
127
|
+
|
|
128
|
+
return decorator
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
# ---------------------------------------------------------------------------
|
|
132
|
+
# @stream_select
|
|
133
|
+
# ---------------------------------------------------------------------------
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def stream_select(options: StreamSelectOptions) -> Any:
|
|
137
|
+
"""Method decorator that returns a ``StreamIterator`` for lazy batch iteration."""
|
|
138
|
+
|
|
139
|
+
def decorator(fn: Any) -> Any:
|
|
140
|
+
param_names = extract_param_names(fn)
|
|
141
|
+
|
|
142
|
+
@functools.wraps(fn)
|
|
143
|
+
def wrapper(self: Any, *args: Any, **kwargs: Any) -> StreamIterator[Any]:
|
|
144
|
+
engine: AsyncEngine = self._engine
|
|
145
|
+
query_dir: Path | None = getattr(self.__class__, "_query_directory", None)
|
|
146
|
+
|
|
147
|
+
sql = load_query(
|
|
148
|
+
inline_query=options.query,
|
|
149
|
+
file_path=options.file,
|
|
150
|
+
query_directory=query_dir,
|
|
151
|
+
class_name=self.__class__.__name__,
|
|
152
|
+
method_name=fn.__name__,
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
all_args = _merge_args(fn, args, kwargs)
|
|
156
|
+
replacements = build_replacements(param_names, all_args)
|
|
157
|
+
|
|
158
|
+
return StreamIterator(
|
|
159
|
+
engine=engine,
|
|
160
|
+
sql=sql,
|
|
161
|
+
replacements=replacements,
|
|
162
|
+
clazz=options.clazz,
|
|
163
|
+
batch_size=options.batch_size,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
return wrapper
|
|
167
|
+
|
|
168
|
+
return decorator
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
# ---------------------------------------------------------------------------
|
|
172
|
+
# @sql_transaction
|
|
173
|
+
# ---------------------------------------------------------------------------
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def sql_transaction(
|
|
177
|
+
*,
|
|
178
|
+
isolation_level: str | None = None,
|
|
179
|
+
) -> Any:
|
|
180
|
+
"""Method decorator that wraps the call in a database transaction.
|
|
181
|
+
|
|
182
|
+
If a ``session`` kwarg/arg is already provided and has an active
|
|
183
|
+
transaction, the method participates in that transaction.
|
|
184
|
+
Otherwise a new transaction is created, committed on success,
|
|
185
|
+
and rolled back on error.
|
|
186
|
+
"""
|
|
187
|
+
|
|
188
|
+
def decorator(fn: Any) -> Any:
|
|
189
|
+
@functools.wraps(fn)
|
|
190
|
+
async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
|
191
|
+
engine: AsyncEngine = self._engine
|
|
192
|
+
|
|
193
|
+
existing_session = _extract_session(fn, args, kwargs)
|
|
194
|
+
|
|
195
|
+
if existing_session is not None and existing_session.in_transaction():
|
|
196
|
+
# Already in a transaction — just call through
|
|
197
|
+
return await fn(self, *args, **kwargs)
|
|
198
|
+
|
|
199
|
+
if existing_session is not None:
|
|
200
|
+
raise ValueError(
|
|
201
|
+
"A session was provided but has no active transaction. "
|
|
202
|
+
"Do not pass a session unless it is already transactional."
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
# Create a new session + transaction
|
|
206
|
+
exec_opts: dict[str, Any] = {}
|
|
207
|
+
if isolation_level:
|
|
208
|
+
exec_opts["isolation_level"] = isolation_level
|
|
209
|
+
|
|
210
|
+
session_factory = async_sessionmaker(
|
|
211
|
+
engine,
|
|
212
|
+
expire_on_commit=False,
|
|
213
|
+
)
|
|
214
|
+
async with session_factory.begin() as session:
|
|
215
|
+
if exec_opts:
|
|
216
|
+
await session.connection(execution_options=exec_opts)
|
|
217
|
+
new_kwargs = dict(kwargs)
|
|
218
|
+
new_kwargs["session"] = session
|
|
219
|
+
return await fn(self, *args, **new_kwargs)
|
|
220
|
+
|
|
221
|
+
return wrapper
|
|
222
|
+
|
|
223
|
+
return decorator
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
# ---------------------------------------------------------------------------
|
|
227
|
+
# Helpers
|
|
228
|
+
# ---------------------------------------------------------------------------
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def _merge_args(fn: Any, args: tuple[Any, ...], kwargs: dict[str, Any]) -> tuple[Any, ...]:
|
|
232
|
+
"""Merge positional and keyword args into a positional tuple,
|
|
233
|
+
respecting the original function signature order (excluding ``self``)."""
|
|
234
|
+
sig = inspect.signature(fn)
|
|
235
|
+
params = [p for p in sig.parameters.values() if p.name != "self"]
|
|
236
|
+
|
|
237
|
+
merged: list[Any] = list(args)
|
|
238
|
+
|
|
239
|
+
# Fill in remaining params from kwargs
|
|
240
|
+
for i in range(len(args), len(params)):
|
|
241
|
+
p = params[i]
|
|
242
|
+
if p.name in kwargs:
|
|
243
|
+
merged.append(kwargs[p.name])
|
|
244
|
+
elif p.default is not inspect.Parameter.empty:
|
|
245
|
+
merged.append(p.default)
|
|
246
|
+
|
|
247
|
+
return tuple(merged)
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def _extract_session(
|
|
251
|
+
fn: Any,
|
|
252
|
+
args: tuple[Any, ...],
|
|
253
|
+
kwargs: dict[str, Any],
|
|
254
|
+
) -> AsyncSession | None:
|
|
255
|
+
"""Pull out the ``session`` argument if present."""
|
|
256
|
+
if "session" in kwargs:
|
|
257
|
+
return kwargs["session"]
|
|
258
|
+
|
|
259
|
+
sig = inspect.signature(fn)
|
|
260
|
+
params = [p for p in sig.parameters.values() if p.name != "self"]
|
|
261
|
+
|
|
262
|
+
for i, p in enumerate(params):
|
|
263
|
+
if p.name == "session" and i < len(args):
|
|
264
|
+
val = args[i]
|
|
265
|
+
if isinstance(val, AsyncSession):
|
|
266
|
+
return val
|
|
267
|
+
|
|
268
|
+
return None
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import dataclasses
|
|
4
|
+
import inspect
|
|
5
|
+
import re
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def extract_param_names(func: Any) -> list[str]:
|
|
12
|
+
"""Get parameter names from a function, excluding ``self`` and ``session``."""
|
|
13
|
+
sig = inspect.signature(func)
|
|
14
|
+
return [
|
|
15
|
+
name
|
|
16
|
+
for name, param in sig.parameters.items()
|
|
17
|
+
if (name not in ("self", "session") and param.default is not inspect.Parameter.empty)
|
|
18
|
+
or (
|
|
19
|
+
param.kind
|
|
20
|
+
in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.POSITIONAL_ONLY)
|
|
21
|
+
and name not in ("self", "session")
|
|
22
|
+
)
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _is_complex(value: Any) -> bool:
|
|
27
|
+
"""Return True if *value* is a dataclass instance or has a ``__dict__``
|
|
28
|
+
that is not a basic built-in type."""
|
|
29
|
+
if value is None or isinstance(value, (str, int, float, bool, bytes, list, tuple)):
|
|
30
|
+
return False
|
|
31
|
+
return dataclasses.is_dataclass(value) and not isinstance(value, type)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _extract_fields(obj: Any) -> dict[str, Any]:
|
|
35
|
+
"""Pull fields from a dataclass instance, including property getters."""
|
|
36
|
+
result: dict[str, Any] = {}
|
|
37
|
+
if dataclasses.is_dataclass(obj) and not isinstance(obj, type):
|
|
38
|
+
for f in dataclasses.fields(obj):
|
|
39
|
+
result[f.name] = getattr(obj, f.name)
|
|
40
|
+
# Also pick up @property descriptors on the class
|
|
41
|
+
for name in dir(type(obj)):
|
|
42
|
+
if name.startswith("_"):
|
|
43
|
+
continue
|
|
44
|
+
attr = getattr(type(obj), name, None)
|
|
45
|
+
if isinstance(attr, property):
|
|
46
|
+
result[name] = attr.fget(obj) # type: ignore[arg-type]
|
|
47
|
+
return result
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def build_replacements(
|
|
51
|
+
param_names: list[str],
|
|
52
|
+
args: tuple[Any, ...],
|
|
53
|
+
) -> dict[str, Any]:
|
|
54
|
+
"""Map function arguments to SQL ``:param`` replacements.
|
|
55
|
+
|
|
56
|
+
Rules (mirroring the TS implementation):
|
|
57
|
+
* Filter out any ``AsyncSession`` values (transaction arg).
|
|
58
|
+
* Single complex arg → spread its fields as parameters.
|
|
59
|
+
* Multiple args → map each name to its value.
|
|
60
|
+
* ``None`` values stay as ``None`` (SQL ``NULL``).
|
|
61
|
+
"""
|
|
62
|
+
# Pair up names with values, skip sessions
|
|
63
|
+
pairs: list[tuple[str, Any]] = []
|
|
64
|
+
for name, value in zip(param_names, args, strict=False):
|
|
65
|
+
if isinstance(value, AsyncSession):
|
|
66
|
+
continue
|
|
67
|
+
pairs.append((name, value))
|
|
68
|
+
|
|
69
|
+
if len(pairs) == 1 and _is_complex(pairs[0][1]):
|
|
70
|
+
return _replace_none(_extract_fields(pairs[0][1]))
|
|
71
|
+
|
|
72
|
+
replacements: dict[str, Any] = {}
|
|
73
|
+
for name, value in pairs:
|
|
74
|
+
if _is_complex(value):
|
|
75
|
+
replacements.update(_extract_fields(value))
|
|
76
|
+
else:
|
|
77
|
+
replacements[name] = value
|
|
78
|
+
|
|
79
|
+
return _replace_none(replacements)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _replace_none(d: dict[str, Any]) -> dict[str, Any]:
|
|
83
|
+
"""Ensure ``None`` values remain as-is (SQLAlchemy handles NULL binding)."""
|
|
84
|
+
return d
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
# ---------------------------------------------------------------------------
|
|
88
|
+
# IN-clause expansion
|
|
89
|
+
# ---------------------------------------------------------------------------
|
|
90
|
+
|
|
91
|
+
_IN_PARAM_RE = re.compile(r":(\w+)")
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _param_is_in_clause(sql: str, key: str) -> bool:
|
|
95
|
+
"""Check if ``:key`` appears inside an ``IN (...)`` context in the SQL."""
|
|
96
|
+
pattern = re.compile(rf"IN\s*\([^)]*:{re.escape(key)}(?!\w)[^)]*\)", re.IGNORECASE)
|
|
97
|
+
return pattern.search(sql) is not None
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _param_is_null_check(sql: str, key: str) -> bool:
|
|
101
|
+
"""Check if ``:key IS NULL`` appears in the SQL."""
|
|
102
|
+
pattern = re.compile(rf":{re.escape(key)}\s+IS\s+NULL", re.IGNORECASE)
|
|
103
|
+
return pattern.search(sql) is not None
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def expand_in_clauses(sql: str, replacements: dict[str, Any]) -> tuple[str, dict[str, Any]]:
|
|
107
|
+
"""Expand list-valued parameters for ``IN (:param)`` clauses.
|
|
108
|
+
|
|
109
|
+
Given ``IN (:ids)`` and ``ids=[1, 2, 3]``, rewrites to
|
|
110
|
+
``IN (:ids_0, :ids_1, :ids_2)`` with individual entries in replacements.
|
|
111
|
+
|
|
112
|
+
List parameters that appear outside of IN clauses (e.g. PostgreSQL array
|
|
113
|
+
values) are left as-is.
|
|
114
|
+
|
|
115
|
+
Also handles the ``(:param IS NULL OR ... IN (:param))`` pattern.
|
|
116
|
+
"""
|
|
117
|
+
expanded: dict[str, Any] = {}
|
|
118
|
+
new_sql = sql
|
|
119
|
+
|
|
120
|
+
for key, value in list(replacements.items()):
|
|
121
|
+
has_in = _param_is_in_clause(new_sql, key)
|
|
122
|
+
has_is_null = _param_is_null_check(new_sql, key)
|
|
123
|
+
|
|
124
|
+
if value is None and has_is_null:
|
|
125
|
+
# None means "skip this filter"
|
|
126
|
+
_is_null_pat = re.compile(rf":{re.escape(key)}\s+IS\s+NULL", re.IGNORECASE)
|
|
127
|
+
new_sql = _is_null_pat.sub("TRUE", new_sql)
|
|
128
|
+
if has_in:
|
|
129
|
+
_in_pat = re.compile(rf"IN\s*\(\s*:{re.escape(key)}\s*\)", re.IGNORECASE)
|
|
130
|
+
new_sql = _in_pat.sub("IN (NULL)", new_sql)
|
|
131
|
+
continue
|
|
132
|
+
|
|
133
|
+
if not isinstance(value, (list, tuple)):
|
|
134
|
+
expanded[key] = value
|
|
135
|
+
continue
|
|
136
|
+
|
|
137
|
+
# List value but NOT used in an IN clause — it's a PostgreSQL array param
|
|
138
|
+
if not has_in:
|
|
139
|
+
expanded[key] = value
|
|
140
|
+
continue
|
|
141
|
+
|
|
142
|
+
# --- list/tuple in an IN clause ---
|
|
143
|
+
|
|
144
|
+
# Handle ":key IS NULL" patterns
|
|
145
|
+
if has_is_null:
|
|
146
|
+
_is_null_pat = re.compile(rf":{re.escape(key)}\s+IS\s+NULL", re.IGNORECASE)
|
|
147
|
+
new_sql = _is_null_pat.sub("FALSE", new_sql)
|
|
148
|
+
|
|
149
|
+
if len(value) == 0:
|
|
150
|
+
_in_pat = re.compile(rf"IN\s*\(\s*:{re.escape(key)}\s*\)", re.IGNORECASE)
|
|
151
|
+
new_sql = _in_pat.sub("IN (NULL)", new_sql)
|
|
152
|
+
continue
|
|
153
|
+
|
|
154
|
+
# Build individual param names
|
|
155
|
+
param_names = [f"{key}_{i}" for i in range(len(value))]
|
|
156
|
+
placeholders = ", ".join(f":{name}" for name in param_names)
|
|
157
|
+
|
|
158
|
+
# Only replace :key references inside IN clauses
|
|
159
|
+
_in_content_pat = re.compile(
|
|
160
|
+
rf"(IN\s*\([^)]*):{re.escape(key)}(?!\w)([^)]*\))", re.IGNORECASE
|
|
161
|
+
)
|
|
162
|
+
new_sql = _in_content_pat.sub(rf"\g<1>{placeholders}\g<2>", new_sql)
|
|
163
|
+
|
|
164
|
+
for pname, val in zip(param_names, value, strict=False):
|
|
165
|
+
expanded[pname] = val
|
|
166
|
+
|
|
167
|
+
return new_sql, expanded
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
_query_cache: dict[str, str] = {}
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass(frozen=True, slots=True)
|
|
10
|
+
class BatchingOptions:
|
|
11
|
+
limit: int
|
|
12
|
+
offset: int
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def load_query(
|
|
16
|
+
*,
|
|
17
|
+
inline_query: str | None,
|
|
18
|
+
file_path: str | None,
|
|
19
|
+
query_directory: Path | None,
|
|
20
|
+
class_name: str,
|
|
21
|
+
method_name: str,
|
|
22
|
+
) -> str:
|
|
23
|
+
cache_key = f"{class_name}.{method_name}"
|
|
24
|
+
cached = _query_cache.get(cache_key)
|
|
25
|
+
if cached is not None:
|
|
26
|
+
return cached
|
|
27
|
+
|
|
28
|
+
if inline_query is not None:
|
|
29
|
+
_query_cache[cache_key] = inline_query
|
|
30
|
+
return inline_query
|
|
31
|
+
|
|
32
|
+
if file_path is not None:
|
|
33
|
+
resolved = Path(file_path)
|
|
34
|
+
if not resolved.is_absolute() and query_directory is not None:
|
|
35
|
+
resolved = query_directory / file_path
|
|
36
|
+
elif query_directory is not None:
|
|
37
|
+
resolved = query_directory / f"{method_name}.sql"
|
|
38
|
+
else:
|
|
39
|
+
raise ValueError(
|
|
40
|
+
f"{class_name}.{method_name}: no query, file, or query_directory specified"
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
if not resolved.exists():
|
|
44
|
+
raise FileNotFoundError(f"{class_name}.{method_name}: SQL file not found: {resolved}")
|
|
45
|
+
|
|
46
|
+
sql = resolved.read_text()
|
|
47
|
+
_query_cache[cache_key] = sql
|
|
48
|
+
return sql
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def apply_batching(sql: str, batching: BatchingOptions | None) -> str:
|
|
52
|
+
if batching is None:
|
|
53
|
+
return sql
|
|
54
|
+
return f"{sql.rstrip().rstrip(';')}\nLIMIT {batching.limit} OFFSET {batching.offset}"
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import dataclasses
|
|
4
|
+
import json
|
|
5
|
+
from collections.abc import Mapping, Sequence
|
|
6
|
+
from decimal import Decimal
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from .types import QueryType
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def map_result(
|
|
13
|
+
rows: Sequence[Mapping[str, Any]],
|
|
14
|
+
query_type: QueryType,
|
|
15
|
+
clazz: type[Any] | None,
|
|
16
|
+
return_list: bool,
|
|
17
|
+
) -> Any:
|
|
18
|
+
"""Map raw database rows to the appropriate return shape."""
|
|
19
|
+
if query_type in (
|
|
20
|
+
QueryType.UPSERT,
|
|
21
|
+
QueryType.DELETE,
|
|
22
|
+
QueryType.BULK_UPDATE,
|
|
23
|
+
QueryType.BULK_DELETE,
|
|
24
|
+
QueryType.RAW,
|
|
25
|
+
):
|
|
26
|
+
return None
|
|
27
|
+
|
|
28
|
+
if query_type == QueryType.SELECT:
|
|
29
|
+
return _map_select(rows, clazz, return_list)
|
|
30
|
+
|
|
31
|
+
if query_type in (QueryType.INSERT, QueryType.UPDATE):
|
|
32
|
+
return _map_returning(rows, clazz, return_list)
|
|
33
|
+
|
|
34
|
+
return None
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _map_select(
|
|
38
|
+
rows: Sequence[Mapping[str, Any]],
|
|
39
|
+
clazz: type[Any] | None,
|
|
40
|
+
return_list: bool,
|
|
41
|
+
) -> Any:
|
|
42
|
+
if return_list:
|
|
43
|
+
return [to_instance(dict(r), clazz) for r in rows]
|
|
44
|
+
|
|
45
|
+
if len(rows) == 0:
|
|
46
|
+
return None
|
|
47
|
+
|
|
48
|
+
if len(rows) > 1:
|
|
49
|
+
raise ValueError(
|
|
50
|
+
f"Expected 0 or 1 rows but got {len(rows)}. "
|
|
51
|
+
"Set return_list=True in QueryOptions to allow multiple rows."
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
return to_instance(dict(rows[0]), clazz)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _map_returning(
|
|
58
|
+
rows: Sequence[Mapping[str, Any]],
|
|
59
|
+
clazz: type[Any] | None,
|
|
60
|
+
return_list: bool,
|
|
61
|
+
) -> Any:
|
|
62
|
+
if len(rows) == 0:
|
|
63
|
+
return None
|
|
64
|
+
|
|
65
|
+
if return_list:
|
|
66
|
+
return [to_instance(dict(r), clazz) for r in rows]
|
|
67
|
+
|
|
68
|
+
return to_instance(dict(rows[0]), clazz)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def to_instance(row: dict[str, Any], clazz: type[Any] | None) -> Any:
|
|
72
|
+
"""Convert a row dict to a class instance or cleaned dict."""
|
|
73
|
+
cleaned = _clean_row(row)
|
|
74
|
+
|
|
75
|
+
if clazz is None:
|
|
76
|
+
return cleaned
|
|
77
|
+
|
|
78
|
+
if dataclasses.is_dataclass(clazz):
|
|
79
|
+
field_names = {f.name for f in dataclasses.fields(clazz)}
|
|
80
|
+
filtered = {k: v for k, v in cleaned.items() if k in field_names}
|
|
81
|
+
|
|
82
|
+
# Handle nested dataclass fields (e.g. from json_agg)
|
|
83
|
+
for f in dataclasses.fields(clazz):
|
|
84
|
+
if f.name in filtered and filtered[f.name] is not None:
|
|
85
|
+
val = filtered[f.name]
|
|
86
|
+
# Check if the field type annotation hints at a list of dataclasses
|
|
87
|
+
origin = getattr(f.type, "__origin__", None)
|
|
88
|
+
if origin is list and isinstance(val, list):
|
|
89
|
+
args = getattr(f.type, "__args__", ())
|
|
90
|
+
if args and dataclasses.is_dataclass(args[0]):
|
|
91
|
+
inner_cls = args[0]
|
|
92
|
+
filtered[f.name] = [
|
|
93
|
+
to_instance(item, inner_cls) if isinstance(item, dict) else item
|
|
94
|
+
for item in val
|
|
95
|
+
]
|
|
96
|
+
|
|
97
|
+
return clazz(**filtered)
|
|
98
|
+
|
|
99
|
+
# Fallback: try to construct the class directly
|
|
100
|
+
return clazz(**cleaned)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _clean_row(row: dict[str, Any]) -> dict[str, Any]:
|
|
104
|
+
"""Normalise database values: Decimal→float/int, JSON strings→objects."""
|
|
105
|
+
cleaned: dict[str, Any] = {}
|
|
106
|
+
for key, value in row.items():
|
|
107
|
+
if value is None:
|
|
108
|
+
cleaned[key] = None
|
|
109
|
+
elif isinstance(value, Decimal):
|
|
110
|
+
# Match TS behaviour: NUMERIC → float, BIGINT → int
|
|
111
|
+
if value == int(value):
|
|
112
|
+
cleaned[key] = int(value)
|
|
113
|
+
else:
|
|
114
|
+
cleaned[key] = float(value)
|
|
115
|
+
elif isinstance(value, str) and value.startswith("["):
|
|
116
|
+
try:
|
|
117
|
+
cleaned[key] = json.loads(value)
|
|
118
|
+
except (json.JSONDecodeError, ValueError):
|
|
119
|
+
cleaned[key] = value
|
|
120
|
+
else:
|
|
121
|
+
cleaned[key] = value
|
|
122
|
+
return cleaned
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import AsyncIterator
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from sqlalchemy import text
|
|
7
|
+
from sqlalchemy.ext.asyncio import AsyncEngine
|
|
8
|
+
|
|
9
|
+
from .param_mapper import expand_in_clauses
|
|
10
|
+
from .query import BatchingOptions, apply_batching
|
|
11
|
+
from .result_mapper import to_instance
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class StreamIterator[T]:
|
|
15
|
+
"""Lazily streams query results in batches, yielding one row at a time."""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
engine: AsyncEngine,
|
|
20
|
+
sql: str,
|
|
21
|
+
replacements: dict[str, Any],
|
|
22
|
+
clazz: type[T] | None = None,
|
|
23
|
+
batch_size: int = 1000,
|
|
24
|
+
) -> None:
|
|
25
|
+
self._engine = engine
|
|
26
|
+
self._sql = sql
|
|
27
|
+
self._replacements = replacements
|
|
28
|
+
self._clazz = clazz
|
|
29
|
+
self._batch_size = batch_size
|
|
30
|
+
|
|
31
|
+
async def _generate(self) -> AsyncIterator[T]:
|
|
32
|
+
offset = 0
|
|
33
|
+
while True:
|
|
34
|
+
batching = BatchingOptions(limit=self._batch_size, offset=offset)
|
|
35
|
+
batch_sql = apply_batching(self._sql, batching)
|
|
36
|
+
expanded_sql, expanded_params = expand_in_clauses(batch_sql, self._replacements)
|
|
37
|
+
|
|
38
|
+
async with self._engine.connect() as conn:
|
|
39
|
+
result = await conn.execute(text(expanded_sql), expanded_params)
|
|
40
|
+
rows = result.mappings().all()
|
|
41
|
+
|
|
42
|
+
for row in rows:
|
|
43
|
+
instance = to_instance(dict(row), self._clazz)
|
|
44
|
+
yield instance
|
|
45
|
+
|
|
46
|
+
if len(rows) < self._batch_size:
|
|
47
|
+
break
|
|
48
|
+
|
|
49
|
+
offset += self._batch_size
|
|
50
|
+
|
|
51
|
+
def __aiter__(self) -> AsyncIterator[T]:
|
|
52
|
+
return self._generate()
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import enum
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class QueryType(enum.Enum):
|
|
9
|
+
SELECT = "SELECT"
|
|
10
|
+
INSERT = "INSERT"
|
|
11
|
+
UPDATE = "UPDATE"
|
|
12
|
+
DELETE = "DELETE"
|
|
13
|
+
UPSERT = "UPSERT"
|
|
14
|
+
BULK_UPDATE = "BULK_UPDATE"
|
|
15
|
+
BULK_DELETE = "BULK_DELETE"
|
|
16
|
+
RAW = "RAW"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(frozen=True, slots=True)
|
|
20
|
+
class BooleanResult:
|
|
21
|
+
result: bool
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass(frozen=True, slots=True)
|
|
25
|
+
class StringResult:
|
|
26
|
+
result: str
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass(frozen=True, slots=True)
|
|
30
|
+
class NumberResult:
|
|
31
|
+
result: int | float
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass(frozen=True, slots=True)
|
|
35
|
+
class ArrayResult:
|
|
36
|
+
result: list[str]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
type Clazz[T] = type[T]
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass(frozen=True)
|
|
43
|
+
class QueryOptions:
|
|
44
|
+
query_type: QueryType
|
|
45
|
+
clazz: type[Any] | None = None
|
|
46
|
+
return_list: bool = False
|
|
47
|
+
query: str | None = None
|
|
48
|
+
file: str | None = None
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass(frozen=True)
|
|
52
|
+
class StreamSelectOptions:
|
|
53
|
+
clazz: type[Any] | None = None
|
|
54
|
+
batch_size: int = 1000
|
|
55
|
+
query: str | None = None
|
|
56
|
+
file: str | None = None
|