activemodel 0.5.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,12 +1,19 @@
1
1
  from activemodel import SessionManager
2
2
 
3
+ from ..logger import logger
4
+
3
5
 
4
6
  def database_reset_transaction():
5
7
  """
6
8
  Wrap all database interactions for a given test in a nested transaction and roll it back after the test.
7
9
 
8
10
  >>> from activemodel.pytest import database_reset_transaction
9
- >>> pytest.fixture(scope="function", autouse=True)(database_reset_transaction)
11
+ >>> database_reset_transaction = pytest.fixture(scope="function", autouse=True)(database_reset_transaction)
12
+
13
+ Transaction-based DB cleaning does *not* work if the DB mutations are happening in a separate process, which should
14
+ use spawn, because the same session is not shared across processes. Note that using `fork` is dangerous.
15
+
16
+ In this case, you should use the truncate.
10
17
 
11
18
  References:
12
19
 
@@ -14,38 +21,43 @@ def database_reset_transaction():
14
21
  - https://aalvarez.me/posts/setting-up-a-sqlalchemy-and-pytest-based-test-suite/
15
22
  - https://github.com/nickjj/docker-flask-example/blob/93af9f4fbf185098ffb1d120ee0693abcd77a38b/test/conftest.py#L77
16
23
  - https://github.com/caiola/vinhos.com/blob/c47d0a5d7a4bf290c1b726561d1e8f5d2ac29bc8/backend/test/conftest.py#L46
24
+ - https://stackoverflow.com/questions/64095876/multiprocessing-fork-vs-spawn
25
+
26
+ Using a named SAVEPOINT does not give us anything extra, so we are not using it.
17
27
  """
18
28
 
19
29
  engine = SessionManager.get_instance().get_engine()
20
30
 
31
+ logger.info("starting database transaction")
32
+
21
33
  with engine.begin() as connection:
22
34
  transaction = connection.begin_nested()
23
35
 
36
+ if SessionManager.get_instance().session_connection is not None:
37
+ logger.warning("session override already exists")
38
+ # TODO should we throw an exception here?
39
+
24
40
  SessionManager.get_instance().session_connection = connection
25
41
 
26
42
  try:
27
- yield
43
+ with SessionManager.get_instance().get_session() as factory_session:
44
+ try:
45
+ from factory.alchemy import SQLAlchemyModelFactory
46
+
47
+ # Ensure that all factories use the same session
48
+ for factory in SQLAlchemyModelFactory.__subclasses__():
49
+ factory._meta.sqlalchemy_session = factory_session
50
+ factory._meta.sqlalchemy_session_persistence = "commit"
51
+ except ImportError:
52
+ pass
53
+
54
+ yield
28
55
  finally:
29
- transaction.rollback()
30
- # TODO is this necessary?
31
- connection.close()
32
-
33
-
34
- # TODO unsure if this adds any value beyond the above approach
35
- # def database_reset_named_truncation():
36
- # start_truncation_query = """
37
- # BEGIN;
38
- # SAVEPOINT test_truncation_savepoint;
39
- # """
56
+ logger.debug("rolling back transaction")
40
57
 
41
- # raw_sql_exec(start_truncation_query)
42
-
43
- # yield
58
+ transaction.rollback()
44
59
 
45
- # end_truncation_query = """
46
- # ROLLBACK TO SAVEPOINT test_truncation_savepoint;
47
- # RELEASE SAVEPOINT test_truncation_savepoint;
48
- # ROLLBACK;
49
- # """
60
+ # TODO is this necessary? unclear
61
+ connection.close()
50
62
 
51
- # raw_sql_exec(end_truncation_query)
63
+ SessionManager.get_instance().session_connection = None
@@ -40,7 +40,7 @@ def database_reset_truncate():
40
40
  transaction = connection.begin()
41
41
 
42
42
  if table.name not in exception_tables:
43
- logger.debug("truncating table=%s", table.name)
43
+ logger.debug(f"truncating table={table.name}")
44
44
  connection.execute(table.delete())
45
45
 
46
46
  transaction.commit()
@@ -1,20 +1,26 @@
1
- import sqlmodel
1
+ import sqlmodel as sm
2
+ from sqlmodel.sql.expression import SelectOfScalar
2
3
 
4
+ from .session_manager import get_session
5
+ from .utils import compile_sql
3
6
 
4
- class QueryWrapper[T]:
7
+
8
+ class QueryWrapper[T: sm.SQLModel]:
5
9
  """
6
10
  Make it easy to run queries off of a model
7
11
  """
