sqliter-py 0.4.0__tar.gz → 0.6.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqliter-py might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: sqliter-py
3
- Version: 0.4.0
3
+ Version: 0.6.0
4
4
  Summary: Interact with SQLite databases using Python and Pydantic
5
5
  Project-URL: Pull Requests, https://github.com/seapagan/sqliter-py/pulls
6
6
  Project-URL: Bug Tracker, https://github.com/seapagan/sqliter-py/issues
@@ -52,14 +52,16 @@ Website](https://sqliter.grantramsay.dev)
52
52
 
53
53
  > [!CAUTION]
54
54
  > This project is still in the early stages of development and is lacking some
55
- > planned functionality. Please use with caution.
55
+ > planned functionality. Please use with caution - Classes and methods may
56
+ > change until a stable release is made. I'll try to keep this to an absolute
57
+ > minimum and the releases and documentation will be very clear about any
58
+ > breaking changes.
56
59
  >
57
60
  > Also, structures like `list`, `dict`, `set` etc are not supported **at this
58
61
  > time** as field types, since SQLite does not have a native column type for
59
- > these. I will look at implementing these in the future, probably by
60
- > serializing them to JSON or pickling them and storing in a text field. For
61
- > now, you can actually do this manually when creating your Model (use `TEXT` or
62
- > `BLOB` fields), then serialize before saving after and retrieving data.
62
+ > these. This is the **next planned enhancement**. These will need to be
63
+ > `pickled` first then stored as a BLOB in the database . Also support `date`
64
+ > which can be stored as a Unix timestamp in an integer field.
63
65
  >
64
66
  > See the [TODO](TODO.md) for planned features and improvements.
65
67
 
@@ -73,11 +75,15 @@ Website](https://sqliter.grantramsay.dev)
73
75
  ## Features
74
76
 
75
77
  - Table creation based on Pydantic models
78
+ - Automatic primary key generation
79
+ - User defined indexes on any field
80
+ - Set any field as UNIQUE
76
81
  - CRUD operations (Create, Read, Update, Delete)
77
- - Basic query building with filtering, ordering, and pagination
82
+ - Chained Query building with filtering, ordering, and pagination
78
83
  - Transaction support
79
84
  - Custom exceptions for better error handling
80
85
  - Full type hinting and type checking
86
+ - Detailed documentation and examples
81
87
  - No external dependencies other than Pydantic
82
88
  - Full test coverage
83
89
  - Can optionally output the raw SQL queries being executed for debugging
@@ -95,16 +101,16 @@ virtual environments (`uv` is used for developing this project and in the CI):
95
101
  uv add sqliter-py
96
102
  ```
97
103
 
98
- With `pip`:
104
+ With `Poetry`:
99
105
 
100
106
  ```bash
101
- pip install sqliter-py
107
+ poetry add sqliter-py
102
108
  ```
103
109
 
104
- Or with `Poetry`:
110
+ Or with `pip`:
105
111
 
106
112
  ```bash
107
- poetry add sqliter-py
113
+ pip install sqliter-py
108
114
  ```
109
115
 
110
116
  ### Optional Dependencies
@@ -113,9 +119,9 @@ Currently by default, the only external dependency is Pydantic. However, there
113
119
  are some optional dependencies that can be installed to enable additional
114
120
  features:
115
121
 
116
- - `inflect`: For pluralizing table names (if not specified). This just offers a
117
- more-advanced pluralization than the default method used. In most cases you
118
- will not need this.
122
+ - `inflect`: For pluralizing the auto-generated table names (if not explicitly
123
+ set in the Model) This just offers a more-advanced pluralization than the
124
+ default method used. In most cases you will not need this.
119
125
 
120
126
  See [Installing Optional
121
127
  Dependencies](https://sqliter.grantramsay.dev/installation#optional-dependencies)
@@ -142,7 +148,7 @@ db.create_table(User)
142
148
 
143
149
  # Insert a record
144
150
  user = User(name="John Doe", age=30)
145
- db.insert(user)
151
+ new_user = db.insert(user)
146
152
 
147
153
  # Query records
148
154
  results = db.select(User).filter(name="John Doe").fetch_all()
@@ -150,11 +156,11 @@ for user in results:
150
156
  print(f"User: {user.name}, Age: {user.age}")
151
157
 
152
158
  # Update a record
153
- user.age = 31
154
- db.update(user)
159
+ new_user.age = 31
160
+ db.update(new_user)
155
161
 
156
162
  # Delete a record
157
- db.delete(User, "John Doe")
163
+ db.delete(User, new_user.pk)
158
164
  ```
159
165
 
160
166
  See the [Usage](https://sqliter.grantramsay.dev/usage) section of the documentation
@@ -24,14 +24,16 @@ Website](https://sqliter.grantramsay.dev)
24
24
 
25
25
  > [!CAUTION]
26
26
  > This project is still in the early stages of development and is lacking some
27
- > planned functionality. Please use with caution.
27
+ > planned functionality. Please use with caution - Classes and methods may
28
+ > change until a stable release is made. I'll try to keep this to an absolute
29
+ > minimum and the releases and documentation will be very clear about any
30
+ > breaking changes.
28
31
  >
29
32
  > Also, structures like `list`, `dict`, `set` etc are not supported **at this
30
33
  > time** as field types, since SQLite does not have a native column type for
31
- > these. I will look at implementing these in the future, probably by
32
- > serializing them to JSON or pickling them and storing in a text field. For
33
- > now, you can actually do this manually when creating your Model (use `TEXT` or
34
- > `BLOB` fields), then serialize before saving after and retrieving data.
34
+ > these. This is the **next planned enhancement**. These will need to be
35
+ > `pickled` first then stored as a BLOB in the database . Also support `date`
36
+ > which can be stored as a Unix timestamp in an integer field.
35
37
  >
36
38
  > See the [TODO](TODO.md) for planned features and improvements.
37
39
 
@@ -45,11 +47,15 @@ Website](https://sqliter.grantramsay.dev)
45
47
  ## Features
46
48
 
47
49
  - Table creation based on Pydantic models
50
+ - Automatic primary key generation
51
+ - User defined indexes on any field
52
+ - Set any field as UNIQUE
48
53
  - CRUD operations (Create, Read, Update, Delete)
49
- - Basic query building with filtering, ordering, and pagination
54
+ - Chained Query building with filtering, ordering, and pagination
50
55
  - Transaction support
51
56
  - Custom exceptions for better error handling
52
57
  - Full type hinting and type checking
58
+ - Detailed documentation and examples
53
59
  - No external dependencies other than Pydantic
54
60
  - Full test coverage
55
61
  - Can optionally output the raw SQL queries being executed for debugging
@@ -67,16 +73,16 @@ virtual environments (`uv` is used for developing this project and in the CI):
67
73
  uv add sqliter-py
68
74
  ```
69
75
 
70
- With `pip`:
76
+ With `Poetry`:
71
77
 
72
78
  ```bash
73
- pip install sqliter-py
79
+ poetry add sqliter-py
74
80
  ```
75
81
 
76
- Or with `Poetry`:
82
+ Or with `pip`:
77
83
 
78
84
  ```bash
79
- poetry add sqliter-py
85
+ pip install sqliter-py
80
86
  ```
81
87
 
82
88
  ### Optional Dependencies
@@ -85,9 +91,9 @@ Currently by default, the only external dependency is Pydantic. However, there
85
91
  are some optional dependencies that can be installed to enable additional
86
92
  features:
87
93
 
88
- - `inflect`: For pluralizing table names (if not specified). This just offers a
89
- more-advanced pluralization than the default method used. In most cases you
90
- will not need this.
94
+ - `inflect`: For pluralizing the auto-generated table names (if not explicitly
95
+ set in the Model) This just offers a more-advanced pluralization than the
96
+ default method used. In most cases you will not need this.
91
97
 
92
98
  See [Installing Optional
93
99
  Dependencies](https://sqliter.grantramsay.dev/installation#optional-dependencies)
@@ -114,7 +120,7 @@ db.create_table(User)
114
120
 
115
121
  # Insert a record
116
122
  user = User(name="John Doe", age=30)
117
- db.insert(user)
123
+ new_user = db.insert(user)
118
124
 
119
125
  # Query records
120
126
  results = db.select(User).filter(name="John Doe").fetch_all()
@@ -122,11 +128,11 @@ for user in results:
122
128
  print(f"User: {user.name}, Age: {user.age}")
123
129
 
124
130
  # Update a record
125
- user.age = 31
126
- db.update(user)
131
+ new_user.age = 31
132
+ db.update(new_user)
127
133
 
128
134
  # Delete a record
129
- db.delete(User, "John Doe")
135
+ db.delete(User, new_user.pk)
130
136
  ```
131
137
 
132
138
  See the [Usage](https://sqliter.grantramsay.dev/usage) section of the documentation
@@ -3,7 +3,7 @@
3
3
 
4
4
  [project]
5
5
  name = "sqliter-py"
6
- version = "0.4.0"
6
+ version = "0.6.0"
7
7
  description = "Interact with SQLite databases using Python and Pydantic"
8
8
  readme = "README.md"
9
9
  requires-python = ">=3.9"
@@ -144,9 +144,10 @@ known-first-party = ["sqliter"]
144
144
  keep-runtime-typing = true
145
145
 
146
146
  [tool.mypy]
147
+ plugins = ["pydantic.mypy"]
148
+
147
149
  python_version = "3.9"
148
150
  exclude = ["docs"]
149
-
150
151
  [[tool.mypy.overrides]]
151
152
  disable_error_code = ["method-assign", "no-untyped-def", "attr-defined"]
152
153
  module = "tests.*"
@@ -114,7 +114,7 @@ class RecordUpdateError(SqliterError):
114
114
  class RecordNotFoundError(SqliterError):
115
115
  """Exception raised when a requested record is not found in the database."""
116
116
 
117
- message_template = "Failed to find a record for key '{}' "
117
+ message_template = "Failed to find that record in the table (key '{}') "
118
118
 
119
119
 
120
120
  class RecordFetchError(SqliterError):
@@ -145,3 +145,24 @@ class SqlExecutionError(SqliterError):
145
145
  """Raised when an SQL execution fails."""
146
146
 
147
147
  message_template = "Failed to execute SQL: '{}'"
148
+
149
+
150
+ class InvalidIndexError(SqliterError):
151
+ """Exception raised when an invalid index field is specified.
152
+
153
+ This error is triggered if one or more fields specified for an index
154
+ do not exist in the model's fields.
155
+
156
+ Attributes:
157
+ invalid_fields (list[str]): The list of fields that were invalid.
158
+ model_class (str): The name of the model where the error occurred.
159
+ """
160
+
161
+ message_template = "Invalid fields for indexing in model '{}': {}"
162
+
163
+ def __init__(self, invalid_fields: list[str], model_class: str) -> None:
164
+ """Tidy up the error message by joining the invalid fields."""
165
+ # Join invalid fields into a comma-separated string
166
+ invalid_fields_str = ", ".join(invalid_fields)
167
+ # Pass the formatted message to the parent class
168
+ super().__init__(model_class, invalid_fields_str)
@@ -1,9 +1,11 @@
1
1
  """This module provides the base model class for SQLiter database models.
2
2
 
3
3
  It exports the BaseDBModel class, which is used to define database
4
- models in SQLiter applications.
4
+ models in SQLiter applications, and the Unique class, which is used to
5
+ define unique constraints on model fields.
5
6
  """
6
7
 
7
8
  from .model import BaseDBModel
9
+ from .unique import Unique
8
10
 
9
- __all__ = ["BaseDBModel"]
11
+ __all__ = ["BaseDBModel", "Unique"]
@@ -10,9 +10,18 @@ in SQLiter applications.
10
10
  from __future__ import annotations
11
11
 
12
12
  import re
13
- from typing import Any, Optional, TypeVar, Union, get_args, get_origin
14
-
15
- from pydantic import BaseModel, ConfigDict
13
+ from typing import (
14
+ Any,
15
+ ClassVar,
16
+ Optional,
17
+ TypeVar,
18
+ Union,
19
+ cast,
20
+ get_args,
21
+ get_origin,
22
+ )
23
+
24
+ from pydantic import BaseModel, ConfigDict, Field
16
25
 
17
26
  T = TypeVar("T", bound="BaseDBModel")
18
27
 
@@ -28,6 +37,8 @@ class BaseDBModel(BaseModel):
28
37
  representing database models.
29
38
  """
30
39
 
40
+ pk: int = Field(0, description="The mandatory primary key of the table.")
41
+
31
42
  model_config = ConfigDict(
32
43
  extra="ignore",
33
44
  populate_by_name=True,
@@ -39,18 +50,24 @@ class BaseDBModel(BaseModel):
39
50
  """Metadata class for configuring database-specific attributes.
40
51
 
41
52
  Attributes:
42
- create_pk (bool): Whether to create a primary key field.
43
- primary_key (str): The name of the primary key field.
44
- table_name (Optional[str]): The name of the database table.
53
+ table_name (Optional[str]): The name of the database table. If not
54
+ specified, the table name will be inferred from the model class
55
+ name and converted to snake_case.
56
+ indexes (ClassVar[list[Union[str, tuple[str]]]]): A list of fields
57
+ or tuples of fields for which regular (non-unique) indexes
58
+ should be created. Indexes improve query performance on these
59
+ fields.
60
+ unique_indexes (ClassVar[list[Union[str, tuple[str]]]]): A list of
61
+ fields or tuples of fields for which unique indexes should be
62
+ created. Unique indexes enforce that all values in these fields
63
+ are distinct across the table.
45
64
  """
46
65
 
47
- create_pk: bool = (
48
- True # Whether to create an auto-increment primary key
49
- )
50
- primary_key: str = "id" # Default primary key name
51
66
  table_name: Optional[str] = (
52
67
  None # Table name, defaults to class name if not set
53
68
  )
69
+ indexes: ClassVar[list[Union[str, tuple[str]]]] = []
70
+ unique_indexes: ClassVar[list[Union[str, tuple[str]]]] = []
54
71
 
55
72
  @classmethod
56
73
  def model_validate_partial(cls: type[T], obj: dict[str, Any]) -> T:
@@ -89,7 +106,7 @@ class BaseDBModel(BaseModel):
89
106
  else:
90
107
  converted_obj[field_name] = field_type(value)
91
108
 
92
- return cls.model_construct(**converted_obj)
109
+ return cast(T, cls.model_construct(**converted_obj))
93
110
 
94
111
  @classmethod
95
112
  def get_table_name(cls) -> str:
@@ -127,18 +144,10 @@ class BaseDBModel(BaseModel):
127
144
 
128
145
  @classmethod
129
146
  def get_primary_key(cls) -> str:
130
- """Get the primary key field name for the model.
131
-
132
- Returns:
133
- The name of the primary key field.
134
- """
135
- return getattr(cls.Meta, "primary_key", "id")
147
+ """Returns the mandatory primary key, always 'pk'."""
148
+ return "pk"
136
149
 
137
150
  @classmethod
138
151
  def should_create_pk(cls) -> bool:
139
- """Determine if a primary key should be automatically created.
140
-
141
- Returns:
142
- True if a primary key should be created, False otherwise.
143
- """
144
- return getattr(cls.Meta, "create_pk", True)
152
+ """Returns True since the primary key is always created."""
153
+ return True
@@ -0,0 +1,19 @@
1
+ """Define a custom field type for unique constraints in SQLiter."""
2
+
3
+ from typing import Any
4
+
5
+ from pydantic.fields import FieldInfo
6
+
7
+
8
+ class Unique(FieldInfo):
9
+ """A custom field type for unique constraints in SQLiter."""
10
+
11
+ def __init__(self, default: Any = ..., **kwargs: Any) -> None: # noqa: ANN401
12
+ """Initialize a Unique field.
13
+
14
+ Args:
15
+ default: The default value for the field.
16
+ **kwargs: Additional keyword arguments to pass to FieldInfo.
17
+ """
18
+ super().__init__(default=default, **kwargs)
19
+ self.unique = True
@@ -129,8 +129,11 @@ class QueryBuilder:
129
129
  field_name, operator = self._parse_field_operator(field)
130
130
  self._validate_field(field_name, valid_fields)
131
131
 
132
- handler = self._get_operator_handler(operator)
133
- handler(field_name, value, operator)
132
+ if operator in ["__isnull", "__notnull"]:
133
+ self._handle_null(field_name, value, operator)
134
+ else:
135
+ handler = self._get_operator_handler(operator)
136
+ handler(field_name, value, operator)
134
137
 
135
138
  return self
136
139
 
@@ -145,6 +148,8 @@ class QueryBuilder:
145
148
  The QueryBuilder instance for method chaining.
146
149
  """
147
150
  if fields:
151
+ if "pk" not in fields:
152
+ fields.append("pk")
148
153
  self._fields = fields
149
154
  self._validate_fields()
150
155
  return self
@@ -164,6 +169,9 @@ class QueryBuilder:
164
169
  invalid fields are specified.
165
170
  """
166
171
  if fields:
172
+ if "pk" in fields:
173
+ err = "The primary key 'pk' cannot be excluded."
174
+ raise ValueError(err)
167
175
  all_fields = set(self.model_class.model_fields.keys())
168
176
 
169
177
  # Check for invalid fields before subtraction
@@ -179,7 +187,7 @@ class QueryBuilder:
179
187
  self._fields = list(all_fields - set(fields))
180
188
 
181
189
  # Explicit check: raise an error if no fields remain
182
- if not self._fields:
190
+ if self._fields == ["pk"]:
183
191
  err = "Exclusion results in no fields being selected."
184
192
  raise ValueError(err)
185
193
 
@@ -208,7 +216,7 @@ class QueryBuilder:
208
216
  raise ValueError(err)
209
217
 
210
218
  # Set self._fields to just the single field
211
- self._fields = [field]
219
+ self._fields = [field, "pk"]
212
220
  return self
213
221
 
214
222
  def _get_operator_handler(
@@ -275,7 +283,7 @@ class QueryBuilder:
275
283
  self.filters.append((field_name, value, operator))
276
284
 
277
285
  def _handle_null(
278
- self, field_name: str, _: FilterValue, operator: str
286
+ self, field_name: str, value: Union[str, float, None], operator: str
279
287
  ) -> None:
280
288
  """Handle IS NULL and IS NOT NULL filter conditions.
281
289
 
@@ -283,15 +291,14 @@ class QueryBuilder:
283
291
  field_name: The name of the field to filter on. _: Placeholder for
284
292
  unused value parameter.
285
293
  operator: The operator string ('__isnull' or '__notnull').
294
+ value: The value to check for.
286
295
 
287
296
  This method adds an IS NULL or IS NOT NULL condition to the filters
288
297
  list.
289
298
  """
290
- condition = (
291
- f"{field_name} IS NOT NULL"
292
- if operator == "__notnull"
293
- else f"{field_name} IS NULL"
294
- )
299
+ is_null = operator == "__isnull"
300
+ check_null = bool(value) if is_null else not bool(value)
301
+ condition = f"{field_name} IS {'NOT ' if not check_null else ''}NULL"
295
302
  self.filters.append((condition, None, operator))
296
303
 
297
304
  def _handle_in(
@@ -527,6 +534,8 @@ class QueryBuilder:
527
534
  if count_only:
528
535
  fields = "COUNT(*)"
529
536
  elif self._fields:
537
+ if "pk" not in self._fields:
538
+ self._fields.append("pk")
530
539
  fields = ", ".join(f'"{field}"' for field in self._fields)
531
540
  else:
532
541
  fields = ", ".join(
@@ -10,12 +10,13 @@ from __future__ import annotations
10
10
 
11
11
  import logging
12
12
  import sqlite3
13
- from typing import TYPE_CHECKING, Any, Optional
13
+ from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
14
14
 
15
15
  from typing_extensions import Self
16
16
 
17
17
  from sqliter.exceptions import (
18
18
  DatabaseConnectionError,
19
+ InvalidIndexError,
19
20
  RecordDeletionError,
20
21
  RecordFetchError,
21
22
  RecordInsertionError,
@@ -26,6 +27,7 @@ from sqliter.exceptions import (
26
27
  TableDeletionError,
27
28
  )
28
29
  from sqliter.helpers import infer_sqlite_type
30
+ from sqliter.model.unique import Unique
29
31
  from sqliter.query.query import QueryBuilder
30
32
 
31
33
  if TYPE_CHECKING: # pragma: no cover
@@ -33,6 +35,8 @@ if TYPE_CHECKING: # pragma: no cover
33
35
 
34
36
  from sqliter.model.model import BaseDBModel
35
37
 
38
+ T = TypeVar("T", bound="BaseDBModel")
39
+
36
40
 
37
41
  class SqliterDB:
38
42
  """Main class for interacting with SQLite databases.
@@ -87,6 +91,8 @@ class SqliterDB:
87
91
  self.conn: Optional[sqlite3.Connection] = None
88
92
  self.reset = reset
89
93
 
94
+ self._in_transaction = False
95
+
90
96
  if self.debug:
91
97
  self._setup_logger()
92
98
 
@@ -223,34 +229,23 @@ class SqliterDB:
223
229
  """
224
230
  table_name = model_class.get_table_name()
225
231
  primary_key = model_class.get_primary_key()
226
- create_pk = model_class.should_create_pk()
227
232
 
228
233
  if force:
229
234
  drop_table_sql = f"DROP TABLE IF EXISTS {table_name}"
230
235
  self._execute_sql(drop_table_sql)
231
236
 
232
- fields = []
233
-
234
- # Always add the primary key field first
235
- if create_pk:
236
- fields.append(f"{primary_key} INTEGER PRIMARY KEY AUTOINCREMENT")
237
- else:
238
- field_info = model_class.model_fields.get(primary_key)
239
- if field_info is not None:
240
- sqlite_type = infer_sqlite_type(field_info.annotation)
241
- fields.append(f"{primary_key} {sqlite_type} PRIMARY KEY")
242
- else:
243
- err = (
244
- f"Primary key field '{primary_key}' not found in model "
245
- "fields."
246
- )
247
- raise ValueError(err)
237
+ fields = [f'"{primary_key}" INTEGER PRIMARY KEY AUTOINCREMENT']
248
238
 
249
239
  # Add remaining fields
250
240
  for field_name, field_info in model_class.model_fields.items():
251
241
  if field_name != primary_key:
252
242
  sqlite_type = infer_sqlite_type(field_info.annotation)
253
- fields.append(f"{field_name} {sqlite_type}")
243
+ unique_constraint = (
244
+ "UNIQUE" if isinstance(field_info, Unique) else ""
245
+ )
246
+ fields.append(
247
+ f"{field_name} {sqlite_type} {unique_constraint}".strip()
248
+ )
254
249
 
255
250
  create_str = (
256
251
  "CREATE TABLE IF NOT EXISTS" if exists_ok else "CREATE TABLE"
@@ -273,6 +268,65 @@ class SqliterDB:
273
268
  except sqlite3.Error as exc:
274
269
  raise TableCreationError(table_name) from exc
275
270
 
271
+ # Create regular indexes
272
+ if hasattr(model_class.Meta, "indexes"):
273
+ self._create_indexes(
274
+ model_class, model_class.Meta.indexes, unique=False
275
+ )
276
+
277
+ # Create unique indexes
278
+ if hasattr(model_class.Meta, "unique_indexes"):
279
+ self._create_indexes(
280
+ model_class, model_class.Meta.unique_indexes, unique=True
281
+ )
282
+
283
+ def _create_indexes(
284
+ self,
285
+ model_class: type[BaseDBModel],
286
+ indexes: list[Union[str, tuple[str]]],
287
+ *,
288
+ unique: bool = False,
289
+ ) -> None:
290
+ """Helper method to create regular or unique indexes.
291
+
292
+ Args:
293
+ model_class: The model class defining the table.
294
+ indexes: List of fields or tuples of fields to create indexes for.
295
+ unique: If True, creates UNIQUE indexes; otherwise, creates regular
296
+ indexes.
297
+
298
+ Raises:
299
+ InvalidIndexError: If any fields specified for indexing do not exist
300
+ in the model.
301
+ """
302
+ valid_fields = set(
303
+ model_class.model_fields.keys()
304
+ ) # Get valid fields from the model
305
+
306
+ for index in indexes:
307
+ # Handle multiple fields in tuple form
308
+ fields = list(index) if isinstance(index, tuple) else [index]
309
+
310
+ # Check if all fields exist in the model
311
+ invalid_fields = [
312
+ field for field in fields if field not in valid_fields
313
+ ]
314
+ if invalid_fields:
315
+ raise InvalidIndexError(invalid_fields, model_class.__name__)
316
+
317
+ # Build the SQL string
318
+ index_name = "_".join(fields)
319
+ index_postfix = "_unique" if unique else ""
320
+ index_type = " UNIQUE " if unique else " "
321
+
322
+ create_index_sql = (
323
+ f"CREATE{index_type}INDEX IF NOT EXISTS "
324
+ f"idx_{model_class.get_table_name()}"
325
+ f"_{index_name}{index_postfix} "
326
+ f"ON {model_class.get_table_name()} ({', '.join(fields)})"
327
+ )
328
+ self._execute_sql(create_index_sql)
329
+
276
330
  def _execute_sql(self, sql: str) -> None:
277
331
  """Execute an SQL statement.
278
332
 
@@ -322,22 +376,31 @@ class SqliterDB:
322
376
  This method is called after operations that modify the database,
323
377
  committing changes only if auto_commit is set to True.
324
378
  """
325
- if self.auto_commit and self.conn:
379
+ if not self._in_transaction and self.auto_commit and self.conn:
326
380
  self.conn.commit()
327
381
 
328
- def insert(self, model_instance: BaseDBModel) -> None:
382
+ def insert(self, model_instance: T) -> T:
329
383
  """Insert a new record into the database.
330
384
 
331
385
  Args:
332
- model_instance: An instance of a Pydantic model to be inserted.
386
+ model_instance: The instance of the model class to insert.
387
+
388
+ Returns:
389
+ The updated model instance with the primary key (pk) set.
333
390
 
334
391
  Raises:
335
- RecordInsertionError: If there's an error inserting the record.
392
+ RecordInsertionError: If an error occurs during the insertion.
336
393
  """
337
394
  model_class = type(model_instance)
338
395
  table_name = model_class.get_table_name()
339
396
 
397
+ # Get the data from the model
340
398
  data = model_instance.model_dump()
399
+ # remove the primary key field if it exists, otherwise we'll get
400
+ # TypeErrors as multiple primary keys will exist
401
+ if data.get("pk", None) == 0:
402
+ data.pop("pk")
403
+
341
404
  fields = ", ".join(data.keys())
342
405
  placeholders = ", ".join(
343
406
  ["?" if value is not None else "NULL" for value in data.values()]
@@ -354,11 +417,15 @@ class SqliterDB:
354
417
  cursor = conn.cursor()
355
418
  cursor.execute(insert_sql, values)
356
419
  self._maybe_commit()
420
+
357
421
  except sqlite3.Error as exc:
358
422
  raise RecordInsertionError(table_name) from exc
423
+ else:
424
+ data.pop("pk", None)
425
+ return model_class(pk=cursor.lastrowid, **data)
359
426
 
360
427
  def get(
361
- self, model_class: type[BaseDBModel], primary_key_value: str
428
+ self, model_class: type[BaseDBModel], primary_key_value: int
362
429
  ) -> BaseDBModel | None:
363
430
  """Retrieve a single record from the database by its primary key.
364
431
 
@@ -405,11 +472,12 @@ class SqliterDB:
405
472
  model_instance: An instance of a Pydantic model to be updated.
406
473
 
407
474
  Raises:
408
- RecordUpdateError: If there's an error updating the record.
409
- RecordNotFoundError: If the record to update is not found.
475
+ RecordUpdateError: If there's an error updating the record or if it
476
+ is not found.
410
477
  """
411
478
  model_class = type(model_instance)
412
479
  table_name = model_class.get_table_name()
480
+
413
481
  primary_key = model_class.get_primary_key()
414
482
 
415
483
  fields = ", ".join(
@@ -515,6 +583,7 @@ class SqliterDB:
515
583
 
516
584
  """
517
585
  self.connect()
586
+ self._in_transaction = True
518
587
  return self
519
588
 
520
589
  def __exit__(
@@ -552,3 +621,4 @@ class SqliterDB:
552
621
  # Close the connection and reset the instance variable
553
622
  self.conn.close()
554
623
  self.conn = None
624
+ self._in_transaction = False
File without changes
File without changes