sqliter-py 0.3.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sqliter/query/query.py CHANGED
@@ -1,4 +1,11 @@
1
- """Define the 'QueryBuilder' class for building SQL queries."""
1
+ """Implements the query building and execution logic for SQLiter.
2
+
3
+ This module defines the QueryBuilder class, which provides a fluent
4
+ interface for constructing SQL queries. It supports operations such
5
+ as filtering, ordering, limiting, and various data retrieval methods,
6
+ allowing for flexible and expressive database queries without writing
7
+ raw SQL.
8
+ """
2
9
 
3
10
  from __future__ import annotations
4
11
 
@@ -21,6 +28,7 @@ from sqliter.exceptions import (
21
28
  InvalidFilterError,
22
29
  InvalidOffsetError,
23
30
  InvalidOrderError,
31
+ RecordDeletionError,
24
32
  RecordFetchError,
25
33
  )
26
34
 
@@ -28,7 +36,7 @@ if TYPE_CHECKING: # pragma: no cover
28
36
  from pydantic.fields import FieldInfo
29
37
 
30
38
  from sqliter import SqliterDB
31
- from sqliter.model import BaseDBModel
39
+ from sqliter.model import BaseDBModel, SerializableField
32
40
 
33
41
  # Define a type alias for the possible value types
34
42
  FilterValue = Union[
@@ -37,7 +45,21 @@ FilterValue = Union[
37
45
 
38
46
 
39
47
  class QueryBuilder:
40
- """Functions to build and execute queries for a given model."""
48
+ """Builds and executes database queries for a specific model.
49
+
50
+ This class provides methods to construct SQL queries, apply filters,
51
+ set ordering, and execute the queries against the database.
52
+
53
+ Attributes:
54
+ db (SqliterDB): The database connection object.
55
+ model_class (type[BaseDBModel]): The Pydantic model class.
56
+ table_name (str): The name of the database table.
57
+ filters (list): List of applied filter conditions.
58
+ _limit (Optional[int]): The LIMIT clause value, if any.
59
+ _offset (Optional[int]): The OFFSET clause value, if any.
60
+ _order_by (Optional[str]): The ORDER BY clause, if any.
61
+ _fields (Optional[list[str]]): List of fields to select, if specified.
62
+ """
41
63
 
42
64
  def __init__(
43
65
  self,
@@ -45,13 +67,11 @@ class QueryBuilder:
45
67
  model_class: type[BaseDBModel],
46
68
  fields: Optional[list[str]] = None,
47
69
  ) -> None:
48
- """Initialize the query builder.
49
-
50
- Pass the database, model class, and optional fields.
70
+ """Initialize a new QueryBuilder instance.
51
71
 
52
72
  Args:
53
- db: The SqliterDB instance.
54
- model_class: The model class to query.
73
+ db: The database connection object.
74
+ model_class: The Pydantic model class for the table.
55
75
  fields: Optional list of field names to select. If None, all fields
56
76
  are selected.
57
77
  """
@@ -68,7 +88,11 @@ class QueryBuilder:
68
88
  self._validate_fields()
69
89
 
70
90
  def _validate_fields(self) -> None:
71
- """Validate that the specified fields exist in the model."""
91
+ """Validate that the specified fields exist in the model.
92
+
93
+ Raises:
94
+ ValueError: If any specified field is not in the model.
95
+ """
72
96
  if self._fields is None:
73
97
  return
74
98
  valid_fields = set(self.model_class.model_fields.keys())
@@ -80,28 +104,75 @@ class QueryBuilder:
80
104
  raise ValueError(err_message)
81
105
 
82
106
  def filter(self, **conditions: str | float | None) -> QueryBuilder:
83
- """Add filter conditions to the query."""
107
+ """Apply filter conditions to the query.
108
+
109
+ This method allows adding one or more filter conditions to the query.
110
+ Each condition is specified as a keyword argument, where the key is
111
+ the field name and the value is the condition to apply.
112
+
113
+ Args:
114
+ **conditions: Arbitrary keyword arguments representing filter
115
+ conditions. The key is the field name, and the value is the
116
+ condition to apply. Supported operators include equality,
117
+ comparison, and special operators like __in, __isnull, etc.
118
+
119
+ Returns:
120
+ QueryBuilder: The current QueryBuilder instance for method
121
+ chaining.
122
+
123
+ Examples:
124
+ >>> query.filter(name="John", age__gt=30)
125
+ >>> query.filter(status__in=["active", "pending"])
126
+ """
84
127
  valid_fields = self.model_class.model_fields
85
128
 
86
129
  for field, value in conditions.items():
87
130
  field_name, operator = self._parse_field_operator(field)
88
131
  self._validate_field(field_name, valid_fields)
89
132
 
90
- handler = self._get_operator_handler(operator)
91
- handler(field_name, value, operator)
133
+ if operator in ["__isnull", "__notnull"]:
134
+ self._handle_null(field_name, value, operator)
135
+ else:
136
+ handler = self._get_operator_handler(operator)
137
+ handler(field_name, value, operator)
92
138
 
93
139
  return self
94
140
 
95
141
  def fields(self, fields: Optional[list[str]] = None) -> QueryBuilder:
96
- """Select specific fields to return in the query."""
142
+ """Specify which fields to select in the query.
143
+
144
+ Args:
145
+ fields: List of field names to select. If None, all fields are
146
+ selected.
147
+
148
+ Returns:
149
+ The QueryBuilder instance for method chaining.
150
+ """
97
151
  if fields:
152
+ if "pk" not in fields:
153
+ fields.append("pk")
98
154
  self._fields = fields
99
155
  self._validate_fields()
100
156
  return self
101
157
 
102
158
  def exclude(self, fields: Optional[list[str]] = None) -> QueryBuilder:
103
- """Exclude specific fields from the query output."""
159
+ """Specify which fields to exclude from the query results.
160
+
161
+ Args:
162
+ fields: List of field names to exclude. If None, no fields are
163
+ excluded.
164
+
165
+ Returns:
166
+ The QueryBuilder instance for method chaining.
167
+
168
+ Raises:
169
+ ValueError: If exclusion results in no fields being selected or if
170
+ invalid fields are specified.
171
+ """
104
172
  if fields:
173
+ if "pk" in fields:
174
+ err = "The primary key 'pk' cannot be excluded."
175
+ raise ValueError(err)
105
176
  all_fields = set(self.model_class.model_fields.keys())
106
177
 
107
178
  # Check for invalid fields before subtraction
@@ -117,7 +188,7 @@ class QueryBuilder:
117
188
  self._fields = list(all_fields - set(fields))
118
189
 
119
190
  # Explicit check: raise an error if no fields remain
120
- if not self._fields:
191
+ if self._fields == ["pk"]:
121
192
  err = "Exclusion results in no fields being selected."
122
193
  raise ValueError(err)
123
194
 
@@ -127,7 +198,17 @@ class QueryBuilder:
127
198
  return self
128
199
 
129
200
  def only(self, field: str) -> QueryBuilder:
130
- """Return only the specified single field."""
201
+ """Specify a single field to select in the query.
202
+
203
+ Args:
204
+ field: The name of the field to select.
205
+
206
+ Returns:
207
+ The QueryBuilder instance for method chaining.
208
+
209
+ Raises:
210
+ ValueError: If the specified field is invalid.
211
+ """
131
212
  all_fields = set(self.model_class.model_fields.keys())
132
213
 
133
214
  # Validate that the field exists
@@ -136,12 +217,20 @@ class QueryBuilder:
136
217
  raise ValueError(err)
137
218
 
138
219
  # Set self._fields to just the single field
139
- self._fields = [field]
220
+ self._fields = [field, "pk"]
140
221
  return self
141
222
 
142
223
  def _get_operator_handler(
143
224
  self, operator: str
144
225
  ) -> Callable[[str, Any, str], None]:
226
+ """Get the appropriate handler function for the given operator.
227
+
228
+ Args:
229
+ operator: The filter operator string.
230
+
231
+ Returns:
232
+ A callable that handles the specific operator type.
233
+ """
145
234
  handlers = {
146
235
  "__isnull": self._handle_null,
147
236
  "__notnull": self._handle_null,
@@ -164,30 +253,70 @@ class QueryBuilder:
164
253
  def _validate_field(
165
254
  self, field_name: str, valid_fields: dict[str, FieldInfo]
166
255
  ) -> None:
256
+ """Validate that a field exists in the model.
257
+
258
+ Args:
259
+ field_name: The name of the field to validate.
260
+ valid_fields: Dictionary of valid fields from the model.
261
+
262
+ Raises:
263
+ InvalidFilterError: If the field is not in the model.
264
+ """
167
265
  if field_name not in valid_fields:
168
266
  raise InvalidFilterError(field_name)
169
267
 
170
268
  def _handle_equality(
171
269
  self, field_name: str, value: FilterValue, operator: str
172
270
  ) -> None:
271
+ """Handle equality filter conditions.
272
+
273
+ Args:
274
+ field_name: The name of the field to filter on.
275
+ value: The value to compare against.
276
+ operator: The operator string (usually '__eq').
277
+
278
+ This method adds an equality condition to the filters list, handling
279
+ NULL values separately.
280
+ """
173
281
  if value is None:
174
282
  self.filters.append((f"{field_name} IS NULL", None, "__isnull"))
175
283
  else:
176
284
  self.filters.append((field_name, value, operator))
177
285
 
178
286
  def _handle_null(
179
- self, field_name: str, _: FilterValue, operator: str
287
+ self, field_name: str, value: Union[str, float, None], operator: str
180
288
  ) -> None:
181
- condition = (
182
- f"{field_name} IS NOT NULL"
183
- if operator == "__notnull"
184
- else f"{field_name} IS NULL"
185
- )
289
+ """Handle IS NULL and IS NOT NULL filter conditions.
290
+
291
+ Args:
292
+ field_name: The name of the field to filter on. _: Placeholder for
293
+ unused value parameter.
294
+ operator: The operator string ('__isnull' or '__notnull').
295
+ value: The value to check for.
296
+
297
+ This method adds an IS NULL or IS NOT NULL condition to the filters
298
+ list.
299
+ """
300
+ is_null = operator == "__isnull"
301
+ check_null = bool(value) if is_null else not bool(value)
302
+ condition = f"{field_name} IS {'NOT ' if not check_null else ''}NULL"
186
303
  self.filters.append((condition, None, operator))
187
304
 
188
305
  def _handle_in(
189
306
  self, field_name: str, value: FilterValue, operator: str
190
307
  ) -> None:
308
+ """Handle IN and NOT IN filter conditions.
309
+
310
+ Args:
311
+ field_name: The name of the field to filter on.
312
+ value: A list of values to check against.
313
+ operator: The operator string ('__in' or '__not_in').
314
+
315
+ Raises:
316
+ TypeError: If the value is not a list.
317
+
318
+ This method adds an IN or NOT IN condition to the filters list.
319
+ """
191
320
  if not isinstance(value, list):
192
321
  err = f"{field_name} requires a list for '{operator}'"
193
322
  raise TypeError(err)
@@ -204,6 +333,19 @@ class QueryBuilder:
204
333
  def _handle_like(
205
334
  self, field_name: str, value: FilterValue, operator: str
206
335
  ) -> None:
336
+ """Handle LIKE and GLOB filter conditions.
337
+
338
+ Args:
339
+ field_name: The name of the field to filter on.
340
+ value: The pattern to match against.
341
+ operator: The operator string (e.g., '__startswith', '__contains').
342
+
343
+ Raises:
344
+ TypeError: If the value is not a string.
345
+
346
+ This method adds a LIKE or GLOB condition to the filters list, depending
347
+ on whether the operation is case-sensitive or not.
348
+ """
207
349
  if not isinstance(value, str):
208
350
  err = f"{field_name} requires a string value for '{operator}'"
209
351
  raise TypeError(err)
@@ -228,11 +370,29 @@ class QueryBuilder:
228
370
  def _handle_comparison(
229
371
  self, field_name: str, value: FilterValue, operator: str
230
372
  ) -> None:
373
+ """Handle comparison filter conditions.
374
+
375
+ Args:
376
+ field_name: The name of the field to filter on.
377
+ value: The value to compare against.
378
+ operator: The comparison operator string (e.g., '__lt', '__gte').
379
+
380
+ This method adds a comparison condition to the filters list.
381
+ """
231
382
  sql_operator = OPERATOR_MAPPING[operator]
232
383
  self.filters.append((f"{field_name} {sql_operator} ?", value, operator))
233
384
 
234
385
  # Helper method for parsing field and operator
235
386
  def _parse_field_operator(self, field: str) -> tuple[str, str]:
387
+ """Parse a field string to separate the field name and operator.
388
+
389
+ Args:
390
+ field: The field string, potentially including an operator.
391
+
392
+ Returns:
393
+ A tuple containing the field name and the operator (or '__eq' if
394
+ no operator was specified).
395
+ """
236
396
  for operator in OPERATOR_MAPPING:
237
397
  if field.endswith(operator):
238
398
  return field[: -len(operator)], operator
@@ -240,7 +400,15 @@ class QueryBuilder:
240
400
 
241
401
  # Helper method for formatting string operators (like startswith)
242
402
  def _format_string_for_operator(self, operator: str, value: str) -> str:
243
- # Mapping operators to their corresponding string format
403
+ """Format a string value based on the specified operator.
404
+
405
+ Args:
406
+ operator: The operator string (e.g., '__startswith', '__contains').
407
+ value: The original string value.
408
+
409
+ Returns:
410
+ The formatted string value suitable for the given operator.
411
+ """
244
412
  format_map = {
245
413
  "__startswith": f"{value}*",
246
414
  "__endswith": f"*{value}",
@@ -254,12 +422,29 @@ class QueryBuilder:
254
422
  return format_map.get(operator, value)
255
423
 
256
424
  def limit(self, limit_value: int) -> Self:
257
- """Limit the number of results returned by the query."""
425
+ """Limit the number of results returned by the query.
426
+
427
+ Args:
428
+ limit_value: The maximum number of records to return.
429
+
430
+ Returns:
431
+ The QueryBuilder instance for method chaining.
432
+ """
258
433
  self._limit = limit_value
259
434
  return self
260
435
 
261
436
  def offset(self, offset_value: int) -> Self:
262
- """Set an offset value for the query."""
437
+ """Set an offset value for the query.
438
+
439
+ Args:
440
+ offset_value: The number of records to skip.
441
+
442
+ Returns:
443
+ The QueryBuilder instance for method chaining.
444
+
445
+ Raises:
446
+ InvalidOffsetError: If the offset value is negative.
447
+ """
263
448
  if offset_value < 0:
264
449
  raise InvalidOffsetError(offset_value)
265
450
  self._offset = offset_value
@@ -270,7 +455,7 @@ class QueryBuilder:
270
455
 
271
456
  def order(
272
457
  self,
273
- order_by_field: str,
458
+ order_by_field: Optional[str] = None,
274
459
  direction: Optional[str] = None,
275
460
  *,
276
461
  reverse: bool = False,
@@ -278,19 +463,19 @@ class QueryBuilder:
278
463
  """Order the query results by the specified field.
279
464
 
280
465
  Args:
281
- order_by_field (str): The field to order by.
282
- direction (Optional[str]): The ordering direction ('ASC' or 'DESC').
283
- This is deprecated in favor of 'reverse'.
284
- reverse (bool): Whether to reverse the order (True for descending,
285
- False for ascending).
466
+ order_by_field: The field to order by [optional].
467
+ direction: Deprecated. Use 'reverse' instead.
468
+ reverse: If True, sort in descending order.
469
+
470
+ Returns:
471
+ The QueryBuilder instance for method chaining.
286
472
 
287
473
  Raises:
288
- InvalidOrderError: If the field doesn't exist in the model fields
289
- or if both 'direction' and 'reverse' are specified.
474
+ InvalidOrderError: If the field doesn't exist or if both 'direction'
475
+ and 'reverse' are specified.
290
476
 
291
- Returns:
292
- QueryBuilder: The current query builder instance with updated
293
- ordering.
477
+ Warns:
478
+ DeprecationWarning: If 'direction' is used instead of 'reverse'.
294
479
  """
295
480
  if direction:
296
481
  warnings.warn(
@@ -300,6 +485,9 @@ class QueryBuilder:
300
485
  stacklevel=2,
301
486
  )
302
487
 
488
+ if order_by_field is None:
489
+ order_by_field = self.model_class.get_primary_key()
490
+
303
491
  if order_by_field not in self.model_class.model_fields:
304
492
  err = f"'{order_by_field}' does not exist in the model fields."
305
493
  raise InvalidOrderError(err)
@@ -331,10 +519,24 @@ class QueryBuilder:
331
519
  fetch_one: bool = False,
332
520
  count_only: bool = False,
333
521
  ) -> list[tuple[Any, ...]] | Optional[tuple[Any, ...]]:
334
- """Helper function to execute the query with filters."""
522
+ """Execute the constructed SQL query.
523
+
524
+ Args:
525
+ fetch_one: If True, fetch only one result.
526
+ count_only: If True, return only the count of results.
527
+
528
+ Returns:
529
+ A list of tuples (all results), a single tuple (one result),
530
+ or None if no results are found.
531
+
532
+ Raises:
533
+ RecordFetchError: If there's an error executing the query.
534
+ """
335
535
  if count_only:
336
536
  fields = "COUNT(*)"
337
537
  elif self._fields:
538
+ if "pk" not in self._fields:
539
+ self._fields.append("pk")
338
540
  fields = ", ".join(f'"{field}"' for field in self._fields)
339
541
  else:
340
542
  fields = ", ".join(
@@ -360,6 +562,11 @@ class QueryBuilder:
360
562
  sql += " OFFSET ?"
361
563
  values.append(self._offset)
362
564
 
565
+ # Print the raw SQL and values if debug is enabled
566
+ # Log the SQL if debug is enabled
567
+ if self.db.debug:
568
+ self.db._log_sql(sql, values) # noqa: SLF001
569
+
363
570
  try:
364
571
  with self.db.connect() as conn:
365
572
  cursor = conn.cursor()
@@ -369,7 +576,13 @@ class QueryBuilder:
369
576
  raise RecordFetchError(self.table_name) from exc
370
577
 
371
578
  def _parse_filter(self) -> tuple[list[Any], LiteralString]:
372
- """Actually parse the filters."""
579
+ """Parse the filter conditions into SQL clauses and values.
580
+
581
+ Returns:
582
+ A tuple containing:
583
+ - A list of values to be used in the SQL query.
584
+ - A string representing the WHERE clause of the SQL query.
585
+ """
373
586
  where_clauses = []
374
587
  values = []
375
588
  for field, value, operator in self.filters:
@@ -388,16 +601,41 @@ class QueryBuilder:
388
601
  return values, where_clause
389
602
 
390
603
  def _convert_row_to_model(self, row: tuple[Any, ...]) -> BaseDBModel:
391
- """Convert a result row tuple into a Pydantic model."""
604
+ """Convert a database row to a model instance.
605
+
606
+ Args:
607
+ row: A tuple representing a database row.
608
+
609
+ Returns:
610
+ An instance of the model class populated with the row data.
611
+ """
392
612
  if self._fields:
393
- return self.model_class.model_validate_partial(
394
- {field: row[idx] for idx, field in enumerate(self._fields)}
395
- )
396
- return self.model_class(
397
- **{
398
- field: row[idx]
399
- for idx, field in enumerate(self.model_class.model_fields)
613
+ data = {
614
+ field: self._deserialize(field, row[idx])
615
+ for idx, field in enumerate(self._fields)
400
616
  }
617
+ return self.model_class.model_validate_partial(data)
618
+
619
+ data = {
620
+ field: self._deserialize(field, row[idx])
621
+ for idx, field in enumerate(self.model_class.model_fields)
622
+ }
623
+ return self.model_class(**data)
624
+
625
+ def _deserialize(
626
+ self, field_name: str, value: SerializableField
627
+ ) -> SerializableField:
628
+ """Deserialize a field value if needed.
629
+
630
+ Args:
631
+ field_name: Name of the field being deserialized.
632
+ value: Value from the database.
633
+
634
+ Returns:
635
+ The deserialized value.
636
+ """
637
+ return self.model_class.deserialize_field(
638
+ field_name, value, return_local_time=self.db.return_local_time
401
639
  )
402
640
 
403
641
  @overload
@@ -413,7 +651,15 @@ class QueryBuilder:
413
651
  def _fetch_result(
414
652
  self, *, fetch_one: bool = False
415
653
  ) -> Union[list[BaseDBModel], Optional[BaseDBModel]]:
416
- """Fetch one or all results and convert them to Pydantic models."""
654
+ """Fetch and convert query results to model instances.
655
+
656
+ Args:
657
+ fetch_one: If True, fetch only one result.
658
+
659
+ Returns:
660
+ A list of model instances, a single model instance, or None if no
661
+ results are found.
662
+ """
417
663
  result = self._execute_query(fetch_one=fetch_one)
418
664
 
419
665
  if not result:
@@ -432,30 +678,85 @@ class QueryBuilder:
432
678
  return [self._convert_row_to_model(row) for row in result]
433
679
 
434
680
  def fetch_all(self) -> list[BaseDBModel]:
435
- """Fetch all results matching the filters."""
681
+ """Fetch all results of the query.
682
+
683
+ Returns:
684
+ A list of model instances representing all query results.
685
+ """
436
686
  return self._fetch_result(fetch_one=False)
437
687
 
438
688
  def fetch_one(self) -> Optional[BaseDBModel]:
439
- """Fetch exactly one result."""
689
+ """Fetch a single result of the query.
690
+
691
+ Returns:
692
+ A single model instance or None if no result is found.
693
+ """
440
694
  return self._fetch_result(fetch_one=True)
441
695
 
442
696
  def fetch_first(self) -> Optional[BaseDBModel]:
443
- """Fetch the first result of the query."""
697
+ """Fetch the first result of the query.
698
+
699
+ Returns:
700
+ The first model instance or None if no result is found.
701
+ """
444
702
  self._limit = 1
445
703
  return self._fetch_result(fetch_one=True)
446
704
 
447
705
  def fetch_last(self) -> Optional[BaseDBModel]:
448
- """Fetch the last result of the query (based on the insertion order)."""
706
+ """Fetch the last result of the query.
707
+
708
+ Returns:
709
+ The last model instance or None if no result is found.
710
+ """
449
711
  self._limit = 1
450
712
  self._order_by = "rowid DESC"
451
713
  return self._fetch_result(fetch_one=True)
452
714
 
453
715
  def count(self) -> int:
454
- """Return the count of records matching the filters."""
716
+ """Count the number of results for the current query.
717
+
718
+ Returns:
719
+ The number of results that match the current query conditions.
720
+ """
455
721
  result = self._execute_query(count_only=True)
456
722
 
457
723
  return int(result[0][0]) if result else 0
458
724
 
459
725
  def exists(self) -> bool:
460
- """Return True if any record matches the filters."""
726
+ """Check if any results exist for the current query.
727
+
728
+ Returns:
729
+ True if at least one result exists, False otherwise.
730
+ """
461
731
  return self.count() > 0
732
+
733
+ def delete(self) -> int:
734
+ """Delete records that match the current query conditions.
735
+
736
+ Returns:
737
+ The number of records deleted.
738
+
739
+ Raises:
740
+ RecordDeletionError: If there's an error deleting the records.
741
+ """
742
+ sql = f'DELETE FROM "{self.table_name}"' # noqa: S608 # nosec
743
+
744
+ # Build the WHERE clause with special handling for None (NULL in SQL)
745
+ values, where_clause = self._parse_filter()
746
+
747
+ if self.filters:
748
+ sql += f" WHERE {where_clause}"
749
+
750
+ # Print the raw SQL and values if debug is enabled
751
+ if self.db.debug:
752
+ self.db._log_sql(sql, values) # noqa: SLF001
753
+
754
+ try:
755
+ with self.db.connect() as conn:
756
+ cursor = conn.cursor()
757
+ cursor.execute(sql, values)
758
+ deleted_count = cursor.rowcount
759
+ self.db._maybe_commit() # noqa: SLF001
760
+ return deleted_count
761
+ except sqlite3.Error as exc:
762
+ raise RecordDeletionError(self.table_name) from exc