8
12
 
9
- def __init__(self, cls, *args) -> None:
13
+ target: SelectOfScalar[T]
14
+
15
+ def __init__(self, cls: T, *args) -> None:
10
16
  # TODO add generics here
11
17
  # self.target: SelectOfScalar[T] = sql.select(cls)
12
18
 
13
19
  if args:
14
20
  # very naive, let's assume the args are specific select statements
15
- self.target = sqlmodel.sql.select(*args).select_from(cls)
21
+ self.target = sm.select(*args).select_from(cls)
16
22
  else:
17
- self.target = sql.select(cls)
23
+ self.target = sm.select(cls)
18
24
 
19
25
  # TODO the .exec results should be handled in one shot
20
26
 
@@ -23,6 +29,7 @@ class QueryWrapper[T]:
23
29
  return session.exec(self.target).first()
24
30
 
25
31
  def one(self):
32
+ "requires exactly one result in the dataset"
26
33
  with get_session() as session:
27
34
  return session.exec(self.target).one()
28
35
 
@@ -32,8 +39,14 @@ class QueryWrapper[T]:
32
39
  for row in result:
33
40
  yield row
34
41
 
42
+ def count(self):
43
+ """
44
+ I did some basic tests
45
+ """
46
+ with get_session() as session:
47
+ return session.scalar(sm.select(sm.func.count()).select_from(self.target))
48
+
35
49
  def exec(self):
36
- # TODO do we really need a unique session each time?
37
50
  with get_session() as session:
38
51
  return session.exec(self.target)
39
52
 
@@ -43,19 +56,20 @@ class QueryWrapper[T]:
43
56
 
44
57
  def __getattr__(self, name):
45
58
  """
46
- This is called to retrieve the function to execute
59
+ This implements the magic that forwards function calls to sqlalchemy.
47
60
  """
48
61
 
49
62
  # TODO prefer methods defined in this class
63
+
50
64
  if not hasattr(self.target, name):
51
65
  return super().__getattribute__(name)
52
66
 
53
- attr = getattr(self.target, name)
67
+ sqlalchemy_target = getattr(self.target, name)
54
68
 
55
- if callable(attr):
69
+ if callable(sqlalchemy_target):
56
70
 
57
71
  def wrapper(*args, **kwargs):
58
- result = attr(*args, **kwargs)
72
+ result = sqlalchemy_target(*args, **kwargs)
59
73
  self.target = result
60
74
  return self
61
75
 
@@ -3,24 +3,58 @@ Class to make managing sessions with SQL Model easy. Also provides a common entr
3
3
  database environment when testing.
