sqliter-py 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqliter/__init__.py +9 -0
- sqliter/constants.py +45 -0
- sqliter/exceptions.py +198 -0
- sqliter/helpers.py +100 -0
- sqliter/model/__init__.py +46 -0
- sqliter/model/foreign_key.py +153 -0
- sqliter/model/model.py +236 -0
- sqliter/model/unique.py +28 -0
- sqliter/py.typed +0 -0
- sqliter/query/__init__.py +9 -0
- sqliter/query/query.py +891 -0
- sqliter/sqliter.py +1087 -0
- sqliter_py-0.12.0.dist-info/METADATA +209 -0
- sqliter_py-0.12.0.dist-info/RECORD +15 -0
- sqliter_py-0.12.0.dist-info/WHEEL +4 -0
sqliter/sqliter.py
ADDED
|
@@ -0,0 +1,1087 @@
|
|
|
1
|
+
"""Core module for SQLiter, providing the main database interaction class.
|
|
2
|
+
|
|
3
|
+
This module defines the SqliterDB class, which serves as the primary
|
|
4
|
+
interface for all database operations in SQLiter. It handles connection
|
|
5
|
+
management, table creation, and CRUD operations, bridging the gap between
|
|
6
|
+
Pydantic models and SQLite database interactions.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
import sqlite3
|
|
13
|
+
import sys
|
|
14
|
+
import time
|
|
15
|
+
from collections import OrderedDict
|
|
16
|
+
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
|
|
17
|
+
|
|
18
|
+
from typing_extensions import Self
|
|
19
|
+
|
|
20
|
+
from sqliter.exceptions import (
|
|
21
|
+
DatabaseConnectionError,
|
|
22
|
+
ForeignKeyConstraintError,
|
|
23
|
+
InvalidIndexError,
|
|
24
|
+
RecordDeletionError,
|
|
25
|
+
RecordFetchError,
|
|
26
|
+
RecordInsertionError,
|
|
27
|
+
RecordNotFoundError,
|
|
28
|
+
RecordUpdateError,
|
|
29
|
+
SqlExecutionError,
|
|
30
|
+
TableCreationError,
|
|
31
|
+
TableDeletionError,
|
|
32
|
+
)
|
|
33
|
+
from sqliter.helpers import infer_sqlite_type
|
|
34
|
+
from sqliter.model.foreign_key import ForeignKeyInfo, get_foreign_key_info
|
|
35
|
+
from sqliter.query.query import QueryBuilder
|
|
36
|
+
|
|
37
|
+
if TYPE_CHECKING: # pragma: no cover
|
|
38
|
+
from types import TracebackType
|
|
39
|
+
|
|
40
|
+
from pydantic.fields import FieldInfo
|
|
41
|
+
|
|
42
|
+
from sqliter.model.model import BaseDBModel
|
|
43
|
+
|
|
44
|
+
T = TypeVar("T", bound="BaseDBModel")
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class SqliterDB:
|
|
48
|
+
"""Main class for interacting with SQLite databases.
|
|
49
|
+
|
|
50
|
+
This class provides methods for connecting to a SQLite database,
|
|
51
|
+
creating tables, and performing CRUD operations.
|
|
52
|
+
|
|
53
|
+
Arguements:
|
|
54
|
+
db_filename (str): The filename of the SQLite database.
|
|
55
|
+
auto_commit (bool): Whether to automatically commit transactions.
|
|
56
|
+
debug (bool): Whether to enable debug logging.
|
|
57
|
+
logger (Optional[logging.Logger]): Custom logger for debug output.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
MEMORY_DB = ":memory:"
|
|
61
|
+
|
|
62
|
+
def __init__( # noqa: PLR0913
|
|
63
|
+
self,
|
|
64
|
+
db_filename: Optional[str] = None,
|
|
65
|
+
*,
|
|
66
|
+
memory: bool = False,
|
|
67
|
+
auto_commit: bool = True,
|
|
68
|
+
debug: bool = False,
|
|
69
|
+
logger: Optional[logging.Logger] = None,
|
|
70
|
+
reset: bool = False,
|
|
71
|
+
return_local_time: bool = True,
|
|
72
|
+
cache_enabled: bool = False,
|
|
73
|
+
cache_max_size: int = 1000,
|
|
74
|
+
cache_ttl: Optional[int] = None,
|
|
75
|
+
cache_max_memory_mb: Optional[int] = None,
|
|
76
|
+
) -> None:
|
|
77
|
+
"""Initialize a new SqliterDB instance.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
db_filename: The filename of the SQLite database.
|
|
81
|
+
memory: If True, create an in-memory database.
|
|
82
|
+
auto_commit: Whether to automatically commit transactions.
|
|
83
|
+
debug: Whether to enable debug logging.
|
|
84
|
+
logger: Custom logger for debug output.
|
|
85
|
+
reset: Whether to reset the database on initialization. This will
|
|
86
|
+
basically drop all existing tables.
|
|
87
|
+
return_local_time: Whether to return local time for datetime fields.
|
|
88
|
+
cache_enabled: Whether to enable query result caching. Default is
|
|
89
|
+
False.
|
|
90
|
+
cache_max_size: Maximum number of cached queries per table (LRU).
|
|
91
|
+
cache_ttl: Optional time-to-live for cache entries in seconds.
|
|
92
|
+
cache_max_memory_mb: Optional maximum memory usage for cache in
|
|
93
|
+
megabytes. When exceeded, oldest entries are evicted.
|
|
94
|
+
|
|
95
|
+
Raises:
|
|
96
|
+
ValueError: If no filename is provided for a non-memory database.
|
|
97
|
+
"""
|
|
98
|
+
if memory:
|
|
99
|
+
self.db_filename = self.MEMORY_DB
|
|
100
|
+
elif db_filename:
|
|
101
|
+
self.db_filename = db_filename
|
|
102
|
+
else:
|
|
103
|
+
err = (
|
|
104
|
+
"Database name must be provided if not using an in-memory "
|
|
105
|
+
"database."
|
|
106
|
+
)
|
|
107
|
+
raise ValueError(err)
|
|
108
|
+
self.auto_commit = auto_commit
|
|
109
|
+
self.debug = debug
|
|
110
|
+
self.logger = logger
|
|
111
|
+
self.conn: Optional[sqlite3.Connection] = None
|
|
112
|
+
self.reset = reset
|
|
113
|
+
self.return_local_time = return_local_time
|
|
114
|
+
|
|
115
|
+
self._in_transaction = False
|
|
116
|
+
|
|
117
|
+
# Initialize cache
|
|
118
|
+
self._cache_enabled = cache_enabled
|
|
119
|
+
self._cache_max_size = cache_max_size
|
|
120
|
+
self._cache_ttl = cache_ttl
|
|
121
|
+
self._cache_max_memory_mb = cache_max_memory_mb
|
|
122
|
+
|
|
123
|
+
# Validate cache parameters
|
|
124
|
+
if self._cache_max_size <= 0:
|
|
125
|
+
msg = "cache_max_size must be greater than 0"
|
|
126
|
+
raise ValueError(msg)
|
|
127
|
+
if self._cache_ttl is not None and self._cache_ttl < 0:
|
|
128
|
+
msg = "cache_ttl must be non-negative"
|
|
129
|
+
raise ValueError(msg)
|
|
130
|
+
if (
|
|
131
|
+
self._cache_max_memory_mb is not None
|
|
132
|
+
and self._cache_max_memory_mb <= 0
|
|
133
|
+
):
|
|
134
|
+
msg = "cache_max_memory_mb must be greater than 0"
|
|
135
|
+
raise ValueError(msg)
|
|
136
|
+
self._cache: OrderedDict[
|
|
137
|
+
str,
|
|
138
|
+
OrderedDict[
|
|
139
|
+
str,
|
|
140
|
+
tuple[
|
|
141
|
+
Union[BaseDBModel, list[BaseDBModel], None],
|
|
142
|
+
Optional[float],
|
|
143
|
+
],
|
|
144
|
+
],
|
|
145
|
+
] = OrderedDict() # {table: {cache_key: (result, expiration)}}
|
|
146
|
+
self._cache_hits = 0
|
|
147
|
+
self._cache_misses = 0
|
|
148
|
+
|
|
149
|
+
if self.debug:
|
|
150
|
+
self._setup_logger()
|
|
151
|
+
|
|
152
|
+
if self.reset:
|
|
153
|
+
self._reset_database()
|
|
154
|
+
|
|
155
|
+
@property
|
|
156
|
+
def filename(self) -> Optional[str]:
|
|
157
|
+
"""Returns the filename of the current database or None if in-memory."""
|
|
158
|
+
return None if self.db_filename == self.MEMORY_DB else self.db_filename
|
|
159
|
+
|
|
160
|
+
@property
|
|
161
|
+
def is_memory(self) -> bool:
|
|
162
|
+
"""Returns True if the database is in-memory."""
|
|
163
|
+
return self.db_filename == self.MEMORY_DB
|
|
164
|
+
|
|
165
|
+
@property
|
|
166
|
+
def is_autocommit(self) -> bool:
|
|
167
|
+
"""Returns True if auto-commit is enabled."""
|
|
168
|
+
return self.auto_commit
|
|
169
|
+
|
|
170
|
+
@property
|
|
171
|
+
def is_connected(self) -> bool:
|
|
172
|
+
"""Returns True if the database is connected, False otherwise."""
|
|
173
|
+
return self.conn is not None
|
|
174
|
+
|
|
175
|
+
@property
|
|
176
|
+
def table_names(self) -> list[str]:
|
|
177
|
+
"""Returns a list of all table names in the database.
|
|
178
|
+
|
|
179
|
+
Temporarily connects to the database if not connected and restores
|
|
180
|
+
the connection state afterward.
|
|
181
|
+
"""
|
|
182
|
+
was_connected = self.is_connected
|
|
183
|
+
if not was_connected:
|
|
184
|
+
self.connect()
|
|
185
|
+
|
|
186
|
+
if self.conn is None:
|
|
187
|
+
err_msg = "Failed to establish a database connection."
|
|
188
|
+
raise DatabaseConnectionError(err_msg)
|
|
189
|
+
|
|
190
|
+
cursor = self.conn.cursor()
|
|
191
|
+
cursor.execute(
|
|
192
|
+
"SELECT name FROM sqlite_master WHERE type='table' "
|
|
193
|
+
"AND name NOT LIKE 'sqlite_%';"
|
|
194
|
+
)
|
|
195
|
+
tables = [row[0] for row in cursor.fetchall()]
|
|
196
|
+
|
|
197
|
+
# Restore the connection state
|
|
198
|
+
if not was_connected:
|
|
199
|
+
self.close()
|
|
200
|
+
|
|
201
|
+
return tables
|
|
202
|
+
|
|
203
|
+
def _reset_database(self) -> None:
|
|
204
|
+
"""Drop all user-created tables in the database."""
|
|
205
|
+
with self.connect() as conn:
|
|
206
|
+
cursor = conn.cursor()
|
|
207
|
+
|
|
208
|
+
# Get all table names, excluding SQLite system tables
|
|
209
|
+
cursor.execute(
|
|
210
|
+
"SELECT name FROM sqlite_master WHERE type='table' "
|
|
211
|
+
"AND name NOT LIKE 'sqlite_%';"
|
|
212
|
+
)
|
|
213
|
+
tables = cursor.fetchall()
|
|
214
|
+
|
|
215
|
+
# Drop each user-created table
|
|
216
|
+
for table in tables:
|
|
217
|
+
cursor.execute(f"DROP TABLE IF EXISTS {table[0]}")
|
|
218
|
+
|
|
219
|
+
conn.commit()
|
|
220
|
+
|
|
221
|
+
if self.debug and self.logger:
|
|
222
|
+
self.logger.debug(
|
|
223
|
+
"Database reset: %s user-created tables dropped.", len(tables)
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
def _setup_logger(self) -> None:
|
|
227
|
+
"""Set up the logger for debug output.
|
|
228
|
+
|
|
229
|
+
This method configures a logger for the SqliterDB instance, either
|
|
230
|
+
using an existing logger or creating a new one specifically for
|
|
231
|
+
SQLiter.
|
|
232
|
+
"""
|
|
233
|
+
# Check if the root logger is already configured
|
|
234
|
+
root_logger = logging.getLogger()
|
|
235
|
+
|
|
236
|
+
if root_logger.hasHandlers():
|
|
237
|
+
# If the root logger has handlers, use it without modifying the root
|
|
238
|
+
# configuration
|
|
239
|
+
self.logger = root_logger.getChild("sqliter")
|
|
240
|
+
else:
|
|
241
|
+
# If no root logger is configured, set up a new logger specific to
|
|
242
|
+
# SqliterDB
|
|
243
|
+
self.logger = logging.getLogger("sqliter")
|
|
244
|
+
|
|
245
|
+
handler = logging.StreamHandler() # Output to console
|
|
246
|
+
formatter = logging.Formatter(
|
|
247
|
+
"%(levelname)-8s%(message)s"
|
|
248
|
+
) # Custom format
|
|
249
|
+
handler.setFormatter(formatter)
|
|
250
|
+
self.logger.addHandler(handler)
|
|
251
|
+
|
|
252
|
+
self.logger.setLevel(logging.DEBUG)
|
|
253
|
+
self.logger.propagate = False
|
|
254
|
+
|
|
255
|
+
def _log_sql(self, sql: str, values: list[Any]) -> None:
|
|
256
|
+
"""Log the SQL query and its values if debug mode is enabled.
|
|
257
|
+
|
|
258
|
+
The values are inserted into the SQL query string to replace the
|
|
259
|
+
placeholders.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
sql: The SQL query string.
|
|
263
|
+
values: The list of values to be inserted into the query.
|
|
264
|
+
"""
|
|
265
|
+
if self.debug and self.logger:
|
|
266
|
+
formatted_sql = sql
|
|
267
|
+
for value in values:
|
|
268
|
+
if isinstance(value, str):
|
|
269
|
+
formatted_sql = formatted_sql.replace("?", f"'{value}'", 1)
|
|
270
|
+
else:
|
|
271
|
+
formatted_sql = formatted_sql.replace("?", str(value), 1)
|
|
272
|
+
|
|
273
|
+
self.logger.debug("Executing SQL: %s", formatted_sql)
|
|
274
|
+
|
|
275
|
+
def connect(self) -> sqlite3.Connection:
|
|
276
|
+
"""Establish a connection to the SQLite database.
|
|
277
|
+
|
|
278
|
+
Returns:
|
|
279
|
+
The SQLite connection object.
|
|
280
|
+
|
|
281
|
+
Raises:
|
|
282
|
+
DatabaseConnectionError: If unable to connect to the database.
|
|
283
|
+
"""
|
|
284
|
+
if not self.conn:
|
|
285
|
+
try:
|
|
286
|
+
self.conn = sqlite3.connect(self.db_filename)
|
|
287
|
+
# Enable foreign key constraint enforcement
|
|
288
|
+
self.conn.execute("PRAGMA foreign_keys = ON")
|
|
289
|
+
except sqlite3.Error as exc:
|
|
290
|
+
raise DatabaseConnectionError(self.db_filename) from exc
|
|
291
|
+
return self.conn
|
|
292
|
+
|
|
293
|
+
def _cache_get(
|
|
294
|
+
self,
|
|
295
|
+
table_name: str,
|
|
296
|
+
cache_key: str,
|
|
297
|
+
) -> tuple[bool, Any]:
|
|
298
|
+
"""Get cached result if valid and not expired.
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
table_name: The name of the table.
|
|
302
|
+
cache_key: The cache key for the query.
|
|
303
|
+
|
|
304
|
+
Returns:
|
|
305
|
+
A tuple of (hit, result) where hit is True if cache hit,
|
|
306
|
+
False if miss. Result is the cached value (which may be None
|
|
307
|
+
or an empty list) on a hit, or None on a miss.
|
|
308
|
+
"""
|
|
309
|
+
if not self._cache_enabled:
|
|
310
|
+
return False, None
|
|
311
|
+
if table_name not in self._cache:
|
|
312
|
+
self._cache_misses += 1
|
|
313
|
+
return False, None
|
|
314
|
+
if cache_key not in self._cache[table_name]:
|
|
315
|
+
self._cache_misses += 1
|
|
316
|
+
return False, None
|
|
317
|
+
|
|
318
|
+
result, expiration = self._cache[table_name][cache_key]
|
|
319
|
+
|
|
320
|
+
# Check TTL expiration
|
|
321
|
+
if expiration is not None and time.time() > expiration:
|
|
322
|
+
self._cache_misses += 1
|
|
323
|
+
del self._cache[table_name][cache_key]
|
|
324
|
+
return False, None
|
|
325
|
+
|
|
326
|
+
# Mark as recently used (LRU)
|
|
327
|
+
self._cache[table_name].move_to_end(cache_key)
|
|
328
|
+
self._cache_hits += 1
|
|
329
|
+
return True, result
|
|
330
|
+
|
|
331
|
+
def _cache_set(
|
|
332
|
+
self,
|
|
333
|
+
table_name: str,
|
|
334
|
+
cache_key: str,
|
|
335
|
+
result: Any, # noqa: ANN401
|
|
336
|
+
ttl: Optional[int] = None,
|
|
337
|
+
) -> None:
|
|
338
|
+
"""Store result in cache with optional expiration.
|
|
339
|
+
|
|
340
|
+
Args:
|
|
341
|
+
table_name: The name of the table.
|
|
342
|
+
cache_key: The cache key for the query.
|
|
343
|
+
result: The result to cache.
|
|
344
|
+
ttl: Optional TTL override for this specific entry.
|
|
345
|
+
"""
|
|
346
|
+
if not self._cache_enabled:
|
|
347
|
+
return
|
|
348
|
+
|
|
349
|
+
if table_name not in self._cache:
|
|
350
|
+
self._cache[table_name] = OrderedDict()
|
|
351
|
+
|
|
352
|
+
# Calculate expiration (use query-specific TTL if provided)
|
|
353
|
+
expiration = None
|
|
354
|
+
effective_ttl = ttl if ttl is not None else self._cache_ttl
|
|
355
|
+
if effective_ttl is not None:
|
|
356
|
+
expiration = time.time() + effective_ttl
|
|
357
|
+
|
|
358
|
+
self._cache[table_name][cache_key] = (result, expiration)
|
|
359
|
+
# Mark as most-recently-used
|
|
360
|
+
self._cache[table_name].move_to_end(cache_key)
|
|
361
|
+
|
|
362
|
+
# Enforce memory limit if set
|
|
363
|
+
if self._cache_max_memory_mb is not None:
|
|
364
|
+
max_bytes = self._cache_max_memory_mb * 1024 * 1024
|
|
365
|
+
# Evict LRU entries until under the memory limit
|
|
366
|
+
while (
|
|
367
|
+
table_name in self._cache
|
|
368
|
+
and self._get_table_memory_usage(table_name) > max_bytes
|
|
369
|
+
):
|
|
370
|
+
self._cache[table_name].popitem(last=False)
|
|
371
|
+
|
|
372
|
+
# Enforce LRU by size
|
|
373
|
+
if len(self._cache[table_name]) > self._cache_max_size:
|
|
374
|
+
self._cache[table_name].popitem(last=False)
|
|
375
|
+
|
|
376
|
+
def _cache_invalidate_table(self, table_name: str) -> None:
|
|
377
|
+
"""Clear all cached queries for a specific table.
|
|
378
|
+
|
|
379
|
+
Args:
|
|
380
|
+
table_name: The name of the table to invalidate.
|
|
381
|
+
"""
|
|
382
|
+
if not self._cache_enabled:
|
|
383
|
+
return
|
|
384
|
+
self._cache.pop(table_name, None)
|
|
385
|
+
|
|
386
|
+
def _get_table_memory_usage( # noqa: C901
|
|
387
|
+
self, table_name: str
|
|
388
|
+
) -> int:
|
|
389
|
+
"""Calculate the actual memory usage for a table's cache.
|
|
390
|
+
|
|
391
|
+
This method recalculates memory usage on-demand by measuring the
|
|
392
|
+
size of all cached entries including tuple and dict overhead.
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
table_name: The name of the table.
|
|
396
|
+
|
|
397
|
+
Returns:
|
|
398
|
+
The memory usage in bytes.
|
|
399
|
+
"""
|
|
400
|
+
if table_name not in self._cache:
|
|
401
|
+
return 0
|
|
402
|
+
|
|
403
|
+
total = 0
|
|
404
|
+
seen: dict[int, int] = {}
|
|
405
|
+
|
|
406
|
+
for key, (result, _expiration) in self._cache[table_name].items():
|
|
407
|
+
# Measure the tuple (result, expiration)
|
|
408
|
+
total += sys.getsizeof((result, _expiration))
|
|
409
|
+
|
|
410
|
+
# Measure the dict key (cache_key string)
|
|
411
|
+
total += sys.getsizeof(key)
|
|
412
|
+
|
|
413
|
+
# Dict entry overhead (approximately 72 bytes for a dict entry)
|
|
414
|
+
total += 72
|
|
415
|
+
|
|
416
|
+
# Recursively measure the result object
|
|
417
|
+
def measure_size(obj: Any) -> int: # noqa: C901, ANN401
|
|
418
|
+
"""Recursively measure object size with memoization."""
|
|
419
|
+
obj_id = id(obj)
|
|
420
|
+
if obj_id in seen:
|
|
421
|
+
return 0 # Already counted
|
|
422
|
+
|
|
423
|
+
size = sys.getsizeof(obj)
|
|
424
|
+
seen[obj_id] = size
|
|
425
|
+
|
|
426
|
+
# Handle lists
|
|
427
|
+
if isinstance(obj, list):
|
|
428
|
+
for item in obj:
|
|
429
|
+
size += measure_size(item)
|
|
430
|
+
|
|
431
|
+
# Handle Pydantic models - measure their fields
|
|
432
|
+
elif hasattr(type(obj), "model_fields"):
|
|
433
|
+
for field_name in type(obj).model_fields:
|
|
434
|
+
field_value = getattr(obj, field_name, None)
|
|
435
|
+
if field_value is not None:
|
|
436
|
+
size += measure_size(field_value)
|
|
437
|
+
# Also measure __dict__ if present
|
|
438
|
+
if hasattr(obj, "__dict__"):
|
|
439
|
+
size += measure_size(obj.__dict__)
|
|
440
|
+
|
|
441
|
+
# Handle dicts
|
|
442
|
+
elif isinstance(obj, dict):
|
|
443
|
+
for k, v in obj.items():
|
|
444
|
+
size += measure_size(k)
|
|
445
|
+
size += measure_size(v)
|
|
446
|
+
|
|
447
|
+
# Handle sets and tuples
|
|
448
|
+
elif isinstance(obj, (set, tuple)):
|
|
449
|
+
for item in obj:
|
|
450
|
+
size += measure_size(item)
|
|
451
|
+
|
|
452
|
+
return size
|
|
453
|
+
|
|
454
|
+
total += measure_size(result)
|
|
455
|
+
|
|
456
|
+
return total
|
|
457
|
+
|
|
458
|
+
def get_cache_stats(self) -> dict[str, int | float]:
|
|
459
|
+
"""Get cache performance statistics.
|
|
460
|
+
|
|
461
|
+
Returns:
|
|
462
|
+
A dictionary containing cache statistics with keys:
|
|
463
|
+
- hits: Number of cache hits
|
|
464
|
+
- misses: Number of cache misses
|
|
465
|
+
- total: Total number of cache lookups
|
|
466
|
+
- hit_rate: Cache hit rate as a percentage (0-100)
|
|
467
|
+
"""
|
|
468
|
+
total = self._cache_hits + self._cache_misses
|
|
469
|
+
hit_rate = (self._cache_hits / total * 100) if total > 0 else 0.0
|
|
470
|
+
return {
|
|
471
|
+
"hits": self._cache_hits,
|
|
472
|
+
"misses": self._cache_misses,
|
|
473
|
+
"total": total,
|
|
474
|
+
"hit_rate": round(hit_rate, 2),
|
|
475
|
+
}
|
|
476
|
+
|
|
477
|
+
def close(self) -> None:
|
|
478
|
+
"""Close the database connection.
|
|
479
|
+
|
|
480
|
+
This method commits any pending changes if auto_commit is True,
|
|
481
|
+
then closes the connection. If the connection is already closed or does
|
|
482
|
+
not exist, this method silently does nothing.
|
|
483
|
+
"""
|
|
484
|
+
if self.conn:
|
|
485
|
+
self._maybe_commit()
|
|
486
|
+
self.conn.close()
|
|
487
|
+
self.conn = None
|
|
488
|
+
self._cache.clear()
|
|
489
|
+
self._cache_hits = 0
|
|
490
|
+
self._cache_misses = 0
|
|
491
|
+
|
|
492
|
+
def commit(self) -> None:
|
|
493
|
+
"""Commit the current transaction.
|
|
494
|
+
|
|
495
|
+
This method explicitly commits any pending changes to the database.
|
|
496
|
+
"""
|
|
497
|
+
if self.conn:
|
|
498
|
+
self.conn.commit()
|
|
499
|
+
|
|
500
|
+
def _build_field_definitions(
|
|
501
|
+
self,
|
|
502
|
+
model_class: type[BaseDBModel],
|
|
503
|
+
primary_key: str,
|
|
504
|
+
) -> tuple[list[str], list[str], list[str]]:
|
|
505
|
+
"""Build SQL field definitions for table creation.
|
|
506
|
+
|
|
507
|
+
Args:
|
|
508
|
+
model_class: The Pydantic model class.
|
|
509
|
+
primary_key: The name of the primary key field.
|
|
510
|
+
|
|
511
|
+
Returns:
|
|
512
|
+
A tuple of (fields, foreign_keys, fk_columns) where:
|
|
513
|
+
- fields: List of column definitions
|
|
514
|
+
- foreign_keys: List of FK constraint definitions
|
|
515
|
+
- fk_columns: List of FK column names for index creation
|
|
516
|
+
"""
|
|
517
|
+
fields = [f'"{primary_key}" INTEGER PRIMARY KEY AUTOINCREMENT']
|
|
518
|
+
foreign_keys: list[str] = []
|
|
519
|
+
fk_columns: list[str] = []
|
|
520
|
+
|
|
521
|
+
for field_name, field_info in model_class.model_fields.items():
|
|
522
|
+
if field_name == primary_key:
|
|
523
|
+
continue
|
|
524
|
+
|
|
525
|
+
fk_info = get_foreign_key_info(field_info)
|
|
526
|
+
if fk_info is not None:
|
|
527
|
+
col, constraint = self._build_fk_field(field_name, fk_info)
|
|
528
|
+
fields.append(col)
|
|
529
|
+
foreign_keys.append(constraint)
|
|
530
|
+
fk_columns.append(fk_info.db_column or field_name)
|
|
531
|
+
else:
|
|
532
|
+
fields.append(self._build_regular_field(field_name, field_info))
|
|
533
|
+
|
|
534
|
+
return fields, foreign_keys, fk_columns
|
|
535
|
+
|
|
536
|
+
def _build_fk_field(
|
|
537
|
+
self, field_name: str, fk_info: ForeignKeyInfo
|
|
538
|
+
) -> tuple[str, str]:
|
|
539
|
+
"""Build FK column definition and constraint.
|
|
540
|
+
|
|
541
|
+
Args:
|
|
542
|
+
field_name: The name of the field.
|
|
543
|
+
fk_info: The ForeignKeyInfo metadata.
|
|
544
|
+
|
|
545
|
+
Returns:
|
|
546
|
+
A tuple of (column_def, fk_constraint).
|
|
547
|
+
"""
|
|
548
|
+
column_name = fk_info.db_column or field_name
|
|
549
|
+
null_str = "" if fk_info.null else "NOT NULL"
|
|
550
|
+
unique_str = "UNIQUE" if fk_info.unique else ""
|
|
551
|
+
|
|
552
|
+
field_def = f'"{column_name}" INTEGER {null_str} {unique_str}'
|
|
553
|
+
column_def = " ".join(field_def.split())
|
|
554
|
+
|
|
555
|
+
target_table = fk_info.to_model.get_table_name()
|
|
556
|
+
fk_constraint = (
|
|
557
|
+
f'FOREIGN KEY ("{column_name}") '
|
|
558
|
+
f'REFERENCES "{target_table}"("pk") '
|
|
559
|
+
f"ON DELETE {fk_info.on_delete} "
|
|
560
|
+
f"ON UPDATE {fk_info.on_update}"
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
return column_def, fk_constraint
|
|
564
|
+
|
|
565
|
+
def _build_regular_field(
|
|
566
|
+
self, field_name: str, field_info: FieldInfo
|
|
567
|
+
) -> str:
|
|
568
|
+
"""Build a regular (non-FK) column definition.
|
|
569
|
+
|
|
570
|
+
Args:
|
|
571
|
+
field_name: The name of the field.
|
|
572
|
+
field_info: The Pydantic field info.
|
|
573
|
+
|
|
574
|
+
Returns:
|
|
575
|
+
The column definition string.
|
|
576
|
+
"""
|
|
577
|
+
sqlite_type = infer_sqlite_type(field_info.annotation)
|
|
578
|
+
unique_constraint = ""
|
|
579
|
+
if (
|
|
580
|
+
hasattr(field_info, "json_schema_extra")
|
|
581
|
+
and field_info.json_schema_extra
|
|
582
|
+
and isinstance(field_info.json_schema_extra, dict)
|
|
583
|
+
and field_info.json_schema_extra.get("unique", False)
|
|
584
|
+
):
|
|
585
|
+
unique_constraint = "UNIQUE"
|
|
586
|
+
return f'"{field_name}" {sqlite_type} {unique_constraint}'.strip()
|
|
587
|
+
|
|
588
|
+
def create_table(
|
|
589
|
+
self,
|
|
590
|
+
model_class: type[BaseDBModel],
|
|
591
|
+
*,
|
|
592
|
+
exists_ok: bool = True,
|
|
593
|
+
force: bool = False,
|
|
594
|
+
) -> None:
|
|
595
|
+
"""Create a table in the database based on the given model class.
|
|
596
|
+
|
|
597
|
+
Args:
|
|
598
|
+
model_class: The Pydantic model class representing the table.
|
|
599
|
+
exists_ok: If True, do not raise an error if the table already
|
|
600
|
+
exists. Default is True which is the original behavior.
|
|
601
|
+
force: If True, drop the table if it exists before creating.
|
|
602
|
+
Defaults to False.
|
|
603
|
+
|
|
604
|
+
Raises:
|
|
605
|
+
TableCreationError: If there's an error creating the table.
|
|
606
|
+
ValueError: If the primary key field is not found in the model.
|
|
607
|
+
"""
|
|
608
|
+
table_name = model_class.get_table_name()
|
|
609
|
+
primary_key = model_class.get_primary_key()
|
|
610
|
+
|
|
611
|
+
if force:
|
|
612
|
+
drop_table_sql = f"DROP TABLE IF EXISTS {table_name}"
|
|
613
|
+
self._execute_sql(drop_table_sql)
|
|
614
|
+
|
|
615
|
+
fields, foreign_keys, fk_columns = self._build_field_definitions(
|
|
616
|
+
model_class, primary_key
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
# Combine field definitions and FK constraints
|
|
620
|
+
all_definitions = fields + foreign_keys
|
|
621
|
+
|
|
622
|
+
create_str = (
|
|
623
|
+
"CREATE TABLE IF NOT EXISTS" if exists_ok else "CREATE TABLE"
|
|
624
|
+
)
|
|
625
|
+
|
|
626
|
+
create_table_sql = f"""
|
|
627
|
+
{create_str} "{table_name}" (
|
|
628
|
+
{", ".join(all_definitions)}
|
|
629
|
+
)
|
|
630
|
+
"""
|
|
631
|
+
|
|
632
|
+
if self.debug:
|
|
633
|
+
self._log_sql(create_table_sql, [])
|
|
634
|
+
|
|
635
|
+
try:
|
|
636
|
+
with self.connect() as conn:
|
|
637
|
+
cursor = conn.cursor()
|
|
638
|
+
cursor.execute(create_table_sql)
|
|
639
|
+
conn.commit()
|
|
640
|
+
except sqlite3.Error as exc:
|
|
641
|
+
raise TableCreationError(table_name) from exc
|
|
642
|
+
|
|
643
|
+
# Create indexes for FK columns
|
|
644
|
+
for column_name in fk_columns:
|
|
645
|
+
index_sql = (
|
|
646
|
+
f'CREATE INDEX IF NOT EXISTS "idx_{table_name}_{column_name}" '
|
|
647
|
+
f'ON "{table_name}" ("{column_name}")'
|
|
648
|
+
)
|
|
649
|
+
self._execute_sql(index_sql)
|
|
650
|
+
|
|
651
|
+
# Create regular indexes
|
|
652
|
+
if hasattr(model_class.Meta, "indexes"):
|
|
653
|
+
self._create_indexes(
|
|
654
|
+
model_class, model_class.Meta.indexes, unique=False
|
|
655
|
+
)
|
|
656
|
+
|
|
657
|
+
# Create unique indexes
|
|
658
|
+
if hasattr(model_class.Meta, "unique_indexes"):
|
|
659
|
+
self._create_indexes(
|
|
660
|
+
model_class, model_class.Meta.unique_indexes, unique=True
|
|
661
|
+
)
|
|
662
|
+
|
|
663
|
+
def _create_indexes(
|
|
664
|
+
self,
|
|
665
|
+
model_class: type[BaseDBModel],
|
|
666
|
+
indexes: list[Union[str, tuple[str]]],
|
|
667
|
+
*,
|
|
668
|
+
unique: bool = False,
|
|
669
|
+
) -> None:
|
|
670
|
+
"""Helper method to create regular or unique indexes.
|
|
671
|
+
|
|
672
|
+
Args:
|
|
673
|
+
model_class: The model class defining the table.
|
|
674
|
+
indexes: List of fields or tuples of fields to create indexes for.
|
|
675
|
+
unique: If True, creates UNIQUE indexes; otherwise, creates regular
|
|
676
|
+
indexes.
|
|
677
|
+
|
|
678
|
+
Raises:
|
|
679
|
+
InvalidIndexError: If any fields specified for indexing do not exist
|
|
680
|
+
in the model.
|
|
681
|
+
"""
|
|
682
|
+
valid_fields = set(
|
|
683
|
+
model_class.model_fields.keys()
|
|
684
|
+
) # Get valid fields from the model
|
|
685
|
+
|
|
686
|
+
for index in indexes:
|
|
687
|
+
# Handle multiple fields in tuple form
|
|
688
|
+
fields = list(index) if isinstance(index, tuple) else [index]
|
|
689
|
+
|
|
690
|
+
# Check if all fields exist in the model
|
|
691
|
+
invalid_fields = [
|
|
692
|
+
field for field in fields if field not in valid_fields
|
|
693
|
+
]
|
|
694
|
+
if invalid_fields:
|
|
695
|
+
raise InvalidIndexError(invalid_fields, model_class.__name__)
|
|
696
|
+
|
|
697
|
+
# Build the SQL string
|
|
698
|
+
index_name = "_".join(fields)
|
|
699
|
+
index_postfix = "_unique" if unique else ""
|
|
700
|
+
index_type = " UNIQUE " if unique else " "
|
|
701
|
+
|
|
702
|
+
# Quote field names for index creation
|
|
703
|
+
quoted_fields = ", ".join(f'"{field}"' for field in fields)
|
|
704
|
+
|
|
705
|
+
create_index_sql = (
|
|
706
|
+
f"CREATE{index_type}INDEX IF NOT EXISTS "
|
|
707
|
+
f"idx_{model_class.get_table_name()}"
|
|
708
|
+
f"_{index_name}{index_postfix} "
|
|
709
|
+
f'ON "{model_class.get_table_name()}" ({quoted_fields})'
|
|
710
|
+
)
|
|
711
|
+
self._execute_sql(create_index_sql)
|
|
712
|
+
|
|
713
|
+
def _execute_sql(self, sql: str) -> None:
|
|
714
|
+
"""Execute an SQL statement.
|
|
715
|
+
|
|
716
|
+
Args:
|
|
717
|
+
sql: The SQL statement to execute.
|
|
718
|
+
|
|
719
|
+
Raises:
|
|
720
|
+
SqlExecutionError: If the SQL execution fails.
|
|
721
|
+
"""
|
|
722
|
+
if self.debug:
|
|
723
|
+
self._log_sql(sql, [])
|
|
724
|
+
|
|
725
|
+
try:
|
|
726
|
+
with self.connect() as conn:
|
|
727
|
+
cursor = conn.cursor()
|
|
728
|
+
cursor.execute(sql)
|
|
729
|
+
conn.commit()
|
|
730
|
+
except (sqlite3.Error, sqlite3.Warning) as exc:
|
|
731
|
+
raise SqlExecutionError(sql) from exc
|
|
732
|
+
|
|
733
|
+
def drop_table(self, model_class: type[BaseDBModel]) -> None:
|
|
734
|
+
"""Drop the table associated with the given model class.
|
|
735
|
+
|
|
736
|
+
Args:
|
|
737
|
+
model_class: The model class for which to drop the table.
|
|
738
|
+
|
|
739
|
+
Raises:
|
|
740
|
+
TableDeletionError: If there's an error dropping the table.
|
|
741
|
+
"""
|
|
742
|
+
table_name = model_class.get_table_name()
|
|
743
|
+
drop_table_sql = f"DROP TABLE IF EXISTS {table_name}"
|
|
744
|
+
|
|
745
|
+
if self.debug:
|
|
746
|
+
self._log_sql(drop_table_sql, [])
|
|
747
|
+
|
|
748
|
+
try:
|
|
749
|
+
with self.connect() as conn:
|
|
750
|
+
cursor = conn.cursor()
|
|
751
|
+
cursor.execute(drop_table_sql)
|
|
752
|
+
self.commit()
|
|
753
|
+
except sqlite3.Error as exc:
|
|
754
|
+
raise TableDeletionError(table_name) from exc
|
|
755
|
+
|
|
756
|
+
def _maybe_commit(self) -> None:
|
|
757
|
+
"""Commit changes if auto_commit is enabled.
|
|
758
|
+
|
|
759
|
+
This method is called after operations that modify the database,
|
|
760
|
+
committing changes only if auto_commit is set to True.
|
|
761
|
+
"""
|
|
762
|
+
if not self._in_transaction and self.auto_commit and self.conn:
|
|
763
|
+
self.conn.commit()
|
|
764
|
+
|
|
765
|
+
def _set_insert_timestamps(
|
|
766
|
+
self, model_instance: T, *, timestamp_override: bool
|
|
767
|
+
) -> None:
|
|
768
|
+
"""Set created_at and updated_at timestamps for insert.
|
|
769
|
+
|
|
770
|
+
Args:
|
|
771
|
+
model_instance: The model instance to update.
|
|
772
|
+
timestamp_override: If True, respect provided non-zero values.
|
|
773
|
+
"""
|
|
774
|
+
current_timestamp = int(time.time())
|
|
775
|
+
|
|
776
|
+
if not timestamp_override:
|
|
777
|
+
model_instance.created_at = current_timestamp
|
|
778
|
+
model_instance.updated_at = current_timestamp
|
|
779
|
+
else:
|
|
780
|
+
if model_instance.created_at == 0:
|
|
781
|
+
model_instance.created_at = current_timestamp
|
|
782
|
+
if model_instance.updated_at == 0:
|
|
783
|
+
model_instance.updated_at = current_timestamp
|
|
784
|
+
|
|
785
|
+
def insert(
|
|
786
|
+
self, model_instance: T, *, timestamp_override: bool = False
|
|
787
|
+
) -> T:
|
|
788
|
+
"""Insert a new record into the database.
|
|
789
|
+
|
|
790
|
+
Args:
|
|
791
|
+
model_instance: The instance of the model class to insert.
|
|
792
|
+
timestamp_override: If True, override the created_at and updated_at
|
|
793
|
+
timestamps with provided values. Default is False. If the values
|
|
794
|
+
are not provided, they will be set to the current time as
|
|
795
|
+
normal. Without this flag, the timestamps will always be set to
|
|
796
|
+
the current time, even if provided.
|
|
797
|
+
|
|
798
|
+
Returns:
|
|
799
|
+
The updated model instance with the primary key (pk) set.
|
|
800
|
+
|
|
801
|
+
Raises:
|
|
802
|
+
RecordInsertionError: If an error occurs during the insertion.
|
|
803
|
+
"""
|
|
804
|
+
model_class = type(model_instance)
|
|
805
|
+
table_name = model_class.get_table_name()
|
|
806
|
+
|
|
807
|
+
self._set_insert_timestamps(
|
|
808
|
+
model_instance, timestamp_override=timestamp_override
|
|
809
|
+
)
|
|
810
|
+
|
|
811
|
+
# Get the data from the model
|
|
812
|
+
data = model_instance.model_dump()
|
|
813
|
+
|
|
814
|
+
# Serialize the data
|
|
815
|
+
for field_name, value in list(data.items()):
|
|
816
|
+
data[field_name] = model_instance.serialize_field(value)
|
|
817
|
+
|
|
818
|
+
# remove the primary key field if it exists, otherwise we'll get
|
|
819
|
+
# TypeErrors as multiple primary keys will exist
|
|
820
|
+
if data.get("pk", None) == 0:
|
|
821
|
+
data.pop("pk")
|
|
822
|
+
|
|
823
|
+
fields = ", ".join(data.keys())
|
|
824
|
+
placeholders = ", ".join(
|
|
825
|
+
["?" if value is not None else "NULL" for value in data.values()]
|
|
826
|
+
)
|
|
827
|
+
values = tuple(value for value in data.values() if value is not None)
|
|
828
|
+
|
|
829
|
+
insert_sql = f"""
|
|
830
|
+
INSERT INTO {table_name} ({fields})
|
|
831
|
+
VALUES ({placeholders})
|
|
832
|
+
""" # noqa: S608
|
|
833
|
+
|
|
834
|
+
try:
|
|
835
|
+
with self.connect() as conn:
|
|
836
|
+
cursor = conn.cursor()
|
|
837
|
+
cursor.execute(insert_sql, values)
|
|
838
|
+
self._maybe_commit()
|
|
839
|
+
|
|
840
|
+
except sqlite3.IntegrityError as exc:
|
|
841
|
+
# Check for foreign key constraint violation
|
|
842
|
+
if "FOREIGN KEY constraint failed" in str(exc):
|
|
843
|
+
fk_operation = "insert"
|
|
844
|
+
fk_reason = "does not exist in referenced table"
|
|
845
|
+
raise ForeignKeyConstraintError(
|
|
846
|
+
fk_operation, fk_reason
|
|
847
|
+
) from exc
|
|
848
|
+
raise RecordInsertionError(table_name) from exc
|
|
849
|
+
except sqlite3.Error as exc:
|
|
850
|
+
raise RecordInsertionError(table_name) from exc
|
|
851
|
+
else:
|
|
852
|
+
self._cache_invalidate_table(table_name)
|
|
853
|
+
data.pop("pk", None)
|
|
854
|
+
# Deserialize each field before creating the model instance
|
|
855
|
+
deserialized_data = {}
|
|
856
|
+
for field_name, value in data.items():
|
|
857
|
+
deserialized_data[field_name] = model_class.deserialize_field(
|
|
858
|
+
field_name, value, return_local_time=self.return_local_time
|
|
859
|
+
)
|
|
860
|
+
return model_class(pk=cursor.lastrowid, **deserialized_data)
|
|
861
|
+
|
|
862
|
+
def get(
|
|
863
|
+
self, model_class: type[BaseDBModel], primary_key_value: int
|
|
864
|
+
) -> BaseDBModel | None:
|
|
865
|
+
"""Retrieve a single record from the database by its primary key.
|
|
866
|
+
|
|
867
|
+
Args:
|
|
868
|
+
model_class: The Pydantic model class representing the table.
|
|
869
|
+
primary_key_value: The value of the primary key to look up.
|
|
870
|
+
|
|
871
|
+
Returns:
|
|
872
|
+
An instance of the model class if found, None otherwise.
|
|
873
|
+
|
|
874
|
+
Raises:
|
|
875
|
+
RecordFetchError: If there's an error fetching the record.
|
|
876
|
+
"""
|
|
877
|
+
table_name = model_class.get_table_name()
|
|
878
|
+
primary_key = model_class.get_primary_key()
|
|
879
|
+
|
|
880
|
+
fields = ", ".join(model_class.model_fields)
|
|
881
|
+
|
|
882
|
+
select_sql = f"""
|
|
883
|
+
SELECT {fields} FROM {table_name} WHERE {primary_key} = ?
|
|
884
|
+
""" # noqa: S608
|
|
885
|
+
|
|
886
|
+
try:
|
|
887
|
+
with self.connect() as conn:
|
|
888
|
+
cursor = conn.cursor()
|
|
889
|
+
cursor.execute(select_sql, (primary_key_value,))
|
|
890
|
+
result = cursor.fetchone()
|
|
891
|
+
|
|
892
|
+
if result:
|
|
893
|
+
result_dict = {
|
|
894
|
+
field: result[idx]
|
|
895
|
+
for idx, field in enumerate(model_class.model_fields)
|
|
896
|
+
}
|
|
897
|
+
# Deserialize each field before creating the model instance
|
|
898
|
+
deserialized_data = {}
|
|
899
|
+
for field_name, value in result_dict.items():
|
|
900
|
+
deserialized_data[field_name] = (
|
|
901
|
+
model_class.deserialize_field(
|
|
902
|
+
field_name,
|
|
903
|
+
value,
|
|
904
|
+
return_local_time=self.return_local_time,
|
|
905
|
+
)
|
|
906
|
+
)
|
|
907
|
+
return model_class(**deserialized_data)
|
|
908
|
+
except sqlite3.Error as exc:
|
|
909
|
+
raise RecordFetchError(table_name) from exc
|
|
910
|
+
else:
|
|
911
|
+
return None
|
|
912
|
+
|
|
913
|
+
def update(self, model_instance: BaseDBModel) -> None:
|
|
914
|
+
"""Update an existing record in the database.
|
|
915
|
+
|
|
916
|
+
Args:
|
|
917
|
+
model_instance: An instance of a Pydantic model to be updated.
|
|
918
|
+
|
|
919
|
+
Raises:
|
|
920
|
+
RecordUpdateError: If there's an error updating the record or if it
|
|
921
|
+
is not found.
|
|
922
|
+
"""
|
|
923
|
+
model_class = type(model_instance)
|
|
924
|
+
table_name = model_class.get_table_name()
|
|
925
|
+
primary_key = model_class.get_primary_key()
|
|
926
|
+
|
|
927
|
+
# Set updated_at timestamp
|
|
928
|
+
current_timestamp = int(time.time())
|
|
929
|
+
model_instance.updated_at = current_timestamp
|
|
930
|
+
|
|
931
|
+
# Get the data and serialize any datetime/date fields
|
|
932
|
+
data = model_instance.model_dump()
|
|
933
|
+
for field_name, value in list(data.items()):
|
|
934
|
+
data[field_name] = model_instance.serialize_field(value)
|
|
935
|
+
|
|
936
|
+
# Remove the primary key from the update data
|
|
937
|
+
primary_key_value = data.pop(primary_key)
|
|
938
|
+
|
|
939
|
+
# Create the SQL using the processed data
|
|
940
|
+
fields = ", ".join(f"{field} = ?" for field in data)
|
|
941
|
+
values = tuple(data.values())
|
|
942
|
+
|
|
943
|
+
update_sql = f"""
|
|
944
|
+
UPDATE {table_name}
|
|
945
|
+
SET {fields}
|
|
946
|
+
WHERE {primary_key} = ?
|
|
947
|
+
""" # noqa: S608
|
|
948
|
+
|
|
949
|
+
try:
|
|
950
|
+
with self.connect() as conn:
|
|
951
|
+
cursor = conn.cursor()
|
|
952
|
+
cursor.execute(update_sql, (*values, primary_key_value))
|
|
953
|
+
|
|
954
|
+
# Check if any rows were updated
|
|
955
|
+
if cursor.rowcount == 0:
|
|
956
|
+
raise RecordNotFoundError(primary_key_value)
|
|
957
|
+
|
|
958
|
+
self._maybe_commit()
|
|
959
|
+
self._cache_invalidate_table(table_name)
|
|
960
|
+
|
|
961
|
+
except sqlite3.Error as exc:
|
|
962
|
+
raise RecordUpdateError(table_name) from exc
|
|
963
|
+
|
|
964
|
+
def delete(
|
|
965
|
+
self, model_class: type[BaseDBModel], primary_key_value: str
|
|
966
|
+
) -> None:
|
|
967
|
+
"""Delete a record from the database by its primary key.
|
|
968
|
+
|
|
969
|
+
Args:
|
|
970
|
+
model_class: The Pydantic model class representing the table.
|
|
971
|
+
primary_key_value: The value of the primary key of the record to
|
|
972
|
+
delete.
|
|
973
|
+
|
|
974
|
+
Raises:
|
|
975
|
+
RecordDeletionError: If there's an error deleting the record.
|
|
976
|
+
RecordNotFoundError: If the record to delete is not found.
|
|
977
|
+
"""
|
|
978
|
+
table_name = model_class.get_table_name()
|
|
979
|
+
primary_key = model_class.get_primary_key()
|
|
980
|
+
|
|
981
|
+
delete_sql = f"""
|
|
982
|
+
DELETE FROM {table_name} WHERE {primary_key} = ?
|
|
983
|
+
""" # noqa: S608
|
|
984
|
+
|
|
985
|
+
try:
|
|
986
|
+
with self.connect() as conn:
|
|
987
|
+
cursor = conn.cursor()
|
|
988
|
+
cursor.execute(delete_sql, (primary_key_value,))
|
|
989
|
+
|
|
990
|
+
if cursor.rowcount == 0:
|
|
991
|
+
raise RecordNotFoundError(primary_key_value)
|
|
992
|
+
self._maybe_commit()
|
|
993
|
+
self._cache_invalidate_table(table_name)
|
|
994
|
+
except sqlite3.IntegrityError as exc:
|
|
995
|
+
# Check for foreign key constraint violation (RESTRICT)
|
|
996
|
+
if "FOREIGN KEY constraint failed" in str(exc):
|
|
997
|
+
fk_operation = "delete"
|
|
998
|
+
fk_reason = "is still referenced by other records"
|
|
999
|
+
raise ForeignKeyConstraintError(
|
|
1000
|
+
fk_operation, fk_reason
|
|
1001
|
+
) from exc
|
|
1002
|
+
raise RecordDeletionError(table_name) from exc
|
|
1003
|
+
except sqlite3.Error as exc:
|
|
1004
|
+
raise RecordDeletionError(table_name) from exc
|
|
1005
|
+
|
|
1006
|
+
def select(
|
|
1007
|
+
self,
|
|
1008
|
+
model_class: type[T],
|
|
1009
|
+
fields: Optional[list[str]] = None,
|
|
1010
|
+
exclude: Optional[list[str]] = None,
|
|
1011
|
+
) -> QueryBuilder[T]:
|
|
1012
|
+
"""Create a QueryBuilder instance for selecting records.
|
|
1013
|
+
|
|
1014
|
+
Args:
|
|
1015
|
+
model_class: The Pydantic model class representing the table.
|
|
1016
|
+
fields: Optional list of fields to include in the query.
|
|
1017
|
+
exclude: Optional list of fields to exclude from the query.
|
|
1018
|
+
|
|
1019
|
+
Returns:
|
|
1020
|
+
A QueryBuilder instance for further query construction.
|
|
1021
|
+
"""
|
|
1022
|
+
query_builder: QueryBuilder[T] = QueryBuilder(self, model_class, fields)
|
|
1023
|
+
|
|
1024
|
+
# If exclude is provided, apply the exclude method
|
|
1025
|
+
if exclude:
|
|
1026
|
+
query_builder.exclude(exclude)
|
|
1027
|
+
|
|
1028
|
+
return query_builder
|
|
1029
|
+
|
|
1030
|
+
# --- Context manager methods ---
|
|
1031
|
+
def __enter__(self) -> Self:
|
|
1032
|
+
"""Enter the runtime context for the SqliterDB instance.
|
|
1033
|
+
|
|
1034
|
+
This method is called when entering a 'with' statement. It ensures
|
|
1035
|
+
that a database connection is established.
|
|
1036
|
+
|
|
1037
|
+
Note that this method should never be called explicitly, but will be
|
|
1038
|
+
called by the 'with' statement when entering the context.
|
|
1039
|
+
|
|
1040
|
+
Returns:
|
|
1041
|
+
The SqliterDB instance.
|
|
1042
|
+
|
|
1043
|
+
"""
|
|
1044
|
+
self.connect()
|
|
1045
|
+
self._in_transaction = True
|
|
1046
|
+
return self
|
|
1047
|
+
|
|
1048
|
+
def __exit__(
|
|
1049
|
+
self,
|
|
1050
|
+
exc_type: Optional[type[BaseException]],
|
|
1051
|
+
exc_value: Optional[BaseException],
|
|
1052
|
+
traceback: Optional[TracebackType],
|
|
1053
|
+
) -> None:
|
|
1054
|
+
"""Exit the runtime context for the SqliterDB instance.
|
|
1055
|
+
|
|
1056
|
+
This method is called when exiting a 'with' statement. It handles
|
|
1057
|
+
committing or rolling back transactions based on whether an exception
|
|
1058
|
+
occurred, and closes the database connection.
|
|
1059
|
+
|
|
1060
|
+
Args:
|
|
1061
|
+
exc_type: The type of the exception that caused the context to be
|
|
1062
|
+
exited, or None if no exception was raised.
|
|
1063
|
+
exc_value: The instance of the exception that caused the context
|
|
1064
|
+
to be exited, or None if no exception was raised.
|
|
1065
|
+
traceback: A traceback object encoding the stack trace, or None
|
|
1066
|
+
if no exception was raised.
|
|
1067
|
+
|
|
1068
|
+
Note that this method should never be called explicitly, but will be
|
|
1069
|
+
called by the 'with' statement when exiting the context.
|
|
1070
|
+
|
|
1071
|
+
"""
|
|
1072
|
+
if self.conn:
|
|
1073
|
+
try:
|
|
1074
|
+
if exc_type:
|
|
1075
|
+
# Roll back the transaction if there was an exception
|
|
1076
|
+
self.conn.rollback()
|
|
1077
|
+
else:
|
|
1078
|
+
self.conn.commit()
|
|
1079
|
+
finally:
|
|
1080
|
+
# Close the connection and reset the instance variable
|
|
1081
|
+
self.conn.close()
|
|
1082
|
+
self.conn = None
|
|
1083
|
+
self._in_transaction = False
|
|
1084
|
+
# Clear cache when exiting context
|
|
1085
|
+
self._cache.clear()
|
|
1086
|
+
self._cache_hits = 0
|
|
1087
|
+
self._cache_misses = 0
|