kodit 0.5.3__py3-none-any.whl → 0.5.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kodit might be problematic. Click here for more details.
- kodit/_version.py +2 -2
- kodit/application/factories/server_factory.py +54 -32
- kodit/application/services/code_search_application_service.py +89 -12
- kodit/application/services/commit_indexing_application_service.py +314 -195
- kodit/application/services/enrichment_query_service.py +274 -43
- kodit/application/services/indexing_worker_service.py +1 -1
- kodit/application/services/queue_service.py +15 -10
- kodit/application/services/sync_scheduler.py +2 -1
- kodit/domain/enrichments/architecture/architecture.py +1 -1
- kodit/domain/enrichments/architecture/physical/physical.py +1 -1
- kodit/domain/enrichments/development/development.py +1 -1
- kodit/domain/enrichments/development/snippet/snippet.py +12 -5
- kodit/domain/enrichments/enrichment.py +31 -4
- kodit/domain/enrichments/usage/api_docs.py +1 -1
- kodit/domain/enrichments/usage/usage.py +1 -1
- kodit/domain/entities/git.py +30 -25
- kodit/domain/factories/git_repo_factory.py +20 -5
- kodit/domain/protocols.py +56 -125
- kodit/domain/services/embedding_service.py +14 -16
- kodit/domain/services/git_repository_service.py +60 -38
- kodit/domain/services/git_service.py +18 -11
- kodit/domain/tracking/resolution_service.py +6 -16
- kodit/domain/value_objects.py +2 -9
- kodit/infrastructure/api/v1/dependencies.py +12 -3
- kodit/infrastructure/api/v1/query_params.py +27 -0
- kodit/infrastructure/api/v1/routers/commits.py +91 -85
- kodit/infrastructure/api/v1/routers/repositories.py +53 -37
- kodit/infrastructure/api/v1/routers/search.py +1 -1
- kodit/infrastructure/api/v1/schemas/enrichment.py +14 -0
- kodit/infrastructure/api/v1/schemas/repository.py +1 -1
- kodit/infrastructure/providers/litellm_provider.py +23 -1
- kodit/infrastructure/slicing/api_doc_extractor.py +0 -2
- kodit/infrastructure/sqlalchemy/embedding_repository.py +44 -34
- kodit/infrastructure/sqlalchemy/enrichment_association_repository.py +73 -0
- kodit/infrastructure/sqlalchemy/enrichment_v2_repository.py +116 -97
- kodit/infrastructure/sqlalchemy/entities.py +12 -116
- kodit/infrastructure/sqlalchemy/git_branch_repository.py +52 -244
- kodit/infrastructure/sqlalchemy/git_commit_repository.py +35 -324
- kodit/infrastructure/sqlalchemy/git_file_repository.py +70 -0
- kodit/infrastructure/sqlalchemy/git_repository.py +60 -230
- kodit/infrastructure/sqlalchemy/git_tag_repository.py +53 -240
- kodit/infrastructure/sqlalchemy/query.py +331 -0
- kodit/infrastructure/sqlalchemy/repository.py +203 -0
- kodit/infrastructure/sqlalchemy/task_repository.py +79 -58
- kodit/infrastructure/sqlalchemy/task_status_repository.py +45 -52
- kodit/migrations/versions/4b1a3b2c8fa5_refactor_git_tracking.py +190 -0
- {kodit-0.5.3.dist-info → kodit-0.5.5.dist-info}/METADATA +1 -1
- {kodit-0.5.3.dist-info → kodit-0.5.5.dist-info}/RECORD +51 -49
- kodit/infrastructure/mappers/enrichment_mapper.py +0 -83
- kodit/infrastructure/mappers/git_mapper.py +0 -193
- kodit/infrastructure/mappers/snippet_mapper.py +0 -104
- kodit/infrastructure/sqlalchemy/snippet_v2_repository.py +0 -479
- {kodit-0.5.3.dist-info → kodit-0.5.5.dist-info}/WHEEL +0 -0
- {kodit-0.5.3.dist-info → kodit-0.5.5.dist-info}/entry_points.txt +0 -0
- {kodit-0.5.3.dist-info → kodit-0.5.5.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,331 @@
|
|
|
1
|
+
"""Base query for SQLAlchemy repositories."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from typing import Any, Self
|
|
7
|
+
|
|
8
|
+
from sqlalchemy import Select
|
|
9
|
+
|
|
10
|
+
from kodit.domain.enrichments.enrichment import EnrichmentV2
|
|
11
|
+
from kodit.infrastructure.api.v1.query_params import PaginationParams
|
|
12
|
+
from kodit.infrastructure.sqlalchemy import entities as db_entities
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Query(ABC):
|
|
16
|
+
"""Base query/specification object for encapsulating query logic."""
|
|
17
|
+
|
|
18
|
+
@abstractmethod
|
|
19
|
+
def apply(self, stmt: Select, model_type: type) -> Select:
|
|
20
|
+
"""Apply this query's criteria to a SQLAlchemy Select statement."""
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class FilterOperator(Enum):
|
|
24
|
+
"""SQL filter operators."""
|
|
25
|
+
|
|
26
|
+
EQ = "eq"
|
|
27
|
+
NE = "ne"
|
|
28
|
+
GT = "gt"
|
|
29
|
+
GTE = "ge"
|
|
30
|
+
LT = "lt"
|
|
31
|
+
LTE = "le"
|
|
32
|
+
IN = "in_"
|
|
33
|
+
LIKE = "like"
|
|
34
|
+
ILIKE = "ilike"
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class FilterCriteria:
|
|
39
|
+
"""Filter criteria for a query."""
|
|
40
|
+
|
|
41
|
+
field: str
|
|
42
|
+
operator: FilterOperator
|
|
43
|
+
value: Any
|
|
44
|
+
|
|
45
|
+
def apply(self, model_type: type, stmt: Select) -> Select: # noqa: C901
|
|
46
|
+
"""Apply filter to statement."""
|
|
47
|
+
column = getattr(model_type, self.field)
|
|
48
|
+
|
|
49
|
+
# Convert AnyUrl to string for SQLAlchemy comparison
|
|
50
|
+
value = self.value
|
|
51
|
+
if hasattr(value, "__str__") and type(value).__module__ == "pydantic.networks":
|
|
52
|
+
value = str(value)
|
|
53
|
+
|
|
54
|
+
# Use column comparison methods instead of operators module
|
|
55
|
+
condition = None
|
|
56
|
+
match self.operator:
|
|
57
|
+
case FilterOperator.EQ:
|
|
58
|
+
condition = column == value
|
|
59
|
+
case FilterOperator.NE:
|
|
60
|
+
condition = column != value
|
|
61
|
+
case FilterOperator.GT:
|
|
62
|
+
condition = column > value
|
|
63
|
+
case FilterOperator.GTE:
|
|
64
|
+
condition = column >= value
|
|
65
|
+
case FilterOperator.LT:
|
|
66
|
+
condition = column < value
|
|
67
|
+
case FilterOperator.LTE:
|
|
68
|
+
condition = column <= value
|
|
69
|
+
case FilterOperator.IN:
|
|
70
|
+
condition = column.in_(value)
|
|
71
|
+
case FilterOperator.LIKE:
|
|
72
|
+
condition = column.like(value)
|
|
73
|
+
case FilterOperator.ILIKE:
|
|
74
|
+
condition = column.ilike(value)
|
|
75
|
+
|
|
76
|
+
return stmt.where(condition)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@dataclass
|
|
80
|
+
class SortCriteria:
|
|
81
|
+
"""Sort criteria for a query."""
|
|
82
|
+
|
|
83
|
+
field: str
|
|
84
|
+
descending: bool = False
|
|
85
|
+
|
|
86
|
+
def apply(self, model_type: type, stmt: Select) -> Select:
|
|
87
|
+
"""Apply sort to statement."""
|
|
88
|
+
column = getattr(model_type, self.field)
|
|
89
|
+
return stmt.order_by(column.desc() if self.descending else column.asc())
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@dataclass
|
|
93
|
+
class PaginationCriteria:
|
|
94
|
+
"""Pagination criteria for a query."""
|
|
95
|
+
|
|
96
|
+
limit: int | None = None
|
|
97
|
+
offset: int = 0
|
|
98
|
+
|
|
99
|
+
def apply(self, stmt: Select) -> Select:
|
|
100
|
+
"""Apply pagination to statement."""
|
|
101
|
+
stmt = stmt.offset(self.offset)
|
|
102
|
+
if self.limit is not None:
|
|
103
|
+
stmt = stmt.limit(self.limit)
|
|
104
|
+
return stmt
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class QueryBuilder(Query):
|
|
108
|
+
"""Composable query builder for constructing database queries."""
|
|
109
|
+
|
|
110
|
+
DEFAULT_SORT_FIELD = "created_at"
|
|
111
|
+
DEFAULT_SORT_DESCENDING = True
|
|
112
|
+
|
|
113
|
+
def __init__(self) -> None:
|
|
114
|
+
"""Initialize query builder."""
|
|
115
|
+
self._filters: list[FilterCriteria] = []
|
|
116
|
+
self._sorts: list[SortCriteria] = []
|
|
117
|
+
self._pagination: PaginationCriteria | None = None
|
|
118
|
+
|
|
119
|
+
def filter(self, field: str, operator: FilterOperator, value: Any) -> Self:
|
|
120
|
+
"""Add a filter criterion."""
|
|
121
|
+
self._filters.append(FilterCriteria(field, operator, value))
|
|
122
|
+
return self
|
|
123
|
+
|
|
124
|
+
def sort(self, field: str, *, descending: bool = False) -> Self:
|
|
125
|
+
"""Add a sort criterion."""
|
|
126
|
+
self._sorts.append(SortCriteria(field, descending))
|
|
127
|
+
return self
|
|
128
|
+
|
|
129
|
+
def paginate(self, pagination: PaginationParams) -> Self:
|
|
130
|
+
"""Add pagination."""
|
|
131
|
+
self._pagination = PaginationCriteria(
|
|
132
|
+
limit=pagination.limit, offset=pagination.offset
|
|
133
|
+
)
|
|
134
|
+
return self
|
|
135
|
+
|
|
136
|
+
def apply_filters_only(self, stmt: Select, model_type: type) -> Select:
|
|
137
|
+
"""Apply only filter criteria to the statement."""
|
|
138
|
+
for filter_criteria in self._filters:
|
|
139
|
+
stmt = filter_criteria.apply(model_type, stmt)
|
|
140
|
+
return stmt
|
|
141
|
+
|
|
142
|
+
def apply(self, stmt: Select, model_type: type) -> Select:
|
|
143
|
+
"""Apply all criteria to the statement."""
|
|
144
|
+
for filter_criteria in self._filters:
|
|
145
|
+
stmt = filter_criteria.apply(model_type, stmt)
|
|
146
|
+
|
|
147
|
+
if not self._sorts:
|
|
148
|
+
self._sorts = [
|
|
149
|
+
SortCriteria(
|
|
150
|
+
field=self.DEFAULT_SORT_FIELD,
|
|
151
|
+
descending=self.DEFAULT_SORT_DESCENDING,
|
|
152
|
+
)
|
|
153
|
+
]
|
|
154
|
+
for sort_criteria in self._sorts:
|
|
155
|
+
stmt = sort_criteria.apply(model_type, stmt)
|
|
156
|
+
|
|
157
|
+
if self._pagination:
|
|
158
|
+
stmt = self._pagination.apply(stmt)
|
|
159
|
+
|
|
160
|
+
return stmt
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class EnrichmentAssociationQueryBuilder(QueryBuilder):
|
|
164
|
+
"""Query builder for enrichment association entities."""
|
|
165
|
+
|
|
166
|
+
def for_enrichment_ids(self, enrichment_ids: list[int]) -> Self:
|
|
167
|
+
"""Build a query for associations by enrichment IDs."""
|
|
168
|
+
self.filter(
|
|
169
|
+
db_entities.EnrichmentAssociation.enrichment_id.key,
|
|
170
|
+
FilterOperator.IN,
|
|
171
|
+
enrichment_ids,
|
|
172
|
+
)
|
|
173
|
+
return self
|
|
174
|
+
|
|
175
|
+
def for_enrichments(self, enrichments: list[EnrichmentV2]) -> Self:
|
|
176
|
+
"""Build a query for enrichment associations by entity IDs."""
|
|
177
|
+
self.filter(
|
|
178
|
+
db_entities.EnrichmentAssociation.enrichment_id.key,
|
|
179
|
+
FilterOperator.IN,
|
|
180
|
+
[enrichment.id for enrichment in enrichments if enrichment.id is not None],
|
|
181
|
+
)
|
|
182
|
+
return self
|
|
183
|
+
|
|
184
|
+
def for_enrichment_type(self) -> Self:
|
|
185
|
+
"""Build a query for enrichment types."""
|
|
186
|
+
return self.for_entity_type(db_entities.EnrichmentV2.__tablename__)
|
|
187
|
+
|
|
188
|
+
def for_entity_type(self, entity_type: str) -> Self:
|
|
189
|
+
"""Build a query for enrichment associations by entity type."""
|
|
190
|
+
self.filter(
|
|
191
|
+
db_entities.EnrichmentAssociation.entity_type.key,
|
|
192
|
+
FilterOperator.EQ,
|
|
193
|
+
entity_type,
|
|
194
|
+
)
|
|
195
|
+
return self
|
|
196
|
+
|
|
197
|
+
def for_entity_ids(self, entity_ids: list[str]) -> Self:
|
|
198
|
+
"""Build a query for enrichment associations by entity IDs."""
|
|
199
|
+
self.filter(
|
|
200
|
+
db_entities.EnrichmentAssociation.entity_id.key,
|
|
201
|
+
FilterOperator.IN,
|
|
202
|
+
entity_ids,
|
|
203
|
+
)
|
|
204
|
+
return self
|
|
205
|
+
|
|
206
|
+
@staticmethod
|
|
207
|
+
def for_enrichment_association(
|
|
208
|
+
entity_type: str,
|
|
209
|
+
entity_id: str,
|
|
210
|
+
) -> QueryBuilder:
|
|
211
|
+
"""Build a query for a specific enrichment association."""
|
|
212
|
+
return EnrichmentAssociationQueryBuilder.for_enrichment_associations(
|
|
213
|
+
entity_type,
|
|
214
|
+
[entity_id],
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
@staticmethod
|
|
218
|
+
def for_enrichment_associations(
|
|
219
|
+
entity_type: str, entity_ids: list[str]
|
|
220
|
+
) -> QueryBuilder:
|
|
221
|
+
"""Build a query for enrichment associations by entity type and IDs."""
|
|
222
|
+
return (
|
|
223
|
+
QueryBuilder()
|
|
224
|
+
.filter(
|
|
225
|
+
db_entities.EnrichmentAssociation.entity_type.key,
|
|
226
|
+
FilterOperator.EQ,
|
|
227
|
+
entity_type,
|
|
228
|
+
)
|
|
229
|
+
.filter(
|
|
230
|
+
db_entities.EnrichmentAssociation.entity_id.key,
|
|
231
|
+
FilterOperator.IN,
|
|
232
|
+
entity_ids,
|
|
233
|
+
)
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
@staticmethod
|
|
237
|
+
def type_and_ids(
|
|
238
|
+
entity_type: str,
|
|
239
|
+
enrichment_ids: list[int],
|
|
240
|
+
) -> QueryBuilder:
|
|
241
|
+
"""Build a query for enrichment associations by enrichment IDs."""
|
|
242
|
+
return (
|
|
243
|
+
QueryBuilder()
|
|
244
|
+
.filter(
|
|
245
|
+
db_entities.EnrichmentAssociation.entity_type.key,
|
|
246
|
+
FilterOperator.EQ,
|
|
247
|
+
entity_type,
|
|
248
|
+
)
|
|
249
|
+
.filter(
|
|
250
|
+
db_entities.EnrichmentAssociation.enrichment_id.key,
|
|
251
|
+
FilterOperator.IN,
|
|
252
|
+
enrichment_ids,
|
|
253
|
+
)
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
@staticmethod
|
|
257
|
+
def associations_pointing_to_these_enrichments(
|
|
258
|
+
enrichment_ids: list[int],
|
|
259
|
+
) -> QueryBuilder:
|
|
260
|
+
"""Build a query for enrichment associations pointing to these enrichments."""
|
|
261
|
+
return EnrichmentAssociationQueryBuilder.type_and_ids(
|
|
262
|
+
entity_type=db_entities.EnrichmentV2.__tablename__,
|
|
263
|
+
enrichment_ids=enrichment_ids,
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
def for_commit(self, commit_sha: str) -> Self:
|
|
267
|
+
"""Build a query for enrichment associations for a commit."""
|
|
268
|
+
self.filter(
|
|
269
|
+
db_entities.EnrichmentAssociation.entity_type.key,
|
|
270
|
+
FilterOperator.EQ,
|
|
271
|
+
db_entities.GitCommit.__tablename__,
|
|
272
|
+
)
|
|
273
|
+
self.filter(
|
|
274
|
+
db_entities.EnrichmentAssociation.entity_id.key,
|
|
275
|
+
FilterOperator.EQ,
|
|
276
|
+
commit_sha,
|
|
277
|
+
)
|
|
278
|
+
return self
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
class EnrichmentQueryBuilder(QueryBuilder):
|
|
282
|
+
"""Query builder for enrichment entities."""
|
|
283
|
+
|
|
284
|
+
def for_ids(self, enrichment_ids: list[int]) -> Self:
|
|
285
|
+
"""Build a query for enrichments by their IDs."""
|
|
286
|
+
self.filter(
|
|
287
|
+
db_entities.EnrichmentV2.id.key,
|
|
288
|
+
FilterOperator.IN,
|
|
289
|
+
enrichment_ids,
|
|
290
|
+
)
|
|
291
|
+
return self
|
|
292
|
+
|
|
293
|
+
def for_type(self, enrichment_type: str) -> Self:
|
|
294
|
+
"""Build a query for enrichments by their type."""
|
|
295
|
+
self.filter(
|
|
296
|
+
db_entities.EnrichmentV2.type.key,
|
|
297
|
+
FilterOperator.EQ,
|
|
298
|
+
enrichment_type,
|
|
299
|
+
)
|
|
300
|
+
return self
|
|
301
|
+
|
|
302
|
+
def for_subtype(self, enrichment_subtype: str) -> Self:
|
|
303
|
+
"""Build a query for enrichments by their subtype."""
|
|
304
|
+
self.filter(
|
|
305
|
+
db_entities.EnrichmentV2.subtype.key,
|
|
306
|
+
FilterOperator.EQ,
|
|
307
|
+
enrichment_subtype,
|
|
308
|
+
)
|
|
309
|
+
return self
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
class GitFileQueryBuilder(QueryBuilder):
|
|
313
|
+
"""Query builder for git file entities."""
|
|
314
|
+
|
|
315
|
+
def for_commit_sha(self, commit_sha: str) -> Self:
|
|
316
|
+
"""Build a query for git files by their commit SHA."""
|
|
317
|
+
self.filter(
|
|
318
|
+
db_entities.GitCommitFile.commit_sha.key,
|
|
319
|
+
FilterOperator.EQ,
|
|
320
|
+
commit_sha,
|
|
321
|
+
)
|
|
322
|
+
return self
|
|
323
|
+
|
|
324
|
+
def for_blob_sha(self, blob_sha: str) -> Self:
|
|
325
|
+
"""Build a query for git files by their blob SHA."""
|
|
326
|
+
self.filter(
|
|
327
|
+
db_entities.GitCommitFile.blob_sha.key,
|
|
328
|
+
FilterOperator.EQ,
|
|
329
|
+
blob_sha,
|
|
330
|
+
)
|
|
331
|
+
return self
|
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
"""Abstract base classes for repositories."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from collections.abc import Callable, Generator
|
|
5
|
+
from typing import Any, Generic, TypeVar
|
|
6
|
+
|
|
7
|
+
from sqlalchemy import func, inspect, select
|
|
8
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
9
|
+
|
|
10
|
+
from kodit.infrastructure.sqlalchemy.query import Query
|
|
11
|
+
from kodit.infrastructure.sqlalchemy.unit_of_work import SqlAlchemyUnitOfWork
|
|
12
|
+
|
|
13
|
+
DomainEntityType = TypeVar("DomainEntityType")
|
|
14
|
+
DatabaseEntityType = TypeVar("DatabaseEntityType")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SqlAlchemyRepository(ABC, Generic[DomainEntityType, DatabaseEntityType]):
|
|
18
|
+
"""Base repository with common SQLAlchemy patterns."""
|
|
19
|
+
|
|
20
|
+
def __init__(self, session_factory: Callable[[], AsyncSession]) -> None:
|
|
21
|
+
"""Initialize the repository."""
|
|
22
|
+
self.session_factory = session_factory
|
|
23
|
+
self._chunk_size = 1000
|
|
24
|
+
|
|
25
|
+
@abstractmethod
|
|
26
|
+
def _get_id(self, entity: DomainEntityType) -> Any:
|
|
27
|
+
"""Extract ID from domain entity."""
|
|
28
|
+
|
|
29
|
+
@property
|
|
30
|
+
@abstractmethod
|
|
31
|
+
def db_entity_type(self) -> type[DatabaseEntityType]:
|
|
32
|
+
"""The SQLAlchemy model type."""
|
|
33
|
+
|
|
34
|
+
@staticmethod
|
|
35
|
+
@abstractmethod
|
|
36
|
+
def to_domain(db_entity: DatabaseEntityType) -> DomainEntityType:
|
|
37
|
+
"""Map database entity to domain entity."""
|
|
38
|
+
|
|
39
|
+
@staticmethod
|
|
40
|
+
@abstractmethod
|
|
41
|
+
def to_db(domain_entity: DomainEntityType) -> DatabaseEntityType:
|
|
42
|
+
"""Map domain entity to database entity."""
|
|
43
|
+
|
|
44
|
+
def _update_db_entity(
|
|
45
|
+
self, existing: DatabaseEntityType, new: DatabaseEntityType
|
|
46
|
+
) -> None:
|
|
47
|
+
"""Update existing database entity with values from new entity."""
|
|
48
|
+
mapper = inspect(type(existing))
|
|
49
|
+
if mapper is None:
|
|
50
|
+
return
|
|
51
|
+
# Skip auto-managed columns
|
|
52
|
+
skip_columns = {"created_at", "updated_at", "id"}
|
|
53
|
+
for column in mapper.columns:
|
|
54
|
+
if not column.primary_key and column.key not in skip_columns:
|
|
55
|
+
setattr(existing, column.key, getattr(new, column.key))
|
|
56
|
+
|
|
57
|
+
async def get(self, entity_id: Any) -> DomainEntityType:
|
|
58
|
+
"""Get entity by primary key."""
|
|
59
|
+
async with SqlAlchemyUnitOfWork(self.session_factory) as session:
|
|
60
|
+
db_entity = await session.get(self.db_entity_type, entity_id)
|
|
61
|
+
if not db_entity:
|
|
62
|
+
raise ValueError(f"Entity with id {entity_id} not found")
|
|
63
|
+
return self.to_domain(db_entity)
|
|
64
|
+
|
|
65
|
+
async def find(self, query: Query) -> list[DomainEntityType]:
|
|
66
|
+
"""Find all entities matching query."""
|
|
67
|
+
async with SqlAlchemyUnitOfWork(self.session_factory) as session:
|
|
68
|
+
stmt = select(self.db_entity_type)
|
|
69
|
+
stmt = query.apply(stmt, self.db_entity_type)
|
|
70
|
+
db_entities = (await session.scalars(stmt)).all()
|
|
71
|
+
return [self.to_domain(db) for db in db_entities]
|
|
72
|
+
|
|
73
|
+
async def count(self, query: Query) -> int:
|
|
74
|
+
"""Count the number of entities matching query."""
|
|
75
|
+
async with SqlAlchemyUnitOfWork(self.session_factory) as session:
|
|
76
|
+
stmt = select(self.db_entity_type).with_only_columns(func.count())
|
|
77
|
+
# For count queries, only apply filters, not sorting or pagination
|
|
78
|
+
from kodit.infrastructure.sqlalchemy.query import QueryBuilder
|
|
79
|
+
|
|
80
|
+
if isinstance(query, QueryBuilder):
|
|
81
|
+
# Apply only filters, skip sorting and pagination for count queries
|
|
82
|
+
stmt = query.apply_filters_only(stmt, self.db_entity_type)
|
|
83
|
+
else:
|
|
84
|
+
stmt = query.apply(stmt, self.db_entity_type)
|
|
85
|
+
result = await session.scalar(stmt)
|
|
86
|
+
return result or 0
|
|
87
|
+
|
|
88
|
+
async def save(self, entity: DomainEntityType) -> DomainEntityType:
|
|
89
|
+
"""Save entity (create new or update existing)."""
|
|
90
|
+
async with SqlAlchemyUnitOfWork(self.session_factory) as session:
|
|
91
|
+
entity_id = self._get_id(entity)
|
|
92
|
+
# Skip session.get if entity_id is None (new entity not yet persisted)
|
|
93
|
+
existing_db_entity = (
|
|
94
|
+
await session.get(self.db_entity_type, entity_id)
|
|
95
|
+
if entity_id is not None
|
|
96
|
+
else None
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
if existing_db_entity:
|
|
100
|
+
# Update existing entity
|
|
101
|
+
new_db_entity = self.to_db(entity)
|
|
102
|
+
self._update_db_entity(existing_db_entity, new_db_entity)
|
|
103
|
+
db_entity = existing_db_entity
|
|
104
|
+
else:
|
|
105
|
+
# Create new entity
|
|
106
|
+
db_entity = self.to_db(entity)
|
|
107
|
+
session.add(db_entity)
|
|
108
|
+
|
|
109
|
+
await session.flush()
|
|
110
|
+
return self.to_domain(db_entity)
|
|
111
|
+
|
|
112
|
+
async def save_bulk(
|
|
113
|
+
self, entities: list[DomainEntityType]
|
|
114
|
+
) -> list[DomainEntityType]:
|
|
115
|
+
"""Save multiple entities in bulk (create new or update existing)."""
|
|
116
|
+
async with SqlAlchemyUnitOfWork(self.session_factory) as session:
|
|
117
|
+
all_saved_db_entities = []
|
|
118
|
+
|
|
119
|
+
for chunk in self._chunked_domain(entities):
|
|
120
|
+
# Get IDs for all entities in chunk
|
|
121
|
+
entity_ids = [self._get_id(entity) for entity in chunk]
|
|
122
|
+
|
|
123
|
+
# Fetch all existing entities in one query
|
|
124
|
+
existing_entities = {}
|
|
125
|
+
for entity_id in entity_ids:
|
|
126
|
+
# Skip None IDs (new entities not yet persisted)
|
|
127
|
+
if entity_id is None:
|
|
128
|
+
continue
|
|
129
|
+
existing = await session.get(self.db_entity_type, entity_id)
|
|
130
|
+
if existing:
|
|
131
|
+
existing_entities[entity_id] = existing
|
|
132
|
+
|
|
133
|
+
# Process each entity
|
|
134
|
+
new_entities = []
|
|
135
|
+
chunk_db_entities = []
|
|
136
|
+
for entity in chunk:
|
|
137
|
+
entity_id = self._get_id(entity)
|
|
138
|
+
new_db_entity = self.to_db(entity)
|
|
139
|
+
|
|
140
|
+
if entity_id in existing_entities:
|
|
141
|
+
# Update existing entity
|
|
142
|
+
existing = existing_entities[entity_id]
|
|
143
|
+
self._update_db_entity(existing, new_db_entity)
|
|
144
|
+
chunk_db_entities.append(existing)
|
|
145
|
+
else:
|
|
146
|
+
# Collect new entities to add
|
|
147
|
+
new_entities.append(new_db_entity)
|
|
148
|
+
chunk_db_entities.append(new_db_entity)
|
|
149
|
+
|
|
150
|
+
# Add all new entities at once
|
|
151
|
+
if new_entities:
|
|
152
|
+
session.add_all(new_entities)
|
|
153
|
+
|
|
154
|
+
await session.flush()
|
|
155
|
+
all_saved_db_entities.extend(chunk_db_entities)
|
|
156
|
+
|
|
157
|
+
return [self.to_domain(db) for db in all_saved_db_entities]
|
|
158
|
+
|
|
159
|
+
async def exists(self, entity_id: Any) -> bool:
|
|
160
|
+
"""Check if entity exists by primary key."""
|
|
161
|
+
async with SqlAlchemyUnitOfWork(self.session_factory) as session:
|
|
162
|
+
db_entity = await session.get(self.db_entity_type, entity_id)
|
|
163
|
+
return db_entity is not None
|
|
164
|
+
|
|
165
|
+
async def delete(self, entity: DomainEntityType) -> None:
|
|
166
|
+
"""Remove entity."""
|
|
167
|
+
async with SqlAlchemyUnitOfWork(self.session_factory) as session:
|
|
168
|
+
db_entity = await session.get(self.db_entity_type, self._get_id(entity))
|
|
169
|
+
if db_entity:
|
|
170
|
+
await session.delete(db_entity)
|
|
171
|
+
|
|
172
|
+
async def delete_by_query(self, query: Query) -> None:
|
|
173
|
+
"""Remove entities by query."""
|
|
174
|
+
async with SqlAlchemyUnitOfWork(self.session_factory) as session:
|
|
175
|
+
stmt = select(self.db_entity_type)
|
|
176
|
+
stmt = query.apply(stmt, self.db_entity_type)
|
|
177
|
+
db_entities = list((await session.scalars(stmt)).all())
|
|
178
|
+
if not db_entities:
|
|
179
|
+
return
|
|
180
|
+
for chunk in self._chunked_db(db_entities):
|
|
181
|
+
for db_entity in chunk:
|
|
182
|
+
await session.delete(db_entity)
|
|
183
|
+
await session.flush()
|
|
184
|
+
|
|
185
|
+
def _chunked_domain(
|
|
186
|
+
self,
|
|
187
|
+
items: list[DomainEntityType],
|
|
188
|
+
chunk_size: int | None = None,
|
|
189
|
+
) -> Generator[list[DomainEntityType], None, None]:
|
|
190
|
+
"""Yield chunks of items."""
|
|
191
|
+
chunk_size = chunk_size or self._chunk_size
|
|
192
|
+
for i in range(0, len(items), chunk_size):
|
|
193
|
+
yield items[i : i + chunk_size]
|
|
194
|
+
|
|
195
|
+
def _chunked_db(
|
|
196
|
+
self,
|
|
197
|
+
items: list[DatabaseEntityType],
|
|
198
|
+
chunk_size: int | None = None,
|
|
199
|
+
) -> Generator[list[DatabaseEntityType], None, None]:
|
|
200
|
+
"""Yield chunks of items."""
|
|
201
|
+
chunk_size = chunk_size or self._chunk_size
|
|
202
|
+
for i in range(0, len(items), chunk_size):
|
|
203
|
+
yield items[i : i + chunk_size]
|