4
4
  """
5
5
 
6
+ import contextlib
7
+ import contextvars
8
+ import json
6
9
  import typing as t
7
10
 
8
11
  from decouple import config
9
- from sqlalchemy import Engine
12
+ from pydantic import BaseModel
13
+ from sqlalchemy import Connection, Engine
10
14
  from sqlmodel import Session, create_engine
11
15
 
12
16
 
17
+ def _serialize_pydantic_model(model: BaseModel | list[BaseModel] | None) -> str | None:
18
+ """
19
+ Pydantic models do not serialize to JSON. You'll get an error such as:
20
+
21
+ 'TypeError: Object of type TranscriptEntry is not JSON serializable'
22
+
23
+ https://github.com/fastapi/sqlmodel/issues/63#issuecomment-2581016387
24
+
25
+ This custom serializer is passed to the DB engine to properly serialize pydantic models to
26
+ JSON for storage in a JSONB column.
27
+ """
28
+
29
+ # TODO I bet this will fail on lists with mixed types
30
+
31
+ if isinstance(model, BaseModel):
32
+ return model.model_dump_json()
33
+ if isinstance(model, list):
34
+ # not everything in a list is a pydantic model
35
+ def dump_if_model(m):
36
+ if isinstance(m, BaseModel):
37
+ return m.model_dump()
38
+ return m
39
+
40
+ return json.dumps([dump_if_model(m) for m in model])
41
+ else:
42
+ return json.dumps(model)
43
+
44
+
13
45
  class SessionManager:
14
46
  _instance: t.ClassVar[t.Optional["SessionManager"]] = None
47
+ "singleton instance of SessionManager"
15
48
 
16
- session_connection: str
49
+ session_connection: Connection | None
50
+ "optionally specify a specific session connection to use for all get_session() calls, useful for testing"
17
51
 
18
52
  @classmethod
19
53
  def get_instance(cls, database_url: str | None = None) -> "SessionManager":
20
54
  if cls._instance is None:
21
- assert (
22
- database_url is not None
23
- ), "Database URL required for first initialization"
55
+ assert database_url is not None, (
56
+ "Database URL required for first initialization"
57
+ )
24
58
  cls._instance = cls(database_url)
25
59
 
26
60
  return cls._instance
@@ -28,6 +62,7 @@ class SessionManager:
28
62
  def __init__(self, database_url: str):
29
63
  self._database_url = database_url
30
64
  self._engine = None
65
+
31
66
  self.session_connection = None
32
67
 
33
68
  # TODO why is this type not reimported?
@@ -35,6 +70,8 @@ class SessionManager:
35
70
  if not self._engine:
36
71
  self._engine = create_engine(
37
72
  self._database_url,
73
+ # NOTE very important! This enables pydantic models to be serialized for JSONB columns
74
+ json_serializer=_serialize_pydantic_model,
38
75
  echo=config("ACTIVEMODEL_LOG_SQL", cast=bool, default=False),
39
76
  # https://docs.sqlalchemy.org/en/20/core/pooling.html#disconnect-handling-pessimistic
40
77
  pool_pre_ping=True,
@@ -44,6 +81,15 @@ class SessionManager:
44
81
  return self._engine
45
82
 
46
83
  def get_session(self):
84
+ if gsession := _session_context.get():
85
+
86
+ @contextlib.contextmanager
87
+ def _reuse_session():
88
+ yield gsession
89
+
90
+ return _reuse_session()
91
+
92
+ # a connection can generate nested transactions
47
93
  if self.session_connection:
48
94
  return Session(bind=self.session_connection)
49
95
 
@@ -51,6 +97,7 @@ class SessionManager:
51
97
 
52
98
 
53
99
  def init(database_url: str):
100
+ "configure activemodel to connect to a specific database"
54
101
  return SessionManager.get_instance(database_url)
55
102
 
56
103
 
@@ -60,3 +107,43 @@ def get_engine():
60
107
 
61
108
  def get_session():
62
109
  return SessionManager.get_instance().get_session()
110
+
111
+
112
+ # contextvars must be at the top-level of a module! You will not get a warning if you don't do this.
113
+ # ContextVar is implemented in C, so it's very special and is both thread-safe and asyncio safe. This variable gives us
114
+ # a place to persist a session to use globally across the application.
115
+ _session_context = contextvars.ContextVar[Session | None](
116
+ "session_context", default=None
117
+ )
118
+
119
+
120
+ @contextlib.contextmanager
121
+ def global_session():
122
+ with SessionManager.get_instance().get_session() as s:
123
+ token = _session_context.set(s)
124
+
125
+ try:
126
+ yield s
127
+ finally:
128
+ _session_context.reset(token)
129
+
130
+
131
+ async def aglobal_session():
132
+ """
133
+ Use this as a fastapi dependency to get a session that is shared across the request:
134
+
135
+ >>> APIRouter(
136
+ >>> prefix="/internal/v1",
137
+ >>> dependencies=[
138
+ >>> Depends(aglobal_session),
139
+ >>> ]
140
+ >>> )
141
+ """
142
+
143
+ with SessionManager.get_instance().get_session() as s:
144
+ token = _session_context.set(s)
145
+
146
+ try:
147
+ yield
148
+ finally:
149
+ _session_context.reset(token)
@@ -0,0 +1 @@
1
+ from .typeid import TypeIDType
@@ -3,11 +3,18 @@ Lifted from: https://github.com/akhundMurad/typeid-python/blob/main/examples/sql
3
3
  """
4
4
 
5
5
  from typing import Optional
6
+ from uuid import UUID
6
7
 
8
+ from pydantic import (
9
+ GetJsonSchemaHandler,
10
+ )
11
+ from pydantic_core import CoreSchema, core_schema
7
12
  from sqlalchemy import types
8
13
  from sqlalchemy.util import generic_repr
9
14
  from typeid import TypeID
10
15
 
16
+ from activemodel.errors import TypeIDValidationError
17
+
11
18
 
12
19
  class TypeIDType(types.TypeDecorator):
