sqliter-py 0.9.0__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. sqliter/constants.py +4 -3
  2. sqliter/exceptions.py +43 -0
  3. sqliter/model/__init__.py +38 -3
  4. sqliter/model/foreign_key.py +153 -0
  5. sqliter/model/model.py +42 -3
  6. sqliter/model/unique.py +20 -11
  7. sqliter/orm/__init__.py +16 -0
  8. sqliter/orm/fields.py +412 -0
  9. sqliter/orm/foreign_key.py +8 -0
  10. sqliter/orm/model.py +243 -0
  11. sqliter/orm/query.py +221 -0
  12. sqliter/orm/registry.py +169 -0
  13. sqliter/query/query.py +720 -69
  14. sqliter/sqliter.py +533 -76
  15. sqliter/tui/__init__.py +62 -0
  16. sqliter/tui/__main__.py +6 -0
  17. sqliter/tui/app.py +179 -0
  18. sqliter/tui/demos/__init__.py +96 -0
  19. sqliter/tui/demos/base.py +114 -0
  20. sqliter/tui/demos/caching.py +283 -0
  21. sqliter/tui/demos/connection.py +150 -0
  22. sqliter/tui/demos/constraints.py +211 -0
  23. sqliter/tui/demos/crud.py +154 -0
  24. sqliter/tui/demos/errors.py +231 -0
  25. sqliter/tui/demos/field_selection.py +150 -0
  26. sqliter/tui/demos/filters.py +389 -0
  27. sqliter/tui/demos/models.py +248 -0
  28. sqliter/tui/demos/ordering.py +156 -0
  29. sqliter/tui/demos/orm.py +460 -0
  30. sqliter/tui/demos/results.py +241 -0
  31. sqliter/tui/demos/string_filters.py +210 -0
  32. sqliter/tui/demos/timestamps.py +126 -0
  33. sqliter/tui/demos/transactions.py +177 -0
  34. sqliter/tui/runner.py +116 -0
  35. sqliter/tui/styles/app.tcss +130 -0
  36. sqliter/tui/widgets/__init__.py +7 -0
  37. sqliter/tui/widgets/code_display.py +81 -0
  38. sqliter/tui/widgets/demo_list.py +65 -0
  39. sqliter/tui/widgets/output_display.py +92 -0
  40. {sqliter_py-0.9.0.dist-info → sqliter_py-0.16.0.dist-info}/METADATA +27 -11
  41. sqliter_py-0.16.0.dist-info/RECORD +47 -0
  42. {sqliter_py-0.9.0.dist-info → sqliter_py-0.16.0.dist-info}/WHEEL +2 -2
  43. sqliter_py-0.16.0.dist-info/entry_points.txt +3 -0
  44. sqliter_py-0.9.0.dist-info/RECORD +0 -14
sqliter/sqliter.py CHANGED
@@ -10,13 +10,16 @@ from __future__ import annotations
10
10
 
11
11
  import logging
12
12
  import sqlite3
13
+ import sys
13
14
  import time
14
- from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
15
+ from collections import OrderedDict
16
+ from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union, cast
15
17
 
16
18
  from typing_extensions import Self
17
19
 
18
20
  from sqliter.exceptions import (
19
21
  DatabaseConnectionError,
22
+ ForeignKeyConstraintError,
20
23
  InvalidIndexError,
21
24
  RecordDeletionError,
22
25
  RecordFetchError,
@@ -28,15 +31,16 @@ from sqliter.exceptions import (
28
31
  TableDeletionError,
29
32
  )
30
33
  from sqliter.helpers import infer_sqlite_type
31
- from sqliter.model.unique import Unique
34
+ from sqliter.model.foreign_key import ForeignKeyInfo, get_foreign_key_info
35
+ from sqliter.model.model import BaseDBModel
32
36
  from sqliter.query.query import QueryBuilder
33
37
 
34
38
  if TYPE_CHECKING: # pragma: no cover
35
39
  from types import TracebackType
36
40
 
37
- from sqliter.model.model import BaseDBModel
41
+ from pydantic.fields import FieldInfo
38
42
 
39
- T = TypeVar("T", bound="BaseDBModel")
43
+ T = TypeVar("T", bound=BaseDBModel)
40
44
 
41
45
 
42
46
  class SqliterDB:
@@ -64,6 +68,10 @@ class SqliterDB:
64
68
  logger: Optional[logging.Logger] = None,
65
69
  reset: bool = False,
66
70
  return_local_time: bool = True,
71
+ cache_enabled: bool = False,
72
+ cache_max_size: int = 1000,
73
+ cache_ttl: Optional[int] = None,
74
+ cache_max_memory_mb: Optional[int] = None,
67
75
  ) -> None:
