activemodel 0.5.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- activemodel/__init__.py +2 -0
- activemodel/base_model.py +141 -33
- activemodel/celery.py +33 -0
- activemodel/errors.py +6 -0
- activemodel/get_column_from_field_patch.py +139 -0
- activemodel/mixins/__init__.py +2 -0
- activemodel/mixins/pydantic_json.py +82 -0
- activemodel/mixins/soft_delete.py +17 -0
- activemodel/mixins/typeid.py +27 -17
- activemodel/pytest/transaction.py +34 -22
- activemodel/pytest/truncate.py +1 -1
- activemodel/query_wrapper.py +24 -10
- activemodel/session_manager.py +92 -5
- activemodel/types/__init__.py +1 -0
- activemodel/types/typeid.py +141 -5
- activemodel/utils.py +51 -1
- activemodel-0.8.0.dist-info/METADATA +282 -0
- activemodel-0.8.0.dist-info/RECORD +24 -0
- {activemodel-0.5.0.dist-info → activemodel-0.8.0.dist-info}/WHEEL +1 -2
- activemodel/_session_manager.py +0 -153
- activemodel-0.5.0.dist-info/METADATA +0 -66
- activemodel-0.5.0.dist-info/RECORD +0 -20
- activemodel-0.5.0.dist-info/top_level.txt +0 -1
- {activemodel-0.5.0.dist-info → activemodel-0.8.0.dist-info}/entry_points.txt +0 -0
- {activemodel-0.5.0.dist-info → activemodel-0.8.0.dist-info/licenses}/LICENSE +0 -0
|
@@ -1,12 +1,19 @@
|
|
|
1
1
|
from activemodel import SessionManager
|
|
2
2
|
|
|
3
|
+
from ..logger import logger
|
|
4
|
+
|
|
3
5
|
|
|
4
6
|
def database_reset_transaction():
|
|
5
7
|
"""
|
|
6
8
|
Wrap all database interactions for a given test in a nested transaction and roll it back after the test.
|
|
7
9
|
|
|
8
10
|
>>> from activemodel.pytest import database_reset_transaction
|
|
9
|
-
>>> pytest.fixture(scope="function", autouse=True)(database_reset_transaction)
|
|
11
|
+
>>> database_reset_transaction = pytest.fixture(scope="function", autouse=True)(database_reset_transaction)
|
|
12
|
+
|
|
13
|
+
Transaction-based DB cleaning does *not* work if the DB mutations are happening in a separate process, which should
|
|
14
|
+
use spawn, because the same session is not shared across processes. Note that using `fork` is dangerous.
|
|
15
|
+
|
|
16
|
+
In this case, you should use the truncate.
|
|
10
17
|
|
|
11
18
|
References:
|
|
12
19
|
|
|
@@ -14,38 +21,43 @@ def database_reset_transaction():
|
|
|
14
21
|
- https://aalvarez.me/posts/setting-up-a-sqlalchemy-and-pytest-based-test-suite/
|
|
15
22
|
- https://github.com/nickjj/docker-flask-example/blob/93af9f4fbf185098ffb1d120ee0693abcd77a38b/test/conftest.py#L77
|
|
16
23
|
- https://github.com/caiola/vinhos.com/blob/c47d0a5d7a4bf290c1b726561d1e8f5d2ac29bc8/backend/test/conftest.py#L46
|
|
24
|
+
- https://stackoverflow.com/questions/64095876/multiprocessing-fork-vs-spawn
|
|
25
|
+
|
|
26
|
+
Using a named SAVEPOINT does not give us anything extra, so we are not using it.
|
|
17
27
|
"""
|
|
18
28
|
|
|
19
29
|
engine = SessionManager.get_instance().get_engine()
|
|
20
30
|
|
|
31
|
+
logger.info("starting database transaction")
|
|
32
|
+
|
|
21
33
|
with engine.begin() as connection:
|
|
22
34
|
transaction = connection.begin_nested()
|
|
23
35
|
|
|
36
|
+
if SessionManager.get_instance().session_connection is not None:
|
|
37
|
+
logger.warning("session override already exists")
|
|
38
|
+
# TODO should we throw an exception here?
|
|
39
|
+
|
|
24
40
|
SessionManager.get_instance().session_connection = connection
|
|
25
41
|
|
|
26
42
|
try:
|
|
27
|
-
|
|
43
|
+
with SessionManager.get_instance().get_session() as factory_session:
|
|
44
|
+
try:
|
|
45
|
+
from factory.alchemy import SQLAlchemyModelFactory
|
|
46
|
+
|
|
47
|
+
# Ensure that all factories use the same session
|
|
48
|
+
for factory in SQLAlchemyModelFactory.__subclasses__():
|
|
49
|
+
factory._meta.sqlalchemy_session = factory_session
|
|
50
|
+
factory._meta.sqlalchemy_session_persistence = "commit"
|
|
51
|
+
except ImportError:
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
yield
|
|
28
55
|
finally:
|
|
29
|
-
|
|
30
|
-
# TODO is this necessary?
|
|
31
|
-
connection.close()
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
# TODO unsure if this adds any value beyond the above approach
|
|
35
|
-
# def database_reset_named_truncation():
|
|
36
|
-
# start_truncation_query = """
|
|
37
|
-
# BEGIN;
|
|
38
|
-
# SAVEPOINT test_truncation_savepoint;
|
|
39
|
-
# """
|
|
56
|
+
logger.debug("rolling back transaction")
|
|
40
57
|
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
# yield
|
|
58
|
+
transaction.rollback()
|
|
44
59
|
|
|
45
|
-
#
|
|
46
|
-
|
|
47
|
-
# RELEASE SAVEPOINT test_truncation_savepoint;
|
|
48
|
-
# ROLLBACK;
|
|
49
|
-
# """
|
|
60
|
+
# TODO is this necessary? unclear
|
|
61
|
+
connection.close()
|
|
50
62
|
|
|
51
|
-
|
|
63
|
+
SessionManager.get_instance().session_connection = None
|
activemodel/pytest/truncate.py
CHANGED
|
@@ -40,7 +40,7 @@ def database_reset_truncate():
|
|
|
40
40
|
transaction = connection.begin()
|
|
41
41
|
|
|
42
42
|
if table.name not in exception_tables:
|
|
43
|
-
logger.debug("truncating table
|
|
43
|
+
logger.debug(f"truncating table={table.name}")
|
|
44
44
|
connection.execute(table.delete())
|
|
45
45
|
|
|
46
46
|
transaction.commit()
|
activemodel/query_wrapper.py
CHANGED
|
@@ -1,20 +1,26 @@
|
|
|
1
|
-
import sqlmodel
|
|
1
|
+
import sqlmodel as sm
|
|
2
|
+
from sqlmodel.sql.expression import SelectOfScalar
|
|
2
3
|
|
|
4
|
+
from .session_manager import get_session
|
|
5
|
+
from .utils import compile_sql
|
|
3
6
|
|
|
4
|
-
|
|
7
|
+
|
|
8
|
+
class QueryWrapper[T: sm.SQLModel]:
|
|
5
9
|
"""
|
|
6
10
|
Make it easy to run queries off of a model
|
|
7
11
|
"""
|
|
8
12
|
|
|
9
|
-
|
|
13
|
+
target: SelectOfScalar[T]
|
|
14
|
+
|
|
15
|
+
def __init__(self, cls: T, *args) -> None:
|
|
10
16
|
# TODO add generics here
|
|
11
17
|
# self.target: SelectOfScalar[T] = sql.select(cls)
|
|
12
18
|
|
|
13
19
|
if args:
|
|
14
20
|
# very naive, let's assume the args are specific select statements
|
|
15
|
-
self.target =
|
|
21
|
+
self.target = sm.select(*args).select_from(cls)
|
|
16
22
|
else:
|
|
17
|
-
self.target =
|
|
23
|
+
self.target = sm.select(cls)
|
|
18
24
|
|
|
19
25
|
# TODO the .exec results should be handled in one shot
|
|
20
26
|
|
|
@@ -23,6 +29,7 @@ class QueryWrapper[T]:
|
|
|
23
29
|
return session.exec(self.target).first()
|
|
24
30
|
|
|
25
31
|
def one(self):
|
|
32
|
+
"requires exactly one result in the dataset"
|
|
26
33
|
with get_session() as session:
|
|
27
34
|
return session.exec(self.target).one()
|
|
28
35
|
|
|
@@ -32,8 +39,14 @@ class QueryWrapper[T]:
|
|
|
32
39
|
for row in result:
|
|
33
40
|
yield row
|
|
34
41
|
|
|
42
|
+
def count(self):
|
|
43
|
+
"""
|
|
44
|
+
I did some basic tests
|
|
45
|
+
"""
|
|
46
|
+
with get_session() as session:
|
|
47
|
+
return session.scalar(sm.select(sm.func.count()).select_from(self.target))
|
|
48
|
+
|
|
35
49
|
def exec(self):
|
|
36
|
-
# TODO do we really need a unique session each time?
|
|
37
50
|
with get_session() as session:
|
|
38
51
|
return session.exec(self.target)
|
|
39
52
|
|
|
@@ -43,19 +56,20 @@ class QueryWrapper[T]:
|
|
|
43
56
|
|
|
44
57
|
def __getattr__(self, name):
|
|
45
58
|
"""
|
|
46
|
-
This
|
|
59
|
+
This implements the magic that forwards function calls to sqlalchemy.
|
|
47
60
|
"""
|
|
48
61
|
|
|
49
62
|
# TODO prefer methods defined in this class
|
|
63
|
+
|
|
50
64
|
if not hasattr(self.target, name):
|
|
51
65
|
return super().__getattribute__(name)
|
|
52
66
|
|
|
53
|
-
|
|
67
|
+
sqlalchemy_target = getattr(self.target, name)
|
|
54
68
|
|
|
55
|
-
if callable(
|
|
69
|
+
if callable(sqlalchemy_target):
|
|
56
70
|
|
|
57
71
|
def wrapper(*args, **kwargs):
|
|
58
|
-
result =
|
|
72
|
+
result = sqlalchemy_target(*args, **kwargs)
|
|
59
73
|
self.target = result
|
|
60
74
|
return self
|
|
61
75
|
|
activemodel/session_manager.py
CHANGED
|
@@ -3,24 +3,58 @@ Class to make managing sessions with SQL Model easy. Also provides a common entr
|
|
|
3
3
|
database environment when testing.
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
|
+
import contextlib
|
|
7
|
+
import contextvars
|
|
8
|
+
import json
|
|
6
9
|
import typing as t
|
|
7
10
|
|
|
8
11
|
from decouple import config
|
|
9
|
-
from
|
|
12
|
+
from pydantic import BaseModel
|
|
13
|
+
from sqlalchemy import Connection, Engine
|
|
10
14
|
from sqlmodel import Session, create_engine
|
|
11
15
|
|
|
12
16
|
|
|
17
|
+
def _serialize_pydantic_model(model: BaseModel | list[BaseModel] | None) -> str | None:
|
|
18
|
+
"""
|
|
19
|
+
Pydantic models do not serialize to JSON. You'll get an error such as:
|
|
20
|
+
|
|
21
|
+
'TypeError: Object of type TranscriptEntry is not JSON serializable'
|
|
22
|
+
|
|
23
|
+
https://github.com/fastapi/sqlmodel/issues/63#issuecomment-2581016387
|
|
24
|
+
|
|
25
|
+
This custom serializer is passed to the DB engine to properly serialize pydantic models to
|
|
26
|
+
JSON for storage in a JSONB column.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
# TODO I bet this will fail on lists with mixed types
|
|
30
|
+
|
|
31
|
+
if isinstance(model, BaseModel):
|
|
32
|
+
return model.model_dump_json()
|
|
33
|
+
if isinstance(model, list):
|
|
34
|
+
# not everything in a list is a pydantic model
|
|
35
|
+
def dump_if_model(m):
|
|
36
|
+
if isinstance(m, BaseModel):
|
|
37
|
+
return m.model_dump()
|
|
38
|
+
return m
|
|
39
|
+
|
|
40
|
+
return json.dumps([dump_if_model(m) for m in model])
|
|
41
|
+
else:
|
|
42
|
+
return json.dumps(model)
|
|
43
|
+
|
|
44
|
+
|
|
13
45
|
class SessionManager:
|
|
14
46
|
_instance: t.ClassVar[t.Optional["SessionManager"]] = None
|
|
47
|
+
"singleton instance of SessionManager"
|
|
15
48
|
|
|
16
|
-
session_connection:
|
|
49
|
+
session_connection: Connection | None
|
|
50
|
+
"optionally specify a specific session connection to use for all get_session() calls, useful for testing"
|
|
17
51
|
|
|
18
52
|
@classmethod
|
|
19
53
|
def get_instance(cls, database_url: str | None = None) -> "SessionManager":
|
|
20
54
|
if cls._instance is None:
|
|
21
|
-
assert (
|
|
22
|
-
|
|
23
|
-
)
|
|
55
|
+
assert database_url is not None, (
|
|
56
|
+
"Database URL required for first initialization"
|
|
57
|
+
)
|
|
24
58
|
cls._instance = cls(database_url)
|
|
25
59
|
|
|
26
60
|
return cls._instance
|
|
@@ -28,6 +62,7 @@ class SessionManager:
|
|
|
28
62
|
def __init__(self, database_url: str):
|
|
29
63
|
self._database_url = database_url
|
|
30
64
|
self._engine = None
|
|
65
|
+
|
|
31
66
|
self.session_connection = None
|
|
32
67
|
|
|
33
68
|
# TODO why is this type not reimported?
|
|
@@ -35,6 +70,8 @@ class SessionManager:
|
|
|
35
70
|
if not self._engine:
|
|
36
71
|
self._engine = create_engine(
|
|
37
72
|
self._database_url,
|
|
73
|
+
# NOTE very important! This enables pydantic models to be serialized for JSONB columns
|
|
74
|
+
json_serializer=_serialize_pydantic_model,
|
|
38
75
|
echo=config("ACTIVEMODEL_LOG_SQL", cast=bool, default=False),
|
|
39
76
|
# https://docs.sqlalchemy.org/en/20/core/pooling.html#disconnect-handling-pessimistic
|
|
40
77
|
pool_pre_ping=True,
|
|
@@ -44,6 +81,15 @@ class SessionManager:
|
|
|
44
81
|
return self._engine
|
|
45
82
|
|
|
46
83
|
def get_session(self):
|
|
84
|
+
if gsession := _session_context.get():
|
|
85
|
+
|
|
86
|
+
@contextlib.contextmanager
|
|
87
|
+
def _reuse_session():
|
|
88
|
+
yield gsession
|
|
89
|
+
|
|
90
|
+
return _reuse_session()
|
|
91
|
+
|
|
92
|
+
# a connection can generate nested transactions
|
|
47
93
|
if self.session_connection:
|
|
48
94
|
return Session(bind=self.session_connection)
|
|
49
95
|
|
|
@@ -51,6 +97,7 @@ class SessionManager:
|
|
|
51
97
|
|
|
52
98
|
|
|
53
99
|
def init(database_url: str):
|
|
100
|
+
"configure activemodel to connect to a specific database"
|
|
54
101
|
return SessionManager.get_instance(database_url)
|
|
55
102
|
|
|
56
103
|
|
|
@@ -60,3 +107,43 @@ def get_engine():
|
|
|
60
107
|
|
|
61
108
|
def get_session():
|
|
62
109
|
return SessionManager.get_instance().get_session()
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
# contextvars must be at the top-level of a module! You will not get a warning if you don't do this.
|
|
113
|
+
# ContextVar is implemented in C, so it's very special and is both thread-safe and asyncio safe. This variable gives us
|
|
114
|
+
# a place to persist a session to use globally across the application.
|
|
115
|
+
_session_context = contextvars.ContextVar[Session | None](
|
|
116
|
+
"session_context", default=None
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@contextlib.contextmanager
|
|
121
|
+
def global_session():
|
|
122
|
+
with SessionManager.get_instance().get_session() as s:
|
|
123
|
+
token = _session_context.set(s)
|
|
124
|
+
|
|
125
|
+
try:
|
|
126
|
+
yield s
|
|
127
|
+
finally:
|
|
128
|
+
_session_context.reset(token)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
async def aglobal_session():
|
|
132
|
+
"""
|
|
133
|
+
Use this as a fastapi dependency to get a session that is shared across the request:
|
|
134
|
+
|
|
135
|
+
>>> APIRouter(
|
|
136
|
+
>>> prefix="/internal/v1",
|
|
137
|
+
>>> dependencies=[
|
|
138
|
+
>>> Depends(aglobal_session),
|
|
139
|
+
>>> ]
|
|
140
|
+
>>> )
|
|
141
|
+
"""
|
|
142
|
+
|
|
143
|
+
with SessionManager.get_instance().get_session() as s:
|
|
144
|
+
token = _session_context.set(s)
|
|
145
|
+
|
|
146
|
+
try:
|
|
147
|
+
yield
|
|
148
|
+
finally:
|
|
149
|
+
_session_context.reset(token)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .typeid import TypeIDType
|
activemodel/types/typeid.py
CHANGED
|
@@ -3,11 +3,18 @@ Lifted from: https://github.com/akhundMurad/typeid-python/blob/main/examples/sql
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from typing import Optional
|
|
6
|
+
from uuid import UUID
|
|
6
7
|
|
|
8
|
+
from pydantic import (
|
|
9
|
+
GetJsonSchemaHandler,
|
|
10
|
+
)
|
|
11
|
+
from pydantic_core import CoreSchema, core_schema
|
|
7
12
|
from sqlalchemy import types
|
|
8
13
|
from sqlalchemy.util import generic_repr
|
|
9
14
|
from typeid import TypeID
|
|
10
15
|
|
|
16
|
+
from activemodel.errors import TypeIDValidationError
|
|
17
|
+
|
|
11
18
|
|
|
12
19
|
class TypeIDType(types.TypeDecorator):
|
|
13
20
|
"""
|
|
@@ -45,12 +52,141 @@ class TypeIDType(types.TypeDecorator):
|
|
|
45
52
|
)
|
|
46
53
|
|
|
47
54
|
def process_bind_param(self, value, dialect):
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
55
|
+
"""
|
|
56
|
+
This is run when a search query is built or ...
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
if value is None:
|
|
60
|
+
return None
|
|
61
|
+
|
|
62
|
+
if isinstance(value, UUID):
|
|
63
|
+
# then it's a UUID class, such as UUID('01942886-7afc-7129-8f57-db09137ed002')
|
|
64
|
+
return value
|
|
65
|
+
|
|
66
|
+
if isinstance(value, str) and value.startswith(self.prefix + "_"):
|
|
67
|
+
# then it's a TypeID such as 'user_01h45ytscbebyvny4gc8cr8ma2'
|
|
68
|
+
value = TypeID.from_string(value)
|
|
69
|
+
|
|
70
|
+
if isinstance(value, str):
|
|
71
|
+
# no prefix, raw UUID, let's coerce it into a UUID which SQLAlchemy can handle
|
|
72
|
+
# ex: '01942886-7afc-7129-8f57-db09137ed002'
|
|
73
|
+
return UUID(value)
|
|
74
|
+
|
|
75
|
+
if isinstance(value, TypeID):
|
|
76
|
+
# TODO in what case could this None prefix ever occur?
|
|
77
|
+
if self.prefix is None:
|
|
78
|
+
if value.prefix is None:
|
|
79
|
+
raise TypeIDValidationError(
|
|
80
|
+
"Must have a valid prefix set on the class"
|
|
81
|
+
)
|
|
82
|
+
else:
|
|
83
|
+
if value.prefix != self.prefix:
|
|
84
|
+
raise TypeIDValidationError(
|
|
85
|
+
f"Expected '{self.prefix}' but got '{value.prefix}'"
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
return value.uuid
|
|
52
89
|
|
|
53
|
-
|
|
90
|
+
raise ValueError("Unexpected input type")
|
|
54
91
|
|
|
55
92
|
def process_result_value(self, value, dialect):
|
|
93
|
+
if value is None:
|
|
94
|
+
return None
|
|
95
|
+
|
|
56
96
|
return TypeID.from_uuid(value, self.prefix)
|
|
97
|
+
|
|
98
|
+
# def coerce_compared_value(self, op, value):
|
|
99
|
+
# """
|
|
100
|
+
# This method is called when SQLAlchemy needs to compare a column to a value.
|
|
101
|
+
# By returning self, we indicate that this type can handle TypeID instances.
|
|
102
|
+
# """
|
|
103
|
+
# if isinstance(value, TypeID):
|
|
104
|
+
# return self
|
|
105
|
+
|
|
106
|
+
# return super().coerce_compared_value(op, value)
|
|
107
|
+
|
|
108
|
+
@classmethod
|
|
109
|
+
def __get_pydantic_core_schema__(
|
|
110
|
+
cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
|
|
111
|
+
) -> CoreSchema:
|
|
112
|
+
"""
|
|
113
|
+
This fixes the following error: 'Unable to serialize unknown type' by telling pydantic how to serialize this field.
|
|
114
|
+
|
|
115
|
+
Note that TypeIDType MUST be the type of the field in SQLModel otherwise you'll get serialization errors.
|
|
116
|
+
This is done automatically for the mixin but for any relationship fields you'll need to specify the type explicitly.
|
|
117
|
+
|
|
118
|
+
- https://github.com/karma-dev-team/karma-system/blob/ee0c1a06ab2cb7aaca6dc4818312e68c5c623365/app/server/value_objects/steam_id.py#L88
|
|
119
|
+
- https://github.com/hhimanshu/uv-workspaces/blob/main/packages/api/src/_lib/dto/typeid_field.py
|
|
120
|
+
- https://github.com/karma-dev-team/karma-system/blob/ee0c1a06ab2cb7aaca6dc4818312e68c5c623365/app/base/typeid/type_id.py#L14
|
|
121
|
+
- https://github.com/pydantic/pydantic/issues/10060
|
|
122
|
+
- https://github.com/fastapi/fastapi/discussions/10027
|
|
123
|
+
- https://github.com/alice-biometrics/petisco/blob/b01ef1b84949d156f73919e126ed77aa8e0b48dd/petisco/base/domain/model/uuid.py#L50
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
from_uuid_schema = core_schema.chain_schema(
|
|
127
|
+
[
|
|
128
|
+
# TODO not sure how this is different from the UUID schema, should try it out.
|
|
129
|
+
# core_schema.is_instance_schema(TypeID),
|
|
130
|
+
# core_schema.uuid_schema(),
|
|
131
|
+
core_schema.no_info_plain_validator_function(
|
|
132
|
+
TypeID.from_string,
|
|
133
|
+
json_schema_input_schema=core_schema.str_schema(),
|
|
134
|
+
),
|
|
135
|
+
]
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
return core_schema.json_or_python_schema(
|
|
139
|
+
json_schema=from_uuid_schema,
|
|
140
|
+
# TODO in the the future we could add more exact types
|
|
141
|
+
# metadata=core_schema.str_schema(
|
|
142
|
+
# pattern="^[0-9a-f]{24}$",
|
|
143
|
+
# min_length=24,
|
|
144
|
+
# max_length=24,
|
|
145
|
+
# ),
|
|
146
|
+
# metadata={
|
|
147
|
+
# "pydantic_js_input_core_schema": core_schema.str_schema(
|
|
148
|
+
# pattern="^[0-9a-f]{24}$",
|
|
149
|
+
# min_length=24,
|
|
150
|
+
# max_length=24,
|
|
151
|
+
# )
|
|
152
|
+
# },
|
|
153
|
+
python_schema=core_schema.union_schema([from_uuid_schema]),
|
|
154
|
+
serialization=core_schema.plain_serializer_function_ser_schema(
|
|
155
|
+
lambda x: str(x)
|
|
156
|
+
),
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
@classmethod
|
|
160
|
+
def __get_pydantic_json_schema__(
|
|
161
|
+
cls, schema: CoreSchema, handler: GetJsonSchemaHandler
|
|
162
|
+
):
|
|
163
|
+
"""
|
|
164
|
+
Called when generating the openapi schema. This overrides the `function-plain` type which
|
|
165
|
+
is generated by the `no_info_plain_validator_function`.
|
|
166
|
+
|
|
167
|
+
This logis seems to be a hot part of the codebase, so I'd expect this to break as pydantic
|
|
168
|
+
fastapi continue to evolve.
|
|
169
|
+
|
|
170
|
+
Note that this method can return multiple types. A return value can be as simple as:
|
|
171
|
+
|
|
172
|
+
{"type": "string"}
|
|
173
|
+
|
|
174
|
+
Or, you could return a more specific JSON schema type:
|
|
175
|
+
|
|
176
|
+
core_schema.uuid_schema()
|
|
177
|
+
|
|
178
|
+
The problem with using something like uuid_schema is the specifi patterns
|
|
179
|
+
|
|
180
|
+
https://github.com/BeanieODM/beanie/blob/2190cd9d1fc047af477d5e6897cc283799f54064/beanie/odm/fields.py#L153
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
return {
|
|
184
|
+
"type": "string",
|
|
185
|
+
# TODO implement a more strict pattern in regex
|
|
186
|
+
# https://github.com/jetify-com/typeid/blob/3d182feed5687c21bb5ab93d5f457ff96749b68b/spec/README.md?plain=1#L38
|
|
187
|
+
# "pattern": "^[0-9a-f]{24}$",
|
|
188
|
+
# "minLength": 24,
|
|
189
|
+
# "maxLength": 24,
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
return core_schema.uuid_schema()
|
activemodel/utils.py
CHANGED
|
@@ -1,15 +1,65 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import pkgutil
|
|
3
|
+
import sys
|
|
4
|
+
from types import ModuleType
|
|
5
|
+
|
|
6
|
+
from sqlalchemy import text
|
|
7
|
+
from sqlmodel import SQLModel
|
|
1
8
|
from sqlmodel.sql.expression import SelectOfScalar
|
|
2
9
|
|
|
3
|
-
from
|
|
10
|
+
from .logger import logger
|
|
11
|
+
from .session_manager import get_engine, get_session
|
|
4
12
|
|
|
5
13
|
|
|
6
14
|
def compile_sql(target: SelectOfScalar):
|
|
15
|
+
"convert a query into SQL, helpful for debugging"
|
|
7
16
|
dialect = get_engine().dialect
|
|
8
17
|
# TODO I wonder if we could store the dialect to avoid getting an engine reference
|
|
9
18
|
compiled = target.compile(dialect=dialect, compile_kwargs={"literal_binds": True})
|
|
10
19
|
return str(compiled)
|
|
11
20
|
|
|
12
21
|
|
|
22
|
+
# TODO document further, lots of risks here
|
|
13
23
|
def raw_sql_exec(raw_query: str):
|
|
14
24
|
with get_session() as session:
|
|
15
25
|
session.execute(text(raw_query))
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def find_all_sqlmodels(module: ModuleType):
|
|
29
|
+
"""Import all model classes from module and submodules into current namespace."""
|
|
30
|
+
|
|
31
|
+
logger.debug(f"Starting model import from module: {module.__name__}")
|
|
32
|
+
model_classes = {}
|
|
33
|
+
|
|
34
|
+
# Walk through all submodules
|
|
35
|
+
for loader, module_name, is_pkg in pkgutil.walk_packages(module.__path__):
|
|
36
|
+
full_name = f"{module.__name__}.{module_name}"
|
|
37
|
+
logger.debug(f"Importing submodule: {full_name}")
|
|
38
|
+
|
|
39
|
+
# Check if module is already imported
|
|
40
|
+
if full_name in sys.modules:
|
|
41
|
+
submodule = sys.modules[full_name]
|
|
42
|
+
else:
|
|
43
|
+
logger.warning(
|
|
44
|
+
f"Module not found in sys.modules, not importing: {full_name}"
|
|
45
|
+
)
|
|
46
|
+
continue
|
|
47
|
+
|
|
48
|
+
# Get all classes from module
|
|
49
|
+
for name, obj in inspect.getmembers(submodule):
|
|
50
|
+
if inspect.isclass(obj) and issubclass(obj, SQLModel) and obj != SQLModel:
|
|
51
|
+
logger.debug(f"Found model class: {name}")
|
|
52
|
+
model_classes[name] = obj
|
|
53
|
+
|
|
54
|
+
logger.debug(f"Completed model import. Found {len(model_classes)} models")
|
|
55
|
+
return model_classes
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def hash_function_code(func):
|
|
59
|
+
"get sha of a function to easily assert that it hasn't changed"
|
|
60
|
+
|
|
61
|
+
import hashlib
|
|
62
|
+
import inspect
|
|
63
|
+
|
|
64
|
+
source = inspect.getsource(func)
|
|
65
|
+
return hashlib.sha256(source.encode()).hexdigest()
|