sqliter-py 0.1.1__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqliter/constants.py +20 -0
- sqliter/exceptions.py +6 -0
- sqliter/model/model.py +74 -8
- sqliter/query/query.py +362 -84
- sqliter/sqliter.py +71 -19
- sqliter_py-0.3.0.dist-info/METADATA +601 -0
- sqliter_py-0.3.0.dist-info/RECORD +12 -0
- sqliter_py-0.3.0.dist-info/licenses/LICENSE.txt +20 -0
- sqliter_py-0.1.1.dist-info/METADATA +0 -204
- sqliter_py-0.1.1.dist-info/RECORD +0 -10
- {sqliter_py-0.1.1.dist-info → sqliter_py-0.3.0.dist-info}/WHEEL +0 -0
sqliter/constants.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""Define constants used in the library."""
|
|
2
|
+
|
|
3
|
+
OPERATOR_MAPPING = {
|
|
4
|
+
"__lt": "<",
|
|
5
|
+
"__lte": "<=",
|
|
6
|
+
"__gt": ">",
|
|
7
|
+
"__gte": ">=",
|
|
8
|
+
"__eq": "=",
|
|
9
|
+
"__ne": "!=",
|
|
10
|
+
"__in": "IN",
|
|
11
|
+
"__not_in": "NOT IN",
|
|
12
|
+
"__isnull": "IS NULL",
|
|
13
|
+
"__notnull": "IS NOT NULL",
|
|
14
|
+
"__startswith": "LIKE",
|
|
15
|
+
"__endswith": "LIKE",
|
|
16
|
+
"__contains": "LIKE",
|
|
17
|
+
"__istartswith": "LIKE",
|
|
18
|
+
"__iendswith": "LIKE",
|
|
19
|
+
"__icontains": "LIKE",
|
|
20
|
+
}
|
sqliter/exceptions.py
CHANGED
|
@@ -72,6 +72,12 @@ class InvalidOffsetError(SqliterError):
|
|
|
72
72
|
)
|
|
73
73
|
|
|
74
74
|
|
|
75
|
+
class InvalidOrderError(SqliterError):
|
|
76
|
+
"""Raised when an invalid order value is used."""
|
|
77
|
+
|
|
78
|
+
message_template = "Invalid order value - {}"
|
|
79
|
+
|
|
80
|
+
|
|
75
81
|
class TableCreationError(SqliterError):
|
|
76
82
|
"""Raised when a table cannot be created in the database."""
|
|
77
83
|
|
sqliter/model/model.py
CHANGED
|
@@ -2,30 +2,96 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
|
|
5
|
+
import re
|
|
6
|
+
from typing import Any, Optional, TypeVar, Union, get_args, get_origin
|
|
6
7
|
|
|
7
|
-
from pydantic import BaseModel
|
|
8
|
+
from pydantic import BaseModel, ConfigDict
|
|
9
|
+
|
|
10
|
+
T = TypeVar("T", bound="BaseDBModel")
|
|
8
11
|
|
|
9
12
|
|
|
10
13
|
class BaseDBModel(BaseModel):
|
|
11
14
|
"""Custom base model for database models."""
|
|
12
15
|
|
|
16
|
+
model_config = ConfigDict(
|
|
17
|
+
extra="ignore",
|
|
18
|
+
populate_by_name=True,
|
|
19
|
+
validate_assignment=False,
|
|
20
|
+
from_attributes=True,
|
|
21
|
+
)
|
|
22
|
+
|
|
13
23
|
class Meta:
|
|
14
24
|
"""Configure the base model with default options."""
|
|
15
25
|
|
|
16
|
-
|
|
17
|
-
|
|
26
|
+
create_pk: bool = (
|
|
27
|
+
True # Whether to create an auto-increment primary key
|
|
28
|
+
)
|
|
29
|
+
primary_key: str = "id" # Default primary key name
|
|
18
30
|
table_name: Optional[str] = (
|
|
19
31
|
None # Table name, defaults to class name if not set
|
|
20
32
|
)
|
|
21
33
|
|
|
34
|
+
@classmethod
|
|
35
|
+
def model_validate_partial(cls: type[T], obj: dict[str, Any]) -> T:
|
|
36
|
+
"""Validate a partial model object."""
|
|
37
|
+
converted_obj: dict[str, Any] = {}
|
|
38
|
+
for field_name, value in obj.items():
|
|
39
|
+
field = cls.model_fields[field_name]
|
|
40
|
+
field_type: Optional[type] = field.annotation
|
|
41
|
+
if (
|
|
42
|
+
field_type is None or value is None
|
|
43
|
+
): # Direct check for None values here
|
|
44
|
+
converted_obj[field_name] = None
|
|
45
|
+
else:
|
|
46
|
+
origin = get_origin(field_type)
|
|
47
|
+
if origin is Union:
|
|
48
|
+
args = get_args(field_type)
|
|
49
|
+
for arg in args:
|
|
50
|
+
try:
|
|
51
|
+
# Try converting the value to the type
|
|
52
|
+
converted_obj[field_name] = arg(value)
|
|
53
|
+
break
|
|
54
|
+
except (ValueError, TypeError):
|
|
55
|
+
pass
|
|
56
|
+
else:
|
|
57
|
+
converted_obj[field_name] = value
|
|
58
|
+
else:
|
|
59
|
+
converted_obj[field_name] = field_type(value)
|
|
60
|
+
|
|
61
|
+
return cls.model_construct(**converted_obj)
|
|
62
|
+
|
|
22
63
|
@classmethod
|
|
23
64
|
def get_table_name(cls) -> str:
|
|
24
|
-
"""Get the table name from the Meta, or
|
|
65
|
+
"""Get the table name from the Meta, or generate one.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
str: The table name, either specified in the Meta class or
|
|
69
|
+
generated by converting the class name to pluralized snake_case
|
|
70
|
+
and removing any 'Model' suffix.
|
|
71
|
+
"""
|
|
25
72
|
table_name: str | None = getattr(cls.Meta, "table_name", None)
|
|
26
73
|
if table_name is not None:
|
|
27
74
|
return table_name
|
|
28
|
-
|
|
75
|
+
|
|
76
|
+
# Get class name and remove 'Model' suffix if present
|
|
77
|
+
class_name = cls.__name__.removesuffix("Model")
|
|
78
|
+
|
|
79
|
+
# Convert to snake_case
|
|
80
|
+
snake_case_name = re.sub(r"(?<!^)(?=[A-Z])", "_", class_name).lower()
|
|
81
|
+
|
|
82
|
+
# Pluralize the table name
|
|
83
|
+
try:
|
|
84
|
+
import inflect
|
|
85
|
+
|
|
86
|
+
p = inflect.engine()
|
|
87
|
+
return p.plural(snake_case_name)
|
|
88
|
+
except ImportError:
|
|
89
|
+
# Fallback to simple pluralization by adding 's'
|
|
90
|
+
return (
|
|
91
|
+
snake_case_name
|
|
92
|
+
if snake_case_name.endswith("s")
|
|
93
|
+
else snake_case_name + "s"
|
|
94
|
+
)
|
|
29
95
|
|
|
30
96
|
@classmethod
|
|
31
97
|
def get_primary_key(cls) -> str:
|
|
@@ -33,6 +99,6 @@ class BaseDBModel(BaseModel):
|
|
|
33
99
|
return getattr(cls.Meta, "primary_key", "id")
|
|
34
100
|
|
|
35
101
|
@classmethod
|
|
36
|
-
def
|
|
102
|
+
def should_create_pk(cls) -> bool:
|
|
37
103
|
"""Check whether the model should create an auto-increment ID."""
|
|
38
|
-
return getattr(cls.Meta, "
|
|
104
|
+
return getattr(cls.Meta, "create_pk", True)
|
sqliter/query/query.py
CHANGED
|
@@ -3,45 +3,256 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
import sqlite3
|
|
6
|
-
|
|
6
|
+
import warnings
|
|
7
|
+
from typing import (
|
|
8
|
+
TYPE_CHECKING,
|
|
9
|
+
Any,
|
|
10
|
+
Callable,
|
|
11
|
+
Literal,
|
|
12
|
+
Optional,
|
|
13
|
+
Union,
|
|
14
|
+
overload,
|
|
15
|
+
)
|
|
7
16
|
|
|
8
|
-
from typing_extensions import Self
|
|
17
|
+
from typing_extensions import LiteralString, Self
|
|
9
18
|
|
|
19
|
+
from sqliter.constants import OPERATOR_MAPPING
|
|
10
20
|
from sqliter.exceptions import (
|
|
11
21
|
InvalidFilterError,
|
|
12
22
|
InvalidOffsetError,
|
|
23
|
+
InvalidOrderError,
|
|
13
24
|
RecordFetchError,
|
|
14
25
|
)
|
|
15
26
|
|
|
16
27
|
if TYPE_CHECKING: # pragma: no cover
|
|
28
|
+
from pydantic.fields import FieldInfo
|
|
29
|
+
|
|
17
30
|
from sqliter import SqliterDB
|
|
18
31
|
from sqliter.model import BaseDBModel
|
|
19
32
|
|
|
33
|
+
# Define a type alias for the possible value types
|
|
34
|
+
FilterValue = Union[
|
|
35
|
+
str, int, float, bool, None, list[Union[str, int, float, bool]]
|
|
36
|
+
]
|
|
37
|
+
|
|
20
38
|
|
|
21
39
|
class QueryBuilder:
|
|
22
40
|
"""Functions to build and execute queries for a given model."""
|
|
23
41
|
|
|
24
|
-
def __init__(
|
|
25
|
-
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
db: SqliterDB,
|
|
45
|
+
model_class: type[BaseDBModel],
|
|
46
|
+
fields: Optional[list[str]] = None,
|
|
47
|
+
) -> None:
|
|
48
|
+
"""Initialize the query builder.
|
|
49
|
+
|
|
50
|
+
Pass the database, model class, and optional fields.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
db: The SqliterDB instance.
|
|
54
|
+
model_class: The model class to query.
|
|
55
|
+
fields: Optional list of field names to select. If None, all fields
|
|
56
|
+
are selected.
|
|
57
|
+
"""
|
|
26
58
|
self.db = db
|
|
27
59
|
self.model_class = model_class
|
|
28
60
|
self.table_name = model_class.get_table_name() # Use model_class method
|
|
29
|
-
self.filters: list[tuple[str, Any]] = []
|
|
61
|
+
self.filters: list[tuple[str, Any, str]] = []
|
|
30
62
|
self._limit: Optional[int] = None
|
|
31
63
|
self._offset: Optional[int] = None
|
|
32
64
|
self._order_by: Optional[str] = None
|
|
65
|
+
self._fields: Optional[list[str]] = fields
|
|
66
|
+
|
|
67
|
+
if self._fields:
|
|
68
|
+
self._validate_fields()
|
|
69
|
+
|
|
70
|
+
def _validate_fields(self) -> None:
|
|
71
|
+
"""Validate that the specified fields exist in the model."""
|
|
72
|
+
if self._fields is None:
|
|
73
|
+
return
|
|
74
|
+
valid_fields = set(self.model_class.model_fields.keys())
|
|
75
|
+
invalid_fields = set(self._fields) - valid_fields
|
|
76
|
+
if invalid_fields:
|
|
77
|
+
err_message = (
|
|
78
|
+
f"Invalid fields specified: {', '.join(invalid_fields)}"
|
|
79
|
+
)
|
|
80
|
+
raise ValueError(err_message)
|
|
33
81
|
|
|
34
|
-
def filter(self, **conditions: str | float | None) ->
|
|
82
|
+
def filter(self, **conditions: str | float | None) -> QueryBuilder:
|
|
35
83
|
"""Add filter conditions to the query."""
|
|
36
84
|
valid_fields = self.model_class.model_fields
|
|
37
85
|
|
|
38
86
|
for field, value in conditions.items():
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
87
|
+
field_name, operator = self._parse_field_operator(field)
|
|
88
|
+
self._validate_field(field_name, valid_fields)
|
|
89
|
+
|
|
90
|
+
handler = self._get_operator_handler(operator)
|
|
91
|
+
handler(field_name, value, operator)
|
|
92
|
+
|
|
93
|
+
return self
|
|
94
|
+
|
|
95
|
+
def fields(self, fields: Optional[list[str]] = None) -> QueryBuilder:
|
|
96
|
+
"""Select specific fields to return in the query."""
|
|
97
|
+
if fields:
|
|
98
|
+
self._fields = fields
|
|
99
|
+
self._validate_fields()
|
|
100
|
+
return self
|
|
101
|
+
|
|
102
|
+
def exclude(self, fields: Optional[list[str]] = None) -> QueryBuilder:
|
|
103
|
+
"""Exclude specific fields from the query output."""
|
|
104
|
+
if fields:
|
|
105
|
+
all_fields = set(self.model_class.model_fields.keys())
|
|
106
|
+
|
|
107
|
+
# Check for invalid fields before subtraction
|
|
108
|
+
invalid_fields = set(fields) - all_fields
|
|
109
|
+
if invalid_fields:
|
|
110
|
+
err = (
|
|
111
|
+
"Invalid fields specified for exclusion: "
|
|
112
|
+
f"{', '.join(invalid_fields)}"
|
|
113
|
+
)
|
|
114
|
+
raise ValueError(err)
|
|
115
|
+
|
|
116
|
+
# Subtract the fields specified for exclusion
|
|
117
|
+
self._fields = list(all_fields - set(fields))
|
|
118
|
+
|
|
119
|
+
# Explicit check: raise an error if no fields remain
|
|
120
|
+
if not self._fields:
|
|
121
|
+
err = "Exclusion results in no fields being selected."
|
|
122
|
+
raise ValueError(err)
|
|
123
|
+
|
|
124
|
+
# Now validate the remaining fields to ensure they are all valid
|
|
125
|
+
self._validate_fields()
|
|
126
|
+
|
|
127
|
+
return self
|
|
128
|
+
|
|
129
|
+
def only(self, field: str) -> QueryBuilder:
|
|
130
|
+
"""Return only the specified single field."""
|
|
131
|
+
all_fields = set(self.model_class.model_fields.keys())
|
|
132
|
+
|
|
133
|
+
# Validate that the field exists
|
|
134
|
+
if field not in all_fields:
|
|
135
|
+
err = f"Invalid field specified: {field}"
|
|
136
|
+
raise ValueError(err)
|
|
42
137
|
|
|
138
|
+
# Set self._fields to just the single field
|
|
139
|
+
self._fields = [field]
|
|
43
140
|
return self
|
|
44
141
|
|
|
142
|
+
def _get_operator_handler(
|
|
143
|
+
self, operator: str
|
|
144
|
+
) -> Callable[[str, Any, str], None]:
|
|
145
|
+
handlers = {
|
|
146
|
+
"__isnull": self._handle_null,
|
|
147
|
+
"__notnull": self._handle_null,
|
|
148
|
+
"__in": self._handle_in,
|
|
149
|
+
"__not_in": self._handle_in,
|
|
150
|
+
"__startswith": self._handle_like,
|
|
151
|
+
"__endswith": self._handle_like,
|
|
152
|
+
"__contains": self._handle_like,
|
|
153
|
+
"__istartswith": self._handle_like,
|
|
154
|
+
"__iendswith": self._handle_like,
|
|
155
|
+
"__icontains": self._handle_like,
|
|
156
|
+
"__lt": self._handle_comparison,
|
|
157
|
+
"__lte": self._handle_comparison,
|
|
158
|
+
"__gt": self._handle_comparison,
|
|
159
|
+
"__gte": self._handle_comparison,
|
|
160
|
+
"__ne": self._handle_comparison,
|
|
161
|
+
}
|
|
162
|
+
return handlers.get(operator, self._handle_equality)
|
|
163
|
+
|
|
164
|
+
def _validate_field(
|
|
165
|
+
self, field_name: str, valid_fields: dict[str, FieldInfo]
|
|
166
|
+
) -> None:
|
|
167
|
+
if field_name not in valid_fields:
|
|
168
|
+
raise InvalidFilterError(field_name)
|
|
169
|
+
|
|
170
|
+
def _handle_equality(
|
|
171
|
+
self, field_name: str, value: FilterValue, operator: str
|
|
172
|
+
) -> None:
|
|
173
|
+
if value is None:
|
|
174
|
+
self.filters.append((f"{field_name} IS NULL", None, "__isnull"))
|
|
175
|
+
else:
|
|
176
|
+
self.filters.append((field_name, value, operator))
|
|
177
|
+
|
|
178
|
+
def _handle_null(
|
|
179
|
+
self, field_name: str, _: FilterValue, operator: str
|
|
180
|
+
) -> None:
|
|
181
|
+
condition = (
|
|
182
|
+
f"{field_name} IS NOT NULL"
|
|
183
|
+
if operator == "__notnull"
|
|
184
|
+
else f"{field_name} IS NULL"
|
|
185
|
+
)
|
|
186
|
+
self.filters.append((condition, None, operator))
|
|
187
|
+
|
|
188
|
+
def _handle_in(
|
|
189
|
+
self, field_name: str, value: FilterValue, operator: str
|
|
190
|
+
) -> None:
|
|
191
|
+
if not isinstance(value, list):
|
|
192
|
+
err = f"{field_name} requires a list for '{operator}'"
|
|
193
|
+
raise TypeError(err)
|
|
194
|
+
sql_operator = OPERATOR_MAPPING.get(operator, "IN")
|
|
195
|
+
placeholder_list = ", ".join(["?"] * len(value))
|
|
196
|
+
self.filters.append(
|
|
197
|
+
(
|
|
198
|
+
f"{field_name} {sql_operator} ({placeholder_list})",
|
|
199
|
+
value,
|
|
200
|
+
operator,
|
|
201
|
+
)
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
def _handle_like(
|
|
205
|
+
self, field_name: str, value: FilterValue, operator: str
|
|
206
|
+
) -> None:
|
|
207
|
+
if not isinstance(value, str):
|
|
208
|
+
err = f"{field_name} requires a string value for '{operator}'"
|
|
209
|
+
raise TypeError(err)
|
|
210
|
+
formatted_value = self._format_string_for_operator(operator, value)
|
|
211
|
+
if operator in ["__startswith", "__endswith", "__contains"]:
|
|
212
|
+
self.filters.append(
|
|
213
|
+
(
|
|
214
|
+
f"{field_name} GLOB ?",
|
|
215
|
+
[formatted_value],
|
|
216
|
+
operator,
|
|
217
|
+
)
|
|
218
|
+
)
|
|
219
|
+
elif operator in ["__istartswith", "__iendswith", "__icontains"]:
|
|
220
|
+
self.filters.append(
|
|
221
|
+
(
|
|
222
|
+
f"{field_name} LIKE ?",
|
|
223
|
+
[formatted_value],
|
|
224
|
+
operator,
|
|
225
|
+
)
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
def _handle_comparison(
|
|
229
|
+
self, field_name: str, value: FilterValue, operator: str
|
|
230
|
+
) -> None:
|
|
231
|
+
sql_operator = OPERATOR_MAPPING[operator]
|
|
232
|
+
self.filters.append((f"{field_name} {sql_operator} ?", value, operator))
|
|
233
|
+
|
|
234
|
+
# Helper method for parsing field and operator
|
|
235
|
+
def _parse_field_operator(self, field: str) -> tuple[str, str]:
|
|
236
|
+
for operator in OPERATOR_MAPPING:
|
|
237
|
+
if field.endswith(operator):
|
|
238
|
+
return field[: -len(operator)], operator
|
|
239
|
+
return field, "__eq" # Default to equality if no operator is found
|
|
240
|
+
|
|
241
|
+
# Helper method for formatting string operators (like startswith)
|
|
242
|
+
def _format_string_for_operator(self, operator: str, value: str) -> str:
|
|
243
|
+
# Mapping operators to their corresponding string format
|
|
244
|
+
format_map = {
|
|
245
|
+
"__startswith": f"{value}*",
|
|
246
|
+
"__endswith": f"*{value}",
|
|
247
|
+
"__contains": f"*{value}*",
|
|
248
|
+
"__istartswith": f"{value.lower()}%",
|
|
249
|
+
"__iendswith": f"%{value.lower()}",
|
|
250
|
+
"__icontains": f"%{value.lower()}%",
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
# Return the formatted string or the original value if no match
|
|
254
|
+
return format_map.get(operator, value)
|
|
255
|
+
|
|
45
256
|
def limit(self, limit_value: int) -> Self:
|
|
46
257
|
"""Limit the number of results returned by the query."""
|
|
47
258
|
self._limit = limit_value
|
|
@@ -49,7 +260,7 @@ class QueryBuilder:
|
|
|
49
260
|
|
|
50
261
|
def offset(self, offset_value: int) -> Self:
|
|
51
262
|
"""Set an offset value for the query."""
|
|
52
|
-
if offset_value
|
|
263
|
+
if offset_value < 0:
|
|
53
264
|
raise InvalidOffsetError(offset_value)
|
|
54
265
|
self._offset = offset_value
|
|
55
266
|
|
|
@@ -57,28 +268,83 @@ class QueryBuilder:
|
|
|
57
268
|
self._limit = -1
|
|
58
269
|
return self
|
|
59
270
|
|
|
60
|
-
def order(
|
|
61
|
-
|
|
62
|
-
|
|
271
|
+
def order(
|
|
272
|
+
self,
|
|
273
|
+
order_by_field: str,
|
|
274
|
+
direction: Optional[str] = None,
|
|
275
|
+
*,
|
|
276
|
+
reverse: bool = False,
|
|
277
|
+
) -> Self:
|
|
278
|
+
"""Order the query results by the specified field.
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
order_by_field (str): The field to order by.
|
|
282
|
+
direction (Optional[str]): The ordering direction ('ASC' or 'DESC').
|
|
283
|
+
This is deprecated in favor of 'reverse'.
|
|
284
|
+
reverse (bool): Whether to reverse the order (True for descending,
|
|
285
|
+
False for ascending).
|
|
286
|
+
|
|
287
|
+
Raises:
|
|
288
|
+
InvalidOrderError: If the field doesn't exist in the model fields
|
|
289
|
+
or if both 'direction' and 'reverse' are specified.
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
QueryBuilder: The current query builder instance with updated
|
|
293
|
+
ordering.
|
|
294
|
+
"""
|
|
295
|
+
if direction:
|
|
296
|
+
warnings.warn(
|
|
297
|
+
"'direction' argument is deprecated and will be removed in a "
|
|
298
|
+
"future version. Use 'reverse' instead.",
|
|
299
|
+
DeprecationWarning,
|
|
300
|
+
stacklevel=2,
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
if order_by_field not in self.model_class.model_fields:
|
|
304
|
+
err = f"'{order_by_field}' does not exist in the model fields."
|
|
305
|
+
raise InvalidOrderError(err)
|
|
306
|
+
# Raise an exception if both 'direction' and 'reverse' are specified
|
|
307
|
+
if direction and reverse:
|
|
308
|
+
err = (
|
|
309
|
+
"Cannot specify both 'direction' and 'reverse' as it "
|
|
310
|
+
"is ambiguous."
|
|
311
|
+
)
|
|
312
|
+
raise InvalidOrderError(err)
|
|
313
|
+
|
|
314
|
+
# Determine the sorting direction
|
|
315
|
+
if reverse:
|
|
316
|
+
sort_order = "DESC"
|
|
317
|
+
elif direction:
|
|
318
|
+
sort_order = direction.upper()
|
|
319
|
+
if sort_order not in {"ASC", "DESC"}:
|
|
320
|
+
err = f"'{direction}' is not a valid sorting direction."
|
|
321
|
+
raise InvalidOrderError(err)
|
|
322
|
+
else:
|
|
323
|
+
sort_order = "ASC"
|
|
324
|
+
|
|
325
|
+
self._order_by = f'"{order_by_field}" {sort_order}'
|
|
63
326
|
return self
|
|
64
327
|
|
|
65
328
|
def _execute_query(
|
|
66
329
|
self,
|
|
67
330
|
*,
|
|
68
331
|
fetch_one: bool = False,
|
|
332
|
+
count_only: bool = False,
|
|
69
333
|
) -> list[tuple[Any, ...]] | Optional[tuple[Any, ...]]:
|
|
70
334
|
"""Helper function to execute the query with filters."""
|
|
71
|
-
|
|
335
|
+
if count_only:
|
|
336
|
+
fields = "COUNT(*)"
|
|
337
|
+
elif self._fields:
|
|
338
|
+
fields = ", ".join(f'"{field}"' for field in self._fields)
|
|
339
|
+
else:
|
|
340
|
+
fields = ", ".join(
|
|
341
|
+
f'"{field}"' for field in self.model_class.model_fields
|
|
342
|
+
)
|
|
72
343
|
|
|
73
|
-
|
|
74
|
-
where_clause = " AND ".join(
|
|
75
|
-
[
|
|
76
|
-
f"{field} IS NULL" if value is None else f"{field} = ?"
|
|
77
|
-
for field, value in self.filters
|
|
78
|
-
]
|
|
79
|
-
)
|
|
344
|
+
sql = f'SELECT {fields} FROM "{self.table_name}"' # noqa: S608 # nosec
|
|
80
345
|
|
|
81
|
-
|
|
346
|
+
# Build the WHERE clause with special handling for None (NULL in SQL)
|
|
347
|
+
values, where_clause = self._parse_filter()
|
|
82
348
|
|
|
83
349
|
if self.filters:
|
|
84
350
|
sql += f" WHERE {where_clause}"
|
|
@@ -87,13 +353,12 @@ class QueryBuilder:
|
|
|
87
353
|
sql += f" ORDER BY {self._order_by}"
|
|
88
354
|
|
|
89
355
|
if self._limit is not None:
|
|
90
|
-
sql +=
|
|
356
|
+
sql += " LIMIT ?"
|
|
357
|
+
values.append(self._limit)
|
|
91
358
|
|
|
92
359
|
if self._offset is not None:
|
|
93
|
-
sql +=
|
|
94
|
-
|
|
95
|
-
# Only include non-None values in the values list
|
|
96
|
-
values = [value for _, value in self.filters if value is not None]
|
|
360
|
+
sql += " OFFSET ?"
|
|
361
|
+
values.append(self._offset)
|
|
97
362
|
|
|
98
363
|
try:
|
|
99
364
|
with self.db.connect() as conn:
|
|
@@ -103,80 +368,93 @@ class QueryBuilder:
|
|
|
103
368
|
except sqlite3.Error as exc:
|
|
104
369
|
raise RecordFetchError(self.table_name) from exc
|
|
105
370
|
|
|
106
|
-
def
|
|
107
|
-
"""
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
371
|
+
def _parse_filter(self) -> tuple[list[Any], LiteralString]:
|
|
372
|
+
"""Actually parse the filters."""
|
|
373
|
+
where_clauses = []
|
|
374
|
+
values = []
|
|
375
|
+
for field, value, operator in self.filters:
|
|
376
|
+
if operator == "__eq":
|
|
377
|
+
where_clauses.append(f"{field} = ?")
|
|
378
|
+
values.append(value)
|
|
379
|
+
else:
|
|
380
|
+
where_clauses.append(field)
|
|
381
|
+
if operator not in ["__isnull", "__notnull"]:
|
|
382
|
+
if isinstance(value, list):
|
|
383
|
+
values.extend(value)
|
|
384
|
+
else:
|
|
385
|
+
values.append(value)
|
|
386
|
+
|
|
387
|
+
where_clause = " AND ".join(where_clauses)
|
|
388
|
+
return values, where_clause
|
|
389
|
+
|
|
390
|
+
def _convert_row_to_model(self, row: tuple[Any, ...]) -> BaseDBModel:
|
|
391
|
+
"""Convert a result row tuple into a Pydantic model."""
|
|
392
|
+
if self._fields:
|
|
393
|
+
return self.model_class.model_validate_partial(
|
|
394
|
+
{field: row[idx] for idx, field in enumerate(self._fields)}
|
|
119
395
|
)
|
|
120
|
-
for row in results
|
|
121
|
-
]
|
|
122
|
-
|
|
123
|
-
def fetch_one(self) -> BaseDBModel | None:
|
|
124
|
-
"""Fetch exactly one result."""
|
|
125
|
-
result = self._execute_query(fetch_one=True)
|
|
126
|
-
if not result:
|
|
127
|
-
return None
|
|
128
396
|
return self.model_class(
|
|
129
397
|
**{
|
|
130
|
-
field:
|
|
398
|
+
field: row[idx]
|
|
131
399
|
for idx, field in enumerate(self.model_class.model_fields)
|
|
132
400
|
}
|
|
133
401
|
)
|
|
134
402
|
|
|
135
|
-
|
|
403
|
+
@overload
|
|
404
|
+
def _fetch_result(
|
|
405
|
+
self, *, fetch_one: Literal[True]
|
|
406
|
+
) -> Optional[BaseDBModel]: ...
|
|
407
|
+
|
|
408
|
+
@overload
|
|
409
|
+
def _fetch_result(
|
|
410
|
+
self, *, fetch_one: Literal[False]
|
|
411
|
+
) -> list[BaseDBModel]: ...
|
|
412
|
+
|
|
413
|
+
def _fetch_result(
|
|
414
|
+
self, *, fetch_one: bool = False
|
|
415
|
+
) -> Union[list[BaseDBModel], Optional[BaseDBModel]]:
|
|
416
|
+
"""Fetch one or all results and convert them to Pydantic models."""
|
|
417
|
+
result = self._execute_query(fetch_one=fetch_one)
|
|
418
|
+
|
|
419
|
+
if not result:
|
|
420
|
+
if fetch_one:
|
|
421
|
+
return None
|
|
422
|
+
return []
|
|
423
|
+
|
|
424
|
+
if fetch_one:
|
|
425
|
+
# Ensure we pass a tuple, not a list, to _convert_row_to_model
|
|
426
|
+
if isinstance(result, list):
|
|
427
|
+
result = result[
|
|
428
|
+
0
|
|
429
|
+
] # Get the first (and only) result if it's wrapped in a list.
|
|
430
|
+
return self._convert_row_to_model(result)
|
|
431
|
+
|
|
432
|
+
return [self._convert_row_to_model(row) for row in result]
|
|
433
|
+
|
|
434
|
+
def fetch_all(self) -> list[BaseDBModel]:
|
|
435
|
+
"""Fetch all results matching the filters."""
|
|
436
|
+
return self._fetch_result(fetch_one=False)
|
|
437
|
+
|
|
438
|
+
def fetch_one(self) -> Optional[BaseDBModel]:
|
|
439
|
+
"""Fetch exactly one result."""
|
|
440
|
+
return self._fetch_result(fetch_one=True)
|
|
441
|
+
|
|
442
|
+
def fetch_first(self) -> Optional[BaseDBModel]:
|
|
136
443
|
"""Fetch the first result of the query."""
|
|
137
444
|
self._limit = 1
|
|
138
|
-
|
|
139
|
-
if not result:
|
|
140
|
-
return None
|
|
141
|
-
return self.model_class(
|
|
142
|
-
**{
|
|
143
|
-
field: result[0][idx]
|
|
144
|
-
for idx, field in enumerate(self.model_class.model_fields)
|
|
145
|
-
}
|
|
146
|
-
)
|
|
445
|
+
return self._fetch_result(fetch_one=True)
|
|
147
446
|
|
|
148
|
-
def fetch_last(self) -> BaseDBModel
|
|
447
|
+
def fetch_last(self) -> Optional[BaseDBModel]:
|
|
149
448
|
"""Fetch the last result of the query (based on the insertion order)."""
|
|
150
449
|
self._limit = 1
|
|
151
450
|
self._order_by = "rowid DESC"
|
|
152
|
-
|
|
153
|
-
if not result:
|
|
154
|
-
return None
|
|
155
|
-
return self.model_class(
|
|
156
|
-
**{
|
|
157
|
-
field: result[0][idx]
|
|
158
|
-
for idx, field in enumerate(self.model_class.model_fields)
|
|
159
|
-
}
|
|
160
|
-
)
|
|
451
|
+
return self._fetch_result(fetch_one=True)
|
|
161
452
|
|
|
162
453
|
def count(self) -> int:
|
|
163
454
|
"""Return the count of records matching the filters."""
|
|
164
|
-
|
|
165
|
-
[f"{field} = ?" for field, _ in self.filters]
|
|
166
|
-
)
|
|
167
|
-
sql = f"SELECT COUNT(*) FROM {self.table_name}" # noqa: S608
|
|
168
|
-
|
|
169
|
-
if self.filters:
|
|
170
|
-
sql += f" WHERE {where_clause}"
|
|
171
|
-
|
|
172
|
-
values = [value for _, value in self.filters]
|
|
173
|
-
|
|
174
|
-
with self.db.connect() as conn:
|
|
175
|
-
cursor = conn.cursor()
|
|
176
|
-
cursor.execute(sql, values)
|
|
177
|
-
result = cursor.fetchone()
|
|
455
|
+
result = self._execute_query(count_only=True)
|
|
178
456
|
|
|
179
|
-
return int(result[0]) if result else 0
|
|
457
|
+
return int(result[0][0]) if result else 0
|
|
180
458
|
|
|
181
459
|
def exists(self) -> bool:
|
|
182
460
|
"""Return True if any record matches the filters."""
|