activemodel 0.11.0__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- activemodel/base_model.py +113 -88
- activemodel/cli/__init__.py +147 -0
- activemodel/mixins/timestamps.py +4 -1
- activemodel/mixins/typeid.py +1 -1
- activemodel/pytest/__init__.py +1 -1
- activemodel/pytest/factories.py +102 -0
- activemodel/pytest/plugin.py +81 -0
- activemodel/pytest/transaction.py +103 -16
- activemodel/pytest/truncate.py +108 -7
- activemodel/query_wrapper.py +12 -2
- activemodel/session_manager.py +45 -11
- activemodel/types/sqlalchemy_protocol.py +10 -0
- activemodel/types/sqlalchemy_protocol.pyi +132 -0
- activemodel/types/typeid.py +1 -0
- activemodel/utils.py +3 -39
- {activemodel-0.11.0.dist-info → activemodel-0.13.0.dist-info}/METADATA +27 -6
- activemodel-0.13.0.dist-info/RECORD +30 -0
- activemodel-0.13.0.dist-info/entry_points.txt +2 -0
- activemodel-0.11.0.dist-info/RECORD +0 -25
- activemodel-0.11.0.dist-info/entry_points.txt +0 -2
- {activemodel-0.11.0.dist-info → activemodel-0.13.0.dist-info}/WHEEL +0 -0
- {activemodel-0.11.0.dist-info → activemodel-0.13.0.dist-info}/licenses/LICENSE +0 -0
activemodel/base_model.py
CHANGED
|
@@ -1,25 +1,23 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import typing as t
|
|
3
|
+
import textcase
|
|
3
4
|
from uuid import UUID
|
|
5
|
+
from contextlib import nullcontext
|
|
4
6
|
|
|
5
|
-
import pydash
|
|
6
7
|
import sqlalchemy as sa
|
|
7
8
|
import sqlmodel as sm
|
|
8
|
-
from sqlalchemy import
|
|
9
|
-
from sqlalchemy.orm import Mapper, declared_attr
|
|
9
|
+
from sqlalchemy.dialects.postgresql import insert as postgres_insert
|
|
10
10
|
from sqlalchemy.orm.attributes import flag_modified as sa_flag_modified
|
|
11
|
-
from sqlalchemy.orm.base import instance_state
|
|
12
11
|
from sqlmodel import Column, Field, Session, SQLModel, inspect, select
|
|
13
12
|
from typeid import TypeID
|
|
13
|
+
from sqlalchemy.orm import declared_attr
|
|
14
14
|
|
|
15
15
|
from activemodel.mixins.pydantic_json import PydanticJSONMixin
|
|
16
16
|
|
|
17
17
|
# NOTE: this patches a core method in sqlmodel to support db comments
|
|
18
18
|
from . import get_column_from_field_patch # noqa: F401
|
|
19
|
-
from .logger import logger
|
|
20
19
|
from .query_wrapper import QueryWrapper
|
|
21
20
|
from .session_manager import get_session
|
|
22
|
-
from sqlalchemy.dialects.postgresql import insert as postgres_insert
|
|
23
21
|
|
|
24
22
|
POSTGRES_INDEXES_NAMING_CONVENTION = {
|
|
25
23
|
"ix": "%(column_0_label)s_idx",
|
|
@@ -42,85 +40,46 @@ SQLModel.metadata.naming_convention = POSTGRES_INDEXES_NAMING_CONVENTION
|
|
|
42
40
|
|
|
43
41
|
class BaseModel(SQLModel):
|
|
44
42
|
"""
|
|
45
|
-
Base model class to inherit from so we can hate python less
|
|
43
|
+
Base model class to inherit from so we can hate python less.
|
|
46
44
|
|
|
47
|
-
|
|
45
|
+
Some notes:
|
|
48
46
|
|
|
49
|
-
-
|
|
50
|
-
-
|
|
51
|
-
-
|
|
47
|
+
- Inspired by https://github.com/woofz/sqlmodel-basecrud/blob/main/sqlmodel_basecrud/basecrud.py
|
|
48
|
+
- lifecycle hooks are modeled after Rails.
|
|
49
|
+
- class docstrings are converted to table-level comments
|
|
50
|
+
- save(), delete(), select(), where(), and other easy methods you would expect in a real ORM
|
|
52
51
|
- Fixes foreign key naming conventions
|
|
52
|
+
- Sane table names
|
|
53
|
+
|
|
54
|
+
Here's how hooks work:
|
|
55
|
+
|
|
56
|
+
Create/Update: before_create, after_create, before_update, after_update, before_save, after_save, around_save
|
|
57
|
+
Delete: before_delete, after_delete, around_delete
|
|
58
|
+
|
|
59
|
+
around_* hooks must be context managers (method returning a CM or a CM attribute).
|
|
60
|
+
Ordering (create): before_create -> before_save -> (enter around_save) -> persist -> after_create -> after_save -> (exit around_save)
|
|
61
|
+
Ordering (update): before_update -> before_save -> (enter around_save) -> persist -> after_update -> after_save -> (exit around_save)
|
|
62
|
+
Delete: before_delete -> (enter around_delete) -> delete -> after_delete -> (exit around_delete)
|
|
63
|
+
|
|
64
|
+
# TODO document this in activemodel, this is an interesting edge case
|
|
65
|
+
# https://claude.ai/share/f09e4f70-2ff7-4cd0-abff-44645134693a
|
|
66
|
+
|
|
53
67
|
"""
|
|
54
68
|
|
|
55
|
-
# this is used for table-level comments
|
|
56
69
|
__table_args__ = None
|
|
57
70
|
|
|
58
71
|
@classmethod
|
|
59
72
|
def __init_subclass__(cls, **kwargs):
|
|
60
|
-
"Setup automatic sqlalchemy lifecycle events for the class"
|
|
61
|
-
|
|
62
73
|
super().__init_subclass__(**kwargs)
|
|
63
74
|
|
|
64
75
|
from sqlmodel._compat import set_config_value
|
|
65
76
|
|
|
66
|
-
#
|
|
67
|
-
#
|
|
77
|
+
# Enables field-level docstrings on the pydantic `description` field, which we
|
|
78
|
+
# copy into table/column comments by patching SQLModel internals elsewhere.
|
|
68
79
|
set_config_value(model=cls, parameter="use_attribute_docstrings", value=True)
|
|
69
80
|
|
|
70
81
|
cls._apply_class_doc()
|
|
71
82
|
|
|
72
|
-
def event_wrapper(method_name: str):
|
|
73
|
-
"""
|
|
74
|
-
This does smart heavy lifting for us to make sqlalchemy lifecycle events nicer to work with:
|
|
75
|
-
|
|
76
|
-
* Passes the target first to the lifecycle method, so it feels like an instance method
|
|
77
|
-
* Allows as little as a single positional argument, so methods can be simple
|
|
78
|
-
* Removes the need for decorators or anything fancy on the subclass
|
|
79
|
-
"""
|
|
80
|
-
|
|
81
|
-
def wrapper(mapper: Mapper, connection: Connection, target: BaseModel):
|
|
82
|
-
if hasattr(cls, method_name):
|
|
83
|
-
method = getattr(cls, method_name)
|
|
84
|
-
|
|
85
|
-
if callable(method):
|
|
86
|
-
arg_count = method.__code__.co_argcount
|
|
87
|
-
|
|
88
|
-
if arg_count == 1: # Just self/cls
|
|
89
|
-
method(target)
|
|
90
|
-
elif arg_count == 2: # Self, mapper
|
|
91
|
-
method(target, mapper)
|
|
92
|
-
elif arg_count == 3: # Full signature
|
|
93
|
-
method(target, mapper, connection)
|
|
94
|
-
else:
|
|
95
|
-
raise TypeError(
|
|
96
|
-
f"Method {method_name} must accept either 1 to 3 arguments, got {arg_count}"
|
|
97
|
-
)
|
|
98
|
-
else:
|
|
99
|
-
logger.warning(
|
|
100
|
-
"SQLModel lifecycle hook found, but not callable hook_name=%s",
|
|
101
|
-
method_name,
|
|
102
|
-
)
|
|
103
|
-
|
|
104
|
-
return wrapper
|
|
105
|
-
|
|
106
|
-
event.listen(cls, "before_insert", event_wrapper("before_insert"))
|
|
107
|
-
event.listen(cls, "before_update", event_wrapper("before_update"))
|
|
108
|
-
|
|
109
|
-
# before_save maps to two type of events
|
|
110
|
-
event.listen(cls, "before_insert", event_wrapper("before_save"))
|
|
111
|
-
event.listen(cls, "before_update", event_wrapper("before_save"))
|
|
112
|
-
|
|
113
|
-
# now, let's handle after_* variants
|
|
114
|
-
event.listen(cls, "after_insert", event_wrapper("after_insert"))
|
|
115
|
-
event.listen(cls, "after_update", event_wrapper("after_update"))
|
|
116
|
-
|
|
117
|
-
# after_save maps to two type of events
|
|
118
|
-
event.listen(cls, "after_insert", event_wrapper("after_save"))
|
|
119
|
-
event.listen(cls, "after_update", event_wrapper("after_save"))
|
|
120
|
-
|
|
121
|
-
# def foreign_key()
|
|
122
|
-
# table.id
|
|
123
|
-
|
|
124
83
|
@classmethod
|
|
125
84
|
def _apply_class_doc(cls):
|
|
126
85
|
"""
|
|
@@ -152,19 +111,19 @@ class BaseModel(SQLModel):
|
|
|
152
111
|
@declared_attr
|
|
153
112
|
def __tablename__(cls) -> str:
|
|
154
113
|
"""
|
|
155
|
-
Automatically generates the table name for the model by converting the class name from camel case to snake case.
|
|
156
|
-
This is the recommended
|
|
114
|
+
Automatically generates the table name for the model by converting the model's class name from camel case to snake case.
|
|
115
|
+
This is the recommended text case style for table names:
|
|
157
116
|
|
|
158
117
|
https://wiki.postgresql.org/wiki/Don%27t_Do_This#Don.27t_use_upper_case_table_or_column_names
|
|
159
118
|
|
|
160
|
-
By default, the class is lower cased which makes it harder to read.
|
|
119
|
+
By default, the model's class name is lower cased which makes it harder to read.
|
|
161
120
|
|
|
162
|
-
|
|
163
|
-
|
|
121
|
+
Also, many text case conversion libraries struggle handling words like "LLMCache", this is why we are using
|
|
122
|
+
a more precise library which processes such acronyms: [`textcase`](https://pypi.org/project/textcase/).
|
|
164
123
|
|
|
165
124
|
https://stackoverflow.com/questions/1175208/elegant-python-function-to-convert-camelcase-to-snake-case
|
|
166
125
|
"""
|
|
167
|
-
return
|
|
126
|
+
return textcase.snake(cls.__name__)
|
|
168
127
|
|
|
169
128
|
@classmethod
|
|
170
129
|
def foreign_key(cls, **kwargs):
|
|
@@ -234,34 +193,81 @@ class BaseModel(SQLModel):
|
|
|
234
193
|
return result
|
|
235
194
|
|
|
236
195
|
def delete(self):
|
|
196
|
+
"""Delete instance running delete hooks and optional around_delete context manager."""
|
|
197
|
+
|
|
198
|
+
cm = self._get_around_context_manager("around_delete") or nullcontext()
|
|
199
|
+
|
|
237
200
|
with get_session() as session:
|
|
238
|
-
if
|
|
201
|
+
if (
|
|
202
|
+
old_session := Session.object_session(self)
|
|
203
|
+
) and old_session is not session:
|
|
239
204
|
old_session.expunge(self)
|
|
240
|
-
|
|
241
205
|
session.delete(self)
|
|
242
|
-
|
|
243
|
-
|
|
206
|
+
|
|
207
|
+
self._call_hook("before_delete")
|
|
208
|
+
with cm:
|
|
209
|
+
session.commit()
|
|
210
|
+
self._call_hook("after_delete")
|
|
211
|
+
|
|
212
|
+
return True
|
|
244
213
|
|
|
245
214
|
def save(self):
|
|
215
|
+
"""Persist instance running create/update hooks and optional around_save context manager."""
|
|
216
|
+
|
|
217
|
+
is_new = self.is_new()
|
|
218
|
+
cm = self._get_around_context_manager("around_save") or nullcontext()
|
|
219
|
+
|
|
246
220
|
with get_session() as session:
|
|
247
|
-
if
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
# to get around this, you need to remove it from the old one,
|
|
251
|
-
# then add it to the new one (below)
|
|
221
|
+
if (
|
|
222
|
+
old_session := Session.object_session(self)
|
|
223
|
+
) and old_session is not session:
|
|
252
224
|
old_session.expunge(self)
|
|
253
225
|
|
|
254
226
|
session.add(self)
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
session
|
|
227
|
+
|
|
228
|
+
# the order and placement of these hooks is really important
|
|
229
|
+
# we need the current object to be in a session otherwise it will not be able to
|
|
230
|
+
# load any relationships.
|
|
231
|
+
self._call_hook("before_create" if is_new else "before_update")
|
|
232
|
+
self._call_hook("before_save")
|
|
233
|
+
|
|
234
|
+
with cm:
|
|
235
|
+
session.commit()
|
|
236
|
+
session.refresh(self)
|
|
237
|
+
|
|
238
|
+
self._call_hook("after_create" if is_new else "after_update")
|
|
239
|
+
self._call_hook("after_save")
|
|
258
240
|
|
|
259
241
|
# Only call the transform method if the class is a subclass of PydanticJSONMixin
|
|
260
242
|
if issubclass(self.__class__, PydanticJSONMixin):
|
|
261
243
|
self.__class__.__transform_dict_to_pydantic__(self)
|
|
262
|
-
|
|
263
244
|
return self
|
|
264
245
|
|
|
246
|
+
def _call_hook(self, hook_name: str) -> None:
|
|
247
|
+
method = getattr(self, hook_name, None)
|
|
248
|
+
if callable(method):
|
|
249
|
+
if method.__code__.co_argcount != 1:
|
|
250
|
+
raise TypeError(
|
|
251
|
+
f"Hook '{hook_name}' must accept exactly 1 positional argument (self)"
|
|
252
|
+
)
|
|
253
|
+
method()
|
|
254
|
+
|
|
255
|
+
def _get_around_context_manager(self, name: str) -> t.ContextManager | None:
|
|
256
|
+
obj = getattr(self, name, None)
|
|
257
|
+
if obj is None:
|
|
258
|
+
return None
|
|
259
|
+
|
|
260
|
+
# If it's a callable (method/function), call it to obtain the CM
|
|
261
|
+
if callable(obj):
|
|
262
|
+
obj = obj()
|
|
263
|
+
|
|
264
|
+
cm = obj
|
|
265
|
+
if not (hasattr(cm, "__enter__") and hasattr(cm, "__exit__")):
|
|
266
|
+
raise TypeError(
|
|
267
|
+
f"{name} must return or be a context manager implementing __enter__/__exit__"
|
|
268
|
+
)
|
|
269
|
+
return t.cast(t.ContextManager, cm)
|
|
270
|
+
|
|
265
271
|
def refresh(self):
|
|
266
272
|
"Refreshes an object from the database"
|
|
267
273
|
|
|
@@ -282,6 +288,7 @@ class BaseModel(SQLModel):
|
|
|
282
288
|
|
|
283
289
|
# TODO shouldn't this be handled by pydantic?
|
|
284
290
|
# TODO where is this actually used? shoudl prob remove this
|
|
291
|
+
# TODO should we even do this? Can we specify a better json rendering class?
|
|
285
292
|
def json(self, **kwargs):
|
|
286
293
|
return json.dumps(self.dict(), default=str, **kwargs)
|
|
287
294
|
|
|
@@ -297,7 +304,12 @@ class BaseModel(SQLModel):
|
|
|
297
304
|
# TODO got to be a better way to fwd these along...
|
|
298
305
|
@classmethod
|
|
299
306
|
def first(cls):
|
|
300
|
-
|
|
307
|
+
# TODO should use dynamic pk
|
|
308
|
+
return cls.select().order_by(sa.desc(cls.id)).first()
|
|
309
|
+
|
|
310
|
+
# @classmethod
|
|
311
|
+
# def last(cls):
|
|
312
|
+
# return cls.select().first()
|
|
301
313
|
|
|
302
314
|
# TODO throw an error if this field is set on the model
|
|
303
315
|
def is_new(self) -> bool:
|
|
@@ -404,6 +416,19 @@ class BaseModel(SQLModel):
|
|
|
404
416
|
with get_session() as session:
|
|
405
417
|
return session.exec(statement).first()
|
|
406
418
|
|
|
419
|
+
@classmethod
|
|
420
|
+
def one_or_none(cls, *args: t.Any, **kwargs: t.Any):
|
|
421
|
+
"""
|
|
422
|
+
Gets a single record from the database. Pass an PK ID or a kwarg to filter by.
|
|
423
|
+
Returns None if no record is found. Throws an error if more than one record is found.
|
|
424
|
+
"""
|
|
425
|
+
|
|
426
|
+
args, kwargs = cls.__process_filter_args__(*args, **kwargs)
|
|
427
|
+
statement = select(cls).filter(*args).filter_by(**kwargs)
|
|
428
|
+
|
|
429
|
+
with get_session() as session:
|
|
430
|
+
return session.exec(statement).one_or_none()
|
|
431
|
+
|
|
407
432
|
@classmethod
|
|
408
433
|
def one(cls, *args: t.Any, **kwargs: t.Any):
|
|
409
434
|
"""
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module provides utilities for generating Protocol type definitions for SQLAlchemy's
|
|
3
|
+
SelectOfScalar methods, as well as formatting and fixing Python files using ruff.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import inspect
|
|
7
|
+
import logging
|
|
8
|
+
import os
|
|
9
|
+
import subprocess
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Any # already imported in header of generated file
|
|
12
|
+
|
|
13
|
+
import sqlmodel as sm
|
|
14
|
+
from sqlmodel.sql.expression import SelectOfScalar
|
|
15
|
+
|
|
16
|
+
from test.test_wrapper import QueryWrapper
|
|
17
|
+
|
|
18
|
+
# Set up logging
|
|
19
|
+
logging.basicConfig(level=logging.DEBUG)
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
QUERY_WRAPPER_CLASS_NAME = QueryWrapper.__name__
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def format_python_file(file_path: str | Path) -> bool:
|
|
26
|
+
"""
|
|
27
|
+
Format a Python file using ruff.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
file_path: Path to the Python file to format
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
bool: True if formatting was successful, False otherwise
|
|
34
|
+
"""
|
|
35
|
+
try:
|
|
36
|
+
subprocess.run(["ruff", "format", str(file_path)], check=True)
|
|
37
|
+
logger.info(f"Formatted file using ruff at {file_path}")
|
|
38
|
+
return True
|
|
39
|
+
except subprocess.CalledProcessError as e:
|
|
40
|
+
logger.error(f"Error running ruff to format the file: {e}")
|
|
41
|
+
return False
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def fix_python_file(file_path: str | Path) -> bool:
|
|
45
|
+
"""
|
|
46
|
+
Fix linting issues in a Python file using ruff.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
file_path: Path to the Python file to fix
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
bool: True if fixing was successful, False otherwise
|
|
53
|
+
"""
|
|
54
|
+
try:
|
|
55
|
+
subprocess.run(["ruff", "check", str(file_path), "--fix"], check=True)
|
|
56
|
+
logger.info(f"Fixed linting issues using ruff at {file_path}")
|
|
57
|
+
return True
|
|
58
|
+
except subprocess.CalledProcessError as e:
|
|
59
|
+
logger.error(f"Error running ruff to fix the file: {e}")
|
|
60
|
+
return False
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def generate_sqlalchemy_protocol():
|
|
64
|
+
"""Generate Protocol type definitions for SQLAlchemy SelectOfScalar methods"""
|
|
65
|
+
logger.info("Starting SQLAlchemy protocol generation")
|
|
66
|
+
|
|
67
|
+
header = """
|
|
68
|
+
# IMPORTANT: This file is auto-generated. Do not edit directly.
|
|
69
|
+
|
|
70
|
+
from typing import Protocol, TypeVar, Any, Generic
|
|
71
|
+
import sqlmodel as sm
|
|
72
|
+
from sqlalchemy.sql.base import _NoArg
|
|
73
|
+
|
|
74
|
+
T = TypeVar('T', bound=sm.SQLModel, covariant=True)
|
|
75
|
+
|
|
76
|
+
class SQLAlchemyQueryMethods(Protocol, Generic[T]):
|
|
77
|
+
\"""Protocol defining SQLAlchemy query methods forwarded by QueryWrapper.__getattr__\"""
|
|
78
|
+
"""
|
|
79
|
+
# Initialize output list for generated method signatures
|
|
80
|
+
output: list = []
|
|
81
|
+
|
|
82
|
+
try:
|
|
83
|
+
# Get all methods from SelectOfScalar
|
|
84
|
+
methods = inspect.getmembers(SelectOfScalar)
|
|
85
|
+
logger.debug(f"Discovered {len(methods)} methods from SelectOfScalar")
|
|
86
|
+
|
|
87
|
+
for name, method in methods:
|
|
88
|
+
# Skip private/dunder methods
|
|
89
|
+
if name.startswith("_"):
|
|
90
|
+
continue
|
|
91
|
+
|
|
92
|
+
if not inspect.isfunction(method) and not inspect.ismethod(method):
|
|
93
|
+
logger.debug(f"Skipping non-method: {name}")
|
|
94
|
+
continue
|
|
95
|
+
|
|
96
|
+
logger.debug(f"Processing method: {name}")
|
|
97
|
+
try:
|
|
98
|
+
signature = inspect.signature(method)
|
|
99
|
+
params = []
|
|
100
|
+
|
|
101
|
+
# Process parameters, skipping 'self'
|
|
102
|
+
for param_name, param in list(signature.parameters.items())[1:]:
|
|
103
|
+
if param.kind == param.VAR_POSITIONAL:
|
|
104
|
+
params.append(f"*{param_name}: Any")
|
|
105
|
+
elif param.kind == param.VAR_KEYWORD:
|
|
106
|
+
params.append(f"**{param_name}: Any")
|
|
107
|
+
else:
|
|
108
|
+
if param.default is inspect.Parameter.empty:
|
|
109
|
+
params.append(f"{param_name}: Any")
|
|
110
|
+
else:
|
|
111
|
+
default_repr = repr(param.default)
|
|
112
|
+
params.append(f"{param_name}: Any = {default_repr}")
|
|
113
|
+
|
|
114
|
+
params_str = ", ".join(params)
|
|
115
|
+
output.append(
|
|
116
|
+
f' def {name}(self, {params_str}) -> "{QUERY_WRAPPER_CLASS_NAME}[T]": ...'
|
|
117
|
+
)
|
|
118
|
+
except (ValueError, TypeError) as e:
|
|
119
|
+
logger.warning(f"Could not get signature for {name}: {e}")
|
|
120
|
+
# Some methods might not have proper signatures
|
|
121
|
+
output.append(
|
|
122
|
+
f' def {name}(self, *args: Any, **kwargs: Any) -> "{QUERY_WRAPPER_CLASS_NAME}[T]": ...'
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
# Write the output to a file
|
|
126
|
+
protocol_path = (
|
|
127
|
+
Path(__file__).parent.parent / "types" / "sqlalchemy_protocol.py"
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# Ensure directory exists
|
|
131
|
+
os.makedirs(protocol_path.parent, exist_ok=True)
|
|
132
|
+
|
|
133
|
+
with open(protocol_path, "w") as f:
|
|
134
|
+
f.write(header + "\n".join(output))
|
|
135
|
+
|
|
136
|
+
logger.info(f"Generated SQLAlchemy protocol at {protocol_path}")
|
|
137
|
+
|
|
138
|
+
# Format and fix the generated file with ruff
|
|
139
|
+
format_python_file(protocol_path)
|
|
140
|
+
fix_python_file(protocol_path)
|
|
141
|
+
except Exception as e:
|
|
142
|
+
logger.error(f"Error generating SQLAlchemy protocol: {e}", exc_info=True)
|
|
143
|
+
raise
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
if __name__ == "__main__":
|
|
147
|
+
generate_sqlalchemy_protocol()
|
activemodel/mixins/timestamps.py
CHANGED
|
@@ -20,7 +20,10 @@ class TimestampsMixin:
|
|
|
20
20
|
>>> class MyModel(TimestampsMixin, SQLModel):
|
|
21
21
|
>>> pass
|
|
22
22
|
|
|
23
|
-
|
|
23
|
+
Notes:
|
|
24
|
+
|
|
25
|
+
- Originally pulled from: https://github.com/tiangolo/sqlmodel/issues/252
|
|
26
|
+
- Related issue: https://github.com/fastapi/sqlmodel/issues/539
|
|
24
27
|
"""
|
|
25
28
|
|
|
26
29
|
created_at: datetime | None = Field(
|
activemodel/mixins/typeid.py
CHANGED
|
@@ -17,7 +17,7 @@ def TypeIDMixin(prefix: str):
|
|
|
17
17
|
# NOTE this will cause issues on code reloads
|
|
18
18
|
assert prefix
|
|
19
19
|
assert prefix not in _prefixes, (
|
|
20
|
-
f"prefix {prefix} already exists, pick a different one"
|
|
20
|
+
f"TypeID prefix '{prefix}' already exists, pick a different one"
|
|
21
21
|
)
|
|
22
22
|
|
|
23
23
|
class _TypeIDMixin:
|
activemodel/pytest/__init__.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
from .transaction import database_reset_transaction
|
|
1
|
+
from .transaction import database_reset_transaction, test_session
|
|
2
2
|
from .truncate import database_reset_truncate
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Notes on polyfactory:
|
|
3
|
+
|
|
4
|
+
1. is_supported_type validates that the class can be used to generate a factory
|
|
5
|
+
https://github.com/litestar-org/polyfactory/issues/655#issuecomment-2727450854
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import typing as t
|
|
9
|
+
|
|
10
|
+
from polyfactory.factories.pydantic_factory import ModelFactory
|
|
11
|
+
from polyfactory.field_meta import FieldMeta
|
|
12
|
+
from typeid import TypeID
|
|
13
|
+
|
|
14
|
+
from activemodel.session_manager import global_session
|
|
15
|
+
|
|
16
|
+
# TODO not currently used
|
|
17
|
+
# def type_id_provider(cls, field_meta):
|
|
18
|
+
# # TODO this doesn't work well with __ args:
|
|
19
|
+
# # https://github.com/litestar-org/polyfactory/pull/666/files
|
|
20
|
+
# return str(TypeID("hi"))
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# BaseFactory.add_provider(TypeIDType, type_id_provider)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class SQLModelFactory[T](ModelFactory[T]):
|
|
27
|
+
"""
|
|
28
|
+
Base factory for SQLModel models:
|
|
29
|
+
|
|
30
|
+
1. Ability to ignore all relationship fks
|
|
31
|
+
2. Option to ignore all pks
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
__is_base_factory__ = True
|
|
35
|
+
|
|
36
|
+
@classmethod
|
|
37
|
+
def should_set_field_value(cls, field_meta: FieldMeta, **kwargs: t.Any) -> bool:
|
|
38
|
+
# TODO what is this checking for?
|
|
39
|
+
has_object_override = hasattr(cls, field_meta.name)
|
|
40
|
+
|
|
41
|
+
# TODO this should be more intelligent, it's goal is to detect all of the relationship field and avoid settings them
|
|
42
|
+
if not has_object_override and (
|
|
43
|
+
field_meta.name == "id" or field_meta.name.endswith("_id")
|
|
44
|
+
):
|
|
45
|
+
return False
|
|
46
|
+
|
|
47
|
+
return super().should_set_field_value(field_meta, **kwargs)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
# TODO we need to think through how to handle relationships and autogenerate them
|
|
51
|
+
class ActiveModelFactory[T](SQLModelFactory[T]):
|
|
52
|
+
__is_base_factory__ = True
|
|
53
|
+
__sqlalchemy_session__ = None
|
|
54
|
+
|
|
55
|
+
# TODO we shouldn't have to type this, but `save()` typing is not working
|
|
56
|
+
@classmethod
|
|
57
|
+
def save(cls, *args, **kwargs) -> T:
|
|
58
|
+
"""
|
|
59
|
+
Where this gets tricky, is this can be called multiple times within the same callstack. This can happen when
|
|
60
|
+
a factory uses other factories to create relationships.
|
|
61
|
+
|
|
62
|
+
In a truncation strategy, the __sqlalchemy_session__ is set to None.
|
|
63
|
+
"""
|
|
64
|
+
with global_session(cls.__sqlalchemy_session__):
|
|
65
|
+
return cls.build(*args, **kwargs).save()
|
|
66
|
+
|
|
67
|
+
@classmethod
|
|
68
|
+
def foreign_key_typeid(cls):
|
|
69
|
+
"""
|
|
70
|
+
Return a random type id for the foreign key on this model.
|
|
71
|
+
|
|
72
|
+
This is helpful for generating TypeIDs for testing 404s, parsing, manually settings, etc
|
|
73
|
+
"""
|
|
74
|
+
# TODO right now assumes the model is typeid, maybe we should assert against this?
|
|
75
|
+
primary_key_name = cls.__model__.primary_key_column().name
|
|
76
|
+
return TypeID(
|
|
77
|
+
cls.__model__.model_fields[primary_key_name].sa_column.type.prefix
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
@classmethod
|
|
81
|
+
def should_set_field_value(cls, field_meta: FieldMeta, **kwargs: t.Any) -> bool:
|
|
82
|
+
# do not default deleted at mixin to deleted!
|
|
83
|
+
# TODO should be smarter about detecting if the mixin is in place
|
|
84
|
+
if field_meta.name in ["deleted_at", "updated_at", "created_at"]:
|
|
85
|
+
return False
|
|
86
|
+
|
|
87
|
+
return super().should_set_field_value(field_meta, **kwargs)
|
|
88
|
+
|
|
89
|
+
# @classmethod
|
|
90
|
+
# def build(
|
|
91
|
+
# cls,
|
|
92
|
+
# factory_use_construct: bool | None = None,
|
|
93
|
+
# sqlmodel_save: bool = False,
|
|
94
|
+
# **kwargs: t.Any,
|
|
95
|
+
# ) -> T:
|
|
96
|
+
# result = super().build(factory_use_construct=factory_use_construct, **kwargs)
|
|
97
|
+
|
|
98
|
+
# # TODO allow magic dunder method here
|
|
99
|
+
# if sqlmodel_save:
|
|
100
|
+
# result.save()
|
|
101
|
+
|
|
102
|
+
# return result
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
"""Pytest plugin integration for activemodel.
|
|
2
|
+
|
|
3
|
+
Currently provides:
|
|
4
|
+
|
|
5
|
+
* ``db_session`` fixture – quick access to a database session (see ``test_session``)
|
|
6
|
+
* ``activemodel_preserve_tables`` ini option – configure tables to preserve when using
|
|
7
|
+
``database_reset_truncate`` (comma separated list or multiple lines depending on config style)
|
|
8
|
+
|
|
9
|
+
Configuration examples:
|
|
10
|
+
|
|
11
|
+
pytest.ini::
|
|
12
|
+
|
|
13
|
+
[pytest]
|
|
14
|
+
activemodel_preserve_tables = alembic_version,zip_code,seed_table
|
|
15
|
+
|
|
16
|
+
pyproject.toml::
|
|
17
|
+
|
|
18
|
+
[tool.pytest.ini_options]
|
|
19
|
+
activemodel_preserve_tables = [
|
|
20
|
+
"alembic_version",
|
|
21
|
+
"zip_code",
|
|
22
|
+
"seed_table",
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
The list always implicitly includes ``alembic_version`` even if not specified.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
from activemodel.session_manager import global_session
|
|
29
|
+
import pytest
|
|
30
|
+
|
|
31
|
+
from .transaction import set_factory_session, set_polyfactory_session, test_session
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def pytest_addoption(
|
|
35
|
+
parser: pytest.Parser,
|
|
36
|
+
) -> None: # pragma: no cover - executed during collection
|
|
37
|
+
"""Register custom ini options.
|
|
38
|
+
|
|
39
|
+
We treat this as a *linelist* so pyproject.toml list syntax works. Comma separated works too because
|
|
40
|
+
pytest splits lines first; users can still provide one line with commas.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
parser.addini(
|
|
44
|
+
"activemodel_preserve_tables",
|
|
45
|
+
help=(
|
|
46
|
+
"Tables to preserve when calling activemodel.pytest.database_reset_truncate. "
|
|
47
|
+
),
|
|
48
|
+
type="linelist",
|
|
49
|
+
default=["alembic_version"],
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@pytest.fixture(scope="function")
|
|
54
|
+
def db_session():
|
|
55
|
+
"""
|
|
56
|
+
Helpful for tests that are more similar to unit tests. If you doing a routing or integration test, you
|
|
57
|
+
probably don't need this. If your unit test is simple (you are just creating a couple of models) you
|
|
58
|
+
can most likely skip this.
|
|
59
|
+
|
|
60
|
+
This is helpful if you are doing a lot of lazy-loaded params or need a database session to be in place
|
|
61
|
+
for testing code that will run within a celery worker or something similar.
|
|
62
|
+
|
|
63
|
+
>>> def the_test(db_session):
|
|
64
|
+
"""
|
|
65
|
+
with test_session() as session:
|
|
66
|
+
yield session
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@pytest.fixture(scope="function")
|
|
70
|
+
def db_truncate_session():
|
|
71
|
+
"""
|
|
72
|
+
Provides a database session for testing when using a truncation cleaning strategy.
|
|
73
|
+
|
|
74
|
+
When not using a transaction cleaning strategy, no global test session is set
|
|
75
|
+
"""
|
|
76
|
+
with global_session() as session:
|
|
77
|
+
# set global database sessions for model factories to avoid lazy loading issues
|
|
78
|
+
set_factory_session(session)
|
|
79
|
+
set_polyfactory_session(session)
|
|
80
|
+
|
|
81
|
+
yield session
|