sqliter-py 0.3.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sqliter/sqliter.py CHANGED
@@ -1,21 +1,34 @@
1
- """This is the main module for the sqliter package."""
1
+ """Core module for SQLiter, providing the main database interaction class.
2
+
3
+ This module defines the SqliterDB class, which serves as the primary
4
+ interface for all database operations in SQLiter. It handles connection
5
+ management, table creation, and CRUD operations, bridging the gap between
6
+ Pydantic models and SQLite database interactions.
7
+ """
2
8
 
3
9
  from __future__ import annotations
4
10
 
11
+ import logging
5
12
  import sqlite3
6
- from typing import TYPE_CHECKING, Optional
13
+ import time
14
+ from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
7
15
 
8
16
  from typing_extensions import Self
9
17
 
10
18
  from sqliter.exceptions import (
11
19
  DatabaseConnectionError,
20
+ InvalidIndexError,
12
21
  RecordDeletionError,
13
22
  RecordFetchError,
14
23
  RecordInsertionError,
15
24
  RecordNotFoundError,
16
25
  RecordUpdateError,
26
+ SqlExecutionError,
17
27
  TableCreationError,
28
+ TableDeletionError,
18
29
  )
30
+ from sqliter.helpers import infer_sqlite_type
31
+ from sqliter.model.unique import Unique
19
32
  from sqliter.query.query import QueryBuilder
20
33
 
21
34
  if TYPE_CHECKING: # pragma: no cover
@@ -23,20 +36,52 @@ if TYPE_CHECKING: # pragma: no cover
23
36
 
24
37
  from sqliter.model.model import BaseDBModel
25
38
 
39
+ T = TypeVar("T", bound="BaseDBModel")
40
+
26
41
 
27
42
  class SqliterDB:
28
- """Class to manage SQLite database interactions."""
43
+ """Main class for interacting with SQLite databases.
44
+
45
+ This class provides methods for connecting to a SQLite database,
46
+ creating tables, and performing CRUD operations.
47
+
48
+ Arguements:
49
+ db_filename (str): The filename of the SQLite database.
50
+ auto_commit (bool): Whether to automatically commit transactions.
51
+ debug (bool): Whether to enable debug logging.
52
+ logger (Optional[logging.Logger]): Custom logger for debug output.
53
+ """
54
+
55
+ MEMORY_DB = ":memory:"
29
56
 
30
- def __init__(
57
+ def __init__( # noqa: PLR0913
31
58
  self,
32
59
  db_filename: Optional[str] = None,
33
60
  *,
34
61
  memory: bool = False,
35
62
  auto_commit: bool = True,
63
+ debug: bool = False,
64
+ logger: Optional[logging.Logger] = None,
65
+ reset: bool = False,
66
+ return_local_time: bool = True,
36
67
  ) -> None:
37
- """Initialize the class and options."""
68
+ """Initialize a new SqliterDB instance.
69
+
70
+ Args:
71
+ db_filename: The filename of the SQLite database.
72
+ memory: If True, create an in-memory database.
73
+ auto_commit: Whether to automatically commit transactions.
74
+ debug: Whether to enable debug logging.
75
+ logger: Custom logger for debug output.
76
+ reset: Whether to reset the database on initialization. This will
77
+ basically drop all existing tables.
78
+ return_local_time: Whether to return local time for datetime fields.
79
+
80
+ Raises:
81
+ ValueError: If no filename is provided for a non-memory database.
82
+ """
38
83
  if memory:
39
- self.db_filename = ":memory:"
84
+ self.db_filename = self.MEMORY_DB
40
85
  elif db_filename:
41
86
  self.db_filename = db_filename
42
87
  else:
@@ -46,10 +91,149 @@ class SqliterDB:
46
91
  )
47
92
  raise ValueError(err)
48
93
  self.auto_commit = auto_commit
94
+ self.debug = debug
95
+ self.logger = logger
49
96
  self.conn: Optional[sqlite3.Connection] = None