68
76
  """Initialize a new SqliterDB instance.
69
77
 
@@ -76,6 +84,12 @@ class SqliterDB:
76
84
  reset: Whether to reset the database on initialization. This will
77
85
  basically drop all existing tables.
78
86
  return_local_time: Whether to return local time for datetime fields.
87
+ cache_enabled: Whether to enable query result caching. Default is
88
+ False.
89
+ cache_max_size: Maximum number of cached queries per table (LRU).
90
+ cache_ttl: Optional time-to-live for cache entries in seconds.
91
+ cache_max_memory_mb: Optional maximum memory usage for cache in
92
+ megabytes. When exceeded, oldest entries are evicted.
79
93
 
80
94
  Raises:
81
95
  ValueError: If no filename is provided for a non-memory database.
@@ -99,6 +113,38 @@ class SqliterDB:
99
113
 
100
114
  self._in_transaction = False
101
115
 
116
+ # Initialize cache
117
+ self._cache_enabled = cache_enabled
118
+ self._cache_max_size = cache_max_size
119
+ self._cache_ttl = cache_ttl
120
+ self._cache_max_memory_mb = cache_max_memory_mb
121
+
122
+ # Validate cache parameters
123
+ if self._cache_max_size <= 0:
124
+ msg = "cache_max_size must be greater than 0"
125
+ raise ValueError(msg)
126
+ if self._cache_ttl is not None and self._cache_ttl < 0:
127
+ msg = "cache_ttl must be non-negative"
128
+ raise ValueError(msg)
129
+ if (
130
+ self._cache_max_memory_mb is not None
131
+ and self._cache_max_memory_mb <= 0
132
+ ):
133
+ msg = "cache_max_memory_mb must be greater than 0"
134
+ raise ValueError(msg)
135
+ self._cache: OrderedDict[
136
+ str,
137
+ OrderedDict[
138
+ str,
139
+ tuple[
140
+ Union[BaseDBModel, list[BaseDBModel], None],
141
+ Optional[float],
142
+ ],
143
+ ],
144
+ ] = OrderedDict() # {table: {cache_key: (result, expiration)}}
145
+ self._cache_hits = 0
146
+ self._cache_misses = 0
147
+
102
148
  if self.debug:
103
149
  self._setup_logger()
104
150
 
@@ -237,10 +283,215 @@ class SqliterDB:
237
283
  if not self.conn:
238
284
  try:
239
285
  self.conn = sqlite3.connect(self.db_filename)
286
+ # Enable foreign key constraint enforcement
287
+ self.conn.execute("PRAGMA foreign_keys = ON")
240
288
  except sqlite3.Error as exc:
241
289
  raise DatabaseConnectionError(self.db_filename) from exc
242
290
  return self.conn
243
291
 
292
+ def _cache_get(
293
+ self,
294
+ table_name: str,
295
+ cache_key: str,
296
+ ) -> tuple[bool, Any]:
297
+ """Get cached result if valid and not expired.
298
+
299
+ Args:
300
+ table_name: The name of the table.
301
+ cache_key: The cache key for the query.
302
+
303
+ Returns:
304
+ A tuple of (hit, result) where hit is True if cache hit,
305
+ False if miss. Result is the cached value (which may be None
306
+ or an empty list) on a hit, or None on a miss.
307
+ """
308
+ if not self._cache_enabled:
309
+ return False, None
310
+ if table_name not in self._cache:
311
+ self._cache_misses += 1
312
+ return False, None
313
+ if cache_key not in self._cache[table_name]:
314
+ self._cache_misses += 1
315
+ return False, None
316
+
317
+ result, expiration = self._cache[table_name][cache_key]
318
+
319
+ # Check TTL expiration
320
+ if expiration is not None and time.time() > expiration:
321
+ self._cache_misses += 1
322
+ del self._cache[table_name][cache_key]
323
+ return False, None
324
+
325
+ # Mark as recently used (LRU)
326
+ self._cache[table_name].move_to_end(cache_key)
327
+ self._cache_hits += 1
328
+ return True, result
329
+
330
+ def _cache_set(
331
+ self,
332
+ table_name: str,
333
+ cache_key: str,
334
+ result: Any, # noqa: ANN401
335
+ ttl: Optional[int] = None,
336
+ ) -> None:
337
+ """Store result in cache with optional expiration.
338
+
339
+ Args:
340
+ table_name: The name of the table.
341
+ cache_key: The cache key for the query.
342
+ result: The result to cache.
343
+ ttl: Optional TTL override for this specific entry.
344
+ """
345
+ if not self._cache_enabled:
346
+ return
347
+
348
+ if table_name not in self._cache:
349
+ self._cache[table_name] = OrderedDict()
350
+
351
+ # Calculate expiration (use query-specific TTL if provided)
352
+ expiration = None
353
+ effective_ttl = ttl if ttl is not None else self._cache_ttl
354
+ if effective_ttl is not None:
355
+ expiration = time.time() + effective_ttl
356
+
357
+ self._cache[table_name][cache_key] = (result, expiration)
358
+ # Mark as most-recently-used
359
+ self._cache[table_name].move_to_end(cache_key)
360
+
361
+ # Enforce memory limit if set
362
+ if self._cache_max_memory_mb is not None:
363
+ max_bytes = self._cache_max_memory_mb * 1024 * 1024
364
+ # Evict LRU entries until under the memory limit
365
+ while (
366
+ table_name in self._cache
367
+ and self._get_table_memory_usage(table_name) > max_bytes
368
+ ):
369
+ self._cache[table_name].popitem(last=False)
370
+
371
+ # Enforce LRU by size
372
+ if len(self._cache[table_name]) > self._cache_max_size:
373
+ self._cache[table_name].popitem(last=False)
374
+
375
+ def _cache_invalidate_table(self, table_name: str) -> None:
376
+ """Clear all cached queries for a specific table.
377
+
378
+ Args:
379
+ table_name: The name of the table to invalidate.
380
+ """
381
+ if not self._cache_enabled:
382
+ return
383
+ self._cache.pop(table_name, None)
384
+
385
+ def _get_table_memory_usage( # noqa: C901
386
+ self, table_name: str
387
+ ) -> int:
388
+ """Calculate the actual memory usage for a table's cache.
389
+
390
+ This method recalculates memory usage on-demand by measuring the
391
+ size of all cached entries including tuple and dict overhead.
392
+
393
+ Args:
394
+ table_name: The name of the table.
395
+
396
+ Returns:
397
+ The memory usage in bytes.
398
+ """
399
+ if table_name not in self._cache:
400
+ return 0
401
+
402
+ total = 0
403
+ seen: dict[int, int] = {}
404
+
405
+ for key, (result, _expiration) in self._cache[table_name].items():
406
+ # Measure the tuple (result, expiration)
407
+ total += sys.getsizeof((result, _expiration))
408
+
409
+ # Measure the dict key (cache_key string)
410
+ total += sys.getsizeof(key)
411
+
412
+ # Dict entry overhead (approximately 72 bytes for a dict entry)
413
+ total += 72
414
+
415
+ # Recursively measure the result object
416
+ def measure_size(obj: Any) -> int: # noqa: C901, ANN401
417
+ """Recursively measure object size with memoization."""
418
+ obj_id = id(obj)
419
+ if obj_id in seen:
420
+ return 0 # Already counted
421
+
422
+ size = sys.getsizeof(obj)
423
+ seen[obj_id] = size
424
+
425
+ # Handle lists
426
+ if isinstance(obj, list):
427
+ for item in obj:
428
+ size += measure_size(item)
429
+
430
+ # Handle Pydantic models - measure their fields
431
+ elif hasattr(type(obj), "model_fields"):
432
+ for field_name in type(obj).model_fields:
433
+ field_value = getattr(obj, field_name, None)
434
+ if field_value is not None:
435
+ size += measure_size(field_value)
436
+ # Also measure __dict__ if present
437
+ if hasattr(obj, "__dict__"):
438
+ size += measure_size(obj.__dict__)
439
+
440
+ # Handle dicts
441
+ elif isinstance(obj, dict):
442
+ for k, v in obj.items():
443
+ size += measure_size(k)
444
+ size += measure_size(v)
445
+
446
+ # Handle sets and tuples
447
+ elif isinstance(obj, (set, tuple)):
448
+ for item in obj:
449
+ size += measure_size(item)
450
+
451
+ return size
452
+
453
+ total += measure_size(result)
454
+
455
+ return total
456
+
457
+ def get_cache_stats(self) -> dict[str, int | float]:
458
+ """Get cache performance statistics.
459
+
460
+ Returns:
461
+ A dictionary containing cache statistics with keys:
462
+ - hits: Number of cache hits
463
+ - misses: Number of cache misses
464
+ - total: Total number of cache lookups
465
+ - hit_rate: Cache hit rate as a percentage (0-100)
466
+ """
467
+ total = self._cache_hits + self._cache_misses
468
+ hit_rate = (self._cache_hits / total * 100) if total > 0 else 0.0
469
+ return {
470
+ "hits": self._cache_hits,
471
+ "misses": self._cache_misses,
472
+ "total": total,
473
+ "hit_rate": round(hit_rate, 2),
474
+ }
475
+
476
+ def clear_cache(self) -> None:
477
+ """Clear all cached query results.
478
+
479
+ This method removes all cached data from memory, freeing up resources
480
+ and forcing subsequent queries to fetch fresh data from the database.
481
+
482
+ Use this when you want to:
483
+ - Free memory used by the cache
484
+ - Force fresh queries after external data changes
485
+
486
+ Note:
487
+ Cache statistics (hits/misses) are preserved. To reset statistics,
488
+ create a new database connection.
489
+
490
+ Example:
491
+ >>> db.clear_cache()
492
+ """
493
+ self._cache.clear()
494
+
244
495
  def close(self) -> None:
245
496
  """Close the database connection.
