activemodel 0.3.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,46 @@
1
+ from sqlmodel import Column, Field
2
+ from typeid import TypeID
3
+
4
+ from activemodel.types.typeid import TypeIDType
5
+
6
+ # global list of prefixes to ensure uniqueness
7
+ _prefixes = []
8
+
9
+
10
+ def TypeIDMixin(prefix: str):
11
+ assert prefix
12
+ assert prefix not in _prefixes, (
13
+ f"prefix {prefix} already exists, pick a different one"
14
+ )
15
+
16
+ class _TypeIDMixin:
17
+ id: TypeIDType = Field(
18
+ sa_column=Column(TypeIDType(prefix), primary_key=True, nullable=False),
19
+ default_factory=lambda: TypeID(prefix),
20
+ )
21
+
22
+ _prefixes.append(prefix)
23
+
24
+ return _TypeIDMixin
25
+
26
+
27
+ # TODO not sure if I love the idea of a dynamic class for each mixin as used above
28
+ # may give this approach another shot in the future
29
+ # class TypeIDMixin2:
30
+ # """
31
+ # Mixin class that adds a TypeID primary key to models.
32
+
33
+
34
+ # >>> class MyModel(BaseModel, TypeIDMixin, prefix="xyz", table=True):
35
+ # >>> name: str
36
+
37
+ # Will automatically have an `id` field with prefix "xyz"
38
+ # """
39
+
40
+ # def __init_subclass__(cls, *, prefix: str, **kwargs):
41
+ # super().__init_subclass__(**kwargs)
42
+
43
+ # cls.id: uuid.UUID = Field(
44
+ # sa_column=Column(TypeIDType(prefix), primary_key=True),
45
+ # default_factory=lambda: TypeID(prefix),
46
+ # )
@@ -0,0 +1,2 @@
1
+ from .transaction import database_reset_transaction
2
+ from .truncate import database_reset_truncate
@@ -0,0 +1,51 @@
1
+ from activemodel import SessionManager
2
+
3
+
4
+ def database_reset_transaction():
5
+ """
6
+ Wrap all database interactions for a given test in a nested transaction and roll it back after the test.
7
+
8
+ >>> from activemodel.pytest import database_reset_transaction
9
+ >>> pytest.fixture(scope="function", autouse=True)(database_reset_transaction)
10
+
11
+ References:
12
+
13
+ - https://stackoverflow.com/questions/62433018/how-to-make-sqlalchemy-transaction-rollback-drop-tables-it-created
14
+ - https://aalvarez.me/posts/setting-up-a-sqlalchemy-and-pytest-based-test-suite/
15
+ - https://github.com/nickjj/docker-flask-example/blob/93af9f4fbf185098ffb1d120ee0693abcd77a38b/test/conftest.py#L77
16
+ - https://github.com/caiola/vinhos.com/blob/c47d0a5d7a4bf290c1b726561d1e8f5d2ac29bc8/backend/test/conftest.py#L46
17
+ """
18
+
19
+ engine = SessionManager.get_instance().get_engine()
20
+
21
+ with engine.begin() as connection:
22
+ transaction = connection.begin_nested()
23
+
24
+ SessionManager.get_instance().session_connection = connection
25
+
26
+ try:
27
+ yield
28
+ finally:
29
+ transaction.rollback()
30
+ # TODO is this necessary?
31
+ connection.close()
32
+
33
+
34
+ # TODO unsure if this adds any value beyond the above approach
35
+ # def database_reset_named_truncation():
36
+ # start_truncation_query = """
37
+ # BEGIN;
38
+ # SAVEPOINT test_truncation_savepoint;
39
+ # """
40
+
41
+ # raw_sql_exec(start_truncation_query)
42
+
43
+ # yield
44
+
45
+ # end_truncation_query = """
46
+ # ROLLBACK TO SAVEPOINT test_truncation_savepoint;
47
+ # RELEASE SAVEPOINT test_truncation_savepoint;
48
+ # ROLLBACK;
49
+ # """
50
+
51
+ # raw_sql_exec(end_truncation_query)
@@ -0,0 +1,46 @@
1
+ from sqlmodel import SQLModel
2
+
3
+ from ..logger import logger
4
+ from ..session_manager import get_engine
5
+
6
+
7
+ def database_reset_truncate():
8
+ """
9
+ Transaction is most likely the better way to go, but there are some scenarios where the session override
10
+ logic does not work properly and you need to truncate tables back to their original state.
11
+
12
+ Here's how to do this once at the start of the test:
13
+
14
+ >>> from activemodel.pytest import database_reset_truncation
15
+ >>> def pytest_configure(config):
16
+ >>> database_reset_truncation()
17
+
18
+ Or, if you want to use this as a fixture:
19
+
20
+ >>> pytest.fixture(scope="function")(database_reset_truncation)
21
+ >>> def test_the_thing(database_reset_truncation)
22
+
23
+ This approach has a couple of problems:
24
+
25
+ * You can't run multiple tests in parallel without separate databases
26
+ * If you have important seed data and want to truncate those tables, the seed data will be lost
27
+ """
28
+
29
+ logger.info("truncating database")
30
+
31
+ # TODO get additonal tables to preserve from config
32
+ exception_tables = ["alembic_version"]
33
+
34
+ assert (
35
+ SQLModel.metadata.sorted_tables
36
+ ), "No model metadata. Ensure model metadata is imported before running truncate_db"
37
+
38
+ with get_engine().connect() as connection:
39
+ for table in reversed(SQLModel.metadata.sorted_tables):
40
+ transaction = connection.begin()
41
+
42
+ if table.name not in exception_tables:
43
+ logger.debug(f"truncating table={table.name}")
44
+ connection.execute(table.delete())
45
+
46
+ transaction.commit()
@@ -1,28 +1,26 @@
1
- from typing import Generic, TypeVar
2
-
1
+ import sqlmodel as sm
3
2
  from sqlmodel.sql.expression import SelectOfScalar