97
+ self.reset = reset
98
+ self.return_local_time = return_local_time
99
+
100
+ self._in_transaction = False
101
+
102
+ if self.debug:
103
+ self._setup_logger()
104
+
105
+ if self.reset:
106
+ self._reset_database()
107
+
108
+ @property
109
+ def filename(self) -> Optional[str]:
110
+ """Returns the filename of the current database or None if in-memory."""
111
+ return None if self.db_filename == self.MEMORY_DB else self.db_filename
112
+
113
+ @property
114
+ def is_memory(self) -> bool:
115
+ """Returns True if the database is in-memory."""
116
+ return self.db_filename == self.MEMORY_DB
117
+
118
+ @property
119
+ def is_autocommit(self) -> bool:
120
+ """Returns True if auto-commit is enabled."""
121
+ return self.auto_commit
122
+
123
+ @property
124
+ def is_connected(self) -> bool:
125
+ """Returns True if the database is connected, False otherwise."""
126
+ return self.conn is not None
127
+
128
+ @property
129
+ def table_names(self) -> list[str]:
130
+ """Returns a list of all table names in the database.
131
+
132
+ Temporarily connects to the database if not connected and restores
133
+ the connection state afterward.
134
+ """
135
+ was_connected = self.is_connected
136
+ if not was_connected:
137
+ self.connect()
138
+
139
+ if self.conn is None:
140
+ err_msg = "Failed to establish a database connection."
141
+ raise DatabaseConnectionError(err_msg)
142
+
143
+ cursor = self.conn.cursor()
144
+ cursor.execute(
145
+ "SELECT name FROM sqlite_master WHERE type='table' "
146
+ "AND name NOT LIKE 'sqlite_%';"
147
+ )
148
+ tables = [row[0] for row in cursor.fetchall()]
149
+
150
+ # Restore the connection state
151
+ if not was_connected:
152
+ self.close()
153
+
154
+ return tables
155
+
156
+ def _reset_database(self) -> None:
157
+ """Drop all user-created tables in the database."""
158
+ with self.connect() as conn:
159
+ cursor = conn.cursor()
160
+
161
+ # Get all table names, excluding SQLite system tables
162
+ cursor.execute(
163
+ "SELECT name FROM sqlite_master WHERE type='table' "
164
+ "AND name NOT LIKE 'sqlite_%';"
165
+ )
166
+ tables = cursor.fetchall()
167
+
168
+ # Drop each user-created table
169
+ for table in tables:
170
+ cursor.execute(f"DROP TABLE IF EXISTS {table[0]}")
171
+
172
+ conn.commit()
173
+
174
+ if self.debug and self.logger:
175
+ self.logger.debug(
176
+ "Database reset: %s user-created tables dropped.", len(tables)
177
+ )
178
+
179
+ def _setup_logger(self) -> None:
180
+ """Set up the logger for debug output.
181
+
182
+ This method configures a logger for the SqliterDB instance, either
183
+ using an existing logger or creating a new one specifically for
184
+ SQLiter.
185
+ """
186
+ # Check if the root logger is already configured
187
+ root_logger = logging.getLogger()
188
+
189
+ if root_logger.hasHandlers():
190
+ # If the root logger has handlers, use it without modifying the root
191
+ # configuration
192
+ self.logger = root_logger.getChild("sqliter")
193
+ else:
194
+ # If no root logger is configured, set up a new logger specific to
195
+ # SqliterDB
196
+ self.logger = logging.getLogger("sqliter")
197
+
198
+ handler = logging.StreamHandler() # Output to console
199
+ formatter = logging.Formatter(
200
+ "%(levelname)-8s%(message)s"
201
+ ) # Custom format
202
+ handler.setFormatter(formatter)
203
+ self.logger.addHandler(handler)
204
+
205
+ self.logger.setLevel(logging.DEBUG)
206
+ self.logger.propagate = False
207
+
208
+ def _log_sql(self, sql: str, values: list[Any]) -> None:
209
+ """Log the SQL query and its values if debug mode is enabled.
210
+
211
+ The values are inserted into the SQL query string to replace the
212
+ placeholders.
213
+
214
+ Args:
215
+ sql: The SQL query string.
216
+ values: The list of values to be inserted into the query.
217
+ """
218
+ if self.debug and self.logger:
219
+ formatted_sql = sql
220
+ for value in values:
221
+ if isinstance(value, str):
222
+ formatted_sql = formatted_sql.replace("?", f"'{value}'", 1)
223
+ else:
224
+ formatted_sql = formatted_sql.replace("?", str(value), 1)
225
+
226
+ self.logger.debug("Executing SQL: %s", formatted_sql)
50
227
 
51
228
  def connect(self) -> sqlite3.Connection:
52
- """Create or return a connection to the SQLite database."""
229
+ """Establish a connection to the SQLite database.
230
+
231
+ Returns:
232
+ The SQLite connection object.
233
+
234
+ Raises:
235
+ DatabaseConnectionError: If unable to connect to the database.
236
+ """
53
237
  if not self.conn:
54
238
  try:
55
239
  self.conn = sqlite3.connect(self.db_filename)
@@ -58,41 +242,77 @@ class SqliterDB:
58
242
  return self.conn
59
243
 
60
244
  def close(self) -> None:
61
- """Close the connection to the SQLite database."""
245
+ """Close the database connection.
246
+
247
+ This method commits any pending changes if auto_commit is True,
248
+ then closes the connection. If the connection is already closed or does
249
+ not exist, this method silently does nothing.
250
+ """
62
251
  if self.conn:
63
252
  self._maybe_commit()
64
253
  self.conn.close()
65
254
  self.conn = None
66
255
 
67
256
  def commit(self) -> None:
68
- """Commit any pending transactions."""
257
+ """Commit the current transaction.
258
+
259
+ This method explicitly commits any pending changes to the database.
260
+ """
69
261
  if self.conn:
70
262
  self.conn.commit()
71
263
 
72
- def create_table(self, model_class: type[BaseDBModel]) -> None:
73
- """Create a table based on the Pydantic model."""
264
+ def create_table(
265
+ self,
266
+ model_class: type[BaseDBModel],
267
+ *,
268
+ exists_ok: bool = True,
269
+ force: bool = False,
270
+ ) -> None:
271
+ """Create a table in the database based on the given model class.
272
+
273
+ Args:
274
+ model_class: The Pydantic model class representing the table.
275
+ exists_ok: If True, do not raise an error if the table already
276
+ exists. Default is True which is the original behavior.
277
+ force: If True, drop the table if it exists before creating.
278
+ Defaults to False.
279
+
280
+ Raises:
281
+ TableCreationError: If there's an error creating the table.
282
+ ValueError: If the primary key field is not found in the model.
283
+ """
74
284
  table_name = model_class.get_table_name()
75
285
  primary_key = model_class.get_primary_key()
76
- create_pk = model_class.should_create_pk()
77
286
 
78
- fields = ", ".join(
79
- f"{field_name} TEXT" for field_name in model_class.model_fields
80
- )
287
+ if force:
288
+ drop_table_sql = f"DROP TABLE IF EXISTS {table_name}"
289
+ self._execute_sql(drop_table_sql)
290
+
291
+ fields = [f'"{primary_key}" INTEGER PRIMARY KEY AUTOINCREMENT']
81
292
 
82
- if create_pk:
83
- create_table_sql = f"""
84
- CREATE TABLE IF NOT EXISTS {table_name} (
85
- {primary_key} INTEGER PRIMARY KEY AUTOINCREMENT,
86
- {fields}
293
+ # Add remaining fields
294
+ for field_name, field_info in model_class.model_fields.items():
295
+ if field_name != primary_key:
296
+ sqlite_type = infer_sqlite_type(field_info.annotation)
297
+ unique_constraint = (
298
+ "UNIQUE" if isinstance(field_info, Unique) else ""
87
299
  )
88
- """
89
- else:
90
- create_table_sql = f"""
91
- CREATE TABLE IF NOT EXISTS {table_name} (
92
- {fields},
93
- PRIMARY KEY ({primary_key})
300
+ fields.append(
301
+ f"{field_name} {sqlite_type} {unique_constraint}".strip()
94
302
  )
95
- """
303
+
304
+ create_str = (
305
+ "CREATE TABLE IF NOT EXISTS" if exists_ok else "CREATE TABLE"
306
+ )
307
+
308
+ create_table_sql = f"""
309
+ {create_str} {table_name} (
310
+ {", ".join(fields)}
311
+ )
312
+ """
313
+
314
+ if self.debug:
315
+ self._log_sql(create_table_sql, [])
96
316
 
97
317
  try:
98
318
  with self.connect() as conn:
@@ -102,17 +322,166 @@ class SqliterDB:
102
322
  except sqlite3.Error as exc:
103
323
  raise TableCreationError(table_name) from exc
104
324
 
