activemodel 0.5.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- activemodel/__init__.py +2 -0
- activemodel/base_model.py +141 -33
- activemodel/celery.py +33 -0
- activemodel/errors.py +6 -0
- activemodel/get_column_from_field_patch.py +139 -0
- activemodel/mixins/__init__.py +2 -0
- activemodel/mixins/pydantic_json.py +82 -0
- activemodel/mixins/soft_delete.py +17 -0
- activemodel/mixins/typeid.py +27 -17
- activemodel/pytest/transaction.py +34 -22
- activemodel/pytest/truncate.py +1 -1
- activemodel/query_wrapper.py +24 -10
- activemodel/session_manager.py +92 -5
- activemodel/types/__init__.py +1 -0
- activemodel/types/typeid.py +141 -5
- activemodel/utils.py +51 -1
- activemodel-0.8.0.dist-info/METADATA +282 -0
- activemodel-0.8.0.dist-info/RECORD +24 -0
- {activemodel-0.5.0.dist-info → activemodel-0.8.0.dist-info}/WHEEL +1 -2
- activemodel/_session_manager.py +0 -153
- activemodel-0.5.0.dist-info/METADATA +0 -66
- activemodel-0.5.0.dist-info/RECORD +0 -20
- activemodel-0.5.0.dist-info/top_level.txt +0 -1
- {activemodel-0.5.0.dist-info → activemodel-0.8.0.dist-info}/entry_points.txt +0 -0
- {activemodel-0.5.0.dist-info → activemodel-0.8.0.dist-info/licenses}/LICENSE +0 -0
activemodel/__init__.py
CHANGED
activemodel/base_model.py
CHANGED
|
@@ -1,33 +1,41 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import typing as t
|
|
3
|
-
from
|
|
3
|
+
from uuid import UUID
|
|
4
4
|
|
|
5
5
|
import pydash
|
|
6
6
|
import sqlalchemy as sa
|
|
7
7
|
import sqlmodel as sm
|
|
8
8
|
from sqlalchemy import Connection, event
|
|
9
9
|
from sqlalchemy.orm import Mapper, declared_attr
|
|
10
|
-
from sqlmodel import Session, SQLModel, select
|
|
10
|
+
from sqlmodel import Field, MetaData, Session, SQLModel, select
|
|
11
11
|
from typeid import TypeID
|
|
12
|
+
from inspect import isclass
|
|
12
13
|
|
|
14
|
+
from activemodel.mixins.pydantic_json import PydanticJSONMixin
|
|
15
|
+
|
|
16
|
+
# NOTE: this patches a core method in sqlmodel to support db comments
|
|
17
|
+
from . import get_column_from_field_patch # noqa: F401
|
|
13
18
|
from .logger import logger
|
|
14
19
|
from .query_wrapper import QueryWrapper
|
|
15
20
|
from .session_manager import get_session
|
|
16
21
|
|
|
22
|
+
POSTGRES_INDEXES_NAMING_CONVENTION = {
|
|
23
|
+
"ix": "%(column_0_label)s_idx",
|
|
24
|
+
"uq": "%(table_name)s_%(column_0_name)s_key",
|
|
25
|
+
"ck": "%(table_name)s_%(constraint_name)s_check",
|
|
26
|
+
"fk": "%(table_name)s_%(column_0_name)s_fkey",
|
|
27
|
+
"pk": "%(table_name)s_pkey",
|
|
28
|
+
}
|
|
29
|
+
"""
|
|
30
|
+
By default, the foreign key naming convention in sqlalchemy do not create unique identifiers when there are multiple
|
|
31
|
+
foreign keys in a table. This naming convention is a workaround to fix this issue:
|
|
17
32
|
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
Helper class to ease validation in SQLModel classes with table=True
|
|
23
|
-
"""
|
|
33
|
+
- https://github.com/zhanymkanov/fastapi-best-practices?tab=readme-ov-file#set-db-keys-naming-conventions
|
|
34
|
+
- https://github.com/fastapi/sqlmodel/discussions/1213
|
|
35
|
+
- Implementation lifted from: https://github.com/AlexanderZharyuk/billing-service/blob/3c8aaf19ab7546b97cc4db76f60335edec9fc79d/src/models.py#L24
|
|
36
|
+
"""
|
|
24
37
|
|
|
25
|
-
|
|
26
|
-
def create(cls, **kwargs):
|
|
27
|
-
"""
|
|
28
|
-
Forces validation to take place, even for SQLModel classes with table=True
|
|
29
|
-
"""
|
|
30
|
-
return cls(**cls.__bases__[0](**kwargs).model_dump())
|
|
38
|
+
SQLModel.metadata.naming_convention = POSTGRES_INDEXES_NAMING_CONVENTION
|
|
31
39
|
|
|
32
40
|
|
|
33
41
|
class BaseModel(SQLModel):
|
|
@@ -36,10 +44,14 @@ class BaseModel(SQLModel):
|
|
|
36
44
|
|
|
37
45
|
https://github.com/woofz/sqlmodel-basecrud/blob/main/sqlmodel_basecrud/basecrud.py
|
|
38
46
|
|
|
39
|
-
{before,after} hooks are modeled after Rails.
|
|
47
|
+
- {before,after} lifecycle hooks are modeled after Rails.
|
|
48
|
+
- class docstrings are converd to table-level comments
|
|
49
|
+
- save(), delete(), select(), where(), and other easy methods you would expect
|
|
50
|
+
- Fixes foreign key naming conventions
|
|
40
51
|
"""
|
|
41
52
|
|
|
42
|
-
#
|
|
53
|
+
# this is used for table-level comments
|
|
54
|
+
__table_args__ = None
|
|
43
55
|
|
|
44
56
|
@classmethod
|
|
45
57
|
def __init_subclass__(cls, **kwargs):
|
|
@@ -47,6 +59,14 @@ class BaseModel(SQLModel):
|
|
|
47
59
|
|
|
48
60
|
super().__init_subclass__(**kwargs)
|
|
49
61
|
|
|
62
|
+
from sqlmodel._compat import set_config_value
|
|
63
|
+
|
|
64
|
+
# enables field-level docstrings on the pydanatic `description` field, which we then copy into
|
|
65
|
+
# sa_args, which is persisted to sql table comments
|
|
66
|
+
set_config_value(model=cls, parameter="use_attribute_docstrings", value=True)
|
|
67
|
+
|
|
68
|
+
cls._apply_class_doc()
|
|
69
|
+
|
|
50
70
|
def event_wrapper(method_name: str):
|
|
51
71
|
"""
|
|
52
72
|
This does smart heavy lifting for us to make sqlalchemy lifecycle events nicer to work with:
|
|
@@ -96,11 +116,37 @@ class BaseModel(SQLModel):
|
|
|
96
116
|
event.listen(cls, "after_insert", event_wrapper("after_save"))
|
|
97
117
|
event.listen(cls, "after_update", event_wrapper("after_save"))
|
|
98
118
|
|
|
119
|
+
# def foreign_key()
|
|
120
|
+
# table.id
|
|
121
|
+
|
|
122
|
+
@classmethod
|
|
123
|
+
def _apply_class_doc(cls):
|
|
124
|
+
"""
|
|
125
|
+
Pull class-level docstring into a table comment.
|
|
126
|
+
|
|
127
|
+
This will help AI SQL writers like: https://github.com/iloveitaly/sql-ai-prompt-generator
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
doc = cls.__doc__.strip() if cls.__doc__ else None
|
|
131
|
+
|
|
132
|
+
if doc:
|
|
133
|
+
table_args = getattr(cls, "__table_args__", None)
|
|
134
|
+
|
|
135
|
+
if table_args is None:
|
|
136
|
+
cls.__table_args__ = {"comment": doc}
|
|
137
|
+
elif isinstance(table_args, dict):
|
|
138
|
+
table_args.setdefault("comment", doc)
|
|
139
|
+
else:
|
|
140
|
+
raise ValueError("Unexpected __table_args__ type")
|
|
141
|
+
|
|
99
142
|
# TODO no type check decorator here
|
|
100
143
|
@declared_attr
|
|
101
144
|
def __tablename__(cls) -> str:
|
|
102
145
|
"""
|
|
103
146
|
Automatically generates the table name for the model by converting the class name from camel case to snake case.
|
|
147
|
+
This is the recommended format for table names:
|
|
148
|
+
|
|
149
|
+
https://wiki.postgresql.org/wiki/Don%27t_Do_This#Don.27t_use_upper_case_table_or_column_names
|
|
104
150
|
|
|
105
151
|
By default, the class is lower cased which makes it harder to read.
|
|
106
152
|
|
|
@@ -111,10 +157,43 @@ class BaseModel(SQLModel):
|
|
|
111
157
|
"""
|
|
112
158
|
return pydash.strings.snake_case(cls.__name__)
|
|
113
159
|
|
|
160
|
+
@classmethod
|
|
161
|
+
def foreign_key(cls, **kwargs):
|
|
162
|
+
"""
|
|
163
|
+
Returns a `Field` object referencing the foreign key of the model.
|
|
164
|
+
|
|
165
|
+
>>> other_model_id: int
|
|
166
|
+
>>> other_model = OtherModel.foreign_key()
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
field_options = {"nullable": False} | kwargs
|
|
170
|
+
|
|
171
|
+
return Field(
|
|
172
|
+
# TODO id field is hard coded, should pick the PK field in case it's different
|
|
173
|
+
sa_type=cls.model_fields["id"].sa_column.type, # type: ignore
|
|
174
|
+
foreign_key=f"{cls.__tablename__}.id",
|
|
175
|
+
**field_options,
|
|
176
|
+
)
|
|
177
|
+
|
|
114
178
|
@classmethod
|
|
115
179
|
def select(cls, *args):
|
|
180
|
+
"create a query wrapper to easily run sqlmodel queries on this model"
|
|
116
181
|
return QueryWrapper[cls](cls, *args)
|
|
117
182
|
|
|
183
|
+
@classmethod
|
|
184
|
+
def where(cls, *args):
|
|
185
|
+
"convenience method to avoid having to write .select().where() in order to add conditions"
|
|
186
|
+
return cls.select().where(*args)
|
|
187
|
+
|
|
188
|
+
def delete(self):
|
|
189
|
+
with get_session() as session:
|
|
190
|
+
if old_session := Session.object_session(self):
|
|
191
|
+
old_session.expunge(self)
|
|
192
|
+
|
|
193
|
+
session.delete(self)
|
|
194
|
+
session.commit()
|
|
195
|
+
return True
|
|
196
|
+
|
|
118
197
|
def save(self):
|
|
119
198
|
with get_session() as session:
|
|
120
199
|
if old_session := Session.object_session(self):
|
|
@@ -129,11 +208,11 @@ class BaseModel(SQLModel):
|
|
|
129
208
|
session.commit()
|
|
130
209
|
session.refresh(self)
|
|
131
210
|
|
|
132
|
-
|
|
211
|
+
# Only call the transform method if the class is a subclass of PydanticJSONMixin
|
|
212
|
+
if issubclass(self.__class__, PydanticJSONMixin):
|
|
213
|
+
self.__class__.__transform_dict_to_pydantic__(self)
|
|
133
214
|
|
|
134
|
-
|
|
135
|
-
# log.quiet(f"{self} already exists in the database.")
|
|
136
|
-
# session.rollback()
|
|
215
|
+
return self
|
|
137
216
|
|
|
138
217
|
# TODO shouldn't this be handled by pydantic?
|
|
139
218
|
def json(self, **kwargs):
|
|
@@ -145,10 +224,16 @@ class BaseModel(SQLModel):
|
|
|
145
224
|
"""
|
|
146
225
|
Returns the number of records in the database.
|
|
147
226
|
"""
|
|
148
|
-
|
|
227
|
+
with get_session() as session:
|
|
228
|
+
return session.scalar(sm.select(sm.func.count()).select_from(cls))
|
|
229
|
+
|
|
230
|
+
# TODO got to be a better way to fwd these along...
|
|
231
|
+
@classmethod
|
|
232
|
+
def first(cls):
|
|
233
|
+
return cls.select().first()
|
|
149
234
|
|
|
150
235
|
# TODO throw an error if this field is set on the model
|
|
151
|
-
def is_new(self):
|
|
236
|
+
def is_new(self) -> bool:
|
|
152
237
|
return not self._sa_instance_state.has_identity
|
|
153
238
|
|
|
154
239
|
@classmethod
|
|
@@ -182,33 +267,56 @@ class BaseModel(SQLModel):
|
|
|
182
267
|
new_model = cls(**kwargs)
|
|
183
268
|
return new_model
|
|
184
269
|
|
|
270
|
+
@classmethod
|
|
271
|
+
def primary_key_field(cls):
|
|
272
|
+
"""
|
|
273
|
+
Returns the primary key column of the model by inspecting SQLAlchemy field information.
|
|
274
|
+
|
|
275
|
+
>>> ExampleModel.primary_key_field().name
|
|
276
|
+
"""
|
|
277
|
+
# TODO note_schema.__class__.__table__.primary_key
|
|
278
|
+
|
|
279
|
+
pk_columns = list(cls.__table__.primary_key.columns)
|
|
280
|
+
|
|
281
|
+
if not pk_columns:
|
|
282
|
+
raise ValueError("No primary key defined for the model.")
|
|
283
|
+
|
|
284
|
+
if len(pk_columns) > 1:
|
|
285
|
+
raise ValueError(
|
|
286
|
+
"Multiple primary keys defined. This method supports only single primary key models."
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
return pk_columns[0]
|
|
290
|
+
|
|
185
291
|
# TODO what's super dangerous here is you pass a kwarg which does not map to a specific
|
|
186
292
|
# field it will result in `True`, which will return all records, and not give you any typing
|
|
187
293
|
# errors. Dangerous when iterating on structure quickly
|
|
188
294
|
# TODO can we pass the generic of the superclass in?
|
|
189
|
-
|
|
295
|
+
# TODO can we type the method signature a bit better?
|
|
190
296
|
# def get(cls, *args: sa.BinaryExpression, **kwargs: t.Any):
|
|
297
|
+
@classmethod
|
|
191
298
|
def get(cls, *args: t.Any, **kwargs: t.Any):
|
|
192
299
|
"""
|
|
193
300
|
Gets a single record from the database. Pass an PK ID or a kwarg to filter by.
|
|
194
301
|
"""
|
|
195
302
|
|
|
303
|
+
# TODO id is hardcoded, not good! Need to dynamically pick the best uid field
|
|
304
|
+
id_field_name = "id"
|
|
305
|
+
|
|
196
306
|
# special case for getting by ID
|
|
197
|
-
if len(args) == 1 and isinstance(args[0], int):
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
args = []
|
|
201
|
-
elif len(args) == 1 and isinstance(args[0], TypeID):
|
|
202
|
-
kwargs["id"] = args[0]
|
|
203
|
-
args = []
|
|
307
|
+
if len(args) == 1 and isinstance(args[0], (int, TypeID, str, UUID)):
|
|
308
|
+
kwargs[id_field_name] = args[0]
|
|
309
|
+
args = ()
|
|
204
310
|
|
|
205
311
|
statement = select(cls).filter(*args).filter_by(**kwargs)
|
|
206
|
-
|
|
312
|
+
|
|
313
|
+
with get_session() as session:
|
|
314
|
+
return session.exec(statement).first()
|
|
207
315
|
|
|
208
316
|
@classmethod
|
|
209
317
|
def all(cls):
|
|
210
318
|
with get_session() as session:
|
|
211
|
-
results = session.exec(
|
|
319
|
+
results = session.exec(sm.select(cls))
|
|
212
320
|
|
|
213
321
|
# TODO do we need this or can we just return results?
|
|
214
322
|
for result in results:
|
|
@@ -222,7 +330,7 @@ class BaseModel(SQLModel):
|
|
|
222
330
|
Helpful for testing and console debugging.
|
|
223
331
|
"""
|
|
224
332
|
|
|
225
|
-
query =
|
|
333
|
+
query = sm.select(cls).order_by(sa.sql.func.random()).limit(1)
|
|
226
334
|
|
|
227
335
|
with get_session() as session:
|
|
228
336
|
return session.exec(query).one()
|
activemodel/celery.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Do not import unless you have Celery/Kombu installed.
|
|
3
|
+
|
|
4
|
+
In order for TypeID objects to be properly handled by celery, a custom encoder must be registered.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
# this is not an explicit dependency, only import this file if you have Celery installed
|
|
8
|
+
from kombu.utils.json import register_type
|
|
9
|
+
from typeid import TypeID
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def register_celery_typeid_encoder():
|
|
13
|
+
"""
|
|
14
|
+
Ensures TypeID objects passed as arguments to a delayed function are properly serialized.
|
|
15
|
+
|
|
16
|
+
Run at the top of your celery initialization script.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def class_full_name(clz) -> str:
|
|
20
|
+
return ".".join([clz.__module__, clz.__qualname__])
|
|
21
|
+
|
|
22
|
+
def _encoder(obj: TypeID) -> str:
|
|
23
|
+
return str(obj)
|
|
24
|
+
|
|
25
|
+
def _decoder(data: str) -> TypeID:
|
|
26
|
+
return TypeID.from_string(data)
|
|
27
|
+
|
|
28
|
+
register_type(
|
|
29
|
+
TypeID,
|
|
30
|
+
class_full_name(TypeID),
|
|
31
|
+
encoder=_encoder,
|
|
32
|
+
decoder=_decoder,
|
|
33
|
+
)
|
activemodel/errors.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pydantic has a great DX for adding docstrings to fields. This allows devs to easily document the fields of a model.
|
|
3
|
+
|
|
4
|
+
Making sure these docstrings make their way to the DB schema is helpful for a bunch of reasons (LLM understanding being one of them).
|
|
5
|
+
|
|
6
|
+
This patch mutates a core sqlmodel function which translates pydantic FieldInfo objects into sqlalchemy Column objects. It adds the field description as a comment to the column.
|
|
7
|
+
|
|
8
|
+
Note that FieldInfo *from pydantic* is used when a "bare" field is defined. This can be confusing, because when inspecting model fields, the class name looks exactly the same.
|
|
9
|
+
|
|
10
|
+
Some ideas for this originally sourced from: https://github.com/fastapi/sqlmodel/issues/492#issuecomment-2489858633
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from typing import (
|
|
14
|
+
TYPE_CHECKING,
|
|
15
|
+
Any,
|
|
16
|
+
Dict,
|
|
17
|
+
Sequence,
|
|
18
|
+
cast,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
import sqlmodel
|
|
22
|
+
from pydantic.fields import FieldInfo as PydanticFieldInfo
|
|
23
|
+
from sqlalchemy import (
|
|
24
|
+
Column,
|
|
25
|
+
ForeignKey,
|
|
26
|
+
)
|
|
27
|
+
from sqlmodel._compat import ( # type: ignore[attr-defined]
|
|
28
|
+
IS_PYDANTIC_V2,
|
|
29
|
+
ModelMetaclass,
|
|
30
|
+
Representation,
|
|
31
|
+
Undefined,
|
|
32
|
+
UndefinedType,
|
|
33
|
+
is_field_noneable,
|
|
34
|
+
)
|
|
35
|
+
from sqlmodel.main import FieldInfo, get_sqlalchemy_type
|
|
36
|
+
|
|
37
|
+
from activemodel.utils import hash_function_code
|
|
38
|
+
|
|
39
|
+
if TYPE_CHECKING:
|
|
40
|
+
from pydantic._internal._model_construction import ModelMetaclass as ModelMetaclass
|
|
41
|
+
from pydantic._internal._repr import Representation as Representation
|
|
42
|
+
from pydantic_core import PydanticUndefined as Undefined
|
|
43
|
+
from pydantic_core import PydanticUndefinedType as UndefinedType
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
assert (
|
|
47
|
+
hash_function_code(sqlmodel.main.get_column_from_field)
|
|
48
|
+
== "398006ef8fd8da191ca1a271ef25b6e135da0f400a80df2f29526d8674f9ec51"
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def get_column_from_field(field: PydanticFieldInfo | FieldInfo) -> Column: # type: ignore
|
|
53
|
+
"""
|
|
54
|
+
Takes a field definition, which can either come from the sqlmodel FieldInfo class or the pydantic variant of that class,
|
|
55
|
+
and converts it into a sqlalchemy Column object.
|
|
56
|
+
"""
|
|
57
|
+
if IS_PYDANTIC_V2:
|
|
58
|
+
field_info = field
|
|
59
|
+
else:
|
|
60
|
+
field_info = field.field_info
|
|
61
|
+
|
|
62
|
+
sa_column = getattr(field_info, "sa_column", Undefined)
|
|
63
|
+
if isinstance(sa_column, Column):
|
|
64
|
+
# IMPORTANT: change from the original function
|
|
65
|
+
if not sa_column.comment and (field_comment := field_info.description):
|
|
66
|
+
sa_column.comment = field_comment
|
|
67
|
+
return sa_column
|
|
68
|
+
|
|
69
|
+
primary_key = getattr(field_info, "primary_key", Undefined)
|
|
70
|
+
if primary_key is Undefined:
|
|
71
|
+
primary_key = False
|
|
72
|
+
|
|
73
|
+
index = getattr(field_info, "index", Undefined)
|
|
74
|
+
if index is Undefined:
|
|
75
|
+
index = False
|
|
76
|
+
|
|
77
|
+
nullable = not primary_key and is_field_noneable(field)
|
|
78
|
+
# Override derived nullability if the nullable property is set explicitly
|
|
79
|
+
# on the field
|
|
80
|
+
field_nullable = getattr(field_info, "nullable", Undefined) # noqa: B009
|
|
81
|
+
if field_nullable is not Undefined:
|
|
82
|
+
assert not isinstance(field_nullable, UndefinedType)
|
|
83
|
+
nullable = field_nullable
|
|
84
|
+
args = []
|
|
85
|
+
foreign_key = getattr(field_info, "foreign_key", Undefined)
|
|
86
|
+
if foreign_key is Undefined:
|
|
87
|
+
foreign_key = None
|
|
88
|
+
unique = getattr(field_info, "unique", Undefined)
|
|
89
|
+
if unique is Undefined:
|
|
90
|
+
unique = False
|
|
91
|
+
if foreign_key:
|
|
92
|
+
if field_info.ondelete == "SET NULL" and not nullable:
|
|
93
|
+
raise RuntimeError('ondelete="SET NULL" requires nullable=True')
|
|
94
|
+
assert isinstance(foreign_key, str)
|
|
95
|
+
ondelete = getattr(field_info, "ondelete", Undefined)
|
|
96
|
+
if ondelete is Undefined:
|
|
97
|
+
ondelete = None
|
|
98
|
+
assert isinstance(ondelete, (str, type(None))) # for typing
|
|
99
|
+
args.append(ForeignKey(foreign_key, ondelete=ondelete))
|
|
100
|
+
kwargs = {
|
|
101
|
+
"primary_key": primary_key,
|
|
102
|
+
"nullable": nullable,
|
|
103
|
+
"index": index,
|
|
104
|
+
"unique": unique,
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
sa_default = Undefined
|
|
108
|
+
if field_info.default_factory:
|
|
109
|
+
sa_default = field_info.default_factory
|
|
110
|
+
elif field_info.default is not Undefined:
|
|
111
|
+
sa_default = field_info.default
|
|
112
|
+
if sa_default is not Undefined:
|
|
113
|
+
kwargs["default"] = sa_default
|
|
114
|
+
|
|
115
|
+
sa_column_args = getattr(field_info, "sa_column_args", Undefined)
|
|
116
|
+
if sa_column_args is not Undefined:
|
|
117
|
+
args.extend(list(cast(Sequence[Any], sa_column_args)))
|
|
118
|
+
|
|
119
|
+
sa_column_kwargs = getattr(field_info, "sa_column_kwargs", Undefined)
|
|
120
|
+
|
|
121
|
+
# IMPORTANT: change from the original function
|
|
122
|
+
if field_info.description:
|
|
123
|
+
if sa_column_kwargs is Undefined:
|
|
124
|
+
sa_column_kwargs = {}
|
|
125
|
+
|
|
126
|
+
assert isinstance(sa_column_kwargs, dict)
|
|
127
|
+
|
|
128
|
+
# only update comments if not already set
|
|
129
|
+
if "comment" not in sa_column_kwargs:
|
|
130
|
+
sa_column_kwargs["comment"] = field_info.description
|
|
131
|
+
|
|
132
|
+
if sa_column_kwargs is not Undefined:
|
|
133
|
+
kwargs.update(cast(Dict[Any, Any], sa_column_kwargs))
|
|
134
|
+
|
|
135
|
+
sa_type = get_sqlalchemy_type(field)
|
|
136
|
+
return Column(sa_type, *args, **kwargs) # type: ignore
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
sqlmodel.main.get_column_from_field = get_column_from_field
|
activemodel/mixins/__init__.py
CHANGED
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
"""
|
|
2
|
+
https://github.com/fastapi/sqlmodel/issues/63
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from types import UnionType
|
|
6
|
+
from typing import get_args, get_origin
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel as PydanticBaseModel
|
|
9
|
+
from sqlalchemy.orm import reconstructor
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PydanticJSONMixin:
|
|
13
|
+
"""
|
|
14
|
+
By default, SQLModel does not convert JSONB columns into pydantic models when they are loaded from the database.
|
|
15
|
+
|
|
16
|
+
This mixin, combined with a custom serializer (`_serialize_pydantic_model`), fixes that issue.
|
|
17
|
+
|
|
18
|
+
>>> class ExampleWithJSON(BaseModel, PydanticJSONMixin, table=True):
|
|
19
|
+
>>> list_field: list[SubObject] = Field(sa_type=JSONB()
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
@reconstructor
|
|
23
|
+
def __transform_dict_to_pydantic__(self):
|
|
24
|
+
"""
|
|
25
|
+
Transforms dictionary fields into Pydantic models upon loading.
|
|
26
|
+
|
|
27
|
+
- Reconstructor only runs once, when the object is loaded.
|
|
28
|
+
- We manually call this method on save(), etc to ensure the pydantic types are maintained
|
|
29
|
+
"""
|
|
30
|
+
# TODO do we need to inspect sa_type
|
|
31
|
+
for field_name, field_info in self.model_fields.items():
|
|
32
|
+
raw_value = getattr(self, field_name, None)
|
|
33
|
+
|
|
34
|
+
if raw_value is None:
|
|
35
|
+
continue
|
|
36
|
+
|
|
37
|
+
annotation = field_info.annotation
|
|
38
|
+
origin = get_origin(annotation)
|
|
39
|
+
|
|
40
|
+
# e.g. `dict` or `dict[str, str]`, we don't want to do anything with these
|
|
41
|
+
if origin is dict:
|
|
42
|
+
continue
|
|
43
|
+
|
|
44
|
+
annotation_args = get_args(annotation)
|
|
45
|
+
is_top_level_list = origin is list
|
|
46
|
+
|
|
47
|
+
# if origin is not None:
|
|
48
|
+
# assert annotation.__class__ == origin
|
|
49
|
+
|
|
50
|
+
model_cls = annotation
|
|
51
|
+
|
|
52
|
+
# e.g. SomePydanticModel | None or list[SomePydanticModel] | None
|
|
53
|
+
# annotation_args are (type, NoneType) in this case
|
|
54
|
+
if isinstance(annotation, UnionType):
|
|
55
|
+
non_none_types = [t for t in annotation_args if t is not type(None)]
|
|
56
|
+
|
|
57
|
+
if len(non_none_types) == 1:
|
|
58
|
+
model_cls = non_none_types[0]
|
|
59
|
+
|
|
60
|
+
# e.g. list[SomePydanticModel] | None, we have to unpack it
|
|
61
|
+
# model_cls will print as a list, but it contains a subtype if you dig into it
|
|
62
|
+
if (
|
|
63
|
+
get_origin(model_cls) is list
|
|
64
|
+
and len(list_annotation_args := get_args(model_cls)) == 1
|
|
65
|
+
):
|
|
66
|
+
model_cls = list_annotation_args[0]
|
|
67
|
+
is_top_level_list = True
|
|
68
|
+
|
|
69
|
+
# e.g. list[SomePydanticModel] or list[SomePydanticModel] | None
|
|
70
|
+
# iterate through the list and run each item through the pydantic model
|
|
71
|
+
if is_top_level_list:
|
|
72
|
+
if isinstance(raw_value, list) and issubclass(
|
|
73
|
+
model_cls, PydanticBaseModel
|
|
74
|
+
):
|
|
75
|
+
parsed_value = [model_cls(**item) for item in raw_value]
|
|
76
|
+
setattr(self, field_name, parsed_value)
|
|
77
|
+
|
|
78
|
+
continue
|
|
79
|
+
|
|
80
|
+
# single class
|
|
81
|
+
if issubclass(model_cls, PydanticBaseModel):
|
|
82
|
+
setattr(self, field_name, model_cls(**raw_value))
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
|
|
3
|
+
import sqlalchemy as sa
|
|
4
|
+
from sqlmodel import Field
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class SoftDeletionMixin:
|
|
8
|
+
deleted_at: datetime = Field(
|
|
9
|
+
default=None,
|
|
10
|
+
nullable=True,
|
|
11
|
+
# TODO https://github.com/fastapi/sqlmodel/discussions/1228
|
|
12
|
+
sa_type=sa.DateTime(timezone=True), # type: ignore
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
def soft_delete(self):
|
|
16
|
+
self.deleted_at = datetime.now()
|
|
17
|
+
raise NotImplementedError("Soft deletion is not implemented")
|
activemodel/mixins/typeid.py
CHANGED
|
@@ -1,36 +1,46 @@
|
|
|
1
|
-
import uuid
|
|
2
|
-
|
|
3
1
|
from sqlmodel import Column, Field
|
|
4
2
|
from typeid import TypeID
|
|
5
3
|
|
|
6
4
|
from activemodel.types.typeid import TypeIDType
|
|
7
5
|
|
|
6
|
+
# global list of prefixes to ensure uniqueness
|
|
7
|
+
_prefixes = []
|
|
8
|
+
|
|
8
9
|
|
|
9
10
|
def TypeIDMixin(prefix: str):
|
|
11
|
+
assert prefix
|
|
12
|
+
assert prefix not in _prefixes, (
|
|
13
|
+
f"prefix {prefix} already exists, pick a different one"
|
|
14
|
+
)
|
|
15
|
+
|
|
10
16
|
class _TypeIDMixin:
|
|
11
|
-
id:
|
|
12
|
-
sa_column=Column(TypeIDType(prefix), primary_key=True),
|
|
17
|
+
id: TypeIDType = Field(
|
|
18
|
+
sa_column=Column(TypeIDType(prefix), primary_key=True, nullable=False),
|
|
13
19
|
default_factory=lambda: TypeID(prefix),
|
|
14
20
|
)
|
|
15
21
|
|
|
22
|
+
_prefixes.append(prefix)
|
|
23
|
+
|
|
16
24
|
return _TypeIDMixin
|
|
17
25
|
|
|
18
26
|
|
|
19
|
-
class
|
|
20
|
-
|
|
21
|
-
|
|
27
|
+
# TODO not sure if I love the idea of a dynamic class for each mixin as used above
|
|
28
|
+
# may give this approach another shot in the future
|
|
29
|
+
# class TypeIDMixin2:
|
|
30
|
+
# """
|
|
31
|
+
# Mixin class that adds a TypeID primary key to models.
|
|
22
32
|
|
|
23
33
|
|
|
24
|
-
|
|
25
|
-
|
|
34
|
+
# >>> class MyModel(BaseModel, TypeIDMixin, prefix="xyz", table=True):
|
|
35
|
+
# >>> name: str
|
|
26
36
|
|
|
27
|
-
|
|
28
|
-
|
|
37
|
+
# Will automatically have an `id` field with prefix "xyz"
|
|
38
|
+
# """
|
|
29
39
|
|
|
30
|
-
|
|
31
|
-
|
|
40
|
+
# def __init_subclass__(cls, *, prefix: str, **kwargs):
|
|
41
|
+
# super().__init_subclass__(**kwargs)
|
|
32
42
|
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
43
|
+
# cls.id: uuid.UUID = Field(
|
|
44
|
+
# sa_column=Column(TypeIDType(prefix), primary_key=True),
|
|
45
|
+
# default_factory=lambda: TypeID(prefix),
|
|
46
|
+
# )
|