activemodel 0.11.0__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
activemodel/base_model.py CHANGED
@@ -1,25 +1,23 @@
1
1
  import json
2
2
  import typing as t
3
+ import textcase
3
4
  from uuid import UUID
5
+ from contextlib import nullcontext
4
6
 
5
- import pydash
6
7
  import sqlalchemy as sa
7
8
  import sqlmodel as sm
8
- from sqlalchemy import Connection, event
9
- from sqlalchemy.orm import Mapper, declared_attr
9
+ from sqlalchemy.dialects.postgresql import insert as postgres_insert
10
10
  from sqlalchemy.orm.attributes import flag_modified as sa_flag_modified
11
- from sqlalchemy.orm.base import instance_state
12
11
  from sqlmodel import Column, Field, Session, SQLModel, inspect, select
13
12
  from typeid import TypeID
13
+ from sqlalchemy.orm import declared_attr
14
14
 
15
15
  from activemodel.mixins.pydantic_json import PydanticJSONMixin
16
16
 
17
17
  # NOTE: this patches a core method in sqlmodel to support db comments
18
18
  from . import get_column_from_field_patch # noqa: F401
19
- from .logger import logger
20
19
  from .query_wrapper import QueryWrapper
21
20
  from .session_manager import get_session
22
- from sqlalchemy.dialects.postgresql import insert as postgres_insert
23
21
 
