activemodel 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
activemodel/__init__.py CHANGED
@@ -1,2 +1,4 @@
1
1
  from .base_model import BaseModel
2
+
3
+ # from .field import Field
2
4
  from .session_manager import SessionManager, get_engine, get_session, init
activemodel/base_model.py CHANGED
@@ -1,33 +1,38 @@
1
1
  import json
2
2
  import typing as t
3
- from contextlib import contextmanager
3
+ from uuid import UUID
4
4
 
5
5
  import pydash
6
6
  import sqlalchemy as sa
7
7
  import sqlmodel as sm
8
8
  from sqlalchemy import Connection, event
9
9
  from sqlalchemy.orm import Mapper, declared_attr
10
- from sqlmodel import Session, SQLModel, select
10
+ from sqlmodel import Field, MetaData, Session, SQLModel, select
11
11
  from typeid import TypeID
12
12
 
13
+ # NOTE: this patches a core method in sqlmodel to support db comments
14
+ from . import get_column_from_field_patch # noqa: F401
13
15
  from .logger import logger
14
16
  from .query_wrapper import QueryWrapper
15
17
  from .session_manager import get_session
16
18
 
19
+ POSTGRES_INDEXES_NAMING_CONVENTION = {
20
+ "ix": "%(column_0_label)s_idx",
21
+ "uq": "%(table_name)s_%(column_0_name)s_key",
22
+ "ck": "%(table_name)s_%(constraint_name)s_check",
23
+ "fk": "%(table_name)s_%(column_0_name)s_fkey",
24
+ "pk": "%(table_name)s_pkey",
25
+ }
26
+ """
27
+ By default, the foreign key naming convention in sqlalchemy do not create unique identifiers when there are multiple
28
+ foreign keys in a table. This naming convention is a workaround to fix this issue:
17
29
 
18
- # TODO this does not seem to work with the latest 2.9.x pydantic and sqlmodel
19
- # https://github.com/SE-Sustainability-OSS/ecodev-core/blob/main/ecodev_core/sqlmodel_utils.py
20
- class SQLModelWithValidation(SQLModel):
21
- """
22
- Helper class to ease validation in SQLModel classes with table=True
23
- """
30
+ - https://github.com/zhanymkanov/fastapi-best-practices?tab=readme-ov-file#set-db-keys-naming-conventions
31
+ - https://github.com/fastapi/sqlmodel/discussions/1213
32
+ - Implementation lifted from: https://github.com/AlexanderZharyuk/billing-service/blob/3c8aaf19ab7546b97cc4db76f60335edec9fc79d/src/models.py#L24
33
+ """
24
34
 
25
- @classmethod
26
- def create(cls, **kwargs):
27
- """
28
- Forces validation to take place, even for SQLModel classes with table=True
29
- """
30
- return cls(**cls.__bases__[0](**kwargs).model_dump())
35
+ SQLModel.metadata.naming_convention = POSTGRES_INDEXES_NAMING_CONVENTION
31
36
 
32
37
 
33
38
  class BaseModel(SQLModel):
@@ -36,10 +41,14 @@ class BaseModel(SQLModel):
36
41
 
37
42
  https://github.com/woofz/sqlmodel-basecrud/blob/main/sqlmodel_basecrud/basecrud.py
38
43
 
39
- {before,after} hooks are modeled after Rails.
44
+ - {before,after} lifecycle hooks are modeled after Rails.
45
+ - class docstrings are converd to table-level comments
46
+ - save(), delete(), select(), where(), and other easy methods you would expect
47
+ - Fixes foreign key naming conventions
40
48
  """
41
49
 
42
- # TODO implement actually calling these hooks
50
+ # this is used for table-level comments
51
+ __table_args__ = None
43
52
 
44
53
  @classmethod
45
54
  def __init_subclass__(cls, **kwargs):
@@ -47,6 +56,14 @@ class BaseModel(SQLModel):
47
56
 
48
57
  super().__init_subclass__(**kwargs)
49
58
 
59
+ from sqlmodel._compat import set_config_value
60
+
61
+ # enables field-level docstrings on the pydanatic `description` field, which we then copy into
62
+ # sa_args, which is persisted to sql table comments
63
+ set_config_value(model=cls, parameter="use_attribute_docstrings", value=True)
64
+
65
+ cls._apply_class_doc()
66
+
50
67
  def event_wrapper(method_name: str):
51
68
  """
