sqlphilosophy 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlphilosophy/VERSION +1 -0
- sqlphilosophy/__init__.py +3 -0
- sqlphilosophy/aio/__init__.py +3 -0
- sqlphilosophy/aio/protocols.py +26 -0
- sqlphilosophy/aio/query.py +396 -0
- sqlphilosophy/aio/repository.py +400 -0
- sqlphilosophy/audit/__init__.py +3 -0
- sqlphilosophy/audit/context.py +37 -0
- sqlphilosophy/audit/fields.py +24 -0
- sqlphilosophy/audit/listener.py +99 -0
- sqlphilosophy/audit/model.py +59 -0
- sqlphilosophy/py.typed +0 -0
- sqlphilosophy/sorting.py +97 -0
- sqlphilosophy/sql.py +532 -0
- sqlphilosophy/sync/__init__.py +3 -0
- sqlphilosophy/sync/protocols.py +26 -0
- sqlphilosophy/sync/query.py +392 -0
- sqlphilosophy/sync/repository.py +360 -0
- sqlphilosophy/types.py +61 -0
- sqlphilosophy-0.1.0.dist-info/METADATA +134 -0
- sqlphilosophy-0.1.0.dist-info/RECORD +24 -0
- sqlphilosophy-0.1.0.dist-info/WHEEL +5 -0
- sqlphilosophy-0.1.0.dist-info/licenses/LICENSE +21 -0
- sqlphilosophy-0.1.0.dist-info/top_level.txt +1 -0
sqlphilosophy/sorting.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
"""List pagination and sort resolution for repository queries."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
from collections.abc import Mapping
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from typing import Literal
|
|
8
|
+
|
|
9
|
+
SortDirection = Literal["asc", "desc"]
|
|
10
|
+
OrderByMap = dict[str, SortDirection]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass(frozen=True)
|
|
14
|
+
class SortSpec:
|
|
15
|
+
column: str
|
|
16
|
+
direction: SortDirection
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(frozen=True)
|
|
20
|
+
class ListQuery:
|
|
21
|
+
"""Offset/limit slice plus optional client sort (first ``order_by`` entry wins)."""
|
|
22
|
+
|
|
23
|
+
offset: int
|
|
24
|
+
limit: int
|
|
25
|
+
order_by: OrderByMap | None = None
|
|
26
|
+
|
|
27
|
+
@classmethod
|
|
28
|
+
def from_page(
|
|
29
|
+
cls,
|
|
30
|
+
*,
|
|
31
|
+
page: int,
|
|
32
|
+
size: int,
|
|
33
|
+
order_by: OrderByMap | None = None,
|
|
34
|
+
) -> ListQuery:
|
|
35
|
+
if page < 1:
|
|
36
|
+
raise ValueError("page must be >= 1")
|
|
37
|
+
if size < 1:
|
|
38
|
+
raise ValueError("size must be >= 1")
|
|
39
|
+
return cls(offset=(page - 1) * size, limit=size, order_by=order_by)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
SortResolver = Callable[[SortSpec], object]
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class SortConfig:
|
|
46
|
+
"""Allowed sort columns for a list endpoint.
|
|
47
|
+
|
|
48
|
+
Provide either:
|
|
49
|
+
|
|
50
|
+
* ``columns`` — map of API column name → ``{asc, desc}`` SQL/ORM expressions, or
|
|
51
|
+
* ``columns`` + ``literal_sql=True`` — map of string SQL fragments, or
|
|
52
|
+
* ``resolver`` — custom ``SortSpec → order clause(s)`` function.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
*,
|
|
58
|
+
default: SortSpec,
|
|
59
|
+
columns: Mapping[str, Mapping[str, object]] | None = None,
|
|
60
|
+
allowlist: frozenset[str] | None = None,
|
|
61
|
+
literal_sql: bool = False,
|
|
62
|
+
resolver: SortResolver | None = None,
|
|
63
|
+
) -> None:
|
|
64
|
+
if resolver is None and columns is None:
|
|
65
|
+
raise ValueError("SortConfig requires columns or resolver")
|
|
66
|
+
self._default = default
|
|
67
|
+
self._columns = columns or {}
|
|
68
|
+
self._literal_sql = literal_sql
|
|
69
|
+
self._resolver = resolver
|
|
70
|
+
self._allowlist = allowlist if allowlist is not None else frozenset(self._columns)
|
|
71
|
+
|
|
72
|
+
def resolve_spec(self, order_by: OrderByMap | None) -> SortSpec:
|
|
73
|
+
if order_by:
|
|
74
|
+
column, direction = next(iter(order_by.items()))
|
|
75
|
+
if direction in ("asc", "desc") and column in self._allowlist:
|
|
76
|
+
return SortSpec(column, direction)
|
|
77
|
+
return self._default
|
|
78
|
+
|
|
79
|
+
def order_expression(self, order_by: OrderByMap | None) -> object:
|
|
80
|
+
"""Return a single ORDER BY expression or a tuple of clauses."""
|
|
81
|
+
spec = self.resolve_spec(order_by)
|
|
82
|
+
if self._resolver is not None:
|
|
83
|
+
return self._resolver(spec)
|
|
84
|
+
if self._literal_sql:
|
|
85
|
+
from sqlphilosophy.sql import literal_order_expr
|
|
86
|
+
|
|
87
|
+
raw = self._columns[spec.column][spec.direction]
|
|
88
|
+
if not isinstance(raw, str):
|
|
89
|
+
raise TypeError("literal_sql SortConfig requires string column specs")
|
|
90
|
+
return literal_order_expr(raw)
|
|
91
|
+
return self._columns[spec.column][spec.direction]
|
|
92
|
+
|
|
93
|
+
def order_clauses(self, order_by: OrderByMap | None) -> tuple[object, ...]:
|
|
94
|
+
expr = self.order_expression(order_by)
|
|
95
|
+
if isinstance(expr, tuple):
|
|
96
|
+
return expr
|
|
97
|
+
return (expr,)
|
sqlphilosophy/sql.py
ADDED
|
@@ -0,0 +1,532 @@
|
|
|
1
|
+
"""SQLAlchemy query helpers — ORM-first, Core for performance paths."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
from collections.abc import Iterable
|
|
5
|
+
from collections.abc import Sequence
|
|
6
|
+
from datetime import date
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
from typing import Any
|
|
9
|
+
from typing import cast
|
|
10
|
+
from typing import Mapping
|
|
11
|
+
from typing import TypeVar
|
|
12
|
+
from uuid import UUID
|
|
13
|
+
from sqlalchemy import and_
|
|
14
|
+
from sqlalchemy import bindparam
|
|
15
|
+
from sqlalchemy import delete
|
|
16
|
+
from sqlalchemy import desc
|
|
17
|
+
from sqlalchemy import func
|
|
18
|
+
from sqlalchemy import inspect as sa_inspect
|
|
19
|
+
from sqlalchemy import literal_column
|
|
20
|
+
from sqlalchemy import select
|
|
21
|
+
from sqlalchemy import update
|
|
22
|
+
from sqlalchemy.orm import DeclarativeBase
|
|
23
|
+
from sqlalchemy.orm import Session
|
|
24
|
+
from sqlalchemy.sql import column
|
|
25
|
+
from sqlalchemy.sql import table
|
|
26
|
+
from sqlalchemy.sql.elements import BindParameter
|
|
27
|
+
from sqlphilosophy.audit.model import AuditMixin
|
|
28
|
+
from sqlphilosophy.sorting import OrderByMap
|
|
29
|
+
from sqlphilosophy.sorting import SortConfig
|
|
30
|
+
from sqlphilosophy.types import ApiObject
|
|
31
|
+
from sqlphilosophy.types import JSONObject
|
|
32
|
+
from sqlphilosophy.types import JSONValue
|
|
33
|
+
from sqlphilosophy.types import PrimaryKey
|
|
34
|
+
from sqlphilosophy.types import RowMapping
|
|
35
|
+
from sqlphilosophy.types import RowValue
|
|
36
|
+
from sqlphilosophy.types import SqlFilter
|
|
37
|
+
from sqlphilosophy.types import SqlFilters
|
|
38
|
+
from sqlphilosophy.types import SqlOrderColumn
|
|
39
|
+
from sqlphilosophy.types import SqlTable
|
|
40
|
+
|
|
41
|
+
_ModelT = TypeVar("_ModelT", bound=DeclarativeBase)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def sql_table(table_name: str, *column_names: str) -> SqlTable:
|
|
45
|
+
"""Lightweight Core table — prefer ORM models unless you need Core performance."""
|
|
46
|
+
return table(table_name, *[column(c) for c in column_names])
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def get_column_value(entity: object) -> ApiObject:
|
|
50
|
+
"""Return mapped column values for an ORM entity instance."""
|
|
51
|
+
insp = sa_inspect(entity.__class__)
|
|
52
|
+
return {attr.key: getattr(entity, attr.key) for attr in insp.mapper.column_attrs}
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def row_mapping(row: object) -> RowMapping:
|
|
56
|
+
"""Normalize a SQLAlchemy Row to a column-keyed dict."""
|
|
57
|
+
if row is None:
|
|
58
|
+
return {}
|
|
59
|
+
if hasattr(row, "_mapping"):
|
|
60
|
+
raw = dict(row._mapping)
|
|
61
|
+
else:
|
|
62
|
+
instance_state = sa_inspect(row, raiseerr=False)
|
|
63
|
+
if instance_state is not None and hasattr(instance_state, "mapper"):
|
|
64
|
+
return cast(RowMapping, get_column_value(row))
|
|
65
|
+
raw = dict(cast(Mapping[str, RowValue], row))
|
|
66
|
+
out: ApiObject = {}
|
|
67
|
+
for key, val in raw.items():
|
|
68
|
+
if hasattr(val, "__mapper__"):
|
|
69
|
+
out.update(get_column_value(val))
|
|
70
|
+
elif hasattr(key, "key"):
|
|
71
|
+
out[str(key.key)] = val
|
|
72
|
+
else:
|
|
73
|
+
out[str(key)] = val
|
|
74
|
+
return cast(RowMapping, out)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def row_mapping_opt(row: object | None) -> RowMapping | None:
|
|
78
|
+
if row is None:
|
|
79
|
+
return None
|
|
80
|
+
return row_mapping(row)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def row_int(row: RowMapping, key: str) -> int:
|
|
84
|
+
val = row[key]
|
|
85
|
+
if isinstance(val, bool):
|
|
86
|
+
raise TypeError(f"expected int for {key!r}, got bool")
|
|
87
|
+
if isinstance(val, int):
|
|
88
|
+
return val
|
|
89
|
+
if isinstance(val, float):
|
|
90
|
+
return int(val)
|
|
91
|
+
if isinstance(val, str):
|
|
92
|
+
return int(val)
|
|
93
|
+
raise TypeError(f"expected int for {key!r}, got {type(val).__name__}")
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def row_opt_int(row: RowMapping, key: str) -> int | None:
|
|
97
|
+
val = row.get(key)
|
|
98
|
+
if val is None:
|
|
99
|
+
return None
|
|
100
|
+
if isinstance(val, bool):
|
|
101
|
+
raise TypeError(f"expected int | None for {key!r}, got bool")
|
|
102
|
+
if isinstance(val, int):
|
|
103
|
+
return val
|
|
104
|
+
if isinstance(val, float):
|
|
105
|
+
return int(val)
|
|
106
|
+
if isinstance(val, str):
|
|
107
|
+
return int(val)
|
|
108
|
+
raise TypeError(f"expected int | None for {key!r}, got {type(val).__name__}")
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def row_str(row: RowMapping, key: str) -> str:
|
|
112
|
+
val = row[key]
|
|
113
|
+
if isinstance(val, str):
|
|
114
|
+
return val
|
|
115
|
+
if isinstance(val, (int, float, bool, UUID, date, datetime)):
|
|
116
|
+
return str(val)
|
|
117
|
+
if isinstance(val, bytes):
|
|
118
|
+
return val.decode()
|
|
119
|
+
raise TypeError(f"expected str for {key!r}, got {type(val).__name__}")
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def row_opt_str(row: RowMapping, key: str) -> str | None:
|
|
123
|
+
val = row.get(key)
|
|
124
|
+
if val is None:
|
|
125
|
+
return None
|
|
126
|
+
if isinstance(val, str):
|
|
127
|
+
return val
|
|
128
|
+
if isinstance(val, (int, float, bool, UUID, date, datetime)):
|
|
129
|
+
return str(val)
|
|
130
|
+
if isinstance(val, bytes):
|
|
131
|
+
return val.decode()
|
|
132
|
+
raise TypeError(f"expected str | None for {key!r}, got {type(val).__name__}")
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def row_bool(row: RowMapping, key: str) -> bool:
|
|
136
|
+
val = row[key]
|
|
137
|
+
if isinstance(val, bool):
|
|
138
|
+
return val
|
|
139
|
+
raise TypeError(f"expected bool for {key!r}, got {type(val).__name__}")
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def row_opt_bool(row: RowMapping, key: str) -> bool | None:
|
|
143
|
+
val = row.get(key)
|
|
144
|
+
if val is None:
|
|
145
|
+
return None
|
|
146
|
+
if isinstance(val, bool):
|
|
147
|
+
return val
|
|
148
|
+
raise TypeError(f"expected bool | None for {key!r}, got {type(val).__name__}")
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def row_float(row: RowMapping, key: str) -> float:
|
|
152
|
+
val = row[key]
|
|
153
|
+
if isinstance(val, bool):
|
|
154
|
+
raise TypeError(f"expected float for {key!r}, got bool")
|
|
155
|
+
if isinstance(val, float):
|
|
156
|
+
return val
|
|
157
|
+
if isinstance(val, int):
|
|
158
|
+
return float(val)
|
|
159
|
+
raise TypeError(f"expected float for {key!r}, got {type(val).__name__}")
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def row_opt_float(row: RowMapping, key: str) -> float | None:
|
|
163
|
+
val = row.get(key)
|
|
164
|
+
if val is None:
|
|
165
|
+
return None
|
|
166
|
+
if isinstance(val, bool):
|
|
167
|
+
raise TypeError(f"expected float | None for {key!r}, got bool")
|
|
168
|
+
if isinstance(val, float):
|
|
169
|
+
return val
|
|
170
|
+
if isinstance(val, int):
|
|
171
|
+
return float(val)
|
|
172
|
+
raise TypeError(f"expected float | None for {key!r}, got {type(val).__name__}")
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def api_int(obj: Mapping[str, RowValue], key: str, default: int = 0) -> int:
|
|
176
|
+
val = obj.get(key)
|
|
177
|
+
if val is None:
|
|
178
|
+
return default
|
|
179
|
+
if isinstance(val, bool):
|
|
180
|
+
return int(val)
|
|
181
|
+
if isinstance(val, int):
|
|
182
|
+
return val
|
|
183
|
+
if isinstance(val, float):
|
|
184
|
+
return int(val)
|
|
185
|
+
if isinstance(val, str):
|
|
186
|
+
try:
|
|
187
|
+
return int(val)
|
|
188
|
+
except ValueError:
|
|
189
|
+
return default
|
|
190
|
+
return default
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def api_float(obj: Mapping[str, RowValue], key: str, default: float = 0.0) -> float:
|
|
194
|
+
val = obj.get(key)
|
|
195
|
+
if val is None:
|
|
196
|
+
return default
|
|
197
|
+
if isinstance(val, bool):
|
|
198
|
+
return float(val)
|
|
199
|
+
if isinstance(val, int):
|
|
200
|
+
return float(val)
|
|
201
|
+
if isinstance(val, float):
|
|
202
|
+
return val
|
|
203
|
+
if isinstance(val, str):
|
|
204
|
+
try:
|
|
205
|
+
return float(val)
|
|
206
|
+
except ValueError:
|
|
207
|
+
return default
|
|
208
|
+
return default
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def row_json(row: RowMapping, key: str) -> JSONValue:
|
|
212
|
+
val = row[key]
|
|
213
|
+
if isinstance(val, bool) or isinstance(val, (str, int, float, type(None))):
|
|
214
|
+
return cast(JSONValue, val)
|
|
215
|
+
if isinstance(val, dict):
|
|
216
|
+
if not all(isinstance(k, str) for k in val):
|
|
217
|
+
raise TypeError(f"expected JSON object keys to be str for {key!r}")
|
|
218
|
+
return cast(JSONValue, val)
|
|
219
|
+
if isinstance(val, list):
|
|
220
|
+
return cast(JSONValue, val)
|
|
221
|
+
raise TypeError(f"expected JSON value for {key!r}, got {type(val).__name__}")
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def row_json_object(row: RowMapping, key: str) -> JSONObject:
|
|
225
|
+
val = row[key]
|
|
226
|
+
if isinstance(val, dict):
|
|
227
|
+
if not all(isinstance(k, str) for k in val):
|
|
228
|
+
raise TypeError(f"expected JSON object keys to be str for {key!r}")
|
|
229
|
+
return cast(JSONObject, val)
|
|
230
|
+
raise TypeError(f"expected JSON object for {key!r}, got {type(val).__name__}")
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def row_opt_json_object(row: RowMapping, key: str) -> JSONObject | None:
|
|
234
|
+
val = row.get(key)
|
|
235
|
+
if val is None:
|
|
236
|
+
return None
|
|
237
|
+
if isinstance(val, dict):
|
|
238
|
+
if not all(isinstance(k, str) for k in val):
|
|
239
|
+
raise TypeError(f"expected JSON object keys to be str for {key!r}")
|
|
240
|
+
return cast(JSONObject, val)
|
|
241
|
+
raise TypeError(f"expected JSON object | None for {key!r}, got {type(val).__name__}")
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def row_uuid(row: RowMapping, key: str) -> UUID:
|
|
245
|
+
val = row[key]
|
|
246
|
+
if isinstance(val, UUID):
|
|
247
|
+
return val
|
|
248
|
+
if isinstance(val, str):
|
|
249
|
+
return UUID(val)
|
|
250
|
+
raise TypeError(f"expected UUID for {key!r}, got {type(val).__name__}")
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def row_opt_uuid(row: RowMapping, key: str) -> UUID | None:
|
|
254
|
+
val = row.get(key)
|
|
255
|
+
if val is None:
|
|
256
|
+
return None
|
|
257
|
+
if isinstance(val, UUID):
|
|
258
|
+
return val
|
|
259
|
+
if isinstance(val, str):
|
|
260
|
+
return UUID(val)
|
|
261
|
+
raise TypeError(f"expected UUID | None for {key!r}, got {type(val).__name__}")
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def rows_mapping(rows: Iterable[object]) -> list[RowMapping]:
|
|
265
|
+
return [row_mapping(r) for r in rows]
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def apply_mappings_page(
|
|
269
|
+
session: Session,
|
|
270
|
+
stmt: Any,
|
|
271
|
+
*,
|
|
272
|
+
limit: int,
|
|
273
|
+
offset: int,
|
|
274
|
+
params: RowMapping | None = None,
|
|
275
|
+
) -> list[RowMapping]:
|
|
276
|
+
"""Execute ``stmt`` with limit/offset; return normalized row mappings."""
|
|
277
|
+
if limit < 0:
|
|
278
|
+
raise ValueError("limit must be >= 0")
|
|
279
|
+
if offset < 0:
|
|
280
|
+
raise ValueError("offset must be >= 0")
|
|
281
|
+
paged = stmt.limit(limit).offset(offset)
|
|
282
|
+
mapped = session.execute(paged, params or {}).mappings()
|
|
283
|
+
rows = mapped.all() if hasattr(mapped, "all") else mapped
|
|
284
|
+
return rows_mapping(rows)
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def get_sort_column(
|
|
288
|
+
sort: SortConfig,
|
|
289
|
+
order_by: OrderByMap | None = None,
|
|
290
|
+
) -> object:
|
|
291
|
+
"""Resolve the primary ORDER BY expression for ``sort`` and optional client ``order_by``."""
|
|
292
|
+
return sort.order_expression(order_by)
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def expanding_in_param(
|
|
296
|
+
name: str,
|
|
297
|
+
values: Sequence[PrimaryKey],
|
|
298
|
+
) -> tuple[object, dict[str, list[str]]]:
|
|
299
|
+
"""Return ``(bindparam(..., expanding=True), {name: [str(v), ...]})`` for ``IN`` clauses."""
|
|
300
|
+
param: BindParameter[Any] = bindparam(name, expanding=True)
|
|
301
|
+
return param, {name: [str(value) for value in values]}
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def partial_update_model(
|
|
305
|
+
session: Session,
|
|
306
|
+
model: type[_ModelT],
|
|
307
|
+
pk_value: PrimaryKey,
|
|
308
|
+
fields: RowMapping,
|
|
309
|
+
writable: frozenset[str],
|
|
310
|
+
*,
|
|
311
|
+
pk_attr: str = "id",
|
|
312
|
+
touch_updated_on: bool = False,
|
|
313
|
+
extra_values: RowMapping | None = None,
|
|
314
|
+
) -> int:
|
|
315
|
+
"""Partial UPDATE on an ORM mapped class; ``fields`` keys must pass ``writable``."""
|
|
316
|
+
if issubclass(model, AuditMixin):
|
|
317
|
+
audit_updates = {k: v for k, v in fields.items() if k in writable}
|
|
318
|
+
if extra_values:
|
|
319
|
+
audit_updates = {**audit_updates, **extra_values}
|
|
320
|
+
if not audit_updates:
|
|
321
|
+
return 0
|
|
322
|
+
row = session.get(model, pk_value)
|
|
323
|
+
if row is None:
|
|
324
|
+
return 0
|
|
325
|
+
for key, value in audit_updates.items():
|
|
326
|
+
setattr(row, key, value)
|
|
327
|
+
session.flush()
|
|
328
|
+
return 1
|
|
329
|
+
core_updates: RowMapping = {k: v for k, v in fields.items() if k in writable}
|
|
330
|
+
if extra_values:
|
|
331
|
+
core_updates = {**core_updates, **extra_values}
|
|
332
|
+
if not core_updates:
|
|
333
|
+
return 0
|
|
334
|
+
if touch_updated_on:
|
|
335
|
+
core_updates = cast(
|
|
336
|
+
RowMapping, {**dict(core_updates), "updated_on": cast(RowValue, func.now())}
|
|
337
|
+
)
|
|
338
|
+
pk_col = getattr(model, pk_attr)
|
|
339
|
+
stmt = update(model).where(pk_col == pk_value).values(**core_updates)
|
|
340
|
+
result = session.execute(stmt)
|
|
341
|
+
return int(result.rowcount or 0)
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def partial_update(
|
|
345
|
+
session: Session,
|
|
346
|
+
table_name: str,
|
|
347
|
+
pk_value: PrimaryKey,
|
|
348
|
+
fields: RowMapping,
|
|
349
|
+
writable: frozenset[str],
|
|
350
|
+
*,
|
|
351
|
+
pk_column: str = "id",
|
|
352
|
+
touch_updated_on: bool = False,
|
|
353
|
+
extra_values: RowMapping | None = None,
|
|
354
|
+
) -> int:
|
|
355
|
+
"""Core partial UPDATE — use ``partial_update_model`` when an ORM class exists."""
|
|
356
|
+
updates: RowMapping = {k: v for k, v in fields.items() if k in writable}
|
|
357
|
+
if extra_values:
|
|
358
|
+
updates = {**updates, **extra_values}
|
|
359
|
+
if not updates:
|
|
360
|
+
return 0
|
|
361
|
+
if touch_updated_on:
|
|
362
|
+
updates = cast(RowMapping, {**dict(updates), "updated_on": cast(RowValue, func.now())})
|
|
363
|
+
col_names = [pk_column, *updates.keys()]
|
|
364
|
+
tbl = sql_table(table_name, *col_names)
|
|
365
|
+
stmt = update(tbl).where(tbl.c[pk_column] == pk_value).values(**updates)
|
|
366
|
+
result = session.execute(stmt)
|
|
367
|
+
return int(result.rowcount or 0)
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def apply_writable_update(
|
|
371
|
+
session: Session,
|
|
372
|
+
model: type[DeclarativeBase],
|
|
373
|
+
pk_value: PrimaryKey,
|
|
374
|
+
values: RowMapping,
|
|
375
|
+
writable: frozenset[str],
|
|
376
|
+
*,
|
|
377
|
+
pk_attr: str = "id",
|
|
378
|
+
) -> None:
|
|
379
|
+
"""Apply only ``writable`` keys from ``values`` to a single row; no-op when empty."""
|
|
380
|
+
filtered = {k: v for k, v in values.items() if k in writable}
|
|
381
|
+
if not filtered:
|
|
382
|
+
return
|
|
383
|
+
pk_col = getattr(model, pk_attr)
|
|
384
|
+
session.execute(update(model).where(pk_col == pk_value).values(**filtered))
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
def delete_by_ids(
|
|
388
|
+
session: Session,
|
|
389
|
+
table_name: str,
|
|
390
|
+
ids: list[object],
|
|
391
|
+
*,
|
|
392
|
+
pk_column: str = "id",
|
|
393
|
+
) -> int:
|
|
394
|
+
if not ids:
|
|
395
|
+
return 0
|
|
396
|
+
tbl = sql_table(table_name, pk_column)
|
|
397
|
+
stmt = delete(tbl).where(tbl.c[pk_column].in_(ids))
|
|
398
|
+
result = session.execute(stmt)
|
|
399
|
+
return int(result.rowcount or 0)
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
def delete_by_ids_model(
|
|
403
|
+
session: Session,
|
|
404
|
+
model: type[_ModelT],
|
|
405
|
+
ids: list[object],
|
|
406
|
+
*,
|
|
407
|
+
pk_attr: str = "id",
|
|
408
|
+
) -> int:
|
|
409
|
+
if not ids:
|
|
410
|
+
return 0
|
|
411
|
+
pk_col = getattr(model, pk_attr)
|
|
412
|
+
stmt = delete(model).where(pk_col.in_(ids))
|
|
413
|
+
result = session.execute(stmt)
|
|
414
|
+
return int(result.rowcount or 0)
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
def col_eq(col_sql: str, param_name: str, value: object) -> tuple[SqlFilter, ApiObject]:
|
|
418
|
+
return literal_column(col_sql) == bindparam(param_name), cast(
|
|
419
|
+
ApiObject, {param_name: cast(RowValue, value)}
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
def col_icontains(
|
|
424
|
+
col_sql: str,
|
|
425
|
+
param_name: str,
|
|
426
|
+
raw: object,
|
|
427
|
+
) -> tuple[SqlFilter, ApiObject] | None:
|
|
428
|
+
text_value = str(raw).strip()
|
|
429
|
+
if not text_value:
|
|
430
|
+
return None
|
|
431
|
+
crit = func.lower(literal_column(col_sql)).like(bindparam(param_name))
|
|
432
|
+
return crit, {param_name: f"%{text_value.lower()}%"}
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
def col_range(
|
|
436
|
+
col_sql: str,
|
|
437
|
+
param_name: str,
|
|
438
|
+
operator: str,
|
|
439
|
+
value: object,
|
|
440
|
+
) -> tuple[SqlFilter, ApiObject]:
|
|
441
|
+
col: SqlOrderColumn = literal_column(col_sql)
|
|
442
|
+
if operator == ">=":
|
|
443
|
+
return col >= bindparam(param_name), cast(ApiObject, {param_name: cast(RowValue, value)})
|
|
444
|
+
if operator == "<=":
|
|
445
|
+
return col <= bindparam(param_name), cast(ApiObject, {param_name: cast(RowValue, value)})
|
|
446
|
+
raise ValueError(f"unsupported operator: {operator}")
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
def merge_criteria(
|
|
450
|
+
*parts: tuple[SqlFilters, ApiObject] | None,
|
|
451
|
+
) -> tuple[SqlFilters, ApiObject]:
|
|
452
|
+
criteria: SqlFilters = []
|
|
453
|
+
params: ApiObject = {}
|
|
454
|
+
for part in parts:
|
|
455
|
+
if part is None:
|
|
456
|
+
continue
|
|
457
|
+
crits, p = part
|
|
458
|
+
criteria.extend(crits)
|
|
459
|
+
params.update(p)
|
|
460
|
+
return criteria, params
|
|
461
|
+
|
|
462
|
+
|
|
463
|
+
def combine_and(*criteria: SqlFilter | None) -> SqlFilter | None:
|
|
464
|
+
present = [c for c in criteria if c is not None]
|
|
465
|
+
if not present:
|
|
466
|
+
return None
|
|
467
|
+
return and_(*present)
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
def order_by_allowlist(
|
|
471
|
+
order_key: str,
|
|
472
|
+
ordering_map: Mapping[str, str],
|
|
473
|
+
*,
|
|
474
|
+
allowlist: frozenset[str],
|
|
475
|
+
) -> SqlOrderColumn:
|
|
476
|
+
if order_key not in allowlist:
|
|
477
|
+
raise ValueError(f"invalid order key: {order_key}")
|
|
478
|
+
return literal_order_expr(ordering_map[order_key])
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
def literal_order_expr(spec: str) -> SqlOrderColumn:
|
|
482
|
+
"""Build ORDER BY from a SQL fragment such as ``a.started_at DESC``."""
|
|
483
|
+
parts = spec.rsplit(" ", 1)
|
|
484
|
+
if len(parts) == 2 and parts[1].upper() == "DESC":
|
|
485
|
+
return desc(literal_column(parts[0]))
|
|
486
|
+
return literal_column(spec)
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
def order_expr_from_sort(
|
|
490
|
+
column: str,
|
|
491
|
+
direction: str,
|
|
492
|
+
*,
|
|
493
|
+
columns: Mapping[str, Mapping[str, str]],
|
|
494
|
+
) -> SqlOrderColumn:
|
|
495
|
+
"""Build an ORDER BY expression from ``(column, asc|desc)`` and a column spec map."""
|
|
496
|
+
return literal_order_expr(columns[column][direction])
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
def count_from_subquery(session: Session, subq: object) -> int:
|
|
500
|
+
"""Count rows from a subquery (path C helper for aggregate count wrappers)."""
|
|
501
|
+
return int(session.execute(select(func.count()).select_from(subq)).scalar_one() or 0)
|
|
502
|
+
|
|
503
|
+
|
|
504
|
+
def count_from_table(
|
|
505
|
+
session: Session,
|
|
506
|
+
tbl: SqlTable,
|
|
507
|
+
criteria: SqlFilters,
|
|
508
|
+
params: RowMapping,
|
|
509
|
+
) -> int:
|
|
510
|
+
stmt = select(func.count()).select_from(tbl)
|
|
511
|
+
combined = combine_and(*criteria)
|
|
512
|
+
if combined is not None:
|
|
513
|
+
stmt = stmt.where(combined)
|
|
514
|
+
return int(session.execute(stmt, params).scalar_one())
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
def select_page_from_table(
|
|
518
|
+
session: Session,
|
|
519
|
+
tbl: SqlTable,
|
|
520
|
+
criteria: SqlFilters,
|
|
521
|
+
params: RowMapping,
|
|
522
|
+
*,
|
|
523
|
+
order_by: SqlOrderColumn,
|
|
524
|
+
limit: int,
|
|
525
|
+
offset: int,
|
|
526
|
+
) -> list[object]:
|
|
527
|
+
stmt = select(tbl)
|
|
528
|
+
combined = combine_and(*criteria)
|
|
529
|
+
if combined is not None:
|
|
530
|
+
stmt = stmt.where(combined)
|
|
531
|
+
stmt = stmt.order_by(order_by).limit(limit).offset(offset)
|
|
532
|
+
return list(session.execute(stmt, params).mappings().all())
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""Portable repository factory protocol (no Phobos or app imports)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
from typing import Protocol
|
|
5
|
+
from typing import TypeVar
|
|
6
|
+
from sqlalchemy.orm import DeclarativeBase
|
|
7
|
+
from sqlphilosophy.sync.query import StatementQueryBuilder
|
|
8
|
+
|
|
9
|
+
T = TypeVar("T", bound=DeclarativeBase)
|
|
10
|
+
R = TypeVar("R")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class RepositoryFactory(Protocol):
|
|
14
|
+
"""Session-scoped factory for statement builders and entity repositories."""
|
|
15
|
+
|
|
16
|
+
def create_statement(self, model: type[T]) -> StatementQueryBuilder[T]:
|
|
17
|
+
"""Return a fluent read builder for ``model``."""
|
|
18
|
+
...
|
|
19
|
+
|
|
20
|
+
def get_repository(self, repo_class: type[R]) -> R:
|
|
21
|
+
"""Return a cached typed entity repository."""
|
|
22
|
+
...
|
|
23
|
+
|
|
24
|
+
def repository(self, model: type[T]) -> object:
|
|
25
|
+
"""Return generic CRUD helpers for ``model`` (``BaseRepository`` in Phobos)."""
|
|
26
|
+
...
|