activemodel 0.3.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- activemodel/__init__.py +2 -4
- activemodel/base_model.py +207 -40
- activemodel/celery.py +28 -0
- activemodel/errors.py +6 -0
- activemodel/get_column_from_field_patch.py +137 -0
- activemodel/logger.py +3 -0
- activemodel/mixins/__init__.py +4 -0
- activemodel/mixins/pydantic_json.py +69 -0
- activemodel/mixins/soft_delete.py +17 -0
- activemodel/{timestamps.py → mixins/timestamps.py} +3 -4
- activemodel/mixins/typeid.py +46 -0
- activemodel/pytest/__init__.py +2 -0
- activemodel/pytest/transaction.py +51 -0
- activemodel/pytest/truncate.py +46 -0
- activemodel/query_wrapper.py +23 -17
- activemodel/session_manager.py +132 -0
- activemodel/types/__init__.py +1 -0
- activemodel/types/typeid.py +191 -0
- activemodel/utils.py +65 -0
- activemodel-0.7.0.dist-info/METADATA +235 -0
- activemodel-0.7.0.dist-info/RECORD +24 -0
- {activemodel-0.3.0.dist-info → activemodel-0.7.0.dist-info}/WHEEL +1 -2
- activemodel-0.3.0.dist-info/METADATA +0 -34
- activemodel-0.3.0.dist-info/RECORD +0 -10
- activemodel-0.3.0.dist-info/top_level.txt +0 -1
- {activemodel-0.3.0.dist-info → activemodel-0.7.0.dist-info}/entry_points.txt +0 -0
- {activemodel-0.3.0.dist-info → activemodel-0.7.0.dist-info/licenses}/LICENSE +0 -0
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from sqlmodel import Column, Field
|
|
2
|
+
from typeid import TypeID
|
|
3
|
+
|
|
4
|
+
from activemodel.types.typeid import TypeIDType
|
|
5
|
+
|
|
6
|
+
# global list of prefixes to ensure uniqueness
|
|
7
|
+
_prefixes = []
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def TypeIDMixin(prefix: str):
|
|
11
|
+
assert prefix
|
|
12
|
+
assert prefix not in _prefixes, (
|
|
13
|
+
f"prefix {prefix} already exists, pick a different one"
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
class _TypeIDMixin:
|
|
17
|
+
id: TypeIDType = Field(
|
|
18
|
+
sa_column=Column(TypeIDType(prefix), primary_key=True, nullable=False),
|
|
19
|
+
default_factory=lambda: TypeID(prefix),
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
_prefixes.append(prefix)
|
|
23
|
+
|
|
24
|
+
return _TypeIDMixin
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# TODO not sure if I love the idea of a dynamic class for each mixin as used above
|
|
28
|
+
# may give this approach another shot in the future
|
|
29
|
+
# class TypeIDMixin2:
|
|
30
|
+
# """
|
|
31
|
+
# Mixin class that adds a TypeID primary key to models.
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# >>> class MyModel(BaseModel, TypeIDMixin, prefix="xyz", table=True):
|
|
35
|
+
# >>> name: str
|
|
36
|
+
|
|
37
|
+
# Will automatically have an `id` field with prefix "xyz"
|
|
38
|
+
# """
|
|
39
|
+
|
|
40
|
+
# def __init_subclass__(cls, *, prefix: str, **kwargs):
|
|
41
|
+
# super().__init_subclass__(**kwargs)
|
|
42
|
+
|
|
43
|
+
# cls.id: uuid.UUID = Field(
|
|
44
|
+
# sa_column=Column(TypeIDType(prefix), primary_key=True),
|
|
45
|
+
# default_factory=lambda: TypeID(prefix),
|
|
46
|
+
# )
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from activemodel import SessionManager
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def database_reset_transaction():
|
|
5
|
+
"""
|
|
6
|
+
Wrap all database interactions for a given test in a nested transaction and roll it back after the test.
|
|
7
|
+
|
|
8
|
+
>>> from activemodel.pytest import database_reset_transaction
|
|
9
|
+
>>> pytest.fixture(scope="function", autouse=True)(database_reset_transaction)
|
|
10
|
+
|
|
11
|
+
References:
|
|
12
|
+
|
|
13
|
+
- https://stackoverflow.com/questions/62433018/how-to-make-sqlalchemy-transaction-rollback-drop-tables-it-created
|
|
14
|
+
- https://aalvarez.me/posts/setting-up-a-sqlalchemy-and-pytest-based-test-suite/
|
|
15
|
+
- https://github.com/nickjj/docker-flask-example/blob/93af9f4fbf185098ffb1d120ee0693abcd77a38b/test/conftest.py#L77
|
|
16
|
+
- https://github.com/caiola/vinhos.com/blob/c47d0a5d7a4bf290c1b726561d1e8f5d2ac29bc8/backend/test/conftest.py#L46
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
engine = SessionManager.get_instance().get_engine()
|
|
20
|
+
|
|
21
|
+
with engine.begin() as connection:
|
|
22
|
+
transaction = connection.begin_nested()
|
|
23
|
+
|
|
24
|
+
SessionManager.get_instance().session_connection = connection
|
|
25
|
+
|
|
26
|
+
try:
|
|
27
|
+
yield
|
|
28
|
+
finally:
|
|
29
|
+
transaction.rollback()
|
|
30
|
+
# TODO is this necessary?
|
|
31
|
+
connection.close()
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# TODO unsure if this adds any value beyond the above approach
|
|
35
|
+
# def database_reset_named_truncation():
|
|
36
|
+
# start_truncation_query = """
|
|
37
|
+
# BEGIN;
|
|
38
|
+
# SAVEPOINT test_truncation_savepoint;
|
|
39
|
+
# """
|
|
40
|
+
|
|
41
|
+
# raw_sql_exec(start_truncation_query)
|
|
42
|
+
|
|
43
|
+
# yield
|
|
44
|
+
|
|
45
|
+
# end_truncation_query = """
|
|
46
|
+
# ROLLBACK TO SAVEPOINT test_truncation_savepoint;
|
|
47
|
+
# RELEASE SAVEPOINT test_truncation_savepoint;
|
|
48
|
+
# ROLLBACK;
|
|
49
|
+
# """
|
|
50
|
+
|
|
51
|
+
# raw_sql_exec(end_truncation_query)
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from sqlmodel import SQLModel
|
|
2
|
+
|
|
3
|
+
from ..logger import logger
|
|
4
|
+
from ..session_manager import get_engine
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def database_reset_truncate():
|
|
8
|
+
"""
|
|
9
|
+
Transaction is most likely the better way to go, but there are some scenarios where the session override
|
|
10
|
+
logic does not work properly and you need to truncate tables back to their original state.
|
|
11
|
+
|
|
12
|
+
Here's how to do this once at the start of the test:
|
|
13
|
+
|
|
14
|
+
>>> from activemodel.pytest import database_reset_truncation
|
|
15
|
+
>>> def pytest_configure(config):
|
|
16
|
+
>>> database_reset_truncation()
|
|
17
|
+
|
|
18
|
+
Or, if you want to use this as a fixture:
|
|
19
|
+
|
|
20
|
+
>>> pytest.fixture(scope="function")(database_reset_truncation)
|
|
21
|
+
>>> def test_the_thing(database_reset_truncation)
|
|
22
|
+
|
|
23
|
+
This approach has a couple of problems:
|
|
24
|
+
|
|
25
|
+
* You can't run multiple tests in parallel without separate databases
|
|
26
|
+
* If you have important seed data and want to truncate those tables, the seed data will be lost
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
logger.info("truncating database")
|
|
30
|
+
|
|
31
|
+
# TODO get additonal tables to preserve from config
|
|
32
|
+
exception_tables = ["alembic_version"]
|
|
33
|
+
|
|
34
|
+
assert (
|
|
35
|
+
SQLModel.metadata.sorted_tables
|
|
36
|
+
), "No model metadata. Ensure model metadata is imported before running truncate_db"
|
|
37
|
+
|
|
38
|
+
with get_engine().connect() as connection:
|
|
39
|
+
for table in reversed(SQLModel.metadata.sorted_tables):
|
|
40
|
+
transaction = connection.begin()
|
|
41
|
+
|
|
42
|
+
if table.name not in exception_tables:
|
|
43
|
+
logger.debug(f"truncating table={table.name}")
|
|
44
|
+
connection.execute(table.delete())
|
|
45
|
+
|
|
46
|
+
transaction.commit()
|
activemodel/query_wrapper.py
CHANGED
|
@@ -1,28 +1,26 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
import sqlmodel as sm
|
|
3
2
|
from sqlmodel.sql.expression import SelectOfScalar
|
|
4
3
|
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
def compile_sql(target: SelectOfScalar):
|
|
9
|
-
return str(target.compile(get_engine().connect()))
|
|
4
|
+
from .session_manager import get_session
|
|
5
|
+
from .utils import compile_sql
|
|
10
6
|
|
|
11
7
|
|
|
12
|
-
class QueryWrapper
|
|
8
|
+
class QueryWrapper[T: sm.SQLModel]:
|
|
13
9
|
"""
|
|
14
10
|
Make it easy to run queries off of a model
|
|
15
11
|
"""
|
|
16
12
|
|
|
17
|
-
|
|
13
|
+
target: SelectOfScalar[T]
|
|
14
|
+
|
|
15
|
+
def __init__(self, cls: T, *args) -> None:
|
|
18
16
|
# TODO add generics here
|
|
19
17
|
# self.target: SelectOfScalar[T] = sql.select(cls)
|
|
20
18
|
|
|
21
19
|
if args:
|
|
22
20
|
# very naive, let's assume the args are specific select statements
|
|
23
|
-
self.target =
|
|
21
|
+
self.target = sm.select(*args).select_from(cls)
|
|
24
22
|
else:
|
|
25
|
-
self.target =
|
|
23
|
+
self.target = sm.select(cls)
|
|
26
24
|
|
|
27
25
|
# TODO the .exec results should be handled in one shot
|
|
28
26
|
|
|
@@ -31,6 +29,7 @@ class QueryWrapper(Generic[WrappedModelType]):
|
|
|
31
29
|
return session.exec(self.target).first()
|
|
32
30
|
|
|
33
31
|
def one(self):
|
|
32
|
+
"requires exactly one result in the dataset"
|
|
34
33
|
with get_session() as session:
|
|
35
34
|
return session.exec(self.target).one()
|
|
36
35
|
|
|
@@ -40,8 +39,14 @@ class QueryWrapper(Generic[WrappedModelType]):
|
|
|
40
39
|
for row in result:
|
|
41
40
|
yield row
|
|
42
41
|
|
|
42
|
+
def count(self):
|
|
43
|
+
"""
|
|
44
|
+
I did some basic tests
|
|
45
|
+
"""
|
|
46
|
+
with get_session() as session:
|
|
47
|
+
return session.scalar(sm.select(sm.func.count()).select_from(self.target))
|
|
48
|
+
|
|
43
49
|
def exec(self):
|
|
44
|
-
# TODO do we really need a unique session each time?
|
|
45
50
|
with get_session() as session:
|
|
46
51
|
return session.exec(self.target)
|
|
47
52
|
|
|
@@ -51,19 +56,20 @@ class QueryWrapper(Generic[WrappedModelType]):
|
|
|
51
56
|
|
|
52
57
|
def __getattr__(self, name):
|
|
53
58
|
"""
|
|
54
|
-
This
|
|
59
|
+
This implements the magic that forwards function calls to sqlalchemy.
|
|
55
60
|
"""
|
|
56
61
|
|
|
57
62
|
# TODO prefer methods defined in this class
|
|
63
|
+
|
|
58
64
|
if not hasattr(self.target, name):
|
|
59
65
|
return super().__getattribute__(name)
|
|
60
66
|
|
|
61
|
-
|
|
67
|
+
sqlalchemy_target = getattr(self.target, name)
|
|
62
68
|
|
|
63
|
-
if callable(
|
|
69
|
+
if callable(sqlalchemy_target):
|
|
64
70
|
|
|
65
71
|
def wrapper(*args, **kwargs):
|
|
66
|
-
result =
|
|
72
|
+
result = sqlalchemy_target(*args, **kwargs)
|
|
67
73
|
self.target = result
|
|
68
74
|
return self
|
|
69
75
|
|
|
@@ -75,7 +81,7 @@ class QueryWrapper(Generic[WrappedModelType]):
|
|
|
75
81
|
|
|
76
82
|
def sql(self):
|
|
77
83
|
"""
|
|
78
|
-
Output the raw SQL of the query
|
|
84
|
+
Output the raw SQL of the query for debugging
|
|
79
85
|
"""
|
|
80
86
|
|
|
81
87
|
return compile_sql(self.target)
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Class to make managing sessions with SQL Model easy. Also provides a common entrypoint to make it easy to mutate the
|
|
3
|
+
database environment when testing.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import contextlib
|
|
7
|
+
import contextvars
|
|
8
|
+
import json
|
|
9
|
+
import typing as t
|
|
10
|
+
|
|
11
|
+
from decouple import config
|
|
12
|
+
from pydantic import BaseModel
|
|
13
|
+
from sqlalchemy import Connection, Engine
|
|
14
|
+
from sqlmodel import Session, create_engine
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _serialize_pydantic_model(model: BaseModel | list[BaseModel] | None) -> str | None:
|
|
18
|
+
"""
|
|
19
|
+
Pydantic models do not serialize to JSON. You'll get an error such as:
|
|
20
|
+
|
|
21
|
+
'TypeError: Object of type TranscriptEntry is not JSON serializable'
|
|
22
|
+
|
|
23
|
+
https://github.com/fastapi/sqlmodel/issues/63#issuecomment-2581016387
|
|
24
|
+
|
|
25
|
+
This custom serializer is passed to the DB engine to properly serialize pydantic models to
|
|
26
|
+
JSON for storage in a JSONB column.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
# TODO I bet this will fail on lists with mixed types
|
|
30
|
+
|
|
31
|
+
if isinstance(model, BaseModel):
|
|
32
|
+
return model.model_dump_json()
|
|
33
|
+
if isinstance(model, list):
|
|
34
|
+
# not everything in a list is a pydantic model
|
|
35
|
+
def dump_if_model(m):
|
|
36
|
+
if isinstance(m, BaseModel):
|
|
37
|
+
return m.model_dump()
|
|
38
|
+
return m
|
|
39
|
+
|
|
40
|
+
return json.dumps([dump_if_model(m) for m in model])
|
|
41
|
+
else:
|
|
42
|
+
return json.dumps(model)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class SessionManager:
|
|
46
|
+
_instance: t.ClassVar[t.Optional["SessionManager"]] = None
|
|
47
|
+
|
|
48
|
+
session_connection: Connection | None
|
|
49
|
+
"optionally specify a specific session connection to use for all get_session() calls, useful for testing"
|
|
50
|
+
|
|
51
|
+
@classmethod
|
|
52
|
+
def get_instance(cls, database_url: str | None = None) -> "SessionManager":
|
|
53
|
+
if cls._instance is None:
|
|
54
|
+
assert database_url is not None, (
|
|
55
|
+
"Database URL required for first initialization"
|
|
56
|
+
)
|
|
57
|
+
cls._instance = cls(database_url)
|
|
58
|
+
|
|
59
|
+
return cls._instance
|
|
60
|
+
|
|
61
|
+
def __init__(self, database_url: str):
|
|
62
|
+
self._database_url = database_url
|
|
63
|
+
self._engine = None
|
|
64
|
+
|
|
65
|
+
self.session_connection = None
|
|
66
|
+
|
|
67
|
+
# TODO why is this type not reimported?
|
|
68
|
+
def get_engine(self) -> Engine:
|
|
69
|
+
if not self._engine:
|
|
70
|
+
self._engine = create_engine(
|
|
71
|
+
self._database_url,
|
|
72
|
+
json_serializer=_serialize_pydantic_model,
|
|
73
|
+
echo=config("ACTIVEMODEL_LOG_SQL", cast=bool, default=False),
|
|
74
|
+
# https://docs.sqlalchemy.org/en/20/core/pooling.html#disconnect-handling-pessimistic
|
|
75
|
+
pool_pre_ping=True,
|
|
76
|
+
# some implementations include `future=True` but it's not required anymore
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
return self._engine
|
|
80
|
+
|
|
81
|
+
def get_session(self):
|
|
82
|
+
if gsession := _session_context.get():
|
|
83
|
+
|
|
84
|
+
@contextlib.contextmanager
|
|
85
|
+
def _reuse_session():
|
|
86
|
+
yield gsession
|
|
87
|
+
|
|
88
|
+
return _reuse_session()
|
|
89
|
+
|
|
90
|
+
if self.session_connection:
|
|
91
|
+
return Session(bind=self.session_connection)
|
|
92
|
+
|
|
93
|
+
return Session(self.get_engine())
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def init(database_url: str):
|
|
97
|
+
return SessionManager.get_instance(database_url)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def get_engine():
|
|
101
|
+
return SessionManager.get_instance().get_engine()
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def get_session():
|
|
105
|
+
return SessionManager.get_instance().get_session()
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
# contextvars must be at the top-level of a module! You will not get a warning if you don't do this.
|
|
109
|
+
_session_context = contextvars.ContextVar[Session | None](
|
|
110
|
+
"session_context", default=None
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@contextlib.contextmanager
|
|
115
|
+
def global_session():
|
|
116
|
+
with SessionManager.get_instance().get_session() as s:
|
|
117
|
+
token = _session_context.set(s)
|
|
118
|
+
|
|
119
|
+
try:
|
|
120
|
+
yield
|
|
121
|
+
finally:
|
|
122
|
+
_session_context.reset(token)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
async def aglobal_session():
|
|
126
|
+
with SessionManager.get_instance().get_session() as s:
|
|
127
|
+
token = _session_context.set(s)
|
|
128
|
+
|
|
129
|
+
try:
|
|
130
|
+
yield
|
|
131
|
+
finally:
|
|
132
|
+
_session_context.reset(token)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .typeid import TypeIDType
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Lifted from: https://github.com/akhundMurad/typeid-python/blob/main/examples/sqlalchemy.py
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Optional
|
|
6
|
+
from uuid import UUID
|
|
7
|
+
|
|
8
|
+
from pydantic import (
|
|
9
|
+
GetJsonSchemaHandler,
|
|
10
|
+
)
|
|
11
|
+
from pydantic_core import CoreSchema, core_schema
|
|
12
|
+
from sqlalchemy import types
|
|
13
|
+
from sqlalchemy.util import generic_repr
|
|
14
|
+
from typeid import TypeID
|
|
15
|
+
|
|
16
|
+
from activemodel.errors import TypeIDValidationError
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TypeIDType(types.TypeDecorator):
|
|
20
|
+
"""
|
|
21
|
+
A SQLAlchemy TypeDecorator that allows storing TypeIDs in the database.
|
|
22
|
+
The prefix will not be persisted, instead the database-native UUID field will be used.
|
|
23
|
+
At retrieval time a TypeID will be constructed based on the configured prefix and the
|
|
24
|
+
UUID value from the database.
|
|
25
|
+
|
|
26
|
+
Usage:
|
|
27
|
+
# will result in TypeIDs such as "user_01h45ytscbebyvny4gc8cr8ma2"
|
|
28
|
+
id = mapped_column(
|
|
29
|
+
TypeIDType("user"),
|
|
30
|
+
primary_key=True,
|
|
31
|
+
default=lambda: TypeID("user")
|
|
32
|
+
)
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
impl = types.Uuid
|
|
36
|
+
# impl = uuid.UUID
|
|
37
|
+
cache_ok = True
|
|
38
|
+
prefix: Optional[str] = None
|
|
39
|
+
|
|
40
|
+
def __init__(self, prefix: Optional[str], *args, **kwargs):
|
|
41
|
+
self.prefix = prefix
|
|
42
|
+
super().__init__(*args, **kwargs)
|
|
43
|
+
|
|
44
|
+
def __repr__(self) -> str:
|
|
45
|
+
# Customize __repr__ to ensure that auto-generated code e.g. from alembic includes
|
|
46
|
+
# the right __init__ params (otherwise by default prefix will be omitted because
|
|
47
|
+
# uuid.__init__ does not have such an argument).
|
|
48
|
+
# TODO this makes it so inspected code does NOT include the suffix
|
|
49
|
+
return generic_repr(
|
|
50
|
+
self,
|
|
51
|
+
to_inspect=TypeID(self.prefix),
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
def process_bind_param(self, value, dialect):
|
|
55
|
+
"""
|
|
56
|
+
This is run when a search query is built or ...
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
if value is None:
|
|
60
|
+
return None
|
|
61
|
+
|
|
62
|
+
if isinstance(value, UUID):
|
|
63
|
+
# then it's a UUID class, such as UUID('01942886-7afc-7129-8f57-db09137ed002')
|
|
64
|
+
return value
|
|
65
|
+
|
|
66
|
+
if isinstance(value, str) and value.startswith(self.prefix + "_"):
|
|
67
|
+
# then it's a TypeID such as 'user_01h45ytscbebyvny4gc8cr8ma2'
|
|
68
|
+
value = TypeID.from_string(value)
|
|
69
|
+
|
|
70
|
+
if isinstance(value, str):
|
|
71
|
+
# no prefix, raw UUID, let's coerce it into a UUID which SQLAlchemy can handle
|
|
72
|
+
# ex: '01942886-7afc-7129-8f57-db09137ed002'
|
|
73
|
+
return UUID(value)
|
|
74
|
+
|
|
75
|
+
if isinstance(value, TypeID):
|
|
76
|
+
# TODO in what case could this None prefix ever occur?
|
|
77
|
+
if self.prefix is None:
|
|
78
|
+
if value.prefix is None:
|
|
79
|
+
raise TypeIDValidationError(
|
|
80
|
+
"Must have a valid prefix set on the class"
|
|
81
|
+
)
|
|
82
|
+
else:
|
|
83
|
+
if value.prefix != self.prefix:
|
|
84
|
+
raise TypeIDValidationError(
|
|
85
|
+
f"Expected '{self.prefix}' but got '{value.prefix}'"
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
return value.uuid
|
|
89
|
+
|
|
90
|
+
raise ValueError("Unexpected input type")
|
|
91
|
+
|
|
92
|
+
def process_result_value(self, value, dialect):
|
|
93
|
+
if value is None:
|
|
94
|
+
return None
|
|
95
|
+
|
|
96
|
+
return TypeID.from_uuid(value, self.prefix)
|
|
97
|
+
|
|
98
|
+
# def coerce_compared_value(self, op, value):
|
|
99
|
+
# """
|
|
100
|
+
# This method is called when SQLAlchemy needs to compare a column to a value.
|
|
101
|
+
# By returning self, we indicate that this type can handle TypeID instances.
|
|
102
|
+
# """
|
|
103
|
+
# if isinstance(value, TypeID):
|
|
104
|
+
# return self
|
|
105
|
+
|
|
106
|
+
# return super().coerce_compared_value(op, value)
|
|
107
|
+
|
|
108
|
+
@classmethod
|
|
109
|
+
def __get_pydantic_core_schema__(
|
|
110
|
+
cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
|
|
111
|
+
) -> CoreSchema:
|
|
112
|
+
"""
|
|
113
|
+
This fixes the following error: 'Unable to serialize unknown type' by telling pydantic how to serialize this field.
|
|
114
|
+
|
|
115
|
+
Note that TypeIDType MUST be the type of the field in SQLModel otherwise you'll get serialization errors.
|
|
116
|
+
This is done automatically for the mixin but for any relationship fields you'll need to specify the type explicitly.
|
|
117
|
+
|
|
118
|
+
- https://github.com/karma-dev-team/karma-system/blob/ee0c1a06ab2cb7aaca6dc4818312e68c5c623365/app/server/value_objects/steam_id.py#L88
|
|
119
|
+
- https://github.com/hhimanshu/uv-workspaces/blob/main/packages/api/src/_lib/dto/typeid_field.py
|
|
120
|
+
- https://github.com/karma-dev-team/karma-system/blob/ee0c1a06ab2cb7aaca6dc4818312e68c5c623365/app/base/typeid/type_id.py#L14
|
|
121
|
+
- https://github.com/pydantic/pydantic/issues/10060
|
|
122
|
+
- https://github.com/fastapi/fastapi/discussions/10027
|
|
123
|
+
- https://github.com/alice-biometrics/petisco/blob/b01ef1b84949d156f73919e126ed77aa8e0b48dd/petisco/base/domain/model/uuid.py#L50
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
from_uuid_schema = core_schema.chain_schema(
|
|
127
|
+
[
|
|
128
|
+
# TODO not sure how this is different from the UUID schema, should try it out.
|
|
129
|
+
# core_schema.is_instance_schema(TypeID),
|
|
130
|
+
# core_schema.uuid_schema(),
|
|
131
|
+
core_schema.no_info_plain_validator_function(
|
|
132
|
+
TypeID.from_string,
|
|
133
|
+
json_schema_input_schema=core_schema.str_schema(),
|
|
134
|
+
),
|
|
135
|
+
]
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
return core_schema.json_or_python_schema(
|
|
139
|
+
json_schema=from_uuid_schema,
|
|
140
|
+
# metadata=core_schema.str_schema(
|
|
141
|
+
# pattern="^[0-9a-f]{24}$",
|
|
142
|
+
# min_length=24,
|
|
143
|
+
# max_length=24,
|
|
144
|
+
# ),
|
|
145
|
+
# metadata={
|
|
146
|
+
# "pydantic_js_input_core_schema": core_schema.str_schema(
|
|
147
|
+
# pattern="^[0-9a-f]{24}$",
|
|
148
|
+
# min_length=24,
|
|
149
|
+
# max_length=24,
|
|
150
|
+
# )
|
|
151
|
+
# },
|
|
152
|
+
python_schema=core_schema.union_schema([from_uuid_schema]),
|
|
153
|
+
serialization=core_schema.plain_serializer_function_ser_schema(
|
|
154
|
+
lambda x: str(x)
|
|
155
|
+
),
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
@classmethod
|
|
159
|
+
def __get_pydantic_json_schema__(
|
|
160
|
+
cls, schema: CoreSchema, handler: GetJsonSchemaHandler
|
|
161
|
+
):
|
|
162
|
+
"""
|
|
163
|
+
Called when generating the openapi schema. This overrides the `function-plain` type which
|
|
164
|
+
is generated by the `no_info_plain_validator_function`.
|
|
165
|
+
|
|
166
|
+
This logis seems to be a hot part of the codebase, so I'd expect this to break as pydantic
|
|
167
|
+
fastapi continue to evolve.
|
|
168
|
+
|
|
169
|
+
Note that this method can return multiple types. A return value can be as simple as:
|
|
170
|
+
|
|
171
|
+
{"type": "string"}
|
|
172
|
+
|
|
173
|
+
Or, you could return a more specific JSON schema type:
|
|
174
|
+
|
|
175
|
+
core_schema.uuid_schema()
|
|
176
|
+
|
|
177
|
+
The problem with using something like uuid_schema is the specifi patterns
|
|
178
|
+
|
|
179
|
+
https://github.com/BeanieODM/beanie/blob/2190cd9d1fc047af477d5e6897cc283799f54064/beanie/odm/fields.py#L153
|
|
180
|
+
"""
|
|
181
|
+
|
|
182
|
+
return {
|
|
183
|
+
"type": "string",
|
|
184
|
+
# TODO implement a more strict pattern in regex
|
|
185
|
+
# https://github.com/jetify-com/typeid/blob/3d182feed5687c21bb5ab93d5f457ff96749b68b/spec/README.md?plain=1#L38
|
|
186
|
+
# "pattern": "^[0-9a-f]{24}$",
|
|
187
|
+
# "minLength": 24,
|
|
188
|
+
# "maxLength": 24,
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
return core_schema.uuid_schema()
|
activemodel/utils.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import pkgutil
|
|
3
|
+
import sys
|
|
4
|
+
from types import ModuleType
|
|
5
|
+
|
|
6
|
+
from sqlalchemy import text
|
|
7
|
+
from sqlmodel import SQLModel
|
|
8
|
+
from sqlmodel.sql.expression import SelectOfScalar
|
|
9
|
+
|
|
10
|
+
from .logger import logger
|
|
11
|
+
from .session_manager import get_engine, get_session
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def compile_sql(target: SelectOfScalar):
|
|
15
|
+
"convert a query into SQL, helpful for debugging"
|
|
16
|
+
dialect = get_engine().dialect
|
|
17
|
+
# TODO I wonder if we could store the dialect to avoid getting an engine reference
|
|
18
|
+
compiled = target.compile(dialect=dialect, compile_kwargs={"literal_binds": True})
|
|
19
|
+
return str(compiled)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# TODO document further, lots of risks here
|
|
23
|
+
def raw_sql_exec(raw_query: str):
|
|
24
|
+
with get_session() as session:
|
|
25
|
+
session.execute(text(raw_query))
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def find_all_sqlmodels(module: ModuleType):
|
|
29
|
+
"""Import all model classes from module and submodules into current namespace."""
|
|
30
|
+
|
|
31
|
+
logger.debug(f"Starting model import from module: {module.__name__}")
|
|
32
|
+
model_classes = {}
|
|
33
|
+
|
|
34
|
+
# Walk through all submodules
|
|
35
|
+
for loader, module_name, is_pkg in pkgutil.walk_packages(module.__path__):
|
|
36
|
+
full_name = f"{module.__name__}.{module_name}"
|
|
37
|
+
logger.debug(f"Importing submodule: {full_name}")
|
|
38
|
+
|
|
39
|
+
# Check if module is already imported
|
|
40
|
+
if full_name in sys.modules:
|
|
41
|
+
submodule = sys.modules[full_name]
|
|
42
|
+
else:
|
|
43
|
+
logger.warning(
|
|
44
|
+
f"Module not found in sys.modules, not importing: {full_name}"
|
|
45
|
+
)
|
|
46
|
+
continue
|
|
47
|
+
|
|
48
|
+
# Get all classes from module
|
|
49
|
+
for name, obj in inspect.getmembers(submodule):
|
|
50
|
+
if inspect.isclass(obj) and issubclass(obj, SQLModel) and obj != SQLModel:
|
|
51
|
+
logger.debug(f"Found model class: {name}")
|
|
52
|
+
model_classes[name] = obj
|
|
53
|
+
|
|
54
|
+
logger.debug(f"Completed model import. Found {len(model_classes)} models")
|
|
55
|
+
return model_classes
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def hash_function_code(func):
|
|
59
|
+
"get sha of a function to easily assert that it hasn't changed"
|
|
60
|
+
|
|
61
|
+
import hashlib
|
|
62
|
+
import inspect
|
|
63
|
+
|
|
64
|
+
source = inspect.getsource(func)
|
|
65
|
+
return hashlib.sha256(source.encode()).hexdigest()
|