13
20
  """
@@ -45,12 +52,141 @@ class TypeIDType(types.TypeDecorator):
45
52
  )
46
53
 
47
54
  def process_bind_param(self, value, dialect):
48
- if self.prefix is None:
49
- assert value.prefix is None
50
- else:
51
- assert value.prefix == self.prefix
55
+ """
56
+ This is run when a search query is built or ...
57
+ """
58
+
59
+ if value is None:
60
+ return None
61
+
62
+ if isinstance(value, UUID):
63
+ # then it's a UUID class, such as UUID('01942886-7afc-7129-8f57-db09137ed002')
64
+ return value
65
+
66
+ if isinstance(value, str) and value.startswith(self.prefix + "_"):
67
+ # then it's a TypeID such as 'user_01h45ytscbebyvny4gc8cr8ma2'
68
+ value = TypeID.from_string(value)
69
+
70
+ if isinstance(value, str):
71
+ # no prefix, raw UUID, let's coerce it into a UUID which SQLAlchemy can handle
72
+ # ex: '01942886-7afc-7129-8f57-db09137ed002'
73
+ return UUID(value)
74
+
75
+ if isinstance(value, TypeID):
76
+ # TODO in what case could this None prefix ever occur?
77
+ if self.prefix is None:
78
+ if value.prefix is None:
79
+ raise TypeIDValidationError(
80
+ "Must have a valid prefix set on the class"
81
+ )
82
+ else:
83
+ if value.prefix != self.prefix:
84
+ raise TypeIDValidationError(
85
+ f"Expected '{self.prefix}' but got '{value.prefix}'"
86
+ )
87
+
88
+ return value.uuid
52
89
 
53
- return value.uuid
90
+ raise ValueError("Unexpected input type")
54
91
 
55
92
  def process_result_value(self, value, dialect):
93
+ if value is None:
94
+ return None
95
+
56
96
  return TypeID.from_uuid(value, self.prefix)
97
+
98
+ # def coerce_compared_value(self, op, value):
99
+ # """
100
+ # This method is called when SQLAlchemy needs to compare a column to a value.
101
+ # By returning self, we indicate that this type can handle TypeID instances.
102
+ # """
103
+ # if isinstance(value, TypeID):
104
+ # return self
105
+
106
+ # return super().coerce_compared_value(op, value)
107
+
108
+ @classmethod
109
+ def __get_pydantic_core_schema__(
110
+ cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
111
+ ) -> CoreSchema:
112
+ """
113
+ This fixes the following error: 'Unable to serialize unknown type' by telling pydantic how to serialize this field.
114
+
115
+ Note that TypeIDType MUST be the type of the field in SQLModel otherwise you'll get serialization errors.
116
+ This is done automatically for the mixin but for any relationship fields you'll need to specify the type explicitly.
117
+
118
+ - https://github.com/karma-dev-team/karma-system/blob/ee0c1a06ab2cb7aaca6dc4818312e68c5c623365/app/server/value_objects/steam_id.py#L88
119
+ - https://github.com/hhimanshu/uv-workspaces/blob/main/packages/api/src/_lib/dto/typeid_field.py
120
+ - https://github.com/karma-dev-team/karma-system/blob/ee0c1a06ab2cb7aaca6dc4818312e68c5c623365/app/base/typeid/type_id.py#L14
121
+ - https://github.com/pydantic/pydantic/issues/10060
122
+ - https://github.com/fastapi/fastapi/discussions/10027
123
+ - https://github.com/alice-biometrics/petisco/blob/b01ef1b84949d156f73919e126ed77aa8e0b48dd/petisco/base/domain/model/uuid.py#L50
124
+ """
125
+
126
+ from_uuid_schema = core_schema.chain_schema(
127
+ [
128
+ # TODO not sure how this is different from the UUID schema, should try it out.
129
+ # core_schema.is_instance_schema(TypeID),
130
+ # core_schema.uuid_schema(),
131
+ core_schema.no_info_plain_validator_function(
132
+ TypeID.from_string,
133
+ json_schema_input_schema=core_schema.str_schema(),
134
+ ),
135
+ ]
136
+ )
137
+
138
+ return core_schema.json_or_python_schema(
139
+ json_schema=from_uuid_schema,
140
+ # TODO in the the future we could add more exact types
141
+ # metadata=core_schema.str_schema(
142
+ # pattern="^[0-9a-f]{24}$",
143
+ # min_length=24,
144
+ # max_length=24,
145
+ # ),
146
+ # metadata={
147
+ # "pydantic_js_input_core_schema": core_schema.str_schema(
148
+ # pattern="^[0-9a-f]{24}$",
149
+ # min_length=24,
150
+ # max_length=24,
151
+ # )
152
+ # },
153
+ python_schema=core_schema.union_schema([from_uuid_schema]),
154
+ serialization=core_schema.plain_serializer_function_ser_schema(
155
+ lambda x: str(x)
156
+ ),
157
+ )
158
+
159
+ @classmethod
160
+ def __get_pydantic_json_schema__(
161
+ cls, schema: CoreSchema, handler: GetJsonSchemaHandler
162
+ ):
163
+ """
164
+ Called when generating the openapi schema. This overrides the `function-plain` type which
165
+ is generated by the `no_info_plain_validator_function`.
166
+
167
+ This logis seems to be a hot part of the codebase, so I'd expect this to break as pydantic
168
+ fastapi continue to evolve.
169
+
170
+ Note that this method can return multiple types. A return value can be as simple as:
171
+
172
+ {"type": "string"}
173
+
174
+ Or, you could return a more specific JSON schema type:
175
+
176
+ core_schema.uuid_schema()
177
+
178
+ The problem with using something like uuid_schema is the specifi patterns
179
+
180
+ https://github.com/BeanieODM/beanie/blob/2190cd9d1fc047af477d5e6897cc283799f54064/beanie/odm/fields.py#L153
181
+ """
182
+
183
+ return {
184
+ "type": "string",
185
+ # TODO implement a more strict pattern in regex
186
+ # https://github.com/jetify-com/typeid/blob/3d182feed5687c21bb5ab93d5f457ff96749b68b/spec/README.md?plain=1#L38
187
+ # "pattern": "^[0-9a-f]{24}$",
188
+ # "minLength": 24,
189
+ # "maxLength": 24,
190
+ }
191
+
192
+ return core_schema.uuid_schema()
activemodel/utils.py CHANGED
@@ -1,15 +1,65 @@
1
+ import inspect
2
+ import pkgutil
3
+ import sys
4
+ from types import ModuleType
5
+
6
+ from sqlalchemy import text
7
+ from sqlmodel import SQLModel
1
8
  from sqlmodel.sql.expression import SelectOfScalar