325
+ # Create regular indexes
326
+ if hasattr(model_class.Meta, "indexes"):
327
+ self._create_indexes(
328
+ model_class, model_class.Meta.indexes, unique=False
329
+ )
330
+
331
+ # Create unique indexes
332
+ if hasattr(model_class.Meta, "unique_indexes"):
333
+ self._create_indexes(
334
+ model_class, model_class.Meta.unique_indexes, unique=True
335
+ )
336
+
337
+ def _create_indexes(
338
+ self,
339
+ model_class: type[BaseDBModel],
340
+ indexes: list[Union[str, tuple[str]]],
341
+ *,
342
+ unique: bool = False,
343
+ ) -> None:
344
+ """Helper method to create regular or unique indexes.
345
+
346
+ Args:
347
+ model_class: The model class defining the table.
348
+ indexes: List of fields or tuples of fields to create indexes for.
349
+ unique: If True, creates UNIQUE indexes; otherwise, creates regular
350
+ indexes.
351
+
352
+ Raises:
353
+ InvalidIndexError: If any fields specified for indexing do not exist
354
+ in the model.
355
+ """
356
+ valid_fields = set(
357
+ model_class.model_fields.keys()
358
+ ) # Get valid fields from the model
359
+
360
+ for index in indexes:
361
+ # Handle multiple fields in tuple form
362
+ fields = list(index) if isinstance(index, tuple) else [index]
363
+
364
+ # Check if all fields exist in the model
365
+ invalid_fields = [
366
+ field for field in fields if field not in valid_fields
367
+ ]
368
+ if invalid_fields:
369
+ raise InvalidIndexError(invalid_fields, model_class.__name__)
370
+
371
+ # Build the SQL string
372
+ index_name = "_".join(fields)
373
+ index_postfix = "_unique" if unique else ""
374
+ index_type = " UNIQUE " if unique else " "
375
+
376
+ create_index_sql = (
377
+ f"CREATE{index_type}INDEX IF NOT EXISTS "
378
+ f"idx_{model_class.get_table_name()}"
379
+ f"_{index_name}{index_postfix} "
380
+ f"ON {model_class.get_table_name()} ({', '.join(fields)})"
381
+ )
382
+ self._execute_sql(create_index_sql)
383
+
384
+ def _execute_sql(self, sql: str) -> None:
385
+ """Execute an SQL statement.
386
+
387
+ Args:
388
+ sql: The SQL statement to execute.
389
+
390
+ Raises:
391
+ SqlExecutionError: If the SQL execution fails.
392
+ """
393
+ if self.debug:
394
+ self._log_sql(sql, [])
395
+
396
+ try:
397
+ with self.connect() as conn:
398
+ cursor = conn.cursor()
399
+ cursor.execute(sql)
400
+ conn.commit()
401
+ except (sqlite3.Error, sqlite3.Warning) as exc:
402
+ raise SqlExecutionError(sql) from exc
403
+
404
+ def drop_table(self, model_class: type[BaseDBModel]) -> None:
405
+ """Drop the table associated with the given model class.
406
+
407
+ Args:
408
+ model_class: The model class for which to drop the table.
409
+
410
+ Raises:
411
+ TableDeletionError: If there's an error dropping the table.
412
+ """
413
+ table_name = model_class.get_table_name()
414
+ drop_table_sql = f"DROP TABLE IF EXISTS {table_name}"
415
+
416
+ if self.debug:
417
+ self._log_sql(drop_table_sql, [])
418
+
419
+ try:
420
+ with self.connect() as conn:
421
+ cursor = conn.cursor()
422
+ cursor.execute(drop_table_sql)
423
+ self.commit()
424
+ except sqlite3.Error as exc:
425
+ raise TableDeletionError(table_name) from exc
426
+
105
427
  def _maybe_commit(self) -> None:
106
- """Commit changes if auto_commit is True."""
107
- if self.auto_commit and self.conn:
428
+ """Commit changes if auto_commit is enabled.
429
+
430
+ This method is called after operations that modify the database,
431
+ committing changes only if auto_commit is set to True.
432
+ """
433
+ if not self._in_transaction and self.auto_commit and self.conn:
108
434
  self.conn.commit()
109
435
 
110
- def insert(self, model_instance: BaseDBModel) -> None:
111
- """Insert a new record into the table defined by the Pydantic model."""
436
+ def insert(
437
+ self, model_instance: T, *, timestamp_override: bool = False
438
+ ) -> T:
439
+ """Insert a new record into the database.
440
+
441
+ Args:
442
+ model_instance: The instance of the model class to insert.
443
+ timestamp_override: If True, override the created_at and updated_at
444
+ timestamps with provided values. Default is False. If the values
445
+ are not provided, they will be set to the current time as
446
+ normal. Without this flag, the timestamps will always be set to
447
+ the current time, even if provided.
448
+
449
+ Returns:
450
+ The updated model instance with the primary key (pk) set.
451
+
452
+ Raises:
453
+ RecordInsertionError: If an error occurs during the insertion.
454
+ """
112
455
  model_class = type(model_instance)