4
3
 
5
- WrappedModelType = TypeVar("WrappedModelType")
6
-
7
-
8
- def compile_sql(target: SelectOfScalar):
9
- return str(target.compile(get_engine().connect()))
4
+ from .session_manager import get_session
5
+ from .utils import compile_sql
10
6
 
11
7
 
12
- class QueryWrapper(Generic[WrappedModelType]):
8
+ class QueryWrapper[T: sm.SQLModel]:
13
9
  """
14
10
  Make it easy to run queries off of a model
15
11
  """
16
12
 
17
- def __init__(self, cls, *args) -> None:
13
+ target: SelectOfScalar[T]
14
+
15
+ def __init__(self, cls: T, *args) -> None:
18
16
  # TODO add generics here
19
17
  # self.target: SelectOfScalar[T] = sql.select(cls)
20
18
 
21
19
  if args:
22
20
  # very naive, let's assume the args are specific select statements
23
- self.target = sql.select(*args).select_from(cls)
21
+ self.target = sm.select(*args).select_from(cls)
24
22
  else:
25
- self.target = sql.select(cls)
23
+ self.target = sm.select(cls)
26
24
 
27
25
  # TODO the .exec results should be handled in one shot
28
26
 
@@ -31,6 +29,7 @@ class QueryWrapper(Generic[WrappedModelType]):
31
29
  return session.exec(self.target).first()
32
30
 
33
31
  def one(self):
32
+ "requires exactly one result in the dataset"
34
33
  with get_session() as session:
35
34
  return session.exec(self.target).one()
36
35
 
@@ -40,8 +39,14 @@ class QueryWrapper(Generic[WrappedModelType]):
40
39
  for row in result:
41
40
  yield row
42
41
 
42
+ def count(self):
43
+ """
44
+ I did some basic tests
45
+ """
46
+ with get_session() as session:
47
+ return session.scalar(sm.select(sm.func.count()).select_from(self.target))
48
+
43
49
  def exec(self):
44
- # TODO do we really need a unique session each time?
45
50
  with get_session() as session:
46
51
  return session.exec(self.target)
47
52
 
@@ -51,19 +56,20 @@ class QueryWrapper(Generic[WrappedModelType]):
51
56
 