246
497
 
@@ -252,6 +503,9 @@ class SqliterDB:
252
503
  self._maybe_commit()
253
504
  self.conn.close()
254
505
  self.conn = None
506
+ self._cache.clear()
507
+ self._cache_hits = 0
508
+ self._cache_misses = 0
255
509
 
256
510
  def commit(self) -> None:
257
511
  """Commit the current transaction.
@@ -261,6 +515,94 @@ class SqliterDB:
261
515
  if self.conn:
262
516
  self.conn.commit()
263
517
 
518
+ def _build_field_definitions(
519
+ self,
520
+ model_class: type[BaseDBModel],
521
+ primary_key: str,
522
+ ) -> tuple[list[str], list[str], list[str]]:
523
+ """Build SQL field definitions for table creation.
524
+
525
+ Args:
526
+ model_class: The Pydantic model class.
527
+ primary_key: The name of the primary key field.
528
+
529
+ Returns:
530
+ A tuple of (fields, foreign_keys, fk_columns) where:
531
+ - fields: List of column definitions
532
+ - foreign_keys: List of FK constraint definitions
533
+ - fk_columns: List of FK column names for index creation
534
+ """
535
+ fields = [f'"{primary_key}" INTEGER PRIMARY KEY AUTOINCREMENT']
536
+ foreign_keys: list[str] = []
537
+ fk_columns: list[str] = []
538
+
539
+ for field_name, field_info in model_class.model_fields.items():
540
+ if field_name == primary_key:
541
+ continue
542
+
543
+ fk_info = get_foreign_key_info(field_info)
544
+ if fk_info is not None:
545
+ col, constraint = self._build_fk_field(field_name, fk_info)
546
+ fields.append(col)
547
+ foreign_keys.append(constraint)
548
+ fk_columns.append(fk_info.db_column or field_name)
549
+ else:
550
+ fields.append(self._build_regular_field(field_name, field_info))
551
+
552
+ return fields, foreign_keys, fk_columns
553
+
554
+ def _build_fk_field(
555
+ self, field_name: str, fk_info: ForeignKeyInfo
556
+ ) -> tuple[str, str]:
557
+ """Build FK column definition and constraint.
558
+
559
+ Args:
560
+ field_name: The name of the field.
561
+ fk_info: The ForeignKeyInfo metadata.
562
+
563
+ Returns:
564
+ A tuple of (column_def, fk_constraint).
565
+ """
566
+ column_name = fk_info.db_column or field_name
567
+ null_str = "" if fk_info.null else "NOT NULL"
568
+ unique_str = "UNIQUE" if fk_info.unique else ""
569
+
570
+ field_def = f'"{column_name}" INTEGER {null_str} {unique_str}'
571
+ column_def = " ".join(field_def.split())
572
+
573
+ target_table = fk_info.to_model.get_table_name()
574
+ fk_constraint = (
575
+ f'FOREIGN KEY ("{column_name}") '
576
+ f'REFERENCES "{target_table}"("pk") '
577
+ f"ON DELETE {fk_info.on_delete} "
578
+ f"ON UPDATE {fk_info.on_update}"
579
+ )
580
+
581
+ return column_def, fk_constraint
582
+
583
+ def _build_regular_field(
584
+ self, field_name: str, field_info: FieldInfo
585
+ ) -> str:
586
+ """Build a regular (non-FK) column definition.
587
+
588
+ Args:
589
+ field_name: The name of the field.
590
+ field_info: The Pydantic field info.
591
+
592
+ Returns:
593
+ The column definition string.
594
+ """
595
+ sqlite_type = infer_sqlite_type(field_info.annotation)
596
+ unique_constraint = ""
597
+ if (
598
+ hasattr(field_info, "json_schema_extra")
599
+ and field_info.json_schema_extra
600
+ and isinstance(field_info.json_schema_extra, dict)
601
+ and field_info.json_schema_extra.get("unique", False)
602
+ ):
603
+ unique_constraint = "UNIQUE"
604
+ return f'"{field_name}" {sqlite_type} {unique_constraint}'.strip()
605
+
264
606
  def create_table(
265
607
  self,
266
608
  model_class: type[BaseDBModel],
@@ -288,26 +630,20 @@ class SqliterDB:
288
630
  drop_table_sql = f"DROP TABLE IF EXISTS {table_name}"
289
631
  self._execute_sql(drop_table_sql)
290
632
 
291
- fields = [f'"{primary_key}" INTEGER PRIMARY KEY AUTOINCREMENT']
633
+ fields, foreign_keys, fk_columns = self._build_field_definitions(
634
+ model_class, primary_key
635
+ )
292
636
 
293
- # Add remaining fields
294
- for field_name, field_info in model_class.model_fields.items():
295
- if field_name != primary_key:
296
- sqlite_type = infer_sqlite_type(field_info.annotation)
297
- unique_constraint = (
298
- "UNIQUE" if isinstance(field_info, Unique) else ""
299
- )
300
- fields.append(
301
- f"{field_name} {sqlite_type} {unique_constraint}".strip()
302
- )
637
+ # Combine field definitions and FK constraints
638
+ all_definitions = fields + foreign_keys
303
639
 
304
640
  create_str = (
305
641
  "CREATE TABLE IF NOT EXISTS" if exists_ok else "CREATE TABLE"
306
642
  )
307
643
 
308
644
  create_table_sql = f"""