52
69
  This does smart heavy lifting for us to make sqlalchemy lifecycle events nicer to work with:
@@ -96,11 +113,37 @@ class BaseModel(SQLModel):
96
113
  event.listen(cls, "after_insert", event_wrapper("after_save"))
97
114
  event.listen(cls, "after_update", event_wrapper("after_save"))
98
115
 
116
+ # def foreign_key()
117
+ # table.id
118
+
119
+ @classmethod
120
+ def _apply_class_doc(cls):
121
+ """
122
+ Pull class-level docstring into a table comment.
123
+
124
+ This will help AI SQL writers like: https://github.com/iloveitaly/sql-ai-prompt-generator
125
+ """
126
+
127
+ doc = cls.__doc__.strip() if cls.__doc__ else None
128
+
129
+ if doc:
130
+ table_args = getattr(cls, "__table_args__", None)
131
+
132
+ if table_args is None:
133
+ cls.__table_args__ = {"comment": doc}
134
+ elif isinstance(table_args, dict):
135
+ table_args.setdefault("comment", doc)
136
+ else:
137
+ raise ValueError("Unexpected __table_args__ type")
138
+
99
139
  # TODO no type check decorator here
100
140
  @declared_attr
101
141
  def __tablename__(cls) -> str:
102
142
  """
103
143
  Automatically generates the table name for the model by converting the class name from camel case to snake case.
144
+ This is the recommended format for table names:
145
+
146
+ https://wiki.postgresql.org/wiki/Don%27t_Do_This#Don.27t_use_upper_case_table_or_column_names
104
147
 
105
148
  By default, the class is lower cased which makes it harder to read.
106
149
 
@@ -111,10 +154,35 @@ class BaseModel(SQLModel):
111
154
  """
112
155
  return pydash.strings.snake_case(cls.__name__)
113
156
 
157
+ @classmethod
158
+ def foreign_key(cls, **kwargs):
159
+ """
160
+ Returns a Field object referencing the foreign key of the model.
161
+ """
162
+
163
+ field_options = {"nullable": False} | kwargs
164
+
165
+ return Field(
166
+ # TODO id field is hard coded, should pick the PK field in case it's different
167
+ sa_type=cls.model_fields["id"].sa_column.type, # type: ignore
168
+ foreign_key=f"{cls.__tablename__}.id",
169
+ **field_options,
170
+ )
171
+
114
172
  @classmethod
115
173
  def select(cls, *args):
174
+ "create a query wrapper to easily run sqlmodel queries on this model"
116
175
  return QueryWrapper[cls](cls, *args)
117
176
 
177
+ def delete(self):
178
+ with get_session() as session:
179
+ if old_session := Session.object_session(self):
180
+ old_session.expunge(self)
181
+
182
+ session.delete(self)
183
+ session.commit()
184
+ return True
185
+
118
186
  def save(self):
119
187
  with get_session() as session:
120
188
  if old_session := Session.object_session(self):
@@ -145,10 +213,16 @@ class BaseModel(SQLModel):
145
213
  """
146
214
  Returns the number of records in the database.
147
215
  """
148
- return get_session().exec(sm.select(sm.func.count()).select_from(cls)).one()
216
+ with get_session() as session:
217
+ return session.scalar(sm.select(sm.func.count()).select_from(cls))
218
+
219
+ # TODO got to be a better way to fwd these along...
220
+ @classmethod
221
+ def first(cls):
222
+ return cls.select().first()
149
223
 
