zeroshot-sql-decorators 0.1.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,14 @@
1
+ Metadata-Version: 2.4
2
+ Name: zeroshot-sql-decorators
3
+ Version: 0.1.5
4
+ Summary: SQL-focused decorators and helpers for Zeroshot Python packages.
5
+ License-Expression: MIT
6
+ Requires-Dist: zeroshot-commons==0.1.5
7
+ Requires-Dist: sqlalchemy[asyncio]>=2.0
8
+ Requires-Dist: asyncpg>=0.30
9
+ Requires-Python: >=3.12
10
+ Description-Content-Type: text/markdown
11
+
12
+ # zeroshot-sql-decorators
13
+
14
+ SQL-focused decorators and helpers for Zeroshot Python packages.
@@ -0,0 +1,3 @@
1
+ # zeroshot-sql-decorators
2
+
3
+ SQL-focused decorators and helpers for Zeroshot Python packages.
@@ -0,0 +1,19 @@
1
+ [project]
2
+ name = "zeroshot-sql-decorators"
3
+ version = "0.1.5"
4
+ description = "SQL-focused decorators and helpers for Zeroshot Python packages."
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ license = "MIT"
8
+ dependencies = [
9
+ "zeroshot-commons==0.1.5",
10
+ "sqlalchemy[asyncio]>=2.0",
11
+ "asyncpg>=0.30",
12
+ ]
13
+
14
+ [tool.uv.sources]
15
+ zeroshot-commons = { workspace = true }
16
+
17
+ [build-system]
18
+ requires = ["uv_build>=0.11.8,<0.12"]
19
+ build-backend = "uv_build"
@@ -0,0 +1,39 @@
1
+ """SQL-focused decorators and helpers for Zeroshot Python packages."""
2
+
3
+ from .decorators import (
4
+ DaoBase,
5
+ TransactionalityBase,
6
+ dao,
7
+ sql_query,
8
+ sql_transaction,
9
+ stream_select,
10
+ with_transactionality,
11
+ )
12
+ from .stream_iterator import StreamIterator
13
+ from .types import (
14
+ ArrayResult,
15
+ BooleanResult,
16
+ NumberResult,
17
+ QueryOptions,
18
+ QueryType,
19
+ StreamSelectOptions,
20
+ StringResult,
21
+ )
22
+
23
+ __all__ = [
24
+ "ArrayResult",
25
+ "BooleanResult",
26
+ "DaoBase",
27
+ "NumberResult",
28
+ "QueryOptions",
29
+ "QueryType",
30
+ "StreamIterator",
31
+ "StreamSelectOptions",
32
+ "StringResult",
33
+ "TransactionalityBase",
34
+ "dao",
35
+ "sql_query",
36
+ "sql_transaction",
37
+ "stream_select",
38
+ "with_transactionality",
39
+ ]
@@ -0,0 +1,268 @@
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ import inspect
5
+ import logging
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ from sqlalchemy import text
10
+ from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
11
+
12
+ from .param_mapper import build_replacements, expand_in_clauses, extract_param_names
13
+ from .query import load_query
14
+ from .result_mapper import map_result
15
+ from .stream_iterator import StreamIterator
16
+ from .types import QueryOptions, StreamSelectOptions
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ # ---------------------------------------------------------------------------
22
+ # Base classes
23
+ # ---------------------------------------------------------------------------
24
+
25
+
26
+ class DaoBase:
27
+ """Base class for Data Access Objects.
28
+
29
+ Holds a reference to the SQLAlchemy ``AsyncEngine``.
30
+ """
31
+
32
+ def __init__(self, engine: AsyncEngine) -> None:
33
+ self._engine = engine
34
+
35
+
36
+ class TransactionalityBase:
37
+ """Base class for repositories that coordinate transactions across DAOs."""
38
+
39
+ def __init__(self, engine: AsyncEngine) -> None:
40
+ self._engine = engine
41
+
42
+
43
+ # ---------------------------------------------------------------------------
44
+ # Class decorators
45
+ # ---------------------------------------------------------------------------
46
+
47
+
48
+ def dao(*, query_directory: Path | str | None = None) -> Any:
49
+ """Mark a class as a DAO with an optional query file directory."""
50
+
51
+ def decorator(cls: type) -> type:
52
+ cls._query_directory = Path(query_directory) if query_directory else None # type: ignore[attr-defined]
53
+ return cls
54
+
55
+ return decorator
56
+
57
+
58
+ def with_transactionality() -> Any:
59
+ """Mark a class as transaction-aware (mirrors TS ``@WithTransactionality()``)."""
60
+
61
+ def decorator(cls: type) -> type:
62
+ return cls
63
+
64
+ return decorator
65
+
66
+
67
+ # ---------------------------------------------------------------------------
68
+ # @sql_query
69
+ # ---------------------------------------------------------------------------
70
+
71
+
72
+ def sql_query(options: QueryOptions) -> Any:
73
+ """Method decorator that executes a SQL query and maps the result.
74
+
75
+ The decorated method becomes a stub — its body is never called.
76
+ Parameters are extracted from the function signature and bound to the query.
77
+ An optional ``session: AsyncSession | None = None`` parameter enables
78
+ transaction participation.
79
+ """
80
+
81
+ def decorator(fn: Any) -> Any:
82
+ param_names = extract_param_names(fn)
83
+
84
+ @functools.wraps(fn)
85
+ async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
86
+ engine: AsyncEngine = self._engine
87
+
88
+ # Resolve query directory
89
+ query_dir: Path | None = getattr(self.__class__, "_query_directory", None)
90
+
91
+ sql = load_query(
92
+ inline_query=options.query,
93
+ file_path=options.file,
94
+ query_directory=query_dir,
95
+ class_name=self.__class__.__name__,
96
+ method_name=fn.__name__,
97
+ )
98
+
99
+ # Build all args (positional + keyword), excluding self
100
+ all_args = _merge_args(fn, args, kwargs)
101
+ replacements = build_replacements(param_names, all_args)
102
+
103
+ # Expand IN clauses for list parameters
104
+ expanded_sql, expanded_params = expand_in_clauses(sql, replacements)
105
+
106
+ # Check for an explicit session argument
107
+ session = _extract_session(fn, args, kwargs)
108
+
109
+ if session is not None:
110
+ result = await session.execute(text(expanded_sql), expanded_params)
111
+ try:
112
+ rows: list[Any] = [dict(r) for r in result.mappings().all()]
113
+ except Exception:
114
+ rows = []
115
+ else:
116
+ async with engine.connect() as conn:
117
+ result = await conn.execute(text(expanded_sql), expanded_params)
118
+ try:
119
+ rows = [dict(r) for r in result.mappings().all()]
120
+ except Exception:
121
+ rows = []
122
+ await conn.commit()
123
+
124
+ return map_result(rows, options.query_type, options.clazz, options.return_list)
125
+
126
+ return wrapper
127
+
128
+ return decorator
129
+
130
+
131
+ # ---------------------------------------------------------------------------
132
+ # @stream_select
133
+ # ---------------------------------------------------------------------------
134
+
135
+
136
+ def stream_select(options: StreamSelectOptions) -> Any:
137
+ """Method decorator that returns a ``StreamIterator`` for lazy batch iteration."""
138
+
139
+ def decorator(fn: Any) -> Any:
140
+ param_names = extract_param_names(fn)
141
+
142
+ @functools.wraps(fn)
143
+ def wrapper(self: Any, *args: Any, **kwargs: Any) -> StreamIterator[Any]:
144
+ engine: AsyncEngine = self._engine
145
+ query_dir: Path | None = getattr(self.__class__, "_query_directory", None)
146
+
147
+ sql = load_query(
148
+ inline_query=options.query,
149
+ file_path=options.file,
150
+ query_directory=query_dir,
151
+ class_name=self.__class__.__name__,
152
+ method_name=fn.__name__,
153
+ )
154
+
155
+ all_args = _merge_args(fn, args, kwargs)
156
+ replacements = build_replacements(param_names, all_args)
157
+
158
+ return StreamIterator(
159
+ engine=engine,
160
+ sql=sql,
161
+ replacements=replacements,
162
+ clazz=options.clazz,
163
+ batch_size=options.batch_size,
164
+ )
165
+
166
+ return wrapper
167
+
168
+ return decorator
169
+
170
+
171
+ # ---------------------------------------------------------------------------
172
+ # @sql_transaction
173
+ # ---------------------------------------------------------------------------
174
+
175
+
176
+ def sql_transaction(
177
+ *,
178
+ isolation_level: str | None = None,
179
+ ) -> Any:
180
+ """Method decorator that wraps the call in a database transaction.
181
+
182
+ If a ``session`` kwarg/arg is already provided and has an active
183
+ transaction, the method participates in that transaction.
184
+ Otherwise a new transaction is created, committed on success,
185
+ and rolled back on error.
186
+ """
187
+
188
+ def decorator(fn: Any) -> Any:
189
+ @functools.wraps(fn)
190
+ async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
191
+ engine: AsyncEngine = self._engine
192
+
193
+ existing_session = _extract_session(fn, args, kwargs)
194
+
195
+ if existing_session is not None and existing_session.in_transaction():
196
+ # Already in a transaction — just call through
197
+ return await fn(self, *args, **kwargs)
198
+
199
+ if existing_session is not None:
200
+ raise ValueError(
201
+ "A session was provided but has no active transaction. "
202
+ "Do not pass a session unless it is already transactional."
203
+ )
204
+
205
+ # Create a new session + transaction
206
+ exec_opts: dict[str, Any] = {}
207
+ if isolation_level:
208
+ exec_opts["isolation_level"] = isolation_level
209
+
210
+ session_factory = async_sessionmaker(
211
+ engine,
212
+ expire_on_commit=False,
213
+ )
214
+ async with session_factory.begin() as session:
215
+ if exec_opts:
216
+ await session.connection(execution_options=exec_opts)
217
+ new_kwargs = dict(kwargs)
218
+ new_kwargs["session"] = session
219
+ return await fn(self, *args, **new_kwargs)
220
+
221
+ return wrapper
222
+
223
+ return decorator
224
+
225
+
226
+ # ---------------------------------------------------------------------------
227
+ # Helpers
228
+ # ---------------------------------------------------------------------------
229
+
230
+
231
+ def _merge_args(fn: Any, args: tuple[Any, ...], kwargs: dict[str, Any]) -> tuple[Any, ...]:
232
+ """Merge positional and keyword args into a positional tuple,
233
+ respecting the original function signature order (excluding ``self``)."""
234
+ sig = inspect.signature(fn)
235
+ params = [p for p in sig.parameters.values() if p.name != "self"]
236
+
237
+ merged: list[Any] = list(args)
238
+
239
+ # Fill in remaining params from kwargs
240
+ for i in range(len(args), len(params)):
241
+ p = params[i]
242
+ if p.name in kwargs:
243
+ merged.append(kwargs[p.name])
244
+ elif p.default is not inspect.Parameter.empty:
245
+ merged.append(p.default)
246
+
247
+ return tuple(merged)
248
+
249
+
250
+ def _extract_session(
251
+ fn: Any,
252
+ args: tuple[Any, ...],
253
+ kwargs: dict[str, Any],
254
+ ) -> AsyncSession | None:
255
+ """Pull out the ``session`` argument if present."""
256
+ if "session" in kwargs:
257
+ return kwargs["session"]
258
+
259
+ sig = inspect.signature(fn)
260
+ params = [p for p in sig.parameters.values() if p.name != "self"]
261
+
262
+ for i, p in enumerate(params):
263
+ if p.name == "session" and i < len(args):
264
+ val = args[i]
265
+ if isinstance(val, AsyncSession):
266
+ return val
267
+
268
+ return None
@@ -0,0 +1,167 @@
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ import inspect
5
+ import re
6
+ from typing import Any
7
+
8
+ from sqlalchemy.ext.asyncio import AsyncSession
9
+
10
+
11
+ def extract_param_names(func: Any) -> list[str]:
12
+ """Get parameter names from a function, excluding ``self`` and ``session``."""
13
+ sig = inspect.signature(func)
14
+ return [
15
+ name
16
+ for name, param in sig.parameters.items()
17
+ if (name not in ("self", "session") and param.default is not inspect.Parameter.empty)
18
+ or (
19
+ param.kind
20
+ in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.POSITIONAL_ONLY)
21
+ and name not in ("self", "session")
22
+ )
23
+ ]
24
+
25
+
26
+ def _is_complex(value: Any) -> bool:
27
+ """Return True if *value* is a dataclass instance or has a ``__dict__``
28
+ that is not a basic built-in type."""
29
+ if value is None or isinstance(value, (str, int, float, bool, bytes, list, tuple)):
30
+ return False
31
+ return dataclasses.is_dataclass(value) and not isinstance(value, type)
32
+
33
+
34
+ def _extract_fields(obj: Any) -> dict[str, Any]:
35
+ """Pull fields from a dataclass instance, including property getters."""
36
+ result: dict[str, Any] = {}
37
+ if dataclasses.is_dataclass(obj) and not isinstance(obj, type):
38
+ for f in dataclasses.fields(obj):
39
+ result[f.name] = getattr(obj, f.name)
40
+ # Also pick up @property descriptors on the class
41
+ for name in dir(type(obj)):
42
+ if name.startswith("_"):
43
+ continue
44
+ attr = getattr(type(obj), name, None)
45
+ if isinstance(attr, property):
46
+ result[name] = attr.fget(obj) # type: ignore[arg-type]
47
+ return result
48
+
49
+
50
+ def build_replacements(
51
+ param_names: list[str],
52
+ args: tuple[Any, ...],
53
+ ) -> dict[str, Any]:
54
+ """Map function arguments to SQL ``:param`` replacements.
55
+
56
+ Rules (mirroring the TS implementation):
57
+ * Filter out any ``AsyncSession`` values (transaction arg).
58
+ * Single complex arg → spread its fields as parameters.
59
+ * Multiple args → map each name to its value.
60
+ * ``None`` values stay as ``None`` (SQL ``NULL``).
61
+ """
62
+ # Pair up names with values, skip sessions
63
+ pairs: list[tuple[str, Any]] = []
64
+ for name, value in zip(param_names, args, strict=False):
65
+ if isinstance(value, AsyncSession):
66
+ continue
67
+ pairs.append((name, value))
68
+
69
+ if len(pairs) == 1 and _is_complex(pairs[0][1]):
70
+ return _replace_none(_extract_fields(pairs[0][1]))
71
+
72
+ replacements: dict[str, Any] = {}
73
+ for name, value in pairs:
74
+ if _is_complex(value):
75
+ replacements.update(_extract_fields(value))
76
+ else:
77
+ replacements[name] = value
78
+
79
+ return _replace_none(replacements)
80
+
81
+
82
+ def _replace_none(d: dict[str, Any]) -> dict[str, Any]:
83
+ """Ensure ``None`` values remain as-is (SQLAlchemy handles NULL binding)."""
84
+ return d
85
+
86
+
87
+ # ---------------------------------------------------------------------------
88
+ # IN-clause expansion
89
+ # ---------------------------------------------------------------------------
90
+
91
+ _IN_PARAM_RE = re.compile(r":(\w+)")
92
+
93
+
94
+ def _param_is_in_clause(sql: str, key: str) -> bool:
95
+ """Check if ``:key`` appears inside an ``IN (...)`` context in the SQL."""
96
+ pattern = re.compile(rf"IN\s*\([^)]*:{re.escape(key)}(?!\w)[^)]*\)", re.IGNORECASE)
97
+ return pattern.search(sql) is not None
98
+
99
+
100
+ def _param_is_null_check(sql: str, key: str) -> bool:
101
+ """Check if ``:key IS NULL`` appears in the SQL."""
102
+ pattern = re.compile(rf":{re.escape(key)}\s+IS\s+NULL", re.IGNORECASE)
103
+ return pattern.search(sql) is not None
104
+
105
+
106
+ def expand_in_clauses(sql: str, replacements: dict[str, Any]) -> tuple[str, dict[str, Any]]:
107
+ """Expand list-valued parameters for ``IN (:param)`` clauses.
108
+
109
+ Given ``IN (:ids)`` and ``ids=[1, 2, 3]``, rewrites to
110
+ ``IN (:ids_0, :ids_1, :ids_2)`` with individual entries in replacements.
111
+
112
+ List parameters that appear outside of IN clauses (e.g. PostgreSQL array
113
+ values) are left as-is.
114
+
115
+ Also handles the ``(:param IS NULL OR ... IN (:param))`` pattern.
116
+ """
117
+ expanded: dict[str, Any] = {}
118
+ new_sql = sql
119
+
120
+ for key, value in list(replacements.items()):
121
+ has_in = _param_is_in_clause(new_sql, key)
122
+ has_is_null = _param_is_null_check(new_sql, key)
123
+
124
+ if value is None and has_is_null:
125
+ # None means "skip this filter"
126
+ _is_null_pat = re.compile(rf":{re.escape(key)}\s+IS\s+NULL", re.IGNORECASE)
127
+ new_sql = _is_null_pat.sub("TRUE", new_sql)
128
+ if has_in:
129
+ _in_pat = re.compile(rf"IN\s*\(\s*:{re.escape(key)}\s*\)", re.IGNORECASE)
130
+ new_sql = _in_pat.sub("IN (NULL)", new_sql)
131
+ continue
132
+
133
+ if not isinstance(value, (list, tuple)):
134
+ expanded[key] = value
135
+ continue
136
+
137
+ # List value but NOT used in an IN clause — it's a PostgreSQL array param
138
+ if not has_in:
139
+ expanded[key] = value
140
+ continue
141
+
142
+ # --- list/tuple in an IN clause ---
143
+
144
+ # Handle ":key IS NULL" patterns
145
+ if has_is_null:
146
+ _is_null_pat = re.compile(rf":{re.escape(key)}\s+IS\s+NULL", re.IGNORECASE)
147
+ new_sql = _is_null_pat.sub("FALSE", new_sql)
148
+
149
+ if len(value) == 0:
150
+ _in_pat = re.compile(rf"IN\s*\(\s*:{re.escape(key)}\s*\)", re.IGNORECASE)
151
+ new_sql = _in_pat.sub("IN (NULL)", new_sql)
152
+ continue
153
+
154
+ # Build individual param names
155
+ param_names = [f"{key}_{i}" for i in range(len(value))]
156
+ placeholders = ", ".join(f":{name}" for name in param_names)
157
+
158
+ # Only replace :key references inside IN clauses
159
+ _in_content_pat = re.compile(
160
+ rf"(IN\s*\([^)]*):{re.escape(key)}(?!\w)([^)]*\))", re.IGNORECASE
161
+ )
162
+ new_sql = _in_content_pat.sub(rf"\g<1>{placeholders}\g<2>", new_sql)
163
+
164
+ for pname, val in zip(param_names, value, strict=False):
165
+ expanded[pname] = val
166
+
167
+ return new_sql, expanded
@@ -0,0 +1,54 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+
6
+ _query_cache: dict[str, str] = {}
7
+
8
+
9
+ @dataclass(frozen=True, slots=True)
10
+ class BatchingOptions:
11
+ limit: int
12
+ offset: int
13
+
14
+
15
+ def load_query(
16
+ *,
17
+ inline_query: str | None,
18
+ file_path: str | None,
19
+ query_directory: Path | None,
20
+ class_name: str,
21
+ method_name: str,
22
+ ) -> str:
23
+ cache_key = f"{class_name}.{method_name}"
24
+ cached = _query_cache.get(cache_key)
25
+ if cached is not None:
26
+ return cached
27
+
28
+ if inline_query is not None:
29
+ _query_cache[cache_key] = inline_query
30
+ return inline_query
31
+
32
+ if file_path is not None:
33
+ resolved = Path(file_path)
34
+ if not resolved.is_absolute() and query_directory is not None:
35
+ resolved = query_directory / file_path
36
+ elif query_directory is not None:
37
+ resolved = query_directory / f"{method_name}.sql"
38
+ else:
39
+ raise ValueError(
40
+ f"{class_name}.{method_name}: no query, file, or query_directory specified"
41
+ )
42
+
43
+ if not resolved.exists():
44
+ raise FileNotFoundError(f"{class_name}.{method_name}: SQL file not found: {resolved}")
45
+
46
+ sql = resolved.read_text()
47
+ _query_cache[cache_key] = sql
48
+ return sql
49
+
50
+
51
+ def apply_batching(sql: str, batching: BatchingOptions | None) -> str:
52
+ if batching is None:
53
+ return sql
54
+ return f"{sql.rstrip().rstrip(';')}\nLIMIT {batching.limit} OFFSET {batching.offset}"
@@ -0,0 +1,122 @@
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ import json
5
+ from collections.abc import Mapping, Sequence
6
+ from decimal import Decimal
7
+ from typing import Any
8
+
9
+ from .types import QueryType
10
+
11
+
12
+ def map_result(
13
+ rows: Sequence[Mapping[str, Any]],
14
+ query_type: QueryType,
15
+ clazz: type[Any] | None,
16
+ return_list: bool,
17
+ ) -> Any:
18
+ """Map raw database rows to the appropriate return shape."""
19
+ if query_type in (
20
+ QueryType.UPSERT,
21
+ QueryType.DELETE,
22
+ QueryType.BULK_UPDATE,
23
+ QueryType.BULK_DELETE,
24
+ QueryType.RAW,
25
+ ):
26
+ return None
27
+
28
+ if query_type == QueryType.SELECT:
29
+ return _map_select(rows, clazz, return_list)
30
+
31
+ if query_type in (QueryType.INSERT, QueryType.UPDATE):
32
+ return _map_returning(rows, clazz, return_list)
33
+
34
+ return None
35
+
36
+
37
+ def _map_select(
38
+ rows: Sequence[Mapping[str, Any]],
39
+ clazz: type[Any] | None,
40
+ return_list: bool,
41
+ ) -> Any:
42
+ if return_list:
43
+ return [to_instance(dict(r), clazz) for r in rows]
44
+
45
+ if len(rows) == 0:
46
+ return None
47
+
48
+ if len(rows) > 1:
49
+ raise ValueError(
50
+ f"Expected 0 or 1 rows but got {len(rows)}. "
51
+ "Set return_list=True in QueryOptions to allow multiple rows."
52
+ )
53
+
54
+ return to_instance(dict(rows[0]), clazz)
55
+
56
+
57
+ def _map_returning(
58
+ rows: Sequence[Mapping[str, Any]],
59
+ clazz: type[Any] | None,
60
+ return_list: bool,
61
+ ) -> Any:
62
+ if len(rows) == 0:
63
+ return None
64
+
65
+ if return_list:
66
+ return [to_instance(dict(r), clazz) for r in rows]
67
+
68
+ return to_instance(dict(rows[0]), clazz)
69
+
70
+
71
+ def to_instance(row: dict[str, Any], clazz: type[Any] | None) -> Any:
72
+ """Convert a row dict to a class instance or cleaned dict."""
73
+ cleaned = _clean_row(row)
74
+
75
+ if clazz is None:
76
+ return cleaned
77
+
78
+ if dataclasses.is_dataclass(clazz):
79
+ field_names = {f.name for f in dataclasses.fields(clazz)}
80
+ filtered = {k: v for k, v in cleaned.items() if k in field_names}
81
+
82
+ # Handle nested dataclass fields (e.g. from json_agg)
83
+ for f in dataclasses.fields(clazz):
84
+ if f.name in filtered and filtered[f.name] is not None:
85
+ val = filtered[f.name]
86
+ # Check if the field type annotation hints at a list of dataclasses
87
+ origin = getattr(f.type, "__origin__", None)
88
+ if origin is list and isinstance(val, list):
89
+ args = getattr(f.type, "__args__", ())
90
+ if args and dataclasses.is_dataclass(args[0]):
91
+ inner_cls = args[0]
92
+ filtered[f.name] = [
93
+ to_instance(item, inner_cls) if isinstance(item, dict) else item
94
+ for item in val
95
+ ]
96
+
97
+ return clazz(**filtered)
98
+
99
+ # Fallback: try to construct the class directly
100
+ return clazz(**cleaned)
101
+
102
+
103
+ def _clean_row(row: dict[str, Any]) -> dict[str, Any]:
104
+ """Normalise database values: Decimal→float/int, JSON strings→objects."""
105
+ cleaned: dict[str, Any] = {}
106
+ for key, value in row.items():
107
+ if value is None:
108
+ cleaned[key] = None
109
+ elif isinstance(value, Decimal):
110
+ # Match TS behaviour: NUMERIC → float, BIGINT → int
111
+ if value == int(value):
112
+ cleaned[key] = int(value)
113
+ else:
114
+ cleaned[key] = float(value)
115
+ elif isinstance(value, str) and value.startswith("["):
116
+ try:
117
+ cleaned[key] = json.loads(value)
118
+ except (json.JSONDecodeError, ValueError):
119
+ cleaned[key] = value
120
+ else:
121
+ cleaned[key] = value
122
+ return cleaned
@@ -0,0 +1,52 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import AsyncIterator
4
+ from typing import Any
5
+
6
+ from sqlalchemy import text
7
+ from sqlalchemy.ext.asyncio import AsyncEngine
8
+
9
+ from .param_mapper import expand_in_clauses
10
+ from .query import BatchingOptions, apply_batching
11
+ from .result_mapper import to_instance
12
+
13
+
14
+ class StreamIterator[T]:
15
+ """Lazily streams query results in batches, yielding one row at a time."""
16
+
17
+ def __init__(
18
+ self,
19
+ engine: AsyncEngine,
20
+ sql: str,
21
+ replacements: dict[str, Any],
22
+ clazz: type[T] | None = None,
23
+ batch_size: int = 1000,
24
+ ) -> None:
25
+ self._engine = engine
26
+ self._sql = sql
27
+ self._replacements = replacements
28
+ self._clazz = clazz
29
+ self._batch_size = batch_size
30
+
31
+ async def _generate(self) -> AsyncIterator[T]:
32
+ offset = 0
33
+ while True:
34
+ batching = BatchingOptions(limit=self._batch_size, offset=offset)
35
+ batch_sql = apply_batching(self._sql, batching)
36
+ expanded_sql, expanded_params = expand_in_clauses(batch_sql, self._replacements)
37
+
38
+ async with self._engine.connect() as conn:
39
+ result = await conn.execute(text(expanded_sql), expanded_params)
40
+ rows = result.mappings().all()
41
+
42
+ for row in rows:
43
+ instance = to_instance(dict(row), self._clazz)
44
+ yield instance
45
+
46
+ if len(rows) < self._batch_size:
47
+ break
48
+
49
+ offset += self._batch_size
50
+
51
+ def __aiter__(self) -> AsyncIterator[T]:
52
+ return self._generate()
@@ -0,0 +1,56 @@
1
+ from __future__ import annotations
2
+
3
+ import enum
4
+ from dataclasses import dataclass
5
+ from typing import Any
6
+
7
+
8
+ class QueryType(enum.Enum):
9
+ SELECT = "SELECT"
10
+ INSERT = "INSERT"
11
+ UPDATE = "UPDATE"
12
+ DELETE = "DELETE"
13
+ UPSERT = "UPSERT"
14
+ BULK_UPDATE = "BULK_UPDATE"
15
+ BULK_DELETE = "BULK_DELETE"
16
+ RAW = "RAW"
17
+
18
+
19
+ @dataclass(frozen=True, slots=True)
20
+ class BooleanResult:
21
+ result: bool
22
+
23
+
24
+ @dataclass(frozen=True, slots=True)
25
+ class StringResult:
26
+ result: str
27
+
28
+
29
+ @dataclass(frozen=True, slots=True)
30
+ class NumberResult:
31
+ result: int | float
32
+
33
+
34
+ @dataclass(frozen=True, slots=True)
35
+ class ArrayResult:
36
+ result: list[str]
37
+
38
+
39
+ type Clazz[T] = type[T]
40
+
41
+
42
+ @dataclass(frozen=True)
43
+ class QueryOptions:
44
+ query_type: QueryType
45
+ clazz: type[Any] | None = None
46
+ return_list: bool = False
47
+ query: str | None = None
48
+ file: str | None = None
49
+
50
+
51
+ @dataclass(frozen=True)
52
+ class StreamSelectOptions:
53
+ clazz: type[Any] | None = None
54
+ batch_size: int = 1000
55
+ query: str | None = None
56
+ file: str | None = None