309
- {create_str} {table_name} (
310
- {", ".join(fields)}
645
+ {create_str} "{table_name}" (
646
+ {", ".join(all_definitions)}
311
647
  )
312
648
  """
313
649
 
@@ -322,6 +658,14 @@ class SqliterDB:
322
658
  except sqlite3.Error as exc:
323
659
  raise TableCreationError(table_name) from exc
324
660
 
661
+ # Create indexes for FK columns
662
+ for column_name in fk_columns:
663
+ index_sql = (
664
+ f'CREATE INDEX IF NOT EXISTS "idx_{table_name}_{column_name}" '
665
+ f'ON "{table_name}" ("{column_name}")'
666
+ )
667
+ self._execute_sql(index_sql)
668
+
325
669
  # Create regular indexes
326
670
  if hasattr(model_class.Meta, "indexes"):
327
671
  self._create_indexes(
@@ -373,11 +717,14 @@ class SqliterDB:
373
717
  index_postfix = "_unique" if unique else ""
374
718
  index_type = " UNIQUE " if unique else " "
375
719
 
720
+ # Quote field names for index creation
721
+ quoted_fields = ", ".join(f'"{field}"' for field in fields)
722
+
376
723
  create_index_sql = (
377
724
  f"CREATE{index_type}INDEX IF NOT EXISTS "
378
725
  f"idx_{model_class.get_table_name()}"
379
726
  f"_{index_name}{index_postfix} "
380
- f"ON {model_class.get_table_name()} ({', '.join(fields)})"
727
+ f'ON "{model_class.get_table_name()}" ({quoted_fields})'
381
728
  )
382
729
  self._execute_sql(create_index_sql)
383
730
 
@@ -433,6 +780,64 @@ class SqliterDB:
433
780
  if not self._in_transaction and self.auto_commit and self.conn:
434
781
  self.conn.commit()
435
782
 
783
+ def _set_insert_timestamps(
784
+ self, model_instance: T, *, timestamp_override: bool
785
+ ) -> None:
786
+ """Set created_at and updated_at timestamps for insert.
787
+
788
+ Args:
789
+ model_instance: The model instance to update.
790
+ timestamp_override: If True, respect provided non-zero values.
791
+ """
792
+ current_timestamp = int(time.time())
793
+
794
+ if not timestamp_override:
795
+ model_instance.created_at = current_timestamp
796
+ model_instance.updated_at = current_timestamp
797
+ else:
798
+ if model_instance.created_at == 0:
799
+ model_instance.created_at = current_timestamp
800
+ if model_instance.updated_at == 0:
801
+ model_instance.updated_at = current_timestamp
802
+
803
+ def _create_instance_from_data(
804
+ self,
805
+ model_class: type[T],
806
+ data: dict[str, Any],
807
+ pk: Optional[int] = None,
808
+ ) -> T:
809
+ """Create a model instance from deserialized data.
810
+
811
+ Handles ORM-specific field exclusions and db_context setup.
812
+
813
+ Args:
814
+ model_class: The model class to instantiate.
815
+ data: Raw data dictionary from the database.
816
+ pk: Optional primary key value to set.
817
+
818
+ Returns:
819
+ A new model instance with db_context set if applicable.
820
+ """
821
+ # Deserialize each field before creating the model instance
822
+ deserialized_data: dict[str, Any] = {}
823
+ for field_name, value in data.items():
824
+ deserialized_data[field_name] = model_class.deserialize_field(
825
+ field_name, value, return_local_time=self.return_local_time
826
+ )
827
+ # For ORM mode, exclude FK descriptor fields from data
828
+ for fk_field in getattr(model_class, "fk_descriptors", {}):
829
+ deserialized_data.pop(fk_field, None)
830
+
831
+ if pk is not None:
832
+ instance = model_class(pk=pk, **deserialized_data)
833
+ else:
834
+ instance = model_class(**deserialized_data)
835
+
836
+ # Set db_context for ORM lazy loading and reverse relationships
837
+ if hasattr(instance, "db_context"):
838
+ instance.db_context = self
839
+ return instance
840
+
436
841
  def insert(
437
842
  self, model_instance: T, *, timestamp_override: bool = False
438
843
  ) -> T:
@@ -455,20 +860,9 @@ class SqliterDB:
455
860
  model_class = type(model_instance)
456
861
  table_name = model_class.get_table_name()
457
862
 
458
- # Always set created_at and updated_at timestamps
459
- current_timestamp = int(time.time())
460
-
461
- # Handle the case where timestamp_override is False
462
- if not timestamp_override:
463
- # Always override both timestamps with the current time
464
- model_instance.created_at = current_timestamp
465
- model_instance.updated_at = current_timestamp
466
- else:
467
- # Respect provided values, but set to current time if they are 0
468
- if model_instance.created_at == 0:
469
- model_instance.created_at = current_timestamp
470
- if model_instance.updated_at == 0:
471
- model_instance.updated_at = current_timestamp
863
+ self._set_insert_timestamps(
864
+ model_instance, timestamp_override=timestamp_override
865
+ )
472
866
 
473
867
  # Get the data from the model
474
868
  data = model_instance.model_dump()
@@ -494,31 +888,50 @@ class SqliterDB:
494
888
  """ # noqa: S608
495
889
 
496
890
  try:
497
- with self.connect() as conn:
498
- cursor = conn.cursor()
499
- cursor.execute(insert_sql, values)
500
- self._maybe_commit()
891
+ conn = self.connect()
892
+ cursor = conn.cursor()
893
+ cursor.execute(insert_sql, values)
894
+ self._maybe_commit()
501
895
 
896
+ except sqlite3.IntegrityError as exc:
897
+ # Rollback implicit transaction if not in user-managed transaction
898
+ if not self._in_transaction and self.conn:
899
+ self.conn.rollback()
900
+ # Check for foreign key constraint violation
901
+ if "FOREIGN KEY constraint failed" in str(exc):
902
+ fk_operation = "insert"
903
+ fk_reason = "does not exist in referenced table"
904
+ raise ForeignKeyConstraintError(
905
+ fk_operation, fk_reason
906
+ ) from exc
907
+ raise RecordInsertionError(table_name) from exc
502
908
  except sqlite3.Error as exc:
909
+ # Rollback implicit transaction if not in user-managed transaction
910
+ if not self._in_transaction and self.conn:
911
+ self.conn.rollback()
503
912
  raise RecordInsertionError(table_name) from exc
504
913
  else:
914
+ self._cache_invalidate_table(table_name)
505
915
  data.pop("pk", None)
506
- # Deserialize each field before creating the model instance
507
- deserialized_data = {}
508
- for field_name, value in data.items():
509
- deserialized_data[field_name] = model_class.deserialize_field(
510
- field_name, value, return_local_time=self.return_local_time
511
- )
512
- return model_class(pk=cursor.lastrowid, **deserialized_data)
916
+ return self._create_instance_from_data(
917
+ model_class, data, pk=cursor.lastrowid
918
+ )
513
919
 
514
920
  def get(
515
- self, model_class: type[BaseDBModel], primary_key_value: int
516
- ) -> BaseDBModel | None:
921
+ self,
922
+ model_class: type[T],
923
+ primary_key_value: int,
924
+ *,
925
+ bypass_cache: bool = False,
926
+ cache_ttl: Optional[int] = None,
927
+ ) -> T | None:
517
928
  """Retrieve a single record from the database by its primary key.
518
929
 
519
930
  Args:
520
931
  model_class: The Pydantic model class representing the table.
521
932
  primary_key_value: The value of the primary key to look up.
933
+ bypass_cache: If True, skip reading/writing cache for this call.
934
+ cache_ttl: Optional TTL override for this specific lookup.
522
935
 
523
936
  Returns:
524
937
  An instance of the model class if found, None otherwise.
@@ -526,8 +939,18 @@ class SqliterDB:
526
939
  Raises:
527
940
  RecordFetchError: If there's an error fetching the record.
528
941
  """