24
22
  POSTGRES_INDEXES_NAMING_CONVENTION = {
25
23
  "ix": "%(column_0_label)s_idx",
@@ -42,85 +40,46 @@ SQLModel.metadata.naming_convention = POSTGRES_INDEXES_NAMING_CONVENTION
42
40
 
43
41
  class BaseModel(SQLModel):
44
42
  """
45
- Base model class to inherit from so we can hate python less
43
+ Base model class to inherit from so we can hate python less.
46
44
 
47
- https://github.com/woofz/sqlmodel-basecrud/blob/main/sqlmodel_basecrud/basecrud.py
45
+ Some notes:
48
46
 
49
- - {before,after} lifecycle hooks are modeled after Rails.
50
- - class docstrings are converd to table-level comments
51
- - save(), delete(), select(), where(), and other easy methods you would expect
47
+ - Inspired by https://github.com/woofz/sqlmodel-basecrud/blob/main/sqlmodel_basecrud/basecrud.py
48
+ - lifecycle hooks are modeled after Rails.
49
+ - class docstrings are converted to table-level comments
50
+ - save(), delete(), select(), where(), and other easy methods you would expect in a real ORM
52
51
  - Fixes foreign key naming conventions
52
+ - Sane table names
53
+
54
+ Here's how hooks work:
55
+
56
+ Create/Update: before_create, after_create, before_update, after_update, before_save, after_save, around_save
57
+ Delete: before_delete, after_delete, around_delete
58
+
59
+ around_* hooks must be context managers (method returning a CM or a CM attribute).
60
+ Ordering (create): before_create -> before_save -> (enter around_save) -> persist -> after_create -> after_save -> (exit around_save)
61
+ Ordering (update): before_update -> before_save -> (enter around_save) -> persist -> after_update -> after_save -> (exit around_save)
62
+ Delete: before_delete -> (enter around_delete) -> delete -> after_delete -> (exit around_delete)
63
+
64
+ # TODO document this in activemodel, this is an interesting edge case
65
+ # https://claude.ai/share/f09e4f70-2ff7-4cd0-abff-44645134693a
66
+
53
67
  """
54
68
 
55
- # this is used for table-level comments
56
69
  __table_args__ = None
57
70
 
58
71
  @classmethod
59
72
  def __init_subclass__(cls, **kwargs):
60
- "Setup automatic sqlalchemy lifecycle events for the class"
61
-
62
73
  super().__init_subclass__(**kwargs)
63
74
 
64
75
  from sqlmodel._compat import set_config_value
65
76
 
66
- # enables field-level docstrings on the pydanatic `description` field, which we then copy into
67
- # sa_args, which is persisted to sql table comments
77
+ # Enables field-level docstrings on the pydantic `description` field, which we
78
+ # copy into table/column comments by patching SQLModel internals elsewhere.
68
79
  set_config_value(model=cls, parameter="use_attribute_docstrings", value=True)
69
80
 
70
81
  cls._apply_class_doc()
71
82
 
72
- def event_wrapper(method_name: str):
73
- """
74
- This does smart heavy lifting for us to make sqlalchemy lifecycle events nicer to work with:
75
-
76
- * Passes the target first to the lifecycle method, so it feels like an instance method
77
- * Allows as little as a single positional argument, so methods can be simple
78
- * Removes the need for decorators or anything fancy on the subclass
79
- """
80
-
81
- def wrapper(mapper: Mapper, connection: Connection, target: BaseModel):
82
- if hasattr(cls, method_name):
83
- method = getattr(cls, method_name)
84
-
85
- if callable(method):
86
- arg_count = method.__code__.co_argcount
87
-
88
- if arg_count == 1: # Just self/cls
89
- method(target)
90
- elif arg_count == 2: # Self, mapper
91
- method(target, mapper)
92
- elif arg_count == 3: # Full signature
93
- method(target, mapper, connection)
94
- else:
95
- raise TypeError(
96
- f"Method {method_name} must accept either 1 to 3 arguments, got {arg_count}"
97
- )
98
- else:
99
- logger.warning(
100
- "SQLModel lifecycle hook found, but not callable hook_name=%s",
101
- method_name,
102
- )
103
-
104
- return wrapper
105
-
106
- event.listen(cls, "before_insert", event_wrapper("before_insert"))
107
- event.listen(cls, "before_update", event_wrapper("before_update"))
108
-
109
- # before_save maps to two type of events
110
- event.listen(cls, "before_insert", event_wrapper("before_save"))
111
- event.listen(cls, "before_update", event_wrapper("before_save"))
112
-
113
- # now, let's handle after_* variants
114
- event.listen(cls, "after_insert", event_wrapper("after_insert"))
115
- event.listen(cls, "after_update", event_wrapper("after_update"))
116
-
117
- # after_save maps to two type of events
118
- event.listen(cls, "after_insert", event_wrapper("after_save"))
119
- event.listen(cls, "after_update", event_wrapper("after_save"))
120
-
121
- # def foreign_key()
122
- # table.id
123
-
124
83
  @classmethod
125
84
  def _apply_class_doc(cls):
126
85
  """
@@ -152,19 +111,19 @@ class BaseModel(SQLModel):
152
111
  @declared_attr
153
112
  def __tablename__(cls) -> str:
154
113
  """
155
- Automatically generates the table name for the model by converting the class name from camel case to snake case.
156
- This is the recommended format for table names:
114
+ Automatically generates the table name for the model by converting the model's class name from camel case to snake case.
115
+ This is the recommended text case style for table names:
157
116
 
158
117
  https://wiki.postgresql.org/wiki/Don%27t_Do_This#Don.27t_use_upper_case_table_or_column_names
159
118
 
160
- By default, the class is lower cased which makes it harder to read.
119
+ By default, the model's class name is lower cased which makes it harder to read.
161
120
 
162
- Many snake_case libraries struggle with snake case for names like LLMCache, which is why we are using a more
163
- complicated implementation from pydash.
121
+ Also, many text case conversion libraries struggle handling words like "LLMCache", this is why we are using
122
+ a more precise library which processes such acronyms: [`textcase`](https://pypi.org/project/textcase/).
164
123
 
165
124
  https://stackoverflow.com/questions/1175208/elegant-python-function-to-convert-camelcase-to-snake-case
166
125
  """
167
- return pydash.strings.snake_case(cls.__name__)
126
+ return textcase.snake(cls.__name__)
168
127
 
169
128
  @classmethod
170
129
  def foreign_key(cls, **kwargs):
@@ -234,34 +193,81 @@ class BaseModel(SQLModel):
234
193
  return result
235
194
 
236
195
  def delete(self):
196
+ """Delete instance running delete hooks and optional around_delete context manager."""
197
+
198
+ cm = self._get_around_context_manager("around_delete") or nullcontext()
199
+
237
200
  with get_session() as session:
238
- if old_session := Session.object_session(self):
201
+ if (
202
+ old_session := Session.object_session(self)
203
+ ) and old_session is not session:
239
204
  old_session.expunge(self)
240
-
241
205
  session.delete(self)
242
- session.commit()
243
- return True
206
+
207
+ self._call_hook("before_delete")
208
+ with cm:
209
+ session.commit()
210
+ self._call_hook("after_delete")
211
+
212
+ return True
244
213
 
245
214
  def save(self):
215
+ """Persist instance running create/update hooks and optional around_save context manager."""
216
+
217
+ is_new = self.is_new()
218
+ cm = self._get_around_context_manager("around_save") or nullcontext()
219
+
246
220
  with get_session() as session:
247
- if old_session := Session.object_session(self):
248
- # I was running into an issue where the object was already
249
- # associated with a session, but the session had been closed,
250
- # to get around this, you need to remove it from the old one,
251
- # then add it to the new one (below)
221
+ if (
222
+ old_session := Session.object_session(self)
223
+ ) and old_session is not session:
252
224
  old_session.expunge(self)
253
225
 
254
226
  session.add(self)
255
- # NOTE very important method! This triggers sqlalchemy lifecycle hooks automatically
256
- session.commit()
257
- session.refresh(self)
227
+
228
+ # the order and placement of these hooks is really important
229
+ # we need the current object to be in a session otherwise it will not be able to
230
+ # load any relationships.
231
+ self._call_hook("before_create" if is_new else "before_update")
232
+ self._call_hook("before_save")
233
+
234
+ with cm:
235
+ session.commit()
236
+ session.refresh(self)
237
+
238
+ self._call_hook("after_create" if is_new else "after_update")
239
+ self._call_hook("after_save")
258
240
 
259
241
  # Only call the transform method if the class is a subclass of PydanticJSONMixin
260
242
  if issubclass(self.__class__, PydanticJSONMixin):
261
243
  self.__class__.__transform_dict_to_pydantic__(self)
262
-
263
244
  return self
264
245
 
246
+ def _call_hook(self, hook_name: str) -> None:
247
+ method = getattr(self, hook_name, None)
248
+ if callable(method):
249
+ if method.__code__.co_argcount != 1:
250
+ raise TypeError(
251
+ f"Hook '{hook_name}' must accept exactly 1 positional argument (self)"
252
+ )
253
+ method()
254
+
255
+ def _get_around_context_manager(self, name: str) -> t.ContextManager | None:
256
+ obj = getattr(self, name, None)
257
+ if obj is None:
258
+ return None
259
+
260
+ # If it's a callable (method/function), call it to obtain the CM
261
+ if callable(obj):
262
+ obj = obj()
263
+
264
+ cm = obj
265
+ if not (hasattr(cm, "__enter__") and hasattr(cm, "__exit__")):
266
+ raise TypeError(
267
+ f"{name} must return or be a context manager implementing __enter__/__exit__"
268
+ )
269
+ return t.cast(t.ContextManager, cm)
270
+
265
271
  def refresh(self):
266
272
  "Refreshes an object from the database"
267
273
 
@@ -282,6 +288,7 @@ class BaseModel(SQLModel):
282
288
 
283
289
  # TODO shouldn't this be handled by pydantic?
284
290
  # TODO where is this actually used? shoudl prob remove this
291
+ # TODO should we even do this? Can we specify a better json rendering class?
285
292
  def json(self, **kwargs):
286
293
  return json.dumps(self.dict(), default=str, **kwargs)
287
294
 
@@ -297,7 +304,12 @@ class BaseModel(SQLModel):
297
304
  # TODO got to be a better way to fwd these along...
298
305
  @classmethod
299
306
  def first(cls):
300
- return cls.select().first()
307
+ # TODO should use dynamic pk
308
+ return cls.select().order_by(sa.desc(cls.id)).first()
309
+
310
+ # @classmethod
311
+ # def last(cls):
312
+ # return cls.select().first()
301
313
 
302
314
  # TODO throw an error if this field is set on the model
303
315
  def is_new(self) -> bool:
@@ -404,6 +416,19 @@ class BaseModel(SQLModel):
404
416
  with get_session() as session:
405
417
  return session.exec(statement).first()
406
418
 
419
+ @classmethod
420
+ def one_or_none(cls, *args: t.Any, **kwargs: t.Any):
421
+ """
422
+ Gets a single record from the database. Pass an PK ID or a kwarg to filter by.
423
+ Returns None if no record is found. Throws an error if more than one record is found.
424
+ """
425
+
426
+ args, kwargs = cls.__process_filter_args__(*args, **kwargs)
427
+ statement = select(cls).filter(*args).filter_by(**kwargs)
428
+
429
+ with get_session() as session:
430
+ return session.exec(statement).one_or_none()
431
+
407
432
  @classmethod
408
433
  def one(cls, *args: t.Any, **kwargs: t.Any):
409
434
  """
@@ -0,0 +1,147 @@
1
+ """
2
+ This module provides utilities for generating Protocol type definitions for SQLAlchemy's
3
+ SelectOfScalar methods, as well as formatting and fixing Python files using ruff.
4
+ """
5
+
6
+ import inspect
7
+ import logging
8
+ import os
9
+ import subprocess
10
+ from pathlib import Path
11
+ from typing import Any # already imported in header of generated file
12
+
13
+ import sqlmodel as sm
14
+ from sqlmodel.sql.expression import SelectOfScalar
15
+
16
+ from test.test_wrapper import QueryWrapper
17
+
18
+ # Set up logging
19
+ logging.basicConfig(level=logging.DEBUG)
20
+ logger = logging.getLogger(__name__)
21
+
22
+ QUERY_WRAPPER_CLASS_NAME = QueryWrapper.__name__
23
+
24
+
25
+ def format_python_file(file_path: str | Path) -> bool:
26
+ """
27
+ Format a Python file using ruff.
28
+
29
+ Args:
30
+ file_path: Path to the Python file to format
31
+
32
+ Returns:
33
+ bool: True if formatting was successful, False otherwise
34
+ """
35
+ try:
36
+ subprocess.run(["ruff", "format", str(file_path)], check=True)
37
+ logger.info(f"Formatted file using ruff at {file_path}")
38
+ return True
39
+ except subprocess.CalledProcessError as e:
40
+ logger.error(f"Error running ruff to format the file: {e}")
41
+ return False
42
+
43
+
44
+ def fix_python_file(file_path: str | Path) -> bool:
45
+ """
46
+ Fix linting issues in a Python file using ruff.
47
+
48
+ Args:
49
+ file_path: Path to the Python file to fix
50
+
51
+ Returns:
52
+ bool: True if fixing was successful, False otherwise
53
+ """
54
+ try:
55
+ subprocess.run(["ruff", "check", str(file_path), "--fix"], check=True)
56
+ logger.info(f"Fixed linting issues using ruff at {file_path}")
57
+ return True
58
+ except subprocess.CalledProcessError as e:
59
+ logger.error(f"Error running ruff to fix the file: {e}")
60
+ return False
61
+
62
+
63
+ def generate_sqlalchemy_protocol():
64
+ """Generate Protocol type definitions for SQLAlchemy SelectOfScalar methods"""
65
+ logger.info("Starting SQLAlchemy protocol generation")
66
+
67
+ header = """
68
+ # IMPORTANT: This file is auto-generated. Do not edit directly.
69
+
70
+ from typing import Protocol, TypeVar, Any, Generic
71
+ import sqlmodel as sm
72
+ from sqlalchemy.sql.base import _NoArg
73
+
74
+ T = TypeVar('T', bound=sm.SQLModel, covariant=True)
75
+
76
+ class SQLAlchemyQueryMethods(Protocol, Generic[T]):
77
+ \"""Protocol defining SQLAlchemy query methods forwarded by QueryWrapper.__getattr__\"""
78
+ """
79
+ # Initialize output list for generated method signatures
80
+ output: list = []
81
+
82
+ try:
83
+ # Get all methods from SelectOfScalar
84
+ methods = inspect.getmembers(SelectOfScalar)
85
+ logger.debug(f"Discovered {len(methods)} methods from SelectOfScalar")
86
+
87
+ for name, method in methods:
88
+ # Skip private/dunder methods
89
+ if name.startswith("_"):
90
+ continue
91
+
92
+ if not inspect.isfunction(method) and not inspect.ismethod(method):
93
+ logger.debug(f"Skipping non-method: {name}")
94
+ continue
95
+
96
+ logger.debug(f"Processing method: {name}")
97
+ try:
98
+ signature = inspect.signature(method)
99
+ params = []
100
+
101
+ # Process parameters, skipping 'self'
102
+ for param_name, param in list(signature.parameters.items())[1:]:
103
+ if param.kind == param.VAR_POSITIONAL:
104
+ params.append(f"*{param_name}: Any")
105
+ elif param.kind == param.VAR_KEYWORD:
106
+ params.append(f"**{param_name}: Any")
107
+ else:
108
+ if param.default is inspect.Parameter.empty:
109
+ params.append(f"{param_name}: Any")
110
+ else:
111
+ default_repr = repr(param.default)
112
+ params.append(f"{param_name}: Any = {default_repr}")
113
+
114
+ params_str = ", ".join(params)
115
+ output.append(
116
+ f' def {name}(self, {params_str}) -> "{QUERY_WRAPPER_CLASS_NAME}[T]": ...'
117
+ )
118
+ except (ValueError, TypeError) as e:
119
+ logger.warning(f"Could not get signature for {name}: {e}")
120
+ # Some methods might not have proper signatures
121
+ output.append(
122
+ f' def {name}(self, *args: Any, **kwargs: Any) -> "{QUERY_WRAPPER_CLASS_NAME}[T]": ...'
123
+ )
124
+
125
+ # Write the output to a file
126
+ protocol_path = (
127
+ Path(__file__).parent.parent / "types" / "sqlalchemy_protocol.py"
128
+ )
129
+
130
+ # Ensure directory exists
131
+ os.makedirs(protocol_path.parent, exist_ok=True)
132
+
133
+ with open(protocol_path, "w") as f:
134
+ f.write(header + "\n".join(output))
135
+
136
+ logger.info(f"Generated SQLAlchemy protocol at {protocol_path}")
137
+
138
+ # Format and fix the generated file with ruff
139
+ format_python_file(protocol_path)
140
+ fix_python_file(protocol_path)
141
+ except Exception as e:
142
+ logger.error(f"Error generating SQLAlchemy protocol: {e}", exc_info=True)
143
+ raise
144
+
145
+
146
+ if __name__ == "__main__":
147
+ generate_sqlalchemy_protocol()
@@ -20,7 +20,10 @@ class TimestampsMixin:
20
20
  >>> class MyModel(TimestampsMixin, SQLModel):
21
21
  >>> pass
22
22
 
23
- Originally pulled from: https://github.com/tiangolo/sqlmodel/issues/252
23
+ Notes:
24
+
25
+ - Originally pulled from: https://github.com/tiangolo/sqlmodel/issues/252
26
+ - Related issue: https://github.com/fastapi/sqlmodel/issues/539
24
27
  """
25
28
 
26
29
  created_at: datetime | None = Field(
@@ -17,7 +17,7 @@ def TypeIDMixin(prefix: str):
17
17
  # NOTE this will cause issues on code reloads
18
18
  assert prefix
19
19
  assert prefix not in _prefixes, (
20
- f"prefix {prefix} already exists, pick a different one"
20
+ f"TypeID prefix '{prefix}' already exists, pick a different one"
21
21
  )
22
22
 
23
23
  class _TypeIDMixin:
@@ -1,2 +1,2 @@
1
- from .transaction import database_reset_transaction
1
+ from .transaction import database_reset_transaction, test_session
2
2
  from .truncate import database_reset_truncate
@@ -0,0 +1,102 @@
1
+ """
2
+ Notes on polyfactory:
3
+
4
+ 1. is_supported_type validates that the class can be used to generate a factory
5
+ https://github.com/litestar-org/polyfactory/issues/655#issuecomment-2727450854
6
+ """
7
+
8
+ import typing as t
9
+
10
+ from polyfactory.factories.pydantic_factory import ModelFactory
11
+ from polyfactory.field_meta import FieldMeta
12
+ from typeid import TypeID
13
+
14
+ from activemodel.session_manager import global_session
15
+
16
+ # TODO not currently used
17
+ # def type_id_provider(cls, field_meta):
18
+ # # TODO this doesn't work well with __ args:
19
+ # # https://github.com/litestar-org/polyfactory/pull/666/files
20
+ # return str(TypeID("hi"))
21
+
22
+
23
+ # BaseFactory.add_provider(TypeIDType, type_id_provider)
24
+
25
+
26
+ class SQLModelFactory[T](ModelFactory[T]):
27
+ """
28
+ Base factory for SQLModel models:
29
+
30
+ 1. Ability to ignore all relationship fks
31
+ 2. Option to ignore all pks
32
+ """
33
+
34
+ __is_base_factory__ = True
35
+
36
+ @classmethod
37
+ def should_set_field_value(cls, field_meta: FieldMeta, **kwargs: t.Any) -> bool:
38
+ # TODO what is this checking for?
39
+ has_object_override = hasattr(cls, field_meta.name)
40
+
41
+ # TODO this should be more intelligent, it's goal is to detect all of the relationship field and avoid settings them
42
+ if not has_object_override and (
43
+ field_meta.name == "id" or field_meta.name.endswith("_id")
44
+ ):
45
+ return False
46
+
47
+ return super().should_set_field_value(field_meta, **kwargs)
48
+
49
+
50
+ # TODO we need to think through how to handle relationships and autogenerate them
51
+ class ActiveModelFactory[T](SQLModelFactory[T]):
52
+ __is_base_factory__ = True
53
+ __sqlalchemy_session__ = None
54
+
55
+ # TODO we shouldn't have to type this, but `save()` typing is not working
56
+ @classmethod
57
+ def save(cls, *args, **kwargs) -> T:
58
+ """
59
+ Where this gets tricky, is this can be called multiple times within the same callstack. This can happen when
60
+ a factory uses other factories to create relationships.
61
+
62
+ In a truncation strategy, the __sqlalchemy_session__ is set to None.
63
+ """
64
+ with global_session(cls.__sqlalchemy_session__):
65
+ return cls.build(*args, **kwargs).save()
66
+
67
+ @classmethod
68
+ def foreign_key_typeid(cls):
69
+ """
70
+ Return a random type id for the foreign key on this model.
71
+
72
+ This is helpful for generating TypeIDs for testing 404s, parsing, manually settings, etc
73
+ """
74
+ # TODO right now assumes the model is typeid, maybe we should assert against this?
75
+ primary_key_name = cls.__model__.primary_key_column().name
76
+ return TypeID(
77
+ cls.__model__.model_fields[primary_key_name].sa_column.type.prefix
78
+ )
79
+
80
+ @classmethod
81
+ def should_set_field_value(cls, field_meta: FieldMeta, **kwargs: t.Any) -> bool:
82
+ # do not default deleted at mixin to deleted!
83
+ # TODO should be smarter about detecting if the mixin is in place
84
+ if field_meta.name in ["deleted_at", "updated_at", "created_at"]:
85
+ return False
86
+
87
+ return super().should_set_field_value(field_meta, **kwargs)
88
+
89
+ # @classmethod
90
+ # def build(
91
+ # cls,
92
+ # factory_use_construct: bool | None = None,
93
+ # sqlmodel_save: bool = False,
94
+ # **kwargs: t.Any,
95
+ # ) -> T:
96
+ # result = super().build(factory_use_construct=factory_use_construct, **kwargs)
97
+
98
+ # # TODO allow magic dunder method here
99
+ # if sqlmodel_save:
100
+ # result.save()
101
+
102
+ # return result
@@ -0,0 +1,81 @@
1
+ """Pytest plugin integration for activemodel.
2
+
3
+ Currently provides:
4
+
5
+ * ``db_session`` fixture – quick access to a database session (see ``test_session``)
6
+ * ``activemodel_preserve_tables`` ini option – configure tables to preserve when using
7
+ ``database_reset_truncate`` (comma separated list or multiple lines depending on config style)
8
+
9
+ Configuration examples:
10
+
11
+ pytest.ini::
12
+
13
+ [pytest]
14
+ activemodel_preserve_tables = alembic_version,zip_code,seed_table
15
+
16
+ pyproject.toml::
17
+
18
+ [tool.pytest.ini_options]
19
+ activemodel_preserve_tables = [
20
+ "alembic_version",
21
+ "zip_code",
22
+ "seed_table",
23
+ ]
24
+
25
+ The list always implicitly includes ``alembic_version`` even if not specified.
26
+ """
27
+
28
+ from activemodel.session_manager import global_session
29
+ import pytest
30
+
31
+ from .transaction import set_factory_session, set_polyfactory_session, test_session
32
+
33
+
34
+ def pytest_addoption(
35
+ parser: pytest.Parser,
36
+ ) -> None: # pragma: no cover - executed during collection
37
+ """Register custom ini options.
38
+
39
+ We treat this as a *linelist* so pyproject.toml list syntax works. Comma separated works too because
40
+ pytest splits lines first; users can still provide one line with commas.
41
+ """
42
+
43
+ parser.addini(
44
+ "activemodel_preserve_tables",
45
+ help=(
46
+ "Tables to preserve when calling activemodel.pytest.database_reset_truncate. "
47
+ ),
48
+ type="linelist",
49
+ default=["alembic_version"],
50
+ )
51
+
52
+
53
+ @pytest.fixture(scope="function")
54
+ def db_session():
55
+ """
56
+ Helpful for tests that are more similar to unit tests. If you doing a routing or integration test, you
57
+ probably don't need this. If your unit test is simple (you are just creating a couple of models) you
58
+ can most likely skip this.
59
+
60
+ This is helpful if you are doing a lot of lazy-loaded params or need a database session to be in place
61
+ for testing code that will run within a celery worker or something similar.
62
+
63
+ >>> def the_test(db_session):
64
+ """
65
+ with test_session() as session:
66
+ yield session
67
+
68
+
69
+ @pytest.fixture(scope="function")
70
+ def db_truncate_session():
71
+ """
72
+ Provides a database session for testing when using a truncation cleaning strategy.
73
+
74
+ When not using a transaction cleaning strategy, no global test session is set
75
+ """
76
+ with global_session() as session:
77
+ # set global database sessions for model factories to avoid lazy loading issues
78
+ set_factory_session(session)
79
+ set_polyfactory_session(session)
80
+
81
+ yield session