113
456
  table_name = model_class.get_table_name()
114
457
 
458
+ # Always set created_at and updated_at timestamps
459
+ current_timestamp = int(time.time())
460
+
461
+ # Handle the case where timestamp_override is False
462
+ if not timestamp_override:
463
+ # Always override both timestamps with the current time
464
+ model_instance.created_at = current_timestamp
465
+ model_instance.updated_at = current_timestamp
466
+ else:
467
+ # Respect provided values, but set to current time if they are 0
468
+ if model_instance.created_at == 0:
469
+ model_instance.created_at = current_timestamp
470
+ if model_instance.updated_at == 0:
471
+ model_instance.updated_at = current_timestamp
472
+
473
+ # Get the data from the model
115
474
  data = model_instance.model_dump()
475
+
476
+ # Serialize the data
477
+ for field_name, value in list(data.items()):
478
+ data[field_name] = model_instance.serialize_field(value)
479
+
480
+ # remove the primary key field if it exists, otherwise we'll get
481
+ # TypeErrors as multiple primary keys will exist
482
+ if data.get("pk", None) == 0:
483
+ data.pop("pk")
484
+
116
485
  fields = ", ".join(data.keys())
117
486
  placeholders = ", ".join(
118
487
  ["?" if value is not None else "NULL" for value in data.values()]
@@ -129,13 +498,34 @@ class SqliterDB:
129
498
  cursor = conn.cursor()
130
499
  cursor.execute(insert_sql, values)
131
500
  self._maybe_commit()
501
+
132
502
  except sqlite3.Error as exc:
133
503
  raise RecordInsertionError(table_name) from exc
504
+ else:
505
+ data.pop("pk", None)
506
+ # Deserialize each field before creating the model instance
507
+ deserialized_data = {}
508
+ for field_name, value in data.items():
509
+ deserialized_data[field_name] = model_class.deserialize_field(
510
+ field_name, value, return_local_time=self.return_local_time
511
+ )
512
+ return model_class(pk=cursor.lastrowid, **deserialized_data)
134
513
 
135
514
  def get(
136
- self, model_class: type[BaseDBModel], primary_key_value: str
515
+ self, model_class: type[BaseDBModel], primary_key_value: int
137
516
  ) -> BaseDBModel | None:
138
- """Retrieve a record by its PK and return a Pydantic instance."""
517
+ """Retrieve a single record from the database by its primary key.
518
+
519
+ Args:
520
+ model_class: The Pydantic model class representing the table.
521
+ primary_key_value: The value of the primary key to look up.
522
+
523
+ Returns:
524
+ An instance of the model class if found, None otherwise.
525
+
526
+ Raises:
527
+ RecordFetchError: If there's an error fetching the record.
528
+ """
139
529
  table_name = model_class.get_table_name()
140
530
  primary_key = model_class.get_primary_key()
141
531
 
@@ -156,29 +546,51 @@ class SqliterDB:
156
546
  field: result[idx]
157
547
  for idx, field in enumerate(model_class.model_fields)
158
548
  }
159
- return model_class(**result_dict)
549
+ # Deserialize each field before creating the model instance
550
+ deserialized_data = {}
551
+ for field_name, value in result_dict.items():
552
+ deserialized_data[field_name] = (
553
+ model_class.deserialize_field(
554
+ field_name,
555
+ value,
556
+ return_local_time=self.return_local_time,
557
+ )
558
+ )
559
+ return model_class(**deserialized_data)
160
560
  except sqlite3.Error as exc:
161
561
  raise RecordFetchError(table_name) from exc
162
562
  else:
163
563
  return None
164
564
 
165
565
  def update(self, model_instance: BaseDBModel) -> None:
166
- """Update an existing record using the Pydantic model."""
566
+ """Update an existing record in the database.
567
+
568
+ Args:
569
+ model_instance: An instance of a Pydantic model to be updated.
570
+
571
+ Raises:
572
+ RecordUpdateError: If there's an error updating the record or if it
573
+ is not found.
574
+ """
167
575
  model_class = type(model_instance)
168
576
  table_name = model_class.get_table_name()
169
577
  primary_key = model_class.get_primary_key()
170
578
 