52
57
  def __getattr__(self, name):
53
58
  """
54
- This is called to retrieve the function to execute
59
+ This implements the magic that forwards function calls to sqlalchemy.
55
60
  """
56
61
 
57
62
  # TODO prefer methods defined in this class
63
+
58
64
  if not hasattr(self.target, name):
59
65
  return super().__getattribute__(name)
60
66
 
61
- attr = getattr(self.target, name)
67
+ sqlalchemy_target = getattr(self.target, name)
62
68
 
63
- if callable(attr):
69
+ if callable(sqlalchemy_target):
64
70
 
65
71
  def wrapper(*args, **kwargs):
66
- result = attr(*args, **kwargs)
72
+ result = sqlalchemy_target(*args, **kwargs)
67
73
  self.target = result
68
74
  return self
69
75
 
@@ -75,7 +81,7 @@ class QueryWrapper(Generic[WrappedModelType]):
75
81
 
76
82
  def sql(self):
77
83
  """
78
- Output the raw SQL of the query
84
+ Output the raw SQL of the query for debugging
79
85
  """
80
86
 
81
87
  return compile_sql(self.target)
@@ -0,0 +1,132 @@
1
+ """
2
+ Class to make managing sessions with SQL Model easy. Also provides a common entrypoint to make it easy to mutate the
3
+ database environment when testing.
4
+ """
5
+
6
+ import contextlib
7
+ import contextvars
8
+ import json
9
+ import typing as t
10
+
11
+ from decouple import config
12
+ from pydantic import BaseModel
13
+ from sqlalchemy import Connection, Engine
14
+ from sqlmodel import Session, create_engine
15
+
16
+
17
+ def _serialize_pydantic_model(model: BaseModel | list[BaseModel] | None) -> str | None:
18
+ """
19
+ Pydantic models do not serialize to JSON. You'll get an error such as:
20
+
21
+ 'TypeError: Object of type TranscriptEntry is not JSON serializable'
22
+
23
+ https://github.com/fastapi/sqlmodel/issues/63#issuecomment-2581016387
24
+
25
+ This custom serializer is passed to the DB engine to properly serialize pydantic models to
26
+ JSON for storage in a JSONB column.
27
+ """
28
+
29
+ # TODO I bet this will fail on lists with mixed types
30
+
31
+ if isinstance(model, BaseModel):
32
+ return model.model_dump_json()
33
+ if isinstance(model, list):
34
+ # not everything in a list is a pydantic model
35
+ def dump_if_model(m):
36
+ if isinstance(m, BaseModel):
37
+ return m.model_dump()
38
+ return m
39
+
40
+ return json.dumps([dump_if_model(m) for m in model])
41
+ else:
42
+ return json.dumps(model)
43
+
44
+
45
+ class SessionManager:
46
+ _instance: t.ClassVar[t.Optional["SessionManager"]] = None
47
+
48
+ session_connection: Connection | None
49
+ "optionally specify a specific session connection to use for all get_session() calls, useful for testing"
50
+
51
+ @classmethod
52
+ def get_instance(cls, database_url: str | None = None) -> "SessionManager":
53
+ if cls._instance is None:
54
+ assert database_url is not None, (
55
+ "Database URL required for first initialization"
56
+ )
57
+ cls._instance = cls(database_url)
58
+
59
+ return cls._instance
60
+
61
+ def __init__(self, database_url: str):
62
+ self._database_url = database_url
63
+ self._engine = None
64
+
65
+ self.session_connection = None
66
+
67
+ # TODO why is this type not reimported?
68
+ def get_engine(self) -> Engine:
69
+ if not self._engine:
70
+ self._engine = create_engine(
71
+ self._database_url,
72
+ json_serializer=_serialize_pydantic_model,
73
+ echo=config("ACTIVEMODEL_LOG_SQL", cast=bool, default=False),
74
+ # https://docs.sqlalchemy.org/en/20/core/pooling.html#disconnect-handling-pessimistic
75
+ pool_pre_ping=True,
76
+ # some implementations include `future=True` but it's not required anymore
77
+ )
78
+
79
+ return self._engine
80
+
81
+ def get_session(self):
82
+ if gsession := _session_context.get():
83
+
84
+ @contextlib.contextmanager
85
+ def _reuse_session():
86
+ yield gsession
87
+
88
+ return _reuse_session()
89
+
90
+ if self.session_connection:
91
+ return Session(bind=self.session_connection)
92
+
93
+ return Session(self.get_engine())
94
+
95
+
96
+ def init(database_url: str):
97
+ return SessionManager.get_instance(database_url)
98
+
99
+
100
+ def get_engine():
101
+ return SessionManager.get_instance().get_engine()
102
+
103
+
104
+ def get_session():
105
+ return SessionManager.get_instance().get_session()
106
+
107
+
108
+ # contextvars must be at the top-level of a module! You will not get a warning if you don't do this.
109
+ _session_context = contextvars.ContextVar[Session | None](
110
+ "session_context", default=None
111
+ )
112
+
113
+
114
+ @contextlib.contextmanager
115
+ def global_session():
116
+ with SessionManager.get_instance().get_session() as s:
117
+ token = _session_context.set(s)
118
+
119
+ try:
120
+ yield
121
+ finally:
122
+ _session_context.reset(token)
123
+
124
+
125
+ async def aglobal_session():
126
+ with SessionManager.get_instance().get_session() as s:
127
+ token = _session_context.set(s)
128
+
129
+ try:
130
+ yield
131
+ finally:
132
+ _session_context.reset(token)
@@ -0,0 +1 @@
1
+ from .typeid import TypeIDType
@@ -0,0 +1,191 @@
1
+ """
2
+ Lifted from: https://github.com/akhundMurad/typeid-python/blob/main/examples/sqlalchemy.py
3
+ """
4
+
5
+ from typing import Optional
6
+ from uuid import UUID
7
+
8
+ from pydantic import (
9
+ GetJsonSchemaHandler,
10
+ )
11
+ from pydantic_core import CoreSchema, core_schema
12
+ from sqlalchemy import types
13
+ from sqlalchemy.util import generic_repr
14
+ from typeid import TypeID
15
+
16
+ from activemodel.errors import TypeIDValidationError
17
+
18
+
19
+ class TypeIDType(types.TypeDecorator):
20
+ """
21
+ A SQLAlchemy TypeDecorator that allows storing TypeIDs in the database.
22
+ The prefix will not be persisted, instead the database-native UUID field will be used.
23
+ At retrieval time a TypeID will be constructed based on the configured prefix and the
24
+ UUID value from the database.
25
+
26
+ Usage:
27
+ # will result in TypeIDs such as "user_01h45ytscbebyvny4gc8cr8ma2"
28
+ id = mapped_column(
29
+ TypeIDType("user"),
30
+ primary_key=True,
31
+ default=lambda: TypeID("user")
32
+ )
33
+ """
34
+
35
+ impl = types.Uuid
36
+ # impl = uuid.UUID
37
+ cache_ok = True
38
+ prefix: Optional[str] = None
39
+
40
+ def __init__(self, prefix: Optional[str], *args, **kwargs):
41
+ self.prefix = prefix
42
+ super().__init__(*args, **kwargs)
43
+
44
+ def __repr__(self) -> str:
45
+ # Customize __repr__ to ensure that auto-generated code e.g. from alembic includes
46
+ # the right __init__ params (otherwise by default prefix will be omitted because
47
+ # uuid.__init__ does not have such an argument).
48
+ # TODO this makes it so inspected code does NOT include the suffix
49
+ return generic_repr(
50
+ self,
51
+ to_inspect=TypeID(self.prefix),
52
+ )
53
+
54
+ def process_bind_param(self, value, dialect):
55
+ """
56
+ This is run when a search query is built or ...
57
+ """
58
+
59
+ if value is None:
60
+ return None
61
+
62
+ if isinstance(value, UUID):
63
+ # then it's a UUID class, such as UUID('01942886-7afc-7129-8f57-db09137ed002')
64
+ return value
65
+
66
+ if isinstance(value, str) and value.startswith(self.prefix + "_"):
67
+ # then it's a TypeID such as 'user_01h45ytscbebyvny4gc8cr8ma2'
68
+ value = TypeID.from_string(value)
69
+
70
+ if isinstance(value, str):
71
+ # no prefix, raw UUID, let's coerce it into a UUID which SQLAlchemy can handle
72
+ # ex: '01942886-7afc-7129-8f57-db09137ed002'
73
+ return UUID(value)
74
+
75
+ if isinstance(value, TypeID):
76
+ # TODO in what case could this None prefix ever occur?
77
+ if self.prefix is None:
78
+ if value.prefix is None:
79
+ raise TypeIDValidationError(
80
+ "Must have a valid prefix set on the class"
81
+ )
82
+ else:
83
+ if value.prefix != self.prefix:
84
+ raise TypeIDValidationError(
85
+ f"Expected '{self.prefix}' but got '{value.prefix}'"
86
+ )
87
+
88
+ return value.uuid
89
+
90
+ raise ValueError("Unexpected input type")
91
+
92
+ def process_result_value(self, value, dialect):
93
+ if value is None:
94
+ return None
95
+
96
+ return TypeID.from_uuid(value, self.prefix)
97
+
98
+ # def coerce_compared_value(self, op, value):
99
+ # """
100
+ # This method is called when SQLAlchemy needs to compare a column to a value.
101
+ # By returning self, we indicate that this type can handle TypeID instances.
102
+ # """
103
+ # if isinstance(value, TypeID):
104
+ # return self
105
+
106
+ # return super().coerce_compared_value(op, value)
107
+
108
+ @classmethod
109
+ def __get_pydantic_core_schema__(
110
+ cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
111
+ ) -> CoreSchema:
112
+ """
113
+ This fixes the following error: 'Unable to serialize unknown type' by telling pydantic how to serialize this field.
114
+
115
+ Note that TypeIDType MUST be the type of the field in SQLModel otherwise you'll get serialization errors.
116
+ This is done automatically for the mixin but for any relationship fields you'll need to specify the type explicitly.
117
+
118
+ - https://github.com/karma-dev-team/karma-system/blob/ee0c1a06ab2cb7aaca6dc4818312e68c5c623365/app/server/value_objects/steam_id.py#L88
119
+ - https://github.com/hhimanshu/uv-workspaces/blob/main/packages/api/src/_lib/dto/typeid_field.py
120
+ - https://github.com/karma-dev-team/karma-system/blob/ee0c1a06ab2cb7aaca6dc4818312e68c5c623365/app/base/typeid/type_id.py#L14
121
+ - https://github.com/pydantic/pydantic/issues/10060
122
+ - https://github.com/fastapi/fastapi/discussions/10027
123
+ - https://github.com/alice-biometrics/petisco/blob/b01ef1b84949d156f73919e126ed77aa8e0b48dd/petisco/base/domain/model/uuid.py#L50
124
+ """
125
+
126
+ from_uuid_schema = core_schema.chain_schema(
127
+ [
128
+ # TODO not sure how this is different from the UUID schema, should try it out.
129
+ # core_schema.is_instance_schema(TypeID),
130
+ # core_schema.uuid_schema(),
131
+ core_schema.no_info_plain_validator_function(
132
+ TypeID.from_string,
133
+ json_schema_input_schema=core_schema.str_schema(),
134
+ ),
135
+ ]
136
+ )
137
+
138
+ return core_schema.json_or_python_schema(
139
+ json_schema=from_uuid_schema,
140
+ # metadata=core_schema.str_schema(
141
+ # pattern="^[0-9a-f]{24}$",
142
+ # min_length=24,
143
+ # max_length=24,
144
+ # ),
145
+ # metadata={
146
+ # "pydantic_js_input_core_schema": core_schema.str_schema(
147
+ # pattern="^[0-9a-f]{24}$",
148
+ # min_length=24,
149
+ # max_length=24,
150
+ # )
151
+ # },
152
+ python_schema=core_schema.union_schema([from_uuid_schema]),
153
+ serialization=core_schema.plain_serializer_function_ser_schema(
154
+ lambda x: str(x)
155
+ ),
156
+ )
157
+
158
+ @classmethod
159
+ def __get_pydantic_json_schema__(
160
+ cls, schema: CoreSchema, handler: GetJsonSchemaHandler
161
+ ):
162
+ """
163
+ Called when generating the openapi schema. This overrides the `function-plain` type which
164
+ is generated by the `no_info_plain_validator_function`.
165
+
166
+ This logis seems to be a hot part of the codebase, so I'd expect this to break as pydantic
167
+ fastapi continue to evolve.
168
+
169
+ Note that this method can return multiple types. A return value can be as simple as:
170
+
171
+ {"type": "string"}
172
+
173
+ Or, you could return a more specific JSON schema type:
174
+
175
+ core_schema.uuid_schema()
176
+
177
+ The problem with using something like uuid_schema is the specifi patterns
178
+
179
+ https://github.com/BeanieODM/beanie/blob/2190cd9d1fc047af477d5e6897cc283799f54064/beanie/odm/fields.py#L153
180
+ """
181
+
182
+ return {
183
+ "type": "string",
184
+ # TODO implement a more strict pattern in regex
185
+ # https://github.com/jetify-com/typeid/blob/3d182feed5687c21bb5ab93d5f457ff96749b68b/spec/README.md?plain=1#L38
186
+ # "pattern": "^[0-9a-f]{24}$",
187
+ # "minLength": 24,
188
+ # "maxLength": 24,
189
+ }
190
+
191
+ return core_schema.uuid_schema()
activemodel/utils.py ADDED
@@ -0,0 +1,65 @@
1
+ import inspect
2
+ import pkgutil
3
+ import sys
4
+ from types import ModuleType
5
+
6
+ from sqlalchemy import text
7
+ from sqlmodel import SQLModel
8
+ from sqlmodel.sql.expression import SelectOfScalar
9
+
10
+ from .logger import logger
11
+ from .session_manager import get_engine, get_session
12
+
13
+
14
+ def compile_sql(target: SelectOfScalar):
15
+ "convert a query into SQL, helpful for debugging"
16
+ dialect = get_engine().dialect
17
+ # TODO I wonder if we could store the dialect to avoid getting an engine reference
18
+ compiled = target.compile(dialect=dialect, compile_kwargs={"literal_binds": True})
19
+ return str(compiled)
20
+
21
+
22
+ # TODO document further, lots of risks here
23
+ def raw_sql_exec(raw_query: str):
24
+ with get_session() as session:
25
+ session.execute(text(raw_query))
26
+
27
+
28
+ def find_all_sqlmodels(module: ModuleType):
29
+ """Import all model classes from module and submodules into current namespace."""
30
+
31
+ logger.debug(f"Starting model import from module: {module.__name__}")
32
+ model_classes = {}
33
+
34
+ # Walk through all submodules
35
+ for loader, module_name, is_pkg in pkgutil.walk_packages(module.__path__):
36
+ full_name = f"{module.__name__}.{module_name}"
37
+ logger.debug(f"Importing submodule: {full_name}")
38
+
39
+ # Check if module is already imported
40
+ if full_name in sys.modules:
41
+ submodule = sys.modules[full_name]
42
+ else:
43
+ logger.warning(
44
+ f"Module not found in sys.modules, not importing: {full_name}"
45
+ )
46
+ continue
47
+
48
+ # Get all classes from module
49
+ for name, obj in inspect.getmembers(submodule):
50
+ if inspect.isclass(obj) and issubclass(obj, SQLModel) and obj != SQLModel:
51
+ logger.debug(f"Found model class: {name}")
52
+ model_classes[name] = obj
53
+
54
+ logger.debug(f"Completed model import. Found {len(model_classes)} models")
55
+ return model_classes
56
+
57
+
58
+ def hash_function_code(func):
59
+ "get sha of a function to easily assert that it hasn't changed"
60
+
61
+ import hashlib
62
+ import inspect
63
+
64
+ source = inspect.getsource(func)
65
+ return hashlib.sha256(source.encode()).hexdigest()