150
224
  # TODO throw an error if this field is set on the model
151
- def is_new(self):
225
+ def is_new(self) -> bool:
152
226
  return not self._sa_instance_state.has_identity
153
227
 
154
228
  @classmethod
@@ -186,29 +260,31 @@ class BaseModel(SQLModel):
186
260
  # field it will result in `True`, which will return all records, and not give you any typing
187
261
  # errors. Dangerous when iterating on structure quickly
188
262
  # TODO can we pass the generic of the superclass in?
189
- @classmethod
263
+ # TODO can we type the method signature a bit better?
190
264
  # def get(cls, *args: sa.BinaryExpression, **kwargs: t.Any):
265
+ @classmethod
191
266
  def get(cls, *args: t.Any, **kwargs: t.Any):
192
267
  """
193
268
  Gets a single record from the database. Pass an PK ID or a kwarg to filter by.
194
269
  """
195
270
 
271
+ # TODO id is hardcoded, not good! Need to dynamically pick the best uid field
272
+ id_field_name = "id"
273
+
196
274
  # special case for getting by ID
197
- if len(args) == 1 and isinstance(args[0], int):
198
- # TODO id is hardcoded, not good! Need to dynamically pick the best uid field
199
- kwargs["id"] = args[0]
200
- args = []
201
- elif len(args) == 1 and isinstance(args[0], TypeID):
202
- kwargs["id"] = args[0]
203
- args = []
275
+ if len(args) == 1 and isinstance(args[0], (int, TypeID, str, UUID)):
276
+ kwargs[id_field_name] = args[0]
277
+ args = ()
204
278
 
205
279
  statement = select(cls).filter(*args).filter_by(**kwargs)
206
- return get_session().exec(statement).first()
280
+
281
+ with get_session() as session:
282
+ return session.exec(statement).first()
207
283
 
208
284
  @classmethod
209
285
  def all(cls):
210
286
  with get_session() as session:
211
- results = session.exec(sa.sql.select(cls))
287
+ results = session.exec(sm.select(cls))
212
288
 
213
289
  # TODO do we need this or can we just return results?
214
290
  for result in results:
@@ -222,7 +298,7 @@ class BaseModel(SQLModel):
222
298
  Helpful for testing and console debugging.