2
9
 
3
- from activemodel import get_engine
10
+ from .logger import logger
11
+ from .session_manager import get_engine, get_session
4
12
 
5
13
 
6
14
  def compile_sql(target: SelectOfScalar):
15
+ "convert a query into SQL, helpful for debugging"
7
16
  dialect = get_engine().dialect
8
17
  # TODO I wonder if we could store the dialect to avoid getting an engine reference
9
18
  compiled = target.compile(dialect=dialect, compile_kwargs={"literal_binds": True})
10
19
  return str(compiled)
11
20
 
12
21
 
22
+ # TODO document further, lots of risks here
13
23
  def raw_sql_exec(raw_query: str):
14
24
  with get_session() as session:
15
25
  session.execute(text(raw_query))
26
+
27
+
28
+ def find_all_sqlmodels(module: ModuleType):
29
+ """Import all model classes from module and submodules into current namespace."""
30
+
31
+ logger.debug(f"Starting model import from module: {module.__name__}")
32
+ model_classes = {}
33
+
34
+ # Walk through all submodules
35
+ for loader, module_name, is_pkg in pkgutil.walk_packages(module.__path__):
36
+ full_name = f"{module.__name__}.{module_name}"
37
+ logger.debug(f"Importing submodule: {full_name}")
38
+
39
+ # Check if module is already imported
40
+ if full_name in sys.modules:
41
+ submodule = sys.modules[full_name]
42
+ else:
43
+ logger.warning(
44
+ f"Module not found in sys.modules, not importing: {full_name}"
45
+ )
46
+ continue
47
+
48
+ # Get all classes from module
49
+ for name, obj in inspect.getmembers(submodule):
50
+ if inspect.isclass(obj) and issubclass(obj, SQLModel) and obj != SQLModel:
51
+ logger.debug(f"Found model class: {name}")
52
+ model_classes[name] = obj
53
+
54
+ logger.debug(f"Completed model import. Found {len(model_classes)} models")
55
+ return model_classes
56
+
57
+
58
+ def hash_function_code(func):
59
+ "get sha of a function to easily assert that it hasn't changed"
60
+
61
+ import hashlib
62
+ import inspect
63
+
64
+ source = inspect.getsource(func)
65
+ return hashlib.sha256(source.encode()).hexdigest()