942
+ if cache_ttl is not None and cache_ttl < 0:
943
+ msg = "cache_ttl must be non-negative"
944
+ raise ValueError(msg)
945
+
529
946
  table_name = model_class.get_table_name()
530
947
  primary_key = model_class.get_primary_key()
948
+ cache_key = f"pk:{primary_key_value}"
949
+
950
+ if not bypass_cache:
951
+ hit, cached = self._cache_get(table_name, cache_key)
952
+ if hit:
953
+ return cast("Optional[T]", cached)
531
954
 
532
955
  fields = ", ".join(model_class.model_fields)
533
956
 
@@ -536,30 +959,29 @@ class SqliterDB:
536
959
  """ # noqa: S608
537
960
 
538
961
  try:
539
- with self.connect() as conn:
540
- cursor = conn.cursor()
541
- cursor.execute(select_sql, (primary_key_value,))
542
- result = cursor.fetchone()
962
+ conn = self.connect()
963
+ cursor = conn.cursor()
964
+ cursor.execute(select_sql, (primary_key_value,))
965
+ result = cursor.fetchone()
543
966
 
544
967
  if result:
545
968
  result_dict = {
546
969
  field: result[idx]
547
970
  for idx, field in enumerate(model_class.model_fields)
548
971
  }
549
- # Deserialize each field before creating the model instance
550
- deserialized_data = {}
551
- for field_name, value in result_dict.items():
552
- deserialized_data[field_name] = (
553
- model_class.deserialize_field(
554
- field_name,
555
- value,
556
- return_local_time=self.return_local_time,
557
- )
972
+ instance = self._create_instance_from_data(
973
+ model_class, result_dict
974
+ )
975
+ if not bypass_cache:
976
+ self._cache_set(
977
+ table_name, cache_key, instance, ttl=cache_ttl
558
978
  )
559
- return model_class(**deserialized_data)
979
+ return instance
560
980
  except sqlite3.Error as exc:
561
981
  raise RecordFetchError(table_name) from exc
562
982
  else:
983
+ if not bypass_cache:
984
+ self._cache_set(table_name, cache_key, None, ttl=cache_ttl)
563
985
  return None
564
986
 
565
987
  def update(self, model_instance: BaseDBModel) -> None:
@@ -582,6 +1004,7 @@ class SqliterDB:
582
1004
 
583
1005
  # Get the data and serialize any datetime/date fields
584
1006
  data = model_instance.model_dump()
1007
+
585
1008
  for field_name, value in list(data.items()):
586
1009
  data[field_name] = model_instance.serialize_field(value)
587
1010
 
@@ -599,21 +1022,30 @@ class SqliterDB:
599
1022
  """ # noqa: S608