223
299
  """
224
300
 
225
- query = sql.select(cls).order_by(sql.func.random()).limit(1)
301
+ query = sm.select(cls).order_by(sa.sql.func.random()).limit(1)
226
302
 
227
303
  with get_session() as session:
228
304
  return session.exec(query).one()
activemodel/celery.py ADDED
@@ -0,0 +1,28 @@
1
+ """
2
+ Do not import unless you have Celery/Kombu installed.
3
+
4
+ In order for TypeID objects to be properly handled by celery, a custom encoder must be registered.
5
+ """
6
+
7
+ from kombu.utils.json import register_type
8
+ from typeid import TypeID
9
+
10
+
11
+ def register_celery_typeid_encoder():
12
+ "this ensures TypeID objects passed as arguments to a delayed function are properly serialized"
13
+
14
+ def class_full_name(clz) -> str:
15
+ return ".".join([clz.__module__, clz.__qualname__])
16
+
17
+ def _encoder(obj: TypeID) -> str:
18
+ return str(obj)
19
+
20
+ def _decoder(data: str) -> TypeID:
21
+ return TypeID.from_string(data)
22
+
23
+ register_type(
24
+ TypeID,
25
+ class_full_name(TypeID),
26
+ encoder=_encoder,
27
+ decoder=_decoder,
28
+ )
activemodel/errors.py ADDED
@@ -0,0 +1,6 @@
1
+ class TypeIDValidationError(ValueError):
2
+ """
3
+ Raised when a TypeID is invalid in some way
4
+ """
5
+
6
+ pass
@@ -0,0 +1,137 @@
1
+ """
2
+ Pydantic has a great DX for adding docstrings to fields. This allows devs to easily document the fields of a model.
3
+
4
+ Making sure these docstrings make their way to the DB schema is helpful for a bunch of reasons (LLM understanding being one of them).
5
+
6
+ This patch mutates a core sqlmodel function which translates pydantic FieldInfo objects into sqlalchemy Column objects. It adds the field description as a comment to the column.
7
+
8
+ Note that FieldInfo *from pydantic* is used when a "bare" field is defined. This can be confusing, because when inspecting model fields, the class name looks exactly the same.
9
+ """
10
+
11
+ from typing import (
12
+ TYPE_CHECKING,
13
+ Any,
14
+ Dict,
15
+ Sequence,
16
+ cast,
17
+ )
18
+
19
+ import sqlmodel
20
+ from pydantic.fields import FieldInfo as PydanticFieldInfo
21
+ from sqlalchemy import (
22
+ Column,
23
+ ForeignKey,
24
+ )
25
+ from sqlmodel._compat import ( # type: ignore[attr-defined]
26
+ IS_PYDANTIC_V2,
27
+ ModelMetaclass,
28
+ Representation,
29
+ Undefined,
30
+ UndefinedType,
31
+ is_field_noneable,
32
+ )
33
+ from sqlmodel.main import FieldInfo, get_sqlalchemy_type
34
+
35
+ from activemodel.utils import hash_function_code
36
+
37
+ if TYPE_CHECKING:
38
+ from pydantic._internal._model_construction import ModelMetaclass as ModelMetaclass
39
+ from pydantic._internal._repr import Representation as Representation
40
+ from pydantic_core import PydanticUndefined as Undefined
41
+ from pydantic_core import PydanticUndefinedType as UndefinedType
42
+
43
+
44
+ assert (
45
+ hash_function_code(sqlmodel.main.get_column_from_field)
46
+ == "398006ef8fd8da191ca1a271ef25b6e135da0f400a80df2f29526d8674f9ec51"
47
+ )
48
+
49
+
50
+ def get_column_from_field(field: PydanticFieldInfo | FieldInfo) -> Column: # type: ignore
51
+ """
52
+ Takes a field definition, which can either come from the sqlmodel FieldInfo class or the pydantic variant of that class,
53
+ and converts it into a sqlalchemy Column object.
54
+ """
55
+ if IS_PYDANTIC_V2:
56
+ field_info = field
57
+ else:
58
+ field_info = field.field_info
59
+
60
+ sa_column = getattr(field_info, "sa_column", Undefined)
61
+ if isinstance(sa_column, Column):
62
+ # IMPORTANT: change from the original function
63
+ if not sa_column.comment and (field_comment := field_info.description):
64
+ sa_column.comment = field_comment
65
+ return sa_column
66
+
67
+ primary_key = getattr(field_info, "primary_key", Undefined)
68
+ if primary_key is Undefined:
69
+ primary_key = False
70
+
71
+ index = getattr(field_info, "index", Undefined)
72
+ if index is Undefined:
73
+ index = False
74
+
75
+ nullable = not primary_key and is_field_noneable(field)
76
+ # Override derived nullability if the nullable property is set explicitly
77
+ # on the field
78
+ field_nullable = getattr(field_info, "nullable", Undefined) # noqa: B009
79
+ if field_nullable is not Undefined:
80
+ assert not isinstance(field_nullable, UndefinedType)
81
+ nullable = field_nullable
82
+ args = []
83
+ foreign_key = getattr(field_info, "foreign_key", Undefined)
84
+ if foreign_key is Undefined:
85
+ foreign_key = None
86
+ unique = getattr(field_info, "unique", Undefined)
87
+ if unique is Undefined:
88
+ unique = False
89
+ if foreign_key:
90
+ if field_info.ondelete == "SET NULL" and not nullable:
91
+ raise RuntimeError('ondelete="SET NULL" requires nullable=True')
92
+ assert isinstance(foreign_key, str)
93
+ ondelete = getattr(field_info, "ondelete", Undefined)
94
+ if ondelete is Undefined:
95
+ ondelete = None
96
+ assert isinstance(ondelete, (str, type(None))) # for typing
97
+ args.append(ForeignKey(foreign_key, ondelete=ondelete))
98
+ kwargs = {
99
+ "primary_key": primary_key,
100
+ "nullable": nullable,
101
+ "index": index,
102
+ "unique": unique,
103
+ }
104
+
105
+ sa_default = Undefined
106
+ if field_info.default_factory:
107
+ sa_default = field_info.default_factory
108
+ elif field_info.default is not Undefined:
109
+ sa_default = field_info.default
110
+ if sa_default is not Undefined:
111
+ kwargs["default"] = sa_default
112
+
113
+ sa_column_args = getattr(field_info, "sa_column_args", Undefined)
114
+ if sa_column_args is not Undefined:
115
+ args.extend(list(cast(Sequence[Any], sa_column_args)))
116
+
117
+ sa_column_kwargs = getattr(field_info, "sa_column_kwargs", Undefined)
118
+
119
+ # IMPORTANT: change from the original function
120
+ if field_info.description:
121
+ if sa_column_kwargs is Undefined:
122
+ sa_column_kwargs = {}
123
+
124
+ assert isinstance(sa_column_kwargs, dict)
125
+
126
+ # only update comments if not already set
127
+ if "comment" not in sa_column_kwargs:
128
+ sa_column_kwargs["comment"] = field_info.description
129
+
130
+ if sa_column_kwargs is not Undefined:
131
+ kwargs.update(cast(Dict[Any, Any], sa_column_kwargs))
132
+
133
+ sa_type = get_sqlalchemy_type(field)
134
+ return Column(sa_type, *args, **kwargs) # type: ignore
135
+
136
+
137
+ sqlmodel.main.get_column_from_field = get_column_from_field
@@ -1,2 +1,4 @@
1
+ from .pydantic_json import PydanticJSONMixin
2
+ from .soft_delete import SoftDeletionMixin
1
3
  from .timestamps import TimestampsMixin
2
4
  from .typeid import TypeIDMixin
@@ -0,0 +1,69 @@
1
+ from types import UnionType
2
+ from typing import get_args, get_origin
3
+
4
+ from pydantic import BaseModel as PydanticBaseModel
5
+ from sqlalchemy.orm import reconstructor
6
+
7
+
8
+ class PydanticJSONMixin:
9
+ """
10
+ By default, SQLModel does not convert JSONB columns into pydantic models when they are loaded from the database.
11
+
12
+ This mixin, combined with a custom serializer, fixes that issue.
13
+ """
14
+
15
+ @reconstructor
16
+ def init_on_load(self):
17
+ # TODO do we need to inspect sa_type
18
+ for field_name, field_info in self.model_fields.items():
19
+ raw_value = getattr(self, field_name, None)
20
+
21
+ if raw_value is None:
22
+ continue
23
+
24
+ annotation = field_info.annotation
25
+ origin = get_origin(annotation)
26
+
27
+ # e.g. `dict` or `dict[str, str]`, we don't want to do anything with these
28
+ if origin is dict:
29
+ continue
30
+
31
+ annotation_args = get_args(annotation)
32
+ is_top_level_list = origin is list
33
+
34
+ # if origin is not None:
35
+ # assert annotation.__class__ == origin
36
+
37
+ model_cls = annotation
38
+
39
+ # e.g. SomePydanticModel | None or list[SomePydanticModel] | None
40
+ # annotation_args are (type, NoneType) in this case
41
+ if isinstance(annotation, UnionType):
42
+ non_none_types = [t for t in annotation_args if t is not type(None)]
43
+
44
+ if len(non_none_types) == 1:
45
+ model_cls = non_none_types[0]
46
+
47
+ # e.g. list[SomePydanticModel] | None, we have to unpack it
48
+ # model_cls will print as a list, but it contains a subtype if you dig into it
49
+ if (
50
+ get_origin(model_cls) is list
51
+ and len(list_annotation_args := get_args(model_cls)) == 1
52
+ ):
53
+ model_cls = list_annotation_args[0]
54
+ is_top_level_list = True
55
+
56
+ # e.g. list[SomePydanticModel] or list[SomePydanticModel] | None
57
+ # iterate through the list and run each item through the pydantic model
58
+ if is_top_level_list:
59
+ if isinstance(raw_value, list) and issubclass(
60
+ model_cls, PydanticBaseModel
61
+ ):
62
+ parsed_value = [model_cls(**item) for item in raw_value]
63
+ setattr(self, field_name, parsed_value)
64
+
65
+ continue
66
+
67
+ # single class
68
+ if issubclass(model_cls, PydanticBaseModel):
69
+ setattr(self, field_name, model_cls(**raw_value))
@@ -0,0 +1,17 @@
1
+ from datetime import datetime
2
+
3
+ import sqlalchemy as sa
4
+ from sqlmodel import Field
5
+
6
+
7
+ class SoftDeletionMixin:
8
+ deleted_at: datetime = Field(
9
+ default=None,
10
+ nullable=True,
11
+ # TODO https://github.com/fastapi/sqlmodel/discussions/1228
12
+ sa_type=sa.DateTime(timezone=True), # type: ignore
13
+ )
14
+
15
+ def soft_delete(self):
16
+ self.deleted_at = datetime.now()
17
+ raise NotImplementedError("Soft deletion is not implemented")
@@ -1,36 +1,46 @@
1
- import uuid
2
-
3
1
  from sqlmodel import Column, Field
4
2
  from typeid import TypeID
5
3
 
6
4
  from activemodel.types.typeid import TypeIDType
7
5
 
6
+ # global list of prefixes to ensure uniqueness
7
+ _prefixes = []
8
+
8
9
 
9
10
  def TypeIDMixin(prefix: str):
11
+ assert prefix
12
+ assert prefix not in _prefixes, (
13
+ f"prefix {prefix} already exists, pick a different one"
14
+ )
15
+
10
16
  class _TypeIDMixin:
11
- id: uuid.UUID = Field(
12
- sa_column=Column(TypeIDType(prefix), primary_key=True),
17
+ id: TypeIDType = Field(
18
+ sa_column=Column(TypeIDType(prefix), primary_key=True, nullable=False),
13
19
  default_factory=lambda: TypeID(prefix),
14
20
  )
15
21
 
22
+ _prefixes.append(prefix)
23
+
16
24
  return _TypeIDMixin
17
25
 
18
26
 
19
- class TypeIDMixin2:
20
- """
21
- Mixin class that adds a TypeID primary key to models.
27
+ # TODO not sure if I love the idea of a dynamic class for each mixin as used above
28
+ # may give this approach another shot in the future
29
+ # class TypeIDMixin2:
30
+ # """
31
+ # Mixin class that adds a TypeID primary key to models.
22
32
 
