activemodel 0.3.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- activemodel/__init__.py +2 -4
- activemodel/base_model.py +207 -40
- activemodel/celery.py +28 -0
- activemodel/errors.py +6 -0
- activemodel/get_column_from_field_patch.py +137 -0
- activemodel/logger.py +3 -0
- activemodel/mixins/__init__.py +4 -0
- activemodel/mixins/pydantic_json.py +69 -0
- activemodel/mixins/soft_delete.py +17 -0
- activemodel/{timestamps.py → mixins/timestamps.py} +3 -4
- activemodel/mixins/typeid.py +46 -0
- activemodel/pytest/__init__.py +2 -0
- activemodel/pytest/transaction.py +51 -0
- activemodel/pytest/truncate.py +46 -0
- activemodel/query_wrapper.py +23 -17
- activemodel/session_manager.py +132 -0
- activemodel/types/__init__.py +1 -0
- activemodel/types/typeid.py +191 -0
- activemodel/utils.py +65 -0
- activemodel-0.7.0.dist-info/METADATA +235 -0
- activemodel-0.7.0.dist-info/RECORD +24 -0
- {activemodel-0.3.0.dist-info → activemodel-0.7.0.dist-info}/WHEEL +1 -2
- activemodel-0.3.0.dist-info/METADATA +0 -34
- activemodel-0.3.0.dist-info/RECORD +0 -10
- activemodel-0.3.0.dist-info/top_level.txt +0 -1
- {activemodel-0.3.0.dist-info → activemodel-0.7.0.dist-info}/entry_points.txt +0 -0
- {activemodel-0.3.0.dist-info → activemodel-0.7.0.dist-info/licenses}/LICENSE +0 -0
activemodel/__init__.py
CHANGED
|
@@ -1,6 +1,4 @@
|
|
|
1
1
|
from .base_model import BaseModel
|
|
2
|
-
from .query_wrapper import QueryWrapper
|
|
3
|
-
from .timestamps import TimestampMixin
|
|
4
2
|
|
|
5
|
-
#
|
|
6
|
-
|
|
3
|
+
# from .field import Field
|
|
4
|
+
from .session_manager import SessionManager, get_engine, get_session, init
|
activemodel/base_model.py
CHANGED
|
@@ -1,12 +1,38 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import typing as t
|
|
3
|
+
from uuid import UUID
|
|
3
4
|
|
|
4
5
|
import pydash
|
|
5
6
|
import sqlalchemy as sa
|
|
6
|
-
|
|
7
|
-
from
|
|
8
|
-
|
|
7
|
+
import sqlmodel as sm
|
|
8
|
+
from sqlalchemy import Connection, event
|
|
9
|
+
from sqlalchemy.orm import Mapper, declared_attr
|
|
10
|
+
from sqlmodel import Field, MetaData, Session, SQLModel, select
|
|
11
|
+
from typeid import TypeID
|
|
12
|
+
|
|
13
|
+
# NOTE: this patches a core method in sqlmodel to support db comments
|
|
14
|
+
from . import get_column_from_field_patch # noqa: F401
|
|
15
|
+
from .logger import logger
|
|
9
16
|
from .query_wrapper import QueryWrapper
|
|
17
|
+
from .session_manager import get_session
|
|
18
|
+
|
|
19
|
+
POSTGRES_INDEXES_NAMING_CONVENTION = {
|
|
20
|
+
"ix": "%(column_0_label)s_idx",
|
|
21
|
+
"uq": "%(table_name)s_%(column_0_name)s_key",
|
|
22
|
+
"ck": "%(table_name)s_%(constraint_name)s_check",
|
|
23
|
+
"fk": "%(table_name)s_%(column_0_name)s_fkey",
|
|
24
|
+
"pk": "%(table_name)s_pkey",
|
|
25
|
+
}
|
|
26
|
+
"""
|
|
27
|
+
By default, the foreign key naming convention in sqlalchemy do not create unique identifiers when there are multiple
|
|
28
|
+
foreign keys in a table. This naming convention is a workaround to fix this issue:
|
|
29
|
+
|
|
30
|
+
- https://github.com/zhanymkanov/fastapi-best-practices?tab=readme-ov-file#set-db-keys-naming-conventions
|
|
31
|
+
- https://github.com/fastapi/sqlmodel/discussions/1213
|
|
32
|
+
- Implementation lifted from: https://github.com/AlexanderZharyuk/billing-service/blob/3c8aaf19ab7546b97cc4db76f60335edec9fc79d/src/models.py#L24
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
SQLModel.metadata.naming_convention = POSTGRES_INDEXES_NAMING_CONVENTION
|
|
10
36
|
|
|
11
37
|
|
|
12
38
|
class BaseModel(SQLModel):
|
|
@@ -14,32 +40,110 @@ class BaseModel(SQLModel):
|
|
|
14
40
|
Base model class to inherit from so we can hate python less
|
|
15
41
|
|
|
16
42
|
https://github.com/woofz/sqlmodel-basecrud/blob/main/sqlmodel_basecrud/basecrud.py
|
|
43
|
+
|
|
44
|
+
- {before,after} lifecycle hooks are modeled after Rails.
|
|
45
|
+
- class docstrings are converd to table-level comments
|
|
46
|
+
- save(), delete(), select(), where(), and other easy methods you would expect
|
|
47
|
+
- Fixes foreign key naming conventions
|
|
17
48
|
"""
|
|
18
49
|
|
|
19
|
-
#
|
|
50
|
+
# this is used for table-level comments
|
|
51
|
+
__table_args__ = None
|
|
52
|
+
|
|
53
|
+
@classmethod
|
|
54
|
+
def __init_subclass__(cls, **kwargs):
|
|
55
|
+
"Setup automatic sqlalchemy lifecycle events for the class"
|
|
56
|
+
|
|
57
|
+
super().__init_subclass__(**kwargs)
|
|
58
|
+
|
|
59
|
+
from sqlmodel._compat import set_config_value
|
|
60
|
+
|
|
61
|
+
# enables field-level docstrings on the pydanatic `description` field, which we then copy into
|
|
62
|
+
# sa_args, which is persisted to sql table comments
|
|
63
|
+
set_config_value(model=cls, parameter="use_attribute_docstrings", value=True)
|
|
64
|
+
|
|
65
|
+
cls._apply_class_doc()
|
|
66
|
+
|
|
67
|
+
def event_wrapper(method_name: str):
|
|
68
|
+
"""
|
|
69
|
+
This does smart heavy lifting for us to make sqlalchemy lifecycle events nicer to work with:
|
|
70
|
+
|
|
71
|
+
* Passes the target first to the lifecycle method, so it feels like an instance method
|
|
72
|
+
* Allows as little as a single positional argument, so methods can be simple
|
|
73
|
+
* Removes the need for decorators or anything fancy on the subclass
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
def wrapper(mapper: Mapper, connection: Connection, target: BaseModel):
|
|
77
|
+
if hasattr(cls, method_name):
|
|
78
|
+
method = getattr(cls, method_name)
|
|
79
|
+
|
|
80
|
+
if callable(method):
|
|
81
|
+
arg_count = method.__code__.co_argcount
|
|
20
82
|
|
|
21
|
-
|
|
22
|
-
|
|
83
|
+
if arg_count == 1: # Just self/cls
|
|
84
|
+
method(target)
|
|
85
|
+
elif arg_count == 2: # Self, mapper
|
|
86
|
+
method(target, mapper)
|
|
87
|
+
elif arg_count == 3: # Full signature
|
|
88
|
+
method(target, mapper, connection)
|
|
89
|
+
else:
|
|
90
|
+
raise TypeError(
|
|
91
|
+
f"Method {method_name} must accept either 1 to 3 arguments, got {arg_count}"
|
|
92
|
+
)
|
|
93
|
+
else:
|
|
94
|
+
logger.warning(
|
|
95
|
+
"SQLModel lifecycle hook found, but not callable hook_name=%s",
|
|
96
|
+
method_name,
|
|
97
|
+
)
|
|
23
98
|
|
|
24
|
-
|
|
25
|
-
pass
|
|
99
|
+
return wrapper
|
|
26
100
|
|
|
27
|
-
|
|
28
|
-
|
|
101
|
+
event.listen(cls, "before_insert", event_wrapper("before_insert"))
|
|
102
|
+
event.listen(cls, "before_update", event_wrapper("before_update"))
|
|
29
103
|
|
|
30
|
-
|
|
31
|
-
|
|
104
|
+
# before_save maps to two type of events
|
|
105
|
+
event.listen(cls, "before_insert", event_wrapper("before_save"))
|
|
106
|
+
event.listen(cls, "before_update", event_wrapper("before_save"))
|
|
32
107
|
|
|
33
|
-
|
|
34
|
-
|
|
108
|
+
# now, let's handle after_* variants
|
|
109
|
+
event.listen(cls, "after_insert", event_wrapper("after_insert"))
|
|
110
|
+
event.listen(cls, "after_update", event_wrapper("after_update"))
|
|
35
111
|
|
|
36
|
-
|
|
37
|
-
|
|
112
|
+
# after_save maps to two type of events
|
|
113
|
+
event.listen(cls, "after_insert", event_wrapper("after_save"))
|
|
114
|
+
event.listen(cls, "after_update", event_wrapper("after_save"))
|
|
38
115
|
|
|
116
|
+
# def foreign_key()
|
|
117
|
+
# table.id
|
|
118
|
+
|
|
119
|
+
@classmethod
|
|
120
|
+
def _apply_class_doc(cls):
|
|
121
|
+
"""
|
|
122
|
+
Pull class-level docstring into a table comment.
|
|
123
|
+
|
|
124
|
+
This will help AI SQL writers like: https://github.com/iloveitaly/sql-ai-prompt-generator
|
|
125
|
+
"""
|
|
126
|
+
|
|
127
|
+
doc = cls.__doc__.strip() if cls.__doc__ else None
|
|
128
|
+
|
|
129
|
+
if doc:
|
|
130
|
+
table_args = getattr(cls, "__table_args__", None)
|
|
131
|
+
|
|
132
|
+
if table_args is None:
|
|
133
|
+
cls.__table_args__ = {"comment": doc}
|
|
134
|
+
elif isinstance(table_args, dict):
|
|
135
|
+
table_args.setdefault("comment", doc)
|
|
136
|
+
else:
|
|
137
|
+
raise ValueError("Unexpected __table_args__ type")
|
|
138
|
+
|
|
139
|
+
# TODO no type check decorator here
|
|
39
140
|
@declared_attr
|
|
40
141
|
def __tablename__(cls) -> str:
|
|
41
142
|
"""
|
|
42
143
|
Automatically generates the table name for the model by converting the class name from camel case to snake case.
|
|
144
|
+
This is the recommended format for table names:
|
|
145
|
+
|
|
146
|
+
https://wiki.postgresql.org/wiki/Don%27t_Do_This#Don.27t_use_upper_case_table_or_column_names
|
|
43
147
|
|
|
44
148
|
By default, the class is lower cased which makes it harder to read.
|
|
45
149
|
|
|
@@ -50,74 +154,137 @@ class BaseModel(SQLModel):
|
|
|
50
154
|
"""
|
|
51
155
|
return pydash.strings.snake_case(cls.__name__)
|
|
52
156
|
|
|
157
|
+
@classmethod
|
|
158
|
+
def foreign_key(cls, **kwargs):
|
|
159
|
+
"""
|
|
160
|
+
Returns a Field object referencing the foreign key of the model.
|
|
161
|
+
"""
|
|
162
|
+
|
|
163
|
+
field_options = {"nullable": False} | kwargs
|
|
164
|
+
|
|
165
|
+
return Field(
|
|
166
|
+
# TODO id field is hard coded, should pick the PK field in case it's different
|
|
167
|
+
sa_type=cls.model_fields["id"].sa_column.type, # type: ignore
|
|
168
|
+
foreign_key=f"{cls.__tablename__}.id",
|
|
169
|
+
**field_options,
|
|
170
|
+
)
|
|
171
|
+
|
|
53
172
|
@classmethod
|
|
54
173
|
def select(cls, *args):
|
|
174
|
+
"create a query wrapper to easily run sqlmodel queries on this model"
|
|
55
175
|
return QueryWrapper[cls](cls, *args)
|
|
56
176
|
|
|
177
|
+
def delete(self):
|
|
178
|
+
with get_session() as session:
|
|
179
|
+
if old_session := Session.object_session(self):
|
|
180
|
+
old_session.expunge(self)
|
|
181
|
+
|
|
182
|
+
session.delete(self)
|
|
183
|
+
session.commit()
|
|
184
|
+
return True
|
|
185
|
+
|
|
57
186
|
def save(self):
|
|
58
|
-
old_session = Session.object_session(self)
|
|
59
187
|
with get_session() as session:
|
|
60
|
-
if old_session:
|
|
188
|
+
if old_session := Session.object_session(self):
|
|
61
189
|
# I was running into an issue where the object was already
|
|
62
190
|
# associated with a session, but the session had been closed,
|
|
63
191
|
# to get around this, you need to remove it from the old one,
|
|
64
192
|
# then add it to the new one (below)
|
|
65
193
|
old_session.expunge(self)
|
|
66
194
|
|
|
67
|
-
self.before_update()
|
|
68
|
-
# self.before_save()
|
|
69
|
-
|
|
70
195
|
session.add(self)
|
|
196
|
+
# NOTE very important method! This triggers sqlalchemy lifecycle hooks automatically
|
|
71
197
|
session.commit()
|
|
72
198
|
session.refresh(self)
|
|
73
199
|
|
|
74
|
-
|
|
75
|
-
# self.after_save()
|
|
76
|
-
|
|
77
|
-
return self
|
|
200
|
+
return self
|
|
78
201
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
202
|
+
# except IntegrityError:
|
|
203
|
+
# log.quiet(f"{self} already exists in the database.")
|
|
204
|
+
# session.rollback()
|
|
82
205
|
|
|
83
206
|
# TODO shouldn't this be handled by pydantic?
|
|
84
207
|
def json(self, **kwargs):
|
|
85
208
|
return json.dumps(self.dict(), default=str, **kwargs)
|
|
86
209
|
|
|
210
|
+
# TODO should move this to the wrapper
|
|
87
211
|
@classmethod
|
|
88
|
-
def count(cls):
|
|
212
|
+
def count(cls) -> int:
|
|
89
213
|
"""
|
|
90
214
|
Returns the number of records in the database.
|
|
91
215
|
"""
|
|
92
|
-
# TODO should move this to the wrapper
|
|
93
216
|
with get_session() as session:
|
|
94
|
-
|
|
95
|
-
|
|
217
|
+
return session.scalar(sm.select(sm.func.count()).select_from(cls))
|
|
218
|
+
|
|
219
|
+
# TODO got to be a better way to fwd these along...
|
|
220
|
+
@classmethod
|
|
221
|
+
def first(cls):
|
|
222
|
+
return cls.select().first()
|
|
223
|
+
|
|
224
|
+
# TODO throw an error if this field is set on the model
|
|
225
|
+
def is_new(self) -> bool:
|
|
226
|
+
return not self._sa_instance_state.has_identity
|
|
227
|
+
|
|
228
|
+
@classmethod
|
|
229
|
+
def find_or_create_by(cls, **kwargs):
|
|
230
|
+
"""
|
|
231
|
+
Find record or create it with the passed args if it doesn't exist.
|
|
232
|
+
"""
|
|
233
|
+
|
|
234
|
+
result = cls.get(**kwargs)
|
|
235
|
+
|
|
236
|
+
if result:
|
|
237
|
+
return result
|
|
238
|
+
|
|
239
|
+
new_model = cls(**kwargs)
|
|
240
|
+
new_model.save()
|
|
241
|
+
|
|
242
|
+
return new_model
|
|
243
|
+
|
|
244
|
+
@classmethod
|
|
245
|
+
def find_or_initialize_by(cls, **kwargs):
|
|
246
|
+
"""
|
|
247
|
+
Unfortunately, unlike ruby, python does not have a great lambda story. This makes writing convenience methods
|
|
248
|
+
like this a bit more difficult.
|
|
249
|
+
"""
|
|
250
|
+
|
|
251
|
+
result = cls.get(**kwargs)
|
|
252
|
+
|
|
253
|
+
if result:
|
|
254
|
+
return result
|
|
255
|
+
|
|
256
|
+
new_model = cls(**kwargs)
|
|
257
|
+
return new_model
|
|
96
258
|
|
|
97
259
|
# TODO what's super dangerous here is you pass a kwarg which does not map to a specific
|
|
98
260
|
# field it will result in `True`, which will return all records, and not give you any typing
|
|
99
261
|
# errors. Dangerous when iterating on structure quickly
|
|
100
262
|
# TODO can we pass the generic of the superclass in?
|
|
263
|
+
# TODO can we type the method signature a bit better?
|
|
264
|
+
# def get(cls, *args: sa.BinaryExpression, **kwargs: t.Any):
|
|
101
265
|
@classmethod
|
|
102
|
-
def get(cls, *args:
|
|
266
|
+
def get(cls, *args: t.Any, **kwargs: t.Any):
|
|
103
267
|
"""
|
|
104
268
|
Gets a single record from the database. Pass an PK ID or a kwarg to filter by.
|
|
105
269
|
"""
|
|
106
270
|
|
|
271
|
+
# TODO id is hardcoded, not good! Need to dynamically pick the best uid field
|
|
272
|
+
id_field_name = "id"
|
|
273
|
+
|
|
107
274
|
# special case for getting by ID
|
|
108
|
-
if len(args) == 1 and isinstance(args[0], int):
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
275
|
+
if len(args) == 1 and isinstance(args[0], (int, TypeID, str, UUID)):
|
|
276
|
+
kwargs[id_field_name] = args[0]
|
|
277
|
+
args = ()
|
|
278
|
+
|
|
279
|
+
statement = select(cls).filter(*args).filter_by(**kwargs)
|
|
112
280
|
|
|
113
|
-
statement = sql.select(cls).filter(*args).filter_by(**kwargs)
|
|
114
281
|
with get_session() as session:
|
|
115
282
|
return session.exec(statement).first()
|
|
116
283
|
|
|
117
284
|
@classmethod
|
|
118
285
|
def all(cls):
|
|
119
286
|
with get_session() as session:
|
|
120
|
-
results = session.exec(
|
|
287
|
+
results = session.exec(sm.select(cls))
|
|
121
288
|
|
|
122
289
|
# TODO do we need this or can we just return results?
|
|
123
290
|
for result in results:
|
|
@@ -131,7 +298,7 @@ class BaseModel(SQLModel):
|
|
|
131
298
|
Helpful for testing and console debugging.
|
|
132
299
|
"""
|
|
133
300
|
|
|
134
|
-
query =
|
|
301
|
+
query = sm.select(cls).order_by(sa.sql.func.random()).limit(1)
|
|
135
302
|
|
|
136
303
|
with get_session() as session:
|
|
137
304
|
return session.exec(query).one()
|
activemodel/celery.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Do not import unless you have Celery/Kombu installed.
|
|
3
|
+
|
|
4
|
+
In order for TypeID objects to be properly handled by celery, a custom encoder must be registered.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from kombu.utils.json import register_type
|
|
8
|
+
from typeid import TypeID
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def register_celery_typeid_encoder():
|
|
12
|
+
"this ensures TypeID objects passed as arguments to a delayed function are properly serialized"
|
|
13
|
+
|
|
14
|
+
def class_full_name(clz) -> str:
|
|
15
|
+
return ".".join([clz.__module__, clz.__qualname__])
|
|
16
|
+
|
|
17
|
+
def _encoder(obj: TypeID) -> str:
|
|
18
|
+
return str(obj)
|
|
19
|
+
|
|
20
|
+
def _decoder(data: str) -> TypeID:
|
|
21
|
+
return TypeID.from_string(data)
|
|
22
|
+
|
|
23
|
+
register_type(
|
|
24
|
+
TypeID,
|
|
25
|
+
class_full_name(TypeID),
|
|
26
|
+
encoder=_encoder,
|
|
27
|
+
decoder=_decoder,
|
|
28
|
+
)
|
activemodel/errors.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pydantic has a great DX for adding docstrings to fields. This allows devs to easily document the fields of a model.
|
|
3
|
+
|
|
4
|
+
Making sure these docstrings make their way to the DB schema is helpful for a bunch of reasons (LLM understanding being one of them).
|
|
5
|
+
|
|
6
|
+
This patch mutates a core sqlmodel function which translates pydantic FieldInfo objects into sqlalchemy Column objects. It adds the field description as a comment to the column.
|
|
7
|
+
|
|
8
|
+
Note that FieldInfo *from pydantic* is used when a "bare" field is defined. This can be confusing, because when inspecting model fields, the class name looks exactly the same.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from typing import (
|
|
12
|
+
TYPE_CHECKING,
|
|
13
|
+
Any,
|
|
14
|
+
Dict,
|
|
15
|
+
Sequence,
|
|
16
|
+
cast,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
import sqlmodel
|
|
20
|
+
from pydantic.fields import FieldInfo as PydanticFieldInfo
|
|
21
|
+
from sqlalchemy import (
|
|
22
|
+
Column,
|
|
23
|
+
ForeignKey,
|
|
24
|
+
)
|
|
25
|
+
from sqlmodel._compat import ( # type: ignore[attr-defined]
|
|
26
|
+
IS_PYDANTIC_V2,
|
|
27
|
+
ModelMetaclass,
|
|
28
|
+
Representation,
|
|
29
|
+
Undefined,
|
|
30
|
+
UndefinedType,
|
|
31
|
+
is_field_noneable,
|
|
32
|
+
)
|
|
33
|
+
from sqlmodel.main import FieldInfo, get_sqlalchemy_type
|
|
34
|
+
|
|
35
|
+
from activemodel.utils import hash_function_code
|
|
36
|
+
|
|
37
|
+
if TYPE_CHECKING:
|
|
38
|
+
from pydantic._internal._model_construction import ModelMetaclass as ModelMetaclass
|
|
39
|
+
from pydantic._internal._repr import Representation as Representation
|
|
40
|
+
from pydantic_core import PydanticUndefined as Undefined
|
|
41
|
+
from pydantic_core import PydanticUndefinedType as UndefinedType
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
assert (
|
|
45
|
+
hash_function_code(sqlmodel.main.get_column_from_field)
|
|
46
|
+
== "398006ef8fd8da191ca1a271ef25b6e135da0f400a80df2f29526d8674f9ec51"
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def get_column_from_field(field: PydanticFieldInfo | FieldInfo) -> Column: # type: ignore
|
|
51
|
+
"""
|
|
52
|
+
Takes a field definition, which can either come from the sqlmodel FieldInfo class or the pydantic variant of that class,
|
|
53
|
+
and converts it into a sqlalchemy Column object.
|
|
54
|
+
"""
|
|
55
|
+
if IS_PYDANTIC_V2:
|
|
56
|
+
field_info = field
|
|
57
|
+
else:
|
|
58
|
+
field_info = field.field_info
|
|
59
|
+
|
|
60
|
+
sa_column = getattr(field_info, "sa_column", Undefined)
|
|
61
|
+
if isinstance(sa_column, Column):
|
|
62
|
+
# IMPORTANT: change from the original function
|
|
63
|
+
if not sa_column.comment and (field_comment := field_info.description):
|
|
64
|
+
sa_column.comment = field_comment
|
|
65
|
+
return sa_column
|
|
66
|
+
|
|
67
|
+
primary_key = getattr(field_info, "primary_key", Undefined)
|
|
68
|
+
if primary_key is Undefined:
|
|
69
|
+
primary_key = False
|
|
70
|
+
|
|
71
|
+
index = getattr(field_info, "index", Undefined)
|
|
72
|
+
if index is Undefined:
|
|
73
|
+
index = False
|
|
74
|
+
|
|
75
|
+
nullable = not primary_key and is_field_noneable(field)
|
|
76
|
+
# Override derived nullability if the nullable property is set explicitly
|
|
77
|
+
# on the field
|
|
78
|
+
field_nullable = getattr(field_info, "nullable", Undefined) # noqa: B009
|
|
79
|
+
if field_nullable is not Undefined:
|
|
80
|
+
assert not isinstance(field_nullable, UndefinedType)
|
|
81
|
+
nullable = field_nullable
|
|
82
|
+
args = []
|
|
83
|
+
foreign_key = getattr(field_info, "foreign_key", Undefined)
|
|
84
|
+
if foreign_key is Undefined:
|
|
85
|
+
foreign_key = None
|
|
86
|
+
unique = getattr(field_info, "unique", Undefined)
|
|
87
|
+
if unique is Undefined:
|
|
88
|
+
unique = False
|
|
89
|
+
if foreign_key:
|
|
90
|
+
if field_info.ondelete == "SET NULL" and not nullable:
|
|
91
|
+
raise RuntimeError('ondelete="SET NULL" requires nullable=True')
|
|
92
|
+
assert isinstance(foreign_key, str)
|
|
93
|
+
ondelete = getattr(field_info, "ondelete", Undefined)
|
|
94
|
+
if ondelete is Undefined:
|
|
95
|
+
ondelete = None
|
|
96
|
+
assert isinstance(ondelete, (str, type(None))) # for typing
|
|
97
|
+
args.append(ForeignKey(foreign_key, ondelete=ondelete))
|
|
98
|
+
kwargs = {
|
|
99
|
+
"primary_key": primary_key,
|
|
100
|
+
"nullable": nullable,
|
|
101
|
+
"index": index,
|
|
102
|
+
"unique": unique,
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
sa_default = Undefined
|
|
106
|
+
if field_info.default_factory:
|
|
107
|
+
sa_default = field_info.default_factory
|
|
108
|
+
elif field_info.default is not Undefined:
|
|
109
|
+
sa_default = field_info.default
|
|
110
|
+
if sa_default is not Undefined:
|
|
111
|
+
kwargs["default"] = sa_default
|
|
112
|
+
|
|
113
|
+
sa_column_args = getattr(field_info, "sa_column_args", Undefined)
|
|
114
|
+
if sa_column_args is not Undefined:
|
|
115
|
+
args.extend(list(cast(Sequence[Any], sa_column_args)))
|
|
116
|
+
|
|
117
|
+
sa_column_kwargs = getattr(field_info, "sa_column_kwargs", Undefined)
|
|
118
|
+
|
|
119
|
+
# IMPORTANT: change from the original function
|
|
120
|
+
if field_info.description:
|
|
121
|
+
if sa_column_kwargs is Undefined:
|
|
122
|
+
sa_column_kwargs = {}
|
|
123
|
+
|
|
124
|
+
assert isinstance(sa_column_kwargs, dict)
|
|
125
|
+
|
|
126
|
+
# only update comments if not already set
|
|
127
|
+
if "comment" not in sa_column_kwargs:
|
|
128
|
+
sa_column_kwargs["comment"] = field_info.description
|
|
129
|
+
|
|
130
|
+
if sa_column_kwargs is not Undefined:
|
|
131
|
+
kwargs.update(cast(Dict[Any, Any], sa_column_kwargs))
|
|
132
|
+
|
|
133
|
+
sa_type = get_sqlalchemy_type(field)
|
|
134
|
+
return Column(sa_type, *args, **kwargs) # type: ignore
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
sqlmodel.main.get_column_from_field = get_column_from_field
|
activemodel/logger.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
from types import UnionType
|
|
2
|
+
from typing import get_args, get_origin
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel as PydanticBaseModel
|
|
5
|
+
from sqlalchemy.orm import reconstructor
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class PydanticJSONMixin:
|
|
9
|
+
"""
|
|
10
|
+
By default, SQLModel does not convert JSONB columns into pydantic models when they are loaded from the database.
|
|
11
|
+
|
|
12
|
+
This mixin, combined with a custom serializer, fixes that issue.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
@reconstructor
|
|
16
|
+
def init_on_load(self):
|
|
17
|
+
# TODO do we need to inspect sa_type
|
|
18
|
+
for field_name, field_info in self.model_fields.items():
|
|
19
|
+
raw_value = getattr(self, field_name, None)
|
|
20
|
+
|
|
21
|
+
if raw_value is None:
|
|
22
|
+
continue
|
|
23
|
+
|
|
24
|
+
annotation = field_info.annotation
|
|
25
|
+
origin = get_origin(annotation)
|
|
26
|
+
|
|
27
|
+
# e.g. `dict` or `dict[str, str]`, we don't want to do anything with these
|
|
28
|
+
if origin is dict:
|
|
29
|
+
continue
|
|
30
|
+
|
|
31
|
+
annotation_args = get_args(annotation)
|
|
32
|
+
is_top_level_list = origin is list
|
|
33
|
+
|
|
34
|
+
# if origin is not None:
|
|
35
|
+
# assert annotation.__class__ == origin
|
|
36
|
+
|
|
37
|
+
model_cls = annotation
|
|
38
|
+
|
|
39
|
+
# e.g. SomePydanticModel | None or list[SomePydanticModel] | None
|
|
40
|
+
# annotation_args are (type, NoneType) in this case
|
|
41
|
+
if isinstance(annotation, UnionType):
|
|
42
|
+
non_none_types = [t for t in annotation_args if t is not type(None)]
|
|
43
|
+
|
|
44
|
+
if len(non_none_types) == 1:
|
|
45
|
+
model_cls = non_none_types[0]
|
|
46
|
+
|
|
47
|
+
# e.g. list[SomePydanticModel] | None, we have to unpack it
|
|
48
|
+
# model_cls will print as a list, but it contains a subtype if you dig into it
|
|
49
|
+
if (
|
|
50
|
+
get_origin(model_cls) is list
|
|
51
|
+
and len(list_annotation_args := get_args(model_cls)) == 1
|
|
52
|
+
):
|
|
53
|
+
model_cls = list_annotation_args[0]
|
|
54
|
+
is_top_level_list = True
|
|
55
|
+
|
|
56
|
+
# e.g. list[SomePydanticModel] or list[SomePydanticModel] | None
|
|
57
|
+
# iterate through the list and run each item through the pydantic model
|
|
58
|
+
if is_top_level_list:
|
|
59
|
+
if isinstance(raw_value, list) and issubclass(
|
|
60
|
+
model_cls, PydanticBaseModel
|
|
61
|
+
):
|
|
62
|
+
parsed_value = [model_cls(**item) for item in raw_value]
|
|
63
|
+
setattr(self, field_name, parsed_value)
|
|
64
|
+
|
|
65
|
+
continue
|
|
66
|
+
|
|
67
|
+
# single class
|
|
68
|
+
if issubclass(model_cls, PydanticBaseModel):
|
|
69
|
+
setattr(self, field_name, model_cls(**raw_value))
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
|
|
3
|
+
import sqlalchemy as sa
|
|
4
|
+
from sqlmodel import Field
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class SoftDeletionMixin:
|
|
8
|
+
deleted_at: datetime = Field(
|
|
9
|
+
default=None,
|
|
10
|
+
nullable=True,
|
|
11
|
+
# TODO https://github.com/fastapi/sqlmodel/discussions/1228
|
|
12
|
+
sa_type=sa.DateTime(timezone=True), # type: ignore
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
def soft_delete(self):
|
|
16
|
+
self.deleted_at = datetime.now()
|
|
17
|
+
raise NotImplementedError("Soft deletion is not implemented")
|
|
@@ -1,10 +1,9 @@
|
|
|
1
1
|
from datetime import datetime
|
|
2
2
|
|
|
3
3
|
import sqlalchemy as sa
|
|
4
|
-
|
|
5
|
-
# TODO raw sql https://github.com/tiangolo/sqlmodel/discussions/772
|
|
6
4
|
from sqlmodel import Field
|
|
7
5
|
|
|
6
|
+
# TODO raw sql https://github.com/tiangolo/sqlmodel/discussions/772
|
|
8
7
|
# @classmethod
|
|
9
8
|
# def select(cls):
|
|
10
9
|
# with get_session() as session:
|
|
@@ -14,11 +13,11 @@ from sqlmodel import Field
|
|
|
14
13
|
# yield result
|
|
15
14
|
|
|
16
15
|
|
|
17
|
-
class
|
|
16
|
+
class TimestampsMixin:
|
|
18
17
|
"""
|
|
19
18
|
Simple created at and updated at timestamps. Mix them into your model:
|
|
20
19
|
|
|
21
|
-
>>> class MyModel(
|
|
20
|
+
>>> class MyModel(TimestampsMixin, SQLModel):
|
|
22
21
|
>>> pass
|
|
23
22
|
|
|
24
23
|
Originally pulled from: https://github.com/tiangolo/sqlmodel/issues/252
|