sqliter-py 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqliter/__init__.py +9 -0
- sqliter/constants.py +45 -0
- sqliter/exceptions.py +198 -0
- sqliter/helpers.py +100 -0
- sqliter/model/__init__.py +46 -0
- sqliter/model/foreign_key.py +153 -0
- sqliter/model/model.py +236 -0
- sqliter/model/unique.py +28 -0
- sqliter/py.typed +0 -0
- sqliter/query/__init__.py +9 -0
- sqliter/query/query.py +891 -0
- sqliter/sqliter.py +1087 -0
- sqliter_py-0.12.0.dist-info/METADATA +209 -0
- sqliter_py-0.12.0.dist-info/RECORD +15 -0
- sqliter_py-0.12.0.dist-info/WHEEL +4 -0
sqliter/query/query.py
ADDED
|
@@ -0,0 +1,891 @@
|
|
|
1
|
+
"""Implements the query building and execution logic for SQLiter.
|
|
2
|
+
|
|
3
|
+
This module defines the QueryBuilder class, which provides a fluent
|
|
4
|
+
interface for constructing SQL queries. It supports operations such
|
|
5
|
+
as filtering, ordering, limiting, and various data retrieval methods,
|
|
6
|
+
allowing for flexible and expressive database queries without writing
|
|
7
|
+
raw SQL.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import hashlib
|
|
13
|
+
import json
|
|
14
|
+
import sqlite3
|
|
15
|
+
import warnings
|
|
16
|
+
from typing import (
|
|
17
|
+
TYPE_CHECKING,
|
|
18
|
+
Any,
|
|
19
|
+
Callable,
|
|
20
|
+
Generic,
|
|
21
|
+
Literal,
|
|
22
|
+
Optional,
|
|
23
|
+
TypeVar,
|
|
24
|
+
Union,
|
|
25
|
+
cast,
|
|
26
|
+
overload,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
from typing_extensions import LiteralString, Self
|
|
30
|
+
|
|
31
|
+
from sqliter.constants import OPERATOR_MAPPING
|
|
32
|
+
from sqliter.exceptions import (
|
|
33
|
+
InvalidFilterError,
|
|
34
|
+
InvalidOffsetError,
|
|
35
|
+
InvalidOrderError,
|
|
36
|
+
RecordDeletionError,
|
|
37
|
+
RecordFetchError,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
if TYPE_CHECKING: # pragma: no cover
|
|
41
|
+
from pydantic.fields import FieldInfo
|
|
42
|
+
|
|
43
|
+
from sqliter import SqliterDB
|
|
44
|
+
from sqliter.model import BaseDBModel, SerializableField
|
|
45
|
+
|
|
46
|
+
# TypeVar for generic QueryBuilder
|
|
47
|
+
T = TypeVar("T", bound="BaseDBModel")
|
|
48
|
+
|
|
49
|
+
# Define a type alias for the possible value types
|
|
50
|
+
FilterValue = Union[
|
|
51
|
+
str, int, float, bool, None, list[Union[str, int, float, bool]]
|
|
52
|
+
]
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class QueryBuilder(Generic[T]):
|
|
56
|
+
"""Builds and executes database queries for a specific model.
|
|
57
|
+
|
|
58
|
+
This class provides methods to construct SQL queries, apply filters,
|
|
59
|
+
set ordering, and execute the queries against the database.
|
|
60
|
+
|
|
61
|
+
Attributes:
|
|
62
|
+
db (SqliterDB): The database connection object.
|
|
63
|
+
model_class (type[T]): The Pydantic model class.
|
|
64
|
+
table_name (str): The name of the database table.
|
|
65
|
+
filters (list): List of applied filter conditions.
|
|
66
|
+
_limit (Optional[int]): The LIMIT clause value, if any.
|
|
67
|
+
_offset (Optional[int]): The OFFSET clause value, if any.
|
|
68
|
+
_order_by (Optional[str]): The ORDER BY clause, if any.
|
|
69
|
+
_fields (Optional[list[str]]): List of fields to select, if specified.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
db: SqliterDB,
|
|
75
|
+
model_class: type[T],
|
|
76
|
+
fields: Optional[list[str]] = None,
|
|
77
|
+
) -> None:
|
|
78
|
+
"""Initialize a new QueryBuilder instance.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
db: The database connection object.
|
|
82
|
+
model_class: The Pydantic model class for the table.
|
|
83
|
+
fields: Optional list of field names to select. If None, all fields
|
|
84
|
+
are selected.
|
|
85
|
+
"""
|
|
86
|
+
self.db = db
|
|
87
|
+
self.model_class: type[T] = model_class
|
|
88
|
+
self.table_name = model_class.get_table_name() # Use model_class method
|
|
89
|
+
self.filters: list[tuple[str, Any, str]] = []
|
|
90
|
+
self._limit: Optional[int] = None
|
|
91
|
+
self._offset: Optional[int] = None
|
|
92
|
+
self._order_by: Optional[str] = None
|
|
93
|
+
self._fields: Optional[list[str]] = fields
|
|
94
|
+
self._bypass_cache: bool = False
|
|
95
|
+
self._query_cache_ttl: Optional[int] = None
|
|
96
|
+
|
|
97
|
+
if self._fields:
|
|
98
|
+
self._validate_fields()
|
|
99
|
+
|
|
100
|
+
def _validate_fields(self) -> None:
|
|
101
|
+
"""Validate that the specified fields exist in the model.
|
|
102
|
+
|
|
103
|
+
Raises:
|
|
104
|
+
ValueError: If any specified field is not in the model.
|
|
105
|
+
"""
|
|
106
|
+
if self._fields is None:
|
|
107
|
+
return
|
|
108
|
+
valid_fields = set(self.model_class.model_fields.keys())
|
|
109
|
+
invalid_fields = set(self._fields) - valid_fields
|
|
110
|
+
if invalid_fields:
|
|
111
|
+
err_message = (
|
|
112
|
+
f"Invalid fields specified: {', '.join(invalid_fields)}"
|
|
113
|
+
)
|
|
114
|
+
raise ValueError(err_message)
|
|
115
|
+
|
|
116
|
+
def filter(self, **conditions: str | float | None) -> Self:
|
|
117
|
+
"""Apply filter conditions to the query.
|
|
118
|
+
|
|
119
|
+
This method allows adding one or more filter conditions to the query.
|
|
120
|
+
Each condition is specified as a keyword argument, where the key is
|
|
121
|
+
the field name and the value is the condition to apply.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
**conditions: Arbitrary keyword arguments representing filter
|
|
125
|
+
conditions. The key is the field name, and the value is the
|
|
126
|
+
condition to apply. Supported operators include equality,
|
|
127
|
+
comparison, and special operators like __in, __isnull, etc.
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
QueryBuilder: The current QueryBuilder instance for method
|
|
131
|
+
chaining.
|
|
132
|
+
|
|
133
|
+
Examples:
|
|
134
|
+
>>> query.filter(name="John", age__gt=30)
|
|
135
|
+
>>> query.filter(status__in=["active", "pending"])
|
|
136
|
+
"""
|
|
137
|
+
valid_fields = self.model_class.model_fields
|
|
138
|
+
|
|
139
|
+
for field, value in conditions.items():
|
|
140
|
+
field_name, operator = self._parse_field_operator(field)
|
|
141
|
+
self._validate_field(field_name, valid_fields)
|
|
142
|
+
|
|
143
|
+
if operator in ["__isnull", "__notnull"]:
|
|
144
|
+
self._handle_null(field_name, value, operator)
|
|
145
|
+
else:
|
|
146
|
+
handler = self._get_operator_handler(operator)
|
|
147
|
+
handler(field_name, value, operator)
|
|
148
|
+
|
|
149
|
+
return self
|
|
150
|
+
|
|
151
|
+
def fields(self, fields: Optional[list[str]] = None) -> Self:
|
|
152
|
+
"""Specify which fields to select in the query.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
fields: List of field names to select. If None, all fields are
|
|
156
|
+
selected.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
The QueryBuilder instance for method chaining.
|
|
160
|
+
"""
|
|
161
|
+
if fields:
|
|
162
|
+
if "pk" not in fields:
|
|
163
|
+
fields.append("pk")
|
|
164
|
+
self._fields = fields
|
|
165
|
+
self._validate_fields()
|
|
166
|
+
return self
|
|
167
|
+
|
|
168
|
+
def exclude(self, fields: Optional[list[str]] = None) -> Self:
|
|
169
|
+
"""Specify which fields to exclude from the query results.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
fields: List of field names to exclude. If None, no fields are
|
|
173
|
+
excluded.
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
The QueryBuilder instance for method chaining.
|
|
177
|
+
|
|
178
|
+
Raises:
|
|
179
|
+
ValueError: If exclusion results in no fields being selected or if
|
|
180
|
+
invalid fields are specified.
|
|
181
|
+
"""
|
|
182
|
+
if fields:
|
|
183
|
+
if "pk" in fields:
|
|
184
|
+
err = "The primary key 'pk' cannot be excluded."
|
|
185
|
+
raise ValueError(err)
|
|
186
|
+
all_fields = set(self.model_class.model_fields.keys())
|
|
187
|
+
|
|
188
|
+
# Check for invalid fields before subtraction
|
|
189
|
+
invalid_fields = set(fields) - all_fields
|
|
190
|
+
if invalid_fields:
|
|
191
|
+
err = (
|
|
192
|
+
"Invalid fields specified for exclusion: "
|
|
193
|
+
f"{', '.join(invalid_fields)}"
|
|
194
|
+
)
|
|
195
|
+
raise ValueError(err)
|
|
196
|
+
|
|
197
|
+
# Subtract the fields specified for exclusion
|
|
198
|
+
self._fields = list(all_fields - set(fields))
|
|
199
|
+
|
|
200
|
+
# Explicit check: raise an error if no fields remain
|
|
201
|
+
if self._fields == ["pk"]:
|
|
202
|
+
err = "Exclusion results in no fields being selected."
|
|
203
|
+
raise ValueError(err)
|
|
204
|
+
|
|
205
|
+
# Now validate the remaining fields to ensure they are all valid
|
|
206
|
+
self._validate_fields()
|
|
207
|
+
|
|
208
|
+
return self
|
|
209
|
+
|
|
210
|
+
def only(self, field: str) -> Self:
|
|
211
|
+
"""Specify a single field to select in the query.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
field: The name of the field to select.
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
The QueryBuilder instance for method chaining.
|
|
218
|
+
|
|
219
|
+
Raises:
|
|
220
|
+
ValueError: If the specified field is invalid.
|
|
221
|
+
"""
|
|
222
|
+
all_fields = set(self.model_class.model_fields.keys())
|
|
223
|
+
|
|
224
|
+
# Validate that the field exists
|
|
225
|
+
if field not in all_fields:
|
|
226
|
+
err = f"Invalid field specified: {field}"
|
|
227
|
+
raise ValueError(err)
|
|
228
|
+
|
|
229
|
+
# Set self._fields to just the single field
|
|
230
|
+
self._fields = [field, "pk"]
|
|
231
|
+
return self
|
|
232
|
+
|
|
233
|
+
def _get_operator_handler(
|
|
234
|
+
self, operator: str
|
|
235
|
+
) -> Callable[[str, Any, str], None]:
|
|
236
|
+
"""Get the appropriate handler function for the given operator.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
operator: The filter operator string.
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
A callable that handles the specific operator type.
|
|
243
|
+
"""
|
|
244
|
+
handlers = {
|
|
245
|
+
"__isnull": self._handle_null,
|
|
246
|
+
"__notnull": self._handle_null,
|
|
247
|
+
"__in": self._handle_in,
|
|
248
|
+
"__not_in": self._handle_in,
|
|
249
|
+
"__startswith": self._handle_like,
|
|
250
|
+
"__endswith": self._handle_like,
|
|
251
|
+
"__contains": self._handle_like,
|
|
252
|
+
"__istartswith": self._handle_like,
|
|
253
|
+
"__iendswith": self._handle_like,
|
|
254
|
+
"__icontains": self._handle_like,
|
|
255
|
+
"__lt": self._handle_comparison,
|
|
256
|
+
"__lte": self._handle_comparison,
|
|
257
|
+
"__gt": self._handle_comparison,
|
|
258
|
+
"__gte": self._handle_comparison,
|
|
259
|
+
"__ne": self._handle_comparison,
|
|
260
|
+
}
|
|
261
|
+
return handlers.get(operator, self._handle_equality)
|
|
262
|
+
|
|
263
|
+
def _validate_field(
|
|
264
|
+
self, field_name: str, valid_fields: dict[str, FieldInfo]
|
|
265
|
+
) -> None:
|
|
266
|
+
"""Validate that a field exists in the model.
|
|
267
|
+
|
|
268
|
+
Args:
|
|
269
|
+
field_name: The name of the field to validate.
|
|
270
|
+
valid_fields: Dictionary of valid fields from the model.
|
|
271
|
+
|
|
272
|
+
Raises:
|
|
273
|
+
InvalidFilterError: If the field is not in the model.
|
|
274
|
+
"""
|
|
275
|
+
if field_name not in valid_fields:
|
|
276
|
+
raise InvalidFilterError(field_name)
|
|
277
|
+
|
|
278
|
+
def _handle_equality(
|
|
279
|
+
self, field_name: str, value: FilterValue, operator: str
|
|
280
|
+
) -> None:
|
|
281
|
+
"""Handle equality filter conditions.
|
|
282
|
+
|
|
283
|
+
Args:
|
|
284
|
+
field_name: The name of the field to filter on.
|
|
285
|
+
value: The value to compare against.
|
|
286
|
+
operator: The operator string (usually '__eq').
|
|
287
|
+
|
|
288
|
+
This method adds an equality condition to the filters list, handling
|
|
289
|
+
NULL values separately.
|
|
290
|
+
"""
|
|
291
|
+
if value is None:
|
|
292
|
+
self.filters.append((f"{field_name} IS NULL", None, "__isnull"))
|
|
293
|
+
else:
|
|
294
|
+
self.filters.append((field_name, value, operator))
|
|
295
|
+
|
|
296
|
+
def _handle_null(
|
|
297
|
+
self, field_name: str, value: Union[str, float, None], operator: str
|
|
298
|
+
) -> None:
|
|
299
|
+
"""Handle IS NULL and IS NOT NULL filter conditions.
|
|
300
|
+
|
|
301
|
+
Args:
|
|
302
|
+
field_name: The name of the field to filter on. _: Placeholder for
|
|
303
|
+
unused value parameter.
|
|
304
|
+
operator: The operator string ('__isnull' or '__notnull').
|
|
305
|
+
value: The value to check for.
|
|
306
|
+
|
|
307
|
+
This method adds an IS NULL or IS NOT NULL condition to the filters
|
|
308
|
+
list.
|
|
309
|
+
"""
|
|
310
|
+
is_null = operator == "__isnull"
|
|
311
|
+
check_null = bool(value) if is_null else not bool(value)
|
|
312
|
+
condition = f"{field_name} IS {'NOT ' if not check_null else ''}NULL"
|
|
313
|
+
self.filters.append((condition, None, operator))
|
|
314
|
+
|
|
315
|
+
def _handle_in(
|
|
316
|
+
self, field_name: str, value: FilterValue, operator: str
|
|
317
|
+
) -> None:
|
|
318
|
+
"""Handle IN and NOT IN filter conditions.
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
field_name: The name of the field to filter on.
|
|
322
|
+
value: A list of values to check against.
|
|
323
|
+
operator: The operator string ('__in' or '__not_in').
|
|
324
|
+
|
|
325
|
+
Raises:
|
|
326
|
+
TypeError: If the value is not a list.
|
|
327
|
+
|
|
328
|
+
This method adds an IN or NOT IN condition to the filters list.
|
|
329
|
+
"""
|
|
330
|
+
if not isinstance(value, list):
|
|
331
|
+
err = f"{field_name} requires a list for '{operator}'"
|
|
332
|
+
raise TypeError(err)
|
|
333
|
+
sql_operator = OPERATOR_MAPPING.get(operator, "IN")
|
|
334
|
+
placeholder_list = ", ".join(["?"] * len(value))
|
|
335
|
+
self.filters.append(
|
|
336
|
+
(
|
|
337
|
+
f"{field_name} {sql_operator} ({placeholder_list})",
|
|
338
|
+
value,
|
|
339
|
+
operator,
|
|
340
|
+
)
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
def _handle_like(
|
|
344
|
+
self, field_name: str, value: FilterValue, operator: str
|
|
345
|
+
) -> None:
|
|
346
|
+
"""Handle LIKE and GLOB filter conditions.
|
|
347
|
+
|
|
348
|
+
Args:
|
|
349
|
+
field_name: The name of the field to filter on.
|
|
350
|
+
value: The pattern to match against.
|
|
351
|
+
operator: The operator string (e.g., '__startswith', '__contains').
|
|
352
|
+
|
|
353
|
+
Raises:
|
|
354
|
+
TypeError: If the value is not a string.
|
|
355
|
+
|
|
356
|
+
This method adds a LIKE or GLOB condition to the filters list, depending
|
|
357
|
+
on whether the operation is case-sensitive or not.
|
|
358
|
+
"""
|
|
359
|
+
if not isinstance(value, str):
|
|
360
|
+
err = f"{field_name} requires a string value for '{operator}'"
|
|
361
|
+
raise TypeError(err)
|
|
362
|
+
formatted_value = self._format_string_for_operator(operator, value)
|
|
363
|
+
if operator in ["__startswith", "__endswith", "__contains"]:
|
|
364
|
+
self.filters.append(
|
|
365
|
+
(
|
|
366
|
+
f"{field_name} GLOB ?",
|
|
367
|
+
[formatted_value],
|
|
368
|
+
operator,
|
|
369
|
+
)
|
|
370
|
+
)
|
|
371
|
+
elif operator in ["__istartswith", "__iendswith", "__icontains"]:
|
|
372
|
+
self.filters.append(
|
|
373
|
+
(
|
|
374
|
+
f"{field_name} LIKE ?",
|
|
375
|
+
[formatted_value],
|
|
376
|
+
operator,
|
|
377
|
+
)
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
def _handle_comparison(
|
|
381
|
+
self, field_name: str, value: FilterValue, operator: str
|
|
382
|
+
) -> None:
|
|
383
|
+
"""Handle comparison filter conditions.
|
|
384
|
+
|
|
385
|
+
Args:
|
|
386
|
+
field_name: The name of the field to filter on.
|
|
387
|
+
value: The value to compare against.
|
|
388
|
+
operator: The comparison operator string (e.g., '__lt', '__gte').
|
|
389
|
+
|
|
390
|
+
This method adds a comparison condition to the filters list.
|
|
391
|
+
"""
|
|
392
|
+
sql_operator = OPERATOR_MAPPING[operator]
|
|
393
|
+
self.filters.append((f"{field_name} {sql_operator} ?", value, operator))
|
|
394
|
+
|
|
395
|
+
# Helper method for parsing field and operator
|
|
396
|
+
def _parse_field_operator(self, field: str) -> tuple[str, str]:
|
|
397
|
+
"""Parse a field string to separate the field name and operator.
|
|
398
|
+
|
|
399
|
+
Args:
|
|
400
|
+
field: The field string, potentially including an operator.
|
|
401
|
+
|
|
402
|
+
Returns:
|
|
403
|
+
A tuple containing the field name and the operator (or '__eq' if
|
|
404
|
+
no operator was specified).
|
|
405
|
+
"""
|
|
406
|
+
for operator in OPERATOR_MAPPING:
|
|
407
|
+
if field.endswith(operator):
|
|
408
|
+
return field[: -len(operator)], operator
|
|
409
|
+
return field, "__eq" # Default to equality if no operator is found
|
|
410
|
+
|
|
411
|
+
# Helper method for formatting string operators (like startswith)
|
|
412
|
+
def _format_string_for_operator(self, operator: str, value: str) -> str:
|
|
413
|
+
"""Format a string value based on the specified operator.
|
|
414
|
+
|
|
415
|
+
Args:
|
|
416
|
+
operator: The operator string (e.g., '__startswith', '__contains').
|
|
417
|
+
value: The original string value.
|
|
418
|
+
|
|
419
|
+
Returns:
|
|
420
|
+
The formatted string value suitable for the given operator.
|
|
421
|
+
"""
|
|
422
|
+
format_map = {
|
|
423
|
+
"__startswith": f"{value}*",
|
|
424
|
+
"__endswith": f"*{value}",
|
|
425
|
+
"__contains": f"*{value}*",
|
|
426
|
+
"__istartswith": f"{value.lower()}%",
|
|
427
|
+
"__iendswith": f"%{value.lower()}",
|
|
428
|
+
"__icontains": f"%{value.lower()}%",
|
|
429
|
+
}
|
|
430
|
+
|
|
431
|
+
# Return the formatted string or the original value if no match
|
|
432
|
+
return format_map.get(operator, value)
|
|
433
|
+
|
|
434
|
+
def limit(self, limit_value: int) -> Self:
|
|
435
|
+
"""Limit the number of results returned by the query.
|
|
436
|
+
|
|
437
|
+
Args:
|
|
438
|
+
limit_value: The maximum number of records to return.
|
|
439
|
+
|
|
440
|
+
Returns:
|
|
441
|
+
The QueryBuilder instance for method chaining.
|
|
442
|
+
"""
|
|
443
|
+
self._limit = limit_value
|
|
444
|
+
return self
|
|
445
|
+
|
|
446
|
+
def offset(self, offset_value: int) -> Self:
|
|
447
|
+
"""Set an offset value for the query.
|
|
448
|
+
|
|
449
|
+
Args:
|
|
450
|
+
offset_value: The number of records to skip.
|
|
451
|
+
|
|
452
|
+
Returns:
|
|
453
|
+
The QueryBuilder instance for method chaining.
|
|
454
|
+
|
|
455
|
+
Raises:
|
|
456
|
+
InvalidOffsetError: If the offset value is negative.
|
|
457
|
+
"""
|
|
458
|
+
if offset_value < 0:
|
|
459
|
+
raise InvalidOffsetError(offset_value)
|
|
460
|
+
self._offset = offset_value
|
|
461
|
+
|
|
462
|
+
if self._limit is None:
|
|
463
|
+
self._limit = -1
|
|
464
|
+
return self
|
|
465
|
+
|
|
466
|
+
def order(
|
|
467
|
+
self,
|
|
468
|
+
order_by_field: Optional[str] = None,
|
|
469
|
+
direction: Optional[str] = None,
|
|
470
|
+
*,
|
|
471
|
+
reverse: bool = False,
|
|
472
|
+
) -> Self:
|
|
473
|
+
"""Order the query results by the specified field.
|
|
474
|
+
|
|
475
|
+
Args:
|
|
476
|
+
order_by_field: The field to order by [optional].
|
|
477
|
+
direction: Deprecated. Use 'reverse' instead.
|
|
478
|
+
reverse: If True, sort in descending order.
|
|
479
|
+
|
|
480
|
+
Returns:
|
|
481
|
+
The QueryBuilder instance for method chaining.
|
|
482
|
+
|
|
483
|
+
Raises:
|
|
484
|
+
InvalidOrderError: If the field doesn't exist or if both 'direction'
|
|
485
|
+
and 'reverse' are specified.
|
|
486
|
+
|
|
487
|
+
Warns:
|
|
488
|
+
DeprecationWarning: If 'direction' is used instead of 'reverse'.
|
|
489
|
+
"""
|
|
490
|
+
if direction:
|
|
491
|
+
warnings.warn(
|
|
492
|
+
"'direction' argument is deprecated and will be removed in a "
|
|
493
|
+
"future version. Use 'reverse' instead.",
|
|
494
|
+
DeprecationWarning,
|
|
495
|
+
stacklevel=2,
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
if order_by_field is None:
|
|
499
|
+
order_by_field = self.model_class.get_primary_key()
|
|
500
|
+
|
|
501
|
+
if order_by_field not in self.model_class.model_fields:
|
|
502
|
+
err = f"'{order_by_field}' does not exist in the model fields."
|
|
503
|
+
raise InvalidOrderError(err)
|
|
504
|
+
# Raise an exception if both 'direction' and 'reverse' are specified
|
|
505
|
+
if direction and reverse:
|
|
506
|
+
err = (
|
|
507
|
+
"Cannot specify both 'direction' and 'reverse' as it "
|
|
508
|
+
"is ambiguous."
|
|
509
|
+
)
|
|
510
|
+
raise InvalidOrderError(err)
|
|
511
|
+
|
|
512
|
+
# Determine the sorting direction
|
|
513
|
+
if reverse:
|
|
514
|
+
sort_order = "DESC"
|
|
515
|
+
elif direction:
|
|
516
|
+
sort_order = direction.upper()
|
|
517
|
+
if sort_order not in {"ASC", "DESC"}:
|
|
518
|
+
err = f"'{direction}' is not a valid sorting direction."
|
|
519
|
+
raise InvalidOrderError(err)
|
|
520
|
+
else:
|
|
521
|
+
sort_order = "ASC"
|
|
522
|
+
|
|
523
|
+
self._order_by = f'"{order_by_field}" {sort_order}'
|
|
524
|
+
return self
|
|
525
|
+
|
|
526
|
+
def _execute_query(
|
|
527
|
+
self,
|
|
528
|
+
*,
|
|
529
|
+
fetch_one: bool = False,
|
|
530
|
+
count_only: bool = False,
|
|
531
|
+
) -> list[tuple[Any, ...]] | Optional[tuple[Any, ...]]:
|
|
532
|
+
"""Execute the constructed SQL query.
|
|
533
|
+
|
|
534
|
+
Args:
|
|
535
|
+
fetch_one: If True, fetch only one result.
|
|
536
|
+
count_only: If True, return only the count of results.
|
|
537
|
+
|
|
538
|
+
Returns:
|
|
539
|
+
A list of tuples (all results), a single tuple (one result),
|
|
540
|
+
or None if no results are found.
|
|
541
|
+
|
|
542
|
+
Raises:
|
|
543
|
+
RecordFetchError: If there's an error executing the query.
|
|
544
|
+
"""
|
|
545
|
+
if count_only:
|
|
546
|
+
fields = "COUNT(*)"
|
|
547
|
+
elif self._fields:
|
|
548
|
+
if "pk" not in self._fields:
|
|
549
|
+
self._fields.append("pk")
|
|
550
|
+
fields = ", ".join(f'"{field}"' for field in self._fields)
|
|
551
|
+
else:
|
|
552
|
+
fields = ", ".join(
|
|
553
|
+
f'"{field}"' for field in self.model_class.model_fields
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
sql = f'SELECT {fields} FROM "{self.table_name}"' # noqa: S608 # nosec
|
|
557
|
+
|
|
558
|
+
# Build the WHERE clause with special handling for None (NULL in SQL)
|
|
559
|
+
values, where_clause = self._parse_filter()
|
|
560
|
+
|
|
561
|
+
if self.filters:
|
|
562
|
+
sql += f" WHERE {where_clause}"
|
|
563
|
+
|
|
564
|
+
if self._order_by:
|
|
565
|
+
sql += f" ORDER BY {self._order_by}"
|
|
566
|
+
|
|
567
|
+
if self._limit is not None:
|
|
568
|
+
sql += " LIMIT ?"
|
|
569
|
+
values.append(self._limit)
|
|
570
|
+
|
|
571
|
+
if self._offset is not None:
|
|
572
|
+
sql += " OFFSET ?"
|
|
573
|
+
values.append(self._offset)
|
|
574
|
+
|
|
575
|
+
# Print the raw SQL and values if debug is enabled
|
|
576
|
+
# Log the SQL if debug is enabled
|
|
577
|
+
if self.db.debug:
|
|
578
|
+
self.db._log_sql(sql, values) # noqa: SLF001
|
|
579
|
+
|
|
580
|
+
try:
|
|
581
|
+
with self.db.connect() as conn:
|
|
582
|
+
cursor = conn.cursor()
|
|
583
|
+
cursor.execute(sql, values)
|
|
584
|
+
return cursor.fetchall() if not fetch_one else cursor.fetchone()
|
|
585
|
+
except sqlite3.Error as exc:
|
|
586
|
+
raise RecordFetchError(self.table_name) from exc
|
|
587
|
+
|
|
588
|
+
def _parse_filter(self) -> tuple[list[Any], LiteralString]:
|
|
589
|
+
"""Parse the filter conditions into SQL clauses and values.
|
|
590
|
+
|
|
591
|
+
Returns:
|
|
592
|
+
A tuple containing:
|
|
593
|
+
- A list of values to be used in the SQL query.
|
|
594
|
+
- A string representing the WHERE clause of the SQL query.
|
|
595
|
+
"""
|
|
596
|
+
where_clauses = []
|
|
597
|
+
values = []
|
|
598
|
+
for field, value, operator in self.filters:
|
|
599
|
+
if operator == "__eq":
|
|
600
|
+
where_clauses.append(f"{field} = ?")
|
|
601
|
+
values.append(value)
|
|
602
|
+
else:
|
|
603
|
+
where_clauses.append(field)
|
|
604
|
+
if operator not in ["__isnull", "__notnull"]:
|
|
605
|
+
if isinstance(value, list):
|
|
606
|
+
values.extend(value)
|
|
607
|
+
else:
|
|
608
|
+
values.append(value)
|
|
609
|
+
|
|
610
|
+
where_clause = " AND ".join(where_clauses)
|
|
611
|
+
return values, where_clause
|
|
612
|
+
|
|
613
|
+
def _convert_row_to_model(self, row: tuple[Any, ...]) -> T:
|
|
614
|
+
"""Convert a database row to a model instance.
|
|
615
|
+
|
|
616
|
+
Args:
|
|
617
|
+
row: A tuple representing a database row.
|
|
618
|
+
|
|
619
|
+
Returns:
|
|
620
|
+
An instance of the model class populated with the row data.
|
|
621
|
+
"""
|
|
622
|
+
if self._fields:
|
|
623
|
+
data = {
|
|
624
|
+
field: self._deserialize(field, row[idx])
|
|
625
|
+
for idx, field in enumerate(self._fields)
|
|
626
|
+
}
|
|
627
|
+
return self.model_class.model_validate_partial(data)
|
|
628
|
+
|
|
629
|
+
data = {
|
|
630
|
+
field: self._deserialize(field, row[idx])
|
|
631
|
+
for idx, field in enumerate(self.model_class.model_fields)
|
|
632
|
+
}
|
|
633
|
+
return self.model_class(**data)
|
|
634
|
+
|
|
635
|
+
def _deserialize(
|
|
636
|
+
self, field_name: str, value: SerializableField
|
|
637
|
+
) -> SerializableField:
|
|
638
|
+
"""Deserialize a field value if needed.
|
|
639
|
+
|
|
640
|
+
Args:
|
|
641
|
+
field_name: Name of the field being deserialized.
|
|
642
|
+
value: Value from the database.
|
|
643
|
+
|
|
644
|
+
Returns:
|
|
645
|
+
The deserialized value.
|
|
646
|
+
"""
|
|
647
|
+
return self.model_class.deserialize_field(
|
|
648
|
+
field_name, value, return_local_time=self.db.return_local_time
|
|
649
|
+
)
|
|
650
|
+
|
|
651
|
+
def bypass_cache(self) -> Self:
|
|
652
|
+
"""Bypass the cache for this specific query.
|
|
653
|
+
|
|
654
|
+
When called, the query will always hit the database regardless of
|
|
655
|
+
the global cache setting. This is useful for queries that require
|
|
656
|
+
fresh data.
|
|
657
|
+
|
|
658
|
+
Returns:
|
|
659
|
+
The QueryBuilder instance for method chaining.
|
|
660
|
+
|
|
661
|
+
Example:
|
|
662
|
+
>>> db.select(User).filter(name="Alice").bypass_cache().fetch_one()
|
|
663
|
+
"""
|
|
664
|
+
self._bypass_cache = True
|
|
665
|
+
return self
|
|
666
|
+
|
|
667
|
+
def cache_ttl(self, ttl: int) -> Self:
|
|
668
|
+
"""Set a custom TTL (time-to-live) for this specific query.
|
|
669
|
+
|
|
670
|
+
When called, the cached result of this query will expire after the
|
|
671
|
+
specified number of seconds, overriding the global cache_ttl setting.
|
|
672
|
+
|
|
673
|
+
Args:
|
|
674
|
+
ttl: Time-to-live in seconds for the cached result.
|
|
675
|
+
|
|
676
|
+
Returns:
|
|
677
|
+
The QueryBuilder instance for method chaining.
|
|
678
|
+
|
|
679
|
+
Raises:
|
|
680
|
+
ValueError: If ttl is negative.
|
|
681
|
+
|
|
682
|
+
Example:
|
|
683
|
+
>>> db.select(User).cache_ttl(60).fetch_all()
|
|
684
|
+
"""
|
|
685
|
+
if ttl < 0:
|
|
686
|
+
msg = "TTL must be non-negative"
|
|
687
|
+
raise ValueError(msg)
|
|
688
|
+
self._query_cache_ttl = ttl
|
|
689
|
+
return self
|
|
690
|
+
|
|
691
|
+
def _make_cache_key(self, *, fetch_one: bool) -> str:
|
|
692
|
+
"""Generate a cache key from the current query state.
|
|
693
|
+
|
|
694
|
+
Args:
|
|
695
|
+
fetch_one: Whether this is a fetch_one or fetch_all query.
|
|
696
|
+
|
|
697
|
+
Returns:
|
|
698
|
+
A SHA256 hash representing the current query state.
|
|
699
|
+
|
|
700
|
+
Raises:
|
|
701
|
+
ValueError: If filters contain incomparable types that prevent
|
|
702
|
+
cache key generation (e.g., filtering the same field with
|
|
703
|
+
both string and numeric values).
|
|
704
|
+
"""
|
|
705
|
+
# Sort filters for consistent cache keys
|
|
706
|
+
# Note: This requires filter values to be comparable. Avoid filtering
|
|
707
|
+
# the same field with incompatible types (e.g., name="Alice" and
|
|
708
|
+
# name=42 in the same query).
|
|
709
|
+
try:
|
|
710
|
+
sorted_filters = sorted(self.filters)
|
|
711
|
+
except TypeError as exc:
|
|
712
|
+
msg = (
|
|
713
|
+
"Cannot generate cache key: filters contain incomparable "
|
|
714
|
+
"types. Avoid filtering the same field with incompatible "
|
|
715
|
+
"value types (e.g., strings and numbers)."
|
|
716
|
+
)
|
|
717
|
+
raise ValueError(msg) from exc
|
|
718
|
+
|
|
719
|
+
# Create a deterministic representation of the query
|
|
720
|
+
key_parts = {
|
|
721
|
+
"table": self.table_name,
|
|
722
|
+
"filters": sorted_filters,
|
|
723
|
+
"limit": self._limit,
|
|
724
|
+
"offset": self._offset,
|
|
725
|
+
"order_by": self._order_by,
|
|
726
|
+
"fields": tuple(sorted(self._fields)) if self._fields else None,
|
|
727
|
+
"fetch_one": fetch_one,
|
|
728
|
+
}
|
|
729
|
+
|
|
730
|
+
# Hash the key parts
|
|
731
|
+
key_json = json.dumps(key_parts, sort_keys=True, default=str)
|
|
732
|
+
return hashlib.sha256(key_json.encode()).hexdigest()
|
|
733
|
+
|
|
734
|
+
@overload
|
|
735
|
+
def _fetch_result(self, *, fetch_one: Literal[True]) -> Optional[T]: ...
|
|
736
|
+
|
|
737
|
+
@overload
|
|
738
|
+
def _fetch_result(self, *, fetch_one: Literal[False]) -> list[T]: ...
|
|
739
|
+
|
|
740
|
+
def _fetch_result(
|
|
741
|
+
self, *, fetch_one: bool = False
|
|
742
|
+
) -> Union[list[T], Optional[T]]:
|
|
743
|
+
"""Fetch and convert query results to model instances.
|
|
744
|
+
|
|
745
|
+
Args:
|
|
746
|
+
fetch_one: If True, fetch only one result.
|
|
747
|
+
|
|
748
|
+
Returns:
|
|
749
|
+
A list of model instances, a single model instance, or None if no
|
|
750
|
+
results are found.
|
|
751
|
+
"""
|
|
752
|
+
# Check cache first (unless bypass is enabled)
|
|
753
|
+
if not self._bypass_cache:
|
|
754
|
+
cache_key = self._make_cache_key(fetch_one=fetch_one)
|
|
755
|
+
hit, cached = self.db._cache_get(self.table_name, cache_key) # noqa: SLF001
|
|
756
|
+
if hit:
|
|
757
|
+
# Cache stores correctly typed data, cast from Any
|
|
758
|
+
return cast("Union[list[T], Optional[T]]", cached)
|
|
759
|
+
|
|
760
|
+
result = self._execute_query(fetch_one=fetch_one)
|
|
761
|
+
|
|
762
|
+
if not result:
|
|
763
|
+
if not self._bypass_cache:
|
|
764
|
+
if fetch_one:
|
|
765
|
+
# Cache empty result
|
|
766
|
+
self.db._cache_set( # noqa: SLF001
|
|
767
|
+
self.table_name,
|
|
768
|
+
cache_key,
|
|
769
|
+
None,
|
|
770
|
+
ttl=self._query_cache_ttl,
|
|
771
|
+
)
|
|
772
|
+
return None
|
|
773
|
+
# Cache empty list
|
|
774
|
+
self.db._cache_set( # noqa: SLF001
|
|
775
|
+
self.table_name, cache_key, [], ttl=self._query_cache_ttl
|
|
776
|
+
)
|
|
777
|
+
return []
|
|
778
|
+
return None if fetch_one else []
|
|
779
|
+
|
|
780
|
+
if fetch_one:
|
|
781
|
+
# Ensure we pass a tuple, not a list, to _convert_row_to_model
|
|
782
|
+
if isinstance(result, list):
|
|
783
|
+
result = result[
|
|
784
|
+
0
|
|
785
|
+
] # Get the first (and only) result if it's wrapped in a list.
|
|
786
|
+
single_result = self._convert_row_to_model(result)
|
|
787
|
+
# Cache single result (unless bypass is enabled)
|
|
788
|
+
if not self._bypass_cache:
|
|
789
|
+
self.db._cache_set( # noqa: SLF001
|
|
790
|
+
self.table_name,
|
|
791
|
+
cache_key,
|
|
792
|
+
single_result,
|
|
793
|
+
ttl=self._query_cache_ttl,
|
|
794
|
+
)
|
|
795
|
+
return single_result
|
|
796
|
+
|
|
797
|
+
list_results = [self._convert_row_to_model(row) for row in result]
|
|
798
|
+
# Cache list result (unless bypass is enabled)
|
|
799
|
+
if not self._bypass_cache:
|
|
800
|
+
self.db._cache_set( # noqa: SLF001
|
|
801
|
+
self.table_name,
|
|
802
|
+
cache_key,
|
|
803
|
+
list_results,
|
|
804
|
+
ttl=self._query_cache_ttl,
|
|
805
|
+
)
|
|
806
|
+
return list_results
|
|
807
|
+
|
|
808
|
+
def fetch_all(self) -> list[T]:
|
|
809
|
+
"""Fetch all results of the query.
|
|
810
|
+
|
|
811
|
+
Returns:
|
|
812
|
+
A list of model instances representing all query results.
|
|
813
|
+
"""
|
|
814
|
+
return self._fetch_result(fetch_one=False)
|
|
815
|
+
|
|
816
|
+
def fetch_one(self) -> Optional[T]:
|
|
817
|
+
"""Fetch a single result of the query.
|
|
818
|
+
|
|
819
|
+
Returns:
|
|
820
|
+
A single model instance or None if no result is found.
|
|
821
|
+
"""
|
|
822
|
+
return self._fetch_result(fetch_one=True)
|
|
823
|
+
|
|
824
|
+
def fetch_first(self) -> Optional[T]:
|
|
825
|
+
"""Fetch the first result of the query.
|
|
826
|
+
|
|
827
|
+
Returns:
|
|
828
|
+
The first model instance or None if no result is found.
|
|
829
|
+
"""
|
|
830
|
+
self._limit = 1
|
|
831
|
+
return self._fetch_result(fetch_one=True)
|
|
832
|
+
|
|
833
|
+
def fetch_last(self) -> Optional[T]:
|
|
834
|
+
"""Fetch the last result of the query.
|
|
835
|
+
|
|
836
|
+
Returns:
|
|
837
|
+
The last model instance or None if no result is found.
|
|
838
|
+
"""
|
|
839
|
+
self._limit = 1
|
|
840
|
+
self._order_by = "rowid DESC"
|
|
841
|
+
return self._fetch_result(fetch_one=True)
|
|
842
|
+
|
|
843
|
+
def count(self) -> int:
|
|
844
|
+
"""Count the number of results for the current query.
|
|
845
|
+
|
|
846
|
+
Returns:
|
|
847
|
+
The number of results that match the current query conditions.
|
|
848
|
+
"""
|
|
849
|
+
result = self._execute_query(count_only=True)
|
|
850
|
+
|
|
851
|
+
return int(result[0][0]) if result else 0
|
|
852
|
+
|
|
853
|
+
def exists(self) -> bool:
|
|
854
|
+
"""Check if any results exist for the current query.
|
|
855
|
+
|
|
856
|
+
Returns:
|
|
857
|
+
True if at least one result exists, False otherwise.
|
|
858
|
+
"""
|
|
859
|
+
return self.count() > 0
|
|
860
|
+
|
|
861
|
+
def delete(self) -> int:
|
|
862
|
+
"""Delete records that match the current query conditions.
|
|
863
|
+
|
|
864
|
+
Returns:
|
|
865
|
+
The number of records deleted.
|
|
866
|
+
|
|
867
|
+
Raises:
|
|
868
|
+
RecordDeletionError: If there's an error deleting the records.
|
|
869
|
+
"""
|
|
870
|
+
sql = f'DELETE FROM "{self.table_name}"' # noqa: S608 # nosec
|
|
871
|
+
|
|
872
|
+
# Build the WHERE clause with special handling for None (NULL in SQL)
|
|
873
|
+
values, where_clause = self._parse_filter()
|
|
874
|
+
|
|
875
|
+
if self.filters:
|
|
876
|
+
sql += f" WHERE {where_clause}"
|
|
877
|
+
|
|
878
|
+
# Print the raw SQL and values if debug is enabled
|
|
879
|
+
if self.db.debug:
|
|
880
|
+
self.db._log_sql(sql, values) # noqa: SLF001
|
|
881
|
+
|
|
882
|
+
try:
|
|
883
|
+
with self.db.connect() as conn:
|
|
884
|
+
cursor = conn.cursor()
|
|
885
|
+
cursor.execute(sql, values)
|
|
886
|
+
deleted_count = cursor.rowcount
|
|
887
|
+
self.db._maybe_commit() # noqa: SLF001
|
|
888
|
+
self.db._cache_invalidate_table(self.table_name) # noqa: SLF001
|
|
889
|
+
return deleted_count
|
|
890
|
+
except sqlite3.Error as exc:
|
|
891
|
+
raise RecordDeletionError(self.table_name) from exc
|