fastapi-repository 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,36 @@
1
+ Metadata-Version: 2.4
2
+ Name: fastapi-repository
3
+ Version: 0.0.1
4
+ Summary: A base repository for FastAPI projects, inspired by Ruby on Rails' Active Record and Ransack.
5
+ Author-email: Seiya Takeda <takedaseiya@gmail.com>
6
+ Project-URL: Homepage, https://github.com/seiyat/fastapi-repository
7
+ Project-URL: Bug Tracker, https://github.com/seiyat/fastapi-repository/issues
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: License :: OSI Approved :: MIT License
10
+ Classifier: Operating System :: OS Independent
11
+ Requires-Python: >=3.8
12
+ Description-Content-Type: text/markdown
13
+ Requires-Dist: sqlalchemy>=1.4.0
14
+ Requires-Dist: fastapi>=0.70.0
15
+
16
+ # FastAPI Repository
17
+
18
+ A base repository for FastAPI projects, inspired by Ruby on Rails' Active Record and Ransack.
19
+
20
+ ## Installation
21
+
22
+ ```bash
23
+ pip install fastapi-repository
24
+ ```
25
+
26
+ ## Usage
27
+
28
+ ```python
29
+ from fastapi_repository import BaseRepository
30
+ from sqlalchemy.ext.asyncio import AsyncSession
31
+ from .models import User
32
+
33
+ class UserRepository(BaseRepository):
34
+ def __init__(self, session: AsyncSession):
35
+ super().__init__(session, model=User)
36
+ ```
@@ -0,0 +1,21 @@
1
+ # FastAPI Repository
2
+
3
+ A base repository for FastAPI projects, inspired by Ruby on Rails' Active Record and Ransack.
4
+
5
+ ## Installation
6
+
7
+ ```bash
8
+ pip install fastapi-repository
9
+ ```
10
+
11
+ ## Usage
12
+
13
+ ```python
14
+ from fastapi_repository import BaseRepository
15
+ from sqlalchemy.ext.asyncio import AsyncSession
16
+ from .models import User
17
+
18
+ class UserRepository(BaseRepository):
19
+ def __init__(self, session: AsyncSession):
20
+ super().__init__(session, model=User)
21
+ ```
@@ -0,0 +1,3 @@
1
+ from .base import BaseRepository, OPERATORS
2
+
3
+ __all__ = ["BaseRepository", "OPERATORS"]
@@ -0,0 +1,341 @@
1
+ from sqlalchemy.ext.asyncio import AsyncSession
2
+ from sqlalchemy.exc import NoResultFound
3
+ from sqlalchemy.future import select
4
+ from uuid import UUID
5
+ from sqlalchemy.orm import joinedload, lazyload
6
+ from sqlalchemy import func, update, delete
7
+ from typing import Optional, List, Union, Dict, Any
8
+
9
+ OPERATORS = {
10
+ # 完全一致
11
+ "exact": lambda col, val: col == val,
12
+ "iexact": lambda col, val: col.ilike(val),
13
+ # 部分一致
14
+ "contains": lambda col, val: col.contains(val),
15
+ "icontains": lambda col, val: col.ilike(f"%{val}%"),
16
+ # in句
17
+ "in": lambda col, val: col.in_(val) if isinstance(val, list) else col.in_([val]),
18
+ # 大小比較
19
+ "gt": lambda col, val: col > val,
20
+ "gte": lambda col, val: col >= val,
21
+ "lt": lambda col, val: col < val,
22
+ "lte": lambda col, val: col <= val,
23
+ # 前方・後方一致
24
+ "startswith": lambda col, val: col.startswith(val),
25
+ "istartswith": lambda col, val: col.ilike(f"{val}%"),
26
+ "endswith": lambda col, val: col.endswith(val),
27
+ "iendswith": lambda col, val: col.ilike(f"%{val}"),
28
+ }
29
+
30
+
31
+ class BaseRepository:
32
+ default_scope: dict = {}
33
+
34
+ def __init__(self, session: AsyncSession, model=None):
35
+ self.session = session
36
+ if not model:
37
+ raise ValueError("Model is not set for this repository.")
38
+ self.model = model
39
+
40
+ async def find(
41
+ self,
42
+ id: Union[int, UUID],
43
+ sorted_by: Optional[str] = None,
44
+ sorted_order: str = "asc",
45
+ joinedload_models: Optional[List] = None,
46
+ lazyload_models: Optional[List] = None,
47
+ disable_default_scope: bool = False,
48
+ ):
49
+ """
50
+ Find a record by its ID. Raise an exception if not found.
51
+ """
52
+ query = await self.__generate_query(
53
+ limit=1,
54
+ offset=0,
55
+ sorted_by=sorted_by,
56
+ sorted_order=sorted_order,
57
+ joinedload_models=joinedload_models,
58
+ lazyload_models=lazyload_models,
59
+ disable_default_scope=disable_default_scope,
60
+ id=id,
61
+ )
62
+
63
+ result = await self.session.execute(query)
64
+ instance = result.scalars().first()
65
+
66
+ if not instance:
67
+ raise NoResultFound(f"{self.model.__name__} with id {id} not found.")
68
+
69
+ return instance
70
+
71
+ async def find_by(
72
+ self,
73
+ sorted_by: Optional[str] = None,
74
+ sorted_order: str = "asc",
75
+ joinedload_models: Optional[List] = None,
76
+ lazyload_models: Optional[List] = None,
77
+ disable_default_scope: bool = False,
78
+ **search_params,
79
+ ):
80
+ """
81
+ Find a record by given attributes. Return None if not found.
82
+ """
83
+ query = await self.__generate_query(
84
+ limit=1,
85
+ offset=0,
86
+ sorted_by=sorted_by,
87
+ sorted_order=sorted_order,
88
+ joinedload_models=joinedload_models,
89
+ lazyload_models=lazyload_models,
90
+ disable_default_scope=disable_default_scope,
91
+ **search_params,
92
+ )
93
+ result = await self.session.execute(query)
94
+ instance = result.scalars().first()
95
+ return instance
96
+
97
+ async def find_by_or_raise(
98
+ self,
99
+ sorted_by: Optional[str] = None,
100
+ sorted_order: str = "asc",
101
+ joinedload_models: Optional[List] = None,
102
+ lazyload_models: Optional[List] = None,
103
+ disable_default_scope: bool = False,
104
+ **search_params,
105
+ ):
106
+ """
107
+ Find a record by given attributes. Raise an exception if not found.
108
+ """
109
+ instance = await self.find_by(
110
+ sorted_by=sorted_by,
111
+ sorted_order=sorted_order,
112
+ joinedload_models=joinedload_models,
113
+ lazyload_models=lazyload_models,
114
+ disable_default_scope=disable_default_scope,
115
+ **search_params,
116
+ )
117
+ if not instance:
118
+ raise NoResultFound(
119
+ f"{self.model.__name__} with attributes {search_params} not found."
120
+ )
121
+ return instance
122
+
123
+ async def where(
124
+ self,
125
+ limit: int = 100,
126
+ offset: int = 0,
127
+ sorted_by: Optional[str] = None,
128
+ sorted_order: str = "asc",
129
+ joinedload_models: Optional[List] = None,
130
+ lazyload_models: Optional[List] = None,
131
+ disable_default_scope: bool = False,
132
+ **search_params,
133
+ ):
134
+ """
135
+ Find records with optional filtering, sorting, and pagination.
136
+ """
137
+ query = await self.__generate_query(
138
+ limit=limit,
139
+ offset=offset,
140
+ sorted_by=sorted_by,
141
+ sorted_order=sorted_order,
142
+ joinedload_models=joinedload_models,
143
+ lazyload_models=lazyload_models,
144
+ disable_default_scope=disable_default_scope,
145
+ **search_params,
146
+ )
147
+ result = await self.session.execute(query)
148
+ return result.unique().scalars().all()
149
+
150
+ async def count(self, disable_default_scope: bool = False, **search_params) -> int:
151
+ """
152
+ Count records with optional filtering.
153
+ """
154
+ conditions = []
155
+ if not disable_default_scope:
156
+ default_conditions = await self.__get_conditions(**self.default_scope)
157
+ conditions.extend(default_conditions)
158
+
159
+ conditions += await self.__get_conditions(**search_params)
160
+ query = select(func.count("*")).select_from(self.model).where(*conditions)
161
+ result = await self.session.execute(query)
162
+ return result.scalar() or 0
163
+
164
+ async def exists(
165
+ self, disable_default_scope: bool = False, **search_params
166
+ ) -> bool:
167
+ """
168
+ Check if any record exists with the given attributes.
169
+ """
170
+ counted = await self.count(
171
+ disable_default_scope=disable_default_scope, **search_params
172
+ )
173
+ return counted > 0
174
+
175
+ async def __generate_query(
176
+ self,
177
+ limit: int = 100,
178
+ offset: int = 0,
179
+ sorted_by: Optional[str] = None,
180
+ sorted_order: str = "asc",
181
+ joinedload_models: Optional[List] = None,
182
+ lazyload_models: Optional[List] = None,
183
+ disable_default_scope: bool = False,
184
+ **search_params,
185
+ ):
186
+ """
187
+ Generate a query with optional filtering, sorting, and pagination.
188
+ Apply default scope if not disabled.
189
+ """
190
+ conditions = []
191
+ if not disable_default_scope:
192
+ default_conditions = await self.__get_conditions(**self.default_scope)
193
+ conditions.extend(default_conditions)
194
+
195
+ conditions += await self.__get_conditions(**search_params)
196
+
197
+ query = select(self.model).where(*conditions)
198
+
199
+ if joinedload_models:
200
+ for model in joinedload_models:
201
+ query = query.options(joinedload(model))
202
+ if lazyload_models:
203
+ for model in lazyload_models:
204
+ query = query.options(lazyload(model))
205
+
206
+ if sorted_by:
207
+ query = self._apply_order_by(query, sorted_by, sorted_order)
208
+
209
+ return query.limit(limit).offset(offset)
210
+
211
+ def _apply_order_by(self, query, sorted_by: str, sorted_order: str):
212
+ """
213
+ クエリに対して order_by を適用するヘルパー。
214
+ """
215
+ column = getattr(self.model, sorted_by, None)
216
+ if not column:
217
+ raise AttributeError(
218
+ f"{self.model.__name__} has no attribute '{sorted_by}'"
219
+ )
220
+
221
+ if sorted_order.lower() == "asc":
222
+ query = query.order_by(column.asc())
223
+ else:
224
+ query = query.order_by(column.desc())
225
+ return query
226
+
227
+ async def __get_conditions(self, **search_params):
228
+ """
229
+ Generate conditions for filtering based on provided keyword arguments.
230
+ Supports Ransack-like operators (field__operator=value).
231
+ """
232
+ conditions = []
233
+ for key, value in search_params.items():
234
+ # keyに "__" が含まれていれば、フィールド名と演算子を分割する
235
+ if "__" in key:
236
+ parts = key.split("__")
237
+ op = "exact"
238
+ if parts[-1] in OPERATORS: # 末尾が演算子なら取り除く
239
+ op = parts.pop()
240
+
241
+ # 単純カラム: foo__icontains=bar
242
+ if len(parts) == 1:
243
+ column = getattr(self.model, parts[0], None)
244
+ if column is None:
245
+ raise AttributeError(
246
+ f"{self.model.__name__} has no attribute '{parts[0]}'"
247
+ )
248
+ conditions.append(OPERATORS[op](column, value))
249
+ continue
250
+
251
+ # 1ホップのリレーション: rel__field__op=value
252
+ rel_attr = getattr(self.model, parts[0], None)
253
+ if rel_attr is None or not hasattr(rel_attr, "property"):
254
+ raise AttributeError(
255
+ f"{self.model.__name__} has no relationship '{parts[0]}'"
256
+ )
257
+ target_cls = rel_attr.property.mapper.class_
258
+ target_column = getattr(target_cls, parts[1], None)
259
+ if target_column is None:
260
+ raise AttributeError(
261
+ f"{target_cls.__name__} has no attribute '{parts[1]}'"
262
+ )
263
+ conditions.append(rel_attr.any(OPERATORS[op](target_column, value)))
264
+ continue
265
+ else:
266
+ # "__"が含まれていない場合は eq (=) 比較とみなす
267
+ column = getattr(self.model, key, None)
268
+ if column is None:
269
+ raise AttributeError(
270
+ f"{self.model.__name__} has no attribute '{key}'"
271
+ )
272
+ conditions.append(column == value)
273
+
274
+ return conditions
275
+
276
+ async def create(self, **create_params):
277
+ """
278
+ Generic create method that instantiates the model,
279
+ saves it, and returns the new instance.
280
+ """
281
+ instance = self.model(**create_params)
282
+ self.session.add(instance)
283
+ await self.session.commit()
284
+ await self.session.refresh(instance)
285
+ return instance
286
+
287
+ async def update(self, id: Union[int, UUID], **update_params):
288
+ """
289
+ Update a single record by its primary key.
290
+ Raises NoResultFound if the record doesn't exist.
291
+
292
+ Usage:
293
+ await repository.update(some_id, field1='value1', field2='value2')
294
+ """
295
+ instance = await self.find(id)
296
+ for field, value in update_params.items():
297
+ setattr(instance, field, value)
298
+ await self.session.commit()
299
+ await self.session.refresh(instance)
300
+ return instance
301
+
302
+ async def update_all(self, updates: Dict[str, Any], **search_params) -> int:
303
+ """
304
+ Update all records that match the given conditions in one query.
305
+ Returns the number of rows that were updated.
306
+
307
+ Usage:
308
+ await repository.update_all(
309
+ {"field1": "new_value", "field2": 123},
310
+ some_field__gte=10,
311
+ other_field="foo"
312
+ )
313
+ """
314
+ conditions = await self.__get_conditions(**search_params)
315
+ stmt = update(self.model).where(*conditions).values(**updates)
316
+ result = await self.session.execute(stmt)
317
+ await self.session.commit()
318
+ return result.rowcount
319
+
320
+ async def destroy(self, id: Union[int, UUID]) -> None:
321
+ """
322
+ Destroy (delete) a single record by its primary key.
323
+ Raises NoResultFound if the record doesn't exist.
324
+ """
325
+ instance = await self.find(id) # Will raise NoResultFound if not found
326
+ await self.session.delete(instance)
327
+ await self.session.commit()
328
+
329
+ async def destroy_all(self, **search_params) -> int:
330
+ """
331
+ Destroy (delete) all records that match the given conditions in one query.
332
+ Returns the number of rows that were deleted.
333
+
334
+ Usage:
335
+ await repository.destroy_all(field1="value1", field2__gte=10)
336
+ """
337
+ conditions = await self.__get_conditions(**search_params)
338
+ stmt = delete(self.model).where(*conditions)
339
+ result = await self.session.execute(stmt)
340
+ await self.session.commit()
341
+ return result.rowcount
@@ -0,0 +1,36 @@
1
+ Metadata-Version: 2.4
2
+ Name: fastapi-repository
3
+ Version: 0.0.1
4
+ Summary: A base repository for FastAPI projects, inspired by Ruby on Rails' Active Record and Ransack.
5
+ Author-email: Seiya Takeda <takedaseiya@gmail.com>
6
+ Project-URL: Homepage, https://github.com/seiyat/fastapi-repository
7
+ Project-URL: Bug Tracker, https://github.com/seiyat/fastapi-repository/issues
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: License :: OSI Approved :: MIT License
10
+ Classifier: Operating System :: OS Independent
11
+ Requires-Python: >=3.8
12
+ Description-Content-Type: text/markdown
13
+ Requires-Dist: sqlalchemy>=1.4.0
14
+ Requires-Dist: fastapi>=0.70.0
15
+
16
+ # FastAPI Repository
17
+
18
+ A base repository for FastAPI projects, inspired by Ruby on Rails' Active Record and Ransack.
19
+
20
+ ## Installation
21
+
22
+ ```bash
23
+ pip install fastapi-repository
24
+ ```
25
+
26
+ ## Usage
27
+
28
+ ```python
29
+ from fastapi_repository import BaseRepository
30
+ from sqlalchemy.ext.asyncio import AsyncSession
31
+ from .models import User
32
+
33
+ class UserRepository(BaseRepository):
34
+ def __init__(self, session: AsyncSession):
35
+ super().__init__(session, model=User)
36
+ ```
@@ -0,0 +1,10 @@
1
+ README.md
2
+ pyproject.toml
3
+ fastapi_repository/__init__.py
4
+ fastapi_repository/base.py
5
+ fastapi_repository.egg-info/PKG-INFO
6
+ fastapi_repository.egg-info/SOURCES.txt
7
+ fastapi_repository.egg-info/dependency_links.txt
8
+ fastapi_repository.egg-info/requires.txt
9
+ fastapi_repository.egg-info/top_level.txt
10
+ tests/test_base_repository.py
@@ -0,0 +1,2 @@
1
+ sqlalchemy>=1.4.0
2
+ fastapi>=0.70.0
@@ -0,0 +1 @@
1
+ fastapi_repository
@@ -0,0 +1,26 @@
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "fastapi-repository"
7
+ version = "0.0.1"
8
+ authors = [
9
+ { name="Seiya Takeda", email="takedaseiya@gmail.com" },
10
+ ]
11
+ description = "A base repository for FastAPI projects, inspired by Ruby on Rails' Active Record and Ransack."
12
+ readme = "README.md"
13
+ requires-python = ">=3.8"
14
+ classifiers = [
15
+ "Programming Language :: Python :: 3",
16
+ "License :: OSI Approved :: MIT License",
17
+ "Operating System :: OS Independent",
18
+ ]
19
+ dependencies = [
20
+ "sqlalchemy>=1.4.0",
21
+ "fastapi>=0.70.0",
22
+ ]
23
+
24
+ [project.urls]
25
+ "Homepage" = "https://github.com/seiyat/fastapi-repository"
26
+ "Bug Tracker" = "https://github.com/seiyat/fastapi-repository/issues"
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,336 @@
1
+ from sqlalchemy.exc import NoResultFound
2
+ import pytest
3
+ from uuid import uuid4
4
+ from faker import Faker
5
+
6
+ fake = Faker()
7
+
8
+ @pytest.mark.asyncio
9
+ async def test_find_user_by_id(user_repository, user):
10
+ found_user = await user_repository.find(user.id)
11
+ assert found_user is not None
12
+ assert found_user.id == user.id
13
+
14
+
15
+ @pytest.mark.asyncio
16
+ async def test_find_user_not_found(user_repository):
17
+ with pytest.raises(NoResultFound):
18
+ await user_repository.find(uuid4())
19
+
20
+
21
+ @pytest.mark.asyncio
22
+ async def test_count_users(user_repository, users, user):
23
+ count = await user_repository.count()
24
+ assert count == len(users) + 1
25
+
26
+
27
+ @pytest.mark.asyncio
28
+ async def test_count_users_with_filter(user_repository, users):
29
+ count = await user_repository.count(email=users[0].email)
30
+ assert count == 1
31
+
32
+
33
+ @pytest.mark.asyncio
34
+ async def test_count_users_with_non_existent_filter(user_repository):
35
+ count = await user_repository.count(email="non_existent_email@example.com")
36
+ assert count == 0
37
+
38
+
39
+ @pytest.mark.asyncio
40
+ async def test_exists_user(user_repository, user):
41
+ exists = await user_repository.exists(email=user.email)
42
+ assert exists is True
43
+
44
+
45
+ @pytest.mark.asyncio
46
+ async def test_exists_user_not_found(user_repository):
47
+ exists = await user_repository.exists(email=fake.email())
48
+ assert exists is False
49
+
50
+
51
+ @pytest.mark.asyncio
52
+ async def test_find_by_or_raise_user(user_repository, user):
53
+ found_user = await user_repository.find_by_or_raise(email=user.email)
54
+ assert found_user is not None
55
+ assert found_user.id == user.id
56
+ assert found_user.email == user.email
57
+
58
+
59
+ @pytest.mark.asyncio
60
+ async def test_find_by_or_raise_user_not_found(user_repository):
61
+ with pytest.raises(NoResultFound):
62
+ await user_repository.find_by_or_raise(
63
+ email=fake.email()
64
+ )
65
+
66
+
67
+ @pytest.mark.asyncio
68
+ async def test_find_user_by_email(user_repository, user):
69
+ found_user = await user_repository.find_by(email=user.email)
70
+ assert found_user is not None
71
+ assert found_user.id == user.id
72
+ assert found_user.email == user.email
73
+
74
+
75
+ @pytest.mark.asyncio
76
+ async def test_find_user_by_email_not_found(user_repository):
77
+ not_found_user = await user_repository.find_by(email=fake.email())
78
+ assert not_found_user is None
79
+
80
+
81
+ @pytest.mark.asyncio
82
+ async def test_where_exact(user_repository, user):
83
+ found_users = await user_repository.where(email__exact=user.email)
84
+ assert len(found_users) == 1
85
+ assert found_users[0].id == user.id
86
+
87
+
88
+ @pytest.mark.asyncio
89
+ async def test_where_iexact(user_repository, user):
90
+ found_users = await user_repository.where(email__iexact=user.email.upper())
91
+ assert len(found_users) == 1
92
+ assert found_users[0].id == user.id
93
+
94
+
95
+ @pytest.mark.asyncio
96
+ async def test_where_contains(user_repository, user):
97
+ substring = user.email[:3]
98
+ found_users = await user_repository.where(email__contains=substring)
99
+ assert any(u.id == user.id for u in found_users)
100
+
101
+
102
+ @pytest.mark.asyncio
103
+ async def test_where_icontains(user_repository, user):
104
+ substring = user.email[:3].upper()
105
+ found_users = await user_repository.where(email__icontains=substring)
106
+ assert any(u.id == user.id for u in found_users)
107
+
108
+
109
+ @pytest.mark.asyncio
110
+ async def test_where_in(user_repository, user):
111
+ fake_email = fake.email()
112
+ found_users = await user_repository.where(email__in=[user.email, fake_email])
113
+ assert len(found_users) == 1
114
+ assert found_users[0].id == user.id
115
+
116
+
117
+ @pytest.mark.asyncio
118
+ async def test_where_startswith(user_repository, user):
119
+ prefix = user.email.split("@")[0][:3]
120
+ found_users = await user_repository.where(email__startswith=prefix)
121
+ assert any(u.id == user.id for u in found_users)
122
+
123
+
124
+ @pytest.mark.asyncio
125
+ async def test_where_istartswith(user_repository, user):
126
+ prefix = user.email.split("@")[0][:3].upper()
127
+ found_users = await user_repository.where(email__istartswith=prefix)
128
+ assert any(u.id == user.id for u in found_users)
129
+
130
+
131
+ @pytest.mark.asyncio
132
+ async def test_where_endswith(user_repository, user):
133
+ suffix = user.email.split("@")[-1]
134
+ found_users = await user_repository.where(email__endswith=suffix)
135
+ assert any(u.id == user.id for u in found_users)
136
+
137
+
138
+ @pytest.mark.asyncio
139
+ async def test_where_iendswith(user_repository, user):
140
+ suffix = user.email.split("@")[-1].upper()
141
+ found_users = await user_repository.where(email__iendswith=suffix)
142
+ assert any(u.id == user.id for u in found_users)
143
+
144
+
145
+ @pytest.mark.asyncio
146
+ async def test_where_repository(user_repository, user):
147
+ found_users = await user_repository.where(email=user.email, id=user.id)
148
+ assert len(found_users) == 1
149
+ assert found_users[0].id == user.id
150
+
151
+
152
+ @pytest.mark.asyncio
153
+ async def test_where_repository_no_results(user_repository):
154
+ not_found_users = await user_repository.where(email=fake.email())
155
+ assert not_found_users == []
156
+
157
+
158
+ @pytest.mark.asyncio
159
+ async def test_where_repository_limit(user_repository, users):
160
+ found_users = await user_repository.where(limit=2)
161
+ assert len(found_users) == 2
162
+
163
+
164
+ @pytest.mark.asyncio
165
+ async def test_where_repository_offset(user_repository, users, user):
166
+ found_users = await user_repository.where(offset=5)
167
+ assert len(found_users) == len(users) + 1 - 5
168
+
169
+
170
+ @pytest.mark.asyncio
171
+ async def test_where_repository_offset_and_limit(user_repository, users):
172
+ found_users = await user_repository.where(limit=3, offset=2)
173
+ assert len(found_users) == 3
174
+
175
+
176
+ @pytest.mark.asyncio
177
+ async def test_where_repository_sorted(user_repository, users, user):
178
+ found_users = await user_repository.where(sorted_by="email", sorted_order="asc")
179
+ assert len(found_users) == len(users) + 1
180
+ assert found_users == sorted(found_users, key=lambda u: u.email)
181
+
182
+
183
+ @pytest.mark.asyncio
184
+ async def test_where_repository_sorted_desc(user_repository, users, user):
185
+ found_users = await user_repository.where(sorted_by="email", sorted_order="desc")
186
+ assert len(found_users) == len(users) + 1
187
+ assert found_users == sorted(found_users, key=lambda u: u.email, reverse=True)
188
+
189
+
190
+ @pytest.mark.asyncio
191
+ async def test_where_repository_attribute_error(user_repository):
192
+ with pytest.raises(AttributeError, match="User has no attribute 'non_existent_column'"):
193
+ await user_repository.where(non_existent_column="value")
194
+
195
+
196
+ @pytest.mark.asyncio
197
+ async def test_where_repository_sorted_attribute_error(user_repository):
198
+ with pytest.raises(AttributeError, match="User has no attribute 'non_existent_column'"):
199
+ await user_repository.where(sorted_by="non_existent_column")
200
+
201
+ @pytest.mark.asyncio
202
+ async def test_create_user(user_repository):
203
+ user_data = {
204
+ "email": "test_create@example.com",
205
+ "hashed_password": "hashed_password_example",
206
+ }
207
+ new_user = await user_repository.create(**user_data)
208
+
209
+ assert new_user.id is not None
210
+ assert new_user.email == user_data["email"]
211
+ assert new_user.hashed_password == user_data["hashed_password"]
212
+
213
+ found_user = await user_repository.find(new_user.id)
214
+ assert found_user is not None
215
+ assert found_user.id == new_user.id
216
+
217
+
218
+ @pytest.mark.asyncio
219
+ async def test_create_user_with_invalid_field(user_repository):
220
+ user_data = {
221
+ "email": "test_invalid@example.com",
222
+ "hashed_password": "hashed_password_example",
223
+ "non_existent_field": "some_value",
224
+ }
225
+ with pytest.raises(TypeError):
226
+ await user_repository.create(**user_data)
227
+
228
+
229
+ @pytest.mark.asyncio
230
+ async def test_destroy_user(user_repository, user):
231
+ found_user = await user_repository.find(user.id)
232
+ assert found_user is not None
233
+
234
+ await user_repository.destroy(user.id)
235
+
236
+ with pytest.raises(NoResultFound):
237
+ await user_repository.find(user.id)
238
+
239
+
240
+ @pytest.mark.asyncio
241
+ async def test_destroy_user_not_found(user_repository):
242
+ non_existent_id = uuid4()
243
+ with pytest.raises(NoResultFound):
244
+ await user_repository.destroy(non_existent_id)
245
+
246
+
247
+ @pytest.mark.asyncio
248
+ async def test_destroy_all_no_conditions(user_repository, users, user):
249
+ count_before = await user_repository.count()
250
+ assert count_before == len(users) + 1
251
+
252
+ deleted_count = await user_repository.destroy_all()
253
+ assert deleted_count == len(users) + 1
254
+
255
+ count_after = await user_repository.count()
256
+ assert count_after == 0
257
+
258
+
259
+ @pytest.mark.asyncio
260
+ async def test_destroy_all_with_conditions(user_repository, users, user):
261
+ active_users_count = await user_repository.count(is_active=True)
262
+ assert active_users_count > 0
263
+
264
+ deleted_count = await user_repository.destroy_all(is_active=True)
265
+ assert deleted_count == active_users_count
266
+
267
+ remaining_active = await user_repository.count(is_active=True)
268
+ assert remaining_active == 0
269
+
270
+
271
+ @pytest.mark.asyncio
272
+ async def test_destroy_all_no_match(user_repository):
273
+ deleted_count = await user_repository.destroy_all(
274
+ email="this_should_not_match_any@example.com"
275
+ )
276
+ assert deleted_count == 0
277
+
278
+
279
+ @pytest.mark.asyncio
280
+ async def test_update_user(user_repository, user):
281
+ new_email = "updated_email@example.com"
282
+ updated_user = await user_repository.update(user.id, email=new_email)
283
+
284
+ assert updated_user.id == user.id
285
+ assert updated_user.email == new_email
286
+
287
+ found_user = await user_repository.find(user.id)
288
+ assert found_user.email == new_email
289
+
290
+
291
+ @pytest.mark.asyncio
292
+ async def test_update_user_not_found(user_repository):
293
+ non_existent_id = uuid4()
294
+ with pytest.raises(NoResultFound):
295
+ await user_repository.update(non_existent_id, email="doesnotexist@example.com")
296
+
297
+
298
+ @pytest.mark.asyncio
299
+ async def test_update_all_no_conditions(user_repository, users, user):
300
+ updated_count = await user_repository.update_all({"failed_attempts": 2})
301
+ assert updated_count == len(users) + 1
302
+
303
+ # Confirm all users have the new value
304
+ for u in await user_repository.where():
305
+ assert u.failed_attempts == 2
306
+
307
+
308
+ @pytest.mark.asyncio
309
+ async def test_update_all_with_conditions(user_repository, users, user):
310
+ updated_count = await user_repository.update_all(
311
+ {"failed_attempts": 3}, is_active=True
312
+ )
313
+ assert updated_count > 0
314
+
315
+ updated_users = await user_repository.where(failed_attempts=3)
316
+ for u in updated_users:
317
+ assert u.is_active is True
318
+
319
+ # create inactive user
320
+ user_data = {
321
+ "email": "inactive@example.com",
322
+ "hashed_password": "hashed_password_example",
323
+ "is_active": False
324
+ }
325
+ await user_repository.create(**user_data)
326
+ inactive_users = await user_repository.where(is_active=False)
327
+ for iu in inactive_users:
328
+ assert iu.failed_attempts != 3
329
+
330
+
331
+ @pytest.mark.asyncio
332
+ async def test_update_all_no_match(user_repository):
333
+ updated_count = await user_repository.update_all(
334
+ {"failed_attempts": 4}, failed_attempts=5
335
+ )
336
+ assert updated_count == 0