sqliter-py 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sqliter/sqliter.py ADDED
@@ -0,0 +1,1087 @@
1
+ """Core module for SQLiter, providing the main database interaction class.
2
+
3
+ This module defines the SqliterDB class, which serves as the primary
4
+ interface for all database operations in SQLiter. It handles connection
5
+ management, table creation, and CRUD operations, bridging the gap between
6
+ Pydantic models and SQLite database interactions.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import logging
12
+ import sqlite3
13
+ import sys
14
+ import time
15
+ from collections import OrderedDict
16
+ from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
17
+
18
+ from typing_extensions import Self
19
+
20
+ from sqliter.exceptions import (
21
+ DatabaseConnectionError,
22
+ ForeignKeyConstraintError,
23
+ InvalidIndexError,
24
+ RecordDeletionError,
25
+ RecordFetchError,
26
+ RecordInsertionError,
27
+ RecordNotFoundError,
28
+ RecordUpdateError,
29
+ SqlExecutionError,
30
+ TableCreationError,
31
+ TableDeletionError,
32
+ )
33
+ from sqliter.helpers import infer_sqlite_type
34
+ from sqliter.model.foreign_key import ForeignKeyInfo, get_foreign_key_info
35
+ from sqliter.query.query import QueryBuilder
36
+
37
+ if TYPE_CHECKING: # pragma: no cover
38
+ from types import TracebackType
39
+
40
+ from pydantic.fields import FieldInfo
41
+
42
+ from sqliter.model.model import BaseDBModel
43
+
44
+ T = TypeVar("T", bound="BaseDBModel")
45
+
46
+
47
+ class SqliterDB:
48
+ """Main class for interacting with SQLite databases.
49
+
50
+ This class provides methods for connecting to a SQLite database,
51
+ creating tables, and performing CRUD operations.
52
+
53
+ Arguements:
54
+ db_filename (str): The filename of the SQLite database.
55
+ auto_commit (bool): Whether to automatically commit transactions.
56
+ debug (bool): Whether to enable debug logging.
57
+ logger (Optional[logging.Logger]): Custom logger for debug output.
58
+ """
59
+
60
+ MEMORY_DB = ":memory:"
61
+
62
+ def __init__( # noqa: PLR0913
63
+ self,
64
+ db_filename: Optional[str] = None,
65
+ *,
66
+ memory: bool = False,
67
+ auto_commit: bool = True,
68
+ debug: bool = False,
69
+ logger: Optional[logging.Logger] = None,
70
+ reset: bool = False,
71
+ return_local_time: bool = True,
72
+ cache_enabled: bool = False,
73
+ cache_max_size: int = 1000,
74
+ cache_ttl: Optional[int] = None,
75
+ cache_max_memory_mb: Optional[int] = None,
76
+ ) -> None:
77
+ """Initialize a new SqliterDB instance.
78
+
79
+ Args:
80
+ db_filename: The filename of the SQLite database.
81
+ memory: If True, create an in-memory database.
82
+ auto_commit: Whether to automatically commit transactions.
83
+ debug: Whether to enable debug logging.
84
+ logger: Custom logger for debug output.
85
+ reset: Whether to reset the database on initialization. This will
86
+ basically drop all existing tables.
87
+ return_local_time: Whether to return local time for datetime fields.
88
+ cache_enabled: Whether to enable query result caching. Default is
89
+ False.
90
+ cache_max_size: Maximum number of cached queries per table (LRU).
91
+ cache_ttl: Optional time-to-live for cache entries in seconds.
92
+ cache_max_memory_mb: Optional maximum memory usage for cache in
93
+ megabytes. When exceeded, oldest entries are evicted.
94
+
95
+ Raises:
96
+ ValueError: If no filename is provided for a non-memory database.
97
+ """
98
+ if memory:
99
+ self.db_filename = self.MEMORY_DB
100
+ elif db_filename:
101
+ self.db_filename = db_filename
102
+ else:
103
+ err = (
104
+ "Database name must be provided if not using an in-memory "
105
+ "database."
106
+ )
107
+ raise ValueError(err)
108
+ self.auto_commit = auto_commit
109
+ self.debug = debug
110
+ self.logger = logger
111
+ self.conn: Optional[sqlite3.Connection] = None
112
+ self.reset = reset
113
+ self.return_local_time = return_local_time
114
+
115
+ self._in_transaction = False
116
+
117
+ # Initialize cache
118
+ self._cache_enabled = cache_enabled
119
+ self._cache_max_size = cache_max_size
120
+ self._cache_ttl = cache_ttl
121
+ self._cache_max_memory_mb = cache_max_memory_mb
122
+
123
+ # Validate cache parameters
124
+ if self._cache_max_size <= 0:
125
+ msg = "cache_max_size must be greater than 0"
126
+ raise ValueError(msg)
127
+ if self._cache_ttl is not None and self._cache_ttl < 0:
128
+ msg = "cache_ttl must be non-negative"
129
+ raise ValueError(msg)
130
+ if (
131
+ self._cache_max_memory_mb is not None
132
+ and self._cache_max_memory_mb <= 0
133
+ ):
134
+ msg = "cache_max_memory_mb must be greater than 0"
135
+ raise ValueError(msg)
136
+ self._cache: OrderedDict[
137
+ str,
138
+ OrderedDict[
139
+ str,
140
+ tuple[
141
+ Union[BaseDBModel, list[BaseDBModel], None],
142
+ Optional[float],
143
+ ],
144
+ ],
145
+ ] = OrderedDict() # {table: {cache_key: (result, expiration)}}
146
+ self._cache_hits = 0
147
+ self._cache_misses = 0
148
+
149
+ if self.debug:
150
+ self._setup_logger()
151
+
152
+ if self.reset:
153
+ self._reset_database()
154
+
155
+ @property
156
+ def filename(self) -> Optional[str]:
157
+ """Returns the filename of the current database or None if in-memory."""
158
+ return None if self.db_filename == self.MEMORY_DB else self.db_filename
159
+
160
+ @property
161
+ def is_memory(self) -> bool:
162
+ """Returns True if the database is in-memory."""
163
+ return self.db_filename == self.MEMORY_DB
164
+
165
+ @property
166
+ def is_autocommit(self) -> bool:
167
+ """Returns True if auto-commit is enabled."""
168
+ return self.auto_commit
169
+
170
+ @property
171
+ def is_connected(self) -> bool:
172
+ """Returns True if the database is connected, False otherwise."""
173
+ return self.conn is not None
174
+
175
+ @property
176
+ def table_names(self) -> list[str]:
177
+ """Returns a list of all table names in the database.
178
+
179
+ Temporarily connects to the database if not connected and restores
180
+ the connection state afterward.
181
+ """
182
+ was_connected = self.is_connected
183
+ if not was_connected:
184
+ self.connect()
185
+
186
+ if self.conn is None:
187
+ err_msg = "Failed to establish a database connection."
188
+ raise DatabaseConnectionError(err_msg)
189
+
190
+ cursor = self.conn.cursor()
191
+ cursor.execute(
192
+ "SELECT name FROM sqlite_master WHERE type='table' "
193
+ "AND name NOT LIKE 'sqlite_%';"
194
+ )
195
+ tables = [row[0] for row in cursor.fetchall()]
196
+
197
+ # Restore the connection state
198
+ if not was_connected:
199
+ self.close()
200
+
201
+ return tables
202
+
203
+ def _reset_database(self) -> None:
204
+ """Drop all user-created tables in the database."""
205
+ with self.connect() as conn:
206
+ cursor = conn.cursor()
207
+
208
+ # Get all table names, excluding SQLite system tables
209
+ cursor.execute(
210
+ "SELECT name FROM sqlite_master WHERE type='table' "
211
+ "AND name NOT LIKE 'sqlite_%';"
212
+ )
213
+ tables = cursor.fetchall()
214
+
215
+ # Drop each user-created table
216
+ for table in tables:
217
+ cursor.execute(f"DROP TABLE IF EXISTS {table[0]}")
218
+
219
+ conn.commit()
220
+
221
+ if self.debug and self.logger:
222
+ self.logger.debug(
223
+ "Database reset: %s user-created tables dropped.", len(tables)
224
+ )
225
+
226
+ def _setup_logger(self) -> None:
227
+ """Set up the logger for debug output.
228
+
229
+ This method configures a logger for the SqliterDB instance, either
230
+ using an existing logger or creating a new one specifically for
231
+ SQLiter.
232
+ """
233
+ # Check if the root logger is already configured
234
+ root_logger = logging.getLogger()
235
+
236
+ if root_logger.hasHandlers():
237
+ # If the root logger has handlers, use it without modifying the root
238
+ # configuration
239
+ self.logger = root_logger.getChild("sqliter")
240
+ else:
241
+ # If no root logger is configured, set up a new logger specific to
242
+ # SqliterDB
243
+ self.logger = logging.getLogger("sqliter")
244
+
245
+ handler = logging.StreamHandler() # Output to console
246
+ formatter = logging.Formatter(
247
+ "%(levelname)-8s%(message)s"
248
+ ) # Custom format
249
+ handler.setFormatter(formatter)
250
+ self.logger.addHandler(handler)
251
+
252
+ self.logger.setLevel(logging.DEBUG)
253
+ self.logger.propagate = False
254
+
255
+ def _log_sql(self, sql: str, values: list[Any]) -> None:
256
+ """Log the SQL query and its values if debug mode is enabled.
257
+
258
+ The values are inserted into the SQL query string to replace the
259
+ placeholders.
260
+
261
+ Args:
262
+ sql: The SQL query string.
263
+ values: The list of values to be inserted into the query.
264
+ """
265
+ if self.debug and self.logger:
266
+ formatted_sql = sql
267
+ for value in values:
268
+ if isinstance(value, str):
269
+ formatted_sql = formatted_sql.replace("?", f"'{value}'", 1)
270
+ else:
271
+ formatted_sql = formatted_sql.replace("?", str(value), 1)
272
+
273
+ self.logger.debug("Executing SQL: %s", formatted_sql)
274
+
275
+ def connect(self) -> sqlite3.Connection:
276
+ """Establish a connection to the SQLite database.
277
+
278
+ Returns:
279
+ The SQLite connection object.
280
+
281
+ Raises:
282
+ DatabaseConnectionError: If unable to connect to the database.
283
+ """
284
+ if not self.conn:
285
+ try:
286
+ self.conn = sqlite3.connect(self.db_filename)
287
+ # Enable foreign key constraint enforcement
288
+ self.conn.execute("PRAGMA foreign_keys = ON")
289
+ except sqlite3.Error as exc:
290
+ raise DatabaseConnectionError(self.db_filename) from exc
291
+ return self.conn
292
+
293
+ def _cache_get(
294
+ self,
295
+ table_name: str,
296
+ cache_key: str,
297
+ ) -> tuple[bool, Any]:
298
+ """Get cached result if valid and not expired.
299
+
300
+ Args:
301
+ table_name: The name of the table.
302
+ cache_key: The cache key for the query.
303
+
304
+ Returns:
305
+ A tuple of (hit, result) where hit is True if cache hit,
306
+ False if miss. Result is the cached value (which may be None
307
+ or an empty list) on a hit, or None on a miss.
308
+ """
309
+ if not self._cache_enabled:
310
+ return False, None
311
+ if table_name not in self._cache:
312
+ self._cache_misses += 1
313
+ return False, None
314
+ if cache_key not in self._cache[table_name]:
315
+ self._cache_misses += 1
316
+ return False, None
317
+
318
+ result, expiration = self._cache[table_name][cache_key]
319
+
320
+ # Check TTL expiration
321
+ if expiration is not None and time.time() > expiration:
322
+ self._cache_misses += 1
323
+ del self._cache[table_name][cache_key]
324
+ return False, None
325
+
326
+ # Mark as recently used (LRU)
327
+ self._cache[table_name].move_to_end(cache_key)
328
+ self._cache_hits += 1
329
+ return True, result
330
+
331
+ def _cache_set(
332
+ self,
333
+ table_name: str,
334
+ cache_key: str,
335
+ result: Any, # noqa: ANN401
336
+ ttl: Optional[int] = None,
337
+ ) -> None:
338
+ """Store result in cache with optional expiration.
339
+
340
+ Args:
341
+ table_name: The name of the table.
342
+ cache_key: The cache key for the query.
343
+ result: The result to cache.
344
+ ttl: Optional TTL override for this specific entry.
345
+ """
346
+ if not self._cache_enabled:
347
+ return
348
+
349
+ if table_name not in self._cache:
350
+ self._cache[table_name] = OrderedDict()
351
+
352
+ # Calculate expiration (use query-specific TTL if provided)
353
+ expiration = None
354
+ effective_ttl = ttl if ttl is not None else self._cache_ttl
355
+ if effective_ttl is not None:
356
+ expiration = time.time() + effective_ttl
357
+
358
+ self._cache[table_name][cache_key] = (result, expiration)
359
+ # Mark as most-recently-used
360
+ self._cache[table_name].move_to_end(cache_key)
361
+
362
+ # Enforce memory limit if set
363
+ if self._cache_max_memory_mb is not None:
364
+ max_bytes = self._cache_max_memory_mb * 1024 * 1024
365
+ # Evict LRU entries until under the memory limit
366
+ while (
367
+ table_name in self._cache
368
+ and self._get_table_memory_usage(table_name) > max_bytes
369
+ ):
370
+ self._cache[table_name].popitem(last=False)
371
+
372
+ # Enforce LRU by size
373
+ if len(self._cache[table_name]) > self._cache_max_size:
374
+ self._cache[table_name].popitem(last=False)
375
+
376
+ def _cache_invalidate_table(self, table_name: str) -> None:
377
+ """Clear all cached queries for a specific table.
378
+
379
+ Args:
380
+ table_name: The name of the table to invalidate.
381
+ """
382
+ if not self._cache_enabled:
383
+ return
384
+ self._cache.pop(table_name, None)
385
+
386
+ def _get_table_memory_usage( # noqa: C901
387
+ self, table_name: str
388
+ ) -> int:
389
+ """Calculate the actual memory usage for a table's cache.
390
+
391
+ This method recalculates memory usage on-demand by measuring the
392
+ size of all cached entries including tuple and dict overhead.
393
+
394
+ Args:
395
+ table_name: The name of the table.
396
+
397
+ Returns:
398
+ The memory usage in bytes.
399
+ """
400
+ if table_name not in self._cache:
401
+ return 0
402
+
403
+ total = 0
404
+ seen: dict[int, int] = {}
405
+
406
+ for key, (result, _expiration) in self._cache[table_name].items():
407
+ # Measure the tuple (result, expiration)
408
+ total += sys.getsizeof((result, _expiration))
409
+
410
+ # Measure the dict key (cache_key string)
411
+ total += sys.getsizeof(key)
412
+
413
+ # Dict entry overhead (approximately 72 bytes for a dict entry)
414
+ total += 72
415
+
416
+ # Recursively measure the result object
417
+ def measure_size(obj: Any) -> int: # noqa: C901, ANN401
418
+ """Recursively measure object size with memoization."""
419
+ obj_id = id(obj)
420
+ if obj_id in seen:
421
+ return 0 # Already counted
422
+
423
+ size = sys.getsizeof(obj)
424
+ seen[obj_id] = size
425
+
426
+ # Handle lists
427
+ if isinstance(obj, list):
428
+ for item in obj:
429
+ size += measure_size(item)
430
+
431
+ # Handle Pydantic models - measure their fields
432
+ elif hasattr(type(obj), "model_fields"):
433
+ for field_name in type(obj).model_fields:
434
+ field_value = getattr(obj, field_name, None)
435
+ if field_value is not None:
436
+ size += measure_size(field_value)
437
+ # Also measure __dict__ if present
438
+ if hasattr(obj, "__dict__"):
439
+ size += measure_size(obj.__dict__)
440
+
441
+ # Handle dicts
442
+ elif isinstance(obj, dict):
443
+ for k, v in obj.items():
444
+ size += measure_size(k)
445
+ size += measure_size(v)
446
+
447
+ # Handle sets and tuples
448
+ elif isinstance(obj, (set, tuple)):
449
+ for item in obj:
450
+ size += measure_size(item)
451
+
452
+ return size
453
+
454
+ total += measure_size(result)
455
+
456
+ return total
457
+
458
+ def get_cache_stats(self) -> dict[str, int | float]:
459
+ """Get cache performance statistics.
460
+
461
+ Returns:
462
+ A dictionary containing cache statistics with keys:
463
+ - hits: Number of cache hits
464
+ - misses: Number of cache misses
465
+ - total: Total number of cache lookups
466
+ - hit_rate: Cache hit rate as a percentage (0-100)
467
+ """
468
+ total = self._cache_hits + self._cache_misses
469
+ hit_rate = (self._cache_hits / total * 100) if total > 0 else 0.0
470
+ return {
471
+ "hits": self._cache_hits,
472
+ "misses": self._cache_misses,
473
+ "total": total,
474
+ "hit_rate": round(hit_rate, 2),
475
+ }
476
+
477
+ def close(self) -> None:
478
+ """Close the database connection.
479
+
480
+ This method commits any pending changes if auto_commit is True,
481
+ then closes the connection. If the connection is already closed or does
482
+ not exist, this method silently does nothing.
483
+ """
484
+ if self.conn:
485
+ self._maybe_commit()
486
+ self.conn.close()
487
+ self.conn = None
488
+ self._cache.clear()
489
+ self._cache_hits = 0
490
+ self._cache_misses = 0
491
+
492
+ def commit(self) -> None:
493
+ """Commit the current transaction.
494
+
495
+ This method explicitly commits any pending changes to the database.
496
+ """
497
+ if self.conn:
498
+ self.conn.commit()
499
+
500
+ def _build_field_definitions(
501
+ self,
502
+ model_class: type[BaseDBModel],
503
+ primary_key: str,
504
+ ) -> tuple[list[str], list[str], list[str]]:
505
+ """Build SQL field definitions for table creation.
506
+
507
+ Args:
508
+ model_class: The Pydantic model class.
509
+ primary_key: The name of the primary key field.
510
+
511
+ Returns:
512
+ A tuple of (fields, foreign_keys, fk_columns) where:
513
+ - fields: List of column definitions
514
+ - foreign_keys: List of FK constraint definitions
515
+ - fk_columns: List of FK column names for index creation
516
+ """
517
+ fields = [f'"{primary_key}" INTEGER PRIMARY KEY AUTOINCREMENT']
518
+ foreign_keys: list[str] = []
519
+ fk_columns: list[str] = []
520
+
521
+ for field_name, field_info in model_class.model_fields.items():
522
+ if field_name == primary_key:
523
+ continue
524
+
525
+ fk_info = get_foreign_key_info(field_info)
526
+ if fk_info is not None:
527
+ col, constraint = self._build_fk_field(field_name, fk_info)
528
+ fields.append(col)
529
+ foreign_keys.append(constraint)
530
+ fk_columns.append(fk_info.db_column or field_name)
531
+ else:
532
+ fields.append(self._build_regular_field(field_name, field_info))
533
+
534
+ return fields, foreign_keys, fk_columns
535
+
536
+ def _build_fk_field(
537
+ self, field_name: str, fk_info: ForeignKeyInfo
538
+ ) -> tuple[str, str]:
539
+ """Build FK column definition and constraint.
540
+
541
+ Args:
542
+ field_name: The name of the field.
543
+ fk_info: The ForeignKeyInfo metadata.
544
+
545
+ Returns:
546
+ A tuple of (column_def, fk_constraint).
547
+ """
548
+ column_name = fk_info.db_column or field_name
549
+ null_str = "" if fk_info.null else "NOT NULL"
550
+ unique_str = "UNIQUE" if fk_info.unique else ""
551
+
552
+ field_def = f'"{column_name}" INTEGER {null_str} {unique_str}'
553
+ column_def = " ".join(field_def.split())
554
+
555
+ target_table = fk_info.to_model.get_table_name()
556
+ fk_constraint = (
557
+ f'FOREIGN KEY ("{column_name}") '
558
+ f'REFERENCES "{target_table}"("pk") '
559
+ f"ON DELETE {fk_info.on_delete} "
560
+ f"ON UPDATE {fk_info.on_update}"
561
+ )
562
+
563
+ return column_def, fk_constraint
564
+
565
+ def _build_regular_field(
566
+ self, field_name: str, field_info: FieldInfo
567
+ ) -> str:
568
+ """Build a regular (non-FK) column definition.
569
+
570
+ Args:
571
+ field_name: The name of the field.
572
+ field_info: The Pydantic field info.
573
+
574
+ Returns:
575
+ The column definition string.
576
+ """
577
+ sqlite_type = infer_sqlite_type(field_info.annotation)
578
+ unique_constraint = ""
579
+ if (
580
+ hasattr(field_info, "json_schema_extra")
581
+ and field_info.json_schema_extra
582
+ and isinstance(field_info.json_schema_extra, dict)
583
+ and field_info.json_schema_extra.get("unique", False)
584
+ ):
585
+ unique_constraint = "UNIQUE"
586
+ return f'"{field_name}" {sqlite_type} {unique_constraint}'.strip()
587
+
588
+ def create_table(
589
+ self,
590
+ model_class: type[BaseDBModel],
591
+ *,
592
+ exists_ok: bool = True,
593
+ force: bool = False,
594
+ ) -> None:
595
+ """Create a table in the database based on the given model class.
596
+
597
+ Args:
598
+ model_class: The Pydantic model class representing the table.
599
+ exists_ok: If True, do not raise an error if the table already
600
+ exists. Default is True which is the original behavior.
601
+ force: If True, drop the table if it exists before creating.
602
+ Defaults to False.
603
+
604
+ Raises:
605
+ TableCreationError: If there's an error creating the table.
606
+ ValueError: If the primary key field is not found in the model.
607
+ """
608
+ table_name = model_class.get_table_name()
609
+ primary_key = model_class.get_primary_key()
610
+
611
+ if force:
612
+ drop_table_sql = f"DROP TABLE IF EXISTS {table_name}"
613
+ self._execute_sql(drop_table_sql)
614
+
615
+ fields, foreign_keys, fk_columns = self._build_field_definitions(
616
+ model_class, primary_key
617
+ )
618
+
619
+ # Combine field definitions and FK constraints
620
+ all_definitions = fields + foreign_keys
621
+
622
+ create_str = (
623
+ "CREATE TABLE IF NOT EXISTS" if exists_ok else "CREATE TABLE"
624
+ )
625
+
626
+ create_table_sql = f"""
627
+ {create_str} "{table_name}" (
628
+ {", ".join(all_definitions)}
629
+ )
630
+ """
631
+
632
+ if self.debug:
633
+ self._log_sql(create_table_sql, [])
634
+
635
+ try:
636
+ with self.connect() as conn:
637
+ cursor = conn.cursor()
638
+ cursor.execute(create_table_sql)
639
+ conn.commit()
640
+ except sqlite3.Error as exc:
641
+ raise TableCreationError(table_name) from exc
642
+
643
+ # Create indexes for FK columns
644
+ for column_name in fk_columns:
645
+ index_sql = (
646
+ f'CREATE INDEX IF NOT EXISTS "idx_{table_name}_{column_name}" '
647
+ f'ON "{table_name}" ("{column_name}")'
648
+ )
649
+ self._execute_sql(index_sql)
650
+
651
+ # Create regular indexes
652
+ if hasattr(model_class.Meta, "indexes"):
653
+ self._create_indexes(
654
+ model_class, model_class.Meta.indexes, unique=False
655
+ )
656
+
657
+ # Create unique indexes
658
+ if hasattr(model_class.Meta, "unique_indexes"):
659
+ self._create_indexes(
660
+ model_class, model_class.Meta.unique_indexes, unique=True
661
+ )
662
+
663
+ def _create_indexes(
664
+ self,
665
+ model_class: type[BaseDBModel],
666
+ indexes: list[Union[str, tuple[str]]],
667
+ *,
668
+ unique: bool = False,
669
+ ) -> None:
670
+ """Helper method to create regular or unique indexes.
671
+
672
+ Args:
673
+ model_class: The model class defining the table.
674
+ indexes: List of fields or tuples of fields to create indexes for.
675
+ unique: If True, creates UNIQUE indexes; otherwise, creates regular
676
+ indexes.
677
+
678
+ Raises:
679
+ InvalidIndexError: If any fields specified for indexing do not exist
680
+ in the model.
681
+ """
682
+ valid_fields = set(
683
+ model_class.model_fields.keys()
684
+ ) # Get valid fields from the model
685
+
686
+ for index in indexes:
687
+ # Handle multiple fields in tuple form
688
+ fields = list(index) if isinstance(index, tuple) else [index]
689
+
690
+ # Check if all fields exist in the model
691
+ invalid_fields = [
692
+ field for field in fields if field not in valid_fields
693
+ ]
694
+ if invalid_fields:
695
+ raise InvalidIndexError(invalid_fields, model_class.__name__)
696
+
697
+ # Build the SQL string
698
+ index_name = "_".join(fields)
699
+ index_postfix = "_unique" if unique else ""
700
+ index_type = " UNIQUE " if unique else " "
701
+
702
+ # Quote field names for index creation
703
+ quoted_fields = ", ".join(f'"{field}"' for field in fields)
704
+
705
+ create_index_sql = (
706
+ f"CREATE{index_type}INDEX IF NOT EXISTS "
707
+ f"idx_{model_class.get_table_name()}"
708
+ f"_{index_name}{index_postfix} "
709
+ f'ON "{model_class.get_table_name()}" ({quoted_fields})'
710
+ )
711
+ self._execute_sql(create_index_sql)
712
+
713
+ def _execute_sql(self, sql: str) -> None:
714
+ """Execute an SQL statement.
715
+
716
+ Args:
717
+ sql: The SQL statement to execute.
718
+
719
+ Raises:
720
+ SqlExecutionError: If the SQL execution fails.
721
+ """
722
+ if self.debug:
723
+ self._log_sql(sql, [])
724
+
725
+ try:
726
+ with self.connect() as conn:
727
+ cursor = conn.cursor()
728
+ cursor.execute(sql)
729
+ conn.commit()
730
+ except (sqlite3.Error, sqlite3.Warning) as exc:
731
+ raise SqlExecutionError(sql) from exc
732
+
733
+ def drop_table(self, model_class: type[BaseDBModel]) -> None:
734
+ """Drop the table associated with the given model class.
735
+
736
+ Args:
737
+ model_class: The model class for which to drop the table.
738
+
739
+ Raises:
740
+ TableDeletionError: If there's an error dropping the table.
741
+ """
742
+ table_name = model_class.get_table_name()
743
+ drop_table_sql = f"DROP TABLE IF EXISTS {table_name}"
744
+
745
+ if self.debug:
746
+ self._log_sql(drop_table_sql, [])
747
+
748
+ try:
749
+ with self.connect() as conn:
750
+ cursor = conn.cursor()
751
+ cursor.execute(drop_table_sql)
752
+ self.commit()
753
+ except sqlite3.Error as exc:
754
+ raise TableDeletionError(table_name) from exc
755
+
756
+ def _maybe_commit(self) -> None:
757
+ """Commit changes if auto_commit is enabled.
758
+
759
+ This method is called after operations that modify the database,
760
+ committing changes only if auto_commit is set to True.
761
+ """
762
+ if not self._in_transaction and self.auto_commit and self.conn:
763
+ self.conn.commit()
764
+
765
+ def _set_insert_timestamps(
766
+ self, model_instance: T, *, timestamp_override: bool
767
+ ) -> None:
768
+ """Set created_at and updated_at timestamps for insert.
769
+
770
+ Args:
771
+ model_instance: The model instance to update.
772
+ timestamp_override: If True, respect provided non-zero values.
773
+ """
774
+ current_timestamp = int(time.time())
775
+
776
+ if not timestamp_override:
777
+ model_instance.created_at = current_timestamp
778
+ model_instance.updated_at = current_timestamp
779
+ else:
780
+ if model_instance.created_at == 0:
781
+ model_instance.created_at = current_timestamp
782
+ if model_instance.updated_at == 0:
783
+ model_instance.updated_at = current_timestamp
784
+
785
+ def insert(
786
+ self, model_instance: T, *, timestamp_override: bool = False
787
+ ) -> T:
788
+ """Insert a new record into the database.
789
+
790
+ Args:
791
+ model_instance: The instance of the model class to insert.
792
+ timestamp_override: If True, override the created_at and updated_at
793
+ timestamps with provided values. Default is False. If the values
794
+ are not provided, they will be set to the current time as
795
+ normal. Without this flag, the timestamps will always be set to
796
+ the current time, even if provided.
797
+
798
+ Returns:
799
+ The updated model instance with the primary key (pk) set.
800
+
801
+ Raises:
802
+ RecordInsertionError: If an error occurs during the insertion.
803
+ """
804
+ model_class = type(model_instance)
805
+ table_name = model_class.get_table_name()
806
+
807
+ self._set_insert_timestamps(
808
+ model_instance, timestamp_override=timestamp_override
809
+ )
810
+
811
+ # Get the data from the model
812
+ data = model_instance.model_dump()
813
+
814
+ # Serialize the data
815
+ for field_name, value in list(data.items()):
816
+ data[field_name] = model_instance.serialize_field(value)
817
+
818
+ # remove the primary key field if it exists, otherwise we'll get
819
+ # TypeErrors as multiple primary keys will exist
820
+ if data.get("pk", None) == 0:
821
+ data.pop("pk")
822
+
823
+ fields = ", ".join(data.keys())
824
+ placeholders = ", ".join(
825
+ ["?" if value is not None else "NULL" for value in data.values()]
826
+ )
827
+ values = tuple(value for value in data.values() if value is not None)
828
+
829
+ insert_sql = f"""
830
+ INSERT INTO {table_name} ({fields})
831
+ VALUES ({placeholders})
832
+ """ # noqa: S608
833
+
834
+ try:
835
+ with self.connect() as conn:
836
+ cursor = conn.cursor()
837
+ cursor.execute(insert_sql, values)
838
+ self._maybe_commit()
839
+
840
+ except sqlite3.IntegrityError as exc:
841
+ # Check for foreign key constraint violation
842
+ if "FOREIGN KEY constraint failed" in str(exc):
843
+ fk_operation = "insert"
844
+ fk_reason = "does not exist in referenced table"
845
+ raise ForeignKeyConstraintError(
846
+ fk_operation, fk_reason
847
+ ) from exc
848
+ raise RecordInsertionError(table_name) from exc
849
+ except sqlite3.Error as exc:
850
+ raise RecordInsertionError(table_name) from exc
851
+ else:
852
+ self._cache_invalidate_table(table_name)
853
+ data.pop("pk", None)
854
+ # Deserialize each field before creating the model instance
855
+ deserialized_data = {}
856
+ for field_name, value in data.items():
857
+ deserialized_data[field_name] = model_class.deserialize_field(
858
+ field_name, value, return_local_time=self.return_local_time
859
+ )
860
+ return model_class(pk=cursor.lastrowid, **deserialized_data)
861
+
862
+ def get(
863
+ self, model_class: type[BaseDBModel], primary_key_value: int
864
+ ) -> BaseDBModel | None:
865
+ """Retrieve a single record from the database by its primary key.
866
+
867
+ Args:
868
+ model_class: The Pydantic model class representing the table.
869
+ primary_key_value: The value of the primary key to look up.
870
+
871
+ Returns:
872
+ An instance of the model class if found, None otherwise.
873
+
874
+ Raises:
875
+ RecordFetchError: If there's an error fetching the record.
876
+ """
877
+ table_name = model_class.get_table_name()
878
+ primary_key = model_class.get_primary_key()
879
+
880
+ fields = ", ".join(model_class.model_fields)
881
+
882
+ select_sql = f"""
883
+ SELECT {fields} FROM {table_name} WHERE {primary_key} = ?
884
+ """ # noqa: S608
885
+
886
+ try:
887
+ with self.connect() as conn:
888
+ cursor = conn.cursor()
889
+ cursor.execute(select_sql, (primary_key_value,))
890
+ result = cursor.fetchone()
891
+
892
+ if result:
893
+ result_dict = {
894
+ field: result[idx]
895
+ for idx, field in enumerate(model_class.model_fields)
896
+ }
897
+ # Deserialize each field before creating the model instance
898
+ deserialized_data = {}
899
+ for field_name, value in result_dict.items():
900
+ deserialized_data[field_name] = (
901
+ model_class.deserialize_field(
902
+ field_name,
903
+ value,
904
+ return_local_time=self.return_local_time,
905
+ )
906
+ )
907
+ return model_class(**deserialized_data)
908
+ except sqlite3.Error as exc:
909
+ raise RecordFetchError(table_name) from exc
910
+ else:
911
+ return None
912
+
913
+ def update(self, model_instance: BaseDBModel) -> None:
914
+ """Update an existing record in the database.
915
+
916
+ Args:
917
+ model_instance: An instance of a Pydantic model to be updated.
918
+
919
+ Raises:
920
+ RecordUpdateError: If there's an error updating the record or if it
921
+ is not found.
922
+ """
923
+ model_class = type(model_instance)
924
+ table_name = model_class.get_table_name()
925
+ primary_key = model_class.get_primary_key()
926
+
927
+ # Set updated_at timestamp
928
+ current_timestamp = int(time.time())
929
+ model_instance.updated_at = current_timestamp
930
+
931
+ # Get the data and serialize any datetime/date fields
932
+ data = model_instance.model_dump()
933
+ for field_name, value in list(data.items()):
934
+ data[field_name] = model_instance.serialize_field(value)
935
+
936
+ # Remove the primary key from the update data
937
+ primary_key_value = data.pop(primary_key)
938
+
939
+ # Create the SQL using the processed data
940
+ fields = ", ".join(f"{field} = ?" for field in data)
941
+ values = tuple(data.values())
942
+
943
+ update_sql = f"""
944
+ UPDATE {table_name}
945
+ SET {fields}
946
+ WHERE {primary_key} = ?
947
+ """ # noqa: S608
948
+
949
+ try:
950
+ with self.connect() as conn:
951
+ cursor = conn.cursor()
952
+ cursor.execute(update_sql, (*values, primary_key_value))
953
+
954
+ # Check if any rows were updated
955
+ if cursor.rowcount == 0:
956
+ raise RecordNotFoundError(primary_key_value)
957
+
958
+ self._maybe_commit()
959
+ self._cache_invalidate_table(table_name)
960
+
961
+ except sqlite3.Error as exc:
962
+ raise RecordUpdateError(table_name) from exc
963
+
964
+ def delete(
965
+ self, model_class: type[BaseDBModel], primary_key_value: str
966
+ ) -> None:
967
+ """Delete a record from the database by its primary key.
968
+
969
+ Args:
970
+ model_class: The Pydantic model class representing the table.
971
+ primary_key_value: The value of the primary key of the record to
972
+ delete.
973
+
974
+ Raises:
975
+ RecordDeletionError: If there's an error deleting the record.
976
+ RecordNotFoundError: If the record to delete is not found.
977
+ """
978
+ table_name = model_class.get_table_name()
979
+ primary_key = model_class.get_primary_key()
980
+
981
+ delete_sql = f"""
982
+ DELETE FROM {table_name} WHERE {primary_key} = ?
983
+ """ # noqa: S608
984
+
985
+ try:
986
+ with self.connect() as conn:
987
+ cursor = conn.cursor()
988
+ cursor.execute(delete_sql, (primary_key_value,))
989
+
990
+ if cursor.rowcount == 0:
991
+ raise RecordNotFoundError(primary_key_value)
992
+ self._maybe_commit()
993
+ self._cache_invalidate_table(table_name)
994
+ except sqlite3.IntegrityError as exc:
995
+ # Check for foreign key constraint violation (RESTRICT)
996
+ if "FOREIGN KEY constraint failed" in str(exc):
997
+ fk_operation = "delete"
998
+ fk_reason = "is still referenced by other records"
999
+ raise ForeignKeyConstraintError(
1000
+ fk_operation, fk_reason
1001
+ ) from exc
1002
+ raise RecordDeletionError(table_name) from exc
1003
+ except sqlite3.Error as exc:
1004
+ raise RecordDeletionError(table_name) from exc
1005
+
1006
+ def select(
1007
+ self,
1008
+ model_class: type[T],
1009
+ fields: Optional[list[str]] = None,
1010
+ exclude: Optional[list[str]] = None,
1011
+ ) -> QueryBuilder[T]:
1012
+ """Create a QueryBuilder instance for selecting records.
1013
+
1014
+ Args:
1015
+ model_class: The Pydantic model class representing the table.
1016
+ fields: Optional list of fields to include in the query.
1017
+ exclude: Optional list of fields to exclude from the query.
1018
+
1019
+ Returns:
1020
+ A QueryBuilder instance for further query construction.
1021
+ """
1022
+ query_builder: QueryBuilder[T] = QueryBuilder(self, model_class, fields)
1023
+
1024
+ # If exclude is provided, apply the exclude method
1025
+ if exclude:
1026
+ query_builder.exclude(exclude)
1027
+
1028
+ return query_builder
1029
+
1030
+ # --- Context manager methods ---
1031
+ def __enter__(self) -> Self:
1032
+ """Enter the runtime context for the SqliterDB instance.
1033
+
1034
+ This method is called when entering a 'with' statement. It ensures
1035
+ that a database connection is established.
1036
+
1037
+ Note that this method should never be called explicitly, but will be
1038
+ called by the 'with' statement when entering the context.
1039
+
1040
+ Returns:
1041
+ The SqliterDB instance.
1042
+
1043
+ """
1044
+ self.connect()
1045
+ self._in_transaction = True
1046
+ return self
1047
+
1048
+ def __exit__(
1049
+ self,
1050
+ exc_type: Optional[type[BaseException]],
1051
+ exc_value: Optional[BaseException],
1052
+ traceback: Optional[TracebackType],
1053
+ ) -> None:
1054
+ """Exit the runtime context for the SqliterDB instance.
1055
+
1056
+ This method is called when exiting a 'with' statement. It handles
1057
+ committing or rolling back transactions based on whether an exception
1058
+ occurred, and closes the database connection.
1059
+
1060
+ Args:
1061
+ exc_type: The type of the exception that caused the context to be
1062
+ exited, or None if no exception was raised.
1063
+ exc_value: The instance of the exception that caused the context
1064
+ to be exited, or None if no exception was raised.
1065
+ traceback: A traceback object encoding the stack trace, or None
1066
+ if no exception was raised.
1067
+
1068
+ Note that this method should never be called explicitly, but will be
1069
+ called by the 'with' statement when exiting the context.
1070
+
1071
+ """
1072
+ if self.conn:
1073
+ try:
1074
+ if exc_type:
1075
+ # Roll back the transaction if there was an exception
1076
+ self.conn.rollback()
1077
+ else:
1078
+ self.conn.commit()
1079
+ finally:
1080
+ # Close the connection and reset the instance variable
1081
+ self.conn.close()
1082
+ self.conn = None
1083
+ self._in_transaction = False
1084
+ # Clear cache when exiting context
1085
+ self._cache.clear()
1086
+ self._cache_hits = 0
1087
+ self._cache_misses = 0