23
33
 
24
- >>> class MyModel(BaseModel, TypeIDMixin, prefix="xyz", table=True):
25
- >>> name: str
34
+ # >>> class MyModel(BaseModel, TypeIDMixin, prefix="xyz", table=True):
35
+ # >>> name: str
26
36
 
27
- Will automatically have an `id` field with prefix "xyz"
28
- """
37
+ # Will automatically have an `id` field with prefix "xyz"
38
+ # """
29
39
 
30
- def __init_subclass__(cls, *, prefix: str, **kwargs):
31
- super().__init_subclass__(**kwargs)
40
+ # def __init_subclass__(cls, *, prefix: str, **kwargs):
41
+ # super().__init_subclass__(**kwargs)
32
42
 
33
- cls.id: uuid.UUID = Field(
34
- sa_column=Column(TypeIDType(prefix), primary_key=True),
35
- default_factory=lambda: TypeID(prefix),
36
- )
43
+ # cls.id: uuid.UUID = Field(
44
+ # sa_column=Column(TypeIDType(prefix), primary_key=True),
45
+ # default_factory=lambda: TypeID(prefix),
46
+ # )
@@ -40,7 +40,7 @@ def database_reset_truncate():
40
40
  transaction = connection.begin()
41
41
 
42
42
  if table.name not in exception_tables:
