activemodel 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- activemodel/__init__.py +2 -0
- activemodel/base_model.py +105 -29
- activemodel/celery.py +28 -0
- activemodel/errors.py +6 -0
- activemodel/get_column_from_field_patch.py +137 -0
- activemodel/mixins/__init__.py +2 -0
- activemodel/mixins/pydantic_json.py +69 -0
- activemodel/mixins/soft_delete.py +17 -0
- activemodel/mixins/typeid.py +27 -17
- activemodel/pytest/truncate.py +1 -1
- activemodel/query_wrapper.py +24 -10
- activemodel/session_manager.py +75 -5
- activemodel/types/__init__.py +1 -0
- activemodel/types/typeid.py +140 -5
- activemodel/utils.py +51 -1
- activemodel-0.7.0.dist-info/METADATA +235 -0
- activemodel-0.7.0.dist-info/RECORD +24 -0
- {activemodel-0.5.0.dist-info → activemodel-0.7.0.dist-info}/WHEEL +1 -2
- activemodel/_session_manager.py +0 -153
- activemodel-0.5.0.dist-info/METADATA +0 -66
- activemodel-0.5.0.dist-info/RECORD +0 -20
- activemodel-0.5.0.dist-info/top_level.txt +0 -1
- {activemodel-0.5.0.dist-info → activemodel-0.7.0.dist-info}/entry_points.txt +0 -0
- {activemodel-0.5.0.dist-info → activemodel-0.7.0.dist-info/licenses}/LICENSE +0 -0
activemodel/__init__.py
CHANGED
activemodel/base_model.py
CHANGED
|
@@ -1,33 +1,38 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import typing as t
|
|
3
|
-
from
|
|
3
|
+
from uuid import UUID
|
|
4
4
|
|
|
5
5
|
import pydash
|
|
6
6
|
import sqlalchemy as sa
|
|
7
7
|
import sqlmodel as sm
|
|
8
8
|
from sqlalchemy import Connection, event
|
|
9
9
|
from sqlalchemy.orm import Mapper, declared_attr
|
|
10
|
-
from sqlmodel import Session, SQLModel, select
|
|
10
|
+
from sqlmodel import Field, MetaData, Session, SQLModel, select
|
|
11
11
|
from typeid import TypeID
|
|
12
12
|
|
|
13
|
+
# NOTE: this patches a core method in sqlmodel to support db comments
|
|
14
|
+
from . import get_column_from_field_patch # noqa: F401
|
|
13
15
|
from .logger import logger
|
|
14
16
|
from .query_wrapper import QueryWrapper
|
|
15
17
|
from .session_manager import get_session
|
|
16
18
|
|
|
19
|
+
POSTGRES_INDEXES_NAMING_CONVENTION = {
|
|
20
|
+
"ix": "%(column_0_label)s_idx",
|
|
21
|
+
"uq": "%(table_name)s_%(column_0_name)s_key",
|
|
22
|
+
"ck": "%(table_name)s_%(constraint_name)s_check",
|
|
23
|
+
"fk": "%(table_name)s_%(column_0_name)s_fkey",
|
|
24
|
+
"pk": "%(table_name)s_pkey",
|
|
25
|
+
}
|
|
26
|
+
"""
|
|
27
|
+
By default, the foreign key naming convention in sqlalchemy do not create unique identifiers when there are multiple
|
|
28
|
+
foreign keys in a table. This naming convention is a workaround to fix this issue:
|
|
17
29
|
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
Helper class to ease validation in SQLModel classes with table=True
|
|
23
|
-
"""
|
|
30
|
+
- https://github.com/zhanymkanov/fastapi-best-practices?tab=readme-ov-file#set-db-keys-naming-conventions
|
|
31
|
+
- https://github.com/fastapi/sqlmodel/discussions/1213
|
|
32
|
+
- Implementation lifted from: https://github.com/AlexanderZharyuk/billing-service/blob/3c8aaf19ab7546b97cc4db76f60335edec9fc79d/src/models.py#L24
|
|
33
|
+
"""
|
|
24
34
|
|
|
25
|
-
|
|
26
|
-
def create(cls, **kwargs):
|
|
27
|
-
"""
|
|
28
|
-
Forces validation to take place, even for SQLModel classes with table=True
|
|
29
|
-
"""
|
|
30
|
-
return cls(**cls.__bases__[0](**kwargs).model_dump())
|
|
35
|
+
SQLModel.metadata.naming_convention = POSTGRES_INDEXES_NAMING_CONVENTION
|
|
31
36
|
|
|
32
37
|
|
|
33
38
|
class BaseModel(SQLModel):
|
|
@@ -36,10 +41,14 @@ class BaseModel(SQLModel):
|
|
|
36
41
|
|
|
37
42
|
https://github.com/woofz/sqlmodel-basecrud/blob/main/sqlmodel_basecrud/basecrud.py
|
|
38
43
|
|
|
39
|
-
{before,after} hooks are modeled after Rails.
|
|
44
|
+
- {before,after} lifecycle hooks are modeled after Rails.
|
|
45
|
+
- class docstrings are converd to table-level comments
|
|
46
|
+
- save(), delete(), select(), where(), and other easy methods you would expect
|
|
47
|
+
- Fixes foreign key naming conventions
|
|
40
48
|
"""
|
|
41
49
|
|
|
42
|
-
#
|
|
50
|
+
# this is used for table-level comments
|
|
51
|
+
__table_args__ = None
|
|
43
52
|
|
|
44
53
|
@classmethod
|
|
45
54
|
def __init_subclass__(cls, **kwargs):
|
|
@@ -47,6 +56,14 @@ class BaseModel(SQLModel):
|
|
|
47
56
|
|
|
48
57
|
super().__init_subclass__(**kwargs)
|
|
49
58
|
|
|
59
|
+
from sqlmodel._compat import set_config_value
|
|
60
|
+
|
|
61
|
+
# enables field-level docstrings on the pydanatic `description` field, which we then copy into
|
|
62
|
+
# sa_args, which is persisted to sql table comments
|
|
63
|
+
set_config_value(model=cls, parameter="use_attribute_docstrings", value=True)
|
|
64
|
+
|
|
65
|
+
cls._apply_class_doc()
|
|
66
|
+
|
|
50
67
|
def event_wrapper(method_name: str):
|
|
51
68
|
"""
|
|
52
69
|
This does smart heavy lifting for us to make sqlalchemy lifecycle events nicer to work with:
|
|
@@ -96,11 +113,37 @@ class BaseModel(SQLModel):
|
|
|
96
113
|
event.listen(cls, "after_insert", event_wrapper("after_save"))
|
|
97
114
|
event.listen(cls, "after_update", event_wrapper("after_save"))
|
|
98
115
|
|
|
116
|
+
# def foreign_key()
|
|
117
|
+
# table.id
|
|
118
|
+
|
|
119
|
+
@classmethod
|
|
120
|
+
def _apply_class_doc(cls):
|
|
121
|
+
"""
|
|
122
|
+
Pull class-level docstring into a table comment.
|
|
123
|
+
|
|
124
|
+
This will help AI SQL writers like: https://github.com/iloveitaly/sql-ai-prompt-generator
|
|
125
|
+
"""
|
|
126
|
+
|
|
127
|
+
doc = cls.__doc__.strip() if cls.__doc__ else None
|
|
128
|
+
|
|
129
|
+
if doc:
|
|
130
|
+
table_args = getattr(cls, "__table_args__", None)
|
|
131
|
+
|
|
132
|
+
if table_args is None:
|
|
133
|
+
cls.__table_args__ = {"comment": doc}
|
|
134
|
+
elif isinstance(table_args, dict):
|
|
135
|
+
table_args.setdefault("comment", doc)
|
|
136
|
+
else:
|
|
137
|
+
raise ValueError("Unexpected __table_args__ type")
|
|
138
|
+
|
|
99
139
|
# TODO no type check decorator here
|
|
100
140
|
@declared_attr
|
|
101
141
|
def __tablename__(cls) -> str:
|
|
102
142
|
"""
|
|
103
143
|
Automatically generates the table name for the model by converting the class name from camel case to snake case.
|
|
144
|
+
This is the recommended format for table names:
|
|
145
|
+
|
|
146
|
+
https://wiki.postgresql.org/wiki/Don%27t_Do_This#Don.27t_use_upper_case_table_or_column_names
|
|
104
147
|
|
|
105
148
|
By default, the class is lower cased which makes it harder to read.
|
|
106
149
|
|
|
@@ -111,10 +154,35 @@ class BaseModel(SQLModel):
|
|
|
111
154
|
"""
|
|
112
155
|
return pydash.strings.snake_case(cls.__name__)
|
|
113
156
|
|
|
157
|
+
@classmethod
|
|
158
|
+
def foreign_key(cls, **kwargs):
|
|
159
|
+
"""
|
|
160
|
+
Returns a Field object referencing the foreign key of the model.
|
|
161
|
+
"""
|
|
162
|
+
|
|
163
|
+
field_options = {"nullable": False} | kwargs
|
|
164
|
+
|
|
165
|
+
return Field(
|
|
166
|
+
# TODO id field is hard coded, should pick the PK field in case it's different
|
|
167
|
+
sa_type=cls.model_fields["id"].sa_column.type, # type: ignore
|
|
168
|
+
foreign_key=f"{cls.__tablename__}.id",
|
|
169
|
+
**field_options,
|
|
170
|
+
)
|
|
171
|
+
|
|
114
172
|
@classmethod
|
|
115
173
|
def select(cls, *args):
|
|
174
|
+
"create a query wrapper to easily run sqlmodel queries on this model"
|
|
116
175
|
return QueryWrapper[cls](cls, *args)
|
|
117
176
|
|
|
177
|
+
def delete(self):
|
|
178
|
+
with get_session() as session:
|
|
179
|
+
if old_session := Session.object_session(self):
|
|
180
|
+
old_session.expunge(self)
|
|
181
|
+
|
|
182
|
+
session.delete(self)
|
|
183
|
+
session.commit()
|
|
184
|
+
return True
|
|
185
|
+
|
|
118
186
|
def save(self):
|
|
119
187
|
with get_session() as session:
|
|
120
188
|
if old_session := Session.object_session(self):
|
|
@@ -145,10 +213,16 @@ class BaseModel(SQLModel):
|
|
|
145
213
|
"""
|
|
146
214
|
Returns the number of records in the database.
|
|
147
215
|
"""
|
|
148
|
-
|
|
216
|
+
with get_session() as session:
|
|
217
|
+
return session.scalar(sm.select(sm.func.count()).select_from(cls))
|
|
218
|
+
|
|
219
|
+
# TODO got to be a better way to fwd these along...
|
|
220
|
+
@classmethod
|
|
221
|
+
def first(cls):
|
|
222
|
+
return cls.select().first()
|
|
149
223
|
|
|
150
224
|
# TODO throw an error if this field is set on the model
|
|
151
|
-
def is_new(self):
|
|
225
|
+
def is_new(self) -> bool:
|
|
152
226
|
return not self._sa_instance_state.has_identity
|
|
153
227
|
|
|
154
228
|
@classmethod
|
|
@@ -186,29 +260,31 @@ class BaseModel(SQLModel):
|
|
|
186
260
|
# field it will result in `True`, which will return all records, and not give you any typing
|
|
187
261
|
# errors. Dangerous when iterating on structure quickly
|
|
188
262
|
# TODO can we pass the generic of the superclass in?
|
|
189
|
-
|
|
263
|
+
# TODO can we type the method signature a bit better?
|
|
190
264
|
# def get(cls, *args: sa.BinaryExpression, **kwargs: t.Any):
|
|
265
|
+
@classmethod
|
|
191
266
|
def get(cls, *args: t.Any, **kwargs: t.Any):
|
|
192
267
|
"""
|
|
193
268
|
Gets a single record from the database. Pass an PK ID or a kwarg to filter by.
|
|
194
269
|
"""
|
|
195
270
|
|
|
271
|
+
# TODO id is hardcoded, not good! Need to dynamically pick the best uid field
|
|
272
|
+
id_field_name = "id"
|
|
273
|
+
|
|
196
274
|
# special case for getting by ID
|
|
197
|
-
if len(args) == 1 and isinstance(args[0], int):
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
args = []
|
|
201
|
-
elif len(args) == 1 and isinstance(args[0], TypeID):
|
|
202
|
-
kwargs["id"] = args[0]
|
|
203
|
-
args = []
|
|
275
|
+
if len(args) == 1 and isinstance(args[0], (int, TypeID, str, UUID)):
|
|
276
|
+
kwargs[id_field_name] = args[0]
|
|
277
|
+
args = ()
|
|
204
278
|
|
|
205
279
|
statement = select(cls).filter(*args).filter_by(**kwargs)
|
|
206
|
-
|
|
280
|
+
|
|
281
|
+
with get_session() as session:
|
|
282
|
+
return session.exec(statement).first()
|
|
207
283
|
|
|
208
284
|
@classmethod
|
|
209
285
|
def all(cls):
|
|
210
286
|
with get_session() as session:
|
|
211
|
-
results = session.exec(
|
|
287
|
+
results = session.exec(sm.select(cls))
|
|
212
288
|
|
|
213
289
|
# TODO do we need this or can we just return results?
|
|
214
290
|
for result in results:
|
|
@@ -222,7 +298,7 @@ class BaseModel(SQLModel):
|
|
|
222
298
|
Helpful for testing and console debugging.
|
|
223
299
|
"""
|
|
224
300
|
|
|
225
|
-
query =
|
|
301
|
+
query = sm.select(cls).order_by(sa.sql.func.random()).limit(1)
|
|
226
302
|
|
|
227
303
|
with get_session() as session:
|
|
228
304
|
return session.exec(query).one()
|
activemodel/celery.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Do not import unless you have Celery/Kombu installed.
|
|
3
|
+
|
|
4
|
+
In order for TypeID objects to be properly handled by celery, a custom encoder must be registered.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from kombu.utils.json import register_type
|
|
8
|
+
from typeid import TypeID
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def register_celery_typeid_encoder():
|
|
12
|
+
"this ensures TypeID objects passed as arguments to a delayed function are properly serialized"
|
|
13
|
+
|
|
14
|
+
def class_full_name(clz) -> str:
|
|
15
|
+
return ".".join([clz.__module__, clz.__qualname__])
|
|
16
|
+
|
|
17
|
+
def _encoder(obj: TypeID) -> str:
|
|
18
|
+
return str(obj)
|
|
19
|
+
|
|
20
|
+
def _decoder(data: str) -> TypeID:
|
|
21
|
+
return TypeID.from_string(data)
|
|
22
|
+
|
|
23
|
+
register_type(
|
|
24
|
+
TypeID,
|
|
25
|
+
class_full_name(TypeID),
|
|
26
|
+
encoder=_encoder,
|
|
27
|
+
decoder=_decoder,
|
|
28
|
+
)
|
activemodel/errors.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pydantic has a great DX for adding docstrings to fields. This allows devs to easily document the fields of a model.
|
|
3
|
+
|
|
4
|
+
Making sure these docstrings make their way to the DB schema is helpful for a bunch of reasons (LLM understanding being one of them).
|
|
5
|
+
|
|
6
|
+
This patch mutates a core sqlmodel function which translates pydantic FieldInfo objects into sqlalchemy Column objects. It adds the field description as a comment to the column.
|
|
7
|
+
|
|
8
|
+
Note that FieldInfo *from pydantic* is used when a "bare" field is defined. This can be confusing, because when inspecting model fields, the class name looks exactly the same.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from typing import (
|
|
12
|
+
TYPE_CHECKING,
|
|
13
|
+
Any,
|
|
14
|
+
Dict,
|
|
15
|
+
Sequence,
|
|
16
|
+
cast,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
import sqlmodel
|
|
20
|
+
from pydantic.fields import FieldInfo as PydanticFieldInfo
|
|
21
|
+
from sqlalchemy import (
|
|
22
|
+
Column,
|
|
23
|
+
ForeignKey,
|
|
24
|
+
)
|
|
25
|
+
from sqlmodel._compat import ( # type: ignore[attr-defined]
|
|
26
|
+
IS_PYDANTIC_V2,
|
|
27
|
+
ModelMetaclass,
|
|
28
|
+
Representation,
|
|
29
|
+
Undefined,
|
|
30
|
+
UndefinedType,
|
|
31
|
+
is_field_noneable,
|
|
32
|
+
)
|
|
33
|
+
from sqlmodel.main import FieldInfo, get_sqlalchemy_type
|
|
34
|
+
|
|
35
|
+
from activemodel.utils import hash_function_code
|
|
36
|
+
|
|
37
|
+
if TYPE_CHECKING:
|
|
38
|
+
from pydantic._internal._model_construction import ModelMetaclass as ModelMetaclass
|
|
39
|
+
from pydantic._internal._repr import Representation as Representation
|
|
40
|
+
from pydantic_core import PydanticUndefined as Undefined
|
|
41
|
+
from pydantic_core import PydanticUndefinedType as UndefinedType
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
assert (
|
|
45
|
+
hash_function_code(sqlmodel.main.get_column_from_field)
|
|
46
|
+
== "398006ef8fd8da191ca1a271ef25b6e135da0f400a80df2f29526d8674f9ec51"
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def get_column_from_field(field: PydanticFieldInfo | FieldInfo) -> Column: # type: ignore
|
|
51
|
+
"""
|
|
52
|
+
Takes a field definition, which can either come from the sqlmodel FieldInfo class or the pydantic variant of that class,
|
|
53
|
+
and converts it into a sqlalchemy Column object.
|
|
54
|
+
"""
|
|
55
|
+
if IS_PYDANTIC_V2:
|
|
56
|
+
field_info = field
|
|
57
|
+
else:
|
|
58
|
+
field_info = field.field_info
|
|
59
|
+
|
|
60
|
+
sa_column = getattr(field_info, "sa_column", Undefined)
|
|
61
|
+
if isinstance(sa_column, Column):
|
|
62
|
+
# IMPORTANT: change from the original function
|
|
63
|
+
if not sa_column.comment and (field_comment := field_info.description):
|
|
64
|
+
sa_column.comment = field_comment
|
|
65
|
+
return sa_column
|
|
66
|
+
|
|
67
|
+
primary_key = getattr(field_info, "primary_key", Undefined)
|
|
68
|
+
if primary_key is Undefined:
|
|
69
|
+
primary_key = False
|
|
70
|
+
|
|
71
|
+
index = getattr(field_info, "index", Undefined)
|
|
72
|
+
if index is Undefined:
|
|
73
|
+
index = False
|
|
74
|
+
|
|
75
|
+
nullable = not primary_key and is_field_noneable(field)
|
|
76
|
+
# Override derived nullability if the nullable property is set explicitly
|
|
77
|
+
# on the field
|
|
78
|
+
field_nullable = getattr(field_info, "nullable", Undefined) # noqa: B009
|
|
79
|
+
if field_nullable is not Undefined:
|
|
80
|
+
assert not isinstance(field_nullable, UndefinedType)
|
|
81
|
+
nullable = field_nullable
|
|
82
|
+
args = []
|
|
83
|
+
foreign_key = getattr(field_info, "foreign_key", Undefined)
|
|
84
|
+
if foreign_key is Undefined:
|
|
85
|
+
foreign_key = None
|
|
86
|
+
unique = getattr(field_info, "unique", Undefined)
|
|
87
|
+
if unique is Undefined:
|
|
88
|
+
unique = False
|
|
89
|
+
if foreign_key:
|
|
90
|
+
if field_info.ondelete == "SET NULL" and not nullable:
|
|
91
|
+
raise RuntimeError('ondelete="SET NULL" requires nullable=True')
|
|
92
|
+
assert isinstance(foreign_key, str)
|
|
93
|
+
ondelete = getattr(field_info, "ondelete", Undefined)
|
|
94
|
+
if ondelete is Undefined:
|
|
95
|
+
ondelete = None
|
|
96
|
+
assert isinstance(ondelete, (str, type(None))) # for typing
|
|
97
|
+
args.append(ForeignKey(foreign_key, ondelete=ondelete))
|
|
98
|
+
kwargs = {
|
|
99
|
+
"primary_key": primary_key,
|
|
100
|
+
"nullable": nullable,
|
|
101
|
+
"index": index,
|
|
102
|
+
"unique": unique,
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
sa_default = Undefined
|
|
106
|
+
if field_info.default_factory:
|
|
107
|
+
sa_default = field_info.default_factory
|
|
108
|
+
elif field_info.default is not Undefined:
|
|
109
|
+
sa_default = field_info.default
|
|
110
|
+
if sa_default is not Undefined:
|
|
111
|
+
kwargs["default"] = sa_default
|
|
112
|
+
|
|
113
|
+
sa_column_args = getattr(field_info, "sa_column_args", Undefined)
|
|
114
|
+
if sa_column_args is not Undefined:
|
|
115
|
+
args.extend(list(cast(Sequence[Any], sa_column_args)))
|
|
116
|
+
|
|
117
|
+
sa_column_kwargs = getattr(field_info, "sa_column_kwargs", Undefined)
|
|
118
|
+
|
|
119
|
+
# IMPORTANT: change from the original function
|
|
120
|
+
if field_info.description:
|
|
121
|
+
if sa_column_kwargs is Undefined:
|
|
122
|
+
sa_column_kwargs = {}
|
|
123
|
+
|
|
124
|
+
assert isinstance(sa_column_kwargs, dict)
|
|
125
|
+
|
|
126
|
+
# only update comments if not already set
|
|
127
|
+
if "comment" not in sa_column_kwargs:
|
|
128
|
+
sa_column_kwargs["comment"] = field_info.description
|
|
129
|
+
|
|
130
|
+
if sa_column_kwargs is not Undefined:
|
|
131
|
+
kwargs.update(cast(Dict[Any, Any], sa_column_kwargs))
|
|
132
|
+
|
|
133
|
+
sa_type = get_sqlalchemy_type(field)
|
|
134
|
+
return Column(sa_type, *args, **kwargs) # type: ignore
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
sqlmodel.main.get_column_from_field = get_column_from_field
|
activemodel/mixins/__init__.py
CHANGED
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
from types import UnionType
|
|
2
|
+
from typing import get_args, get_origin
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel as PydanticBaseModel
|
|
5
|
+
from sqlalchemy.orm import reconstructor
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class PydanticJSONMixin:
|
|
9
|
+
"""
|
|
10
|
+
By default, SQLModel does not convert JSONB columns into pydantic models when they are loaded from the database.
|
|
11
|
+
|
|
12
|
+
This mixin, combined with a custom serializer, fixes that issue.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
@reconstructor
|
|
16
|
+
def init_on_load(self):
|
|
17
|
+
# TODO do we need to inspect sa_type
|
|
18
|
+
for field_name, field_info in self.model_fields.items():
|
|
19
|
+
raw_value = getattr(self, field_name, None)
|
|
20
|
+
|
|
21
|
+
if raw_value is None:
|
|
22
|
+
continue
|
|
23
|
+
|
|
24
|
+
annotation = field_info.annotation
|
|
25
|
+
origin = get_origin(annotation)
|
|
26
|
+
|
|
27
|
+
# e.g. `dict` or `dict[str, str]`, we don't want to do anything with these
|
|
28
|
+
if origin is dict:
|
|
29
|
+
continue
|
|
30
|
+
|
|
31
|
+
annotation_args = get_args(annotation)
|
|
32
|
+
is_top_level_list = origin is list
|
|
33
|
+
|
|
34
|
+
# if origin is not None:
|
|
35
|
+
# assert annotation.__class__ == origin
|
|
36
|
+
|
|
37
|
+
model_cls = annotation
|
|
38
|
+
|
|
39
|
+
# e.g. SomePydanticModel | None or list[SomePydanticModel] | None
|
|
40
|
+
# annotation_args are (type, NoneType) in this case
|
|
41
|
+
if isinstance(annotation, UnionType):
|
|
42
|
+
non_none_types = [t for t in annotation_args if t is not type(None)]
|
|
43
|
+
|
|
44
|
+
if len(non_none_types) == 1:
|
|
45
|
+
model_cls = non_none_types[0]
|
|
46
|
+
|
|
47
|
+
# e.g. list[SomePydanticModel] | None, we have to unpack it
|
|
48
|
+
# model_cls will print as a list, but it contains a subtype if you dig into it
|
|
49
|
+
if (
|
|
50
|
+
get_origin(model_cls) is list
|
|
51
|
+
and len(list_annotation_args := get_args(model_cls)) == 1
|
|
52
|
+
):
|
|
53
|
+
model_cls = list_annotation_args[0]
|
|
54
|
+
is_top_level_list = True
|
|
55
|
+
|
|
56
|
+
# e.g. list[SomePydanticModel] or list[SomePydanticModel] | None
|
|
57
|
+
# iterate through the list and run each item through the pydantic model
|
|
58
|
+
if is_top_level_list:
|
|
59
|
+
if isinstance(raw_value, list) and issubclass(
|
|
60
|
+
model_cls, PydanticBaseModel
|
|
61
|
+
):
|
|
62
|
+
parsed_value = [model_cls(**item) for item in raw_value]
|
|
63
|
+
setattr(self, field_name, parsed_value)
|
|
64
|
+
|
|
65
|
+
continue
|
|
66
|
+
|
|
67
|
+
# single class
|
|
68
|
+
if issubclass(model_cls, PydanticBaseModel):
|
|
69
|
+
setattr(self, field_name, model_cls(**raw_value))
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
|
|
3
|
+
import sqlalchemy as sa
|
|
4
|
+
from sqlmodel import Field
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class SoftDeletionMixin:
|
|
8
|
+
deleted_at: datetime = Field(
|
|
9
|
+
default=None,
|
|
10
|
+
nullable=True,
|
|
11
|
+
# TODO https://github.com/fastapi/sqlmodel/discussions/1228
|
|
12
|
+
sa_type=sa.DateTime(timezone=True), # type: ignore
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
def soft_delete(self):
|
|
16
|
+
self.deleted_at = datetime.now()
|
|
17
|
+
raise NotImplementedError("Soft deletion is not implemented")
|
activemodel/mixins/typeid.py
CHANGED
|
@@ -1,36 +1,46 @@
|
|
|
1
|
-
import uuid
|
|
2
|
-
|
|
3
1
|
from sqlmodel import Column, Field
|
|
4
2
|
from typeid import TypeID
|
|
5
3
|
|
|
6
4
|
from activemodel.types.typeid import TypeIDType
|
|
7
5
|
|
|
6
|
+
# global list of prefixes to ensure uniqueness
|
|
7
|
+
_prefixes = []
|
|
8
|
+
|
|
8
9
|
|
|
9
10
|
def TypeIDMixin(prefix: str):
|
|
11
|
+
assert prefix
|
|
12
|
+
assert prefix not in _prefixes, (
|
|
13
|
+
f"prefix {prefix} already exists, pick a different one"
|
|
14
|
+
)
|
|
15
|
+
|
|
10
16
|
class _TypeIDMixin:
|
|
11
|
-
id:
|
|
12
|
-
sa_column=Column(TypeIDType(prefix), primary_key=True),
|
|
17
|
+
id: TypeIDType = Field(
|
|
18
|
+
sa_column=Column(TypeIDType(prefix), primary_key=True, nullable=False),
|
|
13
19
|
default_factory=lambda: TypeID(prefix),
|
|
14
20
|
)
|
|
15
21
|
|
|
22
|
+
_prefixes.append(prefix)
|
|
23
|
+
|
|
16
24
|
return _TypeIDMixin
|
|
17
25
|
|
|
18
26
|
|
|
19
|
-
class
|
|
20
|
-
|
|
21
|
-
|
|
27
|
+
# TODO not sure if I love the idea of a dynamic class for each mixin as used above
|
|
28
|
+
# may give this approach another shot in the future
|
|
29
|
+
# class TypeIDMixin2:
|
|
30
|
+
# """
|
|
31
|
+
# Mixin class that adds a TypeID primary key to models.
|
|
22
32
|
|
|
23
33
|
|
|
24
|
-
|
|
25
|
-
|
|
34
|
+
# >>> class MyModel(BaseModel, TypeIDMixin, prefix="xyz", table=True):
|
|
35
|
+
# >>> name: str
|
|
26
36
|
|
|
27
|
-
|
|
28
|
-
|
|
37
|
+
# Will automatically have an `id` field with prefix "xyz"
|
|
38
|
+
# """
|
|
29
39
|
|
|
30
|
-
|
|
31
|
-
|
|
40
|
+
# def __init_subclass__(cls, *, prefix: str, **kwargs):
|
|
41
|
+
# super().__init_subclass__(**kwargs)
|
|
32
42
|
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
43
|
+
# cls.id: uuid.UUID = Field(
|
|
44
|
+
# sa_column=Column(TypeIDType(prefix), primary_key=True),
|
|
45
|
+
# default_factory=lambda: TypeID(prefix),
|
|
46
|
+
# )
|
activemodel/pytest/truncate.py
CHANGED
|
@@ -40,7 +40,7 @@ def database_reset_truncate():
|
|
|
40
40
|
transaction = connection.begin()
|
|
41
41
|
|
|
42
42
|
if table.name not in exception_tables:
|
|
43
|
-
logger.debug("truncating table
|
|
43
|
+
logger.debug(f"truncating table={table.name}")
|
|
44
44
|
connection.execute(table.delete())
|
|
45
45
|
|
|
46
46
|
transaction.commit()
|
activemodel/query_wrapper.py
CHANGED
|
@@ -1,20 +1,26 @@
|
|
|
1
|
-
import sqlmodel
|
|
1
|
+
import sqlmodel as sm
|
|
2
|
+
from sqlmodel.sql.expression import SelectOfScalar
|
|
2
3
|
|
|
4
|
+
from .session_manager import get_session
|
|
5
|
+
from .utils import compile_sql
|
|
3
6
|
|
|
4
|
-
|
|
7
|
+
|
|
8
|
+
class QueryWrapper[T: sm.SQLModel]:
|
|
5
9
|
"""
|
|
6
10
|
Make it easy to run queries off of a model
|
|
7
11
|
"""
|
|
8
12
|
|
|
9
|
-
|
|
13
|
+
target: SelectOfScalar[T]
|
|
14
|
+
|
|
15
|
+
def __init__(self, cls: T, *args) -> None:
|
|
10
16
|
# TODO add generics here
|
|
11
17
|
# self.target: SelectOfScalar[T] = sql.select(cls)
|
|
12
18
|
|
|
13
19
|
if args:
|
|
14
20
|
# very naive, let's assume the args are specific select statements
|
|
15
|
-
self.target =
|
|
21
|
+
self.target = sm.select(*args).select_from(cls)
|
|
16
22
|
else:
|
|
17
|
-
self.target =
|
|
23
|
+
self.target = sm.select(cls)
|
|
18
24
|
|
|
19
25
|
# TODO the .exec results should be handled in one shot
|
|
20
26
|
|
|
@@ -23,6 +29,7 @@ class QueryWrapper[T]:
|
|
|
23
29
|
return session.exec(self.target).first()
|
|
24
30
|
|
|
25
31
|
def one(self):
|
|
32
|
+
"requires exactly one result in the dataset"
|
|
26
33
|
with get_session() as session:
|
|
27
34
|
return session.exec(self.target).one()
|
|
28
35
|
|
|
@@ -32,8 +39,14 @@ class QueryWrapper[T]:
|
|
|
32
39
|
for row in result:
|
|
33
40
|
yield row
|
|
34
41
|
|
|
42
|
+
def count(self):
|
|
43
|
+
"""
|
|
44
|
+
I did some basic tests
|
|
45
|
+
"""
|
|
46
|
+
with get_session() as session:
|
|
47
|
+
return session.scalar(sm.select(sm.func.count()).select_from(self.target))
|
|
48
|
+
|
|
35
49
|
def exec(self):
|
|
36
|
-
# TODO do we really need a unique session each time?
|
|
37
50
|
with get_session() as session:
|
|
38
51
|
return session.exec(self.target)
|
|
39
52
|
|
|
@@ -43,19 +56,20 @@ class QueryWrapper[T]:
|
|
|
43
56
|
|
|
44
57
|
def __getattr__(self, name):
|
|
45
58
|
"""
|
|
46
|
-
This
|
|
59
|
+
This implements the magic that forwards function calls to sqlalchemy.
|
|
47
60
|
"""
|
|
48
61
|
|
|
49
62
|
# TODO prefer methods defined in this class
|
|
63
|
+
|
|
50
64
|
if not hasattr(self.target, name):
|
|
51
65
|
return super().__getattribute__(name)
|
|
52
66
|
|
|
53
|
-
|
|
67
|
+
sqlalchemy_target = getattr(self.target, name)
|
|
54
68
|
|
|
55
|
-
if callable(
|
|
69
|
+
if callable(sqlalchemy_target):
|
|
56
70
|
|
|
57
71
|
def wrapper(*args, **kwargs):
|
|
58
|
-
result =
|
|
72
|
+
result = sqlalchemy_target(*args, **kwargs)
|
|
59
73
|
self.target = result
|
|
60
74
|
return self
|
|
61
75
|
|