171
- fields = ", ".join(
172
- f"{field} = ?"
173
- for field in model_class.model_fields
174
- if field != primary_key
175
- )
176
- values = tuple(
177
- getattr(model_instance, field)
178
- for field in model_class.model_fields
179
- if field != primary_key
180
- )
181
- primary_key_value = getattr(model_instance, primary_key)
579
+ # Set updated_at timestamp
580
+ current_timestamp = int(time.time())
581
+ model_instance.updated_at = current_timestamp
582
+
583
+ # Get the data and serialize any datetime/date fields
584
+ data = model_instance.model_dump()
585
+ for field_name, value in list(data.items()):
586
+ data[field_name] = model_instance.serialize_field(value)
587
+
588
+ # Remove the primary key from the update data
589
+ primary_key_value = data.pop(primary_key)
590
+
591
+ # Create the SQL using the processed data
592
+ fields = ", ".join(f"{field} = ?" for field in data)
593
+ values = tuple(data.values())
182
594
 
183
595
  update_sql = f"""
184
596
  UPDATE {table_name}
@@ -203,7 +615,17 @@ class SqliterDB:
203
615
  def delete(
204
616
  self, model_class: type[BaseDBModel], primary_key_value: str
205
617
  ) -> None:
206
- """Delete a record by its primary key."""
618
+ """Delete a record from the database by its primary key.
619
+
620
+ Args:
621
+ model_class: The Pydantic model class representing the table.
622
+ primary_key_value: The value of the primary key of the record to
623
+ delete.
624
+
625
+ Raises:
626
+ RecordDeletionError: If there's an error deleting the record.
627
+ RecordNotFoundError: If the record to delete is not found.
628
+ """
207
629
  table_name = model_class.get_table_name()
208
630
  primary_key = model_class.get_primary_key()
209
631
 
@@ -228,18 +650,15 @@ class SqliterDB:
228
650
  fields: Optional[list[str]] = None,
229
651
  exclude: Optional[list[str]] = None,
230
652
  ) -> QueryBuilder:
231
- """Start a query for the given model.
653
+ """Create a QueryBuilder instance for selecting records.
232
654
 
233
655
  Args:
234
- model_class: The model class to query.
235
- fields: Optional list of field names to select. If None, all fields
236
- are selected.
237
- exclude: Optional list of field names to exclude from the query
238
- output.
656
+ model_class: The Pydantic model class representing the table.
657
+ fields: Optional list of fields to include in the query.
658
+ exclude: Optional list of fields to exclude from the query.
239
659
 
240
660
  Returns:
241
- QueryBuilder: An instance of QueryBuilder for the given model and
242
- fields.
661
+ A QueryBuilder instance for further query construction.
243
662
  """
244
663
  query_builder = QueryBuilder(self, model_class, fields)
245
664
 
@@ -251,8 +670,20 @@ class SqliterDB:
251
670
 
252
671
  # --- Context manager methods ---
253
672
  def __enter__(self) -> Self:
254
- """Enter the runtime context for the 'with' statement."""
673
+ """Enter the runtime context for the SqliterDB instance.
674
+
675
+ This method is called when entering a 'with' statement. It ensures
676
+ that a database connection is established.
677
+
678
+ Note that this method should never be called explicitly, but will be
679
+ called by the 'with' statement when entering the context.
680
+
681
+ Returns:
682
+ The SqliterDB instance.
683
+
684
+ """
255
685
  self.connect()
686
+ self._in_transaction = True
256
687
  return self
257
688
 
258
689
  def __exit__(
@@ -261,7 +692,24 @@ class SqliterDB:
261
692
  exc_value: Optional[BaseException],
262
693
  traceback: Optional[TracebackType],
263
694
  ) -> None:
264
- """Exit the runtime context and close the connection."""
695
+ """Exit the runtime context for the SqliterDB instance.
696
+
697
+ This method is called when exiting a 'with' statement. It handles
698
+ committing or rolling back transactions based on whether an exception
699
+ occurred, and closes the database connection.
700
+
701
+ Args:
702
+ exc_type: The type of the exception that caused the context to be
703
+ exited, or None if no exception was raised.
704
+ exc_value: The instance of the exception that caused the context
705
+ to be exited, or None if no exception was raised.
706
+ traceback: A traceback object encoding the stack trace, or None
707
+ if no exception was raised.
708
+
709
+ Note that this method should never be called explicitly, but will be
710
+ called by the 'with' statement when exiting the context.
711
+
712
+ """
265
713
  if self.conn:
266
714
  try:
267
715
  if exc_type:
@@ -273,3 +721,4 @@ class SqliterDB:
273
721
  # Close the connection and reset the instance variable
274
722
  self.conn.close()
275
723
  self.conn = None
724
+ self._in_transaction = False