43
- logger.debug("truncating table=%s", table.name)
43
+ logger.debug(f"truncating table={table.name}")
44
44
  connection.execute(table.delete())
45
45
 
46
46
  transaction.commit()
@@ -1,20 +1,26 @@
1
- import sqlmodel
1
+ import sqlmodel as sm
2
+ from sqlmodel.sql.expression import SelectOfScalar
2
3
 
4
+ from .session_manager import get_session
5
+ from .utils import compile_sql
3
6
 
4
- class QueryWrapper[T]:
7
+
8
+ class QueryWrapper[T: sm.SQLModel]:
5
9
  """
6
10
  Make it easy to run queries off of a model
7
11
  """
8
12
 
9
- def __init__(self, cls, *args) -> None:
13
+ target: SelectOfScalar[T]
14
+
15
+ def __init__(self, cls: T, *args) -> None:
10
16
  # TODO add generics here
11
17
  # self.target: SelectOfScalar[T] = sql.select(cls)
12
18
 
13
19
  if args:
14
20
  # very naive, let's assume the args are specific select statements
15
- self.target = sqlmodel.sql.select(*args).select_from(cls)
21
+ self.target = sm.select(*args).select_from(cls)
16
22
  else:
17
- self.target = sql.select(cls)
23
+ self.target = sm.select(cls)
18
24
 