600
1023
 
601
1024
  try:
602
- with self.connect() as conn:
603
- cursor = conn.cursor()
604
- cursor.execute(update_sql, (*values, primary_key_value))
1025
+ conn = self.connect()
1026
+ cursor = conn.cursor()
1027
+ cursor.execute(update_sql, (*values, primary_key_value))
605
1028
 
606
- # Check if any rows were updated
607
- if cursor.rowcount == 0:
608
- raise RecordNotFoundError(primary_key_value)
1029
+ # Check if any rows were updated
1030
+ if cursor.rowcount == 0:
1031
+ raise RecordNotFoundError(primary_key_value) # noqa: TRY301
609
1032
 
610
- self._maybe_commit()
1033
+ self._maybe_commit()
1034
+ self._cache_invalidate_table(table_name)
611
1035
 
1036
+ except RecordNotFoundError:
1037
+ # Rollback implicit transaction if not in user-managed transaction
1038
+ if not self._in_transaction and self.conn:
1039
+ self.conn.rollback()
1040
+ raise
612
1041
  except sqlite3.Error as exc:
1042
+ # Rollback implicit transaction if not in user-managed transaction
1043
+ if not self._in_transaction and self.conn:
1044
+ self.conn.rollback()
613
1045
  raise RecordUpdateError(table_name) from exc