19
25
  # TODO the .exec results should be handled in one shot
20
26
 
@@ -23,6 +29,7 @@ class QueryWrapper[T]:
23
29
  return session.exec(self.target).first()
24
30
 
25
31
  def one(self):
32
+ "requires exactly one result in the dataset"
26
33
  with get_session() as session:
27
34
  return session.exec(self.target).one()
28
35
 
@@ -32,8 +39,14 @@ class QueryWrapper[T]:
32
39
  for row in result:
33
40
  yield row
34
41
 
42
+ def count(self):
43
+ """
44
+ I did some basic tests
45
+ """
46
+ with get_session() as session:
47
+ return session.scalar(sm.select(sm.func.count()).select_from(self.target))
48
+
35
49
  def exec(self):
36
- # TODO do we really need a unique session each time?
37
50
  with get_session() as session:
38
51
  return session.exec(self.target)
39
52
 
@@ -43,19 +56,20 @@ class QueryWrapper[T]:
43
56
 
44
57
  def __getattr__(self, name):
45
58
  """
46
- This is called to retrieve the function to execute
59
+ This implements the magic that forwards function calls to sqlalchemy.
47
60
  """
48
61
 
49
62
  # TODO prefer methods defined in this class
63
+
50
64
  if not hasattr(self.target, name):
51
65
  return super().__getattribute__(name)
52
66
 
53
- attr = getattr(self.target, name)
67
+ sqlalchemy_target = getattr(self.target, name)
54
68
 
55
- if callable(attr):
69
+ if callable(sqlalchemy_target):
56
70
 
57
71
  def wrapper(*args, **kwargs):
58
- result = attr(*args, **kwargs)
72
+ result = sqlalchemy_target(*args, **kwargs)
59
73
  self.target = result
60
74
  return self
61
75