614
1046
 
615
1047
  def delete(
616
- self, model_class: type[BaseDBModel], primary_key_value: str
1048
+ self, model_class: type[BaseDBModel], primary_key_value: Union[int, str]
617
1049
  ) -> None:
618
1050
  """Delete a record from the database by its primary key.
619
1051
 
@@ -634,22 +1066,43 @@ class SqliterDB:
634
1066
  """ # noqa: S608
635
1067
 
636
1068
  try:
637
- with self.connect() as conn:
638
- cursor = conn.cursor()
639
- cursor.execute(delete_sql, (primary_key_value,))
1069
+ conn = self.connect()
1070
+ cursor = conn.cursor()
1071
+ cursor.execute(delete_sql, (primary_key_value,))
640
1072
 
641
- if cursor.rowcount == 0:
642
- raise RecordNotFoundError(primary_key_value)
643
- self._maybe_commit()
1073
+ if cursor.rowcount == 0:
1074
+ raise RecordNotFoundError(primary_key_value) # noqa: TRY301
1075
+ self._maybe_commit()
1076
+ self._cache_invalidate_table(table_name)
1077
+ except RecordNotFoundError:
1078
+ # Rollback implicit transaction if not in user-managed transaction
1079
+ if not self._in_transaction and self.conn:
1080
+ self.conn.rollback()
1081
+ raise
1082
+ except sqlite3.IntegrityError as exc:
1083
+ # Rollback implicit transaction if not in user-managed transaction
1084
+ if not self._in_transaction and self.conn:
1085
+ self.conn.rollback()
1086
+ # Check for foreign key constraint violation (RESTRICT)
1087
+ if "FOREIGN KEY constraint failed" in str(exc):
1088
+ fk_operation = "delete"
1089
+ fk_reason = "is still referenced by other records"
1090
+ raise ForeignKeyConstraintError(
1091
+ fk_operation, fk_reason
1092
+ ) from exc
1093
+ raise RecordDeletionError(table_name) from exc
644
1094
  except sqlite3.Error as exc:
1095
+ # Rollback implicit transaction if not in user-managed transaction
1096
+ if not self._in_transaction and self.conn:
1097
+ self.conn.rollback()
645
1098
  raise RecordDeletionError(table_name) from exc
646
1099
 
647
1100
  def select(
648
1101
  self,
649
- model_class: type[BaseDBModel],
1102
+ model_class: type[T],
650
1103
  fields: Optional[list[str]] = None,
651
1104
  exclude: Optional[list[str]] = None,
652
- ) -> QueryBuilder:
1105
+ ) -> QueryBuilder[T]:
653
1106
  """Create a QueryBuilder instance for selecting records.
654
1107
 
655
1108
  Args:
@@ -660,7 +1113,7 @@ class SqliterDB:
660
1113
  Returns:
661
1114
  A QueryBuilder instance for further query construction.
662
1115
  """
663
- query_builder = QueryBuilder(self, model_class, fields)
1116
+ query_builder: QueryBuilder[T] = QueryBuilder(self, model_class, fields)
664
1117
 
665
1118
  # If exclude is provided, apply the exclude method
666
1119
  if exclude:
@@ -722,3 +1175,7 @@ class SqliterDB:
722
1175
  self.conn.close()
723
1176
  self.conn = None
724
1177
  self._in_transaction = False
1178
+ # Clear cache when exiting context
1179
+ self._cache.clear()
1180
+ self._cache_hits = 0
1181
+ self._cache_misses = 0